Lucidrains-系列项目源码解析-七十五-

44 阅读23分钟

Lucidrains 系列项目源码解析(七十五)

.\lucidrains\phenaki-pytorch\phenaki_pytorch\cvivit_trainer.py

# 从 math 模块中导入 sqrt 函数
from math import sqrt
# 从 random 模块中导入 choice 函数
from random import choice
# 从 pathlib 模块中导入 Path 类
from pathlib import Path
# 从 shutil 模块中导入 rmtree 函数
from shutil import rmtree

# 从 beartype 模块中导入 beartype 装饰器
from beartype import beartype

# 导入 torch 模块
import torch
# 从 torch 模块中导入 nn 模块
from torch import nn
# 从 torch.utils.data 模块中导入 Dataset, DataLoader, random_split 类
from torch.utils.data import Dataset, DataLoader, random_split

# 从 torchvision.transforms 模块中导入 T 别名
import torchvision.transforms as T
# 从 torchvision.datasets 模块中导入 ImageFolder 类
from torchvision.datasets import ImageFolder
# 从 torchvision.utils 模块中导入 make_grid, save_image 函数
from torchvision.utils import make_grid, save_image

# 从 einops 模块中导入 rearrange 函数
from einops import rearrange

# 从 phenaki_pytorch.optimizer 模块中导入 get_optimizer 函数
from phenaki_pytorch.optimizer import get_optimizer

# 从 ema_pytorch 模块中导入 EMA 类
from ema_pytorch import EMA

# 从 phenaki_pytorch.cvivit 模块中导入 CViViT 类
from phenaki_pytorch.cvivit import CViViT
# 从 phenaki_pytorch.data 模块中导入 ImageDataset, VideoDataset, video_tensor_to_gif 函数
from phenaki_pytorch.data import ImageDataset, VideoDataset, video_tensor_to_gif

# 从 accelerate 模块中导入 Accelerator 类

# helpers

# 定义 exists 函数,判断值是否存在
def exists(val):
    return val is not None

# 定义 noop 函数,空函数
def noop(*args, **kwargs):
    pass

# 定义 cycle 函数,循环生成数据
def cycle(dl):
    while True:
        for data in dl:
            yield data

# 定义 cast_tuple 函数,将参数转换为元组
def cast_tuple(t):
    return t if isinstance(t, (tuple, list)) else (t,)

# 定义 yes_or_no 函数,询问用户是否为是或否
def yes_or_no(question):
    answer = input(f'{question} (y/n) ')
    return answer.lower() in ('yes', 'y')

# 定义 accum_log 函数,累积日志信息
def accum_log(log, new_logs):
    for key, new_value in new_logs.items():
        old_value = log.get(key, 0.)
        log[key] = old_value + new_value
    return log

# main trainer class

# 使用 beartype 装饰器定义 CViViTTrainer 类
@beartype
class CViViTTrainer(nn.Module):
    # 初始化方法
    def __init__(
        self,
        vae: CViViT,
        *,
        num_train_steps,
        batch_size,
        folder,
        train_on_images = False,
        num_frames = 17,
        lr = 3e-4,
        grad_accum_every = 1,
        wd = 0.,
        max_grad_norm = 0.5,
        discr_max_grad_norm = None,
        save_results_every = 100,
        save_model_every = 1000,
        results_folder = './results',
        valid_frac = 0.05,
        random_split_seed = 42,
        use_ema = True,
        ema_beta = 0.995,
        ema_update_after_step = 0,
        ema_update_every = 1,
        apply_grad_penalty_every = 4,
        accelerate_kwargs: dict = dict()
    ):
        # 调用父类的构造函数
        super().__init__()
        # 获取 VAE 模型的图像大小
        image_size = vae.image_size

        # 初始化加速器
        self.accelerator = Accelerator(**accelerate_kwargs)

        # 设置 VAE 模型
        self.vae = vae

        # 是否使用指数移动平均
        self.use_ema = use_ema
        # 如果是主进程且使用指数移动平均
        if self.is_main and use_ema:
            # 初始化指数移动平均 VAE 模型
            self.ema_vae = EMA(vae, update_after_step = ema_update_after_step, update_every = ema_update_every)

        # 注册缓冲区 'steps',用于记录训练步数
        self.register_buffer('steps', torch.Tensor([0]))

        # 设置训练步数、批量大小和梯度累积步数
        self.num_train_steps = num_train_steps
        self.batch_size = batch_size
        self.grad_accum_every = grad_accum_every

        # 获取所有参数、判别器参数和 VAE 参数
        all_parameters = set(vae.parameters())
        discr_parameters = set(vae.discr.parameters())
        vae_parameters = all_parameters - discr_parameters

        self.vae_parameters = vae_parameters

        # 获取优化器
        self.optim = get_optimizer(vae_parameters, lr = lr, wd = wd)
        self.discr_optim = get_optimizer(discr_parameters, lr = lr, wd = wd)

        # 设置梯度裁剪阈值
        self.max_grad_norm = max_grad_norm
        self.discr_max_grad_norm = discr_max_grad_norm

        # 创建数据集
        dataset_klass = ImageDataset if train_on_images else VideoDataset
        if train_on_images:
            self.ds = ImageDataset(folder, image_size)
        else:
            self.ds = VideoDataset(folder, image_size, num_frames = num_frames)

        # 划分验证集
        if valid_frac > 0:
            train_size = int((1 - valid_frac) * len(self.ds))
            valid_size = len(self.ds) - train_size
            self.ds, self.valid_ds = random_split(self.ds, [train_size, valid_size], generator = torch.Generator().manual_seed(random_split_seed))
            self.print(f'training with dataset of {len(self.ds)} samples and validating with randomly splitted {len(self.valid_ds)} samples')
        else:
            self.valid_ds = self.ds
            self.print(f'training with shared training and valid dataset of {len(self.ds)} samples')

        # 创建数据加载器
        self.dl = DataLoader(
            self.ds,
            batch_size = batch_size,
            shuffle = True
        )

        self.valid_dl = DataLoader(
            self.valid_ds,
            batch_size = batch_size,
            shuffle = True
        )

        # 准备加速器
        (
            self.vae,
            self.optim,
            self.discr_optim,
            self.dl
        ) = self.accelerator.prepare(
            self.vae,
            self.optim,
            self.discr_optim,
            self.dl
        )

        # 创建数据加载器迭代器
        self.dl_iter = cycle(self.dl)
        self.valid_dl_iter = cycle(self.valid_dl)

        # 设置模型保存频率和结果保存频率
        self.save_model_every = save_model_every
        self.save_results_every = save_results_every

        # 设置应用梯度惩罚的频率
        self.apply_grad_penalty_every = apply_grad_penalty_every

        # 设置结果文件夹
        self.results_folder = Path(results_folder)

        # 如果结果文件夹不为空且确认清除之前的实验检查点和结果
        if len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?'):
            rmtree(str(self.results_folder))

        # 创建��果文件夹
        self.results_folder.mkdir(parents = True, exist_ok = True)

    # 保存模型
    def save(self, path):
        if not self.accelerator.is_local_main_process:
            return

        pkg = dict(
            model = self.accelerator.get_state_dict(self.vae),
            optim = self.optim.state_dict(),
            discr_optim = self.discr_optim.state_dict()
        )
        torch.save(pkg, path)

    # 加载模型
    def load(self, path):
        path = Path(path)
        assert path.exists()
        pkg = torch.load(path)

        vae = self.accelerator.unwrap_model(self.vae)
        vae.load_state_dict(pkg['model'])

        self.optim.load_state_dict(pkg['optim'])
        self.discr_optim.load_state_dict(pkg['discr_optim'])

    # 打印信息
    def print(self, msg):
        self.accelerator.print(msg)

    # 获取设备
    @property
    def device(self):
        return self.accelerator.device

    # 是否分布式训练
    @property
    def is_distributed(self):
        return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1)

    @property
    # 检查当前进程是否为主进程
    def is_main(self):
        return self.accelerator.is_main_process

    # 检查当前进程是否为本地主进程
    @property
    def is_local_main(self):
        return self.accelerator.is_local_main_process

    # 训练函数,接受一个日志函数作为参数,默认为一个空函数
    def train(self, log_fn = noop):
        # 获取 VAE 模型参数的设备信息
        device = next(self.vae.parameters()).device

        # 在训练步数未达到指定步数之前循环执行训练步骤
        while self.steps < self.num_train_steps:
            # 执行单个训练步骤,返回日志信息
            logs = self.train_step()
            # 调用日志函数记录日志信息
            log_fn(logs)

        # 打印训练完成信息
        self.print('training complete')

.\lucidrains\phenaki-pytorch\phenaki_pytorch\data.py

# 导入所需的库
from pathlib import Path
import cv2
from PIL import Image
from functools import partial
from typing import Tuple, List
from beartype.door import is_bearable
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader as PytorchDataLoader
from torchvision import transforms as T, utils
from einops import rearrange

# 辅助函数

# 检查值是否存在
def exists(val):
    return val is not None

# 返回输入值
def identity(t, *args, **kwargs):
    return t

