Lucidrains-系列项目源码解析-六十七-

67 阅读14分钟

Lucidrains 系列项目源码解析(六十七)

.\lucidrains\nuwa-pytorch\nuwa_pytorch\optimizer.py

# 导入 torch 库
import torch
# 从 torch.optim 中导入 AdamW 和 Adam 优化器

# 分离可进行权重衰减的参数
def separate_weight_decayable_params(params):
    # 找出参数中维度小于 2 的参数,即不需要进行权重衰减的参数
    no_wd_params = set([param for param in params if param.ndim < 2])
    # 计算需要进行权重衰减的参数
    wd_params = set(params) - no_wd_params
    return wd_params, no_wd_params

# 获取优化器
def get_optimizer(
    params,
    lr = 3e-4,
    wd = 1e-1,
    filter_by_requires_grad = False
):
    # 如果需要根据 requires_grad 过滤参数
    if filter_by_requires_grad:
        # 过滤出 requires_grad 为 True 的参数
        params = list(filter(lambda t: t.requires_grad, params))

    # 如果权重衰减参数为 0,则使用 Adam 优化器
    if wd == 0:
        return Adam(list(params), lr = lr)

    # 将参数转换为集合
    params = set(params)
    # 分离出需要进行权重衰减的参数和不需要进行权重衰减的参数
    wd_params, no_wd_params = separate_weight_decayable_params(params)

    # 构建参数组,其中包含需要进行权重衰减的参数和不需要进行权重衰减的参数
    param_groups = [
        {'params': list(wd_params)},
        {'params': list(no_wd_params), 'weight_decay': 0},
    ]

    # 使用 AdamW 优化器,设置学习率和权重衰减参数
    return AdamW(param_groups, lr = lr, weight_decay = wd)

.\lucidrains\nuwa-pytorch\nuwa_pytorch\reversible.py

# 导入 torch 库
import torch
# 导入 torch 中的神经网络模块
import torch.nn as nn
# 从 operator 模块中导入 itemgetter 函数
from operator import itemgetter
# 从 torch.autograd.function 模块中导入 Function 类
from torch.autograd.function import Function
# 从 torch.utils.checkpoint 模块中导入 get_device_states 和 set_device_states 函数

# 用于将参数路由到可逆层函数中的函数
def route_args(router, args, depth):
    # 初始化路由后的参数列表
    routed_args = [(dict(), dict()) for _ in range(depth)]
    # 获取参数中与路由器匹配的键
    matched_keys = [key for key in args.keys() if key in router]

    # 遍历匹配的键
    for key in matched_keys:
        val = args[key]
        # 遍历路由后的参数列表和路由器中的路由
        for depth, ((f_args, g_args), routes) in enumerate(zip(routed_args, router[key])):
            # 根据路由将参数添加到对应的函数参数中
            new_f_args, new_g_args = map(lambda route: ({key: val} if route else {}), routes)
            routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args})
    return routed_args

# 参考示例 https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html 中的保存和设置随机数生成器
class Deterministic(nn.Module):
    def __init__(self, net):
        super().__init__()
        self.net = net
        self.cpu_state = None
        self.cuda_in_fwd = None
        self.gpu_devices = None
        self.gpu_states = None

    def record_rng(self, *args):
        self.cpu_state = torch.get_rng_state()
        if torch.cuda._initialized:
            self.cuda_in_fwd = True
            self.gpu_devices, self.gpu_states = get_device_states(*args)

    def forward(self, *args, record_rng = False, set_rng = False, **kwargs):
        if record_rng:
            self.record_rng(*args)

        if not set_rng:
            return self.net(*args, **kwargs)

        rng_devices = []
        if self.cuda_in_fwd:
            rng_devices = self.gpu_devices

        with torch.random.fork_rng(devices=rng_devices, enabled=True):
            torch.set_rng_state(self.cpu_state)
            if self.cuda_in_fwd:
                set_device_states(self.gpu_devices, self.gpu_states)
            return self.net(*args, **kwargs)

# 受 https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py 启发
# 一旦多 GPU 确认工作正常,重构并将 PR 发回源代码
class ReversibleBlock(nn.Module):
    def __init__(self, f, g):
        super().__init__()
        self.f = Deterministic(f)
        self.g = Deterministic(g)

    def forward(self, x, f_args = {}, g_args = {}):
        x1, x2 = torch.chunk(x, 2, dim=2)
        y1, y2 = None, None

        with torch.no_grad():
            y1 = x1 + self.f(x2, record_rng=self.training, **f_args)
            y2 = x2 + self.g(y1, record_rng=self.training, **g_args)

        return torch.cat([y1, y2], dim=2)

    def backward_pass(self, y, dy, f_args = {}, g_args = {}):
        y1, y2 = torch.chunk(y, 2, dim=2)
        del y

        dy1, dy2 = torch.chunk(dy, 2, dim=2)
        del dy

        with torch.enable_grad():
            y1.requires_grad = True
            gy1 = self.g(y1, set_rng=True, **g_args)
            torch.autograd.backward(gy1, dy2)

        with torch.no_grad():
            x2 = y2 - gy1
            del y2, gy1

            dx1 = dy1 + y1.grad
            del dy1
            y1.grad = None

        with torch.enable_grad():
            x2.requires_grad = True
            fx2 = self.f(x2, set_rng=True, **f_args)
            torch.autograd.backward(fx2, dx1, retain_graph=True)

        with torch.no_grad():
            x1 = y1 - fx2
            del y1, fx2

            dx2 = dy2 + x2.grad
            del dy2
            x2.grad = None

            x = torch.cat([x1, x2.detach()], dim=2)
            dx = torch.cat([dx1, dx2], dim=2)

        return x, dx

class _ReversibleFunction(Function):
    @staticmethod
    def forward(ctx, x, blocks, args):
        ctx.args = args
        for block, kwarg in zip(blocks, args):
            x = block(x, **kwarg)
        ctx.y = x.detach()
        ctx.blocks = blocks
        return x

    @staticmethod
    # 定义反向传播函数,接收上下文和梯度作为参数
    def backward(ctx, dy):
        # 获取上下文中的 y 和 args
        y = ctx.y
        args = ctx.args
        # 反向遍历上下文中的 blocks 和 args
        for block, kwargs in zip(ctx.blocks[::-1], args[::-1]):
            # 调用每个 block 的反向传播函数,更新 y 和 dy
            y, dy = block.backward_pass(y, dy, **kwargs)
        # 返回更新后的梯度
        return dy, None, None
# 定义一个可逆序列的神经网络模块
class ReversibleSequence(nn.Module):
    # 初始化函数,接受一组块和参数路由作为输入
    def __init__(self, blocks, args_route = {}):
        super().__init__()
        # 将参数路由保存在对象中
        self.args_route = args_route
        # 创建一个包含多个可逆块的模块列表
        self.blocks = nn.ModuleList([ReversibleBlock(f=f, g=g) for f, g in blocks])

    # 前向传播函数
    def forward(self, x, **kwargs):
        # 在最后一个维度上将输入张量 x 进行拼接
        x = torch.cat([x, x], dim=-1)

        # 获取模块列表和参数路由
        blocks = self.blocks
        args = route_args(self.args_route, kwargs, len(blocks))
        # 将参数转换为字典形式
        args = list(map(lambda x: {'f_args': x[0], 'g_args': x[1]}, args))

        # 将块和参数组成元组列表
        layers_and_args = list(zip(blocks, args))

        # 调用自定义的可逆函数 _ReversibleFunction 的前向传播方法
        out =  _ReversibleFunction.apply(x, blocks, args)
        # 在最后一个维度上将输出张量拆分成两部分,然后对它们进行求和
        return torch.stack(out.chunk(2, dim=-1)).sum(dim=0)

.\lucidrains\nuwa-pytorch\nuwa_pytorch\reversible_video_audio.py

import torch
import torch.nn as nn
from torch.autograd.function import Function
from contextlib import contextmanager

from nuwa_pytorch.reversible import Deterministic

from einops import reduce

# helpers

# 检查值是否存在
def exists(val):
    return val is not None

# 上下文管理器,不执行任何操作
@contextmanager
def null_context():
    yield

# 在指定维度上按索引分割张量
def split_at_index(dim, index, t):
    pre_slices = (slice(None),) * dim
    l = (*pre_slices, slice(None, index))
    r = (*pre_slices, slice(index, None))
    return t[l], t[r]

# reversible self attention block

