Lucidrains 系列项目源码解析(十一)
.\lucidrains\byol-pytorch\setup.py
# 导入设置工具和查找包的函数
from setuptools import setup, find_packages
# 设置包的元数据
setup(
# 包的名称
name = 'byol-pytorch',
# 查找并包含除了'examples'之外的所有包
packages = find_packages(exclude=['examples']),
# 版本号
version = '0.8.0',
# 许可证类型
license='MIT',
# 描述信息
description = 'Self-supervised contrastive learning made simple',
# 作者
author = 'Phil Wang',
# 作者邮箱
author_email = 'lucidrains@gmail.com',
# 项目链接
url = 'https://github.com/lucidrains/byol-pytorch',
# 长描述内容类型
long_description_content_type = 'text/markdown',
# 关键词列表
keywords = [
'self-supervised learning',
'artificial intelligence'
],
# 安装依赖
install_requires=[
'accelerate',
'beartype',
'torch>=1.6',
'torchvision>=0.8'
],
# 分类标签
classifiers=[
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\CALM-pytorch\CALM_pytorch\CALM.py
# 从 math 模块中导入 ceil 函数
from math import ceil
# 从 pathlib 模块中导入 Path 类
from pathlib import Path
# 从 functools 模块中导入 partial 函数
from functools import partial
# 从 contextlib 模块中导入 nullcontext 和 contextmanager 函数
from contextlib import nullcontext, contextmanager
# 从 dataclasses 模块中导入 dataclass 装饰器
from dataclasses import dataclass
# 导入 torch 库
import torch
# 从 torch.nn.functional 模块中导入 F 别名
import torch.nn.functional as F
# 从 torch.nn 模块中导入 Module 和 ModuleList 类
from torch.nn import Module, ModuleList
# 从 torch.utils.data 模块中导入 Dataset 和 DataLoader 类
from torch.utils.data import Dataset, DataLoader
# 从 torch.optim.lr_scheduler 模块中导入 _LRScheduler 类
from torch.optim.lr_scheduler import _LRScheduler
# 从 torch 模块中导入 nn、einsum 和 Tensor 类
from torch import nn, einsum, Tensor
# 导入 beartype 库
from beartype import beartype
from beartype.door import is_bearable
# 从 beartype.typing 模块中导入 List、Optional、Callable、Type、Tuple、Union、Literal 类型
from beartype.typing import List, Optional, Callable, Type, Tuple, Union, Literal
# 从 einops 库中导入 rearrange 和 repeat 函数
from einops import rearrange, repeat
# 从 x_transformers.x_transformers 模块中导入 RMSNorm、Attention 和 TransformerWrapper 类
from x_transformers.x_transformers import (
RMSNorm,
Attention,
TransformerWrapper,
)
# 导入 accelerate 库
from accelerate import Accelerator
# 从 pytorch_custom_utils 模块中导入 OptimizerWithWarmupSchedule、get_adam_optimizer 和 auto_unwrap_model 函数
from pytorch_custom_utils import (
OptimizerWithWarmupSchedule,
get_adam_optimizer,
auto_unwrap_model
)
# 从 pytorch_custom_utils.accelerate_utils 模块中导入 model_forward_contexts 函数
from pytorch_custom_utils.accelerate_utils import (
model_forward_contexts
)
# 从 CALM_pytorch.sampling_utils 模块中导入 sample、top_p 和 top_k 函数
# types
# 定义 Sequence 类型为 Tuple 或 List
Sequence = Union[Tuple, List]
# 定义 HiddenPosition 类型为 'input' 或 'output'
HiddenPosition = Union[Literal['input'], Literal['output']]
# 定义 SequenceOf 函数,接受类型参数 t,返回 Tuple[t, ...] 或 List[t]
def SequenceOf(t):
return Union[Tuple[t, ...], List[t]]
# 定义 SingularOrMany 函数,接受类型参数 t,返回 t 或 SequenceOf(t)
def SingularOrMany(t):
return Union[t, SequenceOf(t)]
# helpers
# 定义 exists 函数,判断变量是否存在
def exists(v):
return v is not None
# 定义 default 函数,返回第一个参���或默认值
def default(v, d):
return v if exists(v) else d
# 定义 xnor 函数,实现逻辑异或操作
def xnor(x, y):
return not (x ^ y)
# 定义 cast_tuple 函数,将参数转换为元组
def cast_tuple(t, length = 1):
return t if is_bearable(t, Sequence) else ((t,) * length)
# 定义 get_block_output_from_hook_outputs 函数,从钩子输出中获取模块输出
def get_block_output_from_hook_outputs(
hidden_position: HiddenPosition,
_, inp, out
):
maybe_tensor = out if hidden_position == 'output' else inp
if isinstance(maybe_tensor, tuple):
maybe_tensor = maybe_tensor[0]
assert torch.is_tensor(maybe_tensor)
return maybe_tensor
# freezing llms
# 定义 set_module_requires_grad_ 函数,设置模块参数是否需要梯度
@beartype
def set_module_requires_grad_(
module: Module,
requires_grad: bool
):
for param in module.parameters():
param.requires_grad = requires_grad
# 定义 freeze_all_layers_ 函数,冻结所有层的参数
def freeze_all_layers_(module):
set_module_requires_grad_(module, False)
# function for returning an ordered list of modules, where the output of the module is the output of that transformer block layer
# ex. for x-transformers TransformerWrapper
# 定义 x_transformer_blocks 函数,返回 TransformerWrapper 中每个 transformer block 的模块列表
@beartype
def x_transformer_blocks(transformer: TransformerWrapper) -> List[Module]:
blocks = []
for layer in transformer.attn_layers.layers:
blocks.append(layer[-1])
return blocks[1::2]
# helper classes
# 定义 Recorder 类
class Recorder:
# Recorder 类的构造函数
@beartype
def __init__(
self,
outputs: Optional[List] = None,
forward_hook_get_hidden: HiddenPosition = 'output',
modules: Optional[List] = None,
):
self.output = default(outputs, [])
self.modules = modules
self.get_output_fn = partial(get_block_output_from_hook_outputs, forward_hook_get_hidden)
# Recorder 类的调用函数
def __call__(self, *args):
if exists(self.modules):
self.modules.append(args[0])
hidden = self.get_output_fn(*args)
self.output.append(hidden.detach())
# 定义 ExtractHiddensWrapper 类
class ExtractHiddensWrapper(Module):
# ExtractHiddensWrapper 类的构造函数
@beartype
def __init__(
self,
model: Module,
blocks: List[Module],
hidden_positions: SingularOrMany(HiddenPosition) = 'output'
):
super().__init__()
hidden_positions = cast_tuple(hidden_positions, len(blocks))
assert len(hidden_positions) == len(blocks)
self.model = model
self.outputs = []
self.modules = []
self.recorders = []
for block, hidden_position in zip(blocks, hidden_positions):
recorder = Recorder(self.outputs, hidden_position, self.modules)
self.recorders.append(recorder)
block.register_forward_hook(recorder)
# 定义一个方法用于前向传播,接受任意参数和关键字参数,可以选择是否返回被挂钩的模块
def forward(self, *args, return_hooked_modules = False, **kwargs):
# 调用模型的前向传播方法,传入参数和关键字参数
self.model(*args, **kwargs)
# 复制输出和模块字典
outputs = self.outputs.copy()
modules = self.modules.copy()
# 清空输出和模块字典
self.outputs.clear()
self.modules.clear()
# 如果不需要返回被挂钩的模块,则返回输出字典
if not return_hooked_modules:
return outputs
# 如果需要返回被挂钩的模块,则同时返回输出字典和模块字典
return outputs, modules
# 定义交叉注意力块类
class CrossAttentionBlock(Module):
# 初始化函数
@beartype
def __init__(
self,
dim,
dim_context,
linear_project_context = True, # 在论文中,他们对增强隐藏状态进行了投影。不确定是否需要,但最好先准确
pre_rmsnorm = False,
forward_hook_get_hidden: Union[
Literal['output'],
Literal['input']
] = 'output',
**kwargs
):
super().__init__()
# 如果需要预先进行 RMS 归一化,则创建 RMSNorm 对象
self.pre_rmsnorm = RMSNorm(dim) if pre_rmsnorm else nn.Identity()
self.context_proj = None
self.dim = dim
self.dim_context = dim_context
# 如果需要线性投影上下文,则创建线性层对象
if linear_project_context:
self.context_proj = nn.Linear(dim_context, dim)
dim_context = dim
# 创建注意力对象
self.attn = Attention(
dim = dim,
dim_context = dim_context,
zero_init_output = True,
gate_value_heads = True,
**kwargs
)
self.context = None
self.context_mask = None
self.forward_hook_get_hidden = forward_hook_get_hidden
# 设置掩码
def set_mask(self, mask: Tensor):
self.context_mask = mask
# 取消掩码
def unset_mask(self):
self.context_mask = None
# 前向传播函数
def forward(self, *hook_args):
x = get_block_output_from_hook_outputs(self.forward_hook_get_hidden, *hook_args)
context = self.context
assert exists(context)
maybe_enable_grad = torch.enable_grad if self.training else nullcontext
with maybe_enable_grad():
res = x
x = self.pre_rmsnorm(x)
if exists(self.context_proj):
context = self.context_proj(context)
out = self.attn(x, context, context_mask = self.context_mask) + res
return out
# 主类
@dataclass
class AugmentParams:
model: Module
hidden_position: SingularOrMany(HiddenPosition) = 'output'
transformer_blocks: Optional[List[Module]] = None
extract_blocks_fn: Optional[Callable[[Module], List[Module]]] = None
model_return_hiddens: bool = False
input_shape: Optional[Tuple[int, ...]] = None
connections: Optional[Tuple[Tuple[int, int], ...]] = None
connect_every_num_layers: int = 4 # 在论文中,他们做了 4 层
mask_kwarg: Optional[str] = None
# CALM 类
class CALM(Module):
# 初始化函数
@beartype
def __init__(
self,
anchor_llm: Module,
augment_llms: SingularOrMany(AugmentParams),
*,
attn_kwargs: dict = dict(
linear_project_context = True,
pre_rmsnorm = True,
flash = True
),
anchor_extract_blocks_fn: Callable[[Module], List[Module]] = None,
anchor_transformer_blocks: Optional[List[Module]] = None,
anchor_hidden_position: SingularOrMany(HiddenPosition) = 'output',
pad_id: int = -1
def state_dict(self):
return self.cross_attns.state_dict()
def load_state_dict(self, pkg, strict = False):
self.cross_attns.load_state_dict(pkg, strict = strict)
def parameters(self):
return self.cross_attns.parameters()
def release_cross_attn_contexts(self):
for one_augment_cross_attns in self.cross_attns:
for cross_attn in one_augment_cross_attns:
cross_attn.context = None
def forward_augments(
self,
prompt: Tensor,
prompt_mask: Optional[SingularOrMany(SequenceOf(Tensor))] = None
):
# 如果只提供一个提示并且有多个增强LLM,则将该提示输入到所有增强LLM中
num_augment_llms = len(self.augment_llms)
prompts = cast_tuple(prompt, num_augment_llms)
assert len(prompts) == num_augment_llms
# 提示掩码
if not exists(prompt_mask):
prompt_mask = tuple((p != self.pad_id if not torch.is_floating_point(p) else None) for p in prompts)
prompt_mask = cast_tuple(prompt_mask, num_augment_llms)
prompt_masks = prompt_mask # 在这一点上,应该是复数
assert len(prompt_masks) == num_augment_llms
# 调用增强LLM,使用前向钩子收集隐藏状态
augments_hiddens = []
with torch.no_grad():
self.augment_llms.eval()
for augment_llm, params, prompt, prompt_mask in zip(self.augment_llms, self.augment_llms_params, prompts, prompt_masks):
augment_llm_kwarg = dict()
if exists(params.mask_kwarg):
augment_llm_kwarg = {params.mask_kwarg: prompt_mask}
one_augment_hiddens = augment_llm(prompt, **augment_llm_kwarg)
augments_hiddens.append(one_augment_hiddens)
# 为锚点前向设置每个交叉注意力块的上下文
for one_augment_hiddens, one_augment_cross_attns, one_augment_connections in zip(augments_hiddens, self.cross_attns, self.connections):
for (augment_layer_index, _), cross_attn in zip(one_augment_connections, one_augment_cross_attns):
cross_attn.context = one_augment_hiddens[augment_layer_index - 1]
return prompts, prompt_masks
@contextmanager
def set_cross_attn_masks(self, masks):
# 为交叉注意力设置上下文掩码
for one_cross_attn, mask in zip(self.cross_attns, masks):
for cross_attn in one_cross_attn:
cross_attn.set_mask(mask)
yield
# 取消设置上下文掩码
for one_cross_attn in self.cross_attns:
for cross_attn in one_cross_attn:
cross_attn.unset_mask()
@torch.no_grad()
def generate(
self,
prompt: Tensor,
seq_len: int,
prompt_mask: Optional[SingularOrMany(SequenceOf(Tensor))] = None,
filter_fn: Callable = top_p,
filter_kwargs: dict = dict(
thres = 0.9
)
):
batch, device = prompt.shape[0], next(self.cross_attns.parameters()).device
self.eval()
# 在所有增强模型上运行前向并收集隐藏状态
prompts, prompt_masks = self.forward_augments(prompt = prompt, prompt_mask = prompt_mask)
with self.set_cross_attn_masks(prompt_masks):
# 采样
generated = sample(
self.anchor_llm,
prompt,
seq_len = seq_len,
filter_fn = filter_fn,
filter_kwargs = filter_kwargs
)
self.release_cross_attn_contexts()
return generated
@beartype
def forward(
self,
seq: Tensor,
*,
prompt: SingularOrMany(Tensor),
prompt_mask: Optional[SingularOrMany(Tensor)] = None,
mask: Optional[Tensor] = None,
return_loss = True,
anchor_llm_in_train_mode = True # 对此不确定
):
# 如果需要返回损失值,则将交叉注意力模型设置为训练模式
if return_loss:
self.cross_attns.train()
# 如果锚定语言模型需要在训练模式下,则设置为训练模式,否则设置为评估模式
if anchor_llm_in_train_mode:
self.anchor_llm.train()
else:
self.anchor_llm.eval()
# 将序列截断,去掉最后一个字符,用于输入和标签
seq, labels = seq[:, :-1], seq[:, 1:]
# 在所有数据增强模型上运行前向传播,并收集隐藏状态
prompts, prompt_masks = self.forward_augments(prompt=prompt, prompt_mask=prompt_mask)
# 设置交叉注意力模型的掩码
with self.set_cross_attn_masks(prompt_masks):
# 调用锚定语言模型,该模型应该处理与增强语言模型隐藏状态的交叉注意力
logits = self.anchor_llm(seq)
# 释放交叉注意力上下文
self.release_cross_attn_contexts()
# 断言锚定语言模型返回的 logits 维度应为 (batch, seq, num tokens)
assert logits.ndim == 3, 'anchor llm should return logits in the shape (batch, seq, num tokens)'
# 返回用于解码的 logits
if not return_loss:
return logits
# 考虑提示掩码
if exists(mask):
# 如果存在掩码,则使用掩码填充标签
labels = labels.masked_fill(~mask[:, 1:], self.pad_id)
# 用于微调
loss = F.cross_entropy(
rearrange(logits, 'b n c -> b c n'),
labels,
ignore_index=self.pad_id
)
return loss
# 定义一个循环生成器,用于循环遍历数据加载器中的批次数据
def cycle(dl):
while True:
for batch in dl:
yield batch
# 使用装饰器自动解包模型
@auto_unwrap_model()
class FineTuner:
# 初始化方法,接收多个参数
@beartype
def __init__(
self,
calm: CALM,
*,
num_train_steps: int,
learning_rate: float,
weight_decay: float,
batch_size: int,
dataset: Dataset,
data_kwarg_names: Tuple[str, ...] = ('seq', 'mask', 'prompt'),
accelerate_kwargs: dict = dict(),
checkpoint_every: int = 1000,
checkpoint_path: str = './checkpoints',
scheduler: Optional[Type[_LRScheduler]] = None,
scheduler_kwargs: dict = dict(),
warmup_steps: int = 1000,
max_grad_norm = 0.5,
grad_accum_steps = 1
):
# 初始化加速器
self.accelerator = Accelerator(**accelerate_kwargs)
# 创建数据加载器
self.dl = DataLoader(dataset, batch_size = batch_size, shuffle = True, drop_last = True)
self.data_kwarg_names = data_kwarg_names
# 设置模型
self.model = calm
# 创建 Adam 优化器
adam = get_adam_optimizer(
calm.parameters(),
lr = learning_rate,
wd = weight_decay
)
# 初始化优化器和学习率调度器
self.optimizer = OptimizerWithWarmupSchedule(
accelerator = self.accelerator,
optimizer = adam,
scheduler = scheduler,
scheduler_kwargs = scheduler_kwargs,
warmup_steps = warmup_steps,
max_grad_norm = max_grad_norm
)
self.step = 0
self.num_train_steps = num_train_steps
self.grad_accum_steps = grad_accum_steps
self.checkpoint_every = checkpoint_every
self.checkpoint_path = Path(checkpoint_path)
self.checkpoint_path.mkdir(exist_ok = True, parents = True)
# 判断当前进程是否为主进程
@property
def is_main(self):
return self.accelerator.is_main_process
# 打印信息
def print(self, msg):
self.accelerator.print(msg)
# 保存模型和优化器状态
def save(self, filename: str, overwrite: bool = True):
path = self.checkpoint_path / filename
assert overwrite or not path.exists()
pkg = dict(
model = self.model.state_dict(),
optimizer = self.optimizer.state_dict(),
step = self.step
)
torch.save(pkg, str(path))
# 加载模型和优化器状态
def load(self, filename: str):
path = self.checkpoint_path / filename
assert path.exists()
pkg = torch.load(str(path))
self.model.load_state_dict(pkg['model'])
self.optimizer.load_state_dict(pkg['optimizer'])
self.step = pkg['step']
# 定义 FineTuner 类的调用方法
def __call__(self, forward_kwargs: dict = dict()):
dl_iter = cycle(self.dl)
self.model.train()
for step in range(self.step, self.num_train_steps):
for context in model_forward_contexts(
model = self.model,
accelerator = self.accelerator,
grad_accum_steps = self.grad_accum_steps
):
with context():
data = next(dl_iter)
if not isinstance(data, dict):
data = dict(zip(self.data_kwarg_names, data))
loss = self.model(**data, **forward_kwargs)
self.accelerator.backward(loss / self.grad_accum_steps)
self.print(f'{step + 1}: {loss.item():.3f}')
self.optimizer.step()
self.optimizer.zero_grad()
self.step += 1
self.accelerator.wait_for_everyone()
if self.is_main and not (self.step % self.checkpoint_every):
num = self.step // self.checkpoint_every
self.save(f'checkpoint.{num}.pt')
self.accelerator.wait_for_everyone()
self.print('training complete')
self.save('checkpoint.-1.pt')
.\lucidrains\CALM-pytorch\CALM_pytorch\sampling_utils.py
import torch # 导入 PyTorch 库
import torch.nn.functional as F # 导入 PyTorch 中的函数模块
from torch import Tensor # 导入 PyTorch 中的张量
from torch.nn import Module # 导入 PyTorch 中的神经网络模块
from torch.nn.utils.rnn import pad_sequence # 导入 PyTorch 中的序列填充函数
from beartype import beartype # 导入 beartype 库中的类型检查装饰器
from beartype.typing import Optional, Callable, List, Tuple # 导入 beartype 库中的类型注解
from einops import rearrange # 导入 einops 库中的重排函数
from tqdm import tqdm # 导入 tqdm 库中的进度条显示函数
def exists(v): # 定义函数,判断变量是否存在
return v is not None # 返回变量是否不为 None
def default(v, d): # 定义函数,返回变量或默认值
return v if exists(v) else d # 如果变量存在则返回变量,否则返回默认值
# 采样辅助函数
def log(t, eps = 1e-20): # 定义函数,计算张量的对数
return torch.log(t.clamp(min = eps)) # 返回张量的对数,避免小于 eps 的值
def gumbel_noise(t): # 定义函数,生成 Gumbel 噪声
noise = torch.zeros_like(t).uniform_(0, 1) # 生成与输入张量相同大小的均匀分布噪声
return -log(-log(noise)) # 返回 Gumbel 噪声
def gumbel_sample(t, temperature = 1., dim = -1, keepdim = True, eps = 1e-10): # 定义函数,使用 Gumbel 分布进行采样
return ((t / max(temperature, eps)) + gumbel_noise(t)).argmax(dim = dim, keepdim = keepdim) # 返回 Gumbel 采样结果
# nucleus
def top_p(logits, thres = 0.9): # 定义函数,根据 top-p 策略进行筛选
sorted_logits, sorted_indices = torch.sort(logits, descending = True) # 对 logits 进行降序排序
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim = -1), dim = -1) # 计算累积概率
sorted_indices_to_remove = cum_probs > thres # 根据阈值筛选需要移除的索引
sorted_indices_to_remove = F.pad(sorted_indices_to_remove, (1, -1), value = False) # 对需要移除的索引进行填充
sorted_logits[sorted_indices_to_remove] = float('-inf') # 将需要移除的 logits 置为负无穷
return sorted_logits.scatter(1, sorted_indices, sorted_logits) # 返回根据 top-p 策略筛选后的 logits
# topk
def top_k(logits, frac_num_tokens = 0.1, k: Optional[int] = None): # 定义函数,根据 top-k 策略进行筛选
num_tokens = logits.shape[-1] # 获取 logits 的最后一个维度大小
k = default(k, ceil(frac_num_tokens * num_tokens)) # 计算 k 值
k = min(k, num_tokens) # 取 k 和 num_tokens 中的较小值
val, ind = torch.topk(logits, k) # 获取 top-k 的值和索引
probs = torch.full_like(logits, float('-inf')) # 创建与 logits 相同大小的全为负无穷的张量
probs.scatter_(1, ind, val) # 根据 top-k 的索引和值填充 probs
return probs # 返回根据 top-k 策略筛选后的 logits
# 解码
@torch.no_grad() # 禁用梯度计算
@beartype # 使用 beartype 类型检查装饰器
def sample( # 定义函数,生成序列样本
net: Module, # 神经网络模型
prompts, # 输入的提示序列
seq_len: int, # 生成序列的长度
temperature = 1., # 温度参数
filter_fn: Callable = top_p, # 筛选函数,默认为 top-p
filter_kwargs: dict = dict(), # 筛选函数的参数
pad_id: int = -1, # 填充标识符
eos_id: Optional[int] = None, # 结束标识符
output_keep_prompt = False # 是否保留提示序列
):
device = next(net.parameters()).device # 获取模型参数所在的设备
net.eval() # 设置模型为评估模式
if isinstance(prompts, (tuple, list)): # 如果提示序列是元组或列表
prompts = pad_sequence(prompts, batch_first = True, padding_value = pad_id) # 对提示序列进行填充
batch, prompts_tensor_len = prompts.shape # 获取提示序列的形状信息
batch_arange = torch.arange(batch, device = device)[..., None] # 创建批次索引张量
prompt_lens = (prompts != pad_id).sum(dim = -1) # 计算提示序列的长度
curr_seq_indices = prompt_lens[..., None] # 当前序列索引
out = prompts.clone() # 克隆提示序列作为输出序列
pbar = tqdm( # 创建进度条
initial = out.shape[-1], # 初始值
total = seq_len, # 总步数
desc = 'sampling' # 描述
)
while (curr_seq_indices < seq_len).any(): # 当当前序列索引小于生成序列长度时循环
out = F.pad(out, (0, 1), value = pad_id) # 对输出序列进行填充
net_input = out.masked_fill(out == pad_id, 0) # 将填充值替换为 0
logits = net(net_input) # 输入网络获取 logits
logits = logits[batch_arange, curr_seq_indices] # 根据当前序列索引获取 logits
logits = rearrange(logits, 'b 1 d -> b d') # 重排 logits 的维度
logits = filter_fn(logits, **filter_kwargs) # 根据筛选函数筛选 logits
sampled_tokens = gumbel_sample(logits, temperature = temperature, dim = -1) # 使用 Gumbel 采样获取 tokens
out[batch_arange, curr_seq_indices] = sampled_tokens # 更新输出序列
curr_seq_indices += 1 # 当前序列索引加一
curr_seq_indices.clamp_(max = seq_len) # 限制当前序列索引的最大值为生成序列长度
pbar.update(1) # 更新进度条
if not exists(eos_id): # 如果结束标识符不存在
continue # 继续下一次循环
is_eos_mask = out == eos_id # 获取结束标识符的掩码
all_eos = is_eos_mask.any(dim = -1).all() # 判断是否所有序列都包含结束标识符
if all_eos: # 如果所有序列都包含结束标识符
break # 跳出循环
pbar.close() # 关闭进度条
out = out[:, :seq_len] # 截取生成序列的长度
if exists(eos_id): # 如果结束标识符存在
is_eos_mask = out == eos_id # 获取结束标识符的掩码
after_eos_mask = F.pad(is_eos_mask.cumsum(dim = -1) > 0, (1, -1), value = False) # 获取结束标识符后的掩码
out = out.masked_fill_(after_eos_mask, pad_id) # 将结束标识符后的位置填充为填充标识符
if output_keep_prompt: # 如果需要保留提示序列
return out # 返回输出序列
prompt_mask = torch.arange(out.shape[-1], device = device) < prompt_lens[..., None] # 创建提示序列的掩码
generated_seq_mask = out != pad_id & ~prompt_mask # 生成序列的掩码
seq_lens = generated_seq_mask.sum(dim = -1).tolist() # 计算生成序列的长度
return out[generated_seq_mask].split(seq_lens) # 返回根据生成序列掩码拆分后的结果
.\lucidrains\CALM-pytorch\CALM_pytorch\__init__.py
# 从 CALM_pytorch.CALM 模块中导入以下类和函数
from CALM_pytorch.CALM import (
AugmentParams, # 导入 AugmentParams 类
ExtractHiddensWrapper, # 导入 ExtractHiddensWrapper 类
CALM, # 导入 CALM 类
FineTuner # 导入 FineTuner 类
)
CALM - Pytorch
Implementation of CALM from the paper LLM Augmented LLMs: Expanding Capabilities through Composition, out of Google Deepmind
Can support any number of augmentation LLMs
Install
$ pip install CALM-pytorch
Appreciation
- A16Z Open Source AI Grant Program and 🤗 Huggingface for the generous sponsorships, as well as my other sponsors, for affording me the independence to open source current artificial intelligence research
Usage
ex. with x-transformers
import torch
from x_transformers import TransformerWrapper, Decoder
augment_llm = TransformerWrapper(
num_tokens = 20000,
max_seq_len = 1024,
attn_layers = Decoder(
dim = 512,
depth = 12,
heads = 8
)
)
anchor_llm = TransformerWrapper(
num_tokens = 20000,
max_seq_len = 1024,
attn_layers = Decoder(
dim = 512,
depth = 2,
heads = 8
)
)
# import CALM wrapper
from CALM_pytorch import CALM, AugmentParams
calm = CALM(
anchor_llm,
augment_llms = AugmentParams(
model = augment_llm,
connect_every_num_layers = 4
)
)
# mock input
seq = torch.randint(0, 20000, (1, 1024))
mask = torch.ones((1, 1024)).bool()
prompt = torch.randint(0, 20000, (1, 256))
# forward for finetuning loss
loss = calm(
seq,
mask = mask,
prompt = prompt
)
loss.backward()
# after much training, prompt the composed model
generated = calm.generate(
prompt = seq[:, :1],
seq_len = 1024
)
To use a handy trainer class using 🤗 Accelerate, just import FineTuner and use as follows
trainer = FineTuner(
calm = calm,
dataset = dataset, # returns a dictionary of input kwargs to calm - dict(seq: Tensor, mask: Tensor, prompt: Tensor). it can also return a Tuple, in which data_kwargs needs to be set to the correct ordered value of kwarg names
batch_size = 16,
num_train_steps = 10000,
learning_rate = 3e-4,
weight_decay = 1e-2,
warmup_steps = 1000,
checkpoint_every = 1000
)
trainer()
# checkpoints of the cross attention parameters will be saved to ./checkpoints every 1000 steps
To explore multiple augmentation LLMs, simply pass in a list for augment_llm
ex.
calm = CALM(
anchor_llm = anchor_llm,
augment_llm = [AugmentParams(augment_llm1), AugmentParams(augment_llm2)] # pass in a list of AugmentParams wrapping model and other hparams specific to that transformer
)
Say you want to explore different types of connectivity between anchor and augmentation model(s), just pass in the connections as a tuple of tuple integer pairs, specifying the anchor to augment layer number.
calm = CALM(
anchor_llm = anchor_llm,
augment_llms = (
AugmentParams(
model = augment_llm1,
connections = (
(1, 12), # 1st layer of augment llm1 attended to by 12th layer of anchor llm
(2, 12),
(3, 12),
(4, 12),
),
),
AugmentParams(
model = augment_llm2,
connections = (
(6, 1), # 6th layer of augment llm2 attended to by 1st layer of anchor llm
(6, 2),
(12, 12),
)
)
)
)
CALM setup with 2 specialized augmentation LLMs + a vision transformer
import torch
# pip install vit-pytorch x-transformers
from vit_pytorch.vit import ViT, Attention
from x_transformers import TransformerWrapper, Encoder, Decoder
anchor_llm = TransformerWrapper(
num_tokens = 20000,
max_seq_len = 1024,
attn_layers = Decoder(
dim = 16,
dim_head = 2,
depth = 12,
heads = 8
)
)
augment_llm1 = TransformerWrapper(
num_tokens = 20000,
max_seq_len = 1024,
attn_layers = Encoder(
dim = 16,
dim_head = 2,
depth = 12,
heads = 8
)
)
augment_llm2 = TransformerWrapper(
num_tokens = 20000,
max_seq_len = 1024,
attn_layers = Encoder(
dim = 16,
dim_head = 2,
depth = 12,
heads = 8
)
)
vit = ViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 256,
depth = 6,
heads = 16,
mlp_dim = 2048
)
# calm
from CALM_pytorch import CALM, AugmentParams, FineTuner
calm = CALM(
anchor_llm = anchor_llm,
augment_llms = (
AugmentParams(
model = augment_llm1,
mask_kwarg = 'mask'
),
AugmentParams(
model = augment_llm2,
mask_kwarg = 'mask'
),
AugmentParams(
model = vit,
input_shape = (3, 256, 256),
hidden_position = 'input',
extract_blocks_fn = lambda vit: [m for m in vit.modules() if isinstance(m, Attention)]
)
),
attn_kwargs = dict(
linear_project_context = True,
pre_rmsnorm = True,
flash = True
)
)
seq = torch.randint(0, 20000, (1, 1024))
mask = torch.ones((1, 1024)).bool()
prompt = (
torch.randint(0, 20000, (1, 256)),
torch.randint(0, 20000, (1, 256)),
torch.randn(1, 3, 256, 256)
)
loss = calm(
seq,
mask = mask,
prompt = prompt
)
loss.backward()
Todo
-
figure out how to correctly mask augment llm tokens
-
auto-derive model dimensions with dummy input
-
take care of finetuning training logic
-
show example of manual definitions of custom connectivity between 2+ attention networks
-
if anchor and augment transformer block modules are directly passed in (without extraction fn), run a dummy input through both networks and order them correctly using hooks
-
fix example for x-transformers, as in x-transformers, depth is actually depth x 2, taking hiddens from after attention and ff
-
when finely specifying hidden positions, make sure to reorder it if the transformer blocks themselves were passed in and not ordered to begin with
-
extend to a list of augmentation llms
- full connectivity customization
- custom number of augmentation layers per augmetation llm
- make simple vit work
- refactor so extraction fn, mask kwarg, and other related hparams are grouped together under a dictionary of {[augment_llm_name]: {augment_llm_related_hparams}} - use dataclasses
- show example
-
take care of caching the augment hiddens when sampling. forget about anchor kv cache for now
- logic for not releasing the saved output from recorder, for inference
- managing cross attention block state for popping the saved output from the recorder
- move the augmentation forwards into one shared method, and craft out sampling method for anchor
-
able to wire up with just module names
-
show an example with giving the LLM ability to hear as well, using hubert or wav2vec wrappers
-
handle a wrapper or function that takes in the sequence and prompt length, and auto derives the inputs to CALM
-
add an option for self attention path way with memory tokens attending to hidden states of all augmentation llms, akin to what was done with Zorro
Citations
@inproceedings{Bansal2024LLMAL,
title = {LLM Augmented LLMs: Expanding Capabilities through Composition},
author = {Rachit Bansal and Bidisha Samanta and Siddharth Dalmia and Nitish Gupta and Shikhar Vashishth and Sriram Ganapathy and Abhishek Bapna and Prateek Jain and Partha Pratim Talukdar},
year = {2024},
url = {https://api.semanticscholar.org/CorpusID:266755751}
}
.\lucidrains\CALM-pytorch\setup.py
# 导入设置和查找包的函数
from setuptools import setup, find_packages
# 设置包的信息
setup(
name = 'CALM-Pytorch', # 包的名称
packages = find_packages(exclude=[]), # 查找所有包
version = '0.2.1', # 版本号
license='MIT', # 许可证
description = 'CALM - Pytorch', # 描述
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
long_description_content_type = 'text/markdown', # 长描述内容类型
url = 'https://github.com/lucidrains/CALM-pytorch', # URL
keywords = [
'artificial intelligence', # 关键词
'deep learning', # 关键词
'composing LLMs' # 关键词
],
install_requires = [ # 安装依赖
'accelerate', # 加速库
'beartype', # 类型检查库
'einops>=0.7.0', # 数据重塑库
'pytorch-custom-utils>=0.0.11', # PyTorch自定义工具库
'torch>=2.0', # PyTorch库
'tqdm', # 进度条库
'x-transformers>=1.27.3' # 自定义Transformer库
],
classifiers=[ # 分类器
'Development Status :: 4 - Beta', # 开发状态
'Intended Audience :: Developers', # 目标受众
'Topic :: Scientific/Engineering :: Artificial Intelligence', # 主题
'License :: OSI Approved :: MIT License', # 许可证
'Programming Language :: Python :: 3.6', # 编程语言
],
)
.\lucidrains\charformer-pytorch\charformer_pytorch\charformer_pytorch.py
# 导入 math 模块
import math
# 从 math 模块中导入 gcd 函数
from math import gcd
# 导入 functools 模块
import functools
# 导入 torch 模块
import torch
# 从 torch 模块中导入 nn, F, einsum
import torch.nn.functional as F
from torch import nn, einsum
# 从 einops 模块中导入 rearrange, reduce, repeat
from einops import rearrange, reduce, repeat
# 从 einops.layers.torch 模块中导入 Rearrange
from einops.layers.torch import Rearrange
# 辅助函数
# 判断变量是否存在
def exists(val):
return val is not None
# 计算多个数的最小公倍数
def lcm(*numbers):
return int(functools.reduce(lambda x, y: int((x * y) / gcd(x, y)), numbers, 1))
# 计算带有掩码的张量的均值
def masked_mean(tensor, mask, dim = -1):
diff_len = len(tensor.shape) - len(mask.shape)
mask = mask[(..., *((None,) * diff_len))]
tensor.masked_fill_(~mask, 0.)
total_el = mask.sum(dim = dim)
mean = tensor.sum(dim = dim) / total_el.clamp(min = 1.)
mean.masked_fill_(total_el == 0, 0.)
return mean
# 计算下一个可被整除的长度
def next_divisible_length(seqlen, multiple):
return math.ceil(seqlen / multiple) * multiple
# 将张量填充到指定的倍数
def pad_to_multiple(tensor, multiple, *, seq_dim, dim = -1, value = 0.):
seqlen = tensor.shape[seq_dim]
length = next_divisible_length(seqlen, multiple)
if length == seqlen:
return tensor
remainder = length - seqlen
pad_offset = (0,) * (-1 - dim) * 2
return F.pad(tensor, (*pad_offset, 0, remainder), value = value)
# 辅助类
# 填充层
class Pad(nn.Module):
def __init__(self, padding, value = 0.):
super().__init__()
self.padding = padding
self.value = value
def forward(self, x):
return F.pad(x, self.padding, value = self.value)
# 深度卷积层
class DepthwiseConv1d(nn.Module):
def __init__(self, dim_in, dim_out, kernel_size):
super().__init__()
self.conv = nn.Conv1d(dim_in, dim_out, kernel_size, groups = dim_in)
self.proj_out = nn.Conv1d(dim_out, dim_out, 1)
def forward(self, x):
x = self.conv(x)
return self.proj_out(x)
# 主类
class GBST(nn.Module):
def __init__(
self,
*,
num_tokens,
dim,
max_block_size = None,
blocks = None,
downsample_factor = 4,
score_consensus_attn = True
):
super().__init__()
assert exists(max_block_size) ^ exists(blocks), 'either max_block_size or blocks are given on initialization'
self.token_emb = nn.Embedding(num_tokens, dim)
if exists(blocks):
assert isinstance(blocks, tuple), 'blocks must be a tuple of block sizes'
self.blocks = tuple(map(lambda el: el if isinstance(el, tuple) else (el, 0), blocks))
assert all([(offset < block_size) for block_size, offset in self.blocks]), 'offset must be always smaller than the block size'
max_block_size = max(list(map(lambda t: t[0], self.blocks)))
else:
self.blocks = tuple(map(lambda el: (el, 0), range(1, max_block_size + 1)))
self.pos_conv = nn.Sequential(
Pad((0, 0, 0, max_block_size - 1)),
Rearrange('b n d -> b d n'),
DepthwiseConv1d(dim, dim, kernel_size = max_block_size),
Rearrange('b d n -> b n d')
)
self.score_fn = nn.Sequential(
nn.Linear(dim, 1),
Rearrange('... () -> ...')
)
self.score_consensus_attn = score_consensus_attn
assert downsample_factor <= max_block_size, 'final downsample factor should be less than the maximum block size'
self.block_pad_multiple = lcm(*[block_size for block_size, _ in self.blocks])
self.downsample_factor = downsample_factor
.\lucidrains\charformer-pytorch\charformer_pytorch\__init__.py
# 从 charformer_pytorch.charformer_pytorch 模块中导入 GBST 类
from charformer_pytorch.charformer_pytorch import GBST

Charformer - Pytorch
Implementation of the GBST (gradient-based subword tokenization) module from the Charformer paper, in Pytorch. The paper proposes a module that automatically learns subword representations, obviating the need for tokenizers in the encoder setting.
AI Coffee Break with Letitia video
Install
$ pip install charformer-pytorch
Usage
import torch
from charformer_pytorch import GBST
tokenizer = GBST(
num_tokens = 257, # number of tokens, should be 256 for byte encoding (+ 1 special token for padding in this example)
dim = 512, # dimension of token and intra-block positional embedding
max_block_size = 4, # maximum block size
downsample_factor = 4, # the final downsample factor by which the sequence length will decrease by
score_consensus_attn = True # whether to do the cheap score consensus (aka attention) as in eq. 5 in the paper
)
tokens = torch.randint(0, 257, (1, 1023)) # uneven number of tokens (1023)
mask = torch.ones(1, 1023).bool()
# both tokens and mask will be appropriately downsampled
tokens, mask = tokenizer(tokens, mask = mask) # (1, 256, 512), (1, 256)
# now pass this on to your transformer
Deviating from the paper, you can also specify block size(s) with different offsets. This is to cover a potential use-case for genomics pre-training, where the tokenizer should be able to learn the correct frame. Simply omit the max_block_size, and pass in blocks as a list of tuples of tuples, each tuple with the format (block size, offset). Offsets must be less than the block size
import torch
from charformer_pytorch import GBST
tokenizer = GBST(
num_tokens = 4 + 1,
dim = 512,
blocks = ((3, 0), (3, 1), (3, 2)), # block size of 3, with offsets of 0, 1, 2
downsample_factor = 3,
score_consensus_attn = True
).cuda()
basepairs = torch.randint(0, 4, (1, 1023)).cuda()
mask = torch.ones(1, 1023).bool().cuda()
# both basepairs and mask will be appropriately downsampled
basepairs, mask = tokenizer(basepairs, mask = mask)
Citations
@misc{tay2021charformer,
title = {Charformer: Fast Character Transformers via Gradient-based Subword Tokenization},
author = {Yi Tay and Vinh Q. Tran and Sebastian Ruder and Jai Gupta and Hyung Won Chung and Dara Bahri and Zhen Qin and Simon Baumgartner and Cong Yu and Donald Metzler},
year = {2021},
eprint = {2106.12672},
archivePrefix = {arXiv},
primaryClass = {cs.CL}
}
.\lucidrains\charformer-pytorch\setup.py
# 导入设置工具和查找包的函数
from setuptools import setup, find_packages
# 设置包的元数据
setup(
name = 'charformer-pytorch', # 包的名称
packages = find_packages(), # 查找所有包
version = '0.0.4', # 版本号
license='MIT', # 许可证
description = 'Charformer - Pytorch', # 描述
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
url = 'https://github.com/lucidrains/charformer-pytorch', # 项目链接
keywords = [
'artificial intelligence', # 关键词
'deep learning', # 关键词
'learned tokenization' # 关键词
],
install_requires=[
'einops>=0.3', # 安装所需的依赖包
'torch>=1.6' # 安装所需的依赖包
],
classifiers=[
'Development Status :: 4 - Beta', # 分类器
'Intended Audience :: Developers', # 分类器
'Topic :: Scientific/Engineering :: Artificial Intelligence', # 分类器
'License :: OSI Approved :: MIT License', # 分类器
'Programming Language :: Python :: 3.6', # 分类器
],
)
.\lucidrains\chroma-pytorch\chroma_pytorch\chroma_pytorch.py
import torch # 导入 PyTorch 库
from torch import nn, einsum # 从 PyTorch 库中导入 nn 模块和 einsum 函数
from einops import rearrange, repeat # 从 einops 库中导入 rearrange 和 repeat 函数
import math # 导入 math 库
from pathlib import Path # 从 pathlib 库中导入 Path 类
from random import random # 从 random 库中导入 random 函数
from functools import partial # 从 functools 库中导入 partial 函数
from multiprocessing import cpu_count # 从 multiprocessing 库中导入 cpu_count 函数
import torch # 重新导入 PyTorch 库
from torch import nn, einsum # 从 PyTorch 库中重新导入 nn 模块和 einsum 函数
from torch.special import expm1 # 从 PyTorch 库中导入 expm1 函数
import torch.nn.functional as F # 从 PyTorch 库中导入 F 模块
from torch.utils.data import Dataset, DataLoader # 从 PyTorch 库中导入 Dataset 和 DataLoader 类
from torch.optim import Adam # 从 PyTorch 库中导入 Adam 优化器
from torchvision import transforms as T, utils # 从 torchvision 库中导入 transforms 模块和 utils 模块
from einops import rearrange, reduce, repeat # 从 einops 库中重新导入 rearrange、reduce 和 repeat 函数
from einops.layers.torch import Rearrange # 从 einops 库中导入 Rearrange 类
from tqdm.auto import tqdm # 从 tqdm 库中导入 tqdm 函数
from ema_pytorch import EMA # 从 ema_pytorch 库中导入 EMA 类
from accelerate import Accelerator # 从 accelerate 库中导入 Accelerator 类
# helpers functions
def exists(x): # 定义 exists 函数,判断变量 x 是否存在
return x is not None
def default(val, d): # 定义 default 函数,如果 val 存在则返回 val,否则返回 d()
if exists(val):
return val
return d() if callable(d) else d
def cycle(dl): # 定义 cycle 函数,循环生成数据集 dl 中的数据
while True:
for data in dl:
yield data
def has_int_squareroot(num): # 定义 has_int_squareroot 函数,判断 num 是否有整数平方根
return (math.sqrt(num) ** 2) == num
def num_to_groups(num, divisor): # 定义 num_to_groups 函数,将 num 分成 divisor 组
groups = num // divisor
remainder = num % divisor
arr = [divisor] * groups
if remainder > 0:
arr.append(remainder)
return arr
def convert_image_to(img_type, image): # 定义 convert_image_to 函数,将图像转换为指定类型
if image.mode != img_type:
return image.convert(img_type)
return image
# small helper modules
class Residual(nn.Module): # 定义 Residual 类��实现残差连接
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, *args, **kwargs):
return self.fn(x, *args, **kwargs) + x
def Upsample(dim, dim_out = None): # 定义 Upsample 函数,上采样操作
return nn.Sequential(
nn.Upsample(scale_factor = 2, mode = 'nearest'),
nn.Conv2d(dim, default(dim_out, dim), 3, padding = 1)
)
def Downsample(dim, dim_out = None): # 定义 Downsample 函数,下采样操作
return nn.Conv2d(dim, default(dim_out, dim), 4, 2, 1)
class LayerNorm(nn.Module): # 定义 LayerNorm 类,实现层归一化
def __init__(self, dim):
super().__init__()
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
def forward(self, x):
eps = 1e-5 if x.dtype == torch.float32 else 1e-3
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) * (var + eps).rsqrt() * self.g
class PreNorm(nn.Module): # 定义 PreNorm 类,实现预归一化
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = LayerNorm(dim)
def forward(self, x):
x = self.norm(x)
return self.fn(x)
# positional embeds
class LearnedSinusoidalPosEmb(nn.Module): # 定义 LearnedSinusoidalPosEmb 类,实现学习的正弦位置嵌入
def __init__(self, dim):
super().__init__()
assert (dim % 2) == 0
half_dim = dim // 2
self.weights = nn.Parameter(torch.randn(half_dim))
def forward(self, x):
x = rearrange(x, 'b -> b 1')
freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1)
fouriered = torch.cat((x, fouriered), dim = -1)
return fouriered
# building block modules
class Block(nn.Module): # 定义 Block 类,实现基本块
def __init__(self, dim, dim_out, groups = 8):
super().__init__()
self.proj = nn.Conv2d(dim, dim_out, 3, padding = 1)
self.norm = nn.GroupNorm(groups, dim_out)
self.act = nn.SiLU()
def forward(self, x, scale_shift = None):
x = self.proj(x)
x = self.norm(x)
if exists(scale_shift):
scale, shift = scale_shift
x = x * (scale + 1) + shift
x = self.act(x)
return x
class ResnetBlock(nn.Module): # 定义 ResnetBlock 类,实现残差块
def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8):
super().__init__()
self.mlp = nn.Sequential(
nn.SiLU(),
nn.Linear(time_emb_dim, dim_out * 2)
) if exists(time_emb_dim) else None
self.block1 = Block(dim, dim_out, groups = groups)
self.block2 = Block(dim_out, dim_out, groups = groups)
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
# 定义前向传播函数,接受输入 x 和时间嵌入 time_emb
def forward(self, x, time_emb = None):
# 初始化 scale_shift 为 None
scale_shift = None
# 如果 self.mlp 和 time_emb 都存在
if exists(self.mlp) and exists(time_emb):
# 将 time_emb 输入到 self.mlp 中进行处理
time_emb = self.mlp(time_emb)
# 重新排列 time_emb 的维度,增加两个维度
time_emb = rearrange(time_emb, 'b c -> b c 1 1')
# 将 time_emb 拆分成两部分,分别赋值给 scale 和 shift
scale_shift = time_emb.chunk(2, dim = 1)
# 将输入 x 传入第一个块中进行处理
h = self.block1(x, scale_shift = scale_shift)
# 将处理后的结果传入第二个块中进行处理
h = self.block2(h)
# 返回处理后的结果与输入 x 经过残差卷积的结果之和
return h + self.res_conv(x)
class LinearAttention(nn.Module):
# 定义线性注意力机制模块
def __init__(self, dim, heads = 4, dim_head = 32):
# 初始化函数
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
hidden_dim = dim_head * heads
# 将输入转换为查询、键、值
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
self.to_out = nn.Sequential(
# 输出转换为指定维度
nn.Conv2d(hidden_dim, dim, 1),
# 对输出进行 LayerNorm 处理
LayerNorm(dim)
)
def forward(self, x):
# 前向传播函数
b, c, h, w = x.shape
qkv = self.to_qkv(x).chunk(3, dim = 1)
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv)
q = q.softmax(dim = -2)
k = k.softmax(dim = -1)
q = q * self.scale
v = v / (h * w)
context = torch.einsum('b h d n, b h e n -> b h d e', k, v)
out = torch.einsum('b h d e, b h d n -> b h e n', context, q)
out = rearrange(out, 'b h c (x y) -> b (h c) x y', h = self.heads, x = h, y = w)
return self.to_out(out)
class Attention(nn.Module):
# 定义注意力机制模块
def __init__(self, dim, heads = 4, dim_head = 32):
# 初始化函数
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
hidden_dim = dim_head * heads
# 将输入转换为查询、键、值
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x):
# 前向传播函数
b, c, h, w = x.shape
qkv = self.to_qkv(x).chunk(3, dim = 1)
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv)
q = q * self.scale
sim = einsum('b h d i, b h d j -> b h i j', q, k)
attn = sim.softmax(dim = -1)
out = einsum('b h i j, b h d j -> b h i d', attn, v)
out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w)
return self.to_out(out)
# model
class Unet(nn.Module):
# 定义 Unet 模型
def __init__(
self,
dim,
init_dim = None,
dim_mults=(1, 2, 4, 8),
channels = 3,
resnet_block_groups = 8,
learned_sinusoidal_dim = 16
):
# 调用父类的构造函数
super().__init__()
# 确定维度
self.channels = channels
input_channels = channels * 2
init_dim = default(init_dim, dim)
# 初始化卷积层,输入通道数为input_channels,输出通道数为init_dim,卷积核大小为7,填充为3
self.init_conv = nn.Conv2d(input_channels, init_dim, 7, padding = 3)
# 计算不同层次的维度
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
# 定义ResnetBlock类的部分参数
block_klass = partial(ResnetBlock, groups = resnet_block_groups)
# 时间嵌入
time_dim = dim * 4
sinu_pos_emb = LearnedSinusoidalPosEmb(learned_sinusoidal_dim)
fourier_dim = learned_sinusoidal_dim + 1
# 时间嵌入的多层感知机
self.time_mlp = nn.Sequential(
sinu_pos_emb,
nn.Linear(fourier_dim, time_dim),
nn.GELU(),
nn.Linear(time_dim, time_dim)
)
# 层次
self.downs = nn.ModuleList([])
self.ups = nn.ModuleList([])
num_resolutions = len(in_out)
# 遍历不同层次的维度
for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (num_resolutions - 1)
# 添加不同层次的模块到downs列表中
self.downs.append(nn.ModuleList([
block_klass(dim_in, dim_in, time_emb_dim = time_dim),
block_klass(dim_in, dim_in, time_emb_dim = time_dim),
Residual(PreNorm(dim_in, LinearAttention(dim_in))),
Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding = 1)
]))
mid_dim = dims[-1]
# 中间块
self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim)
self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim)
# 反向遍历不同层次的维度
for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
is_last = ind == (len(in_out) - 1)
# 添加不同层次的模块到ups列表中
self.ups.append(nn.ModuleList([
block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
Residual(PreNorm(dim_out, LinearAttention(dim_out))),
Upsample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, padding = 1)
]))
# 最终的残差块
self.final_res_block = block_klass(dim * 2, dim, time_emb_dim = time_dim)
self.final_conv = nn.Conv2d(dim, channels, 1)
def forward(self, x, time, x_self_cond = None):
# 默认x_self_cond为与x相同形状的全零张量
x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x))
x = torch.cat((x_self_cond, x), dim = 1)
x = self.init_conv(x)
r = x.clone()
t = self.time_mlp(time)
h = []
# 遍历downs列表中的模块
for block1, block2, attn, downsample in self.downs:
x = block1(x, t)
h.append(x)
x = block2(x, t)
x = attn(x)
h.append(x)
x = downsample(x)
x = self.mid_block1(x, t)
x = self.mid_attn(x)
x = self.mid_block2(x, t)
# 遍历ups列表中的模块
for block1, block2, attn, upsample in self.ups:
x = torch.cat((x, h.pop()), dim = 1)
x = block1(x, t)
x = torch.cat((x, h.pop()), dim = 1)
x = block2(x, t)
x = attn(x)
x = upsample(x)
x = torch.cat((x, r), dim = 1)
x = self.final_res_block(x, t)
return self.final_conv(x)
# 定义一个名为 Chroma 的类
class Chroma(nn.Module):
# 初始化方法
def __init__(
self,
model,
*,
image_size,
timesteps = 1000,
use_ddim = False,
noise_schedule = 'cosine',
time_difference = 0.
):
# 调用父类的初始化方法
super().__init__()
# 设置模型和通道数
self.model = model
self.channels = self.model.channels
# 设置图像大小和噪声调度
self.image_size = image_size
# 根据噪声调度选择不同的 log_snr 函数
if noise_schedule == "linear":
self.log_snr = beta_linear_log_snr
elif noise_schedule == "cosine":
self.log_snr = alpha_cosine_log_snr
else:
raise ValueError(f'invalid noise schedule {noise_schedule}')
# 设置采样时间步数和是否使用 ddim
self.timesteps = timesteps
self.use_ddim = use_ddim
# 设置时间差异
self.time_difference = time_difference
# 定义 device 属性
@property
def device(self):
return next(self.model.parameters()).device
# 获取采样时间步数
def get_sampling_timesteps(self, batch, *, device):
# 生成时间序列
times = torch.linspace(1., 0., self.timesteps + 1, device = device)
times = repeat(times, 't -> b t', b = batch)
times = torch.stack((times[:, :-1], times[:, 1:]), dim = 0)
times = times.unbind(dim = -1)
return times
# 生成样本
@torch.no_grad()
def ddpm_sample(self, shape, time_difference = None):
# 获取 batch 大小和设备
batch, device = shape[0], self.device
# 设置时间差异
time_difference = default(time_difference, self.time_difference)
# 获取采样时间步数
time_pairs = self.get_sampling_timesteps(batch, device = device)
# 生成随机噪声图像
img = torch.randn(shape, device=device)
x_start = None
# 循环采样时间步数
for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step', total = self.timesteps):
# 添加时间延迟
time_next = (time_next - self.time_difference).clamp(min = 0.)
# 获取噪声条件
noise_cond = self.log_snr(time)
# 获取预测的 x0
x_start = self.model(img, noise_cond, x_start)
# 限制 x0 的范围
x_start.clamp_(-1., 1.)
# 获取 log(snr)
log_snr = self.log_snr(time)
log_snr_next = self.log_snr(time_next)
log_snr, log_snr_next = map(partial(right_pad_dims_to, img), (log_snr, log_snr_next))
# 获取时间和下一个时间的 alpha 和 sigma
alpha, sigma = log_snr_to_alpha_sigma(log_snr)
alpha_next, sigma_next = log_snr_to_alpha_sigma(log_snr_next)
# 推导后验均值和方差
c = -expm1(log_snr - log_snr_next)
mean = alpha_next * (img * (1 - c) / alpha + c * x_start)
variance = (sigma_next ** 2) * c
log_variance = log(variance)
# 生成噪声
noise = torch.where(
rearrange(time_next > 0, 'b -> b 1 1 1'),
torch.randn_like(img),
torch.zeros_like(img)
)
# 更新图像
img = mean + (0.5 * log_variance).exp() * noise
return img
@torch.no_grad()
# 从给定形状中采样数据,可以指定时间差
def ddim_sample(self, shape, time_difference = None):
# 获取批次大小和设备
batch, device = shape[0], self.device
# 设置时间差,默认为self.time_difference
time_difference = default(time_difference, self.time_difference)
# 获取采样时间步
time_pairs = self.get_sampling_timesteps(batch, device = device)
# 生成符合正态分布的随机数据
img = torch.randn(shape, device = device)
x_start = None
# 遍历时间对
for times, times_next in tqdm(time_pairs, desc = 'sampling loop time step'):
# 获取时间和噪声水平
log_snr = self.log_snr(times)
log_snr_next = self.log_snr(times_next)
# 将噪声水平填充到与img相同的维度
padded_log_snr, padded_log_snr_next = map(partial(right_pad_dims_to, img), (log_snr, log_snr_next))
# 将噪声水平转换为alpha和sigma
alpha, sigma = log_snr_to_alpha_sigma(padded_log_snr)
alpha_next, sigma_next = log_snr_to_alpha_sigma(padded_log_snr_next)
# 添加时间延迟
times_next = (times_next - time_difference).clamp(min = 0.)
# 预测x0
x_start = self.model(img, log_snr, x_start)
# 限制x0的取值范围
x_start.clamp_(-1., 1.)
# 获取预测的噪声
pred_noise = (img - alpha * x_start) / sigma.clamp(min = 1e-8)
# 计算下一个x
img = x_start * alpha_next + pred_noise * sigma_next
return img
# 无梯度计算
@torch.no_grad()
def sample(self, batch_size = 16):
image_size, channels = self.image_size, self.channels
# 根据是否使用DDIM选择采样函数
sample_fn = self.ddpm_sample if not self.use_ddim else self.ddim_sample
return sample_fn((batch_size, channels, image_size, image_size))
# 前向传播函数
def forward(self, img, *args, **kwargs):
batch, c, h, w, device, img_size, = *img.shape, img.device, self.image_size
# 断言图像的高度和宽度必须为img_size
assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
# 生成随机时间
times = torch.zeros((batch,), device = device).float().uniform_(0, 1.)
# 生成噪声
noise = torch.randn_like(img)
# 获取噪声水平并填充到与img相同的维度
noise_level = self.log_snr(times)
padded_noise_level = right_pad_dims_to(img, noise_level)
alpha, sigma = log_snr_to_alpha_sigma(padded_noise_level)
# 添加噪声到图像
noised_img = alpha * img + sigma * noise
# 如果进行自条件训练,50%的概率从当前时间预测x_start,并用unet进行条件
# 这种技术会使训练速度减慢25%,但似乎显著降低FID
self_cond = None
if random() < 0.5:
with torch.no_grad():
self_cond = self.model(noised_img, noise_level).detach_()
# 预测并进行梯度下降
pred = self.model(noised_img, noise_level, self_cond)
return F.mse_loss(pred, img)
# trainer 类
class Trainer(object):
# 初始化方法
def __init__(
self,
diffusion_model,
folder,
*,
train_batch_size = 16,
gradient_accumulate_every = 1,
augment_horizontal_flip = True,
train_lr = 1e-4,
train_num_steps = 100000,
ema_update_every = 10,
ema_decay = 0.995,
adam_betas = (0.9, 0.99),
save_and_sample_every = 1000,
num_samples = 25,
results_folder = './results',
amp = False,
fp16 = False,
split_batches = True,
convert_image_to = None
):
# 调用父类的初始化方法
super().__init__()
# 初始化加速器
self.accelerator = Accelerator(
split_batches = split_batches,
mixed_precision = 'fp16' if fp16 else 'no'
)
# 设置是否使用 amp
self.accelerator.native_amp = amp
# 设置扩散模型
self.model = diffusion_model
# 检查 num_samples 是否有整数平方根
assert has_int_squareroot(num_samples), 'number of samples must have an integer square root'
self.num_samples = num_samples
self.save_and_sample_every = save_and_sample_every
# 设置训练批次大小和梯度累积频率
self.batch_size = train_batch_size
self.gradient_accumulate_every = gradient_accumulate_every
# 设置训练步数和图像大小
self.train_num_steps = train_num_steps
self.image_size = diffusion_model.image_size
# 数据集和数据加载器
self.ds = Dataset(folder, self.image_size, augment_horizontal_flip = augment_horizontal_flip, convert_image_to = convert_image_to)
dl = DataLoader(self.ds, batch_size = train_batch_size, shuffle = True, pin_memory = True, num_workers = cpu_count())
# 准备数据加载器
dl = self.accelerator.prepare(dl)
self.dl = cycle(dl)
# 优化器
self.opt = Adam(diffusion_model.parameters(), lr = train_lr, betas = adam_betas)
# 定期记录结果到文件夹
if self.accelerator.is_main_process:
self.ema = EMA(diffusion_model, beta = ema_decay, update_every = ema_update_every)
self.results_folder = Path(results_folder)
self.results_folder.mkdir(exist_ok = True)
# 步数计数器状态
self.step = 0
# 使用加速器准备模型、数据加载器和优化器
self.model, self.opt = self.accelerator.prepare(self.model, self.opt)
# 保存模��
def save(self, milestone):
if not self.accelerator.is_local_main_process:
return
data = {
'step': self.step,
'model': self.accelerator.get_state_dict(self.model),
'opt': self.opt.state_dict(),
'ema': self.ema.state_dict(),
'scaler': self.accelerator.scaler.state_dict() if exists(self.accelerator.scaler) else None
}
torch.save(data, str(self.results_folder / f'model-{milestone}.pt'))
# 加载模型
def load(self, milestone):
data = torch.load(str(self.results_folder / f'model-{milestone}.pt'))
model = self.accelerator.unwrap_model(self.model)
model.load_state_dict(data['model'])
self.step = data['step']
self.opt.load_state_dict(data['opt'])
self.ema.load_state_dict(data['ema'])
if exists(self.accelerator.scaler) and exists(data['scaler']):
self.accelerator.scaler.load_state_dict(data['scaler'])
# 定义训练方法
def train(self):
# 获取加速器和设备
accelerator = self.accelerator
device = accelerator.device
# 使用 tqdm 显示训练进度条,设置初始值、总步数和是否禁用
with tqdm(initial=self.step, total=self.train_num_steps, disable=not accelerator.is_main_process) as pbar:
# 在未达到总步数前循环
while self.step < self.train_num_steps:
# 初始化总损失
total_loss = 0.
# 根据梯度累积次数循环
for _ in range(self.gradient_accumulate_every):
# 获取下一个数据批次并发送到设备
data = next(self.dl).to(device)
# 使用加速器自动混合精度
with self.accelerator.autocast():
# 计算模型损失
loss = self.model(data)
loss = loss / self.gradient_accumulate_every
total_loss += loss.item()
# 反向传播
self.accelerator.backward(loss)
# 更新进度条显示损失值
pbar.set_description(f'loss: {total_loss:.4f}')
# 等待所有进程完成
accelerator.wait_for_everyone()
# 更新优化器参数
self.opt.step()
self.opt.zero_grad()
# 等待所有进程完成
accelerator.wait_for_everyone()
# 如果是主进程
if accelerator.is_main_process:
# 将指数移动平均模型发送到设备并更新
self.ema.to(device)
self.ema.update()
# 如果步数不为0且可以保存和采样
if self.step != 0 and self.step % self.save_and_sample_every == 0:
# 将指数移动平均模型设置为评估模式
self.ema.ema_model.eval()
# 使用无梯度计算
with torch.no_grad():
# 计算里程碑和批次数
milestone = self.step // self.save_and_sample_every
batches = num_to_groups(self.num_samples, self.batch_size)
all_images_list = list(map(lambda n: self.ema.ema_model.sample(batch_size=n), batches))
# 拼接所有图像并保存
all_images = torch.cat(all_images_list, dim=0)
utils.save_image(all_images, str(self.results_folder / f'sample-{milestone}.png'), nrow=int(math.sqrt(self.num_samples)))
self.save(milestone)
# 更新步数并进度条
self.step += 1
pbar.update(1)
# 打印训练完成信息
accelerator.print('training complete')