Lucidrains 系列项目源码解析(一百零四)
.\lucidrains\vector-quantize-pytorch\setup.py
from setuptools import setup, find_packages
setup(
name = 'vector_quantize_pytorch',
packages = find_packages(),
version = '1.14.5',
license='MIT',
description = 'Vector Quantization - Pytorch',
long_description_content_type = 'text/markdown',
author = 'Phil Wang',
author_email = 'lucidrains@gmail.com',
url = 'https://github.com/lucidrains/vector-quantizer-pytorch',
keywords = [
'artificial intelligence',
'deep learning',
'pytorch',
'quantization'
],
install_requires=[
'einops>=0.7.0',
'einx[torch]>=0.1.3',
'torch'
],
classifiers=[
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\vector-quantize-pytorch\vector_quantize_pytorch\finite_scalar_quantization.py
"""
Finite Scalar Quantization: VQ-VAE Made Simple - https://arxiv.org/abs/2309.15505
Code adapted from Jax version in Appendix A.1
"""
from typing import List, Tuple, Optional
import torch
import torch.nn as nn
from torch.nn import Module
from torch import Tensor, int32
from torch.cuda.amp import autocast
from einops import rearrange, pack, unpack
def exists(v):
return v is not None
def default(*args):
for arg in args:
if exists(arg):
return arg
return None
def pack_one(t, pattern):
return pack([t], pattern)
def unpack_one(t, ps, pattern):
return unpack(t, ps, pattern)[0]
def round_ste(z: Tensor) -> Tensor:
"""Round with straight through gradients."""
zhat = z.round()
return z + (zhat - z).detach()
class FSQ(Module):
def __init__(
self,
levels: List[int],
dim: Optional[int] = None,
num_codebooks = 1,
keep_num_codebooks_dim: Optional[bool] = None,
scale: Optional[float] = None,
allowed_dtypes: Tuple[torch.dtype, ...] = (torch.float32, torch.float64)
):
super().__init__()
_levels = torch.tensor(levels, dtype=int32)
self.register_buffer("_levels", _levels, persistent = False)
_basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=int32)
self.register_buffer("_basis", _basis, persistent = False)
self.scale = scale
codebook_dim = len(levels)
self.codebook_dim = codebook_dim
effective_codebook_dim = codebook_dim * num_codebooks
self.num_codebooks = num_codebooks
self.effective_codebook_dim = effective_codebook_dim
keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1)
assert not (num_codebooks > 1 and not keep_num_codebooks_dim)
self.keep_num_codebooks_dim = keep_num_codebooks_dim
self.dim = default(dim, len(_levels) * num_codebooks)
has_projections = self.dim != effective_codebook_dim
self.project_in = nn.Linear(self.dim, effective_codebook_dim) if has_projections else nn.Identity()
self.project_out = nn.Linear(effective_codebook_dim, self.dim) if has_projections else nn.Identity()
self.has_projections = has_projections
self.codebook_size = self._levels.prod().item()
implicit_codebook = self.indices_to_codes(torch.arange(self.codebook_size), project_out = False)
self.register_buffer("implicit_codebook", implicit_codebook, persistent = False)
self.allowed_dtypes = allowed_dtypes
def bound(self, z: Tensor, eps: float = 1e-3) -> Tensor:
"""Bound `z`, an array of shape (..., d)."""
half_l = (self._levels - 1) * (1 + eps) / 2
offset = torch.where(self._levels % 2 == 0, 0.5, 0.0)
shift = (offset / half_l).atanh()
return (z + shift).tanh() * half_l - offset
def quantize(self, z: Tensor) -> Tensor:
"""Quantizes z, returns quantized zhat, same shape as z."""
quantized = round_ste(self.bound(z))
half_width = self._levels // 2
return quantized / half_width
def _scale_and_shift(self, zhat_normalized: Tensor) -> Tensor:
half_width = self._levels // 2
return (zhat_normalized * half_width) + half_width
def _scale_and_shift_inverse(self, zhat: Tensor) -> Tensor:
half_width = self._levels // 2
return (zhat - half_width) / half_width
def codes_to_indices(self, zhat: Tensor) -> Tensor:
"""Converts a `code` to an index in the codebook."""
assert zhat.shape[-1] == self.codebook_dim
zhat = self._scale_and_shift(zhat)
return (zhat * self._basis).sum(dim=-1).to(int32)
def indices_to_codes(
self,
indices: Tensor,
project_out = True
def codes_to_indices(self, indices: Tensor) -> Tensor:
"""Inverse of `codes_to_indices`."""
is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))
indices = rearrange(indices, '... -> ... 1')
codes_non_centered = (indices // self._basis) % self._levels
codes = self._scale_and_shift_inverse(codes_non_centered)
if self.keep_num_codebooks_dim:
codes = rearrange(codes, '... c d -> ... (c d)')
if project_out:
codes = self.project_out(codes)
if is_img_or_video:
codes = rearrange(codes, 'b ... d -> b d ...')
return codes
@autocast(enabled = False)
def forward(self, z: Tensor) -> Tensor:
"""
einstein notation
b - batch
n - sequence (or flattened spatial dimensions)
d - feature dimension
c - number of codebook dim
"""
orig_dtype = z.dtype
is_img_or_video = z.ndim >= 4
if z.dtype not in self.allowed_dtypes:
z = z.float()
if is_img_or_video:
z = rearrange(z, 'b d ... -> b ... d')
z, ps = pack_one(z, 'b * d')
assert z.shape[-1] == self.dim, f'expected dimension of {self.dim} but found dimension of {z.shape[-1]}'
z = self.project_in(z)
z = rearrange(z, 'b n (c d) -> b n c d', c = self.num_codebooks)
codes = self.quantize(z)
indices = self.codes_to_indices(codes)
codes = rearrange(codes, 'b n c d -> b n (c d)')
out = self.project_out(codes)
if is_img_or_video:
out = unpack_one(out, ps, 'b * d')
out = rearrange(out, 'b ... d -> b d ...')
indices = unpack_one(indices, ps, 'b * c')
if not self.keep_num_codebooks_dim:
indices = rearrange(indices, '... 1 -> ...')
if out.dtype != orig_dtype:
out = out.type(orig_dtype)
return out, indices
.\lucidrains\vector-quantize-pytorch\vector_quantize_pytorch\latent_quantization.py
"""
Disentanglement via Latent Quantization
- https://arxiv.org/abs/2305.18378
Code adapted from Jax version in https://github.com/kylehkhsu/latent_quantization
"""
from typing import List, Optional, Union, Callable
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Module
from torch import Tensor, int32
from torch.optim import Optimizer
from einops import rearrange, pack, unpack
def exists(v):
return v is not None
def default(*args):
for arg in args:
if exists(arg):
return arg
return None
def pack_one(t, pattern):
return pack([t], pattern)
def unpack_one(t, ps, pattern):
return unpack(t, ps, pattern)[0]
class LatentQuantize(Module):
def quantization_loss(self, z: Tensor, zhat: Tensor, reduce="mean") -> Tensor:
"""Computes the quantization loss."""
return F.mse_loss(zhat.detach(), z, reduction=reduce)
def commitment_loss(self, z: Tensor, zhat: Tensor, reduce="mean") -> Tensor:
"""Computes the commitment loss."""
return F.mse_loss(z.detach(), zhat, reduction=reduce)
def quantize(self, z: Tensor) -> Tensor:
"""Quantizes z, returns quantized zhat, same shape as z.
The quantization is done by measuring the distance between the input and the codebook values per latent dimension
and returning the index of the closest codebook value.
"""
def distance(x, y):
return torch.abs(x - y)
if self._equal_levels:
index = torch.argmin(distance(z[..., None], self.values_per_latent), dim=-1)
quantize = self.values_per_latent[torch.arange(self.dim), index]
else:
index = torch.stack([torch.argmin(distance(z[..., i, None], self.values_per_latent[i]), dim=-1) for i in range(self.codebook_dim)], dim=-1)
quantize = torch.stack([self.values_per_latent[i][index[..., i]] for i in range(self.codebook_dim)], dim=-1)
quantize = z + (quantize - z).detach()
return quantize
def _scale_and_shift(self, zhat_normalized: Tensor) -> Tensor:
""" scale and shift zhat from [-0.5, 0.5] to [0, level_per_dim]"""
half_width = self._levels // 2
return (zhat_normalized * 2 * half_width) + half_width
def _scale_and_shift_inverse(self, zhat: Tensor) -> Tensor:
"""normalize zhat to [-0.5, 0.5]"""
half_width = self._levels // 2
return (zhat - half_width) / half_width / 2
def codes_to_indices(self, zhat: Tensor) -> Tensor:
"""Converts a `code` which contains the number per latent to an index in the codebook."""
assert zhat.shape[-1] == self.codebook_dim
zhat = self._scale_and_shift(zhat)
return (zhat * self._basis).sum(dim=-1).to(int32)
def indices_to_codes(
self,
indices: Tensor,
project_out = True
) -> Tensor:
"""Inverse of `codes_to_indices`."""
is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))
indices = rearrange(indices, '... -> ... 1')
codes_non_centered = (indices // self._basis) % self._levels
codes = self._scale_and_shift_inverse(codes_non_centered)
if self.keep_num_codebooks_dim:
codes = rearrange(codes, '... c d -> ... (c d)')
if project_out:
codes = self.project_out(codes)
if is_img_or_video:
codes = rearrange(codes, 'b ... d -> b d ...')
return codes
def quantize_and_project(self, z: Tensor, is_img_or_video, ps) -> Tensor:
codes = self.quantize(z)
indices = self.codes_to_indices(codes)
codes = rearrange(codes, 'b n c d -> b n (c d)')
out = self.project_out(codes)
if is_img_or_video:
out = unpack_one(out, ps, 'b * d')
out = rearrange(out, 'b ... d -> b d ...')
indices = unpack_one(indices, ps, 'b * c')
if not self.keep_num_codebooks_dim:
indices = rearrange(indices, '... 1 -> ...')
return codes, out, indices
def forward(self,
z: Tensor) -> Tensor:
"""
einstein notation
b - batch
n - sequence (or flattened spatial dimensions)
d - feature dimension
c - number of codebook dim
"""
is_img_or_video = z.ndim >= 4
original_input = z
should_inplace_optimize = exists(self.in_place_codebook_optimizer)
if is_img_or_video:
z = rearrange(z, 'b d ... -> b ... d')
z, ps = pack_one(z, 'b * d')
assert z.shape[-1] == self.dim, f'expected dimension of {self.dim} but found dimension of {z.shape[-1]}'
z = self.project_in(z)
z = rearrange(z, 'b n (c d) -> b n c d', c = self.num_codebooks)
codes = self.quantize(z)
indices = self.codes_to_indices(codes)
codes = rearrange(codes, 'b n c d -> b n (c d)')
out = self.project_out(codes)
if is_img_or_video:
out = unpack_one(out, ps, 'b * d')
out = rearrange(out, 'b ... d -> b d ...')
indices = unpack_one(indices, ps, 'b * c')
if not self.keep_num_codebooks_dim:
indices = rearrange(indices, '... 1 -> ...')
if should_inplace_optimize and self.training and not self.optimize_values:
loss = self.commitment_loss(z, out) if self.commitment_loss_weight!=0 else torch.tensor(0.)
loss+= self.quantization_loss(z, out) if self.quantization_loss_weight!=0 else torch.tensor(0.)
loss.backward()
self.in_place_codebook_optimizer.step()
self.in_place_codebook_optimizer.zero_grad()
codes = self.quantize(z)
indices = self.codes_to_indices(codes)
codes = rearrange(codes, 'b n c d -> b n (c d)')
out = self.project_out(codes)
if is_img_or_video:
out = unpack_one(out, ps, 'b * d')
out = rearrange(out, 'b ... d -> b d ...')
indices = unpack_one(indices, ps, 'b * c')
if not self.keep_num_codebooks_dim:
indices = rearrange(indices, '... 1 -> ...')
commitment_loss = self.commitment_loss(original_input, out) if self.training and self.commitment_loss_weight!=0 else torch.tensor(0.)
quantization_loss = self.quantization_loss(original_input, out) if self.training and self.quantization_loss_weight!=0 else torch.tensor(0.)
loss = self.commitment_loss_weight * commitment_loss + self.quantization_loss_weight * quantization_loss
return out, indices, loss
.\lucidrains\vector-quantize-pytorch\vector_quantize_pytorch\lookup_free_quantization.py
"""
Lookup Free Quantization
Proposed in https://arxiv.org/abs/2310.05737
In the simplest setup, each dimension is quantized into {-1, 1}.
An entropy penalty is used to encourage utilization.
"""
from math import log2, ceil
from collections import namedtuple
import torch
from torch import nn, einsum
import torch.nn.functional as F
from torch.nn import Module
from torch.cuda.amp import autocast
from einops import rearrange, reduce, pack, unpack
Return = namedtuple('Return', ['quantized', 'indices', 'entropy_aux_loss'])
LossBreakdown = namedtuple('LossBreakdown', ['per_sample_entropy', 'batch_entropy', 'commitment'])
def exists(v):
return v is not None
def default(*args):
for arg in args:
if exists(arg):
return arg() if callable(arg) else arg
return None
def pack_one(t, pattern):
return pack([t], pattern)
def unpack_one(t, ps, pattern):
return unpack(t, ps, pattern)[0]
def log(t, eps = 1e-5):
return t.clamp(min = eps).log()
def entropy(prob):
return (-prob * log(prob)).sum(dim=-1)
class LFQ(Module):
def __init__(
self,
*,
dim = None,
codebook_size = None,
entropy_loss_weight = 0.1,
commitment_loss_weight = 0.25,
diversity_gamma = 1.,
straight_through_activation = nn.Identity(),
num_codebooks = 1,
keep_num_codebooks_dim = None,
codebook_scale = 1.,
frac_per_sample_entropy = 1.
):
super().__init__()
assert exists(dim) or exists(codebook_size), 'either dim or codebook_size must be specified for LFQ'
assert not exists(codebook_size) or log2(codebook_size).is_integer(), f'your codebook size must be a power of 2 for lookup free quantization (suggested {2 ** ceil(log2(codebook_size))})'
codebook_size = default(codebook_size, lambda: 2 ** dim)
codebook_dim = int(log2(codebook_size))
codebook_dims = codebook_dim * num_codebooks
dim = default(dim, codebook_dims)
has_projections = dim != codebook_dims
self.project_in = nn.Linear(dim, codebook_dims) if has_projections else nn.Identity()
self.project_out = nn.Linear(codebook_dims, dim) if has_projections else nn.Identity()
self.has_projections = has_projections
self.dim = dim
self.codebook_dim = codebook_dim
self.num_codebooks = num_codebooks
keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1)
assert not (num_codebooks > 1 and not keep_num_codebooks_dim)
self.keep_num_codebooks_dim = keep_num_codebooks_dim
self.activation = straight_through_activation
assert 0 < frac_per_sample_entropy <= 1.
self.frac_per_sample_entropy = frac_per_sample_entropy
self.diversity_gamma = diversity_gamma
self.entropy_loss_weight = entropy_loss_weight
self.codebook_scale = codebook_scale
self.commitment_loss_weight = commitment_loss_weight
self.register_buffer('mask', 2 ** torch.arange(codebook_dim - 1, -1, -1))
self.register_buffer('zero', torch.tensor(0.), persistent = False)
all_codes = torch.arange(codebook_size)
bits = ((all_codes[..., None].int() & self.mask) != 0).float()
codebook = self.bits_to_codes(bits)
self.register_buffer('codebook', codebook, persistent = False)
def bits_to_codes(self, bits):
return bits * self.codebook_scale * 2 - self.codebook_scale
@property
def dtype(self):
return self.codebook.dtype
def indices_to_codes(
self,
indices,
project_out = True
):
is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))
if not self.keep_num_codebooks_dim:
indices = rearrange(indices, '... -> ... 1')
bits = ((indices[..., None].int() & self.mask) != 0).to(self.dtype)
codes = self.bits_to_codes(bits)
codes = rearrange(codes, '... c d -> ... (c d)')
if project_out:
codes = self.project_out(codes)
if is_img_or_video:
codes = rearrange(codes, 'b ... d -> b d ...')
return codes
@autocast(enabled = False)
def forward(
self,
x,
inv_temperature = 100.,
return_loss_breakdown = False,
mask = None,
.\lucidrains\vector-quantize-pytorch\vector_quantize_pytorch\random_projection_quantizer.py
import torch
from torch import nn, einsum
import torch.nn.functional as F
from vector_quantize_pytorch.vector_quantize_pytorch import VectorQuantize
from einops import rearrange, repeat, pack, unpack
def exists(val):
return val is not None
class RandomProjectionQuantizer(nn.Module):
""" https://arxiv.org/abs/2202.01855 """
def __init__(
self,
*,
dim,
codebook_size,
codebook_dim,
num_codebooks = 1,
norm = True,
**kwargs
):
super().__init__()
self.num_codebooks = num_codebooks
rand_projs = torch.empty(num_codebooks, dim, codebook_dim)
nn.init.xavier_normal_(rand_projs)
self.register_buffer('rand_projs', rand_projs)
self.norm = nn.LayerNorm(dim, elementwise_affine = False) if norm else nn.Identity()
self.vq = VectorQuantize(
dim = codebook_dim * num_codebooks,
heads = num_codebooks,
codebook_size = codebook_size,
use_cosine_sim = True,
separate_codebook_per_head = True,
**kwargs
)
def forward(
self,
x,
indices = None
):
return_loss = exists(indices)
x = self.norm(x)
x = einsum('b n d, h d e -> b n h e', x, self.rand_projs)
x, ps = pack([x], 'b n *')
self.vq.eval()
out = self.vq(x, indices = indices)
if return_loss:
_, ce_loss = out
return ce_loss
_, indices, _ = out
return indices
.\lucidrains\vector-quantize-pytorch\vector_quantize_pytorch\residual_fsq.py
import random
from math import log2
from functools import partial
from typing import List
import torch
from torch import nn
from torch.nn import Module, ModuleList
import torch.nn.functional as F
from torch.cuda.amp import autocast
from vector_quantize_pytorch.finite_scalar_quantization import FSQ
from einops import rearrange, repeat, reduce, pack, unpack
from einx import get_at
def exists(val):
return val is not None
def first(l):
return l[0]
def default(val, d):
return val if exists(val) else d
def round_up_multiple(num, mult):
return ceil(num / mult) * mult
class ResidualFSQ(Module):
""" Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf """
def __init__(
self,
*,
dim,
levels: List[int],
num_quantizers,
quantize_dropout = False,
quantize_dropout_cutoff_index = 0,
quantize_dropout_multiple_of = 1,
**kwargs
):
super().__init__()
codebook_dim = len(levels)
requires_projection = codebook_dim != dim
self.project_in = nn.Linear(dim, codebook_dim) if requires_projection else nn.Identity()
self.project_out = nn.Linear(codebook_dim, dim) if requires_projection else nn.Identity()
self.has_projections = requires_projection
self.num_quantizers = num_quantizers
self.levels = levels
self.layers = nn.ModuleList([])
levels_tensor = torch.Tensor(levels)
scales = []
for ind in range(num_quantizers):
scales.append((levels_tensor - 1) ** -ind)
fsq = FSQ(
levels = levels,
dim = codebook_dim,
**kwargs
)
self.layers.append(fsq)
assert all([not fsq.has_projections for fsq in self.layers])
self.codebook_size = self.layers[0].codebook_size
self.register_buffer('scales', torch.stack(scales), persistent = False)
self.quantize_dropout = quantize_dropout and num_quantizers > 1
assert quantize_dropout_cutoff_index >= 0
self.quantize_dropout_cutoff_index = quantize_dropout_cutoff_index
self.quantize_dropout_multiple_of = quantize_dropout_multiple_of
@property
def codebooks(self):
codebooks = [layer.implicit_codebook for layer in self.layers]
codebooks = torch.stack(codebooks, dim = 0)
return codebooks
def get_codes_from_indices(self, indices):
batch, quantize_dim = indices.shape[0], indices.shape[-1]
indices, ps = pack([indices], 'b * q')
if quantize_dim < self.num_quantizers:
assert self.quantize_dropout > 0., 'quantize dropout must be greater than 0 if you wish to reconstruct from a signal with less fine quantizations'
indices = F.pad(indices, (0, self.num_quantizers - quantize_dim), value = -1)
mask = indices == -1
indices = indices.masked_fill(mask, 0)
all_codes = get_at('q [c] d, b n q -> q b n d', self.codebooks, indices)
all_codes = all_codes.masked_fill(rearrange(mask, 'b n q -> q b n 1'), 0.)
scales = rearrange(self.scales, 'q d -> q 1 1 d')
all_codes = all_codes * scales
all_codes, = unpack(all_codes, ps, 'q b * d')
return all_codes
def get_output_from_indices(self, indices):
codes = self.get_codes_from_indices(indices)
codes_summed = reduce(codes, 'q ... -> ...', 'sum')
return self.project_out(codes_summed)
def forward(
self,
x,
return_all_codes = False,
rand_quantize_dropout_fixed_seed = None
):
num_quant, quant_dropout_multiple_of, device = self.num_quantizers, self.quantize_dropout_multiple_of, x.device
x = self.project_in(x)
quantized_out = 0.
residual = first(self.layers).bound(x)
all_indices = []
should_quantize_dropout = self.training and self.quantize_dropout
if should_quantize_dropout:
rand = random.Random(rand_quantize_dropout_fixed_seed) if exists(rand_quantize_dropout_fixed_seed) else random
rand_quantize_dropout_index = rand.randrange(self.quantize_dropout_cutoff_index, num_quant)
if quant_dropout_multiple_of != 1:
rand_quantize_dropout_index = round_up_multiple(rand_quantize_dropout_index + 1, quant_dropout_multiple_of) - 1
null_indices = torch.full(x.shape[:2], -1., device = device, dtype = torch.long)
with autocast(enabled = False):
for quantizer_index, (layer, scale) in enumerate(zip(self.layers, self.scales)):
if should_quantize_dropout and quantizer_index > rand_quantize_dropout_index:
all_indices.append(null_indices)
continue
quantized, indices = layer(residual / scale)
quantized = quantized * scale
residual = residual - quantized.detach()
quantized_out = quantized_out + quantized
all_indices.append(indices)
quantized_out = self.project_out(quantized_out)
all_indices = torch.stack(all_indices, dim = -1)
ret = (quantized_out, all_indices)
if not return_all_codes:
return ret
all_codes = self.get_codes_from_indices(all_indices)
return (*ret, all_codes)
class GroupedResidualFSQ(Module):
def __init__(
self,
*,
dim,
groups = 1,
accept_image_fmap = False,
**kwargs
):
super().__init__()
self.dim = dim
self.groups = groups
assert (dim % groups) == 0
dim_per_group = dim // groups
self.accept_image_fmap = accept_image_fmap
self.rvqs = nn.ModuleList([])
for _ in range(groups):
self.rvqs.append(ResidualFSQ(
dim = dim_per_group,
**kwargs
))
self.codebook_size = self.rvqs[0].codebook_size
@property
def codebooks(self):
return torch.stack(tuple(rvq.codebooks for rvq in self.rvqs))
@property
def split_dim(self):
return 1 if self.accept_image_fmap else -1
def get_codes_from_indices(self, indices):
codes = tuple(rvq.get_codes_from_indices(chunk_indices) for rvq, chunk_indices in zip(self.rvqs, indices))
return torch.stack(codes)
def get_output_from_indices(self, indices):
outputs = tuple(rvq.get_output_from_indices(chunk_indices) for rvq, chunk_indices in zip(self.rvqs, indices))
return torch.cat(outputs, dim = self.split_dim)
def forward(
self,
x,
return_all_codes = False
):
shape, split_dim = x.shape, self.split_dim
assert shape[split_dim] == self.dim
x = x.chunk(self.groups, dim = split_dim)
forward_kwargs = dict(
return_all_codes = return_all_codes,
rand_quantize_dropout_fixed_seed = random.randint(0, 1e7)
)
out = tuple(rvq(chunk, **forward_kwargs) for rvq, chunk in zip(self.rvqs, x))
out = tuple(zip(*out))
quantized, all_indices, *maybe_all_codes = out
quantized = torch.cat(quantized, dim = split_dim)
all_indices = torch.stack(all_indices)
ret = (quantized, all_indices, *maybe_all_codes)
return ret
.\lucidrains\vector-quantize-pytorch\vector_quantize_pytorch\residual_lfq.py
import random
from math import log2
from functools import partial
import torch
from torch import nn
from torch.nn import Module, ModuleList
import torch.nn.functional as F
from torch.cuda.amp import autocast
from vector_quantize_pytorch.lookup_free_quantization import LFQ
from einops import rearrange, repeat, reduce, pack, unpack
from einx import get_at
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def round_up_multiple(num, mult):
return ceil(num / mult) * mult
class ResidualLFQ(Module):
""" Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf """
def __init__(
self,
*,
dim,
num_quantizers,
codebook_size,
quantize_dropout = False,
quantize_dropout_cutoff_index = 0,
quantize_dropout_multiple_of = 1,
**kwargs
):
super().__init__()
codebook_dim = int(log2(codebook_size))
requires_projection = codebook_dim != dim
self.project_in = nn.Linear(dim, codebook_dim) if requires_projection else nn.Identity()
self.project_out = nn.Linear(codebook_dim, dim) if requires_projection else nn.Identity()
self.has_projections = requires_projection
self.num_quantizers = num_quantizers
self.layers = nn.ModuleList([])
for ind in range(num_quantizers):
codebook_scale = 2 ** -ind
lfq = LFQ(
dim = codebook_dim,
codebook_scale = codebook_scale,
**kwargs
)
self.layers.append(lfq)
assert all([not lfq.has_projections for lfq in self.layers])
self.quantize_dropout = quantize_dropout and num_quantizers > 1
assert quantize_dropout_cutoff_index >= 0
self.quantize_dropout_cutoff_index = quantize_dropout_cutoff_index
self.quantize_dropout_multiple_of = quantize_dropout_multiple_of
@property
def codebooks(self):
codebooks = [layer.codebook for layer in self.layers]
codebooks = torch.stack(codebooks, dim = 0)
return codebooks
def get_codes_from_indices(self, indices):
batch, quantize_dim = indices.shape[0], indices.shape[-1]
indices, ps = pack([indices], 'b * q')
if quantize_dim < self.num_quantizers:
assert self.quantize_dropout > 0., '如果希望从较少的精细量化信号重构,则 quantize dropout 必须大于 0'
indices = F.pad(indices, (0, self.num_quantizers - quantize_dim), value = -1)
mask = indices == -1.
indices = indices.masked_fill(mask, 0)
all_codes = get_at('q [c] d, b n q -> q b n d', self.codebooks, indices)
all_codes = all_codes.masked_fill(rearrange(mask, 'b n q -> q b n 1'), 0.)
all_codes, = unpack(all_codes, ps, 'q b * d')
return all_codes
def get_output_from_indices(self, indices):
codes = self.get_codes_from_indices(indices)
codes_summed = reduce(codes, 'q ... -> ...', 'sum')
return self.project_out(codes_summed)
def forward(
self,
x,
mask = None,
return_all_codes = False,
rand_quantize_dropout_fixed_seed = None
):
num_quant, quant_dropout_multiple_of, device = self.num_quantizers, self.quantize_dropout_multiple_of, x.device
x = self.project_in(x)
quantized_out = 0.
residual = x
all_losses = []
all_indices = []
should_quantize_dropout = self.training and self.quantize_dropout
if should_quantize_dropout:
rand = random.Random(rand_quantize_dropout_fixed_seed) if exists(rand_quantize_dropout_fixed_seed) else random
rand_quantize_dropout_index = rand.randrange(self.quantize_dropout_cutoff_index, num_quant)
if quant_dropout_multiple_of != 1:
rand_quantize_dropout_index = round_up_multiple(rand_quantize_dropout_index + 1, quant_dropout_multiple_of) - 1
null_indices = torch.full(x.shape[:2], -1., device=device, dtype=torch.long)
null_loss = torch.tensor(0., device=device, dtype=x.dtype)
with autocast(enabled=False):
for quantizer_index, layer in enumerate(self.layers):
if should_quantize_dropout and quantizer_index > rand_quantize_dropout_index:
all_indices.append(null_indices)
all_losses.append(null_loss)
continue
quantized, indices, loss = layer(residual, mask=mask)
residual = residual - quantized.detach()
quantized_out = quantized_out + quantized
all_indices.append(indices)
all_losses.append(loss)
quantized_out = self.project_out(quantized_out)
all_losses, all_indices = map(partial(torch.stack, dim=-1), (all_losses, all_indices))
ret = (quantized_out, all_indices, all_losses)
if not return_all_codes:
return ret
all_codes = self.get_codes_from_indices(all_indices)
return (*ret, all_codes)
class GroupedResidualLFQ(Module):
def __init__(
self,
*,
dim,
groups = 1,
accept_image_fmap = False,
**kwargs
):
super().__init__()
self.dim = dim
self.groups = groups
assert (dim % groups) == 0
dim_per_group = dim // groups
self.accept_image_fmap = accept_image_fmap
self.rvqs = nn.ModuleList([])
for _ in range(groups):
self.rvqs.append(ResidualLFQ(
dim = dim_per_group,
**kwargs
))
@property
def codebooks(self):
return torch.stack(tuple(rvq.codebooks for rvq in self.rvqs))
@property
def split_dim(self):
return 1 if self.accept_image_fmap else -1
def get_codes_from_indices(self, indices):
codes = tuple(rvq.get_codes_from_indices(chunk_indices) for rvq, chunk_indices in zip(self.rvqs, indices))
return torch.stack(codes)
def get_output_from_indices(self, indices):
outputs = tuple(rvq.get_output_from_indices(chunk_indices) for rvq, chunk_indices in zip(self.rvqs, indices))
return torch.cat(outputs, dim = self.split_dim)
def forward(
self,
x,
mask = None,
return_all_codes = False
):
shape, split_dim = x.shape, self.split_dim
assert shape[split_dim] == self.dim
x = x.chunk(self.groups, dim = split_dim)
forward_kwargs = dict(
mask = mask,
return_all_codes = return_all_codes,
rand_quantize_dropout_fixed_seed = random.randint(0, 1e7)
)
out = tuple(rvq(chunk, **forward_kwargs) for rvq, chunk in zip(self.rvqs, x))
out = tuple(zip(*out))
quantized, all_indices, commit_losses, *maybe_all_codes = out
quantized = torch.cat(quantized, dim = split_dim)
all_indices = torch.stack(all_indices)
commit_losses = torch.stack(commit_losses)
ret = (quantized, all_indices, commit_losses, *maybe_all_codes)
return ret
.\lucidrains\vector-quantize-pytorch\vector_quantize_pytorch\residual_vq.py
import random
from math import ceil
from functools import partial
from itertools import zip_longest
import torch
from torch import nn
import torch.nn.functional as F
from vector_quantize_pytorch.vector_quantize_pytorch import VectorQuantize
from einops import rearrange, repeat, reduce, pack, unpack
from einx import get_at
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def round_up_multiple(num, mult):
return ceil(num / mult) * mult
class ResidualVQ(nn.Module):
""" Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf """
def __init__(
self,
*,
dim,
num_quantizers,
codebook_dim = None,
shared_codebook = False,
heads = 1,
quantize_dropout = False,
quantize_dropout_cutoff_index = 0,
quantize_dropout_multiple_of = 1,
accept_image_fmap = False,
**kwargs
):
super().__init__()
assert heads == 1, 'residual vq is not compatible with multi-headed codes'
codebook_dim = default(codebook_dim, dim)
codebook_input_dim = codebook_dim * heads
requires_projection = codebook_input_dim != dim
self.project_in = nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity()
self.project_out = nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity()
self.has_projections = requires_projection
self.num_quantizers = num_quantizers
self.accept_image_fmap = accept_image_fmap
self.layers = nn.ModuleList([VectorQuantize(dim = codebook_dim, codebook_dim = codebook_dim, accept_image_fmap = accept_image_fmap, **kwargs) for _ in range(num_quantizers)])
assert all([not vq.has_projections for vq in self.layers])
self.quantize_dropout = quantize_dropout and num_quantizers > 1
assert quantize_dropout_cutoff_index >= 0
self.quantize_dropout_cutoff_index = quantize_dropout_cutoff_index
self.quantize_dropout_multiple_of = quantize_dropout_multiple_of
if not shared_codebook:
return
first_vq, *rest_vq = self.layers
codebook = first_vq._codebook
for vq in rest_vq:
vq._codebook = codebook
@property
def codebooks(self):
codebooks = [layer._codebook.embed for layer in self.layers]
codebooks = torch.stack(codebooks, dim = 0)
codebooks = rearrange(codebooks, 'q 1 c d -> q c d')
return codebooks
def get_codes_from_indices(self, indices):
batch, quantize_dim = indices.shape[0], indices.shape[-1]
indices, ps = pack([indices], 'b * q')
if quantize_dim < self.num_quantizers:
assert self.quantize_dropout > 0., 'quantize dropout must be greater than 0 if you wish to reconstruct from a signal with less fine quantizations'
indices = F.pad(indices, (0, self.num_quantizers - quantize_dim), value = -1)
mask = indices == -1.
indices = indices.masked_fill(mask, 0)
all_codes = get_at('q [c] d, b n q -> q b n d', self.codebooks, indices)
all_codes = all_codes.masked_fill(rearrange(mask, 'b n q -> q b n 1'), 0.)
all_codes, = unpack(all_codes, ps, 'q b * d')
return all_codes
def get_output_from_indices(self, indices):
codes = self.get_codes_from_indices(indices)
codes_summed = reduce(codes, 'q ... -> ...', 'sum')
return self.project_out(codes_summed)
def forward(
self,
x,
mask = None,
indices = None,
return_all_codes = False,
sample_codebook_temp = None,
freeze_codebook = False,
rand_quantize_dropout_fixed_seed = None
):
num_quant, quant_dropout_multiple_of, return_loss, device = self.num_quantizers, self.quantize_dropout_multiple_of, exists(indices), x.device
x = self.project_in(x)
assert not (self.accept_image_fmap and exists(indices))
quantized_out = 0.
residual = x
all_losses = []
all_indices = []
if return_loss:
assert not torch.any(indices == -1), 'some of the residual vq indices were dropped out. please use indices derived when the module is in eval mode to derive cross entropy loss'
ce_losses = []
should_quantize_dropout = self.training and self.quantize_dropout and not return_loss
if should_quantize_dropout:
rand = random.Random(rand_quantize_dropout_fixed_seed) if exists(rand_quantize_dropout_fixed_seed) else random
rand_quantize_dropout_index = rand.randrange(self.quantize_dropout_cutoff_index, num_quant)
if quant_dropout_multiple_of != 1:
rand_quantize_dropout_index = round_up_multiple(rand_quantize_dropout_index + 1, quant_dropout_multiple_of) - 1
null_indices_shape = (x.shape[0], *x.shape[-2:]) if self.accept_image_fmap else tuple(x.shape[:2])
null_indices = torch.full(null_indices_shape, -1., device = device, dtype = torch.long)
null_loss = torch.full((1,), 0., device = device, dtype = x.dtype)
for quantizer_index, layer in enumerate(self.layers):
if should_quantize_dropout and quantizer_index > rand_quantize_dropout_index:
all_indices.append(null_indices)
all_losses.append(null_loss)
continue
layer_indices = None
if return_loss:
layer_indices = indices[..., quantizer_index]
quantized, *rest = layer(
residual,
mask = mask,
indices = layer_indices,
sample_codebook_temp = sample_codebook_temp,
freeze_codebook = freeze_codebook
)
residual = residual - quantized.detach()
quantized_out = quantized_out + quantized
if return_loss:
ce_loss = rest[0]
ce_losses.append(ce_loss)
continue
embed_indices, loss = rest
all_indices.append(embed_indices)
all_losses.append(loss)
quantized_out = self.project_out(quantized_out)
if return_loss:
return quantized_out, sum(ce_losses)
all_losses, all_indices = map(partial(torch.stack, dim = -1), (all_losses, all_indices))
ret = (quantized_out, all_indices, all_losses)
if return_all_codes:
all_codes = self.get_codes_from_indices(all_indices)
ret = (*ret, all_codes)
return ret
class GroupedResidualVQ(nn.Module):
def __init__(
self,
*,
dim,
groups = 1,
accept_image_fmap = False,
**kwargs
):
super().__init__()
self.dim = dim
self.groups = groups
assert (dim % groups) == 0
dim_per_group = dim // groups
self.accept_image_fmap = accept_image_fmap
self.rvqs = nn.ModuleList([])
for _ in range(groups):
self.rvqs.append(ResidualVQ(
dim = dim_per_group,
accept_image_fmap = accept_image_fmap,
**kwargs
))
@property
def codebooks(self):
return torch.stack(tuple(rvq.codebooks for rvq in self.rvqs))
@property
def split_dim(self):
return 1 if self.accept_image_fmap else -1
def get_codes_from_indices(self, indices):
codes = tuple(rvq.get_codes_from_indices(chunk_indices) for rvq, chunk_indices in zip(self.rvqs, indices))
return torch.stack(codes)
def get_output_from_indices(self, indices):
outputs = tuple(rvq.get_output_from_indices(chunk_indices) for rvq, chunk_indices in zip(self.rvqs, indices))
return torch.cat(outputs, dim = self.split_dim)
def forward(
self,
x,
indices = None,
return_all_codes = False,
sample_codebook_temp = None,
freeze_codebook = False,
mask = None,
):
shape, split_dim = x.shape, self.split_dim
assert shape[split_dim] == self.dim
x = x.chunk(self.groups, dim = split_dim)
indices = default(indices, tuple())
return_ce_loss = len(indices) > 0
assert len(indices) == 0 or len(indices) == self.groups
forward_kwargs = dict(
return_all_codes = return_all_codes,
sample_codebook_temp = sample_codebook_temp,
mask = mask,
freeze_codebook = freeze_codebook,
rand_quantize_dropout_fixed_seed = random.randint(0, 1e7)
)
out = tuple(rvq(chunk, indices = chunk_indices, **forward_kwargs) for rvq, chunk, chunk_indices in zip_longest(self.rvqs, x, indices))
out = tuple(zip(*out))
if return_ce_loss:
quantized, ce_losses = out
return torch.cat(quantized, dim = split_dim), sum(ce_losses)
quantized, all_indices, commit_losses, *maybe_all_codes = out
quantized = torch.cat(quantized, dim = split_dim)
all_indices = torch.stack(all_indices)
commit_losses = torch.stack(commit_losses)
ret = (quantized, all_indices, commit_losses, *maybe_all_codes)
return ret
.\lucidrains\vector-quantize-pytorch\vector_quantize_pytorch\vector_quantize_pytorch.py
from functools import partial
import torch
from torch import nn, einsum
import torch.nn.functional as F
import torch.distributed as distributed
from torch.optim import Optimizer
from torch.cuda.amp import autocast
from einops import rearrange, repeat, reduce, pack, unpack
from typing import Callable
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def noop(*args, **kwargs):
pass
def identity(t):
return t
def l2norm(t):
return F.normalize(t, p = 2, dim = -1)
def cdist(x, y):
x2 = reduce(x ** 2, 'b n d -> b n', 'sum')
y2 = reduce(y ** 2, 'b n d -> b n', 'sum')
xy = einsum('b i d, b j d -> b i j', x, y) * -2
return (rearrange(x2, 'b i -> b i 1') + rearrange(y2, 'b j -> b 1 j') + xy).clamp(min = 0).sqrt()
def log(t, eps = 1e-20):
return torch.log(t.clamp(min = eps))
def ema_inplace(old, new, decay):
is_mps = str(old.device).startswith('mps:')
if not is_mps:
old.lerp_(new, 1 - decay)
else:
old.mul_(decay).add_(new * (1 - decay))
def pack_one(t, pattern):
return pack([t], pattern)
def unpack_one(t, ps, pattern):
return unpack(t, ps, pattern)[0]
def uniform_init(*shape):
t = torch.empty(shape)
nn.init.kaiming_uniform_(t)
return t
def gumbel_noise(t):
noise = torch.zeros_like(t).uniform_(0, 1)
return -log(-log(noise))
def gumbel_sample(
logits,
temperature = 1.,
stochastic = False,
straight_through = False,
reinmax = False,
dim = -1,
training = True
):
dtype, size = logits.dtype, logits.shape[dim]
if training and stochastic and temperature > 0:
sampling_logits = (logits / temperature) + gumbel_noise(logits)
else:
sampling_logits = logits
ind = sampling_logits.argmax(dim = dim)
one_hot = F.one_hot(ind, size).type(dtype)
assert not (reinmax and not straight_through), 'reinmax can only be turned on if using straight through gumbel softmax'
if not straight_through or temperature <= 0. or not training:
return ind, one_hot
if reinmax:
π0 = logits.softmax(dim = dim)
π1 = (one_hot + (logits / temperature).softmax(dim = dim)) / 2
π1 = ((log(π1) - logits).detach() + logits).softmax(dim = 1)
π2 = 2 * π1 - 0.5 * π0
one_hot = π2 - π2.detach() + one_hot
else:
π1 = (logits / temperature).softmax(dim = dim)
one_hot = one_hot + π1 - π1.detach()
return ind, one_hot
def laplace_smoothing(x, n_categories, eps = 1e-5, dim = -1):
denom = x.sum(dim = dim, keepdim = True)
return (x + eps) / (denom + n_categories * eps)
def sample_vectors(samples, num):
num_samples, device = samples.shape[0], samples.device
if num_samples >= num:
indices = torch.randperm(num_samples, device = device)[:num]
else:
indices = torch.randint(0, num_samples, (num,), device = device)
return samples[indices]
def batched_sample_vectors(samples, num):
return torch.stack([sample_vectors(sample, num) for sample in samples.unbind(dim = 0)], dim = 0)
def pad_shape(shape, size, dim = 0):
return [size if i == dim else s for i, s in enumerate(shape)]
def sample_multinomial(total_count, probs):
device = probs.device
probs = probs.cpu()
total_count = probs.new_full((), total_count)
remainder = probs.new_ones(())
sample = torch.empty_like(probs, dtype = torch.long)
for i, p in enumerate(probs):
s = torch.binomial(total_count, p / remainder)
sample[i] = s
total_count -= s
remainder -= p
return sample.to(device)
def all_gather_sizes(x, dim):
size = torch.tensor(x.shape[dim], dtype = torch.long, device = x.device)
all_sizes = [torch.empty_like(size) for _ in range(distributed.get_world_size())]
distributed.all_gather(all_sizes, size)
return torch.stack(all_sizes)
def all_gather_variably_sized(x, sizes, dim = 0):
rank = distributed.get_rank()
all_x = []
for i, size in enumerate(sizes):
t = x if i == rank else x.new_empty(pad_shape(x.shape, size, dim))
distributed.broadcast(t, src = i, async_op = True)
all_x.append(t)
distributed.barrier()
return all_x
def sample_vectors_distributed(local_samples, num):
local_samples = rearrange(local_samples, '1 ... -> ...')
rank = distributed.get_rank()
all_num_samples = all_gather_sizes(local_samples, dim = 0)
if rank == 0:
samples_per_rank = sample_multinomial(num, all_num_samples / all_num_samples.sum())
else:
samples_per_rank = torch.empty_like(all_num_samples)
distributed.broadcast(samples_per_rank, src = 0)
samples_per_rank = samples_per_rank.tolist()
local_samples = sample_vectors(local_samples, samples_per_rank[rank])
all_samples = all_gather_variably_sized(local_samples, samples_per_rank, dim = 0)
out = torch.cat(all_samples, dim = 0)
return rearrange(out, '... -> 1 ...')
def batched_bincount(x, *, minlength):
batch, dtype, device = x.shape[0], x.dtype, x.device
target = torch.zeros(batch, minlength, dtype = dtype, device = device)
values = torch.ones_like(x)
target.scatter_add_(-1, x, values)
return target
def kmeans(
samples,
num_clusters,
num_iters = 10,
use_cosine_sim = False,
sample_fn = batched_sample_vectors,
all_reduce_fn = noop
):
num_codebooks, dim, dtype, device = samples.shape[0], samples.shape[-1], samples.dtype, samples.device
means = sample_fn(samples, num_clusters)
for _ in range(num_iters):
if use_cosine_sim:
dists = samples @ rearrange(means, 'h n d -> h d n')
else:
dists = -cdist(samples, means)
buckets = torch.argmax(dists, dim = -1)
bins = batched_bincount(buckets, minlength = num_clusters)
all_reduce_fn(bins)
zero_mask = bins == 0
bins_min_clamped = bins.masked_fill(zero_mask, 1)
new_means = buckets.new_zeros(num_codebooks, num_clusters, dim, dtype = dtype)
new_means.scatter_add_(1, repeat(buckets, 'h n -> h n d', d = dim), samples)
new_means = new_means / rearrange(bins_min_clamped, '... -> ... 1')
all_reduce_fn(new_means)
if use_cosine_sim:
new_means = l2norm(new_means)
means = torch.where(
rearrange(zero_mask, '... -> ... 1'),
means,
new_means
)
return means, bins
def batched_embedding(indices, embeds):
batch, dim = indices.shape[1], embeds.shape[-1]
indices = repeat(indices, 'h b n -> h b n d', d = dim)
embeds = repeat(embeds, 'h c d -> h b c d', b = batch)
return embeds.gather(2, indices)
def orthogonal_loss_fn(t):
h, n = t.shape[:2]
normed_codes = l2norm(t)
cosine_sim = einsum('h i d, h j d -> h i j', normed_codes, normed_codes)
return (cosine_sim ** 2).sum() / (h * n ** 2) - (1 / n)
class EuclideanCodebook(nn.Module):
def __init__(
self,
dim,
codebook_size,
num_codebooks = 1,
kmeans_init = False,
kmeans_iters = 10,
sync_kmeans = True,
decay = 0.8,
eps = 1e-5,
threshold_ema_dead_code = 2,
reset_cluster_size = None,
use_ddp = False,
learnable_codebook = False,
gumbel_sample = gumbel_sample,
sample_codebook_temp = 1.,
ema_update = True,
affine_param = False,
sync_affine_param = False,
affine_param_batch_decay = 0.99,
affine_param_codebook_decay = 0.9
):
super().__init__()
self.transform_input = identity
self.decay = decay
self.ema_update = ema_update
init_fn = uniform_init if not kmeans_init else torch.zeros
embed = init_fn(num_codebooks, codebook_size, dim)
self.codebook_size = codebook_size
self.num_codebooks = num_codebooks
self.kmeans_iters = kmeans_iters
self.eps = eps
self.threshold_ema_dead_code = threshold_ema_dead_code
self.reset_cluster_size = default(reset_cluster_size, threshold_ema_dead_code)
assert callable(gumbel_sample)
self.gumbel_sample = gumbel_sample
self.sample_codebook_temp = sample_codebook_temp
assert not (use_ddp and num_codebooks > 1 and kmeans_init), 'kmeans init is not compatible with multiple codebooks in distributed environment for now'
self.sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors
self.kmeans_all_reduce_fn = distributed.all_reduce if use_ddp and sync_kmeans else noop
self.all_reduce_fn = distributed.all_reduce if use_ddp else noop
self.register_buffer('initted', torch.Tensor([not kmeans_init]))
self.register_buffer('cluster_size', torch.zeros(num_codebooks, codebook_size))
self.register_buffer('embed_avg', embed.clone())
self.learnable_codebook = learnable_codebook
if learnable_codebook:
self.embed = nn.Parameter(embed)
else:
self.register_buffer('embed', embed)
self.affine_param = affine_param
self.sync_affine_param = sync_affine_param
if not affine_param:
return
self.affine_param_batch_decay = affine_param_batch_decay
self.affine_param_codebook_decay = affine_param_codebook_decay
self.register_buffer('batch_mean', None)
self.register_buffer('batch_variance', None)
self.register_buffer('codebook_mean_needs_init', torch.Tensor([True]))
self.register_buffer('codebook_mean', torch.empty(num_codebooks, 1, dim))
self.register_buffer('codebook_variance_needs_init', torch.Tensor([True]))
self.register_buffer('codebook_variance', torch.empty(num_codebooks, 1, dim))
@torch.jit.ignore
def init_embed_(self, data, mask = None):
if self.initted:
return
if exists(mask):
c = data.shape[0]
data = rearrange(data[mask], '(c n) d -> c n d', c = c)
embed, cluster_size = kmeans(
data,
self.codebook_size,
self.kmeans_iters,
sample_fn = self.sample_fn,
all_reduce_fn = self.kmeans_all_reduce_fn
)
embed_sum = embed * rearrange(cluster_size, '... -> ... 1')
self.embed.data.copy_(embed)
self.embed_avg.data.copy_(embed_sum)
self.cluster_size.data.copy_(cluster_size)
self.initted.data.copy_(torch.Tensor([True]))
@torch.jit.ignore
def update_with_decay(self, buffer_name, new_value, decay):
old_value = getattr(self, buffer_name)
needs_init = getattr(self, buffer_name + "_needs_init", False)
if needs_init:
self.register_buffer(buffer_name + "_needs_init", torch.Tensor([False]))
if not exists(old_value) or needs_init:
self.register_buffer(buffer_name, new_value.detach())
return
value = old_value * decay + new_value.detach() * (1 - decay)
self.register_buffer(buffer_name, value)
@torch.jit.ignore
def update_affine(self, data, embed, mask = None):
assert self.affine_param
var_fn = partial(torch.var, unbiased = False)
embed = rearrange(embed, 'h ... d -> h (...) d')
if self.training:
self.update_with_decay('codebook_mean', reduce(embed, 'h n d -> h 1 d', 'mean'), self.affine_param_codebook_decay)
self.update_with_decay('codebook_variance', reduce(embed, 'h n d -> h 1 d', var_fn), self.affine_param_codebook_decay)
data = rearrange(data, 'h ... d -> h (...) d')
if exists(mask):
c = data.shape[0]
data = rearrange(data[mask], '(c n) d -> c n d', c = c)
if not self.sync_affine_param:
self.update_with_decay('batch_mean', reduce(data, 'h n d -> h 1 d', 'mean'), self.affine_param_batch_decay)
self.update_with_decay('batch_variance', reduce(data, 'h n d -> h 1 d', var_fn), self.affine_param_batch_decay)
return
num_vectors, device, dtype = data.shape[-2], data.device, data.dtype
num_vectors = torch.tensor([num_vectors], device = device, dtype = dtype)
distributed.all_reduce(num_vectors)
batch_sum = reduce(data, 'h n d -> h 1 d', 'sum')
distributed.all_reduce(batch_sum)
batch_mean = batch_sum / num_vectors
self.update_with_decay('batch_mean', batch_mean, self.affine_param_batch_decay)
variance_numer = reduce((data - batch_mean) ** 2, 'h n d -> h 1 d', 'sum')
distributed.all_reduce(variance_numer)
batch_variance = variance_numer / num_vectors
self.update_with_decay('batch_variance', batch_variance, self.affine_param_batch_decay)
def replace(self, batch_samples, batch_mask):
for ind, (samples, mask) in enumerate(zip(batch_samples.unbind(dim = 0), batch_mask.unbind(dim = 0)):
if not torch.any(mask):
continue
sampled = self.sample_fn(rearrange(samples, '... -> 1 ...'), mask.sum().item())
sampled = rearrange(sampled, '1 ... -> ...')
self.embed.data[ind][mask] = sampled
self.cluster_size.data[ind][mask] = self.reset_cluster_size
self.embed_avg.data[ind][mask] = sampled * self.reset_cluster_size
def expire_codes_(self, batch_samples):
if self.threshold_ema_dead_code == 0:
return
expired_codes = self.cluster_size < self.threshold_ema_dead_code
if not torch.any(expired_codes):
return
batch_samples = rearrange(batch_samples, 'h ... d -> h (...) d')
self.replace(batch_samples, batch_mask = expired_codes)
@autocast(enabled = False)
def forward(
self,
x,
sample_codebook_temp = None,
mask = None,
freeze_codebook = False
):
needs_codebook_dim = x.ndim < 4
sample_codebook_temp = default(sample_codebook_temp, self.sample_codebook_temp)
x = x.float()
if needs_codebook_dim:
x = rearrange(x, '... -> 1 ...')
dtype = x.dtype
flatten, ps = pack_one(x, 'h * d')
if exists(mask):
mask = repeat(mask, 'b n -> c (b h n)', c = flatten.shape[0], h = flatten.shape[-2] // (mask.shape[0] * mask.shape[1]))
self.init_embed_(flatten, mask = mask)
if self.affine_param:
self.update_affine(flatten, self.embed, mask = mask)
embed = self.embed if self.learnable_codebook else self.embed.detach()
if self.affine_param:
codebook_std = self.codebook_variance.clamp(min = 1e-5).sqrt()
batch_std = self.batch_variance.clamp(min = 1e-5).sqrt()
embed = (embed - self.codebook_mean) * (batch_std / codebook_std) + self.batch_mean
dist = -cdist(flatten, embed)
embed_ind, embed_onehot = self.gumbel_sample(dist, dim = -1, temperature = sample_codebook_temp, training = self.training)
embed_ind = unpack_one(embed_ind, ps, 'h *')
if self.training:
unpacked_onehot = unpack_one(embed_onehot, ps, 'h * c')
quantize = einsum('h b n c, h c d -> h b n d', unpacked_onehot, embed)
else:
quantize = batched_embedding(embed_ind, embed)
if self.training and self.ema_update and not freeze_codebook:
if self.affine_param:
flatten = (flatten - self.batch_mean) * (codebook_std / batch_std) + self.codebook_mean
if exists(mask):
embed_onehot[~mask] = 0.
cluster_size = embed_onehot.sum(dim = 1)
self.all_reduce_fn(cluster_size)
ema_inplace(self.cluster_size.data, cluster_size, self.decay)
embed_sum = einsum('h n d, h n c -> h c d', flatten, embed_onehot)
embed_sum = embed_sum.contiguous()
self.all_reduce_fn(embed_sum)
ema_inplace(self.embed_avg.data, embed_sum, self.decay)
cluster_size = laplace_smoothing(self.cluster_size, self.codebook_size, self.eps) * self.cluster_size.sum(dim = -1, keepdim = True)
embed_normalized = self.embed_avg / rearrange(cluster_size, '... -> ... 1')
self.embed.data.copy_(embed_normalized)
self.expire_codes_(x)
if needs_codebook_dim:
quantize, embed_ind = map(lambda t: rearrange(t, '1 ... -> ...'), (quantize, embed_ind))
dist = unpack_one(dist, ps, 'h * d')
return quantize, embed_ind, dist
class CosineSimCodebook(nn.Module):
def __init__(
self,
dim,
codebook_size,
num_codebooks = 1,
kmeans_init = False,
kmeans_iters = 10,
sync_kmeans = True,
decay = 0.8,
eps = 1e-5,
threshold_ema_dead_code = 2,
reset_cluster_size = None,
use_ddp = False,
learnable_codebook = False,
gumbel_sample = gumbel_sample,
sample_codebook_temp = 1.,
ema_update = True
):
super().__init__()
self.transform_input = l2norm
self.ema_update = ema_update
self.decay = decay
if not kmeans_init:
embed = l2norm(uniform_init(num_codebooks, codebook_size, dim))
else:
embed = torch.zeros(num_codebooks, codebook_size, dim)
self.codebook_size = codebook_size
self.num_codebooks = num_codebooks
self.kmeans_iters = kmeans_iters
self.eps = eps
self.threshold_ema_dead_code = threshold_ema_dead_code
self.reset_cluster_size = default(reset_cluster_size, threshold_ema_dead_code)
assert callable(gumbel_sample)
self.gumbel_sample = gumbel_sample
self.sample_codebook_temp = sample_codebook_temp
self.sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors
self.kmeans_all_reduce_fn = distributed.all_reduce if use_ddp and sync_kmeans else noop
self.all_reduce_fn = distributed.all_reduce if use_ddp else noop
self.register_buffer('initted', torch.Tensor([not kmeans_init]))
self.register_buffer('cluster_size', torch.zeros(num_codebooks, codebook_size))
self.register_buffer('embed_avg', embed.clone())
self.learnable_codebook = learnable_codebook
if learnable_codebook:
self.embed = nn.Parameter(embed)
else:
self.register_buffer('embed', embed)
@torch.jit.ignore
def init_embed_(self, data, mask = None):
if self.initted:
return
if exists(mask):
c = data.shape[0]
data = rearrange(data[mask], '(c n) d -> c n d', c = c)
embed, cluster_size = kmeans(
data,
self.codebook_size,
self.kmeans_iters,
use_cosine_sim = True,
sample_fn = self.sample_fn,
all_reduce_fn = self.kmeans_all_reduce_fn
)
embed_sum = embed * rearrange(cluster_size, '... -> ... 1')
self.embed.data.copy_(embed)
self.embed_avg.data.copy_(embed_sum)
self.cluster_size.data.copy_(cluster_size)
self.initted.data.copy_(torch.Tensor([True]))
def replace(self, batch_samples, batch_mask):
batch_samples = l2norm(batch_samples)
for ind, (samples, mask) in enumerate(zip(batch_samples.unbind(dim = 0), batch_mask.unbind(dim = 0)):
if not torch.any(mask):
continue
sampled = self.sample_fn(rearrange(samples, '... -> 1 ...'), mask.sum().item())
sampled = rearrange(sampled, '1 ... -> ...')
self.embed.data[ind][mask] = sampled
self.embed_avg.data[ind][mask] = sampled * self.reset_cluster_size
self.cluster_size.data[ind][mask] = self.reset_cluster_size
def expire_codes_(self, batch_samples):
if self.threshold_ema_dead_code == 0:
return
expired_codes = self.cluster_size < self.threshold_ema_dead_code
if not torch.any(expired_codes):
return
batch_samples = rearrange(batch_samples, 'h ... d -> h (...) d')
self.replace(batch_samples, batch_mask = expired_codes)
@autocast(enabled = False)
def forward(
self,
x,
sample_codebook_temp = None,
mask = None,
freeze_codebook = False
):
needs_codebook_dim = x.ndim < 4
sample_codebook_temp = default(sample_codebook_temp, self.sample_codebook_temp)
x = x.float()
if needs_codebook_dim:
x = rearrange(x, '... -> 1 ...')
dtype = x.dtype
flatten, ps = pack_one(x, 'h * d')
if exists(mask):
mask = repeat(mask, 'b n -> c (b h n)', c = flatten.shape[0], h = flatten.shape[-2] // (mask.shape[0] * mask.shape[1]))
self.init_embed_(flatten, mask = mask)
embed = self.embed if self.learnable_codebook else self.embed.detach()
dist = einsum('h n d, h c d -> h n c', flatten, embed)
embed_ind, embed_onehot = self.gumbel_sample(dist, dim = -1, temperature = sample_codebook_temp, training = self.training)
embed_ind = unpack_one(embed_ind, ps, 'h *')
if self.training:
unpacked_onehot = unpack_one(embed_onehot, ps, 'h * c')
quantize = einsum('h b n c, h c d -> h b n d', unpacked_onehot, embed)
else:
quantize = batched_embedding(embed_ind, embed)
if self.training and self.ema_update and not freeze_codebook:
if exists(mask):
embed_onehot[~mask] = 0.
bins = embed_onehot.sum(dim = 1)
self.all_reduce_fn(bins)
ema_inplace(self.cluster_size.data, bins, self.decay)
embed_sum = einsum('h n d, h n c -> h c d', flatten, embed_onehot)
embed_sum = embed_sum.contiguous()
self.all_reduce_fn(embed_sum)
ema_inplace(self.embed_avg.data, embed_sum, self.decay)
cluster_size = laplace_smoothing(self.cluster_size, self.codebook_size, self.eps) * self.cluster_size.sum(dim = -1, keepdim = True)
embed_normalized = self.embed_avg / rearrange(cluster_size, '... -> ... 1')
embed_normalized = l2norm(embed_normalized)
self.embed.data.copy_(l2norm(embed_normalized))
self.expire_codes_(x)
if needs_codebook_dim:
quantize, embed_ind = map(lambda t: rearrange(t, '1 ... -> ...'), (quantize, embed_ind))
dist = unpack_one(dist, ps, 'h * d')
return quantize, embed_ind, dist
class VectorQuantize(nn.Module):
def __init__(
self,
dim,
codebook_size,
codebook_dim = None,
heads = 1,
separate_codebook_per_head = False,
decay = 0.8,
eps = 1e-5,
freeze_codebook = False,
kmeans_init = False,
kmeans_iters = 10,
sync_kmeans = True,
use_cosine_sim = False,
threshold_ema_dead_code = 0,
channel_last = True,
accept_image_fmap = False,
commitment_weight = 1.,
commitment_use_cross_entropy_loss = False,
orthogonal_reg_weight = 0.,
orthogonal_reg_active_codes_only = False,
orthogonal_reg_max_codes = None,
stochastic_sample_codes = False,
sample_codebook_temp = 1.,
straight_through = False,
reinmax = False,
sync_codebook = None,
sync_affine_param = False,
ema_update = True,
learnable_codebook = False,
in_place_codebook_optimizer: Callable[..., Optimizer] = None,
affine_param = False,
affine_param_batch_decay = 0.99,
affine_param_codebook_decay = 0.9,
sync_update_v = 0.
):
super().__init__()
self.dim = dim
self.heads = heads
self.separate_codebook_per_head = separate_codebook_per_head
codebook_dim = default(codebook_dim, dim)
codebook_input_dim = codebook_dim * heads
requires_projection = codebook_input_dim != dim
self.project_in = nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity()
self.project_out = nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity()
self.has_projections = requires_projection
self.eps = eps
self.commitment_weight = commitment_weight
self.commitment_use_cross_entropy_loss = commitment_use_cross_entropy_loss
self.learnable_codebook = learnable_codebook
has_codebook_orthogonal_loss = orthogonal_reg_weight > 0
self.has_codebook_orthogonal_loss = has_codebook_orthogonal_loss
self.orthogonal_reg_weight = orthogonal_reg_weight
self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only
self.orthogonal_reg_max_codes = orthogonal_reg_max_codes
assert not (ema_update and learnable_codebook), 'learnable codebook not compatible with EMA update'
assert 0 <= sync_update_v <= 1.
assert not (sync_update_v > 0. and not learnable_codebook), 'learnable codebook must be turned on'
self.sync_update_v = sync_update_v
codebook_class = EuclideanCodebook if not use_cosine_sim else CosineSimCodebook
gumbel_sample_fn = partial(
gumbel_sample,
stochastic = stochastic_sample_codes,
reinmax = reinmax,
straight_through = straight_through
)
if not exists(sync_codebook):
sync_codebook = distributed.is_initialized() and distributed.get_world_size() > 1
codebook_kwargs = dict(
dim = codebook_dim,
num_codebooks = heads if separate_codebook_per_head else 1,
codebook_size = codebook_size,
kmeans_init = kmeans_init,
kmeans_iters = kmeans_iters,
sync_kmeans = sync_kmeans,
decay = decay,
eps = eps,
threshold_ema_dead_code = threshold_ema_dead_code,
use_ddp = sync_codebook,
learnable_codebook = has_codebook_orthogonal_loss or learnable_codebook,
sample_codebook_temp = sample_codebook_temp,
gumbel_sample = gumbel_sample_fn,
ema_update = ema_update
)
if affine_param:
assert not use_cosine_sim, 'affine param is only compatible with euclidean codebook'
codebook_kwargs = dict(
**codebook_kwargs,
affine_param = True,
sync_affine_param = sync_affine_param,
affine_param_batch_decay = affine_param_batch_decay,
affine_param_codebook_decay = affine_param_codebook_decay,
)
self._codebook = codebook_class(**codebook_kwargs)
self.in_place_codebook_optimizer = in_place_codebook_optimizer(self._codebook.parameters()) if exists(in_place_codebook_optimizer) else None
self.codebook_size = codebook_size
self.accept_image_fmap = accept_image_fmap
self.channel_last = channel_last
@property
def codebook(self):
codebook = self._codebook.embed
if self.separate_codebook_per_head:
return codebook
return rearrange(codebook, '1 ... -> ...')
@codebook.setter
def codebook(self, codes):
if not self.separate_codebook_per_head:
codes = rearrange(codes, '... -> 1 ...')
self._codebook.embed.copy_(codes)
def get_codes_from_indices(self, indices):
codebook = self.codebook
is_multiheaded = codebook.ndim > 2
if not is_multiheaded:
codes = codebook[indices]
else:
indices, ps = pack_one(indices, 'b * h')
indices = rearrange(indices, 'b n h -> b h n')
indices = repeat(indices, 'b h n -> b h n d', d = codebook.shape[-1])
codebook = repeat(codebook, 'h n d -> b h n d', b = indices.shape[0])
codes = codebook.gather(2, indices)
codes = rearrange(codes, 'b h n d -> b n (h d)')
codes = unpack_one(codes, ps, 'b * d')
if not self.channel_last:
codes = rearrange(codes, 'b ... d -> b d ...')
return codes
def get_output_from_indices(self, indices):
codes = self.get_codes_from_indices(indices)
return self.project_out(codes)
def forward(
self,
x,
indices = None,
mask = None,
sample_codebook_temp = None,
freeze_codebook = False