Lucidrains-系列项目源码解析-一百一十一-

45 阅读26分钟

Lucidrains 系列项目源码解析(一百一十一)

.\lucidrains\vit-pytorch\vit_pytorch\vit_with_patch_dropout.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块
from torch import nn

# 从 einops 库中导入 rearrange 和 repeat 函数,以及 Rearrange 类
from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# 辅助函数

# 如果输入 t 是元组,则返回 t,否则返回包含 t 的元组
def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# 类定义

# 定义 PatchDropout 类,继承自 nn.Module
class PatchDropout(nn.Module):
    # 初始化函数,接受概率参数 prob
    def __init__(self, prob):
        super().__init__()
        # 断言概率在 [0, 1) 范围内
        assert 0 <= prob < 1.
        self.prob = prob

    # 前向传播函数,接受输入 x
    def forward(self, x):
        # 如果不在训练模式或概率为 0,则直接返回输入 x
        if not self.training or self.prob == 0.:
            return x

        # 获取输入 x 的形状信息
        b, n, _, device = *x.shape, x.device

        # 生成 batch 索引
        batch_indices = torch.arange(b, device = device)
        batch_indices = rearrange(batch_indices, '... -> ... 1')
        # 计算保留的 patch 数量
        num_patches_keep = max(1, int(n * (1 - self.prob)))
        # 生成保留的 patch 索引
        patch_indices_keep = torch.randn(b, n, device = device).topk(num_patches_keep, dim = -1).indices

        return x[batch_indices, patch_indices_keep]

# 定义 FeedForward 类,继承自 nn.Module
class FeedForward(nn.Module):
    # 初始化函数,接受维度 dim、隐藏层维度 hidden_dim 和 dropout 参数
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        # 定义网络结构
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    # 前向传播函数,接受输入 x
    def forward(self, x):
        return self.net(x)

# 定义 Attention 类,继承自 nn.Module
class Attention(nn.Module):
    # 初始化函数,接受维度 dim、头数 heads、头维度 dim_head 和 dropout 参数
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.norm = nn.LayerNorm(dim)
        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    # 前向传播函数,接受输入 x
    def forward(self, x):
        x = self.norm(x)
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

# 定义 Transformer 类,继承自 nn.Module
class Transformer(nn.Module):
    # 初始化函数,接受维度 dim、深度 depth、头数 heads、头维度 dim_head、MLP 维度 mlp_dim 和 dropout 参数
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        # 根据深度循环创建多个 Transformer 层
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
                FeedForward(dim, mlp_dim, dropout = dropout)
            ]))
    # 前向传播函数,接受输入 x
    def forward(self, x):
        # 遍历每个 Transformer 层
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

# 定义 ViT 类,继承自 nn.Module
class ViT(nn.Module):
    # 初始化函数,设置模型参数
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., patch_dropout = 0.25):
        # 调用父类的初始化函数
        super().__init__()
        # 获取图像的高度和宽度
        image_height, image_width = pair(image_size)
        # 获取补丁的高度和宽度
        patch_height, patch_width = pair(patch_size)

        # 断言图像的高度和宽度能够被补丁的高度和宽度整除
        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        # 计算补丁的数量
        num_patches = (image_height // patch_height) * (image_width // patch_width)
        # 计算每个补丁的维度
        patch_dim = channels * patch_height * patch_width
        # 断言池化类型只能是'cls'(CLS标记)或'mean'(平均池化)
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        # 将图像转换为补丁嵌入
        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.Linear(patch_dim, dim),
        )

        # 初始化位置嵌入参数
        self.pos_embedding = nn.Parameter(torch.randn(num_patches, dim))
        # 初始化CLS标记
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))

        # 创建补丁丢弃层
        self.patch_dropout = PatchDropout(patch_dropout)
        # 创建嵌入丢弃层
        self.dropout = nn.Dropout(emb_dropout)

        # 创建Transformer模型
        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        # 设置池化类型
        self.pool = pool
        # 创建转换到潜在空间的层
        self.to_latent = nn.Identity()

        # 创建MLP头部
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    # 前向传播函数
    def forward(self, img):
        # 将图像转换为补丁嵌入
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        # 添加位置嵌入
        x += self.pos_embedding

        # 对补丁进行丢弃
        x = self.patch_dropout(x)

        # 重复CLS标记以匹配批次大小
        cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)

        # 将CLS标记和补丁连接在一起
        x = torch.cat((cls_tokens, x), dim=1)
        x = self.dropout(x)

        # 使用Transformer进行特征提取
        x = self.transformer(x)

        # 池化操作,根据池化类型选择不同的方式
        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

        # 转换到潜在空间
        x = self.to_latent(x)
        # 使用MLP头部进行分类预测
        return self.mlp_head(x)

.\lucidrains\vit-pytorch\vit_pytorch\vit_with_patch_merger.py

import torch
from torch import nn

from einops import rearrange, repeat
from einops.layers.torch import Rearrange, Reduce

# helpers

# 检查值是否存在
def exists(val):
    return val is not None

# 返回值或默认值
def default(val ,d):
    return val if exists(val) else d

# 将输入转换为元组
def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# patch merger class

# 定义 PatchMerger 类
class PatchMerger(nn.Module):
    def __init__(self, dim, num_tokens_out):
        super().__init__()
        self.scale = dim ** -0.5
        self.norm = nn.LayerNorm(dim)
        self.queries = nn.Parameter(torch.randn(num_tokens_out, dim))

    def forward(self, x):
        x = self.norm(x)
        sim = torch.matmul(self.queries, x.transpose(-1, -2)) * self.scale
        attn = sim.softmax(dim = -1)
        return torch.matmul(attn, x)

# classes

# 定义 FeedForward 类
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

# 定义 Attention 类
class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.norm = nn.LayerNorm(dim)
        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        x = self.norm(x)
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

# 定义 Transformer 类
class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., patch_merge_layer = None, patch_merge_num_tokens = 8):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.layers = nn.ModuleList([])

        self.patch_merge_layer_index = default(patch_merge_layer, depth // 2) - 1 # default to mid-way through transformer, as shown in paper
        self.patch_merger = PatchMerger(dim = dim, num_tokens_out = patch_merge_num_tokens)

        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
                FeedForward(dim, mlp_dim, dropout = dropout)
            ]))
    def forward(self, x):
        for index, (attn, ff) in enumerate(self.layers):
            x = attn(x) + x
            x = ff(x) + x

            if index == self.patch_merge_layer_index:
                x = self.patch_merger(x)

        return self.norm(x)

class ViT(nn.Module):
    # 初始化函数,设置模型参数
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, patch_merge_layer = None, patch_merge_num_tokens = 8, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        # 调用父类的初始化函数
        super().__init__()
        # 获取图像的高度和宽度
        image_height, image_width = pair(image_size)
        # 获取补丁的高度和宽度
        patch_height, patch_width = pair(patch_size)

        # 检查图像的尺寸是否能被补丁的尺寸整除
        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        # 计算补丁的数量
        num_patches = (image_height // patch_height) * (image_width // patch_width)
        # 计算每个补丁的维度
        patch_dim = channels * patch_height * patch_width

        # 定义将图像转换为补丁嵌入的序列
        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim)
        )

        # 初始化位置嵌入参数
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        # 定义丢弃层
        self.dropout = nn.Dropout(emb_dropout)

        # 初始化Transformer模型
        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, patch_merge_layer, patch_merge_num_tokens)

        # 定义MLP头部
        self.mlp_head = nn.Sequential(
            Reduce('b n d -> b d', 'mean'),
            nn.Linear(dim, num_classes)
        )

    # 前向传播函数
    def forward(self, img):
        # 将图像转换为补丁嵌入
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        # 添加位置嵌入到补丁嵌入中
        x += self.pos_embedding[:, :n]
        x = self.dropout(x)

        # 使用Transformer进行特征提取
        x = self.transformer(x)

        # 使用MLP头部进行分类
        return self.mlp_head(x)

