Lucidrains-系列项目源码解析-八十-

68 阅读17分钟

Lucidrains 系列项目源码解析(八十)

.\lucidrains\q-transformer\q_transformer\__init__.py

# 从 q_transformer.q_robotic_transformer 模块中导入 QRoboticTransformer 和 MaxViT 类
from q_transformer.q_robotic_transformer import (
    QRoboticTransformer,
    MaxViT
)

# 从 q_transformer.q_learner 模块中导入 QLearner 类
from q_transformer.q_learner import (
    QLearner
)

# 从 q_transformer.agent 模块中导入 Agent、ReplayMemoryDataset 和 BaseEnvironment 类
from q_transformer.agent import (
    Agent,
    ReplayMemoryDataset,
    BaseEnvironment
)

Q-transformer

Implementation of Q-Transformer, Scalable Offline Reinforcement Learning via Autoregressive Q-Functions, out of Google Deepmind

I will be keeping around the logic for Q-learning on single action just for final comparison with the proposed autoregressive Q-learning on multiple actions. Also to serve as education for myself and the public.

Install

$ pip install q-transformer

Usage

import torch

from q_transformer import (
    QRoboticTransformer,
    QLearner,
    Agent,
    ReplayMemoryDataset
)

# the attention model

model = QRoboticTransformer(
    vit = dict(
        num_classes = 1000,
        dim_conv_stem = 64,
        dim = 64,
        dim_head = 64,
        depth = (2, 2, 5, 2),
        window_size = 7,
        mbconv_expansion_rate = 4,
        mbconv_shrinkage_rate = 0.25,
        dropout = 0.1
    ),
    num_actions = 8,
    action_bins = 256,
    depth = 1,
    heads = 8,
    dim_head = 64,
    cond_drop_prob = 0.2,
    dueling = True
)

# you need to supply your own environment, by overriding BaseEnvironment

from q_transformer.mocks import MockEnvironment

env = MockEnvironment(
    state_shape = (3, 6, 224, 224),
    text_embed_shape = (768,)
)

# env.init()     should return instructions and initial state: Tuple[str, Tensor[*state_shape]]
# env(actions)   should return rewards, next state, and done flag: Tuple[Tensor[()], Tensor[*state_shape], Tensor[()]]

# agent is a class that allows the q-model to interact with the environment to generate a replay memory dataset for learning

agent = Agent(
    model,
    environment = env,
    num_episodes = 1000,
    max_num_steps_per_episode = 100,
)

agent()

# Q learning on the replay memory dataset on the model

q_learner = QLearner(
    model,
    dataset = ReplayMemoryDataset(),
    num_train_steps = 10000,
    learning_rate = 3e-4,
    batch_size = 4,
    grad_accum_every = 16,
)

q_learner()

# after much learning
# your robot should be better at selecting optimal actions

video = torch.randn(2, 3, 6, 224, 224)

instructions = [
    'bring me that apple sitting on the table',
    'please pass the butter'
]

actions = model.get_optimal_actions(video, instructions)

Appreciation

Todo

  • first work way towards single action support

  • offer batchnorm-less variant of maxvit, as done in SOTA weather model metnet3

  • add optional deep dueling architecture

  • add n-step Q learning

  • build the conservative regularization

  • build out main proposal in paper (autoregressive discrete actions until last action, reward given only on last)

  • improvise decoder head variant, instead of concatenating previous actions at the frames + learned tokens stage. in other words, use classic encoder - decoder

    • allow for cross attention to fine frame / learned tokens
  • redo maxvit with axial rotary embeddings + sigmoid gating for attending to nothing. enable flash attention for maxvit with this change

  • build out a simple dataset creator class, taking in the environment and model and returning a folder that can be accepted by a ReplayDataset

    • finish basic environment loop
    • store memories to memmapped files in designated folder
    • ReplayDataset that takes in folder
      • 1 time step option
      • n-time steps
  • handle multiple instructions correctly

  • show a simple end-to-end example, in the same style as all other repos

  • handle no instructions, leverage null conditioner in CFG library

  • cache kv for action decoding

  • for exploration, allow for finely randomizing a subset of actions, and not all actions at once

    • also allow for gumbel based sampling of actions, with annealing of gumbel noise
  • consult some RL experts and figure out if there are any new headways into resolving delusional bias

  • figure out if one can train with randomized orders of actions - order could be sent as a conditioning that is concatted or summed before attention layers

    • offer an improvised variant where the first action token suggests the action ordering. all actions aren't made equal, and some may need to attend to past actions more than others
  • simple beam search function for optimal actions

  • improvise cross attention to past actions and states of timestep, transformer-xl fashion (w/ structured memory dropout)

  • see if the main idea in this paper is applicable to language models here

Citations

@inproceedings{qtransformer,
    title   = {Q-Transformer: Scalable Offline Reinforcement Learning via Autoregressive Q-Functions},
    authors = {Yevgen Chebotar and Quan Vuong and Alex Irpan and Karol Hausman and Fei Xia and Yao Lu and Aviral Kumar and Tianhe Yu and Alexander Herzog and Karl Pertsch and Keerthana Gopalakrishnan and Julian Ibarz and Ofir Nachum and Sumedh Sontakke and Grecia Salazar and Huong T Tran and Jodilyn Peralta and Clayton Tan and Deeksha Manjunath and Jaspiar Singht and Brianna Zitkovich and Tomas Jackson and Kanishka Rao and Chelsea Finn and Sergey Levine},
    booktitle = {7th Annual Conference on Robot Learning},
    year   = {2023}
}
@inproceedings{dao2022flashattention,
    title   = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
    author  = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
    booktitle = {Advances in Neural Information Processing Systems},
    year    = {2022}
}

.\lucidrains\q-transformer\setup.py

# 导入设置和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'q-transformer', # 包的名称
  packages = find_packages(exclude=[]), # 查找所有包
  version = '0.1.14', # 版本号
  license='MIT', # 许可证
  description = 'Q-Transformer', # 描述
  author = 'Phil Wang', # 作者
  author_email = 'lucidrains@gmail.com', # 作者邮箱
  long_description_content_type = 'text/markdown', # 长描述内容类型
  url = 'https://github.com/lucidrains/q-transformer', # URL
  keywords = [ # 关键词
    'artificial intelligence',
    'deep learning',
    'attention mechanisms',
    'transformers',
    'q-learning'
  ],
  install_requires=[ # 安装依赖
    'accelerate',
    'beartype',
    'classifier-free-guidance-pytorch>=0.4.2',
    'einops>=0.7.0',
    'ema-pytorch>=0.3.1',
    'numpy',
    'torchtyping',
    'torch>=2.0'
  ],
  classifiers=[ # 分类
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\quartic-transformer\quartic_transformer\multi_stream_transformer.py

        """
        实现注意力机制的模块
        参数:
            dim - 输入特征的维度
            num_streams - 流的数量
            dim_head - 每个头的维度
            heads - 头的数量
            dropout - 丢弃率
            causal - 是否使用因果注意力
            pre_talking_heads - 是否使用预对话头
            post_talking_heads - 是否使用后对话头
            non_linear_talking_heads - 是否使用非线性对话头
        """
        super().__init__()
        dim_inner = dim_head * heads
        all_heads = num_streams * heads

        self.num_streams = num_streams

        # 将输入转换为查询、键、值
        self.to_qkv = nn.Sequential(
            nn.Linear(dim, dim_inner * 3, bias = False),
            Rearrange('b n (qkv h d) -> qkv b h n d', h = heads, qkv = 3)
        )

        # 生成门控值
        self.to_gates = nn.Sequential(
            nn.Linear(dim, heads),
            Rearrange('b n h -> b h n 1'),
            nn.Sigmoid()
        )

        # RMSNorm 归一化
        self.rmsnorm = einn.Norm('b... [d]', mean = False, bias = False)

        self.scale = dim_head ** 0.5
        self.causal = causal
        self.dropout = nn.Dropout(dropout)

        self.pre_talking_heads = None
        self.post_talking_heads = None

        # 根据参数选择是否使用非线性对话头
        if non_linear_talking_heads:
            self.pre_talking_heads = TalkingHeadsFeedForward(all_heads) if pre_talking_heads else None
            self.post_talking_heads = TalkingHeadsFeedForward(all_heads) if post_talking_heads else None
        else:
            # 根据参数选择是否使用卷积对话头
            self.pre_talking_heads = nn.Conv2d(all_heads, all_heads, 1, bias = False) if pre_talking_heads else None
            self.post_talking_heads = nn.Conv2d(all_heads, all_heads, 1, bias = False) if post_talking_heads else None

            # 初始化卷积对话头的权重
            nn.init.dirac_(self.pre_talking_heads.weight)
            nn.init.dirac_(self.post_talking_heads.weight)

        # 输出层
        self.to_out = nn.Sequential(
            Rearrange('b h n d -> b n (h d)'),
            nn.Linear(dim_inner, dim, bias = False),
            nn.Dropout(dropout)
        )
        ):
            # 获取输入张量 x 的流数
            s = self.num_streams
            # 对输入张量 x 进行均方根归一化
            x = self.rmsnorm(x)

            # 将输入张量 x 转换为查询、键、值张量
            q, k, v = self.to_qkv(x)

            # 对查询张量 q 进行缩放
            q = q * self.scale
            # 计算注意力矩阵
            sim = einsum('b h i d, b h j d -> b h i j', q, k)

            # 计算掩码值
            mask_value = -torch.finfo(sim.dtype).max

            # 如果存在预处理头部函数
            if exists(self.pre_talking_heads):
                # 重排注意力矩阵的维度
                sim = rearrange(sim, '(b s) h n d -> b (s h) n d', s = s)
                # 对注意力矩阵进行预处理
                sim = self.pre_talking_heads(sim)
                # 恢复注意力矩阵的维度
                sim = rearrange(sim, 'b (s h) n d -> (b s) h n d', s = s)

            # 如果存在掩码
            if exists(mask):
                # 根据掩码值对注意力矩阵进行处理
                sim = einx.where('b j, b ... j, ', mask, sim, mask_value)

            # 如果是因果注意力
            if self.causal:
                i, j = sim.shape[-2:]
                # 创建因果掩码
                causal_mask = torch.ones((i, j), dtype = torch.bool).triu(j - i + 1)
                sim = sim.masked_fill(causal_mask, mask_value)

            # 对注意力矩阵进行 softmax 操作
            attn = einx.softmax('b h i [j]', sim)

            # 保存 softmax 操作后的注意力矩阵
            post_softmax_attn = attn

            # 对注意力矩阵进行 dropout 操作
            attn = self.dropout(attn)

            # 如果存在后处理头部函数
            if exists(self.post_talking_heads):
                # 重排注意力矩阵的维度
                attn = rearrange(attn, '(b s) h n d -> b (s h) n d', s = s)
                # 对注意力矩阵进行后处理
                attn = self.post_talking_heads(attn)
                # 恢复注意力矩阵的维度
                attn = rearrange(attn, 'b (s h) n d -> (b s) h n d', s = s)

            # 计算输出张量
            out = einsum('b h i j, b h j d -> b h i d', attn, v)

            # 对输出张量进行门控操作
            out = out * self.to_gates(x)
            # 对输出张量进行输出转换
            out = self.to_out(out)

            # 返回输出张量和 softmax 操作后的注意力矩阵
            return out, post_softmax_attn
