Lucidrains-系列项目源码解析-一百零八-

55 阅读22分钟

Lucidrains 系列项目源码解析(一百零八)

.\lucidrains\vit-pytorch\vit_pytorch\mae.py

# 导入 PyTorch 库
import torch
# 从 torch 库中导入 nn 模块
from torch import nn
# 从 torch.nn.functional 模块中导入 F
import torch.nn.functional as F
# 从 einops 库中导入 repeat 函数
from einops import repeat

# 从 vit_pytorch.vit 模块中导入 Transformer 类
from vit_pytorch.vit import Transformer

# 定义一个名为 MAE 的 nn.Module 类
class MAE(nn.Module):
    # 初始化函数,接收一系列参数
    def __init__(
        self,
        *,
        encoder,
        decoder_dim,
        masking_ratio = 0.75,
        decoder_depth = 1,
        decoder_heads = 8,
        decoder_dim_head = 64
    ):
        super().__init__()
        # 断言确保 masking_ratio 在 0 和 1 之间
        assert masking_ratio > 0 and masking_ratio < 1, 'masking ratio must be kept between 0 and 1'
        # 将 masking_ratio 存储在对象中
        self.masking_ratio = masking_ratio

        # 从编码器中提取一些超参数和函数(待训练的视觉变换器)

        # 存储编码器对象
        self.encoder = encoder
        # 获取补丁数量和编码器维度
        num_patches, encoder_dim = encoder.pos_embedding.shape[-2:]

        # 获取从图像到补丁的转换函数
        self.to_patch = encoder.to_patch_embedding[0]
        # 获取从补丁到嵌入的序列
        self.patch_to_emb = nn.Sequential(*encoder.to_patch_embedding[1:])

        # 获取每个补丁的像素值
        pixel_values_per_patch = encoder.to_patch_embedding[2].weight.shape[-1]

        # 解码器参数
        # 存储解码器维度
        self.decoder_dim = decoder_dim
        # 如果编码器维度与解码器维度不同,则使用 nn.Linear 进行映射,否则使用 nn.Identity
        self.enc_to_dec = nn.Linear(encoder_dim, decoder_dim) if encoder_dim != decoder_dim else nn.Identity()
        # 初始化一个可学习的遮罩令牌
        self.mask_token = nn.Parameter(torch.randn(decoder_dim))
        # 创建一个 Transformer 解码器
        self.decoder = Transformer(dim = decoder_dim, depth = decoder_depth, heads = decoder_heads, dim_head = decoder_dim_head, mlp_dim = decoder_dim * 4)
        # 创建一个嵌入层用于解码器位置编码
        self.decoder_pos_emb = nn.Embedding(num_patches, decoder_dim)
        # 创建一个线性层用于将解码器输出映射回像素值
        self.to_pixels = nn.Linear(decoder_dim, pixel_values_per_patch)
    # 定义一个前向传播函数,接收输入图像
    def forward(self, img):
        # 获取输入图像所在设备
        device = img.device

        # 获取图像的补丁

        patches = self.to_patch(img)
        batch, num_patches, *_ = patches.shape

        # 将补丁转换为编码器标记并添加位置信息

        tokens = self.patch_to_emb(patches)
        if self.encoder.pool == "cls":
            tokens += self.encoder.pos_embedding[:, 1:(num_patches + 1)]
        elif self.encoder.pool == "mean":
            tokens += self.encoder.pos_embedding.to(device, dtype=tokens.dtype) 

        # 计算需要屏蔽的补丁数量,并获取随机索引,将其分为屏蔽和未屏蔽的部分

        num_masked = int(self.masking_ratio * num_patches)
        rand_indices = torch.rand(batch, num_patches, device=device).argsort(dim=-1)
        masked_indices, unmasked_indices = rand_indices[:, :num_masked], rand_indices[:, num_masked:]

        # 获取要编码的未屏蔽标记

        batch_range = torch.arange(batch, device=device)[:, None]
        tokens = tokens[batch_range, unmasked_indices]

        # 获取用于最终重建损失的要屏蔽的补丁

        masked_patches = patches[batch_range, masked_indices]

        # 使用视觉变换器进行注意力

        encoded_tokens = self.encoder.transformer(tokens)

        # 投影编码器到解码器维度,如果它们不相等 - 论文中说可以使用较小的维度进行解码器

        decoder_tokens = self.enc_to_dec(encoded_tokens)

        # 重新应用解码器位置嵌入到未屏蔽标记

        unmasked_decoder_tokens = decoder_tokens + self.decoder_pos_emb(unmasked_indices)

        # 重复屏蔽标记以匹配屏蔽数量,并使用上面得到的屏蔽索引添加位置

        mask_tokens = repeat(self.mask_token, 'd -> b n d', b=batch, n=num_masked)
        mask_tokens = mask_tokens + self.decoder_pos_emb(masked_indices)

        # 将屏蔽标记连接到解码器标记并使用解码器进行注意力

        decoder_tokens = torch.zeros(batch, num_patches, self.decoder_dim, device=device)
        decoder_tokens[batch_range, unmasked_indices] = unmasked_decoder_tokens
        decoder_tokens[batch_range, masked_indices] = mask_tokens
        decoded_tokens = self.decoder(decoder_tokens)

        # 剪切出屏蔽标记并投影到像素值

        mask_tokens = decoded_tokens[batch_range, masked_indices]
        pred_pixel_values = self.to_pixels(mask_tokens)

        # 计算重建损失

        recon_loss = F.mse_loss(pred_pixel_values, masked_patches)
        return recon_loss

.\lucidrains\vit-pytorch\vit_pytorch\max_vit.py

# 导入必要的库
from functools import partial

import torch
from torch import nn, einsum

from einops import rearrange, repeat
from einops.layers.torch import Rearrange, Reduce

# 辅助函数

# 检查变量是否存在
def exists(val):
    return val is not None

# 如果变量存在则返回其值,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# 将变量转换为元组,如果不是元组则重复多次
def cast_tuple(val, length = 1):
    return val if isinstance(val, tuple) else ((val,) * length)

# 辅助类

# 残差连接
class Residual(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x):
        return self.fn(x) + x

# 前馈神经网络
class FeedForward(nn.Module):
    def __init__(self, dim, mult = 4, dropout = 0.):
        super().__init__()
        inner_dim = int(dim * mult)
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, inner_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

# MBConv

# Squeeze-and-Excitation 模块
class SqueezeExcitation(nn.Module):
    def __init__(self, dim, shrinkage_rate = 0.25):
        super().__init__()
        hidden_dim = int(dim * shrinkage_rate)

        self.gate = nn.Sequential(
            Reduce('b c h w -> b c', 'mean'),
            nn.Linear(dim, hidden_dim, bias = False),
            nn.SiLU(),
            nn.Linear(hidden_dim, dim, bias = False),
            nn.Sigmoid(),
            Rearrange('b c -> b c 1 1')
        )

    def forward(self, x):
        return x * self.gate(x)

# MBConv 残差块
class MBConvResidual(nn.Module):
    def __init__(self, fn, dropout = 0.):
        super().__init__()
        self.fn = fn
        self.dropsample = Dropsample(dropout)

    def forward(self, x):
        out = self.fn(x)
        out = self.dropsample(out)
        return out + x

# 随机丢弃采样
class Dropsample(nn.Module):
    def __init__(self, prob = 0):
        super().__init__()
        self.prob = prob
  
    def forward(self, x):
        device = x.device

        if self.prob == 0. or (not self.training):
            return x

        keep_mask = torch.FloatTensor((x.shape[0], 1, 1, 1), device = device).uniform_() > self.prob
        return x * keep_mask / (1 - self.prob)

