Lucidrains-系列项目源码解析-一百一十四-

56 阅读38分钟

Lucidrains 系列项目源码解析(一百一十四)

.\lucidrains\x-transformers\setup.py

# 导入设置和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  # 包的名称
  name = 'x-transformers',
  # 查找除了 'examples' 之外的所有包
  packages = find_packages(exclude=['examples']),
  # 版本号
  version = '1.27.19',
  # 许可证类型
  license='MIT',
  # 描述信息
  description = 'X-Transformers - Pytorch',
  # 作者
  author = 'Phil Wang',
  # 作者邮箱
  author_email = 'lucidrains@gmail.com',
  # 项目链接
  url = 'https://github.com/lucidrains/x-transformers',
  # 长描述内容类型
  long_description_content_type = 'text/markdown',
  # 关键词列表
  keywords = [
    'artificial intelligence',
    'attention mechanism',
    'transformers'
  ],
  # 安装依赖
  install_requires=[
    'torch>=1.6',
    'einops>=0.7.0'
  ],
  # 分类标签
  classifiers=[
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\x-transformers\x_transformers\attend.py

# 导入所需模块和库
from functools import partial
from typing import Optional, Tuple
import torch
from torch import nn, einsum, Tensor
import torch.nn.functional as F
from collections import namedtuple
from functools import wraps
from packaging import version
from dataclasses import dataclass
from einops import rearrange, repeat

# 定义数据类 Intermediates,包含了一些可选的 Tensor 类型字段
@dataclass
class Intermediates:
    qk_similarities: Optional[Tensor] = None
    pre_softmax_attn: Optional[Tensor] = None
    post_softmax_attn: Optional[Tensor] = None
    cached_kv: Optional[Tuple[Tensor, Tensor]] = None

    # 将字段转换为元组
    def to_tuple(self):
        return (self.qk_similarities, self.pre_softmax_attn, self.post_softmax_attn)

# 定义一些辅助函数

# 判断值是否存在
def exists(val):
    return val is not None

# 返回默认值
def default(val, d):
    return val if exists(val) else d

# 过滤掉不存在的值
def compact(arr):
    return [*filter(exists, arr]

# 保证函数只调用一次
def once(fn):
    called = False
    @wraps(fn)
    def inner(x):
        nonlocal called
        if called:
            return
        called = True
        return fn(x)
    return inner

# 打印函数,只打印一次
print_once = once(print)

# 创建因果掩码的函数
def create_causal_mask(i, j, device):
    return torch.ones((i, j), device=device, dtype=torch.bool).triu(j - i + 1)

# 为 ONNX CPU 创建因果掩码的函数
def onnx_create_causal_mask(i, j, device):
    r = torch.arange(i, device=device)
    causal_mask = rearrange(r, 'i -> i 1') < rearrange(r, 'j -> 1 j')
    causal_mask = F.pad(causal_mask, (j - i, 0), value=False)
    return causal_mask

# 主类 Attend
class Attend(nn.Module):
    def __init__(
        self,
        *,
        dropout=0.,
        causal=False,
        heads=None,
        talking_heads=False,
        sparse_topk=None,
        scale=None,
        qk_norm=False,
        flash=False,
        add_zero_kv=False,
        onnxable=False,
        sdp_kwargs: dict = dict(
            enable_flash=True,
            enable_math=True,
            enable_mem_efficient=True
        )
    ):
        super().__init__()
        self.scale = scale

        self.causal = causal
        self.create_causal_mask = onnx_create_causal_mask if onnxable else create_causal_mask

        self.attn_fn = partial(F.softmax, dtype=torch.float32) if not qk_norm else F.softmax

        self.dropout = dropout
        self.attn_dropout = nn.Dropout(dropout)

        # talking heads
        assert not (flash and talking_heads), 'talking heads not compatible with flash attention'
        self.talking_heads = talking_heads
        if talking_heads:
            self.pre_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias=False)
            self.post_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias=False)

        # sparse topk
        assert not (flash and sparse_topk), 'sparse topk not compatible with flash attention'
        self.sparse_topk = sparse_topk

        # 添加一个由零组成的键/值令牌,以帮助控制异常值
        self.add_zero_kv = add_zero_kv

        # flash attention
        self.flash = flash
        assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'

        self.sdp_kwargs = sdp_kwargs

    def flash_attn(
        self,
        q, k, v,
        mask=None,
        attn_bias=None
        ):
            # 解包输入张量的形状和其他属性
            batch, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device

            # 推荐使用 Tri Dao 的多查询单键值注意力
            # 将键值对的形状从 torch.Size([1, 512, 64]) 扩展为 torch.Size([1, 8, 512, 64])

            if k.ndim == 3:
                k = repeat(k, 'b ... -> b h ...', h = q.shape[1])

            if v.ndim == 3:
                v = repeat(v, 'b ... -> b h ...', h = q.shape[1])

            # 处理缩放 - 默认情况下按 dim_head ** -0.5 缩放,但如果使用余弦相似度注意力,则需要注意

            if exists(self.scale):
                default_scale = q.shape[-1] ** -0.5
                q = q * (self.scale / default_scale)

            # 检查是否存在掩码并扩展为兼容的形状
            # 掩码是 B L,因此必须扩展为 B H N L

            causal = self.causal

            # 在 kv 缓存中只有一个令牌的情况下(q_len == 1),只需关闭因果掩码
            # 在推测解码中,这可能会增加到 5-6,因此那里将需要右对齐的因果掩码

            if q_len == 1 and causal:
                causal = False

            # 扩展键填充掩码

            if exists(mask):
                assert mask.ndim == 4
                mask = mask.expand(batch, heads, q_len, k_len)

            # 处理 kv 缓存 - 这应该可以在更新的 flash attention 2 中绕过

            if k_len > q_len and causal:
                causal_mask = self.create_causal_mask(q_len, k_len, device = device)
                if not exists(mask):
                    mask = ~causal_mask
                else:
                    mask = mask & ~causal_mask
                causal = False

            # 手动处理因果掩码,如果给定了另一个掩码

            row_is_entirely_masked = None

            if exists(mask) and causal:
                causal_mask = self.create_causal_mask(q_len, k_len, device = device)
                mask = mask & ~causal_mask

                # 防止整行被掩盖

                row_is_entirely_masked = ~mask.any(dim = -1)
                mask[..., 0] = mask[..., 0] | row_is_entirely_masked

                causal = False

            # 处理 alibi 位置偏差
            # 从布尔值转换为浮点数

            if exists(attn_bias):
                attn_bias = rearrange(attn_bias, 'h i j -> 1 h i j').expand(batch, heads, -1, -1)

                # 如果给定了掩码,掩码已经包含了上述逻辑中的因果掩��
                # 否则,如果没有给定掩码但仍然是因果的,将 alibi 位置偏差掩盖为一个很大的负数

                mask_value = -torch.finfo(q.dtype).max

                if exists(mask):
                    attn_bias = attn_bias.masked_fill(~mask, mask_value // 2)
                elif causal:
                    causal_mask = self.create_causal_mask(q_len, k_len, device = device)
                    attn_bias = attn_bias.masked_fill(causal_mask, mask_value // 2)
                    causal = False

                # scaled_dot_product_attention 将 attn_mask 作为布尔值或加性偏差处理
                # 这里将其作为加性偏差

                mask = attn_bias

            # pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale

            with torch.backends.cuda.sdp_kernel(**self.sdp_kwargs):
                out = F.scaled_dot_product_attention(
                    q, k, v,
                    attn_mask = mask,
                    dropout_p = self.dropout if self.training else 0., 
                    is_causal = causal
                )

            # 对于整行被完全掩盖的情况,应将该行令牌的输出置零

            if exists(row_is_entirely_masked):
                out = out.masked_fill(row_is_entirely_masked[..., None], 0.)

            return out, Intermediates()

        def forward(
            self,
            q, k, v,
            mask = None,
            attn_bias = None,
            prev_attn = None
    ):
        """
        einstein notation
        b - batch
        h - heads
        n, i, j - sequence length (base sequence length, source, target)
        d - feature dimension
        """

        # 获取输入张量的形状信息
        n, heads, kv_heads, device = q.shape[-2], q.shape[1], k.shape[1], q.device

        # 设置缩放因子,默认为特征维度的倒数平方根
        scale = default(self.scale, q.shape[-1] ** -0.5)

        # 是否启用因果注意力
        causal = self.causal

        # 处理缓存的键值对解码
        if n == 1 and causal:
            causal = False

        # 处理分组的多查询注意力
        if kv_heads == 1:
            k, v = map(lambda t: rearrange(t, 'b 1 n d -> b n d'), (k, v))
        elif kv_heads < heads:
            k, v = map(lambda t: repeat(t, 'b kvh n d -> b (r kvh) n d', r = heads // kv_heads), (k, v))

        # 处理零键值对,允许网络关注空内容
        if self.add_zero_kv:
            k, v = map(lambda t: F.pad(t, (0, 0, 1, 0), value = 0.), (k, v))

            if exists(mask):
                mask = F.pad(mask, (1, 0), value = True)

            if exists(attn_bias):
                attn_bias = F.pad(attn_bias, (1, 0), value = 0.)

        # 如果启用了flash attention,则返回flash attention结果
        if self.flash:
            assert not exists(prev_attn), 'residual attention not compatible with flash attention'
            return self.flash_attn(q, k, v, mask = mask, attn_bias = attn_bias)

        # 根据键值对的维度选择相应的乘法运算
        kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d'

        # 计算点积
        dots = einsum(f'b h i d, {kv_einsum_eq} -> b h i j', q, k) * scale

        # 如果存在先前的注意力,加上先前的注意力值
        if exists(prev_attn):
            dots = dots + prev_attn

        # 复制点积结果
        qk_similarities = dots.clone()

        # 如果启用了talking heads,对点积结果进行预处理
        if self.talking_heads:
            dots = self.pre_softmax_talking_heads(dots)

        # 如果存在注意力偏置,加上注意力偏置
        if exists(attn_bias):
            dots = dots + attn_bias

        # 获取点积结果的形状信息
        i, j, dtype = *dots.shape[-2:], dots.dtype

        # 设置掩码值为负无穷
        mask_value = -torch.finfo(dots.dtype).max

        # 如果存在稀疏topk参数且小于j,则只保留topk个值
        if exists(self.sparse_topk) and self.sparse_topk < j:
            top_values, _ = dots.topk(self.sparse_topk, dim = -1)
            sparse_topk_mask = dots < top_values[..., -1:]
            mask = (mask & sparse_topk_mask) if exists(mask) else sparse_topk_mask

        # 如果存在掩码,根据掩码值进行填充
        if exists(mask):
            dots = dots.masked_fill(~mask, mask_value)

        # 如果启用了因果注意力,根据因果掩码进行填充
        if causal:
            causal_mask = self.create_causal_mask(i, j, device = device)
            dots = dots.masked_fill(causal_mask, mask_value)

        # 复制点积结果作为预softmax的注意力值
        pre_softmax_attn = dots.clone()

        # 计算softmax得到注意力权重
        attn = self.attn_fn(dots, dim = -1)
        attn = attn.type(dtype)

        # 复制softmax后的注意力权重
        post_softmax_attn = attn.clone()

        # 对注意力权重进行dropout
        attn = self.attn_dropout(attn)

        # 如果启用了talking heads,对注意力权重进行后处理
        if self.talking_heads:
            attn = self.post_softmax_talking_heads(attn)

        # 计算输出结果
        out = einsum(f'b h i j, {kv_einsum_eq} -> b h i d', attn, v)

        # 保存中间结果
        intermediates = Intermediates(
            qk_similarities = qk_similarities,
            pre_softmax_attn = pre_softmax_attn,
            post_softmax_attn = post_softmax_attn
        )

        # 返回输出结果和中间结果
        return out, intermediates

.\lucidrains\x-transformers\x_transformers\autoregressive_wrapper.py

# 从 math 模块中导入 ceil 和 log 函数
# 从 typing 模块中导入 Optional, Union, Tuple, Callable 类型
# 导入 torch 模块及其子模块
# 导入 nn, Tensor, Module 类
# 导入 torch.nn.functional 模块
# 导入 einops 模块中的 rearrange, pack, unpack 函数
from math import ceil, log
from typing import Optional, Union, Tuple, Callable

import torch
from torch import nn, Tensor
from torch.nn import Module
import torch.nn.functional as F

from einops import rearrange, pack, unpack

# 检查变量是否存在的函数
def exists(val):
    return val is not None

# 如果变量存在则返回该变量,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# 返回输入的函数
def identity(t, *args, **kwargs):
    return t

# 将输入转换为元组的函数
def cast_tuple(t, length = 1):
    return t if isinstance(t, tuple) else (t,) * length

# 评估装饰器函数
def eval_decorator(fn):
    def inner(self, *args, **kwargs):
        was_training = self.training
        self.eval()
        out = fn(self, *args, **kwargs)
        self.train(was_training)
        return out
    return inner

# 对变长前缀进行右对齐的函数
def align_right(t, lens, pad_id = 0):
    batch, seq_len, device, dtype = *t.shape, t.device, t.dtype

    assert lens.ndim == 1 and lens.shape[0] == batch
    assert lens.amax() <= seq_len

    pad_lens = seq_len - lens
    max_pad_len = pad_lens.amax()

    batch_arange = torch.arange(batch, device = device, dtype = torch.long)[..., None]
    prompt_len_arange = torch.arange(seq_len, device = device, dtype = torch.long)

    t = F.pad(t, (max_pad_len, 0), value = 0)
    offset = max_pad_len - pad_lens

    aligned = t[batch_arange, prompt_len_arange + offset[..., None]]
    return aligned

# nucleus 函数
def top_p(logits, thres = 0.9):
    sorted_logits, sorted_indices = torch.sort(logits, descending = True)
    cum_probs = torch.cumsum(F.softmax(sorted_logits, dim = -1), dim = -1)

    sorted_indices_to_remove = cum_probs > thres
    sorted_indices_to_remove = F.pad(sorted_indices_to_remove, (1, -1), value = False)

    sorted_logits[sorted_indices_to_remove] = float('-inf')
    return sorted_logits.scatter(1, sorted_indices, sorted_logits)

# topk 函数
def top_k(logits, frac_num_tokens = 0.1, k = None):
    num_tokens = logits.shape[-1]

    k = default(k, ceil(frac_num_tokens * num_tokens))
    k = min(k, num_tokens)

    val, ind = torch.topk(logits, k)
    probs = torch.full_like(logits, float('-inf'))
    probs.scatter_(1, ind, val)
    return probs

# top_a 函数
def top_a(logits, min_p_pow = 2.0, min_p_ratio = 0.02):
    probs = F.softmax(logits, dim = -1)
    max_probs = torch.amax(probs, dim = -1, keepdim = True)
    limit = torch.pow(max_probs, min_p_pow) * min_p_ratio
    return torch.where(probs < limit, float('-inf'), logits)

# 对比解码函数
def contrastive_decode_fn(
    expert_logits,
    amateur_logits,
    alpha = 0.1,
    beta = 0.5
):
    """
    Appendix A Algorithm 2
    https://arxiv.org/abs/2309.09117
    """

    cutoff = log(alpha) + expert_logits.amax(dim = -1, keepdim = True)
    diffs = (1 + beta) * expert_logits - beta * amateur_logits
    contrastive_decode_logits = diffs.masked_fill(expert_logits < cutoff, -torch.finfo(expert_logits.dtype).max)
    return contrastive_decode_logits

# 自回归包装器类
class AutoregressiveWrapper(Module):
    def __init__(
        self,
        net,
        ignore_index = -100,
        pad_value = 0,
        mask_prob = 0.,
        add_attn_z_loss = False
    ):
        super().__init__()
        self.pad_value = pad_value
        self.ignore_index = ignore_index

        self.net = net
        self.max_seq_len = net.max_seq_len

        # 论文表明在自回归解码器训练中与掩码(MLM)结合使用会带来很大的改进 https://arxiv.org/abs/2210.13432
        assert mask_prob < 1.
        self.mask_prob = mask_prob

        # 是否添加路由器 z-loss
        self.add_attn_z_loss = add_attn_z_loss

    @torch.no_grad()
    @eval_decorator
    # 生成文本序列
    def generate(
        self,
        prompts,  # 输入的提示文本
        seq_len,  # 生成的序列长度
        eos_token = None,  # 结束标记
        temperature = 1.,  # 温度参数
        prompt_lens: Optional[Tensor] = None,  # 提示文本长度
        filter_logits_fn: Callable = top_k,  # 过滤 logits 的函数
        restrict_to_max_seq_len = True,  # 是否限制最大序列长度
        amateur_model: Optional[Union[Module, Tuple[Module]]] = None,  # 业余模型
        filter_kwargs: dict = dict(),  # 过滤参数
        contrastive_decode_kwargs: Union[dict, Tuple[dict]] = dict(  # 对比解码参数
            beta = 0.5,
            alpha = 0.1
        ),
        cache_kv = True,  # 是否缓存键值对
        **kwargs  # 其他参数
    def forward(self, x, return_outputs = False, **kwargs):
        seq, ignore_index, add_attn_z_loss = x.shape[1], self.ignore_index, self.add_attn_z_loss

        # 输入和目标序列
        inp, target = x[:, :-1], x[:, 1:]
        inp = torch.where(inp == ignore_index, self.pad_value, inp)

        # 如果存在 mask_prob,则进行 mask 处理
        if self.mask_prob > 0.:
            rand = torch.randn(inp.shape, device = x.device)
            rand[:, 0] = -torch.finfo(rand.dtype).max  # 第一个 token 不应被 mask 掉
            num_mask = min(int(seq * self.mask_prob), seq - 1)
            indices = rand.topk(num_mask, dim = -1).indices
            mask = ~torch.zeros_like(inp).scatter(1, indices, 1.).bool()
            kwargs.update(self_attn_kv_mask = mask)

        # 获取 logits 和缓存
        logits, cache = self.net(
            inp,
            return_intermediates = True,
            return_attn_z_loss = add_attn_z_loss,
            **kwargs
        )

        # 计算交叉熵损失
        loss = F.cross_entropy(
            rearrange(logits, 'b n c -> b c n'),
            target,
            ignore_index = ignore_index
        )

        # 如果存在注意力 z 损失,则加上
        if add_attn_z_loss:
            loss = loss + cache.attn_z_loss

        # 如果不需要返回输出,则返回损失
        if not return_outputs:
            return loss

        # 否则返回损失和输出
        return loss, (logits, cache)

.\lucidrains\x-transformers\x_transformers\continuous.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块
from torch import nn
# 从 torch.nn.functional 中导入 F 模块
import torch.nn.functional as F

# 从 einops 库中导入 pack, repeat, unpack 函数
from einops import pack, repeat, unpack

# 从 x_transformers.x_transformers 模块中导入以下类和函数
from x_transformers.x_transformers import (
    AttentionLayers,
    ScaledSinusoidalEmbedding,
    AbsolutePositionalEmbedding,
    LayerNorm,
    always,
    pad_at_dim
)

# 辅助函数

# 判断变量是否存在
def exists(val):
    return val is not None

# 返回默认值
def default(val, d):
    if exists(val):
        return val
    return d() if callable(d) else d

# 主要类

# 连续变换器包装器类
class ContinuousTransformerWrapper(nn.Module):
    def __init__(
        self,
        *,
        max_seq_len,
        attn_layers: AttentionLayers,
        dim_in = None,
        dim_out = None,
        emb_dim = None,
        max_mem_len = 0,
        num_memory_tokens = None,
        post_emb_norm = False,
        emb_dropout = 0.,
        use_abs_pos_emb = True,
        scaled_sinu_pos_emb = False
    ):
        super().__init__()
        dim = attn_layers.dim

        # 最大序列长度
        self.max_seq_len = max_seq_len

        # 最大记忆长度
        self.max_mem_len = max_mem_len
        
        # 没有绝对位置编码
        no_abs_pos_emb = max_seq_len == 0 or not (use_abs_pos_emb and not attn_layers.disable_abs_pos_emb)

        if no_abs_pos_emb:
            # 如果没有绝对位置编码,则位置编码为常数 0
            self.pos_emb = always(0)
        elif scaled_sinu_pos_emb:
            # 如果使用缩放的正弦位置编码,则创建 ScaledSinusoidalEmbedding 对象
            self.pos_emb = ScaledSinusoidalEmbedding(dim)
        else:
            # 否则创建 AbsolutePositionalEmbedding 对象
            self.pos_emb = AbsolutePositionalEmbedding(dim, max_seq_len)

        # 后置嵌入层归一化
        self.post_emb_norm = LayerNorm(dim) if post_emb_norm else nn.Identity()
        # 嵌入层丢弃
        self.emb_dropout = nn.Dropout(emb_dropout)

        # 记忆令牌

        # 默认记忆令牌数量为 0
        num_memory_tokens = default(num_memory_tokens, 0)
        # 是否有记忆令牌
        self.has_memory_tokens = num_memory_tokens > 0

        if num_memory_tokens > 0:
            # 如果有记忆令牌,则创建 nn.Parameter 对象
            self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))

        # 注意力层

        # 设置注意力层
        self.attn_layers = attn_layers

        # 投影输入和输出

        # 输入投影层
        self.project_in = nn.Linear(dim_in, dim, bias = False) if exists(dim_in) else nn.Identity()
        # 输出投影层
        self.project_out = nn.Linear(dim, dim_out, bias = False) if exists(dim_out) else nn.Identity()

    def forward(
        self,
        x,
        return_embeddings = False,
        return_intermediates = False,
        return_mems = False,
        mask = None,
        return_attn = False,
        mems = None,
        mem_masks = None,
        pos = None,
        prepend_embeds = None,
        prepend_mask = None,
        **kwargs
        ):
        # 解包输入张量 x 的形状,得到 batch, seq, device
        batch, seq, device = *x.shape[:2], x.device

        # 对输入张量进行投影
        x = self.project_in(x)
        # 添加位置编码
        x = x + self.pos_emb(x, pos = pos)

        # 对投影后的张量进行归一化
        x = self.post_emb_norm(x)

        # 处理记忆令牌

        if self.has_memory_tokens:
            # 重复记忆令牌,扩展为 batch 维度
            m = repeat(self.memory_tokens, 'm d -> b m d', b = batch)
            # 打包记忆令牌和输入张量
            x, mem_ps = pack([m, x], 'b * d')

            # 如果存在 mask,则对 mask 进行处理
            if exists(mask):
                num_mems = m.shape[-2]
                mask = pad_at_dim(mask, (num_mems, 0), dim = -1, value = True)

        # 是否追加嵌入,如 PaLI 中的图像嵌入

        if exists(prepend_embeds):
            prepend_seq, prepend_dim = prepend_embeds.shape[1:]

            # 断言追加的嵌入维度与模型维度相同
            assert prepend_dim == x.shape[-1], 'prepended embeddings need to have same dimensions as model dimensions'

            # 在指定维度上连接张量
            x = torch.cat((prepend_embeds, x), dim = -2)

            # 如果存在 prepend_mask 或 mask,则对 mask 进行处理
            if exists(prepend_mask) or exists(mask):
                mask = default(mask, lambda: torch.ones((batch, seq), device = device, dtype = torch.bool))
                prepend_mask = default(prepend_mask, lambda: torch.ones((batch, prepend_seq), device = device, dtype = torch.bool))

                mask = torch.cat((prepend_mask, mask), dim = -1)

        # 对嵌入张量进行 dropout
        x = self.emb_dropout(x)

        # 注意力层

        x, intermediates = self.attn_layers(x, mask = mask, mems = mems, mem_masks = mem_masks, return_hiddens = True, **kwargs)

        # 剥离记忆令牌

        if self.has_memory_tokens:
            m, x = unpack(x, mem_ps, 'b * d')
            intermediates.memory_tokens = m

        # 输出结果
        out = self.project_out(x) if not return_embeddings else x

        # 如果需要返回中间结果
        if return_intermediates:
            return out, intermediates

        # 如果需要返回记忆
        if return_mems:
            hiddens = intermediates.hiddens
            new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), hiddens))
            return out, new_mems

        # 如果需要返回注意力图
        if return_attn:
            attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
            return out, attn_maps

        return out
