Lucidrains-系列项目源码解析-三十一-

108 阅读26分钟

Lucidrains 系列项目源码解析(三十一)

.\lucidrains\En-transformer\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'En-transformer',  # 包的名称
  packages = find_packages(),  # 查找所有包
  version = '1.6.5',  # 版本号
  license='MIT',  # 许可证
  description = 'E(n)-Equivariant Transformer',  # 描述
  long_description_content_type = 'text/markdown',  # 长描述内容类型
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  url = 'https://github.com/lucidrains/En-transformer',  # 项目链接
  keywords = [  # 关键词列表
    'artificial intelligence',
    'deep learning',
    'equivariance',
    'transformer'
  ],
  install_requires=[  # 安装依赖
    'einops>=0.3',
    'einx',
    'taylor-series-linear-attention>=0.1.4',
    'torch>=1.7'
  ],
  setup_requires=[  # 设置需要的依赖
    'pytest-runner',
  ],
  tests_require=[  # 测试需要的依赖
    'pytest'
  ],
  classifiers=[  # 分类器
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\En-transformer\tests\test_equivariance.py

# 导入 torch 库
import torch
# 从 en_transformer.utils 模块中导入 rot 函数
from en_transformer.utils import rot
# 从 en_transformer 模块中导入 EnTransformer 类
from en_transformer import EnTransformer

# 设置默认张量数据类型为 float64
torch.set_default_dtype(torch.float64)

# 测试函数,用于测试 README 中的示例
def test_readme():
    # 创建 EnTransformer 模型对象,设置参数
    model = EnTransformer(
        dim = 512,
        depth = 1,
        dim_head = 64,
        heads = 8,
        edge_dim = 4,
        neighbors = 6
    )

    # 生成随机输入特征、坐标和边
    feats = torch.randn(1, 32, 512)
    coors = torch.randn(1, 32, 3)
    edges = torch.randn(1, 32, 1024, 4)

    # 创建掩码张量
    mask = torch.ones(1, 32).bool()

    # 调用模型进行前向传播
    feats, coors = model(feats, coors, edges, mask = mask)
    # 断言测试结果为真
    assert True, 'it runs'

# 测试函数,用于测试等变性
def test_equivariance():
    # 创建 EnTransformer 模型对象,设置参数
    model = EnTransformer(
        dim = 512,
        depth = 1,
        edge_dim = 4,
        rel_pos_emb = True
    )

    # 生成随机旋转矩阵 R 和平移向量 T
    R = rot(*torch.rand(3))
    T = torch.randn(1, 1, 3)

    # 生成随机输入特征、坐标和边
    feats = torch.randn(1, 16, 512)
    coors = torch.randn(1, 16, 3)
    edges = torch.randn(1, 16, 16, 4)

    # 调用模型进行前向传播
    feats1, coors1 = model(feats, coors @ R + T, edges)
    feats2, coors2 = model(feats, coors, edges)

    # 断言特征等变
    assert torch.allclose(feats1, feats2, atol = 1e-6), 'type 0 features are invariant'
    # 断言坐标等变
    assert torch.allclose(coors1, (coors2 @ R + T), atol = 1e-6), 'type 1 features are equivariant'

# 其他测试函数的注释与上述两个测试函数类似,不再重复注释
# 请根据上述示例注释完成以下测试函数

def test_equivariance_with_cross_product():
    model = EnTransformer(
        dim = 512,
        depth = 1,
        edge_dim = 4,
        rel_pos_emb = True,
        use_cross_product = True
    )

    R = rot(*torch.rand(3))
    T = torch.randn(1, 1, 3)

    feats = torch.randn(1, 16, 512)
    coors = torch.randn(1, 16, 3)
    edges = torch.randn(1, 16, 16, 4)

    feats1, coors1 = model(feats, coors @ R + T, edges)
    feats2, coors2 = model(feats, coors, edges)

    assert torch.allclose(feats1, feats2, atol = 1e-6), 'type 0 features are invariant'
    assert torch.allclose(coors1, (coors2 @ R + T), atol = 1e-6), 'type 1 features are equivariant'

def test_equivariance_with_nearest_neighbors():
    model = EnTransformer(
        dim = 512,
        depth = 1,
        edge_dim = 4,
        neighbors = 5
    )

    R = rot(*torch.rand(3))
    T = torch.randn(1, 1, 3)

    feats = torch.randn(1, 16, 512)
    coors = torch.randn(1, 16, 3)
    edges = torch.randn(1, 16, 16, 4)

    feats1, coors1 = model(feats, coors @ R + T, edges)
    feats2, coors2 = model(feats, coors, edges)

    assert torch.allclose(feats1, feats2, atol = 1e-6), 'type 0 features are invariant'
    assert torch.allclose(coors1, (coors2 @ R + T), atol = 1e-6), 'type 1 features are equivariant'

def test_equivariance_with_sparse_neighbors():
    model = EnTransformer(
        dim = 512,
        depth = 1,
        heads = 4,
        dim_head = 32,
        neighbors = 0,
        only_sparse_neighbors = True
    )

    R = rot(*torch.rand(3))
    T = torch.randn(1, 1, 3)

    feats = torch.randn(1, 16, 512)
    coors = torch.randn(1, 16, 3)

    i = torch.arange(feats.shape[1])
    adj_mat = (i[:, None] <= (i[None, :] + 1)) & (i[:, None] >= (i[None, :] - 1))

    feats1, coors1 = model(feats, coors @ R + T, adj_mat = adj_mat)
    feats2, coors2 = model(feats, coors, adj_mat = adj_mat)

    assert torch.allclose(feats1, feats2, atol = 1e-6), 'type 0 features are invariant'
    assert torch.allclose(coors1, (coors2 @ R + T), atol = 1e-6), 'type 1 features are equivariant'

def test_depth():
    model = EnTransformer(
        dim = 8,
        depth = 12,
        edge_dim = 4,
        neighbors = 16
    )

    feats = torch.randn(1, 128, 8)
    coors = torch.randn(1, 128, 3)
    edges = torch.randn(1, 128, 128, 4)

    feats, coors = model(feats, coors, edges)

    assert not torch.any(torch.isnan(feats)), 'no NaN in features'
    assert not torch.any(torch.isnan(coors)), 'no NaN in coordinates'

.\lucidrains\enformer-pytorch\enformer_pytorch\config_enformer.py

# 导入预训练配置类 PretrainedConfig 从 transformers 模块
from transformers import PretrainedConfig

# 创建 EnformerConfig 类,继承自 PretrainedConfig 类
class EnformerConfig(PretrainedConfig):
    # 模型类型为 "enformer"
    model_type = "enformer"

    # 初始化函数,接受多个参数
    def __init__(
        self,
        dim = 1536,  # 维度为 1536
        depth = 11,  # 深度为 11
        heads = 8,   # 头数为 8
        output_heads = dict(human = 5313, mouse= 1643),  # 输出头数为人类 5313,老鼠 1643
        target_length = 896,  # 目标长度为 896
        attn_dim_key = 64,    # 注意力维度为 64
        dropout_rate = 0.4,   # 丢弃率为 0.4
        attn_dropout = 0.05,  # 注意力丢弃率为 0.05
        pos_dropout = 0.01,   # 位置丢弃率为 0.01
        use_checkpointing = False,  # 是否使用检查点为 False
        use_convnext = False,       # 是否使用卷积为 False
        num_downsamples = 7,        # 下采样次数为 7,默认 Enformer 下采样 2 ** 7 == 128 倍,可以更改以获得更高分辨率
        dim_divisible_by = 128,     # 维度可被 128 整除
        use_tf_gamma = False,       # 是否使用 TensorFlow Gamma 为 False
        **kwargs,  # 其他关键字参数
    ):
        # 初始化各个参数
        self.dim = dim
        self.depth = depth
        self.heads = heads
        self.output_heads = output_heads
        self.target_length = target_length
        self.attn_dim_key = attn_dim_key
        self.dropout_rate = dropout_rate
        self.attn_dropout = attn_dropout
        self.pos_dropout = pos_dropout
        self.use_checkpointing = use_checkpointing
        self.num_downsamples = num_downsamples
        self.dim_divisible_by = dim_divisible_by
        self.use_tf_gamma = use_tf_gamma

        # 调用父类的初始化函数
        super().__init__(**kwargs)

.\lucidrains\enformer-pytorch\enformer_pytorch\data.py

# 导入 torch 库
import torch
# 导入 torch 中的函数库
import torch.nn.functional as F
# 从 torch.utils.data 中导入 Dataset 类
from torch.utils.data import Dataset

# 导入 polars 库并重命名为 pl
import polars as pl
# 导入 numpy 库并重命名为 np
import numpy as np
# 从 random 中导入 randrange 和 random 函数
from random import randrange, random
# 从 pathlib 中导入 Path 类
from pathlib import Path
# 从 pyfaidx 中导入 Fasta 类

import pyfaidx.Fasta

# 辅助函数

# 判断值是否存在
def exists(val):
    return val is not None

# 返回输入值
def identity(t):
    return t

# 将输入值转换为列表
def cast_list(t):
    return t if isinstance(t, list) else [t]

# 返回一个随机布尔值
def coin_flip():
    return random() > 0.5

# 基因组函数转换

# 创建一个包含 ASCII 码对应索引的张量
seq_indices_embed = torch.zeros(256).long()
seq_indices_embed[ord('a')] = 0
seq_indices_embed[ord('c')] = 1
seq_indices_embed[ord('g')] = 2
seq_indices_embed[ord('t')] = 3
seq_indices_embed[ord('n')] = 4
seq_indices_embed[ord('A')] = 0
seq_indices_embed[ord('C')] = 1
seq_indices_embed[ord('G')] = 2
seq_indices_embed[ord('T')] = 3
seq_indices_embed[ord('N')] = 4
seq_indices_embed[ord('.')] = -1

# 创建一个包含 one-hot 编码的张量
one_hot_embed = torch.zeros(256, 4)
one_hot_embed[ord('a')] = torch.Tensor([1., 0., 0., 0.])
one_hot_embed[ord('c')] = torch.Tensor([0., 1., 0., 0.])
one_hot_embed[ord('g')] = torch.Tensor([0., 0., 1., 0.])
one_hot_embed[ord('t')] = torch.Tensor([0., 0., 0., 1.])
one_hot_embed[ord('n')] = torch.Tensor([0., 0., 0., 0.])
one_hot_embed[ord('A')] = torch.Tensor([1., 0., 0., 0.])
one_hot_embed[ord('C')] = torch.Tensor([0., 1., 0., 0.])
one_hot_embed[ord('G')] = torch.Tensor([0., 0., 1., 0.])
one_hot_embed[ord('T')] = torch.Tensor([0., 0., 0., 1.])
one_hot_embed[ord('N')] = torch.Tensor([0., 0., 0., 0.])
one_hot_embed[ord('.')] = torch.Tensor([0.25, 0.25, 0.25, 0.25])

# 创建一个用于反向互补的映射张量
reverse_complement_map = torch.Tensor([3, 2, 1, 0, 4]).long()

# 将字符串转换为张量
def torch_fromstring(seq_strs):
    batched = not isinstance(seq_strs, str)
    seq_strs = cast_list(seq_strs)
    np_seq_chrs = list(map(lambda t: np.fromstring(t, dtype = np.uint8), seq_strs))
    seq_chrs = list(map(torch.from_numpy, np_seq_chrs))
    return torch.stack(seq_chrs) if batched else seq_chrs[0]

# 将字符串转换为序列索引
def str_to_seq_indices(seq_strs):
    seq_chrs = torch_fromstring(seq_strs)
    return seq_indices_embed[seq_chrs.long()]

# 将字符串转换为 one-hot 编码
def str_to_one_hot(seq_strs):
    seq_chrs = torch_fromstring(seq_strs)
    return one_hot_embed[seq_chrs.long()]

# 将序列索引转换为 one-hot 编码
def seq_indices_to_one_hot(t, padding = -1):
    is_padding = t == padding
    t = t.clamp(min = 0)
    one_hot = F.one_hot(t, num_classes = 5)
    out = one_hot[..., :4].float()
    out = out.masked_fill(is_padding[..., None], 0.25)
    return out

# 数据增强

# 反向互补序列索引
def seq_indices_reverse_complement(seq_indices):
    complement = reverse_complement_map[seq_indices.long()]
    return torch.flip(complement, dims = (-1,))

# 反向互补 one-hot 编码
def one_hot_reverse_complement(one_hot):
    *_, n, d = one_hot.shape
    assert d == 4, 'must be one hot encoding with last dimension equal to 4'
    return torch.flip(one_hot, (-1, -2))

# 处理 bed 文件

# 定义 FastaInterval 类
class FastaInterval():
    def __init__(
        self,
        *,
        fasta_file,
        context_length = None,
        return_seq_indices = False,
        shift_augs = None,
        rc_aug = False
    ):
        fasta_file = Path(fasta_file)
        assert fasta_file.exists(), 'path to fasta file must exist'

        self.seqs = Fasta(str(fasta_file))
        self.return_seq_indices = return_seq_indices
        self.context_length = context_length
        self.shift_augs = shift_augs
        self.rc_aug = rc_aug
    # 定义一个方法,用于生成指定染色体上指定区间的序列
    def __call__(self, chr_name, start, end, return_augs = False):
        # 计算区间长度
        interval_length = end - start
        # 获取染色体序列
        chromosome = self.seqs[chr_name]
        # 获取染色体序列长度
        chromosome_length = len(chromosome)

        # 如果存在平移增强参数
        if exists(self.shift_augs):
            # 获取最小和最大平移值
            min_shift, max_shift = self.shift_augs
            max_shift += 1

            # 计算实际的最小和最大平移值
            min_shift = max(start + min_shift, 0) - start
            max_shift = min(end + max_shift, chromosome_length) - end

            # 随机选择平移值
            rand_shift = randrange(min_shift, max_shift)
            start += rand_shift
            end += rand_shift

        # 初始化左右填充值
        left_padding = right_padding = 0

        # 如果存在上下文长度参数且区间长度小于上下文长度
        if exists(self.context_length) and interval_length < self.context_length:
            # 计算额外的序列长度
            extra_seq = self.context_length - interval_length

            # 计算左右额外序列长度
            extra_left_seq = extra_seq // 2
            extra_right_seq = extra_seq - extra_left_seq

            start -= extra_left_seq
            end += extra_right_seq

        # 处理左边界溢出
        if start < 0:
            left_padding = -start
            start = 0

        # 处理右边界溢出
        if end > chromosome_length:
            right_padding = end - chromosome_length
            end = chromosome_length

        # 生成序列并进行填充
        seq = ('.' * left_padding) + str(chromosome[start:end]) + ('.' * right_padding)

        # 判断是否需要进行反向互补增强
        should_rc_aug = self.rc_aug and coin_flip()

        # 如果需要返回序列索引
        if self.return_seq_indices:
            # 将序列转换为索引
            seq = str_to_seq_indices(seq)

            # 如果需要反向互补增强
            if should_rc_aug:
                seq = seq_indices_reverse_complement(seq)

            return seq

        # 将序列转换为独热编码
        one_hot = str_to_one_hot(seq)

        # 如果需要反向互补增强
        if should_rc_aug:
            one_hot = one_hot_reverse_complement(one_hot)

        # 如果不需要返回增强数据
        if not return_augs:
            return one_hot

        # 返回平移整数以及是否激活反向互补的布尔值
        rand_shift_tensor = torch.tensor([rand_shift])
        rand_aug_bool_tensor = torch.tensor([should_rc_aug])

        return one_hot, rand_shift_tensor, rand_aug_bool_tensor
# 定义一个继承自 Dataset 的 GenomeIntervalDataset 类
class GenomeIntervalDataset(Dataset):
    # 初始化函数,接受多个参数
    def __init__(
        self,
        bed_file,
        fasta_file,
        filter_df_fn = identity,
        chr_bed_to_fasta_map = dict(),
        context_length = None,
        return_seq_indices = False,
        shift_augs = None,
        rc_aug = False,
        return_augs = False
    ):
        # 调用父类的初始化函数
        super().__init__()
        # 将 bed_file 转换为 Path 对象
        bed_path = Path(bed_file)
        # 断言 bed 文件路径存在
        assert bed_path.exists(), 'path to .bed file must exist'

        # 读取 bed 文件内容到 DataFrame
        df = pl.read_csv(str(bed_path), separator = '\t', has_header = False)
        # 对 DataFrame 应用过滤函数
        df = filter_df_fn(df)
        # 将过滤后的 DataFrame 赋值给实例变量 df
        self.df = df

        # 如果 bed 文件中的染色体名称与 fasta 文件中的键名不同,可以在运行时重新映射
        self.chr_bed_to_fasta_map = chr_bed_to_fasta_map

        # 创建 FastaInterval 对象,传入 fasta 文件路径和其他参数
        self.fasta = FastaInterval(
            fasta_file = fasta_file,
            context_length = context_length,
            return_seq_indices = return_seq_indices,
            shift_augs = shift_augs,
            rc_aug = rc_aug
        )

        # 设置是否返回增强数据的标志
        self.return_augs = return_augs

    # 返回数据集的长度
    def __len__(self):
        return len(self.df)

    # 根据索引获取数据
    def __getitem__(self, ind):
        # 获取指定索引处的区间信息
        interval = self.df.row(ind)
        # 解析区间信息中的染色体名称、起始位置和结束位置
        chr_name, start, end = (interval[0], interval[1], interval[2])
        # 如果染色体名称需要重新映射,则进行映射
        chr_name = self.chr_bed_to_fasta_map.get(chr_name, chr_name)
        # 调用 FastaInterval 对象的方法,返回指定区间的数据
        return self.fasta(chr_name, start, end, return_augs = self.return_augs)

.\lucidrains\enformer-pytorch\enformer_pytorch\finetune.py

# 导入 torch 库
import torch
# 导入类型提示 Optional
from typing import Optional

# 从 copy 模块中导入 deepcopy 函数
from copy import deepcopy
# 从 contextlib 模块中导入 contextmanager 装饰器
from contextlib import contextmanager
# 从 torch.nn.functional 模块中导入 F 别名
import torch.nn.functional as F
# 从 torch 模块中导入 nn、einsum
from torch import nn, einsum

# 从 einops 模块中导入 rearrange、repeat
from einops import rearrange, repeat
# 从 einops.layers.torch 模块中导入 Rearrange 类
from einops.layers.torch import Rearrange
# 从 enformer_pytorch.modeling_enformer 模块中导入 Enformer、poisson_loss 函数
from enformer_pytorch.modeling_enformer import Enformer, poisson_loss

# 从 discrete_key_value_bottleneck_pytorch 模块中导入 DiscreteKeyValueBottleneck 类

# 定义 exists 函数,判断变量是否存在
def exists(val):
    return val is not None

# 定义 default 函数,如果变量存在则返回其值,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# 定义 null_context 上下文管理器
@contextmanager
def null_context():
    yield

# 定义 better sequential 函数,返回过滤掉不存在的模块的 nn.Sequential 对象
def Sequential(*modules):
    return nn.Sequential(*filter(exists, modules))

# 控制层的冻结

# 设置模块的 requires_grad 属性
def set_module_requires_grad_(module, requires_grad):
    for param in module.parameters():
        param.requires_grad = requires_grad

# 冻结所有层
def freeze_all_layers_(module):
    set_module_requires_grad_(module, False)

# 解冻所有层
def unfreeze_all_layers_(module):
    set_module_requires_grad_(module, True)

# 冻结批归一化层
def freeze_batchnorms_(model):
    bns = [m for m in model.modules() if isinstance(m, nn.BatchNorm1d)]

    for bn in bns:
        bn.eval()
        bn.track_running_stats = False
        set_module_requires_grad_(bn, False)

# 冻结除了层归一化层之外的所有层
def freeze_all_but_layernorms_(model):
    for m in model.modules():
        set_module_requires_grad_(m, isinstance(m, nn.LayerNorm))

# 冻结除了最后 N 层之外的所有层
def freeze_all_but_last_n_layers_(enformer, n):
    assert isinstance(enformer, Enformer)
    freeze_all_layers_(enformer)

    transformer_blocks = enformer.transformer

    for module in transformer_blocks[-n:]:
        set_module_requires_grad_(module, True)

# 获取 Enformer 的嵌入

def get_enformer_embeddings(
    model,
    seq,
    freeze = False,
    train_layernorms_only = False,
    train_last_n_layers_only = None,
    enformer_kwargs: dict = {}
):
    freeze_batchnorms_(model)

    if train_layernorms_only:
        assert not freeze, 'you set the intent to train the layernorms of the enformer, yet also indicated you wanted to freeze the entire model'
        freeze_all_but_layernorms_(model)

    if exists(train_last_n_layers_only):
        assert not freeze, 'you set the intent to train last N layers of enformer, but also indicated you wanted to freeze the entire network'
        freeze_all_but_last_n_layers_(model, train_last_n_layers_only)

    enformer_context = null_context() if not freeze else torch.no_grad()

    with enformer_context:
        embeddings = model(seq, return_only_embeddings = True, **enformer_kwargs)

        if freeze:
            embeddings.detach_()

    return embeddings

# 微调包装类

# 额外头部投影,类似于人类和老鼠轨迹的训练方式

class HeadAdapterWrapper(nn.Module):
    def __init__(
        self,
        *,
        enformer,
        num_tracks,
        post_transformer_embed = False, # 是否从变换器后面的嵌入中获取嵌入,而不是在最终的逐点卷积之后获取 - 这将添加另一个层归一化
        discrete_key_value_bottleneck = False,
        bottleneck_num_memories = 256,
        bottleneck_num_codebooks = 4,
        bottleneck_decay = 0.9,
        transformer_embed_fn: nn.Module = nn.Identity(),
        output_activation: Optional[nn.Module] = nn.Softplus(),
        auto_set_target_length = True
        ):
        # 调用父类的构造函数
        super().__init__()
        # 断言 enformer 是 Enformer 类的实例
        assert isinstance(enformer, Enformer)
        # 计算 enformer_hidden_dim,如果 post_transformer_embed 为 False,则乘以 2
        enformer_hidden_dim = enformer.dim * (2 if not post_transformer_embed else 1)

        # 设置离散键值瓶颈的标志
        self.discrete_key_value_bottleneck = discrete_key_value_bottleneck

        # 如果启用了离散键值瓶颈
        if discrete_key_value_bottleneck:
            # 创建 DiscreteKeyValueBottleneck 对象
            enformer = DiscreteKeyValueBottleneck(
                encoder = enformer,
                dim = enformer_hidden_dim,
                num_memory_codebooks = bottleneck_num_codebooks,
                num_memories = bottleneck_num_memories,
                dim_memory = enformer_hidden_dim // bottleneck_num_codebooks,
                decay = bottleneck_decay,
            )

        # 设置 post_transformer_embed 标志
        self.post_transformer_embed = post_transformer_embed

        # 设置 enformer 属性
        self.enformer = enformer

        # 设置 auto_set_target_length 标志
        self.auto_set_target_length = auto_set_target_length

        # 如果启用了 post_transformer_embed
        if post_transformer_embed:
            # 深拷贝 enformer 对象
            self.enformer = deepcopy(enformer)
            # 将 enformer 的最后一层设置为 nn.Identity()
            self.enformer._trunk[-1] = nn.Identity()
            # 将 enformer 的 final_pointwise 层设置为 nn.Identity()
            self.enformer.final_pointwise = nn.Identity()

        # 设置 post_embed_transform 属性
        self.post_embed_transform = Sequential(
            transformer_embed_fn,
            nn.LayerNorm(enformer_hidden_dim) if post_transformer_embed else None
        )

        # 设置 to_tracks 属性
        self.to_tracks = Sequential(
            nn.Linear(enformer_hidden_dim, num_tracks),
            output_activation
        )

    # 定义前向传播函数
    def forward(
        self,
        seq,
        *,
        target = None,
        freeze_enformer = False,
        finetune_enformer_ln_only = False,
        finetune_last_n_layers_only = None
    ):
        # 初始化 enformer_kwargs 字典
        enformer_kwargs = dict()

        # 如果存在目标数据并且 auto_set_target_length 为 True
        if exists(target) and self.auto_set_target_length:
            # 设置 enformer_kwargs 中的 target_length 键值对
            enformer_kwargs = dict(target_length = target.shape[-2])

        # 如果启用了离散键值瓶颈
        if self.discrete_key_value_bottleneck:
            # 获取 enformer 的 embeddings
            embeddings = self.enformer(seq, return_only_embeddings = True, **enformer_kwargs)
        else:
            # 获取 enformer 的 embeddings
            embeddings = get_enformer_embeddings(self.enformer, seq, freeze = freeze_enformer, train_layernorms_only = finetune_enformer_ln_only, train_last_n_layers_only = finetune_last_n_layers_only, enformer_kwargs = enformer_kwargs)

        # 将 embeddings 转换为预测结果
        preds = self.to_tracks(embeddings)

        # 如果不存在目标数据,则返回预测结果
        if not exists(target):
            return preds

        # 计算 Poisson 损失并返回结果
        return poisson_loss(preds, target)
# 定义一个包装器,允许为每个轨道提供上下文维度
# 上下文嵌入将投影到头线性投影(超网络)的权重和偏置中

class ContextAdapterWrapper(nn.Module):
    def __init__(
        self,
        *,
        enformer,  # Enformer 模型
        context_dim,  # 上下文维度
        discrete_key_value_bottleneck = False,  # 是否使用离散键值瓶颈
        bottleneck_num_memories = 256,  # 瓶颈内存数量
        bottleneck_num_codebooks = 4,  # 瓶颈码书数量
        bottleneck_decay = 0.9,  # 瓶颈衰减率
        auto_set_target_length = True,  # 是否自动设置目标长度
        output_activation: Optional[nn.Module] = nn.Softplus()  # 输出激活函数,默认为 Softplus
    ):
        super().__init__()
        assert isinstance(enformer, Enformer)
        enformer_hidden_dim = enformer.dim * 2

        self.discrete_key_value_bottleneck = discrete_key_value_bottleneck

        if discrete_key_value_bottleneck:
            enformer = DiscreteKeyValueBottleneck(
                encoder = enformer,
                dim = enformer_hidden_dim,
                num_memory_codebooks = bottleneck_num_codebooks,
                num_memories = bottleneck_num_memories,
                dim_memory = enformer_hidden_dim // bottleneck_num_codebooks,
                decay = bottleneck_decay,
            )

        self.enformer = enformer

        self.auto_set_target_length = auto_set_target_length

        self.to_context_weights = nn.Parameter(torch.randn(context_dim, enformer_hidden_dim))  # 上下文权重参数
        self.to_context_bias = nn.Parameter(torch.randn(context_dim))  # 上下文偏置参数

        self.activation = default(output_activation, nn.Identity())  # 激活函数

    def forward(
        self,
        seq,  # 输入序列
        *,
        context,  # 上下文
        target = None,  # 目标
        freeze_enformer = False,  # 是否冻结 Enformer
        finetune_enformer_ln_only = False,  # 是否仅微调 Enformer 层归一化
        finetune_last_n_layers_only = None  # 仅微调最后 n 层
    ):
        enformer_kwargs = dict()

        if exists(target) and self.auto_set_target_length:
            enformer_kwargs = dict(target_length = target.shape[-2])

        if self.discrete_key_value_bottleneck:
            embeddings = self.enformer(seq, return_only_embeddings = True, **enformer_kwargs)
        else:
            embeddings = get_enformer_embeddings(self.enformer, seq, freeze = freeze_enformer, train_layernorms_only = finetune_enformer_ln_only, train_last_n_layers_only = finetune_last_n_layers_only, enformer_kwargs = enformer_kwargs)

        weights = einsum('t d, d e -> t e', context, self.to_context_weights)  # 计算权重
        bias = einsum('t d, d -> t', context, self.to_context_bias)  # 计算偏置

        pred = einsum('b n d, t d -> b n t', embeddings, weights) + bias  # 预测结果

        pred = self.activation(pred)  # 应用激活函数

        if not exists(target):
            return pred

        return poisson_loss(pred, target)  # 返回 Poisson 损失

# 包装器,执行上下文的注意力聚合,上下文可以是一个标记列表(批次 x 序列 x 维度)

class ContextAttentionAdapterWrapper(nn.Module):
    def __init__(
        self,
        *,
        enformer,  # Enformer 模型
        context_dim,  # 上下文维度
        heads = 8,  # 头数
        dim_head = 64,  # 每个头的维度
        discrete_key_value_bottleneck = False,  # 是否使用离散键值瓶颈
        bottleneck_num_memories = 256,  # 瓶颈内存数量
        bottleneck_num_codebooks = 4,  # 瓶颈码书数量
        bottleneck_decay = 0.9,  # 瓶颈衰减率
        auto_set_target_length = True,  # 是否自动设置目标长度
        output_activation: Optional[nn.Module] = nn.Softplus()  # 输出激活函数,默认为 Softplus
    ):
        # 调用父类的构造函数
        super().__init__()
        # 断言 enformer 是 Enformer 类的实例
        assert isinstance(enformer, Enformer)
        # 计算 enformer 隐藏维度
        enformer_hidden_dim = enformer.dim * 2

        # 设置离散键值瓶颈
        self.discrete_key_value_bottleneck = discrete_key_value_bottleneck

        # 如果启用了离散键值瓶颈
        if discrete_key_value_bottleneck:
            # 创建 DiscreteKeyValueBottleneck 对象
            enformer = DiscreteKeyValueBottleneck(
                encoder = enformer,
                dim = enformer_hidden_dim,
                num_memory_codebooks = bottleneck_num_codebooks,
                num_memories = bottleneck_num_memories,
                dim_memory = enformer_hidden_dim // bottleneck_num_codebooks,
                decay = bottleneck_decay,
            )

        # 设置 enformer
        self.enformer = enformer

        # 设置是否自动设置目标长度
        self.auto_set_target_length = auto_set_target_length

        # 对查询进行归一化
        self.query_norm = nn.LayerNorm(enformer_hidden_dim)
        # 对键值进行归一化
        self.key_values_norm = nn.LayerNorm(context_dim)

        # 设置缩放因子和头数
        self.scale = dim_head ** -0.5
        self.heads = heads
        inner_dim = heads * dim_head
        # 线性变换生成查询
        self.to_queries = nn.Linear(enformer_hidden_dim, inner_dim, bias = False)

        # 初始化空键和空值
        self.null_key = nn.Parameter(torch.randn(inner_dim))
        self.null_value = nn.Parameter(torch.randn(inner_dim))

        # 线性变换生成键值
        self.to_key_values = nn.Linear(context_dim, inner_dim * 2, bias = False)
        # 线性变换生成输出
        self.to_out = nn.Linear(inner_dim, enformer_hidden_dim)

        # 线性变换生成预测结果
        self.to_pred  = Sequential(
            nn.Linear(enformer_hidden_dim, 1),
            Rearrange('b c ... 1 -> b ... c'),
            output_activation
        )

    # 前向传播函数
    def forward(
        self,
        seq,
        *,
        context,
        context_mask = None,
        target = None,
        freeze_enformer = False,
        finetune_enformer_ln_only = False,
        finetune_last_n_layers_only = None
        ):
        """
        b - batch
        n - sequence length
        c - number of contexts (tracks)
        d - dimension
        i - sequence length (query embeddings)
        j - sequence length (keys / values contexts)
        h - attention heads
        """

        # 设置变量 h 为 self.heads

        enformer_kwargs = dict()

        # 如果 target 存在且 self.auto_set_target_length 为真,则设置 enformer_kwargs 的 target_length 为 target 的倒数第二维度长度
        if exists(target) and self.auto_set_target_length:
            enformer_kwargs = dict(target_length = target.shape[-2])

        # 如果 self.discrete_key_value_bottleneck 为真,则调用 self.enformer 方法获取 embeddings
        # 否则调用 get_enformer_embeddings 方法获取 embeddings
        if self.discrete_key_value_bottleneck:
            embeddings = self.enformer(seq, return_only_embeddings = True, **enformer_kwargs)
        else:
            embeddings = get_enformer_embeddings(self.enformer, seq, freeze = freeze_enformer, train_layernorms_only = finetune_enformer_ln_only, train_last_n_layers_only = finetune_last_n_layers_only, enformer_kwargs = enformer_kwargs)

        # 从 genetic 到 context 执行交叉注意力

        # 如果 context 的维度为 2,则将其重排为 'b d -> b 1 d'
        if context.ndim == 2:
            context = rearrange(context, 'b d -> b 1 d')

        # 获取查询 q,键 k 和值 v
        q = self.to_queries(self.query_norm(embeddings))
        k, v = self.to_key_values(self.key_values_norm(context)).chunk(2, dim = -1)

        # 创建 null_k 和 null_v,并将其重复到与 k 和 v 相同的维度
        null_k, null_v = map(lambda t: repeat(t, 'd -> b 1 d', b = context.shape[0]), (self.null_key, self.null_value))

        # 将 null_k 和 k 连接在一起,将 null_v 和 v 连接在一起
        k = torch.cat((null_k, k), dim = 1)
        v = torch.cat((null_v, v), dim = 1)

        # 分离头部
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
        sim = einsum('b h i d, c h j d -> b c h i j', q, k) * self.scale

        # 掩码
        if exists(context_mask):
            context_mask = F.pad(context_mask, (1, 0), value = True)
            context_mask = rearrange(context_mask, 'b j -> b 1 1 1 j')
            sim = sim.masked_fill(~context_mask, -torch.finfo(sim.dtype).max)

        # 注意力
        attn = sim.softmax(dim = -1)

        # 聚合
        out = einsum('b c h i j, c h j d -> b c h i d', attn, v)
        out = rearrange(out, 'b c h n d -> b c n (h d)', h = h)

        # 合并头部
        branch_out = self.to_out(out)

        # 残差连接
        embeddings = embeddings + branch_out

        # 转换为预测
        pred = self.to_pred(embeddings)

        # 如果 target 不存在,则返回 pred,否则返回 poisson_loss(pred, target)
        if not exists(target):
            return pred

        return poisson_loss(pred, target)

.\lucidrains\enformer-pytorch\enformer_pytorch\metrics.py

from torchmetrics import Metric
from typing import Optional
import torch

# 定义一个自定义的 Metric 类,用于计算每个通道的平均皮尔逊相关系数
class MeanPearsonCorrCoefPerChannel(Metric):
    # 是否可微分,默认为不可微分
    is_differentiable: Optional[bool] = False
    # 较高值是否更好,默认为是
    higher_is_better: Optional[bool] = True

    def __init__(self, n_channels:int, dist_sync_on_step=False):
        """Calculates the mean pearson correlation across channels aggregated over regions"""
        # 调用父类的初始化方法
        super().__init__(dist_sync_on_step=dist_sync_on_step)
        # 设置要减少的维度
        self.reduce_dims=(0, 1)
        # 添加状态变量,用于存储乘积、真实值、真实值平方、预测值、预测值平方、计数
        self.add_state("product", default=torch.zeros(n_channels, dtype=torch.float32), dist_reduce_fx="sum")
        self.add_state("true", default=torch.zeros(n_channels, dtype=torch.float32), dist_reduce_fx="sum")
        self.add_state("true_squared", default=torch.zeros(n_channels, dtype=torch.float32), dist_reduce_fx="sum")
        self.add_state("pred", default=torch.zeros(n_channels, dtype=torch.float32), dist_reduce_fx="sum")
        self.add_state("pred_squared", default=torch.zeros(n_channels, dtype=torch.float32), dist_reduce_fx="sum")
        self.add_state("count", default=torch.zeros(n_channels, dtype=torch.float32), dist_reduce_fx="sum")

    def update(self, preds: torch.Tensor, target: torch.Tensor):
        # 断言预测值和目标值的形状相同
        assert preds.shape == target.shape

        # 更新状态变量
        self.product += torch.sum(preds * target, dim=self.reduce_dims)
        self.true += torch.sum(target, dim=self.reduce_dims)
        self.true_squared += torch.sum(torch.square(target), dim=self.reduce_dims)
        self.pred += torch.sum(preds, dim=self.reduce_dims)
        self.pred_squared += torch.sum(torch.square(preds), dim=self.reduce_dims)
        self.count += torch.sum(torch.ones_like(target), dim=self.reduce_dims)

    def compute(self):
        # 计算真实值和预测值的均值
        true_mean = self.true / self.count
        pred_mean = self.pred / self.count

        # 计算协方差、真实值方差、预测值方差、真实值和预测值的平方根乘积、相关系数
        covariance = (self.product
                    - true_mean * self.pred
                    - pred_mean * self.true
                    + self.count * true_mean * pred_mean)

        true_var = self.true_squared - self.count * torch.square(true_mean)
        pred_var = self.pred_squared - self.count * torch.square(pred_mean)
        tp_var = torch.sqrt(true_var) * torch.sqrt(pred_var)
        correlation = covariance / tp_var
        return correlation

.\lucidrains\enformer-pytorch\enformer_pytorch\modeling_enformer.py

# 导入所需的库
import math
from pathlib import Path

import torch
from torch import nn, einsum
import torch.nn.functional as F
import torch.distributed as dist
from torch.utils.checkpoint import checkpoint_sequential

from einops import rearrange, reduce
from einops.layers.torch import Rearrange

from enformer_pytorch.data import str_to_one_hot, seq_indices_to_one_hot

from enformer_pytorch.config_enformer import EnformerConfig

from transformers import PreTrainedModel

# 定义常量
SEQUENCE_LENGTH = 196_608
TARGET_LENGTH = 896

# 从 TensorFlow 中加载 gamma 位置
# 解决 TensorFlow 和 PyTorch 之间 xlogy 结果的差异
# 解决方案来自 @johahi
DIR = Path(__file__).parents[0]
TF_GAMMAS = torch.load(str(DIR / "precomputed"/ "tf_gammas.pt")

# 辅助函数

# 检查值是否存在
def exists(val):
    return val is not None

# 如果值存在则返回该值,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# 返回始终为指定值的函数
def always(val):
    def inner(*args, **kwargs):
        return val
    return inner

# 对字典中的值应用函数
def map_values(fn, d):
    return {key: fn(values) for key, values in d.items()}

# 在指数范围内生成整数序列
def exponential_linspace_int(start, end, num, divisible_by = 1):
    def _round(x):
        return int(round(x / divisible_by) * divisible_by)

    base = math.exp(math.log(end / start) / (num - 1))
    return [_round(start * base**i) for i in range(num)]

# 计算对数,避免值过小
def log(t, eps = 1e-20):
    return torch.log(t.clamp(min = eps))

# 可能用于同步批归一化,在分布式训练中
def MaybeSyncBatchnorm(is_distributed = None):
    is_distributed = default(is_distributed, dist.is_initialized() and dist.get_world_size() > 1)
    return nn.SyncBatchNorm if is_distributed else nn.BatchNorm1d

# 损失函数和指标

# Poisson 损失函数
def poisson_loss(pred, target):
    return (pred - target * log(pred)).mean()

# 计算 Pearson 相关系数
def pearson_corr_coef(x, y, dim = 1, reduce_dims = (-1,)):
    x_centered = x - x.mean(dim = dim, keepdim = True)
    y_centered = y - y.mean(dim = dim, keepdim = True)
    return F.cosine_similarity(x_centered, y_centered, dim = dim).mean(dim = reduce_dims)

# 相对位置编码函数

# 获取指数衰减的位置特征
def get_positional_features_exponential(positions, features, seq_len, min_half_life = 3., dtype = torch.float):
    max_range = math.log(seq_len) / math.log(2.)
    half_life = 2 ** torch.linspace(min_half_life, max_range, features, device = positions.device)
    half_life = half_life[None, ...]
    positions = positions.abs()[..., None]
    return torch.exp(-math.log(2.) / half_life * positions)

# 获取中心掩码位置特征
def get_positional_features_central_mask(positions, features, seq_len, dtype = torch.float):
    center_widths = 2 ** torch.arange(1, features + 1, device = positions.device).to(dtype)
    center_widths = center_widths - 1
    return (center_widths[None, ...] > positions.abs()[..., None]).to(dtype)

# Gamma 分布概率密度函数
def gamma_pdf(x, concentration, rate):
    log_unnormalized_prob = torch.xlogy(concentration - 1., x) - rate * x
    log_normalization = (torch.lgamma(concentration) - concentration * torch.log(rate))
    return torch.exp(log_unnormalized_prob - log_normalization)

# 获取 Gamma 分布位置特征
def get_positional_features_gamma(positions, features, seq_len, stddev = None, start_mean = None, eps = 1e-8, dtype = torch.float):
    if not exists(stddev):
        stddev = seq_len / (2 * features)

    if not exists(start_mean):
        start_mean = seq_len / features

    mean = torch.linspace(start_mean, seq_len, features, device = positions.device)

    mean = mean[None, ...]
    concentration = (mean / stddev) ** 2
    rate = mean / stddev ** 2

    probabilities = gamma_pdf(positions.to(dtype).abs()[..., None], concentration, rate)
    probabilities = probabilities + eps
    outputs = probabilities / torch.amax(probabilities, dim = -1, keepdim = True)
    return outputs

# 获取位置嵌入
def get_positional_embed(seq_len, feature_size, device, use_tf_gamma, dtype = torch.float):
    distances = torch.arange(-seq_len + 1, seq_len, device = device)

    assert not use_tf_gamma or seq_len == 1536, 'if using tf gamma, only sequence length of 1536 allowed for now'
    # 定义特征函数列表,包括指数特征、中心掩码特征和伽马特征(如果不使用 TensorFlow 伽马则使用 TF_GAMMAS)
    feature_functions = [
        get_positional_features_exponential,
        get_positional_features_central_mask,
        get_positional_features_gamma if not use_tf_gamma else always(TF_GAMMAS.to(device))
    ]

    # 计算特征组件的数量
    num_components = len(feature_functions) * 2

    # 检查特征大小是否能被组件数量整除
    if (feature_size % num_components) != 0:
        raise ValueError(f'feature size is not divisible by number of components ({num_components})')

    # 计算每个类别的基础数量
    num_basis_per_class = feature_size // num_components

    # 初始化嵌入列表
    embeddings = []
    # 遍历特征函数列表,生成嵌入特征并添加到嵌入列表中
    for fn in feature_functions:
        embeddings.append(fn(distances, num_basis_per_class, seq_len, dtype = dtype))

    # 在最后一个维度上连接所有嵌入特征
    embeddings = torch.cat(embeddings, dim = -1)
    # 在最后一个维度上连接嵌入特征和距离的符号乘积
    embeddings = torch.cat((embeddings, torch.sign(distances)[..., None] * embeddings), dim = -1)
    # 将嵌入特征转换为指定数据类型并返回
    return embeddings.to(dtype)
def relative_shift(x):
    # 创建一个与 x 的最后一个维度大小相同的全零张量
    to_pad = torch.zeros_like(x[..., :1])
    # 在 x 的最后一个维度上连接全零张量,实现相对位移
    x = torch.cat((to_pad, x), dim=-1)
    # 获取 x 的形状信息
    _, h, t1, t2 = x.shape
    # 重新调整 x 的形状
    x = x.reshape(-1, h, t2, t1)
    # 从 x 中删除第一个元素
    x = x[:, :, 1:, :]
    # 重新调整 x 的形状
    x = x.reshape(-1, h, t1, t2 - 1)
    # 返回 x 的前一半元素
    return x[..., :((t2 + 1) // 2)]

# classes

class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        # 返回残差连接结果
        return self.fn(x, **kwargs) + x

class GELU(nn.Module):
    def forward(self, x):
        # GELU 激活函数
        return torch.sigmoid(1.702 * x) * x

class AttentionPool(nn.Module):
    def __init__(self, dim, pool_size=2):
        super().__init__()
        self.pool_size = pool_size
        # 定义池化函数
        self.pool_fn = Rearrange('b d (n p) -> b d n p', p=pool_size)

        # 定义注意力机制中的卷积层
        self.to_attn_logits = nn.Conv2d(dim, dim, 1, bias=False)

        # 初始化卷积层的权重
        nn.init.dirac_(self.to_attn_logits.weight)

        # 对卷积层的权重进行缩放
        with torch.no_grad():
            self.to_attn_logits.weight.mul_(2)

    def forward(self, x):
        b, _, n = x.shape
        remainder = n % self.pool_size
        needs_padding = remainder > 0

        if needs_padding:
            # 对输入进行填充
            x = F.pad(x, (0, remainder), value=0)
            mask = torch.zeros((b, 1, n), dtype=torch.bool, device=x.device)
            mask = F.pad(mask, (0, remainder), value=True)

        # 对输入进行池化操作
        x = self.pool_fn(x)
        # 计算注意力权重
        logits = self.to_attn_logits(x)

        if needs_padding:
            mask_value = -torch.finfo(logits.dtype).max
            logits = logits.masked_fill(self.pool_fn(mask), mask_value)

        # 计算加权和
        attn = logits.softmax(dim=-1)

        return (x * attn).sum(dim=-1)

class TargetLengthCrop(nn.Module):
    def __init__(self, target_length):
        super().__init__()
        self.target_length = target_length

    def forward(self, x):
        seq_len, target_len = x.shape[-2], self.target_length

        if target_len == -1:
            return x

        if seq_len < target_len:
            raise ValueError(f'sequence length {seq_len} is less than target length {target_len}')

        trim = (target_len - seq_len) // 2

        if trim == 0:
            return x

        return x[:, -trim:trim]

def ConvBlock(dim, dim_out=None, kernel_size=1, is_distributed=None):
    batchnorm_klass = MaybeSyncBatchnorm(is_distributed=is_distributed)

    return nn.Sequential(
        batchnorm_klass(dim),
        GELU(),
        nn.Conv1d(dim, default(dim_out, dim), kernel_size, padding=kernel_size // 2)
    )

# attention classes

class Attention(nn.Module):
    def __init__(
        self,
        dim,
        *,
        num_rel_pos_features,
        heads=8,
        dim_key=64,
        dim_value=64,
        dropout=0.,
        pos_dropout=0.,
        use_tf_gamma=False
    ):
        super().__init__()
        self.scale = dim_key ** -0.5
        self.heads = heads

        # 线性变换得到查询、键、值
        self.to_q = nn.Linear(dim, dim_key * heads, bias=False)
        self.to_k = nn.Linear(dim, dim_key * heads, bias=False)
        self.to_v = nn.Linear(dim, dim_value * heads, bias=False)

        # 输��层的线性变换
        self.to_out = nn.Linear(dim_value * heads, dim)
        nn.init.zeros_(self.to_out.weight)
        nn.init.zeros_(self.to_out.bias)

        # 相对位置编码
        self.num_rel_pos_features = num_rel_pos_features
        self.to_rel_k = nn.Linear(num_rel_pos_features, dim_key * heads, bias=False)
        self.rel_content_bias = nn.Parameter(torch.randn(1, heads, 1, dim_key))
        self.rel_pos_bias = nn.Parameter(torch.randn(1, heads, 1, dim_key))

        # dropout
        self.pos_dropout = nn.Dropout(pos_dropout)
        self.attn_dropout = nn.Dropout(dropout)

        # 是否使用 tf gamma
        self.use_tf_gamma = use_tf_gamma
    # 定义前向传播函数,接受输入张量 x
    def forward(self, x):
        # 获取输入张量 x 的维度信息
        n, h, device = x.shape[-2], self.heads, x.device

        # 将输入张量 x 分别转换为查询(q)、键(k)、值(v)张量
        q = self.to_q(x)
        k = self.to_k(x)
        v = self.to_v(x)

        # 将查询(q)、键(k)、值(v)张量重排维度,以适应多头注意力机制
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))

        # 对查询张量(q)进行缩放
        q = q * self.scale

        # 计算内容注意力得分
        content_logits = einsum('b h i d, b h j d -> b h i j', q + self.rel_content_bias, k)

        # 获取位置嵌入向量
        positions = get_positional_embed(n, self.num_rel_pos_features, device, use_tf_gamma = self.use_tf_gamma, dtype = self.to_rel_k.weight.dtype)
        positions = self.pos_dropout(positions)
        rel_k = self.to_rel_k(positions)

        # 重排位置嵌入向量的维度,以适应多头注意力机制
        rel_k = rearrange(rel_k, 'n (h d) -> h n d', h = h)
        # 计算相对位置注意力得分
        rel_logits = einsum('b h i d, h j d -> b h i j', q + self.rel_pos_bias, rel_k)
        # 对相对位置注意力得分进行相对偏移
        rel_logits = relative_shift(rel_logits)

        # 组合内容注意力得分和相对位置注意力得分
        logits = content_logits + rel_logits
        # 对注意力得分进行 softmax 操作
        attn = logits.softmax(dim = -1)
        attn = self.attn_dropout(attn)

        # 根据注意力权重计算输出张量
        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        # 返回输出张量
        return self.to_out(out)
# 主类 Enformer 继承自 PreTrainedModel
class Enformer(PreTrainedModel):
    # 设置配置类和基础模型前缀
    config_class = EnformerConfig
    base_model_prefix = "enformer"

    # 从超参数创建 Enformer 实例的静态方法
    @staticmethod
    def from_hparams(**kwargs):
        return Enformer(EnformerConfig(**kwargs))

    # 初始化方法,接受配置参数
    def __init__(self, config):
        super().__init__(config)
        self.dim = config.dim
        half_dim = config.dim // 2
        twice_dim = config.dim * 2

        # 创建 stem 模块
        self.stem = nn.Sequential(
            nn.Conv1d(4, half_dim, 15, padding=7),
            Residual(ConvBlock(half_dim)),
            AttentionPool(half_dim, pool_size=2)
        )

        # 创建卷积 tower
        filter_list = exponential_linspace_int(half_dim, config.dim, num=(config.num_downsamples - 1), divisible_by=config.dim_divisible_by)
        filter_list = [half_dim, *filter_list]

        conv_layers = []
        for dim_in, dim_out in zip(filter_list[:-1], filter_list[1:]):
            conv_layers.append(nn.Sequential(
                ConvBlock(dim_in, dim_out, kernel_size=5),
                Residual(ConvBlock(dim_out, dim_out, 1)),
                AttentionPool(dim_out, pool_size=2)
            ))

        self.conv_tower = nn.Sequential(*conv_layers)

        # 是否使用 tensorflow gamma 位置
        use_tf_gamma = config.use_tf_gamma
        self.use_tf_gamma = use_tf_gamma

        # transformer 模块
        transformer = []
        for _ in range(config.depth):
            transformer.append(nn.Sequential(
                Residual(nn.Sequential(
                    nn.LayerNorm(config.dim),
                    Attention(
                        config.dim,
                        heads=config.heads,
                        dim_key=config.attn_dim_key,
                        dim_value=config.dim // config.heads,
                        dropout=config.attn_dropout,
                        pos_dropout=config.pos_dropout,
                        num_rel_pos_features=config.dim // config.heads,
                        use_tf_gamma=use_tf_gamma
                    ),
                    nn.Dropout(config.dropout_rate)
                )),
                Residual(nn.Sequential(
                    nn.LayerNorm(config.dim),
                    nn.Linear(config.dim, config.dim * 2),
                    nn.Dropout(config.dropout_rate),
                    nn.ReLU(),
                    nn.Linear(config.dim * 2, config.dim),
                    nn.Dropout(config.dropout_rate)
                ))
            ))

        self.transformer = nn.Sequential(*transformer)

        # 目标裁剪
        self.target_length = config.target_length
        self.crop_final = TargetLengthCrop(config.target_length)

        # 最终的 pointwise 模块
        self.final_pointwise = nn.Sequential(
            Rearrange('b n d -> b d n'),
            ConvBlock(filter_list[-1], twice_dim, 1),
            Rearrange('b d n -> b n d'),
            nn.Dropout(config.dropout_rate / 8),
            GELU()
        )

        # 创建 trunk 顺序模块
        self._trunk = nn.Sequential(
            Rearrange('b n d -> b d n'),
            self.stem,
            self.conv_tower,
            Rearrange('b d n -> b n d'),
            self.transformer,
            self.crop_final,
            self.final_pointwise
        )

        # 为人类和老鼠创建最终头部
        self.add_heads(**config.output_heads)

        # 在 transformer trunk 上使用检查点
        self.use_checkpointing = config.use_checkpointing

    # 添加头部方法
    def add_heads(self, **kwargs):
        self.output_heads = kwargs

        self._heads = nn.ModuleDict(map_values(lambda features: nn.Sequential(
            nn.Linear(self.dim * 2, features),
            nn.Softplus()
        ), kwargs))

    # 设置目标长度的方法
    def set_target_length(self, target_length):
        crop_module = self._trunk[-2]
        crop_module.target_length = target_length

    # trunk 属性
    @property
    def trunk(self):
        return self._trunk

    @property
    # 返回当前对象的头部属性
    def heads(self):
        return self._heads

    # 对输入进行处理,返回经过处理后的结果
    def trunk_checkpointed(self, x):
        # 重新排列输入的数据维度
        x = rearrange(x, 'b n d -> b d n')
        # 对输入数据进行处理
        x = self.stem(x)
        x = self.conv_tower(x)
        x = rearrange(x, 'b d n -> b n d')
        # 使用序列化函数对输入数据进行处理
        x = checkpoint_sequential(self.transformer, len(self.transformer), x)
        x = self.crop_final(x)
        x = self.final_pointwise(x)
        return x

    # 对输入数据进行前向传播处理
    def forward(
        self,
        x,
        target = None,
        return_corr_coef = False,
        return_embeddings = False,
        return_only_embeddings = False,
        head = None,
        target_length = None
    ):
        # 如果输入数据是列表,则将其转换为独热编码
        if isinstance(x, list):
            x = str_to_one_hot(x)

        # 如果输入数据是 torch.Tensor 类型且数据类型为 long,则将其转换为独热编码
        elif type(x) == torch.Tensor and x.dtype == torch.long:
            x = seq_indices_to_one_hot(x)
        # 将数据移动到指定设备上
        x.to(self.device)

        # 判断是否存在批次维度
        no_batch = x.ndim == 2

        # 如果没有批次维度,则重新排列数据维度
        if no_batch:
            x = rearrange(x, '... -> () ...')

        # 如果存在目标长度,则设置目标长度
        if exists(target_length):
            self.set_target_length(target_length)

        # 根据是否使用检查点技术选择相应的处理函数
        trunk_fn = self.trunk_checkpointed if self.use_checkpointing else self._trunk
        x = trunk_fn(x)

        # 如果没有批次维度,则重新排列数据维度
        if no_batch:
            x = rearrange(x, '() ... -> ...')

        # 如果只返回嵌入向量,则直接返回处理后的结果
        if return_only_embeddings:
            return x

        # 对处理后的结果进行映射处理
        out = map_values(lambda fn: fn(x), self._heads)

        # 如果指定了头部,则返回指定头部的结果
        if exists(head):
            assert head in self._heads, f'head {head} not found'
            out = out[head]

        # 如果存在目标数据,则计算损失
        if exists(target):
            assert exists(head), 'head must be passed in if one were to calculate loss directly with targets'

            # 如果需要返回相关系数,则返回相关系数
            if return_corr_coef:
                return pearson_corr_coef(out, target)

            # 返回泊松损失
            return poisson_loss(out, target)

        # 如果需要返回嵌入向量,则返回嵌入向量和处理后的结果
        if return_embeddings:
            return out, x

        # 返回处理后的结果
        return out
# 从预训练模型加载模型
def from_pretrained(name, use_tf_gamma = None, **kwargs):
    # 从预训练模型名称加载 Enformer 模型
    enformer = Enformer.from_pretrained(name, **kwargs)

    # 如果模型名称为 'EleutherAI/enformer-official-rough'
    if name == 'EleutherAI/enformer-official-rough':
        # 如果 use_tf_gamma 为 None,则设置为 True
        use_tf_gamma = default(use_tf_gamma, True)

        # 遍历 Enformer 模型的所有模块
        for module in enformer.modules():
            # 如果模块是 Attention 类型
            if isinstance(module, Attention):
                # 设置模块的 use_tf_gamma 属性为 use_tf_gamma
                module.use_tf_gamma = use_tf_gamma

    # 返回加载的 Enformer 模型
    return enformer

.\lucidrains\enformer-pytorch\enformer_pytorch\__init__.py

# 从enformer_pytorch包中导入EnformerConfig类
from enformer_pytorch.config_enformer import EnformerConfig
# 从enformer_pytorch包中导入Enformer、from_pretrained、SEQUENCE_LENGTH、AttentionPool类
from enformer_pytorch.modeling_enformer import Enformer, from_pretrained, SEQUENCE_LENGTH, AttentionPool
# 从enformer_pytorch包中导入seq_indices_to_one_hot、str_to_one_hot、GenomeIntervalDataset、FastaInterval类
from enformer_pytorch.data import seq_indices_to_one_hot, str_to_one_hot, GenomeIntervalDataset, FastaInterval

Enformer - Pytorch

Implementation of Enformer, Deepmind's attention network for predicting gene expression, in Pytorch. This repository also contains the means to fine tune pretrained models for your downstream tasks. The original tensorflow sonnet code can be found here.

Update: finetuned for predicting pseudobulk chromatin accessibility here

Install

$ pip install enformer-pytorch

Usage

import torch
from enformer_pytorch import Enformer

model = Enformer.from_hparams(
    dim = 1536,
    depth = 11,
    heads = 8,
    output_heads = dict(human = 5313, mouse = 1643),
    target_length = 896,
)
    
seq = torch.randint(0, 5, (1, 196_608)) # for ACGTN, in that order (-1 for padding)
output = model(seq)

output['human'] # (1, 896, 5313)
output['mouse'] # (1, 896, 1643)

You can also directly pass in the sequence as one-hot encodings, which must be float values

import torch
from enformer_pytorch import Enformer, seq_indices_to_one_hot

model = Enformer.from_hparams(
    dim = 1536,
    depth = 11,
    heads = 8,
    output_heads = dict(human = 5313, mouse = 1643),
    target_length = 896,
)

seq = torch.randint(0, 5, (1, 196_608))
one_hot = seq_indices_to_one_hot(seq)

output = model(one_hot)

output['human'] # (1, 896, 5313)
output['mouse'] # (1, 896, 1643)

Finally, one can fetch the embeddings, for fine-tuning and otherwise, by setting the return_embeddings flag to be True on forward

import torch
from enformer_pytorch import Enformer, seq_indices_to_one_hot

model = Enformer.from_hparams(
    dim = 1536,
    depth = 11,
    heads = 8,
    output_heads = dict(human = 5313, mouse = 1643),
    target_length = 896,
)

seq = torch.randint(0, 5, (1, 196_608))
one_hot = seq_indices_to_one_hot(seq)

output, embeddings = model(one_hot, return_embeddings = True)

embeddings # (1, 896, 3072)

For training, you can directly pass the head and target in to get the poisson loss

import torch
from enformer_pytorch import Enformer, seq_indices_to_one_hot

model = Enformer.from_hparams(
    dim = 1536,
    depth = 11,
    heads = 8,
    output_heads = dict(human = 5313, mouse = 1643),
    target_length = 200,
).cuda()

seq = torch.randint(0, 5, (196_608 // 2,)).cuda()
target = torch.randn(200, 5313).cuda()

loss = model(
    seq,
    head = 'human',
    target = target
)

loss.backward()

# after much training

corr_coef = model(
    seq,
    head = 'human',
    target = target,
    return_corr_coef = True
)

corr_coef # pearson R, used as a metric in the paper

Pretrained Model

Deepmind has released the weights for their tensorflow sonnet Enformer model! I have ported it over to Pytorch and uploaded it to 🤗 Huggingface (~1GB). There are still some rounding errors that seem to be accruing across the layers, resulting in an absolute error as high as 0.5. However, correlation coefficient look good so I am releasing the 'rough'ly working version. Will keep working on figuring out where the numerical errors are happening (it may be the attention pooling module, as I noticed the attention logits are pretty high).

Update: John St. John did some work and found that the enformer-official-rough model hits the reported marks in the paper - human pearson R of 0.625 for validation, and 0.65 for test.

Update: As of version 0.8.0, if one were to use the from_pretrained function to load the pretrained model, it should automatically use precomputed gamma positions to address a difference between tensorflow and pytorch xlogy. This should resolve the numerical discrepancy above. If you were to further finetune and not be using the from_pretrained function, please make sure to set use_tf_gamma = True when using .from_hparams to instantiate the Enformer

$ pip install enformer-pytorch>=0.5

Loading the model

from enformer_pytorch import from_pretrained

enformer = from_pretrained('EleutherAI/enformer-official-rough')

Quick sanity check on a single human validation point

$ python test_pretrained.py
# 0.5963 correlation coefficient on a validation sample

This is all made possible thanks to HuggingFace's custom model feature.

You can also load, with overriding of the target_length parameter, if you are working with shorter sequence lengths

from enformer_pytorch import from_pretrained

model = from_pretrained('EleutherAI/enformer-official-rough', target_length = 128, dropout_rate = 0.1)

# do your fine-tuning

To save on memory during fine-tuning a large Enformer model

from enformer_pytorch import from_pretrained

enformer = from_pretrained('EleutherAI/enformer-official-rough', use_checkpointing = True)

# finetune enformer on a limited budget

Fine-tuning

This repository will also allow for easy fine-tuning of Enformer.

Fine-tuning on new tracks

import torch
from enformer_pytorch import from_pretrained
from enformer_pytorch.finetune import HeadAdapterWrapper

enformer = from_pretrained('EleutherAI/enformer-official-rough')

model = HeadAdapterWrapper(
    enformer = enformer,
    num_tracks = 128,
    post_transformer_embed = False   # by default, embeddings are taken from after the final pointwise block w/ conv -> gelu - but if you'd like the embeddings right after the transformer block with a learned layernorm, set this to True
).cuda()

seq = torch.randint(0, 5, (1, 196_608 // 2,)).cuda()
target = torch.randn(1, 200, 128).cuda()  # 128 tracks

loss = model(seq, target = target)
loss.backward()

Finetuning on contextual data (cell type, transcription factor, etc)

import torch
from enformer_pytorch import from_pretrained
from enformer_pytorch.finetune import ContextAdapterWrapper

enformer = from_pretrained('EleutherAI/enformer-official-rough')
    
model = ContextAdapterWrapper(
    enformer = enformer,
    context_dim = 1024
).cuda()

seq = torch.randint(0, 5, (1, 196_608 // 2,)).cuda()

target = torch.randn(1, 200, 4).cuda()  # 4 tracks
context = torch.randn(4, 1024).cuda()   # 4 contexts for the different 'tracks'

loss = model(
    seq,
    context = context,
    target = target
)

loss.backward()

Finally, there is also a way to use attention aggregation from a set of context embeddings (or a single context embedding). Simply use the ContextAttentionAdapterWrapper

import torch
from enformer_pytorch import from_pretrained
from enformer_pytorch.finetune import ContextAttentionAdapterWrapper

enformer = from_pretrained('EleutherAI/enformer-official-rough')
    
model = ContextAttentionAdapterWrapper(
    enformer = enformer,
    context_dim = 1024,
    heads = 8,              # number of heads in the cross attention
    dim_head = 64           # dimension per head
).cuda()

seq = torch.randint(0, 5, (1, 196_608 // 2,)).cuda()

target = torch.randn(1, 200, 4).cuda()      # 4 tracks
context = torch.randn(4, 16, 1024).cuda()   # 4 contexts for the different 'tracks', each with 16 tokens

context_mask = torch.ones(4, 16).bool().cuda() # optional context mask, in example, include all context tokens

loss = model(
    seq,
    context = context,
    context_mask = context_mask,
    target = target
)

loss.backward()

Data

You can use the GenomicIntervalDataset to easily fetch sequences of any length from a .bed file, with greater context length dynamically computed if specified

import torch
import polars as pl
from enformer_pytorch import Enformer, GenomeIntervalDataset

filter_train = lambda df: df.filter(pl.col('column_4') == 'train')

ds = GenomeIntervalDataset(
    bed_file = './sequences.bed',                       # bed file - columns 0, 1, 2 must be <chromosome>, <start position>, <end position>
    fasta_file = './hg38.ml.fa',                        # path to fasta file
    filter_df_fn = filter_train,                        # filter dataframe function
    return_seq_indices = True,                          # return nucleotide indices (ACGTN) or one hot encodings
    shift_augs = (-2, 2),                               # random shift augmentations from -2 to +2 basepairs
    context_length = 196_608,
    # this can be longer than the interval designated in the .bed file,
    # in which case it will take care of lengthening the interval on either sides
    # as well as proper padding if at the end of the chromosomes
    chr_bed_to_fasta_map = {
        'chr1': 'chromosome1',  # if the chromosome name in the .bed file is different than the key name in the fasta file, you can rename them on the fly
        'chr2': 'chromosome2',
        'chr3': 'chromosome3',
        # etc etc
    }
)

model = Enformer.from_hparams(
    dim = 1536,
    depth = 11,
    heads = 8,
    output_heads = dict(human = 5313, mouse = 1643),
    target_length = 896,
)

seq = ds[0] # (196608,)
pred = model(seq, head = 'human') # (896, 5313)

To return the random shift value, as well as whether reverse complement was activated (in the case you need to reverse the corresponding chip-seq target data), just set return_augs = True when initializing the GenomicIntervalDataset

import torch
import polars as pl
from enformer_pytorch import Enformer, GenomeIntervalDataset

filter_train = lambda df: df.filter(pl.col('column_4') == 'train')

ds = GenomeIntervalDataset(
    bed_file = './sequences.bed',                       # bed file - columns 0, 1, 2 must be <chromosome>, <start position>, <end position>
    fasta_file = './hg38.ml.fa',                        # path to fasta file
    filter_df_fn = filter_train,                        # filter dataframe function
    return_seq_indices = True,                          # return nucleotide indices (ACGTN) or one hot encodings
    shift_augs = (-2, 2),                               # random shift augmentations from -2 to +2 basepairs
    rc_aug = True,                                      # use reverse complement augmentation with 50% probability
    context_length = 196_608,
    return_augs = True                                  # return the augmentation meta data
)

seq, rand_shift_val, rc_bool = ds[0] # (196608,), (1,), (1,)

Appreciation

Special thanks goes out to EleutherAI for providing the resources to retrain the model, during a time when the official model from Deepmind had not been released yet.

Thanks also goes out to @johahi for finding out that there are numerical differences between the torch and tensorflow implementations of xlogy. He provided a fix for this difference, which is adopted in this repository in v0.8.0

Todo

  • script to load weights from trained tensorflow enformer model to pytorch model
  • add loss wrapper with poisson loss
  • move the metrics code over to pytorch as well
  • train enformer model
  • build context manager for fine-tuning with unfrozen enformer but with frozen batchnorm
  • allow for plain fine-tune with fixed static context
  • allow for fine tuning with only unfrozen layernorms (technique from fine tuning transformers)
  • fix handling of 'N' in sequence, figure out representation of N in basenji barnyard
  • take care of shift augmentation in GenomicIntervalDataset
  • speed up str_to_seq_indices
  • add to EleutherAI huggingface (done thanks to Niels)
  • offer some basic training utils, as gradient accumulation will be needed for fine tuning

Citations

@article {Avsec2021.04.07.438649,
    author  = {Avsec, {\v Z}iga and Agarwal, Vikram and Visentin, Daniel and Ledsam, Joseph R. and Grabska-Barwinska, Agnieszka and Taylor, Kyle R. and Assael, Yannis and Jumper, John and Kohli, Pushmeet and Kelley, David R.},
    title   = {Effective gene expression prediction from sequence by integrating long-range interactions},
    elocation-id = {2021.04.07.438649},
    year    = {2021},
    doi     = {10.1101/2021.04.07.438649},
    publisher = {Cold Spring Harbor Laboratory},
    URL     = {https://www.biorxiv.org/content/early/2021/04/08/2021.04.07.438649},
    eprint  = {https://www.biorxiv.org/content/early/2021/04/08/2021.04.07.438649.full.pdf},
    journal = {bioRxiv}
}
@misc{liu2022convnet,
    title   = {A ConvNet for the 2020s},
    author  = {Zhuang Liu and Hanzi Mao and Chao-Yuan Wu and Christoph Feichtenhofer and Trevor Darrell and Saining Xie},
    year    = {2022},
    eprint  = {2201.03545},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}

.\lucidrains\enformer-pytorch\scripts\tf_to_torch.py

# 导入 einops 模块中的 rearrange 函数
from einops import rearrange

# 复制 BatchNorm 层的参数到 PyTorch 模型中
def copy_bn(mod, vars, path):
    # 获取 BatchNorm 层的 offset 和 scale 参数
    bn_offset = vars[f'{path}offset:0']
    bn_scale = vars[f'{path}scale:0']

    # 获取 BatchNorm 层的移动平均值参数
    ema_path = '/'.join(path.split('/')[:-1]) + '/'
    bn_running_mean = vars[f'{ema_path}moving_mean/average:0']
    bn_running_var = vars[f'{ema_path}moving_variance/average:0']

    # 将 scale 参数复制到权重数据中
    mod.weight.data.copy_(bn_scale)
    # 将 offset 参数复制到偏置数据中
    mod.bias.data.copy_(bn_offset)

    # 将移动方差参数复制到 running_var 数据中
    mod.running_var.data.copy_(rearrange(bn_running_var, '1 1 d -> d'))
    # 将移动平均值参数复制到 running_mean 数据中
    mod.running_mean.data.copy_(rearrange(bn_running_mean, '1 1 d -> d'))

# 复制卷积层的参数到 PyTorch 模型中
def copy_conv(mod, vars, path):
    # 获取卷积层的偏置和权重参数
    bias = vars[f'{path}b:0']
    weight = vars[f'{path}w:0']
    # 将权重参数复制到权重数据中
    mod.weight.data.copy_(rearrange(weight, 'k i o -> o i k'))
    # 将偏置参数复制到偏置数据中
    mod.bias.data.copy_(bias)

# 复制注意力池化层的参数到 PyTorch 模型中
def copy_attn_pool(mod, vars, path):
    # 获取注意力池化层的参数
    attn_pool_proj = vars[path]
    # 将参数复制到权重数据中
    mod.to_attn_logits.weight.data.copy_(rearrange(attn_pool_proj, 'i o -> o i 1 1'))

# 复制全连接层的参数到 PyTorch 模型中
def copy_linear(mod, vars, path, has_bias = True):
    # 获取全连接层的权重参数
    weight = vars[f'{path}w:0']
    # 将权重参数复制到权重数据中
    mod.weight.data.copy_(rearrange(weight, 'i o -> o i'))

    # 如果没有偏置参数,则直接返回
    if not has_bias:
        return

    # 获取全连接层的偏置参数
    bias = vars[f'{path}b:0']
    # 将偏置参数复制到偏置数据中
    mod.bias.data.copy_(bias)

# 复制 LayerNorm 层的参数到 PyTorch 模型中
def copy_ln(mod, vars, path):
    # 获取 LayerNorm 层的 scale 和 offset 参数
    weight = vars[f'{path}scale:0']
    bias = vars[f'{path}offset:0']
    # 将 scale 参数复制到权重数据中
    mod.weight.data.copy_(weight)
    # 将 offset 参数复制到偏置数据中
    mod.bias.data.copy_(bias)

# 获取 TensorFlow 模型的变量
def get_tf_vars(tf_model):
    return {v.name: (torch.from_numpy(v.numpy()) if isinstance(v.numpy(), np.ndarray) else None) for v in tf_model.variables}

# 将 TensorFlow 模型的参数复制到 PyTorch 模型中
def copy_tf_to_pytorch(tf_model, pytorch_model):
    # 获取 TensorFlow 模型的变量
    tf_vars = get_tf_vars(tf_model)
    # 获取 PyTorch 模型的 stem 部分
    stem_conv = pytorch_model.stem[0]
    stem_point_bn = pytorch_model.stem[1].fn[0]
    stem_point_conv = pytorch_model.stem[1].fn[2]
    stem_attn_pool = pytorch_model.stem[2]

    # 复制 stem 部分的参数
    copy_conv(stem_conv, tf_vars, 'enformer/trunk/stem/conv1_d/')
    copy_bn(stem_point_bn, tf_vars, 'enformer/trunk/stem/pointwise_conv_block/cross_replica_batch_norm/')
    copy_conv(stem_point_conv, tf_vars, 'enformer/trunk/stem/pointwise_conv_block/conv1_d/')
    copy_attn_pool(stem_attn_pool, tf_vars, 'enformer/trunk/stem/softmax_pooling/linear/w:0')

    # 遍历 conv_tower 部分的参数
    for ind, tower_block in enumerate(pytorch_model.conv_tower):
        tower_bn = tower_block[0][0]
        tower_conv = tower_block[0][2]
        tower_point_bn = tower_block[1].fn[0]
        tower_point_conv = tower_block[1].fn[2]
        tower_attn_pool = tower_block[2]

        # 构建路径
        conv_path = f'enformer/trunk/conv_tower/conv_tower_block_{ind}/conv_block/conv1_d/'
        bn_path = f'enformer/trunk/conv_tower/conv_tower_block_{ind}/conv_block/cross_replica_batch_norm/'
        point_conv_path = f'enformer/trunk/conv_tower/conv_tower_block_{ind}/pointwise_conv_block/conv1_d/'
        point_bn_path = f'enformer/trunk/conv_tower/conv_tower_block_{ind}/pointwise_conv_block/cross_replica_batch_norm/'
        attn_pool_path = f'enformer/trunk/conv_tower/conv_tower_block_{ind}/softmax_pooling/linear/w:0'

        # 复制 conv_tower 部分的参数
        copy_bn(tower_bn, tf_vars, bn_path)
        copy_conv(tower_conv, tf_vars, conv_path)
        copy_bn(tower_point_bn, tf_vars, point_bn_path)
        copy_conv(tower_point_conv, tf_vars, point_conv_path)
        copy_attn_pool(tower_attn_pool, tf_vars, attn_pool_path)
    # 遍历 PyTorch 模型中的 transformer 层
    for ind, transformer_block in enumerate(pytorch_model.transformer):
        # 构建注意力层的路径
        attn_ln_path = f'enformer/trunk/transformer/transformer_block_{ind}/mha/layer_norm/'
        attn_q_path = f'enformer/trunk/transformer/transformer_block_{ind}/mha/attention_{ind}/q_layer/'
        attn_k_path = f'enformer/trunk/transformer/transformer_block_{ind}/mha/attention_{ind}/k_layer/'
        attn_r_k_path = f'enformer/trunk/transformer/transformer_block_{ind}/mha/attention_{ind}/r_k_layer/'
        attn_v_path = f'enformer/trunk/transformer/transformer_block_{ind}/mha/attention_{ind}/v_layer/'
        attn_out_path = f'enformer/trunk/transformer/transformer_block_{ind}/mha/attention_{ind}/embedding_layer/'

        attn_content_bias_path = f'enformer/trunk/transformer/transformer_block_{ind}/mha/attention_{ind}/r_w_bias:0'
        attn_rel_bias_path = f'enformer/trunk/transformer/transformer_block_{ind}/mha/attention_{ind}/r_r_bias:0'

        ff_ln_path = f'enformer/trunk/transformer/transformer_block_{ind}/mlp/layer_norm/'

        # 需要编辑的链接,确保变量可访问
        ff_linear1_path = f'enformer/trunk/transformer/transformer_block_{ind}/mlp/project_in/'
        ff_linear2_path = f'enformer/trunk/transformer/transformer_block_{ind}/mlp/project_out/'

        # 获取注意力层和多头注意力机制
        attn = transformer_block[0]
        attn_ln = attn.fn[0]
        mha = attn.fn[1]

        # 复制线性层参数
        copy_linear(mha.to_q, tf_vars, attn_q_path, has_bias = False)
        copy_linear(mha.to_k, tf_vars, attn_k_path, has_bias = False)
        copy_linear(mha.to_rel_k, tf_vars, attn_r_k_path, has_bias = False)
        copy_linear(mha.to_v, tf_vars, attn_v_path, has_bias = False)
        copy_linear(mha.to_out, tf_vars, attn_out_path)

        # 复制注意力层的偏置参数
        mha.rel_content_bias.data.copy_(tf_vars[attn_content_bias_path])
        mha.rel_pos_bias.data.copy_(tf_vars[attn_rel_bias_path])

        # 获取前馈层和线性层
        ff = transformer_block[-1]
        ff_ln = ff.fn[0]
        ff_linear1 = ff.fn[1]
        ff_linear2 = ff.fn[4]

        # 复制层归一化参数
        copy_ln(attn_ln, tf_vars, attn_ln_path)

        copy_ln(ff_ln, tf_vars, ff_ln_path)
        copy_linear(ff_linear1, tf_vars, ff_linear1_path)
        copy_linear(ff_linear2, tf_vars, ff_linear2_path)

    # 获取最终的批归一化层和卷积层
    final_bn = pytorch_model.final_pointwise[1][0]
    final_conv = pytorch_model.final_pointwise[1][2]

    # 复制批归一化层和卷积层参数
    copy_bn(final_bn, tf_vars, 'enformer/trunk/final_pointwise/conv_block/cross_replica_batch_norm/')
    copy_conv(final_conv, tf_vars, 'enformer/trunk/final_pointwise/conv_block/conv1_d/')

    # 获取头部线性层
    human_linear = pytorch_model._heads['human'][0]
    mouse_linear = pytorch_model._heads['mouse'][0]

    # 复制头部线性层参数
    copy_linear(human_linear, tf_vars, 'enformer/heads/head_human/linear/')
    copy_linear(mouse_linear, tf_vars, 'enformer/heads/head_mouse/linear/')

    # 打印成功信息
    print('success')