Lucidrains-系列项目源码解析-三十四-

179 阅读21分钟

Lucidrains 系列项目源码解析(三十四)

.\lucidrains\equiformer-pytorch\setup.py

# 导入设置安装和查找包的函数
from setuptools import setup, find_packages

# 执行版本文件中的代码,将版本信息导入当前环境
exec(open('equiformer_pytorch/version.py').read())

# 设置包的元数据
setup(
  name = 'equiformer-pytorch',  # 包名
  packages = find_packages(exclude=[]),  # 查找包
  version = __version__,  # 版本号
  license='MIT',  # 许可证
  description = 'Equiformer - SE3/E3 Graph Attention Transformer for Molecules and Proteins',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  long_description_content_type = 'text/markdown',  # 长描述内容类型
  url = 'https://github.com/lucidrains/equiformer-pytorch',  # URL
  keywords = [  # 关键词
    'artificial intelligence',
    'deep learning',
    'transformers',
    'attention mechanism',
    'equivariance',
    'molecules',
    'proteins'
  ],
  install_requires=[  # 安装依赖
    'beartype',
    'einops>=0.6',
    'einx',
    'filelock',
    'opt-einsum',
    'taylor-series-linear-attention>=0.1.4',
    'torch>=1.6',
  ],
  setup_requires=[  # 设置需要的依赖
    'pytest-runner',
  ],
  tests_require=[  # 测试需要的依赖
    'pytest'
  ],
  include_package_data = True,  # 包含包数据
  classifiers=[  # 分类
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\equiformer-pytorch\tests\test_edges.py

# 导入 pytest 库
import pytest

# 导入 torch 库
import torch
# 从 equiformer_pytorch 包中导入 Equiformer 类
from equiformer_pytorch.equiformer_pytorch import Equiformer
# 从 equiformer_pytorch 包中导入 rot 函数
from equiformer_pytorch.irr_repr import rot
# 从 equiformer_pytorch 包中导入 torch_default_dtype 函数
from equiformer_pytorch.utils import torch_default_dtype

# 测试边的等变性

# 使用 pytest.mark.parametrize 装饰器,参数化测试函数
@pytest.mark.parametrize('l2_dist_attention', [True, False])
@pytest.mark.parametrize('reversible', [True, False])
def test_edges_equivariance(
    l2_dist_attention,
    reversible
):
    # 创建 Equiformer 模型对象
    model = Equiformer(
        num_tokens = 28,
        dim = 64,
        num_edge_tokens = 4,
        edge_dim = 16,
        depth = 2,
        input_degrees = 1,
        num_degrees = 3,
        l2_dist_attention = l2_dist_attention,
        reversible = reversible,
        init_out_zero = False,
        reduce_dim_out = True
    )

    # 生成随机原子索引
    atoms = torch.randint(0, 28, (2, 32))
    # 生成随机键索引
    bonds = torch.randint(0, 4, (2, 32, 32))
    # 生成随机坐标
    coors = torch.randn(2, 32, 3)
    # 创建掩码
    mask  = torch.ones(2, 32).bool()

    # 生成随机旋转矩阵
    R   = rot(*torch.randn(3))
    # 使用模型处理数据,得到输出
    _, out1 = model(atoms, coors @ R, mask, edges = bonds)
    # 使用模型处理数据,得到输出,并进行旋转
    out2 = model(atoms, coors, mask, edges = bonds)[1] @ R

    # 断言输出是否等变
    assert torch.allclose(out1, out2, atol = 1e-4), 'is not equivariant'

# 测试邻接矩阵的等变性

# 使用 pytest.mark.parametrize 装饰器,参数化测试函数
@pytest.mark.parametrize('l2_dist_attention', [True, False])
@pytest.mark.parametrize('reversible', [True, False])
def test_adj_mat_equivariance(
    l2_dist_attention,
    reversible
):
    # 创建 Equiformer 模型对象
    model = Equiformer(
        dim = 32,
        heads = 8,
        depth = 1,
        dim_head = 64,
        num_degrees = 2,
        valid_radius = 10,
        l2_dist_attention = l2_dist_attention,
        reversible = reversible,
        attend_sparse_neighbors = True,
        num_neighbors = 0,
        num_adj_degrees_embed = 2,
        max_sparse_neighbors = 8,
        init_out_zero = False,
        reduce_dim_out = True
    )

    # 生成随机特征
    feats = torch.randn(1, 128, 32)
    # 生成随机坐标
    coors = torch.randn(1, 128, 3)
    # 创建掩码
    mask  = torch.ones(1, 128).bool()

    # 创建邻接矩阵
    i = torch.arange(128)
    adj_mat = (i[:, None] <= (i[None, :] + 1)) & (i[:, None] >= (i[None, :] - 1))

    # 生成随机旋转矩阵
    R   = rot(*torch.randn(3))
    # 使用模型处理数据,得到输出
    _, out1 = model(feats, coors @ R, mask, adj_mat = adj_mat)
    # 使用模型处理数据,得到输出,并进行旋转
    out2 = model(feats, coors, mask, adj_mat = adj_mat)[1] @ R

    # 断言输出是否等变
    assert torch.allclose(out1, out2, atol = 1e-4), 'is not equivariant'

.\lucidrains\equiformer-pytorch\tests\test_equivariance.py

# 导入 pytest 库
import pytest

# 导入 torch 库
import torch
# 导入 Equiformer 类
from equiformer_pytorch.equiformer_pytorch import Equiformer
# 导入 rot 函数
from equiformer_pytorch.irr_repr import rot

# 导入 utils 模块中的函数
from equiformer_pytorch.utils import (
    torch_default_dtype,
    cast_tuple,
    to_order,
    exists
)

# 测试输出形状

# 使用参数化装饰器定义测试函数
@pytest.mark.parametrize('dim', [32])
def test_transformer(dim):
    # 创建 Equiformer 模型对象
    model = Equiformer(
        dim = dim,
        depth = 2,
        num_degrees = 3,
        init_out_zero = False
    )

    # 生成随机输入特征、坐标和掩码
    feats = torch.randn(1, 32, dim)
    coors = torch.randn(1, 32, 3)
    mask  = torch.ones(1, 32).bool()

    # 调用模型进行前向传播
    type0, _ = model(feats, coors, mask)
    # 断言输出形状是否符合预期
    assert type0.shape == (1, 32, dim), 'output must be of the right shape'

# 测试等变性

# 使用参数化装饰器定义测试函数
@pytest.mark.parametrize('dim', [32, (4, 8, 16)])
@pytest.mark.parametrize('dim_in', [32, (32, 32)])
@pytest.mark.parametrize('l2_dist_attention', [True, False])
@pytest.mark.parametrize('reversible', [True, False])
def test_equivariance(
    dim,
    dim_in,
    l2_dist_attention,
    reversible
):
    # 将 dim_in 转换为元组
    dim_in = cast_tuple(dim_in)

    # 创建 Equiformer 模型对象
    model = Equiformer(
        dim = dim,
        dim_in = dim_in,
        input_degrees = len(dim_in),
        depth = 2,
        l2_dist_attention = l2_dist_attention,
        reversible = reversible,
        num_degrees = 3,
        reduce_dim_out = True,
        init_out_zero = False
    )

    # 生成不同度数的随机输入特征
    feats = {deg: torch.randn(1, 32, dim, to_order(deg)) for deg, dim in enumerate(dim_in)}
    type0, type1 = feats[0], feats.get(1, None)

    # 生成随机输入坐标和掩码
    coors = torch.randn(1, 32, 3)
    mask  = torch.ones(1, 32).bool()

    # 生成随机旋转矩阵 R
    R   = rot(*torch.randn(3))

    # 创建可能旋转后的特征字典
    maybe_rotated_feats = {0: type0}

    # 如果存在第二个特征,则将其旋转后加入字典
    if exists(type1):
        maybe_rotated_feats[1] = type1 @ R

    # 调用模型进行前向传播
    _, out1 = model(maybe_rotated_feats, coors @ R, mask)
    out2 = model(feats, coors, mask)[1] @ R

    # 断言两次前向传播结果是否等价
    assert torch.allclose(out1, out2, atol = 1e-4), 'is not equivariant'

.\lucidrains\ESBN-pytorch\esbn_pytorch\esbn_pytorch.py

# 导入 torch 库
import torch
# 从 functools 库中导入 partial 函数
from functools import partial
# 从 torch 库中导入 nn 模块和 einsum 函数
from torch import nn, einsum
# 从 einops 库中导入 repeat 和 rearrange 函数
from einops import repeat, rearrange

# 定义辅助函数

# 判断变量是否存在的函数
def exists(val):
    return val is not None

# 安全地拼接张量的函数
def safe_cat(t, el, dim = 0):
    if not exists(t):
        return el
    return torch.cat((t, el), dim = dim)

# 映射函数的函数
def map_fn(fn, *args, **kwargs):
    def inner(*arr):
        return map(lambda t: fn(t, *args, **kwargs), arr)
    return inner

# 定义类

# 定义 ESBN 类,继承自 nn.Module 类
class ESBN(nn.Module):
    # 初始化函数
    def __init__(
        self,
        *,
        value_dim = 64,
        key_dim = 64,
        hidden_dim = 512,
        output_dim = 4,
        encoder = None
    ):
        super().__init__()
        # 初始化隐藏状态、细胞状态和键
        self.h0 = torch.zeros(hidden_dim)
        self.c0 = torch.zeros(hidden_dim)
        self.k0 = torch.zeros(key_dim + 1)

        # 定义 LSTMCell 层、线性层和全连接层
        self.rnn = nn.LSTMCell(key_dim + 1, hidden_dim)
        self.to_gate = nn.Linear(hidden_dim, 1)
        self.to_key = nn.Linear(hidden_dim, key_dim)
        self.to_output = nn.Linear(hidden_dim, output_dim)

        # 定义编码器
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size = 4, stride = 2),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size = 4, stride = 2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size = 4, stride = 2),
            nn.Flatten(1),
            nn.Linear(4 * 64, value_dim)
        ) if not exists(encoder) else encoder

        # 定义置信度的线性层
        self.to_confidence = nn.Linear(1, 1)

    # 前向传播函数
    def forward(self, images):
        # 获取 batch 大小
        b = images.shape[1]
        Mk = None
        Mv = None

        # 将隐藏状态、细胞状态和键重复到 batch 维度
        hx, cx, kx, k0 = map_fn(repeat, 'd -> b d', b = b)(self.h0, self.c0, self.k0, self.k0)
        out = []

        # 遍历图像序列
        for ind, image in enumerate(images):
            is_first = ind == 0
            z = self.encoder(image)
            hx, cx = self.rnn(kx, (hx, cx))
            y, g, kw = self.to_output(hx), self.to_gate(hx), self.to_key(hx)

            if is_first:
                kx = k0
            else:
                # 注意力机制
                sim = einsum('b n d, b d -> b n', Mv, z)
                wk = sim.softmax(dim = -1)

                # 计算置信度
                sim, wk = map_fn(rearrange, 'b n -> b n ()')(sim, wk)
                ck = self.to_confidence(sim).sigmoid()

                # 拼接置信度到记忆键中,然后根据注意力对记忆值进行加权求和
                kx = g.sigmoid() * (wk * torch.cat((Mk, ck), dim = -1)).sum(dim = 1)

            kw, z = map_fn(rearrange, 'b d -> b () d')(kw, z)
            Mk = safe_cat(Mk, kw, dim = 1)
            Mv = safe_cat(Mv, z, dim = 1)
            out.append(y)

        # 将输出堆叠成张量
        return torch.stack(out)