# 定义一个连续自回归包装器类,继承自 nn.Module
class ContinuousAutoregressiveWrapper(nn.Module):
    # 初始化方法
    def __init__(
        self,
        net: ContinuousTransformerWrapper,  # 接收一个 ContinuousTransformerWrapper 类型的网络
        ignore_index = -100,  # 忽略索引,默认为 -100
        pad_value = 0,  # 填充值,默认为 0
        loss_fn = nn.MSELoss(reduction = 'none')  # 损失函数,默认为均方误差损失
    ):
        super().__init__()  # 调用父类的初始化方法
        self.net = net  # 将传入的网络赋值给属性
        self.max_seq_len = net.max_seq_len  # 获取网络的最大序列长度
        self.loss_fn = loss_fn  # 将传入的损失函数赋值给属性

    # 生成方法,不计算梯度
    @torch.no_grad()
    def generate(self, start_tokens, seq_len, **kwargs):
        device = start_tokens.device  # 获取起始标记的设备
        was_training = self.net.training  # 记录网络是否在训练状态
        num_dims = len(start_tokens.shape)  # 获取起始标记的维度数

        assert num_dims >= 2, 'number of dimensions of your start tokens must be greater or equal to 2'  # 断言起始标记的维度数必须大于等于2

        if num_dims == 2:
            start_tokens = start_tokens[None, :]  # 如果维度数为2,则在第一维度上添加一个维度

        b, t, _, device = *start_tokens.shape, start_tokens.device  # 获取起始标记的形状和设备

        self.net.eval()  # 将网络设置为评估模式
        out = start_tokens  # 初始化输出为起始标记

        for _ in range(seq_len):
            x = out[:, -self.max_seq_len:]  # 获取最后 self.max_seq_len 个标记

            last = self.net(x, **kwargs)[:, -1:]  # 使用网络生成下一个标记
            out = torch.cat((out, last), dim = -2)  # 将生成的标记拼接到输出中

        out = out[:, t:]  # 去掉起始标记

        if num_dims == 2:
            out = out.squeeze(0)  # 如果维度数为2,则去掉第一维度

        self.net.train(was_training)  # 恢复网络的训练状态
        return out  # 返回生成的序列

    # 前向传播方法
    def forward(self, x, **kwargs):
        inp, target = x[:, :-1], x[:, 1:]  # 获取输入和目标序列

        assert 'prepend_embeds' not in kwargs  # 断言不应该传入 'prepend_embeds' 参数

        mask = kwargs.get('mask', None)  # 获取掩码,如果不存在则为 None
        if exists(mask) and mask.shape[1] == x.shape[1]:  # 如果掩码存在且与输入序列长度相同
            mask = mask[:, :-1]  # 去掉最后一个标记的掩码
            kwargs['mask'] = mask  # 更新 kwargs 中的掩码

        out = self.net(inp, **kwargs)  # 使用网络进行前向传播

        loss = self.loss_fn(out, target)  # 计算损失

        if exists(mask):  # 如果掩码存在
            assert loss.ndim > 1, 'loss should not be reduced if mask is passed in'  # 断言如果传入掩码,则损失不应该被减少
            loss = loss[mask]  # 根据掩码获取损失

        return loss.mean()  # 返回损失的平均值

