Lucidrains-系列项目源码解析-四十四-

62 阅读39分钟

Lucidrains 系列项目源码解析(四十四)

.\lucidrains\hourglass-transformer-pytorch\hourglass_transformer_pytorch\__init__.py

# 从 hourglass_transformer_pytorch.hourglass_transformer_pytorch 模块中导入 HourglassTransformerLM 和 HourglassTransformer 类
from hourglass_transformer_pytorch.hourglass_transformer_pytorch import HourglassTransformerLM, HourglassTransformer

Hourglass Transformer - Pytorch

Implementation of Hourglass Transformer, in Pytorch.

Install

$ pip install hourglass-transformer-pytorch

Usage

import torch
from hourglass_transformer_pytorch import HourglassTransformerLM

model = HourglassTransformerLM(
    num_tokens = 256,               # number of tokens
    dim = 512,                      # feature dimension
    max_seq_len = 1024,             # maximum sequence length
    heads = 8,                      # attention heads
    dim_head = 64,                  # dimension per attention head
    shorten_factor = 2,             # shortening factor
    depth = (4, 2, 4),              # tuple of 3, standing for pre-transformer-layers, valley-transformer-layers (after downsample), post-transformer-layers (after upsample) - the valley transformer layers can be yet another nested tuple, in which case it will shorten again recursively
)

x = torch.randint(0, 256, (1, 1024))
logits = model(x) # (1, 1024, 256)

For something more sophisticated, two hourglasses, with one nested within the other

import torch
from hourglass_transformer_pytorch import HourglassTransformerLM

model = HourglassTransformerLM(
    num_tokens = 256,
    dim = 512,
    max_seq_len = 1024,
    shorten_factor = (2, 4),     # 2x for first hour glass, 4x for second
    depth = (4, (2, 1, 2), 3),   # 4@1 -> 2@2 -> 1@4 -> 2@2 -> 3@1
)

x = torch.randint(0, 256, (1, 1024))
logits = model(x)

Funnel Transformer would be approximately

import torch
from hourglass_transformer_pytorch import HourglassTransformerLM

model = HourglassTransformerLM(
    num_tokens = 20000,
    dim = 512,
    max_seq_len = 1024,
    causal = False,
    attn_resampling = False,
    shorten_factor = 2,
    depth = (2, (2, (2, 2, 2), 2), 2)
)

x = torch.randint(0, 20000, (1, 1024))
logits = model(x)

For images, instead of average pool and repeat for the down and upsampling functions, they found that linear projections worked a lot better. You can use this by setting updown_sample_type = 'linear'

import torch
from hourglass_transformer_pytorch import HourglassTransformer

model = HourglassTransformer(
    dim = 512,
    shorten_factor = 2,
    depth = (4, 2, 4),
    updown_sample_type = 'linear'
)

img_tokens = torch.randn(1, 1024, 512)
model(img_tokens) # (1, 1024, 512)

Although results were not presented in the paper, you can also use the Hourglass Transformer in this repository non-autoregressively.

import torch
from hourglass_transformer_pytorch import HourglassTransformerLM

model = HourglassTransformerLM(
    num_tokens = 20000,
    dim = 512,
    max_seq_len = 1024,
    shorten_factor = 2,
    depth = (4, 2, 4),
    causal = False          # set this to False
)

x = torch.randint(0, 256, (1, 1024))
mask = torch.ones((1, 1024)).bool()

logits = model(x, mask = mask) # (1, 1024, 20000)

Enwik8 autoregressive example

$ python train.py

Todo

  • work with non-autoregressive, accounting for masking
  • account for masking for attention resampling
  • account for shift padding when naive downsampling

Citations

@misc{nawrot2021hierarchical,
    title   = {Hierarchical Transformers Are More Efficient Language Models}, 
    author  = {Piotr Nawrot and Szymon Tworkowski and Michał Tyrolski and Łukasz Kaiser and Yuhuai Wu and Christian Szegedy and Henryk Michalewski},
    year    = {2021},
    eprint  = {2110.13711},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}

.\lucidrains\hourglass-transformer-pytorch\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'hourglass-transformer-pytorch', # 包的名称
  packages = find_packages(), # 查找所有包
  version = '0.0.6', # 版本号
  license='MIT', # 许可证
  description = 'Hourglass Transformer', # 描述
  author = 'Phil Wang', # 作者
  author_email = 'lucidrains@gmail.com', # 作者邮箱
  url = 'https://github.com/lucidrains/hourglass-transformer-pytorch', # 项目链接
  keywords = [ # 关键词列表
    'artificial intelligence',
    'attention mechanism',
    'transformers'
  ],
  install_requires=[ # 安装依赖
    'einops',
    'torch>=1.6'
  ],
  classifiers=[ # 分类器
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\hourglass-transformer-pytorch\train.py

# 导入所需的模块和类
from hourglass_transformer_pytorch import HourglassTransformerLM
from hourglass_transformer_pytorch.autoregressive_wrapper import AutoregressiveWrapper

import random
import tqdm
import gzip
import numpy as np
import torch
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset

# 定义常量
NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 2e-4
VALIDATE_EVERY  = 100
GENERATE_EVERY  = 500
GENERATE_LENGTH = 512
SEQ_LEN = 512

# 定义辅助函数

# 从 token 解码为字符
def decode_token(token):
    return str(chr(max(32, token)))

# 从 tokens 解码为字符串
def decode_tokens(tokens):
    return ''.join(list(map(decode_token, tokens)))

# 实例化类 GPT-like decoder model

model = HourglassTransformerLM(
    num_tokens = 256,
    dim = 512,
    max_seq_len = SEQ_LEN,
    depth = (4, 2, 4),
    shorten_factor = 2,
    heads = 8
)

model = AutoregressiveWrapper(model)
model.cuda()

# 准备 enwik8 数据

with gzip.open('./data/enwik8.gz') as file:
    X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
    trX, vaX = np.split(X, [int(90e6)])
    data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)

# 定义数据集类
class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
        full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
        return full_seq.cuda()

    def __len__(self):
        return self.data.size(0) // self.seq_len

# 创建训练集和验证集的数据加载器
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset   = TextSamplerDataset(data_val, SEQ_LEN)
train_loader  = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
val_loader    = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))

# 定义优化器
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

# 训练模型
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
    model.train()

    for __ in range(GRADIENT_ACCUMULATE_EVERY):
        loss = model(next(train_loader))
        loss.backward()

    print(f'training loss: {loss.item()}')
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
    optim.step()
    optim.zero_grad()

    if i % VALIDATE_EVERY == 0:
        model.eval()
        with torch.no_grad():
            loss = model(next(val_loader))
            print(f'validation loss: {loss.item()}')

    if i % GENERATE_EVERY == 0:
        model.eval()
        inp = random.choice(val_dataset)[:-1]
        prime = decode_tokens(inp)
        print(f'%s \n\n %s', (prime, '*' * 100))

        sample = model.generate(inp[None, ...], GENERATE_LENGTH)
        output_str = decode_tokens(sample[0])
        print(output_str)

.\lucidrains\HTM-pytorch\htm_pytorch\htm_pytorch.py

# 从 math 模块中导入 ceil 函数
from math import ceil
# 导入 torch 模块
import torch
# 从 torch 模块中导入 nn 和 einsum
from torch import nn, einsum
# 从 torch.nn.functional 模块中导入 F
import torch.nn.functional as F
# 从 einops 模块中导入 rearrange 和 repeat

from einops import rearrange, repeat

# helpers

# 定义函数 exists,判断值是否存在
def exists(val):
    return val is not None

# 定义函数 default,如果值存在则返回该值,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# 定义函数 pad_to_multiple,将输入张量在指定维度上填充到指定的倍数长度
def pad_to_multiple(t, multiple, dim = -2, value = 0.):
    seq_len = t.shape[dim]
    pad_to_len = ceil(seq_len / multiple) * multiple
    remainder = pad_to_len - seq_len

    if remainder == 0:
        return t

    zeroes = (0, 0) * (-dim - 1)
    padded_t = F.pad(t, (*zeroes, remainder, 0), value = value)
    return padded_t

# positional encoding

# 定义 SinusoidalPosition 类,用于生成位置编码
class SinusoidalPosition(nn.Module):
    def __init__(
        self,
        dim,
        min_timescale = 2.,
        max_timescale = 1e4
    ):
        super().__init__()
        freqs = torch.arange(0, dim, min_timescale)
        inv_freqs = max_timescale ** (-freqs / dim)
        self.register_buffer('inv_freqs', inv_freqs)

    def forward(self, x):
        seq_len = x.shape[-2]
        seq = torch.arange(seq_len - 1, -1, -1.)
        sinusoidal_inp = rearrange(seq, 'n -> n ()') * rearrange(self.inv_freqs, 'd -> () d')
        pos_emb = torch.cat((sinusoidal_inp.sin(), sinusoidal_inp.cos()), dim = -1)
        return pos_emb

# multi-head attention

# 定义 Attention 类,实现多头注意力机制
class Attention(nn.Module):
    def __init__(
        self,
        dim,
        dim_head = 64,
        heads = 8,
    ):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        inner_dim = dim_head * heads

        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
        self.to_out = nn.Linear(inner_dim, dim)

    def forward(
        self,
        x,
        mems,
        mask = None
    ):
        h = self.heads
        q, k, v = self.to_q(x), *self.to_kv(mems).chunk(2, dim = -1)

        q, k, v = map(lambda t: rearrange(t, 'b ... (h d) -> (b h) ... d', h = h), (q, k, v))
        q = q * self.scale

        sim = einsum('b m i d, b m i j d -> b m i j', q, k)

        if exists(mask):
            mask = repeat(mask, 'b ... -> (b h) ...', h = h)
            mask_value = -torch.finfo(sim.dtype).max
            sim = sim.masked_fill(~mask, mask_value)

        attn = sim.softmax(dim = -1)

        out = einsum('... i j, ... i j d -> ... i d', attn, v)
        out = rearrange(out, '(b h) ... d -> b ... (h d)', h = h)
        return self.to_out(out)

# main class

# 定义 HTMAttention 类,实现 HTMAttention 模型
class HTMAttention(nn.Module):
    def __init__(
        self,
        dim,
        heads,
        topk_mems = 2,
        mem_chunk_size = 32,
        dim_head = 64,
        add_pos_enc = True,
        eps = 1e-5
    ):
        super().__init__()
        self.dim = dim
        self.eps = eps
        self.scale = dim ** -0.5

        self.to_summary_queries = nn.Linear(dim, dim)
        self.to_summary_keys = nn.Linear(dim, dim)

        self.attn = Attention(dim = dim, heads = heads, dim_head = dim_head)

        self.topk_mems = topk_mems
        self.mem_chunk_size = mem_chunk_size
        self.pos_emb = SinusoidalPosition(dim = dim) if add_pos_enc else None

    def forward(
        self,
        queries,
        memories,
        mask = None,
        chunk_attn_mask = None
    ):
        # 解包参数
        dim, query_len, mem_chunk_size, topk_mems, scale, eps = self.dim, queries.shape[1], self.mem_chunk_size, self.topk_mems, self.scale, self.eps

        # 填充记忆,以及如果需要的话,填充记忆掩码,然后分成块

        memories = pad_to_multiple(memories, mem_chunk_size, dim = -2, value = 0.)
        memories = rearrange(memories, 'b (n c) d -> b n c d', c = mem_chunk_size)

        if exists(mask):
            mask = pad_to_multiple(mask, mem_chunk_size, dim = -1, value = False)
            mask = rearrange(mask, 'b (n c) -> b n c', c = mem_chunk_size)

        # 通过均值池化总结记忆,考虑掩码

        if exists(mask):
            mean_mask = rearrange(mask, '... -> ... ()')
            memories = memories.masked_fill(~mean_mask, 0.)
            numer = memories.sum(dim = 2)
            denom = mean_mask.sum(dim = 2)
            summarized_memories = numer / (denom + eps)
        else:
            summarized_memories = memories.mean(dim = 2)

        # 推导查询和总结的记忆键

        summary_queries = self.to_summary_queries(queries)
        summary_keys = self.to_summary_keys(summarized_memories.detach())

        # 对总结的键进行单头注意力

        sim = einsum('b i d, b j d -> b i j', summary_queries, summary_keys) * scale
        mask_value = -torch.finfo(sim.dtype).max

        if exists(mask):
            chunk_mask = mask.any(dim = 2)
            chunk_mask = rearrange(chunk_mask, 'b j -> b () j')
            sim = sim.masked_fill(~chunk_mask, mask_value)

        if exists(chunk_attn_mask):
            sim = sim.masked_fill(~chunk_attn_mask, mask_value)

        topk_logits, topk_indices = sim.topk(k = topk_mems, dim = -1)
        weights = topk_logits.softmax(dim = -1)

        # 为内存注意力准备查询

        queries = repeat(queries, 'b n d -> b k n d', k = topk_mems)

        # 选择前k个记忆

        memories = repeat(memories, 'b m j d -> b m i j d', i = query_len)
        mem_topk_indices = repeat(topk_indices, 'b i m -> b m i j d', j = mem_chunk_size, d = dim)
        selected_memories = memories.gather(1, mem_topk_indices)

        # 位置编码

        if exists(self.pos_emb):
            pos_emb = self.pos_emb(memories)
            selected_memories = selected_memories + rearrange(pos_emb, 'n d -> () () () n d')

        # 选择掩码

        selected_mask = None
        if exists(mask):
            mask = repeat(mask, 'b m j -> b m i j', i = query_len)
            mask_topk_indices = repeat(topk_indices, 'b i m -> b m i j', j = mem_chunk_size)
            selected_mask = mask.gather(1, mask_topk_indices)

        # 现在进行内存注意力

        within_mem_output = self.attn(
            queries,
            selected_memories.detach(),
            mask = selected_mask
        )

        # 对内存注意力输出进行加权

        weighted_output = within_mem_output * rearrange(weights, 'b i m -> b m i ()')
        output = weighted_output.sum(dim = 1)
        return output
# 定义一个 HTMBlock 类,继承自 nn.Module
class HTMBlock(nn.Module):
    # 初始化方法,接受维度参数和其他关键字参数
    def __init__(self, dim, **kwargs):
        super().__init__()
        # 初始化 LayerNorm 层,对输入进行归一化处理
        self.norm = nn.LayerNorm(dim)
        # 初始化 HTMAttention 层,处理注意力机制
        self.attn = HTMAttention(dim=dim, **kwargs)
    # 前向传播方法,接受查询 queries 和记忆 memories,以及其他关键字参数
    def forward(
        self,
        queries,
        memories,
        **kwargs
    ):
        # 对查询 queries 进行归一化处理
        queries = self.norm(queries)
        # 使用 HTMAttention 层处理查询 queries 和记忆 memories,再加上原始查询 queries
        out = self.attn(queries, memories, **kwargs) + queries
        # 返回处理后的结果
        return out

