Lucidrains-系列项目源码解析-十一-

34 阅读21分钟

Lucidrains 系列项目源码解析(十一)

.\lucidrains\byol-pytorch\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  # 包的名称
  name = 'byol-pytorch',
  # 查找并包含除了'examples'之外的所有包
  packages = find_packages(exclude=['examples']),
  # 版本号
  version = '0.8.0',
  # 许可证类型
  license='MIT',
  # 描述信息
  description = 'Self-supervised contrastive learning made simple',
  # 作者
  author = 'Phil Wang',
  # 作者邮箱
  author_email = 'lucidrains@gmail.com',
  # 项目链接
  url = 'https://github.com/lucidrains/byol-pytorch',
  # 长描述内容类型
  long_description_content_type = 'text/markdown',
  # 关键词列表
  keywords = [
      'self-supervised learning',
      'artificial intelligence'
  ],
  # 安装依赖
  install_requires=[
      'accelerate',
      'beartype',
      'torch>=1.6',
      'torchvision>=0.8'
  ],
  # 分类标签
  classifiers=[
      'Development Status :: 4 - Beta',
      'Intended Audience :: Developers',
      'Topic :: Scientific/Engineering :: Artificial Intelligence',
      'License :: OSI Approved :: MIT License',
      'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\CALM-pytorch\CALM_pytorch\CALM.py

# 从 math 模块中导入 ceil 函数
from math import ceil
# 从 pathlib 模块中导入 Path 类
from pathlib import Path
# 从 functools 模块中导入 partial 函数
from functools import partial
# 从 contextlib 模块中导入 nullcontext 和 contextmanager 函数
from contextlib import nullcontext, contextmanager

# 从 dataclasses 模块中导入 dataclass 装饰器
from dataclasses import dataclass

# 导入 torch 库
import torch
# 从 torch.nn.functional 模块中导入 F 别名
import torch.nn.functional as F
# 从 torch.nn 模块中导入 Module 和 ModuleList 类
from torch.nn import Module, ModuleList
# 从 torch.utils.data 模块中导入 Dataset 和 DataLoader 类
from torch.utils.data import Dataset, DataLoader
# 从 torch.optim.lr_scheduler 模块中导入 _LRScheduler 类
from torch.optim.lr_scheduler import _LRScheduler
# 从 torch 模块中导入 nn、einsum 和 Tensor 类
from torch import nn, einsum, Tensor

# 导入 beartype 库
from beartype import beartype
from beartype.door import is_bearable
# 从 beartype.typing 模块中导入 List、Optional、Callable、Type、Tuple、Union、Literal 类型
from beartype.typing import List, Optional, Callable, Type, Tuple, Union, Literal

# 从 einops 库中导入 rearrange 和 repeat 函数
from einops import rearrange, repeat

# 从 x_transformers.x_transformers 模块中导入 RMSNorm、Attention 和 TransformerWrapper 类
from x_transformers.x_transformers import (
    RMSNorm,
    Attention,
    TransformerWrapper,
)

# 导入 accelerate 库
from accelerate import Accelerator

# 从 pytorch_custom_utils 模块中导入 OptimizerWithWarmupSchedule、get_adam_optimizer 和 auto_unwrap_model 函数
from pytorch_custom_utils import (
    OptimizerWithWarmupSchedule,
    get_adam_optimizer,
    auto_unwrap_model
)

# 从 pytorch_custom_utils.accelerate_utils 模块中导入 model_forward_contexts 函数
from pytorch_custom_utils.accelerate_utils import (
    model_forward_contexts
)

# 从 CALM_pytorch.sampling_utils 模块中导入 sample、top_p 和 top_k 函数

# types

# 定义 Sequence 类型为 Tuple 或 List
Sequence = Union[Tuple, List]

# 定义 HiddenPosition 类型为 'input' 或 'output'
HiddenPosition = Union[Literal['input'], Literal['output']]

# 定义 SequenceOf 函数,接受类型参数 t,返回 Tuple[t, ...] 或 List[t]
def SequenceOf(t):
    return Union[Tuple[t, ...], List[t]]

# 定义 SingularOrMany 函数,接受类型参数 t,返回 t 或 SequenceOf(t)
def SingularOrMany(t):
    return Union[t, SequenceOf(t)]

# helpers

# 定义 exists 函数,判断变量是否存在
def exists(v):
  return v is not None

# 定义 default 函数,返回第一个参���或默认值
def default(v, d):
    return v if exists(v) else d

# 定义 xnor 函数,实现逻辑异或操作
def xnor(x, y):
    return not (x ^ y)

# 定义 cast_tuple 函数,将参数转换为元组
def cast_tuple(t, length = 1):
    return t if is_bearable(t, Sequence) else ((t,) * length)

# 定义 get_block_output_from_hook_outputs 函数,从钩子输出中获取模块输出
def get_block_output_from_hook_outputs(
    hidden_position: HiddenPosition,
    _, inp, out
):
    maybe_tensor = out if hidden_position == 'output' else inp

    if isinstance(maybe_tensor, tuple):
        maybe_tensor = maybe_tensor[0]

    assert torch.is_tensor(maybe_tensor)
    return maybe_tensor

# freezing llms

# 定义 set_module_requires_grad_ 函数,设置模块参数是否需要梯度
@beartype
def set_module_requires_grad_(
    module: Module,
    requires_grad: bool
):
    for param in module.parameters():
        param.requires_grad = requires_grad

# 定义 freeze_all_layers_ 函数,冻结所有层的参数
def freeze_all_layers_(module):
    set_module_requires_grad_(module, False)

# function for returning an ordered list of modules, where the output of the module is the output of that transformer block layer
# ex. for x-transformers TransformerWrapper

# 定义 x_transformer_blocks 函数,返回 TransformerWrapper 中每个 transformer block 的模块列表
@beartype
def x_transformer_blocks(transformer: TransformerWrapper) -> List[Module]:
    blocks = []
    for layer in transformer.attn_layers.layers:
        blocks.append(layer[-1])
    return blocks[1::2]

# helper classes

# 定义 Recorder 类
class Recorder:
    # Recorder 类的构造函数
    @beartype
    def __init__(
        self,
        outputs: Optional[List] = None,
        forward_hook_get_hidden: HiddenPosition = 'output',
        modules: Optional[List] = None,
    ):
        self.output = default(outputs, [])
        self.modules = modules
        self.get_output_fn = partial(get_block_output_from_hook_outputs, forward_hook_get_hidden)

    # Recorder 类的调用函数
    def __call__(self, *args):

        if exists(self.modules):
            self.modules.append(args[0])

        hidden = self.get_output_fn(*args)
        self.output.append(hidden.detach())

# 定义 ExtractHiddensWrapper 类
class ExtractHiddensWrapper(Module):
    # ExtractHiddensWrapper 类的构造函数
    @beartype
    def __init__(
        self,
        model: Module,
        blocks: List[Module],
        hidden_positions: SingularOrMany(HiddenPosition) = 'output'
    ):
        super().__init__()
        hidden_positions = cast_tuple(hidden_positions, len(blocks))
        assert len(hidden_positions) == len(blocks)

        self.model = model

        self.outputs = []
        self.modules = []
        self.recorders = []

        for block, hidden_position in zip(blocks, hidden_positions):
            recorder = Recorder(self.outputs, hidden_position, self.modules)
            self.recorders.append(recorder)
            block.register_forward_hook(recorder)
    # 定义一个方法用于前向传播,接受任意参数和关键字参数,可以选择是否返回被挂钩的模块
    def forward(self, *args, return_hooked_modules = False, **kwargs):
        # 调用模型的前向传播方法,传入参数和关键字参数
        self.model(*args, **kwargs)

        # 复制输出和模块字典
        outputs = self.outputs.copy()
        modules = self.modules.copy()

        # 清空输出和模块字典
        self.outputs.clear()
        self.modules.clear()

        # 如果不需要返回被挂钩的模块,则返回输出字典
        if not return_hooked_modules:
            return outputs

        # 如果需要返回被挂钩的模块,则同时返回输出字典和模块字典
        return outputs, modules
# 定义交叉注意力块类
class CrossAttentionBlock(Module):
    # 初始化函数
    @beartype
    def __init__(
        self,
        dim,
        dim_context,
        linear_project_context = True,  # 在论文中,他们对增强隐藏状态进行了投影。不确定是否需要,但最好先准确
        pre_rmsnorm = False,
        forward_hook_get_hidden: Union[
            Literal['output'],
            Literal['input']
        ] = 'output',
        **kwargs
    ):
        super().__init__()
        # 如果需要预先进行 RMS 归一化,则创建 RMSNorm 对象
        self.pre_rmsnorm = RMSNorm(dim) if pre_rmsnorm else nn.Identity()

        self.context_proj = None

        self.dim = dim
        self.dim_context = dim_context

        # 如果需要线性投影上下文,则创建线性层对象
        if linear_project_context:
            self.context_proj = nn.Linear(dim_context, dim)
            dim_context = dim

        # 创建注意力对象
        self.attn = Attention(
            dim = dim,
            dim_context = dim_context,
            zero_init_output = True,
            gate_value_heads = True,
            **kwargs
        )

        self.context = None
        self.context_mask = None
        self.forward_hook_get_hidden = forward_hook_get_hidden

    # 设置掩码
    def set_mask(self, mask: Tensor):
        self.context_mask = mask

    # 取消掩码
    def unset_mask(self):
        self.context_mask = None

    # 前向传播函数
    def forward(self, *hook_args):
        x = get_block_output_from_hook_outputs(self.forward_hook_get_hidden, *hook_args)

        context = self.context
        assert exists(context)

        maybe_enable_grad = torch.enable_grad if self.training else nullcontext

        with maybe_enable_grad():
            res = x
            x = self.pre_rmsnorm(x)

            if exists(self.context_proj):
                context = self.context_proj(context)

            out = self.attn(x, context, context_mask = self.context_mask) + res

        return out

# 主类
@dataclass
class AugmentParams:
    model: Module
    hidden_position: SingularOrMany(HiddenPosition) = 'output'
    transformer_blocks: Optional[List[Module]] = None
    extract_blocks_fn: Optional[Callable[[Module], List[Module]]] = None
    model_return_hiddens: bool = False
    input_shape: Optional[Tuple[int, ...]] = None
    connections: Optional[Tuple[Tuple[int, int], ...]] = None
    connect_every_num_layers: int = 4 # 在论文中,他们做了 4 层
    mask_kwarg: Optional[str] = None

# CALM 类
class CALM(Module):
    # 初始化函数
    @beartype
    def __init__(
        self,
        anchor_llm: Module,
        augment_llms: SingularOrMany(AugmentParams),
        *,
        attn_kwargs: dict = dict(
            linear_project_context = True,
            pre_rmsnorm = True,
            flash = True
        ),
        anchor_extract_blocks_fn: Callable[[Module], List[Module]] = None,
        anchor_transformer_blocks: Optional[List[Module]] = None,
        anchor_hidden_position: SingularOrMany(HiddenPosition) = 'output',
        pad_id: int = -1
    def state_dict(self):
        return self.cross_attns.state_dict()

    def load_state_dict(self, pkg, strict = False):
        self.cross_attns.load_state_dict(pkg, strict = strict)

    def parameters(self):
        return self.cross_attns.parameters()

    def release_cross_attn_contexts(self):
        for one_augment_cross_attns in self.cross_attns:
            for cross_attn in one_augment_cross_attns:
                cross_attn.context = None

    def forward_augments(
        self,
        prompt: Tensor,
        prompt_mask: Optional[SingularOrMany(SequenceOf(Tensor))] = None
    ):
        # 如果只提供一个提示并且有多个增强LLM,则将该提示输入到所有增强LLM中

        num_augment_llms = len(self.augment_llms)

        prompts = cast_tuple(prompt, num_augment_llms)

        assert len(prompts) == num_augment_llms

        # 提示掩码

        if not exists(prompt_mask):
            prompt_mask = tuple((p != self.pad_id if not torch.is_floating_point(p) else None) for p in prompts)

        prompt_mask = cast_tuple(prompt_mask, num_augment_llms)

        prompt_masks = prompt_mask # 在这一点上,应该是复数

        assert len(prompt_masks) == num_augment_llms

        # 调用增强LLM,使用前向钩子收集隐藏状态

        augments_hiddens = []

        with torch.no_grad():

            self.augment_llms.eval()

            for augment_llm, params, prompt, prompt_mask in zip(self.augment_llms, self.augment_llms_params, prompts, prompt_masks):
                augment_llm_kwarg = dict()

                if exists(params.mask_kwarg):
                    augment_llm_kwarg = {params.mask_kwarg: prompt_mask}

                one_augment_hiddens = augment_llm(prompt, **augment_llm_kwarg)

                augments_hiddens.append(one_augment_hiddens)

        # 为锚点前向设置每个交叉注意力块的上下文

        for one_augment_hiddens, one_augment_cross_attns, one_augment_connections in zip(augments_hiddens, self.cross_attns, self.connections):

            for (augment_layer_index, _), cross_attn in zip(one_augment_connections, one_augment_cross_attns):
            
                cross_attn.context = one_augment_hiddens[augment_layer_index - 1]

        return prompts, prompt_masks

    @contextmanager
    def set_cross_attn_masks(self, masks):
        # 为交叉注意力设置上下文掩码

        for one_cross_attn, mask in zip(self.cross_attns, masks):
            for cross_attn in one_cross_attn:
                cross_attn.set_mask(mask)

        yield

        # 取消设置上下文掩码

        for one_cross_attn in self.cross_attns:
            for cross_attn in one_cross_attn:
                cross_attn.unset_mask()


    @torch.no_grad()
    def generate(
        self,
        prompt: Tensor,
        seq_len: int,
        prompt_mask: Optional[SingularOrMany(SequenceOf(Tensor))] = None,
        filter_fn: Callable = top_p,
        filter_kwargs: dict = dict(
            thres = 0.9
        )
    ):
        batch, device = prompt.shape[0], next(self.cross_attns.parameters()).device

        self.eval()

        # 在所有增强模型上运行前向并收集隐藏状态

        prompts, prompt_masks = self.forward_augments(prompt = prompt, prompt_mask = prompt_mask)

        with self.set_cross_attn_masks(prompt_masks):

            # 采样

            generated =  sample(
                self.anchor_llm,
                prompt,
                seq_len = seq_len,
                filter_fn = filter_fn,
                filter_kwargs = filter_kwargs
            )

            self.release_cross_attn_contexts()

        return generated

    @beartype
    def forward(
        self,
        seq: Tensor,
        *,
        prompt: SingularOrMany(Tensor),
        prompt_mask: Optional[SingularOrMany(Tensor)] = None,
        mask: Optional[Tensor] = None,
        return_loss = True,
        anchor_llm_in_train_mode = True  # 对此不确定
        ):
        # 如果需要返回损失值,则将交叉注意力模型设置为训练模式
        if return_loss:
            self.cross_attns.train()

            # 如果锚定语言模型需要在训练模式下,则设置为训练模式,否则设置为评估模式
            if anchor_llm_in_train_mode:
                self.anchor_llm.train()
            else:
                self.anchor_llm.eval()

            # 将序列截断,去掉最后一个字符,用于输入和标签
            seq, labels = seq[:, :-1], seq[:, 1:]

        # 在所有数据增强模型上运行前向传播,并收集隐藏状态

        prompts, prompt_masks = self.forward_augments(prompt=prompt, prompt_mask=prompt_mask)

        # 设置交叉注意力模型的掩码
        with self.set_cross_attn_masks(prompt_masks):
            # 调用锚定语言模型,该模型应该处理与增强语言模型隐藏状态的交叉注意力

            logits = self.anchor_llm(seq)

            # 释放交叉注意力上下文
            self.release_cross_attn_contexts()

            # 断言锚定语言模型返回的 logits 维度应为 (batch, seq, num tokens)
            assert logits.ndim == 3, 'anchor llm should return logits in the shape (batch, seq, num tokens)'

        # 返回用于解码的 logits

        if not return_loss:
            return logits

        # 考虑提示掩码

        if exists(mask):
            # 如果存在掩码,则使用掩码填充标签
            labels = labels.masked_fill(~mask[:, 1:], self.pad_id)

        # 用于微调

        loss = F.cross_entropy(
            rearrange(logits, 'b n c -> b c n'),
            labels,
            ignore_index=self.pad_id
        )

        return loss
# 定义一个循环生成器,用于循环遍历数据加载器中的批次数据
def cycle(dl):
    while True:
        for batch in dl:
            yield batch

# 使用装饰器自动解包模型
@auto_unwrap_model()
class FineTuner:

    # 初始化方法,接收多个参数
    @beartype
    def __init__(
        self,
        calm: CALM,
        *,
        num_train_steps: int,
        learning_rate: float,
        weight_decay: float,
        batch_size: int,
        dataset: Dataset,
        data_kwarg_names: Tuple[str, ...] = ('seq', 'mask', 'prompt'),
        accelerate_kwargs: dict = dict(),
        checkpoint_every: int = 1000,
        checkpoint_path: str = './checkpoints',
        scheduler: Optional[Type[_LRScheduler]] = None,
        scheduler_kwargs: dict = dict(),
        warmup_steps: int = 1000,
        max_grad_norm = 0.5,
        grad_accum_steps = 1
    ):
        # 初始化加速器
        self.accelerator = Accelerator(**accelerate_kwargs)

        # 创建数据加载器
        self.dl = DataLoader(dataset, batch_size = batch_size, shuffle = True, drop_last = True)
        self.data_kwarg_names = data_kwarg_names

        # 设置模型
        self.model = calm

        # 创建 Adam 优化器
        adam = get_adam_optimizer(
            calm.parameters(),
            lr = learning_rate,
            wd = weight_decay
        )

        # 初始化优化器和学习率调度器
        self.optimizer = OptimizerWithWarmupSchedule(
            accelerator = self.accelerator,
            optimizer = adam,
            scheduler = scheduler,
            scheduler_kwargs = scheduler_kwargs,
            warmup_steps = warmup_steps,
            max_grad_norm = max_grad_norm
        )

        self.step = 0
        self.num_train_steps = num_train_steps
        self.grad_accum_steps = grad_accum_steps

        self.checkpoint_every = checkpoint_every
        self.checkpoint_path = Path(checkpoint_path)
        self.checkpoint_path.mkdir(exist_ok = True, parents = True)

    # 判断当前进程是否为主进程
    @property
    def is_main(self):
        return self.accelerator.is_main_process

    # 打印信息
    def print(self, msg):
        self.accelerator.print(msg)

    # 保存模型和优化器状态
    def save(self, filename: str, overwrite: bool = True):
        path = self.checkpoint_path / filename
        assert overwrite or not path.exists()

        pkg = dict(
            model = self.model.state_dict(),
            optimizer = self.optimizer.state_dict(),
            step = self.step
        )

        torch.save(pkg, str(path))

    # 加载模型和优化器状态
    def load(self, filename: str):
        path = self.checkpoint_path / filename
        assert path.exists()

        pkg = torch.load(str(path))

        self.model.load_state_dict(pkg['model'])
        self.optimizer.load_state_dict(pkg['optimizer'])
        self.step = pkg['step']

    # 定义 FineTuner 类的调用方法
    def __call__(self, forward_kwargs: dict = dict()):
        dl_iter = cycle(self.dl)
        self.model.train()

        for step in range(self.step, self.num_train_steps):

            for context in model_forward_contexts(
                model = self.model,
                accelerator = self.accelerator,
                grad_accum_steps = self.grad_accum_steps
            ):
                with context():
                    data = next(dl_iter)

                    if not isinstance(data, dict):
                        data = dict(zip(self.data_kwarg_names, data))

                    loss = self.model(**data, **forward_kwargs)

                    self.accelerator.backward(loss / self.grad_accum_steps)

            self.print(f'{step + 1}: {loss.item():.3f}')

            self.optimizer.step()
            self.optimizer.zero_grad()

            self.step += 1

            self.accelerator.wait_for_everyone()

            if self.is_main and not (self.step % self.checkpoint_every):
                num = self.step // self.checkpoint_every
                self.save(f'checkpoint.{num}.pt')

            self.accelerator.wait_for_everyone()

        self.print('training complete')
        self.save('checkpoint.-1.pt')

.\lucidrains\CALM-pytorch\CALM_pytorch\sampling_utils.py

import torch  # 导入 PyTorch 库
import torch.nn.functional as F  # 导入 PyTorch 中的函数模块
from torch import Tensor  # 导入 PyTorch 中的张量
from torch.nn import Module  # 导入 PyTorch 中的神经网络模块
from torch.nn.utils.rnn import pad_sequence  # 导入 PyTorch 中的序列填充函数

from beartype import beartype  # 导入 beartype 库中的类型检查装饰器
from beartype.typing import Optional, Callable, List, Tuple  # 导入 beartype 库中的类型注解

from einops import rearrange  # 导入 einops 库中的重排函数

from tqdm import tqdm  # 导入 tqdm 库中的进度条显示函数

def exists(v):  # 定义函数,判断变量是否存在
    return v is not None  # 返回变量是否不为 None

def default(v, d):  # 定义函数,返回变量或默认值
    return v if exists(v) else d  # 如果变量存在则返回变量,否则返回默认值

# 采样辅助函数

def log(t, eps = 1e-20):  # 定义函数,计算张量的对数
    return torch.log(t.clamp(min = eps))  # 返回张量的对数,避免小于 eps 的值

def gumbel_noise(t):  # 定义函数,生成 Gumbel 噪声
    noise = torch.zeros_like(t).uniform_(0, 1)  # 生成与输入张量相同大小的均匀分布噪声
    return -log(-log(noise))  # 返回 Gumbel 噪声

def gumbel_sample(t, temperature = 1., dim = -1, keepdim = True, eps = 1e-10):  # 定义函数,使用 Gumbel 分布进行采样
    return ((t / max(temperature, eps)) + gumbel_noise(t)).argmax(dim = dim, keepdim = keepdim)  # 返回 Gumbel 采样结果

# nucleus

def top_p(logits, thres = 0.9):  # 定义函数,根据 top-p 策略进行筛选
    sorted_logits, sorted_indices = torch.sort(logits, descending = True)  # 对 logits 进行降序排序
    cum_probs = torch.cumsum(F.softmax(sorted_logits, dim = -1), dim = -1)  # 计算累积概率

    sorted_indices_to_remove = cum_probs > thres  # 根据阈值筛选需要移除的索引
    sorted_indices_to_remove = F.pad(sorted_indices_to_remove, (1, -1), value = False)  # 对需要移除的索引进行填充

    sorted_logits[sorted_indices_to_remove] = float('-inf')  # 将需要移除的 logits 置为负无穷
    return sorted_logits.scatter(1, sorted_indices, sorted_logits)  # 返回根据 top-p 策略筛选后的 logits

# topk

def top_k(logits, frac_num_tokens = 0.1, k: Optional[int] = None):  # 定义函数,根据 top-k 策略进行筛选
    num_tokens = logits.shape[-1]  # 获取 logits 的最后一个维度大小

    k = default(k, ceil(frac_num_tokens * num_tokens))  # 计算 k 值
    k = min(k, num_tokens)  # 取 k 和 num_tokens 中的较小值

    val, ind = torch.topk(logits, k)  # 获取 top-k 的值和索引
    probs = torch.full_like(logits, float('-inf'))  # 创建与 logits 相同大小的全为负无穷的张量
    probs.scatter_(1, ind, val)  # 根据 top-k 的索引和值填充 probs
    return probs  # 返回根据 top-k 策略筛选后的 logits

# 解码

@torch.no_grad()  # 禁用梯度计算
@beartype  # 使用 beartype 类型检查装饰器
def sample(  # 定义函数,生成序列样本
    net: Module,  # 神经网络模型
    prompts,  # 输入的提示序列
    seq_len: int,  # 生成序列的长度
    temperature = 1.,  # 温度参数
    filter_fn: Callable = top_p,  # 筛选函数,默认为 top-p
    filter_kwargs: dict = dict(),  # 筛选函数的参数
    pad_id: int = -1,  # 填充标识符
    eos_id: Optional[int] = None,  # 结束标识符
    output_keep_prompt = False  # 是否保留提示序列
):
    device = next(net.parameters()).device  # 获取模型参数所在的设备
    net.eval()  # 设置模型为评估模式

    if isinstance(prompts, (tuple, list)):  # 如果提示序列是元组或列表
        prompts = pad_sequence(prompts, batch_first = True, padding_value = pad_id)  # 对提示序列进行填充

    batch, prompts_tensor_len = prompts.shape  # 获取提示序列的形状信息

    batch_arange = torch.arange(batch, device = device)[..., None]  # 创建批次索引张量

    prompt_lens = (prompts != pad_id).sum(dim = -1)  # 计算提示序列的长度
    curr_seq_indices = prompt_lens[..., None]  # 当前序列索引

    out = prompts.clone()  # 克隆提示序列作为输出序列

    pbar = tqdm(  # 创建进度条
        initial = out.shape[-1],  # 初始值
        total = seq_len,  # 总步数
        desc = 'sampling'  # 描述
    )

    while (curr_seq_indices < seq_len).any():  # 当当前序列索引小于生成序列长度时循环
        out = F.pad(out, (0, 1), value = pad_id)  # 对输出序列进行填充

        net_input = out.masked_fill(out == pad_id, 0)  # 将填充值替换为 0

        logits = net(net_input)  # 输入网络获取 logits

        logits = logits[batch_arange, curr_seq_indices]  # 根据当前序列索引获取 logits
        logits = rearrange(logits, 'b 1 d -> b d')  # 重排 logits 的维度

        logits = filter_fn(logits, **filter_kwargs)  # 根据筛选函数筛选 logits
        sampled_tokens = gumbel_sample(logits, temperature = temperature, dim = -1)  # 使用 Gumbel 采样获取 tokens

        out[batch_arange, curr_seq_indices] = sampled_tokens  # 更新输出序列

        curr_seq_indices += 1  # 当前序列索引加一
        curr_seq_indices.clamp_(max = seq_len)  # 限制当前序列索引的最大值为生成序列长度
        pbar.update(1)  # 更新进度条

        if not exists(eos_id):  # 如果结束标识符不存在
            continue  # 继续下一次循环

        is_eos_mask = out == eos_id  # 获取结束标识符的掩码
        all_eos = is_eos_mask.any(dim = -1).all()  # 判断是否所有序列都包含结束标识符

        if all_eos:  # 如果所有序列都包含结束标识符
            break  # 跳出循环

    pbar.close()  # 关闭进度条

    out = out[:, :seq_len]  # 截取生成序列的长度

    if exists(eos_id):  # 如果结束标识符存在
        is_eos_mask = out == eos_id  # 获取结束标识符的掩码
        after_eos_mask = F.pad(is_eos_mask.cumsum(dim = -1) > 0, (1, -1), value = False)  # 获取结束标识符后的掩码
        out = out.masked_fill_(after_eos_mask, pad_id)  # 将结束标识符后的位置填充为填充标识符

    if output_keep_prompt:  # 如果需要保留提示序列
        return out  # 返回输出序列

    prompt_mask = torch.arange(out.shape[-1], device = device) < prompt_lens[..., None]  # 创建提示序列的掩码

    generated_seq_mask = out != pad_id & ~prompt_mask  # 生成序列的掩码
    seq_lens = generated_seq_mask.sum(dim = -1).tolist()  # 计算生成序列的长度

    return out[generated_seq_mask].split(seq_lens)  # 返回根据生成序列掩码拆分后的结果

.\lucidrains\CALM-pytorch\CALM_pytorch\__init__.py

# 从 CALM_pytorch.CALM 模块中导入以下类和函数
from CALM_pytorch.CALM import (
    AugmentParams,  # 导入 AugmentParams 类
    ExtractHiddensWrapper,  # 导入 ExtractHiddensWrapper 类
    CALM,  # 导入 CALM 类
    FineTuner  # 导入 FineTuner 类
)

CALM - Pytorch

Implementation of CALM from the paper LLM Augmented LLMs: Expanding Capabilities through Composition, out of Google Deepmind

Can support any number of augmentation LLMs

Install

$ pip install CALM-pytorch

Appreciation

Usage

ex. with x-transformers

import torch
from x_transformers import TransformerWrapper, Decoder

augment_llm = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 12,
        heads = 8
    )
)