class ReversibleSelfAttnBlock(nn.Module):
    def __init__(self, f, g, j, k):
        super().__init__()
        self.f = Deterministic(f)
        self.g = Deterministic(g)
        self.j = Deterministic(j)
        self.k = Deterministic(k)        

    def forward(self, x, m, _reverse = True, **kwargs):
        x1, x2 = torch.chunk(x, 2, dim = 2)
        m1, m2 = torch.chunk(m, 2, dim = 2)
        y1, y2, n1, n2 = None, None, None, None

        fn_context = torch.no_grad if _reverse else null_context
        record_rng = self.training and _reverse

        with fn_context():
            y1 = x1 + self.f(x2, record_rng = record_rng)
            y2 = x2 + self.g(y1, record_rng = record_rng)
            n1 = m1 + self.j(m2, record_rng = record_rng)
            n2 = m2 + self.k(n1, record_rng = record_rng)

        return torch.cat((y1, y2), dim = 2), torch.cat((n1, n2), dim = 2)

    def backward_pass(self, y, n, dy, dn, **kwargs):
        y1, y2 = torch.chunk(y, 2, dim = 2)
        del y

        dy1, dy2 = torch.chunk(dy, 2, dim = 2)
        del dy

        with torch.enable_grad():
            y1.requires_grad = True
            gy1 = self.g(y1, set_rng = True)
            torch.autograd.backward(gy1, dy2)

        with torch.no_grad():
            x2 = y2 - gy1
            del y2, gy1

            dx1 = dy1 + y1.grad
            del dy1
            y1.grad = None

        with torch.enable_grad():
            x2.requires_grad = True
            fx2 = self.f(x2, set_rng = True)
            torch.autograd.backward(fx2, dx1, retain_graph = True)

        with torch.no_grad():
            x1 = y1 - fx2
            del y1, fx2

            dx2 = dy2 + x2.grad
            del dy2
            x2.grad = None

            x = torch.cat([x1, x2.detach()], dim = 2)
            dx = torch.cat([dx1, dx2], dim = 2)

        n1, n2 = torch.chunk(n, 2, dim = 2)
        del n

        dn1, dn2 = torch.chunk(dn, 2, dim = 2)
        del dn

        with torch.enable_grad():
            n1.requires_grad = True
            gn1 = self.k(n1, set_rng = True)
            torch.autograd.backward(gn1, dn2)

        with torch.no_grad():
            m2 = n2 - gn1
            del n2, gn1

            dm1 = dn1 + n1.grad
            del dn1
            n1.grad = None

        with torch.enable_grad():
            m2.requires_grad = True
            fm2 = self.j(m2, set_rng = True)
            torch.autograd.backward(fm2, dm1, retain_graph=True)

        with torch.no_grad():
            m1 = n1 - fm2
            del n1, fm2

            dm2 = dn2 + m2.grad
            del dn2
            m2.grad = None

            m = torch.cat([m1, m2.detach()], dim = 2)
            dm = torch.cat([dm1, dm2], dim = 2)

        return x, m, dx, dm

class ReversibleCrossAttnBlock(nn.Module):
    def __init__(self, f, g, j, k):
        super().__init__()
        self.f = Deterministic(f)
        self.g = Deterministic(g)
        self.j = Deterministic(j)
        self.k = Deterministic(k)        
    # 前向传播函数,接受输入 x 和 m,以及一系列参数,返回处理后的结果
    def forward(self, x, m, *, context, context_mask, video_mask = None, audio_mask = None, _reverse = True, **kwargs):
        # 将输入 x 和 m 按照第二维度分成两部分
        x1, x2 = torch.chunk(x, 2, dim = 2)
        m1, m2 = torch.chunk(m, 2, dim = 2)
        y1, y2, n1, n2 = None, None, None, None

        # 根据 _reverse 参数选择是否启用梯度记录
        fn_context = torch.no_grad if _reverse else null_context
        record_rng = self.training and _reverse

        # 使用 fn_context 上下文管理器,根据 _reverse 参数选择是否启用梯度记录
        with fn_context():
            # 计算 y1 和 y2
            y1 = x1 + self.f(x2, context = context, context_mask = context_mask, mask = video_mask, record_rng = record_rng)
            y2 = x2 + self.g(y1, record_rng = record_rng)
            # 计算 n1 和 n2
            n1 = m1 + self.j(m2, context = context, context_mask = context_mask, mask = audio_mask, record_rng = record_rng)
            n2 = m2 + self.k(n1, record_rng = record_rng)

        # 返回拼接后的结果
        return torch.cat((y1, y2), dim = 2), torch.cat((n1, n2), dim = 2)

    # 反向传播函数,接受输入 y, n, dy, dn,以及一系列参数,返回处理后的结果
    def backward_pass(self, y, n, dy, dn, *, context, context_mask, video_mask = None, audio_mask = None, **kwargs):
        # 将输入 y 和 n 按照第二维度分成两部分
        y1, y2 = torch.chunk(y, 2, dim = 2)
        del y

        dy1, dy2 = torch.chunk(dy, 2, dim = 2)
        del dy

        # 启用梯度记录
        with torch.enable_grad():
            y1.requires_grad = True
            # 计算 gy1
            gy1 = self.g(y1, set_rng = True)
            # 反向传播计算 dy2
            torch.autograd.backward(gy1, dy2)

        # 使用 torch.no_grad 上下文管理器,计算中间结果
        with torch.no_grad():
            x2 = y2 - gy1
            del y2, gy1

            dx1 = dy1 + y1.grad
            del dy1
            y1.grad = None

        # 启用梯度记录
        with torch.enable_grad():
            x2.requires_grad = True
            # 计算 fx2
            fx2 = self.f(x2, set_rng = True, context = context, context_mask = context_mask, mask = video_mask)
            # 反向传播计算 dx1
            torch.autograd.backward(fx2, dx1, retain_graph = True)

        # 使用 torch.no_grad 上下文管理器,计算中间结果
        with torch.no_grad():
            x1 = y1 - fx2
            del y1, fx2

            dx2 = dy2 + x2.grad
            del dy2
            x2.grad = None

            x = torch.cat([x1, x2.detach()], dim = 2)
            dx = torch.cat([dx1, dx2], dim = 2)

        # 将输入 n 按照第二维度分成两部分
        n1, n2 = torch.chunk(n, 2, dim = 2)
        del n

        dn1, dn2 = torch.chunk(dn, 2, dim = 2)
        del dn

        # 启用梯度记录
        with torch.enable_grad():
            n1.requires_grad = True
            # 计算 gn1
            gn1 = self.k(n1, set_rng = True)
            # 反向传播计算 dn2
            torch.autograd.backward(gn1, dn2)

        # 使用 torch.no_grad 上下文管理器,计算中间结果
        with torch.no_grad():
            m2 = n2 - gn1
            del n2, gn1

            dm1 = dn1 + n1.grad
            del dn1
            n1.grad = None

        # 启用梯度记录
        with torch.enable_grad():
            m2.requires_grad = True
            # 计算 fm2
            fm2 = self.j(m2, set_rng = True, context = context, context_mask = context_mask, mask = audio_mask)
            # 反向传播计算 dm1
            torch.autograd.backward(fm2, dm1, retain_graph=True)

        # 使用 torch.no_grad 上下文管理器,计算中间结果
        with torch.no_grad():
            m1 = n1 - fm2
            del n1, fm2

            dm2 = dn2 + m2.grad
            del dn2
            m2.grad = None

            m = torch.cat([m1, m2.detach()], dim = 2)
            dm = torch.cat([dm1, dm2], dim = 2)

        # 返回结果
        return x, m, dx, dm
# 可逆交叉模态注意力块

