Lucidrains-系列项目源码解析-四十五-

55 阅读33分钟

Lucidrains 系列项目源码解析(四十五)

.\lucidrains\imagen-pytorch\imagen_pytorch\imagen_video.py

# 导入数学、操作符、函数工具等模块
import math
import operator
import functools
from tqdm.auto import tqdm
from functools import partial, wraps
from pathlib import Path

# 导入 PyTorch 相关模块
import torch
import torch.nn.functional as F
from torch import nn, einsum

# 导入 einops 相关模块
from einops import rearrange, repeat, pack, unpack
from einops.layers.torch import Rearrange

# 导入自定义模块
from imagen_pytorch.t5 import t5_encode_text, get_encoded_dim, DEFAULT_T5_NAME

# 辅助函数

# 检查值是否存在
def exists(val):
    return val is not None

# 返回输入值
def identity(t, *args, **kwargs):
    return t

# 返回数组的第一个元素,如果数组为空则返回默认值
def first(arr, d = None):
    if len(arr) == 0:
        return d
    return arr[0]

# 检查一个数是否能被另一个数整除
def divisible_by(numer, denom):
    return (numer % denom) == 0

# 可能执行函数,如果输入值不存在则直接返回
def maybe(fn):
    @wraps(fn)
    def inner(x):
        if not exists(x):
            return x
        return fn(x)
    return inner

# 仅执行一次函数,用于打印信息
def once(fn):
    called = False
    @wraps(fn)
    def inner(x):
        nonlocal called
        if called:
            return
        called = True
        return fn(x)
    return inner

# 仅打印一次信息
print_once = once(print)

# 返回默认值或默认函数的值
def default(val, d):
    if exists(val):
        return val
    return d() if callable(d) else d

# 将输入值转换为元组
def cast_tuple(val, length = None):
    if isinstance(val, list):
        val = tuple(val)

    output = val if isinstance(val, tuple) else ((val,) * default(length, 1))

    if exists(length):
        assert len(output) == length

    return output

# 将 uint8 类型的图像转换为 float 类型
def cast_uint8_images_to_float(images):
    if not images.dtype == torch.uint8:
        return images
    return images / 255

# 获取模块的设备信息
def module_device(module):
    return next(module.parameters()).device

# 初始化权重为零
def zero_init_(m):
    nn.init.zeros_(m.weight)
    if exists(m.bias):
        nn.init.zeros_(m.bias)

# 模型评估装饰器
def eval_decorator(fn):
    def inner(model, *args, **kwargs):
        was_training = model.training
        model.eval()
        out = fn(model, *args, **kwargs)
        model.train(was_training)
        return out
    return inner

# 将元组填充到指定长度
def pad_tuple_to_length(t, length, fillvalue = None):
    remain_length = length - len(t)
    if remain_length <= 0:
        return t
    return (*t, *((fillvalue,) * remain_length))

# 辅助类

# 简单的返回输入值的模块
class Identity(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()

    def forward(self, x, *args, **kwargs):
        return x

# 创建序列模块
def Sequential(*modules):
    return nn.Sequential(*filter(exists, modules))

# 张量辅助函数

# 对数函数
def log(t, eps: float = 1e-12):
    return torch.log(t.clamp(min = eps))

# L2 归一化
def l2norm(t):
    return F.normalize(t, dim = -1)

# 将右侧维度填充到相同维度
def right_pad_dims_to(x, t):
    padding_dims = x.ndim - t.ndim
    if padding_dims <= 0:
        return t
    return t.view(*t.shape, *((1,) * padding_dims))

# 带掩码的均值计算
def masked_mean(t, *, dim, mask = None):
    if not exists(mask):
        return t.mean(dim = dim)

    denom = mask.sum(dim = dim, keepdim = True)
    mask = rearrange(mask, 'b n -> b n 1')
    masked_t = t.masked_fill(~mask, 0.)

    return masked_t.sum(dim = dim) / denom.clamp(min = 1e-5)

# 调整视频大小
def resize_video_to(
    video,
    target_image_size,
    target_frames = None,
    clamp_range = None,
    mode = 'nearest'
):
    orig_video_size = video.shape[-1]

    frames = video.shape[2]
    target_frames = default(target_frames, frames)

    target_shape = (target_frames, target_image_size, target_image_size)

    if tuple(video.shape[-3:]) == target_shape:
        return video

    out = F.interpolate(video, target_shape, mode = mode)

    if exists(clamp_range):
        out = out.clamp(*clamp_range)
        
    return out

# 缩放视频时间
def scale_video_time(
    video,
    downsample_scale = 1,
    mode = 'nearest'
):
    if downsample_scale == 1:
        return video

    image_size, frames = video.shape[-1], video.shape[-3]
    assert divisible_by(frames, downsample_scale), f'trying to temporally downsample a conditioning video frames of length {frames} by {downsample_scale}, however it is not neatly divisible'

    target_frames = frames // downsample_scale
    # 调用 resize_video_to 函数,将视频调整大小为指定尺寸
    resized_video = resize_video_to(
        video,  # 原始视频
        image_size,  # 目标图像尺寸
        target_frames = target_frames,  # 目标帧数
        mode = mode  # 调整模式
    )

    # 返回调整大小后的视频
    return resized_video
# classifier free guidance functions

# 根据给定形状、概率和设备创建一个布尔类型的掩码
def prob_mask_like(shape, prob, device):
    if prob == 1:
        return torch.ones(shape, device=device, dtype=torch.bool)
    elif prob == 0:
        return torch.zeros(shape, device=device, dtype=torch.bool)
    else:
        return torch.zeros(shape, device=device).float().uniform_(0, 1) < prob

# norms and residuals

# Layer normalization模块
class LayerNorm(nn.Module):
    def __init__(self, dim, stable=False):
        super().__init__()
        self.stable = stable
        self.g = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        if self.stable:
            x = x / x.amax(dim=-1, keepdim=True).detach()

        eps = 1e-5 if x.dtype == torch.float32 else 1e-3
        var = torch.var(x, dim=-1, unbiased=False, keepdim=True)
        mean = torch.mean(x, dim=-1, keepdim=True)
        return (x - mean) * (var + eps).rsqrt() * self.g

# 通道层规范化模块
class ChanLayerNorm(nn.Module):
    def __init__(self, dim, stable=False):
        super().__init__()
        self.stable = stable
        self.g = nn.Parameter(torch.ones(1, dim, 1, 1, 1))

    def forward(self, x):
        if self.stable:
            x = x / x.amax(dim=1, keepdim=True).detach()

        eps = 1e-5 if x.dtype == torch.float32 else 1e-3
        var = torch.var(x, dim=1, unbiased=False, keepdim=True)
        mean = torch.mean(x, dim=1, keepdim=True)
        return (x - mean) * (var + eps).rsqrt() * self.g

# 始终返回相同值的类
class Always():
    def __init__(self, val):
        self.val = val

    def __call__(self, *args, **kwargs):
        return self.val

# 残差连接模块
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) + x

# 并行执行多个函数模块
class Parallel(nn.Module):
    def __init__(self, *fns):
        super().__init__()
        self.fns = nn.ModuleList(fns)

    def forward(self, x):
        outputs = [fn(x) for fn in self.fns]
        return sum(outputs)

# rearranging

# 时间为中心的重排模块
class RearrangeTimeCentric(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x):
        x = rearrange(x, 'b c f ... -> b ... f c')
        x, ps = pack([x], '* f c')

        x = self.fn(x)

        x, = unpack(x, ps, '* f c')
        x = rearrange(x, 'b ... f c -> b c f ...')
        return x

# attention pooling

# PerceiverAttention模块
class PerceiverAttention(nn.Module):
    def __init__(
        self,
        *,
        dim,
        dim_head=64,
        heads=8,
        scale=8
    ):
        super().__init__()
        self.scale = scale

        self.heads = heads
        inner_dim = dim_head * heads

        self.norm = nn.LayerNorm(dim)
        self.norm_latents = nn.LayerNorm(dim)

        self.to_q = nn.Linear(dim, inner_dim, bias=False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)

        self.q_scale = nn.Parameter(torch.ones(dim_head))
        self.k_scale = nn.Parameter(torch.ones(dim_head))

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim, bias=False),
            nn.LayerNorm(dim)
        )
    # 前向传播函数,接收输入 x、潜在变量 latents 和可选的 mask
    def forward(self, x, latents, mask = None):
        # 对输入 x 进行归一化处理
        x = self.norm(x)
        # 对潜在变量 latents 进行归一化处理
        latents = self.norm_latents(latents)

        # 获取输入 x 的 batch 大小和头数
        b, h = x.shape[0], self.heads

        # 生成查询向量 q
        q = self.to_q(latents)

        # 将输入 x 和潜在变量 latents 连接起来,作为键值对的输入
        kv_input = torch.cat((x, latents), dim = -2)
        # 将连接后的输入转换为键和值
        k, v = self.to_kv(kv_input).chunk(2, dim = -1)

        # 对查询、键、值进行维度重排
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))

        # 对查询和键进行 L2 归一化
        q, k = map(l2norm, (q, k))
        # 对查询和键进行缩放
        q = q * self.q_scale
        k = k * self.k_scale

        # 计算相似度矩阵
        sim = einsum('... i d, ... j d  -> ... i j', q, k) * self.scale

        # 如果存在 mask,则进行填充和掩码处理
        if exists(mask):
            max_neg_value = -torch.finfo(sim.dtype).max
            mask = F.pad(mask, (0, latents.shape[-2]), value = True)
            mask = rearrange(mask, 'b j -> b 1 1 j')
            sim = sim.masked_fill(~mask, max_neg_value)

        # 计算注意力权重
        attn = sim.softmax(dim = -1)

        # 计算输出
        out = einsum('... i j, ... j d -> ... i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)', h = h)
        # 返回输出结果
        return self.to_out(out)
# 定义 PerceiverResampler 类,继承自 nn.Module
class PerceiverResampler(nn.Module):
    # 初始化函数
    def __init__(
        self,
        *,
        dim,
        depth,
        dim_head = 64,
        heads = 8,
        num_latents = 64,
        num_latents_mean_pooled = 4, # 从序列的均值池化表示派生的潜在变量数量
        max_seq_len = 512,
        ff_mult = 4
    ):
        super().__init__()
        # 创建位置嵌入层
        self.pos_emb = nn.Embedding(max_seq_len, dim)

        # 初始化潜在变量
        self.latents = nn.Parameter(torch.randn(num_latents, dim))

        self.to_latents_from_mean_pooled_seq = None

        # 如果均值池化的潜在变量数量大于0,则创建相应的层
        if num_latents_mean_pooled > 0:
            self.to_latents_from_mean_pooled_seq = nn.Sequential(
                LayerNorm(dim),
                nn.Linear(dim, dim * num_latents_mean_pooled),
                Rearrange('b (n d) -> b n d', n = num_latents_mean_pooled)
            )

        # 创建多层感知器
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PerceiverAttention(dim = dim, dim_head = dim_head, heads = heads),
                FeedForward(dim = dim, mult = ff_mult)
            ]))

    # 前向传播函数
    def forward(self, x, mask = None):
        n, device = x.shape[1], x.device
        pos_emb = self.pos_emb(torch.arange(n, device = device))

        x_with_pos = x + pos_emb

        latents = repeat(self.latents, 'n d -> b n d', b = x.shape[0])

        # 如果存在均值池化的潜在变量,则将其与原始潜在变量拼接
        if exists(self.to_latents_from_mean_pooled_seq):
            meanpooled_seq = masked_mean(x, dim = 1, mask = torch.ones(x.shape[:2], device = x.device, dtype = torch.bool))
            meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
            latents = torch.cat((meanpooled_latents, latents), dim = -2)

        # 遍历每一层的注意力机制和前馈网络
        for attn, ff in self.layers:
            latents = attn(x_with_pos, latents, mask = mask) + latents
            latents = ff(latents) + latents

        return latents

