Lucidrains 系列项目源码解析(五十三)
.\lucidrains\lumiere-pytorch\lumiere_pytorch\mp_lumiere.py
# 导入所需的库
from math import sqrt
from functools import partial
import torch
from torch import nn, einsum, Tensor
from torch.nn import Module, ModuleList
import torch.nn.functional as F
# 导入 beartype 库,用于类型注解
from beartype import beartype
from beartype.typing import List, Tuple, Optional
# 导入 einops 库,用于操作张量
from einops import rearrange, pack, unpack, repeat
# 导入 lumiere 库中的函数
from lumiere_pytorch.lumiere import (
image_or_video_to_time,
handle_maybe_channel_last,
Lumiere
)
# 定义一些辅助函数
# 判断变量是否存在
def exists(v):
return v is not None
# 如果变量存在则返回变量,否则返回默认值
def default(v, d):
return v if exists(v) else d
# 将张量打包成指定模式的形状
def pack_one(t, pattern):
return pack([t], pattern)
# 将打包后的张量解包成指定模式的形状
def unpack_one(t, ps, pattern):
return unpack(t, ps, pattern)[0]
# 压缩字典中值不存在的键值对
def compact_values(d: dict):
return {k: v for k, v in d.items() if exists(v)}
# 计算 L2 范数
def l2norm(t, dim = -1, eps = 1e-12):
return F.normalize(t, dim = dim, eps = eps)
# 对权重进行归一化处理
def normalize_weight(weight, eps = 1e-4):
weight, ps = pack_one(weight, 'o *')
normed_weight = l2norm(weight, eps = eps)
normed_weight = normed_weight * sqrt(weight.numel() / weight.shape[0])
return unpack_one(normed_weight, ps, 'o *')
# 在一维上进行插值
def interpolate_1d(x, length, mode = 'bilinear'):
x = rearrange(x, 'b c t -> b c t 1')
x = F.interpolate(x, (length, 1), mode = mode)
return rearrange(x, 'b c t 1 -> b c t')
# MP 激活函数
class MPSiLU(Module):
def forward(self, x):
return F.silu(x) / 0.596
# 增益 - 层缩放
class Gain(Module):
def __init__(self):
super().__init__()
self.gain = nn.Parameter(torch.tensor(0.))
def forward(self, x):
return x * self.gain
# MP 线性层
class Linear(Module):
def __init__(self, dim_in, dim_out, eps = 1e-4):
super().__init__()
weight = torch.randn(dim_out, dim_in)
self.weight = nn.Parameter(weight)
self.eps = eps
self.fan_in = dim_in
def forward(self, x):
if self.training:
with torch.no_grad():
normed_weight = normalize_weight(self.weight, eps = self.eps)
self.weight.copy_(normed_weight)
weight = normalize_weight(self.weight, eps = self.eps) / sqrt(self.fan_in)
return F.linear(x, weight)
# 强制权重归一化的卷积层和线性层
class Conv2d(Module):
def __init__(
self,
dim_in,
dim_out,
kernel_size,
eps = 1e-4
):
super().__init__()
weight = torch.randn(dim_out, dim_in, kernel_size, kernel_size)
self.weight = nn.Parameter(weight)
self.eps = eps
self.fan_in = dim_in * kernel_size ** 2
def forward(self, x):
if self.training:
with torch.no_grad():
normed_weight = normalize_weight(self.weight, eps = self.eps)
self.weight.copy_(normed_weight)
weight = normalize_weight(self.weight, eps = self.eps) / sqrt(self.fan_in)
return F.conv2d(x, weight, padding = 'same')
class Conv1d(Module):
def __init__(
self,
dim_in,
dim_out,
kernel_size,
eps = 1e-4,
init_dirac = False
):
super().__init__()
weight = torch.randn(dim_out, dim_in, kernel_size)
self.weight = nn.Parameter(weight)
if init_dirac:
nn.init.dirac_(self.weight)
self.eps = eps
self.fan_in = dim_in * kernel_size
def forward(self, x):
if self.training:
with torch.no_grad():
normed_weight = normalize_weight(self.weight, eps = self.eps)
self.weight.copy_(normed_weight)
weight = normalize_weight(self.weight, eps = self.eps) / sqrt(self.fan_in)
return F.conv1d(x, weight, padding = 'same')
# 像素归一化层
class PixelNorm(Module):
# 初始化函数,设置维度和epsilon值
def __init__(self, dim, eps = 1e-4):
# 调用父类的初始化函数
super().__init__()
# 设置像素规范化的高epsilon值
self.dim = dim
self.eps = eps
# 前向传播函数
def forward(self, x):
# 获取维度
dim = self.dim
# 返回经过L2范数规范化后的结果乘以维度的平方根
return l2norm(x, dim = dim, eps = self.eps) * sqrt(x.shape[dim])
# 定义一个类,实现magnitude preserving sum的功能
# t的值根据经验设定为0.3,用于encoder/decoder/attention residuals和embedding
class MPAdd(Module):
def __init__(self, t):
super().__init__()
self.t = t
# 实现前向传播功能
def forward(self, x, res):
a, b, t = x, res, self.t
num = a * (1. - t) + b * t
den = sqrt((1 - t) ** 2 + t ** 2)
return num / den
# 定义一个类,实现mp attention的功能
class MPAttention(Module):
def __init__(
self,
dim,
heads = 4,
dim_head = 64,
num_mem_kv = 4,
mp_add_t = 0.3,
dropout = 0.
):
super().__init__()
self.heads = heads
hidden_dim = dim_head * heads
self.scale = dim_head ** -0.5
self.pixel_norm = PixelNorm(dim = -1)
self.dropout = nn.Dropout(dropout)
self.mem_kv = nn.Parameter(torch.randn(2, heads, num_mem_kv, dim_head))
self.to_qkv = Linear(dim, hidden_dim * 3)
self.to_out = Linear(hidden_dim, dim)
self.mp_add = MPAdd(t = mp_add_t)
# 实现前向传播功能
def forward(self, x):
res, b = x, x.shape[0]
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
mk, mv = map(lambda t: repeat(t, 'h n d -> b h n d', b = b), self.mem_kv)
k, v = map(partial(torch.cat, dim = -2), ((mk, k), (mv, v)))
q, k, v = map(self.pixel_norm, (q, k, v))
q = q * self.scale
sim = einsum('b h i d, b h j d -> b h i j', q, k)
attn = sim.softmax(dim = -1)
attn = self.dropout(attn)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out)
return self.mp_add(out, res)
# 定义一个类,实现时间维度的下采样
class MPTemporalDownsample(Module):
def __init__(
self,
dim,
channel_last = False,
time_dim = None
):
super().__init__()
self.time_dim = time_dim
self.channel_last = channel_last
self.conv = Conv1d(dim, dim, 3, init_dirac = True)
# 实现前向传播功能
@handle_maybe_channel_last
@image_or_video_to_time
def forward(
self,
x
):
t = x.shape[-1]
assert t > 1, 'time dimension must be greater than 1 to be compressed'
x = interpolate_1d(x, t // 2)
return self.conv(x)
# 定义一个类,实现时间维度的上采样
class MPTemporalUpsample(Module):
def __init__(
self,
dim,
channel_last = False,
time_dim = None
):
super().__init__()
self.time_dim = time_dim
self.channel_last = channel_last
self.conv = Conv1d(dim, dim, 3, init_dirac = True)
# 实现前向传播功能
@handle_maybe_channel_last
@image_or_video_to_time
def forward(
self,
x
):
t = x.shape[-1]
x = interpolate_1d(x, t * 2)
return self.conv(x)
# 定义一个类,实现MP卷积膨胀块的功能
class MPConvolutionInflationBlock(Module):
def __init__(
self,
*,
dim,
conv2d_kernel_size = 3,
conv1d_kernel_size = 3,
channel_last = False,
time_dim = None,
mp_add_t = 0.3,
dropout = 0.
):
super().__init__()
self.time_dim = time_dim
self.channel_last = channel_last
self.spatial_conv = nn.Sequential(
Conv2d(dim, dim, conv2d_kernel_size, 3),
MPSiLU()
)
self.temporal_conv = nn.Sequential(
Conv1d(dim, dim, conv1d_kernel_size, 3),
MPSiLU(),
nn.Dropout(dropout)
)
self.proj_out = nn.Sequential(
Conv1d(dim, dim, 1),
Gain()
)
self.residual_mp_add = MPAdd(t = mp_add_t)
# 实现前向传播功能
@handle_maybe_channel_last
def forward(
self,
x,
batch_size = None
):
# 将输入赋值给残差变量
residual = x
# 判断输入是否为视频,判断输入的维度是否为5
is_video = x.ndim == 5
# 如果是视频
if is_video:
# 获取批量大小
batch_size = x.shape[0]
# 重新排列输入数据的维度
x = rearrange(x, 'b c t h w -> (b t) c h w')
# 对输入进行空间卷积
x = self.spatial_conv(x)
# 重新排列参数
rearrange_kwargs = compact_values(dict(b = batch_size, t = self.time_dim))
# 断言重新排列参数的长度大于0
assert len(rearrange_kwargs) > 0, 'either batch_size is passed in on forward, or time_dim is set on init'
# 重新排列输入数据的维度
x = rearrange(x, '(b t) c h w -> b h w c t', **rearrange_kwargs)
# 打包输入数据
x, ps = pack_one(x, '* c t')
# 对输入进行时间卷积
x = self.temporal_conv(x)
# 对输入进行投影输出
x = self.proj_out(x)
# 解包输入数据
x = unpack_one(x, ps, '* c t')
# 如果是视频
if is_video:
# 重新排列输入数据的维度
x = rearrange(x, 'b h w c t -> b c t h w')
else:
# 重新排列输入数据的维度
x = rearrange(x, 'b h w c t -> (b t) c h w')
# 返回残差模块添加后的结果
return self.residual_mp_add(x, residual)
# 定义一个多头注意力膨胀块类,继承自 Module 类
class MPAttentionInflationBlock(Module):
# 初始化函数
def __init__(
self,
*,
dim, # 维度
depth = 1, # 层数,默认为1
time_dim = None, # 时间维度,默认为None
channel_last = False, # 是否通道在最后,默认为False
mp_add_t = 0.3, # MP 添加时间,默认为0.3
dropout = 0., # 丢弃率,默认为0
**attn_kwargs # 其他注意力参数
):
super().__init__()
self.time_dim = time_dim # 初始化时间维度
self.channel_last = channel_last # 初始化通道在最后
self.temporal_attns = ModuleList([]) # 初始化时间注意力模块列表
# 循环创建指定层数的多头注意力模块
for _ in range(depth):
attn = MPAttention(
dim = dim,
dropout = dropout,
**attn_kwargs
)
self.temporal_attns.append(attn) # 将创建的多头注意力模块添加到列表中
# 定义输出投影层
self.proj_out = nn.Sequential(
Linear(dim, dim), # 线性层
Gain() # 增益层
)
# 定义残差 MP 添加层
self.residual_mp_add = MPAdd(t = mp_add_t)
# 前向传播函数
@handle_maybe_channel_last
def forward(
self,
x, # 输入张量
batch_size = None # 批量大小,默认为None
):
is_video = x.ndim == 5 # 判断是否为视频数据
assert is_video ^ (exists(batch_size) or exists(self.time_dim)), 'either a tensor of shape (batch, channels, time, height, width) is passed in, or (batch * time, channels, height, width) along with `batch_size`'
if self.channel_last:
x = rearrange(x, 'b ... c -> b c ...') # 重新排列张量维度
if is_video:
batch_size = x.shape[0] # 获取批量大小
x = rearrange(x, 'b c t h w -> b h w t c') # 重新排列张量维度
else:
assert exists(batch_size) or exists(self.time_dim) # 断言批量大小或时间维度存在
rearrange_kwargs = dict(b = batch_size, t = self.time_dim)
x = rearrange(x, '(b t) c h w -> b h w t c', **compact_values(rearrange_kwargs)) # 重新排列张量维度
x, ps = pack_one(x, '* t c') # 打包张量
residual = x # 保存残差
# 遍历时间注意力模块列表
for attn in self.temporal_attns:
x = attn(x) # 多头注意��操作
x = self.proj_out(x) # 投影输出
x = self.residual_mp_add(x, residual) # 残差 MP 添加
x = unpack_one(x, ps, '* t c') # 解包张量
if is_video:
x = rearrange(x, 'b h w t c -> b c t h w') # 重新排列张量维度
else:
x = rearrange(x, 'b h w t c -> (b t) c h w') # 重新排列张量维度
if self.channel_last:
x = rearrange(x, 'b c ... -> b ... c') # 重新排列张量维度
return x # 返回结果张量
# MPLumiere 是 Lumiere 的一个部分,包含四个 MP 时间模块
MPLumiere = partial(
Lumiere,
conv_klass = MPConvolutionInflationBlock, # 卷积类
attn_klass = MPAttentionInflationBlock, # 注意力类
downsample_klass = MPTemporalDownsample, # 下采样类
upsample_klass = MPTemporalUpsample # 上采样类
)
.\lucidrains\lumiere-pytorch\lumiere_pytorch\__init__.py
# 从lumiere_pytorch.lumiere模块中导入ConvolutionInflationBlock、AttentionInflationBlock、TemporalDownsample、TemporalUpsample、set_time_dim_函数
from lumiere_pytorch.lumiere import (
ConvolutionInflationBlock,
AttentionInflationBlock,
TemporalDownsample,
TemporalUpsample,
set_time_dim_
)
# 从lumiere_pytorch.lumiere模块中导入Lumiere类
from lumiere_pytorch.lumiere import Lumiere
# 从lumiere_pytorch.mp_lumiere模块中导入MPLumiere、MPConvolutionInflationBlock、MPAttentionInflationBlock、MPTemporalDownsample、MPTemporalUpsample类
from lumiere_pytorch.mp_lumiere import (
MPLumiere,
MPConvolutionInflationBlock,
MPAttentionInflationBlock,
MPTemporalDownsample,
MPTemporalUpsample,
)

Lumiere - Pytorch
Implementation of Lumiere, SOTA text-to-video generation from Google Deepmind, in Pytorch
Since this paper is mostly just a few key ideas on top of text-to-image model, will take it a step further and extend the new Karras U-net to video within this repository.
Appreciation
- A16Z Open Source AI Grant Program and 🤗 Huggingface for the generous sponsorships, as well as my other sponsors, for affording me the independence to open source current artificial intelligence research
Install
$ pip install lumiere-pytorch
Usage
import torch
from lumiere_pytorch import MPLumiere
from denoising_diffusion_pytorch import KarrasUnet
karras_unet = KarrasUnet(
image_size = 256,
dim = 8,
channels = 3,
dim_max = 768
)
lumiere = MPLumiere(
karras_unet,
image_size = 256,
unet_time_kwarg = 'time',
conv_module_names = [
'downs.1',
'ups.1'
],
attn_module_names = [
'mids.0'
],
upsample_module_names = [
'ups.1'
],
downsample_module_names = [
'downs.1'
]
)
noised_video = torch.randn(2, 3, 8, 256, 256)
time = torch.ones(2,)
denoised_video = lumiere(noised_video, time = time)
assert noised_video.shape == denoised_video.shape
Todo
-
add all temporal layers
- researcher must pass in all layers for
- conv inflation modules (stages)
- attn inflation modules (middle)
- temporal downsample
- temporal upsamples
- validate time dimension is 2 ** downsample layers
- validate number of downsamples == upsamples
- at init, do a dry run with a mock tensor and assert output is the same
- researcher must pass in all layers for
-
expose only temporal parameters for learning, freeze everything else
-
figure out the best way to deal with the time conditioning after temporal downsampling - instead of pytree transform at the beginning, probably will need to hook into all the modules and inspect the batch sizes
-
handle middle modules that may have output shape as
(batch, seq, dim) -
following the conclusions of Tero Karras, improvise a variant of the 4 modules with magnitude preservation
-
test out on imagen-pytorch
-
look into multi-diffusion and see if it can turned into some simple wrapper
Citations
@inproceedings{BarTal2024LumiereAS,
title = {Lumiere: A Space-Time Diffusion Model for Video Generation},
author = {Omer Bar-Tal and Hila Chefer and Omer Tov and Charles Herrmann and Roni Paiss and Shiran Zada and Ariel Ephrat and Junhwa Hur and Yuanzhen Li and Tomer Michaeli and Oliver Wang and Deqing Sun and Tali Dekel and Inbar Mosseri},
year = {2024},
url = {https://api.semanticscholar.org/CorpusID:267095113}
}
@article{Karras2023AnalyzingAI,
title = {Analyzing and Improving the Training Dynamics of Diffusion Models},
author = {Tero Karras and Miika Aittala and Jaakko Lehtinen and Janne Hellsten and Timo Aila and Samuli Laine},
journal = {ArXiv},
year = {2023},
volume = {abs/2312.02696},
url = {https://api.semanticscholar.org/CorpusID:265659032}
}
.\lucidrains\lumiere-pytorch\setup.py
# 导入设置安装和查找包的函数
from setuptools import setup, find_packages
# 设置包的信息
setup(
# 包名
name = 'lumiere-pytorch',
# 查找所有包,不排除任何包
packages = find_packages(exclude=[]),
# 版本号
version = '0.0.20',
# 许可证
license='MIT',
# 描述
description = 'Lumiere',
# 作者
author = 'Phil Wang',
# 作者邮箱
author_email = 'lucidrains@gmail.com',
# 长描述内容类型
long_description_content_type = 'text/markdown',
# 项目链接
url = 'https://github.com/lucidrains/lumiere-pytorch',
# 关键词
keywords = [
'artificial intelligence',
'deep learning',
'text-to-video'
],
# 安装依赖
install_requires=[
'beartype',
'einops>=0.7.0',
'optree',
'torch>=2.0',
'x-transformers'
],
# 分类
classifiers=[
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\magvit2-pytorch\magvit2_pytorch\attend.py
# 导入所需模块和库
from functools import partial
from typing import Optional, Tuple
import torch
from torch import nn, einsum, Tensor
import torch.nn.functional as F
from collections import namedtuple
from functools import wraps
from packaging import version
from dataclasses import dataclass
from einops import rearrange, repeat
# 定义一个命名元组EfficientAttentionConfig,包含三个布尔类型的参数
EfficientAttentionConfig = namedtuple('EfficientAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
# 辅助函数
# 判断变量是否存在
def exists(val):
return val is not None
# 如果变量存在则返回其值,否则返回默认值
def default(val, d):
return val if exists(val) else d
# 过滤掉列表中的空值
def compact(arr):
return [*filter(exists, arr)]
# 保证函数只执行一次
def once(fn):
called = False
@wraps(fn)
def inner(x):
nonlocal called
if called:
return
called = True
return fn(x)
return inner
# 打印函数的输出,确保只打印一次
print_once = once(print)
# 用于创建因果掩码的函数
# 针对onnx cpu需要特殊处理(不支持.triu)
# 创建因果掩码
def create_causal_mask(i, j, device):
return torch.ones((i, j), device = device, dtype = torch.bool).triu(j - i + 1)
# 针对onnx创建因果掩码
def onnx_create_causal_mask(i, j, device):
r = torch.arange(i, device = device)
causal_mask = rearrange(r, 'i -> i 1') < rearrange(r, 'j -> 1 j')
causal_mask = F.pad(causal_mask, (j - i, 0), value = False)
return causal_mask
# 主类
class Attend(nn.Module):
def __init__(
self,
*,
dropout = 0.,
causal = False,
heads = None,
scale = None,
flash = False,
onnxable = False,
sdp_kwargs: dict = dict(
enable_flash = True,
enable_math = True,
enable_mem_efficient = True
)
):
super().__init__()
self.scale = scale
self.causal = causal
self.create_causal_mask = onnx_create_causal_mask if onnxable else create_causal_mask
self.dropout = dropout
self.attn_dropout = nn.Dropout(dropout)
# flash attention
# 检查是否支持flash attention
self.flash = flash and torch.cuda.is_available()
assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
self.sdp_kwargs = sdp_kwargs
def flash_attn(
self,
q, k, v,
mask = None,
attn_bias = None
):
# 解包输入张量的形状和其他属性
batch, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device
# 使输入张量连续
q, k, v = map(lambda t: t.contiguous(), (q, k, v))
# 处理缩放,因为在 sdp 中缩放不可定制,对其进行处理
if exists(self.scale):
q = q * self.scale / (q.shape[-1] ** -0.5)
# 检查是否存在 mask 并扩展到兼容的形状
causal = self.causal
# 如果 q_len == 1 且 causal 为真,则将 causal 设置为 False
if q_len == 1 and causal:
causal = False
# 扩展键填充 mask
if exists(mask):
assert mask.ndim == 4
mask = mask.expand(batch, heads, q_len, k_len)
# 处理 kv 缓存
if k_len > q_len and causal:
causal_mask = self.create_causal_mask(q_len, k_len, device=device)
if not exists(mask):
mask = ~causal_mask
else:
mask = mask & ~causal_mask
causal = False
# 手动处理 causal mask,如果给定了另一个 mask
row_is_entirely_masked = None
if exists(mask) and causal:
causal_mask = self.create_causal_mask(q_len, k_len, device=device)
mask = mask & ~causal_mask
# 防止整行被屏蔽
row_is_entirely_masked = ~mask.any(dim=-1)
mask[..., 0] = mask[..., 0] | row_is_entirely_masked
causal = False
# 处理 alibi 位置偏差,将 bool 转换为 float
if exists(attn_bias):
attn_bias = rearrange(attn_bias, 'h i j -> 1 h i j').expand(batch, heads, -1, -1)
mask_value = -torch.finfo(q.dtype).max
if exists(mask):
attn_bias = attn_bias.masked_fill(~mask, mask_value // 2)
elif causal:
causal_mask = self.create_causal_mask(q_len, k_len, device=device)
attn_bias = attn_bias.masked_fill(causal_mask, mask_value // 2)
causal = False
mask = attn_bias
# 使用 scaled_dot_product_attention 处理注意力
with torch.backends.cuda.sdp_kernel(**self.sdp_kwargs):
out = F.scaled_dot_product_attention(
q, k, v,
attn_mask=mask,
dropout_p=self.dropout if self.training else 0.,
is_causal=causal
)
# 对于整行被完全屏蔽的情况,将输出的该行标记为 0
if exists(row_is_entirely_masked):
out = out.masked_fill(row_is_entirely_masked[..., None], 0.)
return out
# 前向传播函数
def forward(
self,
q, k, v,
mask=None,
attn_bias=None,
prev_attn=None
):
"""
einstein notation
b - batch
h - heads
n, i, j - sequence length (base sequence length, source, target)
d - feature dimension
"""
# 获取输入张量的形状信息
n, heads, kv_heads, device = q.shape[-2], q.shape[1], k.shape[1], q.device
# 计算缩放因子
scale = default(self.scale, q.shape[-1] ** -0.5)
# 获取是否为因果注意力的标志
causal = self.causal
# 处理缓存的键值对解码
if n == 1 and causal:
causal = False
# 处理零键值对,允许网络关注空内容
if self.flash:
assert not exists(prev_attn), 'residual attention not compatible with flash attention'
return self.flash_attn(q, k, v, mask = mask, attn_bias = attn_bias)
# 计算点积注意力得分
dots = einsum(f'b h i d, b h j d -> b h i j', q, k) * scale
# 如果存在先前的注意力,加上先前的注意力得分
if exists(prev_attn):
dots = dots + prev_attn
# 如果存在注意力偏置,加上注意力偏置
if exists(attn_bias):
dots = dots + attn_bias
# 获取点积张量的形状信息和数据类型
i, j, dtype = *dots.shape[-2:], dots.dtype
# 定义掩码值
mask_value = -torch.finfo(dots.dtype).max
# 如果存在掩码,用掩码值填充不需要关注的位置
if exists(mask):
dots = dots.masked_fill(~mask, mask_value)
# 如果是因果注意力,创建因果掩码并用掩码值填充
if causal:
causal_mask = self.create_causal_mask(i, j, device = device)
dots = dots.masked_fill(causal_mask, mask_value)
# 计算注意力权重
attn = dots.softmax(dim = -1)
# 对注意力权重进行dropout
attn = self.attn_dropout(attn)
# 计算输出
out = einsum(f'b h i j, b h j d -> b h i d', attn, v)
return out
.\lucidrains\magvit2-pytorch\magvit2_pytorch\data.py
# 导入必要的库
from pathlib import Path
from functools import partial
import torch
from torch import Tensor
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader as PytorchDataLoader
import cv2
from PIL import Image
from torchvision import transforms as T, utils
from beartype import beartype
from beartype.typing import Tuple, List
from beartype.door import is_bearable
import numpy as np
from einops import rearrange
# 辅助函数
# 检查值是否存在
def exists(val):
return val is not None
# 返回输入值
def identity(t, *args, **kwargs):
return t
# 将输入值转换为元组
def pair(val):
return val if isinstance(val, tuple) else (val, val)
# 在指定维度上填充张量
def pad_at_dim(t, pad, dim = -1, value = 0.):
dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
zeros = ((0, 0) * dims_from_right)
return F.pad(t, (*zeros, *pad), value = value)
# 调整张量的帧数
def cast_num_frames(t, *, frames):
f = t.shape[-3]
if f == frames:
return t
if f > frames:
return t[..., :frames, :, :]
return pad_at_dim(t, (0, frames - f), dim = -3)
# 将图像转换为指定格式
def convert_image_to_fn(img_type, image):
if not exists(img_type) or image.mode == img_type:
return image
return image.convert(img_type)
# 如果路径没有后缀,则添加后缀
def append_if_no_suffix(path: str, suffix: str):
path = Path(path)
if path.suffix == '':
path = path.parent / (path.name + suffix)
assert path.suffix == suffix, f'{str(path)} needs to have suffix {suffix}'
return str(path)
# 通道到图像模式的映射
CHANNEL_TO_MODE = {
1: 'L',
3: 'RGB',
4: 'RGBA'
}
# 图像相关的辅助函数和数据集
# 图像数据集类
class ImageDataset(Dataset):
def __init__(
self,
folder,
image_size,
channels = 3,
convert_image_to = None,
exts = ['jpg', 'jpeg', 'png']
):
super().__init__()
folder = Path(folder)
assert folder.is_dir(), f'{str(folder)} must be a folder containing images'
self.folder = folder
self.image_size = image_size
exts = exts + [ext.upper() for ext in exts]
self.paths = [p for ext in exts for p in folder.glob(f'**/*.{ext}')]
print(f'{len(self.paths)} training samples found at {folder}')
if exists(channels) and not exists(convert_image_to):
convert_image_to = CHANNEL_TO_MODE.get(channels)
self.transform = T.Compose([
T.Lambda(partial(convert_image_to_fn, convert_image_to)),
T.Resize(image_size, antialias = True),
T.RandomHorizontalFlip(),
T.CenterCrop(image_size),
T.ToTensor()
])
def __len__(self):
return len(self.paths)
def __getitem__(self, index):
path = self.paths[index]
img = Image.open(path)
return self.transform(img)
# 张量的形状 (channels, frames, height, width) -> gif
# 处理读取和写入 gif
# 逐帧读取图像
def seek_all_images(img: Tensor, channels = 3):
mode = CHANNEL_TO_MODE.get(channels)
assert exists(mode), f'channels {channels} invalid'
i = 0
while True:
try:
img.seek(i)
yield img.convert(mode)
except EOFError:
break
i += 1
# 张量的形状 (channels, frames, height, width) -> gif
# 将视频张量转换为 gif
@beartype
def video_tensor_to_gif(
tensor: Tensor,
path: str,
duration = 120,
loop = 0,
optimize = True
):
path = append_if_no_suffix(path, '.gif')
images = map(T.ToPILImage(), tensor.unbind(dim = 1))
first_img, *rest_imgs = images
first_img.save(str(path), save_all = True, append_images = rest_imgs, duration = duration, loop = loop, optimize = optimize)
return images
# gif -> 张量 (channels, frame, height, width)
# 将 gif 转换为张量
def gif_to_tensor(
path: str,
channels = 3,
transform = T.ToTensor()
):
img = Image.open(path)
tensors = tuple(map(transform, seek_all_images(img, channels = channels)))
return torch.stack(tensors, dim = 1)
# 处理读取和写入 mp4
# 将视频转换为张量
def video_to_tensor(
path: str, # 视频文件的路径,需要导入的视频
num_frames = -1, # 要存储在输出张量中的帧数,默认为-1表示存储所有帧
crop_size = None # 裁剪尺寸,默认为None表示不进行裁剪
# 定义一个函数,将视频文件转换为张量
def video_to_tensor(path: str) -> Tensor: # 返回形状为 (1, 通道数, 帧数, 高度, 宽度) 的张量
# 使用 OpenCV 打开视频文件
video = cv2.VideoCapture(path)
frames = [] # 存储视频帧的列表
check = True
# 循环读取视频帧
while check:
check, frame = video.read()
if not check:
continue
# 如果存在裁剪尺寸,则对帧进行中心裁剪
if exists(crop_size):
frame = crop_center(frame, *pair(crop_size))
# 将帧重新排列为 (1, ...) 的形状并添加到 frames 列表中
frames.append(rearrange(frame, '... -> 1 ...'))
# 将帧列表转换为 numpy 数组,然后合并帧并转换为 numpy 数组
frames = np.array(np.concatenate(frames[:-1], axis=0))
frames = rearrange(frames, 'f h w c -> c f h w')
# 将 numpy 数组转换为 PyTorch 张量并转换为浮点数类型
frames_torch = torch.tensor(frames).float()
# 将张量值归一化到 [0, 1] 范围
frames_torch /= 255.
# 将张量沿着第一个维度翻转,从 BGR 格式转换为 RGB 格式
frames_torch = frames_torch.flip(dims=(0,))
# 返回指定数量的帧数
return frames_torch[:, :num_frames, :, :]
# 定义一个函数,将张量转换为视频文件
@beartype
def tensor_to_video(
tensor: Tensor, # PyTorch 视频张量
path: str, # 要保存的视频路径
fps=25, # 保存视频的帧率
video_format='MP4V' # 视频格式,默认为 MP4
):
# 如果路径没有后缀,则添加 .mp4 后缀
path = append_if_no_suffix(path, '.mp4')
# 将张量移动到 CPU
tensor = tensor.cpu()
# 获取张量的帧数、高度和宽度
num_frames, height, width = tensor.shape[-3:]
# 使用指定的视频格式创建 VideoWriter 对象
fourcc = cv2.VideoWriter_fourcc(*video_format)
video = cv2.VideoWriter(str(path), fourcc, fps, (width, height))
frames = [] # 存储视频帧的列表
# 遍历每一帧,将张量转换为 numpy 数组并写入视频
for idx in range(num_frames):
numpy_frame = tensor[:, idx, :, :].numpy()
numpy_frame = np.uint8(rearrange(numpy_frame, 'c h w -> h w c'))
video.write(numpy_frame)
# 释放 VideoWriter 对象
video.release()
# 关闭所有 OpenCV 窗口
cv2.destroyAllWindows()
return video
# 定义一个函数,对图像进行中心裁剪
def crop_center(
img: Tensor, # 输入图像张���
cropx: int, # 最终图像在 x 方向上的长度
cropy: int # 最终图像在 y 方向上的长度
) -> Tensor: # 返回裁剪后的图像张量
y, x, c = img.shape
startx = x // 2 - cropx // 2
starty = y // 2 - cropy // 2
return img[starty:(starty + cropy), startx:(startx + cropx), :]
# 视频数据集类
class VideoDataset(Dataset):
def __init__(
self,
folder, # 视频文件夹路径
image_size, # 图像尺寸
channels=3, # 通道数,默认为 3
num_frames=17, # 帧数,默认为 17
force_num_frames=True, # 是否强制指定帧数,默认为 True
exts=['gif', 'mp4'] # 视频文件扩展名列表,默认为 ['gif', 'mp4']
):
super().__init__()
folder = Path(folder)
assert folder.is_dir(), f'{str(folder)} must be a folder containing videos'
self.folder = folder
self.image_size = image_size
self.channels = channels
self.paths = [p for ext in exts for p in folder.glob(f'**/*.{ext}')]
print(f'{len(self.paths)} training samples found at {folder}')
# 定义图像转换操作
self.transform = T.Compose([
T.Resize(image_size, antialias=True),
T.CenterCrop(image_size)
])
# 定义将视频路径转换为张量的函数
self.gif_to_tensor = partial(gif_to_tensor, channels=self.channels, transform=self.transform)
self.mp4_to_tensor = partial(video_to_tensor, crop_size=self.image_size)
# 定义将帧数转换为指定数量的函数
self.cast_num_frames_fn = partial(cast_num_frames, frames=num_frames) if force_num_frames else identity
def __len__(self):
return len(self.paths)
def __getitem__(self, index):
path = self.paths[index]
ext = path.suffix
path_str = str(path)
if ext == '.gif':
tensor = self.gif_to_tensor(path_str)
elif ext == '.mp4':
tensor = self.mp4_to_tensor(path_str)
frames = tensor.unbind(dim=1)
tensor = torch.stack([*map(self.transform, frames)], dim=1)
else:
raise ValueError(f'unknown extension {ext}')
return self.cast_num_frames_fn(tensor)
# 重写数据加载器以能够整理张量和字符串
def collate_tensors_and_strings(data):
if is_bearable(data, List[Tensor]):
return (torch.stack(data),)
data = zip(*data)
output = []
# 遍历数据列表中的每个元素
for datum in data:
# 检查数据是否为可接受的类型(元组中包含张量)
if is_bearable(datum, Tuple[Tensor, ...]):
# 如果是,则将张量堆叠成一个张量
datum = torch.stack(datum)
# 检查数据是否为可接受的类型(元组中包含字符串)
elif is_bearable(datum, Tuple[str, ...]):
# 如果是,则将元组转换为列表
datum = list(datum)
else:
# 如果数据类型不符合要求,则引发值错误异常
raise ValueError('detected invalid type being passed from dataset')
# 将处理后的数据添加到输出列表中
output.append(datum)
# 将输出列表转换为元组并返回
return tuple(output)
# 定义一个函数DataLoader,接受任意数量的位置参数和关键字参数
def DataLoader(*args, **kwargs):
# 返回PytorchDataLoader对象,使用指定的参数和自定义的collate函数
return PytorchDataLoader(*args, collate_fn = collate_tensors_and_strings, **kwargs)
.\lucidrains\magvit2-pytorch\magvit2_pytorch\magvit2_pytorch.py
# 导入必要的库
import copy
from pathlib import Path
from math import log2, ceil, sqrt
from functools import wraps, partial
import torch
import torch.nn.functional as F
from torch.cuda.amp import autocast
from torch import nn, einsum, Tensor
from torch.nn import Module, ModuleList
from torch.autograd import grad as torch_grad
import torchvision
from torchvision.models import VGG16_Weights
from collections import namedtuple
# 导入自定义模块
from vector_quantize_pytorch import LFQ, FSQ
from einops import rearrange, repeat, reduce, pack, unpack
from einops.layers.torch import Rearrange
from beartype import beartype
from beartype.typing import Union, Tuple, Optional, List
from magvit2_pytorch.attend import Attend
from magvit2_pytorch.version import __version__
from gateloop_transformer import SimpleGateLoopLayer
from taylor_series_linear_attention import TaylorSeriesLinearAttn
from kornia.filters import filter3d
import pickle
# helper
# 检查变量是否存在
def exists(v):
return v is not None
# 返回默认值
def default(v, d):
return v if exists(v) else d
# 安全获取列表中的元素
def safe_get_index(it, ind, default = None):
if ind < len(it):
return it[ind]
return default
# 将输入转换为元组
def pair(t):
return t if isinstance(t, tuple) else (t, t)
# 返回输入本身
def identity(t, *args, **kwargs):
return t
# 检查一个数是否可以被另一个数整除
def divisible_by(num, den):
return (num % den) == 0
# 将输入打包成指定模式
def pack_one(t, pattern):
return pack([t], pattern)
# 将输入解包成指定模式
def unpack_one(t, ps, pattern):
return unpack(t, ps, pattern)[0]
# 在张量的末尾添加指定维度
def append_dims(t, ndims: int):
return t.reshape(*t.shape, *((1,) * ndims))
# 检查一个数是否为奇数
def is_odd(n):
return not divisible_by(n, 2)
# 删除对象的属性
def maybe_del_attr_(o, attr):
if hasattr(o, attr):
delattr(o, attr)
# 将输入转换为元组
def cast_tuple(t, length = 1):
return t if isinstance(t, tuple) else ((t,) * length)
# tensor helpers
# 对张量进行 L2 归一化
def l2norm(t):
return F.normalize(t, dim = -1, p = 2)
# 在指定维度上对张量进行填充
def pad_at_dim(t, pad, dim = -1, value = 0.):
dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
zeros = ((0, 0) * dims_from_right)
return F.pad(t, (*zeros, *pad), value = value)
# 从视频中选择指定帧
def pick_video_frame(video, frame_indices):
batch, device = video.shape[0], video.device
video = rearrange(video, 'b c f ... -> b f c ...')
batch_indices = torch.arange(batch, device = device)
batch_indices = rearrange(batch_indices, 'b -> b 1')
images = video[batch_indices, frame_indices]
images = rearrange(images, 'b 1 c ... -> b c ...')
return images
# gan related
# 计算梯度惩罚
def gradient_penalty(images, output):
batch_size = images.shape[0]
gradients = torch_grad(
outputs = output,
inputs = images,
grad_outputs = torch.ones(output.size(), device = images.device),
create_graph = True,
retain_graph = True,
only_inputs = True
)[0]
gradients = rearrange(gradients, 'b ... -> b (...)')
return ((gradients.norm(2, dim = 1) - 1) ** 2).mean()
# Leaky ReLU 激活函数
def leaky_relu(p = 0.1):
return nn.LeakyReLU(p)
# Hinge 损失函数(判别器)
def hinge_discr_loss(fake, real):
return (F.relu(1 + fake) + F.relu(1 - real)).mean()
# Hinge 损失函数(生成器)
def hinge_gen_loss(fake):
return -fake.mean()
# 计算损失对层的梯度
@autocast(enabled = False)
@beartype
def grad_layer_wrt_loss(
loss: Tensor,
layer: nn.Parameter
):
return torch_grad(
outputs = loss,
inputs = layer,
grad_outputs = torch.ones_like(loss),
retain_graph = True
)[0].detach()
# helper decorators
# 移除 VGG 属性
def remove_vgg(fn):
@wraps(fn)
def inner(self, *args, **kwargs):
has_vgg = hasattr(self, 'vgg')
if has_vgg:
vgg = self.vgg
delattr(self, 'vgg')
out = fn(self, *args, **kwargs)
if has_vgg:
self.vgg = vgg
return out
return inner
# helper classes
# 顺序模块
def Sequential(*modules):
modules = [*filter(exists, modules)]
if len(modules) == 0:
return nn.Identity()
return nn.Sequential(*modules)
# 残差模块
class Residual(Module):
@beartype
def __init__(self, fn: Module):
super().__init__()
self.fn = fn
# 定义一个前向传播函数,接受输入 x 和其他关键字参数
def forward(self, x, **kwargs):
# 调用函数 fn 对输入 x 进行处理,并将结果与输入 x 相加后返回
return self.fn(x, **kwargs) + x
# 一系列张量操作,将张量转换为 (batch, time, feature dimension) 格式,然后再转回来
class ToTimeSequence(Module):
@beartype
def __init__(self, fn: Module):
super().__init__()
self.fn = fn
def forward(self, x, **kwargs):
# 重新排列张量的维度,将其转换为 (batch, ..., feature, channel) 格式
x = rearrange(x, 'b c f ... -> b ... f c')
# 打包张量,将其转换为 (batch, ..., feature, channel) 格式
x, ps = pack_one(x, '* n c')
# 使用给定的函数对张量进行操作
o = self.fn(x, **kwargs)
# 解包张量,将其转换回原始格式
o = unpack_one(o, ps, '* n c')
# 重新排列张量的维度,将其转换回原始格式
return rearrange(o, 'b ... f c -> b c f ...')
class SqueezeExcite(Module):
# 全局上下文网络 - 基于注意力机制的 Squeeze-Excite 变种 (https://arxiv.org/abs/2012.13375)
def __init__(
self,
dim,
*,
dim_out = None,
dim_hidden_min = 16,
init_bias = -10
):
super().__init__()
dim_out = default(dim_out, dim)
# 创建卷积层,用于计算注意力权重
self.to_k = nn.Conv2d(dim, 1, 1)
dim_hidden = max(dim_hidden_min, dim_out // 2)
# 创建包含卷积层和激活函数的网络结构
self.net = nn.Sequential(
nn.Conv2d(dim, dim_hidden, 1),
nn.LeakyReLU(0.1),
nn.Conv2d(dim_hidden, dim_out, 1),
nn.Sigmoid()
)
# 初始化网络参数
nn.init.zeros_(self.net[-2].weight)
nn.init.constant_(self.net[-2].bias, init_bias)
def forward(self, x):
orig_input, batch = x, x.shape[0]
is_video = x.ndim == 5
if is_video:
# 重新排列视频张量的维度
x = rearrange(x, 'b c f h w -> (b f) c h w')
# 计算上下文信息
context = self.to_k(x)
# 计算注意力权重
context = rearrange(context, 'b c h w -> b c (h w)').softmax(dim = -1)
spatial_flattened_input = rearrange(x, 'b c h w -> b c (h w)')
# 使用注意力权重对输入进行加权求和
out = einsum('b i n, b c n -> b c i', context, spatial_flattened_input)
out = rearrange(out, '... -> ... 1')
gates = self.net(out)
if is_video:
# 将结果转换回视频张量的格式
gates = rearrange(gates, '(b f) c h w -> b c f h w', b = batch)
return gates * orig_input
# token shifting
class TokenShift(Module):
@beartype
def __init__(self, fn: Module):
super().__init__()
self.fn = fn
def forward(self, x, **kwargs):
# 将输入张量分成两部分
x, x_shift = x.chunk(2, dim = 1)
# 在时间维度上进行填充,实现时间维度的位移
x_shift = pad_at_dim(x_shift, (1, -1), dim = 2)
# 将两部分张量连接起来
x = torch.cat((x, x_shift), dim = 1)
return self.fn(x, **kwargs)
# rmsnorm
class RMSNorm(Module):
def __init__(
self,
dim,
channel_first = False,
images = False,
bias = False
):
super().__init__()
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
self.channel_first = channel_first
self.scale = dim ** 0.5
self.gamma = nn.Parameter(torch.ones(shape))
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
def forward(self, x):
# 对输入张量进行 RMS 归一化
return F.normalize(x, dim = (1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
class AdaptiveRMSNorm(Module):
def __init__(
self,
dim,
*,
dim_cond,
channel_first = False,
images = False,
bias = False
):
super().__init__()
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
self.dim_cond = dim_cond
self.channel_first = channel_first
self.scale = dim ** 0.5
# 创建线性层,用于生成 gamma 和 bias
self.to_gamma = nn.Linear(dim_cond, dim)
self.to_bias = nn.Linear(dim_cond, dim) if bias else None
# 初始化线性层参数
nn.init.zeros_(self.to_gamma.weight)
nn.init.ones_(self.to_gamma.bias)
if bias:
nn.init.zeros_(self.to_bias.weight)
nn.init.zeros_(self.to_bias.bias)
@beartype
# 定义一个前向传播函数,接受输入张量 x 和条件张量 cond
def forward(self, x: Tensor, *, cond: Tensor):
# 获取批量大小
batch = x.shape[0]
# 断言条件张量的形状为 (batch, self.dim_cond)
assert cond.shape == (batch, self.dim_cond)
# 根据条件张量生成 gamma
gamma = self.to_gamma(cond)
# 初始化偏置为 0
bias = 0.
# 如果存在偏置生成函数
if exists(self.to_bias):
# 根据条件张量生成偏置
bias = self.to_bias(cond)
# 如果通道在前
if self.channel_first:
# 在 gamma 的维度前面添加维度,使其与输入张量 x 的维度相同
gamma = append_dims(gamma, x.ndim - 2)
# 如果存在偏置生成函数
if exists(self.to_bias):
# 在偏置的维度前面添加维度,使其与输入张量 x 的维度相同
bias = append_dims(bias, x.ndim - 2)
# 对输入张量 x 进行归一化,根据通道顺序选择归一化的维度,然后乘以缩放因子 scale 和 gamma,最后加上偏置 bias
return F.normalize(x, dim = (1 if self.channel_first else -1)) * self.scale * gamma + bias
# 定义一个名为 Attention 的类,继承自 Module 类
class Attention(Module):
# 初始化函数,接受多个参数
@beartype
def __init__(
self,
*,
dim,
dim_cond: Optional[int] = None,
causal = False,
dim_head = 32,
heads = 8,
flash = False,
dropout = 0.,
num_memory_kv = 4
):
# 调用父类的初始化函数
super().__init__()
# 计算内部维度
dim_inner = dim_head * heads
# 检查是否需要条件
self.need_cond = exists(dim_cond)
# 根据是否需要条件选择不同的归一化方法
if self.need_cond:
self.norm = AdaptiveRMSNorm(dim, dim_cond = dim_cond)
else:
self.norm = RMSNorm(dim)
# 构建 QKV 网络
self.to_qkv = nn.Sequential(
nn.Linear(dim, dim_inner * 3, bias = False),
Rearrange('b n (qkv h d) -> qkv b h n d', qkv = 3, h = heads)
)
# 断言内存键值对数量大于 0
assert num_memory_kv > 0
# 初始化内存键值对
self.mem_kv = nn.Parameter(torch.randn(2, heads, num_memory_kv, dim_head))
# 构建 Attend 层
self.attend = Attend(
causal = causal,
dropout = dropout,
flash = flash
)
# 构建输出层
self.to_out = nn.Sequential(
Rearrange('b h n d -> b n (h d)'),
nn.Linear(dim_inner, dim, bias = False)
)
# 前向传播函数
@beartype
def forward(
self,
x,
mask: Optional[Tensor ] = None,
cond: Optional[Tensor] = None
):
# 根据是否需要条件选择不同的参数
maybe_cond_kwargs = dict(cond = cond) if self.need_cond else dict()
# 对输入进行归一化
x = self.norm(x, **maybe_cond_kwargs)
# 获取 QKV
q, k, v = self.to_qkv(x)
# 重复内存键值对
mk, mv = map(lambda t: repeat(t, 'h n d -> b h n d', b = q.shape[0]), self.mem_kv)
k = torch.cat((mk, k), dim = -2)
v = torch.cat((mv, v), dim = -2)
# 进行注意力计算
out = self.attend(q, k, v, mask = mask)
return self.to_out(out)
# 定义一个名为 LinearAttention 的类,继承自 Module 类
class LinearAttention(Module):
"""
using the specific linear attention proposed in https://arxiv.org/abs/2106.09681
"""
# 初始化函数,接受多个参数
@beartype
def __init__(
self,
*,
dim,
dim_cond: Optional[int] = None,
dim_head = 8,
heads = 8,
dropout = 0.
):
# 调用父类的初始化函数
super().__init__()
# 计算内部维度
dim_inner = dim_head * heads
# 检查是否需要条件
self.need_cond = exists(dim_cond)
# 根据是否需要条件选择不同的归一化方法
if self.need_cond:
self.norm = AdaptiveRMSNorm(dim, dim_cond = dim_cond)
else:
self.norm = RMSNorm(dim)
# 构建 TaylorSeriesLinearAttn 层
self.attn = TaylorSeriesLinearAttn(
dim = dim,
dim_head = dim_head,
heads = heads
)
# 前向传播函数
def forward(
self,
x,
cond: Optional[Tensor] = None
):
# 根据是否需要条件选择不同的参数
maybe_cond_kwargs = dict(cond = cond) if self.need_cond else dict()
# 对输入进行归一化
x = self.norm(x, **maybe_cond_kwargs)
return self.attn(x)
# 定义一个名为 LinearSpaceAttention 的类,继承自 LinearAttention 类
class LinearSpaceAttention(LinearAttention):
# 重写前向传播函数
def forward(self, x, *args, **kwargs):
# 重新排列输入数据
x = rearrange(x, 'b c ... h w -> b ... h w c')
x, batch_ps = pack_one(x, '* h w c')
x, seq_ps = pack_one(x, 'b * c')
# 调用父类的前向传播函数
x = super().forward(x, *args, **kwargs)
x = unpack_one(x, seq_ps, 'b * c')
x = unpack_one(x, batch_ps, '* h w c')
return rearrange(x, 'b ... h w c -> b c ... h w')
# 定义一个名为 SpaceAttention 的类,继承自 Attention 类
class SpaceAttention(Attention):
# 重写前向传播函数
def forward(self, x, *args, **kwargs):
# 重新排列输入数据
x = rearrange(x, 'b c t h w -> b t h w c')
x, batch_ps = pack_one(x, '* h w c')
x, seq_ps = pack_one(x, 'b * c')
# 调用父类的前向传播函数
x = super().forward(x, *args, **kwargs)
x = unpack_one(x, seq_ps, 'b * c')
x = unpack_one(x, batch_ps, '* h w c')
return rearrange(x, 'b t h w c -> b c t h w')
# 定义一个名为 TimeAttention 的类,继承自 Attention 类
class TimeAttention(Attention):
# 重写前向传播函数
def forward(self, x, *args, **kwargs):
# 重新排列输入数据
x = rearrange(x, 'b c t h w -> b h w t c')
x, batch_ps = pack_one(x, '* t c')
# 调用父类的前向传播函数
x = super().forward(x, *args, **kwargs)
x = unpack_one(x, batch_ps, '* t c')
return rearrange(x, 'b h w t c -> b c t h w')
# 定义一个名为 GEGLU 的类,继承自 Module 类
class GEGLU(Module):
# 前向传播函数
def forward(self, x):
# 将输入数据分成两部分
x, gate = x.chunk(2, dim = 1)
return F.gelu(gate) * x
# 定义一个名为 FeedForward 的类,继承自 Module 类
class FeedForward(Module):
@beartype
# 初始化函数,设置神经网络的参数
def __init__(
self,
dim, # 输入数据的维度
*,
dim_cond: Optional[int] = None, # 条件维度,默认为None
mult = 4, # 倍数,默认为4
images = False # 是否为图像数据,默认为False
):
super().__init__() # 调用父类的初始化函数
# 根据是否为图像数据选择不同的卷积层类
conv_klass = nn.Conv2d if images else nn.Conv3d
# 根据条件维度是否存在选择不同的归一化层类
rmsnorm_klass = RMSNorm if not exists(dim_cond) else partial(AdaptiveRMSNorm, dim_cond = dim_cond)
# 创建可能的自适应归一化层类
maybe_adaptive_norm_klass = partial(rmsnorm_klass, channel_first = True, images = images)
# 计算内部维度
dim_inner = int(dim * mult * 2 / 3)
# 初始化归一化层
self.norm = maybe_adaptive_norm_klass(dim)
# 初始化神经网络结构
self.net = Sequential(
conv_klass(dim, dim_inner * 2, 1), # 卷积层
GEGLU(), # 激活函数
conv_klass(dim_inner, dim, 1) # 卷积层
)
# 前向传播函数
@beartype
def forward(
self,
x: Tensor, # 输入数据张量
*,
cond: Optional[Tensor] = None # 条件张量,默认为None
):
# 根据条件张量是否存在选择不同的参数
maybe_cond_kwargs = dict(cond = cond) if exists(cond) else dict()
# 对输入数据进行归一化处理
x = self.norm(x, **maybe_cond_kwargs)
return self.net(x) # 返回神经网络处理后的结果
# 定义一个带有反锯齿下采样的鉴别器(模糊池 Zhang 等人)
class Blur(Module):
def __init__(self):
super().__init__()
# 定义一个张量 f
f = torch.Tensor([1, 2, 1])
# 将张量 f 注册为缓冲区
self.register_buffer('f', f)
def forward(
self,
x,
space_only = False,
time_only = False
):
# 断言空间和时间只能选择一个
assert not (space_only and time_only)
# 获取张量 f
f = self.f
if space_only:
# 对 f 进行乘法操作
f = einsum('i, j -> i j', f, f)
# 重新排列张量 f
f = rearrange(f, '... -> 1 1 ...')
elif time_only:
# 重新排列张量 f
f = rearrange(f, 'f -> 1 f 1 1')
else:
# 对 f 进行乘法操作
f = einsum('i, j, k -> i j k', f, f, f)
# 重新排列张量 f
f = rearrange(f, '... -> 1 ...')
# 判断输入 x 是否为图像
is_images = x.ndim == 4
if is_images:
# 重新排列输入 x
x = rearrange(x, 'b c h w -> b c 1 h w')
# 对输入 x 进行 3D 滤波
out = filter3d(x, f, normalized = True)
if is_images:
# 重新排列输出 out
out = rearrange(out, 'b c 1 h w -> b c h w')
return out
class DiscriminatorBlock(Module):
def __init__(
self,
input_channels,
filters,
downsample = True,
antialiased_downsample = True
):
super().__init__()
# 定义卷积层 conv_res
self.conv_res = nn.Conv2d(input_channels, filters, 1, stride = (2 if downsample else 1))
# 定义神经网络结构 net
self.net = nn.Sequential(
nn.Conv2d(input_channels, filters, 3, padding = 1),
leaky_relu(),
nn.Conv2d(filters, filters, 3, padding = 1),
leaky_relu()
)
# 如果需要反锯齿下采样,则定义模糊层 maybe_blur
self.maybe_blur = Blur() if antialiased_downsample else None
# 如果需要下采样,则定义下采样层 downsample
self.downsample = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (c p1 p2) h w', p1 = 2, p2 = 2),
nn.Conv2d(filters * 4, filters, 1)
) if downsample else None
def forward(self, x):
# 对输入 x 进行卷积操作,得到 res
res = self.conv_res(x)
# 对输入 x 进行神经网络结构操作
x = self.net(x)
if exists(self.downsample):
if exists(self.maybe_blur):
# 如果存在模糊层,则对 x 进行模糊操作
x = self.maybe_blur(x, space_only = True)
# 对 x 进行下采样操作
x = self.downsample(x)
# 对 x 进行加权求和并缩放操作
x = (x + res) * (2 ** -0.5)
return x
class Discriminator(Module):
@beartype
def __init__(
self,
*,
dim,
image_size,
channels = 3,
max_dim = 512,
attn_heads = 8,
attn_dim_head = 32,
linear_attn_dim_head = 8,
linear_attn_heads = 16,
ff_mult = 4,
antialiased_downsample = False
):
# 调用父类的构造函数
super().__init__()
# 将图像大小转换为元组
image_size = pair(image_size)
# 计算图像分辨率的最小值
min_image_resolution = min(image_size)
# 计算层数
num_layers = int(log2(min_image_resolution) - 2)
blocks = []
# 计算每一层的维度
layer_dims = [channels] + [(dim * 4) * (2 ** i) for i in range(num_layers + 1)]
# 将每一层的维度限制在最大维度内
layer_dims = [min(layer_dim, max_dim) for layer_dim in layer_dims]
# 将每一层的输入输出维度组成元组
layer_dims_in_out = tuple(zip(layer_dims[:-1], layer_dims[1:]))
blocks = []
attn_blocks = []
image_resolution = min_image_resolution
for ind, (in_chan, out_chan) in enumerate(layer_dims_in_out):
num_layer = ind + 1
is_not_last = ind != (len(layer_dims_in_out) - 1)
# 创建判别器块
block = DiscriminatorBlock(
in_chan,
out_chan,
downsample = is_not_last,
antialiased_downsample = antialiased_downsample
)
# 创建注意力块
attn_block = Sequential(
Residual(LinearSpaceAttention(
dim = out_chan,
heads = linear_attn_heads,
dim_head = linear_attn_dim_head
)),
Residual(FeedForward(
dim = out_chan,
mult = ff_mult,
images = True
))
)
blocks.append(ModuleList([
block,
attn_block
]))
image_resolution //= 2
self.blocks = ModuleList(blocks)
dim_last = layer_dims[-1]
downsample_factor = 2 ** num_layers
last_fmap_size = tuple(map(lambda n: n // downsample_factor, image_size))
latent_dim = last_fmap_size[0] * last_fmap_size[1] * dim_last
# 定义输出层
self.to_logits = Sequential(
nn.Conv2d(dim_last, dim_last, 3, padding = 1),
leaky_relu(),
Rearrange('b ... -> b (...)'),
nn.Linear(latent_dim, 1),
Rearrange('b 1 -> b')
)
def forward(self, x):
# 遍历每个块和注意力块
for block, attn_block in self.blocks:
x = block(x)
x = attn_block(x)
return self.to_logits(x)
# 定义一个继承自 Module 的类 Conv3DMod,用于实现可调制的卷积,用于在潜变量上进行条件化
class Conv3DMod(Module):
# 初始化函数
@beartype
def __init__(
self,
dim,
*,
spatial_kernel,
time_kernel,
causal = True,
dim_out = None,
demod = True,
eps = 1e-8,
pad_mode = 'zeros'
):
super().__init__()
dim_out = default(dim_out, dim)
self.eps = eps
# 断言空间和时间卷积核为奇数
assert is_odd(spatial_kernel) and is_odd(time_kernel)
self.spatial_kernel = spatial_kernel
self.time_kernel = time_kernel
# 根据是否因果,设置时间填充
time_padding = (time_kernel - 1, 0) if causal else ((time_kernel // 2,) * 2)
self.pad_mode = pad_mode
self.padding = (*((spatial_kernel // 2,) * 4), *time_padding)
self.weights = nn.Parameter(torch.randn((dim_out, dim, time_kernel, spatial_kernel, spatial_kernel)))
self.demod = demod
# 初始化权重
nn.init.kaiming_normal_(self.weights, a = 0, mode = 'fan_in', nonlinearity = 'selu')
# 前向传播函数
@beartype
def forward(
self,
fmap,
cond: Tensor
):
"""
notation
b - batch
n - convs
o - output
i - input
k - kernel
"""
b = fmap.shape[0]
# 准备用于调制的权重
weights = self.weights
# 进行调制和解调制,类似 stylegan2 中的操作
cond = rearrange(cond, 'b i -> b 1 i 1 1 1')
weights = weights * (cond + 1)
if self.demod:
inv_norm = reduce(weights ** 2, 'b o i k0 k1 k2 -> b o 1 1 1 1', 'sum').clamp(min = self.eps).rsqrt()
weights = weights * inv_norm
fmap = rearrange(fmap, 'b c t h w -> 1 (b c) t h w')
weights = rearrange(weights, 'b o ... -> (b o) ...')
fmap = F.pad(fmap, self.padding, mode = self.pad_mode)
fmap = F.conv3d(fmap, weights, groups = b)
return rearrange(fmap, '1 (b o) ... -> b o ...', b = b)
# 定义一个继承自 Module 的类 SpatialDownsample2x,用于进行空间下采样
class SpatialDownsample2x(Module):
def __init__(
self,
dim,
dim_out = None,
kernel_size = 3,
antialias = False
):
super().__init__()
dim_out = default(dim_out, dim)
self.maybe_blur = Blur() if antialias else identity
self.conv = nn.Conv2d(dim, dim_out, kernel_size, stride = 2, padding = kernel_size // 2)
# 前向传播函数
def forward(self, x):
x = self.maybe_blur(x, space_only = True)
x = rearrange(x, 'b c t h w -> b t c h w')
x, ps = pack_one(x, '* c h w')
out = self.conv(x)
out = unpack_one(out, ps, '* c h w')
out = rearrange(out, 'b t c h w -> b c t h w')
return out
# 定义一个继承自 Module 的类 TimeDownsample2x,用于进行时间下采样
class TimeDownsample2x(Module):
def __init__(
self,
dim,
dim_out = None,
kernel_size = 3,
antialias = False
):
super().__init__()
dim_out = default(dim_out, dim)
self.maybe_blur = Blur() if antialias else identity
self.time_causal_padding = (kernel_size - 1, 0)
self.conv = nn.Conv1d(dim, dim_out, kernel_size, stride = 2)
# 前向传播函数
def forward(self, x):
x = self.maybe_blur(x, time_only = True)
x = rearrange(x, 'b c t h w -> b h w c t')
x, ps = pack_one(x, '* c t')
x = F.pad(x, self.time_causal_padding)
out = self.conv(x)
out = unpack_one(out, ps, '* c t')
out = rearrange(out, 'b h w c t -> b c t h w')
return out
# 定义一个继承自 Module 的类 SpatialUpsample2x,用于进行空间上采样
class SpatialUpsample2x(Module):
def __init__(
self,
dim,
dim_out = None
):
super().__init__()
dim_out = default(dim_out, dim)
conv = nn.Conv2d(dim, dim_out * 4, 1)
self.net = nn.Sequential(
conv,
nn.SiLU(),
Rearrange('b (c p1 p2) h w -> b c (h p1) (w p2)', p1 = 2, p2 = 2)
)
self.init_conv_(conv)
# 初始化卷积层的权重和偏置
def init_conv_(self, conv):
# 获取卷积层的输出通道数、输入通道数、高度和宽度
o, i, h, w = conv.weight.shape
# 创建一个与卷积层权重相同形状的张量
conv_weight = torch.empty(o // 4, i, h, w)
# 使用 Kaiming 初始化方法初始化权重
nn.init.kaiming_uniform_(conv_weight)
# 将权重张量重复4次,扩展为4倍的输出通道数
conv_weight = repeat(conv_weight, 'o ... -> (o 4) ...')
# 将初始化好的权重复制给卷积层的权重
conv.weight.data.copy_(conv_weight)
# 初始化卷积层的偏置为零
nn.init.zeros_(conv.bias.data)
# 前向传播函数
def forward(self, x):
# 重新排列输入张量的维度,将通道维度放到第二个位置
x = rearrange(x, 'b c t h w -> b t c h w')
# 将输入张量打包成一个元组,每个元素为一个通道的数据
x, ps = pack_one(x, '* c h w')
# 将打包后的输入张量传入网络进行前向传播
out = self.net(x)
# 将网络输出解包,恢复为原始形状
out = unpack_one(out, ps, '* c h w')
# 重新排列输出张量的维度,将通道维度放回最后一个位置
out = rearrange(out, 'b t c h w -> b c t h w')
# 返回前向传播结果
return out
# 定义一个类 TimeUpsample2x,继承自 Module 类
class TimeUpsample2x(Module):
# 初始化函数
def __init__(
self,
dim,
dim_out = None
):
# 调用父类的初始化函数
super().__init__()
# 如果未指定输出维度,则默认与输入维度相同
dim_out = default(dim_out, dim)
# 创建一个 1 维卷积层,输入维度为 dim,输出维度为 dim_out * 2,卷积核大小为 1
conv = nn.Conv1d(dim, dim_out * 2, 1)
# 使用 nn.Sequential 定义网络结构
self.net = nn.Sequential(
conv,
nn.SiLU(), # 使用 SiLU 激活函数
Rearrange('b (c p) t -> b c (t p)', p = 2) # 重新排列张量维度
)
# 初始化卷积层的权重
self.init_conv_(conv)
# 初始化卷积层的权重
def init_conv_(self, conv):
o, i, t = conv.weight.shape
# 创建一个与卷积层权重相同形状的张量
conv_weight = torch.empty(o // 2, i, t)
# 使用 kaiming_uniform_ 方法初始化权重
nn.init.kaiming_uniform_(conv_weight)
# 将权重张量重复一次
conv_weight = repeat(conv_weight, 'o ... -> (o 2) ...')
# 将初始化后的权重赋值给卷积层
conv.weight.data.copy_(conv_weight)
# 将偏置项初始化为零
nn.init.zeros_(conv.bias.data)
# 前向传播函数
def forward(self, x):
# 重新排列输入张量的维度
x = rearrange(x, 'b c t h w -> b h w c t')
# 打包输入张量
x, ps = pack_one(x, '* c t')
# 网络前向传播
out = self.net(x)
# 解包输出张量
out = unpack_one(out, ps, '* c t')
# 重新排列输出张量的维度
out = rearrange(out, 'b h w c t -> b c t h w')
return out
# 定义一个函数 SameConv2d,用于创建相同维度的二维卷积层
def SameConv2d(dim_in, dim_out, kernel_size):
kernel_size = cast_tuple(kernel_size, 2)
padding = [k // 2 for k in kernel_size]
return nn.Conv2d(dim_in, dim_out, kernel_size = kernel_size, padding = padding)
# 定义一个类 CausalConv3d,继承自 Module 类
class CausalConv3d(Module):
# 初始化函数
@beartype
def __init__(
self,
chan_in,
chan_out,
kernel_size: Union[int, Tuple[int, int, int]],
pad_mode = 'constant',
**kwargs
):
# 调用父类的初始化函数
super().__init__()
kernel_size = cast_tuple(kernel_size, 3)
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
# 确保高度和宽度的卷积核大小为奇数
assert is_odd(height_kernel_size) and is_odd(width_kernel_size)
dilation = kwargs.pop('dilation', 1)
stride = kwargs.pop('stride', 1)
# 设置时间维度的填充大小
time_pad = dilation * (time_kernel_size - 1) + (1 - stride)
height_pad = height_kernel_size // 2
width_pad = width_kernel_size // 2
self.pad_mode = pad_mode
self.time_pad = time_pad
self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)
stride = (stride, 1, 1)
dilation = (dilation, 1, 1)
# 创建一个三维卷积层
self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride = stride, dilation = dilation, **kwargs)
# 前向传播函数
def forward(self, x):
# 根据填充模式选择填充方式
pad_mode = self.pad_mode if self.time_pad < x.shape[2] else 'constant'
# 对输入张量进行填充
x = F.pad(x, self.time_causal_padding, mode = pad_mode)
return self.conv(x)
# 定义一个函数 ResidualUnit,用于创建残差单元
@beartype
def ResidualUnit(
dim,
kernel_size: Union[int, Tuple[int, int, int]],
pad_mode: str = 'constant'
):
# 构建残差单元网络结构
net = Sequential(
CausalConv3d(dim, dim, kernel_size, pad_mode = pad_mode),
nn.ELU(), # 使用 ELU 激活函数
nn.Conv3d(dim, dim, 1),
nn.ELU(),
SqueezeExcite(dim)
)
return Residual(net)
# 定义一个类 ResidualUnitMod,继承自 Module 类
@beartype
class ResidualUnitMod(Module):
# 初始化函数
def __init__(
self,
dim,
kernel_size: Union[int, Tuple[int, int, int]],
*,
dim_cond,
pad_mode: str = 'constant',
demod = True
):
super().__init__()
kernel_size = cast_tuple(kernel_size, 3)
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
assert height_kernel_size == width_kernel_size
# 线性层,用于将条件信息转换为相同维度
self.to_cond = nn.Linear(dim_cond, dim)
# 创建一个 Conv3DMod 层
self.conv = Conv3DMod(
dim = dim,
spatial_kernel = height_kernel_size,
time_kernel = time_kernel_size,
causal = True,
demod = demod,
pad_mode = pad_mode
)
# 创建一个 1x1x1 三维卷积层
self.conv_out = nn.Conv3d(dim, dim, 1)
# 前向传播函数
@beartype
def forward(
self,
x,
cond: Tensor,
):
res = x
cond = self.to_cond(cond)
# 进行卷积操作
x = self.conv(x, cond = cond)
x = F.elu(x)
x = self.conv_out(x)
x = F.elu(x)
return x + res
# 定义一个类 CausalConvTranspose3d,继承自 Module 类
# 初始化函数,定义了一个卷积转置层
def __init__(
self,
chan_in,
chan_out,
kernel_size: Union[int, Tuple[int, int, int]],
*,
time_stride,
**kwargs
):
# 调用父类的初始化函数
super().__init__()
# 将 kernel_size 转换为三元组
kernel_size = cast_tuple(kernel_size, 3)
# 分别获取时间、高度和宽度的卷积核大小
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
# 断言高度卷积核大小和宽度卷积核大小为奇数
assert is_odd(height_kernel_size) and is_odd(width_kernel_size)
# 设置上采样因子为时间步长
self.upsample_factor = time_stride
# 计算高度和宽度的填充值
height_pad = height_kernel_size // 2
width_pad = width_kernel_size // 2
# 设置步长和填充值
stride = (time_stride, 1, 1)
padding = (0, height_pad, width_pad)
# 创建一个三维卷积转置层
self.conv = nn.ConvTranspose3d(chan_in, chan_out, kernel_size, stride, padding = padding, **kwargs)
# 前向传播函数
def forward(self, x):
# 断言输入张量 x 的维度为 5
assert x.ndim == 5
# 获取时间维度的大小
t = x.shape[2]
# 对输入张量进行卷积转置操作
out = self.conv(x)
# 裁剪输出张量的时间维度,保留 t * 上采样因子 个时间步
out = out[..., :(t * self.upsample_factor), :, :]
# 返回处理后的输出张量
return out
# 定义了 LossBreakdown 命名元组,包含了不同损失的分解信息
LossBreakdown = namedtuple('LossBreakdown', [
'recon_loss',
'lfq_aux_loss',
'quantizer_loss_breakdown',
'perceptual_loss',
'adversarial_gen_loss',
'adaptive_adversarial_weight',
'multiscale_gen_losses',
'multiscale_gen_adaptive_weights'
])
# 定义了 DiscrLossBreakdown 命名元组,包含了鉴别器损失的分解信息
DiscrLossBreakdown = namedtuple('DiscrLossBreakdown', [
'discr_loss',
'multiscale_discr_losses',
'gradient_penalty'
])
# 定义了 VideoTokenizer 类,继承自 Module 类
class VideoTokenizer(Module):
# 初始化方法
@beartype
def __init__(
self,
*,
image_size,
layers: Tuple[Union[str, Tuple[str, int]], ...] = (
'residual',
'residual',
'residual'
),
residual_conv_kernel_size = 3,
num_codebooks = 1,
codebook_size: Optional[int] = None,
channels = 3,
init_dim = 64,
max_dim = float('inf'),
dim_cond = None,
dim_cond_expansion_factor = 4.,
input_conv_kernel_size: Tuple[int, int, int] = (7, 7, 7),
output_conv_kernel_size: Tuple[int, int, int] = (3, 3, 3),
pad_mode: str = 'constant',
lfq_entropy_loss_weight = 0.1,
lfq_commitment_loss_weight = 1.,
lfq_diversity_gamma = 2.5,
quantizer_aux_loss_weight = 1.,
lfq_activation = nn.Identity(),
use_fsq = False,
fsq_levels: Optional[List[int]] = None,
attn_dim_head = 32,
attn_heads = 8,
attn_dropout = 0.,
linear_attn_dim_head = 8,
linear_attn_heads = 16,
vgg: Optional[Module] = None,
vgg_weights: VGG16_Weights = VGG16_Weights.DEFAULT,
perceptual_loss_weight = 1e-1,
discr_kwargs: Optional[dict] = None,
multiscale_discrs: Tuple[Module, ...] = tuple(),
use_gan = True,
adversarial_loss_weight = 1.,
grad_penalty_loss_weight = 10.,
multiscale_adversarial_loss_weight = 1.,
flash_attn = True,
separate_first_frame_encoding = False
# 返回属性 device,返回 zero 属性的设备信息
@property
def device(self):
return self.zero.device
# 类方法,初始化并从路径加载模型
@classmethod
def init_and_load_from(cls, path, strict = True):
path = Path(path)
assert path.exists()
pkg = torch.load(str(path), map_location = 'cpu')
assert 'config' in pkg, 'model configs were not found in this saved checkpoint'
config = pickle.loads(pkg['config'])
tokenizer = cls(**config)
tokenizer.load(path, strict = strict)
return tokenizer
# 返回模型参数
def parameters(self):
return [
*self.conv_in.parameters(),
*self.conv_in_first_frame.parameters(),
*self.conv_out_first_frame.parameters(),
*self.conv_out.parameters(),
*self.encoder_layers.parameters(),
*self.decoder_layers.parameters(),
*self.encoder_cond_in.parameters(),
*self.decoder_cond_in.parameters(),
*self.quantizers.parameters()
]
# 返回鉴别器参数
def discr_parameters(self):
return self.discr.parameters()
# 复制模型用于评估
def copy_for_eval(self):
device = self.device
vae_copy = copy.deepcopy(self.cpu())
maybe_del_attr_(vae_copy, 'discr')
maybe_del_attr_(vae_copy, 'vgg')
maybe_del_attr_(vae_copy, 'multiscale_discrs')
vae_copy.eval()
return vae_copy.to(device)
# 返回模型状态字典
@remove_vgg
def state_dict(self, *args, **kwargs):
return super().state_dict(*args, **kwargs)
# 加载模型状态字典
@remove_vgg
def load_state_dict(self, *args, **kwargs):
return super().load_state_dict(*args, **kwargs)
# 保存模型
def save(self, path, overwrite = True):
path = Path(path)
assert overwrite or not path.exists(), f'{str(path)} already exists'
pkg = dict(
model_state_dict = self.state_dict(),
version = __version__,
config = self._configs
)
torch.save(pkg, str(path))
# 加载模型参数
def load(self, path, strict = True):
# 将路径转换为 Path 对象
path = Path(path)
# 断言路径存在
assert path.exists()
# 加载模型参数
pkg = torch.load(str(path))
state_dict = pkg.get('model_state_dict')
version = pkg.get('version')
# 断言模型参数存在
assert exists(state_dict)
# 如果版本信息存在,则打印加载的 tokenizer 版本信息
if exists(version):
print(f'loading checkpointed tokenizer from version {version}')
# 加载模型参数到当前模型
self.load_state_dict(state_dict, strict = strict)
# 编码视频
@beartype
def encode(
self,
video: Tensor,
quantize = False,
cond: Optional[Tensor] = None,
video_contains_first_frame = True
):
# 是否单独编码第一帧
encode_first_frame_separately = self.separate_first_frame_encoding and video_contains_first_frame
# 是否填充视频
if video_contains_first_frame:
video_len = video.shape[2]
video = pad_at_dim(video, (self.time_padding, 0), value = 0., dim = 2)
video_packed_shape = [torch.Size([self.time_padding]), torch.Size([]), torch.Size([video_len - 1])]
# 条件编码
assert (not self.has_cond) or exists(cond), '`cond` must be passed into tokenizer forward method since conditionable layers were specified'
if exists(cond):
assert cond.shape == (video.shape[0], self.dim_cond)
cond = self.encoder_cond_in(cond)
cond_kwargs = dict(cond = cond)
# 初始卷积
if encode_first_frame_separately:
pad, first_frame, video = unpack(video, video_packed_shape, 'b c * h w')
first_frame = self.conv_in_first_frame(first_frame)
video = self.conv_in(video)
if encode_first_frame_separately:
video, _ = pack([first_frame, video], 'b c * h w')
video = pad_at_dim(video, (self.time_padding, 0), dim = 2)
# 编码器层
for fn, has_cond in zip(self.encoder_layers, self.has_cond_across_layers):
layer_kwargs = dict()
if has_cond:
layer_kwargs = cond_kwargs
video = fn(video, **layer_kwargs)
maybe_quantize = identity if not quantize else self.quantizers
return maybe_quantize(video)
# 从编码索引解码
@beartype
def decode_from_code_indices(
self,
codes: Tensor,
cond: Optional[Tensor] = None,
video_contains_first_frame = True
):
assert codes.dtype in (torch.long, torch.int32)
if codes.ndim == 2:
video_code_len = codes.shape[-1]
assert divisible_by(video_code_len, self.fmap_size ** 2), f'flattened video ids must have a length ({video_code_len}) that is divisible by the fmap size ({self.fmap_size}) squared ({self.fmap_size ** 2})'
codes = rearrange(codes, 'b (f h w) -> b f h w', h = self.fmap_size, w = self.fmap_size)
quantized = self.quantizers.indices_to_codes(codes)
return self.decode(quantized, cond = cond, video_contains_first_frame = video_contains_first_frame)
# 解码
@beartype
def decode(
self,
quantized: Tensor,
cond: Optional[Tensor] = None,
video_contains_first_frame = True
):
# 检查是否需要单独解码第一帧
decode_first_frame_separately = self.separate_first_frame_encoding and video_contains_first_frame
# 获取批量大小
batch = quantized.shape[0]
# 条件输入,如果需要的话
assert (not self.has_cond) or exists(cond), '`cond` must be passed into tokenizer forward method since conditionable layers were specified'
if exists(cond):
assert cond.shape == (batch, self.dim_cond)
# 将条件输入传入条件编码器
cond = self.decoder_cond_in(cond)
cond_kwargs = dict(cond = cond)
# 解码器层
x = quantized
for fn, has_cond in zip(self.decoder_layers, reversed(self.has_cond_across_layers)):
layer_kwargs = dict()
if has_cond:
layer_kwargs = cond_kwargs
# 逐层解码
x = fn(x, **layer_kwargs)
# 转换为像素
if decode_first_frame_separately:
left_pad, xff, x = x[:, :, :self.time_padding], x[:, :, self.time_padding], x[:, :, (self.time_padding + 1):]
# 对输出进行卷积
out = self.conv_out(x)
outff = self.conv_out_first_frame(xff)
# 将第一帧和其余帧打包
video, _ = pack([outff, out], 'b c * h w')
else:
# 对输出进行卷积
video = self.conv_out(x)
# 如果视频包含第一帧,则移除填充
if video_contains_first_frame:
video = video[:, :, self.time_padding:]
return video
@torch.no_grad()
def tokenize(self, video):
# 设置为评估模式
self.eval()
return self.forward(video, return_codes = True)
@beartype
def forward(
self,
video_or_images: Tensor,
cond: Optional[Tensor] = None,
return_loss = False,
return_codes = False,
return_recon = False,
return_discr_loss = False,
return_recon_loss_only = False,
apply_gradient_penalty = True,
video_contains_first_frame = True,
adversarial_loss_weight = None,
multiscale_adversarial_loss_weight = None
# 主要类定义
class MagViT2(Module):
# 初始化方法
def __init__(self):
# 调用父类的初始化方法
super().__init__()
# 前向传播方法
def forward(self, x):
# 返回输入数据 x,即不做任何处理
return x