Lucidrains-系列项目源码解析-一百零九-

65 阅读26分钟

Lucidrains 系列项目源码解析(一百零九)

.\lucidrains\vit-pytorch\vit_pytorch\parallel_vit.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块
from torch import nn

# 从 einops 库中导入 rearrange 和 repeat 函数
from einops import rearrange, repeat
# 从 einops.layers.torch 库中导入 Rearrange 类
from einops.layers.torch import Rearrange

# 辅助函数

# 定义一个函数,如果输入参数 t 是元组则返回 t,否则返回一个包含 t 的元组
def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# 类定义

# 定义一个并行模块类
class Parallel(nn.Module):
    def __init__(self, *fns):
        super().__init__()
        self.fns = nn.ModuleList(fns)

    def forward(self, x):
        return sum([fn(x) for fn in self.fns])

# 定义一个前馈神经网络类
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

# 定义一个注意力机制类
class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.norm = nn.LayerNorm(dim)
        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        x = self.norm(x)
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

# 定义一个 Transformer 类
class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, num_parallel_branches = 2, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])

        attn_block = lambda: Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)
        ff_block = lambda: FeedForward(dim, mlp_dim, dropout = dropout)

        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Parallel(*[attn_block() for _ in range(num_parallel_branches)]),
                Parallel(*[ff_block() for _ in range(num_parallel_branches)]),
            ]))

    def forward(self, x):
        for attns, ffs in self.layers:
            x = attns(x) + x
            x = ffs(x) + x
        return x

# 定义一个 ViT 类
class ViT(nn.Module):
    # 初始化函数,设置模型参数
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', num_parallel_branches = 2, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        # 检查图像尺寸是否能被分块尺寸整除
        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        # 计算分块数量和分块维度
        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width
        # 检查池化类型是否为'cls'或'mean'
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        # 将图像分块转换为嵌入向量
        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.Linear(patch_dim, dim),
        )

        # 初始化位置编码和类别标记
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        # 初始化Transformer模型
        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, num_parallel_branches, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()

        # MLP头部,用于分类
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    # 前向传播函数
    def forward(self, img):
        # 图像转换为分块嵌入向量
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        # 重复类别标记以匹配批次大小
        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)

        # Transformer模型处理嵌入向量
        x = self.transformer(x)

        # 池化操作,根据池化类型选择不同的操作
        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

        # 转换为最终输出
        x = self.to_latent(x)
        return self.mlp_head(x)

.\lucidrains\vit-pytorch\vit_pytorch\pit.py

# 从 math 模块中导入 sqrt 函数
from math import sqrt

# 导入 torch 模块及相关子模块
import torch
from torch import nn, einsum
import torch.nn.functional as F

# 导入 einops 模块中的 rearrange 和 repeat 函数,以及 torch 子模块中的 Rearrange 类
from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# 定义辅助函数

# 将输入值转换为元组,如果不是元组则重复 num 次
def cast_tuple(val, num):
    return val if isinstance(val, tuple) else (val,) * num

# 计算卷积输出大小
def conv_output_size(image_size, kernel_size, stride, padding = 0):
    return int(((image_size - kernel_size + (2 * padding)) / stride) + 1)

# 定义类

# 前馈神经网络类
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

# 注意力机制类
class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.norm = nn.LayerNorm(dim)
        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        b, n, _, h = *x.shape, self.heads

        x = self.norm(x)
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)

        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

# Transformer 类
class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
                FeedForward(dim, mlp_dim, dropout = dropout)
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

# 深度卷积,用于池化
class DepthWiseConv2d(nn.Module):
    def __init__(self, dim_in, dim_out, kernel_size, padding, stride, bias = True):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(dim_in, dim_out, kernel_size = kernel_size, padding = padding, groups = dim_in, stride = stride, bias = bias),
            nn.Conv2d(dim_out, dim_out, kernel_size = 1, bias = bias)
        )
    def forward(self, x):
        return self.net(x)

