Lucidrains 系列项目源码解析(三十八)
Data source
The enwik8 data was downloaded from the Hutter prize page: prize.hutter1.net/
.\lucidrains\g-mlp-pytorch\g_mlp_pytorch\autoregressive_wrapper.py
import torch
from torch import nn
import torch.nn.functional as F
# 定义一个装饰器函数,用于在模型评估时切换为eval模式
def eval_decorator(fn):
def inner(model, *args, **kwargs):
was_training = model.training
model.eval()
out = fn(model, *args, **kwargs)
model.train(was_training)
return out
return inner
# 定义一个函数用于对logits进行top k过滤
def top_k(logits, thres = 0.9):
k = int((1 - thres) * logits.shape[-1])
val, ind = torch.topk(logits, k)
probs = torch.full_like(logits, float('-inf'))
probs.scatter_(1, ind, val)
return probs
# 定义一个包装类,用于自回归模型
class AutoregressiveWrapper(nn.Module):
def __init__(self, net, ignore_index = -100, pad_value = 0):
super().__init__()
self.pad_value = pad_value
self.ignore_index = ignore_index
self.net = net
self.max_seq_len = net.seq_len
# 生成函数,用于生成序列
@torch.no_grad()
@eval_decorator
def generate(self, start_tokens, seq_len, eos_token = None, temperature = 1., filter_logits_fn = top_k, filter_thres = 0.9, **kwargs):
device = start_tokens.device
num_dims = len(start_tokens.shape)
if num_dims == 1:
start_tokens = start_tokens[None, :]
b, t = start_tokens.shape
out = start_tokens
for _ in range(seq_len):
x = out[:, -self.max_seq_len:]
logits = self.net(x, **kwargs)[:, -1, :]
filtered_logits = top_k(logits, thres = filter_thres)
probs = F.softmax(filtered_logits / temperature, dim=-1)
sample = torch.multinomial(probs, 1)
out = torch.cat((out, sample), dim=-1)
if eos_token is not None and (sample == eos_token).all():
break
out = out[:, t:]
if num_dims == 1:
out = out.squeeze(0)
return out
# 前向传播函数,用于计算损失
def forward(self, x, **kwargs):
xi, xo = x[:, :-1], x[:, 1:]
out = self.net(xi, **kwargs)
loss = F.cross_entropy(out.transpose(1, 2), xo, ignore_index = self.ignore_index)
return loss
.\lucidrains\g-mlp-pytorch\g_mlp_pytorch\g_mlp_pytorch.py
# 从 random 模块中导入 randrange 函数
# 从 torch 模块中导入相关函数和类
# 从 einops 模块中导入 rearrange, repeat 函数以及 Rearrange, Reduce 类
from random import randrange
import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat
from einops.layers.torch import Rearrange, Reduce
# functions
# 判断值是否存在的函数
def exists(val):
return val is not None
# 将输入值转换为元组的函数
def pair(val):
return (val, val) if not isinstance(val, tuple) else val
# 对层进行 dropout 处理的函数
def dropout_layers(layers, prob_survival):
if prob_survival == 1:
return layers
num_layers = len(layers)
to_drop = torch.zeros(num_layers).uniform_(0., 1.) > prob_survival
# 确保至少有一层保留
if all(to_drop):
rand_index = randrange(num_layers)
to_drop[rand_index] = False
layers = [layer for (layer, drop) in zip(layers, to_drop) if not drop]
return layers
# 对张量进行平移的函数
def shift(t, amount, mask = None):
if amount == 0:
return t
return F.pad(t, (0, 0, amount, -amount), value = 0.)
# helper classes
# 残差连接的类
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x):
return self.fn(x) + x
# 对输入进行预平移的类
class PreShiftTokens(nn.Module):
def __init__(self, shifts, fn):
super().__init__()
self.fn = fn
self.shifts = tuple(shifts)
def forward(self, x, **kwargs):
if self.shifts == (0,):
return self.fn(x, **kwargs)
shifts = self.shifts
segments = len(shifts)
feats_per_shift = x.shape[-1] // segments
splitted = x.split(feats_per_shift, dim = -1)
segments_to_shift, rest = splitted[:segments], splitted[segments:]
segments_to_shift = list(map(lambda args: shift(*args), zip(segments_to_shift, shifts)))
x = torch.cat((*segments_to_shift, *rest), dim = -1)
return self.fn(x, **kwargs)
# 对输入进行预归一化的类
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = nn.LayerNorm(dim)
def forward(self, x, **kwargs):
x = self.norm(x)
return self.fn(x, **kwargs)
# 注意力机制类
class Attention(nn.Module):
def __init__(self, dim_in, dim_out, dim_inner, causal = False):
super().__init__()
self.scale = dim_inner ** -0.5
self.causal = causal
self.to_qkv = nn.Linear(dim_in, dim_inner * 3, bias = False)
self.to_out = nn.Linear(dim_inner, dim_out)
def forward(self, x):
device = x.device
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
if self.causal:
mask = torch.ones(sim.shape[-2:], device = device).triu(1).bool()
sim.masked_fill_(mask[None, ...], -torch.finfo(q.dtype).max)
attn = sim.softmax(dim = -1)
out = einsum('b i j, b j d -> b i d', attn, v)
return self.to_out(out)
# 空间门控单元类
class SpatialGatingUnit(nn.Module):
def __init__(
self,
dim,
dim_seq,
causal = False,
act = nn.Identity(),
heads = 1,
init_eps = 1e-3,
circulant_matrix = False
):
super().__init__()
dim_out = dim // 2
self.heads = heads
self.causal = causal
self.norm = nn.LayerNorm(dim_out)
self.act = act
# 参数
if circulant_matrix:
self.circulant_pos_x = nn.Parameter(torch.ones(heads, dim_seq))
self.circulant_pos_y = nn.Parameter(torch.ones(heads, dim_seq))
self.circulant_matrix = circulant_matrix
shape = (heads, dim_seq,) if circulant_matrix else (heads, dim_seq, dim_seq)
weight = torch.zeros(shape)
self.weight = nn.Parameter(weight)
init_eps /= dim_seq
nn.init.uniform_(self.weight, -init_eps, init_eps)
self.bias = nn.Parameter(torch.ones(heads, dim_seq))
# 定义前向传播函数,接受输入 x 和门控信息 gate_res
def forward(self, x, gate_res = None):
# 获取输入 x 的设备信息、特征维度 n 和注意力头数 h
device, n, h = x.device, x.shape[1], self.heads
# 将输入 x 切分为结果 res 和门控信息 gate
res, gate = x.chunk(2, dim = -1)
# 对门控信息 gate 进行归一化处理
gate = self.norm(gate)
# 获取权重和偏置参数
weight, bias = self.weight, self.bias
# 如果使用循环矩阵
if self.circulant_matrix:
# 构建循环矩阵
# 获取权重参数的最后一个维度大小
dim_seq = weight.shape[-1]
# 在权重参数的最后一个维度上进行填充
weight = F.pad(weight, (0, dim_seq), value = 0)
weight = repeat(weight, '... n -> ... (r n)', r = dim_seq)
weight = weight[:, :-dim_seq].reshape(h, dim_seq, 2 * dim_seq - 1)
weight = weight[:, :, (dim_seq - 1):]
# 赋予循环矩阵绝对位置感知
pos_x, pos_y = self.circulant_pos_x, self.circulant_pos_y
weight = weight * rearrange(pos_x, 'h i -> h i ()') * rearrange(pos_y, 'h j -> h () j')
# 如果是因果关系
if self.causal:
# 裁剪权重和偏置参数
weight, bias = weight[:, :n, :n], bias[:, :n]
# 创建掩码,使得只能看到当前位置及之前的信息
mask = torch.ones(weight.shape[-2:], device = device).triu_(1).bool()
mask = rearrange(mask, 'i j -> () i j')
weight = weight.masked_fill(mask, 0.)
# 重排门控信息 gate 的维度
gate = rearrange(gate, 'b n (h d) -> b h n d', h = h)
# 执行矩阵乘法操作
gate = einsum('b h n d, h m n -> b h m d', gate, weight)
# 加上偏置参数
gate = gate + rearrange(bias, 'h n -> () h n ()')
# 重排门控信息 gate 的维度
gate = rearrange(gate, 'b h n d -> b n (h d)')
# 如果存在门控信息 gate_res,则将其加到 gate 上
if exists(gate_res):
gate = gate + gate_res
# 对 gate 执行激活函数,并乘以结果 res
return self.act(gate) * res
# 定义 gMLPBlock 类,继承自 nn.Module 类
class gMLPBlock(nn.Module):
# 初始化函数
def __init__(
self,
*,
dim, # 输入维度
dim_ff, # Feed-Forward 层维度
seq_len, # 序列长度
heads = 1, # 多头注意力机制中的头数
attn_dim = None, # 注意力机制的维度
causal = False, # 是否使用因果关系
act = nn.Identity(), # 激活函数,默认为恒等映射
circulant_matrix = False # 是否使用循环矩阵
):
super().__init__()
# 输入投影层,包含线性变换和 GELU 激活函数
self.proj_in = nn.Sequential(
nn.Linear(dim, dim_ff),
nn.GELU()
)
# 如果存在注意力机制的维度,则创建注意力对象
self.attn = Attention(dim, dim_ff // 2, attn_dim, causal) if exists(attn_dim) else None
# 空间门控单元
self.sgu = SpatialGatingUnit(dim_ff, seq_len, causal, act, heads, circulant_matrix = circulant_matrix)
# 输出投影层
self.proj_out = nn.Linear(dim_ff // 2, dim)
# 前向传播函数
def forward(self, x):
# 如果存在注意力对象,则进行注意力计算
gate_res = self.attn(x) if exists(self.attn) else None
x = self.proj_in(x) # 输入投影
x = self.sgu(x, gate_res = gate_res) # 空间门控单元
x = self.proj_out(x) # 输出投影
return x
# 主要类
# 定义 gMLP 类,继承自 nn.Module 类
class gMLP(nn.Module):
# 初始化函数
def __init__(
self,
*,
num_tokens = None, # 标记数量
dim, # 输入维度
depth, # 深度
seq_len, # 序列长度
heads = 1, # 多头注意力机制中的头数
ff_mult = 4, # Feed-Forward 层维度倍数
attn_dim = None, # 注意力机制的维度
prob_survival = 1., # 生存概率
causal = False, # 是否使用因果关系
circulant_matrix = False, # 是否使用循环矩阵
shift_tokens = 0, # 标记偏移
act = nn.Identity() # 激活函数,默认为恒等映射
):
super().__init__()
assert (dim % heads) == 0, 'dimension must be divisible by number of heads'
dim_ff = dim * ff_mult
self.seq_len = seq_len
self.prob_survival = prob_survival
# Embedding 层
self.to_embed = nn.Embedding(num_tokens, dim) if exists(num_tokens) else nn.Identity()
token_shifts = tuple(range(0 if causal else -shift_tokens, shift_tokens + 1))
# 层列表
self.layers = nn.ModuleList([Residual(PreNorm(dim, PreShiftTokens(token_shifts, gMLPBlock(dim = dim, heads = heads, dim_ff = dim_ff, seq_len = seq_len, attn_dim = attn_dim, causal = causal, act = act, circulant_matrix = circulant_matrix))) for i in range(depth)])
# 输出层
self.to_logits = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_tokens)
) if exists(num_tokens) else nn.Identity()
# 前向传播函数
def forward(self, x):
x = self.to_embed(x) # Embedding
layers = self.layers if not self.training else dropout_layers(self.layers, self.prob_survival)
out = nn.Sequential(*layers)(x) # 层序列
return self.to_logits(out) # 输出层
# 定义 gMLPVision 类,继承自 nn.Module 类
class gMLPVision(nn.Module):
# 初始化函数
def __init__(
self,
*,
image_size, # 图像尺寸
patch_size, # 补丁尺寸
num_classes, # 类别数量
dim, # 输入维度
depth, # 深度
heads = 1, # 多头注意力机制中的头数
ff_mult = 4, # Feed-Forward 层维度倍数
channels = 3, # 通道数
attn_dim = None, # 注意力机制的维度
prob_survival = 1. # 生存概率
):
super().__init__()
assert (dim % heads) == 0, 'dimension must be divisible by number of heads'
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
assert (image_height % patch_height) == 0 and (image_width % patch_width) == 0, 'image height and width must be divisible by patch size'
num_patches = (image_height // patch_height) * (image_width // patch_width)
dim_ff = dim * ff_mult
# 补丁嵌入层
self.to_patch_embed = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (c p1 p2)', p1 = patch_height, p2 = patch_width),
nn.Linear(channels * patch_height * patch_width, dim)
)
self.prob_survival = prob_survival
# 层列表
self.layers = nn.ModuleList([Residual(PreNorm(dim, gMLPBlock(dim = dim, heads = heads, dim_ff = dim_ff, seq_len = num_patches, attn_dim = attn_dim))) for i in range(depth)])
# 输出层
self.to_logits = nn.Sequential(
nn.LayerNorm(dim),
Reduce('b n d -> b d', 'mean'),
nn.Linear(dim, num_classes)
)
# 前向传播函数
def forward(self, x):
x = self.to_patch_embed(x) # 补丁嵌入
layers = self.layers if not self.training else dropout_layers(self.layers, self.prob_survival)
x = nn.Sequential(*layers)(x) # 层序列
return self.to_logits(x) # 输出层
.\lucidrains\g-mlp-pytorch\g_mlp_pytorch\__init__.py
# 从 g_mlp_pytorch.g_mlp_pytorch 模块中导入 gMLP, gMLPVision, gMLPBlock, SpatialGatingUnit 类
from g_mlp_pytorch.g_mlp_pytorch import gMLP, gMLPVision, gMLPBlock, SpatialGatingUnit

gMLP - Pytorch
Implementation of gMLP, an all-MLP replacement for Transformers, in Pytorch
Install
$ pip install g-mlp-pytorch
Usage
For masked language modelling
import torch
from torch import nn
from g_mlp_pytorch import gMLP
model = gMLP(
num_tokens = 20000,
dim = 512,
depth = 6,
seq_len = 256,
circulant_matrix = True, # use circulant weight matrix for linear increase in parameters in respect to sequence length
act = nn.Tanh() # activation for spatial gate (defaults to identity)
)
x = torch.randint(0, 20000, (1, 256))
logits = model(x) # (1, 256, 20000)
For image classification
import torch
from g_mlp_pytorch import gMLPVision
model = gMLPVision(
image_size = 256,
patch_size = 16,
num_classes = 1000,
dim = 512,
depth = 6
)
img = torch.randn(1, 3, 256, 256)
logits = model(img) # (1, 1000)
You can also add a tiny amount of attention (one-headed) to boost performance, as mentioned in the paper as aMLP, with the addition of one extra keyword attn_dim. This applies to both gMLPVision and gMLP
import torch
from g_mlp_pytorch import gMLPVision
model = gMLPVision(
image_size = 256,
patch_size = 16,
num_classes = 1000,
dim = 512,
depth = 6,
attn_dim = 64
)
img = torch.randn(1, 3, 256, 256)
pred = model(img) # (1, 1000)
Non-square images and patch sizes
import torch
from g_mlp_pytorch import gMLPVision
model = gMLPVision(
image_size = (256, 128),
patch_size = (16, 8),
num_classes = 1000,
dim = 512,
depth = 6,
attn_dim = 64
)
img = torch.randn(1, 3, 256, 128)
pred = model(img) # (1, 1000)
Experimental
A independent researcher proposes using a multi-headed approach for gMLPs in a blogpost on Zhihu. To do so, just set heads to be greater than 1
import torch
from torch import nn
from g_mlp_pytorch import gMLP
model = gMLP(
num_tokens = 20000,
dim = 512,
depth = 6,
seq_len = 256,
causal = True,
circulant_matrix = True,
heads = 4 # 4 heads
)
x = torch.randint(0, 20000, (1, 256))
logits = model(x) # (1, 256, 20000)
Citations
@misc{liu2021pay,
title = {Pay Attention to MLPs},
author = {Hanxiao Liu and Zihang Dai and David R. So and Quoc V. Le},
year = {2021},
eprint = {2105.08050},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
@software{peng_bo_2021_5196578,
author = {PENG Bo},
title = {BlinkDL/RWKV-LM: 0.01},
month = aug,
year = 2021,
publisher = {Zenodo},
version = {0.01},
doi = {10.5281/zenodo.5196578},
url = {https://doi.org/10.5281/zenodo.5196578%7D
}
.\lucidrains\g-mlp-pytorch\setup.py
# 导入设置和查找包的函数
from setuptools import setup, find_packages
# 设置包的元数据
setup(
name = 'g-mlp-pytorch', # 包的名称
packages = find_packages(), # 查找所有包
version = '0.1.5', # 版本号
license='MIT', # 许可证
description = 'gMLP - Pytorch', # 描述
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
url = 'https://github.com/lucidrains/g-mlp-pytorch', # 项目链接
keywords = [ # 关键词列表
'artificial intelligence',
'deep learning',
'multi-layered-preceptrons'
],
install_requires=[ # 安装依赖
'einops>=0.3',
'torch>=1.6'
],
classifiers=[ # 分类器列表
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\g-mlp-pytorch\train.py
# 导入所需的库
from g_mlp_pytorch import gMLP
from g_mlp_pytorch.autoregressive_wrapper import AutoregressiveWrapper
import random
import tqdm
import gzip
import numpy as np
import torch
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
# 定义常量
NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 2e-4
VALIDATE_EVERY = 100
GENERATE_EVERY = 500
GENERATE_LENGTH = 768
SEQ_LEN = 768
# 定义辅助函数
# 从 token 解码为字符
def decode_token(token):
return str(chr(max(32, token)))
# 从 tokens 解码为字符串
def decode_tokens(tokens):
return ''.join(list(map(decode_token, tokens)))
# 实例化类似 GPT 的解码器模型
model = gMLP(
num_tokens = 256,
dim = 512,
seq_len = SEQ_LEN,
depth = 8,
causal = True
)
model = AutoregressiveWrapper(model)
model.cuda()
# 准备 enwik8 数据
with gzip.open('./data/enwik8.gz') as file:
X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
trX, vaX = np.split(X, [int(90e6)])
data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)
# 定义数据集类
class TextSamplerDataset(Dataset):
def __init__(self, data, seq_len):
super().__init__()
self.data = data
self.seq_len = seq_len
def __getitem__(self, index):
rand_start = torch.randint(0, self.data.size(0) - self.seq_len - 1, (1,))
full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
return full_seq.cuda()
def __len__(self):
return self.data.size(0) // self.seq_len
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))
# 定义优化器
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
# 训练
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
model.train()
for __ in range(GRADIENT_ACCUMULATE_EVERY):
loss = model(next(train_loader))
loss.backward()
print(f'training loss: {loss.item()}')
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
optim.step()
optim.zero_grad()
if i % VALIDATE_EVERY == 0:
model.eval()
with torch.no_grad():
loss = model(next(val_loader))
print(f'validation loss: {loss.item()}')
if i % GENERATE_EVERY == 0:
model.eval()
inp = random.choice(val_dataset)[:-1]
prime = decode_tokens(inp)
print(f'%s \n\n %s', (prime, '*' * 100))
sample = model.generate(inp, GENERATE_LENGTH)
output_str = decode_tokens(sample)
print(output_str)
Data source
The enwik8 data was downloaded from the Hutter prize page: prize.hutter1.net/
.\lucidrains\gated-state-spaces-pytorch\gated_state_spaces_pytorch\autoregressive_wrapper.py
# 导入 torch 库
import torch
# 导入 torch.nn.functional 模块,并重命名为 F
import torch.nn.functional as F
# 从 einops 库中导入 rearrange 函数
from einops import rearrange
# 从 torch 库中导入 nn 模块
from torch import nn
# 定义一个辅助函数,用于检查值是否存在
def exists(val):
return val is not None
# 定义一个装饰器函数,用于在模型评估时切换模型状态
def eval_decorator(fn):
def inner(model, *args, **kwargs):
was_training = model.training
model.eval()
out = fn(model, *args, **kwargs)
model.train(was_training)
return out
return inner
# 定义一个函数,用于对 logits 进行 top k 过滤
def top_k(logits, thres=0.9):
k = int((1 - thres) * logits.shape[-1])
val, ind = torch.topk(logits, k)
probs = torch.full_like(logits, float('-inf'))
probs.scatter_(1, ind, val)
return probs
# 定义一个自回归封装器类
class AutoregressiveWrapper(nn.Module):
def __init__(self, net, pad_value=0, max_seq_len=4096):
super().__init__()
self.max_seq_len = max_seq_len
self.pad_value = pad_value
self.net = net
# 生成函数,用于生成序列
@torch.no_grad()
@eval_decorator
def generate(
self,
start_tokens,
seq_len,
eos_token=None,
temperature=1.0,
filter_thres=0.9,
**kwargs
):
b, n, device = *start_tokens.shape, start_tokens.device
out = start_tokens
for _ in range(seq_len):
logits = self.net(
out[:, -self.max_seq_len:],
**kwargs
)[:, -1]
filtered_logits = top_k(logits, thres=filter_thres)
probs = F.softmax(filtered_logits / temperature, dim=-1)
sample = torch.multinomial(probs, 1)
out = torch.cat((out, sample), dim=-1)
if exists(eos_token):
is_eos_token = out == eos_token
if is_eos_token.any(dim=-1).all():
# mask out everything after the eos tokens
shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
mask = shifted_is_eos_tokens.float().cumsum(dim=-1) >= 1
out = out.masked_fill(mask, self.pad_value)
break
return out[:, n:]
# 前向传播函数,用于模型训练
def forward(self, x, **kwargs):
inp, labels = x[:, :-1], x[:, 1:]
return self.net(inp, labels=labels, **kwargs)
.\lucidrains\gated-state-spaces-pytorch\gated_state_spaces_pytorch\dsconv.py
import torch
import torch.nn.functional as F
from torch import nn, einsum
from torch.fft import rfft, irfft
from einops import rearrange
from scipy.fftpack import next_fast_len
# functions
# 检查值是否存在
def exists(val):
return val is not None
# 在张量中添加指定数量的维度
def append_dims(x, num_dims):
if num_dims <= 0:
return x
return x.view(*x.shape, *((1,) * num_dims))
# 使用傅立叶技巧进行 O(N log(N)) 的一维卷积
def conv1d_fft(x, weights, dim = -2, weight_dim = -1):
assert weight_dim >= dim
N = x.shape[dim]
M = weights.shape[weight_dim]
fast_len = next_fast_len(N + M - 1)
f_x = torch.fft.rfft(x, n = fast_len, dim = dim)
f_weight = torch.fft.rfft(weights, n = fast_len, dim = weight_dim)
f_v_weight = f_x * append_dims(f_weight.conj(), weight_dim - dim)
out = torch.fft.irfft(f_v_weight, fast_len, dim = dim)
out = out.roll(-1, dims = (dim,))
indices = torch.arange(start = fast_len - N, end = fast_len, dtype = torch.long, device = x.device)
out = out.index_select(dim, indices)
return out
# classes
# 高效的深度可分离卷积模块
class EfficientDsConv(nn.Module):
def __init__(
self,
*,
dim,
heads
):
super().__init__()
assert (dim % heads) == 0
self.heads = heads
self.norm = nn.LayerNorm(dim)
self.to_weight = nn.Linear(dim, heads, bias = False)
# 参数 D
self.param_D = nn.Parameter(torch.randn(dim))
def forward(self, x):
device, seq_len = x.device, x.shape[1]
u = self.norm(x)
# 学习的加权残差
residual = u * self.param_D
# dsconv 核取决于序列长度
K = self.to_weight(x)
K = torch.flip(K, dims = (1,))
# 一维卷积傅立叶变换 O(nlog(n))
u = rearrange(u, '... (h d) -> ... h d', h = self.heads)
out = conv1d_fft(u, K, dim = -3, weight_dim = -2)
out = rearrange(out, '... h d -> ... (h d)')
return out + residual
# 门控深度可分离卷积模块
class GatedDsConv(nn.Module):
""" Pseudocode 3.2 """
""" except state spaces replaced with regular learned convolution kernel """
def __init__(
self,
*,
dim,
heads = 8,
dim_dsconv = 512,
dim_expansion_factor = 4,
reverse_seq = False
):
super().__init__()
assert (dim_dsconv % heads) == 0
self.reverse_seq = reverse_seq
self.norm = nn.LayerNorm(dim)
dim_hidden = int(dim_expansion_factor * dim)
self.to_u = nn.Sequential(nn.Linear(dim, dim_hidden, bias = False), nn.GELU())
self.to_v = nn.Sequential(nn.Linear(dim, dim_dsconv, bias = False), nn.GELU())
self.dsconv = EfficientDsConv(dim = dim_dsconv, heads = heads)
self.to_gate = nn.Linear(dim_dsconv, dim_hidden, bias = False)
self.to_out = nn.Linear(dim_hidden, dim)
def forward(self, x):
if self.reverse_seq:
x = torch.flip(x, dims = (1,))
residual, x = x.clone(), self.norm(x)
u = self.to_u(x)
v = self.to_v(x)
v = self.dsconv(v)
uc = self.to_gate(v)
out = self.to_out(uc * u)
out = out + residual
if self.reverse_seq:
out = torch.flip(out, dims = (1,))
return out
# 门控深度可分离卷积 LM
class GatedDsConvLM(nn.Module):
def __init__(
self,
*,
num_tokens,
dim,
depth,
heads = 8,
dim_dsconv = 512,
max_seq_len = 2048,
dim_expansion_factor = 4,
):
# 初始化函数,继承父类的初始化方法
super().__init__()
# 创建一个嵌入层,用于将输入的 token 转换为指定维度的向量表示
self.token_emb = nn.Embedding(num_tokens, dim)
# 设置最大序列长度
self.max_seq_len = max_seq_len
# 创建一个空的神经网络层列表
self.layers = nn.ModuleList([])
# 根据深度循环创建 GatedDsConv 层,并添加到神经网络层列表中
for _ in range(depth):
self.layers.append(
GatedDsConv(
dim = dim,
heads = heads,
dim_dsconv = dim_dsconv,
dim_expansion_factor = dim_expansion_factor
)
)
# 创建一个线性层,用于将输出的向量转换为预测的 token
self.to_logits = nn.Linear(dim, num_tokens, bias = False)
def forward(self, x, labels = None):
# 断言输入的序列长度不超过最大序列长度
assert x.shape[1] <= self.max_seq_len
# 将输入的 token 转换为向量表示
x = self.token_emb(x)
# 遍历神经网络层列表,依次对输入进行处理
for dsconv in self.layers:
x = dsconv(x)
# 将处理后的向量转换为预测的 token
logits = self.to_logits(x)
# 如果没有提供标签,则直接返回预测结果
if not exists(labels):
return logits
# 重新排列预测结果的维度,以便计算交叉熵损失
logits = rearrange(logits, 'b n c -> b c n')
# 计算交叉熵损失并返回
return F.cross_entropy(logits, labels)
.\lucidrains\gated-state-spaces-pytorch\gated_state_spaces_pytorch\gss.py
import torch
import torch.nn.functional as F
from torch import nn, einsum
from torch.fft import rfft, irfft
from einops import rearrange
# functions
# 检查值是否存在
def exists(val):
return val is not None
# classes
# 定义 DSS 类
class DSS(nn.Module):
def __init__(
self,
*,
dim,
kernel_N = 512,
dss_kernel_lambda_imag_exp = True
):
super().__init__()
self.norm = nn.LayerNorm(dim)
# Lambda
# 初始化 Lambda 的实部参数
self.Lambda_real = nn.Parameter(torch.randn(kernel_N))
# 初始化 Lambda 的虚部参数
self.Lambda_imag = nn.Parameter(torch.randn(kernel_N))
# C
# 初始化 C 的实部参数
self.C_real = nn.Parameter(torch.randn(dim, kernel_N))
# 初始化 C 的虚部参数
self.C_imag = nn.Parameter(torch.randn(dim, kernel_N))
# params D
# 初始化参数 D
self.param_D = nn.Parameter(torch.randn(dim))
# 是否对 Lambda 的虚部进行指数运算
self.dss_kernel_lambda_imag_exp = dss_kernel_lambda_imag_exp
def forward(self, x):
"""
einstein notation:
b - batch
l - sequence length
d - dimension
"""
device, seq_len = x.device, x.shape[1]
u = self.norm(x)
# learned weighted residual
# 计算加权残差
residual = u * self.param_D
# derive simple dss kernel
# 计算简单的 DSS 核
Lambda_imag = self.Lambda_imag.exp() if self.dss_kernel_lambda_imag_exp else self.Lambda_imag
Lambda = -self.Lambda_real.exp() + 1j * Lambda_imag
C = self.C_real + 1j * self.C_imag
arange = torch.arange(seq_len, device = device)
S = (rearrange(Lambda, 'n -> n 1') * rearrange(arange, 'l -> 1 l')).exp()
C = C * (Lambda.exp() - 1) / Lambda
K = einsum('h n, n l -> l h', C, S).real
# conv1d fft O(nlog(n))
u_f = rfft(u, n = seq_len * 2, dim = -2)
K_f = rfft(K, n = seq_len * 2, dim = -2)
y = irfft(u_f * K_f, seq_len * 2, dim = -2)[..., :seq_len, :]
return y + residual
# 定义 GSS 类
class GSS(nn.Module):
""" Pseudocode 3.2 """
def __init__(
self,
*,
dim,
dim_expansion_factor = 4,
dss_kernel_N = 512,
dss_kernel_H = 256,
reverse_seq = False,
dss_kernel_lambda_imag_exp = True
):
super().__init__()
self.reverse_seq = reverse_seq
self.norm = nn.LayerNorm(dim)
dim_hidden = int(dim_expansion_factor * dim)
self.to_u = nn.Sequential(nn.Linear(dim, dim_hidden, bias = False), nn.GELU())
self.to_v = nn.Sequential(nn.Linear(dim, dss_kernel_H, bias = False), nn.GELU())
self.dss = DSS(dim = dss_kernel_H, kernel_N = dss_kernel_N, dss_kernel_lambda_imag_exp = dss_kernel_lambda_imag_exp)
self.to_gate = nn.Linear(dss_kernel_H, dim_hidden, bias = False)
self.to_out = nn.Linear(dim_hidden, dim)
def forward(self, x):
if self.reverse_seq:
x = torch.flip(x, dims = (1,))
residual, x = x.clone(), self.norm(x)
u = self.to_u(x)
v = self.to_v(x)
v = self.dss(v)
uc = self.to_gate(v)
out = self.to_out(uc * u)
out = out + residual
if self.reverse_seq:
out = torch.flip(out, dims = (1,))
return out
# Gated State Spaces LM
# 定义 GatedStateSpacesLM 类
class GatedStateSpacesLM(nn.Module):
def __init__(
self,
*,
num_tokens,
dim,
depth,
dim_expansion_factor = 4,
dss_kernel_N = 512,
dss_kernel_H = 256,
dss_kernel_lambda_imag_exp = True
# 初始化函数,继承父类的初始化方法
):
# 调用父类的初始化方法
super().__init__()
# 创建一个嵌入层,用于将输入的标记转换为向量表示
self.token_emb = nn.Embedding(num_tokens, dim)
# 创建一个空的模块列表,用于存储多个 GSS 模块
self.layers = nn.ModuleList([])
# 循环创建 depth 次 GSS 模块,并添加到模块列表中
for _ in range(depth):
self.layers.append(
GSS(
dim = dim,
dss_kernel_H = dss_kernel_H,
dss_kernel_N = dss_kernel_N,
dim_expansion_factor = dim_expansion_factor,
dss_kernel_lambda_imag_exp = dss_kernel_lambda_imag_exp
)
)
# 创建一个线性层,用于将模型输出的向量转换为预测的标记
self.to_logits = nn.Linear(dim, num_tokens, bias = False)
# 前向传播函数,接收输入 x 和标签 labels
def forward(self, x, labels = None):
# 将输入的标记转换为向量表示
x = self.token_emb(x)
# 遍历模块列表中的每个 GSS 模块,依次对输入进行处理
for gss in self.layers:
x = gss(x)
# 将处理后的向量转换为预测的标记
logits = self.to_logits(x)
# 如果没有提供标签,则直接返回预测结果
if not exists(labels):
return logits
# 重新排列 logits 的维度,以适应交叉熵损失函数的输入要求
logits = rearrange(logits, 'b n c -> b c n')
# 计算交叉熵损失并返回
return F.cross_entropy(logits, labels)
.\lucidrains\gated-state-spaces-pytorch\gated_state_spaces_pytorch\mhesa.py
import torch
import torch.nn.functional as F
from torch import nn, einsum
from torch.fft import rfft, irfft
from einops import rearrange
from scipy.fftpack import next_fast_len
# functions
# 检查值是否存在
def exists(val):
return val is not None
# 在张量中添加指定数量的维度
def append_dims(x, num_dims):
if num_dims <= 0:
return x
return x.view(*x.shape, *((1,) * num_dims))
# 使用傅立叶技巧进行 O(N log(N)) 的一维卷积
def conv1d_fft(x, weights, dim = -2, weight_dim = -1):
assert weight_dim >= dim
N = x.shape[dim]
M = weights.shape[weight_dim]
fast_len = next_fast_len(N + M - 1)
f_x = torch.fft.rfft(x, n = fast_len, dim = dim)
f_weight = torch.fft.rfft(weights, n = fast_len, dim = weight_dim)
f_v_weight = f_x * append_dims(f_weight.conj(), weight_dim - dim)
out = torch.fft.irfft(f_v_weight, fast_len, dim = dim)
out = out.roll(-1, dims = (dim,))
indices = torch.arange(start = fast_len - N, end = fast_len, dtype = torch.long, device = x.device)
out = out.index_select(dim, indices)
return out
# classes
# MHESA 模块
class MHESA(nn.Module):
""" used for time-series in ETSFormer https://arxiv.org/abs/2202.01381 """
def __init__(
self,
*,
dim,
heads,
reverse_seq = False
):
super().__init__()
assert (dim % heads) == 0
self.reverse_seq = reverse_seq
self.heads = heads
self.norm = nn.LayerNorm(dim)
self.alphas = nn.Parameter(torch.randn(heads))
self.dampen_factors = nn.Parameter(torch.randn(heads))
# params D
self.param_D = nn.Parameter(torch.randn(dim))
def forward(self, x):
"""
einstein notation:
b - batch
h - heads
l - sequence length
d - dimension
"""
if self.reverse_seq:
x = torch.flip(x, dims = (1,))
device, seq_len = x.device, x.shape[1]
u = self.norm(x)
# learned weighted residual
residual = u * self.param_D
# weights derived from alphas (learned exponential smoothing decay rate)
alphas = self.alphas.sigmoid()
dampen_factors = self.dampen_factors.sigmoid()
reversed_powers = torch.arange(seq_len - 1, -1, -1, device = device)
K = alphas * (((1 - alphas) * dampen_factors) ** rearrange(reversed_powers, '... l -> ... l 1'))
# conv1d fft O(nlog(n))
u = rearrange(u, '... (h d) -> ... h d', h = self.heads)
out = conv1d_fft(u, K, dim = -3, weight_dim = -2)
out = rearrange(out, '... h d -> ... (h d)')
out = out + residual
if self.reverse_seq:
out = torch.flip(out, dims = (1,))
return out
# GatedMHESA 模块
class GatedMHESA(nn.Module):
""" Pseudocode 3.2 """
""" except state spaces replaced with multi-head exponential smoothing with learned alpha """
""" used for time-series in ETSFormer https://arxiv.org/abs/2202.01381 """
def __init__(
self,
*,
dim,
heads = 8,
dim_mhesa = 512,
dim_expansion_factor = 4,
):
super().__init__()
assert (dim_mhesa % heads) == 0
self.norm = nn.LayerNorm(dim)
dim_hidden = int(dim_expansion_factor * dim)
self.to_u = nn.Sequential(nn.Linear(dim, dim_hidden, bias = False), nn.GELU())
self.to_v = nn.Sequential(nn.Linear(dim, dim_mhesa, bias = False), nn.GELU())
self.mhesa = MHESA(dim = dim_mhesa, heads = heads)
self.to_gate = nn.Linear(dim_mhesa, dim_hidden, bias = False)
self.to_out = nn.Linear(dim_hidden, dim)
def forward(self, x):
residual, x = x.clone(), self.norm(x)
u = self.to_u(x)
v = self.to_v(x)
v = self.mhesa(v)
uc = self.to_gate(v)
out = self.to_out(uc * u)
return out + residual
# Gated Dsconv LM
class GatedExponentialSmoothingLM(nn.Module):
# 初始化函数,设置模型参数
def __init__(
self,
*,
num_tokens, # 标记的数量
dim, # 向量维度
depth, # 模型深度
heads = 8, # 多头注意力机制的头数
dim_mhesa = 512, # MHESA 模块的维度
dim_expansion_factor = 4, # 扩展因子
):
super().__init__()
# 创建标记嵌入层
self.token_emb = nn.Embedding(num_tokens, dim)
# 创建多个 GatedMHESA 层
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
GatedMHESA(
dim = dim,
heads = heads,
dim_mhesa = dim_mhesa,
dim_expansion_factor = dim_expansion_factor
)
)
# 创建输出层
self.to_logits = nn.Linear(dim, num_tokens, bias = False)
# 前向传播函数
def forward(self, x, labels = None):
# 对输入进行标记嵌入
x = self.token_emb(x)
# 遍历多个 GatedMHESA 层
for mhesa in self.layers:
x = mhesa(x)
# 将结果传入输出层
logits = self.to_logits(x)
# 如果没有标签,则直接返回结果
if not exists(labels):
return logits
# 重新排列 logits 的维度
logits = rearrange(logits, 'b n c -> b c n')
# 计算交叉熵损失
return F.cross_entropy(logits, labels)
.\lucidrains\gated-state-spaces-pytorch\gated_state_spaces_pytorch\__init__.py
# 从 gated_state_spaces_pytorch.gss 模块中导入 GSS 和 GatedStateSpacesLM 类
from gated_state_spaces_pytorch.gss import GSS, GatedStateSpacesLM
# 从 gated_state_spaces_pytorch.dsconv 模块中导入 GatedDsConv 和 GatedDsConvLM 类
from gated_state_spaces_pytorch.dsconv import GatedDsConv, GatedDsConvLM
# 从 gated_state_spaces_pytorch.mhesa 模块中导入 GatedExponentialSmoothingLM 和 GatedMHESA 类
from gated_state_spaces_pytorch.mhesa import GatedExponentialSmoothingLM, GatedMHESA

Gated State Spaces - Pytorch
Implementation of Gated State Spaces, from the paper Long Range Language Modeling via Gated State Spaces, in Pytorch. In particular, it will contain the hybrid version containing local self attention with the long-range GSS.
It will also contain a few more settings to compare state spaces to a sequence-wise GLU depthwise conv, and even simpler, a parameterized exponential moving average along the sequence dimension. So we get to the bottom of whether state spaces are worth it, or whether it is really all about the O(L log(L)) FFT convolution trick. Results will be shared in the readme.
I will also pit the GSS module against the Path-X challenge and see how well it does.
Update: This paper has beat S4 on LRA using multi-headed EMA + single head attention.
Install
$ pip install gated-state-spaces-pytorch
Usage
import torch
from gated_state_spaces_pytorch import GSS
gss = GSS(
dim = 512, # dimension
dim_expansion_factor = 4, # hidden dimension (expansion factor x dim) = 2048
dss_kernel_N = 512,
dss_kernel_H = 256
)
x = torch.randn(1, 65536, 512)
out = gss(x) # (1, 65536, 512)
Gated state spaces language model
import torch
from gated_state_spaces_pytorch import GatedStateSpacesLM
gss_lm = GatedStateSpacesLM(
num_tokens = 20000,
depth = 12,
dim = 512,
dim_expansion_factor = 4,
dss_kernel_N = 512,
dss_kernel_H = 256
)
ids = torch.randint(0, 20000, (1, 1024))
logits = gss_lm(ids) # (1, 1024, 20000)
Todo
- enwik8
- gss lm class
- add dsconv + learned ema
- add attention.
Citations
@inproceedings{Mehta2022LongRL,
title = {Long Range Language Modeling via Gated State Spaces},
author = {Harsh Mehta and Ankit Gupta and Ashok Cutkosky and Behnam Neyshabur},
year = {2022}
}
@misc{woo2022etsformer,
title = {ETSformer: Exponential Smoothing Transformers for Time-series Forecasting},
author = {Gerald Woo and Chenghao Liu and Doyen Sahoo and Akshat Kumar and Steven Hoi},
year = {2022},
eprint = {2202.01381},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
.\lucidrains\gated-state-spaces-pytorch\setup.py
# 导入设置工具和查找包的函数
from setuptools import setup, find_packages
# 设置包的元数据
setup(
# 包的名称
name = 'gated-state-spaces-pytorch',
# 查找所有包,不排除任何包
packages = find_packages(exclude=[]),
# 版本号
version = '0.1.0',
# 许可证
license='MIT',
# 描述
description = 'Gated State Spaces - GSS - Pytorch',
# 作者
author = 'Phil Wang',
# 作者邮箱
author_email = 'lucidrains@gmail.com',
# 长描述内容类型
long_description_content_type = 'text/markdown',
# 项目链接
url = 'https://github.com/lucidrains/gated-state-spaces-pytorch',
# 关键词
keywords = [
'artificial intelligence',
'deep learning',
'state spaces',
'long context'
],
# 安装依赖
install_requires=[
'einops>=0.4',
'scipy',
'torch>=1.6',
],
# 分类
classifiers=[
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\gated-state-spaces-pytorch\train.py
# 导入所需的库
import gzip
import random
import numpy as np
import torch
import torch.optim as optim
import tqdm
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
# 导入自定义的模块
from gated_state_spaces_pytorch import GatedStateSpacesLM
from gated_state_spaces_pytorch.autoregressive_wrapper import AutoregressiveWrapper
# 定义常量
NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 2e-4
VALIDATE_EVERY = 100
GENERATE_EVERY = 500
GENERATE_LENGTH = 1024
SEQ_LEN = 4096
# 定义辅助函数
# 生成数据加载器的无限循环
def cycle(loader):
while True:
for data in loader:
yield data
# 将 token 解码为字符
def decode_token(token):
return str(chr(max(32, token)))
# 将 tokens 解码为字符串
def decode_tokens(tokens):
return "".join(list(map(decode_token, tokens)))
# 实例化类似 GPT 的解码器模型
model = GatedStateSpacesLM(
num_tokens = 256,
dim = 512,
depth = 8
)
model = AutoregressiveWrapper(model)
model.cuda()
# 准备 enwik8 数据
with gzip.open("./data/enwik8.gz") as file:
X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
trX, vaX = np.split(X, [int(90e6)])
data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)
# 定义自定义数据集类
class TextSamplerDataset(Dataset):
def __init__(self, data, seq_len):
super().__init__()
self.data = data
self.seq_len = seq_len
def __getitem__(self, index):
rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
full_seq = self.data[rand_start : rand_start + self.seq_len + 1].long()
return full_seq.cuda()
def __len__(self):
return self.data.size(0) // self.seq_len
# 创建训练集和验证集的数据加载器
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
train_loader = cycle(DataLoader(train_dataset, batch_size=BATCH_SIZE))
val_loader = cycle(DataLoader(val_dataset, batch_size=BATCH_SIZE))
# 定义优化器
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
# 训练模型
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"):
model.train()
for __ in range(GRADIENT_ACCUMULATE_EVERY):
loss = model(next(train_loader))
loss.backward()
print(f"training loss: {loss.item()}")
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
optim.step()
optim.zero_grad()
if i % VALIDATE_EVERY == 0:
model.eval()
with torch.no_grad():
loss = model(next(val_loader))
print(f"validation loss: {loss.item()}")
if i % GENERATE_EVERY == 0:
model.eval()
inp = random.choice(val_dataset)[:-1]
prime = decode_tokens(inp)
print(f"%s \n\n %s", (prime, "*" * 100))
sample = model.generate(inp[None, ...], GENERATE_LENGTH)
output_str = decode_tokens(sample[0])
print(output_str)
Data source
The enwik8 data was downloaded from the Hutter prize page: prize.hutter1.net/
.\lucidrains\gateloop-transformer\gateloop_transformer\associative_scan.py
# 从 S5-pytorch 代码库中获取的代码段
# https://github.com/i404788/s5-pytorch/blob/74e2fdae00b915a62c914bf3615c0b8a4279eb84/s5/jax_compat.py#L51-L134
# 将被调整以在小规模上测试 GateLoop https://arxiv.org/abs/2311.01927
import torch
from torch import Tensor
import torch.nn.functional as F
from typing import Tuple, Callable
# 辅助函数
def pad_at_dim(t, pad, dim = -1, value = 0.):
# 在指定维度上填充张量
dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
zeros = ((0, 0) * dims_from_right)
return F.pad(t, (*zeros, *pad), value = value)
# Pytorch 实现的 jax.lax.associative_scan
# 专门用于轴为1的情况(用于自回归建模的令牌序列)
def associative_scan(
operator: Callable,
elems: Tuple[Tensor, Tensor]
):
num_elems = int(elems[0].shape[1])
if not all(int(elem.shape[1]) == num_elems for elem in elems[1:]):
raise ValueError('Array inputs to associative_scan must have the same '
'first dimension. (saw: {})'
.format([elem.shape for elem in elems]))
def _scan(elems):
"""对 `elems` 执行扫描操作."""
num_elems = elems[0].shape[1]
if num_elems < 2:
return elems
# 组合相邻的元素对。
reduced_elems = operator(
[elem[:, :-1:2] for elem in elems],
[elem[:, 1::2] for elem in elems])
# 递归计算部分减少张量的扫描。
odd_elems = _scan(reduced_elems)
if num_elems % 2 == 0:
even_elems = operator(
[e[:, :-1] for e in odd_elems],
[e[:, 2::2] for e in elems])
else:
even_elems = operator(
odd_elems,
[e[:, 2::2] for e in elems])
# 扫描的第一个元素与原始 `elems` 的第一个元素相同。
even_elems = [
torch.cat([elem[:, :1], result], dim=1)
for (elem, result) in zip(elems, even_elems)]
return list(map(_interleave, even_elems, odd_elems))
return _scan(elems)
def _interleave(a, b):
a_axis_len, b_axis_len = a.shape[1], b.shape[1]
output_axis_len = a_axis_len + b_axis_len
if (a_axis_len == (b_axis_len + 1)):
b = pad_at_dim(b, (0, 1), dim = 1)
stacked = torch.stack([a, b], dim=2)
interleaved = torch.flatten(stacked, start_dim=1, end_dim=2)
return interleaved[:, :output_axis_len]
.\lucidrains\gateloop-transformer\gateloop_transformer\gateloop_transformer.py
from functools import partial # 导入 functools 模块中的 partial 函数
import torch # 导入 torch 库
from torch.nn import Module, ModuleList # 从 torch.nn 模块中导入 Module 和 ModuleList 类
from torch import nn, einsum, Tensor # 从 torch 模块中导入 nn、einsum 和 Tensor
from torch.utils.checkpoint import checkpoint # 从 torch.utils.checkpoint 模块导入 checkpoint 函数
import torch.nn.functional as F # 导入 torch.nn.functional 模块并重命名为 F
from einops import rearrange # 导入 einops 库中的 rearrange 函数
from einops.layers.torch import Rearrange # 从 einops.layers.torch 模块中导入 Rearrange 类
from rotary_embedding_torch import RotaryEmbedding # 导入 rotary_embedding_torch 库中的 RotaryEmbedding 类
from gateloop_transformer.associative_scan import associative_scan # 从 gateloop_transformer.associative_scan 模块中导入 associative_scan 函数
# helpers
def exists(v): # 定义 exists 函数,用于判断变量是否存在
return v is not None # 返回变量是否不为 None
def default(v, d): # 定义 default 函数,用于返回变量或默认值
return v if exists(v) else d # 如果变量存在则返回变量,否则返回默认值
def Sequential(*modules): # 定义 Sequential 函数,用于创建序列模块
modules = list(filter(exists, modules)) # 过滤掉不存在的模块
num_modules = len(modules) # 获取模块数量
if num_modules == 0: # 如果模块数量为 0
return nn.Identity() # 返回一个恒等映射的模块
elif num_modules == 1: # 如果模块数量为 1
return modules[0] # 返回该模块
return nn.Sequential(*modules) # 返回包含所有模块的序列模块
# rms norm
class RMSNorm(Module): # 定义 RMSNorm 类,用于实现 RMS 归一化
def __init__(self, dim): # 初始化方法
super().__init__() # 调用父类的初始化方法
self.scale = dim ** 0.5 # 计算缩放因子
self.gamma = nn.Parameter(torch.ones(dim)) # 创建可学习参数 gamma
def forward(self, x): # 前向传播方法
return F.normalize(x, dim=-1) * self.scale * self.gamma # 对输入进行归一化并乘以缩放因子和 gamma
# norm wrappers
class PreNorm(Module): # 定义 PreNorm 类,用于实现预归一化
def __init__(self, dim, fn: Module): # 初始化方法
super().__init__() # 调用父类的初始化方法
self.fn = fn # 保存传入的模块
self.norm = RMSNorm(dim) # 创建 RMSNorm 归一化模块
def forward(self, x, **kwargs): # 前向传播方法
return self.fn(self.norm(x), **kwargs) + x # 对输入进行归一化后,再应��传入的模块并加上原始输入
class PostNorm(Module): # 定义 PostNorm 类,用于实现后归一化
def __init__(self, dim, fn: Module): # 初始化方法
super().__init__() # 调用父类的初始化方法
self.fn = fn # 保存传入的模块
self.norm = nn.LayerNorm(dim) # 创建 LayerNorm 归一化模块
def forward(self, x, **kwargs): # 前向传播方法
return self.norm(self.fn(x, **kwargs) + x) # 应用传入的模块后,再对结果进行归一化并加上原始输入
# feedforward
def FeedForward(dim, mult=4): # 定义 FeedForward 函数,用于创建前馈神经网络
dim_inner = dim * mult # 计算内部维度
return nn.Sequential( # 返回一个序列模块
nn.Linear(dim, dim_inner), # 线性变换层
nn.GELU(), # GELU 激活函数
nn.Linear(dim_inner, dim) # 线性变换层
)
# attention
class CausalFullAttention(Module): # 定义 CausalFullAttention 类,用于实现自回归注意力机制
def __init__(
self,
dim,
*,
dim_head=64,
heads=8,
rotary_emb=False,
add_swish_gating=False,
data_dependent_rel_pos=False,
frac_gradient_data_dependent_rel_pos=0.5,
softmax_normalize=None
): # 初始化方法
super().__init__() # 调用父类的初始化方法
dim_inner = dim_head * heads # 计算内部维度
self.softmax_normalize = default(softmax_normalize, not data_dependent_rel_pos) # 设置 softmax 归一化参数
self.scale = dim_head ** -0.5 # 计算缩放因子
self.rotary_emb = RotaryEmbedding(dim_head) if rotary_emb else None # 创建旋转嵌入对象(如果需要)
self.to_qkv = nn.Sequential( # 创建 Q、K、V 投影模块
nn.Linear(dim, dim_inner * 3, bias=False), # 线性变换层
Rearrange('b n (qkv h d) -> qkv b h n d', h=heads, qkv=3) # 重排张量维度
)
self.data_dependent_rel_pos = data_dependent_rel_pos # 是否使用数据相关的相对位置编码
self.frac_gradient_data_dependent_rel_pos = frac_gradient_data_dependent_rel_pos # 数据相关的相对位置编码的梯度比例
if data_dependent_rel_pos: # 如果使用数据相关的相对位置编码
self.to_a = nn.Sequential( # 创建相对位置编码模块
nn.Linear(dim, dim_inner, bias=False), # 线性变换层
Rearrange('b n (h d c) -> b h n d c', h=heads, c=2) # 重排张量维度
)
self.to_gates = None # 初始化门控模块为 None
if add_swish_gating: # 如果添加 Swish 门控
self.to_gates = nn.Sequential( # 创建门控模块
nn.Linear(dim, dim_inner, bias=False), # 线性变换层
nn.SiLU(), # Swish 激活函数
Rearrange('b n (h d) -> b h n d', h=heads) # 重排张量维度
)
self.to_out = nn.Sequential( # 创建输出模块
Rearrange('b h n d -> b n (h d)'), # 重排张量维度
nn.Linear(dim_inner, dim) # 线性变换层
)
def forward(
self,
x,
ablate_complex=False,
ablate_state_transition=False
):
# 将输入 x 转换为查询 q、键 k、值 v
q, k, v = self.to_qkv(x)
# 如果存在旋转嵌入,则对查询和键进行旋转
if exists(self.rotary_emb):
q = self.rotary_emb.rotate_queries_or_keys(q)
k = self.rotary_emb.rotate_queries_or_keys(k)
# 缩放查询
q = q * self.scale
# 如果启用数据相关的相对位置编码,并且不禁用状态转换
if self.data_dependent_rel_pos and not ablate_state_transition:
# 获取数据相关的相对位置投影
frac_gradient = self.frac_gradient_data_dependent_rel_pos
# 计算相对位置投影
a = self.to_a(x)
# 允许数据相关的相对位置投影变化更慢
a = a * frac_gradient + a.detach() * (1 - frac_gradient)
# 将 a 转换为复数形式
a = torch.view_as_complex(a)
# 如果禁用复数计算
if ablate_complex:
a = a.real + 0.j
# 计算幅度和相位
magnitude, phase = a.abs(), a.angle()
a = torch.polar(magnitude.sigmoid(), phase)
# 重排形状
a = rearrange(a, '... -> ... 1')
a_cumprod = a.cumprod(dim=-2)
# 对实部进行截断
a_cumprod_real = a_cumprod.real.clamp(min=1e-10)
a_cumprod_real_inverse = 1. / a_cumprod_real
# 重排形状
q, k = map(lambda t: rearrange(t, '... (d c) -> ... d c', c=2), (q, k))
# 更新查询和键
q = q * a_cumprod_real
k = k * a_cumprod_real_inverse
# 重排形状
q, k = map(lambda t: rearrange(t, '... d c -> ... (d c)'), (q, k))
# 计算相似度
sim = einsum('b h i d, b h j d -> b h i j', q, k)
i, j = sim.shape[2:]
# 创建因果掩码
causal_mask = torch.ones((i, j), dtype=torch.bool, device=x.device).triu(j - i + 1)
# 如果启用 softmax 归一化
if self.softmax_normalize:
# 对相似度矩阵进行掩码处理
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
# 计算注意力权重
attn = sim.softmax(dim=-1)
else:
# 对相似度矩阵进行掩码处理
attn = sim.masked_fill(causal_mask, 0.)
# 计算输出
out = einsum('b h i j, b h j d -> b h i d', attn, v)
# 如果存在门控机制
if exists(self.to_gates):
# 应用门控机制
out = out * self.to_gates(x)
# 返回输出结果
return self.to_out(out)
# 定义一个函数,实现带有“gateloop操作符”的数据门控线性注意力
def gate_loop_operator(q, k, v, a):
"""
the pseudocode in section 3.2 of the paper
"""
# 计算 k 和 v 的张量积
kv = einsum('b n d, b n e -> b n d e', k, v)
# 将结果转换为复数张量
kv = kv + 0.j
# 定义一个二元操作符函数
def binary_operator(a, b):
a_i, kv_i = a
a_j, kv_j = b
return a_j * a_i, a_j * kv_i + kv_j
# 对二元操作符进行关联扫描
_, kv = associative_scan(binary_operator, (a, kv))
# 计算最终输出
return einsum('b n d, b n d e -> b n e', q, kv.real)
# GateLoopedAttention 类,继承自 Module 类
class GateLoopedAttention(Module):
def __init__(
self,
dim,
heads = None,
dim_inner = None,
checkpoint_gate_looped_attn = True,
add_swish_gating = True,
sub_ln = False,
frac_gradient_state_transition = 0.9
):
super().__init__()
self.frac_gradient_state_transition = frac_gradient_state_transition
self.checkpoint_gate_looped_attn = checkpoint_gate_looped_attn
dim_inner = default(dim_inner, dim)
heads = default(heads, dim_inner)
# 检查维度是否符合要求
assert (dim_inner % heads) == 0, f'dimension for gate looped attention {dim_inner} must be divisible by number of gate loop heads {heads}'
# 将输入张量按照头数进行分割
self.split_heads = Rearrange('b n (h d) -> (b h) n d', h = heads)
# 线性变换,将输入转换为 Q、K、V
self.to_qkv = nn.Linear(dim, dim_inner * 3, bias = False)
# 线性变换,将输入转换为注意力权重
self.to_a = nn.Sequential(
nn.Linear(dim, heads * 2),
Rearrange('b n (h c) -> (b h) n 1 1 c', h = heads, c = 2)
)
# 合并头部
self.merge_heads = Rearrange('(b h) n d -> b n (h d)', h = heads)
# 可选的 LayerNorm
self.maybe_sub_ln = nn.LayerNorm(dim_inner) if sub_ln else nn.Identity()
self.to_gates = None
# 添加 Swish 激活门控
if add_swish_gating:
self.to_gates = nn.Sequential(
nn.Linear(dim, dim_inner, bias = False),
nn.SiLU()
)
# 输出线性变换
self.to_out = nn.Linear(dim_inner, dim, bias = False) if dim_inner != dim or add_swish_gating else nn.Identity()
# 前向传播函数
def forward(
self,
x,
ablate_complex = False,
ablate_state_transition = False
):
frac_gradient = self.frac_gradient_state_transition
# 将输入 x 转换为 Q、K、V
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(self.split_heads, (q, k, v))
# 获取注意力权重
a = self.to_a(x)
a = a * frac_gradient + a.detach() * (1 - frac_gradient)
# 将注意力权重转换为复数张量
a = torch.view_as_complex(a)
# 如果 ablate_complex 为 True,则将注意力权重转换为实部
if ablate_complex:
a = a.real + 0.j
# 如果 ablate_state_transition 为 True,则将注意力权重设置为全 1
if ablate_state_transition:
a = torch.ones_like(a.real) + 0.j
else:
# 对状态转换的激活函数
# 使用 sigmoid 函数处理幅度,使用恒等函数处理相位
magnitude, phase = a.abs(), a.angle()
a = torch.polar(magnitude.sigmoid(), phase)
# 检查是否需要反向传播
need_backwards = any([t.requires_grad for t in (q, k, v, a)])
# 使用 partial 函数创建一个带有检查点的函数
fn = partial(checkpoint, gate_loop_operator) if need_backwards and self.checkpoint_gate_looped_attn else gate_loop_operator
# 计算输出
out = fn(q, k, v, a)
out = self.merge_heads(out)
out = self.maybe_sub_ln(out)
# 如果存在门控,则将门控应用到输出上
if exists(self.to_gates):
out = self.to_gates(x) * out
return self.to_out(out)
# Transformer 类,继承自 Module 类
class Transformer(Module):
def __init__(
self,
dim,
*,
num_tokens,
depth,
dim_head = 64,
heads = 8,
ff_mult = 4,
checkpoint_gate_looped_attn = True,
use_gate_looped_attn = True,
gate_loop_heads = None,
attn_add_swish_gating = True,
dim_gate_looped_attn = None,
attn_softmax_normalize = None,
data_dependent_rel_pos = False,
frac_gradient_state_transition = 0.9,
ablate_complex = False,
ablate_state_transition = False,
rotary_emb = False,
post_ln_norm = False,
sub_ln = False
# 初始化函数,设置模型的参数
):
# 调用父类的初始化函数
super().__init__()
# 设置是否削弱复杂性和状态转换的参数
self.ablate_complex = ablate_complex
self.ablate_state_transition = ablate_state_transition
# 创建一个词嵌入层
self.token_emb = nn.Embedding(num_tokens, dim)
# 创建一个模块列表用于存储每个层的注意力和前馈网络
layers = ModuleList([])
# 根据是否后层归一化选择层包装器
layer_wrapper = PreNorm if not post_ln_norm else PostNorm
# 循环创建指定深度的层
for _ in range(depth):
# 根据是否使用门控循环注意力选择空间混合器类型
if use_gate_looped_attn:
spatial_mixer = GateLoopedAttention(
dim = dim,
heads = gate_loop_heads,
dim_inner = dim_gate_looped_attn,
add_swish_gating = attn_add_swish_gating,
sub_ln = sub_ln,
checkpoint_gate_looped_attn = checkpoint_gate_looped_attn,
frac_gradient_state_transition = frac_gradient_state_transition
)
else:
spatial_mixer = CausalFullAttention(
dim = dim,
dim_head = dim_head,
heads = heads,
rotary_emb = rotary_emb,
add_swish_gating = attn_add_swish_gating,
softmax_normalize = attn_softmax_normalize,
data_dependent_rel_pos = data_dependent_rel_pos,
frac_gradient_data_dependent_rel_pos = frac_gradient_state_transition
)
# 创建通道混合器
channelwise_mixer = FeedForward(
dim = dim,
mult = ff_mult
)
# 将空间混合器和通道混合器添加到层列表中
layers.append(ModuleList([
layer_wrapper(dim, spatial_mixer),
layer_wrapper(dim, channelwise_mixer)
]))
# 将层列表转换为模块列表
self.layers = ModuleList(layers)
# 创建输出层,包括 RMS 归一化和线性层
self.to_logits = Sequential(
RMSNorm(dim) if not post_ln_norm else None,
nn.Linear(dim, num_tokens, bias = False)
)
# 前向传播函数
def forward(
self,
x,
return_loss = False,
ablate_complex = None,
ablate_state_transition = None
):
# 设置是否削弱复杂性和状态转换的参数
ablate_complex = default(ablate_complex, self.ablate_complex)
ablate_state_transition = default(ablate_state_transition, self.ablate_state_transition)
# 如果需要返回损失,则提取标签
if return_loss:
x, labels = x[:, :-1], x[:, 1:]
# 对输入进行词嵌入
x = self.token_emb(x)
# 遍历每个层的注意力和前馈网络
for attn, ff in self.layers:
# 使用注意力层
x = attn(
x,
ablate_complex = ablate_complex,
ablate_state_transition = ablate_state_transition
)
# 使用前馈网络
x = ff(x)
# 获取最终输出
logits = self.to_logits(x)
# 如果不需要返回损失,则直接返回输出
if not return_loss:
return logits
# 重新排列输出并计算交叉熵损失
logits = rearrange(logits, 'b n c -> b c n')
return F.cross_entropy(logits, labels)