Lucidrains 系列项目源码解析(三十)

Electra - Pytorch
A simple working wrapper for fast pretraining of language models as detailed in this paper. It speeds up training (in comparison to normal masked language modeling) by a factor of 4x, and eventually reaches better performance if trained for even longer. Special thanks to Erik Nijkamp for taking the time to replicate the results for GLUE.
Install
$ pip install electra-pytorch
Usage
The following example uses reformer-pytorch, which is available to be pip installed.
import torch
from torch import nn
from reformer_pytorch import ReformerLM
from electra_pytorch import Electra
# (1) instantiate the generator and discriminator, making sure that the generator is roughly a quarter to a half of the size of the discriminator
generator = ReformerLM(
num_tokens = 20000,
emb_dim = 128,
dim = 256, # smaller hidden dimension
heads = 4, # less heads
ff_mult = 2, # smaller feed forward intermediate dimension
dim_head = 64,
depth = 12,
max_seq_len = 1024
)
discriminator = ReformerLM(
num_tokens = 20000,
emb_dim = 128,
dim = 1024,
dim_head = 64,
heads = 16,
depth = 12,
ff_mult = 4,
max_seq_len = 1024
)
# (2) weight tie the token and positional embeddings of generator and discriminator
generator.token_emb = discriminator.token_emb
generator.pos_emb = discriminator.pos_emb
# weight tie any other embeddings if available, token type embeddings, etc.
# (3) instantiate electra
trainer = Electra(
generator,
discriminator,
discr_dim = 1024, # the embedding dimension of the discriminator
discr_layer = 'reformer', # the layer name in the discriminator, whose output would be used for predicting token is still the same or replaced
mask_token_id = 2, # the token id reserved for masking
pad_token_id = 0, # the token id for padding
mask_prob = 0.15, # masking probability for masked language modeling
mask_ignore_token_ids = [] # ids of tokens to ignore for mask modeling ex. (cls, sep)
)
# (4) train
data = torch.randint(0, 20000, (1, 1024))
results = trainer(data)
results.loss.backward()
# after much training, the discriminator should have improved
torch.save(discriminator, f'./pretrained-model.pt')
If you would rather not have the framework auto-magically intercept the hidden output of the discriminator, you can pass in the discriminator (with the extra linear [dim x 1]) by yourself with the following.
import torch
from torch import nn
from reformer_pytorch import ReformerLM
from electra_pytorch import Electra
# (1) instantiate the generator and discriminator, making sure that the generator is roughly a quarter to a half of the size of the discriminator
generator = ReformerLM(
num_tokens = 20000,
emb_dim = 128,
dim = 256, # smaller hidden dimension
heads = 4, # less heads
ff_mult = 2, # smaller feed forward intermediate dimension
dim_head = 64,
depth = 12,
max_seq_len = 1024
)
discriminator = ReformerLM(
num_tokens = 20000,
emb_dim = 128,
dim = 1024,
dim_head = 64,
heads = 16,
depth = 12,
ff_mult = 4,
max_seq_len = 1024,
return_embeddings = True
)
# (2) weight tie the token and positional embeddings of generator and discriminator
generator.token_emb = discriminator.token_emb
generator.pos_emb = discriminator.pos_emb
# weight tie any other embeddings if available, token type embeddings, etc.
# (3) instantiate electra
discriminator_with_adapter = nn.Sequential(discriminator, nn.Linear(1024, 1))
trainer = Electra(
generator,
discriminator_with_adapter,
mask_token_id = 2, # the token id reserved for masking
pad_token_id = 0, # the token id for padding
mask_prob = 0.15, # masking probability for masked language modeling
mask_ignore_token_ids = [] # ids of tokens to ignore for mask modeling ex. (cls, sep)
)
# (4) train
data = torch.randint(0, 20000, (1, 1024))
results = trainer(data)
results.loss.backward()
# after much training, the discriminator should have improved
torch.save(discriminator, f'./pretrained-model.pt')
Important details for successful training
The generator should be roughly a quarter to at most one half of the discriminator's size for effective training. Any greater and the generator will be too good and the adversarial game collapses. This was done by reducing the hidden dimension, feed forward hidden dimension, and number of attention heads in the paper.
Testing
$ python setup.py test
Training
- Download the OpenWebText dataset.
$ mkdir data
$ cd data
$ pip3 install gdown
$ gdown --id 1EA5V0oetDCOke7afsktL_JDQ-ETtNOvx
$ tar -xf openwebtext.tar.xz
$ wget https://storage.googleapis.com/electra-data/vocab.txt
$ cd ..
- Tokenize dataset.
$ python pretraining/openwebtext/preprocess.py
- Pre-train.
$ python pretraining/openwebtext/pretrain.py
- Download GLUE dataset.
$ python examples/glue/download.py
- Fine-tune on the MRPC sub-task of the GLUE benchmark.
$ python examples/glue/run.py --model_name_or_path output/yyyy-mm-dd-hh-mm-ss/ckpt/200000
Citations
@misc{clark2020electra,
title={ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators},
author={Kevin Clark and Minh-Thang Luong and Quoc V. Le and Christopher D. Manning},
year={2020},
eprint={2003.10555},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
.\lucidrains\electra-pytorch\setup.py
# 导入设置工具和查找包的函数
from setuptools import setup, find_packages
# 设置包的元数据
setup(
name = 'electra-pytorch', # 包的名称
packages = find_packages(), # 查找所有包
version = '0.1.2', # 版本号
license='MIT', # 许可证
description = 'Electra - Pytorch', # 描述
author = 'Erik Nijkamp, Phil Wang', # 作者
author_email = 'erik.nijkamp@gmail.com, lucidrains@gmail.com', # 作者邮箱
url = 'https://github.com/lucidrains/electra-pytorch', # 项目链接
keywords = [
'transformers', # 关键词
'artificial intelligence', # 关键词
'pretraining' # 关键词
],
install_requires=[
'torch>=1.6.0', # 安装依赖
'transformers==3.0.2', # 安装依赖
'scipy', # 安装依赖
'sklearn' # 安装依赖
],
setup_requires=[
'pytest-runner' # 安装依赖
],
tests_require=[
'pytest', # 测试依赖
'reformer-pytorch' # 测试依赖
],
classifiers=[
'Development Status :: 4 - Beta', # 分类
'Intended Audience :: Developers', # 分类
'Topic :: Scientific/Engineering :: Artificial Intelligence', # 分类
'License :: OSI Approved :: MIT License', # 分类
'Programming Language :: Python :: 3.7', # 分类
],
)
.\lucidrains\electra-pytorch\tests\test_electra_pytorch.py
# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块
from torch import nn
# 从 reformer_pytorch 库中导入 ReformerLM 类
from reformer_pytorch import ReformerLM
# 从 electra_pytorch 库中导入 Electra 类
# 定义测试 Electra 模型的函数
def test_electra():
# 创建生成器 ReformerLM 模型
generator = ReformerLM(
num_tokens = 20000,
dim = 512,
depth = 1,
max_seq_len = 1024
)
# 创建鉴别器 ReformerLM 模型
discriminator = ReformerLM(
num_tokens = 20000,
dim = 512,
depth = 2,
max_seq_len = 1024
)
# 将生成器的 token_emb 属性设置为鉴别器的 token_emb 属性
generator.token_emb = discriminator.token_emb
# 将生成器的 pos_emb 属性设置为鉴别器的 pos_emb 属性
# 创建 Electra 训练器
trainer = Electra(
generator,
discriminator,
num_tokens = 20000,
discr_dim = 512,
discr_layer = 'reformer',
pad_token_id = 1,
mask_ignore_token_ids = [2, 3]
)
# 生成随机数据
data = torch.randint(0, 20000, (1, 1024))
# 使用训练器进行训练
results = trainer(data)
# 计算损失并反向传播
results.loss.backward()
# 定义测试不使用魔法方法的 Electra 模型的函数
def test_electra_without_magic():
# 创建生成器 ReformerLM 模型
generator = ReformerLM(
num_tokens = 20000,
dim = 512,
depth = 1,
max_seq_len = 1024
)
# 创建鉴别器 ReformerLM 模型
discriminator = ReformerLM(
num_tokens = 20000,
dim = 512,
depth = 2,
max_seq_len = 1024,
return_embeddings = True
)
# 将生成器的 token_emb 属性设置为鉴别器的 token_emb 属性
generator.token_emb = discriminator.token_emb
# 将生成器的 pos_emb 属性设置为鉴别器的 pos_emb 属性
# 创建包含适配器的鉴别器模型
discriminator_with_adapter = nn.Sequential(
discriminator,
nn.Linear(512, 1),
nn.Sigmoid()
)
# 创建 Electra 训练器
trainer = Electra(
generator,
discriminator_with_adapter,
num_tokens = 20000,
pad_token_id = 1,
mask_ignore_token_ids = [2, 3]
)
# 生成随机数据
data = torch.randint(0, 20000, (1, 1024))
# 使用训练器进行训练
results = trainer(data)
# 计算损失并反向传播
results.loss.backward()
.\lucidrains\ema-pytorch\ema_pytorch\ema_pytorch.py
# 导入深拷贝函数 deepcopy 和 partial 函数
from copy import deepcopy
from functools import partial
# 导入 torch 库
import torch
# 从 torch 库中导入 nn, Tensor 模块
from torch import nn, Tensor
# 从 torch.nn 模块中导入 Module 类
from torch.nn import Module
# 导入 beartype 库
from beartype import beartype
# 从 beartype.typing 模块中导入 Set, Optional 类型
from beartype.typing import Set, Optional
# 定义函数 exists,用于检查值是否存在
def exists(val):
return val is not None
# 定义函数 get_module_device,用于获取模块的设备信息
def get_module_device(m: Module):
return next(m.parameters()).device
# 定义函数 inplace_copy,用于原地复制张量数据
def inplace_copy(tgt: Tensor, src: Tensor, *, auto_move_device = False):
if auto_move_device:
src = src.to(tgt.device)
tgt.copy_(src)
# 定义函数 inplace_lerp,用于原地线性插值
def inplace_lerp(tgt: Tensor, src: Tensor, weight, *, auto_move_device = False):
if auto_move_device:
src = src.to(tgt.device)
tgt.lerp_(src, weight)
# 定义 EMA 类,实现模型的指数移动平均阴影
class EMA(Module):
"""
Implements exponential moving average shadowing for your model.
Utilizes an inverse decay schedule to manage longer term training runs.
By adjusting the power, you can control how fast EMA will ramp up to your specified beta.
@crowsonkb's notes on EMA Warmup:
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are
good values for models you plan to train for a million or more steps (reaches decay
factor 0.999 at 31.6K steps, 0.9999 at 1M steps), gamma=1, power=3/4 for models
you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at
215.4k steps).
Args:
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
power (float): Exponential factor of EMA warmup. Default: 2/3.
min_value (float): The minimum EMA decay rate. Default: 0.
"""
# 使用 beartype 装饰器,对初始化函数进行类型检查
@beartype
def __init__(
self,
model: Module,
ema_model: Optional[Module] = None, # if your model has lazylinears or other types of non-deepcopyable modules, you can pass in your own ema model
beta = 0.9999,
update_after_step = 100,
update_every = 10,
inv_gamma = 1.0,
power = 2 / 3,
min_value = 0.0,
param_or_buffer_names_no_ema: Set[str] = set(),
ignore_names: Set[str] = set(),
ignore_startswith_names: Set[str] = set(),
include_online_model = True, # set this to False if you do not wish for the online model to be saved along with the ema model (managed externally)
allow_different_devices = False # if the EMA model is on a different device (say CPU), automatically move the tensor
):
# 调用父类的构造函数
super().__init__()
# 初始化 beta 属性
self.beta = beta
# 判断是否冻结模型
self.is_frozen = beta == 1.
# 是否在模块树中包含在线模型,以便 state_dict 也保存它
self.include_online_model = include_online_model
if include_online_model:
self.online_model = model
else:
self.online_model = [model] # hack
# EMA 模型
self.ema_model = ema_model
if not exists(self.ema_model):
try:
self.ema_model = deepcopy(model)
except Exception as e:
print(f'Error: While trying to deepcopy model: {e}')
print('Your model was not copyable. Please make sure you are not using any LazyLinear')
exit()
self.ema_model.requires_grad_(False)
# 参数和缓冲区的名称
self.parameter_names = {name for name, param in self.ema_model.named_parameters() if torch.is_floating_point(param) or torch.is_complex(param)}
self.buffer_names = {name for name, buffer in self.ema_model.named_buffers() if torch.is_floating_point(buffer) or torch.is_complex(buffer)}
# 张量更新函数
self.inplace_copy = partial(inplace_copy, auto_move_device = allow_different_devices)
self.inplace_lerp = partial(inplace_lerp, auto_move_device = allow_different_devices)
# 更新超参数
self.update_every = update_every
self.update_after_step = update_after_step
self.inv_gamma = inv_gamma
self.power = power
self.min_value = min_value
assert isinstance(param_or_buffer_names_no_ema, (set, list))
self.param_or_buffer_names_no_ema = param_or_buffer_names_no_ema # parameter or buffer
self.ignore_names = ignore_names
self.ignore_startswith_names = ignore_startswith_names
# 是否管理 EMA 模型是否保留在不同设备上
self.allow_different_devices = allow_different_devices
# 初始化和步骤状态
self.register_buffer('initted', torch.tensor(False))
self.register_buffer('step', torch.tensor(0))
@property
def model(self):
return self.online_model if self.include_online_model else self.online_model[0]
def eval(self):
return self.ema_model.eval()
def restore_ema_model_device(self):
device = self.initted.device
self.ema_model.to(device)
def get_params_iter(self, model):
for name, param in model.named_parameters():
if name not in self.parameter_names:
continue
yield name, param
def get_buffers_iter(self, model):
for name, buffer in model.named_buffers():
if name not in self.buffer_names:
continue
yield name, buffer
def copy_params_from_model_to_ema(self):
copy = self.inplace_copy
for (_, ma_params), (_, current_params) in zip(self.get_params_iter(self.ema_model), self.get_params_iter(self.model)):
copy(ma_params.data, current_params.data)
for (_, ma_buffers), (_, current_buffers) in zip(self.get_buffers_iter(self.ema_model), self.get_buffers_iter(self.model)):
copy(ma_buffers.data, current_buffers.data)
def copy_params_from_ema_to_model(self):
copy = self.inplace_copy
for (_, ma_params), (_, current_params) in zip(self.get_params_iter(self.ema_model), self.get_params_iter(self.model)):
copy(current_params.data, ma_params.data)
for (_, ma_buffers), (_, current_buffers) in zip(self.get_buffers_iter(self.ema_model), self.get_buffers_iter(self.model)):
copy(current_buffers.data, ma_buffers.data)
# 获取当前的衰减值
def get_current_decay(self):
# 计算当前的 epoch,确保不小于 0
epoch = (self.step - self.update_after_step - 1).clamp(min=0.)
# 根据公式计算衰减值
value = 1 - (1 + epoch / self.inv_gamma) ** -self.power
# 如果 epoch 小于等于 0,则返回 0
if epoch.item() <= 0:
return 0.
# 返回计算得到的衰减值,确保在一定范围内
return value.clamp(min=self.min_value, max=self.beta).item()
# 更新操作
def update(self):
# 获取当前步数
step = self.step.item()
# 步数加一
self.step += 1
# 如果步数不是更新频率的倍数,则直接返回
if (step % self.update_every) != 0:
return
# 如果步数小于等于更新之后的步数,则将模型参数拷贝到指数移动平均模型中
if step <= self.update_after_step:
self.copy_params_from_model_to_ema()
return
# 如果模型还未初始化,则将模型参数拷贝到指数移动平均模型中,并标记为已初始化
if not self.initted.item():
self.copy_params_from_model_to_ema()
self.initted.data.copy_(torch.tensor(True))
# 更新指数移动平均模型
self.update_moving_average(self.ema_model, self.model)
# 更新指数移动平均模型
@torch.no_grad()
def update_moving_average(self, ma_model, current_model):
# 如果模型被冻结,则直接返回
if self.is_frozen:
return
# 获取拷贝和线性插值函数
copy, lerp = self.inplace_copy, self.inplace_lerp
# 获取当前的衰减值
current_decay = self.get_current_decay()
# 遍历当前模型和指数移动平均模型的参数
for (name, current_params), (_, ma_params) in zip(self.get_params_iter(current_model), self.get_params_iter(ma_model)):
# 如果参数名在忽略列表中,则跳过
if name in self.ignore_names:
continue
# 如果参数名以忽略列表中的前缀开头,则跳过
if any([name.startswith(prefix) for prefix in self.ignore_startswith_names]):
continue
# 如果参数名在不进行指数移动平均的列表中,则直接拷贝参数值
if name in self.param_or_buffer_names_no_ema:
copy(ma_params.data, current_params.data)
continue
# 对参数进行线性插值
lerp(ma_params.data, current_params.data, 1. - current_decay)
# 遍历当前模型和指数移动平均模型的缓冲区
for (name, current_buffer), (_, ma_buffer) in zip(self.get_buffers_iter(current_model), self.get_buffers_iter(ma_model)):
# 如果缓冲区名在忽略列表中,则跳过
if name in self.ignore_names:
continue
# 如果缓冲区名以忽略列表中的前缀开头,则跳过
if any([name.startswith(prefix) for prefix in self.ignore_startswith_names]):
continue
# 如果缓冲区名在不进行指数移动平均的列表中,则直接拷贝缓冲区值
if name in self.param_or_buffer_names_no_ema:
copy(ma_buffer.data, current_buffer.data)
continue
# 对缓冲区进行线性插值
lerp(ma_buffer.data, current_buffer.data, 1. - current_decay)
# 调用函数,返回指数移动平均模型的结果
def __call__(self, *args, **kwargs):
return self.ema_model(*args, **kwargs)
.\lucidrains\ema-pytorch\ema_pytorch\post_hoc_ema.py
# 导入必要的模块
from pathlib import Path
from copy import deepcopy
from functools import partial
import torch
from torch import nn, Tensor
from torch.nn import Module, ModuleList
import numpy as np
from beartype import beartype
from beartype.typing import Set, Tuple, Optional
# 检查值是否存在
def exists(val):
return val is not None
# 返回默认值
def default(val, d):
return val if exists(val) else d
# 返回数组的第一个元素
def first(arr):
return arr[0]
# 获取模块的设备
def get_module_device(m: Module):
return next(m.parameters()).device
# 在原地复制张量
def inplace_copy(tgt: Tensor, src: Tensor, *, auto_move_device = False):
if auto_move_device:
src = src.to(tgt.device)
tgt.copy_(src)
# 在原地执行线性插值
def inplace_lerp(tgt: Tensor, src: Tensor, weight, *, auto_move_device = False):
if auto_move_device:
src = src.to(tgt.device)
tgt.lerp_(src, weight)
# 将相对标准差转换为 gamma
def sigma_rel_to_gamma(sigma_rel):
t = sigma_rel ** -2
return np.roots([1, 7, 16 - t, 12 - t]).real.max().item()
# EMA 模块,使用论文 https://arxiv.org/abs/2312.02696 中的超参数
class KarrasEMA(Module):
"""
exponential moving average module that uses hyperparameters from the paper https://arxiv.org/abs/2312.02696
can either use gamma or sigma_rel from paper
"""
@beartype
def __init__(
self,
model: Module,
sigma_rel: Optional[float] = None,
gamma: Optional[float] = None,
ema_model: Optional[Module] = None, # if your model has lazylinears or other types of non-deepcopyable modules, you can pass in your own ema model
update_every: int = 100,
frozen: bool = False,
param_or_buffer_names_no_ema: Set[str] = set(),
ignore_names: Set[str] = set(),
ignore_startswith_names: Set[str] = set(),
allow_different_devices = False # if the EMA model is on a different device (say CPU), automatically move the tensor
):
super().__init__()
assert exists(sigma_rel) ^ exists(gamma), 'either sigma_rel or gamma is given. gamma is derived from sigma_rel as in the paper, then beta is dervied from gamma'
if exists(sigma_rel):
gamma = sigma_rel_to_gamma(sigma_rel)
self.gamma = gamma
self.frozen = frozen
self.online_model = [model]
# ema model
self.ema_model = ema_model
if not exists(self.ema_model):
try:
self.ema_model = deepcopy(model)
except Exception as e:
print(f'Error: While trying to deepcopy model: {e}')
print('Your model was not copyable. Please make sure you are not using any LazyLinear')
exit()
self.ema_model.requires_grad_(False)
# parameter and buffer names
self.parameter_names = {name for name, param in self.ema_model.named_parameters() if torch.is_floating_point(param) or torch.is_complex(param)}
self.buffer_names = {name for name, buffer in self.ema_model.named_buffers() if torch.is_floating_point(buffer) or torch.is_complex(buffer)}
# tensor update functions
self.inplace_copy = partial(inplace_copy, auto_move_device = allow_different_devices)
self.inplace_lerp = partial(inplace_lerp, auto_move_device = allow_different_devices)
# updating hyperparameters
self.update_every = update_every
assert isinstance(param_or_buffer_names_no_ema, (set, list))
self.param_or_buffer_names_no_ema = param_or_buffer_names_no_ema # parameter or buffer
self.ignore_names = ignore_names
self.ignore_startswith_names = ignore_startswith_names
# whether to manage if EMA model is kept on a different device
self.allow_different_devices = allow_different_devices
# init and step states
self.register_buffer('initted', torch.tensor(False))
self.register_buffer('step', torch.tensor(0))
@property
def model(self):
return first(self.online_model)
@property
# 计算 beta 值,用于更新移动平均模型
def beta(self):
return (1 - 1 / (self.step + 1)) ** (1 + self.gamma)
# 调用 EMA 模型的 eval 方法
def eval(self):
return self.ema_model.eval()
# 将 EMA 模型恢复到指定设备上
def restore_ema_model_device(self):
device = self.initted.device
self.ema_model.to(device)
# 获取模型的参数迭代器
def get_params_iter(self, model):
for name, param in model.named_parameters():
if name not in self.parameter_names:
continue
yield name, param
# 获取模型的缓冲区迭代器
def get_buffers_iter(self, model):
for name, buffer in model.named_buffers():
if name not in self.buffer_names:
continue
yield name, buffer
# 从原模型复制参数到 EMA 模型
def copy_params_from_model_to_ema(self):
copy = self.inplace_copy
for (_, ma_params), (_, current_params) in zip(self.get_params_iter(self.ema_model), self.get_params_iter(self.model)):
copy(ma_params.data, current_params.data)
for (_, ma_buffers), (_, current_buffers) in zip(self.get_buffers_iter(self.ema_model), self.get_buffers_iter(self.model)):
copy(ma_buffers.data, current_buffers.data)
# 从 EMA 模型复制参数到原模型
def copy_params_from_ema_to_model(self):
copy = self.inplace_copy
for (_, ma_params), (_, current_params) in zip(self.get_params_iter(self.ema_model), self.get_params_iter(self.model)):
copy(current_params.data, ma_params.data)
for (_, ma_buffers), (_, current_buffers) in zip(self.get_buffers_iter(self.ema_model), self.get_buffers_iter(self.model)):
copy(current_buffers.data, ma_buffers.data)
# 更新步数并执行移动平均更新
def update(self):
step = self.step.item()
self.step += 1
if (step % self.update_every) != 0:
return
if not self.initted.item():
self.copy_params_from_model_to_ema()
self.initted.data.copy_(torch.tensor(True))
self.update_moving_average(self.ema_model, self.model)
# 迭代所有 EMA 模型的参数和缓冲区
def iter_all_ema_params_and_buffers(self):
for name, ma_params in self.get_params_iter(self.ema_model):
if name in self.ignore_names:
continue
if any([name.startswith(prefix) for prefix in self.ignore_startswith_names]):
continue
if name in self.param_or_buffer_names_no_ema:
continue
yield ma_params
for name, ma_buffer in self.get_buffers_iter(self.ema_model):
if name in self.ignore_names:
continue
if any([name.startswith(prefix) for prefix in self.ignore_startswith_names]):
continue
if name in self.param_or_buffer_names_no_ema:
continue
yield ma_buffer
# 更新移动平均模型
@torch.no_grad()
def update_moving_average(self, ma_model, current_model):
if self.frozen:
return
copy, lerp = self.inplace_copy, self.inplace_lerp
current_decay = self.beta
for (name, current_params), (_, ma_params) in zip(self.get_params_iter(current_model), self.get_params_iter(ma_model)):
if name in self.ignore_names:
continue
if any([name.startswith(prefix) for prefix in self.ignore_startswith_names]):
continue
if name in self.param_or_buffer_names_no_ema:
copy(ma_params.data, current_params.data)
continue
lerp(ma_params.data, current_params.data, 1. - current_decay)
for (name, current_buffer), (_, ma_buffer) in zip(self.get_buffers_iter(current_model), self.get_buffers_iter(ma_model)):
if name in self.ignore_names:
continue
if any([name.startswith(prefix) for prefix in self.ignore_startswith_names]):
continue
if name in self.param_or_buffer_names_no_ema:
copy(ma_buffer.data, current_buffer.data)
continue
lerp(ma_buffer.data, current_buffer.data, 1. - current_decay)
# 定义一个特殊方法 __call__,使得对象可以像函数一样被调用
def __call__(self, *args, **kwargs):
# 调用 ema_model 对象,并传入参数
return self.ema_model(*args, **kwargs)
# 后验EMA包装器
# 解决将所有检查点组合成新合成的EMA的权重,以达到所需的gamma
# 算法3从论文中复制,用torch重新实现
# 计算两个张量的点乘
def p_dot_p(t_a, gamma_a, t_b, gamma_b):
t_ratio = t_a / t_b
t_exp = torch.where(t_a < t_b , gamma_b , -gamma_a)
t_max = torch.maximum(t_a , t_b)
num = (gamma_a + 1) * (gamma_b + 1) * t_ratio ** t_exp
den = (gamma_a + gamma_b + 1) * t_max
return num / den
# 解决权重
def solve_weights(t_i, gamma_i, t_r, gamma_r):
rv = lambda x: x.double().reshape(-1, 1)
cv = lambda x: x.double().reshape(1, -1)
A = p_dot_p(rv(t_i), rv(gamma_i), cv(t_i), cv(gamma_i))
b = p_dot_p(rv(t_i), rv(gamma_i), cv(t_r), cv(gamma_r))
return torch.linalg.solve(A, b)
# 后验EMA类
class PostHocEMA(Module):
# 初始化函数
@beartype
def __init__(
self,
model: Module,
sigma_rels: Optional[Tuple[float, ...]] = None,
gammas: Optional[Tuple[float, ...]] = None,
checkpoint_every_num_steps: int = 1000,
checkpoint_folder: str = './post-hoc-ema-checkpoints',
**kwargs
):
super().__init__()
assert exists(sigma_rels) ^ exists(gammas)
if exists(sigma_rels):
gammas = tuple(map(sigma_rel_to_gamma, sigma_rels))
assert len(gammas) > 1, 'at least 2 ema models with different gammas in order to synthesize new ema models of a different gamma'
assert len(set(gammas)) == len(gammas), 'calculated gammas must be all unique'
self.gammas = gammas
self.num_ema_models = len(gammas)
self._model = [model]
self.ema_models = ModuleList([KarrasEMA(model, gamma = gamma, **kwargs) for gamma in gammas])
self.checkpoint_folder = Path(checkpoint_folder)
self.checkpoint_folder.mkdir(exist_ok = True, parents = True)
assert self.checkpoint_folder.is_dir()
self.checkpoint_every_num_steps = checkpoint_every_num_steps
self.ema_kwargs = kwargs
# 返回模型
@property
def model(self):
return first(self._model)
# 返回步数
@property
def step(self):
return first(self.ema_models).step
# 返回设备
@property
def device(self):
return self.step.device
# 从EMA复制参数到模型
def copy_params_from_ema_to_model(self):
for ema_model in self.ema_models:
ema_model.copy_params_from_model_to_ema()
# 更新EMA模型
def update(self):
for ema_model in self.ema_models:
ema_model.update()
if not (self.step.item() % self.checkpoint_every_num_steps):
self.checkpoint()
# 创建检查点
def checkpoint(self):
step = self.step.item()
for ind, ema_model in enumerate(self.ema_models):
filename = f'{ind}.{step}.pt'
path = self.checkpoint_folder / filename
pkg = deepcopy(ema_model).half().state_dict()
torch.save(pkg, str(path))
# 合成EMA模型
@beartype
def synthesize_ema_model(
self,
gamma: Optional[float] = None,
sigma_rel: Optional[float] = None,
step: Optional[int] = None,
# 定义一个返回 KarrasEMA 对象的函数,参数包括 gamma 和 sigma_rel
def __call__(self, gamma: Optional[float] = None, sigma_rel: Optional[float] = None) -> KarrasEMA:
# 断言 gamma 和 sigma_rel 只能存在一个
assert exists(gamma) ^ exists(sigma_rel)
# 获取设备信息
device = self.device
# 如果存在 sigma_rel,则根据 sigma_rel 转换为 gamma
if exists(sigma_rel):
gamma = sigma_rel_to_gamma(sigma_rel)
# 创建一个合成的 EMA 模型对象
synthesized_ema_model = KarrasEMA(
model = self.model,
gamma = gamma,
**self.ema_kwargs
)
synthesized_ema_model
# 获取所有检查点
gammas = []
timesteps = []
checkpoints = [*self.checkpoint_folder.glob('*.pt')]
# 遍历检查点文件,获取 gamma 和 timestep
for file in checkpoints:
gamma_ind, timestep = map(int, file.stem.split('.'))
gamma = self.gammas[gamma_ind]
gammas.append(gamma)
timesteps.append(timestep)
# 设置步数为最大 timestep
step = default(step, max(timesteps))
# 断言步数小于等于最大 timestep
assert step <= max(timesteps), f'you can only synthesize for a timestep that is less than the max timestep {max(timesteps)}'
# 与算法 3 对齐
gamma_i = Tensor(gammas, device = device)
t_i = Tensor(timesteps, device = device)
gamma_r = Tensor([gamma], device = device)
t_r = Tensor([step], device = device)
# 使用最小二乘法解出将所有检查点组合成合成检查点的权重
weights = solve_weights(t_i, gamma_i, t_r, gamma_r)
weights = weights.squeeze(-1)
# 逐个使用权重将所有检查点相加到合成模型中
tmp_ema_model = KarrasEMA(
model = self.model,
gamma = gamma,
**self.ema_kwargs
)
for ind, (checkpoint, weight) in enumerate(zip(checkpoints, weights.tolist())):
is_first = ind == 0
# 将检查点加载到临时 EMA 模型中
ckpt_state_dict = torch.load(str(checkpoint))
tmp_ema_model.load_state_dict(ckpt_state_dict)
# 将加权检查点添加到合成模型中
for ckpt_tensor, synth_tensor in zip(tmp_ema_model.iter_all_ema_params_and_buffers(), synthesized_ema_model.iter_all_ema_params_and_buffers()):
if is_first:
synth_tensor.zero_()
synth_tensor.add_(ckpt_tensor * weight)
# 返回合成模型
return synthesized_ema_model
# 调用函数,返回所有 EMA 模型的结果
def __call__(self, *args, **kwargs):
return tuple(ema_model(*args, **kwargs) for ema_model in self.ema_models)
.\lucidrains\ema-pytorch\ema_pytorch\__init__.py
# 从 ema_pytorch 模块中导入 EMA 类
from ema_pytorch.ema_pytorch import EMA
# 从 ema_pytorch 模块中导入 KarrasEMA 和 PostHocEMA 类
from ema_pytorch.post_hoc_ema import (
KarrasEMA,
PostHocEMA
)
EMA - Pytorch
A simple way to keep track of an Exponential Moving Average (EMA) version of your pytorch model
Install
$ pip install ema-pytorch
Usage
import torch
from ema_pytorch import EMA
# your neural network as a pytorch module
net = torch.nn.Linear(512, 512)
# wrap your neural network, specify the decay (beta)
ema = EMA(
net,
beta = 0.9999, # exponential moving average factor
update_after_step = 100, # only after this number of .update() calls will it start updating
update_every = 10, # how often to actually update, to save on compute (updates every 10th .update() call)
)
# mutate your network, with SGD or otherwise
with torch.no_grad():
net.weight.copy_(torch.randn_like(net.weight))
net.bias.copy_(torch.randn_like(net.bias))
# you will call the update function on your moving average wrapper
ema.update()
# then, later on, you can invoke the EMA model the same way as your network
data = torch.randn(1, 512)
output = net(data)
ema_output = ema(data)
# if you want to save your ema model, it is recommended you save the entire wrapper
# as it contains the number of steps taken (there is a warmup logic in there, recommended by @crowsonkb, validated for a number of projects now)
# however, if you wish to access the copy of your model with EMA, then it will live at ema.ema_model
In order to use the post-hoc synthesized EMA, proposed by Karras et al. in a recent paper, follow the example below
import torch
from ema_pytorch import PostHocEMA
# your neural network as a pytorch module
net = torch.nn.Linear(512, 512)
# wrap your neural network, specify the sigma_rels or gammas
emas = PostHocEMA(
net,
sigma_rels = (0.05, 0.3), # a tuple with the hyperparameter for the multiple EMAs. you need at least 2 here to synthesize a new one
update_every = 10, # how often to actually update, to save on compute (updates every 10th .update() call)
checkpoint_every_num_steps = 10,
checkpoint_folder = './post-hoc-ema-checkpoints' # the folder of saved checkpoints for each sigma_rel (gamma) across timesteps with the hparam above, used to synthesizing a new EMA model after training
)
net.train()
for _ in range(1000):
# mutate your network, with SGD or otherwise
with torch.no_grad():
net.weight.copy_(torch.randn_like(net.weight))
net.bias.copy_(torch.randn_like(net.bias))
# you will call the update function on your moving average wrapper
emas.update()
# now that you have a few checkpoints
# you can synthesize an EMA model with a different sigma_rel (say 0.15)
synthesized_ema = emas.synthesize_ema_model(sigma_rel = 0.15)
# output with synthesized EMA
data = torch.randn(1, 512)
synthesized_ema_output = synthesized_ema(data)
Citations
@article{Karras2023AnalyzingAI,
title = {Analyzing and Improving the Training Dynamics of Diffusion Models},
author = {Tero Karras and Miika Aittala and Jaakko Lehtinen and Janne Hellsten and Timo Aila and Samuli Laine},
journal = {ArXiv},
year = {2023},
volume = {abs/2312.02696},
url = {https://api.semanticscholar.org/CorpusID:265659032}
}
.\lucidrains\ema-pytorch\setup.py
# 导入设置工具和查找包的函数
from setuptools import setup, find_packages
# 设置包的元数据
setup(
name = 'ema-pytorch', # 包的名称
packages = find_packages(exclude=[]), # 查找所有包
version = '0.4.3', # 版本号
license='MIT', # 许可证
description = 'Easy way to keep track of exponential moving average version of your pytorch module', # 描述
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
long_description_content_type = 'text/markdown', # 长描述内容类型
url = 'https://github.com/lucidrains/ema-pytorch', # URL
keywords = [
'artificial intelligence', # 关键词
'deep learning', # 关键词
'exponential moving average' # 关键词
],
install_requires=[
'beartype', # 安装依赖
'torch>=1.6', # 安装依赖
],
classifiers=[
'Development Status :: 4 - Beta', # 分类
'Intended Audience :: Developers', # 分类
'Topic :: Scientific/Engineering :: Artificial Intelligence', # 分类
'License :: OSI Approved :: MIT License', # 分类
'Programming Language :: Python :: 3.6', # 分类
],
)
.\lucidrains\En-transformer\denoise.py
# 导入 PyTorch 库
import torch
# 导入 PyTorch 中的函数库
import torch.nn.functional as F
# 从 torch 模块中导入 nn 模块
from torch import nn
# 从 torch.optim 模块中导入 Adam 优化器
from torch.optim import Adam
# 从 einops 库中导入 rearrange 和 repeat 函数
from einops import rearrange, repeat
# 导入 sidechainnet 库并重命名为 scn
import sidechainnet as scn
# 从 en_transformer 模块中导入 EnTransformer 类
from en_transformer.en_transformer import EnTransformer
# 设置默认的张量数据类型为 float64
torch.set_default_dtype(torch.float64)
# 定义批量大小为 1
BATCH_SIZE = 1
# 定义每隔多少次梯度累积
GRADIENT_ACCUMULATE_EVERY = 16
# 定义一个循环函数,用于生成数据批次
def cycle(loader, len_thres = 200):
while True:
for data in loader:
# 如果数据序列长度大于指定阈值,则继续循环
if data.seqs.shape[1] > len_thres:
continue
# 生成数据
yield data
# 创建 EnTransformer 模型实例
transformer = EnTransformer(
num_tokens = 21,
dim = 32,
dim_head = 64,
heads = 4,
depth = 4,
rel_pos_emb = True, # 序列中存在固有的顺序(氨基酸链的主干原子)
neighbors = 16
)
# 加载数据集
data = scn.load(
casp_version = 12,
thinning = 30,
with_pytorch = 'dataloaders',
batch_size = BATCH_SIZE,
dynamic_batching = False
)
# 创建数据加载器
dl = cycle(data['train'])
# 使用 Adam 优化器来优化 EnTransformer 模型的参数
optim = Adam(transformer.parameters(), lr=1e-3)
# 将模型移动到 GPU 上
transformer = transformer.cuda()
# 进行训练循环
for _ in range(10000):
for _ in range(GRADIENT_ACCUMULATE_EVERY):
# 获取一个数据批次
batch = next(dl)
seqs, coords, masks = batch.seqs, batch.crds, batch.msks
# 将序列数据移动到 GPU 上并取最大值
seqs = seqs.cuda().argmax(dim = -1)
# 将坐标数据移动到 GPU 上并转换为 float64 类型
coords = coords.cuda().type(torch.float64)
# 将掩码数据移动到 GPU 上并转换为布尔类型
masks = masks.cuda().bool()
# 获取序列长度
l = seqs.shape[1]
# 重新排列坐标数据的维度
coords = rearrange(coords, 'b (l s) c -> b l s c', s = 14)
# 保留主干坐标
coords = coords[:, :, 0:3, :]
coords = rearrange(coords, 'b l s c -> b (l s) c')
# 重复序列数据和掩码数据的维度
seq = repeat(seqs, 'b n -> b (n c)', c = 3)
masks = repeat(masks, 'b n -> b (n c)', c = 3)
# 添加噪声到坐标数据
noised_coords = coords + torch.randn_like(coords)
# 使用 Transformer 模型进行特征提取和去噪
feats, denoised_coords = transformer(seq, noised_coords, mask = masks)
# 计算均方误差损失
loss = F.mse_loss(denoised_coords[masks], coords[masks])
# 反向传播并计算梯度
(loss / GRADIENT_ACCUMULATE_EVERY).backward()
# 打印损失值
print('loss:', loss.item())
# 更新优化器
optim.step()
# 清空梯度
optim.zero_grad()
.\lucidrains\En-transformer\en_transformer\en_transformer.py
# 导入 torch 库
import torch
# 导入 torch 中的函数库
import torch.nn.functional as F
# 从 torch 中导入 nn、einsum 模块
from torch import nn, einsum
# 从 torch.utils.checkpoint 中导入 checkpoint_sequential 函数
from torch.utils.checkpoint import checkpoint_sequential
# 从 einx 中导入 get_at 函数
from einx import get_at
# 从 einops 中导入 rearrange、repeat、reduce 函数,从 einops.layers.torch 中导入 Rearrange 类
from einops import rearrange, repeat, reduce
from einops.layers.torch import Rearrange
# 从 taylor_series_linear_attention 中导入 TaylorSeriesLinearAttn 类
# 辅助函数
# 判断变量是否存在的函数
def exists(val):
return val is not None
# 返回指定数据类型的最小负值的函数
def max_neg_value(t):
return -torch.finfo(t.dtype).max
# 如果变量存在则返回该变量,否则返回默认值的函数
def default(val, d):
return val if exists(val) else d
# 对输入张量进行 L2 归一化的函数
def l2norm(t):
return F.normalize(t, dim = -1)
# 对 nn.Linear 类型的权重进行小范围初始化的函数
def small_init_(t: nn.Linear):
nn.init.normal_(t.weight, std = 0.02)
nn.init.zeros_(t.bias)
# 动态位置偏置
class DynamicPositionBias(nn.Module):
def __init__(
self,
dim,
*,
heads,
depth,
dim_head,
input_dim = 1,
norm = True
):
super().__init__()
assert depth >= 1, 'depth for dynamic position bias MLP must be greater or equal to 1'
self.mlp = nn.ModuleList([])
self.mlp.append(nn.Sequential(
nn.Linear(input_dim, dim),
nn.LayerNorm(dim) if norm else nn.Identity(),
nn.SiLU()
))
for _ in range(depth - 1):
self.mlp.append(nn.Sequential(
nn.Linear(dim, dim),
nn.LayerNorm(dim) if norm else nn.Identity(),
nn.SiLU()
))
self.heads = heads
self.qk_pos_head = nn.Linear(dim, heads)
self.value_pos_head = nn.Linear(dim, dim_head * heads)
def forward(self, pos):
for layer in self.mlp:
pos = layer(pos)
qk_pos = self.qk_pos_head(pos)
value_pos = self.value_pos_head(pos)
qk_pos = rearrange(qk_pos, 'b 1 i j h -> b h i j')
value_pos = rearrange(value_pos, 'b 1 i j (h d) -> b h i j d', h = self.heads)
return qk_pos, value_pos
# 类
# 此类遵循 SE3 Transformers 中的规范化策略
# https://github.com/lucidrains/se3-transformer-pytorch/blob/main/se3_transformer_pytorch/se3_transformer_pytorch.py#L95
# 层归一化类
class LayerNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.gamma = nn.Parameter(torch.ones(dim))
self.register_buffer('beta', torch.zeros(dim))
def forward(self, x):
return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
# 坐标归一化类
class CoorsNorm(nn.Module):
def __init__(self, eps = 1e-8, scale_init = 1.):
super().__init__()
self.eps = eps
scale = torch.zeros(1).fill_(scale_init)
self.scale = nn.Parameter(scale)
def forward(self, coors):
norm = coors.norm(dim = -1, keepdim = True)
normed_coors = coors / norm.clamp(min = self.eps)
return normed_coors * self.scale
# 残差连接类
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, feats, coors, **kwargs):
feats_out, coors_delta = self.fn(feats, coors, **kwargs)
return feats + feats_out, coors + coors_delta
# GEGLU 激活函数类
class GEGLU(nn.Module):
def forward(self, x):
x, gates = x.chunk(2, dim = -1)
return x * F.gelu(gates)
# 前馈神经网络类
class FeedForward(nn.Module):
def __init__(
self,
*,
dim,
mult = 4,
dropout = 0.
):
super().__init__()
inner_dim = int(dim * mult * 2 / 3)
self.net = nn.Sequential(
LayerNorm(dim),
nn.Linear(dim, inner_dim * 2, bias = False),
GEGLU(),
LayerNorm(inner_dim),
nn.Dropout(dropout),
nn.Linear(inner_dim, dim, bias = False)
)
def forward(self, feats, coors):
return self.net(feats), 0
class EquivariantAttention(nn.Module):
# 初始化函数,设置Transformer模型的参数
def __init__(
self,
*,
dim, # 输入特征的维度
dim_head = 64, # 每个头的维度
heads = 4, # 多头注意力机制的头数
edge_dim = 0, # 边的特征维度
coors_hidden_dim = 16, # 坐标隐藏层的维度
neighbors = 0, # 邻居节点的数量
only_sparse_neighbors = False, # 是否只使用稀疏邻居
valid_neighbor_radius = float('inf'), # 有效邻居的半径
init_eps = 1e-3, # 初始化的小量值
rel_pos_emb = None, # 相对位置编码
edge_mlp_mult = 2, # 边的多层感知机的倍数
norm_rel_coors = True, # 是否对相对坐标进行归一化
norm_coors_scale_init = 1., # 归一化坐标的初始值
use_cross_product = False, # 是否使用叉积
talking_heads = False, # 是否使用Talking Heads
dropout = 0., # Dropout概率
num_global_linear_attn_heads = 0, # 全局线性注意力机制的头数
linear_attn_dim_head = 8, # 线性注意力机制的头维度
gate_outputs = True, # 是否使用门控输出
gate_init_bias = 10. # 门控初始化偏置
# 初始化函数,设置模型参数初始化方式
def __init__(
self,
heads,
dim,
dim_head,
num_global_linear_attn_heads,
linear_attn_dim_head,
gate_outputs,
gate_init_bias,
talking_heads,
edge_dim,
edge_mlp_mult,
coors_hidden_dim,
norm_coors,
norm_coors_scale_init,
use_cross_product,
rel_pos_emb,
dropout,
init_eps,
neighbors,
only_sparse_neighbors,
valid_neighbor_radius
):
# 调用父类初始化函数
super().__init__()
# 设置缩放因子
self.scale = dim_head ** -0.5
# 对输入进行归一化
self.norm = LayerNorm(dim)
# 设置邻居节点相关参数
self.neighbors = neighbors
self.only_sparse_neighbors = only_sparse_neighbors
self.valid_neighbor_radius = valid_neighbor_radius
# 计算注意力机制内部维度
attn_inner_dim = heads * dim_head
self.heads = heads
# 判断是否有全局线性注意力机制
self.has_linear_attn = num_global_linear_attn_heads > 0
# 初始化全局线性注意力机制
self.linear_attn = TaylorSeriesLinearAttn(
dim = dim,
dim_head = linear_attn_dim_head,
heads = num_global_linear_attn_heads,
gate_value_heads = True,
combine_heads = False
)
# 线性变换,将输入转换为查询、键、值
self.to_qkv = nn.Linear(dim, attn_inner_dim * 3, bias = False)
# 线性变换,将注意力机制输出转换为模型输出
self.to_out = nn.Linear(attn_inner_dim + self.linear_attn.dim_hidden, dim)
# 是否使用门控输出
self.gate_outputs = gate_outputs
if gate_outputs:
# 初始化门控线性层
gate_linear = nn.Linear(dim, 2 * heads)
nn.init.zeros_(gate_linear.weight)
nn.init.constant_(gate_linear.bias, gate_init_bias)
# 设置输出门控
self.to_output_gates = nn.Sequential(
gate_linear,
nn.Sigmoid(),
Rearrange('b n (l h) -> l b h n 1', h = heads)
)
# 是否使用Talking Heads
self.talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if talking_heads else None
# 初始化边缘MLP
self.edge_mlp = None
has_edges = edge_dim > 0
if has_edges:
edge_input_dim = heads + edge_dim
edge_hidden = edge_input_dim * edge_mlp_mult
# 设置边缘MLP
self.edge_mlp = nn.Sequential(
nn.Linear(edge_input_dim, edge_hidden, bias = False),
nn.GELU(),
nn.Linear(edge_hidden, heads, bias = False)
)
# 设置坐标MLP
self.coors_mlp = nn.Sequential(
nn.GELU(),
nn.Linear(heads, heads, bias = False)
)
else:
# 设置坐标MLP
self.coors_mlp = nn.Sequential(
nn.Linear(heads, coors_hidden_dim, bias = False),
nn.GELU(),
nn.Linear(coors_hidden_dim, heads, bias = False)
)
# 设置坐标门控
self.coors_gate = nn.Linear(heads, heads)
small_init_(self.coors_gate)
# 是否使用交叉乘积
self.use_cross_product = use_cross_product
if use_cross_product:
# 设置交叉坐标MLP
self.cross_coors_mlp = nn.Sequential(
nn.Linear(heads, coors_hidden_dim, bias = False),
nn.GELU(),
nn.Linear(coors_hidden_dim, heads * 2, bias = False)
)
# 设置交叉坐标门控
self.cross_coors_gate_i = nn.Linear(heads, heads)
self.cross_coors_gate_j = nn.Linear(heads, heads)
small_init_(self.cross_coors_gate_i)
small_init_(self.cross_coors_gate_j)
# 设置坐标归一化
self.norm_rel_coors = CoorsNorm(scale_init = norm_coors_scale_init) if norm_rel_coors else nn.Identity()
# 设置坐标组合参数
num_coors_combine_heads = (2 if use_cross_product else 1) * heads
self.coors_combine = nn.Parameter(torch.randn(num_coors_combine_heads))
# 位置嵌入
# 用于序列和残基/原子之间的相对距离
self.rel_pos_emb = rel_pos_emb
# 动态位置偏置MLP
self.dynamic_pos_bias_mlp = DynamicPositionBias(
dim = dim // 2,
heads = heads,
dim_head = dim_head,
depth = 3,
input_dim = (2 if rel_pos_emb else 1)
)
# 丢弃层
self.node_dropout = nn.Dropout(dropout)
self.coor_dropout = nn.Dropout(dropout)
# 初始化
self.init_eps = init_eps
self.apply(self.init_)
# 初始化函数,设置模型参数初始化方式
def init_(self, module):
if type(module) in {nn.Linear}:
# 初始化线性层参数
nn.init.normal_(module.weight, std = self.init_eps)
# 前向传播函数
def forward(
self,
feats,
coors,
edges = None,
mask = None,
adj_mat = None
# 定义一个 Transformer 模型的 Block 类,包含注意力机制和前馈神经网络
class Block(nn.Module):
def __init__(self, attn, ff):
super().__init__()
self.attn = attn
self.ff = ff
# 前向传播函数,接收输入和坐标变化,返回处理后的特征、坐标、掩码、边缘和邻接矩阵
def forward(self, inp, coor_changes = None):
feats, coors, mask, edges, adj_mat = inp
feats, coors = self.attn(feats, coors, edges = edges, mask = mask, adj_mat = adj_mat)
feats, coors = self.ff(feats, coors)
return (feats, coors, mask, edges, adj_mat)
# 定义一个 Encoder Transformer 模型
class EnTransformer(nn.Module):
def __init__(
self,
*,
dim,
depth,
num_tokens = None,
rel_pos_emb = False,
dim_head = 64,
heads = 8,
num_edge_tokens = None,
edge_dim = 0,
coors_hidden_dim = 16,
neighbors = 0,
only_sparse_neighbors = False,
num_adj_degrees = None,
adj_dim = 0,
valid_neighbor_radius = float('inf'),
init_eps = 1e-3,
norm_rel_coors = True,
norm_coors_scale_init = 1.,
use_cross_product = False,
talking_heads = False,
checkpoint = False,
attn_dropout = 0.,
ff_dropout = 0.,
num_global_linear_attn_heads = 0,
gate_outputs = True
):
super().__init__()
# 断言维度每个头部应大于等于32,以使旋转嵌入正常工作
assert dim_head >= 32, 'your dimension per head should be greater than 32 for rotary embeddings to work well'
# 断言邻接度数大于等于1
assert not (exists(num_adj_degrees) and num_adj_degrees < 1), 'make sure adjacent degrees is greater than 1'
# 如果只有稀疏邻居,则将邻接度数设置为1
if only_sparse_neighbors:
num_adj_degrees = default(num_adj_degrees, 1)
# 初始化嵌入层
self.token_emb = nn.Embedding(num_tokens, dim) if exists(num_tokens) else None
self.edge_emb = nn.Embedding(num_edge_tokens, edge_dim) if exists(num_edge_tokens) else None
# 初始化邻接矩阵嵌入层
self.num_adj_degrees = num_adj_degrees
self.adj_emb = nn.Embedding(num_adj_degrees + 1, adj_dim) if exists(num_adj_degrees) and adj_dim > 0 else None
adj_dim = adj_dim if exists(num_adj_degrees) else 0
self.checkpoint = checkpoint
self.layers = nn.ModuleList([])
# 循环创建 Transformer 模型的 Block 层
for ind in range(depth):
self.layers.append(Block(
Residual(EquivariantAttention(
dim = dim,
dim_head = dim_head,
heads = heads,
coors_hidden_dim = coors_hidden_dim,
edge_dim = (edge_dim + adj_dim),
neighbors = neighbors,
only_sparse_neighbors = only_sparse_neighbors,
valid_neighbor_radius = valid_neighbor_radius,
init_eps = init_eps,
rel_pos_emb = rel_pos_emb,
norm_rel_coors = norm_rel_coors,
norm_coors_scale_init = norm_coors_scale_init,
use_cross_product = use_cross_product,
talking_heads = talking_heads,
dropout = attn_dropout,
num_global_linear_attn_heads = num_global_linear_attn_heads,
gate_outputs = gate_outputs
)),
Residual(FeedForward(
dim = dim,
dropout = ff_dropout
))
))
# 前向传播函数,接收特征、坐标、边缘、掩码、邻接矩阵等参数,返回处理后的结果
def forward(
self,
feats,
coors,
edges = None,
mask = None,
adj_mat = None,
return_coor_changes = False,
**kwargs
):
# 获取特征的批次大小
b = feats.shape[0]
# 如果存在 token_emb 属性,则对特征进行处理
if exists(self.token_emb):
feats = self.token_emb(feats)
# 如果存在 edge_emb 属性,则对边进行处理
if exists(self.edge_emb):
assert exists(edges), 'edges must be passed in as (batch x seq x seq) indicating edge type'
edges = self.edge_emb(edges)
# 检查是否存在邻接矩阵,并且 num_adj_degrees 大于 0
assert not (exists(adj_mat) and (not exists(self.num_adj_degrees) or self.num_adj_degrees == 0)), 'num_adj_degrees must be greater than 0 if you are passing in an adjacency matrix'
# 如果存在 num_adj_degrees 属性
if exists(self.num_adj_degrees):
assert exists(adj_mat), 'adjacency matrix must be passed in (keyword argument adj_mat)'
# 如果邻接矩阵的维度为 2,则进行扩展
if len(adj_mat.shape) == 2:
adj_mat = repeat(adj_mat.clone(), 'i j -> b i j', b = b)
# 克隆邻接矩阵并转换为长整型
adj_indices = adj_mat.clone().long()
# 遍历 num_adj_degrees - 1 次
for ind in range(self.num_adj_degrees - 1):
degree = ind + 2
# 计算下一阶邻接矩阵
next_degree_adj_mat = (adj_mat.float() @ adj_mat.float()) > 0
next_degree_mask = (next_degree_adj_mat.float() - adj_mat.float()).bool()
adj_indices.masked_fill_(next_degree_mask, degree)
adj_mat = next_degree_adj_mat.clone()
# 如果存在 adj_emb 属性,则对邻接矩阵进行处理
if exists(self.adj_emb):
adj_emb = self.adj_emb(adj_indices)
edges = torch.cat((edges, adj_emb), dim = -1) if exists(edges) else adj_emb
# 检查是否需要返回坐标变化,并且模型处于训练模式
assert not (return_coor_changes and self.training), 'you must be eval mode in order to return coordinates'
# 遍历层
coor_changes = [coors]
inp = (feats, coors, mask, edges, adj_mat)
# 如果处于训练模式且启用了检查点,则使用检查点跨块进行内存节省
if self.training and self.checkpoint:
inp = checkpoint_sequential(self.layers, len(self.layers), inp)
else:
# 遍历块
for layer in self.layers:
inp = layer(inp)
coor_changes.append(inp[1]) # 为可视化添加坐标
# 返回
feats, coors, *_ = inp
# 如果需要返回坐标变化,则返回特征、坐标和坐标变化
if return_coor_changes:
return feats, coors, coor_changes
# 否则只返回特征和坐标
return feats, coors
.\lucidrains\En-transformer\en_transformer\utils.py
# 导入 torch 库
import torch
# 从 torch 库中导入 sin, cos, atan2, acos 函数
from torch import sin, cos, atan2, acos
# 定义绕 z 轴旋转的函数,参数为旋转角度 gamma
def rot_z(gamma):
# 返回一个包含 z 轴旋转矩阵的张量
return torch.tensor([
[cos(gamma), -sin(gamma), 0],
[sin(gamma), cos(gamma), 0],
[0, 0, 1]
], dtype = gamma.dtype)
# 定义绕 y 轴旋转的函数,参数为旋转角度 beta
def rot_y(beta):
# 返回一个包含 y 轴旋转矩阵的张量
return torch.tensor([
[cos(beta), 0, sin(beta)],
[0, 1, 0],
[-sin(beta), 0, cos(beta)]
], dtype = beta.dtype)
# 定义绕任意轴旋转的函数,参数为三个旋转角度 alpha, beta, gamma
def rot(alpha, beta, gamma):
# 返回绕 z 轴、y 轴、z 轴旋转矩阵的乘积
return rot_z(alpha) @ rot_y(beta) @ rot_z(gamma)
.\lucidrains\En-transformer\en_transformer\__init__.py
# 从 en_transformer 模块中导入 EquivariantAttention 和 EnTransformer 类
from en_transformer.en_transformer import EquivariantAttention, EnTransformer
E(n)-Equivariant Transformer
Implementation of E(n)-Equivariant Transformer, which extends the ideas from Welling's E(n)-Equivariant Graph Neural Network with attention mechanisms and ideas from transformer architecture.
Update: Used for designing of CDR loops in antibodies!
Install
$ pip install En-transformer
Usage
import torch
from en_transformer import EnTransformer
model = EnTransformer(
dim = 512,
depth = 4, # depth
dim_head = 64, # dimension per head
heads = 8, # number of heads
edge_dim = 4, # dimension of edge feature
neighbors = 64, # only do attention between coordinates N nearest neighbors - set to 0 to turn off
talking_heads = True, # use Shazeer's talking heads https://arxiv.org/abs/2003.02436
checkpoint = True, # use checkpointing so one can increase depth at little memory cost (and increase neighbors attended to)
use_cross_product = True, # use cross product vectors (idea by @MattMcPartlon)
num_global_linear_attn_heads = 4 # if your number of neighbors above is low, you can assign a certain number of attention heads to weakly attend globally to all other nodes through linear attention (https://arxiv.org/abs/1812.01243)
)
feats = torch.randn(1, 1024, 512)
coors = torch.randn(1, 1024, 3)
edges = torch.randn(1, 1024, 1024, 4)
mask = torch.ones(1, 1024).bool()
feats, coors = model(feats, coors, edges, mask = mask) # (1, 1024, 512), (1, 1024, 3)
Letting the network take care of both atomic and bond type embeddings
import torch
from en_transformer import EnTransformer
model = EnTransformer(
num_tokens = 10, # number of unique nodes, say atoms
rel_pos_emb = True, # set this to true if your sequence is not an unordered set. it will accelerate convergence
num_edge_tokens = 5, # number of unique edges, say bond types
dim = 128,
edge_dim = 16,
depth = 3,
heads = 4,
dim_head = 32,
neighbors = 8
)
atoms = torch.randint(0, 10, (1, 16)) # 10 different types of atoms
bonds = torch.randint(0, 5, (1, 16, 16)) # 5 different types of bonds (n x n)
coors = torch.randn(1, 16, 3) # atomic spatial coordinates
feats_out, coors_out = model(atoms, coors, edges = bonds) # (1, 16, 512), (1, 16, 3)
If you would like to only attend to sparse neighbors, as defined by an adjacency matrix (say for atoms), you have to set one more flag and then pass in the N x N adjacency matrix.
import torch
from en_transformer import EnTransformer
model = EnTransformer(
num_tokens = 10,
dim = 512,
depth = 1,
heads = 4,
dim_head = 32,
neighbors = 0,
only_sparse_neighbors = True, # must be set to true
num_adj_degrees = 2, # the number of degrees to derive from 1st degree neighbors passed in
adj_dim = 8 # whether to pass the adjacency degree information as an edge embedding
)
atoms = torch.randint(0, 10, (1, 16))
coors = torch.randn(1, 16, 3)
# naively assume a single chain of atoms
i = torch.arange(atoms.shape[1])
adj_mat = (i[:, None] <= (i[None, :] + 1)) & (i[:, None] >= (i[None, :] - 1))
# adjacency matrix must be passed in
feats_out, coors_out = model(atoms, coors, adj_mat = adj_mat) # (1, 16, 512), (1, 16, 3)
Edges
If you need to pass in continuous edges
import torch
from en_transformer import EnTransformer
from en_transformer.utils import rot
model = EnTransformer(
dim = 512,
depth = 1,
heads = 4,
dim_head = 32,
edge_dim = 4,
num_nearest_neighbors = 0,
only_sparse_neighbors = True
)
feats = torch.randn(1, 16, 512)
coors = torch.randn(1, 16, 3)
edges = torch.randn(1, 16, 16, 4)
i = torch.arange(feats.shape[1])
adj_mat = (i[:, None] <= (i[None, :] + 1)) & (i[:, None] >= (i[None, :] - 1))
feats1, coors1 = model(feats, coors, adj_mat = adj_mat, edges = edges)
Example
To run a protein backbone coordinate denoising toy task, first install sidechainnet
$ pip install sidechainnet
Then
$ python denoise.py
Todo
- add arxiv.org/abs/2112.05… for researchers to stretch to even bigger molecules
Citations
@misc{satorras2021en,
title = {E(n) Equivariant Graph Neural Networks},
author = {Victor Garcia Satorras and Emiel Hoogeboom and Max Welling},
year = {2021},
eprint = {2102.09844},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
@misc{shazeer2020talkingheads,
title = {Talking-Heads Attention},
author = {Noam Shazeer and Zhenzhong Lan and Youlong Cheng and Nan Ding and Le Hou},
year = {2020},
eprint = {2003.02436},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
@misc{liu2021swin,
title = {Swin Transformer V2: Scaling Up Capacity and Resolution},
author = {Ze Liu and Han Hu and Yutong Lin and Zhuliang Yao and Zhenda Xie and Yixuan Wei and Jia Ning and Yue Cao and Zheng Zhang and Li Dong and Furu Wei and Baining Guo},
year = {2021},
eprint = {2111.09883},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
@inproceedings{Kim2020TheLC,
title = {The Lipschitz Constant of Self-Attention},
author = {Hyunjik Kim and George Papamakarios and Andriy Mnih},
booktitle = {International Conference on Machine Learning},
year = {2020},
url = {https://api.semanticscholar.org/CorpusID:219530837}
}
@article {Mahajan2023.07.15.549154,
author = {Sai Pooja Mahajan and Jeffrey A. Ruffolo and Jeffrey J. Gray},
title = {Contextual protein and antibody encodings from equivariant graph transformers},
elocation-id = {2023.07.15.549154},
year = {2023},
doi = {10.1101/2023.07.15.549154},
publisher = {Cold Spring Harbor Laboratory},
URL = {https://www.biorxiv.org/content/early/2023/07/29/2023.07.15.549154},
eprint = {https://www.biorxiv.org/content/early/2023/07/29/2023.07.15.549154.full.pdf},
journal = {bioRxiv}
}
@article{Bondarenko2023QuantizableTR,
title = {Quantizable Transformers: Removing Outliers by Helping Attention Heads Do Nothing},
author = {Yelysei Bondarenko and Markus Nagel and Tijmen Blankevoort},
journal = {ArXiv},
year = {2023},
volume = {abs/2306.12929},
url = {https://api.semanticscholar.org/CorpusID:259224568}
}
@inproceedings{Arora2023ZoologyMA,
title = {Zoology: Measuring and Improving Recall in Efficient Language Models},
author = {Simran Arora and Sabri Eyuboglu and Aman Timalsina and Isys Johnson and Michael Poli and James Zou and Atri Rudra and Christopher R'e},
year = {2023},
url = {https://api.semanticscholar.org/CorpusID:266149332}
}