Lucidrains-系列项目源码解析-十六-

62 阅读20分钟

Lucidrains 系列项目源码解析(十六)

.\lucidrains\cross-transformers-pytorch\setup.py

# 导入设置工具和查找包工具
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'cross-transformers-pytorch',  # 包名
  packages = find_packages(),  # 查找所有包
  version = '0.0.2',  # 版本号
  license='MIT',  # 许可证
  description = 'Cross Transformers - Pytorch',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  url = 'https://github.com/lucidrains/cross-transformers-pytorch',  # 项目链接
  keywords = [  # 关键词列表
    'artificial intelligence',
    'attention mechanism',
    'cross attention',
    'few shot learning'
  ],
  install_requires=[  # 安装依赖
    'torch>=1.6',
    'einops>=0.3'
  ],
  classifiers=[  # 分类器列表
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\DALLE-pytorch\dalle_pytorch\attention.py

# 从 inspect 模块中导入 isfunction 函数
# 从 math 模块中导入 ceil 函数
# 导入 torch 库
# 从 torch 模块中导入 nn、einsum
# 从 torch.nn 模块中导入 functional 模块,并重命名为 F
# 从 einops 库中导入 rearrange、repeat 函数
# 导入 rotary_embedding_torch 库中的 apply_rotary_emb 函数

def exists(val):
    # 判断值是否存在
    return val is not None

def uniq(arr):
    # 返回数组中唯一的元素
    return{el: True for el in arr}.keys()

def default(val, d):
    # 如果值存在,则返回该值;否则返回默认值
    if exists(val):
        return val
    return d() if isfunction(d) else d

def max_neg_value(t):
    # 返回给定张量的最大负值
    return -torch.finfo(t.dtype).max

def stable_softmax(t, dim = -1, alpha = 32 ** 2):
    # 计算稳定的 softmax 函数
    t = t / alpha
    t = t - torch.amax(t, dim = dim, keepdim = True).detach()
    return (t * alpha).softmax(dim = dim)

def apply_pos_emb(pos_emb, qkv):
    # 应用位置编码到查询、键、值张量中
    n = qkv[0].shape[-2]
    pos_emb = pos_emb[..., :n, :]
    return tuple(map(lambda t: apply_rotary_emb(pos_emb, t), qkv))

# 定义 Attention 类
class Attention(nn.Module):
    def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropout = 0., stable = False,
                 static_mask = None):
        # 初始化 Attention 类
        super().__init__()
        inner_dim = dim_head *  heads
        self.heads = heads
        self.seq_len = seq_len
        self.scale = dim_head ** -0.5

        self.stable = stable
        self.causal = causal
        self.register_buffer('static_mask', static_mask, persistent=False)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key = None):
        # 前向传播函数
        b, n, _, h, device = *x.shape, self.heads, x.device
        softmax = torch.softmax if not self.stable else stable_softmax
        offset = cache.get('offset', 0) if exists(cache) else 0

        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)

        if exists(rotary_pos_emb):
            q, k, v = apply_pos_emb(rotary_pos_emb[..., offset:, :], (q, k, v))

        q = q * self.scale

        if offset > 0:
            k_top, v_top = cache[cache_key]
            k = torch.cat([k_top, k], dim=-2)
            v = torch.cat([v_top, v], dim=-2)
        if exists(cache):
            cache[cache_key] = k, v

        dots = torch.einsum('b h i d, b h j d -> b h i j', q, k)
        mask_value = max_neg_value(dots)

        if exists(mask):
            mask = rearrange(mask, 'b j -> b () () j')
            dots.masked_fill_(~mask, mask_value)
            del mask

        if self.causal and offset == 0:  # causality is naturally enforced for the cached inference
            i, j = dots.shape[-2:]
            mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
            dots.masked_fill_(mask, mask_value)

        if exists(self.static_mask):
            dots.masked_fill_(~self.static_mask[offset:offset + n, :offset + n], mask_value)

        attn = softmax(dots, dim=-1)

        out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        out =  self.to_out(out)
        return out

# 定义 SparseConvCausalAttention 类,实现稀疏注意力机制
class SparseConvCausalAttention(nn.Module):
    # 初始化函数,设置模型参数和超参数
    def __init__(self, dim, seq_len, image_size = 32, kernel_size = 5, dilation = 1, heads = 8, dim_head = 64, dropout = 0., stable = False, **kwargs):
        # 调用父类的初始化函数
        super().__init__()
        # 断言核大小必须为奇数
        assert kernel_size % 2 == 1, 'kernel size must be odd'

        # 计算内部维度
        inner_dim = dim_head *  heads
        # 设置序列长度
        self.seq_len = seq_len
        # 设置头数
        self.heads = heads
        # 设置缩放因子
        self.scale = dim_head ** -0.5
        # 设置图像大小
        self.image_size = image_size
        # 设置核大小
        self.kernel_size = kernel_size
        # 设置膨胀率
        self.dilation = dilation

        # 设置是否稳定
        self.stable = stable

        # 创建线性层,用于将输入转换为查询、键和值
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        # 创建输出层,包含线性层和dropout层
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        )
    # 定义前向传播函数,接受输入 x,mask 和旋转位置嵌入 rotary_pos_emb
    def forward(self, x, mask = None, rotary_pos_emb = None):
        # 解包 x 的形状信息,包括 batch 大小 b,序列长度 n,头数 h,图像大小 img_size,卷积核大小 kernel_size,膨胀率 dilation,序列长度 seq_len,设备信息 device
        b, n, _, h, img_size, kernel_size, dilation, seq_len, device = *x.shape, self.heads, self.image_size, self.kernel_size, self.dilation, self.seq_len, x.device
        # 根据是否稳定计算 softmax 函数
        softmax = torch.softmax if not self.stable else stable_softmax

        # 计算图像序列长度
        img_seq_len = img_size ** 2
        # 计算文本长度
        text_len = seq_len + 1 - img_seq_len

        # 填充

        # 计算填充长度
        padding = seq_len - n + 1
        # 如果 mask 为 None,则创建全为 True 的 mask 张量
        mask = default(mask, lambda: torch.ones(b, text_len, device = device).bool())

        # 对输入 x 进行填充
        x = F.pad(x, (0, 0, 0, padding), value = 0)
        # 裁剪 mask 的长度
        mask = mask[:, :text_len]

        # 求解查询 / 键 / 值

        # 将输入 x 转换为查询、键、值
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        # 重排查询、键、值的维度
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), qkv)

        # 如果存在旋转位置嵌入,则应用到查询、键、值上
        if exists(rotary_pos_emb):
            q, k, v = apply_pos_emb(rotary_pos_emb, (q, k, v))

        # 缩放查询
        q *= self.scale

        # 分离文本查询、图像查询、文本键、图像键、文本值、图像值
        ((q_text, q_img), (k_text, k_img), (v_text, v_img)) = map(lambda t: (t[:, :-img_seq_len], t[:, -img_seq_len:]), (q, k, v))

        # 文本注意力

        # 计算点积注意力得分
        dots_text = einsum('b i d, b j d -> b i j', q_text, k_text)
        # 计算 mask 的值
        mask_value = max_neg_value(dots_text)

        i, j = dots_text.shape[-2:]
        # 创建文本因果 mask
        text_causal_mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
        dots_text.masked_fill_(text_causal_mask, mask_value)

        # 计算文本注意力权重
        attn_text = softmax(dots_text, dim = -1)
        out_text = einsum('b i j, b j d -> b i d', attn_text, v_text)

        # 图像注意力

        # 计算有效卷积核大小
        effective_kernel_size = (kernel_size - 1) * dilation + 1
        same_padding = effective_kernel_size // 2
        causal_padding = (same_padding * 2, 0, same_padding * 2, 0)

        # 重排图像键、值的维度
        k_img, v_img = map(lambda t: rearrange(t, 'b (h w) c -> b c h w', h = img_size), (k_img, v_img))
        # 对图像键、值进行填充
        k_img, v_img = map(lambda t: F.pad(t, causal_padding), (k_img, v_img))
        k_img, v_img = map(lambda t: F.unfold(t, kernel_size, dilation = dilation), (k_img, v_img))
        k_img, v_img = map(lambda t: rearrange(t, 'b (d j) i -> b i j d', j = kernel_size ** 2), (k_img, v_img))

        # 让图像关注所有文本

        dots_image = einsum('b i d, b i j d -> b i j', q_img, k_img)
        dots_image_to_text = einsum('b i d, b j d -> b i j', q_img, k_text)

        # 使用填充 mask 对张量进行填充和展开
        i, j = dots_image.shape[-2:]
        ones = torch.ones((img_seq_len,), device = device)
        ones = rearrange(ones, '(h w) -> () () h w', h = img_size)
        ones = F.pad(ones, causal_padding, value = 0.)
        ones = F.unfold(ones, kernel_size, dilation = dilation)
        ones = rearrange(ones, 'b j i -> b i j')

        # 对图像注意力进行 mask
        padding_mask = ones == 0.

        # 将文本 mask 与图像因果 mask 连接起来
        padding_mask = repeat(padding_mask, '() i j -> b i j', b = b * h)
        mask = repeat(mask, 'b j -> (b h) i j', i = i, h = h)
        mask = torch.cat((~mask, padding_mask), dim = -1)

        # 图像可以关注所有文本

        dots = torch.cat((dots_image_to_text, dots_image), dim = -1)
        dots.masked_fill_(mask, mask_value)

        attn = softmax(dots, dim = -1)

        # 聚合

        attn_image_to_text, attn_image = attn[..., :text_len], attn[..., text_len:]

        out_image_to_image = einsum('b i j, b i j d -> b i d', attn_image, v_img)
        out_image_to_text = einsum('b i j, b j d -> b i d', attn_image_to_text, v_text)

        out_image = out_image_to_image + out_image_to_text

        # 合并文本和图像的注意力值

        out = torch.cat((out_text, out_image), dim = 1)

        out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
        out =  self.to_out(out)
        return out[:, :n]
