Lucidrains 系列项目源码解析(五十四)
.\lucidrains\magvit2-pytorch\magvit2_pytorch\optimizer.py
# 从 torch.optim 模块中导入 AdamW 和 Adam 优化器
from torch.optim import AdamW, Adam
# 将参数分为需要权重衰减和不需要权重衰减的两个列表
def separate_weight_decayable_params(params):
wd_params, no_wd_params = [], []
# 遍历参数列表,根据参数的维度将参数分别添加到对应的列表中
for param in params:
param_list = no_wd_params if param.ndim < 2 else wd_params
param_list.append(param)
return wd_params, no_wd_params
# 获取优化器
def get_optimizer(
params,
lr = 1e-4,
wd = 1e-2,
betas = (0.9, 0.99),
eps = 1e-8,
filter_by_requires_grad = False,
group_wd_params = True,
**kwargs
):
# 如果需要根据 requires_grad 过滤参数
if filter_by_requires_grad:
params = [t for t in params if t.requires_grad]
# 设置优化器的参数
opt_kwargs = dict(lr = lr, betas = betas, eps = eps)
# 如果权重衰减为 0,则返回 Adam 优化器
if wd == 0:
return Adam(params, **opt_kwargs)
# 设置权重衰减参数
opt_kwargs = {'weight_decay': wd, **opt_kwargs}
# 如果不对权重衰减参数进行分组,则返回 AdamW 优化器
if not group_wd_params:
return AdamW(params, **opt_kwargs)
# 将参数分为需要权重衰减和不需要权重衰减的两个列表
wd_params, no_wd_params = separate_weight_decayable_params(params)
# 组合参数列表,分别设置权重衰减
params = [
{'params': wd_params},
{'params': no_wd_params, 'weight_decay': 0},
]
return AdamW(params, **opt_kwargs)
.\lucidrains\magvit2-pytorch\magvit2_pytorch\trainer.py
# 导入必要的库
from pathlib import Path
from functools import partial
from contextlib import contextmanager, nullcontext
import torch
from torch import nn
from torch.nn import Module
from torch.utils.data import Dataset, random_split
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
import pytorch_warmup as warmup
from beartype import beartype
from beartype.typing import Optional, Literal, Union, Type
from magvit2_pytorch.optimizer import get_optimizer
from magvit2_pytorch.magvit2_pytorch import VideoTokenizer
from magvit2_pytorch.data import (
VideoDataset,
ImageDataset,
DataLoader,
video_tensor_to_gif
)
from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs
from einops import rearrange
from ema_pytorch import EMA
from pytorch_custom_utils import auto_unwrap_model
# 定义常量
VideosOrImagesLiteral = Union[
Literal['videos'],
Literal['images']
]
ConstantLRScheduler = partial(LambdaLR, lr_lambda = lambda step: 1.)
DEFAULT_DDP_KWARGS = DistributedDataParallelKwargs(
find_unused_parameters = True
)
# 定义辅助函数
def exists(v):
return v is not None
def cycle(dl):
while True:
for data in dl:
yield data
# 定义类
@auto_unwrap_model()
class VideoTokenizerTrainer:
@beartype
def __init__(
self,
model: VideoTokenizer,
*,
batch_size: int,
num_train_steps: int,
learning_rate: float = 1e-5,
grad_accum_every: int = 1,
apply_gradient_penalty_every: int = 4,
max_grad_norm: Optional[float] = None,
dataset: Optional[Dataset] = None,
dataset_folder: Optional[str] = None,
dataset_type: VideosOrImagesLiteral = 'videos',
checkpoints_folder = './checkpoints',
results_folder = './results',
random_split_seed = 42,
valid_frac = 0.05,
validate_every_step = 100,
checkpoint_every_step = 100,
num_frames = 17,
use_wandb_tracking = False,
discr_start_after_step = 0.,
warmup_steps = 1000,
scheduler: Optional[Type[LRScheduler]] = None,
scheduler_kwargs: dict = dict(),
accelerate_kwargs: dict = dict(),
ema_kwargs: dict = dict(),
optimizer_kwargs: dict = dict(),
dataset_kwargs: dict = dict()
@contextmanager
@beartype
def trackers(
self,
project_name: str,
run_name: Optional[str] = None,
hps: Optional[dict] = None
):
assert self.use_wandb_tracking
self.accelerator.init_trackers(project_name, config = hps)
if exists(run_name):
self.accelerator.trackers[0].run.name = run_name
yield
self.accelerator.end_training()
def log(self, **data_kwargs):
self.accelerator.log(data_kwargs, step = self.step)
@property
def device(self):
return self.model.device
@property
def is_main(self):
return self.accelerator.is_main_process
@property
def is_local_main(self):
return self.accelerator.is_local_main_process
def wait(self):
return self.accelerator.wait_for_everyone()
def print(self, msg):
return self.accelerator.print(msg)
@property
def ema_tokenizer(self):
return self.ema_model.ema_model
def tokenize(self, *args, **kwargs):
return self.ema_tokenizer.tokenize(*args, **kwargs)
# 保存模型参数到指定路径
def save(self, path, overwrite = True):
# 将路径转换为 Path 对象
path = Path(path)
# 如果 overwrite 为 False,则要求路径不存在
assert overwrite or not path.exists()
# 构建保存的模型参数字典
pkg = dict(
model = self.model.state_dict(),
ema_model = self.ema_model.state_dict(),
optimizer = self.optimizer.state_dict(),
discr_optimizer = self.discr_optimizer.state_dict(),
warmup = self.warmup.state_dict(),
scheduler = self.scheduler.state_dict(),
discr_warmup = self.discr_warmup.state_dict(),
discr_scheduler = self.discr_scheduler.state_dict(),
step = self.step
)
# 保存多尺度判别器优化器的参数
for ind, opt in enumerate(self.multiscale_discr_optimizers):
pkg[f'multiscale_discr_optimizer_{ind}'] = opt.state_dict()
# 使用 torch.save 保存模型参数到指定路径
torch.save(pkg, str(path))
# 加载模型参数
def load(self, path):
# 将路径转换为 Path 对象
path = Path(path)
# 要求路径存在
assert path.exists()
# 加载模型参数字典
pkg = torch.load(str(path))
# 加载模型参数到对应的模型、优化器等对象中
self.model.load_state_dict(pkg['model'])
self.ema_model.load_state_dict(pkg['ema_model'])
self.optimizer.load_state_dict(pkg['optimizer'])
self.discr_optimizer.load_state_dict(pkg['discr_optimizer'])
self.warmup.load_state_dict(pkg['warmup'])
self.scheduler.load_state_dict(pkg['scheduler'])
self.discr_warmup.load_state_dict(pkg['discr_warmup'])
self.discr_scheduler.load_state_dict(pkg['discr_scheduler'])
# 加载多尺度判别器优化器的参数
for ind, opt in enumerate(self.multiscale_discr_optimizers):
opt.load_state_dict(pkg[f'multiscale_discr_optimizer_{ind}'])
# 加载步数
self.step = pkg['step']
# 执行验证步骤
@torch.no_grad()
def valid_step(
self,
dl_iter,
save_recons = True,
num_save_recons = 1
):
# 将 EMA 模型设置为评估模式
self.ema_model.eval()
# 初始化重建损失
recon_loss = 0.
ema_recon_loss = 0.
# 初始化有效视频和重建视频列表
valid_videos = []
recon_videos = []
# 循环执行梯度��积次数
for _ in range(self.grad_accum_every):
# 从数据迭代器中获取有效视频数据
valid_video, = next(dl_iter)
valid_video = valid_video.to(self.device)
# 使用自动混合精度计算损失
with self.accelerator.autocast():
loss, _ = self.model(valid_video, return_recon_loss_only = True)
ema_loss, ema_recon_video = self.ema_model(valid_video, return_recon_loss_only = True)
# 累积重建损失
recon_loss += loss / self.grad_accum_every
ema_recon_loss += ema_loss / self.grad_accum_every
# 调整视频维度
if valid_video.ndim == 4:
valid_video = rearrange(valid_video, 'b c h w -> b c 1 h w')
# 将有效视频和重建视频添加到列表中
valid_videos.append(valid_video.cpu())
recon_videos.append(ema_recon_video.cpu())
# 记录验证重建损失和 EMA 重建损失
self.log(
valid_recon_loss = recon_loss.item(),
valid_ema_recon_loss = ema_recon_loss.item()
)
# 打印验证重建损失和 EMA 重建损失
self.print(f'validation recon loss {recon_loss:.3f}')
self.print(f'validation EMA recon loss {ema_recon_loss:.3f}')
# 如果需要保存重建视频
if not save_recons:
return
# 合并有效视频和重建视频
valid_videos = torch.cat(valid_videos)
recon_videos = torch.cat(recon_videos)
# 将重建视频像素值限制在 0 到 1 之间
recon_videos.clamp_(min = 0., max = 1.)
# 选择指定数量的有效视频和重建视频
valid_videos, recon_videos = map(lambda t: t[:num_save_recons], (valid_videos, recon_videos))
# 重排有效视频和重建视频的维度
real_and_recon = rearrange([valid_videos, recon_videos], 'n b c f h w -> c f (b h) (n w)')
# 生成 GIF 文件保存路径
validate_step = self.step // self.validate_every_step
sample_path = str(self.results_folder / f'sampled.{validate_step}.gif')
# 将视频张量保存为 GIF 文件
video_tensor_to_gif(real_and_recon, str(sample_path))
# 打印保存的样本路径
self.print(f'sample saved to {str(sample_path)}')
# 定义训练方法
def train(self):
# 获取当前步数
step = self.step
# 创建数据加载器的循环迭代器
dl_iter = cycle(self.dataloader)
valid_dl_iter = cycle(self.valid_dataloader)
# 当步数小于总训练步数时循环执行以下操作
while step < self.num_train_steps:
# 打印当前步数
self.print(f'step {step}')
# 执行训练步骤
self.train_step(dl_iter)
# 等待
# 如果是主进程且当前步数是验证间隔的倍数时
if self.is_main and not (step % self.validate_every_step):
# 执行验证步骤
self.valid_step(valid_dl_iter)
# 等待
# 如果是主进程且当前步数是保存检查点间隔的倍数时
if self.is_main and not (step % self.checkpoint_every_step):
# 计算检查点编号
checkpoint_num = step // self.checkpoint_every_step
# 检查点路径
checkpoint_path = self.checkpoints_folder / f'checkpoint.{checkpoint_num}.pt'
# 保存检查点
self.save(str(checkpoint_path))
# 等待
# 步数加一
step += 1
.\lucidrains\magvit2-pytorch\magvit2_pytorch\version.py
# 定义变量 __version__,赋值为字符串 '0.4.0'
__version__ = '0.4.0'
.\lucidrains\magvit2-pytorch\magvit2_pytorch\__init__.py
# 从 magvit2_pytorch 包中导入 MagViT2 和 VideoTokenizer 类
from magvit2_pytorch.magvit2_pytorch import (
MagViT2,
VideoTokenizer
)
# 从 magvit2_pytorch 包中导入 VideoTokenizerTrainer 类
from magvit2_pytorch.trainer import (
VideoTokenizerTrainer
)

MagViT2 - Pytorch
Implementation of MagViT2 from Language Model Beats Diffusion - Tokenizer is Key to Visual Generation in Pytorch. This currently holds SOTA for video generation / understanding.
The Lookup Free Quantizer proposed in the paper can be found in a separate repository. It should probably be explored for all other modalities, starting with audio
Please join if you are interested in replicating the tokenizer proposed in this paper out in the open
Appreciation
-
StabilityAI and 🤗 Huggingface for the generous sponsorship, as well as my other sponsors, for affording me the independence to open source artificial intelligence.
-
Louis Serrano for sharing some early initial runs, validating that the overall architecture converges with finite scalar quantization.
-
You? If you are a talented research engineer / scientist, feel free to contribute to cutting edge open source science!
Install
$ pip install magvit2-pytorch
Usage
from magvit2_pytorch import (
VideoTokenizer,
VideoTokenizerTrainer
)
tokenizer = VideoTokenizer(
image_size = 128,
init_dim = 64,
max_dim = 512,
codebook_size = 1024,
layers = (
'residual',
'compress_space',
('consecutive_residual', 2),
'compress_space',
('consecutive_residual', 2),
'linear_attend_space',
'compress_space',
('consecutive_residual', 2),
'attend_space',
'compress_time',
('consecutive_residual', 2),
'compress_time',
('consecutive_residual', 2),
'attend_time',
)
)
trainer = VideoTokenizerTrainer(
tokenizer,
dataset_folder = '/path/to/a/lot/of/media', # folder of either videos or images, depending on setting below
dataset_type = 'videos', # 'videos' or 'images', prior papers have shown pretraining on images to be effective for video synthesis
batch_size = 4,
grad_accum_every = 8,
learning_rate = 2e-5,
num_train_steps = 1_000_000
)
trainer.train()
# after a lot of training ...
# can use the EMA of the tokenizer
ema_tokenizer = trainer.ema_tokenizer
# mock video
video = torch.randn(1, 3, 17, 128, 128)
# tokenizing video to discrete codes
codes = ema_tokenizer.tokenize(video) # (1, 9, 16, 16) <- in this example, time downsampled by 4x and space downsampled by 8x. flatten token ids for (non)-autoregressive training
# sanity check
decoded_video = ema_tokenizer.decode_from_code_indices(codes)
assert torch.allclose(
decoded_video,
ema_tokenizer(video, return_recon = True)
)
To track your experiments on Weights & Biases set use_wandb_tracking = True on VideoTokenizerTrainer, and then use the .trackers context manager
trainer = VideoTokenizerTrainer(
use_wandb_tracking = True,
...
)
with trainer.trackers(project_name = 'magvit2', run_name = 'baseline'):
trainer.train()
Todo
-
Magvit2 Tokenizer
- add adversarial loss
- implement the blurpool for antialiasing in discriminator
- LFQ should be able to pass loss breakdown (commitment and entropy), and forwarded to the return of the tokenizer
- add conditioning for encoder decoder with residual modulatable conv 3d
-
decode_from_codebook_indicesshould be able to accept flattened ids and reshape to correct feature map dimensions and decode back to video - add trainer and manage discriminator training
- add adaptive rmsnorm and conditionable transformer layers
- completely generalize to multiple discriminators at different time scales (taking inspiration of multi-resolution discriminators from soundstream)
- complete multiscale discriminator losses
- auto-manage multiscale discriminator optimizers
- helper functions for crafting multi-resolution temporal discriminators (picking random consecutive frames)
- add attention
- use axial rotary embeddings for spatial
- add an optional autoregressive loss at some penultimate layer of the decoder - check literature to see if anyone else has done this unification of transformer decoder + tokenizer in one architecture
-
Improvise a RQ Video Transformer, as residual LFQ actually makes sense now
-
MaskGit
Citations
@misc{yu2023language,
title = {Language Model Beats Diffusion -- Tokenizer is Key to Visual Generation},
author = {Lijun Yu and José Lezama and Nitesh B. Gundavarapu and Luca Versari and Kihyuk Sohn and David Minnen and Yong Cheng and Agrim Gupta and Xiuye Gu and Alexander G. Hauptmann and Boqing Gong and Ming-Hsuan Yang and Irfan Essa and David A. Ross and Lu Jiang},
year = {2023},
eprint = {2310.05737},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
@inproceedings{dao2022flashattention,
title = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
author = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
booktitle = {Advances in Neural Information Processing Systems},
year = {2022}
}
@article{Zhang2021TokenST,
title = {Token Shift Transformer for Video Classification},
author = {Hao Zhang and Y. Hao and Chong-Wah Ngo},
journal = {Proceedings of the 29th ACM International Conference on Multimedia},
year = {2021}
}
@inproceedings{Arora2023ZoologyMA,
title = {Zoology: Measuring and Improving Recall in Efficient Language Models},
author = {Simran Arora and Sabri Eyuboglu and Aman Timalsina and Isys Johnson and Michael Poli and James Zou and Atri Rudra and Christopher R'e},
year = {2023},
url = {https://api.semanticscholar.org/CorpusID:266149332}
}
.\lucidrains\magvit2-pytorch\setup.py
# 导入设置工具和查找包工具
from setuptools import setup, find_packages
# 执行版本文件中的代码,将版本信息导入当前环境
exec(open('magvit2_pytorch/version.py').read())
# 设置包的元数据
setup(
# 包名
name = 'magvit2-pytorch',
# 查找所有包
packages = find_packages(),
# 版本号
version = __version__,
# 许可证
license='MIT',
# 描述
description = 'MagViT2 - Pytorch',
# 长描述内容类型
long_description_content_type = 'text/markdown',
# 作者
author = 'Phil Wang',
# 作者邮箱
author_email = 'lucidrains@gmail.com',
# 项目链接
url = 'https://github.com/lucidrains/magvit2-pytorch',
# 关键词
keywords = [
'artificial intelligence',
'deep learning',
'transformer',
'attention mechanisms',
'generative video model'
],
# 安装依赖
install_requires=[
'accelerate>=0.24.0',
'beartype',
'einops>=0.7.0',
'ema-pytorch>=0.2.4',
'pytorch-warmup',
'gateloop-transformer>=0.2.2',
'kornia',
'opencv-python',
'pillow',
'pytorch-custom-utils>=0.0.9',
'numpy',
'vector-quantize-pytorch>=1.11.8',
'taylor-series-linear-attention>=0.1.5',
'torch',
'torchvision',
'x-transformers'
],
# 分类
classifiers=[
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\make-a-video-pytorch\make_a_video_pytorch\attend.py
# 导入必要的库
from functools import wraps
from packaging import version
from collections import namedtuple
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange
# 定义一个命名元组,用于存储注意力机制的配置信息
AttentionConfig = namedtuple('AttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
# 定义一个辅助函数,用于检查变量是否存在
def exists(val):
return val is not None
# 定义一个装饰器函数,确保被装饰的函数只执行一次
def once(fn):
called = False
@wraps(fn)
def inner(x):
nonlocal called
if called:
return
called = True
return fn(x)
return inner
# 用装饰器once包装print函数,确保只打印一次
print_once = once(print)
# 主要类定义
class Attend(nn.Module):
def __init__(
self,
dropout = 0.,
flash = False,
causal = False
):
super().__init__()
self.dropout = dropout
self.attn_dropout = nn.Dropout(dropout)
self.causal = causal
self.flash = flash
assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
# 确定在cuda和cpu上的高效注意力配置
self.cpu_config = AttentionConfig(True, True, True)
self.cuda_config = None
if not torch.cuda.is_available() or not flash:
return
device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
if device_properties.major == 8 and device_properties.minor == 0:
print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
self.cuda_config = AttentionConfig(True, False, False)
else:
print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
self.cuda_config = AttentionConfig(False, True, True)
# 实现flash attention
def flash_attn(self, q, k, v):
_, heads, q_len, _, k_len, is_cuda = *q.shape, k.shape[-2], q.is_cuda
q, k, v = map(lambda t: t.contiguous(), (q, k, v))
# 检查是否有兼容的设备用于flash attention
config = self.cuda_config if is_cuda else self.cpu_config
# 使用torch.backends.cuda.sdp_kernel(**config._asdict())来调用pytorch 2.0的flash attention
with torch.backends.cuda.sdp_kernel(**config._asdict()):
out = F.scaled_dot_product_attention(
q, k, v,
dropout_p = self.dropout if self.training else 0.,
is_causal = self.causal
)
return out
# 前向传播函数
def forward(self, q, k, v, bias = None):
"""
einstein notation
b - batch
h - heads
n, i, j - sequence length (base sequence length, source, target)
d - feature dimension
"""
q_len, k_len, device = q.shape[-2], k.shape[-2], q.device
if self.flash:
assert not exists(bias)
return self.flash_attn(q, k, v)
scale = q.shape[-1] ** -0.5
# 相似度计算
sim = einsum(f"b h i d, b h j d -> b h i j", q, k) * scale
# 注意力偏置
if exists(bias):
sim = sim + bias
# 因果关系
if self.causal:
i, j = sim.shape[-2:]
causal_mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1)
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
# 注意力计算
attn = sim.softmax(dim = -1)
attn = self.attn_dropout(attn)
# 聚合值
out = einsum(f"b h i j, b h j d -> b h i d", attn, v)
return out
.\lucidrains\make-a-video-pytorch\make_a_video_pytorch\make_a_video.py
# 导入数学库
import math
# 导入 functools 库
import functools
# 从 operator 库中导入 mul 函数
from operator import mul
# 导入 torch 库
import torch
# 从 torch.nn 中导入 functional 模块
import torch.nn.functional as F
# 从 torch 中导入 nn、einsum 模块
from torch import nn, einsum
# 从 einops 中导入 rearrange、repeat、pack、unpack 函数,以及 Rearrange 类
from einops import rearrange, repeat, pack, unpack
from einops.layers.torch import Rearrange
# 从 make_a_video_pytorch.attend 模块中导入 Attend 类
# 辅助函数
# 判断变量是否存在
def exists(val):
return val is not None
# 如果变量存在,则返回变量值,否则返回默认值
def default(val, d):
return val if exists(val) else d
# 对元组中的元素进行乘法运算
def mul_reduce(tup):
return functools.reduce(mul, tup)
# 判断一个数是否可以被另一个数整除
def divisible_by(numer, denom):
return (numer % denom) == 0
# 创建 nn.ModuleList 对象
mlist = nn.ModuleList
# 用于时间条件
# 正弦位置编码
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim, theta = 10000):
super().__init__()
self.theta = theta
self.dim = dim
def forward(self, x):
dtype, device = x.dtype, x.device
assert dtype == torch.float, 'input to sinusoidal pos emb must be a float type'
half_dim = self.dim // 2
emb = math.log(self.theta) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device = device, dtype = dtype) * -emb)
emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j')
return torch.cat((emb.sin(), emb.cos()), dim = -1).type(dtype)
# 3D 归一化
# RMS 归一化
class RMSNorm(nn.Module):
def __init__(self, chan, dim = 1):
super().__init__()
self.dim = dim
self.gamma = nn.Parameter(torch.ones(chan))
def forward(self, x):
dim = self.dim
right_ones = (dim + 1) if dim < 0 else (x.ndim - 1 - dim)
gamma = self.gamma.reshape(-1, *((1,) * right_ones))
return F.normalize(x, dim = dim) * (x.shape[dim] ** 0.5) * gamma
# 前馈网络
# 移位令牌
def shift_token(t):
t, t_shift = t.chunk(2, dim = 1)
t_shift = F.pad(t_shift, (0, 0, 0, 0, 1, -1), value = 0.)
return torch.cat((t, t_shift), dim = 1)
# GEGLU 激活函数
class GEGLU(nn.Module):
def forward(self, x):
x, gate = x.chunk(2, dim = 1)
return x * F.gelu(gate)
# 前馈网络
class FeedForward(nn.Module):
def __init__(self, dim, mult = 4):
super().__init__()
inner_dim = int(dim * mult * 2 / 3)
self.proj_in = nn.Sequential(
nn.Conv3d(dim, inner_dim * 2, 1, bias = False),
GEGLU()
)
self.proj_out = nn.Sequential(
RMSNorm(inner_dim),
nn.Conv3d(inner_dim, dim, 1, bias = False)
)
def forward(self, x, enable_time = True):
is_video = x.ndim == 5
enable_time &= is_video
if not is_video:
x = rearrange(x, 'b c h w -> b c 1 h w')
x = self.proj_in(x)
if enable_time:
x = shift_token(x)
out = self.proj_out(x)
if not is_video:
out = rearrange(out, 'b c 1 h w -> b c h w')
return out
# 最佳相对位置编码
# 连续位置偏置
class ContinuousPositionBias(nn.Module):
""" from https://arxiv.org/abs/2111.09883 """
def __init__(
self,
*,
dim,
heads,
num_dims = 1,
layers = 2
):
super().__init__()
self.num_dims = num_dims
self.net = nn.ModuleList([])
self.net.append(nn.Sequential(nn.Linear(self.num_dims, dim), nn.SiLU()))
for _ in range(layers - 1):
self.net.append(nn.Sequential(nn.Linear(dim, dim), nn.SiLU()))
self.net.append(nn.Linear(dim, heads)
@property
def device(self):
return next(self.parameters()).device
# 定义一个前向传播函数,接受多个维度参数
def forward(self, *dimensions):
# 获取当前设备
device = self.device
# 将维度转换为张量
shape = torch.tensor(dimensions, device=device)
# 计算相对位置的形状
rel_pos_shape = 2 * shape - 1
# 计算步长
# 将相对位置形状进行翻转,并计算累积乘积
strides = torch.flip(rel_pos_shape, (0,)).cumprod(dim=-1)
# 在步长张量两端填充1,并再次翻转
strides = torch.flip(F.pad(strides, (1, -1), value=1), (0,))
# 获取所有位置并计算所有相对距离
# 生成每个维度的位置张量
positions = [torch.arange(d, device=device) for d in dimensions]
# 创建网格坐标
grid = torch.stack(torch.meshgrid(*positions, indexing='ij'), dim=-1)
# 重新排列网格坐标
grid = rearrange(grid, '... c -> (...) c')
# 计算相对距离
rel_dist = rearrange(grid, 'i c -> i 1 c') - rearrange(grid, 'j c -> 1 j c')
# 获取所有维度上的相对位置
# 生成每个维度上的相对位置张量
rel_positions = [torch.arange(-d + 1, d, device=device) for d in dimensions]
# 创建相对位置网格
rel_pos_grid = torch.stack(torch.meshgrid(*rel_positions, indexing='ij'), dim=-1)
# 重新排列相对位置网格
rel_pos_grid = rearrange(rel_pos_grid, '... c -> (...) c')
# MLP 输入
# 将相对位置网格转换为浮点数
bias = rel_pos_grid.float()
# 遍历网络的每一层
for layer in self.net:
# 将相对位置网格传入每一层
bias = layer(bias)
# 将相对距离转换为偏置的索引
# 将相对距离加上形状减一确保为正数
rel_dist += (shape - 1)
# 乘以步长
rel_dist *= strides
# 沿着最后一个维度求和,得到索引
rel_dist_indices = rel_dist.sum(dim=-1)
# 选择每个唯一相对位置组合的偏置
# 根据索引选择偏置
bias = bias[rel_dist_indices]
# 重新排列偏置
return rearrange(bias, 'i j h -> h i j')
# 定义注意力机制类
class Attention(nn.Module):
def __init__(
self,
dim,
dim_head = 64,
heads = 8,
flash = False,
causal = False
):
super().__init__()
self.heads = heads
self.scale = dim_head ** -0.5
inner_dim = dim_head * heads
# 创建 Attend 对象
self.attend = Attend(flash = flash, causal = causal)
# 创建 RMSNorm 对象
self.norm = RMSNorm(dim, dim = -1)
# 创建线性变换层
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
# 初始化权重为零,实现跳跃连接
nn.init.zeros_(self.to_out.weight.data)
def forward(
self,
x,
rel_pos_bias = None
):
x = self.norm(x)
q, k, v = self.to_q(x), *self.to_kv(x).chunk(2, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v))
out = self.attend(q, k, v, bias = rel_pos_bias)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
# 定义主要贡献 - 伪 3D 卷积类
class PseudoConv3d(nn.Module):
def __init__(
self,
dim,
dim_out = None,
kernel_size = 3,
*,
temporal_kernel_size = None,
**kwargs
):
super().__init__()
dim_out = default(dim_out, dim)
temporal_kernel_size = default(temporal_kernel_size, kernel_size)
# 创建空间卷积层和时间卷积层
self.spatial_conv = nn.Conv2d(dim, dim_out, kernel_size = kernel_size, padding = kernel_size // 2)
self.temporal_conv = nn.Conv1d(dim_out, dim_out, kernel_size = temporal_kernel_size, padding = temporal_kernel_size // 2) if kernel_size > 1 else None
# 初始化时间卷积层的权重为单位矩阵,偏置为零
if exists(self.temporal_conv):
nn.init.dirac_(self.temporal_conv.weight.data)
nn.init.zeros_(self.temporal_conv.bias.data)
def forward(
self,
x,
enable_time = True
):
b, c, *_, h, w = x.shape
is_video = x.ndim == 5
enable_time &= is_video
if is_video:
x = rearrange(x, 'b c f h w -> (b f) c h w')
x = self.spatial_conv(x)
if is_video:
x = rearrange(x, '(b f) c h w -> b c f h w', b = b)
if not enable_time or not exists(self.temporal_conv):
return x
x = rearrange(x, 'b c f h w -> (b h w) c f')
x = self.temporal_conv(x)
x = rearrange(x, '(b h w) c f -> b c f h w', h = h, w = w)
return x
# 定义分解的时空注意力类
class SpatioTemporalAttention(nn.Module):
def __init__(
self,
dim,
*,
dim_head = 64,
heads = 8,
add_feed_forward = True,
ff_mult = 4,
pos_bias = True,
flash = False,
causal_time_attn = False
):
super().__init__()
assert not (flash and pos_bias), 'learned positional attention bias is not compatible with flash attention'
# 创建空间注意力和时间注意力对象
self.spatial_attn = Attention(dim = dim, dim_head = dim_head, heads = heads, flash = flash)
self.spatial_rel_pos_bias = ContinuousPositionBias(dim = dim // 2, heads = heads, num_dims = 2) if pos_bias else None
self.temporal_attn = Attention(dim = dim, dim_head = dim_head, heads = heads, flash = flash, causal = causal_time_attn)
self.temporal_rel_pos_bias = ContinuousPositionBias(dim = dim // 2, heads = heads, num_dims = 1) if pos_bias else None
self.has_feed_forward = add_feed_forward
if not add_feed_forward:
return
# 创建前馈网络对象
self.ff = FeedForward(dim = dim, mult = ff_mult)
def forward(
self,
x,
enable_time = True
):
# 从输入张量 x 的形状中提取出 b, c, h, w,*_, h, w 表示忽略中间的维度,只取最后两个维度
b, c, *_, h, w = x.shape
# 判断输入张量是否为视频,即维度是否为 5
is_video = x.ndim == 5
# 更新 enable_time 变量,如果是视频则为 True
enable_time &= is_video
# 根据输入张量的维度不同进行不同的重排操作
if is_video:
x = rearrange(x, 'b c f h w -> (b f) (h w) c')
else:
x = rearrange(x, 'b c h w -> b (h w) c')
# 如果存在空间相对位置偏置函数,则计算空间相对位置偏置
space_rel_pos_bias = self.spatial_rel_pos_bias(h, w) if exists(self.spatial_rel_pos_bias) else None
# 对输入张量进行空间注意力操作,并加上原始输入张量
x = self.spatial_attn(x, rel_pos_bias = space_rel_pos_bias) + x
# 根据输入张量的维度不同进行不同的重排操作,恢复原始形状
if is_video:
x = rearrange(x, '(b f) (h w) c -> b c f h w', b = b, h = h, w = w)
else:
x = rearrange(x, 'b (h w) c -> b c h w', h = h, w = w)
# 如果 enable_time 为 True,则进行时间维度的处理
if enable_time:
# 对输入张量进行时间维度的重排操作
x = rearrange(x, 'b c f h w -> (b h w) f c')
# 如果存在时间相对位置偏置函数,则计算时间相对位置偏置
time_rel_pos_bias = self.temporal_rel_pos_bias(x.shape[1]) if exists(self.temporal_rel_pos_bias) else None
# 对输入张量进行时间注意力操作,并加上原始输入张量
x = self.temporal_attn(x, rel_pos_bias = time_rel_pos_bias) + x
# 恢复原始形状
x = rearrange(x, '(b h w) f c -> b c f h w', w = w, h = h)
# 如果存在前馈网络,则对输入张量进行前馈操作,并加上原始输入张量
if self.has_feed_forward:
x = self.ff(x, enable_time = enable_time) + x
# 返回处理后的张量
return x
# 定义 ResNet 块
class Block(nn.Module):
def __init__(
self,
dim,
dim_out,
kernel_size = 3,
temporal_kernel_size = None,
groups = 8
):
super().__init__()
# 创建伪 3D 卷积层
self.project = PseudoConv3d(dim, dim_out, 3)
# 添加 Group Normalization
self.norm = nn.GroupNorm(groups, dim_out)
# 添加 SiLU 激活函数
self.act = nn.SiLU()
def forward(
self,
x,
scale_shift = None,
enable_time = False
):
# 对输入进行卷积操作
x = self.project(x, enable_time = enable_time)
# 对卷积结果进行归一化
x = self.norm(x)
# 如果存在 scale_shift 参数,则进行缩放和平移操作
if exists(scale_shift):
scale, shift = scale_shift
x = x * (scale + 1) + shift
return self.act(x)
# 定义 ResNet 块
class ResnetBlock(nn.Module):
def __init__(
self,
dim,
dim_out,
*,
timestep_cond_dim = None,
groups = 8
):
super().__init__()
self.timestep_mlp = None
# 如果存在时间步条件维度,则创建 MLP 网络
if exists(timestep_cond_dim):
self.timestep_mlp = nn.Sequential(
nn.SiLU(),
nn.Linear(timestep_cond_dim, dim_out * 2)
)
# 创建两个 Block 实例
self.block1 = Block(dim, dim_out, groups = groups)
self.block2 = Block(dim_out, dim_out, groups = groups)
# 如果输入维度和输出维度不同,创建伪 3D 卷积层,否则创建恒等映射
self.res_conv = PseudoConv3d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
def forward(
self,
x,
timestep_emb = None,
enable_time = True
):
# 断言时间步条件嵌入和时间步 MLP 是否同时存在
assert not (exists(timestep_emb) ^ exists(self.timestep_mlp))
scale_shift = None
# 如果存在时间步 MLP 和时间步嵌入,则进行处理
if exists(self.timestep_mlp) and exists(timestep_emb):
time_emb = self.timestep_mlp(timestep_emb)
to_einsum_eq = 'b c 1 1 1' if x.ndim == 5 else 'b c 1 1'
time_emb = rearrange(time_emb, f'b c -> {to_einsum_eq}')
scale_shift = time_emb.chunk(2, dim = 1)
# 对输入进行第一个 Block 处理
h = self.block1(x, scale_shift = scale_shift, enable_time = enable_time)
# 对第一�� Block 处理结果进行第二个 Block 处理
h = self.block2(h, enable_time = enable_time)
return h + self.res_conv(x)
# 像素混洗上采样和下采样,其中时间维度可以配置
# 定义下采样模块
class Downsample(nn.Module):
def __init__(
self,
dim,
downsample_space = True,
downsample_time = False,
nonlin = False
):
super().__init__()
assert downsample_space or downsample_time
# 如果需要空间下采样,则创建相应的模块
self.down_space = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (c p1 p2) h w', p1 = 2, p2 = 2),
nn.Conv2d(dim * 4, dim, 1, bias = False),
nn.SiLU() if nonlin else nn.Identity()
) if downsample_space else None
# 如果需要时间下采样,则创建相应的模块
self.down_time = nn.Sequential(
Rearrange('b c (f p) h w -> b (c p) f h w', p = 2),
nn.Conv3d(dim * 2, dim, 1, bias = False),
nn.SiLU() if nonlin else nn.Identity()
) if downsample_time else None
def forward(
self,
x,
enable_time = True
):
is_video = x.ndim == 5
if is_video:
x = rearrange(x, 'b c f h w -> b f c h w')
x, ps = pack([x], '* c h w')
# 如果存在空间下采样模块,则进行处理
if exists(self.down_space):
x = self.down_space(x)
if is_video:
x, = unpack(x, ps, '* c h w')
x = rearrange(x, 'b f c h w -> b c f h w')
# 如果不是视频或者不存在时间下采样模块或者不启用时间,则直接返回结果
if not is_video or not exists(self.down_time) or not enable_time:
return x
# 如果需要时间下采样,则进行处理
x = self.down_time(x)
return x
# 定义上采样模块
class Upsample(nn.Module):
def __init__(
self,
dim,
upsample_space = True,
upsample_time = False,
nonlin = False
# 定义一个类,继承自 nn.Module
):
# 调用父类的初始化方法
super().__init__()
# 断言是否需要上采样空间或时间
assert upsample_space or upsample_time
# 如果需要上采样空间,则定义空间上采样的操作
self.up_space = nn.Sequential(
nn.Conv2d(dim, dim * 4, 1), # 使用 1x1 卷积进行通道扩展
nn.SiLU() if nonlin else nn.Identity(), # 使用 SiLU 激活函数或者恒等映射
Rearrange('b (c p1 p2) h w -> b c (h p1) (w p2)', p1 = 2, p2 = 2) # 重新排列张量维度
) if upsample_space else None
# 如果需要上采样时间,则定义时间上采样的操作
self.up_time = nn.Sequential(
nn.Conv3d(dim, dim * 2, 1), # 使用 1x1x1 卷积进行通道扩展
nn.SiLU() if nonlin else nn.Identity(), # 使用 SiLU 激活函数或者恒等映射
Rearrange('b (c p) f h w -> b c (f p) h w', p = 2) # 重新排列张量维度
) if upsample_time else None
# 初始化函数
self.init_()
# 初始化函数
def init_(self):
# 如果存在空间上采样操作,则初始化空间上采样的卷积层
if exists(self.up_space):
self.init_conv_(self.up_space[0], 4)
# 如果存在时间上采样操作,则初始化时间上采样的卷积层
if exists(self.up_time):
self.init_conv_(self.up_time[0], 2)
# 初始化卷积层的权重
def init_conv_(self, conv, factor):
o, *remain_dims = conv.weight.shape
conv_weight = torch.empty(o // factor, *remain_dims)
nn.init.kaiming_uniform_(conv_weight)
conv_weight = repeat(conv_weight, 'o ... -> (o r) ...', r = factor)
conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data)
# 前向传播函数
def forward(
self,
x,
enable_time = True
):
# 判断输入是否为视频
is_video = x.ndim == 5
# 如果是视频,则重新排列张量维度
if is_video:
x = rearrange(x, 'b c f h w -> b f c h w')
x, ps = pack([x], '* c h w')
# 如果存在空间上采样操作,则进行空间上采样
if exists(self.up_space):
x = self.up_space(x)
# 如果是视频,则恢复原始张量维度
if is_video:
x, = unpack(x, ps, '* c h w')
x = rearrange(x, 'b f c h w -> b c f h w')
# 如果不是视频或者不存在时间上采样���作或者不启用时间上采样,则直接返回结果
if not is_video or not exists(self.up_time) or not enable_time:
return x
# 进行时间上采样
x = self.up_time(x)
return x
# space time factorized 3d unet
class SpaceTimeUnet(nn.Module):
def __init__(
self,
*,
dim, # 维度
channels = 3, # 通道数,默认为3
dim_mult = (1, 2, 4, 8), # 维度倍增因子
self_attns = (False, False, False, True), # 是否使用自注意力机制
temporal_compression = (False, True, True, True), # 是否进行时间压缩
resnet_block_depths = (2, 2, 2, 2), # ResNet块的深度
attn_dim_head = 64, # 注意力机制的头数
attn_heads = 8, # 注意力头数
condition_on_timestep = True, # 是否在时间步上进行条件化
attn_pos_bias = True, # 是否使用位置偏置
flash_attn = False, # 是否使用快闪注意力
causal_time_attn = False # 是否使用因果时间注意力
):
super().__init__()
assert len(dim_mult) == len(self_attns) == len(temporal_compression) == len(resnet_block_depths)
num_layers = len(dim_mult)
dims = [dim, *map(lambda mult: mult * dim, dim_mult)] # 计算每一层的维度
dim_in_out = zip(dims[:-1], dims[1:])
# determine the valid multiples of the image size and frames of the video
self.frame_multiple = 2 ** sum(tuple(map(int, temporal_compression))) # 计算视频帧数的倍数
self.image_size_multiple = 2 ** num_layers # 计算图像大小的倍数
# timestep conditioning for DDPM, not to be confused with the time dimension of the video
self.to_timestep_cond = None
timestep_cond_dim = (dim * 4) if condition_on_timestep else None
if condition_on_timestep:
self.to_timestep_cond = nn.Sequential(
SinusoidalPosEmb(dim), # 添加正弦位置编码
nn.Linear(dim, timestep_cond_dim), # 线性变换
nn.SiLU() # 激活函数
)
# layers
self.downs = mlist([]) # 下采样层
self.ups = mlist([]) # 上采样层
attn_kwargs = dict(
dim_head = attn_dim_head,
heads = attn_heads,
pos_bias = attn_pos_bias,
flash = flash_attn,
causal_time_attn = causal_time_attn
)
mid_dim = dims[-1]
self.mid_block1 = ResnetBlock(mid_dim, mid_dim, timestep_cond_dim = timestep_cond_dim) # 中间块1
self.mid_attn = SpatioTemporalAttention(dim = mid_dim, **attn_kwargs) # 中间注意力机制
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, timestep_cond_dim = timestep_cond_dim) # 中间块2
for _, self_attend, (dim_in, dim_out), compress_time, resnet_block_depth in zip(range(num_layers), self_attns, dim_in_out, temporal_compression, resnet_block_depths):
assert resnet_block_depth >= 1
self.downs.append(mlist([
ResnetBlock(dim_in, dim_out, timestep_cond_dim = timestep_cond_dim), # 下采样块
mlist([ResnetBlock(dim_out, dim_out) for _ in range(resnet_block_depth)]), # ResNet块
SpatioTemporalAttention(dim = dim_out, **attn_kwargs) if self_attend else None, # 注意力机制
Downsample(dim_out, downsample_time = compress_time) # 下采样
]))
self.ups.append(mlist([
ResnetBlock(dim_out * 2, dim_in, timestep_cond_dim = timestep_cond_dim), # 上采样块
mlist([ResnetBlock(dim_in + (dim_out if ind == 0 else 0), dim_in) for ind in range(resnet_block_depth)]), # ResNet块
SpatioTemporalAttention(dim = dim_in, **attn_kwargs) if self_attend else None, # 注意力机制
Upsample(dim_out, upsample_time = compress_time) # 上采样
]))
self.skip_scale = 2 ** -0.5 # 论文显示更快的收敛速度
self.conv_in = PseudoConv3d(dim = channels, dim_out = dim, kernel_size = 7, temporal_kernel_size = 3) # 输入卷积层
self.conv_out = PseudoConv3d(dim = dim, dim_out = channels, kernel_size = 3, temporal_kernel_size = 3) # 输出卷积层
def forward(
self,
x,
timestep = None,
enable_time = True
):
# some asserts
# 断言条件:self.to_timestep_cond 和 timestep 存在性相同
assert not (exists(self.to_timestep_cond) ^ exists(timestep))
# 判断 x 是否为视频,维度是否为5
is_video = x.ndim == 5
# 如果启用时间和 x 是视频
if enable_time and is_video:
# 获取视频帧数
frames = x.shape[2]
# 断言条件:视频帧数必须能被 self.frame_multiple 整除
assert divisible_by(frames, self.frame_multiple), f'number of frames on the video ({frames}) must be divisible by the frame multiple ({self.frame_multiple})'
# 获取图片或视频的高度和宽度
height, width = x.shape[-2:]
# 断言条件:图片或视频的高度和宽度必须是 self.image_size_multiple 的倍数
assert divisible_by(height, self.image_size_multiple) and divisible_by(width, self.image_size_multiple), f'height and width of the image or video must be a multiple of {self.image_size_multiple}'
# main logic
# 如果 timestep 存在,则根据条件转换为 t
t = self.to_timestep_cond(rearrange(timestep, '... -> (...)')) if exists(timestep) else None
# 对输入 x 进行卷积操作
x = self.conv_in(x, enable_time = enable_time)
# 初始化 hiddens 列表
hiddens = []
# 遍历 downs 列表中的元素
for init_block, blocks, maybe_attention, downsample in self.downs:
# 对 x 进行初始化块操作
x = init_block(x, t, enable_time = enable_time)
# 将当前 x 添加到 hiddens 列表中
hiddens.append(x.clone())
# 遍历 blocks 列表中的元素
for block in blocks:
# 对 x 进行块操作
x = block(x, enable_time = enable_time)
# 如果 maybe_attention 存在,则对 x 进行注意力操作
if exists(maybe_attention):
x = maybe_attention(x, enable_time = enable_time)
# 将当前 x 添加到 hiddens 列表中
hiddens.append(x.clone())
# 对 x 进行下采样操作
x = downsample(x, enable_time = enable_time)
# 对 x 进行中间块1操作
x = self.mid_block1(x, t, enable_time = enable_time)
# 对 x 进行中间注意力操作
x = self.mid_attn(x, enable_time = enable_time)
# 对 x 进行中间块2操作
x = self.mid_block2(x, t, enable_time = enable_time)
# 遍历反转后的 ups 列表中的��素
for init_block, blocks, maybe_attention, upsample in reversed(self.ups):
# 对 x 进行上采样操作
x = upsample(x, enable_time = enable_time)
# 将 hiddens 列表中的元素与 x 进行拼接
x = torch.cat((hiddens.pop() * self.skip_scale, x), dim = 1)
# 对 x 进行初始化块操作
x = init_block(x, t, enable_time = enable_time)
# 将 hiddens 列表中的元素与 x 进行拼接
x = torch.cat((hiddens.pop() * self.skip_scale, x), dim = 1)
# 遍历 blocks 列表中的元素
for block in blocks:
# 对 x 进行块操作
x = block(x, enable_time = enable_time)
# 如果 maybe_attention 存在,则对 x 进行注意力操作
if exists(maybe_attention):
x = maybe_attention(x, enable_time = enable_time)
# 对 x 进行输出卷积操作
x = self.conv_out(x, enable_time = enable_time)
# 返回结果 x
return x
.\lucidrains\make-a-video-pytorch\make_a_video_pytorch\__init__.py
# 从 make_a_video_pytorch.make_a_video 模块中导入 PseudoConv3d, SpatioTemporalAttention 类
from make_a_video_pytorch.make_a_video import PseudoConv3d, SpatioTemporalAttention
# 从 make_a_video_pytorch.make_a_video 模块中导入 ResnetBlock, Downsample, Upsample 类
from make_a_video_pytorch.make_a_video import ResnetBlock, Downsample, Upsample
# 从 make_a_video_pytorch.make_a_video 模块中导入 SpaceTimeUnet 类
from make_a_video_pytorch.make_a_video import SpaceTimeUnet

Make-A-Video - Pytorch (wip)
Implementation of Make-A-Video, new SOTA text to video generator from Meta AI, in Pytorch. They combine pseudo-3d convolutions (axial convolutions) and temporal attention and show much better temporal fusion.
The pseudo-3d convolutions isn't a new concept. It has been explored before in other contexts, say for protein contact prediction as "dimensional hybrid residual networks".
The gist of the paper comes down to, take a SOTA text-to-image model (here they use DALL-E2, but the same learning points would easily apply to Imagen), make a few minor modifications for attention across time and other ways to skimp on the compute cost, do frame interpolation correctly, get a great video model out.
Appreciation
-
Stability.ai for the generous sponsorship to work on cutting edge artificial intelligence research
-
Jonathan Ho for bringing about a revolution in generative artificial intelligence through his seminal paper
-
Alex for einops, an abstraction that is simply genius. No other word for it.
Install
$ pip install make-a-video-pytorch
Usage
Passing in video features
import torch
from make_a_video_pytorch import PseudoConv3d, SpatioTemporalAttention
conv = PseudoConv3d(
dim = 256,
kernel_size = 3
)
attn = SpatioTemporalAttention(
dim = 256,
dim_head = 64,
heads = 8
)
video = torch.randn(1, 256, 8, 16, 16) # (batch, features, frames, height, width)
conv_out = conv(video) # (1, 256, 8, 16, 16)
attn_out = attn(video) # (1, 256, 8, 16, 16)
Passing in images (if one were to pretrain on images first), both temporal convolution and attention will be automatically skipped. In other words, you can use this straightforwardly in your 2d Unet and then port it over to a 3d Unet once that phase of the training is done. The temporal modules are initialized to output identity as the paper had done.
import torch
from make_a_video_pytorch import PseudoConv3d, SpatioTemporalAttention
conv = PseudoConv3d(
dim = 256,
kernel_size = 3
)
attn = SpatioTemporalAttention(
dim = 256,
dim_head = 64,
heads = 8
)
images = torch.randn(1, 256, 16, 16) # (batch, features, height, width)
conv_out = conv(images) # (1, 256, 16, 16)
attn_out = attn(images) # (1, 256, 16, 16)
You can also control the two modules so that when fed 3-dimensional features, it only does training spatially
import torch
from make_a_video_pytorch import PseudoConv3d, SpatioTemporalAttention
conv = PseudoConv3d(
dim = 256,
kernel_size = 3
)
attn = SpatioTemporalAttention(
dim = 256,
dim_head = 64,
heads = 8
)
video = torch.randn(1, 256, 8, 16, 16) # (batch, features, frames, height, width)
# below it will not train across time
conv_out = conv(video, enable_time = False) # (1, 256, 8, 16, 16)
attn_out = attn(video, enable_time = False) # (1, 256, 8, 16, 16)
Full SpaceTimeUnet that is agnostic to images or video training, and where even if video is passed in, time can be ignored
import torch
from make_a_video_pytorch import SpaceTimeUnet
unet = SpaceTimeUnet(
dim = 64,
channels = 3,
dim_mult = (1, 2, 4, 8),
resnet_block_depths = (1, 1, 1, 2),
temporal_compression = (False, False, False, True),
self_attns = (False, False, False, True),
condition_on_timestep = False,
attn_pos_bias = False,
flash_attn = True
).cuda()
# train on images
images = torch.randn(1, 3, 128, 128).cuda()
images_out = unet(images)
assert images.shape == images_out.shape
# then train on videos
video = torch.randn(1, 3, 16, 128, 128).cuda()
video_out = unet(video)
assert video_out.shape == video.shape
# or even treat your videos as images
video_as_images_out = unet(video, enable_time = False)
Todo
-
give attention the best positional embeddings research has to offer
-
soup up the attention
-
add flash attention
-
make sure dalle2-pytorch can accept
SpaceTimeUnetfor training
Citations
@misc{Singer2022,
author = {Uriel Singer},
url = {https://makeavideo.studio/Make-A-Video.pdf}
}
@inproceedings{rogozhnikov2022einops,
title = {Einops: Clear and Reliable Tensor Manipulations with Einstein-like Notation},
author = {Alex Rogozhnikov},
booktitle = {International Conference on Learning Representations},
year = {2022},
url = {https://openreview.net/forum?id=oapKSVM2bcj}
}
@article{Dong2021AttentionIN,
title = {Attention is Not All You Need: Pure Attention Loses Rank Doubly Exponentially with Depth},
author = {Yihe Dong and Jean-Baptiste Cordonnier and Andreas Loukas},
journal = {ArXiv},
year = {2021},
volume = {abs/2103.03404}
}
@article{Zhang2021TokenST,
title = {Token Shift Transformer for Video Classification},
author = {Hao Zhang and Y. Hao and Chong-Wah Ngo},
journal = {Proceedings of the 29th ACM International Conference on Multimedia},
year = {2021}
}
@inproceedings{shleifer2022normformer,
title = {NormFormer: Improved Transformer Pretraining with Extra Normalization},
author = {Sam Shleifer and Myle Ott},
booktitle = {Submitted to The Tenth International Conference on Learning Representations },
year = {2022},
url = {https://openreview.net/forum?id=GMYWzWztDx5},
}
@inproceedings{dao2022flashattention,
title = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
author = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
booktitle = {Advances in Neural Information Processing Systems},
year = {2022}
}
.\lucidrains\make-a-video-pytorch\setup.py
# 导入设置和查找包的函数
from setuptools import setup, find_packages
# 设置包的信息
setup(
name = 'make-a-video-pytorch', # 包的名称
packages = find_packages(exclude=[]), # 查找并包含所有包
version = '0.3.1', # 版本号
license='MIT', # 许可证
description = 'Make-A-Video - Pytorch', # 描述
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
long_description_content_type = 'text/markdown', # 长描述内容类型
url = 'https://github.com/lucidrains/make-a-video-pytorch', # 项目链接
keywords = [ # 关键词列表
'artificial intelligence',
'deep learning',
'attention mechanism',
'text-to-video',
'axial convolutions'
],
install_requires=[ # 安装依赖
'classifier-free-guidance-pytorch',
'einops>=0.6',
'torch>=1.6',
],
classifiers=[ # 分类器
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\MaMMUT-pytorch\mammut_pytorch\mammut_pytorch.py
# 导入 torch 库
import torch
# 从 torch 库中导入 einsum, nn 模块
from torch import einsum, nn
# 从 torch 库中导入 F 模块
import torch.nn.functional as F
# 从 torch 库中导入 distributed 模块
import torch.distributed as dist
# 从 torch 库中导入 Function 模块
from torch.autograd import Function
# 从 einops 库中导入 rearrange, repeat 函数
from einops import rearrange, repeat
# 辅助函数
# 判断变量是否存在
def exists(val):
return val is not None
# 如果变量存在则返回该变量,否则返回默认值
def default(val, d):
return val if exists(val) else d
# 判断一个数是否可以被另一个数整除
def divisible_by(numer, denom):
return (numer % denom) == 0
# 分布式
# 在指定维度上对张量进行填充,使其达到指定长度
def pad_dim_to(t, length, dim = 0):
pad_length = length - t.shape[dim]
zero_pairs = (-dim - 1) if dim < 0 else (t.ndim - dim - 1)
return F.pad(t, (*((0, 0) * zero_pairs), 0, pad_length)
# 对所有进程中的张量进行收集
def all_gather_variable_batch(t):
device, rank, world_size = t.device, dist.get_rank(), dist.get_world_size()
size = torch.tensor(t.shape[0], device = device, dtype = torch.long)
sizes = [torch.empty_like(size, device = device, dtype = torch.long) for i in range(world_size)]
dist.all_gather(sizes, size)
sizes = torch.stack(sizes)
max_size = sizes.amax().item()
padded_t = pad_dim_to(t, max_size, dim = 0)
gathered_tensors = [torch.empty_like(padded_t, device = device, dtype = padded_t.dtype) for i in range(world_size)]
dist.all_gather(gathered_tensors, padded_t)
gathered_tensor = torch.cat(gathered_tensors)
seq = torch.arange(max_size, device = device)
mask = rearrange(seq, 'j -> 1 j') < rearrange(sizes, 'i -> i 1')
mask = rearrange(mask, 'i j -> (i j)')
gathered_tensor = gathered_tensor[mask]
sizes = sizes.tolist()
return gathered_tensor, sizes
# 自定义的 AllGather 函数
class AllGather(Function):
@staticmethod
def forward(ctx, x):
assert dist.is_initialized() and dist.get_world_size() > 1
x, batch_sizes = all_gather_variable_batch(x)
ctx.batch_sizes = batch_sizes
return x
@staticmethod
def backward(ctx, grads):
batch_sizes, rank = ctx.batch_sizes, dist.get_rank()
grads_by_rank = grads.split(batch_sizes, dim = 0)
return grads_by_rank[rank]
# 应用自定义的 AllGather 函数
all_gather = AllGather.apply
# 归一化
# 使用不带偏置的 layernorm,这是 PyTorch 不提供的功能
# Layernorm 类
class LayerNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.gamma = nn.Parameter(torch.ones(dim))
self.register_buffer("beta", torch.zeros(dim))
def forward(self, x):
return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
# 残差连接
# Residual 类
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, *args, **kwargs):
return self.fn(x, *args, **kwargs) + x
# 转换为潜变量
# EmbedToLatents 类
class EmbedToLatents(nn.Module):
def __init__(self, dim, dim_latents):
super().__init__()
self.to_latents = nn.Linear(dim, dim_latents, bias=False)
def forward(self, x):
latents = self.to_latents(x)
return F.normalize(latents, dim=-1)
# 旋转位置嵌入
# https://arxiv.org/abs/2104.09864
# RotaryEmbedding 类
class RotaryEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
def forward(self, max_seq_len, *, device):
seq = torch.arange(max_seq_len, device=device, dtype=self.inv_freq.dtype)
freqs = einsum("i , j -> i j", seq, self.inv_freq)
return torch.cat((freqs, freqs), dim=-1)
# 将张量旋转一半
def rotate_half(x):
x = rearrange(x, "... (j d) -> ... j d", j=2)
x1, x2 = x.unbind(dim=-2)
return torch.cat((-x2, x1), dim=-1)
# 应用旋转位置嵌入
def apply_rotary_pos_emb(pos, t):
return (t * pos.cos()) + (rotate_half(t) * pos.sin())
# 经典的 Noam Shazeer 论文,这里使用 SwiGLU 代替更流行的 GEGLU 用于门控前馈
# https://arxiv.org/abs/2002.05202
# SwiGLU 类
class SwiGLU(nn.Module):
def forward(self, x):
x, gate = x.chunk(2, dim=-1)
return F.silu(gate) * x
# 并行注意力和前馈,带有残差连接
# 定义一个并行Transformer块的类
class ParallelTransformerBlock(nn.Module):
def __init__(self, dim, dim_head=64, heads=8, ff_mult=4):
super().__init__()
self.norm = LayerNorm(dim) # 初始化LayerNorm
attn_inner_dim = dim_head * heads
ff_inner_dim = dim * ff_mult
self.fused_dims = (attn_inner_dim, dim_head, dim_head, (ff_inner_dim * 2)) # 定义融合维度
self.heads = heads
self.scale = dim_head**-0.5
self.rotary_emb = RotaryEmbedding(dim_head) # 初始化RotaryEmbedding
self.fused_attn_ff_proj = nn.Linear(dim, sum(self.fused_dims), bias=False) # 线性变换
self.attn_out = nn.Linear(attn_inner_dim, dim, bias=False) # 线性变换
self.ff_out = nn.Sequential(
SwiGLU(), # SwiGLU激活函数
nn.Linear(ff_inner_dim, dim, bias=False) # 线性变换
)
# 用于缓存因果掩码和旋转嵌入
self.mask = None
self.pos_emb = None
def get_mask(self, n, device):
if self.mask is not None and self.mask.shape[-1] >= n:
return self.mask[:n, :n].to(device)
mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1) # 生成上三角掩码
self.mask = mask
return mask
def get_rotary_embedding(self, n, device):
if self.pos_emb is not None and self.pos_emb.shape[-2] >= n:
return self.pos_emb[:n].to(device)
pos_emb = self.rotary_emb(n, device=device) # 获取旋转嵌入
self.pos_emb = pos_emb
return pos_emb
def forward(self, x, attn_mask=None):
"""
einstein notation
b - batch
h - heads
n, i, j - sequence length (base sequence length, source, target)
d - feature dimension
"""
n, device, h = x.shape[1], x.device, self.heads
# pre layernorm
x = self.norm(x) # LayerNorm
# attention queries, keys, values, and feedforward inner
q, k, v, ff = self.fused_attn_ff_proj(x).split(self.fused_dims, dim=-1) # 拆分线性变换结果
# split heads
q = rearrange(q, "b n (h d) -> b h n d", h=h) # 重排张量形状
# rotary embeddings
positions = self.get_rotary_embedding(n, device) # 获取旋转嵌入
q, k = map(lambda t: apply_rotary_pos_emb(positions, t), (q, k)) # 应用旋转嵌入
# scale
q = q * self.scale # 缩放
# similarity
sim = einsum("b h i d, b j d -> b h i j", q, k) # 计算相似度
# causal mask
causal_mask = self.get_mask(n, device) # 获取因果掩码
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max) # 应用掩码
# extra attention mask - for masking out attention from text CLS token to padding
if exists(attn_mask):
attn_mask = rearrange(attn_mask, 'b i j -> b 1 i j') # 重排注意力掩码
sim = sim.masked_fill(~attn_mask, -torch.finfo(sim.dtype).max) # 应用额外的掩码
# attention
attn = sim.softmax(dim=-1) # softmax计算注意力权重
# aggregate values
out = einsum("b h i j, b j d -> b h i d", attn, v) # 聚合值
# merge heads
out = rearrange(out, "b h n d -> b n (h d)") # 合并头部
return self.attn_out(out) + self.ff_out(ff) # 返回注意力输出和前馈输出
# cross attention - using multi-query + one-headed key / values as in PaLM w/ optional parallel feedforward
class CrossAttention(nn.Module):
def __init__(
self,
dim,
*,
context_dim=None,
dim_head=64,
heads=8,
parallel_ff=False,
ff_mult=4,
norm_context=False
):
# 调用父类的初始化方法
super().__init__()
# 初始化头数和缩放因子
self.heads = heads
self.scale = dim_head ** -0.5
# 计算内部维度
inner_dim = heads * dim_head
# 设置上下文维度
context_dim = default(context_dim, dim)
# 初始化 LayerNorm 层
self.norm = LayerNorm(dim)
self.context_norm = LayerNorm(context_dim) if norm_context else nn.Identity()
# 初始化线性变换层
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(context_dim, dim_head * 2, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)
# 是否有并行前馈
ff_inner_dim = ff_mult * dim
self.ff = nn.Sequential(
nn.Linear(dim, ff_inner_dim * 2, bias=False),
SwiGLU(),
nn.Linear(ff_inner_dim, dim, bias=False)
) if parallel_ff else None
def forward(self, x, context):
"""
einstein notation
b - batch
h - heads
n, i, j - sequence length (base sequence length, source, target)
d - feature dimension
"""
# 预先 LayerNorm,用于查询和上下文
x = self.norm(x)
context = self.context_norm(context)
# 获取查询
q = self.to_q(x)
q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)
# 缩放
q = q * self.scale
# 获取键/值
k, v = self.to_kv(context).chunk(2, dim=-1)
# 查询/键相似度
sim = einsum('b h i d, b j d -> b h i j', q, k)
# 注意力
attn = sim.softmax(dim=-1)
# 聚合
out = einsum('b h i j, b j d -> b h i d', attn, v)
# 合并和组合头
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out)
# 添加并行前馈(用于多模态层)
if exists(self.ff):
out = out + self.ff(x)
return out
# 定义一个名为 MaMMUT 的类,继承自 nn.Module 类
class MaMMUT(nn.Module):
# 初始化函数,接收多个参数
def __init__(
self,
*,
dim,
num_tokens,
depth,
cross_attend_every=1,
cross_attend_layers=None,
dim_latents=None,
image_dim=None,
num_img_queries=256,
dim_head=64,
heads=8,
ff_mult=4,
img_encoder=None,
caption_loss_weight=1.,
contrastive_loss_weight=1.,
pad_id=0
):
# 调用父类的初始化函数
super().__init__()
# 初始化类的属性
self.dim = dim
self.pad_id = pad_id
self.caption_loss_weight = caption_loss_weight
self.contrastive_loss_weight = contrastive_loss_weight
# token embeddings
# 创建一个嵌入层,用于将 token 映射为指定维度的向量
self.token_emb = nn.Embedding(num_tokens, dim)
# 创建一个 nn.Parameter 对象,用于存储文本的分类标记
self.text_cls_token = nn.Parameter(torch.randn(dim))
# image encoder
# 设置图像编码器
self.img_encoder = img_encoder
# attention pooling for image tokens
# 创建一个 nn.Parameter 对象,用于存储图像查询向量
self.img_queries = nn.Parameter(torch.randn(num_img_queries + 1, dim)) # num image queries for multimodal, but 1 extra CLS for contrastive learning
# 创建一个交叉注意力池化层,用于处理图像 token
self.img_attn_pool = CrossAttention(dim=dim, context_dim=image_dim, dim_head=dim_head, heads=heads, norm_context=True)
# 创建 LayerNorm 层,用于规范化图像注意力池化结果
self.img_attn_pool_norm = LayerNorm(dim)
# 创建 LayerNorm 层,用于规范化文本分类标记
self.text_cls_norm = LayerNorm(dim)
# to latents
# 设置潜在空间的维度
dim_latents = default(dim_latents, dim)
# 创建将图像嵌入转换为潜在空间的层
self.img_to_latents = EmbedToLatents(dim, dim_latents)
# 创建将文本嵌入转换为潜在空间的层
self.text_to_latents = EmbedToLatents(dim, dim_latents)
# contrastive learning temperature
# 创建一个 nn.Parameter 对象,用于存储对比学习的温度参数
self.temperature = nn.Parameter(torch.Tensor([1.]))
# layers
# 创建一个空的 nn.ModuleList 对象,用于存储多个层
self.layers = nn.ModuleList([])
# 循环创建指定数量的层
for ind in range(depth):
layer = ind + 1
has_cross_attn = divisible_by(layer, cross_attend_every)
if exists(cross_attend_layers):
assert isinstance(cross_attend_layers, tuple)
has_cross_attn = layer in cross_attend_layers
# 将每一层的处理逻辑添加到 layers 中
self.layers.append(nn.ModuleList([
Residual(ParallelTransformerBlock(dim=dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult)),
Residual(CrossAttention(dim=dim, dim_head=dim_head, heads=heads, parallel_ff=True, ff_mult=ff_mult)) if has_cross_attn else None
]))
# to logits
# 创建一个序列,包含规范化层和线性层,用于生成输出 logits
self.to_logits = nn.Sequential(
LayerNorm(dim),
nn.Linear(dim, num_tokens, bias=False)
)
# they used embedding weight tied projection out to logits, not common, but works
# 将线性层的权重与嵌入层的权重绑定在一起
self.to_logits[-1].weight = self.token_emb.weight
# 初始化嵌入层的权重
nn.init.normal_(self.token_emb.weight, std=0.02)
# is data parallel
# 检查是否启用了数据并行处理
self.is_data_parallel = dist.is_initialized() and dist.get_world_size() > 1
# 定义一个方法,用于将文本嵌入
def embed_text(self, text):
# 获取文本的批量大小和设备信息
batch, device = text.shape[0], text.device
seq = text.shape[1]
# 获取文本的 token 嵌入
text_tokens = self.token_emb(text)
# append text cls tokens
# 重复文本分类标记,拼接到文本 token 后面
text_cls_tokens = repeat(self.text_cls_token, 'd -> b 1 d', b=batch)
text_tokens = torch.cat((text_tokens, text_cls_tokens), dim=-2)
# create specific mask for text cls token at the end
# to prevent it from attending to padding
# 创建特定的掩码,用于防止文本分类标记与填充部分进行注意力交互
cls_mask = rearrange(text!=self.pad_id, 'b j -> b 1 j')
attn_mask = F.pad(cls_mask, (0, 1, seq, 0), value=True)
# go through layers, but do not cross attend
# 遍历层,但不进行交叉注意力
for attn_ff, _ in self.layers:
text_tokens = attn_ff(text_tokens, attn_mask=attn_mask)
# get text cls token
# 获取文本分类标记和文本 token
text_tokens, text_cls_tokens = text_tokens[:, :-1], text_tokens[:, -1]
# 规范化文本分类标记
text_embeds = self.text_cls_norm(text_cls_tokens)
return text_embeds, text_tokens
# 将图像嵌入到嵌入向量中
def embed_image(self, images=None, image_tokens=None):
# 将图像编码为嵌入向量
# 使用在初始化时传入的 img_encoder
# 也可以接受预先计算的图像 tokens
# 确保 images 和 image_tokens 不能同时存在
assert not (exists(images) and exists(image_tokens))
if exists(images):
# 确保存在 self.img_encoder,用于自动图像编码
assert exists(self.img_encoder), 'img_encoder must be passed in for automatic image encoding'
image_tokens = self.img_encoder(images)
# 注意力池化图像 tokens
img_queries = repeat(self.img_queries, 'n d -> b n d', b=image_tokens.shape[0])
img_queries = self.img_attn_pool(img_queries, image_tokens)
img_queries = self.img_attn_pool_norm(img_queries)
return img_queries[:, 0], img_queries[:, 1:]
# 前向传播函数
def forward(
self,
text,
text_mask = None,
images=None,
image_tokens=None,
labels=None,
return_loss=False,
return_embeddings=False
):
batch, device = text.shape[0], text.device
if return_loss and not exists(labels):
text, labels = text[:, :-1], text[:, 1:]
text_embeds, _ = self.embed_text(text)
image_embeds, image_tokens = self.embed_image(images=images, image_tokens=image_tokens)
# 如果研究人员需要返回嵌入向量,则返回嵌入向量
if return_embeddings:
return text_embeds, image_embeds
# 经过各层处理
text_tokens = self.token_emb(text)
for attn_ff, cross_attn in self.layers:
text_tokens = attn_ff(text_tokens)
if exists(cross_attn):
text_tokens = cross_attn(text_tokens, image_tokens)
logits = self.to_logits(text_tokens)
if not return_loss:
return logits
# 缩写
ce = F.cross_entropy
# 计算标题损失(交叉熵损失)
logits = rearrange(logits, 'b n c -> b c n')
caption_loss = ce(logits, labels, ignore_index=self.pad_id)
caption_loss = caption_loss * self.caption_loss_weight
# 嵌入到潜变量
text_latents = self.text_to_latents(text_embeds)
image_latents = self.img_to_latents(image_embeds)
# 如果使用数据并行,需要从所有机器中收集所有潜变量
if self.is_data_parallel:
latents = torch.stack((text_latents, image_latents), dim = 1)
latents = all_gather(latents)
text_latents, image_latents = latents.unbind(dim = 1)
# 计算对比损失
sim = einsum('i d, j d -> i j', text_latents, image_latents)
sim = sim * self.temperature.exp()
contrastive_labels = torch.arange(batch, device=device)
contrastive_loss = (ce(sim, contrastive_labels) + ce(sim.t(), contrastive_labels)) * 0.5
contrastive_loss = contrastive_loss * self.contrastive_loss_weight
return caption_loss + contrastive_loss