Lucidrains-系列项目源码解析-一百零七-

33 阅读28分钟

Lucidrains 系列项目源码解析(一百零七)

.\lucidrains\vit-pytorch\vit_pytorch\cvt.py

import torch
from torch import nn, einsum
import torch.nn.functional as F

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# helper methods

# 根据条件将字典分组
def group_dict_by_key(cond, d):
    return_val = [dict(), dict()]
    for key in d.keys():
        match = bool(cond(key))
        ind = int(not match)
        return_val[ind][key] = d[key]
    return (*return_val,)

# 根据前缀分组并移除前缀
def group_by_key_prefix_and_remove_prefix(prefix, d):
    kwargs_with_prefix, kwargs = group_dict_by_key(lambda x: x.startswith(prefix), d)
    kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))
    return kwargs_without_prefix, kwargs

# classes

# 自定义 LayerNorm 类
class LayerNorm(nn.Module): # layernorm, but done in the channel dimension #1
    def __init__(self, dim, eps = 1e-5):
        super().__init__()
        self.eps = eps
        self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
        self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))

    def forward(self, x):
        var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
        mean = torch.mean(x, dim = 1, keepdim = True)
        return (x - mean) / (var + self.eps).sqrt() * self.g + self.b

# 自定义 FeedForward 类
class FeedForward(nn.Module):
    def __init__(self, dim, mult = 4, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            LayerNorm(dim),
            nn.Conv2d(dim, dim * mult, 1),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Conv2d(dim * mult, dim, 1),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

# 自定义 DepthWiseConv2d 类
class DepthWiseConv2d(nn.Module):
    def __init__(self, dim_in, dim_out, kernel_size, padding, stride, bias = True):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(dim_in, dim_in, kernel_size = kernel_size, padding = padding, groups = dim_in, stride = stride, bias = bias),
            nn.BatchNorm2d(dim_in),
            nn.Conv2d(dim_in, dim_out, kernel_size = 1, bias = bias)
        )
    def forward(self, x):
        return self.net(x)

# 自定义 Attention 类
class Attention(nn.Module):
    def __init__(self, dim, proj_kernel, kv_proj_stride, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        padding = proj_kernel // 2
        self.heads = heads
        self.scale = dim_head ** -0.5

        self.norm = LayerNorm(dim)
        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.to_q = DepthWiseConv2d(dim, inner_dim, proj_kernel, padding = padding, stride = 1, bias = False)
        self.to_kv = DepthWiseConv2d(dim, inner_dim * 2, proj_kernel, padding = padding, stride = kv_proj_stride, bias = False)

        self.to_out = nn.Sequential(
            nn.Conv2d(inner_dim, dim, 1),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        shape = x.shape
        b, n, _, y, h = *shape, self.heads

        x = self.norm(x)
        q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = 1))
        q, k, v = map(lambda t: rearrange(t, 'b (h d) x y -> (b h) (x y) d', h = h), (q, k, v))

        dots = einsum('b i d, b j d -> b i j', q, k) * self.scale

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = einsum('b i j, b j d -> b i d', attn, v)
        out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, y = y)
        return self.to_out(out)

# 自定义 Transformer 类
class Transformer(nn.Module):
    def __init__(self, dim, proj_kernel, kv_proj_stride, depth, heads, dim_head = 64, mlp_mult = 4, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, proj_kernel = proj_kernel, kv_proj_stride = kv_proj_stride, heads = heads, dim_head = dim_head, dropout = dropout),
                FeedForward(dim, mlp_mult, dropout = dropout)
            ]))
    # 定义一个前向传播函数,接受输入 x
    def forward(self, x):
        # 遍历 self.layers 中的每个元素,每个元素包含一个注意力机制和一个前馈神经网络
        for attn, ff in self.layers:
            # 使用注意力机制处理输入 x,并将结果与原始输入相加
            x = attn(x) + x
            # 使用前馈神经网络处理上一步的结果,并将结果与原始输入相加
            x = ff(x) + x
        # 返回处理后的结果 x
        return x
# 定义一个名为 CvT 的神经网络模型,继承自 nn.Module 类
class CvT(nn.Module):
    # 初始化函数,接收一系列参数
    def __init__(
        self,
        *,
        num_classes,  # 类别数量
        s1_emb_dim = 64,  # s1 阶段的嵌入维度
        s1_emb_kernel = 7,  # s1 阶段的卷积核大小
        s1_emb_stride = 4,  # s1 阶段的卷积步长
        s1_proj_kernel = 3,  # s1 阶段的投影卷积核大小
        s1_kv_proj_stride = 2,  # s1 阶段的键值投影步长
        s1_heads = 1,  # s1 阶段的注意力头数
        s1_depth = 1,  # s1 阶段的深度
        s1_mlp_mult = 4,  # s1 阶段的 MLP 扩展倍数
        s2_emb_dim = 192,  # s2 阶段的嵌入维度
        s2_emb_kernel = 3,  # s2 阶段的卷积核大小
        s2_emb_stride = 2,  # s2 阶段的卷积步长
        s2_proj_kernel = 3,  # s2 阶段的投影卷积核大小
        s2_kv_proj_stride = 2,  # s2 阶段的键值投影步长
        s2_heads = 3,  # s2 阶段的注意力头数
        s2_depth = 2,  # s2 阶段的深度
        s2_mlp_mult = 4,  # s2 阶段的 MLP 扩展倍数
        s3_emb_dim = 384,  # s3 阶段的嵌入维度
        s3_emb_kernel = 3,  # s3 阶段的卷积核大小
        s3_emb_stride = 2,  # s3 阶段的卷积步长
        s3_proj_kernel = 3,  # s3 阶段的投影卷积核大小
        s3_kv_proj_stride = 2,  # s3 阶段的键值投影步长
        s3_heads = 6,  # s3 阶段的注意力头数
        s3_depth = 10,  # s3 阶段的深度
        s3_mlp_mult = 4,  # s3 阶段的 MLP 扩展倍数
        dropout = 0.,  # Dropout 概率
        channels = 3  # 输入通道数
    ):
        # 调用父类的初始化函数
        super().__init__()
        # 将参数保存到字典中
        kwargs = dict(locals())

        # 初始化维度为输入通道数
        dim = channels
        # 初始化层列表
        layers = []

        # 遍历 s1、s2、s3 三个阶段
        for prefix in ('s1', 's2', 's3'):
            # 根据前缀分组参数,并从参数字典中移除前缀
            config, kwargs = group_by_key_prefix_and_remove_prefix(f'{prefix}_', kwargs)

            # 将卷积、LayerNorm 和 Transformer 层添加到层列表中
            layers.append(nn.Sequential(
                nn.Conv2d(dim, config['emb_dim'], kernel_size = config['emb_kernel'], padding = (config['emb_kernel'] // 2), stride = config['emb_stride']),
                LayerNorm(config['emb_dim']),
                Transformer(dim = config['emb_dim'], proj_kernel = config['proj_kernel'], kv_proj_stride = config['kv_proj_stride'], depth = config['depth'], heads = config['heads'], mlp_mult = config['mlp_mult'], dropout = dropout)
            ))

            # 更新维度为当前阶段的嵌入维度
            dim = config['emb_dim']

        # 将所有层组成一个序列
        self.layers = nn.Sequential(*layers)

        # 定义输出层,包括全局平均池化、重排和全连接层
        self.to_logits = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            Rearrange('... () () -> ...'),
            nn.Linear(dim, num_classes)
        )

    # 前向传播函数
    def forward(self, x):
        # 经过所有层得到特征向量
        latents = self.layers(x)
        # 将特征向量传递给输出层得到预测结果
        return self.to_logits(latents)

.\lucidrains\vit-pytorch\vit_pytorch\deepvit.py

import torch
from torch import nn, einsum
import torch.nn.functional as F

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# 定义一个前馈神经网络类
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),  # 对输入进行 Layer Normalization
            nn.Linear(dim, hidden_dim),  # 线性变换
            nn.GELU(),  # GELU 激活函数
            nn.Dropout(dropout),  # Dropout 正则化
            nn.Linear(hidden_dim, dim),  # 线性变换
            nn.Dropout(dropout)  # Dropout 正则化
        )
    def forward(self, x):
        return self.net(x)