.\lucidrains\ESBN-pytorch\esbn_pytorch\__init__.py

# 从 esbn_pytorch.esbn_pytorch 模块中导入 ESBN 类
from esbn_pytorch.esbn_pytorch import ESBN

Emerging Symbol Binding Network (ESBN) - Pytorch

Usable implementation of Emerging Symbol Binding Network (ESBN), in Pytorch. They propose to have the main recurrent neural network interact with the input image representations only through a set of memory key / values.

The input image representation are cast as memory values, and are explicitly bound to memory keys that are generated by the network. The network generates the memory keys after getting a sum of all previous memory keys weighted by the similarity of the incoming representation to the set of memory values in storage.

This decoupling / indirection of sensory to abstract processing allows the network to outperform all previous approaches, including transformers.

Usage

import torch
from esbn_pytorch import ESBN

model = ESBN(
    value_dim = 64,
    key_dim = 64,
    hidden_dim = 512,
    output_dim = 4
)

images = torch.randn(3, 2, 3, 32, 32) # (n, b, c, h, w)
model(images) # (3, 2, 4) # (n, b, o)

Citations

@misc{webb2020emergent,
    title={Emergent Symbols through Binding in External Memory}, 
    author={Taylor W. Webb and Ishan Sinha and Jonathan D. Cohen},
    year={2020},
    eprint={2012.14601},
    archivePrefix={arXiv},
    primaryClass={cs.AI}
}

.\lucidrains\ESBN-pytorch\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  # 包的名称
  name = 'esbn-pytorch',
  # 查找并包含所有包
  packages = find_packages(),
  # 版本号
  version = '0.0.4',
  # 许可证
  license='MIT',
  # 描述
  description = 'Emergent Symbol Binding Network - Pytorch',
  # 作者
  author = 'Phil Wang',
  # 作者邮箱
  author_email = 'lucidrains@gmail.com',
  # 项目链接
  url = 'https://github.com/lucidrains/ESBN-pytorch',
  # 关键词
  keywords = [
    'artificial intelligence',
    'deep learning',
    'neuro-symbolic',
    'abstract reasoning',
    'memory'
  ],
  # 安装依赖
  install_requires=[
    'torch>=1.6',
    'einops>=0.3'
  ],
  # 分类
  classifiers=[
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\esbn-transformer\esbn_transformer\esbn_transformer.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块和 einsum 函数
from torch import nn, einsum
# 从 torch 库中导入 nn.functional 模块,并重命名为 F
import torch.nn.functional as F

# 从 einops 库中导入 rearrange、repeat、reduce 函数
from einops import rearrange, repeat, reduce

# 从 einops.layers.torch 库中导入 Rearrange 类
from einops.layers.torch import Rearrange

# 辅助函数

# 判断变量是否存在的函数
def exists(val):
    return val is not None

# 如果变量存在则返回该变量,否则返回默认值的函数
def default(val, d):
    return val if exists(val) else d

# 返回指定数据类型的最大负值的函数
def max_neg_value(t):
    return -torch.finfo(t.dtype).max

# 对所有张量进行重排列的函数
def rearrange_all(tensors, *args, **kwargs):
    return map(lambda t: rearrange(t, *args, **kwargs), tensors)

# 前馈网络

# 分组层归一化类
class GroupLayerNorm(nn.Module):
    def __init__(self, dim, groups = 1, eps = 1e-5):
        super().__init__()
        self.eps = eps
        self.groups = groups
        self.g = nn.Parameter(torch.ones(1, groups, dim, 1))
        self.b = nn.Parameter(torch.zeros(1, groups, dim, 1))

    def forward(self, x):
        x = rearrange(x, 'b (g d) n -> b g d n', g = self.groups)
        std = torch.var(x, dim = 2, unbiased = False, keepdim = True).sqrt()
        mean = torch.mean(x, dim = 2, keepdim = True)
        out = (x - mean) / (std + self.eps) * self.g + self.b
        return rearrange(out, 'b g d n -> b (g d) n')

# 预归一化类
class PreNorm(nn.Module):
    def __init__(
        self,
        dim,
        fn,
        groups = 1
    ):
        super().__init__()
        self.norm = GroupLayerNorm(dim, groups = groups)
        self.fn = fn

    def forward(self, x, **kwargs):
        x = self.norm(x)
        return self.fn(x, **kwargs)

# 前馈网络类
class FeedForward(nn.Module):
    def __init__(
        self,
        *,
        dim,
        mult = 4,
        groups = 1
    ):
        super().__init__()
        input_dim = dim * groups
        hidden_dim = dim * mult * groups

        self.net = nn.Sequential(
            nn.Conv1d(input_dim, hidden_dim, 1, groups = groups),
            nn.GELU(),
            nn.Conv1d(hidden_dim, input_dim, 1, groups = groups)
        )

    def forward(self, x):
        return self.net(x)

# 注意力机制类
class Attention(nn.Module):
    def __init__(
        self,
        *,
        dim,
        dim_head = 64,
        heads = 8,
        causal = False,
        groups = 1
    ):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.groups = groups
        self.heads = heads
        self.causal = causal
        input_dim = dim * groups
        inner_dim = dim_head * heads * groups

        self.to_q = nn.Conv1d(input_dim, inner_dim, 1, bias = False)
        self.to_kv = nn.Conv1d(input_dim, inner_dim * 2, 1, bias = False)
        self.to_out = nn.Conv1d(inner_dim, input_dim, 1)

    def forward(self, x, mask = None, context = None):
        n, device, h, g, causal = x.shape[2], x.device, self.heads, self.groups, self.causal
        context = default(context, x)

        q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = 1))
        q, k, v = rearrange_all((q, k, v), 'b (g h d) n -> (b g h) n d', g = g, h = h)

        q = q * self.scale

        sim = einsum('b i d, b j d -> b i j', q, k)

        if g > 1:
            # 在存在符号的情况下,允许网络使用来自感官侧的注意力矩阵绑定符号
            sim = rearrange(sim, '(b g h) i j -> b g h i j', g = g, h = h)
            sim = sim.cumsum(dim = 1)
            sim = rearrange(sim, 'b g h i j -> (b g h) i j')

        if exists(mask):
            mask = repeat(mask, 'b n -> (b g h) n', h = h, g = g)
            mask = rearrange(mask, 'b n -> b n ()') * rearrange(mask, 'b n -> b () n')
            mask_value = max_neg_value(sim)
            sim = sim.masked_fill(~mask, mask_value)

        if causal:
            causal_mask = torch.ones((n, n), device = device).triu(1).bool()
            mask_value = max_neg_value(sim)
            sim = sim.masked_fill(causal_mask, mask_value)

        attn = sim.softmax(dim = -1)
        out = einsum('b i j, b j d -> b i d', attn, v)
        out = rearrange(out, '(b g h) n d -> b (g h d) n', h = h, g = g)
        return self.to_out(out)

