Lucidrains-系列项目源码解析-二十六-

64 阅读30分钟

Lucidrains 系列项目源码解析(二十六)

.\lucidrains\denoising-diffusion-pytorch\denoising_diffusion_pytorch\karras_unet_1d.py

"""
the magnitude-preserving unet proposed in https://arxiv.org/abs/2312.02696 by Karras et al.
"""

import math
from math import sqrt, ceil
from functools import partial

import torch
from torch import nn, einsum
from torch.nn import Module, ModuleList
from torch.optim.lr_scheduler import LambdaLR
import torch.nn.functional as F

from einops import rearrange, repeat, pack, unpack

from denoising_diffusion_pytorch.attend import Attend

# helpers functions

# 检查变量是否存在
def exists(x):
    return x is not None

# 返回默认值
def default(val, d):
    if exists(val):
        return val
    return d() if callable(d) else d

# 逻辑异或操作
def xnor(x, y):
    return not (x ^ y)

# 在数组末尾添加元素
def append(arr, el):
    arr.append(el)

# 在数组开头添加元素
def prepend(arr, el):
    arr.insert(0, el)

# 将张量打包成指定模式
def pack_one(t, pattern):
    return pack([t], pattern)

# 将打包的张量解包成指定模式
def unpack_one(t, ps, pattern):
    return unpack(t, ps, pattern)[0]

# 将元素转换为元组
def cast_tuple(t, length = 1):
    if isinstance(t, tuple):
        return t
    return ((t,) * length)

# 判断是否可以整除
def divisible_by(numer, denom):
    return (numer % denom) == 0

# 计算 L2 范数
def l2norm(t, dim = -1, eps = 1e-12):
    return F.normalize(t, dim = dim, eps = eps)

# 在一维上插值
def interpolate_1d(x, length, mode = 'bilinear'):
    x = rearrange(x, 'b c t -> b c t 1')
    x = F.interpolate(x, (length, 1), mode = mode)
    return rearrange(x, 'b c t 1 -> b c t')

# mp activations
# section 2.5

# MPSiLU 激活函数
class MPSiLU(Module):
    def forward(self, x):
        return F.silu(x) / 0.596

# gain - layer scaling

# 增益层
class Gain(Module):
    def __init__(self):
        super().__init__()
        self.gain = nn.Parameter(torch.tensor(0.))

    def forward(self, x):
        return x * self.gain

# magnitude preserving concat
# equation (103) - default to 0.5, which they recommended

# 保持幅度的拼接层
class MPCat(Module):
    def __init__(self, t = 0.5, dim = -1):
        super().__init__()
        self.t = t
        self.dim = dim

    def forward(self, a, b):
        dim, t = self.dim, self.t
        Na, Nb = a.shape[dim], b.shape[dim]

        C = sqrt((Na + Nb) / ((1. - t) ** 2 + t ** 2))

        a = a * (1. - t) / sqrt(Na)
        b = b * t / sqrt(Nb)

        return C * torch.cat((a, b), dim = dim)

# magnitude preserving sum
# equation (88)
# empirically, they found t=0.3 for encoder / decoder / attention residuals
# and for embedding, t=0.5

# 保持幅度的求和层
class MPAdd(Module):
    def __init__(self, t):
        super().__init__()
        self.t = t

    def forward(self, x, res):
        a, b, t = x, res, self.t
        num = a * (1. - t) + b * t
        den = sqrt((1 - t) ** 2 + t ** 2)
        return num / den

# pixelnorm
# equation (30)

# 像素范数层
class PixelNorm(Module):
    def __init__(self, dim, eps = 1e-4):
        super().__init__()
        # high epsilon for the pixel norm in the paper
        self.dim = dim
        self.eps = eps

    def forward(self, x):
        dim = self.dim
        return l2norm(x, dim = dim, eps = self.eps) * sqrt(x.shape[dim])

# forced weight normed conv2d and linear
# algorithm 1 in paper

# 规范化权重
def normalize_weight(weight, eps = 1e-4):
    weight, ps = pack_one(weight, 'o *')
    normed_weight = l2norm(weight, eps = eps)
    normed_weight = normed_weight * sqrt(weight.numel() / weight.shape[0])
    return unpack_one(normed_weight, ps, 'o *')

# 一维卷积层
class Conv1d(Module):
    def __init__(
        self,
        dim_in,
        dim_out,
        kernel_size,
        eps = 1e-4,
        init_dirac = False,
        concat_ones_to_input = False   # they use this in the input block to protect against loss of expressivity due to removal of all biases, even though they claim they observed none
    ):
        super().__init__()
        weight = torch.randn(dim_out, dim_in + int(concat_ones_to_input), kernel_size)
        self.weight = nn.Parameter(weight)

        if init_dirac:
            nn.init.dirac_(self.weight)

        self.eps = eps
        self.fan_in = dim_in * kernel_size
        self.concat_ones_to_input = concat_ones_to_input
    # 前向传播函数,接受输入 x
    def forward(self, x):
        # 如果处于训练模式
        if self.training:
            # 在不计算梯度的情况下,对权重进行归一化处理
            with torch.no_grad():
                normed_weight = normalize_weight(self.weight, eps = self.eps)
                # 将归一化后的权重复制给当前权重
                self.weight.copy_(normed_weight)

        # 对权重进行归一化处理,并除以输入特征数的平方根
        weight = normalize_weight(self.weight, eps = self.eps) / sqrt(self.fan_in)

        # 如果需要将输入的维度扩展为包含全为1的维度
        if self.concat_ones_to_input:
            # 在输入 x 上进行填充,使得维度增加一维,填充值为1
            x = F.pad(x, (0, 0, 1, 0), value = 1.)

        # 返回一维卷积操作的结果,使用权重 weight 进行卷积,padding 为 'same'
        return F.conv1d(x, weight, padding = 'same')
# 定义线性层模块,继承自 Module 类
class Linear(Module):
    # 初始化函数,接受输入维度、输出维度和 eps 参数
    def __init__(self, dim_in, dim_out, eps = 1e-4):
        # 调用父类的初始化函数
        super().__init__()
        # 生成随机权重矩阵
        weight = torch.randn(dim_out, dim_in)
        # 将权重矩阵设置为可训练参数
        self.weight = nn.Parameter(weight)
        # 设置 eps 属性
        self.eps = eps
        # 记录输入维度
        self.fan_in = dim_in

    # 前向传播函数
    def forward(self, x):
        # 如果处于训练模式
        if self.training:
            # 使用 torch.no_grad() 上下文管理器,不计算梯度
            with torch.no_grad():
                # 对权重矩阵进行归一化处理
                normed_weight = normalize_weight(self.weight, eps = self.eps)
                # 将归一化后的权重矩阵复制给 self.weight
                self.weight.copy_(normed_weight)

        # 对权重矩阵进行归一化处理,并除以输入维度的平方根
        weight = normalize_weight(self.weight, eps = self.eps) / sqrt(self.fan_in)
        # 返回线性变换后的结果
        return F.linear(x, weight)

# MP Fourier Embedding 模块

class MPFourierEmbedding(Module):
    # 初始化函数,接受维度参数
    def __init__(self, dim):
        # 调用父类的初始化函数
        super().__init__()
        # 断言维度能被 2 整除
        assert divisible_by(dim, 2)
        # 计算维度的一半
        half_dim = dim // 2
        # 初始化权重参数,不需要梯度
        self.weights = nn.Parameter(torch.randn(half_dim), requires_grad = False)

    # 前向传播函数
    def forward(self, x):
        # 对输入进行维度重排,增加一个维度
        x = rearrange(x, 'b -> b 1')
        # 计算频率
        freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi
        # 返回正弦和余弦函数的拼接结果,乘以根号2
        return torch.cat((freqs.sin(), freqs.cos()), dim = -1) * sqrt(2)

# 构建基础模块