# 稀疏轴向因果注意力机制

class SparseAxialCausalAttention(nn.Module):
    # 初始化函数,定义稀疏轴向因果注意力机制的参数
    def __init__(self, dim, seq_len, image_size = 32, axis = 0, heads = 8, dim_head = 64, dropout = 0., stable = False, **kwargs):
        super().__init__()
        # 断言轴向参数只能是0(沿高度)或1(沿宽度)
        assert axis in {0, 1}, 'axis must be either 0 (along height) or 1 (along width)'
        self.axis = axis

        # 计算内部维度
        inner_dim = dim_head *  heads
        self.seq_len = seq_len
        self.heads = heads
        # 缩放因子
        self.scale = dim_head ** -0.5
        self.image_size = image_size

        # 是否稳定
        self.stable = stable

        # 线性变换,将输入维度映射到内部维度的3倍(用于查询、键、值)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        # 输出层,包含线性变换和dropout
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        )
    # 定义前向传播函数,接受输入 x,mask 和旋转位置嵌入 rotary_pos_emb
    def forward(self, x, mask = None, rotary_pos_emb = None):
        # 解包 x 的形状信息,包括 batch 大小 b,序列长度 n,头数 h,图像大小 img_size,轴 axis,序列长度 seq_len,设备 device
        b, n, _, h, img_size, axis, seq_len, device = *x.shape, self.heads, self.image_size, self.axis, self.seq_len, x.device
        # 根据是否稳定计算 softmax 函数
        softmax = torch.softmax if not self.stable else stable_softmax

        # 计算图像序列长度和文本序列长度
        img_seq_len = img_size ** 2
        text_len = seq_len + 1 - img_seq_len

        # 填充

        # 计算需要填充的长度
        padding = seq_len - n + 1
        # 如果 mask 为 None,则创建全为 True 的 mask 张量
        mask = default(mask, lambda: torch.ones(b, text_len, device = device).bool())

        # 对输入 x 进行填充
        x = F.pad(x, (0, 0, 0, padding), value = 0)
        mask = mask[:, :text_len]

        # 求解查询 / 键 / 值

        # 将输入 x 转换为查询、键、值,并按维度 -1 切分
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), qkv)

        # 如果存在旋转位置嵌入,则应用到查询、键、值上
        if exists(rotary_pos_emb):
            q, k, v = apply_pos_emb(rotary_pos_emb, (q, k, v))

        # 缩放查询
        q *= self.scale

        # 拆分文本查询、图像查询、文本键、图像键、文本值、图像值
        ((q_text, q_img), (k_text, k_img), (v_text, v_img)) = map(lambda t: (t[:, :-img_seq_len], t[:, -img_seq_len:]), (q, k, v))

        # 文本注意力

        # 计算文本查询和文本键的点积
        dots_text = einsum('b i d, b j d -> b i j', q_text, k_text)
        mask_value = max_neg_value(dots_text)

        i, j = dots_text.shape[-2:]
        # 创建文本因果 mask
        text_causal_mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
        dots_text.masked_fill_(text_causal_mask, mask_value)

        # 计算文本注意力权重
        attn_text = softmax(dots_text, dim = -1)
        out_text = einsum('b i j, b j d -> b i d', attn_text, v_text)

        # 图像注意力

        # 根据轴 axis 拆分图像查询、图像键、图像值
        split_axis_einops = 'b (h w) c -> b h w c' if axis == 0 else 'b (h w) c -> b w h c'
        merge_axis_einops = 'b x n d -> b (x n) d' if axis == 0 else 'b x n d -> b (n x) d'

        # 拆分轴

        q_img, k_img, v_img = map(lambda t: rearrange(t, split_axis_einops, h = img_size), (q_img, k_img, v_img))

        # 相似度

        dots_image_to_image = einsum('b x i d, b x j d -> b x i j', q_img, k_img)
        dots_image_to_text = einsum('b x i d, b j d -> b x i j', q_img, k_text)

        dots = torch.cat((dots_image_to_text, dots_image_to_image), dim = -1)

        # mask 以使图像对文本有完全注意力,但沿轴是因果的

        bh, x, i, j = dots.shape
        causal_mask = torch.ones(i, img_size, device = device).triu_(img_size - i + 1).bool()
        causal_mask = repeat(causal_mask, 'i j -> b x i j', b = bh, x = x)

        mask = repeat(mask, 'b j -> (b h) x i j', h = h, x = x, i = i)
        mask = torch.cat((~mask, causal_mask), dim = -1)

        dots.masked_fill_(mask, mask_value)

        # 注意力

        attn = softmax(dots, dim = -1)

        # 聚合

        attn_image_to_text, attn_image_to_image = attn[..., :text_len], attn[..., text_len:]

        out_image_to_image = einsum('b x i j, b x j d -> b x i d', attn_image_to_image, v_img)
        out_image_to_text = einsum('b x i j, b j d -> b x i d', attn_image_to_text, v_text)

        out_image = out_image_to_image + out_image_to_text

        # 合并轴

        out_image = rearrange(out_image, merge_axis_einops, x = img_size)

        # 合并文本和图像的注意力值

        out = torch.cat((out_text, out_image), dim = 1)

        out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
        out =  self.to_out(out)
        return out[:, :n]
# 定义 SparseAttention 类,继承自 Attention 类
class SparseAttention(Attention):
    # 初始化函数
    def __init__(
        self,
        *args,
        block_size = 16,  # 定义块大小,默认为16
        text_seq_len = 256,  # 定义文本序列长度,默认为256
        num_random_blocks = None,  # 定义随机块数,默认为None
        **kwargs
    ):
        super().__init__(*args, **kwargs)  # 调用父类的初始化函数
        from deepspeed.ops.sparse_attention import SparseSelfAttention, VariableSparsityConfig  # 导入相关模块
        self.block_size = block_size  # 设置块大小

        num_random_blocks = default(num_random_blocks, self.seq_len // block_size // 4)  # 计算随机块数
        global_block_indices = list(range(ceil(text_seq_len / block_size)))  # 计算全局块索引

        # 初始化稀疏自注意力机制
        self.attn_fn = SparseSelfAttention(
            sparsity_config = VariableSparsityConfig(
                num_heads = self.heads,
                block = self.block_size,
                num_random_blocks = num_random_blocks,
                global_block_indices = global_block_indices,
                attention = 'unidirectional' if self.causal else 'bidirectional'
            ),
            max_seq_length = self.seq_len,
            attn_mask_mode = 'add'
        )

    # 前向传播函数
    def forward(self, x, mask = None, rotary_pos_emb = None):
        b, n, _, h, device = *x.shape, self.heads, x.device  # 获取输入张量的形状和设备信息
        remainder = n % self.block_size  # 计算余数
        mask = default(mask, lambda: torch.ones(b, n, device = device).bool())  # 设置默认掩码

        if remainder > 0:
            padding = self.block_size - remainder  # 计算填充大小
            x = F.pad(x, (0, 0, 0, padding), value = 0)  # 对输入张量进行填充
            mask = F.pad(mask, (0, padding), value = False)  # 对掩码进行填充

        qkv = self.to_qkv(x).chunk(3, dim = -1)  # 将输入张量转换为查询、键、值
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)  # 重排查询、键、值的维度

        if exists(rotary_pos_emb):  # 如果存在旋转位置编码
            q, k, v = apply_pos_emb(rotary_pos_emb, (q, k, v))  # 应用位置编码

        key_pad_mask = None  # 初始化键掩码
        if exists(mask):  # 如果存在掩码
            key_pad_mask = ~mask  # 生成键掩码

        attn_mask = None  # 初始化注意力掩码
        if self.causal:  # 如果是因果注意力
            i, j = q.shape[-2], k.shape[-2]  # 获取查询和键的长度
            mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()  # 生成上三角掩码
            attn_mask = torch.zeros(i, j, device = device).to(q)  # 初始化注意力掩码
            mask_value = max_neg_value(q) / 2  # 计算掩码值
            attn_mask.masked_fill_(mask, mask_value)  # 填充注意力掩码

        # 使用稀疏自注意力机制进行计算
        out = self.attn_fn(q, k, v, attn_mask = attn_mask, key_padding_mask = key_pad_mask)
        out = rearrange(out, 'b h n d -> b n (h d)')  # 重排输出维度
        out = self.to_out(out)  # 输出层处理
        return out[:, :n]  # 返回结果

