Lucidrains 系列项目源码解析(七十四)
.\lucidrains\performer-pytorch\setup.py
# 导入设置工具和查找包的函数
from setuptools import setup, find_packages
# 设置包的元数据
setup(
# 包的名称
name = 'performer-pytorch',
# 查找并包含除了'examples'之外的所有包
packages = find_packages(exclude=['examples']),
# 版本号
version = '1.1.4',
# 许可证
license='MIT',
# 描述
description = 'Performer - Pytorch',
# 作者
author = 'Phil Wang',
# 作者邮箱
author_email = 'lucidrains@gmail.com',
# 项目链接
url = 'https://github.com/lucidrains/performer-pytorch',
# 关键词列表
keywords = [
'artificial intelligence',
'attention mechanism',
'efficient attention',
'transformers'
],
# 安装依赖
install_requires=[
'einops>=0.3',
'local-attention>=1.1.1',
'torch>=1.6',
'axial-positional-embedding>=0.1.0'
],
# 分类标签
classifiers=[
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\perfusion-pytorch\perfusion_pytorch\embedding.py
# 导入 torch 库
import torch
# 从 torch 库中导入 nn, Tensor
from torch import nn, Tensor
# 从 torch.nn 库中导入 Module
from torch.nn import Module
# 从 collections 库中导入 namedtuple
from collections import namedtuple
# 从 beartype 库中导入 beartype
from beartype import beartype
# 从 beartype.door 库中导入 is_bearable
from beartype.door import is_bearable
# 从 beartype.typing 库中导入 Optional, Tuple, Union, Callable, List
from beartype.typing import Optional, Tuple, Union, Callable, List
# 从 einops 库中导入 rearrange
from einops import rearrange
# 从 open_clip 库中导入 tokenizer
from open_clip import tokenizer
# 定义常量 EmbeddingReturn 为一个命名元组,包含 'embed_with_concept', 'embed_with_superclass', 'embed_mask', 'concept_indices' 四个字段
EmbeddingReturn = namedtuple('EmbeddingReturn', [
'embed_with_concept',
'embed_with_superclass',
'embed_mask',
'concept_indices'
])
# 定义辅助函数
# 判断值是否存在
def exists(val):
return val is not None
# 返回默认值
def default(val, d):
return val if exists(val) else d
# 判断列表中元素是否全部唯一
def is_all_unique(arr):
return len(set(arr)) == len(arr)
# 根据给定的索引过滤元组中的元素
def filter_tuple_indices(tup, indices):
return tuple(tup[i] for i in indices)
# 根据给定的 ids 创建一个 mask
@beartype
def get_mask(
x: Tensor,
ids: Tuple[int, ...]
):
masks = tuple(x == i for i in ids)
mask, *rest_masks = masks
for rest_mask in rest_masks:
mask = mask | rest_mask
return mask
# 嵌入包装类
class EmbeddingWrapper(Module):
# 初始化函数
@beartype
def __init__(
self,
embed: nn.Embedding,
num_concepts = 1,
superclass_embed_id: Optional[Union[int, Tuple[int, ...]]] = None,
superclass_string: Optional[str] = None,
tokenize: Callable[[List[str]], Tensor] = tokenizer.tokenize,
tokenizer_pad_id: int = 0,
tokenizer_sos_eos_id: Tuple[int, int] = (49406, 49407)
):
super().__init__()
self.embed = embed
num_embeds, dim = embed.weight.shape
self.num_embeds = num_embeds
self.num_concepts = num_concepts
self.concepts = nn.Parameter(torch.zeros(num_concepts, dim))
assert not (exists(superclass_embed_id) and exists(superclass_string)), 'either superclass embed id is given, or the superclass string'
self.pad_id = tokenizer_pad_id
self.tokenize = None
if exists(superclass_string):
self.tokenize = tokenize
ids = tokenize([superclass_string])[0]
mask_for_ids = get_mask(ids, (tokenizer_pad_id, *tokenizer_sos_eos_id))
ids = ids[~mask_for_ids]
assert ids.shape[-1] == 1, f'your superclass concept string must map exactly one token id'
superclass_embed_id = ids[0].item()
print(f'super class embed for "{superclass_string}"" set as {superclass_embed_id}')
print(f'you can now pass in a list of strings containing superclass concept, and this wrapper will return the embedding w/ concept and superclass required for finetuning')
self.superclass_embed_id = superclass_embed_id
assert not (exists(superclass_embed_id) and num_concepts > 1), 'cannot do multi concept with superclass embed id given'
if exists(superclass_embed_id):
# 作者发现将概念嵌入初始化为超类嵌入会获得更好的结果,允许这种选项
if not isinstance(superclass_embed_id, tuple):
superclass_embed_id = (superclass_embed_id,)
superclass_embed_indices = torch.tensor(list(superclass_embed_id))
superclass_embeds = embed(superclass_embed_indices)
self.concepts.data.copy_(superclass_embeds)
else:
# 否则初始化为通常用于嵌入的小初始化值
nn.init.normal_(self.concepts, std = 0.02)
self.concept_embed_ids = tuple(range(num_embeds, num_embeds + num_concepts))
# 返回参数
def parameters(self):
return [self.concepts]
# 返回设备
@property
def device(self):
return self.concepts.device
# 前向传播函数
@beartype
def forward(
self,
x: Union[Tensor, List[str]],
concept_id: Optional[Union[int, Tuple[int, ...]]] = None,
return_embed_with_superclass = True,
clip_transformer_fn: Optional[Callable[[Tensor], Tensor]] = None
# 一个用于 CLIP 的包装器
# 自动将令牌嵌入与新概念包装在一起
# 定义一个类 OpenClipEmbedWrapper,用于包装 CLIP 模型的嵌入层,并在前向传播中通过文本转换器和最终层归一化层传递概念嵌入和超类概念嵌入
# 同时,将 ids 和 superclass_ids 通过修改后的文本编码器传递两次(将尝试用 nn.Identity 替换 nn.Embedding)
class OpenClipEmbedWrapper(Module):
@beartype
def __init__(
self,
clip: Module,
text_transformer_path = 'transformer',
ln_final_path = 'ln_final', # 在 CLIP 中,最终的层归一化层与转换器分开
**embedding_wrapper_kwargs
):
super().__init__()
# 创建一个嵌入层包装器,用于包装 CLIP 模型的 token 嵌入
self.wrapped_embed = EmbeddingWrapper(clip.token_embedding, **embedding_wrapper_kwargs)
# 获取 CLIP 模型中各模块的路径和模块对象的字典
path_to_modules = dict([(path, mod) for path, mod in clip.named_modules()])
# 确保文本转换器路径在路径字典中
assert text_transformer_path in path_to_modules
# 获取文本转换器和最终层归一化层(如果存在)
text_transformer = path_to_modules[text_transformer_path]
ln_final = path_to_modules.get(ln_final_path, nn.Identity())
# 将文本转换器和最终层归一化层组合成一个序列
self.text_transformer = nn.Sequential(
text_transformer,
ln_final
)
# 前向传播函数,接收输入 x 和其他关键字参数,返回嵌入层包装器
def forward(
self,
x,
**kwargs
) -> EmbeddingWrapper:
# 通过嵌入层包装器获取文本嵌入、超类文本嵌入、文本掩码和概念索引
text_embeds, superclass_text_embeds, text_mask, concept_indices = self.wrapped_embed(x, **kwargs)
# 将文本嵌入传递给文本转换器
text_enc = self.text_transformer(text_embeds)
superclass_text_enc = None
# 如果超类文本嵌入存在,则将其传递给文本转换器
if exists(superclass_text_embeds):
superclass_text_enc = self.text_transformer(superclass_text_embeds)
# 返回嵌入返回对象,包括文本嵌入、超类文本嵌入、文本掩码和概念索引
return EmbeddingReturn(text_enc, superclass_text_enc, text_mask, concept_indices)
# 将多个嵌入层包装器(每个具有一个概念)合并为一个具有多个概念的合并嵌入层包装器
@beartype
def merge_embedding_wrappers(
*embeds: EmbeddingWrapper
) -> EmbeddingWrapper:
# 计算总概念数
total_concepts = sum([embed.num_concepts for embed in embeds])
# 确保所有嵌入层的权重形状相同
assert len(set([tuple(embed.embed.weight.shape) for embed in embeds])) == 1
# 获取第一个嵌入层的嵌入
embed = embeds[0].embed
# 创建一个合并的嵌入层包装器,包括总概念数
merged_concepts = EmbeddingWrapper(
embed = embed,
num_concepts = total_concepts
)
# 将合并的嵌入层包装器设置为评估模式
merged_concepts.eval()
# 将所有嵌入层的概念连接起来
concepts = torch.cat(tuple(embed.concepts.data for embed in embeds), dim = 0)
# 将连接后的概念设置为合并的嵌入层包装器的概念
merged_concepts.concepts = nn.Parameter(concepts)
# 返回合并的嵌入层包装器
return merged_concepts
.\lucidrains\perfusion-pytorch\perfusion_pytorch\open_clip.py
# 导入必要的库
from beartype import beartype
from beartype.typing import List, Optional
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange
import open_clip
# 定义一个函数,用于检查值是否存在
def exists(val):
return val is not None
# 定义一个函数,用于对张量进行 L2 归一化
def l2norm(t):
return F.normalize(t, dim = -1)
# 定义一个类,用于适配 OpenCLIP 模型
class OpenClipAdapter(nn.Module):
@beartype
def __init__(
self,
name = 'ViT-B/32',
pretrained = 'laion400m_e32',
tokenizer_name = 'ViT-B-32-quickgelu',
eos_id = 49407
):
super().__init__()
# 创建 OpenCLIP 模型、预处理函数和 tokenizer
clip, _, preprocess = open_clip.create_model_and_transforms(name, pretrained = pretrained)
tokenizer = open_clip.get_tokenizer(tokenizer_name)
self.clip = clip
self.tokenizer = tokenizer
self.eos_id = eos_id
# 用于获取最终文本表示的钩子
text_attention_final = self.find_layer('ln_final')
self._dim_latent = text_attention_final.weight.shape[0]
self.text_handle = text_attention_final.register_forward_hook(self._text_hook)
# 标准化函数
self.clip_normalize = preprocess.transforms[-1]
self.cleared = False
@property
def device(self):
return next(self.parameters()).device
# 查找指定层
def find_layer(self, layer):
modules = dict([*self.clip.named_modules()])
return modules.get(layer, None)
# 清除钩子
def clear(self):
if self.cleared:
return
self.text_handle()
# 文本钩子函数
def _text_hook(self, _, inputs, outputs):
self.text_encodings = outputs
@property
def dim_latent(self):
return self._dim_latent
@property
def max_text_len(self):
return self.clip.positional_embedding.shape[0]
@beartype
def embed_texts(
self,
texts: List[str]
):
# 对文本进行编码
ids = self.tokenizer(texts)
ids = ids.to(self.device)
ids = ids[..., :self.max_text_len]
is_eos_id = (ids == self.eos_id)
text_mask_excluding_eos = is_eos_id.cumsum(dim = -1) == 0
text_mask = F.pad(text_mask_excluding_eos, (1, -1), value = True)
text_mask = text_mask & (ids != 0)
assert not self.cleared
# 编码文本并进行掩码
text_embed = self.clip.encode_text(ids)
text_encodings = self.text_encodings
text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)
return text_encodings.float(), text_mask
.\lucidrains\perfusion-pytorch\perfusion_pytorch\optimizer.py
# 从 torch.nn 模块中导入 Module 类
# 从 torch.optim 模块中导入 AdamW、Adam、Optimizer 类
from torch.nn import Module
from torch.optim import AdamW, Adam, Optimizer
# 从 beartype 模块中导入 beartype 装饰器
from beartype import beartype
# 从 perfusion_pytorch.embedding 模块中导入 EmbeddingWrapper 类
# 从 perfusion_pytorch.perfusion 模块中导入 Rank1EditModule 类
from perfusion_pytorch.embedding import EmbeddingWrapper
from perfusion_pytorch.perfusion import Rank1EditModule
# 定义一个函数,用于自动查找微调所需的所有参数
@beartype
def get_finetune_parameters(text_image_model: Module):
# 初始化参数列表
params = []
# 遍历 text_image_model 模块中的所有子模块
for module in text_image_model.modules():
# 如果子模块是 EmbeddingWrapper 或 Rank1EditModule 类型
if isinstance(module, (EmbeddingWrapper, Rank1EditModule)):
# 将子模块的参数添加到参数列表中
params.extend(module.parameters())
# 返回参数列表
return params
# 定义一个函数,用于获取微调优化器
@beartype
def get_finetune_optimizer(
text_image_model: Module,
lr = 1e-4,
wd = 1e-2,
betas = (0.9, 0.99),
eps = 1e-8,
**kwargs
) -> Optimizer:
# 获取微调所需的参数
params = get_finetune_parameters(text_image_model)
# 断言参数列表长度大于0,否则抛出异常
assert len(params) > 0, 'no finetuneable parameters found'
# 计算总参数数量
total_params = sum([p.numel() for p in params])
# 打印优化的参数数量
print(f'optimizing {total_params} parameters')
# 判断是否有权重衰减
has_weight_decay = wd > 0
# 根据是否有权重衰减选择 AdamW 或 Adam 类
adam_klass = AdamW if has_weight_decay else Adam
# 初始化 Adam 的参数
adam_kwargs = dict(lr = lr, betas = betas, eps = eps)
# 如果有权重衰减,则更新参数字典
if has_weight_decay:
adam_kwargs.update(weight_decay = wd)
# 返回根据参数和参数字典初始化的优化器
return adam_klass(params, **adam_kwargs, **kwargs)
.\lucidrains\perfusion-pytorch\perfusion_pytorch\perfusion.py
# 从 math 模块中导入 ceil 函数
# 从 copy 模块中导入 deepcopy 函数
# 从 pathlib 模块中导入 Path 类
# 从 beartype 模块中导入 beartype 装饰器
# 从 beartype.typing 模块中导入 Union, List, Optional, Tuple 类型
# 从 torch 模块中导入 nn, einsum, Tensor 类
# 从 torch.nn 模块中导入 Module 类
# 从 torch.nn.functional 模块中导入 F 函数
# 从 einops 模块中导入 rearrange, reduce 函数
# 从 opt_einsum 模块中导入 contract 函数
# 从 perfusion_pytorch.open_clip 模块中导入 OpenClipAdapter 类
from math import ceil
from copy import deepcopy
from pathlib import Path
from beartype import beartype
from beartype.typing import Union, List, Optional, Tuple
import torch
from torch import nn, einsum, Tensor
from torch.nn import Module
import torch.nn.functional as F
from einops import rearrange, reduce
from opt_einsum import contract as opt_einsum
from perfusion_pytorch.open_clip import OpenClipAdapter
# 预先计算的协方差路径
# 如果论文验证通过,将为更多模型添加
CURRENT_DIR = Path(__file__).parents[0]
DATA_DIR = CURRENT_DIR / 'data'
assert DATA_DIR.is_dir()
COVARIANCE_FILENAME_BY_TEXT_IMAGE_MODEL = dict(
SD15 = DATA_DIR / 'covariance_CLIP_ViT-L-14.pt'
)
assert all([filepath.exists() for filepath in COVARIANCE_FILENAME_BY_TEXT_IMAGE_MODEL.values()])
# 辅助函数
def exists(val):
return val is not None
def is_all_unique(arr):
return len(set(arr)) == len(arr)
# 用于计算 C - 输入协方差的函数
@beartype
@torch.no_grad()
def calculate_input_covariance(
clip: OpenClipAdapter,
texts: List[str],
batch_size = 32,
**cov_kwargs
):
num_batches = ceil(len(texts) / batch_size)
all_embeds = []
length = len(texts)
for batch_ind in range(num_batches):
start_index = batch_ind * batch_size
batch_texts = texts[start_index:(start_index + batch_size)]
embeds, mask = clip.embed_texts(batch_texts)
all_embeds.append(embeds[mask])
all_embeds = torch.cat(all_embeds, dim = 0)
return einsum('n d, n e -> d e', all_embeds, all_embeds) / length
# 由掩码加权的损失函数
@beartype
def loss_fn_weighted_by_mask(
pred: Tensor,
target: Tensor,
mask: Tensor,
normalized_mask_min_value = 0.
):
assert mask.shape[-2:] == pred.shape[-2:] == target.shape[-2:]
assert mask.shape[0] == pred.shape[0] == target.shape[0]
assert (mask.amin() >= 0.).all(), 'mask should not have values below 0'
if mask.ndim == 4:
assert mask.shape[1] == 1
mask = rearrange(mask, 'b 1 h w -> b h w')
loss = F.mse_loss(pred, target, reduction = 'none')
loss = reduce(loss, 'b c h w -> b h w')
# 通过最大值对掩码进行归一化
normalized_mask = mask / mask.amax(dim = -1, keepdim = True).clamp(min = 1e-5)
normalized_mask = normalized_mask.clamp(min = normalized_mask_min_value)
loss = loss * normalized_mask
return loss.mean()
# 一个模块,包装了交叉注意力的键和值投影到文本编码
class Rank1EditModule(Module):
@beartype
def __init__(
self,
key_or_values_proj: nn.Linear,
*,
num_concepts: int = 1,
C: Optional[Tensor] = None, # 输入的协方差,从 100K laion 文本中预先计算
default_model = 'SD15',
text_seq_len: int = 77,
is_key_proj: bool = False,
input_decay = 0.99,
train_beta = 0.75,
train_temperature = 0.1,
eval_beta = 0.70, # 在论文中,指定了本地键锁定的范围 (0.6 - 0.75),全局键锁定的范围 (0.4 -0.6)
eval_temperature = 0.15,
frac_gradient_concept_embed = 0.1, # 他们使用一个较慢的学习率来嵌入 - 这可以通过一个技巧来减少反向传播的梯度
multi_concepts_use_cholesky = False # 对于多个概念,使用一种不需要 Cholesky 根的近似技术
):
# 调用父类的构造函数
super().__init__()
# 断言在注意力中的键值投影不应该有偏置
assert not exists(key_or_values_proj.bias), 'key value projection in attention should not have bias'
# 初始化注意力模块的参数
self.num_concepts = num_concepts
self.multi_concepts_use_cholesky = multi_concepts_use_cholesky
# 获取键值投影的权重
self.weight = key_or_values_proj.weight
dim_output, dim_input = self.weight.shape
# 设置训练和评估时的温度和 beta 参数
self.train_beta = train_beta
self.train_temperature = train_temperature
self.eval_beta = eval_beta
self.eval_temperature = eval_temperature
# 输入的衰减参数
self.input_decay = input_decay
# 文本序列的长度
self.text_seq_len = text_seq_len
# 降低概念嵌入学习率的参数
assert 0 < frac_gradient_concept_embed <= 1.
self.frac_gradient_concept_embed = frac_gradient_concept_embed
# 初始化概念文本嵌入的指数平滑参数
self.register_buffer('initted', torch.zeros(num_concepts, 1).bool())
self.register_buffer('ema_concept_text_encs', torch.zeros(num_concepts, dim_input))
# 概念输出 - 仅优化值,而不是键
self.is_key_proj = is_key_proj # 锁定输出到超类,并关闭梯度
self.concept_outputs = nn.Parameter(torch.zeros(num_concepts, dim_output), requires_grad = not is_key_proj)
# 输入协方差 C 的逆矩阵,如果未传入协方差,则使用默认值
if not exists(C):
covariance_filepath = COVARIANCE_FILENAME_BY_TEXT_IMAGE_MODEL.get(default_model, None)
assert exists(covariance_filepath), f'{default_model} not found in the list of precomputed covariances {tuple(COVARIANCE_FILENAME_BY_TEXT_IMAGE_MODEL.keys())}'
C = torch.load(str(covariance_filepath))
print(f'precomputed covariance loaded from {str(covariance_filepath)}')
# 计算 C_inv
C_inv = torch.inverse(C)
self.register_buffer('C_inv', C_inv)
@property
def num_concepts(self):
return self._num_concepts
@num_concepts.setter
def num_concepts(self, value):
self._num_concepts = value
if value == 1 or not self.multi_concepts_use_cholesky:
return
# 对于多个概念,需要 cholesky 分解 L_t_inv
try:
L = torch.linalg.cholesky(self.C_inv)
except:
print('unable to perform cholesky. please make sure input covariance matrix is properly calculated')
exit()
L_T = L.T
L_T_inv = torch.inverse(L_T)
self.register_buffer('L_T', L_T, persistent = False)
self.register_buffer('L_T_inv', L_T_inv, persistent = False)
@property
def device(self):
return next(self.buffers()).device
# 返回参数
def parameters(self):
if not self.is_key_proj:
return []
return [self.concept_outputs]
@beartype
def forward(
self,
text_enc: Tensor,
*,
concept_indices: Optional[Tensor] = None,
text_enc_with_superclass: Optional[Tensor] = None,
concept_id: Union[int, Tuple[int, ...]] = 0
# 合并已训练的 Rank1EditModule(s) 的函数
@beartype
def merge_rank1_edit_modules(
*modules: Rank1EditModule, # 接受多个 Rank1EditModule 参数
use_cholesky = False # 是否使用 Cholesky 分解,默认为 False
) -> Rank1EditModule: # 返回合并后的 Rank1EditModule 对象
# 断言所有模块都已初始化并最好已训练
assert all([m.initted.all() for m in modules]), 'all modules must be initialized and ideally trained'
# 断言概念输出维度必须相同
assert len(set([m.concept_outputs.shape[-1] for m in modules])) == 1, 'concept output dimension must be the same'
# 断言所有模块必须为键或值。不能将键和值的 Rank1EditModule 合并在一起
assert len(set([m.is_key_proj for m in modules])) == 1, 'all modules must be either for keys, or values. you cannot merge rank 1 edit modules of keys and values together'
# 获取第一个模块
first_module = modules[0]
# 深拷贝第一个模块
merged_module = deepcopy(first_module)
# 设置是否使用 Cholesky 分解
merged_module.multi_concepts_use_cholesky = use_cholesky
# 计算总概念数
total_concepts = sum([m.num_concepts for m in modules])
merged_module.num_concepts = total_concepts
# 拼接所有模块的概念输出
concept_outputs = torch.cat(tuple(m.concept_outputs.data for m in modules), dim = 0)
merged_module.concept_outputs = nn.Parameter(concept_outputs, requires_grad = not first_module.is_key_proj)
# 拼接所有模块的 EMA 概念文本编码
ema_concept_text_encs = torch.cat(tuple(m.ema_concept_text_encs.data for m in modules), dim = 0)
merged_module.register_buffer('ema_concept_text_encs', ema_concept_text_encs)
# 注册初始化状态
merged_module.register_buffer('initted', torch.ones(total_concepts, 1).bool())
# 返回合并后的模块
return merged_module
# 用于连接交叉注意力的函数
@beartype
def make_key_value_proj_rank1_edit_modules_(
cross_attention: nn.Module, # 交叉注意力模块
*,
input_covariance: Tensor, # 输入协方差
key_proj_name: str, # 键投影名称
value_proj_name: str, # 值投影名称
**rank1_edit_module_kwargs # Rank1EditModule 的其他参数
):
# 获取键投影和值投影线性层
linear_key = getattr(cross_attention, key_proj_name, None)
linear_values = getattr(cross_attention, value_proj_name, None)
# 断言键投影和值投影必须是 nn.Linear 类型
assert isinstance(linear_key, nn.Linear), f'{key_proj_name} must point to where the keys projection is (ex. self.to_keys = nn.Linear(in, out, bias = False) -> key_proj_name = "to_keys")'
assert isinstance(linear_values, nn.Linear), f'{value_proj_name} must point to where the values projection is (ex. self.to_values = nn.Linear(in, out, bias = False) -> value_proj_name = "to_values")'
# 创建键和值的 Rank1EditModule
rank1_edit_module_keys = Rank1EditModule(linear_key, input_covariance = input_covariance, is_key_proj = True, **rank1_edit_module_kwargs)
rank1_edit_module_values = Rank1EditModule(linear_values, input_covariance = input_covariance, is_key_proj = False, **rank1_edit_module_kwargs)
# 将 Rank1EditModule 设置为键投影和值投影
setattr(cross_attention, key_proj_name, rank1_edit_module_keys)
setattr(cross_attention, value_proj_name, rank1_edit_module_values)
.\lucidrains\perfusion-pytorch\perfusion_pytorch\save_load.py
# 导入所需的模块
from pathlib import Path
import torch
from torch import nn
from torch.nn import Module
from beartype import beartype
from perfusion_pytorch.embedding import EmbeddingWrapper
from perfusion_pytorch.perfusion import Rank1EditModule
# 辅助函数
# 检查值是否存在
def exists(val):
return val is not None
# 保存和加载必要的额外微调参数
# 保存函数,将模型的参数保存到指定路径
@beartype
def save(
text_image_model: Module,
path: str
):
# 将路径转换为 Path 对象
path = Path(path)
# 创建路径的父目录,如果不存在则创建
path.parents[0].mkdir(exist_ok=True, parents=True)
embed_params = None
key_value_params = []
C_inv = None
# 遍历模型的所有模块
for module in text_image_model.modules():
# 如果模块是 EmbeddingWrapper 类型
if isinstance(module, EmbeddingWrapper):
# 确保只有一个包装的 EmbeddingWrapper
assert not exists(embed_params), 'there should only be one wrapped EmbeddingWrapper'
embed_params = module.concepts.data
# 如果模块是 Rank1EditModule 类型
elif isinstance(module, Rank1EditModule):
# 将模块的参数添加到列表中
key_value_params.append([
module.ema_concept_text_encs.data,
module.concept_outputs.data
])
C_inv = module.C_inv.data
# 确保 C_inv 参数存在
assert exists(C_inv), 'Rank1EditModule not found. you likely did not wire up the text to image model correctly'
# 将参数打包成字典
pkg = dict(
embed_params=embed_params,
key_value_params=key_value_params,
C_inv=C_inv
)
# 保存参数到指定路径
torch.save(pkg, f'{str(path)}')
print(f'saved to {str(path)}')
# 加载函数,从指定路径加载参数到模型
@beartype
def load(
text_image_model: Module,
path: str
):
# 将路径转换为 Path 对象
path = Path(path)
# 检查文件是否存在
assert path.exists(), f'file not found at {str(path)}'
# 加载保存的参数
pkg = torch.load(str(path))
embed_params = pkg['embed_params']
key_value_params = pkg['key_value_params']
C_inv = pkg['C_inv']
# 遍历模型的所有模块
for module in text_image_model.modules():
# 如果模块是 EmbeddingWrapper 类型
if isinstance(module, EmbeddingWrapper):
# 将加载的参数复制到模块中
module.concepts.data.copy_(embed_params)
# 如果模块是 Rank1EditModule 类型
elif isinstance(module, Rank1EditModule):
# 确保保存的参数和加载的参数匹配
assert len(key_value_params) > 0, 'mismatch between what was saved vs what is being loaded'
concept_input, concept_output = key_value_params.pop(0)
module.ema_concept_text_encs.data.copy_(concept_input)
module.concept_outputs.data.copy_(concept_output)
module.C_inv.copy_(C_inv)
module.initted.copy_(torch.tensor([True]))
print(f'loaded concept params from {str(path)}')
.\lucidrains\perfusion-pytorch\perfusion_pytorch\__init__.py
# 从perfusion_pytorch.perfusion模块中导入Rank1EditModule、calculate_input_covariance、loss_fn_weighted_by_mask、merge_rank1_edit_modules、make_key_value_proj_rank1_edit_modules_函数
from perfusion_pytorch.perfusion import (
Rank1EditModule,
calculate_input_covariance,
loss_fn_weighted_by_mask,
merge_rank1_edit_modules,
make_key_value_proj_rank1_edit_modules_
)
# 从perfusion_pytorch.embedding模块中导入EmbeddingWrapper、OpenClipEmbedWrapper、merge_embedding_wrappers函数
from perfusion_pytorch.embedding import (
EmbeddingWrapper,
OpenClipEmbedWrapper,
merge_embedding_wrappers
)
# 从perfusion_pytorch.save_load模块中导入save、load函数
from perfusion_pytorch.save_load import (
save,
load
)
# 从perfusion_pytorch.optimizer模块中导入get_finetune_parameters、get_finetune_optimizer函数
from perfusion_pytorch.optimizer import (
get_finetune_parameters,
get_finetune_optimizer
)

Perfusion - Pytorch
Implementation of Key-Locked Rank One Editing. Project page
The selling point of this paper is extremely low extra parameters per added concept, down to 100kb.
It seems they successfully applied the Rank-1 editing technique from a memory editing paper for LLM, with a few improvements. They also identified that the keys determine the "where" of the new concept, while the values determine the "what", and propose local / global-key locking to a superclass concept (while learning the values).
For researchers out there, if this paper checks out, the tools in this repository should work for any other text-to-<insert modality> network using cross attention conditioning. Just a thought
Appreciation
-
StabilityAI for the generous sponsorship, as well as my other sponsors out there
-
Yoad Tewel for the multiple code reviews and clarifying emails
-
Brad Vidler for precomputing the covariance matrix for the CLIP used in Stable Diffusion 1.5!
-
All the maintainers at OpenClip, for their SOTA open sourced contrastive learning text-image models
Install
$ pip install perfusion-pytorch
Usage
import torch
from torch import nn
from perfusion_pytorch import Rank1EditModule
to_keys = nn.Linear(768, 320, bias = False)
to_values = nn.Linear(768, 320, bias = False)
wrapped_to_keys = Rank1EditModule(
to_keys,
is_key_proj = True
)
wrapped_to_values = Rank1EditModule(
to_values
)
text_enc = torch.randn(4, 77, 768) # regular input
text_enc_with_superclass = torch.randn(4, 77, 768) # init_input in algorithm 1, for key-locking
concept_indices = torch.randint(0, 77, (4,)) # index where the concept or superclass concept token is in the sequence
key_pad_mask = torch.ones(4, 77).bool()
keys = wrapped_to_keys(
text_enc,
concept_indices = concept_indices,
text_enc_with_superclass = text_enc_with_superclass,
)
values = wrapped_to_values(
text_enc,
concept_indices = concept_indices,
text_enc_with_superclass = text_enc_with_superclass,
)
# after much training ...
wrapped_to_keys.eval()
wrapped_to_values.eval()
keys = wrapped_to_keys(text_enc)
values = wrapped_to_values(text_enc)
The repository also contains an EmbeddingWrapper that makes it easy to train on a new concept (and for eventual inference with multiple concepts)
import torch
from torch import nn
from perfusion_pytorch import EmbeddingWrapper
embed = nn.Embedding(49408, 512) # open clip embedding, somewhere in the module tree of stable diffusion
# wrap it, and will automatically create a new concept for learning, based on the superclass embed string
wrapped_embed = EmbeddingWrapper(
embed,
superclass_string = 'dog'
)
# now just pass in your prompts with the superclass id
embeds_with_new_concept, embeds_with_superclass, embed_mask, concept_indices = wrapped_embed([
'a portrait of dog',
'dog running through a green field',
'a man walking his dog'
]) # (3, 77, 512), (3, 77, 512), (3, 77), (3,)
# now pass both embeds through clip text transformer
# the embed_mask needs to be passed to the cross attention as key padding mask
If you can identify the CLIP instance within the stable diffusion instance, you can also pass it directly to the OpenClipEmbedWrapper to gain everything you need on forward for the cross attention layers
ex.
from perfusion_pytorch import OpenClipEmbedWrapper
texts = [
'a portrait of dog',
'dog running through a green field',
'a man walking his dog'
]
wrapped_clip_with_new_concept = OpenClipEmbedWrapper(
stable_diffusion.path.to.clip,
superclass_string = 'dog'
)
text_enc, superclass_enc, mask, indices = wrapped_clip_with_new_concept(texts)
# (3, 77, 512), (3, 77, 512), (3, 77), (3,)
Todo
-
wire up with SD 1.5, starting with xiao's dreambooth-sd
-
show example in readme for inference with multiple concepts
-
automatically infer where keys and values projection are if not specified for the
make_key_value_proj_rank1_edit_modules_function -
embedding wrapper should take care of substituting with super class token id and return embedding with super class
-
review multiple concepts - thanks to Yoad
-
offer a function that wires up the cross attention
-
handle multiple concepts in one prompt at inference - summation of the sigmoid term + outputs
- accept multiple concept indices
-
offer a way to combine separately learned concepts from multiple
Rank1EditModuleinto one for inference- offer function for merging
Rank1EditModules
- offer function for merging
-
add the zero-shot masking of concept proposed in paper
-
take care of the function that takes in the dataset and text encoder and precomputes the covariance matrix needed for the rank-1 update
-
instead of having the researcher worry about different learning rates, offer the fractional gradient trick from other paper (to learn the concept embedding)
Citations
@article{Tewel2023KeyLockedRO,
title = {Key-Locked Rank One Editing for Text-to-Image Personalization},
author = {Yoad Tewel and Rinon Gal and Gal Chechik and Yuval Atzmon},
journal = {ACM SIGGRAPH 2023 Conference Proceedings},
year = {2023},
url = {https://api.semanticscholar.org/CorpusID:258436985}
}
@inproceedings{Meng2022LocatingAE,
title = {Locating and Editing Factual Associations in GPT},
author = {Kevin Meng and David Bau and Alex Andonian and Yonatan Belinkov},
booktitle = {Neural Information Processing Systems},
year = {2022},
url = {https://api.semanticscholar.org/CorpusID:255825985}
}
.\lucidrains\perfusion-pytorch\setup.py
# 导入设置工具和查找包的函数
from setuptools import setup, find_packages
# 设置包的元数据
setup(
name = 'perfusion-pytorch', # 包的名称
packages = find_packages(exclude=[]), # 查找所有包
version = '0.1.23', # 版本号
license='MIT', # 许可证
description = 'Perfusion - Pytorch', # 描述
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
long_description_content_type = 'text/markdown', # 长描述内容类型
url = 'https://github.com/lucidrains/perfusion-pytorch', # 项目链接
keywords = [ # 关键词列表
'artificial intelligence',
'deep learning',
'memory editing',
'text-to-image'
],
install_requires=[ # 安装依赖
'beartype',
'einops>=0.6.1',
'open-clip-torch',
'opt-einsum',
'torch>=2.0'
],
include_package_data = True, # 包含数据文件
classifiers=[ # 分类器列表
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)

Phasic Policy Gradient - Pytorch
An implementation of Phasic Policy Gradient, a proposed improvement on top of Proximal Policy Optimization (PPO), in Pytorch. It will be my very first project in Reinforcement Learning.
Install
$ pip install -r requirements.txt
Use
$ python train.py --render
Citations
@misc{cobbe2020phasic,
title={Phasic Policy Gradient},
author={Karl Cobbe and Jacob Hilton and Oleg Klimov and John Schulman},
year={2020},
eprint={2009.04416},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
.\lucidrains\phasic-policy-gradient\train.py
# 导入必要的库
import os
import fire
from collections import deque, namedtuple
from tqdm import tqdm
import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from torch.distributions import Categorical
import torch.nn.functional as F
import gym
# 定义常量
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# 定义命名元组
Memory = namedtuple('Memory', ['state', 'action', 'action_log_prob', 'reward', 'done', 'value'])
AuxMemory = namedtuple('Memory', ['state', 'target_value', 'old_values'])
# 定义数据集类
class ExperienceDataset(Dataset):
def __init__(self, data):
super().__init__()
self.data = data
def __len__(self):
return len(self.data[0])
def __getitem__(self, ind):
return tuple(map(lambda t: t[ind], self.data))
# 创建混洗数据加载器
def create_shuffled_dataloader(data, batch_size):
ds = ExperienceDataset(data)
return DataLoader(ds, batch_size = batch_size, shuffle = True)
# 辅助函数
# 检查值是否存在
def exists(val):
return val is not None
# 归一化函数
def normalize(t, eps = 1e-5):
return (t - t.mean()) / (t.std() + eps)
# 更新网络参数
def update_network_(loss, optimizer):
optimizer.zero_grad()
loss.mean().backward()
optimizer.step()
# 初始化网络参数
def init_(m):
if isinstance(m, nn.Linear):
gain = torch.nn.init.calculate_gain('tanh')
torch.nn.init.orthogonal_(m.weight, gain)
if m.bias is not None:
torch.nn.init.zeros_(m.bias)
# 定义 Actor 神经网络类
class Actor(nn.Module):
def __init__(self, state_dim, hidden_dim, num_actions):
super().__init__()
self.net = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, hidden_dim),
nn.Tanh()
)
self.action_head = nn.Sequential(
nn.Linear(hidden_dim, num_actions),
nn.Softmax(dim=-1)
)
self.value_head = nn.Linear(hidden_dim, 1)
self.apply(init_)
def forward(self, x):
hidden = self.net(x)
return self.action_head(hidden), self.value_head(hidden)
# 定义 Critic 神经网络类
class Critic(nn.Module):
def __init__(self, state_dim, hidden_dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, 1),
)
self.apply(init_)
def forward(self, x):
return self.net(x)
# 定义 PPG 代理类
class PPG:
def __init__(
self,
state_dim,
num_actions,
actor_hidden_dim,
critic_hidden_dim,
epochs,
epochs_aux,
minibatch_size,
lr,
betas,
lam,
gamma,
beta_s,
eps_clip,
value_clip
):
self.actor = Actor(state_dim, actor_hidden_dim, num_actions).to(device)
self.critic = Critic(state_dim, critic_hidden_dim).to(device)
self.opt_actor = Adam(self.actor.parameters(), lr=lr, betas=betas)
self.opt_critic = Adam(self.critic.parameters(), lr=lr, betas=betas)
self.minibatch_size = minibatch_size
self.epochs = epochs
self.epochs_aux = epochs_aux
self.lam = lam
self.gamma = gamma
self.beta_s = beta_s
self.eps_clip = eps_clip
self.value_clip = value_clip
# 保存模型参数
def save(self):
torch.save({
'actor': self.actor.state_dict(),
'critic': self.critic.state_dict()
}, f'./ppg.pt')
# 加载模型参数
def load(self):
# 检查是否存在模型参数文件
if not os.path.exists('./ppg.pt'):
return
# 从文件中加载模型参数
data = torch.load(f'./ppg.pt')
# 更新 actor 模型参数
self.actor.load_state_dict(data['actor'])
# 更新 critic 模型参数
self.critic.load_state_dict(data['critic'])
# 学习函数,用于训练模型
def learn(self, memories, aux_memories, next_state):
# 从记忆中提取并准备训练数据
states = []
actions = []
old_log_probs = []
rewards = []
masks = []
values = []
for mem in memories:
states.append(mem.state)
actions.append(torch.tensor(mem.action))
old_log_probs.append(mem.action_log_prob)
rewards.append(mem.reward)
masks.append(1 - float(mem.done))
values.append(mem.value)
# 计算广义优势估计值
next_state = torch.from_numpy(next_state).to(device)
next_value = self.critic(next_state).detach()
values = values + [next_value]
returns = []
gae = 0
for i in reversed(range(len(rewards))):
delta = rewards[i] + self.gamma * values[i + 1] * masks[i] - values[i]
gae = delta + self.gamma * self.lam * masks[i] * gae
returns.insert(0, gae + values[i])
# 将值转换为 torch 张量
to_torch_tensor = lambda t: torch.stack(t).to(device).detach()
states = to_torch_tensor(states)
actions = to_torch_tensor(actions)
old_values = to_torch_tensor(values[:-1])
old_log_probs = to_torch_tensor(old_log_probs)
rewards = torch.tensor(returns).float().to(device)
# 将状态和目标值存储到辅助内存缓冲区以供后续训练使用
aux_memory = AuxMemory(states, rewards, old_values)
aux_memories.append(aux_memory)
# 为策略阶段训练准备数据加载器
dl = create_shuffled_dataloader([states, actions, old_log_probs, rewards, old_values], self.minibatch_size)
# 策略阶段训练,类似于原始的 PPO
for _ in range(self.epochs):
for states, actions, old_log_probs, rewards, old_values in dl:
action_probs, _ = self.actor(states)
values = self.critic(states)
dist = Categorical(action_probs)
action_log_probs = dist.log_prob(actions)
entropy = dist.entropy()
# 计算剪切的替代目标,经典的 PPO 损失
ratios = (action_log_probs - old_log_probs).exp()
advantages = normalize(rewards - old_values.detach())
surr1 = ratios * advantages
surr2 = ratios.clamp(1 - self.eps_clip, 1 + self.eps_clip) * advantages
policy_loss = - torch.min(surr1, surr2) - self.beta_s * entropy
# 更新策略网络
update_network_(policy_loss, self.opt_actor)
# 计算值损失并更新值网络,与策略网络分开
value_loss = clipped_value_loss(values, rewards, old_values, self.value_clip)
update_network_(value_loss, self.opt_critic)
# 定义一个辅助学习函数,用于训练辅助记忆
def learn_aux(self, aux_memories):
# 将状态和目标值合并成一个张量
states = []
rewards = []
old_values = []
for state, reward, old_value in aux_memories:
states.append(state)
rewards.append(reward)
old_values.append(old_value)
# 将状态、奖励和旧值连接成一个张量
states = torch.cat(states)
rewards = torch.cat(rewards)
old_values = torch.cat(old_values)
# 获取用于最小化 kl 散度和剪切的旧动作预测值
old_action_probs, _ = self.actor(states)
old_action_probs.detach_()
# 为辅助阶段训练准备数据加载器
dl = create_shuffled_dataloader([states, old_action_probs, rewards, old_values], self.minibatch_size)
# 提出的辅助阶段训练
# 在将值蒸馏到策略网络的同时,确保策略网络不改变动作预测值(kl 散度损失)
for epoch in range(self.epochs_aux):
for states, old_action_probs, rewards, old_values in tqdm(dl, desc=f'auxiliary epoch {epoch}'):
action_probs, policy_values = self.actor(states)
action_logprobs = action_probs.log()
# 策略网络损失由 kl 散度损失和辅助损失组成
aux_loss = clipped_value_loss(policy_values, rewards, old_values, self.value_clip)
loss_kl = F.kl_div(action_logprobs, old_action_probs, reduction='batchmean')
policy_loss = aux_loss + loss_kl
# 更新策略网络
update_network_(policy_loss, self.opt_actor)
# 论文指出在辅助阶段额外训练值网络非常重要
values = self.critic(states)
value_loss = clipped_value_loss(values, rewards, old_values, self.value_clip)
# 更新值网络
update_network_(value_loss, self.opt_critic)
# 主函数
def main(
env_name = 'LunarLander-v2', # 环境名称,默认为'LunarLander-v2'
num_episodes = 50000, # 总的训练轮数,默认为50000
max_timesteps = 500, # 每轮最大时间步数,默认为500
actor_hidden_dim = 32, # Actor神经网络隐藏层维度,默认为32
critic_hidden_dim = 256, # Critic神经网络隐藏层维度,默认为256
minibatch_size = 64, # 每次训练的样本批量大小,默认为64
lr = 0.0005, # 学习率,默认为0.0005
betas = (0.9, 0.999), # Adam优化器的beta参数,默认为(0.9, 0.999)
lam = 0.95, # GAE的lambda参数,默认为0.95
gamma = 0.99, # 折扣因子,默认为0.99
eps_clip = 0.2, # PPO算法的epsilon clip参数,默认为0.2
value_clip = 0.4, # Critic的值函数clip参数,默认为0.4
beta_s = .01, # 熵损失的权重参数,默认为0.01
update_timesteps = 5000, # 更新模型的时间步数间隔,默认为5000
num_policy_updates_per_aux = 32, # 辅助网络更新次数,默认为32
epochs = 1, # 主网络训练轮数,默认为1
epochs_aux = 6, # 辅助网络训练轮数,默认为6
seed = None, # 随机种子,默认为None
render = False, # 是否渲染环境,默认为False
render_every_eps = 250, # 每隔多少轮渲染一次,默认为250
save_every = 1000, # 每隔多少轮保存模型,默认为1000
load = False, # 是否加载已有模型,默认为False
monitor = False # 是否监视环境,默认为False
):
env = gym.make(env_name) # 创建环境
if monitor:
env = gym.wrappers.Monitor(env, './tmp/', force=True) # 监视环境
state_dim = env.observation_space.shape[0] # 状态空间维度
num_actions = env.action_space.n # 动作空间维度
memories = deque([]) # 存储经验的队列
aux_memories = deque([]) # 存储辅助经验的队列
agent = PPG( # 创建PPO算法的代理
state_dim,
num_actions,
actor_hidden_dim,
critic_hidden_dim,
epochs,
epochs_aux,
minibatch_size,
lr,
betas,
lam,
gamma,
beta_s,
eps_clip,
value_clip
)
if load:
agent.load() # 加载模型
if exists(seed): # 如果存在随机种子
torch.manual_seed(seed) # 设置PyTorch随机种子
np.random.seed(seed) # 设置NumPy随机种子
time = 0 # 时间步数
updated = False # 是否更新模型
num_policy_updates = 0 # 策略更新次数
for eps in tqdm(range(num_episodes), desc='episodes'): # 遍历训练轮数
render_eps = render and eps % render_every_eps == 0 # 是否渲染当前轮次
state = env.reset() # 重置环境状态
for timestep in range(max_timesteps): # 遍历每个时间步
time += 1 # 时间步数加1
if updated and render_eps: # 如果已更新并需要渲染
env.render() # 渲染环境
state = torch.from_numpy(state).to(device) # 转换状态为PyTorch张量
action_probs, _ = agent.actor(state) # 获取动作概率
value = agent.critic(state) # 获取值函数
dist = Categorical(action_probs) # 创建分类分布
action = dist.sample() # 采样动作
action_log_prob = dist.log_prob(action) # 计算动作对数概率
action = action.item() # 转换动作为标量
next_state, reward, done, _ = env.step(action) # 执行动作
memory = Memory(state, action, action_log_prob, reward, done, value) # 创建经验
memories.append(memory) # 将经验添加到队列
state = next_state # 更新状态
if time % update_timesteps == 0: # 如果达到更新时间步
agent.learn(memories, aux_memories, next_state) # 更新主网络
num_policy_updates += 1 # 策略更新次数加1
memories.clear() # 清空经验队列
if num_policy_updates % num_policy_updates_per_aux == 0: # 达到辅助网络更新次数
agent.learn_aux(aux_memories) # 更新辅助网络
aux_memories.clear() # 清空辅助经验队列
updated = True # 设置为已更新
if done: # 如果环境结束
if render_eps: # 如果需要渲染
updated = False # 设置为未更新
break # 跳出循环
if render_eps: # 如果需要渲染
env.close() # 关闭环境
if eps % save_every == 0: # 每隔一定轮次保存模型
agent.save() # 保存模型
if __name__ == '__main__':
fire.Fire(main) # 使用Fire库执行主函数
.\lucidrains\phenaki-pytorch\phenaki_pytorch\attention.py
# 初始化注意力机制模块
def __init__(
self,
dim,
dim_context = None,
dim_head = 64,
heads = 8,
causal = False,
num_null_kv = 0,
norm_context = True,
dropout = 0.,
scale = 8
):
# 调用父类初始化方法
super().__init__()
# 设置注意力头数
self.heads = heads
# 是否为因果注意力
self.causal = causal
# 缩放因子
self.scale = scale
# 内部维度
inner_dim = dim_head * heads
# 如果未指定上下文维度,则默认为输入维度
dim_context = default(dim_context, dim)
# 如果是因果注意力,则使用AlibiPositionalBias初始化相对位置偏置
if causal:
self.rel_pos_bias = AlibiPositionalBias(heads = heads)
# 注意力机制的dropout层
self.attn_dropout = nn.Dropout(dropout)
# 输入的LayerNorm层
self.norm = LayerNorm(dim)
# 上下文的LayerNorm层(如果需要规范化上下文)
self.context_norm = LayerNorm(dim_context) if norm_context else nn.Identity()
# 空键值对的数量
self.num_null_kv = num_null_kv
# 空键值对参数
self.null_kv = nn.Parameter(torch.randn(heads, 2 * num_null_kv, dim_head))
# 查询转换层
self.to_q = nn.Linear(dim, inner_dim, bias = False)
# 键值对转换层
self.to_kv = nn.Linear(dim_context, inner_dim * 2, bias = False)
# 查询缩放参数
self.q_scale = nn.Parameter(torch.ones(dim_head))
# 键缩放参数
self.k_scale = nn.Parameter(torch.ones(dim_head))
# 输出转换层
self.to_out = nn.Linear(inner_dim, dim, bias = False)
# 获取输入张量 x 的批量大小、设备和数据类型
batch, device, dtype = x.shape[0], x.device, x.dtype
# 如果上下文存在,则对上下文进行归一化处理
if exists(context):
context = self.context_norm(context)
# 对输入张量 x 进行归一化处理
x = self.norm(x)
# 将输入张量 x 转换为查询(q)、键(k)、值(v)张量
q, k, v = self.to_q(x), *self.to_kv(kv_input).chunk(2, dim = -1)
# 将查询(q)、键(k)、值(v)张量按照指定维度重新排列
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v))
# 重复空键值对(null_kv)以匹配批量大小和维度
nk, nv = repeat(self.null_kv, 'h (n r) d -> b h n r d', b = batch, r = 2).unbind(dim = -2)
# 将键(k)和值(v)张量与空键值对(nk、nv)进行拼接
k = torch.cat((nk, k), dim = -2)
v = torch.cat((nv, v), dim = -2)
# 对查询(q)和键(k)进行 L2 归一化处理
q, k = map(l2norm, (q, k))
q = q * self.q_scale
k = k * self.k_scale
# 计算查询(q)和键(k)之间的相似度
sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
i, j = sim.shape[-2:]
# 如果存在注意力偏置(attn_bias),则对相似度矩阵进行加权
if exists(attn_bias):
attn_bias = F.pad(attn_bias, (self.num_null_kv, 0), value = 0.)
sim = sim + attn_bias
# 如果存在掩码(mask),则对掩码进行处理
if exists(mask):
mask = F.pad(mask, (self.num_null_kv, 0), value = True)
mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
# 如果启用因果注意力,则对相似度矩阵进行处理
if self.causal:
sim = sim + self.rel_pos_bias(sim)
causal_mask = torch.ones((i, j), device = device, dtype = torch.bool).triu(j - i + 1)
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
# 对相似度矩阵进行 softmax 操作
attn = sim.softmax(dim = -1)
attn = self.attn_dropout(attn)
# 计算输出张量
out = einsum('b h i j, b h j d -> b h i d', attn, v)
# 重新排列输出张量的维度
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
# 定义一个名为 AlibiPositionalBias 的类,用于处理位置偏差
class AlibiPositionalBias(nn.Module):
def __init__(self, heads):
super().__init__()
self.heads = heads
# 初始化斜率参数
slopes = torch.Tensor(self._get_slopes(heads))
slopes = rearrange(slopes, 'h -> h 1 1')
# 注册斜率参数和偏差参数
self.register_buffer('slopes', slopes, persistent = False)
self.register_buffer('bias', None, persistent = False)
# 获取偏差值
def get_bias(self, i, j, device):
i_arange = torch.arange(j - i, j, device = device)
j_arange = torch.arange(j, device = device)
bias = -torch.abs(rearrange(j_arange, 'j -> 1 1 j') - rearrange(i_arange, 'i -> 1 i 1'))
return bias
# 获取斜率参数
@staticmethod
def _get_slopes(heads):
def get_slopes_power_of_2(n):
start = (2**(-2**-(math.log2(n)-3)))
ratio = start
return [start*ratio**i for i in range(n)]
if math.log2(heads).is_integer():
return get_slopes_power_of_2(heads)
closest_power_of_2 = 2 ** math.floor(math.log2(heads))
return get_slopes_power_of_2(closest_power_of_2) + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][:heads-closest_power_of_2]
# 前向传播函数
def forward(self, sim):
h, i, j, device = *sim.shape[-3:], sim.device
if exists(self.bias) and self.bias.shape[-1] >= j:
return self.bias[..., :i, :j]
bias = self.get_bias(i, j, device)
bias = bias * self.slopes
num_heads_unalibied = h - bias.shape[0]
bias = F.pad(bias, (0, 0, 0, 0, 0, num_heads_unalibied))
self.register_buffer('bias', bias, persistent = False)
return self.bias
# 定义一个名为 ContinuousPositionBias 的类,用于处理连续位置偏差
class ContinuousPositionBias(nn.Module):
""" from https://arxiv.org/abs/2111.09883 """
def __init__(
self,
*,
dim,
heads,
num_dims = 2, # 2 for images, 3 for video
layers = 2,
log_dist = True,
cache_rel_pos = False
):
super().__init__()
self.num_dims = num_dims
self.log_dist = log_dist
self.net = nn.ModuleList([])
self.net.append(nn.Sequential(nn.Linear(self.num_dims, dim), leaky_relu()))
for _ in range(layers - 1):
self.net.append(nn.Sequential(nn.Linear(dim, dim), leaky_relu()))
self.net.append(nn.Linear(dim, heads)
self.cache_rel_pos = cache_rel_pos
self.register_buffer('rel_pos', None, persistent = False)
# 前向传播函数
def forward(self, *dimensions, device = torch.device('cpu')):
if not exists(self.rel_pos) or not self.cache_rel_pos:
positions = [torch.arange(d, device = device) for d in dimensions]
grid = torch.stack(torch.meshgrid(*positions, indexing = 'ij'))
grid = rearrange(grid, 'c ... -> (...) c')
rel_pos = rearrange(grid, 'i c -> i 1 c') - rearrange(grid, 'j c -> 1 j c')
if self.log_dist:
rel_pos = torch.sign(rel_pos) * torch.log(rel_pos.abs() + 1)
self.register_buffer('rel_pos', rel_pos, persistent = False)
rel_pos = self.rel_pos.float()
for layer in self.net:
rel_pos = layer(rel_pos)
return rearrange(rel_pos, 'i j h -> h i j')
# 定义一个名为 Transformer 的类,用于实现 Transformer 模型
class Transformer(nn.Module):
def __init__(
self,
dim,
*,
depth,
dim_context = None,
causal = False,
dim_head = 64,
heads = 8,
ff_mult = 4,
peg = False,
peg_causal = False,
attn_num_null_kv = 2,
has_cross_attn = False,
attn_dropout = 0.,
ff_dropout = 0.
):
# 调用父类的构造函数
super().__init__()
# 初始化一个空的神经网络模块列表
self.layers = nn.ModuleList([])
# 循环depth次,向神经网络模块列表中添加不同的模块
for _ in range(depth):
self.layers.append(nn.ModuleList([
# 如果peg为真,则添加一个PEG模块,否则添加None
PEG(dim = dim, causal = peg_causal) if peg else None,
# 添加一个Attention模块
Attention(dim = dim, dim_head = dim_head, heads = heads, causal = causal, dropout = attn_dropout),
# 如果has_cross_attn为真,则添加一个带有跨注意力的Attention模块,否则添加None
Attention(dim = dim, dim_head = dim_head, dim_context = dim_context, heads = heads, causal = False, num_null_kv = attn_num_null_kv, dropout = attn_dropout) if has_cross_attn else None,
# 添加一个FeedForward模块
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
]))
# 初始化一个LayerNorm模块
self.norm_out = LayerNorm(dim)
@beartype
def forward(
self,
x,
video_shape: Tuple[int, int, int, int] = None,
attn_bias = None,
context = None,
self_attn_mask = None,
cross_attn_context_mask = None
):
# 遍历神经网络模块列表中的不同模块
for peg, self_attn, cross_attn, ff in self.layers:
# 如果存在PEG模块,则对输入进行处理并与原始输入相加
if exists(peg):
x = peg(x, shape = video_shape) + x
# 对输入进行自注意力处理并与原始输入相加
x = self_attn(x, attn_bias = attn_bias, mask = self_attn_mask) + x
# 如果存在跨注意力模块且存在上下文信息,则对输入进行处理并与原始输入相加
if exists(cross_attn) and exists(context):
x = cross_attn(x, context = context, mask = cross_attn_context_mask) + x
# 对输入进行前馈处理并与原始输入相加
x = ff(x) + x
# 对处理后的结果进行LayerNorm处理并返回
return self.norm_out(x)
.\lucidrains\phenaki-pytorch\phenaki_pytorch\cvivit.py
# 导入必要的库
from pathlib import Path
import copy
import math
from functools import wraps
import torch
import torch.nn.functional as F
from torch import nn, einsum
from torch.autograd import grad as torch_grad
import torchvision
# 导入 einops 库中的函数
from einops import rearrange, repeat, pack, unpack
from einops.layers.torch import Rearrange
# 导入自定义的模块
from vector_quantize_pytorch import VectorQuantize, LFQ
from phenaki_pytorch.attention import Attention, Transformer, ContinuousPositionBias
# 定义一些辅助函数
# 判断变量是否存在
def exists(val):
return val is not None
# 如果变量存在则返回其值,否则返回默认值
def default(val, d):
return val if exists(val) else d
# 判断一个数是否可以被另一个数整除
def divisible_by(numer, denom):
return (numer % denom) == 0
# 定义 leaky_relu 激活函数
def leaky_relu(p = 0.1):
return nn.LeakyReLU(p)
# 移除 vgg 属性的装饰器
def remove_vgg(fn):
@wraps(fn)
def inner(self, *args, **kwargs):
has_vgg = hasattr(self, 'vgg')
if has_vgg:
vgg = self.vgg
delattr(self, 'vgg')
out = fn(self, *args, **kwargs)
if has_vgg:
self.vgg = vgg
return out
return inner
# 将单个值转换为元组
def pair(val):
ret = (val, val) if not isinstance(val, tuple) else val
assert len(ret) == 2
return ret
# 将单个值转换为指定长度的元组
def cast_tuple(val, l = 1):
return val if isinstance(val, tuple) else (val,) * l
# 计算梯度惩罚
def gradient_penalty(images, output, weight = 10):
batch_size = images.shape[0]
gradients = torch_grad(
outputs = output,
inputs = images,
grad_outputs = torch.ones(output.size(), device = images.device),
create_graph = True,
retain_graph = True,
only_inputs = True
)[0]
gradients = rearrange(gradients, 'b ... -> b (...)')
return weight * ((gradients.norm(2, dim = 1) - 1) ** 2).mean()
# 对张量进行 L2 归一化
def l2norm(t):
return F.normalize(t, dim = -1)
# 安全除法,避免分母为零
def safe_div(numer, denom, eps = 1e-8):
return numer / (denom + eps)
# 定义 GAN 损失函数
# Hinge 损失函数(判别器)
def hinge_discr_loss(fake, real):
return (F.relu(1 + fake) + F.relu(1 - real)).mean()
# Hinge 损失函数(生成器)
def hinge_gen_loss(fake):
return -fake.mean()
# 二元交叉熵损失函数(判别器)
def bce_discr_loss(fake, real):
return (-log(1 - torch.sigmoid(fake)) - log(torch.sigmoid(real))).mean()
# 二元交叉熵损失函数(生成器)
def bce_gen_loss(fake):
return -log(torch.sigmoid(fake)).mean()
# 计算损失函数对某一层的梯度
def grad_layer_wrt_loss(loss, layer):
return torch_grad(
outputs = loss,
inputs = layer,
grad_outputs = torch.ones_like(loss),
retain_graph = True
)[0].detach()
# 定义判别器模块
class DiscriminatorBlock(nn.Module):
def __init__(
self,
input_channels,
filters,
downsample = True
):
super().__init__()
self.conv_res = nn.Conv2d(input_channels, filters, 1, stride = (2 if downsample else 1))
self.net = nn.Sequential(
nn.Conv2d(input_channels, filters, 3, padding=1),
leaky_relu(),
nn.Conv2d(filters, filters, 3, padding=1),
leaky_relu()
)
self.downsample = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (c p1 p2) h w', p1 = 2, p2 = 2),
nn.Conv2d(filters * 4, filters, 1)
) if downsample else None
def forward(self, x):
res = self.conv_res(x)
x = self.net(x)
if exists(self.downsample):
x = self.downsample(x)
x = (x + res) * (1 / math.sqrt(2))
return x
class Discriminator(nn.Module):
def __init__(
self,
*,
dim,
image_size,
channels = 3,
attn_res_layers = (16,),
max_dim = 512
# 初始化函数,继承父类的初始化方法
):
# 调用父类的初始化方法
super().__init__()
# 将图像大小转换为元组
image_size = pair(image_size)
# 计算图像的最小分辨率
min_image_resolution = min(image_size)
# 计算层数,根据最小分辨率
num_layers = int(math.log2(min_image_resolution) - 2)
# 将注意力层的分辨率转换为元组
attn_res_layers = cast_tuple(attn_res_layers, num_layers)
# 初始化块列表
blocks = []
# 计算每一层的维度
layer_dims = [channels] + [(dim * 4) * (2 ** i) for i in range(num_layers + 1)]
# 将每一层的维度限制在最大维度内
layer_dims = [min(layer_dim, max_dim) for layer_dim in layer_dims]
# 将每一层的输入输出维度组成元组
layer_dims_in_out = tuple(zip(layer_dims[:-1], layer_dims[1:]))
# 初始化块列表和注意力块列表
blocks = []
attn_blocks = []
# 初始化图像分辨率
image_resolution = min_image_resolution
# 遍历每一层的输入输出维度
for ind, (in_chan, out_chan) in enumerate(layer_dims_in_out):
# 计算当前层的编号
num_layer = ind + 1
# 判断是否为最后一层
is_not_last = ind != (len(layer_dims_in_out) - 1)
# 创建鉴别器块
block = DiscriminatorBlock(in_chan, out_chan, downsample = is_not_last)
blocks.append(block)
# 初始化注意力块
attn_block = None
if image_resolution in attn_res_layers:
attn_block = Attention(dim = out_chan)
attn_blocks.append(attn_block)
# 更新图像分辨率
image_resolution //= 2
# 将块列表和注意力块列表转换为模块列表
self.blocks = nn.ModuleList(blocks)
self.attn_blocks = nn.ModuleList(attn_blocks)
# 计算最后一层的维度
dim_last = layer_dims[-1]
# 计算下采样因子
downsample_factor = 2 ** num_layers
# 计算最后特征图的大小
last_fmap_size = tuple(map(lambda n: n // downsample_factor, image_size))
# 计算潜在维度
latent_dim = last_fmap_size[0] * last_fmap_size[1] * dim_last
# 定义输出层
self.to_logits = nn.Sequential(
nn.Conv2d(dim_last, dim_last, 3, padding = 1),
leaky_relu(),
Rearrange('b ... -> b (...)'),
nn.Linear(latent_dim, 1),
Rearrange('b 1 -> b')
)
# 前向传播函数
def forward(self, x):
# 遍历块列表和注意力块列表
for block, attn_block in zip(self.blocks, self.attn_blocks):
# 应用块
x = block(x)
# 如果存在注意力块
if exists(attn_block):
x, ps = pack([x], 'b c *')
x = rearrange(x, 'b c n -> b n c')
x = attn_block(x) + x
x = rearrange(x, 'b n c -> b c n')
x, = unpack(x, ps, 'b c *')
# 返回输出结果
return self.to_logits(x)
# 定义一个函数,用于从视频中选择指定帧的图像
def pick_video_frame(video, frame_indices):
# 获取视频的批量大小和设备信息
batch, device = video.shape[0], video.device
# 重新排列视频张量的维度,将通道维度放在第二个位置
video = rearrange(video, 'b c f ... -> b f c ...')
# 创建一个包含批量索引的张量
batch_indices = torch.arange(batch, device=device)
batch_indices = rearrange(batch_indices, 'b -> b 1')
# 从视频中选择指定帧的图像
images = video[batch_indices, frame_indices]
# 重新排列图像张量的维度,将通道维度放在第一个位置
images = rearrange(images, 'b 1 c ... -> b c ...')
return images
# 定义一个 CViViT 类,实现3D ViT模型,具有分解的空间和时间注意力,并制作成vqgan-vae自动编码器
class CViViT(nn.Module):
def __init__(
self,
*,
dim, # 模型维度
codebook_size, # 代码簿大小
image_size, # 图像大小
patch_size, # 图像块大小
temporal_patch_size, # 时间块大小
spatial_depth, # 空间深度
temporal_depth, # 时间深度
discr_base_dim=16, # 判别器基础维度
dim_head=64, # 头部维度
heads=8, # 头部数量
channels=3, # 通道数
use_vgg_and_gan=True, # 是否使用VGG和GAN
vgg=None, # VGG模型
discr_attn_res_layers=(16,), # 判别器注意力层分辨率
use_hinge_loss=True, # 是否使用hinge损失
attn_dropout=0., # 注意力机制的dropout率
ff_dropout=0., # feed-forward层的dropout率
lookup_free_quantization=True, # 是否使用无查找表的量化
lookup_free_quantization_kwargs: dict = {} # 无查找表的量化参数
):
"""
einstein notations:
b - batch
c - channels
t - time
d - feature dimension
p1, p2, pt - image patch sizes and then temporal patch size
"""
super().__init__()
self.image_size = pair(image_size)
self.patch_size = pair(patch_size)
patch_height, patch_width = self.patch_size
self.temporal_patch_size = temporal_patch_size
self.spatial_rel_pos_bias = ContinuousPositionBias(dim = dim, heads = heads)
image_height, image_width = self.image_size
assert (image_height % patch_height) == 0 and (image_width % patch_width) == 0
self.to_patch_emb_first_frame = nn.Sequential(
Rearrange('b c 1 (h p1) (w p2) -> b 1 h w (c p1 p2)', p1 = patch_height, p2 = patch_width),
nn.LayerNorm(channels * patch_width * patch_height),
nn.Linear(channels * patch_width * patch_height, dim),
nn.LayerNorm(dim)
)
self.to_patch_emb = nn.Sequential(
Rearrange('b c (t pt) (h p1) (w p2) -> b t h w (c pt p1 p2)', p1 = patch_height, p2 = patch_width, pt = temporal_patch_size),
nn.LayerNorm(channels * patch_width * patch_height * temporal_patch_size),
nn.Linear(channels * patch_width * patch_height * temporal_patch_size, dim),
nn.LayerNorm(dim)
)
transformer_kwargs = dict(
dim = dim,
dim_head = dim_head,
heads = heads,
attn_dropout = attn_dropout,
ff_dropout = ff_dropout,
peg = True,
peg_causal = True,
)
self.enc_spatial_transformer = Transformer(depth = spatial_depth, **transformer_kwargs)
self.enc_temporal_transformer = Transformer(depth = temporal_depth, **transformer_kwargs)
# offer look up free quantization
# https://arxiv.org/abs/2310.05737
self.lookup_free_quantization = lookup_free_quantization
if lookup_free_quantization:
self.vq = LFQ(dim = dim, codebook_size = codebook_size, **lookup_free_quantization_kwargs)
else:
self.vq = VectorQuantize(dim = dim, codebook_size = codebook_size, use_cosine_sim = True)
self.dec_spatial_transformer = Transformer(depth = spatial_depth, **transformer_kwargs)
self.dec_temporal_transformer = Transformer(depth = temporal_depth, **transformer_kwargs)
self.to_pixels_first_frame = nn.Sequential(
nn.Linear(dim, channels * patch_width * patch_height),
Rearrange('b 1 h w (c p1 p2) -> b c 1 (h p1) (w p2)', p1 = patch_height, p2 = patch_width)
)
self.to_pixels = nn.Sequential(
nn.Linear(dim, channels * patch_width * patch_height * temporal_patch_size),
Rearrange('b t h w (c pt p1 p2) -> b c (t pt) (h p1) (w p2)', p1 = patch_height, p2 = patch_width, pt = temporal_patch_size),
)
# turn off GAN and perceptual loss if grayscale
self.vgg = None
self.discr = None
self.use_vgg_and_gan = use_vgg_and_gan
if not use_vgg_and_gan:
return
# preceptual loss
if exists(vgg):
self.vgg = vgg
else:
self.vgg = torchvision.models.vgg16(pretrained = True)
self.vgg.classifier = nn.Sequential(*self.vgg.classifier[:-2])
# gan related losses
self.discr = Discriminator(
image_size = self.image_size,
dim = discr_base_dim,
channels = channels,
attn_res_layers = discr_attn_res_layers
)
self.discr_loss = hinge_discr_loss if use_hinge_loss else bce_discr_loss
self.gen_loss = hinge_gen_loss if use_hinge_loss else bce_gen_loss
# 计算视频的掩码,用于生成视频的 token
def calculate_video_token_mask(self, videos, video_frame_mask):
# 解构赋值,获取视频的高度和宽度
*_, h, w = videos.shape
# 获取补丁的高度和宽度
ph, pw = self.patch_size
# 断言视频帧掩码的总和减去1必须能被时间补丁大小整除
assert torch.all(((video_frame_mask.sum(dim = -1) - 1) % self.temporal_patch_size) == 0), 'number of frames must be divisible by temporal patch size, subtracting off the first frame'
# 将第一帧掩码和其余帧掩码分开
first_frame_mask, rest_frame_mask = video_frame_mask[:, :1], video_frame_mask[:, 1:]
# 重新排列其余帧掩码,以适应时间补丁大小
rest_vq_mask = rearrange(rest_frame_mask, 'b (f p) -> b f p', p = self.temporal_patch_size)
# 合并第一帧掩码和其余帧掩码的逻辑或结果
video_mask = torch.cat((first_frame_mask, rest_vq_mask.any(dim = -1)), dim = -1)
# 重复视频掩码,以匹配视频的高度和宽度
return repeat(video_mask, 'b f -> b (f hw)', hw = (h // ph) * (w // pw))
# 获取视频补丁的形状
def get_video_patch_shape(self, num_frames, include_first_frame = True):
patch_frames = 0
if include_first_frame:
num_frames -= 1
patch_frames += 1
patch_frames += (num_frames // self.temporal_patch_size)
return (patch_frames, *self.patch_height_width)
# 返回图像 token 的数量
@property
def image_num_tokens(self):
return int(self.image_size[0] / self.patch_size[0]) * int(self.image_size[1] / self.patch_size[1])
# 根据 token 数量返回帧数
def frames_per_num_tokens(self, num_tokens):
tokens_per_frame = self.image_num_tokens
assert (num_tokens % tokens_per_frame) == 0, f'number of tokens must be divisible by number of tokens per frame {tokens_per_frame}'
assert (num_tokens > 0)
pseudo_frames = num_tokens // tokens_per_frames
return (pseudo_frames - 1) * self.temporal_patch_size + 1
# 根据帧数返回 token 数量
def num_tokens_per_frames(self, num_frames, include_first_frame = True):
image_num_tokens = self.image_num_tokens
total_tokens = 0
if include_first_frame:
num_frames -= 1
total_tokens += image_num_tokens
assert (num_frames % self.temporal_patch_size) == 0
return total_tokens + int(num_frames / self.temporal_patch_size) * image_num_tokens
# 用于评估的模型拷贝
def copy_for_eval(self):
device = next(self.parameters()).device
vae_copy = copy.deepcopy(self.cpu())
if vae_copy.use_vgg_and_gan:
del vae_copy.discr
del vae_copy.vgg
vae_copy.eval()
return vae_copy.to(device)
# 重写 state_dict 方法
@remove_vgg
def state_dict(self, *args, **kwargs):
return super().state_dict(*args, **kwargs)
# 重写 load_state_dict 方法
@remove_vgg
def load_state_dict(self, *args, **kwargs):
return super().load_state_dict(*args, **kwargs)
# 加载模型
def load(self, path):
path = Path(path)
assert path.exists()
pt = torch.load(str(path))
self.load_state_dict(pt)
# 根据 codebook 索引解码
def decode_from_codebook_indices(self, indices):
if self.lookup_free_quantization:
codes = self.vq.indices_to_codes(indices)
else:
codes = self.vq.codebook[indices]
return self.decode(codes)
# 返回补丁的高度和宽度
@property
def patch_height_width(self):
return self.image_size[0] // self.patch_size[0], self.image_size[1] // self.patch_size[1]
# 编码 tokens
def encode(
self,
tokens
):
b = tokens.shape[0]
h, w = self.patch_height_width
video_shape = tuple(tokens.shape[:-1])
tokens = rearrange(tokens, 'b t h w d -> (b t) (h w) d')
attn_bias = self.spatial_rel_pos_bias(h, w, device = tokens.device)
tokens = self.enc_spatial_transformer(tokens, attn_bias = attn_bias, video_shape = video_shape)
tokens = rearrange(tokens, '(b t) (h w) d -> b t h w d', b = b, h = h , w = w)
# encode - temporal
tokens = rearrange(tokens, 'b t h w d -> (b h w) t d')
tokens = self.enc_temporal_transformer(tokens, video_shape = video_shape)
tokens = rearrange(tokens, '(b h w) t d -> b t h w d', b = b, h = h, w = w)
return tokens
# 解码 tokens
def decode(
self,
tokens
):
# 获取 tokens 的 batch 大小
b = tokens.shape[0]
# 获取 patch 的高度和宽度
h, w = self.patch_height_width
# 如果 tokens 的维度为 3,则重新排列 tokens 的维度
if tokens.ndim == 3:
tokens = rearrange(tokens, 'b (t h w) d -> b t h w d', h = h, w = w)
# 获取视频形状的元组
video_shape = tuple(tokens.shape[:-1])
# 解码 - 时间维度
# 重新排列 tokens 的维度
tokens = rearrange(tokens, 'b t h w d -> (b h w) t d')
# 对 tokens 进行时间维度的解码
tokens = self.dec_temporal_transformer(tokens, video_shape = video_shape)
# 重新排列 tokens 的维度
tokens = rearrange(tokens, '(b h w) t d -> b t h w d', b = b, h = h, w = w)
# 解码 - 空间维度
# 重新排列 tokens 的维度
tokens = rearrange(tokens, 'b t h w d -> (b t) (h w) d')
# 获取空间相对位置偏置
attn_bias = self.spatial_rel_pos_bias(h, w, device = tokens.device)
# 对 tokens 进行空间维度的解码
tokens = self.dec_spatial_transformer(tokens, attn_bias = attn_bias, video_shape = video_shape)
# 重新排列 tokens 的维度
tokens = rearrange(tokens, '(b t) (h w) d -> b t h w d', b = b, h = h , w = w)
# 转换为像素
# 获取第一帧 token 和其余帧 tokens
first_frame_token, rest_frames_tokens = tokens[:, :1], tokens[:, 1:]
# 将第一帧转换为像素
first_frame = self.to_pixels_first_frame(first_frame_token)
# 将其余帧转换为像素
rest_frames = self.to_pixels(rest_frames_tokens)
# 拼接重构视频
recon_video = torch.cat((first_frame, rest_frames), dim = 2)
# 返回重构视频
return recon_video
def forward(
self,
video,
mask = None,
return_recons = False,
return_recons_only = False,
return_discr_loss = False,
apply_grad_penalty = True,
return_only_codebook_ids = False