# 定义 Conv3d 类,继承自 nn.Module
class Conv3d(nn.Module):
    # 初始化函数
    def __init__(
        self,
        dim,
        dim_out = None,
        kernel_size = 3,
        *,
        temporal_kernel_size = None,
        **kwargs
    ):
        super().__init__()
        dim_out = default(dim_out, dim)
        temporal_kernel_size = default(temporal_kernel_size, kernel_size)

        # 创建空���卷积层
        self.spatial_conv = nn.Conv2d(dim, dim_out, kernel_size = kernel_size, padding = kernel_size // 2)
        # 创建时间卷积层(如果 kernel_size 大于1)
        self.temporal_conv = nn.Conv1d(dim_out, dim_out, kernel_size = temporal_kernel_size) if kernel_size > 1 else None
        self.kernel_size = kernel_size

        # 初始化时间卷积层的权重为单位矩阵
        if exists(self.temporal_conv):
            nn.init.dirac_(self.temporal_conv.weight.data) # initialized to be identity
            nn.init.zeros_(self.temporal_conv.bias.data)

    # 前向传播函数
    def forward(
        self,
        x,
        ignore_time = False
    ):
        b, c, *_, h, w = x.shape

        is_video = x.ndim == 5
        ignore_time &= is_video

        if is_video:
            x = rearrange(x, 'b c f h w -> (b f) c h w')

        x = self.spatial_conv(x)

        if is_video:
            x = rearrange(x, '(b f) c h w -> b c f h w', b = b)

        if ignore_time or not exists(self.temporal_conv):
            return x

        x = rearrange(x, 'b c f h w -> (b h w) c f')

        # 因果时间卷积 - 时间在 imagen-video 中是因果的

        if self.kernel_size > 1:
            x = F.pad(x, (self.kernel_size - 1, 0))

        x = self.temporal_conv(x)

        x = rearrange(x, '(b h w) c f -> b c f h w', h = h, w = w)

        return x

# 定义 Attention 类,继承自 nn.Module
class Attention(nn.Module):
    # 初始化函数
    def __init__(
        self,
        dim,
        *,
        dim_head = 64,
        heads = 8,
        causal = False,
        context_dim = None,
        rel_pos_bias = False,
        rel_pos_bias_mlp_depth = 2,
        init_zero = False,
        scale = 8
    ):
        # 调用父类的初始化方法
        super().__init__()
        # 设置缩放因子和是否因果的标志
        self.scale = scale
        self.causal = causal

        # 如果启用相对位置偏置,则创建动态位置偏置对象
        self.rel_pos_bias = DynamicPositionBias(dim = dim, heads = heads, depth = rel_pos_bias_mlp_depth) if rel_pos_bias else None

        # 初始化头数和内部维度
        self.heads = heads
        inner_dim = dim_head * heads

        # 初始化 LayerNorm
        self.norm = LayerNorm(dim)

        # 初始化空注意力偏置和空键值对
        self.null_attn_bias = nn.Parameter(torch.randn(heads))
        self.null_kv = nn.Parameter(torch.randn(2, dim_head))
        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)

        # 初始化缩放参数
        self.q_scale = nn.Parameter(torch.ones(dim_head))
        self.k_scale = nn.Parameter(torch.ones(dim_head))

        # 如果存在上下文维度,则初始化上下文处理层
        self.to_context = nn.Sequential(nn.LayerNorm(context_dim), nn.Linear(context_dim, dim_head * 2)) if exists(context_dim) else None

        # 初始化输出层
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim, bias = False),
            LayerNorm(dim)
        )

        # 如果初始化为零,则将输出层的偏置初始化为零
        if init_zero:
            nn.init.zeros_(self.to_out[-1].g)

    def forward(
        self,
        x,
        context = None,
        mask = None,
        attn_bias = None
    ):
        # 获取输入张量的形状和设备信息
        b, n, device = *x.shape[:2], x.device

        # 对输入张量进行 LayerNorm 处理
        x = self.norm(x)
        # 分别计算查询、键、值
        q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1))

        # 将查询张量重排为多头形式
        q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)

        # 添加空键/值以用于分类器的先验网络引导
        nk, nv = map(lambda t: repeat(t, 'd -> b 1 d', b = b), self.null_kv.unbind(dim = -2))
        k = torch.cat((nk, k), dim = -2)
        v = torch.cat((nv, v), dim = -2)

        # 如果存在上下文,则添加文本条件
        if exists(context):
            assert exists(self.to_context)
            ck, cv = self.to_context(context).chunk(2, dim = -1)
            k = torch.cat((ck, k), dim = -2)
            v = torch.cat((cv, v), dim = -2)

        # 对查询、键进行 L2 归一化
        q, k = map(l2norm, (q, k))
        q = q * self.q_scale
        k = k * self.k_scale

        # 计算查询/键的相似性
        sim = einsum('b h i d, b j d -> b h i j', q, k) * self.scale

        # 相对位置编码(T5 风格)
        if not exists(attn_bias) and exists(self.rel_pos_bias):
            attn_bias = self.rel_pos_bias(n, device = device, dtype = q.dtype)

        if exists(attn_bias):
            null_attn_bias = repeat(self.null_attn_bias, 'h -> h n 1', n = n)
            attn_bias = torch.cat((null_attn_bias, attn_bias), dim = -1)
            sim = sim + attn_bias

        # 掩码
        max_neg_value = -torch.finfo(sim.dtype).max

        if self.causal:
            i, j = sim.shape[-2:]
            causal_mask = torch.ones((i, j), device = device, dtype = torch.bool).triu(j - i + 1)
            sim = sim.masked_fill(causal_mask, max_neg_value)

        if exists(mask):
            mask = F.pad(mask, (1, 0), value = True)
            mask = rearrange(mask, 'b j -> b 1 1 j')
            sim = sim.masked_fill(~mask, max_neg_value)

        # 注意力
        attn = sim.softmax(dim = -1)

        # 聚合值
        out = einsum('b h i j, b j d -> b h i d', attn, v)

        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)
# 定义一个伪 Conv2d 函数,使用 Conv3d 但在帧维度上使用大小为1的卷积核
def Conv2d(dim_in, dim_out, kernel, stride = 1, padding = 0, **kwargs):
    # 将 kernel 转换为元组
    kernel = cast_tuple(kernel, 2)
    # 将 stride 转换为元组
    stride = cast_tuple(stride, 2)
    # 将 padding 转换为元组
    padding = cast_tuple(padding, 2)

    # 如果 kernel 的长度为2,则在前面添加1
    if len(kernel) == 2:
        kernel = (1, *kernel)

    # 如果 stride 的长度为2,则在前面添加1
    if len(stride) == 2:
        stride = (1, *stride)

    # 如果 padding 的长度为2,则在前面添加0
    if len(padding) == 2:
        padding = (0, *padding)

    # 返回一个 Conv3d 对象
    return nn.Conv3d(dim_in, dim_out, kernel, stride = stride, padding = padding, **kwargs)

# 定义一个 Pad 类
class Pad(nn.Module):
    def __init__(self, padding, value = 0.):
        super().__init__()
        self.padding = padding
        self.value = value

    # 前向传播函数
    def forward(self, x):
        return F.pad(x, self.padding, value = self.value)

# 定义一个 Upsample 函数
def Upsample(dim, dim_out = None):
    dim_out = default(dim_out, dim)

    # 返回一个包含 Upsample 和 Conv2d 的序列
    return nn.Sequential(
        nn.Upsample(scale_factor = 2, mode = 'nearest'),
        Conv2d(dim, dim_out, 3, padding = 1)
    )

# 定义一个 PixelShuffleUpsample 类
class PixelShuffleUpsample(nn.Module):
    def __init__(self, dim, dim_out = None):
        super().__init__()
        dim_out = default(dim_out, dim)
        conv = Conv2d(dim, dim_out * 4, 1)

        self.net = nn.Sequential(
            conv,
            nn.SiLU()
        )

        self.pixel_shuffle = nn.PixelShuffle(2)

        self.init_conv_(conv)

    # 初始化卷积层的权重
    def init_conv_(self, conv):
        o, i, f, h, w = conv.weight.shape
        conv_weight = torch.empty(o // 4, i, f, h, w)
        nn.init.kaiming_uniform_(conv_weight)
        conv_weight = repeat(conv_weight, 'o ... -> (o 4) ...')

        conv.weight.data.copy_(conv_weight)
        nn.init.zeros_(conv.bias.data)

    # 前向传播函数
    def forward(self, x):
        out = self.net(x)
        frames = x.shape[2]
        out = rearrange(out, 'b c f h w -> (b f) c h w')
        out = self.pixel_shuffle(out)
        return rearrange(out, '(b f) c h w -> b c f h w', f = frames)

# 定义一个 Downsample 函数
def Downsample(dim, dim_out = None):
    dim_out = default(dim_out, dim)
    return nn.Sequential(
        Rearrange('b c f (h p1) (w p2) -> b (c p1 p2) f h w', p1 = 2, p2 = 2),
        Conv2d(dim * 4, dim_out, 1)
    )

# 定义一个 TemporalPixelShuffleUpsample 类
class TemporalPixelShuffleUpsample(nn.Module):
    def __init__(self, dim, dim_out = None, stride = 2):
        super().__init__()
        self.stride = stride
        dim_out = default(dim_out, dim)
        conv = nn.Conv1d(dim, dim_out * stride, 1)

        self.net = nn.Sequential(
            conv,
            nn.SiLU()
        )

        self.pixel_shuffle = Rearrange('b (c r) n -> b c (n r)', r = stride)

        self.init_conv_(conv)

    # 初始化卷积层的权重
    def init_conv_(self, conv):
        o, i, f = conv.weight.shape
        conv_weight = torch.empty(o // self.stride, i, f)
        nn.init.kaiming_uniform_(conv_weight)
        conv_weight = repeat(conv_weight, 'o ... -> (o r) ...', r = self.stride)

        conv.weight.data.copy_(conv_weight)
        nn.init.zeros_(conv.bias.data)

    # 前向传播函数
    def forward(self, x):
        b, c, f, h, w = x.shape
        x = rearrange(x, 'b c f h w -> (b h w) c f')
        out = self.net(x)
        out = self.pixel_shuffle(out)
        return rearrange(out, '(b h w) c f -> b c f h w', h = h, w = w)

# 定义一个 TemporalDownsample 函数
def TemporalDownsample(dim, dim_out = None, stride = 2):
    dim_out = default(dim_out, dim)
    return nn.Sequential(
        Rearrange('b c (f p) h w -> b (c p) f h w', p = stride),
        Conv2d(dim * stride, dim_out, 1)
    )

# 定义一个 SinusoidalPosEmb 类
class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    # 前向传播函数
    def forward(self, x):
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device = x.device) * -emb)
        emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j')
        return torch.cat((emb.sin(), emb.cos()), dim = -1)