# 将输入值转换为元组
def pair(val):
    return val if isinstance(val, tuple) else (val, val)

# 调整帧数
def cast_num_frames(t, *, frames):
    f = t.shape[1]
    if f == frames:
        return t
    if f > frames:
        return t[:, :frames]
    return F.pad(t, (0, 0, 0, 0, 0, frames - f))

# 将图像转换为指定格式
def convert_image_to_fn(img_type, image):
    if image.mode != img_type:
        return image.convert(img_type)
    return image

# 图像相关的辅助函数和数据集

# 图像数据集类
class ImageDataset(Dataset):
    def __init__(
        self,
        folder,
        image_size,
        exts = ['jpg', 'jpeg', 'png']
    ):
        super().__init__()
        self.folder = folder
        self.image_size = image_size
        self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]

        print(f'{len(self.paths)} training samples found at {folder}')

        self.transform = T.Compose([
            T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
            T.Resize(image_size),
            T.RandomHorizontalFlip(),
            T.CenterCrop(image_size),
            T.ToTensor()
        ])

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        path = self.paths[index]
        img = Image.open(path)
        return self.transform(img)

# 处理读取和写入 GIF

# 通道数对应的图像模式
CHANNELS_TO_MODE = {
    1 : 'L',
    3 : 'RGB',
    4 : 'RGBA'
}

# 读取 GIF 中的所有图像
def seek_all_images(img, channels = 3):
    assert channels in CHANNELS_TO_MODE, f'channels {channels} invalid'
    mode = CHANNELS_TO_MODE[channels]

    i = 0
    while True:
        try:
            img.seek(i)
            yield img.convert(mode)
        except EOFError:
            break
        i += 1

# 将视频张量转换为 GIF
def video_tensor_to_gif(
    tensor,
    path,
    duration = 120,
    loop = 0,
    optimize = True
):
    images = map(T.ToPILImage(), tensor.unbind(dim = 1))
    first_img, *rest_imgs = images
    first_img.save(path, save_all = True, append_images = rest_imgs, duration = duration, loop = loop, optimize = optimize)
    return images

# GIF 转换为张量
def gif_to_tensor(
    path,
    channels = 3,
    transform = T.ToTensor()
):
    img = Image.open(path)
    tensors = tuple(map(transform, seek_all_images(img, channels = channels)))
    return torch.stack(tensors, dim = 1)

# 处理读取和写入 MP4

# 将视频转换为张量
def video_to_tensor(
    path: str,              # 要导入的视频路径
    num_frames = -1,        # 要存储在输出张量中的帧数
    crop_size = None
) -> torch.Tensor:          # 形状为 (1, 通道数, 帧数, 高度, 宽度)

    video = cv2.VideoCapture(path)

    frames = []
    check = True

    while check:
        check, frame = video.read()

        if not check:
            continue

        if exists(crop_size):
            frame = crop_center(frame, *pair(crop_size))

        frames.append(rearrange(frame, '... -> 1 ...'))

    frames = np.array(np.concatenate(frames[:-1], axis = 0))  # 将帧列表转换为 numpy 数组
    frames = rearrange(frames, 'f h w c -> c f h w')

    frames_torch = torch.tensor(frames).float()

    return frames_torch[:, :num_frames, :, :]

# 将张量转换为视频
def tensor_to_video(
    tensor,                # Pytorch 视频张量
    path: str,             # 要保存的视频路径
    fps = 25,              # 保存视频的帧率
    # 定义视频格式为 MP4V
    video_format = 'MP4V'
# Import the video and cut it into frames.
def read_zip(fname):
    # 将张量移回 CPU
    tensor = tensor.cpu()

    # 获取张量的帧数、高度和宽度
    num_frames, height, width = tensor.shape[-3:]

    # 使用指定的视频格式创建 VideoWriter 对象
    fourcc = cv2.VideoWriter_fourcc(*video_format) # Changes in this line can allow for different video formats.
    video = cv2.VideoWriter(path, fourcc, fps, (width, height))

    frames = []

    # 遍历每一帧,将张量转换为 numpy 数组并写入视频
    for idx in range(num_frames):
        numpy_frame = tensor[:, idx, :, :].numpy()
        numpy_frame = np.uint8(rearrange(numpy_frame, 'c h w -> h w c'))
        video.write(numpy_frame)

    # 释放视频对象
    video.release()

    # 关闭所有 OpenCV 窗口
    cv2.destroyAllWindows()

    # 返回视频对象
    return video

# 将图像中心裁剪为指定大小
def crop_center(
    img,        # tensor
    cropx,      # Length of the final image in the x direction.
    cropy       # Length of the final image in the y direction.
) -> torch.Tensor:
    y, x, c = img.shape
    startx = x // 2 - cropx // 2
    starty = y // 2 - cropy // 2
    return img[starty:(starty + cropy), startx:(startx + cropx), :]

# 视频数据集类
class VideoDataset(Dataset):
    def __init__(
        self,
        folder,
        image_size,
        channels = 3,
        num_frames = 17,
        horizontal_flip = False,
        force_num_frames = True,
        exts = ['gif', 'mp4']
    ):
        super().__init__()
        self.folder = folder
        self.image_size = image_size
        self.channels = channels
        self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]

        # 定义数据转换流程
        self.transform = T.Compose([
            T.Resize(image_size),
            T.RandomHorizontalFlip() if horizontal_flip else T.Lambda(identity),
            T.CenterCrop(image_size),
            T.ToTensor()
        ])

        # 定义将视频路径转换为张量的函数
        self.gif_to_tensor = partial(gif_to_tensor, channels = self.channels, transform = self.transform)
        self.mp4_to_tensor = partial(video_to_tensor, crop_size = self.image_size)

        # 定义将帧数转换为指定数量的函数
        self.cast_num_frames_fn = partial(cast_num_frames, frames = num_frames) if force_num_frames else identity

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        path = self.paths[index]
        ext = path.suffix

        # 根据文���扩展名选择相应的处理方式
        if ext == '.gif':
            tensor = self.gif_to_tensor(path)
        elif ext == '.mp4':
            tensor = self.mp4_to_tensor(str(path))
        else:
            raise ValueError(f'unknown extension {ext}')

        # 转换帧数并返回张量
        return self.cast_num_frames_fn(tensor)

# 重写数据加载器以能够整理字符串
def collate_tensors_and_strings(data):
    if is_bearable(data, List[torch.Tensor]):
        return (torch.stack(data, dim = 0),)

    data = zip(*data)
    output = []

    for datum in data:
        if is_bearable(datum, Tuple[torch.Tensor, ...]):
            datum = torch.stack(datum, dim = 0)
        elif is_bearable(datum, Tuple[str, ...]):
            datum = list(datum)
        else:
            raise ValueError('detected invalid type being passed from dataset')

        output.append(datum)

    return tuple(output)

# 创建数据加载器
def DataLoader(*args, **kwargs):
    return PytorchDataLoader(*args, collate_fn = collate_tensors_and_strings, **kwargs)

.\lucidrains\phenaki-pytorch\phenaki_pytorch\optimizer.py

# 从 torch.optim 模块中导入 AdamW 和 Adam 优化器
from torch.optim import AdamW, Adam

# 将参数分为需要权重衰减和不需要权重衰减的两个列表
def separate_weight_decayable_params(params):
    wd_params, no_wd_params = [], []
    for param in params:
        # 根据参数的维度判断是否需要权重衰减
        param_list = no_wd_params if param.ndim < 2 else wd_params
        param_list.append(param)
    return wd_params, no_wd_params

# 获取优化器
def get_optimizer(
    params,
    lr = 1e-4,
    wd = 1e-2,
    betas = (0.9, 0.99),
    eps = 1e-8,
    filter_by_requires_grad = False,
    group_wd_params = True,
    **kwargs
):
    # 根据是否需要梯度过滤参数
    if filter_by_requires_grad:
        params = list(filter(lambda t: t.requires_grad, params))

    # 如果权重衰减为0,则使用 Adam 优化器
    if wd == 0:
        return Adam(params, lr = lr, betas = betas, eps = eps)

    # 如果需要对参数进行分组权重衰减
    if group_wd_params:
        wd_params, no_wd_params = separate_weight_decayable_params(params)

        # 将参数分为需要权重衰减和不需要权重衰减的两组
        params = [
            {'params': wd_params},
            {'params': no_wd_params, 'weight_decay': 0},
        ]

    # 使用 AdamW 优化器,设置学习率、权重衰减、动量参数和 epsilon
    return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps)

.\lucidrains\phenaki-pytorch\phenaki_pytorch\phenaki_pytorch.py

# 导入数学库
import math
# 导入 functools 库
import functools
# 从 contextlib 库中导入 nullcontext
from contextlib import nullcontext
# 从 functools 库中导入 partial 和 wraps
from functools import partial, wraps

# 从 typing 模块中导入 Optional, List, Union
from typing import Optional, List, Union
# 从 beartype 库中导入 beartype
from beartype import beartype

# 导入 torch 库
import torch
# 从 torch.nn.functional 中导入 F
import torch.nn.functional as F
# 从 torch 中导入 nn, einsum
from torch import nn, einsum

