Lucidrains 系列项目源码解析(九十)
.\lucidrains\self-rewarding-lm-pytorch\self_rewarding_lm_pytorch\spin.py
from pathlib import Path
# 导入 Path 模块,用于处理文件路径
from beartype import beartype
from beartype.typing import Optional, Callable, Union
# 导入 beartype 模块,用于类型注解
from torchtyping import TensorType
# 导入 TensorType 类型注解
import torch
# 导入 torch 模块
from torch.nn import Module, Dropout
# 从 torch.nn 模块中导入 Module 和 Dropout 类
import torch.nn.functional as F
# 导入 torch.nn.functional 模块,用于神经网络函数
from torch.cuda.amp import autocast
# 导入 autocast 函数,用于混合精度训练
from torch.utils.data import Dataset, DataLoader
# 导入 Dataset 和 DataLoader 类,用于处理数据集和数据加载
from torch.nn.utils.rnn import pad_sequence
# 导入 pad_sequence 函数,用于填充序列
from accelerate import Accelerator
# 导入 Accelerator 类,用于加速训练
from einops import rearrange
# 导入 rearrange 函数,用于重排张量维度
from einx import get_at
# 导入 get_at 函数,用于获取张量的特定位置
from pytorch_custom_utils.utils import (
masked_mean,
maybe_and_mask
)
# 从 pytorch_custom_utils.utils 模块中导入 masked_mean 和 maybe_and_mask 函数
from pytorch_custom_utils.accelerate_utils import (
model_forward_contexts
)
# 从 pytorch_custom_utils.accelerate_utils 模块中导入 model_forward_contexts 函数
from self_rewarding_lm_pytorch.dpo import (
adam_optimizer_with_linear_decay
)
# 从 self_rewarding_lm_pytorch.dpo 模块中导入 adam_optimizer_with_linear_decay 函数
from self_rewarding_lm_pytorch.sampling_utils import (
sample,
top_p,
top_k
)
# 从 self_rewarding_lm_pytorch.sampling_utils 模块中导入 sample、top_p 和 top_k 函数
from tqdm import tqdm
# 导入 tqdm 模块,用于显示进度条
from ema_pytorch import EMA
# 导入 EMA 类,用于指数移动平均
# helper functions
def exists(v):
return v is not None
# 定义 exists 函数,判断变量是否为 None
def cycle(dl):
while True:
for batch in dl:
yield batch
# 定义 cycle 函数,用于循环迭代数据加载器中的批次数据
def log_prob_from_model_and_seq(model, seq):
logits = model(seq)
log_probs = logits.log_softmax(dim = -1)
return get_at('b n [c], b n -> b n', log_probs, seq)
# 定义 log_prob_from_model_and_seq 函数,计算模型生成序列的对数概率
def prompt_mask_from_len(lengths, seq):
seq_len, device = seq.shape[-1], seq.device
return torch.arange(seq_len, device = device) < rearrange(lengths, '... -> ... 1')
# 定义 prompt_mask_from_len 函数,根据序列长度生成掩码
def set_dropout_(model: Module, prob: float):
for module in model.modules():
if isinstance(module, Dropout):
module.p = prob
# 定义 set_dropout_ 函数,设置模型中的 Dropout 层的概率
# main class
class SPIN(Module):
def __init__(
self,
model: Module,
*,
λ = 0.1,
pad_id: Optional[int] = None,
ref_model_ema_decay = 1.,
ema_kwargs: dict = dict()
):
super().__init__()
self.policy_model = model
self.ref_model = EMA(
model,
beta = ref_model_ema_decay,
**ema_kwargs
)
# 初始化 SPIN 类,包括策略模型、参考模型和参数
self.λ = λ
self.pad_id = pad_id
# 设置 λ 和 pad_id 属性
def update_reference_model_with_policy(self):
self.ref_model.copy_params_from_model_to_ema()
# 更新参考模型参数为策略模型参数
def update_ema(self):
self.ref_model.update()
# 更新指数��动平均
def parameters(self):
return self.policy_model.parameters()
# 返回策略模型的参数
@property
def device(self):
return next(self.parameters()).device
# 返回模型所在设备
@autocast(enabled = False)
def forward(
self,
generated_seq: TensorType['b', 'n', int],
real_seq: TensorType['b', 'n', int],
prompt_len: TensorType['b', int],
generated_seq_mask: Optional[TensorType['b', 'n', bool]] = None,
real_seq_mask: Optional[TensorType['b', 'n', bool]] = None
# 设置策略模型为训练模式
self.policy_model.train()
"""
b - batch
n - sequence length
"""
# 根据提示长度和实际序列生成实际提示掩码和生成提示掩码
real_prompt_mask = prompt_mask_from_len(prompt_len, real_seq)
generated_prompt_mask = prompt_mask_from_len(prompt_len, generated_seq)
"""
Equation 4.7 in https://arxiv.org/abs/2401.01335v1
"""
# 如果存在填充 ID
if exists(self.pad_id):
# 确保生成序列掩码和实际序列掩码不存在
assert not exists(generated_seq_mask)
assert not exists(real_seq_mask)
# 生成生成序列掩码并填充
generated_seq_mask = generated_seq != self.pad_id
generated_seq.masked_fill_(~generated_seq_mask, 0)
# 生成实际序列掩码并填充
real_seq_mask = real_seq != self.pad_id
real_seq.masked_fill_(~real_seq_mask, 0)
# 禁用梯度计算
with torch.no_grad():
# 设置参考模型为评估模式
self.ref_model.eval()
# 计算生成序列和实际序列的参考模型对数概率
ref_generated_logprob = log_prob_from_model_and_seq(self.ref_model, generated_seq)
ref_real_logprob = log_prob_from_model_and_seq(self.ref_model, real_seq)
# 计算策略模型对生成序列和实际序列的对数概率
policy_generated_logprob = log_prob_from_model_and_seq(self.policy_model, generated_seq)
policy_real_logprob = log_prob_from_model_and_seq(self.policy_model, real_seq)
# 对变长序列进行掩码平均值计算
# 对生成序列和实际序列的策略模型对数概率和参考模型对数概率进行掩码平均值计算
policy_generated_logprob, ref_generated_logprob = [masked_mean(seq, maybe_and_mask(generated_seq_mask, ~generated_prompt_mask)) for seq in (policy_generated_logprob, ref_generated_logprob)]
policy_real_logprob, ref_real_logprob = [masked_mean(seq, maybe_and_mask(real_seq_mask, ~real_prompt_mask)) for seq in (policy_real_logprob, ref_real_logprob)]
# 计算 SPIN 损失
# 计算损失值
losses = -F.logsigmoid(self.λ * ((policy_real_logprob - ref_real_logprob) - (policy_generated_logprob - ref_generated_logprob)))
# 返回损失值的平均值
return losses.mean()
class SPINTrainer(Module):
# 定义 SPINTrainer 类,继承自 Module 类
def __init__(
self,
model: Union[Module, SPIN],
*,
train_sft_dataset: Dataset,
max_seq_len: int,
valid_sft_dataset: Optional[Dataset] = None,
valid_every = 100,
accelerator: Optional[Accelerator] = None,
accelerate_kwargs: dict = dict(),
batch_size = 16,
grad_accum_steps = 2,
epochs = 2,
start_learning_rate = 1e-6,
end_learning_rate = 1e-7,
learning_rate_num_decay_steps = 1000,
dropout = 0.,
weight_decay = 0.,
adam_kwargs: dict = dict(),
temperature = 0.7,
filter_fn = top_p,
filter_kwargs = dict(thres = 0.9),
pad_id: int = -1,
ref_model_ema_decay = 1.,
checkpoint_every = None,
checkpoint_folder = './spin-checkpoints',
spin_kwargs: dict = dict(
λ = 0.1,
)
):
# 初始化函数,接受多个参数
super().__init__()
self.accelerator = accelerator
# 设置 accelerator 属性为传入的 accelerator 参数
if not exists(self.accelerator):
self.accelerator = Accelerator(**accelerate_kwargs)
# 如果 accelerator 不存在,则根据 accelerate_kwargs 创建一个 Accelerator 对象
if not isinstance(model, SPIN):
model = SPIN(
model,
pad_id = pad_id,
ref_model_ema_decay = ref_model_ema_decay,
**spin_kwargs
)
# 如果 model 不是 SPIN 类型,则根据传入参数创建一个 SPIN 对象
self.model = model
self.dropout = dropout
self.train_dataloader = DataLoader(train_sft_dataset, batch_size = batch_size, shuffle = True, drop_last = True)
# 设置模型、dropout 和训练数据加载器属性
self.grad_accum_steps = grad_accum_steps
self.num_train_steps = len(self.train_dataloader) // self.grad_accum_steps * epochs
# 设置梯度累积步数和训练步数
self.optimizer = adam_optimizer_with_linear_decay(
model,
start_learning_rate,
end_learning_rate,
num_decay_steps = learning_rate_num_decay_steps,
accelerator = self.accelerator,
weight_decay = weight_decay,
adam_kwargs = adam_kwargs
)
# 使用 adam_optimizer_with_linear_decay 函数创建优化器
(
self.model,
self.train_dataloader
) = self.accelerator.prepare(
self.model,
self.train_dataloader
)
# 准备模型和训练数据加载器
self.max_seq_len = max_seq_len
self.pad_id = pad_id
# 设置最大序列长度和 pad_id
# sampling
self.temperature = temperature
self.filter_fn = filter_fn
self.filter_kwargs = filter_kwargs
# 设置采样相关参数
# validation
self.valid_dataloader = None
self.valid_every = valid_every
# 初始化验证数据加载器和验证频率
if exists(valid_sft_dataset):
self.valid_dataloader = DataLoader(valid_sft_dataset, batch_size = batch_size)
# 如果存在验证数据集,则创建验证数据加载器
# checkpointing
self.should_checkpoint = exists(checkpoint_every)
self.checkpoint_every = checkpoint_every
# 设置是否需要检查点和检查点频率
if self.should_checkpoint:
self.checkpoint_folder = Path(checkpoint_folder)
self.checkpoint_folder.mkdir(exist_ok = True, parents = True)
# 如果需要检查点,则创建检查点文件夹
self.steps = 0
# 初始化步数为 0
@property
def is_main(self):
return self.accelerator.is_main_process
# 返回是否为主进程的属性
@property
def unwrapped_model(self):
return self.accelerator.unwrap_model(self.model)
# 返回解封装后的模型属性
def print(self, *msg):
self.accelerator.print(*msg)
# 打印函数
def log(self, **data):
self.accelerator.log(data, step = self.steps)
# 记录日志函数
def wait(self):
return self.accelerator.wait_for_everyone()
# 等待所有进程函数
def save(self, path: str, overwrite: bool = False):
self.wait()
if self.is_main:
path = self.checkpoint_folder / path
assert not path.exists() or overwrite, f'file already exists'
pkg = dict(
model = self.unwrapped_model.state_dict()
)
torch.save(pkg, str(path))
# 保存模型函数
def calc_spin_loss(
self,
real_seq: TensorType['b', 'n', int],
prompt_len: TensorType['b', int]
# 计算 SPIN 损失函数
):
# 根据实际序列长度和掩码生成提示掩码
prompt_mask = prompt_mask_from_len(prompt_len, real_seq)
# 根据提示掩码拆分实际序列,得到提示列表
prompts = real_seq[prompt_mask].split(prompt_len.tolist())
# 使用策略模型生成序列
generated_seqs = sample(
self.unwrapped_model.policy_model,
prompts = prompts,
seq_len = self.max_seq_len,
temperature = self.temperature,
filter_fn = self.filter_fn,
filter_kwargs = self.filter_kwargs,
output_keep_prompt = True
)
# 计算 SPIN 损失
spin_loss = self.model(
real_seq = real_seq,
generated_seq = generated_seqs,
prompt_len = prompt_len
)
return spin_loss
def forward(self, overwrite_checkpoints: bool = True):
"""
Algorithm 1 - https://arxiv.org/abs/2401.01335v1
"""
# 更新参考模型
self.model.update_reference_model_with_policy()
self.steps = 0
# 设置模型的 dropout
set_dropout_(self.model, self.dropout)
# 创建训练数据加载器的迭代器
train_dataloader_iter = cycle(self.train_dataloader)
# 循环进行自我训练
for _ in tqdm(range(self.num_train_steps), desc = 'spin fine-tuning'):
self.model.train()
# 遍历模型前向计算上下文
for forward_context in model_forward_contexts(self.accelerator, self.model, self.grad_accum_steps):
with forward_context():
# 从训练数据加载器中获取实际序列和提示长度
real_seq, prompt_len = next(train_dataloader_iter)
# 计算 SPIN 损失
train_loss = self.calc_spin_loss(real_seq, prompt_len)
# 反向传播
self.accelerator.backward(train_loss / self.grad_accum_steps)
# 打印训练损失
self.print(f'train spin loss: {train_loss.item():.3f}')
self.log(loss = train_loss.item())
# 更新优化器
self.optimizer.step()
self.optimizer.zero_grad()
self.steps += 1
# 等待
self.wait()
# 更新指数移动平均模型
self.unwrapped_model.update_ema()
# 如果存在验证数据加载器且满足验证频率条件
if exists(self.valid_dataloader) and not (self.valid_every % self.steps):
self.wait()
if self.is_main:
total_loss = 0.
total_batches = 0.
with torch.no_grad():
self.model.eval()
# 遍历验证数据加载器
for valid_seq, prompt_len in tqdm(self.valid_dataloader, desc = 'valid spin'):
batch = valid_seq.shape[0]
# 计算验证 SPIN 损失
valid_spin_loss = self.calc_spin_loss(valid_seq, prompt_len)
total_batches += batch
total_loss += valid_spin_loss * batch
valid_loss = total_loss / total_batches
# 打印验证损失
self.print(f'valid spin loss: {valid_loss.item():.3f}')
self.log(valid_spin_loss = valid_loss.item())
# 如果需要保存检查点且满足检查点频率条件
if self.should_checkpoint and not (self.checkpoint_every % self.steps):
checkpoint_num = self.steps // self.checkpoint_every
self.save(f'spin.ckpt.{checkpoint_num}.pt', overwrite = overwrite_checkpoints)
self.print(f'self-play training complete')
.\lucidrains\self-rewarding-lm-pytorch\self_rewarding_lm_pytorch\__init__.py
# 导入自我奖励语言模型训练器和奖励配置
from self_rewarding_lm_pytorch.self_rewarding_lm_pytorch import (
SelfRewardingTrainer,
RewardConfig
)
# 导入 SPIN 模型和 SPIN 训练器
from self_rewarding_lm_pytorch.spin import (
SPIN,
SPINTrainer,
)
# 导入 DPO 模型和 DPO 训练器
from self_rewarding_lm_pytorch.dpo import (
DPO,
DPOTrainer,
)
# 导入创建模拟数据集的函数
from self_rewarding_lm_pytorch.mocks import create_mock_dataset
# 导入自我奖励语言模型微调配置
from self_rewarding_lm_pytorch.self_rewarding_lm_pytorch import (
SFTConfig,
SelfRewardDPOConfig,
ExternalRewardDPOConfig,
SelfPlayConfig,
create_default_paper_config
)
.\lucidrains\self-rewarding-lm-pytorch\setup.py
# 导入设置和查找包的函数
from setuptools import setup, find_packages
# 设置包的信息
setup(
name = 'self-rewarding-lm-pytorch', # 包的名称
packages = find_packages(exclude=[]), # 查找所有包
version = '0.2.8', # 版本号
license='MIT', # 许可证
description = 'Self Rewarding LM - Pytorch', # 描述
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
long_description_content_type = 'text/markdown', # 长描述内容类型
url = 'https://github.com/lucidrains/self-rewarding-lm-pytorch', # URL
keywords = [ # 关键词
'artificial intelligence',
'deep learning',
'self rewarding',
'direct preference optimization'
],
install_requires=[ # 安装依赖
'accelerate',
'beartype',
'einops>=0.7.0',
'einx[torch]>=0.1.3',
'ema-pytorch>=0.3.3',
'Jinja2',
'numpy',
'pytorch-custom-utils>=0.0.17',
'torch>=2.0',
'torchtyping',
'tqdm'
],
classifiers=[ # 分类
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
Data source
The enwik8 data was downloaded from the Hutter prize page: prize.hutter1.net/
Simple Hierarchical Transformer
Experiments around a simple idea for inducing multiple hierarchical predictive coding models within a GPT. It is so simple, it may not work. But then again, deep learning progress is built on the bedrocks of simple ideas. Worth a shot.
So far, the idea has passed the litmus test from a research friend. Will bring it to completion in the next week or so. If it does not work out, I'll leave the negative experimental results as well as the repository around, and maybe some PhD student can build upon it.
Update: I think it is working 🤞
Appreciation
-
StabilityAI for the sponsorship to carry out this independent research
-
🤗 Huggingface for their accelerate library
Install
$ pip install simple-hierarchical-transformer
Usage
Three hierarchies, all servicing predicting the next token
import torch
from simple_hierarchical_transformer import HierarchicalTransformer
model = HierarchicalTransformer(
num_tokens = 20000, # number of tokens
dim = 512, # model dimensions
depth = 6, # depth
dim_head = 64, # dimension per attention head
heads = 8, # attention heads
seq_len = 2048, # sequence lengths
hierarchies = (1, 2, 8), # hierarchies - here we have 1x (like in a regular transformer), then 2x and 8x compressed hierarchical tokens that undergo their own transformer blocks. information is pooled into one hierarchy at each layer
window_sizes = (32, 64, None) # local attention window sizes - the idea is that the higher hierarchies can pass distant information to the local one. None stands for full receptive field. Setting 0 would turn off attention at this hierarchy altogether (while token shift will still be in effect in each layer)
)
ids = torch.randint(0, 20000, (1, 2048))
loss, _ = model(ids, return_loss = True)
loss.backward()
# after much training
logits = model(ids)
By not specifying hierarchies and window_sizes, you basically default to a regular autoregressive transformer with attention across full sequence length.
# non-hierarchical transformer
model = HierarchicalTransformer(
num_tokens = 20000,
dim = 512,
depth = 8,
dim_head = 64,
heads = 8,
seq_len = 2048,
hierarchies = 1, # implied 1 if not set
window_sizes = None # implied None (full sequence length) if not set
)
Now something more complex. Experiments show that as you compress up the hierarchies, you need greater model dimensions for appropriate capacity.
model = HierarchicalTransformer(
num_tokens = 256,
dim = (128, 256, 512, 1024),
depth = 8,
seq_len = 1024,
use_flash_attn = True,
ff_mult = (2, 2, 4, 4),
dim_head = (16, 32, 64, 64),
heads = (2, 4, 8, 8),
hierarchies = (1, 2, 4, 16),
hierarchical_stride = (1, 1, 1, 8), # this would determine the stride when compressing, and when concatting the hierarchical tokens to the fine tokens, the past tokens will be repeated this amount of time. causality is not violated as using the trick from hourglass transformers where sequence is shifted by compression factor - 1. recommend sticking with 1 except for highly compressed hierarchies, as it becomes very uncompetitive with baseline and generations look off
window_sizes = (16, 32, 64, None)
).cuda()
# hierarchies
# 1x - dim 128 - attention (2 heads, 16 dim, receptive field 16)
# 2x - dim 256 - attention (4 heads, 32 dim, receptive field 32)
# 4x - dim 512 - attention (8 heads, 64 dim, receptive field 64)
# 8x - dim 1024 - attention (8 heads, 64 dim, receptive field of all)
Todo
-
branch out to two parallel paths, one for hierarchical tokens, other for plain fine tokens.
-
show that local attention in fine + hierarchical tokens can come close to full attention baseline
-
simple dsconv seems enough to merge for 1 hierarchy
-
auto-set window size to be half of max sequence length for fine and all hierarchies
-
figure out effects of just pooling all fine + hierarchical tokens before cross entropy loss - not much of a difference
-
complete ability to add any number of hierarchies, and designate which hierarchy will pool the information from the others for prediction
-
fully customizable dimensions across hierarchies, as higher hierarchies require greater model dimensions
-
add prophet losses for hierarchical branches
-
allow for repeating hierarchy tokens for fine tokens in the future, as position may matter less as one goes up the hierarchy. but not a priority, get things working first - implemented as
hierarchical_stride -
allow for some layers to only rely on token shift, no attention
-
random projections + vq, as was done in universal speech model paper from brain - for hierarchical predictive coding
-
allow for specifying which hierarchy receives information from the others during merging, maybe design a specialized attention with masking, but need to account fo different model dimensions across hierarchies
-
build out simple local attention block, for use across all hierarchies
-
add flash attention to local attention library
-
figure out if attention can be shared across hierarchies
-
do a clean wandb report showing 2x compression without much loss for character level enwik8
-
try a self attention based compressor for hierarchies 4 or above
-
build a small autoencoder using the token embeddings as input, at the very beginning of the network, and then use intermediate feature maps for each parallel hierarchical network
Citations
Closest idea would be hourglass transformers.
And my renewed interest in hierarchical approaches came from reading this.
@article{Nawrot2021HierarchicalTA,
title = {Hierarchical Transformers Are More Efficient Language Models},
author = {Piotr Nawrot and Szymon Tworkowski and Michal Tyrolski and Lukasz Kaiser and Yuhuai Wu and Christian Szegedy and Henryk Michalewski},
journal = {ArXiv},
year = {2021},
volume = {abs/2110.13711}
}
@inproceedings{dao2022flashattention,
title = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
author = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
booktitle = {Advances in Neural Information Processing Systems},
year = {2022}
}
@misc{su2021roformer,
title = {RoFormer: Enhanced Transformer with Rotary Position Embedding},
author = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu},
year = {2021},
eprint = {2104.09864},
archivePrefix = {arXiv},
primaryClass = {cs.CL}
}
@inproceedings{Sun2022ALT,
title = {A Length-Extrapolatable Transformer},
author = {Yutao Sun and Li Dong and Barun Patra and Shuming Ma and Shaohan Huang and Alon Benhaim and Vishrav Chaudhary and Xia Song and Furu Wei},
year = {2022}
}
@software{peng_bo_2021_5196578,
author = {PENG Bo},
title = {BlinkDL/RWKV-LM: 0.01},
month = {aug},
year = {2021},
publisher = {Zenodo},
version = {0.01},
doi = {10.5281/zenodo.5196578},
url = {https://doi.org/10.5281/zenodo.5196578}
}
@article{Piergiovanni2023Mirasol3BAM,
title = {Mirasol3B: A Multimodal Autoregressive model for time-aligned and contextual modalities},
author = {A. J. Piergiovanni and Isaac Noble and Dahun Kim and Michael S. Ryoo and Victor Gomes and Anelia Angelova},
journal = {ArXiv},
year = {2023},
volume = {abs/2311.05698},
url = {https://api.semanticscholar.org/CorpusID:265129010}
}
.\lucidrains\simple-hierarchical-transformer\setup.py
# 导入设置工具和查找包的函数
from setuptools import setup, find_packages
# 设置包的元数据
setup(
name = 'simple-hierarchical-transformer', # 包的名称
packages = find_packages(exclude=[]), # 查找并包含所有包
version = '0.2.0', # 版本号
license='MIT', # 许可证
description = 'Simple Hierarchical Transformer', # 描述
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
long_description_content_type = 'text/markdown', # 长描述内容类型
url = 'https://github.com/lucidrains/simple-hierarchical-transformer', # 项目链接
keywords = [ # 关键词列表
'artificial intelligence',
'deep learning',
'transformers',
'attention mechanism',
'hierarchical'
],
install_requires=[ # 依赖的包列表
'accelerate',
'einops>=0.7.0',
'local-attention',
'torch>=2.0'
],
classifiers=[ # 分类器列表
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\simple-hierarchical-transformer\simple_hierarchical_transformer\attention.py
# 导入 torch 库
import torch
# 从 torch 库中导入 nn, einsum 模块
from torch import nn, einsum
# 从 torch.nn 模块中导入 Module 类
from torch.nn import Module
# 从 torch.nn.functional 模块中导入 F 别名
import torch.nn.functional as F
# 导入 namedtuple 类
from collections import namedtuple
# 导入 wraps 函数
from functools import wraps
# 从 packaging 模块中导入 version 类
from packaging import version
# 从 einops 库中导入 rearrange 函数
from einops import rearrange
# 定义常量 Config,使用 namedtuple 创建一个命名元组
Config = namedtuple('EfficientAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
# 定义辅助函数
# 判断值是否存在的函数
def exists(val):
return val is not None
# 仅执行一次的装饰器函数
def once(fn):
called = False
@wraps(fn)
def inner(x):
nonlocal called
if called:
return
called = True
return fn(x)
return inner
# 仅打印一次的函数
print_once = once(print)
# 主类
# 定义 Attend 类,继承自 Module 类
class Attend(Module):
# 初始化函数
def __init__(
self,
causal = False,
use_flash_attn = False
):
super().__init__()
# 是否是因果关系
self.causal = causal
# 注册缓冲区 mask,初始值为 None
self.register_buffer("mask", None, persistent=False)
# 是否使用 flash attention
self.use_flash_attn = use_flash_attn
# 断言语句,如果使用 flash attention 且 torch 版本小于 2.0,则抛出异常
assert not (use_flash_attn and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
# 确定 cuda 和 cpu 的高效注意力配置
self.cpu_config = Config(True, True, True)
self.cuda_config = None
# 如果没有可用的 cuda 或不使用 flash attention,则直接返回
if not torch.cuda.is_available() or not use_flash_attn:
return
# 获取当前 cuda 设备的属性
device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
# 如果是 A100 GPU,则打印信息并设置 cuda_config
if device_properties.major == 8 and device_properties.minor == 0:
print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
self.cuda_config = Config(True, False, False)
else:
# 如果不是 A100 GPU,则打印信息并设置 cuda_config
print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
self.cuda_config = Config(False, True, True)
# 获取 mask 函数
def get_mask(self, n, device):
# 如果 mask 存在且形状大于等于 n,则返回 mask
if exists(self.mask) and self.mask.shape[-1] >= n:
return self.mask[:n, :n]
# 创建 mask,上三角矩阵
mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1)
# 注册缓冲区 mask
self.register_buffer("mask", mask, persistent=False)
return mask
# flash attention 函数
def flash_attn(self, q, k, v, mask = None):
_, heads, q_len, _, k_len, is_cuda = *q.shape, k.shape[-2], q.is_cuda
# 检查 mask 是否存在并扩展到兼容的形状
if exists(mask):
mask = rearrange(mask, 'b j -> b 1 1 j')
mask = mask.expand(-1, heads, q_len, -1)
# 检查是否有兼容的设备用于 flash attention
config = self.cuda_config if is_cuda else self.cpu_config
# 使用 torch.backends.cuda.sdp_kernel 运行 pytorch 2.0 flash attention
with torch.backends.cuda.sdp_kernel(**config._asdict()):
out = F.scaled_dot_product_attention(
q, k, v,
attn_mask = mask,
is_causal = self.causal
)
return out
# 前向传播函数
def forward(self, q, k, v, mask = None):
"""
einstein notation
b - batch
h - heads
n, i, j - sequence length (base sequence length, source, target)
d - feature dimension
"""
n, device = q.shape[-2], q.device
scale = q.shape[-1] ** -0.5
if self.use_flash_attn:
return self.flash_attn(q, k, v, mask = mask)
# 相似度
sim = einsum("b h i d, b h j d -> b h i j", q, k) * scale
# key padding mask
if exists(mask):
mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
# 因果 mask
if self.causal:
causal_mask = self.get_mask(n, device)
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
# 注意力
attn = sim.softmax(dim=-1)
# 聚合值
out = einsum("b h i j, b h j d -> b h i d", attn, v)
return out
.\lucidrains\simple-hierarchical-transformer\simple_hierarchical_transformer\simple_hierarchical_transformer.py
# 从 math 模块中导入 log2 和 ceil 函数
# 从 functools 模块中导入 partial 函数
# 从 itertools 模块中导入 zip_longest 函数
# 导入 torch 库
import torch
# 从 torch.nn.functional 模块中导入 F
import torch.nn.functional as F
# 从 torch.cuda.amp 模块中导入 autocast 函数
from torch.cuda.amp import autocast
# 从 torch 模块中导入 nn, einsum, Tensor
from torch import nn, einsum, Tensor
# 从 torch.nn 模块中导入 Module, ModuleList
from torch.nn import Module, ModuleList
# 从 einops 模块中导入 rearrange, repeat
from einops import rearrange, repeat
# 从 einops.layers.torch 模块中导入 Rearrange
from einops.layers.torch import Rearrange
# 从 simple_hierarchical_transformer.attention 模块中导入 Attend
from simple_hierarchical_transformer.attention import Attend
# 从 typing 模块中导入 Tuple
from typing import Tuple
# 从 local_attention 模块中导入 LocalMHA
# 定义常量 Linear,使用 nn.Linear 函数,设置 bias 参数为 False
Linear = partial(nn.Linear, bias = False)
# 定义 LocalMHA,使用 partial 函数,设置 LocalMHA 函数的 causal 和 prenorm 参数为 True
# 定义辅助函数 exists,判断值是否存在
def exists(val):
return val is not None
# 定义辅助函数 is_power_of_two,判断一个数是否为2的幂
def is_power_of_two(n):
return log2(n).is_integer()
# 定义辅助函数 all_unique,判断列表中的元素是否唯一
def all_unique(arr):
return len(set(arr)) == len(arr
# 定义辅助函数 apply_fns,对输入的函数列表和张量列表进行函数应用
def apply_fns(fns, tensors):
return [fn(tensor) for fn, tensor in zip(fns, tensors)]
# 定义辅助函数 cast_tuple,将输入转换为元组
def cast_tuple(t, length = 1):
return t if isinstance(t, tuple) else ((t,) * length)
# 定义辅助函数 default,返回第一个非空值
def default(*vals):
for val in vals:
if exists(val):
return val
return None
# 定义 eval_decorator 装饰器函数,用于在模型评估时切换为 eval 模式
def eval_decorator(fn):
def inner(model, *args, **kwargs):
was_training = model.training
model.eval()
out = fn(model, *args, **kwargs)
model.train(was_training)
return out
return inner
# 定义张量辅助函数 l2norm,对张量进行 L2 归一化
def l2norm(t):
return F.normalize(t, dim = -1)
# 定义余弦相似度损失函数 cosine_sim_loss,计算余弦相似度损失
def cosine_sim_loss(x, y):
x, y = map(l2norm, (x, y))
return 1. - einsum('b n d, b n d -> b n', x, y).mean()
# 定义采样辅助函数 log,对张量进行对数运算
def log(t, eps = 1e-20):
return t.clamp(min = eps).log()
# 定义采样辅助函数 gumbel_noise,生成 Gumbel 噪声
def gumbel_noise(t):
noise = torch.zeros_like(t).uniform_(0, 1)
return -log(-log(noise))
# 定义采样辅助函数 gumbel_sample,使用 Gumbel 噪声进行采样
def gumbel_sample(t, temperature = 1., dim = -1):
return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim = dim)
# 定义采样辅助函数 top_k,对 logits 进行 top-k 采样
def top_k(logits, thres = 0.9):
k = int((1 - thres) * logits.shape[-1])
val, ind = torch.topk(logits, k)
probs = torch.full_like(logits, -torch.finfo(logits.dtype).max)
probs.scatter_(1, ind, val)
return probs
# 旋转位置嵌入类 RotaryEmbedding
class RotaryEmbedding(Module):
# 初始化函数
def __init__(
self,
dim,
scale_base = 512,
use_xpos = True
):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
self.use_xpos = use_xpos
self.scale_base = scale_base
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
self.register_buffer('scale', scale)
# 获取设备信息
@property
def device(self):
return next(self.buffers()).device
# 前向传播函数
@autocast(enabled = False)
def forward(self, seq_len):
device = self.device
t = torch.arange(seq_len, device = device).type_as(self.inv_freq)
freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
freqs = torch.cat((freqs, freqs), dim = -1)
if not self.use_xpos:
return freqs, torch.ones(1, device = device)
power = (t - (seq_len // 2)) / self.scale_base
scale = self.scale ** rearrange(power, 'n -> n 1')
scale = torch.cat((scale, scale), dim = -1)
return freqs, scale
# 旋转半部分函数 rotate_half
def rotate_half(x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
# 应用旋转位置嵌入函数 apply_rotary_pos_emb
def apply_rotary_pos_emb(pos, t, scale = 1.):
seq_len = t.shape[-2]
pos = pos[..., -seq_len:, :]
if not isinstance(scale, (int, float)):
scale = scale[..., -seq_len:, :]
return (t * pos.cos() * scale) + (rotate_half(t) * pos.sin() * scale)
# 应用旋转位置嵌入到查询和键函数 apply_rotary_pos_emb_qk
@autocast(enabled = False)
def apply_rotary_pos_emb_qk(rotary_emb, q, k):
freqs, scale = rotary_emb
q = apply_rotary_pos_emb(freqs, q, scale)
k = apply_rotary_pos_emb(freqs, k, scale ** -1)
return q, k
# 令牌移位函数 token_shift
def token_shift(t):
t, t_shift = t.chunk(2, dim = -1)
t_shift = F.pad(t_shift, (0, 0, 1, -1))
return torch.cat((t, t_shift), dim = -1)
# hierarchy related classes
# 将序列填充到指定倍数
def pad_seq_to_multiple(t, mult):
# 获取序列长度
seq_len = t.shape[-2]
# 计算下一个序列长度的倍数
next_seq_len_mult = ceil(seq_len / mult) * mult
# 计算需要填充的长度
remainder = next_seq_len_mult - seq_len
# 如果不需要填充,则直接返回原序列和序列长度
if remainder == 0:
return t, seq_len
# 对序列进行填充
t = F.pad(t, (0, 0, 0, remainder), value = 0.)
return t, seq_len
# 将序列截断到指定倍数
def curtail_seq_to_multiple(t, mult):
# 获取序列长度
seq_len = t.shape[-2]
# 计算前一个序列长度的倍数
prev_seq_len_mult = (seq_len // mult) * mult
# 计算需要截断的长度
remainder = seq_len - prev_seq_len_mult
# 如果不需要截断,则直接返回原序列
if remainder == 0:
return t
# 对序列进行截断
t = t[..., :prev_seq_len_mult, :]
return t
# 将多个序列按照指定步长合并
def hierarchical_cat(tokens, strides: Tuple[int, ...]):
# 断言tokens和strides的长度相等
assert len(tokens) == len(strides)
# 如果所有步长都为1,则直接拼接所有序列
if all([s == 1 for s in strides]):
return torch.cat(tokens, dim = -1)
# 对每个序列进行重复以匹配步长
tokens = [repeat(t, 'b n d -> b (n s) d', s = s) for t, s in zip(tokens, strides)]
# 获取最小序列长度
min_seq_len = min([t.shape[-2] for t in tokens])
# 截取所有序列到最小序列长度
tokens = [t[..., :min_seq_len, :] for t in tokens]
return torch.cat(tokens, dim = -1)
# 定义CausalConv类
class CausalConv(Module):
def __init__(
self,
dim_in,
dim_out,
kernel_size,
stride = 1
):
super().__init__()
# 设置causal_padding为kernel_size - 1
self.causal_padding = kernel_size - 1
# 创建Conv1d层
self.conv = nn.Conv1d(dim_in, dim_out, kernel_size, stride = stride)
def forward(self, x):
# 对输入进行padding
x = F.pad(x, (self.causal_padding, 0))
return self.conv(x)
# 定义Compress类
class Compress(Module):
def __init__(
self,
*,
dim,
dim_out,
num_tokens = None,
stride = 1,
compress_factor = 1,
expansion_factor = 4,
dim_head = 64,
heads = 8,
ignore_index = 0,
should_recon = False
):
super().__init__()
# 断��compress_factor大于0且为2的幂
assert compress_factor > 0 and is_power_of_two(compress_factor)
self.stride = stride
self.no_compress = compress_factor == 1
self.compress_factor = compress_factor
self.should_recon = should_recon
# 如果不压缩,则使用Linear层或者Identity层
if self.no_compress:
self.compress_fn = Linear(dim, dim_out) if dim != dim_out else nn.Identity()
return
dim_inner = int(dim * expansion_factor)
# 使用Sequential定义压缩函数
self.compress_fn = nn.Sequential(
Rearrange('b n d -> b d n'),
CausalConv(dim, dim_inner, compress_factor, stride = stride),
nn.SiLU(),
nn.Conv1d(dim_inner, dim_out, 1),
Rearrange('b d n -> b n d')
)
# 如果需要重构,则定义Linear层
if should_recon:
assert exists(num_tokens)
self.to_recon = Linear(dim_out, compress_factor * num_tokens)
self.ignore_index = ignore_index
# 重构函数
def recon(self, h, ids):
assert self.should_recon
if self.no_compress:
return torch.zeros((), device = h.device).requires_grad_()
c = self.compress_factor
seq_len = ids.shape[-1]
recon_logits = self.to_recon(h)
recon_logits = rearrange(recon_logits, 'b n (c d) -> (b c) d n', c = c)
recon_ids = F.pad(ids, (c - 1, 0), value = self.ignore_index)
recon_ids = tuple(recon_ids[:, i:(seq_len + i)] for i in range(c))
recon_ids = torch.stack(recon_ids, dim = 1)
recon_ids = rearrange(recon_ids, 'b c n -> (b c) n')
if self.stride > 1:
recon_ids = recon_ids[..., ::self.stride]
recon_loss = F.cross_entropy(recon_logits, recon_ids, ignore_index = self.ignore_index)
return recon_loss
def forward(self, x):
return self.compress_fn(x)
# 定义HierarchicalMerge类
class HierarchicalMerge(Module):
def __init__(
self,
dims: Tuple[int, ...],
dim_out,
h_strides = 1
):
super().__init__()
dim = sum(dims)
strides = cast_tuple(h_strides, len(dims))
assert len(strides) == len(dims)
self.strides = strides
# 使用Sequential定义网络结构
self.net = nn.Sequential(
RMSNorm(dim),
nn.Linear(dim, dim_out * 2),
nn.SiLU(),
nn.Linear(dim_out * 2, dim_out)
)
# 定义一个前向传播函数,接收 tokens 作为输入
def forward(self, tokens):
# 调用 hierarchical_cat 函数对 tokens 进行处理,得到 x
x = hierarchical_cat(tokens, self.strides)
# 将处理后的 x 传入神经网络中进行前向传播,返回结果
return self.net(x)
# 定义 RMSNorm 类,继承自 Module 类
class RMSNorm(Module):
# 初始化方法,接受维度参数 dim
def __init__(self, dim):
# 调用父类的初始化方法
super().__init__()
# 计算缩放因子
self.scale = dim ** 0.5
# 初始化可学习参数 gamma
self.gamma = nn.Parameter(torch.ones(dim))
# 前向传播方法,接受输入 x
def forward(self, x):
# 对输入 x 进行归一化处理,乘以缩放因子和 gamma
return F.normalize(x, dim=-1) * self.scale * self.gamma
# 定义 FeedForward 类,继承自 Module 类
class FeedForward(Module):
# 初始化方法,接受维度参数 dim 和倍数参数 mult,默认为 4
def __init__(self, dim, mult=4):
# 调用父类的初始化方法
super().__init__()
# 计算内部维度
dim_inner = int(dim * mult)
# 定义神经网络结构
self.net = nn.Sequential(
RMSNorm(dim),
Linear(dim, dim_inner),
nn.GELU(),
Linear(dim_inner, dim)
)
# 前向传播方法,接受输入 x
def forward(self, x):
# 将输入 x 传入神经网络
return self.net(x)
# 定义 Attention 类,继承自 Module 类
class Attention(Module):
# 初始化方法,接受维度参数 dim,头部维度参数 dim_head,默认为 64,头部数量参数 heads,默认为 8,是否使用 Flash Attention 参数 use_flash_attn,默认为 False
def __init__(
self,
dim,
dim_head=64,
heads=8,
use_flash_attn=False
):
# 调用父类的初始化方法
super().__init__()
# 计算缩放因子
self.scale = dim_head ** -0.5
# 头部数量
self.heads = heads
# 内部维度
dim_inner = dim_head * heads
# 初始化 RMSNorm 和 RotaryEmbedding
self.norm = RMSNorm(dim)
self.rotary_emb = RotaryEmbedding(dim_head)
# 初始化 Attend 层
self.attend = Attend(causal=True, use_flash_attn=use_flash_attn)
# 初始化线性层,用于计算 Q、K、V
self.to_qkv = Linear(dim, dim_inner * 3)
# 初始化线性层,用于输出
self.to_out = Linear(dim_inner, dim)
# 前向传播方法,接受输入 x
def forward(self, x):
# 获取输入 x 的倒数第二维度大小
n = x.shape[-2]
# 对输入 x 进行归一化处理
x = self.norm(x)
# 将输入 x 经过线性层得到 Q、K、V,并按头部维度拆分
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), (q, k, v))
# 获取旋转位置编码
rotary_emb = self.rotary_emb(n)
# 对 Q、K 应用旋转位置编码
q, k = apply_rotary_pos_emb_qk(rotary_emb, q, k)
# 进行注意力计算
out = self.attend(q, k, v)
# 重排输出维度
out = rearrange(out, 'b h n d -> b n (h d)')
# 经过输出线性层
return self.to_out(out)
# 定义 HierarchicalBlock 类,继承自 Module 类
class HierarchicalBlock(Module):
# 初始化方法,接受维度参数 dim,头部维度参数 dim_head,默认为 64,头部数量参数 heads,默认为 8,窗口大小参数 window_size,默认为 None,压缩因子参数 compress_factor,默认为 1,步长参数 stride,默认为 1,FeedForward 倍数参数 ff_mult,默认为 4
def __init__(
self,
dim,
dim_head=64,
heads=8,
window_size=None,
compress_factor=1,
stride=1,
ff_mult=4
):
# 调用父类的初始化方法
super().__init__()
# 步长
self.stride = stride
# 断言压缩因子为 2 的幂
assert is_power_of_two(compress_factor)
self.compress_factor = compress_factor
self.no_compress = compress_factor == 1
# 断言窗口大小为非负数
assert not exists(window_size) or window_size >= 0
self.has_attn = window_size != 0
# 初始化注意力层
self.attn = None
if self.has_attn:
attn_klass = Attention
if exists(window_size):
attn_klass = partial(LocalMHA, window_size=window_size)
self.attn = attn_klass(dim=dim, dim_head=dim_head, heads=heads)
# 初始化 FeedForward 层
self.ff = FeedForward(dim=dim, mult=ff_mult)
# 前向传播方法,接受输入 x
def forward(self, x):
c = self.compress_factor
axial_dim = c // self.stride
# 将输入 x 进行填充,使其长度为压缩因子的整数倍
x, orig_seq_len = pad_seq_to_multiple(x, axial_dim)
# 如果不需要压缩,则直接返回
if not self.no_compress:
x = rearrange(x, 'b (n c) d -> (b c) n d', c=axial_dim)
# 如果存在注意力层,则进行注意力计算
if exists(self.attn):
x = self.attn(token_shift(x)) + x
# 经过 FeedForward 层
x = self.ff(token_shift(x)) + x
# 如果不需要压缩,则重排维度
if not self.no_compress:
x = rearrange(x, '(b c) n d -> b (n c) d', c=axial_dim)
# 返回结果,截取原始序列长度
return x[:, :orig_seq_len]
# 定义 HierarchicalTransformer 类
class HierarchicalTransformer(Module):
# 初始化函数,设置模型参数
def __init__(
self,
*,
num_tokens, # 标记数量
dim, # 向量维度
depth, # 深度
seq_len = 2048, # 序列长度,默认为2048
dim_head = 64, # 头部维度
heads = 8, # 头部数量
ff_mult = 4, # FeedForward 层的倍数
hierarchies = 1, # 分层数量
window_sizes = None, # 窗口大小
hierarchical_stride = 1, # 分层步长
hierarchy_merge_all = False, # 是否将汇总的分层信息传递回所有分层或只传递给一个进行预测
predict_hierarchy = None, # 预测分层
predict_use_all_hierarchy = False, # 是否使用所有分层进行预测
recon_loss_weight = 0.1, # 重构损失权重
hierarchical_ar_loss_weight = 0.25, # 分层自回归损失权重
ignore_index = 0, # 忽略的索引
use_flash_attn = False, # 是否使用 Flash Attention
@torch.no_grad() # 禁用梯度计算
@eval_decorator # 评估装饰器
def generate(
self,
prompt, # 提示
seq_len, # 序列长度
temperature = 1.0, # 温度
filter_thres = 0.9, # 过滤阈值
**kwargs # 其他参数
):
b, t, device = *prompt.shape, prompt.device
out = prompt
# 生成序列
for _ in range(seq_len):
logits = self.forward(out[:, -self.seq_len:], **kwargs)[:, -1]
filtered_logits = top_k(logits, thres = filter_thres)
sample = gumbel_sample(filtered_logits, temperature = temperature)
sample = rearrange(sample, 'b -> b 1')
out = torch.cat((out, sample), dim = -1)
return out[:, t:] # 返回生成的序列
@property
def device(self):
return next(self.parameters()).device # 返回模型参数的设备
# 前向传播函数
def forward(
self,
ids, # 标识符
return_loss = False, # 是否返回损失
return_hierarchical_token_embeds = False, # 是否返回分层标记嵌入
return_hierarchical_embeds = False, # 是否返回分层嵌入
ablate_hierarchical_merge = False # 是否消融分层合并
):
"""
einops notation:
b - batch
n - sequence length
c - compression factor
d - dimension
"""
# 如果是训练阶段,预测序列中的下一个标记
if return_loss:
ids, labels = ids[:, :-1], ids[:, 1:]
# 断言序列长度
assert ids.shape[-1] <= self.seq_len
# 获取标记嵌入,并填充到压缩因子的倍数
x = self.token_emb(ids)
# 对于每个层次结构,适当地压缩标记嵌入到层次嵌入中
tokens = []
for compress in self.compressors:
tokens.append(compress(x))
# 后嵌入规范化
tokens = apply_fns(self.post_token_emb_norms, tokens)
# 如果想要所有压缩后的标记嵌入
# 仅用于研究空间
if return_hierarchical_token_embeds:
return tokens
# 层次结构
for layer, merge in zip_longest(self.layers, self.hierarchical_merges):
tokens = apply_fns(layer, tokens)
# 汇总所有层次的信息
# 然后更新将用于进行最终自回归预测的标记
if not self.need_hierarchical_merge or ablate_hierarchical_merge:
continue
pooled = merge(tokens)
if self.hierarchy_merge_all:
tokens = [(t + p[..., ::s, :]) for t, p, s in zip(tokens, pooled.split(self.dims, dim = -1), self.h_strides)]
else:
predict_tokens = tokens[self.predict_hierarchy_index]
predict_tokens = predict_tokens + pooled
tokens[self.predict_hierarchy_index] = predict_tokens
# 最终规范化嵌入
embeds = apply_fns(self.norms, tokens)
# 如果想要所有规范化的层次嵌入
if return_hierarchical_embeds:
return embeds
# 选择将进行预测的层次嵌入
if self.predict_use_all_hierarchy:
predict_embed = hierarchical_cat(embeds, self.h_strides)
else:
predict_embed = embeds[self.predict_hierarchy_index]
# 用于预测下一个标记的对数
logits = self.to_logits(predict_embed)
if not return_loss:
return logits
# 自回归损失(预测编码)
logits = rearrange(logits, 'b n c -> b c n')
ce_loss = F.cross_entropy(logits, labels, ignore_index = self.ignore_index)
# 层次标记的重建损失
recon_losses = self.zeros.requires_grad_()
if self.should_recon:
for compress, t in zip(self.compressors, embeds):
recon_loss = compress.recon(t, ids)
recon_losses = recon_losses + recon_loss
# 层次自回归损失
hierarchical_ar_losses = self.zeros.requires_grad_()
for h_embed, maybe_h_pred_linear in zip(embeds, self.to_hierarchical_preds):
if not exists(maybe_h_pred_linear):
continue
h_pred = maybe_h_pred_linear(h_embed)
h_ar_loss = cosine_sim_loss(h_pred[:, :-1], h_embed[:, 1:])
hierarchical_ar_losses = hierarchical_ar_losses + h_ar_loss
# 总损失
total_loss = ce_loss + \
recon_losses * self.recon_loss_weight + \
hierarchical_ar_losses * self.hierarchical_ar_loss_weight
return total_loss, (ce_loss, recon_losses, hierarchical_ar_losses)
.\lucidrains\simple-hierarchical-transformer\simple_hierarchical_transformer\__init__.py
# 从 simple_hierarchical_transformer.simple_hierarchical_transformer 模块中导入 HierarchicalTransformer 类
from simple_hierarchical_transformer.simple_hierarchical_transformer import HierarchicalTransformer
.\lucidrains\simple-hierarchical-transformer\train.py
# 导入必要的库
import gzip
import random
import tqdm
import numpy as np
import torch
from torch.optim import Adam
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
# 导入自定义的简单分层Transformer模型
from simple_hierarchical_transformer import HierarchicalTransformer
# 导入加速器库
from accelerate import Accelerator
# 初始化加速器
accelerator = Accelerator()
# 获取设备信息和打印函数
device = accelerator.device
acc_print = accelerator.print
# 定义常量
NUM_BATCHES = int(1e5)
BATCH_SIZE = 2
GRADIENT_ACCUMULATE_EVERY = 8
LEARNING_RATE = 1e-4
VALIDATE_EVERY = 100
PRIME_LENGTH = 128
GENERATE_EVERY = 500
SEQ_LEN = 2048
GENERATE_LENGTH = 1024
# 定义辅助函数
def cycle(loader):
while True:
for data in loader:
yield data
def decode_token(token):
return str(chr(max(32, token)))
def decode_tokens(tokens):
return "".join(list(map(decode_token, tokens)))
# 实例化Transformer模型
model = HierarchicalTransformer(
num_tokens = 256,
dim = 1024,
depth = 8,
seq_len = SEQ_LEN,
hierarchies = (1, 2),
window_sizes = (32, 64),
use_flash_attn = True
).to(device)
# 准备enwik8数据
with gzip.open("./data/enwik8.gz") as file:
data = np.frombuffer(file.read(int(95e6)), dtype=np.uint8).copy()
np_train, np_valid = np.split(data, [int(90e6)])
data_train, data_val = torch.from_numpy(np_train), torch.from_numpy(np_valid)
# 定义数据集类
class TextSamplerDataset(Dataset):
def __init__(self, data, seq_len):
super().__init__()
self.data = data
self.seq_len = seq_len
def __getitem__(self, index):
rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
full_seq = self.data[rand_start : rand_start + self.seq_len + 1].long()
return full_seq.to(device)
def __len__(self):
return self.data.size(0) // self.seq_len
# 创建训练集和验证集
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
train_loader = cycle(DataLoader(train_dataset, batch_size=BATCH_SIZE))
val_loader = cycle(DataLoader(val_dataset, batch_size=BATCH_SIZE))
# 定义优化器
optim = Adam(model.parameters(), lr = LEARNING_RATE)
# 准备模型、优化器和数据加载器
model, optim, train_loader, val_loader = accelerator.prepare(
model, optim, train_loader, val_loader
)
# 训练模型
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval = 10.0, desc = "training"):
model.train()
for _ in range(GRADIENT_ACCUMULATE_EVERY):
loss, (ce_loss, recon_loss, prophet_loss) = model(next(train_loader), return_loss = True)
accelerator.backward(loss / GRADIENT_ACCUMULATE_EVERY)
acc_print(f"training loss: {ce_loss.item()}")
accelerator.clip_grad_norm_(model.parameters(), 0.5)
optim.step()
optim.zero_grad()
if i % VALIDATE_EVERY == 0:
model.eval()
with torch.no_grad():
_, (ce_loss, *_) = model(next(val_loader), return_loss = True)
acc_print(f"validation loss: {ce_loss.item()}")
if i % GENERATE_EVERY == 0:
model.eval()
inp = random.choice(val_dataset)[:PRIME_LENGTH]
prime = decode_tokens(inp)
acc_print(f"%s \n\n %s", (prime, "*" * 100))
sample = model.generate(inp[None, ...], GENERATE_LENGTH)
output_str = decode_tokens(sample[0])
acc_print(output_str, "\n")
Data source
The enwik8 data was downloaded from the Hutter prize page: prize.hutter1.net/
Sinkhorn Transformer with Deepspeed for Enwik8
Deepspeed is the framework Microsoft used to train the world's largest Attention model (17GB) to date. They have open sourced it, and it works with Sinkhorn Transformers!
-
First install Deepspeed following instructions from their official repository github.com/microsoft/D…
-
Run the following command in this folder
$ deepspeed train.py --deepspeed --deepspeed_config ds_config.json
.\lucidrains\sinkhorn-transformer\examples\enwik8_deepspeed\train.py
import deepspeed # 导入deepspeed库
from sinkhorn_transformer import SinkhornTransformerLM # 从sinkhorn_transformer库中导入SinkhornTransformerLM类
from sinkhorn_transformer.autoregressive_wrapper import AutoregressiveWrapper # 从sinkhorn_transformer库中导入AutoregressiveWrapper类
import argparse # 导入argparse库
import random # 导入random库
import tqdm # 导入tqdm库
import gzip # 导入gzip库
import numpy as np # 导入numpy库,并重命名为np
import torch # 导入torch库
import torch.optim as optim # 从torch库中导入optim模块
from torch.nn import functional as F # 从torch库中导入functional模块,并重命名为F
from torch.utils.data import DataLoader, Dataset # 从torch.utils.data库中导入DataLoader和Dataset类
def add_argument(): # 定义函数add_argument
parser=argparse.ArgumentParser(description='enwik8') # 创建一个ArgumentParser对象,设置描述信息为'enwik8'
parser.add_argument('--with_cuda', default=False, action='store_true', # 添加一个参数'--with_cuda',默认值为False,如果存在则设置为True
help='use CPU in case there\'s no GPU support') # 设置参数'--with_cuda'的帮助信息
parser.add_argument('--use_ema', default=False, action='store_true', # 添加一个参数'--use_ema',默认值为False,如果存在则设置为True
help='whether use exponential moving average') # 设置参数'--use_ema'的帮助信息
parser.add_argument('-b', '--batch_size', default=32, type=int, # 添加一个参数'-b'或'--batch_size',默认值为32,类型为整数
help='mini-batch size (default: 32)') # 设置参数'-b'或'--batch_size'的帮助信息
parser.add_argument('-e', '--epochs', default=30, type=int, # 添加一个参数'-e'或'--epochs',默认值为30,类型为整数
help='number of total epochs (default: 30)') # 设置参数'-e'或'--epochs'的帮助信息
parser.add_argument('--local_rank', type=int, default=-1, # 添加一个参数'--local_rank',类型为整数,默认值为-1
help='local rank passed from distributed launcher') # 设置参数'--local_rank'的帮助信息
parser = deepspeed.add_config_arguments(parser) # 调用deepspeed库中的add_config_arguments函数
args = parser.parse_args() # 解析命令行参数
return args # 返回参数args
# constants
VALIDATE_EVERY = 100 # 定义常量VALIDATE_EVERY为100
GENERATE_EVERY = 500 # 定义常量GENERATE_EVERY为500
GENERATE_LENGTH = 1024 # 定义常量GENERATE_LENGTH为1024
SEQ_LEN = 4096 # 定义常量SEQ_LEN为4096
# helpers
def decode_token(token): # 定义函数decode_token,接受一个token参数
return str(chr(max(32, token))) # 返回ASCII码对应的字符,如果小于32则返回空格
def decode_tokens(tokens): # 定义函数decode_tokens,接受一个tokens参数
return ''.join(list(map(decode_token, tokens))) # 将tokens中的每个token转换为字符,并拼接成字符串
# instantiate model
model = SinkhornTransformerLM( # 创建SinkhornTransformerLM模型对象
num_tokens = 256, # 设置num_tokens参数为256
emb_dim = 128, # 设置emb_dim参数为128
dim = 512, # 设置dim参数为512
depth = 8, # 设置depth参数为8
max_seq_len = SEQ_LEN, # 设置max_seq_len参数为SEQ_LEN
heads = 8, # 设置heads参数为8
bucket_size = 128, # 设置bucket_size参数为128
ff_chunks = 10, # 设置ff_chunks参数为10
causal = True, # 设置causal参数为True
reversible = True, # 设置reversible参数为True
attn_dropout = 0.1, # 设置attn_dropout参数为0.1
n_local_attn_heads = 4 # 设置n_local_attn_heads参数为4
)
model = AutoregressiveWrapper(model) # 使用AutoregressiveWrapper对模型进行包装
model.cuda() # 将模型移动到GPU上
# prepare enwik8 data
with gzip.open('./data/enwik8.gz') as file: # 打开enwik8.gz文件
X = np.fromstring(file.read(int(95e6)), dtype=np.uint8) # 从文件中读取数据,转换为numpy数组
trX, vaX = np.split(X, [int(90e6)]) # 将数据分割为训练集和验证集
data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX) # 将数据转换为PyTorch张量
class TextSamplerDataset(Dataset): # 定义TextSamplerDataset类,继承自Dataset类
def __init__(self, data, seq_len): # 定义初始化方法,接受data和seq_len参数
super().__init__() # 调用父类的初始化方法
self.data = data # 设置数据属性为传入的data
self.seq_len = seq_len # 设置序列长度属性为传入的seq_len
def __getitem__(self, index): # 定义获取数据项方法,接受index参数
rand_start = torch.randint(0, self.data.size(0) - self.seq_len - 1, (1,)) # 生成随机起始位置
full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long() # 获取完整序列
return full_seq # 返回完整序列
def __len__(self): # 定义长度方法
return self.data.size(0) // self.seq_len # 返回数据长度除以序列长度的整数部分
train_dataset = TextSamplerDataset(data_train, SEQ_LEN) # 创建训练集数据集对象
val_dataset = TextSamplerDataset(data_val, SEQ_LEN) # 创建验证集数据集对象
# setup deepspeed
cmd_args = add_argument() # 调用add_argument函数,获取命令行参数
model_engine, optimizer, trainloader, _ = deepspeed.initialize(args=cmd_args, model=model, model_parameters=model.parameters(), training_data=train_dataset) # 使用deepspeed初始化模型引擎、优化器、训练数据加载器
# training
for i, data in enumerate(trainloader): # 遍历训练数据加载器
model_engine.train() # 设置模型为训练模式
data = data.to(model_engine.local_rank) # 将数据移动到指定设备
loss = model_engine(data, return_loss = True) # 计算损失
model_engine.backward(loss) # 反向传播
model_engine.step() # 更新模型参数
print(loss.item() * 4) # 打印损失值
if i % VALIDATE_EVERY == 0: # 每隔VALIDATE_EVERY次迭代进行验证
model.eval() # 设置模型为评估模式
with torch.no_grad(): # 禁用梯度计算
inp = random.choice(val_dataset) # 从验证集中随机选择一个样本
loss = model(inp[None, :].cuda(), return_loss = True) # 计算验证集上的损失
print(f'validation loss: {loss.item()}') # 打印验证损失值
if model_engine.local_rank == 0 and i % GENERATE_EVERY == 0: # 如果是主进程且每隔GENERATE_EVERY次迭代生成样本
model.eval() # 设置��型为评估模式
inp = random.choice(val_dataset)[:-1] # 从验证集中随机选择一个样本,并去掉最后一个字符
prime = decode_tokens(inp) # 解码得到的输入序列
print(f'%s \n\n %s', (prime, '*' * 100)) # 打印输入序列和分隔符
sample = model.generate(inp.cuda(), GENERATE_LENGTH) # 生成样本
output_str = decode_tokens(sample) # 解码生成的样本
print(output_str) # 打印生成的样本
Data source
The enwik8 data was downloaded from the Hutter prize page: prize.hutter1.net/
.\lucidrains\sinkhorn-transformer\examples\enwik8_simple\train.py
# 导入所需的库和模块
from sinkhorn_transformer import SinkhornTransformerLM
from sinkhorn_transformer.autoregressive_wrapper import AutoregressiveWrapper
import random
import tqdm
import gzip
import numpy as np
import torch
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
# 定义常量
NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 1e-4
VALIDATE_EVERY = 100
GENERATE_EVERY = 500
GENERATE_LENGTH = 512
SEQ_LEN = 4096
# 定义辅助函数
# 从 token 解码为字符
def decode_token(token):
return str(chr(max(32, token)))
# 从 tokens 解码为字符串
def decode_tokens(tokens):
return ''.join(list(map(decode_token, tokens)))
# 实例化模型
model = SinkhornTransformerLM(
num_tokens = 256,
emb_dim = 128,
dim = 512,
depth = 8,
max_seq_len = SEQ_LEN,
heads = 8,
bucket_size = 128,
ff_chunks = 2,
causal = True,
reversible = True,
attn_dropout = 0.1,
n_local_attn_heads = 4
)
model = AutoregressiveWrapper(model)
model.cuda()
# 准备 enwik8 数据
with gzip.open('./data/enwik8.gz') as file:
X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
trX, vaX = np.split(X, [int(90e6)])
data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)
# 定义数据集类
class TextSamplerDataset(Dataset):
def __init__(self, data, seq_len):
super().__init__()
self.data = data
self.seq_len = seq_len
def __getitem__(self, index):
rand_start = torch.randint(0, self.data.size(0) - self.seq_len - 1, (1,))
full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
return full_seq.cuda()
def __len__(self):
return self.data.size(0) // self.seq_len
# 创建训练集和验证集的数据加载器
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))
# 定义优化器
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
# 训练模型
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
model.train()
for __ in range(GRADIENT_ACCUMULATE_EVERY):
loss = model(next(train_loader), return_loss = True)
loss.backward()
print(f'training loss: {loss.item()}')
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
optim.step()
optim.zero_grad()
if i % VALIDATE_EVERY == 0:
model.eval()
with torch.no_grad():
loss = model(next(val_loader), return_loss = True)
print(f'validation loss: {loss.item()}')
if i % GENERATE_EVERY == 0:
model.eval()
inp = random.choice(val_dataset)[:-1]
prime = decode_tokens(inp)
print(f'%s \n\n %s', (prime, '*' * 100))
sample = model.generate(inp, GENERATE_LENGTH)
output_str = decode_tokens(sample)
print(output_str)