Lucidrains-系列项目源码解析-一百一十-

37 阅读27分钟

Lucidrains 系列项目源码解析(一百一十)

.\lucidrains\vit-pytorch\vit_pytorch\simple_vit_1d.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块
from torch import nn

# 从 einops 库中导入 rearrange 和 Rearrange 函数
from einops import rearrange
from einops.layers.torch import Rearrange

# 定义函数 posemb_sincos_1d,用于生成位置编码
def posemb_sincos_1d(patches, temperature = 10000, dtype = torch.float32):
    # 获取 patches 的形状信息
    _, n, dim, device, dtype = *patches.shape, patches.device, patches.dtype

    # 生成序列 n
    n = torch.arange(n, device = device)
    # 检查 dim 是否为偶数
    assert (dim % 2) == 0, 'feature dimension must be multiple of 2 for sincos emb'
    # 计算 omega
    omega = torch.arange(dim // 2, device = device) / (dim // 2 - 1)
    omega = 1. / (temperature ** omega)

    # 计算位置编码
    n = n.flatten()[:, None] * omega[None, :]
    pe = torch.cat((n.sin(), n.cos()), dim = 1)
    return pe.type(dtype)

# 定义类 FeedForward
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, dim),
        )
    def forward(self, x):
        return self.net(x)

# 定义类 Attention
class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64):
        super().__init__()
        inner_dim = dim_head *  heads
        self.heads = heads
        self.scale = dim_head ** -0.5
        self.norm = nn.LayerNorm(dim)

        self.attend = nn.Softmax(dim = -1)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        self.to_out = nn.Linear(inner_dim, dim, bias = False)

    def forward(self, x):
        x = self.norm(x)

        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

# 定义类 Transformer
class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head),
                FeedForward(dim, mlp_dim)
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return self.norm(x)

# 定义类 SimpleViT
class SimpleViT(nn.Module):
    def __init__(self, *, seq_len, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64):
        super().__init__()

        assert seq_len % patch_size == 0

        num_patches = seq_len // patch_size
        patch_dim = channels * patch_size

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (n p) -> b n (p c)', p = patch_size),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim),
        )

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)

        self.to_latent = nn.Identity()
        self.linear_head = nn.Linear(dim, num_classes)

    def forward(self, series):
        *_, n, dtype = *series.shape, series.dtype

        x = self.to_patch_embedding(series)
        pe = posemb_sincos_1d(x)
        x = rearrange(x, 'b ... d -> b (...) d') + pe

        x = self.transformer(x)
        x = x.mean(dim = 1)

        x = self.to_latent(x)
        return self.linear_head(x)

# 在主函数中创建 SimpleViT 实例 v
if __name__ == '__main__':

    v = SimpleViT(
        seq_len = 256,
        patch_size = 16,
        num_classes = 1000,
        dim = 1024,
        depth = 6,
        heads = 8,
        mlp_dim = 2048
    )

    # 生成随机时间序列数据
    time_series = torch.randn(4, 3, 256)
    # 输入时间序列数据到 SimpleViT 模型中,得到 logits
    logits = v(time_series) # (4, 1000)

.\lucidrains\vit-pytorch\vit_pytorch\simple_vit_3d.py

import torch
import torch.nn.functional as F
from torch import nn

from einops import rearrange
from einops.layers.torch import Rearrange

# helpers

# 如果输入参数是元组,则返回元组,否则返回包含两个相同元素的元组
def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# 生成三维位置编码的正弦和余弦值
def posemb_sincos_3d(patches, temperature = 10000, dtype = torch.float32):
    # 获取 patches 的形状信息
    _, f, h, w, dim, device, dtype = *patches.shape, patches.device, patches.dtype

    # 生成三维网格坐标
    z, y, x = torch.meshgrid(
        torch.arange(f, device = device),
        torch.arange(h, device = device),
        torch.arange(w, device = device),
    indexing = 'ij')

    # 计算傅立叶维度
    fourier_dim = dim // 6

    # 计算温度参数
    omega = torch.arange(fourier_dim, device = device) / (fourier_dim - 1)
    omega = 1. / (temperature ** omega)

    # 计算位置编码
    z = z.flatten()[:, None] * omega[None, :]
    y = y.flatten()[:, None] * omega[None, :]
    x = x.flatten()[:, None] * omega[None, :] 

    pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos(), z.sin(), z.cos()), dim = 1)

    # 如果特征维度不能被6整除,则进行填充
    pe = F.pad(pe, (0, dim - (fourier_dim * 6)))
    return pe.type(dtype)

# classes

# 前馈神经网络类
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, dim),
        )
    def forward(self, x):
        return self.net(x)

# 注意力机制类
class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64):
        super().__init__()
        inner_dim = dim_head *  heads
        self.heads = heads
        self.scale = dim_head ** -0.5
        self.norm = nn.LayerNorm(dim)

        self.attend = nn.Softmax(dim = -1)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        self.to_out = nn.Linear(inner_dim, dim, bias = False)

    def forward(self, x):
        x = self.norm(x)

        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

# Transformer 模型类
class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head),
                FeedForward(dim, mlp_dim)
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return self.norm(x)

class SimpleViT(nn.Module):
    # 初始化函数,设置模型参数和结构
    def __init__(self, *, image_size, image_patch_size, frames, frame_patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64):
        # 调用父类的初始化函数
        super().__init__()
        # 获取图像大小的高度和宽度
        image_height, image_width = pair(image_size)
        # 获取图像块的高度和宽度
        patch_height, patch_width = pair(image_patch_size)

        # 断言图像高度和宽度能够被图像块的高度和宽度整除
        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
        # 断言帧数能够被帧块大小整除
        assert frames % frame_patch_size == 0, 'Frames must be divisible by the frame patch size'

        # 计算图像块的数量
        num_patches = (image_height // patch_height) * (image_width // patch_width) * (frames // frame_patch_size)
        # 计算图像块的维度
        patch_dim = channels * patch_height * patch_width * frame_patch_size

        # 将图像块转换为嵌入向量
        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (f pf) (h p1) (w p2) -> b f h w (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim),
        )

        # 创建 Transformer 模型
        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)

        # 将嵌入向量转换为潜在向量
        self.to_latent = nn.Identity()
        # 线性层,用于分类
        self.linear_head = nn.Linear(dim, num_classes)

    # 前向传播函数
    def forward(self, video):
        # 获取视频的形状信息
        *_, h, w, dtype = *video.shape, video.dtype

        # 将视频转换为图像块的嵌入向量
        x = self.to_patch_embedding(video)
        # 获取位置编码
        pe = posemb_sincos_3d(x)
        # 将位置编码加到嵌入向量中
        x = rearrange(x, 'b ... d -> b (...) d') + pe

        # 经过 Transformer 模型处理
        x = self.transformer(x)
        # 对结果进行平均池化
        x = x.mean(dim = 1)

        # 转换为潜在向量
        x = self.to_latent(x)
        # 使用线性层进行分类
        return self.linear_head(x)