class ReversibleCrossModalityAttnBlock(nn.Module):
    def __init__(self, f, g, j, k):
        super().__init__()
        self.f = Deterministic(f)  # 初始化可逆函数 f
        self.g = Deterministic(g)  # 初始化可逆函数 g
        self.j = Deterministic(j)  # 初始化可逆函数 j
        self.k = Deterministic(k)  # 初始化可逆函数 k

    def forward(self, x, m, *, video_mask = None, audio_mask = None, _reverse = True, **kwargs):
        x1, x2 = torch.chunk(x, 2, dim = 2)  # 将输入 x 沿着第二维度分成两部分 x1 和 x2
        m1, m2 = torch.chunk(m, 2, dim = 2)  # 将输入 m 沿着第二维度分成两部分 m1 和 m2
        y1, y2, n1, n2 = None, None, None, None

        fn_context = torch.no_grad if _reverse else null_context  # 根据 _reverse 的值选择上下文管理器
        record_rng = self.training and _reverse

        with fn_context():
            y1 = x1 + self.f(x2, m2, record_rng = record_rng, mask = video_mask, context_mask = audio_mask)  # 计算 y1
            y2 = x2 + self.k(y1, record_rng = record_rng)  # 计算 y2
            n1 = m1 + self.j(m2, y2, record_rng = record_rng, mask = audio_mask, context_mask = video_mask)  # 计算 n1
            n2 = m2 + self.g(n1, record_rng = record_rng)  # 计算 n2

        return torch.cat((y1, y2), dim = 2), torch.cat((n1, n2), dim = 2)  # 返回拼接后的结果

    def backward_pass(self, y, n, dy, dn, video_mask = None, audio_mask = None, **kwargs):
        n1, n2 = torch.chunk(n, 2, dim = 2)  # 将输入 n 沿着第二维度分成两部分 n1 和 n2
        del n

        dn1, dn2 = torch.chunk(dn, 2, dim = 2)  # 将输入 dn 沿着第二维度分成两部分 dn1 和 dn2
        del dn

        y1, y2 = torch.chunk(y, 2, dim = 2)  # 将输入 y 沿着第二维度分成两部分 y1 和 y2
        del y

        dy1, dy2 = torch.chunk(dy, 2, dim = 2)  # 将输入 dy 沿着第二维度分成两部分 dy1 和 dy2
        del dy

        with torch.enable_grad():
            n1.requires_grad = True
            gn1 = self.g(n1, set_rng = True)  # 计算 gn1
            torch.autograd.backward(gn1, dn2)  # 反向传播计算梯度

        with torch.no_grad():
            m2 = n2 - gn1  # 计算 m2
            del n2, gn1

            dm1 = dn1 + n1.grad  # 计算 dm1
            del dn1
            n1.grad = None

        with torch.enable_grad():
            m2.requires_grad = True
            y2.requires_grad = True
            fm2 = self.j(m2, y2, set_rng=True, mask = audio_mask, context_mask = video_mask)  # 计算 fm2
            torch.autograd.backward(fm2, dm1)  # 反向传播计算梯度

        with torch.no_grad():
            m1 = n1 - fm2  # 计算 m1
            del n1, fm2

            dm2 = dn2 + m2.grad  # 计算 dm2
            dx2 = dy2 + y2.grad  # 计算 dx2
            del dn2
            del dy2
            m2.grad = None
            y2.grad = None

        with torch.enable_grad():
            y1.requires_grad = True
            gy1 = self.k(y1, set_rng = True)  # 计算 gy1
            torch.autograd.backward(gy1, dx2)  # 反向传播计算梯度

        with torch.no_grad():
            x2 = y2 - gy1  # 计算 x2
            del y2, gy1

            dx1 = dy1 + y1.grad  # 计算 dx1
            del dy1
            y1.grad = None

        with torch.enable_grad():
            x2.requires_grad = True
            m2.requires_grad = True
            fx2 = self.f(x2, m2, set_rng = True, mask = video_mask, context_mask = audio_mask)  # 计算 fx2
            torch.autograd.backward(fx2, dx1)  # 反向传播计算梯度

        with torch.no_grad():
            x1 = y1 - fx2  # 计算 x1
            del y1, fx2

            dx2 = dx2 + x2.grad  # 计算 dx2
            dm2 = dm2 + m2.grad  # 计算 dm2
            x2.grad = None
            m2.grad = None

        with torch.no_grad():
            m = torch.cat([m1, m2.detach()], dim = 2)  # 拼接 m1 和 m2
            dm = torch.cat([dm1, dm2], dim = 2)  # 拼接 dm1 和 dm2

            x = torch.cat([x1, x2.detach()], dim = 2)  # 拼接 x1 和 x2
            dx = torch.cat([dx1, dx2], dim = 2)  # 拼接 dx1 和 dx2

        return x, m, dx, dm

# 反向和非反向函数

class ReversibleFunction(Function):
    @staticmethod
    def forward(ctx, inp, ind, blocks, kwargs):
        x, m = split_at_index(1, ind, inp)  # 在指定索引处分割输入

        for block in blocks:
            x, m = block(x, m, _reverse = True, **kwargs)  # 对每个块进行前向传播

        ctx.blocks = blocks
        ctx.kwargs = kwargs
        ctx.ind = ind
        ctx.save_for_backward(x.detach(), m.detach())
        return torch.cat((x, m), dim = 1)  # 拼接结果

    @staticmethod
    # 定义一个反向传播函数,接受上下文和梯度作为参数
    def backward(ctx, d):
        # 从上下文中获取索引、块和关键字参数
        ind = ctx.ind
        blocks = ctx.blocks
        kwargs = ctx.kwargs
        # 将梯度按照索引分割成两部分
        dy, dn = split_at_index(1, ind, d)
        # 从上下文中获取保存的张量 y 和 n
        y, n = ctx.saved_tensors

        # 对块列表进行反向遍历
        for block in blocks[::-1]:
            # 调用每个块的反向传播函数,更新 y、n、dy 和 dn
            y, n, dy, dn = block.backward_pass(y, n, dy, dn, **kwargs)

        # 将分割后的梯度拼接在一起
        d = torch.cat((dy, dn), dim=1)
        # 返回更新后的梯度和 None(因为没有额外的返回值)
        return d, None, None, None
# 将 ReversibleFunction.apply 赋值给 reversible_apply
reversible_apply = ReversibleFunction.apply

# 定义不可逆应用函数,接受输入、索引、块和关键字参数
def irreversible_apply(inputs, ind, blocks, kwargs):
    # 在索引处将输入分割为 x 和 m
    x, m = split_at_index(1, ind, inputs)
    # 对每个块应用,更新 x 和 m
    for block in blocks:
        x, m = block(x, m, _reverse = False, **kwargs)
    # 拼接 x 和 m,返回结果
    return torch.cat((x, m), dim = 1)

# 主要的可逆序列类
class DualModalityReversibleSequence(nn.Module):
    # 初始化函数,接受输入块和块类型
    def __init__(self, input_blocks, block_types):
        super().__init__()
        self.block_types = block_types
        blocks = nn.ModuleList([])

        # 遍历输入块和块类型,根据类型选择可逆类别
        for block, block_type in zip(input_blocks, block_types):
            if block_type == 'intra_modality_self_attn':
                reversible_klass = ReversibleSelfAttnBlock
            elif block_type == 'intra_modality_cross_attn':
                reversible_klass = ReversibleCrossAttnBlock
            elif block_type == 'inter_modality_cross_attn':
                reversible_klass = ReversibleCrossModalityAttnBlock
            else:                
                raise ValueError(f'unknown layer type {block_type}')

            blocks.append(reversible_klass(*block))

        self.blocks = blocks

    # 前向传播函数,接受视频、音频、上下文和掩码等参数
    def forward(
        self,
        video,
        audio,
        *,
        context,
        context_mask = None,
        video_mask = None,
        audio_mask = None,
        reverse = True
    ):  
        blocks = self.blocks
        # 将视频和音频拼接起来
        video, audio = list(map(lambda t: torch.cat((t, t), dim = -1), (video, audio)))
        kwargs = {'context': context, 'context_mask': context_mask, 'video_mask': video_mask, 'audio_mask': audio_mask}

        # 根据是否可逆选择应用函数
        fn = reversible_apply if reverse else irreversible_apply
        ind = video.shape[1]
        inp = torch.cat((video, audio), dim = 1)
        out = fn(inp, ind, blocks, kwargs)
        # 将输出拆分为视频和音频
        video, audio  = split_at_index(1, ind, out)
        # 对视频和音频应用 reduce 函数,返回结果
        return list(map(lambda t: reduce(t, 'b n (c d) -> b n d', 'mean', c = 2), (video, audio)))

.\lucidrains\nuwa-pytorch\nuwa_pytorch\train_nuwa.py

# 从 random 模块中导入 randrange 函数
from random import randrange
# 从 pathlib 模块中导入 Path 类
from pathlib import Path

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块
from torch import nn
# 从 torch.utils.data 模块中导入 Dataset 和 DataLoader 类
from torch.utils.data import Dataset, DataLoader
# 从 torch.nn.utils.rnn 模块中导入 pad_sequence 函数
from torch.nn.utils.rnn import pad_sequence
# 从 einops 库中导入 rearrange 函数
from einops import rearrange

# 从 tqdm 模块中导入 tqdm 函数
from tqdm import tqdm
# 导入 numpy 库
import numpy as np
# 从 shutil 模块中导入 rmtree 函数
from shutil import rmtree

# 导入 nuwa_pytorch 库中的 tokenizer 模块和 optimizer 模块
from nuwa_pytorch.tokenizer import tokenizer
from nuwa_pytorch.optimizer import get_optimizer
# 导入 nuwa_pytorch 库中的 image_utils 模块
from nuwa_pytorch.image_utils import gif_to_tensor
# 从 nuwa_pytorch 模块中导入 NUWA 类

# 从 torchvision.transforms 模块中导入 T 别名
import torchvision.transforms as T
# 从 torchvision.utils 模块中导入 make_grid 和 save_image 函数

# 辅助函数

# 判断值是否存在的函数
def exists(val):
    return val is not None

# 空操作函数
def noop(*args, **kwargs):
    pass

# 生成循环迭代器的函数
def cycle(dl):
    while True:
        for data in dl:
            yield data

# 将输入转换为元组的函数
def cast_tuple(t):
    return t if isinstance(t, (tuple, list)) else (t,)

# 询问用户是否为是或否的函数
def yes_or_no(question):
    answer = input(f'{question} (y/n) ')
    return answer.lower() in ('yes', 'y')

