Lucidrains-系列项目源码解析-三十-

84 阅读11分钟

Lucidrains 系列项目源码解析(三十)

Electra - Pytorch

A simple working wrapper for fast pretraining of language models as detailed in this paper. It speeds up training (in comparison to normal masked language modeling) by a factor of 4x, and eventually reaches better performance if trained for even longer. Special thanks to Erik Nijkamp for taking the time to replicate the results for GLUE.

Install

$ pip install electra-pytorch

Usage

The following example uses reformer-pytorch, which is available to be pip installed.

import torch
from torch import nn
from reformer_pytorch import ReformerLM

from electra_pytorch import Electra

# (1) instantiate the generator and discriminator, making sure that the generator is roughly a quarter to a half of the size of the discriminator

generator = ReformerLM(
    num_tokens = 20000,
    emb_dim = 128,
    dim = 256,              # smaller hidden dimension
    heads = 4,              # less heads
    ff_mult = 2,            # smaller feed forward intermediate dimension
    dim_head = 64,
    depth = 12,
    max_seq_len = 1024
)

discriminator = ReformerLM(
    num_tokens = 20000,
    emb_dim = 128,
    dim = 1024,
    dim_head = 64,
    heads = 16,
    depth = 12,
    ff_mult = 4,
    max_seq_len = 1024
)

# (2) weight tie the token and positional embeddings of generator and discriminator

generator.token_emb = discriminator.token_emb
generator.pos_emb = discriminator.pos_emb
# weight tie any other embeddings if available, token type embeddings, etc.

# (3) instantiate electra

trainer = Electra(
    generator,
    discriminator,
    discr_dim = 1024,           # the embedding dimension of the discriminator
    discr_layer = 'reformer',   # the layer name in the discriminator, whose output would be used for predicting token is still the same or replaced
    mask_token_id = 2,          # the token id reserved for masking
    pad_token_id = 0,           # the token id for padding
    mask_prob = 0.15,           # masking probability for masked language modeling
    mask_ignore_token_ids = []  # ids of tokens to ignore for mask modeling ex. (cls, sep)
)

# (4) train

data = torch.randint(0, 20000, (1, 1024))

results = trainer(data)
results.loss.backward()

# after much training, the discriminator should have improved

torch.save(discriminator, f'./pretrained-model.pt')

If you would rather not have the framework auto-magically intercept the hidden output of the discriminator, you can pass in the discriminator (with the extra linear [dim x 1]) by yourself with the following.

import torch
from torch import nn
from reformer_pytorch import ReformerLM

from electra_pytorch import Electra

# (1) instantiate the generator and discriminator, making sure that the generator is roughly a quarter to a half of the size of the discriminator

generator = ReformerLM(
    num_tokens = 20000,
    emb_dim = 128,
    dim = 256,              # smaller hidden dimension
    heads = 4,              # less heads
    ff_mult = 2,            # smaller feed forward intermediate dimension
    dim_head = 64,
    depth = 12,
    max_seq_len = 1024
)

discriminator = ReformerLM(
    num_tokens = 20000,
    emb_dim = 128,
    dim = 1024,
    dim_head = 64,
    heads = 16,
    depth = 12,
    ff_mult = 4,
    max_seq_len = 1024,
    return_embeddings = True
)

# (2) weight tie the token and positional embeddings of generator and discriminator

generator.token_emb = discriminator.token_emb
generator.pos_emb = discriminator.pos_emb
# weight tie any other embeddings if available, token type embeddings, etc.

# (3) instantiate electra

discriminator_with_adapter = nn.Sequential(discriminator, nn.Linear(1024, 1))

trainer = Electra(
    generator,
    discriminator_with_adapter,
    mask_token_id = 2,          # the token id reserved for masking
    pad_token_id = 0,           # the token id for padding
    mask_prob = 0.15,           # masking probability for masked language modeling
    mask_ignore_token_ids = []  # ids of tokens to ignore for mask modeling ex. (cls, sep)
)

# (4) train

data = torch.randint(0, 20000, (1, 1024))

results = trainer(data)
results.loss.backward()

# after much training, the discriminator should have improved

torch.save(discriminator, f'./pretrained-model.pt')

Important details for successful training

The generator should be roughly a quarter to at most one half of the discriminator's size for effective training. Any greater and the generator will be too good and the adversarial game collapses. This was done by reducing the hidden dimension, feed forward hidden dimension, and number of attention heads in the paper.

Testing

$ python setup.py test

Training

  1. Download the OpenWebText dataset.
$ mkdir data
$ cd data
$ pip3 install gdown
$ gdown --id 1EA5V0oetDCOke7afsktL_JDQ-ETtNOvx
$ tar -xf openwebtext.tar.xz
$ wget https://storage.googleapis.com/electra-data/vocab.txt
$ cd ..
  1. Tokenize dataset.
$ python pretraining/openwebtext/preprocess.py
  1. Pre-train.
$ python pretraining/openwebtext/pretrain.py
  1. Download GLUE dataset.
$ python examples/glue/download.py 
  1. Fine-tune on the MRPC sub-task of the GLUE benchmark.
$ python examples/glue/run.py --model_name_or_path output/yyyy-mm-dd-hh-mm-ss/ckpt/200000

Citations

@misc{clark2020electra,
    title={ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators},
    author={Kevin Clark and Minh-Thang Luong and Quoc V. Le and Christopher D. Manning},
    year={2020},
    eprint={2003.10555},
    archivePrefix={arXiv},
    primaryClass={cs.CL}
}

.\lucidrains\electra-pytorch\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'electra-pytorch',  # 包的名称
  packages = find_packages(),  # 查找所有包
  version = '0.1.2',  # 版本号
  license='MIT',  # 许可证
  description = 'Electra - Pytorch',  # 描述
  author = 'Erik Nijkamp, Phil Wang',  # 作者
  author_email = 'erik.nijkamp@gmail.com, lucidrains@gmail.com',  # 作者邮箱
  url = 'https://github.com/lucidrains/electra-pytorch',  # 项目链接
  keywords = [
    'transformers',  # 关键词
    'artificial intelligence',  # 关键词
    'pretraining'  # 关键词
  ],
  install_requires=[
    'torch>=1.6.0',  # 安装依赖
    'transformers==3.0.2',  # 安装依赖
    'scipy',  # 安装依赖
    'sklearn'  # 安装依赖
  ],
  setup_requires=[
    'pytest-runner'  # 安装依赖
  ],
  tests_require=[
    'pytest',  # 测试依赖
    'reformer-pytorch'  # 测试依赖
  ],
  classifiers=[
    'Development Status :: 4 - Beta',  # 分类
    'Intended Audience :: Developers',  # 分类
    'Topic :: Scientific/Engineering :: Artificial Intelligence',  # 分类
    'License :: OSI Approved :: MIT License',  # 分类
    'Programming Language :: Python :: 3.7',  # 分类
  ],
)

.\lucidrains\electra-pytorch\tests\test_electra_pytorch.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块
from torch import nn
# 从 reformer_pytorch 库中导入 ReformerLM 类
from reformer_pytorch import ReformerLM
# 从 electra_pytorch 库中导入 Electra 类

# 定义测试 Electra 模型的函数
def test_electra():
    # 创建生成器 ReformerLM 模型
    generator = ReformerLM(
        num_tokens = 20000,
        dim = 512,
        depth = 1,
        max_seq_len = 1024
    )

    # 创建鉴别器 ReformerLM 模型
    discriminator = ReformerLM(
        num_tokens = 20000,
        dim = 512,
        depth = 2,
        max_seq_len = 1024
    )

    # 将生成器的 token_emb 属性设置为鉴别器的 token_emb 属性
    generator.token_emb = discriminator.token_emb
    # 将生成器的 pos_emb 属性设置为鉴别器的 pos_emb 属性

    # 创建 Electra 训练器
    trainer = Electra(
        generator,
        discriminator,
        num_tokens = 20000,
        discr_dim = 512,
        discr_layer = 'reformer',
        pad_token_id = 1,
        mask_ignore_token_ids = [2, 3]
    )

    # 生成随机数据
    data = torch.randint(0, 20000, (1, 1024))
    # 使用训练器进行训练
    results = trainer(data)
    # 计算损失并反向传播
    results.loss.backward()