.\lucidrains\vit-pytorch\vit_pytorch\simple_vit_with_fft.py

# 导入 torch 库
import torch
# 从 torch.fft 中导入 fft2 函数
from torch.fft import fft2
# 从 torch 中导入 nn 模块
from torch import nn

# 从 einops 库中导入 rearrange、reduce、pack、unpack 函数
from einops import rearrange, reduce, pack, unpack
# 从 einops.layers.torch 中导入 Rearrange 类

# 辅助函数

# 如果输入 t 是元组,则返回 t,否则返回 (t, t)
def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# 生成二维位置编码的函数
def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32):
    # 生成网格坐标
    y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
    # 确保特征维度是4的倍数
    assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
    # 计算 omega
    omega = torch.arange(dim // 4) / (dim // 4 - 1)
    omega = 1.0 / (temperature ** omega)

    # 计算位置编码
    y = y.flatten()[:, None] * omega[None, :]
    x = x.flatten()[:, None] * omega[None, :]
    pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
    return pe.type(dtype)

# 类

# 前馈神经网络类
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, dim),
        )
    def forward(self, x):
        return self.net(x)

# 注意力机制类
class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64):
        super().__init__()
        inner_dim = dim_head *  heads
        self.heads = heads
        self.scale = dim_head ** -0.5
        self.norm = nn.LayerNorm(dim)

        self.attend = nn.Softmax(dim = -1)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        self.to_out = nn.Linear(inner_dim, dim, bias = False)

    def forward(self, x):
        x = self.norm(x)

        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

# Transformer 类
class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head),
                FeedForward(dim, mlp_dim)
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return self.norm(x)

# SimpleViT 类
class SimpleViT(nn.Module):
    # 初始化函数,设置模型参数
    def __init__(self, *, image_size, patch_size, freq_patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64):
        # 调用父类初始化函数
        super().__init__()
        # 获取图像的高度和宽度
        image_height, image_width = pair(image_size)
        # 获取 patch 的高度和宽度
        patch_height, patch_width = pair(patch_size)
        # 获取频域 patch 的高度和宽度
        freq_patch_height, freq_patch_width = pair(freq_patch_size)

        # 断言图像的高度和宽度能够被 patch 的高度和宽度整除
        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
        # 断言图像的高度和宽度能够被频域 patch 的高度和宽度整除
        assert image_height % freq_patch_height == 0 and image_width % freq_patch_width == 0, 'Image dimensions must be divisible by the freq patch size.'

        # 计算 patch 的维度
        patch_dim = channels * patch_height * patch_width
        # 计算频域 patch 的维度
        freq_patch_dim = channels * 2 * freq_patch_height * freq_patch_width

        # 将图像转换为 patch 的嵌入向量
        self.to_patch_embedding = nn.Sequential(
            Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim),
        )

        # 将频域 patch 转换为嵌入向量
        self.to_freq_embedding = nn.Sequential(
            Rearrange("b c (h p1) (w p2) ri -> b (h w) (p1 p2 ri c)", p1 = freq_patch_height, p2 = freq_patch_width),
            nn.LayerNorm(freq_patch_dim),
            nn.Linear(freq_patch_dim, dim),
            nn.LayerNorm(dim)
        )

        # 生成位置编码
        self.pos_embedding = posemb_sincos_2d(
            h = image_height // patch_height,
            w = image_width // patch_width,
            dim = dim,
        )

        # 生成频域位置编码
        self.freq_pos_embedding = posemb_sincos_2d(
            h = image_height // freq_patch_height,
            w = image_width // freq_patch_width,
            dim = dim
        )

        # 创建 Transformer 模型
        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)

        # 池化方式为平均池化
        self.pool = "mean"
        # 转换为潜在空间的操作
        self.to_latent = nn.Identity()

        # 线性层,用于分类
        self.linear_head = nn.Linear(dim, num_classes)

    # 前向传播函数
    def forward(self, img):
        # 获取设备和数据类型
        device, dtype = img.device, img.dtype

        # 将图像转换为 patch 的嵌入向量
        x = self.to_patch_embedding(img)
        # 对图像进行二维傅里叶变换
        freqs = torch.view_as_real(fft2(img))

        # 将频域 patch 转换为嵌入向量
        f = self.to_freq_embedding(freqs)

        # 添加位置编码
        x += self.pos_embedding.to(device, dtype = dtype)
        f += self.freq_pos_embedding.to(device, dtype = dtype)

        # 打包数据
        x, ps = pack((f, x), 'b * d')

        # 使用 Transformer 进行特征提取
        x = self.transformer(x)

        # 解包数据
        _, x = unpack(x, ps, 'b * d')
        # 对特征进行池化操作
        x = reduce(x, 'b n d -> b d', 'mean')

        # 转换为潜在空间
        x = self.to_latent(x)
        # 使用线性层进行分类
        return self.linear_head(x)
# 如果当前脚本作为主程序运行
if __name__ == '__main__':
    # 创建一个简单的ViT模型实例,指定参数包括类别数、图像大小、patch大小、频率patch大小、维度、深度、头数、MLP维度
    vit = SimpleViT(
        num_classes = 1000,
        image_size = 256,
        patch_size = 8,
        freq_patch_size = 8,
        dim = 1024,
        depth = 1,
        heads = 8,
        mlp_dim = 2048,
    )

    # 生成一个8个样本的随机张量,每个样本包含3个通道,大小为256x256
    images = torch.randn(8, 3, 256, 256)

    # 将图像输入ViT模型,得到输出logits
    logits = vit(images)

.\lucidrains\vit-pytorch\vit_pytorch\simple_vit_with_patch_dropout.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块
from torch import nn

# 从 einops 库中导入 rearrange 和 Rearrange 函数
from einops import rearrange
from einops.layers.torch import Rearrange

# 辅助函数

