Lucidrains 系列项目源码解析(一百一十一)
.\lucidrains\vit-pytorch\vit_pytorch\vit_with_patch_dropout.py
# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块
from torch import nn
# 从 einops 库中导入 rearrange 和 repeat 函数,以及 Rearrange 类
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
# 辅助函数
# 如果输入 t 是元组,则返回 t,否则返回包含 t 的元组
def pair(t):
return t if isinstance(t, tuple) else (t, t)
# 类定义
# 定义 PatchDropout 类,继承自 nn.Module
class PatchDropout(nn.Module):
# 初始化函数,接受概率参数 prob
def __init__(self, prob):
super().__init__()
# 断言概率在 [0, 1) 范围内
assert 0 <= prob < 1.
self.prob = prob
# 前向传播函数,接受输入 x
def forward(self, x):
# 如果不在训练模式或概率为 0,则直接返回输入 x
if not self.training or self.prob == 0.:
return x
# 获取输入 x 的形状信息
b, n, _, device = *x.shape, x.device
# 生成 batch 索引
batch_indices = torch.arange(b, device = device)
batch_indices = rearrange(batch_indices, '... -> ... 1')
# 计算保留的 patch 数量
num_patches_keep = max(1, int(n * (1 - self.prob)))
# 生成保留的 patch 索引
patch_indices_keep = torch.randn(b, n, device = device).topk(num_patches_keep, dim = -1).indices
return x[batch_indices, patch_indices_keep]
# 定义 FeedForward 类,继承自 nn.Module
class FeedForward(nn.Module):
# 初始化函数,接受维度 dim、隐藏层维度 hidden_dim 和 dropout 参数
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
# 定义网络结构
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
# 前向传播函数,接受输入 x
def forward(self, x):
return self.net(x)
# 定义 Attention 类,继承自 nn.Module
class Attention(nn.Module):
# 初始化函数,接受维度 dim、头数 heads、头维度 dim_head 和 dropout 参数
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
# 前向传播函数,接受输入 x
def forward(self, x):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
# 定义 Transformer 类,继承自 nn.Module
class Transformer(nn.Module):
# 初始化函数,接受维度 dim、深度 depth、头数 heads、头维度 dim_head、MLP 维度 mlp_dim 和 dropout 参数
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.layers = nn.ModuleList([])
# 根据深度循环创建多个 Transformer 层
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
# 前向传播函数,接受输入 x
def forward(self, x):
# 遍历每个 Transformer 层
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x
# 定义 ViT 类,继承自 nn.Module
class ViT(nn.Module):
# 初始化函数,设置模型参数
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., patch_dropout = 0.25):
# 调用父类的初始化函数
super().__init__()
# 获取图像的高度和宽度
image_height, image_width = pair(image_size)
# 获取补丁的高度和宽度
patch_height, patch_width = pair(patch_size)
# 断言图像的高度和宽度能够被补丁的高度和宽度整除
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
# 计算补丁的数量
num_patches = (image_height // patch_height) * (image_width // patch_width)
# 计算每个补丁的维度
patch_dim = channels * patch_height * patch_width
# 断言池化类型只能是'cls'(CLS标记)或'mean'(平均池化)
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
# 将图像转换为补丁嵌入
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
nn.Linear(patch_dim, dim),
)
# 初始化位置嵌入参数
self.pos_embedding = nn.Parameter(torch.randn(num_patches, dim))
# 初始化CLS标记
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
# 创建补丁丢弃层
self.patch_dropout = PatchDropout(patch_dropout)
# 创建嵌入丢弃层
self.dropout = nn.Dropout(emb_dropout)
# 创建Transformer模型
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
# 设置池化类型
self.pool = pool
# 创建转换到潜在空间的层
self.to_latent = nn.Identity()
# 创建MLP头部
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
# 前向传播函数
def forward(self, img):
# 将图像转换为补丁嵌入
x = self.to_patch_embedding(img)
b, n, _ = x.shape
# 添加位置嵌入
x += self.pos_embedding
# 对补丁进行丢弃
x = self.patch_dropout(x)
# 重复CLS标记以匹配批次大小
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
# 将CLS标记和补丁连接在一起
x = torch.cat((cls_tokens, x), dim=1)
x = self.dropout(x)
# 使用Transformer进行特征提取
x = self.transformer(x)
# 池化操作,根据池化类型选择不同的方式
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
# 转换到潜在空间
x = self.to_latent(x)
# 使用MLP头部进行分类预测
return self.mlp_head(x)
.\lucidrains\vit-pytorch\vit_pytorch\vit_with_patch_merger.py
import torch
from torch import nn
from einops import rearrange, repeat
from einops.layers.torch import Rearrange, Reduce
# helpers
# 检查值是否存在
def exists(val):
return val is not None
# 返回值或默认值
def default(val ,d):
return val if exists(val) else d
# 将输入转换为元组
def pair(t):
return t if isinstance(t, tuple) else (t, t)
# patch merger class
# 定义 PatchMerger 类
class PatchMerger(nn.Module):
def __init__(self, dim, num_tokens_out):
super().__init__()
self.scale = dim ** -0.5
self.norm = nn.LayerNorm(dim)
self.queries = nn.Parameter(torch.randn(num_tokens_out, dim))
def forward(self, x):
x = self.norm(x)
sim = torch.matmul(self.queries, x.transpose(-1, -2)) * self.scale
attn = sim.softmax(dim = -1)
return torch.matmul(attn, x)
# classes
# 定义 FeedForward 类
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
# 定义 Attention 类
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
# 定义 Transformer 类
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., patch_merge_layer = None, patch_merge_num_tokens = 8):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.layers = nn.ModuleList([])
self.patch_merge_layer_index = default(patch_merge_layer, depth // 2) - 1 # default to mid-way through transformer, as shown in paper
self.patch_merger = PatchMerger(dim = dim, num_tokens_out = patch_merge_num_tokens)
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
def forward(self, x):
for index, (attn, ff) in enumerate(self.layers):
x = attn(x) + x
x = ff(x) + x
if index == self.patch_merge_layer_index:
x = self.patch_merger(x)
return self.norm(x)
class ViT(nn.Module):
# 初始化函数,设置模型参数
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, patch_merge_layer = None, patch_merge_num_tokens = 8, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
# 调用父类的初始化函数
super().__init__()
# 获取图像的高度和宽度
image_height, image_width = pair(image_size)
# 获取补丁的高度和宽度
patch_height, patch_width = pair(patch_size)
# 检查图像的尺寸是否能被补丁的尺寸整除
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
# 计算补丁的数量
num_patches = (image_height // patch_height) * (image_width // patch_width)
# 计算每个补丁的维度
patch_dim = channels * patch_height * patch_width
# 定义将图像转换为补丁嵌入的序列
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim)
)
# 初始化位置嵌入参数
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
# 定义丢弃层
self.dropout = nn.Dropout(emb_dropout)
# 初始化Transformer模型
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, patch_merge_layer, patch_merge_num_tokens)
# 定义MLP头部
self.mlp_head = nn.Sequential(
Reduce('b n d -> b d', 'mean'),
nn.Linear(dim, num_classes)
)
# 前向传播函数
def forward(self, img):
# 将图像转换为补丁嵌入
x = self.to_patch_embedding(img)
b, n, _ = x.shape
# 添加位置嵌入到补丁嵌入中
x += self.pos_embedding[:, :n]
x = self.dropout(x)
# 使用Transformer进行特征提取
x = self.transformer(x)
# 使用MLP头部进行分类
return self.mlp_head(x)
.\lucidrains\vit-pytorch\vit_pytorch\vivit.py
# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块
from torch import nn
# 从 einops 库中导入 rearrange, repeat, reduce 函数
from einops import rearrange, repeat, reduce
# 从 einops.layers.torch 库中导入 Rearrange 类
from einops.layers.torch import Rearrange
# 辅助函数
# 判断值是否存在的函数
def exists(val):
return val is not None
# 将输入转换为元组的函数
def pair(t):
return t if isinstance(t, tuple) else (t, t)
# 类
# 前馈神经网络类
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
# 定义神经网络结构
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
# 注意力机制类
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
# 变换器类
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return self.norm(x)
# 视觉变换器类
class ViT(nn.Module):
def __init__(
self,
*,
image_size,
image_patch_size,
frames,
frame_patch_size,
num_classes,
dim,
spatial_depth,
temporal_depth,
heads,
mlp_dim,
pool = 'cls',
channels = 3,
dim_head = 64,
dropout = 0.,
emb_dropout = 0.
):
# 调用父类的构造函数
super().__init__()
# 解构图像尺寸和图像块尺寸
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(image_patch_size)
# 断言图像高度和宽度能够被图像块高度和宽度整除
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
# 断言帧数能够被帧块大小整除
assert frames % frame_patch_size == 0, 'Frames must be divisible by frame patch size'
# 计算图像块数量和帧块数量
num_image_patches = (image_height // patch_height) * (image_width // patch_width)
num_frame_patches = (frames // frame_patch_size)
# 计算图像块维度
patch_dim = channels * patch_height * patch_width * frame_patch_size
# 断言池化类型为'cls'或'mean'
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
# 根据池化类型设置是否使用全局平均池化
self.global_average_pool = pool == 'mean'
# 定义将图像块转换为嵌入向量的层序列
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (f pf) (h p1) (w p2) -> b f (h w) (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim)
)
# 初始化位置嵌入参数
self.pos_embedding = nn.Parameter(torch.randn(1, num_frame_patches, num_image_patches, dim))
self.dropout = nn.Dropout(emb_dropout)
# 初始化空间和时间的CLS token参数
self.spatial_cls_token = nn.Parameter(torch.randn(1, 1, dim)) if not self.global_average_pool else None
self.temporal_cls_token = nn.Parameter(torch.randn(1, 1, dim)) if not self.global_average_pool else None
# 初始化空间和时间的Transformer模型
self.spatial_transformer = Transformer(dim, spatial_depth, heads, dim_head, mlp_dim, dropout)
self.temporal_transformer = Transformer(dim, temporal_depth, heads, dim_head, mlp_dim, dropout)
# 设置池化类型和转换为潜在空间的层
self.pool = pool
self.to_latent = nn.Identity()
# 定义MLP头部
self.mlp_head = nn.Linear(dim, num_classes)
def forward(self, video):
# 将视频转换为图像块嵌入向量
x = self.to_patch_embedding(video)
b, f, n, _ = x.shape
# 添加位置嵌入
x = x + self.pos_embedding[:, :f, :n]
# 如果存在空间CLS token,则添加到输入中
if exists(self.spatial_cls_token):
spatial_cls_tokens = repeat(self.spatial_cls_token, '1 1 d -> b f 1 d', b = b, f = f)
x = torch.cat((spatial_cls_tokens, x), dim = 2)
# 应用Dropout
x = self.dropout(x)
# 重排张量形状以便空间注意力
x = rearrange(x, 'b f n d -> (b f) n d')
# 在空间上进行注意力计算
x = self.spatial_transformer(x)
# 重排张量形状以便后续处理
x = rearrange(x, '(b f) n d -> b f n d', b = b)
# 剔除空间CLS token或进行全局平均池化以便时间注意力
x = x[:, :, 0] if not self.global_average_pool else reduce(x, 'b f n d -> b f d', 'mean')
# 如果存在时间CLS token,则添加到输入中
if exists(self.temporal_cls_token):
temporal_cls_tokens = repeat(self.temporal_cls_token, '1 1 d-> b 1 d', b = b)
x = torch.cat((temporal_cls_tokens, x), dim = 1)
# 在时间上进行注意力计算
x = self.temporal_transformer(x)
# 剔除时间CLS token或进行全局平均池化
x = x[:, 0] if not self.global_average_pool else reduce(x, 'b f d -> b d', 'mean')
# 转换为潜在空间并返回MLP头部的输出
x = self.to_latent(x)
return self.mlp_head(x)
.\lucidrains\vit-pytorch\vit_pytorch\xcit.py
# 从 random 模块中导入 randrange 函数
from random import randrange
# 导入 torch 模块及相关子模块
import torch
from torch import nn, einsum
from torch.nn import Module, ModuleList
import torch.nn.functional as F
# 导入 einops 模块及相关函数
from einops import rearrange, repeat, pack, unpack
from einops.layers.torch import Rearrange
# 辅助函数
# 判断变量是否存在的函数
def exists(val):
return val is not None
# 将张量打包成指定模式的函数
def pack_one(t, pattern):
return pack([t], pattern)
# 将打包的张量解包成指定模式的函数
def unpack_one(t, ps, pattern):
return unpack(t, ps, pattern)[0]
# 对张量进行 L2 归一化的函数
def l2norm(t):
return F.normalize(t, dim = -1, p = 2)
# 对神经网络层进行 dropout 处理的函数
def dropout_layers(layers, dropout):
if dropout == 0:
return layers
num_layers = len(layers)
to_drop = torch.zeros(num_layers).uniform_(0., 1.) < dropout
# 确保至少有一层不被丢弃
if all(to_drop):
rand_index = randrange(num_layers)
to_drop[rand_index] = False
layers = [layer for (layer, drop) in zip(layers, to_drop) if not drop]
return layers
# 类
# LayerScale 类,用于对输入进行缩放
class LayerScale(Module):
def __init__(self, dim, fn, depth):
super().__init__()
if depth <= 18:
init_eps = 0.1
elif 18 < depth <= 24:
init_eps = 1e-5
else:
init_eps = 1e-6
self.fn = fn
self.scale = nn.Parameter(torch.full((dim,), init_eps))
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) * self.scale
# FeedForward 类,前馈神经网络层
class FeedForward(Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
# Attention 类,注意力机制层
class Attention(Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x, context = None):
h = self.heads
x = self.norm(x)
context = x if not exists(context) else torch.cat((x, context), dim = 1)
qkv = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
attn = self.attend(sim)
attn = self.dropout(attn)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
# XCAttention 类,交叉通道注意力机制层
class XCAttention(Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.norm = nn.LayerNorm(dim)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
# 定义前向传播函数,接受输入 x
def forward(self, x):
# 获取头数
h = self.heads
# 将输入 x 打包成指定格式,并返回打包后的数据和打包方案 ps
x, ps = pack_one(x, 'b * d')
# 对输入 x 进行归一化处理
x = self.norm(x)
# 将 x 转换为查询、键、值,并按最后一个维度分割成三部分
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
# 将查询、键、值按照指定格式重新排列
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h d n', h=h), (q, k, v))
# 对查询、键进行 L2 归一化处理
q, k = map(l2norm, (q, k))
# 计算注意力矩阵,包括计算相似度、温度调节和注意力计算
sim = einsum('b h i n, b h j n -> b h i j', q, k) * self.temperature.exp()
# 进行注意力聚合
attn = self.attend(sim)
# 对注意力矩阵进行 dropout 处理
attn = self.dropout(attn)
# 根据注意力矩阵和值计算输出
out = einsum('b h i j, b h j n -> b h i n', attn, v)
# 将输出按指定格式重新排列
out = rearrange(out, 'b h d n -> b n (h d)')
# 将输出解包成原始格式
out = unpack_one(out, ps, 'b * d')
# 返回输出结果
return self.to_out(out)
class LocalPatchInteraction(Module):
# 定义局部补丁交互模块,继承自 Module 类
def __init__(self, dim, kernel_size = 3):
# 初始化函数,接受维度 dim 和卷积核大小 kernel_size,默认为 3
super().__init__()
# 调用父类的初始化函数
assert (kernel_size % 2) == 1
# 断言卷积核大小为奇数
padding = kernel_size // 2
# 计算卷积的填充大小
self.net = nn.Sequential(
# 定义神经网络模块
nn.LayerNorm(dim),
# 对输入进行层归一化
Rearrange('b h w c -> b c h w'),
# 重新排列张量的维度
nn.Conv2d(dim, dim, kernel_size, padding = padding, groups = dim),
# 二维卷积层
nn.BatchNorm2d(dim),
# 对输入进行批归一化
nn.GELU(),
# GELU 激活函数
nn.Conv2d(dim, dim, kernel_size, padding = padding, groups = dim),
# 二维卷积层
Rearrange('b c h w -> b h w c'),
# 重新排列张量的维度
)
def forward(self, x):
# 前向传播函数,接受输入 x
return self.net(x)
# 返回经过网络处理后的结果
class Transformer(Module):
# 定义 Transformer 模块,继承自 Module 类
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., layer_dropout = 0.):
# 初始化函数,接受维度 dim、深度 depth、头数 heads、头维度 dim_head、MLP维度 mlp_dim、dropout率 dropout 和层dropout率 layer_dropout,默认为 0
super().__init__()
# 调用父类的初始化函数
self.layers = ModuleList([])
# 初始化模块列表
self.layer_dropout = layer_dropout
# 设置层dropout率
for ind in range(depth):
# 循环遍历深度次数
layer = ind + 1
# 计算当前层索引
self.layers.append(ModuleList([
# 向模块列表中添加模块列表
LayerScale(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout), depth = layer),
# 添加注意力机制模块
LayerScale(dim, FeedForward(dim, mlp_dim, dropout = dropout), depth = layer)
# 添加前馈神经网络模块
]))
def forward(self, x, context = None):
# 前向传播函数,接受输入 x 和上下文 context,默认为 None
layers = dropout_layers(self.layers, dropout = self.layer_dropout)
# 对模块列表进行层dropout处理
for attn, ff in layers:
# 遍历模块列表中的注意力机制和前馈神经网络模块
x = attn(x, context = context) + x
# 经过注意力机制处理后与原始输入相加
x = ff(x) + x
# 经过前馈神经网络处理后与原始输入相加
return x
# 返回处理后的结果
class XCATransformer(Module):
# 定义 XCAttention Transformer 模块,继承自 Module 类
def __init__(self, dim, depth, heads, dim_head, mlp_dim, local_patch_kernel_size = 3, dropout = 0., layer_dropout = 0.):
# 初始化函数,接受维度 dim、深度 depth、头数 heads、头维度 dim_head、MLP维度 mlp_dim、局部补丁卷积核大小 local_patch_kernel_size,默认为 3,dropout率 dropout 和层dropout率 layer_dropout,默认为 0
super().__init__()
# 调用父类的初始化函数
self.layers = ModuleList([])
# 初始化模块列表
self.layer_dropout = layer_dropout
# 设置层dropout率
for ind in range(depth):
# 循环遍历深度次数
layer = ind + 1
# 计算当前层索引
self.layers.append(ModuleList([
# 向模块列表中添加模块列表
LayerScale(dim, XCAttention(dim, heads = heads, dim_head = dim_head, dropout = dropout), depth = layer),
# 添加交叉协方差注意力机制模块
LayerScale(dim, LocalPatchInteraction(dim, local_patch_kernel_size), depth = layer),
# 添加局部补丁交互模块
LayerScale(dim, FeedForward(dim, mlp_dim, dropout = dropout), depth = layer)
# 添加前馈神经网络模块
]))
def forward(self, x):
# 前向传播函数,接受输入 x
layers = dropout_layers(self.layers, dropout = self.layer_dropout)
# 对模块列表进行层dropout处理
for cross_covariance_attn, local_patch_interaction, ff in layers:
# 遍历模块列表中的交叉协方差注意力机制、局部补丁交互和前馈神经网络模块
x = cross_covariance_attn(x) + x
# 经过交叉协方差注意力机制处理后与原始输入相加
x = local_patch_interaction(x) + x
# 经过局部补丁交互处理后与原始输入相加
x = ff(x) + x
# 经过前馈神经网络处理后与原始输入相加
return x
# 返回处理后的结果
class XCiT(Module):
# 定义 XCiT 模块,继承自 Module 类
def __init__(
self,
*,
image_size,
patch_size,
num_classes,
dim,
depth,
cls_depth,
heads,
mlp_dim,
dim_head = 64,
dropout = 0.,
emb_dropout = 0.,
local_patch_kernel_size = 3,
layer_dropout = 0.
):
# 初始化函数,接受关键字参数 image_size、patch_size、num_classes、dim、depth、cls_depth、heads、mlp_dim、dim_head、dropout、emb_dropout、局部补丁卷积核大小 local_patch_kernel_size,默认为 3,层dropout率 layer_dropout,默认为 0
super().__init__()
# 调用父类的初始化函数
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
# 断言图像尺寸必须能被补丁大小整除
num_patches = (image_size // patch_size) ** 2
# 计算补丁数量
patch_dim = 3 * patch_size ** 2
# 计算补丁维度
self.to_patch_embedding = nn.Sequential(
# 定义序列模块
Rearrange('b c (h p1) (w p2) -> b h w (p1 p2 c)', p1 = patch_size, p2 = patch_size),
# 重新排列张量的维度
nn.LayerNorm(patch_dim),
# 对输入进行层归一化
nn.Linear(patch_dim, dim),
# 线性变换
nn.LayerNorm(dim)
# 对输入进行层归一化
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, dim))
# 定义位置编码参数
self.cls_token = nn.Parameter(torch.randn(dim))
# 定义类别标记参数
self.dropout = nn.Dropout(emb_dropout)
# 定义丢弃层
self.xcit_transformer = XCATransformer(dim, depth, heads, dim_head, mlp_dim, local_patch_kernel_size, dropout, layer_dropout)
# 定义 XCAttention Transformer 模块
self.final_norm = nn.LayerNorm(dim)
# 对最终结果进行层归一化
self.cls_transformer = Transformer(dim, cls_depth, heads, dim_head, mlp_dim, dropout, layer_dropout)
# 定义 Transformer 模块
self.mlp_head = nn.Sequential(
# 定义序列模块
nn.LayerNorm(dim),
# 对输入进行层归一化
nn.Linear(dim, num_classes)
# 线性变换
)
# 定义 MLP 头部模块
# 前向传播函数,接收输入图像并进行处理
def forward(self, img):
# 将输入图像转换为补丁嵌入
x = self.to_patch_embedding(img)
# 将嵌入的补丁打包成一个张量
x, ps = pack_one(x, 'b * d')
# 获取张量的形状信息
b, n, _ = x.shape
# 添加位置嵌入到张量中
x += self.pos_embedding[:, :n]
# 解包张量
x = unpack_one(x, ps, 'b * d')
# 对张量进行 dropout 操作
x = self.dropout(x)
# 使用 XCIT Transformer 处理张量
x = self.xcit_transformer(x)
# 对处理后的张量进行最终归一化
x = self.final_norm(x)
# 重复生成类别标记 tokens
cls_tokens = repeat(self.cls_token, 'd -> b 1 d', b = b)
# 重新排列张量的维度
x = rearrange(x, 'b ... d -> b (...) d')
# 使用类别标记 tokens 和上下文张量进行类别 Transformer 操作
cls_tokens = self.cls_transformer(cls_tokens, context = x)
# 返回 MLP 头部处理后的结果
return self.mlp_head(cls_tokens[:, 0])
.\lucidrains\vit-pytorch\vit_pytorch\__init__.py
# 从 vit_pytorch.vit 模块中导入 ViT 类
from vit_pytorch.vit import ViT
# 从 vit_pytorch.simple_vit 模块中导入 SimpleViT 类
from vit_pytorch.simple_vit import SimpleViT
# 从 vit_pytorch.mae 模块中导入 MAE 类
from vit_pytorch.mae import MAE
# 从 vit_pytorch.dino 模块中导入 Dino 类
from vit_pytorch.dino import Dino
.\lucidrains\VN-transformer\denoise.py
# 导入 PyTorch 库
import torch
# 导入 PyTorch 中的函数库
import torch.nn.functional as F
# 从 torch.optim 模块中导入 Adam 优化器
from torch.optim import Adam
# 从 einops 库中导入 rearrange 和 repeat 函数
from einops import rearrange, repeat
# 导入 sidechainnet 库,并从 VN_transformer 模块中导入 VNTransformer 类
import sidechainnet as scn
from VN_transformer import VNTransformer
# 定义常量 BATCH_SIZE
BATCH_SIZE = 1
# 定义常量 GRADIENT_ACCUMULATE_EVERY
GRADIENT_ACCUMULATE_EVERY = 16
# 定义常量 MAX_SEQ_LEN
MAX_SEQ_LEN = 256
# 定义默认数据类型 DEFAULT_TYPE
DEFAULT_TYPE = torch.float64
# 设置 PyTorch 默认数据类型为 DEFAULT_TYPE
torch.set_default_dtype(DEFAULT_TYPE)
# 定义一个循环生成器函数 cycle,用于生成数据
def cycle(loader, len_thres = MAX_SEQ_LEN):
while True:
for data in loader:
# 如果数据的序列长度大于 len_thres,则继续循环
if data.seqs.shape[1] > len_thres:
continue
# 生成数据
yield data
# 创建 VNTransformer 模型对象
transformer = VNTransformer(
num_tokens = 24,
dim = 64,
depth = 4,
dim_head = 64,
heads = 8,
dim_feat = 64,
bias_epsilon = 1e-6,
l2_dist_attn = True,
flash_attn = False
).cuda()
# 加载数据集
data = scn.load(
casp_version = 12,
thinning = 30,
with_pytorch = 'dataloaders',
batch_size = BATCH_SIZE,
dynamic_batching = False
)
# 创建数据生成器 dl
dl = cycle(data['train'])
# 初始化 Adam 优化器
optim = Adam(transformer.parameters(), lr = 1e-4)
# 进行训练循环
for _ in range(10000):
for _ in range(GRADIENT_ACCUMULATE_EVERY):
# 获取一个 batch 的数据
batch = next(dl)
seqs, coords, masks = batch.seqs, batch.crds, batch.msks
# 将序列数据转移到 GPU 上,并取最大值作为索引
seqs = seqs.cuda().argmax(dim = -1)
# 将坐标数据转移到 GPU 上,并设置数据类型为默认类型
coords = coords.cuda().type(torch.get_default_dtype())
# 将掩码数据转移到 GPU 上,并转换为布尔类型
masks = masks.cuda().bool()
# 获取序列长度
l = seqs.shape[1]
# 重新排列坐标数据的维度
coords = rearrange(coords, 'b (l s) c -> b l s c', s = 14)
# 保留主干坐标
coords = coords[:, :, 0:3, :]
coords = rearrange(coords, 'b l s c -> b (l s) c')
# 将序列数据重复为坐标数据的维度
seq = repeat(seqs, 'b n -> b (n c)', c = 3)
masks = repeat(masks, 'b n -> b (n c)', c = 3)
# 给坐标数据添加高斯噪声
noised_coords = coords + torch.randn_like(coords).cuda()
# 运行 Transformer 模型
type1_out, _ = transformer(
noised_coords,
feats = seq,
mask = masks
)
# 去噪后的坐标数据
denoised_coords = noised_coords + type1_out
# 计算均方误差损失
loss = F.mse_loss(denoised_coords[masks], coords[masks])
# 反向传播并计算梯度
(loss / GRADIENT_ACCUMULATE_EVERY).backward()
# 输出当前损失值
print('loss:', loss.item())
# 更新优化器参数
optim.step()
# 清空梯度
optim.zero_grad()
VN (Vector Neuron) Transformer
A Transformer made of Rotation-equivariant Attention using Vector Neurons
Appreciation
- StabilityAI for the generous sponsorship, as well as my other sponsors, for affording me the independence to open source artificial intelligence.
Install
$ pip install VN-transformer
Usage
import torch
from VN_transformer import VNTransformer
model = VNTransformer(
dim = 64,
depth = 2,
dim_head = 64,
heads = 8,
dim_feat = 64, # will default to early fusion, since this was the best performing
bias_epsilon = 1e-6 # in this paper, they propose breaking equivariance with a tiny bit of bias noise in the VN linear. they claim this leads to improved stability. setting this to 0 would turn off the epsilon approximate equivariance
)
coors = torch.randn(1, 32, 3) # (batch, sequence, spatial coordinates)
feats = torch.randn(1, 32, 64)
coors_out, feats_out = model(coors, feats = feats) # (1, 32, 3), (1, 32, 64)
Tests
Confidence in equivariance
$ python setup.py test
Example
First install sidechainnet
$ pip install sidechainnet
Then run the protein backbone denoising task
$ python denoise.py
It does not perform as well as En-Transformer, nor Equiformer
Citations
@inproceedings{Assaad2022VNTransformerRA,
title = {VN-Transformer: Rotation-Equivariant Attention for Vector Neurons},
author = {Serge Assaad and C. Downey and Rami Al-Rfou and Nigamaa Nayakanti and Benjamin Sapp},
year = {2022}
}
@article{Deng2021VectorNA,
title = {Vector Neurons: A General Framework for SO(3)-Equivariant Networks},
author = {Congyue Deng and Or Litany and Yueqi Duan and Adrien Poulenard and Andrea Tagliasacchi and Leonidas J. Guibas},
journal = {2021 IEEE/CVF International Conference on Computer Vision (ICCV)},
year = {2021},
pages = {12180-12189},
url = {https://api.semanticscholar.org/CorpusID:233394028}
}
@inproceedings{Kim2020TheLC,
title = {The Lipschitz Constant of Self-Attention},
author = {Hyunjik Kim and George Papamakarios and Andriy Mnih},
booktitle = {International Conference on Machine Learning},
year = {2020},
url = {https://api.semanticscholar.org/CorpusID:219530837}
}
@inproceedings{dao2022flashattention,
title = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
author = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
booktitle = {Advances in Neural Information Processing Systems},
year = {2022}
}
.\lucidrains\VN-transformer\setup.py
# 导入设置安装和查找包的函数
from setuptools import setup, find_packages
# 设置包的元数据
setup(
name = 'VN-transformer', # 包的名称
packages = find_packages(exclude=[]), # 查找所有包
version = '0.1.0', # 版本号
license='MIT', # 许可证
description = 'Vector Neuron Transformer (VN-Transformer)', # 描述
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
long_description_content_type = 'text/markdown', # 长描述内容类型
url = 'https://github.com/lucidrains/VN-transformer', # 项目链接
keywords = [ # 关键词列表
'artificial intelligence',
'deep learning',
'equivariance',
'vector neurons',
'transformers',
'attention mechanism'
],
install_requires=[ # 安装依赖
'einops>=0.6.0',
'torch>=1.6'
],
setup_requires=[ # 设置依赖
'pytest-runner',
],
tests_require=[ # 测试依赖
'pytest'
],
classifiers=[ # 分类器
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.8',
],
)
.\lucidrains\VN-transformer\tests\test.py
# 导入 pytest 库
import pytest
# 导入 torch 库
import torch
# 从 VN_transformer 模块中导入 VNTransformer、VNInvariant、VNAttention 类和 rot 函数
from VN_transformer.VN_transformer import VNTransformer, VNInvariant, VNAttention
from VN_transformer.rotations import rot
# 设置默认的 torch 数据类型为 float64
torch.set_default_dtype(torch.float64)
# 测试不变层
def test_vn_invariant():
# 创建一个 VNInvariant 层对象,输入维度为 64
layer = VNInvariant(64)
# 生成一个形状为 (1, 32, 64, 3) 的随机张量
coors = torch.randn(1, 32, 64, 3)
# 生成一个随机旋转矩阵 R
R = rot(*torch.randn(3))
# 对输入张量和经过旋转的输入张量进行 VNInvariant 层的计算
out1 = layer(coors)
out2 = layer(coors @ R)
# 检查经过不变层计算的两个输出张量是否在给定的容差范围内相等
assert torch.allclose(out1, out2, atol = 1e-6)
# 测试等变性
@pytest.mark.parametrize('l2_dist_attn', [True, False])
def test_equivariance(l2_dist_attn):
# 创建一个 VNTransformer 模型对象,设置相关参数
model = VNTransformer(
dim = 64,
depth = 2,
dim_head = 64,
heads = 8,
l2_dist_attn = l2_dist_attn
)
# 生成一个形状为 (1, 32, 3) 的随机张量
coors = torch.randn(1, 32, 3)
# 创建一个形状为 (1, 32) 的全为 True 的布尔张量
mask = torch.ones(1, 32).bool()
# 生成一个随机旋转矩阵 R
R = rot(*torch.randn(3))
# 对输入张量和经过旋转的输入张量进行 VNTransformer 模型的计算
out1 = model(coors @ R, mask = mask)
out2 = model(coors, mask = mask) @ R
# 检查经过模型计算的两个输出张量是否在给定的容差范围内相等
assert torch.allclose(out1, out2, atol = 1e-6), 'is not equivariant'
# 测试 VN Perceiver 注意力等变性
@pytest.mark.parametrize('l2_dist_attn', [True, False])
def test_perceiver_vn_attention_equivariance(l2_dist_attn):
# 创建一个 VNAttention 模型对象,设置相关参数
model = VNAttention(
dim = 64,
dim_head = 64,
heads = 8,
num_latents = 2,
l2_dist_attn = l2_dist_attn
)
# 生成一个形状为 (1, 32, 64, 3) 的随机张量
coors = torch.randn(1, 32, 64, 3)
# 创建一个形状为 (1, 32) 的全为 True 的布尔张量
mask = torch.ones(1, 32).bool()
# 生成一个随机旋转矩阵 R
R = rot(*torch.randn(3))
# 对输入张量和经过旋转的输入张量进行 VNAttention 模型的计算
out1 = model(coors @ R, mask = mask)
out2 = model(coors, mask = mask) @ R
# ��查输出张量的形状是否符合预期
assert out1.shape[1] == 2
# 检查经过模型计算的两个输出张量是否在给定的容差范围内相等
assert torch.allclose(out1, out2, atol = 1e-6), 'is not equivariant'
# 测试 SO(3) 早期融合等变性
@pytest.mark.parametrize('l2_dist_attn', [True, False])
def test_equivariance_with_early_fusion(l2_dist_attn):
# 创建一个 VNTransformer 模型对象,设置相关参数
model = VNTransformer(
dim = 64,
depth = 2,
dim_head = 64,
heads = 8,
dim_feat = 64,
l2_dist_attn = l2_dist_attn
)
# 生成一个形状为 (1, 32, 64) 的随机张量
feats = torch.randn(1, 32, 64)
# 生成一个形状为 (1, 32, 3) 的随机张量
coors = torch.randn(1, 32, 3)
# 创建一个形状为 (1, 32) 的全为 True 的布尔张量
mask = torch.ones(1, 32).bool()
# 生成一个随机旋转矩阵 R
R = rot(*torch.randn(3))
# 对输入张量和特征张量进行 VNTransformer 模型的计算
out1, _ = model(coors @ R, feats = feats, mask = mask, return_concatted_coors_and_feats = False)
out2, _ = model(coors, feats = feats, mask = mask, return_concatted_coors_and_feats = False)
out2 = out2 @ R
# 检查经过模型计算的两个输出张量是否在给定的容差范围内相等
assert torch.allclose(out1, out2, atol = 1e-6), 'is not equivariant'
# 测试 SE(3) 早期融合等变性
@pytest.mark.parametrize('l2_dist_attn', [True, False])
def test_se3_equivariance_with_early_fusion(l2_dist_attn):
# 创建一个 VNTransformer 模型对象,设置相关参数
model = VNTransformer(
dim = 64,
depth = 2,
dim_head = 64,
heads = 8,
dim_feat = 64,
translation_equivariance = True,
l2_dist_attn = l2_dist_attn
)
# 生成一个形状为 (1, 32, 64) 的随机张量
feats = torch.randn(1, 32, 64)
# 生成一个形状为 (1, 32, 3) 的随机张量
coors = torch.randn(1, 32, 3)
# 创建一个形状为 (1, 32) 的全为 True 的布尔张量
mask = torch.ones(1, 32).bool()
# 生成一个随机平移向量 T 和旋转矩阵 R
T = torch.randn(3)
R = rot(*torch.randn(3))
# 对输入张量和特征张量进行 VNTransformer 模型的计算
out1, _ = model((coors + T) @ R, feats = feats, mask = mask, return_concatted_coors_and_feats = False)
out2, _ = model(coors, feats = feats, mask = mask, return_concatted_coors_and_feats = False)
out2 = (out2 + T) @ R
# 检查经过模型计算的两个输出张量是否在给定的容差范围内相等
assert torch.allclose(out1, out2, atol = 1e-6), 'is not equivariant'
.\lucidrains\VN-transformer\VN_transformer\attend.py
# 导入必要的库
from functools import wraps
from packaging import version
from collections import namedtuple
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange, reduce
# 定义一个命名元组 FlashAttentionConfig,包含三个布尔类型的参数
FlashAttentionConfig = namedtuple('FlashAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
# 定义一个辅助函数,用于检查变量是否存在
def exists(val):
return val is not None
# 定义一个装饰器函数,确保被装饰的函数只能调用一次
def once(fn):
called = False
@wraps(fn)
def inner(x):
nonlocal called
if called:
return
called = True
return fn(x)
return inner
# 定义一个只能打印一次的函数
print_once = once(print)
# 主要类 Attend
class Attend(nn.Module):
def __init__(
self,
dropout = 0.,
flash = False,
l2_dist = False
):
super().__init__()
assert not (flash and l2_dist), 'flash attention is not compatible with l2 distance'
self.l2_dist = l2_dist
self.dropout = dropout
self.attn_dropout = nn.Dropout(dropout)
self.flash = flash
assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
# 确定用于 cuda 和 cpu 的高效注意力配置
self.cpu_config = FlashAttentionConfig(True, True, True)
self.cuda_config = None
if not torch.cuda.is_available() or not flash:
return
device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
if device_properties.major == 8 and device_properties.minor == 0:
print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
self.cuda_config = FlashAttentionConfig(True, False, False)
else:
print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
self.cuda_config = FlashAttentionConfig(False, True, True)
# Flash Attention 函数
def flash_attn(self, q, k, v, mask = None):
_, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device
# 检查是否存在 mask 并将其扩展到兼容的形状
# mask 的形状为 B L,需要扩展为 B H N L
if exists(mask):
mask = mask.expand(-1, heads, q_len, -1)
# 检查是否有兼容的设备用于 Flash Attention
config = self.cuda_config if is_cuda else self.cpu_config
# 使用 torch.backends.cuda.sdp_kernel(**config._asdict()) 来调用 pytorch 2.0 的 flash attention
with torch.backends.cuda.sdp_kernel(**config._asdict()):
out = F.scaled_dot_product_attention(
q, k, v,
attn_mask = mask,
dropout_p = self.dropout if self.training else 0.
)
return out
# 定义一个前向传播函数,接受查询(q)、键(k)、值(v)和可选的掩码(mask)
def forward(self, q, k, v, mask = None):
"""
einstein notation
b - batch
h - heads
n, i, j - sequence length (base sequence length, source, target)
d - feature dimension
"""
# 获取查询(q)和键(k)的序列长度以及设备信息
q_len, k_len, device = q.shape[-2], k.shape[-2], q.device
# 缩放因子,根据特征维度的倒数开根号
scale = q.shape[-1] ** -0.5
# 如果存在掩码(mask)且维度不是4,则重新排列掩码的维度
if exists(mask) and mask.ndim != 4:
mask = rearrange(mask, 'b j -> b 1 1 j')
# 如果启用了flash,则调用flash_attn函数
if self.flash:
return self.flash_attn(q, k, v, mask = mask)
# 相似度计算
sim = einsum(f"b h i d, b h j d -> b h i j", q, k) * scale
# L2距离计算
if self.l2_dist:
# -cdist squared == (-q^2 + 2qk - k^2)
# 因此简单地基于上面的qk进行计算
q_squared = reduce(q ** 2, 'b h i d -> b h i 1', 'sum')
k_squared = reduce(k ** 2, 'b h j d -> b h 1 j', 'sum')
sim = sim * 2 - q_squared - k_squared
# 键填充掩码
if exists(mask):
sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
# 注意力计算
attn = sim.softmax(dim=-1)
attn = self.attn_dropout(attn)
# 聚合值
out = einsum(f"b h i j, b h j d -> b h i d", attn, v)
return out
.\lucidrains\VN-transformer\VN_transformer\rotations.py
# 导入 torch 库
import torch
# 从 torch 库中导入 sin, cos, atan2, acos 函数
from torch import sin, cos, atan2, acos
# 定义绕 z 轴旋转的函数,参数为旋转角度 gamma
def rot_z(gamma):
# 返回绕 z 轴旋转的旋转矩阵
return torch.tensor([
[cos(gamma), -sin(gamma), 0],
[sin(gamma), cos(gamma), 0],
[0, 0, 1]
], dtype=gamma.dtype)
# 定义绕 y 轴旋转的函数,参数为旋转角度 beta
def rot_y(beta):
# 返回绕 y 轴旋转的旋转矩阵
return torch.tensor([
[cos(beta), 0, sin(beta)],
[0, 1, 0],
[-sin(beta), 0, cos(beta)]
], dtype=beta.dtype)
# 定义绕任意轴旋转的函数,参数为三个旋转角度 alpha, beta, gamma
def rot(alpha, beta, gamma):
# 返回绕任意轴旋转的旋转矩阵,先绕 z 轴旋转 alpha,再绕 y 轴旋转 beta,最后绕 z 轴旋转 gamma
return rot_z(alpha) @ rot_y(beta) @ rot_z(gamma)
.\lucidrains\VN-transformer\VN_transformer\VN_transformer.py
# 导入 torch 库
import torch
# 导入 torch 中的函数库
import torch.nn.functional as F
# 从 torch 中导入 nn, einsum, Tensor
from torch import nn, einsum, Tensor
# 从 einops 中导入 rearrange, repeat, reduce
from einops import rearrange, repeat, reduce
# 从 einops.layers.torch 中导入 Rearrange, Reduce
from einops.layers.torch import Rearrange, Reduce
# 从 VN_transformer.attend 中导入 Attend
# 辅助函数
# 判断变量是否存在
def exists(val):
return val is not None
# 如果变量存在则返回其值,否则返回默认值
def default(val, d):
return val if exists(val) else d
# 计算两个向量的内积
def inner_dot_product(x, y, *, dim = -1, keepdim = True):
return (x * y).sum(dim = dim, keepdim = keepdim)
# layernorm
# LayerNorm 类
class LayerNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.gamma = nn.Parameter(torch.ones(dim))
self.register_buffer('beta', torch.zeros(dim))
def forward(self, x):
return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
# equivariant modules
# VNLinear 类
class VNLinear(nn.Module):
def __init__(
self,
dim_in,
dim_out,
bias_epsilon = 0.
):
super().__init__()
self.weight = nn.Parameter(torch.randn(dim_out, dim_in))
self.bias = None
self.bias_epsilon = bias_epsilon
# 在这篇论文中,他们提出使用一个小偏置进行准等变性,通过 epsilon 可控,他们声称这样可以获得更好的稳定性和结果
if bias_epsilon > 0.:
self.bias = nn.Parameter(torch.randn(dim_out))
def forward(self, x):
out = einsum('... i c, o i -> ... o c', x, self.weight)
if exists(self.bias):
bias = F.normalize(self.bias, dim = -1) * self.bias_epsilon
out = out + rearrange(bias, '... -> ... 1')
return out
# VNReLU 类
class VNReLU(nn.Module):
def __init__(self, dim, eps = 1e-6):
super().__init__()
self.eps = eps
self.W = nn.Parameter(torch.randn(dim, dim))
self.U = nn.Parameter(torch.randn(dim, dim))
def forward(self, x):
q = einsum('... i c, o i -> ... o c', x, self.W)
k = einsum('... i c, o i -> ... o c', x, self.U)
qk = inner_dot_product(q, k)
k_norm = k.norm(dim = -1, keepdim = True).clamp(min = self.eps)
q_projected_on_k = q - inner_dot_product(q, k / k_norm) * k
out = torch.where(
qk >= 0.,
q,
q_projected_on_k
)
return out
# VNAttention 类
class VNAttention(nn.Module):
def __init__(
self,
dim,
dim_head = 64,
heads = 8,
dim_coor = 3,
bias_epsilon = 0.,
l2_dist_attn = False,
flash = False,
num_latents = None # 设置此参数将启用类似于 perceiver 的跨注意力机制,从潜在变量到序列,潜在变量由 VNWeightedPool 推导而来
):
super().__init__()
assert not (l2_dist_attn and flash), 'l2 distance attention is not compatible with flash attention'
self.scale = (dim_coor * dim_head) ** -0.5
dim_inner = dim_head * heads
self.heads = heads
self.to_q_input = None
if exists(num_latents):
self.to_q_input = VNWeightedPool(dim, num_pooled_tokens = num_latents, squeeze_out_pooled_dim = False)
self.to_q = VNLinear(dim, dim_inner, bias_epsilon = bias_epsilon)
self.to_k = VNLinear(dim, dim_inner, bias_epsilon = bias_epsilon)
self.to_v = VNLinear(dim, dim_inner, bias_epsilon = bias_epsilon)
self.to_out = VNLinear(dim_inner, dim, bias_epsilon = bias_epsilon)
if l2_dist_attn and not exists(num_latents):
# 对于 l2 距离注意力,查询和键是相同的,不是 perceiver-like 注意力
self.to_k = self.to_q
self.attend = Attend(flash = flash, l2_dist = l2_dist_attn)
# 定义一个前向传播函数,接受输入 x 和可选的 mask 参数
def forward(self, x, mask = None):
"""
einstein notation
b - batch
n - sequence
h - heads
d - feature dimension (channels)
c - coordinate dimension (3 for 3d space)
i - source sequence dimension
j - target sequence dimension
"""
# 获取输入 x 的最后一个维度,即特征维度的大小
c = x.shape[-1]
# 如果存在 self.to_q_input 方法,则使用该方法处理输入 x 和 mask,否则直接使用 x
if exists(self.to_q_input):
q_input = self.to_q_input(x, mask = mask)
else:
q_input = x
# 分别通过 self.to_q、self.to_k、self.to_v 方法处理 q_input,得到 q、k、v
q, k, v = self.to_q(q_input), self.to_k(x), self.to_v(x)
# 将 q、k、v 重排维度,将其转换为 'b h n (d c)' 的形式
q, k, v = map(lambda t: rearrange(t, 'b n (h d) c -> b h n (d c)', h = self.heads), (q, k, v))
# 调用 attend 方法进行注意力计算
out = self.attend(q, k, v, mask = mask)
# 将输出 out 重排维度,将其转换为 'b n (h d) c' 的形式
out = rearrange(out, 'b h n (d c) -> b n (h d) c', c = c)
# 返回处理后的输出结果
return self.to_out(out)
# 定义一个 VNFeedForward 类,包含线性层、ReLU 激活函数和另一个线性层
def VNFeedForward(dim, mult = 4, bias_epsilon = 0.):
# 计算内部维度
dim_inner = int(dim * mult)
# 返回一个包含上述三个层的序列模块
return nn.Sequential(
VNLinear(dim, dim_inner, bias_epsilon = bias_epsilon), # VNLinear 线性层
VNReLU(dim_inner), # VNReLU 激活函数
VNLinear(dim_inner, dim, bias_epsilon = bias_epsilon) # 另一个 VNLinear 线性层
)
# 定义一个 VNLayerNorm 类,包含 LayerNorm 层
class VNLayerNorm(nn.Module):
def __init__(self, dim, eps = 1e-6):
super().__init__()
self.eps = eps
self.ln = LayerNorm(dim) # LayerNorm 层
def forward(self, x):
norms = x.norm(dim = -1)
x = x / rearrange(norms.clamp(min = self.eps), '... -> ... 1')
ln_out = self.ln(norms)
return x * rearrange(ln_out, '... -> ... 1')
# 定义一个 VNWeightedPool 类,包含权重参数和池化操作
class VNWeightedPool(nn.Module):
def __init__(
self,
dim,
dim_out = None,
num_pooled_tokens = 1,
squeeze_out_pooled_dim = True
):
super().__init__()
dim_out = default(dim_out, dim)
self.weight = nn.Parameter(torch.randn(num_pooled_tokens, dim, dim_out)) # 权重参数
self.squeeze_out_pooled_dim = num_pooled_tokens == 1 and squeeze_out_pooled_dim
def forward(self, x, mask = None):
if exists(mask):
mask = rearrange(mask, 'b n -> b n 1 1')
x = x.masked_fill(~mask, 0.)
numer = reduce(x, 'b n d c -> b d c', 'sum')
denom = mask.sum(dim = 1)
mean_pooled = numer / denom.clamp(min = 1e-6)
else:
mean_pooled = reduce(x, 'b n d c -> b d c', 'mean')
out = einsum('b d c, m d e -> b m e c', mean_pooled, self.weight)
if not self.squeeze_out_pooled_dim:
return out
out = rearrange(out, 'b 1 d c -> b d c')
return out
# 定义一个 VNTransformerEncoder 类,包含多层 VNAttention、VNLayerNorm 和 VNFeedForward
class VNTransformerEncoder(nn.Module):
def __init__(
self,
dim,
*,
depth,
dim_head = 64,
heads = 8,
dim_coor = 3,
ff_mult = 4,
final_norm = False,
bias_epsilon = 0.,
l2_dist_attn = False,
flash_attn = False
):
super().__init__()
self.dim = dim
self.dim_coor = dim_coor
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
VNAttention(dim = dim, dim_head = dim_head, heads = heads, bias_epsilon = bias_epsilon, l2_dist_attn = l2_dist_attn, flash = flash_attn), # VNAttention 层
VNLayerNorm(dim), # VNLayerNorm 层
VNFeedForward(dim = dim, mult = ff_mult, bias_epsilon = bias_epsilon), # VNFeedForward 层
VNLayerNorm(dim) # 另一个 VNLayerNorm 层
]))
self.norm = VNLayerNorm(dim) if final_norm else nn.Identity()
def forward(
self,
x,
mask = None
):
*_, d, c = x.shape
assert x.ndim == 4 and d == self.dim and c == self.dim_coor, 'input needs to be in the shape of (batch, seq, dim ({self.dim}), coordinate dim ({self.dim_coor}))'
for attn, attn_post_ln, ff, ff_post_ln in self.layers:
x = attn_post_ln(attn(x, mask = mask)) + x
x = ff_post_ln(ff(x)) + x
return self.norm(x)
# 定义一个 VNInvariant 类,包含 MLP 模块
class VNInvariant(nn.Module):
def __init__(
self,
dim,
dim_coor = 3,
):
super().__init__()
self.mlp = nn.Sequential(
VNLinear(dim, dim_coor), # VNLinear 线性层
VNReLU(dim_coor), # VNReLU 激活函数
Rearrange('... d e -> ... e d') # 重新排列维度
)
def forward(self, x):
return einsum('b n d i, b n i o -> b n o', x, self.mlp(x))
# 定义一个 VNTransformer 类,包含多个参数和模块
class VNTransformer(nn.Module):
def __init__(
self,
*,
dim,
depth,
num_tokens = None,
dim_feat = None,
dim_head = 64,
heads = 8,
dim_coor = 3,
reduce_dim_out = True,
bias_epsilon = 0.,
l2_dist_attn = False,
flash_attn = False,
translation_equivariance = False,
translation_invariant = False
):
# 调用父类的构造函数
super().__init__()
# 如果 num_tokens 存在,则创建一个维度为 dim 的嵌入层
self.token_emb = nn.Embedding(num_tokens, dim) if exists(num_tokens) else None
# 设置特征维度为 dim_feat 或默认为 0
dim_feat = default(dim_feat, 0)
self.dim_feat = dim_feat
# 计算坐标总维度,包括坐标和特征
self.dim_coor_total = dim_coor + dim_feat
# 确保平移等变性和平移不变性最多只能有一个为真
assert (int(translation_equivariance) + int(translation_invariant)) <= 1
self.translation_equivariance = translation_equivariance
self.translation_invariant = translation_invariant
# 定义输入投影层
self.vn_proj_in = nn.Sequential(
Rearrange('... c -> ... 1 c'),
VNLinear(1, dim, bias_epsilon = bias_epsilon)
)
# 创建 VNTransformerEncoder 编码器
self.encoder = VNTransformerEncoder(
dim = dim,
depth = depth,
dim_head = dim_head,
heads = heads,
bias_epsilon = bias_epsilon,
dim_coor = self.dim_coor_total,
l2_dist_attn = l2_dist_attn,
flash_attn = flash_attn
)
# 如果需要减少输出维度,则定义输出投影层
if reduce_dim_out:
self.vn_proj_out = nn.Sequential(
VNLayerNorm(dim),
VNLinear(dim, 1, bias_epsilon = bias_epsilon),
Rearrange('... 1 c -> ... c')
)
else:
self.vn_proj_out = nn.Identity()
def forward(
self,
coors,
*,
feats = None,
mask = None,
return_concatted_coors_and_feats = False
):
# 如果需要平移等变性或平移不变性,则计算坐标的平均值并减去
if self.translation_equivariance or self.translation_invariant:
coors_mean = reduce(coors, '... c -> c', 'mean')
coors = coors - coors_mean
x = coors
# 如果存在特征,则将特征拼接到坐标中
if exists(feats):
if feats.dtype == torch.long:
assert exists(self.token_emb), 'num_tokens must be given to the VNTransformer (to build the Embedding), if the features are to be given as indices'
feats = self.token_emb(feats)
assert feats.shape[-1] == self.dim_feat, f'dim_feat should be set to {feats.shape[-1]}'
x = torch.cat((x, feats), dim = -1)
assert x.shape[-1] == self.dim_coor_total
# 输入投影层
x = self.vn_proj_in(x)
# 编码器
x = self.encoder(x, mask = mask)
# 输出投影层
x = self.vn_proj_out(x)
# 提取坐标和特征
coors_out, feats_out = x[..., :3], x[..., 3:]
# 如果需要平移等变性,则将坐标输出加上坐标平均值
if self.translation_equivariance:
coors_out = coors_out + coors_mean
# 如果没有特征,则返回坐标输出
if not exists(feats):
return coors_out
# 如果需要返回拼接的坐标和特征,则返回拼接后的结果
if return_concatted_coors_and_feats:
return torch.cat((coors_out, feats_out), dim = -1)
# 否则返回坐标和特征分开的结果
return coors_out, feats_out
.\lucidrains\VN-transformer\VN_transformer\__init__.py
# 从VN_transformer.VN_transformer模块中导入以下类和函数
from VN_transformer.VN_transformer import (
VNTransformer, # 导入VNTransformer类
VNLinear, # 导入VNLinear类
VNLayerNorm, # 导入VNLayerNorm类
VNFeedForward, # 导入VNFeedForward类
VNAttention, # 导入VNAttention类
VNWeightedPool, # 导入VNWeightedPool类
VNTransformerEncoder, # 导入VNTransformerEncoder类
VNInvariant # 导入VNInvariant类
)