# 定义测试不使用魔法方法的 Electra 模型的函数
def test_electra_without_magic():
    # 创建生成器 ReformerLM 模型
    generator = ReformerLM(
        num_tokens = 20000,
        dim = 512,
        depth = 1,
        max_seq_len = 1024
    )

    # 创建鉴别器 ReformerLM 模型
    discriminator = ReformerLM(
        num_tokens = 20000,
        dim = 512,
        depth = 2,
        max_seq_len = 1024,
        return_embeddings = True
    )

    # 将生成器的 token_emb 属性设置为鉴别器的 token_emb 属性
    generator.token_emb = discriminator.token_emb
    # 将生成器的 pos_emb 属性设置为鉴别器的 pos_emb 属性

    # 创建包含适配器的鉴别器模型
    discriminator_with_adapter = nn.Sequential(
        discriminator,
        nn.Linear(512, 1),
        nn.Sigmoid()
    )

    # 创建 Electra 训练器
    trainer = Electra(
        generator,
        discriminator_with_adapter,
        num_tokens = 20000,
        pad_token_id = 1,
        mask_ignore_token_ids = [2, 3]
    )

    # 生成随机数据
    data = torch.randint(0, 20000, (1, 1024))
    # 使用训练器进行训练
    results = trainer(data)
    # 计算损失并反向传播
    results.loss.backward()

.\lucidrains\ema-pytorch\ema_pytorch\ema_pytorch.py

# 导入深拷贝函数 deepcopy 和 partial 函数
from copy import deepcopy
from functools import partial

# 导入 torch 库
import torch
# 从 torch 库中导入 nn, Tensor 模块
from torch import nn, Tensor
# 从 torch.nn 模块中导入 Module 类
from torch.nn import Module

# 导入 beartype 库
from beartype import beartype
# 从 beartype.typing 模块中导入 Set, Optional 类型
from beartype.typing import Set, Optional

# 定义函数 exists,用于检查值是否存在
def exists(val):
    return val is not None

# 定义函数 get_module_device,用于获取模块的设备信息
def get_module_device(m: Module):
    return next(m.parameters()).device

# 定义函数 inplace_copy,用于原地复制张量数据
def inplace_copy(tgt: Tensor, src: Tensor, *, auto_move_device = False):
    if auto_move_device:
        src = src.to(tgt.device)

    tgt.copy_(src)

# 定义函数 inplace_lerp,用于原地线性插值
def inplace_lerp(tgt: Tensor, src: Tensor, weight, *, auto_move_device = False):
    if auto_move_device:
        src = src.to(tgt.device)

    tgt.lerp_(src, weight)

# 定义 EMA 类,实现模型的指数移动平均阴影
class EMA(Module):
    """
    Implements exponential moving average shadowing for your model.

    Utilizes an inverse decay schedule to manage longer term training runs.
    By adjusting the power, you can control how fast EMA will ramp up to your specified beta.

    @crowsonkb's notes on EMA Warmup:

    If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are
    good values for models you plan to train for a million or more steps (reaches decay
    factor 0.999 at 31.6K steps, 0.9999 at 1M steps), gamma=1, power=3/4 for models
    you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at
    215.4k steps).

    Args:
        inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
        power (float): Exponential factor of EMA warmup. Default: 2/3.
        min_value (float): The minimum EMA decay rate. Default: 0.
    """

    # 使用 beartype 装饰器,对初始化函数进行类型检查
    @beartype
    def __init__(
        self,
        model: Module,
        ema_model: Optional[Module] = None,           # if your model has lazylinears or other types of non-deepcopyable modules, you can pass in your own ema model
        beta = 0.9999,
        update_after_step = 100,
        update_every = 10,
        inv_gamma = 1.0,
        power = 2 / 3,
        min_value = 0.0,
        param_or_buffer_names_no_ema: Set[str] = set(),
        ignore_names: Set[str] = set(),
        ignore_startswith_names: Set[str] = set(),
        include_online_model = True,                  # set this to False if you do not wish for the online model to be saved along with the ema model (managed externally)
        allow_different_devices = False               # if the EMA model is on a different device (say CPU), automatically move the tensor
    ):
        # 调用父类的构造函数
        super().__init__()
        # 初始化 beta 属性
        self.beta = beta

        # 判断是否冻结模型
        self.is_frozen = beta == 1.

        # 是否在模块树中包含在线模型,以便 state_dict 也保存它
        self.include_online_model = include_online_model

        if include_online_model:
            self.online_model = model
        else:
            self.online_model = [model] # hack

        # EMA 模型
        self.ema_model = ema_model

        if not exists(self.ema_model):
            try:
                self.ema_model = deepcopy(model)
            except Exception as e:
                print(f'Error: While trying to deepcopy model: {e}')
                print('Your model was not copyable. Please make sure you are not using any LazyLinear')
                exit()

        self.ema_model.requires_grad_(False)

        # 参数和缓冲区的名称
        self.parameter_names = {name for name, param in self.ema_model.named_parameters() if torch.is_floating_point(param) or torch.is_complex(param)}
        self.buffer_names = {name for name, buffer in self.ema_model.named_buffers() if torch.is_floating_point(buffer) or torch.is_complex(buffer)}

        # 张量更新函数
        self.inplace_copy = partial(inplace_copy, auto_move_device = allow_different_devices)
        self.inplace_lerp = partial(inplace_lerp, auto_move_device = allow_different_devices)

        # 更新超参数
        self.update_every = update_every
        self.update_after_step = update_after_step
        self.inv_gamma = inv_gamma
        self.power = power
        self.min_value = min_value

        assert isinstance(param_or_buffer_names_no_ema, (set, list))
        self.param_or_buffer_names_no_ema = param_or_buffer_names_no_ema # parameter or buffer

        self.ignore_names = ignore_names
        self.ignore_startswith_names = ignore_startswith_names

        # 是否管理 EMA 模型是否保留在不同设备上
        self.allow_different_devices = allow_different_devices

        # 初始化和步骤状态
        self.register_buffer('initted', torch.tensor(False))
        self.register_buffer('step', torch.tensor(0))

    @property
    def model(self):
        return self.online_model if self.include_online_model else self.online_model[0]

    def eval(self):
        return self.ema_model.eval()
    
    def restore_ema_model_device(self):
        device = self.initted.device
        self.ema_model.to(device)

    def get_params_iter(self, model):
        for name, param in model.named_parameters():
            if name not in self.parameter_names:
                continue
            yield name, param

    def get_buffers_iter(self, model):
        for name, buffer in model.named_buffers():
            if name not in self.buffer_names:
                continue
            yield name, buffer

    def copy_params_from_model_to_ema(self):
        copy = self.inplace_copy

        for (_, ma_params), (_, current_params) in zip(self.get_params_iter(self.ema_model), self.get_params_iter(self.model)):
            copy(ma_params.data, current_params.data)

        for (_, ma_buffers), (_, current_buffers) in zip(self.get_buffers_iter(self.ema_model), self.get_buffers_iter(self.model)):
            copy(ma_buffers.data, current_buffers.data)

    def copy_params_from_ema_to_model(self):
        copy = self.inplace_copy

        for (_, ma_params), (_, current_params) in zip(self.get_params_iter(self.ema_model), self.get_params_iter(self.model)):
            copy(current_params.data, ma_params.data)

        for (_, ma_buffers), (_, current_buffers) in zip(self.get_buffers_iter(self.ema_model), self.get_buffers_iter(self.model)):
            copy(current_buffers.data, ma_buffers.data)
    # 获取当前的衰减值
    def get_current_decay(self):
        # 计算当前的 epoch,确保不小于 0
        epoch = (self.step - self.update_after_step - 1).clamp(min=0.)
        # 根据公式计算衰减值
        value = 1 - (1 + epoch / self.inv_gamma) ** -self.power

        # 如果 epoch 小于等于 0,则返回 0
        if epoch.item() <= 0:
            return 0.

        # 返回计算得到的衰减值,确保在一定范围内
        return value.clamp(min=self.min_value, max=self.beta).item()

    # 更新操作
    def update(self):
        # 获取当前步数
        step = self.step.item()
        # 步数加一
        self.step += 1

        # 如果步数不是更新频率的倍数,则直接返回
        if (step % self.update_every) != 0:
            return

        # 如果步数小于等于更新之后的步数,则将模型参数拷贝到指数移动平均模型中
        if step <= self.update_after_step:
            self.copy_params_from_model_to_ema()
            return

        # 如果模型还未初始化,则将模型参数拷贝到指数移动平均模型中,并标记为已初始化
        if not self.initted.item():
            self.copy_params_from_model_to_ema()
            self.initted.data.copy_(torch.tensor(True))

        # 更新指数移动平均模型
        self.update_moving_average(self.ema_model, self.model)

    # 更新指数移动平均模型
    @torch.no_grad()
    def update_moving_average(self, ma_model, current_model):
        # 如果模型被冻结,则直接返回
        if self.is_frozen:
            return

        # 获取拷贝和线性插值函数
        copy, lerp = self.inplace_copy, self.inplace_lerp
        # 获取当前的衰减值
        current_decay = self.get_current_decay()

        # 遍历当前模型和指数移动平均模型的参数
        for (name, current_params), (_, ma_params) in zip(self.get_params_iter(current_model), self.get_params_iter(ma_model)):
            # 如果参数名在忽略列表中,则跳过
            if name in self.ignore_names:
                continue

            # 如果参数名以忽略列表中的前缀开头,则跳过
            if any([name.startswith(prefix) for prefix in self.ignore_startswith_names]):
                continue

            # 如果参数名在不进行指数移动平均的列表中,则直接拷贝参数值
            if name in self.param_or_buffer_names_no_ema:
                copy(ma_params.data, current_params.data)
                continue

            # 对参数进行线性插值
            lerp(ma_params.data, current_params.data, 1. - current_decay)

        # 遍历当前模型和指数移动平均模型的缓冲区
        for (name, current_buffer), (_, ma_buffer) in zip(self.get_buffers_iter(current_model), self.get_buffers_iter(ma_model)):
            # 如果缓冲区名在忽略列表中,则跳过
            if name in self.ignore_names:
                continue

            # 如果缓冲区名以忽略列表中的前缀开头,则跳过
            if any([name.startswith(prefix) for prefix in self.ignore_startswith_names]):
                continue

            # 如果缓冲区名在不进行指数移动平均的列表中,则直接拷贝缓冲区值
            if name in self.param_or_buffer_names_no_ema:
                copy(ma_buffer.data, current_buffer.data)
                continue

            # 对缓冲区进行线性插值
            lerp(ma_buffer.data, current_buffer.data, 1. - current_decay)

    # 调用函数,返回指数移动平均模型的结果
    def __call__(self, *args, **kwargs):
        return self.ema_model(*args, **kwargs)