.\lucidrains\DALLE-pytorch\dalle_pytorch\dalle_pytorch.py

# 从 math 模块中导入 log2 和 sqrt 函数
from math import log2, sqrt
# 导入 torch 库
import torch
# 从 torch 库中导入 nn 和 einsum 模块
from torch import nn, einsum
# 从 torch.nn.functional 模块中导入 F
import torch.nn.functional as F
# 导入 numpy 库
import numpy as np

# 导入自定义模块
from axial_positional_embedding import AxialPositionalEmbedding
from einops import rearrange

# 从 dalle_pytorch 库中导入 distributed_utils 模块
from dalle_pytorch import distributed_utils
# 从 dalle_pytorch.vae 模块中导入 OpenAIDiscreteVAE 和 VQGanVAE 类
from dalle_pytorch.vae import OpenAIDiscreteVAE, VQGanVAE
# 从 dalle_pytorch.transformer 模块中导入 Transformer 和 DivideMax 类

# helpers

# 定义函数,判断变量是否存在
def exists(val):
    return val is not None

# 定义函数,返回默认值
def default(val, d):
    return val if exists(val) else d

# 定义类,始终返回指定值
class always():
    def __init__(self, val):
        self.val = val
    def __call__(self, x, *args, **kwargs):
        return self.val

# 判断张量是否为空
def is_empty(t):
    return t.nelement() == 0

# 计算带掩码的平均值
def masked_mean(t, mask, dim = 1):
    t = t.masked_fill(~mask[:, :, None], 0.)
    return t.sum(dim = 1) / mask.sum(dim = 1)[..., None]

# 生成与给定形状相同的概率掩码
def prob_mask_like(shape, prob, device):
    return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob

# 设置模型参数是否需要梯度
def set_requires_grad(model, value):
    for param in model.parameters():
        param.requires_grad = value

# 评估装饰器
def eval_decorator(fn):
    def inner(model, *args, **kwargs):
        was_training = model.training
        model.eval()
        out = fn(model, *args, **kwargs)
        model.train(was_training)
        return out
    return inner

# sampling helpers

# 计算对数
def log(t, eps = 1e-20):
    return torch.log(t.clamp(min = eps)

# 生成 Gumbel 噪声
def gumbel_noise(t):
    noise = torch.zeros_like(t).uniform_(0, 1)
    return -log(-log(noise))

# Gumbel 采样
def gumbel_sample(t, temperature = 1., dim = -1):
    return ((t / temperature) + gumbel_noise(t)).argmax(dim = dim)

# Top-k 采样
def top_k(logits, thres = 0.5):
    num_logits = logits.shape[-1]
    k = max(int((1 - thres) * num_logits), 1)
    val, ind = torch.topk(logits, k)
    probs = torch.full_like(logits, float('-inf'))
    probs.scatter_(1, ind, val)
    return probs

# 共享嵌入层
class SharedEmbedding(nn.Embedding):
    def __init__(self, linear, start_index, end_index, **kwargs):
        super().__init__(end_index - start_index, linear.weight.shape[1], **kwargs)
        del self.weight

        self.linear = linear
        self.start_index = start_index
        self.end_index = end_index

    def forward(self, input):
        return F.embedding(
            input, self.linear.weight[self.start_index:self.end_index], self.padding_idx, self.max_norm,
            self.norm_type, self.scale_grad_by_freq, self.sparse)

# 离散 VAE 类

# ResNet 块
class ResBlock(nn.Module):
    def __init__(self, chan):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(chan, chan, 3, padding = 1),
            nn.ReLU(),
            nn.Conv2d(chan, chan, 3, padding = 1),
            nn.ReLU(),
            nn.Conv2d(chan, chan, 1)
        )

    def forward(self, x):
        return self.net(x) + x

# 离散 VAE 类
class DiscreteVAE(nn.Module):
    def __init__(
        self,
        image_size = 256,
        num_tokens = 512,
        codebook_dim = 512,
        num_layers = 3,
        num_resnet_blocks = 0,
        hidden_dim = 64,
        channels = 3,
        smooth_l1_loss = False,
        temperature = 0.9,
        straight_through = False,
        reinmax = False,
        kl_div_loss_weight = 0.,
        normalization = ((*((0.5,) * 3), 0), (*((0.5,) * 3), 1))
    ):
        # 调用父类的构造函数
        super().__init__()
        # 断言图片大小必须是2的幂次方
        assert log2(image_size).is_integer(), 'image size must be a power of 2'
        # 断言层数必须大于等于1
        assert num_layers >= 1, 'number of layers must be greater than or equal to 1'
        # 判断是否有残差块
        has_resblocks = num_resnet_blocks > 0

        # 初始化各种参数
        self.channels = channels
        self.image_size = image_size
        self.num_tokens = num_tokens
        self.num_layers = num_layers
        self.temperature = temperature
        self.straight_through = straight_through
        self.reinmax = reinmax

        # 创建编码簿
        self.codebook = nn.Embedding(num_tokens, codebook_dim)

        hdim = hidden_dim

        # 初始化编码器和解码器通道数
        enc_chans = [hidden_dim] * num_layers
        dec_chans = list(reversed(enc_chans))

        enc_chans = [channels, *enc_chans]

        dec_init_chan = codebook_dim if not has_resblocks else dec_chans[0]
        dec_chans = [dec_init_chan, *dec_chans]

        enc_chans_io, dec_chans_io = map(lambda t: list(zip(t[:-1], t[1:])), (enc_chans, dec_chans))

        enc_layers = []
        dec_layers = []

        # 创建编码器和解码器的层
        for (enc_in, enc_out), (dec_in, dec_out) in zip(enc_chans_io, dec_chans_io):
            enc_layers.append(nn.Sequential(nn.Conv2d(enc_in, enc_out, 4, stride = 2, padding = 1), nn.ReLU()))
            dec_layers.append(nn.Sequential(nn.ConvTranspose2d(dec_in, dec_out, 4, stride = 2, padding = 1), nn.ReLU()))

        # 添加残差块
        for _ in range(num_resnet_blocks):
            dec_layers.insert(0, ResBlock(dec_chans[1]))
            enc_layers.append(ResBlock(enc_chans[-1]))

        if num_resnet_blocks > 0:
            dec_layers.insert(0, nn.Conv2d(codebook_dim, dec_chans[1], 1))

        enc_layers.append(nn.Conv2d(enc_chans[-1], num_tokens, 1))
        dec_layers.append(nn.Conv2d(dec_chans[-1], channels, 1))

        # 创建编码器和解码器
        self.encoder = nn.Sequential(*enc_layers)
        self.decoder = nn.Sequential(*dec_layers)

        # 设置损失函数和 KL 散度损失权重
        self.loss_fn = F.smooth_l1_loss if smooth_l1_loss else F.mse_loss
        self.kl_div_loss_weight = kl_div_loss_weight

        # 处理类内的归一化
        self.normalization = tuple(map(lambda t: t[:channels], normalization))

        # 注册外部参数
        self._register_external_parameters()

    def _register_external_parameters(self):
        """Register external parameters for DeepSpeed partitioning."""
        if (
                not distributed_utils.is_distributed
                or not distributed_utils.using_backend(
                    distributed_utils.DeepSpeedBackend)
        ):
            return

        deepspeed = distributed_utils.backend.backend_module
        deepspeed.zero.register_external_parameter(self, self.codebook.weight)

    def norm(self, images):
        if not exists(self.normalization):
            return images

        means, stds = map(lambda t: torch.as_tensor(t).to(images), self.normalization)
        means, stds = map(lambda t: rearrange(t, 'c -> () c () ()'), (means, stds))
        images = images.clone()
        images.sub_(means).div_(stds)
        return images

    @torch.no_grad()
    @eval_decorator
    def get_codebook_indices(self, images):
        logits = self(images, return_logits = True)
        codebook_indices = logits.argmax(dim = 1).flatten(1)
        return codebook_indices

    def decode(
        self,
        img_seq
    ):
        image_embeds = self.codebook(img_seq)
        b, n, d = image_embeds.shape
        h = w = int(sqrt(n))

        image_embeds = rearrange(image_embeds, 'b (h w) d -> b d h w', h = h, w = w)
        images = self.decoder(image_embeds)
        return images

    def forward(
        self,
        img,
        return_loss = False,
        return_recons = False,
        return_logits = False,
        temp = None
        ):
        # 从输入参数中获取图像、标记数量、图像大小和 KL 散度损失权重
        device, num_tokens, image_size, kl_div_loss_weight = img.device, self.num_tokens, self.image_size, self.kl_div_loss_weight
        # 断言输入图像的形状符合要求
        assert img.shape[-1] == image_size and img.shape[-2] == image_size, f'input must have the correct image size {image_size}'

        # 对输入图像进行归一化处理
        img = self.norm(img)

        # 将归一化后的图像输入编码器获取 logits
        logits = self.encoder(img)

        # 如果需要返回 logits,则直接返回,用于 DALL-E 训练中获取硬图像索引
        if return_logits:
            return logits

        # 获取温度参数,默认为 self.temperature
        temp = default(temp, self.temperature)

        # 使用 Gumbel Softmax 采样生成 one-hot 编码
        one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=self.straight_through)

        # 如果使用 straight-through 和 reinmax
        if self.straight_through and self.reinmax:
            # 使用 reinmax 提高二阶精度 - https://arxiv.org/abs/2304.08612
            # 算法 2
            one_hot = one_hot.detach()
            π0 = logits.softmax(dim=1)
            π1 = (one_hot + (logits / temp).softmax(dim=1)) / 2
            π1 = ((log(π1) - logits).detach() + logits).softmax(dim=1)
            π2 = 2 * π1 - 0.5 * π0
            one_hot = π2 - π2.detach() + one_hot

        # 使用 one-hot 编码和 codebook 权重进行采样
        sampled = einsum('b n h w, n d -> b d h w', one_hot, self.codebook.weight)
        # 将采样结果输入解码器获取输出
        out = self.decoder(sampled)

        # 如果不需要返回损失,则直接返回输出
        if not return_loss:
            return out

        # 重构损失
        recon_loss = self.loss_fn(img, out)

        # KL 散度
        logits = rearrange(logits, 'b n h w -> b (h w) n')
        log_qy = F.log_softmax(logits, dim=-1)
        log_uniform = torch.log(torch.tensor([1. / num_tokens], device=device))
        kl_div = F.kl_div(log_uniform, log_qy, None, None, 'batchmean', log_target=True)

        # 计算总损失
        loss = recon_loss + (kl_div * kl_div_loss_weight)

        # 如果不需要返回重构图像,则直接返回总损失
        if not return_recons:
            return loss

        # 返回总损失和输出图像
        return loss, out