# 定义一个注意力机制类
class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        self.heads = heads
        self.scale = dim_head ** -0.5

        self.norm = nn.LayerNorm(dim)  # 对输入进行 Layer Normalization
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)  # 线性变换

        self.dropout = nn.Dropout(dropout)  # Dropout 正则化

        self.reattn_weights = nn.Parameter(torch.randn(heads, heads))  # 定义可学习参数

        self.reattn_norm = nn.Sequential(
            Rearrange('b h i j -> b i j h'),  # 重新排列张量维度
            nn.LayerNorm(heads),  # 对输入进行 Layer Normalization
            Rearrange('b i j h -> b h i j')  # 重新排列张量维度
        )

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),  # 线性变换
            nn.Dropout(dropout)  # Dropout 正则化
        )

    def forward(self, x):
        b, n, _, h = *x.shape, self.heads
        x = self.norm(x)  # 对输入进行 Layer Normalization

        qkv = self.to_qkv(x).chunk(3, dim = -1)  # 将线性变换后的结果切分成三部分
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)  # 重新排列张量维度

        # attention

        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale  # 计算点积
        attn = dots.softmax(dim=-1)  # Softmax 操作
        attn = self.dropout(attn)  # Dropout 正则化

        # re-attention

        attn = einsum('b h i j, h g -> b g i j', attn, self.reattn_weights)  # 重新排列张量维度
        attn = self.reattn_norm(attn)  # 对输入进行 Layer Normalization

        # aggregate and out

        out = einsum('b h i j, b h j d -> b h i d', attn, v)  # 点积操作
        out = rearrange(out, 'b h n d -> b n (h d)')  # 重新排列张量维度
        out =  self.to_out(out)  # 线性变换
        return out

# 定义一个 Transformer 类
class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),  # 注意力机制
                FeedForward(dim, mlp_dim, dropout = dropout)  # 前馈神经网络
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x  # 注意力机制的输出与输入相加
            x = ff(x) + x  # 前馈神经网络的输出与输入相加
        return x

# 定义一个 DeepViT 类
class DeepViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()
        assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
        num_patches = (image_size // patch_size) ** 2
        patch_dim = channels * patch_size ** 2
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),  # 重新排列张量维度
            nn.LayerNorm(patch_dim),  # 对输入进行 Layer Normalization
            nn.Linear(patch_dim, dim),  # 线性变换
            nn.LayerNorm(dim)  # 对输入进行 Layer Normalization
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))  # 定义可学习参数
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))  # 定义可学习参数
        self.dropout = nn.Dropout(emb_dropout)  # Dropout 正则化

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)  # Transformer 模块

        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),  # 对输入进行 Layer Normalization
            nn.Linear(dim, num_classes)  # 线性变换
        )
    # 前向传播函数,接收输入图像并返回预测结果
    def forward(self, img):
        # 将输入图像转换为补丁嵌入
        x = self.to_patch_embedding(img)
        # 获取批量大小、补丁数量和嵌入维度
        b, n, _ = x.shape

        # 重复类别标记以匹配批量大小,并与补丁嵌入连接
        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
        x = torch.cat((cls_tokens, x), dim=1)
        # 添加位置嵌入到输入嵌入中
        x += self.pos_embedding[:, :(n + 1)]
        # 对输入进行 dropout 处理
        x = self.dropout(x)

        # 使用 Transformer 处理输入数据
        x = self.transformer(x)

        # 根据池化方式计算输出结果
        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

        # 将输出结果转换为潜在空间
        x = self.to_latent(x)
        # 使用 MLP 头部处理潜在空间的输出,并返回预测结果
        return self.mlp_head(x)

.\lucidrains\vit-pytorch\vit_pytorch\dino.py

# 导入所需的库
import copy
import random
from functools import wraps, partial

import torch
from torch import nn
import torch.nn.functional as F

from torchvision import transforms as T

# 辅助函数

# 检查值是否存在
def exists(val):
    return val is not None

# 如果值存在,则返回该值,否则返回默认值
def default(val, default):
    return val if exists(val) else default

# 单例装饰器,用于缓存结果
def singleton(cache_key):
    def inner_fn(fn):
        @wraps(fn)
        def wrapper(self, *args, **kwargs):
            instance = getattr(self, cache_key)
            if instance is not None:
                return instance

            instance = fn(self, *args, **kwargs)
            setattr(self, cache_key, instance)
            return instance
        return wrapper
    return inner_fn

# 获取模块所在设备
def get_module_device(module):
    return next(module.parameters()).device

# 设置模型参数是否需要梯度
def set_requires_grad(model, val):
    for p in model.parameters():
        p.requires_grad = val

# 损失函数(论文中的算法1)

def loss_fn(
    teacher_logits,
    student_logits,
    teacher_temp,
    student_temp,
    centers,
    eps = 1e-20
):
    teacher_logits = teacher_logits.detach()
    student_probs = (student_logits / student_temp).softmax(dim = -1)
    teacher_probs = ((teacher_logits - centers) / teacher_temp).softmax(dim = -1)
    return - (teacher_probs * torch.log(student_probs + eps)).sum(dim = -1).mean()

# 数据增强工具类

class RandomApply(nn.Module):
    def __init__(self, fn, p):
        super().__init__()
        self.fn = fn
        self.p = p

    def forward(self, x):
        if random.random() > self.p:
            return x
        return self.fn(x)

# 指数移动平均

class EMA():
    def __init__(self, beta):
        super().__init__()
        self.beta = beta

    def update_average(self, old, new):
        if old is None:
            return new
        return old * self.beta + (1 - self.beta) * new

# 更新移动平均值
def update_moving_average(ema_updater, ma_model, current_model):
    for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
        old_weight, up_weight = ma_params.data, current_params.data
        ma_params.data = ema_updater.update_average(old_weight, up_weight)

# MLP类用于投影器和预测器

class L2Norm(nn.Module):
    def forward(self, x, eps = 1e-6):
        norm = x.norm(dim = 1, keepdim = True).clamp(min = eps)
        return x / norm

class MLP(nn.Module):
    def __init__(self, dim, dim_out, num_layers, hidden_size = 256):
        super().__init__()

        layers = []
        dims = (dim, *((hidden_size,) * (num_layers - 1)))

        for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
            is_last = ind == (len(dims) - 1)

            layers.extend([
                nn.Linear(layer_dim_in, layer_dim_out),
                nn.GELU() if not is_last else nn.Identity()
            ])

        self.net = nn.Sequential(
            *layers,
            L2Norm(),
            nn.Linear(hidden_size, dim_out)
        )

    def forward(self, x):
        return self.net(x)

# 用于基础神经网络的包装类
# 将管理隐藏层输出的拦截并将其传递到投影器和预测器网络中

class NetWrapper(nn.Module):
    def __init__(self, net, output_dim, projection_hidden_size, projection_num_layers, layer = -2):
        super().__init__()
        self.net = net
        self.layer = layer

        self.projector = None
        self.projection_hidden_size = projection_hidden_size
        self.projection_num_layers = projection_num_layers
        self.output_dim = output_dim

        self.hidden = {}
        self.hook_registered = False

    def _find_layer(self):
        if type(self.layer) == str:
            modules = dict([*self.net.named_modules()])
            return modules.get(self.layer, None)
        elif type(self.layer) == int:
            children = [*self.net.children()]
            return children[self.layer]
        return None
    # 定义一个私有方法,用于在 forward hook 中保存隐藏层的输出
    def _hook(self, _, input, output):
        # 获取输入数据的设备信息
        device = input[0].device
        # 将隐藏层的输出展平并保存到字典中
        self.hidden[device] = output.flatten(1)

    # 注册 forward hook,用于捕获隐藏层的输出
    def _register_hook(self):
        # 查找指定的隐藏层
        layer = self._find_layer()
        # 断言找到了隐藏层
        assert layer is not None, f'hidden layer ({self.layer}) not found'
        # 注册 forward hook
        handle = layer.register_forward_hook(self._hook)
        self.hook_registered = True

    # 获取投影器,用于将隐藏层的输出投影到指定维度
    @singleton('projector')
    def _get_projector(self, hidden):
        # 获取隐藏层输出的维度
        _, dim = hidden.shape
        # 创建 MLP 投影器
        projector = MLP(dim, self.output_dim, self.projection_num_layers, self.projection_hidden_size)
        return projector.to(hidden)

    # 获取输入数据的隐藏层输出
    def get_embedding(self, x):
        # 如果隐藏层为最后一层,则直接返回网络的输出
        if self.layer == -1:
            return self.net(x)

        # 如果 hook 没有注册,则注册 hook
        if not self.hook_registered:
            self._register_hook()

        # 清空隐藏层输出字典
        self.hidden.clear()
        # 前向传播获取隐藏层输出
        _ = self.net(x)
        hidden = self.hidden[x.device]
        self.hidden.clear()

        # 断言隐藏层输出不为空
        assert hidden is not None, f'hidden layer {self.layer} never emitted an output'
        return hidden

    # 网络的前向传播,可选择是否返回投影后的输出
    def forward(self, x, return_projection = True):
        # 获取输入数据的隐藏层输出
        embed = self.get_embedding(x)
        # 如果不需要返回投影后的输出,则直接返回隐藏层输出
        if not return_projection:
            return embed

        # 获取投影器并对隐藏层输出进行投影
        projector = self._get_projector(embed)
        return projector(embed), embed
