Lucidrains-系列项目源码解析-七十八-

112 阅读15分钟

Lucidrains 系列项目源码解析(七十八)

.\lucidrains\progen\train.py

# 导入 load_dotenv 函数,用于加载环境变量
from dotenv import load_dotenv
# 调用 load_dotenv 函数加载环境变量

# 导入 click、humanize、Template、Path、tqdm、numpy 等模块
import click
import humanize
from jinja2 import Template
from pathlib import Path
import tqdm
import numpy as np

# 导入 toml 模块
import toml

# 导入 jax 相关模块和函数
import jax
from jax import nn, random, jit, tree_util, tree_map
from optax import adamw, clip_by_global_norm, chain, apply_updates, apply_every

# 导入 haiku 模块中的 PRNGSequence 类
from haiku import PRNGSequence

# 导入 progen_transformer 模块及其子模块
from progen_transformer import ProGen
from progen_transformer.data import decode_tokens, iterator_from_tfrecords_folder
from progen_transformer.utils import sample, get_loss_fn, set_hardware_rng_, confirm, exists
from progen_transformer.checkpoint import get_checkpoint_fns

# 导入 wandb 模块
import wandb

# 创建模板对象 sample_tmpl,用于生成 HTML 样式
sample_tmpl = Template("""<i>{{prime_str}}</i><br/><br/><div style="overflow-wrap: break-word;">{{sampled_str}}</div>""")

# 设置硬件随机数生成器
set_hardware_rng_(jax)

# 主函数定义,接收多个命令行参数
@click.command()
@click.option('--seed', default = 42)
@click.option('--batch_size', default = 4)
@click.option('--grad_accum_every', default = 4)
@click.option('--learning_rate', default = 2e-4)
@click.option('--weight_decay', default = 1e-3)
@click.option('--data_parallel', default = False, is_flag = True)
@click.option('--max_grad_norm', default = 0.5)
@click.option('--validate_every', default = 100)
@click.option('--sample_every', default = 500)
@click.option('--checkpoint_every', default = 1000)
@click.option('--checkpoint_path', default = './ckpts')
@click.option('--checkpoint_keep_n', default = 500)
@click.option('--config_path', default = './configs/model')
@click.option('--model_name', default = 'default')
@click.option('--prime_length', default = 25)
@click.option('--seq_len', default = 1024)
@click.option('--mixed_precision', default = False, is_flag = True)
@click.option('--data_path', default = './train_data')
@click.option('--wandb_off', default = False, is_flag = True)
@click.option('--wandb_project_name', default = 'progen-training')
@click.option('--new', default = False, is_flag = True)
def main(
    seed,
    batch_size,
    grad_accum_every,
    learning_rate,
    weight_decay,
    data_parallel,
    max_grad_norm,
    validate_every,
    sample_every,
    checkpoint_every,
    checkpoint_path,
    checkpoint_keep_n,
    config_path,
    model_name,
    prime_length,
    seq_len,
    mixed_precision,
    data_path,
    wandb_off,
    wandb_project_name,
    new
):
    # 准备文件夹

    # 获取重置、获取最新、保存检查点的函数
    reset_checkpoint, get_last_checkpoint, save_checkpoint = get_checkpoint_fns(checkpoint_path)

    # 如果设置了 new 参数,清除所有检查点并重新开始训练
    if new:
        if not confirm('are you sure you want to clear all your checkpoints and restart training?'):
            exit()
        reset_checkpoint()

    # 初始化所有状态,或从检查点加载

    # 获取最新的检查点
    last_checkpoint = get_last_checkpoint()

    # 如果最新的检查点不存在
    if not exists(last_checkpoint):
        # 获取模型配置文件路径
        config_folder_path = Path(config_path)
        config_path = config_folder_path / f'{model_name}.toml'
        # 检查模型配置文件是否存在
        assert config_path.exists(), f'path to your model config {str(config_path)} does not exist'
        # 加载模型参数
        model_kwargs = toml.loads(config_path.read_text())
    else:
        # 使用最新的检查点中的模型配置
        model_kwargs = last_checkpoint['model_config']

    # 设置模型和参数

    # 创建 ProGen 模型实例
    model = ProGen(**{
        **model_kwargs,
        'mixed_precision': mixed_precision
    })

    # 编译模型应用函数
    model_apply = jit(model.apply)
    # 创建随机数生成器
    rng = PRNGSequence(seed)
    # 获取损失函数
    loss_fn = get_loss_fn(model, data_parallel = data_parallel)

    # 优化器

    # 定义排除规范和偏置参数的函数
    exclude_norm_and_bias_params = lambda p: tree_map(lambda x: x.ndim > 1, p)

    # 构建优化器链
    optim = chain(
        clip_by_global_norm(max_grad_norm),
        adamw(learning_rate, weight_decay = weight_decay, mask = exclude_norm_and_bias_params),
        apply_every(grad_accum_every)
    )

    # 获取参数和优化器状态

    if exists(last_checkpoint):
        params = last_checkpoint['params']
        optim_state = last_checkpoint['optim_state']
        start_seq_index = last_checkpoint['next_seq_index']
    else:
        # 如果不是第一次训练,则创建一个全零数组作为模拟数据
        mock_data = np.zeros((model_kwargs['seq_len'],), dtype = np.uint8)
        # 使用模拟数据初始化模型参数
        params = model.init(next(rng), mock_data)
        # 使用初始化的参数初始化优化器状态
        optim_state = optim.init(params)
        # 设置起始序列索引为0
        start_seq_index = 0

    # 实验追踪器

    # 获取模型序列长度
    seq_len = model_kwargs['seq_len']
    # 计算模型参数的数量
    num_params = tree_util.tree_reduce(lambda acc, el: acc + el.size, params, 0)
    # 将参数数量转换为可读的格式
    num_params_readable = humanize.naturalsize(num_params)

    # 设置wandb配置中的参数数量
    wandb.config.num_params = num_params

    # 根据wandb_off参数决定是否禁用wandb
    wandb_kwargs = {'mode': 'disabled'} if wandb_off else {}

    # 如果存在上次的检查点信息,则恢复运行ID和恢复模式
    if exists(last_checkpoint) and exists(last_checkpoint['run_id']):
        run_id = last_checkpoint['run_id']
        wandb_kwargs = {**wandb_kwargs, 'id': run_id, 'resume': 'allow'}

    # 初始化wandb
    wandb.init(project = wandb_project_name, **wandb_kwargs)
    wandb_run_id = wandb.run.id if not wandb_off else None

    # 获取tf数据集

    # 从tfrecords文件夹中获取训练数据集
    total_train_seqs, get_train_dataset = iterator_from_tfrecords_folder(data_path, data_type = 'train')
    # 从tfrecords文件夹中获取验证数据集
    total_valid_seqs, get_valid_dataset = iterator_from_tfrecords_folder(data_path, data_type = 'valid',)

    # 断言训练数据集和验证数据集的序列数量大于0
    assert total_train_seqs > 0, 'no protein sequences found for training'
    assert total_valid_seqs > 0, 'no protein sequences found for validation'

    # 获取训练数据集和验证数据集
    train_dataset = get_train_dataset(
        seq_len = seq_len,
        batch_size = batch_size,
        skip = start_seq_index
    )

    valid_dataset = get_valid_dataset(
        seq_len = seq_len,
        batch_size = batch_size,
        loop = True
    )

    # 打印信息

    print(f'params: {num_params_readable}')
    print(f'sequence length: {seq_len}')
    print(f'num sequences: {total_train_seqs}')
    print(f'starting from sequence {start_seq_index}')

    # 训练

    # 计算有效批次大小
    effective_batch_size = batch_size * grad_accum_every
    # 计算序列索引范围
    seq_index_ranges = range(start_seq_index, total_train_seqs, effective_batch_size)    

    # 遍历序列索引范围
    for i, seq_index in tqdm.tqdm(enumerate(seq_index_ranges), mininterval = 10., desc = 'training', total = len(seq_index_ranges)):
        # 根据梯度累积次数进行训练
        for _ in range(grad_accum_every):
            data = next(train_dataset)

            # 计算损失和梯度
            loss, grads = loss_fn(params, next(rng), data)
            # 更新参数和优化器状态
            updates, optim_state = optim.update(grads, optim_state, params)
            params = apply_updates(params, updates)

        print(f'loss: {loss.item()}')
        wandb.log({'loss': loss.item()})

        if i % checkpoint_every == 0:
            # 保存检查点信息
            package = {
                'next_seq_index': seq_index + effective_batch_size,
                'params': params,
                'optim_state': optim_state,
                'model_config': model_kwargs,
                'run_id': wandb_run_id
            }

            save_checkpoint(package, checkpoint_keep_n)
            print(f"checkpoint to start at sequence index of {package['next_seq_index']}")

        if i % validate_every == 0:
            # 验证模型
            valid_data = next(valid_dataset)
            loss, _ = loss_fn(params, next(rng), valid_data)
            print(f'valid_loss: {loss.item()}')
            wandb.log({'valid_loss': loss.item()})

        if i % sample_every == 0:
            # 生成样本
            valid_data = next(valid_dataset)[0]
            prime = valid_data[:prime_length]
            prime_str = decode_tokens(prime)

            sampled = sample(rng, model_apply, params, prime, seq_len, top_k = 25)
            sampled_str = decode_tokens(sampled[prime_length:])

            print(prime_str, "\n", "*" * 40, "\n", sampled_str)
            wandb.log({'samples': wandb.Html(sample_tmpl.render(prime_str = prime_str, sampled_str = sampled_str))})