.\lucidrains\ema-pytorch\ema_pytorch\post_hoc_ema.py

# 导入必要的模块
from pathlib import Path
from copy import deepcopy
from functools import partial

import torch
from torch import nn, Tensor
from torch.nn import Module, ModuleList

import numpy as np

from beartype import beartype
from beartype.typing import Set, Tuple, Optional

# 检查值是否存在
def exists(val):
    return val is not None

# 返回默认值
def default(val, d):
    return val if exists(val) else d

# 返回数组的第一个元素
def first(arr):
    return arr[0]

# 获取模块的设备
def get_module_device(m: Module):
    return next(m.parameters()).device

# 在原地复制张量
def inplace_copy(tgt: Tensor, src: Tensor, *, auto_move_device = False):
    if auto_move_device:
        src = src.to(tgt.device)

    tgt.copy_(src)

# 在原地执行线性插值
def inplace_lerp(tgt: Tensor, src: Tensor, weight, *, auto_move_device = False):
    if auto_move_device:
        src = src.to(tgt.device)

    tgt.lerp_(src, weight)

# 将相对标准差转换为 gamma
def sigma_rel_to_gamma(sigma_rel):
    t = sigma_rel ** -2
    return np.roots([1, 7, 16 - t, 12 - t]).real.max().item()

# EMA 模块,使用论文 https://arxiv.org/abs/2312.02696 中的超参数
class KarrasEMA(Module):
    """
    exponential moving average module that uses hyperparameters from the paper https://arxiv.org/abs/2312.02696
    can either use gamma or sigma_rel from paper
    """

    @beartype
    def __init__(
        self,
        model: Module,
        sigma_rel: Optional[float] = None,
        gamma: Optional[float] = None,
        ema_model: Optional[Module] = None,           # if your model has lazylinears or other types of non-deepcopyable modules, you can pass in your own ema model
        update_every: int = 100,
        frozen: bool = False,
        param_or_buffer_names_no_ema: Set[str] = set(),
        ignore_names: Set[str] = set(),
        ignore_startswith_names: Set[str] = set(),
        allow_different_devices = False               # if the EMA model is on a different device (say CPU), automatically move the tensor
    ):
        super().__init__()

        assert exists(sigma_rel) ^ exists(gamma), 'either sigma_rel or gamma is given. gamma is derived from sigma_rel as in the paper, then beta is dervied from gamma'

        if exists(sigma_rel):
            gamma = sigma_rel_to_gamma(sigma_rel)

        self.gamma = gamma
        self.frozen = frozen

        self.online_model = [model]

        # ema model

        self.ema_model = ema_model

        if not exists(self.ema_model):
            try:
                self.ema_model = deepcopy(model)
            except Exception as e:
                print(f'Error: While trying to deepcopy model: {e}')
                print('Your model was not copyable. Please make sure you are not using any LazyLinear')
                exit()

        self.ema_model.requires_grad_(False)

        # parameter and buffer names

        self.parameter_names = {name for name, param in self.ema_model.named_parameters() if torch.is_floating_point(param) or torch.is_complex(param)}
        self.buffer_names = {name for name, buffer in self.ema_model.named_buffers() if torch.is_floating_point(buffer) or torch.is_complex(buffer)}

        # tensor update functions

        self.inplace_copy = partial(inplace_copy, auto_move_device = allow_different_devices)
        self.inplace_lerp = partial(inplace_lerp, auto_move_device = allow_different_devices)

        # updating hyperparameters

        self.update_every = update_every

        assert isinstance(param_or_buffer_names_no_ema, (set, list))
        self.param_or_buffer_names_no_ema = param_or_buffer_names_no_ema # parameter or buffer

        self.ignore_names = ignore_names
        self.ignore_startswith_names = ignore_startswith_names

        # whether to manage if EMA model is kept on a different device

        self.allow_different_devices = allow_different_devices

        # init and step states

        self.register_buffer('initted', torch.tensor(False))
        self.register_buffer('step', torch.tensor(0))

    @property
    def model(self):
        return first(self.online_model)
    
    @property
    # 计算 beta 值,用于更新移动平均模型
    def beta(self):
        return (1 - 1 / (self.step + 1)) ** (1 + self.gamma)

    # 调用 EMA 模型的 eval 方法
    def eval(self):
        return self.ema_model.eval()
    
    # 将 EMA 模型恢复到指定设备上
    def restore_ema_model_device(self):
        device = self.initted.device
        self.ema_model.to(device)

    # 获取模型的参数迭代器
    def get_params_iter(self, model):
        for name, param in model.named_parameters():
            if name not in self.parameter_names:
                continue
            yield name, param

    # 获取模型的缓冲区迭代器
    def get_buffers_iter(self, model):
        for name, buffer in model.named_buffers():
            if name not in self.buffer_names:
                continue
            yield name, buffer

    # 从原模型复制参数到 EMA 模型
    def copy_params_from_model_to_ema(self):
        copy = self.inplace_copy

        for (_, ma_params), (_, current_params) in zip(self.get_params_iter(self.ema_model), self.get_params_iter(self.model)):
            copy(ma_params.data, current_params.data)

        for (_, ma_buffers), (_, current_buffers) in zip(self.get_buffers_iter(self.ema_model), self.get_buffers_iter(self.model)):
            copy(ma_buffers.data, current_buffers.data)

    # 从 EMA 模型复制参数到原模型
    def copy_params_from_ema_to_model(self):
        copy = self.inplace_copy

        for (_, ma_params), (_, current_params) in zip(self.get_params_iter(self.ema_model), self.get_params_iter(self.model)):
            copy(current_params.data, ma_params.data)

        for (_, ma_buffers), (_, current_buffers) in zip(self.get_buffers_iter(self.ema_model), self.get_buffers_iter(self.model)):
            copy(current_buffers.data, ma_buffers.data)

    # 更新步数并执行移动平均更新
    def update(self):
        step = self.step.item()
        self.step += 1

        if (step % self.update_every) != 0:
            return

        if not self.initted.item():
            self.copy_params_from_model_to_ema()
            self.initted.data.copy_(torch.tensor(True))

        self.update_moving_average(self.ema_model, self.model)

    # 迭代所有 EMA 模型的参数和缓冲区
    def iter_all_ema_params_and_buffers(self):
        for name, ma_params in self.get_params_iter(self.ema_model):
            if name in self.ignore_names:
                continue

            if any([name.startswith(prefix) for prefix in self.ignore_startswith_names]):
                continue

            if name in self.param_or_buffer_names_no_ema:
                continue

            yield ma_params

        for name, ma_buffer in self.get_buffers_iter(self.ema_model):
            if name in self.ignore_names:
                continue

            if any([name.startswith(prefix) for prefix in self.ignore_startswith_names]):
                continue

            if name in self.param_or_buffer_names_no_ema:
                continue

            yield ma_buffer

    # 更新移动平均模型
    @torch.no_grad()
    def update_moving_average(self, ma_model, current_model):
        if self.frozen:
            return

        copy, lerp = self.inplace_copy, self.inplace_lerp
        current_decay = self.beta

        for (name, current_params), (_, ma_params) in zip(self.get_params_iter(current_model), self.get_params_iter(ma_model)):
            if name in self.ignore_names:
                continue

            if any([name.startswith(prefix) for prefix in self.ignore_startswith_names]):
                continue

            if name in self.param_or_buffer_names_no_ema:
                copy(ma_params.data, current_params.data)
                continue

            lerp(ma_params.data, current_params.data, 1. - current_decay)

        for (name, current_buffer), (_, ma_buffer) in zip(self.get_buffers_iter(current_model), self.get_buffers_iter(ma_model)):
            if name in self.ignore_names:
                continue

            if any([name.startswith(prefix) for prefix in self.ignore_startswith_names]):
                continue

            if name in self.param_or_buffer_names_no_ema:
                copy(ma_buffer.data, current_buffer.data)
                continue

            lerp(ma_buffer.data, current_buffer.data, 1. - current_decay)
    # 定义一个特殊方法 __call__,使得对象可以像函数一样被调用
    def __call__(self, *args, **kwargs):
        # 调用 ema_model 对象,并传入参数
        return self.ema_model(*args, **kwargs)