# 主类定义
class Dino(nn.Module):
    # 初始化函数
    def __init__(
        self,
        net,
        image_size,
        hidden_layer = -2,
        projection_hidden_size = 256,
        num_classes_K = 65336,
        projection_layers = 4,
        student_temp = 0.9,
        teacher_temp = 0.04,
        local_upper_crop_scale = 0.4,
        global_lower_crop_scale = 0.5,
        moving_average_decay = 0.9,
        center_moving_average_decay = 0.9,
        augment_fn = None,
        augment_fn2 = None
    ):
        # 调用父类的初始化函数
        super().__init__()
        # 设置网络
        self.net = net

        # 默认的 BYOL 数据增强
        DEFAULT_AUG = torch.nn.Sequential(
            RandomApply(
                T.ColorJitter(0.8, 0.8, 0.8, 0.2),
                p = 0.3
            ),
            T.RandomGrayscale(p=0.2),
            T.RandomHorizontalFlip(),
            RandomApply(
                T.GaussianBlur((3, 3), (1.0, 2.0)),
                p = 0.2
            ),
            T.Normalize(
                mean=torch.tensor([0.485, 0.456, 0.406]),
                std=torch.tensor([0.229, 0.224, 0.225])),
        )

        # 设置数据增强函数
        self.augment1 = default(augment_fn, DEFAULT_AUG)
        self.augment2 = default(augment_fn2, DEFAULT_AUG)

        # 设置局部和全局裁剪
        self.local_crop = T.RandomResizedCrop((image_size, image_size), scale = (0.05, local_upper_crop_scale))
        self.global_crop = T.RandomResizedCrop((image_size, image_size), scale = (global_lower_crop_scale, 1.))

        # 设置学生编码器
        self.student_encoder = NetWrapper(net, num_classes_K, projection_hidden_size, projection_layers, layer = hidden_layer)

        self.teacher_encoder = None
        self.teacher_ema_updater = EMA(moving_average_decay)

        # 注册缓冲区
        self.register_buffer('teacher_centers', torch.zeros(1, num_classes_K))
        self.register_buffer('last_teacher_centers',  torch.zeros(1, num_classes_K))

        self.teacher_centering_ema_updater = EMA(center_moving_average_decay)

        self.student_temp = student_temp
        self.teacher_temp = teacher_temp

        # 获取网络设备并将包装器设置为相同设备
        device = get_module_device(net)
        self.to(device)

        # 发送一个模拟图像张量以实例化单例参数
        self.forward(torch.randn(2, 3, image_size, image_size, device=device))

    # 获取教师编码器的单例函数
    @singleton('teacher_encoder')
    def _get_teacher_encoder(self):
        teacher_encoder = copy.deepcopy(self.student_encoder)
        set_requires_grad(teacher_encoder, False)
        return teacher_encoder

    # 重置移动平均值
    def reset_moving_average(self):
        del self.teacher_encoder
        self.teacher_encoder = None

    # 更新移动平均值
    def update_moving_average(self):
        assert self.teacher_encoder is not None, 'target encoder has not been created yet'
        update_moving_average(self.teacher_ema_updater, self.teacher_encoder, self.student_encoder)

        new_teacher_centers = self.teacher_centering_ema_updater.update_average(self.teacher_centers, self.last_teacher_centers)
        self.teacher_centers.copy_(new_teacher_centers)

    # 前向传播函数
    def forward(
        self,
        x,
        return_embedding = False,
        return_projection = True,
        student_temp = None,
        teacher_temp = None
        ):
        # 如果需要返回嵌入向量,则调用学生编码器并返回结果
        if return_embedding:
            return self.student_encoder(x, return_projection = return_projection)

        # 对输入数据进行两种不同的数据增强
        image_one, image_two = self.augment1(x), self.augment2(x)

        # 对增强后的图像进行局部裁剪
        local_image_one, local_image_two   = self.local_crop(image_one),  self.local_crop(image_two)
        # 对增强后的图像进行全局裁剪
        global_image_one, global_image_two = self.global_crop(image_one), self.global_crop(image_two)

        # 使用学生编码器对局部裁剪后的图像进行编码
        student_proj_one, _ = self.student_encoder(local_image_one)
        student_proj_two, _ = self.student_encoder(local_image_two)

        # 使用torch.no_grad()上下文管理器,获取教师编码器并对全局裁剪后的图像进行编码
        with torch.no_grad():
            teacher_encoder = self._get_teacher_encoder()
            teacher_proj_one, _ = teacher_encoder(global_image_one)
            teacher_proj_two, _ = teacher_encoder(global_image_two)

        # 部分应用损失函数,设置学生温度、教师温度和教师中心
        loss_fn_ = partial(
            loss_fn,
            student_temp = default(student_temp, self.student_temp),
            teacher_temp = default(teacher_temp, self.teacher_temp),
            centers = self.teacher_centers
        )

        # 计算教师投影的平均值,并将其复制到最后的教师中心
        teacher_logits_avg = torch.cat((teacher_proj_one, teacher_proj_two)).mean(dim = 0)
        self.last_teacher_centers.copy_(teacher_logits_avg)

        # 计算损失,取两个损失函数的平均值
        loss = (loss_fn_(teacher_proj_one, student_proj_two) + loss_fn_(teacher_proj_two, student_proj_one)) / 2
        return loss

.\lucidrains\vit-pytorch\vit_pytorch\distill.py

import torch  # 导入 PyTorch 库
import torch.nn.functional as F  # 导入 PyTorch 中的函数模块
from torch import nn  # 从 PyTorch 中导入 nn 模块
from vit_pytorch.vit import ViT  # 从 vit_pytorch 库中导入 ViT 类
from vit_pytorch.t2t import T2TViT  # 从 vit_pytorch 库中导入 T2TViT 类
from vit_pytorch.efficient import ViT as EfficientViT  # 从 vit_pytorch 库中导入 EfficientViT 类

from einops import rearrange, repeat  # 从 einops 库中导入 rearrange 和 repeat 函数

# helpers

def exists(val):  # 定义 exists 函数,用于判断变量是否存在
    return val is not None  # 返回变量是否不为 None

# classes

class DistillMixin:  # 定义 DistillMixin 类
    def forward(self, img, distill_token = None):  # 定义 forward 方法,接收图像和 distill_token 参数
        distilling = exists(distill_token)  # 判断 distill_token 是否存在
        x = self.to_patch_embedding(img)  # 将图像转换为 patch embedding
        b, n, _ = x.shape  # 获取 x 的形状信息

        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)  # 重复添加 cls_token
        x = torch.cat((cls_tokens, x), dim = 1)  # 在维度 1 上拼接 cls_tokens 和 x
        x += self.pos_embedding[:, :(n + 1)]  # 添加位置编码

        if distilling:  # 如果进行蒸馏
            distill_tokens = repeat(distill_token, '() n d -> b n d', b = b)  # 重复添加 distill_token
            x = torch.cat((x, distill_tokens), dim = 1)  # 在维度 1 上拼接 x 和 distill_tokens

        x = self._attend(x)  # 调用 _attend 方法进行注意力计算

        if distilling:  # 如果进行蒸馏
            x, distill_tokens = x[:, :-1], x[:, -1]  # 分割出 distill_tokens

        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]  # 计算平均值或取第一个值

        x = self.to_latent(x)  # 转换为 latent 表示
        out = self.mlp_head(x)  # 经过 MLP 头部处理得到输出

        if distilling:  # 如果进行蒸馏
            return out, distill_tokens  # 返回输出和 distill_tokens

        return out  # 返回输出