# MBConv 构建函数
def MBConv(
    dim_in,
    dim_out,
    *,
    downsample,
    expansion_rate = 4,
    shrinkage_rate = 0.25,
    dropout = 0.
):
    hidden_dim = int(expansion_rate * dim_out)
    stride = 2 if downsample else 1

    net = nn.Sequential(
        nn.Conv2d(dim_in, hidden_dim, 1),
        nn.BatchNorm2d(hidden_dim),
        nn.GELU(),
        nn.Conv2d(hidden_dim, hidden_dim, 3, stride = stride, padding = 1, groups = hidden_dim),
        nn.BatchNorm2d(hidden_dim),
        nn.GELU(),
        SqueezeExcitation(hidden_dim, shrinkage_rate = shrinkage_rate),
        nn.Conv2d(hidden_dim, dim_out, 1),
        nn.BatchNorm2d(dim_out)
    )

    if dim_in == dim_out and not downsample:
        net = MBConvResidual(net, dropout = dropout)

    return net

# 注意力相关类

# 注意力机制
class Attention(nn.Module):
    def __init__(
        self,
        dim,
        dim_head = 32,
        dropout = 0.,
        window_size = 7
    ):
        # 调用父类的构造函数
        super().__init__()
        # 断言维度应该能够被每个头的维度整除
        assert (dim % dim_head) == 0, 'dimension should be divisible by dimension per head'

        # 计算头的数量
        self.heads = dim // dim_head
        # 缩放因子
        self.scale = dim_head ** -0.5

        # LayerNorm 层
        self.norm = nn.LayerNorm(dim)
        # 线性变换,将输入维度转换为查询、键、值的维度
        self.to_qkv = nn.Linear(dim, dim * 3, bias = False)

        # 注意力机制
        self.attend = nn.Sequential(
            nn.Softmax(dim = -1),  # Softmax 激活函数
            nn.Dropout(dropout)  # Dropout 层
        )

        # 输出层
        self.to_out = nn.Sequential(
            nn.Linear(dim, dim, bias = False),  # 线性变换
            nn.Dropout(dropout)  # Dropout 层
        )

        # 相对位置偏置

        # Embedding 层,用于存储相对位置偏置
        self.rel_pos_bias = nn.Embedding((2 * window_size - 1) ** 2, self.heads)

        # 计算相对位置偏置
        pos = torch.arange(window_size)
        grid = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij'))
        grid = rearrange(grid, 'c i j -> (i j) c')
        rel_pos = rearrange(grid, 'i ... -> i 1 ...') - rearrange(grid, 'j ... -> 1 j ...')
        rel_pos += window_size - 1
        rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum(dim = -1)

        # 注册缓冲区,存储相对位置索引
        self.register_buffer('rel_pos_indices', rel_pos_indices, persistent = False)

    def forward(self, x):
        # 获取输入张量的形状信息
        batch, height, width, window_height, window_width, _, device, h = *x.shape, x.device, self.heads

        # LayerNorm 层
        x = self.norm(x)

        # 展开张量
        x = rearrange(x, 'b x y w1 w2 d -> (b x y) (w1 w2) d')

        # 为查询、键、值进行投影
        q, k, v = self.to_qkv(x).chunk(3, dim = -1)

        # 分割头
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))

        # 缩放
        q = q * self.scale

        # 计算相似度
        sim = einsum('b h i d, b h j d -> b h i j', q, k)

        # 添加相对位置偏置
        bias = self.rel_pos_bias(self.rel_pos_indices)
        sim = sim + rearrange(bias, 'i j h -> h i j')

        # 注意力机制
        attn = self.attend(sim)

        # 聚合
        out = einsum('b h i j, b h j d -> b h i d', attn, v)

        # 合并头
        out = rearrange(out, 'b h (w1 w2) d -> b w1 w2 (h d)', w1 = window_height, w2 = window_width)

        # 合并头的输出
        out = self.to_out(out)
        return rearrange(out, '(b x y) ... -> b x y ...', x = height, y = width)
# 定义一个名为 MaxViT 的神经网络模型类,继承自 nn.Module
class MaxViT(nn.Module):
    # 初始化函数,接受一系列参数
    def __init__(
        self,
        *,
        num_classes,
        dim,
        depth,
        dim_head = 32,
        dim_conv_stem = None,
        window_size = 7,
        mbconv_expansion_rate = 4,
        mbconv_shrinkage_rate = 0.25,
        dropout = 0.1,
        channels = 3
    ):
        # 调用父类的初始化函数
        super().__init__()
        # 断言 depth 是一个元组,用于指定每个阶段的 transformer 块数量
        assert isinstance(depth, tuple), 'depth needs to be tuple if integers indicating number of transformer blocks at that stage'

        # 卷积 stem

        # 如果未指定 dim_conv_stem,则设为 dim
        dim_conv_stem = default(dim_conv_stem, dim)

        # 定义卷积 stem
        self.conv_stem = nn.Sequential(
            nn.Conv2d(channels, dim_conv_stem, 3, stride = 2, padding = 1),
            nn.Conv2d(dim_conv_stem, dim_conv_stem, 3, padding = 1)
        )

        # 变量

        # 计算阶段数量
        num_stages = len(depth)

        # 计算每个阶段的维度
        dims = tuple(map(lambda i: (2 ** i) * dim, range(num_stages)))
        dims = (dim_conv_stem, *dims)
        dim_pairs = tuple(zip(dims[:-1], dims[1:]))

        # 初始化 layers 为一个空的 nn.ModuleList
        self.layers = nn.ModuleList([])

        # 为了高效的块状 - 网格状注意力,设置窗口大小的简写
        w = window_size

        # 遍历每个阶段
        for ind, ((layer_dim_in, layer_dim), layer_depth) in enumerate(zip(dim_pairs, depth)):
            for stage_ind in range(layer_depth):
                is_first = stage_ind == 0
                stage_dim_in = layer_dim_in if is_first else layer_dim

                # 定义一个块
                block = nn.Sequential(
                    MBConv(
                        stage_dim_in,
                        layer_dim,
                        downsample = is_first,
                        expansion_rate = mbconv_expansion_rate,
                        shrinkage_rate = mbconv_shrinkage_rate
                    ),
                    Rearrange('b d (x w1) (y w2) -> b x y w1 w2 d', w1 = w, w2 = w),  # 块状注意力
                    Residual(layer_dim, Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = w)),
                    Residual(layer_dim, FeedForward(dim = layer_dim, dropout = dropout)),
                    Rearrange('b x y w1 w2 d -> b d (x w1) (y w2)'),

                    Rearrange('b d (w1 x) (w2 y) -> b x y w1 w2 d', w1 = w, w2 = w),  # 网格状注���力
                    Residual(layer_dim, Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = w)),
                    Residual(layer_dim, FeedForward(dim = layer_dim, dropout = dropout)),
                    Rearrange('b x y w1 w2 d -> b d (w1 x) (w2 y)'),
                )

                # 将块添加到 layers 中
                self.layers.append(block)

        # MLP 头部

        self.mlp_head = nn.Sequential(
            Reduce('b d h w -> b d', 'mean'),
            nn.LayerNorm(dims[-1]),
            nn.Linear(dims[-1], num_classes)
        )

    # 前向传播函数
    def forward(self, x):
        # 经过卷积 stem
        x = self.conv_stem(x)

        # 遍历每个阶段的块
        for stage in self.layers:
            x = stage(x)

        # 经过 MLP 头部
        return self.mlp_head(x)

.\lucidrains\vit-pytorch\vit_pytorch\max_vit_with_registers.py

# 导入必要的库
from functools import partial