# 池化层
class Pool(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.downsample = DepthWiseConv2d(dim, dim * 2, kernel_size = 3, stride = 2, padding = 1)
        self.cls_ff = nn.Linear(dim, dim * 2)

    def forward(self, x):
        cls_token, tokens = x[:, :1], x[:, 1:]

        cls_token = self.cls_ff(cls_token)

        tokens = rearrange(tokens, 'b (h w) c -> b c h w', h = int(sqrt(tokens.shape[1])))
        tokens = self.downsample(tokens)
        tokens = rearrange(tokens, 'b c h w -> b (h w) c')

        return torch.cat((cls_token, tokens), dim = 1)

# 主类
class PiT(nn.Module):
    def __init__(
        self,
        *,
        image_size,
        patch_size,
        num_classes,
        dim,
        depth,
        heads,
        mlp_dim,
        dim_head = 64,
        dropout = 0.,
        emb_dropout = 0.,
        channels = 3
    ):  
        # 初始化函数,继承父类的初始化方法
        super().__init__()
        # 确保图像尺寸能够被分块大小整除
        assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
        # 确保深度是一个整数元组,指定每个下采样之前的块数
        assert isinstance(depth, tuple), 'depth must be a tuple of integers, specifying the number of blocks before each downsizing'
        # 将头数转换为元组
        heads = cast_tuple(heads, len(depth))

        # 计算每个分块的维度
        patch_dim = channels * patch_size ** 2

        # 创建将图像转换为分块嵌入的序列
        self.to_patch_embedding = nn.Sequential(
            nn.Unfold(kernel_size = patch_size, stride = patch_size // 2),
            Rearrange('b c n -> b n c'),
            nn.Linear(patch_dim, dim)
        )

        # 计算输出大小和分块数量
        output_size = conv_output_size(image_size, patch_size, patch_size // 2)
        num_patches = output_size ** 2

        # 初始化位置嵌入和类别令牌
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        layers = []

        # 遍历深度和头数,构建Transformer层和池化层
        for ind, (layer_depth, layer_heads) in enumerate(zip(depth, heads)):
            not_last = ind < (len(depth) - 1)
            
            layers.append(Transformer(dim, layer_depth, layer_heads, dim_head, mlp_dim, dropout))

            if not_last:
                layers.append(Pool(dim))
                dim *= 2

        self.layers = nn.Sequential(*layers)

        # 创建MLP头部
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, img):
        # 将图像转换为分块嵌入
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        # 重复类别令牌并连接到分块嵌入
        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :n+1]
        x = self.dropout(x)

        # 通过层堆叠Transformer
        x = self.layers(x)

        # 返回MLP头部的输出
        return self.mlp_head(x[:, 0])

.\lucidrains\vit-pytorch\vit_pytorch\recorder.py

# 从 functools 模块导入 wraps 装饰器
from functools import wraps
# 导入 torch 模块
import torch
# 从 torch 模块导入 nn 模块
from torch import nn

# 从 vit_pytorch.vit 模块导入 Attention 类
from vit_pytorch.vit import Attention

# 定义一个函数,用于查找指定类型的模块
def find_modules(nn_module, type):
    return [module for module in nn_module.modules() if isinstance(module, type)]

# 定义一个 Recorder 类,继承自 nn.Module 类
class Recorder(nn.Module):
    # 初始化方法
    def __init__(self, vit, device = None):
        super().__init__()
        self.vit = vit

        self.data = None
        self.recordings = []
        self.hooks = []
        self.hook_registered = False
        self.ejected = False
        self.device = device

    # 私有方法,用于注册钩子函数
    def _hook(self, _, input, output):
        self.recordings.append(output.clone().detach())

    # 注册钩子函数的方法
    def _register_hook(self):
        # 查找所有 transformer 模块中的 Attention 模块
        modules = find_modules(self.vit.transformer, Attention)
        # 为每个 Attention 模块注册前向钩子函数
        for module in modules:
            handle = module.attend.register_forward_hook(self._hook)
            self.hooks.append(handle)
        self.hook_registered = True

    # 弹出 Recorder 对象的方法
    def eject(self):
        self.ejected = True
        # 移除所有钩子函数
        for hook in self.hooks:
            hook.remove()
        self.hooks.clear()
        return self.vit

    # 清空记录数据的方法
    def clear(self):
        self.recordings.clear()

    # 记录数据的方法
    def record(self, attn):
        recording = attn.clone().detach()
        self.recordings.append(recording)

    # 前向传播方法
    def forward(self, img):
        # 断言 Recorder 对象未被弹出
        assert not self.ejected, 'recorder has been ejected, cannot be used anymore'
        self.clear()
        # 如果钩子函数未注册,则注册钩子函数
        if not self.hook_registered:
            self._register_hook()

        # 对输入图片进行预测
        pred = self.vit(img)

        # 将所有记录数据移动到指定设备上
        target_device = self.device if self.device is not None else img.device
        recordings = tuple(map(lambda t: t.to(target_device), self.recordings))

        # 如果有记录数据,则在指定维度上堆叠
        attns = torch.stack(recordings, dim = 1) if len(recordings) > 0 else None
        return pred, attns

.\lucidrains\vit-pytorch\vit_pytorch\regionvit.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块和 einsum 函数
from torch import nn, einsum
# 从 einops 库中导入 rearrange 函数和 Rearrange、Reduce 类
from einops import rearrange
from einops.layers.torch import Rearrange, Reduce
# 从 torch 库中导入 nn.functional 模块,并重命名为 F

# 辅助函数

# 判断变量是否存在
def exists(val):
    return val is not None

# 如果变量存在则返回其值,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# 将变量转换为元组,如果不是元组则重复 length 次
def cast_tuple(val, length = 1):
    return val if isinstance(val, tuple) else ((val,) * length)

# 判断一个数是否可以被另一个数整除
def divisible_by(val, d):
    return (val % d) == 0

# 辅助类

# 下采样类
class Downsample(nn.Module):
    def __init__(self, dim_in, dim_out):
        super().__init__()
        self.conv = nn.Conv2d(dim_in, dim_out, 3, stride = 2, padding = 1)

    def forward(self, x):
        return self.conv(x)

# PEG 类
class PEG(nn.Module):
    def __init__(self, dim, kernel_size = 3):
        super().__init__()
        self.proj = nn.Conv2d(dim, dim, kernel_size = kernel_size, padding = kernel_size // 2, groups = dim, stride = 1)

    def forward(self, x):
        return self.proj(x) + x

# transformer 类

# 前馈网络
def FeedForward(dim, mult = 4, dropout = 0.):
    return nn.Sequential(
        nn.LayerNorm(dim),
        nn.Linear(dim, dim * mult, 1),
        nn.GELU(),
        nn.Dropout(dropout),
        nn.Linear(dim * mult, dim, 1)
    )

# 注意力机制
class Attention(nn.Module):
    def __init__(
        self,
        dim,
        heads = 4,
        dim_head = 32,
        dropout = 0.
    ):
        super().__init__()
        self.heads = heads
        self.scale = dim_head ** -0.5
        inner_dim = dim_head * heads

        self.norm = nn.LayerNorm(dim)
        self.dropout = nn.Dropout(dropout)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x, rel_pos_bias = None):
        h = self.heads

        # prenorm

        x = self.norm(x)

        q, k, v = self.to_qkv(x).chunk(3, dim = -1)

        # split heads

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
        q = q * self.scale

        sim = einsum('b h i d, b h j d -> b h i j', q, k)

        # add relative positional bias for local tokens

        if exists(rel_pos_bias):
            sim = sim + rel_pos_bias

        attn = sim.softmax(dim = -1)
        attn = self.dropout(attn)

        # merge heads

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

# R2LTransformer 类
class R2LTransformer(nn.Module):
    def __init__(
        self,
        dim,
        *,
        window_size,
        depth = 4,
        heads = 4,
        dim_head = 32,
        attn_dropout = 0.,
        ff_dropout = 0.,
    ):
        super().__init__()
        self.layers = nn.ModuleList([])

        self.window_size = window_size
        rel_positions = 2 * window_size - 1
        self.local_rel_pos_bias = nn.Embedding(rel_positions ** 2, heads)

        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head, dropout = attn_dropout),
                FeedForward(dim, dropout = ff_dropout)
            ]))
    # 定义一个前向传播函数,接受本地 tokens 和区域 tokens 作为输入
    def forward(self, local_tokens, region_tokens):
        # 获取本地 tokens 的设备信息
        device = local_tokens.device
        # 获取本地 tokens 和区域 tokens 的高度和宽度
        lh, lw = local_tokens.shape[-2:]
        rh, rw = region_tokens.shape[-2:]
        # 计算窗口大小
        window_size_h, window_size_w = lh // rh, lw // rw

        # 重排本地 tokens 和区域 tokens 的维度
        local_tokens = rearrange(local_tokens, 'b c h w -> b (h w) c')
        region_tokens = rearrange(region_tokens, 'b c h w -> b (h w) c')

        # 计算本地相对位置偏差
        h_range = torch.arange(window_size_h, device = device)
        w_range = torch.arange(window_size_w, device = device)
        grid_x, grid_y = torch.meshgrid(h_range, w_range, indexing = 'ij')
        grid = torch.stack((grid_x, grid_y))
        grid = rearrange(grid, 'c h w -> c (h w)')
        grid = (grid[:, :, None] - grid[:, None, :]) + (self.window_size - 1)
        bias_indices = (grid * torch.tensor([1, self.window_size * 2 - 1], device = device)[:, None, None]).sum(dim = 0)
        rel_pos_bias = self.local_rel_pos_bias(bias_indices)
        rel_pos_bias = rearrange(rel_pos_bias, 'i j h -> () h i j')
        rel_pos_bias = F.pad(rel_pos_bias, (1, 0, 1, 0), value = 0)

        # 遍历 r2l transformer 层
        for attn, ff in self.layers:
            # 对区域 tokens 进行自注意力操作
            region_tokens = attn(region_tokens) + region_tokens

            # 将区域 tokens 连接到本地 tokens
            local_tokens = rearrange(local_tokens, 'b (h w) d -> b h w d', h = lh)
            local_tokens = rearrange(local_tokens, 'b (h p1) (w p2) d -> (b h w) (p1 p2) d', p1 = window_size_h, p2 = window_size_w)
            region_tokens = rearrange(region_tokens, 'b n d -> (b n) () d')

            # 对本地 tokens 进行自注意力操作,同时考虑区域 tokens
            region_and_local_tokens = torch.cat((region_tokens, local_tokens), dim = 1)
            region_and_local_tokens = attn(region_and_local_tokens, rel_pos_bias = rel_pos_bias) + region_and_local_tokens

            # 前馈神经网络
            region_and_local_tokens = ff(region_and_local_tokens) + region_and_local_tokens

            # 分离本地和区域 tokens
            region_tokens, local_tokens = region_and_local_tokens[:, :1], region_and_local_tokens[:, 1:]
            local_tokens = rearrange(local_tokens, '(b h w) (p1 p2) d -> b (h p1 w p2) d', h = lh // window_size_h, w = lw // window_size_w, p1 = window_size_h)
            region_tokens = rearrange(region_tokens, '(b n) () d -> b n d', n = rh * rw)

        # 重排本地 tokens 和区域 tokens 的维度
        local_tokens = rearrange(local_tokens, 'b (h w) c -> b c h w', h = lh, w = lw)
        region_tokens = rearrange(region_tokens, 'b (h w) c -> b c h w', h = rh, w = rw)
        # 返回本地 tokens 和区域 tokens
        return local_tokens, region_tokens
# 定义一个名为 RegionViT 的类,继承自 nn.Module
class RegionViT(nn.Module):
    # 初始化函数,接受一系列参数
    def __init__(
        self,
        *,
        dim = (64, 128, 256, 512),  # 定义维度的元组
        depth = (2, 2, 8, 2),  # 定义深度的元组
        window_size = 7,  # 定义窗口大小
        num_classes = 1000,  # 定义类别数量
        tokenize_local_3_conv = False,  # 是否使用局部 3 卷积
        local_patch_size = 4,  # 定义局部补丁大小
        use_peg = False,  # 是否使用 PEG
        attn_dropout = 0.,  # 注意力机制的 dropout
        ff_dropout = 0.,  # 前馈神经网络的 dropout
        channels = 3,  # 通道数
    ):
        super().__init__()  # 调用父类的初始化函数
        dim = cast_tuple(dim, 4)  # 将维度转换为元组
        depth = cast_tuple(depth, 4)  # 将深度转换为元组
        assert len(dim) == 4, 'dim needs to be a single value or a tuple of length 4'  # 断言维度长度为 4
        assert len(depth) == 4, 'depth needs to be a single value or a tuple of length 4'  # 断言深度长度为 4

        self.local_patch_size = local_patch_size  # 设置局部补丁大小

        region_patch_size = local_patch_size * window_size  # 计算区域补丁大小
        self.region_patch_size = local_patch_size * window_size  # 设置区域补丁大小

        init_dim, *_, last_dim = dim  # 解构维度元组

        # 定义局部和区域编码器

        if tokenize_local_3_conv:
            self.local_encoder = nn.Sequential(
                nn.Conv2d(3, init_dim, 3, 2, 1),
                nn.LayerNorm(init_dim),
                nn.GELU(),
                nn.Conv2d(init_dim, init_dim, 3, 2, 1),
                nn.LayerNorm(init_dim),
                nn.GELU(),
                nn.Conv2d(init_dim, init_dim, 3, 1, 1)
            )
        else:
            self.local_encoder = nn.Conv2d(3, init_dim, 8, 4, 3)

        self.region_encoder = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (c p1 p2) h w', p1 = region_patch_size, p2 = region_patch_size),
            nn.Conv2d((region_patch_size ** 2) * channels, init_dim, 1)
        )

        # 定义层

        current_dim = init_dim  # 初始化当前维度
        self.layers = nn.ModuleList([])  # 初始化层列表

        for ind, dim, num_layers in zip(range(4), dim, depth):
            not_first = ind != 0  # 判断是否为第一层
            need_downsample = not_first  # 是否需要下采样
            need_peg = not_first and use_peg  # 是否需要 PEG

            self.layers.append(nn.ModuleList([
                Downsample(current_dim, dim) if need_downsample else nn.Identity(),  # 如果需要下采样则使用 Downsample,否则使用恒等映射
                PEG(dim) if need_peg else nn.Identity(),  # 如果需要 PEG 则使用 PEG,否则使用恒等映射
                R2LTransformer(dim, depth = num_layers, window_size = window_size, attn_dropout = attn_dropout, ff_dropout = ff_dropout)  # 使用 R2LTransformer
            ]))

            current_dim = dim  # 更新当前维度

        # 定义最终的 logits

        self.to_logits = nn.Sequential(
            Reduce('b c h w -> b c', 'mean'),  # 对特征进行降维
            nn.LayerNorm(last_dim),  # 对最后一个维度进行 LayerNorm
            nn.Linear(last_dim, num_classes)  # 线性变换得到类别数量
        )

    # 前向传播函数
    def forward(self, x):
        *_, h, w = x.shape  # 获取输入张量的高度和宽度
        assert divisible_by(h, self.region_patch_size) and divisible_by(w, self.region_patch_size), 'height and width must be divisible by region patch size'  # 断言高度和宽度必须能被区域补丁大小整除
        assert divisible_by(h, self.local_patch_size) and divisible_by(w, self.local_patch_size), 'height and width must be divisible by local patch size'  # 断言高度和宽度必须能被局部补丁大小整除

        local_tokens = self.local_encoder(x)  # 使用局部编码器对输入进行编码
        region_tokens = self.region_encoder(x)  # 使用区域编码器对输入进行编码

        for down, peg, transformer in self.layers:  # 遍历层列表
            local_tokens, region_tokens = down(local_tokens), down(region_tokens)  # 对局部和区域 tokens 进行下采样
            local_tokens = peg(local_tokens)  # 使用 PEG 对局部 tokens 进行处理
            local_tokens, region_tokens = transformer(local_tokens, region_tokens)  # 使用 transformer 对局部和区域 tokens 进行处理

        return self.to_logits(region_tokens)  # 返回最终的 logits