# 如果输入 t 是元组,则返回 t,否则返回 (t, t)
def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# 生成二维位置编码的正弦和余弦值
def posemb_sincos_2d(patches, temperature = 10000, dtype = torch.float32):
    # 获取 patches 的形状信息
    _, h, w, dim, device, dtype = *patches.shape, patches.device, patches.dtype

    # 创建网格矩阵 y 和 x
    y, x = torch.meshgrid(torch.arange(h, device = device), torch.arange(w, device = device), indexing = 'ij')
    # 确保特征维度是 4 的倍数,用于 sincos 编码
    assert (dim % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb'
    # 计算 omega 值
    omega = torch.arange(dim // 4, device = device) / (dim // 4 - 1)
    omega = 1. / (temperature ** omega)

    # 计算位置编码
    y = y.flatten()[:, None] * omega[None, :]
    x = x.flatten()[:, None] * omega[None, :] 
    pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1)
    return pe.type(dtype)

# 补丁丢弃

# 定义 PatchDropout 类
class PatchDropout(nn.Module):
    def __init__(self, prob):
        super().__init__()
        assert 0 <= prob < 1.
        self.prob = prob

    def forward(self, x):
        # 如果不是训练状态或概率为 0,则直接返回输入 x
        if not self.training or self.prob == 0.:
            return x

        # 获取输入 x 的形状信息
        b, n, _, device = *x.shape, x.device

        # 创建批次索引
        batch_indices = torch.arange(b, device = device)
        batch_indices = rearrange(batch_indices, '... -> ... 1')
        # 计算要保留的补丁数量
        num_patches_keep = max(1, int(n * (1 - self.prob)))
        # 随机选择要保留的补丁索引
        patch_indices_keep = torch.randn(b, n, device = device).topk(num_patches_keep, dim = -1).indices

        return x[batch_indices, patch_indices_keep]

# 类

# 定义前馈神经网络类
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, dim),
        )
    def forward(self, x):
        return self.net(x)

# 定义注意力机制类
class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64):
        super().__init__()
        inner_dim = dim_head *  heads
        self.heads = heads
        self.scale = dim_head ** -0.5
        self.norm = nn.LayerNorm(dim)

        self.attend = nn.Softmax(dim = -1)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        self.to_out = nn.Linear(inner_dim, dim, bias = False)

    def forward(self, x):
        x = self.norm(x)

        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

# 定义变换器类
class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head),
                FeedForward(dim, mlp_dim)
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return self.norm(x)

# 简单 ViT 模型类
class SimpleViT(nn.Module):
    # 初始化函数,设置模型参数和层结构
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, patch_dropout = 0.5):
        # 调用父类的初始化函数
        super().__init__()
        # 获取图像的高度和宽度
        image_height, image_width = pair(image_size)
        # 获取补丁的高度和宽度
        patch_height, patch_width = pair(patch_size)

        # 断言图像的高度和宽度能够被补丁的高度和宽度整除
        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        # 计算补丁的数量
        num_patches = (image_height // patch_height) * (image_width // patch_width)
        # 计算每个补丁的维度
        patch_dim = channels * patch_height * patch_width

        # 定义将图像转换为补丁嵌入的层结构
        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b h w (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim)
        )

        # 定义补丁的丢弃层
        self.patch_dropout = PatchDropout(patch_dropout)

        # 定义变换器层
        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)

        # 定义转换为潜在空间的层
        self.to_latent = nn.Identity()
        # 定义线性头层
        self.linear_head = nn.Linear(dim, num_classes)

    # 前向传播函数
    def forward(self, img):
        # 获取图像的形状和数据类型
        *_, h, w, dtype = *img.shape, img.dtype

        # 将图像转换为补丁嵌入
        x = self.to_patch_embedding(img)
        # 获取位置编码
        pe = posemb_sincos_2d(x)
        # 将位置编码添加到补丁嵌入中
        x = rearrange(x, 'b ... d -> b (...) d') + pe

        # 对补丁进行丢弃
        x = self.patch_dropout(x)

        # 使用变换器进行转换
        x = self.transformer(x)
        # 对结果进行平均池化
        x = x.mean(dim = 1)

        # 转换为潜在空间
        x = self.to_latent(x)
        # 使用线性头层进行分类预测
        return self.linear_head(x)

.\lucidrains\vit-pytorch\vit_pytorch\simple_vit_with_qk_norm.py

import torch
from torch import nn
import torch.nn.functional as F

from einops import rearrange
from einops.layers.torch import Rearrange

# helpers

# 定义一个函数,如果输入参数是元组则返回元组,否则返回包含两个相同元素的元组
def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# 生成二维位置编码的正弦和余弦值
def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32):
    y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
    assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
    omega = torch.arange(dim // 4) / (dim // 4 - 1)
    omega = 1.0 / (temperature ** omega)

    y = y.flatten()[:, None] * omega[None, :]
    x = x.flatten()[:, None] * omega[None, :]
    pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
    return pe.type(dtype)

# they use a query-key normalization that is equivalent to rms norm (no mean-centering, learned gamma), from vit 22B paper

# in latest tweet, seem to claim more stable training at higher learning rates
# unsure if this has taken off within Brain, or it has some hidden drawback

# 定义一个类,实现 RMS 归一化
class RMSNorm(nn.Module):
    def __init__(self, heads, dim):
        super().__init__()
        self.scale = dim ** 0.5
        self.gamma = nn.Parameter(torch.ones(heads, 1, dim) / self.scale)

    def forward(self, x):
        normed = F.normalize(x, dim = -1)
        return normed * self.scale * self.gamma

# classes

# 定义一个类,实现前馈神经网络
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, dim),
        )
    def forward(self, x):
        return self.net(x)

# 定义一个类,实现注意力机制
class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64):
        super().__init__()
        inner_dim = dim_head *  heads
        self.heads = heads
        self.norm = nn.LayerNorm(dim)

        self.attend = nn.Softmax(dim = -1)

        self.q_norm = RMSNorm(heads, dim_head)
        self.k_norm = RMSNorm(heads, dim_head)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        self.to_out = nn.Linear(inner_dim, dim, bias = False)

    def forward(self, x):
        x = self.norm(x)

        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        q = self.q_norm(q)
        k = self.k_norm(k)

        dots = torch.matmul(q, k.transpose(-1, -2))

        attn = self.attend(dots)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

# 定义一个类,实现 Transformer 模型
class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head),
                FeedForward(dim, mlp_dim)
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return self.norm(x)

