Lucidrains-系列项目源码解析-一百零六-

46 阅读27分钟

Lucidrains 系列项目源码解析(一百零六)

.\lucidrains\vit-pytorch\setup.py

# 导入设置和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  # 包的名称
  name = 'vit-pytorch',
  # 查找除了 'examples' 文件夹之外的所有包
  packages = find_packages(exclude=['examples']),
  # 版本号
  version = '1.6.5',
  # 许可证类型
  license='MIT',
  # 描述
  description = 'Vision Transformer (ViT) - Pytorch',
  # 长描述内容类型
  long_description_content_type = 'text/markdown',
  # 作者
  author = 'Phil Wang',
  # 作者邮箱
  author_email = 'lucidrains@gmail.com',
  # 项目链接
  url = 'https://github.com/lucidrains/vit-pytorch',
  # 关键词
  keywords = [
    'artificial intelligence',
    'attention mechanism',
    'image recognition'
  ],
  # 安装依赖
  install_requires=[
    'einops>=0.7.0',
    'torch>=1.10',
    'torchvision'
  ],
  # 设置需要的依赖
  setup_requires=[
    'pytest-runner',
  ],
  # 测试需要的依赖
  tests_require=[
    'pytest',
    'torch==1.12.1',
    'torchvision==0.13.1'
  ],
  # 分类
  classifiers=[
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\vit-pytorch\tests\test.py

# 导入 torch 库
import torch
# 从 vit_pytorch 库中导入 ViT 类
from vit_pytorch import ViT

# 定义测试函数
def test():
    # 创建 ViT 模型对象,设置参数:图像大小为 256,patch 大小为 32,类别数为 1000,特征维度为 1024,深度为 6,注意力头数为 16,MLP 隐藏层维度为 2048,dropout 概率为 0.1,嵌入层 dropout 概率为 0.1
    v = ViT(
        image_size = 256,
        patch_size = 32,
        num_classes = 1000,
        dim = 1024,
        depth = 6,
        heads = 16,
        mlp_dim = 2048,
        dropout = 0.1,
        emb_dropout = 0.1
    )

    # 生成一个形状为 (1, 3, 256, 256) 的随机张量作为输入图像
    img = torch.randn(1, 3, 256, 256)

    # 将输入图像传入 ViT 模型进行预测
    preds = v(img)
    # 断言预测结果的形状为 (1, 1000),如果不符合则抛出异常信息 'correct logits outputted'
    assert preds.shape == (1, 1000), 'correct logits outputted'

.\lucidrains\vit-pytorch\vit_pytorch\ats_vit.py

# 导入 torch 库
import torch
# 导入 torch 中的函数库
import torch.nn.functional as F
# 从 torch.nn.utils.rnn 中导入 pad_sequence 函数
from torch.nn.utils.rnn import pad_sequence
# 从 torch 中导入 nn、einsum 模块
from torch import nn, einsum
# 从 einops 中导入 rearrange、repeat 函数和 Rearrange 类
from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# 辅助函数

# 判断值是否存在
def exists(val):
    return val is not None

# 将输入转换为元组
def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# 自适应令牌采样函数和类

# 计算输入张量的自然对数,避免输入为 0 时出现错误
def log(t, eps = 1e-6):
    return torch.log(t + eps)

# 生成服从 Gumbel 分布的随机数
def sample_gumbel(shape, device, dtype, eps = 1e-6):
    u = torch.empty(shape, device = device, dtype = dtype).uniform_(0, 1)
    return -log(-log(u, eps), eps)

# 在指定维度上对输入张量进行批量索引选择
def batched_index_select(values, indices, dim = 1):
    # 获取值张量和索引张量的维度信息
    value_dims = values.shape[(dim + 1):]
    values_shape, indices_shape = map(lambda t: list(t.shape), (values, indices))
    # 将索引张量扩展到与值张量相同的维度
    indices = indices[(..., *((None,) * len(value_dims))]
    indices = indices.expand(*((-1,) * len(indices_shape)), *value_dims)
    value_expand_len = len(indices_shape) - (dim + 1)
    values = values[(*((slice(None),) * dim), *((None,) * value_expand_len), ...]

    value_expand_shape = [-1] * len(values.shape)
    expand_slice = slice(dim, (dim + value_expand_len))
    value_expand_shape[expand_slice] = indices.shape[expand_slice]
    values = values.expand(*value_expand_shape)

    dim += value_expand_len
    return values.gather(dim, indices)

# 自适应令牌采样类
class AdaptiveTokenSampling(nn.Module):
    def __init__(self, output_num_tokens, eps = 1e-6):
        super().__init__()
        self.eps = eps
        self.output_num_tokens = output_num_tokens
    # 定义一个前向传播函数,接收注意力值、数值、掩码作为输入
    def forward(self, attn, value, mask):
        # 获取注意力值的头数、输出的标记数、eps值、设备和数据类型
        heads, output_num_tokens, eps, device, dtype = attn.shape[1], self.output_num_tokens, self.eps, attn.device, attn.dtype

        # 获取CLS标记到所有其他标记的注意力值
        cls_attn = attn[..., 0, 1:]

        # 计算数值的范数,用于加权得分,如论文中所述
        value_norms = value[..., 1:, :].norm(dim=-1)

        # 通过数值的范数加权注意力得分,对所有头求和
        cls_attn = einsum('b h n, b h n -> b n', cls_attn, value_norms)

        # 归一化为1
        normed_cls_attn = cls_attn / (cls_attn.sum(dim=-1, keepdim=True) + eps)

        # 不使用逆变换采样,而是反转softmax并使用gumbel-max采样
        pseudo_logits = log(normed_cls_attn)

        # 为gumbel-max采样屏蔽伪对数
        mask_without_cls = mask[:, 1:]
        mask_value = -torch.finfo(attn.dtype).max / 2
        pseudo_logits = pseudo_logits.masked_fill(~mask_without_cls, mask_value)

        # 扩展k次,k为自适应采样数
        pseudo_logits = repeat(pseudo_logits, 'b n -> b k n', k=output_num_tokens)
        pseudo_logits = pseudo_logits + sample_gumbel(pseudo_logits.shape, device=device, dtype=dtype)

        # gumbel-max采样并加一以保留0用于填充/掩码
        sampled_token_ids = pseudo_logits.argmax(dim=-1) + 1

        # 使用torch.unique计算唯一值,然后从右侧填充序列
        unique_sampled_token_ids_list = [torch.unique(t, sorted=True) for t in torch.unbind(sampled_token_ids)]
        unique_sampled_token_ids = pad_sequence(unique_sampled_token_ids_list, batch_first=True)

        # 基于填充计算新的掩码
        new_mask = unique_sampled_token_ids != 0

        # CLS标记永远不会被屏蔽(得到True值)
        new_mask = F.pad(new_mask, (1, 0), value=True)

        # 在前面添加一个0标记ID以保留CLS注意力得分
        unique_sampled_token_ids = F.pad(unique_sampled_token_ids, (1, 0), value=0)
        expanded_unique_sampled_token_ids = repeat(unique_sampled_token_ids, 'b n -> b h n', h=heads)

        # 收集新的注意力得分
        new_attn = batched_index_select(attn, expanded_unique_sampled_token_ids, dim=2)

        # 返回采样的注意力得分、新掩码(表示填充)以及采样的标记索引(用于残差)
        return new_attn, new_mask, unique_sampled_token_ids
# 定义前馈神经网络类
class FeedForward(nn.Module):
    # 初始化函数,定义网络结构
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        # 使用 nn.Sequential 定义网络层次结构
        self.net = nn.Sequential(
            nn.LayerNorm(dim),  # Layer normalization
            nn.Linear(dim, hidden_dim),  # 线性变换
            nn.GELU(),  # GELU 激活函数
            nn.Dropout(dropout),  # Dropout 正则化
            nn.Linear(hidden_dim, dim),  # 线性变换
            nn.Dropout(dropout)  # Dropout 正则化
        )
    # 前向传播函数
    def forward(self, x):
        return self.net(x)

# 定义注意力机制类
class Attention(nn.Module):
    # 初始化函数,定义注意力机制结构
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., output_num_tokens = None):
        super().__init__()
        inner_dim = dim_head *  heads
        self.heads = heads
        self.scale = dim_head ** -0.5

        self.norm = nn.LayerNorm(dim)  # Layer normalization
        self.attend = nn.Softmax(dim = -1)  # Softmax 注意力权重计算
        self.dropout = nn.Dropout(dropout)  # Dropout 正则化

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)  # 线性变换

        self.output_num_tokens = output_num_tokens
        self.ats = AdaptiveTokenSampling(output_num_tokens) if exists(output_num_tokens) else None

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),  # 线性变换
            nn.Dropout(dropout)  # Dropout 正则化
        )

    # 前向传播函数
    def forward(self, x, *, mask):
        num_tokens = x.shape[1]

        x = self.norm(x)
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        if exists(mask):
            dots_mask = rearrange(mask, 'b i -> b 1 i 1') * rearrange(mask, 'b j -> b 1 1 j')
            mask_value = -torch.finfo(dots.dtype).max
            dots = dots.masked_fill(~dots_mask, mask_value)

        attn = self.attend(dots)
        attn = self.dropout(attn)

        sampled_token_ids = None

        # 如果启用了自适应令牌采样,并且令牌数量大于输出令牌数量
        if exists(self.output_num_tokens) and (num_tokens - 1) > self.output_num_tokens:
            attn, mask, sampled_token_ids = self.ats(attn, v, mask = mask)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')

        return self.to_out(out), mask, sampled_token_ids