# 后验EMA包装器

# 解决将所有检查点组合成新合成的EMA的权重,以达到所需的gamma
# 算法3从论文中复制,用torch重新实现

# 计算两个张量的点乘
def p_dot_p(t_a, gamma_a, t_b, gamma_b):
    t_ratio = t_a / t_b
    t_exp = torch.where(t_a < t_b , gamma_b , -gamma_a)
    t_max = torch.maximum(t_a , t_b)
    num = (gamma_a + 1) * (gamma_b + 1) * t_ratio ** t_exp
    den = (gamma_a + gamma_b + 1) * t_max
    return num / den

# 解决权重
def solve_weights(t_i, gamma_i, t_r, gamma_r):
    rv = lambda x: x.double().reshape(-1, 1)
    cv = lambda x: x.double().reshape(1, -1)
    A = p_dot_p(rv(t_i), rv(gamma_i), cv(t_i), cv(gamma_i))
    b = p_dot_p(rv(t_i), rv(gamma_i), cv(t_r), cv(gamma_r))
    return torch.linalg.solve(A, b)

# 后验EMA类
class PostHocEMA(Module):

    # 初始化函数
    @beartype
    def __init__(
        self,
        model: Module,
        sigma_rels: Optional[Tuple[float, ...]] = None,
        gammas: Optional[Tuple[float, ...]] = None,
        checkpoint_every_num_steps: int = 1000,
        checkpoint_folder: str = './post-hoc-ema-checkpoints',
        **kwargs
    ):
        super().__init__()
        assert exists(sigma_rels) ^ exists(gammas)

        if exists(sigma_rels):
            gammas = tuple(map(sigma_rel_to_gamma, sigma_rels))

        assert len(gammas) > 1, 'at least 2 ema models with different gammas in order to synthesize new ema models of a different gamma'
        assert len(set(gammas)) == len(gammas), 'calculated gammas must be all unique'

        self.gammas = gammas
        self.num_ema_models = len(gammas)

        self._model = [model]
        self.ema_models = ModuleList([KarrasEMA(model, gamma = gamma, **kwargs) for gamma in gammas])

        self.checkpoint_folder = Path(checkpoint_folder)
        self.checkpoint_folder.mkdir(exist_ok = True, parents = True)
        assert self.checkpoint_folder.is_dir()

        self.checkpoint_every_num_steps = checkpoint_every_num_steps
        self.ema_kwargs = kwargs

    # 返回模型
    @property
    def model(self):
        return first(self._model)

    # 返回步数
    @property
    def step(self):
        return first(self.ema_models).step

    # 返回设备
    @property
    def device(self):
        return self.step.device

    # 从EMA复制参数到模型
    def copy_params_from_ema_to_model(self):
        for ema_model in self.ema_models:
            ema_model.copy_params_from_model_to_ema()

    # 更新EMA模型
    def update(self):
        for ema_model in self.ema_models:
            ema_model.update()

        if not (self.step.item() % self.checkpoint_every_num_steps):
            self.checkpoint()

    # 创建检查点
    def checkpoint(self):
        step = self.step.item()

        for ind, ema_model in enumerate(self.ema_models):
            filename = f'{ind}.{step}.pt'
            path = self.checkpoint_folder / filename

            pkg = deepcopy(ema_model).half().state_dict()
            torch.save(pkg, str(path))

    # 合成EMA模型
    @beartype
    def synthesize_ema_model(
        self,
        gamma: Optional[float] = None,
        sigma_rel: Optional[float] = None,
        step: Optional[int] = None,
    # 定义一个返回 KarrasEMA 对象的函数,参数包括 gamma 和 sigma_rel
    def __call__(self, gamma: Optional[float] = None, sigma_rel: Optional[float] = None) -> KarrasEMA:
        # 断言 gamma 和 sigma_rel 只能存在一个
        assert exists(gamma) ^ exists(sigma_rel)
        # 获取设备信息
        device = self.device

        # 如果存在 sigma_rel,则根据 sigma_rel 转换为 gamma
        if exists(sigma_rel):
            gamma = sigma_rel_to_gamma(sigma_rel)

        # 创建一个合成的 EMA 模型对象
        synthesized_ema_model = KarrasEMA(
            model = self.model,
            gamma = gamma,
            **self.ema_kwargs
        )

        synthesized_ema_model

        # 获取所有检查点

        gammas = []
        timesteps = []
        checkpoints = [*self.checkpoint_folder.glob('*.pt')]

        # 遍历检查点文件,获取 gamma 和 timestep
        for file in checkpoints:
            gamma_ind, timestep = map(int, file.stem.split('.'))
            gamma = self.gammas[gamma_ind]

            gammas.append(gamma)
            timesteps.append(timestep)

        # 设置步数为最大 timestep
        step = default(step, max(timesteps))
        # 断言步数小于等于最大 timestep
        assert step <= max(timesteps), f'you can only synthesize for a timestep that is less than the max timestep {max(timesteps)}'

        # 与算法 3 对齐

        gamma_i = Tensor(gammas, device = device)
        t_i = Tensor(timesteps, device = device)

        gamma_r = Tensor([gamma], device = device)
        t_r = Tensor([step], device = device)

        # 使用最小二乘法解出将所有检查点组合成合成检查点的权重

        weights = solve_weights(t_i, gamma_i, t_r, gamma_r)
        weights = weights.squeeze(-1)

        # 逐个使用权重将所有检查点相加到合成模型中

        tmp_ema_model = KarrasEMA(
            model = self.model,
            gamma = gamma,
            **self.ema_kwargs
        )

        for ind, (checkpoint, weight) in enumerate(zip(checkpoints, weights.tolist())):
            is_first = ind == 0

            # 将检查点加载到临时 EMA 模型中

            ckpt_state_dict = torch.load(str(checkpoint))
            tmp_ema_model.load_state_dict(ckpt_state_dict)

            # 将加权检查点添加到合成模型中

            for ckpt_tensor, synth_tensor in zip(tmp_ema_model.iter_all_ema_params_and_buffers(), synthesized_ema_model.iter_all_ema_params_and_buffers()):
                if is_first:
                    synth_tensor.zero_()

                synth_tensor.add_(ckpt_tensor * weight)

        # 返回合成模型

        return synthesized_ema_model

    # 调用函数,返回所有 EMA 模型的结果
    def __call__(self, *args, **kwargs):
        return tuple(ema_model(*args, **kwargs) for ema_model in self.ema_models)

.\lucidrains\ema-pytorch\ema_pytorch\__init__.py

# 从 ema_pytorch 模块中导入 EMA 类
from ema_pytorch.ema_pytorch import EMA

# 从 ema_pytorch 模块中导入 KarrasEMA 和 PostHocEMA 类
from ema_pytorch.post_hoc_ema import (
    KarrasEMA,
    PostHocEMA
)

EMA - Pytorch

A simple way to keep track of an Exponential Moving Average (EMA) version of your pytorch model

Install

$ pip install ema-pytorch

Usage

import torch
from ema_pytorch import EMA

# your neural network as a pytorch module

net = torch.nn.Linear(512, 512)

# wrap your neural network, specify the decay (beta)

ema = EMA(
    net,
    beta = 0.9999,              # exponential moving average factor
    update_after_step = 100,    # only after this number of .update() calls will it start updating
    update_every = 10,          # how often to actually update, to save on compute (updates every 10th .update() call)
)

# mutate your network, with SGD or otherwise

with torch.no_grad():
    net.weight.copy_(torch.randn_like(net.weight))
    net.bias.copy_(torch.randn_like(net.bias))

# you will call the update function on your moving average wrapper

ema.update()

# then, later on, you can invoke the EMA model the same way as your network

data = torch.randn(1, 512)

output     = net(data)
ema_output = ema(data)

# if you want to save your ema model, it is recommended you save the entire wrapper
# as it contains the number of steps taken (there is a warmup logic in there, recommended by @crowsonkb, validated for a number of projects now)
# however, if you wish to access the copy of your model with EMA, then it will live at ema.ema_model

In order to use the post-hoc synthesized EMA, proposed by Karras et al. in a recent paper, follow the example below

import torch
from ema_pytorch import PostHocEMA

# your neural network as a pytorch module

net = torch.nn.Linear(512, 512)

# wrap your neural network, specify the sigma_rels or gammas

emas = PostHocEMA(
    net,
    sigma_rels = (0.05, 0.3),           # a tuple with the hyperparameter for the multiple EMAs. you need at least 2 here to synthesize a new one
    update_every = 10,                  # how often to actually update, to save on compute (updates every 10th .update() call)
    checkpoint_every_num_steps = 10,
    checkpoint_folder = './post-hoc-ema-checkpoints'  # the folder of saved checkpoints for each sigma_rel (gamma) across timesteps with the hparam above, used to synthesizing a new EMA model after training
)

net.train()

for _ in range(1000):
    # mutate your network, with SGD or otherwise

    with torch.no_grad():
        net.weight.copy_(torch.randn_like(net.weight))
        net.bias.copy_(torch.randn_like(net.bias))

    # you will call the update function on your moving average wrapper

    emas.update()

# now that you have a few checkpoints
# you can synthesize an EMA model with a different sigma_rel (say 0.15)

synthesized_ema = emas.synthesize_ema_model(sigma_rel = 0.15)

# output with synthesized EMA

data = torch.randn(1, 512)

synthesized_ema_output = synthesized_ema(data)

Citations

@article{Karras2023AnalyzingAI,
    title   = {Analyzing and Improving the Training Dynamics of Diffusion Models},
    author  = {Tero Karras and Miika Aittala and Jaakko Lehtinen and Janne Hellsten and Timo Aila and Samuli Laine},
    journal = {ArXiv},
    year    = {2023},
    volume  = {abs/2312.02696},
    url     = {https://api.semanticscholar.org/CorpusID:265659032}
}

.\lucidrains\ema-pytorch\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'ema-pytorch',  # 包的名称
  packages = find_packages(exclude=[]),  # 查找所有包
  version = '0.4.3',  # 版本号
  license='MIT',  # 许可证
  description = 'Easy way to keep track of exponential moving average version of your pytorch module',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  long_description_content_type = 'text/markdown',  # 长描述内容类型
  url = 'https://github.com/lucidrains/ema-pytorch',  # URL
  keywords = [
    'artificial intelligence',  # 关键词
    'deep learning',  # 关键词
    'exponential moving average'  # 关键词
  ],
  install_requires=[
    'beartype',  # 安装依赖
    'torch>=1.6',  # 安装依赖
  ],
  classifiers=[
    'Development Status :: 4 - Beta',  # 分类
    'Intended Audience :: Developers',  # 分类
    'Topic :: Scientific/Engineering :: Artificial Intelligence',  # 分类
    'License :: OSI Approved :: MIT License',  # 分类
    'Programming Language :: Python :: 3.6',  # 分类
  ],
)

.\lucidrains\En-transformer\denoise.py

# 导入 PyTorch 库
import torch
# 导入 PyTorch 中的函数库
import torch.nn.functional as F
# 从 torch 模块中导入 nn 模块
from torch import nn
# 从 torch.optim 模块中导入 Adam 优化器
from torch.optim import Adam

# 从 einops 库中导入 rearrange 和 repeat 函数
from einops import rearrange, repeat
# 导入 sidechainnet 库并重命名为 scn
import sidechainnet as scn
# 从 en_transformer 模块中导入 EnTransformer 类
from en_transformer.en_transformer import EnTransformer

# 设置默认的张量数据类型为 float64
torch.set_default_dtype(torch.float64)

# 定义批量大小为 1
BATCH_SIZE = 1
# 定义每隔多少次梯度累积
GRADIENT_ACCUMULATE_EVERY = 16

# 定义一个循环函数,用于生成数据批次
def cycle(loader, len_thres = 200):
    while True:
        for data in loader:
            # 如果数据序列长度大于指定阈值,则继续循环
            if data.seqs.shape[1] > len_thres:
                continue
            # 生成数据
            yield data

# 创建 EnTransformer 模型实例
transformer = EnTransformer(
    num_tokens = 21,
    dim = 32,
    dim_head = 64,
    heads = 4,
    depth = 4,
    rel_pos_emb = True, # 序列中存在固有的顺序(氨基酸链的主干原子)
    neighbors = 16
)

# 加载数据集
data = scn.load(
    casp_version = 12,
    thinning = 30,
    with_pytorch = 'dataloaders',
    batch_size = BATCH_SIZE,
    dynamic_batching = False
)

# 创建数据加载器
dl = cycle(data['train'])
# 使用 Adam 优化器来优化 EnTransformer 模型的参数
optim = Adam(transformer.parameters(), lr=1e-3)
# 将模型移动到 GPU 上
transformer = transformer.cuda()

# 进行训练循环
for _ in range(10000):
    for _ in range(GRADIENT_ACCUMULATE_EVERY):
        # 获取一个数据批次
        batch = next(dl)
        seqs, coords, masks = batch.seqs, batch.crds, batch.msks

        # 将序列数据移动到 GPU 上并取最大值
        seqs = seqs.cuda().argmax(dim = -1)
        # 将坐标数据移动到 GPU 上并转换为 float64 类型
        coords = coords.cuda().type(torch.float64)
        # 将掩码数据移动到 GPU 上并转换为布尔类型
        masks = masks.cuda().bool()

        # 获取序列长度
        l = seqs.shape[1]
        # 重新排列坐标数据的维度
        coords = rearrange(coords, 'b (l s) c -> b l s c', s = 14)

        # 保留主干坐标

        coords = coords[:, :, 0:3, :]
        coords = rearrange(coords, 'b l s c -> b (l s) c')

        # 重复序列数据和掩码数据的维度
        seq = repeat(seqs, 'b n -> b (n c)', c = 3)
        masks = repeat(masks, 'b n -> b (n c)', c = 3)

        # 添加噪声到坐标数据
        noised_coords = coords + torch.randn_like(coords)

        # 使用 Transformer 模型进行特征提取和去噪
        feats, denoised_coords = transformer(seq, noised_coords, mask = masks)

        # 计算均方误差损失
        loss = F.mse_loss(denoised_coords[masks], coords[masks])

        # 反向传播并计算梯度
        (loss / GRADIENT_ACCUMULATE_EVERY).backward()

    # 打印损失值
    print('loss:', loss.item())
    # 更新优化器
    optim.step()
    # 清空梯度
    optim.zero_grad()

.\lucidrains\En-transformer\en_transformer\en_transformer.py

# 导入 torch 库
import torch
# 导入 torch 中的函数库
import torch.nn.functional as F
# 从 torch 中导入 nn、einsum 模块
from torch import nn, einsum
# 从 torch.utils.checkpoint 中导入 checkpoint_sequential 函数
from torch.utils.checkpoint import checkpoint_sequential
# 从 einx 中导入 get_at 函数
from einx import get_at
# 从 einops 中导入 rearrange、repeat、reduce 函数,从 einops.layers.torch 中导入 Rearrange 类
from einops import rearrange, repeat, reduce
from einops.layers.torch import Rearrange
# 从 taylor_series_linear_attention 中导入 TaylorSeriesLinearAttn 类

# 辅助函数

# 判断变量是否存在的函数
def exists(val):
    return val is not None

# 返回指定数据类型的最小负值的函数
def max_neg_value(t):
    return -torch.finfo(t.dtype).max

# 如果变量存在则返回该变量,否则返回默认值的函数
def default(val, d):
    return val if exists(val) else d

# 对输入张量进行 L2 归一化的函数
def l2norm(t):
    return F.normalize(t, dim = -1)

# 对 nn.Linear 类型的权重进行小范围初始化的函数
def small_init_(t: nn.Linear):
    nn.init.normal_(t.weight, std = 0.02)
    nn.init.zeros_(t.bias)

# 动态位置偏置

class DynamicPositionBias(nn.Module):
    def __init__(
        self,
        dim,
        *,
        heads,
        depth,
        dim_head,
        input_dim = 1,
        norm = True
    ):
        super().__init__()
        assert depth >= 1, 'depth for dynamic position bias MLP must be greater or equal to 1'
        self.mlp = nn.ModuleList([])

        self.mlp.append(nn.Sequential(
            nn.Linear(input_dim, dim),
            nn.LayerNorm(dim) if norm else nn.Identity(),
            nn.SiLU()
        ))

        for _ in range(depth - 1):
            self.mlp.append(nn.Sequential(
                nn.Linear(dim, dim),
                nn.LayerNorm(dim) if norm else nn.Identity(),
                nn.SiLU()
            ))

        self.heads = heads
        self.qk_pos_head = nn.Linear(dim, heads)
        self.value_pos_head = nn.Linear(dim, dim_head * heads)

    def forward(self, pos):
        for layer in self.mlp:
            pos = layer(pos)

        qk_pos = self.qk_pos_head(pos)
        value_pos = self.value_pos_head(pos)

        qk_pos = rearrange(qk_pos, 'b 1 i j h -> b h i j')
        value_pos = rearrange(value_pos, 'b 1 i j (h d) -> b h i j d', h = self.heads)
        return qk_pos, value_pos

# 类

# 此类遵循 SE3 Transformers 中的规范化策略
# https://github.com/lucidrains/se3-transformer-pytorch/blob/main/se3_transformer_pytorch/se3_transformer_pytorch.py#L95

# 层归一化类
class LayerNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(dim))
        self.register_buffer('beta', torch.zeros(dim))

    def forward(self, x):
        return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)

