Lucidrains 系列项目源码解析(二十六)
.\lucidrains\denoising-diffusion-pytorch\denoising_diffusion_pytorch\karras_unet_1d.py
"""
the magnitude-preserving unet proposed in https://arxiv.org/abs/2312.02696 by Karras et al.
"""
import math
from math import sqrt, ceil
from functools import partial
import torch
from torch import nn, einsum
from torch.nn import Module, ModuleList
from torch.optim.lr_scheduler import LambdaLR
import torch.nn.functional as F
from einops import rearrange, repeat, pack, unpack
from denoising_diffusion_pytorch.attend import Attend
def exists(x):
return x is not None
def default(val, d):
if exists(val):
return val
return d() if callable(d) else d
def xnor(x, y):
return not (x ^ y)
def append(arr, el):
arr.append(el)
def prepend(arr, el):
arr.insert(0, el)
def pack_one(t, pattern):
return pack([t], pattern)
def unpack_one(t, ps, pattern):
return unpack(t, ps, pattern)[0]
def cast_tuple(t, length = 1):
if isinstance(t, tuple):
return t
return ((t,) * length)
def divisible_by(numer, denom):
return (numer % denom) == 0
def l2norm(t, dim = -1, eps = 1e-12):
return F.normalize(t, dim = dim, eps = eps)
def interpolate_1d(x, length, mode = 'bilinear'):
x = rearrange(x, 'b c t -> b c t 1')
x = F.interpolate(x, (length, 1), mode = mode)
return rearrange(x, 'b c t 1 -> b c t')
class MPSiLU(Module):
def forward(self, x):
return F.silu(x) / 0.596
class Gain(Module):
def __init__(self):
super().__init__()
self.gain = nn.Parameter(torch.tensor(0.))
def forward(self, x):
return x * self.gain
class MPCat(Module):
def __init__(self, t = 0.5, dim = -1):
super().__init__()
self.t = t
self.dim = dim
def forward(self, a, b):
dim, t = self.dim, self.t
Na, Nb = a.shape[dim], b.shape[dim]
C = sqrt((Na + Nb) / ((1. - t) ** 2 + t ** 2))
a = a * (1. - t) / sqrt(Na)
b = b * t / sqrt(Nb)
return C * torch.cat((a, b), dim = dim)
class MPAdd(Module):
def __init__(self, t):
super().__init__()
self.t = t
def forward(self, x, res):
a, b, t = x, res, self.t
num = a * (1. - t) + b * t
den = sqrt((1 - t) ** 2 + t ** 2)
return num / den
class PixelNorm(Module):
def __init__(self, dim, eps = 1e-4):
super().__init__()
self.dim = dim
self.eps = eps
def forward(self, x):
dim = self.dim
return l2norm(x, dim = dim, eps = self.eps) * sqrt(x.shape[dim])
def normalize_weight(weight, eps = 1e-4):
weight, ps = pack_one(weight, 'o *')
normed_weight = l2norm(weight, eps = eps)
normed_weight = normed_weight * sqrt(weight.numel() / weight.shape[0])
return unpack_one(normed_weight, ps, 'o *')
class Conv1d(Module):
def __init__(
self,
dim_in,
dim_out,
kernel_size,
eps = 1e-4,
init_dirac = False,
concat_ones_to_input = False
):
super().__init__()
weight = torch.randn(dim_out, dim_in + int(concat_ones_to_input), kernel_size)
self.weight = nn.Parameter(weight)
if init_dirac:
nn.init.dirac_(self.weight)
self.eps = eps
self.fan_in = dim_in * kernel_size
self.concat_ones_to_input = concat_ones_to_input
def forward(self, x):
if self.training:
with torch.no_grad():
normed_weight = normalize_weight(self.weight, eps = self.eps)
self.weight.copy_(normed_weight)
weight = normalize_weight(self.weight, eps = self.eps) / sqrt(self.fan_in)
if self.concat_ones_to_input:
x = F.pad(x, (0, 0, 1, 0), value = 1.)
return F.conv1d(x, weight, padding = 'same')
class Linear(Module):
def __init__(self, dim_in, dim_out, eps = 1e-4):
super().__init__()
weight = torch.randn(dim_out, dim_in)
self.weight = nn.Parameter(weight)
self.eps = eps
self.fan_in = dim_in
def forward(self, x):
if self.training:
with torch.no_grad():
normed_weight = normalize_weight(self.weight, eps = self.eps)
self.weight.copy_(normed_weight)
weight = normalize_weight(self.weight, eps = self.eps) / sqrt(self.fan_in)
return F.linear(x, weight)
class MPFourierEmbedding(Module):
def __init__(self, dim):
super().__init__()
assert divisible_by(dim, 2)
half_dim = dim // 2
self.weights = nn.Parameter(torch.randn(half_dim), requires_grad = False)
def forward(self, x):
x = rearrange(x, 'b -> b 1')
freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi
return torch.cat((freqs.sin(), freqs.cos()), dim = -1) * sqrt(2)
class Encoder(Module):
def __init__(
self,
dim,
dim_out = None,
*,
emb_dim = None,
dropout = 0.1,
mp_add_t = 0.3,
has_attn = False,
attn_dim_head = 64,
attn_res_mp_add_t = 0.3,
attn_flash = False,
downsample = False
):
super().__init__()
dim_out = default(dim_out, dim)
self.downsample = downsample
self.downsample_conv = None
curr_dim = dim
if downsample:
self.downsample_conv = Conv1d(curr_dim, dim_out, 1)
curr_dim = dim_out
self.pixel_norm = PixelNorm(dim = 1)
self.to_emb = None
if exists(emb_dim):
self.to_emb = nn.Sequential(
Linear(emb_dim, dim_out),
Gain()
)
self.block1 = nn.Sequential(
MPSiLU(),
Conv1d(curr_dim, dim_out, 3)
)
self.block2 = nn.Sequential(
MPSiLU(),
nn.Dropout(dropout),
Conv1d(dim_out, dim_out, 3)
)
self.res_mp_add = MPAdd(t = mp_add_t)
self.attn = None
if has_attn:
self.attn = Attention(
dim = dim_out,
heads = max(ceil(dim_out / attn_dim_head), 2),
dim_head = attn_dim_head,
mp_add_t = attn_res_mp_add_t,
flash = attn_flash
)
def forward(
self,
x,
emb = None
):
if self.downsample:
x = interpolate_1d(x, x.shape[-1] // 2, mode = 'bilinear')
x = self.downsample_conv(x)
x = self.pixel_norm(x)
res = x.clone()
x = self.block1(x)
if exists(emb):
scale = self.to_emb(emb) + 1
x = x * rearrange(scale, 'b c -> b c 1')
x = self.block2(x)
x = self.res_mp_add(x, res)
if exists(self.attn):
x = self.attn(x)
return x
class Decoder(Module):
def __init__(
self,
dim,
dim_out = None,
*,
emb_dim = None,
dropout = 0.1,
mp_add_t = 0.3,
has_attn = False,
attn_dim_head = 64,
attn_res_mp_add_t = 0.3,
attn_flash = False,
upsample = False
):
super().__init__()
dim_out = default(dim_out, dim)
self.upsample = upsample
self.needs_skip = not upsample
self.to_emb = None
if exists(emb_dim):
self.to_emb = nn.Sequential(
Linear(emb_dim, dim_out),
Gain()
)
self.block1 = nn.Sequential(
MPSiLU(),
Conv1d(dim, dim_out, 3)
)
self.block2 = nn.Sequential(
MPSiLU(),
nn.Dropout(dropout),
Conv1d(dim_out, dim_out, 3)
)
self.res_conv = Conv1d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
self.res_mp_add = MPAdd(t = mp_add_t)
self.attn = None
if has_attn:
self.attn = Attention(
dim = dim_out,
heads = max(ceil(dim_out / attn_dim_head), 2),
dim_head = attn_dim_head,
mp_add_t = attn_res_mp_add_t,
flash = attn_flash
)
def forward(
self,
x,
emb = None
):
if self.upsample:
x = interpolate_1d(x, x.shape[-1] * 2, mode = 'bilinear')
res = self.res_conv(x)
x = self.block1(x)
if exists(emb):
scale = self.to_emb(emb) + 1
x = x * rearrange(scale, 'b c -> b c 1')
x = self.block2(x)
x = self.res_mp_add(x, res)
if exists(self.attn):
x = self.attn(x)
return x
class Attention(Module):
def __init__(
self,
dim,
heads = 4,
dim_head = 64,
num_mem_kv = 4,
flash = False,
mp_add_t = 0.3
):
super().__init__()
self.heads = heads
hidden_dim = dim_head * heads
self.pixel_norm = PixelNorm(dim = -1)
self.attend = Attend(flash = flash)
self.mem_kv = nn.Parameter(torch.randn(2, heads, num_mem_kv, dim_head))
self.to_qkv = Conv1d(dim, hidden_dim * 3, 1)
self.to_out = Conv1d(hidden_dim, dim, 1)
self.mp_add = MPAdd(t = mp_add_t)
def forward(self, x):
res, b, c, n = x, *x.shape
qkv = self.to_qkv(x).chunk(3, dim = 1)
q, k, v = map(lambda t: rearrange(t, 'b (h c) n -> b h n c', h = self.heads), qkv)
mk, mv = map(lambda t: repeat(t, 'h n d -> b h n d', b = b), self.mem_kv)
k, v = map(partial(torch.cat, dim = -2), ((mk, k), (mv, v)))
q, k, v = map(self.pixel_norm, (q, k, v))
out = self.attend(q, k, v)
out = rearrange(out, 'b h n d -> b (h d) n')
out = self.to_out(out)
return self.mp_add(out, res)
class KarrasUnet1D(Module):
"""
going by figure 21. config G
"""
def __init__(
self,
*,
seq_len,
dim = 192,
dim_max = 768,
num_classes = None,
channels = 4,
num_downsamples = 3,
num_blocks_per_stage = 4,
attn_res = (16, 8),
fourier_dim = 16,
attn_dim_head = 64,
attn_flash = False,
mp_cat_t = 0.5,
mp_add_emb_t = 0.5,
attn_res_mp_add_t = 0.3,
resnet_mp_add_t = 0.3,
dropout = 0.1,
self_condition = False
):
super().__init__()
self.self_condition = self_condition
self.channels = channels
self.seq_len = seq_len
input_channels = channels * (2 if self_condition else 1)
self.input_block = Conv1d(input_channels, dim, 3, concat_ones_to_input = True)
self.output_block = nn.Sequential(
Conv1d(dim, channels, 3),
Gain()
)
emb_dim = dim * 4
self.to_time_emb = nn.Sequential(
MPFourierEmbedding(fourier_dim),
Linear(fourier_dim, emb_dim)
)
self.needs_class_labels = exists(num_classes)
self.num_classes = num_classes
if self.needs_class_labels:
self.to_class_emb = Linear(num_classes, 4 * dim)
self.add_class_emb = MPAdd(t = mp_add_emb_t)
self.emb_activation = MPSiLU()
self.num_downsamples = num_downsamples
attn_res = set(cast_tuple(attn_res))
block_kwargs = dict(
dropout = dropout,
emb_dim = emb_dim,
attn_dim_head = attn_dim_head,
attn_res_mp_add_t = attn_res_mp_add_t,
attn_flash = attn_flash
)
self.downs = ModuleList([])
self.ups = ModuleList([])
curr_dim = dim
curr_res = seq_len
self.skip_mp_cat = MPCat(t = mp_cat_t, dim = 1)
prepend(self.ups, Decoder(dim * 2, dim, **block_kwargs))
assert num_blocks_per_stage >= 1
for _ in range(num_blocks_per_stage):
enc = Encoder(curr_dim, curr_dim, **block_kwargs)
dec = Decoder(curr_dim * 2, curr_dim, **block_kwargs)
append(self.downs, enc)
prepend(self.ups, dec)
for _ in range(self.num_downsamples):
dim_out = min(dim_max, curr_dim * 2)
upsample = Decoder(dim_out, curr_dim, has_attn = curr_res in attn_res, upsample = True, **block_kwargs)
curr_res //= 2
has_attn = curr_res in attn_res
downsample = Encoder(curr_dim, dim_out, downsample = True, has_attn = has_attn, **block_kwargs)
append(self.downs, downsample)
prepend(self.ups, upsample)
prepend(self.ups, Decoder(dim_out * 2, dim_out, has_attn = has_attn, **block_kwargs))
for _ in range(num_blocks_per_stage):
enc = Encoder(dim_out, dim_out, has_attn = has_attn, **block_kwargs)
dec = Decoder(dim_out * 2, dim_out, has_attn = has_attn, **block_kwargs)
append(self.downs, enc)
prepend(self.ups, dec)
curr_dim = dim_out
mid_has_attn = curr_res in attn_res
self.mids = ModuleList([
Decoder(curr_dim, curr_dim, has_attn = mid_has_attn, **block_kwargs),
Decoder(curr_dim, curr_dim, has_attn = mid_has_attn, **block_kwargs),
])
self.out_dim = channels
@property
def downsample_factor(self):
return 2 ** self.num_downsamples
def forward(
self,
x,
time,
self_cond = None,
class_labels = None
):
assert x.shape[1:] == (self.channels, self.seq_len)
if self.self_condition:
self_cond = default(self_cond, lambda: torch.zeros_like(x))
x = torch.cat((self_cond, x), dim = 1)
else:
assert not exists(self_cond)
time_emb = self.to_time_emb(time)
assert xnor(exists(class_labels), self.needs_class_labels)
if self.needs_class_labels:
if class_labels.dtype in (torch.int, torch.long):
class_labels = F.one_hot(class_labels, self.num_classes)
assert class_labels.shape[-1] == self.num_classes
class_labels = class_labels.float() * sqrt(self.num_classes)
class_emb = self.to_class_emb(class_labels)
time_emb = self.add_class_emb(time_emb, class_emb)
emb = self.emb_activation(time_emb)
skips = []
x = self.input_block(x)
skips.append(x)
for encoder in self.downs:
x = encoder(x, emb = emb)
skips.append(x)
for decoder in self.mids:
x = decoder(x, emb = emb)
for decoder in self.ups:
if decoder.needs_skip:
skip = skips.pop()
x = self.skip_mp_cat(x, skip)
x = decoder(x, emb = emb)
return self.output_block(x)
class MPFeedForward(Module):
def __init__(
self,
*,
dim,
mult = 4,
mp_add_t = 0.3
):
super().__init__()
dim_inner = int(dim * mult)
self.net = nn.Sequential(
PixelNorm(dim = 1),
Conv2d(dim, dim_inner, 1),
MPSiLU(),
Conv2d(dim_inner, dim, 1)
)
self.mp_add = MPAdd(t = mp_add_t)
def forward(self, x):
res = x
out = self.net(x)
return self.mp_add(out, res)
class MPImageTransformer(Module):
def __init__(
self,
*,
dim,
depth,
dim_head = 64,
heads = 8,
num_mem_kv = 4,
ff_mult = 4,
attn_flash = False,
residual_mp_add_t = 0.3
):
super().__init__()
self.layers = ModuleList([])
for _ in range(depth):
self.layers.append(ModuleList([
Attention(dim = dim, heads = heads, dim_head = dim_head, num_mem_kv = num_mem_kv, flash = attn_flash, mp_add_t = residual_mp_add_t),
MPFeedForward(dim = dim, mult = ff_mult, mp_add_t = residual_mp_add_t)
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x)
x = ff(x)
return x
if __name__ == '__main__':
unet = KarrasUnet1D(
seq_len = 64,
dim = 192,
dim_max = 768,
num_classes = 1000,
)
images = torch.randn(2, 4, 64)
denoised_images = unet(
images,
time = torch.ones(2,),
class_labels = torch.randint(0, 1000, (2,))
)
assert denoised_images.shape == images.shape
.\lucidrains\denoising-diffusion-pytorch\denoising_diffusion_pytorch\karras_unet_3d.py
"""
the magnitude-preserving unet proposed in https://arxiv.org/abs/2312.02696 by Karras et al.
"""
import math
from math import sqrt, ceil
from functools import partial
from typing import Optional, Union, Tuple
import torch
from torch import nn, einsum
from torch.nn import Module, ModuleList
from torch.optim.lr_scheduler import LambdaLR
import torch.nn.functional as F
from einops import rearrange, repeat, pack, unpack
from denoising_diffusion_pytorch.attend import Attend
def exists(x):
return x is not None
def default(val, d):
if exists(val):
return val
return d() if callable(d) else d
def xnor(x, y):
return not (x ^ y)
def append(arr, el):
arr.append(el)
def prepend(arr, el):
arr.insert(0, el)
def pack_one(t, pattern):
return pack([t], pattern)
def unpack_one(t, ps, pattern):
return unpack(t, ps, pattern)[0]
def cast_tuple(t, length = 1):
if isinstance(t, tuple):
return t
return ((t,) * length)
def divisible_by(numer, denom):
return (numer % denom) == 0
def l2norm(t, dim = -1, eps = 1e-12):
return F.normalize(t, dim = dim, eps = eps)
class MPSiLU(Module):
def forward(self, x):
return F.silu(x) / 0.596
class Gain(Module):
def __init__(self):
super().__init__()
self.gain = nn.Parameter(torch.tensor(0.))
def forward(self, x):
return x * self.gain
class MPCat(Module):
def __init__(self, t = 0.5, dim = -1):
super().__init__()
self.t = t
self.dim = dim
def forward(self, a, b):
dim, t = self.dim, self.t
Na, Nb = a.shape[dim], b.shape[dim]
C = sqrt((Na + Nb) / ((1. - t) ** 2 + t ** 2))
a = a * (1. - t) / sqrt(Na)
b = b * t / sqrt(Nb)
return C * torch.cat((a, b), dim = dim)
class MPAdd(Module):
def __init__(self, t):
super().__init__()
self.t = t
def forward(self, x, res):
a, b, t = x, res, self.t
num = a * (1. - t) + b * t
den = sqrt((1 - t) ** 2 + t ** 2)
return num / den
class PixelNorm(Module):
def __init__(self, dim, eps = 1e-4):
super().__init__()
self.dim = dim
self.eps = eps
def forward(self, x):
dim = self.dim
return l2norm(x, dim = dim, eps = self.eps) * sqrt(x.shape[dim])
def normalize_weight(weight, eps = 1e-4):
weight, ps = pack_one(weight, 'o *')
normed_weight = l2norm(weight, eps = eps)
normed_weight = normed_weight * sqrt(weight.numel() / weight.shape[0])
return unpack_one(normed_weight, ps, 'o *')
class Conv3d(Module):
def __init__(
self,
dim_in,
dim_out,
kernel_size,
eps = 1e-4,
concat_ones_to_input = False
):
super().__init__()
weight = torch.randn(dim_out, dim_in + int(concat_ones_to_input), kernel_size, kernel_size, kernel_size)
self.weight = nn.Parameter(weight)
self.eps = eps
self.fan_in = dim_in * kernel_size ** 3
self.concat_ones_to_input = concat_ones_to_input
def forward(self, x):
if self.training:
with torch.no_grad():
normed_weight = normalize_weight(self.weight, eps = self.eps)
self.weight.copy_(normed_weight)
weight = normalize_weight(self.weight, eps = self.eps) / sqrt(self.fan_in)
if self.concat_ones_to_input:
x = F.pad(x, (0, 0, 0, 0, 0, 0, 1, 0), value = 1.)
return F.conv3d(x, weight, padding='same')
class Linear(Module):
def __init__(self, dim_in, dim_out, eps = 1e-4):
super().__init__()
weight = torch.randn(dim_out, dim_in)
self.weight = nn.Parameter(weight)
self.eps = eps
self.fan_in = dim_in
def forward(self, x):
if self.training:
with torch.no_grad():
normed_weight = normalize_weight(self.weight, eps = self.eps)
self.weight.copy_(normed_weight)
weight = normalize_weight(self.weight, eps = self.eps) / sqrt(self.fan_in)
return F.linear(x, weight)
class MPFourierEmbedding(Module):
def __init__(self, dim):
super().__init__()
assert divisible_by(dim, 2)
half_dim = dim // 2
self.weights = nn.Parameter(torch.randn(half_dim), requires_grad = False)
def forward(self, x):
x = rearrange(x, 'b -> b 1')
freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi
return torch.cat((freqs.sin(), freqs.cos()), dim = -1) * sqrt(2)
class Encoder(Module):
def __init__(
self,
dim,
dim_out = None,
*,
emb_dim = None,
dropout = 0.1,
mp_add_t = 0.3,
has_attn = False,
attn_dim_head = 64,
attn_res_mp_add_t = 0.3,
attn_flash = False,
factorize_space_time_attn = False,
downsample = False,
downsample_config: Tuple[bool, bool, bool] = (True, True, True)
):
super().__init__()
dim_out = default(dim_out, dim)
self.downsample = downsample
self.downsample_config = downsample_config
self.downsample_conv = None
curr_dim = dim
if downsample:
self.downsample_conv = Conv3d(curr_dim, dim_out, 1)
curr_dim = dim_out
self.pixel_norm = PixelNorm(dim = 1)
self.to_emb = None
if exists(emb_dim):
self.to_emb = nn.Sequential(
Linear(emb_dim, dim_out),
Gain()
)
self.block1 = nn.Sequential(
MPSiLU(),
Conv3d(curr_dim, dim_out, 3)
)
self.block2 = nn.Sequential(
MPSiLU(),
nn.Dropout(dropout),
Conv3d(dim_out, dim_out, 3)
)
self.res_mp_add = MPAdd(t = mp_add_t)
self.attn = None
self.factorized_attn = factorize_space_time_attn
if has_attn:
attn_kwargs = dict(
dim = dim_out,
heads = max(ceil(dim_out / attn_dim_head), 2),
dim_head = attn_dim_head,
mp_add_t = attn_res_mp_add_t,
flash = attn_flash
)
if factorize_space_time_attn:
self.attn = nn.ModuleList([
Attention(**attn_kwargs, only_space = True),
Attention(**attn_kwargs, only_time = True),
])
else:
self.attn = Attention(**attn_kwargs)
def forward(
self,
x,
emb = None
):
if self.downsample:
t, h, w = x.shape[-3:]
resize_factors = tuple((2 if downsample else 1) for downsample in self.downsample_config)
interpolate_shape = tuple(shape // factor for shape, factor in zip((t, h, w), resize_factors))
x = F.interpolate(x, interpolate_shape, mode='trilinear')
x = self.downsample_conv(x)
x = self.pixel_norm(x)
res = x.clone()
x = self.block1(x)
if exists(emb):
scale = self.to_emb(emb) + 1
x = x * rearrange(scale, 'b c -> b c 1 1 1')
x = self.block2(x)
x = self.res_mp_add(x, res)
if exists(self.attn):
if self.factorized_attn:
attn_space, attn_time = self.attn
x = attn_space(x)
x = attn_time(x)
else:
x = self.attn(x)
return x
class Decoder(Module):
def __init__(
self,
dim,
dim_out = None,
*,
emb_dim = None,
dropout = 0.1,
mp_add_t = 0.3,
has_attn = False,
attn_dim_head = 64,
attn_res_mp_add_t = 0.3,
attn_flash = False,
factorize_space_time_attn = False,
upsample = False,
upsample_config: Tuple[bool, bool, bool] = (True, True, True)
):
super().__init__()
dim_out = default(dim_out, dim)
self.upsample = upsample
self.upsample_config = upsample_config
self.needs_skip = not upsample
self.to_emb = None
if exists(emb_dim):
self.to_emb = nn.Sequential(
Linear(emb_dim, dim_out),
Gain()
)
self.block1 = nn.Sequential(
MPSiLU(),
Conv3d(dim, dim_out, 3)
)
self.block2 = nn.Sequential(
MPSiLU(),
nn.Dropout(dropout),
Conv3d(dim_out, dim_out, 3)
)
self.res_conv = Conv3d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
self.res_mp_add = MPAdd(t = mp_add_t)
self.attn = None
self.factorized_attn = factorize_space_time_attn
if has_attn:
attn_kwargs = dict(
dim = dim_out,
heads = max(ceil(dim_out / attn_dim_head), 2),
dim_head = attn_dim_head,
mp_add_t = attn_res_mp_add_t,
flash = attn_flash
)
if factorize_space_time_attn:
self.attn = nn.ModuleList([
Attention(**attn_kwargs, only_space = True),
Attention(**attn_kwargs, only_time = True),
])
else:
self.attn = Attention(**attn_kwargs)
def forward(
self,
x,
emb = None
):
if self.upsample:
t, h, w = x.shape[-3:]
resize_factors = tuple((2 if upsample else 1) for upsample in self.upsample_config)
interpolate_shape = tuple(shape * factor for shape, factor in zip((t, h, w), resize_factors))
x = F.interpolate(x, interpolate_shape, mode = 'trilinear')
res = self.res_conv(x)
x = self.block1(x)
if exists(emb):
scale = self.to_emb(emb) + 1
x = x * rearrange(scale, 'b c -> b c 1 1 1')
x = self.block2(x)
x = self.res_mp_add(x, res)
if exists(self.attn):
if self.factorized_attn:
attn_space, attn_time = self.attn
x = attn_space(x)
x = attn_time(x)
else:
x = self.attn(x)
return x
class Attention(Module):
def __init__(
self,
dim,
heads = 4,
dim_head = 64,
num_mem_kv = 4,
flash = False,
mp_add_t = 0.3,
only_space = False,
only_time = False
):
super().__init__()
assert (int(only_space) + int(only_time)) <= 1
self.heads = heads
hidden_dim = dim_head * heads
self.pixel_norm = PixelNorm(dim = -1)
self.attend = Attend(flash = flash)
self.mem_kv = nn.Parameter(torch.randn(2, heads, num_mem_kv, dim_head))
self.to_qkv = Conv3d(dim, hidden_dim * 3, 1)
self.to_out = Conv3d(hidden_dim, dim, 1)
self.mp_add = MPAdd(t = mp_add_t)
self.only_space = only_space
self.only_time = only_time
def forward(self, x):
res, orig_shape = x, x.shape
b, c, t, h, w = orig_shape
qkv = self.to_qkv(x)
if self.only_space:
qkv = rearrange(qkv, 'b c t x y -> (b t) c x y')
elif self.only_time:
qkv = rearrange(qkv, 'b c t x y -> (b x y) c t')
qkv = qkv.chunk(3, dim = 1)
q, k, v = map(lambda t: rearrange(t, 'b (h c) ... -> b h (...) c', h = self.heads), qkv)
mk, mv = map(lambda t: repeat(t, 'h n d -> b h n d', b = k.shape[0]), self.mem_kv)
k, v = map(partial(torch.cat, dim = -2), ((mk, k), (mv, v)))
q, k, v = map(self.pixel_norm, (q, k, v))
out = self.attend(q, k, v)
out = rearrange(out, 'b h n d -> b (h d) n')
if self.only_space:
out = rearrange(out, '(b t) c n -> b c (t n)', t = t)
elif self.only_time:
out = rearrange(out, '(b x y) c n -> b c (n x y)', x = h, y = w)
out = out.reshape(orig_shape)
out = self.to_out(out)
return self.mp_add(out, res)
class KarrasUnet3D(Module):
"""
根据图21的配置G进行设计
"""
def __init__(
self,
*,
image_size,
frames,
dim = 192,
dim_max = 768,
num_classes = None,
channels = 4,
num_downsamples = 3,
num_blocks_per_stage: Union[int, Tuple[int, ...]] = 4,
downsample_types: Optional[Tuple[str, ...]] = None,
attn_res = (16, 8),
fourier_dim = 16,
attn_dim_head = 64,
attn_flash = False,
mp_cat_t = 0.5,
mp_add_emb_t = 0.5,
attn_res_mp_add_t = 0.3,
resnet_mp_add_t = 0.3,
dropout = 0.1,
self_condition = False,
factorize_space_time_attn = False
@property
def downsample_factor(self):
return 2 ** self.num_downsamples
def forward(
self,
x,
time,
self_cond = None,
class_labels = None
):
assert x.shape[1:] == (self.channels, self.frames, self.image_size, self.image_size)
if self.self_condition:
self_cond = default(self_cond, lambda: torch.zeros_like(x))
x = torch.cat((self_cond, x), dim = 1)
else:
assert not exists(self_cond)
time_emb = self.to_time_emb(time)
assert xnor(exists(class_labels), self.needs_class_labels)
if self.needs_class_labels:
if class_labels.dtype in (torch.int, torch.long):
class_labels = F.one_hot(class_labels, self.num_classes)
assert class_labels.shape[-1] == self.num_classes
class_labels = class_labels.float() * sqrt(self.num_classes)
class_emb = self.to_class_emb(class_labels)
time_emb = self.add_class_emb(time_emb, class_emb)
emb = self.emb_activation(time_emb)
skips = []
x = self.input_block(x)
skips.append(x)
for encoder in self.downs:
x = encoder(x, emb = emb)
skips.append(x)
for decoder in self.mids:
x = decoder(x, emb = emb)
for decoder in self.ups:
if decoder.needs_skip:
skip = skips.pop()
x = self.skip_mp_cat(x, skip)
x = decoder(x, emb = emb)
return self.output_block(x)
class MPFeedForward(Module):
def __init__(
self,
*,
dim,
mult = 4,
mp_add_t = 0.3
):
super().__init__()
dim_inner = int(dim * mult)
self.net = nn.Sequential(
PixelNorm(dim = 1),
Conv3d(dim, dim_inner, 1),
MPSiLU(),
Conv3d(dim_inner, dim, 1)
)
self.mp_add = MPAdd(t = mp_add_t)
def forward(self, x):
res = x
out = self.net(x)
return self.mp_add(out, res)
class MPImageTransformer(Module):
def __init__(
self,
*,
dim,
depth,
dim_head = 64,
heads = 8,
num_mem_kv = 4,
ff_mult = 4,
attn_flash = False,
residual_mp_add_t = 0.3
):
super().__init__()
self.layers = ModuleList([])
for _ in range(depth):
self.layers.append(ModuleList([
Attention(dim = dim, heads = heads, dim_head = dim_head, num_mem_kv = num_mem_kv, flash = attn_flash, mp_add_t = residual_mp_add_t),
MPFeedForward(dim = dim, mult = ff_mult, mp_add_t = residual_mp_add_t)
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x)
x = ff(x)
return x
if __name__ == '__main__':
unet = KarrasUnet3D(
frames = 32,
image_size = 64,
dim = 8,
dim_max = 768,
num_downsamples = 6,
num_blocks_per_stage = (4, 3, 2, 2, 2, 2),
downsample_types = (
'image',
'frame',
'image',
'frame',
'image',
'frame',
),
attn_dim_head = 8,
num_classes = 1000,
factorize_space_time_attn = True
)
video = torch.randn(2, 4, 32, 64, 64)
denoised_video = unet(
video,
time = torch.ones(2,),
class_labels = torch.randint(0, 1000, (2,))
)
.\lucidrains\denoising-diffusion-pytorch\denoising_diffusion_pytorch\learned_gaussian_diffusion.py
import torch
from collections import namedtuple
from math import pi, sqrt, log as ln
from inspect import isfunction
from torch import nn, einsum
from einops import rearrange
from denoising_diffusion_pytorch.denoising_diffusion_pytorch import GaussianDiffusion, extract, unnormalize_to_zero_to_one
NAT = 1. / ln(2)
ModelPrediction = namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start', 'pred_variance'])
def exists(x):
return x is not None
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def log(t, eps = 1e-15):
return torch.log(t.clamp(min = eps))
def meanflat(x):
return x.mean(dim = tuple(range(1, len(x.shape)))
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
KL divergence between normal distributions parameterized by mean and log-variance.
"""
return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2))
def approx_standard_normal_cdf(x):
return 0.5 * (1.0 + torch.tanh(sqrt(2.0 / pi) * (x + 0.044715 * (x ** 3)))
def discretized_gaussian_log_likelihood(x, *, means, log_scales, thres = 0.999):
assert x.shape == means.shape == log_scales.shape
centered_x = x - means
inv_stdv = torch.exp(-log_scales)
plus_in = inv_stdv * (centered_x + 1. / 255.)
cdf_plus = approx_standard_normal_cdf(plus_in)
min_in = inv_stdv * (centered_x - 1. / 255.)
cdf_min = approx_standard_normal_cdf(min_in)
log_cdf_plus = log(cdf_plus)
log_one_minus_cdf_min = log(1. - cdf_min)
cdf_delta = cdf_plus - cdf_min
log_probs = torch.where(x < -thres,
log_cdf_plus,
torch.where(x > thres,
log_one_minus_cdf_min,
log(cdf_delta)))
return log_probs
class LearnedGaussianDiffusion(GaussianDiffusion):
def __init__(
self,
model,
vb_loss_weight = 0.001,
*args,
**kwargs
):
super().__init__(model, *args, **kwargs)
assert model.out_dim == (model.channels * 2), 'dimension out of unet must be twice the number of channels for learned variance - you can also set the `learned_variance` keyword argument on the Unet to be `True`'
assert not model.self_condition, 'not supported yet'
self.vb_loss_weight = vb_loss_weight
def model_predictions(self, x, t, x_self_cond = None, clip_x_start = False, rederive_pred_noise = False):
model_output = self.model(x, t)
model_output, pred_variance = model_output.chunk(2, dim = 1)
maybe_clip = partial(torch.clamp, min = -1., max = 1.) if clip_x_start else identity
if self.objective == 'pred_noise':
pred_noise = model_output
x_start = self.predict_start_from_noise(x, t, model_output)
elif self.objective == 'pred_x0':
pred_noise = self.predict_noise_from_start(x, t, model_output)
x_start = model_output
x_start = maybe_clip(x_start)
return ModelPrediction(pred_noise, x_start, pred_variance)
def p_mean_variance(self, *, x, t, clip_denoised, model_output = None, **kwargs):
model_output = default(model_output, lambda: self.model(x, t))
pred_noise, var_interp_frac_unnormalized = model_output.chunk(2, dim = 1)
min_log = extract(self.posterior_log_variance_clipped, t, x.shape)
max_log = extract(torch.log(self.betas), t, x.shape)
var_interp_frac = unnormalize_to_zero_to_one(var_interp_frac_unnormalized)
model_log_variance = var_interp_frac * max_log + (1 - var_interp_frac) * min_log
model_variance = model_log_variance.exp()
x_start = self.predict_start_from_noise(x, t, pred_noise)
if clip_denoised:
x_start.clamp_(-1., 1.)
model_mean, _, _ = self.q_posterior(x_start, x, t)
return model_mean, model_variance, model_log_variance, x_start
def p_losses(self, x_start, t, noise = None, clip_denoised = False):
noise = default(noise, lambda: torch.randn_like(x_start))
x_t = self.q_sample(x_start = x_start, t = t, noise = noise)
model_output = self.model(x_t, t)
true_mean, _, true_log_variance_clipped = self.q_posterior(x_start = x_start, x_t = x_t, t = t)
model_mean, _, model_log_variance, _ = self.p_mean_variance(x = x_t, t = t, clip_denoised = clip_denoised, model_output = model_output)
detached_model_mean = model_mean.detach()
kl = normal_kl(true_mean, true_log_variance_clipped, detached_model_mean, model_log_variance)
kl = meanflat(kl) * NAT
decoder_nll = -discretized_gaussian_log_likelihood(x_start, means = detached_model_mean, log_scales = 0.5 * model_log_variance)
decoder_nll = meanflat(decoder_nll) * NAT
vb_losses = torch.where(t == 0, decoder_nll, kl)
pred_noise, _ = model_output.chunk(2, dim = 1)
simple_losses = F.mse_loss(pred_noise, noise)
return simple_losses + vb_losses.mean() * self.vb_loss_weight
.\lucidrains\denoising-diffusion-pytorch\denoising_diffusion_pytorch\simple_diffusion.py
import math
from functools import partial, wraps
import torch
from torch import sqrt
from torch import nn, einsum
import torch.nn.functional as F
from torch.special import expm1
from tqdm import tqdm
from einops import rearrange, repeat, reduce, pack, unpack
def exists(val):
return val is not None
def identity(t):
return t
def is_lambda(f):
return callable(f) and f.__name__ == "<lambda>"
def default(val, d):
if exists(val):
return val
return d() if is_lambda(d) else d
def cast_tuple(t, l = 1):
return ((t,) * l) if not isinstance(t, tuple) else t
def append_dims(t, dims):
shape = t.shape
return t.reshape(*shape, *((1,) * dims))
def l2norm(t):
return F.normalize(t, dim = -1)
class Upsample(nn.Module):
def __init__(
self,
dim,
dim_out = None,
factor = 2
):
super().__init__()
self.factor = factor
self.factor_squared = factor ** 2
dim_out = default(dim_out, dim)
conv = nn.Conv2d(dim, dim_out * self.factor_squared, 1)
self.net = nn.Sequential(
conv,
nn.SiLU(),
nn.PixelShuffle(factor)
)
self.init_conv_(conv)
def init_conv_(self, conv):
o, i, h, w = conv.weight.shape
conv_weight = torch.empty(o // self.factor_squared, i, h, w)
nn.init.kaiming_uniform_(conv_weight)
conv_weight = repeat(conv_weight, 'o ... -> (o r) ...', r = self.factor_squared)
conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data)
def forward(self, x):
return self.net(x)
def Downsample(
dim,
dim_out = None,
factor = 2
):
return nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (c p1 p2) h w', p1 = factor, p2 = factor),
nn.Conv2d(dim * (factor ** 2), default(dim_out, dim), 1)
)
class RMSNorm(nn.Module):
def __init__(self, dim, scale = True, normalize_dim = 2):
super().__init__()
self.g = nn.Parameter(torch.ones(dim)) if scale else 1
self.scale = scale
self.normalize_dim = normalize_dim
def forward(self, x):
normalize_dim = self.normalize_dim
scale = append_dims(self.g, x.ndim - self.normalize_dim - 1) if self.scale else 1
return F.normalize(x, dim = normalize_dim) * scale * (x.shape[normalize_dim] ** 0.5)
class LearnedSinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
assert (dim % 2) == 0
half_dim = dim // 2
self.weights = nn.Parameter(torch.randn(half_dim))
def forward(self, x):
x = rearrange(x, 'b -> b 1')
freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1)
fouriered = torch.cat((x, fouriered), dim = -1)
return fouriered
class Block(nn.Module):
def __init__(self, dim, dim_out, groups = 8):
super().__init__()
self.proj = nn.Conv2d(dim, dim_out, 3, padding = 1)
self.norm = nn.GroupNorm(groups, dim_out)
self.act = nn.SiLU()
def forward(self, x, scale_shift = None):
x = self.proj(x)
x = self.norm(x)
if exists(scale_shift):
scale, shift = scale_shift
x = x * (scale + 1) + shift
x = self.act(x)
return x
class ResnetBlock(nn.Module):
def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8):
super().__init__()
self.mlp = nn.Sequential(
nn.SiLU(),
nn.Linear(time_emb_dim, dim_out * 2)
) if exists(time_emb_dim) else None
self.block1 = Block(dim, dim_out, groups = groups)
self.block2 = Block(dim_out, dim_out, groups = groups)
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
def forward(self, x, time_emb = None):
scale_shift = None
if exists(self.mlp) and exists(time_emb):
time_emb = self.mlp(time_emb)
time_emb = rearrange(time_emb, 'b c -> b c 1 1')
scale_shift = time_emb.chunk(2, dim = 1)
h = self.block1(x, scale_shift = scale_shift)
h = self.block2(h)
return h + self.res_conv(x)
class LinearAttention(nn.Module):
def __init__(self, dim, heads = 4, dim_head = 32):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
hidden_dim = dim_head * heads
self.norm = RMSNorm(dim, normalize_dim = 1)
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
self.to_out = nn.Sequential(
nn.Conv2d(hidden_dim, dim, 1),
RMSNorm(dim, normalize_dim = 1)
)
def forward(self, x):
residual = x
b, c, h, w = x.shape
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = 1)
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv)
q = q.softmax(dim = -2)
k = k.softmax(dim = -1)
q = q * self.scale
context = torch.einsum('b h d n, b h e n -> b h d e', k, v)
out = torch.einsum('b h d e, b h d n -> b h e n', context, q)
out = rearrange(out, 'b h c (x y) -> b (h c) x y', h = self.heads, x = h, y = w)
return self.to_out(out) + residual
class Attention(nn.Module):
def __init__(self, dim, heads = 4, dim_head = 32, scale = 8, dropout = 0.):
super().__init__()
self.scale = scale
self.heads = heads
hidden_dim = dim_head * heads
self.norm = RMSNorm(dim)
self.attn_dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, hidden_dim * 3, bias = False)
self.q_scale = nn.Parameter(torch.ones(dim_head))
self.k_scale = nn.Parameter(torch.ones(dim_head))
self.to_out = nn.Linear(hidden_dim, dim, bias = False)
def forward(self, x):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
q, k = map(l2norm, (q, k))
q = q * self.q_scale
k = k * self.k_scale
sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
attn = sim.softmax(dim = -1)
attn = self.attn_dropout(attn)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class FeedForward(nn.Module):
def __init__(
self,
dim,
cond_dim,
mult = 4,
dropout = 0.
):
super().__init__()
self.norm = RMSNorm(dim, scale = False)
dim_hidden = dim * mult
self.to_scale_shift = nn.Sequential(
nn.SiLU(),
nn.Linear(cond_dim, dim_hidden * 2),
Rearrange('b d -> b 1 d')
)
to_scale_shift_linear = self.to_scale_shift[-2]
nn.init.zeros_(to_scale_shift_linear.weight)
nn.init.zeros_(to_scale_shift_linear.bias)
self.proj_in = nn.Sequential(
nn.Linear(dim, dim_hidden, bias = False),
nn.SiLU()
)
self.proj_out = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(dim_hidden, dim, bias = False)
)
def forward(self, x, t):
x = self.norm(x)
x = self.proj_in(x)
scale, shift = self.to_scale_shift(t).chunk(2, dim = -1)
x = x * (scale + 1) + shift
return self.proj_out(x)
class Transformer(nn.Module):
def __init__(
self,
dim,
time_cond_dim,
depth,
dim_head = 32,
heads = 4,
ff_mult = 4,
dropout = 0.,
):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = dropout),
FeedForward(dim = dim, mult = ff_mult, cond_dim = time_cond_dim, dropout = dropout)
]))
def forward(self, x, t):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x, t) + x
return x
class UViT(nn.Module):
def __init__(
self,
dim,
init_dim = None,
out_dim = None,
dim_mults = (1, 2, 4, 8),
downsample_factor = 2,
channels = 3,
vit_depth = 6,
vit_dropout = 0.2,
attn_dim_head = 32,
attn_heads = 4,
ff_mult = 4,
resnet_block_groups = 8,
learned_sinusoidal_dim = 16,
init_img_transform: callable = None,
final_img_itransform: callable = None,
patch_size = 1,
dual_patchnorm = False
):
super().__init__()
if exists(init_img_transform) and exists(final_img_itransform):
init_shape = torch.Size(1, 1, 32, 32)
mock_tensor = torch.randn(init_shape)
assert final_img_itransform(init_img_transform(mock_tensor)).shape == init_shape
self.init_img_transform = default(init_img_transform, identity)
self.final_img_itransform = default(final_img_itransform, identity)
input_channels = channels
init_dim = default(init_dim, dim)
self.init_conv = nn.Conv2d(input_channels, init_dim, 7, padding = 3)
self.unpatchify = identity
input_channels = channels * (patch_size ** 2)
needs_patch = patch_size > 1
if needs_patch:
if not dual_patchnorm:
self.init_conv = nn.Conv2d(channels, init_dim, patch_size, stride = patch_size)
else:
self.init_conv = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b h w (c p1 p2)', p1 = patch_size, p2 = patch_size),
nn.LayerNorm(input_channels),
nn.Linear(input_channels, init_dim),
nn.LayerNorm(init_dim),
Rearrange('b h w c -> b c h w')
)
self.unpatchify = nn.ConvTranspose2d(input_channels, channels, patch_size, stride = patch_size)
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
resnet_block = partial(ResnetBlock, groups = resnet_block_groups)
time_dim = dim * 4
sinu_pos_emb = LearnedSinusoidalPosEmb(learned_sinusoidal_dim)
fourier_dim = learned_sinusoidal_dim + 1
self.time_mlp = nn.Sequential(
sinu_pos_emb,
nn.Linear(fourier_dim, time_dim),
nn.GELU(),
nn.Linear(time_dim, time_dim)
)
downsample_factor = cast_tuple(downsample_factor, len(dim_mults)
assert len(downsample_factor) == len(dim_mults)
self.downs = nn.ModuleList([])
self.ups = nn.ModuleList([])
num_resolutions = len(in_out)
for ind, ((dim_in, dim_out), factor) in enumerate(zip(in_out, downsample_factor)):
is_last = ind >= (num_resolutions - 1)
self.downs.append(nn.ModuleList([
resnet_block(dim_in, dim_in, time_emb_dim = time_dim),
resnet_block(dim_in, dim_in, time_emb_dim = time_dim),
LinearAttention(dim_in),
Downsample(dim_in, dim_out, factor = factor)
]))
mid_dim = dims[-1]
self.vit = Transformer(
dim = mid_dim,
time_cond_dim = time_dim,
depth = vit_depth,
dim_head = attn_dim_head,
heads = attn_heads,
ff_mult = ff_mult,
dropout = vit_dropout
)
for ind, ((dim_in, dim_out), factor) in enumerate(zip(reversed(in_out), reversed(downsample_factor))):
is_last = ind == (len(in_out) - 1)
self.ups.append(nn.ModuleList([
Upsample(dim_out, dim_in, factor = factor),
resnet_block(dim_in * 2, dim_in, time_emb_dim = time_dim),
resnet_block(dim_in * 2, dim_in, time_emb_dim = time_dim),
LinearAttention(dim_in),
]))
default_out_dim = input_channels
self.out_dim = default(out_dim, default_out_dim)
self.final_res_block = resnet_block(dim * 2, dim, time_emb_dim = time_dim)
self.final_conv = nn.Conv2d(dim, self.out_dim, 1)
def forward(self, x, time):
x = self.init_img_transform(x)
x = self.init_conv(x)
r = x.clone()
t = self.time_mlp(time)
h = []
for block1, block2, attn, downsample in self.downs:
x = block1(x, t)
h.append(x)
x = block2(x, t)
x = attn(x)
h.append(x)
x = downsample(x)
x = rearrange(x, 'b c h w -> b h w c')
x, ps = pack([x], 'b * c')
x = self.vit(x, t)
x, = unpack(x, ps, 'b * c')
x = rearrange(x, 'b h w c -> b c h w')
for upsample, block1, block2, attn in self.ups:
x = upsample(x)
x = torch.cat((x, h.pop()), dim = 1)
x = block1(x, t)
x = torch.cat((x, h.pop()), dim = 1)
x = block2(x, t)
x = attn(x)
x = torch.cat((x, r), dim = 1)
x = self.final_res_block(x, t)
x = self.final_conv(x)
x = self.unpatchify(x)
return self.final_img_itransform(x)
def normalize_to_neg_one_to_one(img):
return img * 2 - 1
def unnormalize_to_zero_to_one(t):
return (t + 1) * 0.5
def right_pad_dims_to(x, t):
padding_dims = x.ndim - t.ndim
if padding_dims <= 0:
return t
return t.view(*t.shape, *((1,) * padding_dims))
def log(t, eps = 1e-20):
return torch.log(t.clamp(min = eps))
def logsnr_schedule_cosine(t, logsnr_min = -15, logsnr_max = 15):
t_min = math.atan(math.exp(-0.5 * logsnr_max))
t_max = math.atan(math.exp(-0.5 * logsnr_min))
return -2 * log(torch.tan(t_min + t * (t_max - t_min)))
def logsnr_schedule_shifted(fn, image_d, noise_d):
shift = 2 * math.log(noise_d / image_d)
@wraps(fn)
def inner(*args, **kwargs):
nonlocal shift
return fn(*args, **kwargs) + shift
return inner
def logsnr_schedule_interpolated(fn, image_d, noise_d_low, noise_d_high):
logsnr_low_fn = logsnr_schedule_shifted(fn, image_d, noise_d_low)
logsnr_high_fn = logsnr_schedule_shifted(fn, image_d, noise_d_high)
@wraps(fn)
def inner(t, *args, **kwargs):
nonlocal logsnr_low_fn
nonlocal logsnr_high_fn
return t * logsnr_low_fn(t, *args, **kwargs) + (1 - t) * logsnr_high_fn(t, *args, **kwargs)
return inner
class GaussianDiffusion(nn.Module):
def __init__(
self,
model: UViT,
*,
image_size,
channels = 3,
pred_objective = 'v',
noise_schedule = logsnr_schedule_cosine,
noise_d = None,
noise_d_low = None,
noise_d_high = None,
num_sample_steps = 500,
clip_sample_denoised = True,
min_snr_loss_weight = True,
min_snr_gamma = 5
):
super().__init__()
assert pred_objective in {'v', 'eps'}, 'whether to predict v-space (progressive distillation paper) or noise'
self.model = model
self.channels = channels
self.image_size = image_size
self.pred_objective = pred_objective
assert not all([*map(exists, (noise_d, noise_d_low, noise_d_high))]), 'you must either set noise_d for shifted schedule, or noise_d_low and noise_d_high for shifted and interpolated schedule'
self.log_snr = noise_schedule
if exists(noise_d):
self.log_snr = logsnr_schedule_shifted(self.log_snr, image_size, noise_d)
if exists(noise_d_low) or exists(noise_d_high):
assert exists(noise_d_low) and exists(noise_d_high), 'both noise_d_low and noise_d_high must be set'
self.log_snr = logsnr_schedule_interpolated(self.log_snr, image_size, noise_d_low, noise_d_high)
self.num_sample_steps = num_sample_steps
self.clip_sample_denoised = clip_sample_denoised
self.min_snr_loss_weight = min_snr_loss_weight
self.min_snr_gamma = min_snr_gamma
@property
def device(self):
return next(self.model.parameters()).device
def p_mean_variance(self, x, time, time_next):
log_snr = self.log_snr(time)
log_snr_next = self.log_snr(time_next)
c = -expm1(log_snr - log_snr_next)
squared_alpha, squared_alpha_next = log_snr.sigmoid(), log_snr_next.sigmoid()
squared_sigma, squared_sigma_next = (-log_snr).sigmoid(), (-log_snr_next).sigmoid()
alpha, sigma, alpha_next = map(sqrt, (squared_alpha, squared_sigma, squared_alpha_next))
batch_log_snr = repeat(log_snr, ' -> b', b = x.shape[0])
pred = self.model(x, batch_log_snr)
if self.pred_objective == 'v':
x_start = alpha * x - sigma * pred
elif self.pred_objective == 'eps':
x_start = (x - sigma * pred) / alpha
x_start.clamp_(-1., 1.)
model_mean = alpha_next * (x * (1 - c) / alpha + c * x_start)
posterior_variance = squared_sigma_next * c
return model_mean, posterior_variance
@torch.no_grad()
def p_sample(self, x, time, time_next):
batch, *_, device = *x.shape, x.device
model_mean, model_variance = self.p_mean_variance(x = x, time = time, time_next = time_next)
if time_next == 0:
return model_mean
noise = torch.randn_like(x)
return model_mean + sqrt(model_variance) * noise
@torch.no_grad()
def p_sample_loop(self, shape):
batch = shape[0]
img = torch.randn(shape, device = self.device)
steps = torch.linspace(1., 0., self.num_sample_steps + 1, device = self.device)
for i in tqdm(range(self.num_sample_steps), desc = 'sampling loop time step', total = self.num_sample_steps):
times = steps[i]
times_next = steps[i + 1]
img = self.p_sample(img, times, times_next)
img.clamp_(-1., 1.)
img = unnormalize_to_zero_to_one(img)
return img
@torch.no_grad()
def sample(self, batch_size = 16):
return self.p_sample_loop((batch_size, self.channels, self.image_size, self.image_size))
@autocast(enabled = False)
def q_sample(self, x_start, times, noise = None):
noise = default(noise, lambda: torch.randn_like(x_start))
log_snr = self.log_snr(times)
log_snr_padded = right_pad_dims_to(x_start, log_snr)
alpha, sigma = sqrt(log_snr_padded.sigmoid()), sqrt((-log_snr_padded).sigmoid())
x_noised = x_start * alpha + noise * sigma
return x_noised, log_snr
def p_losses(self, x_start, times, noise = None):
noise = default(noise, lambda: torch.randn_like(x_start))
x, log_snr = self.q_sample(x_start = x_start, times = times, noise = noise)
model_out = self.model(x, log_snr)
if self.pred_objective == 'v':
padded_log_snr = right_pad_dims_to(x, log_snr)
alpha, sigma = padded_log_snr.sigmoid().sqrt(), (-padded_log_snr).sigmoid().sqrt()
target = alpha * noise - sigma * x_start
elif self.pred_objective == 'eps':
target = noise
loss = F.mse_loss(model_out, target, reduction = 'none')
loss = reduce(loss, 'b ... -> b', 'mean')
snr = log_snr.exp()
maybe_clip_snr = snr.clone()
if self.min_snr_loss_weight:
maybe_clip_snr.clamp_(max = self.min_snr_gamma)
if self.pred_objective == 'v':
loss_weight = maybe_clip_snr / (snr + 1)
elif self.pred_objective == 'eps':
loss_weight = maybe_clip_snr / snr
return (loss * loss_weight).mean()
def forward(self, img, *args, **kwargs):
b, c, h, w, device, img_size, = *img.shape, img.device, self.image_size
assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
img = normalize_to_neg_one_to_one(img)
times = torch.zeros((img.shape[0],), device = self.device).float().uniform_(0, 1)
return self.p_losses(img, times, *args, **kwargs)