# 累积日志的函数
def accum_log(log, new_logs):
    for key, new_value in new_logs.items():
        old_value = log.get(key, 0.)
        log[key] = old_value + new_value
    return log

# 数据加载器辅助函数

# 数据填充函数
def pad_collate_fn(batch):
    texts, videos = zip(*batch)
    return pad_sequence(texts, batch_first = True), torch.stack(videos)

# 数据处理流水线函数

# 将视频张量数据集转换为索引的函数
def convert_video_tensor_dataset_to_indices(
    *,
    vae,
    raw_video_dataset,
    num_frames,
    path,
):
    vae_device = next(vae.parameters()).device
    num_videos = len(raw_video_dataset)
    assert num_videos > 0, 'there must be at least 1 video'

    fmap_size = vae.image_size // (vae.num_layers ** 2)
    shape = (num_videos, num_frames * fmap_size * fmap_size)

    video_indices_memmap = np.memmap(path, mode = 'w+', dtype = np.int64, shape = shape)

    for ind in tqdm(range(num_videos)):
        _, video = raw_video_dataset[ind]
        video = rearrange(video, '... -> 1 ...')
        video = video.to(vae_device)
        indices = vae.get_video_indices(video)
        indices = rearrange(indices, '1 f h w -> (f h w)')
        video_indices_memmap[ind] = indices.cpu().numpy()

    print(f'completed conversion of {num_videos} videos to indices at {path}')

# 数据集类

# Mnist 数据集类
class MnistDataset(Dataset):
    def __init__(
        self,
        num_videos,
        videos_memmap_path,
        text_memmap_path,
        num_digits = 2,
        num_frames = 10,
        image_size = 64,
        channels = 1,
        random_rotate = False
    ):
        super().__init__()
        self.num_videos = num_videos
        self.videos_memmap = np.memmap(videos_memmap_path, mode = 'r', dtype = np.uint8, shape = (num_videos, num_frames, channels, image_size, image_size))
        self.text_memmap = np.memmap(text_memmap_path, mode = 'r', dtype = np.uint8, shape = (num_videos, num_digits))
        self.random_rotate = random_rotate

    def __len__(self):
        return self.num_videos

    def __getitem__(self, idx):
        video = torch.from_numpy(self.videos_memmap[idx].copy()).float()
        label = torch.from_numpy(self.text_memmap[idx].copy())

        video /= 255
        video = video.to(torch.float32)

        text = tokenizer.encode(' '.join(map(str, label.tolist())))
        text = torch.Tensor(text).long()

        if self.random_rotate:
            video = T.functional.rotate(video, choice([0, 90, 180, 270]))

        return text, video

# 视频索引数据集类
class VideoIndicesDataset(Dataset):
    def __init__(
        self,
        *,
        videos_memmap_path,
        text_memmap_path,
        vae,
        num_videos,
        num_frames,
        num_digits = 2,
    ):
        self.num_videos = num_videos
        fmap_size = vae.image_size // (vae.num_layers ** 2)
        self.videos_memmap = np.memmap(videos_memmap_path, mode = 'r', dtype = np.int64, shape = (num_videos, num_frames * (fmap_size ** 2)))
        self.text_memmap = np.memmap(text_memmap_path, mode = 'r', dtype = np.uint8, shape = (num_videos, num_digits))

    def __len__(self):
        return self.num_videos
    # 定义一个特殊方法,用于获取数据集中指定索引位置的数据
    def __getitem__(self, idx):
        # 从内存映射中读取视频数据,并转换为PyTorch张量
        video = torch.from_numpy(self.videos_memmap[idx].copy())
        # 从内存映射中读取文本数据,并转换为PyTorch张量
        text = torch.from_numpy(self.text_memmap[idx].copy())

        # 将文本数据转换为字符串,使用空格连接后编码为token,再转换为PyTorch张量
        text = tokenizer.encode(' '.join(map(str, text.tolist())))
        text = torch.Tensor(text).long()

        # 将视频数据转换为长整型张量
        video = video.long()
        # 返回处理后的文本和视频数据
        return text, video
# 从视频文件夹中创建用于训练的数据集类
class GifVideoDataset(Dataset):
    def __init__(
        self,
        *,
        folder,  # 视频文件夹路径
        channels = 1  # 通道数,默认为1
    ):
        # 将文件夹路径转换为 Path 对象
        folder = Path(folder)
        # 获取所有 GIF 文件和对应的文本文件
        gifs = folder.glob('**/*.gif')
        txts = folder.glob('**/*.txt')

        # 获取 GIF 文件和文本文件的路径前缀
        gif_path_stems = set(map(lambda t: str(t.with_suffix('')), gifs))
        txt_path_stems = set(map(lambda t: str(t.with_suffix('')), txts))
        # 获取共同的路径前缀作为数据集的路径
        self.path_stems = list(gif_path_stems.intersection(txt_path_stems))

        self.channels = channels  # 设置通道数
        print(f'{len(self.path_stems)} video / text pairs found')  # 打印找到的视频/文本对数量

    def __len__(self):
        return len(self.path_stems)  # 返回数据集的长度

    def __getitem__(self, idx):
        path_stem = self.path_stems[idx]  # 获取指定索引的路径前缀

        txt_path = Path(f'{path_stem}.txt')  # 构建文本文件路径
        txt_str = txt_path.read_text()  # 读取文本文件内容
        text_tensor = torch.Tensor(tokenizer.encode(txt_str)).long()  # 将文本内容编码为张量

        video_tensor = gif_to_tensor(f'{path_stem}.gif', channels = self.channels)  # 将 GIF 文件转换为张量
        return text_tensor, video_tensor  # 返回文本张量和视频张量的元组

# 训练类
class NUWATrainer(nn.Module):
    def __init__(
        self,
        *,
        nuwa,  # NUWA 模型实例
        dataset,  # 数据集实例
        num_train_steps,  # 训练步数
        lr = 3e-4,  # 学习率,默认为 3e-4
        wd = 0.01,  # 权重衰减,默认为 0.01
        batch_size = 4,  # 批量大小,默认为 4
        grad_accum_every = 8,  # 梯度累积间隔,默认为 8
        max_grad_norm = 0.5,  # 最大梯度范数,默认为 0.5
        save_model_every = 2500,  # 每隔多少步保存模型,默认为 2500
        save_results_every = 1000,  # 每隔多少步保存结果,默认为 1000
        results_folder = './results-nuwa',  # 结果文件夹路径,默认为 './results-nuwa'
        num_sampled_frames = float('inf')  # 抽样帧数,默认为无穷大
    ):
        super().__init__()
        assert isinstance(nuwa, NUWA), 'nuwa must be an instance of NUWA'  # 断言 nuwa 必须是 NUWA 类的实例
        self.nuwa = nuwa  # 设置 NUWA 模型实例

        self.steps = 0  # 训练步数初始化为 0
        self.num_train_steps = num_train_steps  # 设置训练步数
        self.batch_size = batch_size  # 设置批量大小
        self.grad_accum_every = grad_accum_every  # 设置梯度累积间隔
        self.max_grad_norm = max_grad_norm  # 设置最大梯度范数

        self.optim = get_optimizer(nuwa.parameters(), lr = lr, wd = wd)  # 获取优化器

        # 数据集
        self.ds = dataset  # 设置数据集

        # 数据加载器
        self.dl = cycle(DataLoader(
            self.ds,
            batch_size = batch_size,
            collate_fn = pad_collate_fn,
            shuffle = True
        ))  # 创建循环数据加载器

        self.save_model_every = save_model_every  # 设置保存模型间隔
        self.save_results_every = save_results_every  # 设置保存结果间隔
        self.num_sampled_frames = num_sampled_frames  # 设置抽样帧数

        self.results_folder = Path(results_folder)  # 设置结果文件夹路径

        # 如果结果文件夹中有文件且确认清除之前的实验检查点和结果
        if len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?'):
            rmtree(str(self.results_folder))  # 清除之前的实验检查点和结果

        self.results_folder.mkdir(parents = True, exist_ok = True)  # 创建结果文件夹
    # 定义训练步骤函数
    def train_step(self):
        # 获取模型参数所在设备
        device = next(self.nuwa.parameters()).device
        # 设置模型为训练模式
        self.nuwa.train()

        # 初始化日志字典
        logs = {}

        # 循环执行梯度累积次数
        for _ in range(self.grad_accum_every):
            # 从数据加载器中获取文本和视频数据
            text, video = next(self.dl)
            # 将文本和视频数据移动到指定设备
            text, video = map(lambda t: t.to(device), (text, video))

            # 计算模型损失
            loss = self.nuwa(
                text = text,
                video = video,
                return_loss = True
            )
            # 累积损失到日志中
            accum_log(logs, {'loss': loss.item() / self.grad_accum_every})

            # 反向传播梯度
            (loss / self.grad_accum_every).backward()

        # 打印当前步骤的损失值
        print(f'{self.steps} loss: {logs["loss"]}')

        # 对模型参数进行梯度裁剪
        torch.nn.utils.clip_grad_norm_(self.nuwa.parameters(), self.max_grad_norm)
        # 更新优化器参数
        self.optim.step()
        # 清空梯度
        self.optim.zero_grad()

        # 每隔一定步骤保存生成结果
        if not (self.steps % self.save_results_every):
            # 设置模型为评估模式
            self.nuwa.eval()
            print(f'{self.steps} sampling')

            # 随机选择一个数据样本
            rand_idx = randrange(0, len(self.ds))

            text, video = self.ds[rand_idx]
            text, video = next(self.dl)
            text = text.to(device)

            # 生成视频序列
            video = self.nuwa.generate(text = text[:1], num_frames = min(video.shape[1], self.num_sampled_frames))
            one_video = video[0].cpu().clamp(0., 1.)

            # 解码文本数据
            text_str = tokenizer.decode(text[0])

            # 保存生成的文本和视频结果
            logs['sampled_text'] = text_str
            logs['sampled_video'] = one_video.numpy()

            # 重新排列视频帧以保存为图像
            image = rearrange(one_video, 'f c h w -> c (f h) w')
            save_image(image, str(self.results_folder / f'{self.steps}.png'))

            print(f'{self.steps}: saving to {str(self.results_folder)}')

        # 每隔一定步骤保存模型
        if not (self.steps % self.save_model_every):
            # 获取模型状态字典
            state_dict = self.nuwa.state_dict()
            model_path = str(self.results_folder / f'nuwa.{self.steps}.pt')
            # ���存模型参数
            torch.save(state_dict, model_path)

            print(f'{self.steps}: saving model to {str(self.results_folder)}')

        # 更新步骤数
        self.steps += 1
        return logs

    # 定义训练函数
    def train(self, log_fn = noop):
        # 循环执行训练步骤直到达到指定训练步数
        while self.steps < self.num_train_steps:
            # 执行训练步骤并记录日志
            logs = self.train_step()
            log_fn(logs)

        # 打印训练完成信息
        print('training complete')