class DistillableViT(DistillMixin, ViT):  # 定义 DistillableViT 类,继承自 DistillMixin 和 ViT
    def __init__(self, *args, **kwargs):  # 初始化方法
        super(DistillableViT, self).__init__(*args, **kwargs)  # 调用父类的初始化方法
        self.args = args  # 保存参数
        self.kwargs = kwargs  # 保存关键字参数
        self.dim = kwargs['dim']  # 保存维度信息
        self.num_classes = kwargs['num_classes']  # 保存类别数

    def to_vit(self):  # 定义 to_vit 方法
        v = ViT(*self.args, **self.kwargs)  # 创建 ViT 对象
        v.load_state_dict(self.state_dict())  # 加载当前状态字典
        return v  # 返回 ViT 对象

    def _attend(self, x):  # 定义 _attend 方法
        x = self.dropout(x)  # 使用 dropout
        x = self.transformer(x)  # 经过 transformer 处理
        return x  # 返回处理后的结果

class DistillableT2TViT(DistillMixin, T2TViT):  # 定义 DistillableT2TViT 类,继承自 DistillMixin 和 T2TViT
    def __init__(self, *args, **kwargs):  # 初始化方法
        super(DistillableT2TViT, self).__init__(*args, **kwargs)  # 调用父类的初始化方法
        self.args = args  # 保存参数
        self.kwargs = kwargs  # 保存关键字参数
        self.dim = kwargs['dim']  # 保存维度信息
        self.num_classes = kwargs['num_classes']  # 保存类别数

    def to_vit(self):  # 定义 to_vit 方法
        v = T2TViT(*self.args, **self.kwargs)  # 创建 T2TViT 对象
        v.load_state_dict(self.state_dict())  # 加载当前状态字典
        return v  # 返回 T2TViT 对象

    def _attend(self, x):  # 定义 _attend 方法
        x = self.dropout(x)  # 使用 dropout
        x = self.transformer(x)  # 经过 transformer 处理
        return x  # 返回处理后的结果

class DistillableEfficientViT(DistillMixin, EfficientViT):  # 定义 DistillableEfficientViT 类,继承自 DistillMixin 和 EfficientViT
    def __init__(self, *args, **kwargs):  # 初始化方法
        super(DistillableEfficientViT, self).__init__(*args, **kwargs)  # 调用父类的初始化方法
        self.args = args  # 保存参数
        self.kwargs = kwargs  # 保存关键字参数
        self.dim = kwargs['dim']  # 保存维度信息
        self.num_classes = kwargs['num_classes']  # 保存类别数

    def to_vit(self):  # 定义 to_vit 方法
        v = EfficientViT(*self.args, **self.kwargs)  # 创建 EfficientViT 对象
        v.load_state_dict(self.state_dict())  # 加载当前状态字典
        return v  # 返回 EfficientViT 对象

    def _attend(self, x):  # 定义 _attend 方法
        return self.transformer(x)  # 经过 transformer 处理

# knowledge distillation wrapper

class DistillWrapper(nn.Module):  # 定义 DistillWrapper 类,继承自 nn.Module
    def __init__(  # 初始化方法
        self,
        *,
        teacher,  # 教师模型
        student,  # 学生模型
        temperature = 1.,  # 温度参数
        alpha = 0.5,  # alpha 参数
        hard = False  # 是否硬蒸馏
    ):
        super().__init__()  # 调用父类的初始化方法
        assert (isinstance(student, (DistillableViT, DistillableT2TViT, DistillableEfficientViT))) , 'student must be a vision transformer'  # 断言学生模型必须是视觉 transformer

        self.teacher = teacher  # 保存教师模型
        self.student = student  # 保存学生模型

        dim = student.dim  # 获取学生模型的维度信息
        num_classes = student.num_classes  # 获取学生模型的类别数
        self.temperature = temperature  # 保存温度参数
        self.alpha = alpha  # 保存 alpha 参数
        self.hard = hard  # 保存是否硬蒸馏

        self.distillation_token = nn.Parameter(torch.randn(1, 1, dim))  # 创建蒸馏 token

        self.distill_mlp = nn.Sequential(  # 创建 MLP 处理蒸馏信息
            nn.LayerNorm(dim),  # LayerNorm 处理
            nn.Linear(dim, num_classes)  # 线性层处理
        )
    # 定义一个前向传播函数,接受输入图像、标签、温度和权重参数
    def forward(self, img, labels, temperature = None, alpha = None, **kwargs):
        # 获取输入图像的批量大小
        b, *_ = img.shape
        # 如果 alpha 参数存在,则使用传入的值,否则使用类属性中的值
        alpha = alpha if exists(alpha) else self.alpha
        # 如果 temperature 参数存在,则使用传入的值,否则使用类属性中的值
        T = temperature if exists(temperature) else self.temperature

        # 在不计算梯度的情况下,通过教师模型获取教师网络的输出
        with torch.no_grad():
            teacher_logits = self.teacher(img)

        # 通过学生模型获取学生网络的输出和蒸馏 token
        student_logits, distill_tokens = self.student(img, distill_token = self.distillation_token, **kwargs)
        # 通过蒸馏 token 获取蒸馏网络的输出
        distill_logits = self.distill_mlp(distill_tokens)

        # 计算学生网络的交叉熵损失
        loss = F.cross_entropy(student_logits, labels)

        # 如果不是硬蒸馏,则计算软蒸馏损失
        if not self.hard:
            distill_loss = F.kl_div(
                F.log_softmax(distill_logits / T, dim = -1),
                F.softmax(teacher_logits / T, dim = -1).detach(),
            reduction = 'batchmean')
            distill_loss *= T ** 2

        # 如果是硬蒸馏,则计算交叉熵损失
        else:
            teacher_labels = teacher_logits.argmax(dim = -1)
            distill_loss = F.cross_entropy(distill_logits, teacher_labels)

        # 返回加权损失值,结合了学生网络的损失和蒸馏损失
        return loss * (1 - alpha) + distill_loss * alpha

.\lucidrains\vit-pytorch\vit_pytorch\efficient.py

import torch
from torch import nn
from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# 定义一个函数,用于确保输入是一个元组
def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# 定义一个名为 ViT 的类,继承自 nn.Module
class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, transformer, pool = 'cls', channels = 3):
        super().__init__()
        image_size_h, image_size_w = pair(image_size)
        # 检查图像尺寸是否能被 patch 大小整除
        assert image_size_h % patch_size == 0 and image_size_w % patch_size == 0, 'image dimensions must be divisible by the patch size'
        # 检查池化类型是否为 'cls' 或 'mean'
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
        num_patches = (image_size_h // patch_size) * (image_size_w // patch_size)
        patch_dim = channels * patch_size ** 2

        # 定义将图像切片转换为嵌入向量的序列
        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim)
        )

        # 初始化位置编码和类别标记
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.transformer = transformer

        self.pool = pool
        self.to_latent = nn.Identity()

        # 定义 MLP 头部
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    # 前向传播函数
    def forward(self, img):
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        # 重复类别标记以匹配批次大小
        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :(n + 1)]
        x = self.transformer(x)

        # 根据池化类型选择不同的池化方式
        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

        x = self.to_latent(x)
        return self.mlp_head(x)

.\lucidrains\vit-pytorch\vit_pytorch\es_vit.py

# 导入所需的库
import copy
import random
from functools import wraps, partial
import torch
from torch import nn, einsum
import torch.nn.functional as F
from torchvision import transforms as T
from einops import rearrange, reduce, repeat

# 辅助函数

# 检查值是否存在
def exists(val):
    return val is not None

# 如果值存在则返回该值,否则返回默认值
def default(val, default):
    return val if exists(val) else default

# 单例装饰器,用于缓存结果
def singleton(cache_key):
    def inner_fn(fn):
        @wraps(fn)
        def wrapper(self, *args, **kwargs):
            instance = getattr(self, cache_key)
            if instance is not None:
                return instance
            instance = fn(self, *args, **kwargs)
            setattr(self, cache_key, instance)
            return instance
        return wrapper
    return inner_fn

# 获取模块所在设备
def get_module_device(module):
    return next(module.parameters()).device

# 设置模型参数是否需要梯度
def set_requires_grad(model, val):
    for p in model.parameters():
        p.requires_grad = val

# 张量相关的辅助函数

# 对张量取对数
def log(t, eps = 1e-20):
    return torch.log(t + eps)

# 损失函数