.\lucidrains\x-transformers\x_transformers\dpo.py

# 导入必要的库
from copy import deepcopy
import torch
from torch.nn import Module
import torch.nn.functional as F
from x_transformers.x_transformers import TransformerWrapper
from einops import rearrange

# 辅助函数

# 检查变量是否存在
def exists(v):
    return v is not None

# 冻结模型的所有层
def freeze_all_layers_(module):
    for param in module.parameters():
        param.requires_grad = False

# 从模型和序列中获取对数概率
def log_prob_from_model_and_seq(model, seq):
    logits = model(seq)
    log_prob = logits.log_softmax(dim = -1)
    indices = rearrange(seq, '... -> ... 1')
    log_probs = log_prob.gather(-1, indices)
    return rearrange(log_probs, '... 1 -> ...')

# 计算带掩码的平均值
def masked_mean(log_probs, mask = None):
    if not exists(mask):
        return log_probs.mean(dim = -1)

    log_probs = log_probs.masked_fill(~mask, 0.)
    num = log_probs.sum(dim = -1)
    den = mask.sum(dim = -1)
    return num / den.clamp(min = 1e-5)

# 可能的并且掩码
def maybe_and_mask(*masks):
    masks = [*filter(exists, masks)]
    if len(masks) == 0:
        return None

    mask, *rest_masks = masks
    for rest_mask in rest_masks:
        mask = mask & rest_mask

    return mask

# 主类

class DPO(Module):
    def __init__(
        self,
        model: TransformerWrapper,
        *,
        beta = 0.1,
        pad_id = None
    ):
        super().__init__()
        self.policy_model = model

        self.ref_model = deepcopy(model)
        freeze_all_layers_(self.ref_model)

        self.beta = beta
        self.pad_id = pad_id

    def parameters(self):
        return self.policy_model.parameters()

    def forward(
        self,
        preferred_seq,
        unpreferred_seq,
        *,
        prompt_mask,
        preferred_seq_mask = None,
        unpreferred_seq_mask = None,
    ):
        assert preferred_seq.ndim == 2
        assert preferred_seq.shape == unpreferred_seq.shape

        if exists(self.pad_id):
            if not exists(preferred_seq_mask):
                preferred_seq_mask = preferred_seq != self.pad_id

            if not exists(unpreferred_seq_mask):
                unpreferred_seq_mask = unpreferred_seq != self.pad_id

        """
        Following Appendix B in https://arxiv.org/abs/2305.18290
        """

        with torch.no_grad():
            self.ref_model.eval()
            ref_preferred_logprob = log_prob_from_model_and_seq(self.ref_model, preferred_seq)
            ref_unpreferred_logprob = log_prob_from_model_and_seq(self.ref_model, unpreferred_seq)

        policy_preferred_logprob = log_prob_from_model_and_seq(self.policy_model, preferred_seq)
        policy_unpreferred_logprob = log_prob_from_model_and_seq(self.policy_model, unpreferred_seq)

        # 带掩码的对数概率的平均值

        preferred_seq_mask = maybe_and_mask(~prompt_mask, preferred_seq_mask)
        unpreferred_seq_mask = maybe_and_mask(~prompt_mask, unpreferred_seq_mask)

        ref_preferred_logprob, policy_preferred_logprob = map(lambda t: masked_mean(t, preferred_seq_mask), (ref_preferred_logprob, policy_preferred_logprob))
        ref_unpreferred_logprob, policy_unpreferred_logprob = map(lambda t: masked_mean(t, unpreferred_seq_mask), (ref_unpreferred_logprob, policy_unpreferred_logprob))

        # 主要的 DPO 公式

        policy_logratios = policy_preferred_logprob - policy_unpreferred_logprob
        ref_logratios = ref_preferred_logprob - ref_unpreferred_logprob

        losses = -F.logsigmoid(self.beta * (policy_logratios - ref_logratios))

        return losses.mean()

.\lucidrains\x-transformers\x_transformers\nonautoregressive_wrapper.py

import math
from random import random
from contextlib import nullcontext
from collections import namedtuple

import torch
import torch.nn.functional as F
from torch import nn

from einops import rearrange, repeat, pack, unpack

from x_transformers.x_transformers import TransformerWrapper
from typing import Optional

# 定义一个命名元组 Losses,包含 loss、generator_loss 和 critic_loss 三个字段
Losses = namedtuple('Losses', ['loss', 'generator_loss', 'critic_loss'])

# 辅助函数

# 判断值是否存在
def exists(val):
    return val is not None

# 如果值存在则返回该值,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# 采样辅助函数

# 从 logits 中选择 top-k 的概率值
def top_k(logits, thres = 0.9):
    k = math.ceil((1 - thres) * logits.shape[-1])
    val, ind = logits.topk(k, dim = -1)
    probs = torch.full_like(logits, float('-inf'))
    probs.scatter_(2, ind, val)
    return probs

# 计算对数
def log(t, eps = 1e-10):
    return torch.log(t + eps)

# 生成 Gumbel 噪声
def gumbel_noise(t):
    noise = torch.zeros_like(t).uniform_(0, 1)
    return -log(-log(noise))

# 从 Gumbel 噪声中采样
def gumbel_sample(t, temperature = 1., dim = -1):
    return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim = dim)

# 概率辅助函数

# 根据概率值进行采样
def sample_prob(prob):
    return random() < prob

# 抛硬币,返回 True 或 False
def coin_flip():
    return sample_prob(0.5)

# 张量辅助函数

# 根据掩码和概率值获取子集掩码
def get_mask_subset_prob(mask, prob, min_mask = 0):
    batch, seq, device = *mask.shape, mask.device
    num_to_mask = (mask.sum(dim = -1, keepdim = True) * prob).clamp(min = min_mask)
    logits = torch.rand((batch, seq), device = device)
    logits = logits.masked_fill(~mask, -1)

    randperm = logits.argsort(dim = -1).argsort(dim = -1).float()

    num_padding = (~mask).sum(dim = -1, keepdim = True)
    randperm -= num_padding

    subset_mask = randperm < num_to_mask
    subset_mask.masked_fill_(~mask, False)
    return subset_mask

# 调度函数

# 线性调度函数
def linear_schedule(t):
    return 1 - t

# 余弦调度函数
def cosine_schedule(t):
    """ https://arxiv.org/abs/2202.04200 """
    return torch.cos(t * math.pi / 2)

# 自标记评论者类
# 受 Nijkamp 等人启发 - https://aclanthology.org/2021.naacl-main.409/

class SelfCritic(nn.Module):
    def __init__(self, net):
        super().__init__()
        self.net = net

        dim = net.attn_layers.dim
        self.to_logits = nn.Linear(dim, 1)

    def forward(self, x):
        embed = self.net(x, return_embeddings = True)
        return self.to_logits(embed)

# 非自回归包装器类
# 参考 https://arxiv.org/abs/1904.09324 和 https://arxiv.org/abs/2202.04200