.\lucidrains\vit-pytorch\vit_pytorch\rvt.py

# 从 math 模块中导入 sqrt, pi, log 函数
# 从 torch 模块中导入 nn, einsum, F
# 从 einops 模块中导入 rearrange, repeat
# 从 einops.layers.torch 模块中导入 Rearrange 类
from math import sqrt, pi, log

import torch
from torch import nn, einsum
import torch.nn.functional as F

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# 旋转嵌入

# 将输入张量中的每两个元素进行旋转
def rotate_every_two(x):
    x = rearrange(x, '... (d j) -> ... d j', j = 2)
    x1, x2 = x.unbind(dim = -1)
    x = torch.stack((-x2, x1), dim = -1)
    return rearrange(x, '... d j -> ... (d j)')

# 轴向旋转嵌入类
class AxialRotaryEmbedding(nn.Module):
    def __init__(self, dim, max_freq = 10):
        super().__init__()
        self.dim = dim
        scales = torch.linspace(1., max_freq / 2, self.dim // 4)
        self.register_buffer('scales', scales)

    def forward(self, x):
        device, dtype, n = x.device, x.dtype, int(sqrt(x.shape[-2]))

        seq = torch.linspace(-1., 1., steps = n, device = device)
        seq = seq.unsqueeze(-1)

        scales = self.scales[(*((None,) * (len(seq.shape) - 1)), Ellipsis]
        scales = scales.to(x)

        seq = seq * scales * pi

        x_sinu = repeat(seq, 'i d -> i j d', j = n)
        y_sinu = repeat(seq, 'j d -> i j d', i = n)

        sin = torch.cat((x_sinu.sin(), y_sinu.sin()), dim = -1)
        cos = torch.cat((x_sinu.cos(), y_sinu.cos()), dim = -1)

        sin, cos = map(lambda t: rearrange(t, 'i j d -> (i j) d'), (sin, cos))
        sin, cos = map(lambda t: repeat(t, 'n d -> () n (d j)', j = 2), (sin, cos))
        return sin, cos

# 深度可分离卷积类
class DepthWiseConv2d(nn.Module):
    def __init__(self, dim_in, dim_out, kernel_size, padding, stride = 1, bias = True):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(dim_in, dim_in, kernel_size = kernel_size, padding = padding, groups = dim_in, stride = stride, bias = bias),
            nn.Conv2d(dim_in, dim_out, kernel_size = 1, bias = bias)
        )
    def forward(self, x):
        return self.net(x)

# 辅助类

# 空间卷积类
class SpatialConv(nn.Module):
    def __init__(self, dim_in, dim_out, kernel, bias = False):
        super().__init__()
        self.conv = DepthWiseConv2d(dim_in, dim_out, kernel, padding = kernel // 2, bias = False)
        self.cls_proj = nn.Linear(dim_in, dim_out) if dim_in != dim_out else nn.Identity()

    def forward(self, x, fmap_dims):
        cls_token, x = x[:, :1], x[:, 1:]
        x = rearrange(x, 'b (h w) d -> b d h w', **fmap_dims)
        x = self.conv(x)
        x = rearrange(x, 'b d h w -> b (h w) d')
        cls_token = self.cls_proj(cls_token)
        return torch.cat((cls_token, x), dim = 1)

# GEGLU 类
class GEGLU(nn.Module):
    def forward(self, x):
        x, gates = x.chunk(2, dim = -1)
        return F.gelu(gates) * x

# 前馈网络类
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0., use_glu = True):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim * 2 if use_glu else hidden_dim),
            GEGLU() if use_glu else nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

