Lucidrains 系列项目源码解析(一百零三)
.\lucidrains\uformer-pytorch\uformer_pytorch\__init__.py
# 从uformer_pytorch.uformer_pytorch模块中导入Uformer类
from uformer_pytorch.uformer_pytorch import Uformer

UNet Stylegan2
An implementation of Stylegan2 with UNet Discriminator. This repository works largely the same way as Stylegan2 Pytorch. Simply replace all the stylegan2_pytorch command with unet_stylegan2 instead.


Update: Results have been very good. Will need to investigate combining this with a few other techniques, and then I will write up full instructions for use.
Install
$ pip install unet-stylegan2
Usage
$ unet_stylegan2 --data ./path/to/data
Citations
@misc{karras2019analyzing,
title={Analyzing and Improving the Image Quality of StyleGAN},
author={Tero Karras and Samuli Laine and Miika Aittala and Janne Hellsten and Jaakko Lehtinen and Timo Aila},
year={2019},
eprint={1912.04958},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
@misc{schnfeld2020unet,
title={A U-Net Based Discriminator for Generative Adversarial Networks},
author={Edgar Schönfeld and Bernt Schiele and Anna Khoreva},
year={2020},
eprint={2002.12655},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
.\lucidrains\unet-stylegan2\setup.py
# 导入设置工具和查找包的函数
from setuptools import setup, find_packages
# 设置包的元数据
setup(
name = 'unet_stylegan2', # 包的名称
packages = find_packages(), # 查找并包含所有包
scripts=['bin/unet_stylegan2'], # 包含可执行脚本
version = '0.5.1', # 版本号
license='GPLv3+', # 许可证
description = 'StyleGan2 with UNet Discriminator, in Pytorch', # 描述
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
url = 'https://github.com/lucidrains/unet-stylegan2', # 项目链接
keywords = ['generative adversarial networks', 'artificial intelligence'], # 关键词
install_requires=[ # 安装依赖
'fire',
'numpy',
'retry',
'tqdm',
'torch',
'torchvision',
'pillow',
'linear_attention_transformer>=0.12.1'
],
classifiers=[ # 分类
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\unet-stylegan2\unet_stylegan2\diff_augment.py
# 导入 torch 库
import torch
# 导入 torch.nn.functional 模块
import torch.nn.functional as F
# 定义函数 DiffAugment,对输入进行不同类型的数据增强
def DiffAugment(x, types=[]):
# 遍历传入的增强类型列表
for p in types:
# 遍历对应增强类型的函数列表
for f in AUGMENT_FNS[p]:
# 对输入数据应用增强函数
x = f(x)
# 返回增强后的数据,保证内存格式为 torch.contiguous_format
return x.contiguous(memory_format=torch.contiguous_format)
# 定义函数 rand_brightness,对输入数据进行随机亮度增强
def rand_brightness(x):
# 对输入数据添加随机亮度
x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
return x
# 定义函数 rand_saturation,对输入数据进行随机饱和度增强
def rand_saturation(x):
# 计算输入数据的均值
x_mean = x.mean(dim=1, keepdim=True)
# 对输入数据添加随机饱和度
x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean
return x
# 定义函数 rand_contrast,对输入数据进行随机对比度增强
def rand_contrast(x):
# 计算输入数据的均值
x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
# 对输入数据添加随机对比度
x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean
return x
# 定义函数 rand_translation,对输入数据进行随机平移增强
def rand_translation(x, ratio=0.125):
# 计算平移的像素数
shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
# 生成随机平移量
translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
# 生成平移后的坐标网格
grid_batch, grid_x, grid_y = torch.meshgrid(
torch.arange(x.size(0), dtype=torch.long, device=x.device),
torch.arange(x.size(2), dtype=torch.long, device=x.device),
torch.arange(x.size(3), dtype=torch.long, device=x.device),
)
# 对坐标进行平移
grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
# 对输入数据进行平移操作
x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2).contiguous(memory_format=torch.contiguous_format)
return x
# 定义函数 rand_cutout,对输入数据进行随机遮挡增强
def rand_cutout(x, ratio=0.5):
# 计算遮挡区域的大小
cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
# 生成随机遮挡区域的偏移量
offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)
offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)
# 生成遮挡区域的坐标网格
grid_batch, grid_x, grid_y = torch.meshgrid(
torch.arange(x.size(0), dtype=torch.long, device=x.device),
torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
)
# 对遮挡区域进行偏移
grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
# 生成遮挡掩码
mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
mask[grid_batch, grid_x, grid_y] = 0
# 对输入数据应用遮挡
x = x * mask.unsqueeze(1)
return x
# 定义增强函数字典,包含不同类型的增强函数列表
AUGMENT_FNS = {
'color': [rand_brightness, rand_saturation, rand_contrast],
'translation': [rand_translation],
'cutout': [rand_cutout],
}
.\lucidrains\unet-stylegan2\unet_stylegan2\unet_stylegan2.py
# 导入必要的库
import os
import sys
import math
import fire
import json
from tqdm import tqdm
from math import floor, log2
from random import random
from shutil import rmtree
from functools import partial
import multiprocessing
import numpy as np
import torch
from torch import nn
from torch.utils import data
import torch.nn.functional as F
from torch.optim import Adam
from torch.autograd import grad as torch_grad
import torchvision
from torchvision import transforms
from linear_attention_transformer import ImageLinearAttention
from PIL import Image
from pathlib import Path
# 尝试导入 apex 库,设置 APEX_AVAILABLE 变量
try:
from apex import amp
APEX_AVAILABLE = True
except:
APEX_AVAILABLE = False
# 检查是否有可用的 CUDA 设备
assert torch.cuda.is_available(), 'You need to have an Nvidia GPU with CUDA installed.'
# 获取 CPU 核心数量
num_cores = multiprocessing.cpu_count()
# 常量定义
# 支持的图片文件格式
EXTS = ['jpg', 'jpeg', 'png', 'webp']
# 微小的常数,用于避免除零错误
EPS = 1e-8
# 辅助类定义
# 自定义异常类,用于处理 NaN 异常
class NanException(Exception):
pass
# 指数移动平均类
class EMA():
def __init__(self, beta):
super().__init__()
self.beta = beta
def update_average(self, old, new):
if old is None:
return new
return old * self.beta + (1 - self.beta) * new
# 随机应用类,根据概率应用不同的函数
class RandomApply(nn.Module):
def __init__(self, prob, fn, fn_else = lambda x: x):
super().__init__()
self.fn = fn
self.fn_else = fn_else
self.prob = prob
def forward(self, x):
fn = self.fn if random() < self.prob else self.fn_else
return fn(x)
# 残差连接类
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x):
return self.fn(x) + x
# 展平类
class Flatten(nn.Module):
def __init__(self, index):
super().__init__()
self.index = index
def forward(self, x):
return x.flatten(self.index)
# Rezero 类
class Rezero(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
self.g = nn.Parameter(torch.zeros(1))
def forward(self, x):
return self.fn(x) * self.g
# 图像的自注意力和前馈网络层
attn_and_ff = lambda chan: nn.Sequential(*[
Residual(Rezero(ImageLinearAttention(chan, norm_queries = True))),
Residual(Rezero(nn.Sequential(nn.Conv2d(chan, chan * 2, 1), leaky_relu(), nn.Conv2d(chan * 2, chan, 1))))
])
# 辅助函数定义
# 返回默认值
def default(value, d):
return d if value is None else value
# 无限循环迭代器
def cycle(iterable):
while True:
for i in iterable:
yield i
# 将元素转换为列表
def cast_list(el):
return el if isinstance(el, list) else [el]
# 检查张量是否为空
def is_empty(t):
if isinstance(t, torch.Tensor):
return t.nelement() == 0
return t is None
# 如果张量包含 NaN,则抛出异常
def raise_if_nan(t):
if torch.isnan(t):
raise NanException
# 反向传播函数,支持混合精度训练
def loss_backwards(fp16, loss, optimizer, **kwargs):
if fp16:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward(**kwargs)
else:
loss.backward(**kwargs)
# 计算梯度惩罚项
def gradient_penalty(images, outputs, weight = 10):
batch_size = images.shape[0]
gradients = torch_grad(outputs=outputs, inputs=images,
grad_outputs=list(map(lambda t: torch.ones(t.size()).cuda(), outputs)),
create_graph=True, retain_graph=True, only_inputs=True)[0]
gradients = gradients.reshape(batch_size, -1)
return weight * ((gradients.norm(2, dim=1) - 1) ** 2).mean()
# 计算潜在空间长度
def calc_pl_lengths(styles, images):
num_pixels = images.shape[2] * images.shape[3]
pl_noise = torch.randn(images.shape).cuda() / math.sqrt(num_pixels)
outputs = (images * pl_noise).sum()
pl_grads = torch_grad(outputs=outputs, inputs=styles,
grad_outputs=torch.ones(outputs.shape).cuda(),
create_graph=True, retain_graph=True, only_inputs=True)[0]
return (pl_grads ** 2).sum(dim=2).mean(dim=1).sqrt()
# 生成随机噪声
def noise(n, latent_dim):
return torch.randn(n, latent_dim).cuda()
# 生成多层随机噪声列表
def noise_list(n, layers, latent_dim):
# 返回一个包含噪声和层信息的元组列表
return [(noise(n, latent_dim), layers)]
# 生成一个混合的噪声列表,包含两个噪声列表的和
def mixed_list(n, layers, latent_dim):
# 随机选择一个整数作为分割点
tt = int(torch.rand(()).numpy() * layers)
# 返回两个噪声列表的和
return noise_list(n, tt, latent_dim) + noise_list(n, layers - tt, latent_dim)
# 将潜在向量描述转换为样式向量和层数的元组列表
def latent_to_w(style_vectorizer, latent_descr):
return [(style_vectorizer(z), num_layers) for z, num_layers in latent_descr]
# 生成一个指定大小的图像噪声
def image_noise(n, im_size):
return torch.FloatTensor(n, im_size, im_size, 1).uniform_(0., 1.).cuda()
# 返回一个带有泄漏整流的激活函数
def leaky_relu(p=0.2):
return nn.LeakyReLU(p)
# 将输入参数按照最大批量大小分块,对模型进行评估
def evaluate_in_chunks(max_batch_size, model, *args):
split_args = list(zip(*list(map(lambda x: x.split(max_batch_size, dim=0), args))))
chunked_outputs = [model(*i) for i in split_args]
if len(chunked_outputs) == 1:
return chunked_outputs[0]
return torch.cat(chunked_outputs, dim=0)
# 将样式定义转换为张量
def styles_def_to_tensor(styles_def):
return torch.cat([t[:, None, :].expand(-1, n, -1) for t, n in styles_def], dim=1)
# 设置模型参数是否需要梯度
def set_requires_grad(model, bool):
for p in model.parameters():
p.requires_grad = bool
# Slerp 插值函数
def slerp(val, low, high):
low_norm = low / torch.norm(low, dim=1, keepdim=True)
high_norm = high / torch.norm(high, dim=1, keepdim=True)
omega = torch.acos((low_norm * high_norm).sum(1))
so = torch.sin(omega)
res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high
return res
# 热身函数,用于在一定步数内线性增加数值
def warmup(start, end, max_steps, current_step):
if current_step > max_steps:
return end
return (end - start) * (current_step / max_steps) + start
# 对张量进行对数运算
def log(t, eps = 1e-6):
return torch.log(t + eps)
# 生成 CutMix 的坐标
def cutmix_coordinates(height, width, alpha = 1.):
lam = np.random.beta(alpha, alpha)
cx = np.random.uniform(0, width)
cy = np.random.uniform(0, height)
w = width * np.sqrt(1 - lam)
h = height * np.sqrt(1 - lam)
x0 = int(np.round(max(cx - w / 2, 0)))
x1 = int(np.round(min(cx + w / 2, width)))
y0 = int(np.round(max(cy - h / 2, 0)))
y1 = int(np.round(min(cy + h / 2, height)))
return ((y0, y1), (x0, x1)), lam
# 执行 CutMix 操作
def cutmix(source, target, coors, alpha = 1.):
source, target = map(torch.clone, (source, target))
((y0, y1), (x0, x1)), _ = coors
source[:, :, y0:y1, x0:x1] = target[:, :, y0:y1, x0:x1]
return source
# 对源和目标进行遮罩操作
def mask_src_tgt(source, target, mask):
return source * mask + (1 - mask) * target
# 数据集
# 将 RGB 图像转换为带透明通道的图像
def convert_rgb_to_transparent(image):
if image.mode == 'RGB':
return image.convert('RGBA')
return image
# 将带透明通道的图像转换为 RGB 图像
def convert_transparent_to_rgb(image):
if image.mode == 'RGBA':
return image.convert('RGB')
return image
# 扩展灰度图像通道数
class expand_greyscale(object):
def __init__(self, num_channels):
self.num_channels = num_channels
def __call__(self, tensor):
return tensor.expand(self.num_channels, -1, -1)
# 调整图像大小至最小尺寸
def resize_to_minimum_size(min_size, image):
if max(*image.size) < min_size:
return torchvision.transforms.functional.resize(image, min_size)
return image
# 数据集类
class Dataset(data.Dataset):
def __init__(self, folder, image_size, transparent = False, aug_prob = 0.):
super().__init__()
self.folder = folder
self.image_size = image_size
self.paths = [p for ext in EXTS for p in Path(f'{folder}').glob(f'**/*.{ext}')]
convert_image_fn = convert_transparent_to_rgb if not transparent else convert_rgb_to_transparent
num_channels = 3 if not transparent else 4
self.transform = transforms.Compose([
transforms.Lambda(convert_image_fn),
transforms.Lambda(partial(resize_to_minimum_size, image_size)),
transforms.Resize(image_size),
RandomApply(aug_prob, transforms.RandomResizedCrop(image_size, scale=(0.5, 1.0), ratio=(0.98, 1.02)), transforms.CenterCrop(image_size)),
transforms.ToTensor(),
transforms.Lambda(expand_greyscale(num_channels))
])
def __len__(self):
return len(self.paths)
# 定义一个特殊方法,用于获取对象中指定索引位置的元素
def __getitem__(self, index):
# 获取指定索引位置的路径
path = self.paths[index]
# 打开指定路径的图像文件
img = Image.open(path)
# 对图像进行变换处理并返回
return self.transform(img)
# 定义一个生成器块类
class GeneratorBlock(nn.Module):
# 初始化函数
def __init__(self, latent_dim, input_channel, upsample, rgba=False):
super().__init__()
self.input_channel = input_channel
# 将输入的潜在向量映射到输入通道数
self.to_style = nn.Linear(latent_dim, input_channel)
# 如果是 RGBA 模式,则输出通道数为 4,否则为 3
out_filters = 3 if not rgba else 4
# 定义卷积层,不进行调制
self.conv = Conv2DMod(input_channel, out_filters, 1, demod=False)
# 如果需要上采样,则定义上采样层
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) if upsample else None
# 前向传播函数
def forward(self, x, prev_rgb, istyle):
b, c, h, w = x.shape
# 将潜在向量映射到输入通道数
style = self.to_style(istyle)
# 使用卷积层进行特征提取
x = self.conv(x, style)
# 如果有上一个 RGB 图像,则进行残差连接
if prev_rgb is not None:
x = x + prev_rgb
# 如果需要上采样,则进行上采样操作
if self.upsample is not None:
x = self.upsample(x)
return x
# 初始化函数,定义生成器的结构
def __init__(self, latent_dim, input_channels, filters, upsample = True, upsample_rgb = True, rgba = False):
# 调用父类的初始化函数
super().__init__()
# 如果需要上采样,则创建上采样层
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) if upsample else None
# 创建将潜在向量映射到输入通道的全连接层
self.to_style1 = nn.Linear(latent_dim, input_channels)
# 创建将噪声映射到滤波器数量的全连接层
self.to_noise1 = nn.Linear(1, filters)
# 创建卷积层,使用自定义的Conv2DMod类
self.conv1 = Conv2DMod(input_channels, filters, 3)
# 创建将潜在向量映射到滤波器数量的全连接层
self.to_style2 = nn.Linear(latent_dim, filters)
# 创建将噪声映射到滤波器数量的全连接层
self.to_noise2 = nn.Linear(1, filters)
# 创建卷积层,使用自定义的Conv2DMod类
self.conv2 = Conv2DMod(filters, filters, 3)
# 定义激活函数为LeakyReLU
self.activation = leaky_relu()
# 创建RGBBlock实例,用于生成RGB输出
self.to_rgb = RGBBlock(latent_dim, filters, upsample_rgb, rgba)
# 前向传播函数,定义生成器的前向传播过程
def forward(self, x, prev_rgb, istyle, inoise):
# 如果需要上采样,则对输入进行上采样
if self.upsample is not None:
x = self.upsample(x)
# 裁剪噪声张量,使其与输入张量的尺寸相匹配
inoise = inoise[:, :x.shape[2], :x.shape[3], :]
# 将噪声映射到滤波器数量,并进行维度变换
noise1 = self.to_noise1(inoise).permute((0, 3, 2, 1))
noise2 = self.to_noise2(inoise).permute((0, 3, 2, 1))
# 将潜在向量映射到输入通道,并进行卷积操作
style1 = self.to_style1(istyle)
x = self.conv1(x, style1)
x = self.activation(x + noise1)
# 将潜在向量映射到滤波器数量,并进行卷积操作
style2 = self.to_style2(istyle)
x = self.conv2(x, style2)
x = self.activation(x + noise2)
# 生成RGB输出
rgb = self.to_rgb(x, prev_rgb, istyle)
return x, rgb
# 定义一个包含两个卷积层和激活函数的序列模块
def double_conv(chan_in, chan_out):
return nn.Sequential(
nn.Conv2d(chan_in, chan_out, 3, padding=1), # 3x3卷积层,输入通道数为chan_in,输出通道数为chan_out,填充为1
leaky_relu(), # 使用LeakyReLU激活函数
nn.Conv2d(chan_out, chan_out, 3, padding=1), # 3x3卷积层,输入通道数为chan_out,输出通道数为chan_out,填充为1
leaky_relu() # 使用LeakyReLU激活函数
)
# 定义一个下采样块模块
class DownBlock(nn.Module):
def __init__(self, input_channels, filters, downsample=True):
super().__init__()
self.conv_res = nn.Conv2d(input_channels, filters, 1, stride=(2 if downsample else 1)) # 1x1卷积层,输入通道数为input_channels,输出通道数为filters,步长为2或1
self.net = double_conv(input_channels, filters) # 使用double_conv函数创建卷积层序列
self.down = nn.Conv2d(filters, filters, 3, padding=1, stride=2) if downsample else None # 下采样卷积层,输入通道数为filters,输出通道数为filters,填充为1,步长为2或None
def forward(self, x):
res = self.conv_res(x) # 对输入x进行1x1卷积
x = self.net(x) # 使用卷积层序列处理输入x
unet_res = x
if self.down is not None:
x = self.down(x) # 如果存在下采样卷积层,则对x进行下采样
x = x + res # 将1x1卷积结果与处理后的x相加
return x, unet_res
# 定义一个上采样块模块
class UpBlock(nn.Module):
def __init__(self, input_channels, filters):
super().__init__()
self.conv_res = nn.ConvTranspose2d(input_channels // 2, filters, 1, stride=2) # 转置卷积层,输入通道数为input_channels的一半,输出通道数为filters,步长为2
self.net = double_conv(input_channels, filters) # 使用double_conv函数创建卷积层序列
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) # 上采样层,尺度因��为2,插值模式为双���性插值,不对齐角点
def forward(self, x, res):
*_, h, w = x.shape
conv_res = self.conv_res(x, output_size=(h * 2, w * 2)) # 对输入x进行转置卷积
x = self.up(x) # 对输入x进行上采样
x = torch.cat((x, res), dim=1) # 在通道维度上拼接x和res
x = self.net(x) # 使用卷积层序列处理拼接后的x
x = x + conv_res # 将转置卷积结果与处理后的x相加
return x
# 定义一个生成器模块
class Generator(nn.Module):
def __init__(self, image_size, latent_dim, network_capacity=16, transparent=False, no_const=False, fmap_max=512):
super().__init__()
self.image_size = image_size
self.latent_dim = latent_dim
self.num_layers = int(log2(image_size) - 1)
filters = [network_capacity * (2 ** (i + 1)) for i in range(self.num_layers)][::-1]
set_fmap_max = partial(min, fmap_max)
filters = list(map(set_fmap_max, filters))
init_channels = filters[0]
filters = [init_channels, *filters]
in_out_pairs = zip(filters[:-1], filters[1:])
self.no_const = no_const
if no_const:
self.to_initial_block = nn.ConvTranspose2d(latent_dim, init_channels, 4, 1, 0, bias=False) # 转置卷积层,输入通道数为latent_dim,输出通道数为init_channels,核大小为4,步长为1,填充为0,无偏置
else:
self.initial_block = nn.Parameter(torch.randn((1, init_channels, 4, 4))) # 初始化块参数为随机张量
self.initial_conv = nn.Conv2d(filters[0], filters[0], 3, padding=1) # 3x3卷积层,输入通道数为filters[0],输出通道数为filters[0],填充为1
self.blocks = nn.ModuleList([]) # 创建模块列表
self.attns = nn.ModuleList([]) # 创建模块列表
for ind, (in_chan, out_chan) in enumerate(in_out_pairs):
not_first = ind != 0
not_last = ind != (self.num_layers - 1)
num_layer = self.num_layers - ind
attn_fn = attn_and_ff(in_chan) # 获取注意力函数
self.attns.append(attn_fn) # 添加到注意力模块列表
block = GeneratorBlock(
latent_dim,
in_chan,
out_chan,
upsample=not_first,
upsample_rgb=not_last,
rgba=transparent
)
self.blocks.append(block) # 添加生成器块模块到模块列表
def forward(self, styles, input_noise):
batch_size = styles.shape[0]
image_size = self.image_size
if self.no_const:
avg_style = styles.mean(dim=1)[:, :, None, None]
x = self.to_initial_block(avg_style) # 使用平均风格向量生成初始块
else:
x = self.initial_block.expand(batch_size, -1, -1, -1) # 扩展初始块参数
x = self.initial_conv(x) # 对初始块进行卷积
styles = styles.transpose(0, 1) # 转置风格张量
rgb = None
for style, block, attn in zip(styles, self.blocks, self.attns):
if attn is not None:
x = attn(x) # 如果存在注意力模块,则应用注意力
x, rgb = block(x, rgb, style, input_noise) # 使用生成器块模块处理x和rgb
return rgb # 返回rgb
class Discriminator(nn.Module):
# 初始化函数,设置神经网络的参数
def __init__(self, image_size, network_capacity = 16, transparent = False, fmap_max = 512):
# 调用父类的初始化函数
super().__init__()
# 计算网络层数
num_layers = int(log2(image_size) - 3)
# 初始化滤波器数量
num_init_filters = 3 if not transparent else 4
blocks = []
# 计算每一层的滤波器数量
filters = [num_init_filters] + [(network_capacity) * (2 ** i) for i in range(num_layers + 1)]
# 设置最大滤波器数量
set_fmap_max = partial(min, fmap_max)
filters = list(map(set_fmap_max, filters))
filters[-1] = filters[-2]
# 组合输入输出通道数
chan_in_out = list(zip(filters[:-1], filters[1:]))
chan_in_out = list(map(list, chan_in_out))
down_blocks = []
attn_blocks = []
# 遍历每一层,创建下采样块和注意力块
for ind, (in_chan, out_chan) in enumerate(chan_in_out):
num_layer = ind + 1
is_not_last = ind != (len(chan_in_out) - 1)
block = DownBlock(in_chan, out_chan, downsample = is_not_last)
down_blocks.append(block)
attn_fn = attn_and_ff(out_chan)
attn_blocks.append(attn_fn)
# 将下采样块和注意力块转换为 ModuleList
self.down_blocks = nn.ModuleList(down_blocks)
self.attn_blocks = nn.ModuleList(attn_blocks)
last_chan = filters[-1]
# 定义输出层
self.to_logit = nn.Sequential(
leaky_relu(),
nn.AvgPool2d(image_size // (2 ** num_layers)),
Flatten(1),
nn.Linear(last_chan, 1)
)
self.conv = double_conv(last_chan, last_chan)
# 反向遍历通道输入输出,创建上采样块
dec_chan_in_out = chan_in_out[:-1][::-1]
self.up_blocks = nn.ModuleList(list(map(lambda c: UpBlock(c[1] * 2, c[0]), dec_chan_in_out)))
self.conv_out = nn.Conv2d(3, 1, 1)
# 前向传播函数
def forward(self, x):
b, *_ = x.shape
residuals = []
# 遍历下采样块和注意力块
for (down_block, attn_block) in zip(self.down_blocks, self.attn_blocks):
x, unet_res = down_block(x)
residuals.append(unet_res)
if attn_block is not None:
x = attn_block(x)
x = self.conv(x) + x
enc_out = self.to_logit(x)
# 反向遍历上采样块,生成解码输出
for (up_block, res) in zip(self.up_blocks, residuals[:-1][::-1]):
x = up_block(x, res)
dec_out = self.conv_out(x)
return enc_out.squeeze(), dec_out
class StyleGAN2(nn.Module):
# 定义 StyleGAN2 类,继承自 nn.Module
def __init__(self, image_size, latent_dim = 512, fmap_max = 512, style_depth = 8, network_capacity = 16, transparent = False, fp16 = False, steps = 1, lr = 1e-4, ttur_mult = 2, no_const = False, lr_mul = 0.1, aug_types = ['translation', 'cutout']):
# 初始化函数,接受多个参数
super().__init__()
# 调用父类的初始化函数
self.lr = lr
self.steps = steps
self.ema_updater = EMA(0.995)
# 设置学习率、步数和指数移动平均更新器
self.S = StyleVectorizer(latent_dim, style_depth, lr_mul = lr_mul)
self.G = Generator(image_size, latent_dim, network_capacity, transparent = transparent, no_const = no_const, fmap_max = fmap_max)
self.D = Discriminator(image_size, network_capacity, transparent = transparent, fmap_max = fmap_max)
# 创建 StyleVectorizer、Generator 和 Discriminator 实例
self.SE = StyleVectorizer(latent_dim, style_depth, lr_mul = lr_mul)
self.GE = Generator(image_size, latent_dim, network_capacity, transparent = transparent, no_const = no_const)
# 创建额外的 StyleVectorizer 和 Generator 实例
self.D_aug = AugWrapper(self.D, image_size, aug_types)
# 创建用于增强所有输入到鉴别器的包装器
set_requires_grad(self.SE, False)
set_requires_grad(self.GE, False)
# 设置 SE 和 GE 的梯度计算为 False
generator_params = list(self.G.parameters()) + list(self.S.parameters())
self.G_opt = Adam(generator_params, lr = self.lr, betas=(0.5, 0.9))
self.D_opt = Adam(self.D.parameters(), lr = self.lr * ttur_mult, betas=(0.5, 0.9))
# 设置生成器和鉴别器的优化器
self._init_weights()
self.reset_parameter_averaging()
# 初始化权重和参数平均化
self.cuda()
# 将模型移至 GPU
self.fp16 = fp16
if fp16:
(self.S, self.G, self.D, self.SE, self.GE), (self.G_opt, self.D_opt) = amp.initialize([self.S, self.G, self.D, self.SE, self.GE], [self.G_opt, self.D_opt], opt_level='O1')
# 如果启用混合精度训练,则初始化混合精度训练
def _init_weights(self):
# 初始化权重函数
for m in self.modules():
if type(m) in {nn.Conv2d, nn.Linear}:
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
# 对卷积层和全连接层进行权重初始化
for block in self.G.blocks:
nn.init.zeros_(block.to_noise1.weight)
nn.init.zeros_(block.to_noise2.weight)
nn.init.zeros_(block.to_noise1.bias)
nn.init.zeros_(block.to_noise2.bias)
# 初始化生成器中的噪声层参数
def EMA(self):
# 指数移动平均函数
def update_moving_average(ma_model, current_model):
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
old_weight, up_weight = ma_params.data, current_params.data
ma_params.data = self.ema_updater.update_average(old_weight, up_weight)
# 更新移动平均参数
update_moving_average(self.SE, self.S)
update_moving_average(self.GE, self.G)
# 更新 SE 和 GE 的移动平均参数
def reset_parameter_averaging(self):
# 重置参数平均化函数
self.SE.load_state_dict(self.S.state_dict())
self.GE.load_state_dict(self.G.state_dict())
# 将 SE 和 GE 的状态字典加载到 S 和 G 中
def forward(self, x):
# 前向传播函数
return x
# 返回输入 x
class Trainer():
# 定义 Trainer 类
# 初始化函数,设置模型参数和训练参数
def __init__(self, name, results_dir, models_dir, image_size, network_capacity, transparent = False, batch_size = 4, mixed_prob = 0.9, gradient_accumulate_every=1, lr = 2e-4, ttur_mult = 2, num_workers = None, save_every = 1000, trunc_psi = 0.6, fp16 = False, no_const = False, aug_prob = 0., dataset_aug_prob = 0., cr_weight = 0.2, apply_pl_reg = False, lr_mul = 0.1, *args, **kwargs):
# 存储 GAN 参数
self.GAN_params = [args, kwargs]
self.GAN = None
# 设置模型名称、结果目录、模型目录、配置文件路径
self.name = name
self.results_dir = Path(results_dir)
self.models_dir = Path(models_dir)
self.config_path = self.models_dir / name / '.config.json'
# 检查图像大小是否为2的幂次方
assert log2(image_size).is_integer(), 'image size must be a power of 2 (64, 128, 256, 512, 1024)'
self.image_size = image_size
self.network_capacity = network_capacity
self.transparent = transparent
self.no_const = no_const
self.aug_prob = aug_prob
# 设置学习率、TTUR倍数、学习率倍数、批量大小、工作进程数、混合概率
self.lr = lr
self.ttur_mult = ttur_mult
self.lr_mul = lr_mul
self.batch_size = batch_size
self.num_workers = num_workers
self.mixed_prob = mixed_prob
self.save_every = save_every
self.steps = 0
self.av = None
self.trunc_psi = trunc_psi
self.apply_pl_reg = apply_pl_reg
self.pl_mean = None
self.gradient_accumulate_every = gradient_accumulate_every
# 检查是否支持混合精度训练
assert not fp16 or fp16 and APEX_AVAILABLE, 'Apex is not available for you to use mixed precision training'
self.fp16 = fp16
self.d_loss = 0
self.g_loss = 0
self.last_gp_loss = 0
self.last_cr_loss = 0
# 初始化指数移动平均
self.pl_length_ma = EMA(0.99)
self.init_folders()
self.loader = None
self.dataset_aug_prob = dataset_aug_prob
self.cr_weight = cr_weight
# 初始化 GAN 模型
def init_GAN(self):
args, kwargs = self.GAN_params
self.GAN = StyleGAN2(lr = self.lr, ttur_mult = self.ttur_mult, lr_mul = self.lr_mul, image_size = self.image_size, network_capacity = self.network_capacity, transparent = self.transparent, fp16 = self.fp16, no_const = self.no_const, *args, **kwargs)
# 写入配置文件
def write_config(self):
self.config_path.write_text(json.dumps(self.config()))
# 加载配置文件
def load_config(self):
config = self.config() if not self.config_path.exists() else json.loads(self.config_path.read_text())
self.image_size = config['image_size']
self.network_capacity = config['network_capacity']
self.transparent = config['transparent']
self.no_const = config.pop('no_const', False)
del self.GAN
self.init_GAN()
# 返回配置信息
def config(self):
return {'image_size': self.image_size, 'network_capacity': self.network_capacity, 'transparent': self.transparent, 'no_const': self.no_const}
# 设置数据源
def set_data_src(self, folder):
self.dataset = Dataset(folder, self.image_size, transparent = self.transparent, aug_prob = self.dataset_aug_prob)
self.loader = cycle(data.DataLoader(self.dataset, num_workers = default(self.num_workers, num_cores), batch_size = self.batch_size, drop_last = True, shuffle=True, pin_memory=True))
# 禁用梯度计算
@torch.no_grad()
# 定义评估函数,用于生成图像
def evaluate(self, num = 0, num_image_tiles = 8, trunc = 1.0):
# 将 GAN 设置为评估模式
self.GAN.eval()
# 根据是否透明设置文件扩展名
ext = 'jpg' if not self.transparent else 'png'
num_rows = num_image_tiles
latent_dim = self.GAN.G.latent_dim
image_size = self.GAN.G.image_size
num_layers = self.GAN.G.num_layers
# latents and noise
# 生成潜在向量和噪声
latents = noise_list(num_rows ** 2, num_layers, latent_dim)
n = image_noise(num_rows ** 2, image_size)
# regular
# 生成正常图像
generated_images = self.generate_truncated(self.GAN.S, self.GAN.G, latents, n, trunc_psi = self.trunc_psi)
torchvision.utils.save_image(generated_images, str(self.results_dir / self.name / f'{str(num)}.{ext}'), nrow=num_rows)
# moving averages
# 生成移动平均图像
generated_images = self.generate_truncated(self.GAN.SE, self.GAN.GE, latents, n, trunc_psi = self.trunc_psi)
torchvision.utils.save_image(generated_images, str(self.results_dir / self.name / f'{str(num)}-ema.{ext}'), nrow=num_rows)
# mixing regularities
# 定义瓷砖函数
def tile(a, dim, n_tile):
init_dim = a.size(dim)
repeat_idx = [1] * a.dim()
repeat_idx[dim] = n_tile
a = a.repeat(*(repeat_idx))
order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])).cuda()
return torch.index_select(a, dim, order_index)
nn = noise(num_rows, latent_dim)
tmp1 = tile(nn, 0, num_rows)
tmp2 = nn.repeat(num_rows, 1)
tt = int(num_layers / 2)
mixed_latents = [(tmp1, tt), (tmp2, num_layers - tt)]
# 生成混合图像
generated_images = self.generate_truncated(self.GAN.SE, self.GAN.GE, mixed_latents, n, trunc_psi = self.trunc_psi)
torchvision.utils.save_image(generated_images, str(self.results_dir / self.name / f'{str(num)}-mr.{ext}'), nrow=num_rows)
@torch.no_grad()
# 生成截断图像
def generate_truncated(self, S, G, style, noi, trunc_psi = 0.75, num_image_tiles = 8):
latent_dim = G.latent_dim
if self.av is None:
z = noise(2000, latent_dim)
samples = evaluate_in_chunks(self.batch_size, S, z).cpu().numpy()
self.av = np.mean(samples, axis = 0)
self.av = np.expand_dims(self.av, axis = 0)
w_space = []
for tensor, num_layers in style:
tmp = S(tensor)
av_torch = torch.from_numpy(self.av).cuda()
tmp = trunc_psi * (tmp - av_torch) + av_torch
w_space.append((tmp, num_layers))
w_styles = styles_def_to_tensor(w_space)
generated_images = evaluate_in_chunks(self.batch_size, G, w_styles, noi)
return generated_images.clamp_(0., 1.)
@torch.no_grad()
# 生成插值图像序列
def generate_interpolation(self, num = 0, num_image_tiles = 8, trunc = 1.0, save_frames = False):
# 将 GAN 设置为评估模式
self.GAN.eval()
# 确定文件扩展名
ext = 'jpg' if not self.transparent else 'png'
# 设置图像行数
num_rows = num_image_tiles
# 获取潜在空间维度、图像尺寸和层数
latent_dim = self.GAN.G.latent_dim
image_size = self.GAN.G.image_size
num_layers = self.GAN.G.num_layers
# 生成潜在向量和噪声
latents_low = noise(num_rows ** 2, latent_dim)
latents_high = noise(num_rows ** 2, latent_dim)
n = image_noise(num_rows ** 2, image_size)
# 创建插值比例
ratios = torch.linspace(0., 8., 100)
frames = []
# 遍历插值比例
for ratio in tqdm(ratios):
# 线性插值生成插值潜在向量
interp_latents = slerp(ratio, latents_low, latents_high)
latents = [(interp_latents, num_layers)]
# 生成经过截断的图像
generated_images = self.generate_truncated(self.GAN.SE, self.GAN.GE, latents, n, trunc_psi = self.trunc_psi)
# 将生成的图像拼接成网格
images_grid = torchvision.utils.make_grid(generated_images, nrow = num_rows)
# 转换为 PIL 图像
pil_image = transforms.ToPILImage()(images_grid.cpu())
frames.append(pil_image)
# 保存为 GIF 动画
frames[0].save(str(self.results_dir / self.name / f'{str(num)}.gif'), save_all=True, append_images=frames[1:], duration=80, loop=0, optimize=True)
# 如果需要保存每一帧图像
if save_frames:
folder_path = (self.results_dir / self.name / f'{str(num)}')
folder_path.mkdir(parents=True, exist_ok=True)
for ind, frame in enumerate(frames):
frame.save(str(folder_path / f'{str(ind)}.{ext}')
# 打印日志信息
def print_log(self):
pl_mean = default(self.pl_mean, 0)
print(f'G: {self.g_loss:.2f} | D: {self.d_loss:.2f} | GP: {self.last_gp_loss:.2f} | PL: {pl_mean:.2f} | CR: {self.last_cr_loss:.2f}')
# 返回模型文件名
def model_name(self, num):
return str(self.models_dir / self.name / f'model_{num}.pt')
# 初始化结果和模型文件夹
def init_folders(self):
(self.results_dir / self.name).mkdir(parents=True, exist_ok=True)
(self.models_dir / self.name).mkdir(parents=True, exist_ok=True)
# 清空结果和模型文件夹
def clear(self):
rmtree(f'./models/{self.name}', True)
rmtree(f'./results/{self.name}', True)
rmtree(str(self.config_path), True)
self.init_folders()
# 保存模型
def save(self, num):
save_data = {'GAN': self.GAN.state_dict()}
if self.GAN.fp16:
save_data['amp'] = amp.state_dict()
torch.save(save_data, self.model_name(num))
self.write_config()
# 加载模型
def load(self, num = -1):
self.load_config()
name = num
if num == -1:
file_paths = [p for p in Path(self.models_dir / self.name).glob('model_*.pt')]
saved_nums = sorted(map(lambda x: int(x.stem.split('_')[1]), file_paths))
if len(saved_nums) == 0:
return
name = saved_nums[-1]
print(f'continuing from previous epoch - {name}')
self.steps = name * self.save_every
load_data = torch.load(self.model_name(name))
self.GAN.load_state_dict(load_data['GAN'])
if self.GAN.fp16 and 'amp' in load_data:
amp.load_state_dict(load_data['amp'])
.\lucidrains\unet-stylegan2\unet_stylegan2\__init__.py
# 从 unet_stylegan2 模块中导入 Trainer, StyleGAN2 和 NanException 类
from unet_stylegan2.unet_stylegan2 import Trainer, StyleGAN2, NanException

Uniformer - Pytorch
Implementation of Uniformer, a simple attention and 3d convolutional net that achieved SOTA in a number of video classification tasks
Install
$ pip install uniformer-pytorch
Usage
Uniformer-S
import torch
from uniformer_pytorch import Uniformer
model = Uniformer(
num_classes = 1000, # number of output classes
dims = (64, 128, 256, 512), # feature dimensions per stage (4 stages)
depths = (3, 4, 8, 3), # depth at each stage
mhsa_types = ('l', 'l', 'g', 'g') # aggregation type at each stage, 'l' stands for local, 'g' stands for global
)
video = torch.randn(1, 3, 8, 224, 224) # (batch, channels, time, height, width)
logits = model(video) # (1, 1000)
Uniformer-B
import torch
from uniformer_pytorch import Uniformer
model = Uniformer(
num_classes = 1000
depths = (5, 8, 20, 7)
)
Citations
@inproceedings{anonymous2022uniformer,
title = {UniFormer: Unified Transformer for Efficient Spatial-Temporal Representation Learning},
author = {Anonymous},
booktitle = {Submitted to The Tenth International Conference on Learning Representations },
year = {2022},
url = {https://openreview.net/forum?id=nBU_u6DLvoK},
note = {under review}
}
.\lucidrains\uniformer-pytorch\setup.py
# 导入设置工具和查找包的函数
from setuptools import setup, find_packages
# 设置包的元数据
setup(
name = 'uniformer-pytorch', # 包的名称
packages = find_packages(), # 查找所有包
version = '0.0.4', # 版本号
license='MIT', # 许可证
description = 'Uniformer - Pytorch', # 描述
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
url = 'https://github.com/lucidrains/uniformer-pytorch', # 项目链接
keywords = [ # 关键词列表
'artificial intelligence',
'attention mechanism',
'video classification'
],
install_requires=[ # 安装依赖
'einops>=0.3',
'torch>=1.6'
],
classifiers=[ # 分类器列表
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\uniformer-pytorch\uniformer_pytorch\uniformer_pytorch.py
import torch
from torch import nn, einsum
from einops import rearrange
from einops.layers.torch import Reduce
# helpers
# 检查值是否存在的辅助函数
def exists(val):
return val is not None
# classes
# LayerNorm 类
class LayerNorm(nn.Module):
def __init__(self, dim, eps = 1e-5):
super().__init__()
self.eps = eps
self.g = nn.Parameter(torch.ones(1, dim, 1, 1, 1))
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1, 1))
def forward(self, x):
# 计算标准差
std = torch.var(x, dim = 1, unbiased = False, keepdim = True).sqrt()
# 计算均值
mean = torch.mean(x, dim = 1, keepdim = True)
# LayerNorm 操作
return (x - mean) / (std + self.eps) * self.g + self.b
# FeedForward 函数
def FeedForward(dim, mult = 4, dropout = 0.):
return nn.Sequential(
LayerNorm(dim),
nn.Conv3d(dim, dim * mult, 1),
nn.GELU(),
nn.Dropout(dropout),
nn.Conv3d(dim * mult, dim, 1)
)
# MHRAs (multi-head relation aggregators)
# LocalMHRA 类
class LocalMHRA(nn.Module):
def __init__(
self,
dim,
heads,
dim_head = 64,
local_aggr_kernel = 5
):
super().__init__()
self.heads = heads
inner_dim = dim_head * heads
# 使用 BatchNorm3d 代替 LayerNorm
self.norm = nn.BatchNorm3d(dim)
# 仅使用值,因为注意力矩阵由卷积处理
self.to_v = nn.Conv3d(dim, inner_dim, 1, bias = False)
# 通过相对位置聚合
self.rel_pos = nn.Conv3d(heads, heads, local_aggr_kernel, padding = local_aggr_kernel // 2, groups = heads)
# 合并所有头部的输出
self.to_out = nn.Conv3d(inner_dim, dim, 1)
def forward(self, x):
x = self.norm(x)
b, c, *_, h = *x.shape, self.heads
# 转换为值
v = self.to_v(x)
# 分割头部
v = rearrange(v, 'b (c h) ... -> (b c) h ...', h = h)
# 通过相对位置聚合
out = self.rel_pos(v)
# 合并头部
out = rearrange(out, '(b c) h ... -> b (c h) ...', b = b)
return self.to_out(out)
# GlobalMHRA 类
class GlobalMHRA(nn.Module):
def __init__(
self,
dim,
heads,
dim_head = 64,
dropout = 0.
):
super().__init__()
self.heads = heads
self.scale = dim_head ** -0.5
inner_dim = dim_head * heads
self.norm = LayerNorm(dim)
self.to_qkv = nn.Conv1d(dim, inner_dim * 3, 1, bias = False)
self.to_out = nn.Conv1d(inner_dim, dim, 1)
def forward(self, x):
x = self.norm(x)
shape, h = x.shape, self.heads
x = rearrange(x, 'b c ... -> b c (...)')
q, k, v = self.to_qkv(x).chunk(3, dim = 1)
q, k, v = map(lambda t: rearrange(t, 'b (h d) n -> b h n d', h = h), (q, k, v))
q = q * self.scale
sim = einsum('b h i d, b h j d -> b h i j', q, k)
# 注意力
attn = sim.softmax(dim = -1)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b (h d) n', h = h)
out = self.to_out(out)
return out.view(*shape)
# Transformer 类
class Transformer(nn.Module):
def __init__(
self,
*,
dim,
depth,
heads,
mhsa_type = 'g',
local_aggr_kernel = 5,
dim_head = 64,
ff_mult = 4,
ff_dropout = 0.,
attn_dropout = 0.
# 调用父类的构造函数初始化对象
):
super().__init__()
# 初始化一个空的神经网络模块列表
self.layers = nn.ModuleList([])
# 循环创建指定数量的层
for _ in range(depth):
# 根据不同的注意力类型创建不同的注意力模块
if mhsa_type == 'l':
attn = LocalMHRA(dim, heads = heads, dim_head = dim_head, local_aggr_kernel = local_aggr_kernel)
elif mhsa_type == 'g':
attn = GlobalMHRA(dim, heads = heads, dim_head = dim_head, dropout = attn_dropout)
else:
raise ValueError('unknown mhsa_type')
# 将卷积层、注意力层和前馈网络层组成一个模块列表,并添加到神经网络模块列表中
self.layers.append(nn.ModuleList([
nn.Conv3d(dim, dim, 3, padding = 1),
attn,
FeedForward(dim, mult = ff_mult, dropout = ff_dropout),
]))
# 前向传播函数
def forward(self, x):
# 遍历每个层,依次进行前向传播
for dpe, attn, ff in self.layers:
# 执行卷积层、注意力层和前馈网络层的操作,并将结果与输入相加
x = dpe(x) + x
x = attn(x) + x
x = ff(x) + x
# 返回最终的输出结果
return x
# 主类定义
class Uniformer(nn.Module):
# 初始化函数
def __init__(
self,
*,
num_classes, # 类别数量
dims = (64, 128, 256, 512), # 不同层的维度
depths = (3, 4, 8, 3), # 不同层的深度
mhsa_types = ('l', 'l', 'g', 'g'), # 多头自注意力类型
local_aggr_kernel = 5, # 局部聚合核大小
channels = 3, # 输入通道数
ff_mult = 4, # FeedForward 层的倍数
dim_head = 64, # 头部维度
ff_dropout = 0., # FeedForward 层的 dropout
attn_dropout = 0. # 注意力层的 dropout
):
super().__init__()
init_dim, *_, last_dim = dims
# 将输入视频转换为 tokens
self.to_tokens = nn.Conv3d(channels, init_dim, (3, 4, 4), stride = (2, 4, 4), padding = (1, 0, 0))
dim_in_out = tuple(zip(dims[:-1], dims[1:]))
mhsa_types = tuple(map(lambda t: t.lower(), mhsa_types))
self.stages = nn.ModuleList([])
# 遍历不同层的深度和多头自注意力类型
for ind, (depth, mhsa_type) in enumerate(zip(depths, mhsa_types)):
is_last = ind == len(depths) - 1
stage_dim = dims[ind]
heads = stage_dim // dim_head
# 添加 Transformer 层和下采样层到 stages
self.stages.append(nn.ModuleList([
Transformer(
dim = stage_dim,
depth = depth,
heads = heads,
mhsa_type = mhsa_type,
ff_mult = ff_mult,
ff_dropout = ff_dropout,
attn_dropout = attn_dropout
),
nn.Sequential(
nn.Conv3d(stage_dim, dims[ind + 1], (1, 2, 2), stride = (1, 2, 2)),
LayerNorm(dims[ind + 1]),
) if not is_last else None
]))
# 输出层
self.to_logits = nn.Sequential(
Reduce('b c t h w -> b c', 'mean'),
nn.LayerNorm(last_dim),
nn.Linear(last_dim, num_classes)
)
# 前向传播函数
def forward(self, video):
x = self.to_tokens(video)
# 遍历不同层的 Transformer 和下采样层
for transformer, conv in self.stages:
x = transformer(x)
if exists(conv):
x = conv(x)
return self.to_logits(x)
.\lucidrains\uniformer-pytorch\uniformer_pytorch\__init__.py
# 从 uniformer_pytorch 包中导入 Uniformer 类
from uniformer_pytorch.uniformer_pytorch import Uniformer
.\lucidrains\vector-quantize-pytorch\examples\autoencoder.py
# 导入所需的库
from tqdm.auto import trange
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from vector_quantize_pytorch import VectorQuantize
# 设置超参数
lr = 3e-4
train_iter = 1000
num_codes = 256
seed = 1234
device = "cuda" if torch.cuda.is_available() else "cpu"
# 定义简单的 VQ 自编码器模型
class SimpleVQAutoEncoder(nn.Module):
def __init__(self, **vq_kwargs):
super().__init__()
# 定义模型的层
self.layers = nn.ModuleList(
[
nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.GELU(),
nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
nn.MaxPool2d(kernel_size=2, stride=2),
VectorQuantize(dim=32, accept_image_fmap=True, **vq_kwargs),
nn.Upsample(scale_factor=2, mode="nearest"),
nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1),
nn.GELU(),
nn.Upsample(scale_factor=2, mode="nearest"),
nn.Conv2d(16, 1, kernel_size=3, stride=1, padding=1),
]
)
return
# 前向传播函数
def forward(self, x):
for layer in self.layers:
if isinstance(layer, VectorQuantize):
x, indices, commit_loss = layer(x)
else:
x = layer(x)
return x.clamp(-1, 1), indices, commit_loss
# 训练函数
def train(model, train_loader, train_iterations=1000, alpha=10):
def iterate_dataset(data_loader):
data_iter = iter(data_loader)
while True:
try:
x, y = next(data_iter)
except StopIteration:
data_iter = iter(data_loader)
x, y = next(data_iter)
yield x.to(device), y.to(device)
# 迭代训练数据集
for _ in (pbar := trange(train_iterations)):
opt.zero_grad()
x, _ = next(iterate_dataset(train_loader))
out, indices, cmt_loss = model(x)
rec_loss = (out - x).abs().mean()
(rec_loss + alpha * cmt_loss).backward()
opt.step()
# 更新进度条显示
pbar.set_description(
f"rec loss: {rec_loss.item():.3f} | "
+ f"cmt loss: {cmt_loss.item():.3f} | "
+ f"active %: {indices.unique().numel() / num_codes * 100:.3f}"
)
return
# 数据预处理
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)
train_dataset = DataLoader(
datasets.FashionMNIST(
root="~/data/fashion_mnist", train=True, download=True, transform=transform
),
batch_size=256,
shuffle=True,
)
# 打印信息并开始训练
print("baseline")
torch.random.manual_seed(seed)
model = SimpleVQAutoEncoder(codebook_size=num_codes).to(device)
opt = torch.optim.AdamW(model.parameters(), lr=lr)
train(model, train_dataset, train_iterations=train_iter)
.\lucidrains\vector-quantize-pytorch\examples\autoencoder_fsq.py
# 导入所需的库
from tqdm.auto import trange
import math
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from vector_quantize_pytorch import FSQ
# 设置超参数
lr = 3e-4
train_iter = 1000
levels = [8, 6, 5] # 目标大小为 2^8,实际大小为 240
num_codes = math.prod(levels) # 计算编码数量
seed = 1234
device = "cuda" if torch.cuda.is_available() else "cpu"
# 定义简单的自动编码器类
class SimpleFSQAutoEncoder(nn.Module):
def __init__(self, levels: list[int]):
super().__init__()
# 定义网络层
self.layers = nn.ModuleList(
[
nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.GELU(),
nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(32, len(levels), kernel_size=1),
FSQ(levels), # 使用自定义的 FSQ 模块
nn.Conv2d(len(levels), 32, kernel_size=3, stride=1, padding=1),
nn.Upsample(scale_factor=2, mode="nearest"),
nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1),
nn.GELU(),
nn.Upsample(scale_factor=2, mode="nearest"),
nn.Conv2d(16, 1, kernel_size=3, stride=1, padding=1),
]
)
return
def forward(self, x):
for layer in self.layers:
if isinstance(layer, FSQ):
x, indices = layer(x) # 使用 FSQ 模块
else:
x = layer(x)
return x.clamp(-1, 1), indices
# 训练函数
def train(model, train_loader, train_iterations=1000):
def iterate_dataset(data_loader):
data_iter = iter(data_loader)
while True:
try:
x, y = next(data_iter)
except StopIteration:
data_iter = iter(data_loader)
x, y = next(data_iter)
yield x.to(device), y.to(device)
for _ in (pbar := trange(train_iterations)):
opt.zero_grad()
x, _ = next(iterate_dataset(train_loader))
out, indices = model(x)
rec_loss = (out - x).abs().mean()
rec_loss.backward()
opt.step()
pbar.set_description(
f"rec loss: {rec_loss.item():.3f} | "
+ f"active %: {indices.unique().numel() / num_codes * 100:.3f}"
)
return
# 数据预处理
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)
train_dataset = DataLoader(
datasets.FashionMNIST(
root="~/data/fashion_mnist", train=True, download=True, transform=transform
),
batch_size=256,
shuffle=True,
)
# 打印信息并开始训练
print("baseline")
torch.random.manual_seed(seed)
model = SimpleFSQAutoEncoder(levels).to(device)
opt = torch.optim.AdamW(model.parameters(), lr=lr)
train(model, train_dataset, train_iterations=train_iter)
.\lucidrains\vector-quantize-pytorch\examples\autoencoder_lfq.py
# 导入所需的库
from tqdm.auto import trange
from math import log2
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# 导入自定义的 LFQ 模块
from vector_quantize_pytorch import LFQ
# 设置训练参数
lr = 3e-4
train_iter = 1000
seed = 1234
codebook_size = 2 ** 8
entropy_loss_weight = 0.02
diversity_gamma = 1.
device = "cuda" if torch.cuda.is_available() else "cpu"
# 定义 LFQAutoEncoder 类,继承自 nn.Module
class LFQAutoEncoder(nn.Module):
def __init__(
self,
codebook_size,
**vq_kwargs
):
super().__init__()
assert log2(codebook_size).is_integer()
quantize_dim = int(log2(codebook_size))
# 编码器部分
self.encode = nn.Sequential(
nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.GELU(),
nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.GroupNorm(4, 32, affine=False), # 添加规范化层
nn.Conv2d(32, quantize_dim, kernel_size=1),
)
# LFQ 模块
self.quantize = LFQ(dim=quantize_dim, **vq_kwargs)
# 解码器部分
self.decode = nn.Sequential(
nn.Conv2d(quantize_dim, 32, kernel_size=3, stride=1, padding=1),
nn.Upsample(scale_factor=2, mode="nearest"),
nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1),
nn.GELU(),
nn.Upsample(scale_factor=2, mode="nearest"),
nn.Conv2d(16, 1, kernel_size=3, stride=1, padding=1),
)
return
# 前向传播函数
def forward(self, x):
x = self.encode(x)
x, indices, entropy_aux_loss = self.quantize(x)
x = self.decode(x)
return x.clamp(-1, 1), indices, entropy_aux_loss
# 训练函数
def train(model, train_loader, train_iterations=1000):
def iterate_dataset(data_loader):
data_iter = iter(data_loader)
while True:
try:
x, y = next(data_iter)
except StopIteration:
data_iter = iter(data_loader)
x, y = next(data_iter)
yield x.to(device), y.to(device)
# 迭代训练数据集
for _ in (pbar := trange(train_iterations)):
opt.zero_grad()
x, _ = next(iterate_dataset(train_loader))
out, indices, entropy_aux_loss = model(x)
rec_loss = F.l1_loss(out, x)
(rec_loss + entropy_aux_loss).backward()
opt.step()
pbar.set_description(
f"rec loss: {rec_loss.item():.3f} | "
+ f"entropy aux loss: {entropy_aux_loss.item():.3f} | "
+ f"active %: {indices.unique().numel() / codebook_size * 100:.3f}"
)
return
# 数据预处理
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)
# 加载 FashionMNIST 数据集
train_dataset = DataLoader(
datasets.FashionMNIST(
root="~/data/fashion_mnist", train=True, download=True, transform=transform
),
batch_size=256,
shuffle=True,
)
# 打印提示信息
print("baseline")
# 设置随机种子
torch.random.manual_seed(seed)
# 创建 LFQAutoEncoder 模型实例
model = LFQAutoEncoder(
codebook_size = codebook_size,
entropy_loss_weight = entropy_loss_weight,
diversity_gamma = diversity_gamma
).to(device)
# 定义优化器
opt = torch.optim.AdamW(model.parameters(), lr=lr)
# 训练模型
train(model, train_dataset, train_iterations=train_iter)