.\lucidrains\vit-pytorch\vit_pytorch\vivit.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块
from torch import nn

# 从 einops 库中导入 rearrange, repeat, reduce 函数
from einops import rearrange, repeat, reduce
# 从 einops.layers.torch 库中导入 Rearrange 类
from einops.layers.torch import Rearrange

# 辅助函数

# 判断值是否存在的函数
def exists(val):
    return val is not None

# 将输入转换为元组的函数
def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# 类

# 前馈神经网络类
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        # 定义神经网络结构
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

# 注意力机制类
class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.norm = nn.LayerNorm(dim)
        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        x = self.norm(x)
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

# 变换器类
class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
                FeedForward(dim, mlp_dim, dropout = dropout)
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return self.norm(x)

# 视觉变换器类
class ViT(nn.Module):
    def __init__(
        self,
        *,
        image_size,
        image_patch_size,
        frames,
        frame_patch_size,
        num_classes,
        dim,
        spatial_depth,
        temporal_depth,
        heads,
        mlp_dim,
        pool = 'cls',
        channels = 3,
        dim_head = 64,
        dropout = 0.,
        emb_dropout = 0.
    ):
        # 调用父类的构造函数
        super().__init__()
        # 解构图像尺寸和图像块尺寸
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(image_patch_size)

        # 断言图像高度和宽度能够被图像块高度和宽度整除
        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
        # 断言帧数能够被帧块大小整除
        assert frames % frame_patch_size == 0, 'Frames must be divisible by frame patch size'

        # 计算图像块数量和帧块数量
        num_image_patches = (image_height // patch_height) * (image_width // patch_width)
        num_frame_patches = (frames // frame_patch_size)

        # 计算图像块维度
        patch_dim = channels * patch_height * patch_width * frame_patch_size

        # 断言池化类型为'cls'或'mean'
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        # 根据池化类型设置是否使用全局平均池化
        self.global_average_pool = pool == 'mean'

        # 定义将图像块转换为嵌入向量的层序列
        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (f pf) (h p1) (w p2) -> b f (h w) (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim)
        )

        # 初始化位置嵌入参数
        self.pos_embedding = nn.Parameter(torch.randn(1, num_frame_patches, num_image_patches, dim))
        self.dropout = nn.Dropout(emb_dropout)

        # 初始化空间和时间的CLS token参数
        self.spatial_cls_token = nn.Parameter(torch.randn(1, 1, dim)) if not self.global_average_pool else None
        self.temporal_cls_token = nn.Parameter(torch.randn(1, 1, dim)) if not self.global_average_pool else None

        # 初始化空间和时间的Transformer模型
        self.spatial_transformer = Transformer(dim, spatial_depth, heads, dim_head, mlp_dim, dropout)
        self.temporal_transformer = Transformer(dim, temporal_depth, heads, dim_head, mlp_dim, dropout)

        # 设置池化类型和转换为潜在空间的层
        self.pool = pool
        self.to_latent = nn.Identity()

        # 定义MLP头部
        self.mlp_head = nn.Linear(dim, num_classes)

    def forward(self, video):
        # 将视频转换为图像块嵌入向量
        x = self.to_patch_embedding(video)
        b, f, n, _ = x.shape

        # 添加位置嵌入
        x = x + self.pos_embedding[:, :f, :n]

        # 如果存在空间CLS token,则添加到输入中
        if exists(self.spatial_cls_token):
            spatial_cls_tokens = repeat(self.spatial_cls_token, '1 1 d -> b f 1 d', b = b, f = f)
            x = torch.cat((spatial_cls_tokens, x), dim = 2)

        # 应用Dropout
        x = self.dropout(x)

        # 重排张量形状以便空间注意力
        x = rearrange(x, 'b f n d -> (b f) n d')

        # 在空间上进行注意力计算
        x = self.spatial_transformer(x)

        # 重排张量形状以便后续处理
        x = rearrange(x, '(b f) n d -> b f n d', b = b)

        # 剔除空间CLS token或进行全局平均池化以便时间注意力
        x = x[:, :, 0] if not self.global_average_pool else reduce(x, 'b f n d -> b f d', 'mean')

        # 如果存在时间CLS token,则添加到输入中
        if exists(self.temporal_cls_token):
            temporal_cls_tokens = repeat(self.temporal_cls_token, '1 1 d-> b 1 d', b = b)
            x = torch.cat((temporal_cls_tokens, x), dim = 1)

        # 在时间上进行注意力计算
        x = self.temporal_transformer(x)

        # 剔除时间CLS token或进行全局平均池化
        x = x[:, 0] if not self.global_average_pool else reduce(x, 'b f d -> b d', 'mean')

        # 转换为潜在空间并返回MLP头部的输出
        x = self.to_latent(x)
        return self.mlp_head(x)

.\lucidrains\vit-pytorch\vit_pytorch\xcit.py

# 从 random 模块中导入 randrange 函数
from random import randrange

# 导入 torch 模块及相关子模块
import torch
from torch import nn, einsum
from torch.nn import Module, ModuleList
import torch.nn.functional as F

# 导入 einops 模块及相关函数
from einops import rearrange, repeat, pack, unpack
from einops.layers.torch import Rearrange

# 辅助函数

# 判断变量是否存在的函数
def exists(val):
    return val is not None

# 将张量打包成指定模式的函数
def pack_one(t, pattern):
    return pack([t], pattern)

# 将打包的张量解包成指定模式的函数
def unpack_one(t, ps, pattern):
    return unpack(t, ps, pattern)[0]

# 对张量进行 L2 归一化的函数
def l2norm(t):
    return F.normalize(t, dim = -1, p = 2)

# 对神经网络层进行 dropout 处理的函数
def dropout_layers(layers, dropout):
    if dropout == 0:
        return layers

    num_layers = len(layers)
    to_drop = torch.zeros(num_layers).uniform_(0., 1.) < dropout

    # 确保至少有一层不被丢弃
    if all(to_drop):
        rand_index = randrange(num_layers)
        to_drop[rand_index] = False

    layers = [layer for (layer, drop) in zip(layers, to_drop) if not drop]
    return layers

# 类

# LayerScale 类,用于对输入进行缩放
class LayerScale(Module):
    def __init__(self, dim, fn, depth):
        super().__init__()
        if depth <= 18:
            init_eps = 0.1
        elif 18 < depth <= 24:
            init_eps = 1e-5
        else:
            init_eps = 1e-6

        self.fn = fn
        self.scale = nn.Parameter(torch.full((dim,), init_eps))

    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) * self.scale

