Lucidrains 系列项目源码解析(四十一)
.\lucidrains\gigagan-pytorch\gigagan_pytorch\distributed.py
import torch
import torch.nn.functional as F
from torch.autograd import Function
import torch.distributed as dist
from einops import rearrange
def exists(val):
return val is not None
def pad_dim_to(t, length, dim = 0):
pad_length = length - t.shape[dim]
zero_pairs = (-dim - 1) if dim < 0 else (t.ndim - dim - 1)
return F.pad(t, (*((0, 0) * zero_pairs), 0, pad_length))
def all_gather_variable_dim(t, dim = 0, sizes = None):
device, world_size = t.device, dist.get_world_size()
if not exists(sizes):
size = torch.tensor(t.shape[dim], device = device, dtype = torch.long)
sizes = [torch.empty_like(size, device = device, dtype = torch.long) for i in range(world_size)]
dist.all_gather(sizes, size)
sizes = torch.stack(sizes)
max_size = sizes.amax().item()
padded_t = pad_dim_to(t, max_size, dim = dim)
gathered_tensors = [torch.empty(padded_t.shape, device = device, dtype = padded_t.dtype) for i in range(world_size)]
dist.all_gather(gathered_tensors, padded_t)
gathered_tensor = torch.cat(gathered_tensors, dim = dim)
seq = torch.arange(max_size, device = device)
mask = rearrange(seq, 'j -> 1 j') < rearrange(sizes, 'i -> i 1')
mask = rearrange(mask, 'i j -> (i j)')
seq = torch.arange(mask.shape[-1], device = device)
indices = seq[mask]
gathered_tensor = gathered_tensor.index_select(dim, indices)
return gathered_tensor, sizes
class AllGather(Function):
@staticmethod
def forward(ctx, x, dim, sizes):
is_dist = dist.is_initialized() and dist.get_world_size() > 1
ctx.is_dist = is_dist
if not is_dist:
return x, None
x, batch_sizes = all_gather_variable_dim(x, dim = dim, sizes = sizes)
ctx.batch_sizes = batch_sizes.tolist()
ctx.dim = dim
return x, batch_sizes
@staticmethod
def backward(ctx, grads, _):
if not ctx.is_dist:
return grads, None, None
batch_sizes, rank = ctx.batch_sizes, dist.get_rank()
grads_by_rank = grads.split(batch_sizes, dim = ctx.dim)
return grads_by_rank[rank], None, None
all_gather = AllGather.apply
.\lucidrains\gigagan-pytorch\gigagan_pytorch\gigagan_pytorch.py
from collections import namedtuple
from pathlib import Path
from math import log2, sqrt
from random import random
from functools import partial
from torchvision import utils
import torch
import torch.nn.functional as F
from torch import nn, einsum, Tensor
from torch.autograd import grad as torch_grad
from torch.utils.data import DataLoader
from torch.cuda.amp import GradScaler
from beartype import beartype
from beartype.typing import List, Optional, Tuple, Dict, Union, Iterable
from einops import rearrange, pack, unpack, repeat, reduce
from einops.layers.torch import Rearrange, Reduce
from kornia.filters import filter2d
from ema_pytorch import EMA
from gigagan_pytorch.version import __version__
from gigagan_pytorch.open_clip import OpenClipAdapter
from gigagan_pytorch.optimizer import get_optimizer
from gigagan_pytorch.distributed import all_gather
from tqdm import tqdm
from numerize import numerize
from accelerate import Accelerator, DistributedType
from accelerate.utils import DistributedDataParallelKwargs
def exists(val):
return val is not None
@beartype
def is_empty(arr: Iterable):
return len(arr) == 0
def default(*vals):
for val in vals:
if exists(val):
return val
return None
def cast_tuple(t, length = 1):
return t if isinstance(t, tuple) else ((t,) * length)
def is_power_of_two(n):
return log2(n).is_integer()
def safe_unshift(arr):
if len(arr) == 0:
return None
return arr.pop(0)
def divisible_by(numer, denom):
return (numer % denom) == 0
def group_by_num_consecutive(arr, num):
out = []
for ind, el in enumerate(arr):
if ind > 0 and divisible_by(ind, num):
yield out
out = []
out.append(el)
if len(out) > 0:
yield out
def is_unique(arr):
return len(set(arr)) == len(arr)
def cycle(dl):
while True:
for data in dl:
yield data
def num_to_groups(num, divisor):
groups, remainder = divmod(num, divisor)
arr = [divisor] * groups
if remainder > 0:
arr.append(remainder)
return arr
def mkdir_if_not_exists(path):
path.mkdir(exist_ok = True, parents = True)
@beartype
def set_requires_grad_(
m: nn.Module,
requires_grad: bool
):
for p in m.parameters():
p.requires_grad = requires_grad
def leaky_relu(neg_slope = 0.2):
return nn.LeakyReLU(neg_slope)
def conv2d_3x3(dim_in, dim_out):
return nn.Conv2d(dim_in, dim_out, 3, padding = 1)
def log(t, eps = 1e-20):
return t.clamp(min = eps).log()
def gradient_penalty(
images,
outputs,
grad_output_weights = None,
weight = 10,
scaler: Optional[GradScaler] = None,
eps = 1e-4
):
if not isinstance(outputs, (list, tuple)):
outputs = [outputs]
if exists(scaler):
outputs = [*map(scaler.scale, outputs)]
if not exists(grad_output_weights):
grad_output_weights = (1,) * len(outputs)
maybe_scaled_gradients, *_ = torch_grad(
outputs = outputs,
inputs = images,
grad_outputs = [(torch.ones_like(output) * weight) for output, weight in zip(outputs, grad_output_weights)],
create_graph = True,
retain_graph = True,
only_inputs = True
)
gradients = maybe_scaled_gradients
if exists(scaler):
scale = scaler.get_scale()
inv_scale = 1. / max(scale, eps)
gradients = maybe_scaled_gradients * inv_scale
gradients = rearrange(gradients, 'b ... -> b (...)')
return weight * ((gradients.norm(2, dim = 1) - 1) ** 2).mean()
def generator_hinge_loss(fake):
return fake.mean()
def discriminator_hinge_loss(real, fake):
return (F.relu(1 + real) + F.relu(1 - fake)).mean()
def aux_matching_loss(real, fake):
"""
# 计算负对数似然损失,因为在这个框架中,鉴别器对于真实数据为0,对于生成数据为高值。GANs可以任意交换这一点,只要生成器和鉴别器是对立的即可
"""
return (log(1 + (-real).exp()) + log(1 + (-fake).exp())).mean()
@beartype
def aux_clip_loss(
clip: OpenClipAdapter,
images: Tensor,
texts: Optional[List[str]] = None,
text_embeds: Optional[Tensor] = None
):
assert exists(texts) ^ exists(text_embeds)
images, batch_sizes = all_gather(images, 0, None)
if exists(texts):
text_embeds, _ = clip.embed_texts(texts)
text_embeds, _ = all_gather(text_embeds, 0, batch_sizes)
return clip.contrastive_loss(images = images, text_embeds = text_embeds)
class DiffAugment(nn.Module):
def __init__(
self,
*,
prob,
horizontal_flip,
horizontal_flip_prob = 0.5
):
super().__init__()
self.prob = prob
assert 0 <= prob <= 1.
self.horizontal_flip = horizontal_flip
self.horizontal_flip_prob = horizontal_flip_prob
def forward(
self,
images,
rgbs: List[Tensor]
):
if random() >= self.prob:
return images, rgbs
if random() < self.horizontal_flip_prob:
images = torch.flip(images, (-1,))
rgbs = [torch.flip(rgb, (-1,)) for rgb in rgbs]
return images, rgbs
class ChannelRMSNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.scale = dim ** 0.5
self.gamma = nn.Parameter(torch.ones(dim, 1, 1))
def forward(self, x):
normed = F.normalize(x, dim = 1)
return normed * self.scale * self.gamma
class RMSNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.scale = dim ** 0.5
self.gamma = nn.Parameter(torch.ones(dim))
def forward(self, x):
normed = F.normalize(x, dim = -1)
return normed * self.scale * self.gamma
class Blur(nn.Module):
def __init__(self):
super().__init__()
f = torch.Tensor([1, 2, 1])
self.register_buffer('f', f)
def forward(self, x):
f = self.f
f = f[None, None, :] * f [None, :, None]
return filter2d(x, f, normalized = True)
def Upsample(*args):
return nn.Sequential(
nn.Upsample(scale_factor = 2, mode = 'bilinear', align_corners = False),
Blur()
)
class PixelShuffleUpsample(nn.Module):
def __init__(self, dim):
super().__init__()
conv = nn.Conv2d(dim, dim * 4, 1)
self.net = nn.Sequential(
conv,
nn.SiLU(),
nn.PixelShuffle(2)
)
self.init_conv_(conv)
def init_conv_(self, conv):
o, i, h, w = conv.weight.shape
conv_weight = torch.empty(o // 4, i, h, w)
nn.init.kaiming_uniform_(conv_weight)
conv_weight = repeat(conv_weight, 'o ... -> (o 4) ...')
conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data)
def forward(self, x):
return self.net(x)
def Downsample(dim):
return nn.Sequential(
Rearrange('b c (h s1) (w s2) -> b (c s1 s2) h w', s1 = 2, s2 = 2),
nn.Conv2d(dim * 4, dim, 1)
)
def SqueezeExcite(dim, dim_out, reduction = 4, dim_min = 32):
dim_hidden = max(dim_out // reduction, dim_min)
return nn.Sequential(
Reduce('b c h w -> b c', 'mean'),
nn.Linear(dim, dim_hidden),
nn.SiLU(),
nn.Linear(dim_hidden, dim_out),
nn.Sigmoid(),
Rearrange('b c -> b c 1 1')
)
def get_same_padding(size, kernel, dilation, stride):
return ((size - 1) * (stride - 1) + dilation * (kernel - 1)) // 2
class AdaptiveConv2DMod(nn.Module):
def __init__(
self,
dim,
dim_out,
kernel,
*,
demod = True,
stride = 1,
dilation = 1,
eps = 1e-8,
num_conv_kernels = 1
):
super().__init__()
self.eps = eps
self.dim_out = dim_out
self.kernel = kernel
self.stride = stride
self.dilation = dilation
self.adaptive = num_conv_kernels > 1
self.weights = nn.Parameter(torch.randn((num_conv_kernels, dim_out, dim, kernel, kernel)))
self.demod = demod
nn.init.kaiming_normal_(self.weights, a = 0, mode = 'fan_in', nonlinearity = 'leaky_relu')
def forward(
self,
fmap,
mod: Optional[Tensor] = None,
kernel_mod: Optional[Tensor] = None
):
"""
notation
b - batch
n - convs
o - output
i - input
k - kernel
"""
b, h = fmap.shape[0], fmap.shape[-2]
if mod.shape[0] != b:
mod = repeat(mod, 'b ... -> (s b) ...', s = b // mod.shape[0])
if exists(kernel_mod):
kernel_mod_has_el = kernel_mod.numel() > 0
assert self.adaptive or not kernel_mod_has_el
if kernel_mod_has_el and kernel_mod.shape[0] != b:
kernel_mod = repeat(kernel_mod, 'b ... -> (s b) ...', s = b // kernel_mod.shape[0])
weights = self.weights
if self.adaptive:
weights = repeat(weights, '... -> b ...', b = b)
assert exists(kernel_mod) and kernel_mod.numel() > 0
kernel_attn = kernel_mod.softmax(dim = -1)
kernel_attn = rearrange(kernel_attn, 'b n -> b n 1 1 1 1')
weights = reduce(weights * kernel_attn, 'b n ... -> b ...', 'sum')
mod = rearrange(mod, 'b i -> b 1 i 1 1')
weights = weights * (mod + 1)
if self.demod:
inv_norm = reduce(weights ** 2, 'b o i k1 k2 -> b o 1 1 1', 'sum').clamp(min = self.eps).rsqrt()
weights = weights * inv_norm
fmap = rearrange(fmap, 'b c h w -> 1 (b c) h w')
weights = rearrange(weights, 'b o ... -> (b o) ...')
padding = get_same_padding(h, self.kernel, self.dilation, self.stride)
fmap = F.conv2d(fmap, weights, padding = padding, groups = b)
return rearrange(fmap, '1 (b o) ... -> b o ...', b = b)
class SelfAttention(nn.Module):
def __init__(
self,
dim,
dim_head = 64,
heads = 8,
dot_product = False
):
super().__init__()
self.heads = heads
self.scale = dim_head ** -0.5
dim_inner = dim_head * heads
self.dot_product = dot_product
self.norm = ChannelRMSNorm(dim)
self.to_q = nn.Conv2d(dim, dim_inner, 1, bias = False)
self.to_k = nn.Conv2d(dim, dim_inner, 1, bias = False) if dot_product else None
self.to_v = nn.Conv2d(dim, dim_inner, 1, bias = False)
self.null_kv = nn.Parameter(torch.randn(2, heads, dim_head))
self.to_out = nn.Conv2d(dim_inner, dim, 1, bias = False)
def forward(self, fmap):
"""
einstein notation
b - batch
h - heads
x - height
y - width
d - dimension
i - source seq (attend from)
j - target seq (attend to)
"""
batch = fmap.shape[0]
fmap = self.norm(fmap)
x, y = fmap.shape[-2:]
h = self.heads
q, v = self.to_q(fmap), self.to_v(fmap)
k = self.to_k(fmap) if exists(self.to_k) else q
q, k, v = map(lambda t: rearrange(t, 'b (h d) x y -> (b h) (x y) d', h = self.heads), (q, k, v))
nk, nv = map(lambda t: repeat(t, 'h d -> (b h) 1 d', b = batch), self.null_kv)
k = torch.cat((nk, k), dim = -2)
v = torch.cat((nv, v), dim = -2)
if self.dot_product:
sim = einsum('b i d, b j d -> b i j', q, k)
else:
q_squared = (q * q).sum(dim = -1)
k_squared = (k * k).sum(dim = -1)
l2dist_squared = rearrange(q_squared, 'b i -> b i 1') + rearrange(k_squared, 'b j -> b 1 j') - 2 * einsum('b i d, b j d -> b i j', q, k)
sim = -l2dist_squared
sim = sim * self.scale
attn = sim.softmax(dim = -1)
out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) (x y) d -> b (h d) x y', x = x, y = y, h = h)
return self.to_out(out)
class CrossAttention(nn.Module):
def __init__(
self,
dim,
dim_context,
dim_head = 64,
heads = 8
):
super().__init__()
self.heads = heads
self.scale = dim_head ** -0.5
dim_inner = dim_head * heads
kv_input_dim = default(dim_context, dim)
self.norm = ChannelRMSNorm(dim)
self.norm_context = RMSNorm(kv_input_dim)
self.to_q = nn.Conv2d(dim, dim_inner, 1, bias = False)
self.to_kv = nn.Linear(kv_input_dim, dim_inner * 2, bias = False)
self.to_out = nn.Conv2d(dim_inner, dim, 1, bias = False)
def forward(self, fmap, context, mask = None):
"""
einstein notation
b - batch
h - heads
x - height
y - width
d - dimension
i - source seq (attend from)
j - target seq (attend to)
"""
fmap = self.norm(fmap)
context = self.norm_context(context)
x, y = fmap.shape[-2:]
h = self.heads
q, k, v = (self.to_q(fmap), *self.to_kv(context).chunk(2, dim = -1))
k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (k, v))
q = rearrange(q, 'b (h d) x y -> (b h) (x y) d', h = self.heads)
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
if exists(mask):
mask = repeat(mask, 'b j -> (b h) 1 j', h = self.heads)
sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
attn = sim.softmax(dim = -1)
out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) (x y) d -> b (h d) x y', x = x, y = y, h = h)
return self.to_out(out)
class TextAttention(nn.Module):
def __init__(
self,
dim,
dim_head = 64,
heads = 8
):
super().__init__()
self.heads = heads
self.scale = dim_head ** -0.5
dim_inner = dim_head * heads
self.norm = RMSNorm(dim)
self.to_qkv = nn.Linear(dim, dim_inner * 3, bias = False)
self.null_kv = nn.Parameter(torch.randn(2, heads, dim_head))
self.to_out = nn.Linear(dim_inner, dim, bias = False)
def forward(self, encodings, mask = None):
"""
einstein notation
b - batch
h - heads
x - height
y - width
d - dimension
i - source seq (attend from)
j - target seq (attend to)
"""
batch = encodings.shape[0]
encodings = self.norm(encodings)
h = self.heads
q, k, v = self.to_qkv(encodings).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = self.heads), (q, k, v))
nk, nv = map(lambda t: repeat(t, 'h d -> (b h) 1 d', b = batch), self.null_kv)
k = torch.cat((nk, k), dim = -2)
v = torch.cat((nv, v), dim = -2)
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
if exists(mask):
mask = F.pad(mask, (1, 0), value = True)
mask = repeat(mask, 'b n -> (b h) 1 n', h = h)
sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
attn = sim.softmax(dim = -1)
out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
return self.to_out(out)
def FeedForward(
dim,
mult = 4,
channel_first = False
):
dim_hidden = int(dim * mult)
norm_klass = ChannelRMSNorm if channel_first else RMSNorm
proj = partial(nn.Conv2d, kernel_size = 1) if channel_first else nn.Linear
return nn.Sequential(
norm_klass(dim),
proj(dim, dim_hidden),
nn.GELU(),
proj(dim_hidden, dim)
)
class SelfAttentionBlock(nn.Module):
def __init__(
self,
dim,
dim_head = 64,
heads = 8,
ff_mult = 4,
dot_product = False
):
super().__init__()
self.attn = SelfAttention(dim = dim, dim_head = dim_head, heads = heads, dot_product = dot_product)
self.ff = FeedForward(dim = dim, mult = ff_mult, channel_first = True)
def forward(self, x):
x = self.attn(x) + x
x = self.ff(x) + x
return x
class CrossAttentionBlock(nn.Module):
def __init__(
self,
dim,
dim_context,
dim_head = 64,
heads = 8,
ff_mult = 4
):
super().__init__()
self.attn = CrossAttention(dim = dim, dim_context = dim_context, dim_head = dim_head, heads = heads)
self.ff = FeedForward(dim = dim, mult = ff_mult, channel_first = True)
def forward(self, x, context, mask = None):
x = self.attn(x, context = context, mask = mask) + x
x = self.ff(x) + x
return x
class Transformer(nn.Module):
def __init__(
self,
dim,
depth,
dim_head = 64,
heads = 8,
ff_mult = 4
):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
TextAttention(dim = dim, dim_head = dim_head, heads = heads),
FeedForward(dim = dim, mult = ff_mult)
]))
self.norm = RMSNorm(dim)
def forward(self, x, mask = None):
for attn, ff in self.layers:
x = attn(x, mask = mask) + x
x = ff(x) + x
return self.norm(x)
class TextEncoder(nn.Module):
@beartype
def __init__(
self,
*,
dim,
depth,
clip: Optional[OpenClipAdapter] = None,
dim_head = 64,
heads = 8,
):
super().__init__()
self.dim = dim
if not exists(clip):
clip = OpenClipAdapter()
self.clip = clip
set_requires_grad_(clip, False)
self.learned_global_token = nn.Parameter(torch.randn(dim))
self.project_in = nn.Linear(clip.dim_latent, dim) if clip.dim_latent != dim else nn.Identity()
self.transformer = Transformer(
dim = dim,
depth = depth,
dim_head = dim_head,
heads = heads
)
@beartype
def forward(
self,
texts: Optional[List[str]] = None,
text_encodings: Optional[Tensor] = None
):
assert exists(texts) ^ exists(text_encodings)
if not exists(text_encodings):
with torch.no_grad():
self.clip.eval()
_, text_encodings = self.clip.embed_texts(texts)
mask = (text_encodings != 0.).any(dim = -1)
text_encodings = self.project_in(text_encodings)
mask_with_global = F.pad(mask, (1, 0), value = True)
batch = text_encodings.shape[0]
global_tokens = repeat(self.learned_global_token, 'd -> b d', b = batch)
text_encodings, ps = pack([global_tokens, text_encodings], 'b * d')
text_encodings = self.transformer(text_encodings, mask = mask_with_global)
global_tokens, text_encodings = unpack(text_encodings, ps, 'b * d')
return global_tokens, text_encodings, mask
class EqualLinear(nn.Module):
def __init__(
self,
dim,
dim_out,
lr_mul = 1,
bias = True
):
super().__init__()
self.weight = nn.Parameter(torch.randn(dim_out, dim))
if bias:
self.bias = nn.Parameter(torch.zeros(dim_out))
self.lr_mul = lr_mul
def forward(self, input):
return F.linear(input, self.weight * self.lr_mul, bias=self.bias * self.lr_mul)
class StyleNetwork(nn.Module):
def __init__(
self,
dim,
depth,
lr_mul = 0.1,
dim_text_latent = 0
):
super().__init__()
self.dim = dim
self.dim_text_latent = dim_text_latent
layers = []
for i in range(depth):
is_first = i == 0
dim_in = (dim + dim_text_latent) if is_first else dim
layers.extend([EqualLinear(dim_in, dim, lr_mul), leaky_relu()])
self.net = nn.Sequential(*layers)
def forward(
self,
x,
text_latent = None
):
x = F.normalize(x, dim = 1)
if self.dim_text_latent > 0:
assert exists(text_latent)
x = torch.cat((x, text_latent), dim = -1)
return self.net(x)
class Noise(nn.Module):
def __init__(self, dim):
super().__init__()
self.weight = nn.Parameter(torch.zeros(dim, 1, 1))
def forward(
self,
x,
noise = None
):
b, _, h, w, device = *x.shape, x.device
if not exists(noise):
noise = torch.randn(b, 1, h, w, device = device)
return x + self.weight * noise
class BaseGenerator(nn.Module):
pass
class Generator(BaseGenerator):
@beartype
def __init__(
self,
*,
image_size,
dim_capacity = 16,
dim_max = 2048,
channels = 3,
style_network: Optional[Union[StyleNetwork, Dict]] = None,
style_network_dim = None,
text_encoder: Optional[Union[TextEncoder, Dict]] = None,
dim_latent = 512,
self_attn_resolutions: Tuple[int, ...] = (32, 16),
self_attn_dim_head = 64,
self_attn_heads = 8,
self_attn_dot_product = True,
self_attn_ff_mult = 4,
cross_attn_resolutions: Tuple[int, ...] = (32, 16),
cross_attn_dim_head = 64,
cross_attn_heads = 8,
cross_attn_ff_mult = 4,
num_conv_kernels = 2,
num_skip_layers_excite = 0,
unconditional = False,
pixel_shuffle_upsample = False
def init_(self, m):
if type(m) in {nn.Conv2d, nn.Linear}:
nn.init.kaiming_normal_(m.weight, a = 0, mode = 'fan_in', nonlinearity = 'leaky_relu')
@property
def total_params(self):
return sum([p.numel() for p in self.parameters() if p.requires_grad])
@property
def device(self):
return next(self.parameters()).device
@beartype
def forward(
self,
styles = None,
noise = None,
texts: Optional[List[str]] = None,
text_encodings: Optional[Tensor] = None,
global_text_tokens = None,
fine_text_tokens = None,
text_mask = None,
batch_size = 1,
return_all_rgbs = False
):
if not self.unconditional:
if exists(texts) or exists(text_encodings):
assert exists(texts) ^ exists(text_encodings), '要么传入原始文本作为 List[str],要么传入文本编码(来自 clip)作为 Tensor,但不能同时传入'
assert exists(self.text_encoder)
if exists(texts):
text_encoder_kwargs = dict(texts = texts)
elif exists(text_encodings):
text_encoder_kwargs = dict(text_encodings = text_encodings)
global_text_tokens, fine_text_tokens, text_mask = self.text_encoder(**text_encoder_kwargs)
else:
assert all([*map(exists, (global_text_tokens, fine_text_tokens, text_mask)]), '未传入原始文本或文本嵌入以进行条件训练'
else:
assert not any([*map(exists, (texts, global_text_tokens, fine_text_tokens))])
if not exists(styles):
assert exists(self.style_network)
if not exists(noise):
noise = torch.randn((batch_size, self.style_network_dim), device = self.device)
styles = self.style_network(noise, global_text_tokens)
conv_mods = self.style_to_conv_modulations(styles)
conv_mods = conv_mods.split(self.style_embed_split_dims, dim = -1)
conv_mods = iter(conv_mods)
batch_size = styles.shape[0]
x = repeat(self.init_block, 'c h w -> b c h w', b = batch_size)
x = self.init_conv(x, mod = next(conv_mods), kernel_mod = next(conv_mods))
rgb = torch.zeros((batch_size, self.channels, 4, 4), device = self.device, dtype = x.dtype)
excitations = [None] * self.num_skip_layers_excite
rgbs = []
for squeeze_excite, (resnet_conv1, noise1, act1, resnet_conv2, noise2, act2), to_rgb_conv, self_attn, cross_attn, upsample, upsample_rgb in self.layers:
if exists(upsample):
x = upsample(x)
if exists(squeeze_excite):
skip_excite = squeeze_excite(x)
excitations.append(skip_excite)
excite = safe_unshift(excitations)
if exists(excite):
x = x * excite
x = resnet_conv1(x, mod = next(conv_mods), kernel_mod = next(conv_mods))
x = noise1(x)
x = act1(x)
x = resnet_conv2(x, mod = next(conv_mods), kernel_mod = next(conv_mods))
x = noise2(x)
x = act2(x)
if exists(self_attn):
x = self_attn(x)
if exists(cross_attn):
x = cross_attn(x, context = fine_text_tokens, mask = text_mask)
layer_rgb = to_rgb_conv(x, mod = next(conv_mods), kernel_mod = next(conv_mods))
rgb = rgb + layer_rgb
rgbs.append(rgb)
if exists(upsample_rgb):
rgb = upsample_rgb(rgb)
assert is_empty([*conv_mods]), '卷积错误调制'
if return_all_rgbs:
return rgb, rgbs
return rgb
@beartype
class SimpleDecoder(nn.Module):
def __init__(
self,
dim,
*,
dims: Tuple[int, ...],
patch_dim: int = 1,
frac_patches: float = 1.,
dropout: float = 0.5
):
super().__init__()
assert 0 < frac_patches <= 1.
self.patch_dim = patch_dim
self.frac_patches = frac_patches
self.dropout = nn.Dropout(dropout)
dims = [dim, *dims]
layers = [conv2d_3x3(dim, dim)]
for dim_in, dim_out in zip(dims[:-1], dims[1:]):
layers.append(nn.Sequential(
Upsample(dim_in),
conv2d_3x3(dim_in, dim_out),
leaky_relu()
))
self.net = nn.Sequential(*layers)
@property
def device(self):
return next(self.parameters()).device
def forward(
self,
fmap,
orig_image
):
fmap = self.dropout(fmap)
if self.frac_patches < 1.:
batch, patch_dim = fmap.shape[0], self.patch_dim
fmap_size, img_size = fmap.shape[-1], orig_image.shape[-1]
assert divisible_by(fmap_size, patch_dim), f'feature map dimensions are {fmap_size}, but the patch dim was designated to be {patch_dim}'
assert divisible_by(img_size, patch_dim), f'image size is {img_size} but the patch dim was specified to be {patch_dim}'
fmap, orig_image = map(lambda t: rearrange(t, 'b c (p1 h) (p2 w) -> b (p1 p2) c h w', p1 = patch_dim, p2 = patch_dim), (fmap, orig_image))
total_patches = patch_dim ** 2
num_patches_recon = max(int(self.frac_patches * total_patches), 1)
batch_arange = torch.arange(batch, device = self.device)[..., None]
batch_randperm = torch.randn((batch, total_patches)).sort(dim = -1).indices
patch_indices = batch_randperm[..., :num_patches_recon]
fmap, orig_image = map(lambda t: t[batch_arange, patch_indices], (fmap, orig_image))
fmap, orig_image = map(lambda t: rearrange(t, 'b p ... -> (b p) ...'), (fmap, orig_image))
recon = self.net(fmap)
return F.mse_loss(recon, orig_image)
class RandomFixedProjection(nn.Module):
def __init__(
self,
dim,
dim_out,
channel_first = True
):
super().__init__()
weights = torch.randn(dim, dim_out)
nn.init.kaiming_normal_(weights, mode = 'fan_out', nonlinearity = 'linear')
self.channel_first = channel_first
self.register_buffer('fixed_weights', weights)
def forward(self, x):
if not self.channel_first:
return x @ self.fixed_weights
return einsum('b c ..., c d -> b d ...', x, self.fixed_weights)
class VisionAidedDiscriminator(nn.Module):
""" the vision-aided gan loss """
@beartype
def __init__(
self,
*,
depth = 2,
dim_head = 64,
heads = 8,
clip: Optional[OpenClipAdapter] = None,
layer_indices = (-1, -2, -3),
conv_dim = None,
text_dim = None,
unconditional = False,
num_conv_kernels = 2
):
super().__init__()
if not exists(clip):
clip = OpenClipAdapter()
self.clip = clip
dim = clip._dim_image_latent
self.unconditional = unconditional
text_dim = default(text_dim, dim)
conv_dim = default(conv_dim, dim)
self.layer_discriminators = nn.ModuleList([])
conv_klass = partial(AdaptiveConv2DMod, kernel = 3, num_conv_kernels = num_conv_kernels) if not unconditional else conv2d_3x3
for _ in layer_indices:
self.layer_discriminators.append(nn.ModuleList([
RandomFixedProjection(dim, conv_dim),
conv_klass(conv_dim, conv_dim),
nn.Linear(text_dim, conv_dim) if not unconditional else None,
nn.Linear(text_dim, num_conv_kernels) if not unconditional else None,
nn.Sequential(
conv2d_3x3(conv_dim, 1),
Rearrange('b 1 ... -> b ...')
)
]))
def parameters(self):
return self.layer_discriminators.parameters()
@property
def total_params(self):
return sum([p.numel() for p in self.parameters()])
@beartype
def forward(
self,
images,
texts: Optional[List[str]] = None,
text_embeds: Optional[Tensor] = None,
return_clip_encodings = False
):
assert self.unconditional or (exists(text_embeds) ^ exists(texts))
with torch.no_grad():
if not self.unconditional and exists(texts):
self.clip.eval()
text_embeds = self.clip.embed_texts
_, image_encodings = self.clip.embed_images(images)
logits = []
for layer_index, (rand_proj, conv, to_conv_mod, to_conv_kernel_mod, to_logits) in zip(self.layer_indices, self.layer_discriminators):
image_encoding = image_encodings[layer_index]
cls_token, rest_tokens = image_encoding[:, :1], image_encoding[:, 1:]
height_width = int(sqrt(rest_tokens.shape[-2]))
img_fmap = rearrange(rest_tokens, 'b (h w) d -> b d h w', h = height_width)
img_fmap = img_fmap + rearrange(cls_token, 'b 1 d -> b d 1 1 ')
img_fmap = rand_proj(img_fmap)
if self.unconditional:
img_fmap = conv(img_fmap)
else:
assert exists(text_embeds)
img_fmap = conv(
img_fmap,
mod = to_conv_mod(text_embeds),
kernel_mod = to_conv_kernel_mod(text_embeds)
)
layer_logits = to_logits(img_fmap)
logits.append(layer_logits)
if not return_clip_encodings:
return logits
return logits, image_encodings
class Predictor(nn.Module):
def __init__(
self,
dim,
depth = 4,
num_conv_kernels = 2,
unconditional = False
):
super().__init__()
self.unconditional = unconditional
self.residual_fn = nn.Conv2d(dim, dim, 1)
self.residual_scale = 2 ** -0.5
self.layers = nn.ModuleList([])
klass = nn.Conv2d if unconditional else partial(AdaptiveConv2DMod, num_conv_kernels = num_conv_kernels)
klass_kwargs = dict(padding = 1) if unconditional else dict()
for ind in range(depth):
self.layers.append(nn.ModuleList([
klass(dim, dim, 3, **klass_kwargs),
leaky_relu(),
klass(dim, dim, 3, **klass_kwargs),
leaky_relu()
]))
self.to_logits = nn.Conv2d(dim, 1, 1)
def forward(
self,
x,
mod = None,
kernel_mod = None
):
residual = self.residual_fn(x)
kwargs = dict()
if not self.unconditional:
kwargs = dict(mod = mod, kernel_mod = kernel_mod)
for conv1, activation, conv2, activation in self.layers:
inner_residual = x
x = conv1(x, **kwargs)
x = activation(x)
x = conv2(x, **kwargs)
x = activation(x)
x = x + inner_residual
x = x * self.residual_scale
x = x + residual
return self.to_logits(x)
class Discriminator(nn.Module):
@beartype
def __init__(
self,
*,
dim_capacity = 16,
image_size,
dim_max = 2048,
channels = 3,
attn_resolutions: Tuple[int, ...] = (32, 16),
attn_dim_head = 64,
attn_heads = 8,
self_attn_dot_product = False,
ff_mult = 4,
text_encoder: Optional[Union[TextEncoder, Dict]] = None,
text_dim = None,
filter_input_resolutions: bool = True,
multiscale_input_resolutions: Tuple[int, ...] = (64, 32, 16, 8),
multiscale_output_skip_stages: int = 1,
aux_recon_resolutions: Tuple[int, ...] = (8,),
aux_recon_patch_dims: Tuple[int, ...] = (2,),
aux_recon_frac_patches: Tuple[float, ...] = (0.25,),
aux_recon_fmap_dropout: float = 0.5,
resize_mode = 'bilinear',
num_conv_kernels = 2,
num_skip_layers_excite = 0,
unconditional = False,
predictor_depth = 2
def init_(self, m):
if type(m) in {nn.Conv2d, nn.Linear}:
nn.init.kaiming_normal_(m.weight, a = 0, mode = 'fan_in', nonlinearity = 'leaky_relu')
def resize_image_to(self, images, resolution):
return F.interpolate(images, resolution, mode = self.resize_mode)
def real_images_to_rgbs(self, images):
return [self.resize_image_to(images, resolution) for resolution in self.multiscale_input_resolutions]
@property
def total_params(self):
return sum([p.numel() for p in self.parameters()])
@property
def device(self):
return next(self.parameters()).device
@beartype
def forward(
self,
images,
rgbs: List[Tensor],
texts: Optional[List[str]] = None,
text_encodings: Optional[Tensor] = None,
text_embeds = None,
real_images = None,
return_multiscale_outputs = True,
calc_aux_loss = True
TrainDiscrLosses = namedtuple('TrainDiscrLosses', [
'divergence',
'multiscale_divergence',
'vision_aided_divergence',
'total_matching_aware_loss',
'gradient_penalty',
'aux_reconstruction'
])
TrainGenLosses = namedtuple('TrainGenLosses', [
'divergence',
'multiscale_divergence',
'total_vd_divergence',
'contrastive_loss'
])
class GigaGAN(nn.Module):
@beartype
def __init__(
self,
*,
generator: Union[BaseGenerator, Dict],
discriminator: Union[Discriminator, Dict],
vision_aided_discriminator: Optional[Union[VisionAidedDiscriminator, Dict]] = None,
diff_augment: Optional[Union[DiffAugment, Dict]] = None,
learning_rate = 2e-4,
betas = (0.5, 0.9),
weight_decay = 0.,
discr_aux_recon_loss_weight = 1.,
multiscale_divergence_loss_weight = 0.1,
vision_aided_divergence_loss_weight = 0.5,
generator_contrastive_loss_weight = 0.1,
matching_awareness_loss_weight = 0.1,
calc_multiscale_loss_every = 1,
apply_gradient_penalty_every = 4,
resize_image_mode = 'bilinear',
train_upsampler = False,
log_steps_every = 20,
create_ema_generator_at_init = True,
save_and_sample_every = 1000,
early_save_thres_steps = 2500,
early_save_and_sample_every = 100,
num_samples = 25,
model_folder = './gigagan-models',
results_folder = './gigagan-results',
sample_upsampler_dl: Optional[DataLoader] = None,
accelerator: Optional[Accelerator] = None,
accelerate_kwargs: dict = {},
find_unused_parameters = True,
amp = False,
mixed_precision_type = 'fp16'
def save(self, path, overwrite = True):
path = Path(path)
mkdir_if_not_exists(path.parents[0])
assert overwrite or not path.exists()
pkg = dict(
G = self.unwrapped_G.state_dict(),
D = self.unwrapped_D.state_dict(),
G_opt = self.G_opt.state_dict(),
D_opt = self.D_opt.state_dict(),
steps = self.steps.item(),
version = __version__
)
if exists(self.G_opt.scaler):
pkg['G_scaler'] = self.G_opt.scaler.state_dict()
if exists(self.D_opt.scaler):
pkg['D_scaler'] = self.D_opt.scaler.state_dict()
if exists(self.VD):
pkg['VD'] = self.unwrapped_VD.state_dict()
pkg['VD_opt'] = self.VD_opt.state_dict()
if exists(self.VD_opt.scaler):
pkg['VD_scaler'] = self.VD_opt.scaler.state_dict()
if self.has_ema_generator:
pkg['G_ema'] = self.G_ema.state_dict()
torch.save(pkg, str(path))
def load(self, path, strict = False):
path = Path(path)
assert path.exists()
pkg = torch.load(str(path))
if 'version' in pkg and pkg['version'] != __version__:
print(f"trying to load from version {pkg['version']}")
self.unwrapped_G.load_state_dict(pkg['G'], strict = strict)
self.unwrapped_D.load_state_dict(pkg['D'], strict = strict)
if exists(self.VD):
self.unwrapped_VD.load_state_dict(pkg['VD'], strict = strict)
if self.has_ema_generator:
self.G_ema.load_state_dict(pkg['G_ema'])
if 'steps' in pkg:
self.steps.copy_(torch.tensor([pkg['steps']]))
if 'G_opt'not in pkg or 'D_opt' not in pkg:
return
try:
self.G_opt.load_state_dict(pkg['G_opt'])
self.D_opt.load_state_dict(pkg['D_opt'])
if exists(self.VD):
self.VD_opt.load_state_dict(pkg['VD_opt'])
if 'G_scaler' in pkg and exists(self.G_opt.scaler):
self.G_opt.scaler.load_state_dict(pkg['G_scaler'])
if 'D_scaler' in pkg and exists(self.D_opt.scaler):
self.D_opt.scaler.load_state_dict(pkg['D_scaler'])
if 'VD_scaler' in pkg and exists(self.VD_opt.scaler):
self.VD_opt.scaler.load_state_dict(pkg['VD_scaler'])
except Exception as e:
self.print(f'unable to load optimizers {e.msg}- optimizer states will be reset')
pass
@property
def device(self):
return self.accelerator.device
@property
def unwrapped_G(self):
return self.accelerator.unwrap_model(self.G)
@property
def unwrapped_D(self):
return self.accelerator.unwrap_model(self.D)
@property
def unwrapped_VD(self):
return self.accelerator.unwrap_model(self.VD)
@property
def need_vision_aided_discriminator(self):
return exists(self.VD) and self.vision_aided_divergence_loss_weight > 0.
@property
def need_contrastive_loss(self):
return self.generator_contrastive_loss_weight > 0. and not self.unconditional
def print(self, msg):
self.accelerator.print(msg)
@property
def is_distributed(self):
return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1)
@property
def is_main(self):
return self.accelerator.is_main_process
@property
def is_local_main(self):
return self.accelerator.is_local_main_process
def resize_image_to(self, images, resolution):
return F.interpolate(images, resolution, mode = self.resize_image_mode)
@beartype
def set_dataloader(self, dl: DataLoader):
assert not exists(self.train_dl), 'training dataloader has already been set'
self.train_dl = dl
self.train_dl_batch_size = dl.batch_size
self.train_dl = self.accelerator.prepare(self.train_dl)
@torch.inference_mode()
def generate(self, *args, **kwargs):
model = self.G_ema if self.has_ema_generator else self.G
model.eval()
return model(*args, **kwargs)
def create_ema_generator(
self,
update_every = 10,
update_after_step = 100,
decay = 0.995
):
if not self.is_main:
return
assert not self.has_ema_generator, 'EMA generator has already been created'
self.G_ema = EMA(self.unwrapped_G, update_every = update_every, update_after_step = update_after_step, beta = decay)
self.has_ema_generator = True
def generate_kwargs(self, dl_iter, batch_size):
maybe_text_kwargs = dict()
if self.train_upsampler or not self.unconditional:
assert exists(dl_iter)
if self.unconditional:
real_images = next(dl_iter)
else:
result = next(dl_iter)
assert isinstance(result, tuple), 'dataset should return a tuple of two items for text conditioned training, (images: Tensor, texts: List[str])'
real_images, texts = result
maybe_text_kwargs['texts'] = texts[:batch_size]
real_images = real_images.to(self.device)
if self.train_upsampler:
size = self.unwrapped_G.input_image_size
lowres_real_images = F.interpolate(real_images, (size, size))
G_kwargs = dict(lowres_image = lowres_real_images)
else:
assert exists(batch_size)
G_kwargs = dict(batch_size = batch_size)
noise = torch.randn(batch_size, self.unwrapped_G.style_network.dim, device = self.device)
G_kwargs.update(noise = noise)
return G_kwargs, maybe_text_kwargs
@beartype
def train_discriminator_step(
self,
dl_iter: Iterable,
grad_accum_every = 1,
apply_gradient_penalty = False,
calc_multiscale_loss = True
def train_generator_step(
self,
batch_size = None,
dl_iter: Optional[Iterable] = None,
grad_accum_every = 1,
calc_multiscale_loss = True
):
total_divergence = 0.
total_multiscale_divergence = 0. if calc_multiscale_loss else None
total_vd_divergence = 0.
contrastive_loss = 0.
self.G.train()
self.D.train()
self.D_opt.zero_grad()
self.G_opt.zero_grad()
all_images = []
all_texts = []
for _ in range(grad_accum_every):
G_kwargs, maybe_text_kwargs = self.generate_kwargs(dl_iter, batch_size)
with self.accelerator.autocast():
images, rgbs = self.G(
**G_kwargs,
**maybe_text_kwargs,
return_all_rgbs = True
)
if exists(self.diff_augment):
images, rgbs = self.diff_augment(images, rgbs)
if self.need_contrastive_loss:
all_images.append(images)
all_texts.extend(maybe_text_kwargs['texts'])
logits, multiscale_logits, _ = self.D(
images,
rgbs,
**maybe_text_kwargs,
return_multiscale_outputs = calc_multiscale_loss,
calc_aux_loss = False
)
divergence = generator_hinge_loss(logits)
total_divergence += (divergence.item() / grad_accum_every)
total_loss = divergence
if self.multiscale_divergence_loss_weight > 0. and len(multiscale_logits) > 0:
multiscale_divergence = 0.
for multiscale_logit in multiscale_logits:
multiscale_divergence = multiscale_divergence + generator_hinge_loss(multiscale_logit)
total_multiscale_divergence += (multiscale_divergence.item() / grad_accum_every)
total_loss = total_loss + multiscale_divergence * self.multiscale_divergence_loss_weight
if self.need_vision_aided_discriminator:
vd_loss = 0.
logits = self.VD(images, **maybe_text_kwargs)
for logit in logits:
vd_loss = vd_loss + generator_hinge_loss(logit)
total_vd_divergence += (vd_loss.item() / grad_accum_every)
total_loss = total_loss + vd_loss * self.vision_aided_divergence_loss_weight
self.accelerator.backward(total_loss / grad_accum_every, retain_graph = self.need_contrastive_loss)
if self.need_contrastive_loss:
all_images = torch.cat(all_images, dim = 0)
contrastive_loss = aux_clip_loss(
clip = self.G.text_encoder.clip,
texts = all_texts,
images = all_images
)
self.accelerator.backward(contrastive_loss * self.generator_contrastive_loss_weight)
self.G_opt.step()
self.accelerator.wait_for_everyone()
if self.is_main and self.has_ema_generator:
self.G_ema.update()
return TrainGenLosses(
total_divergence,
total_multiscale_divergence,
total_vd_divergence,
contrastive_loss
)
def sample(self, model, dl_iter, batch_size):
G_kwargs, maybe_text_kwargs = self.generate_kwargs(dl_iter, batch_size)
with self.accelerator.autocast():
generator_output = model(**G_kwargs, **maybe_text_kwargs)
if not self.train_upsampler:
return generator_output
output_size = generator_output.shape[-1]
lowres_image = G_kwargs['lowres_image']
lowres_image = F.interpolate(lowres_image, (output_size, output_size))
return torch.cat([lowres_image, generator_output])
@torch.inference_mode()
def save_sample(
self,
batch_size,
dl_iter = None
):
milestone = self.steps.item() // self.save_and_sample_every
nrow_mult = 2 if self.train_upsampler else 1
batches = num_to_groups(self.num_samples, batch_size)
if self.train_upsampler:
dl_iter = default(self.sample_upsampler_dl_iter, dl_iter)
assert exists(dl_iter)
sample_models_and_output_file_name = [(self.unwrapped_G, f'sample-{milestone}.png')]
if self.has_ema_generator:
sample_models_and_output_file_name.append((self.G_ema, f'ema-sample-{milestone}.png'))
for model, filename in sample_models_and_output_file_name:
model.eval()
all_images_list = list(map(lambda n: self.sample(model, dl_iter, n), batches))
all_images = torch.cat(all_images_list, dim = 0)
all_images.clamp_(0., 1.)
utils.save_image(
all_images,
str(self.results_folder / filename),
nrow = int(sqrt(self.num_samples)) * nrow_mult
)
self.save(str(self.model_folder / f'model-{milestone}.ckpt'))
@beartype
def forward(
self,
*,
steps,
grad_accum_every = 1
):
assert exists(self.train_dl), 'you need to set the dataloader by running .set_dataloader(dl: Dataloader)'
batch_size = self.train_dl_batch_size
dl_iter = cycle(self.train_dl)
last_gp_loss = 0.
last_multiscale_d_loss = 0.
last_multiscale_g_loss = 0.
for _ in tqdm(range(steps), initial = self.steps.item()):
steps = self.steps.item()
is_first_step = steps == 1
apply_gradient_penalty = self.apply_gradient_penalty_every > 0 and divisible_by(steps, self.apply_gradient_penalty_every)
calc_multiscale_loss = self.calc_multiscale_loss_every > 0 and divisible_by(steps, self.calc_multiscale_loss_every)
(
d_loss,
multiscale_d_loss,
vision_aided_d_loss,
matching_aware_loss,
gp_loss,
recon_loss
) = self.train_discriminator_step(
dl_iter = dl_iter,
grad_accum_every = grad_accum_every,
apply_gradient_penalty = apply_gradient_penalty,
calc_multiscale_loss = calc_multiscale_loss
)
self.accelerator.wait_for_everyone()
(
g_loss,
multiscale_g_loss,
vision_aided_g_loss,
contrastive_loss
) = self.train_generator_step(
dl_iter = dl_iter,
batch_size = batch_size,
grad_accum_every = grad_accum_every,
calc_multiscale_loss = calc_multiscale_loss
)
if exists(gp_loss):
last_gp_loss = gp_loss
if exists(multiscale_d_loss):
last_multiscale_d_loss = multiscale_d_loss
if exists(multiscale_g_loss):
last_multiscale_g_loss = multiscale_g_loss
if is_first_step or divisible_by(steps, self.log_steps_every):
losses = (
('G', g_loss),
('MSG', last_multiscale_g_loss),
('VG', vision_aided_g_loss),
('D', d_loss),
('MSD', last_multiscale_d_loss),
('VD', vision_aided_d_loss),
('GP', last_gp_loss),
('SSL', recon_loss),
('CL', contrastive_loss),
('MAL', matching_aware_loss)
)
losses_str = ' | '.join([f'{loss_name}: {loss:.2f}' for loss_name, loss in losses])
self.print(losses_str)
self.accelerator.wait_for_everyone()
if self.is_main and (is_first_step or divisible_by(steps, self.save_and_sample_every) or (steps <= self.early_save_thres_steps and divisible_by(steps, self.early_save_and_sample_every))):
self.save_sample(batch_size, dl_iter)
self.steps += 1
self.print(f'complete {steps} training steps')