Lucidrains-系列项目源码解析-九十八-

131 阅读15分钟

Lucidrains 系列项目源码解析(九十八)

.\lucidrains\tf-bind-transformer\tf_bind_transformer\gene_utils.py

# 用于获取转录因子序列的代码

# 定义基因标识映射,将'RXR'映射为'RXRA'
GENE_IDENTIFIER_MAP = {
    'RXR': 'RXRA'
}

# 包含连字符的基因名称集合
NAMES_WITH_HYPHENS = {
    'NKX3-1',
    'NKX2-1',
    'NKX2-5',
    'SS18-SSX'
}

# 解析基因名称的函数
def parse_gene_name(name):
    # 如果名称中不包含连字符或者名称在NAMES_WITH_HYPHENS中,则直接返回名称
    if '-' not in name or name in NAMES_WITH_HYPHENS:
        name = GENE_IDENTIFIER_MAP.get(name, name)

        # 如果名称中包含下划线,则只搜索下划线左侧的目标因子名称
        if '_' in name:
            name, *_ = name.split('_')

        return (name,)

    # 如果名称中包含连字符,则按照一定规则解析名称
    first, *rest = name.split('-')

    parsed_rest = []

    for name in rest:
        if len(name) == 1:
            name = f'{first[:-1]}{name}'
        parsed_rest.append(name)

    return tuple([first, *parsed_rest])

.\lucidrains\tf-bind-transformer\tf_bind_transformer\optimizer.py

# 从 torch.optim 模块中导入 AdamW 优化器
from torch.optim import AdamW

# 将参数分为可进行权重衰减和不可进行权重衰减的参数
def separate_weight_decayable_params(params):
    # 找出参数中维度小于 2 的参数,即不可进行权重衰减的参数
    no_wd_params = set([param for param in params if param.ndim < 2])
    # 可进行权重衰减的参数为所有参数减去不可进行权重衰减的参数
    wd_params = set(params) - no_wd_params
    return wd_params, no_wd_params

# 根据参数和超参数创建 AdamW 优化器
def get_optimizer(params, lr = 3e-4, wd = 1e-1, filter_by_requires_grad = False):
    # 如果需要根据 requires_grad 过滤参数,则只保留 requires_grad 为 True 的参数
    if filter_by_requires_grad:
        params = list(filter(lambda t: t.requires_grad, params))

    # 将参数转换为集合
    params = set(params)
    # 将参数分为可进行权重衰减和不可进行权重衰减的参数
    wd_params, no_wd_params = separate_weight_decayable_params(params)

    # 构建参数组,其中可进行权重衰减的参数使用默认权重衰减,不可进行权重衰减的参数不使用权重衰减
    param_groups = [
        {'params': list(wd_params)},
        {'params': list(no_wd_params), 'weight_decay': 0},
    ]

    # 返回使用 AdamW 优化器的参数组和超参数 lr 和 wd 的优化器
    return AdamW(param_groups, lr = lr, weight_decay = wd)

.\lucidrains\tf-bind-transformer\tf_bind_transformer\protein_utils.py

# 导入所需的库
import torch
import os
import re
from pathlib import Path
from functools import partial
import esm
from torch.nn.utils.rnn import pad_sequence
from transformers import AlbertTokenizer, AutoModelForMaskedLM, logging
from tf_bind_transformer.cache_utils import cache_fn, run_once, md5_hash_fn

# 定义函数,判断值是否存在
def exists(val):
    return val is not None

# 定义函数,对字典中的值应用给定函数
def map_values(fn, dictionary):
    return {k: fn(v) for k, v in dictionary.items()}

# 定义函数,将张量移动到指定设备
def to_device(t, *, device):
    return t.to(device)

# 定义函数,将输入转换为元组
def cast_tuple(t):
    return (t,) if not isinstance(t, tuple) else t

# 检查是否设置了环境变量 PROTEIN_EMBED_USE_CPU
PROTEIN_EMBED_USE_CPU = os.getenv('PROTEIN_EMBED_USE_CPU', None) is not None

# 如果设置了 PROTEIN_EMBED_USE_CPU,则打印提示信息
if PROTEIN_EMBED_USE_CPU:
    print('calculating protein embed only on cpu')

# 全局变量
GLOBAL_VARIABLES = {
    'model': None,
    'tokenizer': None
}

# 计算蛋白质表示与亚单位
def calc_protein_representations_with_subunits(proteins, get_repr_fn, *, device):
    representations = []

    for subunits in proteins:
        subunits = cast_tuple(subunits)
        subunits_representations = list(map(get_repr_fn, subunits))
        subunits_representations = list(map(partial(to_device, device=device), subunits_representations))
        subunits_representations = torch.cat(subunits_representations, dim=0)
        representations.append(subunits_representations)

    lengths = [seq_repr.shape[0] for seq_repr in representations]
    masks = torch.arange(max(lengths), device=device)[None, :] < torch.tensor(lengths, device=device)[:, None]
    padded_representations = pad_sequence(representations, batch_first=True)

    return padded_representations.to(device), masks.to(device)

# ESM 相关函数
ESM_MAX_LENGTH = 1024
ESM_EMBED_DIM = 1280

# 映射整数到氨基酸字符串的字典
INT_TO_AA_STR_MAP = {
    0: 'A',
    1: 'C',
    2: 'D',
    3: 'E',
    4: 'F',
    5: 'G',
    6: 'H',
    7: 'I',
    8: 'K',
    9: 'L',
    10: 'M',
    11: 'N',
    12: 'P',
    13: 'Q',
    14: 'R',
    15: 'S',
    16: 'T',
    17: 'V',
    18: 'W',
    19: 'Y',
    20: '_'
}

# 将张量转换为氨基酸字符串
def tensor_to_aa_str(t):
    str_seqs = []
    for int_seq in t.unbind(dim=0):
        str_seq = list(map(lambda t: INT_TO_AA_STR_MAP[t] if t != 20 else '', int_seq.tolist()))
        str_seqs.append(''.join(str_seq))
    return str_seqs

# 初始化 ESM 模型
@run_once('init_esm')
def init_esm():
    model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
    batch_converter = alphabet.get_batch_converter()

    if not PROTEIN_EMBED_USE_CPU:
        model = model.cuda()

    GLOBAL_VARIABLES['model'] = (model, batch_converter)

# 获取单个蛋白质的 ESM 表示
def get_single_esm_repr(protein_str):
    init_esm()
    model, batch_converter = GLOBAL_VARIABLES['model']

    data = [('protein', protein_str)]
    batch_labels, batch_strs, batch_tokens = batch_converter(data)

    if batch_tokens.shape[1] > ESM_MAX_LENGTH:
        print(f'warning max length protein esm: {protein_str}')

    batch_tokens = batch_tokens[:, :ESM_MAX_LENGTH]

    if not PROTEIN_EMBED_USE_CPU:
        batch_tokens = batch_tokens.cuda()

    with torch.no_grad():
        results = model(batch_tokens, repr_layers=[33])

    token_representations = results['representations'][33]
    representation = token_representations[0][1: len(protein_str) + 1]
    return representation

# 获取多个蛋白质的 ESM 表示
def get_esm_repr(proteins, device):
    if isinstance(proteins, torch.Tensor):
        proteins = tensor_to_aa_str(proteins)

    get_protein_repr_fn = cache_fn(get_single_esm_repr, path='esm/proteins')

    return calc_protein_representations_with_subunits(proteins, get_protein_repr_fn, device=device)

# PROT-ALBERT 2048 上下文长度
PROT_ALBERT_PATH = 'Rostlab/prot_albert'
PROT_ALBERT_DIM = 4096
PROT_ALBERT_MAX_LENGTH = 2048

# 将蛋白质字符串中的特殊字符替换为空格
def protein_str_with_spaces(protein_str):
    protein_str = re.sub(r"[UZOB]", 'X', protein_str)
    return ' '.join([*protein_str])

# 初始化 PROT-ALBERT 模型
@run_once('init_prot_albert')
def init_prot_albert():
    GLOBAL_VARIABLES['tokenizer'] = AlbertTokenizer.from_pretrained(PROT_ALBERT_PATH, do_lower_case=False)
    # 从预训练的 ALBERT 模型中加载用于 Masked Language Modeling 的模型
    model = AutoModelForMaskedLM.from_pretrained(PROT_ALBERT_PATH)
    
    # 如果不使用 CPU 运行蛋白质嵌入模型,则将模型移动到 GPU 上
    if not PROTEIN_EMBED_USE_CPU:
        model = model.cuda()
    
    # 将加载的模型存储在全局变量中
    GLOBAL_VARIABLES['model'] = model