import torch
from torch import nn, einsum
import torch.nn.functional as F
from torch.nn import Module, ModuleList, Sequential

from einops import rearrange, repeat, reduce, pack, unpack
from einops.layers.torch import Rearrange, Reduce

# 辅助函数

# 检查变量是否存在
def exists(val):
    return val is not None

# 如果变量存在则返回其值,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# 将单个元素打包成指定模式的数据
def pack_one(x, pattern):
    return pack([x], pattern)

# 将数据解包成单个元素
def unpack_one(x, ps, pattern):
    return unpack(x, ps, pattern)[0]

# 将变量转换为元组,如果不是元组则重复多次
def cast_tuple(val, length = 1):
    return val if isinstance(val, tuple) else ((val,) * length

# 辅助类

# 定义前馈神经网络结构
def FeedForward(dim, mult = 4, dropout = 0.):
    inner_dim = int(dim * mult)
    return Sequential(
        nn.LayerNorm(dim),
        nn.Linear(dim, inner_dim),
        nn.GELU(),
        nn.Dropout(dropout),
        nn.Linear(inner_dim, dim),
        nn.Dropout(dropout)
    )

# MBConv

# 定义Squeeze-and-Excitation模块
class SqueezeExcitation(Module):
    def __init__(self, dim, shrinkage_rate = 0.25):
        super().__init__()
        hidden_dim = int(dim * shrinkage_rate)

        self.gate = Sequential(
            Reduce('b c h w -> b c', 'mean'),
            nn.Linear(dim, hidden_dim, bias = False),
            nn.SiLU(),
            nn.Linear(hidden_dim, dim, bias = False),
            nn.Sigmoid(),
            Rearrange('b c -> b c 1 1')
        )

    def forward(self, x):
        return x * self.gate(x)

# 定义MBConv残差模块
class MBConvResidual(Module):
    def __init__(self, fn, dropout = 0.):
        super().__init__()
        self.fn = fn
        self.dropsample = Dropsample(dropout)

    def forward(self, x):
        out = self.fn(x)
        out = self.dropsample(out)
        return out + x

# 定义DropSample模块
class Dropsample(Module):
    def __init__(self, prob = 0):
        super().__init__()
        self.prob = prob
  
    def forward(self, x):
        device = x.device

        if self.prob == 0. or (not self.training):
            return x

        keep_mask = torch.FloatTensor((x.shape[0], 1, 1, 1), device = device).uniform_() > self.prob
        return x * keep_mask / (1 - self.prob)

# 定义MBConv模块
def MBConv(
    dim_in,
    dim_out,
    *,
    downsample,
    expansion_rate = 4,
    shrinkage_rate = 0.25,
    dropout = 0.
):
    hidden_dim = int(expansion_rate * dim_out)
    stride = 2 if downsample else 1

    net = Sequential(
        nn.Conv2d(dim_in, hidden_dim, 1),
        nn.BatchNorm2d(hidden_dim),
        nn.GELU(),
        nn.Conv2d(hidden_dim, hidden_dim, 3, stride = stride, padding = 1, groups = hidden_dim),
        nn.BatchNorm2d(hidden_dim),
        nn.GELU(),
        SqueezeExcitation(hidden_dim, shrinkage_rate = shrinkage_rate),
        nn.Conv2d(hidden_dim, dim_out, 1),
        nn.BatchNorm2d(dim_out)
    )

    if dim_in == dim_out and not downsample:
        net = MBConvResidual(net, dropout = dropout)

    return net

# 注意力相关类

# 定义注意力机制模块
class Attention(Module):
    def __init__(
        self,
        dim,
        dim_head = 32,
        dropout = 0.,
        window_size = 7,
        num_registers = 1
    ):
        # 调用父类的构造函数
        super().__init__()
        # 断言寄存器数量大于0
        assert num_registers > 0
        # 断言维度应该可以被每个头的维度整除
        assert (dim % dim_head) == 0, 'dimension should be divisible by dimension per head'

        # 计算头的数量
        self.heads = dim // dim_head
        # 缩放因子
        self.scale = dim_head ** -0.5

        # LayerNorm层
        self.norm = nn.LayerNorm(dim)
        # 线性变换层,将输入维度转换为3倍的维度,用于计算Q、K、V
        self.to_qkv = nn.Linear(dim, dim * 3, bias = False)

        # 注意力机制
        self.attend = nn.Sequential(
            nn.Softmax(dim = -1),  # Softmax激活函数
            nn.Dropout(dropout)  # Dropout层
        )

        # 输出层
        self.to_out = nn.Sequential(
            nn.Linear(dim, dim, bias = False),  # 线性变换层
            nn.Dropout(dropout)  # Dropout层
        )

        # 相对位置偏差

        # 计算相对位置偏差的数量
        num_rel_pos_bias = (2 * window_size - 1) ** 2

        # Embedding层,用于存储相对位置偏差
        self.rel_pos_bias = nn.Embedding(num_rel_pos_bias + 1, self.heads)

        # 生成相对位置偏差的索引
        pos = torch.arange(window_size)
        grid = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij'))
        grid = rearrange(grid, 'c i j -> (i j) c')
        rel_pos = rearrange(grid, 'i ... -> i 1 ...') - rearrange(grid, 'j ... -> 1 j ...')
        rel_pos += window_size - 1
        rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum(dim = -1)

        # 对相对位置偏差索引进行填充
        rel_pos_indices = F.pad(rel_pos_indices, (num_registers, 0, num_registers, 0), value = num_rel_pos_bias)
        self.register_buffer('rel_pos_indices', rel_pos_indices, persistent = False)

    def forward(self, x):
        # 获取设备信息、头的数量、相对位置偏差索引
        device, h, bias_indices = x.device, self.heads, self.rel_pos_indices

        # LayerNorm层
        x = self.norm(x)

        # 为查询、键、值进行投影
        q, k, v = self.to_qkv(x).chunk(3, dim = -1)

        # 分割头
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))

        # 缩放
        q = q * self.scale

        # 计算相似度
        sim = einsum('b h i d, b h j d -> b h i j', q, k)

        # 添加位置偏差
        bias = self.rel_pos_bias(bias_indices)
        sim = sim + rearrange(bias, 'i j h -> h i j')

        # 注意力机制
        attn = self.attend(sim)

        # 聚合
        out = einsum('b h i j, b h j d -> b h i d', attn, v)

        # 合并头部输出
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)
class MaxViT(Module):
    # 定义一个名为 MaxViT 的类,继承自 Module 类
    def __init__(
        self,
        *,
        num_classes,
        dim,
        depth,
        dim_head = 32,
        dim_conv_stem = None,
        window_size = 7,
        mbconv_expansion_rate = 4,
        mbconv_shrinkage_rate = 0.25,
        dropout = 0.1,
        channels = 3,
        num_register_tokens = 4
    ):
        # 初始化函数,接受一系列参数
        super().__init__()
        # 调用父类的初始化函数

        assert isinstance(depth, tuple), 'depth needs to be tuple if integers indicating number of transformer blocks at that stage'
        assert num_register_tokens > 0
        # 断言语句,确保 depth 是元组类型,num_register_tokens 大于 0

        # convolutional stem

        dim_conv_stem = default(dim_conv_stem, dim)
        # 如果 dim_conv_stem 为 None,则设置为 dim

        self.conv_stem = Sequential(
            nn.Conv2d(channels, dim_conv_stem, 3, stride = 2, padding = 1),
            nn.Conv2d(dim_conv_stem, dim_conv_stem, 3, padding = 1)
        )
        # 创建一个包含两个卷积层的 Sequential 对象,作为卷积部分的网络结构

        # variables

        num_stages = len(depth)
        # 计算 depth 的长度,作为阶段数

        dims = tuple(map(lambda i: (2 ** i) * dim, range(num_stages)))
        dims = (dim_conv_stem, *dims)
        # 计算每个阶段的维度

        dim_pairs = tuple(zip(dims[:-1], dims[1:]))
        # 将维度组成成对

        self.layers = nn.ModuleList([])
        # 创建一个空的 ModuleList 对象用于存储网络层

        # window size

        self.window_size = window_size
        # 设置窗口大小

        self.register_tokens = nn.ParameterList([])
        # 创建一个空的 ParameterList 对象用于存储注册令牌

        # iterate through stages

        for ind, ((layer_dim_in, layer_dim), layer_depth) in enumerate(zip(dim_pairs, depth)):
            # 遍历每个阶段
            for stage_ind in range(layer_depth):
                is_first = stage_ind == 0
                stage_dim_in = layer_dim_in if is_first else layer_dim
                # 判断是否为当前阶段的第一个块

                conv = MBConv(
                    stage_dim_in,
                    layer_dim,
                    downsample = is_first,
                    expansion_rate = mbconv_expansion_rate,
                    shrinkage_rate = mbconv_shrinkage_rate
                )
                # 创建一个 MBConv 对象

                block_attn = Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = window_size, num_registers = num_register_tokens)
                block_ff = FeedForward(dim = layer_dim, dropout = dropout)
                # 创建注意力和前馈网络对象

                grid_attn = Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = window_size, num_registers = num_register_tokens)
                grid_ff = FeedForward(dim = layer_dim, dropout = dropout)
                # 创建注意力和前馈网络对象

                register_tokens = nn.Parameter(torch.randn(num_register_tokens, layer_dim))
                # 创建一个随机初始化的注册令牌

                self.layers.append(ModuleList([
                    conv,
                    ModuleList([block_attn, block_ff]),
                    ModuleList([grid_attn, grid_ff])
                ]))
                # 将卷积层、注意力和前馈网络组成的模块列表添加到网络层中

                self.register_tokens.append(register_tokens)
                # 将注册令牌添加到注册令牌列表中

        # mlp head out

        self.mlp_head = nn.Sequential(
            Reduce('b d h w -> b d', 'mean'),
            nn.LayerNorm(dims[-1]),
            nn.Linear(dims[-1], num_classes)
        )
        # 创建一个线性层用于分类
    # 定义前向传播函数,接受输入张量 x
    def forward(self, x):
        # 获取输入张量 x 的批量大小 b 和窗口大小 w
        b, w = x.shape[0], self.window_size

        # 对输入张量 x 进行卷积操作
        x = self.conv_stem(x)

        # 遍历每个层的操作,包括卷积、注意力机制和前馈网络
        for (conv, (block_attn, block_ff), (grid_attn, grid_ff)), register_tokens in zip(self.layers, self.register_tokens):
            # 对输入张量 x 进行卷积操作
            x = conv(x)

            # block-like attention

            # 对输入张量 x 进行重新排列操作,将其转换为多维矩阵
            x = rearrange(x, 'b d (x w1) (y w2) -> b x y w1 w2 d', w1 = w, w2 = w)

            # 准备注册令牌

            # 将注册令牌进行重复操作,以匹配输入张量 x 的形状
            r = repeat(register_tokens, 'n d -> b x y n d', b = b, x = x.shape[1],y = x.shape[2])
            r, register_batch_ps = pack_one(r, '* n d')

            x, window_ps = pack_one(x, 'b x y * d')
            x, batch_ps  = pack_one(x, '* n d')
            x, register_ps = pack([r, x], 'b * d')

            # 对输入张量 x 进行块状注意力操作,并与原始输入相加
            x = block_attn(x) + x
            # 对输入张量 x 进行块状前馈网络操作,并与原始输入相加
            x = block_ff(x) + x

            r, x = unpack(x, register_ps, 'b * d')

            x = unpack_one(x, batch_ps, '* n d')
            x = unpack_one(x, window_ps, 'b x y * d')
            x = rearrange(x, 'b x y w1 w2 d -> b d (x w1) (y w2)')

            r = unpack_one(r, register_batch_ps, '* n d')

            # grid-like attention

            # 对输入张量 x 进行重新排列操作,将其转换为多维矩阵
            x = rearrange(x, 'b d (w1 x) (w2 y) -> b x y w1 w2 d', w1 = w, w2 = w)

            # 准备注册令牌

            # 对注册令牌进行降维操作,计算均值
            r = reduce(r, 'b x y n d -> b n d', 'mean')
            r = repeat(r, 'b n d -> b x y n d', x = x.shape[1], y = x.shape[2])
            r, register_batch_ps = pack_one(r, '* n d')

            x, window_ps = pack_one(x, 'b x y * d')
            x, batch_ps  = pack_one(x, '* n d')
            x, register_ps = pack([r, x], 'b * d')

            # 对输入张量 x 进行网格状注意力操作,并与原始输入相加
            x = grid_attn(x) + x

            r, x = unpack(x, register_ps, 'b * d')

            # 对输入张量 x 进行网格状前馈网络操作,并与��始输入相加
            x = grid_ff(x) + x

            x = unpack_one(x, batch_ps, '* n d')
            x = unpack_one(x, window_ps, 'b x y * d')
            x = rearrange(x, 'b x y w1 w2 d -> b d (w1 x) (w2 y)')

        # 返回经过 MLP 头部处理后的结果
        return self.mlp_head(x)