# 坐标归一化类
class CoorsNorm(nn.Module):
    def __init__(self, eps = 1e-8, scale_init = 1.):
        super().__init__()
        self.eps = eps
        scale = torch.zeros(1).fill_(scale_init)
        self.scale = nn.Parameter(scale)

    def forward(self, coors):
        norm = coors.norm(dim = -1, keepdim = True)
        normed_coors = coors / norm.clamp(min = self.eps)
        return normed_coors * self.scale

# 残差连接类
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, feats, coors, **kwargs):
        feats_out, coors_delta = self.fn(feats, coors, **kwargs)
        return feats + feats_out, coors + coors_delta

# GEGLU 激活函数类
class GEGLU(nn.Module):
    def forward(self, x):
        x, gates = x.chunk(2, dim = -1)
        return x * F.gelu(gates)

# 前馈神经网络类
class FeedForward(nn.Module):
    def __init__(
        self,
        *,
        dim,
        mult = 4,
        dropout = 0.
    ):
        super().__init__()
        inner_dim = int(dim * mult * 2 / 3)

        self.net = nn.Sequential(
            LayerNorm(dim),
            nn.Linear(dim, inner_dim * 2, bias = False),
            GEGLU(),
            LayerNorm(inner_dim),
            nn.Dropout(dropout),
            nn.Linear(inner_dim, dim, bias = False)
        )

    def forward(self, feats, coors):
        return self.net(feats), 0