# FeedForward 类,前馈神经网络层
class FeedForward(Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

# Attention 类,注意力机制层
class Attention(Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head * heads
        self.heads = heads
        self.scale = dim_head ** -0.5

        self.norm = nn.LayerNorm(dim)
        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)

        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x, context = None):
        h = self.heads

        x = self.norm(x)
        context = x if not exists(context) else torch.cat((x, context), dim = 1)

        qkv = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)

        sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        attn = self.attend(sim)
        attn = self.dropout(attn)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

# XCAttention 类,交叉通道注意力机制层
class XCAttention(Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head * heads
        self.heads = heads
        self.norm = nn.LayerNorm(dim)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.temperature = nn.Parameter(torch.ones(heads, 1, 1))

        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        )
    # 定义前向传播函数,接受输入 x
    def forward(self, x):
        # 获取头数
        h = self.heads
        # 将输入 x 打包成指定格式,并返回打包后的数据和打包方案 ps
        x, ps = pack_one(x, 'b * d')

        # 对输入 x 进行归一化处理
        x = self.norm(x)
        # 将 x 转换为查询、键、值,并按最后一个维度分割成三部分
        q, k, v = self.to_qkv(x).chunk(3, dim=-1)

        # 将查询、键、值按照指定格式重新排列
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h d n', h=h), (q, k, v))

        # 对查询、键进行 L2 归一化处理
        q, k = map(l2norm, (q, k))

        # 计算注意力矩阵,包括计算相似度、温度调节和注意力计算
        sim = einsum('b h i n, b h j n -> b h i j', q, k) * self.temperature.exp()

        # 进行注意力聚合
        attn = self.attend(sim)
        # 对注意力矩阵进行 dropout 处理
        attn = self.dropout(attn)

        # 根据注意力矩阵和值计算输出
        out = einsum('b h i j, b h j n -> b h i n', attn, v)
        # 将输出按指定格式重新排列
        out = rearrange(out, 'b h d n -> b n (h d)')

        # 将输出解包成原始格式
        out = unpack_one(out, ps, 'b * d')
        # 返回输出结果
        return self.to_out(out)
class LocalPatchInteraction(Module):
    # 定义局部补丁交互模块,继承自 Module 类
    def __init__(self, dim, kernel_size = 3):
        # 初始化函数,接受维度 dim 和卷积核大小 kernel_size,默认为 3
        super().__init__()
        # 调用父类的初始化函数

        assert (kernel_size % 2) == 1
        # 断言卷积核大小为奇数
        padding = kernel_size // 2
        # 计算卷积的填充大小

        self.net = nn.Sequential(
            # 定义神经网络模块
            nn.LayerNorm(dim),
            # 对输入进行层归一化
            Rearrange('b h w c -> b c h w'),
            # 重新排列张量的维度
            nn.Conv2d(dim, dim, kernel_size, padding = padding, groups = dim),
            # 二维卷积层
            nn.BatchNorm2d(dim),
            # 对输入进行批归一化
            nn.GELU(),
            # GELU 激活函数
            nn.Conv2d(dim, dim, kernel_size, padding = padding, groups = dim),
            # 二维卷积层
            Rearrange('b c h w -> b h w c'),
            # 重新排列张量的维度
        )

    def forward(self, x):
        # 前向传播函数,接受输入 x
        return self.net(x)
        # 返回经过网络处理后的结果

class Transformer(Module):
    # 定义 Transformer 模块,继承自 Module 类
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., layer_dropout = 0.):
        # 初始化函数,接受维度 dim、深度 depth、头数 heads、头维度 dim_head、MLP维度 mlp_dim、dropout率 dropout 和层dropout率 layer_dropout,默认为 0
        super().__init__()
        # 调用父类的初始化函数
        self.layers = ModuleList([])
        # 初始化模块列表

        self.layer_dropout = layer_dropout
        # 设置层dropout率

        for ind in range(depth):
            # 循环遍历深度次数
            layer = ind + 1
            # 计算当前层索引
            self.layers.append(ModuleList([
                # 向模块列表中添加模块列表
                LayerScale(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout), depth = layer),
                # 添加注意力机制模块
                LayerScale(dim, FeedForward(dim, mlp_dim, dropout = dropout), depth = layer)
                # 添加前馈神经网络模块
            ]))

    def forward(self, x, context = None):
        # 前向传播函数,接受输入 x 和上下文 context,默认为 None
        layers = dropout_layers(self.layers, dropout = self.layer_dropout)
        # 对模块列表进行层dropout处理

        for attn, ff in layers:
            # 遍历模块列表中的注意力机制和前馈神经网络模块
            x = attn(x, context = context) + x
            # 经过注意力机制处理后与原始输入相加
            x = ff(x) + x
            # 经过前馈神经网络处理后与原始输入相加

        return x
        # 返回处理后的结果

class XCATransformer(Module):
    # 定义 XCAttention Transformer 模块,继承自 Module 类
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, local_patch_kernel_size = 3, dropout = 0., layer_dropout = 0.):
        # 初始化函数,接受维度 dim、深度 depth、头数 heads、头维度 dim_head、MLP维度 mlp_dim、局部补丁卷积核大小 local_patch_kernel_size,默认为 3,dropout率 dropout 和层dropout率 layer_dropout,默认为 0
        super().__init__()
        # 调用父类的初始化函数
        self.layers = ModuleList([])
        # 初始化模块列表

        self.layer_dropout = layer_dropout
        # 设置层dropout率

        for ind in range(depth):
            # 循环遍历深度次数
            layer = ind + 1
            # 计算当前层索引
            self.layers.append(ModuleList([
                # 向模块列表中添加模块列表
                LayerScale(dim, XCAttention(dim, heads = heads, dim_head = dim_head, dropout = dropout), depth = layer),
                # 添加交叉协方差注意力机制模块
                LayerScale(dim, LocalPatchInteraction(dim, local_patch_kernel_size), depth = layer),
                # 添加局部补丁交互模块
                LayerScale(dim, FeedForward(dim, mlp_dim, dropout = dropout), depth = layer)
                # 添加前馈神经网络模块
            ]))

    def forward(self, x):
        # 前向传播函数,接受输入 x
        layers = dropout_layers(self.layers, dropout = self.layer_dropout)
        # 对模块列表进行层dropout处理

        for cross_covariance_attn, local_patch_interaction, ff in layers:
            # 遍历模块列表中的交叉协方差注意力机制、局部补丁交互和前馈神经网络模块
            x = cross_covariance_attn(x) + x
            # 经过交叉协方差注意力机制处理后与原始输入相加
            x = local_patch_interaction(x) + x
            # 经过局部补丁交互处理后与原始输入相加
            x = ff(x) + x
            # 经过前馈神经网络处理后与原始输入相加

        return x
        # 返回处理后的结果