.\lucidrains\vit-pytorch\vit_pytorch\mobile_vit.py

import torch
import torch.nn as nn

from einops import rearrange
from einops.layers.torch import Reduce

# helpers

# 定义一个 1x1 卷积层 + 批归一化 + SiLU 激活函数的函数
def conv_1x1_bn(inp, oup):
    return nn.Sequential(
        nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
        nn.BatchNorm2d(oup),
        nn.SiLU()
    )

# 定义一个 nxn 卷积层 + 批归一化 + SiLU 激活函数的函数
def conv_nxn_bn(inp, oup, kernel_size=3, stride=1):
    return nn.Sequential(
        nn.Conv2d(inp, oup, kernel_size, stride, 1, bias=False),
        nn.BatchNorm2d(oup),
        nn.SiLU()
    )

# classes

# 定义一个前馈神经网络类
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

# 定义一个注意力机制类
class Attention(nn.Module):
    def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
        super().__init__()
        inner_dim = dim_head * heads
        self.heads = heads
        self.scale = dim_head ** -0.5

        self.norm = nn.LayerNorm(dim)
        self.attend = nn.Softmax(dim=-1)
        self.dropout = nn.Dropout(dropout)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        x = self.norm(x)
        qkv = self.to_qkv(x).chunk(3, dim=-1)

        q, k, v = map(lambda t: rearrange(t, 'b p n (h d) -> b p h n d', h=self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b p h n d -> b p n (h d)')
        return self.to_out(out)

# 定义一个 Transformer 类
class Transformer(nn.Module):
    """Transformer block described in ViT.
    Paper: https://arxiv.org/abs/2010.11929
    Based on: https://github.com/lucidrains/vit-pytorch
    """

    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads, dim_head, dropout),
                FeedForward(dim, mlp_dim, dropout)
            ]))

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