# 定义一个前馈神经网络模块
def FeedForward(dim, mult = 4, dropout = 0.):
    # 计算内部维度
    dim_inner = int(dim * mult)
    # 返回一个包含多个层的神经网络模块
    return nn.Sequential(
        # 归一化层,对输入进行归一化处理
        einn.Norm('b... [d]', mean = False, bias = False),
        # 全连接层,将输入维度转换为内部维度
        nn.Linear(dim, dim_inner, bias = False),
        # GELU激活函数
        nn.GELU(),
        # Dropout层,以一定概率丢弃部分神经元
        nn.Dropout(dropout),
        # 全连接层,将内部维度转换为输出维度
        nn.Linear(dim_inner, dim, bias = False)
    )

# 定义一个TalkingHeads前馈神经网络模块
def TalkingHeadsFeedForward(dim, mult = 2, dropout = 0.):
    # 计算内部维度
    dim_inner = int(dim * mult)
    # 创建一个包含多个层的神经网络模块
    net = nn.Sequential(
        # 归一化层,对输入进行归一化处理
        einn.Norm('b [c] ...', mean = False, bias = False),
        # 二维卷积层,将输入维度转换为内部维度
        nn.Conv2d(dim, dim_inner, 1, bias = False),
        # GELU激活函数
        nn.GELU(),
        # Dropout层,以一定概率丢弃部分神经元
        nn.Dropout(dropout),
        # 二维卷积层,将内部维度转换为输出维度
        nn.Conv2d(dim_inner, dim, 1, bias = False)
    )

    # 初始化最后一层的权重为零
    nn.init.zeros_(net[-1].weight)
    # 返回一个残差连接的神经网络模块
    return Residual(net)

# 定义TokenAndPosEmb类,用于处理共享的Token和位置嵌入
class TokenAndPosEmb(Module):
    def __init__(
        self,
        *,
        dim,
        num_tokens,
        max_seq_len,
        num_streams
    ):
        super().__init__()
        # 创建Token嵌入层
        self.token_emb = nn.Embedding(num_tokens, dim)
        # 创建位置嵌入层
        self.pos_emb = nn.Embedding(max_seq_len, dim)
        # 创建流嵌入参数
        self.stream_emb = nn.Parameter(torch.zeros(num_streams, dim))
        # 初始化流嵌入参数
        nn.init.normal_(self.stream_emb, std = 0.02)

    def forward(self, x):
        # 生成序列长度
        seq_len = torch.arange(x.shape[-1], device = x.device)
        # 获取Token嵌入
        token_emb = self.token_emb(x)
        # 获取位置嵌入
        pos_emb = self.pos_emb(seq_len)
        # 返回Token、位置和流嵌入的加和结果
        return einx.add('b n d, n d, s d -> (b s) n d', token_emb, pos_emb, self.stream_emb)

# 定义SeparateTokenAndPosEmb类,用于处理独立的Token和位置嵌入
class SeparateTokenAndPosEmb(Module):
    def __init__(
        self,
        *,
        dim,
        num_tokens,
        max_seq_len,
        num_streams
    ):
        super().__init__()
        # 创建独立的Token嵌入参数
        self.token_emb = nn.Parameter(torch.zeros(num_streams, num_tokens, dim))
        # 创建独立的位置嵌入参数
        self.pos_emb = nn.Parameter(torch.zeros(num_streams, max_seq_len, dim))
        # 初始化Token嵌入参数和位置嵌入参数
        nn.init.normal_(self.token_emb, std = 0.02)
        nn.init.normal_(self.pos_emb, std = 0.02)

    def forward(self, x):
        # 生成序列长度
        seq_len = torch.arange(x.shape[-1], device = x.device)
        # 获取Token嵌入
        token_emb = get_at('s [e] d, b n -> b s n d', self.token_emb, x)
        # 获取位置嵌入
        pos_emb = get_at('s [e] d, n -> s n d', self.pos_emb, x)
        # 返回Token和位置嵌入的加和结果
        return einx.add('b s n d, s n d -> (b s) n d', token_emb, pos_emb)

# 定义MultiStreamTransformer类,用于多流Transformer模型
class MultiStreamTransformer(Module):
    def __init__(
        self,
        *,
        dim,
        num_tokens,
        depth,
        num_streams = 2,
        dim_head = 64,
        heads = 8,
        max_seq_len = 2048,
        attn_dropout = 0.,
        ff_dropout = 0.,
        ff_mult = 4.,
        ablate_cross_stream_talking_heads = False,
        pre_talking_heads = True,
        post_talking_heads = True,
        separate_stream_emb = True,
        non_linear_talking_heads = False
    ):
        # 调用父类的构造函数
        super().__init__()
        # 根据是否需要分离流嵌入选择不同的嵌入类
        embed_klass = SeparateTokenAndPosEmb if separate_stream_emb else TokenAndPosEmb

        # 初始化嵌入层
        self.emb = embed_klass(
            dim = dim,
            num_tokens = num_tokens,
            num_streams = num_streams,
            max_seq_len = max_seq_len
        )

        # 设置流的数量
        self.num_streams = num_streams
        # 初始化层列表
        self.layers = ModuleList([])

        # 根据是否禁用跨流的交谈头选择不同的流数量
        talking_heads_num_streams = 2 if not ablate_cross_stream_talking_heads else 1

        # 根据深度循环创建多个注意力层和前馈层
        for _ in range(depth):
            self.layers.append(ModuleList([
                Attention(
                    dim = dim,
                    dim_head = dim_head,
                    heads = heads,
                    dropout = attn_dropout,
                    num_streams = talking_heads_num_streams,
                    pre_talking_heads = pre_talking_heads,
                    post_talking_heads = post_talking_heads,
                    non_linear_talking_heads = non_linear_talking_heads
                ),
                FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
            ]))

        # 定义输出层
        self.to_logits = nn.Sequential(
            Reduce('(b s) n d -> b n d', 'sum', s = num_streams),
            einn.Norm('b... [d]', mean = False, bias = False),
            nn.Linear(dim, num_tokens, bias = False)
        )

    def forward(
        self,
        x,
        mask = None,
        stream_attn_diversity_loss = False
    ):
        # 获取输入张量的形状和设备信息
        b, n, s, device = *x.shape, self.num_streams, x.device

        # 如果流的数量大于1,则计算流的注意力多样性损失
        stream_attn_diversity_loss &= s > 1

        # 对输入张量进行嵌入
        x = self.emb(x)

        # 存储每个注意力层的注意力矩阵
        attn_matrices = []

        # 遍历每个注意力层和前馈层
        for attn, ff in self.layers:
            # 计算注意力层的输出和后softmax的注意力矩阵
            attn_out, post_softmax_attn = attn(x, mask = mask)

            # 将后softmax的注意力矩阵添加到列表中
            attn_matrices.append(post_softmax_attn)

            # 更新输入张量
            x = x + attn_out
            x = ff(x) + x

        # 如果需要计算流的注意力多样性损失,则计算辅助损失
        if stream_attn_diversity_loss:
            aux_loss = sum([calc_stream_loss(attn_matrix, s).mean() for attn_matrix in attn_matrices])

        # 计算最终输出
        logits = self.to_logits(x)

        # 如果不需要计算流的注意力多样性损失,则直接返回logits
        if not stream_attn_diversity_loss:
            return logits

        # 如果需要计算流的注意力多样性损失,则返回logits和辅助损失
        return logits, aux_loss

.\lucidrains\quartic-transformer\quartic_transformer\quartic_transformer.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn, einsum 模块
from torch import nn, einsum
# 从 torch.nn 模块中导入 Module, ModuleList 类
from torch.nn import Module, ModuleList

# 从 einops 库中导入 rearrange, repeat, pack, unpack 函数
from einops import rearrange, repeat, pack, unpack
# 从 einops.layers.torch 模块中导入 Rearrange 类
from einops.layers.torch import Rearrange

# 导入 einx 库
import einx
# 从 einx.nn.torch 模块中导入 einn 模块

# 导入 colt5_attention 模块中的 topk 函数

# 导入 taylor_series_linear_attention 模块中的 TaylorSeriesLinearAttn 类

# 从 x_transformers.x_transformers 模块中导入 DynamicPositionBias 类

# 定义辅助函数

# 判断变量是否存在的函数
def exists(v):
    return v is not None

# 返回默认值的函数
def default(v, d):
    return v if exists(v) else d

# 将张量打包成指定模式的函数
def pack_one(t, pattern):
    return pack([t], pattern)

# 将打包的张量解包成指定模式的函数
def unpack_one(t, ps, pattern):
    return unpack(t, ps, pattern)[0]

# 定义注意力机制类

