Lucidrains 系列项目源码解析(一)
.\lucidrains\Adan-pytorch\adan_pytorch\adan.py
import math
import torch
from torch.optim import Optimizer
# 定义一个函数,用于检查变量是否存在
def exists(val):
return val is not None
# 定义一个名为 Adan 的类,继承自 Optimizer 类
class Adan(Optimizer):
# 初始化函数,接受一些参数并设置默认值
def __init__(
self,
params,
lr = 1e-3,
betas = (0.02, 0.08, 0.01),
eps = 1e-8,
weight_decay = 0,
restart_cond: callable = None
):
assert len(betas) == 3
# 将参数存储在 defaults 字典中
defaults = dict(
lr = lr,
betas = betas,
eps = eps,
weight_decay = weight_decay,
restart_cond = restart_cond
)
# 调用父类的初始化函数
super().__init__(params, defaults)
# 定义优化步骤函数
def step(self, closure = None):
loss = None
# 如果存在闭包函数,则计算损失值
if exists(closure):
loss = closure()
# 遍历参数组
for group in self.param_groups:
lr = group['lr']
beta1, beta2, beta3 = group['betas']
weight_decay = group['weight_decay']
eps = group['eps']
restart_cond = group['restart_cond']
# 遍历参数
for p in group['params']:
if not exists(p.grad):
continue
data, grad = p.data, p.grad.data
assert not grad.is_sparse
state = self.state[p]
# 初始化状态信息
if len(state) == 0:
state['step'] = 0
state['prev_grad'] = torch.zeros_like(grad)
state['m'] = torch.zeros_like(grad)
state['v'] = torch.zeros_like(grad)
state['n'] = torch.zeros_like(grad)
step, m, v, n, prev_grad = state['step'], state['m'], state['v'], state['n'], state['prev_grad']
if step > 0:
prev_grad = state['prev_grad']
# 主要算法
m.mul_(1 - beta1).add_(grad, alpha = beta1)
grad_diff = grad - prev_grad
v.mul_(1 - beta2).add_(grad_diff, alpha = beta2)
next_n = (grad + (1 - beta2) * grad_diff) ** 2
n.mul_(1 - beta3).add_(next_n, alpha = beta3)
# 偏置校正项
step += 1
correct_m, correct_v, correct_n = map(lambda n: 1 / (1 - (1 - n) ** step), (beta1, beta2, beta3))
# 梯度步骤
def grad_step_(data, m, v, n):
weighted_step_size = lr / (n * correct_n).sqrt().add_(eps)
denom = 1 + weight_decay * lr
data.addcmul_(weighted_step_size, (m * correct_m + (1 - beta2) * v * correct_v), value = -1.).div_(denom)
grad_step_(data, m, v, n)
# 重启条件
if exists(restart_cond) and restart_cond(state):
m.data.copy_(grad)
v.zero_()
n.data.copy_(grad ** 2)
grad_step_(data, m, v, n)
# 设置新的增量步骤
prev_grad.copy_(grad)
state['step'] = step
return loss
.\lucidrains\Adan-pytorch\adan_pytorch\__init__.py
# 从 adan_pytorch.adan 模块中导入 Adan 类
from adan_pytorch.adan import Adan