anchor_llm = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 2,
        heads = 8
    )
)

# import CALM wrapper

from CALM_pytorch import CALM, AugmentParams

calm = CALM(
    anchor_llm,
    augment_llms = AugmentParams(
        model = augment_llm,
        connect_every_num_layers = 4
    )
)

# mock input

seq = torch.randint(0, 20000, (1, 1024))
mask = torch.ones((1, 1024)).bool()
prompt = torch.randint(0, 20000, (1, 256))

# forward for finetuning loss

loss = calm(
    seq,
    mask = mask,
    prompt = prompt
)

loss.backward()

# after much training, prompt the composed model

generated = calm.generate(
    prompt = seq[:, :1],
    seq_len = 1024
)

To use a handy trainer class using 🤗 Accelerate, just import FineTuner and use as follows

trainer = FineTuner(
    calm = calm,
    dataset = dataset,   # returns a dictionary of input kwargs to calm - dict(seq: Tensor, mask: Tensor, prompt: Tensor). it can also return a Tuple, in which data_kwargs needs to be set to the correct ordered value of kwarg names
    batch_size = 16,
    num_train_steps = 10000,
    learning_rate = 3e-4,
    weight_decay = 1e-2,
    warmup_steps = 1000,
    checkpoint_every = 1000
)

trainer()

