Lucidrains-系列项目源码解析-四十三-

139 阅读18分钟

Lucidrains 系列项目源码解析(四十三)

.\lucidrains\gradnorm-pytorch\setup.py

# 导入设置工具和查找包工具
from setuptools import setup, find_packages

# 设置包的信息
setup(
  # 包名
  name = 'gradnorm-pytorch',
  # 查找所有包,不排除任何包
  packages = find_packages(exclude=[]),
  # 版本号
  version = '0.0.26',
  # 许可证
  license='MIT',
  # 描述
  description = 'GradNorm - Pytorch',
  # 作者
  author = 'Phil Wang',
  # 作者邮箱
  author_email = 'lucidrains@gmail.com',
  # 长描述内容类型
  long_description_content_type = 'text/markdown',
  # 项目链接
  url = 'https://github.com/lucidrains/gradnorm-pytorch',
  # 关键词
  keywords = [
    'artificial intelligence',
    'deep learning',
    'loss balancing',
    'gradient normalization'
  ],
  # 安装依赖
  install_requires=[
    'accelerate',
    'beartype',
    'einops>=0.7.0',
    'torch>=2.0'
  ],
  # 分类
  classifiers=[
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\graph-transformer-pytorch\graph_transformer_pytorch\graph_transformer_pytorch.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn, einsum
from torch import nn, einsum
# 从 einops 库中导入 rearrange, repeat

from einops import rearrange, repeat

# 从 rotary_embedding_torch 库中导入 RotaryEmbedding, apply_rotary_emb

# helpers

# 判断变量是否存在的函数
def exists(val):
    return val is not None

# 如果变量存在则返回该变量,否则返回默认值的函数
def default(val, d):
    return val if exists(val) else d

# 定义 nn.ModuleList 类别为 List
List = nn.ModuleList

# normalizations

# 预处理层,包含 LayerNorm 和传入的函数
class PreNorm(nn.Module):
    def __init__(
        self,
        dim,
        fn
    ):
        super().__init__()
        self.fn = fn
        self.norm = nn.LayerNorm(dim)

    def forward(self, x, *args, **kwargs):
        x = self.norm(x)
        return self.fn(x, *args,**kwargs)

# gated residual

# 残差连接层
class Residual(nn.Module):
    def forward(self, x, res):
        return x + res

# 带门控的残差连接层
class GatedResidual(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.proj = nn.Sequential(
            nn.Linear(dim * 3, 1, bias = False),
            nn.Sigmoid()
        )

    def forward(self, x, res):
        gate_input = torch.cat((x, res, x - res), dim = -1)
        gate = self.proj(gate_input)
        return x * gate + res * (1 - gate)

# attention

# 注意力机制层
class Attention(nn.Module):
    def __init__(
        self,
        dim,
        pos_emb = None,
        dim_head = 64,
        heads = 8,
        edge_dim = None
    ):
        super().__init__()
        edge_dim = default(edge_dim, dim)

        inner_dim = dim_head * heads
        self.heads = heads
        self.scale = dim_head ** -0.5

        self.pos_emb = pos_emb

        self.to_q = nn.Linear(dim, inner_dim)
        self.to_kv = nn.Linear(dim, inner_dim * 2)
        self.edges_to_kv = nn.Linear(edge_dim, inner_dim)

        self.to_out = nn.Linear(inner_dim, dim)

    def forward(self, nodes, edges, mask = None):
        h = self.heads

        q = self.to_q(nodes)
        k, v = self.to_kv(nodes).chunk(2, dim = -1)

        e_kv = self.edges_to_kv(edges)

        q, k, v, e_kv = map(lambda t: rearrange(t, 'b ... (h d) -> (b h) ... d', h = h), (q, k, v, e_kv))

        if exists(self.pos_emb):
            freqs = self.pos_emb(torch.arange(nodes.shape[1], device = nodes.device))
            freqs = rearrange(freqs, 'n d -> () n d')
            q = apply_rotary_emb(freqs, q)
            k = apply_rotary_emb(freqs, k)

        ek, ev = e_kv, e_kv

        k, v = map(lambda t: rearrange(t, 'b j d -> b () j d '), (k, v))
        k = k + ek
        v = v + ev

        sim = einsum('b i d, b i j d -> b i j', q, k) * self.scale

        if exists(mask):
            mask = rearrange(mask, 'b i -> b i ()') & rearrange(mask, 'b j -> b () j')
            mask = repeat(mask, 'b i j -> (b h) i j', h = h)
            max_neg_value = -torch.finfo(sim.dtype).max
            sim.masked_fill_(~mask, max_neg_value)

        attn = sim.softmax(dim = -1)
        out = einsum('b i j, b i j d -> b i d', attn, v)
        out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
        return self.to_out(out)

# optional feedforward

# 可选的前馈神经网络层
def FeedForward(dim, ff_mult = 4):
    return nn.Sequential(
        nn.Linear(dim, dim * ff_mult),
        nn.GELU(),
        nn.Linear(dim * ff_mult, dim)
    )

# classes

# 图形变换器模型
class GraphTransformer(nn.Module):
    def __init__(
        self,
        dim,
        depth,
        dim_head = 64,
        edge_dim = None,
        heads = 8,
        gated_residual = True,
        with_feedforwards = False,
        norm_edges = False,
        rel_pos_emb = False,
        accept_adjacency_matrix = False
    # 初始化函数,继承父类的初始化方法
    ):
        # 调用父类的初始化方法
        super().__init__()
        # 初始化图神经网络的层列表
        self.layers = List([])
        # 设置边的维度,默认为节点的维度
        edge_dim = default(edge_dim, dim)
        # 如果需要对边进行归一化,则使用 LayerNorm 进行归一化,否则使用恒等映射
        self.norm_edges = nn.LayerNorm(edge_dim) if norm_edges else nn.Identity()

        # 如果需要接受邻接矩阵,则使用 Embedding 层进行嵌入,否则设为 None
        self.adj_emb = nn.Embedding(2, edge_dim) if accept_adjacency_matrix else None

        # 如果需要相对位置编码,则使用 RotaryEmbedding 进行编码,否则设为 None
        pos_emb = RotaryEmbedding(dim_head) if rel_pos_emb else None

        # 循环创建指定深度的图神经网络层
        for _ in range(depth):
            # 添加每一层的注意力机制和前馈网络
            self.layers.append(List([
                List([
                    # 使用预归一化和注意力机制
                    PreNorm(dim, Attention(dim, pos_emb = pos_emb, edge_dim = edge_dim, dim_head = dim_head, heads = heads)),
                    GatedResidual(dim)
                ]),
                List([
                    # 使用预归一化和前馈网络
                    PreNorm(dim, FeedForward(dim)),
                    GatedResidual(dim)
                ]) if with_feedforwards else None
            ]))

    # 前向传播函数
    def forward(
        self,
        nodes,
        edges = None,
        adj_mat = None,
        mask = None
    ):
        # 获取节点的批次大小、序列长度和维度
        batch, seq, _ = nodes.shape

        # 如果存在边信息,则对边进行归一化处理
        if exists(edges):
            edges = self.norm_edges(edges)

        # 如果存在邻接矩阵,则进行相应处理
        if exists(adj_mat):
            assert adj_mat.shape == (batch, seq, seq)
            assert exists(self.adj_emb), 'accept_adjacency_matrix must be set to True'
            adj_mat = self.adj_emb(adj_mat.long())

        # 组合所有边信息
        all_edges = default(edges, 0) + default(adj_mat, 0)

        # 遍历每一层的注意力机制和前馈网络
        for attn_block, ff_block in self.layers:
            attn, attn_residual = attn_block
            # 使用注意力机制和门控残差连接更新节点信息
            nodes = attn_residual(attn(nodes, all_edges, mask = mask), nodes)

            # 如果存在前馈网络,则使用前馈网络和门控残差连接更新节点信息
            if exists(ff_block):
                ff, ff_residual = ff_block
                nodes = ff_residual(ff(nodes), nodes)

        # 返回更新后的节点信息和边信息
        return nodes, edges

.\lucidrains\graph-transformer-pytorch\graph_transformer_pytorch\__init__.py

# 从 graph_transformer_pytorch 包中导入 GraphTransformer 类
from graph_transformer_pytorch.graph_transformer_pytorch import GraphTransformer

Graph Transformer - Pytorch

Implementation of Graph Transformer in Pytorch, for potential use in replicating Alphafold2. This was recently used by both Costa et al and Bakers lab for transforming MSA and pair-wise embedding into 3d coordinates.

Install

$ pip install graph-transformer-pytorch

Usage

import torch
from graph_transformer_pytorch import GraphTransformer

model = GraphTransformer(
    dim = 256,
    depth = 6,
    edge_dim = 512,             # optional - if left out, edge dimensions is assumed to be the same as the node dimensions above
    with_feedforwards = True,   # whether to add a feedforward after each attention layer, suggested by literature to be needed
    gated_residual = True,      # to use the gated residual to prevent over-smoothing
    rel_pos_emb = True          # set to True if the nodes are ordered, default to False
)

nodes = torch.randn(1, 128, 256)
edges = torch.randn(1, 128, 128, 512)
mask = torch.ones(1, 128).bool()

nodes, edges = model(nodes, edges, mask = mask)

nodes.shape # (1, 128, 256) - project to R^3 for coordinates

If you want it to handle an adjacency matrix

import torch
from graph_transformer_pytorch import GraphTransformer

model = GraphTransformer(
    dim = 256,
    depth = 6,
    edge_dim = 512,
    with_feedforwards = True,
    gated_residual = True,
    rel_pos_emb = True,
    accept_adjacency_matrix = True  # set this to True
)

nodes = torch.randn(2, 128, 256)
adj_mat = torch.randint(0, 2, (2, 128, 128))
mask = torch.ones(2, 128).bool()

nodes, edges = model(nodes, adj_mat = adj_mat, mask = mask)

nodes.shape # (1, 128, 256) - project to R^3 for coordinates

Citations

@article {Costa2021.06.02.446809,
    author  = {Costa, Allan and Ponnapati, Manvitha and Jacobson, Joseph M. and Chatterjee, Pranam},
    title   = {Distillation of MSA Embeddings to Folded Protein Structures with Graph Transformers},
    year    = {2021},
    doi     = {10.1101/2021.06.02.446809},
    publisher = {Cold Spring Harbor Laboratory},
    URL     = {https://www.biorxiv.org/content/early/2021/06/02/2021.06.02.446809},
    eprint  = {https://www.biorxiv.org/content/early/2021/06/02/2021.06.02.446809.full.pdf},
    journal = {bioRxiv}
}
@article {Baek2021.06.14.448402,
    author  = {Baek, Minkyung and DiMaio, Frank and Anishchenko, Ivan and Dauparas, Justas and Ovchinnikov, Sergey and Lee, Gyu Rie and Wang, Jue and Cong, Qian and Kinch, Lisa N. and Schaeffer, R. Dustin and Mill{\'a}n, Claudia and Park, Hahnbeom and Adams, Carson and Glassman, Caleb R. and DeGiovanni, Andy and Pereira, Jose H. and Rodrigues, Andria V. and van Dijk, Alberdina A. and Ebrecht, Ana C. and Opperman, Diederik J. and Sagmeister, Theo and Buhlheller, Christoph and Pavkov-Keller, Tea and Rathinaswamy, Manoj K and Dalwadi, Udit and Yip, Calvin K and Burke, John E and Garcia, K. Christopher and Grishin, Nick V. and Adams, Paul D. and Read, Randy J. and Baker, David},
    title   = {Accurate prediction of protein structures and interactions using a 3-track network},
    year    = {2021},
    doi     = {10.1101/2021.06.14.448402},
    publisher = {Cold Spring Harbor Laboratory},
    URL     = {https://www.biorxiv.org/content/early/2021/06/15/2021.06.14.448402},
    eprint  = {https://www.biorxiv.org/content/early/2021/06/15/2021.06.14.448402.full.pdf},
    journal = {bioRxiv}
}
@misc{shi2021masked,
    title   = {Masked Label Prediction: Unified Message Passing Model for Semi-Supervised Classification}, 
    author  = {Yunsheng Shi and Zhengjie Huang and Shikun Feng and Hui Zhong and Wenjin Wang and Yu Sun},
    year    = {2021},
    eprint  = {2009.03509},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}

.\lucidrains\graph-transformer-pytorch\setup.py

# 导入设置和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'graph-transformer-pytorch',  # 包的名称
  packages = find_packages(),  # 查找所有包
  version = '0.1.1',  # 版本号
  license='MIT',  # 许可证
  description = 'Graph Transformer - Pytorch',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  url = 'https://github.com/lucidrains/graph-transformer-pytorch',  # 项目链接
  long_description_content_type = 'text/markdown',  # 长描述内容类型
  keywords = [
    'artificial intelligence',  # 关键词
    'deep learning',  # 关键词
    'transformers',  # 关键词
    'graphs'  # 关键词
  ],
  install_requires=[
    'einops>=0.3',  # 安装所需的依赖包
    'rotary-embedding-torch',  # 安装所需的依赖包
    'torch>=1.6'  # 安装所需的依赖包
  ],
  classifiers=[
    'Development Status :: 4 - Beta',  # 分类器
    'Intended Audience :: Developers',  # 分类器
    'Topic :: Scientific/Engineering :: Artificial Intelligence',  # 分类器
    'License :: OSI Approved :: MIT License',  # 分类器
    'Programming Language :: Python :: 3.6',  # 分类器
  ],
)

Data source

The enwik8 data was downloaded from the Hutter prize page: prize.hutter1.net/

.\lucidrains\h-transformer-1d\h_transformer_1d\autoregressive_wrapper.py

import torch
from torch import nn
import torch.nn.functional as F

# helper function

# 检查值是否存在
def exists(val):
    return val is not None

# 装饰器函数,用于在模型评估时切换模型状态
def eval_decorator(fn):
    def inner(model, *args, **kwargs):
        was_training = model.training
        model.eval()
        out = fn(model, *args, **kwargs)
        model.train(was_training)
        return out
    return inner

# top k filtering

# 根据阈值过滤 logits,保留前 k 个值
def top_k(logits, thres = 0.9):
    k = int((1 - thres) * logits.shape[-1])
    val, ind = torch.topk(logits, k)
    probs = torch.full_like(logits, float('-inf'))
    probs.scatter_(1, ind, val)
    return probs

# 自回归包装器类
class AutoregressiveWrapper(nn.Module):
    def __init__(self, net, ignore_index = -100, pad_value = 0):
        super().__init__()
        self.pad_value = pad_value
        self.ignore_index = ignore_index

        self.net = net
        self.max_seq_len = net.max_seq_len

    # 生成序列的方法,支持自定义起始标记、序列长度、结束标记、温度等参数
    @torch.no_grad()
    @eval_decorator
    def generate(self, start_tokens, seq_len, eos_token = None, temperature = 1., filter_logits_fn = top_k, filter_thres = 0.9, **kwargs):
        device = start_tokens.device
        num_dims = len(start_tokens.shape)

        if num_dims == 1:
            start_tokens = start_tokens[None, :]

        b, t = start_tokens.shape

        out = start_tokens

        for _ in range(seq_len):
            x = out[:, -self.max_seq_len:]

            logits = self.net(x, **kwargs)[:, -1, :]

            filtered_logits = top_k(logits, thres = filter_thres)
            probs = F.softmax(filtered_logits / temperature, dim=-1)

            sample = torch.multinomial(probs, 1)

            out = torch.cat((out, sample), dim=-1)

            if exists(eos_token):
                is_eos_token = (out == eos_token)

                if is_eos_token.any(dim = -1).all():
                    # mask out everything after the eos tokens
                    shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
                    mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1
                    out = out.masked_fill(mask, self.pad_value)
                    break

        out = out[:, t:]

        if num_dims == 1:
            out = out.squeeze(0)

        return out

    # 前向传播方法,计算损失值
    def forward(self, x, **kwargs):
        xi = x[:, :-1]
        xo = x[:, 1:]

        out = self.net(xi, **kwargs)
        loss = F.cross_entropy(out.transpose(1, 2), xo, ignore_index = self.ignore_index)
        return loss

.\lucidrains\h-transformer-1d\h_transformer_1d\h_transformer_1d.py

# 从 math 模块中导入 log2 和 ceil 函数
# 从 functools 模块中导入 wraps 函数
import torch
# 从 torch 模块中导入 nn, einsum, diagonal 和 nn.functional 模块
from torch import nn, einsum, diagonal
import torch.nn.functional as F
# 从 h_transformer_1d.reversible 模块中导入 ReversibleSequence 和 SequentialSequence 类
from h_transformer_1d.reversible import ReversibleSequence, SequentialSequence
# 从 rotary_embedding_torch 模块中导入 apply_rotary_emb 和 RotaryEmbedding 类
from rotary_embedding_torch import apply_rotary_emb, RotaryEmbedding
# 从 einops 模块中导入 rearrange, reduce, repeat 函数

# helpers

# 定义函数 exists,判断变量是否存在
def exists(val):
    return val is not None

# 定义函数 masked_aggregate,对张量进行聚合操作
def masked_aggregate(tensor, mask = None, dim = -1, average = True):
    if not exists(mask):
        fn = torch.sum if not average else torch.mean
        return fn(tensor, dim = dim)

    diff_len = len(tensor.shape) - len(mask.shape)
    mask = mask[(..., *((None,) * diff_len))]
    tensor = tensor.masked_fill(~mask, 0.)

    total_el = mask.sum(dim = dim)
    agg = tensor.sum(dim = dim)

    if average:
        agg = agg / total_el.clamp(min = 1.)

    agg.masked_fill_(total_el == 0, 0.)
    return agg

# 定义函数 shift,对张量进行平移操作
def shift(t, amount, mask = None):
    if amount == 0:
        return t

    if exists(mask):
        t = t.masked_fill(~mask[..., None], 0.)

    return F.pad(t, (0, 0, amount, -amount), value = 0.)

# helper classes

# 定义类 PreNorm,实现预层归一化
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = nn.LayerNorm(dim)

    def forward(self, x, **kwargs):
        x = self.norm(x)
        return self.fn(x, **kwargs)

# 定义类 FeedForward,实现前馈神经网络
class FeedForward(nn.Module):
    def __init__(
        self,
        dim,
        *,
        mult = 4
    ):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim * mult),
            nn.GELU(),
            nn.Linear(dim * mult, dim)
        )

    def forward(self, x):
        return self.net(x)

# token shifting

# 定义类 PreShiftTokens,实现令牌平移
class PreShiftTokens(nn.Module):
    def __init__(self, shifts, fn):
        super().__init__()
        self.fn = fn
        self.shifts = tuple(shifts)

    def forward(self, x, **kwargs):
        mask = kwargs.get('mask', None)
        shifts = self.shifts
        segments = len(shifts)
        feats_per_shift = x.shape[-1] // segments
        splitted = x.split(feats_per_shift, dim = -1)
        segments_to_shift, rest = splitted[:segments], splitted[segments:]
        segments_to_shift = list(map(lambda args: shift(*args, mask = mask), zip(segments_to_shift, shifts)))
        x = torch.cat((*segments_to_shift, *rest), dim = -1)
        return self.fn(x, **kwargs)

# hierarchical attention helper functions

# 定义函数 cast_for_op,将张量转换为指定类型进行操作
def cast_for_op(cast_type, fn):
    @wraps(fn)
    def inner(t, *args, **kwargs):
        orig_type = t.dtype
        t = t.type(cast_type)
        out = fn(t, *args, **kwargs)
        out = out.type(orig_type)
        return out
    return inner

# 定义函数 flip_every_two,交换张量中每两个元素的位置
def flip_every_two(t):
    t = rearrange(t, 'b (n r) ... -> b n r ...', r = 2)
    t = torch.flip(t, dims = (2,))                          # so we pay attention to the off-diagonal blocks in the attention matrix
    t = rearrange(t, 'b n r ... -> b (n r) ...')
    return t

# attention

# 定义类 HAttention1D,实现一维注意力机制
class HAttention1D(nn.Module):
    def __init__(
        self,
        dim,
        *,
        heads = 8,
        dim_head = 64,
        block_size = 16,
        pos_emb = None,
        eps = 1e-8,
        **kwargs
    ):
        super().__init__()
        self.eps = eps
        self.heads = heads
        self.scale = dim_head ** -0.5
        self.block_size = block_size
        inner_dim = heads * dim_head

        self.pos_emb = pos_emb
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        self.to_out = nn.Linear(inner_dim, dim)

# causal attention

# 定义类 CausalHAttention1D,实现一维因果注意力机制
class CausalHAttention1D(nn.Module):
    def __init__(
        self,
        dim,
        *,
        max_seq_len,
        heads = 8,
        dim_head = 64,
        block_size = 16,
        eps = 1e-8,
        pos_emb = None
        ):
        # 调用父类的初始化方法
        super().__init__()
        # 初始化注意力机制的参数
        self.eps = eps
        self.heads = heads
        self.scale = dim_head ** -0.5
        self.block_size = block_size
        inner_dim = heads * dim_head

        # 设置位置编码
        self.pos_emb = pos_emb

        # 线性变换,将输入维度转换为内部维度的三倍,用于计算查询、键、值
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        # 线性变换,将内部维度转换为输出维度
        self.to_out = nn.Linear(inner_dim, dim)

        # 推导出掩码

        # 计算级别数量
        num_levels = int(log2(max_seq_len // block_size)) - 1
        root_seq = torch.arange(max_seq_len)
        seqs = [root_seq]
        seq = root_seq

        # 生成掩码序列
        for ind in range(num_levels):
            seq = rearrange(seq, '(n r) -> n r', r = 2)
            seq = seq.max(dim = -1).values
            expanded_mask_seq = repeat(seq, 'n -> (n r)', r = (2 ** (ind + 1)))
            seqs.append(expanded_mask_seq)

        # 将生成的掩码序列堆叠起来
        seq_keys = torch.stack(seqs, dim = 0)
        # 创建掩码,用于屏蔽无效位置
        mask = seq_keys > rearrange(root_seq, 'n -> () n')
        # 将掩码作为缓冲区注册到模型中
        self.register_buffer('mask', mask)
# 主类定义

class HTransformer1D(nn.Module):
    def __init__(
        self,
        *,
        num_tokens,          # 标记的数量
        dim,                 # 向量维度
        depth,               # 深度
        max_seq_len,         # 最大序列长度
        causal = False,      # 是否因果
        heads = 8,           # 多头注意力的头数
        dim_head = 64,       # 每个头的维度
        ff_mult = 4,         # FeedForward 层的倍数
        block_size = 128,    # 块的大小,即 Nr
        pos_emb = None,      # 位置编码
        reversible = False,  # 是否可逆
        shift_tokens = False # 是否移动标记
    ):
        super().__init__()
        assert (max_seq_len % block_size) == 0, 'maximum sequence length must be divisible by the block size'
        num_blocks = max_seq_len // block_size
        assert log2(max_seq_len // block_size).is_integer(), f'number of blocks {num_blocks} must be a power of 2'

        self.token_emb = nn.Embedding(num_tokens, dim)  # 标记嵌入层
        self.pos_emb = RotaryEmbedding(dim = dim_head)   # 位置编码
        self.max_seq_len = max_seq_len

        layers = nn.ModuleList([])  # 模块列表

        attn_class = CausalHAttention1D if causal else HAttention1D  # 根据是否因果选择不同的注意力类
        attn_kwargs = dict(max_seq_len = max_seq_len) if causal else dict()  # 如果是因果,传入最大序列长度参数

        shift_token_ranges = (0, 1) if shift_tokens else (-1, 0, 1)  # 如果移动标记,设置移动范围

        for ind in range(depth):
            attn = attn_class(dim, dim_head = dim_head, heads = heads, block_size = block_size, pos_emb = self.pos_emb, **attn_kwargs)  # 创建注意力层
            ff = FeedForward(dim, mult = ff_mult)  # 创建 FeedForward 层

            if shift_tokens:
                attn, ff = map(lambda t: PreShiftTokens(shift_token_ranges, t), (attn, ff))  # 如果移动标记,对注意力和 FeedForward 层进行预移动标记处理

            attn, ff = map(lambda t: PreNorm(dim, t), (attn, ff))  # 对注意力和 FeedForward 层进行预归一化处理
            layers.append(nn.ModuleList([attn ,ff]))  # 将注意力和 FeedForward 层添加到模块列表中

        execute_type = ReversibleSequence if reversible else SequentialSequence  # 根据是否可逆选择不同的执行类型
        route_attn = ((True, False),) * depth  # 设置注意力路由
        attn_route_map = {'mask': route_attn}  # 设置注意力路由映射

        self.layers = execute_type(layers, args_route = {**attn_route_map})  # 创建执行类型的层

        self.to_logits = nn.Sequential(
            nn.LayerNorm(dim),  # 归��化层
            nn.Linear(dim, num_tokens)  # 线性层,输出标记数量
        )

    def forward(self, x, mask = None):
        b, n, device = *x.shape, x.device  # 获取输入张量的形状和设备信息
        assert n <= self.max_seq_len, 'sequence length must be less than the maximum sequence length'  # 断言序列长度必须小于等于最大序列长度
        x = self.token_emb(x)  # 标记嵌入
        x = self.layers(x, mask = mask)  # 执行层
        return self.to_logits(x)  # 输出预测结果

.\lucidrains\h-transformer-1d\h_transformer_1d\reversible.py

# 导入 torch 库
import torch
# 导入 torch 中的神经网络模块
import torch.nn as nn
# 从 operator 模块中导入 itemgetter 函数
from operator import itemgetter
# 从 torch.autograd.function 模块中导入 Function 类
from torch.autograd.function import Function
# 从 torch.utils.checkpoint 模块中导入 get_device_states 和 set_device_states 函数

# 用于将参数路由到可逆层函数中的函数
def route_args(router, args, depth):
    # 初始化路由后的参数列表
    routed_args = [(dict(), dict()) for _ in range(depth)]
    # 获取参数中与路由器匹配的键
    matched_keys = [key for key in args.keys() if key in router]

    for key in matched_keys:
        val = args[key]
        for depth, ((f_args, g_args), routes) in enumerate(zip(routed_args, router[key])):
            new_f_args, new_g_args = map(lambda route: ({key: val} if route else {}), routes)
            routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args})
    return routed_args

# 根据概率丢弃层的函数
def layer_drop(layers, prob):
    to_drop = torch.empty(len(layers)).uniform_(0, 1) < prob
    blocks = [block for block, drop in zip(layers, to_drop) if not drop]
    blocks = layers[:1] if len(blocks) == 0 else blocks
    return blocks

# 保存和设置随机数种子的类
class Deterministic(nn.Module):
    def __init__(self, net):
        super().__init__()
        self.net = net
        self.cpu_state = None
        self.cuda_in_fwd = None
        self.gpu_devices = None
        self.gpu_states = None

    def record_rng(self, *args):
        self.cpu_state = torch.get_rng_state()
        if torch.cuda._initialized:
            self.cuda_in_fwd = True
            self.gpu_devices, self.gpu_states = get_device_states(*args)

    def forward(self, *args, record_rng = False, set_rng = False, **kwargs):
        if record_rng:
            self.record_rng(*args)

        if not set_rng:
            return self.net(*args, **kwargs)

        rng_devices = []
        if self.cuda_in_fwd:
            rng_devices = self.gpu_devices

        with torch.random.fork_rng(devices=rng_devices, enabled=True):
            torch.set_rng_state(self.cpu_state)
            if self.cuda_in_fwd:
                set_device_states(self.gpu_devices, self.gpu_states)
            return self.net(*args, **kwargs)

# 可逆块类,受启发于 https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py
# 一旦多 GPU 工作正常,重构并将 PR 发回源代码
class ReversibleBlock(nn.Module):
    def __init__(self, f, g):
        super().__init__()
        self.f = Deterministic(f)
        self.g = Deterministic(g)

    def forward(self, x, f_args = {}, g_args = {}):
        x1, x2 = torch.chunk(x, 2, dim=2)
        y1, y2 = None, None

        with torch.no_grad():
            y1 = x1 + self.f(x2, record_rng=self.training, **f_args)
            y2 = x2 + self.g(y1, record_rng=self.training, **g_args)

        return torch.cat([y1, y2], dim=2)

    def backward_pass(self, y, dy, f_args = {}, g_args = {}):
        y1, y2 = torch.chunk(y, 2, dim=2)
        del y

        dy1, dy2 = torch.chunk(dy, 2, dim=2)
        del dy

        with torch.enable_grad():
            y1.requires_grad = True
            gy1 = self.g(y1, set_rng=True, **g_args)
            torch.autograd.backward(gy1, dy2)

        with torch.no_grad():
            x2 = y2 - gy1
            del y2, gy1

            dx1 = dy1 + y1.grad
            del dy1
            y1.grad = None

        with torch.enable_grad():
            x2.requires_grad = True
            fx2 = self.f(x2, set_rng=True, **f_args)
            torch.autograd.backward(fx2, dx1, retain_graph=True)

        with torch.no_grad():
            x1 = y1 - fx2
            del y1, fx2

            dx2 = dy2 + x2.grad
            del dy2
            x2.grad = None

            x = torch.cat([x1, x2.detach()], dim=2)
            dx = torch.cat([dx1, dx2], dim=2)

        return x, dx

# 可逆函数类
class _ReversibleFunction(Function):
    @staticmethod
    # 前向传播函数,接收上下文对象 ctx,输入数据 x,模块列表 blocks 和参数列表 args
    def forward(ctx, x, blocks, args):
        # 将参数列表 args 存储到上下文对象 ctx 中
        ctx.args = args
        # 遍历模块列表 blocks 和参数列表 args,对输入数据 x 进行处理
        for block, kwarg in zip(blocks, args):
            x = block(x, **kwarg)
        # 将处理后的数据 x 分离出来,并存储到上下文对象 ctx 中
        ctx.y = x.detach()
        # 将模块列表 blocks 存储到上下文对象 ctx 中
        ctx.blocks = blocks
        # 返回处理后的数据 x
        return x

    # 反向传播函数,接收上下文对象 ctx 和梯度 dy
    @staticmethod
    def backward(ctx, dy):
        # 获取上下文对象 ctx 中存储的处理后的数据 y 和参数列表 args
        y = ctx.y
        args = ctx.args
        # 反向遍历模块列表 blocks 和参数列表 args,对梯度 dy 进行处理
        for block, kwargs in zip(ctx.blocks[::-1], args[::-1]):
            # 调用模块的反向传播函数,更新梯度 dy 和数据 y
            y, dy = block.backward_pass(y, dy, **kwargs)
        # 返回更新后的梯度 dy
        return dy, None, None
class SequentialSequence(nn.Module):
    # 定义一个顺序执行的神经网络模块
    def __init__(self, layers, args_route = {}, layer_dropout = 0.):
        super().__init__()
        # 断言每个参数路由映射的深度与顺序层的数量相同
        assert all(len(route) == len(layers) for route in args_route.values()), 'each argument route map must have the same depth as the number of sequential layers'
        self.layers = layers
        self.args_route = args_route
        self.layer_dropout = layer_dropout

    def forward(self, x, **kwargs):
        # 根据参数路由和关键字参数获取参数
        args = route_args(self.args_route, kwargs, len(self.layers))
        layers_and_args = list(zip(self.layers, args))

        if self.training and self.layer_dropout > 0:
            # 如果处于训练状态且存在层丢弃率,则执行层丢弃
            layers_and_args = layer_drop(layers_and_args, self.layer_dropout)

        for (f, g), (f_args, g_args) in layers_and_args:
            # 依次执行每个顺序层的前向传播
            x = x + f(x, **f_args)
            x = x + g(x, **g_args)
        return x

class ReversibleSequence(nn.Module):
    # 定义一个可逆的序列神经网络模块
    def __init__(self, blocks, args_route = {}, layer_dropout = 0.):
        super().__init__()
        self.args_route = args_route
        self.layer_dropout = layer_dropout
        # 创建包含可逆块的模块列表
        self.blocks = nn.ModuleList([ReversibleBlock(f=f, g=g) for f, g in blocks])

    def forward(self, x, **kwargs):
        # 在最后一个维度上连接输入张量的副本
        x = torch.cat([x, x], dim=-1)

        blocks = self.blocks
        # 根据参数路由和关键字参数获取参数
        args = route_args(self.args_route, kwargs, len(blocks))
        args = list(map(lambda x: {'f_args': x[0], 'g_args': x[1]}, args))

        layers_and_args = list(zip(blocks, args))

        if self.training and self.layer_dropout > 0:
            # 如果处于训练状态且存在层丢弃率,则执行层丢弃
            layers_and_args = layer_drop(layers_and_args, self.layer_dropout)
            blocks, args = map(lambda ind: list(map(itemgetter(ind), layers_and_args)), (0, 1))

        # 调用自定义的可逆函数进行前向传播
        out =  _ReversibleFunction.apply(x, blocks, args)
        # 在最后一个维度上分割输出并求和
        return torch.stack(out.chunk(2, dim=-1)).sum(dim=0)

.\lucidrains\h-transformer-1d\h_transformer_1d\__init__.py

# 从 h_transformer_1d.h_transformer_1d 模块中导入 HTransformer1D 类
from h_transformer_1d.h_transformer_1d import HTransformer1D

H-Transformer-1D

Implementation of H-Transformer-1D, Transformer using hierarchical Attention for sequence learning with subquadratic costs. The encoder (non-autoregressive) flavor of this architecture currently holds the throne for Long Range Arena, a benchmark for efficient transformers.

Open In Colab 131k tokens

Install

$ pip install h-transformer-1d

Usage

import torch
from h_transformer_1d import HTransformer1D

model = HTransformer1D(
    num_tokens = 256,          # number of tokens
    dim = 512,                 # dimension
    depth = 12,                # depth
    causal = False,            # autoregressive or not
    max_seq_len = 8192,        # maximum sequence length
    heads = 8,                 # heads
    dim_head = 64,             # dimension per head
    block_size = 128,          # block size
    reversible = True,         # use reversibility, to save on memory with increased depth
    shift_tokens = True        # whether to shift half the feature space by one along the sequence dimension, for faster convergence (experimental feature)
)

x = torch.randint(0, 256, (1, 8000))   # variable sequence length
mask = torch.ones((1, 8000)).bool()    # variable mask length

# network will automatically pad to power of 2, do hierarchical attention, etc

logits = model(x, mask = mask) # (1, 8000, 256)

Citations

@misc{zhu2021htransformer1d,
    title   = {H-Transformer-1D: Fast One-Dimensional Hierarchical Attention for Sequences}, 
    author  = {Zhenhai Zhu and Radu Soricut},
    year    = {2021},
    eprint  = {2107.11906},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@software{peng_bo_2021_5196578,
    author       = {PENG Bo},
    title        = {BlinkDL/RWKV-LM: 0.01},
    month        = {aug},
    year         = {2021},
    publisher    = {Zenodo},
    version      = {0.01},
    doi          = {10.5281/zenodo.5196578},
    url          = {https://doi.org/10.5281/zenodo.5196578}
}

.\lucidrains\h-transformer-1d\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'h-transformer-1d',  # 包的名称
  packages = find_packages(),  # 查找所有包
  version = '0.1.9',  # 版本号
  license='MIT',  # 许可证
  description = 'H-Transformer 1D - Pytorch',  # 描述
  long_description_content_type = 'text/markdown',  # 长描述内容类型
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  url = 'https://github.com/lucidrains/h-transformer-1d',  # 项目链接
  keywords = [  # 关键词列表
    'artificial intelligence',
    'deep learning',
    'transformers',
    'efficient attention'
  ],
  install_requires=[  # 安装依赖
    'einops>=0.3',
    'rotary-embedding-torch>=0.5.3',
    'torch>=1.6'
  ],
  classifiers=[  # 分类器
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\h-transformer-1d\train.py

# 导入所需的模块和类
from h_transformer_1d import HTransformer1D
from h_transformer_1d.autoregressive_wrapper import AutoregressiveWrapper

import random
import tqdm
import gzip
import numpy as np
import torch
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset

# 定义常量
NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 2e-4
VALIDATE_EVERY  = 100
GENERATE_EVERY  = 500
GENERATE_LENGTH = 512
SEQ_LEN = 4096

# 定义辅助函数

# 从 token 解码为字符
def decode_token(token):
    return str(chr(max(32, token)))

# 从 tokens 解码为字符串
def decode_tokens(tokens):
    return ''.join(list(map(decode_token, tokens)))

# 实例化类似 GPT 的解码器模型

model = HTransformer1D(
    num_tokens = 256,
    dim = 512,
    max_seq_len = SEQ_LEN,
    depth = 8,
    heads = 8,
    causal = True,
    reversible = True
)

model = AutoregressiveWrapper(model)
model.cuda()

# 准备 enwik8 数据

with gzip.open('./data/enwik8.gz') as file:
    X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
    trX, vaX = np.split(X, [int(90e6)])
    data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)

# 定义数据集类
class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
        full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
        return full_seq.cuda()

    def __len__(self):
        return self.data.size(0) // self.seq_len

# 创建训练集和验证集的数据加载器
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset   = TextSamplerDataset(data_val, SEQ_LEN)
train_loader  = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
val_loader    = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))

# 定义优化器
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

# 训练模型
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
    model.train()

    for __ in range(GRADIENT_ACCUMULATE_EVERY):
        loss = model(next(train_loader))
        loss.backward()

    print(f'training loss: {loss.item()}')
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.25)
    optim.step()
    optim.zero_grad()

    if i % VALIDATE_EVERY == 0:
        model.eval()
        with torch.no_grad():
            loss = model(next(val_loader))
            print(f'validation loss: {loss.item()}')

    if i % GENERATE_EVERY == 0:
        model.eval()
        inp = random.choice(val_dataset)[:-1]
        prime = decode_tokens(inp)
        print(f'%s \n\n %s', (prime, '*' * 100))

        sample = model.generate(inp, GENERATE_LENGTH)
        output_str = decode_tokens(sample)
        print(output_str)

.\lucidrains\halonet-pytorch\halonet_pytorch\halonet_pytorch.py

import torch
from torch import nn, einsum
import torch.nn.functional as F

from einops import rearrange, repeat

# 导入所需的库

# 相对位置编码

def to(x):
    return {'device': x.device, 'dtype': x.dtype}

# 返回包含设备和数据类型信息的字典

def pair(x):
    return (x, x) if not isinstance(x, tuple) else x

# 如果输入不是元组,则返回包含两个相同元素的元组,否则返回原元组

def expand_dim(t, dim, k):
    t = t.unsqueeze(dim = dim)
    expand_shape = [-1] * len(t.shape)
    expand_shape[dim] = k
    return t.expand(*expand_shape)

# 在指定维度上扩展张量的大小

def rel_to_abs(x):
    b, l, m = x.shape
    r = (m + 1) // 2

    col_pad = torch.zeros((b, l, 1), **to(x))
    x = torch.cat((x, col_pad), dim = 2)
    flat_x = rearrange(x, 'b l c -> b (l c)')
    flat_pad = torch.zeros((b, m - l), **to(x))
    flat_x_padded = torch.cat((flat_x, flat_pad), dim = 1)
    final_x = flat_x_padded.reshape(b, l + 1, m)
    final_x = final_x[:, :l, -r:]
    return final_x

# 将相对位置编码转换为绝对位置编码

def relative_logits_1d(q, rel_k):
    b, h, w, _ = q.shape
    r = (rel_k.shape[0] + 1) // 2

    logits = einsum('b x y d, r d -> b x y r', q, rel_k)
    logits = rearrange(logits, 'b x y r -> (b x) y r')
    logits = rel_to_abs(logits)

    logits = logits.reshape(b, h, w, r)
    logits = expand_dim(logits, dim = 2, k = r)
    return logits

# 计算相对位置的一维逻辑值

class RelPosEmb(nn.Module):
    def __init__(
        self,
        block_size,
        rel_size,
        dim_head
    ):
        super().__init__()
        height = width = rel_size
        scale = dim_head ** -0.5

        self.block_size = block_size
        self.rel_height = nn.Parameter(torch.randn(height * 2 - 1, dim_head) * scale)
        self.rel_width = nn.Parameter(torch.randn(width * 2 - 1, dim_head) * scale)

    def forward(self, q):
        block = self.block_size

        q = rearrange(q, 'b (x y) c -> b x y c', x = block)
        rel_logits_w = relative_logits_1d(q, self.rel_width)
        rel_logits_w = rearrange(rel_logits_w, 'b x i y j-> b (x y) (i j)')

        q = rearrange(q, 'b x y d -> b y x d')
        rel_logits_h = relative_logits_1d(q, self.rel_height)
        rel_logits_h = rearrange(rel_logits_h, 'b x i y j -> b (y x) (j i)')
        return rel_logits_w + rel_logits_h

# 相对位置编码类

# classes

class HaloAttention(nn.Module):
    def __init__(
        self,
        *,
        dim,
        block_size,
        halo_size,
        dim_head = 64,
        heads = 8
    ):
        super().__init__()
        assert halo_size > 0, 'halo size must be greater than 0'

        self.dim = dim
        self.heads = heads
        self.scale = dim_head ** -0.5

        self.block_size = block_size
        self.halo_size = halo_size

        inner_dim = dim_head * heads

        self.rel_pos_emb = RelPosEmb(
            block_size = block_size,
            rel_size = block_size + (halo_size * 2),
            dim_head = dim_head
        )

        self.to_q  = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
        self.to_out = nn.Linear(inner_dim, dim)

# HaloAttention 类,实现了自注意力机制
    # 定义前向传播函数,接受输入张量 x
    def forward(self, x):
        # 解包输入张量 x 的形状信息,包括批大小 b,通道数 c,高度 h,宽度 w,块大小 block,边界大小 halo,头数 heads,设备信息 device
        b, c, h, w, block, halo, heads, device = *x.shape, self.block_size, self.halo_size, self.heads, x.device
        # 断言高度和宽度能够被块大小整除,确保 fmap 的维度必须是块大小的整数倍
        assert h % block == 0 and w % block == 0, 'fmap dimensions must be divisible by the block size'
        # 断言通道数等于指定的维度
        assert c == self.dim, f'channels for input ({c}) does not equal to the correct dimension ({self.dim})'

        # 获取块的邻域,并为推导键值准备一个带有边界的版本(带有填充的块)

        # 重排输入张量 x,将其形状变为 '(b c (h p1) (w p2) -> (b h w) (p1 p2) c',其中 p1 和 p2 为块大小
        q_inp = rearrange(x, 'b c (h p1) (w p2) -> (b h w) (p1 p2) c', p1 = block, p2 = block)

        # 使用 F.unfold 函数对 x 进行展开,设置卷积核大小为 block + halo * 2,步长为 block,填充为 halo
        kv_inp = F.unfold(x, kernel_size = block + halo * 2, stride = block, padding = halo)
        # 重排展开后的张量 kv_inp,将其形状变为 '(b (c j) i -> (b i) j c',其中 j 为块大小
        kv_inp = rearrange(kv_inp, 'b (c j) i -> (b i) j c', c = c)

        # 推导查询、键、值

        # 将 q_inp 输入到 self.to_q 函数中得到查询 q
        q = self.to_q(q_inp)
        # 将 kv_inp 输入到 self.to_kv 函数中得到键 k 和值 v,并按最后一个维度分割成两部分
        k, v = self.to_kv(kv_inp).chunk(2, dim = -1)

        # 分割头部

        # 对查询 q、键 k、值 v 进行重排,将其形状变为 '(b n (h d) -> (b h) n d',其中 h 为头部数
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = heads), (q, k, v))

        # 缩放

        q *= self.scale

        # 注意力计算

        sim = einsum('b i d, b j d -> b i j', q, k)

        # 添加相对位置偏置

        sim += self.rel_pos_emb(q)

        # 掩码填充(在论文中,他们声称不需要掩码,但是对于填充怎么处理?)

        # 创建全为 1 的掩码张量 mask,形状为 (1, 1, h, w),设备为 device
        mask = torch.ones(1, 1, h, w, device = device)
        # 使用 F.unfold 函数对 mask 进行展开,设置卷积核大小为 block + (halo * 2),步长为 block,填充为 halo
        mask = F.unfold(mask, kernel_size = block + (halo * 2), stride = block, padding = halo)
        # 重复 mask 张量,形状变为 '(() j i -> (b i h) () j',其中 b 为批大小,h 为头部数
        mask = repeat(mask, '() j i -> (b i h) () j', b = b, h = heads)
        # 将 mask 转换为布尔类型张量
        mask = mask.bool()

        # 计算最大负值
        max_neg_value = -torch.finfo(sim.dtype).max
        # 使用 mask 对 sim 进行掩码填充,将 mask 为 True 的位置���充为最大负值
        sim.masked_fill_(mask, max_neg_value)

        # 注意力计算

        attn = sim.softmax(dim = -1)

        # 聚合

        out = einsum('b i j, b j d -> b i d', attn, v)

        # 合并和组合头部

        out = rearrange(out, '(b h) n d -> b n (h d)', h = heads)
        out = self.to_out(out)

        # 将块合并回原始特征图

        out = rearrange(out, '(b h w) (p1 p2) c -> b c (h p1) (w p2)', b = b, h = (h // block), w = (w // block), p1 = block, p2 = block)
        return out

.\lucidrains\halonet-pytorch\halonet_pytorch\__init__.py

# 从 halonet_pytorch.halonet_pytorch 模块中导入 HaloAttention 类
from halonet_pytorch.halonet_pytorch import HaloAttention

HaloNet - Pytorch

Implementation of the Attention layer from the paper, Scaling Local Self-Attention For Parameter Efficient Visual Backbones. This repository will only house the attention layer and not much more.

Install

$ pip install halonet-pytorch

Usage

import torch
from halonet_pytorch import HaloAttention

attn = HaloAttention(
    dim = 512,         # dimension of feature map
    block_size = 8,    # neighborhood block size (feature map must be divisible by this)
    halo_size = 4,     # halo size (block receptive field)
    dim_head = 64,     # dimension of each head
    heads = 4          # number of attention heads
).cuda()

fmap = torch.randn(1, 512, 32, 32).cuda()
attn(fmap) # (1, 512, 32, 32)

Citations

@misc{vaswani2021scaling,
    title   = {Scaling Local Self-Attention For Parameter Efficient Visual Backbones}, 
    author  = {Ashish Vaswani and Prajit Ramachandran and Aravind Srinivas and Niki Parmar and Blake Hechtman and Jonathon Shlens},
    year    = {2021},
    eprint  = {2103.12731},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}

.\lucidrains\halonet-pytorch\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'halonet-pytorch', # 包的名称
  packages = find_packages(), # 查找所有包
  version = '0.0.4', # 版本号
  license='MIT', # 许可证
  description = 'HaloNet - Pytorch', # 描述
  author = 'Phil Wang', # 作者
  author_email = 'lucidrains@gmail.com', # 作者邮箱
  url = 'https://github.com/lucidrains/halonet-pytorch', # 项目链接
  keywords = [ # 关键词列表
    'artificial intelligence',
    'deep learning',
    'attention mechanism'
  ],
  install_requires=[ # 安装依赖
    'einops>=0.3',
    'torch>=1.6'
  ],
  classifiers=[ # 分类器列表
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\hamburger-pytorch\hamburger_pytorch\hamburger_pytorch.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn, einsum 模块
from torch import nn, einsum
# 从 torch 库中导入 nn.functional 模块,并重命名为 F
import torch.nn.functional as F
# 从 contextlib 模块中导入 contextmanager 上下文管理器
from contextlib import contextmanager
# 从 einops 模块中导入 repeat, rearrange 函数
from einops import repeat, rearrange

# 辅助函数

# 定义一个空上下文管理器
@contextmanager
def null_context():
    yield

# 判断变量是否存在的函数
def exists(val):
    return val is not None

# 返回默认值的函数
def default(val, d):
    return val if exists(val) else d

# 类

# 定义 NMF 类,继承自 nn.Module
class NMF(nn.Module):
    def __init__(
        self,
        dim,
        n,
        ratio = 8,
        K = 6,
        eps = 2e-8
    ):
        super().__init__()
        r = dim // ratio

        # 初始化 D 和 C 为随机数
        D = torch.zeros(dim, r).uniform_(0, 1)
        C = torch.zeros(r, n).uniform_(0, 1)

        self.K = K
        self.D = nn.Parameter(D)
        self.C = nn.Parameter(C)

        self.eps = eps

    def forward(self, x):
        b, D, C, eps = x.shape[0], self.D, self.C, self.eps

        # 将输入 x 转为非负数
        x = F.relu(x)

        # 将 D 和 C 扩展为与输入 x 相同的 batch 维度
        D = repeat(D, 'd r -> b d r', b = b)
        C = repeat(C, 'r n -> b r n', b = b)

        # 转置函数
        t = lambda tensor: rearrange(tensor, 'b i j -> b j i')

        for k in reversed(range(self.K)):
            # 只在最后一步计算梯度,根据 'One-step Gradient' 提议
            context = null_context if k == 0 else torch.no_grad
            with context():
                C_new = C * ((t(D) @ x) / ((t(D) @ D @ C) + eps))
                D_new = D * ((x @ t(C)) / ((D @ C @ t(C)) + eps))
                C, D = C_new, D_new

        return D @ C

# 定义 Hamburger 类,继承自 nn.Module
class Hamburger(nn.Module):
    def __init__(
        self,
        *,
        dim,
        n,
        inner_dim = None,
        ratio = 8,
        K = 6
    ):
        super().__init__()
        inner_dim = default(inner_dim, dim)

        # 定义 lower_bread 为一维卷积层
        self.lower_bread = nn.Conv1d(dim, inner_dim, 1, bias = False)
        # 定义 ham 为 NMF 类的实例
        self.ham = NMF(inner_dim, n, ratio = ratio, K = K)
        # 定义 upper_bread 为一维卷积层
        self.upper_bread = nn.Conv1d(inner_dim, dim, 1, bias = False)

    def forward(self, x):
        shape = x.shape
        # 将输入 x 展平为二维
        x = x.flatten(2)

        x = self.lower_bread(x)
        x = self.ham(x)
        x = self.upper_bread(x)
        # 将 x 重新 reshape 成原始形状
        return x.reshape(shape)

.\lucidrains\hamburger-pytorch\hamburger_pytorch\__init__.py

# 从hamburger_pytorch包中导入Hamburger类
from hamburger_pytorch.hamburger_pytorch import Hamburger

🍔 - Pytorch

Pytorch implementation of the hamburger module from the ICLR 2021 paper Is Attention Better Than Matrix Decomposition?. Following Betteridge's law, the answer according to the paper is "No" for segmentation and GANs.

This repository will contain the NMF-MU (nonnegative matrix factorization w/ multiplicative update) module sandwiched by linear projections.

Update: I tried this, but did not get better results than just using linear attention

Install

$ pip install hamburger-pytorch

Usage

import torch
from hamburger_pytorch import Hamburger

hamburger = Hamburger(
    dim = 512,       # input dimension
    n = 32 * 32,     # n will be size of the sequence, in this case, height times width of the images
    ratio = 8,       # matrix factorization ratio, recommended to be at 8
    K = 6            # number of iterations, optimal at 6 as shown in paper
)

x = torch.randn(1, 512, 32, 32)
hamburger(x) + x # (1, 512, 32, 32)

Citations

@inproceedings{
    anonymous2021is,
    title={Is Attention Better Than Matrix Decomposition?},
    author={Anonymous},
    booktitle={Submitted to International Conference on Learning Representations},
    year={2021},
    url={https://openreview.net/forum?id=1FvkSpWosOl},
    note={under review}
}

.\lucidrains\hamburger-pytorch\setup.py

# 导入设置和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  # 包的名称
  name = 'hamburger-pytorch',
  # 查找并包含所有包
  packages = find_packages(),
  # 版本号
  version = '0.0.3',
  # 许可证
  license='MIT',
  # 描述
  description = 'Hamburger - Pytorch',
  # 作者
  author = 'Phil Wang',
  # 作者邮箱
  author_email = 'lucidrains@gmail.com',
  # 项目链接
  url = 'https://github.com/lucidrains/hamburger-pytorch',
  # 关键词
  keywords = [
    'artificial intelligence',
    'attention mechanism',
    'matrix factorization'
  ],
  # 安装依赖
  install_requires=[
    'torch',
    'einops>=0.3'
  ],
  # 分类
  classifiers=[
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\holodeck-pytorch\holodeck_pytorch\holodeck_pytorch.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块和 einsum 函数
from torch import nn, einsum
# 从 torch.nn 模块中导入 F 函数
import torch.nn.functional as F

# 从 einops 库中导入 rearrange 函数
from einops import rearrange

# 定义辅助函数

# 判断变量是否存在的函数
def exists(val):
    return val is not None

# 返回默认值的函数
def default(val, d):
    return val if exists(val) else d

# 注意力机制类

class Attention(nn.Module):
    def __init__(
        self,
        dim,
        dim_head = 64,
        dim_context = None,
        heads = 8,
        norm_context = False
    ):
        super().__init__()
        self.heads = heads
        self.scale = dim_head ** -0.5
        inner_dim = dim_head * heads

        dim_context = default(dim_context, dim)

        self.norm = nn.LayerNorm(dim)
        self.context_norm = nn.LayerNorm(dim_context) if norm_context else nn.Identity()

        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim_context, dim_head * 2, bias = False)
        self.to_out = nn.Linear(inner_dim, dim, bias = False)

    def forward(
        self,
        x,
        context = None,
        mask = None,
        attn_bias = None
    ):
        if exists(context):
            context = self.context_norm(context)

        kv_input = default(context, x)

        x = self.norm(x)

        q, k, v = self.to_q(x), *self.to_kv(kv_input).chunk(2, dim = -1)

        q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)

        q = q * self.scale

        sim = einsum('b h i d, b j d -> b h i j', q, k)

        mask_value = -torch.finfo(sim.dtype).max

        if exists(attn_bias):
            sim = sim + attn_bias

        if exists(mask):
            mask = rearrange(mask, 'b j -> b 1 1 j')
            sim = sim.masked_fill(~mask, mask_value)

        attn = sim.softmax(dim = -1)

        out = einsum('b h i j, b j d -> b h i d', attn, v)

        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

# 主类

class Holodeck(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x

.\lucidrains\holodeck-pytorch\holodeck_pytorch\__init__.py

# 从 holodeck_pytorch 模块中导入 Holodeck 类
from holodeck_pytorch.holodeck_pytorch import Holodeck

Holodeck - Pytorch (wip)

Implementation of a holodeck, written in Pytorch.

Citations

@article{Wu20234DGS,
    title     = {4D Gaussian Splatting for Real-Time Dynamic Scene Rendering},
    author    = {Guanjun Wu and Taoran Yi and Jiemin Fang and Lingxi Xie and Xiaopeng Zhang and Wei Wei and Wenyu Liu and Qi Tian and Xinggang Wang},
    journal   = {ArXiv},
    year      = {2023},
    volume    = {abs/2310.08528},
    url       = {https://api.semanticscholar.org/CorpusID:263908793}
}
@inproceedings{Singer2023TextTo4DDS,
    title   = {Text-To-4D Dynamic Scene Generation},
    author  = {Uriel Singer and Shelly Sheynin and Adam Polyak and Oron Ashual and Iurii Makarov and Filippos Kokkinos and Naman Goyal and Andrea Vedaldi and Devi Parikh and Justin Johnson and Yaniv Taigman},
    year    = {2023}
}
@inproceedings{Bauer2023SpatialFS,
    title   = {Spatial Functa: Scaling Functa to ImageNet Classification and Generation},
    author  = {M. Bauer and Emilien Dupont and Andy Brock and Dan Rosenbaum and Jonathan Schwarz and Hyunjik Kim},
    year    = {2023}
}

.\lucidrains\holodeck-pytorch\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'holodeck-pytorch',  # 包的名称
  packages = find_packages(exclude=[]),  # 查找所有包
  version = '0.0.1',  # 版本号
  license='MIT',  # 许可证
  description = 'Holodeck - Pytorch',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  long_description_content_type = 'text/markdown',  # 长描述内容类型
  url = 'https://github.com/lucidrains/holodeck-pytorch',  # URL
  keywords = [  # 关键词列表
    'artificial intelligence',
    'deep learning',
    'transformers',
    'attention mechanism',
    'denoising diffusion',
    'temporal scene representations',
    'hypernetworks'
  ],
  install_requires=[  # 安装依赖
    'einops>=0.4',
    'torch>=1.6',
  ],
  classifiers=[  # 分类器
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

Data source

The enwik8 data was downloaded from the Hutter prize page: prize.hutter1.net/

.\lucidrains\hourglass-transformer-pytorch\hourglass_transformer_pytorch\autoregressive_wrapper.py

import torch
from torch import nn
import torch.nn.functional as F

# helper function

# 检查值是否存在
def exists(val):
    return val is not None

# 装饰器函数,用于在模型评估时切换模型状态
def eval_decorator(fn):
    def inner(model, *args, **kwargs):
        was_training = model.training
        model.eval()
        out = fn(model, *args, **kwargs)
        model.train(was_training)
        return out
    return inner

# top k filtering

# 根据阈值过滤 logits,保留前 k 个值
def top_k(logits, thres = 0.9):
    k = int((1 - thres) * logits.shape[-1])
    val, ind = torch.topk(logits, k)
    probs = torch.full_like(logits, float('-inf'))
    probs.scatter_(1, ind, val)
    return probs

# 自回归包装器类
class AutoregressiveWrapper(nn.Module):
    def __init__(self, net, pad_value = 0):
        super().__init__()
        assert hasattr(net, 'max_seq_len'), 'your transformer class must have max_seq_len set to the maximum sequence length'

        self.pad_value = pad_value
        self.net = net
        self.max_seq_len = net.max_seq_len

    # 生成序列的方法
    @torch.no_grad()
    @eval_decorator
    def generate(self, start_tokens, seq_len, eos_token = None, temperature = 1., filter_thres = 0.9, **kwargs):
        b, t, device = *start_tokens.shape, start_tokens.device

        out = start_tokens

        for _ in range(seq_len):
            x = out[:, -self.max_seq_len:]

            logits = self.net(x, **kwargs)[:, -1, :]

            filtered_logits = top_k(logits, thres = filter_thres)
            probs = F.softmax(filtered_logits / temperature, dim=-1)

            sample = torch.multinomial(probs, 1)

            out = torch.cat((out, sample), dim=-1)

            if exists(eos_token):
                is_eos_token = (out == eos_token)

                if is_eos_token.any(dim = -1).all():
                    # mask out everything after the eos tokens
                    shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
                    mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1
                    out = out.masked_fill(mask, self.pad_value)
                    break

        out = out[:, t:]
        return out

    # 前向传播方法
    def forward(self, x, **kwargs):
        x_inp, x_labels = x[:, :-1], x[:, 1:]
        logits = self.net(x_inp, **kwargs)
        return F.cross_entropy(logits.transpose(1, 2), x_labels, ignore_index = self.pad_value)

.\lucidrains\hourglass-transformer-pytorch\hourglass_transformer_pytorch\hourglass_transformer_pytorch.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块、einsum 函数
from torch import nn, einsum
# 从 torch.nn.functional 库中导入 F 模块
import torch.nn.functional as F
# 从 einops 库中导入 rearrange、reduce、repeat 函数

from einops import rearrange, reduce, repeat

# helpers

# 判断变量是否存在的函数
def exists(val):
    return val is not None

# 如果变量存在则返回该变量,否则返回默认值的函数
def default(val, d):
    return val if exists(val) else d

# 将张量填充到指定的倍数的函数
def pad_to_multiple(tensor, multiple, dim = -1, value = 0):
    seq_len = tensor.shape[dim]
    m = seq_len / multiple
    if m.is_integer():
        return tensor
    remainder = math.ceil(m) * multiple - seq_len
    pad_offset = (0,) * (-1 - dim) * 2
    return F.pad(tensor, (*pad_offset, 0, remainder), value = value)

# 将输入值转换为元组的函数
def cast_tuple(val, depth = 1):
    return val if isinstance(val, tuple) else ((val,) * depth)

# factory

# 获取 hourglass transformer 的工厂函数
def get_hourglass_transformer(
    dim,
    *,
    depth,
    shorten_factor,
    attn_resampling,
    updown_sample_type,
    **kwargs
):
    assert isinstance(depth, int) or (isinstance(depth, tuple)  and len(depth) == 3), 'depth must be either an integer or a tuple of 3, indicating (pre_transformer_depth, <nested-hour-glass-config>, post_transformer_depth)'
    assert not (isinstance(depth, int) and shorten_factor), 'there does not need to be a shortening factor when only a single transformer block is indicated (depth of one integer value)'

    if isinstance(depth, int):
        return Transformer(dim = dim, depth = depth, **kwargs)

    return HourglassTransformer(dim = dim, depth = depth, shorten_factor = shorten_factor, attn_resampling = attn_resampling, updown_sample_type = updown_sample_type, **kwargs)

# up and down sample classes

# 下采样类
class NaiveDownsample(nn.Module):
    def __init__(self, shorten_factor):
        super().__init__()
        self.shorten_factor = shorten_factor

    def forward(self, x):
        return reduce(x, 'b (n s) d -> b n d', 'mean', s = self.shorten_factor)

# 上采样类
class NaiveUpsample(nn.Module):
    def __init__(self, shorten_factor):
        super().__init__()
        self.shorten_factor = shorten_factor

    def forward(self, x):
        return repeat(x, 'b n d -> b (n s) d', s = self.shorten_factor)

# 线性下采样类
class LinearDownsample(nn.Module):
    def __init__(self, dim, shorten_factor):
        super().__init__()
        self.proj = nn.Linear(dim * shorten_factor, dim)
        self.shorten_factor = shorten_factor

    def forward(self, x):
        x = rearrange(x, 'b (n s) d -> b n (s d)', s = self.shorten_factor)
        return self.proj(x)

# 线性上采样类
class LinearUpsample(nn.Module):
    def __init__(self, dim, shorten_factor):
        super().__init__()
        self.proj = nn.Linear(dim, dim * shorten_factor)
        self.shorten_factor = shorten_factor

    def forward(self, x):
        x = self.proj(x)
        return rearrange(x, 'b n (s d) -> b (n s) d', s = self.shorten_factor)

# classes

# 预归一化残差类
class PreNormResidual(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs) + x

# 注意力机制类
class Attention(nn.Module):
    def __init__(
        self,
        dim,
        heads = 8,
        dim_head = 64,
        dropout = 0.,
        causal = False
    ):
        super().__init__()
        self.heads = heads
        self.causal = causal
        self.scale = dim_head ** -0.5
        inner_dim = heads * dim_head

        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
        self.to_out = nn.Linear(inner_dim, dim)

        self.dropout = nn.Dropout(dropout)
    # 定义一个前向传播函数,接受输入 x,上下文 context 和掩码 mask
    def forward(self, x, context = None, mask = None):
        # 获取头数和设备信息
        h, device = self.heads, x.device
        # 如果没有指定上下文,则使用输入 x 作为键值对输入
        kv_input = default(context, x)

        # 将输入 x 分别转换为查询 q,键 k 和值 v
        q, k, v = self.to_q(x), *self.to_kv(kv_input).chunk(2, dim = -1)
        # 将查询 q,键 k 和值 v 重排维度,以适应多头注意力机制
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))

        # 对查询 q 进行缩放
        q = q * self.scale

        # 计算查询和键之间的相似度
        sim = einsum('b h i d, b h j d -> b h i j', q, k)
        # 定义掩码值为负无穷
        mask_value = -torch.finfo(sim.dtype).max

        # 如果存在掩码,则将相似度矩阵进行掩码处理
        if exists(mask):
            mask = rearrange(mask, 'b j -> b () () j')
            sim = sim.masked_fill(~mask, mask_value)

        # 如果启用因果性,生成一个上三角掩码矩阵
        if self.causal:
            i, j = sim.shape[-2:]
            mask = torch.ones(i, j, device = device, dtype = torch.bool).triu_(j - i + 1)
            mask = rearrange(mask, 'i j -> () () i j')
            sim = sim.masked_fill(mask, mask_value)

        # 对相似度矩阵进行 softmax 操作
        attn = sim.softmax(dim = -1)
        # 对注意力矩阵进行 dropout 操作
        attn = self.dropout(attn)

        # 根据注意力矩阵计算输出
        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        # 重排输出维度,以适应后续处理
        out = rearrange(out, 'b h n d -> b n (h d)', h = h)
        # 返回输出结果
        return self.to_out(out)
def FeedForward(dim, mult = 4, dropout = 0.):
    # 返回一个包含线性层、GELU激活函数、Dropout层和另一个线性层的序列模块
    return nn.Sequential(
        nn.Linear(dim, dim * mult),
        nn.GELU(),
        nn.Dropout(dropout),
        nn.Linear(dim * mult, dim)
    )

# transformer classes

class Transformer(nn.Module):
    def __init__(
        self,
        dim,
        *,
        depth,
        causal = False,
        heads = 8,
        dim_head = 64,
        attn_dropout = 0.,
        ff_mult = 4,
        ff_dropout = 0.,
        norm_out = False
    ):
        super().__init__()
        self.layers = nn.ModuleList([])

        for _ in range(depth):
            # 为每个深度创建一个包含注意力和前馈网络的预层归一化残差模块
            self.layers.append(nn.ModuleList([
                PreNormResidual(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = attn_dropout, causal = causal)),
                PreNormResidual(dim, FeedForward(dim, mult = ff_mult, dropout = ff_dropout))
            ]))

        # 如果需要输出归一化,则使用LayerNorm,否则使用Identity
        self.norm = nn.LayerNorm(dim) if norm_out else nn.Identity()

    def forward(self, x, context = None, mask = None):
        for attn, ff in self.layers:
            # 依次对每个层进行前向传播:注意力层 -> 前馈网络
            x = attn(x, context = context, mask = mask)
            x = ff(x)

        return self.norm(x)

class HourglassTransformer(nn.Module):
    def __init__(
        self,
        dim,
        *,
        depth,
        shorten_factor = 2,
        attn_resampling = True,
        updown_sample_type = 'naive',
        heads = 8,
        dim_head = 64,
        causal = False,
        norm_out = False
    ):
        super().__init__()
        assert len(depth) == 3, 'depth should be a tuple of length 3'
        assert updown_sample_type in {'naive', 'linear'}, 'downsample / upsample type must be either naive (average pool and repeat) or linear (linear projection and reshape)'

        pre_layers_depth, valley_depth, post_layers_depth = depth

        if isinstance(shorten_factor, (tuple, list)):
            shorten_factor, *rest_shorten_factor = shorten_factor
        elif isinstance(valley_depth, int):
            shorten_factor, rest_shorten_factor = shorten_factor, None
        else:
            shorten_factor, rest_shorten_factor = shorten_factor, shorten_factor

        transformer_kwargs = dict(
            dim = dim,
            heads = heads,
            dim_head = dim_head
        )

        self.causal = causal
        self.shorten_factor = shorten_factor

        if updown_sample_type == 'naive':
            # 使用NaiveDownsample和NaiveUpsample进行下采样和上采样
            self.downsample = NaiveDownsample(shorten_factor)
            self.upsample   = NaiveUpsample(shorten_factor)
        elif updown_sample_type == 'linear':
            # 使用LinearDownsample和LinearUpsample进行下采样和上采样
            self.downsample = LinearDownsample(dim, shorten_factor)
            self.upsample   = LinearUpsample(dim, shorten_factor)
        else:
            raise ValueError(f'unknown updown_sample_type keyword value - must be either naive or linear for now')

        # 获取中间层的Transformer
        self.valley_transformer = get_hourglass_transformer(
            shorten_factor = rest_shorten_factor,
            depth = valley_depth,
            attn_resampling = attn_resampling,
            updown_sample_type = updown_sample_type,
            causal = causal,
            **transformer_kwargs
        )

        # 如果需要注意力重采样,则创建前后的Transformer
        self.attn_resampling_pre_valley = Transformer(depth = 1, **transformer_kwargs) if attn_resampling else None
        self.attn_resampling_post_valley = Transformer(depth = 1, **transformer_kwargs) if attn_resampling else None

        # 创建前向Transformer和后向Transformer
        self.pre_transformer = Transformer(depth = pre_layers_depth, causal = causal, **transformer_kwargs)
        self.post_transformer = Transformer(depth = post_layers_depth, causal = causal, **transformer_kwargs)
        # 如果需要输出归一化,则使用LayerNorm,否则使用Identity
        self.norm_out = nn.LayerNorm(dim) if norm_out else nn.Identity()
    def forward(self, x, mask = None):
        # 定义变量含义:b 为 batch 大小,n 为序列长度,d 为特征维度,s 为缩短因子

        s, b, n = self.shorten_factor, *x.shape[:2]

        # hourglass 的上半部分,前置 transformer 层

        x = self.pre_transformer(x, mask = mask)

        # 填充到缩短因子的倍数,为池化做准备

        x = pad_to_multiple(x, s, dim = -2)

        if exists(mask):
            padded_mask = pad_to_multiple(mask, s, dim = -1, value = False)

        # 保存残差,并用于“注意力重采样”在下采样和上采样时

        x_residual = x.clone()

        # 如果是自回归的,进行移位操作,移位量为缩短因子减一

        if self.causal:
            shift = s - 1
            x = F.pad(x, (0, 0, shift, -shift), value = 0.)

            if exists(mask):
                padded_mask = F.pad(padded_mask, (shift, -shift), value = False)

        # 简单的平均池化

        downsampled = self.downsample(x)

        if exists(mask):
            downsampled_mask = reduce(padded_mask, 'b (n s) -> b n', 'sum', s = s) > 0
        else:
            downsampled_mask = None

        # 前谷“注意力重采样” - 每个桶中的池化令牌与预池化的令牌进行关注

        if exists(self.attn_resampling_pre_valley):
            if exists(mask):
                attn_resampling_mask = rearrange(padded_mask, 'b (n s) -> (b n) s', s = s)
            else:
                attn_resampling_mask = None

            downsampled = self.attn_resampling_pre_valley(
                rearrange(downsampled, 'b n d -> (b n) () d'),
                rearrange(x, 'b (n s) d -> (b n) s d', s = s),
                mask = attn_resampling_mask
            )

            downsampled = rearrange(downsampled, '(b n) () d -> b n d', b = b)

        # “谷” - 可能是一个常规 transformer 或另一个 hourglass

        x = self.valley_transformer(downsampled, mask = downsampled_mask)

        valley_out = x.clone()

        # 简单的重复上采样

        x = self.upsample(x)

        # 加上残差

        x = x + x_residual

        # 后谷“注意力重采样”

        if exists(self.attn_resampling_post_valley):
            x = self.attn_resampling_post_valley(
                rearrange(x, 'b (n s) d -> (b n) s d', s = s),
                rearrange(valley_out, 'b n d -> (b n) () d')
            )

            x = rearrange(x, '(b n) s d -> b (n s) d', b = b)

        # 将序列恢复到原始长度,如果为了池化而填充

        x = x[:, :n]

        # 后置 transformer 层

        x = self.post_transformer(x, mask = mask)
        return self.norm_out(x)
# 主要类定义

class HourglassTransformerLM(nn.Module):
    def __init__(
        self,
        *,
        num_tokens,  # 标记的数量
        dim,  # 向量维度
        max_seq_len,  # 最大序列长度
        depth,  # 深度
        shorten_factor = None,  # 缩短因子,默认为None
        heads = 8,  # 头数,默认为8
        dim_head = 64,  # 头的维度,默认为64
        attn_resampling = True,  # 注意力重采样,默认为True
        updown_sample_type = 'naive',  # 上下采样类型,默认为'naive'
        causal = True  # 因果关系,默认为True
    ):
        super().__init__()
        self.max_seq_len = max_seq_len

        # 标记嵌入层
        self.token_emb = nn.Embedding(num_tokens, dim)
        # 位置嵌入层
        self.pos_emb = nn.Embedding(max_seq_len, dim)

        # 获取 HourglassTransformer 模型
        self.transformer = get_hourglass_transformer(
            dim = dim,
            depth = depth,
            shorten_factor = shorten_factor,
            attn_resampling = attn_resampling,
            updown_sample_type = updown_sample_type,
            dim_head = dim_head,
            heads = heads,
            causal = causal,
            norm_out = True
        )

        # 线性层,用于输出logits
        self.to_logits = nn.Linear(dim, num_tokens)

    def forward(self, x, mask = None):
        device = x.device
        x = self.token_emb(x)
        pos_emb = self.pos_emb(torch.arange(x.shape[-2], device = device))
        x = x + rearrange(pos_emb, 'n d -> () n d')

        # 使用 Transformer 处理输入数据
        x = self.transformer(x, mask = mask)
        return self.to_logits(x)