class Attention(Module):
    def __init__(
        self,
        dim,
        dim_edges = None,
        dim_head = 64,
        heads = 8,
        dropout = 0.,
        causal = False,
        incorporate_edges = True
    ):
        super().__init__()
        dim_edges = default(dim_edges, dim)
        dim_inner = dim_head * heads

        # 定义 QKV 线性层和重排操作
        self.to_qkv = nn.Sequential(
            nn.Linear(dim, dim_inner * 3, bias = False),
            Rearrange('b n (qkv h d) -> qkv b h n d', h = heads, qkv = 3)
        )

        # 定义门控线性层和 Sigmoid 激活函数
        self.to_gates = nn.Sequential(
            nn.Linear(dim, heads),
            Rearrange('b n h -> b h n 1'),
            nn.Sigmoid()
        )

        # 定义 RMSNorm 层
        self.rmsnorm = einn.Norm('b... [d]', mean = False, bias = False)

        self.scale = dim_head ** 0.5
        self.causal = causal
        self.dropout = nn.Dropout(dropout)

        self.edges_to_attn_bias = None

        if incorporate_edges:
            # 定义边到注意力偏置的线性层和重排操作
            self.edges_to_attn_bias = nn.Sequential(
                einn.Norm('b... [d]', mean = False, bias = False),
                nn.Linear(dim_edges, heads),
                Rearrange('b i j h -> b h i j')
            )

        # 定义预处理头部的卷积层
        self.pre_talking_heads = nn.Conv2d(heads, heads, 1, bias = False)

        self.to_edges_out = None

        if incorporate_edges:
            # 定义输出到边的线��层和重排操作
            self.to_edges_out = nn.Sequential(
                nn.Conv2d(heads, dim_edges, 1, bias = False),
                Rearrange('b d i j -> b i j d')
            )

        # 定义输出层
        self.to_out = nn.Sequential(
            Rearrange('b h n d -> b n (h d)'),
            nn.Linear(dim_inner, dim, bias = False),
            nn.Dropout(dropout)
        )

    def forward(
        self,
        x,
        mask = None,
        edges = None
    ):
        x = self.rmsnorm(x)

        q, k, v = self.to_qkv(x)

        q = q * self.scale
        sim = einsum('b h i d, b h j d -> b h i j', q, k)

        mask_value = -torch.finfo(sim.dtype).max

        if exists(edges) and exists(self.edges_to_attn_bias):
            attn_bias = self.edges_to_attn_bias(edges)
            sim = sim + attn_bias

        sim = self.pre_talking_heads(sim)

        if exists(mask):
            sim = einx.where('b j, b ... j, ', mask, sim, mask_value)

        if self.causal:
            i, j = sim.shape[-2:]
            causal_mask = torch.ones((i, j), dtype = torch.bool).triu(j - i + 1)
            sim = sim.masked_fill(causal_mask, mask_value)

        attn = einx.softmax('b h i [j]', sim)
        attn = self.dropout(attn)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)

        out = out * self.to_gates(x)
        out = self.to_out(out)

        edges_out = None
        if exists(self.to_edges_out):
            edges_out = self.to_edges_out(attn)

        if not exists(edges_out):
            return out

        return out, edges_out

# 定义前馈神经网络类

def FeedForward(dim, mult = 4, dropout = 0.):
    dim_inner = int(dim * mult)
    return nn.Sequential(
        einn.Norm('b... [d]', mean = False, bias = False),
        nn.Linear(dim, dim_inner, bias = False),
        nn.GELU(),
        nn.Dropout(dropout),
        nn.Linear(dim_inner, dim, bias = False)
    )

# 定义边嵌入类

class EdgeEmbed(Module):
    # 初始化函数,接受维度参数和可选的边缘维度参数
    def __init__(self, dim, dim_edges = None):
        # 调用父类的初始化函数
        super().__init__()
        # 如果没有提供边缘维度参数,则使用默认值为维度参数
        dim_edges = default(dim_edges, dim)
        # 创建一个线性层,将输入维度映射到边缘维度,不使用偏置
        self.to_rows = nn.Linear(dim, dim_edges, bias = False)
        # 创建另一个线性层,将输入维度映射到边缘维度,不使用偏置
        self.to_cols = nn.Linear(dim, dim_edges, bias = False)

        # 创建一个序列模块,包含一个线性层和一个 LayerNorm 层,用于处理边缘维度数据
        self.to_edges = nn.Sequential(
            nn.Linear(dim_edges, dim_edges, bias = False),
            nn.LayerNorm(dim_edges)
        )

    # 前向传播函数,接受输入张量 x
    def forward(self, x):
        # 将输入张量 x 映射到行维度
        rows = self.to_rows(x)
        # 将输入张量 x 映射到列维度
        cols = self.to_cols(x)
        # 对行和列的外积求和,得到四维张量
        outer_sum = einx.add('b i d, b j d -> b i j d', rows, cols)
        # 将外积求和结果传入边缘处理模块,返回处理后的结果
        return self.to_edges(outer_sum)
# 定义 AxialLinearAttention 类,用于实现轴向线性注意力机制
class AxialLinearAttention(Module):
    def __init__(
        self,
        dim,
        diagonal_attn = True,
        **attn_kwargs
    ):
        super().__init__()

        # 初始化行注意力机制
        self.row_attn = TaylorSeriesLinearAttn(dim = dim, gate_value_heads = True, prenorm = True, **attn_kwargs)
        # 初始化列注意力机制
        self.col_attn = TaylorSeriesLinearAttn(dim = dim, gate_value_heads = True, prenorm = True, **attn_kwargs)

        # 如果设置了对角线注意力机制,则初始化对角线注意力机制
        self.diagonal_attn = Attention(dim = dim, incorporate_edges = False, **attn_kwargs) if diagonal_attn else None

    # 前向传播函数
    def forward(
        self,
        x,
        mask = None
    ):
        # 获取输入张量 x 的形状信息
        b, n, device = *x.shape[:2], x.device

        # 重排输入张量 x 的维度
        x = rearrange(x, 'b i j d -> (b i) j d')

        # 对行进行注意力计算并更新 x
        x = self.row_attn(x, mask = mask) + x

        # 重排 x 的维度
        x = rearrange(x, '(b i) j d -> (b j) i d', b = b)

        # 对列进行注意力计算并更新 x
        x = self.col_attn(x, mask = mask) + x

        # 重排 x 的维度
        x = rearrange(x, '(b j) i d -> b i j d', b = b)

        # 如果没有对角线注意力机制,则直接返回 x
        if not exists(self.diagonal_attn):
            return x

        # 创建对角线掩码
        diagonal_mask = torch.eye(n, dtype = torch.bool, device = device)
        diagonal_mask = rearrange(diagonal_mask, 'i j -> 1 i j')

        # 从 x 中提取对角线元素
        x = rearrange(x[diagonal_mask], '(b n) d -> b n d', b = b)

        # 对对角线元素进行注意力计算并更新 x
        x = self.diagonal_attn(x) + x

        # 重新排列对角线掩码的维度
        diagonal_mask = rearrange(diagonal_mask, '... -> ... 1')
        # 使用对角线掩码更新 x
        x = x.masked_scatter(diagonal_mask, x)
        return x

# 定义 QuarticTransformer 类,用于实现四次方变换器
class QuarticTransformer(Module):
    def __init__(
        self,
        *,
        num_tokens,
        dim,
        depth,
        dim_edges = None,
        dim_head = 64,
        heads = 8,
        causal = False,
        linear_dim_head = 16,
        linear_heads = 16,
        ff_mult = 4,
        dropout = 0.,
        max_seq_len = 2048,
        ablate_edges = False,
        edges_diagonal_attn = True
    ):
        super().__init__()
        dim_edges = default(dim_edges, dim)

        # 初始化类的属性
        self.ablate_edges = ablate_edges
        self.max_seq_len = max_seq_len

        # 初始化 token embedding 和 position embedding
        self.token_emb = nn.Embedding(num_tokens, dim)
        self.pos_emb = nn.Embedding(max_seq_len, dim)

        # 初始化动态相对位置偏置
        self.dynamic_rel_pos_bias = DynamicPositionBias(dim, depth = 2, heads = dim_edges)

        # 初始化边缘嵌入
        self.to_edge_emb = EdgeEmbed(dim, dim_edges)

        # 初始化层列表
        self.layers = ModuleList([])
        for _ in range(depth):
            self.layers.append(ModuleList([
                ModuleList([
                    Attention(dim = dim, dim_edges = dim_edges, dim_head = dim_head, heads = heads, dropout = dropout, causal = causal),
                    FeedForward(dim = dim, mult = ff_mult, dropout = dropout)
                ]),
                ModuleList([
                    AxialLinearAttention(dim = dim_edges, dim_head = linear_dim_head, heads = linear_heads, causal = causal, diagonal_attn = edges_diagonal_attn),
                    FeedForward(dim = dim_edges, mult = ff_mult)
                ])
            ]))

        # 初始化输出层
        self.to_logits = nn.Sequential(
            einn.Norm('b... [d]', mean = False, bias = False),
            nn.Linear(dim, num_tokens, bias = False)
        )

    # 前向传播函数
    def forward(
        self,
        x,
        mask = None
        ):
        # 获取输入张量的序列长度和设备信息
        seq_len, device = x.shape[-1], x.device
        # 断言序列长度不超过最大序列长度
        assert seq_len <= self.max_seq_len

        # 对输入张量进行 token embedding
        x = self.token_emb(x)

        # 添加位置编码
        x = x + self.pos_emb(torch.arange(seq_len, device=device))
        # 获取边的嵌入表示
        edges = self.to_edge_emb(x)

        # 计算动态相对位置偏置
        edges_rel_pos = self.dynamic_rel_pos_bias(seq_len, seq_len)
        # 将边的嵌入表示与动态相对位置偏置相加
        edges = einx.add('b i j d, d i j -> b i j d', edges, edges_rel_pos)

        # 初始化边的掩码
        edges_mask = None
        # 如果掩码存在,则更新边的掩码
        if exists(mask):
            edges_mask = einx.logical_and('b i, b j -> b (i j)', mask, mask)

        # 遍历每个层
        for (attn, ff), (edges_linear_attn, edges_ff,) in self.layers:

            # 使用注意力机制和前馈网络处理节点和边
            nodes_out, edges_out = attn(x, mask=mask, edges=edges if not self.ablate_edges else None)

            # 更新节点表示
            x = x + nodes_out
            x = ff(x) + x

            # 如果需要剔除边信息,则跳过
            if self.ablate_edges:
                continue

            # 更新边的表示
            edges = edges + edges_out

            # 线性变换边信息
            edges = edges_linear_attn(edges, mask=mask) + edges

            # 使用前馈网络处理边信息
            edges = edges_ff(edges) + edges

        # 返回最终的输出结果
        return self.to_logits(x)