class Encoder(Module):
    # 初始化函数,接受维度、输出维度等参数
    def __init__(
        self,
        dim,
        dim_out = None,
        *,
        emb_dim = None,
        dropout = 0.1,
        mp_add_t = 0.3,
        has_attn = False,
        attn_dim_head = 64,
        attn_res_mp_add_t = 0.3,
        attn_flash = False,
        downsample = False
    ):
        # 调用父类的初始化函数
        super().__init__()
        # 如果未指定输出维度,则设为输入维度
        dim_out = default(dim_out, dim)

        # 是否下���样
        self.downsample = downsample
        self.downsample_conv = None

        curr_dim = dim
        # 如果下采样为真
        if downsample:
            # 初始化下采样卷积层
            self.downsample_conv = Conv1d(curr_dim, dim_out, 1)
            curr_dim = dim_out

        # 像素归一化
        self.pixel_norm = PixelNorm(dim = 1)

        self.to_emb = None
        # 如果存在嵌入维度
        if exists(emb_dim):
            # 初始化嵌入层
            self.to_emb = nn.Sequential(
                Linear(emb_dim, dim_out),
                Gain()
            )

        # 第一个块
        self.block1 = nn.Sequential(
            MPSiLU(),
            Conv1d(curr_dim, dim_out, 3)
        )

        # 第二个块
        self.block2 = nn.Sequential(
            MPSiLU(),
            nn.Dropout(dropout),
            Conv1d(dim_out, dim_out, 3)
        )

        # MPAdd 模块
        self.res_mp_add = MPAdd(t = mp_add_t)

        self.attn = None
        # 如果有注意力机制
        if has_attn:
            # 初始化注意力层
            self.attn = Attention(
                dim = dim_out,
                heads = max(ceil(dim_out / attn_dim_head), 2),
                dim_head = attn_dim_head,
                mp_add_t = attn_res_mp_add_t,
                flash = attn_flash
            )

    # 前向传播函数
    def forward(
        self,
        x,
        emb = None
    ):
        # 如果下采样为真
        if self.downsample:
            # 对输入进行一维插值,减半长度
            x = interpolate_1d(x, x.shape[-1] // 2, mode = 'bilinear')
            x = self.downsample_conv(x)

        # 对输入进行像素归一化
        x = self.pixel_norm(x)

        # 复制输入作为残差
        res = x.clone()

        # 第一个块的前向传播
        x = self.block1(x)

        # 如果存在嵌入
        if exists(emb):
            # 计算缩放因子
            scale = self.to_emb(emb) + 1
            x = x * rearrange(scale, 'b c -> b c 1')

        # 第二个块的前向传播
        x = self.block2(x)

        # MPAdd 模块的前向传播
        x = self.res_mp_add(x, res)

        # 如果存在注意力层
        if exists(self.attn):
            x = self.attn(x)

        # 返回结果
        return x

# 解码器模块

class Decoder(Module):
    # 初始化函数,接受维度、输出维度等参数
    def __init__(
        self,
        dim,
        dim_out = None,
        *,
        emb_dim = None,
        dropout = 0.1,
        mp_add_t = 0.3,
        has_attn = False,
        attn_dim_head = 64,
        attn_res_mp_add_t = 0.3,
        attn_flash = False,
        upsample = False
    # 初始化函数,继承父类的初始化方法
    ):
        # 调用父类的初始化方法
        super().__init__()
        # 如果输出维度未指定,则使用输入维度作为输出维度
        dim_out = default(dim_out, dim)

        # 设置上采样标志
        self.upsample = upsample
        # 判断是否需要跳跃连接
        self.needs_skip = not upsample

        # 初始化嵌入层
        self.to_emb = None
        # 如果嵌入维度存在,则创建嵌入层
        if exists(emb_dim):
            self.to_emb = nn.Sequential(
                Linear(emb_dim, dim_out),
                Gain()
            )

        # 第一个块
        self.block1 = nn.Sequential(
            MPSiLU(),
            Conv1d(dim, dim_out, 3)
        )

        # 第二个块
        self.block2 = nn.Sequential(
            MPSiLU(),
            nn.Dropout(dropout),
            Conv1d(dim_out, dim_out, 3)
        )

        # 残差连接的卷积层
        self.res_conv = Conv1d(dim, dim_out, 1) if dim != dim_out else nn.Identity()

        # 残差连接的加法操作
        self.res_mp_add = MPAdd(t = mp_add_t)

        # 注意力机制
        self.attn = None
        # 如果需要注意力机制
        if has_attn:
            self.attn = Attention(
                dim = dim_out,
                heads = max(ceil(dim_out / attn_dim_head), 2),
                dim_head = attn_dim_head,
                mp_add_t = attn_res_mp_add_t,
                flash = attn_flash
            )

    # 前向传播函数
    def forward(
        self,
        x,
        emb = None
    ):
        # 如果需要上采样
        if self.upsample:
            # 对输入进行一维插值上采样
            x = interpolate_1d(x, x.shape[-1] * 2, mode = 'bilinear')

        # 计算残差连接
        res = self.res_conv(x)

        # 第一个块的操作
        x = self.block1(x)

        # 如果嵌入存在
        if exists(emb):
            # 计算缩放因子
            scale = self.to_emb(emb) + 1
            x = x * rearrange(scale, 'b c -> b c 1')

        # 第二个块的操作
        x = self.block2(x)

        # 执行残差连接的加法操作
        x = self.res_mp_add(x, res)

        # 如果存在注意力机制
        if exists(self.attn):
            # 执行注意力机制操作
            x = self.attn(x)

        # 返回结果
        return x
# 定义一个注意力机制的类,继承自 Module 类
class Attention(Module):
    # 初始化函数,设置注意力机制的参数
    def __init__(
        self,
        dim,
        heads = 4,
        dim_head = 64,
        num_mem_kv = 4,
        flash = False,
        mp_add_t = 0.3
    ):
        # 调用父类的初始化函数
        super().__init__()
        # 设置头数和隐藏维度
        self.heads = heads
        hidden_dim = dim_head * heads

        # 像素归一化
        self.pixel_norm = PixelNorm(dim = -1)

        # 注意力机制
        self.attend = Attend(flash = flash)

        # 记忆键值对
        self.mem_kv = nn.Parameter(torch.randn(2, heads, num_mem_kv, dim_head))
        self.to_qkv = Conv1d(dim, hidden_dim * 3, 1)
        self.to_out = Conv1d(hidden_dim, dim, 1)

        # 多路加法
        self.mp_add = MPAdd(t = mp_add_t)

    # 前向传播函数
    def forward(self, x):
        res, b, c, n = x, *x.shape

        # 将输入数据转换为查询、键、值
        qkv = self.to_qkv(x).chunk(3, dim = 1)
        q, k, v = map(lambda t: rearrange(t, 'b (h c) n -> b h n c', h = self.heads), qkv)

        # 扩展记忆键值对
        mk, mv = map(lambda t: repeat(t, 'h n d -> b h n d', b = b), self.mem_kv)
        k, v = map(partial(torch.cat, dim = -2), ((mk, k), (mv, v)))

        # 对查询、键、值进行像素归一化
        q, k, v = map(self.pixel_norm, (q, k, v))

        # 进行注意力计算
        out = self.attend(q, k, v)

        out = rearrange(out, 'b h n d -> b (h d) n')
        out = self.to_out(out)

        return self.mp_add(out, res)

# 定义一个基于 Karras 提出的 Unet 的 1D 版本
class KarrasUnet1D(Module):
    """
    going by figure 21. config G
    """

    # 初始化函数,设置 Unet 的参数
    def __init__(
        self,
        *,
        seq_len,
        dim = 192,
        dim_max = 768,            
        num_classes = None,       
        channels = 4,             
        num_downsamples = 3,
        num_blocks_per_stage = 4,
        attn_res = (16, 8),
        fourier_dim = 16,
        attn_dim_head = 64,
        attn_flash = False,
        mp_cat_t = 0.5,
        mp_add_emb_t = 0.5,
        attn_res_mp_add_t = 0.3,
        resnet_mp_add_t = 0.3,
        dropout = 0.1,
        self_condition = False
    # 初始化函数,继承父类的初始化方法
    ):
        super().__init__()

        # 设置 self_condition 属性
        self.self_condition = self_condition

        # 确定维度

        # 设置通道数和序列长度
        self.channels = channels
        self.seq_len = seq_len
        # 计算输入通道数
        input_channels = channels * (2 if self_condition else 1)

        # 输入和输出块

        # 创建输入块
        self.input_block = Conv1d(input_channels, dim, 3, concat_ones_to_input = True)

        # 创建输出块
        self.output_block = nn.Sequential(
            Conv1d(dim, channels, 3),
            Gain()
        )

        # 时间嵌入

        # 设置嵌入维度
        emb_dim = dim * 4

        # 创建时间嵌入层
        self.to_time_emb = nn.Sequential(
            MPFourierEmbedding(fourier_dim),
            Linear(fourier_dim, emb_dim)
        )

        # 类别嵌入

        # 判断是否需要类别标签
        self.needs_class_labels = exists(num_classes)
        self.num_classes = num_classes

        # 如果需要类别标签
        if self.needs_class_labels:
            # 创建类别嵌入层
            self.to_class_emb = Linear(num_classes, 4 * dim)
            self.add_class_emb = MPAdd(t = mp_add_emb_t)

        # 最终嵌入激活函数

        self.emb_activation = MPSiLU()

        # 下采样数量

        self.num_downsamples = num_downsamples

        # 注意力

        attn_res = set(cast_tuple(attn_res))

        # ResNet 块

        block_kwargs = dict(
            dropout = dropout,
            emb_dim = emb_dim,
            attn_dim_head = attn_dim_head,
            attn_res_mp_add_t = attn_res_mp_add_t,
            attn_flash = attn_flash
        )

        # UNet 编码器和解码器

        self.downs = ModuleList([])
        self.ups = ModuleList([])

        curr_dim = dim
        curr_res = seq_len

        self.skip_mp_cat = MPCat(t = mp_cat_t, dim = 1)

        # 处理初始输入块和前三个编码器块的跳跃连接

        prepend(self.ups, Decoder(dim * 2, dim, **block_kwargs))

        assert num_blocks_per_stage >= 1

        for _ in range(num_blocks_per_stage):
            enc = Encoder(curr_dim, curr_dim, **block_kwargs)
            dec = Decoder(curr_dim * 2, curr_dim, **block_kwargs)

            append(self.downs, enc)
            prepend(self.ups, dec)

        # 阶段

        for _ in range(self.num_downsamples):
            dim_out = min(dim_max, curr_dim * 2)
            upsample = Decoder(dim_out, curr_dim, has_attn = curr_res in attn_res, upsample = True, **block_kwargs)

            curr_res //= 2
            has_attn = curr_res in attn_res

            downsample = Encoder(curr_dim, dim_out, downsample = True, has_attn = has_attn, **block_kwargs)

            append(self.downs, downsample)
            prepend(self.ups, upsample)
            prepend(self.ups, Decoder(dim_out * 2, dim_out, has_attn = has_attn, **block_kwargs))

            for _ in range(num_blocks_per_stage):
                enc = Encoder(dim_out, dim_out, has_attn = has_attn, **block_kwargs)
                dec = Decoder(dim_out * 2, dim_out, has_attn = has_attn, **block_kwargs)

                append(self.downs, enc)
                prepend(self.ups, dec)

            curr_dim = dim_out

        # 处理两个中间解码器

        mid_has_attn = curr_res in attn_res

        self.mids = ModuleList([
            Decoder(curr_dim, curr_dim, has_attn = mid_has_attn, **block_kwargs),
            Decoder(curr_dim, curr_dim, has_attn = mid_has_attn, **block_kwargs),
        ])

        self.out_dim = channels

    @property
    def downsample_factor(self):
        return 2 ** self.num_downsamples

    def forward(
        self,
        x,
        time,
        self_cond = None,
        class_labels = None
    ):
        # 验证图像形状是否符合要求

        assert x.shape[1:] == (self.channels, self.seq_len)

        # 自身条件

        if self.self_condition:
            self_cond = default(self_cond, lambda: torch.zeros_like(x))
            x = torch.cat((self_cond, x), dim = 1)
        else:
            assert not exists(self_cond)

        # 时间条件

        time_emb = self.to_time_emb(time)

        # 类别条件

        assert xnor(exists(class_labels), self.needs_class_labels)

        if self.needs_class_labels:
            if class_labels.dtype in (torch.int, torch.long):
                class_labels = F.one_hot(class_labels, self.num_classes)

            assert class_labels.shape[-1] == self.num_classes
            class_labels = class_labels.float() * sqrt(self.num_classes)

            class_emb = self.to_class_emb(class_labels)

            time_emb = self.add_class_emb(time_emb, class_emb)

        # 最终的 mp-silu 用于嵌入

        emb = self.emb_activation(time_emb)

        # 跳过连接

        skips = []

        # 输入块

        x = self.input_block(x)

        skips.append(x)

        # 下采样

        for encoder in self.downs:
            x = encoder(x, emb = emb)
            skips.append(x)

        # 中间层

        for decoder in self.mids:
            x = decoder(x, emb = emb)

        # 上采样

        for decoder in self.ups:
            if decoder.needs_skip:
                skip = skips.pop()
                x = self.skip_mp_cat(x, skip)

            x = decoder(x, emb = emb)

        # 输出块

        return self.output_block(x)
# 定义一个 MPFeedForward 类,用于实现多头感知器前馈网络
class MPFeedForward(Module):
    # 初始化函数,接收参数 dim(维度)、mult(倍数,默认为4)、mp_add_t(MPAdd 参数,默认为0.3)
    def __init__(
        self,
        *,
        dim,
        mult = 4,
        mp_add_t = 0.3
    ):
        # 调用父类的初始化函数
        super().__init__()
        # 计算内部维度
        dim_inner = int(dim * mult)
        # 定义网络结构
        self.net = nn.Sequential(
            PixelNorm(dim = 1),  # 对输入进行像素归一化
            Conv2d(dim, dim_inner, 1),  # 1x1 卷积层
            MPSiLU(),  # MP SiLU 激活函数
            Conv2d(dim_inner, dim, 1)  # 1x1 卷积层
        )

        # 初始化 MPAdd 模块
        self.mp_add = MPAdd(t = mp_add_t)

    # 前向传播函数
    def forward(self, x):
        res = x
        out = self.net(x)  # 网络前向传播
        return self.mp_add(out, res)  # 返回 MPAdd 模块的输出结果

# 定义一个 MPImageTransformer 类,用于实现多头图像变换器
class MPImageTransformer(Module):
    # 初始化函数,接收参数 dim(维度)、depth(深度)、dim_head(头部维度,默认为64)、heads(头数,默认为8)、num_mem_kv(记忆键值对数,默认为4)、ff_mult(前馈网络倍数,默认为4)、attn_flash(是否使用闪回,默认为False)、residual_mp_add_t(MPAdd 参数,默认为0.3)
    def __init__(
        self,
        *,
        dim,
        depth,
        dim_head = 64,
        heads = 8,
        num_mem_kv = 4,
        ff_mult = 4,
        attn_flash = False,
        residual_mp_add_t = 0.3
    ):
        # 调用父类的初始化函数
        super().__init__()
        # 初始化网络层列表
        self.layers = ModuleList([])

        # 根据深度循环添加注意力和前馈网络层
        for _ in range(depth):
            self.layers.append(ModuleList([
                Attention(dim = dim, heads = heads, dim_head = dim_head, num_mem_kv = num_mem_kv, flash = attn_flash, mp_add_t = residual_mp_add_t),  # 添加注意力层
                MPFeedForward(dim = dim, mult = ff_mult, mp_add_t = residual_mp_add_t)  # 添加前馈网络层
            ]))

    # 前向传播函数
    def forward(self, x):
        # 遍历网络层列表
        for attn, ff in self.layers:
            x = attn(x)  # 注意力层前向传播
            x = ff(x)  # 前馈网络层前向传播

        return x  # 返回输出结果

# 示例代码
if __name__ == '__main__':
    # 创建 KarrasUnet1D 实例
    unet = KarrasUnet1D(
        seq_len = 64,
        dim = 192,
        dim_max = 768,
        num_classes = 1000,
    )

    # 生成随机输入图像
    images = torch.randn(2, 4, 64)

    # 使用 unet 进行图像去噪
    denoised_images = unet(
        images,
        time = torch.ones(2,),
        class_labels = torch.randint(0, 1000, (2,))
    )

    # 断言去噪后的图像形状与原始图像形状相同
    assert denoised_images.shape == images.shape

.\lucidrains\denoising-diffusion-pytorch\denoising_diffusion_pytorch\karras_unet_3d.py

"""
the magnitude-preserving unet proposed in https://arxiv.org/abs/2312.02696 by Karras et al.
"""

import math
from math import sqrt, ceil
from functools import partial
from typing import Optional, Union, Tuple

import torch
from torch import nn, einsum
from torch.nn import Module, ModuleList
from torch.optim.lr_scheduler import LambdaLR
import torch.nn.functional as F

from einops import rearrange, repeat, pack, unpack

from denoising_diffusion_pytorch.attend import Attend

# helpers functions

# 检查变量是否存在
def exists(x):
    return x is not None

# 如果变量存在则返回其值,否则返回默认值
def default(val, d):
    if exists(val):
        return val
    return d() if callable(d) else d

# 逻辑异或操作
def xnor(x, y):
    return not (x ^ y)

# 在数组末尾添加元素
def append(arr, el):
    arr.append(el)

# 在数组开头添加元素
def prepend(arr, el):
    arr.insert(0, el)

# 将张量打包成指定模式的形状
def pack_one(t, pattern):
    return pack([t], pattern)

# 将打包后的张量解包成原始形状
def unpack_one(t, ps, pattern):
    return unpack(t, ps, pattern)[0]

# 将输入转换为元组
def cast_tuple(t, length = 1):
    if isinstance(t, tuple):
        return t
    return ((t,) * length)

# 判断一个数是否可以被另一个数整除
def divisible_by(numer, denom):
    return (numer % denom) == 0

# 在论文中,他们使用 eps 1e-4 作为像素归一化的值

# 计算 L2 范数
def l2norm(t, dim = -1, eps = 1e-12):
    return F.normalize(t, dim = dim, eps = eps)

# mp activations
# section 2.5

# MPSiLU 激活函数
class MPSiLU(Module):
    def forward(self, x):
        return F.silu(x) / 0.596

# gain - layer scaling

# 增益层
class Gain(Module):
    def __init__(self):
        super().__init__()
        self.gain = nn.Parameter(torch.tensor(0.))

    def forward(self, x):
        return x * self.gain

# magnitude preserving concat
# equation (103) - default to 0.5, which they recommended

# 保持幅度的拼接层
class MPCat(Module):
    def __init__(self, t = 0.5, dim = -1):
        super().__init__()
        self.t = t
        self.dim = dim

    def forward(self, a, b):
        dim, t = self.dim, self.t
        Na, Nb = a.shape[dim], b.shape[dim]

        C = sqrt((Na + Nb) / ((1. - t) ** 2 + t ** 2))

        a = a * (1. - t) / sqrt(Na)
        b = b * t / sqrt(Nb)

        return C * torch.cat((a, b), dim = dim)

# magnitude preserving sum
# equation (88)
# empirically, they found t=0.3 for encoder / decoder / attention residuals
# and for embedding, t=0.5

# 保持幅度的求和层
class MPAdd(Module):
    def __init__(self, t):
        super().__init__()
        self.t = t

    def forward(self, x, res):
        a, b, t = x, res, self.t
        num = a * (1. - t) + b * t
        den = sqrt((1 - t) ** 2 + t ** 2)
        return num / den

# pixelnorm
# equation (30)

# 像素归一化层
class PixelNorm(Module):
    def __init__(self, dim, eps = 1e-4):
        super().__init__()
        # 论文中像素归一化的高 epsilon 值
        self.dim = dim
        self.eps = eps

    def forward(self, x):
        dim = self.dim
        return l2norm(x, dim = dim, eps = self.eps) * sqrt(x.shape[dim])

# forced weight normed conv3d and linear
# algorithm 1 in paper

# 归一化权重的 Conv3d 和 Linear 层
def normalize_weight(weight, eps = 1e-4):
    weight, ps = pack_one(weight, 'o *')
    normed_weight = l2norm(weight, eps = eps)
    normed_weight = normed_weight * sqrt(weight.numel() / weight.shape[0])
    return unpack_one(normed_weight, ps, 'o *')

# 3D 卷积层
class Conv3d(Module):
    def __init__(
        self,
        dim_in,
        dim_out,
        kernel_size,
        eps = 1e-4,
        concat_ones_to_input = False   # they use this in the input block to protect against loss of expressivity due to removal of all biases, even though they claim they observed none
    ):
        super().__init__()
        weight = torch.randn(dim_out, dim_in + int(concat_ones_to_input), kernel_size, kernel_size, kernel_size)
        self.weight = nn.Parameter(weight)

        self.eps = eps
        self.fan_in = dim_in * kernel_size ** 3
        self.concat_ones_to_input = concat_ones_to_input
    # 定义前向传播函数,接受输入 x
    def forward(self, x):

        # 如果处于训练模式
        if self.training:
            # 在不计算梯度的情况下,对权重进行归一化处理
            with torch.no_grad():
                normed_weight = normalize_weight(self.weight, eps = self.eps)
                # 将归一化后的权重复制给当前权重
                self.weight.copy_(normed_weight)

        # 对权重进行归一化处理,并除以输入特征的平方根
        weight = normalize_weight(self.weight, eps = self.eps) / sqrt(self.fan_in)

        # 如果需要将输入与全为1的张量进行拼接
        if self.concat_ones_to_input:
            # 在输入张量的最后一维度上填充1
            x = F.pad(x, (0, 0, 0, 0, 0, 0, 1, 0), value = 1.)

        # 返回经过卷积操作后的结果
        return F.conv3d(x, weight, padding='same')
# 定义一个线性层模块,包含输入维度、输出维度和一个小的常数 eps
class Linear(Module):
    def __init__(self, dim_in, dim_out, eps = 1e-4):
        super().__init__()
        # 用随机数初始化权重矩阵
        weight = torch.randn(dim_out, dim_in)
        self.weight = nn.Parameter(weight)
        self.eps = eps
        self.fan_in = dim_in

    # 前向传播函数
    def forward(self, x):
        # 如果处于训练状态
        if self.training:
            # 使用 torch.no_grad() 上下文管理器,不计算梯度
            with torch.no_grad():
                # 对权重进行归一化处理
                normed_weight = normalize_weight(self.weight, eps = self.eps)
                # 将归一化后的权重复制给原始权重
                self.weight.copy_(normed_weight)

        # 对权重进行归一化处理,并除以输入维度的平方根
        weight = normalize_weight(self.weight, eps = self.eps) / sqrt(self.fan_in)
        # 返回线性变换后的结果
        return F.linear(x, weight)

# MP Fourier Embedding 模块

class MPFourierEmbedding(Module):
    def __init__(self, dim):
        super().__init__()
        # 断言维度必须是2的倍数
        assert divisible_by(dim, 2)
        half_dim = dim // 2
        # 初始化权重参数,不需要计算梯度
        self.weights = nn.Parameter(torch.randn(half_dim), requires_grad = False)

    # 前向传播函数
    def forward(self, x):
        # 对输入进行维度重排
        x = rearrange(x, 'b -> b 1')
        # 计算频率
        freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi
        # 返回正弦和余弦函数的拼接结果,并乘以根号2
        return torch.cat((freqs.sin(), freqs.cos()), dim = -1) * sqrt(2)

# 构建基本模块

class Encoder(Module):
    def __init__(
        self,
        dim,
        dim_out = None,
        *,
        emb_dim = None,
        dropout = 0.1,
        mp_add_t = 0.3,
        has_attn = False,
        attn_dim_head = 64,
        attn_res_mp_add_t = 0.3,
        attn_flash = False,
        factorize_space_time_attn = False,
        downsample = False,
        downsample_config: Tuple[bool, bool, bool] = (True, True, True)
    ):
        super().__init__()
        dim_out = default(dim_out, dim)

        self.downsample = downsample
        self.downsample_config = downsample_config

        self.downsample_conv = None

        curr_dim = dim
        # 如果需要下采样
        if downsample:
            # 使用 1x1 卷积进行下采样
            self.downsample_conv = Conv3d(curr_dim, dim_out, 1)
            curr_dim = dim_out

        # 像素归一化
        self.pixel_norm = PixelNorm(dim = 1)

        self.to_emb = None
        # 如果存在嵌入维度
        if exists(emb_dim):
            # 构建嵌入层
            self.to_emb = nn.Sequential(
                Linear(emb_dim, dim_out),
                Gain()
            )

        # 第一个基本模块
        self.block1 = nn.Sequential(
            MPSiLU(),
            Conv3d(curr_dim, dim_out, 3)
        )

        # 第二个基本模块
        self.block2 = nn.Sequential(
            MPSiLU(),
            nn.Dropout(dropout),
            Conv3d(dim_out, dim_out, 3)
        )

        # MPAdd 模块
        self.res_mp_add = MPAdd(t = mp_add_t)

        self.attn = None
        self.factorized_attn = factorize_space_time_attn

        # 如果有注意力机制
        if has_attn:
            attn_kwargs = dict(
                dim = dim_out,
                heads = max(ceil(dim_out / attn_dim_head), 2),
                dim_head = attn_dim_head,
                mp_add_t = attn_res_mp_add_t,
                flash = attn_flash
            )

            # 如果需要分解空间和时间的注意力机制
            if factorize_space_time_attn:
                self.attn = nn.ModuleList([
                    Attention(**attn_kwargs, only_space = True),
                    Attention(**attn_kwargs, only_time = True),
                ])
            else:
                self.attn = Attention(**attn_kwargs)

    # 前向传播函数
    def forward(
        self,
        x,
        emb = None
        ):
        # 如果存在下采样参数
        if self.downsample:
            # 获取输入张量的时间、高度、宽度
            t, h, w = x.shape[-3:]
            # 根据下采样配置计算缩放因子
            resize_factors = tuple((2 if downsample else 1) for downsample in self.downsample_config)
            # 计算插值后的形状
            interpolate_shape = tuple(shape // factor for shape, factor in zip((t, h, w), resize_factors))

            # 对输入张量进行三线性插值
            x = F.interpolate(x, interpolate_shape, mode='trilinear')
            # 使用下采样卷积层处理插值后的张量
            x = self.downsample_conv(x)

        # 对输入张量进行像素归一化
        x = self.pixel_norm(x)

        # 复制输入张量
        res = x.clone()

        # 使用第一个残差块处理输入张量
        x = self.block1(x)

        # 如果存在嵌入向量
        if exists(emb):
            # 计算缩放因子
            scale = self.to_emb(emb) + 1
            # 对输入张量进行缩放
            x = x * rearrange(scale, 'b c -> b c 1 1 1')

        # 使用第二个残差块处理输入张量
        x = self.block2(x)

        # 将残差块的输出与之前复制的张量相加
        x = self.res_mp_add(x, res)

        # 如果存在注意力机制
        if exists(self.attn):
            # 如果使用分解的注意力机制
            if self.factorized_attn:
                # 获取空间注意力和时间注意力
                attn_space, attn_time = self.attn
                # 先对空间进行注意力处理
                x = attn_space(x)
                # 再对时间进行注意力处理
                x = attn_time(x)

            else:
                # 使用整体的注意力机制处理输入张量
                x = self.attn(x)

        # 返回处理后的张量
        return x
# 定义一个名为 Decoder 的类,继承自 Module 类
class Decoder(Module):
    # 初始化方法
    def __init__(
        self,
        dim,
        dim_out = None,
        *,
        emb_dim = None,
        dropout = 0.1,
        mp_add_t = 0.3,
        has_attn = False,
        attn_dim_head = 64,
        attn_res_mp_add_t = 0.3,
        attn_flash = False,
        factorize_space_time_attn = False,
        upsample = False,
        upsample_config: Tuple[bool, bool, bool] = (True, True, True)
    ):
        # 调用父类的初始化方法
        super().__init__()
        # 如果未指定 dim_out,则设为 dim
        dim_out = default(dim_out, dim)

        # 设置是否需要上采样和上采样配置
        self.upsample = upsample
        self.upsample_config = upsample_config

        # 如果不需要上采样,则需要跳跃连接
        self.needs_skip = not upsample

        # 如果存在 emb_dim,则创建线性层和增益层
        self.to_emb = None
        if exists(emb_dim):
            self.to_emb = nn.Sequential(
                Linear(emb_dim, dim_out),
                Gain()
            )

        # 第一个块包含 MPSiLU 和 3D 卷积层
        self.block1 = nn.Sequential(
            MPSiLU(),
            Conv3d(dim, dim_out, 3)
        )

        # 第二个块包含 MPSiLU、Dropout 和 3D 卷积层
        self.block2 = nn.Sequential(
            MPSiLU(),
            nn.Dropout(dropout),
            Conv3d(dim_out, dim_out, 3)
        )

        # 如果输入维度不等于输出维度,则使用 1x1 卷积层进行维度匹配
        self.res_conv = Conv3d(dim, dim_out, 1) if dim != dim_out else nn.Identity()

        # 创建 MPAdd 模块
        self.res_mp_add = MPAdd(t = mp_add_t)

        # 初始化注意力机制相关参数
        self.attn = None
        self.factorized_attn = factorize_space_time_attn

        # 如果需要注意力机制
        if has_attn:
            attn_kwargs = dict(
                dim = dim_out,
                heads = max(ceil(dim_out / attn_dim_head), 2),
                dim_head = attn_dim_head,
                mp_add_t = attn_res_mp_add_t,
                flash = attn_flash
            )

            # 如果需要分解空间和时间的注意力机制
            if factorize_space_time_attn:
                self.attn = nn.ModuleList([
                    Attention(**attn_kwargs, only_space = True),
                    Attention(**attn_kwargs, only_time = True),
                ])
            else:
                self.attn = Attention(**attn_kwargs)

    # 前向传播方法
    def forward(
        self,
        x,
        emb = None
    ):
        # 如果需要上采样
        if self.upsample:
            t, h, w = x.shape[-3:]
            resize_factors = tuple((2 if upsample else 1) for upsample in self.upsample_config)
            interpolate_shape = tuple(shape * factor for shape, factor in zip((t, h, w), resize_factors))

            x = F.interpolate(x, interpolate_shape, mode = 'trilinear')

        # 计算残差连接
        res = self.res_conv(x)

        # 第一个块的操作
        x = self.block1(x)

        # 如果存在 emb,则进行缩放
        if exists(emb):
            scale = self.to_emb(emb) + 1
            x = x * rearrange(scale, 'b c -> b c 1 1 1')

        # 第二个块的操作
        x = self.block2(x)

        # 计算残差连接的 MPAdd
        x = self.res_mp_add(x, res)

        # 如果存在注意力机制
        if exists(self.attn):
            # 如果使用分解的注意力机制
            if self.factorized_attn:
                attn_space, attn_time = self.attn
                x = attn_space(x)
                x = attn_time(x)

            else:
                x = self.attn(x)

        return x

# 定义名为 Attention 的类,继承自 Module 类
class Attention(Module):
    # 初始化方法
    def __init__(
        self,
        dim,
        heads = 4,
        dim_head = 64,
        num_mem_kv = 4,
        flash = False,
        mp_add_t = 0.3,
        only_space = False,
        only_time = False
    ):
        # 调用父类的初始化方法
        super().__init__()
        # 确保只有空间或时间中的一个为 True
        assert (int(only_space) + int(only_time)) <= 1

        # 设置头数和隐藏维度
        self.heads = heads
        hidden_dim = dim_head * heads

        # 像素归一化
        self.pixel_norm = PixelNorm(dim = -1)

        # 注意力机制
        self.attend = Attend(flash = flash)

        # 记忆键值对
        self.mem_kv = nn.Parameter(torch.randn(2, heads, num_mem_kv, dim_head))
        self.to_qkv = Conv3d(dim, hidden_dim * 3, 1)
        self.to_out = Conv3d(hidden_dim, dim, 1)

        # MPAdd 模块
        self.mp_add = MPAdd(t = mp_add_t)

        # 是否只考虑空间或时间
        self.only_space = only_space
        self.only_time = only_time
    # 定义前向传播函数,接受输入 x
    def forward(self, x):
        # 保存输入 x 的原始形状
        res, orig_shape = x, x.shape
        b, c, t, h, w = orig_shape

        # 将输入 x 转换为查询、键、值
        qkv = self.to_qkv(x)

        # 根据 self.only_space 和 self.only_time 进行不同的重排操作
        if self.only_space:
            qkv = rearrange(qkv, 'b c t x y -> (b t) c x y')
        elif self.only_time:
            qkv = rearrange(qkv, 'b c t x y -> (b x y) c t')

        # 将查询、键、值分成三部分
        qkv = qkv.chunk(3, dim = 1)

        # 重排查询、键、值的形状
        q, k, v = map(lambda t: rearrange(t, 'b (h c) ... -> b h (...) c', h = self.heads), qkv)

        # 复制记忆键值对
        mk, mv = map(lambda t: repeat(t, 'h n d -> b h n d', b = k.shape[0]), self.mem_kv)

        # 拼接键和值
        k, v = map(partial(torch.cat, dim = -2), ((mk, k), (mv, v)))

        # 对查询、键、值进行像素归一化
        q, k, v = map(self.pixel_norm, (q, k, v))

        # 进行注意力计算
        out = self.attend(q, k, v)

        # 重排输出形状
        out = rearrange(out, 'b h n d -> b (h d) n')

        # 根据 self.only_space 和 self.only_time 进行不同的重排操作
        if self.only_space:
            out = rearrange(out, '(b t) c n -> b c (t n)', t = t)
        elif self.only_time:
            out = rearrange(out, '(b x y) c n -> b c (n x y)', x = h, y = w)

        # 恢复输出形状
        out = out.reshape(orig_shape)

        # 将输出转换为最终输出
        out = self.to_out(out)

        # 将最终输出与输入相加并返回
        return self.mp_add(out, res)
# 定义了一个名为KarrasUnet3D的类,代表Karras提出的3D U-Net模型
# 该模型没有偏置,没有组归一化,使用保持幅度的操作

class KarrasUnet3D(Module):
    """
    根据图21的配置G进行设计
    """

    def __init__(
        self,
        *,
        image_size,              # 图像大小
        frames,                  # 帧数
        dim = 192,               # 维度
        dim_max = 768,           # 通道数将在每次下采样时翻倍,并限制在这个值
        num_classes = None,      # 类别数,在论文中为一个流行的基准测试使用了1000个类别
        channels = 4,            # 为什么是4个通道,可能是指alpha通道?
        num_downsamples = 3,     # 下采样次数
        num_blocks_per_stage: Union[int, Tuple[int, ...]] = 4,  # 每个阶段的块数
        downsample_types: Optional[Tuple[str, ...]] = None,     # 下采样类型
        attn_res = (16, 8),      # 注意力机制的分辨率
        fourier_dim = 16,        # 傅立叶维度
        attn_dim_head = 64,      # 注意力机制的头数
        attn_flash = False,      # 是否使用闪光注意力
        mp_cat_t = 0.5,          # MP Cat阈值
        mp_add_emb_t = 0.5,      # MP Add Emb阈值
        attn_res_mp_add_t = 0.3, # 注意力机制MP Add阈值
        resnet_mp_add_t = 0.3,   # ResNet MP Add阈值
        dropout = 0.1,           # 丢弃率
        self_condition = False,  # 是否自我条件
        factorize_space_time_attn = False  # 是否分解空间时间注意力
    @property
    def downsample_factor(self):
        return 2 ** self.num_downsamples

    def forward(
        self,
        x,
        time,
        self_cond = None,
        class_labels = None
    ):
        # 验证图像形状

        assert x.shape[1:] == (self.channels, self.frames, self.image_size, self.image_size)

        # 自我条件

        if self.self_condition:
            self_cond = default(self_cond, lambda: torch.zeros_like(x))
            x = torch.cat((self_cond, x), dim = 1)
        else:
            assert not exists(self_cond)

        # 时间条件

        time_emb = self.to_time_emb(time)

        # 类别条件

        assert xnor(exists(class_labels), self.needs_class_labels)

        if self.needs_class_labels:
            if class_labels.dtype in (torch.int, torch.long):
                class_labels = F.one_hot(class_labels, self.num_classes)

            assert class_labels.shape[-1] == self.num_classes
            class_labels = class_labels.float() * sqrt(self.num_classes)

            class_emb = self.to_class_emb(class_labels)

            time_emb = self.add_class_emb(time_emb, class_emb)

        # 最终的MP-SiLU用于嵌入

        emb = self.emb_activation(time_emb)

        # 跳跃连接

        skips = []

        # 输入块

        x = self.input_block(x)

        skips.append(x)

        # 下采样

        for encoder in self.downs:
            x = encoder(x, emb = emb)
            skips.append(x)

        # 中间

        for decoder in self.mids:
            x = decoder(x, emb = emb)

        # 上采样

        for decoder in self.ups:
            if decoder.needs_skip:
                skip = skips.pop()
                x = self.skip_mp_cat(x, skip)

            x = decoder(x, emb = emb)

        # 输出块

        return self.output_block(x)

# 改进的MP Transformer

class MPFeedForward(Module):
    def __init__(
        self,
        *,
        dim,
        mult = 4,
        mp_add_t = 0.3
    ):
        super().__init__()
        dim_inner = int(dim * mult)
        self.net = nn.Sequential(
            PixelNorm(dim = 1),
            Conv3d(dim, dim_inner, 1),
            MPSiLU(),
            Conv3d(dim_inner, dim, 1)
        )

        self.mp_add = MPAdd(t = mp_add_t)

    def forward(self, x):
        res = x
        out = self.net(x)
        return self.mp_add(out, res)

class MPImageTransformer(Module):
    def __init__(
        self,
        *,
        dim,
        depth,
        dim_head = 64,
        heads = 8,
        num_mem_kv = 4,
        ff_mult = 4,
        attn_flash = False,
        residual_mp_add_t = 0.3
    # 定义一个继承自 nn.Module 的 Transformer 类
    ):
        # 调用父类的构造函数
        super().__init__()
        # 初始化一个空的 ModuleList 用于存储 Transformer 的层
        self.layers = ModuleList([])

        # 根据指定的深度循环创建 Transformer 的每一层
        for _ in range(depth):
            # 在 layers 中添加一个包含 Attention 和 MPFeedForward 两个模块的 ModuleList
            self.layers.append(ModuleList([
                Attention(dim = dim, heads = heads, dim_head = dim_head, num_mem_kv = num_mem_kv, flash = attn_flash, mp_add_t = residual_mp_add_t),
                MPFeedForward(dim = dim, mult = ff_mult, mp_add_t = residual_mp_add_t)
            ]))

    # 定义 Transformer 类的前向传播函数
    def forward(self, x):

        # 遍历 Transformer 的每一层,依次进行 Attention 和 FeedForward 操作
        for attn, ff in self.layers:
            x = attn(x)
            x = ff(x)

        # 返回处理后的结果
        return x
# 如果当前脚本作为主程序运行
if __name__ == '__main__':

    # 创建一个 KarrasUnet3D 的实例
    unet = KarrasUnet3D(
        frames = 32,  # 视频帧数
        image_size = 64,  # 图像大小
        dim = 8,  # 维度
        dim_max = 768,  # 最大维度
        num_downsamples = 6,  # 下采样次数
        num_blocks_per_stage = (4, 3, 2, 2, 2, 2),  # 每个阶段的块数
        downsample_types = (
            'image',  # 图像下采样类型
            'frame',  # 帧下采样类型
            'image',  # 图像下采样类型
            'frame',  # 帧下采样类型
            'image',  # 图像下采样类型
            'frame',  # 帧下采样类型
        ),
        attn_dim_head = 8,  # 注意力机制的头数
        num_classes = 1000,  # 类别数
        factorize_space_time_attn = True  # 是否在空间和时间上分别进行注意力操作
    )

    # 创建一个形状为 (2, 4, 32, 64, 64) 的随机张量作为视频输入
    video = torch.randn(2, 4, 32, 64, 64)

    # 使用 unet 对视频进行去噪处理
    denoised_video = unet(
        video,  # 输入视频
        time = torch.ones(2,),  # 时间信息
        class_labels = torch.randint(0, 1000, (2,))  # 类别标签
    )

.\lucidrains\denoising-diffusion-pytorch\denoising_diffusion_pytorch\learned_gaussian_diffusion.py

import torch
from collections import namedtuple
from math import pi, sqrt, log as ln
from inspect import isfunction
from torch import nn, einsum
from einops import rearrange

from denoising_diffusion_pytorch.denoising_diffusion_pytorch import GaussianDiffusion, extract, unnormalize_to_zero_to_one

# 定义常量
NAT = 1. / ln(2)

# 定义命名元组
ModelPrediction = namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start', 'pred_variance'])

# 辅助函数

# 判断变量是否存在
def exists(x):
    return x is not None

# 返回默认值
def default(val, d):
    if exists(val):
        return val
    return d() if isfunction(d) else d

# 张量辅助函数

# 计算张量的对数
def log(t, eps = 1e-15):
    return torch.log(t.clamp(min = eps))

# 求张量的平均值
def meanflat(x):
    return x.mean(dim = tuple(range(1, len(x.shape)))

# 计算两个正态分布之间的 KL 散度
def normal_kl(mean1, logvar1, mean2, logvar2):
    """
    KL divergence between normal distributions parameterized by mean and log-variance.
    """
    return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2))

# 近似标准正态分布的累积分布函数
def approx_standard_normal_cdf(x):
    return 0.5 * (1.0 + torch.tanh(sqrt(2.0 / pi) * (x + 0.044715 * (x ** 3)))

# 计算离散高斯分布的对数似然
def discretized_gaussian_log_likelihood(x, *, means, log_scales, thres = 0.999):
    assert x.shape == means.shape == log_scales.shape

    centered_x = x - means
    inv_stdv = torch.exp(-log_scales)
    plus_in = inv_stdv * (centered_x + 1. / 255.)
    cdf_plus = approx_standard_normal_cdf(plus_in)
    min_in = inv_stdv * (centered_x - 1. / 255.)
    cdf_min = approx_standard_normal_cdf(min_in)
    log_cdf_plus = log(cdf_plus)
    log_one_minus_cdf_min = log(1. - cdf_min)
    cdf_delta = cdf_plus - cdf_min

    log_probs = torch.where(x < -thres,
        log_cdf_plus,
        torch.where(x > thres,
            log_one_minus_cdf_min,
            log(cdf_delta)))

    return log_probs

# https://arxiv.org/abs/2102.09672

# i thought the results were questionable, if one were to focus only on FID
# but may as well get this in here for others to try, as GLIDE is using it (and DALL-E2 first stage of cascade)
# gaussian diffusion for learned variance + hybrid eps simple + vb loss

# 继承 GaussianDiffusion 类,实现 LearnedGaussianDiffusion 类
class LearnedGaussianDiffusion(GaussianDiffusion):
    def __init__(
        self,
        model,
        vb_loss_weight = 0.001,  # lambda was 0.001 in the paper
        *args,
        **kwargs
    ):
        super().__init__(model, *args, **kwargs)
        assert model.out_dim == (model.channels * 2), 'dimension out of unet must be twice the number of channels for learned variance - you can also set the `learned_variance` keyword argument on the Unet to be `True`'
        assert not model.self_condition, 'not supported yet'

        self.vb_loss_weight = vb_loss_weight

    # 模型预测函数
    def model_predictions(self, x, t, x_self_cond = None, clip_x_start = False, rederive_pred_noise = False):
        model_output = self.model(x, t)
        model_output, pred_variance = model_output.chunk(2, dim = 1)

        maybe_clip = partial(torch.clamp, min = -1., max = 1.) if clip_x_start else identity

        if self.objective == 'pred_noise':
            pred_noise = model_output
            x_start = self.predict_start_from_noise(x, t, model_output)

        elif self.objective == 'pred_x0':
            pred_noise = self.predict_noise_from_start(x, t, model_output)
            x_start = model_output

        x_start = maybe_clip(x_start)

        return ModelPrediction(pred_noise, x_start, pred_variance)
    # 计算预测均值、方差和对数方差,根据输入的特征 x 和时间 t,以及是否裁剪去噪声
    def p_mean_variance(self, *, x, t, clip_denoised, model_output = None, **kwargs):
        # 如果未提供模型输出,则使用默认的模型输出函数计算模型输出
        model_output = default(model_output, lambda: self.model(x, t))
        # 将模型输出分成预测噪声和插值分数未归一化的方差
        pred_noise, var_interp_frac_unnormalized = model_output.chunk(2, dim = 1)

        # 提取后验对数方差的最小值和最大值
        min_log = extract(self.posterior_log_variance_clipped, t, x.shape)
        max_log = extract(torch.log(self.betas), t, x.shape)
        # 将插值分数未归一化的方差归一化到 [0, 1] 区间
        var_interp_frac = unnormalize_to_zero_to_one(var_interp_frac_unnormalized)

        # 计算模型对数方差和方差
        model_log_variance = var_interp_frac * max_log + (1 - var_interp_frac) * min_log
        model_variance = model_log_variance.exp()

        # 根据预测噪声和时间 t 预测起始值 x_start
        x_start = self.predict_start_from_noise(x, t, pred_noise)

        # 如果需要裁剪去噪声,则将 x_start 裁剪到 [-1, 1] 区间
        if clip_denoised:
            x_start.clamp_(-1., 1.)

        # 计算模型均值和其他参数
        model_mean, _, _ = self.q_posterior(x_start, x, t)

        # 返回模型均值、方差、对数方差和起始值 x_start
        return model_mean, model_variance, model_log_variance, x_start

    # 计算损失函数,包括 KL 散度和简单损失
    def p_losses(self, x_start, t, noise = None, clip_denoised = False):
        # 如果未提供噪声,则使用默认的噪声函数生成噪声
        noise = default(noise, lambda: torch.randn_like(x_start))
        # 根据起始值 x_start、时间 t 和噪声生成 x_t
        x_t = self.q_sample(x_start = x_start, t = t, noise = noise)

        # 获取模型输出
        model_output = self.model(x_t, t)

        # 计算学习方差(插值)的 KL 散度
        true_mean, _, true_log_variance_clipped = self.q_posterior(x_start = x_start, x_t = x_t, t = t)
        model_mean, _, model_log_variance, _ = self.p_mean_variance(x = x_t, t = t, clip_denoised = clip_denoised, model_output = model_output)

        # 为了稳定性,使用分离的模型预测均值计算 KL 散度
        detached_model_mean = model_mean.detach()

        kl = normal_kl(true_mean, true_log_variance_clipped, detached_model_mean, model_log_variance)
        kl = meanflat(kl) * NAT

        # 计算解码器负对数似然
        decoder_nll = -discretized_gaussian_log_likelihood(x_start, means = detached_model_mean, log_scales = 0.5 * model_log_variance)
        decoder_nll = meanflat(decoder_nll) * NAT

        # 在第一个时间步返回解码器 NLL,否则返回 KL 散度
        vb_losses = torch.where(t == 0, decoder_nll, kl)

        # 简单损失 - 预测噪声、x0 或 x_prev
        pred_noise, _ = model_output.chunk(2, dim = 1)
        simple_losses = F.mse_loss(pred_noise, noise)

        # 返回简单损失和 VB 损失的平均值乘以 VB 损失权重
        return simple_losses + vb_losses.mean() * self.vb_loss_weight

.\lucidrains\denoising-diffusion-pytorch\denoising_diffusion_pytorch\simple_diffusion.py

# 导入数学库
import math
# 导入 functools 模块中的 partial 和 wraps 函数
from functools import partial, wraps

# 导入 torch 库
import torch
# 从 torch 库中导入 sqrt 函数
from torch import sqrt
# 从 torch 库中导入 nn 模块和 einsum 函数
from torch import nn, einsum
# 从 torch 库中导入 nn.functional 模块和 F 别名
import torch.nn.functional as F
# 从 torch.special 模块中导入 expm1 函数
from torch.special import expm1
# 从 torch.cuda.amp 模块中导入 autocast 函数

# 导入 tqdm 库
from tqdm import tqdm
# 从 einops 库中导入 rearrange、repeat、reduce、pack、unpack 函数
from einops import rearrange, repeat, reduce, pack, unpack
# 从 einops.layers.torch 模块中导入 Rearrange 类

# helpers

# 判断变量是否存在的函数
def exists(val):
    return val is not None

# 返回输入的函数
def identity(t):
    return t

# 判断是否为 lambda 函数的函数
def is_lambda(f):
    return callable(f) and f.__name__ == "<lambda>"

# 返回默认值的函数
def default(val, d):
    if exists(val):
        return val
    return d() if is_lambda(d) else d

# 将输入转换为元组的函数
def cast_tuple(t, l = 1):
    return ((t,) * l) if not isinstance(t, tuple) else t

# 在输入张量中添加维度的函数
def append_dims(t, dims):
    shape = t.shape
    return t.reshape(*shape, *((1,) * dims))

# 对输入张量进行 L2 归一化的函数
def l2norm(t):
    return F.normalize(t, dim = -1)

# u-vit 相关函数和模块

# 上采样模块
class Upsample(nn.Module):
    def __init__(
        self,
        dim,
        dim_out = None,
        factor = 2
    ):
        super().__init__()
        self.factor = factor
        self.factor_squared = factor ** 2

        dim_out = default(dim_out, dim)
        conv = nn.Conv2d(dim, dim_out * self.factor_squared, 1)

        self.net = nn.Sequential(
            conv,
            nn.SiLU(),
            nn.PixelShuffle(factor)
        )

        self.init_conv_(conv)

    # 初始化卷积层权重
    def init_conv_(self, conv):
        o, i, h, w = conv.weight.shape
        conv_weight = torch.empty(o // self.factor_squared, i, h, w)
        nn.init.kaiming_uniform_(conv_weight)
        conv_weight = repeat(conv_weight, 'o ... -> (o r) ...', r = self.factor_squared)

        conv.weight.data.copy_(conv_weight)
        nn.init.zeros_(conv.bias.data)

    def forward(self, x):
        return self.net(x)

# 下采样模块
def Downsample(
    dim,
    dim_out = None,
    factor = 2
):
    return nn.Sequential(
        Rearrange('b c (h p1) (w p2) -> b (c p1 p2) h w', p1 = factor, p2 = factor),
        nn.Conv2d(dim * (factor ** 2), default(dim_out, dim), 1)
    )

# RMS 归一化模块
class RMSNorm(nn.Module):
    def __init__(self, dim, scale = True, normalize_dim = 2):
        super().__init__()
        self.g = nn.Parameter(torch.ones(dim)) if scale else 1

        self.scale = scale
        self.normalize_dim = normalize_dim

    def forward(self, x):
        normalize_dim = self.normalize_dim
        scale = append_dims(self.g, x.ndim - self.normalize_dim - 1) if self.scale else 1
        return F.normalize(x, dim = normalize_dim) * scale * (x.shape[normalize_dim] ** 0.5)

# 正弦位置嵌入模块
class LearnedSinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        assert (dim % 2) == 0
        half_dim = dim // 2
        self.weights = nn.Parameter(torch.randn(half_dim))

    def forward(self, x):
        x = rearrange(x, 'b -> b 1')
        freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi
        fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1)
        fouriered = torch.cat((x, fouriered), dim = -1)
        return fouriered

# 基础模块
class Block(nn.Module):
    def __init__(self, dim, dim_out, groups = 8):
        super().__init__()
        self.proj = nn.Conv2d(dim, dim_out, 3, padding = 1)
        self.norm = nn.GroupNorm(groups, dim_out)
        self.act = nn.SiLU()

    def forward(self, x, scale_shift = None):
        x = self.proj(x)
        x = self.norm(x)

        if exists(scale_shift):
            scale, shift = scale_shift
            x = x * (scale + 1) + shift

        x = self.act(x)
        return x

class ResnetBlock(nn.Module):
    # 初始化函数,定义神经网络结构
    def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8):
        # 调用父类的初始化函数
        super().__init__()
        # 如果存在时间嵌入维度,则创建包含激活函数和线性层的序列模块
        self.mlp = nn.Sequential(
            nn.SiLU(),
            nn.Linear(time_emb_dim, dim_out * 2)
        ) if exists(time_emb_dim) else None

        # 创建第一个块
        self.block1 = Block(dim, dim_out, groups = groups)
        # 创建第二个块
        self.block2 = Block(dim_out, dim_out, groups = groups)
        # 如果输入维度和输出维度不相等,则使用卷积层进行维度转换,否则使用恒等映射
        self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()

    # 前向传播函数
    def forward(self, x, time_emb = None):

        scale_shift = None
        # 如果存在时间嵌入模块和时间嵌入向量,则进行处理
        if exists(self.mlp) and exists(time_emb):
            # 对时间嵌入向量进行处理
            time_emb = self.mlp(time_emb)
            # 重新排列时间嵌入向量的维度
            time_emb = rearrange(time_emb, 'b c -> b c 1 1')
            # 将时间嵌入向量分成两部分,用于缩放和平移
            scale_shift = time_emb.chunk(2, dim = 1)

        # 使用第一个块处理输入数据
        h = self.block1(x, scale_shift = scale_shift)

        # 使用第二个块处理第一个块的输出
        h = self.block2(h)

        # 返回块处理后的结果与输入数据经过维度转换后的结果的和
        return h + self.res_conv(x)
class LinearAttention(nn.Module):
    # 初始化线性注意力模块
    def __init__(self, dim, heads = 4, dim_head = 32):
        super().__init__()
        # 缩放因子
        self.scale = dim_head ** -0.5
        self.heads = heads
        hidden_dim = dim_head * heads

        # 归一化层
        self.norm = RMSNorm(dim, normalize_dim = 1)
        # 转换输入到查询、键、值
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)

        # 输出转换层
        self.to_out = nn.Sequential(
            nn.Conv2d(hidden_dim, dim, 1),
            RMSNorm(dim, normalize_dim = 1)
        )

    # 前向传播函数
    def forward(self, x):
        residual = x

        b, c, h, w = x.shape

        x = self.norm(x)

        qkv = self.to_qkv(x).chunk(3, dim = 1)
        q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv)

        q = q.softmax(dim = -2)
        k = k.softmax(dim = -1)

        q = q * self.scale

        context = torch.einsum('b h d n, b h e n -> b h d e', k, v)

        out = torch.einsum('b h d e, b h d n -> b h e n', context, q)
        out = rearrange(out, 'b h c (x y) -> b (h c) x y', h = self.heads, x = h, y = w)

        return self.to_out(out) + residual