class EquivariantAttention(nn.Module):
    # 初始化函数,设置Transformer模型的参数
    def __init__(
        self,
        *,
        dim,  # 输入特征的维度
        dim_head = 64,  # 每个头的维度
        heads = 4,  # 多头注意力机制的头数
        edge_dim = 0,  # 边的特征维度
        coors_hidden_dim = 16,  # 坐标隐藏层的维度
        neighbors = 0,  # 邻居节点的数量
        only_sparse_neighbors = False,  # 是否只使用稀疏邻居
        valid_neighbor_radius = float('inf'),  # 有效邻居的半径
        init_eps = 1e-3,  # 初始化的小量值
        rel_pos_emb = None,  # 相对位置编码
        edge_mlp_mult = 2,  # 边的多层感知机的倍数
        norm_rel_coors = True,  # 是否对相对坐标进行归一化
        norm_coors_scale_init = 1.,  # 归一化坐标的初始值
        use_cross_product = False,  # 是否使用叉积
        talking_heads = False,  # 是否使用Talking Heads
        dropout = 0.,  # Dropout概率
        num_global_linear_attn_heads = 0,  # 全局线性注意力机制的头数
        linear_attn_dim_head = 8,  # 线性注意力机制的头维度
        gate_outputs = True,  # 是否使用门控输出
        gate_init_bias = 10.  # 门控初始化偏置
    # 初始化函数,设置模型参数初始化方式
    def __init__(
        self,
        heads,
        dim,
        dim_head,
        num_global_linear_attn_heads,
        linear_attn_dim_head,
        gate_outputs,
        gate_init_bias,
        talking_heads,
        edge_dim,
        edge_mlp_mult,
        coors_hidden_dim,
        norm_coors,
        norm_coors_scale_init,
        use_cross_product,
        rel_pos_emb,
        dropout,
        init_eps,
        neighbors,
        only_sparse_neighbors,
        valid_neighbor_radius
    ):
        # 调用父类初始化函数
        super().__init__()
        # 设置缩放因子
        self.scale = dim_head ** -0.5
        # 对输入进行归一化
        self.norm = LayerNorm(dim)

        # 设置邻居节点相关参数
        self.neighbors = neighbors
        self.only_sparse_neighbors = only_sparse_neighbors
        self.valid_neighbor_radius = valid_neighbor_radius

        # 计算注意力机制内部维度
        attn_inner_dim = heads * dim_head
        self.heads = heads

        # 判断是否有全局线性注意力机制
        self.has_linear_attn = num_global_linear_attn_heads > 0

        # 初始化全局线性注意力机制
        self.linear_attn = TaylorSeriesLinearAttn(
            dim = dim,
            dim_head = linear_attn_dim_head,
            heads = num_global_linear_attn_heads,
            gate_value_heads = True,
            combine_heads = False
        )

        # 线性变换,将输入转换为查询、键、值
        self.to_qkv = nn.Linear(dim, attn_inner_dim * 3, bias = False)
        # 线性变换,将注意力机制输出转换为模型输出
        self.to_out = nn.Linear(attn_inner_dim + self.linear_attn.dim_hidden, dim)

        # 是否使用门控输出
        self.gate_outputs = gate_outputs
        if gate_outputs:
            # 初始化门控线性层
            gate_linear = nn.Linear(dim, 2 * heads)
            nn.init.zeros_(gate_linear.weight)
            nn.init.constant_(gate_linear.bias, gate_init_bias)

            # 设置输出门控
            self.to_output_gates = nn.Sequential(
                gate_linear,
                nn.Sigmoid(),
                Rearrange('b n (l h) -> l b h n 1', h = heads)
            )

        # 是否使用Talking Heads
        self.talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if talking_heads else None

        # 初始化边缘MLP
        self.edge_mlp = None
        has_edges = edge_dim > 0

        if has_edges:
            edge_input_dim = heads + edge_dim
            edge_hidden = edge_input_dim * edge_mlp_mult

            # 设置边缘MLP
            self.edge_mlp = nn.Sequential(
                nn.Linear(edge_input_dim, edge_hidden, bias = False),
                nn.GELU(),
                nn.Linear(edge_hidden, heads, bias = False)
            )

            # 设置坐标MLP
            self.coors_mlp = nn.Sequential(
                nn.GELU(),
                nn.Linear(heads, heads, bias = False)
            )
        else:
            # 设置坐标MLP
            self.coors_mlp = nn.Sequential(
                nn.Linear(heads, coors_hidden_dim, bias = False),
                nn.GELU(),
                nn.Linear(coors_hidden_dim, heads, bias = False)
            )

        # 设置坐标门控
        self.coors_gate = nn.Linear(heads, heads)
        small_init_(self.coors_gate)

        # 是否使用交叉乘积
        self.use_cross_product = use_cross_product
        if use_cross_product:
            # 设置交叉坐标MLP
            self.cross_coors_mlp = nn.Sequential(
                nn.Linear(heads, coors_hidden_dim, bias = False),
                nn.GELU(),
                nn.Linear(coors_hidden_dim, heads * 2, bias = False)
            )

            # 设置交叉坐标门控
            self.cross_coors_gate_i = nn.Linear(heads, heads)
            self.cross_coors_gate_j = nn.Linear(heads, heads)

            small_init_(self.cross_coors_gate_i)
            small_init_(self.cross_coors_gate_j)

        # 设置坐标归一化
        self.norm_rel_coors = CoorsNorm(scale_init = norm_coors_scale_init) if norm_rel_coors else nn.Identity()

        # 设置坐标组合参数
        num_coors_combine_heads = (2 if use_cross_product else 1) * heads
        self.coors_combine = nn.Parameter(torch.randn(num_coors_combine_heads))

        # 位置嵌入
        # 用于序列和残基/原子之间的相对距离

        self.rel_pos_emb = rel_pos_emb

        # 动态位置偏置MLP
        self.dynamic_pos_bias_mlp = DynamicPositionBias(
            dim = dim // 2,
            heads = heads,
            dim_head = dim_head,
            depth = 3,
            input_dim = (2 if rel_pos_emb else 1)
        )

        # 丢弃层

        self.node_dropout = nn.Dropout(dropout)
        self.coor_dropout = nn.Dropout(dropout)

        # 初始化

        self.init_eps = init_eps
        self.apply(self.init_)

    # 初始化函数,设置模型参数初始化方式
    def init_(self, module):
        if type(module) in {nn.Linear}:
            # 初始化线性层参数
            nn.init.normal_(module.weight, std = self.init_eps)

    # 前向传播函数
    def forward(
        self,
        feats,
        coors,
        edges = None,
        mask = None,
        adj_mat = None
# 定义一个 Transformer 模型的 Block 类,包含注意力机制和前馈神经网络
class Block(nn.Module):
    def __init__(self, attn, ff):
        super().__init__()
        self.attn = attn
        self.ff = ff

    # 前向传播函数,接收输入和坐标变化,返回处理后的特征、坐标、掩码、边缘和邻接矩阵
    def forward(self, inp, coor_changes = None):
        feats, coors, mask, edges, adj_mat = inp
        feats, coors = self.attn(feats, coors, edges = edges, mask = mask, adj_mat = adj_mat)
        feats, coors = self.ff(feats, coors)
        return (feats, coors, mask, edges, adj_mat)

# 定义一个 Encoder Transformer 模型
class EnTransformer(nn.Module):
    def __init__(
        self,
        *,
        dim,
        depth,
        num_tokens = None,
        rel_pos_emb = False,
        dim_head = 64,
        heads = 8,
        num_edge_tokens = None,
        edge_dim = 0,
        coors_hidden_dim = 16,
        neighbors = 0,
        only_sparse_neighbors = False,
        num_adj_degrees = None,
        adj_dim = 0,
        valid_neighbor_radius = float('inf'),
        init_eps = 1e-3,
        norm_rel_coors = True,
        norm_coors_scale_init = 1.,
        use_cross_product = False,
        talking_heads = False,
        checkpoint = False,
        attn_dropout = 0.,
        ff_dropout = 0.,
        num_global_linear_attn_heads = 0,
        gate_outputs = True
    ):
        super().__init__()
        # 断言维度每个头部应大于等于32,以使旋转嵌入正常工作
        assert dim_head >= 32, 'your dimension per head should be greater than 32 for rotary embeddings to work well'
        # 断言邻接度数大于等于1
        assert not (exists(num_adj_degrees) and num_adj_degrees < 1), 'make sure adjacent degrees is greater than 1'

        # 如果只有稀疏邻居,则将邻接度数设置为1
        if only_sparse_neighbors:
            num_adj_degrees = default(num_adj_degrees, 1)

        # 初始化嵌入层
        self.token_emb = nn.Embedding(num_tokens, dim) if exists(num_tokens) else None
        self.edge_emb = nn.Embedding(num_edge_tokens, edge_dim) if exists(num_edge_tokens) else None

        # 初始化邻接矩阵嵌入层
        self.num_adj_degrees = num_adj_degrees
        self.adj_emb = nn.Embedding(num_adj_degrees + 1, adj_dim) if exists(num_adj_degrees) and adj_dim > 0 else None
        adj_dim = adj_dim if exists(num_adj_degrees) else 0

        self.checkpoint = checkpoint
        self.layers = nn.ModuleList([])

        # 循环创建 Transformer 模型的 Block 层
        for ind in range(depth):
            self.layers.append(Block(
                Residual(EquivariantAttention(
                    dim = dim,
                    dim_head = dim_head,
                    heads = heads,
                    coors_hidden_dim = coors_hidden_dim,
                    edge_dim = (edge_dim + adj_dim),
                    neighbors = neighbors,
                    only_sparse_neighbors = only_sparse_neighbors,
                    valid_neighbor_radius = valid_neighbor_radius,
                    init_eps = init_eps,
                    rel_pos_emb = rel_pos_emb,
                    norm_rel_coors = norm_rel_coors,
                    norm_coors_scale_init = norm_coors_scale_init,
                    use_cross_product = use_cross_product,
                    talking_heads = talking_heads,
                    dropout = attn_dropout,
                    num_global_linear_attn_heads = num_global_linear_attn_heads,
                    gate_outputs = gate_outputs
                )),
                Residual(FeedForward(
                    dim = dim,
                    dropout = ff_dropout
                ))
            ))

    # 前向传播函数,接收特征、坐标、边缘、掩码、邻接矩阵等参数,返回处理后的结果
    def forward(
        self,
        feats,
        coors,
        edges = None,
        mask = None,
        adj_mat = None,
        return_coor_changes = False,
        **kwargs
        ):
            # 获取特征的批次大小
            b = feats.shape[0]

            # 如果存在 token_emb 属性,则对特征进行处理
            if exists(self.token_emb):
                feats = self.token_emb(feats)

            # 如果存在 edge_emb 属性,则对边进行处理
            if exists(self.edge_emb):
                assert exists(edges), 'edges must be passed in as (batch x seq x seq) indicating edge type'
                edges = self.edge_emb(edges)

            # 检查是否存在邻接矩阵,并且 num_adj_degrees 大于 0
            assert not (exists(adj_mat) and (not exists(self.num_adj_degrees) or self.num_adj_degrees == 0)), 'num_adj_degrees must be greater than 0 if you are passing in an adjacency matrix'

            # 如果存在 num_adj_degrees 属性
            if exists(self.num_adj_degrees):
                assert exists(adj_mat), 'adjacency matrix must be passed in (keyword argument adj_mat)'

                # 如果邻接矩阵的维度为 2,则进行扩展
                if len(adj_mat.shape) == 2:
                    adj_mat = repeat(adj_mat.clone(), 'i j -> b i j', b = b)

                # 克隆邻接矩阵并转换为长整型
                adj_indices = adj_mat.clone().long()

                # 遍历 num_adj_degrees - 1 次
                for ind in range(self.num_adj_degrees - 1):
                    degree = ind + 2

                    # 计算下一阶邻接矩阵
                    next_degree_adj_mat = (adj_mat.float() @ adj_mat.float()) > 0
                    next_degree_mask = (next_degree_adj_mat.float() - adj_mat.float()).bool()
                    adj_indices.masked_fill_(next_degree_mask, degree)
                    adj_mat = next_degree_adj_mat.clone()

                # 如果存在 adj_emb 属性,则对邻接矩阵进行处理
                if exists(self.adj_emb):
                    adj_emb = self.adj_emb(adj_indices)
                    edges = torch.cat((edges, adj_emb), dim = -1) if exists(edges) else adj_emb

            # 检查是否需要返回坐标变化,并且模型处于训练模式
            assert not (return_coor_changes and self.training), 'you must be eval mode in order to return coordinates'

            # 遍历层
            coor_changes = [coors]
            inp = (feats, coors, mask, edges, adj_mat)

            # 如果处于训练模式且启用了检查点,则使用检查点跨块进行内存节省
            if self.training and self.checkpoint:
                inp = checkpoint_sequential(self.layers, len(self.layers), inp)
            else:
                # 遍历块
                for layer in self.layers:
                    inp = layer(inp)
                    coor_changes.append(inp[1]) # 为可视化添加坐标

            # 返回
            feats, coors, *_ = inp

            # 如果需要返回坐标变化,则返回特征、坐标和坐标变化
            if return_coor_changes:
                return feats, coors, coor_changes

            # 否则只返回特征和坐标
            return feats, coors

.\lucidrains\En-transformer\en_transformer\utils.py

# 导入 torch 库
import torch
# 从 torch 库中导入 sin, cos, atan2, acos 函数
from torch import sin, cos, atan2, acos

# 定义绕 z 轴旋转的函数,参数为旋转角度 gamma
def rot_z(gamma):
    # 返回一个包含 z 轴旋转矩阵的张量
    return torch.tensor([
        [cos(gamma), -sin(gamma), 0],
        [sin(gamma), cos(gamma), 0],
        [0, 0, 1]
    ], dtype = gamma.dtype)

# 定义绕 y 轴旋转的函数,参数为旋转角度 beta
def rot_y(beta):
    # 返回一个包含 y 轴旋转矩阵的张量
    return torch.tensor([
        [cos(beta), 0, sin(beta)],
        [0, 1, 0],
        [-sin(beta), 0, cos(beta)]
    ], dtype = beta.dtype)

# 定义绕任意轴旋转的函数,参数为三个旋转角度 alpha, beta, gamma
def rot(alpha, beta, gamma):
    # 返回绕 z 轴、y 轴、z 轴旋转矩阵的乘积
    return rot_z(alpha) @ rot_y(beta) @ rot_z(gamma)

.\lucidrains\En-transformer\en_transformer\__init__.py

# 从 en_transformer 模块中导入 EquivariantAttention 和 EnTransformer 类
from en_transformer.en_transformer import EquivariantAttention, EnTransformer

E(n)-Equivariant Transformer

Implementation of E(n)-Equivariant Transformer, which extends the ideas from Welling's E(n)-Equivariant Graph Neural Network with attention mechanisms and ideas from transformer architecture.

Update: Used for designing of CDR loops in antibodies!

Install

$ pip install En-transformer

Usage

import torch
from en_transformer import EnTransformer

model = EnTransformer(
    dim = 512,
    depth = 4,                       # depth
    dim_head = 64,                   # dimension per head
    heads = 8,                       # number of heads
    edge_dim = 4,                    # dimension of edge feature
    neighbors = 64,                  # only do attention between coordinates N nearest neighbors - set to 0 to turn off
    talking_heads = True,            # use Shazeer's talking heads https://arxiv.org/abs/2003.02436
    checkpoint = True,               # use checkpointing so one can increase depth at little memory cost (and increase neighbors attended to)
    use_cross_product = True,        # use cross product vectors (idea by @MattMcPartlon)
    num_global_linear_attn_heads = 4 # if your number of neighbors above is low, you can assign a certain number of attention heads to weakly attend globally to all other nodes through linear attention (https://arxiv.org/abs/1812.01243)
)

feats = torch.randn(1, 1024, 512)
coors = torch.randn(1, 1024, 3)
edges = torch.randn(1, 1024, 1024, 4)

mask = torch.ones(1, 1024).bool()

feats, coors = model(feats, coors, edges, mask = mask)  # (1, 1024, 512), (1, 1024, 3)

Letting the network take care of both atomic and bond type embeddings

import torch
from en_transformer import EnTransformer

model = EnTransformer(
    num_tokens = 10,       # number of unique nodes, say atoms
    rel_pos_emb = True,    # set this to true if your sequence is not an unordered set. it will accelerate convergence
    num_edge_tokens = 5,   # number of unique edges, say bond types
    dim = 128,
    edge_dim = 16,
    depth = 3,
    heads = 4,
    dim_head = 32,
    neighbors = 8
)

atoms = torch.randint(0, 10, (1, 16))    # 10 different types of atoms
bonds = torch.randint(0, 5, (1, 16, 16)) # 5 different types of bonds (n x n)
coors = torch.randn(1, 16, 3)            # atomic spatial coordinates

feats_out, coors_out = model(atoms, coors, edges = bonds) # (1, 16, 512), (1, 16, 3)

If you would like to only attend to sparse neighbors, as defined by an adjacency matrix (say for atoms), you have to set one more flag and then pass in the N x N adjacency matrix.

import torch
from en_transformer import EnTransformer

model = EnTransformer(
    num_tokens = 10,
    dim = 512,
    depth = 1,
    heads = 4,
    dim_head = 32,
    neighbors = 0,
    only_sparse_neighbors = True,    # must be set to true
    num_adj_degrees = 2,             # the number of degrees to derive from 1st degree neighbors passed in
    adj_dim = 8                      # whether to pass the adjacency degree information as an edge embedding
)

atoms = torch.randint(0, 10, (1, 16))
coors = torch.randn(1, 16, 3)

# naively assume a single chain of atoms
i = torch.arange(atoms.shape[1])
adj_mat = (i[:, None] <= (i[None, :] + 1)) & (i[:, None] >= (i[None, :] - 1))

# adjacency matrix must be passed in
feats_out, coors_out = model(atoms, coors, adj_mat = adj_mat) # (1, 16, 512), (1, 16, 3)

Edges

If you need to pass in continuous edges

import torch
from en_transformer import EnTransformer
from en_transformer.utils import rot

model = EnTransformer(
    dim = 512,
    depth = 1,
    heads = 4,
    dim_head = 32,
    edge_dim = 4,
    num_nearest_neighbors = 0,
    only_sparse_neighbors = True
)

feats = torch.randn(1, 16, 512)
coors = torch.randn(1, 16, 3)
edges = torch.randn(1, 16, 16, 4)

i = torch.arange(feats.shape[1])
adj_mat = (i[:, None] <= (i[None, :] + 1)) & (i[:, None] >= (i[None, :] - 1))

feats1, coors1 = model(feats, coors, adj_mat = adj_mat, edges = edges)

Example

To run a protein backbone coordinate denoising toy task, first install sidechainnet

$ pip install sidechainnet

Then

$ python denoise.py

Todo

Citations

@misc{satorras2021en,
    title 	= {E(n) Equivariant Graph Neural Networks}, 
    author 	= {Victor Garcia Satorras and Emiel Hoogeboom and Max Welling},
    year 	= {2021},
    eprint 	= {2102.09844},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@misc{shazeer2020talkingheads,
    title   = {Talking-Heads Attention}, 
    author  = {Noam Shazeer and Zhenzhong Lan and Youlong Cheng and Nan Ding and Le Hou},
    year    = {2020},
    eprint  = {2003.02436},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@misc{liu2021swin,
    title   = {Swin Transformer V2: Scaling Up Capacity and Resolution},
    author  = {Ze Liu and Han Hu and Yutong Lin and Zhuliang Yao and Zhenda Xie and Yixuan Wei and Jia Ning and Yue Cao and Zheng Zhang and Li Dong and Furu Wei and Baining Guo},
    year    = {2021},
    eprint  = {2111.09883},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@inproceedings{Kim2020TheLC,
    title   = {The Lipschitz Constant of Self-Attention},
    author  = {Hyunjik Kim and George Papamakarios and Andriy Mnih},
    booktitle = {International Conference on Machine Learning},
    year    = {2020},
    url     = {https://api.semanticscholar.org/CorpusID:219530837}
}
@article {Mahajan2023.07.15.549154,
    author  = {Sai Pooja Mahajan and Jeffrey A. Ruffolo and Jeffrey J. Gray},
    title   = {Contextual protein and antibody encodings from equivariant graph transformers},
    elocation-id = {2023.07.15.549154},
    year    = {2023},
    doi     = {10.1101/2023.07.15.549154},
    publisher = {Cold Spring Harbor Laboratory},
    URL     = {https://www.biorxiv.org/content/early/2023/07/29/2023.07.15.549154},
    eprint  = {https://www.biorxiv.org/content/early/2023/07/29/2023.07.15.549154.full.pdf},
    journal = {bioRxiv}
}
@article{Bondarenko2023QuantizableTR,
    title   = {Quantizable Transformers: Removing Outliers by Helping Attention Heads Do Nothing},
    author  = {Yelysei Bondarenko and Markus Nagel and Tijmen Blankevoort},
    journal = {ArXiv},
    year    = {2023},
    volume  = {abs/2306.12929},
    url     = {https://api.semanticscholar.org/CorpusID:259224568}
}
@inproceedings{Arora2023ZoologyMA,
    title   = {Zoology: Measuring and Improving Recall in Efficient Language Models},
    author  = {Simran Arora and Sabri Eyuboglu and Aman Timalsina and Isys Johnson and Michael Poli and James Zou and Atri Rudra and Christopher R'e},
    year    = {2023},
    url     = {https://api.semanticscholar.org/CorpusID:266149332}
}