.\lucidrains\quartic-transformer\quartic_transformer\__init__.py

# 从 quartic_transformer 包中导入 QuarticTransformer 类
from quartic_transformer.quartic_transformer import QuarticTransformer

# 从 quartic_transformer 包中导入 MultiStreamTransformer 类
from quartic_transformer.multi_stream_transformer import MultiStreamTransformer

Quartic Transformer (wip)

Exploring an idea where one forgets about efficiency and carries out attention on each edge of the nodes (tokens). You can think of it as doing attention on the attention matrix, taking the perspective of the attention matrix as all the directed edges of a fully connected graph.

The hypothesis is that there is a task out there that the (sub)quartic transformer can do that quadratic transformers cannot.

Will also contain a modified implementation of multistream transformer (which is not quartic, but number of streams times the quadratic).

Appreciation

Install

$ pip install quartic-transformer

Usage

import torch
from quartic_transformer import QuarticTransformer

model = QuarticTransformer(
    num_tokens = 256,
    depth = 2,
    dim = 512,
    dim_edges = 32
)

tokens = torch.randint(0, 256, (1, 128))

logits = model(tokens) # (1, 128, 256)

Todo

  • first add a weak taylor linear attention on top of all edges

  • use coordinate descent routing from the node attention matrix to select a subset of edges to update (and do full attention across)

  • build multi-stream transformer, but allow exchange of information at the attention matrix, either through residual attention or a small edge-wise feedforward

Citation