# checkpoints of the cross attention parameters will be saved to ./checkpoints every 1000 steps

To explore multiple augmentation LLMs, simply pass in a list for augment_llm

ex.

calm = CALM(
    anchor_llm = anchor_llm,
    augment_llm = [AugmentParams(augment_llm1), AugmentParams(augment_llm2)] # pass in a list of AugmentParams wrapping model and other hparams specific to that transformer
)

Say you want to explore different types of connectivity between anchor and augmentation model(s), just pass in the connections as a tuple of tuple integer pairs, specifying the anchor to augment layer number.

calm = CALM(
    anchor_llm = anchor_llm,
    augment_llms = (
        AugmentParams(
            model = augment_llm1,
            connections = (
                (1, 12),  # 1st layer of augment llm1 attended to by 12th layer of anchor llm
                (2, 12),
                (3, 12),
                (4, 12),
            ),
        ),
        AugmentParams(
            model = augment_llm2,
            connections = (
                (6, 1),   # 6th layer of augment llm2 attended to by 1st layer of anchor llm
                (6, 2),
                (12, 12),
            )
        )
    )
)

CALM setup with 2 specialized augmentation LLMs + a vision transformer

import torch

# pip install vit-pytorch x-transformers

from vit_pytorch.vit import ViT, Attention
from x_transformers import TransformerWrapper, Encoder, Decoder