# Transformer 块类
class TransformerBlock(nn.Module):
    # 初始化函数,设置模型参数
    def __init__(
        self,
        *,
        dim,  # 输入维度
        causal = False,  # 是否使用因果注意力
        dim_head = 64,  # 注意力头的维度
        heads = 8,  # 注意力头的数量
        ff_mult = 4,  # FeedForward 层的倍数
        groups = 1  # 分组数
    ):
        # 调用父类的初始化函数
        super().__init__()
        # 初始化注意力层,包括预处理和注意力计算
        self.attn = PreNorm(dim, Attention(dim = dim, dim_head = dim_head, heads = heads, causal = causal, groups = groups), groups = groups)
        # 初始化前馈神经网络层,包括预处理和前馈计算
        self.ff = PreNorm(dim, FeedForward(dim = dim, mult = ff_mult, groups = groups), groups = groups)
    
    # 前向传播函数
    def forward(self, x, mask = None):
        # 使用注意力层处理输入数据,并将结果与输入相加
        x = self.attn(x, mask = mask) + x
        # 使用前馈神经网络层处理数据,并将结果与输入相加
        x = self.ff(x) + x
        # 返回处理后的数据
        return x
# 主类定义

class EsbnTransformer(nn.Module):
    def __init__(
        self,
        *,
        dim,  # 维度
        depth,  # 深度
        num_tokens,  # 令牌数量
        max_seq_len,  # 最大序列长度
        causal = False,  # 是否因果
        dim_head = 64,  # 头部维度
        heads = 8,  # 头部数量
        ff_mult = 4  # FeedForward 层倍增因子
    ):
        super().__init__()
        self.dim = dim
        self.max_seq_len = max_seq_len
        self.token_emb = nn.Embedding(num_tokens, dim)  # 令牌嵌入层
        self.pos_emb = nn.Embedding(max_seq_len, dim)  # 位置嵌入层

        self.layers = nn.ModuleList([])
        self.pre_transformer_block = TransformerBlock(dim = dim, causal = causal, dim_head = dim_head, heads = heads)  # 前置 Transformer 块

        self.symbols = nn.Parameter(torch.randn(max_seq_len, dim))  # 符号参数

        for _ in range(depth):
            self.layers.append(TransformerBlock(dim = dim, causal = causal, dim_head = dim_head, heads = heads, groups = 2))  # 添加 Transformer 块到层列表

        self.post_transformer_block = TransformerBlock(dim = dim, causal = causal, dim_head = dim_head, heads = heads,)  # 后置 Transformer 块

        self.to_logits = nn.Sequential(
            Rearrange('b d n -> b n d'),  # 重新排列张量维度
            nn.LayerNorm(dim),  # 层归一化
            nn.Linear(dim, num_tokens)  # 线性层
        )

    def forward(self, x, mask = None):
        b, n, d, device = *x.shape, self.dim, x.device
        x = self.token_emb(x)  # 通过令牌嵌入层获取输入张量的嵌入表示

        pos_emb = self.pos_emb(torch.arange(n, device = device))  # 获取位置嵌入
        pos_emb = rearrange(pos_emb, 'n d -> () n d')  # 重新排列位置嵌入张量维度

        x = x + pos_emb  # 将位置嵌入加到输入张量上
        x = rearrange(x, 'b n d -> b d n')  # 重新排列张量维度

        x = self.pre_transformer_block(x, mask = mask)  # 前置 Transformer 块处理输入张量

        x = rearrange(x, 'b d n -> b () d n')  # 重新排列张量维度
        symbols = self.symbols[:, :n]  # 获取符号参数

        symbols = repeat(symbols, 'n d -> b () d n', b = b)  # 重复符号参数以匹配输入张量维度
        x = torch.cat((x, symbols), dim = 1)  # 拼接张量
        x = rearrange(x, 'b ... n -> b (...) n')  # 重新排列张量维度

        for block in self.layers:
            x = block(x, mask = mask)  # 遍历并应用每个 Transformer 块

        x = rearrange(x, 'b (s d) n -> b s d n', s = 2)  # 重新��列张量维度
        x = x[:, 1]  # 获取特定索引的张量

        x = self.post_transformer_block(x, mask = mask)  # 后置 Transformer 块处理张量
        return self.to_logits(x)  # 返回处理后的张量

.\lucidrains\esbn-transformer\esbn_transformer\__init__.py

# 从 esbn_transformer.esbn_transformer 模块中导入 EsbnTransformer 类
from esbn_transformer.esbn_transformer import EsbnTransformer

ESBN Transformer (wip)

An attempt to merge ESBN with Transformers, to endow Transformers with the ability to emergently bind symbols and improve extrapolation. The resulting architecture will be benchmarked with the Give-N task as outlined in this paper, commonly used to assess whether a child has acquired an understanding of counting.

Usage

import torch
from esbn_transformer import EsbnTransformer

model = EsbnTransformer(
    num_tokens = 256,
    dim = 512,
    depth = 4,
    max_seq_len = 512
)

x = torch.randint(0, 256, (1, 512))
out = model(x) # (1, 512, 256)

Citations

@misc{webb2020emergent,
    title   = {Emergent Symbols through Binding in External Memory}, 
    author  = {Taylor W. Webb and Ishan Sinha and Jonathan D. Cohen},
    year    = {2020},
    eprint  = {2012.14601},
    archivePrefix = {arXiv},
    primaryClass = {cs.AI}
}
@misc{dulberg2021modelling,
    title   = {Modelling the development of counting with memory-augmented neural networks}, 
    author  = {Zack Dulberg and Taylor Webb and Jonathan Cohen},
    year    = {2021},
    eprint  = {2105.10577},
    archivePrefix = {arXiv},
    primaryClass = {cs.AI}
}

.\lucidrains\ETSformer-pytorch\etsformer_pytorch\etsformer_pytorch.py

# 从 math 模块中导入 pi 常数
from math import pi
# 从 collections 模块中导入 namedtuple 类
from collections import namedtuple

# 导入 torch 库
import torch
# 从 torch.nn.functional 模块中导入 F
import torch.nn.functional as F
# 从 torch 模块中导入 nn 和 einsum
from torch import nn, einsum

# 从 scipy.fftpack 模块中导入 next_fast_len 函数
from scipy.fftpack import next_fast_len
# 从 einops 模块中导入 rearrange 和 repeat 函数
from einops import rearrange, repeat
# 从 einops.layers.torch 模块中导入 Rearrange 类
from einops.layers.torch import Rearrange

# 定义一个名为 Intermediates 的命名元组,包含 growth_latents、seasonal_latents 和 level_output 三个字段
Intermediates = namedtuple('Intermediates', ['growth_latents', 'seasonal_latents', 'level_output'])

# 定义一个名为 exists 的函数,用于判断值是否存在
def exists(val):
    return val is not None

# 定义一个名为 fourier_extrapolate 的函数,用于对信号进行傅立叶外推
def fourier_extrapolate(signal, start, end):
    # 获取信号所在设备
    device = signal.device
    # 对信号进行傅立叶变换
    fhat = torch.fft.fft(signal)
    fhat_len = fhat.shape[-1]
    # 生成时间序列
    time = torch.linspace(start, end - 1, end - start, device=device, dtype=torch.complex64)
    # 生成频率序列
    freqs = torch.linspace(0, fhat_len - 1, fhat_len, device=device, dtype=torch.complex64)
    # 计算傅立叶外推结果
    res = fhat[..., None, :] * (1.j * 2 * pi * freqs[..., None, :] * time[..., :, None] / fhat_len).exp() / fhat_len
    return res.sum(dim=-1).real

