Lucidrains 系列项目源码解析(八)
.\lucidrains\big-sleep\big_sleep\big_sleep.py
# 导入必要的库
import os
import sys
import subprocess
import signal
import string
import re
from datetime import datetime
from pathlib import Path
import random
import torch
import torch.nn.functional as F
from torch import nn
from torch.optim import Adam
from torchvision.utils import save_image
import torchvision.transforms as T
from PIL import Image
from tqdm import tqdm, trange
from big_sleep.ema import EMA
from big_sleep.resample import resample
from big_sleep.biggan import BigGAN
from big_sleep.clip import load, tokenize
# 检查是否有可用的 CUDA
assert torch.cuda.is_available(), 'CUDA must be available in order to use Big Sleep'
# 优雅地处理键盘中断
terminate = False
def signal_handling(signum,frame):
print('detecting keyboard interrupt, gracefully exiting')
global terminate
terminate = True
signal.signal(signal.SIGINT,signal_handling)
# 辅助函数
def exists(val):
return val is not None
def open_folder(path):
if os.path.isfile(path):
path = os.path.dirname(path)
if not os.path.isdir(path):
return
cmd_list = None
if sys.platform == 'darwin':
cmd_list = ['open', '--', path]
elif sys.platform == 'linux2' or sys.platform == 'linux':
cmd_list = ['xdg-open', path]
elif sys.platform in ['win32', 'win64']:
cmd_list = ['explorer', path.replace('/','\\')]
if cmd_list == None:
return
try:
subprocess.check_call(cmd_list)
except subprocess.CalledProcessError:
pass
except OSError:
pass
def create_text_path(text=None, img=None, encoding=None):
input_name = ""
if text is not None:
input_name += text
if img is not None:
if isinstance(img, str):
img_name = "".join(img.split(".")[:-1]) # replace spaces by underscores, remove img extension
img_name = img_name.split("/")[-1] # only take img name, not path
else:
img_name = "PIL_img"
input_name += "_" + img_name
if encoding is not None:
input_name = "your_encoding"
return input_name.replace("-", "_").replace(",", "").replace(" ", "_").replace("|", "--").strip('-_')[:255]
# 张量辅助函数
def differentiable_topk(x, k, temperature=1.):
n, dim = x.shape
topk_tensors = []
for i in range(k):
is_last = i == (k - 1)
values, indices = (x / temperature).softmax(dim=-1).topk(1, dim=-1)
topks = torch.zeros_like(x).scatter_(-1, indices, values)
topk_tensors.append(topks)
if not is_last:
x = x.scatter(-1, indices, float('-inf'))
topks = torch.cat(topk_tensors, dim=-1)
return topks.reshape(n, k, dim).sum(dim = 1)
def create_clip_img_transform(image_width):
clip_mean = [0.48145466, 0.4578275, 0.40821073]
clip_std = [0.26862954, 0.26130258, 0.27577711]
transform = T.Compose([
#T.ToPILImage(),
T.Resize(image_width),
T.CenterCrop((image_width, image_width)),
T.ToTensor(),
T.Normalize(mean=clip_mean, std=clip_std)
])
return transform
def rand_cutout(image, size, center_bias=False, center_focus=2):
width = image.shape[-1]
min_offset = 0
max_offset = width - size
if center_bias:
# 以图像中心为中心进行采样
center = max_offset / 2
std = center / center_focus
offset_x = int(random.gauss(mu=center, sigma=std))
offset_y = int(random.gauss(mu=center, sigma=std))
# 如果超出边界,则均匀重新采样
offset_x = random.randint(min_offset, max_offset) if (offset_x > max_offset or offset_x < min_offset) else offset_x
offset_y = random.randint(min_offset, max_offset) if (offset_y > max_offset or offset_y < min_offset) else offset_y
else:
offset_x = random.randint(min_offset, max_offset)
offset_y = random.randint(min_offset, max_offset)
cutout = image[:, :, offset_x:offset_x + size, offset_y:offset_y + size]
# 返回变量 cutout 的值
return cutout
# 加载 BigGAN 模型
class Latents(torch.nn.Module):
def __init__(
self,
num_latents = 15,
num_classes = 1000,
z_dim = 128,
max_classes = None,
class_temperature = 2.
):
super().__init__()
# 初始化正态分布的参数用于生成隐变量
self.normu = torch.nn.Parameter(torch.zeros(num_latents, z_dim).normal_(std = 1))
# 初始化正态分布的参数用于生成类别信息
self.cls = torch.nn.Parameter(torch.zeros(num_latents, num_classes).normal_(mean = -3.9, std = .3))
# 注册缓冲区,用于存储阈值
self.register_buffer('thresh_lat', torch.tensor(1))
# 检查最大类别数是否在合理范围内
assert not exists(max_classes) or max_classes > 0 and max_classes <= num_classes, f'max_classes must be between 0 and {num_classes}'
self.max_classes = max_classes
self.class_temperature = class_temperature
def forward(self):
# 根据最大类别数选择类别信息
if exists(self.max_classes):
classes = differentiable_topk(self.cls, self.max_classes, temperature = self.class_temperature)
else:
classes = torch.sigmoid(self.cls)
return self.normu, classes
class Model(nn.Module):
def __init__(
self,
image_size,
max_classes = None,
class_temperature = 2.,
ema_decay = 0.99
):
super().__init__()
# 确保图像尺寸合法
assert image_size in (128, 256, 512), 'image size must be one of 128, 256, or 512'
# 加载预训练的 BigGAN 模型
self.biggan = BigGAN.from_pretrained(f'biggan-deep-{image_size}')
self.max_classes = max_classes
self.class_temperature = class_temperature
self.ema_decay\
= ema_decay
self.init_latents()
def init_latents(self):
# 初始化隐变量
latents = Latents(
num_latents = len(self.biggan.config.layers) + 1,
num_classes = self.biggan.config.num_classes,
z_dim = self.biggan.config.z_dim,
max_classes = self.max_classes,
class_temperature = self.class_temperature
)
self.latents = EMA(latents, self.ema_decay)
def forward(self):
self.biggan.eval()
out = self.biggan(*self.latents(), 1)
return (out + 1) / 2
class BigSleep(nn.Module):
def __init__(
self,
num_cutouts = 128,
loss_coef = 100,
image_size = 512,
bilinear = False,
max_classes = None,
class_temperature = 2.,
experimental_resample = False,
ema_decay = 0.99,
center_bias = False,
larger_clip = False
):
super().__init__()
self.loss_coef = loss_coef
self.image_size = image_size
self.num_cutouts = num_cutouts
self.experimental_resample = experimental_resample
self.center_bias = center_bias
# 根据插值方式设置插值参数
self.interpolation_settings = {'mode': 'bilinear', 'align_corners': False} if bilinear else {'mode': 'nearest'}
model_name = 'ViT-B/32' if not larger_clip else 'ViT-L/14'
# 加载视觉-文本模型和图像归一化函数
self.perceptor, self.normalize_image = load(model_name, jit = False)
self.model = Model(
image_size = image_size,
max_classes = max_classes,
class_temperature = class_temperature,
ema_decay = ema_decay
)
def reset(self):
# 重置隐变量
self.model.init_latents()
def sim_txt_to_img(self, text_embed, img_embed, text_type="max"):
sign = -1
if text_type == "min":
sign = 1
# 计算文本嵌入和图像嵌入的余弦相似度
return sign * self.loss_coef * torch.cosine_similarity(text_embed, img_embed, dim = -1).mean()
# 定义前向传播函数,接受文本嵌入和文本最小嵌入作为输入,返回损失值
def forward(self, text_embeds, text_min_embeds=[], return_loss = True):
# 获取图像大小和裁剪块数量
width, num_cutouts = self.image_size, self.num_cutouts
# 使用模型进行前向传播
out = self.model()
# 如果不需要返回损失值,则直接返回模型输出
if not return_loss:
return out
# 初始化空列表用于存储裁剪块
pieces = []
for ch in range(num_cutouts):
# 随机采样裁剪块大小
size = int(width * torch.zeros(1,).normal_(mean=.8, std=.3).clip(.5, .95))
# 获取裁剪块
apper = rand_cutout(out, size, center_bias=self.center_bias)
# 如果启用实验性重采样,则进行重采样
if (self.experimental_resample):
apper = resample(apper, (224, 224))
else:
apper = F.interpolate(apper, (224, 224), **self.interpolation_settings)
pieces.append(apper)
# 将所有裁剪块拼接在一起
into = torch.cat(pieces)
# 对拼接后的图像进行归一化处理
into = self.normalize_image(into)
# 对拼接后的图像进行编码
image_embed = self.perceptor.encode_image(into)
# 获取潜在向量和软标签
latents, soft_one_hot_classes = self.model.latents()
num_latents = latents.shape[0]
latent_thres = self.model.latents.model.thresh_lat
# 计算潜在向量的损失
lat_loss = torch.abs(1 - torch.std(latents, dim=1)).mean() + \
torch.abs(torch.mean(latents, dim = 1)).mean() + \
4 * torch.max(torch.square(latents).mean(), latent_thres)
# 遍历每个潜在向量数组,计算额外的损失
for array in latents:
mean = torch.mean(array)
diffs = array - mean
var = torch.mean(torch.pow(diffs, 2.0))
std = torch.pow(var, 0.5)
zscores = diffs / std
skews = torch.mean(torch.pow(zscores, 3.0))
kurtoses = torch.mean(torch.pow(zscores, 4.0)) - 3.0
lat_loss = lat_loss + torch.abs(kurtoses) / num_latents + torch.abs(skews) / num_latents
# 计算分类损失
cls_loss = ((50 * torch.topk(soft_one_hot_classes, largest = False, dim = 1, k = 999)[0]) ** 2).mean()
# 初始化结果列表
results = []
# 计算文本嵌入与图像嵌入之间的相似性损失
for txt_embed in text_embeds:
results.append(self.sim_txt_to_img(txt_embed, image_embed))
# 计算文本最小嵌入与图像嵌入之间的相似性损失
for txt_min_embed in text_min_embeds:
results.append(self.sim_txt_to_img(txt_min_embed, image_embed, "min"))
# 计算总的相似性损失
sim_loss = sum(results).mean()
# 返回模型输出和各项损失值
return out, (lat_loss, cls_loss, sim_loss)
class Imagine(nn.Module):
# 定义 Imagine 类,继承自 nn.Module
def __init__(
self,
*,
text=None,
img=None,
encoding=None,
text_min = "",
lr = .07,
image_size = 512,
gradient_accumulate_every = 1,
save_every = 50,
epochs = 20,
iterations = 1050,
save_progress = False,
bilinear = False,
open_folder = True,
seed = None,
append_seed = False,
torch_deterministic = False,
max_classes = None,
class_temperature = 2.,
save_date_time = False,
save_best = False,
experimental_resample = False,
ema_decay = 0.99,
num_cutouts = 128,
center_bias = False,
larger_clip = False
):
# 初始化函数,接收多个参数
super().__init__()
if torch_deterministic:
# 如果 torch_deterministic 为真
assert not bilinear, 'the deterministic (seeded) operation does not work with interpolation (PyTorch 1.7.1)'
# 断言不使用双线性插值,因为确定性(种子化)操作与插值不兼容(PyTorch 1.7.1)
torch.set_deterministic(True)
self.seed = seed
self.append_seed = append_seed
if exists(seed):
# 如果种子存在
print(f'setting seed of {seed}')
# 打印设置种子值
if seed == 0:
print('you can override this with --seed argument in the command line, or --random for a randomly chosen one')
# 如果种子为0,提示可以在命令行中使用 --seed 参数覆盖,或者使用 --random 选择随机种子
torch.manual_seed(seed)
self.epochs = epochs
self.iterations = iterations
model = BigSleep(
image_size = image_size,
bilinear = bilinear,
max_classes = max_classes,
class_temperature = class_temperature,
experimental_resample = experimental_resample,
ema_decay = ema_decay,
num_cutouts = num_cutouts,
center_bias = center_bias,
larger_clip = larger_clip
).cuda()
# 创建 BigSleep 模型对象
self.model = model
self.lr = lr
self.optimizer = Adam(model.model.latents.model.parameters(), lr)
self.gradient_accumulate_every = gradient_accumulate_every
self.save_every = save_every
self.save_progress = save_progress
self.save_date_time = save_date_time
self.save_best = save_best
self.current_best_score = 0
self.open_folder = open_folder
self.total_image_updates = (self.epochs * self.iterations) / self.save_every
self.encoded_texts = {
"max": [],
"min": []
}
# 创建编码文本的字典
self.clip_transform = create_clip_img_transform(224)
# 创建图像转换
self.set_clip_encoding(text=text, img=img, encoding=encoding, text_min=text_min)
# 设置剪辑编码
@property
def seed_suffix(self):
# 定义 seed_suffix 属性
return f'.{self.seed}' if self.append_seed and exists(self.seed) else ''
# 如果 append_seed 为真且存在种子值,则返回种子值后缀
def set_text(self, text):
# 设置文本
self.set_clip_encoding(text = text)
def create_clip_encoding(self, text=None, img=None, encoding=None):
# 创建剪辑编码
self.text = text
self.img = img
if encoding is not None:
encoding = encoding.cuda()
#elif self.create_story:
# encoding = self.update_story_encoding(epoch=0, iteration=1)
elif text is not None and img is not None:
encoding = (self.create_text_encoding(text) + self.create_img_encoding(img)) / 2
elif text is not None:
encoding = self.create_text_encoding(text)
elif img is not None:
encoding = self.create_img_encoding(img)
return encoding
# 返回编码结果
def create_text_encoding(self, text):
# 创建文本编码
tokenized_text = tokenize(text).cuda()
# 对文本进行标记化
with torch.no_grad():
text_encoding = self.model.perceptor.encode_text(tokenized_text).detach()
# 使用模型对文本进行编码
return text_encoding
# 返回文本编码结果
# 创建图像编码,将图像转换为张量并进行归一化处理,然后在GPU上执行
def create_img_encoding(self, img):
if isinstance(img, str):
img = Image.open(img)
normed_img = self.clip_transform(img).unsqueeze(0).cuda()
with torch.no_grad():
img_encoding = self.model.perceptor.encode_image(normed_img).detach()
return img_encoding
# 对多个短语进行编码,根据文本类型将编码结果存储在字典中
def encode_multiple_phrases(self, text, img=None, encoding=None, text_type="max"):
if text is not None and "|" in text:
self.encoded_texts[text_type] = [self.create_clip_encoding(text=prompt_min, img=img, encoding=encoding) for prompt_min in text.split("|")]
else:
self.encoded_texts[text_type] = [self.create_clip_encoding(text=text, img=img, encoding=encoding)]
# 对最大和最小短语进行编码,调用encode_multiple_phrases方法
def encode_max_and_min(self, text, img=None, encoding=None, text_min=""):
self.encode_multiple_phrases(text, img=img, encoding=encoding)
if text_min is not None and text_min != "":
self.encode_multiple_phrases(text_min, img=img, encoding=encoding, text_type="min")
# 设置Clip编码,包括文本、图像、编码等信息,并调用encode_max_and_min方法
def set_clip_encoding(self, text=None, img=None, encoding=None, text_min=""):
self.current_best_score = 0
self.text = text
self.text_min = text_min
if len(text_min) > 0:
text = text + "_wout_" + text_min[:255] if text is not None else "wout_" + text_min[:255]
text_path = create_text_path(text=text, img=img, encoding=encoding)
if self.save_date_time:
text_path = datetime.now().strftime("%y%m%d-%H%M%S-") + text_path
self.text_path = text_path
self.filename = Path(f'./{text_path}{self.seed_suffix}.png')
self.encode_max_and_min(text, img=img, encoding=encoding, text_min=text_min) # Tokenize and encode each prompt
# 重置模型,将模型移至GPU上,并初始化优化器
def reset(self):
self.model.reset()
self.model = self.model.cuda()
self.optimizer = Adam(self.model.model.latents.parameters(), self.lr)
# 训练模型的一步,计算损失并更新模型参数
def train_step(self, epoch, i, pbar=None):
total_loss = 0
for _ in range(self.gradient_accumulate_every):
out, losses = self.model(self.encoded_texts["max"], self.encoded_texts["min"])
loss = sum(losses) / self.gradient_accumulate_every
total_loss += loss
loss.backward()
self.optimizer.step()
self.model.model.latents.update()
self.optimizer.zero_grad()
if (i + 1) % self.save_every == 0:
with torch.no_grad():
self.model.model.latents.eval()
out, losses = self.model(self.encoded_texts["max"], self.encoded_texts["min"])
top_score, best = torch.topk(losses[2], k=1, largest=False)
image = self.model.model()[best].cpu()
self.model.model.latents.train()
save_image(image, str(self.filename))
if pbar is not None:
pbar.update(1)
else:
print(f'image updated at "./{str(self.filename)}"')
if self.save_progress:
total_iterations = epoch * self.iterations + i
num = total_iterations // self.save_every
save_image(image, Path(f'./{self.text_path}.{num}{self.seed_suffix}.png'))
if self.save_best and top_score.item() < self.current_best_score:
self.current_best_score = top_score.item()
save_image(image, Path(f'./{self.text_path}{self.seed_suffix}.best.png'))
return out, total_loss
# 定义一个方法用于前向传播
def forward(self):
# 初始化一个空字符串用于记录惩罚信息
penalizing = ""
# 如果self.text_min的长度大于0,则将punishing赋值为包含self.text_min的字符串
if len(self.text_min) > 0:
penalizing = f'penalizing "{self.text_min}"'
# 打印信息,包括self.text_path和punishing信息
print(f'Imagining "{self.text_path}" {penalizing}...')
# 禁用梯度计算
with torch.no_grad():
# 对模型进行一次前向传播,用于解决CLIP和CUDA的问题
self.model(self.encoded_texts["max"][0])
# 如果需要打开文件夹
if self.open_folder:
# 打开当前目录
open_folder('./')
# 将self.open_folder设置为False
self.open_folder = False
# 创建一个进度条用于显示图片更新的进度
image_pbar = tqdm(total=self.total_image_updates, desc='image update', position=2, leave=True)
# 创建一个进度条用于显示训练轮数的进度
epoch_pbar = trange(self.epochs, desc = ' epochs', position=0, leave=True)
# 遍历每个轮数
for epoch in (ep for ep in epoch_pbar if not terminate):
# 创建一个进度条用于显示每轮训练迭代的进度
pbar = trange(self.iterations, desc=' iteration', position=1, leave=True)
# 更新图片更新进度条
image_pbar.update(0)
# 遍历每个迭代
for i in (it for it in pbar if not terminate):
# 执行训练步骤,获取输出和损失值
out, loss = self.train_step(epoch, i, image_pbar)
# 设置进度条描述信息为当前损失值
pbar.set_description(f'loss: {loss.item():04.2f}')
.\lucidrains\big-sleep\big_sleep\cli.py
# 导入 fire 模块,用于命令行接口
import fire
# 导入 random 模块并重命名为 rnd
import random as rnd
# 从 big_sleep 模块中导入 Imagine 类和 version 变量
from big_sleep import Imagine, version
# 从 pathlib 模块中导入 Path 类
from pathlib import Path
# 从当前目录下的 version 模块中导入 __version__ 变量
from .version import __version__;
# 定义 train 函数,接受多个参数
def train(
text=None,
img=None,
text_min="",
lr = .07,
image_size = 512,
gradient_accumulate_every = 1,
epochs = 20,
iterations = 1050,
save_every = 50,
overwrite = False,
save_progress = False,
save_date_time = False,
bilinear = False,
open_folder = True,
seed = 0,
append_seed = False,
random = False,
torch_deterministic = False,
max_classes = None,
class_temperature = 2.,
save_best = False,
experimental_resample = False,
ema_decay = 0.5,
num_cutouts = 128,
center_bias = False,
larger_model = False
):
# 打印版本信息
print(f'Starting up... v{__version__}')
# 如果 random 为 True,则生成一个随机种子
if random:
seed = rnd.randint(0, 1e6)
# 创建 Imagine 对象,传入各种参数
imagine = Imagine(
text=text,
img=img,
text_min=text_min,
lr = lr,
image_size = image_size,
gradient_accumulate_every = gradient_accumulate_every,
epochs = epochs,
iterations = iterations,
save_every = save_every,
save_progress = save_progress,
bilinear = bilinear,
seed = seed,
append_seed = append_seed,
torch_deterministic = torch_deterministic,
open_folder = open_folder,
max_classes = max_classes,
class_temperature = class_temperature,
save_date_time = save_date_time,
save_best = save_best,
experimental_resample = experimental_resample,
ema_decay = ema_decay,
num_cutouts = num_cutouts,
center_bias = center_bias,
larger_clip = larger_model
)
# 如果不覆盖且文件已存在,则询问是否覆盖
if not overwrite and imagine.filename.exists():
answer = input('Imagined image already exists, do you want to overwrite? (y/n) ').lower()
if answer not in ('yes', 'y'):
exit()
# 调用 Imagine 对象的方法开始训练
imagine()
# 定义主函数
def main():
# 使用 fire 模块创建命令行接口,传入 train 函数
fire.Fire(train)
.\lucidrains\big-sleep\big_sleep\ema.py
# 导入必要的库
from copy import deepcopy
import torch
from torch import nn
# 定义指数移动平均类
class EMA(nn.Module):
# 初始化函数,接受模型和衰减率作为参数
def __init__(self, model, decay):
super().__init__()
self.model = model
self.decay = decay
# 注册缓冲区
self.register_buffer('accum', torch.tensor(1.))
self._biased = deepcopy(self.model)
self.average = deepcopy(self.model)
# 将偏置参数和平均参数初始化为零
for param in self._biased.parameters():
param.detach_().zero_()
for param in self.average.parameters():
param.detach_().zero_()
# 更新参数
self.update()
# 更新函数,用于更新指数移动平均
@torch.no_grad()
def update(self):
assert self.training, 'Update should only be called during training'
# 更新累积值
self.accum *= self.decay
# 获取模型参数、偏置参数和平均参数
model_params = dict(self.model.named_parameters())
biased_params = dict(self._biased.named_parameters())
average_params = dict(self.average.named_parameters())
assert model_params.keys() == biased_params.keys() == average_params.keys(), f'Model parameter keys incompatible with EMA stored parameter keys'
# 更新参数
for name, param in model_params.items():
biased_params[name].mul_(self.decay)
biased_params[name].add_((1 - self.decay) * param)
average_params[name].copy_(biased_params[name])
average_params[name].div_(1 - self.accum)
# 获取模型缓冲区、偏置缓冲区和平均缓冲区
model_buffers = dict(self.model.named_buffers())
biased_buffers = dict(self._biased.named_buffers())
average_buffers = dict(self.average.named_buffers())
assert model_buffers.keys() == biased_buffers.keys() == average_buffers.keys()
# 更新缓冲区
for name, buffer in model_buffers.items():
biased_buffers[name].copy_(buffer)
average_buffers[name].copy_(buffer)
# 前向传播函数,根据是否处于训练状态返回模型或平均模型的输出
def forward(self, *args, **kwargs):
if self.training:
return self.model(*args, **kwargs)
return self.average(*args, **kwargs)
.\lucidrains\big-sleep\big_sleep\resample.py
"""Good differentiable image resampling for PyTorch."""
# 导入所需的库
from functools import update_wrapper
import math
import torch
from torch.nn import functional as F
# 定义 sinc 函数
def sinc(x):
return torch.where(x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([]))
# 定义 lanczos 函数
def lanczos(x, a):
cond = torch.logical_and(-a < x, x < a)
out = torch.where(cond, sinc(x) * sinc(x/a), x.new_zeros([]))
return out / out.sum()
# 定义 ramp 函数
def ramp(ratio, width):
n = math.ceil(width / ratio + 1)
out = torch.empty([n])
cur = 0
for i in range(out.shape[0]):
out[i] = cur
cur += ratio
return torch.cat([-out[1:].flip([0]), out])[1:-1]
# 定义 odd 函数
def odd(fn):
return update_wrapper(lambda x: torch.sign(x) * fn(abs(x)), fn)
# 定义将输入转换为线性 sRGB 的函数
def _to_linear_srgb(input):
cond = input <= 0.04045
a = input / 12.92
b = ((input + 0.055) / 1.055)**2.4
return torch.where(cond, a, b)
# 定义将输入转换为非线性 sRGB 的函数
def _to_nonlinear_srgb(input):
cond = input <= 0.0031308
a = 12.92 * input
b = 1.055 * input**(1/2.4) - 0.055
return torch.where(cond, a, b)
# 使用 odd 函数包装 _to_linear_srgb 函数和 _to_nonlinear_srgb 函数
to_linear_srgb = odd(_to_linear_srgb)
to_nonlinear_srgb = odd(_to_nonlinear_srgb)
# 定义 resample 函数
def resample(input, size, align_corners=True, is_srgb=False):
n, c, h, w = input.shape
dh, dw = size
# 如果 is_srgb 为 True,则将输入转换为线性 sRGB
if is_srgb:
input = to_linear_srgb(input)
input = input.view([n * c, 1, h, w])
# 如果目标高度小于原始高度
if dh < h:
kernel_h = lanczos(ramp(dh / h, 3), 3).to(input.device, input.dtype)
pad_h = (kernel_h.shape[0] - 1) // 2
input = F.pad(input, (0, 0, pad_h, pad_h), 'reflect')
input = F.conv2d(input, kernel_h[None, None, :, None])
# 如果目标宽度小于原始宽度
if dw < w:
kernel_w = lanczos(ramp(dw / w, 3), 3).to(input.device, input.dtype)
pad_w = (kernel_w.shape[0] - 1) // 2
input = F.pad(input, (pad_w, pad_w, 0, 0), 'reflect')
input = F.conv2d(input, kernel_w[None, None, None, :])
input = input.view([n, c, h, w])
input = F.interpolate(input, size, mode='bicubic', align_corners=align_corners)
# 如果 is_srgb 为 True,则将输出转换为非线性 sRGB
if is_srgb:
input = to_nonlinear_srgb(input)
return input
.\lucidrains\big-sleep\big_sleep\version.py
# 定义变量 __version__,赋值为字符串 '0.9.1'
__version__ = '0.9.1'
.\lucidrains\big-sleep\big_sleep\__init__.py
# 从 big_sleep.big_sleep 模块中导入 BigSleep 和 Imagine 类
from big_sleep.big_sleep import BigSleep, Imagine