# 如果当前脚本被直接执行,则调用主函数
if __name__ == '__main__':
    main()

.\lucidrains\protein-bert-pytorch\protein_bert_pytorch\protein_bert_pytorch.py

# 导入 math、torch 库以及 torch.nn.functional 模块中的 F 函数
import math
import torch
import torch.nn.functional as F
# 从 torch 模块中导入 nn、einsum 函数
from torch import nn, einsum
# 从 einops.layers.torch 模块中导入 Rearrange、Reduce 类
from einops.layers.torch import Rearrange, Reduce
# 从 einops 模块中导入 rearrange、repeat 函数

# helpers

# 判断变量是否存在的辅助函数
def exists(val):
    return val is not None

# 返回给定张量类型的最小负值的辅助函数
def max_neg_value(t):
    return -torch.finfo(t.dtype).max

# helper classes

# 残差连接类
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x):
        return self.fn(x) + x

# 全局线性自注意力类
class GlobalLinearSelfAttention(nn.Module):
    def __init__(
        self,
        *,
        dim,
        dim_head,
        heads
    ):
        super().__init__()
        inner_dim = dim_head * heads
        self.heads = heads
        self.scale = dim_head ** -0.5
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        self.to_out = nn.Linear(inner_dim, dim)

    def forward(self, feats, mask = None):
        h = self.heads
        q, k, v = self.to_qkv(feats).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))

        if exists(mask):
            mask = rearrange(mask, 'b n -> b () n ()')
            k = k.masked_fill(~mask, -torch.finfo(k.dtype).max)

        q = q.softmax(dim = -1)
        k = k.softmax(dim = -2)

        q = q * self.scale

        if exists(mask):
            v = v.masked_fill(~mask, 0.)

        context = einsum('b h n d, b h n e -> b h d e', k, v)
        out = einsum('b h d e, b h n d -> b h n e', context, q)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

