Lucidrains 系列项目源码解析(二十五)
.\lucidrains\denoising-diffusion-pytorch\denoising_diffusion_pytorch\elucidated_diffusion.py
from math import sqrt
from random import random
import torch
from torch import nn, einsum
import torch.nn.functional as F
from tqdm import tqdm
def exists(val):
return val is not None
def default(val, d):
if exists(val):
return val
return d() if callable(d) else d
def log(t, eps = 1e-20):
return torch.log(t.clamp(min = eps))
def normalize_to_neg_one_to_one(img):
return img * 2 - 1
def unnormalize_to_zero_to_one(t):
return (t + 1) * 0.5
class ElucidatedDiffusion(nn.Module):
def __init__(
self,
net,
*,
image_size,
channels = 3,
num_sample_steps = 32,
sigma_min = 0.002,
sigma_max = 80,
sigma_data = 0.5,
rho = 7,
P_mean = -1.2,
P_std = 1.2,
S_churn = 80,
S_tmin = 0.05,
S_tmax = 50,
S_noise = 1.003,
):
super().__init__()
assert net.random_or_learned_sinusoidal_cond
self.self_condition = net.self_condition
self.net = net
self.channels = channels
self.image_size = image_size
self.sigma_min = sigma_min
self.sigma_max = sigma_max
self.sigma_data = sigma_data
self.rho = rho
self.P_mean = P_mean
self.P_std = P_std
self.num_sample_steps = num_sample_steps
self.S_churn = S_churn
self.S_tmin = S_tmin
self.S_tmax = S_tmax
self.S_noise = S_noise
@property
def device(self):
return next(self.net.parameters()).device
def c_skip(self, sigma):
return (self.sigma_data ** 2) / (sigma ** 2 + self.sigma_data ** 2)
def c_out(self, sigma):
return sigma * self.sigma_data * (self.sigma_data ** 2 + sigma ** 2) ** -0.5
def c_in(self, sigma):
return 1 * (sigma ** 2 + self.sigma_data ** 2) ** -0.5
def c_noise(self, sigma):
return log(sigma) * 0.25
def preconditioned_network_forward(self, noised_images, sigma, self_cond = None, clamp = False):
batch, device = noised_images.shape[0], noised_images.device
if isinstance(sigma, float):
sigma = torch.full((batch,), sigma, device = device)
padded_sigma = rearrange(sigma, 'b -> b 1 1 1')
net_out = self.net(
self.c_in(padded_sigma) * noised_images,
self.c_noise(sigma),
self_cond
)
out = self.c_skip(padded_sigma) * noised_images + self.c_out(padded_sigma) * net_out
if clamp:
out = out.clamp(-1., 1.)
return out
def sample_schedule(self, num_sample_steps = None):
num_sample_steps = default(num_sample_steps, self.num_sample_steps)
N = num_sample_steps
inv_rho = 1 / self.rho
steps = torch.arange(num_sample_steps, device = self.device, dtype = torch.float32)
sigmas = (self.sigma_max ** inv_rho + steps / (N - 1) * (self.sigma_min ** inv_rho - self.sigma_max ** inv_rho)) ** self.rho
sigmas = F.pad(sigmas, (0, 1), value = 0.)
return sigmas
@torch.no_grad()
def sample(self, batch_size = 16, num_sample_steps = None, clamp = True):
num_sample_steps = default(num_sample_steps, self.num_sample_steps)
shape = (batch_size, self.channels, self.image_size, self.image_size)
sigmas = self.sample_schedule(num_sample_steps)
gammas = torch.where(
(sigmas >= self.S_tmin) & (sigmas <= self.S_tmax),
min(self.S_churn / num_sample_steps, sqrt(2) - 1),
0.
)
sigmas_and_gammas = list(zip(sigmas[:-1], sigmas[1:], gammas[:-1])
init_sigma = sigmas[0]
images = init_sigma * torch.randn(shape, device = self.device)
x_start = None
for sigma, sigma_next, gamma in tqdm(sigmas_and_gammas, desc = 'sampling time step'):
sigma, sigma_next, gamma = map(lambda t: t.item(), (sigma, sigma_next, gamma))
eps = self.S_noise * torch.randn(shape, device = self.device)
sigma_hat = sigma + gamma * sigma
images_hat = images + sqrt(sigma_hat ** 2 - sigma ** 2) * eps
self_cond = x_start if self.self_condition else None
model_output = self.preconditioned_network_forward(images_hat, sigma_hat, self_cond, clamp = clamp)
denoised_over_sigma = (images_hat - model_output) / sigma_hat
images_next = images_hat + (sigma_next - sigma_hat) * denoised_over_sigma
if sigma_next != 0:
self_cond = model_output if self.self_condition else None
model_output_next = self.preconditioned_network_forward(images_next, sigma_next, self_cond, clamp = clamp)
denoised_prime_over_sigma = (images_next - model_output_next) / sigma_next
images_next = images_hat + 0.5 * (sigma_next - sigma_hat) * (denoised_over_sigma + denoised_prime_over_sigma)
images = images_next
x_start = model_output_next if sigma_next != 0 else model_output
images = images.clamp(-1., 1.)
return unnormalize_to_zero_to_one(images)
@torch.no_grad()
def sample_using_dpmpp(self, batch_size = 16, num_sample_steps = None):
"""
感谢Katherine Crowson (https://github.com/crowsonkb)解决了所有问题!
https://arxiv.org/abs/2211.01095
"""
device, num_sample_steps = self.device, default(num_sample_steps, self.num_sample_steps)
sigmas = self.sample_schedule(num_sample_steps)
shape = (batch_size, self.channels, self.image_size, self.image_size)
images = sigmas[0] * torch.randn(shape, device = device)
sigma_fn = lambda t: t.neg().exp()
t_fn = lambda sigma: sigma.log().neg()
old_denoised = None
for i in tqdm(range(len(sigmas) - 1)):
denoised = self.preconditioned_network_forward(images, sigmas[i].item())
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
h = t_next - t
if not exists(old_denoised) or sigmas[i + 1] == 0:
denoised_d = denoised
else:
h_last = t - t_fn(sigmas[i - 1])
r = h_last / h
gamma = - 1 / (2 * r)
denoised_d = (1 - gamma) * denoised + gamma * old_denoised
images = (sigma_fn(t_next) / sigma_fn(t)) * images - (-h).expm1() * denoised_d
old_denoised = denoised
images = images.clamp(-1., 1.)
return unnormalize_to_zero_to_one(images)
def loss_weight(self, sigma):
return (sigma ** 2 + self.sigma_data ** 2) * (sigma * self.sigma_data) ** -2
def noise_distribution(self, batch_size):
return (self.P_mean + self.P_std * torch.randn((batch_size,), device = self.device)).exp()
def forward(self, images):
batch_size, c, h, w, device, image_size, channels = *images.shape, images.device, self.image_size, self.channels
assert h == image_size and w == image_size, f'height and width of image must be {image_size}'
assert c == channels, 'mismatch of image channels'
images = normalize_to_neg_one_to_one(images)
sigmas = self.noise_distribution(batch_size)
padded_sigmas = rearrange(sigmas, 'b -> b 1 1 1')
noise = torch.randn_like(images)
noised_images = images + padded_sigmas * noise
self_cond = None
if self.self_condition and random() < 0.5:
with torch.no_grad():
self_cond = self.preconditioned_network_forward(noised_images, sigmas)
self_cond.detach_()
denoised = self.preconditioned_network_forward(noised_images, sigmas, self_cond)
losses = F.mse_loss(denoised, images, reduction = 'none')
losses = reduce(losses, 'b ... -> b', 'mean')
losses = losses * self.loss_weight(sigmas)
return losses.mean()
.\lucidrains\denoising-diffusion-pytorch\denoising_diffusion_pytorch\fid_evaluation.py
import math
import os
import numpy as np
import torch
from einops import rearrange, repeat
from pytorch_fid.fid_score import calculate_frechet_distance
from pytorch_fid.inception import InceptionV3
from torch.nn.functional import adaptive_avg_pool2d
from tqdm.auto import tqdm
def num_to_groups(num, divisor):
groups = num // divisor
remainder = num % divisor
arr = [divisor] * groups
if remainder > 0:
arr.append(remainder)
return arr
class FIDEvaluation:
def __init__(
self,
batch_size,
dl,
sampler,
channels=3,
accelerator=None,
stats_dir="./results",
device="cuda",
num_fid_samples=50000,
inception_block_idx=2048,
):
self.batch_size = batch_size
self.n_samples = num_fid_samples
self.device = device
self.channels = channels
self.dl = dl
self.sampler = sampler
self.stats_dir = stats_dir
self.print_fn = print if accelerator is None else accelerator.print
assert inception_block_idx in InceptionV3.BLOCK_INDEX_BY_DIM
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[inception_block_idx]
self.inception_v3 = InceptionV3([block_idx]).to(device)
self.dataset_stats_loaded = False
def calculate_inception_features(self, samples):
if self.channels == 1:
samples = repeat(samples, "b 1 ... -> b c ...", c=3)
self.inception_v3.eval()
features = self.inception_v3(samples)[0]
if features.size(2) != 1 or features.size(3) != 1:
features = adaptive_avg_pool2d(features, output_size=(1, 1))
features = rearrange(features, "... 1 1 -> ...")
return features
def load_or_precalc_dataset_stats(self):
path = os.path.join(self.stats_dir, "dataset_stats")
try:
ckpt = np.load(path + ".npz")
self.m2, self.s2 = ckpt["m2"], ckpt["s2"]
self.print_fn("Dataset stats loaded from disk.")
ckpt.close()
except OSError:
num_batches = int(math.ceil(self.n_samples / self.batch_size))
stacked_real_features = []
self.print_fn(
f"Stacking Inception features for {self.n_samples} samples from the real dataset."
)
for _ in tqdm(range(num_batches)):
try:
real_samples = next(self.dl)
except StopIteration:
break
real_samples = real_samples.to(self.device)
real_features = self.calculate_inception_features(real_samples)
stacked_real_features.append(real_features)
stacked_real_features = (
torch.cat(stacked_real_features, dim=0).cpu().numpy()
)
m2 = np.mean(stacked_real_features, axis=0)
s2 = np.cov(stacked_real_features, rowvar=False)
np.savez_compressed(path, m2=m2, s2=s2)
self.print_fn(f"Dataset stats cached to {path}.npz for future use.")
self.m2, self.s2 = m2, s2
self.dataset_stats_loaded = True
@torch.inference_mode()
def fid_score(self):
if not self.dataset_stats_loaded:
self.load_or_precalc_dataset_stats()
self.sampler.eval()
batches = num_to_groups(self.n_samples, self.batch_size)
stacked_fake_features = []
self.print_fn(
f"Stacking Inception features for {self.n_samples} generated samples."
)
for batch in tqdm(batches):
fake_samples = self.sampler.sample(batch_size=batch)
fake_features = self.calculate_inception_features(fake_samples)
stacked_fake_features.append(fake_features)
stacked_fake_features = torch.cat(stacked_fake_features, dim=0).cpu().numpy()
m1 = np.mean(stacked_fake_features, axis=0)
s1 = np.cov(stacked_fake_features, rowvar=False)
return calculate_frechet_distance(m1, s1, self.m2, self.s2)
.\lucidrains\denoising-diffusion-pytorch\denoising_diffusion_pytorch\guided_diffusion.py
import math
import copy
from pathlib import Path
from random import random
from functools import partial
from collections import namedtuple
from multiprocessing import cpu_count
import torch
from torch import nn, einsum
import torch.nn.functional as F
from torch.cuda.amp import autocast
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from torchvision import transforms as T, utils
from einops import rearrange, reduce
from einops.layers.torch import Rearrange
from PIL import Image
from tqdm.auto import tqdm
from ema_pytorch import EMA
from accelerate import Accelerator
ModelPrediction = namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start'])
def exists(x):
return x is not None
def default(val, d):
if exists(val):
return val
return d() if callable(d) else d
def identity(t, *args, **kwargs):
return t
def cycle(dl):
while True:
for data in dl:
yield data
def has_int_squareroot(num):
return (math.sqrt(num) ** 2) == num
def num_to_groups(num, divisor):
groups = num // divisor
remainder = num % divisor
arr = [divisor] * groups
if remainder > 0:
arr.append(remainder)
return arr
def convert_image_to_fn(img_type, image):
if image.mode != img_type:
return image.convert(img_type)
return image
def normalize_to_neg_one_to_one(img):
return img * 2 - 1
def unnormalize_to_zero_to_one(t):
return (t + 1) * 0.5
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, *args, **kwargs):
return self.fn(x, *args, **kwargs) + x
def Upsample(dim, dim_out = None):
return nn.Sequential(
nn.Upsample(scale_factor = 2, mode = 'nearest'),
nn.Conv2d(dim, default(dim_out, dim), 3, padding = 1)
)
def Downsample(dim, dim_out = None):
return nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (c p1 p2) h w', p1 = 2, p2 = 2),
nn.Conv2d(dim * 4, default(dim_out, dim), 1)
)
class RMSNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
def forward(self, x):
return F.normalize(x, dim = 1) * self.g * (x.shape[-1] ** 0.5)
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = RMSNorm(dim)
def forward(self, x):
x = self.norm(x)
return self.fn(x)
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class RandomOrLearnedSinusoidalPosEmb(nn.Module):
def __init__(self, dim, is_random = False):
super().__init__()
assert (dim % 2) == 0
half_dim = dim // 2
self.weights = nn.Parameter(torch.randn(half_dim), requires_grad = not is_random)
def forward(self, x):
x = rearrange(x, 'b -> b 1')
freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1)
fouriered = torch.cat((x, fouriered), dim = -1)
return fouriered
class Block(nn.Module):
def __init__(self, dim, dim_out, groups = 8):
super().__init__()
self.proj = nn.Conv2d(dim, dim_out, 3, padding = 1)
self.norm = nn.GroupNorm(groups, dim_out)
self.act = nn.SiLU()
def forward(self, x, scale_shift = None):
x = self.proj(x)
x = self.norm(x)
if exists(scale_shift):
scale, shift = scale_shift
x = x * (scale + 1) + shift
x = self.act(x)
return x
class ResnetBlock(nn.Module):
def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8):
super().__init__()
self.mlp = nn.Sequential(
nn.SiLU(),
nn.Linear(time_emb_dim, dim_out * 2)
) if exists(time_emb_dim) else None
self.block1 = Block(dim, dim_out, groups = groups)
self.block2 = Block(dim_out, dim_out, groups = groups)
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
def forward(self, x, time_emb = None):
scale_shift = None
if exists(self.mlp) and exists(time_emb):
time_emb = self.mlp(time_emb)
time_emb = rearrange(time_emb, 'b c -> b c 1 1')
scale_shift = time_emb.chunk(2, dim = 1)
h = self.block1(x, scale_shift = scale_shift)
h = self.block2(h)
return h + self.res_conv(x)
class LinearAttention(nn.Module):
def __init__(self, dim, heads = 4, dim_head = 32):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
self.to_out = nn.Sequential(
nn.Conv2d(hidden_dim, dim, 1),
RMSNorm(dim)
)
def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x).chunk(3, dim = 1)
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv)
q = q.softmax(dim = -2)
k = k.softmax(dim = -1)
q = q * self.scale
context = torch.einsum('b h d n, b h e n -> b h d e', k, v)
out = torch.einsum('b h d e, b h d n -> b h e n', context, q)
out = rearrange(out, 'b h c (x y) -> b (h c) x y', h = self.heads, x = h, y = w)
return self.to_out(out)
class Attention(nn.Module):
def __init__(self, dim, heads = 4, dim_head = 32):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x).chunk(3, dim = 1)
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv)
q = q * self.scale
sim = einsum('b h d i, b h d j -> b h i j', q, k)
attn = sim.softmax(dim = -1)
out = einsum('b h i j, b h d j -> b h i d', attn, v)
out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w)
return self.to_out(out)
class Unet(nn.Module):
def __init__(
self,
dim,
init_dim = None,
out_dim = None,
dim_mults=(1, 2, 4, 8),
channels = 3,
self_condition = False,
resnet_block_groups = 8,
learned_variance = False,
learned_sinusoidal_cond = False,
random_fourier_features = False,
learned_sinusoidal_dim = 16
):
super().__init__()
self.channels = channels
self.self_condition = self_condition
input_channels = channels * (2 if self_condition else 1)
init_dim = default(init_dim, dim)
self.init_conv = nn.Conv2d(input_channels, init_dim, 7, padding = 3)
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
block_klass = partial(ResnetBlock, groups = resnet_block_groups)
time_dim = dim * 4
self.random_or_learned_sinusoidal_cond = learned_sinusoidal_cond or random_fourier_features
if self.random_or_learned_sinusoidal_cond:
sinu_pos_emb = RandomOrLearnedSinusoidalPosEmb(learned_sinusoidal_dim, random_fourier_features)
fourier_dim = learned_sinusoidal_dim + 1
else:
sinu_pos_emb = SinusoidalPosEmb(dim)
fourier_dim = dim
self.time_mlp = nn.Sequential(
sinu_pos_emb,
nn.Linear(fourier_dim, time_dim),
nn.GELU(),
nn.Linear(time_dim, time_dim)
)
self.downs = nn.ModuleList([])
self.ups = nn.ModuleList([])
num_resolutions = len(in_out)
for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (num_resolutions - 1)
self.downs.append(nn.ModuleList([
block_klass(dim_in, dim_in, time_emb_dim = time_dim),
block_klass(dim_in, dim_in, time_emb_dim = time_dim),
Residual(PreNorm(dim_in, LinearAttention(dim_in))),
Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding = 1)
]))
mid_dim = dims[-1]
self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim)
self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
is_last = ind == (len(in_out) - 1)
self.ups.append(nn.ModuleList([
block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
Residual(PreNorm(dim_out, LinearAttention(dim_out))),
Upsample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, padding = 1)
]))
default_out_dim = channels * (1 if not learned_variance else 2)
self.out_dim = default(out_dim, default_out_dim)
self.final_res_block = block_klass(dim * 2, dim, time_emb_dim = time_dim)
self.final_conv = nn.Conv2d(dim, self.out_dim, 1)
def forward(self, x, time, x_self_cond = None):
if self.self_condition:
x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x))
x = torch.cat((x_self_cond, x), dim = 1)
x = self.init_conv(x)
r = x.clone()
t = self.time_mlp(time)
h = []
for block1, block2, attn, downsample in self.downs:
x = block1(x, t)
h.append(x)
x = block2(x, t)
x = attn(x)
h.append(x)
x = downsample(x)
x = self.mid_block1(x, t)
x = self.mid_attn(x)
x = self.mid_block2(x, t)
for block1, block2, attn, upsample in self.ups:
x = torch.cat((x, h.pop()), dim = 1)
x = block1(x, t)
x = torch.cat((x, h.pop()), dim = 1)
x = block2(x, t)
x = attn(x)
x = upsample(x)
x = torch.cat((x, r), dim = 1)
x = self.final_res_block(x, t)
return self.final_conv(x)
def extract(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
def linear_beta_schedule(timesteps):
scale = 1000 / timesteps
beta_start = scale * 0.0001
beta_end = scale * 0.02
return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)
def cosine_beta_schedule(timesteps, s=0.008):
steps = timesteps + 1
t = torch.linspace(0, timesteps, steps, dtype=torch.float64) / timesteps
alphas_cumprod = torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return torch.clip(betas, 0, 0.999)
def sigmoid_beta_schedule(timesteps, start=-3, end=3, tau=1, clamp_min=1e-5):
steps = timesteps + 1
t = torch.linspace(0, timesteps, steps, dtype=torch.float64) / timesteps
v_start = torch.tensor(start / tau).sigmoid()
v_end = torch.tensor(end / tau).sigmoid()
alphas_cumprod = (-((t * (end - start) + start) / tau).sigmoid() + v_end) / (v_end - v_start)
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return torch.clip(betas, 0, 0.999)
class GaussianDiffusion(nn.Module):
def __init__(
self,
model,
*,
image_size,
timesteps=1000,
sampling_timesteps=None,
objective='pred_noise',
beta_schedule='sigmoid',
schedule_fn_kwargs=dict(),
ddim_sampling_eta=0.,
auto_normalize=True,
min_snr_loss_weight=False,
min_snr_gamma=5
def predict_start_from_noise(self, x_t, t, noise):
return (
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
)
def predict_noise_from_start(self, x_t, t, x0):
return (
(extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / \
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
)
def predict_v(self, x_start, t, noise):
return (
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * noise -
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start
)
def predict_start_from_v(self, x_t, t, v):
return (
extract(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
)
def q_posterior(self, x_start, x_t, t):
posterior_mean = (
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def model_predictions(self, x, t, x_self_cond = None, clip_x_start = False):
model_output = self.model(x, t, x_self_cond)
maybe_clip = partial(torch.clamp, min = -1., max = 1.) if clip_x_start else identity
if self.objective == 'pred_noise':
pred_noise = model_output
x_start = self.predict_start_from_noise(x, t, pred_noise)
x_start = maybe_clip(x_start)
elif self.objective == 'pred_x0':
x_start = model_output
x_start = maybe_clip(x_start)
pred_noise = self.predict_noise_from_start(x, t, x_start)
elif self.objective == 'pred_v':
v = model_output
x_start = self.predict_start_from_v(x, t, v)
x_start = maybe_clip(x_start)
pred_noise = self.predict_noise_from_start(x, t, x_start)
return ModelPrediction(pred_noise, x_start)
def p_mean_variance(self, x, t, x_self_cond = None, clip_denoised = True):
preds = self.model_predictions(x, t, x_self_cond)
x_start = preds.pred_x_start
if clip_denoised:
x_start.clamp_(-1., 1.)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start = x_start, x_t = x, t = t)
return model_mean, posterior_variance, posterior_log_variance, x_start
def condition_mean(self, cond_fn, mean, variance, x, t, guidance_kwargs=None):
"""
Compute the mean for the previous step, given a function cond_fn that
computes the gradient of a conditional log probability with respect to
x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
condition on y.
This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
"""
gradient = cond_fn(mean, t, **guidance_kwargs)
new_mean = (
mean.float() + variance * gradient.float()
)
print("gradient: ",(variance * gradient.float()).mean())
return new_mean
def p_sample(self, x, t: int, x_self_cond = None, cond_fn=None, guidance_kwargs=None):
b, *_, device = *x.shape, x.device
batched_times = torch.full((b,), t, device = x.device, dtype = torch.long)
model_mean, variance, model_log_variance, x_start = self.p_mean_variance(
x = x, t = batched_times, x_self_cond = x_self_cond, clip_denoised = True
)
if exists(cond_fn) and exists(guidance_kwargs):
model_mean = self.condition_mean(cond_fn, model_mean, variance, x, batched_times, guidance_kwargs)
noise = torch.randn_like(x) if t > 0 else 0.
pred_img = model_mean + (0.5 * model_log_variance).exp() * noise
return pred_img, x_start
def p_sample_loop(self, shape, return_all_timesteps = False, cond_fn=None, guidance_kwargs=None):
batch, device = shape[0], self.betas.device
img = torch.randn(shape, device = device)
imgs = [img]
x_start = None
for t in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
self_cond = x_start if self.self_condition else None
img, x_start = self.p_sample(img, t, self_cond, cond_fn, guidance_kwargs)
imgs.append(img)
ret = img if not return_all_timesteps else torch.stack(imgs, dim = 1)
ret = self.unnormalize(ret)
return ret
@torch.no_grad()
def ddim_sample(self, shape, return_all_timesteps = False, cond_fn=None, guidance_kwargs=None):
batch, device, total_timesteps, sampling_timesteps, eta, objective = shape[0], self.betas.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective
times = torch.linspace(-1, total_timesteps - 1, steps = sampling_timesteps + 1)
times = list(reversed(times.int().tolist()))
time_pairs = list(zip(times[:-1], times[1:]))
img = torch.randn(shape, device = device)
imgs = [img]
x_start = None
for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
self_cond = x_start if self.self_condition else None
pred_noise, x_start, *_ = self.model_predictions(img, time_cond, self_cond, clip_x_start = True)
imgs.append(img)
if time_next < 0:
img = x_start
continue
alpha = self.alphas_cumprod[time]
alpha_next = self.alphas_cumprod[time_next]
sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
c = (1 - alpha_next - sigma ** 2).sqrt()
noise = torch.randn_like(img)
img = x_start * alpha_next.sqrt() + \
c * pred_noise + \
sigma * noise
ret = img if not return_all_timesteps else torch.stack(imgs, dim = 1)
ret = self.unnormalize(ret)
return ret
@torch.no_grad()
def sample(self, batch_size = 16, return_all_timesteps = False, cond_fn=None, guidance_kwargs=None):
image_size, channels = self.image_size, self.channels
sample_fn = self.p_sample_loop if not self.is_ddim_sampling else self.ddim_sample
return sample_fn((batch_size, channels, image_size, image_size), return_all_timesteps = return_all_timesteps, cond_fn=cond_fn, guidance_kwargs=guidance_kwargs)
@torch.no_grad()
def interpolate(self, x1, x2, t = None, lam = 0.5):
b, *_, device = *x1.shape, x1.device
t = default(t, self.num_timesteps - 1)
assert x1.shape == x2.shape
t_batched = torch.full((b,), t, device = device)
xt1, xt2 = map(lambda x: self.q_sample(x, t = t_batched), (x1, x2))
img = (1 - lam) * xt1 + lam * xt2
x_start = None
for i in tqdm(reversed(range(0, t)), desc = 'interpolation sample time step', total = t):
self_cond = x_start if self.self_condition else None
img, x_start = self.p_sample(img, i, self_cond)
return img
@autocast(enabled = False)
def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
return (
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
)
def p_losses(self, x_start, t, noise = None):
b, c, h, w = x_start.shape
noise = default(noise, lambda: torch.randn_like(x_start))
x = self.q_sample(x_start = x_start, t = t, noise = noise)
x_self_cond = None
if self.self_condition and random() < 0.5:
with torch.no_grad():
x_self_cond = self.model_predictions(x, t).pred_x_start
x_self_cond.detach_()
model_out = self.model(x, t, x_self_cond)
if self.objective == 'pred_noise':
target = noise
elif self.objective == 'pred_x0':
target = x_start
elif self.objective == 'pred_v':
v = self.predict_v(x_start, t, noise)
target = v
else:
raise ValueError(f'unknown objective {self.objective}')
loss = F.mse_loss(model_out, target, reduction = 'none')
loss = reduce(loss, 'b ... -> b', 'mean')
loss = loss * extract(self.loss_weight, t, loss.shape)
return loss.mean()
def forward(self, img, *args, **kwargs):
b, c, h, w, device, img_size, = *img.shape, img.device, self.image_size
assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
t = torch.randint(0, self.num_timesteps, (b,), device=device).long()
img = self.normalize(img)
return self.p_losses(img, t, *args, **kwargs)
class Dataset(Dataset):
def __init__(
self,
folder,
image_size,
exts = ['jpg', 'jpeg', 'png', 'tiff'],
augment_horizontal_flip = False,
convert_image_to = None
):
super().__init__()
self.folder = folder
self.image_size = image_size
self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]
maybe_convert_fn = partial(convert_image_to_fn, convert_image_to) if exists(convert_image_to) else nn.Identity()
self.transform = T.Compose([
T.Lambda(maybe_convert_fn),
T.Resize(image_size),
T.RandomHorizontalFlip() if augment_horizontal_flip else nn.Identity(),
T.CenterCrop(image_size),
T.ToTensor()
])
def __len__(self):
return len(self.paths)
def __getitem__(self, index):
path = self.paths[index]
img = Image.open(path)
return self.transform(img)
class Trainer(object):
def __init__(
self,
diffusion_model,
folder,
*,
train_batch_size = 16,
gradient_accumulate_every = 1,
augment_horizontal_flip = True,
train_lr = 1e-4,
train_num_steps = 100000,
ema_update_every = 10,
ema_decay = 0.995,
adam_betas = (0.9, 0.99),
save_and_sample_every = 1000,
num_samples = 25,
results_folder = './results',
amp = False,
fp16 = False,
split_batches = True,
convert_image_to = None
):
super().__init__()
self.accelerator = Accelerator(
split_batches = split_batches,
mixed_precision = 'fp16' if fp16 else 'no'
)
self.accelerator.native_amp = amp
self.model = diffusion_model
assert has_int_squareroot(num_samples), 'number of samples must have an integer square root'
self.num_samples = num_samples
self.save_and_sample_every = save_and_sample_every
self.batch_size = train_batch_size
self.gradient_accumulate_every = gradient_accumulate_every
self.train_num_steps = train_num_steps
self.image_size = diffusion_model.image_size
self.ds = Dataset(folder, self.image_size, augment_horizontal_flip = augment_horizontal_flip, convert_image_to = convert_image_to)
dl = DataLoader(self.ds, batch_size = train_batch_size, shuffle = True, pin_memory = True, num_workers = cpu_count())
dl = self.accelerator.prepare(dl)
self.dl = cycle(dl)
self.opt = Adam(diffusion_model.parameters(), lr = train_lr, betas = adam_betas)
if self.accelerator.is_main_process:
self.ema = EMA(diffusion_model, beta = ema_decay, update_every = ema_update_every)
self.results_folder = Path(results_folder)
self.results_folder.mkdir(exist_ok = True)
self.step = 0
self.model, self.opt = self.accelerator.prepare(self.model, self.opt)
def save(self, milestone):
if not self.accelerator.is_local_main_process:
return
data = {
'step': self.step,
'model': self.accelerator.get_state_dict(self.model),
'opt': self.opt.state_dict(),
'ema': self.ema.state_dict(),
'scaler': self.accelerator.scaler.state_dict() if exists(self.accelerator.scaler) else None,
'version': __version__
}
torch.save(data, str(self.results_folder / f'model-{milestone}.pt'))
def load(self, milestone):
accelerator = self.accelerator
device = accelerator.device
data = torch.load(str(self.results_folder / f'model-{milestone}.pt'), map_location=device)
model = self.accelerator.unwrap_model(self.model)
model.load_state_dict(data['model'])
self.step = data['step']
self.opt.load_state_dict(data['opt'])
self.ema.load_state_dict(data['ema'])
if 'version' in data:
print(f"loading from version {data['version']}")
if exists(self.accelerator.scaler) and exists(data['scaler']):
self.accelerator.scaler.load_state_dict(data['scaler'])
def train(self):
accelerator = self.accelerator
device = accelerator.device
with tqdm(initial = self.step, total = self.train_num_steps, disable = not accelerator.is_main_process) as pbar:
while self.step < self.train_num_steps:
total_loss = 0.
for _ in range(self.gradient_accumulate_every):
data = next(self.dl).to(device)
with self.accelerator.autocast():
loss = self.model(data)
loss = loss / self.gradient_accumulate_every
total_loss += loss.item()
self.accelerator.backward(loss)
accelerator.clip_grad_norm_(self.model.parameters(), 1.0)
pbar.set_description(f'loss: {total_loss:.4f}')
accelerator.wait_for_everyone()
self.opt.step()
self.opt.zero_grad()
accelerator.wait_for_everyone()
self.step += 1
if accelerator.is_main_process:
self.ema.to(device)
self.ema.update()
if self.step != 0 and self.step % self.save_and_sample_every == 0:
self.ema.ema_model.eval()
with torch.no_grad():
milestone = self.step // self.save_and_sample_every
batches = num_to_groups(self.num_samples, self.batch_size)
all_images_list = list(map(lambda n: self.ema.ema_model.sample(batch_size=n), batches))
all_images = torch.cat(all_images_list, dim = 0)
utils.save_image(all_images, str(self.results_folder / f'sample-{milestone}.png'), nrow = int(math.sqrt(self.num_samples)))
self.save(milestone)
pbar.update(1)
accelerator.print('training complete')
if __name__ == '__main__':
class Classifier(nn.Module):
def __init__(self, image_size, num_classes, t_dim=1) -> None:
super().__init__()
self.linear_t = nn.Linear(t_dim, num_classes)
self.linear_img = nn.Linear(image_size * image_size * 3, num_classes)
def forward(self, x, t):
"""
Args:
x (_type_): [B, 3, N, N]
t (_type_): [B,]
Returns:
logits [B, num_classes]
"""
B = x.shape[0]
t = t.view(B, 1)
logits = self.linear_t(t.float()) + self.linear_img(x.view(x.shape[0], -1))
return logits
def classifier_cond_fn(x, t, classifier, y, classifier_scale=1):
"""
return the graident of the classifier outputing y wrt x.
formally expressed as d_log(classifier(x, t)) / dx
"""
assert y is not None
with torch.enable_grad():
x_in = x.detach().requires_grad_(True)
logits = classifier(x_in, t)
log_probs = F.log_softmax(logits, dim=-1)
selected = log_probs[range(len(logits)), y.view(-1)]
grad = torch.autograd.grad(selected.sum(), x_in)[0] * classifier_scale
return grad
model = Unet(
dim = 64,
dim_mults = (1, 2, 4, 8)
)
image_size = 128
diffusion = GaussianDiffusion(
model,
image_size = image_size,
timesteps = 1000
)
classifier = Classifier(image_size=image_size, num_classes=1000, t_dim=1)
batch_size = 4
sampled_images = diffusion.sample(
batch_size = batch_size,
cond_fn=classifier_cond_fn,
guidance_kwargs={
"classifier":classifier,
"y":torch.fill(torch.zeros(batch_size), 1).long(),
"classifier_scale":1,
}
)
sampled_images.shape
.\lucidrains\denoising-diffusion-pytorch\denoising_diffusion_pytorch\karras_unet.py
"""
the magnitude-preserving unet proposed in https://arxiv.org/abs/2312.02696 by Karras et al.
"""
import math
from math import sqrt, ceil
from functools import partial
import torch
from torch import nn, einsum
from torch.nn import Module, ModuleList
from torch.optim.lr_scheduler import LambdaLR
import torch.nn.functional as F
from einops import rearrange, repeat, pack, unpack
from denoising_diffusion_pytorch.attend import Attend
def exists(x):
return x is not None
def default(val, d):
if exists(val):
return val
return d() if callable(d) else d
def xnor(x, y):
return not (x ^ y)
def append(arr, el):
arr.append(el)
def prepend(arr, el):
arr.insert(0, el)
def pack_one(t, pattern):
return pack([t], pattern)
def unpack_one(t, ps, pattern):
return unpack(t, ps, pattern)[0]
def cast_tuple(t, length = 1):
if isinstance(t, tuple):
return t
return ((t,) * length)
def divisible_by(numer, denom):
return (numer % denom) == 0
def l2norm(t, dim = -1, eps = 1e-12):
return F.normalize(t, dim = dim, eps = eps)
class MPSiLU(Module):
def forward(self, x):
return F.silu(x) / 0.596
class Gain(Module):
def __init__(self):
super().__init__()
self.gain = nn.Parameter(torch.tensor(0.))
def forward(self, x):
return x * self.gain
class MPCat(Module):
def __init__(self, t = 0.5, dim = -1):
super().__init__()
self.t = t
self.dim = dim
def forward(self, a, b):
dim, t = self.dim, self.t
Na, Nb = a.shape[dim], b.shape[dim]
C = sqrt((Na + Nb) / ((1. - t) ** 2 + t ** 2))
a = a * (1. - t) / sqrt(Na)
b = b * t / sqrt(Nb)
return C * torch.cat((a, b), dim = dim)
class MPAdd(Module):
def __init__(self, t):
super().__init__()
self.t = t
def forward(self, x, res):
a, b, t = x, res, self.t
num = a * (1. - t) + b * t
den = sqrt((1 - t) ** 2 + t ** 2)
return num / den
class PixelNorm(Module):
def __init__(self, dim, eps = 1e-4):
super().__init__()
self.dim = dim
self.eps = eps
def forward(self, x):
dim = self.dim
return l2norm(x, dim = dim, eps = self.eps) * sqrt(x.shape[dim])
def normalize_weight(weight, eps = 1e-4):
weight, ps = pack_one(weight, 'o *')
normed_weight = l2norm(weight, eps = eps)
normed_weight = normed_weight * sqrt(weight.numel() / weight.shape[0])
return unpack_one(normed_weight, ps, 'o *')
class Conv2d(Module):
def __init__(
self,
dim_in,
dim_out,
kernel_size,
eps = 1e-4,
concat_ones_to_input = False
):
super().__init__()
weight = torch.randn(dim_out, dim_in + int(concat_ones_to_input), kernel_size, kernel_size)
self.weight = nn.Parameter(weight)
self.eps = eps
self.fan_in = dim_in * kernel_size ** 2
self.concat_ones_to_input = concat_ones_to_input
def forward(self, x):
if self.training:
with torch.no_grad():
normed_weight = normalize_weight(self.weight, eps = self.eps)
self.weight.copy_(normed_weight)
weight = normalize_weight(self.weight, eps = self.eps) / sqrt(self.fan_in)
if self.concat_ones_to_input:
x = F.pad(x, (0, 0, 0, 0, 1, 0), value = 1.)
return F.conv2d(x, weight, padding='same')
class Linear(Module):
def __init__(self, dim_in, dim_out, eps = 1e-4):
super().__init__()
weight = torch.randn(dim_out, dim_in)
self.weight = nn.Parameter(weight)
self.eps = eps
self.fan_in = dim_in
def forward(self, x):
if self.training:
with torch.no_grad():
normed_weight = normalize_weight(self.weight, eps = self.eps)
self.weight.copy_(normed_weight)
weight = normalize_weight(self.weight, eps = self.eps) / sqrt(self.fan_in)
return F.linear(x, weight)
class MPFourierEmbedding(Module):
def __init__(self, dim):
super().__init__()
assert divisible_by(dim, 2)
half_dim = dim // 2
self.weights = nn.Parameter(torch.randn(half_dim), requires_grad = False)
def forward(self, x):
x = rearrange(x, 'b -> b 1')
freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi
return torch.cat((freqs.sin(), freqs.cos()), dim = -1) * sqrt(2)
class Encoder(Module):
def __init__(
self,
dim,
dim_out = None,
*,
emb_dim = None,
dropout = 0.1,
mp_add_t = 0.3,
has_attn = False,
attn_dim_head = 64,
attn_res_mp_add_t = 0.3,
attn_flash = False,
downsample = False
):
super().__init__()
dim_out = default(dim_out, dim)
self.downsample = downsample
self.downsample_conv = None
curr_dim = dim
if downsample:
self.downsample_conv = Conv2d(curr_dim, dim_out, 1)
curr_dim = dim_out
self.pixel_norm = PixelNorm(dim = 1)
self.to_emb = None
if exists(emb_dim):
self.to_emb = nn.Sequential(
Linear(emb_dim, dim_out),
Gain()
)
self.block1 = nn.Sequential(
MPSiLU(),
Conv2d(curr_dim, dim_out, 3)
)
self.block2 = nn.Sequential(
MPSiLU(),
nn.Dropout(dropout),
Conv2d(dim_out, dim_out, 3)
)
self.res_mp_add = MPAdd(t = mp_add_t)
self.attn = None
if has_attn:
self.attn = Attention(
dim = dim_out,
heads = max(ceil(dim_out / attn_dim_head), 2),
dim_head = attn_dim_head,
mp_add_t = attn_res_mp_add_t,
flash = attn_flash
)
def forward(
self,
x,
emb = None
):
if self.downsample:
h, w = x.shape[-2:]
x = F.interpolate(x, (h // 2, w // 2), mode = 'bilinear')
x = self.downsample_conv(x)
x = self.pixel_norm(x)
res = x.clone()
x = self.block1(x)
if exists(emb):
scale = self.to_emb(emb) + 1
x = x * rearrange(scale, 'b c -> b c 1 1')
x = self.block2(x)
x = self.res_mp_add(x, res)
if exists(self.attn):
x = self.attn(x)
return x
class Decoder(Module):
def __init__(
self,
dim,
dim_out = None,
*,
emb_dim = None,
dropout = 0.1,
mp_add_t = 0.3,
has_attn = False,
attn_dim_head = 64,
attn_res_mp_add_t = 0.3,
attn_flash = False,
upsample = False
):
super().__init__()
dim_out = default(dim_out, dim)
self.upsample = upsample
self.needs_skip = not upsample
self.to_emb = None
if exists(emb_dim):
self.to_emb = nn.Sequential(
Linear(emb_dim, dim_out),
Gain()
)
self.block1 = nn.Sequential(
MPSiLU(),
Conv2d(dim, dim_out, 3)
)
self.block2 = nn.Sequential(
MPSiLU(),
nn.Dropout(dropout),
Conv2d(dim_out, dim_out, 3)
)
self.res_conv = Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
self.res_mp_add = MPAdd(t = mp_add_t)
self.attn = None
if has_attn:
self.attn = Attention(
dim = dim_out,
heads = max(ceil(dim_out / attn_dim_head), 2),
dim_head = attn_dim_head,
mp_add_t = attn_res_mp_add_t,
flash = attn_flash
)
def forward(
self,
x,
emb = None
):
if self.upsample:
h, w = x.shape[-2:]
x = F.interpolate(x, (h * 2, w * 2), mode = 'bilinear')
res = self.res_conv(x)
x = self.block1(x)
if exists(emb):
scale = self.to_emb(emb) + 1
x = x * rearrange(scale, 'b c -> b c 1 1')
x = self.block2(x)
x = self.res_mp_add(x, res)
if exists(self.attn):
x = self.attn(x)
return x
class Attention(Module):
def __init__(
self,
dim,
heads = 4,
dim_head = 64,
num_mem_kv = 4,
flash = False,
mp_add_t = 0.3
):
super().__init__()
self.heads = heads
hidden_dim = dim_head * heads
self.pixel_norm = PixelNorm(dim = -1)
self.attend = Attend(flash = flash)
self.mem_kv = nn.Parameter(torch.randn(2, heads, num_mem_kv, dim_head))
self.to_qkv = Conv2d(dim, hidden_dim * 3, 1)
self.to_out = Conv2d(hidden_dim, dim, 1)
self.mp_add = MPAdd(t = mp_add_t)
def forward(self, x):
res, b, c, h, w = x, *x.shape
qkv = self.to_qkv(x).chunk(3, dim = 1)
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h (x y) c', h = self.heads), qkv)
mk, mv = map(lambda t: repeat(t, 'h n d -> b h n d', b = b), self.mem_kv)
k, v = map(partial(torch.cat, dim = -2), ((mk, k), (mv, v)))
q, k, v = map(self.pixel_norm, (q, k, v))
out = self.attend(q, k, v)
out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w)
out = self.to_out(out)
return self.mp_add(out, res)
class KarrasUnet(Module):
"""
根据图 21 配置 G
"""
def __init__(
self,
*,
image_size,
dim = 192,
dim_max = 768,
num_classes = None,
channels = 4,
num_downsamples = 3,
num_blocks_per_stage = 4,
attn_res = (16, 8),
fourier_dim = 16,
attn_dim_head = 64,
attn_flash = False,
mp_cat_t = 0.5,
mp_add_emb_t = 0.5,
attn_res_mp_add_t = 0.3,
resnet_mp_add_t = 0.3,
dropout = 0.1,
self_condition = False
):
super().__init__()
self.self_condition = self_condition
self.channels = channels
self.image_size = image_size
input_channels = channels * (2 if self_condition else 1)
self.input_block = Conv2d(input_channels, dim, 3, concat_ones_to_input = True)
self.output_block = nn.Sequential(
Conv2d(dim, channels, 3),
Gain()
)
emb_dim = dim * 4
self.to_time_emb = nn.Sequential(
MPFourierEmbedding(fourier_dim),
Linear(fourier_dim, emb_dim)
)
self.needs_class_labels = exists(num_classes)
self.num_classes = num_classes
if self.needs_class_labels:
self.to_class_emb = Linear(num_classes, 4 * dim)
self.add_class_emb = MPAdd(t = mp_add_emb_t)
self.emb_activation = MPSiLU()
self.num_downsamples = num_downsamples
attn_res = set(cast_tuple(attn_res))
block_kwargs = dict(
dropout = dropout,
emb_dim = emb_dim,
attn_dim_head = attn_dim_head,
attn_res_mp_add_t = attn_res_mp_add_t,
attn_flash = attn_flash
)
self.downs = ModuleList([])
self.ups = ModuleList([])
curr_dim = dim
curr_res = image_size
self.skip_mp_cat = MPCat(t = mp_cat_t, dim = 1)
prepend(self.ups, Decoder(dim * 2, dim, **block_kwargs))
assert num_blocks_per_stage >= 1
for _ in range(num_blocks_per_stage):
enc = Encoder(curr_dim, curr_dim, **block_kwargs)
dec = Decoder(curr_dim * 2, curr_dim, **block_kwargs)
append(self.downs, enc)
prepend(self.ups, dec)
for _ in range(self.num_downsamples):
dim_out = min(dim_max, curr_dim * 2)
upsample = Decoder(dim_out, curr_dim, has_attn = curr_res in attn_res, upsample = True, **block_kwargs)
curr_res //= 2
has_attn = curr_res in attn_res
downsample = Encoder(curr_dim, dim_out, downsample = True, has_attn = has_attn, **block_kwargs)
append(self.downs, downsample)
prepend(self.ups, upsample)
prepend(self.ups, Decoder(dim_out * 2, dim_out, has_attn = has_attn, **block_kwargs))
for _ in range(num_blocks_per_stage):
enc = Encoder(dim_out, dim_out, has_attn = has_attn, **block_kwargs)
dec = Decoder(dim_out * 2, dim_out, has_attn = has_attn, **block_kwargs)
append(self.downs, enc)
prepend(self.ups, dec)
curr_dim = dim_out
mid_has_attn = curr_res in attn_res
self.mids = ModuleList([
Decoder(curr_dim, curr_dim, has_attn = mid_has_attn, **block_kwargs),
Decoder(curr_dim, curr_dim, has_attn = mid_has_attn, **block_kwargs),
])
self.out_dim = channels
@property
def downsample_factor(self):
return 2 ** self.num_downsamples
def forward(
self,
x,
time,
self_cond = None,
class_labels = None
):
assert x.shape[1:] == (self.channels, self.image_size, self.image_size)
if self.self_condition:
self_cond = default(self_cond, lambda: torch.zeros_like(x))
x = torch.cat((self_cond, x), dim = 1)
else:
assert not exists(self_cond)
time_emb = self.to_time_emb(time)
assert xnor(exists(class_labels), self.needs_class_labels)
if self.needs_class_labels:
if class_labels.dtype in (torch.int, torch.long):
class_labels = F.one_hot(class_labels, self.num_classes)
assert class_labels.shape[-1] == self.num_classes
class_labels = class_labels.float() * sqrt(self.num_classes)
class_emb = self.to_class_emb(class_labels)
time_emb = self.add_class_emb(time_emb, class_emb)
emb = self.emb_activation(time_emb)
skips = []
x = self.input_block(x)
skips.append(x)
for encoder in self.downs:
x = encoder(x, emb = emb)
skips.append(x)
for decoder in self.mids:
x = decoder(x, emb = emb)
for decoder in self.ups:
if decoder.needs_skip:
skip = skips.pop()
x = self.skip_mp_cat(x, skip)
x = decoder(x, emb = emb)
return self.output_block(x)
class MPFeedForward(Module):
def __init__(
self,
*,
dim,
mult = 4,
mp_add_t = 0.3
):
super().__init__()
dim_inner = int(dim * mult)
self.net = nn.Sequential(
PixelNorm(dim = 1),
Conv2d(dim, dim_inner, 1),
MPSiLU(),
Conv2d(dim_inner, dim, 1)
)
self.mp_add = MPAdd(t = mp_add_t)
def forward(self, x):
res = x
out = self.net(x)
return self.mp_add(out, res)
class MPImageTransformer(Module):
def __init__(
self,
*,
dim,
depth,
dim_head = 64,
heads = 8,
num_mem_kv = 4,
ff_mult = 4,
attn_flash = False,
residual_mp_add_t = 0.3
):
super().__init__()
self.layers = ModuleList([])
for _ in range(depth):
self.layers.append(ModuleList([
Attention(dim = dim, heads = heads, dim_head = dim_head, num_mem_kv = num_mem_kv, flash = attn_flash, mp_add_t = residual_mp_add_t),
MPFeedForward(dim = dim, mult = ff_mult, mp_add_t = residual_mp_add_t)
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x)
x = ff(x)
return x
def InvSqrtDecayLRSched(
optimizer,
t_ref = 70000,
sigma_ref = 0.01
):
"""
refer to equation 67 and Table1
"""
def inv_sqrt_decay_fn(t: int):
return sigma_ref / sqrt(max(t / t_ref, 1.))
return LambdaLR(optimizer, lr_lambda = inv_sqrt_decay_fn)
if __name__ == '__main__':
unet = KarrasUnet(
image_size = 64,
dim = 192,
dim_max = 768,
num_classes = 1000,
)
images = torch.randn(2, 4, 64, 64)
denoised_images = unet(
images,
time = torch.ones(2,),
class_labels = torch.randint(0, 1000, (2,))
)
assert denoised_images.shape == images.shape