Adan - Pytorch
Implementation of the Adan (ADAptive Nesterov momentum algorithm) Optimizer in Pytorch.
Explanation from Davis Blalock
Install
$ pip install adan-pytorch
Usage
from adan_pytorch import Adan
# mock model
import torch
from torch import nn
model = torch.nn.Sequential(
nn.Linear(16, 16),
nn.GELU()
)
# instantiate Adan with model parameters
optim = Adan(
model.parameters(),
lr = 1e-3, # learning rate (can be much higher than Adam, up to 5-10x)
betas = (0.02, 0.08, 0.01), # beta 1-2-3 as described in paper - author says most sensitive to beta3 tuning
weight_decay = 0.02 # weight decay 0.02 is optimal per author
)
# train
for _ in range(10):
loss = model(torch.randn(16)).sum()
loss.backward()
optim.step()
optim.zero_grad()
Citations
@article{Xie2022AdanAN,
title = {Adan: Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models},
author = {Xingyu Xie and Pan Zhou and Huan Li and Zhouchen Lin and Shuicheng Yan},
journal = {ArXiv},
year = {2022},
volume = {abs/2208.06677}
}
.\lucidrains\Adan-pytorch\setup.py
# 导入设置工具和查找包工具
from setuptools import setup, find_packages
# 设置包的元数据
setup(
name = 'adan-pytorch', # 包名
packages = find_packages(exclude=[]), # 查找所有包
version = '0.1.0', # 版本号
license='MIT', # 许可证
description = 'Adan - (ADAptive Nesterov momentum algorithm) Optimizer in Pytorch', # 描述
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
long_description_content_type = 'text/markdown', # 长描述内容类型
url = 'https://github.com/lucidrains/Adan-pytorch', # 项目链接
keywords = [
'artificial intelligence', # 关键词
'deep learning', # 关键词
'optimizer', # 关键词
],
install_requires=[
'torch>=1.6', # 安装依赖
],
classifiers=[
'Development Status :: 4 - Beta', # 分类
'Intended Audience :: Developers', # 分类
'Topic :: Scientific/Engineering :: Artificial Intelligence', # 分类
'License :: OSI Approved :: MIT License', # 分类
'Programming Language :: Python :: 3.6', # 分类
],
)
.\lucidrains\adjacent-attention-network\adjacent_attention_network\adjacent_attention_network.py
import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat
from isab_pytorch import ISAB
# helpers
# 检查值是否存在的辅助函数
def exists(val):
return val is not None
# 从 values 中按照 indices 进行批量索引选择的辅助函数
def batched_index_select(values, indices):
last_dim = values.shape[-1]
return values.gather(1, indices[:, :, None].expand(-1, -1, last_dim))
# helper classes
# 残差连接类
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) + x
# 预层归一化类
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = nn.LayerNorm(dim)
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
# 前馈神经网络类
class FeedForward(nn.Module):
def __init__(self, dim, mult = 4, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, dim * mult),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(dim * mult, dim)
)
def forward(self, x, **kwargs):
return self.net(x)
# adjacent attention class
# 邻接注意力类
class AdjacentAttention(nn.Module):
def __init__(
self,
*,
dim,
dim_head = 64,
heads = 4,
dropout = 0.
):
super().__init__()
inner_dim = dim_head * heads
self.scale = dim_head ** -0.5
self.heads = heads
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Linear(inner_dim, dim)
self.null_k = nn.Parameter(torch.randn(heads, dim_head))
self.null_v = nn.Parameter(torch.randn(heads, dim_head))
self.dropout = nn.Dropout(dropout)
def forward(
self,
x,
adj_kv_indices,
mask
):
b, n, d, h = *x.shape, self.heads
flat_indices = repeat(adj_kv_indices, 'b n a -> (b h) (n a)', h = h)
# derive query, key, value
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
# gather keys and values according to adjacency matrix
k, v = map(lambda t: rearrange(t, 'b h n d -> (b h) n d'), (k, v))
k = batched_index_select(k, flat_indices)
v = batched_index_select(v, flat_indices)
k, v = map(lambda t: rearrange(t, '(b h) (n a) d -> b h n a d', h = h, n = n), (k, v))
# add null key / value, so a node can attend to nothing
# have come across this in GNN literature as some other name
nk, nv = map(lambda t: rearrange(t, 'h d -> () h () () d').expand(b, -1, n, 1, -1), (self.null_k, self.null_v))
k = torch.cat((nk, k), dim = -2)
v = torch.cat((nv, v), dim = -2)
mask = F.pad(mask, (1, 0), value = 1)
# similarity of each node to its neighbors
sim = einsum('b h n d, b h n a d -> b h n a', q, k) * self.scale
# mask out neighbors that are just padding
mask_value = -torch.finfo(sim.dtype).max
mask = rearrange(mask.bool(), 'b n a -> b () n a')
sim.masked_fill_(~mask.bool(), mask_value)
# attention
attn = sim.softmax(dim = -1)
# dropout
attn = self.dropout(attn)
# get weighted average of the values of all neighbors
out = einsum('b h n a, b h n a d -> b h n d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
# combine output
return self.to_out(out)
# adjacent network (layers of adjacent attention)
# 邻接注意力网络类
class AdjacentAttentionNetwork(nn.Module):
def __init__(
self,
*,
dim,
depth,
dim_head = 64,
heads = 4,
num_neighbors_cutoff = None,
num_global_nodes = 0,
attn_dropout = 0.,
ff_dropout = 0.
):
super().__init__()
self.num_neighbors_cutoff = num_neighbors_cutoff
self.layers = nn.ModuleList([])
for _ in range(depth):
global_attn = PreNorm(dim, ISAB(
dim = dim,
heads = heads,
num_induced_points = num_global_nodes
)) if num_global_nodes > 0 else None
self.layers.append(nn.ModuleList([
Residual(PreNorm(dim, AdjacentAttention(
dim = dim,
dim_head = dim_head,
heads = heads,
dropout = attn_dropout
))),
global_attn,
Residual(PreNorm(dim, FeedForward(
dim = dim,
dropout = ff_dropout
)))
]))
def forward(self, x, adjacency_mat, mask = None):
device, n = x.device, x.shape[1]
diag = torch.eye(adjacency_mat.shape[-1], device = device).bool()
adjacency_mat |= diag # nodes should pay attention itself (self-interacting)
# zero out points on adjacency matrix
# where the nodes are just padding
if exists(mask):
adjacency_mat &= (mask[:, :, None] * mask[:, None, :])
adj_mat = adjacency_mat.float()
# if we don't set a hard limit to the number of neighbors:
# - get the maximum number of neighbors and pad the rest of the nodes with less than that number of neighbors
# else:
# - randomly sample the cutoff number of neighbors for any node that exceeds the max
# - this would be similar to random sparse attention (bigbird)
# get the maximum number of neighbors
max_neighbors = int(adj_mat.sum(dim = -1).max())
if exists(self.num_neighbors_cutoff) and max_neighbors > self.num_neighbors_cutoff:
# to randomly sample the neighbors, add a small uniform noise to the mask and topk
noise = torch.empty((n, n), device = device).uniform_(-0.01, 0.01)
adj_mat = adj_mat + noise
adj_mask, adj_kv_indices = adj_mat.topk(dim = -1, k = self.num_neighbors_cutoff)
# cast the mask back to 0s and 1s
adj_mask = (adj_mask > 0.5).float()
else:
# todo - get distribution of number of neighbors, and strategically break up attention (message passing) to multiple steps
# - start with a bimodal num neighbors test case, then generalize
# use topk to get all the neighbors
# also pass the mask into the attention, as some neighbors will be just padding and not actually neighbors
adj_mask, adj_kv_indices = adj_mat.topk(dim = -1, k = max_neighbors)
for attn, global_attn, ff in self.layers:
x = attn(
x,
adj_kv_indices = adj_kv_indices,
mask = adj_mask
)
if exists(global_attn):
out, _ = global_attn(x, mask = mask)
x = x + out
x = ff(x)
return x
.\lucidrains\adjacent-attention-network\adjacent_attention_network\__init__.py
# 从相邻注意力网络模块中导入相邻注意力网络类
from adjacent_attention_network.adjacent_attention_network import AdjacentAttentionNetwork
Adjacent Attention Network
An implementation of a simple transformer that is equivalent to graph neural network where the message passing is done with multi-head attention at each successive layer. Since Graph Attention Network is already taken, I decided to name it Adjacent Attention Network instead. The design will be more transformer-centric. Instead of using the square root inverse adjacency matrix trick by Kipf and Welling, in this framework it will simply be translated to the proper attention mask at each layer.
This repository is for my own exploration into the graph neural network field. My gut tells me the transformers architecture can generalize and outperform graph neural networks.
Install
$ pip install adjacent-attention-network
Usage
Basically a transformers where each node pays attention to the neighbors as defined by the adjacency matrix. Complexity is O(n * max_neighbors). Max number of neighbors as defined by the adjacency matrix.
The following example will have a complexity of ~ 1024 * 100
import torch
from adjacent_attention_network import AdjacentAttentionNetwork
model = AdjacentAttentionNetwork(
dim = 512,
depth = 6,
heads = 4
)
adj_mat = torch.empty(1, 1024, 1024).uniform_(0, 1) < 0.1
nodes = torch.randn(1, 1024, 512)
mask = torch.ones(1, 1024).bool()
model(nodes, adj_mat, mask = mask) # (1, 1024, 512)
If the number of neighbors contain outliers, then the above will lead to wasteful computation, since a lot of nodes will be doing attention on padding. You can use the following stop-gap measure to account for these outliers.
import torch
from adjacent_attention_network import AdjacentAttentionNetwork
model = AdjacentAttentionNetwork(
dim = 512,
depth = 6,
heads = 4,
num_neighbors_cutoff = 100
).cuda()
adj_mat = torch.empty(1, 1024, 1024).uniform_(0, 1).cuda() < 0.1
nodes = torch.randn(1, 1024, 512).cuda()
mask = torch.ones(1, 1024).bool().cuda()
# for some reason, one of the nodes is fully connected to all others
adj_mat[:, 0] = 1.
model(nodes, adj_mat, mask = mask) # (1, 1024, 512)
For non-local attention, I've decided to use a trick from the Set Transformers paper, the Induced Set Attention Block (ISAB). From the lens of graph neural net literature, this would be analogous as having global nodes for message passing non-locally.
import torch
from adjacent_attention_network import AdjacentAttentionNetwork
model = AdjacentAttentionNetwork(
dim = 512,
depth = 6,
heads = 4,
num_global_nodes = 5
).cuda()
adj_mat = torch.empty(1, 1024, 1024).uniform_(0, 1).cuda() < 0.1
nodes = torch.randn(1, 1024, 512).cuda()
mask = torch.ones(1, 1024).bool().cuda()
model(nodes, adj_mat, mask = mask) # (1, 1024, 512)
.\lucidrains\adjacent-attention-network\setup.py
# 导入设置安装和查找包的函数
from setuptools import setup, find_packages
# 设置包的元数据
setup(
name = 'adjacent-attention-pytorch', # 包的名称
packages = find_packages(), # 查找所有包
version = '0.0.12', # 版本号
license='MIT', # 许可证
description = 'Adjacent Attention Network - Pytorch', # 描述
long_description_content_type = 'text/markdown', # 长描述内容类型
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
url = 'https://github.com/lucidrains/adjacent-attention-pytorch', # 项目链接
keywords = [ # 关键词列表
'artificial intelligence',
'attention mechanism',
'graph neural network',
'transformers'
],
install_requires=[ # 安装依赖
'einops>=0.3',
'torch>=1.6',
'isab-pytorch<0.2'
],
classifiers=[ # 分类器
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\agent-attention-pytorch\agent_attention_pytorch\agent_attention_pytorch.py
# 导入 torch 库
import torch
# 从 torch.nn 模块中导入 Module 类
from torch.nn import Module
# 从 torch 模块中导入 nn、einsum、Tensor
from torch import nn, einsum, Tensor
# 从 einops 库中导入 rearrange、repeat
from einops import rearrange, repeat
# 从 einops.layers.torch 中导入 Rearrange 类
# 定义函数
# 判断变量是否存在的函数
def exists(v):
return v is not None
# 主要类
# 自注意力机制的代理类
class AgentSelfAttention(Module):
def __init__(
self,
dim,
*,
num_agent_tokens,
dim_head = 64,
heads = 8,
dropout = 0.,
talking_heads = True,
gate = True,
combine_agent_tokens = False
):
super().__init__()
self.scale = dim_head ** -0.5
dim_inner = dim_head * heads
# 将输入转换为查询、键、值
self.to_qkv = nn.Sequential(
nn.Linear(dim, dim_inner * 3, bias = False),
Rearrange('b n (qkv h d) -> qkv b h n d', h = heads, qkv = 3)
)
# 生成门控信息
self.to_gates = nn.Sequential(
nn.Linear(dim, heads),
Rearrange('b n h -> b h n 1'),
nn.Sigmoid()
) if gate else None
# 初始化代理令牌
self.agent_tokens = nn.Parameter(torch.zeros(heads, num_agent_tokens, dim_head))
nn.init.normal_(self.agent_tokens, std = 0.02)
# 对查询和键进行对话操作
self.qa_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if talking_heads else nn.Identity()
self.ak_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if talking_heads else nn.Identity()
# 对查询和键进行 dropout 操作
self.qa_dropout = nn.Dropout(dropout)
self.ak_dropout = nn.Dropout(dropout)
# 输出层
self.to_out = nn.Sequential(
Rearrange('b h n d -> b n (h d)'),
nn.Linear(dim_inner, dim, bias = False)
)
# 前向传播函数
def forward(
self,
x,
mask = None,
agent_tokens = None,
return_agent_tokens = False
):
batch = x.shape[0]
q, k, v = self.to_qkv(x)
if exists(agent_tokens):
a = agent_tokens
else:
a = repeat(self.agent_tokens, 'h m d -> b h m d', b = batch)
a = a * self.scale
qa_sim = einsum('b h i d, b h j d -> b h i j', q, a)
ak_sim = einsum('b h i d, b h j d -> b h i j', a, k)
if exists(mask):
max_neg_value = -torch.finfo(qa_sim.dtype).max
ak_sim = ak_sim.masked_fill(~rearrange(mask, 'b j -> b 1 1 j'), max_neg_value)
qa_attn = qa_sim.softmax(dim = -1)
ak_attn = ak_sim.softmax(dim = -1)
qa_attn = self.qa_dropout(qa_attn)
ak_attn = self.ak_dropout(ak_attn)
qa_attn = self.qa_talking_heads(qa_attn)
ak_attn = self.ak_talking_heads(ak_attn)
agent_gathered_tokens = einsum('b h i j, b h j d -> b h i d', ak_attn, v)
out = einsum('b h i j, b h j d -> b h i d', qa_attn, agent_gathered_tokens)
if exists(mask):
out = out.masked_fill(~rearrange(mask, 'b n -> b 1 n 1'), 0.)
if exists(self.to_gates):
out = out * self.to_gates(x)
out = self.to_out(out)
if not return_agent_tokens:
return out
return out, agent_gathered_tokens
.\lucidrains\agent-attention-pytorch\agent_attention_pytorch\agent_transformer.py
# 导入 torch 库
import torch
# 导入 torch 中的函数库
import torch.nn.functional as F
# 从 torch.nn 中导入 Module 和 ModuleList
from torch.nn import Module, ModuleList
# 从 torch 中导入 nn, einsum, Tensor
from torch import nn, einsum, Tensor
# 从 einops 中导入 rearrange, repeat, pack, unpack
from einops import rearrange, repeat, pack, unpack
# 从 einops.layers.torch 中导入 Rearrange
# 定义函数
# 判断变量是否存在的函数
def exists(v):
return v is not None
# 归一化函数
# RMS 归一化类
class RMSNorm(Module):
def __init__(self, dim):
super().__init__()
self.scale = dim ** 0.5
self.gamma = nn.Parameter(torch.ones(dim))
def forward(self, x):
return F.normalize(x, dim = -1) * self.scale * self.gamma
# 前馈网络函数
# 前馈网络类
def FeedForward(dim, mult = 4):
dim_inner = int(dim * mult)
return nn.Sequential(
RMSNorm(dim),
nn.Linear(dim, dim_inner),
nn.GELU(),
nn.Linear(dim_inner, dim)
)
# 主类
# 自注意力机制类
class AgentSelfAttention(Module):
def __init__(
self,
dim,
*,
num_agent_tokens,
dim_head = 64,
heads = 8,
dropout = 0.,
talking_heads = True,
gate = True,
sub_layernorm = False
):
super().__init__()
self.scale = dim_head ** -0.5
dim_inner = dim_head * heads
self.norm = RMSNorm(dim)
self.to_qkv = nn.Sequential(
nn.Linear(dim, dim_inner * 3, bias = False),
Rearrange('b n (qkv h d) -> qkv b h n d', h = heads, qkv = 3)
)
self.to_gates = nn.Sequential(
nn.Linear(dim, heads),
Rearrange('b n h -> b h n 1'),
nn.Sigmoid()
) if gate else None
self.qa_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if talking_heads else nn.Identity()
self.ak_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if talking_heads else nn.Identity()
self.qa_dropout = nn.Dropout(dropout)
self.ak_dropout = nn.Dropout(dropout)
self.to_agent_out = nn.Sequential(
nn.LayerNorm(dim_head) if sub_layernorm else nn.Identity(),
Rearrange('b h n d -> b n (h d)'),
nn.Linear(dim_inner, dim, bias = False)
)
self.to_out = nn.Sequential(
nn.LayerNorm(dim_head) if sub_layernorm else nn.Identity(),
Rearrange('b h n d -> b n (h d)'),
nn.Linear(dim_inner, dim, bias = False)
)
def forward(
self,
x,
*,
agent_tokens,
mask = None,
return_agent_tokens = False
):
x = self.norm(x)
a = self.norm(agent_tokens)
x_and_agents, xa_ps = pack([a, x], 'b * d')
qkv = self.to_qkv(x_and_agents)
qkv_agent, qkv_input = unpack(qkv, xa_ps, 'qkv b h * d')
q, k, v = qkv_input
agent_queries, agent_keys, _ = qkv_agent
q = q * self.scale
agent_queries = agent_queries * self.scale
qa_sim = einsum('b h i d, b h j d -> b h i j', q, agent_keys)
ak_sim = einsum('b h i d, b h j d -> b h i j', agent_queries, k)
if exists(mask):
max_neg_value = -torch.finfo(qa_sim.dtype).max
ak_sim = ak_sim.masked_fill(~rearrange(mask, 'b j -> b 1 1 j'), max_neg_value)
qa_attn = qa_sim.softmax(dim = -1)
ak_attn = ak_sim.softmax(dim = -1)
qa_attn = self.qa_dropout(qa_attn)
ak_attn = self.ak_dropout(ak_attn)
qa_attn = self.qa_talking_heads(qa_attn)
ak_attn = self.ak_talking_heads(ak_attn)
agent_out = einsum('b h i j, b h j d -> b h i d', ak_attn, v)
out = einsum('b h i j, b h j d -> b h i d', qa_attn, agent_out)
if exists(mask):
out = out.masked_fill(~rearrange(mask, 'b n -> b 1 n 1'), 0.)
if exists(self.to_gates):
out = out * self.to_gates(x)
agent_out = agent_out * self.to_gates(a)
out = self.to_out(out)
agent_out = self.to_agent_out(agent_out)
if not return_agent_tokens:
return out
return out, agent_out
# 变换器类
# 变换器类
class AgentTransformer(Module):
def __init__(
self,
dim,
*,
num_agent_tokens,
depth,
heads = 8,
dim_head = 64,
ff_mult = 4,
final_norm = True,
**attn_kwargs: dict
):
super().__init__()
self.agent_tokens = nn.Parameter(torch.zeros(num_agent_tokens, dim))
nn.init.normal_(self.agent_tokens, std = 0.02)
self.layers = ModuleList([])
for _ in range(depth):
self.layers.append(ModuleList([
AgentSelfAttention(
dim = dim,
heads = heads,
dim_head = dim_head,
num_agent_tokens = num_agent_tokens,
**attn_kwargs
),
FeedForward(dim = dim, mult = ff_mult)
]))
self.final_norm = RMSNorm(dim) if final_norm else None
def forward(
self,
x,
mask = None,
return_agent_tokens = False
):
batch = x.shape[0]
a = repeat(self.agent_tokens, 'm d -> b m d', b = batch)
for attn, ff in self.layers:
attn_out, agent_out = attn(
x,
agent_tokens = a,
mask = mask,
return_agent_tokens = True
)
a = a + agent_out
x = x + attn_out
x, ps = pack([a, x], 'b * d')
x = ff(x) + x
a, x = unpack(x, ps, 'b * d')
if exists(self.final_norm):
x = self.final_norm(x)
a = self.final_norm(a)
if not return_agent_tokens:
return x
return x, a
.\lucidrains\agent-attention-pytorch\agent_attention_pytorch\__init__.py
# 从 agent_attention_pytorch 包中导入 AgentSelfAttention 类
from agent_attention_pytorch.agent_attention_pytorch import (
AgentSelfAttention
)
# 从 agent_attention_pytorch 包中导入 AgentTransformer 类
from agent_attention_pytorch.agent_transformer import (
AgentTransformer
)

Agent Attention - Pytorch
Implementation of Agent Attention in Pytorch.
This work seems to be an elegant simplification of ISAB architecture from the Set Transformers paper (requires only one attention block rather than two). While ISAB works, I have found it to be a bit unstable, thus wondering if the simplification in this work resolves that issue.
This repository will add support for variable sequence lengths (masking) and post-softmax talking heads.
Appreciation
- A16Z Open Source AI Grant Program and 🤗 Huggingface for the generous sponsorships, as well as my other sponsors, for affording me the independence to open source current artificial intelligence research
Install
$ pip install agent-attention-pytorch
Usage
import torch
from agent_attention_pytorch import AgentSelfAttention
attn = AgentSelfAttention(
dim = 512,
num_agent_tokens = 256, # number of "agent" tokens
dim_head = 64, # attention head dimension
heads = 8 # number of heads
)
x = torch.randn(2, 65536, 512)
mask = torch.ones(2, 65536).bool()
out = attn(x, mask = mask)
assert out.shape == x.shape
For a full fledged linear transformer based on agent tokens, just import AgentTransformer
import torch
from agent_attention_pytorch import AgentTransformer
transformer = AgentTransformer(
dim = 512,
depth = 6,
num_agent_tokens = 128,
dim_head = 64,
heads = 8
)
x = torch.randn(2, 65536, 512)
mask = torch.ones(2, 65536).bool()
out, agent_tokens = transformer(x, mask = mask)
# (2, 65536, 512), (2, 128, 512)
assert out.shape == x.shape
Citations
@inproceedings{Han2023AgentAO,
title = {Agent Attention: On the Integration of Softmax and Linear Attention},
author = {Dongchen Han and Tianzhu Ye and Yizeng Han and Zhuofan Xia and Shiji Song and Gao Huang},
year = {2023},
url = {https://api.semanticscholar.org/CorpusID:266210414}
}
@misc{shazeer2020talkingheads,
title = {Talking-Heads Attention},
author = {Noam Shazeer and Zhenzhong Lan and Youlong Cheng and Nan Ding and Le Hou},
year = {2020},
eprint = {2003.02436},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
@article{Bondarenko2023QuantizableTR,
title = {Quantizable Transformers: Removing Outliers by Helping Attention Heads Do Nothing},
author = {Yelysei Bondarenko and Markus Nagel and Tijmen Blankevoort},
journal = {ArXiv},
year = {2023},
volume = {abs/2306.12929},
url = {https://api.semanticscholar.org/CorpusID:259224568}
}
@article{Wang2022FoundationT,
title = {Foundation Transformers},
author = {Hongyu Wang and Shuming Ma and Shaohan Huang and Li Dong and Wenhui Wang and Zhiliang Peng and Yu Wu and Payal Bajaj and Saksham Singhal and Alon Benhaim and Barun Patra and Zhun Liu and Vishrav Chaudhary and Xia Song and Furu Wei},
journal = {ArXiv},
year = {2022},
volume = {abs/2210.06423},
url = {https://api.semanticscholar.org/CorpusID:252846241}
}
.\lucidrains\agent-attention-pytorch\setup.py
# 导入设置和查找包的函数
from setuptools import setup, find_packages
# 设置包的元数据
setup(
name = 'agent-attention-pytorch', # 包的名称
packages = find_packages(exclude=[]), # 查找所有包
version = '0.1.7', # 版本号
license='MIT', # 许可证
description = 'Agent Attention - Pytorch', # 描述
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
long_description_content_type = 'text/markdown', # 长描述内容类型
url = 'https://github.com/lucidrains/agent-attention-pytorch', # URL
keywords = [
'artificial intelligence', # 关键词
'deep learning', # 关键词
'attention', # 关键词
'linear attention' # 关键词
],
install_requires=[
'einops>=0.7.0', # 安装所需的依赖
'torch>=2.0' # 安装所需的依赖
],
classifiers=[
'Development Status :: 4 - Beta', # 分类器
'Intended Audience :: Developers', # 分类器
'Topic :: Scientific/Engineering :: Artificial Intelligence', # 分类器
'License :: OSI Approved :: MIT License', # 分类器
'Programming Language :: Python :: 3.6', # 分类器
],
)
.\lucidrains\all-normalization-transformer\all_normalization_transformer\all_normalization_transformer.py
# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块
from torch import nn
# 从 torch.nn 模块中导入 functional 模块
import torch.nn.functional as F
# 从 einops 库中导入 rearrange 函数
from einops import rearrange
# 定义累积均值函数
def cum_mean(t):
# 获取张量的设备信息
device = t.device
# 创建一个从 1 到张量最后一个维度大小的张量
running_num = torch.arange(t.shape[-1], device=t.device) + 1
# 返回累积和除以运行次数的结果
return t.cumsum(dim=-1) / running_num
# 定义归一化函数
def normalize(t, eps=1e-8):
# 减去均值
t -= t.mean(dim=-1, keepdim=True)
# 计算标准差
s = (t ** 2).mean(dim=-1, keepdim=True)
# 返回归一化结果
return t * torch.rsqrt(s + eps)
# 定义因果归一化函数
def causal_normalize(t, eps=1e-8):
# 减去因果均值
t -= cum_mean(t).diagonal(dim1=-2, dim2=-1)[..., None]
# 计算因果标准差
s = cum_mean(t ** 2).diagonal(dim1=-2, dim2=-1)[..., None]
# 返回因果归一化结果
return t * torch.rsqrt(s + eps)
# 定义残差模块
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, *args, **kwargs):
return self.fn(x, *args, **kwargs) + x
# 定义后归一化模块
class PostNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = nn.LayerNorm(dim)
def forward(self, x):
x = self.fn(x)
return self.norm(x)
# 定义前归一化模块
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = nn.LayerNorm(dim)
def forward(self, x):
x = self.norm(x)
return self.fn(x)
# 定义前馈神经网络模块
class FeedForward(nn.Module):
def __init__(self, dim, mult = 4):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, dim * 4),
nn.GELU(),
nn.Linear(dim * 4, dim)
)
def forward(self, x):
return self.net(x)
# 定义注意力模块
class Attention(nn.Module):
def __init__(self, dim, heads = 8, causal = False, shared_kv = False):
super().__init__()
self.causal = causal
self.heads = heads
self.scale = dim ** -0.5
self.shared_kv = shared_kv
self.num_qkv = 3 if not shared_kv else 2
self.to_qkv = nn.Linear(dim, dim * self.num_qkv, bias = False)
self.to_out = nn.Linear(dim, dim)
self.norm_g = nn.Parameter(torch.ones(1, heads, 1, 1))
self.norm_b = nn.Parameter(torch.zeros(1, heads, 1, 1))
def forward(self, x):
b, n, _, h, device = *x.shape, self.heads, x.device
qkv = self.to_qkv(x)
qkv = rearrange(qkv, 'b n (qkv h d) -> qkv b h n d', qkv = self.num_qkv, h = h)
if self.shared_kv:
q, k = qkv
v = k
else:
q, k, v = qkv
dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
if self.causal:
mask = torch.ones(n, n, device = device).triu_(1).bool()
dots.masked_fill_(mask, 0.)
normalize_fn = causal_normalize if self.causal else normalize
normed_attn = normalize_fn(dots)
attn = normed_attn * self.norm_g + self.norm_b
if self.causal:
attn.masked_fill_(mask, 0.)
out = torch.einsum('bhij,bhjd->bhid', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out)
return out
# 定义变压器模块
class Transformer(nn.Module):
def __init__(self, dim, depth, heads = 8, causal = False, only_norm = False, shared_kv = False):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Residual(PostNorm(dim, Attention(dim, heads, causal = causal, shared_kv = shared_kv))),
Residual(PreNorm(dim, FeedForward(dim))) if not only_norm else nn.Identity(),
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x)
x = ff(x)
return x
# 定义变压器语言模型模块
class TransformerLM(nn.Module):
def __init__(self, *, num_tokens, dim, depth, max_seq_len, heads = 8, causal = False, only_norm = False, shared_kv = False):
super().__init__()
self.max_seq_len = max_seq_len
self.token_emb = nn.Embedding(num_tokens, dim)
self.pos_emb = nn.Embedding(max_seq_len, dim)
self.transformer = Transformer(dim, depth, heads, causal = causal, only_norm = only_norm, shared_kv = shared_kv)
self.to_logits = nn.Linear(dim, num_tokens)
def forward(self, x, **kwargs):
_, n = x.shape
x = self.token_emb(x)
x += self.pos_emb(torch.arange(n, device=x.device))
x = self.transformer(x)
x = self.to_logits(x)
return x
.\lucidrains\all-normalization-transformer\all_normalization_transformer\autoregressive_wrapper.py
# 导入必要的库
from functools import partial
import torch
import random
from torch import nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
# 定义一个函数,返回参数值或默认值
def default(value, default):
return value if value is not None else default
# 定义一个函数,计算输入张量的对数
def log(t, eps=1e-9):
return torch.log(t + eps)
# 从输入logits中选择概率最高的元素,直到累积概率超过阈值
def top_p(logits, thres = 0.9):
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cum_probs > 1.0 - thres
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
sorted_indices_to_remove[:, 0] = 0
sorted_logits[sorted_indices_to_remove] = float('-inf')
return sorted_logits.scatter(1, sorted_indices, sorted_logits)
# 从输入logits中选择概率最高的K个元素
def top_k(logits, thres = 0.9):
k = int((1 - thres) * logits.shape[-1])
val, ind = torch.topk(logits, k)
probs = torch.full_like(logits, float('-inf'))
probs.scatter_(1, ind, val)
return probs
# 定义一个自回归封装类
class AutoregressiveWrapper(nn.Module):
def __init__(self, net, ignore_index = None, pad_value = 0):
super().__init__()
self.pad_value = pad_value
self.ignore_index = default(ignore_index, pad_value)
self.net = net
self.max_seq_len = net.max_seq_len
# 生成序列
@torch.no_grad()
def generate(self, start_tokens, seq_len, eos_token = None, temperature = 1., filter_logits_fn = top_k, filter_thres = 0.9, **kwargs):
was_training = self.net.training
num_dims = len(start_tokens.shape)
if num_dims == 1:
start_tokens = start_tokens[None, :]
b, t = start_tokens.shape
self.net.eval()
out = start_tokens
input_mask = kwargs.pop('src_mask', None)
if input_mask is None:
input_mask = torch.full_like(out, True, dtype=torch.bool, device=out.device)
for _ in range(seq_len):
x = out[:, -self.max_seq_len:]
input_mask = input_mask[:, -self.max_seq_len:]
logits = self.net(x, src_mask=input_mask, **kwargs)
logits = logits[:, -1, :]
filtered_logits = filter_logits_fn(logits, thres = filter_thres)
gumbel_noise = -log(-log(torch.zeros_like(filtered_logits).uniform_(0, 1)))
sample = ((filtered_logits / temperature) + gumbel_noise).argmax(dim=-1)
out = torch.cat((out, sample[:, None]), dim=-1)
input_mask = F.pad(input_mask, (1, 0), value=True)
if eos_token is not None and (sample == eos_token).all():
break
out = out[:, t:]
if num_dims == 1:
out = out.squeeze(0)
self.net.train(was_training)
return out
# 前向传播函数
def forward(self, x, *args, **kwargs):
pad = partial(pad_sequence, batch_first = True, padding_value = self.pad_value)
m = kwargs.pop('input_mask', None)
xi, xo = x[:, :-1], x[:, 1:]
if m is not None:
assert m.shape == x.shape[0:2], 'input mask must be the same shape as the input of the auto-regressive wrapper to automatically handle'
kwargs.update(input_mask = m[:, :-1])
out = self.net(xi, *args, **kwargs)
loss = F.cross_entropy(out.transpose(1, 2), xo, ignore_index = self.ignore_index)
return loss
.\lucidrains\all-normalization-transformer\all_normalization_transformer\__init__.py
# 从 all_normalization_transformer 包中导入 TransformerLM 类
from all_normalization_transformer.all_normalization_transformer import TransformerLM
# 从 all_normalization_transformer 包中导入 AutoregressiveWrapper 类
from all_normalization_transformer.autoregressive_wrapper import AutoregressiveWrapper
Data source
The enwik8 data was downloaded from the Hutter prize page: prize.hutter1.net/
Transformer with Normalized Attention
A Transformer that consists of only normalization as its sole non-linearity, as proposed in the paper Normalized Attention Without Probability Cage. This repository will build on the paper's contributions and attempt to make it work for the auto-regressive case.
Update - It works. You can have an entire language model built on only matrix multiplies and normalization.
Pre-requisites
$ pip install -r requirements.txt
Train
$ python train_enwik8.py
Citations
@misc{richter2020normalized,
title={Normalized Attention Without Probability Cage},
author={Oliver Richter and Roger Wattenhofer},
year={2020},
eprint={2005.09561},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
.\lucidrains\all-normalization-transformer\train_enwik8.py
# 导入所需的模块
from all_normalization_transformer import TransformerLM
from all_normalization_transformer.autoregressive_wrapper import AutoregressiveWrapper
import random
import tqdm
import gzip
import numpy as np
import torch
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
# 定义常量
NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 3e-4
VALIDATE_EVERY = 100
GENERATE_EVERY = 500
GENERATE_LENGTH = 512
SEQ_LEN = 512
# 定义辅助函数
# 从 token 解码为字符
def decode_token(token):
return str(chr(max(32, token)))
# 从 tokens 解码为字符串
def decode_tokens(tokens):
return ''.join(list(map(decode_token, tokens)))
# 实例化模型
# 创建 TransformerLM 模型对象
model = TransformerLM(
num_tokens = 256,
dim = 512,
depth = 12,
max_seq_len = SEQ_LEN,
heads = 8,
causal = True,
only_norm = True,
shared_kv = True
)
# 将模型包装为 AutoregressiveWrapper
model = AutoregressiveWrapper(model)
# 将模型移动到 GPU 上
model.cuda()
# 准备 enwik8 数据
# 从压缩文件中读取数据
with gzip.open('./data/enwik8.gz') as file:
X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
trX, vaX = np.split(X, [int(90e6)])
data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)
# 定义自定义数据集类
class TextSamplerDataset(Dataset):
def __init__(self, data, seq_len):
super().__init__()
self.data = data
self.seq_len = seq_len
def __getitem__(self, index):
rand_start = torch.randint(0, self.data.size(0) - self.seq_len - 1, (1,))
full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
return full_seq.cuda()
def __len__(self):
return self.data.size(0) // self.seq_len
# 创建训练集和验证集的数据集对象
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
# 创建训练集和验证集的数据加载器
train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))
# 定义优化器
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
# 训练模型
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
model.train()
for __ in range(GRADIENT_ACCUMULATE_EVERY):
loss = model(next(train_loader))
(loss / GRADIENT_ACCUMULATE_EVERY).backward()
print(f'training loss: {loss.item()}')
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
optim.step()
optim.zero_grad()
if i % VALIDATE_EVERY == 0:
model.eval()
with torch.no_grad():
loss = model(next(val_loader))
print(f'validation loss: {loss.item()}')
if i % GENERATE_EVERY == 0:
model.eval()
inp = random.choice(val_dataset)[:-1]
inp = inp[:SEQ_LEN]
prime = decode_tokens(inp)
print(f'%s \n\n %s', (prime, '*' * 100))
sample = model.generate(inp, GENERATE_LENGTH)
output_str = decode_tokens(sample)
print(output_str)
.\lucidrains\alphafold2\alphafold2_pytorch\alphafold2.py
import torch
from torch import nn, einsum
from torch.utils.checkpoint import checkpoint, checkpoint_sequential
from inspect import isfunction
from functools import partial
from dataclasses import dataclass
import torch.nn.functional as F
from math import sqrt
from einops import rearrange, repeat, reduce
from einops.layers.torch import Rearrange
from alphafold2_pytorch.utils import *
import alphafold2_pytorch.constants as constants
from alphafold2_pytorch.mlm import MLM
# structure module
from invariant_point_attention import IPABlock
from pytorch3d.transforms import quaternion_multiply, quaternion_to_matrix
# constants
@dataclass
class Recyclables:
coords: torch.Tensor
single_msa_repr_row: torch.Tensor
pairwise_repr: torch.Tensor
@dataclass
class ReturnValues:
distance: torch.Tensor = None
theta: torch.Tensor = None
phi: torch.Tensor = None
omega: torch.Tensor = None
msa_mlm_loss: torch.Tensor = None
recyclables: Recyclables = None
# helpers
def exists(val):
return val is not None
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def cast_tuple(val, depth = 1):
return val if isinstance(val, tuple) else (val,) * depth
def init_zero_(layer):
nn.init.constant_(layer.weight, 0.)
if exists(layer.bias):
nn.init.constant_(layer.bias, 0.)
# helper classes
class Always(nn.Module):
def __init__(self, val):
super().__init__()
self.val = val
def forward(self, x):
return self.val
# feed forward
class GEGLU(nn.Module):
def forward(self, x):
x, gates = x.chunk(2, dim = -1)
return x * F.gelu(gates)
class FeedForward(nn.Module):
def __init__(
self,
dim,
mult = 4,
dropout = 0.
):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.net = nn.Sequential(
nn.Linear(dim, dim * mult * 2),
GEGLU(),
nn.Dropout(dropout),
nn.Linear(dim * mult, dim)
)
init_zero_(self.net[-1])
def forward(self, x, **kwargs):
x = self.norm(x)
return self.net(x)
# attention
class Attention(nn.Module):
def __init__(
self,
dim,
seq_len = None,
heads = 8,
dim_head = 64,
dropout = 0.,
gating = True
):
super().__init__()
inner_dim = dim_head * heads
self.seq_len = seq_len
self.heads= heads
self.scale = dim_head ** -0.5
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim)
self.gating = nn.Linear(dim, inner_dim)
nn.init.constant_(self.gating.weight, 0.)
nn.init.constant_(self.gating.bias, 1.)
self.dropout = nn.Dropout(dropout)
init_zero_(self.to_out)
def forward(self, x, mask = None, attn_bias = None, context = None, context_mask = None, tie_dim = None):
device, orig_shape, h, has_context = x.device, x.shape, self.heads, exists(context)
context = default(context, x)
q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
i, j = q.shape[-2], k.shape[-2]
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
# scale
q = q * self.scale
# query / key similarities
if exists(tie_dim):
# as in the paper, for the extra MSAs
# they average the queries along the rows of the MSAs
# they named this particular module MSAColumnGlobalAttention
q, k = map(lambda t: rearrange(t, '(b r) ... -> b r ...', r = tie_dim), (q, k))
q = q.mean(dim = 1)
dots = einsum('b h i d, b r h j d -> b r h i j', q, k)
dots = rearrange(dots, 'b r ... -> (b r) ...')
else:
dots = einsum('b h i d, b h j d -> b h i j', q, k)
# add attention bias, if supplied (for pairwise to msa attention communication)
if exists(attn_bias):
dots = dots + attn_bias
# masking
if exists(mask):
mask = default(mask, lambda: torch.ones(1, i, device = device).bool())
context_mask = mask if not has_context else default(context_mask, lambda: torch.ones(1, k.shape[-2], device = device).bool())
mask_value = -torch.finfo(dots.dtype).max
mask = mask[:, None, :, None] * context_mask[:, None, None, :]
dots = dots.masked_fill(~mask, mask_value)
# attention
attn = dots.softmax(dim = -1)
attn = self.dropout(attn)
# aggregate
out = einsum('b h i j, b h j d -> b h i d', attn, v)
# merge heads
out = rearrange(out, 'b h n d -> b n (h d)')
# gating
gates = self.gating(x)
out = out * gates.sigmoid()
# combine to out
out = self.to_out(out)
return out
class AxialAttention(nn.Module):
def __init__(
self,
dim,
heads,
row_attn = True,
col_attn = True,
accept_edges = False,
global_query_attn = False,
**kwargs
):
super().__init__()
assert not (not row_attn and not col_attn), 'row or column attention must be turned on'
self.row_attn = row_attn
self.col_attn = col_attn
self.global_query_attn = global_query_attn
self.norm = nn.LayerNorm(dim)
self.attn = Attention(dim = dim, heads = heads, **kwargs)
self.edges_to_attn_bias = nn.Sequential(
nn.Linear(dim, heads, bias = False),
Rearrange('b i j h -> b h i j')
) if accept_edges else None
def forward(self, x, edges = None, mask = None):
assert self.row_attn ^ self.col_attn, 'has to be either row or column attention, but not both'
b, h, w, d = x.shape
x = self.norm(x)
# axial attention
if self.col_attn:
axial_dim = w
mask_fold_axial_eq = 'b h w -> (b w) h'
input_fold_eq = 'b h w d -> (b w) h d'
output_fold_eq = '(b w) h d -> b h w d'
elif self.row_attn:
axial_dim = h
mask_fold_axial_eq = 'b h w -> (b h) w'
input_fold_eq = 'b h w d -> (b h) w d'
output_fold_eq = '(b h) w d -> b h w d'
x = rearrange(x, input_fold_eq)
if exists(mask):
mask = rearrange(mask, mask_fold_axial_eq)
attn_bias = None
if exists(self.edges_to_attn_bias) and exists(edges):
attn_bias = self.edges_to_attn_bias(edges)
attn_bias = repeat(attn_bias, 'b h i j -> (b x) h i j', x = axial_dim)
tie_dim = axial_dim if self.global_query_attn else None
out = self.attn(x, mask = mask, attn_bias = attn_bias, tie_dim = tie_dim)
out = rearrange(out, output_fold_eq, h = h, w = w)
return out
class TriangleMultiplicativeModule(nn.Module):
def __init__(
self,
*,
dim,
hidden_dim = None,
mix = 'ingoing'
# 初始化函数,继承父类的初始化方法
def __init__(
super().__init__()
# 断言混合参数只能是'ingoing'或'outgoing'
assert mix in {'ingoing', 'outgoing'}, 'mix must be either ingoing or outgoing'
# 如果隐藏维度未指定,则默认为输入维度
hidden_dim = default(hidden_dim, dim)
# 对输入进行层归一化
self.norm = nn.LayerNorm(dim)
# 左投影层
self.left_proj = nn.Linear(dim, hidden_dim)
# 右投影层
self.right_proj = nn.Linear(dim, hidden_dim)
# 左门控层
self.left_gate = nn.Linear(dim, hidden_dim)
# 右门控层
self.right_gate = nn.Linear(dim, hidden_dim)
# 输出门控层
self.out_gate = nn.Linear(dim, hidden_dim)
# 初始化所有门控为恒等变换
for gate in (self.left_gate, self.right_gate, self.out_gate):
nn.init.constant_(gate.weight, 0.)
nn.init.constant_(gate.bias, 1.)
# 根据混合类型设置混合的乘积表达式
if mix == 'outgoing':
self.mix_einsum_eq = '... i k d, ... j k d -> ... i j d'
elif mix == 'ingoing':
self.mix_einsum_eq = '... k j d, ... k i d -> ... i j d'
# 输出层归一化
self.to_out_norm = nn.LayerNorm(hidden_dim)
# 输出层线性变换
self.to_out = nn.Linear(hidden_dim, dim)
# 前向传播函数
def forward(self, x, mask = None):
# 断言特征图必须是对称的
assert x.shape[1] == x.shape[2], 'feature map must be symmetrical'
# 如果存在掩码,则重排掩码的维度
if exists(mask):
mask = rearrange(mask, 'b i j -> b i j ()')
# 对输入进行归一化
x = self.norm(x)
# 左投影
left = self.left_proj(x)
# 右投影
right = self.right_proj(x)
# 如果存在掩码,则将投影结果与掩码相乘
if exists(mask):
left = left * mask
right = right * mask
# 计算左门控
left_gate = self.left_gate(x).sigmoid()
# 计算右门控
right_gate = self.right_gate(x).sigmoid()
# 计算输出门控
out_gate = self.out_gate(x).sigmoid()
# 左投影结果与左门控相乘
left = left * left_gate
# 右投影结果与右门控相乘
right = right * right_gate
# 执行乘积操作
out = einsum(self.mix_einsum_eq, left, right)
# 输出结果归一化
out = self.to_out_norm(out)
# 输出结果与输出门控相乘
out = out * out_gate
# 返回输出结果
return self.to_out(out)
# 定义 OuterMean 类,用于计算两个输入的外积均值
class OuterMean(nn.Module):
def __init__(
self,
dim,
hidden_dim = None,
eps = 1e-5
):
super().__init__()
self.eps = eps
self.norm = nn.LayerNorm(dim)
hidden_dim = default(hidden_dim, dim)
self.left_proj = nn.Linear(dim, hidden_dim)
self.right_proj = nn.Linear(dim, hidden_dim)
self.proj_out = nn.Linear(hidden_dim, dim)
def forward(self, x, mask = None):
x = self.norm(x)
left = self.left_proj(x)
right = self.right_proj(x)
outer = rearrange(left, 'b m i d -> b m i () d') * rearrange(right, 'b m j d -> b m () j d')
if exists(mask):
# 如果存在 mask,则进行 masked mean 操作,用于处理 MSA 中的填充
mask = rearrange(mask, 'b m i -> b m i () ()') * rearrange(mask, 'b m j -> b m () j ()')
outer = outer.masked_fill(~mask, 0.)
outer = outer.mean(dim = 1) / (mask.sum(dim = 1) + self.eps)
else:
outer = outer.mean(dim = 1)
return self.proj_out(outer)
# 定义 PairwiseAttentionBlock 类,用于计算两个输入的注意力
class PairwiseAttentionBlock(nn.Module):
def __init__(
self,
dim,
seq_len,
heads,
dim_head,
dropout = 0.,
global_column_attn = False
):
super().__init__()
self.outer_mean = OuterMean(dim)
self.triangle_attention_outgoing = AxialAttention(dim = dim, heads = heads, dim_head = dim_head, row_attn = True, col_attn = False, accept_edges = True)
self.triangle_attention_ingoing = AxialAttention(dim = dim, heads = heads, dim_head = dim_head, row_attn = False, col_attn = True, accept_edges = True, global_query_attn = global_column_attn)
self.triangle_multiply_outgoing = TriangleMultiplicativeModule(dim = dim, mix = 'outgoing')
self.triangle_multiply_ingoing = TriangleMultiplicativeModule(dim = dim, mix = 'ingoing')
def forward(
self,
x,
mask = None,
msa_repr = None,
msa_mask = None
):
if exists(msa_repr):
x = x + self.outer_mean(msa_repr, mask = msa_mask)
x = self.triangle_multiply_outgoing(x, mask = mask) + x
x = self.triangle_multiply_ingoing(x, mask = mask) + x
x = self.triangle_attention_outgoing(x, edges = x, mask = mask) + x
x = self.triangle_attention_ingoing(x, edges = x, mask = mask) + x
return x
# 定义 MsaAttentionBlock 类,用于计��� MSA 的注意力
class MsaAttentionBlock(nn.Module):
def __init__(
self,
dim,
seq_len,
heads,
dim_head,
dropout = 0.
):
super().__init__()
self.row_attn = AxialAttention(dim = dim, heads = heads, dim_head = dim_head, row_attn = True, col_attn = False, accept_edges = True)
self.col_attn = AxialAttention(dim = dim, heads = heads, dim_head = dim_head, row_attn = False, col_attn = True)
def forward(
self,
x,
mask = None,
pairwise_repr = None
):
x = self.row_attn(x, mask = mask, edges = pairwise_repr) + x
x = self.col_attn(x, mask = mask) + x
return x
# 定义 EvoformerBlock 类,包含 PairwiseAttentionBlock、FeedForward 和 MsaAttentionBlock
class EvoformerBlock(nn.Module):
def __init__(
self,
*,
dim,
seq_len,
heads,
dim_head,
attn_dropout,
ff_dropout,
global_column_attn = False
):
super().__init__()
self.layer = nn.ModuleList([
PairwiseAttentionBlock(dim = dim, seq_len = seq_len, heads = heads, dim_head = dim_head, dropout = attn_dropout, global_column_attn = global_column_attn),
FeedForward(dim = dim, dropout = ff_dropout),
MsaAttentionBlock(dim = dim, seq_len = seq_len, heads = heads, dim_head = dim_head, dropout = attn_dropout),
FeedForward(dim = dim, dropout = ff_dropout),
])
def forward(self, inputs):
x, m, mask, msa_mask = inputs
attn, ff, msa_attn, msa_ff = self.layer
# msa attention and transition
m = msa_attn(m, mask = msa_mask, pairwise_repr = x)
m = msa_ff(m) + m
# pairwise attention and transition
x = attn(x, mask = mask, msa_repr = m, msa_mask = msa_mask)
x = ff(x) + x
return x, m, mask, msa_mask
# 定义 Evoformer 类,包含多个 EvoformerBlock
class Evoformer(nn.Module):
def __init__(
self,
*,
depth,
**kwargs
):
super().__init__()
self.layers = nn.ModuleList([EvoformerBlock(**kwargs) for _ in range(depth)])
def forward(
self,
x,
m,
mask = None,
msa_mask = None
):
inp = (x, m, mask, msa_mask)
x, m, *_ = checkpoint_sequential(self.layers, 1, inp)
return x, m
# 定义 Alphafold2 类,包含各种模型参数和结构相关的参数
class Alphafold2(nn.Module):
def __init__(
self,
*,
dim,
max_seq_len = 2048,
depth = 6,
heads = 8,
dim_head = 64,
max_rel_dist = 32,
num_tokens = constants.NUM_AMINO_ACIDS,
num_embedds = constants.NUM_EMBEDDS_TR,
max_num_msas = constants.MAX_NUM_MSA,
max_num_templates = constants.MAX_NUM_TEMPLATES,
extra_msa_evoformer_layers = 4,
attn_dropout = 0.,
ff_dropout = 0.,
templates_dim = 32,
templates_embed_layers = 4,
templates_angles_feats_dim = 55,
predict_angles = False,
symmetrize_omega = False,
predict_coords = False, # structure module related keyword arguments below
structure_module_depth = 4,
structure_module_heads = 1,
structure_module_dim_head = 4,
disable_token_embed = False,
mlm_mask_prob = 0.15,
mlm_random_replace_token_prob = 0.1,
mlm_keep_token_same_prob = 0.1,
mlm_exclude_token_ids = (0,),
recycling_distance_buckets = 32
):
# 初始化函数,继承父类的初始化方法
super().__init__()
# 设置维度
self.dim = dim
# token embedding
# 创建一个词嵌入层,用于将词索引映射为向量表示,如果禁用了词嵌入,则使用常数0
self.token_emb = nn.Embedding(num_tokens + 1, dim) if not disable_token_embed else Always(0)
# 线性层,用于将维度转换为双倍
self.to_pairwise_repr = nn.Linear(dim, dim * 2)
# 是否禁用了词嵌入
self.disable_token_embed = disable_token_embed
# positional embedding
# 设置最大相对距离
self.max_rel_dist = max_rel_dist
# 创建一个位置嵌入层,用于将位置索引映射为向量表示
self.pos_emb = nn.Embedding(max_rel_dist * 2 + 1, dim)
# extra msa embedding
# 创建一个额外的多序列比对嵌入模块
self.extra_msa_evoformer = Evoformer(
dim = dim,
depth = extra_msa_evoformer_layers,
seq_len = max_seq_len,
heads = heads,
dim_head = dim_head,
attn_dropout = attn_dropout,
ff_dropout = ff_dropout,
global_column_attn = True
)
# template embedding
# 线性层,用于将模板维度转换为指定维度
self.to_template_embed = nn.Linear(templates_dim, dim)
self.templates_embed_layers = templates_embed_layers
# 模板对注意力块
self.template_pairwise_embedder = PairwiseAttentionBlock(
dim = dim,
dim_head = dim_head,
heads = heads,
seq_len = max_seq_len
)
# 模板点注意力
self.template_pointwise_attn = Attention(
dim = dim,
dim_head = dim_head,
heads = heads,
dropout = attn_dropout
)
# 模板角度MLP
self.template_angle_mlp = nn.Sequential(
nn.Linear(templates_angles_feats_dim, dim),
nn.GELU(),
nn.Linear(dim, dim)
)
# projection for angles, if needed
# 是否需要预测角度
self.predict_angles = predict_angles
self.symmetrize_omega = symmetrize_omega
if predict_angles:
# 线性层,用于将维度转换为角度桶的数量
self.to_prob_theta = nn.Linear(dim, constants.THETA_BUCKETS)
self.to_prob_phi = nn.Linear(dim, constants.PHI_BUCKETS)
self.to_prob_omega = nn.Linear(dim, constants.OMEGA_BUCKETS)
# custom embedding projection
# 自定义嵌入投影
self.embedd_project = nn.Linear(num_embedds, dim)
# main trunk modules
# 主干模块
self.net = Evoformer(
dim = dim,
depth = depth,
seq_len = max_seq_len,
heads = heads,
dim_head = dim_head,
attn_dropout = attn_dropout,
ff_dropout = ff_dropout
)
# MSA SSL MLM
# 多序列比对自监督MLM
self.mlm = MLM(
dim = dim,
num_tokens = num_tokens,
mask_id = num_tokens, # 最后一个嵌入的标记用于掩码
mask_prob = mlm_mask_prob,
keep_token_same_prob = mlm_keep_token_same_prob,
random_replace_token_prob = mlm_random_replace_token_prob,
exclude_token_ids = mlm_exclude_token_ids
)
# calculate distogram logits
# 计算距离图的logits
self.to_distogram_logits = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, constants.DISTOGRAM_BUCKETS)
)
# to coordinate output
# 是否预测坐标
self.predict_coords = predict_coords
self.structure_module_depth = structure_module_depth
self.msa_to_single_repr_dim = nn.Linear(dim, dim)
self.trunk_to_pairwise_repr_dim = nn.Linear(dim, dim)
with torch_default_dtype(torch.float32):
# IPA块
self.ipa_block = IPABlock(
dim = dim,
heads = structure_module_heads,
)
self.to_quaternion_update = nn.Linear(dim, 6)
init_zero_(self.ipa_block.attn.to_out)
self.to_points = nn.Linear(dim, 3)
# aux confidence measure
# 辅助置信度测量
self.lddt_linear = nn.Linear(dim, 1)
# recycling params
# 回收参数
self.recycling_msa_norm = nn.LayerNorm(dim)
self.recycling_pairwise_norm = nn.LayerNorm(dim)
self.recycling_distance_embed = nn.Embedding(recycling_distance_buckets, dim)
self.recycling_distance_buckets = recycling_distance_buckets
def forward(
self,
seq,
msa = None,
mask = None,
msa_mask = None,
extra_msa = None,
extra_msa_mask = None,
seq_index = None,
seq_embed = None,
msa_embed = None,
templates_feats = None,
templates_mask = None,
templates_angles = None,
embedds = None,
recyclables = None,
return_trunk = False,
return_confidence = False,
return_recyclables = False,
return_aux_logits = False