class SimpleViT(nn.Module):
    # 初始化函数,设置模型参数和层结构
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64):
        # 调用父类的初始化函数
        super().__init__()
        # 获取图像的高度和宽度
        image_height, image_width = pair(image_size)
        # 获取补丁的高度和宽度
        patch_height, patch_width = pair(patch_size)

        # 断言图像的高度和宽度能够被补丁的高度和宽度整除
        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        # 计算补丁的维度
        patch_dim = channels * patch_height * patch_width

        # 定义将图像转换为补丁嵌入的层结构
        self.to_patch_embedding = nn.Sequential(
            Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim),
        )

        # 生成位置编码
        self.pos_embedding = posemb_sincos_2d(
            h = image_height // patch_height,
            w = image_width // patch_width,
            dim = dim,
        ) 

        # 定义 Transformer 模型
        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)

        # 池化方式为平均池化
        self.pool = "mean"
        # 定义将嵌入转换为潜在表示的层结构
        self.to_latent = nn.Identity()

        # 线性层归一化
        self.linear_head = nn.LayerNorm(dim)

    # 前向传播函数
    def forward(self, img):
        # 获取输入图像的设备信息
        device = img.device

        # 将输入图像转换为补丁嵌入
        x = self.to_patch_embedding(img)
        # 添加位置编码
        x += self.pos_embedding.to(device, dtype=x.dtype)

        # 经过 Transformer 模型
        x = self.transformer(x)
        # 对特征进行平均池化
        x = x.mean(dim = 1)

        # 将特征转换为潜在表示
        x = self.to_latent(x)
        # 返回线性层归一化后的结果
        return self.linear_head(x)

.\lucidrains\vit-pytorch\vit_pytorch\simple_vit_with_register_tokens.py

"""
    Vision Transformers Need Registers
    https://arxiv.org/abs/2309.16588
"""

import torch
from torch import nn

from einops import rearrange, repeat, pack, unpack
from einops.layers.torch import Rearrange

# helpers

# 定义一个函数,将输入转换为元组
def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# 生成二维位置编码的正弦和余弦值
def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32):
    y, x = torch.meshgrid(torch.arange(h), torch.arange(w))
    assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
    omega = torch.arange(dim // 4) / (dim // 4 - 1)
    omega = 1.0 / (temperature ** omega)

    y = y.flatten()[:, None] * omega[None, :]
    x = x.flatten()[:, None] * omega[None, :]
    pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
    return pe.type(dtype)

# classes

# 定义前馈神经网络类
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, dim),
        )
    def forward(self, x):
        return self.net(x)

# 定义注意力机制类
class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64):
        super().__init__()
        inner_dim = dim_head *  heads
        self.heads = heads
        self.scale = dim_head ** -0.5
        self.norm = nn.LayerNorm(dim)

        self.attend = nn.Softmax(dim = -1)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        self.to_out = nn.Linear(inner_dim, dim, bias = False)

    def forward(self, x):
        x = self.norm(x)

        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

# 定义Transformer类
class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head),
                FeedForward(dim, mlp_dim)
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return self.norm(x)

# 定义简单的ViT模型类
class SimpleViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, num_register_tokens = 4, channels = 3, dim_head = 64):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        patch_dim = channels * patch_height * patch_width

        self.to_patch_embedding = nn.Sequential(
            Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim),
        )

        self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim))

        self.pos_embedding = posemb_sincos_2d(
            h = image_height // patch_height,
            w = image_width // patch_width,
            dim = dim,
        ) 

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)

        self.pool = "mean"
        self.to_latent = nn.Identity()

        self.linear_head = nn.Linear(dim, num_classes)
    # 定义前向传播函数,接收输入图像
    def forward(self, img):
        # 获取输入图像的批量大小和设备信息
        batch, device = img.shape[0], img.device

        # 将输入图像转换为补丁嵌入
        x = self.to_patch_embedding(img)
        # 将位置嵌入添加到补丁嵌入中
        x += self.pos_embedding.to(device, dtype=x.dtype)

        # 重复注册令牌以匹配批量大小
        r = repeat(self.register_tokens, 'n d -> b n d', b = batch)

        # 打包补丁嵌入和注册令牌
        x, ps = pack([x, r], 'b * d')

        # 使用Transformer处理输入数据
        x = self.transformer(x)

        # 解包处理后的数据
        x, _ = unpack(x, ps, 'b * d')

        # 对数据进行平均池化
        x = x.mean(dim = 1)

        # 将数据转换为潜在空间
        x = self.to_latent(x)
        # 使用线性头部进行最终预测
        return self.linear_head(x)

.\lucidrains\vit-pytorch\vit_pytorch\t2t.py

# 导入所需的库
import math
import torch
from torch import nn

# 导入自定义的 Transformer 类
from vit_pytorch.vit import Transformer

# 导入 einops 库中的函数和类
from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# 定义一个辅助函数,用于检查变量是否存在
def exists(val):
    return val is not None

# 定义一个函数,用于计算卷积层输出的大小
def conv_output_size(image_size, kernel_size, stride, padding):
    return int(((image_size - kernel_size + (2 * padding)) / stride) + 1)