.\lucidrains\HTM-pytorch\htm_pytorch\__init__.py

# 从 htm_pytorch 包中导入 HTMAttention 和 HTMBlock 类
from htm_pytorch.htm_pytorch import HTMAttention, HTMBlock

Hierarchical Transformer Memory (HTM) - Pytorch

Implementation of Hierarchical Transformer Memory (HTM) for Pytorch. This Deepmind paper proposes a simple method to allow transformers to attend to memories of the past efficiently. Original Jax repository

Install

$ pip install htm-pytorch

Usage

import torch
from htm_pytorch import HTMAttention

attn = HTMAttention(
    dim = 512,
    heads = 8,               # number of heads for within-memory attention
    dim_head = 64,           # dimension per head for within-memory attention
    topk_mems = 8,           # how many memory chunks to select for
    mem_chunk_size = 32,     # number of tokens in each memory chunk
    add_pos_enc = True       # whether to add positional encoding to the memories
)

queries = torch.randn(1, 128, 512)     # queries
memories = torch.randn(1, 20000, 512)  # memories, of any size
mask = torch.ones(1, 20000).bool()     # memory mask

attended = attn(queries, memories, mask = mask) # (1, 128, 512)

If you want the entire HTM Block (which contains the layernorm for the input followed by a skip connection), just import HTMBlock instead

import torch
from htm_pytorch import HTMBlock

block = HTMBlock(
    dim = 512,
    topk_mems = 8,
    mem_chunk_size = 32
)

queries = torch.randn(1, 128, 512)
memories = torch.randn(1, 20000, 512)
mask = torch.ones(1, 20000).bool()

out = block(queries, memories, mask = mask) # (1, 128, 512)

Citations

@misc{lampinen2021mental,
    title   = {Towards mental time travel: a hierarchical memory for reinforcement learning agents}, 
    author  = {Andrew Kyle Lampinen and Stephanie C. Y. Chan and Andrea Banino and Felix Hill},
    year    = {2021},
    eprint  = {2105.14039},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}