@inproceedings{Keles2022OnTC,
    title   = {On The Computational Complexity of Self-Attention},
    author  = {Feyza Duman Keles and Pruthuvi Maheshakya Wijewardena and Chinmay Hegde},
    booktitle = {International Conference on Algorithmic Learning Theory},
    year    = {2022},
    url     = {https://api.semanticscholar.org/CorpusID:252198880}
}
@article{Burtsev2021MultiStreamT,
    title   = {Multi-Stream Transformers},
    author  = {Mikhail S. Burtsev and Anna Rumshisky},
    journal = {ArXiv},
    year    = {2021},
    volume  = {abs/2107.10342},
    url     = {https://api.semanticscholar.org/CorpusID:236171087}
}
@misc{Sutton,
    title  = {The Bitter Lesson},
    url    = {http://www.incompleteideas.net/IncIdeas/BitterLesson.html},
    author = {Sutton, Rich}
}
@article{Shazeer2020TalkingHeadsA,
    title   = {Talking-Heads Attention},
    author  = {Noam M. Shazeer and Zhenzhong Lan and Youlong Cheng and Nan Ding and Le Hou},
    journal = {ArXiv},
    year    = {2020},
    volume  = {abs/2003.02436},
    url     = {https://api.semanticscholar.org/CorpusID:212414717}
}

.\lucidrains\quartic-transformer\setup.py

# 导入设置和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'quartic-transformer', # 包的名称
  packages = find_packages(exclude=[]), # 查找所有包
  version = '0.0.12', # 版本号
  license='MIT', # 许可证
  description = 'Quartic Transformer', # 描述
  author = 'Phil Wang', # 作者
  author_email = 'lucidrains@gmail.com', # 作者邮箱
  long_description_content_type = 'text/markdown', # 长描述内容类型
  url = 'https://github.com/lucidrains/quartic-transformer', # URL
  keywords = [ # 关键词列表
    'artificial intelligence',
    'deep learning',
    'transformer',
    'attention'
  ],
  install_requires=[ # 安装依赖
    'colt5-attention',
    'einops>=0.7.0',
    'einx[torch]>=0.1.3',
    'taylor-series-linear-attention',
    'torch>=2.0',
    'x-transformers'
  ],
  classifiers=[ # 分类器列表
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

Recurrent Interface Network (RIN) - Pytorch

Implementation of Recurrent Interface Network (RIN), for highly efficient generation of images and video without cascading networks, in Pytorch. The author unawaredly reinvented the induced set-attention block from the set transformers paper. They also combine this with the self-conditioning technique from the Bit Diffusion paper, specifically for the latents. The last ingredient seems to be a new noise function based around the sigmoid, which the author claims is better than cosine scheduler for larger images.

The big surprise is that the generations can reach this level of fidelity. Will need to verify this on my own machine

Additionally, we will try adding an extra linear attention on the main branch as well as self conditioning in the pixel-space.

The insight of being able to self-condition on any hidden state of the network as well as the newly proposed sigmoid noise schedule are the two main findings.

This repository also contains the ability to noise higher resolution images more, using the scale keyword argument on the GaussianDiffusion class. It also contains the simple linear gamma schedule proposed in that paper.

Appreciation

  • Stability.ai for the generous sponsorship to work on cutting edge artificial intelligence research

Install

$ pip install rin-pytorch

Usage

from rin_pytorch import GaussianDiffusion, RIN, Trainer

model = RIN(
    dim = 256,                  # model dimensions
    image_size = 128,           # image size
    patch_size = 8,             # patch size
    depth = 6,                  # depth
    num_latents = 128,          # number of latents. they used 256 in the paper
    dim_latent = 512,           # can be greater than the image dimension (dim) for greater capacity
    latent_self_attn_depth = 4, # number of latent self attention blocks per recurrent step, K in the paper
).cuda()

diffusion = GaussianDiffusion(
    model,
    timesteps = 400,
    train_prob_self_cond = 0.9,  # how often to self condition on latents
    scale = 1.                   # this will be set to < 1. for more noising and leads to better convergence when training on higher resolution images (512, 1024) - input noised images will be auto variance normalized
).cuda()

trainer = Trainer(
    diffusion,
    '/path/to/your/images',
    num_samples = 16,
    train_batch_size = 4,
    gradient_accumulate_every = 4,
    train_lr = 1e-4,
    save_and_sample_every = 1000,
    train_num_steps = 700000,         # total training steps
    ema_decay = 0.995,                # exponential moving average decay
)

trainer.train()

Results will be saved periodically to the ./results folder

If you would like to experiment with the RIN and GaussianDiffusion class outside the Trainer

import torch
from rin_pytorch import RIN, GaussianDiffusion

model = RIN(
    dim = 256,                  # model dimensions
    image_size = 128,           # image size
    patch_size = 8,             # patch size
    depth = 6,                  # depth
    num_latents = 128,          # number of latents. they used 256 in the paper
    latent_self_attn_depth = 4, # number of latent self attention blocks per recurrent step, K in the paper
).cuda()

diffusion = GaussianDiffusion(
    model,
    timesteps = 1000,
    train_prob_self_cond = 0.9,
    scale = 1.
)

training_images = torch.randn(8, 3, 128, 128).cuda() # images are normalized from 0 to 1
loss = diffusion(training_images)
loss.backward()
# after a lot of training

sampled_images = diffusion.sample(batch_size = 4)
sampled_images.shape # (4, 3, 128, 128)

Todo

Citations

@misc{jabri2022scalable,
    title   = {Scalable Adaptive Computation for Iterative Generation}, 
    author  = {Allan Jabri and David Fleet and Ting Chen},
    year    = {2022},
    eprint  = {2212.11972},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@inproceedings{Chen2023OnTI,
    title   = {On the Importance of Noise Scheduling for Diffusion Models},
    author  = {Ting Chen},
    year    = {2023}
}
@article{Salimans2022ProgressiveDF,
    title   = {Progressive Distillation for Fast Sampling of Diffusion Models},
    author  = {Tim Salimans and Jonathan Ho},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2202.00512}
}
@misc{https://doi.org/10.48550/arxiv.2302.01327,
    doi     = {10.48550/ARXIV.2302.01327},
    url     = {https://arxiv.org/abs/2302.01327},
    author  = {Kumar, Manoj and Dehghani, Mostafa and Houlsby, Neil},
    title   = {Dual PatchNorm},
    publisher = {arXiv},
    year    = {2023},
    copyright = {Creative Commons Attribution 4.0 International}
}
@inproceedings{Hang2023EfficientDT,
    title   = {Efficient Diffusion Training via Min-SNR Weighting Strategy},
    author  = {Tiankai Hang and Shuyang Gu and Chen Li and Jianmin Bao and Dong Chen and Han Hu and Xin Geng and Baining Guo},
    year    = {2023}
}
@inproceedings{dao2022flashattention,
    title   = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
    author  = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
    booktitle = {Advances in Neural Information Processing Systems},
    year    = {2022}
}
@inproceedings{Hoogeboom2023simpleDE,
    title   = {simple diffusion: End-to-end diffusion for high resolution images},
    author  = {Emiel Hoogeboom and Jonathan Heek and Tim Salimans},
    year    = {2023}
}

.\lucidrains\recurrent-interface-network-pytorch\rin_pytorch\attend.py

# 导入所需的模块和类
from functools import wraps
from packaging import version
from collections import namedtuple

import torch
from torch import nn, einsum
import torch.nn.functional as F

from einops import rearrange, reduce

# 定义一个命名元组 FlashAttentionConfig,用于存储 FlashAttention 的配置信息
FlashAttentionConfig = namedtuple('FlashAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])

# 定义一个辅助函数,用于检查值是否存在
def exists(val):
    return val is not None

# 定义一个装饰器函数,确保被装饰的函数只执行一次
def once(fn):
    called = False
    @wraps(fn)
    def inner(x):
        nonlocal called
        if called:
            return
        called = True
        return fn(x)
    return inner

# 定义一个打印函数,确保只打印一次
print_once = once(print)

# 主要类 Attend
class Attend(nn.Module):
    def __init__(
        self,
        dropout = 0.,
        flash = False
    ):
        super().__init__()
        self.dropout = dropout
        self.attn_dropout = nn.Dropout(dropout)

        self.flash = flash
        assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'

        # 确定 CUDA 和 CPU 的高效注意力配置
        self.cpu_config = FlashAttentionConfig(True, True, True)
        self.cuda_config = None

        if not torch.cuda.is_available() or not flash:
            return

        device_properties = torch.cuda.get_device_properties(torch.device('cuda'))

        if device_properties.major == 8 and device_properties.minor == 0:
            print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
            self.cuda_config = FlashAttentionConfig(True, False, False)
        else:
            print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
            self.cuda_config = FlashAttentionConfig(False, True, True)

    # Flash Attention 方法
    def flash_attn(self, q, k, v, mask = None):
        _, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device

        # 检查是否存在 mask,并将其扩展到兼容的形状
        if exists(mask):
            mask = mask.expand(-1, heads, q_len, -1)

        # 检查是否有兼容的设备用于 Flash Attention
        config = self.cuda_config if is_cuda else self.cpu_config

        # 使用 torch.backends.cuda.sdp_kernel() 来执行 Flash Attention
        with torch.backends.cuda.sdp_kernel(**config._asdict()):
            out = F.scaled_dot_product_attention(
                q, k, v,
                attn_mask = mask,
                dropout_p = self.dropout if self.training else 0.
            )

        return out

    # 前向传播方法
    def forward(self, q, k, v, mask = None):
        """
        einstein notation
        b - batch
        h - heads
        n, i, j - sequence length (base sequence length, source, target)
        d - feature dimension
        """

        q_len, k_len, device = q.shape[-2], k.shape[-2], q.device

        scale = q.shape[-1] ** -0.5

        if exists(mask) and mask.ndim != 4:
            mask = rearrange(mask, 'b j -> b 1 1 j')

        if self.flash:
            return self.flash_attn(q, k, v, mask = mask)

        # 相似度计算
        sim = einsum(f"b h i d, b h j d -> b h i j", q, k) * scale

        # key padding mask
        if exists(mask):
            sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)

        # 注意力计算
        attn = sim.softmax(dim=-1)
        attn = self.attn_dropout(attn)

        # 聚合值
        out = einsum(f"b h i j, b h j d -> b h i d", attn, v)

        return out

.\lucidrains\recurrent-interface-network-pytorch\rin_pytorch\rin_pytorch.py

import math
from pathlib import Path
from random import random
from functools import partial
from multiprocessing import cpu_count

import torch
from torch import nn, einsum
from torch.special import expm1
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from torch.optim import Adam
from torchvision import transforms as T, utils

from beartype import beartype

from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange

from rin_pytorch.attend import Attend

from PIL import Image
from tqdm.auto import tqdm
from ema_pytorch import EMA

from accelerate import Accelerator, DistributedDataParallelKwargs

# helpers functions

# 检查变量是否存在
def exists(x):
    return x is not None

# 返回输入值
def identity(x):
    return x

# 如果值存在则返回该值,否则返回默认值
def default(val, d):
    if exists(val):
        return val
    return d() if callable(d) else d

# 检查一个数是否可以被另一个数整除
def divisible_by(numer, denom):
    return (numer % denom) == 0

# 安全地进行除法运算
def safe_div(numer, denom, eps = 1e-10):
    return numer / denom.clamp(min = eps)

# 生成数据集的循环迭代器
def cycle(dl):
    while True:
        for data in dl:
            yield data

# 检查一个数是否有整数平方根
def has_int_squareroot(num):
    num_sqrt = math.sqrt(num)
    return int(num_sqrt) == num_sqrt

# 将一个数分成若干组
def num_to_groups(num, divisor):
    groups = num // divisor
    remainder = num % divisor
    arr = [divisor] * groups
    if remainder > 0:
        arr.append(remainder)
    return arr

# 将图像转换为指定类型
def convert_image_to(img_type, image):
    if image.mode != img_type:
        return image.convert(img_type)
    return image

# 创建序列模块
def Sequential(*mods):
    return nn.Sequential(*filter(exists, mods))

# use layernorm without bias, more stable

# 自定义 LayerNorm 类
class LayerNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(dim))
        self.register_buffer("beta", torch.zeros(dim))

    def forward(self, x):
        return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)

# 自定义 MultiHeadedRMSNorm 类
class MultiHeadedRMSNorm(nn.Module):
    def __init__(self, dim, heads = 1):
        super().__init__()
        self.scale = dim ** 0.5
        self.gamma = nn.Parameter(torch.ones(heads, 1, dim))

    def forward(self, x):
        return F.normalize(x, dim = -1) * self.scale * self.gamma

# positional embeds

# 自定义 LearnedSinusoidalPosEmb 类
class LearnedSinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        assert (dim % 2) == 0
        half_dim = dim // 2
        self.weights = nn.Parameter(torch.randn(half_dim))

    def forward(self, x):
        x = rearrange(x, 'b -> b 1')
        freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi
        fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1)
        fouriered = torch.cat((x, fouriered), dim = -1)
        return fouriered

# 自定义 LinearAttention 类
class LinearAttention(nn.Module):
    def __init__(
        self,
        dim,
        heads = 4,
        dim_head = 32,
        norm = False,
        qk_norm = False,
        time_cond_dim = None
    ):
        super().__init__()
        hidden_dim = dim_head * heads
        self.scale = dim_head ** -0.5
        self.heads = heads

        self.time_cond = None

        if exists(time_cond_dim):
            self.time_cond = nn.Sequential(
                nn.SiLU(),
                nn.Linear(time_cond_dim, dim * 2),
                Rearrange('b d -> b 1 d')
            )

            nn.init.zeros_(self.time_cond[-2].weight)
            nn.init.zeros_(self.time_cond[-2].bias)

        self.norm = LayerNorm(dim) if norm else nn.Identity()

        self.to_qkv = nn.Linear(dim, hidden_dim * 3, bias = False)

        self.qk_norm = qk_norm
        if qk_norm:
            self.q_norm = MultiHeadedRMSNorm(dim_head, heads)
            self.k_norm = MultiHeadedRMSNorm(dim_head, heads)

        self.to_out = nn.Sequential(
            nn.Linear(hidden_dim, dim, bias = False),
            LayerNorm(dim)
        )

    def forward(
        self,
        x,
        time = None
        ):
        # 获取 self.heads 的值,表示注意力头的数量
        h = self.heads
        # 对输入 x 进行归一化处理
        x = self.norm(x)

        # 如果存在时间条件
        if exists(self.time_cond):
            # 确保时间存在
            assert exists(time)
            # 将时间条件应用到输入 x 上,得到缩放和偏移量
            scale, shift = self.time_cond(time).chunk(2, dim = -1)
            x = (x * (scale + 1)) + shift

        # 将输入 x 转换为查询、键、值,并分成三部分
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)

        # 如果需要对查询和键进行归一化
        if self.qk_norm:
            q = self.q_norm(q)
            k = self.k_norm(k)

        # 对查询和键进行 softmax 操作
        q = q.softmax(dim = -1)
        k = k.softmax(dim = -2)

        # 对查询结果乘以缩放因子
        q = q * self.scale

        # 计算上下文信息
        context = torch.einsum('b h n d, b h n e -> b h d e', k, v)

        # 计算输出
        out = torch.einsum('b h d e, b h n d -> b h n e', context, q)
        # 重新排列输出的维度
        out = rearrange(out, 'b h n d -> b n (h d)')
        # 将输出传递给输出层并返回结果
        return self.to_out(out)
# 定义注意力机制模块
class Attention(nn.Module):
    def __init__(
        self,
        dim,
        dim_context = None,
        heads = 4,
        dim_head = 32,
        norm = False,
        norm_context = False,
        time_cond_dim = None,
        flash = False,
        qk_norm = False
    ):
        super().__init__()
        hidden_dim = dim_head * heads
        dim_context = default(dim_context, dim)

        self.time_cond = None

        # 如果存在时间条件维度,创建时间条件模块
        if exists(time_cond_dim):
            self.time_cond = nn.Sequential(
                nn.SiLU(),
                nn.Linear(time_cond_dim, dim * 2),
                Rearrange('b d -> b 1 d')
            )

            nn.init.zeros_(self.time_cond[-2].weight)
            nn.init.zeros_(self.time_cond[-2].bias)

        self.scale = dim_head ** -0.5
        self.heads = heads

        # 根据是否需要归一化创建 LayerNorm 或者 nn.Identity
        self.norm = LayerNorm(dim) if norm else nn.Identity()
        self.norm_context = LayerNorm(dim_context) if norm_context else nn.Identity()

        # 创建线性变换层
        self.to_q = nn.Linear(dim, hidden_dim, bias = False)
        self.to_kv = nn.Linear(dim_context, hidden_dim * 2, bias = False)
        self.to_out = nn.Linear(hidden_dim, dim, bias = False)

        self.qk_norm = qk_norm
        # 如果需要对 Q 和 K 进行归一化,创建 MultiHeadedRMSNorm 对象
        if qk_norm:
            self.q_norm = MultiHeadedRMSNorm(dim_head, heads)
            self.k_norm = MultiHeadedRMSNorm(dim_head, heads)

        # 创建 Attend 对象
        self.attend = Attend(flash = flash)

    def forward(
        self,
        x,
        context = None,
        time = None
    ):
        h = self.heads

        # 如果存在上下文,对上下文进行归一化
        if exists(context):
            context = self.norm_context(context)

        x = self.norm(x)

        context = default(context, x)

        # 如果存在时间条件,对输入进行时间条件处理
        if exists(self.time_cond):
            assert exists(time)
            scale, shift = self.time_cond(time).chunk(2, dim = -1)
            x = (x * (scale + 1)) + shift

        qkv = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)

        if self.qk_norm:
            q = self.q_norm(q)
            k = self.k_norm(k)

        out = self.attend(q, k, v)

        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

# 定义位置编码器模块
class PEG(nn.Module):
    def __init__(
        self,
        dim
    ):
        super().__init__()
        # 创建深度可分离卷积层
        self.ds_conv = nn.Conv2d(dim, dim, 3, padding = 1, groups = dim)

    def forward(self, x):
        b, n, d = x.shape
        hw = int(math.sqrt(n))
        x = rearrange(x, 'b (h w) d -> b d h w', h = hw)
        x = self.ds_conv(x)
        x = rearrange(x, 'b d h w -> b (h w) d')
        return x

# 定义前馈神经网络模块
class FeedForward(nn.Module):
    def __init__(self, dim, mult = 4, time_cond_dim = None):
        super().__init__()
        self.norm = LayerNorm(dim)

        self.time_cond = None

        # 如果存在时间条件维度,创建时间条件模块
        if exists(time_cond_dim):
            self.time_cond = nn.Sequential(
                nn.SiLU(),
                nn.Linear(time_cond_dim, dim * 2),
                Rearrange('b d -> b 1 d')
            )

            nn.init.zeros_(self.time_cond[-2].weight)
            nn.init.zeros_(self.time_cond[-2].bias)

        inner_dim = int(dim * mult)
        # 创建前馈神经网络结构
        self.net = nn.Sequential(
            nn.Linear(dim, inner_dim),
            nn.GELU(),
            nn.Linear(inner_dim, dim)
        )

    def forward(self, x, time = None):
        x = self.norm(x)

        if exists(self.time_cond):
            assert exists(time)
            scale, shift = self.time_cond(time).chunk(2, dim = -1)
            x = (x * (scale + 1)) + shift

        return self.net(x)

# 定义 RINBlock 模块
class RINBlock(nn.Module):
    def __init__(
        self,
        dim,
        latent_self_attn_depth,
        dim_latent = None,
        final_norm = True,
        patches_self_attn = True,
        **attn_kwargs
    # 初始化函数,设置模型的各个组件
    def __init__(
        self,
        dim,
        dim_latent,
        latent_self_attn_depth,
        final_norm = False,
        patches_self_attn = False,
        **attn_kwargs
    ):
        # 调用父类的初始化函数
        super().__init__()
        # 如果未指定隐藏层维度,则使用输入维度
        dim_latent = default(dim_latent, dim)

        # 将潜在特征向量关注到补丁上的注意力机制
        self.latents_attend_to_patches = Attention(dim_latent, dim_context = dim, norm = True, norm_context = True, **attn_kwargs)
        # 潜在特征向量的交叉注意力机制和前馈网络
        self.latents_cross_attn_ff = FeedForward(dim_latent)

        # 潜在特征向量的自注意力机制列表
        self.latent_self_attns = nn.ModuleList([])
        for _ in range(latent_self_attn_depth):
            self.latent_self_attns.append(nn.ModuleList([
                Attention(dim_latent, norm = True, **attn_kwargs),
                FeedForward(dim_latent)
            ]))

        # 最终潜在特征向量的归一化层
        self.latent_final_norm = LayerNorm(dim_latent) if final_norm else nn.Identity()

        # 补丁的位置编码
        self.patches_peg = PEG(dim)
        self.patches_self_attn = patches_self_attn

        # 如果开启了补丁的自注意力机制
        if patches_self_attn:
            # 补丁的自注意力机制和前馈网络
            self.patches_self_attn = LinearAttention(dim, norm = True, **attn_kwargs)
            self.patches_self_attn_ff = FeedForward(dim)

        # 补丁关注到潜在特征向量的注意力机制和前馈网络
        self.patches_attend_to_latents = Attention(dim, dim_context = dim_latent, norm = True, norm_context = True, **attn_kwargs)
        self.patches_cross_attn_ff = FeedForward(dim)

    # 前向传播函数
    def forward(self, patches, latents, t):
        # 对补丁进行位置编码
        patches = self.patches_peg(patches) + patches

        # 潜在特征向量从补丁中提取或聚类信息
        latents = self.latents_attend_to_patches(latents, patches, time = t) + latents

        # 潜在特征向量的交叉注意力机制和前馈网络
        latents = self.latents_cross_attn_ff(latents, time = t) + latents

        # 潜在特征向量的自注意力机制
        for attn, ff in self.latent_self_attns:
            latents = attn(latents, time = t) + latents
            latents = ff(latents, time = t) + latents

        # 如果开启了补丁的自注意力机制
        if self.patches_self_attn:
            # 补丁的额外自注意力机制
            patches = self.patches_self_attn(patches, time = t) + patches
            patches = self.patches_self_attn_ff(patches) + patches

        # 补丁关注到潜在特征向量的注意力机制
        patches = self.patches_attend_to_latents(patches, latents, time = t) + patches

        # 补丁的交叉注意力机制和前馈网络
        patches = self.patches_cross_attn_ff(patches, time = t) + patches

        # 最终潜在特征向量的归一化
        latents = self.latent_final_norm(latents)
        return patches, latents
# 定义 RIN(Recursive Image Network)类,继承自 nn.Module
class RIN(nn.Module):
    # 初始化函数
    def __init__(
        self,
        dim,
        image_size,
        patch_size = 16,
        channels = 3,
        depth = 6,                      # RIN 块的数量
        latent_self_attn_depth = 2,     # 每轮从像素空间到潜在空间交叉注意力的自注意力数量
        dim_latent = None,              # 潜在空间的维度,默认为图像维度(dim)
        num_latents = 256,              # 为了获得良好结果,仍然需要使用相当数量的潜在空间(256),与 Deepmind 的 Perceiver 系列论文保持一致
        learned_sinusoidal_dim = 16,
        latent_token_time_cond = False, # 是否使用一个潜在令牌作为时间条件,或者采用自适应层归一化的方式(如其他论文“Paella” - Dominic Rampas 等所示)
        dual_patchnorm = True,
        patches_self_attn = True,       # 该存储库中的自注意力并不严格遵循论文中提出的设计。提供一种方法来移除它,以防它是不稳定的根源
        **attn_kwargs
        ):
        # 调用父类的构造函数
        super().__init__()
        # 断言图像大小能够被补丁大小整除
        assert divisible_by(image_size, patch_size)
        # 如果未指定 latent 维度,则使用默认的维度
        dim_latent = default(dim_latent, dim)

        # 设置图像大小和通道数(由于自条件,通道数乘以2)
        self.image_size = image_size
        self.channels = channels

        # 计算图像中的补丁数量和每个像素补丁的维度
        patch_height_width = image_size // patch_size
        num_patches = patch_height_width ** 2
        pixel_patch_dim = channels * (patch_size ** 2)

        # 时间条件

        # 学习的正弦位置嵌入
        sinu_pos_emb = LearnedSinusoidalPosEmb(learned_sinusoidal_dim)
        time_dim = dim * 4
        fourier_dim = learned_sinusoidal_dim + 1

        self.latent_token_time_cond = latent_token_time_cond
        time_output_dim = dim_latent if latent_token_time_cond else time_dim

        # 时间 MLP
        self.time_mlp = nn.Sequential(
            sinu_pos_emb,
            nn.Linear(fourier_dim, time_dim),
            nn.GELU(),
            nn.Linear(time_dim, time_output_dim)
        )

        # 像素到补丁和反向

        self.to_patches = Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (c p1 p2)', p1 = patch_size, p2 = patch_size),
            nn.LayerNorm(pixel_patch_dim * 2) if dual_patchnorm else None,
            nn.Linear(pixel_patch_dim * 2, dim),
            nn.LayerNorm(dim) if dual_patchnorm else None,
        )

        # 轴向位置嵌入,由 MLP 参数化

        pos_emb_dim = dim // 2

        self.axial_pos_emb_height_mlp = nn.Sequential(
            Rearrange('... -> ... 1'),
            nn.Linear(1, pos_emb_dim),
            nn.SiLU(),
            nn.Linear(pos_emb_dim, pos_emb_dim),
            nn.SiLU(),
            nn.Linear(pos_emb_dim, dim)
        )

        self.axial_pos_emb_width_mlp = nn.Sequential(
            Rearrange('... -> ... 1'),
            nn.Linear(1, pos_emb_dim),
            nn.SiLU(),
            nn.Linear(pos_emb_dim, pos_emb_dim),
            nn.SiLU(),
            nn.Linear(pos_emb_dim, dim)
        )

        # nn.Parameter(torch.randn(2, patch_height_width, dim) * 0.02)

        self.to_pixels = nn.Sequential(
            LayerNorm(dim),
            nn.Linear(dim, pixel_patch_dim),
            Rearrange('b (h w) (c p1 p2) -> b c (h p1) (w p2)', p1 = patch_size, p2 = patch_size, h = patch_height_width)
        )

        # 初始化 latent
        self.latents = nn.Parameter(torch.randn(num_latents, dim_latent))
        nn.init.normal_(self.latents, std = 0.02)

        self.init_self_cond_latents = nn.Sequential(
            FeedForward(dim_latent),
            LayerNorm(dim_latent)
        )

        nn.init.zeros_(self.init_self_cond_latents[-1].gamma)

        # 主要的 RIN 主体参数 - 另一个注意力即可时刻

        if not latent_token_time_cond:
            attn_kwargs = {**attn_kwargs, 'time_cond_dim': time_dim}

        # 创建 RINBlock 模块列表
        self.blocks = nn.ModuleList([RINBlock(dim, dim_latent = dim_latent, latent_self_attn_depth = latent_self_attn_depth, patches_self_attn = patches_self_attn, **attn_kwargs) for _ in range(depth)])

    @property
    def device(self):
        # 返回模型参数所在的设备
        return next(self.parameters()).device

    def forward(
        self,
        x,
        time,
        x_self_cond = None,
        latent_self_cond = None,
        return_latents = False
        ):
        # 获取输入张量的批量大小
        batch = x.shape[0]

        # 如果没有给定 latents 的条件,则使用全零张量作为 latents 的条件
        x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x))

        # 在第二维度上连接 x_self_cond 和 x,得到新的输入张量 x
        x = torch.cat((x_self_cond, x), dim = 1)

        # 准备时间条件
        t = self.time_mlp(time)

        # 准备 latents
        latents = repeat(self.latents, 'n d -> b n d', b = batch)

        # 根据论文中的方法对 latents 进行初始化
        if exists(latent_self_cond):
            latents = latents + self.init_self_cond_latents(latent_self_cond)

        # 如果将时间条件视为一个 latents token 或用于自适应层归一化的尺度和偏移
        if self.latent_token_time_cond:
            t = rearrange(t, 'b d -> b 1 d')
            latents = torch.cat((latents, t), dim = -2)

        # 将输入 x 转换为 patches
        patches = self.to_patches(x)

        # 生成高度和宽度范围
        height_range = width_range = torch.linspace(0., 1., steps = int(math.sqrt(patches.shape[-2])), device = self.device)
        pos_emb_h, pos_emb_w = self.axial_pos_emb_height_mlp(height_range), self.axial_pos_emb_width_mlp(width_range)

        # 生成位置编码
        pos_emb = rearrange(pos_emb_h, 'i d -> i 1 d') + rearrange(pos_emb_w, 'j d -> 1 j d')
        patches = patches + rearrange(pos_emb, 'i j d -> (i j) d')

        # 循环执行递归接口网络的每个块
        for block in self.blocks:
            patches, latents = block(patches, latents, t)

        # 将 patches 转换为像素
        pixels = self.to_pixels(patches)

        # 如果不需要返回 latents,则直接返回像素
        if not return_latents:
            return pixels

        # 如果设置了 latent_token_time_cond,则移除时间条件 token
        if self.latent_token_time_cond:
            latents = latents[:, :-1]

        # 返回像素和 latents
        return pixels, latents