class XCiT(Module):
    # 定义 XCiT 模块,继承自 Module 类
    def __init__(
        self,
        *,
        image_size,
        patch_size,
        num_classes,
        dim,
        depth,
        cls_depth,
        heads,
        mlp_dim,
        dim_head = 64,
        dropout = 0.,
        emb_dropout = 0.,
        local_patch_kernel_size = 3,
        layer_dropout = 0.
    ):
        # 初始化函数,接受关键字参数 image_size、patch_size、num_classes、dim、depth、cls_depth、heads、mlp_dim、dim_head、dropout、emb_dropout、局部补丁卷积核大小 local_patch_kernel_size,默认为 3,层dropout率 layer_dropout,默认为 0
        super().__init__()
        # 调用父类的初始化函数
        assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
        # 断言图像尺寸必须能被补丁大小整除

        num_patches = (image_size // patch_size) ** 2
        # 计算补丁数量
        patch_dim = 3 * patch_size ** 2
        # 计算补丁维度

        self.to_patch_embedding = nn.Sequential(
            # 定义序列模块
            Rearrange('b c (h p1) (w p2) -> b h w (p1 p2 c)', p1 = patch_size, p2 = patch_size),
            # 重新排列张量的维度
            nn.LayerNorm(patch_dim),
            # 对输入进行层归一化
            nn.Linear(patch_dim, dim),
            # 线性变换
            nn.LayerNorm(dim)
            # 对输入进行层归一化
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, dim))
        # 定义位置编码参数
        self.cls_token = nn.Parameter(torch.randn(dim))
        # 定义类别标记参数

        self.dropout = nn.Dropout(emb_dropout)
        # 定义丢弃层

        self.xcit_transformer = XCATransformer(dim, depth, heads, dim_head, mlp_dim, local_patch_kernel_size, dropout, layer_dropout)
        # 定义 XCAttention Transformer 模块

        self.final_norm = nn.LayerNorm(dim)
        # 对最终结果进行层归一化

        self.cls_transformer = Transformer(dim, cls_depth, heads, dim_head, mlp_dim, dropout, layer_dropout)
        # 定义 Transformer 模块

        self.mlp_head = nn.Sequential(
            # 定义序列模块
            nn.LayerNorm(dim),
            # 对输入进行层归一化
            nn.Linear(dim, num_classes)
            # 线性变换
        )
        # 定义 MLP 头部模块
    # 前向传播函数,接收输入图像并进行处理
    def forward(self, img):
        # 将输入图像转换为补丁嵌入
        x = self.to_patch_embedding(img)

        # 将嵌入的补丁打包成一个张量
        x, ps = pack_one(x, 'b * d')

        # 获取张量的形状信息
        b, n, _ = x.shape
        # 添加位置嵌入到张量中
        x += self.pos_embedding[:, :n]

        # 解包张量
        x = unpack_one(x, ps, 'b * d')

        # 对张量进行 dropout 操作
        x = self.dropout(x)

        # 使用 XCIT Transformer 处理张量
        x = self.xcit_transformer(x)

        # 对处理后的张量进行最终归一化
        x = self.final_norm(x)

        # 重复生成类别标记 tokens
        cls_tokens = repeat(self.cls_token, 'd -> b 1 d', b = b)

        # 重新排列张量的维度
        x = rearrange(x, 'b ... d -> b (...) d')
        # 使用类别标记 tokens 和上下文张量进行类别 Transformer 操作
        cls_tokens = self.cls_transformer(cls_tokens, context = x)

        # 返回 MLP 头部处理后的结果
        return self.mlp_head(cls_tokens[:, 0])

.\lucidrains\vit-pytorch\vit_pytorch\__init__.py

# 从 vit_pytorch.vit 模块中导入 ViT 类
from vit_pytorch.vit import ViT
# 从 vit_pytorch.simple_vit 模块中导入 SimpleViT 类
from vit_pytorch.simple_vit import SimpleViT

# 从 vit_pytorch.mae 模块中导入 MAE 类
from vit_pytorch.mae import MAE
# 从 vit_pytorch.dino 模块中导入 Dino 类
from vit_pytorch.dino import Dino

.\lucidrains\VN-transformer\denoise.py

# 导入 PyTorch 库
import torch
# 导入 PyTorch 中的函数库
import torch.nn.functional as F
# 从 torch.optim 模块中导入 Adam 优化器
from torch.optim import Adam

# 从 einops 库中导入 rearrange 和 repeat 函数
from einops import rearrange, repeat

# 导入 sidechainnet 库,并从 VN_transformer 模块中导入 VNTransformer 类
import sidechainnet as scn
from VN_transformer import VNTransformer

# 定义常量 BATCH_SIZE
BATCH_SIZE = 1
# 定义常量 GRADIENT_ACCUMULATE_EVERY
GRADIENT_ACCUMULATE_EVERY = 16
# 定义常量 MAX_SEQ_LEN
MAX_SEQ_LEN = 256
# 定义默认数据类型 DEFAULT_TYPE
DEFAULT_TYPE = torch.float64

# 设置 PyTorch 默认数据类型为 DEFAULT_TYPE
torch.set_default_dtype(DEFAULT_TYPE)

# 定义一个循环生成器函数 cycle,用于生成数据
def cycle(loader, len_thres = MAX_SEQ_LEN):
    while True:
        for data in loader:
            # 如果数据的序列长度大于 len_thres,则继续循环
            if data.seqs.shape[1] > len_thres:
                continue
            # 生成数据
            yield data

# 创建 VNTransformer 模型对象
transformer = VNTransformer(
    num_tokens = 24,
    dim = 64,
    depth = 4,
    dim_head = 64,
    heads = 8,
    dim_feat = 64,
    bias_epsilon = 1e-6,
    l2_dist_attn = True,
    flash_attn = False
).cuda()

# 加载数据集
data = scn.load(
    casp_version = 12,
    thinning = 30,
    with_pytorch = 'dataloaders',
    batch_size = BATCH_SIZE,
    dynamic_batching = False    
)

# 创建数据生成器 dl
dl = cycle(data['train'])
# 初始化 Adam 优化器
optim = Adam(transformer.parameters(), lr = 1e-4)

# 进行训练循环
for _ in range(10000):
    for _ in range(GRADIENT_ACCUMULATE_EVERY):
        # 获取一个 batch 的数据
        batch = next(dl)
        seqs, coords, masks = batch.seqs, batch.crds, batch.msks

        # 将序列数据转移到 GPU 上,并取最大值作为索引
        seqs = seqs.cuda().argmax(dim = -1)
        # 将坐标数据转移到 GPU 上,并设置数据类型为默认类型
        coords = coords.cuda().type(torch.get_default_dtype())
        # 将掩码数据转移到 GPU 上,并转换为布尔类型
        masks = masks.cuda().bool()

        # 获取序列长度
        l = seqs.shape[1]
        # 重新排列坐标数据的维度
        coords = rearrange(coords, 'b (l s) c -> b l s c', s = 14)

        # 保留主干坐标
        coords = coords[:, :, 0:3, :]
        coords = rearrange(coords, 'b l s c -> b (l s) c')

        # 将序列数据重复为坐标数据的维度
        seq = repeat(seqs, 'b n -> b (n c)', c = 3)
        masks = repeat(masks, 'b n -> b (n c)', c = 3)

        # 给坐标数据添加高斯噪声
        noised_coords = coords + torch.randn_like(coords).cuda()

        # 运行 Transformer 模型
        type1_out, _ = transformer(
            noised_coords,
            feats = seq,
            mask = masks
        )

        # 去噪后的坐标数据
        denoised_coords = noised_coords + type1_out

        # 计算均方误差损失
        loss = F.mse_loss(denoised_coords[masks], coords[masks]) 
        # 反向传播并计算梯度
        (loss / GRADIENT_ACCUMULATE_EVERY).backward()

    # 输出当前损失值
    print('loss:', loss.item())
    # 更新优化器参数
    optim.step()
    # 清空梯度
    optim.zero_grad()

VN (Vector Neuron) Transformer

A Transformer made of Rotation-equivariant Attention using Vector Neurons

Open Review