# 定义一个名为 InputEmbedding 的函数,用于输入嵌入
def InputEmbedding(time_features, model_dim, kernel_size=3, dropout=0.):
    return nn.Sequential(
        Rearrange('b n d -> b d n'),
        nn.Conv1d(time_features, model_dim, kernel_size=kernel_size, padding=kernel_size // 2),
        nn.Dropout(dropout),
        Rearrange('b d n -> b n d'),
    )

# 定义一个名为 FeedForward 的函数,用于前馈网络
def FeedForward(dim, mult=4, dropout=0.):
    return nn.Sequential(
        nn.Linear(dim, dim * mult),
        nn.Sigmoid(),
        nn.Dropout(dropout),
        nn.Linear(dim * mult, dim),
        nn.Dropout(dropout)
    )

# 定义一个名为 FeedForwardBlock 的类,用于前馈网络块
class FeedForwardBlock(nn.Module):
    def __init__(
        self,
        *,
        dim,
        **kwargs
    ):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.ff = FeedForward(dim, **kwargs)
        self.post_norm = nn.LayerNorm(dim)

    def forward(self, x):
        x = self.norm(x)
        return self.post_norm(x + self.ff(x))

# encoder 相关类

## 多头指数平滑注意力机制
# 定义一个名为 conv1d_fft 的函数,用于一维卷积和快速傅立叶变换
def conv1d_fft(x, weights, dim=-2, weight_dim=-1):
    # 算法 3
    N = x.shape[dim]
    M = weights.shape[weight_dim]

    fast_len = next_fast_len(N + M - 1)

    f_x = torch.fft.rfft(x, n=fast_len, dim=dim)
    f_weight = torch.fft.rfft(weights, n=fast_len, dim=weight_dim)

    f_v_weight = f_x * rearrange(f_weight.conj(), '... -> ... 1')
    out = torch.fft.irfft(f_v_weight, fast_len, dim=dim)
    out = out.roll(-1, dims=(dim,))

    indices = torch.arange(start=fast_len - N, end=fast_len, dtype=torch.long, device=x.device)
    out = out.index_select(dim, indices)
    return out

# 定义一个名为 MHESA 的类,用于多头指数平滑注意力机制
class MHESA(nn.Module):
    def __init__(
        self,
        *,
        dim,
        heads=8,
        dropout=0.,
        norm_heads=False
    ):
        super().__init__()
        self.heads = heads
        self.initial_state = nn.Parameter(torch.randn(heads, dim // heads))

        self.dropout = nn.Dropout(dropout)
        self.alpha = nn.Parameter(torch.randn(heads))

        self.norm_heads = nn.Sequential(
            Rearrange('b n (h d) -> b (h d) n', h=heads),
            nn.GroupNorm(heads, dim),
            Rearrange('b (h d) n -> b n (h d)', h=heads)
        ) if norm_heads else nn.Identity()

        self.project_in = nn.Linear(dim, dim)
        self.project_out = nn.Linear(dim, dim)

    # 定义一个名为 naive_Aes 的方法,用于执行简单指数平滑
    def naive_Aes(self, x, weights):
        n, h = x.shape[-2], self.heads

        # 在附录 A.1 中 - 算法 2

        arange = torch.arange(n, device=x.device)

        weights = repeat(weights, '... l -> ... t l', t=n)
        indices = repeat(arange, 'l -> h t l', h=h, t=n)

        indices = (indices - rearrange(arange + 1, 't -> 1 t 1')) % n

        weights = weights.gather(-1, indices)
        weights = self.dropout(weights)

        # 因果关系

        weights = weights.tril()

        # 矩阵相乘

        output = einsum('b h n d, h m n -> b h m d', x, weights)
        return output
    # 定义前向传播函数,接受输入 x 和是否使用 naive 模式的标志
    def forward(self, x, naive = False):
        # 获取输入 x 的形状信息,包括 batch size (b), 序列长度 (n), 特征维度 (d), 头数 (h), 设备信息 (device)
        b, n, d, h, device = *x.shape, self.heads, x.device

        # 线性投影输入数据
        x = self.project_in(x)

        # 将投影后的数据按头数拆分
        x = rearrange(x, 'b n (h d) -> b h n d', h = h)

        # 计算时间差异
        x = torch.cat((
            repeat(self.initial_state, 'h d -> b h 1 d', b = b),
            x
        ), dim = -2)

        x = x[:, :, 1:] - x[:, :, :-1]

        # 准备指数 alpha
        alpha = self.alpha.sigmoid()
        alpha = rearrange(alpha, 'h -> h 1')

        # 计算权重
        arange = torch.arange(n, device = device)
        weights = alpha * (1 - alpha) ** torch.flip(arange, dims = (0,))

        # 根据是否使用 naive 模式选择不同的计算方式
        if naive:
            output = self.naive_Aes(x, weights)
        else:
            output = conv1d_fft(x, weights)

        # 计算初始状态的贡献
        init_weight = (1 - alpha) ** (arange + 1)
        init_output = rearrange(init_weight, 'h n -> h n 1') * rearrange(self.initial_state, 'h d -> h 1 d')

        output = output + init_output

        # 合并头部信息
        output = rearrange(output, 'b h n d -> b n (h d)')

        # 对输出进行规范化处理
        output = self.norm_heads(output)

        # 返回输出结果
        return self.project_out(output)
## frequency attention

# 定义频率注意力模块
class FrequencyAttention(nn.Module):
    def __init__(
        self,
        *,
        K = 4,
        dropout = 0.
    ):
        super().__init__()
        self.K = K
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # 对输入数据进行傅立叶变换
        freqs = torch.fft.rfft(x, dim = 1)

        # 获取振幅

        amp = freqs.abs()
        amp = self.dropout(amp)

        # 获取前K个振幅值 - 用于季节性,被标记为注意力

        topk_amp, _ = amp.topk(k = self.K, dim = 1, sorted = True)

        # 掩盖所有振幅低于前K个最小值的频率

        topk_freqs = freqs.masked_fill(amp < topk_amp[:, -1:], 0.+0.j)

        # 反向傅立叶变换

        return torch.fft.irfft(topk_freqs, dim = 1)

## level module

# 定义水平模块
class Level(nn.Module):
    def __init__(self, time_features, model_dim):
        super().__init__()
        self.alpha = nn.Parameter(torch.Tensor([0.]))
        self.to_growth = nn.Linear(model_dim, time_features)
        self.to_seasonal = nn.Linear(model_dim, time_features)

    def forward(self, x, latent_growth, latent_seasonal):
        # 按附录A.2中的方程式

        n, device = x.shape[1], x.device

        alpha = self.alpha.sigmoid()

        arange = torch.arange(n, device = device)
        powers = torch.flip(arange, dims = (0,))

        # 用于具有季节性项的原始时间序列信号(从频率注意力中减去)的Aes

        seasonal =self.to_seasonal(latent_seasonal)
        Aes_weights = alpha * (1 - alpha) ** powers
        seasonal_normalized_term = conv1d_fft(x - seasonal, Aes_weights)

        # 辅助项

        growth = self.to_growth(latent_growth)
        growth_smoothing_weights = (1 - alpha) ** powers
        growth_term = conv1d_fft(growth, growth_smoothing_weights)

        return seasonal_normalized_term + growth_term

# 解码器类

class LevelStack(nn.Module):
    def forward(self, x, num_steps_forecast):
        return repeat(x[:, -1], 'b d -> b n d', n = num_steps_forecast)

class GrowthDampening(nn.Module):
    def __init__(
        self,
        dim,
        heads = 8
    ):
        super().__init__()
        self.heads = heads
        self.dampen_factor = nn.Parameter(torch.randn(heads))

    def forward(self, growth, *, num_steps_forecast):
        device, h = growth.device, self.heads

        dampen_factor = self.dampen_factor.sigmoid()

        # 类似于level stack,它获取最后一个增长用于预测

        last_growth = growth[:, -1]
        last_growth = rearrange(last_growth, 'b l (h d) -> b l 1 h d', h = h)

        # 准备每个头部的减弱因子和幂

        dampen_factor = rearrange(dampen_factor, 'h -> 1 1 1 h 1')
        powers = (torch.arange(num_steps_forecast, device = device) + 1)
        powers = rearrange(powers, 'n -> 1 1 n 1 1')

        # 遵循论文中的Eq(2)

        dampened_growth = last_growth * (dampen_factor ** powers).cumsum(dim = 2)
        return rearrange(dampened_growth, 'b l n h d -> b l n (h d)')

# 主类

class ETSFormer(nn.Module):
    def __init__(
        self,
        *,
        model_dim,
        time_features = 1,
        embed_kernel_size = 3,
        layers = 2,
        heads = 8,
        K = 4,
        dropout = 0.
    ):
        # 调用父类的构造函数
        super().__init__()
        # 断言模型维度必须能够被头数整除
        assert (model_dim % heads) == 0, 'model dimension must be divisible by number of heads'
        # 初始化模型维度和时间特征
        self.model_dim = model_dim
        self.time_features = time_features

        # 创建输入嵌入层
        self.embed = InputEmbedding(time_features, model_dim, kernel_size = embed_kernel_size, dropout = dropout)

        # 初始化编码器层列表
        self.encoder_layers = nn.ModuleList([])

        # 循环创建编码器层
        for ind in range(layers):
            is_last_layer = ind == (layers - 1)

            # 添加编码器层
            self.encoder_layers.append(nn.ModuleList([
                FrequencyAttention(K = K, dropout = dropout),
                MHESA(dim = model_dim, heads = heads, dropout = dropout),
                FeedForwardBlock(dim = model_dim) if not is_last_layer else None,
                Level(time_features = time_features, model_dim = model_dim)
            ]))

        # 创建增长阻尼模块
        self.growth_dampening_module = GrowthDampening(dim = model_dim, heads = heads)

        # 线性层将潜在变量转换为时间特征
        self.latents_to_time_features = nn.Linear(model_dim, time_features)
        # 创建级别堆栈
        self.level_stack = LevelStack()

    def forward(
        self,
        x,
        *,
        num_steps_forecast = 0,
        return_latents = False
    ):
        # 检查输入是否只有一个时间特征
        one_time_feature = x.ndim == 2

        if one_time_feature:
            x = rearrange(x, 'b n -> b n 1')

        z = self.embed(x)

        latent_growths = []
        latent_seasonals = []

        # 遍历编码器层
        for freq_attn, mhes_attn, ff_block, level in self.encoder_layers:
            latent_seasonal = freq_attn(z)
            z = z - latent_seasonal

            latent_growth = mhes_attn(z)
            z = z - latent_growth

            if exists(ff_block):
                z = ff_block(z)

            x = level(x, latent_growth, latent_seasonal)

            latent_growths.append(latent_growth)
            latent_seasonals.append(latent_seasonal)

        latent_growths = torch.stack(latent_growths, dim = -2)
        latent_seasonals = torch.stack(latent_seasonals, dim = -2)

        latents = Intermediates(latent_growths, latent_seasonals, x)

        if num_steps_forecast == 0:
            return latents

        latent_seasonals = rearrange(latent_seasonals, 'b n l d -> b l d n')
        extrapolated_seasonals = fourier_extrapolate(latent_seasonals, x.shape[1], x.shape[1] + num_steps_forecast)
        extrapolated_seasonals = rearrange(extrapolated_seasonals, 'b l d n -> b l n d')

        dampened_growths = self.growth_dampening_module(latent_growths, num_steps_forecast = num_steps_forecast)
        level = self.level_stack(x, num_steps_forecast = num_steps_forecast)

        summed_latents = dampened_growths.sum(dim = 1) + extrapolated_seasonals.sum(dim = 1)
        forecasted = level + self.latents_to_time_features(summed_latents)

        if one_time_feature:
            forecasted = rearrange(forecasted, 'b n 1 -> b n')

        if return_latents:
            return forecasted, latents

        return forecasted
# 分类包装器

class MultiheadLayerNorm(nn.Module):
    def __init__(self, dim, heads = 1, eps = 1e-5):
        super().__init__()
        self.eps = eps
        self.g = nn.Parameter(torch.ones(heads, 1, dim))  # 初始化可学习参数 g
        self.b = nn.Parameter(torch.zeros(heads, 1, dim))  # 初始化可学习参数 b

    def forward(self, x):
        std = torch.var(x, dim = -1, unbiased = False, keepdim = True).sqrt()  # 计算标准差
        mean = torch.mean(x, dim = -1, keepdim = True)  # 计算均值
        return (x - mean) / (std + self.eps) * self.g + self.b  # 返回归一化后的结果

class ClassificationWrapper(nn.Module):
    def __init__(
        self,
        *,
        etsformer,
        num_classes = 10,
        heads = 16,
        dim_head = 32,
        level_kernel_size = 3,
        growth_kernel_size = 3,
        seasonal_kernel_size = 3,
        dropout = 0.
    ):
        super().__init__()
        assert isinstance(etsformer, ETSFormer)
        self.etsformer = etsformer
        model_dim = etsformer.model_dim
        time_features = etsformer.time_features

        inner_dim = dim_head * heads
        self.scale = dim_head ** -0.5
        self.dropout = nn.Dropout(dropout)

        self.queries = nn.Parameter(torch.randn(heads, dim_head))  # 初始化查询参数

        self.growth_to_kv = nn.Sequential(
            Rearrange('b n d -> b d n'),  # 重新排列张量维度
            nn.Conv1d(model_dim, inner_dim * 2, growth_kernel_size, bias = False, padding = growth_kernel_size // 2),  # 一维卷积层
            Rearrange('... (kv h d) n -> ... (kv h) n d', kv = 2, h = heads),  # 重新排列张量维度
            MultiheadLayerNorm(dim_head, heads = 2 * heads),  # 多头层归一化
        )

        self.seasonal_to_kv = nn.Sequential(
            Rearrange('b n d -> b d n'),  # 重新排列张量维度
            nn.Conv1d(model_dim, inner_dim * 2, seasonal_kernel_size, bias = False, padding = seasonal_kernel_size // 2),  # 一维卷积层
            Rearrange('... (kv h d) n -> ... (kv h) n d', kv = 2, h = heads),  # 重新排列张量维度
            MultiheadLayerNorm(dim_head, heads = 2 * heads),  # 多头层归一化
        )

        self.level_to_kv = nn.Sequential(
            Rearrange('b n t -> b t n'),  # 重新排列张量维度
            nn.Conv1d(time_features, inner_dim * 2, level_kernel_size, bias = False, padding = level_kernel_size // 2),  # 一维卷积层
            Rearrange('b (kv h d) n -> b (kv h) n d', kv = 2, h = heads),  # 重新排列张量维度
            MultiheadLayerNorm(dim_head, heads = 2 * heads),  # 多头层归一化
        )

        self.to_out = nn.Linear(inner_dim, model_dim)  # 线性变换层

        self.to_logits = nn.Sequential(
            nn.LayerNorm(model_dim),  # 层归一化
            nn.Linear(model_dim, num_classes)  # 线性变换层
        )

    def forward(self, timeseries):
        latent_growths, latent_seasonals, level_output = self.etsformer(timeseries)  # 获取ETSFormer的输出

        latent_growths = latent_growths.mean(dim = -2)  # 沿着指定维度计算均值
        latent_seasonals = latent_seasonals.mean(dim = -2)  # 沿着指定维度计算均值

        # queries, key, values

        q = self.queries * self.scale  # 缩放查询参数

        kvs = torch.cat((
            self.growth_to_kv(latent_growths),  # 经过growth_to_kv处理
            self.seasonal_to_kv(latent_seasonals),  # 经过seasonal_to_kv处理
            self.level_to_kv(level_output)  # 经过level_to_kv处理
        ), dim = -2)

        k, v = kvs.chunk(2, dim = 1)  # 按维度切分张量

        # cross attention pooling

        sim = einsum('h d, b h j d -> b h j', q, k)  # 执行张量乘法
        sim = sim - sim.amax(dim = -1, keepdim = True).detach()  # 减去最大值并断开梯度

        attn = sim.softmax(dim = -1)  # softmax操作
        attn = self.dropout(attn)  # dropout操作

        out = einsum('b h j, b h j d -> b h d', attn, v)  # 执行张量乘法
        out = rearrange(out, 'b ... -> b (...)')  # 重新排列张量维度

        out = self.to_out(out)  # 线性变换

        # project to logits

        return self.to_logits(out)  # 返回logits

.\lucidrains\ETSformer-pytorch\etsformer_pytorch\__init__.py

# 从etsformer_pytorch模块中导入ETSFormer、ClassificationWrapper和MHESA类
from etsformer_pytorch.etsformer_pytorch import (
    ETSFormer,
    ClassificationWrapper,
    MHESA
)

ETSformer - Pytorch

Implementation of ETSformer, state of the art time-series Transformer, in Pytorch

Install

$ pip install etsformer-pytorch

Usage

import torch
from etsformer_pytorch import ETSFormer

model = ETSFormer(
    time_features = 4,
    model_dim = 512,                # in paper they use 512
    embed_kernel_size = 3,          # kernel size for 1d conv for input embedding
    layers = 2,                     # number of encoder and corresponding decoder layers
    heads = 8,                      # number of exponential smoothing attention heads
    K = 4,                          # num frequencies with highest amplitude to keep (attend to)
    dropout = 0.2                   # dropout (in paper they did 0.2)
)

timeseries = torch.randn(1, 1024, 4)

pred = model(timeseries, num_steps_forecast = 32) # (1, 32, 4) - (batch, num steps forecast, num time features)

For using ETSFormer for classification, using cross attention pooling on all latents and level output

import torch
from etsformer_pytorch import ETSFormer, ClassificationWrapper

etsformer = ETSFormer(
    time_features = 1,
    model_dim = 512,
    embed_kernel_size = 3,
    layers = 2,
    heads = 8,
    K = 4,
    dropout = 0.2
)

adapter = ClassificationWrapper(
    etsformer = etsformer,
    dim_head = 32,
    heads = 16,
    dropout = 0.2,
    level_kernel_size = 5,
    num_classes = 10
)

timeseries = torch.randn(1, 1024)

logits = adapter(timeseries) # (1, 10)

Citation

@misc{woo2022etsformer,
    title   = {ETSformer: Exponential Smoothing Transformers for Time-series Forecasting}, 
    author  = {Gerald Woo and Chenghao Liu and Doyen Sahoo and Akshat Kumar and Steven Hoi},
    year    = {2022},
    eprint  = {2202.01381},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}

.\lucidrains\ETSformer-pytorch\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  # 包的名称
  name = 'ETSformer-pytorch',
  # 查找所有包,不排除任何包
  packages = find_packages(exclude=[]),
  # 版本号
  version = '0.1.1',
  # 许可证类型
  license='MIT',
  # 包的描述
  description = 'ETSTransformer - Exponential Smoothing Transformer for Time-Series Forecasting - Pytorch',
  # 长描述内容类型
  long_description_content_type = 'text/markdown',
  # 作者
  author = 'Phil Wang',
  # 作者邮箱
  author_email = 'lucidrains@gmail.com',
  # 项目链接
  url = 'https://github.com/lucidrains/ETSformer-pytorch',
  # 关键词列表
  keywords = [
    'artificial intelligence',
    'deep learning',
    'transformers',
    'time-series',
    'forecasting'
  ],
  # 安装依赖
  install_requires=[
    'einops>=0.4',
    'scipy',
    'torch>=1.6',
  ],
  # 分类标签
  classifiers=[
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\evolutionary-design-molecules\evolutionary_design_molecules\evolutionary_design_molecules.py

import torch
from torch import nn, einsum, Tensor
from torch.nn import Module, ModuleList

from beartype import beartype
from einops import rearrange

from vector_quantize_pytorch import LFQ

# helper functions

# 检查变量是否存在
def exists(v):
    return v is not None

# genetic algorithm

# 进化函数,根据给定的初始种群、计算适应度函数,进行遗传算法进化
@beartype
def evolve(
    init_pool: Tensor,
    calc_fitness: Callable,
    generations = 1e5,
    population = 100,
    mutation_rate = 0.04,
    frac_survive_fittest = 0.25,
    frac_tournament = 0.25,
    frac_elite = 0.05,
):

    keep_fittest_len = int(population * frac_survive_fittest)
    num_elite = int(frac_elite * population)
    num_repro_and_mutate = keep_fittest_len - num_elite
    num_tournament_contenders = int(num_repro_and_mutate * FRAC_TOURNAMENT)
    num_children = population - keep_fittest_len
    num_mutate = mutation_rate * gene_length

    assert num_tournament_contenders >= 2

    # genetic algorithm

    generation = 1

    pool = init_pool

    for generation in generations:
        print(f"\n\ngeneration {generation}\n")

        # sort population by fitness

        fitnesses = calc_fitness(pool)

        indices = fitnesses.sort(descending = True).indices
        pool, fitnesses = pool[indices], fitnesses[indices]

        # keep the fittest

        pool, fitnesses = pool[:keep_fittest_len], fitnesses[:keep_fittest_len]

        # display every generation

        for gene, fitness in zip(pool, fitnesses):
            print(f"{decode(gene)} ({fitness.item():.3f})")

        # solved if any fitness is inf

        if (fitnesses == float('inf')).any():
            break

        # elites can pass directly to next generation

        elites, pool = pool[:num_elite], pool[num_elite:]
        elites_fitnesses, fitnesses = fitnesses[:num_elite], fitnesses[num_elite:]

        # deterministic tournament selection - let top 2 winners become parents

        contender_ids = torch.randn((num_children, num_repro_and_mutate)).argsort(dim = -1)[..., :num_tournament_contenders]
        participants, tournaments = pool[contender_ids], fitnesses[contender_ids]
        top2_winners = tournaments.topk(2, dim = -1, largest = True, sorted = False).indices
        top2_winners = repeat(top2_winners, '... -> ... g', g = gene_length)
        parents = participants.gather(1, top2_winners)

        # cross over recombination of parents

        parent1, parent2 = parents.unbind(dim = 1)
        children = torch.cat((parent1[:, :gene_midpoint], parent2[:, gene_midpoint:]), dim = -1)

        pool = torch.cat((pool, children))

        # mutate genes in population

        mutate_mask = torch.randn(pool.shape).argsort(dim = -1) < num_mutate
        noise = torch.randint(0, 2, pool.shape) * 2 - 1
        pool = torch.where(mutate_mask, pool + noise, pool)
        pool.clamp_(0, 255)

        # add back the elites

        pool = torch.cat((elites, pool))

        generation += 1

    return pool

# autoencoder

# 分子自编码器类
class MolecularAutoencoder(Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        raise NotImplementedError

# main class

# 主类,用于设计分子
class EvolveDesignMoleculesInSilico(Module):
    def __init__(self):
        super().__init__()

    def forward(self):
        raise NotImplementedError

.\lucidrains\evolutionary-design-molecules\evolutionary_design_molecules\__init__.py

# 定义一个名为calculate_area的函数,用于计算矩形的面积
def calculate_area(length, width):
    # 计算矩形的面积
    area = length * width
    # 返回计算得到的面积
    return area

Evolutionary Design of Molecules with Deep Learning and Genetic Algorithms (wip)

Unofficial implementation of the paper Evolutionary design of molecules based on deep learning and a genetic algorithm.

There are a few improvements that will be improvised on top of the general idea. (1) Use an equivariant attention network for the autoencoder. (2) Latent space will be using independent binary codes, the lookup-free quantization proposed by Yu et al. (3) Will bring in a few ideas to maintain greater diversity in the genetic pool

May also include a policy network for choosing the parents, as proposed in the paper Reinforced Genetic Algorithm for Structure-based Drug Design.

Citations

@article{article,
	author 	= {Kwon, Youngchun and Kang, Seokho and Choi, Youn-Suk and Kim, Inkoo},
	year 	= {2021},
	month 	= {08},
	title 	= {Evolutionary design of molecules based on deep learning and a genetic algorithm},
	journal = {Scientific Reports},
	doi 	= {10.1038/s41598-021-96812-8}
}
@article{Yu2023LanguageMB,
	title 	= {Language Model Beats Diffusion - Tokenizer is Key to Visual Generation},
	author 	= {Lijun Yu and Jos'e Lezama and Nitesh B. Gundavarapu and Luca Versari and Kihyuk Sohn and David C. Minnen and Yong Cheng and Agrim Gupta and Xiuye Gu and Alexander G. Hauptmann and Boqing Gong and Ming-Hsuan Yang and Irfan Essa and David A. Ross and Lu Jiang},
	journal = {ArXiv},
	year 	= {2023},
	volume 	= {abs/2310.05737},
	url 	= {https://api.semanticscholar.org/CorpusID:263830733}
}

.\lucidrains\evolutionary-design-molecules\setup.py

# 导入设置工具和查找包
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'evolutionary-design-molecules',  # 包的名称
  packages = find_packages(exclude=[]),  # 查找并包含所有包
  version = '0.0.1',  # 版本号
  license='MIT',  # 许可证
  description = 'Evolutionary Design of Molecules',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  long_description_content_type = 'text/markdown',  # 长描述内容类型
  url = 'https://github.com/lucidrains/evolutionary-design-molecules',  # 项目链接
  keywords = [  # 关键词列表
    'artificial intelligence',
    'deep learning',
    'evolutionary algorithms'
  ],
  install_requires=[  # 安装依赖
    'beartype',
    'einops>=0.7.0',
    'torch>=2.0',
    'vector-quantize-pytorch>=1.12.1'
  ],
  classifiers=[  # 分类器
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\fast-transformer-pytorch\fast_transformer_pytorch\fast_transformer_pytorch.py

import torch  # 导入 PyTorch 库
import torch.nn.functional as F  # 导入 PyTorch 中的函数模块
from torch import nn, einsum  # 从 PyTorch 中导入 nn 和 einsum 模块

from einops import rearrange, reduce  # 从 einops 库中导入 rearrange 和 reduce 函数
from rotary_embedding_torch import apply_rotary_emb, RotaryEmbedding  # 从 rotary_embedding_torch 库中导入 apply_rotary_emb 和 RotaryEmbedding 类

# helper functions

def exists(val):
    return val is not None  # 判断值是否为 None 的辅助函数

def default(val, d):
    return val if exists(val) else d  # 如果值存在则返回该值,否则返回默认值的辅助函数

# helper classes

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)  # 对输入进行 LayerNorm 归一化
        self.fn = fn  # 传入的函数

    def forward(self, x, **kwargs):
        x = self.norm(x)  # 对输入进行归一化
        return self.fn(x, **kwargs)  # 调用传入的函数处理归一化后的输入

# blocks

def FeedForward(dim, mult = 4):
    return nn.Sequential(
        nn.Linear(dim, dim * mult),  # 线性变换
        nn.GELU(),  # GELU 激活函数
        nn.Linear(dim * mult, dim)  # 线性变换
    )

class FastAttention(nn.Module):
    def __init__(
        self,
        dim,
        *,
        heads = 8,
        dim_head = 64,
        max_seq_len = None,
        pos_emb = None
    ):
        super().__init__()
        inner_dim = heads * dim_head
        self.heads = heads
        self.scale = dim_head ** -0.5

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)  # 线性变换将输入转换为查询、键、值

        # rotary positional embedding

        assert not (exists(pos_emb) and not exists(max_seq_len)), 'max_seq_len must be passed in if to use rotary positional embeddings'  # 断言语句,确保条件成立

        self.pos_emb = pos_emb  # 位置编码
        self.max_seq_len = max_seq_len  # 最大序列长度

        # if using relative positional encoding, make sure to reduce pairs of consecutive feature dimension before doing projection to attention logits

        kv_attn_proj_divisor = 1 if not exists(pos_emb) else 2  # 如果使用相对位置编码,则将连续特征维度减少一半再进行注意力机制的投影

        self.to_q_attn_logits = nn.Linear(dim_head, 1, bias = False)  # 用于将查询投影到查询注意力得分的线性变换
        self.to_k_attn_logits = nn.Linear(dim_head // kv_attn_proj_divisor, 1, bias = False)  # 用于将键��影到键注意力得分的线性变换

        # final transformation of values to "r" as in the paper

        self.to_r = nn.Linear(dim_head // kv_attn_proj_divisor, dim_head)  # 将值最终转换为 "r",与论文中描述的一致

        self.to_out = nn.Linear(inner_dim, dim)  # 最终输出的线性变换
    # 定义前向传播函数,接受输入张量 x 和可选的 mask 参数
    def forward(self, x, mask = None):
        # 获取输入张量 x 的形状信息
        n, device, h, use_rotary_emb = x.shape[1], x.device, self.heads, exists(self.pos_emb)

        # 将输入张量 x 经过线性变换得到 qkv,并按照通道数分割为 q、k、v
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)

        # 初始化 mask_value 为 x 数据类型的最小值
        mask_value = -torch.finfo(x.dtype).max
        # 将 mask 重排为 'b () n' 形状
        mask = rearrange(mask, 'b n -> b () n')

        # 如果需要使用相对位置编码
        if use_rotary_emb:
            # 获取位置编码频率信息
            freqs = self.pos_emb(torch.arange(self.max_seq_len, device = device), cache_key = self.max_seq_len)
            freqs = rearrange(freqs[:n], 'n d -> () () n d')
            # 对 q、k、v 应用旋转编码
            q_aggr, k_aggr, v_aggr = map(lambda t: apply_rotary_emb(freqs, t), (q, k, v))
        else:
            q_aggr, k_aggr, v_aggr = q, k, v

        # 计算查询注意力 logits
        q_attn_logits = rearrange(self.to_q_attn_logits(q), 'b h n () -> b h n') * self.scale
        q_attn_logits = q_attn_logits.masked_fill(~mask, mask_value)
        q_attn = q_attn_logits.softmax(dim = -1)

        # 计算全局查询 token
        global_q = einsum('b h n, b h n d -> b h d', q_attn, q_aggr)
        global_q = rearrange(global_q, 'b h d -> b h () d')

        # 用全局查询 token 偏置键
        k = k * global_q

        # 如果使用旋转编码,对特征维度中相邻对进行内积
        if use_rotary_emb:
            k = reduce(k, 'b h n (d r) -> b h n d', 'sum', r = 2)

        # 计算键注意力 logits
        k_attn_logits = rearrange(self.to_k_attn_logits(k), 'b h n () -> b h n') * self.scale
        k_attn_logits = k_attn_logits.masked_fill(~mask, mask_value)
        k_attn = k_attn_logits.softmax(dim = -1)

        # 计算全局键 token
        global_k = einsum('b h n, b h n d -> b h d', k_attn, k_aggr)
        global_k = rearrange(global_k, 'b h d -> b h () d')

        # 偏置值
        u = v_aggr * global_k

        # 如果使用旋转编码,对特征维度中相邻对进行内积
        if use_rotary_emb:
            u = reduce(u, 'b h n (d r) -> b h n d', 'sum', r = 2)

        # 转换步骤
        r = self.to_r(u)

        # 论文中指出将查询作为残差添加
        r = r + q

        # 合并头部
        r = rearrange(r, 'b h n d -> b n (h d)')
        # 返回输出结果
        return self.to_out(r)
# 主类 FastTransformer
class FastTransformer(nn.Module):
    # 初始化函数
    def __init__(
        self,
        *,
        num_tokens,  # 标记数量
        dim,  # 维度
        depth,  # 深度
        max_seq_len,  # 最大序列长度
        heads = 8,  # 头数
        dim_head = 64,  # 头的维度
        ff_mult = 4,  # FeedForward 的倍数
        absolute_pos_emb = False  # 是否使用绝对位置编码
    ):
        super().__init__()
        self.token_emb = nn.Embedding(num_tokens, dim)  # 标记嵌入层

        # 位置编码
        self.abs_pos_emb = nn.Embedding(max_seq_len, dim) if absolute_pos_emb else None

        layer_pos_emb = None
        if not absolute_pos_emb:
            assert (dim_head % 4) == 0, 'dimension of the head must be divisible by 4 to use rotary embeddings'
            layer_pos_emb = RotaryEmbedding(dim_head // 2)

        # 层
        self.layers = nn.ModuleList([])

        for _ in range(depth):
            attn = FastAttention(dim, dim_head = dim_head, heads = heads, pos_emb = layer_pos_emb, max_seq_len = max_seq_len)  # 快速注意力机制
            ff = FeedForward(dim, mult = ff_mult)  # 前馈网络

            self.layers.append(nn.ModuleList([
                PreNorm(dim, attn),  # 预归一化
                PreNorm(dim, ff)  # 预归一化
            ]))

        # 在所有层之间进行权重绑定投影
        first_block, _ = self.layers[0]
        for block, _ in self.layers[1:]:
            block.fn.to_q_attn_logits = first_block.fn.to_q_attn_logits
            block.fn.to_k_attn_logits = first_block.fn.to_k_attn_logits

        # 转换为 logits
        self.to_logits = nn.Sequential(
            nn.LayerNorm(dim),  # 层归一化
            nn.Linear(dim, num_tokens)  # 线性层
        )

    # 前向传播函数
    def forward(
        self,
        x,
        mask = None
    ):
        n, device = x.shape[1], x.device
        x = self.token_emb(x)  # 标记嵌入

        if exists(self.abs_pos_emb):
            pos_emb = self.abs_pos_emb(torch.arange(n, device = device))
            x = x + rearrange(pos_emb, 'n d -> () n d')  # 重排位置编码

        for attn, ff in self.layers:
            x = attn(x, mask = mask) + x  # 注意力机制
            x = ff(x) + x  # 前馈网络

        return self.to_logits(x)  # 返回 logits

.\lucidrains\fast-transformer-pytorch\fast_transformer_pytorch\__init__.py

# 从 fast_transformer_pytorch.fast_transformer_pytorch 模块中导入 FastTransformer 类
from fast_transformer_pytorch.fast_transformer_pytorch import FastTransformer

Fast Transformer - Pytorch

Implementation of Fast Transformer in Pytorch. This only work as an encoder.

Yannic video

AI Epiphany

Install

$ pip install fast-transformer-pytorch

Usage

import torch
from fast_transformer_pytorch import FastTransformer

model = FastTransformer(
    num_tokens = 20000,
    dim = 512,
    depth = 2,
    max_seq_len = 4096,
    absolute_pos_emb = True   # default uses relative positional encoding, but if that isn't working, then turn on absolute positional embedding by setting this to True
)

x = torch.randint(0, 20000, (1, 4096))
mask = torch.ones(1, 4096).bool()

logits = model(x, mask = mask) # (1, 4096, 20000)

Citations

@misc{wu2021fastformer,
    title   = {Fastformer: Additive Attention is All You Need}, 
    author  = {Chuhan Wu and Fangzhao Wu and Tao Qi and Yongfeng Huang},
    year    = {2021},
    eprint  = {2108.09084},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}

.\lucidrains\fast-transformer-pytorch\setup.py

# 导入设置工具和查找包工具
from setuptools import setup, find_packages

# 设置包的信息
setup(
  # 包名
  name = 'fast-transformer-pytorch',
  # 查找所有包
  packages = find_packages(),
  # 版本号
  version = '0.0.4',
  # 许可证
  license='MIT',
  # 描述
  description = 'Fast Transformer - Pytorch',
  # 作者
  author = 'Phil Wang',
  # 作者邮箱
  author_email = 'lucidrains@gmail.com',
  # 项目链接
  url = 'https://github.com/lucidrains/fast-transformer-pytorch',
  # 关键词
  keywords = [
    'artificial intelligence',
    'deep learning',
    'transformers'
  ],
  # 安装依赖
  install_requires=[
    'einops>=0.3',
    'rotary-embedding-torch',
    'torch>=1.6'
  ],
  # 分类
  classifiers=[
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\feedback-transformer-pytorch\feedback_transformer_pytorch\feedback_transformer_pytorch.py

# 导入数学库
import math
# 导入命名元组
from collections import namedtuple

# 导入 PyTorch 库
import torch
# 导入神经网络模块、矩阵乘法函数
from torch import nn, einsum
# 导入 PyTorch 中的函数库
import torch.nn.functional as F
# 从 einops 库中导入重新排列函数
from einops import rearrange

# 定义命名元组 Memory,包含 keys 和 values 两个字段
Memory = namedtuple('Memory', ['keys', 'values'])

# 辅助函数

# 判断变量是否存在
def exists(val):
    return val is not None

# 如果变量存在则返回其值,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# 安全地拼接张量
def safe_cat(arr, el, dim = 1):
    if not exists(arr):
        return el
    return torch.cat((arr, el), dim = dim)

# 位置嵌入

# 定义相对位置偏置类
class RelativePositionBias(nn.Module):
    def __init__(
        self,
        causal = False,
        num_buckets = 32,
        max_distance = 128,
        heads = 8
    ):
        super().__init__()
        self.causal = causal
        self.num_buckets = num_buckets
        self.max_distance = max_distance
        self.relative_attention_bias = nn.Embedding(num_buckets, heads)

    # 静态方法,计算相对位置的桶索引
    @staticmethod
    def _relative_position_bucket(relative_position, causal = True, num_buckets = 32, max_distance = 128):
        ret = 0
        n = -relative_position
        if not causal:
            num_buckets //= 2
            ret += (n < 0).long() * num_buckets
            n = torch.abs(n)
        else:
            n = torch.max(n, torch.zeros_like(n))

        max_exact = num_buckets // 2
        is_small = n < max_exact

        val_if_large = max_exact + (
            torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
        ).long()
        val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))

        ret += torch.where(is_small, n, val_if_large)
        return ret

    # 前向传播函数
    def forward(self, qk_dots):
        i, j, device = *qk_dots.shape[-2:], qk_dots.device
        q_pos = torch.arange(i, dtype = torch.long, device = device)
        k_pos = torch.arange(j, dtype = torch.long, device = device)
        rel_pos = k_pos[None, :] - q_pos[:, None]
        rp_bucket = self._relative_position_bucket(rel_pos, causal = self.causal, num_buckets = self.num_buckets, max_distance = self.max_distance)
        values = self.relative_attention_bias(rp_bucket)
        bias = rearrange(values, 'i j h -> () h i j')
        return bias

# 辅助类

# 残差连接类
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) + x

# 预层归一化类
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = nn.LayerNorm(dim)

    def forward(self, x, **kwargs):
        x = self.norm(x)
        return self.fn(x, **kwargs)

# 如果条件成立则跳过的类
class SkipIf(nn.Module):
    def __init__(self, cond, fn):
        super().__init__()
        self.cond = cond
        self.fn = fn

    def forward(self, x, *args, **kwargs):
        if self.cond(x, *args, **kwargs):
            return x
        return self.fn(x, *args, **kwargs)

# 前馈网络

# GEGLU 激活函数类
class GEGLU(nn.Module):
    def forward(self, x):
        x, gate = x.chunk(2, dim = -1)
        return F.gelu(gate) * x

# 前馈网络类
class FeedForward(nn.Module):
    def __init__(
        self,
        *,
        dim,
        mult = 4,
        dropout = 0.
    ):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim * mult * 2),
            GEGLU(),
            nn.Dropout(dropout),
            nn.Linear(dim * mult, dim)
        )

    def forward(self, x):
        return self.net(x)

# 注意力机制

# 注意力类
class Attention(nn.Module):
    def __init__(
        self,
        *,
        dim,
        heads = 8,
        dim_head = 64,
        dropout = 0.
    ):
        super().__init__()
        self.heads = heads
        self.scale = dim_head ** -0.5

        inner_dim = dim_head * heads
        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
        self.to_out = nn.Linear(inner_dim, dim)

        self.dropout = nn.Dropout(dropout)
    # 定义前向传播函数,接受输入 x、记忆 memory 和位置编码 pos_emb
    def forward(self, x, memory, pos_emb = None):
        # 获取头数 h、序列长度 n 和设备信息
        h, n, device = self.heads, x.shape[1], x.device

        # 判断是否进行自注意力计算,只有在大于1个标记时才进行自注意力计算
        self_attend = n > 1 

        # 将输入 x 转换为查询向量 q,并乘以缩放因子
        q = self.to_q(x) * self.scale

        # 解包记忆 memory 中的键 k 和值 v,如果不存在则设为 None
        k, v = memory if exists(memory) else (None, None)

        # 如果需要进行自注意力计算
        if self_attend:
            # 将输入 x 转换为键 k 和值 v
            self_k, self_v = self.to_kv(x).chunk(2, dim = -1)
            # 将自注意力计算得到的键 k 和值 v 与原有的键 k 和值 v 进行拼接
            k = safe_cat(k, self_k, dim = 1)
            v = safe_cat(v, self_v, dim = 1)

        # 将查询 q、键 k 和值 v 重排维度,以适应多头注意力计算
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))

        # 计算注意力分数矩阵 sim
        sim = einsum('b h i d, b h j d -> b h i j', q, k)
        i, j = sim.shape[-2:]

        # 如果存在位置编码 pos_emb,则加上位置编码
        if exists(pos_emb):
            sim = sim + pos_emb(sim)

        # 如果需要进行自注意力计算
        if self_attend:
            # 生成因果掩码,用于屏蔽未来信息
            causal_mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
            causal_mask = rearrange(causal_mask, 'i j -> () () i j')
            mask_value = -torch.finfo(q.dtype).max
            sim.masked_fill_(causal_mask, mask_value)

        # 对注意力分数矩阵进行 softmax 操作
        attn = sim.softmax(dim = -1)
        # 对注意力分数矩阵应用 dropout
        attn = self.dropout(attn)

        # 计算加权后的值向量 out
        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        # 重排维度,恢复原始形状
        out = rearrange(out, 'b h n d -> b n (h d)')
        # 将输出结果传递给输出层
        return self.to_out(out)
# 主类定义

class FeedbackTransformer(nn.Module):
    def __init__(
        self,
        *,
        num_tokens,  # 标记数量
        dim,  # 维度
        depth,  # 深度
        mem_len,  # 记忆长度
        seq_len = 2,  # 序列长度,默认为2
        heads = 8,  # 头数
        dim_head = 64,  # 头维度
        attn_dropout = 0.,  # 注意力机制的dropout
        ff_dropout = 0.,  # 前馈网络的dropout
        keep_last_hidden = False  # 是否保留最后一个隐藏层
    ):
        super().__init__()
        self.seq_len = seq_len
        self.mem_len = mem_len

        self.token_emb = nn.Embedding(num_tokens, dim)  # 标记嵌入层
        self.pos_emb = RelativePositionBias(causal = True, heads = heads)  # 相对位置偏置

        # 主要层

        self.layers = nn.ModuleList([])
        shared_kv_proj = None

        for _ in range(depth):
            attn = Attention(dim = dim, heads = heads, dim_head = dim_head, dropout = attn_dropout)  # 注意力机制
            ff = FeedForward(dim = dim, dropout = ff_dropout)  # 前馈网络

            shared_kv_proj = default(shared_kv_proj, attn.to_kv)  # 共享键值投影
            attn.to_kv = shared_kv_proj

            attn, ff = map(lambda fn: Residual(PreNorm(dim, fn)), (attn, ff))  # 添加残差连接和层归一化

            if seq_len == 1:
                memory_is_empty = lambda *args, **kwargs: not exists(kwargs['memory'])
                attn = SkipIf(memory_is_empty, attn)  # 如果记忆为空,则跳过

            self.layers.append(nn.ModuleList([
                attn,
                ff
            ]))

        # 记忆参数

        self.layer_weight = nn.Parameter(torch.ones(depth + 1))  # 层权重
        self.shared_kv_proj = shared_kv_proj
        self.keep_last_hidden = keep_last_hidden

        # 最终投影到logits

        self.to_logits = nn.Sequential(
            nn.LayerNorm(dim),  # 层归一化
            nn.Linear(dim, num_tokens)  # 线性层
        )

    def forward(self, x, memory = None, return_memory = False):
        b, n, device = *x.shape, x.device

        x = self.token_emb(x)  # 标记嵌入

        memory_keys = None
        memory_values = None

        if exists(memory):
            memory_keys, memory_values = memory

        outputs = []

        # 计算层的权重以存储到记忆中

        layer_weight = self.layer_weight.softmax(dim = -1)
        layer_weight = rearrange(layer_weight, 'd -> d () () ()')

        for x in x.split(self.seq_len, dim = 1):
            hiddens = [x]

            # 准备用于注意力的记忆,如果存在

            memory = None
            if exists(memory_keys):
                memory = (memory_keys, memory_values)

            for attn, ff in self.layers:

                x = attn(x, memory = memory, pos_emb = self.pos_emb)  # 注意力机制
                x = ff(x)  # 前馈网络

                hiddens.append(x)

            outputs.append(x)

            # 计算新的记忆键/值并存储到FIFO队列

            if self.keep_last_hidden:  # 保留最后一个隐藏层
                agg_hiddens = hiddens[-1]
            else:
                hiddens = torch.stack(hiddens)
                agg_hiddens = (hiddens * layer_weight).sum(dim = 0)

            # 预先计算记忆键/值并存储到缓冲区

            mem_k, mem_v = self.shared_kv_proj(agg_hiddens).chunk(2, dim = -1)
            memory_keys = safe_cat(memory_keys, mem_k, dim = 1)
            memory_values = safe_cat(memory_values, mem_v, dim = 1)

            # 强制在记忆缓冲区上施加最大长度限制

            memory_keys = memory_keys[:, -self.mem_len:]
            memory_values = memory_values[:, -self.mem_len:]

        x = torch.cat((outputs), dim = 1)
        out = self.to_logits(x)

        if not return_memory:
            return out

        return out, Memory(memory_keys, memory_values)