# 视图损失函数
def view_loss_fn(
    teacher_logits,
    student_logits,
    teacher_temp,
    student_temp,
    centers,
    eps = 1e-20
):
    teacher_logits = teacher_logits.detach()
    student_probs = (student_logits / student_temp).softmax(dim = -1)
    teacher_probs = ((teacher_logits - centers) / teacher_temp).softmax(dim = -1)
    return - (teacher_probs * log(student_probs, eps)).sum(dim = -1).mean()

# 区域损失函数
def region_loss_fn(
    teacher_logits,
    student_logits,
    teacher_latent,
    student_latent,
    teacher_temp,
    student_temp,
    centers,
    eps = 1e-20
):
    teacher_logits = teacher_logits.detach()
    student_probs = (student_logits / student_temp).softmax(dim = -1)
    teacher_probs = ((teacher_logits - centers) / teacher_temp).softmax(dim = -1)

    sim_matrix = einsum('b i d, b j d -> b i j', student_latent, teacher_latent)
    sim_indices = sim_matrix.max(dim = -1).indices
    sim_indices = repeat(sim_indices, 'b n -> b n k', k = teacher_probs.shape[-1])
    max_sim_teacher_probs = teacher_probs.gather(1, sim_indices)

    return - (max_sim_teacher_probs * log(student_probs, eps)).sum(dim = -1).mean()

# 数据增强工具

# 随机应用函数
class RandomApply(nn.Module):
    def __init__(self, fn, p):
        super().__init__()
        self.fn = fn
        self.p = p

    def forward(self, x):
        if random.random() > self.p:
            return x
        return self.fn(x)

# 指数移动平均

class EMA():
    def __init__(self, beta):
        super().__init__()
        self.beta = beta

    def update_average(self, old, new):
        if old is None:
            return new
        return old * self.beta + (1 - self.beta) * new

# 更新移动平均值
def update_moving_average(ema_updater, ma_model, current_model):
    for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
        old_weight, up_weight = ma_params.data, current_params.data
        ma_params.data = ema_updater.update_average(old_weight, up_weight)

# MLP 类用于投影器和预测器

# L2范数
class L2Norm(nn.Module):
    def forward(self, x, eps = 1e-6):
        return F.normalize(x, dim = 1, eps = eps)

# 多层感知机
class MLP(nn.Module):
    def __init__(self, dim, dim_out, num_layers, hidden_size = 256):
        super().__init__()

        layers = []
        dims = (dim, *((hidden_size,) * (num_layers - 1)))

        for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
            is_last = ind == (len(dims) - 1)

            layers.extend([
                nn.Linear(layer_dim_in, layer_dim_out),
                nn.GELU() if not is_last else nn.Identity()
            ])

        self.net = nn.Sequential(
            *layers,
            L2Norm(),
            nn.Linear(hidden_size, dim_out)
        )

    def forward(self, x):
        return self.net(x)

# 用于基础神经网络的包装类
# 将管理隐藏层输出的拦截
# 创建一个包装器类,用于将输入传递到投影器和预测器网络中
class NetWrapper(nn.Module):
    def __init__(self, net, output_dim, projection_hidden_size, projection_num_layers, layer = -2):
        super().__init__()
        self.net = net
        self.layer = layer

        self.view_projector = None
        self.region_projector = None
        self.projection_hidden_size = projection_hidden_size
        self.projection_num_layers = projection_num_layers
        self.output_dim = output_dim

        self.hidden = {}
        self.hook_registered = False

    # 查找指定的层
    def _find_layer(self):
        if type(self.layer) == str:
            modules = dict([*self.net.named_modules()])
            return modules.get(self.layer, None)
        elif type(self.layer) == int:
            children = [*self.net.children()]
            return children[self.layer]
        return None

    # 钩子函数,用于获取隐藏层输出
    def _hook(self, _, input, output):
        device = input[0].device
        self.hidden[device] = output

    # 注册钩子函数
    def _register_hook(self):
        layer = self._find_layer()
        assert layer is not None, f'hidden layer ({self.layer}) not found'
        handle = layer.register_forward_hook(self._hook)
        self.hook_registered = True

    # 获取视图投影器
    @singleton('view_projector')
    def _get_view_projector(self, hidden):
        dim = hidden.shape[1]
        projector = MLP(dim, self.output_dim, self.projection_num_layers, self.projection_hidden_size)
        return projector.to(hidden)

    # 获取区域投影器
    @singleton('region_projector')
    def _get_region_projector(self, hidden):
        dim = hidden.shape[1]
        projector = MLP(dim, self.output_dim, self.projection_num_layers, self.projection_hidden_size)
        return projector.to(hidden)

    # 获取嵌入向量
    def get_embedding(self, x):
        if self.layer == -1:
            return self.net(x)

        if not self.hook_registered:
            self._register_hook()

        self.hidden.clear()
        _ = self.net(x)
        hidden = self.hidden[x.device]
        self.hidden.clear()

        assert hidden is not None, f'hidden layer {self.layer} never emitted an output'
        return hidden

    # 前向传播函数
    def forward(self, x, return_projection = True):
        region_latents = self.get_embedding(x)
        global_latent = reduce(region_latents, 'b c h w -> b c', 'mean')

        if not return_projection:
            return global_latent, region_latents

        view_projector = self._get_view_projector(global_latent)
        region_projector = self._get_region_projector(region_latents)

        region_latents = rearrange(region_latents, 'b c h w -> b (h w) c')

        return view_projector(global_latent), region_projector(region_latents), region_latents