# 定义一个 LearnedSinusoidalPosEmb 类
class LearnedSinusoidalPosEmb(nn.Module):
    # 初始化函数,接受维度参数
    def __init__(self, dim):
        # 调用父类的初始化函数
        super().__init__()
        # 断言维度为偶数
        assert (dim % 2) == 0
        # 计算维度的一半
        half_dim = dim // 2
        # 初始化权重参数为服从标准正态分布的张量
        self.weights = nn.Parameter(torch.randn(half_dim))

    # 前向传播函数,接受输入张量 x
    def forward(self, x):
        # 重新排列输入张量 x 的维度,增加一个维度
        x = rearrange(x, 'b -> b 1')
        # 计算频率,乘以权重参数和 2π
        freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi
        # 将正弦和余弦值拼接在一起
        fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1)
        # 将输入张量 x 和频率值拼接在一起
        fouriered = torch.cat((x, fouriered), dim = -1)
        # 返回拼接后的张量
        return fouriered
class Block(nn.Module):
    # 定义一个块模块,包含归一化、激活函数和卷积操作
    def __init__(
        self,
        dim,
        dim_out,
        groups = 8,
        norm = True
    ):
        super().__init__()
        # 初始化 GroupNorm 归一化层,如果不需要归一化则使用 Identity 函数
        self.groupnorm = nn.GroupNorm(groups, dim) if norm else Identity()
        # 初始化激活函数为 SiLU
        self.activation = nn.SiLU()
        # 初始化卷积操作,输出维度为 dim_out,卷积核大小为 3,填充为 1
        self.project = Conv3d(dim, dim_out, 3, padding = 1)

    # 前向传播函数,对输入进行归一化、缩放平移、激活和卷积操作
    def forward(
        self,
        x,
        scale_shift = None,
        ignore_time = False
    ):
        # 对输入进行归一化
        x = self.groupnorm(x)

        # 如果有缩放平移参数,则对输入进行缩放平移操作
        if exists(scale_shift):
            scale, shift = scale_shift
            x = x * (scale + 1) + shift

        # 对归一化后的输入进行激活函数操作
        x = self.activation(x)
        # 返回卷积操作后的结果
        return self.project(x, ignore_time = ignore_time)

class ResnetBlock(nn.Module):
    # 定义一个 ResNet 块模块,包含时间 MLP、交叉注意力、块模块和全局上下文注意力
    def __init__(
        self,
        dim,
        dim_out,
        *,
        cond_dim = None,
        time_cond_dim = None,
        groups = 8,
        linear_attn = False,
        use_gca = False,
        squeeze_excite = False,
        **attn_kwargs
    ):
        super().__init__()

        self.time_mlp = None

        # 如果存在时间条件维度,则初始化时间 MLP
        if exists(time_cond_dim):
            self.time_mlp = nn.Sequential(
                nn.SiLU(),
                nn.Linear(time_cond_dim, dim_out * 2)
            )

        self.cross_attn = None

        # 如果存在条件维度,则初始化交叉注意力模块
        if exists(cond_dim):
            attn_klass = CrossAttention if not linear_attn else LinearCrossAttention

            self.cross_attn = attn_klass(
                dim = dim_out,
                context_dim = cond_dim,
                **attn_kwargs
            )

        # 初始化两个块模块
        self.block1 = Block(dim, dim_out, groups = groups)
        self.block2 = Block(dim_out, dim_out, groups = groups)

        # 如果使用全局上下文注意力,则初始化全局上下文模块
        self.gca = GlobalContext(dim_in = dim_out, dim_out = dim_out) if use_gca else Always(1)

        # 如果输入维度不等于输出维度,则初始化卷积操作
        self.res_conv = Conv2d(dim, dim_out, 1) if dim != dim_out else Identity()


    # 前向传播函数,包括时间 MLP、交叉注意力、块模块和全局上下文注意力的操作
    def forward(
        self,
        x,
        time_emb = None,
        cond = None,
        ignore_time = False
    ):

        scale_shift = None
        # 如果存在时间 MLP 和时间嵌入,则进行时间 MLP 操作
        if exists(self.time_mlp) and exists(time_emb):
            time_emb = self.time_mlp(time_emb)
            time_emb = rearrange(time_emb, 'b c -> b c 1 1 1')
            scale_shift = time_emb.chunk(2, dim = 1)

        # 第一个块模块操作
        h = self.block1(x, ignore_time = ignore_time)

        # 如果存在交叉注意力模块,则进行交叉注意力操作
        if exists(self.cross_attn):
            assert exists(cond)
            h = rearrange(h, 'b c ... -> b ... c')
            h, ps = pack([h], 'b * c')

            h = self.cross_attn(h, context = cond) + h

            h, = unpack(h, ps, 'b * c')
            h = rearrange(h, 'b ... c -> b c ...')

        # 第二个块模块操作
        h = self.block2(h, scale_shift = scale_shift, ignore_time = ignore_time)

        # 全局上下文注意力操作
        h = h * self.gca(h)

        # 返回结果加上残差连接
        return h + self.res_conv(x)

class CrossAttention(nn.Module):
    # 定义交叉注意力模块,包含查询、键值映射和输出映射
    def __init__(
        self,
        dim,
        *,
        context_dim = None,
        dim_head = 64,
        heads = 8,
        norm_context = False,
        scale = 8
    ):
        super().__init__()
        self.scale = scale

        self.heads = heads
        inner_dim = dim_head * heads

        context_dim = default(context_dim, dim)

        # 初始化 LayerNorm 归一化层
        self.norm = LayerNorm(dim)
        self.norm_context = LayerNorm(context_dim) if norm_context else Identity()

        # 初始化查询映射和键值映射
        self.null_kv = nn.Parameter(torch.randn(2, dim_head))
        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)

        self.q_scale = nn.Parameter(torch.ones(dim_head))
        self.k_scale = nn.Parameter(torch.ones(dim_head))

        # 初始化输出映射
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim, bias = False),
            LayerNorm(dim)
        )
    # 定义前向传播函数,接受输入 x、上下文 context 和可选的掩码 mask
    def forward(self, x, context, mask = None):
        # 获取输入 x 的形状信息,包括 batch 大小 b、序列长度 n、设备信息 device
        b, n, device = *x.shape[:2], x.device

        # 对输入 x 和上下文 context 进行归一化处理
        x = self.norm(x)
        context = self.norm_context(context)

        # 将输入 x 转换为查询 q,上下文 context 转换为键 k 和值 v
        q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))

        # 将查询 q、键 k 和值 v 重排为多头注意力的形式
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v))

        # 为先验网络添加空键/值,用于无分类器干预的指导
        nk, nv = map(lambda t: repeat(t, 'd -> b h 1 d', h = self.heads,  b = b), self.null_kv.unbind(dim = -2))
        k = torch.cat((nk, k), dim = -2)
        v = torch.cat((nv, v), dim = -2)

        # 对查询 q 和键 k 进行 L2 归一化处理
        q, k = map(l2norm, (q, k))
        q = q * self.q_scale
        k = k * self.k_scale

        # 计算相似度矩阵
        sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        # 掩码处理
        max_neg_value = -torch.finfo(sim.dtype).max
        if exists(mask):
            mask = F.pad(mask, (1, 0), value = True)
            mask = rearrange(mask, 'b j -> b 1 1 j')
            sim = sim.masked_fill(~mask, max_neg_value)

        # 对相似度矩阵进行 softmax 操作,得到注意力权重
        attn = sim.softmax(dim = -1, dtype = torch.float32)

        # 根据注意力权重计算输出
        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        # 返回输出结果
        return self.to_out(out)
class LinearCrossAttention(CrossAttention):
    # 线性交叉注意力类,继承自CrossAttention类
    def forward(self, x, context, mask = None):
        # 前向传播函数,接受输入x、上下文context和掩码mask,默认为None
        b, n, device = *x.shape[:2], x.device

        x = self.norm(x)
        # 对输入x进行规范化
        context = self.norm_context(context)
        # 对上下文context进行规范化

        q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
        # 将输入x和上下文context转换为查询q、键k和值v

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = self.heads), (q, k, v))
        # 重排查询q、键k和值v的维度

        # add null key / value for classifier free guidance in prior net
        # 为先前网络中的无分类器自由指导添加空键/值

        nk, nv = map(lambda t: repeat(t, 'd -> (b h) 1 d', h = self.heads,  b = b), self.null_kv.unbind(dim = -2))

        k = torch.cat((nk, k), dim = -2)
        v = torch.cat((nv, v), dim = -2)

        # masking
        # 掩码处理

        max_neg_value = -torch.finfo(x.dtype).max

        if exists(mask):
            mask = F.pad(mask, (1, 0), value = True)
            mask = rearrange(mask, 'b n -> b n 1')
            k = k.masked_fill(~mask, max_neg_value)
            v = v.masked_fill(~mask, 0.)

        # linear attention
        # 线性注意力

        q = q.softmax(dim = -1)
        k = k.softmax(dim = -2)

        q = q * self.scale

        context = einsum('b n d, b n e -> b d e', k, v)
        out = einsum('b n d, b d e -> b n e', q, context)
        out = rearrange(out, '(b h) n d -> b n (h d)', h = self.heads)
        return self.to_out(out)