# 定义一个 MV2Block 类
class MV2Block(nn.Module):
    """MV2 block described in MobileNetV2.
    Paper: https://arxiv.org/pdf/1801.04381
    Based on: https://github.com/tonylins/pytorch-mobilenet-v2
    """

    def __init__(self, inp, oup, stride=1, expansion=4):
        super().__init__()
        self.stride = stride
        assert stride in [1, 2]

        hidden_dim = int(inp * expansion)
        self.use_res_connect = self.stride == 1 and inp == oup

        if expansion == 1:
            self.conv = nn.Sequential(
                # dw
                nn.Conv2d(hidden_dim, hidden_dim, 3, stride,
                          1, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.SiLU(),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )
        else:
            self.conv = nn.Sequential(
                # pw
                nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.SiLU(),
                # dw
                nn.Conv2d(hidden_dim, hidden_dim, 3, stride,
                          1, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.SiLU(),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )
    # 定义一个前向传播函数,接受输入 x
    def forward(self, x):
        # 将输入 x 通过卷积层 conv 处理得到输出 out
        out = self.conv(x)
        # 如果使用残差连接
        if self.use_res_connect:
            # 将输出 out 与输入 x 相加,实现残差连接
            out = out + x
        # 返回处理后的输出 out
        return out
class MobileViTBlock(nn.Module):
    def __init__(self, dim, depth, channel, kernel_size, patch_size, mlp_dim, dropout=0.):
        super().__init__()
        self.ph, self.pw = patch_size

        # 定义卷积层1,用于局部表示
        self.conv1 = conv_nxn_bn(channel, channel, kernel_size)
        # 定义卷积层2,用于局部表示到全局表示的转换
        self.conv2 = conv_1x1_bn(channel, dim)

        # 定义 Transformer 模块,用于全局表示
        self.transformer = Transformer(dim, depth, 4, 8, mlp_dim, dropout)

        # 定义卷积层3,用于全局表示到局部表示的转换
        self.conv3 = conv_1x1_bn(dim, channel)
        # 定义卷积层4,用于融合局部和全局表示
        self.conv4 = conv_nxn_bn(2 * channel, channel, kernel_size)

    def forward(self, x):
        y = x.clone()

        # 计算局部表示
        x = self.conv1(x)
        x = self.conv2(x)

        # 计算全局表示
        _, _, h, w = x.shape
        x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw)
        x = self.transformer(x)        
        x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h//self.ph, w=w//self.pw, ph=self.ph, pw=self.pw)

        # 融合局部和全局表示
        x = self.conv3(x)
        x = torch.cat((x, y), 1)
        x = self.conv4(x)
        return x

class MobileViT(nn.Module):
    """MobileViT.
    Paper: https://arxiv.org/abs/2110.02178
    Based on: https://github.com/chinhsuanwu/mobilevit-pytorch
    """

    def __init__(
        self,
        image_size,
        dims,
        channels,
        num_classes,
        expansion=4,
        kernel_size=3,
        patch_size=(2, 2),
        depths=(2, 4, 3)
    ):
        super().__init__()
        assert len(dims) == 3, 'dims must be a tuple of 3'
        assert len(depths) == 3, 'depths must be a tuple of 3'

        ih, iw = image_size
        ph, pw = patch_size
        assert ih % ph == 0 and iw % pw == 0

        init_dim, *_, last_dim = channels

        # 定义第一个卷积层,用于图像输入的预处理
        self.conv1 = conv_nxn_bn(3, init_dim, stride=2)

        # 定义 stem 部分的卷积块
        self.stem = nn.ModuleList([])
        self.stem.append(MV2Block(channels[0], channels[1], 1, expansion))
        self.stem.append(MV2Block(channels[1], channels[2], 2, expansion))
        self.stem.append(MV2Block(channels[2], channels[3], 1, expansion))
        self.stem.append(MV2Block(channels[2], channels[3], 1, expansion))

        # 定义 trunk 部分的卷积块和 MobileViTBlock
        self.trunk = nn.ModuleList([])
        self.trunk.append(nn.ModuleList([
            MV2Block(channels[3], channels[4], 2, expansion),
            MobileViTBlock(dims[0], depths[0], channels[5],
                           kernel_size, patch_size, int(dims[0] * 2))
        ]))

        self.trunk.append(nn.ModuleList([
            MV2Block(channels[5], channels[6], 2, expansion),
            MobileViTBlock(dims[1], depths[1], channels[7],
                           kernel_size, patch_size, int(dims[1] * 4))
        ]))

        self.trunk.append(nn.ModuleList([
            MV2Block(channels[7], channels[8], 2, expansion),
            MobileViTBlock(dims[2], depths[2], channels[9],
                           kernel_size, patch_size, int(dims[2] * 4))
        ]))

        # 定义输出层,包括卷积、池化和全连接层
        self.to_logits = nn.Sequential(
            conv_1x1_bn(channels[-2], last_dim),
            Reduce('b c h w -> b c', 'mean'),
            nn.Linear(channels[-1], num_classes, bias=False)
        )

    def forward(self, x):
        x = self.conv1(x)

        # stem 部分的卷积块
        for conv in self.stem:
            x = conv(x)

        # trunk 部分的卷积块和 MobileViTBlock
        for conv, attn in self.trunk:
            x = conv(x)
            x = attn(x)

        return self.to_logits(x)

.\lucidrains\vit-pytorch\vit_pytorch\mp3.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块和 einsum 函数
from torch import nn, einsum
# 从 torch 库中导入 nn.functional 模块并重命名为 F
import torch.nn.functional as F

# 从 einops 库中导入 rearrange 和 repeat 函数,从 einops.layers.torch 中导入 Rearrange 类

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# 辅助函数

# 判断变量是否存在的函数
def exists(val):
    return val is not None

# 如果变量存在则返回变量,否则返回默认值的函数
def default(val, d):
    return val if exists(val) else d

# 将输入转换为元组的函数
def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# 位置嵌入

# 生成二维正弦余弦位置嵌入的函数
def posemb_sincos_2d(patches, temperature = 10000, dtype = torch.float32):
    # 获取 patches 的形状信息
    _, h, w, dim, device, dtype = *patches.shape, patches.device, patches.dtype

    # 创建网格矩阵 y 和 x
    y, x = torch.meshgrid(torch.arange(h, device = device), torch.arange(w, device = device), indexing = 'ij')
    # 断言特征维度必须是 4 的倍数
    assert (dim % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb'
    # 计算 omega
    omega = torch.arange(dim // 4, device = device) / (dim // 4 - 1)
    omega = 1. / (temperature ** omega)

    # 计算位置嵌入
    y = y.flatten()[:, None] * omega[None, :]
    x = x.flatten()[:, None] * omega[None, :]
    pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1)
    return pe.type(dtype)

# 前馈网络

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

# (跨)注意力机制

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.norm = nn.LayerNorm(dim)

        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x, context = None):
        b, n, _, h = *x.shape, self.heads

        x = self.norm(x)

        context = self.norm(context) if exists(context) else x

        qkv = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)

        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
                FeedForward(dim, mlp_dim, dropout = dropout)
            ]))
    def forward(self, x, context = None):
        for attn, ff in self.layers:
            x = attn(x, context = context) + x
            x = ff(x) + x
        return x

class ViT(nn.Module):
    # 初始化函数,定义模型的参数和结构
    def __init__(self, *, num_classes, image_size, patch_size, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0.):
        # 调用父类的初始化函数
        super().__init__()
        # 获取图像的高度和宽度
        image_height, image_width = pair(image_size)
        # 获取补丁的高度和宽度
        patch_height, patch_width = pair(patch_size)

        # 断言图像的高度和宽度必须能够被补丁的高度和宽度整除
        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
        
        # 计算补丁的数量
        num_patches = (image_height // patch_height) * (image_width // patch_width)
        # 计算每个补丁的维度
        patch_dim = channels * patch_height * patch_width

        # 设置模型的维度和补丁数量
        self.dim = dim
        self.num_patches = num_patches

        # 定义将图像转换为补丁嵌入的层序列
        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b h w (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim),
        )

        # 创建 Transformer 模型
        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        # 定义将潜在表示转换为输出类别的层
        self.to_latent = nn.Identity()
        self.linear_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    # 前向传播函数,定义模型的前向计算过程
    def forward(self, img):
        # 获取图像的形状和数据类型
        *_, h, w, dtype = *img.shape, img.dtype

        # 将图像转换为补丁嵌入
        x = self.to_patch_embedding(img)
        # 生成位置编码
        pe = posemb_sincos_2d(x)
        # 将位置编码加到补丁嵌入中
        x = rearrange(x, 'b ... d -> b (...) d') + pe

        # 经过 Transformer 模型处理
        x = self.transformer(x)
        # 对补丁进行平均池化
        x = x.mean(dim = 1)

        # 转换为潜在表示
        x = self.to_latent(x)
        # 经过线性层得到输出类别
        return self.linear_head(x)