artificial intelligence

cosmic love and attention

fire in the sky

a pyramid made of ice

a lonely house in the woods

marriage in the mountains

lantern dangling from a tree in a foggy graveyard

a vivid dream

balloons over the ruins of a city

the death of the lonesome astronomer - by moirage

the tragic intimacy of the eternal conversation with oneself - by moirage

demon fire - by WiseNat
Big Sleep
Ryan Murdock has done it again, combining OpenAI's CLIP and the generator from a BigGAN! This repository wraps up his work so it is easily accessible to anyone who owns a GPU.
You will be able to have the GAN dream up images using natural language with a one-line command in the terminal.
User-made notebook with bugfixes and added features, like google drive integration
Install
$ pip install big-sleep
Usage
$ dream "a pyramid made of ice"
Images will be saved to wherever the command is invoked
Advanced
You can invoke this in code with
from big_sleep import Imagine
dream = Imagine(
text = "fire in the sky",
lr = 5e-2,
save_every = 25,
save_progress = True
)
dream()
You can now train more than one phrase using the delimiter "|"
Train on Multiple Phrases
In this example we train on three phrases:
an armchair in the form of pikachuan armchair imitating pikachuabstract
from big_sleep import Imagine
dream = Imagine(
text = "an armchair in the form of pikachu|an armchair imitating pikachu|abstract",
lr = 5e-2,
save_every = 25,
save_progress = True
)
dream()
Penalize certain prompts as well!
In this example we train on the three phrases from before,
and penalize the phrases:
blurzoom
from big_sleep import Imagine
dream = Imagine(
text = "an armchair in the form of pikachu|an armchair imitating pikachu|abstract",
text_min = "blur|zoom",
)
dream()
You can also set a new text by using the .set_text(<str>) command
dream.set_text("a quiet pond underneath the midnight moon")
And reset the latents with .reset()
dream.reset()
To save the progression of images during training, you simply have to supply the --save-progress flag
$ dream "a bowl of apples next to the fireplace" --save-progress --save-every 100
Due to the class conditioned nature of the GAN, Big Sleep often steers off the manifold into noise. You can use a flag to save the best high scoring image (per CLIP critic) to {filepath}.best.png in your folder.
$ dream "a room with a view of the ocean" --save-best
Larger model
If you have enough memory, you can also try using a bigger vision model released by OpenAI for improved generations.
$ dream "storm clouds rolling in over a white barnyard" --larger-model
Experimentation
You can set the number of classes that you wish to restrict Big Sleep to use for the Big GAN with the --max-classes flag as follows (ex. 15 classes). This may lead to extra stability during training, at the cost of lost expressivity.
$ dream 'a single flower in a withered field' --max-classes 15
Alternatives
Deep Daze - CLIP and a deep SIREN network
Citations
@misc{unpublished2021clip,
title = {CLIP: Connecting Text and Images},
author = {Alec Radford, Ilya Sutskever, Jong Wook Kim, Gretchen Krueger, Sandhini Agarwal},
year = {2021}
}
@misc{brock2019large,
title = {Large Scale GAN Training for High Fidelity Natural Image Synthesis},
author = {Andrew Brock and Jeff Donahue and Karen Simonyan},
year = {2019},
eprint = {1809.11096},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
.\lucidrains\big-sleep\setup.py
# 导入 sys 模块
import sys
# 从 setuptools 模块中导入 setup 和 find_packages 函数
from setuptools import setup, find_packages
# 将 'big_sleep' 目录添加到 sys.path 的最前面
sys.path[0:0] = ['big_sleep']
# 从 version 模块中导入 __version__ 变量
from version import __version__
# 设置包的元数据
setup(
# 包的名称
name = 'big-sleep',
# 查找并包含所有包
packages = find_packages(),
# 包含所有数据文件
include_package_data = True,
# 设置入口点,命令行脚本为 'dream'
entry_points={
'console_scripts': [
'dream = big_sleep.cli:main',
],
},
# 版本号
version = __version__,
# 许可证
license='MIT',
# 描述
description = 'Big Sleep',
# 作者
author = 'Ryan Murdock, Phil Wang',
# 作者邮箱
author_email = 'lucidrains@gmail.com',
# 项目链接
url = 'https://github.com/lucidrains/big-sleep',
# 关键词
keywords = [
'artificial intelligence',
'deep learning',
'transformers',
'text to image',
'generative adversarial networks'
],
# 安装依赖
install_requires=[
'torch>=1.7.1',
'einops>=0.3',
'fire',
'ftfy',
'pytorch-pretrained-biggan',
'regex',
'torchvision>=0.8.2',
'tqdm'
],
# 分类
classifiers=[
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\big-sleep\test\multi_prompt_minmax.py
# 导入所需的库
import time
import shutil
import torch
from big_sleep import Imagine
# 初始化终止标志
terminate = False
# 信号处理函数,设置终止标志为True
def signal_handling(signum,frame):
global terminate
terminate = True
# 设定尝试次数
num_attempts = 4
# 循环尝试生成图像
for attempt in range(num_attempts):
# 创建Imagine对象,用于生成图像
dream = Imagine(
text = "an armchair in the form of pikachu\\an armchair imitating pikachu\\abstract",
text_min = "blur\\zoom",
lr = 7e-2,
image_size = 512,
gradient_accumulate_every = 1,
save_every = 50,
epochs = 5,
iterations = 50,
save_progress = False,
bilinear = False,
open_folder = False,
seed = None,
torch_deterministic = False,
max_classes = 20,
class_temperature = 2.,
save_date_time = False,
save_best = True,
experimental_resample = True,
ema_decay = 0.99
)
# 生成图像
dream()
# 复制生成的最佳图像
shutil.copy(dream.textpath + ".best.png", f"{attempt}.png")
try:
# 等待2秒
time.sleep(2)
# 删除dream对象
del dream
# 再次等待2秒
time.sleep(2)
# 清空GPU缓存
torch.cuda.empty_cache()
except Exception:
# 出现异常时,仅清空GPU缓存
torch.cuda.empty_cache()
.\lucidrains\bit-diffusion\bit_diffusion\bit_diffusion.py
# 导入所需的库
import math
from pathlib import Path
from functools import partial
from multiprocessing import cpu_count
import torch
from torch import nn, einsum
from torch.special import expm1
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from torchvision import transforms as T, utils
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange
from PIL import Image
from tqdm.auto import tqdm
from ema_pytorch import EMA
from accelerate import Accelerator
# 常量定义
BITS = 8
# 辅助函数
def exists(x):
return x is not None
def default(val, d):
if exists(val):
return val
return d() if callable(d) else d
def cycle(dl):
while True:
for data in dl:
yield data
def has_int_squareroot(num):
return (math.sqrt(num) ** 2) == num
def num_to_groups(num, divisor):
groups = num // divisor
remainder = num % divisor
arr = [divisor] * groups
if remainder > 0:
arr.append(remainder)
return arr
def convert_image_to(pil_img_type, image):
if image.mode != pil_img_type:
return image.convert(pil_img_type)
return image
# 小型辅助模块
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, *args, **kwargs):
return self.fn(x, *args, **kwargs) + x
def Upsample(dim, dim_out = None):
return nn.Sequential(
nn.Upsample(scale_factor = 2, mode = 'nearest'),
nn.Conv2d(dim, default(dim_out, dim), 3, padding = 1)
)
def Downsample(dim, dim_out = None):
return nn.Conv2d(dim, default(dim_out, dim), 4, 2, 1)
class LayerNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
def forward(self, x):
eps = 1e-5 if x.dtype == torch.float32 else 1e-3
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) * (var + eps).rsqrt() * self.g
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = LayerNorm(dim)
def forward(self, x):
x = self.norm(x)
return self.fn(x)
# 位置嵌入
class LearnedSinusoidalPosEmb(nn.Module):
""" following @crowsonkb 's lead with learned sinusoidal pos emb """
""" https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """
def __init__(self, dim):
super().__init__()
assert (dim % 2) == 0
half_dim = dim // 2
self.weights = nn.Parameter(torch.randn(half_dim))
def forward(self, x):
x = rearrange(x, 'b -> b 1')
freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1)
fouriered = torch.cat((x, fouriered), dim = -1)
return fouriered
# 构建块模块
class Block(nn.Module):
def __init__(self, dim, dim_out, groups = 8):
super().__init__()
self.proj = nn.Conv2d(dim, dim_out, 3, padding = 1)
self.norm = nn.GroupNorm(groups, dim_out)
self.act = nn.SiLU()
def forward(self, x, scale_shift = None):
x = self.proj(x)
x = self.norm(x)
if exists(scale_shift):
scale, shift = scale_shift
x = x * (scale + 1) + shift
x = self.act(x)
return x
class ResnetBlock(nn.Module):
# 初始化函数,定义神经网络结构
def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8):
# 调用父类的初始化函数
super().__init__()
# 如果存在时间嵌入维度,则创建包含激活函数和线性层的序列模块
self.mlp = nn.Sequential(
nn.SiLU(),
nn.Linear(time_emb_dim, dim_out * 2)
) if exists(time_emb_dim) else None
# 创建第一个块
self.block1 = Block(dim, dim_out, groups = groups)
# 创建第二个块
self.block2 = Block(dim_out, dim_out, groups = groups)
# 如果输入维度和输出维度不相等,则使用卷积层进行维度转换,否则使用恒等映射
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
# 前向传播函数
def forward(self, x, time_emb = None):
scale_shift = None
# 如果存在时间嵌入模块和时间嵌入向量,则进行处理
if exists(self.mlp) and exists(time_emb):
# 对时间嵌入向量进行处理
time_emb = self.mlp(time_emb)
# 重新排列时间嵌入向量的维度
time_emb = rearrange(time_emb, 'b c -> b c 1 1')
# 将时间嵌入向量分成两部分,用于缩放和平移
scale_shift = time_emb.chunk(2, dim = 1)
# 使用第一个块处理输入数据
h = self.block1(x, scale_shift = scale_shift)
# 使用第二个块处理第一个块的输出
h = self.block2(h)
# 返回块处理后的结果与输入数据经过维度转换后的结果的和
return h + self.res_conv(x)
# 定义一个线性注意力模块,继承自 nn.Module 类
class LinearAttention(nn.Module):
# 初始化函数,接受维度 dim、头数 heads 和头维度 dim_head 作为参数
def __init__(self, dim, heads = 4, dim_head = 32):
super().__init__()
# 缩放因子为头维度的倒数
self.scale = dim_head ** -0.5
# 头数
self.heads = heads
# 隐藏维度为头维度乘以头数
hidden_dim = dim_head * heads
# 将输入转换为查询、键、值的形式
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
# 输出转换层,包含一个卷积层和一个 LayerNorm 层
self.to_out = nn.Sequential(
nn.Conv2d(hidden_dim, dim, 1),
LayerNorm(dim)
)
# 前向传播函数
def forward(self, x):
# 获取输入张量的形状信息
b, c, h, w = x.shape
# 将输入通过查询、键、值转换层,并按维度 1 切分为三部分
qkv = self.to_qkv(x).chunk(3, dim = 1)
# 将查询、键、值按照指定维度重排
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv)
# 对查询和键进行 softmax 操作
q = q.softmax(dim = -2)
k = k.softmax(dim = -1)
# 对查询进行缩放
q = q * self.scale
# 对值进行归一化
v = v / (h * w)
# 计算上下文信息
context = torch.einsum('b h d n, b h e n -> b h d e', k, v)
# 计算输出
out = torch.einsum('b h d e, b h d n -> b h e n', context, q)
# 重排输出张量的维度
out = rearrange(out, 'b h c (x y) -> b (h c) x y', h = self.heads, x = h, y = w)
return self.to_out(out)
# 定义一个注意力模块,继承自 nn.Module 类
class Attention(nn.Module):
# 初始化函数,接受维度 dim、头数 heads 和头维度 dim_head 作为参数
def __init__(self, dim, heads = 4, dim_head = 32):
super().__init__()
# 缩放因子为头维度的倒数
self.scale = dim_head ** -0.5
# 头数
self.heads = heads
# 隐藏维度为头维度乘以头数
hidden_dim = dim_head * heads
# 将输入转换为查询、键、值的形式
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
# 输出转换层,包含一个卷积层
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
# 前向传播函数
def forward(self, x):
# 获取输入张量的形状信息
b, c, h, w = x.shape
# 将输入通过查询、键、值转换层,并按维度 1 切分为三部分
qkv = self.to_qkv(x).chunk(3, dim = 1)
# 将查询、键、值按照指定维度重排
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv)
# 对查询进行缩放
q = q * self.scale
# 计算相似度
sim = einsum('b h d i, b h d j -> b h i j', q, k)
# 对相似度进行 softmax 操作
attn = sim.softmax(dim = -1)
# 计算输出
out = einsum('b h i j, b h d j -> b h i d', attn, v)
# 重排输出张量的维度
out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w)
return self.to_out(out)
# 定义一个 Unet 模型,继承自 nn.Module 类
class Unet(nn.Module):
# 初始化函数,接受维度 dim、初始维度 init_dim、维度倍增 dim_mults、通道数 channels、位数 bits、ResNet 块组数 resnet_block_groups 和学习的正弦维度 learned_sinusoidal_dim 作为参数
def __init__(
self,
dim,
init_dim = None,
dim_mults=(1, 2, 4, 8),
channels = 3,
bits = BITS,
resnet_block_groups = 8,
learned_sinusoidal_dim = 16
):
# 调用父类的构造函数
super().__init__()
# 确定维度
channels *= bits
self.channels = channels
input_channels = channels * 2
# 初始化维度
init_dim = default(init_dim, dim)
# 创建一个卷积层,输入通道数为input_channels,输出通道数为init_dim,卷积核大小为7,填充为3
self.init_conv = nn.Conv2d(input_channels, init_dim, 7, padding = 3)
# 计算不同层次的维度
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
# 使用ResnetBlock类创建一个部分函数block_klass,其中groups参数为resnet_block_groups
block_klass = partial(ResnetBlock, groups = resnet_block_groups)
# 时间嵌入
time_dim = dim * 4
# 创建一个LearnedSinusoidalPosEmb对象sinu_pos_emb
sinu_pos_emb = LearnedSinusoidalPosEmb(learned_sinusoidal_dim)
fourier_dim = learned_sinusoidal_dim + 1
# 创建一个包含线性层和激活函数的神经网络模块time_mlp
self.time_mlp = nn.Sequential(
sinu_pos_emb,
nn.Linear(fourier_dim, time_dim),
nn.GELU(),
nn.Linear(time_dim, time_dim)
)
# 层
self.downs = nn.ModuleList([])
self.ups = nn.ModuleList([])
num_resolutions = len(in_out)
# 遍历不同层次的维度
for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (num_resolutions - 1)
# 向downs列表中添加模块列表
self.downs.append(nn.ModuleList([
block_klass(dim_in, dim_in, time_emb_dim = time_dim),
block_klass(dim_in, dim_in, time_emb_dim = time_dim),
Residual(PreNorm(dim_in, LinearAttention(dim_in))),
Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding = 1)
]))
mid_dim = dims[-1]
# 创建一个ResnetBlock对象mid_block1
self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim)
# 创建一个包含注意力机制的Residual对象mid_attn
self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
# 创建一个ResnetBlock对象mid_block2
self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim)
# 反向遍历不同层次的维度
for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
is_last = ind == (len(in_out) - 1)
# ���ups列表中添加模块列表
self.ups.append(nn.ModuleList([
block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
Residual(PreNorm(dim_out, LinearAttention(dim_out))),
Upsample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, padding = 1)
]))
# 创建一个ResnetBlock对象final_res_block
self.final_res_block = block_klass(dim * 2, dim, time_emb_dim = time_dim)
# 创建一个卷积层final_conv,输入通道数为dim,输出通道数为channels,卷积核大小为1
self.final_conv = nn.Conv2d(dim, channels, 1)
def forward(self, x, time, x_self_cond = None):
# 如果x_self_cond为None,则创建一个与x相同形状的全零张量
x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x))
# 在通道维度上拼接x_self_cond和x
x = torch.cat((x_self_cond, x), dim = 1)
# 将输入数据x通过init_conv卷积层
x = self.init_conv(x)
r = x.clone()
# 通过时间嵌入网络计算时间信息t
t = self.time_mlp(time)
h = []
# 遍历downs列表中的模块列表
for block1, block2, attn, downsample in self.downs:
# 通过block1进行处理
x = block1(x, t)
h.append(x)
# 通过block2进行处理
x = block2(x, t)
# 通过attn进行处理
x = attn(x)
h.append(x)
# 通过downsample进行处理
x = downsample(x)
# 通过mid_block1进行处理
x = self.mid_block1(x, t)
# 通过mid_attn进行处理
x = self.mid_attn(x)
# 通过mid_block2进行处理
x = self.mid_block2(x, t)
# 遍历ups列表中的模块列表
for block1, block2, attn, upsample in self.ups:
# 在通道维度上拼接x和h中的张量
x = torch.cat((x, h.pop()), dim = 1)
# 通过block1进行处理
x = block1(x, t)
# 在通道维度上拼接x和h中的张量
x = torch.cat((x, h.pop()), dim = 1)
# 通过block2进行处理
x = block2(x, t)
# 通过attn进行处理
x = attn(x)
# 通过upsample进行处理
x = upsample(x)
# 在通道维度上拼接x和r
x = torch.cat((x, r), dim = 1)
# 通过final_res_block进行处理
x = self.final_res_block(x, t)
return self.final_conv(x)
# 将十进制数转换为位表示,并反向转换
def decimal_to_bits(x, bits = BITS):
"""将范围在0到1之间的图像张量转换为范围在-1到1之间的位张量"""
device = x.device
# 将图像张量乘以255并取整,限制在0到255之间
x = (x * 255).int().clamp(0, 255)
# 创建位掩码
mask = 2 ** torch.arange(bits - 1, -1, -1, device = device)
mask = rearrange(mask, 'd -> d 1 1')
x = rearrange(x, 'b c h w -> b c 1 h w')
# 将图像张量转换为位张量
bits = ((x & mask) != 0).float()
bits = rearrange(bits, 'b c d h w -> b (c d) h w')
bits = bits * 2 - 1
return bits
def bits_to_decimal(x, bits = BITS):
"""将范围在-1到1之间的位转换为范围在0到1之间的图像张量"""
device = x.device
# 将位张量转换为整数张量
x = (x > 0).int()
mask = 2 ** torch.arange(bits - 1, -1, -1, device = device, dtype = torch.int32)
mask = rearrange(mask, 'd -> d 1 1')
x = rearrange(x, 'b (c d) h w -> b c d h w', d = bits)
dec = reduce(x * mask, 'b c d h w -> b c h w', 'sum')
return (dec / 255).clamp(0., 1.)
# 位扩散类
def log(t, eps = 1e-20):
return torch.log(t.clamp(min = eps))
def right_pad_dims_to(x, t):
padding_dims = x.ndim - t.ndim
if padding_dims <= 0:
return t
return t.view(*t.shape, *((1,) * padding_dims))
def beta_linear_log_snr(t):
return -torch.log(expm1(1e-4 + 10 * (t ** 2)))
def alpha_cosine_log_snr(t, s: float = 0.008):
return -log((torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** -2) - 1, eps = 1e-5) # 不确定这是否考虑了在离散版本中将beta剪切为0.999
def log_snr_to_alpha_sigma(log_snr):
return torch.sqrt(torch.sigmoid(log_snr)), torch.sqrt(torch.sigmoid(-log_snr))
class BitDiffusion(nn.Module):
def __init__(
self,
model,
*,
image_size,
timesteps = 1000,
use_ddim = False,
noise_schedule = 'cosine',
time_difference = 0.,
bit_scale = 1.
):
super().__init__()
self.model = model
self.channels = self.model.channels
self.image_size = image_size
if noise_schedule == "linear":
self.log_snr = beta_linear_log_snr
elif noise_schedule == "cosine":
self.log_snr = alpha_cosine_log_snr
else:
raise ValueError(f'invalid noise schedule {noise_schedule}')
self.bit_scale = bit_scale
self.timesteps = timesteps
self.use_ddim = use_ddim
# 在论文中提出���与time_next相加,作为修复自我条件不足和在采样时间步数小于400时降低FID的方法
self.time_difference = time_difference
@property
def device(self):
return next(self.model.parameters()).device
def get_sampling_timesteps(self, batch, *, device):
times = torch.linspace(1., 0., self.timesteps + 1, device = device)
times = repeat(times, 't -> b t', b = batch)
times = torch.stack((times[:, :-1], times[:, 1:]), dim = 0)
times = times.unbind(dim = -1)
return times
@torch.no_grad()
# 从 DDPM 模型中采样生成图像
def ddpm_sample(self, shape, time_difference = None):
# 获取批次大小和设备信息
batch, device = shape[0], self.device
# 设置时间差,默认为 self.time_difference
time_difference = default(time_difference, self.time_difference)
# 获取采样时间步骤对
time_pairs = self.get_sampling_timesteps(batch, device = device)
# 生成随机噪声图像
img = torch.randn(shape, device=device)
x_start = None
# 遍历时间步骤对
for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step', total = self.timesteps):
# 添加时间延迟
time_next = (time_next - self.time_difference).clamp(min = 0.)
# 获取噪声条件
noise_cond = self.log_snr(time)
# 获取预测的 x0
x_start = self.model(img, noise_cond, x_start)
# 限制 x0 的范围
x_start.clamp_(-self.bit_scale, self.bit_scale)
# 获取 log(snr)
log_snr = self.log_snr(time)
log_snr_next = self.log_snr(time_next)
log_snr, log_snr_next = map(partial(right_pad_dims_to, img), (log_snr, log_snr_next))
# 获取时间和下一个时间的 alpha 和 sigma
alpha, sigma = log_snr_to_alpha_sigma(log_snr)
alpha_next, sigma_next = log_snr_to_alpha_sigma(log_snr_next)
# 推导后验均值和方差
c = -expm1(log_snr - log_snr_next)
mean = alpha_next * (img * (1 - c) / alpha + c * x_start)
variance = (sigma_next ** 2) * c
log_variance = log(variance)
# 获取噪声
noise = torch.where(
rearrange(time_next > 0, 'b -> b 1 1 1'),
torch.randn_like(img),
torch.zeros_like(img)
)
img = mean + (0.5 * log_variance).exp() * noise
return bits_to_decimal(img)
# 无梯度计算的 DDIM 模型采样函数
@torch.no_grad()
def ddim_sample(self, shape, time_difference = None):
# 获取批次大小和设备信息
batch, device = shape[0], self.device
# 设置时间差,默认为 self.time_difference
time_difference = default(time_difference, self.time_difference)
# 获取采样时间步骤对
time_pairs = self.get_sampling_timesteps(batch, device = device)
# 生成随机噪声图像
img = torch.randn(shape, device = device)
x_start = None
# 遍历时间步骤对
for times, times_next in tqdm(time_pairs, desc = 'sampling loop time step'):
# 添加时间延迟
times_next = (times_next - time_difference).clamp(min = 0.)
# 获取时间和噪声水平
log_snr = self.log_snr(times)
log_snr_next = self.log_snr(times_next)
padded_log_snr, padded_log_snr_next = map(partial(right_pad_dims_to, img), (log_snr, log_snr_next))
alpha, sigma = log_snr_to_alpha_sigma(padded_log_snr)
alpha_next, sigma_next = log_snr_to_alpha_sigma(padded_log_snr_next)
# 预测 x0
x_start = self.model(img, log_snr, x_start)
# 限制 x0 的范围
x_start.clamp_(-self.bit_scale, self.bit_scale)
# 获取预测的噪声
pred_noise = (img - alpha * x_start) / sigma.clamp(min = 1e-8)
# 计算下一个 x
img = x_start * alpha_next + pred_noise * sigma_next
return bits_to_decimal(img)
# 采样函数,根据是否使用 DDIM 选择不同的采样方法
@torch.no_grad()
def sample(self, batch_size = 16):
image_size, channels = self.image_size, self.channels
sample_fn = self.ddpm_sample if not self.use_ddim else self.ddim_sample
return sample_fn((batch_size, channels, image_size, image_size))
# 定义前向传播函数,接受图像和其他参数
def forward(self, img, *args, **kwargs):
# 解包图像的形状和设备信息
batch, c, h, w, device, img_size, = *img.shape, img.device, self.image_size
# 断言图像的高度和宽度必须为指定的图像大小
assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
# 生成随机采样时间
times = torch.zeros((batch,), device=device).float().uniform_(0, 1.)
# 将图像转换为比特表示
img = decimal_to_bits(img) * self.bit_scale
# 生成噪声样本
noise = torch.randn_like(img)
# 计算噪声水平
noise_level = self.log_snr(times)
# 将噪声水平填充到与图像相同的维度
padded_noise_level = right_pad_dims_to(img, noise_level)
# 将噪声水平转换为 alpha 和 sigma
alpha, sigma = log_snr_to_alpha_sigma(padded_noise_level)
# 添加噪声到图像
noised_img = alpha * img + sigma * noise
# 如果进行自条件训练,50%的概率从当前时间预测 x_start,并使用 unet 进行条件
# 这种技术会使训练速度减慢 25%,但似乎显著降低 FID
self_cond = None
if torch.rand((1)) < 0.5:
with torch.no_grad():
# 使用模型预测 x_start,并分离计算图
self_cond = self.model(noised_img, noise_level).detach_()
# 预测并进行梯度下降步骤
pred = self.model(noised_img, noise_level, self_cond)
# 返回预测值和真实值的均方误差损失
return F.mse_loss(pred, img)
# dataset classes
# 定义 Dataset 类,继承自 torch.utils.data.Dataset
class Dataset(Dataset):
# 初始化函数
def __init__(
self,
folder,
image_size,
exts = ['jpg', 'jpeg', 'png', 'tiff'],
augment_horizontal_flip = False,
pil_img_type = None
):
# 调用父类的初始化函数
super().__init__()
# 设置属性
self.folder = folder
self.image_size = image_size
# 获取指定扩展名的所有文件路径
self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]
# 部分转换函数
maybe_convert_fn = partial(convert_image_to, pil_img_type) if exists(pil_img_type) else nn.Identity()
# 数据转换操作
self.transform = T.Compose([
T.Lambda(maybe_convert_fn),
T.Resize(image_size),
T.RandomHorizontalFlip() if augment_horizontal_flip else nn.Identity(),
T.CenterCrop(image_size),
T.ToTensor()
])
# 返回数据集的长度
def __len__(self):
return len(self.paths)
# 获取指定索引的数据
def __getitem__(self, index):
path = self.paths[index]
img = Image.open(path)
return self.transform(img)
# trainer class
# 定义 Trainer 类
class Trainer(object):
# 初始化函数
def __init__(
self,
diffusion_model,
folder,
*,
train_batch_size = 16,
gradient_accumulate_every = 1,
augment_horizontal_flip = True,
train_lr = 1e-4,
train_num_steps = 100000,
ema_update_every = 10,
ema_decay = 0.995,
adam_betas = (0.9, 0.99),
save_and_sample_every = 1000,
num_samples = 25,
results_folder = './results',
amp = False,
mixed_precision_type = 'fp16',
split_batches = True,
pil_img_type = None
):
# 调用父类的初始化函数
super().__init__()
# 初始化加速器
self.accelerator = Accelerator(
split_batches = split_batches,
mixed_precision = mixed_precision_type if amp else 'no'
)
# 设置扩散模型
self.model = diffusion_model
# 检查样本数量是否有整数平方根
assert has_int_squareroot(num_samples), 'number of samples must have an integer square root'
self.num_samples = num_samples
self.save_and_sample_every = save_and_sample_every
self.batch_size = train_batch_size
self.gradient_accumulate_every = gradient_accumulate_every
self.train_num_steps = train_num_steps
self.image_size = diffusion_model.image_size
# dataset and dataloader
# 创建数据集
self.ds = Dataset(folder, self.image_size, augment_horizontal_flip = augment_horizontal_flip, pil_img_type = pil_img_type)
# 创建数据加载器
dl = DataLoader(self.ds, batch_size = train_batch_size, shuffle = True, pin_memory = True, num_workers = cpu_count())
# 准备数据加载器
dl = self.accelerator.prepare(dl)
self.dl = cycle(dl)
# optimizer
# 创建优化器
self.opt = Adam(diffusion_model.parameters(), lr = train_lr, betas = adam_betas)
# for logging results in a folder periodically
# 如果是主进程
if self.accelerator.is_main_process:
# 创建指数移动平均模型
self.ema = EMA(diffusion_model, beta = ema_decay, update_every = ema_update_every)
# 设置结果文件夹路径
self.results_folder = Path(results_folder)
self.results_folder.mkdir(exist_ok = True)
# step counter state
# 步数计数器
self.step = 0
# prepare model, dataloader, optimizer with accelerator
# 使用加速器准备模型、数据加载器和优化器
self.model, self.opt = self.accelerator.prepare(self.model, self.opt)
# 保存模型
def save(self, milestone):
# 如果不是本地主进程,则返回
if not self.accelerator.is_local_main_process:
return
# 保存模型相关数据
data = {
'step': self.step,
'model': self.accelerator.get_state_dict(self.model),
'opt': self.opt.state_dict(),
'ema': self.ema.state_dict(),
'scaler': self.accelerator.scaler.state_dict() if exists(self.accelerator.scaler) else None
}
# 将数据保存到文件
torch.save(data, str(self.results_folder / f'model-{milestone}.pt'))
# 加载指定里程碑的模型数据
def load(self, milestone):
# 从文件中加载模型数据
data = torch.load(str(self.results_folder / f'model-{milestone}.pt'))
# 获取未包装的模型对象
model = self.accelerator.unwrap_model(self.model)
# 加载模型的状态字典
model.load_state_dict(data['model'])
# 设置当前步数为加载的数据中的步数
self.step = data['step']
# 加载优化器的状态字典
self.opt.load_state_dict(data['opt'])
# 加载指数移动平均模型的状态字典
self.ema.load_state_dict(data['ema'])
# 如果加速器的缩放器和加载的数据中的缩放器都存在,则加载缩放器的状态字典
if exists(self.accelerator.scaler) and exists(data['scaler']):
self.accelerator.scaler.load_state_dict(data['scaler'])
# 训练模型
def train(self):
# 获取加速器和设备
accelerator = self.accelerator
device = accelerator.device
# 使用 tqdm 显示训练进度条
with tqdm(initial=self.step, total=self.train_num_steps, disable=not accelerator.is_main_process) as pbar:
# 在未达到训练步数之前循环训练
while self.step < self.train_num_steps:
total_loss = 0.
# 根据梯度累积的次数循环
for _ in range(self.gradient_accumulate_every):
# 从数据加载器中获取数据并移动到设备上
data = next(self.dl).to(device)
# 使用自动混合精度计算模型的损失
with self.accelerator.autocast():
loss = self.model(data)
loss = loss / self.gradient_accumulate_every
total_loss += loss.item()
# 反向传播计算梯度
self.accelerator.backward(loss)
# 更新进度条显示损失值
pbar.set_description(f'loss: {total_loss:.4f}')
# 等待所有进程完成当前步骤
accelerator.wait_for_everyone()
# 更新优化器参数
self.opt.step()
self.opt.zero_grad()
# 等待所有进程完成当前步骤
accelerator.wait_for_everyone()
# 如果是主进程
if accelerator.is_main_process:
# 将指数移动平均模型移动到设备上并更新
self.ema.to(device)
self.ema.update()
# 如果步数不为0���可以保存和采样
if self.step != 0 and self.step % self.save_and_sample_every == 0:
# 将指数移动平均模型设置为评估模式
self.ema.ema_model.eval()
# 使用无梯度计算生成样本图像
with torch.no_grad():
milestone = self.step // self.save_and_sample_every
batches = num_to_groups(self.num_samples, self.batch_size)
all_images_list = list(map(lambda n: self.ema.ema_model.sample(batch_size=n), batches))
# 拼接所有生成的图像并保存
all_images = torch.cat(all_images_list, dim=0)
utils.save_image(all_images, str(self.results_folder / f'sample-{milestone}.png'), nrow=int(math.sqrt(self.num_samples)))
self.save(milestone)
# 更新步数并进度条
self.step += 1
pbar.update(1)
# 打印训练完成信息
accelerator.print('training complete')