Vector Quantization - Pytorch
A vector quantization library originally transcribed from Deepmind's tensorflow implementation, made conveniently into a package. It uses exponential moving averages to update the dictionary.
VQ has been successfully used by Deepmind and OpenAI for high quality generation of images (VQ-VAE-2) and music (Jukebox).
Install
$ pip install vector-quantize-pytorch
Usage
import torch
from vector_quantize_pytorch import VectorQuantize
vq = VectorQuantize(
dim = 256,
codebook_size = 512, # codebook size
decay = 0.8, # the exponential moving average decay, lower means the dictionary will change faster
commitment_weight = 1. # the weight on the commitment loss
)
x = torch.randn(1, 1024, 256)
quantized, indices, commit_loss = vq(x) # (1, 1024, 256), (1, 1024), (1)
Residual VQ
This paper proposes to use multiple vector quantizers to recursively quantize the residuals of the waveform. You can use this with the ResidualVQ class and one extra initialization parameter.
import torch
from vector_quantize_pytorch import ResidualVQ
residual_vq = ResidualVQ(
dim = 256,
num_quantizers = 8, # specify number of quantizers
codebook_size = 1024, # codebook size
)
x = torch.randn(1, 1024, 256)
quantized, indices, commit_loss = residual_vq(x)
# (1, 1024, 256), (1, 1024, 8), (1, 8)
# (batch, seq, dim), (batch, seq, quantizer), (batch, quantizer)
# if you need all the codes across the quantization layers, just pass return_all_codes = True
quantized, indices, commit_loss, all_codes = residual_vq(x, return_all_codes = True)
# *_, (8, 1, 1024, 256)
# all_codes - (quantizer, batch, seq, dim)
Furthermore, this paper uses Residual-VQ to construct the RQ-VAE, for generating high resolution images with more compressed codes.
They make two modifications. The first is to share the codebook across all quantizers. The second is to stochastically sample the codes rather than always taking the closest match. You can use both of these features with two extra keyword arguments.
import torch
from vector_quantize_pytorch import ResidualVQ
residual_vq = ResidualVQ(
dim = 256,
num_quantizers = 8,
codebook_size = 1024,
stochastic_sample_codes = True,
sample_codebook_temp = 0.1, # temperature for stochastically sampling codes, 0 would be equivalent to non-stochastic
shared_codebook = True # whether to share the codebooks for all quantizers or not
)
x = torch.randn(1, 1024, 256)
quantized, indices, commit_loss = residual_vq(x)
# (1, 1024, 256), (8, 1, 1024), (8, 1)
# (batch, seq, dim), (quantizer, batch, seq), (quantizer, batch)
A recent paper further proposes to do residual VQ on groups of the feature dimension, showing equivalent results to Encodec while using far fewer codebooks. You can use it by importing GroupedResidualVQ
import torch
from vector_quantize_pytorch import GroupedResidualVQ
residual_vq = GroupedResidualVQ(
dim = 256,
num_quantizers = 8, # specify number of quantizers
groups = 2,
codebook_size = 1024, # codebook size
)
x = torch.randn(1, 1024, 256)
quantized, indices, commit_loss = residual_vq(x)
# (1, 1024, 256), (2, 1, 1024, 8), (2, 1, 8)
# (batch, seq, dim), (groups, batch, seq, quantizer), (groups, batch, quantizer)
Initialization
The SoundStream paper proposes that the codebook should be initialized by the kmeans centroids of the first batch. You can easily turn on this feature with one flag kmeans_init = True, for either VectorQuantize or ResidualVQ class
import torch
from vector_quantize_pytorch import ResidualVQ
residual_vq = ResidualVQ(
dim = 256,
codebook_size = 256,
num_quantizers = 4,
kmeans_init = True, # set to True
kmeans_iters = 10 # number of kmeans iterations to calculate the centroids for the codebook on init
)
x = torch.randn(1, 1024, 256)
quantized, indices, commit_loss = residual_vq(x)
Increasing codebook usage
This repository will contain a few techniques from various papers to combat "dead" codebook entries, which is a common problem when using vector quantizers.
Lower codebook dimension
The Improved VQGAN paper proposes to have the codebook kept in a lower dimension. The encoder values are projected down before being projected back to high dimensional after quantization. You can set this with the codebook_dim hyperparameter.
import torch
from vector_quantize_pytorch import VectorQuantize
vq = VectorQuantize(
dim = 256,
codebook_size = 256,
codebook_dim = 16 # paper proposes setting this to 32 or as low as 8 to increase codebook usage
)
x = torch.randn(1, 1024, 256)
quantized, indices, commit_loss = vq(x)
Cosine similarity
The Improved VQGAN paper also proposes to l2 normalize the codes and the encoded vectors, which boils down to using cosine similarity for the distance. They claim enforcing the vectors on a sphere leads to improvements in code usage and downstream reconstruction. You can turn this on by setting use_cosine_sim = True
import torch
from vector_quantize_pytorch import VectorQuantize
vq = VectorQuantize(
dim = 256,
codebook_size = 256,
use_cosine_sim = True # set this to True
)
x = torch.randn(1, 1024, 256)
quantized, indices, commit_loss = vq(x)
Expiring stale codes
Finally, the SoundStream paper has a scheme where they replace codes that have hits below a certain threshold with randomly selected vector from the current batch. You can set this threshold with threshold_ema_dead_code keyword.
import torch
from vector_quantize_pytorch import VectorQuantize
vq = VectorQuantize(
dim = 256,
codebook_size = 512,
threshold_ema_dead_code = 2 # should actively replace any codes that have an exponential moving average cluster size less than 2
)
x = torch.randn(1, 1024, 256)
quantized, indices, commit_loss = vq(x)
Orthogonal regularization loss
VQ-VAE / VQ-GAN is quickly gaining popularity. A recent paper proposes that when using vector quantization on images, enforcing the codebook to be orthogonal leads to translation equivariance of the discretized codes, leading to large improvements in downstream text to image generation tasks.
You can use this feature by simply setting the orthogonal_reg_weight to be greater than 0, in which case the orthogonal regularization will be added to the auxiliary loss outputted by the module.
import torch
from vector_quantize_pytorch import VectorQuantize
vq = VectorQuantize(
dim = 256,
codebook_size = 256,
accept_image_fmap = True, # set this true to be able to pass in an image feature map
orthogonal_reg_weight = 10, # in paper, they recommended a value of 10
orthogonal_reg_max_codes = 128, # this would randomly sample from the codebook for the orthogonal regularization loss, for limiting memory usage
orthogonal_reg_active_codes_only = False # set this to True if you have a very large codebook, and would only like to enforce the loss on the activated codes per batch
)
img_fmap = torch.randn(1, 256, 32, 32)
quantized, indices, loss = vq(img_fmap) # (1, 256, 32, 32), (1, 32, 32), (1,)
# loss now contains the orthogonal regularization loss with the weight as assigned
Multi-headed VQ
There has been a number of papers that proposes variants of discrete latent representations with a multi-headed approach (multiple codes per feature). I have decided to offer one variant where the same codebook is used to vector quantize across the input dimension head times.
You can also use a more proven approach (memcodes) from NWT paper
import torch
from vector_quantize_pytorch import VectorQuantize
vq = VectorQuantize(
dim = 256,
codebook_dim = 32, # a number of papers have shown smaller codebook dimension to be acceptable
heads = 8, # number of heads to vector quantize, codebook shared across all heads
separate_codebook_per_head = True, # whether to have a separate codebook per head. False would mean 1 shared codebook
codebook_size = 8196,
accept_image_fmap = True
)
img_fmap = torch.randn(1, 256, 32, 32)
quantized, indices, loss = vq(img_fmap) # (1, 256, 32, 32), (1, 32, 32, 8), (1,)
# indices shape - (batch, height, width, heads)
Random Projection Quantizer
This paper first proposed to use a random projection quantizer for masked speech modeling, where signals are projected with a randomly initialized matrix and then matched with a random initialized codebook. One therefore does not need to learn the quantizer. This technique was used by Google's Universal Speech Model to achieve SOTA for speech-to-text modeling.
USM further proposes to use multiple codebook, and the masked speech modeling with a multi-softmax objective. You can do this easily by setting num_codebooks to be greater than 1
import torch
from vector_quantize_pytorch import RandomProjectionQuantizer
quantizer = RandomProjectionQuantizer(
dim = 512, # input dimensions
num_codebooks = 16, # in USM, they used up to 16 for 5% gain
codebook_dim = 256, # codebook dimension
codebook_size = 1024 # codebook size
)
x = torch.randn(1, 1024, 512)
indices = quantizer(x) # (1, 1024, 16) - (batch, seq, num_codebooks)
This repository should also automatically synchronizing the codebooks in a multi-process setting. If somehow it isn't, please open an issue. You can override whether to synchronize codebooks or not by setting sync_codebook = True | False
Finite Scalar Quantization