class NonAutoregressiveWrapper(nn.Module):
    def __init__(
        self,
        net,
        *,
        mask_id,
        steps = 18,
        self_cond = False,
        self_cond_train_prob = 0.75,
        no_replace_prob = 0.15,          # 有多少百分比的标记将保持不变,原始 MLM 论文中进行了这样的操作
        random_token_prob = 0.1,         # 有多少百分比的标记将被替换为随机标记,原始 MLM 论文中进行了这样的操作
        schedule = 'linear',
        can_mask_prev_unmasked = False,  # 当取消掩码时,是否可以重新掩码先前未掩码的标记
        token_critic: Optional[TransformerWrapper] = None,
        self_token_critic = False,
        critic_loss_weight = 1.
        # 调用父类的构造函数
        super().__init__()
        # 断言确保 self_token_critic 为假且 token_critic 不存在
        assert not (self_token_critic and exists(token_critic))

        # 设置网络模型
        self.net = net

        # 获取嵌入维度和词汇表大小
        dim = net.emb_dim
        self.dim = dim
        self.num_tokens = net.num_tokens

        # 设置掩码的标识符
        self.mask_id = mask_id

        # 设置不替换概率和随机替换概率
        self.no_replace_prob = no_replace_prob
        self.random_token_prob = random_token_prob

        # 获取最大序列长度和步数
        self.max_seq_len = net.max_seq_len
        self.steps = steps

        # 根据调度方式设置调度函数
        if callable(schedule):
            self.schedule_fn = schedule
        if schedule == 'linear':
            self.schedule_fn = linear_schedule
        elif schedule == 'cosine':
            self.schedule_fn = cosine_schedule
        else:
            raise ValueError(f'invalid schedule {schedule}')

        # 设置是否可以掩盖之前未掩盖的标记
        self.can_mask_prev_unmasked = can_mask_prev_unmasked

        # 设置自条件
        self.self_cond = self_cond

        # 如果存在自条件,则初始化空嵌入和线性层
        if self_cond:
            self.null_embed = nn.Parameter(torch.randn(dim))
            self.to_self_cond = nn.Linear(dim, dim, bias=False) if self_cond else None
            self.self_cond_train_prob = self_cond_train_prob

        # 设置标记评论者
        self.token_critic = token_critic

        # 如果存在 self_token_critic,则初始化 SelfCritic 类
        if self_token_critic:
            self.token_critic = SelfCritic(net)

        # 设置评论者损失权重
        self.critic_loss_weight = critic_loss_weight

    # 生成函数,不进行梯度计算
    @torch.no_grad()
    def generate(
        self,
        batch_size=None,
        start_temperature=1.,
        filter_thres=0.7,
        noise_level_scale=1.,
        **kwargs
    ):
        # 检查是否存在 batch_size 变量,如果不存在则设置为默认值 1
        sample_one = not exists(batch_size)
        batch_size = default(batch_size, 1)

        # 获取神经网络参数的设备信息
        device = next(self.net.parameters()).device

        # 保存当前模型的训练状态,并将模型设置为评估模式
        was_training = self.training
        self.eval()

        # 在0到1之间生成self.steps + 1个时间点
        times = torch.linspace(0., 1., self.steps + 1)

        # 初始化序列和掩码,将序列初始值设为mask_id,掩码初始值设为True
        shape = (batch_size, self.max_seq_len)
        seq = torch.full(shape, self.mask_id, device=device)
        mask = torch.full(shape, True, device=device)

        # 计算所有掩码的数量
        all_mask_num_tokens = (self.schedule_fn(times[1:]) * self.max_seq_len).long()

        # 判断是否有自我条件
        has_self_cond = self.self_cond
        last_embed = self.null_embed if has_self_cond else None

        # 逐步解除掩码
        for mask_num_tokens, steps_until_x0 in zip(all_mask_num_tokens.tolist(), reversed(range(self.steps))):

            # 如果有自我条件,则计算自我条件
            self_cond = self.to_self_cond(last_embed) if has_self_cond else None

            # 获取神经网络的输出logits和embeds
            logits, embeds = self.net(
                seq,
                sum_embeds=self_cond,
                return_logits_and_embeddings=True,
                **kwargs
            )

            # 如果有自我条件,则更新last_embed
            if has_self_cond:
                last_embed = embeds

            # 如果存在filter_thres,则对logits进行top_k筛选
            if exists(filter_thres):
                logits = top_k(logits, filter_thres)

            # 计算温度和概率
            annealing_scale = steps_until_x0 / self.steps
            temperature = start_temperature * annealing_scale
            probs = (logits / max(temperature, 1e-3)).softmax(dim=-1)

            # 从logits中采样得到sampled_ids
            sampled_ids = gumbel_sample(logits, temperature=max(temperature, 1e-3))

            # 根据掩码mask更新序列seq
            seq = torch.where(mask, sampled_ids, seq)

            # 如果存在token_critic,则计算scores
            if exists(self.token_critic):
                scores = self.token_critic(seq)
                scores = rearrange(scores, 'b n 1 -> b n')
                scores = scores + noise_level_scale * gumbel_noise(scores) * annealing_scale
            else:
                scores = 1 - logits.softmax(dim=-1)
                scores = scores.gather(2, rearrange(sampled_ids, 'b n -> b n 1'))
                scores = rearrange(scores, 'b n 1 -> b n')

            # 如果mask_num_tokens为0,则跳过
            if mask_num_tokens == 0:
                pass

            # 如果不允许掩盖之前未掩盖的标记,则将scores中的非掩码位置设为最小值
            if not self.can_mask_prev_unmasked:
                scores = scores.masked_fill(~mask, -torch.finfo(scores.dtype).max)

            # 根据scores中的topk值更新掩码mask
            mask_indices = scores.topk(mask_num_tokens, dim=-1).indices
            mask = torch.zeros_like(scores, dtype=torch.bool).scatter(1, mask_indices, True)
            seq = seq.masked_fill(mask, self.mask_id)

        # 恢复模型的训练状态
        self.train(was_training)

        # 如果sample_one为True,则将seq重新排列
        if sample_one:
            seq = rearrange(seq, '1 n -> n')

        # 返回生成的序列seq
        return seq

    # 定义前向传播函数
    def forward(
        self,
        x,
        only_train_generator=False,
        only_train_critic=False,
        generator_sample_temperature=None,
        **kwargs
    ):
        # 获取输入张量 x 的形状、维度大小 n 和设备信息
        b, n, device = *x.shape, x.device
        # 断言维度大小 n 等于 self.max_seq_len
        assert n == self.max_seq_len

        # 复制原始序列 x
        orig_seq = x.clone()

        # 生成随机数 rand_times,范围在 [0, 1] 之间
        rand_times = torch.empty(b, device = device).uniform_(0, 1)
        # 生成随机排列的索引 batched_randperm
        batched_randperm = torch.rand((b, n), device = device).argsort(dim = -1).float()

        # 根据随机数 rand_times 调用 schedule_fn 函数生成随机概率 rand_probs
        rand_probs = self.schedule_fn(rand_times)
        # 计算每个样本应mask的 token 数量
        num_tokens_mask = (rand_probs * n).clamp(min = 1.)
        # 生成 mask,用于随机 mask token
        mask = batched_randperm < rearrange(num_tokens_mask, 'b -> b 1')

        # 为了确保所有 token 都产生嵌入,而不仅仅是 [mask] 输入中的 token,与经典的 BERT MLM 论文中所做的不同
        # 可能需要为了 self-conditioning(对嵌入的自我调节)良好运作
        replace_mask_id_mask = mask.clone()
        frac_seq_left = 1.

        # 如果 self.no_replace_prob 大于 0 且 coin_flip() 为真
        if self.no_replace_prob > 0. and coin_flip():
            frac_seq_left -= self.no_replace_prob

            # 获取 no_replace_prob_mask,用于不替换 token
            no_replace_prob_mask = get_mask_subset_prob(mask, self.no_replace_prob)
            replace_mask_id_mask &= ~no_replace_prob_mask

        # 如果 self.random_token_prob 大于 0 且 coin_flip() 为真
        if self.random_token_prob > 0. and coin_flip():
            # 获取 random_token_prob_mask,用于随机替换 token
            random_token_prob_mask = get_mask_subset_prob(replace_mask_id_mask, self.random_token_prob * frac_seq_left)
            # 生成随机 token
            random_tokens = torch.randint(0, self.num_tokens, (b, n), device = device)

            # 根据 random_token_prob_mask 替换 token
            x = torch.where(random_token_prob_mask, random_tokens, x)
            replace_mask_id_mask &= ~random_token_prob_mask

        # 根据 replace_mask_id_mask 进行 mask 操作,用 self.mask_id 替换 token
        masked = torch.where(replace_mask_id_mask, self.mask_id, x)

        # self conditioning

        # 如果 self.self_cond 为真
        if self.self_cond:
            self_cond = self.null_embed

            # 如果以 self_cond_train_prob 的概率进行采样
            if sample_prob(self.self_cond_train_prob):
                with torch.no_grad():
                    # 通过网络获取 self_cond
                    self_cond = self.net(masked, return_embeddings = True, **kwargs).detach()

            # 更新 kwargs,添加 sum_embeds 信息
            kwargs.update(sum_embeds = self.to_self_cond(self_cond))

        # logits

        # 根据 only_train_critic 决定 context
        context = torch.no_grad if only_train_critic else nullcontext

        with context():
            # 获取 logits
            logits = self.net(masked, **kwargs)

        # 交叉熵损失
        loss = F.cross_entropy(
            logits[mask],
            orig_seq[mask]
        )

        # 如果不存在 token_critic 或者只训练生成器
        if not exists(self.token_critic) or only_train_generator:
            return Losses(loss, loss, None)

        # 采样生成的 token
        sampled_ids = gumbel_sample(logits, temperature = default(generator_sample_temperature, random()))
        generated = torch.where(mask, sampled_ids, orig_seq)

        # 获取 critic_logits 和 critic_labels
        critic_logits = self.token_critic(generated)
        critic_labels = (sampled_ids != orig_seq).float()

        # critic 损失
        critic_loss = F.binary_cross_entropy_with_logits(
            rearrange(critic_logits, '... 1 -> ...'),
            critic_labels
        )

        # 根据研究人员想要训练的内容确定要返回的损失
        if only_train_critic:
            total_loss = critic_loss
            loss = None
        else:
            total_loss = loss + critic_loss * self.critic_loss_weight

        return Losses(total_loss, loss,  critic_loss)

.\lucidrains\x-transformers\x_transformers\xl_autoregressive_wrapper.py

# 从 math 模块中导入 ceil 函数
from math import ceil

# 导入 torch 模块及相关子模块
import torch
from torch import nn
import torch.nn.functional as F

# 导入 einops 模块中的 rearrange, pack, unpack 函数
from einops import rearrange, pack, unpack
# 导入 x_transformers 模块中的 autoregressive_wrapper 模块中的 top_p, top_k, eval_decorator 函数

# 辅助函数

# 判断变量是否存在的函数
def exists(val):
    return val is not None

# 判断一个数是否能被另一个数整除的函数
def divisible_by(numer, denom):
    return (numer % denom) == 0 

# xl 自回归包装器类

class XLAutoregressiveWrapper(nn.Module):
    def __init__(
        self,
        net,
        ignore_index = -100,
        pad_value = 0
    ):
        super().__init__()
        self.pad_value = pad_value
        self.ignore_index = ignore_index

        self.net = net
        self.max_seq_len = net.max_seq_len

    # 生成序列的方法,使用 torch.no_grad() 修饰,eval_decorator 修饰
    @torch.no_grad()
    @eval_decorator
    def generate(
        self,
        start_tokens,
        seq_len,
        eos_token = None,
        temperature = 1.,
        filter_logits_fn = top_k,
        filter_thres = 0.9,
        mems = None,
        **kwargs
    ):
        device, max_seq_len = start_tokens.device, self.max_seq_len

        start_tokens, ps = pack([start_tokens], '* n')

        b, t = start_tokens.shape

        *all_leading_tokens, _ = start_tokens.split(max_seq_len, dim = -1)

        # 捕获当前段的记忆

        for leading_tokens in all_leading_tokens:
            _, mems = self.net(
                leading_tokens,
                mems = mems,
                return_mems = True,
                **kwargs
            )

        # 现在开始从当前段进行采样

        curr_pos = len(all_leading_tokens) * max_seq_len
        curr_mems = mems

        cache = None
        out = start_tokens

        for _ in range(seq_len):
            curr_segment_len = out.shape[-1]
            is_last_segment_tokens = divisible_by(curr_segment_len, max_seq_len)

            x = out[:, curr_pos:]

            logits, cache = self.net(
                x,
                mems = curr_mems,
                cache = cache,
                return_mems = True,
                return_intermediates = True,
                **kwargs
            )

            mems = cache.mems

            logits = logits[:, -1]
            filtered_logits = filter_logits_fn(logits, thres = filter_thres)
            probs = F.softmax(filtered_logits / temperature, dim=-1)

            sample = torch.multinomial(probs, 1)

            if is_last_segment_tokens:
                curr_pos = curr_segment_len
                curr_mems = mems

            out = torch.cat((out, sample), dim=-1)

            if exists(eos_token):
                is_eos_tokens = (out == eos_token)

                if is_eos_tokens.any(dim = -1).all():
                    # 掩盖掉 eos 之后的所有内容
                    shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
                    mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1
                    out = out.masked_fill(mask, self.pad_value)
                    break

        out = out[:, t:]

        out, = unpack(out, ps, '* n')

        return out

    def forward(
        self,
        x,
        mems = None,
        **kwargs
        ):
        # 从 self 中获取 ignore_index 和 max_seq_len 的值
        ignore_index, max_seq_len = self.ignore_index, self.max_seq_len

        # 将输入 x 的每一行的最后一个元素作为 labels,其余作为 x
        x, labels = x[:, :-1], x[:, 1:]

        # 获取 x 的序列长度
        seq_len = x.shape[1]

        # 准备分块数据

        # 将 x 和 labels 按照 max_seq_len 进行分块
        split_x = x.split(max_seq_len, dim = -1)
        split_labels = labels.split(max_seq_len, dim = -1)
        # 计算每个分块的损失权重
        loss_weights = tuple(map(lambda t: t.shape[-1] / seq_len, split_x))

        # 遍历每个分块并计算加权损失

        # 初始化总损失
        total_loss = 0.        

        for chunk, chunk_labels, loss_weight in zip(split_x, split_labels, loss_weights):

            # 在网络中传入当前分块数据,获取输出 logits 和记忆 mems
            logits, mems = self.net(
                chunk,
                mems = mems,
                return_mems = True,
                **kwargs
            )

            # 计算交叉熵损失
            loss = F.cross_entropy(
                rearrange(logits, 'b n c -> b c n'),
                chunk_labels,
                ignore_index = ignore_index
            )

            # 累加加权损失
            total_loss = total_loss + loss * loss_weight

        # 返回总损失
        return total_loss

.\lucidrains\x-transformers\x_transformers\xval.py

"""
定义了一个基于离散标记的常规变换器,但对于数字是连续的
更好地泛化了算术
https://arxiv.org/abs/2310.02989
"""

# 导入所需的库
import torch
from torch import nn, Tensor
import torch.nn.functional as F

from typing import Callable
from collections import namedtuple

from einops import rearrange
from einops.layers.torch import Rearrange

from x_transformers.x_transformers import (
    AttentionLayers,
    TokenEmbedding,
    ScaledSinusoidalEmbedding,
    AbsolutePositionalEmbedding
)

from x_transformers.autoregressive_wrapper import (
    top_k,
    top_p
)

# 常量

# 定义一个命名元组,用于表示损失的细分
LossBreakdown = namedtuple('LossBreakdown', ['cross_entropy_loss', 'numerical_mse_loss'])

# 定义一个命名元组,用于表示生成的返回结果
GenerateReturn = namedtuple('GenerateReturn', ['sampled_token_ids', 'sampled_numbers', 'is_number_mask'])

# 辅助函数

# 判断变量是否存在
def exists(val):
    return val is not None

# 返回默认值
def default(val, d):
    if exists(val):
        return val
    return d() if callable(d) else d

# 主要类

class XValTransformerWrapper(nn.Module):
    def __init__(
        self,
        *,
        num_tokens,
        max_seq_len,
        numerical_token_id,
        attn_layers: AttentionLayers,
        emb_dim = None,
        logits_dim = None,
        tie_embedding = False,
        max_mem_len = 0,
        num_memory_tokens = None,
        emb_dropout = 0.,
        use_abs_pos_emb = True,
        scaled_sinu_pos_emb = False
    ):
        super().__init__()
        dim = attn_layers.dim
        emb_dim = default(emb_dim, dim)

        self.emb_dim = emb_dim
        self.token_emb = TokenEmbedding(emb_dim, num_tokens)

        self.numerical_token_id = numerical_token_id

        self.max_seq_len = max_seq_len

        self.max_mem_len = max_mem_len

        if not (use_abs_pos_emb and not attn_layers.disable_abs_pos_emb):
            self.pos_emb = always(0)  # 如果不使用绝对位置编码或者禁用了绝对位置编码,则将位置编码设置为常数0
        elif scaled_sinu_pos_emb:
            self.pos_emb = ScaledSinusoidalEmbedding(dim)  # 如果使用了缩放的正弦位置编码,则使用缩放的正弦位置编码
        else:
            self.pos_emb = AbsolutePositionalEmbedding(dim, max_seq_len)  # 否则使用绝对位置编码

        self.emb_dropout = nn.Dropout(emb_dropout)

        # 内存标记

        num_memory_tokens = default(num_memory_tokens, 0)
        self.has_memory_tokens = num_memory_tokens > 0

        if num_memory_tokens > 0:
            self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))  # 初始化内存标记

        # 注意力层

        self.attn_layers = attn_layers

        # 转换为logits

        logits_dim = default(logits_dim, num_tokens)
        self.to_logits = nn.Linear(dim, logits_dim) if not tie_embedding else lambda t: t @ self.token_emb.emb.weight.t()

        self.to_numerical_output = nn.Sequential(
            nn.Linear(dim, 1),
            Rearrange('... 1 -> ...')
        )

    def forward(
        self,
        x: Tensor,
        x_num: Tensor,
        return_embeddings = False,
        return_intermediates = False,
        return_mems = False,
        mask = None,
        return_attn = False,
        mems = None,
        pos = None,
        prepend_embeds = None,
        **kwargs
        ):
        # 断言输入张量 x 的形状与 x_num 的形状相同
        assert x.shape == x_num.shape

        # 获取批次大小
        batch = x.shape[0]

        # 创建数值标记掩码
        is_number_mask = x == self.numerical_token_id

        # 对输入进行 token 嵌入
        x = self.token_emb(x)

        # 根据数值标记掩码调整缩放因子
        scale = torch.where(is_number_mask, x_num, 1.)
        # 重新排列张量维度,添加一个维度
        scale = rearrange(scale, '... -> ... 1')

        # 对输入进行缩放
        x = x * scale

        # 添加位置嵌入
        x = x + self.pos_emb(x, pos = pos)

        # 存储记忆令牌

        if self.has_memory_tokens:
            # 复制记忆令牌,扩展为与批次大小相同的维度
            m = repeat(self.memory_tokens, 'm d -> b m d', b = batch)
            # 打包输入张量和记忆令牌
            x, mem_ps = pack([m, x], 'b * d')

            if exists(mask):
                num_mems = m.shape[-2]
                # 在指定维度上填充掩码
                mask = pad_at_dim(mask, (num_mems, 0), dim = -1, value = True)

        # 是否追加嵌入,如 PaLI 中的图像嵌入
        if exists(prepend_embeds):
            _, prepend_dim = prepend_embeds.shape[1:]
            # 断言追加的嵌入维度与模型维度相同
            assert prepend_dim == x.shape[-1], 'prepended embeddings need to have same dimensions as model dimensions'

            # 在指定维度上连接张量
            x = torch.cat((prepend_embeds, x), dim = -2)

        # 对输入进行嵌入层的 dropout
        x = self.emb_dropout(x)

        # 注意力层

        x, intermediates = self.attn_layers(x, mask = mask, mems = mems, return_hiddens = True, **kwargs)

        # 分离记忆令牌

        if self.has_memory_tokens:
            m, x = unpack(x, mem_ps, 'b * d')
            intermediates.memory_tokens = m

        # 如果不返回嵌入,则生成 logits 和数值预测
        if not return_embeddings:
            logits = self.to_logits(x)
            numerical_pred = self.to_numerical_output(x)
            out = (logits, numerical_pred)
        else:
            out = x

        # 如果返回中间结果
        if return_intermediates:
            return out, intermediates

        # 如果返回记忆令牌
        if return_mems:
            hiddens = intermediates.hiddens
            new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), hiddens))
            return out, new_mems

        # 如果返回注意力图
        if return_attn:
            attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
            return out, attn_maps

        return out