# 定义 Masked Position Prediction Pre-Training 类
class MP3(nn.Module):
    # 初始化函数,接受 ViT 模型和 masking 比例作为参数
    def __init__(self, vit: ViT, masking_ratio):
        super().__init__()
        self.vit = vit

        # 断言确保 masking 比例在 0 到 1 之间
        assert masking_ratio > 0 and masking_ratio < 1, 'masking ratio must be kept between 0 and 1'
        self.masking_ratio = masking_ratio

        # 获取 ViT 模型的维度
        dim = vit.dim
        # 定义 MLP 头部,包含 LayerNorm 和 Linear 层
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, vit.num_patches)
        )

    # 前向传播函数,接受图像作为输入
    def forward(self, img):
        # 获取输入图像的设备信息
        device = img.device
        # 将图像转换为 token
        tokens = self.vit.to_patch_embedding(img)
        # 重新排列 token 的维度
        tokens = rearrange(tokens, 'b ... d -> b (...) d')

        # 获取 batch 大小和 patch 数量
        batch, num_patches, *_ = tokens.shape

        # Masking
        # 计算需要被 mask 的数量
        num_masked = int(self.masking_ratio * num_patches)
        # 生成随机索引并排序
        rand_indices = torch.rand(batch, num_patches, device=device).argsort(dim=-1)
        masked_indices, unmasked_indices = rand_indices[:, :num_masked], rand_indices[:, num_masked:]

        # 生成 batch 范围的索引
        batch_range = torch.arange(batch, device=device)[:, None]
        # 获取未被 mask 的 token
        tokens_unmasked = tokens[batch_range, unmasked_indices]

        # 使用 ViT 模型的 transformer 进行注意力计算
        attended_tokens = self.vit.transformer(tokens, tokens_unmasked)
        # 将输出结果通过 MLP 头部得到 logits
        logits = rearrange(self.mlp_head(attended_tokens), 'b n d -> (b n) d')
        
        # 定义标签
        labels = repeat(torch.arange(num_patches, device=device), 'n -> (b n)', b=batch)
        # 计算交叉熵损失
        loss = F.cross_entropy(logits, labels)

        return loss

.\lucidrains\vit-pytorch\vit_pytorch\mpp.py

# 导入数学库
import math

# 导入 PyTorch 库
import torch
from torch import nn
import torch.nn.functional as F

# 导入 einops 库中的函数
from einops import rearrange, repeat, reduce

# 辅助函数

# 检查值是否存在
def exists(val):
    return val is not None

# 创建概率掩码
def prob_mask_like(t, prob):
    batch, seq_length, _ = t.shape
    return torch.zeros((batch, seq_length)).float().uniform_(0, 1) < prob

# 根据概率获取掩码子集
def get_mask_subset_with_prob(patched_input, prob):
    batch, seq_len, _, device = *patched_input.shape, patched_input.device
    max_masked = math.ceil(prob * seq_len)

    rand = torch.rand((batch, seq_len), device=device)
    _, sampled_indices = rand.topk(max_masked, dim=-1)

    new_mask = torch.zeros((batch, seq_len), device=device)
    new_mask.scatter_(1, sampled_indices, 1)
    return new_mask.bool()


# MPP 损失函数

class MPPLoss(nn.Module):
    def __init__(
        self,
        patch_size,
        channels,
        output_channel_bits,
        max_pixel_val,
        mean,
        std
    ):
        super().__init__()
        self.patch_size = patch_size
        self.channels = channels
        self.output_channel_bits = output_channel_bits
        self.max_pixel_val = max_pixel_val

        self.mean = torch.tensor(mean).view(-1, 1, 1) if mean else None
        self.std = torch.tensor(std).view(-1, 1, 1) if std else None

    def forward(self, predicted_patches, target, mask):
        p, c, mpv, bits, device = self.patch_size, self.channels, self.max_pixel_val, self.output_channel_bits, target.device
        bin_size = mpv / (2 ** bits)

        # 反归一化输入
        if exists(self.mean) and exists(self.std):
            target = target * self.std + self.mean

        # 将目标数据重塑为补丁
        target = target.clamp(max=mpv)  # 为了安全起见,进行截断
        avg_target = reduce(target, 'b c (h p1) (w p2) -> b (h w) c', 'mean', p1=p, p2=p).contiguous()

        channel_bins = torch.arange(bin_size, mpv, bin_size, device=device)
        discretized_target = torch.bucketize(avg_target, channel_bins)

        bin_mask = (2 ** bits) ** torch.arange(0, c, device=device).long()
        bin_mask = rearrange(bin_mask, 'c -> () () c')

        target_label = torch.sum(bin_mask * discretized_target, dim=-1)

        loss = F.cross_entropy(predicted_patches[mask], target_label[mask])
        return loss


# 主类

class MPP(nn.Module):
    def __init__(
        self,
        transformer,
        patch_size,
        dim,
        output_channel_bits=3,
        channels=3,
        max_pixel_val=1.0,
        mask_prob=0.15,
        replace_prob=0.5,
        random_patch_prob=0.5,
        mean=None,
        std=None
    ):
        super().__init__()
        self.transformer = transformer
        self.loss = MPPLoss(patch_size, channels, output_channel_bits,
                            max_pixel_val, mean, std)

        # 提取补丁函数
        self.patch_to_emb = nn.Sequential(transformer.to_patch_embedding[1:])

        # 输出转换
        self.to_bits = nn.Linear(dim, 2**(output_channel_bits * channels))

        # ViT 相关维度
        self.patch_size = patch_size

        # MPP 相关概率
        self.mask_prob = mask_prob
        self.replace_prob = replace_prob
        self.random_patch_prob = random_patch_prob

        # 令牌 ID
        self.mask_token = nn.Parameter(torch.randn(1, 1, channels * patch_size ** 2))
    # 定义前向传播函数,接受输入和其他参数
    def forward(self, input, **kwargs):
        # 获取变换器
        transformer = self.transformer
        # 克隆原始图像用于计算损失
        img = input.clone().detach()

        # 将原始图像重塑为补丁
        p = self.patch_size
        input = rearrange(input,
                          'b c (h p1) (w p2) -> b (h w) (p1 p2 c)',
                          p1=p,
                          p2=p)

        # 使用概率获取子集的掩码
        mask = get_mask_subset_with_prob(input, self.mask_prob)

        # 使用掩码补丁以概率替换输入(以概率 1 - replace_prob 保持补丁不变)
        masked_input = input.clone().detach()

        # 如果随机令牌概率 > 0 用于 mpp
        if self.random_patch_prob > 0:
            random_patch_sampling_prob = self.random_patch_prob / (
                1 - self.replace_prob)
            random_patch_prob = prob_mask_like(input,
                                               random_patch_sampling_prob).to(mask.device)

            bool_random_patch_prob = mask * (random_patch_prob == True)
            random_patches = torch.randint(0,
                                           input.shape[1],
                                           (input.shape[0], input.shape[1]),
                                           device=input.device)
            randomized_input = masked_input[
                torch.arange(masked_input.shape[0]).unsqueeze(-1),
                random_patches]
            masked_input[bool_random_patch_prob] = randomized_input[
                bool_random_patch_prob]

        # [mask] 输入
        replace_prob = prob_mask_like(input, self.replace_prob).to(mask.device)
        bool_mask_replace = (mask * replace_prob) == True
        masked_input[bool_mask_replace] = self.mask_token

        # 补丁的线性嵌入
        masked_input = self.patch_to_emb(masked_input)

        # 将 cls 令牌添加到输入序列
        b, n, _ = masked_input.shape
        cls_tokens = repeat(transformer.cls_token, '() n d -> b n d', b=b)
        masked_input = torch.cat((cls_tokens, masked_input), dim=1)

        # 将位置嵌入添加到输入
        masked_input += transformer.pos_embedding[:, :(n + 1)]
        masked_input = transformer.dropout(masked_input)

        # 获取生成器输出并计算 mpp 损失
        masked_input = transformer.transformer(masked_input, **kwargs)
        cls_logits = self.to_bits(masked_input)
        logits = cls_logits[:, 1:, :]

        mpp_loss = self.loss(logits, img, mask)

        return mpp_loss