.\lucidrains\HTM-pytorch\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'htm-pytorch',  # 包的名称
  packages = find_packages(),  # 查找所有包
  version = '0.0.4',  # 版本号
  license='MIT',  # 许可证
  description = 'Hierarchical Transformer Memory - Pytorch',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  url = 'https://github.com/lucidrains/htm-pytorch',  # 项目链接
  keywords = [  # 关键词列表
    'artificial intelligence',
    'deep learning',
    'attention-mechanism',
    'memory'
  ],
  install_requires=[  # 安装依赖
    'einops>=0.3',
    'torch>=1.6'
  ],
  classifiers=[  # 分类器列表
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\imagen-pytorch\imagen_pytorch\cli.py

import click
import torch
from pathlib import Path
import pkgutil

from imagen_pytorch import load_imagen_from_checkpoint
from imagen_pytorch.version import __version__
from imagen_pytorch.data import Collator
from imagen_pytorch.utils import safeget
from imagen_pytorch import ImagenTrainer, ElucidatedImagenConfig, ImagenConfig
from datasets import load_dataset, concatenate_datasets
from tqdm import tqdm
import json

# 定义一个函数,用于检查值是否存在
def exists(val):
    return val is not None

# 定义一个简单的字符串处理函数,将特殊字符替换为下划线,并截取指定长度
def simple_slugify(text: str, max_length = 255):
    return text.replace('-', '_').replace(',', '').replace(' ', '_').replace('|', '--').strip('-_./\\')[:max_length]

# 主函数
def main():
    pass

# 创建一个命令组
@click.group()
def imagen():
    pass

# 创建一个命令,用于从 Imagen 模型检查点中进行采样
@imagen.command(help = 'Sample from the Imagen model checkpoint')
@click.option('--model', default = './imagen.pt', help = 'path to trained Imagen model')
@click.option('--cond_scale', default = 5, help = 'conditioning scale (classifier free guidance) in decoder')
@click.option('--load_ema', default = True, help = 'load EMA version of unets if available')
@click.argument('text')
def sample(
    model,
    cond_scale,
    load_ema,
    text
):
    model_path = Path(model)
    full_model_path = str(model_path.resolve())
    assert model_path.exists(), f'model not found at {full_model_path}'
    loaded = torch.load(str(model_path))

    # 获取版本信息
    version = safeget(loaded, 'version')
    print(f'loading Imagen from {full_model_path}, saved at version {version} - current package version is {__version__}')

    # 获取 Imagen 参数和类型
    imagen = load_imagen_from_checkpoint(str(model_path), load_ema_if_available = load_ema)
    imagen.cuda()

    # 生成图像
    pil_image = imagen.sample([text], cond_scale = cond_scale, return_pil_images = True)

    image_path = f'./{simple_slugify(text)}.png'
    pil_image[0].save(image_path)

    print(f'image saved to {str(image_path)}')
    return

# 创建一个命令,用于生成 Imagen 模型的配置
@imagen.command(help = 'Generate a config for the Imagen model')
@click.option('--path', default = './imagen_config.json', help = 'Path to the Imagen model config')
def config(
    path
):
    data = pkgutil.get_data(__name__, 'default_config.json').decode("utf-8") 
    with open(path, 'w') as f:
        f.write(data)

# 创建一个命令,用于训练 Imagen 模型
@imagen.command(help = 'Train the Imagen model')
@click.option('--config', default = './imagen_config.json', help = 'Path to the Imagen model config')
@click.option('--unet', default = 1, help = 'Unet to train', type = click.IntRange(1, 3, False, True, True))
@click.option('--epoches', default = 50, help = 'Amount of epoches to train for')
def train(
    config,
    unet,
    epoches,
):
    # 检查配置文件路径
    config_path = Path(config)
    full_config_path = str(config_path.resolve())
    assert config_path.exists(), f'config not found at {full_config_path}'
    
    with open(config_path, 'r') as f:
        config_data = json.loads(f.read())

    assert 'checkpoint_path' in config_data, 'checkpoint path not found in config'
    
    model_path = Path(config_data['checkpoint_path'])
    full_model_path = str(model_path.resolve())
    
    # 设置 Imagen 配置
    imagen_config_klass = ElucidatedImagenConfig if config_data['type'] == 'elucidated' else ImagenConfig
    imagen = imagen_config_klass(**config_data['imagen']).create()

    trainer = ImagenTrainer(
    imagen = imagen,
        **config_data['trainer']
    )

    # 加载模型
    if model_path.exists():
        loaded = torch.load(str(model_path))
        version = safeget(loaded, 'version')
        print(f'loading Imagen from {full_model_path}, saved at version {version} - current package version is {__version__}')
        trainer.load(model_path)
        
    if torch.cuda.is_available():
        trainer = trainer.cuda()

    size = config_data['imagen']['image_sizes'][unet-1]

    max_batch_size = config_data['max_batch_size'] if 'max_batch_size' in config_data else 1

    channels = 'RGB'
    # 检查配置数据中是否包含 'channels' 键
    if 'channels' in config_data['imagen']:
        # 断言通道数在 1 到 4 之间,否则抛出异常
        assert config_data['imagen']['channels'] > 0 and config_data['imagen']['channels'] < 5, 'Imagen only support 1 to 4 channels L, LA, RGB, RGBA'
        # 根据通道数设置 channels 变量
        if config_data['imagen']['channels'] == 4:
            channels = 'RGBA' # Color with alpha
        elif config_data['imagen']['channels'] == 2:
            channels == 'LA' # Luminance (Greyscale) with alpha
        elif config_data['imagen']['channels'] == 1:
            channels = 'L' # Luminance (Greyscale)

    # 断言配置数据中包含 'batch_size' 键
    assert 'batch_size' in config_data['dataset'], 'A batch_size is required in the config file'
    
    # 加载并添加训练数据集和验证数据集
    ds = load_dataset(config_data['dataset_name'])
    
    train_ds = None
    
    # 如果有训练和验证数据集,则将它们合并成一个数据集,以便训练器处理拆分
    if 'train' in ds and 'valid' in ds:
        train_ds = concatenate_datasets([ds['train'], ds['valid']])
    elif 'train' in ds:
        train_ds = ds['train']
    elif 'valid' in ds:
        train_ds = ds['valid']
    else:
        train_ds = ds
        
    # 断言训练数据集不为空
    assert train_ds is not None, 'No train dataset could be fetched from the dataset name provided'
    
    # 添加训练数据集到训练器
    trainer.add_train_dataset(
        ds = train_ds,
        collate_fn = Collator(
            image_size = size,
            image_label = config_data['image_label'],
            text_label = config_data['text_label'],
            url_label = config_data['url_label'],
            name = imagen.text_encoder_name,
            channels = channels
        ),
        **config_data['dataset']
    )
    
    # 检查是否需要验证、采样和保存
    should_validate = trainer.split_valid_from_train and 'validate_at_every' in config_data
    should_sample = 'sample_texts' in config_data and 'sample_at_every' in config_data
    should_save = 'save_at_every' in config_data
    
    # 根据配置设置验证、采样和保存的频率
    valid_at_every = config_data['validate_at_every'] if should_validate else 0
    assert isinstance(valid_at_every, int), 'validate_at_every must be an integer'
    sample_at_every = config_data['sample_at_every'] if should_sample else 0
    assert isinstance(sample_at_every, int), 'sample_at_every must be an integer'
    save_at_every = config_data['save_at_every'] if should_save else 0
    assert isinstance(save_at_every, int), 'save_at_every must be an integer'
    sample_texts = config_data['sample_texts'] if should_sample else []
    assert isinstance(sample_texts, list), 'sample_texts must be a list'
    
    # 当 should_sample 为真时,检查 sample_texts 不为空
    assert not should_sample or len(sample_texts) > 0, 'sample_texts must not be empty when sample_at_every is set'
    
    # 循环训练模型
    for i in range(epoches):
        for _ in tqdm(range(len(trainer.train_dl)):
            # 训练模型并获取损失
            loss = trainer.train_step(unet_number = unet, max_batch_size = max_batch_size)
            print(f'loss: {loss}')

        # 在指定的验证频率进行验证
        if not (i % valid_at_every) and i > 0 and trainer.is_main and should_validate:
            valid_loss = trainer.valid_step(unet_number = unet, max_batch_size = max_batch_size)
            print(f'valid loss: {valid_loss}')

        # 在指定的采样频率进行采样并保存图片
        if not (i % save_at_every) and i > 0 and trainer.is_main and should_sample:
            images = trainer.sample(texts = [sample_texts], batch_size = 1, return_pil_images = True, stop_at_unet_number = unet)
            images[0].save(f'./sample-{i // 100}.png')
            
        # 在指定的保存频率保存模型
        if not (i % save_at_every) and i > 0 and trainer.is_main and should_save:
            trainer.save(model_path)

    # 最终保存模型
    trainer.save(model_path)

.\lucidrains\imagen-pytorch\imagen_pytorch\configs.py

# 导入必要的模块和类
from pydantic import BaseModel, model_validator
from typing import List, Optional, Union, Tuple
from enum import Enum

# 导入自定义模块中的类和函数
from imagen_pytorch.imagen_pytorch import Imagen, Unet, Unet3D, NullUnet
from imagen_pytorch.trainer import ImagenTrainer
from imagen_pytorch.elucidated_imagen import ElucidatedImagen
from imagen_pytorch.t5 import DEFAULT_T5_NAME, get_encoded_dim

# 定义一些辅助函数

# 判断值是否存在
def exists(val):
    return val is not None

# 返回默认值
def default(val, d):
    return val if exists(val) else d

# 定义一个接受内部类型的列表或元组
def ListOrTuple(inner_type):
    return Union[List[inner_type], Tuple[inner_type]]

# 定义一个接受内部类型的单个值或列表
def SingleOrList(inner_type):
    return Union[inner_type, ListOrTuple(inner_type)]

# 噪声调度

# 定义一个枚举类,表示噪声调度的类型
class NoiseSchedule(Enum):
    cosine = 'cosine'
    linear = 'linear'

# 允许额外字段的基础模型类
class AllowExtraBaseModel(BaseModel):
    class Config:
        extra = "allow"
        use_enum_values = True

# imagen pydantic 类

# 空 Unet 配置类
class NullUnetConfig(BaseModel):
    is_null:            bool

    def create(self):
        return NullUnet()

# Unet 配置类
class UnetConfig(AllowExtraBaseModel):
    dim:                int
    dim_mults:          ListOrTuple(int)
    text_embed_dim:     int = get_encoded_dim(DEFAULT_T5_NAME)
    cond_dim:           Optional[int] = None
    channels:           int = 3
    attn_dim_head:      int = 32
    attn_heads:         int = 16

    def create(self):
        return Unet(**self.dict())

# Unet3D 配置类
class Unet3DConfig(AllowExtraBaseModel):
    dim:                int
    dim_mults:          ListOrTuple(int)
    text_embed_dim:     int = get_encoded_dim(DEFAULT_T5_NAME)
    cond_dim:           Optional[int] = None
    channels:           int = 3
    attn_dim_head:      int = 32
    attn_heads:         int = 16

    def create(self):
        return Unet3D(**self.dict())

# Imagen 配置类
class ImagenConfig(AllowExtraBaseModel):
    unets:                  ListOrTuple(Union[UnetConfig, Unet3DConfig, NullUnetConfig])
    image_sizes:            ListOrTuple(int)
    video:                  bool = False
    timesteps:              SingleOrList(int) = 1000
    noise_schedules:        SingleOrList(NoiseSchedule) = 'cosine'
    text_encoder_name:      str = DEFAULT_T5_NAME
    channels:               int = 3
    loss_type:              str = 'l2'
    cond_drop_prob:         float = 0.5

    @model_validator(mode="after")
    def check_image_sizes(self):
        if len(self.image_sizes) != len(self.unets):
            raise ValueError(f'image sizes length {len(self.image_sizes)} must be equivalent to the number of unets {len(self.unets)}')
        return self

    def create(self):
        decoder_kwargs = self.dict()
        unets_kwargs = decoder_kwargs.pop('unets')
        is_video = decoder_kwargs.pop('video', False)

        unets = []

        for unet, unet_kwargs in zip(self.unets, unets_kwargs):
            if isinstance(unet, NullUnetConfig):
                unet_klass = NullUnet
            elif is_video:
                unet_klass = Unet3D
            else:
                unet_klass = Unet

            unets.append(unet_klass(**unet_kwargs))

        imagen = Imagen(unets, **decoder_kwargs)

        imagen._config = self.dict().copy()
        return imagen

# ElucidatedImagen 配置类
class ElucidatedImagenConfig(AllowExtraBaseModel):
    unets:                  ListOrTuple(Union[UnetConfig, Unet3DConfig, NullUnetConfig])
    image_sizes:            ListOrTuple(int)
    video:                  bool = False
    text_encoder_name:      str = DEFAULT_T5_NAME
    channels:               int = 3
    cond_drop_prob:         float = 0.5
    num_sample_steps:       SingleOrList(int) = 32
    sigma_min:              SingleOrList(float) = 0.002
    sigma_max:              SingleOrList(int) = 80
    sigma_data:             SingleOrList(float) = 0.5
    rho:                    SingleOrList(int) = 7
    P_mean:                 SingleOrList(float) = -1.2
    P_std:                  SingleOrList(float) = 1.2
    S_churn:                SingleOrList(int) = 80
    S_tmin:                 SingleOrList(float) = 0.05
    S_tmax:                 SingleOrList(int) = 50
    # 定义 S_tmax 变量,类型为 int 或 int 列表,默认值为 50
    S_noise:                SingleOrList(float) = 1.003
    # 定义 S_noise 变量,类型为 float 或 float 列表,默认值为 1.003

    @model_validator(mode="after")
    # 使用 model_validator 装饰器,指定 mode 参数为 "after"
    def check_image_sizes(self):
        # 检查图像大小是否与 unets 数量相等
        if len(self.image_sizes) != len(self.unets):
            raise ValueError(f'image sizes length {len(self.image_sizes)} must be equivalent to the number of unets {len(self.unets)}')
        return self
        # 返回当前对象

    def create(self):
        # 创建方法 create
        decoder_kwargs = self.dict()
        # 获取当前对象的字典形式
        unets_kwargs = decoder_kwargs.pop('unets')
        # 从字典中弹出键为 'unets' 的值,并赋给 unets_kwargs
        is_video = decoder_kwargs.pop('video', False)
        # 从字典中弹出键为 'video' 的值,如果不存在则默认为 False

        unet_klass = Unet3D if is_video else Unet
        # 根据 is_video 的值选择 Unet3D 或 Unet 类

        unets = []

        for unet, unet_kwargs in zip(self.unets, unets_kwargs):
            # 遍历 self.unets 和 unets_kwargs
            if isinstance(unet, NullUnetConfig):
                unet_klass = NullUnet
            elif is_video:
                unet_klass = Unet3D
            else:
                unet_klass = Unet

            unets.append(unet_klass(**unet_kwargs))
            # 根据条件选择 Unet 类型,并将实例添加到 unets 列表中

        imagen = ElucidatedImagen(unets, **decoder_kwargs)
        # 创建 ElucidatedImagen 实例,传入 unets 和 decoder_kwargs

        imagen._config = self.dict().copy()
        # 将当前对象的字典形式复制给 imagen 的 _config 属性
        return imagen
        # 返回 imagen 实例
# 定义一个配置类 ImagenTrainerConfig,继承自 AllowExtraBaseModel
class ImagenTrainerConfig(AllowExtraBaseModel):
    # 定义属性 imagen,类型为字典
    imagen:                 dict
    # 定义属性 elucidated,默认值为 False
    elucidated:             bool = False
    # 定义属性 video,默认值为 False
    video:                  bool = False
    # 定义属性 use_ema,默认值为 True
    use_ema:                bool = True
    # 定义属性 lr,默认值为 1e-4
    lr:                     SingleOrList(float) = 1e-4
    # 定义属性 eps,默认值为 1e-8
    eps:                    SingleOrList(float) = 1e-8
    # 定义属性 beta1,默认值为 0.9
    beta1:                  float = 0.9
    # 定义属性 beta2,默认值为 0.99
    beta2:                  float = 0.99
    # 定义属性 max_grad_norm,默认值为 None
    max_grad_norm:          Optional[float] = None
    # 定义属性 group_wd_params,默认值为 True
    group_wd_params:        bool = True
    # 定义属性 warmup_steps,默认值为 None
    warmup_steps:           SingleOrList(Optional[int]) = None
    # 定义属性 cosine_decay_max_steps,默认值为 None
    cosine_decay_max_steps: SingleOrList(Optional[int]) = None

    # 定义一个方法 create,用于创建 ImagenTrainer 对象
    def create(self):
        # 将配置参数转换为字典
        trainer_kwargs = self.dict()

        # 弹出并获取 imagen 属性的值
        imagen_config = trainer_kwargs.pop('imagen')
        # 弹出并获取 elucidated 属性的值
        elucidated = trainer_kwargs.pop('elucidated')

        # 根据 elucidated 属性的值选择不同的配置类
        imagen_config_klass = ElucidatedImagenConfig if elucidated else ImagenConfig
        # 创建 imagen 对象,根据 video 属性的值选择不同的配置
        imagen = imagen_config_klass(**{**imagen_config, 'video': video}).create()

        # 返回创建的 ImagenTrainer 对象
        return ImagenTrainer(imagen, **trainer_kwargs)

.\lucidrains\imagen-pytorch\imagen_pytorch\data.py

# 导入所需的库
from pathlib import Path
from functools import partial

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as T
from imagen_pytorch import t5
from torch.nn.utils.rnn import pad_sequence

from PIL import Image

# 导入自定义的文件工具函数
from datasets.utils.file_utils import get_datasets_user_agent
import io
import urllib

# 设置用户代理
USER_AGENT = get_datasets_user_agent()

# 辅助函数

# 检查值是否存在
def exists(val):
    return val is not None

# 无限循环生成数据集
def cycle(dl):
    while True:
        for data in dl:
            yield data

# 将图像转换为指定类型
def convert_image_to(img_type, image):
    if image.mode != img_type:
        return image.convert(img_type)
    return image

# 数据集、数据加载器、数据整理器

# 数据整理器类
class Collator:
    def __init__(self, image_size, url_label, text_label, image_label, name, channels):
        self.url_label = url_label
        self.text_label = text_label
        self.image_label = image_label
        self.download = url_label is not None
        self.name = name
        self.channels = channels
        self.transform = T.Compose([
            T.Resize(image_size),
            T.CenterCrop(image_size),
            T.ToTensor(),
        ])
    def __call__(self, batch):

        texts = []
        images = []
        for item in batch:
            try:
                if self.download:
                    image = self.fetch_single_image(item[self.url_label])
                else:
                    image = item[self.image_label]
                image = self.transform(image.convert(self.channels))
            except:
                continue

            text = t5.t5_encode_text([item[self.text_label]], name=self.name)
            texts.append(torch.squeeze(text))
            images.append(image)

        if len(texts) == 0:
            return None
        
        texts = pad_sequence(texts, True)

        newbatch = []
        for i in range(len(texts)):
            newbatch.append((images[i], texts[i]))

        return torch.utils.data.dataloader.default_collate(newbatch)

    def fetch_single_image(self, image_url, timeout=1):
        try:
            request = urllib.request.Request(
                image_url,
                data=None,
                headers={"user-agent": USER_AGENT},
            )
            with urllib.request.urlopen(request, timeout=timeout) as req:
                image = Image.open(io.BytesIO(req.read())).convert('RGB')
        except Exception:
            image = None
        return image

# 数据集���
class Dataset(Dataset):
    def __init__(
        self,
        folder,
        image_size,
        exts = ['jpg', 'jpeg', 'png', 'tiff'],
        convert_image_to_type = None
    ):
        super().__init__()
        self.folder = folder
        self.image_size = image_size
        self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]

        convert_fn = partial(convert_image_to, convert_image_to_type) if exists(convert_image_to_type) else nn.Identity()

        self.transform = T.Compose([
            T.Lambda(convert_fn),
            T.Resize(image_size),
            T.RandomHorizontalFlip(),
            T.CenterCrop(image_size),
            T.ToTensor()
        ])

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        path = self.paths[index]
        img = Image.open(path)
        return self.transform(img)

# 获取图像数据加载器
def get_images_dataloader(
    folder,
    *,
    batch_size,
    image_size,
    shuffle = True,
    cycle_dl = False,
    pin_memory = True
):
    ds = Dataset(folder, image_size)
    dl = DataLoader(ds, batch_size = batch_size, shuffle = shuffle, pin_memory = pin_memory)

    if cycle_dl:
        dl = cycle(dl)
    return dl

.\lucidrains\imagen-pytorch\imagen_pytorch\elucidated_imagen.py

# 从 math 模块中导入 sqrt 函数
from math import sqrt
# 从 random 模块中导入 random 函数
from random import random
# 从 functools 模块中导入 partial 函数
from functools import partial
# 从 contextlib 模块中导入 contextmanager 和 nullcontext
from contextlib import contextmanager, nullcontext
# 从 typing 模块中导入 List 和 Union
from typing import List, Union
# 从 collections 模块中导入 namedtuple
from collections import namedtuple
# 从 tqdm.auto 模块中导入 tqdm 函数
from tqdm.auto import tqdm

# 导入 torch 库
import torch
# 从 torch.nn 模块中导入 functional 模块
import torch.nn.functional as F
# 从 torch 模块中导入 nn 模块
from torch import nn
# 从 torch.cuda.amp 模块中导入 autocast 函数
from torch.cuda.amp import autocast
# 从 torch.nn.parallel 模块中导入 DistributedDataParallel 类
from torch.nn.parallel import DistributedDataParallel
# 从 torchvision.transforms 模块中导入 T 别名
import torchvision.transforms as T

# 导入 kornia.augmentation 模块
import kornia.augmentation as K

# 从 einops 模块中导入 rearrange、repeat 和 reduce 函数
from einops import rearrange, repeat, reduce

# 从 imagen_pytorch.imagen_pytorch 模块中导入各种函数和类
from imagen_pytorch.imagen_pytorch import (
    GaussianDiffusionContinuousTimes,
    Unet,
    NullUnet,
    first,
    exists,
    identity,
    maybe,
    default,
    cast_tuple,
    cast_uint8_images_to_float,
    eval_decorator,
    pad_tuple_to_length,
    resize_image_to,
    calc_all_frame_dims,
    safe_get_tuple_index,
    right_pad_dims_to,
    module_device,
    normalize_neg_one_to_one,
    unnormalize_zero_to_one,
    compact,
    maybe_transform_dict_key
)

# 从 imagen_pytorch.imagen_video 模块中导入 Unet3D、resize_video_to 和 scale_video_time 函数
from imagen_pytorch.imagen_video import (
    Unet3D,
    resize_video_to,
    scale_video_time
)

# 从 imagen_pytorch.t5 模块中导入 t5_encode_text、get_encoded_dim 和 DEFAULT_T5_NAME 常量
from imagen_pytorch.t5 import t5_encode_text, get_encoded_dim, DEFAULT_T5_NAME

# 定义常量 Hparams_fields
Hparams_fields = [
    'num_sample_steps',
    'sigma_min',
    'sigma_max',
    'sigma_data',
    'rho',
    'P_mean',
    'P_std',
    'S_churn',
    'S_tmin',
    'S_tmax',
    'S_noise'
]

# 创建命名元组 Hparams
Hparams = namedtuple('Hparams', Hparams_fields)

# 定义辅助函数 log
def log(t, eps = 1e-20):
    return torch.log(t.clamp(min = eps))

# 主类 ElucidatedImagen
class ElucidatedImagen(nn.Module):
    # 初始化方法
    def __init__(
        self,
        unets,
        *,
        image_sizes,                                # 用于级联 ddpm 的图像大小
        text_encoder_name = DEFAULT_T5_NAME,
        text_embed_dim = None,
        channels = 3,
        cond_drop_prob = 0.1,
        random_crop_sizes = None,
        resize_mode = 'nearest',
        temporal_downsample_factor = 1,
        resize_cond_video_frames = True,
        lowres_sample_noise_level = 0.2,            # 低分辨率采样噪声级别
        per_sample_random_aug_noise_level = False,  # 是否在每个批次元素上接收随机增强噪声值
        condition_on_text = True,
        auto_normalize_img = True,                  # 是否自动归一化图像
        dynamic_thresholding = True,
        dynamic_thresholding_percentile = 0.95,     # 动态阈值百分位数
        only_train_unet_number = None,
        lowres_noise_schedule = 'linear',
        num_sample_steps = 32,                      # 采样步数
        sigma_min = 0.002,                          # 最小噪声水平
        sigma_max = 80,                             # 最大噪声水平
        sigma_data = 0.5,                           # 数据分布的标准差
        rho = 7,                                    # 控制采样计划
        P_mean = -1.2,                              # 训练时噪声抽取的对数正态分布均值
        P_std = 1.2,                                # 训练时噪声抽取的对数正态分布标准差
        S_churn = 80,                               # 随机采样参数
        S_tmin = 0.05,
        S_tmax = 50,
        S_noise = 1.003,
    # 强制取消条件性
    def force_unconditional_(self):
        self.condition_on_text = False
        self.unconditional = True

        for unet in self.unets:
            unet.cond_on_text = False
    # 返回属性 device 的值
    @property
    def device(self):
        return self._temp.device

    # 获取指定编号的 UNet 模型
    def get_unet(self, unet_number):
        # 确保 unet_number 在有效范围内
        assert 0 < unet_number <= len(self.unets)
        index = unet_number - 1

        # 如果 self.unets 是 nn.ModuleList 类型,则转换为列表
        if isinstance(self.unets, nn.ModuleList):
            unets_list = [unet for unet in self.unets]
            # 删除属性 'unets'
            delattr(self, 'unets')
            self.unets = unets_list

        # 如果 index 不等于正在训练的 UNet 索引,则将 UNet 移动到指定设备
        if index != self.unet_being_trained_index:
            for unet_index, unet in enumerate(self.unets):
                unet.to(self.device if unet_index == index else 'cpu')

        self.unet_being_trained_index = index
        return self.unets[index]

    # 将所有 UNet 模型重置到同一设备上
    def reset_unets_all_one_device(self, device = None):
        device = default(device, self.device)
        self.unets = nn.ModuleList([*self.unets])
        self.unets.to(device)

        self.unet_being_trained_index = -1

    # 使用上下文管理器将指定 UNet 移动到 GPU 上
    @contextmanager
    def one_unet_in_gpu(self, unet_number = None, unet = None):
        assert exists(unet_number) ^ exists(unet)

        if exists(unet_number):
            unet = self.unets[unet_number - 1]

        cpu = torch.device('cpu')

        devices = [module_device(unet) for unet in self.unets]

        self.unets.to(cpu)
        unet.to(self.device)

        yield

        for unet, device in zip(self.unets, devices):
            unet.to(device)

    # 重写 state_dict 函数
    def state_dict(self, *args, **kwargs):
        self.reset_unets_all_one_device()
        return super().state_dict(*args, **kwargs)

    # 重写 load_state_dict 函数
    def load_state_dict(self, *args, **kwargs):
        self.reset_unets_all_one_device()
        return super().load_state_dict(*args, **kwargs)

    # 动态阈值
    def threshold_x_start(self, x_start, dynamic_threshold = True):
        if not dynamic_threshold:
            return x_start.clamp(-1., 1.)

        s = torch.quantile(
            rearrange(x_start, 'b ... -> b (...)').abs(),
            self.dynamic_thresholding_percentile,
            dim = -1
        )

        s.clamp_(min = 1.)
        s = right_pad_dims_to(x_start, s)
        return x_start.clamp(-s, s) / s

    # 衍生的预处理参数 - 表 1
    def c_skip(self, sigma_data, sigma):
        return (sigma_data ** 2) / (sigma ** 2 + sigma_data ** 2)

    def c_out(self, sigma_data, sigma):
        return sigma * sigma_data * (sigma_data ** 2 + sigma ** 2) ** -0.5

    def c_in(self, sigma_data, sigma):
        return 1 * (sigma ** 2 + sigma_data ** 2) ** -0.5

    def c_noise(self, sigma):
        return log(sigma) * 0.25

    # 预处理网络输出
    def preconditioned_network_forward(
        self,
        unet_forward,
        noised_images,
        sigma,
        *,
        sigma_data,
        clamp = False,
        dynamic_threshold = True,
        **kwargs
    ):
        batch, device = noised_images.shape[0], noised_images.device

        if isinstance(sigma, float):
            sigma = torch.full((batch,), sigma, device = device)

        padded_sigma = self.right_pad_dims_to_datatype(sigma)

        net_out = unet_forward(
            self.c_in(sigma_data, padded_sigma) * noised_images,
            self.c_noise(sigma),
            **kwargs
        )

        out = self.c_skip(sigma_data, padded_sigma) * noised_images +  self.c_out(sigma_data, padded_sigma) * net_out

        if not clamp:
            return out

        return self.threshold_x_start(out, dynamic_threshold)

    # 采样
    # 采样计划
    def sample_schedule(
        self,
        num_sample_steps,
        rho,
        sigma_min,
        sigma_max
    ):
        N = num_sample_steps
        inv_rho = 1 / rho

        # 生成一个包含 num_sample_steps 个元素的张量,设备为 self.device,数据类型为 torch.float32
        steps = torch.arange(num_sample_steps, device = self.device, dtype = torch.float32)
        # 计算每个步骤的 sigma 值
        sigmas = (sigma_max ** inv_rho + steps / (N - 1) * (sigma_min ** inv_rho - sigma_max ** inv_rho)) ** rho

        # 在 sigmas 张量的末尾填充一个值为 0 的元素,用于表示最后一个步骤的 sigma 值为 0
        sigmas = F.pad(sigmas, (0, 1), value = 0.) # last step is sigma value of 0.
        return sigmas

    @torch.no_grad()
    def one_unet_sample(
        self,
        unet,
        shape,
        *,
        unet_number,
        clamp = True,
        dynamic_threshold = True,
        cond_scale = 1.,
        use_tqdm = True,
        inpaint_videos = None,
        inpaint_images = None,
        inpaint_masks = None,
        inpaint_resample_times = 5,
        init_images = None,
        skip_steps = None,
        sigma_min = None,
        sigma_max = None,
        **kwargs
    @torch.no_grad()
    @eval_decorator
    def sample(
        self,
        texts: List[str] = None,
        text_masks = None,
        text_embeds = None,
        cond_images = None,
        cond_video_frames = None,
        post_cond_video_frames = None,
        inpaint_videos = None,
        inpaint_images = None,
        inpaint_masks = None,
        inpaint_resample_times = 5,
        init_images = None,
        skip_steps = None,
        sigma_min = None,
        sigma_max = None,
        video_frames = None,
        batch_size = 1,
        cond_scale = 1.,
        lowres_sample_noise_level = None,
        start_at_unet_number = 1,
        start_image_or_video = None,
        stop_at_unet_number = None,
        return_all_unet_outputs = False,
        return_pil_images = False,
        use_tqdm = True,
        use_one_unet_in_gpu = True,
        device = None,
    # training

    # 计算损失权重
    def loss_weight(self, sigma_data, sigma):
        return (sigma ** 2 + sigma_data ** 2) * (sigma * sigma_data) ** -2

    # 生成服从指定均值和标准差的噪声分布
    def noise_distribution(self, P_mean, P_std, batch_size):
        return (P_mean + P_std * torch.randn((batch_size,), device = self.device)).exp()

    def forward(
        self,
        images, # 重命名为 images 或 video
        unet: Union[Unet, Unet3D, NullUnet, DistributedDataParallel] = None,
        texts: List[str] = None,
        text_embeds = None,
        text_masks = None,
        unet_number = None,
        cond_images = None,
        **kwargs

.\lucidrains\imagen-pytorch\imagen_pytorch\imagen_pytorch.py

# 导入数学库
import math
# 从随机模块中导入随机函数
from random import random
# 从 beartype 库中导入 List 和 Union 类型
from beartype.typing import List, Union
# 从 beartype 库中导入 beartype 装饰器
from beartype import beartype
# 从 tqdm 库中导入 tqdm 函数
from tqdm.auto import tqdm
# 从 functools 库中导入 partial 和 wraps 函数
from functools import partial, wraps
# 从 contextlib 库中导入 contextmanager 和 nullcontext 函数
from contextlib import contextmanager, nullcontext
# 从 pathlib 库中导入 Path 类

from pathlib import Path

# 导入 torch 库
import torch
# 从 torch.nn.functional 模块中导入 F 函数
import torch.nn.functional as F
# 从 torch.nn.parallel 模块中导入 DistributedDataParallel 类
from torch.nn.parallel import DistributedDataParallel
# 从 torch 模块中导入 nn 和 einsum 函数
from torch import nn, einsum
# 从 torch.cuda.amp 模块中导入 autocast 函数
from torch.cuda.amp import autocast
# 从 torch.special 模块中导入 expm1 函数
from torch.special import expm1
# 从 torchvision.transforms 模块中导入 T 函数

import torchvision.transforms as T

# 从 kornia.augmentation 模块中导入 K 函数
import kornia.augmentation as K

# 从 einops 模块中导入 rearrange, repeat, reduce, pack, unpack 函数
from einops import rearrange, repeat, reduce, pack, unpack
# 从 einops.layers.torch 模块中导入 Rearrange 类
from einops.layers.torch import Rearrange

# 从 imagen_pytorch.t5 模块中导入 t5_encode_text, get_encoded_dim, DEFAULT_T5_NAME 函数
from imagen_pytorch.t5 import t5_encode_text, get_encoded_dim, DEFAULT_T5_NAME

# 从 imagen_pytorch.imagen_video 模块中导入 Unet3D, resize_video_to, scale_video_time 函数

from imagen_pytorch.imagen_video import Unet3D, resize_video_to, scale_video_time

# helper functions

# 判断值是否存在
def exists(val):
    return val is not None

# 返回输入值
def identity(t, *args, **kwargs):
    return t

# 判断一个数是否可以被另一个数整除
def divisible_by(numer, denom):
    return (numer % denom) == 0

# 返回列表的第一个元素,如果列表为空则返回默认值
def first(arr, d = None):
    if len(arr) == 0:
        return d
    return arr[0]

# 可能的装饰器
def maybe(fn):
    @wraps(fn)
    def inner(x):
        if not exists(x):
            return x
        return fn(x)
    return inner

# 仅执行一次的装饰器
def once(fn):
    called = False
    @wraps(fn)
    def inner(x):
        nonlocal called
        if called:
            return
        called = True
        return fn(x)
    return inner

# 仅打印一次的装饰器
print_once = once(print)

# 返回默认值
def default(val, d):
    if exists(val):
        return val
    return d() if callable(d) else d

# 将输入值转换为元组
def cast_tuple(val, length = None):
    if isinstance(val, list):
        val = tuple(val)

    output = val if isinstance(val, tuple) else ((val,) * default(length, 1))

    if exists(length):
        assert len(output) == length

    return output

# 压缩字典,去除值为 None 的键值对
def compact(input_dict):
    return {key: value for key, value in input_dict.items() if exists(value)}

# 对字典中指定键的值进行转换
def maybe_transform_dict_key(input_dict, key, fn):
    if key not in input_dict:
        return input_dict

    copied_dict = input_dict.copy()
    copied_dict[key] = fn(copied_dict[key])
    return copied_dict

# 将 uint8 类型的图像转换为 float 类型
def cast_uint8_images_to_float(images):
    if not images.dtype == torch.uint8:
        return images
    return images / 255

# 获取模块的设备信息
def module_device(module):
    return next(module.parameters()).device

# 初始化权重为零
def zero_init_(m):
    nn.init.zeros_(m.weight)
    if exists(m.bias):
        nn.init.zeros_(m.bias)

# 模型评估装饰器
def eval_decorator(fn):
    def inner(model, *args, **kwargs):
        was_training = model.training
        model.eval()
        out = fn(model, *args, **kwargs)
        model.train(was_training)
        return out
    return inner

# 将元组填充到指定长度
def pad_tuple_to_length(t, length, fillvalue = None):
    remain_length = length - len(t)
    if remain_length <= 0:
        return t
    return (*t, *((fillvalue,) * remain_length))

# helper classes

# 空操作模块
class Identity(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()

    def forward(self, x, *args, **kwargs):
        return x

# tensor helpers

# 计算张量的对数
def log(t, eps: float = 1e-12):
    return torch.log(t.clamp(min = eps))

# 计算张量的 L2 范数
def l2norm(t):
    return F.normalize(t, dim = -1)

# 将一个张量的维度右侧填充到与另一个张量相同的维度
def right_pad_dims_to(x, t):
    padding_dims = x.ndim - t.ndim
    if padding_dims <= 0:
        return t
    return t.view(*t.shape, *((1,) * padding_dims))

# 计算带有掩码的张量均值
def masked_mean(t, *, dim, mask = None):
    if not exists(mask):
        return t.mean(dim = dim)

    denom = mask.sum(dim = dim, keepdim = True)
    mask = rearrange(mask, 'b n -> b n 1')
    masked_t = t.masked_fill(~mask, 0.)

    return masked_t.sum(dim = dim) / denom.clamp(min = 1e-5)

# 调整图像大小
def resize_image_to(
    image,
    target_image_size,
    clamp_range = None,
    mode = 'nearest'
):
    orig_image_size = image.shape[-1]

    if orig_image_size == target_image_size:
        return image

    out = F.interpolate(image, target_image_size, mode = mode)

    if exists(clamp_range):
        out = out.clamp(*clamp_range)

    return out

# 计算所有帧的维度
def calc_all_frame_dims(
    downsample_factors: List[int],
    frames
):
    # 如果frames不存在,则返回一个空元组的元组,长度为downsample_factors的长度
    if not exists(frames):
        return (tuple(),) * len(downsample_factors)

    # 存储所有帧的维度信息
    all_frame_dims = []

    # 遍历downsample_factors列表
    for divisor in downsample_factors:
        # 断言frames能够被divisor整除
        assert divisible_by(frames, divisor)
        # 将frames除以divisor得到的结果作为元组添加到all_frame_dims列表中
        all_frame_dims.append((frames // divisor,))

    # 返回所有帧的维度信息
    return all_frame_dims
# 安全获取元组中指定索引的值,如果索引超出范围则返回默认值
def safe_get_tuple_index(tup, index, default = None):
    if len(tup) <= index:
        return default
    return tup[index]

# 图像归一化函数
# ddpms 期望图像范围在 -1 到 1 之间

def normalize_neg_one_to_one(img):
    return img * 2 - 1

def unnormalize_zero_to_one(normed_img):
    return (normed_img + 1) * 0.5

# 无分类器指导函数

def prob_mask_like(shape, prob, device):
    if prob == 1:
        return torch.ones(shape, device = device, dtype = torch.bool)
    elif prob == 0:
        return torch.zeros(shape, device = device, dtype = torch.bool)
    else:
        return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob

# 连续时间高斯扩散辅助函数和类
# 这部分很大程度上要感谢 @crowsonkb 在 https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/utils.py

@torch.jit.script
def beta_linear_log_snr(t):
    return -torch.log(expm1(1e-4 + 10 * (t ** 2)))

@torch.jit.script
def alpha_cosine_log_snr(t, s: float = 0.008):
    return -log((torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** -2) - 1, eps = 1e-5) # 不确定这是否考虑了在离散版本中 beta 被剪切为 0.999

def log_snr_to_alpha_sigma(log_snr):
    return torch.sqrt(torch.sigmoid(log_snr)), torch.sqrt(torch.sigmoid(-log_snr))

class GaussianDiffusionContinuousTimes(nn.Module):
    def __init__(self, *, noise_schedule, timesteps = 1000):
        super().__init__()

        if noise_schedule == "linear":
            self.log_snr = beta_linear_log_snr
        elif noise_schedule == "cosine":
            self.log_snr = alpha_cosine_log_snr
        else:
            raise ValueError(f'invalid noise schedule {noise_schedule}')

        self.num_timesteps = timesteps

    def get_times(self, batch_size, noise_level, *, device):
        return torch.full((batch_size,), noise_level, device = device, dtype = torch.float32)

    def sample_random_times(self, batch_size, *, device):
        return torch.zeros((batch_size,), device = device).float().uniform_(0, 1)

    def get_condition(self, times):
        return maybe(self.log_snr)(times)

    def get_sampling_timesteps(self, batch, *, device):
        times = torch.linspace(1., 0., self.num_timesteps + 1, device = device)
        times = repeat(times, 't -> b t', b = batch)
        times = torch.stack((times[:, :-1], times[:, 1:]), dim = 0)
        times = times.unbind(dim = -1)
        return times

    def q_posterior(self, x_start, x_t, t, *, t_next = None):
        t_next = default(t_next, lambda: (t - 1. / self.num_timesteps).clamp(min = 0.))

        """ https://openreview.net/attachment?id=2LdBqxc1Yv&name=supplementary_material """
        log_snr = self.log_snr(t)
        log_snr_next = self.log_snr(t_next)
        log_snr, log_snr_next = map(partial(right_pad_dims_to, x_t), (log_snr, log_snr_next))

        alpha, sigma = log_snr_to_alpha_sigma(log_snr)
        alpha_next, sigma_next = log_snr_to_alpha_sigma(log_snr_next)

        # c - as defined near eq 33
        c = -expm1(log_snr - log_snr_next)
        posterior_mean = alpha_next * (x_t * (1 - c) / alpha + c * x_start)

        # following (eq. 33)
        posterior_variance = (sigma_next ** 2) * c
        posterior_log_variance_clipped = log(posterior_variance, eps = 1e-20)
        return posterior_mean, posterior_variance, posterior_log_variance_clipped

    def q_sample(self, x_start, t, noise = None):
        dtype = x_start.dtype

        if isinstance(t, float):
            batch = x_start.shape[0]
            t = torch.full((batch,), t, device = x_start.device, dtype = dtype)

        noise = default(noise, lambda: torch.randn_like(x_start))
        log_snr = self.log_snr(t).type(dtype)
        log_snr_padded_dim = right_pad_dims_to(x_start, log_snr)
        alpha, sigma =  log_snr_to_alpha_sigma(log_snr_padded_dim)

        return alpha * x_start + sigma * noise, log_snr, alpha, sigma
    # 从输入的 x_from 中采样数据,从 from_t 到 to_t 时间范围内,添加噪声
    def q_sample_from_to(self, x_from, from_t, to_t, noise = None):
        # 获取输入 x_from 的形状、设备和数据类型
        shape, device, dtype = x_from.shape, x_from.device, x_from.dtype
        batch = shape[0]

        # 如果 from_t 是浮点数,则将其转换为与 batch 大小相同的张量
        if isinstance(from_t, float):
            from_t = torch.full((batch,), from_t, device = device, dtype = dtype)

        # 如果 to_t 是浮点数,则将其转换为与 batch 大小相同的张量
        if isinstance(to_t, float):
            to_t = torch.full((batch,), to_t, device = device, dtype = dtype)

        # 如果未提供噪声,则生成一个与 x_from 相同形状的随机噪声张量
        noise = default(noise, lambda: torch.randn_like(x_from))

        # 计算 from_t 对应的 log_snr,并将其维度与 x_from 对齐
        log_snr = self.log_snr(from_t)
        log_snr_padded_dim = right_pad_dims_to(x_from, log_snr)
        # 根据 log_snr 计算 alpha 和 sigma
        alpha, sigma =  log_snr_to_alpha_sigma(log_snr_padded_dim)

        # 计算 to_t 对应的 log_snr,并将其维度与 x_from 对齐
        log_snr_to = self.log_snr(to_t)
        log_snr_padded_dim_to = right_pad_dims_to(x_from, log_snr_to)
        # 根据 log_snr_to 计算 alpha_to 和 sigma_to
        alpha_to, sigma_to =  log_snr_to_alpha_sigma(log_snr_padded_dim_to)

        # 返回根据公式计算得到的结果
        return x_from * (alpha_to / alpha) + noise * (sigma_to * alpha - sigma * alpha_to) / alpha

    # 根据给定的 x_t、t 和速度 v 预测起始值
    def predict_start_from_v(self, x_t, t, v):
        # 计算 t 对应的 log_snr,并将其维度与 x_t 对齐
        log_snr = self.log_snr(t)
        log_snr = right_pad_dims_to(x_t, log_snr)
        # 根据 log_snr 计算 alpha 和 sigma
        alpha, sigma = log_snr_to_alpha_sigma(log_snr)
        # 返回根据公式计算得到的结果
        return alpha * x_t - sigma * v

    # 根据给定的 x_t、t 和噪声 noise 预测起始值
    def predict_start_from_noise(self, x_t, t, noise):
        # 计算 t 对应的 log_snr,并将其维度与 x_t 对齐
        log_snr = self.log_snr(t)
        log_snr = right_pad_dims_to(x_t, log_snr)
        # 根据 log_snr 计算 alpha 和 sigma
        alpha, sigma = log_snr_to_alpha_sigma(log_snr)
        # 返回根据公式计算得到的结果
        return (x_t - sigma * noise) / alpha.clamp(min = 1e-8)
# 定义 LayerNorm 类,用于实现层归一化操作
class LayerNorm(nn.Module):
    # 初始化函数,接受特征数、是否稳定、维度作为参数
    def __init__(self, feats, stable = False, dim = -1):
        super().__init__()
        self.stable = stable
        self.dim = dim

        # 初始化可学习参数 g
        self.g = nn.Parameter(torch.ones(feats, *((1,) * (-dim - 1))))

    # 前向传播函数
    def forward(self, x):
        dtype, dim = x.dtype, self.dim

        # 如果设置了稳定性,对输入进行归一化处理
        if self.stable:
            x = x / x.amax(dim = dim, keepdim = True).detach()

        # 根据数据类型选择 eps 值
        eps = 1e-5 if x.dtype == torch.float32 else 1e-3
        # 计算方差和均值
        var = torch.var(x, dim = dim, unbiased = False, keepdim = True)
        mean = torch.mean(x, dim = dim, keepdim = True)

        # 返回归一化后的结果
        return (x - mean) * (var + eps).rsqrt().type(dtype) * self.g.type(dtype)

# 定义 ChanLayerNorm 类,是 LayerNorm 的一个特例,维度为 -3
ChanLayerNorm = partial(LayerNorm, dim = -3)

# 定义 Always 类,用于返回固定值
class Always():
    def __init__(self, val):
        self.val = val

    def __call__(self, *args, **kwargs):
        return self.val

# 定义 Residual 类,实现残差连接
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) + x

# 定义 Parallel 类,实现并行计算
class Parallel(nn.Module):
    def __init__(self, *fns):
        super().__init__()
        self.fns = nn.ModuleList(fns)

    def forward(self, x):
        outputs = [fn(x) for fn in self.fns]
        return sum(outputs)

# 定义 PerceiverAttention 类,实现注意力机制
class PerceiverAttention(nn.Module):
    def __init__(
        self,
        *,
        dim,
        dim_head = 64,
        heads = 8,
        scale = 8
    ):
        super().__init__()
        self.scale = scale

        self.heads = heads
        inner_dim = dim_head * heads

        # 初始化层归一化操作和线性变换
        self.norm = nn.LayerNorm(dim)
        self.norm_latents = nn.LayerNorm(dim)
        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)

        # 初始化缩放参数
        self.q_scale = nn.Parameter(torch.ones(dim_head))
        self.k_scale = nn.Parameter(torch.ones(dim_head))

        # 输出层
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim, bias = False),
            nn.LayerNorm(dim)
        )

    # 前向传播函数
    def forward(self, x, latents, mask = None):
        x = self.norm(x)
        latents = self.norm_latents(latents)

        b, h = x.shape[0], self.heads

        q = self.to_q(latents)

        # 拼接键值对
        kv_input = torch.cat((x, latents), dim = -2)
        k, v = self.to_kv(kv_input).chunk(2, dim = -1)

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))

        # 对 q 和 k 进行 L2 归一化
        q, k = map(l2norm, (q, k))
        q = q * self.q_scale
        k = k * self.k_scale

        # 计算相似度并进行掩码处理
        sim = einsum('... i d, ... j d  -> ... i j', q, k) * self.scale

        if exists(mask):
            max_neg_value = -torch.finfo(sim.dtype).max
            mask = F.pad(mask, (0, latents.shape[-2]), value = True)
            mask = rearrange(mask, 'b j -> b 1 1 j')
            sim = sim.masked_fill(~mask, max_neg_value)

        # 注意力计算
        attn = sim.softmax(dim = -1, dtype = torch.float32)
        attn = attn.to(sim.dtype)

        out = einsum('... i j, ... j d -> ... i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)', h = h)
        return self.to_out(out)

# 定义 PerceiverResampler 类,实现 Perceiver 模型的重采样
class PerceiverResampler(nn.Module):
    def __init__(
        self,
        *,
        dim,
        depth,
        dim_head = 64,
        heads = 8,
        num_latents = 64,
        num_latents_mean_pooled = 4, # number of latents derived from mean pooled representation of the sequence
        max_seq_len = 512,
        ff_mult = 4
    # 初始化函数,继承父类的初始化方法
    ):
        # 调用父类的初始化方法
        super().__init__()
        # 创建位置编码的嵌入层,用于将位置信息嵌入输入数据中
        self.pos_emb = nn.Embedding(max_seq_len, dim)

        # 创建可学习的潜在变量,用于表示输入数据的潜在特征
        self.latents = nn.Parameter(torch.randn(num_latents, dim))

        # 初始化从平均池化序列到潜在变量的映射层
        self.to_latents_from_mean_pooled_seq = None

        # 如果平均池化的潜在变量数量大于0,则创建映射层
        if num_latents_mean_pooled > 0:
            self.to_latents_from_mean_pooled_seq = nn.Sequential(
                LayerNorm(dim),
                nn.Linear(dim, dim * num_latents_mean_pooled),
                Rearrange('b (n d) -> b n d', n = num_latents_mean_pooled)
            )

        # 创建多层感知器的注意力和前馈网络层
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PerceiverAttention(dim = dim, dim_head = dim_head, heads = heads),
                FeedForward(dim = dim, mult = ff_mult)
            ]))

    # 前向传播函数,接收输入数据 x 和掩码 mask
    def forward(self, x, mask = None):
        # 获取输入数据的长度和设备信息
        n, device = x.shape[1], x.device
        # 根据位置编码获取位置嵌入
        pos_emb = self.pos_emb(torch.arange(n, device = device))

        # 将输入数据与位置编码相加,融合位置信息
        x_with_pos = x + pos_emb

        # 重复潜在变量以匹配输入数据的维度
        latents = repeat(self.latents, 'n d -> b n d', b = x.shape[0])

        # 如果存在平均池化的潜在变量映射层,则将平均池化的潜在变量与原始潜在变量拼接
        if exists(self.to_latents_from_mean_pooled_seq):
            meanpooled_seq = masked_mean(x, dim = 1, mask = torch.ones(x.shape[:2], device = x.device, dtype = torch.bool))
            meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
            latents = torch.cat((meanpooled_latents, latents), dim = -2)

        # 遍历多层感知器的注意力和前馈网络层
        for attn, ff in self.layers:
            # 使用注意力层处理输入数据和潜在变量,然后与潜在变量相加
            latents = attn(x_with_pos, latents, mask = mask) + latents
            # 使用前馈网络层处理潜在变量,然后与潜在变量相加
            latents = ff(latents) + latents

        # 返回处理后的潜在变量
        return latents