# 主类
class EsViTTrainer(nn.Module):
    def __init__(
        self,
        net,
        image_size,
        hidden_layer = -2,
        projection_hidden_size = 256,
        num_classes_K = 65336,
        projection_layers = 4,
        student_temp = 0.9,
        teacher_temp = 0.04,
        local_upper_crop_scale = 0.4,
        global_lower_crop_scale = 0.5,
        moving_average_decay = 0.9,
        center_moving_average_decay = 0.9,
        augment_fn = None,
        augment_fn2 = None
    # 定义一个继承自父类的子类,初始化网络
    ):
        super().__init__()
        self.net = net

        # 默认的 BYOL 数据增强

        DEFAULT_AUG = torch.nn.Sequential(
            # 随机应用颜色抖动
            RandomApply(
                T.ColorJitter(0.8, 0.8, 0.8, 0.2),
                p = 0.3
            ),
            # 随机转换为灰度图像
            T.RandomGrayscale(p=0.2),
            # 随机水平翻转
            T.RandomHorizontalFlip(),
            # 随机应用高斯模糊
            RandomApply(
                T.GaussianBlur((3, 3), (1.0, 2.0)),
                p = 0.2
            ),
            # 归一化
            T.Normalize(
                mean=torch.tensor([0.485, 0.456, 0.406]),
                std=torch.tensor([0.229, 0.224, 0.225])),
        )

        # 初始化两种数据增强方式
        self.augment1 = default(augment_fn, DEFAULT_AUG)
        self.augment2 = default(augment_fn2, DEFAULT_AUG)

        # 定义局部和全局裁剪
        self.local_crop = T.RandomResizedCrop((image_size, image_size), scale = (0.05, local_upper_crop_scale))
        self.global_crop = T.RandomResizedCrop((image_size, image_size), scale = (global_lower_crop_scale, 1.))

        # 初始化学生编码器
        self.student_encoder = NetWrapper(net, num_classes_K, projection_hidden_size, projection_layers, layer = hidden_layer)

        # 初始化教师编码器和指数移动平均更新器
        self.teacher_encoder = None
        self.teacher_ema_updater = EMA(moving_average_decay)

        # 注册缓冲区,用于存储教师视图中心和区域中心
        self.register_buffer('teacher_view_centers', torch.zeros(1, num_classes_K))
        self.register_buffer('last_teacher_view_centers',  torch.zeros(1, num_classes_K))

        self.register_buffer('teacher_region_centers', torch.zeros(1, num_classes_K))
        self.register_buffer('last_teacher_region_centers',  torch.zeros(1, num_classes_K))

        # 初始化教师中心指数移动平均更新器
        self.teacher_centering_ema_updater = EMA(center_moving_average_decay)

        self.student_temp = student_temp
        self.teacher_temp = teacher_temp

        # 获取网络设备并将包装器设备设置为相同
        device = get_module_device(net)
        self.to(device)

        # 发送一个模拟图像张量以实例化单例参数
        self.forward(torch.randn(2, 3, image_size, image_size, device=device))

    # 使用装饰器创建单例模式,获取教师编码器
    @singleton('teacher_encoder')
    def _get_teacher_encoder(self):
        teacher_encoder = copy.deepcopy(self.student_encoder)
        set_requires_grad(teacher_encoder, False)
        return teacher_encoder

    # 重置移动平均值
    def reset_moving_average(self):
        del self.teacher_encoder
        self.teacher_encoder = None

    # 更新移动平均值
    def update_moving_average(self):
        assert self.teacher_encoder is not None, 'target encoder has not been created yet'
        update_moving_average(self.teacher_ema_updater, self.teacher_encoder, self.student_encoder)

        new_teacher_view_centers = self.teacher_centering_ema_updater.update_average(self.teacher_view_centers, self.last_teacher_view_centers)
        self.teacher_view_centers.copy_(new_teacher_view_centers)

        new_teacher_region_centers = self.teacher_centering_ema_updater.update_average(self.teacher_region_centers, self.last_teacher_region_centers)
        self.teacher_region_centers.copy_(new_teacher_region_centers)

    # 前向传播函数
    def forward(
        self,
        x,
        return_embedding = False,
        return_projection = True,
        student_temp = None,
        teacher_temp = None
        ):
        # 如果需要返回嵌入向量,则调用学生编码器并返回结果
        if return_embedding:
            return self.student_encoder(x, return_projection = return_projection)

        # 对输入数据进行两种不同的数据增强
        image_one, image_two = self.augment1(x), self.augment2(x)

        # 对增强后的数据进行局部裁剪和全局裁剪
        local_image_one, local_image_two   = self.local_crop(image_one),  self.local_crop(image_two)
        global_image_one, global_image_two = self.global_crop(image_one), self.global_crop(image_two)

        # 使用学生编码器对局部裁剪后的数据进行编码
        student_view_proj_one, student_region_proj_one, student_latent_one = self.student_encoder(local_image_one)
        student_view_proj_two, student_region_proj_two, student_latent_two = self.student_encoder(local_image_two)

        # 使用torch.no_grad()上下文管理器,获取教师编码器的结果
        with torch.no_grad():
            teacher_encoder = self._get_teacher_encoder()
            teacher_view_proj_one, teacher_region_proj_one, teacher_latent_one = teacher_encoder(global_image_one)
            teacher_view_proj_two, teacher_region_proj_two, teacher_latent_two = teacher_encoder(global_image_two)

        # 部分函数调用,设置视图级别损失函数和区域级别损失函数的参数
        view_loss_fn_ = partial(
            view_loss_fn,
            student_temp = default(student_temp, self.student_temp),
            teacher_temp = default(teacher_temp, self.teacher_temp),
            centers = self.teacher_view_centers
        )

        region_loss_fn_ = partial(
            region_loss_fn,
            student_temp = default(student_temp, self.student_temp),
            teacher_temp = default(teacher_temp, self.teacher_temp),
            centers = self.teacher_region_centers
        )

        # 计算视图级别损失
        teacher_view_logits_avg = torch.cat((teacher_view_proj_one, teacher_view_proj_two)).mean(dim = 0)
        self.last_teacher_view_centers.copy_(teacher_view_logits_avg)

        teacher_region_logits_avg = torch.cat((teacher_region_proj_one, teacher_region_proj_two)).mean(dim = (0, 1))
        self.last_teacher_region_centers.copy_(teacher_region_logits_avg)

        view_loss = (view_loss_fn_(teacher_view_proj_one, student_view_proj_two) \
                   + view_loss_fn_(teacher_view_proj_two, student_view_proj_one)) / 2

        # 计算区域级别损失
        region_loss = (region_loss_fn_(teacher_region_proj_one, student_region_proj_two, teacher_latent_one, student_latent_two) \
                     + region_loss_fn_(teacher_region_proj_two, student_region_proj_one, teacher_latent_two, student_latent_one)) / 2

        # 返回视图级别损失和区域级别损失的平均值
        return (view_loss + region_loss) / 2

.\lucidrains\vit-pytorch\vit_pytorch\extractor.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块
from torch import nn

# 检查值是否存在
def exists(val):
    return val is not None

# 返回输入值
def identity(t):
    return t

# 克隆并分离张量
def clone_and_detach(t):
    return t.clone().detach()

# 应用函数到元组或单个值
def apply_tuple_or_single(fn, val):
    if isinstance(val, tuple):
        return tuple(map(fn, val))
    return fn(val)

# 定义 Extractor 类,继承自 nn.Module
class Extractor(nn.Module):
    def __init__(
        self,
        vit,
        device = None,
        layer = None,
        layer_name = 'transformer',
        layer_save_input = False,
        return_embeddings_only = False,
        detach = True
    ):
        super().__init__()
        # 初始化属性
        self.vit = vit

        self.data = None
        self.latents = None
        self.hooks = []
        self.hook_registered = False
        self.ejected = False
        self.device = device

        self.layer = layer
        self.layer_name = layer_name
        self.layer_save_input = layer_save_input # 是否保存层的输入或输出
        self.return_embeddings_only = return_embeddings_only

        # 根据 detach 参数选择克隆并分离函数或返回输入值函数
        self.detach_fn = clone_and_detach if detach else identity

    # 钩子函数,用于提取特征
    def _hook(self, _, inputs, output):
        layer_output = inputs if self.layer_save_input else output
        self.latents = apply_tuple_or_single(self.detach_fn, layer_output)

    # 注册钩子函数
    def _register_hook(self):
        if not exists(self.layer):
            assert hasattr(self.vit, self.layer_name), 'layer whose output to take as embedding not found in vision transformer'
            layer = getattr(self.vit, self.layer_name)
        else:
            layer = self.layer

        handle = layer.register_forward_hook(self._hook)
        self.hooks.append(handle)
        self.hook_registered = True

    # 弹出钩子函数
    def eject(self):
        self.ejected = True
        for hook in self.hooks:
            hook.remove()
        self.hooks.clear()
        return self.vit

    # 清除特征
    def clear(self):
        del self.latents
        self.latents = None

    # 前向传播函数
    def forward(
        self,
        img,
        return_embeddings_only = False
    ):
        assert not self.ejected, 'extractor has been ejected, cannot be used anymore'
        self.clear()
        if not self.hook_registered:
            self._register_hook()

        pred = self.vit(img)

        target_device = self.device if exists(self.device) else img.device
        latents = apply_tuple_or_single(lambda t: t.to(target_device), self.latents)

        if return_embeddings_only or self.return_embeddings_only:
            return latents

        return pred, latents

.\lucidrains\vit-pytorch\vit_pytorch\learnable_memory_vit.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块
from torch import nn
# 从 torch.nn 模块中导入 functional 模块
import torch.nn.functional as F

# 从 einops 库中导入 rearrange 和 repeat 函数
from einops import rearrange, repeat
# 从 einops.layers.torch 库中导入 Rearrange 类
from einops.layers.torch import Rearrange

# 辅助函数

# 判断变量是否存在的函数
def exists(val):
    return val is not None

# 将输入转换为元组的函数
def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# 控制层是否冻结的函数

# 设置模块参数是否需要梯度的函数
def set_module_requires_grad_(module, requires_grad):
    for param in module.parameters():
        param.requires_grad = requires_grad

# 冻结所有层的函数
def freeze_all_layers_(module):
    set_module_requires_grad_(module, False)

# 解冻所有层的函数
def unfreeze_all_layers_(module):
    set_module_requires_grad_(module, True)

# 类

# 前馈神经网络类
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

# 注意力机制类
class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads

        self.heads = heads
        self.scale = dim_head ** -0.5
        self.norm = nn.LayerNorm(dim)

        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x, attn_mask = None, memories = None):
        x = self.norm(x)

        x_kv = x # input for key / values projection

        if exists(memories):
            # add memories to key / values if it is passed in
            memories = repeat(memories, 'n d -> b n d', b = x.shape[0]) if memories.ndim == 2 else memories
            x_kv = torch.cat((x_kv, memories), dim = 1)

        qkv = (self.to_q(x), *self.to_kv(x_kv).chunk(2, dim = -1))
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        if exists(attn_mask):
            dots = dots.masked_fill(~attn_mask, -torch.finfo(dots.dtype).max)

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