Appreciation

  • StabilityAI for the generous sponsorship, as well as my other sponsors, for affording me the independence to open source artificial intelligence.

Install

$ pip install VN-transformer

Usage

import torch
from VN_transformer import VNTransformer

model = VNTransformer(
    dim = 64,
    depth = 2,
    dim_head = 64,
    heads = 8,
    dim_feat = 64,       # will default to early fusion, since this was the best performing
    bias_epsilon = 1e-6  # in this paper, they propose breaking equivariance with a tiny bit of bias noise in the VN linear. they claim this leads to improved stability. setting this to 0 would turn off the epsilon approximate equivariance
)

coors = torch.randn(1, 32, 3)    # (batch, sequence, spatial coordinates)
feats = torch.randn(1, 32, 64)

coors_out, feats_out = model(coors, feats = feats) # (1, 32, 3), (1, 32, 64)

Tests

Confidence in equivariance

$ python setup.py test

Example

First install sidechainnet

$ pip install sidechainnet

Then run the protein backbone denoising task

$ python denoise.py

It does not perform as well as En-Transformer, nor Equiformer

Citations

@inproceedings{Assaad2022VNTransformerRA,
    title   = {VN-Transformer: Rotation-Equivariant Attention for Vector Neurons},
    author  = {Serge Assaad and C. Downey and Rami Al-Rfou and Nigamaa Nayakanti and Benjamin Sapp},
    year    = {2022}
}
@article{Deng2021VectorNA,
    title   = {Vector Neurons: A General Framework for SO(3)-Equivariant Networks},
    author  = {Congyue Deng and Or Litany and Yueqi Duan and Adrien Poulenard and Andrea Tagliasacchi and Leonidas J. Guibas},
    journal = {2021 IEEE/CVF International Conference on Computer Vision (ICCV)},
    year    = {2021},
    pages   = {12180-12189},
    url     = {https://api.semanticscholar.org/CorpusID:233394028}
}
@inproceedings{Kim2020TheLC,
    title   = {The Lipschitz Constant of Self-Attention},
    author  = {Hyunjik Kim and George Papamakarios and Andriy Mnih},
    booktitle = {International Conference on Machine Learning},
    year    = {2020},
    url     = {https://api.semanticscholar.org/CorpusID:219530837}
}
@inproceedings{dao2022flashattention,
    title   = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
    author  = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
    booktitle = {Advances in Neural Information Processing Systems},
    year    = {2022}
}

.\lucidrains\VN-transformer\setup.py

# 导入设置安装和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'VN-transformer',  # 包的名称
  packages = find_packages(exclude=[]),  # 查找所有包
  version = '0.1.0',  # 版本号
  license='MIT',  # 许可证
  description = 'Vector Neuron Transformer (VN-Transformer)',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  long_description_content_type = 'text/markdown',  # 长描述内容类型
  url = 'https://github.com/lucidrains/VN-transformer',  # 项目链接
  keywords = [  # 关键词列表
    'artificial intelligence',
    'deep learning',
    'equivariance',
    'vector neurons',
    'transformers',
    'attention mechanism'
  ],
  install_requires=[  # 安装依赖
    'einops>=0.6.0',
    'torch>=1.6'
  ],
  setup_requires=[  # 设置依赖
    'pytest-runner',
  ],
  tests_require=[  # 测试依赖
    'pytest'
  ],
  classifiers=[  # 分类器
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.8',
  ],
)

.\lucidrains\VN-transformer\tests\test.py

# 导入 pytest 库
import pytest

# 导入 torch 库
import torch
# 从 VN_transformer 模块中导入 VNTransformer、VNInvariant、VNAttention 类和 rot 函数
from VN_transformer.VN_transformer import VNTransformer, VNInvariant, VNAttention
from VN_transformer.rotations import rot

# 设置默认的 torch 数据类型为 float64
torch.set_default_dtype(torch.float64)

# 测试不变层
def test_vn_invariant():
    # 创建一个 VNInvariant 层对象,输入维度为 64
    layer = VNInvariant(64)

    # 生成一个形状为 (1, 32, 64, 3) 的随机张量
    coors = torch.randn(1, 32, 64, 3)

    # 生成一个随机旋转矩阵 R
    R = rot(*torch.randn(3))
    # 对输入张量和经过旋转的输入张量进行 VNInvariant 层的计算
    out1 = layer(coors)
    out2 = layer(coors @ R)

    # 检查经过不变层计算的两个输出张量是否在给定的容差范围内相等
    assert torch.allclose(out1, out2, atol = 1e-6)

# 测试等变性
@pytest.mark.parametrize('l2_dist_attn', [True, False])
def test_equivariance(l2_dist_attn):

    # 创建一个 VNTransformer 模型对象,设置相关参数
    model = VNTransformer(
        dim = 64,
        depth = 2,
        dim_head = 64,
        heads = 8,
        l2_dist_attn = l2_dist_attn
    )

    # 生成一个形状为 (1, 32, 3) 的随机张量
    coors = torch.randn(1, 32, 3)
    # 创建一个形状为 (1, 32) 的全为 True 的布尔张量
    mask  = torch.ones(1, 32).bool()

    # 生成一个随机旋转矩阵 R
    R   = rot(*torch.randn(3))
    # 对输入张量和经过旋转的输入张量进行 VNTransformer 模型的计算
    out1 = model(coors @ R, mask = mask)
    out2 = model(coors, mask = mask) @ R

    # 检查经过模型计算的两个输出张量是否在给定的容差范围内相等
    assert torch.allclose(out1, out2, atol = 1e-6), 'is not equivariant'

# 测试 VN Perceiver 注意力等变性
@pytest.mark.parametrize('l2_dist_attn', [True, False])
def test_perceiver_vn_attention_equivariance(l2_dist_attn):

    # 创建一个 VNAttention 模型对象,设置相关参数
    model = VNAttention(
        dim = 64,
        dim_head = 64,
        heads = 8,
        num_latents = 2,
        l2_dist_attn = l2_dist_attn
    )

    # 生成一个形状为 (1, 32, 64, 3) 的随机张量
    coors = torch.randn(1, 32, 64, 3)
    # 创建一个形状为 (1, 32) 的全为 True 的布尔张量
    mask  = torch.ones(1, 32).bool()

    # 生成一个随机旋转矩阵 R
    R   = rot(*torch.randn(3))
    # 对输入张量和经过旋转的输入张量进行 VNAttention 模型的计算
    out1 = model(coors @ R, mask = mask)
    out2 = model(coors, mask = mask) @ R

    # ��查输出张量的形状是否符合预期
    assert out1.shape[1] == 2
    # 检查经过模型计算的两个输出张量是否在给定的容差范围内相等
    assert torch.allclose(out1, out2, atol = 1e-6), 'is not equivariant'

# 测试 SO(3) 早期融合等变性
@pytest.mark.parametrize('l2_dist_attn', [True, False])
def test_equivariance_with_early_fusion(l2_dist_attn):

    # 创建一个 VNTransformer 模型对象,设置相关参数
    model = VNTransformer(
        dim = 64,
        depth = 2,
        dim_head = 64,
        heads = 8,
        dim_feat = 64,
        l2_dist_attn = l2_dist_attn
    )

    # 生成一个形状为 (1, 32, 64) 的随机张量
    feats = torch.randn(1, 32, 64)
    # 生成一个形状为 (1, 32, 3) 的随机张量
    coors = torch.randn(1, 32, 3)
    # 创建一个形状为 (1, 32) 的全为 True 的布尔张量
    mask  = torch.ones(1, 32).bool()

    # 生成一个随机旋转矩阵 R
    R   = rot(*torch.randn(3))
    # 对输入张量和特征张量进行 VNTransformer 模型的计算
    out1, _ = model(coors @ R, feats = feats, mask = mask, return_concatted_coors_and_feats = False)

    out2, _ = model(coors, feats = feats, mask = mask, return_concatted_coors_and_feats = False)
    out2 = out2 @ R

    # 检查经过模型计算的两个输出张量是否在给定的容差范围内相等
    assert torch.allclose(out1, out2, atol = 1e-6), 'is not equivariant'

# 测试 SE(3) 早期融合等变性
@pytest.mark.parametrize('l2_dist_attn', [True, False])
def test_se3_equivariance_with_early_fusion(l2_dist_attn):

    # 创建一个 VNTransformer 模型对象,设置相关参数
    model = VNTransformer(
        dim = 64,
        depth = 2,
        dim_head = 64,
        heads = 8,
        dim_feat = 64,
        translation_equivariance = True,
        l2_dist_attn = l2_dist_attn
    )

    # 生成一个形状为 (1, 32, 64) 的随机张量
    feats = torch.randn(1, 32, 64)
    # 生成一个形状为 (1, 32, 3) 的随机张量
    coors = torch.randn(1, 32, 3)
    # 创建一个形状为 (1, 32) 的全为 True 的布尔张量
    mask  = torch.ones(1, 32).bool()

    # 生成一个随机平移向量 T 和旋转矩阵 R
    T   = torch.randn(3)
    R   = rot(*torch.randn(3))
    # 对输入张量和特征张量进行 VNTransformer 模型的计算
    out1, _ = model((coors + T) @ R, feats = feats, mask = mask, return_concatted_coors_and_feats = False)

    out2, _ = model(coors, feats = feats, mask = mask, return_concatted_coors_and_feats = False)
    out2 = (out2 + T) @ R

    # 检查经过模型计算的两个输出张量是否在给定的容差范围内相等
    assert torch.allclose(out1, out2, atol = 1e-6), 'is not equivariant'

.\lucidrains\VN-transformer\VN_transformer\attend.py

# 导入必要的库
from functools import wraps
from packaging import version
from collections import namedtuple

import torch
from torch import nn, einsum
import torch.nn.functional as F

from einops import rearrange, reduce

# 定义一个命名元组 FlashAttentionConfig,包含三个布尔类型的参数
FlashAttentionConfig = namedtuple('FlashAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])

# 定义一个辅助函数,用于检查变量是否存在
def exists(val):
    return val is not None

# 定义一个装饰器函数,确保被装饰的函数只能调用一次
def once(fn):
    called = False
    @wraps(fn)
    def inner(x):
        nonlocal called
        if called:
            return
        called = True
        return fn(x)
    return inner

# 定义一个只能打印一次的函数
print_once = once(print)

# 主要类 Attend
class Attend(nn.Module):
    def __init__(
        self,
        dropout = 0.,
        flash = False,
        l2_dist = False
    ):
        super().__init__()
        assert not (flash and l2_dist), 'flash attention is not compatible with l2 distance'
        self.l2_dist = l2_dist

        self.dropout = dropout
        self.attn_dropout = nn.Dropout(dropout)

        self.flash = flash
        assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'

        # 确定用于 cuda 和 cpu 的高效注意力配置

        self.cpu_config = FlashAttentionConfig(True, True, True)
        self.cuda_config = None

        if not torch.cuda.is_available() or not flash:
            return

        device_properties = torch.cuda.get_device_properties(torch.device('cuda'))

        if device_properties.major == 8 and device_properties.minor == 0:
            print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
            self.cuda_config = FlashAttentionConfig(True, False, False)
        else:
            print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
            self.cuda_config = FlashAttentionConfig(False, True, True)

    # Flash Attention 函数
    def flash_attn(self, q, k, v, mask = None):
        _, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device

        # 检查是否存在 mask 并将其扩展到兼容的形状
        # mask 的形状为 B L,需要扩展为 B H N L

        if exists(mask):
            mask = mask.expand(-1, heads, q_len, -1)

        # 检查是否有兼容的设备用于 Flash Attention

        config = self.cuda_config if is_cuda else self.cpu_config

        # 使用 torch.backends.cuda.sdp_kernel(**config._asdict()) 来调用 pytorch 2.0 的 flash attention
        with torch.backends.cuda.sdp_kernel(**config._asdict()):
            out = F.scaled_dot_product_attention(
                q, k, v,
                attn_mask = mask,
                dropout_p = self.dropout if self.training else 0.
            )

        return out
    # 定义一个前向传播函数,接受查询(q)、键(k)、值(v)和可选的掩码(mask)
    def forward(self, q, k, v, mask = None):
        """
        einstein notation
        b - batch
        h - heads
        n, i, j - sequence length (base sequence length, source, target)
        d - feature dimension
        """

        # 获取查询(q)和键(k)的序列长度以及设备信息
        q_len, k_len, device = q.shape[-2], k.shape[-2], q.device

        # 缩放因子,根据特征维度的倒数开根号
        scale = q.shape[-1] ** -0.5

        # 如果存在掩码(mask)且维度不是4,则重新排列掩码的维度
        if exists(mask) and mask.ndim != 4:
            mask = rearrange(mask, 'b j -> b 1 1 j')

        # 如果启用了flash,则调用flash_attn函数
        if self.flash:
            return self.flash_attn(q, k, v, mask = mask)

        # 相似度计算

        sim = einsum(f"b h i d, b h j d -> b h i j", q, k) * scale

        # L2距离计算

        if self.l2_dist:
            # -cdist squared == (-q^2 + 2qk - k^2)
            # 因此简单地基于上面的qk进行计算
            q_squared = reduce(q ** 2, 'b h i d -> b h i 1', 'sum')
            k_squared = reduce(k ** 2, 'b h j d -> b h 1 j', 'sum')
            sim = sim * 2 - q_squared - k_squared

        # 键填充掩码

        if exists(mask):
            sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)

        # 注意力计算

        attn = sim.softmax(dim=-1)
        attn = self.attn_dropout(attn)

        # 聚合值

        out = einsum(f"b h i j, b h j d -> b h i d", attn, v)

        return out

.\lucidrains\VN-transformer\VN_transformer\rotations.py

# 导入 torch 库
import torch
# 从 torch 库中导入 sin, cos, atan2, acos 函数
from torch import sin, cos, atan2, acos

# 定义绕 z 轴旋转的函数,参数为旋转角度 gamma
def rot_z(gamma):
    # 返回绕 z 轴旋转的旋转矩阵
    return torch.tensor([
        [cos(gamma), -sin(gamma), 0],
        [sin(gamma), cos(gamma), 0],
        [0, 0, 1]
    ], dtype=gamma.dtype)

# 定义绕 y 轴旋转的函数,参数为旋转角度 beta
def rot_y(beta):
    # 返回绕 y 轴旋转的旋转矩阵
    return torch.tensor([
        [cos(beta), 0, sin(beta)],
        [0, 1, 0],
        [-sin(beta), 0, cos(beta)]
    ], dtype=beta.dtype)

# 定义绕任意轴旋转的函数,参数为三个旋转角度 alpha, beta, gamma
def rot(alpha, beta, gamma):
    # 返回绕任意轴旋转的旋转矩阵,先绕 z 轴旋转 alpha,再绕 y 轴旋转 beta,最后绕 z 轴旋转 gamma
    return rot_z(alpha) @ rot_y(beta) @ rot_z(gamma)

.\lucidrains\VN-transformer\VN_transformer\VN_transformer.py

# 导入 torch 库
import torch
# 导入 torch 中的函数库
import torch.nn.functional as F
# 从 torch 中导入 nn, einsum, Tensor
from torch import nn, einsum, Tensor

# 从 einops 中导入 rearrange, repeat, reduce
from einops import rearrange, repeat, reduce
# 从 einops.layers.torch 中导入 Rearrange, Reduce
from einops.layers.torch import Rearrange, Reduce
# 从 VN_transformer.attend 中导入 Attend

# 辅助函数

# 判断变量是否存在
def exists(val):
    return val is not None

# 如果变量存在则返回其值,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# 计算两个向量的内积
def inner_dot_product(x, y, *, dim = -1, keepdim = True):
    return (x * y).sum(dim = dim, keepdim = keepdim)

# layernorm

# LayerNorm 类
class LayerNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(dim))
        self.register_buffer('beta', torch.zeros(dim))

    def forward(self, x):
        return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)