# 定义注意力机制模块
class Attention(nn.Module):
    def __init__(
        self,
        dim,
        *,
        dim_head = 64,
        heads = 8,
        context_dim = None,
        scale = 8
    ):
        super().__init__()
        self.scale = scale

        self.heads = heads
        inner_dim = dim_head * heads

        self.norm = LayerNorm(dim)

        self.null_kv = nn.Parameter(torch.randn(2, dim_head))
        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)

        self.q_scale = nn.Parameter(torch.ones(dim_head))
        self.k_scale = nn.Parameter(torch.ones(dim_head))

        self.to_context = nn.Sequential(nn.LayerNorm(context_dim), nn.Linear(context_dim, dim_head * 2)) if exists(context_dim) else None

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim, bias = False),
            LayerNorm(dim)
        )

    def forward(self, x, context = None, mask = None, attn_bias = None):
        b, n, device = *x.shape[:2], x.device

        x = self.norm(x)

        q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1))

        q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)

        # add null key / value for classifier free guidance in prior net

        nk, nv = map(lambda t: repeat(t, 'd -> b 1 d', b = b), self.null_kv.unbind(dim = -2))
        k = torch.cat((nk, k), dim = -2)
        v = torch.cat((nv, v), dim = -2)

        # add text conditioning, if present

        if exists(context):
            assert exists(self.to_context)
            ck, cv = self.to_context(context).chunk(2, dim = -1)
            k = torch.cat((ck, k), dim = -2)
            v = torch.cat((cv, v), dim = -2)

        # qk rmsnorm

        q, k = map(l2norm, (q, k))
        q = q * self.q_scale
        k = k * self.k_scale

        # calculate query / key similarities

        sim = einsum('b h i d, b j d -> b h i j', q, k) * self.scale

        # relative positional encoding (T5 style)

        if exists(attn_bias):
            sim = sim + attn_bias

        # masking

        max_neg_value = -torch.finfo(sim.dtype).max

        if exists(mask):
            mask = F.pad(mask, (1, 0), value = True)
            mask = rearrange(mask, 'b j -> b 1 1 j')
            sim = sim.masked_fill(~mask, max_neg_value)

        # attention

        attn = sim.softmax(dim = -1, dtype = torch.float32)
        attn = attn.to(sim.dtype)

        # aggregate values

        out = einsum('b h i j, b j d -> b h i d', attn, v)

        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