# Transformer 类
class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
                FeedForward(dim, mlp_dim, dropout = dropout)
            ]))

    def forward(self, x, attn_mask = None, memories = None):
        for ind, (attn, ff) in enumerate(self.layers):
            layer_memories = memories[ind] if exists(memories) else None

            x = attn(x, attn_mask = attn_mask, memories = layer_memories) + x
            x = ff(x) + x
        return x

# ViT ��
class ViT(nn.Module):
    # 初始化函数,设置模型参数和结构
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        # 调用父类的初始化函数
        super().__init__()
        # 获取图像的高度和宽度
        image_height, image_width = pair(image_size)
        # 获取patch的高度和宽度
        patch_height, patch_width = pair(patch_size)

        # 断言图像的高度和宽度能够被patch的高度和宽度整除
        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        # 计算patch的数量
        num_patches = (image_height // patch_height) * (image_width // patch_width)
        # 计算每个patch的维度
        patch_dim = channels * patch_height * patch_width
        # 断言池化类型只能是'cls'或'mean'
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        # 定义将图像转换为patch嵌入的层
        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim)
        )

        # 初始化位置嵌入参数
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        # 初始化类别标记参数
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        # 初始化dropout层
        self.dropout = nn.Dropout(emb_dropout)

        # 初始化transformer模型
        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        # 定义MLP头部
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    # 将图像转换为tokens
    def img_to_tokens(self, img):
        # 将图像转换为patch嵌入
        x = self.to_patch_embedding(img)

        # 重复类别标记,拼接到patch嵌入中
        cls_tokens = repeat(self.cls_token, '1 n d -> b n d', b = x.shape[0])
        x = torch.cat((cls_tokens, x), dim = 1)

        # 添加位置嵌入并进行dropout
        x += self.pos_embedding
        x = self.dropout(x)
        return x

    # 前向传播函数
    def forward(self, img):
        # 将图像转换为tokens
        x = self.img_to_tokens(img)        

        # 使用transformer模型处理tokens
        x = self.transformer(x)

        # 获取类别标记的输出
        cls_tokens = x[:, 0]
        return self.mlp_head(cls_tokens)
# 适配器模块,具有每层可学习的记忆、记忆 CLS 标记和可学习的适配器头部

class Adapter(nn.Module):
    def __init__(
        self,
        *,
        vit,
        num_memories_per_layer = 10,
        num_classes = 2,   
    ):
        super().__init__()
        assert isinstance(vit, ViT)

        # 提取一些需要的模型变量

        dim = vit.cls_token.shape[-1]
        layers = len(vit.transformer.layers)
        num_patches = vit.pos_embedding.shape[-2]

        self.vit = vit

        # 冻结 ViT 主干 - 只有记忆会被微调

        freeze_all_layers_(vit)

        # 可学习的参数

        self.memory_cls_token = nn.Parameter(torch.randn(dim))
        self.memories_per_layer = nn.Parameter(torch.randn(layers, num_memories_per_layer, dim))

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

        # 专门的注意力掩码以保留原始 ViT 的输出
        # 它允许记忆 CLS 标记关注所有其他标记(和可学习的记忆层标记),但反之亦然

        attn_mask = torch.ones((num_patches, num_patches), dtype = torch.bool)
        attn_mask = F.pad(attn_mask, (1, num_memories_per_layer), value = False)  # 主要标记不能关注每层的可学习记忆
        attn_mask = F.pad(attn_mask, (0, 0, 1, 0), value = True)                  # 记忆 CLS 标记可以关注所有内容
        self.register_buffer('attn_mask', attn_mask)

    def forward(self, img):
        b = img.shape[0]

        tokens = self.vit.img_to_tokens(img)

        # 添加任务特定的记忆标记

        memory_cls_tokens = repeat(self.memory_cls_token, 'd -> b 1 d', b = b)
        tokens = torch.cat((memory_cls_tokens, tokens), dim = 1)        

        # 通过变压器传递记忆以及图像标记进行关注

        out = self.vit.transformer(tokens, memories = self.memories_per_layer, attn_mask = self.attn_mask)

        # 提取记忆 CLS 标记

        memory_cls_tokens = out[:, 0]

        # 通过任务特定的适配器头部传递

        return self.mlp_head(memory_cls_tokens)

.\lucidrains\vit-pytorch\vit_pytorch\levit.py

# 从 math 模块中导入 ceil 函数
from math import ceil

# 导入 torch 模块及相关子模块
import torch
from torch import nn, einsum
import torch.nn.functional as F

# 导入 einops 模块中的 rearrange 和 repeat 函数,以及 torch 子模块中的 Rearrange 类
from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# 辅助函数

# 判断变量是否存在的函数
def exists(val):
    return val is not None

# 返回默认值的函数
def default(val, d):
    return val if exists(val) else d

