Lucidrains 系列项目源码解析(六十七)
.\lucidrains\nuwa-pytorch\nuwa_pytorch\optimizer.py
import torch
def separate_weight_decayable_params(params):
no_wd_params = set([param for param in params if param.ndim < 2])
wd_params = set(params) - no_wd_params
return wd_params, no_wd_params
def get_optimizer(
params,
lr = 3e-4,
wd = 1e-1,
filter_by_requires_grad = False
):
if filter_by_requires_grad:
params = list(filter(lambda t: t.requires_grad, params))
if wd == 0:
return Adam(list(params), lr = lr)
params = set(params)
wd_params, no_wd_params = separate_weight_decayable_params(params)
param_groups = [
{'params': list(wd_params)},
{'params': list(no_wd_params), 'weight_decay': 0},
]
return AdamW(param_groups, lr = lr, weight_decay = wd)
.\lucidrains\nuwa-pytorch\nuwa_pytorch\reversible.py
import torch
import torch.nn as nn
from operator import itemgetter
from torch.autograd.function import Function
def route_args(router, args, depth):
routed_args = [(dict(), dict()) for _ in range(depth)]
matched_keys = [key for key in args.keys() if key in router]
for key in matched_keys:
val = args[key]
for depth, ((f_args, g_args), routes) in enumerate(zip(routed_args, router[key])):
new_f_args, new_g_args = map(lambda route: ({key: val} if route else {}), routes)
routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args})
return routed_args
class Deterministic(nn.Module):
def __init__(self, net):
super().__init__()
self.net = net
self.cpu_state = None
self.cuda_in_fwd = None
self.gpu_devices = None
self.gpu_states = None
def record_rng(self, *args):
self.cpu_state = torch.get_rng_state()
if torch.cuda._initialized:
self.cuda_in_fwd = True
self.gpu_devices, self.gpu_states = get_device_states(*args)
def forward(self, *args, record_rng = False, set_rng = False, **kwargs):
if record_rng:
self.record_rng(*args)
if not set_rng:
return self.net(*args, **kwargs)
rng_devices = []
if self.cuda_in_fwd:
rng_devices = self.gpu_devices
with torch.random.fork_rng(devices=rng_devices, enabled=True):
torch.set_rng_state(self.cpu_state)
if self.cuda_in_fwd:
set_device_states(self.gpu_devices, self.gpu_states)
return self.net(*args, **kwargs)
class ReversibleBlock(nn.Module):
def __init__(self, f, g):
super().__init__()
self.f = Deterministic(f)
self.g = Deterministic(g)
def forward(self, x, f_args = {}, g_args = {}):
x1, x2 = torch.chunk(x, 2, dim=2)
y1, y2 = None, None
with torch.no_grad():
y1 = x1 + self.f(x2, record_rng=self.training, **f_args)
y2 = x2 + self.g(y1, record_rng=self.training, **g_args)
return torch.cat([y1, y2], dim=2)
def backward_pass(self, y, dy, f_args = {}, g_args = {}):
y1, y2 = torch.chunk(y, 2, dim=2)
del y
dy1, dy2 = torch.chunk(dy, 2, dim=2)
del dy
with torch.enable_grad():
y1.requires_grad = True
gy1 = self.g(y1, set_rng=True, **g_args)
torch.autograd.backward(gy1, dy2)
with torch.no_grad():
x2 = y2 - gy1
del y2, gy1
dx1 = dy1 + y1.grad
del dy1
y1.grad = None
with torch.enable_grad():
x2.requires_grad = True
fx2 = self.f(x2, set_rng=True, **f_args)
torch.autograd.backward(fx2, dx1, retain_graph=True)
with torch.no_grad():
x1 = y1 - fx2
del y1, fx2
dx2 = dy2 + x2.grad
del dy2
x2.grad = None
x = torch.cat([x1, x2.detach()], dim=2)
dx = torch.cat([dx1, dx2], dim=2)
return x, dx
class _ReversibleFunction(Function):
@staticmethod
def forward(ctx, x, blocks, args):
ctx.args = args
for block, kwarg in zip(blocks, args):
x = block(x, **kwarg)
ctx.y = x.detach()
ctx.blocks = blocks
return x
@staticmethod
def backward(ctx, dy):
y = ctx.y
args = ctx.args
for block, kwargs in zip(ctx.blocks[::-1], args[::-1]):
y, dy = block.backward_pass(y, dy, **kwargs)
return dy, None, None
class ReversibleSequence(nn.Module):
def __init__(self, blocks, args_route = {}):
super().__init__()
self.args_route = args_route
self.blocks = nn.ModuleList([ReversibleBlock(f=f, g=g) for f, g in blocks])
def forward(self, x, **kwargs):
x = torch.cat([x, x], dim=-1)
blocks = self.blocks
args = route_args(self.args_route, kwargs, len(blocks))
args = list(map(lambda x: {'f_args': x[0], 'g_args': x[1]}, args))
layers_and_args = list(zip(blocks, args))
out = _ReversibleFunction.apply(x, blocks, args)
return torch.stack(out.chunk(2, dim=-1)).sum(dim=0)
.\lucidrains\nuwa-pytorch\nuwa_pytorch\reversible_video_audio.py
import torch
import torch.nn as nn
from torch.autograd.function import Function
from contextlib import contextmanager
from nuwa_pytorch.reversible import Deterministic
from einops import reduce
def exists(val):
return val is not None
@contextmanager
def null_context():
yield
def split_at_index(dim, index, t):
pre_slices = (slice(None),) * dim
l = (*pre_slices, slice(None, index))
r = (*pre_slices, slice(index, None))
return t[l], t[r]
class ReversibleSelfAttnBlock(nn.Module):
def __init__(self, f, g, j, k):
super().__init__()
self.f = Deterministic(f)
self.g = Deterministic(g)
self.j = Deterministic(j)
self.k = Deterministic(k)
def forward(self, x, m, _reverse = True, **kwargs):
x1, x2 = torch.chunk(x, 2, dim = 2)
m1, m2 = torch.chunk(m, 2, dim = 2)
y1, y2, n1, n2 = None, None, None, None
fn_context = torch.no_grad if _reverse else null_context
record_rng = self.training and _reverse
with fn_context():
y1 = x1 + self.f(x2, record_rng = record_rng)
y2 = x2 + self.g(y1, record_rng = record_rng)
n1 = m1 + self.j(m2, record_rng = record_rng)
n2 = m2 + self.k(n1, record_rng = record_rng)
return torch.cat((y1, y2), dim = 2), torch.cat((n1, n2), dim = 2)
def backward_pass(self, y, n, dy, dn, **kwargs):
y1, y2 = torch.chunk(y, 2, dim = 2)
del y
dy1, dy2 = torch.chunk(dy, 2, dim = 2)
del dy
with torch.enable_grad():
y1.requires_grad = True
gy1 = self.g(y1, set_rng = True)
torch.autograd.backward(gy1, dy2)
with torch.no_grad():
x2 = y2 - gy1
del y2, gy1
dx1 = dy1 + y1.grad
del dy1
y1.grad = None
with torch.enable_grad():
x2.requires_grad = True
fx2 = self.f(x2, set_rng = True)
torch.autograd.backward(fx2, dx1, retain_graph = True)
with torch.no_grad():
x1 = y1 - fx2
del y1, fx2
dx2 = dy2 + x2.grad
del dy2
x2.grad = None
x = torch.cat([x1, x2.detach()], dim = 2)
dx = torch.cat([dx1, dx2], dim = 2)
n1, n2 = torch.chunk(n, 2, dim = 2)
del n
dn1, dn2 = torch.chunk(dn, 2, dim = 2)
del dn
with torch.enable_grad():
n1.requires_grad = True
gn1 = self.k(n1, set_rng = True)
torch.autograd.backward(gn1, dn2)
with torch.no_grad():
m2 = n2 - gn1
del n2, gn1
dm1 = dn1 + n1.grad
del dn1
n1.grad = None
with torch.enable_grad():
m2.requires_grad = True
fm2 = self.j(m2, set_rng = True)
torch.autograd.backward(fm2, dm1, retain_graph=True)
with torch.no_grad():
m1 = n1 - fm2
del n1, fm2
dm2 = dn2 + m2.grad
del dn2
m2.grad = None
m = torch.cat([m1, m2.detach()], dim = 2)
dm = torch.cat([dm1, dm2], dim = 2)
return x, m, dx, dm
class ReversibleCrossAttnBlock(nn.Module):
def __init__(self, f, g, j, k):
super().__init__()
self.f = Deterministic(f)
self.g = Deterministic(g)
self.j = Deterministic(j)
self.k = Deterministic(k)
def forward(self, x, m, *, context, context_mask, video_mask = None, audio_mask = None, _reverse = True, **kwargs):
x1, x2 = torch.chunk(x, 2, dim = 2)
m1, m2 = torch.chunk(m, 2, dim = 2)
y1, y2, n1, n2 = None, None, None, None
fn_context = torch.no_grad if _reverse else null_context
record_rng = self.training and _reverse
with fn_context():
y1 = x1 + self.f(x2, context = context, context_mask = context_mask, mask = video_mask, record_rng = record_rng)
y2 = x2 + self.g(y1, record_rng = record_rng)
n1 = m1 + self.j(m2, context = context, context_mask = context_mask, mask = audio_mask, record_rng = record_rng)
n2 = m2 + self.k(n1, record_rng = record_rng)
return torch.cat((y1, y2), dim = 2), torch.cat((n1, n2), dim = 2)
def backward_pass(self, y, n, dy, dn, *, context, context_mask, video_mask = None, audio_mask = None, **kwargs):
y1, y2 = torch.chunk(y, 2, dim = 2)
del y
dy1, dy2 = torch.chunk(dy, 2, dim = 2)
del dy
with torch.enable_grad():
y1.requires_grad = True
gy1 = self.g(y1, set_rng = True)
torch.autograd.backward(gy1, dy2)
with torch.no_grad():
x2 = y2 - gy1
del y2, gy1
dx1 = dy1 + y1.grad
del dy1
y1.grad = None
with torch.enable_grad():
x2.requires_grad = True
fx2 = self.f(x2, set_rng = True, context = context, context_mask = context_mask, mask = video_mask)
torch.autograd.backward(fx2, dx1, retain_graph = True)
with torch.no_grad():
x1 = y1 - fx2
del y1, fx2
dx2 = dy2 + x2.grad
del dy2
x2.grad = None
x = torch.cat([x1, x2.detach()], dim = 2)
dx = torch.cat([dx1, dx2], dim = 2)
n1, n2 = torch.chunk(n, 2, dim = 2)
del n
dn1, dn2 = torch.chunk(dn, 2, dim = 2)
del dn
with torch.enable_grad():
n1.requires_grad = True
gn1 = self.k(n1, set_rng = True)
torch.autograd.backward(gn1, dn2)
with torch.no_grad():
m2 = n2 - gn1
del n2, gn1
dm1 = dn1 + n1.grad
del dn1
n1.grad = None
with torch.enable_grad():
m2.requires_grad = True
fm2 = self.j(m2, set_rng = True, context = context, context_mask = context_mask, mask = audio_mask)
torch.autograd.backward(fm2, dm1, retain_graph=True)
with torch.no_grad():
m1 = n1 - fm2
del n1, fm2
dm2 = dn2 + m2.grad
del dn2
m2.grad = None
m = torch.cat([m1, m2.detach()], dim = 2)
dm = torch.cat([dm1, dm2], dim = 2)
return x, m, dx, dm
class ReversibleCrossModalityAttnBlock(nn.Module):
def __init__(self, f, g, j, k):
super().__init__()
self.f = Deterministic(f)
self.g = Deterministic(g)
self.j = Deterministic(j)
self.k = Deterministic(k)
def forward(self, x, m, *, video_mask = None, audio_mask = None, _reverse = True, **kwargs):
x1, x2 = torch.chunk(x, 2, dim = 2)
m1, m2 = torch.chunk(m, 2, dim = 2)
y1, y2, n1, n2 = None, None, None, None
fn_context = torch.no_grad if _reverse else null_context
record_rng = self.training and _reverse
with fn_context():
y1 = x1 + self.f(x2, m2, record_rng = record_rng, mask = video_mask, context_mask = audio_mask)
y2 = x2 + self.k(y1, record_rng = record_rng)
n1 = m1 + self.j(m2, y2, record_rng = record_rng, mask = audio_mask, context_mask = video_mask)
n2 = m2 + self.g(n1, record_rng = record_rng)
return torch.cat((y1, y2), dim = 2), torch.cat((n1, n2), dim = 2)
def backward_pass(self, y, n, dy, dn, video_mask = None, audio_mask = None, **kwargs):
n1, n2 = torch.chunk(n, 2, dim = 2)
del n
dn1, dn2 = torch.chunk(dn, 2, dim = 2)
del dn
y1, y2 = torch.chunk(y, 2, dim = 2)
del y
dy1, dy2 = torch.chunk(dy, 2, dim = 2)
del dy
with torch.enable_grad():
n1.requires_grad = True
gn1 = self.g(n1, set_rng = True)
torch.autograd.backward(gn1, dn2)
with torch.no_grad():
m2 = n2 - gn1
del n2, gn1
dm1 = dn1 + n1.grad
del dn1
n1.grad = None
with torch.enable_grad():
m2.requires_grad = True
y2.requires_grad = True
fm2 = self.j(m2, y2, set_rng=True, mask = audio_mask, context_mask = video_mask)
torch.autograd.backward(fm2, dm1)
with torch.no_grad():
m1 = n1 - fm2
del n1, fm2
dm2 = dn2 + m2.grad
dx2 = dy2 + y2.grad
del dn2
del dy2
m2.grad = None
y2.grad = None
with torch.enable_grad():
y1.requires_grad = True
gy1 = self.k(y1, set_rng = True)
torch.autograd.backward(gy1, dx2)
with torch.no_grad():
x2 = y2 - gy1
del y2, gy1
dx1 = dy1 + y1.grad
del dy1
y1.grad = None
with torch.enable_grad():
x2.requires_grad = True
m2.requires_grad = True
fx2 = self.f(x2, m2, set_rng = True, mask = video_mask, context_mask = audio_mask)
torch.autograd.backward(fx2, dx1)
with torch.no_grad():
x1 = y1 - fx2
del y1, fx2
dx2 = dx2 + x2.grad
dm2 = dm2 + m2.grad
x2.grad = None
m2.grad = None
with torch.no_grad():
m = torch.cat([m1, m2.detach()], dim = 2)
dm = torch.cat([dm1, dm2], dim = 2)
x = torch.cat([x1, x2.detach()], dim = 2)
dx = torch.cat([dx1, dx2], dim = 2)
return x, m, dx, dm
class ReversibleFunction(Function):
@staticmethod
def forward(ctx, inp, ind, blocks, kwargs):
x, m = split_at_index(1, ind, inp)
for block in blocks:
x, m = block(x, m, _reverse = True, **kwargs)
ctx.blocks = blocks
ctx.kwargs = kwargs
ctx.ind = ind
ctx.save_for_backward(x.detach(), m.detach())
return torch.cat((x, m), dim = 1)
@staticmethod
def backward(ctx, d):
ind = ctx.ind
blocks = ctx.blocks
kwargs = ctx.kwargs
dy, dn = split_at_index(1, ind, d)
y, n = ctx.saved_tensors
for block in blocks[::-1]:
y, n, dy, dn = block.backward_pass(y, n, dy, dn, **kwargs)
d = torch.cat((dy, dn), dim=1)
return d, None, None, None
reversible_apply = ReversibleFunction.apply
def irreversible_apply(inputs, ind, blocks, kwargs):
x, m = split_at_index(1, ind, inputs)
for block in blocks:
x, m = block(x, m, _reverse = False, **kwargs)
return torch.cat((x, m), dim = 1)
class DualModalityReversibleSequence(nn.Module):
def __init__(self, input_blocks, block_types):
super().__init__()
self.block_types = block_types
blocks = nn.ModuleList([])
for block, block_type in zip(input_blocks, block_types):
if block_type == 'intra_modality_self_attn':
reversible_klass = ReversibleSelfAttnBlock
elif block_type == 'intra_modality_cross_attn':
reversible_klass = ReversibleCrossAttnBlock
elif block_type == 'inter_modality_cross_attn':
reversible_klass = ReversibleCrossModalityAttnBlock
else:
raise ValueError(f'unknown layer type {block_type}')
blocks.append(reversible_klass(*block))
self.blocks = blocks
def forward(
self,
video,
audio,
*,
context,
context_mask = None,
video_mask = None,
audio_mask = None,
reverse = True
):
blocks = self.blocks
video, audio = list(map(lambda t: torch.cat((t, t), dim = -1), (video, audio)))
kwargs = {'context': context, 'context_mask': context_mask, 'video_mask': video_mask, 'audio_mask': audio_mask}
fn = reversible_apply if reverse else irreversible_apply
ind = video.shape[1]
inp = torch.cat((video, audio), dim = 1)
out = fn(inp, ind, blocks, kwargs)
video, audio = split_at_index(1, ind, out)
return list(map(lambda t: reduce(t, 'b n (c d) -> b n d', 'mean', c = 2), (video, audio)))
.\lucidrains\nuwa-pytorch\nuwa_pytorch\train_nuwa.py
from random import randrange
from pathlib import Path
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from einops import rearrange
from tqdm import tqdm
import numpy as np
from shutil import rmtree
from nuwa_pytorch.tokenizer import tokenizer
from nuwa_pytorch.optimizer import get_optimizer
from nuwa_pytorch.image_utils import gif_to_tensor
import torchvision.transforms as T
def exists(val):
return val is not None
def noop(*args, **kwargs):
pass
def cycle(dl):
while True:
for data in dl:
yield data
def cast_tuple(t):
return t if isinstance(t, (tuple, list)) else (t,)
def yes_or_no(question):
answer = input(f'{question} (y/n) ')
return answer.lower() in ('yes', 'y')
def accum_log(log, new_logs):
for key, new_value in new_logs.items():
old_value = log.get(key, 0.)
log[key] = old_value + new_value
return log
def pad_collate_fn(batch):
texts, videos = zip(*batch)
return pad_sequence(texts, batch_first = True), torch.stack(videos)
def convert_video_tensor_dataset_to_indices(
*,
vae,
raw_video_dataset,
num_frames,
path,
):
vae_device = next(vae.parameters()).device
num_videos = len(raw_video_dataset)
assert num_videos > 0, 'there must be at least 1 video'
fmap_size = vae.image_size // (vae.num_layers ** 2)
shape = (num_videos, num_frames * fmap_size * fmap_size)
video_indices_memmap = np.memmap(path, mode = 'w+', dtype = np.int64, shape = shape)
for ind in tqdm(range(num_videos)):
_, video = raw_video_dataset[ind]
video = rearrange(video, '... -> 1 ...')
video = video.to(vae_device)
indices = vae.get_video_indices(video)
indices = rearrange(indices, '1 f h w -> (f h w)')
video_indices_memmap[ind] = indices.cpu().numpy()
print(f'completed conversion of {num_videos} videos to indices at {path}')
class MnistDataset(Dataset):
def __init__(
self,
num_videos,
videos_memmap_path,
text_memmap_path,
num_digits = 2,
num_frames = 10,
image_size = 64,
channels = 1,
random_rotate = False
):
super().__init__()
self.num_videos = num_videos
self.videos_memmap = np.memmap(videos_memmap_path, mode = 'r', dtype = np.uint8, shape = (num_videos, num_frames, channels, image_size, image_size))
self.text_memmap = np.memmap(text_memmap_path, mode = 'r', dtype = np.uint8, shape = (num_videos, num_digits))
self.random_rotate = random_rotate
def __len__(self):
return self.num_videos
def __getitem__(self, idx):
video = torch.from_numpy(self.videos_memmap[idx].copy()).float()
label = torch.from_numpy(self.text_memmap[idx].copy())
video /= 255
video = video.to(torch.float32)
text = tokenizer.encode(' '.join(map(str, label.tolist())))
text = torch.Tensor(text).long()
if self.random_rotate:
video = T.functional.rotate(video, choice([0, 90, 180, 270]))
return text, video
class VideoIndicesDataset(Dataset):
def __init__(
self,
*,
videos_memmap_path,
text_memmap_path,
vae,
num_videos,
num_frames,
num_digits = 2,
):
self.num_videos = num_videos
fmap_size = vae.image_size // (vae.num_layers ** 2)
self.videos_memmap = np.memmap(videos_memmap_path, mode = 'r', dtype = np.int64, shape = (num_videos, num_frames * (fmap_size ** 2)))
self.text_memmap = np.memmap(text_memmap_path, mode = 'r', dtype = np.uint8, shape = (num_videos, num_digits))
def __len__(self):
return self.num_videos
def __getitem__(self, idx):
video = torch.from_numpy(self.videos_memmap[idx].copy())
text = torch.from_numpy(self.text_memmap[idx].copy())
text = tokenizer.encode(' '.join(map(str, text.tolist())))
text = torch.Tensor(text).long()
video = video.long()
return text, video
class GifVideoDataset(Dataset):
def __init__(
self,
*,
folder,
channels = 1
):
folder = Path(folder)
gifs = folder.glob('**/*.gif')
txts = folder.glob('**/*.txt')
gif_path_stems = set(map(lambda t: str(t.with_suffix('')), gifs))
txt_path_stems = set(map(lambda t: str(t.with_suffix('')), txts))
self.path_stems = list(gif_path_stems.intersection(txt_path_stems))
self.channels = channels
print(f'{len(self.path_stems)} video / text pairs found')
def __len__(self):
return len(self.path_stems)
def __getitem__(self, idx):
path_stem = self.path_stems[idx]
txt_path = Path(f'{path_stem}.txt')
txt_str = txt_path.read_text()
text_tensor = torch.Tensor(tokenizer.encode(txt_str)).long()
video_tensor = gif_to_tensor(f'{path_stem}.gif', channels = self.channels)
return text_tensor, video_tensor
class NUWATrainer(nn.Module):
def __init__(
self,
*,
nuwa,
dataset,
num_train_steps,
lr = 3e-4,
wd = 0.01,
batch_size = 4,
grad_accum_every = 8,
max_grad_norm = 0.5,
save_model_every = 2500,
save_results_every = 1000,
results_folder = './results-nuwa',
num_sampled_frames = float('inf')
):
super().__init__()
assert isinstance(nuwa, NUWA), 'nuwa must be an instance of NUWA'
self.nuwa = nuwa
self.steps = 0
self.num_train_steps = num_train_steps
self.batch_size = batch_size
self.grad_accum_every = grad_accum_every
self.max_grad_norm = max_grad_norm
self.optim = get_optimizer(nuwa.parameters(), lr = lr, wd = wd)
self.ds = dataset
self.dl = cycle(DataLoader(
self.ds,
batch_size = batch_size,
collate_fn = pad_collate_fn,
shuffle = True
))
self.save_model_every = save_model_every
self.save_results_every = save_results_every
self.num_sampled_frames = num_sampled_frames
self.results_folder = Path(results_folder)
if len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?'):
rmtree(str(self.results_folder))
self.results_folder.mkdir(parents = True, exist_ok = True)
def train_step(self):
device = next(self.nuwa.parameters()).device
self.nuwa.train()
logs = {}
for _ in range(self.grad_accum_every):
text, video = next(self.dl)
text, video = map(lambda t: t.to(device), (text, video))
loss = self.nuwa(
text = text,
video = video,
return_loss = True
)
accum_log(logs, {'loss': loss.item() / self.grad_accum_every})
(loss / self.grad_accum_every).backward()
print(f'{self.steps} loss: {logs["loss"]}')
torch.nn.utils.clip_grad_norm_(self.nuwa.parameters(), self.max_grad_norm)
self.optim.step()
self.optim.zero_grad()
if not (self.steps % self.save_results_every):
self.nuwa.eval()
print(f'{self.steps} sampling')
rand_idx = randrange(0, len(self.ds))
text, video = self.ds[rand_idx]
text, video = next(self.dl)
text = text.to(device)
video = self.nuwa.generate(text = text[:1], num_frames = min(video.shape[1], self.num_sampled_frames))
one_video = video[0].cpu().clamp(0., 1.)
text_str = tokenizer.decode(text[0])
logs['sampled_text'] = text_str
logs['sampled_video'] = one_video.numpy()
image = rearrange(one_video, 'f c h w -> c (f h) w')
save_image(image, str(self.results_folder / f'{self.steps}.png'))
print(f'{self.steps}: saving to {str(self.results_folder)}')
if not (self.steps % self.save_model_every):
state_dict = self.nuwa.state_dict()
model_path = str(self.results_folder / f'nuwa.{self.steps}.pt')
torch.save(state_dict, model_path)
print(f'{self.steps}: saving model to {str(self.results_folder)}')
self.steps += 1
return logs
def train(self, log_fn = noop):
while self.steps < self.num_train_steps:
logs = self.train_step()
log_fn(logs)
print('training complete')
.\lucidrains\nuwa-pytorch\nuwa_pytorch\train_vqgan_vae.py
from math import sqrt
import copy
from random import choice
from pathlib import Path
import torch
from torch import nn
import numpy as np
from PIL import Image
from torchvision.datasets import ImageFolder
import torchvision.transforms as T
from torch.utils.data import Dataset, DataLoader, random_split
from einops import rearrange
from nuwa_pytorch.vqgan_vae import VQGanVAE
def exists(val):
return val is not None
def noop(*args, **kwargs):
pass
def cycle(dl):
while True:
for data in dl:
yield data
def cast_tuple(t):
return t if isinstance(t, (tuple, list)) else (t,)
def yes_or_no(question):
answer = input(f'{question} (y/n) ')
return answer.lower() in ('yes', 'y')
def accum_log(log, new_logs):
for key, new_value in new_logs.items():
old_value = log.get(key, 0.)
log[key] = old_value + new_value
return log
class MemmappedImageDataset(Dataset):
def __init__(
self,
*,
path,
shape,
random_rotate = True
):
super().__init__()
path = Path(path)
assert path.exists(), f'path {path} must exist'
self.memmap = np.memmap(str(path), mode = 'r', dtype = np.uint8, shape = shape)
self.random_rotate = random_rotate
image_size = shape[-1]
self.transform = T.Compose([
T.Resize(image_size),
T.CenterCrop(image_size),
T.ToTensor()
])
def __len__(self):
return self.memmap.shape[0]
def __getitem__(self, index):
arr = self.memmap[index]
if arr.shape[0] == 1:
arr = rearrange(arr, '1 ... -> ...')
img = Image.fromarray(arr)
img = self.transform(img)
if self.random_rotate:
img = T.functional.rotate(img, choice([0, 90, 180, 270]))
return img
class ImageDataset(Dataset):
def __init__(
self,
folder,
image_size,
exts = ['jpg', 'jpeg', 'png']
):
super().__init__()
self.folder = folder
self.image_size = image_size
self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]
print(f'{len(self.paths)} training samples found at {folder}')
self.transform = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Resize(image_size),
T.RandomHorizontalFlip(),
T.CenterCrop(image_size),
T.ToTensor()
])
def __len__(self):
return len(self.paths)
def __getitem__(self, index):
path = self.paths[index]
img = Image.open(path)
return self.transform(img)
class EMA(nn.Module):
def __init__(
self,
model,
beta = 0.99,
ema_update_after_step = 1000,
ema_update_every = 10,
):
super().__init__()
self.beta = beta
self.online_model = model
self.ema_model = copy.deepcopy(model)
self.ema_update_after_step = ema_update_after_step
self.ema_update_every = ema_update_every
self.register_buffer('initted', torch.Tensor([False]))
self.register_buffer('step', torch.tensor([0.]))
def update(self):
self.step += 1
if self.step <= self.ema_update_after_step or (self.step % self.ema_update_every) != 0:
return
if not self.initted:
self.ema_model.state_dict(self.online_model.state_dict())
self.initted.data.copy_(torch.Tensor([True]))
self.update_moving_average(self.ema_model, self.online_model)
def update_moving_average(self, ma_model, current_model):
def calculate_ema(beta, old, new):
if not exists(old):
return new
return old * beta + (1 - beta) * new
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
old_weight, up_weight = ma_params.data, current_params.data
ma_params.data = calculate_ema(self.beta, old_weight, up_weight)
for current_buffer, ma_buffer in zip(current_model.buffers(), ma_model.buffers()):
new_buffer_value = calculate_ema(self.beta, ma_buffer, current_buffer)
ma_buffer.copy_(new_buffer_value)
def __call__(self, *args, **kwargs):
return self.ema_model(*args, **kwargs)
class VQGanVAETrainer(nn.Module):
def __init__(
self,
vae,
*,
num_train_steps,
lr,
batch_size,
grad_accum_every,
wd = 0.,
images_memmap_path = None,
images_memmap_shape = None,
folder = None,
save_results_every = 100,
save_model_every = 1000,
results_folder = './results',
valid_frac = 0.05,
random_split_seed = 42,
ema_beta = 0.995,
ema_update_after_step = 2000,
ema_update_every = 10,
apply_grad_penalty_every = 4,
):
super().__init__()
assert isinstance(vae, VQGanVAE), 'vae must be instance of VQGanVAE'
image_size = vae.image_size
self.vae = vae
self.ema_vae = EMA(vae, ema_update_after_step = ema_update_after_step, ema_update_every = ema_update_every)
self.register_buffer('steps', torch.Tensor([0]))
self.num_train_steps = num_train_steps
self.batch_size = batch_size
self.grad_accum_every = grad_accum_every
all_parameters = set(vae.parameters())
discr_parameters = set(vae.discr.parameters())
vae_parameters = all_parameters - discr_parameters
self.optim = get_optimizer(vae_parameters, lr = lr, wd = wd)
self.discr_optim = get_optimizer(discr_parameters, lr = lr, wd = wd)
assert exists(folder) ^ exists(images_memmap_path), 'either folder or memmap path to images must be supplied'
if exists(images_memmap_path):
assert exists(images_memmap_shape), 'shape of memmapped images must be supplied'
if exists(folder):
self.ds = ImageDataset(folder, image_size = image_size)
elif exists(images_memmap_path):
self.ds = MemmappedImageDataset(path = images_memmap_path, shape = images_memmap_shape)
if valid_frac > 0:
train_size = int((1 - valid_frac) * len(self.ds))
valid_size = len(self.ds) - train_size
self.ds, self.valid_ds = random_split(self.ds, [train_size, valid_size], generator = torch.Generator().manual_seed(random_split_seed))
print(f'training with dataset of {len(self.ds)} samples and validating with randomly splitted {len(self.valid_ds)} samples')
else:
self.valid_ds = self.ds
print(f'training with shared training and valid dataset of {len(self.ds)} samples')
self.dl = cycle(DataLoader(
self.ds,
batch_size = batch_size,
shuffle = True
))
self.valid_dl = cycle(DataLoader(
self.valid_ds,
batch_size = batch_size,
shuffle = True
))
self.save_model_every = save_model_every
self.save_results_every = save_results_every
self.apply_grad_penalty_every = apply_grad_penalty_every
self.results_folder = Path(results_folder)
if len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?'):
rmtree(str(self.results_folder))
self.results_folder.mkdir(parents = True, exist_ok = True)
def train_step(self):
device = next(self.vae.parameters()).device
steps = int(self.steps.item())
apply_grad_penalty = not (steps % self.apply_grad_penalty_every)
self.vae.train()
logs = {}
for _ in range(self.grad_accum_every):
img = next(self.dl)
img = img.to(device)
loss = self.vae(
img,
return_loss = True,
apply_grad_penalty = apply_grad_penalty
)
accum_log(logs, {'loss': loss.item() / self.grad_accum_every})
(loss / self.grad_accum_every).backward()
self.optim.step()
self.optim.zero_grad()
if exists(self.vae.discr):
self.discr_optim.zero_grad()
discr_loss = 0
for _ in range(self.grad_accum_every):
img = next(self.dl)
img = img.to(device)
loss = self.vae(img, return_discr_loss = True)
accum_log(logs, {'discr_loss': loss.item() / self.grad_accum_every})
(loss / self.grad_accum_every).backward()
self.discr_optim.step()
print(f"{steps}: vae loss: {logs['loss']} - discr loss: {logs['discr_loss']}")
self.ema_vae.update()
if not (steps % self.save_results_every):
for model, filename in ((self.ema_vae.ema_model, f'{steps}.ema'), (self.vae, str(steps))):
model.eval()
imgs = next(self.dl)
imgs = imgs.to(device)
recons = model(imgs)
nrows = int(sqrt(self.batch_size))
imgs_and_recons = torch.stack((imgs, recons), dim = 0)
imgs_and_recons = rearrange(imgs_and_recons, 'r b ... -> (b r) ...')
imgs_and_recons = imgs_and_recons.detach().cpu().float().clamp(0., 1.)
grid = make_grid(imgs_and_recons, nrow = 2, normalize = True, value_range = (0, 1))
logs['reconstructions'] = grid
save_image(grid, str(self.results_folder / f'{filename}.png'))
print(f'{steps}: saving to {str(self.results_folder)}')
if not (steps % self.save_model_every):
state_dict = self.vae.state_dict()
model_path = str(self.results_folder / f'vae.{steps}.pt')
torch.save(state_dict, model_path)
ema_state_dict = self.ema_vae.state_dict()
model_path = str(self.results_folder / f'vae.{steps}.ema.pt')
torch.save(ema_state_dict, model_path)
print(f'{steps}: saving model to {str(self.results_folder)}')
self.steps += 1
return logs
def train(self, log_fn = noop):
device = next(self.vae.parameters()).device
while self.steps < self.num_train_steps:
logs = self.train_step()
log_fn(logs)
print('training complete')
.\lucidrains\nuwa-pytorch\nuwa_pytorch\vqgan_vae.py
import copy
import math
from functools import partial, wraps
from math import sqrt
from vector_quantize_pytorch import VectorQuantize as VQ
import torchvision
import torch
from torch import nn, einsum
import torch.nn.functional as F
from torch.autograd import grad as torch_grad
from einops import rearrange, reduce, repeat
MList = nn.ModuleList
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def eval_decorator(fn):
def inner(model, *args, **kwargs):
was_training = model.training
model.eval()
out = fn(model, *args, **kwargs)
model.train(was_training)
return out
return inner
def remove_vgg(fn):
@wraps(fn)
def inner(self, *args, **kwargs):
has_vgg = hasattr(self, 'vgg')
if has_vgg:
vgg = self.vgg
delattr(self, 'vgg')
out = fn(self, *args, **kwargs)
if has_vgg:
self.vgg = vgg
return out
return inner
def pick_and_pop(keys, d):
values = list(map(lambda key: d.pop(key), keys))
return dict(zip(keys, values))
def group_dict_by_key(cond, d):
return_val = [dict(),dict()]
for key in d.keys():
match = bool(cond(key))
ind = int(not match)
return_val[ind][key] = d[key]
return (*return_val,)
def string_begins_with(prefix, str):
return str.startswith(prefix)
def group_by_key_prefix(prefix, d):
return group_dict_by_key(partial(string_begins_with, prefix), d)
def groupby_prefix_and_trim(prefix, d):
kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))
return kwargs_without_prefix, kwargs
def gradient_penalty(images, output, weight = 10):
batch_size = images.shape[0]
gradients = torch_grad(outputs = output, inputs = images,
grad_outputs = torch.ones(output.size(), device = images.device),
create_graph = True, retain_graph = True, only_inputs = True)[0]
gradients = rearrange(gradients, 'b ... -> b (...)')
return weight * ((gradients.norm(2, dim=1) - 1) ** 2).mean()
def l2norm(t):
return F.normalize(t, dim = -1)
def leaky_relu(p = 0.1):
return nn.LeakyReLU(0.1)
def stable_softmax(t, dim = -1, alpha = 32 ** 2):
t = t / alpha
t = t - torch.amax(t, dim = dim, keepdim = True).detach()
return (t * alpha).softmax(dim = dim)
def safe_div(numer, denom, eps = 1e-6):
return numer / (denom + eps)
def hinge_discr_loss(fake, real):
return (F.relu(1 + fake) + F.relu(1 - real)).mean()
def hinge_gen_loss(fake):
return -fake.mean()
def bce_discr_loss(fake, real):
return (-log(1 - sigmoid(fake)) - log(sigmoid(real))).mean()
def bce_gen_loss(fake):
return -log(sigmoid(fake)).mean()
def grad_layer_wrt_loss(loss, layer):
return torch_grad(
outputs = loss,
inputs = layer,
grad_outputs = torch.ones_like(loss),
retain_graph = True
)[0].detach()
class LayerNormChan(nn.Module):
def __init__(
self,
dim,
eps = 1e-5
):
super().__init__()
self.eps = eps
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1)
def forward(self, x):
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
class Discriminator(nn.Module):
def __init__(
self,
dims,
channels = 3,
groups = 16,
init_kernel_size = 5
):
super().__init__()
dim_pairs = zip(dims[:-1], dims[1:])
self.layers = MList([nn.Sequential(nn.Conv2d(channels, dims[0], init_kernel_size, padding = init_kernel_size // 2), leaky_relu())])
for dim_in, dim_out in dim_pairs:
self.layers.append(nn.Sequential(
nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1),
nn.GroupNorm(groups, dim_out),
leaky_relu()
))
dim = dims[-1]
self.to_logits = nn.Sequential(
nn.Conv2d(dim, dim, 1),
leaky_relu(),
nn.Conv2d(dim, 1, 4)
)
def forward(self, x):
for net in self.layers:
x = net(x)
return self.to_logits(x)
class ContinuousPositionBias(nn.Module):
""" 定义一个连续位置偏置的类,参考 https://arxiv.org/abs/2111.09883 """
def __init__(self, *, dim, heads, layers = 2):
super().__init__()
self.net = MList([])
self.net.append(nn.Sequential(nn.Linear(2, dim), leaky_relu()))
for _ in range(layers - 1):
self.net.append(nn.Sequential(nn.Linear(dim, dim), leaky_relu()))
self.net.append(nn.Linear(dim, heads)
self.register_buffer('rel_pos', None, persistent = False)
def forward(self, x):
n, device = x.shape[-1], x.device
fmap_size = int(sqrt(n))
if not exists(self.rel_pos):
pos = torch.arange(fmap_size, device = device)
grid = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij'))
grid = rearrange(grid, 'c i j -> (i j) c')
rel_pos = rearrange(grid, 'i c -> i 1 c') - rearrange(grid, 'j c -> 1 j c')
rel_pos = torch.sign(rel_pos) * torch.log(rel_pos.abs() + 1)
self.register_buffer('rel_pos', rel_pos, persistent = False)
rel_pos = self.rel_pos.float()
for layer in self.net:
rel_pos = layer(rel_pos)
bias = rearrange(rel_pos, 'i j h -> h i j')
return x + bias
class GLUResBlock(nn.Module):
""" 定义一个 GLUResBlock 类 """
def __init__(self, chan, groups = 16):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(chan, chan * 2, 3, padding = 1),
nn.GLU(dim = 1),
nn.GroupNorm(groups, chan),
nn.Conv2d(chan, chan * 2, 3, padding = 1),
nn.GLU(dim = 1),
nn.GroupNorm(groups, chan),
nn.Conv2d(chan, chan, 1)
)
def forward(self, x):
return self.net(x) + x
class ResBlock(nn.Module):
""" 定义一个 ResBlock 类 """
def __init__(self, chan, groups = 16):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(chan, chan, 3, padding = 1),
nn.GroupNorm(groups, chan),
leaky_relu(),
nn.Conv2d(chan, chan, 3, padding = 1),
nn.GroupNorm(groups, chan),
leaky_relu(),
nn.Conv2d(chan, chan, 1)
)
def forward(self, x):
return self.net(x) + x
class VQGanAttention(nn.Module):
""" 定义一个 VQGanAttention 类 """
def __init__(
self,
*,
dim,
dim_head = 64,
heads = 8,
dropout = 0.
):
super().__init__()
self.heads = heads
self.scale = nn.Parameter(torch.ones(1, heads, 1, 1) * math.log(0.01))
inner_dim = heads * dim_head
self.dropout = nn.Dropout(dropout)
self.post_norm = LayerNormChan(dim)
self.cpb = ContinuousPositionBias(dim = dim // 4, heads = heads)
self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)
self.to_out = nn.Conv2d(inner_dim, dim, 1)
def forward(self, x):
h = self.heads
height, width, residual = *x.shape[-2:], x.clone()
q, k, v = self.to_qkv(x).chunk(3, dim = 1)
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = h), (q, k, v))
q, k = map(l2norm, (q, k))
sim = einsum('b h c i, b h c j -> b h i j', q, k) * self.scale.exp()
sim = self.cpb(sim)
attn = stable_softmax(sim, dim = -1)
attn = self.dropout(attn)
out = einsum('b h i j, b h c j -> b h c i', attn, v)
out = rearrange(out, 'b h c (x y) -> b (h c) x y', x = height, y = width)
out = self.to_out(out)
return self.post_norm(out) + residual
class VQGanVAE(nn.Module):
""" 定义一个 VQGanVAE 类 """
def __init__(
self,
*,
dim,
image_size,
channels = 3,
num_layers = 4,
layer_mults = None,
l2_recon_loss = False,
use_hinge_loss = True,
num_resnet_blocks = 1,
vgg = None,
vq_codebook_dim = 256,
vq_codebook_size = 512,
vq_decay = 0.8,
vq_commitment_weight = 1.,
vq_kmeans_init = True,
vq_use_cosine_sim = True,
use_attn = True,
attn_dim_head = 64,
attn_heads = 8,
resnet_groups = 16,
attn_dropout = 0.,
first_conv_kernel_size = 5,
use_vgg_and_gan = True,
**kwargs
):
super().__init__()
assert dim % resnet_groups == 0, f'dimension {dim} must be divisible by {resnet_groups} (groups for the groupnorm)'
vq_kwargs, kwargs = groupby_prefix_and_trim('vq_', kwargs)
self.image_size = image_size
self.channels = channels
self.num_layers = num_layers
self.fmap_size = image_size // (num_layers ** 2)
self.codebook_size = vq_codebook_size
self.encoders = MList([])
self.decoders = MList([])
layer_mults = default(layer_mults, list(map(lambda t: 2 ** t, range(num_layers))))
assert len(layer_mults) == num_layers, 'layer multipliers must be equal to designated number of layers'
layer_dims = [dim * mult for mult in layer_mults]
dims = (dim, *layer_dims)
codebook_dim = layer_dims[-1]
dim_pairs = zip(dims[:-1], dims[1:])
append = lambda arr, t: arr.append(t)
prepend = lambda arr, t: arr.insert(0, t)
if not isinstance(num_resnet_blocks, tuple):
num_resnet_blocks = (*((0,) * (num_layers - 1)), num_resnet_blocks)
if not isinstance(use_attn, tuple):
use_attn = (*((False,) * (num_layers - 1)), use_attn)
assert len(num_resnet_blocks) == num_layers, 'number of resnet blocks config must be equal to number of layers'
assert len(use_attn) == num_layers
for layer_index, (dim_in, dim_out), layer_num_resnet_blocks, layer_use_attn in zip(range(num_layers), dim_pairs, num_resnet_blocks, use_attn):
append(self.encoders, nn.Sequential(nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1), leaky_relu()))
prepend(self.decoders, nn.Sequential(nn.Upsample(scale_factor = 2, mode = 'bilinear', align_corners = False), nn.Conv2d(dim_out, dim_in, 3, padding = 1), leaky_relu()))
if layer_use_attn:
prepend(self.decoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))
for _ in range(layer_num_resnet_blocks):
append(self.encoders, ResBlock(dim_out, groups = resnet_groups))
prepend(self.decoders, GLUResBlock(dim_out, groups = resnet_groups))
if layer_use_attn:
append(self.encoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))
prepend(self.encoders, nn.Conv2d(channels, dim, first_conv_kernel_size, padding = first_conv_kernel_size // 2))
append(self.decoders, nn.Conv2d(dim, channels, 1))
self.vq = VQ(
dim = layer_dims[-1],
codebook_dim = vq_codebook_dim,
codebook_size = vq_codebook_size,
decay = vq_decay,
commitment_weight = vq_commitment_weight,
accept_image_fmap = True,
kmeans_init = vq_kmeans_init,
use_cosine_sim = vq_use_cosine_sim,
**vq_kwargs
)
self.recon_loss_fn = F.mse_loss if l2_recon_loss else F.l1_loss
self.vgg = None
self.discr = None
self.use_vgg_and_gan = use_vgg_and_gan
if not use_vgg_and_gan:
return
if exists(vgg):
self.vgg = vgg
else:
self.vgg = torchvision.models.vgg16(pretrained = True)
self.vgg.classifier = nn.Sequential(*self.vgg.classifier[:-2])
self.discr = Discriminator(dims = dims, channels = channels)
self.discr_loss = hinge_discr_loss if use_hinge_loss else bce_discr_loss
self.gen_loss = hinge_gen_loss if use_hinge_loss else bce_gen_loss
def copy_for_eval(self):
device = next(self.parameters()).device
vae_copy = copy.deepcopy(self.cpu())
if vae_copy.use_vgg_and_gan:
del vae_copy.discr
del vae_copy.vgg
vae_copy.eval()
return vae_copy.to(device)
@remove_vgg
def state_dict(self, *args, **kwargs):
return super().state_dict(*args, **kwargs)
@remove_vgg
def load_state_dict(self, *args, **kwargs):
return super().load_state_dict(*args, **kwargs)
@property
def codebook(self):
return self.vq.codebook
def encode(self, fmap):
for enc in self.encoders:
fmap = enc(fmap)
return self.vq(fmap)
def decode(self, fmap):
for dec in self.decoders:
fmap = dec(fmap)
return fmap
@torch.no_grad()
@eval_decorator
def codebook_indices_to_video(self, indices):
b = indices.shape[0]
codes = self.codebook[indices]
codes = rearrange(codes, 'b (f h w) d -> (b f) d h w', h = self.fmap_size, w = self.fmap_size)
video = self.decode(codes)
return rearrange(video, '(b f) ... -> b f ...', b = b)
@torch.no_grad()
@eval_decorator
def get_video_indices(self, video):
b, f, _, h, w = video.shape
images = rearrange(video, 'b f ... -> (b f) ...')
_, indices, _ = self.encode(images)
return rearrange(indices, '(b f) ... -> b f ...', b = b)
def forward(
self,
img,
return_loss = False,
return_discr_loss = False,
return_recons = False,
apply_grad_penalty = False
):
batch, channels, height, width, device = *img.shape, img.device
assert height == self.image_size and width == self.image_size, 'height and width of input image must be equal to {self.image_size}'
assert channels == self.channels, 'number of channels on image or sketch is not equal to the channels set on this VQGanVAE'
fmap, indices, commit_loss = self.encode(img)
fmap = self.decode(fmap)
if not return_loss and not return_discr_loss:
return fmap
assert return_loss ^ return_discr_loss, 'you should either return autoencoder loss or discriminator loss, but not both'
if return_discr_loss:
assert exists(self.discr), 'discriminator must exist to train it'
fmap.detach_()
img.requires_grad_()
fmap_discr_logits, img_discr_logits = map(self.discr, (fmap, img))
loss = self.discr_loss(fmap_discr_logits, img_discr_logits)
if apply_grad_penalty:
gp = gradient_penalty(img, img_discr_logits)
loss = loss + gp
if return_recons:
return loss, fmap
return loss
recon_loss = self.recon_loss_fn(fmap, img)
if not self.use_vgg_and_gan:
if return_recons:
return recon_loss, fmap
return recon_loss
img_vgg_input = img
fmap_vgg_input = fmap
if img.shape[1] == 1:
img_vgg_input, fmap_vgg_input = map(lambda t: repeat(t, 'b 1 ... -> b c ...', c = 3), (img_vgg_input, fmap_vgg_input))
img_vgg_feats = self.vgg(img_vgg_input)
recon_vgg_feats = self.vgg(fmap_vgg_input)
perceptual_loss = F.mse_loss(img_vgg_feats, recon_vgg_feats)
gen_loss = self.gen_loss(self.discr(fmap))
last_dec_layer = self.decoders[-1].weight
norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_dec_layer).norm(p = 2)
norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_dec_layer).norm(p = 2)
adaptive_weight = safe_div(norm_grad_wrt_perceptual_loss, norm_grad_wrt_gen_loss)
adaptive_weight.clamp_(max = 1e4)
loss = recon_loss + perceptual_loss + commit_loss + adaptive_weight * gen_loss
if return_recons:
return loss, fmap
return loss