# 定义上采样函数
def Upsample(dim, dim_out = None):
    dim_out = default(dim_out, dim)

    return nn.Sequential(
        nn.Upsample(scale_factor = 2, mode = 'nearest'),
        nn.Conv2d(dim, dim_out, 3, padding = 1)
    )

# 定义像素混洗上采样类
class PixelShuffleUpsample(nn.Module):
    """
    code shared by @MalumaDev at DALLE2-pytorch for addressing checkboard artifacts
    https://arxiv.org/ftp/arxiv/papers/1707/1707.02937.pdf
    """
    def __init__(self, dim, dim_out = None):
        super().__init__()
        dim_out = default(dim_out, dim)
        conv = nn.Conv2d(dim, dim_out * 4, 1)

        self.net = nn.Sequential(
            conv,
            nn.SiLU(),
            nn.PixelShuffle(2)
        )

        self.init_conv_(conv)

    def init_conv_(self, conv):
        o, i, h, w = conv.weight.shape
        conv_weight = torch.empty(o // 4, i, h, w)
        nn.init.kaiming_uniform_(conv_weight)
        conv_weight = repeat(conv_weight, 'o ... -> (o 4) ...')

        conv.weight.data.copy_(conv_weight)
        nn.init.zeros_(conv.bias.data)

    def forward(self, x):
        return self.net(x)

# 定义下采样函数
def Downsample(dim, dim_out = None):
    # https://arxiv.org/abs/2208.03641 shows this is the most optimal way to downsample
    # named SP-conv in the paper, but basically a pixel unshuffle
    dim_out = default(dim_out, dim)
    # 返回一个包含两个操作的序列:1. 重新排列输入张量的维度,将其转换为'b (c s1 s2) h w'的形式;2. 使用1x1卷积层将输入通道数从dim * 4降至dim_out
    return nn.Sequential(
        # 重新排列输入张量的维度,将其转换为'b (c s1 s2) h w'的形式,其中s1和s2分别为2
        Rearrange('b c (h s1) (w s2) -> b (c s1 s2) h w', s1 = 2, s2 = 2),
        # 使用1x1卷积层将输入通道数从dim * 4降至dim_out
        nn.Conv2d(dim * 4, dim_out, 1)
    )
class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)  # 计算对数值
        emb = torch.exp(torch.arange(half_dim, device = x.device) * -emb)  # 计算指数值
        emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j')  # 重排张量形状
        return torch.cat((emb.sin(), emb.cos()), dim = -1)  # 拼接正弦和余弦值