# 定义一个类,用于将输入重新排列成指定形状
class RearrangeImage(nn.Module):
    def forward(self, x):
        return rearrange(x, 'b (h w) c -> b c h w', h = int(math.sqrt(x.shape[1]))

# 定义主要的 T2TViT 类
class T2TViT(nn.Module):
    def __init__(self, *, image_size, num_classes, dim, depth = None, heads = None, mlp_dim = None, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., transformer = None, t2t_layers = ((7, 4), (3, 2), (3, 2))):
        super().__init__()
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        layers = []
        layer_dim = channels
        output_image_size = image_size

        # 遍历 t2t_layers 中的每个元素
        for i, (kernel_size, stride) in enumerate(t2t_layers):
            layer_dim *= kernel_size ** 2
            is_first = i == 0
            is_last = i == (len(t2t_layers) - 1)
            output_image_size = conv_output_size(output_image_size, kernel_size, stride, stride // 2)

            # 根据条件选择不同的层
            layers.extend([
                RearrangeImage() if not is_first else nn.Identity(),
                nn.Unfold(kernel_size = kernel_size, stride = stride, padding = stride // 2),
                Rearrange('b c n -> b n c'),
                Transformer(dim = layer_dim, heads = 1, depth = 1, dim_head = layer_dim, mlp_dim = layer_dim, dropout = dropout) if not is_last else nn.Identity(),
            ])

        layers.append(nn.Linear(layer_dim, dim))
        self.to_patch_embedding = nn.Sequential(*layers)

        # 初始化位置编码和类别标记
        self.pos_embedding = nn.Parameter(torch.randn(1, output_image_size ** 2 + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        # 根据是否提供 transformer 参数选择不同的 Transformer 模型
        if not exists(transformer):
            assert all([exists(depth), exists(heads), exists(mlp_dim)]), 'depth, heads, and mlp_dim must be supplied'
            self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
        else:
            self.transformer = transformer

        self.pool = pool
        self.to_latent = nn.Identity()

        # 定义 MLP 头部
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, img):
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :n+1]
        x = self.dropout(x)

        x = self.transformer(x)

        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

        x = self.to_latent(x)
        return self.mlp_head(x)

.\lucidrains\vit-pytorch\vit_pytorch\twins_svt.py

import torch
from torch import nn, einsum
import torch.nn.functional as F

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# 辅助方法

# 根据条件将字典分组
def group_dict_by_key(cond, d):
    return_val = [dict(), dict()]
    for key in d.keys():
        match = bool(cond(key))
        ind = int(not match)
        return_val[ind][key] = d[key]
    return (*return_val,)

# 根据前缀分组并移除前缀
def group_by_key_prefix_and_remove_prefix(prefix, d):
    kwargs_with_prefix, kwargs = group_dict_by_key(lambda x: x.startswith(prefix), d)
    kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
    return kwargs_without_prefix, kwargs

# 类

# 残差连接
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) + x

# 层归一化
class LayerNorm(nn.Module):
    def __init__(self, dim, eps = 1e-5):
        super().__init__()
        self.eps = eps
        self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
        self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))

    def forward(self, x):
        var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
        mean = torch.mean(x, dim = 1, keepdim = True)
        return (x - mean) / (var + self.eps).sqrt() * self.g + self.b

# 前馈神经网络
class FeedForward(nn.Module):
    def __init__(self, dim, mult = 4, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            LayerNorm(dim),
            nn.Conv2d(dim, dim * mult, 1),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Conv2d(dim * mult, dim, 1),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

# 图像块嵌入
class PatchEmbedding(nn.Module):
    def __init__(self, *, dim, dim_out, patch_size):
        super().__init__()
        self.dim = dim
        self.dim_out = dim_out
        self.patch_size = patch_size

        self.proj = nn.Sequential(
            LayerNorm(patch_size ** 2 * dim),
            nn.Conv2d(patch_size ** 2 * dim, dim_out, 1),
            LayerNorm(dim_out)
        )

    def forward(self, fmap):
        p = self.patch_size
        fmap = rearrange(fmap, 'b c (h p1) (w p2) -> b (c p1 p2) h w', p1 = p, p2 = p)
        return self.proj(fmap)

# 像素级注意力
class PEG(nn.Module):
    def __init__(self, dim, kernel_size = 3):
        super().__init__()
        self.proj = Residual(nn.Conv2d(dim, dim, kernel_size = kernel_size, padding = kernel_size // 2, groups = dim, stride = 1))

    def forward(self, x):
        return self.proj(x)

# 局部注意力
class LocalAttention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., patch_size = 7):
        super().__init__()
        inner_dim = dim_head *  heads
        self.patch_size = patch_size
        self.heads = heads
        self.scale = dim_head ** -0.5

        self.norm = LayerNorm(dim)
        self.to_q = nn.Conv2d(dim, inner_dim, 1, bias = False)
        self.to_kv = nn.Conv2d(dim, inner_dim * 2, 1, bias = False)

        self.to_out = nn.Sequential(
            nn.Conv2d(inner_dim, dim, 1),
            nn.Dropout(dropout)
        )

    def forward(self, fmap):
        fmap = self.norm(fmap)

        shape, p = fmap.shape, self.patch_size
        b, n, x, y, h = *shape, self.heads
        x, y = map(lambda t: t // p, (x, y))

        fmap = rearrange(fmap, 'b c (x p1) (y p2) -> (b x y) c p1 p2', p1 = p, p2 = p)

        q, k, v = (self.to_q(fmap), *self.to_kv(fmap).chunk(2, dim = 1))
        q, k, v = map(lambda t: rearrange(t, 'b (h d) p1 p2 -> (b h) (p1 p2) d', h = h), (q, k, v))

        dots = einsum('b i d, b j d -> b i j', q, k) * self.scale

        attn = dots.softmax(dim = - 1)

        out = einsum('b i j, b j d -> b i d', attn, v)
        out = rearrange(out, '(b x y h) (p1 p2) d -> b (h d) (x p1) (y p2)', h = h, x = x, y = y, p1 = p, p2 = p)
        return self.to_out(out)

class GlobalAttention(nn.Module):
    # 初始化函数,设置注意力机制的参数
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., k = 7):
        # 调用父类的初始化函数
        super().__init__()
        # 计算内部维度
        inner_dim = dim_head *  heads
        # 设置头数和缩放因子
        self.heads = heads
        self.scale = dim_head ** -0.5

        # 归一化层
        self.norm = LayerNorm(dim)

        # 转换查询向量
        self.to_q = nn.Conv2d(dim, inner_dim, 1, bias = False)
        # 转换键值对
        self.to_kv = nn.Conv2d(dim, inner_dim * 2, k, stride = k, bias = False)

        # 丢弃部分数据
        self.dropout = nn.Dropout(dropout)

        # 输出层
        self.to_out = nn.Sequential(
            nn.Conv2d(inner_dim, dim, 1),
            nn.Dropout(dropout)
        )

    # 前向传播函数
    def forward(self, x):
        # 对输入数据进行归一化
        x = self.norm(x)

        # 获取输入数据的形状
        shape = x.shape
        b, n, _, y, h = *shape, self.heads
        # 分别计算查询、键、值
        q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = 1))

        # 重排查询、键、值的维度
        q, k, v = map(lambda t: rearrange(t, 'b (h d) x y -> (b h) (x y) d', h = h), (q, k, v))

        # 计算点积
        dots = einsum('b i d, b j d -> b i j', q, k) * self.scale

        # 计算注意力分布
        attn = dots.softmax(dim = -1)
        attn = self.dropout(attn)

        # 计算输出
        out = einsum('b i j, b j d -> b i d', attn, v)
        out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, y = y)
        return self.to_out(out)