class LinearAttention(nn.Module):
    # 线性注意力类,继承自nn.Module类
    def __init__(
        self,
        dim,
        dim_head = 32,
        heads = 8,
        dropout = 0.05,
        context_dim = None,
        **kwargs
    ):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        inner_dim = dim_head * heads
        self.norm = ChanLayerNorm(dim)

        self.nonlin = nn.SiLU()

        self.to_q = nn.Sequential(
            nn.Dropout(dropout),
            Conv2d(dim, inner_dim, 1, bias = False),
            Conv2d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim)
        )

        self.to_k = nn.Sequential(
            nn.Dropout(dropout),
            Conv2d(dim, inner_dim, 1, bias = False),
            Conv2d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim)
        )

        self.to_v = nn.Sequential(
            nn.Dropout(dropout),
            Conv2d(dim, inner_dim, 1, bias = False),
            Conv2d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim)
        )

        self.to_context = nn.Sequential(nn.LayerNorm(context_dim), nn.Linear(context_dim, inner_dim * 2, bias = False)) if exists(context_dim) else None

        self.to_out = nn.Sequential(
            Conv2d(inner_dim, dim, 1, bias = False),
            ChanLayerNorm(dim)
        )

    def forward(self, fmap, context = None):
        # 前向传播函数,接受特征图fmap和上下文context,默认为None
        h, x, y = self.heads, *fmap.shape[-2:]

        fmap = self.norm(fmap)
        q, k, v = map(lambda fn: fn(fmap), (self.to_q, self.to_k, self.to_v))
        q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> (b h) (x y) c', h = h), (q, k, v))

        if exists(context):
            assert exists(self.to_context)
            ck, cv = self.to_context(context).chunk(2, dim = -1)
            ck, cv = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (ck, cv))
            k = torch.cat((k, ck), dim = -2)
            v = torch.cat((v, cv), dim = -2)

        q = q.softmax(dim = -1)
        k = k.softmax(dim = -2)

        q = q * self.scale

        context = einsum('b n d, b n e -> b d e', k, v)
        out = einsum('b n d, b d e -> b n e', q, context)
        out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, x = x, y = y)

        out = self.nonlin(out)
        return self.to_out(out)

class GlobalContext(nn.Module):
    # 全局上下文类,继承自nn.Module类
    """ basically a superior form of squeeze-excitation that is attention-esque """
    # 基本上是一种类似于注意力的优越形式的挤压激励

    def __init__(
        self,
        *,
        dim_in,
        dim_out
        # 初始化函数,接受输入维度dim_in和输出维度dim_out
    # 定义一个继承自 nn.Module 的类,用于实现一个自定义的注意力机制模块
    ):
        # 调用父类的构造函数
        super().__init__()
        # 定义一个将输入特征维度转换为 K 维度的卷积层
        self.to_k = Conv2d(dim_in, 1, 1)
        # 计算隐藏层维度,取最大值为 3 或者输出维度的一半
        hidden_dim = max(3, dim_out // 2)

        # 定义一个神经网络序列,包含卷积层、激活函数和输出层
        self.net = nn.Sequential(
            Conv2d(dim_in, hidden_dim, 1),
            nn.SiLU(),  # 使用 SiLU 激活函数
            Conv2d(hidden_dim, dim_out, 1),
            nn.Sigmoid()  # 使用 Sigmoid 激活函数
        )

    # 定义前向传播函数
    def forward(self, x):
        # 将输入 x 经过 to_k 卷积层得到 context
        context = self.to_k(x)
        # 对输入 x 和 context 进行维度重排
        x, context = map(lambda t: rearrange(t, 'b n ... -> b n (...)'), (x, context))
        # 使用 einsum 计算注意力权重并与输入 x 相乘
        out = einsum('b i n, b c n -> b c i', context.softmax(dim = -1), x)
        # 对输出 out 进行维度重排
        out = rearrange(out, '... -> ... 1 1')
        # 将处理后的 out 输入到神经网络序列中得到最终输出
        return self.net(out)
# 定义一个前馈神经网络模块,包含层归一化、线性层、GELU激活函数和线性层
def FeedForward(dim, mult = 2):
    # 计算隐藏层维度
    hidden_dim = int(dim * mult)
    return nn.Sequential(
        LayerNorm(dim),  # 层归一化
        nn.Linear(dim, hidden_dim, bias = False),  # 线性层
        nn.GELU(),  # GELU激活函数
        LayerNorm(hidden_dim),  # 层归一化
        nn.Linear(hidden_dim, dim, bias = False)  # 线性层
    )

# 定义一个时间标记位移模块
class TimeTokenShift(nn.Module):
    def forward(self, x):
        if x.ndim != 5:
            return x

        x, x_shift = x.chunk(2, dim = 1)  # 将输入张量按维度1分块
        x_shift = F.pad(x_shift, (0, 0, 0, 0, 1, -1), value = 0.)  # 对x_shift进行填充
        return torch.cat((x, x_shift), dim = 1)  # 在维度1上连接张量x和x_shift

# 定义一个通道前馈神经网络模块
def ChanFeedForward(dim, mult = 2, time_token_shift = True):
    # 计算隐藏层维度
    hidden_dim = int(dim * mult)
    return Sequential(
        ChanLayerNorm(dim),  # 通道层归一化
        Conv2d(dim, hidden_dim, 1, bias = False),  # 二维卷积层
        nn.GELU(),  # GELU激活函数
        TimeTokenShift() if time_token_shift else None,  # 时间标记位移模块
        ChanLayerNorm(hidden_dim),  # 通道层归一化
        Conv2d(hidden_dim, dim, 1, bias = False)  # 二维卷积层
    )

# 定义一个Transformer块模块
class TransformerBlock(nn.Module):
    def __init__(
        self,
        dim,
        *,
        depth = 1,
        heads = 8,
        dim_head = 32,
        ff_mult = 2,
        ff_time_token_shift = True,
        context_dim = None
    ):
        super().__init__()
        self.layers = nn.ModuleList([])

        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim = dim, heads = heads, dim_head = dim_head, context_dim = context_dim),  # 注意力机制
                ChanFeedForward(dim = dim, mult = ff_mult, time_token_shift = ff_time_token_shift)  # 通道前馈神经网络
            ]))

    def forward(self, x, context = None):
        for attn, ff in self.layers:
            x = rearrange(x, 'b c ... -> b ... c')  # 重新排列张量维度
            x, ps = pack([x], 'b * c')  # 打包张量

            x = attn(x, context = context) + x  # 注意力机制处理后与原始张量相加

            x, = unpack(x, ps, 'b * c')  # 解包张量
            x = rearrange(x, 'b ... c -> b c ...')  # 重新排列张量维度

            x = ff(x) + x  # 通道前馈神经网络处理后与原始张量相加
        return x

# 定义一个线性注意力Transformer块模块
class LinearAttentionTransformerBlock(nn.Module):
    def __init__(
        self,
        dim,
        *,
        depth = 1,
        heads = 8,
        dim_head = 32,
        ff_mult = 2,
        ff_time_token_shift = True,
        context_dim = None,
        **kwargs
    ):
        super().__init__()
        self.layers = nn.ModuleList([])

        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                LinearAttention(dim = dim, heads = heads, dim_head = dim_head, context_dim = context_dim),  # 线性注意力机制
                ChanFeedForward(dim = dim, mult = ff_mult, time_token_shift = ff_time_token_shift)  # 通道前馈神经网络
            ]))

    def forward(self, x, context = None):
        for attn, ff in self.layers:
            x = attn(x, context = context) + x  # 线性注意力机制处理后与原始张量相加
            x = ff(x) + x  # 通道前馈神经网络处理后与原始张量相加
        return x