class LearnedSinusoidalPosEmb(nn.Module):
    """ following @crowsonkb 's lead with learned sinusoidal pos emb """
    """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """

    def __init__(self, dim):
        super().__init__()
        assert (dim % 2) == 0
        half_dim = dim // 2
        self.weights = nn.Parameter(torch.randn(half_dim))  # 初始化权重参数

    def forward(self, x):
        x = rearrange(x, 'b -> b 1')  # 重排张量形状
        freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi  # 计算频率
        fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1)  # 拼接正弦和余弦值
        fouriered = torch.cat((x, fouriered), dim = -1)  # 拼接原始张量和傅立叶变换结果
        return fouriered

class Block(nn.Module):
    def __init__(
        self,
        dim,
        dim_out,
        groups = 8,
        norm = True
    ):
        super().__init__()
        self.groupnorm = nn.GroupNorm(groups, dim) if norm else Identity()  # 初始化分组归一化层
        self.activation = nn.SiLU()  # 激活函数
        self.project = nn.Conv2d(dim, dim_out, 3, padding = 1)  # 卷积层

    def forward(self, x, scale_shift = None):
        x = self.groupnorm(x)  # 分组归一化

        if exists(scale_shift):
            scale, shift = scale_shift
            x = x * (scale + 1) + shift  # 缩放和平移

        x = self.activation(x)  # 激活函数
        return self.project(x)  # 卷积操作

class ResnetBlock(nn.Module):
    def __init__(
        self,
        dim,
        dim_out,
        *,
        cond_dim = None,
        time_cond_dim = None,
        groups = 8,
        linear_attn = False,
        use_gca = False,
        squeeze_excite = False,
        **attn_kwargs
    ):
        super().__init__()

        self.time_mlp = None

        if exists(time_cond_dim):
            self.time_mlp = nn.Sequential(
                nn.SiLU(),
                nn.Linear(time_cond_dim, dim_out * 2)
            )  # 时间条件的多层感��机

        self.cross_attn = None

        if exists(cond_dim):
            attn_klass = CrossAttention if not linear_attn else LinearCrossAttention

            self.cross_attn = attn_klass(
                dim = dim_out,
                context_dim = cond_dim,
                **attn_kwargs
            )  # 交叉注意力机制

        self.block1 = Block(dim, dim_out, groups = groups)  # 第一个块
        self.block2 = Block(dim_out, dim_out, groups = groups)  # 第二个块

        self.gca = GlobalContext(dim_in = dim_out, dim_out = dim_out) if use_gca else Always(1)  # 全局上下文注意力

        self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else Identity()  # 残差卷积

    def forward(self, x, time_emb = None, cond = None):

        scale_shift = None
        if exists(self.time_mlp) and exists(time_emb):
            time_emb = self.time_mlp(time_emb)
            time_emb = rearrange(time_emb, 'b c -> b c 1 1')
            scale_shift = time_emb.chunk(2, dim = 1)  # 分割时间嵌入

        h = self.block1(x)  # 第一个块操作

        if exists(self.cross_attn):
            assert exists(cond)
            h = rearrange(h, 'b c h w -> b h w c')
            h, ps = pack([h], 'b * c')
            h = self.cross_attn(h, context = cond) + h  # 交叉注意力机制
            h, = unpack(h, ps, 'b * c')
            h = rearrange(h, 'b h w c -> b c h w')

        h = self.block2(h, scale_shift = scale_shift)  # 第二个块操作

        h = h * self.gca(h)  # 全局上下文注意力

        return h + self.res_conv(x)  # 返回残差连接结果

class CrossAttention(nn.Module):
    def __init__(
        self,
        dim,
        *,
        context_dim = None,
        dim_head = 64,
        heads = 8,
        norm_context = False,
        scale = 8
    # 初始化函数,设置缩放因子和头数
    def __init__(
        super().__init__()
        self.scale = scale

        self.heads = heads
        inner_dim = dim_head * heads

        # 设置上下文维度
        context_dim = default(context_dim, dim)

        # 初始化层归一化
        self.norm = LayerNorm(dim)
        self.norm_context = LayerNorm(context_dim) if norm_context else Identity()

        # 初始化空键值对
        self.null_kv = nn.Parameter(torch.randn(2, dim_head))
        # 线性变换,将输入转换为查询向量
        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        # 线性变换,将上下文转换为键值对
        self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)

        # 初始化查询和键的缩放参数
        self.q_scale = nn.Parameter(torch.ones(dim_head))
        self.k_scale = nn.Parameter(torch.ones(dim_head))

        # 输出层
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim, bias = False),
            LayerNorm(dim)
        )

    # 前向传播函数
    def forward(self, x, context, mask = None):
        # 获取输入的形状和设备信息
        b, n, device = *x.shape[:2], x.device

        # 对输入和上下文进行层归一化
        x = self.norm(x)
        context = self.norm_context(context)

        # 获取查询、键、值
        q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))

        # 重排查询、键、值的维度
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v))

        # 添加空键/值对,用于分类器在先验网络中的自由引导
        nk, nv = map(lambda t: repeat(t, 'd -> b h 1 d', h = self.heads,  b = b), self.null_kv.unbind(dim = -2))
        k = torch.cat((nk, k), dim = -2)
        v = torch.cat((nv, v), dim = -2)

        # 余弦相似度注意力
        q, k = map(l2norm, (q, k))
        q = q * self.q_scale
        k = k * self.k_scale

        # 计算相似度
        sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        # 掩码
        max_neg_value = -torch.finfo(sim.dtype).max
        if exists(mask):
            mask = F.pad(mask, (1, 0), value = True)
            mask = rearrange(mask, 'b j -> b 1 1 j')
            sim = sim.masked_fill(~mask, max_neg_value)

        # softmax计算注意力权重
        attn = sim.softmax(dim = -1, dtype = torch.float32)
        attn = attn.to(sim.dtype)

        # 加权求和得到输出
        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)
class LinearCrossAttention(CrossAttention):
    # 线性交叉注意力类,继承自CrossAttention类
    def forward(self, x, context, mask = None):
        # 前向传播函数,接收输入x、上下文context和掩码mask,默认为None
        b, n, device = *x.shape[:2], x.device

        x = self.norm(x)
        # 对输入x进行规范化处理
        context = self.norm_context(context)
        # 对上下文context进行规范化处理

        q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
        # 将输入x和上下文context转换为查询q、键k和值v

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = self.heads), (q, k, v))
        # 对查询q、键k和值v进行形状重排

        # add null key / value for classifier free guidance in prior net
        # 在先前网络中添加空键/值以用于分类器的自由引导

        nk, nv = map(lambda t: repeat(t, 'd -> (b h) 1 d', h = self.heads,  b = b), self.null_kv.unbind(dim = -2))

        k = torch.cat((nk, k), dim = -2)
        v = torch.cat((nv, v), dim = -2)

        # masking
        # 掩码处理

        max_neg_value = -torch.finfo(x.dtype).max

        if exists(mask):
            mask = F.pad(mask, (1, 0), value = True)
            mask = rearrange(mask, 'b n -> b n 1')
            k = k.masked_fill(~mask, max_neg_value)
            v = v.masked_fill(~mask, 0.)

        # linear attention
        # 线性注意力计算

        q = q.softmax(dim = -1)
        k = k.softmax(dim = -2)

        q = q * self.scale

        context = einsum('b n d, b n e -> b d e', k, v)
        out = einsum('b n d, b d e -> b n e', q, context)
        out = rearrange(out, '(b h) n d -> b n (h d)', h = self.heads)
        return self.to_out(out)

class LinearAttention(nn.Module):
    # 线性注意力类,继承自nn.Module类
    def __init__(
        self,
        dim,
        dim_head = 32,
        heads = 8,
        dropout = 0.05,
        context_dim = None,
        **kwargs
    ):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        inner_dim = dim_head * heads
        self.norm = ChanLayerNorm(dim)

        self.nonlin = nn.SiLU()

        self.to_q = nn.Sequential(
            nn.Dropout(dropout),
            nn.Conv2d(dim, inner_dim, 1, bias = False),
            nn.Conv2d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim)
        )

        self.to_k = nn.Sequential(
            nn.Dropout(dropout),
            nn.Conv2d(dim, inner_dim, 1, bias = False),
            nn.Conv2d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim)
        )

        self.to_v = nn.Sequential(
            nn.Dropout(dropout),
            nn.Conv2d(dim, inner_dim, 1, bias = False),
            nn.Conv2d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim)
        )

        self.to_context = nn.Sequential(nn.LayerNorm(context_dim), nn.Linear(context_dim, inner_dim * 2, bias = False)) if exists(context_dim) else None

        self.to_out = nn.Sequential(
            nn.Conv2d(inner_dim, dim, 1, bias = False),
            ChanLayerNorm(dim)
        )

    def forward(self, fmap, context = None):
        # 前向传播函数,接收特征图fmap和上下文context,默认为None
        h, x, y = self.heads, *fmap.shape[-2:]

        fmap = self.norm(fmap)
        q, k, v = map(lambda fn: fn(fmap), (self.to_q, self.to_k, self.to_v))
        q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> (b h) (x y) c', h = h), (q, k, v))

        if exists(context):
            assert exists(self.to_context)
            ck, cv = self.to_context(context).chunk(2, dim = -1)
            ck, cv = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (ck, cv))
            k = torch.cat((k, ck), dim = -2)
            v = torch.cat((v, cv), dim = -2)

        q = q.softmax(dim = -1)
        k = k.softmax(dim = -2)

        q = q * self.scale

        context = einsum('b n d, b n e -> b d e', k, v)
        out = einsum('b n d, b d e -> b n e', q, context)
        out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, x = x, y = y)

        out = self.nonlin(out)
        return self.to_out(out)

class GlobalContext(nn.Module):
    # 全局上下文类
    """ basically a superior form of squeeze-excitation that is attention-esque """

    def __init__(
        self,
        *,
        dim_in,
        dim_out
    # 定义一个类,继承自 nn.Module
    class Attention(nn.Module):
        # 初始化函数
        def __init__(self, dim_in, dim_out):
            # 调用父类的初始化函数
            super().__init__()
            # 创建一个卷积层,输入维度为 dim_in,输出维度为 1,卷积核大小为 1
            self.to_k = nn.Conv2d(dim_in, 1, 1)
            # 计算隐藏层维度,取 dim_out 除以 2 和 3 中的较大值
            hidden_dim = max(3, dim_out // 2)
    
            # 创建一个神经网络序列
            self.net = nn.Sequential(
                # 第一层卷积层,输入维度为 dim_in,输出维度为 hidden_dim,卷积核大小为 1
                nn.Conv2d(dim_in, hidden_dim, 1),
                # 使用 SiLU 激活函数
                nn.SiLU(),
                # 第二层卷积层,输入维度为 hidden_dim,输出维度为 dim_out,卷积核大小为 1
                nn.Conv2d(hidden_dim, dim_out, 1),
                # 使用 Sigmoid 激活函数
                nn.Sigmoid()
            )
    
        # 前向传播函数
        def forward(self, x):
            # 将输入 x 通过 self.to_k 进行处理,得到 context
            context = self.to_k(x)
            # 对 x 和 context 进行维度重排,将 'b n ...' 转换为 'b n (...)'
            x, context = map(lambda t: rearrange(t, 'b n ... -> b n (...)'), (x, context))
            # 使用 einsum 进行张量乘法,计算注意力权重
            out = einsum('b i n, b c n -> b c i', context.softmax(dim = -1), x)
            # 将输出 out 进行维度重排,将 '...' 转换为 '... 1'
            out = rearrange(out, '... -> ... 1')
            # 将处理后的 out 输入到神经网络 self.net 中
            return self.net(out)
# 定义一个前馈神经网络模块,包含层归一化、线性层、GELU激活函数和线性层
def FeedForward(dim, mult = 2):
    # 计算隐藏层维度
    hidden_dim = int(dim * mult)
    return nn.Sequential(
        LayerNorm(dim),  # 层归一化
        nn.Linear(dim, hidden_dim, bias = False),  # 线性层
        nn.GELU(),  # GELU激活函数
        LayerNorm(hidden_dim),  # 层归一化
        nn.Linear(hidden_dim, dim, bias = False)  # 线性层
    )

# 定义一个通道前馈神经网络模块,包含通道层归一化、卷积层、GELU激活函数和卷积层
def ChanFeedForward(dim, mult = 2):  # in paper, it seems for self attention layers they did feedforwards with twice channel width
    hidden_dim = int(dim * mult)
    return nn.Sequential(
        ChanLayerNorm(dim),  # 通道层归一化
        nn.Conv2d(dim, hidden_dim, 1, bias = False),  # 卷积层
        nn.GELU(),  # GELU激活函数
        ChanLayerNorm(hidden_dim),  # 通道层归一化
        nn.Conv2d(hidden_dim, dim, 1, bias = False)  # 卷积层
    )

# 定义一个Transformer块,包含多个自注意力层和前馈神经网络层
class TransformerBlock(nn.Module):
    def __init__(
        self,
        dim,
        *,
        depth = 1,
        heads = 8,
        dim_head = 32,
        ff_mult = 2,
        context_dim = None
    ):
        super().__init__()
        self.layers = nn.ModuleList([])

        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim = dim, heads = heads, dim_head = dim_head, context_dim = context_dim),  # 自注意力层
                FeedForward(dim = dim, mult = ff_mult)  # 前馈神经网络层
            ]))

    def forward(self, x, context = None):
        x = rearrange(x, 'b c h w -> b h w c')
        x, ps = pack([x], 'b * c')

        for attn, ff in self.layers:
            x = attn(x, context = context) + x
            x = ff(x) + x

        x, = unpack(x, ps, 'b * c')
        x = rearrange(x, 'b h w c -> b c h w')
        return x