class Transformer(nn.Module):
    # 定义 Transformer 类,继承自 nn.Module
    def __init__(self, dim, depth, heads = 8, dim_head = 64, mlp_mult = 4, local_patch_size = 7, global_k = 7, dropout = 0., has_local = True):
        # 初始化函数,接受多个参数
        super().__init__()
        # 调用父类的初始化函数
        self.layers = nn.ModuleList([])
        # 初始化 layers 为一个空的 ModuleList
        for _ in range(depth):
            # 循环 depth 次
            self.layers.append(nn.ModuleList([
                # 向 layers 中添加一个 ModuleList
                Residual(LocalAttention(dim, heads = heads, dim_head = dim_head, dropout = dropout, patch_size = local_patch_size)) if has_local else nn.Identity(),
                # 添加 LocalAttention 或者 Identity 到 ModuleList
                Residual(FeedForward(dim, mlp_mult, dropout = dropout)) if has_local else nn.Identity(),
                # 添加 FeedForward 或者 Identity 到 ModuleList
                Residual(GlobalAttention(dim, heads = heads, dim_head = dim_head, dropout = dropout, k = global_k)),
                # 添加 GlobalAttention 到 ModuleList
                Residual(FeedForward(dim, mlp_mult, dropout = dropout))
                # 添加 FeedForward 到 ModuleList
            ]))
        # 循环结束后,layers 中包含 depth 个 ModuleList
    def forward(self, x):
        # 定义 forward 函数,接受输入 x
        for local_attn, ff1, global_attn, ff2 in self.layers:
            # 遍历 layers 中的每个 ModuleList
            x = local_attn(x)
            # 对 x 应用 local_attn
            x = ff1(x)
            # 对 x 应用 ff1
            x = global_attn(x)
            # 对 x 应用 global_attn
            x = ff2(x)
            # 对 x 应用 ff2
        return x
        # 返回处理后的 x

class TwinsSVT(nn.Module):
    # 定义 TwinsSVT 类,继承自 nn.Module
    def __init__(
        self,
        *,
        num_classes,
        s1_emb_dim = 64,
        s1_patch_size = 4,
        s1_local_patch_size = 7,
        s1_global_k = 7,
        s1_depth = 1,
        s2_emb_dim = 128,
        s2_patch_size = 2,
        s2_local_patch_size = 7,
        s2_global_k = 7,
        s2_depth = 1,
        s3_emb_dim = 256,
        s3_patch_size = 2,
        s3_local_patch_size = 7,
        s3_global_k = 7,
        s3_depth = 5,
        s4_emb_dim = 512,
        s4_patch_size = 2,
        s4_local_patch_size = 7,
        s4_global_k = 7,
        s4_depth = 4,
        peg_kernel_size = 3,
        dropout = 0.
    ):
        # 初始化函数,接受多个参数
        super().__init__()
        # 调用父类的初始化函数
        kwargs = dict(locals())
        # 将参数保存为字典

        dim = 3
        # 初始化维度为 3
        layers = []
        # 初始化 layers 为空列表

        for prefix in ('s1', 's2', 's3', 's4'):
            # 遍历前缀列表
            config, kwargs = group_by_key_prefix_and_remove_prefix(f'{prefix}_', kwargs)
            # 从参数字典中提取以当前前缀开头的参数
            is_last = prefix == 's4'
            # 判断是否是最后一个前缀

            dim_next = config['emb_dim']
            # 获取下一个维度

            layers.append(nn.Sequential(
                # 向 layers 中添加一个 Sequential 模块
                PatchEmbedding(dim = dim, dim_out = dim_next, patch_size = config['patch_size']),
                # 添加 PatchEmbedding 到 Sequential
                Transformer(dim = dim_next, depth = 1, local_patch_size = config['local_patch_size'], global_k = config['global_k'], dropout = dropout, has_local = not is_last),
                # 添加 Transformer 到 Sequential
                PEG(dim = dim_next, kernel_size = peg_kernel_size),
                # 添加 PEG 到 Sequential
                Transformer(dim = dim_next, depth = config['depth'],  local_patch_size = config['local_patch_size'], global_k = config['global_k'], dropout = dropout, has_local = not is_last)
                # 添加 Transformer 到 Sequential
            ))

            dim = dim_next
            # 更新维度为下一个维度

        self.layers = nn.Sequential(
            # 将 layers 中的模块组合成一个 Sequential
            *layers,
            # 展开 layers 中的模块
            nn.AdaptiveAvgPool2d(1),
            # 添加 AdaptiveAvgPool2d 到 Sequential
            Rearrange('... () () -> ...'),
            # 添加 Rearrange 到 Sequential
            nn.Linear(dim, num_classes)
            # 添加 Linear 到 Sequential
        )

    def forward(self, x):
        # 定义 forward 函数,接受输入 x
        return self.layers(x)
        # 返回处理后的 x

.\lucidrains\vit-pytorch\vit_pytorch\vit.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块
from torch import nn

# 从 einops 库中导入 rearrange 和 repeat 函数
from einops import rearrange, repeat
# 从 einops.layers.torch 库中导入 Rearrange 类
from einops.layers.torch import Rearrange

# 定义辅助函数 pair,用于返回元组形式的输入
def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# 定义 FeedForward 类,继承自 nn.Module 类
class FeedForward(nn.Module):
    # 初始化函数,接受维度 dim、隐藏层维度 hidden_dim 和 dropout 概率
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        # 定义神经网络结构
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )

    # 前向传播函数
    def forward(self, x):
        return self.net(x)

# 定义 Attention 类,继承自 nn.Module 类
class Attention(nn.Module):
    # 初始化函数,接受维度 dim、头数 heads、头维度 dim_head 和 dropout 概率
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.norm = nn.LayerNorm(dim)

        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    # 前向传播函数
    def forward(self, x):
        x = self.norm(x)

        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

# 定义 Transformer 类,继承自 nn.Module 类
class Transformer(nn.Module):
    # 初始化函数,接受维度 dim、深度 depth、头数 heads、头维度 dim_head、MLP 维度 mlp_dim 和 dropout 概率
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
                FeedForward(dim, mlp_dim, dropout = dropout)
            ]))

    # 前向传播函数
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x

        return self.norm(x)