# equivariant modules

# VNLinear 类
class VNLinear(nn.Module):
    def __init__(
        self,
        dim_in,
        dim_out,
        bias_epsilon = 0.
    ):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(dim_out, dim_in))

        self.bias = None
        self.bias_epsilon = bias_epsilon

        # 在这篇论文中,他们提出使用一个小偏置进行准等变性,通过 epsilon 可控,他们声称这样可以获得更好的稳定性和结果

        if bias_epsilon > 0.:
            self.bias = nn.Parameter(torch.randn(dim_out))

    def forward(self, x):
        out = einsum('... i c, o i -> ... o c', x, self.weight)

        if exists(self.bias):
            bias = F.normalize(self.bias, dim = -1) * self.bias_epsilon
            out = out + rearrange(bias, '... -> ... 1')

        return out

# VNReLU 类
class VNReLU(nn.Module):
    def __init__(self, dim, eps = 1e-6):
        super().__init__()
        self.eps = eps
        self.W = nn.Parameter(torch.randn(dim, dim))
        self.U = nn.Parameter(torch.randn(dim, dim))

    def forward(self, x):
        q = einsum('... i c, o i -> ... o c', x, self.W)
        k = einsum('... i c, o i -> ... o c', x, self.U)

        qk = inner_dot_product(q, k)

        k_norm = k.norm(dim = -1, keepdim = True).clamp(min = self.eps)
        q_projected_on_k = q - inner_dot_product(q, k / k_norm) * k

        out = torch.where(
            qk >= 0.,
            q,
            q_projected_on_k
        )

        return out