anchor_llm = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 16,
        dim_head = 2,
        depth = 12,
        heads = 8
    )
)

augment_llm1 = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Encoder(
        dim = 16,
        dim_head = 2,
        depth = 12,
        heads = 8
    )
)

augment_llm2 = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Encoder(
        dim = 16,
        dim_head = 2,
        depth = 12,
        heads = 8
    )
)

vit = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 256,
    depth = 6,
    heads = 16,
    mlp_dim = 2048
)

# calm

from CALM_pytorch import CALM, AugmentParams, FineTuner

calm = CALM(
    anchor_llm = anchor_llm,
    augment_llms = (
        AugmentParams(
            model = augment_llm1,
            mask_kwarg = 'mask'
        ),
        AugmentParams(
            model = augment_llm2,
            mask_kwarg = 'mask'
        ),
        AugmentParams(
            model = vit,
            input_shape = (3, 256, 256),
            hidden_position = 'input',
            extract_blocks_fn = lambda vit: [m for m in vit.modules() if isinstance(m, Attention)]
        )
    ),
    attn_kwargs = dict(
        linear_project_context = True,
        pre_rmsnorm = True,
        flash = True
    )
)

seq = torch.randint(0, 20000, (1, 1024))
mask = torch.ones((1, 1024)).bool()

prompt = (
    torch.randint(0, 20000, (1, 256)),
    torch.randint(0, 20000, (1, 256)),
    torch.randn(1, 3, 256, 256)
)

loss = calm(
    seq,
    mask = mask,
    prompt = prompt
)

loss.backward()

Todo

  • figure out how to correctly mask augment llm tokens

  • auto-derive model dimensions with dummy input

  • take care of finetuning training logic

  • show example of manual definitions of custom connectivity between 2+ attention networks

  • if anchor and augment transformer block modules are directly passed in (without extraction fn), run a dummy input through both networks and order them correctly using hooks

  • fix example for x-transformers, as in x-transformers, depth is actually depth x 2, taking hiddens from after attention and ff

  • when finely specifying hidden positions, make sure to reorder it if the transformer blocks themselves were passed in and not ordered to begin with

  • extend to a list of augmentation llms

    • full connectivity customization
    • custom number of augmentation layers per augmetation llm
    • make simple vit work
      • refactor so extraction fn, mask kwarg, and other related hparams are grouped together under a dictionary of {[augment_llm_name]: {augment_llm_related_hparams}} - use dataclasses
      • show example
  • take care of caching the augment hiddens when sampling. forget about anchor kv cache for now

    • logic for not releasing the saved output from recorder, for inference
    • managing cross attention block state for popping the saved output from the recorder
    • move the augmentation forwards into one shared method, and craft out sampling method for anchor
  • able to wire up with just module names

  • show an example with giving the LLM ability to hear as well, using hubert or wav2vec wrappers

  • handle a wrapper or function that takes in the sequence and prompt length, and auto derives the inputs to CALM

  • add an option for self attention path way with memory tokens attending to hidden states of all augmentation llms, akin to what was done with Zorro