# 定义 Transformer 类
class Transformer(nn.Module):
    # 初始化函数,定义 Transformer 结构
    def __init__(self, dim, depth, max_tokens_per_depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        assert len(max_tokens_per_depth) == depth, 'max_tokens_per_depth must be a tuple of length that is equal to the depth of the transformer'
        assert sorted(max_tokens_per_depth, reverse = True) == list(max_tokens_per_depth), 'max_tokens_per_depth must be in decreasing order'
        assert min(max_tokens_per_depth) > 0, 'max_tokens_per_depth must have at least 1 token at any layer'

        self.layers = nn.ModuleList([])
        for _, output_num_tokens in zip(range(depth), max_tokens_per_depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, output_num_tokens = output_num_tokens, heads = heads, dim_head = dim_head, dropout = dropout),
                FeedForward(dim, mlp_dim, dropout = dropout)
            ]))
    # 定义前向传播函数,接受输入张量 x
    def forward(self, x):
        # 获取输入张量 x 的形状的前两个维度大小和设备信息
        b, n, device = *x.shape[:2], x.device

        # 使用掩码来跟踪填充位置,以便在采样标记时移除重复项
        mask = torch.ones((b, n), device=device, dtype=torch.bool)

        # 创建一个包含从 0 到 n-1 的张量,设备信息与输入张量 x 一致
        token_ids = torch.arange(n, device=device)
        token_ids = repeat(token_ids, 'n -> b n', b=b)

        # 遍历每个注意力层和前馈层
        for attn, ff in self.layers:
            # 调用注意力层的前向传播函数,获取注意力输出、更新后的掩码和采样的标记
            attn_out, mask, sampled_token_ids = attn(x, mask=mask)

            # 当进行标记采样时,需要使用采样的标记 id 从输入张量中选择对应的标记
            if exists(sampled_token_ids):
                x = batched_index_select(x, sampled_token_ids, dim=1)
                token_ids = batched_index_select(token_ids, sampled_token_ids, dim=1)

            # 更新输入张量,加上注意力输出
            x = x + attn_out

            # 经过前馈层处理后再加上原始输入,得到最终输出
            x = ff(x) + x

        # 返回最终输出张量和标记 id
        return x, token_ids
class ViT(nn.Module):
    # 定义 ViT 模型类,继承自 nn.Module
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, max_tokens_per_depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        # 初始化函数,接收参数 image_size, patch_size, num_classes, dim, depth, max_tokens_per_depth, heads, mlp_dim, channels, dim_head, dropout, emb_dropout
        super().__init__()
        # 调用父类的初始化函数

        image_height, image_width = pair(image_size)
        # 获取图像的高度和宽度
        patch_height, patch_width = pair(patch_size)
        # 获取补丁的高度和宽度

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
        # 断言,确保图像的尺寸能够被补丁的尺寸整除

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        # 计算补丁的数量
        patch_dim = channels * patch_height * patch_width
        # 计算每个补丁的维度

        self.to_patch_embedding = nn.Sequential(
            # 定义将图像转换为补丁嵌入的序列
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            # 重新排列图像的通道和补丁的维度
            nn.LayerNorm(patch_dim),
            # 对每个补丁进行 LayerNorm
            nn.Linear(patch_dim, dim),
            # 线性变换将每个补丁的维度映射到指定的维度 dim
            nn.LayerNorm(dim)
            # 对映射后的维度进行 LayerNorm
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        # 初始化位置嵌入参数
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        # 初始化类别标记参数
        self.dropout = nn.Dropout(emb_dropout)
        # 定义丢弃层,用于嵌入的丢弃

        self.transformer = Transformer(dim, depth, max_tokens_per_depth, heads, dim_head, mlp_dim, dropout)
        # 初始化 Transformer 模型

        self.mlp_head = nn.Sequential(
            # 定义 MLP 头部
            nn.LayerNorm(dim),
            # 对输入进行 LayerNorm
            nn.Linear(dim, num_classes)
            # 线性变换将维度映射到类别数量
        )

    def forward(self, img, return_sampled_token_ids = False):
        # 定义前向传播函数,接收图像和是否返回采样的令牌 ID

        x = self.to_patch_embedding(img)
        # 将图像转换为补丁嵌入
        b, n, _ = x.shape
        # 获取 x 的形状信息

        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
        # 重复类别标记,使其与补丁嵌入的形状相同
        x = torch.cat((cls_tokens, x), dim=1)
        # 拼接类别标记和补丁嵌入
        x += self.pos_embedding[:, :(n + 1)]
        # 添加位置嵌入
        x = self.dropout(x)
        # 对输入进行丢弃

        x, token_ids = self.transformer(x)
        # 使用 Transformer 进行转换

        logits = self.mlp_head(x[:, 0])
        # 使用 MLP 头部生成输出

        if return_sampled_token_ids:
            # 如果需要返回采样的令牌 ID
            token_ids = token_ids[:, 1:] - 1
            # 移除类别标记并减去 1 以使 -1 成为填充
            return logits, token_ids
            # 返回输出和令牌 ID

        return logits
        # 返回输出

.\lucidrains\vit-pytorch\vit_pytorch\cait.py

from random import randrange
import torch
from torch import nn, einsum
import torch.nn.functional as F

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# helpers

# 检查值是否存在
def exists(val):
    return val is not None

# 对层应用 dropout
def dropout_layers(layers, dropout):
    if dropout == 0:
        return layers

    num_layers = len(layers)
    to_drop = torch.zeros(num_layers).uniform_(0., 1.) < dropout

    # 确保至少有一层保留
    if all(to_drop):
        rand_index = randrange(num_layers)
        to_drop[rand_index] = False

    layers = [layer for (layer, drop) in zip(layers, to_drop) if not drop]
    return layers

# classes

# 缩放层
class LayerScale(nn.Module):
    def __init__(self, dim, fn, depth):
        super().__init__()
        if depth <= 18:  # 根据深度选择初始化值,详见论文第2节
            init_eps = 0.1
        elif depth > 18 and depth <= 24:
            init_eps = 1e-5
        else:
            init_eps = 1e-6

        scale = torch.zeros(1, 1, dim).fill_(init_eps)
        self.scale = nn.Parameter(scale)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) * self.scale

# 前馈神经网络
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

# 注意力机制
class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        self.heads = heads
        self.scale = dim_head ** -0.5

        self.norm = nn.LayerNorm(dim)
        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)

        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.mix_heads_pre_attn = nn.Parameter(torch.randn(heads, heads))
        self.mix_heads_post_attn = nn.Parameter(torch.randn(heads, heads))

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x, context = None):
        b, n, _, h = *x.shape, self.heads

        x = self.norm(x)
        context = x if not exists(context) else torch.cat((x, context), dim = 1)

        qkv = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)

        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        dots = einsum('b h i j, h g -> b g i j', dots, self.mix_heads_pre_attn)    # talking heads, pre-softmax

        attn = self.attend(dots)
        attn = self.dropout(attn)

        attn = einsum('b h i j, h g -> b g i j', attn, self.mix_heads_post_attn)   # talking heads, post-softmax

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

# Transformer 模型
class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., layer_dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        self.layer_dropout = layer_dropout

        for ind in range(depth):
            self.layers.append(nn.ModuleList([
                LayerScale(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout), depth = ind + 1),
                LayerScale(dim, FeedForward(dim, mlp_dim, dropout = dropout), depth = ind + 1)
            ]))
    def forward(self, x, context = None):
        layers = dropout_layers(self.layers, dropout = self.layer_dropout)

        for attn, ff in layers:
            x = attn(x, context = context) + x
            x = ff(x) + x
        return x

