Lucidrains 系列项目源码解析(四十五)
.\lucidrains\imagen-pytorch\imagen_pytorch\imagen_video.py
import math
import operator
import functools
from tqdm.auto import tqdm
from functools import partial, wraps
from pathlib import Path
import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat, pack, unpack
from einops.layers.torch import Rearrange
from imagen_pytorch.t5 import t5_encode_text, get_encoded_dim, DEFAULT_T5_NAME
def exists(val):
return val is not None
def identity(t, *args, **kwargs):
return t
def first(arr, d = None):
if len(arr) == 0:
return d
return arr[0]
def divisible_by(numer, denom):
return (numer % denom) == 0
def maybe(fn):
@wraps(fn)
def inner(x):
if not exists(x):
return x
return fn(x)
return inner
def once(fn):
called = False
@wraps(fn)
def inner(x):
nonlocal called
if called:
return
called = True
return fn(x)
return inner
print_once = once(print)
def default(val, d):
if exists(val):
return val
return d() if callable(d) else d
def cast_tuple(val, length = None):
if isinstance(val, list):
val = tuple(val)
output = val if isinstance(val, tuple) else ((val,) * default(length, 1))
if exists(length):
assert len(output) == length
return output
def cast_uint8_images_to_float(images):
if not images.dtype == torch.uint8:
return images
return images / 255
def module_device(module):
return next(module.parameters()).device
def zero_init_(m):
nn.init.zeros_(m.weight)
if exists(m.bias):
nn.init.zeros_(m.bias)
def eval_decorator(fn):
def inner(model, *args, **kwargs):
was_training = model.training
model.eval()
out = fn(model, *args, **kwargs)
model.train(was_training)
return out
return inner
def pad_tuple_to_length(t, length, fillvalue = None):
remain_length = length - len(t)
if remain_length <= 0:
return t
return (*t, *((fillvalue,) * remain_length))
class Identity(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
def forward(self, x, *args, **kwargs):
return x
def Sequential(*modules):
return nn.Sequential(*filter(exists, modules))
def log(t, eps: float = 1e-12):
return torch.log(t.clamp(min = eps))
def l2norm(t):
return F.normalize(t, dim = -1)
def right_pad_dims_to(x, t):
padding_dims = x.ndim - t.ndim
if padding_dims <= 0:
return t
return t.view(*t.shape, *((1,) * padding_dims))
def masked_mean(t, *, dim, mask = None):
if not exists(mask):
return t.mean(dim = dim)
denom = mask.sum(dim = dim, keepdim = True)
mask = rearrange(mask, 'b n -> b n 1')
masked_t = t.masked_fill(~mask, 0.)
return masked_t.sum(dim = dim) / denom.clamp(min = 1e-5)
def resize_video_to(
video,
target_image_size,
target_frames = None,
clamp_range = None,
mode = 'nearest'
):
orig_video_size = video.shape[-1]
frames = video.shape[2]
target_frames = default(target_frames, frames)
target_shape = (target_frames, target_image_size, target_image_size)
if tuple(video.shape[-3:]) == target_shape:
return video
out = F.interpolate(video, target_shape, mode = mode)
if exists(clamp_range):
out = out.clamp(*clamp_range)
return out
def scale_video_time(
video,
downsample_scale = 1,
mode = 'nearest'
):
if downsample_scale == 1:
return video
image_size, frames = video.shape[-1], video.shape[-3]
assert divisible_by(frames, downsample_scale), f'trying to temporally downsample a conditioning video frames of length {frames} by {downsample_scale}, however it is not neatly divisible'
target_frames = frames // downsample_scale
resized_video = resize_video_to(
video,
image_size,
target_frames = target_frames,
mode = mode
)
return resized_video
def prob_mask_like(shape, prob, device):
if prob == 1:
return torch.ones(shape, device=device, dtype=torch.bool)
elif prob == 0:
return torch.zeros(shape, device=device, dtype=torch.bool)
else:
return torch.zeros(shape, device=device).float().uniform_(0, 1) < prob
class LayerNorm(nn.Module):
def __init__(self, dim, stable=False):
super().__init__()
self.stable = stable
self.g = nn.Parameter(torch.ones(dim))
def forward(self, x):
if self.stable:
x = x / x.amax(dim=-1, keepdim=True).detach()
eps = 1e-5 if x.dtype == torch.float32 else 1e-3
var = torch.var(x, dim=-1, unbiased=False, keepdim=True)
mean = torch.mean(x, dim=-1, keepdim=True)
return (x - mean) * (var + eps).rsqrt() * self.g
class ChanLayerNorm(nn.Module):
def __init__(self, dim, stable=False):
super().__init__()
self.stable = stable
self.g = nn.Parameter(torch.ones(1, dim, 1, 1, 1))
def forward(self, x):
if self.stable:
x = x / x.amax(dim=1, keepdim=True).detach()
eps = 1e-5 if x.dtype == torch.float32 else 1e-3
var = torch.var(x, dim=1, unbiased=False, keepdim=True)
mean = torch.mean(x, dim=1, keepdim=True)
return (x - mean) * (var + eps).rsqrt() * self.g
class Always():
def __init__(self, val):
self.val = val
def __call__(self, *args, **kwargs):
return self.val
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) + x
class Parallel(nn.Module):
def __init__(self, *fns):
super().__init__()
self.fns = nn.ModuleList(fns)
def forward(self, x):
outputs = [fn(x) for fn in self.fns]
return sum(outputs)
class RearrangeTimeCentric(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x):
x = rearrange(x, 'b c f ... -> b ... f c')
x, ps = pack([x], '* f c')
x = self.fn(x)
x, = unpack(x, ps, '* f c')
x = rearrange(x, 'b ... f c -> b c f ...')
return x
class PerceiverAttention(nn.Module):
def __init__(
self,
*,
dim,
dim_head=64,
heads=8,
scale=8
):
super().__init__()
self.scale = scale
self.heads = heads
inner_dim = dim_head * heads
self.norm = nn.LayerNorm(dim)
self.norm_latents = nn.LayerNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
self.q_scale = nn.Parameter(torch.ones(dim_head))
self.k_scale = nn.Parameter(torch.ones(dim_head))
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim, bias=False),
nn.LayerNorm(dim)
)
def forward(self, x, latents, mask = None):
x = self.norm(x)
latents = self.norm_latents(latents)
b, h = x.shape[0], self.heads
q = self.to_q(latents)
kv_input = torch.cat((x, latents), dim = -2)
k, v = self.to_kv(kv_input).chunk(2, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
q, k = map(l2norm, (q, k))
q = q * self.q_scale
k = k * self.k_scale
sim = einsum('... i d, ... j d -> ... i j', q, k) * self.scale
if exists(mask):
max_neg_value = -torch.finfo(sim.dtype).max
mask = F.pad(mask, (0, latents.shape[-2]), value = True)
mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, max_neg_value)
attn = sim.softmax(dim = -1)
out = einsum('... i j, ... j d -> ... i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)', h = h)
return self.to_out(out)
class PerceiverResampler(nn.Module):
def __init__(
self,
*,
dim,
depth,
dim_head = 64,
heads = 8,
num_latents = 64,
num_latents_mean_pooled = 4,
max_seq_len = 512,
ff_mult = 4
):
super().__init__()
self.pos_emb = nn.Embedding(max_seq_len, dim)
self.latents = nn.Parameter(torch.randn(num_latents, dim))
self.to_latents_from_mean_pooled_seq = None
if num_latents_mean_pooled > 0:
self.to_latents_from_mean_pooled_seq = nn.Sequential(
LayerNorm(dim),
nn.Linear(dim, dim * num_latents_mean_pooled),
Rearrange('b (n d) -> b n d', n = num_latents_mean_pooled)
)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PerceiverAttention(dim = dim, dim_head = dim_head, heads = heads),
FeedForward(dim = dim, mult = ff_mult)
]))
def forward(self, x, mask = None):
n, device = x.shape[1], x.device
pos_emb = self.pos_emb(torch.arange(n, device = device))
x_with_pos = x + pos_emb
latents = repeat(self.latents, 'n d -> b n d', b = x.shape[0])
if exists(self.to_latents_from_mean_pooled_seq):
meanpooled_seq = masked_mean(x, dim = 1, mask = torch.ones(x.shape[:2], device = x.device, dtype = torch.bool))
meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
latents = torch.cat((meanpooled_latents, latents), dim = -2)
for attn, ff in self.layers:
latents = attn(x_with_pos, latents, mask = mask) + latents
latents = ff(latents) + latents
return latents
class Conv3d(nn.Module):
def __init__(
self,
dim,
dim_out = None,
kernel_size = 3,
*,
temporal_kernel_size = None,
**kwargs
):
super().__init__()
dim_out = default(dim_out, dim)
temporal_kernel_size = default(temporal_kernel_size, kernel_size)
self.spatial_conv = nn.Conv2d(dim, dim_out, kernel_size = kernel_size, padding = kernel_size // 2)
self.temporal_conv = nn.Conv1d(dim_out, dim_out, kernel_size = temporal_kernel_size) if kernel_size > 1 else None
self.kernel_size = kernel_size
if exists(self.temporal_conv):
nn.init.dirac_(self.temporal_conv.weight.data)
nn.init.zeros_(self.temporal_conv.bias.data)
def forward(
self,
x,
ignore_time = False
):
b, c, *_, h, w = x.shape
is_video = x.ndim == 5
ignore_time &= is_video
if is_video:
x = rearrange(x, 'b c f h w -> (b f) c h w')
x = self.spatial_conv(x)
if is_video:
x = rearrange(x, '(b f) c h w -> b c f h w', b = b)
if ignore_time or not exists(self.temporal_conv):
return x
x = rearrange(x, 'b c f h w -> (b h w) c f')
if self.kernel_size > 1:
x = F.pad(x, (self.kernel_size - 1, 0))
x = self.temporal_conv(x)
x = rearrange(x, '(b h w) c f -> b c f h w', h = h, w = w)
return x
class Attention(nn.Module):
def __init__(
self,
dim,
*,
dim_head = 64,
heads = 8,
causal = False,
context_dim = None,
rel_pos_bias = False,
rel_pos_bias_mlp_depth = 2,
init_zero = False,
scale = 8
):
super().__init__()
self.scale = scale
self.causal = causal
self.rel_pos_bias = DynamicPositionBias(dim = dim, heads = heads, depth = rel_pos_bias_mlp_depth) if rel_pos_bias else None
self.heads = heads
inner_dim = dim_head * heads
self.norm = LayerNorm(dim)
self.null_attn_bias = nn.Parameter(torch.randn(heads))
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
self.q_scale = nn.Parameter(torch.ones(dim_head))
self.k_scale = nn.Parameter(torch.ones(dim_head))
self.to_context = nn.Sequential(nn.LayerNorm(context_dim), nn.Linear(context_dim, dim_head * 2)) if exists(context_dim) else None
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim, bias = False),
LayerNorm(dim)
)
if init_zero:
nn.init.zeros_(self.to_out[-1].g)
def forward(
self,
x,
context = None,
mask = None,
attn_bias = None
):
b, n, device = *x.shape[:2], x.device
x = self.norm(x)
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1))
q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)
nk, nv = map(lambda t: repeat(t, 'd -> b 1 d', b = b), self.null_kv.unbind(dim = -2))
k = torch.cat((nk, k), dim = -2)
v = torch.cat((nv, v), dim = -2)
if exists(context):
assert exists(self.to_context)
ck, cv = self.to_context(context).chunk(2, dim = -1)
k = torch.cat((ck, k), dim = -2)
v = torch.cat((cv, v), dim = -2)
q, k = map(l2norm, (q, k))
q = q * self.q_scale
k = k * self.k_scale
sim = einsum('b h i d, b j d -> b h i j', q, k) * self.scale
if not exists(attn_bias) and exists(self.rel_pos_bias):
attn_bias = self.rel_pos_bias(n, device = device, dtype = q.dtype)
if exists(attn_bias):
null_attn_bias = repeat(self.null_attn_bias, 'h -> h n 1', n = n)
attn_bias = torch.cat((null_attn_bias, attn_bias), dim = -1)
sim = sim + attn_bias
max_neg_value = -torch.finfo(sim.dtype).max
if self.causal:
i, j = sim.shape[-2:]
causal_mask = torch.ones((i, j), device = device, dtype = torch.bool).triu(j - i + 1)
sim = sim.masked_fill(causal_mask, max_neg_value)
if exists(mask):
mask = F.pad(mask, (1, 0), value = True)
mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, max_neg_value)
attn = sim.softmax(dim = -1)
out = einsum('b h i j, b j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
def Conv2d(dim_in, dim_out, kernel, stride = 1, padding = 0, **kwargs):
kernel = cast_tuple(kernel, 2)
stride = cast_tuple(stride, 2)
padding = cast_tuple(padding, 2)
if len(kernel) == 2:
kernel = (1, *kernel)
if len(stride) == 2:
stride = (1, *stride)
if len(padding) == 2:
padding = (0, *padding)
return nn.Conv3d(dim_in, dim_out, kernel, stride = stride, padding = padding, **kwargs)
class Pad(nn.Module):
def __init__(self, padding, value = 0.):
super().__init__()
self.padding = padding
self.value = value
def forward(self, x):
return F.pad(x, self.padding, value = self.value)
def Upsample(dim, dim_out = None):
dim_out = default(dim_out, dim)
return nn.Sequential(
nn.Upsample(scale_factor = 2, mode = 'nearest'),
Conv2d(dim, dim_out, 3, padding = 1)
)
class PixelShuffleUpsample(nn.Module):
def __init__(self, dim, dim_out = None):
super().__init__()
dim_out = default(dim_out, dim)
conv = Conv2d(dim, dim_out * 4, 1)
self.net = nn.Sequential(
conv,
nn.SiLU()
)
self.pixel_shuffle = nn.PixelShuffle(2)
self.init_conv_(conv)
def init_conv_(self, conv):
o, i, f, h, w = conv.weight.shape
conv_weight = torch.empty(o // 4, i, f, h, w)
nn.init.kaiming_uniform_(conv_weight)
conv_weight = repeat(conv_weight, 'o ... -> (o 4) ...')
conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data)
def forward(self, x):
out = self.net(x)
frames = x.shape[2]
out = rearrange(out, 'b c f h w -> (b f) c h w')
out = self.pixel_shuffle(out)
return rearrange(out, '(b f) c h w -> b c f h w', f = frames)
def Downsample(dim, dim_out = None):
dim_out = default(dim_out, dim)
return nn.Sequential(
Rearrange('b c f (h p1) (w p2) -> b (c p1 p2) f h w', p1 = 2, p2 = 2),
Conv2d(dim * 4, dim_out, 1)
)
class TemporalPixelShuffleUpsample(nn.Module):
def __init__(self, dim, dim_out = None, stride = 2):
super().__init__()
self.stride = stride
dim_out = default(dim_out, dim)
conv = nn.Conv1d(dim, dim_out * stride, 1)
self.net = nn.Sequential(
conv,
nn.SiLU()
)
self.pixel_shuffle = Rearrange('b (c r) n -> b c (n r)', r = stride)
self.init_conv_(conv)
def init_conv_(self, conv):
o, i, f = conv.weight.shape
conv_weight = torch.empty(o // self.stride, i, f)
nn.init.kaiming_uniform_(conv_weight)
conv_weight = repeat(conv_weight, 'o ... -> (o r) ...', r = self.stride)
conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data)
def forward(self, x):
b, c, f, h, w = x.shape
x = rearrange(x, 'b c f h w -> (b h w) c f')
out = self.net(x)
out = self.pixel_shuffle(out)
return rearrange(out, '(b h w) c f -> b c f h w', h = h, w = w)
def TemporalDownsample(dim, dim_out = None, stride = 2):
dim_out = default(dim_out, dim)
return nn.Sequential(
Rearrange('b c (f p) h w -> b (c p) f h w', p = stride),
Conv2d(dim * stride, dim_out, 1)
)
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device = x.device) * -emb)
emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j')
return torch.cat((emb.sin(), emb.cos()), dim = -1)
class LearnedSinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
assert (dim % 2) == 0
half_dim = dim // 2
self.weights = nn.Parameter(torch.randn(half_dim))
def forward(self, x):
x = rearrange(x, 'b -> b 1')
freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1)
fouriered = torch.cat((x, fouriered), dim = -1)
return fouriered
class Block(nn.Module):
def __init__(
self,
dim,
dim_out,
groups = 8,
norm = True
):
super().__init__()
self.groupnorm = nn.GroupNorm(groups, dim) if norm else Identity()
self.activation = nn.SiLU()
self.project = Conv3d(dim, dim_out, 3, padding = 1)
def forward(
self,
x,
scale_shift = None,
ignore_time = False
):
x = self.groupnorm(x)
if exists(scale_shift):
scale, shift = scale_shift
x = x * (scale + 1) + shift
x = self.activation(x)
return self.project(x, ignore_time = ignore_time)
class ResnetBlock(nn.Module):
def __init__(
self,
dim,
dim_out,
*,
cond_dim = None,
time_cond_dim = None,
groups = 8,
linear_attn = False,
use_gca = False,
squeeze_excite = False,
**attn_kwargs
):
super().__init__()
self.time_mlp = None
if exists(time_cond_dim):
self.time_mlp = nn.Sequential(
nn.SiLU(),
nn.Linear(time_cond_dim, dim_out * 2)
)
self.cross_attn = None
if exists(cond_dim):
attn_klass = CrossAttention if not linear_attn else LinearCrossAttention
self.cross_attn = attn_klass(
dim = dim_out,
context_dim = cond_dim,
**attn_kwargs
)
self.block1 = Block(dim, dim_out, groups = groups)
self.block2 = Block(dim_out, dim_out, groups = groups)
self.gca = GlobalContext(dim_in = dim_out, dim_out = dim_out) if use_gca else Always(1)
self.res_conv = Conv2d(dim, dim_out, 1) if dim != dim_out else Identity()
def forward(
self,
x,
time_emb = None,
cond = None,
ignore_time = False
):
scale_shift = None
if exists(self.time_mlp) and exists(time_emb):
time_emb = self.time_mlp(time_emb)
time_emb = rearrange(time_emb, 'b c -> b c 1 1 1')
scale_shift = time_emb.chunk(2, dim = 1)
h = self.block1(x, ignore_time = ignore_time)
if exists(self.cross_attn):
assert exists(cond)
h = rearrange(h, 'b c ... -> b ... c')
h, ps = pack([h], 'b * c')
h = self.cross_attn(h, context = cond) + h
h, = unpack(h, ps, 'b * c')
h = rearrange(h, 'b ... c -> b c ...')
h = self.block2(h, scale_shift = scale_shift, ignore_time = ignore_time)
h = h * self.gca(h)
return h + self.res_conv(x)
class CrossAttention(nn.Module):
def __init__(
self,
dim,
*,
context_dim = None,
dim_head = 64,
heads = 8,
norm_context = False,
scale = 8
):
super().__init__()
self.scale = scale
self.heads = heads
inner_dim = dim_head * heads
context_dim = default(context_dim, dim)
self.norm = LayerNorm(dim)
self.norm_context = LayerNorm(context_dim) if norm_context else Identity()
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)
self.q_scale = nn.Parameter(torch.ones(dim_head))
self.k_scale = nn.Parameter(torch.ones(dim_head))
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim, bias = False),
LayerNorm(dim)
)
def forward(self, x, context, mask = None):
b, n, device = *x.shape[:2], x.device
x = self.norm(x)
context = self.norm_context(context)
q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v))
nk, nv = map(lambda t: repeat(t, 'd -> b h 1 d', h = self.heads, b = b), self.null_kv.unbind(dim = -2))
k = torch.cat((nk, k), dim = -2)
v = torch.cat((nv, v), dim = -2)
q, k = map(l2norm, (q, k))
q = q * self.q_scale
k = k * self.k_scale
sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
max_neg_value = -torch.finfo(sim.dtype).max
if exists(mask):
mask = F.pad(mask, (1, 0), value = True)
mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, max_neg_value)
attn = sim.softmax(dim = -1, dtype = torch.float32)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class LinearCrossAttention(CrossAttention):
def forward(self, x, context, mask = None):
b, n, device = *x.shape[:2], x.device
x = self.norm(x)
context = self.norm_context(context)
q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = self.heads), (q, k, v))
nk, nv = map(lambda t: repeat(t, 'd -> (b h) 1 d', h = self.heads, b = b), self.null_kv.unbind(dim = -2))
k = torch.cat((nk, k), dim = -2)
v = torch.cat((nv, v), dim = -2)
max_neg_value = -torch.finfo(x.dtype).max
if exists(mask):
mask = F.pad(mask, (1, 0), value = True)
mask = rearrange(mask, 'b n -> b n 1')
k = k.masked_fill(~mask, max_neg_value)
v = v.masked_fill(~mask, 0.)
q = q.softmax(dim = -1)
k = k.softmax(dim = -2)
q = q * self.scale
context = einsum('b n d, b n e -> b d e', k, v)
out = einsum('b n d, b d e -> b n e', q, context)
out = rearrange(out, '(b h) n d -> b n (h d)', h = self.heads)
return self.to_out(out)
class LinearAttention(nn.Module):
def __init__(
self,
dim,
dim_head = 32,
heads = 8,
dropout = 0.05,
context_dim = None,
**kwargs
):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
inner_dim = dim_head * heads
self.norm = ChanLayerNorm(dim)
self.nonlin = nn.SiLU()
self.to_q = nn.Sequential(
nn.Dropout(dropout),
Conv2d(dim, inner_dim, 1, bias = False),
Conv2d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim)
)
self.to_k = nn.Sequential(
nn.Dropout(dropout),
Conv2d(dim, inner_dim, 1, bias = False),
Conv2d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim)
)
self.to_v = nn.Sequential(
nn.Dropout(dropout),
Conv2d(dim, inner_dim, 1, bias = False),
Conv2d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim)
)
self.to_context = nn.Sequential(nn.LayerNorm(context_dim), nn.Linear(context_dim, inner_dim * 2, bias = False)) if exists(context_dim) else None
self.to_out = nn.Sequential(
Conv2d(inner_dim, dim, 1, bias = False),
ChanLayerNorm(dim)
)
def forward(self, fmap, context = None):
h, x, y = self.heads, *fmap.shape[-2:]
fmap = self.norm(fmap)
q, k, v = map(lambda fn: fn(fmap), (self.to_q, self.to_k, self.to_v))
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> (b h) (x y) c', h = h), (q, k, v))
if exists(context):
assert exists(self.to_context)
ck, cv = self.to_context(context).chunk(2, dim = -1)
ck, cv = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (ck, cv))
k = torch.cat((k, ck), dim = -2)
v = torch.cat((v, cv), dim = -2)
q = q.softmax(dim = -1)
k = k.softmax(dim = -2)
q = q * self.scale
context = einsum('b n d, b n e -> b d e', k, v)
out = einsum('b n d, b d e -> b n e', q, context)
out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, x = x, y = y)
out = self.nonlin(out)
return self.to_out(out)
class GlobalContext(nn.Module):
""" basically a superior form of squeeze-excitation that is attention-esque """
def __init__(
self,
*,
dim_in,
dim_out
):
super().__init__()
self.to_k = Conv2d(dim_in, 1, 1)
hidden_dim = max(3, dim_out // 2)
self.net = nn.Sequential(
Conv2d(dim_in, hidden_dim, 1),
nn.SiLU(),
Conv2d(hidden_dim, dim_out, 1),
nn.Sigmoid()
)
def forward(self, x):
context = self.to_k(x)
x, context = map(lambda t: rearrange(t, 'b n ... -> b n (...)'), (x, context))
out = einsum('b i n, b c n -> b c i', context.softmax(dim = -1), x)
out = rearrange(out, '... -> ... 1 1')
return self.net(out)
def FeedForward(dim, mult = 2):
hidden_dim = int(dim * mult)
return nn.Sequential(
LayerNorm(dim),
nn.Linear(dim, hidden_dim, bias = False),
nn.GELU(),
LayerNorm(hidden_dim),
nn.Linear(hidden_dim, dim, bias = False)
)
class TimeTokenShift(nn.Module):
def forward(self, x):
if x.ndim != 5:
return x
x, x_shift = x.chunk(2, dim = 1)
x_shift = F.pad(x_shift, (0, 0, 0, 0, 1, -1), value = 0.)
return torch.cat((x, x_shift), dim = 1)
def ChanFeedForward(dim, mult = 2, time_token_shift = True):
hidden_dim = int(dim * mult)
return Sequential(
ChanLayerNorm(dim),
Conv2d(dim, hidden_dim, 1, bias = False),
nn.GELU(),
TimeTokenShift() if time_token_shift else None,
ChanLayerNorm(hidden_dim),
Conv2d(hidden_dim, dim, 1, bias = False)
)
class TransformerBlock(nn.Module):
def __init__(
self,
dim,
*,
depth = 1,
heads = 8,
dim_head = 32,
ff_mult = 2,
ff_time_token_shift = True,
context_dim = None
):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim = dim, heads = heads, dim_head = dim_head, context_dim = context_dim),
ChanFeedForward(dim = dim, mult = ff_mult, time_token_shift = ff_time_token_shift)
]))
def forward(self, x, context = None):
for attn, ff in self.layers:
x = rearrange(x, 'b c ... -> b ... c')
x, ps = pack([x], 'b * c')
x = attn(x, context = context) + x
x, = unpack(x, ps, 'b * c')
x = rearrange(x, 'b ... c -> b c ...')
x = ff(x) + x
return x
class LinearAttentionTransformerBlock(nn.Module):
def __init__(
self,
dim,
*,
depth = 1,
heads = 8,
dim_head = 32,
ff_mult = 2,
ff_time_token_shift = True,
context_dim = None,
**kwargs
):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
LinearAttention(dim = dim, heads = heads, dim_head = dim_head, context_dim = context_dim),
ChanFeedForward(dim = dim, mult = ff_mult, time_token_shift = ff_time_token_shift)
]))
def forward(self, x, context = None):
for attn, ff in self.layers:
x = attn(x, context = context) + x
x = ff(x) + x
return x
class CrossEmbedLayer(nn.Module):
def __init__(
self,
dim_in,
kernel_sizes,
dim_out = None,
stride = 2
):
super().__init__()
assert all([*map(lambda t: (t % 2) == (stride % 2), kernel_sizes)])
dim_out = default(dim_out, dim_in)
kernel_sizes = sorted(kernel_sizes)
num_scales = len(kernel_sizes)
dim_scales = [int(dim_out / (2 ** i)) for i in range(1, num_scales)]
dim_scales = [*dim_scales, dim_out - sum(dim_scales)]
self.convs = nn.ModuleList([])
for kernel, dim_scale in zip(kernel_sizes, dim_scales):
self.convs.append(Conv2d(dim_in, dim_scale, kernel, stride = stride, padding = (kernel - stride) // 2))
def forward(self, x):
fmaps = tuple(map(lambda conv: conv(x), self.convs))
return torch.cat(fmaps, dim = 1)
class UpsampleCombiner(nn.Module):
def __init__(
self,
dim,
*,
enabled = False,
dim_ins = tuple(),
dim_outs = tuple()
):
super().__init__()
dim_outs = cast_tuple(dim_outs, len(dim_ins))
assert len(dim_ins) == len(dim_outs)
self.enabled = enabled
if not self.enabled:
self.dim_out = dim
return
self.fmap_convs = nn.ModuleList([Block(dim_in, dim_out) for dim_in, dim_out in zip(dim_ins, dim_outs)])
self.dim_out = dim + (sum(dim_outs) if len(dim_outs) > 0 else 0)
def forward(self, x, fmaps = None):
target_size = x.shape[-1]
fmaps = default(fmaps, tuple())
if not self.enabled or len(fmaps) == 0 or len(self.fmap_convs) == 0:
return x
fmaps = [resize_video_to(fmap, target_size) for fmap in fmaps]
outs = [conv(fmap) for fmap, conv in zip(fmaps, self.fmap_convs)]
return torch.cat((x, *outs), dim = 1)
class DynamicPositionBias(nn.Module):
def __init__(
self,
dim,
*,
heads,
depth
):
super().__init__()
self.mlp = nn.ModuleList([])
self.mlp.append(nn.Sequential(
nn.Linear(1, dim),
LayerNorm(dim),
nn.SiLU()
))
for _ in range(max(depth - 1, 0)):
self.mlp.append(nn.Sequential(
nn.Linear(dim, dim),
LayerNorm(dim),
nn.SiLU()
))
self.mlp.append(nn.Linear(dim, heads)
def forward(self, n, device, dtype):
i = torch.arange(n, device = device)
j = torch.arange(n, device = device)
indices = rearrange(i, 'i -> i 1') - rearrange(j, 'j -> 1 j')
indices += (n - 1)
pos = torch.arange(-n + 1, n, device = device, dtype = dtype)
pos = rearrange(pos, '... -> ... 1')
for layer in self.mlp:
pos = layer(pos)
bias = pos[indices]
bias = rearrange(bias, 'i j h -> h i j')
return bias
class Unet3D(nn.Module):
def __init__(
self,
*,
dim,
text_embed_dim = get_encoded_dim(DEFAULT_T5_NAME),
num_resnet_blocks = 1,
cond_dim = None,
num_image_tokens = 4,
num_time_tokens = 2,
learned_sinu_pos_emb_dim = 16,
out_dim = None,
dim_mults = (1, 2, 4, 8),
temporal_strides = 1,
cond_images_channels = 0,
channels = 3,
channels_out = None,
attn_dim_head = 64,
attn_heads = 8,
ff_mult = 2.,
ff_time_token_shift = True,
lowres_cond = False,
layer_attns = False,
layer_attns_depth = 1,
layer_attns_add_text_cond = True,
attend_at_middle = True,
time_rel_pos_bias_depth = 2,
time_causal_attn = True,
layer_cross_attns = True,
use_linear_attn = False,
use_linear_cross_attn = False,
cond_on_text = True,
max_text_len = 256,
init_dim = None,
resnet_groups = 8,
init_conv_kernel_size = 7,
init_cross_embed = True,
init_cross_embed_kernel_sizes = (3, 7, 15),
cross_embed_downsample = False,
cross_embed_downsample_kernel_sizes = (2, 4),
attn_pool_text = True,
attn_pool_num_latents = 32,
dropout = 0.,
memory_efficient = False,
init_conv_to_final_conv_residual = False,
use_global_context_attn = True,
scale_skip_connection = True,
final_resnet_block = True,
final_conv_kernel_size = 3,
self_cond = False,
combine_upsample_fmaps = False,
pixel_shuffle_upsample = True,
resize_mode = 'nearest'
def cast_model_parameters(
self,
*,
lowres_cond,
text_embed_dim,
channels,
channels_out,
cond_on_text
):
if lowres_cond == self.lowres_cond and \
channels == self.channels and \
cond_on_text == self.cond_on_text and \
text_embed_dim == self._locals['text_embed_dim'] and \
channels_out == self.channels_out:
return self
updated_kwargs = dict(
lowres_cond = lowres_cond,
text_embed_dim = text_embed_dim,
channels = channels,
channels_out = channels_out,
cond_on_text = cond_on_text
)
return self.__class__(**{**self._locals, **updated_kwargs})
def to_config_and_state_dict(self):
return self._locals, self.state_dict()
@classmethod
def from_config_and_state_dict(klass, config, state_dict):
unet = klass(**config)
unet.load_state_dict(state_dict)
return unet
def persist_to_file(self, path):
path = Path(path)
path.parents[0].mkdir(exist_ok = True, parents = True)
config, state_dict = self.to_config_and_state_dict()
pkg = dict(config = config, state_dict = state_dict)
torch.save(pkg, str(path))
@classmethod
def hydrate_from_file(klass, path):
path = Path(path)
assert path.exists()
pkg = torch.load(str(path))
assert 'config' in pkg and 'state_dict' in pkg
config, state_dict = pkg['config'], pkg['state_dict']
return Unet.from_config_and_state_dict(config, state_dict)
def forward_with_cond_scale(
self,
*args,
cond_scale = 1.,
**kwargs
):
logits = self.forward(*args, **kwargs)
if cond_scale == 1:
return logits
null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
return null_logits + (logits - null_logits) * cond_scale
def forward(
self,
x,
time,
*,
lowres_cond_img = None,
lowres_noise_times = None,
text_embeds = None,
text_mask = None,
cond_images = None,
cond_video_frames = None,
post_cond_video_frames = None,
self_cond = None,
cond_drop_prob = 0.,
ignore_time = False
.\lucidrains\imagen-pytorch\imagen_pytorch\t5.py
import torch
import transformers
from typing import List
from transformers import T5Tokenizer, T5EncoderModel, T5Config
from einops import rearrange
transformers.logging.set_verbosity_error()
def exists(val):
return val is not None
def default(val, d):
if exists(val):
return val
return d() if callable(d) else d
MAX_LENGTH = 256
DEFAULT_T5_NAME = 'google/t5-v1_1-base'
T5_CONFIGS = {}
def get_tokenizer(name):
tokenizer = T5Tokenizer.from_pretrained(name, model_max_length=MAX_LENGTH)
return tokenizer
def get_model(name):
model = T5EncoderModel.from_pretrained(name)
return model
def get_model_and_tokenizer(name):
global T5_CONFIGS
if name not in T5_CONFIGS:
T5_CONFIGS[name] = dict()
if "model" not in T5_CONFIGS[name]:
T5_CONFIGS[name]["model"] = get_model(name)
if "tokenizer" not in T5_CONFIGS[name]:
T5_CONFIGS[name]["tokenizer"] = get_tokenizer(name)
return T5_CONFIGS[name]['model'], T5_CONFIGS[name]['tokenizer']
def get_encoded_dim(name):
if name not in T5_CONFIGS:
config = T5Config.from_pretrained(name)
T5_CONFIGS[name] = dict(config=config)
elif "config" in T5_CONFIGS[name]:
config = T5_CONFIGS[name]["config"]
elif "model" in T5_CONFIGS[name]:
config = T5_CONFIGS[name]["model"].config
else:
assert False
return config.d_model
def t5_tokenize(
texts: List[str],
name = DEFAULT_T5_NAME
):
t5, tokenizer = get_model_and_tokenizer(name)
if torch.cuda.is_available():
t5 = t5.cuda()
device = next(t5.parameters()).device
encoded = tokenizer.batch_encode_plus(
texts,
return_tensors = "pt",
padding = 'longest',
max_length = MAX_LENGTH,
truncation = True
)
input_ids = encoded.input_ids.to(device)
attn_mask = encoded.attention_mask.to(device)
return input_ids, attn_mask
def t5_encode_tokenized_text(
token_ids,
attn_mask = None,
pad_id = None,
name = DEFAULT_T5_NAME
):
assert exists(attn_mask) or exists(pad_id)
t5, _ = get_model_and_tokenizer(name)
attn_mask = default(attn_mask, lambda: (token_ids != pad_id).long())
t5.eval()
with torch.no_grad():
output = t5(input_ids = token_ids, attention_mask = attn_mask)
encoded_text = output.last_hidden_state.detach()
attn_mask = attn_mask.bool()
encoded_text = encoded_text.masked_fill(~rearrange(attn_mask, '... -> ... 1'), 0.)
return encoded_text
def t5_encode_text(
texts: List[str],
name = DEFAULT_T5_NAME,
return_attn_mask = False
):
token_ids, attn_mask = t5_tokenize(texts, name = name)
encoded_text = t5_encode_tokenized_text(token_ids, attn_mask = attn_mask, name = name)
if return_attn_mask:
attn_mask = attn_mask.bool()
return encoded_text, attn_mask
return encoded_text
.\lucidrains\imagen-pytorch\imagen_pytorch\test\test_trainer.py
from imagen_pytorch.trainer import ImagenTrainer
from imagen_pytorch.configs import ImagenConfig
from imagen_pytorch.t5 import t5_encode_text
from torch.utils.data import Dataset
import torch
def test_trainer_instantiation():
unet1 = dict(
dim = 8,
dim_mults = (1, 1, 1, 1),
num_resnet_blocks = 1,
layer_attns = False,
layer_cross_attns = False,
attn_heads = 2
)
imagen = ImagenConfig(
unets=(unet1,),
image_sizes=(64,),
).create()
trainer = ImagenTrainer(
imagen=imagen
)
def test_trainer_step():
class TestDataset(Dataset):
def __init__(self):
super().__init__()
def __len__(self):
return 16
def __getitem__(self, index):
return (torch.zeros(3, 64, 64), torch.zeros(6, 768))
unet1 = dict(
dim = 8,
dim_mults = (1, 1, 1, 1),
num_resnet_blocks = 1,
layer_attns = False,
layer_cross_attns = False,
attn_heads = 2
)
imagen = ImagenConfig(
unets=(unet1,),
image_sizes=(64,),
).create()
trainer = ImagenTrainer(
imagen=imagen
)
ds = TestDataset()
trainer.add_train_dataset(ds, batch_size=8)
trainer.train_step(1)
assert trainer.num_steps_taken(1) == 1
.\lucidrains\imagen-pytorch\imagen_pytorch\test\__init__.py
from imagen_pytorch.test import test_trainer
.\lucidrains\imagen-pytorch\imagen_pytorch\trainer.py
import os
from math import ceil
from contextlib import contextmanager, nullcontext
from functools import partial, wraps
from collections.abc import Iterable
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import random_split, DataLoader
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR
from torch.cuda.amp import autocast, GradScaler
import pytorch_warmup as warmup
from imagen_pytorch.imagen_pytorch import Imagen, NullUnet
from imagen_pytorch.elucidated_imagen import ElucidatedImagen
from imagen_pytorch.data import cycle
from imagen_pytorch.version import __version__
from packaging import version
import numpy as np
from ema_pytorch import EMA
from accelerate import Accelerator, DistributedType, DistributedDataParallelKwargs
from fsspec.core import url_to_fs
from fsspec.implementations.local import LocalFileSystem
def exists(val):
return val is not None
def default(val, d):
if exists(val):
return val
return d() if callable(d) else d
def cast_tuple(val, length = 1):
if isinstance(val, list):
val = tuple(val)
return val if isinstance(val, tuple) else ((val,) * length)
def find_first(fn, arr):
for ind, el in enumerate(arr):
if fn(el):
return ind
return -1
def pick_and_pop(keys, d):
values = list(map(lambda key: d.pop(key), keys))
return dict(zip(keys, values))
def group_dict_by_key(cond, d):
return_val = [dict(),dict()]
for key in d.keys():
match = bool(cond(key))
ind = int(not match)
return_val[ind][key] = d[key]
return (*return_val,)
def string_begins_with(prefix, str):
return str.startswith(prefix)
def group_by_key_prefix(prefix, d):
return group_dict_by_key(partial(string_begins_with, prefix), d)
def groupby_prefix_and_trim(prefix, d):
kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))
return kwargs_without_prefix, kwargs
def num_to_groups(num, divisor):
groups = num // divisor
remainder = num % divisor
arr = [divisor] * groups
if remainder > 0:
arr.append(remainder)
return arr
def url_to_bucket(url):
if '://' not in url:
return url
_, suffix = url.split('://')
if prefix in {'gs', 's3'}:
return suffix.split('/')[0]
else:
raise ValueError(f'storage type prefix "{prefix}" is not supported yet')
def eval_decorator(fn):
def inner(model, *args, **kwargs):
was_training = model.training
model.eval()
out = fn(model, *args, **kwargs)
model.train(was_training)
return out
return inner
def cast_torch_tensor(fn, cast_fp16 = False):
@wraps(fn)
def inner(model, *args, **kwargs):
device = kwargs.pop('_device', model.device)
cast_device = kwargs.pop('_cast_device', True)
should_cast_fp16 = cast_fp16 and model.cast_half_at_training
kwargs_keys = kwargs.keys()
all_args = (*args, *kwargs.values())
split_kwargs_index = len(all_args) - len(kwargs_keys)
all_args = tuple(map(lambda t: torch.from_numpy(t) if exists(t) and isinstance(t, np.ndarray) else t, all_args))
if cast_device:
all_args = tuple(map(lambda t: t.to(device) if exists(t) and isinstance(t, torch.Tensor) else t, all_args))
if should_cast_fp16:
all_args = tuple(map(lambda t: t.half() if exists(t) and isinstance(t, torch.Tensor) and t.dtype != torch.bool else t, all_args))
args, kwargs_values = all_args[:split_kwargs_index], all_args[split_kwargs_index:]
kwargs = dict(tuple(zip(kwargs_keys, kwargs_values)))
out = fn(model, *args, **kwargs)
return out
return inner
def split_iterable(it, split_size):
accum = []
for ind in range(ceil(len(it) / split_size)):
start_index = ind * split_size
accum.append(it[start_index: (start_index + split_size)])
return accum
def split(t, split_size = None):
if not exists(split_size):
return t
if isinstance(t, torch.Tensor):
return t.split(split_size, dim = 0)
if isinstance(t, Iterable):
return split_iterable(t, split_size)
return TypeError
def find_first(cond, arr):
for el in arr:
if cond(el):
return el
return None
def split_args_and_kwargs(*args, split_size = None, **kwargs):
all_args = (*args, *kwargs.values())
len_all_args = len(all_args)
first_tensor = find_first(lambda t: isinstance(t, torch.Tensor), all_args)
assert exists(first_tensor)
batch_size = len(first_tensor)
split_size = default(split_size, batch_size)
num_chunks = ceil(batch_size / split_size)
dict_len = len(kwargs)
dict_keys = kwargs.keys()
split_kwargs_index = len_all_args - dict_len
split_all_args = [split(arg, split_size = split_size) if exists(arg) and isinstance(arg, (torch.Tensor, Iterable)) else ((arg,) * num_chunks) for arg in all_args]
chunk_sizes = num_to_groups(batch_size, split_size)
for (chunk_size, *chunked_all_args) in tuple(zip(chunk_sizes, *split_all_args)):
chunked_args, chunked_kwargs_values = chunked_all_args[:split_kwargs_index], chunked_all_args[split_kwargs_index:]
chunked_kwargs = dict(tuple(zip(dict_keys, chunked_kwargs_values)))
chunk_size_frac = chunk_size / batch_size
yield chunk_size_frac, (chunked_args, chunked_kwargs)
def imagen_sample_in_chunks(fn):
@wraps(fn)
def inner(self, *args, max_batch_size = None, **kwargs):
if not exists(max_batch_size):
return fn(self, *args, **kwargs)
if self.imagen.unconditional:
batch_size = kwargs.get('batch_size')
batch_sizes = num_to_groups(batch_size, max_batch_size)
outputs = [fn(self, *args, **{**kwargs, 'batch_size': sub_batch_size}) for sub_batch_size in batch_sizes]
else:
outputs = [fn(self, *chunked_args, **chunked_kwargs) for _, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs)]
if isinstance(outputs[0], torch.Tensor):
return torch.cat(outputs, dim = 0)
return list(map(lambda t: torch.cat(t, dim = 0), list(zip(*outputs))))
return inner
def restore_parts(state_dict_target, state_dict_from):
for name, param in state_dict_from.items():
if name not in state_dict_target:
continue
if param.size() == state_dict_target[name].size():
state_dict_target[name].copy_(param)
else:
print(f"layer {name}({param.size()} different than target: {state_dict_target[name].size()}")
return state_dict_target
class ImagenTrainer(nn.Module):
locked = False
def __init__(
self,
imagen = None,
imagen_checkpoint_path = None,
use_ema = True,
lr = 1e-4,
eps = 1e-8,
beta1 = 0.9,
beta2 = 0.99,
max_grad_norm = None,
group_wd_params = True,
warmup_steps = None,
cosine_decay_max_steps = None,
only_train_unet_number = None,
fp16 = False,
precision = None,
split_batches = True,
dl_tuple_output_keywords_names = ('images', 'text_embeds', 'text_masks', 'cond_images'),
verbose = True,
split_valid_fraction = 0.025,
split_valid_from_train = False,
split_random_seed = 42,
checkpoint_path = None,
checkpoint_every = None,
checkpoint_fs = None,
fs_kwargs: dict = None,
max_checkpoints_keep = 20,
**kwargs
def prepare(self):
assert not self.prepared, f'The trainer is allready prepared'
self.validate_and_set_unet_being_trained(self.only_train_unet_number)
self.prepared = True
@property
def device(self):
return self.accelerator.device
@property
def is_distributed(self):
return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1)
@property
def is_main(self):
return self.accelerator.is_main_process
@property
def is_local_main(self):
return self.accelerator.is_local_main_process
@property
def unwrapped_unet(self):
return self.accelerator.unwrap_model(self.unet_being_trained)
def get_lr(self, unet_number):
self.validate_unet_number(unet_number)
unet_index = unet_number - 1
optim = getattr(self, f'optim{unet_index}')
return optim.param_groups[0]['lr']
def validate_and_set_unet_being_trained(self, unet_number = None):
if exists(unet_number):
self.validate_unet_number(unet_number)
assert not exists(self.only_train_unet_number) or self.only_train_unet_number == unet_number, 'you cannot only train on one unet at a time. you will need to save the trainer into a checkpoint, and resume training on a new unet'
self.only_train_unet_number = unet_number
self.imagen.only_train_unet_number = unet_number
if not exists(unet_number):
return
self.wrap_unet(unet_number)
def wrap_unet(self, unet_number):
if hasattr(self, 'one_unet_wrapped'):
return
unet = self.imagen.get_unet(unet_number)
unet_index = unet_number - 1
optimizer = getattr(self, f'optim{unet_index}')
scheduler = getattr(self, f'scheduler{unet_index}')
if self.train_dl:
self.unet_being_trained, self.train_dl, optimizer = self.accelerator.prepare(unet, self.train_dl, optimizer)
else:
self.unet_being_trained, optimizer = self.accelerator.prepare(unet, optimizer)
if exists(scheduler):
scheduler = self.accelerator.prepare(scheduler)
setattr(self, f'optim{unet_index}', optimizer)
setattr(self, f'scheduler{unet_index}', scheduler)
self.one_unet_wrapped = True
def set_accelerator_scaler(self, unet_number):
def patch_optimizer_step(accelerated_optimizer, method):
def patched_step(*args, **kwargs):
accelerated_optimizer._accelerate_step_called = True
return method(*args, **kwargs)
return patched_step
unet_number = self.validate_unet_number(unet_number)
scaler = getattr(self, f'scaler{unet_number - 1}')
self.accelerator.scaler = scaler
for optimizer in self.accelerator._optimizers:
optimizer.scaler = scaler
optimizer._accelerate_step_called = False
optimizer._optimizer_original_step_method = optimizer.optimizer.step
optimizer._optimizer_patched_step_method = patch_optimizer_step(optimizer, optimizer.optimizer.step)
def print(self, msg):
if not self.is_main:
return
if not self.verbose:
return
return self.accelerator.print(msg)
def validate_unet_number(self, unet_number = None):
if self.num_unets == 1:
unet_number = default(unet_number, 1)
assert 0 < unet_number <= self.num_unets, f'unet number should be in between 1 and {self.num_unets}'
return unet_number
def num_steps_taken(self, unet_number = None):
if self.num_unets == 1:
unet_number = default(unet_number, 1)
return self.steps[unet_number - 1].item()
def print_untrained_unets(self):
print_final_error = False
for ind, (steps, unet) in enumerate(zip(self.steps.tolist(), self.imagen.unets)):
if steps > 0 or isinstance(unet, NullUnet):
continue
self.print(f'unet {ind + 1} has not been trained')
print_final_error = True
if print_final_error:
self.print('when sampling, you can pass stop_at_unet_number to stop early in the cascade, so it does not try to generate with untrained unets')
def add_train_dataloader(self, dl = None):
if not exists(dl):
return
assert not exists(self.train_dl), 'training dataloader was already added'
assert not self.prepared, f'You need to add the dataset before preperation'
self.train_dl = dl
def add_valid_dataloader(self, dl):
if not exists(dl):
return
assert not exists(self.valid_dl), 'validation dataloader was already added'
assert not self.prepared, f'You need to add the dataset before preperation'
self.valid_dl = dl
def add_train_dataset(self, ds = None, *, batch_size, **dl_kwargs):
if not exists(ds):
return
assert not exists(self.train_dl), 'training dataloader was already added'
valid_ds = None
if self.split_valid_from_train:
train_size = int((1 - self.split_valid_fraction) * len(ds)
valid_size = len(ds) - train_size
ds, valid_ds = random_split(ds, [train_size, valid_size], generator = torch.Generator().manual_seed(self.split_random_seed))
self.print(f'training with dataset of {len(ds)} samples and validating with randomly splitted {len(valid_ds)} samples')
dl = DataLoader(ds, batch_size = batch_size, **dl_kwargs)
self.add_train_dataloader(dl)
if not self.split_valid_from_train:
return
self.add_valid_dataset(valid_ds, batch_size = batch_size, **dl_kwargs)
def add_valid_dataset(self, ds, *, batch_size, **dl_kwargs):
if not exists(ds):
return
assert not exists(self.valid_dl), 'validation dataloader was already added'
dl = DataLoader(ds, batch_size = batch_size, **dl_kwargs)
self.add_valid_dataloader(dl)
def create_train_iter(self):
assert exists(self.train_dl), 'training dataloader has not been registered with the trainer yet'
if exists(self.train_dl_iter):
return
self.train_dl_iter = cycle(self.train_dl)
def create_valid_iter(self):
assert exists(self.valid_dl), 'validation dataloader has not been registered with the trainer yet'
if exists(self.valid_dl_iter):
return
self.valid_dl_iter = cycle(self.valid_dl)
def train_step(self, *, unet_number = None, **kwargs):
if not self.prepared:
self.prepare()
self.create_train_iter()
kwargs = {'unet_number': unet_number, **kwargs}
loss = self.step_with_dl_iter(self.train_dl_iter, **kwargs)
self.update(unet_number = unet_number)
return loss
@torch.no_grad()
@eval_decorator
def valid_step(self, **kwargs):
if not self.prepared:
self.prepare()
self.create_valid_iter()
context = self.use_ema_unets if kwargs.pop('use_ema_unets', False) else nullcontext
with context():
loss = self.step_with_dl_iter(self.valid_dl_iter, **kwargs)
return loss
def step_with_dl_iter(self, dl_iter, **kwargs):
dl_tuple_output = cast_tuple(next(dl_iter))
model_input = dict(list(zip(self.dl_tuple_output_keywords_names, dl_tuple_output)))
loss = self.forward(**{**kwargs, **model_input})
return loss
@property
def all_checkpoints_sorted(self):
glob_pattern = os.path.join(self.checkpoint_path, '*.pt')
checkpoints = self.fs.glob(glob_pattern)
sorted_checkpoints = sorted(checkpoints, key = lambda x: int(str(x).split('.')[-2]), reverse = True)
return sorted_checkpoints
def load_from_checkpoint_folder(self, last_total_steps = -1):
if last_total_steps != -1:
filepath = os.path.join(self.checkpoint_path, f'checkpoint.{last_total_steps}.pt')
self.load(filepath)
return
sorted_checkpoints = self.all_checkpoints_sorted
if len(sorted_checkpoints) == 0:
self.print(f'no checkpoints found to load from at {self.checkpoint_path}')
return
last_checkpoint = sorted_checkpoints[0]
self.load(last_checkpoint)
def save_to_checkpoint_folder(self):
self.accelerator.wait_for_everyone()
if not self.can_checkpoint:
return
total_steps = int(self.steps.sum().item())
filepath = os.path.join(self.checkpoint_path, f'checkpoint.{total_steps}.pt')
self.save(filepath)
if self.max_checkpoints_keep <= 0:
return
sorted_checkpoints = self.all_checkpoints_sorted
checkpoints_to_discard = sorted_checkpoints[self.max_checkpoints_keep:]
for checkpoint in checkpoints_to_discard:
self.fs.rm(checkpoint)
def save(
self,
path,
overwrite = True,
without_optim_and_sched = False,
**kwargs
):
self.accelerator.wait_for_everyone()
if not self.can_checkpoint:
return
fs = self.fs
assert not (fs.exists(path) and not overwrite)
self.reset_ema_unets_all_one_device()
save_obj = dict(
model = self.imagen.state_dict(),
version = __version__,
steps = self.steps.cpu(),
**kwargs
)
save_optim_and_sched_iter = range(0, self.num_unets) if not without_optim_and_sched else tuple()
for ind in save_optim_and_sched_iter:
scaler_key = f'scaler{ind}'
optimizer_key = f'optim{ind}'
scheduler_key = f'scheduler{ind}'
warmup_scheduler_key = f'warmup{ind}'
scaler = getattr(self, scaler_key)
optimizer = getattr(self, optimizer_key)
scheduler = getattr(self, scheduler_key)
warmup_scheduler = getattr(self, warmup_scheduler_key)
if exists(scheduler):
save_obj = {**save_obj, scheduler_key: scheduler.state_dict()}
if exists(warmup_scheduler):
save_obj = {**save_obj, warmup_scheduler_key: warmup_scheduler.state_dict()}
save_obj = {**save_obj, scaler_key: scaler.state_dict(), optimizer_key: optimizer.state_dict()}
if self.use_ema:
save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()}
if hasattr(self.imagen, '_config'):
self.print(f'this checkpoint is commandable from the CLI - "imagen --model {str(path)} \"<prompt>"')
save_obj = {
**save_obj,
'imagen_type': 'elucidated' if self.is_elucidated else 'original',
'imagen_params': self.imagen._config
}
with fs.open(path, 'wb') as f:
torch.save(save_obj, f)
self.print(f'checkpoint saved to {path}')
def load(self, path, only_model = False, strict = True, noop_if_not_exist = False):
fs = self.fs
if noop_if_not_exist and not fs.exists(path):
self.print(f'trainer checkpoint not found at {str(path)}')
return
assert fs.exists(path), f'{path} does not exist'
self.reset_ema_unets_all_one_device()
with fs.open(path) as f:
loaded_obj = torch.load(f, map_location='cpu')
if version.parse(__version__) != version.parse(loaded_obj['version']):
self.print(f'loading saved imagen at version {loaded_obj["version"]}, but current package version is {__version__}')
try:
self.imagen.load_state_dict(loaded_obj['model'], strict = strict)
except RuntimeError:
print("Failed loading state dict. Trying partial load")
self.imagen.load_state_dict(restore_parts(self.imagen.state_dict(),
loaded_obj['model']))
if only_model:
return loaded_obj
self.steps.copy_(loaded_obj['steps'])
for ind in range(0, self.num_unets):
scaler_key = f'scaler{ind}'
optimizer_key = f'optim{ind}'
scheduler_key = f'scheduler{ind}'
warmup_scheduler_key = f'warmup{ind}'
scaler = getattr(self, scaler_key)
optimizer = getattr(self, optimizer_key)
scheduler = getattr(self, scheduler_key)
warmup_scheduler = getattr(self, warmup_scheduler_key)
if exists(scheduler) and scheduler_key in loaded_obj:
scheduler.load_state_dict(loaded_obj[scheduler_key])
if exists(warmup_scheduler) and warmup_scheduler_key in loaded_obj:
warmup_scheduler.load_state_dict(loaded_obj[warmup_scheduler_key])
if exists(optimizer):
try:
optimizer.load_state_dict(loaded_obj[optimizer_key])
scaler.load_state_dict(loaded_obj[scaler_key])
except:
self.print('could not load optimizer and scaler, possibly because you have turned on mixed precision training since the last run. resuming with new optimizer and scalers')
if self.use_ema:
assert 'ema' in loaded_obj
try:
self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict)
except RuntimeError:
print("Failed loading state dict. Trying partial load")
self.ema_unets.load_state_dict(restore_parts(self.ema_unets.state_dict(),
loaded_obj['ema']))
self.print(f'checkpoint loaded from {path}')
return loaded_obj
@property
def unets(self):
return nn.ModuleList([ema.ema_model for ema in self.ema_unets])
def get_ema_unet(self, unet_number = None):
if not self.use_ema:
return
unet_number = self.validate_unet_number(unet_number)
index = unet_number - 1
if isinstance(self.unets, nn.ModuleList):
unets_list = [unet for unet in self.ema_unets]
delattr(self, 'ema_unets')
self.ema_unets = unets_list
if index != self.ema_unet_being_trained_index:
for unet_index, unet in enumerate(self.ema_unets):
unet.to(self.device if unet_index == index else 'cpu')
self.ema_unet_being_trained_index = index
return self.ema_unets[index]
def reset_ema_unets_all_one_device(self, device = None):
if not self.use_ema:
return
device = default(device, self.device)
self.ema_unets = nn.ModuleList([*self.ema_unets])
self.ema_unets.to(device)
self.ema_unet_being_trained_index = -1
@torch.no_grad()
@contextmanager
def use_ema_unets(self):
if not self.use_ema:
output = yield
return output
self.reset_ema_unets_all_one_device()
self.imagen.reset_unets_all_one_device()
self.unets.eval()
trainable_unets = self.imagen.unets
self.imagen.unets = self.unets
output = yield
self.imagen.unets = trainable_unets
for ema in self.ema_unets:
ema.restore_ema_model_device()
return output
def print_unet_devices(self):
self.print('unet devices:')
for i, unet in enumerate(self.imagen.unets):
device = next(unet.parameters()).device
self.print(f'\tunet {i}: {device}')
if not self.use_ema:
return
self.print('\nema unet devices:')
for i, ema_unet in enumerate(self.ema_unets):
device = next(ema_unet.parameters()).device
self.print(f'\tema unet {i}: {device}')
def state_dict(self, *args, **kwargs):
self.reset_ema_unets_all_one_device()
return super().state_dict(*args, **kwargs)
def load_state_dict(self, *args, **kwargs):
self.reset_ema_unets_all_one_device()
return super().load_state_dict(*args, **kwargs)
def encode_text(self, text, **kwargs):
return self.imagen.encode_text(text, **kwargs)
def update(self, unet_number = None):
unet_number = self.validate_unet_number(unet_number)
self.validate_and_set_unet_being_trained(unet_number)
self.set_accelerator_scaler(unet_number)
index = unet_number - 1
unet = self.unet_being_trained
optimizer = getattr(self, f'optim{index}')
scaler = getattr(self, f'scaler{index}')
scheduler = getattr(self, f'scheduler{index}')
warmup_scheduler = getattr(self, f'warmup{index}')
if exists(self.max_grad_norm):
self.accelerator.clip_grad_norm_(unet.parameters(), self.max_grad_norm)
optimizer.step()
optimizer.zero_grad()
if self.use_ema:
ema_unet = self.get_ema_unet(unet_number)
ema_unet.update()
maybe_warmup_context = nullcontext() if not exists(warmup_scheduler) else warmup_scheduler.dampening()
with maybe_warmup_context:
if exists(scheduler) and not self.accelerator.optimizer_step_was_skipped:
scheduler.step()
self.steps += F.one_hot(torch.tensor(unet_number - 1, device = self.steps.device), num_classes = len(self.steps))
if not exists(self.checkpoint_path):
return
total_steps = int(self.steps.sum().item())
if total_steps % self.checkpoint_every:
return
self.save_to_checkpoint_folder()
@torch.no_grad()
@cast_torch_tensor
@imagen_sample_in_chunks
def sample(self, *args, **kwargs):
context = nullcontext if kwargs.pop('use_non_ema', False) else self.use_ema_unets
self.print_untrained_unets()
if not self.is_main:
kwargs['use_tqdm'] = False
with context():
output = self.imagen.sample(*args, device = self.device, **kwargs)
return output
@partial(cast_torch_tensor, cast_fp16 = True)
def forward(
self,
*args,
unet_number = None,
max_batch_size = None,
**kwargs
):
unet_number = self.validate_unet_number(unet_number)
self.validate_and_set_unet_being_trained(unet_number)
self.set_accelerator_scaler(unet_number)
assert not exists(self.only_train_unet_number) or self.only_train_unet_number == unet_number, f'you can only train unet #{self.only_train_unet_number}'
total_loss = 0.
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
with self.accelerator.autocast():
loss = self.imagen(*chunked_args, unet = self.unet_being_trained, unet_number = unet_number, **chunked_kwargs)
loss = loss * chunk_size_frac
total_loss += loss.item()
if self.training:
self.accelerator.backward(loss)
return total_loss