class Attention(nn.Module):
    # 初始化注意力模块
    def __init__(self, dim, heads = 4, dim_head = 32, scale = 8, dropout = 0.):
        super().__init__()
        self.scale = scale
        self.heads = heads
        hidden_dim = dim_head * heads

        # 归一化层
        self.norm = RMSNorm(dim)

        self.attn_dropout = nn.Dropout(dropout)
        # 转换输入到查询、键、值
        self.to_qkv = nn.Linear(dim, hidden_dim * 3, bias = False)

        self.q_scale = nn.Parameter(torch.ones(dim_head))
        self.k_scale = nn.Parameter(torch.ones(dim_head))

        # 输出转换层
        self.to_out = nn.Linear(hidden_dim, dim, bias = False)

    # 前向传播函数
    def forward(self, x):
        x = self.norm(x)

        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        q, k = map(l2norm, (q, k))

        q = q * self.q_scale
        k = k * self.k_scale

        sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        attn = sim.softmax(dim = -1)
        attn = self.attn_dropout(attn)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)

        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class FeedForward(nn.Module):
    # 初始化前馈神经网络模块
    def __init__(
        self,
        dim,
        cond_dim,
        mult = 4,
        dropout = 0.
    ):
        super().__init__()
        # 归一化层
        self.norm = RMSNorm(dim, scale = False)
        dim_hidden = dim * mult

        # 缩放和偏移层
        self.to_scale_shift = nn.Sequential(
            nn.SiLU(),
            nn.Linear(cond_dim, dim_hidden * 2),
            Rearrange('b d -> b 1 d')
        )

        to_scale_shift_linear = self.to_scale_shift[-2]
        nn.init.zeros_(to_scale_shift_linear.weight)
        nn.init.zeros_(to_scale_shift_linear.bias)

        # 输入投影层
        self.proj_in = nn.Sequential(
            nn.Linear(dim, dim_hidden, bias = False),
            nn.SiLU()
        )

        # 输出投影层
        self.proj_out = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(dim_hidden, dim, bias = False)
        )

    # 前向传播函数
    def forward(self, x, t):
        x = self.norm(x)
        x = self.proj_in(x)

        scale, shift = self.to_scale_shift(t).chunk(2, dim = -1)
        x = x * (scale + 1) + shift

        return self.proj_out(x)