class CaiT(nn.Module):
    # 初始化函数,设置模型参数
    def __init__(
        self,
        *,
        image_size,  # 图像大小
        patch_size,  # 补丁大小
        num_classes,  # 类别数量
        dim,  # 特征维度
        depth,  # 深度
        cls_depth,  # 分类深度
        heads,  # 多头注意力头数
        mlp_dim,  # MLP隐藏层维度
        dim_head = 64,  # 头维度
        dropout = 0.,  # 丢弃率
        emb_dropout = 0.,  # 嵌入层丢弃率
        layer_dropout = 0.  # 层丢弃率
    ):
        # 调用父类初始化函数
        super().__init__()
        # 检查图像尺寸是否能被补丁大小整除
        assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
        # 计算补丁数量
        num_patches = (image_size // patch_size) ** 2
        # 计算补丁维度
        patch_dim = 3 * patch_size ** 2

        # 补丁嵌入层
        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim)
        )

        # 位置嵌入
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, dim))
        # 分类令牌
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))

        # 丢弃层
        self.dropout = nn.Dropout(emb_dropout)

        # 补丁Transformer
        self.patch_transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, layer_dropout)
        # 分类Transformer
        self.cls_transformer = Transformer(dim, cls_depth, heads, dim_head, mlp_dim, dropout, layer_dropout)

        # MLP头
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    # 前向传播函数
    def forward(self, img):
        # 补丁嵌入
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        # 添加位置嵌入
        x += self.pos_embedding[:, :n]
        x = self.dropout(x)

        # 补丁Transformer
        x = self.patch_transformer(x)

        # 重复分类令牌
        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
        # 分类Transformer
        x = self.cls_transformer(cls_tokens, context = x)

        # 返回MLP头的结果
        return self.mlp_head(x[:, 0])

.\lucidrains\vit-pytorch\vit_pytorch\cct.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块和 einsum 函数
from torch import nn, einsum
# 从 torch.nn 模块中导入 F 函数
import torch.nn.functional as F

# 从 einops 库中导入 rearrange 和 repeat 函数
from einops import rearrange, repeat

# 定义辅助函数

# 判断变量是否存在的函数
def exists(val):
    return val is not None

# 如果变量存在则返回该变量,否则返回默认值的函数
def default(val, d):
    return val if exists(val) else d

# 将输入转换为元组的函数
def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# CCT 模型

# 定义导出的 CCT 模型名称列表
__all__ = ['cct_2', 'cct_4', 'cct_6', 'cct_7', 'cct_8', 'cct_14', 'cct_16']

# 定义创建不同层数 CCT 模型的函数
def cct_2(*args, **kwargs):
    return _cct(num_layers=2, num_heads=2, mlp_ratio=1, embedding_dim=128,
                *args, **kwargs)

def cct_4(*args, **kwargs):
    return _cct(num_layers=4, num_heads=2, mlp_ratio=1, embedding_dim=128,
                *args, **kwargs)

def cct_6(*args, **kwargs):
    return _cct(num_layers=6, num_heads=4, mlp_ratio=2, embedding_dim=256,
                *args, **kwargs)

def cct_7(*args, **kwargs):
    return _cct(num_layers=7, num_heads=4, mlp_ratio=2, embedding_dim=256,
                *args, **kwargs)

def cct_8(*args, **kwargs):
    return _cct(num_layers=8, num_heads=4, mlp_ratio=2, embedding_dim=256,
                *args, **kwargs)

def cct_14(*args, **kwargs):
    return _cct(num_layers=14, num_heads=6, mlp_ratio=3, embedding_dim=384,
                *args, **kwargs)

def cct_16(*args, **kwargs):
    return _cct(num_layers=16, num_heads=6, mlp_ratio=3, embedding_dim=384,
                *args, **kwargs)