# 定义函数,将图像归一化到[-1, 1]范围
def normalize_img(x):
    return x * 2 - 1

# 定义函数,将图像反归一化
def unnormalize_img(x):
    return (x + 1) * 0.5

# 定义函数,将带噪声图像的方差归一化,如果比例不为1
def normalize_img_variance(x, eps = 1e-5):
    std = reduce(x, 'b c h w -> b 1 1 1', partial(torch.std, unbiased = False))
    return x / std.clamp(min = eps)

# 定义函数,计算输入张量的自然对数
def log(t, eps = 1e-20):
    return torch.log(t.clamp(min = eps))

# 定义函数,将输入张量的维度右侧填充到与另一个张量相同的维度
def right_pad_dims_to(x, t):
    padding_dims = x.ndim - t.ndim
    if padding_dims <= 0:
        return t
    return t.view(*t.shape, *((1,) * padding_dims))

# 定义简单线性调度函数
def simple_linear_schedule(t, clip_min = 1e-9):
    return (1 - t).clamp(min = clip_min)

# 定义余弦调度函数
def cosine_schedule(t, start = 0, end = 1, tau = 1, clip_min = 1e-9):
    power = 2 * tau
    v_start = math.cos(start * math.pi / 2) ** power
    v_end = math.cos(end * math.pi / 2) ** power
    output = math.cos((t * (end - start) + start) * math.pi / 2) ** power
    output = (v_end - output) / (v_end - v_start)
    return output.clamp(min = clip_min)