class XValAutoregressiveWrapper(nn.Module):
    # 定义 XValAutoregressiveWrapper 类,继承自 nn.Module
    def __init__(
        self,
        net: XValTransformerWrapper,
        ignore_index = -100,
        pad_value = 0,
        numerical_loss_weight = 1.
    ):
        # 初始化函数,接受网络 net、ignore_index、pad_value 和 numerical_loss_weight 参数
        super().__init__()
        # 调用父类的初始化函数
        self.net = net
        # 将传入的网络赋值给对象属性 net
        self.max_seq_len = net.max_seq_len
        # 获取网络的最大序列长度
        self.numerical_loss_weight = numerical_loss_weight
        # 设置数值损失的权重
        self.ignore_index = ignore_index
        # 设置忽略的索引值

    @torch.no_grad()
    def generate(
        self,
        start_tokens: Tensor,
        start_numbers: Tensor,
        seq_len,
        filter_logits_fn: Callable = top_k,
        filter_kwargs: dict = dict(),
        temperature = 1.,
        **kwargs
    ):
        # 生成函数,接受起始标记、起始数字、序列长度等参数
        device = start_tokens.device
        # 获取起始标记所在设备
        was_training = self.net.training
        # 保存网络是否处于训练状态
        num_dims = len(start_tokens.shape)
        # 获取起始标记的维度数

        assert num_dims >= 2, 'number of dimensions of your start tokens must be greater or equal to 2'
        # 断言起始标记的维度数至少为 2
        assert start_tokens.shape == start_numbers.shape
        # 断言起始标记和起始数字的形状相同

        b, t, device = *start_tokens.shape, start_tokens.device
        # 获取起始标记的形状和设备信息
        self.net.eval()
        # 将网络设置为评估模式
        out = start_tokens
        num_out = start_numbers
        # 初始化输出和数字输出

        for _ in range(seq_len):
            # 循环生成序列
            x = out[:, -self.max_seq_len:]
            x_num = num_out[:, -self.max_seq_len:]
            # 获取最后 max_seq_len 个标记和数字

            logits, numerical_pred = self.net(x, x_num, **kwargs)
            # 使用网络生成 logits 和数值预测

            last_logits = logits[:, -1]
            last_num_pred = numerical_pred[:, -1:]
            # 获取最后一个 logits 和数值预测

            filtered_logits = filter_logits_fn(last_logits, **filter_kwargs)
            # 使用过滤函数过滤 logits

            probs = F.softmax(filtered_logits / temperature, dim=-1)
            # 计算 softmax 概率

            sample = torch.multinomial(probs, 1)
            # 从概率分布中采样一个标记

            out = torch.cat((out, sample), dim = -1)
            num_out = torch.cat((num_out, last_num_pred), dim = -1)
            # 将新生成的标记和数值添加到输出中

        out = out[:, t:]
        num_out = num_out[:, t:]
        # 去除起始标记
        is_number = out == self.net.numerical_token_id
        # 判断是否为数值标记
        num_out = torch.where(is_number, num_out, float('nan'))
        # 将非数值标记的数值设置为 NaN

        self.net.train(was_training)
        # 恢复网络的训练状态
        return GenerateReturn(out, num_out, is_number)
        # 返回生成的序列和数值信息

    def forward(
        self,
        x: Tensor,
        x_num: Tensor,
        return_loss_breakdown = False,
        **kwargs
    ):
        # 前向传播函数,接受输入 x、数值输入 x_num 和其他参数
        inp, target = x[:, :-1], x[:, 1:]
        # 获取输入和目标序列
        x_num_inp, x_num_target = x_num[:, :-1], x_num[:, 1:]
        # 获取数值输入和数值目标

        mask = kwargs.get('mask', None)
        # 获取掩码
        if exists(mask) and mask.shape[1] == x.shape[1]:
            mask = mask[:, :-1]
            kwargs['mask'] = mask
        # 处理掩码

        logits, numerical_pred = self.net(inp, x_num_inp, **kwargs)
        # 使用网络进行前向传播

        logits = rearrange(logits, 'b n c -> b c n')
        # 重新排列 logits 的维度

        cross_entropy_loss = F.cross_entropy(logits, target, reduction = 'none', ignore_index = self.ignore_index)
        # 计算交叉熵损失

        target_mask = target != self.ignore_index
        # 创建目标掩码

        numerical_mse_loss = F.mse_loss(numerical_pred, x_num_target, reduction = 'none')
        # 计算数值均方误差损失

        numerical_mse_loss = numerical_mse_loss * target_mask
        # 根据目标掩码调整数值损失

        loss = cross_entropy_loss + numerical_mse_loss * self.numerical_loss_weight
        # 计算总损失

        if exists(mask):
            loss = loss[mask]
        # 根据掩码筛选损失

        loss = loss.mean()
        # 计算平均损失

        if not return_loss_breakdown:
            return loss
        # 如果不需要详细损失信息,直接返回总损失

        return loss, LossBreakdown(cross_entropy_loss, numerical_mse_loss)
        # 返回总损失和损失细分信息

.\lucidrains\x-transformers\x_transformers\x_transformers.py

# 导入数学库
import math
# 从 random 模块中导入 random 函数
from random import random
# 从 typing 模块中导入 Dict 类型提示
from typing import Dict
# 从 packaging 模块中导入 version 版本信息
from packaging import version

# 导入 torch 库
import torch
# 从 torch 库中导入 nn, einsum, Tensor
from torch import nn, einsum, Tensor
# 从 torch.nn 模块中导入 functional 模块
import torch.nn.functional as F
# 从 torch.cuda.amp 模块中导入 autocast 函数
from torch.cuda.amp import autocast

# 导入 functools 模块中的 partial, wraps 函数
from functools import partial, wraps
# 导入 collections 模块中的 namedtuple 类
from collections import namedtuple
# 导入 dataclasses 模块中的 dataclass 装饰器
from dataclasses import dataclass
# 从 typing 模块中导入 List, Callable, Optional, Union 类型提示
from typing import List, Callable, Optional, Union

# 从 einops 模块中导入 rearrange, repeat, reduce, pack, unpack 函数
from einops import rearrange, repeat, reduce, pack, unpack
# 从 einops.layers.torch 模块中导入 Rearrange 类
from einops.layers.torch import Rearrange

# 从 x_transformers.attend 模块中导入 Attend, Intermediates 类
from x_transformers.attend import Attend, Intermediates
# 从 x_transformers.autoregressive_wrapper 模块中导入 AutoregressiveWrapper 类

# 常量定义

# 默认头部维度
DEFAULT_DIM_HEAD = 64

# 定义 LayerIntermediates 数据类
@dataclass
class LayerIntermediates:
    hiddens:            Optional[List[Tensor]] = None   # 所有隐藏层,在最终规范化之前(在预规范化架构中)
    last_hidden:        Optional[Tensor] = None         # 所有注意力层之后的最后一个隐藏层,在最终规范化之后
    attn_intermediates: Optional[List[Intermediates]] = None
    layer_hiddens:      Optional[List[Tensor]] = None
    attn_z_loss:        Optional[Tensor] = None
    mems:               Optional[Tensor] = None
    memory_tokens:      Optional[Tensor] = None

# 辅助函数

# 判断变量是否存在
def exists(val):
    return val is not None

# 如果变量存在则返回其值,否则返回默认值
def default(val, d):
    if exists(val):
        return val
    return d() if callable(d) else d

# 将变量转换为元组
def cast_tuple(val, depth):
    return val if isinstance(val, tuple) else (val,) * depth

# 判断一个数是否可以被另一个数整除
def divisible_by(num, den):
    return (num % den) == 0

# 如果变量存在则执行函数
def maybe(fn):
    @wraps(fn)
    def inner(x, *args, **kwargs):
        if not exists(x):
            return x
        return fn(x, *args, **kwargs)
    return inner

# 至多一个为真
def at_most_one_of(*bools):
    return sum(map(int, bools)) <= 1

# 始终返回相同值
class always():
    def __init__(self, val):
        self.val = val
    def __call__(self, *args, **kwargs):
        return self.val

# 不等于某个值
class not_equals():
    def __init__(self, val):
        self.val = val
    def __call__(self, x, *args, **kwargs):
        return x != self.val

# 等于某个值
class equals():
    def __init__(self, val):
        self.val = val
    def __call__(self, x, *args, **kwargs):
        return x == self.val

# 创建序列模块
def Sequential(*modules):
    return nn.Sequential(*filter(exists, modules))

# 张量辅助函数

# 返回张量的最小负值
def max_neg_value(tensor):
    return -torch.finfo(tensor.dtype).max

# 对张量进行 L2 归一化
def l2norm(t, groups = 1):
    t = rearrange(t, '... (g d) -> ... g d', g = groups)
    t = F.normalize(t, p = 2, dim = -1)
    return rearrange(t, '... g d -> ... (g d)')

# 在指定维度上填充张量
def pad_at_dim(t, pad, dim = -1, value = 0.):
    dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
    zeros = ((0, 0) * dims_from_right)
    return F.pad(t, (*zeros, *pad), value = value)

# 对多个掩码进行逻辑或操作
def or_reduce(masks):
    head, *body = masks
    for rest in body:
        head = head | rest
    return head

# 辅助损失函数

# 计算 z 损失
def calc_z_loss(
    pre_softmax_attns: List[Tensor],
    mask = None,
    weight = 1.
):
    # 在 https://arxiv.org/abs/2202.08906 中应用于专家混合路由器对数的相同损失
    # 在论文中,他们在一个小脚注中提到将其应用于注意力对数,具有稳定效果
    # 在 PaLM 中也作为措施之一使用

    lse = 0.

    for attn in pre_softmax_attns:
        lse = lse + attn.logsumexp(dim = -1)

    loss = torch.square(lse)
    loss = reduce(loss, 'b h n -> b n', 'sum')

    if not exists(mask):
        return loss.mean() * weight

    loss = loss[mask].sum() / mask.sum().clamp(min = 1e-5)
    return loss * weight

# 初始化辅助函数

# 初始化为零
def init_zero_(layer):
    nn.init.constant_(layer.weight, 0.)
    if exists(layer.bias):
        nn.init.constant_(layer.bias, 0.)

# 关键字参数辅助函数

# 选择并弹出键值对
def pick_and_pop(keys, d):
    values = list(map(lambda key: d.pop(key), keys))
    return dict(zip(keys, values))

# 根据条件将字典分组
def group_dict_by_key(cond, d):
    return_val = [dict(),dict()]
    for key in d.keys():
        match = bool(cond(key))
        ind = int(not match)
        return_val[ind][key] = d[key]
    return (*return_val,)
# 检查字符串是否以指定前缀开头
def string_begins_with(prefix, str):
    return str.startswith(prefix)

# 根据键的前缀对字典进行分组
def group_by_key_prefix(prefix, d):
    return group_dict_by_key(partial(string_begins_with, prefix), d)

# 根据前缀对字典进行分组并修剪前缀
def groupby_prefix_and_trim(prefix, d):
    # 根据前缀对字典进行分组
    kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
    # 剔除前缀,生成新的字典
    kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))
    return kwargs_without_prefix, kwargs

# 结构化的 dropout,比传统的注意力 dropout 更有效
def dropout_seq(seq, mask, dropout):
    # 获取序列的形状和设备信息
    b, n, *_, device = *seq.shape, seq.device
    # 生成服从标准正态分布的随机数
    logits = torch.randn(b, n, device=device)

    # 如果存在掩码
    if exists(mask):
        # 获取 logits 中的最大负值
        mask_value = max_neg_value(logits)
        # 使用 mask_value 替换掩码为 False 的位置
        logits = logits.masked_fill(~mask, mask_value)

    # 计算保留的概率和保留的数量
    keep_prob = 1. - dropout
    num_keep = max(1, int(keep_prob * n))
    keep_indices = logits.topk(num_keep, dim=1).indices

    # 生成批次索引
    batch_indices = torch.arange(b, device=device)
    batch_indices = rearrange(batch_indices, 'b -> b 1')

    # 根据保留的索引获取序列的子集
    seq = seq[batch_indices, keep_indices]

    # 如果存在掩码
    if exists(mask):
        # 计算序列中每个样本的非零元素数量
        seq_counts = mask.sum(dim=-1)
        # 计算保留的元素数量
        seq_keep_counts = torch.ceil(seq_counts * keep_prob).int()
        keep_mask = torch.arange(num_keep, device=device) < rearrange(seq_keep_counts, 'b -> b 1')

        # 更新掩码
        mask = mask[batch_indices, keep_indices] & keep_mask

    return seq, mask

# 激活函数
class ReluSquared(nn.Module):
    def forward(self, x):
        return F.relu(x) ** 2

# 词嵌入
class TokenEmbedding(nn.Module):
    def __init__(self, dim, num_tokens, l2norm_embed=False):
        super().__init__()
        self.l2norm_embed = l2norm_embed
        self.emb = nn.Embedding(num_tokens, dim)

    def forward(self, x):
        token_emb = self.emb(x.long())
        return l2norm(token_emb) if self.l2norm_embed else token_emb