Citations

@inproceedings{Bansal2024LLMAL,
  title   = {LLM Augmented LLMs: Expanding Capabilities through Composition},
  author  = {Rachit Bansal and Bidisha Samanta and Siddharth Dalmia and Nitish Gupta and Shikhar Vashishth and Sriram Ganapathy and Abhishek Bapna and Prateek Jain and Partha Pratim Talukdar},
  year    = {2024},
  url     = {https://api.semanticscholar.org/CorpusID:266755751}
}

.\lucidrains\CALM-pytorch\setup.py

# 导入设置和查找包的函数
from setuptools import setup, find_packages

# 设置包的信息
setup(
  name = 'CALM-Pytorch',  # 包的名称
  packages = find_packages(exclude=[]),  # 查找所有包
  version = '0.2.1',  # 版本号
  license='MIT',  # 许可证
  description = 'CALM - Pytorch',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  long_description_content_type = 'text/markdown',  # 长描述内容类型
  url = 'https://github.com/lucidrains/CALM-pytorch',  # URL
  keywords = [
    'artificial intelligence',  # 关键词
    'deep learning',  # 关键词
    'composing LLMs'  # 关键词
  ],
  install_requires = [  # 安装依赖
    'accelerate',  # 加速库
    'beartype',  # 类型检查库
    'einops>=0.7.0',  # 数据重塑库
    'pytorch-custom-utils>=0.0.11',  # PyTorch自定义工具库
    'torch>=2.0',  # PyTorch库
    'tqdm',  # 进度条库
    'x-transformers>=1.27.3'  # 自定义Transformer库
  ],
  classifiers=[  # 分类器
    'Development Status :: 4 - Beta',  # 开发状态
    'Intended Audience :: Developers',  # 目标受众
    'Topic :: Scientific/Engineering :: Artificial Intelligence',  # 主题
    'License :: OSI Approved :: MIT License',  # 许可证
    'Programming Language :: Python :: 3.6',  # 编程语言
  ],
)

.\lucidrains\charformer-pytorch\charformer_pytorch\charformer_pytorch.py

# 导入 math 模块
import math
# 从 math 模块中导入 gcd 函数
from math import gcd
# 导入 functools 模块
import functools
# 导入 torch 模块
import torch
# 从 torch 模块中导入 nn, F, einsum
import torch.nn.functional as F
from torch import nn, einsum
# 从 einops 模块中导入 rearrange, reduce, repeat
from einops import rearrange, reduce, repeat
# 从 einops.layers.torch 模块中导入 Rearrange
from einops.layers.torch import Rearrange

# 辅助函数

# 判断变量是否存在
def exists(val):
    return val is not None

# 计算多个数的最小公倍数
def lcm(*numbers):
    return int(functools.reduce(lambda x, y: int((x * y) / gcd(x, y)), numbers, 1))

# 计算带有掩码的张量的均值
def masked_mean(tensor, mask, dim = -1):
    diff_len = len(tensor.shape) - len(mask.shape)
    mask = mask[(..., *((None,) * diff_len))]
    tensor.masked_fill_(~mask, 0.)

    total_el = mask.sum(dim = dim)
    mean = tensor.sum(dim = dim) / total_el.clamp(min = 1.)
    mean.masked_fill_(total_el == 0, 0.)
    return mean

# 计算下一个可被整除的长度
def next_divisible_length(seqlen, multiple):
    return math.ceil(seqlen / multiple) * multiple

# 将张量填充到指定的倍数
def pad_to_multiple(tensor, multiple, *, seq_dim, dim = -1, value = 0.):
    seqlen = tensor.shape[seq_dim]
    length = next_divisible_length(seqlen, multiple)
    if length == seqlen:
        return tensor
    remainder = length - seqlen
    pad_offset = (0,) * (-1 - dim) * 2
    return F.pad(tensor, (*pad_offset, 0, remainder), value = value)

# 辅助类

# 填充层
class Pad(nn.Module):
    def __init__(self, padding, value = 0.):
        super().__init__()
        self.padding = padding
        self.value = value

    def forward(self, x):
        return F.pad(x, self.padding, value = self.value)

# 深度卷积层
class DepthwiseConv1d(nn.Module):
    def __init__(self, dim_in, dim_out, kernel_size):
        super().__init__()
        self.conv = nn.Conv1d(dim_in, dim_out, kernel_size, groups = dim_in)
        self.proj_out = nn.Conv1d(dim_out, dim_out, 1)

    def forward(self, x):
        x = self.conv(x)
        return self.proj_out(x)

# 主类

class GBST(nn.Module):
    def __init__(
        self,
        *,
        num_tokens,
        dim,
        max_block_size = None,
        blocks = None,
        downsample_factor = 4,
        score_consensus_attn = True
    ):
        super().__init__()
        assert exists(max_block_size) ^ exists(blocks), 'either max_block_size or blocks are given on initialization'
        self.token_emb = nn.Embedding(num_tokens, dim)

        if exists(blocks):
            assert isinstance(blocks, tuple), 'blocks must be a tuple of block sizes'
            self.blocks = tuple(map(lambda el: el if isinstance(el, tuple) else (el, 0), blocks))
            assert all([(offset < block_size) for block_size, offset in self.blocks]), 'offset must be always smaller than the block size'

            max_block_size = max(list(map(lambda t: t[0], self.blocks)))
        else:
            self.blocks = tuple(map(lambda el: (el, 0), range(1, max_block_size + 1)))

        self.pos_conv = nn.Sequential(
            Pad((0, 0, 0, max_block_size - 1)),
            Rearrange('b n d -> b d n'),
            DepthwiseConv1d(dim, dim, kernel_size = max_block_size),
            Rearrange('b d n -> b n d')
        )

        self.score_fn = nn.Sequential(
            nn.Linear(dim, 1),
            Rearrange('... () -> ...')
        )

        self.score_consensus_attn = score_consensus_attn

        assert downsample_factor <= max_block_size, 'final downsample factor should be less than the maximum block size'

        self.block_pad_multiple = lcm(*[block_size for block_size, _ in self.blocks])
        self.downsample_factor = downsample_factor

.\lucidrains\charformer-pytorch\charformer_pytorch\__init__.py

# 从 charformer_pytorch.charformer_pytorch 模块中导入 GBST 类
from charformer_pytorch.charformer_pytorch import GBST

Charformer - Pytorch

Implementation of the GBST (gradient-based subword tokenization) module from the Charformer paper, in Pytorch. The paper proposes a module that automatically learns subword representations, obviating the need for tokenizers in the encoder setting.

AI Coffee Break with Letitia video

Install

$ pip install charformer-pytorch

Usage

import torch
from charformer_pytorch import GBST

tokenizer = GBST(
    num_tokens = 257,             # number of tokens, should be 256 for byte encoding (+ 1 special token for padding in this example)
    dim = 512,                    # dimension of token and intra-block positional embedding
    max_block_size = 4,           # maximum block size
    downsample_factor = 4,        # the final downsample factor by which the sequence length will decrease by
    score_consensus_attn = True   # whether to do the cheap score consensus (aka attention) as in eq. 5 in the paper
)

tokens = torch.randint(0, 257, (1, 1023)) # uneven number of tokens (1023)
mask   = torch.ones(1, 1023).bool()

# both tokens and mask will be appropriately downsampled

tokens, mask = tokenizer(tokens, mask = mask) # (1, 256, 512), (1, 256)

# now pass this on to your transformer

Deviating from the paper, you can also specify block size(s) with different offsets. This is to cover a potential use-case for genomics pre-training, where the tokenizer should be able to learn the correct frame. Simply omit the max_block_size, and pass in blocks as a list of tuples of tuples, each tuple with the format (block size, offset). Offsets must be less than the block size

import torch
from charformer_pytorch import GBST

tokenizer = GBST(
    num_tokens = 4 + 1,
    dim = 512,
    blocks = ((3, 0), (3, 1), (3, 2)),  # block size of 3, with offsets of 0, 1, 2
    downsample_factor = 3,
    score_consensus_attn = True
).cuda()

basepairs = torch.randint(0, 4, (1, 1023)).cuda()
mask      = torch.ones(1, 1023).bool().cuda()

# both basepairs and mask will be appropriately downsampled

basepairs, mask = tokenizer(basepairs, mask = mask)

Citations

@misc{tay2021charformer,
    title   = {Charformer: Fast Character Transformers via Gradient-based Subword Tokenization}, 
    author  = {Yi Tay and Vinh Q. Tran and Sebastian Ruder and Jai Gupta and Hyung Won Chung and Dara Bahri and Zhen Qin and Simon Baumgartner and Cong Yu and Donald Metzler},
    year    = {2021},
    eprint  = {2106.12672},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}

.\lucidrains\charformer-pytorch\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'charformer-pytorch',  # 包的名称
  packages = find_packages(),  # 查找所有包
  version = '0.0.4',  # 版本号
  license='MIT',  # 许可证
  description = 'Charformer - Pytorch',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  url = 'https://github.com/lucidrains/charformer-pytorch',  # 项目链接
  keywords = [
    'artificial intelligence',  # 关键词
    'deep learning',  # 关键词
    'learned tokenization'  # 关键词
  ],
  install_requires=[
    'einops>=0.3',  # 安装所需的依赖包
    'torch>=1.6'  # 安装所需的依赖包
  ],
  classifiers=[
    'Development Status :: 4 - Beta',  # 分类器
    'Intended Audience :: Developers',  # 分类器
    'Topic :: Scientific/Engineering :: Artificial Intelligence',  # 分类器
    'License :: OSI Approved :: MIT License',  # 分类器
    'Programming Language :: Python :: 3.6',  # 分类器
  ],
)