# 定义Sigmoid调度函数
def sigmoid_schedule(t, start = -3, end = 3, tau = 1, clamp_min = 1e-9):
    v_start = torch.tensor(start / tau).sigmoid()
    v_end = torch.tensor(end / tau).sigmoid()
    gamma = (-((t * (end - start) + start) / tau).sigmoid() + v_end) / (v_end - v_start)
    return gamma.clamp_(min = clamp_min, max = 1.)

# 将gamma转换为alpha和sigma
def gamma_to_alpha_sigma(gamma, scale = 1):
    return torch.sqrt(gamma) * scale, torch.sqrt(1 - gamma)

# 将gamma转换为对数信噪比
def gamma_to_log_snr(gamma, scale = 1, eps = 1e-5):
    return log(gamma * (scale ** 2) / (1 - gamma), eps = eps)

# 定义高斯扩散类
@beartype
class GaussianDiffusion(nn.Module):
    def __init__(
        self,
        model: RIN,
        *,
        timesteps = 1000,
        use_ddim = True,
        noise_schedule = 'sigmoid',
        objective = 'v',
        schedule_kwargs: dict = dict(),
        time_difference = 0.,
        min_snr_loss_weight = True,
        min_snr_gamma = 5,
        train_prob_self_cond = 0.9,
        scale = 1.                      # this will be set to < 1. for better convergence when training on higher resolution images
    ):
        super().__init__()
        self.model = model
        self.channels = self.model.channels

        assert objective in {'x0', 'eps', 'v'}, 'objective must be either predict x0 or noise'
        self.objective = objective

        self.image_size = model.image_size

        if noise_schedule == "linear":
            self.gamma_schedule = simple_linear_schedule
        elif noise_schedule == "cosine":
            self.gamma_schedule = cosine_schedule
        elif noise_schedule == "sigmoid":
            self.gamma_schedule = sigmoid_schedule
        else:
            raise ValueError(f'invalid noise schedule {noise_schedule}')

        assert scale <= 1, 'scale must be less than or equal to 1'
        self.scale = scale
        self.maybe_normalize_img_variance = normalize_img_variance if scale < 1 else identity

        self.gamma_schedule = partial(self.gamma_schedule, **schedule_kwargs)

        self.timesteps = timesteps
        self.use_ddim = use_ddim

        self.time_difference = time_difference

        self.train_prob_self_cond = train_prob_self_cond

        self.min_snr_loss_weight = min_snr_loss_weight
        self.min_snr_gamma = min_snr_gamma

    @property
    def device(self):
        return next(self.model.parameters()).device
    # 获取采样时间步长
    def get_sampling_timesteps(self, batch, *, device):
        # 在设备上创建一个从1到0的等差数列,共self.timesteps+1个点
        times = torch.linspace(1., 0., self.timesteps + 1, device=device)
        # 将时间序列重复batch次
        times = repeat(times, 't -> b t', b=batch)
        # 将时间序列拆分成相邻时间对
        times = torch.stack((times[:, :-1], times[:, 1:]), dim=0)
        times = times.unbind(dim=-1)
        return times

    # 无需梯度计算
    @torch.no_grad()
    def ddpm_sample(self, shape, time_difference=None):
        batch, device = shape[0], self.device

        # 设置时间差值
        time_difference = default(time_difference, self.time_difference)

        # 获取采样时间对
        time_pairs = self.get_sampling_timesteps(batch, device=device)

        # 生成随机噪声图像
        img = torch.randn(shape, device=device)

        x_start = None
        last_latents = None

        # 遍历时间对
        for time, time_next in tqdm(time_pairs, desc='sampling loop time step', total=self.timesteps):

            # 添加时间延迟
            time_next = (time_next - self.time_difference).clamp(min=0.)

            noise_cond = time

            # 获取预测的 x0
            maybe_normalized_img = self.maybe_normalize_img_variance(img)
            model_output, last_latents = self.model(maybe_normalized_img, noise_cond, x_start, last_latents, return_latents=True)

            # 获取 log(snr)
            gamma = self.gamma_schedule(time)
            gamma_next = self.gamma_schedule(time_next)
            gamma, gamma_next = map(partial(right_pad_dims_to, img), (gamma, gamma_next))

            # 获取 alpha 和 sigma
            alpha, sigma = gamma_to_alpha_sigma(gamma)
            alpha_next, sigma_next = gamma_to_alpha_sigma(gamma_next)

            # 计算 x0 和噪声
            if self.objective == 'x0':
                x_start = model_output
            elif self.objective == 'eps':
                x_start = safe_div(img - sigma * model_output, alpha)
            elif self.objective == 'v':
                x_start = alpha * img - sigma * model_output

            # 限制 x0 的取值范围
            x_start.clamp_(-1., 1.)

            # 推导后验均值和方差
            log_snr, log_snr_next = map(gamma_to_log_snr, (gamma, gamma_next))
            c = -expm1(log_snr - log_snr_next)
            mean = alpha_next * (img * (1 - c) / alpha + c * x_start)
            variance = (sigma_next ** 2) * c
            log_variance = log(variance)

            # 获取噪声
            noise = torch.where(
                rearrange(time_next > 0, 'b -> b 1 1 1'),
                torch.randn_like(img),
                torch.zeros_like(img)
            )

            img = mean + (0.5 * log_variance).exp() * noise

        return unnormalize_img(img)

    # 无需梯度计算
    @torch.no_grad()
    # 从给定形状中获取批次和设备信息
    def ddim_sample(self, shape, time_difference = None):
        batch, device = shape[0], self.device

        # 设置时间差值为默认值或者给定值
        time_difference = default(time_difference, self.time_difference)

        # 获取采样时间步骤
        time_pairs = self.get_sampling_timesteps(batch, device = device)

        # 生成符合正态分布的随机张量
        img = torch.randn(shape, device = device)

        x_start = None
        last_latents = None

        # 遍历时间对
        for times, times_next in tqdm(time_pairs, desc = 'sampling loop time step'):

            # 获取时间和噪声水平
            gamma = self.gamma_schedule(times)
            gamma_next = self.gamma_schedule(times_next)

            # 将噪声水平填充到与图像相同的维度
            padded_gamma, padded_gamma_next = map(partial(right_pad_dims_to, img), (gamma, gamma_next))

            # 将噪声水平转换为 alpha 和 sigma
            alpha, sigma = gamma_to_alpha_sigma(padded_gamma)
            alpha_next, sigma_next = gamma_to_alpha_sigma(padded_gamma_next)

            # 添加时间延迟
            times_next = (times_next - time_difference).clamp(min = 0.)

            # 预测 x0
            maybe_normalized_img = self.maybe_normalize_img_variance(img)
            model_output, last_latents = self.model(maybe_normalized_img, times, x_start, last_latents, return_latents = True)

            # 计算 x0 和噪声
            if self.objective == 'x0':
                x_start = model_output
            elif self.objective == 'eps':
                x_start = safe_div(img - sigma * model_output, alpha)
            elif self.objective == 'v':
                x_start = alpha * img - sigma * model_output

            # 限制 x0 的取值范围
            x_start.clamp_(-1., 1.)

            # 获取预测的噪声
            pred_noise = safe_div(img - alpha * x_start, sigma)

            # 计算下一个图像
            img = x_start * alpha_next + pred_noise * sigma_next

        # 返回未归一化的图像
        return unnormalize_img(img)

    # 无需梯度计算的函数装饰器
    @torch.no_grad()
    # 生成样本
    def sample(self, batch_size = 16):
        image_size, channels = self.image_size, self.channels
        # 根据是否使用 DDIM 选择采样函数
        sample_fn = self.ddpm_sample if not self.use_ddim else self.ddim_sample
        return sample_fn((batch_size, channels, image_size, image_size))
    # 定义一个前向传播函数,接受图像和其他参数
    def forward(self, img, *args, **kwargs):
        # 解包图像的形状和设备信息
        batch, c, h, w, device, img_size, = *img.shape, img.device, self.image_size
        # 断言图像的高度和宽度必须为指定的图像大小
        assert h == img_size and w == img_size, f'height and width of image must be {img_size}'

        # 生成随机时间采样
        times = torch.zeros((batch,), device=device).float().uniform_(0, 1.)

        # 将图像转换为比特表示
        img = normalize_img(img)

        # 生成噪声样本
        noise = torch.randn_like(img)

        # 计算 gamma 值
        gamma = self.gamma_schedule(times)
        padded_gamma = right_pad_dims_to(img, gamma)
        alpha, sigma = gamma_to_alpha_sigma(padded_gamma, self.scale)

        # 添加噪声到图像
        noised_img = alpha * img + sigma * noise

        # 可能对图像进行归一化处理
        noised_img = self.maybe_normalize_img_variance(noised_img)

        # 在论文中,他们必须使用非常高的概率进行潜在的自我条件,高达 90% 的时间
        # 稍微有点缺点
        self_cond = self_latents = None

        if random() < self.train_prob_self_cond:
            with torch.no_grad():
                model_output, self_latents = self.model(noised_img, times, return_latents=True)
                self_latents = self_latents.detach()

                if self.objective == 'x0':
                    self_cond = model_output

                elif self.objective == 'eps':
                    self_cond = safe_div(noised_img - sigma * model_output, alpha)

                elif self.objective == 'v':
                    self_cond = alpha * noised_img - sigma * model_output

                self_cond.clamp_(-1., 1.)
                self_cond = self_cond.detach()

        # 预测并进行梯度下降步骤
        pred = self.model(noised_img, times, self_cond, self_latents)

        if self.objective == 'eps':
            target = noise

        elif self.objective == 'x0':
            target = img

        elif self.objective == 'v':
            target = alpha * noise - sigma * img

        # 计算损失
        loss = F.mse_loss(pred, target, reduction='none')
        loss = reduce(loss, 'b ... -> b', 'mean')

        # 最小信噪比损失权重
        snr = (alpha * alpha) / (sigma * sigma)
        maybe_clipped_snr = snr.clone()

        if self.min_snr_loss_weight:
            maybe_clipped_snr.clamp_(max=self.min_snr_gamma)

        if self.objective == 'eps':
            loss_weight = maybe_clipped_snr / snr

        elif self.objective == 'x0':
            loss_weight = maybe_clipped_snr

        elif self.objective == 'v':
            loss_weight = maybe_clipped_snr / (snr + 1)

        return (loss * loss_weight).mean()