# VNAttention 类
class VNAttention(nn.Module):
    def __init__(
        self,
        dim,
        dim_head = 64,
        heads = 8,
        dim_coor = 3,
        bias_epsilon = 0.,
        l2_dist_attn = False,
        flash = False,
        num_latents = None   # 设置此参数将启用类似于 perceiver 的跨注意力机制,从潜在变量到序列,潜在变量由 VNWeightedPool 推导而来
    ):
        super().__init__()
        assert not (l2_dist_attn and flash), 'l2 distance attention is not compatible with flash attention'

        self.scale = (dim_coor * dim_head) ** -0.5
        dim_inner = dim_head * heads
        self.heads = heads

        self.to_q_input = None
        if exists(num_latents):
            self.to_q_input = VNWeightedPool(dim, num_pooled_tokens = num_latents, squeeze_out_pooled_dim = False)

        self.to_q = VNLinear(dim, dim_inner, bias_epsilon = bias_epsilon)
        self.to_k = VNLinear(dim, dim_inner, bias_epsilon = bias_epsilon)
        self.to_v = VNLinear(dim, dim_inner, bias_epsilon = bias_epsilon)
        self.to_out = VNLinear(dim_inner, dim, bias_epsilon = bias_epsilon)

        if l2_dist_attn and not exists(num_latents):
            # 对于 l2 距离注意力,查询和键是相同的,不是 perceiver-like 注意力
            self.to_k = self.to_q

        self.attend = Attend(flash = flash, l2_dist = l2_dist_attn)
    # 定义一个前向传播函数,接受输入 x 和可选的 mask 参数
    def forward(self, x, mask = None):
        """
        einstein notation
        b - batch
        n - sequence
        h - heads
        d - feature dimension (channels)
        c - coordinate dimension (3 for 3d space)
        i - source sequence dimension
        j - target sequence dimension
        """

        # 获取输入 x 的最后一个维度,即特征维度的大小
        c = x.shape[-1]

        # 如果存在 self.to_q_input 方法,则使用该方法处理输入 x 和 mask,否则直接使用 x
        if exists(self.to_q_input):
            q_input = self.to_q_input(x, mask = mask)
        else:
            q_input = x

        # 分别通过 self.to_q、self.to_k、self.to_v 方法处理 q_input,得到 q、k、v
        q, k, v = self.to_q(q_input), self.to_k(x), self.to_v(x)
        # 将 q、k、v 重排维度,将其转换为 'b h n (d c)' 的形式
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) c -> b h n (d c)', h = self.heads), (q, k, v))

        # 调用 attend 方法进行注意力计算
        out = self.attend(q, k, v, mask = mask)

        # 将输出 out 重排维度,将其转换为 'b n (h d) c' 的形式
        out = rearrange(out, 'b h n (d c) -> b n (h d) c', c = c)
        # 返回处理后的输出结果
        return self.to_out(out)
# 定义一个 VNFeedForward 类,包含线性层、ReLU 激活函数和另一个线性层
def VNFeedForward(dim, mult = 4, bias_epsilon = 0.):
    # 计算内部维度
    dim_inner = int(dim * mult)
    # 返回一个包含上述三个层的序列模块
    return nn.Sequential(
        VNLinear(dim, dim_inner, bias_epsilon = bias_epsilon),  # VNLinear 线性层
        VNReLU(dim_inner),  # VNReLU 激活函数
        VNLinear(dim_inner, dim, bias_epsilon = bias_epsilon)  # 另一个 VNLinear 线性层
    )

# 定义一个 VNLayerNorm 类,包含 LayerNorm 层
class VNLayerNorm(nn.Module):
    def __init__(self, dim, eps = 1e-6):
        super().__init__()
        self.eps = eps
        self.ln = LayerNorm(dim)  # LayerNorm 层

    def forward(self, x):
        norms = x.norm(dim = -1)
        x = x / rearrange(norms.clamp(min = self.eps), '... -> ... 1')
        ln_out = self.ln(norms)
        return x * rearrange(ln_out, '... -> ... 1')