# 从 einops 库中导入 rearrange, repeat, pack, unpack
from einops import rearrange, repeat, pack, unpack
# 从 einops.layers.torch 中导入 Rearrange
from einops.layers.torch import Rearrange

# 从 phenaki_pytorch.t5 中导入 t5_encode_text, get_encoded_dim, DEFAULT_T5_NAME
from phenaki_pytorch.t5 import t5_encode_text, get_encoded_dim, DEFAULT_T5_NAME

# 从 phenaki_pytorch.cvivit 中导入 CViViT
from phenaki_pytorch.cvivit import CViViT
# 从 phenaki_pytorch.attention 中导入 Attention, Transformer, ContinuousPositionBias

# helpers

# 定义函数 exists,判断值是否存在
def exists(val):
    return val is not None

# 定义函数 default,返回值或默认值
def default(val, d):
    return val if exists(val) else d

# 定义函数 cast_tuple,将值转换为元组
def cast_tuple(val, length = 1):
    return val if isinstance(val, tuple) else (val,) * length

# 定义函数 reduce_mult,对数组中的元素进行累乘
def reduce_mult(arr):
    return functools.reduce(lambda x, y: x * y, arr)

# 定义函数 divisible_by,判断两个数是否整除
def divisible_by(numer, denom):
    return (numer % denom) == 0

# tensor helpers

# 定义函数 get_mask_subset_with_prob,根据概率获取掩码子集
def get_mask_subset_with_prob(mask, prob):
    batch, seq_len, device = *mask.shape, mask.device

    num_tokens = mask.sum(dim = -1)
    num_pads = seq_len - num_tokens
    num_masked = (prob * num_tokens).round().clamp(min = 1)

    randperm_indices = torch.rand((batch, seq_len), device = device).argsort(dim = -1)
    randperm_indices -= rearrange(num_pads, 'b -> b 1')
    randperm_indices.masked_fill_(randperm_indices < 0, seq_len) # set to max out of bounds, so never chosen

    mask_subset = randperm_indices < rearrange(num_masked, 'b -> b 1')
    return mask_subset

# decorators

# 定义装饰器 eval_decorator,用于在评估模型时切换模型状态
def eval_decorator(fn):
    def inner(model, *args, **kwargs):
        was_training = model.training
        model.eval()
        out = fn(model, *args, **kwargs)
        model.train(was_training)
        return out
    return inner

# classifier free guidance functions

# 定义函数 uniform,生成指定形状的均匀分布张量
def uniform(shape, device):
    return torch.zeros(shape, device = device).float().uniform_(0, 1)

# 定义函数 prob_mask_like,生成概率掩码张量
def prob_mask_like(shape, prob, device):
    if prob == 1:
        return torch.ones(shape, device = device, dtype = torch.bool)
    elif prob == 0:
        return torch.zeros(shape, device = device, dtype = torch.bool)
    else:
        return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob

# tensor helper functions

# 定义函数 log,计算张量的对数
def log(t, eps = 1e-10):
    return torch.log(t + eps)

# sampling helpers

# 定义函数 gumbel_noise,生成古贝尔噪声
def gumbel_noise(t):
    noise = torch.zeros_like(t).uniform_(0, 1)
    return -log(-log(noise))

# 定义函数 gumbel_sample,使用古贝尔噪声进行采样
def gumbel_sample(t, temperature = 1., dim = -1):
    return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim = dim)

# 定义函数 top_k,根据阈值获取前 k 个概率最大的位置
def top_k(logits, thres = 0.5):
    num_logits = logits.shape[-1]
    k = max(int((1 - thres) * num_logits), 1)
    val, ind = torch.topk(logits, k)
    probs = torch.full_like(logits, float('-inf'))
    probs.scatter_(1, ind, val)
    return probs

# mask git

# 定义 MaskGit 类
class MaskGit(nn.Module):
    def __init__(
        self,
        *,
        dim,
        num_tokens,
        max_seq_len,
        gradient_shrink_alpha = 0.1,
        heads = 8,
        dim_head = 64,
        unconditional = False,
        attn_dropout = 0.,
        ff_dropout = 0.,
        **kwargs
    # 初始化函数,设置模型的维度、mask_id、是否无条件生成等参数
    ):
        super().__init__()
        self.dim = dim

        self.mask_id = num_tokens
        self.unconditional = unconditional

        # 创建 token embedding 层,num_tokens + 1 个 token,最后一个用作 mask_id
        self.token_emb = nn.Embedding(num_tokens + 1, dim)

        self.max_seq_len = max_seq_len
        # 创建位置编码 embedding 层
        self.pos_emb = nn.Embedding(max_seq_len, dim)

        # 设置梯度缩放参数
        self.gradient_shrink_alpha = gradient_shrink_alpha

        # 创建连续位置偏置
        self.continuous_pos_bias = ContinuousPositionBias(dim = dim_head, heads = heads, num_dims = 3)

        # 创建 Transformer 模型
        self.transformer = Transformer(
            dim = dim,
            attn_num_null_kv = 2,
            has_cross_attn = not self.unconditional,
            dim_head = dim_head,
            heads = heads,
            attn_dropout = attn_dropout,
            ff_dropout = ff_dropout,
            peg = True,
            **kwargs
        )

        # 创建输出层,将 dim 维度映射到 num_tokens
        self.to_logits = nn.Linear(dim, num_tokens)

    # 带条件缩放的前向传播函数
    def forward_with_cond_scale(
        self,
        *args,
        cond_scale = 3,
        **kwargs
    ):
        # 调用前向传播函数,cond_drop_prob 为 0
        logits = self.forward(*args, cond_drop_prob = 0., **kwargs)

        if cond_scale == 1:
            return logits

        # 调用前向传播函数,cond_drop_prob 为 1
        null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
        return null_logits + (logits - null_logits) * cond_scale

    # 前向传播函数
    def forward(
        self,
        x,
        cond_drop_prob = 0.,
        text_mask = None,
        video_mask = None,
        video_patch_shape = None,
        return_embeds = False,
        **kwargs
    ):
        assert x.ndim in {2, 4}, 'video token ids must be of shape (batch, seq) or (batch, frame, height, width)'

        if x.ndim == 4:
            video_patch_shape = x.shape[1:]
            x = rearrange(x, 'b ... -> b (...)')

        b, n, device = *x.shape, x.device

        # 如果 text_mask 不存在,则创建全为 True 的 mask
        if not exists(text_mask):
            text_mask = torch.ones((b, n), device = device, dtype = torch.bool)

        assert exists(video_patch_shape), 'video patch shape must be given'

        # 计算相对位置偏置
        rel_pos_bias = self.continuous_pos_bias(*video_patch_shape, device = device)

        # 如果 cond_drop_prob 大于 0,则生成保留 mask
        if cond_drop_prob > 0:
            keep_mask = prob_mask_like((b,), 1 - cond_drop_prob, device = device)
            text_mask = rearrange(keep_mask, 'b -> b 1') & text_mask

        video_shape = (b, *video_patch_shape)

        # 对输入进行 token embedding
        x = self.token_emb(x)

        # 断言视频 token 序列长度不超过 max_seq_len
        assert n <= self.max_seq_len, f'the video token sequence length you are passing in ({n}) is greater than the `max_seq_len` ({self.max_seq_len}) set on your `MaskGit`'
        x = self.pos_emb(torch.arange(n, device = device)) + x

        # 梯度缩放
        x = x * self.gradient_shrink_alpha + x.detach() * (1 - self.gradient_shrink_alpha)

        # Transformer 模型的前向传播
        x = self.transformer(
            x,
            video_shape = video_shape,
            attn_bias = rel_pos_bias,
            self_attn_mask = video_mask,
            cross_attn_context_mask = text_mask,
            **kwargs
        )

        # 如果需要返回嵌入向量,则直接返回
        if return_embeds:
            return x

        return self.to_logits(x)