# 将输入值转换为元组的函数
def cast_tuple(val, l = 3):
    val = val if isinstance(val, tuple) else (val,)
    return (*val, *((val[-1],) * max(l - len(val), 0))

# 返回固定值的函数
def always(val):
    return lambda *args, **kwargs: val

# 类

# 前馈神经网络类
class FeedForward(nn.Module):
    def __init__(self, dim, mult, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(dim, dim * mult, 1),
            nn.Hardswish(),
            nn.Dropout(dropout),
            nn.Conv2d(dim * mult, dim, 1),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

# 注意力机制类
class Attention(nn.Module):
    def __init__(self, dim, fmap_size, heads = 8, dim_key = 32, dim_value = 64, dropout = 0., dim_out = None, downsample = False):
        super().__init__()
        inner_dim_key = dim_key *  heads
        inner_dim_value = dim_value *  heads
        dim_out = default(dim_out, dim)

        self.heads = heads
        self.scale = dim_key ** -0.5

        self.to_q = nn.Sequential(nn.Conv2d(dim, inner_dim_key, 1, stride = (2 if downsample else 1), bias = False), nn.BatchNorm2d(inner_dim_key))
        self.to_k = nn.Sequential(nn.Conv2d(dim, inner_dim_key, 1, bias = False), nn.BatchNorm2d(inner_dim_key))
        self.to_v = nn.Sequential(nn.Conv2d(dim, inner_dim_value, 1, bias = False), nn.BatchNorm2d(inner_dim_value))

        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        out_batch_norm = nn.BatchNorm2d(dim_out)
        nn.init.zeros_(out_batch_norm.weight)

        self.to_out = nn.Sequential(
            nn.GELU(),
            nn.Conv2d(inner_dim_value, dim_out, 1),
            out_batch_norm,
            nn.Dropout(dropout)
        )

        # 位置偏置

        self.pos_bias = nn.Embedding(fmap_size * fmap_size, heads)

        q_range = torch.arange(0, fmap_size, step = (2 if downsample else 1))
        k_range = torch.arange(fmap_size)

        q_pos = torch.stack(torch.meshgrid(q_range, q_range, indexing = 'ij'), dim = -1)
        k_pos = torch.stack(torch.meshgrid(k_range, k_range, indexing = 'ij'), dim = -1)

        q_pos, k_pos = map(lambda t: rearrange(t, 'i j c -> (i j) c'), (q_pos, k_pos))
        rel_pos = (q_pos[:, None, ...] - k_pos[None, :, ...]).abs()

        x_rel, y_rel = rel_pos.unbind(dim = -1)
        pos_indices = (x_rel * fmap_size) + y_rel

        self.register_buffer('pos_indices', pos_indices)

    def apply_pos_bias(self, fmap):
        bias = self.pos_bias(self.pos_indices)
        bias = rearrange(bias, 'i j h -> () h i j')
        return fmap + (bias / self.scale)

    def forward(self, x):
        b, n, *_, h = *x.shape, self.heads

        q = self.to_q(x)
        y = q.shape[2]

        qkv = (q, self.to_k(x), self.to_v(x))
        q, k, v = map(lambda t: rearrange(t, 'b (h d) ... -> b h (...) d', h = h), qkv)

        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        dots = self.apply_pos_bias(dots)

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h (x y) d -> b (h d) x y', h = h, y = y)
        return self.to_out(out)

class Transformer(nn.Module):
    # 初始化函数,设置模型参数和结构
    def __init__(self, dim, fmap_size, depth, heads, dim_key, dim_value, mlp_mult = 2, dropout = 0., dim_out = None, downsample = False):
        # 调用父类的初始化函数
        super().__init__()
        # 如果未指定输出维度,则默认为输入维度
        dim_out = default(dim_out, dim)
        # 初始化一个空的模块列表用于存储每个层
        self.layers = nn.ModuleList([])
        # 判断是否使用注意力机制的残差连接
        self.attn_residual = (not downsample) and dim == dim_out

        # 根据深度循环创建每个层
        for _ in range(depth):
            # 每个层包含一个注意力机制和一个前馈神经网络
            self.layers.append(nn.ModuleList([
                Attention(dim, fmap_size = fmap_size, heads = heads, dim_key = dim_key, dim_value = dim_value, dropout = dropout, downsample = downsample, dim_out = dim_out),
                FeedForward(dim_out, mlp_mult, dropout = dropout)
            ]))
    
    # 前向传播函数,处理输入数据
    def forward(self, x):
        # 遍历每个层
        for attn, ff in self.layers:
            # 如果使用注意力机制的残差连接,则保存输入数据
            attn_res = (x if self.attn_residual else 0)
            # 经过注意力机制处理后,加上残差连接
            x = attn(x) + attn_res
            # 经过前馈神经网络处理后,加上残差连接
            x = ff(x) + x
        # 返回处理后的数据
        return x
# 定义 LeViT 类,继承自 nn.Module
class LeViT(nn.Module):
    # 初始化函数,接收多个参数
    def __init__(
        self,
        *,
        image_size,  # 图像大小
        num_classes,  # 类别数量
        dim,  # 维度
        depth,  # 深度
        heads,  # 头数
        mlp_mult,  # MLP 倍数
        stages = 3,  # 阶段数,默认为 3
        dim_key = 32,  # 键维度,默认为 32
        dim_value = 64,  # 值维度,默认为 64
        dropout = 0.,  # Dropout,默认为 0
        num_distill_classes = None  # 蒸馏类别数量,默认为 None
    ):
        # 调用父类的初始化函数
        super().__init__()

        # 将 dim、depth、heads 转换为元组
        dims = cast_tuple(dim, stages)
        depths = cast_tuple(depth, stages)
        layer_heads = cast_tuple(heads, stages)

        # 断言确保 dimensions、depths、heads 必须是小于指定阶段数的元组
        assert all(map(lambda t: len(t) == stages, (dims, depths, layer_heads))), 'dimensions, depths, and heads must be a tuple that is less than the designated number of stages'

        # 定义卷积嵌入层
        self.conv_embedding = nn.Sequential(
            nn.Conv2d(3, 32, 3, stride = 2, padding = 1),
            nn.Conv2d(32, 64, 3, stride = 2, padding = 1),
            nn.Conv2d(64, 128, 3, stride = 2, padding = 1),
            nn.Conv2d(128, dims[0], 3, stride = 2, padding = 1)
        )

        # 计算特征图大小
        fmap_size = image_size // (2 ** 4)
        layers = []

        # 遍历阶段,构建 Transformer 层
        for ind, dim, depth, heads in zip(range(stages), dims, depths, layer_heads):
            is_last = ind == (stages - 1)
            layers.append(Transformer(dim, fmap_size, depth, heads, dim_key, dim_value, mlp_mult, dropout))

            if not is_last:
                next_dim = dims[ind + 1]
                layers.append(Transformer(dim, fmap_size, 1, heads * 2, dim_key, dim_value, dim_out = next_dim, downsample = True))
                fmap_size = ceil(fmap_size / 2)

        # 构建骨干网络
        self.backbone = nn.Sequential(*layers)

        # 定义池化层
        self.pool = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            Rearrange('... () () -> ...')
        )

        # 定义蒸馏头部
        self.distill_head = nn.Linear(dim, num_distill_classes) if exists(num_distill_classes) else always(None)
        # 定义 MLP 头部
        self.mlp_head = nn.Linear(dim, num_classes)

    # 前向传播函数
    def forward(self, img):
        # 图像经过卷积嵌入层
        x = self.conv_embedding(img)

        # 特征图经过骨干网络
        x = self.backbone(x)        

        # 特征图经过池化层
        x = self.pool(x)

        # 输出结果经过 MLP 头部
        out = self.mlp_head(x)
        # 蒸馏结果经过蒸馏头部
        distill = self.distill_head(x)

        # 如果存在蒸馏结果,则返回输出结果和蒸馏结果
        if exists(distill):
            return out, distill

        # 否则只返回输出结果
        return out

.\lucidrains\vit-pytorch\vit_pytorch\local_vit.py

# 从 math 模块中导入 sqrt 函数
from math import sqrt
# 导入 torch 模块
import torch
# 从 torch 模块中导入 nn, einsum
from torch import nn, einsum
# 从 torch.nn 模块中导入 functional 模块
import torch.nn.functional as F

# 从 einops 模块中导入 rearrange, repeat
from einops import rearrange, repeat
# 从 einops.layers.torch 模块中导入 Rearrange 类

# classes

# 定义 Residual 类,继承自 nn.Module
class Residual(nn.Module):
    # 初始化函数
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    # 前向传播函数
    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) + x

# 定义 ExcludeCLS 类,继承自 nn.Module
class ExcludeCLS(nn.Module):
    # 初始化函数
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    # 前向传播函数
    def forward(self, x, **kwargs):
        cls_token, x = x[:, :1], x[:, 1:]
        x = self.fn(x, **kwargs)
        return torch.cat((cls_token, x), dim = 1)

# feed forward related classes

# 定义 DepthWiseConv2d 类,继承自 nn.Module
class DepthWiseConv2d(nn.Module):
    # 初始化函数
    def __init__(self, dim_in, dim_out, kernel_size, padding, stride = 1, bias = True):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(dim_in, dim_in, kernel_size = kernel_size, padding = padding, groups = dim_in, stride = stride, bias = bias),
            nn.Conv2d(dim_in, dim_out, kernel_size = 1, bias = bias)
        )
    # 前向传播函数
    def forward(self, x):
        return self.net(x)

# 定义 FeedForward 类,继承自 nn.Module
class FeedForward(nn.Module):
    # 初始化函数
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Conv2d(dim, hidden_dim, 1),
            nn.Hardswish(),
            DepthWiseConv2d(hidden_dim, hidden_dim, 3, padding = 1),
            nn.Hardswish(),
            nn.Dropout(dropout),
            nn.Conv2d(hidden_dim, dim, 1),
            nn.Dropout(dropout)
        )
    # 前向传播函数
    def forward(self, x):
        h = w = int(sqrt(x.shape[-2]))
        x = rearrange(x, 'b (h w) c -> b c h w', h = h, w = w)
        x = self.net(x)
        x = rearrange(x, 'b c h w -> b (h w) c')
        return x

# attention

# 定义 Attention 类,继承自 nn.Module
class Attention(nn.Module):
    # 初始化函数
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.norm = nn.LayerNorm(dim)
        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        )

    # 前向传播函数
    def forward(self, x):
        b, n, _, h = *x.shape, self.heads

        x = self.norm(x)
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)

        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

# 定义 Transformer 类,继承自 nn.Module
class Transformer(nn.Module):
    # 初始化函数
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Residual(Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                ExcludeCLS(Residual(FeedForward(dim, mlp_dim, dropout = dropout)))
            ]))
    # 前向传播函数
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x)
            x = ff(x)
        return x

# main class

# 定义 LocalViT 类,继承自 nn.Module
class LocalViT(nn.Module):
    # 初始化函数,设置模型参数和层结构
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        # 调用父类的初始化函数
        super().__init__()
        # 检查图像尺寸是否能被分块尺寸整除
        assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
        # 计算图像分块数量
        num_patches = (image_size // patch_size) ** 2
        # 计算每个分块的维度
        patch_dim = channels * patch_size ** 2

        # 定义将图像分块转换为嵌入向量的层序列
        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim),
        )

        # 初始化位置编码参数
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        # 初始化类别标记参数
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        # 初始化丢弃层
        self.dropout = nn.Dropout(emb_dropout)

        # 初始化 Transformer 模型
        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        # 定义 MLP 头部层序列
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    # 前向传播函数
    def forward(self, img):
        # 将图像转换为分块嵌入向量
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        # 重复类别标记以匹配批次大小
        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
        # 将类别标记和分块嵌入向量拼接在一起
        x = torch.cat((cls_tokens, x), dim=1)
        # 添加位置编码
        x += self.pos_embedding[:, :(n + 1)]
        # 对结果进行丢弃
        x = self.dropout(x)

        # 使用 Transformer 进行特征变换
        x = self.transformer(x)

        # 返回 MLP 头部的输出结果
        return self.mlp_head(x[:, 0])