.\lucidrains\nuwa-pytorch\nuwa_pytorch\train_vqgan_vae.py

# 从 math 模块中导入 sqrt 函数
from math import sqrt
# 从 copy 模块中导入 copy 函数
import copy
# 从 random 模块中导入 choice 函数
from random import choice
# 从 pathlib 模块中导入 Path 类
from pathlib import Path
# 从 shutil 模块中导入 rmtree 函数

# 导入 torch 模块
import torch
# 从 torch 模块中导入 nn 模块
from torch import nn
# 导入 numpy 模块
import numpy as np

# 从 PIL 模块中导入 Image 类
from PIL import Image
# 从 torchvision.datasets 模块中导入 ImageFolder 类
from torchvision.datasets import ImageFolder
# 从 torchvision.transforms 模块中导入 T 别名
import torchvision.transforms as T
# 从 torch.utils.data 模块中导入 Dataset, DataLoader, random_split 类
from torch.utils.data import Dataset, DataLoader, random_split
# 从 torchvision.utils 模块中导入 make_grid, save_image 函数

# 从 einops 模块中导入 rearrange 函数
from einops import rearrange
# 从 nuwa_pytorch.vqgan_vae 模块中导入 VQGanVAE 类
from nuwa_pytorch.vqgan_vae import VQGanVAE
# 从 nuwa_pytorch.optimizer 模块中导入 get_optimizer 函数

# helpers

# 定义 exists 函数,判断值是否存在
def exists(val):
    return val is not None

# 定义 noop 函数,空函数
def noop(*args, **kwargs):
    pass

# 定义 cycle 函数,循环生成数据
def cycle(dl):
    while True:
        for data in dl:
            yield data

# 定义 cast_tuple 函数,将输入转换为元组
def cast_tuple(t):
    return t if isinstance(t, (tuple, list)) else (t,)

# 定义 yes_or_no 函数,询问用户是否为是或否
def yes_or_no(question):
    answer = input(f'{question} (y/n) ')
    return answer.lower() in ('yes', 'y')

# 定义 accum_log 函数,累积日志
def accum_log(log, new_logs):
    for key, new_value in new_logs.items():
        old_value = log.get(key, 0.)
        log[key] = old_value + new_value
    return log

# classes

# 定义 MemmappedImageDataset 类,继承自 Dataset 类
class MemmappedImageDataset(Dataset):
    def __init__(
        self,
        *,
        path,
        shape,
        random_rotate = True
    ):
        super().__init__()
        path = Path(path)
        assert path.exists(), f'path {path} must exist'
        self.memmap = np.memmap(str(path), mode = 'r', dtype = np.uint8, shape = shape)
        self.random_rotate = random_rotate

        image_size = shape[-1]
        self.transform = T.Compose([
            T.Resize(image_size),
            T.CenterCrop(image_size),
            T.ToTensor()
        ])

    def __len__(self):
        return self.memmap.shape[0]

    def __getitem__(self, index):
        arr = self.memmap[index]

        if arr.shape[0] == 1:
            arr = rearrange(arr, '1 ... -> ...')

        img = Image.fromarray(arr)
        img = self.transform(img)

        if self.random_rotate:
            img = T.functional.rotate(img, choice([0, 90, 180, 270]))
        return img

# 定义 ImageDataset 类,继承自 Dataset 类
class ImageDataset(Dataset):
    def __init__(
        self,
        folder,
        image_size,
        exts = ['jpg', 'jpeg', 'png']
    ):
        super().__init__()
        self.folder = folder
        self.image_size = image_size
        self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]

        print(f'{len(self.paths)} training samples found at {folder}')

        self.transform = T.Compose([
            T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
            T.Resize(image_size),
            T.RandomHorizontalFlip(),
            T.CenterCrop(image_size),
            T.ToTensor()
        ])

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        path = self.paths[index]
        img = Image.open(path)
        return self.transform(img)

# exponential moving average wrapper

# 定义 EMA 类,继承自 nn.Module 类
class EMA(nn.Module):
    def __init__(
        self,
        model,
        beta = 0.99,
        ema_update_after_step = 1000,
        ema_update_every = 10,
    ):
        super().__init__()
        self.beta = beta
        self.online_model = model
        self.ema_model = copy.deepcopy(model)

        self.ema_update_after_step = ema_update_after_step # only start EMA after this step number, starting at 0
        self.ema_update_every = ema_update_every

        self.register_buffer('initted', torch.Tensor([False]))
        self.register_buffer('step', torch.tensor([0.]))

    def update(self):
        self.step += 1

        if self.step <= self.ema_update_after_step or (self.step % self.ema_update_every) != 0:
            return

        if not self.initted:
            self.ema_model.state_dict(self.online_model.state_dict())
            self.initted.data.copy_(torch.Tensor([True]))

        self.update_moving_average(self.ema_model, self.online_model)
    # 更新移动平均模型的参数
    def update_moving_average(self, ma_model, current_model):
        # 定义计算指数移动平均的函数
        def calculate_ema(beta, old, new):
            # 如果旧值不存在,则直接返回新值
            if not exists(old):
                return new
            # 计算指数移动平均值
            return old * beta + (1 - beta) * new

        # 遍历当前模型和移动平均模型的参数,更新移动平均值
        for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
            old_weight, up_weight = ma_params.data, current_params.data
            ma_params.data = calculate_ema(self.beta, old_weight, up_weight)

        # 遍历当前模型和移动平均模型的缓冲区,更新移动平均值
        for current_buffer, ma_buffer in zip(current_model.buffers(), ma_model.buffers()):
            new_buffer_value = calculate_ema(self.beta, ma_buffer, current_buffer)
            ma_buffer.copy_(new_buffer_value)

    # 调用函数,返回移动平均模型的结果
    def __call__(self, *args, **kwargs):
        return self.ema_model(*args, **kwargs)
# 主要的训练器类