# 定义 ViT 类,继承自 nn.Module 类
class ViT(nn.Module):
    # 初始化函数,接受关键字参数 image_size、patch_size、num_classes、dim、depth、heads、mlp_dim、pool、channels、dim_head、dropout 和 emb_dropout
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim),
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Linear(dim, num_classes)
    # 前向传播函数,接收输入图像并进行处理
    def forward(self, img):
        # 将输入图像转换为补丁嵌入
        x = self.to_patch_embedding(img)
        # 获取补丁嵌入的形状信息
        b, n, _ = x.shape

        # 为每个样本添加类别标记
        cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
        # 将类别标记与补丁嵌入拼接在一起
        x = torch.cat((cls_tokens, x), dim=1)
        # 添加位置编码到输入
        x += self.pos_embedding[:, :(n + 1)]
        # 对输入进行 dropout 处理
        x = self.dropout(x)

        # 使用 Transformer 处理输入
        x = self.transformer(x)

        # 对 Transformer 输出进行池化操作,取平均值或者只取第一个位置的输出
        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

        # 将输出转换为潜在空间
        x = self.to_latent(x)
        # 使用 MLP 头部处理最终输出
        return self.mlp_head(x)

.\lucidrains\vit-pytorch\vit_pytorch\vit_1d.py

import torch
from torch import nn

from einops import rearrange, repeat, pack, unpack
from einops.layers.torch import Rearrange

# 导入所需的库

# 定义 FeedForward 类
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),  # 对输入进行 Layer Normalization
            nn.Linear(dim, hidden_dim),  # 线性变换
            nn.GELU(),  # GELU 激活函数
            nn.Dropout(dropout),  # Dropout 正则化
            nn.Linear(hidden_dim, dim),  # 线性变换
            nn.Dropout(dropout)  # Dropout 正则化
        )
    def forward(self, x):
        return self.net(x)

# 定义 Attention 类
class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.norm = nn.LayerNorm(dim)  # 对输入进行 Layer Normalization
        self.attend = nn.Softmax(dim = -1)  # Softmax 函数
        self.dropout = nn.Dropout(dropout)  # Dropout 正则化

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)  # 线性变换

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),  # 线性变换
            nn.Dropout(dropout)  # Dropout 正则化
        ) if project_out else nn.Identity()

    def forward(self, x):
        x = self.norm(x)
        qkv = self.to_qkv(x).chunk(3, dim = -1)  # 将输入进行线性变换后切分成三部分
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)  # 重排张量形状

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale  # 点积计算

        attn = self.attend(dots)  # 注意力权重计算
        attn = self.dropout(attn)  # Dropout 正则化

        out = torch.matmul(attn, v)  # 加权求和
        out = rearrange(out, 'b h n d -> b n (h d)')  # 重排张量形状
        return self.to_out(out)

# 定义 Transformer 类
class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
                FeedForward(dim, mlp_dim, dropout = dropout)
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

# 定义 ViT ��
class ViT(nn.Module):
    def __init__(self, *, seq_len, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()
        assert (seq_len % patch_size) == 0

        num_patches = seq_len // patch_size
        patch_dim = channels * patch_size

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (n p) -> b n (p c)', p = patch_size),  # 重排张量形状
            nn.LayerNorm(patch_dim),  # 对输入进行 Layer Normalization
            nn.Linear(patch_dim, dim),  # 线性变换
            nn.LayerNorm(dim),  # 对输入进行 Layer Normalization
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))  # 位置编码
        self.cls_token = nn.Parameter(torch.randn(dim))  # 类别标记
        self.dropout = nn.Dropout(emb_dropout)  # Dropout 正则化

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)  # Transformer 模块

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),  # 对输入进行 Layer Normalization
            nn.Linear(dim, num_classes)  # 线性变换
        )

    def forward(self, series):
        x = self.to_patch_embedding(series)
        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, 'd -> b d', b = b)  # 重复类别标记

        x, ps = pack([cls_tokens, x], 'b * d')  # 打包张量

        x += self.pos_embedding[:, :(n + 1)]  # 加上位置编码
        x = self.dropout(x)  # Dropout 正则化

        x = self.transformer(x)  # Transformer 模块

        cls_tokens, _ = unpack(x, ps, 'b * d')  # 解包张量

        return self.mlp_head(cls_tokens)  # MLP 头部

if __name__ == '__main__':

    v = ViT(
        seq_len = 256,
        patch_size = 16,
        num_classes = 1000,
        dim = 1024,
        depth = 6,
        heads = 8,
        mlp_dim = 2048,
        dropout = 0.1,
        emb_dropout = 0.1
    )

    time_series = torch.randn(4, 3, 256)
    logits = v(time_series) # (4, 1000)

.\lucidrains\vit-pytorch\vit_pytorch\vit_3d.py

import torch  # 导入 PyTorch 库
from torch import nn  # 从 PyTorch 库中导入 nn 模块

from einops import rearrange, repeat  # 从 einops 库中导入 rearrange 和 repeat 函数
from einops.layers.torch import Rearrange  # 从 einops 库中导入 Torch 版的 Rearrange 类

# helpers

def pair(t):
    return t if isinstance(t, tuple) else (t, t)  # 如果 t 是元组则返回 t,否则返回 (t, t)

# classes

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),  # 对输入进行 Layer Normalization
            nn.Linear(dim, hidden_dim),  # 线性变换
            nn.GELU(),  # GELU 激活函数
            nn.Dropout(dropout),  # Dropout 层
            nn.Linear(hidden_dim, dim),  # 线性变换
            nn.Dropout(dropout)  # Dropout 层
        )
    def forward(self, x):
        return self.net(x)  # 前向传播

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.norm = nn.LayerNorm(dim)  # 对输入进行 Layer Normalization
        self.attend = nn.Softmax(dim = -1)  # Softmax 层
        self.dropout = nn.Dropout(dropout)  # Dropout 层

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)  # 线性变换,用于计算 Q、K、V

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),  # 线性变换
            nn.Dropout(dropout)  # Dropout 层
        ) if project_out else nn.Identity()  # 如果 project_out 为真则使用 nn.Sequential,否则使用 nn.Identity

    def forward(self, x):
        x = self.norm(x)  # Layer Normalization
        qkv = self.to_qkv(x).chunk(3, dim = -1)  # 将线性变换后的结果切分成 Q、K、V
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)  # 重排 Q、K、V 的维度

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale  # 计算 Q、K 的点积

        attn = self.attend(dots)  # 注意力权重
        attn = self.dropout(attn)  # Dropout

        out = torch.matmul(attn, v)  # 加权求和
        out = rearrange(out, 'b h n d -> b n (h d)')  # 重排输出维度
        return self.to_out(out)  # 返回输出

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),  # 注意力层
                FeedForward(dim, mlp_dim, dropout = dropout)  # 前馈神经网络层
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x  # 注意力层输出与输入相加
            x = ff(x) + x  # 前馈神经网络层输出与输入相加
        return x  # 返回输出