# 定义一个线性注意力Transformer块,包含多个线性注意力层和通道前馈神经网络层
class LinearAttentionTransformerBlock(nn.Module):
    def __init__(
        self,
        dim,
        *,
        depth = 1,
        heads = 8,
        dim_head = 32,
        ff_mult = 2,
        context_dim = None,
        **kwargs
    ):
        super().__init__()
        self.layers = nn.ModuleList([])

        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                LinearAttention(dim = dim, heads = heads, dim_head = dim_head, context_dim = context_dim),  # 线性注意力层
                ChanFeedForward(dim = dim, mult = ff_mult)  # 通道前馈神经网络层
            ]))

    def forward(self, x, context = None):
        for attn, ff in self.layers:
            x = attn(x, context = context) + x
            x = ff(x) + x
        return x

# 定义一个交叉嵌入层,包含多个卷积层
class CrossEmbedLayer(nn.Module):
    def __init__(
        self,
        dim_in,
        kernel_sizes,
        dim_out = None,
        stride = 2
    ):
        super().__init__()
        assert all([*map(lambda t: (t % 2) == (stride % 2), kernel_sizes)])
        dim_out = default(dim_out, dim_in)

        kernel_sizes = sorted(kernel_sizes)
        num_scales = len(kernel_sizes)

        # 计算每个尺度的维度
        dim_scales = [int(dim_out / (2 ** i)) for i in range(1, num_scales)]
        dim_scales = [*dim_scales, dim_out - sum(dim_scales)]

        self.convs = nn.ModuleList([])
        for kernel, dim_scale in zip(kernel_sizes, dim_scales):
            self.convs.append(nn.Conv2d(dim_in, dim_scale, kernel, stride = stride, padding = (kernel - stride) // 2))

    def forward(self, x):
        fmaps = tuple(map(lambda conv: conv(x), self.convs))
        return torch.cat(fmaps, dim = 1)

# 定义一个上采样合并器,包含多个块
class UpsampleCombiner(nn.Module):
    def __init__(
        self,
        dim,
        *,
        enabled = False,
        dim_ins = tuple(),
        dim_outs = tuple()
    ):
        super().__init__()
        dim_outs = cast_tuple(dim_outs, len(dim_ins))
        assert len(dim_ins) == len(dim_outs)

        self.enabled = enabled

        if not self.enabled:
            self.dim_out = dim
            return

        self.fmap_convs = nn.ModuleList([Block(dim_in, dim_out) for dim_in, dim_out in zip(dim_ins, dim_outs)])
        self.dim_out = dim + (sum(dim_outs) if len(dim_outs) > 0 else 0)
    # 定义一个前向传播函数,接受输入 x 和特征图列表 fmaps,默认为 None
    def forward(self, x, fmaps = None):
        # 获取输入 x 的最后一个维度大小作为目标大小
        target_size = x.shape[-1]

        # 如果未提供特征图列表,则使用空元组
        fmaps = default(fmaps, tuple())

        # 如果模块未启用,特征图列表为空,或者卷积层列表为空,则直接返回输入 x
        if not self.enabled or len(fmaps) == 0 or len(self.fmap_convs) == 0:
            return x

        # 将特征图列表中的每个特征图调整大小为目标大小
        fmaps = [resize_image_to(fmap, target_size) for fmap in fmaps]
        # 对每个调整大小后的特征图应用对应的卷积操作,得到输出列表
        outs = [conv(fmap) for fmap, conv in zip(fmaps, self.fmap_convs)]
        # 在第一个维度上拼接输入 x 和所有输出,返回结果
        return torch.cat((x, *outs), dim = 1)
# 定义一个名为 Unet 的类,继承自 nn.Module
class Unet(nn.Module):
    # 初始化方法,设置类的属性
    def __init__(
        self,
        *,
        dim,
        text_embed_dim = get_encoded_dim(DEFAULT_T5_NAME),  # 默认文本嵌入维度
        num_resnet_blocks = 1,  # ResNet 块的数量
        cond_dim = None,  # 条件维度
        num_image_tokens = 4,  # 图像令牌数量
        num_time_tokens = 2,  # 时间令牌数量
        learned_sinu_pos_emb_dim = 16,  # 学习的正弦位置编码维度
        out_dim = None,  # 输出维度
        dim_mults=(1, 2, 4, 8),  # 维度倍增
        cond_images_channels = 0,  # 条件图像通道数
        channels = 3,  # 通道数
        channels_out = None,  # 输出通道数
        attn_dim_head = 64,  # 注意力头维度
        attn_heads = 8,  # 注意力头数量
        ff_mult = 2.,  # FeedForward 层倍增因子
        lowres_cond = False,  # 低分辨率条件
        layer_attns = True,  # 层间注意力
        layer_attns_depth = 1,  # 层间注意力深度
        layer_mid_attns_depth = 1,  # 中间层注意力深度
        layer_attns_add_text_cond = True,  # 是否使用文本嵌入来条件化自注意力块
        attend_at_middle = True,  # 是否在瓶颈处进行注意力
        layer_cross_attns = True,  # 层间交叉注意力
        use_linear_attn = False,  # 是否使用线性注意力
        use_linear_cross_attn = False,  # 是否使用线性交叉注意力
        cond_on_text = True,  # 是否在文本上进行条件化
        max_text_len = 256,  # 最大文本长度
        init_dim = None,  # 初始化维度
        resnet_groups = 8,  # ResNet 组数
        init_conv_kernel_size = 7,  # 初始卷积核大小
        init_cross_embed = True,  # 初始化交叉嵌入
        init_cross_embed_kernel_sizes = (3, 7, 15),  # 初始化交叉嵌入的卷积核大小
        cross_embed_downsample = False,  # 交叉嵌入下采样
        cross_embed_downsample_kernel_sizes = (2, 4),  # 交叉嵌入下采样的卷积核大小
        attn_pool_text = True,  # 注意力池化文本
        attn_pool_num_latents = 32,  # 注意力池化潜在数
        dropout = 0.,  # 丢弃率
        memory_efficient = False,  # 内存效率
        init_conv_to_final_conv_residual = False,  # 初始卷积到最终卷积的残差连接
        use_global_context_attn = True,  # 使用全局上下文注意力
        scale_skip_connection = True,  # 缩放跳跃连接
        final_resnet_block = True,  # 最终 ResNet 块
        final_conv_kernel_size = 3,  # 最终卷积核大小
        self_cond = False,  # 自条件
        resize_mode = 'nearest',  # 调整模式
        combine_upsample_fmaps = False,  # 合并所有上采样块的特征图
        pixel_shuffle_upsample = True,  # 像素混洗上采样
    # 如果当前 Unet 的设置不正确,重新使用正确的设置重新初始化 Unet
    def cast_model_parameters(
        self,
        *,
        lowres_cond,
        text_embed_dim,
        channels,
        channels_out,
        cond_on_text
    ):
        # 如果设置与当前 Unet 的设置相同,则返回当前 Unet
        if lowres_cond == self.lowres_cond and \
            channels == self.channels and \
            cond_on_text == self.cond_on_text and \
            text_embed_dim == self._locals['text_embed_dim'] and \
            channels_out == self.channels_out:
            return self

        # 更新参数
        updated_kwargs = dict(
            lowres_cond = lowres_cond,
            text_embed_dim = text_embed_dim,
            channels = channels,
            channels_out = channels_out,
            cond_on_text = cond_on_text
        )

        return self.__class__(**{**self._locals, **updated_kwargs})

    # 返回完整 Unet 配置及其参数状态字典的方法
    def to_config_and_state_dict(self):
        return self._locals, self.state_dict()

    # 从配置和状态字典中重新创建 Unet 的类方法
    @classmethod
    def from_config_and_state_dict(klass, config, state_dict):
        unet = klass(**config)
        unet.load_state_dict(state_dict)
        return unet

    # 将 Unet 持久化到磁盘的方法
    def persist_to_file(self, path):
        path = Path(path)
        path.parents[0].mkdir(exist_ok = True, parents = True)

        config, state_dict = self.to_config_and_state_dict()
        pkg = dict(config = config, state_dict = state_dict)
        torch.save(pkg, str(path))

    # 从使用 `persist_to_file` 保存的文件重新创建 Unet 的类方法
    @classmethod
    # 从文件中加载模型参数并返回实例化后的模型对象
    def hydrate_from_file(klass, path):
        # 将路径转换为 Path 对象
        path = Path(path)
        # 断言路径存在
        assert path.exists()
        # 使用 torch.load 加载模型参数
        pkg = torch.load(str(path))

        # 断言加载的参数中包含 'config' 和 'state_dict'
        assert 'config' in pkg and 'state_dict' in pkg
        # 分别获取配置和状态字典
        config, state_dict = pkg['config'], pkg['state_dict']

        # 使用配置和状态字典实例化 Unet 模型
        return Unet.from_config_and_state_dict(config, state_dict)

    # 使用分类器自由指导进行前向传播

    def forward_with_cond_scale(
        self,
        *args,
        cond_scale = 1.,
        **kwargs
    ):
        # 调用 forward 方法获取 logits
        logits = self.forward(*args, **kwargs)

        # 如果 cond_scale 为 1,则直接返回 logits
        if cond_scale == 1:
            return logits

        # 使用 cond_scale 进行加权计算
        null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
        return null_logits + (logits - null_logits) * cond_scale

    # 普通的前向传播方法

    def forward(
        self,
        x,
        time,
        *,
        lowres_cond_img = None,
        lowres_noise_times = None,
        text_embeds = None,
        text_mask = None,
        cond_images = None,
        self_cond = None,
        cond_drop_prob = 0.
# 定义一个空的 Unet 类
class NullUnet(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()
        self.lowres_cond = False
        self.dummy_parameter = nn.Parameter(torch.tensor([0.]))

    # 将模型参数转换为自身
    def cast_model_parameters(self, *args, **kwargs):
        return self

    # 前向传播函数,直接返回输入
    def forward(self, x, *args, **kwargs):
        return x

# 预定义的 Unet 类,配置与论文附录中的超参数对应
class BaseUnet64(Unet):
    def __init__(self, *args, **kwargs):
        default_kwargs = dict(
            dim = 512,
            dim_mults = (1, 2, 3, 4),
            num_resnet_blocks = 3,
            layer_attns = (False, True, True, True),
            layer_cross_attns = (False, True, True, True),
            attn_heads = 8,
            ff_mult = 2.,
            memory_efficient = False
        )
        super().__init__(*args, **{**default_kwargs, **kwargs})

class SRUnet256(Unet):
    def __init__(self, *args, **kwargs):
        default_kwargs = dict(
            dim = 128,
            dim_mults = (1, 2, 4, 8),
            num_resnet_blocks = (2, 4, 8, 8),
            layer_attns = (False, False, False, True),
            layer_cross_attns = (False, False, False, True),
            attn_heads = 8,
            ff_mult = 2.,
            memory_efficient = True
        )
        super().__init__(*args, **{**default_kwargs, **kwargs})

class SRUnet1024(Unet):
    def __init__(self, *args, **kwargs):
        default_kwargs = dict(
            dim = 128,
            dim_mults = (1, 2, 4, 8),
            num_resnet_blocks = (2, 4, 8, 8),
            layer_attns = False,
            layer_cross_attns = (False, False, False, True),
            attn_heads = 8,
            ff_mult = 2.,
            memory_efficient = True
        )
        super().__init__(*args, **{**default_kwargs, **kwargs})

# 主要的 Imagen 类,是来自 Ho 等人的级联 DDPM
class Imagen(nn.Module):
    def __init__(
        self,
        unets,
        *,
        image_sizes,                                # 用于级联 ddpm,每个阶段的图像大小
        text_encoder_name = DEFAULT_T5_NAME,
        text_embed_dim = None,
        channels = 3,
        timesteps = 1000,
        cond_drop_prob = 0.1,
        loss_type = 'l2',
        noise_schedules = 'cosine',
        pred_objectives = 'noise',
        random_crop_sizes = None,
        lowres_noise_schedule = 'linear',
        lowres_sample_noise_level = 0.2,            # 论文中提到的一个新技巧,对低分辨率条件图像添加噪声,并在采样时将其固定到一定水平(0.1 或 0.3)- Unet 也被设计为在这��噪声水平上进行条件化
        per_sample_random_aug_noise_level = False,  # 不清楚在进行增强噪声水平条件化时,每个批次元素是否接收随机的增强噪声值-由于 @marunine 的发现,关闭此功能
        condition_on_text = True,
        auto_normalize_img = True,                  # 是否自动处理将图像从 [0, 1] 规范化为 [-1, 1] 并自动恢复-如果要自己从数据加载器传入 [-1, 1] 范围的图像,则可以关闭此功能
        dynamic_thresholding = True,
        dynamic_thresholding_percentile = 0.95,     # 通过查阅论文,不确定这是基于什么的
        only_train_unet_number = None,
        temporal_downsample_factor = 1,
        resize_cond_video_frames = True,
        resize_mode = 'nearest',
        min_snr_loss_weight = True,                 # https://arxiv.org/abs/2303.09556
        min_snr_gamma = 5
    def force_unconditional_(self):
        self.condition_on_text = False
        self.unconditional = True

        for unet in self.unets:
            unet.cond_on_text = False

    @property
    def device(self):
        return self._temp.device
    # 获取指定编号的 UNet 模型
    def get_unet(self, unet_number):
        # 确保编号在有效范围内
        assert 0 < unet_number <= len(self.unets)
        index = unet_number - 1

        # 如果 self.unets 是 nn.ModuleList 类型
        if isinstance(self.unets, nn.ModuleList):
            # 将 self.unets 转换为列表
            unets_list = [unet for unet in self.unets]
            # 删除原有的 self.unets 属性
            delattr(self, 'unets')
            # 将转换后的列表重新赋值给 self.unets
            self.unets = unets_list

        # 如果指定的编号不是当前正在训练的编号
        if index != self.unet_being_trained_index:
            # 遍历所有 UNet 模型
            for unet_index, unet in enumerate(self.unets):
                # 将当前 UNet 模型移到指定设备上,其他模型移到 CPU 上
                unet.to(self.device if unet_index == index else 'cpu')

        # 更新当前正在训练的 UNet 模型编号
        self.unet_being_trained_index = index
        # 返回指定编号的 UNet 模型
        return self.unets[index]

    # 将所有 UNet 模型重置到同一设备上
    def reset_unets_all_one_device(self, device = None):
        # 设置设备为默认设备或者指定设备
        device = default(device, self.device)
        # 将所有 UNet 模型转换为 nn.ModuleList 类型
        self.unets = nn.ModuleList([*self.unets])
        # 将所有 UNet 模型移到指定设备上
        self.unets.to(device)

        # 重置当前正在训练的 UNet 模型编号
        self.unet_being_trained_index = -1

    # 使用上下文管理器将指定编号的 UNet 模型移到 GPU 上
    @contextmanager
    def one_unet_in_gpu(self, unet_number = None, unet = None):
        # 确保只有一个参数是有效的
        assert exists(unet_number) ^ exists(unet)

        # 如果指定了编号,则获取对应的 UNet 模型
        if exists(unet_number):
            unet = self.unets[unet_number - 1]

        # 创建 CPU 设备
        cpu = torch.device('cpu')

        # 获取所有 UNet 模型的设备信息
        devices = [module_device(unet) for unet in self.unets]

        # 将所有 UNet 模型移到 CPU 上
        self.unets.to(cpu)
        # 将指定 UNet 模型移到当前设备上
        unet.to(self.device)

        yield

        # 将所有 UNet 模型还原到各自的设备上
        for unet, device in zip(self.unets, devices):
            unet.to(device)

    # 重写 state_dict 函数
    def state_dict(self, *args, **kwargs):
        # 重置所有 UNet 模型到同一设备上
        self.reset_unets_all_one_device()
        return super().state_dict(*args, **kwargs)

    # 重写 load_state_dict 函数
    def load_state_dict(self, *args, **kwargs):
        # 重置所有 UNet 模型到同一设备上
        self.reset_unets_all_one_device()
        return super().load_state_dict(*args, **kwargs)

    # 高斯扩散方法

    def p_mean_variance(
        self,
        unet,
        x,
        t,
        *,
        noise_scheduler,
        text_embeds = None,
        text_mask = None,
        cond_images = None,
        cond_video_frames = None,
        post_cond_video_frames = None,
        lowres_cond_img = None,
        self_cond = None,
        lowres_noise_times = None,
        cond_scale = 1.,
        model_output = None,
        t_next = None,
        pred_objective = 'noise',
        dynamic_threshold = True
    ):
        # 断言条件:如果条件为真,则抛出异常,说明不能使用分类器自由引导
        assert not (cond_scale != 1. and not self.can_classifier_guidance), 'imagen was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'

        # 初始化视频参数字典
        video_kwargs = dict()
        # 如果是视频模式,设置视频参数
        if self.is_video:
            video_kwargs = dict(
                cond_video_frames = cond_video_frames,
                post_cond_video_frames = post_cond_video_frames,
            )

        # 使用默认函数处理模型输出,获取预测结果
        pred = default(model_output, lambda: unet.forward_with_cond_scale(
            x,
            noise_scheduler.get_condition(t),
            text_embeds = text_embeds,
            text_mask = text_mask,
            cond_images = cond_images,
            cond_scale = cond_scale,
            lowres_cond_img = lowres_cond_img,
            self_cond = self_cond,
            lowres_noise_times = self.lowres_noise_schedule.get_condition(lowres_noise_times),
            **video_kwargs
        ))

        # 根据预测目标类型进行处理
        if pred_objective == 'noise':
            x_start = noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
        elif pred_objective == 'x_start':
            x_start = pred
        elif pred_objective == 'v':
            x_start = noise_scheduler.predict_start_from_v(x, t = t, v = pred)
        else:
            raise ValueError(f'unknown objective {pred_objective}')

        # 如果启用动态阈值
        if dynamic_threshold:
            # 根据重构样本的绝对值百分位数确定动态阈值
            s = torch.quantile(
                rearrange(x_start, 'b ... -> b (...)').abs(),
                self.dynamic_thresholding_percentile,
                dim = -1
            )

            s.clamp_(min = 1.)
            s = right_pad_dims_to(x_start, s)
            x_start = x_start.clamp(-s, s) / s
        else:
            x_start.clamp_(-1., 1.)

        # 计算均值和方差
        mean_and_variance = noise_scheduler.q_posterior(x_start = x_start, x_t = x, t = t, t_next = t_next)
        return mean_and_variance, x_start

    # 无梯度计算
    @torch.no_grad()
    def p_sample(
        self,
        unet,
        x,
        t,
        *,
        noise_scheduler,
        t_next = None,
        text_embeds = None,
        text_mask = None,
        cond_images = None,
        cond_video_frames = None,
        post_cond_video_frames = None,
        cond_scale = 1.,
        self_cond = None,
        lowres_cond_img = None,
        lowres_noise_times = None,
        pred_objective = 'noise',
        dynamic_threshold = True
    ):
        # 获取输入张量的形状和设备信息
        b, *_, device = *x.shape, x.device

        # 初始化视频参数字典
        video_kwargs = dict()
        # 如果是视频模式,设置视频参数
        if self.is_video:
            video_kwargs = dict(
                cond_video_frames = cond_video_frames,
                post_cond_video_frames = post_cond_video_frames,
            )

        # 获取均值、方差和起始值
        (model_mean, _, model_log_variance), x_start = self.p_mean_variance(
            unet,
            x = x,
            t = t,
            t_next = t_next,
            noise_scheduler = noise_scheduler,
            text_embeds = text_embeds,
            text_mask = text_mask,
            cond_images = cond_images,
            cond_scale = cond_scale,
            lowres_cond_img = lowres_cond_img,
            self_cond = self_cond,
            lowres_noise_times = lowres_noise_times,
            pred_objective = pred_objective,
            dynamic_threshold = dynamic_threshold,
            **video_kwargs
        )

        # 生成随机噪声
        noise = torch.randn_like(x)
        # 当 t == 0 时不添加噪声
        is_last_sampling_timestep = (t_next == 0) if isinstance(noise_scheduler, GaussianDiffusionContinuousTimes) else (t == 0)
        nonzero_mask = (1 - is_last_sampling_timestep.float()).reshape(b, *((1,) * (len(x.shape) - 1)))
        # 计算预测值
        pred = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
        return pred, x_start

    # 无梯度计算
    @torch.no_grad()
    # 定义一个函数 p_sample_loop,用于执行采样循环
    def p_sample_loop(
        self,
        unet,
        shape,
        *,
        noise_scheduler,
        lowres_cond_img = None,
        lowres_noise_times = None,
        text_embeds = None,
        text_mask = None,
        cond_images = None,
        cond_video_frames = None,
        post_cond_video_frames = None,
        inpaint_images = None,
        inpaint_videos = None,
        inpaint_masks = None,
        inpaint_resample_times = 5,
        init_images = None,
        skip_steps = None,
        cond_scale = 1,
        pred_objective = 'noise',
        dynamic_threshold = True,
        use_tqdm = True
    ):
        # 获取当前设备
        device = self.device

        # 获取批次大小
        batch = shape[0]
        # 生成指定形状的随机张量
        img = torch.randn(shape, device = device)

        # video

        # 判断是否为视频
        is_video = len(shape) == 5
        # 如果是视频,获取帧数
        frames = shape[-3] if is_video else None
        # 如果存在帧数,则传入目标帧数参数,否则传入空字典
        resize_kwargs = dict(target_frames = frames) if exists(frames) else dict()

        # for initialization with an image or video

        # 如果存在初始化图像
        if exists(init_images):
            # 将随机生成的图像与初始化图像相加
            img += init_images

        # keep track of x0, for self conditioning

        # 初始化 x0,用于自身条件
        x_start = None

        # prepare inpainting

        # 将 inpaint_videos 默认为 inpaint_images
        inpaint_images = default(inpaint_videos, inpaint_images)

        # 判断是否存在 inpaint_images 和 inpaint_masks
        has_inpainting = exists(inpaint_images) and exists(inpaint_masks)
        # 如果存在 inpaint_images 和 inpaint_masks,则重采样次数为 inpaint_resample_times,否则为 1
        resample_times = inpaint_resample_times if has_inpainting else 1

        # 如果存在 inpaint_images 和 inpaint_masks
        if has_inpainting:
            # 对 inpaint_images 进行归一化处理
            inpaint_images = self.normalize_img(inpaint_images)
            # 将 inpaint_images 调整大小为指定形状
            inpaint_images = self.resize_to(inpaint_images, shape[-1], **resize_kwargs)
            # 将 inpaint_masks 调整大小为指定形状,并转换为布尔类型
            inpaint_masks = self.resize_to(rearrange(inpaint_masks, 'b ... -> b 1 ...').float(), shape[-1], **resize_kwargs).bool()

        # time

        # 获取采样时间步长
        timesteps = noise_scheduler.get_sampling_timesteps(batch, device = device)

        # 是否跳过任何步骤

        # 设置默认跳过步数为 0
        skip_steps = default(skip_steps, 0)
        # 从指定步数开始采样
        timesteps = timesteps[skip_steps:]

        # video conditioning kwargs

        # 初始化视频条件参数字典
        video_kwargs = dict()
        # 如果是视频
        if self.is_video:
            # 设置视频条件参数
            video_kwargs = dict(
                cond_video_frames = cond_video_frames,
                post_cond_video_frames = post_cond_video_frames,
            )

        # 遍历时间步长
        for times, times_next in tqdm(timesteps, desc = 'sampling loop time step', total = len(timesteps), disable = not use_tqdm):
            # 判断是否为最后一个时间步长
            is_last_timestep = times_next == 0

            # 反向遍历重采样次数
            for r in reversed(range(resample_times)):
                # 判断是否为最后一个重采样步骤
                is_last_resample_step = r == 0

                # 如果存在 inpainting
                if has_inpainting:
                    # 从噪声调度器中采样噪声图像
                    noised_inpaint_images, *_ = noise_scheduler.q_sample(inpaint_images, t = times)
                    # 根据掩模进行图像修复
                    img = img * ~inpaint_masks + noised_inpaint_images * inpaint_masks

                # 如果 unet.self_cond 为真,则设置 self_cond 为 x_start,否则为 None
                self_cond = x_start if unet.self_cond else None

                # 生成图像
                img, x_start = self.p_sample(
                    unet,
                    img,
                    times,
                    t_next = times_next,
                    text_embeds = text_embeds,
                    text_mask = text_mask,
                    cond_images = cond_images,
                    cond_scale = cond_scale,
                    self_cond = self_cond,
                    lowres_cond_img = lowres_cond_img,
                    lowres_noise_times = lowres_noise_times,
                    noise_scheduler = noise_scheduler,
                    pred_objective = pred_objective,
                    dynamic_threshold = dynamic_threshold,
                    **video_kwargs
                )

                # 如果存在 inpainting 且不是最后一个重采样步骤或所有时间步骤都为最后一个
                if has_inpainting and not (is_last_resample_step or torch.all(is_last_timestep)):
                    # 从指定时间点到另一个时间点采样图像
                    renoised_img = noise_scheduler.q_sample_from_to(img, times_next, times)

                    # 根据条件选择图像
                    img = torch.where(
                        self.right_pad_dims_to_datatype(is_last_timestep),
                        img,
                        renoised_img
                    )

        # 限制图像像素值范围在 -1 到 1 之间
        img.clamp_(-1., 1.)

        # final inpainting

        # 如果存在 inpainting
        if has_inpainting:
            # 根据掩模进行最终图像修复
            img = img * ~inpaint_masks + inpaint_images * inpaint_masks

        # 反归一化图像
        unnormalize_img = self.unnormalize_img(img)
        # 返回反归一化后的图像
        return unnormalize_img

    # 禁用梯度计算
    @torch.no_grad()
    # 设置评估模式装饰器
    @eval_decorator
    # 设置类型检查装饰器
    @beartype
    # 定义一个方法用于生成样本
    def sample(
        self,
        texts: List[str] = None,  # 文本列表,默认为 None
        text_masks = None,  # 文本掩码,默认为 None
        text_embeds = None,  # 文本嵌入,默认为 None
        video_frames = None,  # 视频帧,默认为 None
        cond_images = None,  # 条件图像,默认为 None
        cond_video_frames = None,  # 条件视频帧,默认为 None
        post_cond_video_frames = None,  # 后置条件视频帧,默认为 None
        inpaint_videos = None,  # 修复视频,默认为 None
        inpaint_images = None,  # 修复图像,默认为 None
        inpaint_masks = None,  # 修复掩码,默认为 None
        inpaint_resample_times = 5,  # 修复重采样次数,默认为 5
        init_images = None,  # 初始图像,默认为 None
        skip_steps = None,  # 跳过步骤,默认为 None
        batch_size = 1,  # 批量大小,默认为 1
        cond_scale = 1.,  # 条件比例,默认为 1.0
        lowres_sample_noise_level = None,  # 低分辨率采样噪声级别,默认为 None
        start_at_unet_number = 1,  # 开始于 Unet 编号,默认为 1
        start_image_or_video = None,  # 开始图像或视频,默认为 None
        stop_at_unet_number = None,  # 停止于 Unet 编号,默认为 None
        return_all_unet_outputs = False,  # 返回所有 Unet 输出,默认为 False
        return_pil_images = False,  # 返回 PIL 图像,默认为 False
        device = None,  # 设备,默认为 None
        use_tqdm = True,  # 使用 tqdm,默认为 True
        use_one_unet_in_gpu = True  # 在 GPU 中使用一个 Unet,默认为 True
    # 定义一个方法用于计算损失
    @beartype
    def p_losses(
        self,
        unet: Union[Unet, Unet3D, NullUnet, DistributedDataParallel],  # Unet 对象,默认为 None
        x_start,  # 起始值
        times,  # 时间
        *,
        noise_scheduler,  # 噪声调度器
        lowres_cond_img = None,  # 低分辨率条件图像,默认为 None
        lowres_aug_times = None,  # 低分辨率增强次数,默认为 None
        text_embeds = None,  # 文本嵌入,默认为 None
        text_mask = None,  # 文本掩码,默认为 None
        cond_images = None,  # 条件图像,默认为 None
        noise = None,  # 噪声,默认为 None
        times_next = None,  # 下一个时间,默认为 None
        pred_objective = 'noise',  # 预测目标,默认为 'noise'
        min_snr_gamma = None,  # 最小信噪比伽马,默认为 None
        random_crop_size = None,  # ��机裁剪大小,默认为 None
        **kwargs  # 其他关键字参数
    # 定义一个方法用于前向传播
    @beartype
    def forward(
        self,
        images,  # 图像或视频
        unet: Union[Unet, Unet3D, NullUnet, DistributedDataParallel] = None,  # Unet 对象,默认为 None
        texts: List[str] = None,  # 文本列表,默认为 None
        text_embeds = None,  # 文本嵌入,默认为 None
        text_masks = None,  # 文本掩码,默认为 None
        unet_number = None,  # Unet 编号,默认为 None
        cond_images = None,  # 条件图像,默认为 None
        **kwargs  # 其他关键字参数