class VQGanVAETrainer(nn.Module):
    def __init__(
        self,
        vae,
        *,
        num_train_steps,
        lr,
        batch_size,
        grad_accum_every,
        wd = 0.,
        images_memmap_path = None,
        images_memmap_shape = None,
        folder = None,
        save_results_every = 100,
        save_model_every = 1000,
        results_folder = './results',
        valid_frac = 0.05,
        random_split_seed = 42,
        ema_beta = 0.995,
        ema_update_after_step = 2000,
        ema_update_every = 10,
        apply_grad_penalty_every = 4,
    ):
        super().__init__()
        assert isinstance(vae, VQGanVAE), 'vae must be instance of VQGanVAE'
        image_size = vae.image_size

        self.vae = vae
        self.ema_vae = EMA(vae, ema_update_after_step = ema_update_after_step, ema_update_every = ema_update_every)

        self.register_buffer('steps', torch.Tensor([0]))

        self.num_train_steps = num_train_steps
        self.batch_size = batch_size
        self.grad_accum_every = grad_accum_every

        all_parameters = set(vae.parameters())
        discr_parameters = set(vae.discr.parameters())
        vae_parameters = all_parameters - discr_parameters

        self.optim = get_optimizer(vae_parameters, lr = lr, wd = wd)
        self.discr_optim = get_optimizer(discr_parameters, lr = lr, wd = wd)

        # 创建数据集

        assert exists(folder) ^ exists(images_memmap_path), 'either folder or memmap path to images must be supplied'

        if exists(images_memmap_path):
            assert exists(images_memmap_shape), 'shape of memmapped images must be supplied'

        if exists(folder):
            self.ds = ImageDataset(folder, image_size = image_size)
        elif exists(images_memmap_path):
            self.ds = MemmappedImageDataset(path = images_memmap_path, shape = images_memmap_shape)

        # 划分验证集

        if valid_frac > 0:
            train_size = int((1 - valid_frac) * len(self.ds))
            valid_size = len(self.ds) - train_size
            self.ds, self.valid_ds = random_split(self.ds, [train_size, valid_size], generator = torch.Generator().manual_seed(random_split_seed))
            print(f'training with dataset of {len(self.ds)} samples and validating with randomly splitted {len(self.valid_ds)} samples')
        else:
            self.valid_ds = self.ds
            print(f'training with shared training and valid dataset of {len(self.ds)} samples')

        # 数据加载器

        self.dl = cycle(DataLoader(
            self.ds,
            batch_size = batch_size,
            shuffle = True
        ))

        self.valid_dl = cycle(DataLoader(
            self.valid_ds,
            batch_size = batch_size,
            shuffle = True
        ))

        self.save_model_every = save_model_every
        self.save_results_every = save_results_every

        self.apply_grad_penalty_every = apply_grad_penalty_every

        self.results_folder = Path(results_folder)

        if len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?'):
            rmtree(str(self.results_folder))

        self.results_folder.mkdir(parents = True, exist_ok = True)
    # 定义训练步骤函数
    def train_step(self):
        # 获取模型参数所在设备
        device = next(self.vae.parameters()).device
        # 获取当前步数
        steps = int(self.steps.item())
        # 是否应用梯度惩罚
        apply_grad_penalty = not (steps % self.apply_grad_penalty_every)

        # 设置 VAE 模型为训练模式
        self.vae.train()

        # 初始化日志字典
        logs = {}

        # 更新 VAE(生成器)

        # 多次执行梯度累积
        for _ in range(self.grad_accum_every):
            # 获取下一个数据批次
            img = next(self.dl)
            img = img.to(device)

            # 计算损失
            loss = self.vae(
                img,
                return_loss = True,
                apply_grad_penalty = apply_grad_penalty
            )

            # 累积损失到日志中
            accum_log(logs, {'loss': loss.item() / self.grad_accum_every})

            # 反向传播
            (loss / self.grad_accum_every).backward()

        # 更新优化器
        self.optim.step()
        self.optim.zero_grad()

        # 更新鉴别器

        if exists(self.vae.discr):
            self.discr_optim.zero_grad()
            discr_loss = 0

            for _ in range(self.grad_accum_every):
                img = next(self.dl)
                img = img.to(device)

                loss = self.vae(img, return_discr_loss = True)
                accum_log(logs, {'discr_loss': loss.item() / self.grad_accum_every})

                (loss / self.grad_accum_every).backward()

            self.discr_optim.step()

            # 打印日志
            print(f"{steps}: vae loss: {logs['loss']} - discr loss: {logs['discr_loss']}")

        # 更新指数移动平均生成器
        self.ema_vae.update()

        # 定期采样结果

        if not (steps % self.save_results_every):
            for model, filename in ((self.ema_vae.ema_model, f'{steps}.ema'), (self.vae, str(steps))):
                model.eval()

                imgs = next(self.dl)
                imgs = imgs.to(device)

                recons = model(imgs)
                nrows = int(sqrt(self.batch_size))

                imgs_and_recons = torch.stack((imgs, recons), dim = 0)
                imgs_and_recons = rearrange(imgs_and_recons, 'r b ... -> (b r) ...')

                imgs_and_recons = imgs_and_recons.detach().cpu().float().clamp(0., 1.)
                grid = make_grid(imgs_and_recons, nrow = 2, normalize = True, value_range = (0, 1))

                logs['reconstructions'] = grid

                save_image(grid, str(self.results_folder / f'{filename}.png'))

            print(f'{steps}: saving to {str(self.results_folder)}')

        # 定期保存模型

        if not (steps % self.save_model_every):
            state_dict = self.vae.state_dict()
            model_path = str(self.results_folder / f'vae.{steps}.pt')
            torch.save(state_dict, model_path)

            ema_state_dict = self.ema_vae.state_dict()
            model_path = str(self.results_folder / f'vae.{steps}.ema.pt')
            torch.save(ema_state_dict, model_path)

            print(f'{steps}: saving model to {str(self.results_folder)}')

        # 更新步数
        self.steps += 1
        return logs

    # 训练函数
    def train(self, log_fn = noop):
        # 获取模型参数所在设备
        device = next(self.vae.parameters()).device

        # 在训练步数未达到总训练步数前循环执行训练步骤
        while self.steps < self.num_train_steps:
            logs = self.train_step()
            log_fn(logs)

        # 训练完成
        print('training complete')

.\lucidrains\nuwa-pytorch\nuwa_pytorch\vqgan_vae.py

# 导入必要的库
import copy
import math
from functools import partial, wraps
from math import sqrt

# 导入自定义模块
from vector_quantize_pytorch import VectorQuantize as VQ

# 导入 PyTorch 相关库
import torchvision
import torch
from torch import nn, einsum
import torch.nn.functional as F
from torch.autograd import grad as torch_grad

# 导入 einops 库
from einops import rearrange, reduce, repeat

# 定义常量
MList = nn.ModuleList

# 辅助函数

# 判断变量是否存在
def exists(val):
    return val is not None

# 如果变量存在则返回其值,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# 装饰器

# 模型评估装饰器
def eval_decorator(fn):
    def inner(model, *args, **kwargs):
        was_training = model.training
        model.eval()
        out = fn(model, *args, **kwargs)
        model.train(was_training)
        return out
    return inner

# 移除 VGG 模型装饰器
def remove_vgg(fn):
    @wraps(fn)
    def inner(self, *args, **kwargs):
        has_vgg = hasattr(self, 'vgg')
        if has_vgg:
            vgg = self.vgg
            delattr(self, 'vgg')

        out = fn(self, *args, **kwargs)

        if has_vgg:
            self.vgg = vgg

        return out
    return inner

# 关键字参数辅助函数

# 从字典中选择指定键的值并弹出这些键
def pick_and_pop(keys, d):
    values = list(map(lambda key: d.pop(key), keys))
    return dict(zip(keys, values))

# 根据条件将字典分组
def group_dict_by_key(cond, d):
    return_val = [dict(),dict()]
    for key in d.keys():
        match = bool(cond(key))
        ind = int(not match)
        return_val[ind][key] = d[key]
    return (*return_val,)

# 判断字符串是否以指定前缀开头
def string_begins_with(prefix, str):
    return str.startswith(prefix)

# 根据前缀将字典分组
def group_by_key_prefix(prefix, d):
    return group_dict_by_key(partial(string_begins_with, prefix), d)

# 根据前缀将字典分组并去除前缀
def groupby_prefix_and_trim(prefix, d):
    kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
    kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))
    return kwargs_without_prefix, kwargs

# 张量辅助函数

# 计算梯度惩罚
def gradient_penalty(images, output, weight = 10):
    batch_size = images.shape[0]
    gradients = torch_grad(outputs = output, inputs = images,
                           grad_outputs = torch.ones(output.size(), device = images.device),
                           create_graph = True, retain_graph = True, only_inputs = True)[0]

    gradients = rearrange(gradients, 'b ... -> b (...)')
    return weight * ((gradients.norm(2, dim=1) - 1) ** 2).mean()

# 计算 L2 范数
def l2norm(t):
    return F.normalize(t, dim = -1)

# Leaky ReLU 激活函数
def leaky_relu(p = 0.1):
    return nn.LeakyReLU(0.1)

# 稳定的 Softmax 函数
def stable_softmax(t, dim = -1, alpha = 32 ** 2):
    t = t / alpha
    t = t - torch.amax(t, dim = dim, keepdim = True).detach()
    return (t * alpha).softmax(dim = dim)

# 安全除法
def safe_div(numer, denom, eps = 1e-6):
    return numer / (denom + eps)

# GAN 损失函数

# Hinge 判别器损失
def hinge_discr_loss(fake, real):
    return (F.relu(1 + fake) + F.relu(1 - real)).mean()

# Hinge 生成器损失
def hinge_gen_loss(fake):
    return -fake.mean()