# vit

class Transformer(nn.Module):
    # 初始化Transformer模块
    def __init__(
        self,
        dim,
        time_cond_dim,
        depth,
        dim_head = 32,
        heads = 4,
        ff_mult = 4,
        dropout = 0.,
    ):
        super().__init__()

        self.layers = nn.ModuleList([])
        # 创建多层Transformer
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = dropout),
                FeedForward(dim = dim, mult = ff_mult, cond_dim = time_cond_dim, dropout = dropout)
            ]))

    # 前向传播函数
    def forward(self, x, t):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x, t) + x

        return x
# 定义 UViT 类,继承自 nn.Module
class UViT(nn.Module):
    # 初始化函数,接受多个参数
    def __init__(
        self,
        dim,  # 特征维度
        init_dim = None,  # 初始维度,默认为 None
        out_dim = None,  # 输出维度,默认为 None
        dim_mults = (1, 2, 4, 8),  # 维度倍增因子,默认为 (1, 2, 4, 8)
        downsample_factor = 2,  # 下采样因子,默认为 2
        channels = 3,  # 通道数,默认为 3
        vit_depth = 6,  # ViT 深度,默认为 6
        vit_dropout = 0.2,  # ViT dropout 概率,默认为 0.2
        attn_dim_head = 32,  # 注意力头维度,默认为 32
        attn_heads = 4,  # 注意力头数,默认为 4
        ff_mult = 4,  # FeedForward 层倍增因子,默认为 4
        resnet_block_groups = 8,  # ResNet 块组数,默认为 8
        learned_sinusoidal_dim = 16,  # 学习的正弦维度,默认为 16
        init_img_transform: callable = None,  # 初始图像变换函数,默认为 None
        final_img_itransform: callable = None,  # 最终图像逆变换函数,默认为 None
        patch_size = 1,  # 补丁大小,默认为 1
        dual_patchnorm = False  # 双补丁规范化,默认为 False
        ):
        # 调用父类的构造函数
        super().__init__()

        # 用于初始 DWT 变换(或者研究者想要尝试的其他变换)

        if exists(init_img_transform) and exists(final_img_itransform):
            # 初始化形状为 1x1x32x32 的张量
            init_shape = torch.Size(1, 1, 32, 32)
            mock_tensor = torch.randn(init_shape)
            # 确保经过 final_img_itransform 和 init_img_transform 变换后的形状与初始形状相同
            assert final_img_itransform(init_img_transform(mock_tensor)).shape == init_shape

        # 设置初始图像变换和最终图像逆变换
        self.init_img_transform = default(init_img_transform, identity)
        self.final_img_itransform = default(final_img_itransform, identity)

        input_channels = channels

        init_dim = default(init_dim, dim)
        # 初始化卷积层,输入通道数为 input_channels,输出通道数为 init_dim,卷积核大小为 7x7,填充为 3
        self.init_conv = nn.Conv2d(input_channels, init_dim, 7, padding = 3)

        # 是否进行初始补丁处理,作为 DWT 的替代方案
        self.unpatchify = identity

        input_channels = channels * (patch_size ** 2)
        needs_patch = patch_size > 1

        if needs_patch:
            if not dual_patchnorm:
                # 如果不使用双补丁规范化,则初始化卷积层
                self.init_conv = nn.Conv2d(channels, init_dim, patch_size, stride = patch_size)
            else:
                # 使用双补丁规范化
                self.init_conv = nn.Sequential(
                    Rearrange('b c (h p1) (w p2) -> b h w (c p1 p2)', p1 = patch_size, p2 = patch_size),
                    nn.LayerNorm(input_channels),
                    nn.Linear(input_channels, init_dim),
                    nn.LayerNorm(init_dim),
                    Rearrange('b h w c -> b c h w')
                )

            # 反卷积层,用于将补丁还原为原始图像
            self.unpatchify = nn.ConvTranspose2d(input_channels, channels, patch_size, stride = patch_size)

        # 确定维度
        dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
        in_out = list(zip(dims[:-1], dims[1:]))

        # 部分 ResNet 块
        resnet_block = partial(ResnetBlock, groups = resnet_block_groups)

        # 时间嵌入
        time_dim = dim * 4

        sinu_pos_emb = LearnedSinusoidalPosEmb(learned_sinusoidal_dim)
        fourier_dim = learned_sinusoidal_dim + 1

        # 时间 MLP
        self.time_mlp = nn.Sequential(
            sinu_pos_emb,
            nn.Linear(fourier_dim, time_dim),
            nn.GELU(),
            nn.Linear(time_dim, time_dim)
        )

        # 下采样因子
        downsample_factor = cast_tuple(downsample_factor, len(dim_mults)
        assert len(downsample_factor) == len(dim_mults)

        # 层
        self.downs = nn.ModuleList([])
        self.ups = nn.ModuleList([])
        num_resolutions = len(in_out)

        for ind, ((dim_in, dim_out), factor) in enumerate(zip(in_out, downsample_factor)):
            is_last = ind >= (num_resolutions - 1)

            self.downs.append(nn.ModuleList([
                resnet_block(dim_in, dim_in, time_emb_dim = time_dim),
                resnet_block(dim_in, dim_in, time_emb_dim = time_dim),
                LinearAttention(dim_in),
                Downsample(dim_in, dim_out, factor = factor)
            ]))

        mid_dim = dims[-1]

        # ViT 模型
        self.vit = Transformer(
            dim = mid_dim,
            time_cond_dim = time_dim,
            depth = vit_depth,
            dim_head = attn_dim_head,
            heads = attn_heads,
            ff_mult = ff_mult,
            dropout = vit_dropout
        )

        for ind, ((dim_in, dim_out), factor) in enumerate(zip(reversed(in_out), reversed(downsample_factor))):
            is_last = ind == (len(in_out) - 1)

            self.ups.append(nn.ModuleList([
                Upsample(dim_out, dim_in, factor = factor),
                resnet_block(dim_in * 2, dim_in, time_emb_dim = time_dim),
                resnet_block(dim_in * 2, dim_in, time_emb_dim = time_dim),
                LinearAttention(dim_in),
            ]))

        default_out_dim = input_channels
        self.out_dim = default(out_dim, default_out_dim)

        # 最终 ResNet 块和卷积层
        self.final_res_block = resnet_block(dim * 2, dim, time_emb_dim = time_dim)
        self.final_conv = nn.Conv2d(dim, self.out_dim, 1)
    # 定义前向传播函数,接受输入 x 和时间信息 time
    def forward(self, x, time):
        # 对输入图像进行初始化转换
        x = self.init_img_transform(x)

        # 初始卷积操作
        x = self.init_conv(x)
        # 保存初始特征图
        r = x.clone()

        # 时间信息通过 MLP 网络处理
        t = self.time_mlp(time)

        # 存储中间特征图的列表
        h = []

        # 下采样模块
        for block1, block2, attn, downsample in self.downs:
            # 第一个块处理
            x = block1(x, t)
            h.append(x)

            # 第二个块处理
            x = block2(x, t)
            # 注意力机制处理
            x = attn(x)
            h.append(x)

            # 下采样操作
            x = downsample(x)

        # 重新排列特征图维度
        x = rearrange(x, 'b c h w -> b h w c')
        # 打包特征图
        x, ps = pack([x], 'b * c')

        # Vision Transformer 处理
        x = self.vit(x, t)

        # 解包特征图
        x, = unpack(x, ps, 'b * c')
        # 重新排列特征图维度
        x = rearrange(x, 'b h w c -> b c h w')

        # 上采样模块
        for upsample, block1, block2, attn in self.ups:
            # 上采样操作
            x = upsample(x)

            # 拼接特征图
            x = torch.cat((x, h.pop()), dim = 1)
            x = block1(x, t)

            # 拼接特征图
            x = torch.cat((x, h.pop()), dim = 1)
            x = block2(x, t)
            x = attn(x)

        # 拼接初始特征图
        x = torch.cat((x, r), dim = 1)

        # 最终残差块处理
        x = self.final_res_block(x, t)
        # 最终卷积操作
        x = self.final_conv(x)

        # 反向解除图像补丁
        x = self.unpatchify(x)
        # 返回最终图像
        return self.final_img_itransform(x)
# normalization functions

# 将图像数据归一化到 [-1, 1] 范围
def normalize_to_neg_one_to_one(img):
    return img * 2 - 1

# 将归一化后的数据反归一化到 [0, 1] 范围
def unnormalize_to_zero_to_one(t):
    return (t + 1) * 0.5

# diffusion helpers

# 将 t 张量的维度右侧填充到与 x 张量相同维度
def right_pad_dims_to(x, t):
    padding_dims = x.ndim - t.ndim
    if padding_dims <= 0:
        return t
    return t.view(*t.shape, *((1,) * padding_dims))

# logsnr schedules and shifting / interpolating decorators
# only cosine for now

# 计算张量 t 的对数,避免 t 小于 eps 时取对数出错
def log(t, eps = 1e-20):
    return torch.log(t.clamp(min = eps))

# 计算 logsnr 的余弦调度
def logsnr_schedule_cosine(t, logsnr_min = -15, logsnr_max = 15):
    t_min = math.atan(math.exp(-0.5 * logsnr_max))
    t_max = math.atan(math.exp(-0.5 * logsnr_min))
    return -2 * log(torch.tan(t_min + t * (t_max - t_min)))

# 对 logsnr_schedule_cosine 进行偏移
def logsnr_schedule_shifted(fn, image_d, noise_d):
    shift = 2 * math.log(noise_d / image_d)
    @wraps(fn)
    def inner(*args, **kwargs):
        nonlocal shift
        return fn(*args, **kwargs) + shift
    return inner

# 对 logsnr_schedule_cosine 进行插值
def logsnr_schedule_interpolated(fn, image_d, noise_d_low, noise_d_high):
    logsnr_low_fn = logsnr_schedule_shifted(fn, image_d, noise_d_low)
    logsnr_high_fn = logsnr_schedule_shifted(fn, image_d, noise_d_high)

    @wraps(fn)
    def inner(t, *args, **kwargs):
        nonlocal logsnr_low_fn
        nonlocal logsnr_high_fn
        return t * logsnr_low_fn(t, *args, **kwargs) + (1 - t) * logsnr_high_fn(t, *args, **kwargs)

    return inner

# main gaussian diffusion class

# 高斯扩散类
class GaussianDiffusion(nn.Module):
    def __init__(
        self,
        model: UViT,
        *,
        image_size,
        channels = 3,
        pred_objective = 'v',
        noise_schedule = logsnr_schedule_cosine,
        noise_d = None,
        noise_d_low = None,
        noise_d_high = None,
        num_sample_steps = 500,
        clip_sample_denoised = True,
        min_snr_loss_weight = True,
        min_snr_gamma = 5
    ):
        super().__init__()
        assert pred_objective in {'v', 'eps'}, 'whether to predict v-space (progressive distillation paper) or noise'

        self.model = model

        # image dimensions

        self.channels = channels
        self.image_size = image_size

        # training objective

        self.pred_objective = pred_objective

        # noise schedule

        assert not all([*map(exists, (noise_d, noise_d_low, noise_d_high))]), 'you must either set noise_d for shifted schedule, or noise_d_low and noise_d_high for shifted and interpolated schedule'

        # determine shifting or interpolated schedules

        self.log_snr = noise_schedule

        if exists(noise_d):
            self.log_snr = logsnr_schedule_shifted(self.log_snr, image_size, noise_d)

        if exists(noise_d_low) or exists(noise_d_high):
            assert exists(noise_d_low) and exists(noise_d_high), 'both noise_d_low and noise_d_high must be set'

            self.log_snr = logsnr_schedule_interpolated(self.log_snr, image_size, noise_d_low, noise_d_high)

        # sampling

        self.num_sample_steps = num_sample_steps
        self.clip_sample_denoised = clip_sample_denoised

        # loss weight

        self.min_snr_loss_weight = min_snr_loss_weight
        self.min_snr_gamma = min_snr_gamma

    @property
    def device(self):
        return next(self.model.parameters()).device
    # 计算均值和方差
    def p_mean_variance(self, x, time, time_next):
        
        # 计算当前时间点和下一个时间点的对数信噪比
        log_snr = self.log_snr(time)
        log_snr_next = self.log_snr(time_next)
        # 计算 c 值
        c = -expm1(log_snr - log_snr_next)

        # 计算 alpha 和 sigma
        squared_alpha, squared_alpha_next = log_snr.sigmoid(), log_snr_next.sigmoid()
        squared_sigma, squared_sigma_next = (-log_snr).sigmoid(), (-log_snr_next).sigmoid()
        alpha, sigma, alpha_next = map(sqrt, (squared_alpha, squared_sigma, squared_alpha_next))

        # 重复 log_snr 以匹配 x 的形状
        batch_log_snr = repeat(log_snr, ' -> b', b = x.shape[0])
        # 使用模型预测
        pred = self.model(x, batch_log_snr)

        # 根据预测目标选择不同的计算方式
        if self.pred_objective == 'v':
            x_start = alpha * x - sigma * pred
        elif self.pred_objective == 'eps':
            x_start = (x - sigma * pred) / alpha

        # 将 x_start 限制在 -1 到 1 之间
        x_start.clamp_(-1., 1.)

        # 计算模型均值和后验方差
        model_mean = alpha_next * (x * (1 - c) / alpha + c * x_start)
        posterior_variance = squared_sigma_next * c

        return model_mean, posterior_variance

    # 采样相关函数

    @torch.no_grad()
    def p_sample(self, x, time, time_next):
        batch, *_, device = *x.shape, x.device

        # 计算模型均值和方差
        model_mean, model_variance = self.p_mean_variance(x = x, time = time, time_next = time_next)

        # 如果是最后一个时间点,则直接返回模型均值
        if time_next == 0:
            return model_mean

        # 生成噪声并返回采样结果
        noise = torch.randn_like(x)
        return model_mean + sqrt(model_variance) * noise

    @torch.no_grad()
    def p_sample_loop(self, shape):
        batch = shape[0]

        # 生成随机初始图像
        img = torch.randn(shape, device = self.device)
        steps = torch.linspace(1., 0., self.num_sample_steps + 1, device = self.device)

        # 循环进行采样
        for i in tqdm(range(self.num_sample_steps), desc = 'sampling loop time step', total = self.num_sample_steps):
            times = steps[i]
            times_next = steps[i + 1]
            img = self.p_sample(img, times, times_next)

        # 将图像限制在 -1 到 1 之间,并反归一化到 [0, 1] 范围
        img.clamp_(-1., 1.)
        img = unnormalize_to_zero_to_one(img)
        return img

    @torch.no_grad()
    def sample(self, batch_size = 16):
        return self.p_sample_loop((batch_size, self.channels, self.image_size, self.image_size))

    # 训练相关函数 - 噪声预测

    @autocast(enabled = False)
    def q_sample(self, x_start, times, noise = None):
        noise = default(noise, lambda: torch.randn_like(x_start))

        # 计算 alpha 和 sigma,生成带噪声的图像
        log_snr = self.log_snr(times)
        log_snr_padded = right_pad_dims_to(x_start, log_snr)
        alpha, sigma = sqrt(log_snr_padded.sigmoid()), sqrt((-log_snr_padded).sigmoid())
        x_noised =  x_start * alpha + noise * sigma

        return x_noised, log_snr

    # 计算损失函数
    def p_losses(self, x_start, times, noise = None):
        noise = default(noise, lambda: torch.randn_like(x_start))

        # 生成带噪声的图像并计算模型输出
        x, log_snr = self.q_sample(x_start = x_start, times = times, noise = noise)
        model_out = self.model(x, log_snr)

        # 根据预测目标选择不同的计算方式
        if self.pred_objective == 'v':
            padded_log_snr = right_pad_dims_to(x, log_snr)
            alpha, sigma = padded_log_snr.sigmoid().sqrt(), (-padded_log_snr).sigmoid().sqrt()
            target = alpha * noise - sigma * x_start
        elif self.pred_objective == 'eps':
            target = noise

        # 计算均方误差损失
        loss = F.mse_loss(model_out, target, reduction = 'none')
        loss = reduce(loss, 'b ... -> b', 'mean')

        snr = log_snr.exp()

        maybe_clip_snr = snr.clone()
        if self.min_snr_loss_weight:
            maybe_clip_snr.clamp_(max = self.min_snr_gamma)

        # 根据预测目标选择不同的损失权重计算方式
        if self.pred_objective == 'v':
            loss_weight = maybe_clip_snr / (snr + 1)
        elif self.pred_objective == 'eps':
            loss_weight = maybe_clip_snr / snr

        return (loss * loss_weight).mean()
    # 定义一个前向传播函数,接受图像和其他参数
    def forward(self, img, *args, **kwargs):
        # 解包图像的形状信息,包括通道数、高度、宽度等
        b, c, h, w, device, img_size, = *img.shape, img.device, self.image_size
        # 断言图像的高度和宽度必须等于指定的图像大小
        assert h == img_size and w == img_size, f'height and width of image must be {img_size}'

        # 将图像数据归一化到 -1 到 1 之间
        img = normalize_to_neg_one_to_one(img)
        # 创建一个与图像数量相同的随机时间数组
        times = torch.zeros((img.shape[0],), device = self.device).float().uniform_(0, 1)

        # 调用损失函数计算函数,传入图像、时间和其他参数
        return self.p_losses(img, times, *args, **kwargs)