# 定义 TokenCritic 类,继承自 nn.Module
class TokenCritic(nn.Module):
    def __init__(
        self,
        *,
        dim,  # 维度
        num_tokens,  # token 数量
        max_seq_len,  # 最大序列长度
        has_cross_attn = False,  # 是否有跨注意力
        attn_dropout = 0.,  # 注意力丢弃率
        ff_dropout = 0.,  # FeedForward 层丢弃率
        **kwargs
    ):
        super().__init__()
        self.has_cross_attn = has_cross_attn

        self.mask_id = num_tokens  # 定义 mask_id 为 num_tokens

        self.token_emb = nn.Embedding(num_tokens + 1, dim)  # 创建 token 的嵌入层,最后一个 token 用作 mask_id
        self.pos_emb = nn.Embedding(max_seq_len, dim)  # 创建位置嵌入层

        self.transformer = Transformer(
            dim = dim,
            peg = True,
            attn_dropout = attn_dropout,
            ff_dropout = ff_dropout,
            has_cross_attn = has_cross_attn,
            **kwargs
        )  # 创建 Transformer 模型

        self.to_logits = nn.Sequential(
            nn.Linear(dim, 1),  # 线性层
            Rearrange('... 1 -> ...')  # 重排维度
        )  # 创建输出 logits 的序列

    def forward_with_cond_scale(
        self,
        *args,
        cond_scale = 3,  # 条件缩放
        **kwargs
    ):
        logits = self.forward(*args, cond_drop_prob = 0., **kwargs)  # 调用 forward 方法获取 logits

        if cond_scale == 1:
            return logits

        null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)  # 调用 forward 方法获取 null_logits
        return null_logits + (logits - null_logits) * cond_scale  # 返回根据条件缩放计算后的结果

    def forward(
        self,
        x,
        text_mask = None,
        cond_drop_prob = None,
        context = None,
        video_mask = None,
        video_patch_shape = None,
        **kwargs
    ):
        if exists(video_patch_shape):
            video_shape = (x.shape[0], *video_patch_shape)
        else:
            video_shape = x.shape

        x = rearrange(x, 'b ... -> b (...)')  # 重排输入数据的维度
        b, n, device = *x.shape, x.device

        if not exists(text_mask):
            text_mask = torch.ones((b, n), device = device, dtype = torch.bool)  # 如果不存在文本 mask,则创建全为 True 的 mask

        if exists(context) and cond_drop_prob > 0:
            keep_mask = prob_mask_like((b,), 1 - cond_drop_prob, device = device)  # 根据条件概率创建 mask
            text_mask = rearrange(keep_mask, 'b -> b 1') & text_mask  # ���新文本 mask

        x = self.token_emb(x)  # 对输入数据进行 token 嵌入
        x = self.pos_emb(torch.arange(n, device = device)) + x  # 添加位置嵌入

        x = self.transformer(
            x,
            video_shape = video_shape,
            context = context,
            self_attn_mask = video_mask,
            cross_attn_context_mask = text_mask,
            **kwargs
        )  # 调用 Transformer 模型进行计算

        return self.to_logits(x)  # 返回 logits

# 定义 SelfCritic 类,继承自 nn.Module,受 Nijkamp 等人启发
@beartype
class SelfCritic(nn.Module):
    def __init__(
        self,
        maskgit: MaskGit  # 接收 MaskGit 类型参数
    ):
        super().__init__()
        self.maskgit = maskgit

        self.to_pred = nn.Sequential(
            nn.Linear(maskgit.dim, 1),  # 线性层
            Rearrange('... 1 -> ...')  # 重排维度
        )  # 创建输出预测的序列

    def forward_with_cond_scale(
        self,
        *args,
        cond_scale = 3,  # 条件缩放
        **kwargs
    ):
        logits = self.forward(*args, cond_drop_prob = 0., **kwargs)  # 调用 forward 方法获取 logits

        if cond_scale == 1:
            return logits

        null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)  # 调用 forward 方法获取 null_logits
        return null_logits + (logits - null_logits) * cond_scale  # 返回根据条件缩放计算后的结果

    def forward(self, x, *args, **kwargs):
        embeds = self.maskgit(x, *args, return_embeds = True, **kwargs)  # 调用 maskgit 方法获取嵌入
        return self.to_pred(embeds)  # 返回预测结果