# 注意力类
class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., use_rotary = True, use_ds_conv = True, conv_query_kernel = 5):
        super().__init__()
        inner_dim = dim_head *  heads
        self.use_rotary = use_rotary
        self.heads = heads
        self.scale = dim_head ** -0.5

        self.norm = nn.LayerNorm(dim)
        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.use_ds_conv = use_ds_conv

        self.to_q = SpatialConv(dim, inner_dim, conv_query_kernel, bias = False) if use_ds_conv else nn.Linear(dim, inner_dim, bias = False)

        self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        )
    # 定义前向传播函数,接受输入 x,位置嵌入 pos_emb,特征图维度 fmap_dims
    def forward(self, x, pos_emb, fmap_dims):
        # 获取输入 x 的形状信息
        b, n, _, h = *x.shape, self.heads

        # 如果使用深度可分离卷积,则传递特定参数给 to_q 函数
        to_q_kwargs = {'fmap_dims': fmap_dims} if self.use_ds_conv else {}

        # 对输入 x 进行归一化处理
        x = self.norm(x)

        # 将 x 传递给 to_q 函数,得到查询向量 q
        q = self.to_q(x, **to_q_kwargs)

        # 将 q 与键值对应的结果拆分为 q, k, v
        qkv = (q, *self.to_kv(x).chunk(2, dim = -1))

        # 将 q, k, v 重排维度,以适应多头注意力机制
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), qkv)

        # 如果使用旋转注意力机制
        if self.use_rotary:
            # 对查询和键应用二维旋转嵌入,不包括 CLS 标记
            sin, cos = pos_emb
            dim_rotary = sin.shape[-1]

            # 拆分 CLS 标记和其余部分
            (q_cls, q), (k_cls, k) = map(lambda t: (t[:, :1], t[:, 1:]), (q, k))

            # 处理旋转维度小于头维度的情况
            (q, q_pass), (k, k_pass) = map(lambda t: (t[..., :dim_rotary], t[..., dim_rotary:]), (q, k))
            q, k = map(lambda t: (t * cos) + (rotate_every_two(t) * sin), (q, k))
            q, k = map(lambda t: torch.cat(t, dim = -1), ((q, q_pass), (k, k_pass)))

            # 拼接回 CLS 标记
            q = torch.cat((q_cls, q), dim = 1)
            k = torch.cat((k_cls, k), dim = 1)

        # 计算点积注意力得分
        dots = einsum('b i d, b j d -> b i j', q, k) * self.scale

        # 经过注意力计算
        attn = self.attend(dots)
        attn = self.dropout(attn)

        # 计算输出
        out = einsum('b i j, b j d -> b i d', attn, v)
        out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
        # 返回输出结果
        return self.to_out(out)
# 定义一个 Transformer 类,继承自 nn.Module
class Transformer(nn.Module):
    # 初始化函数,接受多个参数
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, image_size, dropout = 0., use_rotary = True, use_ds_conv = True, use_glu = True):
        # 调用父类的初始化函数
        super().__init__()
        # 初始化一个空的层列表
        self.layers = nn.ModuleList([])
        # 创建 AxialRotaryEmbedding 对象作为位置编码
        self.pos_emb = AxialRotaryEmbedding(dim_head, max_freq = image_size)
        # 循环创建指定数量的层
        for _ in range(depth):
            # 每层包含注意力机制和前馈神经网络
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, use_rotary = use_rotary, use_ds_conv = use_ds_conv),
                FeedForward(dim, mlp_dim, dropout = dropout, use_glu = use_glu)
            ]))
    # 前向传播函数,接受输入 x 和 fmap_dims
    def forward(self, x, fmap_dims):
        # 计算位置编码
        pos_emb = self.pos_emb(x[:, 1:])
        # 遍历每一层,依次进行注意力机制和前馈神经网络操作
        for attn, ff in self.layers:
            x = attn(x, pos_emb = pos_emb, fmap_dims = fmap_dims) + x
            x = ff(x) + x
        # 返回处理后的结果
        return x

# 定义一个 RvT 类,继承自 nn.Module
class RvT(nn.Module):
    # 初始化函数,接受多个参数
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., use_rotary = True, use_ds_conv = True, use_glu = True):
        # 调用父类的初始化函数
        super().__init__()
        # 断言确保图像尺寸能够被补丁大小整除
        assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
        # 计算补丁数量和补丁维度
        num_patches = (image_size // patch_size) ** 2
        patch_dim = channels * patch_size ** 2

        # 初始化补丁嵌入层
        self.patch_size = patch_size
        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
            nn.Linear(patch_dim, dim),
        )

        # 初始化分类令牌
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        # 初始化 Transformer 模型
        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, image_size, dropout, use_rotary, use_ds_conv, use_glu)

        # 初始化 MLP 头部
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    # 前向传播函数,接受输入图像 img
    def forward(self, img):
        # 获取输入图像的形状信息
        b, _, h, w, p = *img.shape, self.patch_size

        # 将图像转换为补丁嵌入
        x = self.to_patch_embedding(img)
        n = x.shape[1]

        # 重复分类令牌并与补丁嵌入拼接
        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
        x = torch.cat((cls_tokens, x), dim=1)

        # 计算特征图尺寸信息
        fmap_dims = {'h': h // p, 'w': w // p}
        # 使用 Transformer 处理输入数据
        x = self.transformer(x, fmap_dims = fmap_dims)

        # 返回 MLP 头部处理后的结果
        return self.mlp_head(x[:, 0])

.\lucidrains\vit-pytorch\vit_pytorch\scalable_vit.py

# 导入必要的库
from functools import partial
import torch
from torch import nn

# 导入 einops 库中的函数和层
from einops import rearrange, repeat
from einops.layers.torch import Rearrange, Reduce

# 辅助函数

# 判断变量是否存在
def exists(val):
    return val is not None

# 如果变量存在则返回其值,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# 将输入转换为元组
def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# 将输入转换为指定长度的元组
def cast_tuple(val, length = 1):
    return val if isinstance(val, tuple) else ((val,) * length)

# 辅助类

# 通道层归一化
class ChanLayerNorm(nn.Module):
    def __init__(self, dim, eps = 1e-5):
        super().__init__()
        self.eps = eps
        self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
        self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))

    def forward(self, x):
        var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
        mean = torch.mean(x, dim = 1, keepdim = True)
        return (x - mean) / (var + self.eps).sqrt() * self.g + self.b

# 下采样
class Downsample(nn.Module):
    def __init__(self, dim_in, dim_out):
        super().__init__()
        self.conv = nn.Conv2d(dim_in, dim_out, 3, stride = 2, padding = 1)

    def forward(self, x):
        return self.conv(x)