# 绝对位置嵌入
class AbsolutePositionalEmbedding(nn.Module):
    def __init__(self, dim, max_seq_len, l2norm_embed=False):
        super().__init__()
        self.scale = dim ** -0.5 if not l2norm_embed else 1.
        self.max_seq_len = max_seq_len
        self.l2norm_embed = l2norm_embed
        self.emb = nn.Embedding(max_seq_len, dim)

    def forward(self, x, pos=None, seq_start_pos=None):
        seq_len, device = x.shape[1], x.device
        assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}'

        if not exists(pos):
            pos = torch.arange(seq_len, device=device)

        if exists(seq_start_pos):
            pos = (pos - seq_start_pos[..., None]).clamp(min=0)

        pos_emb = self.emb(pos)
        pos_emb = pos_emb * self.scale
        return l2norm(pos_emb) if self.l2norm_embed else pos_emb

# 缩放的正弦位置嵌入
class ScaledSinusoidalEmbedding(nn.Module):
    def __init__(self, dim, theta=10000):
        super().__init__()
        assert divisible_by(dim, 2)
        self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5)

        half_dim = dim // 2
        freq_seq = torch.arange(half_dim).float() / half_dim
        inv_freq = theta ** -freq_seq
        self.register_buffer('inv_freq', inv_freq, persistent=False)

    def forward(self, x, pos=None, seq_start_pos=None):
        seq_len, device = x.shape[1], x.device

        if not exists(pos):
            pos = torch.arange(seq_len, device=device)

        if exists(seq_start_pos):
            pos = pos - seq_start_pos[..., None]

        emb = einsum('i, j -> i j', pos, self.inv_freq)
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb * self.scale

class RelativePositionBias(nn.Module):
    # 初始化函数,设置模型参数
    def __init__(self, scale, causal = False, num_buckets = 32, max_distance = 128, heads = 8):
        # 调用父类的初始化函数
        super().__init__()
        # 设置模型的缩放比例
        self.scale = scale
        # 设置是否使用因果关系
        self.causal = causal
        # 设置桶的数量
        self.num_buckets = num_buckets
        # 设置最大距离
        self.max_distance = max_distance
        # 创建相对注意力偏置的嵌入层
        self.relative_attention_bias = nn.Embedding(num_buckets, heads)

    # 静态方法,用于计算相对位置的桶索引
    @staticmethod
    def _relative_position_bucket(relative_position, causal = True, num_buckets = 32, max_distance = 128):
        # 初始化返回值
        ret = 0
        # 计算相对位置的负值
        n = -relative_position
        # 如果不是因果关系,调整桶的数量
        if not causal:
            num_buckets //= 2
            ret += (n < 0).long() * num_buckets
            n = torch.abs(n)
        else:
            n = torch.max(n, torch.zeros_like(n))

        # 计算最大精确值
        max_exact = num_buckets // 2
        is_small = n < max_exact

        # 计算大值时的桶索引
        val_if_large = max_exact + (
            torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
        ).long()
        val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))

        # 根据大小选择桶索引
        ret += torch.where(is_small, n, val_if_large)
        return ret

    # 返回设备信息
    @property
    def device(self):
        return next(self.parameters()).device

    # 前向传播函数
    def forward(self, i, j):
        # 获取设备信息
        device = self.device
        # 生成查询位置
        q_pos = torch.arange(j - i, j, dtype = torch.long, device = device)
        # 生成键位置
        k_pos = torch.arange(j, dtype = torch.long, device = device)
        # 计算相对位置
        rel_pos = k_pos[None, :] - q_pos[:, None]
        # 计算相对位置的桶索引
        rp_bucket = self._relative_position_bucket(rel_pos, causal = self.causal, num_buckets = self.num_buckets, max_distance = self.max_distance)
        # 获取相对注意力偏置值
        values = self.relative_attention_bias(rp_bucket)
        # 重排形状
        bias = rearrange(values, 'i j h -> h i j')
        return bias * self.scale
class DynamicPositionBias(nn.Module):
    # 定义动态位置偏置类,继承自 nn.Module
    def __init__(self, dim, *, heads, depth, log_distance = False, norm = False):
        # 初始化函数,接受维度、头数、深度、是否对距离取对数、是否进行归一化等参数
        super().__init__()
        # 调用父类的初始化函数
        assert depth >= 1, 'depth for dynamic position bias MLP must be greater or equal to 1'
        # 断言深度必须大于等于1
        self.log_distance = log_distance
        # 设置是否对距离取对数的标志

        self.mlp = nn.ModuleList([])
        # 初始化多层感知机模块列表

        self.mlp.append(Sequential(
            nn.Linear(1, dim),
            LayerNorm(dim) if norm else None,
            nn.SiLU()
        ))
        # 向多层感知机模块列表中添加线性层、归一化层和激活函数

        for _ in range(depth - 1):
            self.mlp.append(Sequential(
                nn.Linear(dim, dim),
                nn.LayerNorm(dim) if norm else None,
                nn.SiLU()
            ))
        # 根据深度循环添加多层感知机模块

        self.mlp.append(nn.Linear(dim, heads)
        # 向多层感知机模块列表中添加线性层,输出头数

    @property
    def device(self):
        # 定义设备属性,返回参数的设备
        return next(self.parameters()).device

    def forward(self, i, j):
        # 前向传播函数,接受输入i和j
        assert i == j
        # 断言i等于j
        n, device = j, self.device
        # 设置n为j,获取设备信息

        # get the (n x n) matrix of distances
        # 获取距离的(n x n)矩阵
        seq_arange = torch.arange(n, device = device)
        context_arange = torch.arange(n, device = device)
        indices = rearrange(seq_arange, 'i -> i 1') - rearrange(context_arange, 'j -> 1 j')
        indices += (n - 1)

        # input to continuous positions MLP
        # 连续位置多层感知机的输入
        pos = torch.arange(-n + 1, n, device = device).float()
        pos = rearrange(pos, '... -> ... 1')

        if self.log_distance:
            pos = torch.sign(pos) * torch.log(pos.abs() + 1)  # log of distance is sign(rel_pos) * log(abs(rel_pos) + 1)
        # 如果需要对距离取对数,则进行对数操作

        for layer in self.mlp:
            pos = layer(pos)
        # 遍历多层感知机模块列表,对位置进行处理

        # get position biases        
        # 获取位置偏置
        bias = pos[indices]
        bias = rearrange(bias, 'i j h -> h i j')
        return bias
        # 返回位置偏置

class AlibiPositionalBias(nn.Module):
    # 定义Alibi位置偏置类,继承自 nn.Module
    def __init__(self, heads, total_heads, **kwargs):
        # 初始化函数,接受头数和总头数等参数
        super().__init__()
        # 调用父类的初始化函数
        self.heads = heads
        self.total_heads = total_heads

        slopes = Tensor(self._get_slopes(heads))
        slopes = rearrange(slopes, 'h -> h 1 1')
        self.register_buffer('slopes', slopes, persistent = False)
        self.register_buffer('bias', None, persistent = False)
        # 初始化斜率和偏置

    def get_bias(self, i, j, device):
        # 定义获取偏置的函数,接受i、j和设备参数
        i_arange = torch.arange(j - i, j, device = device)
        j_arange = torch.arange(j, device = device)
        bias = -torch.abs(rearrange(j_arange, 'j -> 1 1 j') - rearrange(i_arange, 'i -> 1 i 1'))
        return bias
        # 返回偏置

    @staticmethod
    def _get_slopes(heads):
        # 定义获取斜率的静态方法,接受头数参数
        def get_slopes_power_of_2(n):
            start = (2**(-2**-(math.log2(n)-3)))
            ratio = start
            return [start*ratio**i for i in range(n)]
        # 定义获取2的幂次方斜率的函数

        if math.log2(heads).is_integer():
            return get_slopes_power_of_2(heads)
        # 如果头数是2的幂次方,则返回对应斜率

        closest_power_of_2 = 2 ** math.floor(math.log2(heads))
        return get_slopes_power_of_2(closest_power_of_2) + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][:heads-closest_power_of_2]
        # 否则返回最接近的2的幂次方的斜率和补充的斜率

    @property
    def device(self):
        # 定义设备属性,返回缓冲区的设备
        return next(self.buffers()).device

    def forward(self, i, j):
        # 前向传播函数,接受输入i和j
        h, device = self.total_heads, self.device

        if exists(self.bias) and self.bias.shape[-1] >= j and self.bias.shape[-2] >= i:
            return self.bias[..., -i:, -j:]
        # 如果偏置存在且形状符合要求,则返回偏置

        bias = self.get_bias(i, j, device)
        bias = bias * self.slopes
        # 计算偏置并乘以斜率

        num_heads_unalibied = h - bias.shape[0]
        bias = pad_at_dim(bias, (0, num_heads_unalibied), dim = 0)
        self.register_buffer('bias', bias, persistent = False)
        # 对未校准的头数进行填充

        return self.bias
        # 返回偏置

class RotaryEmbedding(nn.Module):
    # 定义旋转嵌入类,继承自 nn.Module
    def __init__(
        self,
        dim,
        use_xpos = False,
        scale_base = 512,
        interpolation_factor = 1.,
        base = 10000,
        base_rescale_factor = 1.
        # 初始化函数,接受维度、是否使用x位置、缩放基数、插值因子、基数和基数重新缩放因子等参数
    ):
        # 调用父类的构造函数
        super().__init__()
        # 根据 reddit 用户 bloc97 的建议,将旋转嵌入重新缩放到更长的序列长度,而无需微调
        # 与 NTK 文献有一定联系
        # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
        base *= base_rescale_factor ** (dim / (dim - 2))

        # 计算频率的倒数
        inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
        # 将频率的倒数作为缓冲区
        self.register_buffer('inv_freq', inv_freq)

        assert interpolation_factor >= 1.
        # 设置插值因子
        self.interpolation_factor = interpolation_factor

        if not use_xpos:
            # 如果不使用 xpos,则将缩放设置为 None
            self.register_buffer('scale', None)
            return

        # 计算缩放
        scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)

        self.scale_base = scale_base
        # 将缩放作为缓冲区
        self.register_buffer('scale', scale)

    # 根据序列长度进行前向传播
    def forward_from_seq_len(self, seq_len):
        device = self.inv_freq.device

        t = torch.arange(seq_len, device = device)
        return self.forward(t)

    # 禁用自动混合精度
    @autocast(enabled = False)
    def forward(self, t):
        # 计算最大位置
        max_pos = t.max()+1

        # 计算频率
        freqs = torch.einsum('i , j -> i j', t.type_as(self.inv_freq), self.inv_freq) / self.interpolation_factor
        freqs = torch.cat((freqs, freqs), dim = -1)

        if not exists(self.scale):
            return freqs, 1.

        # 计算幂次
        power = (t - (max_pos // 2)) / self.scale_base
        scale = self.scale ** rearrange(power, 'n -> n 1')
        scale = torch.cat((scale, scale), dim = -1)

        return freqs, scale
# 定义一个函数,将输入张量 x 进行重新排列,将最后两个维度中的第一个维度 j 换到倒数第二个维度
def rotate_half(x):
    x = rearrange(x, '... (j d) -> ... j d', j = 2)
    # 将 x 拆分为两部分 x1 和 x2,根据倒数第二个维度进行拆分
    x1, x2 = x.unbind(dim = -2)
    # 将 x2 取负值,然后与 x1 进行拼接,得到旋转后的张量
    return torch.cat((-x2, x1), dim = -1)

# 定义一个函数,应用旋转位置嵌入到输入张量 t 上
@autocast(enabled = False)
def apply_rotary_pos_emb(t, freqs, scale = 1):
    # 获取旋转维度和序列长度
    rot_dim, seq_len = freqs.shape[-1], t.shape[-2]
    # 截取与序列长度相同的频率信息
    freqs = freqs[-seq_len:, :]
    # 如果 scale 是张量,则截取与序列长度相同的部分
    scale = scale[-seq_len:, :] if isinstance(scale, torch.Tensor) else scale

    # 如果输入张量 t 和频率信息 freqs 的维度分别为 4 和 3
    if t.ndim == 4 and freqs.ndim == 3:
        # 将频率信息维度扩展为 4 维
        freqs = rearrange(freqs, 'b n d -> b 1 n d')

    # 部分旋转嵌入,Wang et al. GPT-J
    # 将输入张量 t 拆分为旋转部分 t 和未旋转部分 t_unrotated
    t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:]
    # 计算旋转后的张量 t
    t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
    # 将旋转后的张量 t 与未旋转部分拼接,返回结果
    return torch.cat((t, t_unrotated), dim = -1)

# norms

# 定义一个缩放层,用于对输入进行缩放
class Scale(nn.Module):
    def __init__(self, value, fn):
        super().__init__()
        self.value = value
        self.fn = fn

    def forward(self, x, **kwargs):
        # 对输入进行处理
        out = self.fn(x, **kwargs)
        # 定义缩放函数
        scale_fn = lambda t: t * self.value

        # 如果输出不是元组,则对输出进行缩放处理
        if not isinstance(out, tuple):
            return scale_fn(out)

        # 如果输出是元组,则对第一个元素进行缩放处理
        return (scale_fn(out[0]), *out[1:])

# 定义一个缩放归一化层
class ScaleNorm(nn.Module):
    def __init__(self, dim, eps = 1e-5):
        super().__init__()
        self.eps = eps
        self.g = nn.Parameter(torch.ones(1) * (dim ** -0.5))

    def forward(self, x):
        # 计算输入张量的范数,并进行归一化处理
        norm = torch.norm(x, dim = -1, keepdim = True)
        return x / norm.clamp(min = self.eps) * self.g

# 定义一个 LayerNorm ��
class LayerNorm(nn.Module):
    def __init__(self, dim):
        """
        bias-less layernorm has been shown to be more stable. most newer models have moved towards rmsnorm, also bias-less
        """
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(dim))
        self.register_buffer("beta", torch.zeros(dim))

    def forward(self, x):
        # 使用 F.layer_norm 进行 LayerNorm 处理
        return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)