.\lucidrains\chroma-pytorch\chroma_pytorch\chroma_pytorch.py

import torch  # 导入 PyTorch 库
from torch import nn, einsum  # 从 PyTorch 库中导入 nn 模块和 einsum 函数

from einops import rearrange, repeat  # 从 einops 库中导入 rearrange 和 repeat 函数

import math  # 导入 math 库
from pathlib import Path  # 从 pathlib 库中导入 Path 类
from random import random  # 从 random 库中导入 random 函数
from functools import partial  # 从 functools 库中导入 partial 函数
from multiprocessing import cpu_count  # 从 multiprocessing 库中导入 cpu_count 函数

import torch  # 重新导入 PyTorch 库
from torch import nn, einsum  # 从 PyTorch 库中重新导入 nn 模块和 einsum 函数
from torch.special import expm1  # 从 PyTorch 库中导入 expm1 函数
import torch.nn.functional as F  # 从 PyTorch 库中导入 F 模块
from torch.utils.data import Dataset, DataLoader  # 从 PyTorch 库中导入 Dataset 和 DataLoader 类

from torch.optim import Adam  # 从 PyTorch 库中导入 Adam 优化器
from torchvision import transforms as T, utils  # 从 torchvision 库中导入 transforms 模块和 utils 模块

from einops import rearrange, reduce, repeat  # 从 einops 库中重新导入 rearrange、reduce 和 repeat 函数
from einops.layers.torch import Rearrange  # 从 einops 库中导入 Rearrange 类

from tqdm.auto import tqdm  # 从 tqdm 库中导入 tqdm 函数
from ema_pytorch import EMA  # 从 ema_pytorch 库中导入 EMA 类

from accelerate import Accelerator  # 从 accelerate 库中导入 Accelerator 类

# helpers functions

def exists(x):  # 定义 exists 函数,判断变量 x 是否存在
    return x is not None

def default(val, d):  # 定义 default 函数,如果 val 存在则返回 val,否则返回 d()
    if exists(val):
        return val
    return d() if callable(d) else d

def cycle(dl):  # 定义 cycle 函数,循环生成数据集 dl 中的数据
    while True:
        for data in dl:
            yield data

def has_int_squareroot(num):  # 定义 has_int_squareroot 函数,判断 num 是否有整数平方根
    return (math.sqrt(num) ** 2) == num

def num_to_groups(num, divisor):  # 定义 num_to_groups 函数,将 num 分成 divisor 组
    groups = num // divisor
    remainder = num % divisor
    arr = [divisor] * groups
    if remainder > 0:
        arr.append(remainder)
    return arr

def convert_image_to(img_type, image):  # 定义 convert_image_to 函数,将图像转换为指定类型
    if image.mode != img_type:
        return image.convert(img_type)
    return image

# small helper modules

class Residual(nn.Module):  # 定义 Residual 类��实现残差连接
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, *args, **kwargs):
        return self.fn(x, *args, **kwargs) + x

def Upsample(dim, dim_out = None):  # 定义 Upsample 函数,上采样操作
    return nn.Sequential(
        nn.Upsample(scale_factor = 2, mode = 'nearest'),
        nn.Conv2d(dim, default(dim_out, dim), 3, padding = 1)
    )

def Downsample(dim, dim_out = None):  # 定义 Downsample 函数,下采样操作
    return nn.Conv2d(dim, default(dim_out, dim), 4, 2, 1)

class LayerNorm(nn.Module):  # 定义 LayerNorm 类,实现层归一化
    def __init__(self, dim):
        super().__init__()
        self.g = nn.Parameter(torch.ones(1, dim, 1, 1))

    def forward(self, x):
        eps = 1e-5 if x.dtype == torch.float32 else 1e-3
        var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
        mean = torch.mean(x, dim = 1, keepdim = True)
        return (x - mean) * (var + eps).rsqrt() * self.g

class PreNorm(nn.Module):  # 定义 PreNorm 类,实现预归一化
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = LayerNorm(dim)

    def forward(self, x):
        x = self.norm(x)
        return self.fn(x)

# positional embeds

class LearnedSinusoidalPosEmb(nn.Module):  # 定义 LearnedSinusoidalPosEmb 类,实现学习的正弦位置嵌入
    def __init__(self, dim):
        super().__init__()
        assert (dim % 2) == 0
        half_dim = dim // 2
        self.weights = nn.Parameter(torch.randn(half_dim))

    def forward(self, x):
        x = rearrange(x, 'b -> b 1')
        freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi
        fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1)
        fouriered = torch.cat((x, fouriered), dim = -1)
        return fouriered

# building block modules

class Block(nn.Module):  # 定义 Block 类,实现基本块
    def __init__(self, dim, dim_out, groups = 8):
        super().__init__()
        self.proj = nn.Conv2d(dim, dim_out, 3, padding = 1)
        self.norm = nn.GroupNorm(groups, dim_out)
        self.act = nn.SiLU()

    def forward(self, x, scale_shift = None):
        x = self.proj(x)
        x = self.norm(x)

        if exists(scale_shift):
            scale, shift = scale_shift
            x = x * (scale + 1) + shift

        x = self.act(x)
        return x

class ResnetBlock(nn.Module):  # 定义 ResnetBlock 类,实现残差块
    def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.SiLU(),
            nn.Linear(time_emb_dim, dim_out * 2)
        ) if exists(time_emb_dim) else None

        self.block1 = Block(dim, dim_out, groups = groups)
        self.block2 = Block(dim_out, dim_out, groups = groups)
        self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
    # 定义前向传播函数,接受输入 x 和时间嵌入 time_emb
    def forward(self, x, time_emb = None):

        # 初始化 scale_shift 为 None
        scale_shift = None
        # 如果 self.mlp 和 time_emb 都存在
        if exists(self.mlp) and exists(time_emb):
            # 将 time_emb 输入到 self.mlp 中进行处理
            time_emb = self.mlp(time_emb)
            # 重新排列 time_emb 的维度,增加两个维度
            time_emb = rearrange(time_emb, 'b c -> b c 1 1')
            # 将 time_emb 拆分成两部分,分别赋值给 scale 和 shift
            scale_shift = time_emb.chunk(2, dim = 1)

        # 将输入 x 传入第一个块中进行处理
        h = self.block1(x, scale_shift = scale_shift)

        # 将处理后的结果传入第二个块中进行处理
        h = self.block2(h)

        # 返回处理后的结果与输入 x 经过残差卷积的结果之和
        return h + self.res_conv(x)