# 交叉注意力类
class CrossAttention(nn.Module):
    def __init__(
        self,
        *,
        dim,
        dim_keys,
        dim_out,
        heads,
        dim_head = 64,
        qk_activation = nn.Tanh()
    ):
        super().__init__()
        self.heads = heads
        self.scale = dim_head ** -0.5
        inner_dim = dim_head * heads

        self.qk_activation = qk_activation

        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim_keys, inner_dim * 2, bias = False)
        self.to_out = nn.Linear(inner_dim, dim_out)

        self.null_key = nn.Parameter(torch.randn(dim_head))
        self.null_value = nn.Parameter(torch.randn(dim_head))

    def forward(self, x, context, mask = None, context_mask = None):
        b, h, device = x.shape[0], self.heads, x.device

        q = self.to_q(x)
        k, v = self.to_kv(context).chunk(2, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))

        null_k, null_v = map(lambda t: repeat(t, 'd -> b h () d', b = b, h = h), (self.null_key, self.null_value))
        k = torch.cat((null_k, k), dim = -2)
        v = torch.cat((null_v, v), dim = -2)

        q, k = map(lambda t: self.qk_activation(t), (q, k))

        sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        if exists(mask) or exists(context_mask):
            i, j = sim.shape[-2:]

            if not exists(mask):
                mask = torch.ones(b, i, dtype = torch.bool, device = device)

            if exists(context_mask):
                context_mask = F.pad(context_mask, (1, 0), value = True)
            else:
                context_mask = torch.ones(b, j, dtype = torch.bool, device = device)

            mask = rearrange(mask, 'b i -> b () i ()') * rearrange(context_mask, 'b j -> b () () j')
            sim.masked_fill_(~mask, max_neg_value(sim))

        attn = sim.softmax(dim = -1)
        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class Layer(nn.Module):
    # 初始化函数,设置模型参数
    def __init__(
        self,
        *,
        dim,
        dim_global,
        narrow_conv_kernel = 9,
        wide_conv_kernel = 9,
        wide_conv_dilation = 5,
        attn_heads = 8,
        attn_dim_head = 64,
        attn_qk_activation = nn.Tanh(),
        local_to_global_attn = False,
        local_self_attn = False,
        glu_conv = False
    ):
        # 调用父类的初始化函数
        super().__init__()

        # 如果启用局部自注意力机制,则创建全局线性自注意力对象
        self.seq_self_attn = GlobalLinearSelfAttention(dim = dim, dim_head = attn_dim_head, heads = attn_heads) if local_self_attn else None

        # 如果启用门控线性单元,则设置卷积倍数为2,否则为1
        conv_mult = 2 if glu_conv else 1

        # 创建窄卷积层
        self.narrow_conv = nn.Sequential(
            nn.Conv1d(dim, dim * conv_mult, narrow_conv_kernel, padding = narrow_conv_kernel // 2),
            nn.GELU() if not glu_conv else nn.GLU(dim = 1)
        )

        # 计算宽卷积的填充大小
        wide_conv_padding = (wide_conv_kernel + (wide_conv_kernel - 1) * (wide_conv_dilation - 1)) // 2

        # 创建宽卷积层
        self.wide_conv = nn.Sequential(
            nn.Conv1d(dim, dim * conv_mult, wide_conv_kernel, dilation = wide_conv_dilation, padding = wide_conv_padding),
            nn.GELU() if not glu_conv else nn.GLU(dim = 1)
        )

        # 设置是否进行局部到全局的注意力计算
        self.local_to_global_attn = local_to_global_attn

        # 根据是否进行局部到全局的注意力计算,创建相应的全局信息提取层
        if local_to_global_attn:
            self.extract_global_info = CrossAttention(
                dim = dim,
                dim_keys = dim_global,
                dim_out = dim,
                heads = attn_heads,
                dim_head = attn_dim_head
            )
        else:
            self.extract_global_info = nn.Sequential(
                Reduce('b n d -> b d', 'mean'),
                nn.Linear(dim_global, dim),
                nn.GELU(),
                Rearrange('b d -> b () d')
            )

        # 创建局部层归一化层
        self.local_norm = nn.LayerNorm(dim)

        # 创建局部前馈网络
        self.local_feedforward = nn.Sequential(
            Residual(nn.Sequential(
                nn.Linear(dim, dim),
                nn.GELU(),
            )),
            nn.LayerNorm(dim)
        )

        # 创建全局关注局部的交叉注意力层
        self.global_attend_local = CrossAttention(dim = dim_global, dim_out = dim_global, dim_keys = dim, heads = attn_heads, dim_head = attn_dim_head, qk_activation = attn_qk_activation)

        # 创建全局密集层
        self.global_dense = nn.Sequential(
            nn.Linear(dim_global, dim_global),
            nn.GELU()
        )

        # 创建全局层归一化层
        self.global_norm = nn.LayerNorm(dim_global)

        # 创建全局前馈网络
        self.global_feedforward = nn.Sequential(
            Residual(nn.Sequential(
                nn.Linear(dim_global, dim_global),
                nn.GELU()
            )),
            nn.LayerNorm(dim_global),
        )

    # 前向传播函数
    def forward(self, tokens, annotation, mask = None):
        # 如果启用局部到全局的注意力计算,则提取全局信息
        if self.local_to_global_attn:
            global_info = self.extract_global_info(tokens, annotation, mask = mask)
        else:
            global_info = self.extract_global_info(annotation)

        # 处理局部(蛋白质序列)

        # 如果存在局部自注意力机制,则计算全局线性注意力
        global_linear_attn = self.seq_self_attn(tokens) if exists(self.seq_self_attn) else 0

        # 重排输入以适应卷积层的输入格式
        conv_input = rearrange(tokens, 'b n d -> b d n')

        # 如果存在掩码,则根据掩码进行填充
        if exists(mask):
            conv_input_mask = rearrange(mask, 'b n -> b () n')
            conv_input = conv_input.masked_fill(~conv_input_mask, 0.)

        # 进行窄卷积和宽卷积操作
        narrow_out = self.narrow_conv(conv_input)
        narrow_out = rearrange(narrow_out, 'b d n -> b n d')
        wide_out = self.wide_conv(conv_input)
        wide_out = rearrange(wide_out, 'b d n -> b n d')

        # 更新 tokens
        tokens = tokens + narrow_out + wide_out + global_info + global_linear_attn
        tokens = self.local_norm(tokens)

        # 应用局部前馈网络
        tokens = self.local_feedforward(tokens)

        # 处理全局(注释)

        # 全局关注局部的交叉注意力
        annotation = self.global_attend_local(annotation, tokens, context_mask = mask)
        annotation = self.global_dense(annotation)
        annotation = self.global_norm(annotation)
        annotation = self.global_feedforward(annotation)

        return tokens, annotation
# 主模型类定义
class ProteinBERT(nn.Module):
    # 初始化函数
    def __init__(
        self,
        *,
        num_tokens = 26,  # 标记的数量
        num_annotation = 8943,  # 注释的数量
        dim = 512,  # 维度
        dim_global = 256,  # 全局维度
        depth = 6,  # 深度
        narrow_conv_kernel = 9,  # 窄卷积核大小
        wide_conv_kernel = 9,  # 宽卷积核大小
        wide_conv_dilation = 5,  # 宽卷积膨胀率
        attn_heads = 8,  # 注意力头数
        attn_dim_head = 64,  # 注意力头维度
        attn_qk_activation = nn.Tanh(),  # 注意力激活函数
        local_to_global_attn = False,  # 是否使用局部到全局注意力
        local_self_attn = False,  # 是否使用局部自注意力
        num_global_tokens = 1,  # 全局标记数量
        glu_conv = False  # 是否使用门控线性单元卷积
    ):
        super().__init__()
        self.num_tokens = num_tokens  # 设置标记数量
        self.token_emb = nn.Embedding(num_tokens, dim)  # 标记嵌入层

        self.num_global_tokens = num_global_tokens  # 设置全局标记数量
        self.to_global_emb = nn.Linear(num_annotation, num_global_tokens * dim_global)  # 全局嵌入层

        # 创建多层神经网络
        self.layers = nn.ModuleList([Layer(dim = dim, dim_global = dim_global, narrow_conv_kernel = narrow_conv_kernel, wide_conv_dilation = wide_conv_dilation, wide_conv_kernel = wide_conv_kernel, attn_qk_activation = attn_qk_activation, local_to_global_attn = local_to_global_attn, local_self_attn = local_self_attn, glu_conv = glu_conv) for layer in range(depth)])

        self.to_token_logits = nn.Linear(dim, num_tokens)  # 标记的逻辑回归层

        self.to_annotation_logits = nn.Sequential(
            Reduce('b n d -> b d', 'mean'),  # 减少维度
            nn.Linear(dim_global, num_annotation)  # 全局注释的逻辑回归层
        )

    # 前向传播函数
    def forward(self, seq, annotation, mask = None):
        tokens = self.token_emb(seq)  # 标记嵌入

        annotation = self.to_global_emb(annotation)  # 全局嵌入
        annotation = rearrange(annotation, 'b (n d) -> b n d', n = self.num_global_tokens)  # 重新排列全局嵌入

        for layer in self.layers:
            tokens, annotation = layer(tokens, annotation, mask = mask)  # 多层神经网络的前向传播

        tokens = self.to_token_logits(tokens)  # 标记的逻辑回归
        annotation = self.to_annotation_logits(annotation)  # 全局注释的逻辑回归
        return tokens, annotation  # 返回标记和注释

# 预训练包装器类定义
class PretrainingWrapper(nn.Module):
    # 初始化函数
    def __init__(
        self,
        model,
        random_replace_token_prob = 0.05,  # 随机替换标记的概率
        remove_annotation_prob = 0.25,  # 移除注释的概率
        add_annotation_prob = 0.01,  # 添加注释的概率
        remove_all_annotations_prob = 0.5,  # 移除所有注释的概率
        seq_loss_weight = 1.,  # 序列损失权重
        annotation_loss_weight = 1.,  # 注释损失权重
        exclude_token_ids = (0, 1, 2)   # 要排除的标记ID(用于排除填充、开始和结束标记)
    ):
        super().__init__()
        assert isinstance(model, ProteinBERT), 'model must be an instance of ProteinBERT'  # 断言模型必须是ProteinBERT的实例

        self.model = model  # 设置模型

        self.random_replace_token_prob = random_replace_token_prob  # 设置随机替换标记的概率
        self.remove_annotation_prob = remove_annotation_prob  # 设置移除注释的概率
        self.add_annotation_prob = add_annotation_prob  # 设置添加注释的概率
        self.remove_all_annotations_prob = remove_all_annotations_prob  # 设置移除所有注释的概率

        self.seq_loss_weight = seq_loss_weight  # 设置序列损失权重
        self.annotation_loss_weight = annotation_loss_weight  # 设置注释损失权重

        self.exclude_token_ids = exclude_token_ids  # 设置要排除的标记ID
    # 定义一个前向传播函数,接受序列、注释和掩码作为输入
    def forward(self, seq, annotation, mask = None):
        # 获取批量大小和设备信息
        batch_size, device = seq.shape[0], seq.device

        # 复制输入序列和注释
        seq_labels = seq
        annotation_labels = annotation

        # 如果没有提供掩码,则创建一个全为 True 的掩码
        if not exists(mask):
            mask = torch.ones_like(seq).bool()

        # 准备用于对序列进行噪声处理的掩码

        excluded_tokens_mask = mask

        # 根据排除的标记 ID,生成排除标记的掩码
        for token_id in self.exclude_token_ids:
            excluded_tokens_mask = excluded_tokens_mask & (seq != token_id)

        # 根据给定的概率生成随机替换标记的掩码
        random_replace_token_prob_mask = get_mask_subset_with_prob(excluded_tokens_mask, self.random_replace_token_prob)

        # 准备用于对注释进行噪声处理的掩码

        batch_mask = torch.ones(batch_size, device = device, dtype = torch.bool)
        batch_mask = rearrange(batch_mask, 'b -> b ()')
        remove_annotation_from_batch_mask = get_mask_subset_with_prob(batch_mask, self.remove_all_annotations_prob)

        annotation_mask = annotation > 0
        remove_annotation_prob_mask = get_mask_subset_with_prob(annotation_mask, self.remove_annotation_prob)
        add_annotation_prob_mask = get_mask_subset_with_prob(~annotation_mask, self.add_annotation_prob)
        remove_annotation_mask = remove_annotation_from_batch_mask & remove_annotation_prob_mask

        # 生成随机标记

        random_tokens = torch.randint(0, self.model.num_tokens, seq.shape, device=seq.device)

        # 确保不会用排除的标记类型(填充、开始、结束)替换标记
        for token_id in self.exclude_token_ids:
            random_replace_token_prob_mask = random_replace_token_prob_mask & (random_tokens != token_id)

        # 对序列进行噪声处理

        noised_seq = torch.where(random_replace_token_prob_mask, random_tokens, seq)

        # 对注释进行噪声处理

        noised_annotation = annotation + add_annotation_prob_mask.type(annotation.dtype)
        noised_annotation = noised_annotation * remove_annotation_mask.type(annotation.dtype)

        # 使用模型进行去噪处理

        seq_logits, annotation_logits = self.model(noised_seq, noised_annotation, mask = mask)

        # 计算损失

        seq_logits = seq_logits[mask]
        seq_labels = seq_labels[mask]

        seq_loss = F.cross_entropy(seq_logits, seq_labels, reduction = 'sum')
        annotation_loss = F.binary_cross_entropy_with_logits(annotation_logits, annotation_labels, reduction = 'sum')

        # 返回序列损失加上注释损失的加权和
        return seq_loss * self.seq_loss_weight + annotation_loss * self.annotation_loss_weight

.\lucidrains\protein-bert-pytorch\protein_bert_pytorch\__init__.py

# 从 protein_bert_pytorch 包中导入 ProteinBERT 和 PretrainingWrapper 类
from protein_bert_pytorch.protein_bert_pytorch import ProteinBERT, PretrainingWrapper

ProteinBERT - Pytorch (wip)

Implementation of ProteinBERT in Pytorch.

Original Repository

Install

$ pip install protein-bert-pytorch

Usage

import torch
from protein_bert_pytorch import ProteinBERT

model = ProteinBERT(
    num_tokens = 21,
    num_annotation = 8943,
    dim = 512,
    dim_global = 256,
    depth = 6,
    narrow_conv_kernel = 9,
    wide_conv_kernel = 9,
    wide_conv_dilation = 5,
    attn_heads = 8,
    attn_dim_head = 64
)

seq = torch.randint(0, 21, (2, 2048))
mask = torch.ones(2, 2048).bool()
annotation = torch.randint(0, 1, (2, 8943)).float()

seq_logits, annotation_logits = model(seq, annotation, mask = mask) # (2, 2048, 21), (2, 8943)

To use for pretraining

import torch
from protein_bert_pytorch import ProteinBERT, PretrainingWrapper

model = ProteinBERT(
    num_tokens = 21,
    num_annotation = 8943,
    dim = 512,
    dim_global = 256,
    depth = 6,
    narrow_conv_kernel = 9,
    wide_conv_kernel = 9,
    wide_conv_dilation = 5,
    attn_heads = 8,
    attn_dim_head = 64,
    local_to_global_attn = False,
    local_self_attn = True,
    num_global_tokens = 2,
    glu_conv = False
)

learner = PretrainingWrapper(
    model,
    random_replace_token_prob = 0.05,    # what percentage of the tokens to replace with a random one, defaults to 5% as in paper
    remove_annotation_prob = 0.25,       # what percentage of annotations to remove, defaults to 25%
    add_annotation_prob = 0.01,          # probability to add an annotation randomly, defaults to 1%
    remove_all_annotations_prob = 0.5,   # what percentage of batch items to remove annotations for completely, defaults to 50%
    seq_loss_weight = 1.,                # weight on loss of sequence
    annotation_loss_weight = 1.,         # weight on loss of annotation
    exclude_token_ids = (0, 1, 2)        # for excluding padding, start, and end tokens from being masked
)

# do the following in a loop for a lot of sequences and annotations

seq        = torch.randint(0, 21, (2, 2048))
annotation = torch.randint(0, 1, (2, 8943)).float()
mask       = torch.ones(2, 2048).bool()

loss = learner(seq, annotation, mask = mask) # (2, 2048, 21), (2, 8943)
loss.backward()

# save your model and evaluate it

torch.save(model, './improved-protein-bert.pt')

Citations

@article {Brandes2021.05.24.445464,
    author      = {Brandes, Nadav and Ofer, Dan and Peleg, Yam and Rappoport, Nadav and Linial, Michal},
    title       = {ProteinBERT: A universal deep-learning model of protein sequence and function},
    year        = {2021},
    doi         = {10.1101/2021.05.24.445464},
    publisher   = {Cold Spring Harbor Laboratory},
    URL         = {https://www.biorxiv.org/content/early/2021/05/25/2021.05.24.445464},
    eprint      = {https://www.biorxiv.org/content/early/2021/05/25/2021.05.24.445464.full.pdf},
    journal     = {bioRxiv}
}

.\lucidrains\protein-bert-pytorch\setup.py

# 导入设置和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
    name="protein-bert-pytorch",  # 包的名称
    packages=find_packages(),  # 查找所有包
    version="0.1.0",  # 版本号
    license="MIT",  # 许可证
    description="ProteinBERT - Pytorch",  # 描述
    author="Phil Wang",  # 作者
    author_email="lucidrains@gmail.com",  # 作者邮箱
    url="https://github.com/lucidrains/protein-bert-pytorch",  # 项目链接
    keywords=[  # 关键词列表
        "artificial intelligence",
        "deep learning",
        "attention mechanism",
        "protein sequences",
        "unsupervised learning"
    ],
    install_requires=[  # 安装依赖
        "einops>=0.3",
        "torch>=1.6",
    ],
    classifiers=[  # 分类器列表
        "Development Status :: 4 - Beta",
        "Intended Audience :: Developers",
        "Topic :: Scientific/Engineering :: Artificial Intelligence",
        "License :: OSI Approved :: MIT License",
        "Programming Language :: Python :: 3.6",
    ],
)

.\lucidrains\pytorch-custom-utils\pytorch_custom_utils\accelerate_utils.py

# 导入必要的模块
from functools import partial, wraps
from typing import Optional, Callable
from contextlib import nullcontext, contextmanager

from torch.nn import Module

from accelerate import Accelerator
from accelerate.tracking import WandBTracker

# 辅助函数

# 检查变量是否存在
def exists(v):
    return v is not None

# 创建一个结合两个上下文管理器的上下文管理器
@contextmanager
def combine_contexts(a, b):
    with a() as c1, b() as c2:
        yield (c1, c2)

# 在数组中查找第一个满足条件的元素
def find_first(cond: Callable, arr):
    for el in arr:
        if cond(el):
            return el

    return None

# 添加一个用于 wandb 跟踪的上下文管理器,具有特定的项目和实验名称

def add_wandb_tracker_contextmanager(
    accelerator_instance_name = 'accelerator',
    tracker_hps_instance_name = 'tracker_hps'
):
    def decorator(klass):

        @contextmanager
        def wandb_tracking(
            self,
            project: str,
            run: Optional[str] = None,
            hps: Optional[dict] = None
        ):
            maybe_accelerator = getattr(self, accelerator_instance_name, None)

            assert exists(maybe_accelerator) and isinstance(maybe_accelerator, Accelerator), f'Accelerator instance not found at self.{accelerator_instance_name}'

            hps = getattr(self, tracker_hps_instance_name, hps)

            maybe_accelerator.init_trackers(project, config = hps)

            wandb_tracker = find_first(lambda el: isinstance(el, WandBTracker), maybe_accelerator.trackers)

            assert exists(wandb_tracker), 'wandb tracking was not enabled. you need to set `log_with = "wandb"` on your accelerate kwargs'

            if exists(run):
                assert exists(wandb_tracker)
                wandb_tracker.run.name = run

            yield

            maybe_accelerator.end_training() 

        if not hasattr(klass, 'wandb_tracking'):
            klass.wandb_tracking = wandb_tracking

        return klass

    return decorator

# 当在可能的 DDP 包装的主模型上找不到属性时,自动取消包装模型

class ForwardingWrapper:
  def __init__(self, parent, child):
    self.parent = parent
    self.child = child

  def __getattr__(self, key):
    if hasattr(self.parent, key):
      return getattr(self.parent, key)

    return getattr(self.child, key)

  def __call__(self, *args, **kwargs):
    call_fn = self.__getattr__('__call__')
    return call_fn(*args, **kwargs)

def auto_unwrap_model(
    accelerator_instance_name = 'accelerator',
    model_instance_name = 'model'
):
    def decorator(klass):
        _orig_init = klass.__init__

        @wraps(_orig_init)
        def __init__(self, *args, **kwargs):
            _orig_init(self, *args, **kwargs)
            model = getattr(self, model_instance_name)
            accelerator = getattr(self, accelerator_instance_name)

            assert isinstance(accelerator, Accelerator)
            forward_wrapped_model = ForwardingWrapper(model, accelerator.unwrap_model(model))
            setattr(self, model_instance_name, forward_wrapped_model)

        klass.__init__ = __init__
        return klass

    return decorator

# 梯度累积上下文管理器
# 对除最后一次迭代外的所有迭代应用 no_sync 上下文

def model_forward_contexts(
    accelerator: Accelerator,
    model: Module,
    grad_accum_steps: int = 1
):
    for i in range(grad_accum_steps):
        is_last_step = i == grad_accum_steps - 1

        maybe_no_sync = partial(accelerator.no_sync, model) if not is_last_step else nullcontext

        yield partial(combine_contexts, accelerator.autocast, maybe_no_sync)

.\lucidrains\pytorch-custom-utils\pytorch_custom_utils\get_adam_optimizer.py

# 从 typing 模块导入 Tuple 类型
from typing import Tuple
# 从 torch.optim 模块导入 AdamW 和 Adam 优化器

# optimizer

# 将参数分为需要权重衰减和不需要权重衰减的两个列表
def separate_weight_decayable_params(params):
    wd_params, no_wd_params = [], []

    for param in params:
        # 根据参数的维度判断是否需要权重衰减
        param_list = no_wd_params if param.ndim < 2 else wd_params
        param_list.append(param)

    return wd_params, no_wd_params

# 获取 Adam 优化器
def get_adam_optimizer(
    params,
    lr: float = 1e-4,
    wd: float = 1e-2,
    betas: Tuple[int, int] = (0.9, 0.99),
    eps: float = 1e-8,
    filter_by_requires_grad = False,
    omit_gammas_and_betas_from_wd = True,
    **kwargs
):
    # 判断是否需要权重衰减
    has_weight_decay = wd > 0.

    # 根据是否需要过滤 requires_grad 来筛选参数
    if filter_by_requires_grad:
        params = [t for t in params if t.requires_grad]

    # 设置优化器的参数
    opt_kwargs = dict(
        lr = lr,
        betas = betas,
        eps = eps
    )

    # 如果不需要权重衰减,则返回 Adam 优化器
    if not has_weight_decay:
        return Adam(params, **opt_kwargs)

    # 设置带有权重衰减的优化器参数
    opt_kwargs = {'weight_decay': wd, **opt_kwargs}

    # 如果不忽略 gammas 和 betas 的权重衰减,则返回 AdamW 优化器
    if not omit_gammas_and_betas_from_wd:
        return AdamW(params, **opt_kwargs)

    # 在 transformers 中有一种早期实践,其中从权重衰减中省略了 betas 和 gammas
    # 不确定是否真的需要
    wd_params, no_wd_params = separate_weight_decayable_params(params)

    # 将参数分为需要权重衰减和不需要权重衰减的两部分
    params = [
        {'params': wd_params},
        {'params': no_wd_params, 'weight_decay': 0},
    ]

    return AdamW(params, **opt_kwargs)

.\lucidrains\pytorch-custom-utils\pytorch_custom_utils\module_device.py

# 导入必要的模块
from functools import wraps
from typing import List
from optree import tree_flatten, tree_unflatten

import torch
from torch import is_tensor
from torch.nn import Module

# 为模型提供一个 .device 属性
# 使用一个虚拟的标量张量

def module_device(
    device_property_name = 'device'
):
    # 装饰器函数,用于装饰类
    def decorator(klass):
        # 断言被装饰的类是 torch.nn.Module 的子类
        assert issubclass(klass, Module), 'should decorate a subclass of torch.nn.Module'

        # 保存原始的 __init__ 方法
        _orig_init = klass.__init__

        @wraps(_orig_init)
        def __init__(self, *args, **kwargs):
            # 调用原始的 __init__ 方法
            _orig_init(self, *args, **kwargs)

            # 在模型中注册一个名为 '_dummy' 的缓冲区,值为 torch.tensor(0),不持久化
            self.register_buffer('_dummy', torch.tensor(0), persistent = False)

        @property
        def _device_property(self):
            # 返回 '_dummy' 缓冲区的设备信息
            return self._dummy.device

        # 替换类的 __init__ 方法为自定义的 __init__ 方法
        klass.__init__ = __init__
        # 设置类的属性 device_property_name 为 _device_property
        setattr(klass, device_property_name, _device_property)
        return klass

    return decorator

# 一个装饰器,自动将传入 .forward 方法的所有张量转换为正确的设备

def autocast_device(
    methods: List[str] = ['forward']
):
    # 装饰器函数,用于装饰类
    def decorator(klass):
        # 断言被装饰的类是 torch.nn.Module 的子类
        assert issubclass(klass, Module), 'should decorate a subclass of torch.nn.Module'

        # 获取要装饰的方法的原始函数
        orig_fns = [getattr(klass, method) for method in methods]

        for method, orig_fn in zip(methods, orig_fns):

            @wraps(orig_fn)
            def fn(self, *args, **kwargs):

                # 确定设备
                # 使用上面装饰器中的虚拟张量
                # 否则查找参数并使用参数上的设备

                if hasattr(self, '_dummy'):
                    device = self._dummy.device
                else:
                    device = next(self.parameters()).device

                # 展平参数

                flattened_args, tree_spec = tree_flatten([args, kwargs])

                # 转换参数

                maybe_transformed_args = []

                for flattened_arg in flattened_args:
                    if is_tensor(flattened_arg):
                        flattened_arg = flattened_arg.to(device)

                    maybe_transformed_args.append(flattened_arg)

                # 还原参数

                args, kwargs = tree_unflatten(tree_spec, maybe_transformed_args)

                # 调用原始函数

                orig_fn(self, *args, **kwargs)

            # 设置类的方法为新的 fn 函数
            setattr(klass, method, fn)

        return klass

    return decorator

.\lucidrains\pytorch-custom-utils\pytorch_custom_utils\optimizer_scheduler_warmup.py

# 导入所需的模块和类
from contextlib import nullcontext
from typing import Optional, Type
from accelerate import Accelerator
from functools import partial
from torch import nn
from torch.nn import Module
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR, _LRScheduler
import pytorch_warmup as warmup

# 定义一个辅助函数,用于检查变量是否存在
def exists(v):
    return v is not None

# 定义一个常量,为 LambdaLR 类的部分应用,设置 lr_lambda 为恒定值 1.0
ConstantLRScheduler = partial(LambdaLR, lr_lambda = lambda step: 1.)

# 定义一个带有调度器和预热的优化器类
class OptimizerWithWarmupSchedule(nn.Module):
    def __init__(
        self,
        accelerator: Accelerator,
        optimizer: Optimizer,
        scheduler: Optional[Type[_LRScheduler]] = None,
        scheduler_kwargs: dict = dict(),
        warmup_steps: int = 0,
        max_grad_norm: Optional[float] = None
    ):
        super().__init__()
        self.max_grad_norm = max_grad_norm
        has_warmup = warmup_steps > 0

        # 如果有预热步数大于0,则创建 LinearWarmup 对象,否则为 None
        self.warmup = warmup.LinearWarmup(optimizer, warmup_period = warmup_steps) if has_warmup else None

        # 如果调度器存在,则使用给定参数创建调度器对象,否则使用常量调度器
        if exists(scheduler):
            self.scheduler = scheduler(optimizer, **scheduler_kwargs)
        else:
            self.scheduler = ConstantLRScheduler(optimizer)

        self.optimizer = optimizer

        # 准备优化器和调度器,返回准备后的优化器和调度器对象
        self.optimizer, self.scheduler = accelerator.prepare(self.optimizer, self.scheduler)
        self.accelerator = accelerator

    # 返回当前状态的字典表示
    def state_dict(self):
        pkg = dict(
            optimizer = self.optimizer.state_dict(),
            scheduler = self.scheduler.state_dict()
        )

        if exists(self.warmup):
            pkg['warmup'] = self.warmup.state_dict()

        return pkg

    # 加载状态字典表示
    def load_state_dict(self, pkg):
        self.optimizer.load_state_dict(pkg['optimizer'])
        self.scheduler.load_state_dict(pkg['scheduler'])

        if exists(self.warmup):
            self.warmup.load_state_dict(pkg['warmup'])

    # 将所有参数的梯度清零
    def zero_grad(self):
        self.optimizer.zero_grad()

    # 执行一步优化
    def step(self):
        # 如果最大梯度范数存在,则对参数进行梯度裁剪
        if exists(self.max_grad_norm):
            for param_group in self.optimizer.param_groups:
                self.accelerator.clip_grad_norm_(param_group['params'], self.max_grad_norm)

        # 执行一步优化
        self.optimizer.step()

        # 如果优化步骤未被跳过,则执行调度器的步骤
        if not self.accelerator.optimizer_step_was_skipped:
            # 根据是否存在预热对象,选择上下文管理器
            context = nullcontext if not exists(self.warmup) else self.warmup.dampening

            # 执行调度器的步骤
            with context():
                self.scheduler.step()

.\lucidrains\pytorch-custom-utils\pytorch_custom_utils\save_load.py

# 导入所需的模块
import pickle
from functools import wraps
from pathlib import Path
from packaging import version
import torch
from torch.nn import Module
from beartype import beartype
from beartype.typing import Optional

# 定义一个辅助函数,用于检查变量是否存在
def exists(v):
    return v is not None

# 装饰器函数,用于保存和加载模型
@beartype
def save_load(
    save_method_name = 'save',
    load_method_name = 'load',
    config_instance_var_name = '_config',
    init_and_load_classmethod_name = 'init_and_load',
    version: Optional[str] = None
):
    # 内部函数,用于实现保存和加载功能
    def _save_load(klass):
        # 断言被装饰的类是 torch.nn.Module 的子类
        assert issubclass(klass, Module), 'save_load should decorate a subclass of torch.nn.Module'

        # 保存原始的 __init__ 方法
        _orig_init = klass.__init__

        # 重写 __init__ 方法
        @wraps(_orig_init)
        def __init__(self, *args, **kwargs):
            # 序列化参数和关键字参数
            _config = pickle.dumps((args, kwargs))
            # 将序列化后的参数保存到实例变量中
            setattr(self, config_instance_var_name, _config)
            # 调用原始的 __init__ 方法
            _orig_init(self, *args, **kwargs)

        # 保存模型到文件
        def _save(self, path, overwrite = True):
            path = Path(path)
            assert overwrite or not path.exists()

            pkg = dict(
                model = self.state_dict(),
                config = getattr(self, config_instance_var_name),
                version = version,
            )

            torch.save(pkg, str(path))

        # 从文件加载模型
        def _load(self, path, strict = True):
            path = Path(path)
            assert path.exists()

            pkg = torch.load(str(path), map_location = 'cpu')

            if exists(version) and exists(pkg['version']) and version.parse(version) != version.parse(pkg['version']):
                self.print(f'loading saved model at version {pkg["version"]}, but current package version is {__version__}')

            self.load_state_dict(pkg['model'], strict = strict)

        # 从文件初始化并加载模型
        @classmethod
        def _init_and_load_from(cls, path, strict = True):
            path = Path(path)
            assert path.exists()
            pkg = torch.load(str(path), map_location = 'cpu')

            assert 'config' in pkg, 'model configs were not found in this saved checkpoint'

            config = pickle.loads(pkg['config'])
            args, kwargs = config
            model = cls(*args, **kwargs)

            _load(model, path, strict = strict)
            return model

        # 设置装饰后的 __init__ 方法,以及保存、加载和初始化加载方法
        klass.__init__ = __init__
        setattr(klass, save_method_name, _save)
        setattr(klass, load_method_name, _load)
        setattr(klass, init_and_load_classmethod_name, _init_and_load_from)

        return klass

    return _save_load

.\lucidrains\pytorch-custom-utils\pytorch_custom_utils\total_parameters.py

# 从 torch.nn 模块中导入 Module 类
from torch.nn import Module

# 为你的模型提供一个 .total_parameters 属性,该属性简单地对所有模块的参数求和

# 定义一个装饰器函数,用于为类添加 total_parameters 属性
def total_parameters(
    count_only_requires_grad = False,  # 是否只计算需要梯度的参数
    total_parameters_property_name = 'total_parameters'  # total_parameters 属性的名称
):
    # 装饰器函数
    def decorator(klass):
        # 断言 klass 是 torch.nn.Module 的子类
        assert issubclass(klass, Module), 'should decorate a subclass of torch.nn.Module'

        # 定义一个计算所有参数数量的属性
        @property
        def _total_parameters(self):
            return sum(p.numel() for p in self.parameters())

        # 定义一个计算需要梯度的参数数量的属性
        @property
        def _total_parameters_with_requires_grad(self):
            return sum(p.numel() for p in self.parameters() if p.requires_grad)

        # 根据 count_only_requires_grad 的值选择计算哪种参数数量
        fn = _total_parameters_with_requires_grad if count_only_requires_grad else  _total_parameters

        # 将计算参数数量的函数设置为 klass 的属性
        setattr(klass, total_parameters_property_name, fn)
        return klass

    return decorator

.\lucidrains\pytorch-custom-utils\pytorch_custom_utils\utils.py

# 导入所需的模块
from typing import Tuple
import torch.nn.functional as F

# 检查变量是否存在
def exists(v):
    return v is not None

# 填充和切片

# 在指定维度上填充张量
def pad_at_dim(t, pad: Tuple[int, int], *, dim = -1, value = 0.):
    dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
    zeros = ((0, 0) * dims_from_right)
    return F.pad(t, (*zeros, *pad), value = value)

# 在指定维度上切片张量
def slice_at_dim(t, dim_slice: slice, *, dim):
    dim += (t.ndim if dim < 0 else 0)
    colons = [slice(None)] * t.ndim
    colons[dim] = dim_slice
    return t[tuple(colons)]

# 根据长度填充或切片张量
def pad_or_slice_to(t, length, *, dim, pad_value = 0):
    curr_length = t.shape[dim]

    if curr_length < length:
        t = pad_at_dim(t, (0, length - curr_length), dim = dim, value = pad_value)
    elif curr_length > length:
        t = slice_at_dim(t, slice(0, length), dim = dim)

    return t

# 与掩码相关

# 计算受掩码影响的张量的均值
def masked_mean(tensor, mask, dim = -1, eps = 1e-5):
    if not exists(mask):
        return tensor.mean(dim = dim)

    tensor.masked_fill_(~mask, 0.)

    total_el = mask.sum(dim = dim)
    num = tensor.sum(dim = dim)
    den = total_el.float().clamp(min = eps)
    mean = num / den
    mean.masked_fill_(total_el == 0, 0.)
    return mean

# 对多个掩码进行逻辑与操作
def maybe_and_mask(*masks):
    masks = [*filter(exists, masks)]
    if len(masks) == 0:
        return None

    mask, *rest_masks = masks
    for rest_mask in rest_masks:
        mask = mask & rest_mask

    return mask

.\lucidrains\pytorch-custom-utils\pytorch_custom_utils\__init__.py

# 从 pytorch_custom_utils.module_device 模块中导入 module_device 和 autocast_device 函数
from pytorch_custom_utils.module_device import (
    module_device,
    autocast_device
)

# 从 pytorch_custom_utils.save_load 模块中导入 save_load 函数
from pytorch_custom_utils.save_load import save_load

# 从 pytorch_custom_utils.total_parameters 模块中导入 total_parameters 函数
from pytorch_custom_utils.total_parameters import total_parameters

# 从 pytorch_custom_utils.get_adam_optimizer 模块中导入 get_adam_optimizer 函数
from pytorch_custom_utils.get_adam_optimizer import get_adam_optimizer

# 从 pytorch_custom_utils.optimizer_scheduler_warmup 模块中导入 OptimizerWithWarmupSchedule 类
from pytorch_custom_utils.optimizer_scheduler_warmup import OptimizerWithWarmupSchedule

# 从 pytorch_custom_utils.accelerate_utils 模块中导入 add_wandb_tracker_contextmanager 和 auto_unwrap_model 函数
from pytorch_custom_utils.accelerate_utils import (
    add_wandb_tracker_contextmanager,
    auto_unwrap_model
)

Pytorch Custom Utils (wip)

Just some miscellaneous utility functions / decorators / modules related to Pytorch and Accelerate to help speed up implementation of new AI research

Install

$ pip install pytorch-custom-utils

Quick save and load

Class decorator for adding a quick save and load method to the module instance. Can also initialize the entire network with a class method, init_and_load.

ex.

import torch
from torch import nn

from pytorch_custom_utils import save_load

# decorate the entire class with `save_load` class decorator

@save_load()
class MLP(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.net = nn.Sequential(nn.Linear(dim, dim), nn.SiLU(), nn.Linear(dim, dim))

    def forward(self, x):
        return self.net(x)

# instantiated mlp

mlp = MLP(dim = 512)

# now you have a save and load method

mlp.save('./mlp.pt')
mlp.load('./mlp.pt')

# you can also directly initialize from the checkpoint, without having to save the corresponding hyperparameters (in this case, dim = 512)

mlp = MLP.init_and_load('./mlp.pt')

Keep track of device on module

ex.

import torch
from torch import nn

from pytorch_custom_utils import module_device

# decorate the class with `module_device` class decorator

@module_device()
class MLP(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.net = nn.Linear(dim, dim)

    def forward(self, x):
        return self.net(x)

# instantiated mlp

mlp = MLP(dim = 512)
mlp.to(torch.device('mps'))

# now you have a convenient .device

mlp.device # mps:0

.\lucidrains\pytorch-custom-utils\setup.py

# 导入设置工具和查找包工具
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'pytorch-custom-utils',  # 包名
  packages = find_packages(exclude=[]),  # 查找所有包
  version = '0.0.18',  # 版本号
  license='MIT',  # 许可证
  description = 'Pytorch Custom Utils',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  long_description_content_type = 'text/markdown',  # 长描述内容类型
  url = 'https://github.com/lucidrains/pytorch-custom-utils',  # URL
  keywords = [
    'pytorch',  # 关键字
    'accelerate'  # 关键字
  ],
  install_requires=[
    'accelerate',  # 安装依赖
    'optree',  # 安装依赖
    'pytorch-warmup',  # 安装依赖
    'torch>=2.0'  # 安装依赖
  ],
  classifiers=[
    'Development Status :: 4 - Beta',  # 分类
    'Intended Audience :: Developers',  # 分类
    'Topic :: Scientific/Engineering :: Artificial Intelligence',  # 分类
    'License :: OSI Approved :: MIT License',  # 分类
    'Programming Language :: Python :: 3.6',  # 分类
  ],
)

.\lucidrains\q-transformer\q_transformer\agent.py

# 导入必要的库
import sys
from pathlib import Path

# 导入 numpy 的相关模块
from numpy.lib.format import open_memmap

# 导入 torch 相关模块
import torch
from torch import nn, einsum, Tensor
from torch.nn import Module, ModuleList
from torch.utils.data import Dataset

# 导入 einops 库
from einops import rearrange

# 导入自定义的 QRoboticTransformer 类
from q_transformer.q_robotic_transformer import QRoboticTransformer

# 导入 torchtyping 库
from torchtyping import TensorType

# 导入 beartype 库
from beartype import beartype
from beartype.typing import Iterator, Tuple, Union

# 导入 tqdm 库
from tqdm import tqdm

# 确保在 64 位系统上进行训练
assert sys.maxsize > (2 ** 32), 'you need to be on 64 bit system to store > 2GB experience for your q-transformer agent'

# 定义常量
TEXT_EMBEDS_FILENAME = 'text_embeds.memmap.npy'
STATES_FILENAME = 'states.memmap.npy'
ACTIONS_FILENAME = 'actions.memmap.npy'
REWARDS_FILENAME = 'rewards.memmap.npy'
DONES_FILENAME = 'dones.memmap.npy'

DEFAULT_REPLAY_MEMORIES_FOLDER = './replay_memories_data'

# 定义辅助函数
def exists(v):
    return v is not None

def cast_tuple(t):
    return (t,) if not isinstance(t, tuple) else t

# 定义回放记忆数据集类
class ReplayMemoryDataset(Dataset):
    @beartype
    def __init__(
        self,
        folder: str = DEFAULT_REPLAY_MEMORIES_FOLDER,
        num_timesteps: int = 1
    ):
        # 确保时间步数大于等于 1
        assert num_timesteps >= 1
        self.is_single_timestep = num_timesteps == 1
        self.num_timesteps = num_timesteps

        # 检查文件夹是否存在
        folder = Path(folder)
        assert folder.exists() and folder.is_dir()

        # 打开并读取相关文件
        text_embeds_path = folder / TEXT_EMBEDS_FILENAME
        states_path = folder / STATES_FILENAME
        actions_path = folder / ACTIONS_FILENAME
        rewards_path = folder / REWARDS_FILENAME
        dones_path = folder / DONES_FILENAME

        self.text_embeds = open_memmap(str(text_embeds_path), dtype='float32', mode='r')
        self.states = open_memmap(str(states_path), dtype='float32', mode='r')
        self.actions = open_memmap(str(actions_path), dtype='int', mode='r')
        self.rewards = open_memmap(str(rewards_path), dtype='float32', mode='r')
        self.dones = open_memmap(str(dones_path), dtype='bool', mode='r')

        self.num_timesteps = num_timesteps

        # 根据结束标志计算每个 episode 的长度
        self.episode_length = (self.dones.cumsum(axis=-1) == 0).sum(axis=-1) + 1

        # 过滤出长度足够的 episode
        trainable_episode_indices = self.episode_length >= num_timesteps

        self.text_embeds = self.text_embeds[trainable_episode_indices]
        self.states = self.states[trainable_episode_indices]
        self.actions = self.actions[trainable_episode_indices]
        self.rewards = self.rewards[trainable_episode_indices]
        self.dones = self.dones[trainable_episode_indices]

        self.episode_length = self.episode_length[trainable_episode_indices]

        # 确保存在可训练的 episode
        assert self.dones.size > 0, 'no trainable episodes'

        self.num_episodes, self.max_episode_len = self.dones.shape

        timestep_arange = torch.arange(self.max_episode_len)

        timestep_indices = torch.stack(torch.meshgrid(
            torch.arange(self.num_episodes),
            timestep_arange
        ), dim=-1)

        trainable_mask = timestep_arange < rearrange(torch.from_numpy(self.episode_length) - num_timesteps, 'e -> e 1')
        self.indices = timestep_indices[trainable_mask]

    # 返回数据集的长度
    def __len__(self):
        return self.indices.shape[0]
    # 重载索引操作符,根据索引获取数据
    def __getitem__(self, idx):
        # 从索引中获取当前 episode 和 timestep 的索引
        episode_index, timestep_index = self.indices[idx]

        # 创建一个切片对象,用于获取当前 timestep 到 num_timesteps 之间的数据
        timestep_slice = slice(timestep_index, (timestep_index + self.num_timesteps))

        # 复制当前 episode 的文本嵌入数据
        text_embeds = self.text_embeds[episode_index, timestep_slice].copy()
        # 复制当前 episode 的状态数据
        states = self.states[episode_index, timestep_slice].copy()
        # 复制当前 episode 的动作数据
        actions = self.actions[episode_index, timestep_slice].copy()
        # 复制当前 episode 的奖励数据
        rewards = self.rewards[episode_index, timestep_slice].copy()
        # 复制当前 episode 的完成标志数据
        dones = self.dones[episode_index, timestep_slice].copy()

        # 获取下一个状态数据,如果当前 timestep 已经是最后一个,则获取最后一个状态数据
        next_state = self.states[episode_index, min(timestep_index, self.max_episode_len - 1)].copy()

        # 返回文本嵌入数据、状态数据、动作数据、下一个状态数据、奖励数据、完成标志数据
        return text_embeds, states, actions, next_state, rewards, dones
# 定义一个基础环境类,用于扩展
class BaseEnvironment(Module):
    # 初始化方法,接受状态形状和文本嵌入形状作为参数
    @beartype
    def __init__(
        self,
        *,
        state_shape: Tuple[int, ...],
        text_embed_shape: Union[int, Tuple[int, ...]]
    ):
        # 调用父类的初始化方法
        super().__init__()
        # 设置状态形状和文本嵌入形状属性
        self.state_shape = state_shape
        self.text_embed_shape = cast_tuple(text_embed_shape)
        # 注册一个缓冲区
        self.register_buffer('dummy', torch.zeros(0), persistent=False)

    # 返回缓冲区所在设备
    @property
    def device(self):
        return self.dummy.device

    # 初始化方法,返回指令和初始状态
    def init(self) -> Tuple[str, Tensor]:
        raise NotImplementedError

    # 前向传播方法,接受动作作为参数,返回奖励、下一个状态和是否结束的元组
    def forward(
        self,
        actions: Tensor
    ) -> Tuple[
        TensorType[(), float],     # reward
        Tensor,                    # next state
        TensorType[(), bool]       # done
    ]:
        raise NotImplementedError

# 代理类
class Agent(Module):
    # 初始化方法,接受 QRoboticTransformer 对象、环境对象和一些参数
    @beartype
    def __init__(
        self,
        q_transformer: QRoboticTransformer,
        *,
        environment: BaseEnvironment,
        memories_dataset_folder: str = DEFAULT_REPLAY_MEMORIES_FOLDER,
        num_episodes: int = 1000,
        max_num_steps_per_episode: int = 10000,
        epsilon_start: float = 0.25,
        epsilon_end: float = 0.001,
        num_steps_to_target_epsilon: int = 1000
    ):
        # 调用父类的初始化方法
        super().__init__()
        # 设置 QRoboticTransformer 对象
        self.q_transformer = q_transformer
        # 设置是否在文本上进行条件
        condition_on_text = q_transformer.condition_on_text
        self.condition_on_text = condition_on_text
        # 设置环境对象
        self.environment = environment

        # 断言环境对象具有状态形状和文本嵌入形状属性
        assert hasattr(environment, 'state_shape') and hasattr(environment, 'text_embed_shape')

        # 断言参数的取值范围
        assert 0. <= epsilon_start <= 1.
        assert 0. <= epsilon_end <= 1.
        assert epsilon_start >= epsilon_end

        # 设置一些参数
        self.epsilon_start = epsilon_start
        self.epsilon_end = epsilon_end
        self.num_steps_to_target_epsilon = num_steps_to_target_epsilon
        self.epsilon_slope = (epsilon_end - epsilon_start) / num_steps_to_target_epsilon

        self.num_episodes = num_episodes
        self.max_num_steps_per_episode = max_num_steps_per_episode

        # 创建存储回忆的文件夹
        mem_path = Path(memories_dataset_folder)
        self.memories_dataset_folder = mem_path

        mem_path.mkdir(exist_ok=True, parents=True)
        assert mem_path.is_dir()

        # 设置存储状态、动作、奖励和结束标志的文件路径
        states_path = mem_path / STATES_FILENAME
        actions_path = mem_path / ACTIONS_FILENAME
        rewards_path = mem_path / REWARDS_FILENAME
        dones_path = mem_path / DONES_FILENAME

        # 设置先验形状和动作数量
        prec_shape = (num_episodes, max_num_steps_per_episode)
        num_actions = q_transformer.num_actions
        state_shape = environment.state_shape

        # 如果在文本上进行条件
        if condition_on_text:
            text_embeds_path = mem_path / TEXT_EMBEDS_FILENAME
            text_embed_shape = environment.text_embed_shape
            self.text_embed_shape = text_embed_shape
            # 创建文本嵌入的内存映射
            self.text_embeds = open_memmap(str(text_embeds_path), dtype='float32', mode='w+', shape=(*prec_shape, *text_embed_shape))

        # 创建状态、动作、奖励和结束标志的内存映射
        self.states = open_memmap(str(states_path), dtype='float32', mode='w+', shape=(*prec_shape, *state_shape))
        self.actions = open_memmap(str(actions_path), dtype='int', mode='w+', shape=(*prec_shape, num_actions))
        self.rewards = open_memmap(str(rewards_path), dtype='float32', mode='w+', shape=prec_shape)
        self.dones = open_memmap(str(dones_path), dtype='bool', mode='w+', shape=prec_shape)

    # 根据步数获取 epsilon 值
    def get_epsilon(self, step):
        return max(self.epsilon_end, self.epsilon_slope * float(step) + self.epsilon_start)

    # 无需梯度的装饰器
    @beartype
    @torch.no_grad()
    # 定义一个方法,用于执行前向传播
    def forward(self):
        # 将 Q-Transformer 设置为评估模式
        self.q_transformer.eval()

        # 循环执行多个 episode
        for episode in range(self.num_episodes):
            # 打印当前 episode 的信息
            print(f'episode {episode}')

            # 初始化环境,获取指令和当前状态
            instruction, curr_state = self.environment.init()

            # 在每个 episode 中执行多个步骤
            for step in tqdm(range(self.max_num_steps_per_episode)):
                # 判断是否是最后一个步骤
                last_step = step == (self.max_num_steps_per_episode - 1)

                # 根据当前步骤获取 epsilon 值
                epsilon = self.get_epsilon(step)

                # 初始化文本嵌入为 None
                text_embed = None

                # 如果需要根据文本条件执行动作
                if self.condition_on_text:
                    # 获取指令的文本嵌入
                    text_embed = self.q_transformer.embed_texts([instruction])

                # 获取动作
                actions = self.q_transformer.get_actions(
                    rearrange(curr_state, '... -> 1 ...'),
                    text_embeds = text_embed,
                    prob_random_action = epsilon
                )

                # 执行动作,获取奖励、下一个状态和是否结束的标志
                reward, next_state, done = self.environment(actions)

                # 判断是否结束或是最后一个步骤
                done = done | last_step

                # 使用 memmap 存储记忆,以便后续回顾和学习

                # 如果需要根据文本条件执行动作
                if self.condition_on_text:
                    # 断言文本嵌入的形状符合预期
                    assert text_embed.shape[1:] == self.text_embed_shape
                    # 将文本嵌入存储到指定位置
                    self.text_embeds[episode, step] = text_embed

                # 存储当前状态、动作、奖励和结束标志
                self.states[episode, step]      = curr_state
                self.actions[episode, step]     = actions
                self.rewards[episode, step]     = reward
                self.dones[episode, step]       = done

                # 如果已经结束,跳出当前 episode 的循环
                if done:
                    break

                # 更新当前状态为下一个状态
                curr_state = next_state

            # 如果需要根据文本条件执行动作
            if self.condition_on_text:
                # 刷���文本嵌入的存储
                self.text_embeds.flush()

            # 刷新当前状态、动作、奖励和结束标志的存储
            self.states.flush()
            self.actions.flush()
            self.rewards.flush()
            self.dones.flush()

        # 关闭 memmap

        # 如果需要根据文本条件执行动作
        if self.condition_on_text:
            # 删除文本嵌入
            del self.text_embeds

        # 删除当前状态、动作、奖励和结束标志
        del self.states
        del self.actions
        del self.rewards
        del self.dones

        # 打印完成信息,存储的记忆位置
        print(f'completed, memories stored to {self.memories_dataset_folder.resolve()}')

.\lucidrains\q-transformer\q_transformer\attend.py

# 导入所需的模块和函数
from functools import wraps
from packaging import version

import torch
from torch import nn, einsum
import torch.nn.functional as F

from einops import rearrange, reduce

# 定义一个装饰器函数,确保函数只被调用一次
def once(fn):
    called = False
    @wraps(fn)
    def inner(x):
        nonlocal called
        if called:
            return
        called = True
        return fn(x)
    return inner

# 用装饰器once包装print函数,确保只打印一次
print_once = once(print)

# 判断变量是否存在的辅助函数
def exists(val):
    return val is not None

# 如果val存在则返回val,否则返回默认值d的辅助函数
def default(val, d):
    return val if exists(val) else d

# 将多个可能的mask合并为一个mask的辅助函数
def maybe_reduce_mask_and(*maybe_masks):
    maybe_masks = [*filter(exists, maybe_masks)]

    if len(maybe_masks) == 0:
        return None

    mask, *rest_masks = maybe_masks

    for rest_mask in rest_masks:
        mask = mask & rest_mask

    return mask

# 主要的Attend类
class Attend(nn.Module):
    def __init__(
        self,
        dropout = 0.,
        flash = False,
        causal = False,
        flash_config: dict = dict(
            enable_flash = True,
            enable_math = True,
            enable_mem_efficient = True
        )
    ):
        super().__init__()
        self.dropout = dropout
        self.attn_dropout = nn.Dropout(dropout)

        self.causal = causal
        self.flash = flash
        assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'

        if flash:
            print_once('using memory efficient attention')

        self.flash_config = flash_config

    # Flash Attention函数
    def flash_attn(self, q, k, v, mask = None, attn_mask = None):
        _, heads, q_len, dim_head, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device

        # 检查mask是否存在并扩展到兼容的形状
        if exists(mask):
            mask = mask.expand(-1, heads, q_len, -1)

        mask = maybe_reduce_mask_and(mask, attn_mask)

        # 使用torch.backends.cuda.sdp_kernel(**self.flash_config)进行pytorch 2.0的flash attention计算
        with torch.backends.cuda.sdp_kernel(**self.flash_config):
            out = F.scaled_dot_product_attention(
                q, k, v,
                attn_mask = mask,
                is_causal = self.causal,
                dropout_p = self.dropout if self.training else 0.
            )

        return out

    # 前向传播函数
    def forward(self, q, k, v, mask = None, attn_mask = None):
        """
        einstein notation
        b - batch
        h - heads
        n, i, j - sequence length (base sequence length, source, target)
        d - feature dimension
        """

        q_len, k_len, device = q.shape[-2], k.shape[-2], q.device

        scale = q.shape[-1] ** -0.5

        if exists(mask) and mask.ndim != 4:
            mask = rearrange(mask, 'b j -> b 1 1 j')

        if self.flash:
            return self.flash_attn(q, k, v, mask = mask, attn_mask = attn_mask)

        # 相似度计算
        sim = einsum(f"b h i d, b h j d -> b h i j", q, k) * scale

        # 因果mask
        if self.causal:
            i, j = sim.shape[-2:]
            causal_mask = torch.ones((i, j), dtype = torch.bool, device = sim.device).triu(j - i + 1)
            sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)

        # key padding mask
        if exists(mask):
            sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)

        # attention mask
        if exists(attn_mask):
            sim = sim.masked_fill(~attn_mask, -torch.finfo(sim.dtype).max)

        # 注意力权重计算
        attn = sim.softmax(dim=-1)
        attn = self.attn_dropout(attn)

        # 聚合值
        out = einsum(f"b h i j, b h j d -> b h i d", attn, v)

        return out