# 如果 PyTorch 版本大于等于 2.1.0,则将 LayerNorm 替换为 nn.LayerNorm,并设置 bias 为 False
if version.parse(torch.__version__) >= version.parse('2.1.0'):
    LayerNorm = partial(nn.LayerNorm, bias = False)

# 定义一个 RMSNorm 层
class RMSNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.scale = dim ** 0.5
        self.g = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        # 对输入进行归一化处理,并乘以缩放因子和参数 g
        return F.normalize(x, dim = -1) * self.scale * self.g

# 定义一个简单的 RMSNorm 层
class SimpleRMSNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.scale = dim ** 0.5

    def forward(self, x):
        # 对输入进行归一化处理,并乘以缩放因子
        return F.normalize(x, dim = -1) * self.scale

# residual and residual gates

# 定义一个残差连接层
class Residual(nn.Module):
    def __init__(self, dim, scale_residual = False, scale_residual_constant = 1.):
        super().__init__()
        self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
        self.scale_residual_constant = scale_residual_constant

    def forward(self, x, residual):
        # 如果存在残差缩放参数,则对残差进行缩放处理
        if exists(self.residual_scale):
            residual = residual * self.residual_scale

        # 如果缩放常数不为 1,则对残差进行缩放处理
        if self.scale_residual_constant != 1:
            residual = residual * self.scale_residual_constant

        # 返回残差连接结果
        return x + residual

# 定义一个 GRU 门控单元层
class GRUGating(nn.Module):
    def __init__(self, dim, scale_residual = False, **kwargs):
        super().__init__()
        self.gru = nn.GRUCell(dim, dim)
        self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None

    def forward(self, x, residual):
        # 如果存在残差缩放参数,则对残差进行缩放处理
        if exists(self.residual_scale):
            residual = residual * self.residual_scale

        # 使用 GRU 单元进行门控处理
        gated_output = self.gru(
            rearrange(x, 'b n d -> (b n) d'),
            rearrange(residual, 'b n d -> (b n) d')
        )

        # 将门控输出重塑为与输入相同的形状
        return gated_output.reshape_as(x)

# token shifting

# 定义一个函数,对输入张量进行平移操作
def shift(t, amount, mask = None):
    if amount == 0:
        return t
    else:
        # 如果平移量大于输入张量的长度,则取最大值
        amount = min(amount, t.shape[1])

    # 如果存在掩码,则对输入张量进行掩码填充
    if exists(mask):
        t = t.masked_fill(~mask[..., None], 0.)

    # 在指定维度上对输入张量进行填充操作
    return pad_at_dim(t, (amount, -amount), dim = - 2, value = 0.)

# 定义一个 ShiftTokens 类,用于对输入进行平移操作
class ShiftTokens(nn.Module):
    def __init__(self, shifts, fn):
        super().__init__()
        self.fn = fn
        self.shifts = tuple(shifts)
    # 定义一个前向传播函数,接受输入 x 和关键字参数 kwargs
    def forward(self, x, **kwargs):
        # 从关键字参数 kwargs 中获取名为 'mask' 的值,如果没有则为 None
        mask = kwargs.get('mask', None)
        # 获取位移列表
        shifts = self.shifts
        # 计算段数
        segments = len(shifts)
        # 计算每个段的特征数
        feats_per_shift = x.shape[-1] // segments
        # 将输入 x 按特征数分割成多个张量
        splitted = x.split(feats_per_shift, dim=-1)
        # 将分割后的张量分为需要进行位移的段和剩余部分
        segments_to_shift, rest = splitted[:segments], splitted[segments:]
        # 对需要进行位移的段进行位移操作,使用 map 函数和 lambda 表达式
        segments_to_shift = list(map(lambda args: shift(*args, mask=mask), zip(segments_to_shift, shifts)))
        # 将位移后的段和剩余部分拼接在一起
        x = torch.cat((*segments_to_shift, *rest), dim=-1)
        # 调用 self.fn 函数对拼接后的张量进行处理,返回结果
        return self.fn(x, **kwargs)
# 定义 GLU 类,用于实现门控线性单元
class GLU(nn.Module):
    def __init__(
        self,
        dim_in,
        dim_out,
        activation: Callable,
        mult_bias = False
    ):
        super().__init__()
        self.act = activation
        self.proj = nn.Linear(dim_in, dim_out * 2)
        self.mult_bias = nn.Parameter(torch.ones(dim_out)) if mult_bias else 1.

    # 前向传播函数
    def forward(self, x):
        # 将输入通过线性变换后分成两部分
        x, gate = self.proj(x).chunk(2, dim = -1)
        # 返回门控线性单元的输出
        return x * self.act(gate) * self.mult_bias

# 定义 FeedForward 类,用于实现前馈神经网络
class FeedForward(nn.Module):
    def __init__(
        self,
        dim,
        dim_out = None,
        mult = 4,
        glu = False,
        glu_mult_bias = False,
        swish = False,
        relu_squared = False,
        post_act_ln = False,
        dropout = 0.,
        no_bias = False,
        zero_init_output = False
    ):
        super().__init__()
        inner_dim = int(dim * mult)
        dim_out = default(dim_out, dim)

        # 根据参数选择激活函数
        if relu_squared:
            activation = ReluSquared()
        elif swish:
            activation = nn.SiLU()
        else:
            activation = nn.GELU()

        # 根据参数选择网络结构
        if glu:
            project_in = GLU(dim, inner_dim, activation, mult_bias = glu_mult_bias)
        else:
            project_in = nn.Sequential(
                nn.Linear(dim, inner_dim, bias = not no_bias),
                activation
            )

        # 构建前馈神经网络
        self.ff = Sequential(
            project_in,
            LayerNorm(inner_dim) if post_act_ln else None,
            nn.Dropout(dropout),
            nn.Linear(inner_dim, dim_out, bias = not no_bias)
        )

        # 初始化最后一层线性层的权重为0
        if zero_init_output:
            init_zero_(self.ff[-1])

    # 前向传播函数
    def forward(self, x):
        return self.ff(x)