# 创建 CCT 模型的内部函数
def _cct(num_layers, num_heads, mlp_ratio, embedding_dim,
         kernel_size=3, stride=None, padding=None,
         *args, **kwargs):
    # 计算默认的步长和填充值
    stride = default(stride, max(1, (kernel_size // 2) - 1))
    padding = default(padding, max(1, (kernel_size // 2)))

    # 返回 CCT 模型
    return CCT(num_layers=num_layers,
               num_heads=num_heads,
               mlp_ratio=mlp_ratio,
               embedding_dim=embedding_dim,
               kernel_size=kernel_size,
               stride=stride,
               padding=padding,
               *args, **kwargs)

# 位置编码

# 创建正弦位置编码的函数
def sinusoidal_embedding(n_channels, dim):
    pe = torch.FloatTensor([[p / (10000 ** (2 * (i // 2) / dim)) for i in range(dim)]
                            for p in range(n_channels)])
    pe[:, 0::2] = torch.sin(pe[:, 0::2])
    pe[:, 1::2] = torch.cos(pe[:, 1::2])
    return rearrange(pe, '... -> 1 ...')

# 模块

# 定义注意力机制模块
class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, attention_dropout=0.1, projection_dropout=0.1):
        super().__init__()
        self.heads = num_heads
        head_dim = dim // self.heads
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=False)
        self.attn_drop = nn.Dropout(attention_dropout)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(projection_dropout)

    def forward(self, x):
        B, N, C = x.shape

        qkv = self.qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        q = q * self.scale

        attn = einsum('b h i d, b h j d -> b h i j', q, k)
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = einsum('b h i j, b h j d -> b h i d', attn, v)
        x = rearrange(x, 'b h n d -> b n (h d)')

        return self.proj_drop(self.proj(x))

# 定义 Transformer 编码器层模块
class TransformerEncoderLayer(nn.Module):
    """
    Inspired by torch.nn.TransformerEncoderLayer and
    rwightman's timm package.
    """
    # 初始化函数,定义了 Transformer Encoder 层的结构
    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
                 attention_dropout=0.1, drop_path_rate=0.1):
        # 调用父类的初始化函数
        super().__init__()

        # 对输入进行 Layer Normalization
        self.pre_norm = nn.LayerNorm(d_model)
        # 定义自注意力机制
        self.self_attn = Attention(dim=d_model, num_heads=nhead,
                                   attention_dropout=attention_dropout, projection_dropout=dropout)

        # 第一个线性层
        self.linear1  = nn.Linear(d_model, dim_feedforward)
        # 第一个 Dropout 层
        self.dropout1 = nn.Dropout(dropout)
        # 第一个 Layer Normalization 层
        self.norm1    = nn.LayerNorm(d_model)
        # 第二个线性层
        self.linear2  = nn.Linear(dim_feedforward, d_model)
        # 第二个 Dropout 层
        self.dropout2 = nn.Dropout(dropout)

        # DropPath 模块
        self.drop_path = DropPath(drop_path_rate)

        # 激活函数为 GELU
        self.activation = F.gelu

    # 前向传播函数
    def forward(self, src, *args, **kwargs):
        # 使用自注意力机制处理输入,并加上 DropPath 模块
        src = src + self.drop_path(self.self_attn(self.pre_norm(src)))
        # 对结果进行 Layer Normalization
        src = self.norm1(src)
        # 第一个线性层、激活函数、Dropout、第二个线性层的组合
        src2 = self.linear2(self.dropout1(self.activation(self.linear1(src))))
        # 将结果与 DropPath 模块处理后的结果相加
        src = src + self.drop_path(self.dropout2(src2))
        # 返回处理后的结果
        return src
class DropPath(nn.Module):
    # 初始化 DropPath 类
    def __init__(self, drop_prob=None):
        # 调用父类的初始化方法
        super().__init__()
        # 将传入的 drop_prob 转换为浮点数
        self.drop_prob = float(drop_prob)

    # 前向传播方法
    def forward(self, x):
        # 获取输入 x 的批次大小、drop_prob、设备和数据类型
        batch, drop_prob, device, dtype = x.shape[0], self.drop_prob, x.device, x.dtype

        # 如果 drop_prob 小于等于 0 或者不处于训练模式,则直接返回输入 x
        if drop_prob <= 0. or not self.training:
            return x

        # 计算保留概率
        keep_prob = 1 - self.drop_prob
        # 构建形状元组
        shape = (batch, *((1,) * (x.ndim - 1)))

        # 生成保留掩码
        keep_mask = torch.zeros(shape, device=device).float().uniform_(0, 1) < keep_prob
        # 对输入 x 进行 DropPath 操作
        output = x.div(keep_prob) * keep_mask.float()
        return output

class Tokenizer(nn.Module):
    # 初始化 Tokenizer 类
    def __init__(self,
                 kernel_size, stride, padding,
                 pooling_kernel_size=3, pooling_stride=2, pooling_padding=1,
                 n_conv_layers=1,
                 n_input_channels=3,
                 n_output_channels=64,
                 in_planes=64,
                 activation=None,
                 max_pool=True,
                 conv_bias=False):
        # 调用父类的初始化方法
        super().__init()

        # 构建卷积层的通道数列表
        n_filter_list = [n_input_channels] + \
                        [in_planes for _ in range(n_conv_layers - 1)] + \
                        [n_output_channels]

        # 构建通道数列表的配对
        n_filter_list_pairs = zip(n_filter_list[:-1], n_filter_list[1:])

        # 构建卷积层序列
        self.conv_layers = nn.Sequential(
            *[nn.Sequential(
                nn.Conv2d(chan_in, chan_out,
                          kernel_size=(kernel_size, kernel_size),
                          stride=(stride, stride),
                          padding=(padding, padding), bias=conv_bias),
                nn.Identity() if not exists(activation) else activation(),
                nn.MaxPool2d(kernel_size=pooling_kernel_size,
                             stride=pooling_stride,
                             padding=pooling_padding) if max_pool else nn.Identity()
            )
                for chan_in, chan_out in n_filter_list_pairs
            ])

        # 对模型参数进行初始化
        self.apply(self.init_weight)

    # 计算序列长度
    def sequence_length(self, n_channels=3, height=224, width=224):
        return self.forward(torch.zeros((1, n_channels, height, width))).shape[1]

    # 前向传播方法
    def forward(self, x):
        # 对卷积层的输出进行重排列
        return rearrange(self.conv_layers(x), 'b c h w -> b (h w) c')

    # 初始化权重方法
    @staticmethod
    def init_weight(m):
        if isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight)


class TransformerClassifier(nn.Module):
    # 初始化函数,设置模型的各种参数
    def __init__(self,
                 seq_pool=True,  # 是否使用序列池化
                 embedding_dim=768,  # 嵌入维度
                 num_layers=12,  # 编码器层数
                 num_heads=12,  # 注意力头数
                 mlp_ratio=4.0,  # MLP 扩展比例
                 num_classes=1000,  # 类别数
                 dropout_rate=0.1,  # Dropout 比例
                 attention_dropout=0.1,  # 注意力 Dropout 比例
                 stochastic_depth_rate=0.1,  # 随机深度比例
                 positional_embedding='sine',  # 位置编码类型
                 sequence_length=None,  # 序列长度
                 *args, **kwargs):  # 其他参数
        super().__init__()  # 调用父类的初始化函数
        assert positional_embedding in {'sine', 'learnable', 'none'}  # 断言位置编码类型合法

        dim_feedforward = int(embedding_dim * mlp_ratio)  # 计算前馈网络维度
        self.embedding_dim = embedding_dim  # 设置嵌入维度
        self.sequence_length = sequence_length  # 设置序列长度
        self.seq_pool = seq_pool  # 设置是否使用序列池化

        assert exists(sequence_length) or positional_embedding == 'none', \  # 断言序列长度存在或位置编码为'none'
            f"Positional embedding is set to {positional_embedding} and" \  # 打印位置编码设置信息
            f" the sequence length was not specified."

        if not seq_pool:  # 如果不使用序列池化
            sequence_length += 1  # 序列长度加一
            self.class_emb = nn.Parameter(torch.zeros(1, 1, self.embedding_dim), requires_grad=True)  # 创建类别嵌入参数
        else:
            self.attention_pool = nn.Linear(self.embedding_dim, 1)  # 创建注意力池化层

        if positional_embedding == 'none':  # 如果位置编码为'none'
            self.positional_emb = None  # 不使用位置编码
        elif positional_embedding == 'learnable':  # 如果位置编码为'learnable'
            self.positional_emb = nn.Parameter(torch.zeros(1, sequence_length, embedding_dim),  # 创建可学习位置编码参数
                                               requires_grad=True)
            nn.init.trunc_normal_(self.positional_emb, std=0.2)  # 对位置编码参数进行初始化
        else:
            self.positional_emb = nn.Parameter(sinusoidal_embedding(sequence_length, embedding_dim),  # 创建正弦位置编码参数
                                               requires_grad=False)

        self.dropout = nn.Dropout(p=dropout_rate)  # 创建 Dropout 层

        dpr = [x.item() for x in torch.linspace(0, stochastic_depth_rate, num_layers)]  # 计算随机深度比例列表

        self.blocks = nn.ModuleList([  # 创建 Transformer 编码器层列表
            TransformerEncoderLayer(d_model=embedding_dim, nhead=num_heads,
                                    dim_feedforward=dim_feedforward, dropout=dropout_rate,
                                    attention_dropout=attention_dropout, drop_path_rate=layer_dpr)
            for layer_dpr in dpr])

        self.norm = nn.LayerNorm(embedding_dim)  # 创建 LayerNorm 层

        self.fc = nn.Linear(embedding_dim, num_classes)  # 创建全连接层
        self.apply(self.init_weight)  # 应用初始化权重函数

    # 前向传播函数
    def forward(self, x):
        b = x.shape[0]  # 获取 batch 大小

        if not exists(self.positional_emb) and x.size(1) < self.sequence_length:  # 如果位置编码不存在且序列长度小于指定长度
            x = F.pad(x, (0, 0, 0, self.n_channels - x.size(1)), mode='constant', value=0)  # 对输入进行填充

        if not self.seq_pool:  # 如果不使用序列池化
            cls_token = repeat(self.class_emb, '1 1 d -> b 1 d', b = b)  # 重复类别嵌入
            x = torch.cat((cls_token, x), dim=1)  # 拼接类别嵌入和输入

        if exists(self.positional_emb):  # 如果位置编码存在
            x += self.positional_emb  # 加上位置编码

        x = self.dropout(x)  # Dropout

        for blk in self.blocks:  # 遍历编码器层
            x = blk(x)  # 应用编码器层

        x = self.norm(x)  # LayerNorm

        if self.seq_pool:  # 如果使用序列池化
            attn_weights = rearrange(self.attention_pool(x), 'b n 1 -> b n')  # 注意力权重计算
            x = einsum('b n, b n d -> b d', attn_weights.softmax(dim = 1), x)  # 加权池化
        else:
            x = x[:, 0]  # 取第一个位置的输出作为结果

        return self.fc(x)  # 全连接层输出结果

    # 初始化权重函数
    @staticmethod
    def init_weight(m):
        if isinstance(m, nn.Linear):  # 如果是线性层
            nn.init.trunc_normal_(m.weight, std=.02)  # 初始化权重
            if isinstance(m, nn.Linear) and exists(m.bias):  # 如果是线性层且存在偏置
                nn.init.constant_(m.bias, 0)  # 初始化偏置为0
        elif isinstance(m, nn.LayerNorm):  # 如果是 LayerNorm 层
            nn.init.constant_(m.bias, 0)  # 初始化偏置为0
            nn.init.constant_(m.weight, 1.0)  # 初始化权重为1.0
# 定义 CCT 类,继承自 nn.Module
class CCT(nn.Module):
    # 初始化函数,设置各种参数
    def __init__(
        self,
        img_size=224,
        embedding_dim=768,
        n_input_channels=3,
        n_conv_layers=1,
        kernel_size=7,
        stride=2,
        padding=3,
        pooling_kernel_size=3,
        pooling_stride=2,
        pooling_padding=1,
        *args, **kwargs
    ):
        # 调用父类的初始化函数
        super().__init__()
        # 获取图像的高度和宽度
        img_height, img_width = pair(img_size)

        # 初始化 Tokenizer 对象
        self.tokenizer = Tokenizer(n_input_channels=n_input_channels,
                                   n_output_channels=embedding_dim,
                                   kernel_size=kernel_size,
                                   stride=stride,
                                   padding=padding,
                                   pooling_kernel_size=pooling_kernel_size,
                                   pooling_stride=pooling_stride,
                                   pooling_padding=pooling_padding,
                                   max_pool=True,
                                   activation=nn.ReLU,
                                   n_conv_layers=n_conv_layers,
                                   conv_bias=False)

        # 初始化 TransformerClassifier 对象
        self.classifier = TransformerClassifier(
            sequence_length=self.tokenizer.sequence_length(n_channels=n_input_channels,
                                                           height=img_height,
                                                           width=img_width),
            embedding_dim=embedding_dim,
            seq_pool=True,
            dropout_rate=0.,
            attention_dropout=0.1,
            stochastic_depth=0.1,
            *args, **kwargs)

    # 前向传播函数
    def forward(self, x):
        # 对输入数据进行编码
        x = self.tokenizer(x)
        # 使用 Transformer 进行分类
        return self.classifier(x)

.\lucidrains\vit-pytorch\vit_pytorch\cct_3d.py

import torch  # 导入 PyTorch 库
from torch import nn, einsum  # 从 PyTorch 库中导入 nn 模块和 einsum 函数
import torch.nn.functional as F  # 从 PyTorch 库中导入 F 模块

from einops import rearrange, repeat  # 从 einops 库中导入 rearrange 和 repeat 函数

# helpers

def exists(val):
    return val is not None  # 判断变量是否存在的辅助函数

def default(val, d):
    return val if exists(val) else d  # 如果变量存在则返回变量,否则返回默认值的辅助函数

def pair(t):
    return t if isinstance(t, tuple) else (t, t)  # 如果输入是元组则返回输入,否则返回包含输入两次的元组

# CCT Models

__all__ = ['cct_2', 'cct_4', 'cct_6', 'cct_7', 'cct_8', 'cct_14', 'cct_16']  # 定义导出的模型名称列表

# 定义不同层数的 CCT 模型函数

def cct_2(*args, **kwargs):
    return _cct(num_layers=2, num_heads=2, mlp_ratio=1, embedding_dim=128,
                *args, **kwargs)  # 返回 2 层 CCT 模型

def cct_4(*args, **kwargs):
    return _cct(num_layers=4, num_heads=2, mlp_ratio=1, embedding_dim=128,
                *args, **kwargs)  # 返回 4 层 CCT 模型

def cct_6(*args, **kwargs):
    return _cct(num_layers=6, num_heads=4, mlp_ratio=2, embedding_dim=256,
                *args, **kwargs)  # 返回 6 层 CCT 模型

def cct_7(*args, **kwargs):
    return _cct(num_layers=7, num_heads=4, mlp_ratio=2, embedding_dim=256,
                *args, **kwargs)  # 返回 7 层 CCT 模型

def cct_8(*args, **kwargs):
    return _cct(num_layers=8, num_heads=4, mlp_ratio=2, embedding_dim=256,
                *args, **kwargs)  # 返回 8 层 CCT 模型

def cct_14(*args, **kwargs):
    return _cct(num_layers=14, num_heads=6, mlp_ratio=3, embedding_dim=384,
                *args, **kwargs)  # 返回 14 层 CCT 模型

def cct_16(*args, **kwargs):
    return _cct(num_layers=16, num_heads=6, mlp_ratio=3, embedding_dim=384,
                *args, **kwargs)  # 返回 16 层 CCT 模型

# 定义 CCT 模型函数

def _cct(num_layers, num_heads, mlp_ratio, embedding_dim,
         kernel_size=3, stride=None, padding=None,
         *args, **kwargs):
    stride = default(stride, max(1, (kernel_size // 2) - 1))  # 设置默认的步长
    padding = default(padding, max(1, (kernel_size // 2)))  # 设置默认的填充大小

    return CCT(num_layers=num_layers,
               num_heads=num_heads,
               mlp_ratio=mlp_ratio,
               embedding_dim=embedding_dim,
               kernel_size=kernel_size,
               stride=stride,
               padding=padding,
               *args, **kwargs)  # 返回 CCT 模型

# positional

def sinusoidal_embedding(n_channels, dim):
    pe = torch.FloatTensor([[p / (10000 ** (2 * (i // 2) / dim)) for i in range(dim)]
                            for p in range(n_channels)])  # 计算正弦余弦位置编码
    pe[:, 0::2] = torch.sin(pe[:, 0::2])  # 偶数列使用正弦函数
    pe[:, 1::2] = torch.cos(pe[:, 1::2])  # 奇数列使用余弦函数
    return rearrange(pe, '... -> 1 ...')  # 重新排列位置编码的维度

# modules

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, attention_dropout=0.1, projection_dropout=0.1):
        super().__init__()
        self.heads = num_heads  # 设置注意力头数
        head_dim = dim // self.heads  # 计算每个头的维度
        self.scale = head_dim ** -0.5  # 缩放因子

        self.qkv = nn.Linear(dim, dim * 3, bias=False)  # 线性变换层
        self.attn_drop = nn.Dropout(attention_dropout)  # 注意力丢弃层
        self.proj = nn.Linear(dim, dim)  # 投影层
        self.proj_drop = nn.Dropout(projection_dropout)  # 投影丢弃层

    def forward(self, x):
        B, N, C = x.shape  # 获取输入张量的形状

        qkv = self.qkv(x).chunk(3, dim = -1)  # 将线性变换后的张量切分为 Q、K、V
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)  # 重排 Q、K、V 的维度

        q = q * self.scale  # 缩放 Q

        attn = einsum('b h i d, b h j d -> b h i j', q, k)  # 计算注意力分数
        attn = attn.softmax(dim=-1)  # 对注意力分数进行 softmax
        attn = self.attn_drop(attn)  # 使用注意力丢弃层

        x = einsum('b h i j, b h j d -> b h i d', attn, v)  # 计算加权后的 V
        x = rearrange(x, 'b h n d -> b n (h d)')  # 重排输出张量的维度

        return self.proj_drop(self.proj(x))  # 使用投影丢弃层进行投影

class TransformerEncoderLayer(nn.Module):
    """
    Inspired by torch.nn.TransformerEncoderLayer and
    rwightman's timm package.
    """
    # 初始化函数,定义了 Transformer Encoder 层的结构
    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
                 attention_dropout=0.1, drop_path_rate=0.1):
        # 调用父类的初始化函数
        super().__init__()

        # 对输入进行 Layer Normalization
        self.pre_norm = nn.LayerNorm(d_model)
        # 定义自注意力机制
        self.self_attn = Attention(dim=d_model, num_heads=nhead,
                                   attention_dropout=attention_dropout, projection_dropout=dropout)

        # 第一个线性层
        self.linear1  = nn.Linear(d_model, dim_feedforward)
        # 第一个 Dropout 层
        self.dropout1 = nn.Dropout(dropout)
        # 第一个 Layer Normalization 层
        self.norm1    = nn.LayerNorm(d_model)
        # 第二个线性层
        self.linear2  = nn.Linear(dim_feedforward, d_model)
        # 第二个 Dropout 层
        self.dropout2 = nn.Dropout(dropout)

        # DropPath 模块
        self.drop_path = DropPath(drop_path_rate)

        # 激活函数为 GELU
        self.activation = F.gelu

    # 前向传播函数
    def forward(self, src, *args, **kwargs):
        # 使用自注意力机制对输入进行处理,并加上 DropPath 模块
        src = src + self.drop_path(self.self_attn(self.pre_norm(src)))
        # 对结果进行 Layer Normalization
        src = self.norm1(src)
        # 第二个线性层的计算过程
        src2 = self.linear2(self.dropout1(self.activation(self.linear1(src))))
        # 将第二个线性层的结果加上 DropPath 模块
        src = src + self.drop_path(self.dropout2(src2))
        # 返回处理后的结果
        return src
class DropPath(nn.Module):
    # 初始化 DropPath 类
    def __init__(self, drop_prob=None):
        # 调用父类的初始化方法
        super().__init__()
        # 将传入的 drop_prob 转换为浮点数
        self.drop_prob = float(drop_prob)

    # 前向传播方法
    def forward(self, x):
        # 获取输入 x 的批量大小、drop_prob、设备和数据类型
        batch, drop_prob, device, dtype = x.shape[0], self.drop_prob, x.device, x.dtype

        # 如果 drop_prob 小于等于 0 或者不处于训练模式,则直接返回输入 x
        if drop_prob <= 0. or not self.training:
            return x

        # 计算保留概率
        keep_prob = 1 - self.drop_prob
        # 构建形状元组
        shape = (batch, *((1,) * (x.ndim - 1)))

        # 生成保留掩码
        keep_mask = torch.zeros(shape, device=device).float().uniform_(0, 1) < keep_prob
        # 对输入 x 进行处理并返回输出
        output = x.div(keep_prob) * keep_mask.float()
        return output

class Tokenizer(nn.Module):
    # 初始化 Tokenizer 类
    def __init__(
        self,
        frame_kernel_size,
        kernel_size,
        stride,
        padding,
        frame_stride=1,
        frame_pooling_stride=1,
        frame_pooling_kernel_size=1,
        pooling_kernel_size=3,
        pooling_stride=2,
        pooling_padding=1,
        n_conv_layers=1,
        n_input_channels=3,
        n_output_channels=64,
        in_planes=64,
        activation=None,
        max_pool=True,
        conv_bias=False
    ):
        # 调用父类的初始化方法
        super().__init__()

        # 构建卷积层的通道数列表
        n_filter_list = [n_input_channels] + \
                        [in_planes for _ in range(n_conv_layers - 1)] + \
                        [n_output_channels]

        # 构建通道数列表的配对
        n_filter_list_pairs = zip(n_filter_list[:-1], n_filter_list[1:])

        # 构建卷积层序列
        self.conv_layers = nn.Sequential(
            *[nn.Sequential(
                nn.Conv3d(chan_in, chan_out,
                          kernel_size=(frame_kernel_size, kernel_size, kernel_size),
                          stride=(frame_stride, stride, stride),
                          padding=(frame_kernel_size // 2, padding, padding), bias=conv_bias),
                nn.Identity() if not exists(activation) else activation(),
                nn.MaxPool3d(kernel_size=(frame_pooling_kernel_size, pooling_kernel_size, pooling_kernel_size),
                             stride=(frame_pooling_stride, pooling_stride, pooling_stride),
                             padding=(frame_pooling_kernel_size // 2, pooling_padding, pooling_padding)) if max_pool else nn.Identity()
            )
                for chan_in, chan_out in n_filter_list_pairs
            ])

        # 对模型进行权重初始化
        self.apply(self.init_weight)

    # 计算序列长度
    def sequence_length(self, n_channels=3, frames=8, height=224, width=224):
        return self.forward(torch.zeros((1, n_channels, frames, height, width))).shape[1]

    # 前向传播方法
    def forward(self, x):
        # 对输入 x 进行卷积操作并返回重排后的输出
        x = self.conv_layers(x)
        return rearrange(x, 'b c f h w -> b (f h w) c')

    # 初始化权重方法
    @staticmethod
    def init_weight(m):
        if isinstance(m, nn.Conv3d):
            nn.init.kaiming_normal_(m.weight)


class TransformerClassifier(nn.Module):
    # 初始化 TransformerClassifier 类
    def __init__(
        self,
        seq_pool=True,
        embedding_dim=768,
        num_layers=12,
        num_heads=12,
        mlp_ratio=4.0,
        num_classes=1000,
        dropout_rate=0.1,
        attention_dropout=0.1,
        stochastic_depth_rate=0.1,
        positional_embedding='sine',
        sequence_length=None,
        *args, **kwargs
    ):
        # 调用父类的构造函数
        super().__init__()
        # 断言位置编码在{'sine', 'learnable', 'none'}中
        assert positional_embedding in {'sine', 'learnable', 'none'}

        # 计算前馈网络的维度
        dim_feedforward = int(embedding_dim * mlp_ratio)
        self.embedding_dim = embedding_dim
        self.sequence_length = sequence_length
        self.seq_pool = seq_pool

        # 断言序列长度存在或者位置编码为'none'
        assert exists(sequence_length) or positional_embedding == 'none', \
            f"Positional embedding is set to {positional_embedding} and" \
            f" the sequence length was not specified."

        # 如果不使用序列池化
        if not seq_pool:
            sequence_length += 1
            self.class_emb = nn.Parameter(torch.zeros(1, 1, self.embedding_dim))
        else:
            self.attention_pool = nn.Linear(self.embedding_dim, 1)

        # 根据位置编码类型初始化位置编码
        if positional_embedding == 'none':
            self.positional_emb = None
        elif positional_embedding == 'learnable':
            self.positional_emb = nn.Parameter(torch.zeros(1, sequence_length, embedding_dim))
            nn.init.trunc_normal_(self.positional_emb, std=0.2)
        else:
            self.register_buffer('positional_emb', sinusoidal_embedding(sequence_length, embedding_dim))

        # 初始化Dropout层
        self.dropout = nn.Dropout(p=dropout_rate)

        # 生成随机Drop Path率
        dpr = [x.item() for x in torch.linspace(0, stochastic_depth_rate, num_layers)]

        # 创建Transformer编码器层
        self.blocks = nn.ModuleList([
            TransformerEncoderLayer(d_model=embedding_dim, nhead=num_heads,
                                    dim_feedforward=dim_feedforward, dropout=dropout_rate,
                                    attention_dropout=attention_dropout, drop_path_rate=layer_dpr)
            for layer_dpr in dpr])

        # 初始化LayerNorm层
        self.norm = nn.LayerNorm(embedding_dim)

        # 初始化全连接层
        self.fc = nn.Linear(embedding_dim, num_classes)
        # 应用初始化权重函数
        self.apply(self.init_weight)

    @staticmethod
    def init_weight(m):
        # 初始化线性层的权重
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=.02)
            # 如果是线性层且存在偏置项,则初始化偏置项
            if isinstance(m, nn.Linear) and exists(m.bias):
                nn.init.constant_(m.bias, 0)
        # 初始化LayerNorm层的权重和偏置项
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x):
        # 获取批量大小
        b = x.shape[0]

        # 如果位置编码不存在且输入序列长度小于设定的序列长度,则进行填充
        if not exists(self.positional_emb) and x.size(1) < self.sequence_length:
            x = F.pad(x, (0, 0, 0, self.n_channels - x.size(1)), mode='constant', value=0)

        # 如果不使用序列池化,则在输入序列前添加类别标记
        if not self.seq_pool:
            cls_token = repeat(self.class_emb, '1 1 d -> b 1 d', b=b)
            x = torch.cat((cls_token, x), dim=1)

        # 如果位置编码存在,则加上位置编码
        if exists(self.positional_emb):
            x += self.positional_emb

        # Dropout层
        x = self.dropout(x)

        # 遍历Transformer编码器层
        for blk in self.blocks:
            x = blk(x)

        # LayerNorm层
        x = self.norm(x)

        # 如果使用序列池化,则计算注意力权重并进行加权求和
        if self.seq_pool:
            attn_weights = rearrange(self.attention_pool(x), 'b n 1 -> b n')
            x = einsum('b n, b n d -> b d', attn_weights.softmax(dim=1), x)
        else:
            x = x[:, 0]

        # 全连接层
        return self.fc(x)
# 定义 CCT 类,继承自 nn.Module
class CCT(nn.Module):
    # 初始化函数,设置模型参数
    def __init__(
        self,
        img_size=224,  # 图像大小,默认为 224
        num_frames=8,  # 帧数,默认为 8
        embedding_dim=768,  # 嵌入维度,默认为 768
        n_input_channels=3,  # 输入通道数,默认为 3
        n_conv_layers=1,  # 卷积层数,默认为 1
        frame_stride=1,  # 帧步长,默认为 1
        frame_kernel_size=3,  # 帧卷积核大小,默认为 3
        frame_pooling_kernel_size=1,  # 帧池化核大小,默认为 1
        frame_pooling_stride=1,  # 帧池化步长,默认为 1
        kernel_size=7,  # 卷积核大小,默认为 7
        stride=2,  # 步长,默认为 2
        padding=3,  # 填充,默认为 3
        pooling_kernel_size=3,  # 池化核大小,默认为 3
        pooling_stride=2,  # 池化步长,默认为 2
        pooling_padding=1,  # 池化填充,默认为 1
        *args, **kwargs  # 其他参数
    ):
        super().__init__()  # 调用父类的初始化函数

        img_height, img_width = pair(img_size)  # 获取图像的高度和宽度

        # 初始化 Tokenizer 对象
        self.tokenizer = Tokenizer(
            n_input_channels=n_input_channels,
            n_output_channels=embedding_dim,
            frame_stride=frame_stride,
            frame_kernel_size=frame_kernel_size,
            frame_pooling_stride=frame_pooling_stride,
            frame_pooling_kernel_size=frame_pooling_kernel_size,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            pooling_kernel_size=pooling_kernel_size,
            pooling_stride=pooling_stride,
            pooling_padding=pooling_padding,
            max_pool=True,
            activation=nn.ReLU,
            n_conv_layers=n_conv_layers,
            conv_bias=False
        )

        # 初始化 TransformerClassifier 对象
        self.classifier = TransformerClassifier(
            sequence_length=self.tokenizer.sequence_length(
                n_channels=n_input_channels,
                frames=num_frames,
                height=img_height,
                width=img_width
            ),
            embedding_dim=embedding_dim,
            seq_pool=True,
            dropout_rate=0.,
            attention_dropout=0.1,
            stochastic_depth=0.1,
            *args, **kwargs
        )

    # 前向传播函数
    def forward(self, x):
        x = self.tokenizer(x)  # 对输入数据进行编码
        return self.classifier(x)  # 对编码后的���据进行分类

.\lucidrains\vit-pytorch\vit_pytorch\crossformer.py

import torch
from torch import nn, einsum
from einops import rearrange
from einops.layers.torch import Rearrange, Reduce
import torch.nn.functional as F

# 辅助函数

def cast_tuple(val, length = 1):
    return val if isinstance(val, tuple) else ((val,) * length)

# 交叉嵌入层

class CrossEmbedLayer(nn.Module):
    def __init__(
        self,
        dim_in,
        dim_out,
        kernel_sizes,
        stride = 2
    ):
        super().__init__()
        kernel_sizes = sorted(kernel_sizes)
        num_scales = len(kernel_sizes)

        # 计算每个尺度的维度
        dim_scales = [int(dim_out / (2 ** i)) for i in range(1, num_scales)]
        dim_scales = [*dim_scales, dim_out - sum(dim_scales)]

        self.convs = nn.ModuleList([])
        for kernel, dim_scale in zip(kernel_sizes, dim_scales):
            self.convs.append(nn.Conv2d(dim_in, dim_scale, kernel, stride = stride, padding = (kernel - stride) // 2))

    def forward(self, x):
        # 对输入进行卷积操作,并将结果拼接在一起
        fmaps = tuple(map(lambda conv: conv(x), self.convs))
        return torch.cat(fmaps, dim = 1)

# 动态位置偏置

def DynamicPositionBias(dim):
    return nn.Sequential(
        nn.Linear(2, dim),
        nn.LayerNorm(dim),
        nn.ReLU(),
        nn.Linear(dim, dim),
        nn.LayerNorm(dim),
        nn.ReLU(),
        nn.Linear(dim, dim),
        nn.LayerNorm(dim),
        nn.ReLU(),
        nn.Linear(dim, 1),
        Rearrange('... () -> ...')
    )

# transformer 类

class LayerNorm(nn.Module):
    def __init__(self, dim, eps = 1e-5):
        super().__init__()
        self.eps = eps
        self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
        self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))

    def forward(self, x):
        var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
        mean = torch.mean(x, dim = 1, keepdim = True)
        return (x - mean) / (var + self.eps).sqrt() * self.g + self.b

def FeedForward(dim, mult = 4, dropout = 0.):
    return nn.Sequential(
        LayerNorm(dim),
        nn.Conv2d(dim, dim * mult, 1),
        nn.GELU(),
        nn.Dropout(dropout),
        nn.Conv2d(dim * mult, dim, 1)
    )

class Attention(nn.Module):
    def __init__(
        self,
        dim,
        attn_type,
        window_size,
        dim_head = 32,
        dropout = 0.
    ):
        super().__init__()
        assert attn_type in {'short', 'long'}, 'attention type must be one of local or distant'
        heads = dim // dim_head
        self.heads = heads
        self.scale = dim_head ** -0.5
        inner_dim = dim_head * heads

        self.attn_type = attn_type
        self.window_size = window_size

        self.norm = LayerNorm(dim)

        self.dropout = nn.Dropout(dropout)

        self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)
        self.to_out = nn.Conv2d(inner_dim, dim, 1)

        # 位置

        self.dpb = DynamicPositionBias(dim // 4)

        # 计算和存储用于检索偏置的索引

        pos = torch.arange(window_size)
        grid = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij'))
        grid = rearrange(grid, 'c i j -> (i j) c')
        rel_pos = grid[:, None] - grid[None, :]
        rel_pos += window_size - 1
        rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum(dim = -1)

        self.register_buffer('rel_pos_indices', rel_pos_indices, persistent = False)
    # 定义前向传播函数,接受输入 x
    def forward(self, x):
        # 解构 x 的形状,获取高度、宽度、头数、窗口大小和设备信息
        *_, height, width, heads, wsz, device = *x.shape, self.heads, self.window_size, x.device

        # 对输入进行预处理
        x = self.norm(x)

        # 根据不同的注意力类型重新排列输入,以便进行短距离或长距离注意力
        if self.attn_type == 'short':
            x = rearrange(x, 'b d (h s1) (w s2) -> (b h w) d s1 s2', s1 = wsz, s2 = wsz)
        elif self.attn_type == 'long':
            x = rearrange(x, 'b d (l1 h) (l2 w) -> (b h w) d l1 l2', l1 = wsz, l2 = wsz)

        # 将输入转换为查询、键、值
        q, k, v = self.to_qkv(x).chunk(3, dim = 1)

        # 将查询、键、值按头数进行分割
        q, k, v = map(lambda t: rearrange(t, 'b (h d) x y -> b h (x y) d', h = heads), (q, k, v))
        q = q * self.scale

        # 计算注意力矩阵
        sim = einsum('b h i d, b h j d -> b h i j', q, k)

        # 添加动态位置偏置
        pos = torch.arange(-wsz, wsz + 1, device = device)
        rel_pos = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij'))
        rel_pos = rearrange(rel_pos, 'c i j -> (i j) c')
        biases = self.dpb(rel_pos.float())
        rel_pos_bias = biases[self.rel_pos_indices]
        sim = sim + rel_pos_bias

        # 注意力权重归一化
        attn = sim.softmax(dim = -1)
        attn = self.dropout(attn)

        # 合并头部
        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = wsz, y = wsz)
        out = self.to_out(out)

        # 根据不同的注意力类型重新排列输出
        if self.attn_type == 'short':
            out = rearrange(out, '(b h w) d s1 s2 -> b d (h s1) (w s2)', h = height // wsz, w = width // wsz)
        elif self.attn_type == 'long':
            out = rearrange(out, '(b h w) d l1 l2 -> b d (l1 h) (l2 w)', h = height // wsz, w = width // wsz)

        return out
# 定义一个名为 Transformer 的神经网络模块
class Transformer(nn.Module):
    # 初始化函数,接受多个参数
    def __init__(
        self,
        dim,
        *,
        local_window_size,
        global_window_size,
        depth = 4,
        dim_head = 32,
        attn_dropout = 0.,
        ff_dropout = 0.,
    ):
        # 调用父类的初始化函数
        super().__init__()
        # 初始化一个空的神经网络模块列表
        self.layers = nn.ModuleList([])

        # 循环创建指定深度的神经网络层
        for _ in range(depth):
            # 每层包含两个注意力机制和两个前馈神经网络
            self.layers.append(nn.ModuleList([
                Attention(dim, attn_type = 'short', window_size = local_window_size, dim_head = dim_head, dropout = attn_dropout),
                FeedForward(dim, dropout = ff_dropout),
                Attention(dim, attn_type = 'long', window_size = global_window_size, dim_head = dim_head, dropout = attn_dropout),
                FeedForward(dim, dropout = ff_dropout)
            ]))

    # 前向传播函数
    def forward(self, x):
        # 遍历每一层的注意力机制和前馈神经网络
        for short_attn, short_ff, long_attn, long_ff in self.layers:
            # 执行短程注意力机制和前馈神经网络
            x = short_attn(x) + x
            x = short_ff(x) + x
            # 执行长程注意力机制和前馈神经网络
            x = long_attn(x) + x
            x = long_ff(x) + x

        # 返回处理后的数据
        return x

# 定义一个名为 CrossFormer 的神经网络模块
class CrossFormer(nn.Module):
    # 初始化函数,接受多个参数
    def __init__(
        self,
        *,
        dim = (64, 128, 256, 512),
        depth = (2, 2, 8, 2),
        global_window_size = (8, 4, 2, 1),
        local_window_size = 7,
        cross_embed_kernel_sizes = ((4, 8, 16, 32), (2, 4), (2, 4), (2, 4)),
        cross_embed_strides = (4, 2, 2, 2),
        num_classes = 1000,
        attn_dropout = 0.,
        ff_dropout = 0.,
        channels = 3
    ):
        # 调用父类的初始化函数
        super().__init__()

        # 将参数转换为元组形式
        dim = cast_tuple(dim, 4)
        depth = cast_tuple(depth, 4)
        global_window_size = cast_tuple(global_window_size, 4)
        local_window_size = cast_tuple(local_window_size, 4)
        cross_embed_kernel_sizes = cast_tuple(cross_embed_kernel_sizes, 4)
        cross_embed_strides = cast_tuple(cross_embed_strides, 4)

        # 断言确保参数长度为4
        assert len(dim) == 4
        assert len(depth) == 4
        assert len(global_window_size) == 4
        assert len(local_window_size) == 4
        assert len(cross_embed_kernel_sizes) == 4
        assert len(cross_embed_strides) == 4

        # 定义维度相关变量
        last_dim = dim[-1]
        dims = [channels, *dim]
        dim_in_and_out = tuple(zip(dims[:-1], dims[1:]))

        # 初始化一个空的神经网络模块列表
        self.layers = nn.ModuleList([])

        # 循环创建交叉嵌入层和 Transformer 层
        for (dim_in, dim_out), layers, global_wsz, local_wsz, cel_kernel_sizes, cel_stride in zip(dim_in_and_out, depth, global_window_size, local_window_size, cross_embed_kernel_sizes, cross_embed_strides):
            self.layers.append(nn.ModuleList([
                CrossEmbedLayer(dim_in, dim_out, cel_kernel_sizes, stride = cel_stride),
                Transformer(dim_out, local_window_size = local_wsz, global_window_size = global_wsz, depth = layers, attn_dropout = attn_dropout, ff_dropout = ff_dropout)
            ]))

        # 定义最终的逻辑层
        self.to_logits = nn.Sequential(
            Reduce('b c h w -> b c', 'mean'),
            nn.Linear(last_dim, num_classes)
        )

    # 前向传播函数
    def forward(self, x):
        # 遍历每一层的交叉嵌入层和 Transformer 层
        for cel, transformer in self.layers:
            # 执行交叉嵌入层
            x = cel(x)
            # 执行 Transformer 层
            x = transformer(x)

        # 返回最终的逻辑结果
        return self.to_logits(x)

.\lucidrains\vit-pytorch\vit_pytorch\cross_vit.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块和 einsum 函数
from torch import nn, einsum
# 从 torch.nn 模块中导入 functional 模块并重命名为 F
import torch.nn.functional as F

# 从 einops 库中导入 rearrange 和 repeat 函数
from einops import rearrange, repeat
# 从 einops.layers.torch 库中导入 Rearrange 类
from einops.layers.torch import Rearrange

# 辅助函数

# 判断变量是否存在的函数
def exists(val):
    return val is not None

# 如果变量存在则返回该变量,否则返回默认值的函数
def default(val, d):
    return val if exists(val) else d

# 前馈神经网络

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        # 定义神经网络结构
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

# 注意力机制

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        self.heads = heads
        self.scale = dim_head ** -0.5

        self.norm = nn.LayerNorm(dim)
        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x, context = None, kv_include_self = False):
        b, n, _, h = *x.shape, self.heads
        x = self.norm(x)
        context = default(context, x)

        if kv_include_self:
            context = torch.cat((x, context), dim = 1) # 交叉注意力需要 CLS 标记包含自身作为键/值

        qkv = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)

        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

# Transformer 编码器,用于小和大补丁

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        self.norm = nn.LayerNorm(dim)
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
                FeedForward(dim, mlp_dim, dropout = dropout)
            ]))

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return self.norm(x)

# 投影 CLS 标记,以防小和大补丁标记具有不同的维度

class ProjectInOut(nn.Module):
    def __init__(self, dim_in, dim_out, fn):
        super().__init__()
        self.fn = fn

        need_projection = dim_in != dim_out
        self.project_in = nn.Linear(dim_in, dim_out) if need_projection else nn.Identity()
        self.project_out = nn.Linear(dim_out, dim_in) if need_projection else nn.Identity()

    def forward(self, x, *args, **kwargs):
        x = self.project_in(x)
        x = self.fn(x, *args, **kwargs)
        x = self.project_out(x)
        return x

# 交叉���意力 Transformer

class CrossTransformer(nn.Module):
    def __init__(self, sm_dim, lg_dim, depth, heads, dim_head, dropout):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                ProjectInOut(sm_dim, lg_dim, Attention(lg_dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                ProjectInOut(lg_dim, sm_dim, Attention(sm_dim, heads = heads, dim_head = dim_head, dropout = dropout))
            ]))
    # 定义一个前向传播函数,接受两个输入:sm_tokens和lg_tokens
    def forward(self, sm_tokens, lg_tokens):
        # 将输入的sm_tokens和lg_tokens分别拆分为(sm_cls, sm_patch_tokens)和(lg_cls, lg_patch_tokens)
        (sm_cls, sm_patch_tokens), (lg_cls, lg_patch_tokens) = map(lambda t: (t[:, :1], t[:, 1:]), (sm_tokens, lg_tokens))

        # 遍历self.layers中的每一层,每一层包含sm_attend_lg和lg_attend_sm
        for sm_attend_lg, lg_attend_sm in self.layers:
            # 对sm_cls进行注意力计算,使用lg_patch_tokens作为上下文,kv_include_self设置为True,然后加上原始sm_cls
            sm_cls = sm_attend_lg(sm_cls, context=lg_patch_tokens, kv_include_self=True) + sm_cls
            # 对lg_cls进行注意力计算,使用sm_patch_tokens作为上下文,kv_include_self设置为True,然后加上原始lg_cls
            lg_cls = lg_attend_sm(lg_cls, context=sm_patch_tokens, kv_include_self=True) + lg_cls

        # 将sm_cls和sm_patch_tokens在维度1上拼接起来
        sm_tokens = torch.cat((sm_cls, sm_patch_tokens), dim=1)
        # 将lg_cls和lg_patch_tokens在维度1上拼接起来
        lg_tokens = torch.cat((lg_cls, lg_patch_tokens), dim=1)
        # 返回拼接后的sm_tokens和lg_tokens
        return sm_tokens, lg_tokens
# 定义多尺度编码器类
class MultiScaleEncoder(nn.Module):
    def __init__(
        self,
        *,
        depth,  # 编码器深度
        sm_dim,  # 小尺度维度
        lg_dim,  # 大尺度维度
        sm_enc_params,  # 小尺度编码器参数
        lg_enc_params,  # 大尺度编码器参数
        cross_attn_heads,  # 跨尺度注意力头数
        cross_attn_depth,  # 跨尺度注意力深度
        cross_attn_dim_head = 64,  # 跨尺度注意力头维度
        dropout = 0.  # 丢弃率
    ):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Transformer(dim = sm_dim, dropout = dropout, **sm_enc_params),  # 小尺度变换器
                Transformer(dim = lg_dim, dropout = dropout, **lg_enc_params),  # 大尺度变换器
                CrossTransformer(sm_dim = sm_dim, lg_dim = lg_dim, depth = cross_attn_depth, heads = cross_attn_heads, dim_head = cross_attn_dim_head, dropout = dropout)  # 跨尺度变换器
            ]))

    def forward(self, sm_tokens, lg_tokens):
        for sm_enc, lg_enc, cross_attend in self.layers:
            sm_tokens, lg_tokens = sm_enc(sm_tokens), lg_enc(lg_tokens)  # 小尺度编码器和大尺度编码器
            sm_tokens, lg_tokens = cross_attend(sm_tokens, lg_tokens)  # 跨尺度注意力

        return sm_tokens, lg_tokens

# 基于补丁的图像到标记嵌入器类
class ImageEmbedder(nn.Module):
    def __init__(
        self,
        *,
        dim,  # 维度
        image_size,  # 图像尺寸
        patch_size,  # 补丁尺寸
        dropout = 0.  # 丢弃率
    ):
        super().__init__()
        assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
        num_patches = (image_size // patch_size) ** 2
        patch_dim = 3 * patch_size ** 2

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),  # 图像转换为补丁
            nn.LayerNorm(patch_dim),  # 层归一化
            nn.Linear(patch_dim, dim),  # 线性变换
            nn.LayerNorm(dim)  # 层归一化
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))  # 位置嵌入
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))  # 类别标记
        self.dropout = nn.Dropout(dropout)  # 丢弃层

    def forward(self, img):
        x = self.to_patch_embedding(img)  # 图像转换为补丁嵌入
        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)  # 重复类别标记
        x = torch.cat((cls_tokens, x), dim=1)  # 拼接类别标记和补丁嵌入
        x += self.pos_embedding[:, :(n + 1)]  # 加上位置嵌入

        return self.dropout(x)  # 返回结果经过丢弃层处理

# 跨ViT类
class CrossViT(nn.Module):
    def __init__(
        self,
        *,
        image_size,  # 图像尺寸
        num_classes,  # 类别数
        sm_dim,  # 小尺度维度
        lg_dim,  # 大尺度维度
        sm_patch_size = 12,  # 小尺度补丁尺寸
        sm_enc_depth = 1,  # 小尺度编码器深度
        sm_enc_heads = 8,  # 小尺度编码器头数
        sm_enc_mlp_dim = 2048,  # 小尺度编码器MLP维度
        sm_enc_dim_head = 64,  # 小尺度编码器头维度
        lg_patch_size = 16,  # 大尺度补丁尺寸
        lg_enc_depth = 4,  # 大尺度编码器深度
        lg_enc_heads = 8,  # 大尺度编码器头数
        lg_enc_mlp_dim = 2048,  # 大尺度编码器MLP维度
        lg_enc_dim_head = 64,  # 大尺度编码器头维度
        cross_attn_depth = 2,  # 跨尺度注意力深度
        cross_attn_heads = 8,  # 跨尺度注意力头数
        cross_attn_dim_head = 64,  # 跨尺度注意力头维度
        depth = 3,  # 深度
        dropout = 0.1,  # 丢弃率
        emb_dropout = 0.1  # 嵌入丢弃率
    # 初始化函数,继承父类的初始化方法
    def __init__(
        super().__init__()
        # 创建小尺寸图像嵌入器对象
        self.sm_image_embedder = ImageEmbedder(dim = sm_dim, image_size = image_size, patch_size = sm_patch_size, dropout = emb_dropout)
        # 创建大尺寸图像嵌入器对象
        self.lg_image_embedder = ImageEmbedder(dim = lg_dim, image_size = image_size, patch_size = lg_patch_size, dropout = emb_dropout)

        # 创建多尺度编码器对象
        self.multi_scale_encoder = MultiScaleEncoder(
            depth = depth,
            sm_dim = sm_dim,
            lg_dim = lg_dim,
            cross_attn_heads = cross_attn_heads,
            cross_attn_dim_head = cross_attn_dim_head,
            cross_attn_depth = cross_attn_depth,
            sm_enc_params = dict(
                depth = sm_enc_depth,
                heads = sm_enc_heads,
                mlp_dim = sm_enc_mlp_dim,
                dim_head = sm_enc_dim_head
            ),
            lg_enc_params = dict(
                depth = lg_enc_depth,
                heads = lg_enc_heads,
                mlp_dim = lg_enc_mlp_dim,
                dim_head = lg_enc_dim_head
            ),
            dropout = dropout
        )

        # 创建小尺寸MLP头部对象
        self.sm_mlp_head = nn.Sequential(nn.LayerNorm(sm_dim), nn.Linear(sm_dim, num_classes))
        # 创建大尺寸MLP头部对象
        self.lg_mlp_head = nn.Sequential(nn.LayerNorm(lg_dim), nn.Linear(lg_dim, num_classes))

    # 前向传播函数
    def forward(self, img):
        # 获取小尺寸图像嵌入
        sm_tokens = self.sm_image_embedder(img)
        # 获取大尺寸图像嵌入
        lg_tokens = self.lg_image_embedder(img)

        # 多尺度编码器处理小尺寸和大尺寸图像嵌入
        sm_tokens, lg_tokens = self.multi_scale_encoder(sm_tokens, lg_tokens)

        # 提取小尺寸和大尺寸的类别特征
        sm_cls, lg_cls = map(lambda t: t[:, 0], (sm_tokens, lg_tokens))

        # 小尺寸MLP头部处理小尺寸类别特征
        sm_logits = self.sm_mlp_head(sm_cls)
        # 大尺寸MLP头部处理大尺寸类别特征
        lg_logits = self.lg_mlp_head(lg_cls)

        # 返回小尺寸和大尺寸类别特征的加和
        return sm_logits + lg_logits