| VQ | FSQ | |
|---|---|---|
| Quantization | argmin_c || z-c || | round(f(z)) |
| Gradients | Straight Through Estimation (STE) | STE |
| Auxiliary Losses | Commitment, codebook, entropy loss, ... | N/A |
| Tricks | EMA on codebook, codebook splitting, projections, ... | N/A |
| Parameters | Codebook | N/A |
This work out of Google Deepmind aims to vastly simplify the way vector quantization is done for generative modeling, removing the need for commitment losses, EMA updating of the codebook, as well as tackle the issues with codebook collapse or insufficient utilization. They simply round each scalar into discrete levels with straight through gradients; the codes become uniform points in a hypercube.
Thanks goes out to @sekstini for porting over this implementation in record time!
import torch
from vector_quantize_pytorch import FSQ
levels = [8,5,5,5] # see 4.1 and A.4.1 in the paper
quantizer = FSQ(levels)
x = torch.randn(1, 1024, 4) # 4 since there are 4 levels
xhat, indices = quantizer(x)
print(xhat.shape) # (1, 1024, 4) - (batch, seq, dim)
print(indices.shape) # (1, 1024) - (batch, seq)
assert xhat.shape == x.shape
assert torch.all(xhat == quantizer.indices_to_codes(indices))
An improvised Residual FSQ, for an attempt to improve audio encoding.
Credit goes to @sekstini for originally incepting the idea here
import torch
from vector_quantize_pytorch import ResidualFSQ
residual_fsq = ResidualFSQ(
dim = 256,
levels = [8, 5, 5, 3],
num_quantizers = 8
)
x = torch.randn(1, 1024, 256)
residual_fsq.eval()
quantized, indices = residual_fsq(x)
# (1, 1024, 256), (1, 1024, 8), (8)
# (batch, seq, dim), (batch, seq, quantizers), (quantizers)
quantized_out = residual_fsq.get_output_from_indices(indices)
# (8, 1, 1024, 8)
# (residual layers, batch, seq, quantizers)
assert torch.all(quantized == quantized_out)
Lookup Free Quantization