.\lucidrains\vit-pytorch\vit_pytorch\na_vit.py

from functools import partial
from typing import List, Union

import torch
import torch.nn.functional as F
from torch import nn, Tensor
from torch.nn.utils.rnn import pad_sequence as orig_pad_sequence

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# helpers

# 检查值是否存在
def exists(val):
    return val is not None

# 返回默认值
def default(val, d):
    return val if exists(val) else d

# 返回固定值的函数
def always(val):
    return lambda *args: val

# 将输入转换为元组
def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# 检查一个数是否可以被另一个数整除
def divisible_by(numer, denom):
    return (numer % denom) == 0

# auto grouping images

# 根据最大序列长度对图像进行分组
def group_images_by_max_seq_len(
    images: List[Tensor],
    patch_size: int,
    calc_token_dropout = None,
    max_seq_len = 2048

) -> List[List[Tensor]]:

    calc_token_dropout = default(calc_token_dropout, always(0.))

    groups = []
    group = []
    seq_len = 0

    if isinstance(calc_token_dropout, (float, int)):
        calc_token_dropout = always(calc_token_dropout)

    for image in images:
        assert isinstance(image, Tensor)

        image_dims = image.shape[-2:]
        ph, pw = map(lambda t: t // patch_size, image_dims)

        image_seq_len = (ph * pw)
        image_seq_len = int(image_seq_len * (1 - calc_token_dropout(*image_dims)))

        assert image_seq_len <= max_seq_len, f'image with dimensions {image_dims} exceeds maximum sequence length'

        if (seq_len + image_seq_len) > max_seq_len:
            groups.append(group)
            group = []
            seq_len = 0

        group.append(image)
        seq_len += image_seq_len

    if len(group) > 0:
        groups.append(group)

    return groups

# normalization
# they use layernorm without bias, something that pytorch does not offer

# 自定义 LayerNorm 类
class LayerNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(dim))
        self.register_buffer('beta', torch.zeros(dim))

    def forward(self, x):
        return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)

# they use a query-key normalization that is equivalent to rms norm (no mean-centering, learned gamma), from vit 22B paper

# 自定义 RMSNorm 类
class RMSNorm(nn.Module):
    def __init__(self, heads, dim):
        super().__init__()
        self.scale = dim ** 0.5
        self.gamma = nn.Parameter(torch.ones(heads, 1, dim))

    def forward(self, x):
        normed = F.normalize(x, dim = -1)
        return normed * self.scale * self.gamma

# feedforward

# 定义 FeedForward 函数
def FeedForward(dim, hidden_dim, dropout = 0.):
    return nn.Sequential(
        LayerNorm(dim),
        nn.Linear(dim, hidden_dim),
        nn.GELU(),
        nn.Dropout(dropout),
        nn.Linear(hidden_dim, dim),
        nn.Dropout(dropout)
    )

# 定义 Attention 类
class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        self.heads = heads
        self.norm = LayerNorm(dim)

        self.q_norm = RMSNorm(heads, dim_head)
        self.k_norm = RMSNorm(heads, dim_head)

        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim, bias = False),
            nn.Dropout(dropout)
        )

    def forward(
        self,
        x,
        context = None,
        mask = None,
        attn_mask = None
        ):
        # 对输入进行归一化处理
        x = self.norm(x)
        # 从上下文中获取默认的键值对输入
        kv_input = default(context, x)

        # 将输入数据转换为查询、键、值三部分
        qkv = (self.to_q(x), *self.to_kv(kv_input).chunk(2, dim = -1))

        # 将查询、键、值进行维度重排,以适应多头注意力机制
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        # 对查询和键进行归一化处理
        q = self.q_norm(q)
        k = self.k_norm(k)

        # 计算查询和键的点积
        dots = torch.matmul(q, k.transpose(-1, -2))

        # 如果存在掩码,则进行掩码处理
        if exists(mask):
            mask = rearrange(mask, 'b j -> b 1 1 j')
            dots = dots.masked_fill(~mask, -torch.finfo(dots.dtype).max)

        # 如果存在注意力掩码,则进行掩码处理
        if exists(attn_mask):
            dots = dots.masked_fill(~attn_mask, -torch.finfo(dots.dtype).max)

        # 进行注意力计算
        attn = self.attend(dots)
        # 对注意力结果进行dropout处理
        attn = self.dropout(attn)

        # 将注意力结果与值进行加权求和
        out = torch.matmul(attn, v)
        # 将结果进行维度重排,返回输出
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)
class Transformer(nn.Module):
    # 定义 Transformer 类,继承自 nn.Module
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        # 初始化函数,接受维度、深度、头数、头维度、MLP维度和dropout参数
        super().__init__()
        # 调用父类的初始化函数
        self.layers = nn.ModuleList([])
        # 初始化层列表
        for _ in range(depth):
            # 循环创建指定深度的层
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
                FeedForward(dim, mlp_dim, dropout = dropout)
            ]))

        self.norm = LayerNorm(dim)
        # 初始化 LayerNorm 层

    def forward(
        self,
        x,
        mask = None,
        attn_mask = None
    ):
        # 前向传播函数
        for attn, ff in self.layers:
            # 遍历层列表
            x = attn(x, mask = mask, attn_mask = attn_mask) + x
            # 使用注意力机制处理输入并加上残差连接
            x = ff(x) + x
            # 使用前馈网络处理输入并加上残差连接

        return self.norm(x)
        # 返回经过 LayerNorm 处理后的结果