class ViT(nn.Module):
    def __init__(self, *, image_size, image_patch_size, frames, frame_patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(image_patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
        assert frames % frame_patch_size == 0, 'Frames must be divisible by frame patch size'

        num_patches = (image_height // patch_height) * (image_width // patch_width) * (frames // frame_patch_size)
        patch_dim = channels * patch_height * patch_width * frame_patch_size

        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (f pf) (h p1) (w p2) -> b (f h w) (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),  # 重排图像补丁的维度
            nn.LayerNorm(patch_dim),  # 对输入进行 Layer Normalization
            nn.Linear(patch_dim, dim),  # 线性变换
            nn.LayerNorm(dim),  # 对输入进行 Layer Normalization
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))  # 位置编码
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))  # 类别标记
        self.dropout = nn.Dropout(emb_dropout)  # Dropout 层

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)  # Transformer 模块

        self.pool = pool  # 池化方式
        self.to_latent = nn.Identity()  # 转换为潜在空间的恒等映射

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),  # 对输入进行 Layer Normalization
            nn.Linear(dim, num_classes)  # 线性变换
        )  # MLP 头部
    # 前向传播函数,接收视频数据作为输入
    def forward(self, video):
        # 将视频数据转换为补丁嵌入
        x = self.to_patch_embedding(video)
        # 获取批量大小、补丁数量和嵌入维度
        b, n, _ = x.shape

        # 重复类别标记以匹配批量大小
        cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
        # 将类别标记与补丁嵌入拼接在一起
        x = torch.cat((cls_tokens, x), dim=1)
        # 添加位置嵌入到输入中
        x += self.pos_embedding[:, :(n + 1)]
        # 对输入进行 dropout 处理
        x = self.dropout(x)

        # 使用 Transformer 处理输入数据
        x = self.transformer(x)

        # 根据池化方式计算输出
        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

        # 将输出转换为潜在空间
        x = self.to_latent(x)
        # 使用 MLP 头部处理潜在空间的输出
        return self.mlp_head(x)

.\lucidrains\vit-pytorch\vit_pytorch\vit_for_small_dataset.py

# 从 math 模块中导入 sqrt 函数
from math import sqrt
# 导入 torch 模块
import torch
# 从 torch.nn 模块中导入 functional 模块和 nn 模块
import torch.nn.functional as F
from torch import nn
# 从 einops 模块中导入 rearrange 和 repeat 函数,从 einops.layers.torch 模块中导入 Rearrange 类

# 定义辅助函数 pair,如果输入参数 t 是元组则返回 t,否则返回 (t, t)
def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# 定义 FeedForward 类,继承自 nn.Module 类
class FeedForward(nn.Module):
    # 初始化函数,接受维度 dim、隐藏层维度 hidden_dim 和 dropout 参数
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        # 定义神经网络结构
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    # 前向传播函数
    def forward(self, x):
        return self.net(x)

# 定义 LSA 类,继承自 nn.Module 类
class LSA(nn.Module):
    # 初始化函数,接受维度 dim、头数 heads、头维度 dim_head 和 dropout 参数
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        self.heads = heads
        self.temperature = nn.Parameter(torch.log(torch.tensor(dim_head ** -0.5)))

        self.norm = nn.LayerNorm(dim)
        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        )

    # 前向传播函数
    def forward(self, x):
        x = self.norm(x)
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.temperature.exp()

        mask = torch.eye(dots.shape[-1], device = dots.device, dtype = torch.bool)
        mask_value = -torch.finfo(dots.dtype).max
        dots = dots.masked_fill(mask, mask_value)

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

# 定义 Transformer 类,继承自 nn.Module 类
class Transformer(nn.Module):
    # 初始化函数,接受维度 dim、深度 depth、头数 heads、头维度 dim_head、MLP维度 mlp_dim 和 dropout 参数
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                LSA(dim, heads = heads, dim_head = dim_head, dropout = dropout),
                FeedForward(dim, mlp_dim, dropout = dropout)
            ]))
    # 前向传播函数
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

# 定义 SPT 类,继承自 nn.Module 类
class SPT(nn.Module):
    # 初始化函数,接受维度 dim、patch 大小 patch_size 和通道数 channels 参数
    def __init__(self, *, dim, patch_size, channels = 3):
        super().__init__()
        patch_dim = patch_size * patch_size * 5 * channels

        self.to_patch_tokens = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim)
        )

    # 前向传播函数
    def forward(self, x):
        shifts = ((1, -1, 0, 0), (-1, 1, 0, 0), (0, 0, 1, -1), (0, 0, -1, 1))
        shifted_x = list(map(lambda shift: F.pad(x, shift), shifts))
        x_with_shifts = torch.cat((x, *shifted_x), dim = 1)
        return self.to_patch_tokens(x_with_shifts)

# 定义 ViT 类
class ViT(nn.Module):
    # 初始化函数,设置模型参数
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        # 调用父类的初始化函数
        super().__init__()
        # 获取图像的高度和宽度
        image_height, image_width = pair(image_size)
        # 获取补丁的高度和宽度
        patch_height, patch_width = pair(patch_size)

        # 检查图像的尺寸是否能被补丁的尺寸整除
        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        # 计算补丁的数量
        num_patches = (image_height // patch_height) * (image_width // patch_width)
        # 计算每个补丁的维度
        patch_dim = channels * patch_height * patch_width
        # 检查池化类型是否为 'cls' 或 'mean'
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        # 创建补丁嵌入层
        self.to_patch_embedding = SPT(dim = dim, patch_size = patch_size, channels = channels)

        # 初始化位置嵌入参数
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        # 初始化类别标记参数
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        # 初始化丢弃层
        self.dropout = nn.Dropout(emb_dropout)

        # 创建 Transformer 模型
        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        # 设置池化类型
        self.pool = pool
        # 创建转换到潜在空间的层
        self.to_latent = nn.Identity()

        # 创建 MLP 头部
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    # 前向传播函数
    def forward(self, img):
        # 将图像转换为补丁
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        # 重复类别标记以匹配批次大小
        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
        # 将类别标记与补丁连接
        x = torch.cat((cls_tokens, x), dim=1)
        # 添加位置嵌入
        x += self.pos_embedding[:, :(n + 1)]
        # 应用丢弃层
        x = self.dropout(x)

        # 使用 Transformer 进行转换
        x = self.transformer(x)

        # 池化操作,根据池化类型选择不同的方式
        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

        # 转换到潜在空间
        x = self.to_latent(x)
        # 使用 MLP 头部进行分类
        return self.mlp_head(x)