# 主要的 CLIP 类
class CLIP(nn.Module):
    # 初始化函数
    def __init__(
        self,
        *,
        dim_text = 512,  # 文本维度
        dim_image = 512,  # 图像维度
        dim_latent = 512,  # 潜在维度
        num_text_tokens = 10000,  # 文本标记数量
        text_enc_depth = 6,  # 文本编码器深度
        text_seq_len = 256,  # 文本序列长度
        text_heads = 8,  # 文本注意力头数
        num_visual_tokens = 512,  # 视觉标记数量
        visual_enc_depth = 6,  # 视觉编码器深度
        visual_heads = 8,  # 视觉注意力头数
        visual_image_size = 256,  # 视觉图像大小
        visual_patch_size = 32,  # 视觉图像块大小
        channels = 3  # 通道数
    ):
        super().__init__()
        # 创建文本嵌入层
        self.text_emb = nn.Embedding(num_text_tokens, dim_text)
        # 创建文本位置嵌入层
        self.text_pos_emb = nn.Embedding(text_seq_len, dim_text)
        # 创建文本变换器
        self.text_transformer = Transformer(causal = False, seq_len = text_seq_len, dim = dim_text, depth = text_enc_depth, heads = text_heads, rotary_emb = False)
        # 创建文本到潜在空间的线性层
        self.to_text_latent = nn.Linear(dim_text, dim_latent, bias = False)

        # 确保图像尺寸能够被图像块大小整除
        assert visual_image_size % visual_patch_size == 0, 'Image dimensions must be divisible by the patch size.'
        num_patches = (visual_image_size // visual_patch_size) ** 2
        patch_dim = channels * visual_patch_size ** 2

        self.visual_patch_size = visual_patch_size
        # 创建图像块到嵌入空间的线性层
        self.to_visual_embedding = nn.Linear(patch_dim, dim_image)
        # 创建图像位置嵌入层
        self.visual_pos_emb = nn.Embedding(num_patches, dim_image)
        # 创建视觉变换器
        self.visual_transformer = Transformer(causal = False, seq_len = num_patches, dim = dim_image, depth = visual_enc_depth, heads = visual_heads, rotary_emb = False)
        # 创建图像到潜在空间的线性层
        self.to_visual_latent = nn.Linear(dim_image, dim_latent, bias = False)

        # 温度参数
        self.temperature = nn.Parameter(torch.tensor(1.))

    # 前向传播函数
    def forward(
        self,
        text,
        image,
        text_mask = None,
        return_loss = False
    ):
        b, device, p = text.shape[0], text.device, self.visual_patch_size

        # 文本嵌入
        text_emb = self.text_emb(text)
        text_emb += self.text_pos_emb(torch.arange(text.shape[1], device = device))

        # 图像块提取
        image_patches = rearrange(image, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
        image_emb = self.to_visual_embedding(image_patches)
        image_emb += self.visual_pos_emb(torch.arange(image_emb.shape[1], device = device))

        # 文本编码
        enc_text = self.text_transformer(text_emb, mask = text_mask)
        # 图像编码
        enc_image = self.visual_transformer(image_emb)

        # 计算文本潜在空间表示
        if exists(text_mask):
            text_latents = masked_mean(enc_text, text_mask, dim = 1)
        else:
            text_latents = enc_text.mean(dim = 1)

        # 计算图像潜在空间表示
        image_latents = enc_image.mean(dim = 1)

        # 线性变换
        text_latents = self.to_text_latent(text_latents)
        image_latents = self.to_visual_latent(image_latents)

        # 归一化
        text_latents, image_latents = map(lambda t: F.normalize(t, p = 2, dim = -1), (text_latents, image_latents))

        temp = self.temperature.exp()

        # 如果不需要计算损失,则返回相似度
        if not return_loss:
            sim = einsum('n d, n d -> n', text_latents, image_latents) * temp
            return sim

        # 计算损失
        sim = einsum('i d, j d -> i j', text_latents, image_latents) * temp
        labels = torch.arange(b, device = device)
        loss = (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2
        return loss

# 主要的 DALL-E 类
class DALLE(nn.Module):
    # 初始化函数
    def __init__(
        self,
        *,
        dim,
        vae,
        num_text_tokens = 10000,
        text_seq_len = 256,
        depth,
        heads = 8,
        dim_head = 64,
        reversible = False,
        attn_dropout = 0.,
        ff_dropout = 0,
        sparse_attn = False,
        attn_types = None,
        loss_img_weight = 7,
        stable = False,
        sandwich_norm = False,
        shift_tokens = True,
        rotary_emb = True,
        shared_attn_ids = None,
        shared_ff_ids = None,
        share_input_output_emb = False,
        optimize_for_inference = False,
    ):
        # 调用父类的构造函数
        super().__init__()
        # 断言确保 vae 是 DiscreteVAE、OpenAIDiscreteVAE 或 VQGanVAE 的实例
        assert isinstance(vae, (DiscreteVAE, OpenAIDiscreteVAE, VQGanVAE)), 'vae must be an instance of DiscreteVAE'

        # 获取图像大小、图像标记数量、图像特征图大小和图像序列长度
        image_size = vae.image_size
        num_image_tokens = vae.num_tokens
        image_fmap_size = (vae.image_size // (2 ** vae.num_layers))
        image_seq_len = image_fmap_size ** 2

        # 为每个位置(文本序列长度)保留唯一的填充标记
        num_text_tokens = num_text_tokens + text_seq_len
        # 创建文本位置嵌入和图像位置嵌入
        self.text_pos_emb = nn.Embedding(text_seq_len + 1, dim) if not rotary_emb else always(0) # +1 for <bos>
        self.image_pos_emb = AxialPositionalEmbedding(dim, axial_shape = (image_fmap_size, image_fmap_size)) if not rotary_emb else always(0)

        # 设置文本标记数量和图像标记数量
        self.num_text_tokens = num_text_tokens
        self.num_image_tokens = num_image_tokens

        # 设置文本序列长度和图像序列长度
        self.text_seq_len = text_seq_len
        self.image_seq_len = image_seq_len

        # 计算总序列长度和总标记数量
        seq_len = text_seq_len + image_seq_len
        total_tokens = num_text_tokens + num_image_tokens
        self.total_tokens = total_tokens
        self.total_seq_len = seq_len

        # 冻结 VAE 不参与训练
        self.vae = vae
        set_requires_grad(self.vae, False)

        # 创建 Transformer 模型
        self.transformer = Transformer(
            dim = dim,
            causal = True,
            seq_len = seq_len,
            depth = depth,
            heads = heads,
            dim_head = dim_head,
            reversible = reversible,
            attn_dropout = attn_dropout,
            ff_dropout = ff_dropout,
            attn_types = attn_types,
            image_fmap_size = image_fmap_size,
            sparse_attn = sparse_attn,
            stable = stable,
            sandwich_norm = sandwich_norm,
            shift_tokens = shift_tokens,
            rotary_emb = rotary_emb,
            shared_attn_ids = shared_attn_ids,
            shared_ff_ids = shared_ff_ids,
            optimize_for_inference = optimize_for_inference,
        )

        # 设置稳定性参数
        self.stable = stable

        # 如果稳定性为真,使用 DivideMax 进行归一化
        if stable:
            self.norm_by_max = DivideMax(dim = -1)

        # 转换为 logits
        self.to_logits = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, self.total_tokens),
        )

        # 如果共享输入输出嵌入,创建共享嵌入层,否则创建独立嵌入层
        if share_input_output_emb:
            self.text_emb = SharedEmbedding(self.to_logits[1], 0, num_text_tokens)
            self.image_emb = SharedEmbedding(self.to_logits[1], num_text_tokens, total_tokens)
        else:
            self.text_emb = nn.Embedding(num_text_tokens, dim)
            self.image_emb = nn.Embedding(num_image_tokens, dim)

        # 创建序列范围和 logits 范围
        seq_range = torch.arange(seq_len)
        logits_range = torch.arange(total_tokens)

        seq_range = rearrange(seq_range, 'n -> () n ()')
        logits_range = rearrange(logits_range, 'd -> () () d')

        # 创建 logits 掩码
        logits_mask = (
            ((seq_range >= text_seq_len) & (logits_range < num_text_tokens)) |
            ((seq_range < text_seq_len) & (logits_range >= num_text_tokens))
        )

        # 注册 logits 掩码为缓冲区
        self.register_buffer('logits_mask', logits_mask, persistent=False)
        self.loss_img_weight = loss_img_weight


    @torch.no_grad()
    @eval_decorator
    def generate_texts(
        self,
        tokenizer,
        text = None,
        *,
        filter_thres = 0.5,
        temperature = 1.
        ):
        # 获取文本序列长度
        text_seq_len = self.text_seq_len
        # 如果文本为空或者为None,则将文本tokens设置为0,并移至GPU
        if text is None or text == "":
            text_tokens = torch.tensor([[0]]).cuda()
        else:
            # 将文本编码为tokens,并移至GPU
            text_tokens = torch.tensor(tokenizer.tokenizer.encode(text)).cuda().unsqueeze(0)

        # 循环直到文本tokens长度达到指定长度
        for _ in range(text_tokens.shape[1], text_seq_len):
            # 获取当前设备
            device = text_tokens.device

            # 获取文本tokens的嵌入
            tokens = self.text_emb(text_tokens)
            # 添加文本位置嵌入
            tokens += self.text_pos_emb(torch.arange(text_tokens.shape[1], device=device))

            # 获取tokens序列长度
            seq_len = tokens.shape[1]

            # 使用transformer处理tokens
            output_transf = self.transformer(tokens)

            # 如果启用了稳定性,对输出进行归一化
            if self.stable:
                output_transf = self.norm_by_max(output_transf)

            # 获取logits
            logits = self.to_logits(output_transf)

            # 对logits进行掩码,确保文本预测文本(除了最后一个token),图像预测图像
            logits_mask = self.logits_mask[:, :seq_len]
            max_neg_value = -torch.finfo(logits.dtype).max
            logits.masked_fill_(logits_mask, max_neg_value)
            logits = logits[:, -1, :]

            # 从logits中筛选出top k的token
            filtered_logits = top_k(logits, thres=filter_thres)
            # 使用Gumbel采样获取样本
            sample = gumbel_sample(filtered_logits, temperature=temperature, dim=-1)

            # 将新样本添加到文本tokens中
            text_tokens = torch.cat((text_tokens, sample[:, None]), dim=-1)

        # 创建填充tokens集合
        padding_tokens = set(np.arange(self.text_seq_len) + (self.num_text_tokens - self.text_seq_len))
        # 解码文本tokens,获取文本列表
        texts = [tokenizer.tokenizer.decode(text_token, pad_tokens=padding_tokens) for text_token in text_tokens]
        return text_tokens, texts

    @torch.no_grad()
    @eval_decorator
    def generate_images(
        self,
        text,
        *,
        clip=None,
        filter_thres=0.5,
        temperature=1.,
        img=None,
        num_init_img_tokens=None,
        cond_scale=1.,
        use_cache=False,
    ):
        # 获取VAE模型、文��序列长度、图像序列长度、文本tokens数量
        vae, text_seq_len, image_seq_len, num_text_tokens = self.vae, self.text_seq_len, self.image_seq_len, self.num_text_tokens
        # 计算总长度
        total_len = text_seq_len + image_seq_len

        # 确保文本在指定范围内
        text = text[:, :text_seq_len]
        out = text

        # 如果存在图像输入
        if exists(img):
            # 获取图像大小
            image_size = vae.image_size
            assert img.shape[1] == 3 and img.shape[2] == image_size and img.shape[3] == image_size, f'input image must have the correct image size {image_size}'

            # 获取图像的codebook索引
            indices = vae.get_codebook_indices(img)
            # 设置初始图像tokens数量
            num_img_tokens = default(num_init_img_tokens, int(0.4375 * image_seq_len))  # OpenAI used 14 * 32 initial tokens to prime
            assert num_img_tokens < image_seq_len, 'number of initial image tokens for priming must be less than the total image token sequence length'

            indices = indices[:, :num_img_tokens]
            out = torch.cat((out, indices), dim=-1)

        prev_cache = None
        cache = {} if use_cache else None
        # 循环直到out的长度达到总长度
        for cur_len in range(out.shape[1], total_len):
            is_image = cur_len >= text_seq_len

            text, image = out[:, :text_seq_len], out[:, text_seq_len:]

            # 使用条件缩放处理文本和图像
            logits = self.forward_with_cond_scale(text, image, cond_scale=cond_scale, cache=cache)
            logits = logits[:, -1, :]

            # 从logits中筛选出top k的token
            filtered_logits = top_k(logits, thres=filter_thres)
            # 使用Gumbel采样获取样本
            sample = gumbel_sample(filtered_logits, temperature=temperature, dim=-1)

            # 如果是图像token,减去num_text_tokens的偏移量
            sample -= (num_text_tokens if is_image else 0)
            out = torch.cat((out, sample[:, None]), dim=-1)

        # 获取文本序列和图像序列
        text_seq = out[:, :text_seq_len]
        img_seq = out[:, -image_seq_len:]
        # 解码图像序列
        images = vae.decode(img_seq)

        # 如果存在clip模型
        if exists(clip):
            # 使用clip模型评分
            scores = clip(text_seq, images, return_loss=False)
            return images, scores

        return images
    # 定义一个带有条件缩放参数的前向传播函数
    def forward_with_cond_scale(self, *args, cond_scale = 1, cache = None, **kwargs):
        # 如果条件缩放参数为1,则直接调用原始的前向传播函数
        if cond_scale == 1:
            return self(*args, **kwargs)

        # 如果缓存存在,则复制缓存,否则设为None
        prev_cache = cache.copy() if exists(cache) else None
        # 调用原始的前向传播函数,传入缓存参数
        logits = self(*args, cache = cache, **kwargs)

        # Katherine Crowson的发现
        # https://twitter.com/RiversHaveWings/status/1478093658716966912
        # 传入空条件概率为1的参数,调用原始的前向传播函数
        null_cond_logits = self(*args, null_cond_prob = 1., cache = prev_cache, **kwargs)
        # 返回空条件logits加上(原始logits减去空条件logits)乘以条件缩放参数的结果
        return null_cond_logits + (logits - null_cond_logits) * cond_scale

    # 定义一个前向传播函数,接受文本、图像、是否返回损失、空条件概率和缓存等参数
    def forward(
        self,
        text,
        image = None,
        return_loss = False,
        null_cond_prob = 0.,
        cache = None,
    ):
        # 检查传入的文本张量是否与指定的文本序列长度相匹配
        assert text.shape[-1] == self.text_seq_len, f'the length {text.shape[-1]} of the text tokens you passed in does not have the correct length ({self.text_seq_len})'
        # 获取文本张量的批次大小、设备信息和总序列长度
        batch, device, total_seq_len = text.shape[0], text.device, self.total_seq_len

        # 以 <null_cond_prob> 的概率随机移除文本条件

        if null_cond_prob > 0:
            # 创建一个与文本张量形状相同的概率掩码,用于随机移除文本条件
            null_mask = prob_mask_like((batch,), null_cond_prob, device=device)
            # 将文本张量中的部分内容根据概率掩码置零
            text *= rearrange(~null_mask, 'b -> b 1')

        # 确保文本标记中的填充获得唯一的填充标记ID

        # 生成文本范围,用于替换文本张量中的填充标记
        text_range = torch.arange(self.text_seq_len, device=device) + (self.num_text_tokens - self.text_seq_len)
        text = torch.where(text == 0, text_range, text)

        # 添加 <bos> 标记

        # 在文本张量的开头添加一个零值填充
        text = F.pad(text, (1, 0), value=0)

        # 对文本进行嵌入处理
        tokens = self.text_emb(text)
        # 添加文本位置编码
        tokens += self.text_pos_emb(torch.arange(text.shape[1], device=device))

        seq_len = tokens.shape[1]

        # 如果存在图像且图像不为空
        if exists(image) and not is_empty(image):
            is_raw_image = len(image.shape) == 4

            if is_raw_image:
                # 获取图像的代码簿索引
                image_size = self.vae.image_size
                channels = self.vae.channels
                assert tuple(image.shape[1:]) == (channels, image_size, image_size), f'invalid image of dimensions {image.shape} passed in during training'

                image = self.vae.get_codebook_indices(image)

            image_len = image.shape[1]
            image_emb = self.image_emb(image)

            # 添加图像位置编码
            image_emb += self.image_pos_emb(image_emb)

            # 将文本和图像嵌入连接起来
            tokens = torch.cat((tokens, image_emb), dim=1)

            seq_len += image_len

        # 在训练时,如果长度超过总文本+图像长度,则移除最后一个标记,因为不需要对其进行训练

        if tokens.shape[1] > total_seq_len:
            seq_len -= 1
            tokens = tokens[:, :-1]

        # ���果启用稳定性训练
        if self.stable:
            alpha = 0.1
            # 对 tokens 进行稳定性训练
            tokens = tokens * alpha + tokens.detach() * (1 - alpha)

        # 如果存在缓存且缓存中有 'offset' 键
        if exists(cache) and cache.get('offset'):
            # 仅保留 tokens 的最后一个标记
            tokens = tokens[:, -1:]
        # 使用 transformer 进行处理,传入缓存信息
        out = self.transformer(tokens, cache=cache)

        # 如果启用稳定性训练
        if self.stable:
            # 对输出进行最大归一化
            out = self.norm_by_max(out)

        # 将输出转换为 logits
        logits = self.to_logits(out)

        # 对 logits 进行掩码处理,确保文本预测文本(除最后一个标记),图像预测图像

        logits_mask = self.logits_mask[:, :seq_len]
        if exists(cache) and cache.get('offset'):
            logits_mask = logits_mask[:, -1:]
        max_neg_value = -torch.finfo(logits.dtype).max
        logits.masked_fill_(logits_mask, max_neg_value)

        # 如果存在缓存
        if exists(cache):
            # 更新缓存中的 'offset' 键
            cache['offset'] = cache.get('offset', 0) + logits.shape[1]

        # 如果不需要返回损失值
        if not return_loss:
            return logits

        # 断言在训练时必须提供图像
        assert exists(image), 'when training, image must be supplied'

        # 对图像进行偏移处理
        offsetted_image = image + self.num_text_tokens
        # 创建标签,用于计算损失
        labels = torch.cat((text[:, 1:], offsetted_image), dim=1)

        # 重新排列 logits 的维度
        logits = rearrange(logits, 'b n c -> b c n')

        # 计算文本损失和图像损失
        loss_text = F.cross_entropy(logits[:, :, :self.text_seq_len], labels[:, :self.text_seq_len])
        loss_img = F.cross_entropy(logits[:, :, self.text_seq_len:], labels[:, self.text_seq_len:])

        # 计算总损失
        loss = (loss_text + self.loss_img_weight * loss_img) / (self.loss_img_weight + 1)
        return loss

.\lucidrains\DALLE-pytorch\dalle_pytorch\distributed_backends\deepspeed_backend.py

import json
import os

import torch

from .distributed_backend import DistributedBackend


class DeepSpeedBackend(DistributedBackend):
    """使用 DeepSpeed 引擎的分布式后端。"""

    BACKEND_MODULE_NAME = 'deepspeed'
    BACKEND_NAME = 'DeepSpeed'

    def wrap_arg_parser(self, parser):
        if not self.has_backend():
            parser.add_argument(
                '--deepspeed',
                type=lambda _: False,
                help=(
                    '是否使用 DeepSpeed '
                    "(由于不可用,此选项被忽略)"
                ),
            )
        else:
            parser = self.backend_module.add_config_arguments(parser)

        parser.add_argument(
            '--local_rank',
            type=int,
            default=-1,
            help='从分布式启动器传递的本地排名',
        )
        return parser

    def _initialize(self):
        self.backend_module.init_distributed()
        if torch.cuda.is_available():
            torch.cuda.set_device(self._get_local_rank())

    @staticmethod
    def _require_torch_distributed_init():
        """当 `torch.distributed` 尚未初始化时引发错误。"""
        assert torch.distributed.is_initialized(), \
            ('`torch.distributed` 未初始化;请在脚本开头调用 '
             '`DeepSpeedBackend.initialize`')

    def _get_world_size(self):
        self._require_torch_distributed_init()
        return torch.distributed.get_world_size()

    def _get_rank(self):
        self._require_torch_distributed_init()
        return torch.distributed.get_rank()

    def _get_local_rank(self):
        self._require_torch_distributed_init()
        return int(os.environ['LOCAL_RANK'])

    def _local_barrier(self):
        self._require_torch_distributed_init()
        torch.distributed.barrier()

    def _check_args(self, args, optimizer, lr_scheduler, kwargs):
        """在检查传递给 `distribute` 的值后,返回适当的优化器和学习率调度器。"""
        self._check_argvs(args, optimizer, lr_scheduler, kwargs)
        (optimizer, lr_scheduler) = self._check_config(
            args, optimizer, lr_scheduler, kwargs)
        return (optimizer, lr_scheduler)

    def _check_argvs(self, args, optimizer, lr_scheduler, kwargs):
        """对给定的命令行参数应用几个合理性检查。"""
        has_json_config = (hasattr(args, 'deepspeed_config')
                           and args.deepspeed_config is not None)
        has_dict_config = 'config_params' in kwargs
        if (
                # 没有给定配置
                (not has_json_config and not has_dict_config)
                # JSON 配置文件不存在
                or (not has_dict_config
                    and not os.path.isfile(args.deepspeed_config))
        ):
            # 让 DeepSpeed 处理这些参数错误。
            return

        if not args.deepspeed:
            print(
                '警告:已选择 DeepSpeed 后端;设置 `args.deepspeed = True`'
            )
            args.deepspeed = True

        if has_json_config and has_dict_config:
            print(
                '警告:DeepSpeed 配置同时以 JSON 文件和 Python 字典形式给出。Python 字典优先。'
            )
    def _check_config(self, args, optimizer, lr_scheduler, kwargs):
        """Return an appropriate optimizer and learning rate scheduler
        for the DeepSpeed configuration.
        """
        # 检查 DeepSpeed 配置,根据情况返回优化器和学习率调度器
        if 'config_params' in kwargs:
            config = kwargs['config_params']
        else:
            with open(args.deepspeed_config, 'r') as json_config_file:
                config = json.load(json_config_file)

        if 'optimizer' in config and optimizer is not None:
            print(
                'WARNING: Optimizer encountered in both DeepSpeed config and '
                'keyword arguments. Optimizer in DeepSpeed config '
                'takes precedence.'
            )
            optimizer = None

        if 'scheduler' in config and lr_scheduler is not None:
            print(
                'WARNING: Learning rate scheduler encountered in both '
                'DeepSpeed config and keyword arguments. Learning rate '
                'scheduler in DeepSpeed config takes precedence.'
            )
            # 对于 LR 调度器,JSON 配置已经具有优先权。我们这样做是为了向前兼容。
            lr_scheduler = None

        return (optimizer, lr_scheduler)

    def _distribute(
            self,
            args=None,
            model=None,
            optimizer=None,
            model_parameters=None,
            training_data=None,
            lr_scheduler=None,
            **kwargs,
    ):
        """Return a distributed model engine, optimizer, dataloader, and
        learning rate scheduler. These are obtained by wrapping the
        given values with the backend.

        For the other or other possible arguments,
        see `deepspeed.initialize`.
        """
        (optimizer, lr_scheduler) = self._check_args(
            args, optimizer, lr_scheduler, kwargs)

        return self.backend_module.initialize(
            args=args,
            model=model,
            optimizer=optimizer,
            model_parameters=model_parameters,
            training_data=training_data,
            lr_scheduler=lr_scheduler,
            **kwargs,
        )

    def _average_all(self, tensor):
        self._require_torch_distributed_init()
        # We copy because modification happens in-place
        averaged = tensor.detach().clone()
        # We use `all_reduce` because it is better supported than `reduce`
        torch.distributed.all_reduce(averaged, torch.distributed.ReduceOp.SUM)
        return averaged / self.get_world_size()

.\lucidrains\DALLE-pytorch\dalle_pytorch\distributed_backends\distributed_backend.py

"""
An abstract backend for distributed deep learning.

Provides several standard utility methods under a common API.
Please check the documentation of the class `DistributedBackend` for
details to implement a new backend.
"""

from importlib import import_module


class DistributedBackend:
    """An abstract backend class for distributed deep learning.

    Provides several standard utility methods under a common API.
    Variables that must be overridden:
    - BACKEND_MODULE_NAME
    - BACKEND_NAME
    Methods that must be overridden:
    - wrap_arg_parser
    - _initialize
    - _get_world_size
    - _get_rank
    - _get_local_rank
    - _local_barrier
    - _distribute
    - _average_all
    """

    BACKEND_MODULE_NAME = None
    """Name of the module to import for the backend."""
    BACKEND_NAME = None
    """Name of the backend for printing."""

    ROOT_RANK = 0

    backend_module = None
    """The module to access the backend."""
    is_initialized = False
    """Whether the backend is initialized."""

    def __init__(self):
        if self.BACKEND_MODULE_NAME is None:
            raise NotImplementedError('BACKEND_MODULE_NAME is not set')
        if self.BACKEND_NAME is None:
            raise NotImplementedError('BACKEND_NAME is not set')

    def has_backend(self):
        """Return whether the backend module is now imported."""
        try:
            self.backend_module = import_module(self.BACKEND_MODULE_NAME)
        except ModuleNotFoundError:
            return False
        return True

    def check_batch_size(self, batch_size):
        """Check whether the batch size makes sense for distribution."""
        assert batch_size >= self.get_world_size(), \
            (f"batch size can't be smaller than number of processes "
             f'({batch_size} < {self.get_world_size()})')

    def wrap_arg_parser(self, parser):
        """Add arguments to support optional distributed backend usage."""
        raise NotImplementedError

    def initialize(self):
        """Initialize the distributed backend."""
        self._initialize()
        self.is_initialized = True

    def _initialize(self):
        """Initialize the distributed backend."""
        raise NotImplementedError

    def require_init(self):
        """Raise an error when the backend has not been initialized yet."""
        assert self.is_initialized, \
            (f'{BACKEND_NAME} backend has not been initialized; please call '
             f'`distributed_utils.initialize` at the start of your script to '
             f'allow optional distributed usage')

    def get_world_size(self):
        """Return the amount of distributed processes."""
        self.require_init()
        return self._get_world_size()

    def _get_world_size(self):
        """Return the amount of distributed processes."""
        raise NotImplementedError

    def get_rank(self):
        """Return the global rank of the calling worker process."""
        self.require_init()
        return self._get_rank()

    def _get_rank(self):
        """Return the global rank of the calling worker process."""
        raise NotImplementedError

    def get_local_rank(self):
        """Return the local rank of the calling worker process.
        The local rank is the rank based on a single node's processes.
        """
        self.require_init()
        return self._get_local_rank()

    def _get_local_rank(self):
        """Return the local rank of the calling worker process.
        The local rank is the rank based on a single node's processes.
        """
        raise NotImplementedError

    def is_root_worker(self):
        """Return whether the calling worker has the root rank."""
        return self.get_rank() == self.ROOT_RANK

    def is_local_root_worker(self):
        """Return whether the calling worker has the root rank on this node."""
        return self.get_local_rank() == self.ROOT_RANK
    def local_barrier(self):
        """Wait until all processes on this node have called this function."""
        # 确保初始化已完成
        self.require_init()
        # 调用本地屏障函数
        self._local_barrier()

    def _local_barrier(self):
        """Wait until all processes on this node have called this function."""
        # 抛出未实现错误
        raise NotImplementedError

    def distribute(
            self,
            args=None,
            model=None,
            optimizer=None,
            model_parameters=None,
            training_data=None,
            lr_scheduler=None,
            **kwargs,
    ):
        """Return a distributed model engine, optimizer, dataloader, and
        learning rate scheduler. These are obtained by wrapping the
        given values with the backend.
        """
        # 确保初始化已完成
        self.require_init()
        # 调用分发函数
        return self._distribute(
            args,
            model,
            optimizer,
            model_parameters,
            training_data,
            lr_scheduler,
            **kwargs,
        )

    def _distribute(
            self,
            args=None,
            model=None,
            optimizer=None,
            model_parameters=None,
            training_data=None,
            lr_scheduler=None,
            **kwargs,
    ):
        """Return a distributed model engine, optimizer, dataloader, and
        learning rate scheduler. These are obtained by wrapping the
        given values with the backend.
        """
        # 抛出未实现错误
        raise NotImplementedError

    def average_all(self, tensor):
        """Return the average of `tensor` over all workers."""
        # 确保初始化已完成
        self.require_init()
        # 返回所有工作进程上张量的平均值
        return self._average_all(tensor)

    def _average_all(self, tensor):
        """Return the average of `tensor` over all workers."""
        # 抛出未实现错误
        raise NotImplementedError

.\lucidrains\DALLE-pytorch\dalle_pytorch\distributed_backends\dummy_backend.py

# 导入分布式后端类 DistributedBackend
from .distributed_backend import DistributedBackend

# 定义一个虚拟的分布式后端类 DummyBackend,继承自 DistributedBackend
class DummyBackend(DistributedBackend):
    """Acts like a distributed backend.

    Used as a stand-in replacement to obtain a non-distributed program.
    """

    # 定义一个常量 BACKEND_MODULE_NAME 为 'NO MODULE'
    BACKEND_MODULE_NAME = 'NO MODULE'
    # 定义一个常量 BACKEND_NAME 为 'Dummy'
    BACKEND_NAME = 'Dummy'

    # 检查是否存在后端
    def has_backend(self):
        return True

    # 包装参数解析器,返回原参数解析器
    def wrap_arg_parser(self, parser):
        return parser

    # 初始化方法,不做任何操作
    def _initialize(self):
        pass

    # 获取世界大小,返回 1
    def _get_world_size(self):
        return 1

    # 获取当前进程的排名,返回 ROOT_RANK
    def _get_rank(self):
        return self.ROOT_RANK

    # 获取本地排名,返回 ROOT_RANK
    def _get_local_rank(self):
        return self.ROOT_RANK

    # 本地屏障,不做任何操作
    def _local_barrier(self):
        pass

    # 分发方法,返回模型、优化器、数据加载器和学习率调度器
    def _distribute(
            self,
            _args=None,
            model=None,
            optimizer=None,
            _model_parameters=None,
            training_data=None,
            lr_scheduler=None,
            **_kwargs,
    ):
        """Return the model, optimizer, dataloader, and learning rate scheduler
        as is.
        """
        return (model, optimizer, training_data, lr_scheduler)

    # 对所有张量进行平均操作,返回原张量
    def _average_all(self, tensor):
        return tensor

.\lucidrains\DALLE-pytorch\dalle_pytorch\distributed_backends\horovod_backend.py

import torch
# 导入 torch 库

from .distributed_backend import DistributedBackend
# 从当前目录下的 distributed_backend 模块中导入 DistributedBackend 类

class HorovodBackend(DistributedBackend):
    """Distributed backend using Horovod."""
    # 使用 Horovod 的分布式后端

    BACKEND_MODULE_NAME = 'horovod.torch'
    BACKEND_NAME = 'Horovod'
    # 定义后端模块名和后端名称

    def wrap_arg_parser(self, parser):
        return parser
    # 包装参数解析器

    def check_batch_size(self, batch_size):
        # Horovod 使用本地批大小来确定有效批大小
        pass
    # 检查批大小

    def _initialize(self):
        self.backend_module.init()
        # 初始化后端模块
        if torch.cuda.is_available():
            torch.cuda.set_device(self._get_local_rank())
        # 如果 CUDA 可用,则设置当前设备为本地排名对应的设备

    def _get_world_size(self):
        return self.backend_module.size()
    # 获取世界大小

    def _get_rank(self):
        return self.backend_module.rank()
    # 获取排名

    def _get_local_rank(self):
        return self.backend_module.local_rank()
    # 获取本地排名

    def _local_barrier(self):
        # 实际上是全局屏障,但对我们的目的有效
        self.backend_module.join()
    # 本地屏障

    def _distribute(
            self,
            _args=None,
            model=None,
            optimizer=None,
            _model_parameters=None,
            training_data=None,
            lr_scheduler=None,
            **_kwargs,
    ):
        optimizer = self.backend_module.DistributedOptimizer(optimizer)
        # 使用后端模块的 DistributedOptimizer 对象对优化器进行分布式处理
        self.backend_module.broadcast_parameters(
            model.state_dict(), root_rank=self.ROOT_RANK)
        # 广播模型参数
        self.backend_module.broadcast_optimizer_state(
            optimizer, root_rank=self.ROOT_RANK)
        # 广播优化器状态
        return (model, optimizer, training_data, lr_scheduler)
    # 分发模型、优化器、训练数据和学习率调度器

    def _average_all(self, tensor):
        # 默认情况下,减少操作是平均值
        averaged = self.backend_module.allreduce(tensor)
        # 对张量进行全局平均值操作
        return averaged
    # 对所有张量进行平均值操作

.\lucidrains\DALLE-pytorch\dalle_pytorch\distributed_backends\__init__.py

# 从当前目录中导入 DeepSpeedBackend 模块
from .deepspeed_backend import DeepSpeedBackend
# 从当前目录中导入 DistributedBackend 模块
from .distributed_backend import DistributedBackend
# 从当前目录中导入 DummyBackend 模块
from .dummy_backend import DummyBackend
# 从当前目录中导入 HorovodBackend 模块
from .horovod_backend import HorovodBackend

.\lucidrains\DALLE-pytorch\dalle_pytorch\distributed_utils.py

"""
Utility functions for optional distributed execution.

To use,
1. set the `BACKENDS` to the ones you want to make available,
2. in the script, wrap the argument parser with `wrap_arg_parser`,
3. in the script, set and use the backend by calling
   `set_backend_from_args`.

You can check whether a backend is in use with the `using_backend`
function.
"""

from dalle_pytorch.distributed_backends import \
    DeepSpeedBackend, \
    DummyBackend, \
    HorovodBackend

_DEFAULT_BACKEND = DummyBackend()
"""Which backend to use by default. Assumed to be _not_ distributed."""

BACKENDS = [
    _DEFAULT_BACKEND,
    DeepSpeedBackend(),
    HorovodBackend(),
]

is_distributed = None
"""Whether we are distributed."""
backend = None
"""Backend in usage."""


def wrap_arg_parser(parser):
    """Add arguments to support optional distributed backend usage."""
    parser.add_argument(
        '--distributed_backend',
        '--distr_backend',
        type=str,
        default=None,
        help='which distributed backend to use. Do not distribute by default',
    )
    for distr_backend in BACKENDS:
        parser = distr_backend.wrap_arg_parser(parser)
    return parser


def set_backend_from_args(args):
    """Set and return the backend based on the given `args`."""
    global is_distributed, backend

    # Handle this specially for backwards compatibility.
    if args.deepspeed:
        args.distributed_backend = DeepSpeedBackend.BACKEND_NAME

    if not args.distributed_backend:
        is_distributed = False
        backend = _DEFAULT_BACKEND
        return backend

    backend_name = args.distributed_backend.lower()
    for distr_backend in BACKENDS:
        if distr_backend.BACKEND_NAME.lower() == backend_name:
            backend = distr_backend
            if not backend.has_backend():
                raise ModuleNotFoundError(
                    f'{backend.BACKEND_NAME} backend selected but '
                    'module not available'
                )

            print(f'Using {backend.BACKEND_NAME} for distributed execution')
            is_distributed = True
            return backend

    raise ValueError(
        'unknown backend; please check `distributed_utils.BACKENDS`')


def require_set_backend():
    """Raise an `AssertionError` when the backend has not been set."""
    assert backend is not None, (
        'distributed backend is not set. Please call '
        '`distributed_utils.set_backend_from_args` at the start of your script'
    )


def using_backend(test_backend):
    """Return whether the backend is set to `test_backend`.

    `test_backend` may be a string of the name of the backend or
    its class.
    """
    require_set_backend()
    if isinstance(test_backend, str):
        return backend.BACKEND_NAME == test_backend
    return isinstance(backend, test_backend)

.\lucidrains\DALLE-pytorch\dalle_pytorch\loader.py

from pathlib import Path
from random import randint, choice
import PIL
from torch.utils.data import Dataset
from torchvision import transforms as T

class TextImageDataset(Dataset):
    def __init__(self,
                 folder,
                 text_len=256,
                 image_size=128,
                 truncate_captions=False,
                 resize_ratio=0.75,
                 transparent=False,
                 tokenizer=None,
                 shuffle=False
                 ):
        """
        @param folder: 包含图像和文本文件的文件夹,它们通过其路径的相应“stem”匹配
        @param truncate_captions: 如果标题太长,将截断标题而不是抛出异常
        """
        super().__init__()
        self.shuffle = shuffle
        path = Path(folder)

        # 获取所有文本文件和图像文件的路径
        text_files = [*path.glob('**/*.txt')]
        image_files = [
            *path.glob('**/*.png'), *path.glob('**/*.jpg'),
            *path.glob('**/*.jpeg'), *path.glob('**/*.bmp')
        ]

        # 将文本文件和图像文件的stem作为key,文件路径作为value存储在字典中
        text_files = {text_file.stem: text_file for text_file in text_files}
        image_files = {image_file.stem: image_file for image_file in image_files}

        # 获取文本文件和图像文件stem的交集作为keys
        keys = (image_files.keys() & text_files.keys())

        self.keys = list(keys)
        self.text_files = {k: v for k, v in text_files.items() if k in keys}
        self.image_files = {k: v for k, v in image_files.items() if k in keys}
        self.text_len = text_len
        self.truncate_captions = truncate_captions
        self.resize_ratio = resize_ratio
        self.tokenizer = tokenizer

        # 根据是否透明设置图像模式
        image_mode = 'RGBA' if transparent else 'RGB'

        # 图像转换操作
        self.image_transform = T.Compose([
            T.Lambda(lambda img: img.convert(image_mode)
            if img.mode != image_mode else img),
            T.RandomResizedCrop(image_size,
                                scale=(self.resize_ratio, 1.),
                                ratio=(1., 1.)),
            T.ToTensor()
        ])

    def __len__(self):
        return len(self.keys)

    def random_sample(self):
        return self.__getitem__(randint(0, self.__len__() - 1))

    def sequential_sample(self, ind):
        if ind >= self.__len__() - 1:
            return self.__getitem__(0)
        return self.__getitem__(ind + 1)

    def skip_sample(self, ind):
        if self.shuffle:
            return self.random_sample()
        return self.sequential_sample(ind=ind)

    def __getitem__(self, ind):
        key = self.keys[ind]

        text_file = self.text_files[key]
        image_file = self.image_files[key]

        # 读取文本文件内容并按换行符分割
        descriptions = text_file.read_text().split('\n')
        descriptions = list(filter(lambda t: len(t) > 0, descriptions))
        try:
            description = choice(descriptions)
        except IndexError as zero_captions_in_file_ex:
            print(f"An exception occurred trying to load file {text_file}.")
            print(f"Skipping index {ind}")
            return self.skip_sample(ind)

        # 对文本进行标记化处理
        tokenized_text = self.tokenizer.tokenize(
            description,
            self.text_len,
            truncate_text=self.truncate_captions
        ).squeeze(0)
        try:
            image_tensor = self.image_transform(PIL.Image.open(image_file))
        except (PIL.UnidentifiedImageError, OSError) as corrupt_image_exceptions:
            print(f"An exception occurred trying to load file {image_file}.")
            print(f"Skipping index {ind}")
            return self.skip_sample(ind)

        # 成功返回标记化的文本和图像张量
        return tokenized_text, image_tensor