# 获取单个蛋白质的 ALBERT 表示
def get_single_prot_albert_repr(
    protein_str,
    max_length = PROT_ALBERT_MAX_LENGTH,
    hidden_state_index = -1
):
    # 初始化 ALBERT 模型
    init_prot_albert()
    # 获取全局变量中的模型和分词器
    model = GLOBAL_VARIABLES['model']
    tokenizer = GLOBAL_VARIABLES['tokenizer']

    # 对蛋白质字符串进行编码
    encoding = tokenizer.batch_encode_plus(
        [protein_str_with_spaces(protein_str)],
        add_special_tokens = True,
        padding = True,
        truncation = True,
        max_length = max_length,
        return_attention_mask = True,
        return_tensors = 'pt'
    )

    # 如果不使用 CPU 进行蛋白质嵌入
    if not PROTEIN_EMBED_USE_CPU:
        encoding = map_values(lambda t: t.cuda(), encoding)

    # 将模型设置为评估模式
    model.eval()
    # 禁用梯度计算
    with torch.no_grad():
        # 获取模型输出
        outputs = model(**encoding, output_hidden_states = True)

    # 获取隐藏状态
    hidden_state = outputs.hidden_states[hidden_state_index][0]
    return hidden_state

# 获取蛋白质 ALBERT 表示
def get_prot_albert_repr(
    proteins,
    device,
    max_length = PROT_ALBERT_MAX_LENGTH,
    hidden_state_index = -1
):
    # 如果输入为字符串,则转换为列表
    if isinstance(proteins, str):
        proteins = [proteins]

    # 如果输入为张量,则转换为氨基酸字符串
    if isinstance(proteins, torch.Tensor):
        proteins = tensor_to_aa_str(proteins)

    # 缓存单个蛋白质 ALBERT 表示的函数
    get_protein_repr_fn = cache_fn(get_single_prot_albert_repr, path = f'proteins/prot_albert')

    # 计算蛋白质表示
    return calc_protein_representations_with_subunits(proteins, get_protein_repr_fn, device = device)

# alphafold2 函数

# 定义最大长度和嵌入维度
AF2_MAX_LENGTH = 2500
AF2_EMBEDDING_DIM = 384

# 设置 AF2_DIRECTORY 路径
AF2_DIRECTORY = os.getenv('TF_BIND_AF2_DIRECTORY', os.path.expanduser('~/.cache.tf.bind.transformer/.af2_embeddings'))
AF2_DIRECTORY_PATH = Path(AF2_DIRECTORY)

# 获取单个 alphafold2 表示
def get_single_alphafold2_repr(
    protein_str,
    max_length = AF2_MAX_LENGTH,
):
    # 计算蛋白质字符串的 MD5 哈希值
    md5 = md5_hash_fn(protein_str)
    embedding_path = AF2_DIRECTORY_PATH / f'{md5}.pt'
    assert embedding_path.exists(), f'af2 embedding not found for {protein_str}'

    # 加载嵌入张量
    tensor = torch.load(str(embedding_path))
    return tensor[:max_length]

# 获取 alphafold2 表示
def get_alphafold2_repr(
    proteins,
    device,
    max_length = AF2_MAX_LENGTH,
    **kwargs
):
    representations = []

    for subunits in proteins:
        subunits = cast_tuple(subunits)
        subunits = list(map(lambda t: get_single_alphafold2_repr(t, max_length = max_length), subunits))
        subunits = torch.cat(subunits, dim = 0)
        representations.append(subunits)

    lengths = [seq_repr.shape[0] for seq_repr in representations]
    masks = torch.arange(max(lengths), device = device)[None, :] <  torch.tensor(lengths, device = device)[:, None]
    padded_representations = pad_sequence(representations, batch_first = True)

    return padded_representations.to(device), masks.to(device)

# 工厂函数

# 定义蛋白质表示配置
PROTEIN_REPR_CONFIG = {
    'esm': {
        'dim': ESM_EMBED_DIM,
        'fn': get_esm_repr
    },
    'protalbert': {
        'dim': PROT_ALBERT_DIM,
        'fn': get_prot_albert_repr
    },
    'alphafold2': {
        'dim': AF2_EMBEDDING_DIM,
        'fn': get_alphafold2_repr
    }
}

# 获取蛋白质嵌入器
def get_protein_embedder(name):
    allowed_protein_embedders = list(PROTEIN_REPR_CONFIG.keys())
    assert name in allowed_protein_embedders, f"must be one of {', '.join(allowed_protein_embedders)}"

    config = PROTEIN_REPR_CONFIG[name]
    return config

.\lucidrains\tf-bind-transformer\tf_bind_transformer\tf_bind_transformer.py

# 导入必要的库
import copy
import math
import torch
import torch.nn.functional as F
from torch import nn, einsum
from functools import wraps

# 导入 einops 库中的函数
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce

# 导入 contextlib 库中的 contextmanager 函数
from contextlib import contextmanager

# 导入自定义的 Enformer 模型和相关函数
from enformer_pytorch import Enformer
from enformer_pytorch.modeling_enformer import poisson_loss, pearson_corr_coef
from enformer_pytorch.finetune import freeze_batchnorms_, freeze_all_but_layernorms_, unfreeze_last_n_layers_, unfreeze_all_layers_

# 导入 logavgexp 库中的函数
from logavgexp_pytorch import logavgexp

# 导入自定义的缓存函数和一些工具函数
from tf_bind_transformer.cache_utils import cache_fn
from tf_bind_transformer.protein_utils import get_protein_embedder
from tf_bind_transformer.context_utils import get_text_repr, get_contextual_dim

# 导入自定义的注意力机制相关类
from tf_bind_transformer.attention import FeedForward, JointCrossAttentionBlock, CrossAttention, SelfAttentionBlock

# 辅助函数

# 判断变量是否存在
def exists(val):
    return val is not None

# 返回默认值
def default(val, d):
    return val if exists(val) else d

# 返回函数本身
def identity(fn, *args, **kwargs):
    return fn

# 空上下文管理器
@contextmanager
def null_context():
    yield

# 张量操作函数

# 对张量进行 L2 归一化
def l2norm(t):
    return F.normalize(t, dim = -1)

# 根据概率生成掩码
def prob_mask_like(t, prob):
    return torch.zeros_like(t).float().uniform_(0, 1) < prob