# dataset classes

# 定义 Dataset 类,继承自 torch.utils.data.Dataset
class Dataset(Dataset):
    # 初始化函数
    def __init__(
        self,
        folder,  # 数据集文件夹路径
        image_size,  # 图像大小
        exts = ['jpg', 'jpeg', 'png', 'tiff'],  # 图像文件扩展名列表
        augment_horizontal_flip = False,  # 是否进行水平翻转增强
        convert_image_to = None  # 图像转换函数
    ):
        super().__init__()
        self.folder = folder
        self.image_size = image_size
        # 获取文件夹中指定扩展名的所有文件路径
        self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]

        # 部分应用转换函数
        maybe_convert_fn = partial(convert_image_to, convert_image_to) if exists(convert_image_to) else nn.Identity()

        # 图像转换操作序列
        self.transform = T.Compose([
            T.Lambda(maybe_convert_fn),
            T.Resize(image_size),
            T.RandomHorizontalFlip() if augment_horizontal_flip else nn.Identity(),
            T.CenterCrop(image_size),
            T.ToTensor()
        ])

    # 返回数据集长度
    def __len__(self):
        return len(self.paths)

    # 获取指定索引处的数据
    def __getitem__(self, index):
        path = self.paths[index]
        img = Image.open(path)
        return self.transform(img)

# trainer class

# 定义 Trainer 类
@beartype
class Trainer(object):
    # 初始化函数
    def __init__(
        self,
        diffusion_model: GaussianDiffusion,  # 扩散模型
        folder,  # 数据集文件夹路径
        *,
        train_batch_size = 16,  # 训练批量大小
        gradient_accumulate_every = 1,  # 梯度累积步数
        augment_horizontal_flip = True,  # 是否进行水平翻转增强
        train_lr = 1e-4,  # 训练学习率
        train_num_steps = 100000,  # 训练步数
        max_grad_norm = 1.,  # 梯度裁剪阈值
        ema_update_every = 10,  # EMA 更新频率
        ema_decay = 0.995,  # EMA 衰减率
        betas = (0.9, 0.99),  # Adam 优化器的 beta 参数
        save_and_sample_every = 1000,  # 保存和采样频率
        num_samples = 25,  # 采样数量
        results_folder = './results',  # 结果保存文件夹路径
        amp = False,  # 是否使用混合精度训练
        mixed_precision_type = 'fp16',  # 混合精度类型
        split_batches = True,  # 是否拆分批次
        convert_image_to = None  # 图像转换函数
    ):
        super().__init__()

        # 初始化加速器
        self.accelerator = Accelerator(
            split_batches = split_batches,
            mixed_precision = mixed_precision_type if amp else 'no',
            kwargs_handlers = [DistributedDataParallelKwargs(find_unused_parameters=True)]
        )

        # 设置扩散模型
        self.model = diffusion_model

        # 检查采样数量是否有整数平方根
        assert has_int_squareroot(num_samples), 'number of samples must have an integer square root'
        self.num_samples = num_samples
        self.save_and_sample_every = save_and_sample_every

        self.batch_size = train_batch_size
        self.gradient_accumulate_every = gradient_accumulate_every
        self.max_grad_norm = max_grad_norm

        self.train_num_steps = train_num_steps
        self.image_size = diffusion_model.image_size

        # 数据集和数据加载器

        # 创建数据集对象
        self.ds = Dataset(folder, self.image_size, augment_horizontal_flip = augment_horizontal_flip, convert_image_to = convert_image_to)
        # 创建数据加载器
        dl = DataLoader(self.ds, batch_size = train_batch_size, shuffle = True, pin_memory = True, num_workers = cpu_count())

        # 准备数据加载器
        dl = self.accelerator.prepare(dl)
        self.dl = cycle(dl)

        # 优化器

        # 创建 Adam 优化器
        self.opt = Adam(diffusion_model.parameters(), lr = train_lr, betas = betas)

        # 定期记录结果到文件夹

        self.results_folder = Path(results_folder)

        if self.accelerator.is_local_main_process:
            self.results_folder.mkdir(exist_ok = True)

        if self.accelerator.is_main_process:
            self.ema = EMA(diffusion_model, beta = ema_decay, update_every = ema_update_every)

        # 步数计数器状态

        self.step = 0

        # 准备模型、数据加载器、优化器与加速器

        self.model, self.opt = self.accelerator.prepare(self.model, self.opt)

    # 保存模型
    def save(self, milestone):
        if not self.accelerator.is_local_main_process:
            return

        data = {
            'step': self.step + 1,
            'model': self.accelerator.get_state_dict(self.model),
            'opt': self.opt.state_dict(),
            'ema': self.ema.state_dict(),
            'scaler': self.accelerator.scaler.state_dict() if exists(self.accelerator.scaler) else None
        }

        torch.save(data, str(self.results_folder / f'model-{milestone}.pt'))
    # 加载指定里程碑的模型数据
    def load(self, milestone):
        # 从文件中加载模型数据
        data = torch.load(str(self.results_folder / f'model-{milestone}.pt'))

        # 获取未加速的模型对象
        model = self.accelerator.unwrap_model(self.model)
        # 加载模型的状态字典
        model.load_state_dict(data['model'])

        # 设置当前训练步数
        self.step = data['step']
        # 加载优化器的状态字典
        self.opt.load_state_dict(data['opt'])

        # 如果是主进程,则加载指数移动平均模型的状态字典
        if self.accelerator.is_main_process:
            self.ema.load_state_dict(data['ema'])

        # 如果加速器和数据中都存在缩放器状态字典,则加载缩放器的状态字典
        if exists(self.accelerator.scaler) and exists(data['scaler']):
            self.accelerator.scaler.load_state_dict(data['scaler'])

    # 训练模型
    def train(self):
        # 获取加速器和设备
        accelerator = self.accelerator
        device = accelerator.device

        # 使用 tqdm 显示训练进度
        with tqdm(initial = self.step, total = self.train_num_steps, disable = not accelerator.is_main_process) as pbar:

            # 在未达到训练步数上限前循环训练
            while self.step < self.train_num_steps:

                total_loss = 0.

                # 根据梯度累积次数循环执行训练步骤
                for _ in range(self.gradient_accumulate_every):
                    # 获取下一个数据批次并发送到设备
                    data = next(self.dl).to(device)

                    # 使用自动混合精度计算模型损失
                    with accelerator.autocast():
                        loss = self.model(data)
                        loss = loss / self.gradient_accumulate_every
                        total_loss += loss.item()

                    # 反向传播计算梯度
                    accelerator.backward(loss)

                # 更新进度条显示当前损失值
                pbar.set_description(f'loss: {total_loss:.4f}')

                # 等待所有进程完成当前步骤
                accelerator.wait_for_everyone()
                # 对模型参数进行梯度裁剪
                accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)

                # 执行优化器的一步更新
                self.opt.step()
                # 清空梯度
                self.opt.zero_grad()

                # 等待所有进程完成当前步骤
                accelerator.wait_for_everyone()

                # 在每个本地主进程上保存里程碑,仅在全局主进程上采样
                if accelerator.is_local_main_process:
                    milestone = self.step // self.save_and_sample_every
                    save_and_sample = self.step != 0 and self.step % self.save_and_sample_every == 0
                    
                    if accelerator.is_main_process:
                        # 将指数移动平均模型发送到设备
                        self.ema.to(device)
                        # 更新指数移动平均模型
                        self.ema.update()

                        if save_and_sample:
                            # 将指数移动平均模型设置为评估模式
                            self.ema.ema_model.eval()

                            with torch.no_grad():
                                # 将样本数量分组并生成样本图像
                                batches = num_to_groups(self.num_samples, self.batch_size)
                                all_images_list = list(map(lambda n: self.ema.ema_model.sample(batch_size=n), batches))

                            all_images = torch.cat(all_images_list, dim = 0)
                            # 保存生成的样本图像
                            utils.save_image(all_images, str(self.results_folder / f'sample-{milestone}.png'), nrow = int(math.sqrt(self.num_samples)))

                    if save_and_sample:
                        # 保存当前里程碑的模型数据
                        self.save(milestone)

                # 更新训练步数并更新进度条
                self.step += 1
                pbar.update(1)

        # 打印训练完成信息
        accelerator.print('training complete')