# 定义 Phenaki 类,继承自 nn.Module
@beartype
class Phenaki(nn.Module):
    def __init__(
        self,
        *,
        maskgit: MaskGit,  # MaskGit 类型参数
        cvivit: CViViT,  # CViViT 类型参数
        critic: Optional[Union[TokenCritic, SelfCritic]] = None,  # 可选的 TokenCritic 或 SelfCritic 类型参数
        steps = 18,  # 步数
        t5_name = DEFAULT_T5_NAME,  # T5 模型名称
        sample_temperature = 0.,  # 采样温度
        text_embed_dim = None,  # 文本嵌入维度
        cond_drop_prob = 0.25,  # 条件丢弃概率
        max_text_len = 128,  # 最大文本长度
        self_token_critic = False,  # 是否使用自身 TokenCritic
        critic_loss_weight = 1.,  # TokenCritic 权重
        critic_noise_anneal_schedule = 'decay',  # TokenCritic 噪声退火计划
        critic_train_sample_temperature = 1.  # TokenCritic 训练采样温度
    # 初始化函数,继承父类的初始化方法
    def __init__(self):
        super().__init__()

        # 复制cvivit用于评估
        self.cvivit = cvivit.copy_for_eval()

        # 设置maskgit属性
        self.maskgit = maskgit
        self.unconditional = maskgit.unconditional

        # 设置mask_id属性
        self.mask_id = maskgit.mask_id

        # 断言条件,确保self_token_critic和critic不存在,或者critic存在
        assert not (self_token_critic and exists(critic))

        # 如果self_token_critic为真,则创建SelfCritic对象
        if self_token_critic:
            critic = SelfCritic(maskgit)

        # 如果critic存在,则将其设置为评估模式
        if exists(critic):
            critic = critic.eval()

        # 断言条件,确保critic不存在或者self_token_critic为真,或者maskgit.unconditional为假且critic具有交叉注意力
        assert not exists(critic) or self_token_critic or (not maskgit.unconditional) == critic.has_cross_attn

        # 设置critic相关属性
        self.critic = critic
        self.critic_noise_anneal_schedule = critic_noise_anneal_schedule
        self.critic_loss_weight = critic_loss_weight
        self.critic_train_sample_temperature = critic_train_sample_temperature

        # 设置步数和采样温度
        self.steps = steps
        self.sample_temperature = sample_temperature

        # 文本条件
        text_embed_dim = default(text_embed_dim, get_encoded_dim(t5_name))
        self.encode_texts = partial(t5_encode_text, name = t5_name)
        self.text_embed_dim = text_embed_dim
        self.max_text_len = max_text_len

        # 断言条件,确保cond_drop_prob大于0
        assert cond_drop_prob > 0.
        # 设置cond_drop_prob属性,用于transformers的分类器自由引导
        self.cond_drop_prob = cond_drop_prob # classifier free guidance for transformers - @crowsonkb

    # 采样图像函数
    def sample_images(
        self,
        *,
        texts: Union[List[str], str] = None,
        batch_size = 1,
        cond_scale = 3.,
        starting_temperature = 0.9,
        noise_K = 1.
    ):
        # 生成单帧视频
        single_framed_video = self.sample(
            texts = texts,
            num_frames = 1,
            cond_scale = cond_scale,
            starting_temperature = starting_temperature,
            noise_K = noise_K
        )

        # 重新排列视频维度
        return rearrange(single_framed_video, '... c 1 h w')

    # 采样函数
    @eval_decorator
    @torch.no_grad()
    def sample(
        self,
        *,
        num_frames,
        texts: Union[List[str], str] = None,
        prime_frames = None,
        batch_size = 1,
        cond_scale = 3.,
        starting_temperature = 0.9,
        noise_K = 1. # 用于token-critic论文第3.2节中critic分数的噪声超参数,需要找到正确的值
    def forward(
        self,
        videos = None,
        *,
        texts: Optional[List[str]] = None,
        video_codebook_ids = None,
        video_frame_mask = None,
        text_embeds = None,
        cond_drop_prob = None,
        only_train_generator = False,
        only_train_critic = False
# 定义一个名为 make_video 的函数,用于生成视频

@beartype
# 使用 beartype 装饰器对函数参数进行类型检查
def make_video(
    phenaki: Phenaki,  # 接受 Phenaki 对象作为参数
    texts: List[str],  # 接受一个字符串列表作为参数
    num_frames,  # 接受一个整数作为参数,表示帧数
    prime_lengths  # 接受一个整数或整数元组作为参数,表示前置长度
):
    num_scenes = len(texts)  # 获取文本列表的长度,即场景数
    num_frames = cast_tuple(num_frames, num_scenes)  # 将 num_frames 转换为元组,长度与场景数相同

    prime_lengths = cast_tuple(prime_lengths, num_scenes - 1)  # 将 prime_lengths 转换为元组,长度为场景数减一
    prime_lengths = (*prime_lengths, 0)  # 在 prime_lengths 元组末尾添加一个 0,表示最后一个场景无需前置长度

    entire_video = []  # 初始化整个视频列表
    video_prime = None  # 初始化视频前置
    scenes = []  # 初始化场景列表

    # 遍历文本、帧数、前置长度三个参数的元素,生成视频
    for text, scene_num_frames, next_scene_prime_length in zip(texts, num_frames, prime_lengths):
        # 从 Phenaki 对象中生成视频,传入文本、视频前置、场景帧数
        video = phenaki.sample(texts=text, prime_frames=video_prime, num_frames=scene_num_frames)
        scenes.append(video)  # 将生成的视频添加到场景列表中

        video_prime = video[:, :, -next_scene_prime_length:]  # 更新视频前置为当前视频的最后 next_scene_prime_length 帧

    # 将所有场景的视频拼接在一起,沿着第二维度拼接,返回拼接后的视频和场景列表
    return torch.cat(scenes, dim=2), scenes

.\lucidrains\phenaki-pytorch\phenaki_pytorch\phenaki_trainer.py

# 导入数学库
import math
# 导入复制库
import copy
# 导入路径库
from pathlib import Path
# 导入随机库
from random import random, choices
# 导入偏函数库
from functools import partial
# 导入命名元组库
from collections import namedtuple
# 导入 CPU 核心数库
from multiprocessing import cpu_count

# 导入 beartype 库
from beartype import beartype
# 导入 beartype.door 库
from beartype.door import is_bearable
# 导入 beartype.vale 库
from beartype.vale import Is
# 导入类型提示库
from typing import Optional, List, Iterable, Tuple
# 导入类型扩展库
from typing_extensions import Annotated

# 导入 PyTorch 库
import torch
# 从 PyTorch 中导入神经网络库和张量乘法库
from torch import nn, einsum
# 从 PyTorch 中导入函数库
import torch.nn.functional as F
# 从 PyTorch 中导入数据集库
from torch.utils.data import Dataset
# 从 PyTorch 中导入优化器库
from torch.optim import Adam

# 从 torchvision 中导入变换库
from torchvision import transforms as T
# 从 torchvision 中导入图像处理库
from torchvision.utils import make_grid, save_image

# 从 einops 中导入重排库和减少库
from einops import rearrange, reduce
# 从 einops.layers.torch 中导入重排层
from einops.layers.torch import Rearrange

# 从 PIL 中导入图像库
from PIL import Image
# 从 tqdm.auto 中导入进度条库
from tqdm.auto import tqdm

# 从 phenaki_pytorch.optimizer 中导入获取优化器函数
from phenaki_pytorch.optimizer import get_optimizer
# 从 accelerate 中导入加速器库
from accelerate import Accelerator

# 从 phenaki_pytorch.phenaki_pytorch 中导入 Phenaki 类
from phenaki_pytorch.phenaki_pytorch import Phenaki

# 从 phenaki_pytorch.data 中导入图像数据集、视频数据集、视频张量转 GIF、数据加载器
from phenaki_pytorch.data import ImageDataset, VideoDataset, video_tensor_to_gif, DataLoader

# 常量

# 数据集字段类型配置
DATASET_FIELD_TYPE_CONFIG = dict(
    videos = Annotated[
        torch.Tensor,
        Is[lambda t: t.dtype == torch.float and t.ndim in {4, 5}]
    ],
    texts = List[str],
    video_codebook_ids = Annotated[
        torch.Tensor,
        Is[lambda t: t.dtype == torch.long]
    ],
    video_frame_mask = Annotated[
        torch.Tensor,
        Is[lambda t: t.dtype == torch.bool]
    ],
    text_embeds = Annotated[
        torch.Tensor,
        Is[lambda t: t.dtype == torch.float and t.ndim == 3]
    ],
)

# 辅助函数

# 检查变量是否存在
def exists(x):
    return x is not None

# 返回默认值
def default(val, d):
    if exists(val):
        return val
    return d() if callable(d) else d

# 返回输入值
def identity(t, *args, **kwargs):
    return t

# 无限循环生成数据
def cycle(dl):
    while True:
        for data in dl:
            yield data

# 检查整数是否有平方根
def has_int_squareroot(num):
    return (math.sqrt(num) ** 2) == num

# 将数字分组
def num_to_groups(num, divisor):
    groups = num // divisor
    remainder = num % divisor
    arr = [divisor] * groups
    if remainder > 0:
        arr.append(remainder)
    return arr

# 将元素转移到指定设备
def elements_to_device_if_tensor(arr, device):
    output = []
    for el in arr:
        if isinstance(el, torch.Tensor):
            el = el.to(device)
        output.append(el)
    return output

# 分割可迭代对象
def split_iterable(it, split_size):
    accum = []
    for ind in range(math.ceil(len(it) / split_size)):
        start_index = ind * split_size
        accum.append(it[start_index: (start_index + split_size)])
    return accum

# 分割数据
def split(t, split_size = None):
    if not exists(split_size):
        return t

    if isinstance(t, torch.Tensor):
        return t.split(split_size, dim = 0)

    if isinstance(t, Iterable):
        return split_iterable(t, split_size)

    return TypeError

# 查找第一个符合条件的元素
def find_first(cond, arr):
    for el in arr:
        if cond(el):
            return el
    return None

# 分割参数和关键字参数
def split_args_and_kwargs(*args, batch_size = None, split_size = None, **kwargs):
    all_args = (*args, *kwargs.values())
    len_all_args = len(all_args)

    if not exists(batch_size):
        first_tensor = find_first(lambda t: isinstance(t, torch.Tensor), all_args)
        assert exists(first_tensor)
        batch_size = len(first_tensor)

    split_size = default(split_size, batch_size)
    num_chunks = math.ceil(batch_size / split_size)

    dict_len = len(kwargs)
    dict_keys = kwargs.keys()
    split_kwargs_index = len_all_args - dict_len

    split_all_args = [split(arg, split_size = split_size) if exists(arg) and isinstance(arg, (torch.Tensor, Iterable)) else ((arg,) * num_chunks) for arg in all_args]
    chunk_sizes = tuple(map(len, split_all_args[0]))
    # 遍历元组中的每个元素,元素包含一个 chunk_size 和对应的参数列表
    for (chunk_size, *chunked_all_args) in tuple(zip(chunk_sizes, *split_all_args)):
        # 将参数列表拆分为位置参数和关键字参数值
        chunked_args, chunked_kwargs_values = chunked_all_args[:split_kwargs_index], chunked_all_args[split_kwargs_index:]
        # 将关键字参数的键和值组成字典
        chunked_kwargs = dict(tuple(zip(dict_keys, chunked_kwargs_values)))
        # 计算当前 chunk 的大小占总 batch 大小的比例
        chunk_size_frac = chunk_size / batch_size
        # 生成当前 chunk 的比例和参数元组
        yield chunk_size_frac, (chunked_args, chunked_kwargs)
# 简单的文本转换函数,将特定字符替换为指定字符,去除空格和特殊字符,并截取指定长度
def simple_slugify(text, max_length = 255):
    return text.replace('-', '_').replace(',', '').replace(' ', '_').replace('|', '--').strip('-_')[:max_length]

# 检查元组中是否存在重复元素
def has_duplicates(tup):
    counts = dict()
    for el in tup:
        if el not in counts:
            counts[el] = 0
        counts[el] += 1
    return any(filter(lambda count: count > 1, counts.values()))

# 根据配置确定数据的类型
def determine_types(data, config):
    output = []
    for el in data:
        for name, data_type in config.items():
            if is_bearable(el, data_type):
                output.append(name)
                break
        else:
            raise TypeError(f'unable to determine type of {data}')

    return tuple(output)

# 训练器类
@beartype
class PhenakiTrainer(object):
    def __init__(
        self,
        phenaki: Phenaki,
        *,
        folder = None,
        train_on_images = False,
        batch_size = 16,
        grad_accum_every = 1,
        num_frames = 17,
        sample_num_frames = None,
        train_lr = 1e-4,
        train_num_steps = 100000,
        max_grad_norm = None,
        ema_update_every = 10,
        ema_decay = 0.995,
        adam_betas = (0.9, 0.99),
        wd = 0,
        save_and_sample_every = 1000,
        num_samples = 25,
        results_folder = './results',
        amp = False,
        fp16 = False,
        split_batches = True,
        convert_image_to = None,
        sample_texts_file_path = None,  # path to a text file with video captions, delimited by newline
        sample_texts: Optional[List[str]] = None,
        dataset: Optional[Dataset] = None,
        dataset_fields: Optional[Tuple[str, ...]] = None
    ):
        # 调用父类的构造函数
        super().__init__()
        # 导入 phenaki 模块中的 maskgit 和 cvivit
        maskgit = phenaki.maskgit
        cvivit = phenaki.cvivit

        # 确保 cvivit 在 phenaki 中存在
        assert exists(cvivit), 'cvivit must be present on phenaki'

        # 定义加速器
        self.accelerator = Accelerator(
            split_batches = split_batches,
            mixed_precision = 'fp16' if fp16 else 'no'
        )

        # 设置加速器的本地自动混合精度
        self.accelerator.native_amp = amp

        # 设置模型为 phenaki
        self.model = phenaki

        # 确保样本数量具有整数平方根
        assert has_int_squareroot(num_samples), 'number of samples must have an integer square root'
        # 设置是否无条件生成
        self.unconditional = maskgit.unconditional

        # 训练相关变量
        self.batch_size = batch_size
        self.grad_accum_every = grad_accum_every
        self.max_grad_norm = max_grad_norm
        self.train_num_steps = train_num_steps
        self.image_size = cvivit.image_size

        # 采样相关变量
        self.num_samples = num_samples
        self.sample_texts = None

        # 如果存在采样文本文件路径,则读取文本内容
        if exists(sample_texts_file_path):
            sample_texts_file_path = Path(sample_texts_file_path)
            assert sample_texts_file_path.exists()
            captions = sample_texts_file_path.read_text().split('\n')
            self.sample_texts = list(filter(len, captions))

        # 如果存在采样文本,则设置为采样文本
        elif exists(self.sample_texts):
            self.sample_texts = sample_texts

        # 如果是无条件生成或存在采样文本,则继续,否则报错
        assert maskgit.unconditional or exists(self.sample_texts), 'if maskgit is to be trained text conditioned, `sample_texts` List[str] or `sample_texts_file_path` must be given'

        # 设置保存和采样频率
        self.save_and_sample_every = save_and_sample_every

        # 数据集和数据加载器
        dataset_klass = ImageDataset if train_on_images else VideoDataset
        self.sample_num_frames = default(sample_num_frames, num_frames)
        self.train_on_images = train_on_images

        # 如果存在数据集,则使用该数据集,否则根据训练类型选择数据集
        if dataset:
            self.ds = dataset
        elif train_on_images:
            assert exists(folder)
            self.ds = ImageDataset(folder, self.image_size)
        else:
            assert exists(folder)
            self.ds = VideoDataset(folder, self.image_size, num_frames = num_frames)

        # 创建数据加载器
        dl = DataLoader(self.ds, batch_size = batch_size, shuffle = True, pin_memory = True, num_workers = cpu_count())
        dl = self.accelerator.prepare(dl)
        self.dl = cycle(dl)

        # 如果存在数据集字段,则检查字段是否合法
        if exists(dataset_fields):
            assert not has_duplicates(dataset_fields), 'dataset fields must not have duplicate field names'
            valid_dataset_fields = set(DATASET_FIELD_TYPE_CONFIG.keys())
            assert len(set(dataset_fields) - valid_dataset_fields) == 0, f'dataset fields must be one of {valid_dataset_fields}'

        self.dataset_fields = dataset_fields

        # 优化器
        self.opt = get_optimizer(maskgit.parameters(), lr = train_lr, wd = wd, betas = adam_betas)

        # 步数计数器
        self.step = 0

        # 准备模型、数据加载器和优化器
        self.model, self.opt = self.accelerator.prepare(self.model, self.opt)

        # 设置结果文件���
        self.results_folder = Path(results_folder)
        self.results_folder.mkdir(parents = True, exist_ok = True)

    # 将数据元组转换为关键字参数
    def data_tuple_to_kwargs(self, data):
        if not exists(self.dataset_fields):
            self.dataset_fields = determine_types(data, DATASET_FIELD_TYPE_CONFIG)
            assert not has_duplicates(self.dataset_fields), 'dataset fields must not have duplicate field names'

        return dict(zip(self.dataset_fields, data))

    # 打印消息
    def print(self, msg):
        self.accelerator.print(msg)

    # 设备属性
    @property
    def device(self):
        return self.accelerator.device

    # 是否分布式属性
    @property
    def is_distributed(self):
        return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1)

    # 是否主进程属性
    @property
    def is_main(self):
        return self.accelerator.is_main_process

    # 是否本地主进程属性
    @property
    def is_local_main(self):
        return self.accelerator.is_local_main_process
    # 保存模型的当前状态
    def save(self, milestone):
        # 如果不是本地主进程,则直接返回
        if not self.accelerator.is_local_main_process:
            return

        # 构建保存的数据字典
        data = {
            'step': self.step,  # 保存当前步数
            'model': self.accelerator.get_state_dict(self.model),  # 保存模型的状态字典
            'opt': self.opt.state_dict(),  # 保存优化器的状态字典
            'scaler': self.accelerator.scaler.state_dict() if exists(self.accelerator.scaler) else None  # 保存混合精度训练器的状态字典
        }

        # 将数据保存到文件中
        torch.save(data, str(self.results_folder / f'model-{milestone}.pt'))

    # 加载指定里程碑的模型状态
    def load(self, milestone):
        # 获取加速器和设备
        accelerator = self.accelerator
        device = accelerator.device

        # 从文件中加载数据
        data = torch.load(str(self.results_folder / f'model-{milestone}.pt'), map_location=device)

        # 获取模型并加载状态
        model = self.accelerator.unwrap_model(self.model)
        model.load_state_dict(data['model'])

        # 加载步数和优化器状态
        self.step = data['step']
        self.opt.load_state_dict(data['opt'])

        # 如果混合精度训练器存在且数据中也存在,则加载混合精度训练器状态
        if exists(self.accelerator.scaler) and exists(data['scaler']):
            self.accelerator.scaler.load_state_dict(data['scaler'])

    # 训练步骤函数
    def train_step(
        self,
        only_train_generator=False,  # 是否只训练生成器
        only_train_critic=False  # 是否只训练评论家
    # 定义 train 方法,用于训练模型
    def train(
        self,
        only_train_generator = False,
        only_train_critic = False
        ):
        # 获取加速器和设备
        accelerator = self.accelerator
        device = self.device

        # 初始化总损失
        total_loss = 0.

        # 循环执行梯度累积
        for _ in range(self.grad_accum_every):
            # 从数据加载器中获取数据
            data = next(self.dl)
            # 将数据转移到指定设备
            data = elements_to_device_if_tensor(data, device)
            # 将数据转换为关键字参数
            data_kwargs = self.data_tuple_to_kwargs(data)

            # 检查是否训练图像,数据维度是否正确
            assert not (self.train_on_images and data_kwargs['videos'].ndim != 4), 'you have it set to train on images, but the dataset is not returning tensors of 4 dimensions (batch, channels, height, width)'

            # 使用混合精度进行训练
            with self.accelerator.autocast():
                # 模型前向传播计算损失
                loss = self.model(**{
                    **data_kwargs,
                    'only_train_generator': only_train_generator,
                    'only_train_critic': only_train_critic
                })

                # 将损失除以梯度累积次数
                loss = loss / self.grad_accum_every
                # 累加总损失
                total_loss += loss.item()

            # 反向传播
            self.accelerator.backward(loss)

        # 如果存在最大梯度范数,则进行梯度裁剪
        if exists(self.max_grad_norm):
            accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)

        # 等待所有进程完成
        accelerator.wait_for_everyone()

        # 更新优化器参数
        self.opt.step()
        self.opt.zero_grad()

        # 等待所有进程完成
        accelerator.wait_for_everyone()

        # 如果是主进程且满足保存和采样间隔条件
        if self.is_main and self.step % self.save_and_sample_every == 0:
            # 模型转为评估模式
            self.model.eval()
            milestone = self.step // self.save_and_sample_every

            # 是否传入文本
            sample_kwargs = dict()

            if not self.unconditional:
                texts = choices(self.sample_texts, k = self.num_samples)
            else:
                texts = (None,) * self.num_samples

            sample_kwargs = {'texts': texts}

            # 选择采样方法
            if self.train_on_images:
                sample_method = self.model.sample_images
            else:
                sample_method = partial(self.model.sample, num_frames = self.sample_num_frames)

            # 分组评估,适当拆分参数
            with torch.no_grad():
                groups = num_to_groups(self.num_samples, self.batch_size)
                args_kwargs_iter = split_args_and_kwargs(batch_size = self.num_samples, split_size = self.batch_size, **sample_kwargs)

                all_sampled = []
                for group_batch_size, (_, (_, kwargs)) in zip(groups, args_kwargs_iter):
                    _kwargs = kwargs if not self.unconditional else dict()
                    sampled = sample_method(num_frames = self.sample_num_frames, batch_size = group_batch_size, **_kwargs)
                    all_sampled.append(sampled)

            # 保存视频和图像
            if not self.train_on_images:
                sampled_videos = torch.cat(all_sampled, dim = 0)
                milestone_folder = self.results_folder / f'videos.{milestone}'
                milestone_folder.mkdir(parents = True, exist_ok = True)

                for ind, (video_tensor, video_caption) in enumerate(zip(sampled_videos.unbind(dim = 0), texts)):
                    slugged_video_caption = simple_slugify(video_caption) if exists(video_caption) else str(ind)
                    video_tensor_to_gif(video_tensor, str(milestone_folder / f'{slugged_video_caption}.gif'))
            else:
                nrows = int(math.sqrt(self.num_samples))

                sampled_images = sampled_videos.detach().cpu().float().clamp(0., 1.)
                grid = make_grid(sampled_images, nrow = nrows, normalize = True, value_range = (0, 1))

                save_image(grid, str(self.results_folder / f'{milestone}.png'))

            # 保存检查点
            self.save(milestone)

        # 更新步数
        self.step += 1
        return total_loss
    ):  
        # 使用 tqdm 创建一个进度条,设置初始值为 self.step,总步数为 self.train_num_steps,如果不是主进程则禁用
        with tqdm(
            initial = self.step,
            total = self.train_num_steps,
            disable = not self.is_main
        ) as pbar:
            # 当 self.step 小于 self.train_num_steps 时循环
            while self.step < self.train_num_steps:
                # 调用 train_step 方法进行训练,传入参数 only_train_generator 和 only_train_critic
                loss = self.train_step(
                    only_train_generator = only_train_generator,
                    only_train_critic = only_train_critic
                )
                # 设置进度条的描述为当前 loss 值,保留四位小数
                pbar.set_description(f'loss: {loss:.4f}')
                # 更新进度条
                pbar.update(1)
        # 训练完成后打印信息
        self.print('training complete')

.\lucidrains\phenaki-pytorch\phenaki_pytorch\t5.py

# 导入 torch 库
import torch
# 导入 transformers 库
import transformers
# 从 transformers 库中导入 T5Tokenizer, T5EncoderModel, T5Config

# 减少警告信息,只使用编码器
transformers.logging.set_verbosity_error()

# 辅助函数
def exists(val):
    return val is not None

# 配置
MAX_LENGTH = 256
DEFAULT_T5_NAME = 'google/t5-v1_1-base'
T5_CONFIGS = {}

# 全局单例
# 获取指定名称的 tokenizer
def get_tokenizer(name):
    tokenizer = T5Tokenizer.from_pretrained(name)
    return tokenizer

# 获取指定名称的模型
def get_model(name):
    model = T5EncoderModel.from_pretrained(name)
    return model

# 获取指定名称的模型和 tokenizer
def get_model_and_tokenizer(name):
    global T5_CONFIGS

    if name not in T5_CONFIGS:
        T5_CONFIGS[name] = dict()

    if "model" not in T5_CONFIGS[name]:
        T5_CONFIGS[name]["model"] = get_model(name)

    if "tokenizer" not in T5_CONFIGS[name]:
        T5_CONFIGS[name]["tokenizer"] = get_tokenizer(name)

    return T5_CONFIGS[name]['model'], T5_CONFIGS[name]['tokenizer']

# 获取编码维度
def get_encoded_dim(name):
    if name not in T5_CONFIGS:
        config = T5Config.from_pretrained(name)
        T5_CONFIGS[name] = dict(config = config)

    elif "config" in T5_CONFIGS[name]:
        config = T5_CONFIGS[name]["config"]

    elif "model" in T5_CONFIGS[name]:
        config = T5_CONFIGS[name]["model"].config

    else:
        raise ValueError(f'unknown t5 name {name}')

    return config.d_model

# 编码文本
def t5_encode_text(
    texts,
    name = DEFAULT_T5_NAME,
    output_device = None
):
    # 获取模型和 tokenizer
    t5, tokenizer = get_model_and_tokenizer(name)

    # 如果 CUDA 可用,则将模型移至 CUDA
    if torch.cuda.is_available():
        t5 = t5.cuda()

    device = next(t5.parameters()).device

    # 对文本进行编码
    encoded = tokenizer.batch_encode_plus(
        texts,
        return_tensors = 'pt',
        padding = 'longest',
        max_length = MAX_LENGTH,
        truncation = True
    )

    input_ids = encoded.input_ids.to(device)
    attn_mask = encoded.attention_mask.to(device)

    t5.eval()

    with torch.no_grad():
        output = t5(input_ids = input_ids, attention_mask = attn_mask)
        encoded_text = output.last_hidden_state.detach()

    attn_mask = attn_mask[..., None].bool()

    # 如果输出设备不存在,则返回编码文本
    if not exists(output_device):
        encoded_text = encoded_text.masked_fill(~attn_mask, 0.)
        return encoded_text

    encoded_text = encoded_text.to(output_device)
    attn_mask = attn_mask.to(output_device)

    encoded_text = encoded_text.masked_fill(~attn_mask, 0.)
    return encoded_text

.\lucidrains\phenaki-pytorch\phenaki_pytorch\__init__.py

# 从 phenaki_pytorch 模块中导入 Phenaki, CViViT, MaskGit, TokenCritic, make_video 函数
from phenaki_pytorch.phenaki_pytorch import Phenaki, CViViT, MaskGit, TokenCritic, make_video

# 从 phenaki_pytorch 模块中导入 CViViTTrainer 类
from phenaki_pytorch.cvivit_trainer import CViViTTrainer

# 从 phenaki_pytorch 模块中导入 PhenakiTrainer 类
from phenaki_pytorch.phenaki_trainer import PhenakiTrainer

Phenaki - Pytorch

Implementation of Phenaki Video, which uses Mask GIT to produce text guided videos of up to 2 minutes in length, in Pytorch. It will also combine another technique involving a token critic for potentially even better generations

Please join Join us on Discord if you are interested in replicating this work in the open

AI Coffeebreak explanation

Appreciation

  • Stability.ai for the generous sponsorship to work on cutting edge artificial intelligence research

  • 🤗 Huggingface for their amazing transformers and accelerate library

  • Guillem for his ongoing contributions

  • You? If you are a great machine learning engineer and / or researcher, feel free to contribute to the frontier of open source generative AI

Install

$ pip install phenaki-pytorch

Usage

C-ViViT

import torch
from phenaki_pytorch import CViViT, CViViTTrainer

cvivit = CViViT(
    dim = 512,
    codebook_size = 65536,
    image_size = 256,
    patch_size = 32,
    temporal_patch_size = 2,
    spatial_depth = 4,
    temporal_depth = 4,
    dim_head = 64,
    heads = 8
).cuda()

trainer = CViViTTrainer(
    cvivit,
    folder = '/path/to/images/or/videos',
    batch_size = 4,
    grad_accum_every = 4,
    train_on_images = False,  # you can train on images first, before fine tuning on video, for sample efficiency
    use_ema = False,          # recommended to be turned on (keeps exponential moving averaged cvivit) unless if you don't have enough resources
    num_train_steps = 10000
)

trainer.train()               # reconstructions and checkpoints will be saved periodically to ./results

Phenaki

import torch
from phenaki_pytorch import CViViT, MaskGit, Phenaki

cvivit = CViViT(
    dim = 512,
    codebook_size = 65536,
    image_size = (256, 128),  # video with rectangular screen allowed
    patch_size = 32,
    temporal_patch_size = 2,
    spatial_depth = 4,
    temporal_depth = 4,
    dim_head = 64,
    heads = 8
)

cvivit.load('/path/to/trained/cvivit.pt')

maskgit = MaskGit(
    num_tokens = 5000,
    max_seq_len = 1024,
    dim = 512,
    dim_context = 768,
    depth = 6,
)

phenaki = Phenaki(
    cvivit = cvivit,
    maskgit = maskgit
).cuda()

videos = torch.randn(3, 3, 17, 256, 128).cuda() # (batch, channels, frames, height, width)
mask = torch.ones((3, 17)).bool().cuda() # [optional] (batch, frames) - allows for co-training videos of different lengths as well as video and images in the same batch

texts = [
    'a whale breaching from afar',
    'young girl blowing out candles on her birthday cake',
    'fireworks with blue and green sparkles'
]

loss = phenaki(videos, texts = texts, video_frame_mask = mask)
loss.backward()

# do the above for many steps, then ...

video = phenaki.sample(texts = 'a squirrel examines an acorn', num_frames = 17, cond_scale = 5.) # (1, 3, 17, 256, 128)

# so in the paper, they do not really achieve 2 minutes of coherent video
# at each new scene with new text conditioning, they condition on the previous K frames
# you can easily achieve this with this framework as so

video_prime = video[:, :, -3:] # (1, 3, 3, 256, 128) # say K = 3

video_next = phenaki.sample(texts = 'a cat watches the squirrel from afar', prime_frames = video_prime, num_frames = 14) # (1, 3, 14, 256, 128)

# the total video

entire_video = torch.cat((video, video_next), dim = 2) # (1, 3, 17 + 14, 256, 128)

# and so on...

Or just import the make_video function

# ... above code

from phenaki_pytorch import make_video

entire_video, scenes = make_video(phenaki, texts = [
    'a squirrel examines an acorn buried in the snow',
    'a cat watches the squirrel from a frosted window sill',
    'zoom out to show the entire living room, with the cat residing by the window sill'
], num_frames = (17, 14, 14), prime_lengths = (5, 5))

entire_video.shape # (1, 3, 17 + 14 + 14 = 45, 256, 256)

# scenes - List[Tensor[3]] - video segment of each scene

That's it!

Token Critic

A new paper suggests that instead of relying on the predicted probabilities of each token as a measure of confidence, one can train an extra critic to decide what to iteratively mask during sampling. You can optionally train this critic for potentially better generations as shown below

import torch
from phenaki_pytorch import CViViT, MaskGit, TokenCritic, Phenaki

cvivit = CViViT(
    dim = 512,
    codebook_size = 65536,
    image_size = (256, 128),
    patch_size = 32,
    temporal_patch_size = 2,
    spatial_depth = 4,
    temporal_depth = 4,
    dim_head = 64,
    heads = 8
)

maskgit = MaskGit(
    num_tokens = 5000,
    max_seq_len = 1024,
    dim = 512,
    dim_context = 768,
    depth = 6,
)

# (1) define the critic

critic = TokenCritic(
    num_tokens = 5000,
    max_seq_len = 1024,
    dim = 512,
    dim_context = 768,
    depth = 6,
    has_cross_attn = True
)

trainer = Phenaki(
    maskgit = maskgit,
    cvivit = cvivit,
    critic = critic    # and then (2) pass it into Phenaki
).cuda()

texts = [
    'a whale breaching from afar',
    'young girl blowing out candles on her birthday cake',
    'fireworks with blue and green sparkles'
]

videos = torch.randn(3, 3, 3, 256, 128).cuda() # (batch, channels, frames, height, width)

loss = trainer(videos = videos, texts = texts)
loss.backward()

Or even simpler, just reuse MaskGit itself as a Self Critic (Nijkamp et al), by setting self_token_critic = True on the initialization of Phenaki

phenaki = Phenaki(
    ...,
    self_token_critic= True  # set this to True
)

Now your generations should be greatly improved!

Phenaki Trainer

This repository will also endeavor to allow the researcher to train on text-to-image and then text-to-video. Similarly, for unconditional training, the researcher should be able to first train on images and then fine tune on video. Below is an example for text-to-video

import torch
from torch.utils.data import Dataset
from phenaki_pytorch import CViViT, MaskGit, Phenaki, PhenakiTrainer

cvivit = CViViT(
    dim = 512,
    codebook_size = 65536,
    image_size = 256,
    patch_size = 32,
    temporal_patch_size = 2,
    spatial_depth = 4,
    temporal_depth = 4,
    dim_head = 64,
    heads = 8
)

cvivit.load('/path/to/trained/cvivit.pt')

maskgit = MaskGit(
    num_tokens = 5000,
    max_seq_len = 1024,
    dim = 512,
    dim_context = 768,
    depth = 6,
    unconditional = False
)

phenaki = Phenaki(
    cvivit = cvivit,
    maskgit = maskgit
).cuda()

# mock text video dataset
# you will have to extend your own, and return the (<video tensor>, <caption>) tuple

class MockTextVideoDataset(Dataset):
    def __init__(
        self,
        length = 100,
        image_size = 256,
        num_frames = 17
    ):
        super().__init__()
        self.num_frames = num_frames
        self.image_size = image_size
        self.len = length

    def __len__(self):
        return self.len

    def __getitem__(self, idx):
        video = torch.randn(3, self.num_frames, self.image_size, self.image_size)
        caption = 'video caption'
        return video, caption

dataset = MockTextVideoDataset()

# pass in the dataset

trainer = PhenakiTrainer(
    phenaki = phenaki,
    batch_size = 4,
    grad_accum_every = 4,
    train_on_images = False, # if your mock dataset above return (images, caption) pairs, set this to True
    dataset = dataset,       # pass in your dataset here
    sample_texts_file_path = '/path/to/captions.txt' # each caption should be on a new line, during sampling, will be randomly drawn
)

trainer.train()

Unconditional is as follows

ex. unconditional images and video training

import torch
from phenaki_pytorch import CViViT, MaskGit, Phenaki, PhenakiTrainer

cvivit = CViViT(
    dim = 512,
    codebook_size = 65536,
    image_size = 256,
    patch_size = 32,
    temporal_patch_size = 2,
    spatial_depth = 4,
    temporal_depth = 4,
    dim_head = 64,
    heads = 8
)

cvivit.load('/path/to/trained/cvivit.pt')

maskgit = MaskGit(
    num_tokens = 5000,
    max_seq_len = 1024,
    dim = 512,
    dim_context = 768,
    depth = 6,
    unconditional = False
)

phenaki = Phenaki(
    cvivit = cvivit,
    maskgit = maskgit
).cuda()

# pass in the folder to images or video

trainer = PhenakiTrainer(
    phenaki = phenaki,
    batch_size = 4,
    grad_accum_every = 4,
    train_on_images = True,                # for sake of example, bottom is folder of images
    dataset = '/path/to/images/or/video'
)

trainer.train()

Todo

  • pass mask probability into maskgit and auto-mask and get cross entropy loss

  • cross attention + get t5 embeddings code from imagen-pytorch and get classifier free guidance wired up

  • wire up full vqgan-vae for c-vivit, just take what is in parti-pytorch already, but make sure to use a stylegan discriminator as said in paper

  • complete token critic training code

  • complete first pass of maskgit scheduled sampling + token critic (optionally without if researcher does not want to do extra training)

  • inference code that allows for sliding time + conditioning on K past frames

  • alibi pos bias for temporal attention

  • give spatial attention the most powerful positional bias

  • make sure to use stylegan-esque discriminator

  • 3d relative positional bias for maskgit

  • make sure maskgit can also support training of images, and make sure it works on local machine

  • also build option for token critic to be conditioned with the text

  • should be able to train for text to image generation first

  • make sure critic trainer can take in cvivit and automatically pass in video patch shape for relative positional bias - make sure critic also gets optimal relative positional bias

  • training code for cvivit

  • move cvivit into own file

  • unconditional generations (both video and images)

  • wire up accelerate for multi-gpu training for both c-vivit and maskgit

  • add depthwise-convs to cvivit for position generating

  • some basic video manipulation code, allow for sampled tensor to be saved as gif

  • basic critic training code

  • add position generating dsconv to maskgit too

  • outfit customizable self attention blocks to stylegan discriminator

  • add all top of the line research for stabilizing transformers training

  • get some basic critic sampling code, show comparison of with and without critic

  • bring in concatenative token shift (temporal dimension)

  • add a DDPM upsampler, either port from imagen-pytorch or just rewrite a simple version here

  • take care of masking in maskgit

  • test maskgit + critic alone on oxford flowers dataset

  • support rectangular sized videos

  • add flash attention as an option for all transformers and cite @tridao

Citations

@article{Villegas2022PhenakiVL,
    title   = {Phenaki: Variable Length Video Generation From Open Domain Textual Description},
    author  = {Ruben Villegas and Mohammad Babaeizadeh and Pieter-Jan Kindermans and Hernan Moraldo and Han Zhang and Mohammad Taghi Saffar and Santiago Castro and Julius Kunze and D. Erhan},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2210.02399}
}
@article{Chang2022MaskGITMG,
    title   = {MaskGIT: Masked Generative Image Transformer},
    author  = {Huiwen Chang and Han Zhang and Lu Jiang and Ce Liu and William T. Freeman},
    journal = {2022 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
    year    = {2022},
    pages   = {11305-11315}
}
@article{Lezama2022ImprovedMI,
    title   = {Improved Masked Image Generation with Token-Critic},
    author  = {Jos{\'e} Lezama and Huiwen Chang and Lu Jiang and Irfan Essa},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2209.04439}
}
@misc{ding2021cogview,
    title   = {CogView: Mastering Text-to-Image Generation via Transformers},
    author  = {Ming Ding and Zhuoyi Yang and Wenyi Hong and Wendi Zheng and Chang Zhou and Da Yin and Junyang Lin and Xu Zou and Zhou Shao and Hongxia Yang and Jie Tang},
    year    = {2021},
    eprint  = {2105.13290},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{shazeer2020glu,
    title   = {GLU Variants Improve Transformer},
    author  = {Noam Shazeer},
    year    = {2020},
    url     = {https://arxiv.org/abs/2002.05202}
}
@misc{press2021ALiBi,
    title   = {Train Short, Test Long: Attention with Linear Biases Enable Input Length Extrapolation},
    author  = {Ofir Press and Noah A. Smith and Mike Lewis},
    year    = {2021},
    url     = {https://ofir.io/train_short_test_long.pdf}
}
@article{Liu2022SwinTV,
    title   = {Swin Transformer V2: Scaling Up Capacity and Resolution},
    author  = {Ze Liu and Han Hu and Yutong Lin and Zhuliang Yao and Zhenda Xie and Yixuan Wei and Jia Ning and Yue Cao and Zheng Zhang and Li Dong and Furu Wei and Baining Guo},
    journal = {2022 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
    year    = {2022},
    pages   = {11999-12009}
}
@inproceedings{Nijkamp2021SCRIPTSP,
    title   = {SCRIPT: Self-Critic PreTraining of Transformers},
    author  = {Erik Nijkamp and Bo Pang and Ying Nian Wu and Caiming Xiong},
    booktitle = {North American Chapter of the Association for Computational Linguistics},
    year    = {2021}
}
@misc{https://doi.org/10.48550/arxiv.2302.01327,
    doi     = {10.48550/ARXIV.2302.01327},
    url     = {https://arxiv.org/abs/2302.01327},
    author  = {Kumar, Manoj and Dehghani, Mostafa and Houlsby, Neil},
    title   = {Dual PatchNorm},
    publisher = {arXiv},
    year    = {2023},
    copyright = {Creative Commons Attribution 4.0 International}
}
@misc{gilmer2023intriguing
    title  = {Intriguing Properties of Transformer Training Instabilities},
    author = {Justin Gilmer, Andrea Schioppa, and Jeremy Cohen},
    year   = {2023},
    status = {to be published - one attention stabilization technique is circulating within Google Brain, being used by multiple teams}
}
@misc{mentzer2023finite,
    title   = {Finite Scalar Quantization: VQ-VAE Made Simple},
    author  = {Fabian Mentzer and David Minnen and Eirikur Agustsson and Michael Tschannen},
    year    = {2023},
    eprint  = {2309.15505},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{yu2023language,
    title   = {Language Model Beats Diffusion -- Tokenizer is Key to Visual Generation},
    author  = {Lijun Yu and José Lezama and Nitesh B. Gundavarapu and Luca Versari and Kihyuk Sohn and David Minnen and Yong Cheng and Agrim Gupta and Xiuye Gu and Alexander G. Hauptmann and Boqing Gong and Ming-Hsuan Yang and Irfan Essa and David A. Ross and Lu Jiang},
    year    = {2023},
    eprint  = {2310.05737},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}