The research team behind MagViT has released new SOTA results for generative video modeling. A core change between v1 and v2 include a new type of quantization, look-up free quantization (LFQ), which eliminates the codebook and embedding lookup entirely.
This paper presents a simple LFQ quantizer of using independent binary latents. Other implementations of LFQ exist. However, the team shows that MAGVIT-v2 with LFQ significantly improves on the ImageNet benchmark. The differences between LFQ and 2-level FSQ includes entropy regularizations as well as maintained commitment loss.
Developing a more advanced method of LFQ quantization without codebook-lookup could revolutionize generative modeling.
You can use it simply as follows. Will be dogfooded at MagViT2 pytorch port
import torch
from vector_quantize_pytorch import LFQ
# you can specify either dim or codebook_size
# if both specified, will be validated against each other
quantizer = LFQ(
codebook_size = 65536, # codebook size, must be a power of 2
dim = 16, # this is the input feature dimension, defaults to log2(codebook_size) if not defined
entropy_loss_weight = 0.1, # how much weight to place on entropy loss
diversity_gamma = 1. # within entropy loss, how much weight to give to diversity of codes, taken from https://arxiv.org/abs/1911.05894
)
image_feats = torch.randn(1, 16, 32, 32)
quantized, indices, entropy_aux_loss = quantizer(image_feats, inv_temperature=100.) # you may want to experiment with temperature
# (1, 16, 32, 32), (1, 32, 32), (1,)
assert image_feats.shape == quantized.shape
assert (quantized == quantizer.indices_to_codes(indices)).all()
You can also pass in video features as (batch, feat, time, height, width) or sequences as (batch, seq, feat)
seq = torch.randn(1, 32, 16)
quantized, *_ = quantizer(seq)
assert seq.shape == quantized.shape
video_feats = torch.randn(1, 16, 10, 32, 32)
quantized, *_ = quantizer(video_feats)
assert video_feats.shape == quantized.shape
Or support multiple codebooks
import torch
from vector_quantize_pytorch import LFQ
quantizer = LFQ(
codebook_size = 4096,
dim = 16,
num_codebooks = 4 # 4 codebooks, total codebook dimension is log2(4096) * 4
)
image_feats = torch.randn(1, 16, 32, 32)
quantized, indices, entropy_aux_loss = quantizer(image_feats)
# (1, 16, 32, 32), (1, 32, 32, 4), (1,)
assert image_feats.shape == quantized.shape
assert (quantized == quantizer.indices_to_codes(indices)).all()
An improvised Residual LFQ, to see if it can lead to an improvement for audio compression.
import torch
from vector_quantize_pytorch import ResidualLFQ
residual_lfq = ResidualLFQ(
dim = 256,
codebook_size = 256,
num_quantizers = 8
)
x = torch.randn(1, 1024, 256)
residual_lfq.eval()
quantized, indices, commit_loss = residual_lfq(x)
# (1, 1024, 256), (1, 1024, 8), (8)
# (batch, seq, dim), (batch, seq, quantizers), (quantizers)
quantized_out = residual_lfq.get_output_from_indices(indices)
# (8, 1, 1024, 8)
# (residual layers, batch, seq, quantizers)
assert torch.all(quantized == quantized_out)
Latent Quantization
Disentanglement is essential for representation learning as it promotes interpretability, generalization, improved learning, and robustness. It aligns with the goal of capturing meaningful and independent features of the data, facilitating more effective use of learned representations across various applications. For better disentanglement, the challenge is to disentangle underlying variations in a dataset without explicit ground truth information. This work introduces a key inductive bias aimed at encoding and decoding within an organized latent space. The strategy incorporated encompasses discretizing the latent space by assigning discrete code vectors through the utilization of an individual learnable scalar codebook for each dimension. This methodology enables their models to surpass robust prior methods effectively.
Be aware they had to use a very high weight decay for the results in this paper.
import torch
from vector_quantize_pytorch import LatentQuantize
# you can specify either dim or codebook_size
# if both specified, will be validated against each other
quantizer = LatentQuantize(
levels = [5, 5, 8], # number of levels per codebook dimension
dim = 16, # input dim
commitment_loss_weight=0.1,
quantization_loss_weight=0.1,
)
image_feats = torch.randn(1, 16, 32, 32)
quantized, indices, loss = quantizer(image_feats)
# (1, 16, 32, 32), (1, 32, 32), (1,)
assert image_feats.shape == quantized.shape
assert (quantized == quantizer.indices_to_codes(indices)).all()
You can also pass in video features as (batch, feat, time, height, width) or sequences as (batch, seq, feat)
seq = torch.randn(1, 32, 16)
quantized, *_ = quantizer(seq)
assert seq.shape == quantized.shape
video_feats = torch.randn(1, 16, 10, 32, 32)
quantized, *_ = quantizer(video_feats)
assert video_feats.shape == quantized.shape
Or support multiple codebooks
import torch
from vector_quantize_pytorch import LatentQuantize
levels = [4, 8, 16]
dim = 9
num_codebooks = 3
model = LatentQuantize(levels, dim, num_codebooks=num_codebooks)
input_tensor = torch.randn(2, 3, dim)
output_tensor, indices, loss = model(input_tensor)
assert output_tensor.shape == input_tensor.shape
assert indices.shape == (2, 3, num_codebooks)
assert loss.item() >= 0
Citations
@misc{oord2018neural,
title = {Neural Discrete Representation Learning},
author = {Aaron van den Oord and Oriol Vinyals and Koray Kavukcuoglu},
year = {2018},
eprint = {1711.00937},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
@misc{zeghidour2021soundstream,
title = {SoundStream: An End-to-End Neural Audio Codec},
author = {Neil Zeghidour and Alejandro Luebs and Ahmed Omran and Jan Skoglund and Marco Tagliasacchi},
year = {2021},
eprint = {2107.03312},
archivePrefix = {arXiv},
primaryClass = {cs.SD}
}
@inproceedings{anonymous2022vectorquantized,
title = {Vector-quantized Image Modeling with Improved {VQGAN}},
author = {Anonymous},
booktitle = {Submitted to The Tenth International Conference on Learning Representations },
year = {2022},
url = {https://openreview.net/forum?id=pfNyExj7z2},
note = {under review}
}
@unknown{unknown,
author = {Lee, Doyup and Kim, Chiheon and Kim, Saehoon and Cho, Minsu and Han, Wook-Shin},
year = {2022},
month = {03},
title = {Autoregressive Image Generation using Residual Quantization}
}
@article{Defossez2022HighFN,
title = {High Fidelity Neural Audio Compression},
author = {Alexandre D'efossez and Jade Copet and Gabriel Synnaeve and Yossi Adi},
journal = {ArXiv},
year = {2022},
volume = {abs/2210.13438}
}
@inproceedings{Chiu2022SelfsupervisedLW,
title = {Self-supervised Learning with Random-projection Quantizer for Speech Recognition},
author = {Chung-Cheng Chiu and James Qin and Yu Zhang and Jiahui Yu and Yonghui Wu},
booktitle = {International Conference on Machine Learning},
year = {2022}
}
@inproceedings{Zhang2023GoogleUS,
title = {Google USM: Scaling Automatic Speech Recognition Beyond 100 Languages},
author = {Yu Zhang and Wei Han and James Qin and Yongqiang Wang and Ankur Bapna and Zhehuai Chen and Nanxin Chen and Bo Li and Vera Axelrod and Gary Wang and Zhong Meng and Ke Hu and Andrew Rosenberg and Rohit Prabhavalkar and Daniel S. Park and Parisa Haghani and Jason Riesa and Ginger Perng and Hagen Soltau and Trevor Strohman and Bhuvana Ramabhadran and Tara N. Sainath and Pedro J. Moreno and Chung-Cheng Chiu and Johan Schalkwyk and Franccoise Beaufays and Yonghui Wu},
year = {2023}
}
@inproceedings{Shen2023NaturalSpeech2L,
title = {NaturalSpeech 2: Latent Diffusion Models are Natural and Zero-Shot Speech and Singing Synthesizers},
author = {Kai Shen and Zeqian Ju and Xu Tan and Yanqing Liu and Yichong Leng and Lei He and Tao Qin and Sheng Zhao and Jiang Bian},
year = {2023}
}
@inproceedings{Yang2023HiFiCodecGV,
title = {HiFi-Codec: Group-residual Vector quantization for High Fidelity Audio Codec},
author = {Dongchao Yang and Songxiang Liu and Rongjie Huang and Jinchuan Tian and Chao Weng and Yuexian Zou},
year = {2023}
}
@article{Liu2023BridgingDA,
title = {Bridging Discrete and Backpropagation: Straight-Through and Beyond},
author = {Liyuan Liu and Chengyu Dong and Xiaodong Liu and Bin Yu and Jianfeng Gao},
journal = {ArXiv},
year = {2023},
volume = {abs/2304.08612}
}
@inproceedings{huh2023improvedvqste,
title = {Straightening Out the Straight-Through Estimator: Overcoming Optimization Challenges in Vector Quantized Networks},
author = {Huh, Minyoung and Cheung, Brian and Agrawal, Pulkit and Isola, Phillip},
booktitle = {International Conference on Machine Learning},
year = {2023},
organization = {PMLR}
}
@inproceedings{rogozhnikov2022einops,
title = {Einops: Clear and Reliable Tensor Manipulations with Einstein-like Notation},
author = {Alex Rogozhnikov},
booktitle = {International Conference on Learning Representations},
year = {2022},
url = {https://openreview.net/forum?id=oapKSVM2bcj}
}
@misc{shin2021translationequivariant,
title = {Translation-equivariant Image Quantizer for Bi-directional Image-Text Generation},
author = {Woncheol Shin and Gyubok Lee and Jiyoung Lee and Joonseok Lee and Edward Choi},
year = {2021},
eprint = {2112.00384},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
@misc{mentzer2023finite,
title = {Finite Scalar Quantization: VQ-VAE Made Simple},
author = {Fabian Mentzer and David Minnen and Eirikur Agustsson and Michael Tschannen},
year = {2023},
eprint = {2309.15505},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
@misc{yu2023language,
title = {Language Model Beats Diffusion -- Tokenizer is Key to Visual Generation},
author = {Lijun Yu and José Lezama and Nitesh B. Gundavarapu and Luca Versari and Kihyuk Sohn and David Minnen and Yong Cheng and Agrim Gupta and Xiuye Gu and Alexander G. Hauptmann and Boqing Gong and Ming-Hsuan Yang and Irfan Essa and David A. Ross and Lu Jiang},
year = {2023},
eprint = {2310.05737},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
@misc{hsu2023disentanglement,
title = {Disentanglement via Latent Quantization},
author = {Kyle Hsu and Will Dorrell and James C. R. Whittington and Jiajun Wu and Chelsea Finn},
year = {2023},
eprint = {2305.18378},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}