# 位置编码器
class PEG(nn.Module):
    def __init__(self, dim, kernel_size = 3):
        super().__init__()
        self.proj = nn.Conv2d(dim, dim, kernel_size = kernel_size, padding = kernel_size // 2, groups = dim, stride = 1)

    def forward(self, x):
        return self.proj(x) + x

# 前馈网络
class FeedForward(nn.Module):
    def __init__(self, dim, expansion_factor = 4, dropout = 0.):
        super().__init__()
        inner_dim = dim * expansion_factor
        self.net = nn.Sequential(
            ChanLayerNorm(dim),
            nn.Conv2d(dim, inner_dim, 1),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Conv2d(inner_dim, dim, 1),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

# 注意力机制

# 可扩展的自注意力机制
class ScalableSelfAttention(nn.Module):
    def __init__(
        self,
        dim,
        heads = 8,
        dim_key = 32,
        dim_value = 32,
        dropout = 0.,
        reduction_factor = 1
    ):
        super().__init__()
        self.heads = heads
        self.scale = dim_key ** -0.5
        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.norm = ChanLayerNorm(dim)
        self.to_q = nn.Conv2d(dim, dim_key * heads, 1, bias = False)
        self.to_k = nn.Conv2d(dim, dim_key * heads, reduction_factor, stride = reduction_factor, bias = False)
        self.to_v = nn.Conv2d(dim, dim_value * heads, reduction_factor, stride = reduction_factor, bias = False)

        self.to_out = nn.Sequential(
            nn.Conv2d(dim_value * heads, dim, 1),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        height, width, heads = *x.shape[-2:], self.heads

        x = self.norm(x)

        q, k, v = self.to_q(x), self.to_k(x), self.to_v(x)

        # 分割头部

        q, k, v = map(lambda t: rearrange(t, 'b (h d) ... -> b h (...) d', h = heads), (q, k, v))

        # 相似度

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        # 注意力权重

        attn = self.attend(dots)
        attn = self.dropout(attn)

        # 聚合数值

        out = torch.matmul(attn, v)

        # 合并头部

        out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = height, y = width)
        return self.to_out(out)

# 交互式窗口化自注意力机制
class InteractiveWindowedSelfAttention(nn.Module):
    def __init__(
        self,
        dim,
        window_size,
        heads = 8,
        dim_key = 32,
        dim_value = 32,
        dropout = 0.
    ):
        # 调用父类的构造函数
        super().__init__()
        # 初始化头数和缩放因子
        self.heads = heads
        self.scale = dim_key ** -0.5
        self.window_size = window_size
        # 初始化注意力机制和dropout层
        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        # 初始化通道层归一化和局部交互模块
        self.norm = ChanLayerNorm(dim)
        self.local_interactive_module = nn.Conv2d(dim_value * heads, dim_value * heads, 3, padding = 1)

        # 初始化转换层,将输入转换为查询、键和值
        self.to_q = nn.Conv2d(dim, dim_key * heads, 1, bias = False)
        self.to_k = nn.Conv2d(dim, dim_key * heads, 1, bias = False)
        self.to_v = nn.Conv2d(dim, dim_value * heads, 1, bias = False)

        # 初始化输出层,包括卷积层和dropout层
        self.to_out = nn.Sequential(
            nn.Conv2d(dim_value * heads, dim, 1),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        # 获取输入张量的高度、宽度、头数和窗口大小
        height, width, heads, wsz = *x.shape[-2:], self.heads, self.window_size

        # 对输入张量进行归一化
        x = self.norm(x)

        # 计算窗口的高度和宽度
        wsz_h, wsz_w = default(wsz, height), default(wsz, width)
        assert (height % wsz_h) == 0 and (width % wsz_w) == 0, f'height ({height}) or width ({width}) of feature map is not divisible by the window size ({wsz_h}, {wsz_w})'

        # 将输入张量转换为查询、键和值
        q, k, v = self.to_q(x), self.to_k(x), self.to_v(x)

        # 获取局部交互模块的输出
        local_out = self.local_interactive_module(v)

        # 将查询、键和值分割成窗口(并拆分出头部)以进行有效的自注意力计算
        q, k, v = map(lambda t: rearrange(t, 'b (h d) (x w1) (y w2) -> (b x y) h (w1 w2) d', h = heads, w1 = wsz_h, w2 = wsz_w), (q, k, v))

        # 计算相似度
        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        # 注意力计算
        attn = self.attend(dots)
        attn = self.dropout(attn)

        # 聚合值
        out = torch.matmul(attn, v)

        # 将窗口重新整形为完整的特征图(并合并头部)
        out = rearrange(out, '(b x y) h (w1 w2) d -> b (h d) (x w1) (y w2)', x = height // wsz_h, y = width // wsz_w, w1 = wsz_h, w2 = wsz_w)

        # 添加局部交互模块的输出
        out = out + local_out

        return self.to_out(out)
class Transformer(nn.Module):
    # 定义 Transformer 类,继承自 nn.Module
    def __init__(
        self,
        dim,
        depth,
        heads = 8,
        ff_expansion_factor = 4,
        dropout = 0.,
        ssa_dim_key = 32,
        ssa_dim_value = 32,
        ssa_reduction_factor = 1,
        iwsa_dim_key = 32,
        iwsa_dim_value = 32,
        iwsa_window_size = None,
        norm_output = True
    ):
        # 初始化函数
        super().__init__()
        # 初始化 nn.ModuleList 用于存储 Transformer 层
        self.layers = nn.ModuleList([])
        # 循环创建 Transformer 层
        for ind in range(depth):
            # 判断是否为第一层
            is_first = ind == 0

            # 添加 Transformer 层的组件到 layers 中
            self.layers.append(nn.ModuleList([
                ScalableSelfAttention(dim, heads = heads, dim_key = ssa_dim_key, dim_value = ssa_dim_value, reduction_factor = ssa_reduction_factor, dropout = dropout),
                FeedForward(dim, expansion_factor = ff_expansion_factor, dropout = dropout),
                PEG(dim) if is_first else None,
                FeedForward(dim, expansion_factor = ff_expansion_factor, dropout = dropout),
                InteractiveWindowedSelfAttention(dim, heads = heads, dim_key = iwsa_dim_key, dim_value = iwsa_dim_value, window_size = iwsa_window_size, dropout = dropout)
            ]))

        # 初始化最后的归一化层
        self.norm = ChanLayerNorm(dim) if norm_output else nn.Identity()

    # 前向传播函数
    def forward(self, x):
        # 遍历 Transformer 层
        for ssa, ff1, peg, iwsa, ff2 in self.layers:
            # Self-Attention 操作
            x = ssa(x) + x
            # FeedForward 操作
            x = ff1(x) + x

            # 如果存在 PEG 操作,则执行
            if exists(peg):
                x = peg(x)

            # Interactive Windowed Self-Attention 操作
            x = iwsa(x) + x
            # 再次 FeedForward 操作
            x = ff2(x) + x

        # 返回归一化后的结果
        return self.norm(x)

class ScalableViT(nn.Module):
    # 定义 ScalableViT 类,继承自 nn.Module
    def __init__(
        self,
        *,
        num_classes,
        dim,
        depth,
        heads,
        reduction_factor,
        window_size = None,
        iwsa_dim_key = 32,
        iwsa_dim_value = 32,
        ssa_dim_key = 32,
        ssa_dim_value = 32,
        ff_expansion_factor = 4,
        channels = 3,
        dropout = 0.
    ):
        # 初始化函数
        super().__init__()
        # 将图像转换为补丁
        self.to_patches = nn.Conv2d(channels, dim, 7, stride = 4, padding = 3)

        # 断言 depth 为元组,表示每个阶段的 Transformer 块数量
        assert isinstance(depth, tuple), 'depth needs to be tuple if integers indicating number of transformer blocks at that stage'

        # 计算每个阶段的维度
        num_stages = len(depth)
        dims = tuple(map(lambda i: (2 ** i) * dim, range(num_stages)))

        # 定义每个阶段的超参数
        hyperparams_per_stage = [
            heads,
            ssa_dim_key,
            ssa_dim_value,
            reduction_factor,
            iwsa_dim_key,
            iwsa_dim_value,
            window_size,
        ]

        # 将超参数转换为每个阶段的形式
        hyperparams_per_stage = list(map(partial(cast_tuple, length = num_stages), hyperparams_per_stage))
        assert all(tuple(map(lambda arr: len(arr) == num_stages, hyperparams_per_stage)))

        # 初始化 Transformer 层
        self.layers = nn.ModuleList([])

        # 遍历每个阶段的维度和超参数
        for ind, (layer_dim, layer_depth, layer_heads, layer_ssa_dim_key, layer_ssa_dim_value, layer_ssa_reduction_factor, layer_iwsa_dim_key, layer_iwsa_dim_value, layer_window_size) in enumerate(zip(dims, depth, *hyperparams_per_stage)):
            is_last = ind == (num_stages - 1)

            # 添加 Transformer 层和下采样层到 layers 中
            self.layers.append(nn.ModuleList([
                Transformer(dim = layer_dim, depth = layer_depth, heads = layer_heads, ff_expansion_factor = ff_expansion_factor, dropout = dropout, ssa_dim_key = layer_ssa_dim_key, ssa_dim_value = layer_ssa_dim_value, ssa_reduction_factor = layer_ssa_reduction_factor, iwsa_dim_key = layer_iwsa_dim_key, iwsa_dim_value = layer_iwsa_dim_value, iwsa_window_size = layer_window_size, norm_output = not is_last),
                Downsample(layer_dim, layer_dim * 2) if not is_last else None
            ]))

        # MLP 头部
        self.mlp_head = nn.Sequential(
            Reduce('b d h w -> b d', 'mean'),
            nn.LayerNorm(dims[-1]),
            nn.Linear(dims[-1], num_classes)
        )

    # 前向传播函数
    def forward(self, img):
        # 将图像转换为补丁
        x = self.to_patches(img)

        # 遍历每个 Transformer 层
        for transformer, downsample in self.layers:
            x = transformer(x)

            # 如果存在下采样层,则执行
            if exists(downsample):
                x = downsample(x)

        # 返回 MLP 头部的结果
        return self.mlp_head(x)

.\lucidrains\vit-pytorch\vit_pytorch\sep_vit.py

# 导入必要的库
from functools import partial

import torch
from torch import nn, einsum

from einops import rearrange, repeat
from einops.layers.torch import Rearrange, Reduce

# 辅助函数

def cast_tuple(val, length = 1):
    return val if isinstance(val, tuple) else ((val,) * length)

# 辅助类

class ChanLayerNorm(nn.Module):
    def __init__(self, dim, eps = 1e-5):
        super().__init__()
        self.eps = eps
        self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
        self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))

    def forward(self, x):
        var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
        mean = torch.mean(x, dim = 1, keepdim = True)
        return (x - mean) / (var + self.eps).sqrt() * self.g + self.b

class OverlappingPatchEmbed(nn.Module):
    def __init__(self, dim_in, dim_out, stride = 2):
        super().__init__()
        kernel_size = stride * 2 - 1
        padding = kernel_size // 2
        self.conv = nn.Conv2d(dim_in, dim_out, kernel_size, stride = stride, padding = padding)

    def forward(self, x):
        return self.conv(x)

class PEG(nn.Module):
    def __init__(self, dim, kernel_size = 3):
        super().__init__()
        self.proj = nn.Conv2d(dim, dim, kernel_size = kernel_size, padding = kernel_size // 2, groups = dim, stride = 1)

    def forward(self, x):
        return self.proj(x) + x

# 前馈网络

class FeedForward(nn.Module):
    def __init__(self, dim, mult = 4, dropout = 0.):
        super().__init__()
        inner_dim = int(dim * mult)
        self.net = nn.Sequential(
            ChanLayerNorm(dim),
            nn.Conv2d(dim, inner_dim, 1),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Conv2d(inner_dim, dim, 1),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

# 注意力机制

class DSSA(nn.Module):
    def __init__(
        self,
        dim,
        heads = 8,
        dim_head = 32,
        dropout = 0.,
        window_size = 7
    ):
        super().__init__()
        self.heads = heads
        self.scale = dim_head ** -0.5
        self.window_size = window_size
        inner_dim = dim_head * heads

        self.norm = ChanLayerNorm(dim)

        self.attend = nn.Sequential(
            nn.Softmax(dim = -1),
            nn.Dropout(dropout)
        )

        self.to_qkv = nn.Conv1d(dim, inner_dim * 3, 1, bias = False)

        # 窗口标记

        self.window_tokens = nn.Parameter(torch.randn(dim))

        # 窗口标记的预处理和非线性变换
        # 然后将窗口标记投影到查询和键

        self.window_tokens_to_qk = nn.Sequential(
            nn.LayerNorm(dim_head),
            nn.GELU(),
            Rearrange('b h n c -> b (h c) n'),
            nn.Conv1d(inner_dim, inner_dim * 2, 1),
            Rearrange('b (h c) n -> b h n c', h = heads),
        )

        # 窗口注意力

        self.window_attend = nn.Sequential(
            nn.Softmax(dim = -1),
            nn.Dropout(dropout)
        )

        self.to_out = nn.Sequential(
            nn.Conv2d(inner_dim, dim, 1),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        """
        einstein notation

        b - batch
        c - channels
        w1 - window size (height)
        w2 - also window size (width)
        i - sequence dimension (source)
        j - sequence dimension (target dimension to be reduced)
        h - heads
        x - height of feature map divided by window size
        y - width of feature map divided by window size
        """

        # 获取输入张量的形状信息
        batch, height, width, heads, wsz = x.shape[0], *x.shape[-2:], self.heads, self.window_size
        # 检查高度和宽度是否可以被窗口大小整除
        assert (height % wsz) == 0 and (width % wsz) == 0, f'height {height} and width {width} must be divisible by window size {wsz}'
        # 计算窗口数量
        num_windows = (height // wsz) * (width // wsz)

        # 对输入张量进行归一化处理
        x = self.norm(x)

        # 将窗口折叠进行“深度”注意力 - 不确定为什么它被命名为深度,当它只是“窗口化”注意力时
        x = rearrange(x, 'b c (h w1) (w w2) -> (b h w) c (w1 w2)', w1 = wsz, w2 = wsz)

        # 添加窗口标记
        w = repeat(self.window_tokens, 'c -> b c 1', b = x.shape[0])
        x = torch.cat((w, x), dim = -1)

        # 为查询、键、值进行投影
        q, k, v = self.to_qkv(x).chunk(3, dim = 1)

        # 分离头部
        q, k, v = map(lambda t: rearrange(t, 'b (h d) ... -> b h (...) d', h = heads), (q, k, v))

        # 缩放
        q = q * self.scale

        # 相似度
        dots = einsum('b h i d, b h j d -> b h i j', q, k)

        # 注意力
        attn = self.attend(dots)

        # 聚合值
        out = torch.matmul(attn, v)

        # 分离窗口标记和窗口化特征图
        window_tokens, windowed_fmaps = out[:, :, 0], out[:, :, 1:]

        # 如果只有一个窗口,则提前返回
        if num_windows == 1:
            fmap = rearrange(windowed_fmaps, '(b x y) h (w1 w2) d -> b (h d) (x w1) (y w2)', x = height // wsz, y = width // wsz, w1 = wsz, w2 = wsz)
            return self.to_out(fmap)

        # 执行点对点注意力,这是论文中的主要创新
        window_tokens = rearrange(window_tokens, '(b x y) h d -> b h (x y) d', x = height // wsz, y = width // wsz)
        windowed_fmaps = rearrange(windowed_fmaps, '(b x y) h n d -> b h (x y) n d', x = height // wsz, y = width // wsz)

        # 窗口化查询和键(在进行预归一化激活之前)
        w_q, w_k = self.window_tokens_to_qk(window_tokens).chunk(2, dim = -1)

        # 缩放
        w_q = w_q * self.scale

        # 相似度
        w_dots = einsum('b h i d, b h j d -> b h i j', w_q, w_k)

        w_attn = self.window_attend(w_dots)

        # 聚合来自“深度”注意力步骤的特征图(论文中最有趣的部分,我以前没有见过)
        aggregated_windowed_fmap = einsum('b h i j, b h j w d -> b h i w d', w_attn, windowed_fmaps)

        # 折叠回窗口,然后组合头部以进行聚合
        fmap = rearrange(aggregated_windowed_fmap, 'b h (x y) (w1 w2) d -> b (h d) (x w1) (y w2)', x = height // wsz, y = width // wsz, w1 = wsz, w2 = wsz)
        return self.to_out(fmap)
class Transformer(nn.Module):
    # 定义 Transformer 类,继承自 nn.Module
    def __init__(
        self,
        dim,
        depth,
        dim_head = 32,
        heads = 8,
        ff_mult = 4,
        dropout = 0.,
        norm_output = True
    ):
        # 初始化函数,接受多个参数
        super().__init__()
        # 调用父类的初始化函数
        self.layers = nn.ModuleList([])

        for ind in range(depth):
            # 遍历深度次数
            self.layers.append(nn.ModuleList([
                DSSA(dim, heads = heads, dim_head = dim_head, dropout = dropout),
                FeedForward(dim, mult = ff_mult, dropout = dropout),
            ]))
            # 在 layers 中添加 DSSA 和 FeedForward 模块

        self.norm = ChanLayerNorm(dim) if norm_output else nn.Identity()
        # 如果 norm_output 为 True,则使用 ChanLayerNorm,否则使用 nn.Identity

    def forward(self, x):
        # 前向传播函数
        for attn, ff in self.layers:
            # 遍历 layers 中的模块
            x = attn(x) + x
            # 对输入 x 进行注意力操作
            x = ff(x) + x
            # 对输入 x 进行前馈操作

        return self.norm(x)
        # 返回经过规范化的结果

class SepViT(nn.Module):
    # 定义 SepViT 类,继承自 nn.Module
    def __init__(
        self,
        *,
        num_classes,
        dim,
        depth,
        heads,
        window_size = 7,
        dim_head = 32,
        ff_mult = 4,
        channels = 3,
        dropout = 0.
    ):
        # 初始化函数,接受多个参数
        super().__init__()
        # 调用父类的初始化函数
        assert isinstance(depth, tuple), 'depth needs to be tuple if integers indicating number of transformer blocks at that stage'
        # 断言 depth 是元组类型,用于指示每个阶段的 transformer 块数量

        num_stages = len(depth)
        # 获取深度的长度

        dims = tuple(map(lambda i: (2 ** i) * dim, range(num_stages)))
        dims = (channels, *dims)
        dim_pairs = tuple(zip(dims[:-1], dims[1:]))
        # 计算每个阶段的维度

        strides = (4, *((2,) * (num_stages - 1)))
        # 定义步长

        hyperparams_per_stage = [heads, window_size]
        hyperparams_per_stage = list(map(partial(cast_tuple, length = num_stages), hyperparams_per_stage))
        assert all(tuple(map(lambda arr: len(arr) == num_stages, hyperparams_per_stage)))
        # 处理每个阶段的超参数

        self.layers = nn.ModuleList([])

        for ind, ((layer_dim_in, layer_dim), layer_depth, layer_stride, layer_heads, layer_window_size) in enumerate(zip(dim_pairs, depth, strides, *hyperparams_per_stage)):
            # 遍历每个阶段的参数
            is_last = ind == (num_stages - 1)

            self.layers.append(nn.ModuleList([
                OverlappingPatchEmbed(layer_dim_in, layer_dim, stride = layer_stride),
                PEG(layer_dim),
                Transformer(dim = layer_dim, depth = layer_depth, heads = layer_heads, ff_mult = ff_mult, dropout = dropout, norm_output = not is_last),
            ]))
            # 在 layers 中添加 OverlappingPatchEmbed、PEG 和 Transformer 模块

        self.mlp_head = nn.Sequential(
            Reduce('b d h w -> b d', 'mean'),
            nn.LayerNorm(dims[-1]),
            nn.Linear(dims[-1], num_classes)
        )
        # 定义 MLP 头部模块

    def forward(self, x):
        # 前向传播函数
        for ope, peg, transformer in self.layers:
            # 遍历 layers 中的模块
            x = ope(x)
            # 对输入 x 进行 OverlappingPatchEmbed 操作
            x = peg(x)
            # 对输入 x 进行 PEG 操作
            x = transformer(x)
            # 对输入 x 进行 Transformer 操作

        return self.mlp_head(x)
        # 返回经过 MLP 头部处理的结果

.\lucidrains\vit-pytorch\vit_pytorch\simmim.py

import torch
from torch import nn
import torch.nn.functional as F
from einops import repeat

class SimMIM(nn.Module):
    def __init__(
        self,
        *,
        encoder,
        masking_ratio = 0.5
    ):
        super().__init__()
        assert masking_ratio > 0 and masking_ratio < 1, 'masking ratio must be kept between 0 and 1'
        self.masking_ratio = masking_ratio

        # extract some hyperparameters and functions from encoder (vision transformer to be trained)

        self.encoder = encoder
        num_patches, encoder_dim = encoder.pos_embedding.shape[-2:]

        self.to_patch = encoder.to_patch_embedding[0]
        self.patch_to_emb = nn.Sequential(*encoder.to_patch_embedding[1:])

        pixel_values_per_patch = encoder.to_patch_embedding[2].weight.shape[-1]

        # simple linear head

        self.mask_token = nn.Parameter(torch.randn(encoder_dim))
        self.to_pixels = nn.Linear(encoder_dim, pixel_values_per_patch)

    def forward(self, img):
        device = img.device

        # get patches

        patches = self.to_patch(img)
        batch, num_patches, *_ = patches.shape

        # for indexing purposes

        batch_range = torch.arange(batch, device = device)[:, None]

        # get positions

        pos_emb = self.encoder.pos_embedding[:, 1:(num_patches + 1)]

        # patch to encoder tokens and add positions

        tokens = self.patch_to_emb(patches)
        tokens = tokens + pos_emb

        # prepare mask tokens

        mask_tokens = repeat(self.mask_token, 'd -> b n d', b = batch, n = num_patches)
        mask_tokens = mask_tokens + pos_emb

        # calculate of patches needed to be masked, and get positions (indices) to be masked

        num_masked = int(self.masking_ratio * num_patches)
        masked_indices = torch.rand(batch, num_patches, device = device).topk(k = num_masked, dim = -1).indices
        masked_bool_mask = torch.zeros((batch, num_patches), device = device).scatter_(-1, masked_indices, 1).bool()

        # mask tokens

        tokens = torch.where(masked_bool_mask[..., None], mask_tokens, tokens)

        # attend with vision transformer

        encoded = self.encoder.transformer(tokens)

        # get the masked tokens

        encoded_mask_tokens = encoded[batch_range, masked_indices]

        # small linear projection for predicted pixel values

        pred_pixel_values = self.to_pixels(encoded_mask_tokens)

        # get the masked patches for the final reconstruction loss

        masked_patches = patches[batch_range, masked_indices]

        # calculate reconstruction loss

        recon_loss = F.l1_loss(pred_pixel_values, masked_patches) / num_masked
        return recon_loss

.\lucidrains\vit-pytorch\vit_pytorch\simple_flash_attn_vit.py

# 导入必要的库
from collections import namedtuple
from packaging import version

import torch
import torch.nn.functional as F
from torch import nn

from einops import rearrange
from einops.layers.torch import Rearrange

# 定义常量
Config = namedtuple('FlashAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])

# 定义辅助函数

# 将输入转换为元组
def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# 生成二维位置编码的正弦和余弦值
def posemb_sincos_2d(patches, temperature = 10000, dtype = torch.float32):
    _, h, w, dim, device, dtype = *patches.shape, patches.device, patches.dtype

    y, x = torch.meshgrid(torch.arange(h, device = device), torch.arange(w, device = device), indexing = 'ij')
    assert (dim % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb'
    omega = torch.arange(dim // 4, device = device) / (dim // 4 - 1)
    omega = 1. / (temperature ** omega)

    y = y.flatten()[:, None] * omega[None, :]
    x = x.flatten()[:, None] * omega[None, :] 
    pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1)
    return pe.type(dtype)

# 主类

class Attend(nn.Module):
    def __init__(self, use_flash = False):
        super().__init__()
        self.use_flash = use_flash
        assert not (use_flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'

        # 确定 CUDA 和 CPU 的高效注意力配置

        self.cpu_config = Config(True, True, True)
        self.cuda_config = None

        if not torch.cuda.is_available() or not use_flash:
            return

        device_properties = torch.cuda.get_device_properties(torch.device('cuda'))

        if device_properties.major == 8 and device_properties.minor == 0:
            self.cuda_config = Config(True, False, False)
        else:
            self.cuda_config = Config(False, True, True)

    def flash_attn(self, q, k, v):
        config = self.cuda_config if q.is_cuda else self.cpu_config

        # Flash Attention - https://arxiv.org/abs/2205.14135
        
        with torch.backends.cuda.sdp_kernel(**config._asdict()):
            out = F.scaled_dot_product_attention(q, k, v)

        return out

    def forward(self, q, k, v):
        n, device, scale = q.shape[-2], q.device, q.shape[-1] ** -0.5

        if self.use_flash:
            return self.flash_attn(q, k, v)

        # 相似度

        sim = einsum("b h i d, b j d -> b h i j", q, k) * scale

        # 注意力

        attn = sim.softmax(dim=-1)

        # 聚合值

        out = einsum("b h i j, b j d -> b h i d", attn, v)

        return out

# 类

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, dim),
        )
    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, use_flash = True):
        super().__init__()
        inner_dim = dim_head *  heads
        self.heads = heads
        self.scale = dim_head ** -0.5
        self.norm = nn.LayerNorm(dim)

        self.attend = Attend(use_flash = use_flash)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        self.to_out = nn.Linear(inner_dim, dim, bias = False)

    def forward(self, x):
        x = self.norm(x)

        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        out = self.attend(q, k, v)

        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class Transformer(nn.Module):
    # 初始化 Transformer 模型
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, use_flash):
        # 调用父类的初始化方法
        super().__init__()
        # 创建一个空的层列表
        self.layers = nn.ModuleList([])
        # 根据深度循环创建多个 Transformer 层
        for _ in range(depth):
            # 每个 Transformer 层包含注意力机制和前馈神经网络
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head, use_flash = use_flash),
                FeedForward(dim, mlp_dim)
            ]))
    
    # 前向传播函数
    def forward(self, x):
        # 遍历每个 Transformer 层
        for attn, ff in self.layers:
            # 使用注意力机制处理输入,并将结果与输入相加
            x = attn(x) + x
            # 使用前馈神经网络处理输入,并将结果与输入相加
            x = ff(x) + x
        # 返回处理后的结果
        return x
# 定义一个简单的ViT模型,继承自nn.Module类
class SimpleViT(nn.Module):
    # 初始化函数,接收一系列参数来构建ViT模型
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, use_flash = True):
        super().__init__()
        # 获取图像的高度和宽度
        image_height, image_width = pair(image_size)
        # 获取patch的高度和宽度
        patch_height, patch_width = pair(patch_size)

        # 断言图像的高度和宽度必须能够被patch的高度和宽度整除
        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        # 计算patch的数量
        num_patches = (image_height // patch_height) * (image_width // patch_width)
        # 计算每个patch的维度
        patch_dim = channels * patch_height * patch_width

        # 将图像转换为patch嵌入
        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b h w (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim),
        )

        # 创建Transformer模块
        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, use_flash)

        # 将嵌入转换为潜在空间
        self.to_latent = nn.Identity()
        # 线性头部,用于分类
        self.linear_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    # 前向传播函数,接收图像作为输入
    def forward(self, img):
        # 获取图像的形状和数据类型
        *_, h, w, dtype = *img.shape, img.dtype

        # 将图像转换为patch嵌入
        x = self.to_patch_embedding(img)
        # 生成位置编码
        pe = posemb_sincos_2d(x)
        # 将位置编码加到嵌入中
        x = rearrange(x, 'b ... d -> b (...) d') + pe

        # 经过Transformer模块处理
        x = self.transformer(x)
        # 对所有patch的输出取平均值
        x = x.mean(dim = 1)

        # 转换为潜在空间
        x = self.to_latent(x)
        # 使用线性头部进行分类
        return self.linear_head(x)

.\lucidrains\vit-pytorch\vit_pytorch\simple_vit.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块
from torch import nn

# 从 einops 库中导入 rearrange 和 Rearrange 函数
from einops import rearrange
from einops.layers.torch import Rearrange

# 定义辅助函数

# 如果 t 是元组则返回 t,否则返回 (t, t)
def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# 生成二维正弦余弦位置编码
def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32):
    # 生成网格坐标
    y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
    # 确保特征维度是 4 的倍数
    assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
    # 计算 omega
    omega = torch.arange(dim // 4) / (dim // 4 - 1)
    omega = 1.0 / (temperature ** omega)

    y = y.flatten()[:, None] * omega[None, :]
    x = x.flatten()[:, None] * omega[None, :]
    # 拼接正弦余弦位置编码
    pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
    return pe.type(dtype)

# 定义类

# 前馈神经网络类
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, dim),
        )
    def forward(self, x):
        return self.net(x)

# 注意力机制类
class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64):
        super().__init__()
        inner_dim = dim_head *  heads
        self.heads = heads
        self.scale = dim_head ** -0.5
        self.norm = nn.LayerNorm(dim)

        self.attend = nn.Softmax(dim = -1)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        self.to_out = nn.Linear(inner_dim, dim, bias = False)

    def forward(self, x):
        x = self.norm(x)

        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

# Transformer 类
class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head),
                FeedForward(dim, mlp_dim)
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return self.norm(x)

# 简单 ViT 模型类
class SimpleViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        patch_dim = channels * patch_height * patch_width

        self.to_patch_embedding = nn.Sequential(
            Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim),
        )

        self.pos_embedding = posemb_sincos_2d(
            h = image_height // patch_height,
            w = image_width // patch_width,
            dim = dim,
        ) 

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)

        self.pool = "mean"
        self.to_latent = nn.Identity()

        self.linear_head = nn.Linear(dim, num_classes)

    def forward(self, img):
        device = img.device

        x = self.to_patch_embedding(img)
        x += self.pos_embedding.to(device, dtype=x.dtype)

        x = self.transformer(x)
        x = x.mean(dim = 1)

        x = self.to_latent(x)
        return self.linear_head(x)