# 二元交叉熵判别器损失
def bce_discr_loss(fake, real):
    return (-log(1 - sigmoid(fake)) - log(sigmoid(real))).mean()

# 二元交叉熵生成器损失
def bce_gen_loss(fake):
    return -log(sigmoid(fake)).mean()

# 计算损失对层的梯度
def grad_layer_wrt_loss(loss, layer):
    return torch_grad(
        outputs = loss,
        inputs = layer,
        grad_outputs = torch.ones_like(loss),
        retain_graph = True
    )[0].detach()

# VQGAN VAE

# 通道层归一化
class LayerNormChan(nn.Module):
    def __init__(
        self,
        dim,
        eps = 1e-5
    ):
        super().__init__()
        self.eps = eps
        self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
        self.b = nn.Parameter(torch.zeros(1, dim, 1, 1)

    def forward(self, x):
        var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
        mean = torch.mean(x, dim = 1, keepdim = True)
        return (x - mean) / (var + self.eps).sqrt() * self.g + self.b

# 判别器模型
class Discriminator(nn.Module):
    def __init__(
        self,
        dims,
        channels = 3,
        groups = 16,
        init_kernel_size = 5
    # 定义一个继承自 nn.Module 的类,用于构建一个简单的卷积神经网络
    ):
        # 调用父类的初始化方法
        super().__init__()
        # 将输入维度按照前后两两配对,形成一个维度对的列表
        dim_pairs = zip(dims[:-1], dims[1:])

        # 初始化网络的第一层,包括一个卷积层和激活函数
        self.layers = MList([nn.Sequential(nn.Conv2d(channels, dims[0], init_kernel_size, padding = init_kernel_size // 2), leaky_relu())])

        # 遍历维度对列表,构建网络的中间层,每层包括卷积层、归一化层和激活函数
        for dim_in, dim_out in dim_pairs:
            self.layers.append(nn.Sequential(
                nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1),
                nn.GroupNorm(groups, dim_out),
                leaky_relu()
            ))

        # 获取最后一个维度
        dim = dims[-1]
        # 构建输出层,包括两个卷积层和激活函数,用于生成输出结果
        self.to_logits = nn.Sequential( # return 5 x 5, for PatchGAN-esque training
            nn.Conv2d(dim, dim, 1),
            leaky_relu(),
            nn.Conv2d(dim, 1, 4)
        )

    # 定义前向传播方法,将输入数据通过网络层进行处理,得到输出结果
    def forward(self, x):
        # 遍历网络的每一层,将输入数据依次传递给每一层
        for net in self.layers:
            x = net(x)

        # 返回经过所有网络层处理后的输出结果
        return self.to_logits(x)
class ContinuousPositionBias(nn.Module):
    """ 定义一个连续位置偏置的类,参考 https://arxiv.org/abs/2111.09883 """

    def __init__(self, *, dim, heads, layers = 2):
        super().__init__()
        self.net = MList([])
        self.net.append(nn.Sequential(nn.Linear(2, dim), leaky_relu()))

        for _ in range(layers - 1):
            self.net.append(nn.Sequential(nn.Linear(dim, dim), leaky_relu()))

        self.net.append(nn.Linear(dim, heads)
        self.register_buffer('rel_pos', None, persistent = False)

    def forward(self, x):
        n, device = x.shape[-1], x.device
        fmap_size = int(sqrt(n))

        if not exists(self.rel_pos):
            pos = torch.arange(fmap_size, device = device)
            grid = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij'))
            grid = rearrange(grid, 'c i j -> (i j) c')
            rel_pos = rearrange(grid, 'i c -> i 1 c') - rearrange(grid, 'j c -> 1 j c')
            rel_pos = torch.sign(rel_pos) * torch.log(rel_pos.abs() + 1)
            self.register_buffer('rel_pos', rel_pos, persistent = False)

        rel_pos = self.rel_pos.float()

        for layer in self.net:
            rel_pos = layer(rel_pos)

        bias = rearrange(rel_pos, 'i j h -> h i j')
        return x + bias

class GLUResBlock(nn.Module):
    """ 定义一个 GLUResBlock 类 """

    def __init__(self, chan, groups = 16):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(chan, chan * 2, 3, padding = 1),
            nn.GLU(dim = 1),
            nn.GroupNorm(groups, chan),
            nn.Conv2d(chan, chan * 2, 3, padding = 1),
            nn.GLU(dim = 1),
            nn.GroupNorm(groups, chan),
            nn.Conv2d(chan, chan, 1)
        )

    def forward(self, x):
        return self.net(x) + x

class ResBlock(nn.Module):
    """ 定义一个 ResBlock 类 """

    def __init__(self, chan, groups = 16):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(chan, chan, 3, padding = 1),
            nn.GroupNorm(groups, chan),
            leaky_relu(),
            nn.Conv2d(chan, chan, 3, padding = 1),
            nn.GroupNorm(groups, chan),
            leaky_relu(),
            nn.Conv2d(chan, chan, 1)
        )

    def forward(self, x):
        return self.net(x) + x

class VQGanAttention(nn.Module):
    """ 定义一个 VQGanAttention 类 """

    def __init__(
        self,
        *,
        dim,
        dim_head = 64,
        heads = 8,
        dropout = 0.
    ):
        super().__init__()
        self.heads = heads
        self.scale = nn.Parameter(torch.ones(1, heads, 1, 1) * math.log(0.01))
        inner_dim = heads * dim_head

        self.dropout = nn.Dropout(dropout)
        self.post_norm = LayerNormChan(dim)

        self.cpb = ContinuousPositionBias(dim = dim // 4, heads = heads)
        self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)
        self.to_out = nn.Conv2d(inner_dim, dim, 1)

    def forward(self, x):
        h = self.heads
        height, width, residual = *x.shape[-2:], x.clone()

        q, k, v = self.to_qkv(x).chunk(3, dim = 1)

        q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = h), (q, k, v))

        q, k = map(l2norm, (q, k))

        sim = einsum('b h c i, b h c j -> b h i j', q, k) * self.scale.exp()

        sim = self.cpb(sim)

        attn = stable_softmax(sim, dim = -1)
        attn = self.dropout(attn)

        out = einsum('b h i j, b h c j -> b h c i', attn, v)
        out = rearrange(out, 'b h c (x y) -> b (h c) x y', x = height, y = width)
        out = self.to_out(out)

        return self.post_norm(out) + residual

class VQGanVAE(nn.Module):
    """ 定义一个 VQGanVAE 类 """
    # 初始化函数,设置模型的参数
    def __init__(
        self,
        *,
        dim,  # 模型的维度
        image_size,  # 图像的尺寸
        channels = 3,  # 图像的通道数,默认为3
        num_layers = 4,  # 模型的层数,默认为4
        layer_mults = None,  # 每一层的倍增因子
        l2_recon_loss = False,  # 是否使用L2重建损失,默认为False
        use_hinge_loss = True,  # 是否使用hinge损失,默认为True
        num_resnet_blocks = 1,  # ResNet块的数量,默认为1
        vgg = None,  # VGG模型
        vq_codebook_dim = 256,  # VQ编码簇的维度
        vq_codebook_size = 512,  # VQ编码簇的大小
        vq_decay = 0.8,  # VQ衰减率
        vq_commitment_weight = 1.,  # VQ损失的权重
        vq_kmeans_init = True,  # 是否使用K均值初始化VQ编码簇,默认为True
        vq_use_cosine_sim = True,  # 是否使用余弦相似度计算VQ损失,默认为True
        use_attn = True,  # 是否使用注意力机制,默认为True
        attn_dim_head = 64,  # 注意力机制的头维度
        attn_heads = 8,  # 注意力机制的头数量
        resnet_groups = 16,  # ResNet块的组数
        attn_dropout = 0.,  # 注意力机制的dropout率
        first_conv_kernel_size = 5,  # 第一个卷积层的卷积核大小
        use_vgg_and_gan = True,  # 是否同时使用VGG和GAN,默认为True
        **kwargs  # 其他参数
        ):
        # 调用父类的构造函数
        super().__init__()
        # 断言维度必须能够被 resnet_groups 整除
        assert dim % resnet_groups == 0, f'dimension {dim} must be divisible by {resnet_groups} (groups for the groupnorm)'

        # 将参数中以 'vq_' 开头的参数提取出来
        vq_kwargs, kwargs = groupby_prefix_and_trim('vq_', kwargs)

        # 初始化一些属性
        self.image_size = image_size
        self.channels = channels
        self.num_layers = num_layers
        self.fmap_size = image_size // (num_layers ** 2)
        self.codebook_size = vq_codebook_size

        self.encoders = MList([])
        self.decoders = MList([])

        # 计算每一层的维度
        layer_mults = default(layer_mults, list(map(lambda t: 2 ** t, range(num_layers))))
        assert len(layer_mults) == num_layers, 'layer multipliers must be equal to designated number of layers'

        layer_dims = [dim * mult for mult in layer_mults]
        dims = (dim, *layer_dims)
        codebook_dim = layer_dims[-1]

        dim_pairs = zip(dims[:-1], dims[1:])

        append = lambda arr, t: arr.append(t)
        prepend = lambda arr, t: arr.insert(0, t)

        # 如果 num_resnet_blocks 不是元组,则转换为元组
        if not isinstance(num_resnet_blocks, tuple):
            num_resnet_blocks = (*((0,) * (num_layers - 1)), num_resnet_blocks)

        # 如果 use_attn 不是元组,则转换为元组
        if not isinstance(use_attn, tuple):
            use_attn = (*((False,) * (num_layers - 1)), use_attn)

        assert len(num_resnet_blocks) == num_layers, 'number of resnet blocks config must be equal to number of layers'
        assert len(use_attn) == num_layers

        # 遍历每一层,构建编码器和解码器
        for layer_index, (dim_in, dim_out), layer_num_resnet_blocks, layer_use_attn in zip(range(num_layers), dim_pairs, num_resnet_blocks, use_attn):
            append(self.encoders, nn.Sequential(nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1), leaky_relu()))
            prepend(self.decoders, nn.Sequential(nn.Upsample(scale_factor = 2, mode = 'bilinear', align_corners = False), nn.Conv2d(dim_out, dim_in, 3, padding = 1), leaky_relu()))

            if layer_use_attn:
                prepend(self.decoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))

            for _ in range(layer_num_resnet_blocks):
                append(self.encoders, ResBlock(dim_out, groups = resnet_groups))
                prepend(self.decoders, GLUResBlock(dim_out, groups = resnet_groups))

            if layer_use_attn:
                append(self.encoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))

        prepend(self.encoders, nn.Conv2d(channels, dim, first_conv_kernel_size, padding = first_conv_kernel_size // 2))
        append(self.decoders, nn.Conv2d(dim, channels, 1))

        # 初始化 VQ 模块
        self.vq = VQ(
            dim = layer_dims[-1],
            codebook_dim = vq_codebook_dim,
            codebook_size = vq_codebook_size,
            decay = vq_decay,
            commitment_weight = vq_commitment_weight,
            accept_image_fmap = True,
            kmeans_init = vq_kmeans_init,
            use_cosine_sim = vq_use_cosine_sim,
            **vq_kwargs
        )

        # 重构损失函数
        self.recon_loss_fn = F.mse_loss if l2_recon_loss else F.l1_loss

        # 如果是灰度图像,则关闭 GAN 和感知损失
        self.vgg = None
        self.discr = None
        self.use_vgg_and_gan = use_vgg_and_gan

        if not use_vgg_and_gan:
            return

        # 初始化感知损失
        if exists(vgg):
            self.vgg = vgg
        else:
            self.vgg = torchvision.models.vgg16(pretrained = True)
            self.vgg.classifier = nn.Sequential(*self.vgg.classifier[:-2])

        # 初始化GAN相关损失
        self.discr = Discriminator(dims = dims, channels = channels)

        self.discr_loss = hinge_discr_loss if use_hinge_loss else bce_discr_loss
        self.gen_loss = hinge_gen_loss if use_hinge_loss else bce_gen_loss
    # 创建一个模型的副本用于评估,确保在同一设备上
    def copy_for_eval(self):
        # 获取模型参数的设备信息
        device = next(self.parameters()).device
        # 深度复制模型并将其移动到 CPU
        vae_copy = copy.deepcopy(self.cpu())

        # 如果模型使用 VGG 和 GAN,则删除相关部分
        if vae_copy.use_vgg_and_gan:
            del vae_copy.discr
            del vae_copy.vgg

        # 将模型设置为评估模式
        vae_copy.eval()
        # 将模型移动回原设备
        return vae_copy.to(device)

    # 重写父类的 state_dict 方法,移除 VGG 相关部分
    @remove_vgg
    def state_dict(self, *args, **kwargs):
        return super().state_dict(*args, **kwargs)

    # 重写父类的 load_state_dict 方法,移除 VGG 相关部分
    @remove_vgg
    def load_state_dict(self, *args, **kwargs):
        return super().load_state_dict(*args, **kwargs)

    # 返回模型的 codebook 属性,即 VQ 模块的 codebook
    @property
    def codebook(self):
        return self.vq.codebook

    # 对输入进行编码操作,通过多个编码器层
    def encode(self, fmap):
        for enc in self.encoders:
            fmap = enc(fmap)

        return self.vq(fmap)

    # 对输入进行解码操作,通过多个解码器层
    def decode(self, fmap):
        for dec in self.decoders:
            fmap = dec(fmap)

        return fmap

    # 将 codebook 索引转换为视频数据
    @torch.no_grad()
    @eval_decorator
    def codebook_indices_to_video(self, indices):
        b = indices.shape[0]
        codes = self.codebook[indices]
        codes = rearrange(codes, 'b (f h w) d -> (b f) d h w', h = self.fmap_size, w = self.fmap_size)
        video = self.decode(codes)
        return rearrange(video, '(b f) ... -> b f ...', b = b)

    # 从视频数据中获取 codebook 索引
    @torch.no_grad()
    @eval_decorator
    def get_video_indices(self, video):
        b, f, _, h, w = video.shape
        images = rearrange(video, 'b f ... -> (b f) ...')
        _, indices, _ = self.encode(images)
        return rearrange(indices, '(b f) ... -> b f ...', b = b)

    # 模型的前向传播方法,包括返回损失、重构、梯度惩罚等选项
    def forward(
        self,
        img,
        return_loss = False,
        return_discr_loss = False,
        return_recons = False,
        apply_grad_penalty = False
        ):
        # 解构赋值,获取图像的批次、通道数、高度、宽度和设备信息
        batch, channels, height, width, device = *img.shape, img.device
        # 断言输入图像的高度和宽度与设定的self.image_size相等
        assert height == self.image_size and width == self.image_size, 'height and width of input image must be equal to {self.image_size}'
        # 断言输入图像的通道数与VQGanVAE中设定的通道数相等
        assert channels == self.channels, 'number of channels on image or sketch is not equal to the channels set on this VQGanVAE'

        # 编码输入图像,获取特征图、索引和commit_loss
        fmap, indices, commit_loss = self.encode(img)

        # 解码特征图
        fmap = self.decode(fmap)

        # 如果不需要返回损失和鉴别器损失,则直接返回解码后的特征图
        if not return_loss and not return_discr_loss:
            return fmap

        # 断言只能返回自编码器损失或鉴别器损失,不能同时返回
        assert return_loss ^ return_discr_loss, 'you should either return autoencoder loss or discriminator loss, but not both'

        # 是否返回鉴别器损失
        if return_discr_loss:
            # 断言鉴别器存在
            assert exists(self.discr), 'discriminator must exist to train it'

            # 分离特征图,设置输入图像为可求导
            fmap.detach_()
            img.requires_grad_()

            # 获取特征图和输入图像的鉴别器logits
            fmap_discr_logits, img_discr_logits = map(self.discr, (fmap, img))

            # 计算鉴别器损失
            loss = self.discr_loss(fmap_discr_logits, img_discr_logits)

            # 如果需要应用梯度惩罚
            if apply_grad_penalty:
                gp = gradient_penalty(img, img_discr_logits)
                loss = loss + gp

            # 如果需要返回重构图像
            if return_recons:
                return loss, fmap

            return loss

        # 重构损失
        recon_loss = self.recon_loss_fn(fmap, img)

        # 如果不使用VGG和GAN,则直接返回重构损失
        if not self.use_vgg_and_gan:
            if return_recons:
                return recon_loss, fmap

            return recon_loss

        # 感知损失
        img_vgg_input = img
        fmap_vgg_input = fmap

        # 处理灰度图像用于VGG
        if img.shape[1] == 1:
            img_vgg_input, fmap_vgg_input = map(lambda t: repeat(t, 'b 1 ... -> b c ...', c = 3), (img_vgg_input, fmap_vgg_input))

        # 获取输入图像和重构图像的VGG特征
        img_vgg_feats = self.vgg(img_vgg_input)
        recon_vgg_feats = self.vgg(fmap_vgg_input)
        perceptual_loss = F.mse_loss(img_vgg_feats, recon_vgg_feats)

        # 生成器损失
        gen_loss = self.gen_loss(self.discr(fmap))

        # 计算自适应权重
        last_dec_layer = self.decoders[-1].weight

        norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_dec_layer).norm(p = 2)
        norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_dec_layer).norm(p = 2)

        adaptive_weight = safe_div(norm_grad_wrt_perceptual_loss, norm_grad_wrt_gen_loss)
        adaptive_weight.clamp_(max = 1e4)

        # 组合损失
        loss = recon_loss + perceptual_loss + commit_loss + adaptive_weight * gen_loss

        # 如果需要返回重构图像
        if return_recons:
            return loss, fmap

        return loss