Lucidrains 系列项目源码解析(三十七)
.\lucidrains\flash-genomics-model\flash_genomics_model\__init__.py
# 从flash_genomics_model.flash_genomics_model模块中导入FlashGenomicsModel类
from flash_genomics_model.flash_genomics_model import FlashGenomicsModel
Flash Genomics Model (FGM)
My own attempt at a long context genomics model, leveraging recent advances in long context attention modeling (Flash Attention + other hierarchical methods).
If you would like to collaborate and not averse to having the final model completely open sourced, get in touch. My goal is to simply figure out if long context is what holds us back from a new SOTA model.
Update: Question has been answered, but I'll probably still continue with the project.
Todo
- add ability to combine with hyena, or even hyena alone, to fully evaluate what is necessary
Citations
@inproceedings{dao2022flashattention,
title = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
author = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
booktitle = {Advances in Neural Information Processing Systems},
year = {2022}
}
@article {Dalla-Torre2023.01.11.523679,
author = {Hugo Dalla-Torre and Liam Gonzalez and Javier Mendoza Revilla and Nicolas Lopez Carranza and Adam Henryk Grzywaczewski and Francesco Oteri and Christian Dallago and Evan Trop and Hassan Sirelkhatim and Guillaume Richard and Marcin Skwark and Karim Beguir and Marie Lopez and Thomas Pierrot},
title = {The Nucleotide Transformer: Building and Evaluating Robust Foundation Models for Human Genomics},
elocation-id = {2023.01.11.523679},
year = {2023},
doi = {10.1101/2023.01.11.523679},
publisher = {Cold Spring Harbor Laboratory},
URL = {https://www.biorxiv.org/content/early/2023/01/15/2023.01.11.523679},
eprint = {https://www.biorxiv.org/content/early/2023/01/15/2023.01.11.523679.full.pdf},
journal = {bioRxiv}
}
@article {Benegas2022.08.22.504706,
author = {Gonzalo Benegas and Sanjit Singh Batra and Yun S. Song},
title = {DNA language models are powerful zero-shot predictors of genome-wide variant effects},
elocation-id = {2022.08.22.504706},
year = {2023},
doi = {10.1101/2022.08.22.504706},
publisher = {Cold Spring Harbor Laboratory},
URL = {https://www.biorxiv.org/content/early/2023/04/12/2022.08.22.504706},
eprint = {https://www.biorxiv.org/content/early/2023/04/12/2022.08.22.504706.full.pdf},
journal = {bioRxiv}
}
@article{Nguyen2023HyenaDNALG,
title = {HyenaDNA: Long-Range Genomic Sequence Modeling at Single Nucleotide Resolution},
author = {Eric D Nguyen and Michael Poli and Marjan Faizi and Armin W. Thomas and Callum Jacob Birch-sykes and Michael Wornow and Aman Patel and Clayton M. Rabideau and Stefano Massaroli and Yoshua Bengio and Stefano Ermon and Stephen A. Baccus and Christopher R{\'e}},
journal = {ArXiv},
year = {2023},
volume = {abs/2306.15794}
}
.\lucidrains\flash-genomics-model\setup.py
# 导入设置工具和查找包的函数
from setuptools import setup, find_packages
# 设置包的元数据
setup(
name = 'flash-genomics-model', # 包的名称
packages = find_packages(exclude=[]), # 查找所有包
version = '0.0.1', # 版本号
license='MIT', # 许可证
description = 'Flash Genomics Model (FGM)', # 描述
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
long_description_content_type = 'text/markdown', # 长描述内容类型
url = 'https://github.com/lucidrains/flash-genomics-model', # 项目链接
keywords = [ # 关键词列表
'artificial intelligence',
'deep learning',
'transformers',
'attention mechanism',
'long context',
'genomics',
'pre-training'
],
install_requires=[ # 安装依赖
'einops>=0.6.1',
'MEGABYTE-pytorch',
'torch>=1.6',
],
classifiers=[ # 分类器
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
Data source
The enwik8 data was downloaded from the Hutter prize page: prize.hutter1.net/
.\lucidrains\FLASH-pytorch\flash_pytorch\autoregressive_wrapper.py
# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块
from torch import nn
# 从 torch 库中导入 nn.functional 模块,并重命名为 F
import torch.nn.functional as F
# 从 einops 库中导入 rearrange 函数
from einops import rearrange
# 辅助函数
# 判断值是否存在的函数
def exists(val):
return val is not None
# 评估装饰器函数
def eval_decorator(fn):
# 内部函数
def inner(model, *args, **kwargs):
# 保存模型原始训练状态
was_training = model.training
# 将模型设置为评估模式
model.eval()
# 调用传入的函数
out = fn(model, *args, **kwargs)
# 恢复模型原始训练状态
model.train(was_training)
return out
return inner
# top k 过滤
# 根据阈值过滤 logits 中的 top k 值
def top_k(logits, thres = 0.9):
# 计算 top k 的数量
k = int((1 - thres) * logits.shape[-1])
# 获取 top k 的值和索引
val, ind = torch.topk(logits, k)
# 创建与 logits 相同形状的张量,填充为负无穷
probs = torch.full_like(logits, float('-inf'))
# 根据索引将 top k 的值填充到 probs 中
probs.scatter_(1, ind, val)
return probs
# 自回归包装器类
class AutoregressiveWrapper(nn.Module):
def __init__(self, net, pad_value = 0):
super().__init__()
self.pad_value = pad_value
self.net = net
# 无需梯度的装饰器
@torch.no_grad()
@eval_decorator
# 生成序列的方法
def generate(self, start_tokens, seq_len, eos_token = None, temperature = 1., filter_thres = 0.9, **kwargs):
# 获取起始 tokens 的形状和设备信息
b, t, device = *start_tokens.shape, start_tokens.device
# 初始化输出为起始 tokens
out = start_tokens
# 循环生成序列
for _ in range(seq_len):
# 获取模型输出 logits
logits = self.net(out, **kwargs)[:, -1, :]
# 过滤 logits 中的 top k 值
filtered_logits = top_k(logits, thres = filter_thres)
# 计算概率分布
probs = F.softmax(filtered_logits / temperature, dim=-1)
# 从概率分布中采样一个值作为下一个 token
sample = torch.multinomial(probs, 1)
# 将采样值拼接到输出序列中
out = torch.cat((out, sample), dim=-1)
# 如果存在 eos_token
if exists(eos_token):
# 判断是否出现 eos_token
is_eos_token = (out == eos_token)
if is_eos_token.any(dim = -1).all():
# 创建一个 mask,用于标记 eos_token 后的内容
shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1
# 将 mask 标记的内容替换为 pad_value
out = out.masked_fill(mask, self.pad_value)
break
# 去除起始 tokens,返回生成的序列
out = out[:, t:]
return out
# 前向传播方法
def forward(self, x, **kwargs):
# 分离输入和标签
x_inp, x_labels = x[:, :-1], x[:, 1:]
# 获取模型输出 logits
logits = self.net(x_inp, **kwargs)
# 计算交叉熵损失
return F.cross_entropy(rearrange(logits, 'b c n -> b n c'), x_labels)
.\lucidrains\FLASH-pytorch\flash_pytorch\flash_pytorch.py
# 导入数学库和 PyTorch 库
import math
import torch
import torch.nn.functional as F
from torch import nn, einsum
# 导入 einops 库中的 rearrange 函数和 rotary_embedding_torch 库中的 RotaryEmbedding 类
from einops import rearrange
from rotary_embedding_torch import RotaryEmbedding
# 辅助函数
# 判断值是否存在
def exists(val):
return val is not None
# 如果值存在则返回该值,否则返回默认值
def default(val, d):
return val if exists(val) else d
# 将数字 n 填充到最接近的 mult 的倍数
def padding_to_multiple_of(n, mult):
remainder = n % mult
if remainder == 0:
return 0
return mult - remainder
# scalenorm
# 缩放归一化层
class ScaleNorm(nn.Module):
def __init__(self, dim, eps = 1e-5):
super().__init__()
self.scale = dim ** -0.5
self.eps = eps
self.g = nn.Parameter(torch.ones(1))
def forward(self, x):
norm = torch.norm(x, dim = -1, keepdim = True) * self.scale
return x / norm.clamp(min = self.eps) * self.g
# absolute positional encodings
# 缩放的正弦嵌入层
class ScaledSinuEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
self.scale = nn.Parameter(torch.ones(1,))
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
def forward(self, x):
n, device = x.shape[1], x.device
t = torch.arange(n, device = device).type_as(self.inv_freq)
sinu = einsum('i , j -> i j', t, self.inv_freq)
emb = torch.cat((sinu.sin(), sinu.cos()), dim = -1)
return emb * self.scale
# T5 relative positional bias
# T5 相对位置偏置层
class T5RelativePositionBias(nn.Module):
def __init__(
self,
scale,
causal = False,
num_buckets = 32,
max_distance = 128
):
super().__init__()
self.scale = scale
self.causal = causal
self.num_buckets = num_buckets
self.max_distance = max_distance
self.relative_attention_bias = nn.Embedding(num_buckets, 1)
@staticmethod
def _relative_position_bucket(
relative_position,
causal = True,
num_buckets = 32,
max_distance = 128
):
ret = 0
n = -relative_position
if not causal:
num_buckets //= 2
ret += (n < 0).long() * num_buckets
n = torch.abs(n)
else:
n = torch.max(n, torch.zeros_like(n))
max_exact = num_buckets // 2
is_small = n < max_exact
val_if_large = max_exact + (
torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
).long()
val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
ret += torch.where(is_small, n, val_if_large)
return ret
def forward(self, x):
i, j, device = *x.shape[-2:], x.device
q_pos = torch.arange(i, dtype = torch.long, device = device)
k_pos = torch.arange(j, dtype = torch.long, device = device)
rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')
rp_bucket = self._relative_position_bucket(rel_pos, causal = self.causal, num_buckets = self.num_buckets, max_distance = self.max_distance)
values = self.relative_attention_bias(rp_bucket)
bias = rearrange(values, 'i j 1 -> i j')
return bias * self.scale
# class
# 偏移缩放层
class OffsetScale(nn.Module):
def __init__(self, dim, heads = 1):
super().__init__()
self.gamma = nn.Parameter(torch.ones(heads, dim))
self.beta = nn.Parameter(torch.zeros(heads, dim))
nn.init.normal_(self.gamma, std = 0.02)
def forward(self, x):
out = einsum('... d, h d -> ... h d', x, self.gamma) + self.beta
return out.unbind(dim = -2)
# activation functions
# ReLU 平方激活函数
class ReLUSquared(nn.Module):
def forward(self, x):
return F.relu(x) ** 2
# 拉普拉斯注意力函数
class LaplacianAttnFn(nn.Module):
""" https://arxiv.org/abs/2209.10655 claims this is more stable than Relu squared """
# 定义一个前向传播函数,接受输入 x
def forward(self, x):
# 计算均值 mu
mu = math.sqrt(0.5)
# 计算标准差 std
std = math.sqrt((4 * math.pi) ** -1)
# 使用误差函数计算激活函数的输出值,并返回
return (1 + torch.special.erf((x - mu) / (std * math.sqrt(2)))) * 0.5
# 定义了一个名为GAU的类,表示门控注意力单元
class GAU(nn.Module):
def __init__(
self,
*,
dim, # 输入维度
query_key_dim = 128, # 查询和键的维度,默认为128
expansion_factor = 2., # 扩展因子,默认为2
add_residual = True, # 是否添加残差连接,默认为True
causal = False, # 是否使用因果注意力,默认为False
dropout = 0., # dropout概率,默认为0
laplace_attn_fn = False, # 是否使用拉普拉斯注意力函数,默认为False
rel_pos_bias = False, # 是否使用相对位置偏置,默认为False
norm_klass = nn.LayerNorm # 规范化层,默认为nn.LayerNorm
):
super().__init__()
hidden_dim = int(expansion_factor * dim)
self.norm = norm_klass(dim) # 初始化规范化层
self.causal = causal # 初始化因果注意力标志
self.dropout = nn.Dropout(dropout) # 初始化dropout层
self.attn_fn = ReLUSquared() if not laplace_attn_fn else LaplacianAttnFn() # 初始化注意力函数
self.rel_pos_bias = T5RelativePositionBias(scale = dim ** 0.5, causal = causal) # 初始化相对位置偏置
self.to_hidden = nn.Sequential( # 隐藏层映射
nn.Linear(dim, hidden_dim * 2), # 线性变换
nn.SiLU() # 激活函数
)
self.to_qk = nn.Sequential( # 查询和键映射
nn.Linear(dim, query_key_dim), # 线性变换
nn.SiLU() # 激活函数
)
self.offsetscale = OffsetScale(query_key_dim, heads = 2) # 初始化偏移和缩放层
self.to_out = nn.Sequential( # 输出映射
nn.Linear(hidden_dim, dim), # 线性变换
nn.Dropout(dropout) # dropout层
)
self.add_residual = add_residual # 是否添加残差连接
def forward(
self,
x, # 输入张量
rel_pos_bias = None, # 相对位置偏置,默认为None
mask = None # 掩码,默认为None
):
seq_len, device = x.shape[-2], x.device # 获取序列长度和设备信息
normed_x = self.norm(x) # 规范化输入张量
v, gate = self.to_hidden(normed_x).chunk(2, dim = -1) # 隐藏层映射并分割为两部分
qk = self.to_qk(normed_x) # 查询和键映射
q, k = self.offsetscale(qk) # 偏移和缩放
sim = einsum('b i d, b j d -> b i j', q, k) # 计算相似度
if exists(self.rel_pos_bias): # 如果存在相对位置偏置
sim = sim + self.rel_pos_bias(sim) # 加上相对位置偏置
if exists(rel_pos_bias): # 如果存在传入的相对位置偏置
sim = sim + rel_pos_bias # 加上传入的相对位置偏置
attn = self.attn_fn(sim / seq_len) # 计算注意力权重
attn = self.dropout(attn) # dropout
if exists(mask): # 如果存在掩码
mask = rearrange(mask, 'b j -> b 1 j') # 重排掩码形状
attn = attn.masked_fill(~mask, 0.) # 根据掩码填充注意力权重
if self.causal: # 如果是因果注意力
causal_mask = torch.ones((seq_len, seq_len), dtype = torch.bool, device = device).triu(1) # 创建因果掩码
attn = attn.masked_fill(causal_mask, 0.) # 根据因果掩码填充注意力权重
out = einsum('b i j, b j d -> b i d', attn, v) # 计算输出
out = out * gate # 门控
out = self.to_out(out) # 输出映射
if self.add_residual: # 如果添加残差连接
out = out + x # 添加残差连接
return out # 返回输出
# 定义了一个名为FLASH的类,表示快���自注意力流水线
class FLASH(nn.Module):
def __init__(
self,
*,
dim, # 输入维度
group_size = 256, # 组大小,默认为256
query_key_dim = 128, # 查询和键的维度,默认为128
expansion_factor = 2., # 扩展因子,默认为2
causal = False, # 是否使用因果注意力,默认为False
dropout = 0., # dropout概率,默认为0
rotary_pos_emb = None, # 旋转位置嵌入,默认为None
norm_klass = nn.LayerNorm, # 规范化层,默认为nn.LayerNorm
shift_tokens = False, # 是否移动令牌,默认为False
laplace_attn_fn = False, # 是否使用拉普拉斯注意力函数,默认为False
reduce_group_non_causal_attn = True # 是否在非因果线性注意力中减少组,默认为True
):
super().__init__()
hidden_dim = int(dim * expansion_factor)
self.group_size = group_size # 组大小
self.causal = causal # 因果注意力标志
self.shift_tokens = shift_tokens # 移动令牌标志
self.attn_fn = ReLUSquared() if not laplace_attn_fn else LaplacianAttnFn() # 初始化注意力函数
# 位置嵌入
self.rotary_pos_emb = rotary_pos_emb
self.rel_pos_bias = T5RelativePositionBias(query_key_dim ** 0.5, causal = causal) # 初始化相对位置偏置
# 规范化层
self.norm = norm_klass(dim)
self.dropout = nn.Dropout(dropout)
# 是否在非因果线性注意力中减少组
self.reduce_group_non_causal_attn = reduce_group_non_causal_attn
# 投影
self.to_hidden = nn.Sequential( # 隐藏层映射
nn.Linear(dim, hidden_dim * 2), # 线性变换
nn.SiLU() # 激活函数
)
self.to_qk = nn.Sequential( # 查询和键映射
nn.Linear(dim, query_key_dim), # 线性变换
nn.SiLU() # 激活函数
)
self.qk_offset_scale = OffsetScale(query_key_dim, heads = 4) # 偏移和缩放
self.to_out = nn.Linear(hidden_dim, dim) # 输出映射
def forward(
self,
x, # 输入张量
*,
mask = None # 掩码,默认为None
):
"""
b - batch
n - sequence length (within groups)
g - group dimension
d - feature dimension (keys)
e - feature dimension (values)
i - sequence dimension (source)
j - sequence dimension (target)
"""
# 获取输入张量的形状信息
b, n, device, g = x.shape[0], x.shape[-2], x.device, self.group_size
# 对输入进行预处理
normed_x = self.norm(x)
# 执行令牌移位操作
if self.shift_tokens:
x_shift, x_pass = normed_x.chunk(2, dim = -1)
x_shift = F.pad(x_shift, (0, 0, 1, -1), value = 0.)
normed_x = torch.cat((x_shift, x_pass), dim = -1)
# 初始投影
v, gate = self.to_hidden(normed_x).chunk(2, dim = -1)
qk = self.to_qk(normed_x)
# 偏移和缩放
quad_q, lin_q, quad_k, lin_k = self.qk_offset_scale(qk)
# 屏蔽线性注意力键
if exists(mask):
lin_mask = rearrange(mask, '... -> ... 1')
lin_k = lin_k.masked_fill(~lin_mask, 0.)
# 旋转查询和键
if exists(self.rotary_pos_emb):
quad_q, lin_q, quad_k, lin_k = map(self.rotary_pos_emb.rotate_queries_or_keys, (quad_q, lin_q, quad_k, lin_k))
# 对组进行填充
padding = padding_to_multiple_of(n, g)
if padding > 0:
quad_q, quad_k, lin_q, lin_k, v = map(lambda t: F.pad(t, (0, 0, 0, padding), value = 0.), (quad_q, quad_k, lin_q, lin_k, v))
mask = default(mask, torch.ones((b, n), device = device, dtype = torch.bool))
mask = F.pad(mask, (0, padding), value = False)
# 沿着序列对组进行分组
quad_q, quad_k, lin_q, lin_k, v = map(lambda t: rearrange(t, 'b (n g) d -> b n g d', g = self.group_size), (quad_q, quad_k, lin_q, lin_k, v))
if exists(mask):
mask = rearrange(mask, 'b (g j) -> b g 1 j', j = g)
# 计算二次注意力输出
sim = einsum('... i d, ... j d -> ... i j', quad_q, quad_k) / g
sim = sim + self.rel_pos_bias(sim)
attn = self.attn_fn(sim)
attn = self.dropout(attn)
if exists(mask):
attn = attn.masked_fill(~mask, 0.)
if self.causal:
causal_mask = torch.ones((g, g), dtype = torch.bool, device = device).triu(1)
attn = attn.masked_fill(causal_mask, 0.)
quad_out = einsum('... i j, ... j d -> ... i d', attn, v)
# 计算线性注意力输出
if self.causal:
lin_kv = einsum('b g n d, b g n e -> b g d e', lin_k, v) / g
# 沿着组维度进行排他性累加
lin_kv = lin_kv.cumsum(dim = 1)
lin_kv = F.pad(lin_kv, (0, 0, 0, 0, 1, -1), value = 0.)
lin_out = einsum('b g d e, b g n d -> b g n e', lin_kv, lin_q)
else:
context_einsum_eq = 'b d e' if self.reduce_group_non_causal_attn else 'b g d e'
lin_kv = einsum(f'b g n d, b g n e -> {context_einsum_eq}', lin_k, v) / n
lin_out = einsum(f'b g n d, {context_einsum_eq} -> b g n e', lin_q, lin_kv)
# 将组折叠回完整序列,并去除填充
quad_attn_out, lin_attn_out = map(lambda t: rearrange(t, 'b g n d -> b (g n) d')[:, :n], (quad_out, lin_out))
# 门控
out = gate * (quad_attn_out + lin_attn_out)
# 投影输出并添加残差连接
return self.to_out(out) + x
# FLASH Transformer 类定义
class FLASHTransformer(nn.Module):
# 初始化函数
def __init__(
self,
*,
dim, # 特征维度
num_tokens, # token 的数量
depth, # 层数
group_size = 256, # 分组大小,默认为 256
query_key_dim = 128, # 查询键的维度,默认为 128
expansion_factor = 2., # 扩展因子,默认为 2.0
causal = False, # 是否是因果的,默认为 False
attn_dropout = 0., # 注意力机制的 dropout,默认为 0
norm_type = 'scalenorm', # 归一化类型,默认为 scalenorm
shift_tokens = True, # 是否移动 token,默认为 True
laplace_attn_fn = False, # 是否使用拉普拉斯注意力函数,默认为 False
reduce_group_non_causal_attn = True # 是否减少非因果注意力,默认为 True
):
super().__init__()
# 断言,确保 norm_type 是 scalenorm 或 layernorm
assert norm_type in ('scalenorm', 'layernorm'), 'norm_type must be one of scalenorm or layernorm'
# 根据 norm_type 选择不同的归一化类
if norm_type == 'scalenorm':
norm_klass = ScaleNorm
elif norm_type == 'layernorm':
norm_klass = nn.LayerNorm
# 创建 token 的嵌入层
self.token_emb = nn.Embedding(num_tokens, dim)
# 创建绝对位置嵌入层
self.abs_pos_emb = ScaledSinuEmbedding(dim)
# 设置分组大小
self.group_size = group_size
# 创建旋转位置嵌入层
rotary_pos_emb = RotaryEmbedding(dim = min(32, query_key_dim))
# 最大旋转嵌入维度为 32,部分旋转嵌入,来自 Wang 等人 - GPT-J
# 创建多层 FLASH 模块
self.layers = nn.ModuleList([FLASH(dim = dim, group_size = group_size, query_key_dim = query_key_dim, expansion_factor = expansion_factor, causal = causal, dropout = attn_dropout, rotary_pos_emb = rotary_pos_emb, norm_klass = norm_klass, shift_tokens = shift_tokens, reduce_group_non_causal_attn = reduce_group_non_causal_attn, laplace_attn_fn = laplace_attn_fn) for _ in range(depth)])
# 创建输出层
self.to_logits = nn.Sequential(
nn.LayerNorm(dim), # 归一化层
nn.Linear(dim, num_tokens) # 线性层,将特征维度映射到 token 数量
)
# 前向传播函数
def forward(
self,
x, # 输入张量
*,
mask = None # 掩码,默认为 None
):
x = self.token_emb(x) # 对输入张量进行 token 嵌入
x = self.abs_pos_emb(x) + x # 添加绝对位置嵌入
# 遍历每个 FLASH 模块
for flash in self.layers:
x = flash(x, mask = mask) # 调用 FLASH 模块的前向传播
return self.to_logits(x) # 返回输出结果
.\lucidrains\FLASH-pytorch\flash_pytorch\__init__.py
# 从 flash_pytorch.flash_pytorch 模块中导入 GAU, FLASH, FLASHTransformer 类
from flash_pytorch.flash_pytorch import GAU, FLASH, FLASHTransformer

FLASH - Pytorch
Implementation of the Transformer variant proposed in the paper Transformer Quality in Linear Time
Install
$ pip install FLASH-pytorch
Usage
The main novel circuit in this paper is the "Gated Attention Unit", which they claim can replace multi-headed attention while reducing it to just one head.
It uses a relu squared activation in place of the softmax, the activation of which was first seen in the Primer paper, and the use of ReLU in ReLA Transformer. The gating style seems mostly inspired by gMLPs.
import torch
from flash_pytorch import GAU
gau = GAU(
dim = 512,
query_key_dim = 128, # query / key dimension
causal = True, # autoregressive or not
expansion_factor = 2, # hidden dimension = dim * expansion_factor
laplace_attn_fn = True # new Mega paper claims this is more stable than relu squared as attention function
)
x = torch.randn(1, 1024, 512)
out = gau(x) # (1, 1024, 512)
The authors then combine GAU with Katharopoulos linear attention, using grouping of the sequences to overcome a known issue with autoregressive linear attention.
This combination of the quadratic gated attention unit with grouped linear attention they named FLASH
You can also use this quite easily
import torch
from flash_pytorch import FLASH
flash = FLASH(
dim = 512,
group_size = 256, # group size
causal = True, # autoregressive or not
query_key_dim = 128, # query / key dimension
expansion_factor = 2., # hidden dimension = dim * expansion_factor
laplace_attn_fn = True # new Mega paper claims this is more stable than relu squared as attention function
)
x = torch.randn(1, 1111, 512) # sequence will be auto-padded to nearest group size
out = flash(x) # (1, 1111, 512)
Finally, you can use the full FLASH transformer as mentioned in the paper. This contains all the positional embeddings mentioned in the paper. Absolute positional embedding uses scaled sinusoidal. GAU quadratic attention will get one-headed T5 relative positional bias. On top of all this, both GAU attention as well as the linear attention will be rotary embedded (RoPE).
import torch
from flash_pytorch import FLASHTransformer
model = FLASHTransformer(
num_tokens = 20000, # number of tokens
dim = 512, # model dimension
depth = 12, # depth
causal = True, # autoregressive or not
group_size = 256, # size of the groups
query_key_dim = 128, # dimension of queries / keys
expansion_factor = 2., # hidden dimension = dim * expansion_factor
norm_type = 'scalenorm', # in the paper, they claimed scalenorm led to faster training at no performance hit. the other option is 'layernorm' (also default)
shift_tokens = True # discovered by an independent researcher in Shenzhen @BlinkDL, this simply shifts half of the feature space forward one step along the sequence dimension - greatly improved convergence even more in my local experiments
)
x = torch.randint(0, 20000, (1, 1024))
logits = model(x) # (1, 1024, 20000)
Test on Autoregressive Enwik8
$ python train.py
Citations
@article{Hua2022TransformerQI,
title = {Transformer Quality in Linear Time},
author = {Weizhe Hua and Zihang Dai and Hanxiao Liu and Quoc V. Le},
journal = {ArXiv},
year = {2022},
volume = {abs/2202.10447}
}
@software{peng_bo_2021_5196578,
author = {PENG Bo},
title = {BlinkDL/RWKV-LM: 0.01},
month = {aug},
year = {2021},
publisher = {Zenodo},
version = {0.01},
doi = {10.5281/zenodo.5196578},
url = {https://doi.org/10.5281/zenodo.5196578}
}
@inproceedings{Ma2022MegaMA,
title = {Mega: Moving Average Equipped Gated Attention},
author = {Xuezhe Ma and Chunting Zhou and Xiang Kong and Junxian He and Liangke Gui and Graham Neubig and Jonathan May and Luke Zettlemoyer},
year = {2022}
}
.\lucidrains\FLASH-pytorch\setup.py
# 导入设置安装和查找包的函数
from setuptools import setup, find_packages
# 设置包的元数据
setup(
# 包的名称
name = 'FLASH-pytorch',
# 查找所有包,不排除任何包
packages = find_packages(exclude=[]),
# 版本号
version = '0.1.9',
# 许可证类型
license='MIT',
# 包的描述
description = 'FLASH - Transformer Quality in Linear Time - Pytorch',
# 作者
author = 'Phil Wang',
# 作者邮箱
author_email = 'lucidrains@gmail.com',
# 长描述内容类型
long_description_content_type = 'text/markdown',
# 项目的URL
url = 'https://github.com/lucidrains/FLASH-pytorch',
# 关键词列表
keywords = [
'artificial intelligence',
'deep learning',
'transformers',
'attention mechanism'
],
# 安装依赖
install_requires=[
'einops>=0.4',
'rotary-embedding-torch>=0.1.5',
'torch>=1.9',
],
# 分类标签
classifiers=[
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\FLASH-pytorch\train.py
# 导入所需的库
from flash_pytorch import FLASHTransformer
from flash_pytorch.autoregressive_wrapper import AutoregressiveWrapper
import random
import tqdm
import gzip
import numpy as np
import torch
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
# 定义常量
NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 2e-4
VALIDATE_EVERY = 100
GENERATE_EVERY = 500
GENERATE_LENGTH = 512
SEQ_LEN = 1024
# 定义辅助函数
def cycle(loader):
# 无限循环生成数据
while True:
for data in loader:
yield data
def decode_token(token):
# 将 token 解码为字符
return str(chr(max(32, token)))
def decode_tokens(tokens):
# 将 tokens 解码为字符串
return ''.join(list(map(decode_token, tokens)))
# 实例化类似 GPT 的解码器模型
model = FLASHTransformer(
num_tokens = 256,
dim = 512,
depth = 8,
causal = True,
group_size = 256,
shift_tokens = True,
laplace_attn_fn = True
)
model = AutoregressiveWrapper(model)
model.cuda()
# 准备 enwik8 数据
with gzip.open('./data/enwik8.gz') as file:
X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
trX, vaX = np.split(X, [int(90e6)])
data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)
# 定义数据集类
class TextSamplerDataset(Dataset):
def __init__(self, data, seq_len):
super().__init__()
self.data = data
self.seq_len = seq_len
def __getitem__(self, index):
rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
return full_seq.cuda()
def __len__(self):
return self.data.size(0) // self.seq_len
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))
# 定义优化器
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
# 训练过程
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
model.train()
for __ in range(GRADIENT_ACCUMULATE_EVERY):
loss = model(next(train_loader))
loss.backward()
print(f'training loss: {loss.item()}')
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
optim.step()
optim.zero_grad()
if i % VALIDATE_EVERY == 0:
model.eval()
with torch.no_grad():
loss = model(next(val_loader))
print(f'validation loss: {loss.item()}')
if i % GENERATE_EVERY == 0:
model.eval()
inp = random.choice(val_dataset)[:-1]
prime = decode_tokens(inp)
print(f'%s \n\n %s', (prime, '*' * 100))
sample = model.generate(inp[None, ...], GENERATE_LENGTH)
output_str = decode_tokens(sample[0])
print(output_str)

Flexible Diffusion Modeling of Long Videos - Pytorch (wip)
Implementation of the video diffusion model and training scheme presented in the paper, Flexible Diffusion Modeling of Long Videos, in Pytorch. While the Unet architecture does not look that novel (quite similar to Space-time factored unets, where they do attention across time) they achieved up to 25 minutes of coherent video with their specific frame sampling conditioning scheme during training.
I will also attempt to push this approach even further by introducing a super-resoluting module on top identical to what was used in Imagen
Citations
@inproceedings{Harvey2022FlexibleDM,
title = {Flexible Diffusion Modeling of Long Videos},
author = {William Harvey and Saeid Naderiparizi and Vaden Masrani and Christian Weilbach and Frank Wood},
year = {2022}
}

FUSS - Nim (wip)
Implementation of FUSS (Fitness Uniform Selection), a selection method proposed by Marcus Hutter himself for maintaining diversity in evolutionary algorithms, in Nim
Basically will be a rewrite of FUSS in C
Citations
@article{Hutter_2006,
doi = {10.1109/tevc.2005.863127},
url = {https://doi.org/10.1109%2Ftevc.2005.863127},
year = 2006,
month = {oct},
publisher = {Institute of Electrical and Electronics Engineers ({IEEE})},
volume = {10},
number = {5},
pages = {568--589},
author = {M. Hutter and S. Legg},
title = {Fitness uniform optimization},
journal = {{IEEE} Transactions on Evolutionary Computation}
}
Data source
The enwik8 data was downloaded from the Hutter prize page: prize.hutter1.net/
.\lucidrains\g-mlp-gpt\g_mlp_gpt\autoregressive_wrapper.py
import torch
from torch import nn
import torch.nn.functional as F
# 定义一个装饰器函数,用于在模型评估时切换为eval模式
def eval_decorator(fn):
def inner(model, *args, **kwargs):
was_training = model.training
model.eval()
out = fn(model, *args, **kwargs)
model.train(was_training)
return out
return inner
# 定义一个函数用于对logits进行top k过滤
def top_k(logits, thres = 0.9):
k = int((1 - thres) * logits.shape[-1])
val, ind = torch.topk(logits, k)
probs = torch.full_like(logits, float('-inf'))
probs.scatter_(1, ind, val)
return probs
# 定义一个包装类,用于自回归模型
class AutoregressiveWrapper(nn.Module):
def __init__(self, net, ignore_index = -100, pad_value = 0):
super().__init__()
self.pad_value = pad_value
self.ignore_index = ignore_index
self.net = net
self.max_seq_len = net.seq_len
# 生成函数,用于生成序列
@torch.no_grad()
@eval_decorator
def generate(self, start_tokens, seq_len, eos_token = None, temperature = 1., filter_logits_fn = top_k, filter_thres = 0.9, **kwargs):
device = start_tokens.device
num_dims = len(start_tokens.shape)
if num_dims == 1:
start_tokens = start_tokens[None, :]
b, t = start_tokens.shape
out = start_tokens
for _ in range(seq_len):
x = out[:, -self.max_seq_len:]
logits = self.net(x, **kwargs)[:, -1, :]
filtered_logits = top_k(logits, thres = filter_thres)
probs = F.softmax(filtered_logits / temperature, dim=-1)
sample = torch.multinomial(probs, 1)
out = torch.cat((out, sample), dim=-1)
if eos_token is not None and (sample == eos_token).all():
break
out = out[:, t:]
if num_dims == 1:
out = out.squeeze(0)
return out
# 前向传播函数,用于计算损失
def forward(self, x, **kwargs):
xi, xo = x[:, :-1], x[:, 1:]
out = self.net(xi, **kwargs)
loss = F.cross_entropy(out.transpose(1, 2), xo, ignore_index = self.ignore_index)
return loss
.\lucidrains\g-mlp-gpt\g_mlp_gpt\g_mlp_gpt.py
# 从 math 模块中导入 ceil 函数,用于向上取整
# 从 functools 模块中导入 partial 函数,用于创建偏函数
# 从 random 模块中导入 randrange 函数,用于生成指定范围内的随机整数
# 导入 torch 模块
# 从 torch.nn.functional 模块中导入 F 别名
# 从 torch 模块中导入 nn、einsum 函数
from math import ceil
from functools import partial
from random import randrange
import torch
import torch.nn.functional as F
from torch import nn, einsum
# 从 einops 模块中导入 rearrange、repeat 函数
from einops import rearrange, repeat
# 从 g_mlp_gpt.reversible 模块中导入 ReversibleSequence、SequentialSequence 类
# functions
# 定义函数 exists,用于判断值是否存在
def exists(val):
return val is not None
# 定义函数 cast_tuple,用于将值转换为元组
def cast_tuple(val, num):
return ((val,) * num) if not isinstance(val, tuple) else val
# 定义函数 pad_to_multiple,用于将张量填充到指定的倍数
def pad_to_multiple(tensor, multiple, dim = -1, value = 0):
seqlen = tensor.shape[dim]
m = seqlen / multiple
if m.is_integer():
return tensor
remainder = ceil(m) * multiple - seqlen
pad_offset = (0,) * (-1 - dim) * 2
return F.pad(tensor, (*pad_offset, 0, remainder), value = value)
# 定义函数 dropout_layers,用于对层进行随机丢弃
def dropout_layers(layers, prob_survival):
if prob_survival == 1:
return layers
num_layers = len(layers)
to_drop = torch.zeros(num_layers).uniform_(0., 1.) > prob_survival
# 确保至少有一层保留
if all(to_drop):
rand_index = randrange(num_layers)
to_drop[rand_index] = False
layers = [layer for (layer, drop) in zip(layers, to_drop) if not drop]
return layers
# helper classes
# 定义类 Residual,实现残差连接
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x):
return self.fn(x) + x
# 定义类 PreNorm,实现预层归一化
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = nn.LayerNorm(dim)
def forward(self, x, **kwargs):
x = self.norm(x)
return self.fn(x, **kwargs)
# 定义类 GEGLU,实现门控线性单元
class GEGLU(nn.Module):
def forward(self, x):
x, gates = x.chunk(2, dim = -1)
return x * F.gelu(gates)
# 定义类 FeedForward,实现前馈神经网络
class FeedForward(nn.Module):
def __init__(self, dim, mult = 4):
super().__init__()
inner_dim = int(dim * mult * 2 / 3)
self.net = nn.Sequential(
nn.Linear(dim, inner_dim * 2),
GEGLU(),
nn.Linear(inner_dim, dim)
)
def forward(self, x):
return self.net(x)
# 定义类 Attention,实现注意力机制
class Attention(nn.Module):
def __init__(self, dim_in, dim_out, dim_inner):
super().__init__()
self.scale = dim_inner ** -0.5
self.to_qkv = nn.Linear(dim_in, dim_inner * 3, bias = False)
self.to_out = nn.Linear(dim_inner, dim_out)
def forward(self, x):
device = x.device
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
mask = torch.ones(sim.shape[-2:], device = device).triu(1).bool()
sim.masked_fill_(mask[None, ...], -torch.finfo(q.dtype).max)
attn = sim.softmax(dim = -1)
out = einsum('b i j, b j d -> b i d', attn, v)
return self.to_out(out)
# 定义类 LocalAttention,实现局部注意力机制
class LocalAttention(nn.Module):
def __init__(self, dim_in, dim_inner, dim_out, window = 128):
super().__init__()
self.scale = dim_inner ** -0.5
self.window = window
self.to_qkv = nn.Linear(dim_in, dim_inner * 3, bias = False)
self.to_out = nn.Linear(dim_inner, dim_out)
# 定义前向传播函数,接受输入 x
def forward(self, x):
# 获取输入 x 的形状信息,包括 batch size、序列长度、设备信息和窗口大小
b, n, *_, device, w = *x.shape, x.device, self.window
# 将输入 x 进行填充,使其长度能够被窗口大小整除
x = pad_to_multiple(x, w, dim = -2, value = 0.)
# 将填充后的 x 分别转换为查询、键、值,并按照最后一个维度分割成三部分
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
# 定义窗口函数,将输入按照窗口大小重新排列
window_fn = lambda t: rearrange(t, 'b (w n) d -> b w n d', n = w)
q, k, v = map(window_fn, (q, k, v))
# 对键和值进行填充,使其能够进行滑动窗口操作
k, v = map(lambda t: F.pad(t, (0, 0, 0, 0, 1, 0)), (k, v))
k, v = map(lambda t: torch.cat((k[:, :-1], k[:, 1:]), dim = 2), (k, v))
# 计算查询和键之间的相似度,并乘以缩放因子
sim = einsum('b w i d, b w j d -> b w i j', q, k) * self.scale
buckets, i, j = sim.shape[-3:]
# 创建掩码,用于屏蔽无效的位置信息
mask_value = -torch.finfo(sim.dtype).max
mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
mask = repeat(mask, 'i j -> () u i j', u = buckets)
# 将掩码应用到相似度矩阵中
sim.masked_fill_(mask, mask_value)
# 对相似度矩阵进行 softmax 操作,得到注意力权重
attn = sim.softmax(dim = -1)
# 根据注意力权重计算输出
out = einsum('b w i j, b w j d -> b w i d', attn, v)
# 将输出重新排列成原始形状
out = rearrange(out, 'b w n d -> b (w n) d')
# 将输出传递给输出层,并返回结果
out = self.to_out(out[:, :n])
return out
# 定义一个类 CausalSGU,继承自 nn.Module
class CausalSGU(nn.Module):
# 初始化函数,接受多个参数
def __init__(
self,
dim,
dim_seq,
init_eps = 1e-3,
heads = 4,
act = nn.Identity()
):
# 调用父类的初始化函数
super().__init__()
# 计算输出维度
dim_out = dim // 2
# 初始化 LayerNorm 模块
self.norm = nn.LayerNorm(dim_out)
# 设置头数和权重、偏置参数
self.heads = heads
self.weight = nn.Parameter(torch.zeros(heads, dim_seq, dim_seq))
self.bias = nn.Parameter(torch.zeros(heads, dim_seq))
# 初始化权重和偏置参数
init_eps /= dim_seq
nn.init.uniform_(self.weight, -init_eps, init_eps)
nn.init.constant_(self.bias, 1.)
# 设置激活函数
self.act = act
# 创建一个缓冲区,用于存储掩码
self.register_buffer('mask', ~torch.ones(dim_seq, dim_seq).triu_(1).bool())
# 前向传播函数,接受输入 x 和 gate_res
def forward(self, x, gate_res = None):
# 获取设备信息、输入序列长度和头数
device, n, h = x.device, x.shape[1], self.heads
# 将输入 x 分成两部分:res 和 gate
res, gate = x.chunk(2, dim = -1)
# 对 gate 进行 LayerNorm 处理
gate = self.norm(gate)
# 获取权重和偏置参数
weight, bias = self.weight, self.bias
weight, bias = weight[:, :n, :n], bias[:, :n]
# 对权重参数应用掩码
weight = weight * self.mask[None, :n, :n].int().float()
# 重排 gate 的维度
gate = rearrange(gate, 'b n (h d) -> b h n d', h = h)
# 执行矩阵乘法操作
gate = einsum('b h n d, h m n -> b h m d', gate, weight)
# 添加偏置参数
gate = gate + rearrange(bias, 'h n -> () h n ()')
# 重排 gate 的维度
gate = rearrange(gate, 'b h n d -> b n (h d)')
# 如果存在 gate_res,则将其加到 gate 上
if exists(gate_res):
gate = gate + gate_res
# 返回激活函数作用后的结果乘以 res
return self.act(gate) * res
# 定义一个类 CausalLocalSGU,继承自 nn.Module
class CausalLocalSGU(nn.Module):
# 初始化函数,接受多个参数
def __init__(
self,
dim,
dim_seq,
init_eps = 1e-3,
heads = 4,
window = 128,
act = nn.Identity()
):
# 调用父类的初始化函数
super().__init__()
# 计算输出维度
dim_out = dim // 2
# 初始化 LayerNorm 模块
self.norm = nn.LayerNorm(dim_out)
# 设置头数、窗口大小和权重、偏置参数
self.heads = heads
self.window = window
self.weight = nn.Parameter(torch.zeros(heads, window, window * 2))
self.bias = nn.Parameter(torch.zeros(heads, window))
# 初始化权重和偏置参数
init_eps /= window
nn.init.uniform_(self.weight, -init_eps, init_eps)
nn.init.constant_(self.bias, 1.)
# 设置激活函数
self.act = act
# 创建一个缓冲区,用于存储掩码
self.register_buffer('mask', ~torch.ones(window, window * 2).triu_(window + 1).bool())
# 前向传播函数,接受输入 x 和 gate_res
def forward(self, x, gate_res = None):
# 获取设备信息、输入序列长度、头数和窗口大小
device, n, h, w = x.device, x.shape[1], self.heads, self.window
# 将输入 x 分成两部分:res 和 gate
res, gate = x.chunk(2, dim = -1)
# 对 gate 进行填充和重排
gate = pad_to_multiple(gate, w, dim = -2)
gate = rearrange(gate, 'b (w n) d -> b w n d', n = w)
# 对 gate 进行 LayerNorm 处理
gate = self.norm(gate)
# 对 gate 进行填充和拼接
gate = F.pad(gate, (0, 0, 0, 0, 1, 0), value = 0.)
gate = torch.cat((gate[:, :-1], gate[:, 1:]), dim = 2)
# 获取权重和偏置参数
weight, bias = self.weight, self.bias
# 对权重参数应用掩码
weight = weight * self.mask[None, ...].int().float()
# 重排 gate 的维度
gate = rearrange(gate, 'b w n (h d) -> b w h n d', h = h)
# 执行矩阵乘法操作
gate = einsum('b w h n d, h m n -> b w h m d', gate, weight)
# 添加偏置参数
gate = gate + rearrange(bias, 'h n -> () () h n ()')
# 重排 gate 的维度
gate = rearrange(gate, 'b w h n d -> b w n (h d)')
# 重排 gate 的维度
gate = rearrange(gate, 'b w n d -> b (w n) d')
gate = gate[:, :n]
# 如果存在 gate_res,则将其加到 gate 上
if exists(gate_res):
gate = gate + gate_res
# 返回激活函数作用后的结果乘以 res
return self.act(gate) * res
# 定义一个类 AxiallyFold,继承自 nn.Module
class AxiallyFold(nn.Module):
# 初始化函数,接受维度、步长和函数参数
def __init__(self, dim, every, fn):
# 调用父类的初始化函数
super().__init__()
# 设置函数和步长
self.fn = fn
self.every = every
# 如果步长大于 1,则创建一个卷积层
self.conv = nn.Conv1d(dim, dim, kernel_size = every, groups = dim) if every > 1 else None
# 前向传播函数,接受输入 x
def forward(self, x):
# 获取步长
every = self.every
# 如果步长小于等于 1,则直接应用函数
if every <= 1:
return self.fn(x)
# 获取序列长度
n = x.shape[1]
# 对输入 x 进行填充和重排
x = pad_to_multiple(x, self.every, dim = -2)
x = rearrange(x, 'b (n e) d -> (b e) n d', e = every)
x = self.fn(x)
# 重排结果并进行填充
x = rearrange(x, '(b e) n d -> b d (n e)', e = every)
x = F.pad(x, (every - 1, 0), value = 0)
# 对结果应用卷积操作
out = self.conv(x)
out = rearrange(out, 'b d n -> b n d')
return out[:, :n]
# 定义一个类 gMLPBlock,继承自 nn.Module
class gMLPBlock(nn.Module):
# 初始化函数,设置模型参数
def __init__(
self,
*,
dim, # 输入维度
seq_len, # 序列长度
dim_ff, # FeedForward 层维度
heads = 4, # 多头注意力机制的头数,默认为4
causal = False, # 是否使用因果关系,默认为False
window = None, # 窗口大小,默认为None
attn_dim = None, # 注意力机制维度,默认为None
act = nn.Identity() # 激活函数,默认为恒等函数
):
super().__init__()
is_windowed = exists(window) and window < seq_len
# 根据是否存在窗口大小选择不同的 SGU 类型
SGU_klass = partial(CausalLocalSGU, window = window) if is_windowed else CausalSGU
# 根据是否存在窗口大小选择不同的 Attention 类型
Attention_klass = partial(LocalAttention, window = window) if is_windowed else Attention
# 如果存在注意力机制维度,则创建注意力层
self.attn = Attention_klass(dim_in = dim, dim_inner = attn_dim, dim_out = dim_ff // 2) if exists(attn_dim) else None
# 输入投影层,包含线性层和 GELU 激活函数
self.proj_in = nn.Sequential(
nn.Linear(dim, dim_ff),
nn.GELU()
)
# SGU 层,根据选择的 SGU 类型进行初始化
self.sgu = SGU_klass(dim_ff, seq_len, causal, heads = heads, act = act)
# 输出投影层,线性层
self.proj_out = nn.Linear(dim_ff // 2, dim)
# 前向传播函数
def forward(self, x):
# 如果存在注意力层,则进行注意力计算
gate_res = self.attn(x) if exists(self.attn) else None
# 输入投影
x = self.proj_in(x)
# SGU 层计算
x = self.sgu(x, gate_res = gate_res)
# 输出投影
x = self.proj_out(x)
return x
# 主要类
class gMLPGPT(nn.Module):
def __init__(
self,
*,
num_tokens, # 标记的数量
dim, # 向量维度
depth, # 模型深度
seq_len, # 序列长度
heads = 1, # 多头注意力机制的头数,默认为1
ff_mult = 4, # FeedForward 层的倍数,默认为4
prob_survival = 1., # 存活概率,默认为1
reversible = False, # 是否可逆,默认为False
window = None, # 窗口大小,默认为None
attn_dim = None, # 注意力维度,默认为None
act = nn.Identity() # 激活函数,默认为恒等函数
):
super().__init__()
dim_ff = dim * ff_mult
self.seq_len = seq_len
self.prob_survival = prob_survival
self.to_embed = nn.Embedding(num_tokens, dim) # 创建嵌入层
window = cast_tuple(window, depth) # 将窗口大小转换为元组
window = tuple(map(lambda t: t if isinstance(t, tuple) else (t, 1), window)) # 将窗口大小转换为元组
attn_dims = cast_tuple(attn_dim, depth) # 将注意力维度转换为元组
assert len(window) == depth, f'num window sizes {len(window)} must be equal to depth {depth}' # 断言窗口大小数量必须等于深度
layers = nn.ModuleList([]) # 创建模块列表
for ind, (w, ax), attn_dim in zip(range(depth), window, attn_dims):
attn_dim = attn_dim if exists(window) else None
get_gmlp = lambda: PreNorm(dim, AxiallyFold(dim, ax, gMLPBlock(dim = dim, dim_ff = dim_ff, seq_len = seq_len, heads = heads, window = w, act = act, attn_dim = attn_dim)) # 获取 gMLP 模块
layer_blocks = nn.ModuleList([
get_gmlp()
])
if reversible:
layer_blocks.append(FeedForward(dim, mult = ff_mult)) # 如果是可逆模型,添加 FeedForward 层
layers.append(layer_blocks) # 添加模块列表到层列表
execute_klass = SequentialSequence if not reversible else ReversibleSequence # 根据是否可逆选择执行类
self.net = execute_klass(layers) # 创建执行网络
self.to_logits = nn.Sequential(
nn.LayerNorm(dim), # 层归一化
nn.Linear(dim, num_tokens) # 线性层
)
def forward(self, x):
layer_dropout = 1. - self.prob_survival # 计算层的丢弃率
x = self.to_embed(x) # 嵌入输入序列
out = self.net(x, layer_dropout = layer_dropout) # 通过网络传播输入
return self.to_logits(out) # 返回输出��果
.\lucidrains\g-mlp-gpt\g_mlp_gpt\reversible.py
# 导入 torch 库
import torch
# 导入 torch 中的神经网络模块
import torch.nn as nn
# 从 operator 模块中导入 itemgetter 函数
from operator import itemgetter
# 从 torch.autograd.function 模块中导入 Function 类
from torch.autograd.function import Function
# 从 torch.utils.checkpoint 模块中导入 get_device_states 和 set_device_states 函数
# 用于将参数路由到可逆层函数中的函数
def route_args(router, args, depth):
# 初始化路由后的参数列表
routed_args = [(dict(), dict()) for _ in range(depth)]
# 获取参数中与路由器匹配的键
matched_keys = [key for key in args.keys() if key in router]
for key in matched_keys:
val = args[key]
for depth, ((f_args, g_args), routes) in enumerate(zip(routed_args, router[key])):
new_f_args, new_g_args = map(lambda route: ({key: val} if route else {}), routes)
routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args})
return routed_args
# 根据概率丢弃层的函数
def layer_drop(layers, prob):
to_drop = torch.empty(len(layers)).uniform_(0, 1) < prob
blocks = [block for block, drop in zip(layers, to_drop) if not drop]
blocks = layers[:1] if len(blocks) == 0 else blocks
return blocks
# 保存和设置随机数种子的类
class Deterministic(nn.Module):
def __init__(self, net):
super().__init__()
self.net = net
self.cpu_state = None
self.cuda_in_fwd = None
self.gpu_devices = None
self.gpu_states = None
def record_rng(self, *args):
self.cpu_state = torch.get_rng_state()
if torch.cuda._initialized:
self.cuda_in_fwd = True
self.gpu_devices, self.gpu_states = get_device_states(*args)
def forward(self, *args, record_rng = False, set_rng = False, **kwargs):
if record_rng:
self.record_rng(*args)
if not set_rng:
return self.net(*args, **kwargs)
rng_devices = []
if self.cuda_in_fwd:
rng_devices = self.gpu_devices
with torch.random.fork_rng(devices=rng_devices, enabled=True):
torch.set_rng_state(self.cpu_state)
if self.cuda_in_fwd:
set_device_states(self.gpu_devices, self.gpu_states)
return self.net(*args, **kwargs)
# 可逆块类,受启发于 https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py
# 一旦多 GPU 工作正常,重构并将 PR 发回源代码
class ReversibleBlock(nn.Module):
def __init__(self, f, g):
super().__init__()
self.f = Deterministic(f)
self.g = Deterministic(g)
def forward(self, x, f_args = {}, g_args = {}):
x1, x2 = torch.chunk(x, 2, dim=2)
y1, y2 = None, None
with torch.no_grad():
y1 = x1 + self.f(x2, record_rng=self.training, **f_args)
y2 = x2 + self.g(y1, record_rng=self.training, **g_args)
return torch.cat([y1, y2], dim=2)
def backward_pass(self, y, dy, f_args = {}, g_args = {}):
y1, y2 = torch.chunk(y, 2, dim=2)
del y
dy1, dy2 = torch.chunk(dy, 2, dim=2)
del dy
with torch.enable_grad():
y1.requires_grad = True
gy1 = self.g(y1, set_rng=True, **g_args)
torch.autograd.backward(gy1, dy2)
with torch.no_grad():
x2 = y2 - gy1
del y2, gy1
dx1 = dy1 + y1.grad
del dy1
y1.grad = None
with torch.enable_grad():
x2.requires_grad = True
fx2 = self.f(x2, set_rng=True, **f_args)
torch.autograd.backward(fx2, dx1, retain_graph=True)
with torch.no_grad():
x1 = y1 - fx2
del y1, fx2
dx2 = dy2 + x2.grad
del dy2
x2.grad = None
x = torch.cat([x1, x2.detach()], dim=2)
dx = torch.cat([dx1, dx2], dim=2)
return x, dx
# 可逆函数类
class _ReversibleFunction(Function):
@staticmethod
# 前向传播函数,接收上下文对象 ctx,输入数据 x,模块列表 blocks 和参数列表 args
def forward(ctx, x, blocks, args):
# 将参数列表 args 存储到上下文对象 ctx 中
ctx.args = args
# 遍历模块列表 blocks 和参数列表 args,对输入数据 x 进行处理
for block, kwarg in zip(blocks, args):
x = block(x, **kwarg)
# 将处理后的数据 x 分离出来,并存储到上下文对象 ctx 中
ctx.y = x.detach()
# 将模块列表 blocks 存储到上下文对象 ctx 中
ctx.blocks = blocks
# 返回处理后的数据 x
return x
# 反向传播函数,接收上下文对象 ctx 和梯度 dy
@staticmethod
def backward(ctx, dy):
# 获取上下文对象 ctx 中存储的处理后的数据 y 和参数列表 args
y = ctx.y
args = ctx.args
# 反向遍历模块列表 blocks 和参数列表 args,对梯度 dy 进行处理
for block, kwargs in zip(ctx.blocks[::-1], args[::-1]):
# 调用模块的反向传播函数,更新梯度 dy 和数据 y
y, dy = block.backward_pass(y, dy, **kwargs)
# 返回更新后的梯度 dy
return dy, None, None
# 定义一个继承自 nn.Module 的类 SequentialSequence
class SequentialSequence(nn.Module):
# 初始化函数,接受层列表、参数路由字典和层丢弃率作为参数
def __init__(self, layers, args_route = {}, layer_dropout = 0.):
super().__init__()
# 断言每个参数路由映射的深度与顺序层的数量相同
assert all(len(route) == len(layers) for route in args_route.values()), 'each argument route map must have the same depth as the number of sequential layers'
self.layers = layers
self.args_route = args_route
self.layer_dropout = layer_dropout
# 前向传播函数,接受输入 x 和关键字参数 kwargs
def forward(self, x, **kwargs):
# 根据参数路由和关键字参数获取参数
args = route_args(self.args_route, kwargs, len(self.layers))
# 将层和参数组成元组列表
layers_and_args = list(zip(self.layers, args))
# 如果处于训练状态且层丢弃率大于0
if self.training and self.layer_dropout > 0:
# 对层和参数进行层丢弃
layers_and_args = layer_drop(layers_and_args, self.layer_dropout)
# 遍历层和参数列表,对输入 x 进行操作
for (f,), (f_args, _) in layers_and_args:
x = x + f(x, **f_args)
# 返回处理后的 x
return x
# 定义一个继承自 nn.Module 的类 ReversibleSequence
class ReversibleSequence(nn.Module):
# 初始化函数,接受块列表、参数路由字典和层丢弃率作为参数
def __init__(self, blocks, args_route = {}, layer_dropout = 0.):
super().__init__()
self.args_route = args_route
self.layer_dropout = layer_dropout
# 创建包含可逆块的模块列表
self.blocks = nn.ModuleList([ReversibleBlock(f=f, g=g) for f, g in blocks])
# 前向传播函数,接受输入 x、层丢弃率和关键字参数 kwargs
def forward(self, x, layer_dropout = 0., **kwargs):
# 在最后一个维度上连接输入 x 的副本
x = torch.cat([x, x], dim=-1)
# 获取块列表和参数
blocks = self.blocks
args = route_args(self.args_route, kwargs, len(blocks))
args = list(map(lambda x: {'f_args': x[0], 'g_args': x[1]}, args))
# 将块和参数组成元组列表
layers_and_args = list(zip(blocks, args))
# 如果处于训练状态且层丢弃率大于0
if self.training and layer_dropout > 0:
# 对块和参数进行层丢弃
layers_and_args = layer_drop(layers_and_args, layer_dropout)
# 分别获取块和参数
blocks, args = map(lambda ind: list(map(itemgetter(ind), layers_and_args)), (0, 1))
# 调用自定义的可逆函数进行处理
out = _ReversibleFunction.apply(x, blocks, args)
# 在最后一个维度上分割输出并求和
return torch.stack(out.chunk(2, dim=-1)).sum(dim=0)
.\lucidrains\g-mlp-gpt\g_mlp_gpt\__init__.py
# 从 g_mlp_gpt.g_mlp_gpt 模块中导入 gMLPGPT 类
from g_mlp_gpt.g_mlp_gpt import gMLPGPT
GPT - gMLP
This repository will attempt to crack long context autoregressive language modeling (GPT) using variations of gMLPs. Specifically, it will contain a variant that does gMLP for local sliding windows. The hope is to be able to stretch a single GPU to be able to train context lengths of 4096 and above efficiently and well.
You can also add the "tiny" attention (as described in the paper) with the attn_dim keyword argument, which corresponds to the dimension of the single head (64 is recommended). You can pass in a tuple to customize different dimension per layer.
Install
$ pip install g-mlp-gpt
Usage
import torch
from g_mlp_gpt import gMLPGPT
model = gMLPGPT(
num_tokens = 20000,
dim = 512,
depth = 4,
seq_len = 1024,
window = (128, 256, 512, 1024) # window sizes for each depth
)
x = torch.randint(0, 20000, (1, 1000))
logits = model(x) # (1, 1000, 20000)
16k context length
import torch
from g_mlp_gpt import gMLPGPT
model = gMLPGPT(
num_tokens = 20000,
dim = 512,
seq_len = 16384,
reversible = True, # reversible networks
act = nn.Tanh(), # tanh activation for spatial gating
depth = 12,
window = (
128,
128,
256,
512,
1024,
1024,
(2048, 2), # window size of 2048, axial of 2
(2048, 2),
(4096, 4),
(4096, 4),
(8192, 8), # window size of 8192, axial of 8
(8192, 8)
)
).cuda()
x = torch.randint(0, 20000, (1, 16384)).cuda()
logits = model(x) # (1, 16384, 20000)
Citations
@misc{liu2021pay,
title = {Pay Attention to MLPs},
author = {Hanxiao Liu and Zihang Dai and David R. So and Quoc V. Le},
year = {2021},
eprint = {2105.08050},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
.\lucidrains\g-mlp-gpt\setup.py
# 导入设置工具和查找包的函数
from setuptools import setup, find_packages
# 设置包的信息
setup(
# 包的名称
name = 'g-mlp-gpt',
# 查找所有包
packages = find_packages(),
# 版本号
version = '0.0.15',
# 许可证
license='MIT',
# 描述
description = 'gMLP - GPT',
# 作者
author = 'Phil Wang',
# 作者邮箱
author_email = 'lucidrains@gmail.com',
# 项目链接
url = 'https://github.com/lucidrains/g-mlp-gpt',
# 关键词
keywords = [
'artificial intelligence',
'deep learning',
'multi-layered-preceptrons'
],
# 安装依赖
install_requires=[
'einops>=0.3',
'torch>=1.6'
],
# 分类
classifiers=[
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\g-mlp-gpt\train.py
# 导入必要的库
from g_mlp_gpt import gMLPGPT
from g_mlp_gpt.autoregressive_wrapper import AutoregressiveWrapper
import random
import tqdm
import gzip
import numpy as np
import torch
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
# 定义常量
NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 2e-4
VALIDATE_EVERY = 100
GENERATE_EVERY = 500
GENERATE_LENGTH = 768
SEQ_LEN = 768
# 定义辅助函数
# 从 token 解码为字符
def decode_token(token):
return str(chr(max(32, token)))
# 从 tokens 解码为字符串
def decode_tokens(tokens):
return ''.join(list(map(decode_token, tokens)))
# 实例化类似 GPT 的解码器模型
model = gMLPGPT(
num_tokens = 256,
dim = 512,
seq_len = SEQ_LEN,
depth = 8,
window = (16, 32, 64, 128, 256, 512, 768, SEQ_LEN),
attn_dim = 16
)
model = AutoregressiveWrapper(model)
model.cuda()
# 准备 enwik8 数据
with gzip.open('./data/enwik8.gz') as file:
X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
trX, vaX = np.split(X, [int(90e6)])
data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)
# 定义数据集类
class TextSamplerDataset(Dataset):
def __init__(self, data, seq_len):
super().__init__()
self.data = data
self.seq_len = seq_len
def __getitem__(self, index):
rand_start = torch.randint(0, self.data.size(0) - self.seq_len - 1, (1,))
full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
return full_seq.cuda()
def __len__(self):
return self.data.size(0) // self.seq_len
# 创建训练集和验证集的 DataLoader
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))
# 定义优化器
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
# 训练模型
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
model.train()
for __ in range(GRADIENT_ACCUMULATE_EVERY):
loss = model(next(train_loader))
loss.backward()
print(f'training loss: {loss.item()}')
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
optim.step()
optim.zero_grad()
if i % VALIDATE_EVERY == 0:
model.eval()
with torch.no_grad():
loss = model(next(val_loader))
print(f'validation loss: {loss.item()}')
if i % GENERATE_EVERY == 0:
model.eval()
inp = random.choice(val_dataset)[:-1]
prime = decode_tokens(inp)
print(f'%s \n\n %s', (prime, '*' * 100))
sample = model.generate(inp, GENERATE_LENGTH)
output_str = decode_tokens(sample)
print(output_str)