class LinearAttention(nn.Module):
    # 定义线性注意力机制模块
    def __init__(self, dim, heads = 4, dim_head = 32):
        # 初始化函数
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        # 将输入转换为查询、键、值
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)

        self.to_out = nn.Sequential(
            # 输出转换为指定维度
            nn.Conv2d(hidden_dim, dim, 1),
            # 对输出进行 LayerNorm 处理
            LayerNorm(dim)
        )

    def forward(self, x):
        # 前向传播函数
        b, c, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, dim = 1)
        q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv)

        q = q.softmax(dim = -2)
        k = k.softmax(dim = -1)

        q = q * self.scale
        v = v / (h * w)

        context = torch.einsum('b h d n, b h e n -> b h d e', k, v)

        out = torch.einsum('b h d e, b h d n -> b h e n', context, q)
        out = rearrange(out, 'b h c (x y) -> b (h c) x y', h = self.heads, x = h, y = w)
        return self.to_out(out)

class Attention(nn.Module):
    # 定义注意力机制模块
    def __init__(self, dim, heads = 4, dim_head = 32):
        # 初始化函数
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        # 将输入转换为查询、键、值
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
        self.to_out = nn.Conv2d(hidden_dim, dim, 1)

    def forward(self, x):
        # 前向传播函数
        b, c, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, dim = 1)
        q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv)

        q = q * self.scale

        sim = einsum('b h d i, b h d j -> b h i j', q, k)
        attn = sim.softmax(dim = -1)
        out = einsum('b h i j, b h d j -> b h i d', attn, v)
        out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w)
        return self.to_out(out)

# model

class Unet(nn.Module):
    # 定义 Unet 模型
    def __init__(
        self,
        dim,
        init_dim = None,
        dim_mults=(1, 2, 4, 8),
        channels = 3,
        resnet_block_groups = 8,
        learned_sinusoidal_dim = 16
    ):
        # 调用父类的构造函数
        super().__init__()

        # 确定维度
        self.channels = channels
        input_channels = channels * 2
        init_dim = default(init_dim, dim)
        # 初始化卷积层,输入通道数为input_channels,输出通道数为init_dim,卷积核大小为7,填充为3
        self.init_conv = nn.Conv2d(input_channels, init_dim, 7, padding = 3)

        # 计算不同层次的维度
        dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
        in_out = list(zip(dims[:-1], dims[1:]))

        # 定义ResnetBlock类的部分参数
        block_klass = partial(ResnetBlock, groups = resnet_block_groups)

        # 时间嵌入
        time_dim = dim * 4
        sinu_pos_emb = LearnedSinusoidalPosEmb(learned_sinusoidal_dim)
        fourier_dim = learned_sinusoidal_dim + 1

        # 时间嵌入的多层感知机
        self.time_mlp = nn.Sequential(
            sinu_pos_emb,
            nn.Linear(fourier_dim, time_dim),
            nn.GELU(),
            nn.Linear(time_dim, time_dim)
        )

        # 层次
        self.downs = nn.ModuleList([])
        self.ups = nn.ModuleList([])
        num_resolutions = len(in_out)

        # 遍历不同层次的维度
        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (num_resolutions - 1)

            # 添加不同层次的模块到downs列表中
            self.downs.append(nn.ModuleList([
                block_klass(dim_in, dim_in, time_emb_dim = time_dim),
                block_klass(dim_in, dim_in, time_emb_dim = time_dim),
                Residual(PreNorm(dim_in, LinearAttention(dim_in))),
                Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding = 1)
            ]))

        mid_dim = dims[-1]
        # 中间块
        self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim)
        self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
        self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim)

        # 反向遍历不同层次的维度
        for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
            is_last = ind == (len(in_out) - 1)

            # 添加不同层次的模块到ups列表中
            self.ups.append(nn.ModuleList([
                block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
                block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
                Residual(PreNorm(dim_out, LinearAttention(dim_out))),
                Upsample(dim_out, dim_in) if not is_last else  nn.Conv2d(dim_out, dim_in, 3, padding = 1)
            ]))

        # 最终的残差块
        self.final_res_block = block_klass(dim * 2, dim, time_emb_dim = time_dim)
        self.final_conv = nn.Conv2d(dim, channels, 1)

    def forward(self, x, time, x_self_cond = None):

        # 默认x_self_cond为与x相同形状的全零张量
        x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x))
        x = torch.cat((x_self_cond, x), dim = 1)

        x = self.init_conv(x)
        r = x.clone()

        t = self.time_mlp(time)

        h = []

        # 遍历downs列表中的模块
        for block1, block2, attn, downsample in self.downs:
            x = block1(x, t)
            h.append(x)

            x = block2(x, t)
            x = attn(x)
            h.append(x)

            x = downsample(x)

        x = self.mid_block1(x, t)
        x = self.mid_attn(x)
        x = self.mid_block2(x, t)

        # 遍历ups列表中的模块
        for block1, block2, attn, upsample in self.ups:
            x = torch.cat((x, h.pop()), dim = 1)
            x = block1(x, t)

            x = torch.cat((x, h.pop()), dim = 1)
            x = block2(x, t)
            x = attn(x)

            x = upsample(x)

        x = torch.cat((x, r), dim = 1)

        x = self.final_res_block(x, t)
        return self.final_conv(x)
# 定义一个名为 Chroma 的类
class Chroma(nn.Module):
    # 初始化方法
    def __init__(
        self,
        model,
        *,
        image_size,
        timesteps = 1000,
        use_ddim = False,
        noise_schedule = 'cosine',
        time_difference = 0.
    ):
        # 调用父类的初始化方法
        super().__init__()
        # 设置模型和通道数
        self.model = model
        self.channels = self.model.channels

        # 设置图像大小和噪声调度
        self.image_size = image_size

        # 根据噪声调度选择不同的 log_snr 函数
        if noise_schedule == "linear":
            self.log_snr = beta_linear_log_snr
        elif noise_schedule == "cosine":
            self.log_snr = alpha_cosine_log_snr
        else:
            raise ValueError(f'invalid noise schedule {noise_schedule}')

        # 设置采样时间步数和是否使用 ddim
        self.timesteps = timesteps
        self.use_ddim = use_ddim

        # 设置时间差异
        self.time_difference = time_difference

    # 定义 device 属性
    @property
    def device(self):
        return next(self.model.parameters()).device

    # 获取采样时间步数
    def get_sampling_timesteps(self, batch, *, device):
        # 生成时间序列
        times = torch.linspace(1., 0., self.timesteps + 1, device = device)
        times = repeat(times, 't -> b t', b = batch)
        times = torch.stack((times[:, :-1], times[:, 1:]), dim = 0)
        times = times.unbind(dim = -1)
        return times

    # 生成样本
    @torch.no_grad()
    def ddpm_sample(self, shape, time_difference = None):
        # 获取 batch 大小和设备
        batch, device = shape[0], self.device

        # 设置时间差异
        time_difference = default(time_difference, self.time_difference)

        # 获取采样时间步数
        time_pairs = self.get_sampling_timesteps(batch, device = device)

        # 生成随机噪声图像
        img = torch.randn(shape, device=device)

        x_start = None

        # 循环采样时间步数
        for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step', total = self.timesteps):

            # 添加时间延迟
            time_next = (time_next - self.time_difference).clamp(min = 0.)

            # 获取噪声条件
            noise_cond = self.log_snr(time)

            # 获取预测的 x0
            x_start = self.model(img, noise_cond, x_start)

            # 限制 x0 的范围
            x_start.clamp_(-1., 1.)

            # 获取 log(snr)
            log_snr = self.log_snr(time)
            log_snr_next = self.log_snr(time_next)
            log_snr, log_snr_next = map(partial(right_pad_dims_to, img), (log_snr, log_snr_next))

            # 获取时间和下一个时间的 alpha 和 sigma
            alpha, sigma = log_snr_to_alpha_sigma(log_snr)
            alpha_next, sigma_next = log_snr_to_alpha_sigma(log_snr_next)

            # 推导后验均值和方差
            c = -expm1(log_snr - log_snr_next)
            mean = alpha_next * (img * (1 - c) / alpha + c * x_start)
            variance = (sigma_next ** 2) * c
            log_variance = log(variance)

            # 生成噪声
            noise = torch.where(
                rearrange(time_next > 0, 'b -> b 1 1 1'),
                torch.randn_like(img),
                torch.zeros_like(img)
            )

            # 更新图像
            img = mean + (0.5 * log_variance).exp() * noise

        return img

    @torch.no_grad()
    # 从给定形状中采样数据,可以指定时间差
    def ddim_sample(self, shape, time_difference = None):
        # 获取批次大小和设备
        batch, device = shape[0], self.device

        # 设置时间差,默认为self.time_difference
        time_difference = default(time_difference, self.time_difference)

        # 获取采样时间步
        time_pairs = self.get_sampling_timesteps(batch, device = device)

        # 生成符合正态分布的随机数据
        img = torch.randn(shape, device = device)

        x_start = None

        # 遍历时间对
        for times, times_next in tqdm(time_pairs, desc = 'sampling loop time step'):

            # 获取时间和噪声水平
            log_snr = self.log_snr(times)
            log_snr_next = self.log_snr(times_next)

            # 将噪声水平填充到与img相同的维度
            padded_log_snr, padded_log_snr_next = map(partial(right_pad_dims_to, img), (log_snr, log_snr_next))

            # 将噪声水平转换为alpha和sigma
            alpha, sigma = log_snr_to_alpha_sigma(padded_log_snr)
            alpha_next, sigma_next = log_snr_to_alpha_sigma(padded_log_snr_next)

            # 添加时间延迟
            times_next = (times_next - time_difference).clamp(min = 0.)

            # 预测x0
            x_start = self.model(img, log_snr, x_start)

            # 限制x0的取值范围
            x_start.clamp_(-1., 1.)

            # 获取预测的噪声
            pred_noise = (img - alpha * x_start) / sigma.clamp(min = 1e-8)

            # 计算下一个x
            img = x_start * alpha_next + pred_noise * sigma_next

        return img

    # 无梯度计算
    @torch.no_grad()
    def sample(self, batch_size = 16):
        image_size, channels = self.image_size, self.channels
        # 根据是否使用DDIM选择采样函数
        sample_fn = self.ddpm_sample if not self.use_ddim else self.ddim_sample
        return sample_fn((batch_size, channels, image_size, image_size))

    # 前向传播函数
    def forward(self, img, *args, **kwargs):
        batch, c, h, w, device, img_size, = *img.shape, img.device, self.image_size
        # 断言图像的高度和宽度必须为img_size
        assert h == img_size and w == img_size, f'height and width of image must be {img_size}'

        # 生成随机时间
        times = torch.zeros((batch,), device = device).float().uniform_(0, 1.)

        # 生成噪声
        noise = torch.randn_like(img)

        # 获取噪声水平并填充到与img相同的维度
        noise_level = self.log_snr(times)
        padded_noise_level = right_pad_dims_to(img, noise_level)
        alpha, sigma =  log_snr_to_alpha_sigma(padded_noise_level)

        # 添加噪声到图像
        noised_img = alpha * img + sigma * noise

        # 如果进行自条件训练,50%的概率从当前时间预测x_start,并用unet进行条件
        # 这种技术会使训练速度减慢25%,但似乎显著降低FID
        self_cond = None
        if random() < 0.5:
            with torch.no_grad():
                self_cond = self.model(noised_img, noise_level).detach_()

        # 预测并进行梯度下降
        pred = self.model(noised_img, noise_level, self_cond)

        return F.mse_loss(pred, img)