# 定义一个 VNWeightedPool 类,包含权重参数和池化操作
class VNWeightedPool(nn.Module):
    def __init__(
        self,
        dim,
        dim_out = None,
        num_pooled_tokens = 1,
        squeeze_out_pooled_dim = True
    ):
        super().__init__()
        dim_out = default(dim_out, dim)
        self.weight = nn.Parameter(torch.randn(num_pooled_tokens, dim, dim_out))  # 权重参数
        self.squeeze_out_pooled_dim = num_pooled_tokens == 1 and squeeze_out_pooled_dim

    def forward(self, x, mask = None):
        if exists(mask):
            mask = rearrange(mask, 'b n -> b n 1 1')
            x = x.masked_fill(~mask, 0.)
            numer = reduce(x, 'b n d c -> b d c', 'sum')
            denom = mask.sum(dim = 1)
            mean_pooled = numer / denom.clamp(min = 1e-6)
        else:
            mean_pooled = reduce(x, 'b n d c -> b d c', 'mean')

        out = einsum('b d c, m d e -> b m e c', mean_pooled, self.weight)

        if not self.squeeze_out_pooled_dim:
            return out

        out = rearrange(out, 'b 1 d c -> b d c')
        return out

# 定义一个 VNTransformerEncoder 类,包含多层 VNAttention、VNLayerNorm 和 VNFeedForward
class VNTransformerEncoder(nn.Module):
    def __init__(
        self,
        dim,
        *,
        depth,
        dim_head = 64,
        heads = 8,
        dim_coor = 3,
        ff_mult = 4,
        final_norm = False,
        bias_epsilon = 0.,
        l2_dist_attn = False,
        flash_attn = False
    ):
        super().__init__()
        self.dim = dim
        self.dim_coor = dim_coor

        self.layers = nn.ModuleList([])

        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                VNAttention(dim = dim, dim_head = dim_head, heads = heads, bias_epsilon = bias_epsilon, l2_dist_attn = l2_dist_attn, flash = flash_attn),  # VNAttention 层
                VNLayerNorm(dim),  # VNLayerNorm 层
                VNFeedForward(dim = dim, mult = ff_mult, bias_epsilon = bias_epsilon),  # VNFeedForward 层
                VNLayerNorm(dim)  # 另一个 VNLayerNorm 层
            ]))

        self.norm = VNLayerNorm(dim) if final_norm else nn.Identity()

    def forward(
        self,
        x,
        mask = None
    ):
        *_, d, c = x.shape

        assert x.ndim == 4 and d == self.dim and c == self.dim_coor, 'input needs to be in the shape of (batch, seq, dim ({self.dim}), coordinate dim ({self.dim_coor}))'

        for attn, attn_post_ln, ff, ff_post_ln in self.layers:
            x = attn_post_ln(attn(x, mask = mask)) + x
            x = ff_post_ln(ff(x)) + x

        return self.norm(x)

# 定义一个 VNInvariant 类,包含 MLP 模块
class VNInvariant(nn.Module):
    def __init__(
        self,
        dim,
        dim_coor = 3,

    ):
        super().__init__()
        self.mlp = nn.Sequential(
            VNLinear(dim, dim_coor),  # VNLinear 线性层
            VNReLU(dim_coor),  # VNReLU 激活函数
            Rearrange('... d e -> ... e d')  # 重新排列维度
        )

    def forward(self, x):
        return einsum('b n d i, b n i o -> b n o', x, self.mlp(x))

# 定义一个 VNTransformer 类,包含多个参数和模块
class VNTransformer(nn.Module):
    def __init__(
        self,
        *,
        dim,
        depth,
        num_tokens = None,
        dim_feat = None,
        dim_head = 64,
        heads = 8,
        dim_coor = 3,
        reduce_dim_out = True,
        bias_epsilon = 0.,
        l2_dist_attn = False,
        flash_attn = False,
        translation_equivariance = False,
        translation_invariant = False
    ):
        # 调用父类的构造函数
        super().__init__()
        # 如果 num_tokens 存在,则创建一个维度为 dim 的嵌入层
        self.token_emb = nn.Embedding(num_tokens, dim) if exists(num_tokens) else None

        # 设置特征维度为 dim_feat 或默认为 0
        dim_feat = default(dim_feat, 0)
        self.dim_feat = dim_feat
        # 计算坐标总维度,包括坐标和特征
        self.dim_coor_total = dim_coor + dim_feat

        # 确保平移等变性和平移不变性最多只能有一个为真
        assert (int(translation_equivariance) + int(translation_invariant)) <= 1
        self.translation_equivariance = translation_equivariance
        self.translation_invariant = translation_invariant

        # 定义输入投影层
        self.vn_proj_in = nn.Sequential(
            Rearrange('... c -> ... 1 c'),
            VNLinear(1, dim, bias_epsilon = bias_epsilon)
        )

        # 创建 VNTransformerEncoder 编码器
        self.encoder = VNTransformerEncoder(
            dim = dim,
            depth = depth,
            dim_head = dim_head,
            heads = heads,
            bias_epsilon = bias_epsilon,
            dim_coor = self.dim_coor_total,
            l2_dist_attn = l2_dist_attn,
            flash_attn = flash_attn
        )

        # 如果需要减少输出维度,则定义输出投影层
        if reduce_dim_out:
            self.vn_proj_out = nn.Sequential(
                VNLayerNorm(dim),
                VNLinear(dim, 1, bias_epsilon = bias_epsilon),
                Rearrange('... 1 c -> ... c')
            )
        else:
            self.vn_proj_out = nn.Identity()

    def forward(
        self,
        coors,
        *,
        feats = None,
        mask = None,
        return_concatted_coors_and_feats = False
    ):
        # 如果需要平移等变性或平移不变性,则计算坐标的平均值并减去
        if self.translation_equivariance or self.translation_invariant:
            coors_mean = reduce(coors, '... c -> c', 'mean')
            coors = coors - coors_mean

        x = coors

        # 如果存在特征,则将特征拼接到坐标中
        if exists(feats):
            if feats.dtype == torch.long:
                assert exists(self.token_emb), 'num_tokens must be given to the VNTransformer (to build the Embedding), if the features are to be given as indices'
                feats = self.token_emb(feats)

            assert feats.shape[-1] == self.dim_feat, f'dim_feat should be set to {feats.shape[-1]}'
            x = torch.cat((x, feats), dim = -1)

        assert x.shape[-1] == self.dim_coor_total

        # 输入投影层
        x = self.vn_proj_in(x)
        # 编码器
        x = self.encoder(x, mask = mask)
        # 输出投影层
        x = self.vn_proj_out(x)

        # 提取坐标和特征
        coors_out, feats_out = x[..., :3], x[..., 3:]

        # 如果需要平移等变性,则将坐标输出加上坐标平均值
        if self.translation_equivariance:
            coors_out = coors_out + coors_mean

        # 如果没有特征,则返回坐标输出
        if not exists(feats):
            return coors_out

        # 如果需要返回拼接的坐标和特征,则返回拼接后的结果
        if return_concatted_coors_and_feats:
            return torch.cat((coors_out, feats_out), dim = -1)

        # 否则返回坐标和特征分开的结果
        return coors_out, feats_out

.\lucidrains\VN-transformer\VN_transformer\__init__.py

# 从VN_transformer.VN_transformer模块中导入以下类和函数
from VN_transformer.VN_transformer import (
    VNTransformer,         # 导入VNTransformer类
    VNLinear,              # 导入VNLinear类
    VNLayerNorm,           # 导入VNLayerNorm类
    VNFeedForward,         # 导入VNFeedForward类
    VNAttention,           # 导入VNAttention类
    VNWeightedPool,        # 导入VNWeightedPool类
    VNTransformerEncoder,  # 导入VNTransformerEncoder类
    VNInvariant            # 导入VNInvariant类
)