# 定义一个交叉嵌入层模块
class CrossEmbedLayer(nn.Module):
    def __init__(
        self,
        dim_in,
        kernel_sizes,
        dim_out = None,
        stride = 2
    ):
        super().__init__()
        assert all([*map(lambda t: (t % 2) == (stride % 2), kernel_sizes)])
        dim_out = default(dim_out, dim_in)

        kernel_sizes = sorted(kernel_sizes)
        num_scales = len(kernel_sizes)

        # 计算每个尺度的维度
        dim_scales = [int(dim_out / (2 ** i)) for i in range(1, num_scales)]
        dim_scales = [*dim_scales, dim_out - sum(dim_scales)]

        self.convs = nn.ModuleList([])
        for kernel, dim_scale in zip(kernel_sizes, dim_scales):
            self.convs.append(Conv2d(dim_in, dim_scale, kernel, stride = stride, padding = (kernel - stride) // 2))

    def forward(self, x):
        fmaps = tuple(map(lambda conv: conv(x), self.convs))  # 对输入张量进行卷积操作
        return torch.cat(fmaps, dim = 1)  # 在维度1上连接卷积结果

# 定义一个上采样合并器模块
class UpsampleCombiner(nn.Module):
    def __init__(
        self,
        dim,
        *,
        enabled = False,
        dim_ins = tuple(),
        dim_outs = tuple()
    # 初始化函数,设置输出维度和是否启用
    ):
        # 调用父类的初始化函数
        super().__init__()
        # 将输出维度转换为元组,长度与输入维度相同
        dim_outs = cast_tuple(dim_outs, len(dim_ins))
        # 断言输入维度和输出维度长度相同
        assert len(dim_ins) == len(dim_outs)

        # 设置是否启用标志
        self.enabled = enabled

        # 如果未启用,则直接设置输出维度并返回
        if not self.enabled:
            self.dim_out = dim
            return

        # 根据输入维度和输出维度创建模块列表
        self.fmap_convs = nn.ModuleList([Block(dim_in, dim_out) for dim_in, dim_out in zip(dim_ins, dim_outs)])
        # 计算最终输出维度
        self.dim_out = dim + (sum(dim_outs) if len(dim_outs) > 0 else 0)

    # 前向传播函数,处理输入数据和特征图
    def forward(self, x, fmaps = None):
        # 获取输入数据的目标尺寸
        target_size = x.shape[-1]

        # 设置特征图为默认值空元组
        fmaps = default(fmaps, tuple())

        # 如果未启用或特征图为空或卷积模块为空,则直接返回输入数据
        if not self.enabled or len(fmaps) == 0 or len(self.fmap_convs) == 0:
            return x

        # 将特征图调整为目标尺寸
        fmaps = [resize_video_to(fmap, target_size) for fmap in fmaps]
        # 对每个特征图应用对应的卷积模块
        outs = [conv(fmap) for fmap, conv in zip(fmaps, self.fmap_convs)]
        # 拼接输入数据和卷积结果,沿指定维度拼接
        return torch.cat((x, *outs), dim = 1)
# 定义一个动态位置偏置的神经网络模块
class DynamicPositionBias(nn.Module):
    def __init__(
        self,
        dim,
        *,
        heads,
        depth
    ):
        super().__init__()
        self.mlp = nn.ModuleList([])

        # 添加一个线性层、LayerNorm 和 SiLU 激活函数到 MLP 中
        self.mlp.append(nn.Sequential(
            nn.Linear(1, dim),
            LayerNorm(dim),
            nn.SiLU()
        ))

        # 根据深度添加多个线性层、LayerNorm 和 SiLU 激活函数到 MLP 中
        for _ in range(max(depth - 1, 0)):
            self.mlp.append(nn.Sequential(
                nn.Linear(dim, dim),
                LayerNorm(dim),
                nn.SiLU()
            ))

        # 添加一个线性层到 MLP 中
        self.mlp.append(nn.Linear(dim, heads)

    # 前向传播函数
    def forward(self, n, device, dtype):
        # 创建张量 i 和 j
        i = torch.arange(n, device = device)
        j = torch.arange(n, device = device)

        # 计算位置索引
        indices = rearrange(i, 'i -> i 1') - rearrange(j, 'j -> 1 j')
        indices += (n - 1)

        # 创建位置张量
        pos = torch.arange(-n + 1, n, device = device, dtype = dtype)
        pos = rearrange(pos, '... -> ... 1')

        # 遍历 MLP 中的每一层
        for layer in self.mlp:
            pos = layer(pos)

        # 计算位置偏置
        bias = pos[indices]
        bias = rearrange(bias, 'i j h -> h i j')
        return bias

# 定义一个 3D UNet 神经网络模块
class Unet3D(nn.Module):
    def __init__(
        self,
        *,
        dim,
        text_embed_dim = get_encoded_dim(DEFAULT_T5_NAME),
        num_resnet_blocks = 1,
        cond_dim = None,
        num_image_tokens = 4,
        num_time_tokens = 2,
        learned_sinu_pos_emb_dim = 16,
        out_dim = None,
        dim_mults = (1, 2, 4, 8),
        temporal_strides = 1,
        cond_images_channels = 0,
        channels = 3,
        channels_out = None,
        attn_dim_head = 64,
        attn_heads = 8,
        ff_mult = 2.,
        ff_time_token_shift = True,         # 在 feedforwards 的隐藏层中沿时间轴进行令牌移位
        lowres_cond = False,                # 用于级联扩散
        layer_attns = False,
        layer_attns_depth = 1,
        layer_attns_add_text_cond = True,   # 是否在自注意力块中加入文本嵌入
        attend_at_middle = True,            # 是否在瓶颈处进行一层注意力
        time_rel_pos_bias_depth = 2,
        time_causal_attn = True,
        layer_cross_attns = True,
        use_linear_attn = False,
        use_linear_cross_attn = False,
        cond_on_text = True,
        max_text_len = 256,
        init_dim = None,
        resnet_groups = 8,
        init_conv_kernel_size = 7,          # 初始卷积的内核大小
        init_cross_embed = True,
        init_cross_embed_kernel_sizes = (3, 7, 15),
        cross_embed_downsample = False,
        cross_embed_downsample_kernel_sizes = (2, 4),
        attn_pool_text = True,
        attn_pool_num_latents = 32,
        dropout = 0.,
        memory_efficient = False,
        init_conv_to_final_conv_residual = False,
        use_global_context_attn = True,
        scale_skip_connection = True,
        final_resnet_block = True,
        final_conv_kernel_size = 3,
        self_cond = False,
        combine_upsample_fmaps = False,      # 在所有上采样块中合并特征图
        pixel_shuffle_upsample = True,       # 可能解决棋盘伪影
        resize_mode = 'nearest'
    # 如果当前 UNet 的设置不正确,则重新初始化 UNet
    def cast_model_parameters(
        self,
        *,
        lowres_cond,
        text_embed_dim,
        channels,
        channels_out,
        cond_on_text
    # 如果当前对象的属性与传入参数相同,则直接返回当前对象
    ):
        if lowres_cond == self.lowres_cond and \
            channels == self.channels and \
            cond_on_text == self.cond_on_text and \
            text_embed_dim == self._locals['text_embed_dim'] and \
            channels_out == self.channels_out:
            return self

        # 更新参数字典
        updated_kwargs = dict(
            lowres_cond = lowres_cond,
            text_embed_dim = text_embed_dim,
            channels = channels,
            channels_out = channels_out,
            cond_on_text = cond_on_text
        )

        # 返回一个新的类实例,使用当前对象的属性和更新后的参数
        return self.__class__(**{**self._locals, **updated_kwargs})

    # 返回完整的unet配置及其参数状态字典的方法

    def to_config_and_state_dict(self):
        return self._locals, self.state_dict()

    # 从配置和状态字典中重新创建unet的类方法

    @classmethod
    def from_config_and_state_dict(klass, config, state_dict):
        unet = klass(**config)
        unet.load_state_dict(state_dict)
        return unet

    # 将unet持久化到磁盘的方法

    def persist_to_file(self, path):
        path = Path(path)
        path.parents[0].mkdir(exist_ok = True, parents = True)

        config, state_dict = self.to_config_and_state_dict()
        pkg = dict(config = config, state_dict = state_dict)
        torch.save(pkg, str(path))

    # 从使用`persist_to_file`保存的文件中重新创建unet的类方法

    @classmethod
    def hydrate_from_file(klass, path):
        path = Path(path)
        assert path.exists()
        pkg = torch.load(str(path))

        assert 'config' in pkg and 'state_dict' in pkg
        config, state_dict = pkg['config'], pkg['state_dict']

        return Unet.from_config_and_state_dict(config, state_dict)

    # 带有分类器自由引导的前向传播

    def forward_with_cond_scale(
        self,
        *args,
        cond_scale = 1.,
        **kwargs
    ):
        logits = self.forward(*args, **kwargs)

        if cond_scale == 1:
            return logits

        null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
        return null_logits + (logits - null_logits) * cond_scale

    def forward(
        self,
        x,
        time,
        *,
        lowres_cond_img = None,
        lowres_noise_times = None,
        text_embeds = None,
        text_mask = None,
        cond_images = None,
        cond_video_frames = None,
        post_cond_video_frames = None,
        self_cond = None,
        cond_drop_prob = 0.,
        ignore_time = False

.\lucidrains\imagen-pytorch\imagen_pytorch\t5.py

# 导入 torch 库
import torch
# 导入 transformers 库
import transformers
# 导入 List 类型
from typing import List
# 从 transformers 库中导入 T5Tokenizer, T5EncoderModel, T5Config
from transformers import T5Tokenizer, T5EncoderModel, T5Config
# 从 einops 库中导入 rearrange 函数
from einops import rearrange

# 设置 transformers 库的日志级别为 error
transformers.logging.set_verbosity_error()

# 定义函数,判断值是否存在
def exists(val):
    return val is not None

# 定义函数,返回默认值
def default(val, d):
    if exists(val):
        return val
    return d() if callable(d) else d

# 配置

# 定义最大长度为 256
MAX_LENGTH = 256

# 默认的 T5 模型名称
DEFAULT_T5_NAME = 'google/t5-v1_1-base'

# T5 配置字典
T5_CONFIGS = {}

# 全局单例变量

# 获取指定名称的 tokenizer
def get_tokenizer(name):
    tokenizer = T5Tokenizer.from_pretrained(name, model_max_length=MAX_LENGTH)
    return tokenizer

# 获取指定名称的模型
def get_model(name):
    model = T5EncoderModel.from_pretrained(name)
    return model

# 获取指定名称的模型和 tokenizer
def get_model_and_tokenizer(name):
    global T5_CONFIGS

    if name not in T5_CONFIGS:
        T5_CONFIGS[name] = dict()
    if "model" not in T5_CONFIGS[name]:
        T5_CONFIGS[name]["model"] = get_model(name)
    if "tokenizer" not in T5_CONFIGS[name]:
        T5_CONFIGS[name]["tokenizer"] = get_tokenizer(name)

    return T5_CONFIGS[name]['model'], T5_CONFIGS[name]['tokenizer']

# 获取编码维度
def get_encoded_dim(name):
    if name not in T5_CONFIGS:
        # 避免仅获取维度时加载模型
        config = T5Config.from_pretrained(name)
        T5_CONFIGS[name] = dict(config=config)
    elif "config" in T5_CONFIGS[name]:
        config = T5_CONFIGS[name]["config"]
    elif "model" in T5_CONFIGS[name]:
        config = T5_CONFIGS[name]["model"].config
    else:
        assert False
    return config.d_model

# 编码文本

# 对文本进行分词
def t5_tokenize(
    texts: List[str],
    name = DEFAULT_T5_NAME
):
    t5, tokenizer = get_model_and_tokenizer(name)

    if torch.cuda.is_available():
        t5 = t5.cuda()

    device = next(t5.parameters()).device

    encoded = tokenizer.batch_encode_plus(
        texts,
        return_tensors = "pt",
        padding = 'longest',
        max_length = MAX_LENGTH,
        truncation = True
    )

    input_ids = encoded.input_ids.to(device)
    attn_mask = encoded.attention_mask.to(device)
    return input_ids, attn_mask

# 对分词后的文本进行编码
def t5_encode_tokenized_text(
    token_ids,
    attn_mask = None,
    pad_id = None,
    name = DEFAULT_T5_NAME
):
    assert exists(attn_mask) or exists(pad_id)
    t5, _ = get_model_and_tokenizer(name)

    attn_mask = default(attn_mask, lambda: (token_ids != pad_id).long())

    t5.eval()

    with torch.no_grad():
        output = t5(input_ids = token_ids, attention_mask = attn_mask)
        encoded_text = output.last_hidden_state.detach()

    attn_mask = attn_mask.bool()

    encoded_text = encoded_text.masked_fill(~rearrange(attn_mask, '... -> ... 1'), 0.) # 强制所有填充的嵌入为 0
    return encoded_text

# 对文本进行编码
def t5_encode_text(
    texts: List[str],
    name = DEFAULT_T5_NAME,
    return_attn_mask = False
):
    token_ids, attn_mask = t5_tokenize(texts, name = name)
    encoded_text = t5_encode_tokenized_text(token_ids, attn_mask = attn_mask, name = name)

    if return_attn_mask:
        attn_mask = attn_mask.bool()
        return encoded_text, attn_mask

    return encoded_text

.\lucidrains\imagen-pytorch\imagen_pytorch\test\test_trainer.py

# 从 imagen_pytorch 包中导入 ImagenTrainer 类
# 从 imagen_pytorch 包中导入 ImagenConfig 类
# 从 imagen_pytorch 包中导入 t5_encode_text 函数
# 从 torch.utils.data 包中导入 Dataset 类
# 导入 torch 库
from imagen_pytorch.trainer import ImagenTrainer
from imagen_pytorch.configs import ImagenConfig
from imagen_pytorch.t5 import t5_encode_text
from torch.utils.data import Dataset
import torch

# 定义一个测试函数,用于测试 ImagenTrainer 类的实例化
def test_trainer_instantiation():
    # 定义 unet1 字典,包含模型的参数配置
    unet1 = dict(
        dim = 8,
        dim_mults = (1, 1, 1, 1),
        num_resnet_blocks = 1,
        layer_attns = False,
        layer_cross_attns = False,
        attn_heads = 2
    )

    # 创建 ImagenConfig 对象,传入 unet1 参数配置
    imagen = ImagenConfig(
        unets=(unet1,),
        image_sizes=(64,),
    ).create()

    # 实例化 ImagenTrainer 对象,传入 imagen 参数
    trainer = ImagenTrainer(
        imagen=imagen
    )

# 定义一个测试函数,用于测试训练步骤
def test_trainer_step():
    # 定义一个自定义的 Dataset 类,用于生成训练数据
    class TestDataset(Dataset):
        def __init__(self):
            super().__init__()
        def __len__(self):
            return 16
        def __getitem__(self, index):
            return (torch.zeros(3, 64, 64), torch.zeros(6, 768))
    
    # 定义 unet1 字典,包含模型的参数配置
    unet1 = dict(
        dim = 8,
        dim_mults = (1, 1, 1, 1),
        num_resnet_blocks = 1,
        layer_attns = False,
        layer_cross_attns = False,
        attn_heads = 2
    )

    # 创建 ImagenConfig 对象,传入 unet1 参数配置
    imagen = ImagenConfig(
        unets=(unet1,),
        image_sizes=(64,),
    ).create()

    # 实例化 ImagenTrainer 对象,传入 imagen 参数
    trainer = ImagenTrainer(
        imagen=imagen
    )

    # 创建 TestDataset 对象
    ds = TestDataset()
    # 将数据集添加到训练器中,设置批量大小为 8
    trainer.add_train_dataset(ds, batch_size=8)
    # 执行一次训练步骤
    trainer.train_step(1)
    # 断言训练步骤的数量为 1
    assert trainer.num_steps_taken(1) == 1

.\lucidrains\imagen-pytorch\imagen_pytorch\test\__init__.py

# 从 imagen_pytorch.test 模块中导入 test_trainer 函数
from imagen_pytorch.test import test_trainer

.\lucidrains\imagen-pytorch\imagen_pytorch\trainer.py

# 导入必要的库
import os
from math import ceil
from contextlib import contextmanager, nullcontext
from functools import partial, wraps
from collections.abc import Iterable

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import random_split, DataLoader
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR
from torch.cuda.amp import autocast, GradScaler

import pytorch_warmup as warmup

from imagen_pytorch.imagen_pytorch import Imagen, NullUnet
from imagen_pytorch.elucidated_imagen import ElucidatedImagen
from imagen_pytorch.data import cycle

from imagen_pytorch.version import __version__
from packaging import version

import numpy as np

from ema_pytorch import EMA

from accelerate import Accelerator, DistributedType, DistributedDataParallelKwargs

from fsspec.core import url_to_fs
from fsspec.implementations.local import LocalFileSystem

# 辅助函数

# 检查值是否存在
def exists(val):
    return val is not None

# 返回值或默认值
def default(val, d):
    if exists(val):
        return val
    return d() if callable(d) else d

# 将值转换为元组
def cast_tuple(val, length = 1):
    if isinstance(val, list):
        val = tuple(val)

    return val if isinstance(val, tuple) else ((val,) * length)

# 查找第一个满足条件的元素的索引
def find_first(fn, arr):
    for ind, el in enumerate(arr):
        if fn(el):
            return ind
    return -1

# 选择并弹出指定键的值
def pick_and_pop(keys, d):
    values = list(map(lambda key: d.pop(key), keys))
    return dict(zip(keys, values))

# 根据键的条件分组字典
def group_dict_by_key(cond, d):
    return_val = [dict(),dict()]
    for key in d.keys():
        match = bool(cond(key))
        ind = int(not match)
        return_val[ind][key] = d[key]
    return (*return_val,)

# 检查字符串是否以指定前缀开头
def string_begins_with(prefix, str):
    return str.startswith(prefix)

# 根据键的前缀分组字典
def group_by_key_prefix(prefix, d):
    return group_dict_by_key(partial(string_begins_with, prefix), d)

# 根据前缀分组字典并修剪键
def groupby_prefix_and_trim(prefix, d):
    kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
    kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))
    return kwargs_without_prefix, kwargs

# 将数字分成组
def num_to_groups(num, divisor):
    groups = num // divisor
    remainder = num % divisor
    arr = [divisor] * groups
    if remainder > 0:
        arr.append(remainder)
    return arr

# URL转换为文���系统、存储桶、路径 - 用于将检查点保存到云端

def url_to_bucket(url):
    if '://' not in url:
        return url

    _, suffix = url.split('://')

    if prefix in {'gs', 's3'}:
        return suffix.split('/')[0]
    else:
        raise ValueError(f'storage type prefix "{prefix}" is not supported yet')

# 装饰器

# 模型评估装饰器
def eval_decorator(fn):
    def inner(model, *args, **kwargs):
        was_training = model.training
        model.eval()
        out = fn(model, *args, **kwargs)
        model.train(was_training)
        return out
    return inner

# 转换为Torch张量装饰器
def cast_torch_tensor(fn, cast_fp16 = False):
    @wraps(fn)
    def inner(model, *args, **kwargs):
        device = kwargs.pop('_device', model.device)
        cast_device = kwargs.pop('_cast_device', True)

        should_cast_fp16 = cast_fp16 and model.cast_half_at_training

        kwargs_keys = kwargs.keys()
        all_args = (*args, *kwargs.values())
        split_kwargs_index = len(all_args) - len(kwargs_keys)
        all_args = tuple(map(lambda t: torch.from_numpy(t) if exists(t) and isinstance(t, np.ndarray) else t, all_args))

        if cast_device:
            all_args = tuple(map(lambda t: t.to(device) if exists(t) and isinstance(t, torch.Tensor) else t, all_args))

        if should_cast_fp16:
            all_args = tuple(map(lambda t: t.half() if exists(t) and isinstance(t, torch.Tensor) and t.dtype != torch.bool else t, all_args))

        args, kwargs_values = all_args[:split_kwargs_index], all_args[split_kwargs_index:]
        kwargs = dict(tuple(zip(kwargs_keys, kwargs_values)))

        out = fn(model, *args, **kwargs)
        return out
    return inner
# 定义一个函数,将可迭代对象按照指定大小分割成子列表
def split_iterable(it, split_size):
    accum = []
    # 遍历可迭代对象,根据指定大小分割成子列表
    for ind in range(ceil(len(it) / split_size)):
        start_index = ind * split_size
        accum.append(it[start_index: (start_index + split_size)])
    return accum

# 定义一个函数,根据不同类型的输入进行分割操作
def split(t, split_size = None):
    # 如果未指定分割大小,则直接返回输入
    if not exists(split_size):
        return t

    # 如果输入是 torch.Tensor 类型,则按照指定大小在指定维度上进行分割
    if isinstance(t, torch.Tensor):
        return t.split(split_size, dim = 0)

    # 如果输入是可迭代对象,则调用 split_iterable 函数进行分割
    if isinstance(t, Iterable):
        return split_iterable(t, split_size)

    # 其他情况返回类型错误
    return TypeError

# 定义一个函数,查找满足条件的第一个元素
def find_first(cond, arr):
    # 遍历数组,找到满足条件的第一个元素并返回
    for el in arr:
        if cond(el):
            return el
    return None

# 定义一个函数,将参数和关键字参数按照指定大小分割成子列表
def split_args_and_kwargs(*args, split_size = None, **kwargs):
    # 将所有参数和关键字参数合并成一个列表
    all_args = (*args, *kwargs.values())
    len_all_args = len(all_args)
    # 找到第一个是 torch.Tensor 类型的参数
    first_tensor = find_first(lambda t: isinstance(t, torch.Tensor), all_args)
    assert exists(first_tensor)

    # 获取第一个 tensor 的大小作为 batch_size
    batch_size = len(first_tensor)
    split_size = default(split_size, batch_size)
    num_chunks = ceil(batch_size / split_size)

    dict_len = len(kwargs)
    dict_keys = kwargs.keys()
    split_kwargs_index = len_all_args - dict_len

    # 对所有参数和关键字参数进行分割操作
    split_all_args = [split(arg, split_size = split_size) if exists(arg) and isinstance(arg, (torch.Tensor, Iterable)) else ((arg,) * num_chunks) for arg in all_args]
    chunk_sizes = num_to_groups(batch_size, split_size)

    # 遍历分割后的结果,生成分块大小比例和分块后的参数和关键字参数
    for (chunk_size, *chunked_all_args) in tuple(zip(chunk_sizes, *split_all_args)):
        chunked_args, chunked_kwargs_values = chunked_all_args[:split_kwargs_index], chunked_all_args[split_kwargs_index:]
        chunked_kwargs = dict(tuple(zip(dict_keys, chunked_kwargs_values)))
        chunk_size_frac = chunk_size / batch_size
        yield chunk_size_frac, (chunked_args, chunked_kwargs)

# 定义一个装饰器函数,用于对输入的函数进行分块处理
def imagen_sample_in_chunks(fn):
    @wraps(fn)
    def inner(self, *args, max_batch_size = None, **kwargs):
        # 如果未指定最大批处理大小,则直接调用原函数
        if not exists(max_batch_size):
            return fn(self, *args, **kwargs)

        # 如果是无条件的训练,则根据最大批处理大小分块处理
        if self.imagen.unconditional:
            batch_size = kwargs.get('batch_size')
            batch_sizes = num_to_groups(batch_size, max_batch_size)
            outputs = [fn(self, *args, **{**kwargs, 'batch_size': sub_batch_size}) for sub_batch_size in batch_sizes]
        else:
            # 否则根据参数和关键字参数进行分块处理
            outputs = [fn(self, *chunked_args, **chunked_kwargs) for _, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs)]

        # 如果输出是 torch.Tensor 类型,则按照指定维��拼接
        if isinstance(outputs[0], torch.Tensor):
            return torch.cat(outputs, dim = 0)

        # 否则对输出进行拼接处理
        return list(map(lambda t: torch.cat(t, dim = 0), list(zip(*outputs))))

    return inner

# 定义一个函数,用于恢复模型的部分参数
def restore_parts(state_dict_target, state_dict_from):
    for name, param in state_dict_from.items():

        if name not in state_dict_target:
            continue

        if param.size() == state_dict_target[name].size():
            state_dict_target[name].copy_(param)
        else:
            print(f"layer {name}({param.size()} different than target: {state_dict_target[name].size()}")

    return state_dict_target

# 定义一个类,用于图像生成的训练
class ImagenTrainer(nn.Module):
    locked = False

    def __init__(
        self,
        imagen = None,
        imagen_checkpoint_path = None,
        use_ema = True,
        lr = 1e-4,
        eps = 1e-8,
        beta1 = 0.9,
        beta2 = 0.99,
        max_grad_norm = None,
        group_wd_params = True,
        warmup_steps = None,
        cosine_decay_max_steps = None,
        only_train_unet_number = None,
        fp16 = False,
        precision = None,
        split_batches = True,
        dl_tuple_output_keywords_names = ('images', 'text_embeds', 'text_masks', 'cond_images'),
        verbose = True,
        split_valid_fraction = 0.025,
        split_valid_from_train = False,
        split_random_seed = 42,
        checkpoint_path = None,
        checkpoint_every = None,
        checkpoint_fs = None,
        fs_kwargs: dict = None,
        max_checkpoints_keep = 20,
        **kwargs
    # 准备训练器,确保训练器尚未准备好,设置只训练的 UNet 编号,并将 prepared 标记为 True
    def prepare(self):
        assert not self.prepared, f'The trainer is allready prepared'
        self.validate_and_set_unet_being_trained(self.only_train_unet_number)
        self.prepared = True
    # 计算属性

    @property
    def device(self):
        return self.accelerator.device

    @property
    def is_distributed(self):
        return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1)

    @property
    def is_main(self):
        return self.accelerator.is_main_process

    @property
    def is_local_main(self):
        return self.accelerator.is_local_main_process

    @property
    def unwrapped_unet(self):
        return self.accelerator.unwrap_model(self.unet_being_trained)

    # 优化器辅助函数

    def get_lr(self, unet_number):
        self.validate_unet_number(unet_number)
        unet_index = unet_number - 1

        optim = getattr(self, f'optim{unet_index}')

        return optim.param_groups[0]['lr']

    # 仅允许同时训练一个 UNet 的函数

    def validate_and_set_unet_being_trained(self, unet_number = None):
        if exists(unet_number):
            self.validate_unet_number(unet_number)

        assert not exists(self.only_train_unet_number) or self.only_train_unet_number == unet_number, 'you cannot only train on one unet at a time. you will need to save the trainer into a checkpoint, and resume training on a new unet'

        self.only_train_unet_number = unet_number
        self.imagen.only_train_unet_number = unet_number

        if not exists(unet_number):
            return

        self.wrap_unet(unet_number)

    def wrap_unet(self, unet_number):
        if hasattr(self, 'one_unet_wrapped'):
            return

        unet = self.imagen.get_unet(unet_number)
        unet_index = unet_number - 1

        optimizer = getattr(self, f'optim{unet_index}')
        scheduler = getattr(self, f'scheduler{unet_index}')

        if self.train_dl:
            self.unet_being_trained, self.train_dl, optimizer = self.accelerator.prepare(unet, self.train_dl, optimizer)
        else:
            self.unet_being_trained, optimizer = self.accelerator.prepare(unet, optimizer)

        if exists(scheduler):
            scheduler = self.accelerator.prepare(scheduler)

        setattr(self, f'optim{unet_index}', optimizer)
        setattr(self, f'scheduler{unet_index}', scheduler)

        self.one_unet_wrapped = True

    # 由于没有每个优化器单独的 gradscaler,对 accelerator 进行修改

    def set_accelerator_scaler(self, unet_number):
        def patch_optimizer_step(accelerated_optimizer, method):
            def patched_step(*args, **kwargs):
                accelerated_optimizer._accelerate_step_called = True
                return method(*args, **kwargs)
            return patched_step

        unet_number = self.validate_unet_number(unet_number)
        scaler = getattr(self, f'scaler{unet_number - 1}')

        self.accelerator.scaler = scaler
        for optimizer in self.accelerator._optimizers:
            optimizer.scaler = scaler
            optimizer._accelerate_step_called = False
            optimizer._optimizer_original_step_method = optimizer.optimizer.step
            optimizer._optimizer_patched_step_method = patch_optimizer_step(optimizer, optimizer.optimizer.step)

    # 辅助打印函数

    def print(self, msg):
        if not self.is_main:
            return

        if not self.verbose:
            return

        return self.accelerator.print(msg)

    # 验证 UNet 编号

    def validate_unet_number(self, unet_number = None):
        if self.num_unets == 1:
            unet_number = default(unet_number, 1)

        assert 0 < unet_number <= self.num_unets, f'unet number should be in between 1 and {self.num_unets}'
        return unet_number

    # 训练步骤数
    # 返回指定 U-Net 编号的训练步数
    def num_steps_taken(self, unet_number = None):
        # 如果只有一个 U-Net,则默认使用编号为 1
        if self.num_unets == 1:
            unet_number = default(unet_number, 1)

        # 返回指定 U-Net 的训练步数
        return self.steps[unet_number - 1].item()

    # 打印未训练的 U-Net
    def print_untrained_unets(self):
        print_final_error = False

        # 遍历训练步数和 U-Net 对象,检查是否未训练
        for ind, (steps, unet) in enumerate(zip(self.steps.tolist(), self.imagen.unets)):
            if steps > 0 or isinstance(unet, NullUnet):
                continue

            # 打印未训练的 U-Net 编号
            self.print(f'unet {ind + 1} has not been trained')
            print_final_error = True

        # 如果存在未训练的 U-Net,则打印提示信息
        if print_final_error:
            self.print('when sampling, you can pass stop_at_unet_number to stop early in the cascade, so it does not try to generate with untrained unets')

    # 数据相关函数

    # 添加训练数据加载器
    def add_train_dataloader(self, dl = None):
        if not exists(dl):
            return

        # 确保训练数据加载器未添加过
        assert not exists(self.train_dl), 'training dataloader was already added'
        assert not self.prepared, f'You need to add the dataset before preperation'
        self.train_dl = dl

    # 添加验证数据加载器
    def add_valid_dataloader(self, dl):
        if not exists(dl):
            return

        # 确保验证数据加载器未添加过
        assert not exists(self.valid_dl), 'validation dataloader was already added'
        assert not self.prepared, f'You need to add the dataset before preperation'
        self.valid_dl = dl

    # 添加训练数据集
    def add_train_dataset(self, ds = None, *, batch_size, **dl_kwargs):
        if not exists(ds):
            return

        # 确保训练数据加载器未添加过
        assert not exists(self.train_dl), 'training dataloader was already added'

        # 如果需要从训练数据集中分割验证数据集
        valid_ds = None
        if self.split_valid_from_train:
            # 计算训练数据集和验证数据集的大小
            train_size = int((1 - self.split_valid_fraction) * len(ds)
            valid_size = len(ds) - train_size

            # 随机分割数据集
            ds, valid_ds = random_split(ds, [train_size, valid_size], generator = torch.Generator().manual_seed(self.split_random_seed))
            self.print(f'training with dataset of {len(ds)} samples and validating with randomly splitted {len(valid_ds)} samples')

        # 创建数据加载器并添加训练数据加载器
        dl = DataLoader(ds, batch_size = batch_size, **dl_kwargs)
        self.add_train_dataloader(dl)

        # 如果不需要从训练数据集中分割验证数据集,则直接返回
        if not self.split_valid_from_train:
            return

        # 添加验证数据集
        self.add_valid_dataset(valid_ds, batch_size = batch_size, **dl_kwargs)

    # 添加验证数据集
    def add_valid_dataset(self, ds, *, batch_size, **dl_kwargs):
        if not exists(ds):
            return

        # 确保验证数据加载器未添加过
        assert not exists(self.valid_dl), 'validation dataloader was already added'

        # 创建数据加载器并添加验证数据加载器
        dl = DataLoader(ds, batch_size = batch_size, **dl_kwargs)
        self.add_valid_dataloader(dl)

    # 创建训练数据迭代器
    def create_train_iter(self):
        assert exists(self.train_dl), 'training dataloader has not been registered with the trainer yet'

        if exists(self.train_dl_iter):
            return

        self.train_dl_iter = cycle(self.train_dl)

    # 创建验证数据迭代器
    def create_valid_iter(self):
        assert exists(self.valid_dl), 'validation dataloader has not been registered with the trainer yet'

        if exists(self.valid_dl_iter):
            return

        self.valid_dl_iter = cycle(self.valid_dl)

    # 训练步骤
    def train_step(self, *, unet_number = None, **kwargs):
        if not self.prepared:
            self.prepare()
        self.create_train_iter()

        kwargs = {'unet_number': unet_number, **kwargs}
        loss = self.step_with_dl_iter(self.train_dl_iter, **kwargs)
        self.update(unet_number = unet_number)
        return loss

    # 验证步骤
    @torch.no_grad()
    @eval_decorator
    def valid_step(self, **kwargs):
        if not self.prepared:
            self.prepare()
        self.create_valid_iter()
        context = self.use_ema_unets if kwargs.pop('use_ema_unets', False) else nullcontext
        with context():
            loss = self.step_with_dl_iter(self.valid_dl_iter, **kwargs)
        return loss
    # 使用 dl_iter 迭代器获取下一个数据元组
    def step_with_dl_iter(self, dl_iter, **kwargs):
        dl_tuple_output = cast_tuple(next(dl_iter))
        # 将数据元组转换为字典
        model_input = dict(list(zip(self.dl_tuple_output_keywords_names, dl_tuple_output)))
        # 调用 forward 方法计算损失
        loss = self.forward(**{**kwargs, **model_input})
        return loss

    # 检查点函数

    # 获取所有按照时间排序的检查点文件
    @property
    def all_checkpoints_sorted(self):
        glob_pattern = os.path.join(self.checkpoint_path, '*.pt')
        checkpoints = self.fs.glob(glob_pattern)
        sorted_checkpoints = sorted(checkpoints, key = lambda x: int(str(x).split('.')[-2]), reverse = True)
        return sorted_checkpoints

    # 从检查点文件夹加载模型
    def load_from_checkpoint_folder(self, last_total_steps = -1):
        if last_total_steps != -1:
            filepath = os.path.join(self.checkpoint_path, f'checkpoint.{last_total_steps}.pt')
            self.load(filepath)
            return

        sorted_checkpoints = self.all_checkpoints_sorted

        if len(sorted_checkpoints) == 0:
            self.print(f'no checkpoints found to load from at {self.checkpoint_path}')
            return

        last_checkpoint = sorted_checkpoints[0]
        self.load(last_checkpoint)

    # 保存到检查点文件夹
    def save_to_checkpoint_folder(self):
        self.accelerator.wait_for_everyone()

        if not self.can_checkpoint:
            return

        total_steps = int(self.steps.sum().item())
        filepath = os.path.join(self.checkpoint_path, f'checkpoint.{total_steps}.pt')

        self.save(filepath)

        if self.max_checkpoints_keep <= 0:
            return

        sorted_checkpoints = self.all_checkpoints_sorted
        checkpoints_to_discard = sorted_checkpoints[self.max_checkpoints_keep:]

        for checkpoint in checkpoints_to_discard:
            self.fs.rm(checkpoint)

    # 保存和加载函数

    # 保存模型到指定路径
    def save(
        self,
        path,
        overwrite = True,
        without_optim_and_sched = False,
        **kwargs
    ):
        self.accelerator.wait_for_everyone()

        if not self.can_checkpoint:
            return

        fs = self.fs

        assert not (fs.exists(path) and not overwrite)

        self.reset_ema_unets_all_one_device()

        # 构建保存对象
        save_obj = dict(
            model = self.imagen.state_dict(),
            version = __version__,
            steps = self.steps.cpu(),
            **kwargs
        )

        save_optim_and_sched_iter = range(0, self.num_unets) if not without_optim_and_sched else tuple()

        # 保存优化器和调度器状态
        for ind in save_optim_and_sched_iter:
            scaler_key = f'scaler{ind}'
            optimizer_key = f'optim{ind}'
            scheduler_key = f'scheduler{ind}'
            warmup_scheduler_key = f'warmup{ind}'

            scaler = getattr(self, scaler_key)
            optimizer = getattr(self, optimizer_key)
            scheduler = getattr(self, scheduler_key)
            warmup_scheduler = getattr(self, warmup_scheduler_key)

            if exists(scheduler):
                save_obj = {**save_obj, scheduler_key: scheduler.state_dict()}

            if exists(warmup_scheduler):
                save_obj = {**save_obj, warmup_scheduler_key: warmup_scheduler.state_dict()}

            save_obj = {**save_obj, scaler_key: scaler.state_dict(), optimizer_key: optimizer.state_dict()}

        if self.use_ema:
            save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()}

        # 确定是否存在 imagen 配置
        if hasattr(self.imagen, '_config'):
            self.print(f'this checkpoint is commandable from the CLI - "imagen --model {str(path)} \"<prompt>"')
            save_obj = {
                **save_obj,
                'imagen_type': 'elucidated' if self.is_elucidated else 'original',
                'imagen_params': self.imagen._config
            }

        # 保存到指定路径
        with fs.open(path, 'wb') as f:
            torch.save(save_obj, f)

        self.print(f'checkpoint saved to {path}')
    # 加载模型参数和优化器状态
    def load(self, path, only_model = False, strict = True, noop_if_not_exist = False):
        # 获取文件系统对象
        fs = self.fs

        # 如果文件不存在且设置了不执行操作,则打印消息并返回
        if noop_if_not_exist and not fs.exists(path):
            self.print(f'trainer checkpoint not found at {str(path)}')
            return

        # 断言文件存在,否则抛出异常
        assert fs.exists(path), f'{path} does not exist'

        # 重置所有 EMA 模型到同一设备上
        self.reset_ema_unets_all_one_device()

        # 避免在主进程中使用 Accelerate 时产生额外的 GPU 内存使用
        with fs.open(path) as f:
            # 加载模型参数和优化器状态
            loaded_obj = torch.load(f, map_location='cpu')

        # 检查加载的模型版本是否与当前包版本一致
        if version.parse(__version__) != version.parse(loaded_obj['version']):
            self.print(f'loading saved imagen at version {loaded_obj["version"]}, but current package version is {__version__}')

        try:
            # 加载模型参数
            self.imagen.load_state_dict(loaded_obj['model'], strict = strict)
        except RuntimeError:
            print("Failed loading state dict. Trying partial load")
            # 尝试部分加载模型参数
            self.imagen.load_state_dict(restore_parts(self.imagen.state_dict(),
                                                      loaded_obj['model']))

        # 如果只加载模型参数,则返回加载的对象
        if only_model:
            return loaded_obj

        # 复制加载的步数
        self.steps.copy_(loaded_obj['steps'])

        # 遍历所有 U-Net 模型
        for ind in range(0, self.num_unets):
            scaler_key = f'scaler{ind}'
            optimizer_key = f'optim{ind}'
            scheduler_key = f'scheduler{ind}'
            warmup_scheduler_key = f'warmup{ind}'

            # 获取对应的 scaler、optimizer、scheduler 和 warmup_scheduler
            scaler = getattr(self, scaler_key)
            optimizer = getattr(self, optimizer_key)
            scheduler = getattr(self, scheduler_key)
            warmup_scheduler = getattr(self, warmup_scheduler_key)

            # 如果 scheduler 存在且在加载对象中有对应的键,则加载其状态
            if exists(scheduler) and scheduler_key in loaded_obj:
                scheduler.load_state_dict(loaded_obj[scheduler_key])

            # 如果 warmup_scheduler 存在且在加载对象中���对应的键,则加载其状态
            if exists(warmup_scheduler) and warmup_scheduler_key in loaded_obj:
                warmup_scheduler.load_state_dict(loaded_obj[warmup_scheduler_key])

            # 如果 optimizer 存在,则尝试加载其状态
            if exists(optimizer):
                try:
                    optimizer.load_state_dict(loaded_obj[optimizer_key])
                    scaler.load_state_dict(loaded_obj[scaler_key])
                except:
                    self.print('could not load optimizer and scaler, possibly because you have turned on mixed precision training since the last run. resuming with new optimizer and scalers')

        # 如果使用 EMA,则加载 EMA 模型参数
        if self.use_ema:
            assert 'ema' in loaded_obj
            try:
                self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict)
            except RuntimeError:
                print("Failed loading state dict. Trying partial load")
                self.ema_unets.load_state_dict(restore_parts(self.ema_unets.state_dict(),
                                                             loaded_obj['ema']))

        # 打印加载成功的消息,并返回加载的对象
        self.print(f'checkpoint loaded from {path}')
        return loaded_obj

    # 获取所有 EMA 模型
    @property
    def unets(self):
        return nn.ModuleList([ema.ema_model for ema in self.ema_unets])

    # 获取指定编号的 EMA 模型
    def get_ema_unet(self, unet_number = None):
        # 如果不使用 EMA,则返回
        if not self.use_ema:
            return

        # 验证并获取正确的 U-Net 编号
        unet_number = self.validate_unet_number(unet_number)
        index = unet_number - 1

        # 如果 unets 是 nn.ModuleList,则转换为列表并更新 ema_unets
        if isinstance(self.unets, nn.ModuleList):
            unets_list = [unet for unet in self.ema_unets]
            delattr(self, 'ema_unets')
            self.ema_unets = unets_list

        # 将当前训练的 EMA 模型移到指定设备上
        if index != self.ema_unet_being_trained_index:
            for unet_index, unet in enumerate(self.ema_unets):
                unet.to(self.device if unet_index == index else 'cpu')

        # 更新当前训练的 EMA 模型索引,并返回对应的 EMA 模型
        self.ema_unet_being_trained_index = index
        return self.ema_unets[index]

    # 重置所有 EMA 模型到指定设备上
    def reset_ema_unets_all_one_device(self, device = None):
        # 如果不使用 EMA,则返回
        if not self.use_ema:
            return

        # 获取默认设备
        device = default(device, self.device)
        # 将所有 EMA 模型转移到指定设备上
        self.ema_unets = nn.ModuleList([*self.ema_unets])
        self.ema_unets.to(device)

        # 重置当前训练的 EMA 模型索引
        self.ema_unet_being_trained_index = -1

    # 禁用梯度计算
    @torch.no_grad()
    # 定义一个上下文管理器,用于控制是否使用指数移动平均的 U-Net 模型
    @contextmanager
    def use_ema_unets(self):
        # 如果不使用指数移动平均模型,则直接返回输出
        if not self.use_ema:
            output = yield
            return output

        # 重置所有 U-Net 模型为同一设备上的指数移动平均模型
        self.reset_ema_unets_all_one_device()
        self.imagen.reset_unets_all_one_device()

        # 将 U-Net 模型设置为评估模式
        self.unets.eval()

        # 保存可训练的 U-Net 模型,然后将指数移动平均模型用于采样
        trainable_unets = self.imagen.unets
        self.imagen.unets = self.unets

        output = yield

        # 恢复原始的训练 U-Net 模型
        self.imagen.unets = trainable_unets

        # 将指数移动平均模型的 U-Net 恢复到原始设备
        for ema in self.ema_unets:
            ema.restore_ema_model_device()

        return output

    # 打印 U-Net 模型的设备信息
    def print_unet_devices(self):
        self.print('unet devices:')
        for i, unet in enumerate(self.imagen.unets):
            device = next(unet.parameters()).device
            self.print(f'\tunet {i}: {device}')

        # 如果不使用指数移动平均模型,则直接返回
        if not self.use_ema:
            return

        self.print('\nema unet devices:')
        for i, ema_unet in enumerate(self.ema_unets):
            device = next(ema_unet.parameters()).device
            self.print(f'\tema unet {i}: {device}')

    # 重写状态字典函数

    def state_dict(self, *args, **kwargs):
        # 重置所有 U-Net 模型为同一设备上的指数移动平均模型
        self.reset_ema_unets_all_one_device()
        return super().state_dict(*args, **kwargs)

    def load_state_dict(self, *args, **kwargs):
        # 重置所有 U-Net 模型为同一设备上的指数移动平均模型
        self.reset_ema_unets_all_one_device()
        return super().load_state_dict(*args, **kwargs)

    # 编码文本函数

    def encode_text(self, text, **kwargs):
        return self.imagen.encode_text(text, **kwargs)

    # 前向传播函数和梯度更新步骤

    def update(self, unet_number = None):
        unet_number = self.validate_unet_number(unet_number)
        self.validate_and_set_unet_being_trained(unet_number)
        self.set_accelerator_scaler(unet_number)

        index = unet_number - 1
        unet = self.unet_being_trained

        optimizer = getattr(self, f'optim{index}')
        scaler = getattr(self, f'scaler{index}')
        scheduler = getattr(self, f'scheduler{index}')
        warmup_scheduler = getattr(self, f'warmup{index}')

        # 在加速器上设置梯度缩放器,因为我们每个 U-Net 管理一个

        if exists(self.max_grad_norm):
            self.accelerator.clip_grad_norm_(unet.parameters(), self.max_grad_norm)

        optimizer.step()
        optimizer.zero_grad()

        if self.use_ema:
            ema_unet = self.get_ema_unet(unet_number)
            ema_unet.update()

        # 调度器,如果需要

        maybe_warmup_context = nullcontext() if not exists(warmup_scheduler) else warmup_scheduler.dampening()

        with maybe_warmup_context:
            if exists(scheduler) and not self.accelerator.optimizer_step_was_skipped: # 推荐在文档中
                scheduler.step()

        self.steps += F.one_hot(torch.tensor(unet_number - 1, device = self.steps.device), num_classes = len(self.steps))

        if not exists(self.checkpoint_path):
            return

        total_steps = int(self.steps.sum().item())

        if total_steps % self.checkpoint_every:
            return

        self.save_to_checkpoint_folder()

    @torch.no_grad()
    @cast_torch_tensor
    @imagen_sample_in_chunks
    def sample(self, *args, **kwargs):
        context = nullcontext if  kwargs.pop('use_non_ema', False) else self.use_ema_unets

        self.print_untrained_unets()

        if not self.is_main:
            kwargs['use_tqdm'] = False

        with context():
            output = self.imagen.sample(*args, device = self.device, **kwargs)

        return output

    @partial(cast_torch_tensor, cast_fp16 = True)
    def forward(
        self,
        *args,
        unet_number = None,
        max_batch_size = None,
        **kwargs
        ):
        # 验证并修正 UNet 编号
        unet_number = self.validate_unet_number(unet_number)
        # 验证并设置正在训练的 UNet 编号
        self.validate_and_set_unet_being_trained(unet_number)
        # 设置加速器缩放器
        self.set_accelerator_scaler(unet_number)

        # 断言只有训练指定 UNet 编号或者没有指定 UNet 编号
        assert not exists(self.only_train_unet_number) or self.only_train_unet_number == unet_number, f'you can only train unet #{self.only_train_unet_number}'

        # 初始化总损失
        total_loss = 0.

        # 将参数和关键字参数按照最大批处理大小拆分
        for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
            # 使用加速器自动转换
            with self.accelerator.autocast():
                # 计算损失
                loss = self.imagen(*chunked_args, unet = self.unet_being_trained, unet_number = unet_number, **chunked_kwargs)
                # 损失乘以分块大小比例
                loss = loss * chunk_size_frac

            # 累加总损失
            total_loss += loss.item()

            # 如果处于训练状态,进行反向传播
            if self.training:
                self.accelerator.backward(loss)

        # 返回总损失
        return total_loss