Lucidrains-系列项目源码解析-十四-

109 阅读18分钟

Lucidrains 系列项目源码解析(十四)

.\lucidrains\CoLT5-attention\colt5_attention\vit.py

import torch
from torch import nn

from einops import rearrange, pack, unpack
from einops.layers.torch import Rearrange, Reduce

from colt5_attention.transformer_block import (
    ConditionalRoutedImageAttention,
    ConditionalRoutedFeedForward
)

# helpers

# 定义一个函数,如果输入参数是元组则返回元组,否则返回元组包含输入参数的元组
def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# 定义一个函数,生成二维位置编码的正弦和余弦值
def posemb_sincos_2d(patches, temperature = 10000, dtype = torch.float32):
    # 获取 patches 的形状信息
    _, h, w, dim, device, dtype = *patches.shape, patches.device, patches.dtype

    # 生成网格坐标
    y, x = torch.meshgrid(torch.arange(h, device = device), torch.arange(w, device = device), indexing = 'ij')
    # 确保特征维度是4的倍数
    assert (dim % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb'
    # 计算 omega 值
    omega = torch.arange(dim // 4, device = device) / (dim // 4 - 1)
    omega = 1. / (temperature ** omega)

    # 计算位置编码
    y = y.flatten()[:, None] * omega[None, :]
    x = x.flatten()[:, None] * omega[None, :] 
    pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1)
    pe = pe.type(dtype)
    return rearrange(pe, '(h w) d -> h w d', h = h, w = w)

# classes

# 定义一个 Transformer 类
class Transformer(nn.Module):
    def __init__(
        self,
        dim,
        depth,
        attn_num_heavy_tokens_q,
        attn_num_heavy_tokens_kv,
        attn_light_dim_head,
        attn_light_heads,
        attn_light_window_size,
        attn_heavy_dim_head,
        attn_heavy_heads,
        ff_num_heavy_tokens,
        ff_light_mult,
        ff_heavy_mult,
        router_straight_through = True,
        router_kwargs: dict = {},
        router_use_triton = False,
        flash_attn = True,
        attn_num_routed_kv = 1
    ):
        super().__init__()
        self.layers = nn.ModuleList([])

        for _ in range(depth):

            # 创建 ConditionalRoutedFeedForward 实例
            ff = ConditionalRoutedFeedForward(
                dim,
                num_heavy_tokens = ff_num_heavy_tokens,
                light_ff_mult = ff_light_mult,
                heavy_ff_mult = ff_heavy_mult,
                router_straight_through = router_straight_through,
                router_kwargs = router_kwargs,
                use_triton = router_use_triton
            )

            # 创建 ConditionalRoutedImageAttention 实例
            attn = ConditionalRoutedImageAttention(
                dim,
                num_heavy_tokens_q = attn_num_heavy_tokens_q,
                num_heavy_tokens_kv = attn_num_heavy_tokens_kv,
                num_routed_kv = attn_num_routed_kv,
                light_dim_head = attn_light_dim_head,
                light_heads = attn_light_heads,
                light_window_size = attn_light_window_size,
                heavy_dim_head = attn_heavy_dim_head,
                heavy_heads = attn_heavy_heads,
                router_straight_through = router_straight_through,
                router_kwargs = router_kwargs,
                use_triton = router_use_triton,
                use_flash_attn = flash_attn,
                channel_first = False,
                use_null_q_tokens = True
            )

            self.layers.append(nn.ModuleList([attn, ff]))

    # 前向传播函数
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x

            x, ps = pack([x], 'b * d')
            x = ff(x) + x            
            x, = unpack(x, ps, 'b * d')

        return x

# 定义一个 ConditionalRoutedViT 类
class ConditionalRoutedViT(nn.Module):
    def __init__(
        self,
        *,
        image_size,
        patch_size,
        num_classes,
        dim,
        depth,
        attn_num_heavy_tokens_q,
        attn_num_heavy_tokens_kv,
        attn_heavy_dim_head,
        attn_heavy_heads,
        attn_light_dim_head,
        attn_light_heads,
        attn_light_window_size,
        ff_num_heavy_tokens,
        ff_heavy_mult,
        ff_light_mult,
        channels = 3,
        router_straight_through = True,
        router_kwargs: dict = {},
        router_use_triton = False,
        flash_attn = True,
        attn_num_routed_kv = 1,
        default_coor_descent_eps = 1.
    # 定义一个继承自 nn.Module 的类,用于实现图像的分块处理和Transformer处理
    ):
        # 调用父类的初始化方法
        super().__init__()
        # 获取图像的高度和宽度
        image_height, image_width = pair(image_size)
        # 获取分块的高度和宽度
        patch_height, patch_width = pair(patch_size)

        # 断言图像的高度和宽度能够被分块的高度和宽度整除
        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        # 计算分块的数量
        num_patches = (image_height // patch_height) * (image_width // patch_width)
        # 计算每个分块的维度
        patch_dim = channels * patch_height * patch_width

        # 定义一个序列模块,用于将图像分块转换为嵌入向量
        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b h w (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim),
        )

        # 设置路由器参数,包括epsilon值
        router_kwargs = {'eps': default_coor_descent_eps, **router_kwargs}

        # 创建Transformer模块
        self.transformer = Transformer(
            dim,
            depth,
            attn_num_heavy_tokens_q,
            attn_num_heavy_tokens_kv,
            attn_light_dim_head,
            attn_light_heads,
            attn_light_window_size,
            attn_heavy_dim_head,
            attn_heavy_heads,
            ff_num_heavy_tokens,
            ff_light_mult,
            ff_heavy_mult,
            router_straight_through,
            router_kwargs,
            router_use_triton,
            flash_attn,
            attn_num_routed_kv
        )

        # 定义一个线性头部模块,用于分类
        self.linear_head = nn.Sequential(
            Reduce('b h w c -> b c', 'mean'),
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    # 前向传播函数
    def forward(self, img):
        # 获取图像的高度、宽度和数据类型
        *_, h, w, dtype = *img.shape, img.dtype

        # 将图像转换为嵌入向量
        x = self.to_patch_embedding(img)
        # 添加位置编码
        x = x + posemb_sincos_2d(x)        

        # 使用Transformer处理嵌入向量
        x = self.transformer(x)

        # 使用线性头部进行分类
        return self.linear_head(x)

.\lucidrains\CoLT5-attention\colt5_attention\__init__.py

# 从 colt5_attention.transformer_block 模块中导入以下类:
# ConditionalRoutedFeedForward:有条件路由的前馈网络
# ConditionalRoutedAttention:有条件路由的注意力机制
# ConditionalRoutedImageAttention:有条件路由的图像注意力机制
# ConditionalRoutedAutoregressiveAttention:有条件路由的自回归注意力机制
# ConditionalRoutedCrossAttention:有条件路由的交叉注意力机制
# ConditionalRoutedTransformerBlock:有条件路由的Transformer块
# CoordinateDescentRouter:坐标下降路由器

from colt5_attention.coor_descent 模块中导入 coor_descent 函数

from colt5_attention.topk 模块中导入 topk 函数

# 从 colt5_attention.vit 模块中导入 ConditionalRoutedViT 类

CoLT5 Attention - Pytorch

Implementation of the conditionally routed efficient attention in the proposed CoLT5 architecture, in Pytorch.

They used coordinate descent from this paper (main algorithm originally from Wright et al) to route a subset of tokens for 'heavier' branches of the feedforward and attention blocks.

Update: unsure of how the routing normalized scores for the key-values are used. Did some improvising there, scaling the projected values, but if you think you know the answer, please open an issue

Update 2: seems to work well with the improvisation above

Appreciation

  • Stability.ai for the generous sponsorship to work on cutting edge artificial intelligence research

  • einops for making my life easy

  • Triton for allowing me to speed up coordinate descent with a fused implementation in just 2 days, sparing me from having to write a thousand lines of CUDA code

Install

$ pip install colt5-attention

Usage

import torch

from colt5_attention import (
    ConditionalRoutedFeedForward,
    ConditionalRoutedAttention,
    ConditionalRoutedTransformerBlock
)

# mock input, say it is 32768 length

tokens = torch.randn(2, 32768, 512)
mask = torch.ones(2, 32768).bool()  # can handle variable lengthed sequences

# feedforward

ff = ConditionalRoutedFeedForward(
    dim = 512,
    light_ff_mult = 0.5,      # hidden dimension ratio of light branch
    heavy_ff_mult = 4,        # hidden dimension ratio of heavy branch
    num_heavy_tokens = 1024   # heavy branch receives only 1024 routed tokens of 32768
)

ff_out = ff(tokens, mask = mask)  # (2, 32768, 512) - light and heavy branch summed

# attention

attn = ConditionalRoutedAttention(
    dim = 512,
    light_dim_head = 64,       # attention head dimension of light branch
    light_heads = 8,           # number of attention heads for light branch
    light_window_size = 128,   # local attention receptive field for light
    heavy_dim_head = 64,       # attention head dimension of heavy branch
    heavy_heads = 8,           # number of attention heads for heavy branch
    num_heavy_tokens_q = 1024, # heavy branch receives only 1024 routed tokens of 32768
    num_heavy_tokens_kv = 1024 # heavy branch receives only 1024 routed tokens of 32768
)

attn_out = attn(tokens, mask = mask) # (2, 32768, 512) - light and heavy branch summed

# both attention and feedforward with residual
# the complete transformer block
# a stack of these would constitute the encoder of CoLT5

block = ConditionalRoutedTransformerBlock(
    dim = 512,
    light_dim_head = 64,
    light_heads = 8,
    light_window_size = 128,
    heavy_dim_head = 64,
    heavy_heads = 8,
    light_ff_mult = 0.5,
    heavy_ff_mult = 4,
    num_heavy_ff_tokens = 1024,
    num_heavy_attn_tokens_q = 1024,
    num_heavy_attn_tokens_kv = 1024
)

block_out = block(tokens, mask = mask) # (2, 32768, 512)

Also included a variation of the conditionally routed attention for cross attention, to be tried with long context memories in a transformer-xl

import torch
from colt5_attention import ConditionalRoutedCrossAttention

# mock input, let us say it is a transformer of 1024 length attending to 1 million context past memories

tokens = torch.randn(1, 1024, 512).cuda()
tokens_mask = torch.ones(1, 1024).bool().cuda()

memories = torch.randn(1, 1_048_576, 512).cuda()
memories_mask = torch.ones(1, 1_048_576).bool().cuda()

# conditionally routed cross attention

cross_attn = ConditionalRoutedCrossAttention(
    dim = 512,
    dim_head = 64,
    heads = 8,
    num_tokens_q = 512,         # only 512 routed from 1024
    num_tokens_kv = 1024,       # only 1024 routed from 1 million
    kv_routing_tokens = 2,      # say you want 2 routing tokens to route different sets of key / values to the queries. 4 attention heads will be allocated to each routed set in this example (8 / 2)
    use_triton = True,          # use cuda kernel
    route_block_size = 131072   # route in blocks of 131072
).cuda()

cross_attn_out = cross_attn(
    tokens,
    context = memories,
    mask = tokens_mask,
    context_mask = memories_mask
)

cross_attn_out.shape # (1, 1024, 512) - same as tokens

This repository also has an improvised version for autoregressive attention. The way this was achieved was by viewing the sequence in windows. Each window can only attend to windows of key / values into the past. The local attention of the light branch covers the intra-window attention.

The coordinate descent is made viable through a CUDA kernel written in Triton. Finally, to get autoregressive generation to work well, I had to make sure for the unrouted tokens (for queries), outputs a learned output embedding rather than just zeros.

Currently I am seeing occasional differences between the gradients (as high as 1e-1 for a very small fraction of elements) once the number of iterations exceed 20. However, enwik8 seems to train well and I can see the effects of the routing. Training is surprisingly stable too

ex.

import torch
from colt5_attention import ConditionalRoutedAutoregressiveAttention

# mock input, say it is 8192 length

tokens = torch.randn(2, 8192, 512).cuda()

# attention

attn = ConditionalRoutedAutoregressiveAttention(
    dim = 512,
    light_dim_head = 64,          # attention head dimension of light branch
    light_heads = 8,              # number of attention heads for light branch
    light_window_size = 128,      # local attention receptive field for light
    heavy_window_size = 128,      # the windowing for the routed heavy attention, by default, will be equal to the light window size. be aware if this is any greater than the light window size, there may be tokens that would be missed by attention
    heavy_dim_head = 64,          # attention head dimension of heavy branch
    heavy_heads = 8,              # number of attention heads for heavy branch
    num_heavy_tokens_q = 32,      # heavy branch receives only 32 out of 128 of the windowed queries (1024 query tokens total)
    num_heavy_tokens_kv = 1024,   # heavy branch receives only 1024 routed tokens for key-values
    num_routed_kv = 2,            # one can split the attention heads so that groups of heads attend to different sets of key - values (2 routing tokens in this case)
    use_triton = True,            # will need to use Triton for this to be viable, otherwise it is too slow and memory efficient with the number of iterations
    use_flash_attn = True         # use flash attention in heavy branch
).cuda()

attn_out = attn(tokens) + tokens # (2, 8192, 512) - output of attention with residual (prenorm is included)

Finally, this repository contains a version for image feature maps. Typically a lot of research papers cannot do attention on image feature maps with dimensions greater than 32 by 32. This routed attention will use a local window patch for the light branch, and routed attention for the heavy

ex.

import torch
from colt5_attention import ConditionalRoutedImageAttention

attn = ConditionalRoutedImageAttention(
    dim = 32,
    light_dim_head = 64,       # attention head dimension of light branch
    light_heads = 8,           # number of attention heads for light branch
    light_window_size = 32,    # height and width of local window attention on the image feature map
    channel_first = True,      # whether to accept images with channel first than last
    heavy_dim_head = 64,       # attention head dimension of heavy branch
    heavy_heads = 8,           # number of attention heads for heavy branch
    num_heavy_tokens_q = 1024, # heavy branch receives only 1024 routed tokens of 65536
    num_heavy_tokens_kv = 1024 # heavy branch receives only 1024 routed tokens of 65536
).cuda()

fmap = torch.randn(1, 32, 256, 256).cuda() # image feature map is too large for attention, given 256 ^ 2  == 65536 tokens

out = attn(fmap)

Simple ViT using coordinate descent routed attention and feedforward

import torch
from colt5_attention.vit import ConditionalRoutedViT

vit = ConditionalRoutedViT(
    image_size = 256,                # image size
    patch_size = 32,                 # patch size
    num_classes = 1000,              # number of output classes
    dim = 1024,                      # feature dimension
    depth = 6,                       # depth
    attn_num_heavy_tokens_q = 16,    # number of routed queries for heavy attention
    attn_num_heavy_tokens_kv = 16,   # number of routed key/values for heavy attention
    attn_heavy_dim_head = 64,        # dimension per attention head for heavy
    attn_heavy_heads = 8,            # number of attention heads for heavy
    attn_light_window_size = 4,      # the local windowed attention for light branch
    attn_light_dim_head = 32,        # dimension per head for local light attention
    attn_light_heads = 4,            # number of attention heads for local windowed attention
    ff_num_heavy_tokens = 16,        # number of tokens routed for heavy feedforward
    ff_heavy_mult = 4,               # the expansion factor of the heavy feedforward branch
    ff_light_mult = 2                # expansion factor of the light feedforward branch
)

images = torch.randn(1, 3, 256, 256)

logits = vit(images) # (1, 1000)

Differentiable Topk

Use a small wrapper around coordinate descent for differentiable topk

import torch
from colt5_attention import topk

x = torch.randn(1024, 512)

values, indices, coor_descent_values, gates = topk(x, k = 10, fused = True)

# you can either use the topk indices + gates, or use the values directly (values have already been multiplied with the gates within the function)

Todo

  • add the coordinate descent method as another router
  • allow for multi-headed routing (multiple routing tokens), only for key-values
  • add an autoregressive version of the conditionally routed attention
  • test out the autoregressive version and verify that more routed key / value tokens lead to better results - it works
  • make flash attention compatible
  • create a variant of CoLT5 for high resolution feature maps (image attention) - then try out for diffusion
  • fused coordinate descent kernel using triton
    • forwards
    • backwards
    • benchmark triton vs plain pytorch coor_descent - 50 iterations with 4 segments - 18.5x faster for forward (7.23 vs 0.39), 7.2x faster for backwards (5.77 vs 0.80)
    • fall back on plain coordinate descent for cpu
    • handle edge case for when a row is completely masked out for triton, or simply enforce it never to be so
    • fix masking in coordinate descent
    • simplified some logic within the triton kernel and the problem went away. probably some tiny quirk with the compiler
    • maximum block size in triton allowed is 131k, make sure at least quarter of million sequence length can be reached. to get around this initially, one can fold a million token sequence into ~9 131k and uniformly route. offer uniform routing scheme within router itself
    • remove sinkhorn and cumulative softmax approaches and cleanup; neither can work as well as coordinate descent
    • allow for saving intermediates every number of iterations - trading memory for recompute efficiency during backwards
    • in-place write to checkpointed a and b tensor for potentially savings on forward when recompute segments is high

Citations

@inproceedings{Ainslie2023CoLT5FL,
    title   = {CoLT5: Faster Long-Range Transformers with Conditional Computation},
    author  = {Joshua Ainslie and Tao Lei and Michiel de Jong and Santiago Ontan'on and Siddhartha Brahma and Yury Zemlyanskiy and David Uthus and Mandy Guo and James Lee-Thorp and Yi Tay and Yun-Hsuan Sung and Sumit Sanghai},
    year    = {2023}
}
@article{Tillet2019TritonAI,
    title   = {Triton: an intermediate language and compiler for tiled neural network computations},
    author  = {Philippe Tillet and H. Kung and D. Cox},
    journal = {Proceedings of the 3rd ACM SIGPLAN International Workshop on Machine Learning and Programming Languages},
    year    = {2019}
}
@inproceedings{dao2022flashattention,
    title     = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
    author    = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
    booktitle = {Advances in Neural Information Processing Systems},
    year      = {2022}
}
@article{Lei2023ConditionalAP,
    title   = {Conditional Adapters: Parameter-efficient Transfer Learning with Fast Inference},
    author  = {Tao Lei and Junwen Bai and Siddhartha Brahma and Joshua Ainslie and Kenton Lee and Yanqi Zhou and Nan Du and Vincent Zhao and Yuexin Wu and Bo Li and Yu Zhang and Ming-Wei Chang},
    journal = {ArXiv},
    year    = {2023},
    volume  = {abs/2304.04947}
}
@article{Beyer2022BetterPV,
    title   = {Better plain ViT baselines for ImageNet-1k},
    author  = {Lucas Beyer and Xiaohua Zhai and Alexander Kolesnikov},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2205.01580}
}

.\lucidrains\CoLT5-attention\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  # 包的名称
  name = 'CoLT5-attention',
  # 查找并包含所有包
  packages = find_packages(),
  # 版本号
  version = '0.10.20',
  # 许可证类型
  license='MIT',
  # 描述
  description = 'Conditionally Routed Attention',
  # 长描述内容类型
  long_description_content_type = 'text/markdown',
  # 作者
  author = 'Phil Wang',
  # 作者邮箱
  author_email = 'lucidrains@gmail.com',
  # 项目链接
  url = 'https://github.com/lucidrains/CoLT5-attention',
  # 关键词
  keywords = [
    'artificial intelligence',
    'attention mechanism',
    'dynamic routing'
  ],
  # 安装依赖
  install_requires=[
    'einops>=0.6.1',
    'local-attention>=1.8.6',
    'packaging',
    'torch>=1.10'
  ],
  # 分类
  classifiers=[
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\complex-valued-transformer\complex_valued_transformer\attend.py

from functools import partial  # 导入 functools 模块中的 partial 函数

import torch  # 导入 torch 库
from torch import nn, einsum, Tensor  # 从 torch 库中导入 nn、einsum、Tensor
import torch.nn.functional as F  # 从 torch 库中导入 F

from collections import namedtuple  # 导入 collections 模块中的 namedtuple
from functools import wraps  # 导入 functools 模块中的 wraps
from packaging import version  # 导入 packaging 模块中的 version

from einops import rearrange, repeat  # 从 einops 库中导入 rearrange、repeat

# 定义一个命名元组 EfficientAttentionConfig,包含三个属性
EfficientAttentionConfig = namedtuple('EfficientAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])

# 定义辅助函数

# 判断变量是否存在
def exists(val):
    return val is not None

# 如果变量存在则返回该变量,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# 仅执行一次的装饰器函数
def once(fn):
    called = False
    @wraps(fn)
    def inner(x):
        nonlocal called
        if called:
            return
        called = True
        return fn(x)
    return inner

# 仅打印一次的函数
print_once = once(print)

# tensor 函数

# 创建一个因果掩码
def create_causal_mask(i, j, device):
    return torch.ones((i, j), device=device, dtype=torch.bool).triu(j - i + 1)

# 主类

class Attend(nn.Module):
    def __init__(
        self,
        *,
        dropout=0.,
        causal=False,
        heads=None,
        scale=None,
        flash=False,
    ):
        super().__init__()
        self.scale = scale

        self.causal = causal
        self.create_causal_mask = create_causal_mask

        self.dropout = dropout
        self.attn_dropout = nn.Dropout(dropout)

        # flash attention

        self.flash = flash
        assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'

        # determine efficient attention configs for cuda and cpu

        self.cpu_config = EfficientAttentionConfig(True, True, True)
        self.cuda_config = None

        if not torch.cuda.is_available() or not flash:
            return

        device_properties = torch.cuda.get_device_properties(torch.device('cuda'))

        major, minor = device_properties.major, device_properties.minor

        if (major, minor) == (8, 0):
            print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
            self.cuda_config = EfficientAttentionConfig(True, False, False)
        elif (major, minor) == (9, 0):
            print_once('H100 GPU detected, using flash attention')
            self.cuda_config = EfficientAttentionConfig(True, False, False)
        else:
            print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
            self.cuda_config = EfficientAttentionConfig(False, True, True)

    def flash_attn(
        self,
        q, k, v,
        mask=None
    ):
        # 解包 q 的形状,获取 batch, heads, q_len, k_len, is_cuda, device
        batch, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device

        # 检查是否存在 mask 并扩展到兼容的形状
        # mask 是 B L,因此需要扩展为 B H N L

        causal = self.causal

        # 在 kv 缓存中只有一个令牌的情况下(q_len == 1),只需关闭因果掩码
        # 在推测解码中,这可能会增加到 5-6,因此在那里需要右对齐的因果掩码

        if q_len == 1 and causal:
            causal = False

        # 扩展键填充掩码

        if exists(mask):
            assert mask.ndim == 4
            mask = mask.expand(batch, heads, q_len, k_len)

        # 处理 kv 缓存 - 这应该在更新的 flash attention 2 中可以绕过

        if k_len > q_len and causal:
            causal_mask = self.create_causal_mask(q_len, k_len, device=device)
            if not exists(mask):
                mask = ~causal_mask
            else:
                mask = mask & ~causal_mask
            causal = False

        # 手动处理因果掩码,如果给定了另一个掩码

        row_is_entirely_masked = None

        if exists(mask) and causal:
            causal_mask = self.create_causal_mask(q_len, k_len, device=device)
            mask = mask & ~causal_mask

            # 防止整行被掩盖

            row_is_entirely_masked = ~mask.any(dim=-1)
            mask[..., 0] = mask[..., 0] | row_is_entirely_masked

            causal = False

        # 检查是否有兼容的设备用于 flash attention

        config = self.cuda_config if is_cuda else self.cpu_config

        # pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale
        
        with torch.backends.cuda.sdp_kernel(**config._asdict()):
            out = F.scaled_dot_product_attention(
                q, k, v,
                attn_mask=mask,
                dropout_p=self.dropout if self.training else 0., 
                is_causal=causal
            )

        # 对于整行被完全掩盖的情况,应将该行令牌的输出置零

        if exists(row_is_entirely_masked):
            out = out.masked_fill(row_is_entirely_masked[..., None], 0.)

        return out

    def forward(
        self,
        q, k, v,
        mask=None
    ):
        """
        einstein notation
        b - batch
        h - heads
        n, i, j - sequence length (base sequence length, source, target)
        d - feature dimension
        """

        n, heads, kv_heads, device = q.shape[-2], q.shape[1], k.shape[1], q.device

        scale = default(self.scale, q.shape[-1] ** -0.5)

        if self.flash:
            return self.flash_attn(q, k, v, mask=mask)

        kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d'

        sim = einsum(f'b h i d, {kv_einsum_eq} -> b h i j', q, k) * scale

        i, j, dtype = *sim.shape[-2:], sim.dtype

        mask_value = -torch.finfo(sim.dtype).max

        if exists(mask):
            sim = sim.masked_fill(~mask, mask_value)

        if self.causal and n > 1:
            causal_mask = self.create_causal_mask(i, j, device=device)
            sim = sim.masked_fill(causal_mask, mask_value)

        attn = sim.softmax(dim=-1)
        attn = attn.type(dtype)

        attn = self.attn_dropout(attn)

        out = einsum(f'b h i j, {kv_einsum_eq} -> b h i d', attn, v)

        return out

.\lucidrains\complex-valued-transformer\complex_valued_transformer\autoregressive_wrapper.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块
from torch import nn
# 从 torch 库中导入 nn.functional 模块,并重命名为 F
import torch.nn.functional as F

# 从 einops 库中导入 rearrange 函数
from einops import rearrange

# 辅助函数

# 判断变量是否存在的函数
def exists(val):
    return val is not None

# 返回输入的函数
def identity(t):
    return t

# 评估装饰器函数
def eval_decorator(fn):
    def inner(model, *args, **kwargs):
        # 保存模型当前是否为训练状态
        was_training = model.training
        # 将模型设置为评估状态
        model.eval()
        # 调用传入的函数
        out = fn(model, *args, **kwargs)
        # 恢复模型之前的训练状态
        model.train(was_training)
        return out
    return inner

# top k 过滤

# 根据阈值过滤 logits 中的 top k 值
def top_k(logits, thres = 0.9):
    # 计算需要保留的 top k 值的数量
    k = int((1 - thres) * logits.shape[-1])
    # 获取 top k 值及其索引
    val, ind = torch.topk(logits, k)
    # 创建与 logits 相同形状的张量,填充为负的最大值
    probs = torch.full_like(logits, -torch.finfo(logits.dtype).max)
    # 根据索引将 top k 值填充到 probs 中
    probs.scatter_(1, ind, val)
    return probs

# 自回归包装器类
class AutoregressiveWrapper(nn.Module):
    def __init__(
        self,
        net,        
        seq_len,
        pad_value = 0,
        logits_fn = identity
    ):
        super().__init__()
        self.seq_len = seq_len
        self.pad_value = pad_value
        self.net = net
        self.logits_fn = logits_fn

    # 生成函数,用于生成序列
    @torch.no_grad()
    @eval_decorator
    def generate(
        self,
        prompt,
        seq_len,
        temperature = 1.0,
        filter_thres = 0.9,
        **kwargs
    ):
        # 获取 prompt 的形状、设备信息
        b, t, device = *prompt.shape, prompt.device

        out = prompt

        # 生成序列
        for _ in range(seq_len):
            # 获取最后 seq_len 长度的输出
            logits = self.net(out[:, -self.seq_len:], **kwargs)[:, -1]
            logits = self.logits_fn(logits)

            # 过滤 logits 中的 top k 值
            filtered_logits = top_k(logits, thres = filter_thres)
            # 计算 softmax 温度调节后的概率
            probs = F.softmax(filtered_logits / temperature, dim = -1)

            # 从概率分布中采样一个值
            sample = torch.multinomial(probs, 1)
            # 将采样值拼接到输出序列中
            out = torch.cat((out, sample), dim = -1)

        return out[:, t:]

    # 前向传播函数
    def forward(self, x, **kwargs):
        # 获取输入 x 的特征和标签
        x, labels = x[:, :-1], x[:, 1:]
        # 获取模型输出的 logits
        logits = self.net(x, **kwargs)
        # 重排 logits 的维度
        logits = rearrange(self.logits_fn(logits), "b c n -> b n c")
        # 计算交叉熵损失
        return F.cross_entropy(logits, labels)

.\lucidrains\complex-valued-transformer\complex_valued_transformer\complex_valued_transformer.py

from typing import Optional
from functools import partial

import torch
from torch import cfloat
import torch.nn.functional as F
from torch import nn, einsum, Tensor
from torch.nn import Module, ModuleList

from einops import rearrange, repeat, reduce
from einops.layers.torch import Rearrange

from complex_valued_transformer.attend import Attend

# helpers

# 检查变量是否存在
def exists(v):
    return v is not None

# 如果变量存在则返回该变量,否则返回默认值
def default(v, d):
    return v if exists(v) else d

# helper tensor functions

# 使用旋转因子调制输入张量
def modulate_with_rotation(x, m):
    if m.dtype == cfloat:
        m = m.abs()

    rot = m.cos() + 1.j * m.sin()
    return x * rot

# complex attention
# https://arxiv.org/abs/2306.09827

# 实部复杂注意力机制
def complex_attention_real(
    q: Tensor,
    k: Tensor,
    v: Tensor,
    attend: Attend,
    mask: Optional[Tensor] = None
):
    """
    section 4.1 equation 8
    """

    assert all([t.dtype == cfloat for t in (q, k, v)])
    q, k, v = map(torch.view_as_real, (q, k, v))
    q, k, v = map(lambda t: rearrange(t, '... d c -> ... (d c)'), (q, k, v))

    o = attend(q, k, v, mask = mask)

    o = rearrange(o, '... (d c) -> ... d c', c = 2)
    return torch.view_as_complex(o)

# complex attention - Yang et al
# https://arxiv.org/abs/1910.10202

# 完整复杂注意力机制
def complex_attention_complete(
    q: Tensor,
    k: Tensor,
    v: Tensor,
    attend: Attend,
    mask: Optional[Tensor] = None
):
    """
    section 3.2 equation 3
    """
    batch, device = q.shape[0], q.device

    assert all([t.dtype == cfloat for t in (q, k, v)])
    q, k, v = map(torch.view_as_real, (q, k, v))

    # complex attention =    (MH(A, A, A) − MH(A, B, B) − MH(B, A, B) − MH(B, B, A))
    #                     + i(MH(A, A, B) + MH(A, B, A) + MH(B, A, A) − MH(B, B, B))

    q = repeat(q, 'b h n d c -> (c r b) h n d', r = 2)
    k = repeat(k, 'b h n d c -> (r c b) h n d', r = 2)
    v = repeat(v, 'b h n d c -> (r b) h n (d c)', r = 4)

    if exists(mask):
        mask = repeat(mask, 'b ... -> (r b) ...', r = 4)

    o = attend(q, k, v, mask = mask)

    o = rearrange(o, '(r b) ... (d c) -> (r c) b ... d', r = 4, c = 2)

    indices = torch.tensor([0, 3, 5, 6, 1, 2, 4, 7], dtype = torch.long, device = device)

    o = rearrange(o[indices], '(r c) ... -> ... c r', c = 2)

    sign = torch.tensor([
        [1., -1., -1., -1.],   # real component
        [1.,  1.,  1., -1.]    # imag component
    ], dtype = o.dtype, device = device)

    o = (o * sign).sum(dim = -1)

    return torch.view_as_complex(o)

# complex multihead attention

# 复杂多头注意力机制
class ComplexMultiheadAttention(Module):
    def __init__(
        self,
        dim,
        *,
        causal = False,
        dim_head = 32,
        heads = 8,
        complete_complex = False, # whether to use complete complex formulation (Yang et al.) or just the real component, which reduces down to usual dot product on real and imaginary components flattened into the feature dimension
        flash = False
    ):
        super().__init__()
        dim_inner = heads * dim_head

        self.to_q = nn.Linear(dim, dim_inner, bias = False, dtype = cfloat)
        self.to_kv = nn.Linear(dim, dim_inner * 2, bias = False, dtype = cfloat)
        self.to_out = nn.Linear(dim_inner, dim, bias = False, dtype = cfloat)

        maybe_flash_attn = Attend(
            causal = causal,
            heads = heads,
            flash = flash
        )

        complex_attention = complex_attention_complete if complete_complex else complex_attention_real
        self.attend = partial(complex_attention, attend = maybe_flash_attn)

        self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
        self.merge_heads = Rearrange('b h n d -> b n (h d)')

    def forward(
        self,
        x,
        context = None,
        mask = None,
        rotary_emb = None
        ):
        # 检查是否存在上下文变量
        has_context = exists(context)
        # 如果上下文变量不存在,则使用默认值 x
        context = default(context, x)

        # 将输入 x 转换为查询 q,键 k,值 v
        q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
        # 将查询 q,键 k,值 v 分别拆分为多个头部
        q, k, v = map(self.split_heads, (q, k, v))

        # 如果存在旋转嵌入变量,则将查询 q 和键 k 乘以旋转嵌入
        if exists(rotary_emb):
            q = q * rotary_emb
            k = k * rotary_emb

        # 使用注意力机制计算输出 o
        o = self.attend(q, k, v, mask = mask)

        # 将多个头部的输出 o 合并
        o = self.merge_heads(o)
        # 返回最终输出
        return self.to_out(o)
# 定义一个名为 ComplexRMSNorm 的类,继承自 Module 类
class ComplexRMSNorm(Module):
    # 初始化方法,接受一个参数 dim
    def __init__(self, dim):
        # 调用父类的初始化方法
        super().__init__()
        # 初始化 scale 属性为 dim 的平方根的倒数
        self.scale = dim ** -0.5
        # 初始化 gamma 属性为一个可学习参数,维度为 dim,数据类型为复数
        self.gamma = nn.Parameter(torch.ones(dim, dtype=cfloat))

    # 前向传播方法,接受输入 x
    def forward(self, x):
        # 对输入 x 进行维度为 -1 的标准化,然后乘以 gamma 和 scale
        return F.normalize(x, dim=-1) * self.gamma * self.scale

# 定义一个名为 ModReLU 的类,继承自 Module 类
class ModReLU(Module):
    # 初始化方法,接受一个参数 relu_squared,默认为 False
    def __init__(self, relu_squared=False):
        # 调用父类的初始化方法
        super().__init__()
        # 根据 relu_squared 的值确定 pow 的值为 2 或 1
        self.pow = 2 if relu_squared else 1
        # 初始化 bias 属性为一个可学习参数,值为 0
        self.bias = nn.Parameter(torch.tensor(0.))

    # 前向传播方法,接受输入 x
    def forward(self, x):
        # 计算实部,使用 ReLU 函数对绝对值加上 bias,然后取 pow 次方
        real = F.relu(torch.abs(x) + self.bias) ** self.pow
        # 计算虚部,使用指数函数计算角度
        imag = torch.exp(1.j * torch.angle(x))
        # 返回实部和虚部相加的结果
        return real + imag

# 定义一个名为 ComplexFeedForward 的函数,接受参数 dim、mult 和 relu_squared,默认为 4 和 False
def ComplexFeedForward(dim, mult=4, relu_squared=False):
    # 计算内部维度 dim_inner
    dim_inner = dim * mult
    # 返回一个包含线性层、ModReLU 层和线性层的序列
    return nn.Sequential(
        nn.Linear(dim, dim_inner, dtype=cfloat),
        ModReLU(relu_squared=relu_squared),
        nn.Linear(dim_inner, dim, dtype=cfloat)
    )

# 定义一个名为 RotaryEmbedding 的类,继承自 Module 类
class RotaryEmbedding(Module):
    # 初始化方法,接受参数 dim 和 base,默认为 10000
    def __init__(self, dim, base=10000):
        # 调用父类的初始化方法
        super().__init__()
        # 计算频率的倒数
        inv_freq = 1.0 / (base ** (torch.arange(0, dim).float() / dim))
        # 将频率的倒数作为缓冲区注册为 inv_freq 属性
        self.register_buffer('inv_freq', inv_freq)

    # 定义 device 属性,返回 inv_freq 的设备信息
    @property
    def device(self):
        return self.inv_freq.device

    # 前向传播方法,接受参数 seq_len
    def forward(self, seq_len):
        # 生成序列 t,计算频率,返回余弦和正弦值
        t = torch.arange(seq_len, device=self.device).type_as(self.inv_freq)
        freqs = einsum('i, j -> i j', t, self.inv_freq)
        return torch.cos(freqs) + 1.j * torch.sin(freqs)

# 定义一个名为 ComplexTransformer 的类,继承自 Module 类
class ComplexTransformer(Module):
    # 初始化方法,接受多个参数
    def __init__(
        self,
        dim,
        *,
        depth,
        num_tokens: Optional[int] = None,
        causal=False,
        dim_head=32,
        heads=8,
        ff_mult=4,
        relu_squared=True,
        complete_complex=False,
        rotary_emb=True,
        flash_attn=True
    ):
        # 调用父类的初始化方法
        super().__init__()

        # 判断是否存在 num_tokens
        self.has_embed = exists(num_tokens)

        # 如果存在 num_tokens,则初始化 embed 属性为一个可学习参数
        if exists(num_tokens):
            self.embed = nn.Parameter(torch.randn((num_tokens, dim), dtype=cfloat))

        # 根据 rotary_emb 的值初始化 rotary_emb 属性为 None 或 RotaryEmbedding 对象
        self.rotary_emb = None
        if rotary_emb:
            self.rotary_emb = RotaryEmbedding(dim_head)

        # 初始化 layers 属性为一个模块列表,包含多个复杂层
        self.layers = ModuleList([])
        for _ in range(depth):
            self.layers.append(ModuleList([
                ComplexRMSNorm(dim),
                ComplexMultiheadAttention(dim=dim, dim_head=dim_head, heads=heads, causal=causal, complete_complex=complete_complex, flash=flash_attn),
                ComplexRMSNorm(dim),
                ComplexFeedForward(dim=dim, mult=ff_mult, relu_squared=relu_squared)
            ]))

        # 初始化 norm 属性为 ComplexRMSNorm 对象
        self.norm = ComplexRMSNorm(dim)

        # 初始化 to_logits 属性为一个线性层,用于输出结果
        self.to_logits = nn.Linear(dim, num_tokens, dtype=cfloat)

    # 前向传播方法,接受输入 x、context、mask 和其他参数
    def forward(
        self,
        x,
        context=None,
        mask=None,
        return_abs_logits=False,
        return_real_logits=False
    ):
        # 如果存在 embed 属性,则将 x 替换为 embed[x]
        if self.has_embed:
            x = self.embed[x]

        # 获取序列长度
        seq_len = x.shape[-2]
        rotary_emb = None

        # 如果存在 rotary_emb 属性,则计算 rotary_emb
        if exists(self.rotary_emb):
            rotary_emb = self.rotary_emb(seq_len)

        # 遍历复杂层,进行前向传播
        for attn_norm, attn, ff_norm, ff in self.layers:
            x = attn(attn_norm(x), context=context, mask=mask, rotary_emb=rotary_emb) + x
            x = ff(ff_norm(x)) + x

        # 对结果进行标准化
        x = self.norm(x)

        # 如果不存在 embed 属性,则直接返回结果
        if not self.has_embed:
            return x

        # 计算 logits
        logits = self.to_logits(x)

        # 根据参数选择返回的 logits 类型
        assert (int(return_abs_logits) + int(return_real_logits)) <= 1
        if return_abs_logits:
            logits = logits.abs()
        elif return_real_logits:
            logits = logits.real

        return logits

.\lucidrains\complex-valued-transformer\complex_valued_transformer\__init__.py

# 从 complex_valued_transformer 模块中导入以下函数和类
from complex_valued_transformer.complex_valued_transformer import (
    ComplexMultiheadAttention,  # 导入复数多头注意力机制类
    ComplexRMSNorm,  # 导入复数均方根归一化类
    ComplexFeedForward,  # 导入复数前馈神经网络类
    ComplexTransformer,  # 导入复数变换器类
    complex_attention_real,  # 导入实部注意力函数
    complex_attention_complete,  # 导入完整注意力函数
    modulate_with_rotation  # 导入旋转调制函数
)

Data source

The enwik8 data was downloaded from the Hutter prize page: prize.hutter1.net/

Complex Valued Transformer

Implementation of the transformer proposed in Building Blocks for a Complex-Valued Transformer Architecture, plus a few other proposals from related papers. The full architecture will be evaluated on enwik8 character level language modeling as well as some algorithmic tasks (parity, binary addition).

Will not bother with complex layernorm, as RMS norm is now much more popular.

Update: It trains, seems to tolerate a much higher learning rate. Surprisingly stable, even when using softmax for complete complex formulation from Yang et al. This is likely because both papers are using the original transformer architecture with post-normalization instead of the recent pre-normalization.

Update 2: No difference between Eilers (just real component) vs Yang (real and imaginary) complex attention, at least for enwik8

Update 3: I am not seeing anything remarkable. YMMV

Install

$ pip install complex-valued-transformer

Usage

import torch
from complex_valued_transformer import ComplexTransformer

transformer = ComplexTransformer(
    num_tokens = 256,
    dim = 512,
    depth = 4,
    dim_head = 32,
    heads = 8,
    causal = True,
    complete_complex = True
)

ids = torch.randint(0, 256, (2, 1024))

logits = transformer(ids) # (2, 1024, 256)

Todo

  • add rotary embeddings, formulated in complex domain

  • flash attention v1 compat

  • consider integrating with BS-RoFormer

  • craft a few algorithmic tasks, and explore layers that modulate rotations, see if giving that inductive bias makes a difference

Citations

@article{Eilers2023BuildingBF,
    title   = {Building Blocks for a Complex-Valued Transformer Architecture},
    author  = {Florian Eilers and Xiaoyi Jiang},
    journal = {ArXiv},
    year    = {2023},
    volume  = {abs/2306.09827},
    url     = {https://api.semanticscholar.org/CorpusID:258542729}
}
@article{Yang2019ComplexTA,
    title    = {Complex Transformer: A Framework for Modeling Complex-Valued Sequence},
    author   = {Muqiao Yang and Martin Q. Ma and Dongyu Li and Yao-Hung Hubert Tsai and Ruslan Salakhutdinov},
    journal  = {ICASSP 2020 - 2020 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
    year     = {2019},
    pages    = {4232-4236},
    url      = {https://api.semanticscholar.org/CorpusID:204838137}
}
@article{Dong2021SignalTC,
    title   = {Signal Transformer: Complex-valued Attention and Meta-Learning for Signal Recognition},
    author  = {Yihong Dong and Ying Peng and Muqiao Yang and Songtao Lu and Qingjiang Shi},
    journal = {ArXiv},
    year    = {2021},
    volume  = {abs/2106.04392},
    url     = {https://api.semanticscholar.org/CorpusID:235367992}
}
@inproceedings{dao2022flashattention,
    title   = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
    author  = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
    booktitle = {Advances in Neural Information Processing Systems},
    year    = {2022}
}
@article{So2021PrimerSF,
    title   = {Primer: Searching for Efficient Transformers for Language Modeling},
    author  = {David R. So and Wojciech Ma'nke and Hanxiao Liu and Zihang Dai and Noam M. Shazeer and Quoc V. Le},
    journal = {ArXiv},
    year    = {2021},
    volume  = {abs/2109.08668},
    url     = {https://api.semanticscholar.org/CorpusID:237563187}
}

.\lucidrains\complex-valued-transformer\setup.py

# 导入设置安装和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'complex-valued-transformer',  # 包的名称
  packages = find_packages(exclude=[]),  # 查找所有包
  version = '0.0.14',  # 版本号
  license='MIT',  # 许可证
  description = 'Complex Valued Transformer / Attention',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  long_description_content_type = 'text/markdown',  # 长描述内容类型
  url = 'https://github.com/lucidrains/complex-valued-transformer',  # 项目链接
  keywords = [  # 关键词列表
    'artificial intelligence',
    'deep learning',
    'attention mechanisms',
    'transformers',
    'complex domain'
  ],
  install_requires=[  # 安装依赖
    'einops>=0.7.0',
    'torch>=1.12'
  ],
  classifiers=[  # 分类器
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\complex-valued-transformer\train.py

# 导入所需的库
import gzip
import random
import tqdm
import numpy as np

import torch
from torch.optim import Adam
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset

# 导入自定义的模块
from complex_valued_transformer.autoregressive_wrapper import AutoregressiveWrapper
from complex_valued_transformer.complex_valued_transformer import ComplexTransformer

# 定义常量
NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 1e-1
VALIDATE_EVERY = 100
PRIME_LENGTH = 128
GENERATE_EVERY = 500
GENERATE_LENGTH = 512
SEQ_LEN = 512

# 定义辅助函数
def cycle(loader):
    # 无限循环生成数据
    while True:
        for data in loader:
            yield data

def decode_token(token):
    # 将 token 解码为字符
    return str(chr(max(32, token)))

def decode_tokens(tokens):
    # 将 tokens 解码为字符串
    return "".join(list(map(decode_token, tokens)))

# 实例化 Transformer 模型
model = ComplexTransformer(
    num_tokens = 256,
    dim = 256,
    dim_head = 32,
    depth = 8,
    causal = True,
    complete_complex = True # 设置为 True 会增加 MHA 的计算量(Yang 等人的论文)
)

model = AutoregressiveWrapper(
    model,
    seq_len = SEQ_LEN,
    logits_fn = lambda logits: logits.real
).cuda()

# 准备 enwik8 数据
with gzip.open("./data/enwik8.gz") as file:
    data = np.frombuffer(file.read(int(95e6)), dtype=np.uint8).copy()
    np_train, np_valid = np.split(data, [int(90e6)])
    data_train, data_val = torch.from_numpy(np_train), torch.from_numpy(np_valid)

# 定义数据集类
class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
        full_seq = self.data[rand_start : rand_start + self.seq_len + 1].long()
        return full_seq.cuda()

    def __len__(self):
        return self.data.size(0) // self.seq_len

# 创建训练集和验证集
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
train_loader = cycle(DataLoader(train_dataset, batch_size=BATCH_SIZE))
val_loader = cycle(DataLoader(val_dataset, batch_size=BATCH_SIZE))

# 定义优化器
optim = Adam(model.parameters(), lr = LEARNING_RATE)

# 训练模型
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval = 10.0, desc = "training"):
    model.train()

    for _ in range(GRADIENT_ACCUMULATE_EVERY):
        loss = model(next(train_loader))
        loss.backward(loss / GRADIENT_ACCUMULATE_EVERY)

    print(f"training loss: {loss.item()}")
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)

    optim.step()
    optim.zero_grad()

    if i % VALIDATE_EVERY == 0:
        model.eval()
        with torch.no_grad():
            loss = model(next(val_loader))
            print(f"validation loss: {loss.item()}")

    if i % GENERATE_EVERY == 0:
        model.eval()
        inp = random.choice(val_dataset)[:PRIME_LENGTH]
        prime = decode_tokens(inp)
        print(f"%s \n\n %s", (prime, "*" * 100))

        sample = model.generate(inp[None, ...], GENERATE_LENGTH)
        output_str = decode_tokens(sample[0])
        print(output_str, "\n")

.\lucidrains\compositional-attention-pytorch\compositional_attention_pytorch\compositional_attention_pytorch.py

import torch
import torch.nn.functional as F
from torch import nn, einsum

from einops import rearrange
from einops_exts import rearrange_many

# 检查变量是否存在的函数
def exists(val):
    return val is not None

# 计算稳定的 softmax 函数
def stable_softmax(t, dim = -1):
    t = t - t.amax(dim = dim, keepdim = True).detach()
    return t.softmax(dim = dim)

# 组合注意力机制类
class CompositionalAttention(nn.Module):
    def __init__(
        self,
        dim,
        dim_head = 64,
        num_searches = 8,
        num_retrievals = 2,
        dropout = 0.,
        prenorm = False,
        causal = False
    ):
        super().__init__()
        # 根据 prenorm 参数选择是否使用 LayerNorm 或 Identity
        self.norm = nn.LayerNorm(dim) if prenorm else nn.Identity()

        self.scale = dim_head ** -0.5
        inner_search_dim = dim_head * num_searches
        inner_retrieval_dim = dim_head * num_retrievals

        self.num_searches = num_searches
        self.num_retrievals = num_retrievals

        # 线性变换层,将输入映射到搜索查询和键
        self.to_searches_queries = nn.Linear(dim, inner_search_dim, bias = False)
        self.to_searches_keys = nn.Linear(dim, inner_search_dim, bias = False)
        self.to_retrieval_values = nn.Linear(dim, inner_retrieval_dim, bias = False)

        # 线性变换层,将输入映射到检索查询和键
        self.to_retrieval_queries = nn.Linear(dim, inner_search_dim, bias = False)
        self.to_retrieval_keys = nn.Linear(dim_head, dim_head, bias = False)

        # 线性变换层,将检索结果映射回输出维度
        self.to_out = nn.Linear(inner_search_dim, dim, bias = False)

        self.search_dropout = nn.Dropout(dropout)
        self.retrieval_dropout = nn.Dropout(dropout)

        # 是否使用自回归变体进行自我实验
        self.causal = causal

    def forward(self, x, mask = None):
        """
        einstein notation:
        b - batch
        n - sequence dimension
        i - sequence dimension (source)
        j - sequence dimension (target, aggregation dimension)
        s - number of searches
        r - number of retrievals
        d - feature dimension
        """
        x = self.norm(x)

        s = self.num_searches
        r = self.num_retrievals

        # 获取搜索查询和键
        sq, sk = self.to_searches_queries(x), self.to_searches_keys(x)
        sq, sk = rearrange_many((sq, sk), 'b n (s d) -> b s n d', s = s)

        sq = sq * self.scale

        # 计算搜索相似度和注意力
        search_sim = einsum('b s i d, b s j d -> b s i j', sq, sk)

        if exists(mask):
            mask = rearrange(mask, 'b j -> b 1 1 j')
            search_sim = search_sim.masked_fill(~mask, -torch.finfo(search_sim.dtype).max)

        if self.causal:
            i, j = search_sim.shape[-2:]
            causal_mask = torch.ones((i, j), device = x.device, dtype = torch.bool).triu(j - i + 1)
            search_sim = search_sim.masked_fill(causal_mask, -torch.finfo(search_sim.dtype).max)

        search_attn = stable_softmax(search_sim, dim = -1)
        search_attn = self.search_dropout(search_attn)

        # 获取检索值
        rv = self.to_retrieval_values(x)
        rv = rearrange(rv, 'b n (r d) -> b r n d', r = r)

        retrieved = einsum('b s i j, b r j d -> b s r i d', search_attn, rv)

        # 获取检索查询和键
        rq, rk = self.to_retrieval_queries(x), self.to_retrieval_keys(retrieved)
        rq = rearrange(rq, 'b n (s d) -> b s n d', s = s)
        rq = rq * self.scale

        # 获取检索注意力
        retrieval_sim = einsum('b s n d , b s r n d -> b s n r', rq, rk)

        retrieval_attn = stable_softmax(retrieval_sim, dim = -1)
        retrieval_attn = self.retrieval_dropout(retrieval_attn)

        # 聚合检索结果
        out = einsum('b s n r, b s r n d -> b s n d', retrieval_attn, retrieved)

        # 组合搜索结果
        out = rearrange(out, 'b s n d -> b n (s d)')
        return self.to_out(out)

.\lucidrains\compositional-attention-pytorch\compositional_attention_pytorch\__init__.py

# 从compositional_attention_pytorch包中导入CompositionalAttention类
from compositional_attention_pytorch.compositional_attention_pytorch import CompositionalAttention

Compositional Attention - Pytorch

Implementation of Compositional Attention from MILA. They reframe the "heads" of multi-head attention as "searches", and once the multi-headed/searched values are aggregated, there is an extra retrieval step (using attention) off the searched results. They then show this variant of attention yield better OOD results on a toy task. Their ESBN results still leaves a lot to be desired, but I like the general direction of the paper.

Install

$ pip install compositional-attention-pytorch

Usage

import torch
from compositional_attention_pytorch import CompositionalAttention

attn = CompositionalAttention(
    dim = 1024,            # input dimension
    dim_head = 64,         # dimension per attention 'head' - head is now either search or retrieval
    num_searches = 8,      # number of searches
    num_retrievals = 2,    # number of retrievals
    dropout = 0.,          # dropout of attention of search and retrieval
)

tokens = torch.randn(1, 512, 1024)  # tokens
mask = torch.ones((1, 512)).bool()  # mask

out = attn(tokens, mask = mask) # (1, 512, 1024)

Citations

@article{Mittal2021CompositionalAD,
    title   = {Compositional Attention: Disentangling Search and Retrieval},
    author  = {Sarthak Mittal and Sharath Chandra Raparthy and Irina Rish and Yoshua Bengio and Guillaume Lajoie},
    journal = {ArXiv},
    year    = {2021},
    volume  = {abs/2110.09419}
}

.\lucidrains\compositional-attention-pytorch\setup.py

# 导入设置工具和查找包工具
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  # 包名
  name = 'compositional-attention-pytorch',
  # 查找所有包,不排除任何包
  packages = find_packages(exclude=[]),
  # 版本号
  version = '0.0.1',
  # 许可证
  license='MIT',
  # 描述
  description = 'Compositional Attention - Pytorch',
  # 作者
  author = 'Phil Wang',
  # 作者邮箱
  author_email = 'lucidrains@gmail.com',
  # 项目链接
  url = 'https://github.com/lucidrains/compositional-attention-pytorch',
  # 关键词
  keywords = [
    'artificial intelligence',
    'deep learning',
    'attention mechanism'
  ],
  # 安装依赖
  install_requires=[
    'einops>=0.4',
    'einops-exts',
    'torch>=1.6',
  ],
  # 分类
  classifiers=[
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\compressive-transformer-pytorch\compressive_transformer_pytorch\autoregressive_wrapper.py

# 导入数学库
import math
# 导入partial函数
from functools import partial
# 导入namedtuple
from collections import namedtuple

# 导入torch库
import torch
# 导入torch的神经网络模块
from torch import nn
import torch.nn.functional as F
# 导入pad_sequence函数
from torch.nn.utils.rnn import pad_sequence

# 定义一个命名元组Return,包含loss、aux_loss和is_last_batch三个字段
Return = namedtuple('Return', ['loss', 'aux_loss', 'is_last_batch'])

# 定义辅助函数

# top_p函数,根据概率阈值过滤logits
def top_p(logits, thres = 0.9):
    # 对logits进行降序排序
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    # 计算累积概率
    cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

    # 根据阈值确定需要移除的索引
    sorted_indices_to_remove = cum_probs > (1 - thres)
    sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
    sorted_indices_to_remove[:, 0] = 0

    # 将需要移除的logits设置为负无穷
    sorted_logits[sorted_indices_to_remove] = float('-inf')
    return sorted_logits.scatter(1, sorted_indices, sorted_logits)

# top_k函数,根据概率阈值过滤logits
def top_k(logits, thres = 0.9):
    # 计算需要保留的top k值
    k = int((1 - thres) * logits.shape[-1])
    # 获取top k值及其索引
    val, ind = torch.topk(logits, k)
    # 创建与logits相同形状的tensor,并填充为负无穷
    probs = torch.full_like(logits, float('-inf'))
    # 将top k值填充到对应位置
    probs.scatter_(1, ind, val)
    return probs

# 主类

# AutoregressiveWrapper类,继承自nn.Module
class AutoregressiveWrapper(nn.Module):
    # 初始化函数
    def __init__(self, net, ignore_index = -100, pad_value = 0):
        super().__init__()
        self.pad_value = pad_value
        self.ignore_index = ignore_index

        self.net = net
        self.seq_len = net.seq_len

    # 生成函数,用于生成序列
    @torch.no_grad()
    def generate(self, start_tokens, seq_len, eos_token = None, temperature = 1., filter_logits_fn = top_k, filter_thres = 0.9, **kwargs):
        # 保存网络是否处于训练状态
        was_training = self.net.training
        num_dims = len(start_tokens.shape)

        if num_dims == 1:
            start_tokens = start_tokens[None, :]

        b, t = start_tokens.shape

        # 将网络设置为评估模式
        self.net.eval()

        out = start_tokens

        # 处理默认的masking

        full_mask_like = lambda x: torch.full_like(x, True, dtype=torch.bool, device=x.device)

        mask = kwargs.pop('mask', None)
        if mask is None:
            mask = full_mask_like(out)

        # 处理任意长度的primed序列

        mem = None
        *primes, out = out.split(self.seq_len, dim=1)
        *prime_masks, mask = mask.split(self.seq_len, dim=1)

        for prime, prime_mask in zip(primes, prime_masks):
            _, mem, _ = self.net(prime, memories = mem, mask = prime_mask, **kwargs)

        # 生成直到达到序列长度

        input_len = out.shape[1]

        for _ in range(seq_len):
            logits, mem, aux_loss = self.net(out[:, -input_len:], memories = mem, mask = mask[:, -input_len:], **kwargs)
            logits = logits[:, -1, :]
            filtered_logits = filter_logits_fn(logits, thres = filter_thres)
            probs = F.softmax(filtered_logits / temperature, dim=-1)
            sample = torch.multinomial(probs, 1)

            # 不同于大多数模型,一旦填满完整序列长度,输入从序列长度为1开始

            out = torch.cat((out, sample), dim=-1)
            mask = F.pad(mask, (0, 1), value=True)

            # 将样本追加到累积输出中

            input_len = input_len % self.seq_len
            input_len += 1

            if eos_token is not None and (sample == eos_token).all():
                break

        out = out[:, t:]

        if num_dims == 1:
            out = out.squeeze(0)

        # 恢复网络训练状态
        self.net.train(was_training)
        return out
    # 定义一个前向传播函数,接受输入 x,最大批处理大小 max_batch_size,默认不返回损失,**kwargs 为其他参数
    def forward(self, x, max_batch_size = None, return_loss = False, **kwargs):
        # 定义一个填充函数,将输入序列填充到相同长度
        pad = partial(pad_sequence, batch_first = True, padding_value = self.pad_value)

        # 如果不需要返回损失
        if not return_loss:
            # 如果输入不是张量,则进行填充
            if not isinstance(x, torch.Tensor):
                x = pad(x)
            # 返回网络输出结果
            return self.net(x, **kwargs)

        # 如果需要返回损失
        if isinstance(x, torch.Tensor):
            # 将输入序列拆分为输入和输出序列
            xi = x[:, :-1]
            xo = x[:, 1:]
        else:
            # 对输入序列进行填充和拆分
            xi = pad(list(map(lambda t: t[:-1], x)))
            xo = pad(list(map(lambda t: t[1:], x)))

        # 处理输入掩码,解决自回归模型中输入掩码与源序列长度不匹配的问题
        mask = kwargs.pop('mask', None)
        if mask is not None and mask.shape[1] == x.shape[1]:
            mask = mask[:, :-1]

        # 定义一个函数,用于将序列分段
        segment_fn = lambda x: x.split(self.seq_len, dim=1)
        # 将输入和输出序列分段
        (xi, xo) = map(segment_fn, (xi, xo))

        # 获取序列段数
        num_segments = len(xi)
        # 如果存在掩码,则对掩码进行分段处理
        mask = segment_fn(mask) if mask is not None else ((None,) * num_segments)

        # 如果最大批处理大小未指定,则设为输入序列的大小
        max_batch_size = x.shape[0] if max_batch_size is None else max_batch_size
        # 定义一个函数,用于将序列按照最大批处理大小分割
        split_batch_fn = lambda x: x.split(max_batch_size, dim=0)

        # 计算梯度累积次数
        grad_accumulate_every = math.ceil(x.shape[0] / max_batch_size)
        # 初始化记忆列表
        mems = [None] * grad_accumulate_every

        # 遍历每个序列段
        for xi_seg, xo_seg, mask_seg in zip(xi, xo, mask):
            # 将输入和输出序列按照最大批处理大小分割
            xi_seg, xo_seg = map(split_batch_fn, (xi_seg, xo_seg))
            mask_seg = split_batch_fn(mask_seg) if mask_seg is not None else ((None,) * grad_accumulate_every)

            new_mems = []
            # 遍历每个分割后的序列段
            for ind, (xi_seg_b, xo_seg_b, mask_seg_b, mem) in enumerate(zip(xi_seg, xo_seg, mask_seg, mems)):
                is_last = ind == (grad_accumulate_every - 1)

                # 获取网络输出结果、新记忆和辅助损失
                logits, new_mem, aux_loss = self.net(xi_seg_b, mask = mask_seg_b, memories = mem, **kwargs)
                new_mems.append(new_mem)

                # 计算交叉熵损失
                loss = F.cross_entropy(logits.transpose(1, 2), xo_seg_b, ignore_index = self.ignore_index)
                # 返回损失、辅助损失和是否为最后一个序列段的标志
                yield Return(loss, aux_loss, is_last)

            mems = new_mems

.\lucidrains\compressive-transformer-pytorch\compressive_transformer_pytorch\compressive_transformer_pytorch.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块
from torch import nn
# 从 torch.nn 模块中导入 F 函数
import torch.nn.functional as F

# 从 mogrifier 模块中导入 Mogrifier 类
from mogrifier import Mogrifier

# 导入 math 库
import math
# 从 collections 模块中导入 namedtuple 类
from collections import namedtuple
# 从 functools 模块中导入 partial 函数
from functools import partial
# 从 inspect 模块中导入 isfunction 函数

# 定义 Memory 命名元组
Memory = namedtuple('Memory', ['mem', 'compressed_mem'])

# 辅助函数

# 定义 to 函数,返回包含数据类型和设备信息的字典
def to(t):
    return {'dtype': t.dtype, 'device': t.device}

# 定义 cast_tuple 函数,将元素转换为元组
def cast_tuple(el):
    return el if isinstance(el, tuple) else (el,)

# 定义 default 函数,如果 x 不为 None,则返回 x,否则返回 val 或 val() 的结果
def default(x, val):
    if x is not None:
        return x
    return val if not isfunction(val) else val()

# 定义 max_neg_value 函数,返回给定张量的最大负值
def max_neg_value(tensor):
    return -torch.finfo(tensor.dtype).max

# 定义 reshape_dim 函数,根据给定维度和分割维度对张量进行重塑
def reshape_dim(t, dim, split_dims):
    shape = list(t.shape)
    num_dims = len(shape)
    dim = (dim + num_dims) % num_dims
    shape[dim:dim+1] = split_dims
    return t.reshape(shape)

# 定义 split_at_index 函数,根据给定维度和索引将张量分割成两部分
def split_at_index(dim, index, t):
    pre_slices = (slice(None),) * dim
    l = (*pre_slices, slice(None, index))
    r = (*pre_slices, slice(index, None))
    return t[l], t[r]

# 定义 queue_fifo 函数,实现先进先出队列操作
def queue_fifo(*args, length, dim=-2):
    queue = torch.cat(args, dim=dim)
    if length > 0:
        return split_at_index(dim, -length, queue)

    device = queue.device
    shape = list(queue.shape)
    shape[dim] = 0
    return queue, torch.empty(shape, device=device)

# 定义 shift 函数,实现张量的位移操作
def shift(x):
    *_, i, j = x.shape
    zero_pad = torch.zeros((*_, i, i), **to(x))
    x = torch.cat([x, zero_pad], -1)
    l = i + j - 1
    x = x.view(*_, -1)
    zero_pad = torch.zeros(*_, -x.size(-1) % l, **to(x))
    shifted = torch.cat([x, zero_pad], -1).view(*_, -1, l)
    return shifted[..., :i, i - 1:]

# 定义 iterate_tensor 函数,实现对张量的迭代操作
def iterate_tensor(t):
    length = t.shape[0]
    for ind in range(length):
        yield t[ind]

# full attention 用于计算辅助重构损失

# 定义 full_attn 函数,实现全连接注意力机制
def full_attn(q, k, v, dropout_fn=None):
    *_, dim = q.shape
    dots = torch.einsum('bhid,bhjd->bhij', q, k) * (dim ** -0.5)
    attn = dots.softmax(dim=-1)
    if dropout_fn is not None:
        attn = dropout_fn(attn)
    return torch.einsum('bhij,bhjd->bhid', attn, v)

# 辅助类

# 定义 Residual 类,实现残差连接
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn
    def forward(self, x, **kwargs):
        out = self.fn(x, **kwargs)
        out = cast_tuple(out)
        ret = (out[0] + x), *out[1:]
        return ret

# 定义 GRUGating 类,实现 GRU 门控机制
class GRUGating(nn.Module):
    def __init__(self, dim, fn, mogrify=False):
        super().__init__()
        self.dim = dim
        self.fn = fn
        self.gru = nn.GRUCell(dim, dim)
        self.mogrify = Mogrifier(dim, factorize_k=dim // 4) if mogrify else None

    def forward(self, x, **kwargs):
        batch, dim = x.shape[0], self.dim
        out = self.fn(x, **kwargs)
        (y, *rest) = cast_tuple(out)

        if self.mogrify is not None:
            y, x = self.mogrify(y, x)

        gated_output = self.gru(
            y.reshape(-1, dim),
            x.reshape(-1, dim)
        )

        gated_output = gated_output.reshape(batch, -1, dim)
        ret = gated_output, *rest
        return ret

# 定义 PreNorm 类,实现预层归一化
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        x = self.norm(x)
        return self.fn(x, **kwargs)

# 定义 ConvCompress 类,实现卷积压缩
class ConvCompress(nn.Module):
    def __init__(self, dim, ratio=4):
        super().__init__()
        self.conv = nn.Conv1d(dim, dim, ratio, stride=ratio)

    def forward(self, mem):
        mem = mem.transpose(1, 2)
        compressed_mem = self.conv(mem)
        return compressed_mem.transpose(1, 2)

# feedforward

# 定义 GELU_ 类,实现 GELU 激活函数
class GELU_(nn.Module):
    def forward(self, x):
        return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))

# 如果 nn 模块中存在 GELU 函数,则使用 nn.GELU,否则使用 GELU_ 类
GELU = nn.GELU if hasattr(nn, 'GELU') else GELU_

# 定义 FeedForward 类
class FeedForward(nn.Module):
    # 初始化神经网络模块,设置输入维度、倍数、dropout率、激活函数和是否使用GLU
    def __init__(self, dim, mult = 4, dropout = 0., activation = None, glu = False):
        # 调用父类的初始化方法
        super().__init__()
        # 设置默认激活函数为GELU
        activation = default(activation, GELU)

        # 是否使用GLU
        self.glu = glu
        # 第一层线性变换,输入维度为dim,输出维度为dim * mult * (2 if glu else 1)
        self.w1 = nn.Linear(dim, dim * mult * (2 if glu else 1))
        # 激活函数
        self.act = activation()
        # dropout层
        self.dropout = nn.Dropout(dropout)
        # 第二层线性变换,输入维度为dim * mult,输出维度为dim
        self.w2 = nn.Linear(dim * mult, dim)

    # 前向传播函数
    def forward(self, x, **kwargs):
        # 如果不使用GLU
        if not self.glu:
            # 第一层线性变换
            x = self.w1(x)
            # 激活函数
            x = self.act(x)
        else:
            # 使用GLU
            # 将第一层线性变换的输出分成两部分
            x, v = self.w1(x).chunk(2, dim=-1)
            # 激活函数作用在其中一部分上,另一部分保持不变
            x = self.act(x) * v

        # dropout层
        x = self.dropout(x)
        # 第二层线性变换
        x = self.w2(x)
        # 返回结果
        return x
# 定义 SelfAttention 类,继承自 nn.Module
class SelfAttention(nn.Module):
    # 初始化函数,接受多个参数
    def __init__(self, dim, seq_len, mem_len, cmem_len, cmem_ratio = 4, heads = 8, attn_dropout = 0., dropout = 0., reconstruction_attn_dropout = 0.):
        super().__init__()
        # 断言确保维度能够被头数整除
        assert (dim % heads) == 0, 'dimension must be divisible by the number of heads'

        # 初始化各个参数
        self.heads = heads
        self.dim_head = dim // heads
        self.seq_len = seq_len
        self.mem_len = mem_len
        self.cmem_len = cmem_len
        self.cmem_ratio = cmem_ratio
        self.scale = self.dim_head ** (-0.5)

        # 创建 ConvCompress 对象,用于压缩记忆
        self.compress_mem_fn = ConvCompress(dim, cmem_ratio)

        # 创建线性层,用于计算查询、键和值
        self.to_q = nn.Linear(dim, dim, bias = False)
        self.to_kv = nn.Linear(dim, dim * 2, bias = False)
        self.to_out = nn.Linear(dim, dim)

        # 创建 Dropout 层,用于注意力机制的 dropout 和整体的 dropout
        self.attn_dropout = nn.Dropout(attn_dropout)
        self.dropout = nn.Dropout(dropout)

        # 创建 Dropout 层,用于重构注意力机制的 dropout
        self.reconstruction_attn_dropout = nn.Dropout(reconstruction_attn_dropout)
    # 定义前向传播函数,接受输入 x 和一些可选参数
    def forward(self, x, memories = None, pos_emb = None, input_mask = None, calc_memory = True, **kwargs):
        # 获取输入 x 的形状信息
        b, t, e, h, dim_h = *x.shape, self.heads, self.dim_head

        # 初始化记忆
        memories = default(memories, (None, None))
        mem, cmem = memories

        # 初始化空的记忆
        init_empty_mem = lambda: torch.empty(b, 0, e, **to(x))
        mem = default(mem, init_empty_mem)
        cmem = default(cmem, init_empty_mem)

        # 获取记忆的长度
        mem_len = mem.shape[1]
        cmem_len = cmem.shape[1]

        # 计算查询向量 q
        q = self.to_q(x)

        # 将记忆和输入 x 连接起来,获取键值对 k, v
        kv_input = torch.cat((cmem, mem, x), dim=1)
        kv_len = kv_input.shape[1]
        k, v = self.to_kv(kv_input).chunk(2, dim=-1)

        # 合并多头注意力的维度
        merge_heads = lambda x: reshape_dim(x, -1, (-1, dim_h)).transpose(1, 2)
        q, k, v = map(merge_heads, (q, k, v))

        # 扩展键值对 k, v 的维度
        k, v = map(lambda x: x.expand(-1, h, -1, -1), (k, v))

        # 计算点积注意力
        dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
        mask_value = max_neg_value(dots)

        # 添加位置编码
        if pos_emb is not None:
            pos_emb = pos_emb[:, -kv_len:].type(q.dtype)
            pos_dots = torch.einsum('bhid,hjd->bhij', q, pos_emb) * self.scale
            pos_dots = shift(pos_dots)
            dots = dots + pos_dots

        # 添加输入掩码
        if input_mask is not None:
            mask = input_mask[:, None, :, None] * input_mask[:, None, None, :]
            mask = F.pad(mask, (mem_len + cmem_len, 0), value = True)
            dots.masked_fill_(~mask, mask_value)

        # 创建掩码矩阵
        total_mem_len = mem_len + cmem_len
        mask = torch.ones(t, t + total_mem_len, **to(x)).triu_(diagonal = 1 + total_mem_len).bool()
        dots.masked_fill_(mask[None, None, ...], mask_value)

        # 计算注意力权重
        attn = dots.softmax(dim=-1)
        attn = self.attn_dropout(attn)

        # 计算输出
        out = torch.einsum('bhij,bhjd->bhid', attn, v)
        out = out.transpose(1, 2).reshape(b, t, -1)
        logits = self.to_out(out)
        logits = self.dropout(logits)

        # 复制记忆
        new_mem = mem
        new_cmem = cmem
        aux_loss = torch.zeros(1, requires_grad = True, **to(q))

        # 如果序列长度小于设定值或者不需要计算记忆,则直接返回结果
        if self.seq_len > t or not calc_memory:
            return logits, Memory(new_mem, new_cmem), aux_loss

        # 计算记忆和压缩记忆
        old_mem, new_mem = queue_fifo(mem, x, length = self.mem_len, dim = 1)
        old_mem_padding = old_mem.shape[1] % self.cmem_ratio

        # 对旧记忆进行填充
        if old_mem_padding != 0:
            old_mem = F.pad(old_mem, (0, 0, old_mem_padding, 0), value = 0.)

        # 如果旧记忆为空或者压缩记忆长度小于等于0,则直接返回结果
        if old_mem.shape[1] == 0 or self.cmem_len <= 0:
            return logits, Memory(new_mem, new_cmem), aux_loss

        # 压缩记忆
        compressed_mem = self.compress_mem_fn(old_mem.detach())
        old_cmem, new_cmem = split_at_index(1, -self.cmem_len, torch.cat((cmem, compressed_mem), dim=1))

        # 如果不处于训练状态,则直接返回结果
        if not self.training:
            return logits, Memory(new_mem, new_cmem), aux_loss

        # 计算训练时的压缩记忆辅助损失
        self.to_kv.weight.detach_()

        cmem_k, cmem_v = self.to_kv(compressed_mem).chunk(2, dim=-1)
        cmem_k, cmem_v = map(merge_heads, (cmem_k, cmem_v))
        cmem_k, cmem_v = map(lambda x: x.expand(-1, h, -1, -1), (cmem_k, cmem_v))

        old_mem_range = slice(- min(mem_len, self.mem_len) - self.seq_len, -self.seq_len)
        old_mem_k, old_mem_v = map(lambda x: x[:, :, old_mem_range].clone(), (k, v))

        q, old_mem_k, old_mem_v = map(torch.detach, (q, old_mem_k, old_mem_v))

        attn_fn = partial(full_attn, dropout_fn = self.reconstruction_attn_dropout)

        aux_loss = F.mse_loss(
            attn_fn(q, old_mem_k, old_mem_v),
            attn_fn(q, cmem_k, cmem_v)
        )

        return logits, Memory(new_mem, new_cmem), aux_loss
# 定义一个压缩变换器类,继承自 nn.Module
class CompressiveTransformer(nn.Module):
    # 初始化函数,接受多个参数
    def __init__(
        self,
        num_tokens,  # 标记的数量
        dim,  # 维度
        seq_len,  # 序列长度
        depth,  # 深度
        emb_dim = None,  # 嵌入维度,默认为 None
        memory_layers = None,  # 记忆层,默认为 None
        enhanced_recurrence = True,  # 增强循环,默认为 True
        mem_len = None,  # 记忆长度,默认为 None
        cmem_len = None,  # 压缩记忆长度,默认为 None
        cmem_ratio = 4,  # 压缩记忆比率,默认为 4
        heads = 8,  # 头数,默认为 8
        gru_gated_residual = True,  # GRU 门控残差,默认为 True
        mogrify_gru = False,  # Mogrify GRU,默认为 False
        attn_dropout = 0.,  # 注意力丢弃率,默认为 0
        ff_glu = False,  # FeedForward GLU,默认为 False
        ff_dropout = 0.,  # FeedForward 丢弃率,默认为 0
        attn_layer_dropout = 0.,  # 注意力层丢弃率,默认为 0
        reconstruction_attn_dropout = 0.,  # 重构注意力丢弃率,默认为 0
        reconstruction_loss_weight = 1.  # 重构损失权重,默认为 1
    ):
        super().__init__()  # 调用父类的初始化函数
        emb_dim = default(emb_dim, dim)  # 如果嵌入维度为 None,则使用维度
        mem_len = default(mem_len, seq_len)  # 如果记忆长度为 None,则使用序列长度
        cmem_len = default(cmem_len, mem_len // cmem_ratio)  # 如果压缩记忆长度为 None,则使用记忆长度除以压缩比率
        memory_layers = default(memory_layers, list(range(1, depth + 1)))  # 如果记忆层为 None,则使用范围为 1 到深度的列表

        assert mem_len >= seq_len, 'length of memory should be at least the sequence length'  # 断言记忆长度至少应该等于序列长度
        assert cmem_len >= (mem_len // cmem_ratio), f'length of compressed memory should be at least the memory length divided by the compression ratio {int(mem_len // cmem_ratio)}'  # 断言压缩记忆长度至少应该等于记忆长度除以压缩比率
        assert all([layer > 0 and layer <= depth for layer in memory_layers]), 'one of the indicated memory layers is invalid'  # 断言所有指定的记忆层都在有效范围内

        self.seq_len = seq_len  # 保存序列长度

        self.depth = depth  # 保存深度
        self.memory_layers = list(memory_layers)  # 保存记忆层列表
        self.enhanced_recurrence = enhanced_recurrence  # 保存增强循环标志

        self.token_emb = nn.Embedding(num_tokens, emb_dim)  # 创建标记嵌入层
        self.to_model_dim = nn.Identity() if emb_dim == dim else nn.Linear(emb_dim, dim)  # 如果嵌入维度等于维度,则使用恒等映射,否则使用线性映射

        seq_and_mem_len = seq_len + mem_len + cmem_len  # 计算序列和记忆长度之和
        self.pos_emb = nn.Parameter(torch.zeros(heads, seq_and_mem_len, dim // heads))  # 创建位置嵌入参数

        self.to_logits = nn.Sequential(
            nn.Identity() if emb_dim == dim else nn.Linear(dim, emb_dim),  # 如果嵌入维度等于维度,则使用恒等映射,否则使用线性映射
            nn.Linear(emb_dim, num_tokens)  # 线性映射到标记数量
        )

        wrapper = partial(GRUGating, dim, mogrify = mogrify_gru) if gru_gated_residual else Residual  # 根据 GRU 门控残差标志选择包装器

        self.attn_layers = nn.ModuleList([wrapper(PreNorm(dim, SelfAttention(dim, seq_len, mem_len, cmem_len, cmem_ratio, heads, dropout = attn_layer_dropout, attn_dropout = attn_dropout, reconstruction_attn_dropout = reconstruction_attn_dropout))) for _ in range(depth)])  # 创建注意力层列表
        self.ff_layers = nn.ModuleList([wrapper(PreNorm(dim, FeedForward(dim, dropout = ff_dropout, glu = ff_glu))) for _ in range(depth)])  # 创建前馈层列表

        self.reconstruction_loss_weight = reconstruction_loss_weight  # 保存重构损失权重
    # 前向传播函数,接受输入 x,记忆 memories 和掩码 mask
    def forward(self, x, memories = None, mask = None):
        # 对输入进行 token embedding
        x = self.token_emb(x)
        # 调整输入维度到模型维度
        x = self.to_model_dim(x)
        b, t, d = x.shape

        # 断言输入序列长度不超过指定的最大序列长度
        assert t <= self.seq_len, f'input contains a sequence length {t} that is greater than the designated maximum sequence length {self.seq_len}'

        # 初始化记忆
        memories = default(memories, (None, None))
        mem, cmem = memories

        num_memory_layers = len(self.memory_layers)
        # 初始化空记忆
        init_empty_mem = lambda: torch.empty(num_memory_layers, b, 0, d, **to(x))
        mem = default(mem, init_empty_mem)
        cmem = default(cmem, init_empty_mem)

        total_len = mem.shape[2] + cmem.shape[2] + self.seq_len
        # 获取位置编码
        pos_emb = self.pos_emb[:, (self.seq_len - t):total_len]

        next_mem = []
        next_cmem = []
        aux_loss = torch.tensor(0., requires_grad = True, **to(x))

        # 如果启用增强循环
        if self.enhanced_recurrence:
            mem = torch.roll(mem, -1, 0)
            cmem = torch.roll(cmem, -1, 0)

        # 迭代记忆
        mem_iter, cmem_iter = map(iterate_tensor, (mem, cmem))

        # 遍历注意力层和前馈层
        for ind, (attn, ff) in enumerate(zip(self.attn_layers, self.ff_layers)):
            layer_num = ind + 1

            use_memory = layer_num in self.memory_layers
            memories = (next(mem_iter), next(cmem_iter)) if use_memory else None

            # 执行注意力机制和前馈网络
            x, (mem_out, cmem_out), layer_aux_loss = attn(x, memories = memories, calc_memory = use_memory, input_mask = mask, pos_emb = pos_emb)
            x,  = ff(x)

            aux_loss = aux_loss + layer_aux_loss

            # 如果不使用记忆,则跳过
            if not use_memory:
                continue

            next_mem.append(mem_out)
            next_cmem.append(cmem_out)

        # 获取输出结果
        out = self.to_logits(x)

        # 将下一步记忆和压缩记忆堆叠并分离梯度
        next_mem, next_cmem = map(torch.stack, (next_mem, next_cmem))
        next_mem, next_cmem = map(torch.detach, (next_mem, next_cmem))

        # 计算辅助损失
        aux_loss = aux_loss * self.reconstruction_loss_weight / num_memory_layers
        # 返回输出、记忆和辅助损失
        return out, Memory(mem = next_mem, compressed_mem = next_cmem), aux_loss