# 定义 Attention 类,用于实现注意力机制
class Attention(nn.Module):
    def __init__(
        self,
        dim,
        dim_head = DEFAULT_DIM_HEAD,
        dim_context = None,
        heads = 8,
        causal = False,
        flash = False,
        talking_heads = False,
        head_scale = False,
        sparse_topk = None,
        num_mem_kv = 0,
        dropout = 0.,
        on_attn = False,
        gate_value_heads = False,
        swiglu_values = False,
        gate_values = False,
        zero_init_output = False,
        max_attend_past = None,
        qk_norm = False,
        qk_norm_groups = 1,
        qk_norm_scale = 10,
        qk_norm_dim_scale = False,
        one_kv_head = False,
        kv_heads = None,
        shared_kv = False,
        value_dim_head = None,
        tensor_product = False,      # https://arxiv.org/abs/2208.06061
        add_zero_kv = False,         # same as add_zero_attn in pytorch
        rotary_embed_values = False,
        onnxable = False
    # 前向传播函数
    def forward(
        self,
        x,
        context = None,
        mask = None,
        context_mask = None,
        attn_mask = None,
        rel_pos = None,
        rotary_pos_emb = None,
        prev_attn = None,
        mem = None,
        mem_mask = None,
        return_intermediates = False,
        cache: Optional[Intermediates] = None,
class AttentionLayers(nn.Module):
    # 初始化函数,设置模型参数
    def __init__(
        self,
        dim,
        depth,
        heads = 8,
        causal = False,
        cross_attend = False,
        only_cross = False,
        use_scalenorm = False,
        use_rmsnorm = False,
        use_simple_rmsnorm = False,
        alibi_pos_bias = False,
        alibi_num_heads = None,
        rel_pos_bias = False,
        rel_pos_num_buckets = 32,
        rel_pos_max_distance = 128,
        dynamic_pos_bias = False,
        dynamic_pos_bias_log_distance = False,
        dynamic_pos_bias_mlp_depth = 2,
        dynamic_pos_bias_norm = False,
        rotary_pos_emb = False,
        rotary_emb_dim = None,
        rotary_xpos = False,
        rotary_interpolation_factor = 1.,
        rotary_xpos_scale_base = 512,
        rotary_base_rescale_factor = 1.,
        custom_layers = None,
        sandwich_coef = None,
        par_ratio = None,
        weight_tie_layers = False,   # Albert - https://arxiv.org/abs/1909.11942
        layers_execute_order = None, # generalizes weight tying, can do arbitrary layer execution orders
        residual_attn = False,
        cross_residual_attn = False,
        macaron = False,
        pre_norm = True,
        pre_norm_has_final_norm = True,
        gate_residual = False,
        scale_residual = False,
        scale_residual_constant = 1.,
        shift_tokens = 0,
        sandwich_norm = False,
        resi_dual = False,
        resi_dual_scale = 1.,
        zero_init_branch_output = False,
        layer_dropout = 0.,
        cross_attn_tokens_dropout = 0.,
        disable_abs_pos_emb = None,
        **kwargs
    # 前向传播函数,接收输入数据并进行模型计算
    def forward(
        self,
        x,
        context = None,
        mask = None,
        context_mask = None,
        attn_mask = None,
        self_attn_kv_mask = None,
        mems = None,
        mem_masks = None,
        seq_start_pos: Optional[Tensor] = None,
        cache: Optional[LayerIntermediates] = None,
        cache_age = 1,
        return_hiddens = False,
        rotary_pos_emb = None
class Encoder(AttentionLayers):
    # 定义编码器类,继承自AttentionLayers类
    def __init__(self, **kwargs):
        # 初始化函数,接受任意关键字参数
        assert 'causal' not in kwargs, 'cannot set causality on encoder'
        # 断言关键字参数中不包含'causal',否则抛出异常
        super().__init__(causal = False, **kwargs)
        # 调用父类的初始化函数,设置causal参数为False,并传入其他关键字参数

class Decoder(AttentionLayers):
    # 定义解码器类,继承自AttentionLayers类
    def __init__(self, **kwargs):
        # 初始化函数,接受任意关键字参数
        assert 'causal' not in kwargs, 'cannot set causality on decoder'
        # 断言关键字参数中不包含'causal',否则抛出异常
        super().__init__(causal = True, **kwargs)
        # 调用父类的初始化函数,设置causal参数为True,并传入其他关键字参数

class PrefixDecoder(AttentionLayers):
    # 定义前缀解码器类,继承自AttentionLayers类
    def __init__(self, **kwargs):
        # 初始化函数,接受任意关键字参数
        assert 'causal' not in kwargs, 'cannot set causality on decoder'
        # 断言关键字参数中不包含'causal',否则抛出异常
        super().__init__(causal = False, **kwargs)
        # 调用父类的初始化函数,设置causal参数为False,并传入其他关键字参数

    def forward(
        self,
        x,
        *args,
        attn_mask = None,
        prefix_attn_len = None,
        **kwargs
    ):
        # 前向传播函数,接受输入x和任意位置参数args,注意力掩码attn_mask和前缀注意力长度prefix_attn_len,以及任意关键字参数kwargs
        b, n, device = x.shape[0], x.shape[1], x.device
        # 获取输入x的批量大小b,序列长度n,设备device
        causal_mask = torch.ones((n, n), device = device, dtype = torch.bool).triu(1)
        # 创建一个全为1的张量作为因果掩码,上三角部分为True,下三角部分为False

        forwarded_mask = ~causal_mask
        # 计算非因果掩码,即上三角部分为False,下三角部分为True

        if exists(prefix_attn_len):
            # 如果前缀注意力长度存在
            if isinstance(prefix_attn_len, int):
                # 如果前缀注意力长度是整数
                prefix_attn_len = torch.full((b,), prefix_attn_len, device = device)
                # 创建一个形状为(b,)的张量,填充值为前缀注意力长度,设备为device

            prefix_mask = torch.arange(n, device = device) < rearrange(prefix_attn_len, 'b -> b 1 1 1')
            # 创建前缀掩码,根据前缀注意���长度生成

            forwarded_mask = forwarded_mask | prefix_mask
            # 更新前向掩码,将前缀掩码应用到前向掩码中

        if exists(attn_mask):
            # 如果注意力掩码存在
            forwarded_mask = forwarded_mask & attn_mask
            # 更新前向掩码,将注意力掩码应用到前向掩码中

        return super().forward(x, *args, attn_mask = forwarded_mask, **kwargs)
        # 调用父类的前向传播函数,传入更新后的注意力掩码参数

class CrossAttender(AttentionLayers):
    # 定义交叉注意力层类,继承自AttentionLayers类
    def __init__(self, **kwargs):
        # 初始化函数,接受任意关键字参数
        super().__init__(cross_attend = True, only_cross = True, **kwargs)
        # 调用父类的初始化函数,设置cross_attend和only_cross参数为True,并传入其他关键字参数

class ViTransformerWrapper(nn.Module):
    # 定义ViTransformerWrapper类,继承自nn.Module类
    def __init__(
        self,
        *,
        image_size,
        patch_size,
        attn_layers: Encoder,
        channels = 3,
        num_classes = None,
        post_emb_norm = False,
        num_register_tokens = 0,
        emb_dropout = 0.
    ):
        # 初始化函数,接受命名关键字参数
        super().__init__()
        # 调用父类的初始化函数
        assert divisible_by(image_size, patch_size), 'image dimensions must be divisible by the patch size'
        # 断言图像尺寸能被补丁尺寸整除,否则抛出异常
        dim = attn_layers.dim
        # 获取注意力层的维度
        num_patches = (image_size // patch_size) ** 2
        # 计算图像中的补丁数量
        patch_dim = channels * patch_size ** 2
        # 计算补丁的维度

        self.patch_size = patch_size
        # 设置对象属性patch_size为传入的补丁尺寸

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, dim))
        # 创建位置嵌入参数,形状为(1, num_patches, dim),初始化为随机值

        has_register_tokens = num_register_tokens > 0
        # 判断是否存在注册令牌
        self.has_register_tokens = has_register_tokens
        # 设置对象属性has_register_tokens为判断结果

        if has_register_tokens:
            # 如果存在注册令牌
            self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim))
            # 创建注册令牌参数,形状为(num_register_tokens, dim),初始化为随机值

        self.patch_to_embedding = nn.Sequential(
            LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            LayerNorm(dim)
        )
        # 创建补丁到嵌入的序列模块

        self.post_emb_norm = LayerNorm(dim) if post_emb_norm else nn.Identity()
        # 根据post_emb_norm参数选择是否进行嵌入后的归一化
        self.dropout = nn.Dropout(emb_dropout)
        # 创建丢弃层,用于嵌入的丢弃

        self.attn_layers = attn_layers
        # 设置对象属性attn_layers为传入的注意力层

        self.mlp_head = nn.Linear(dim, num_classes) if exists(num_classes) else nn.Identity()
        # 创建MLP头部,根据是否存在类别数量选择是否添加线性层

    def forward(
        self,
        img,
        return_embeddings = False,
        return_logits_and_embeddings = False
    ):
        # 前向传播函数,接受输入图像img,返回嵌入、逻辑和嵌入的标志
        b, p = img.shape[0], self.patch_size
        # 获取输入图像的批量大小b和补丁大小p

        x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
        # 重排输入图像,将其转换为形状为(b, h*w, p1*p2*c)的张量
        x = self.patch_to_embedding(x)
        # 将补丁转换为嵌入

        n = x.shape[1]
        # 获取嵌入的序列长度n

        x = x + self.pos_embedding[:, :n]
        # 添加位置嵌入到嵌入中

        x = self.post_emb_norm(x)
        # 对嵌入进行归一化
        x = self.dropout(x)
        # 对嵌入进行丢弃

        if self.has_register_tokens:
            # 如果存在注册令牌
            r = repeat(self.register_tokens, 'n d -> b n d', b = b)
            # 重复注册令牌,形状为(b, num_register_tokens, dim)
            x, ps = pack((x, r), 'b * d')
            # 打包嵌入和注册令牌

        embed = self.attn_layers(x)
        # 使用注意力层处理嵌入

        if self.has_register_tokens:
            # 如果存在注册令牌
            embed, _ = unpack(embed, ps, 'b * d')
            # 解包嵌入

        assert at_most_one_of(return_embeddings, return_logits_and_embeddings)
        # 断言返回嵌入和逻辑的标志中最多只有一个为True

        if not exists(self.mlp_head) or return_embeddings:
            # 如果MLP头部不存在或者需要返回嵌入
            return embed
            # 返回嵌入

        pooled = embed.mean(dim = -2)
        # 对嵌入进行平均池化
        logits = self.mlp_head(pooled)
        # 使用MLP头部生成逻辑

        if not return_logits_and_embeddings:
            # 如果不需要返回逻辑和嵌入
            return logits
            # 返回逻辑

        return logits, embed
        # 返回逻辑和嵌入
    # 初始化函数,设置模型参数
    def __init__(
        self,
        *,
        num_tokens,  # 令牌数量
        max_seq_len,  # 最大序列长度
        attn_layers: AttentionLayers,  # 注意力层对象
        embed_num_tokens: Dict[str, int] = dict(),  # 嵌入令牌数量的字典,默认为空
        emb_dim = None,  # 嵌入维度,默认为空
        max_mem_len = 0,  # 最大记忆长度,默认为0
        shift_mem_down = 0,  # 记忆向下移动的步数,默认为0
        emb_dropout = 0.,  # 嵌入层的dropout率,默认为0
        post_emb_norm = False,  # 是否对嵌入后进行归一化,默认为False
        num_memory_tokens = None,  # 记忆令牌数量,默认为空
        memory_tokens_interspersed_every = None,  # 记忆令牌插入间隔,默认为空
        tie_embedding = False,  # 是否共享嵌入权重,默认为False
        logits_dim = None,  # logits维度,默认为空
        use_abs_pos_emb = True,  # 是否使用绝对位置编码,默认为True
        scaled_sinu_pos_emb = False,  # 是否使用缩放的正弦位置编码,默认为False
        l2norm_embed = False,  # 是否对嵌入进行L2归一化,默认为False
        emb_frac_gradient = 1.,  # 梯度分配给嵌入的比例,默认为1
        attn_z_loss_weight = 1e-4,  # 注意力z损失的权重,默认为1e-4
    ):
        # 调用父类的初始化函数
        super().__init__()

        # 获取注意力层的维度
        dim = attn_layers.dim
        # 如果嵌入维度为空,则设置为注意力层的维度
        emb_dim = default(emb_dim, dim)
        self.emb_dim = emb_dim
        self.num_tokens = num_tokens

        self.max_seq_len = max_seq_len
        self.max_mem_len = max_mem_len
        self.shift_mem_down = shift_mem_down

        self.l2norm_embed = l2norm_embed
        # 创建令牌嵌入层对象
        self.token_emb = TokenEmbedding(emb_dim, num_tokens, l2norm_embed = l2norm_embed)

        # 判断是否不需要绝对位置编码
        no_abs_pos_emb = max_seq_len == 0 or not (use_abs_pos_emb and not attn_layers.disable_abs_pos_emb)

        # 根据条件选择不同的位置编码方式
        if no_abs_pos_emb:
            self.pos_emb = always(0)
        elif scaled_sinu_pos_emb:
            self.pos_emb = ScaledSinusoidalEmbedding(emb_dim)
        else:
            self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len, l2norm_embed = l2norm_embed)

        # 初始化额外的嵌入层
        self.embeds = None

        # 如果有额外的嵌入令牌数量,则创建对应的嵌入层
        if len(embed_num_tokens) > 0:
            self.embeds = nn.ModuleDict({f'{name}_embed': nn.Embedding(num_tokens, emb_dim) for name, num_tokens in embed_num_tokens.items()})

        # 设置梯度分配给嵌入的比例
        self.emb_frac_gradient = emb_frac_gradient

        # 对嵌入后的结果进行归一化
        self.post_emb_norm = LayerNorm(emb_dim) if post_emb_norm else nn.Identity()
        self.emb_dropout = nn.Dropout(emb_dropout)

        # 投影嵌入到指定维度
        self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
        self.attn_layers = attn_layers

        # 初始化模型参数
        self.init_()

        # 设置logits的维度
        logits_dim = default(logits_dim, num_tokens)
        # 如果不共享嵌入权重,则创建线性层
        self.to_logits = nn.Linear(dim, logits_dim, bias = False) if not tie_embedding else lambda t: t @ self.token_emb.emb.weight.t()

        # 设置记忆令牌
        num_memory_tokens = default(num_memory_tokens, 0)
        self.num_memory_tokens = num_memory_tokens
        if num_memory_tokens > 0:
            self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))

        self.memory_tokens_interspersed_every = memory_tokens_interspersed_every

        # 判断是否可以进行缓存的kv解码
        self.can_cache_kv = self.num_memory_tokens == 0
        self.can_cache_kv_outside_max_seq_len = no_abs_pos_emb

    # 初始化函数,根据是否进行L2归一化初始化权重
    def init_(self):
        if self.l2norm_embed:
            nn.init.normal_(self.token_emb.emb.weight, std = 1e-5)
            if not isinstance(self.pos_emb, always):
                nn.init.normal_(self.pos_emb.emb.weight, std = 1e-5)
            return

        nn.init.kaiming_normal_(self.token_emb.emb.weight)

    # 前向传播函数
    def forward(
        self,
        x,  # 输入数据
        return_embeddings = False,  # 是否返回嵌入结果
        return_logits_and_embeddings = False,  # 是否返回logits和嵌入结果
        return_intermediates = False,  # 是否返回中间结果
        mask = None,  # 掩码
        return_mems = False,  # 是否返回记忆
        return_attn = False,  # 是否返回注意力
        mems = None,  # 记忆
        mem_masks = None,  # 记忆掩码
        pos = None,  # 位置编码
        prepend_embeds = None,  # 前置嵌入
        prepend_mask = None,  # 前置掩码
        embed_ids: Dict[str, Tensor] = dict(),  # 嵌入ID的字典
        sum_embeds = None,  # 嵌入求和
        return_attn_z_loss = False,  # 是否返回注意力z损失
        attn_z_loss_weight = 1e-4,  # 注意力z损失的权重
        seq_start_pos = None,  # 序列起始位置
        cache: Optional[LayerIntermediates] = None,  # 缓存
        **kwargs  # 其他参数
class XTransformer(nn.Module):
    # 定义 XTransformer 类,继承自 nn.Module
    def __init__(
        self,
        *,
        dim,
        tie_token_emb = False,
        ignore_index = -100,
        pad_value = 0,
        cross_attn_tokens_dropout = 0.,
        **kwargs
    ):
        # 初始化函数,接受一系列参数
        super().__init__()
        # 调用父类的初始化函数

        # 将参数按照前缀分组并修剪
        enc_kwargs, kwargs = groupby_prefix_and_trim('enc_', kwargs)
        dec_kwargs, kwargs = groupby_prefix_and_trim('dec_', kwargs)

        # 断言确保编码器或解码器的维度必须使用 `dim` 关键字设置
        assert 'dim' not in enc_kwargs and 'dim' not in dec_kwargs, 'dimension of either encoder or decoder must be set with `dim` keyword'

        # 从参数中选择并弹出 'num_tokens' 和 'max_seq_len',并设置默认值
        enc_transformer_kwargs = pick_and_pop(['num_tokens', 'max_seq_len'], enc_kwargs)
        enc_transformer_kwargs['emb_dropout'] = enc_kwargs.pop('emb_dropout', 0)
        enc_transformer_kwargs['num_memory_tokens'] = enc_kwargs.pop('num_memory_tokens', None)
        enc_transformer_kwargs['scaled_sinu_pos_emb'] = enc_kwargs.pop('scaled_sinu_pos_emb', False)
        enc_transformer_kwargs['use_abs_pos_emb'] = enc_kwargs.pop('use_abs_pos_emb', True)

        dec_transformer_kwargs = pick_and_pop(['num_tokens', 'max_seq_len'], dec_kwargs)
        dec_transformer_kwargs['emb_dropout'] = dec_kwargs.pop('emb_dropout', 0)
        dec_transformer_kwargs['scaled_sinu_pos_emb'] = dec_kwargs.pop('scaled_sinu_pos_emb', False)
        dec_transformer_kwargs['use_abs_pos_emb'] = dec_kwargs.pop('use_abs_pos_emb', True)

        # 设置交叉注意力的 tokens dropout 参数
        self.cross_attn_tokens_dropout = cross_attn_tokens_dropout

        # 创建编码器和解码器的 TransformerWrapper 对象
        self.encoder = TransformerWrapper(
            **enc_transformer_kwargs,
            attn_layers = Encoder(dim = dim, **enc_kwargs)
        )

        self.decoder = TransformerWrapper(
            **dec_transformer_kwargs,
            attn_layers = Decoder(dim = dim, cross_attend = True, **dec_kwargs)
        )

        # 如果 tie_token_emb 为 True,则共享解码器的 token_emb 层和编码器的 token_emb 层
        if tie_token_emb:
            self.decoder.token_emb = self.encoder.token_emb

        # 将解码器包装在 AutoregressiveWrapper 中
        self.decoder = AutoregressiveWrapper(self.decoder, ignore_index=ignore_index, pad_value=pad_value)

    @torch.no_grad()
    def generate(self, seq_in, seq_out_start, seq_len, mask = None, attn_mask = None, **kwargs):
        # 生成函数,接受输入序列和输出序列的起始位置、长度等参数
        encodings = self.encoder(seq_in, mask = mask, attn_mask = attn_mask, return_embeddings = True)
        # 使用编码器对输入序列进行编码,返回编码结果
        return self.decoder.generate(seq_out_start, seq_len, context = encodings, context_mask = mask, **kwargs)
        # 使用解码器生成输出序列

    def forward(self, src, tgt, mask = None, attn_mask = None, src_prepend_embeds = None):
        # 前向传播函数,接受源序列、目标序列、掩码等参数

        # 使用编码器对源序列进行编码
        enc = self.encoder(src, mask = mask, attn_mask = attn_mask, prepend_embeds = src_prepend_embeds, return_embeddings = True)

        # 如果存在源序列的前置嵌入和掩码,则在掩码上进行填充
        if exists(src_prepend_embeds) and exists(mask):
            mask = pad_at_dim(mask, (src_prepend_embeds.shape[-2], 0), dim = -1, value = True)

        # 如果处于训练状态且交叉注意力 tokens dropout 大于 0,则对编码结果进行 dropout
        if self.training and self.cross_attn_tokens_dropout > 0:
            enc, mask = dropout_seq(enc, mask, self.cross_attn_tokens_dropout)

        # 使用解码器生成输出序列
        out = self.decoder(tgt, context = enc, context_mask = mask)
        return out