# trainer 类
class Trainer(object):
    # 初始化方法
    def __init__(
        self,
        diffusion_model,
        folder,
        *,
        train_batch_size = 16,
        gradient_accumulate_every = 1,
        augment_horizontal_flip = True,
        train_lr = 1e-4,
        train_num_steps = 100000,
        ema_update_every = 10,
        ema_decay = 0.995,
        adam_betas = (0.9, 0.99),
        save_and_sample_every = 1000,
        num_samples = 25,
        results_folder = './results',
        amp = False,
        fp16 = False,
        split_batches = True,
        convert_image_to = None
    ):
        # 调用父类的初始化方法
        super().__init__()

        # 初始化加速器
        self.accelerator = Accelerator(
            split_batches = split_batches,
            mixed_precision = 'fp16' if fp16 else 'no'
        )

        # 设置是否使用 amp
        self.accelerator.native_amp = amp

        # 设置扩散模型
        self.model = diffusion_model

        # 检查 num_samples 是否有整数平方根
        assert has_int_squareroot(num_samples), 'number of samples must have an integer square root'
        self.num_samples = num_samples
        self.save_and_sample_every = save_and_sample_every

        # 设置训练批次大小和梯度累积频率
        self.batch_size = train_batch_size
        self.gradient_accumulate_every = gradient_accumulate_every

        # 设置训练步数和图像大小
        self.train_num_steps = train_num_steps
        self.image_size = diffusion_model.image_size

        # 数据集和数据加载器
        self.ds = Dataset(folder, self.image_size, augment_horizontal_flip = augment_horizontal_flip, convert_image_to = convert_image_to)
        dl = DataLoader(self.ds, batch_size = train_batch_size, shuffle = True, pin_memory = True, num_workers = cpu_count())

        # 准备数据加载器
        dl = self.accelerator.prepare(dl)
        self.dl = cycle(dl)

        # 优化器
        self.opt = Adam(diffusion_model.parameters(), lr = train_lr, betas = adam_betas)

        # 定期记录结果到文件夹
        if self.accelerator.is_main_process:
            self.ema = EMA(diffusion_model, beta = ema_decay, update_every = ema_update_every)

            self.results_folder = Path(results_folder)
            self.results_folder.mkdir(exist_ok = True)

        # 步数计数器状态
        self.step = 0

        # 使用加速器准备模型、数据加载器和优化器
        self.model, self.opt = self.accelerator.prepare(self.model, self.opt)

    # 保存模��
    def save(self, milestone):
        if not self.accelerator.is_local_main_process:
            return

        data = {
            'step': self.step,
            'model': self.accelerator.get_state_dict(self.model),
            'opt': self.opt.state_dict(),
            'ema': self.ema.state_dict(),
            'scaler': self.accelerator.scaler.state_dict() if exists(self.accelerator.scaler) else None
        }

        torch.save(data, str(self.results_folder / f'model-{milestone}.pt'))

    # 加载模型
    def load(self, milestone):
        data = torch.load(str(self.results_folder / f'model-{milestone}.pt'))

        model = self.accelerator.unwrap_model(self.model)
        model.load_state_dict(data['model'])

        self.step = data['step']
        self.opt.load_state_dict(data['opt'])
        self.ema.load_state_dict(data['ema'])

        if exists(self.accelerator.scaler) and exists(data['scaler']):
            self.accelerator.scaler.load_state_dict(data['scaler'])
    # 定义训练方法
    def train(self):
        # 获取加速器和设备
        accelerator = self.accelerator
        device = accelerator.device

        # 使用 tqdm 显示训练进度条,设置初始值、总步数和是否禁用
        with tqdm(initial=self.step, total=self.train_num_steps, disable=not accelerator.is_main_process) as pbar:

            # 在未达到总步数前循环
            while self.step < self.train_num_steps:

                # 初始化总损失
                total_loss = 0.

                # 根据梯度累积次数循环
                for _ in range(self.gradient_accumulate_every):
                    # 获取下一个数据批次并发送到设备
                    data = next(self.dl).to(device)

                    # 使用加速器自动混合精度
                    with self.accelerator.autocast():
                        # 计算模型损失
                        loss = self.model(data)
                        loss = loss / self.gradient_accumulate_every
                        total_loss += loss.item()

                    # 反向传播
                    self.accelerator.backward(loss)

                # 更新进度条显示损失值
                pbar.set_description(f'loss: {total_loss:.4f}')

                # 等待所有进程完成
                accelerator.wait_for_everyone()

                # 更新优化器参数
                self.opt.step()
                self.opt.zero_grad()

                # 等待所有进程完成
                accelerator.wait_for_everyone()

                # 如果是主进程
                if accelerator.is_main_process:
                    # 将指数移动平均模型发送到设备并更新
                    self.ema.to(device)
                    self.ema.update()

                    # 如果步数不为0且可以保存和采样
                    if self.step != 0 and self.step % self.save_and_sample_every == 0:
                        # 将指数移动平均模型设置为评估模式
                        self.ema.ema_model.eval()

                        # 使用无梯度计算
                        with torch.no_grad():
                            # 计算里程碑和批次数
                            milestone = self.step // self.save_and_sample_every
                            batches = num_to_groups(self.num_samples, self.batch_size)
                            all_images_list = list(map(lambda n: self.ema.ema_model.sample(batch_size=n), batches))

                        # 拼接所有图像并保存
                        all_images = torch.cat(all_images_list, dim=0)
                        utils.save_image(all_images, str(self.results_folder / f'sample-{milestone}.png'), nrow=int(math.sqrt(self.num_samples)))
                        self.save(milestone)

                # 更新步数并进度条
                self.step += 1
                pbar.update(1)

        # 打印训练完成信息
        accelerator.print('training complete')