# 对输入进行傅立叶编码
def fourier_encode(x, dims, theta = 20000):
    device, dtype = x.device, x.dtype
    emb = math.log(theta) / (dims // 2)
    emb = torch.exp(torch.arange(dims // 2, device = device) * -emb)
    emb = rearrange(x, 'n -> n 1') * rearrange(emb, 'd -> 1 d')
    emb = torch.cat((emb.sin(), emb.cos()), dim = -1)
    return emb

# 计算相关系数损失
def corr_coef_loss(pred, target):
    return 1 - pearson_corr_coef(pred, target).mean()

# 缓存 Enformer 前向传播结果的装饰器

def cache_enformer_forward(fn):
    cached_forward = cache_fn(fn, clear = True, path = 'genetic')

    @wraps(fn)
    def inner(seqs, *args, **kwargs):
        if seqs.ndim == 3:
            seqs = seqs.argmax(dim = -1)

        seq_list = seqs.unbind(dim = 0)
        seq_cache_keys = [''.join(list(map(str, one_seq.tolist()))) for one_seq in seq_list]
        outputs = [cached_forward(one_seq, *args, __cache_key = seq_cache_key, **kwargs) for one_seq, seq_cache_key in zip(seq_list, seq_cache_keys)]
        return torch.stack(outputs)

    return inner

# 模型

# FiLM 模块
class FiLM(nn.Module):
    def __init__(
        self,
        dim,
        conditioned_dim
    ):
        super().__init__()
        self.to_gamma = nn.Linear(dim, conditioned_dim)
        self.to_bias = nn.Linear(dim, conditioned_dim)

    def forward(self, x, condition, mask = None):
        gamma = self.to_gamma(condition)
        bias = self.to_bias(condition)

        x = x * rearrange(gamma, 'b d -> b 1 d')
        x = x + rearrange(bias, 'b d -> b 1 d')
        return x

# SqueezeExcitation 模块
class SqueezeExcitation(nn.Module):
    def __init__(
        self,
        dim,
        conditioned_dim,
        eps = 1e-8
    ):
        super().__init__()
        self.eps = eps
        self.to_gate = nn.Linear(dim + conditioned_dim, conditioned_dim)

    def forward(self, x, condition, mask = None):
        if exists(mask):
            numer = x.masked_fill(mask[..., None], 0.).sum(dim = 1)
            denom = mask.sum(dim = 1)[..., None].clamp(min = self.eps)
            mean_x = numer / denom
        else:
            mean_x = x.mean(dim = 1)

        condition = torch.cat((condition, mean_x), dim = -1)
        gate = self.to_gate(condition)

        x = x * rearrange(gate, 'b d -> b 1 d').sigmoid()
        return x

# 用于计算辅助损失的 ReadValueMLP 类
class ReadValueMLP(nn.Module):
    def __init__(
        self,
        dim,
        *,
        fourier_dims = 256,
        norm_factor_fourier = 50,
        norm_factor_linear = 8000,
        eps = 1e-20
    # 初始化函数,设置模型参数
    def __init__(
        self,
        eps,
        fourier_dims,
        norm_factor_fourier,
        norm_factor_linear
    ):
        # 调用父类初始化函数
        super().__init__()
        # 设置模型参数
        self.eps = eps
        self.fourier_dims = fourier_dims
        self.norm_factor_fourier = norm_factor_fourier
        self.norm_factor_linear = norm_factor_linear

        # 定义 logits 的归一化层
        self.logits_norm = nn.Sequential(
            Reduce('b n d -> b d', 'mean'),  # 对 logits 进行平均池化
            nn.LayerNorm(dim)  # 对结果进行 LayerNorm
        )

        # 定义 MLP 网络
        self.mlp = nn.Sequential(
            nn.Linear(dim + fourier_dims + 2, dim * 2),  # 线性层
            nn.GELU(),  # GELU 激活函数
            nn.Linear(dim * 2, 1),  # 线性层
            Rearrange('... 1 -> ...')  # 重新排列维度
        )

    # 前向传播函数
    def forward(self, logits, peaks_nr, read_value):
        # 对 logits 进行归一化
        logits = self.logits_norm(logits)

        # 对 peaks_nr 进行对数变换
        peaks_nr_log_space = torch.log(peaks_nr + self.eps)

        # 重新排列 peaks_nr 的维度
        peaks_nr = rearrange(peaks_nr, '... -> (...)')
        # 对 peaks_nr 进行傅立叶编码
        peaks_nr_encoded = fourier_encode(peaks_nr / self.norm_factor_fourier, self.fourier_dims)
        # 对 peaks_nr 进行归一化
        peaks_nr_normed = rearrange(peaks_nr, '... -> ... 1') / self.norm_factor_linear

        # 将 peaks_nr_normed、peaks_nr_log_space、peaks_nr_encoded 拼接在一起
        peaks_nr_encoded_with_self = torch.cat((peaks_nr_normed, peaks_nr_log_space, peaks_nr_encoded), dim = -1)

        # 将 logits 和 peaks_nr_encoded_with_self 拼接在一起
        logits_with_peaks = torch.cat((logits, peaks_nr_encoded_with_self), dim = -1)

        # 通过 MLP 网络得到预测值
        pred = self.mlp(logits_with_peaks)
        # 重新排列 read_value 的维度
        read_value = rearrange(read_value, '... -> (...)')

        # 返回 Smooth L1 损失
        return F.smooth_l1_loss(pred, read_value)
# 定义一个名为 HypergridLinear 的类,继承自 nn.Module
class HypergridLinear(nn.Module):
    # 初始化函数,接受输入维度 dim、输出维度 dim_out 和上下文维度 context_dim
    def __init__(
        self,
        dim,
        dim_out,
        *,
        context_dim
    ):
        super().__init__()
        # 定义权重参数,使用随机初始化
        self.weights = nn.Parameter(torch.randn(dim, dim_out))
        # 定义上下文投影层,使用线性变换
        self.contextual_projection = nn.Linear(context_dim, dim * dim_out)

    # 前向传播函数,接受输入 x 和上下文 context
    def forward(self, x, context):
        # 推导上下文门控,参考超网格论文
        gating = self.contextual_projection(context).sigmoid()
        gating = rearrange(gating, 'b (i o) -> b i o', i = int(math.sqrt(gating.shape[-1])))
        
        # 门控交互投影与上下文
        to_logits_w = rearrange(self.weights, 'i o -> 1 i o') * gating
        return einsum('b n d, b d e -> b n e', x, to_logits_w)

# 定义一个名为 FILIP 的类,继承自 nn.Module
class FILIP(nn.Module):
    # 初始化函数,接受输入维度 dim、上下文维度 context_dim、头数 heads、头维度 dim_head、dropout 概率
    def __init__(
        self,
        dim,
        context_dim,
        heads,
        dim_head = 64,
        dropout = 0.
    ):
        super().__init__()
        self.heads = heads
        inner_latent_dim = heads * dim_head

        # 定义转换到潜在空间的权重和偏置
        self.to_latent_w = nn.Parameter(torch.randn(dim, inner_latent_dim))
        self.to_latent_b = nn.Parameter(torch.randn(inner_latent_dim))

        self.pre_attn_dropout = dropout

        # 定义空上下文和上下文到潜在空间的权重和偏置
        self.null_context = nn.Parameter(torch.randn(heads, dim_head))
        self.context_to_latent_w = nn.Parameter(torch.randn(context_dim, inner_latent_dim))
        self.context_to_latent_b = nn.Parameter(torch.randn(inner_latent_dim))

    # 前向传播函数,接受输入 x、上下文 context 和上下文掩码 context_mask
    def forward(
        self,
        x,
        context,
        context_mask = None
    ):
        b, heads, device = x.shape[0], self.heads, x.device

        x = einsum('b n d, d e -> b n e', x, self.to_latent_w)
        x = x + self.to_latent_b

        x = rearrange(x, 'b n (h d) -> b h n d', h = heads)

        context = einsum('b n d, d e -> b n e', context, self.context_to_latent_w)
        context = context + self.context_to_latent_b

        context = rearrange(context, 'b n (h d) -> b h n d', h = heads)

        context, x = map(l2norm, (context, x))

        # DNA 和蛋白质序列之间的细粒度交互,参考 FILIP 论文
        if x.shape[0] == 1:
            x = rearrange(x, '1 ... -> ...')
            einsum_eq = 'h i d, b h j d -> b h i j'
        else:
            einsum_eq = 'b h i d, b h j d -> b h i j'

        # 如果上下文掩码不存在,则创建一个全为 True 的掩码
        if not exists(context_mask):
            context_mask = torch.ones((b, context.shape[-1]), device = device).bool()

        # 根据 dropout 概率生成掩码
        if self.training:
            keep_mask = prob_mask_like(context_mask, 1 - self.pre_attn_dropout)
            context_mask = context_mask & keep_mask

        # 添加空上下文并修改掩码
        context_mask = F.pad(context_mask, (1, 0), value = True)
        context_mask = rearrange(context_mask, 'b j -> b 1 1 j')

        null_context = repeat(self.null_context, 'h d -> b h 1 d', b = b)
        context = torch.cat((null_context, context), dim = -2)

        # 可微分最大化,参考 FILIP 论文
        interactions = einsum(einsum_eq, x, context)
        interactions = logavgexp(interactions, mask = context_mask, dim = -1, temp = 0.05)
        interactions = rearrange(interactions, 'b h i -> b i h')
        return interactions

# 定义一个名为 AdapterModel 的类,继承自 nn.Module
class AdapterModel(nn.Module):
    # 初始化函数,设置模型的各种参数
    def __init__(
        self,
        *,
        enformer,  # enformer 模型
        latent_dim = 64,  # 潜在维度,默认为 64
        latent_heads = 32,  # 潜在头数,默认为 32
        aa_embed_dim = None,  # 氨基酸嵌入维度,默认为 None
        aa_embed_encoder = 'esm',  # 氨基酸嵌入编码器,默认为 'esm'
        contextual_embed_dim = None,  # 上下文嵌入维度,默认为 None
        use_aa_embeds = False,  # 是否使用氨基酸嵌入,默认为 False
        use_free_text_context = False,  # 是否使用自由文本上下文,默认为 False
        free_text_context_encoder = 'pubmed',  # 自由文本上下文编码器,默认为 'pubmed'
        free_text_embed_method = 'cls',  # 自由文本嵌入方法,默认为 'cls'
        dropout = 0.,  # 丢弃率,默认为 0
        binary_target = False,  # 是否为二进制目标,默认为 False
        target_mse_loss = False,  # 是否使用均方误差损失,默认为 False
        aux_read_value_loss = False,  # 是否使用辅助读值损失,默认为 False
        read_value_aux_loss_weight = 0.05,  # 读值辅助损失权重,默认为 0.05
        joint_cross_attn_depth = 1,  # 联合交叉注意力深度,默认为 1
        genome_self_attn_depth = 0,  # 基因组自注意力深度,默认为 0
        fourier_dims = 256,  # 傅立叶维度,默认为 256
        condition_squeeze_excite = False,  # 是否条件挤压激活,默认为 False
        condition_film = False,  # 是否条件 FILM,默认为 False
        condition_hypergrid = True,  # 是否条件超网格,默认为 True
        use_corr_coef_loss = False,  # 是否使用相关系数损失,默认为 False
        finetune_output_heads = None,  # 微调输出头,默认为 None
        **kwargs  # 其他参数
        ):
            # 调用父类的构造函数
            super().__init__()
            # 断言 enformer 是 Enformer 的实例
            assert isinstance(enformer, Enformer), 'enformer must be an instance of Enformer'
            # 设置 self.enformer 为传入的 enformer
            self.enformer = enformer
            # 计算 enformer_dim 为 enformer.dim 的两倍
            enformer_dim = enformer.dim * 2

            # 如果 finetune_output_heads 存在,则为 enformer 添加头部
            if exists(finetune_output_heads):
                self.enformer.add_heads(**finetune_output_heads)

            # 初始化 norm_seq_embed 为 LayerNorm 层,输入维度为 enformer_dim
            self.norm_seq_embed = nn.LayerNorm(enformer_dim)

            # 上下文嵌入相关变量

            # 断言 free_text_embed_method 只能是 'cls' 或 'mean_pool'
            assert free_text_embed_method in {'cls', 'mean_pool'}, 'must be either cls or mean_pool'
            # 设置 self.free_text_embed_method 为传入的 free_text_embed_method
            self.free_text_embed_method = free_text_embed_method
            # 设置 self.use_free_text_context 为传入的 use_free_text_context

            if use_free_text_context:
                # 如果使用自由文本上下文,则计算上下文嵌入维度
                contextual_embed_dim = get_contextual_dim(free_text_context_encoder)
            else:
                # 否则,断言必须给出上下文嵌入维度
                assert exists(contextual_embed_dim), 'contextual embedding dimension must be given if not using transformer encoder'

            # 蛋白质嵌入相关变量

            # 设置 self.use_aa_embeds 为传入的 use_aa_embeds
            self.use_aa_embeds = use_aa_embeds
            # 获取蛋白质嵌入器的配置
            self.aa_embed_config = get_protein_embedder(aa_embed_encoder)
            # 获取蛋白质嵌入函数
            self.get_aa_embed = self.aa_embed_config['fn']

            if use_aa_embeds:
                # 如果使用蛋白质嵌入,则设置 aa_embed_dim 为蛋白质嵌入维度
                aa_embed_dim = self.aa_embed_config['dim']
            else:
                # 否则,断言必须设置 AA 嵌入维度
                assert exists(aa_embed_dim), 'AA embedding dimensions must be set if not using ESM'

            # 条件

            self.cond_genetic = None
            self.cond_protein = None

            if condition_squeeze_excite or condition_film:
                # 根据条件选择 SqueezeExcitation 或 FiLM 类
                condition_klass = SqueezeExcitation if condition_squeeze_excite else FiLM

                # 如果需要条件激活,则为 genetic 和 protein 设置条件
                self.cond_genetic  = condition_klass(contextual_embed_dim, enformer_dim)
                self.cond_protein  = condition_klass(contextual_embed_dim, aa_embed_dim)

            # 基因组自注意力

            # 初始化 genome_self_attns 为空的 ModuleList

            for _ in range(genome_self_attn_depth):
                # 循环创建 SelfAttentionBlock,并添加到 genome_self_attns 中
                attn = SelfAttentionBlock(
                    dim = enformer_dim,
                    dropout = dropout
                )
                self.genome_self_attns.append(attn)

            # 联合注意力

            # 初始化 joint_cross_attns 为空的 ModuleList

            for _ in range(joint_cross_attn_depth):
                # 循环创建 JointCrossAttentionBlock,并添加到 joint_cross_attns 中
                attn = JointCrossAttentionBlock(
                    dim = enformer_dim,
                    context_dim = aa_embed_dim,
                    dropout = dropout
                )

                self.joint_cross_attns.append(attn)

            # 潜变量

            # 初始化 filip 为 FILIP 模块
            self.filip = FILIP(
                dim = enformer_dim,
                context_dim = aa_embed_dim,
                dim_head = latent_dim,
                heads = latent_heads,
                dropout = dropout
            )

            # 超网格条件

            if condition_hypergrid:
                # 如果需要超网格条件,则初始化 linear_with_hypergrid 为 HypergridLinear
                self.linear_with_hypergrid = HypergridLinear(latent_heads, latent_heads, context_dim = contextual_embed_dim)
            else:
                # 否则,初始化 linear_to_logits 为 Linear 层
                self.linear_to_logits = nn.Linear(latent_heads, latent_heads)

            # 到预测

            # 设置 binary_target 和 aux_read_value_loss 为传入的值
            self.binary_target = binary_target
            self.aux_read_value_loss = aux_read_value_loss
            self.read_value_aux_loss_weight = read_value_aux_loss_weight

            if binary_target:
                # 如果是二进制目标,则设置损失函数为二进制交叉熵或均方误差
                self.loss_fn = F.binary_cross_entropy_with_logits if not target_mse_loss else F.mse_loss

                # 设置 to_pred 为 Sequential 模块,用于预测
                self.to_pred = nn.Sequential(
                    Reduce('... n d -> ... d', 'mean'),
                    nn.LayerNorm(latent_heads),
                    nn.Linear(latent_heads, 1),
                    Rearrange('... 1 -> ...')
                )

                # 设置 to_read_value_aux_loss 为 ReadValueMLP 模块
                self.to_read_value_aux_loss = ReadValueMLP(
                    dim = latent_heads,
                    fourier_dims = fourier_dims
                )

            else:
                # 如果不是二进制目标,则设置损失函数为泊松损失或相关系数损失
                self.loss_fn = poisson_loss if not use_corr_coef_loss else corr_coef_loss

                # 设置 to_pred 为 Sequential 模块,用于预测
                self.to_pred = nn.Sequential(
                    nn.Linear(latent_heads, 1),
                    Rearrange('... 1 -> ...'),
                    nn.Softplus()
                )
    # 合并主要损失和辅助损失,如果不需要辅助损失则返回主要损失
    def combine_losses(self, loss, aux_loss):
        if not self.aux_read_value_loss:
            return loss

        return loss + self.read_value_aux_loss_weight * aux_loss

    # 前向传播函数,用于处理 Enformer 模型的头部
    def forward_enformer_head(
        self,
        seq_embed,
        *,
        head,
        target = None,
        return_corr_coef = False
    ):
        # 检查是否开启二进制目标训练,如果是则无法在轨道上微调
        assert not self.binary_target, 'cannot finetune on tracks if binary_target training is turned on'

        # 解冻 Enformer 模型的所有层
        unfreeze_all_layers_(self.enformer._heads)

        # 检查指定的头部是否存在于 Enformer 模型中
        assert head in self.enformer._heads, f'{head} head not found in enformer'

        # 使用指定的头部对序列嵌入进行预测
        pred = self.enformer._heads[head](seq_embed)

        # 如果没有提供目标数据,则直接返回预测结果
        if not exists(target):
            return pred

        # 检查预测结果和目标数据的维度是否匹配
        assert pred.shape[-1] == target.shape[-1], f'{head} head on enformer produced {pred.shape[-1]} tracks, but the supplied target only has {target.shape[-1]}'

        # 如果提供了目标数据并且需要返回相关系数,则计算并返回相关系数
        if exists(target) and return_corr_coef:
            return pearson_corr_coef(pred, target)

        # 计算并返回损失函数的结果
        return self.loss_fn(pred, target)

    # 前向传播函数,用于处理多个输入和参数的情况
    def forward(
        self,
        seq,
        *,
        aa = None,
        aa_embed = None,
        contextual_embed = None,
        contextual_free_text = None,
        aa_mask = None,
        target = None,
        read_value = None,
        peaks_nr = None,
        return_corr_coef = False,
        finetune_enformer = False,
        finetune_enformer_ln_only = False,
        unfreeze_enformer_last_n_layers = 0,
        head = None

.\lucidrains\tf-bind-transformer\tf_bind_transformer\training_utils.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块
from torch import nn
# 从 tf_bind_transformer.optimizer 模块中导入 get_optimizer 函数
from tf_bind_transformer.optimizer import get_optimizer
# 从 tf_bind_transformer.data 模块中导入 read_bed, collate_dl_outputs, get_dataloader, remap_df_add_experiment_target_cell 函数
from tf_bind_transformer.data import read_bed, collate_dl_outputs, get_dataloader, remap_df_add_experiment_target_cell
# 从 tf_bind_transformer.data 模块中导入 RemapAllPeakDataset, NegativePeakDataset, ScopedNegativePeakDataset 类

# 定义 exists 函数,用于判断值是否存在
def exists(val):
    return val is not None

# 定义 default 函数,用于返回默认值
def default(val, d):
    return val if exists(val) else d

# 定义 accum_log 函数,用于记录和累积梯度步骤中的值
def accum_log(log, new_logs):
    for key, new_value in new_logs.items():
        old_value = log.get(key, 0.)
        log[key] = old_value + new_value
    return log

# 定义简单的 Trainer 类
class Trainer(nn.Module):
    def __init__(
        self,
        model,
        *,
        remap_bed_file,
        negative_bed_file,
        factor_fasta_folder,
        fasta_file,
        train_chromosome_ids,
        valid_chromosome_ids,
        batch_size,
        context_length,
        lr = 3e-4,
        wd = 0.1,
        validate_every = 250,
        grad_clip_norm = None,
        grad_accum_every = 1,
        held_out_targets = [],
        held_out_cell_types = [],
        exclude_targets = [],
        exclude_cell_types = [],
        shuffle = False,
        train_sample_frac = 1.,
        valid_sample_frac = 1.,
        remap_sample_frac = 1.,
        shift_aug_range = (-2, 2),
        rc_aug = False,
        experiments_json_path = None,
        read_value_aux_loss = False,
        checkpoint_filename = './checkpoint.pt',
        include_scoped_negs = False,
        scoped_negs_remap_bed_path = None,
        scoped_negs_path = None,
        scoped_negs_exts = '.bed.bool.npy',
        include_biotypes_metadata_in_context = False,
        biotypes_metadata_path = None,
        include_biotypes_metadata_columns = ['germ_layer', 'cellline_cat'],
        biotypes_metadata_delimiter = ' | ',
        balance_sampling_by_target = True,
        valid_balance_sampling_by_target = None,
    # 定义 forward 方法,用于前向传播
    def forward(
        self,
        finetune_enformer_ln_only = True,
        **kwargs
        ):
            # 获取当前的梯度累积步数
            grad_accum_every = self.grad_accum_every
            # 获取当前步数
            curr_step = int(self.steps.item())
            # 设置模型为训练模式
            self.model.train()

            # 初始化日志字典
            log = {}

            # 循环执行梯度累积步数次
            for _ in range(self.grad_accum_every):
                # 从数据加载器中获取数据
                dl_outputs = [next(self.dl), next(self.neg_dl)]

                # 如果包含了作用域负样本,则继续获取数据
                if self.include_scoped_negs:
                    dl_outputs.append(next(self.scoped_neg_dl))

                # 将数据整理成模型所需的格式
                seq, tf_aa, contextual_texts, peaks_nr, read_value, binary_target = collate_dl_outputs(*dl_outputs)
                seq, binary_target, read_value, peaks_nr = seq.cuda(), binary_target.cuda(), read_value.cuda(), peaks_nr.cuda()

                # 计算模型的损失
                loss, aux_loss = self.model(
                    seq,
                    target = binary_target,
                    aa = tf_aa,
                    contextual_free_text = contextual_texts,
                    finetune_enformer_ln_only = finetune_enformer_ln_only,
                    read_value = read_value,
                    peaks_nr = peaks_nr,
                    **kwargs
                )

                # 计算总损失
                total_loss = self.model.combine_losses(loss, aux_loss)

                # 更新日志
                log = accum_log(log, {
                    'loss': loss.item() / grad_accum_every,
                    'aux_loss': aux_loss.item() / grad_accum_every,
                    'total_loss': total_loss.item() / grad_accum_every
                })

                # 反向传播
                (total_loss / self.grad_accum_every).backward()

            # 打印当前步数的总损失
            print(f'{curr_step} loss: {log["total_loss"]}')

            # 如果设置了梯度裁剪阈值,则进行梯度裁剪
            if exists(self.grad_clip_norm):
                nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip_norm)

            # 更新优化器
            self.optim.step()
            self.optim.zero_grad()

            # 每隔一定步数进行验证
            if (curr_step % self.validate_every) == 0:
                # 设置模型为评估模式
                self.model.eval()

                # 循环执行梯度累积步数次验证
                for _ in range(self.grad_accum_every):
                    # 从验证数据加载器中获取数据
                    seq, tf_aa, contextual_texts, peaks_nr, read_value, binary_target = collate_dl_outputs(next(self.valid_dl), next(self.valid_neg_dl))
                    seq, binary_target = seq.cuda(), binary_target.cuda()

                    # 获取验证集的预测结果
                    valid_logits = self.model(
                        seq,
                        aa = tf_aa,
                        contextual_free_text = contextual_texts,
                    )

                    # 计算验证集的损失和准确率
                    valid_loss = self.model.loss_fn(valid_logits, binary_target.float())
                    valid_accuracy = ((valid_logits.sigmoid() > 0.5).int() == binary_target).sum() / (binary_target.numel())

                    # 更新日志
                    log = accum_log(log, {
                        'valid_loss': valid_loss.item() / grad_accum_every,
                        'valid_accuracy': valid_accuracy.item() / grad_accum_every
                    })

                # 打印验证集的损失和准确率
                print(f'{curr_step} valid loss: {log["valid_loss"]}')
                print(f'{curr_step} valid accuracy: {log["valid_accuracy"]}')

                # 如果当前步数大于0,则保存模型参数
                if curr_step > 0:
                    torch.save(self.model.state_dict(), self.checkpoint_filename)

            # 更新步数
            self.steps += 1
            # 返回日志
            return log

.\lucidrains\tf-bind-transformer\tf_bind_transformer\training_utils_bigwig.py

import torch
from torch import nn
from tf_bind_transformer.optimizer import get_optimizer
from tf_bind_transformer.data_bigwig import BigWigDataset, BigWigTracksOnlyDataset, get_bigwig_dataloader, get_bigwig_tracks_dataloader
from enformer_pytorch.modeling_enformer import poisson_loss, pearson_corr_coef

def exists(val):
    # 检查值是否存在
    return val is not None

def default(val, d):
    # 如果值存在则返回该值,否则返回默认值
    return val if exists(val) else d

# helpers for logging and accumulating values across gradient steps

def accum_log(log, new_logs):
    # 累积日志中的值
    for key, new_value in new_logs.items():
        old_value = log.get(key, 0.)
        log[key] = old_value + new_value
    return log

# simple Trainer class

class BigWigTrainer(nn.Module):
    def __init__(
        self,
        model,
        *,
        human_factor_fasta_folder,
        annot_file_path,
        human_loci_path,
        mouse_loci_path,
        human_fasta_file,
        mouse_fasta_file,
        batch_size,
        bigwig_tracks_only_folder_path = None,
        bigwig_folder_path = None,
        train_chromosome_ids = None,
        valid_chromosome_ids = None,
        mouse_factor_fasta_folder = None,
        downsample_factor = 128,
        target_length = 896,
        lr = 3e-4,
        wd = 0.1,
        validate_every = 250,
        grad_clip_norm = None,
        grad_accum_every = 1,
        held_out_targets_human = [],
        held_out_targets_mouse = [],
        held_out_cell_types_human = [],
        held_out_cell_types_mouse = [],
        context_length = 4096,
        shuffle = False,
        shift_aug_range = (-2, 2),
        rc_aug = False,
        checkpoint_filename = './checkpoint.pt',
        include_biotypes_metadata_in_context = False,
        biotypes_metadata_path = None,
        include_biotypes_metadata_columns = ['germ_layer', 'cellline_cat'],
        biotypes_metadata_delimiter = ' | ',
        bigwig_reduction_type = 'sum',
        enformer_train_valid_split = True
    def forward(
        self,
        finetune_enformer_ln_only = True,
        **kwargs

.\lucidrains\tf-bind-transformer\tf_bind_transformer\__init__.py

# 从 tf_bind_transformer 库中导入 AdapterModel 类
from tf_bind_transformer.tf_bind_transformer import AdapterModel
# 从 tf_bind_transformer 库中导入 Trainer 类
from tf_bind_transformer.training_utils import Trainer
# 从 tf_bind_transformer 库中导入 BigWigTrainer 类
from tf_bind_transformer.training_utils_bigwig import BigWigTrainer

TimeSformer - Pytorch

Implementation of TimeSformer, from Facebook AI. A pure and simple attention-based solution for reaching SOTA on video classification. This repository will only house the best performing variant, 'Divided Space-Time Attention', which is nothing more than attention along the time axis before the spatial.

Press release

Install

$ pip install timesformer-pytorch

Usage

import torch
from timesformer_pytorch import TimeSformer

model = TimeSformer(
    dim = 512,
    image_size = 224,
    patch_size = 16,
    num_frames = 8,
    num_classes = 10,
    depth = 12,
    heads = 8,
    dim_head =  64,
    attn_dropout = 0.1,
    ff_dropout = 0.1
)

video = torch.randn(2, 8, 3, 224, 224) # (batch x frames x channels x height x width)
mask = torch.ones(2, 8).bool() # (batch x frame) - use a mask if there are variable length videos in the same batch

pred = model(video, mask = mask) # (2, 10)

Citations

@misc{bertasius2021spacetime,
    title   = {Is Space-Time Attention All You Need for Video Understanding?}, 
    author  = {Gedas Bertasius and Heng Wang and Lorenzo Torresani},
    year    = {2021},
    eprint  = {2102.05095},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{su2021roformer,
    title   = {RoFormer: Enhanced Transformer with Rotary Position Embedding},
    author  = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu},
    year    = {2021},
    eprint  = {2104.09864},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
@article{tokshift2021,
    title   = {Token Shift Transformer for Video Classification},
    author  = {Hao Zhang, Yanbin Hao, Chong-Wah Ngo},
    journal = {ACM Multimedia 2021},
}

.\lucidrains\TimeSformer-pytorch\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'timesformer-pytorch',  # 包的名称
  packages = find_packages(),  # 查找并包含所有包
  version = '0.4.1',  # 版本号
  license='MIT',  # 许可证
  description = 'TimeSformer - Pytorch',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  url = 'https://github.com/lucidrains/TimeSformer-pytorch',  # 项目链接
  keywords = [  # 关键词列表
    'artificial intelligence',
    'attention mechanism',
    'transformers',
    'video classification',
  ],
  install_requires=[  # 安装依赖
    'einops>=0.3',
    'torch>=1.6'
  ],
  classifiers=[  # 分类器列表
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\TimeSformer-pytorch\timesformer_pytorch\rotary.py

# 从 math 模块中导入 log 和 pi 函数
# 从 torch 模块中导入 nn, einsum 和 F
# 从 einops 模块中导入 rearrange 和 repeat 函数
from math import log, pi
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange, repeat

# 定义函数,用于将输入张量中的每两个元素进行旋转
def rotate_every_two(x):
    # 重新排列输入张量的维度,将每两个元素组成一组
    x = rearrange(x, '... (d j) -> ... d j', j = 2)
    # 将每组中的两个元素拆分为两个张量
    x1, x2 = x.unbind(dim = -1)
    # 对每组中的两个元素进行旋转操作
    x = torch.stack((-x2, x1), dim = -1)
    # 重新排列张量的维度,恢复原始形状
    return rearrange(x, '... d j -> ... (d j)')

# 定义函数,应用旋转嵌入到查询和键中
def apply_rot_emb(q, k, rot_emb):
    # 解包旋转嵌入
    sin, cos = rot_emb
    # 获取旋转维度的大小
    rot_dim = sin.shape[-1]
    # 将查询和键张量分为旋转部分和非旋转部分
    (q, q_pass), (k, k_pass) = map(lambda t: (t[..., :rot_dim], t[..., rot_dim:]), (q, k))
    # 对查询和键张量的旋转部分进行旋转操作
    q, k = map(lambda t: t * cos + rotate_every_two(t) * sin, (q, k))
    # 将旋转后的查询和键张量与非旋转部分拼接
    q, k = map(lambda t: torch.cat(t, dim = -1), ((q, q_pass), (k, k_pass)))
    return q, k

# 定义类,实现轴向旋转嵌入
class AxialRotaryEmbedding(nn.Module):
    def __init__(self, dim, max_freq = 10):
        super().__init__()
        self.dim = dim
        # 计算频率范围
        scales = torch.logspace(0., log(max_freq / 2) / log(2), self.dim // 4, base = 2)
        # 将频率范围作为缓冲区存储
        self.register_buffer('scales', scales)

    def forward(self, h, w, device):
        # 重新排列频率范围的维度
        scales = rearrange(self.scales, '... -> () ...')
        # 将频率范围移动到指定设备
        scales = scales.to(device)

        # 生成高度序列
        h_seq = torch.linspace(-1., 1., steps = h, device = device)
        h_seq = h_seq.unsqueeze(-1)

        # 生成宽度序列
        w_seq = torch.linspace(-1., 1., steps = w, device = device)
        w_seq = w_seq.unsqueeze(-1)

        # 对高度和宽度序列应用频率范围和 pi
        h_seq = h_seq * scales * pi
        w_seq = w_seq * scales * pi

        # 生成正弦序列
        x_sinu = repeat(h_seq, 'i d -> i j d', j = w)
        y_sinu = repeat(w_seq, 'j d -> i j d', i = h)

        # 拼接正弦和余弦序列
        sin = torch.cat((x_sinu.sin(), y_sinu.sin()), dim = -1)
        cos = torch.cat((x_sinu.cos(), y_sinu.cos()), dim = -1)

        # 重新排列正弦和余弦序列的维度
        sin, cos = map(lambda t: rearrange(t, 'i j d -> (i j) d'), (sin, cos))
        sin, cos = map(lambda t: repeat(t, 'n d -> () n (d j)', j = 2), (sin, cos))
        return sin, cos

# 定义类,实现旋转嵌入
class RotaryEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        # 计算频率的倒数
        inv_freqs = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        # 将频率的倒数作为缓冲区存储
        self.register_buffer('inv_freqs', inv_freqs)

    def forward(self, n, device):
        # 生成序列
        seq = torch.arange(n, device = device)
        # 计算频率
        freqs = einsum('i, j -> i j', seq, self.inv_freqs)
        freqs = torch.cat((freqs, freqs), dim = -1)
        freqs = rearrange(freqs, 'n d -> () n d')
        return freqs.sin(), freqs.cos()

.\lucidrains\TimeSformer-pytorch\timesformer_pytorch\timesformer_pytorch.py

import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange, repeat

from timesformer_pytorch.rotary import apply_rot_emb, AxialRotaryEmbedding, RotaryEmbedding

# 导入所需的库

# helpers

def exists(val):
    return val is not None

# 定义一个辅助函数,用于检查变量是否存在

# classes

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = nn.LayerNorm(dim)

    def forward(self, x, *args, **kwargs):
        x = self.norm(x)
        return self.fn(x, *args, **kwargs)

# 定义一个预正则化层,包含一个 LayerNorm 层和一个传入的函数

# time token shift

def shift(t, amt):
    if amt is 0:
        return t
    return F.pad(t, (0, 0, 0, 0, amt, -amt))

# 定义一个函数,用于在时间维度上进行平移

class PreTokenShift(nn.Module):
    def __init__(self, frames, fn):
        super().__init__()
        self.frames = frames
        self.fn = fn

    def forward(self, x, *args, **kwargs):
        f, dim = self.frames, x.shape[-1]
        cls_x, x = x[:, :1], x[:, 1:]
        x = rearrange(x, 'b (f n) d -> b f n d', f = f)

        # shift along time frame before and after

        dim_chunk = (dim // 3)
        chunks = x.split(dim_chunk, dim = -1)
        chunks_to_shift, rest = chunks[:3], chunks[3:]
        shifted_chunks = tuple(map(lambda args: shift(*args), zip(chunks_to_shift, (-1, 0, 1))))
        x = torch.cat((*shifted_chunks, *rest), dim = -1)

        x = rearrange(x, 'b f n d -> b (f n) d')
        x = torch.cat((cls_x, x), dim = 1)
        return self.fn(x, *args, **kwargs)

# 定义一个预 Token 平移层,用于在时间维度上进行平移操作

# feedforward

class GEGLU(nn.Module):
    def forward(self, x):
        x, gates = x.chunk(2, dim = -1)
        return x * F.gelu(gates)

# 定义一个 GEGLU 激活函数

class FeedForward(nn.Module):
    def __init__(self, dim, mult = 4, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim * mult * 2),
            GEGLU(),
            nn.Dropout(dropout),
            nn.Linear(dim * mult, dim)
        )

    def forward(self, x):
        return self.net(x)

# 定义一个前馈神经网络层,包含线性层、GEGLU激活函数和线性层

# attention

def attn(q, k, v, mask = None):
    sim = einsum('b i d, b j d -> b i j', q, k)

    if exists(mask):
        max_neg_value = -torch.finfo(sim.dtype).max
        sim.masked_fill_(~mask, max_neg_value)

    attn = sim.softmax(dim = -1)
    out = einsum('b i j, b j d -> b i d', attn, v)
    return out

# 定义一个注意力机制函数,计算注意力权重并应用到值上

class Attention(nn.Module):
    def __init__(
        self,
        dim,
        dim_head = 64,
        heads = 8,
        dropout = 0.
    ):
        super().__init__()
        self.heads = heads
        self.scale = dim_head ** -0.5
        inner_dim = dim_head * heads

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        )

# 定义一个注意力层,包含线性层用于计算查询、键、值,以及输出线性层和 Dropout
    # 定义一个前向传播函数,接受输入 x,从 einops_from 重排到 einops_to,可选参数 mask 用于掩码,cls_mask 用于分类掩码,rot_emb 用于旋转嵌入,**einops_dims 用于指定维度
    def forward(self, x, einops_from, einops_to, mask = None, cls_mask = None, rot_emb = None, **einops_dims):
        # 获取头数
        h = self.heads
        # 将输入 x 分解为查询、键、值
        q, k, v = self.to_qkv(x).chunk(3, dim = -1)
        # 将查询、键、值重排为 (b h) n d 的形式
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v))

        # 对查询进行缩放
        q = q * self.scale

        # 分离出索引为 1 的分类令牌
        (cls_q, q_), (cls_k, k_), (cls_v, v_) = map(lambda t: (t[:, :1], t[:, 1:]), (q, k, v))

        # 让分类令牌关注所有时间和空间的补丁的键/值
        cls_out = attn(cls_q, k, v, mask = cls_mask)

        # 根据给定的 einops_from 和 einops_to 重排时间或空间
        q_, k_, v_ = map(lambda t: rearrange(t, f'{einops_from} -> {einops_to}', **einops_dims), (q_, k_, v_))

        # 如果存在旋转嵌入,则应用旋转嵌入
        if exists(rot_emb):
            q_, k_ = apply_rot_emb(q_, k_, rot_emb)

        # 将分类令牌的键和值在时间或空间上扩展并连接
        r = q_.shape[0] // cls_k.shape[0]
        cls_k, cls_v = map(lambda t: repeat(t, 'b () d -> (b r) () d', r = r), (cls_k, cls_v))

        k_ = torch.cat((cls_k, k_), dim = 1)
        v_ = torch.cat((cls_v, v_), dim = 1)

        # 注意力机制
        out = attn(q_, k_, v_, mask = mask)

        # 将时间或空间合并回原始形状
        out = rearrange(out, f'{einops_to} -> {einops_from}', **einops_dims)

        # 将分类令牌连接回输出
        out = torch.cat((cls_out, out), dim = 1)

        # 将头部合并回输出
        out = rearrange(out, '(b h) n d -> b n (h d)', h = h)

        # 合并头部输出
        return self.to_out(out)
# 主要类

class TimeSformer(nn.Module):
    def __init__(
        self,
        *,
        dim,
        num_frames,
        num_classes,
        image_size = 224,
        patch_size = 16,
        channels = 3,
        depth = 12,
        heads = 8,
        dim_head = 64,
        attn_dropout = 0.,
        ff_dropout = 0.,
        rotary_emb = True,
        shift_tokens = False
    ):
        super().__init__()
        assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_size // patch_size) ** 2
        num_positions = num_frames * num_patches
        patch_dim = channels * patch_size ** 2

        self.heads = heads
        self.patch_size = patch_size
        self.to_patch_embedding = nn.Linear(patch_dim, dim)
        self.cls_token = nn.Parameter(torch.randn(1, dim))

        self.use_rotary_emb = rotary_emb
        if rotary_emb:
            self.frame_rot_emb = RotaryEmbedding(dim_head)
            self.image_rot_emb = AxialRotaryEmbedding(dim_head)
        else:
            self.pos_emb = nn.Embedding(num_positions + 1, dim)

        self.layers = nn.ModuleList([])
        for _ in range(depth):
            ff = FeedForward(dim, dropout = ff_dropout)
            time_attn = Attention(dim, dim_head = dim_head, heads = heads, dropout = attn_dropout)
            spatial_attn = Attention(dim, dim_head = dim_head, heads = heads, dropout = attn_dropout)

            if shift_tokens:
                time_attn, spatial_attn, ff = map(lambda t: PreTokenShift(num_frames, t), (time_attn, spatial_attn, ff))

            time_attn, spatial_attn, ff = map(lambda t: PreNorm(dim, t), (time_attn, spatial_attn, ff))

            self.layers.append(nn.ModuleList([time_attn, spatial_attn, ff]))

        self.to_out = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, video, mask = None):
        b, f, _, h, w, *_, device, p = *video.shape, video.device, self.patch_size
        assert h % p == 0 and w % p == 0, f'height {h} and width {w} of video must be divisible by the patch size {p}'

        # 计算高度和宽度维度中的补丁数量,以及总补丁数(n)

        hp, wp = (h // p), (w // p)
        n = hp * wp

        # 视频转换为补丁嵌入

        video = rearrange(video, 'b f c (h p1) (w p2) -> b (f h w) (p1 p2 c)', p1 = p, p2 = p)
        tokens = self.to_patch_embedding(video)

        # 添加类别标记

        cls_token = repeat(self.cls_token, 'n d -> b n d', b = b)
        x =  torch.cat((cls_token, tokens), dim = 1)

        # 位置嵌入

        frame_pos_emb = None
        image_pos_emb = None
        if not self.use_rotary_emb:
            x += self.pos_emb(torch.arange(x.shape[1], device = device))
        else:
            frame_pos_emb = self.frame_rot_emb(f, device = device)
            image_pos_emb = self.image_rot_emb(hp, wp, device = device)

        # 计算不同帧数的掩码

        frame_mask = None
        cls_attn_mask = None
        if exists(mask):
            mask_with_cls = F.pad(mask, (1, 0), value = True)

            frame_mask = repeat(mask_with_cls, 'b f -> (b h n) () f', n = n, h = self.heads)

            cls_attn_mask = repeat(mask, 'b f -> (b h) () (f n)', n = n, h = self.heads)
            cls_attn_mask = F.pad(cls_attn_mask, (1, 0), value = True)

        # 时间和空间注意力

        for (time_attn, spatial_attn, ff) in self.layers:
            x = time_attn(x, 'b (f n) d', '(b n) f d', n = n, mask = frame_mask, cls_mask = cls_attn_mask, rot_emb = frame_pos_emb) + x
            x = spatial_attn(x, 'b (f n) d', '(b f) n d', f = f, cls_mask = cls_attn_mask, rot_emb = image_pos_emb) + x
            x = ff(x) + x

        cls_token = x[:, 0]
        return self.to_out(cls_token)

.\lucidrains\TimeSformer-pytorch\timesformer_pytorch\__init__.py

# 从 timesformer_pytorch.timesformer_pytorch 模块中导入 TimeSformer 类
from timesformer_pytorch.timesformer_pytorch import TimeSformer

Data source

The enwik8 data was downloaded from the Hutter prize page: prize.hutter1.net/

Token Shift GPT

Implementation of Token Shift GPT - An autoregressive model that relies solely on shifting along the sequence dimension and feedforwards.

Update: Inexplicably, it actually works quite well. The feedforward module follows the same design as gMLP, except the feature dimension of the gate tensor is divided up into log2(seq_len) chunks, and the mean pool of the past consecutive segments (length 1, 2, 4, 8, etc. into the past) are shifted into each chunk before a projection along the feature dimension.

Install

$ pip install token-shift-gpt

Usage

import torch
from token_shift_gpt import TokenShiftGPT

model = TokenShiftGPT(
    num_tokens = 256,
    dim = 512,
    max_seq_len = 1024,
    depth = 12,
    ff_mult = 8   # when working with small model dimensions, you may want to increase the intermediate feedforward dimension (here, 8x instead of the usual 4x), so the learning is not bottlenecked by the dimensions of the shifted chunk
)

x = torch.randint(0, 256, (1, 1024))
logits = model(x) # (1, 1024, 256)

To use the discounted cumulative sum approach (which only uses one chunk and seems to be just as effective as the above), just set use_discounted_cumsum = True

First install an additional library

$ pip install torch-discounted-cumsum

Then

import torch
from token_shift_gpt import TokenShiftGPT

model = TokenShiftGPT(
    num_tokens = 256,
    dim = 512,
    max_seq_len = 1024,
    depth = 12,
    ff_mult = 8,
    use_discounted_cumsum = True,
    discounted_gamma = 0.9              # gamma factor for discount
)

x = torch.randint(0, 256, (1, 1024))
logits = model(x) # (1, 1024, 256)

Citations

@misc{yu2021s2mlp,
    title   = {S$^2$-MLP: Spatial-Shift MLP Architecture for Vision}, 
    author  = {Tan Yu and Xu Li and Yunfeng Cai and Mingming Sun and Ping Li},
    year    = {2021},
    eprint  = {2106.07477},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{liu2021pay,
    title   = {Pay Attention to MLPs}, 
    author  = {Hanxiao Liu and Zihang Dai and David R. So and Quoc V. Le},
    year    = {2021},
    eprint  = {2105.08050},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@software{peng_bo_2021_5196578,
    author       = {PENG Bo},
    title        = {BlinkDL/RWKV-LM: 0.01},
    month        = {aug},
    year         = {2021},
    publisher    = {Zenodo},
    version      = {0.01},
    doi          = {10.5281/zenodo.5196578},
    url          = {https://doi.org/10.5281/zenodo.5196578}
}

.\lucidrains\token-shift-gpt\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'token-shift-gpt',  # 包的名称
  packages = find_packages(),  # 查找所有包
  version = '0.0.3',  # 版本号
  license='MIT',  # 许可证
  description = 'Token Shift GPT - Pytorch',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  url = 'https://github.com/lucidrains/token-shift-gpt',  # 项目链接
  keywords = [
    'artificial intelligence',  # 关键词
    'deep learning',  # 关键词
    'autoregressive language modeling'  # 关键词
  ],
  install_requires=[
    'einops>=0.3',  # 安装所需的依赖
    'torch>=1.6'  # 安装所需的依赖
  ],
  classifiers=[
    'Development Status :: 4 - Beta',  # 分类器
    'Intended Audience :: Developers',  # 分类器
    'Topic :: Scientific/Engineering :: Artificial Intelligence',  # 分类器
    'License :: OSI Approved :: MIT License',  # 分类器
    'Programming Language :: Python :: 3.6',  # 分类器
  ],
)

.\lucidrains\token-shift-gpt\token_shift_gpt\autoregressive_wrapper.py

import torch
from torch import nn
import torch.nn.functional as F

# 定义一个装饰器函数,用于在模型评估时切换为eval模式
def eval_decorator(fn):
    def inner(model, *args, **kwargs):
        was_training = model.training
        model.eval()
        out = fn(model, *args, **kwargs)
        model.train(was_training)
        return out
    return inner

# 定义一个函数用于对logits进行top k过滤
def top_k(logits, thres = 0.9):
    k = int((1 - thres) * logits.shape[-1])
    val, ind = torch.topk(logits, k)
    probs = torch.full_like(logits, float('-inf'))
    probs.scatter_(1, ind, val)
    return probs

# 定义一个包装类,用于自回归模型
class AutoregressiveWrapper(nn.Module):
    def __init__(self, net, ignore_index = -100, pad_value = 0):
        super().__init__()
        self.pad_value = pad_value
        self.ignore_index = ignore_index

        self.net = net
        self.max_seq_len = net.seq_len

    # 生成函数,用于生成序列
    @torch.no_grad()
    @eval_decorator
    def generate(self, start_tokens, seq_len, eos_token = None, temperature = 1., filter_logits_fn = top_k, filter_thres = 0.9, **kwargs):
        device = start_tokens.device
        num_dims = len(start_tokens.shape)

        if num_dims == 1:
            start_tokens = start_tokens[None, :]

        b, t = start_tokens.shape

        out = start_tokens

        for _ in range(seq_len):
            x = out[:, -self.max_seq_len:]

            logits = self.net(x, **kwargs)[:, -1, :]

            filtered_logits = top_k(logits, thres = filter_thres)
            probs = F.softmax(filtered_logits / temperature, dim=-1)

            sample = torch.multinomial(probs, 1)

            out = torch.cat((out, sample), dim=-1)

            if eos_token is not None and (sample == eos_token).all():
                break

        out = out[:, t:]

        if num_dims == 1:
            out = out.squeeze(0)

        return out

    # 前向传播函数,用于计算损失
    def forward(self, x, **kwargs):
        xi, xo = x[:, :-1], x[:, 1:]
        out = self.net(xi, **kwargs)
        loss = F.cross_entropy(out.transpose(1, 2), xo, ignore_index = self.ignore_index)
        return loss

.\lucidrains\token-shift-gpt\token_shift_gpt\token_shift_gpt.py

# 从 math 模块中导入 log2 和 ceil 函数
# 从 torch 模块中导入 nn, einsum 和 nn.functional 模块
from math import log2, ceil
import torch
from torch import nn, einsum
import torch.nn.functional as F

# 从 einops 模块中导入 rearrange 函数
from einops import rearrange

# 定义一个辅助函数,用于检查值是否存在
def exists(val):
    return val is not None

# 定义一个函数,用于在指定维度上对输入进行平移
def shift(x, amt, dim = -1):
    return F.pad(x, (*((0, 0) * (-dim - 1)), amt, -amt), value = 0.)

# 定义一个函数,用于在 tokens 上进行平移
def shift_tokens(x, amt, eps = 1e-5):
    n, device = x.shape[1], x.device

    # 计算累积和
    cumsum = x.cumsum(dim = 1)
    *x, x_pass = x.chunk(amt + 1, dim = -1)
    *x_cumsum, _ = cumsum.chunk(amt + 1, dim = -1)

    # 计算平移量
    amts = 2 ** torch.arange(amt)
    amts = amts.tolist()

    shifts = []
    denom = torch.arange(n, device = device)

    for x_chunk, x_cumsum_chunk, amt in zip(x, x_cumsum, amts):
        # 计算平移后的值
        shifted_chunk = shift(x_cumsum_chunk, amt, dim = -2) - shift(x_cumsum_chunk, 2 * amt, dim = -2)
        shifted_denom = shift(denom, amt, dim = -1) - shift(denom, 2 * amt, dim = -1)
        shifted_denom = rearrange(shifted_denom, 'n -> () n ()')
        normed_shifted_x = shifted_chunk /  (shifted_denom + eps)
        shifts.append(normed_shifted_x)

    return torch.cat((*shifts, x_pass), dim = -1)

# 定义一个函数,用于计算折扣累积和
def discounted_cumsum(t, gamma):
    try:
        from torch_discounted_cumsum import discounted_cumsum_left
    except ImportError:
        print('unable to import torch_discounted_cumsum - please run `pip install torch-discounted-cumsum`')

    b, n, d = t.shape
    t = rearrange(t, 'b n d -> (b d) n')
    t = discounted_cumsum_left(t, gamma)
    t = rearrange(t, '(b d) n -> b n d', b = b)
    return t

# 定义一个残差模块
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x):
        return self.fn(x) + x

# 定义一个前馈神经网络模块
class FeedForward(nn.Module):
    def __init__(
        self,
        *,
        dim,
        max_seq_len,
        num_shifts,
        mult = 4,
        eps = 1e-3,
        use_discounted_cumsum = False,
        discount_gamma = 0.9
    ):
        super().__init__()
        self.norm = nn.LayerNorm(dim)

        self.project_in = nn.Sequential(
            nn.Linear(dim, dim * mult),
            nn.GELU()
        )

        self.num_shifts = num_shifts
        hidden_dim = dim * mult // 2

        self.gate_norm = nn.LayerNorm(hidden_dim)
        self.to_gate = nn.Linear(hidden_dim, hidden_dim)

        nn.init.constant_(self.to_gate.weight, eps)
        nn.init.constant_(self.to_gate.bias, 1.)

        self.project_out = nn.Linear(hidden_dim, dim)

        # 用于使用折扣累积和方法

        self.use_discounted_cumsum = use_discounted_cumsum
        self.discount_gamma = discount_gamma

    def forward(self, x):
        x = self.norm(x)

        x = self.project_in(x)

        x, gate = x.chunk(2, dim = -1)

        gate = self.gate_norm(gate)

        if self.use_discounted_cumsum:
            gate = shift(gate, 1, dim = -2)
            gate = discounted_cumsum(gate, self.discount_gamma)
        else:
            gate = shift_tokens(gate, self.num_shifts)

        x = x * self.to_gate(gate)
        return self.project_out(x)

# 定义一个 TokenShiftGPT 模块
class TokenShiftGPT(nn.Module):
    def __init__(
        self,
        *,
        num_tokens,
        dim,
        max_seq_len,
        depth,
        ff_mult = 4,
        use_discounted_cumsum = False,
        discount_gamma = 0.9
    ):
        super().__init__()
        self.seq_len = max_seq_len
        num_shifts = ceil(log2(max_seq_len)) - 1

        self.token_emb = nn.Embedding(num_tokens, dim)
        self.pos_emb = nn.Embedding(max_seq_len, dim)

        self.net = nn.Sequential(
            *[Residual(FeedForward(dim = dim, num_shifts = num_shifts, mult = ff_mult, max_seq_len = max_seq_len, use_discounted_cumsum = use_discounted_cumsum, discount_gamma = discount_gamma)) for _ in range(depth)],
            nn.LayerNorm(dim),
            nn.Linear(dim, num_tokens)
        )
    # 定义一个前向传播函数,接收输入 x
    def forward(self, x):
        # 对输入 x 进行 token embedding
        x = self.token_emb(x)
        # 生成位置编码,长度为 x 的第二维度,设备为 x 所在的设备
        pos_emb = self.pos_emb(torch.arange(x.shape[1], device = x.device))
        # 将位置编码与 token embedding 相加,并重新排列维度
        x = x + rearrange(pos_emb, 'n d -> () n d')
        # 将处理后的输入 x 输入到神经网络中进行计算
        return self.net(x)