class NaViT(nn.Module):
    # 定义 NaViT 类,继承自 nn.Module
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., token_dropout_prob = None):
        # 初始化函数,接受图像大小、补丁大小、类别数、维度、深度、头数、MLP维度等参数
        super().__init__()
        # 调用父类的初始化函数
        image_height, image_width = pair(image_size)

        # what percent of tokens to dropout
        # if int or float given, then assume constant dropout prob
        # otherwise accept a callback that in turn calculates dropout prob from height and width

        self.calc_token_dropout = None

        if callable(token_dropout_prob):
            self.calc_token_dropout = token_dropout_prob

        elif isinstance(token_dropout_prob, (float, int)):
            assert 0. < token_dropout_prob < 1.
            token_dropout_prob = float(token_dropout_prob)
            self.calc_token_dropout = lambda height, width: token_dropout_prob

        # calculate patching related stuff

        assert divisible_by(image_height, patch_size) and divisible_by(image_width, patch_size), 'Image dimensions must be divisible by the patch size.'

        patch_height_dim, patch_width_dim = (image_height // patch_size), (image_width // patch_size)
        patch_dim = channels * (patch_size ** 2)

        self.channels = channels
        self.patch_size = patch_size

        self.to_patch_embedding = nn.Sequential(
            LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            LayerNorm(dim),
        )

        self.pos_embed_height = nn.Parameter(torch.randn(patch_height_dim, dim))
        self.pos_embed_width = nn.Parameter(torch.randn(patch_width_dim, dim))

        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        # final attention pooling queries

        self.attn_pool_queries = nn.Parameter(torch.randn(dim))
        self.attn_pool = Attention(dim = dim, dim_head = dim_head, heads = heads)

        # output to logits

        self.to_latent = nn.Identity()

        self.mlp_head = nn.Sequential(
            LayerNorm(dim),
            nn.Linear(dim, num_classes, bias = False)
        )

    @property
    def device(self):
        return next(self.parameters()).device

    def forward(
        self,
        batched_images: Union[List[Tensor], List[List[Tensor]]], # assume different resolution images already grouped correctly
        group_images = False,
        group_max_seq_len = 2048

.\lucidrains\vit-pytorch\vit_pytorch\nest.py

# 导入必要的库
from functools import partial
import torch
from torch import nn, einsum

# 导入 einops 库中的 rearrange 和 Reduce 函数
from einops import rearrange
from einops.layers.torch import Rearrange, Reduce

# 定义一个辅助函数,用于将输入值转换为元组
def cast_tuple(val, depth):
    return val if isinstance(val, tuple) else ((val,) * depth)

# 定义 LayerNorm 类,用于实现层归一化
class LayerNorm(nn.Module):
    def __init__(self, dim, eps = 1e-5):
        super().__init__()
        self.eps = eps
        self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
        self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))

    def forward(self, x):
        var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
        mean = torch.mean(x, dim = 1, keepdim = True)
        return (x - mean) / (var + self.eps).sqrt() * self.g + self.b

# 定义 FeedForward 类,用于实现前馈神经网络
class FeedForward(nn.Module):
    def __init__(self, dim, mlp_mult = 4, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            LayerNorm(dim),
            nn.Conv2d(dim, dim * mlp_mult, 1),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Conv2d(dim * mlp_mult, dim, 1),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

# 定义 Attention 类,用于实现注意力机制
class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dropout = 0.):
        super().__init__()
        dim_head = dim // heads
        inner_dim = dim_head * heads
        self.heads = heads
        self.scale = dim_head ** -0.5

        self.norm = LayerNorm(dim)
        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)
        self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)

        self.to_out = nn.Sequential(
            nn.Conv2d(inner_dim, dim, 1),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        b, c, h, w, heads = *x.shape, self.heads

        x = self.norm(x)

        qkv = self.to_qkv(x).chunk(3, dim = 1)
        q, k, v = map(lambda t: rearrange(t, 'b (h d) x y -> b h (x y) d', h = heads), qkv)

        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w)
        return self.to_out(out)

# 定义 Aggregate 函数,用于聚合特征
def Aggregate(dim, dim_out):
    return nn.Sequential(
        nn.Conv2d(dim, dim_out, 3, padding = 1),
        LayerNorm(dim_out),
        nn.MaxPool2d(3, stride = 2, padding = 1)
    )

# 定义 Transformer 类,用于实现 Transformer 模型
class Transformer(nn.Module):
    def __init__(self, dim, seq_len, depth, heads, mlp_mult, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        self.pos_emb = nn.Parameter(torch.randn(seq_len))

        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dropout = dropout),
                FeedForward(dim, mlp_mult, dropout = dropout)
            ]))
    def forward(self, x):
        *_, h, w = x.shape

        pos_emb = self.pos_emb[:(h * w)]
        pos_emb = rearrange(pos_emb, '(h w) -> () () h w', h = h, w = w)
        x = x + pos_emb

        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

# 定义 NesT 类,用于实现 NesT 模型
class NesT(nn.Module):
    def __init__(
        self,
        *,
        image_size,
        patch_size,
        num_classes,
        dim,
        heads,
        num_hierarchies,
        block_repeats,
        mlp_mult = 4,
        channels = 3,
        dim_head = 64,
        dropout = 0.
    ):
        # 调用父类的构造函数
        super().__init__()
        # 确保图像尺寸能够被分块尺寸整除
        assert (image_size % patch_size) == 0, 'Image dimensions must be divisible by the patch size.'
        # 计算总的分块数量
        num_patches = (image_size // patch_size) ** 2
        # 计算每个分块的维度
        patch_dim = channels * patch_size ** 2
        # 计算特征图的大小
        fmap_size = image_size // patch_size
        # 计算块的数量
        blocks = 2 ** (num_hierarchies - 1)

        # 计算序列长度,跨层次保持不变
        seq_len = (fmap_size // blocks) ** 2
        # 生成层次列表
        hierarchies = list(reversed(range(num_hierarchies)))
        # 计算每个层次的倍数
        mults = [2 ** i for i in reversed(hierarchies)]

        # 计算每个层次的头数
        layer_heads = list(map(lambda t: t * heads, mults))
        # 计算每个层次的维度
        layer_dims = list(map(lambda t: t * dim, mults))
        # 最后一个维度
        last_dim = layer_dims[-1]

        # 添加最后一个维度到层次维度列表
        layer_dims = [*layer_dims, layer_dims[-1]]
        # 生成维度对
        dim_pairs = zip(layer_dims[:-1], layer_dims[1:])

        # 定义将图像转换为分块嵌入的序列
        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (p1 p2 c) h w', p1 = patch_size, p2 = patch_size),
            LayerNorm(patch_dim),
            nn.Conv2d(patch_dim, layer_dims[0], 1),
            LayerNorm(layer_dims[0])
        )

        # 将块重复次数转换为元组
        block_repeats = cast_tuple(block_repeats, num_hierarchies)

        # 初始化层次列表
        self.layers = nn.ModuleList([])

        # 遍历层次、头数、维度对、块重复次数
        for level, heads, (dim_in, dim_out), block_repeat in zip(hierarchies, layer_heads, dim_pairs, block_repeats):
            is_last = level == 0
            depth = block_repeat

            # 添加 Transformer 和 Aggregate 模块到层次列表
            self.layers.append(nn.ModuleList([
                Transformer(dim_in, seq_len, depth, heads, mlp_mult, dropout),
                Aggregate(dim_in, dim_out) if not is_last else nn.Identity()
            ]))


        # 定义 MLP 头部
        self.mlp_head = nn.Sequential(
            LayerNorm(last_dim),
            Reduce('b c h w -> b c', 'mean'),
            nn.Linear(last_dim, num_classes)
        )

    def forward(self, img):
        # 将图像转换为分块��入
        x = self.to_patch_embedding(img)
        b, c, h, w = x.shape

        # 获取层次数量
        num_hierarchies = len(self.layers)

        # 遍历层次,应用 Transformer 和 Aggregate 模块
        for level, (transformer, aggregate) in zip(reversed(range(num_hierarchies)), self.layers):
            block_size = 2 ** level
            x = rearrange(x, 'b c (b1 h) (b2 w) -> (b b1 b2) c h w', b1 = block_size, b2 = block_size)
            x = transformer(x)
            x = rearrange(x, '(b b1 b2) c h w -> b c (b1 h) (b2 w)', b1 = block_size, b2 = block_size)
            x = aggregate(x)

        # 应用 MLP 头部并返回结果
        return self.mlp_head(x)