Lucidrains 系列项目源码解析(二十一)
.\lucidrains\DALLE2-pytorch\setup.py
# 导入所需的模块和函数
from setuptools import setup, find_packages
# 执行版本文件中的代码,将版本信息导入当前环境
exec(open('dalle2_pytorch/version.py').read())
# 设置包的信息和配置
setup(
# 包名
name = 'dalle2-pytorch',
# 查找所有包,不排除任何包
packages = find_packages(exclude=[]),
# 包含所有数据文件
include_package_data = True,
# 设置命令行入口点
entry_points={
'console_scripts': [
'dalle2_pytorch = dalle2_pytorch.cli:main',
'dream = dalle2_pytorch.cli:dream'
],
},
# 版本号
version = __version__,
# 许可证信息
license='MIT',
# 描述信息
description = 'DALL-E 2',
# 作者信息
author = 'Phil Wang',
# 作者邮箱
author_email = 'lucidrains@gmail.com',
# 长描述内容类型
long_description_content_type = 'text/markdown',
# 项目链接
url = 'https://github.com/lucidrains/dalle2-pytorch',
# 关键词
keywords = [
'artificial intelligence',
'deep learning',
'text to image'
],
# 安装依赖
install_requires=[
'accelerate',
'click',
'open-clip-torch>=2.0.0,<3.0.0',
'clip-anytorch>=2.5.2',
'coca-pytorch>=0.0.5',
'ema-pytorch>=0.0.7',
'einops>=0.7.0',
'embedding-reader',
'kornia>=0.5.4',
'numpy',
'packaging',
'pillow',
'pydantic>=2',
'pytorch-warmup',
'resize-right>=0.0.2',
'rotary-embedding-torch',
'torch>=1.10',
'torchvision',
'tqdm',
'vector-quantize-pytorch',
'x-clip>=0.4.4',
'webdataset>=0.2.5',
'fsspec>=2022.1.0',
'torchmetrics[image]>=0.8.0'
],
# 分类信息
classifiers=[
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\DALLE2-pytorch\train_decoder.py
# 导入所需的模块
from pathlib import Path
from typing import List
from datetime import timedelta
# 导入自定义模块
from dalle2_pytorch.trainer import DecoderTrainer
from dalle2_pytorch.dataloaders import create_image_embedding_dataloader
from dalle2_pytorch.trackers import Tracker
from dalle2_pytorch.train_configs import DecoderConfig, TrainDecoderConfig
from dalle2_pytorch.utils import Timer, print_ribbon
from dalle2_pytorch.dalle2_pytorch import Decoder, resize_image_to
from clip import tokenize
# 导入第三方模块
import torchvision
import torch
from torch import nn
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore
from torchmetrics.image.kid import KernelInceptionDistance
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
from accelerate import Accelerator, DistributedDataParallelKwargs, InitProcessGroupKwargs
from accelerate.utils import dataclasses as accelerate_dataclasses
import webdataset as wds
import click
# 定义常量
TRAIN_CALC_LOSS_EVERY_ITERS = 10
VALID_CALC_LOSS_EVERY_ITERS = 10
# 定义辅助函数
def exists(val):
return val is not None
# 定义主要函数
def create_dataloaders(
available_shards,
webdataset_base_url,
img_embeddings_url=None,
text_embeddings_url=None,
shard_width=6,
num_workers=4,
batch_size=32,
n_sample_images=6,
shuffle_train=True,
resample_train=False,
img_preproc = None,
index_width=4,
train_prop = 0.75,
val_prop = 0.15,
test_prop = 0.10,
seed = 0,
**kwargs
):
"""
随机将可用的数据分片分为训练、验证和测试集,并为每个集合返回一个数据加载器
"""
# 检查训练、验证和测试集的比例之和是否为1
assert train_prop + test_prop + val_prop == 1
# 计算训练集、测试集和验证集的数量
num_train = round(train_prop*len(available_shards))
num_test = round(test_prop*len(available_shards))
num_val = len(available_shards) - num_train - num_test
# 检查分配是否正确
assert num_train + num_test + num_val == len(available_shards), f"{num_train} + {num_test} + {num_val} = {num_train + num_test + num_val} != {len(available_shards)}"
# 使用随机数生成器手动设置种子,将数据集随机分为训练、测试和验证集
train_split, test_split, val_split = torch.utils.data.random_split(available_shards, [num_train, num_test, num_val], generator=torch.Generator().manual_seed(seed))
# 根据分片宽度将训练、测试和验证集的 URL 进行格式化
train_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in train_split]
test_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in test_split]
val_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in val_split]
# 创建数据加载器的 lambda 函数
create_dataloader = lambda tar_urls, shuffle=False, resample=False, for_sampling=False: create_image_embedding_dataloader(
tar_url=tar_urls,
num_workers=num_workers,
batch_size=batch_size if not for_sampling else n_sample_images,
img_embeddings_url=img_embeddings_url,
text_embeddings_url=text_embeddings_url,
index_width=index_width,
shuffle_num = None,
extra_keys= ["txt"],
shuffle_shards = shuffle,
resample_shards = resample,
img_preproc=img_preproc,
handler=wds.handlers.warn_and_continue
)
# 创建训练、验证和测试集的数据加载器
train_dataloader = create_dataloader(train_urls, shuffle=shuffle_train, resample=resample_train)
train_sampling_dataloader = create_dataloader(train_urls, shuffle=False, for_sampling=True)
val_dataloader = create_dataloader(val_urls, shuffle=False)
test_dataloader = create_dataloader(test_urls, shuffle=False)
test_sampling_dataloader = create_dataloader(test_urls, shuffle=False, for_sampling=True)
# 返回数据加载器字典
return {
"train": train_dataloader,
"train_sampling": train_sampling_dataloader,
"val": val_dataloader,
"test": test_dataloader,
"test_sampling": test_sampling_dataloader
}
def get_dataset_keys(dataloader):
"""
# 有时需要获取数据加载器返回的键。由于数据集被嵌入在数据加载器中,我们需要进行一些处理来恢复它。
"""
# 如果数据加载器实际上是一个WebLoader,则需要提取真正的数据加载器
if isinstance(dataloader, wds.WebLoader):
dataloader = dataloader.pipeline[0]
# 返回数据加载器的数据集键映射
return dataloader.dataset.key_map
# 从数据加载器中获取示例数据,返回一个包含示例的列表
def get_example_data(dataloader, device, n=5):
# 初始化空列表
images = []
img_embeddings = []
text_embeddings = []
captions = []
# 遍历数据加载器
for img, emb, txt in dataloader:
# 获取图像和文本嵌入
img_emb, text_emb = emb.get('img'), emb.get('text')
# 如果图像嵌入不为空
if img_emb is not None:
# 将图像嵌入转移到指定设备上
img_emb = img_emb.to(device=device, dtype=torch.float)
img_embeddings.extend(list(img_emb))
else:
# 否则添加与图像形状相同数量的 None
img_embeddings.extend([None]*img.shape[0])
# 如果文本嵌入不为空
if text_emb is not None:
# 将文本嵌入转移到指定设备上
text_emb = text_emb.to(device=device, dtype=torch.float)
text_embeddings.extend(list(text_emb))
else:
# 否则添加与图像形状相同数量的 None
text_embeddings.extend([None]*img.shape[0])
# 将图像转移到指定设备上
img = img.to(device=device, dtype=torch.float)
images.extend(list(img))
captions.extend(list(txt))
# 如果示例数量达到指定数量,跳出循环
if len(images) >= n:
break
# 返回示例列表
return list(zip(images[:n], img_embeddings[:n], text_embeddings[:n], captions[:n]))
# 生成样本并从嵌入中生成图像
def generate_samples(trainer, example_data, clip=None, start_unet=1, end_unet=None, condition_on_text_encodings=False, cond_scale=1.0, device=None, text_prepend="", match_image_size=True):
# 解压示例数据
real_images, img_embeddings, text_embeddings, txts = zip(*example_data)
sample_params = {}
# 如果图像嵌入为空
if img_embeddings[0] is None:
# 从真实图像生成图像嵌入
imgs_tensor = torch.stack(real_images)
assert clip is not None, "clip is None, but img_embeddings is None"
imgs_tensor.to(device=device)
img_embeddings, img_encoding = clip.embed_image(imgs_tensor)
sample_params["image_embed"] = img_embeddings
else:
# 使用预先计算的图像嵌入
img_embeddings = torch.stack(img_embeddings)
sample_params["image_embed"] = img_embeddings
# 如果基于文本编码条件生成
if condition_on_text_encodings:
# 如果文本嵌入为空
if text_embeddings[0] is None:
# 从文本生成文本嵌入
assert clip is not None, "clip is None, but text_embeddings is None"
tokenized_texts = tokenize(txts, truncate=True).to(device=device)
text_embed, text_encodings = clip.embed_text(tokenized_texts)
sample_params["text_encodings"] = text_encodings
else:
# 使用预先计算的文本嵌入
text_embeddings = torch.stack(text_embeddings)
sample_params["text_encodings"] = text_embeddings
sample_params["start_at_unet_number"] = start_unet
sample_params["stop_at_unet_number"] = end_unet
# 如果只训练上采样器
if start_unet > 1:
sample_params["image"] = torch.stack(real_images)
if device is not None:
sample_params["_device"] = device
# 生成样本
samples = trainer.sample(**sample_params, _cast_deepspeed_precision=False) # 在采样时不需要转换为 FP16
generated_images = list(samples)
captions = [text_prepend + txt for txt in txts]
# 如果匹配图像大小
if match_image_size:
generated_image_size = generated_images[0].shape[-1]
real_images = [resize_image_to(image, generated_image_size, clamp_range=(0, 1)) for image in real_images]
# 返回真实图像、生成图像和标题
return real_images, generated_images, captions
# 生成网格样本
def generate_grid_samples(trainer, examples, clip=None, start_unet=1, end_unet=None, condition_on_text_encodings=False, cond_scale=1.0, device=None, text_prepend=""):
# 生成样本并使用 torchvision 将其放入并排网格中以便查看
real_images, generated_images, captions = generate_samples(trainer, examples, clip, start_unet, end_unet, condition_on_text_encodings, cond_scale, device, text_prepend)
# 使用torchvision.utils.make_grid函数将每对原始图像和生成图像组合成一个图像网格
grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]
# 返回图像网格列表和对应的文本描述
return grid_images, captions
def evaluate_trainer(trainer, dataloader, device, start_unet, end_unet, clip=None, condition_on_text_encodings=False, cond_scale=1.0, inference_device=None, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None):
"""
Computes evaluation metrics for the decoder
"""
metrics = {}
# 准备数据
examples = get_example_data(dataloader, device, n_evaluation_samples)
if len(examples) == 0:
print("No data to evaluate. Check that your dataloader has shards.")
return metrics
real_images, generated_images, captions = generate_samples(trainer, examples, clip, start_unet, end_unet, condition_on_text_encodings, cond_scale, inference_device)
real_images = torch.stack(real_images).to(device=device, dtype=torch.float)
generated_images = torch.stack(generated_images).to(device=device, dtype=torch.float)
# 将像素值从 [0, 1] 转换为 [0, 255],并将数据类型从 torch.float 转换为 torch.uint8
int_real_images = real_images.mul(255).add(0.5).clamp(0, 255).type(torch.uint8)
int_generated_images = generated_images.mul(255).add(0.5).clamp(0, 255).type(torch.uint8)
def null_sync(t, *args, **kwargs):
return [t]
if exists(FID):
fid = FrechetInceptionDistance(**FID, dist_sync_fn=null_sync)
fid.to(device=device)
fid.update(int_real_images, real=True)
fid.update(int_generated_images, real=False)
metrics["FID"] = fid.compute().item()
if exists(IS):
inception = InceptionScore(**IS, dist_sync_fn=null_sync)
inception.to(device=device)
inception.update(int_real_images)
is_mean, is_std = inception.compute()
metrics["IS_mean"] = is_mean.item()
metrics["IS_std"] = is_std.item()
if exists(KID):
kernel_inception = KernelInceptionDistance(**KID, dist_sync_fn=null_sync)
kernel_inception.to(device=device)
kernel_inception.update(int_real_images, real=True)
kernel_inception.update(int_generated_images, real=False)
kid_mean, kid_std = kernel_inception.compute()
metrics["KID_mean"] = kid_mean.item()
metrics["KID_std"] = kid_std.item()
if exists(LPIPS):
# 将像素值从 [0, 1] 转换为 [-1, 1]
renorm_real_images = real_images.mul(2).sub(1).clamp(-1,1)
renorm_generated_images = generated_images.mul(2).sub(1).clamp(-1,1)
lpips = LearnedPerceptualImagePatchSimilarity(**LPIPS, dist_sync_fn=null_sync)
lpips.to(device=device)
lpips.update(renorm_real_images, renorm_generated_images)
metrics["LPIPS"] = lpips.compute().item()
if trainer.accelerator.num_processes > 1:
# 同步指标
metrics_order = sorted(metrics.keys())
metrics_tensor = torch.zeros(1, len(metrics), device=device, dtype=torch.float)
for i, metric_name in enumerate(metrics_order):
metrics_tensor[0, i] = metrics[metric_name]
metrics_tensor = trainer.accelerator.gather(metrics_tensor)
metrics_tensor = metrics_tensor.mean(dim=0)
for i, metric_name in enumerate(metrics_order):
metrics[metric_name] = metrics_tensor[i].item()
return metrics
def save_trainer(tracker: Tracker, trainer: DecoderTrainer, epoch: int, sample: int, next_task: str, validation_losses: List[float], samples_seen: int, is_latest=True, is_best=False):
"""
Logs the model with an appropriate method depending on the tracker
"""
tracker.save(trainer, is_best=is_best, is_latest=is_latest, epoch=epoch, sample=sample, next_task=next_task, validation_losses=validation_losses, samples_seen=samples_seen)
def recall_trainer(tracker: Tracker, trainer: DecoderTrainer):
"""
Loads the model with an appropriate method depending on the tracker
"""
trainer.accelerator.print(print_ribbon(f"Loading model from {type(tracker.loader).__name__}"))
state_dict = tracker.recall()
trainer.load_state_dict(state_dict, only_model=False, strict=True)
# 返回状态字典中的"epoch"键对应的值,如果不存在则返回默认值0
# 返回状态字典中的"validation_losses"键对应的值,如果不存在则返回空列表
# 返回状态字典中的"next_task"键对应的值,如果不存在则返回默认值"train"
# 返回状态字典中的"sample"键对应的值,如果不存在则返回默认值0
# 返回状态字典中的"samples_seen"键对应的值,如果不存在则返回默认值0
return state_dict.get("epoch", 0), state_dict.get("validation_losses", []), state_dict.get("next_task", "train"), state_dict.get("sample", 0), state_dict.get("samples_seen", 0)
# 定义训练函数,用于训练解码器模型
def train(
dataloaders, # 数据加载器
decoder: Decoder, # 解码器模型
accelerator: Accelerator, # 加速器
tracker: Tracker, # 追踪器
inference_device, # 推断设备
clip=None, # 梯度裁剪阈值
evaluate_config=None, # 评估配置
epoch_samples = None, # 每个周期的样本数
validation_samples = None, # 验证样本数
save_immediately=False, # 是否立即保存
epochs = 20, # 训练周期数
n_sample_images = 5, # 样本图像数
save_every_n_samples = 100000, # 每隔多少样本保存一次
unet_training_mask=None, # UNet训练掩码
condition_on_text_encodings=False, # 是否基于文本编码条件
cond_scale=1.0, # 条件缩放
**kwargs # 其他参数
):
"""
Trains a decoder on a dataset.
"""
is_master = accelerator.process_index == 0
if not exists(unet_training_mask):
# 如果未提供UNet训练掩码,则默认所有UNet都应训练
unet_training_mask = [True] * len(decoder.unets)
assert len(unet_training_mask) == len(decoder.unets), f"The unet training mask should be the same length as the number of unets in the decoder. Got {len(unet_training_mask)} and {trainer.num_unets}"
trainable_unet_numbers = [i+1 for i, trainable in enumerate(unet_training_mask) if trainable]
first_trainable_unet = trainable_unet_numbers[0]
last_trainable_unet = trainable_unet_numbers[-1]
def move_unets(unet_training_mask):
for i in range(len(decoder.unets)):
if not unet_training_mask[i]:
# 将不可训练的UNet替换为nn.Identity()。此训练脚本不使用未训练的UNet,因此这样做是可以的。
decoder.unets[i] = nn.Identity().to(inference_device)
# 移除不可训练的UNet
move_unets(unet_training_mask)
trainer = DecoderTrainer(
decoder=decoder,
accelerator=accelerator,
dataloaders=dataloaders,
**kwargs
)
# 根据召回的状态字典设置起始模型和参数
start_epoch = 0
validation_losses = []
next_task = 'train'
sample = 0
samples_seen = 0
val_sample = 0
step = lambda: int(trainer.num_steps_taken(unet_number=first_trainable_unet))
if tracker.can_recall:
start_epoch, validation_losses, next_task, recalled_sample, samples_seen = recall_trainer(tracker, trainer)
if next_task == 'train':
sample = recalled_sample
if next_task == 'val':
val_sample = recalled_sample
accelerator.print(f"Loaded model from {type(tracker.loader).__name__} on epoch {start_epoch} having seen {samples_seen} samples with minimum validation loss {min(validation_losses) if len(validation_losses) > 0 else 'N/A'}")
accelerator.print(f"Starting training from task {next_task} at sample {sample} and validation sample {val_sample}")
trainer.to(device=inference_device)
accelerator.print(print_ribbon("Generating Example Data", repeat=40))
accelerator.print("This can take a while to load the shard lists...")
if is_master:
train_example_data = get_example_data(dataloaders["train_sampling"], inference_device, n_sample_images)
accelerator.print("Generated training examples")
test_example_data = get_example_data(dataloaders["test_sampling"], inference_device, n_sample_images)
accelerator.print("Generated testing examples")
send_to_device = lambda arr: [x.to(device=inference_device, dtype=torch.float) for x in arr]
sample_length_tensor = torch.zeros(1, dtype=torch.int, device=inference_device)
unet_losses_tensor = torch.zeros(TRAIN_CALC_LOSS_EVERY_ITERS, trainer.num_unets, dtype=torch.float, device=inference_device)
# 等待所有节点到达此处,以防止它们在不同时间尝试自动恢复当前运行,这没有意义并会导致错误
accelerator.wait_for_everyone()
# 使用给定的配置创建跟踪器对象
tracker: Tracker = tracker_config.create(config, accelerator_config, dummy_mode=dummy)
# 将配置保存到指定路径下的文件中,文件名为'decoder_config.json'
tracker.save_config(config_path, config_name='decoder_config.json')
# 添加保存元数据,键为'state_dict_key',值为配置模型的转储
tracker.add_save_metadata(state_dict_key='config', metadata=config.model_dump())
# 返回跟踪器对象
return tracker
def initialize_training(config: TrainDecoderConfig, config_path):
# 确保在不加载时,分布式模型初始化为相同的值
torch.manual_seed(config.seed)
# 为可配置的分布式训练设置加速器
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=config.train.find_unused_parameters, static_graph=config.train.static_graph)
init_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=60*60))
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs, init_kwargs])
if accelerator.num_processes > 1:
# 使用分布式训练,并立即确保所有进程都可以连接
accelerator.print("Waiting for all processes to connect...")
accelerator.wait_for_everyone()
accelerator.print("All processes online and connected")
# 如果我们处于深度学习 fp16 模式,则必须确保学习的方差关闭
if accelerator.mixed_precision == "fp16" and accelerator.distributed_type == accelerate_dataclasses.DistributedType.DEEPSPEED and config.decoder.learned_variance:
raise ValueError("DeepSpeed fp16 mode does not support learned variance")
# 设置数据
all_shards = list(range(config.data.start_shard, config.data.end_shard + 1))
world_size = accelerator.num_processes
rank = accelerator.process_index
shards_per_process = len(all_shards) // world_size
assert shards_per_process > 0, "Not enough shards to split evenly"
my_shards = all_shards[rank * shards_per_process: (rank + 1) * shards_per_process]
dataloaders = create_dataloaders (
available_shards=my_shards,
img_preproc = config.data.img_preproc,
train_prop = config.data.splits.train,
val_prop = config.data.splits.val,
test_prop = config.data.splits.test,
n_sample_images=config.train.n_sample_images,
**config.data.model_dump(),
rank = rank,
seed = config.seed,
)
# 如果模型中有 clip,则需要将其移除以与 deepspeed 兼容
clip = None
if config.decoder.clip is not None:
clip = config.decoder.clip.create() # 当然我们保留它以在训练期间使用,只是不在解码器中使用会导致问题
config.decoder.clip = None
# 创建解码器模型并打印基本信息
decoder = config.decoder.create()
get_num_parameters = lambda model, only_training=False: sum(p.numel() for p in model.parameters() if (p.requires_grad or not only_training))
# 如果我们是主节点,则创建并初始化跟踪器
tracker = create_tracker(accelerator, config, config_path, dummy = rank!=0)
has_img_embeddings = config.data.img_embeddings_url is not None
has_text_embeddings = config.data.text_embeddings_url is not None
conditioning_on_text = any([unet.cond_on_text_encodings for unet in config.decoder.unets])
has_clip_model = clip is not None
data_source_string = ""
if has_img_embeddings:
data_source_string += "precomputed image embeddings"
elif has_clip_model:
data_source_string += "clip image embeddings generation"
else:
raise ValueError("No image embeddings source specified")
if conditioning_on_text:
if has_text_embeddings:
data_source_string += " and precomputed text embeddings"
elif has_clip_model:
data_source_string += " and clip text encoding generation"
else:
raise ValueError("No text embeddings source specified")
accelerator.print(print_ribbon("Loaded Config", repeat=40))
accelerator.print(f"Running training with {accelerator.num_processes} processes and {accelerator.distributed_type} distributed training")
accelerator.print(f"Training using {data_source_string}. {'conditioned on text' if conditioning_on_text else 'not conditioned on text'}")
# 打印解码器的参数数量,包括总数和仅训练时的数量
accelerator.print(f"Number of parameters: {get_num_parameters(decoder)} total; {get_num_parameters(decoder, only_training=True)} training")
# 遍历解码器中的每个 UNet 模型
for i, unet in enumerate(decoder.unets):
# 打印每个 UNet 模型的参数数量,包括总数和仅训练时的数量
accelerator.print(f"Unet {i} has {get_num_parameters(unet)} total; {get_num_parameters(unet, only_training=True)} training")
# 调用训练函数,传入数据加载器、解码器、加速器等参数
train(dataloaders, decoder, accelerator,
clip=clip,
tracker=tracker,
inference_device=accelerator.device,
evaluate_config=config.evaluate,
condition_on_text_encodings=conditioning_on_text,
**config.train.model_dump(),
)
# 创建一个简单的 click 命令行接口,用于加载配置并启动训练
@click.command()
@click.option("--config_file", default="./train_decoder_config.json", help="Path to config file")
def main(config_file):
# 将配置文件路径转换为 Path 对象
config_file_path = Path(config_file)
# 从 JSON 文件路径加载训练配置
config = TrainDecoderConfig.from_json_path(str(config_file_path))
# 初始化训练,传入配置和配置文件路径
initialize_training(config, config_path=config_file_path)
if __name__ == "__main__":
# 如果作为脚本直接运行,则调用 main 函数
main()
.\lucidrains\DALLE2-pytorch\train_diffusion_prior.py
import click # 导入 click 库,用于创建命令行界面
import torch # 导入 PyTorch 库
from torch import nn # 从 PyTorch 中导入 nn 模块
from typing import List # 导入 List 类型提示
from accelerate import Accelerator # 从 accelerate 库中导入 Accelerator 类
from accelerate.utils import set_seed # 从 accelerate 库中导入 set_seed 函数
from torch.utils.data import DataLoader # 从 PyTorch 中导入 DataLoader 类
from embedding_reader import EmbeddingReader # 导入自定义的 embedding_reader 模块
from accelerate.utils import dataclasses as accelerate_dataclasses # 从 accelerate 库中导入 dataclasses 模块
from dalle2_pytorch.utils import Timer # 从 dalle2_pytorch.utils 中导入 Timer 类
from dalle2_pytorch.trackers import Tracker # 从 dalle2_pytorch.trackers 中导入 Tracker 类
from dalle2_pytorch import DiffusionPriorTrainer # 导入自定义的 DiffusionPriorTrainer 类
from dalle2_pytorch.dataloaders import get_reader, make_splits # 从 dalle2_pytorch.dataloaders 中导入 get_reader 和 make_splits 函数
from dalle2_pytorch.train_configs import ( # 从 dalle2_pytorch.train_configs 中导入 TrainDiffusionPriorConfig 相关配置
DiffusionPriorConfig,
DiffusionPriorTrainConfig,
TrainDiffusionPriorConfig,
)
# helpers
cos = nn.CosineSimilarity(dim=1, eps=1e-6) # 创建一个计算余弦相似度的对象
def exists(val):
return val is not None # 判断值是否为 None
def all_between(values: list, lower_bound, upper_bound):
for value in values:
if value < lower_bound or value > upper_bound:
return False
return True
def make_model(
prior_config: DiffusionPriorConfig,
train_config: DiffusionPriorTrainConfig,
device: str = None,
accelerator: Accelerator = None,
):
# 根据配置创建模型
diffusion_prior = prior_config.create()
# 实例化训练器
trainer = DiffusionPriorTrainer(
diffusion_prior=diffusion_prior,
lr=train_config.lr,
wd=train_config.wd,
max_grad_norm=train_config.max_grad_norm,
amp=train_config.amp,
use_ema=train_config.use_ema,
device=device,
accelerator=accelerator,
warmup_steps=train_config.warmup_steps,
)
return trainer
def create_tracker(
accelerator: Accelerator,
config: TrainDiffusionPriorConfig,
config_path: str,
dummy: bool = False,
) -> Tracker:
tracker_config = config.tracker
accelerator_config = {
"Distributed": accelerator.distributed_type
!= accelerate_dataclasses.DistributedType.NO,
"DistributedType": accelerator.distributed_type,
"NumProcesses": accelerator.num_processes,
"MixedPrecision": accelerator.mixed_precision,
}
tracker: Tracker = tracker_config.create(
config, accelerator_config, dummy_mode=dummy
)
tracker.save_config(config_path, config_name="prior_config.json")
return tracker
def pad_gather_reduce(trainer: DiffusionPriorTrainer, x, method="mean"):
"""
pad a value or tensor across all processes and gather
params:
- trainer: a trainer that carries an accelerator object
- x: a number or torch tensor to reduce
- method: "mean", "sum", "max", "min"
return:
- the average tensor after maskin out 0's
- None if the gather resulted in an empty tensor
"""
assert method in [
"mean",
"sum",
"max",
"min",
], "This function has limited capabilities [sum, mean, max, min]"
assert type(x) is not None, "Cannot reduce a None type object"
# 等待所有进���到达此处后再进��聚合
if type(x) is not torch.Tensor:
x = torch.tensor([x])
# 确保张量在正确的设备上
x = x.to(trainer.device)
# 跨进程填充
padded_x = trainer.accelerator.pad_across_processes(x, dim=0)
# 聚合所有进程
gathered_x = trainer.accelerator.gather(padded_x)
# 掩码掉零值
masked_x = gathered_x[gathered_x != 0]
# 如果张量为空,则警告并返回 None
if len(masked_x) == 0:
click.secho(
f"The call to this method resulted in an empty tensor after masking out zeros. The gathered tensor was this: {gathered_x} and the original value passed was: {x}.",
fg="red",
)
return None
if method == "mean":
return torch.mean(masked_x)
elif method == "sum":
return torch.sum(masked_x)
elif method == "max":
return torch.max(masked_x)
elif method == "min":
return torch.min(masked_x)
def save_trainer(
tracker: Tracker,
# 定义一个名为trainer的变量,类型为DiffusionPriorTrainer
trainer: DiffusionPriorTrainer,
# 定义一个名为is_latest的变量,类型为bool,表示是否为最新的
is_latest: bool,
# 定义一个名为is_best的变量,类型为bool,表示是否为最佳的
is_best: bool,
# 定义一个名为epoch的变量,类型为int,表示当前的训练轮数
epoch: int,
# 定义一个名为samples_seen的变量,类型为int,表示已经处理的样本数量
samples_seen: int,
# 定义一个名为best_validation_loss的变量,类型为float,表示最佳验证损失值
best_validation_loss: float,
# 记录模型的状态,根据追踪器选择适当的方法
def log_model(tracker: Tracker, trainer: DiffusionPriorTrainer, is_best: bool, is_latest: bool, epoch: int, samples_seen: int, best_validation_loss: float):
# 等待所有进程完成
trainer.accelerator.wait_for_everyone()
# 如果是主进程
if trainer.accelerator.is_main_process:
# 打印保存模型的信息,包括最佳和最新状态
click.secho(
f"RANK:{trainer.accelerator.process_index} | Saving Model | Best={is_best} | Latest={is_latest}",
fg="magenta",
)
# 保存模型
tracker.save(
trainer=trainer,
is_best=is_best,
is_latest=is_latest,
epoch=int(epoch),
samples_seen=int(samples_seen),
best_validation_loss=best_validation_loss,
)
# 恢复训练器状态
def recall_trainer(tracker: Tracker, trainer: DiffusionPriorTrainer):
# 如果是主进程
if trainer.accelerator.is_main_process:
# 打印加载模型的信息
click.secho(f"Loading model from {type(tracker.loader).__name__}", fg="yellow")
# 从追踪器中恢复模型状态
state_dict = tracker.recall()
# 加载模型状态到训练器
trainer.load(state_dict, strict=True)
return (
int(state_dict.get("epoch", 0)),
state_dict.get("best_validation_loss", 0),
int(state_dict.get("samples_seen", 0)),
)
# 评估函数
# 报告验证集上的损失
def report_validation_loss(trainer: DiffusionPriorTrainer, dataloader: DataLoader, text_conditioned: bool, use_ema: bool, tracker: Tracker, split: str, tracker_folder: str, loss_type: str):
# 如果是主进程
if trainer.accelerator.is_main_process:
# 打印评估性能的信息
click.secho(
f"Measuring performance on {use_ema}-{split} split",
fg="green",
blink=True,
)
# 初始化总损失
total_loss = torch.zeros(1, dtype=torch.float, device=trainer.device)
# 遍历数据加载器中的数据
for image_embeddings, text_data in dataloader:
image_embeddings = image_embeddings.to(trainer.device)
text_data = text_data.to(trainer.device)
input_args = dict(image_embed=image_embeddings)
if text_conditioned:
input_args = dict(**input_args, text=text_data)
else:
input_args = dict(**input_args, text_embed=text_data)
if use_ema:
loss = trainer.ema_diffusion_prior(**input_args)
else:
loss = trainer(**input_args)
total_loss += loss
# 计算所有进程的平均损失
avg_loss = pad_gather_reduce(trainer, total_loss, method="mean")
stats = {f"{tracker_folder}/{loss_type}-loss": avg_loss}
# 打印和记录结果到主进程
tracker.log(stats, step=trainer.step.item() + 1)
return avg_loss
# 报告余弦相似度
def report_cosine_sims(trainer: DiffusionPriorTrainer, dataloader: DataLoader, text_conditioned: bool, tracker: Tracker, split: str, timesteps: int, tracker_folder: str):
# 设置为评估模式
trainer.eval()
# 如果是主进程
if trainer.accelerator.is_main_process:
# 打印余弦相似度的信息
click.secho(
f"Measuring Cosine-Similarity on {split} split with {timesteps} timesteps",
fg="green",
blink=True,
)
# 遍历数据加载器,获取测试图像嵌入和文本数据
for test_image_embeddings, text_data in dataloader:
# 将测试图像嵌入和文本数据移动到训练器所在的设备上
test_image_embeddings = test_image_embeddings.to(trainer.device)
text_data = text_data.to(trainer.device)
# 如果是文本条件下,从标记化文本中生成嵌入
if text_conditioned:
text_embedding, text_encodings = trainer.embed_text(text_data)
text_cond = dict(text_embed=text_embedding, text_encodings=text_encodings)
else:
text_embedding = text_data
text_cond = dict(text_embed=text_embedding)
# 复制文本嵌入以进行混洗
text_embed_shuffled = text_embedding.clone()
# 滚动文本以模拟“不相关”的标题
rolled_idx = torch.roll(torch.arange(text_embedding.shape[0]), 1)
text_embed_shuffled = text_embed_shuffled[rolled_idx]
text_embed_shuffled = text_embed_shuffled / text_embed_shuffled.norm(
dim=1, keepdim=True
)
if text_conditioned:
text_encodings_shuffled = text_encodings[rolled_idx]
else:
text_encodings_shuffled = None
text_cond_shuffled = dict(
text_embed=text_embed_shuffled, text_encodings=text_encodings_shuffled
)
# 准备文本嵌入
text_embed = text_embedding / text_embedding.norm(dim=1, keepdim=True)
# 准备图像嵌入
test_image_embeddings = test_image_embeddings / test_image_embeddings.norm(
dim=1, keepdim=True
)
# 在未混洗的文本嵌入上进行预测
predicted_image_embeddings = trainer.p_sample_loop(
test_image_embeddings.shape,
text_cond,
timesteps=timesteps,
)
predicted_image_embeddings = (
predicted_image_embeddings
/ predicted_image_embeddings.norm(dim=1, keepdim=True)
)
# 在混洗的嵌入上进行预测
predicted_unrelated_embeddings = trainer.p_sample_loop(
test_image_embeddings.shape,
text_cond_shuffled,
timesteps=timesteps,
)
predicted_unrelated_embeddings = (
predicted_unrelated_embeddings
/ predicted_unrelated_embeddings.norm(dim=1, keepdim=True)
)
# 计算相似度
orig_sim = pad_gather_reduce(
trainer, cos(text_embed, test_image_embeddings), method="mean"
)
pred_sim = pad_gather_reduce(
trainer, cos(text_embed, predicted_image_embeddings), method="mean"
)
unrel_sim = pad_gather_reduce(
trainer, cos(text_embed, predicted_unrelated_embeddings), method="mean"
)
pred_img_sim = pad_gather_reduce(
trainer,
cos(test_image_embeddings, predicted_image_embeddings),
method="mean",
)
# 统计结果
stats = {
f"{tracker_folder}/baseline similarity [steps={timesteps}]": orig_sim,
f"{tracker_folder}/similarity with text [steps={timesteps}]": pred_sim,
f"{tracker_folder}/similarity with original image [steps={timesteps}]": pred_img_sim,
f"{tracker_folder}/similarity with unrelated caption [steps={timesteps}]": unrel_sim,
f"{tracker_folder}/difference from baseline similarity [steps={timesteps}]": pred_sim
- orig_sim,
}
# 记录统计结果
tracker.log(stats, step=trainer.step.item() + 1)
# 定义评估模型的函数,用于在模型上运行评估并跟踪指标,返回损失(如果请求)
def eval_model(
trainer: DiffusionPriorTrainer, # 训练器对象
dataloader: DataLoader, # 数据加载器对象
text_conditioned: bool, # 是否基于文本条件
split: str, # 数据集划分
tracker: Tracker, # 追踪器对象
use_ema: bool, # 是否使用指数移动平均
report_cosine: bool, # 是否报告余弦相似度
report_loss: bool, # 是否报告损失
timesteps: List[int], # 时间步列表
loss_type: str = None, # 损失类型,默认为None
):
"""
Run evaluation on a model and track metrics
returns: loss if requested
"""
# 将模型设置为评估模式
trainer.eval()
# 根据是否使用指数移动平均设置使用的模式
use_ema = "ema" if use_ema else "online"
# 设置追踪器文件夹路径
tracker_folder = f"metrics/{use_ema}-{split}"
# 检查传入的时间步是否有效
min_timesteps = trainer.accelerator.unwrap_model(
trainer.diffusion_prior
).sample_timesteps
max_timesteps = trainer.accelerator.unwrap_model(
trainer.diffusion_prior
).noise_scheduler.num_timesteps
assert all_between(
timesteps, lower_bound=min_timesteps, upper_bound=max_timesteps
), f"all timesteps values must be between {min_timesteps} and {max_timesteps}: got {timesteps}"
# 如果需要报告余弦相似度,则在不同的eta和时间步上测量余弦相似度指标
if report_cosine:
for timestep in timesteps:
report_cosine_sims(
trainer,
dataloader=dataloader,
text_conditioned=text_conditioned,
tracker=tracker,
split=split,
timesteps=timestep,
tracker_folder=tracker_folder,
)
# 如果需要报告损失,则在数据的另一个划分上测量损失
if report_loss:
# 报告验证集上的损失
loss = report_validation_loss(
trainer=trainer,
dataloader=dataloader,
text_conditioned=text_conditioned,
use_ema=use_ema,
tracker=tracker,
split=split,
tracker_folder=tracker_folder,
loss_type=loss_type,
)
return loss
# 训练脚本
# 定义训练函数
def train(
trainer: DiffusionPriorTrainer, # 训练器对象
tracker: Tracker, # 追踪器对象
train_loader: DataLoader, # 训练数据加载器对象
eval_loader: DataLoader, # 评估数据加载器对象
test_loader: DataLoader, # 测试数据加载器对象
config: DiffusionPriorTrainConfig, # 训练配置对象
):
# 初始化计时器
save_timer = Timer() # 保存计时器
samples_timer = Timer() # 样本速率计时器
validation_profiler = Timer() # 验证时间计时器
validation_countdown = Timer() # 验证倒计时计时器
# 跟踪最佳验证损失
best_validation_loss = config.train.best_validation_loss
samples_seen = config.train.num_samples_seen
# 开始训练
start_epoch = config.train.current_epoch
# 在测试数据上进行评估
if trainer.accelerator.is_main_process:
click.secho(f"Starting Test", fg="red")
# 在开始验证之前最后保存一次最新模型
save_trainer(
tracker=tracker,
trainer=trainer,
is_best=False,
is_latest=True,
samples_seen=samples_seen,
epoch=epoch,
best_validation_loss=best_validation_loss,
)
# 在测试数据上评估模型
test_loss = eval_model(
trainer=trainer,
dataloader=test_loader,
text_conditioned=config.prior.condition_on_text_encodings,
split="test",
tracker=tracker,
use_ema=True,
report_cosine=False,
report_loss=True,
timesteps=config.train.eval_timesteps,
loss_type=config.prior.loss_type,
)
# 如果测试损失小于最佳验证损失,则更新最佳验证损失并保存模型
if test_loss < best_validation_loss:
best_validation_loss = test_loss
# 保存最佳模型
save_trainer(
trainer=trainer,
tracker=tracker,
is_best=True,
is_latest=False,
samples_seen=samples_seen,
epoch=epoch,
best_validation_loss=test_loss,
)
# 初始化训练
def initialize_training(config_file, accelerator):
"""
Parse the configuration file, and prepare everything necessary for training
"""
# 加载配置文件
if accelerator.is_main_process:
click.secho(f"Loading configuration from {config_file}", fg="green")
# 从JSON路径加载训练配置
config = TrainDiffusionPriorConfig.from_json_path(config_file)
# 设置随机种子
set_seed(config.train.random_seed)
# 获取设备
device = accelerator.device
# 创建训练器(如果可能且已配置,将自动分发)
trainer: DiffusionPriorTrainer = make_model(
config.prior, config.train, device, accelerator
).to(device)
# 创建一个追踪器
tracker = create_tracker(
accelerator, config, config_file, dummy=accelerator.process_index != 0
)
# 从检查点重新加载
if tracker.can_recall:
current_epoch, best_validation_loss, samples_seen = recall_trainer(
tracker=tracker, trainer=trainer
)
# 显示最佳值
if trainer.accelerator.is_main_process:
click.secho(f"Current Epoch: {current_epoch} | Best Val Loss: {best_validation_loss} | Samples Seen: {samples_seen}", fg="yellow")
# 更新配置以反映已召回的值
config.train.num_samples_seen = samples_seen
config.train.current_epoch = current_epoch
config.train.best_validation_loss = best_validation_loss
# 获取并准备数据
if trainer.accelerator.is_main_process:
click.secho("Grabbing data...", fg="blue", blink=True)
trainer.accelerator.wait_for_everyone()
img_reader = get_reader(
text_conditioned=trainer.text_conditioned,
img_url=config.data.image_url,
meta_url=config.data.meta_url,
)
# 计算在 epoch 中的起始点
trainer.accelerator.wait_for_everyone()
train_loader, eval_loader, test_loader = make_splits(
text_conditioned=trainer.text_conditioned,
batch_size=config.data.batch_size,
num_data_points=config.data.num_data_points,
train_split=config.data.splits.train,
eval_split=config.data.splits.val,
image_reader=img_reader,
rank=accelerator.state.process_index,
world_size=accelerator.state.num_processes,
start=0,
)
# 更新起始点以完成在恢复运行时的 epoch
if tracker.can_recall:
samples_seen = config.train.num_samples_seen
length = (
config.data.num_data_points
if samples_seen <= img_reader.count
else img_reader.count
)
scaled_samples = length * config.train.current_epoch
start_point = (
scaled_samples - samples_seen if scaled_samples > samples_seen else samples_seen
)
if trainer.accelerator.is_main_process:
click.secho(f"Resuming at sample: {start_point}", fg="yellow")
train_loader.dataset.set_start(start_point)
# 开始训练
if trainer.accelerator.is_main_process:
click.secho(
f"Beginning Prior Training : Distributed={accelerator.state.distributed_type != accelerate_dataclasses.DistributedType.NO}",
fg="yellow",
)
train(
trainer=trainer,
tracker=tracker,
train_loader=train_loader,
eval_loader=eval_loader,
test_loader=test_loader,
config=config,
)
# 创建一个命令行接口
@click.command()
# 添加一个命令行选项,指定配置文件,默认为"configs/train_prior_config.example.json"
@click.option("--config_file", default="configs/train_prior_config.example.json")
def main(config_file):
# 初始化加速器对象
accelerator = Accelerator()
# 设置训练环境
initialize_training(config_file, accelerator)
# 如果当前脚本被直接执行,则调用main函数
if __name__ == "__main__":
main()
.\lucidrains\ddpm-ipa-protein-generation\ddpm_ipa_protein_generation\ddpm_ipa_protein_generation.py
import torch
from torch import nn
# gaussian diffusion with continuous time helper functions and classes
# large part of this was thanks to @crowsonkb at https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/utils.py
# 定义一个 torch.jit.script 装饰器修饰的函数,用于计算 beta_linear_log_snr
@torch.jit.script
def beta_linear_log_snr(t):
return -torch.log(expm1(1e-4 + 10 * (t ** 2)))
# 定义一个 torch.jit.script 装饰器修饰的函数,用于计算 alpha_cosine_log_snr
@torch.jit.script
def alpha_cosine_log_snr(t, s: float = 0.008):
return -log((torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** -2) - 1, eps = 1e-5) # not sure if this accounts for beta being clipped to 0.999 in discrete version
# 将 log_snr 转换为 alpha 和 sigma
def log_snr_to_alpha_sigma(log_snr):
return torch.sqrt(torch.sigmoid(log_snr)), torch.sqrt(torch.sigmoid(-log_snr))
# 定义一个继承自 nn.Module 的类 Diffusion
class Diffusion(nn.Module):
def __init__(self, *, noise_schedule, timesteps = 1000):
super().__init__()
# 根据噪声调度选择不同的 log_snr 函数
if noise_schedule == "linear":
self.log_snr = beta_linear_log_snr
elif noise_schedule == "cosine":
self.log_snr = alpha_cosine_log_snr
else:
raise ValueError(f'invalid noise schedule {noise_schedule}')
self.num_timesteps = timesteps
# 获取时间
def get_times(self, batch_size, noise_level, *, device):
return torch.full((batch_size,), noise_level, device = device, dtype = torch.float32)
# 随机采样时间
def sample_random_times(self, batch_size, max_thres = 0.999, *, device):
return torch.zeros((batch_size,), device = device).float().uniform_(0, max_thres)
# 获取条件
def get_condition(self, times):
return maybe(self.log_snr)(times)
# 获取采样时间步长
def get_sampling_timesteps(self, batch, *, device):
times = torch.linspace(1., 0., self.num_timesteps + 1, device = device)
times = repeat(times, 't -> b t', b = batch)
times = torch.stack((times[:, :-1], times[:, 1:]), dim = 0)
times = times.unbind(dim = -1)
return times
# 计算 posterior
def q_posterior(self, x_start, x_t, t, *, t_next = None):
t_next = default(t_next, lambda: (t - 1. / self.num_timesteps).clamp(min = 0.))
""" https://openreview.net/attachment?id=2LdBqxc1Yv&name=supplementary_material """
log_snr = self.log_snr(t)
log_snr_next = self.log_snr(t_next)
log_snr, log_snr_next = map(partial(right_pad_dims_to, x_t), (log_snr, log_snr_next))
alpha, sigma = log_snr_to_alpha_sigma(log_snr)
alpha_next, sigma_next = log_snr_to_alpha_sigma(log_snr_next)
# c - as defined near eq 33
c = -expm1(log_snr - log_snr_next)
posterior_mean = alpha_next * (x_t * (1 - c) / alpha + c * x_start)
# following (eq. 33)
posterior_variance = (sigma_next ** 2) * c
posterior_log_variance_clipped = log(posterior_variance, eps = 1e-20)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
# 采样
def q_sample(self, x_start, t, noise = None):
if isinstance(t, float):
batch = x_start.shape[0]
t = torch.full((batch,), t, device = x_start.device, dtype = x_start.dtype)
noise = default(noise, lambda: torch.randn_like(x_start))
log_snr = self.log_snr(t)
log_snr_padded_dim = right_pad_dims_to(x_start, log_snr)
alpha, sigma = log_snr_to_alpha_sigma(log_snr_padded_dim)
return alpha * x_start + sigma * noise, log_snr
# 从输入的起始点 x_from 开始,根据给定的时间范围 from_t 到 to_t,生成样本
def q_sample_from_to(self, x_from, from_t, to_t, noise = None):
# 获取输入张量 x_from 的形状、设备和数据类型
shape, device, dtype = x_from.shape, x_from.device, x_from.dtype
batch = shape[0]
# 如果 from_t 是浮点数,则将其转换为与 batch 大小相同的张量
if isinstance(from_t, float):
from_t = torch.full((batch,), from_t, device = device, dtype = dtype)
# 如果 to_t 是浮点数,则将其转换为与 batch 大小相同的张量
if isinstance(to_t, float):
to_t = torch.full((batch,), to_t, device = device, dtype = dtype)
# 如果未提供噪声,则生成一个与 x_from 相同形状的随机噪声张量
noise = default(noise, lambda: torch.randn_like(x_from))
# 计算起始点到终点的 log_snr,并在需要时对其进行维度填充
log_snr = self.log_snr(from_t)
log_snr_padded_dim = right_pad_dims_to(x_from, log_snr)
alpha, sigma = log_snr_to_alpha_sigma(log_snr_padded_dim)
# 计算终点的 log_snr,并在需要时对其进行维度填充
log_snr_to = self.log_snr(from_t)
log_snr_padded_dim_to = right_pad_dims_to(x_from, log_snr_to)
alpha_to, sigma_to = log_snr_to_alpha_sigma(log_snr_padded_dim_to)
# 根据公式生成样本并返回
return x_from * (alpha_to / alpha) + noise * (sigma_to * alpha - sigma * alpha_to) / alpha
# 根据给定的时间 t 和噪声生成预测的起始点
def predict_start_from_noise(self, x_t, t, noise):
# 计算时间 t 对应的 log_snr,并在需要时对其进行维度填充
log_snr = self.log_snr(t)
log_snr = right_pad_dims_to(x_t, log_snr)
alpha, sigma = log_snr_to_alpha_sigma(log_snr)
# 根据公式计算并返回预测的起始点
return (x_t - sigma * noise) / alpha.clamp(min = 1e-8)
.\lucidrains\ddpm-ipa-protein-generation\ddpm_ipa_protein_generation\__init__.py
# 定义一个名为calculate_area的函数,用于计算矩形的面积
def calculate_area(length, width):
# 计算矩形的面积
area = length * width
# 返回计算得到的面积
return area
DDPM + IPA for Protein Structure and Sequence Generation (wip)
Implementation of the DDPM + IPA (invariant point attention) for protein generation, as outlined in the paper Protein Structure and Sequence Generation with Equivariant Denoising Diffusion Probabilistic Models. They basically combined the invariant point attention module from Alphafold2 (used for coordinate refinement) with a standard DDPM, and demonstrate very cool infilling capabilities for protein structure generation.
I will also equip this with ability to condition on encoded text, identical to Imagen. Eventually, I will also try to offer a version using Insertion-deletion DDPM (but I have yet to replicate this work and open source it)
Citations
@misc{https://doi.org/10.48550/arxiv.2205.15019,
doi = {10.48550/ARXIV.2205.15019},
url = {https://arxiv.org/abs/2205.15019},
author = {Anand, Namrata and Achim, Tudor},
keywords = {Quantitative Methods (q-bio.QM), Artificial Intelligence (cs.AI), FOS: Biological sciences, FOS: Biological sciences, FOS: Computer and information sciences, FOS: Computer and information sciences},
title = {Protein Structure and Sequence Generation with Equivariant Denoising Diffusion Probabilistic Models},
publisher = {arXiv},
year = {2022},
copyright = {arXiv.org perpetual, non-exclusive license}
}
.\lucidrains\ddpm-ipa-protein-generation\setup.py
# 导入设置工具和查找包工具
from setuptools import setup, find_packages
# 设置包的元数据
setup(
name = 'ddpm-ipa-protein-generation', # 包名
packages = find_packages(exclude=[]), # 查找所有包
version = '0.0.1', # 版本号
license='MIT', # 许可证
description = 'DDPM + Invariant Point Attention - Protein Generation', # 描述
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
long_description_content_type = 'text/markdown', # 长描述内容类型
url = 'https://github.com/lucidrains/ddpm-ipa-protein-generation', # URL
keywords = [ # 关键词列表
'artificial intelligence',
'deep learning',
'attention mechanism',
'geometric deep learning',
'denoising diffusion probabilistic models'
],
install_requires=[ # 安装依赖
'invariant-point-attention>=0.2.1',
'einops>=0.4',
'torch>=1.6',
],
classifiers=[ # 分类器
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\ddpm-proteins\cache.py
# 导入必要的库
from tqdm import tqdm
import sidechainnet as scn
from ddpm_proteins.utils import get_msa_attention_embeddings, get_msa_transformer
# 加载 sidechainnet 数据集
data = scn.load(
casp_version = 12,
thinning = 30,
with_pytorch = 'dataloaders',
batch_size = 1,
dynamic_batching = False
)
# 设置常量
LENGTH_THRES = 256
# 定义获取 MSAs 的函数,根据你的设置填写
def fetch_msas_fn(aa_str):
"""
给定一个氨基酸序列作为字符串
填写一个返回 MSAs(作为字符串列表)的函数
(默认情况下,它将返回空列表,只有主要序列被馈送到 MSA Transformer)
"""
return []
# 缓存循环
# 获取 MSA Transformer 模型和批处理转换器
model, batch_converter = get_msa_transformer()
# 遍历训练数据集中的批次
for batch in tqdm(data['train']):
# 如果序列长度大于阈值,则跳过当前批次
if batch.seqs.shape[1] > LENGTH_THRES:
continue
# 获取批次中的蛋白质 ID 和序列
pids = batch.pids
seqs = batch.seqs.argmax(dim = -1)
# 获取 MSA 注意力嵌入
_ = get_msa_attention_embeddings(
model,
batch_converter,
seqs,
batch.pids,
fetch_msas_fn
)
# 输出缓存完成信息
print('caching complete')
.\lucidrains\ddpm-proteins\ddpm_proteins\ddpm_proteins.py
import math
from math import log, pi
import copy
import torch
from torch import nn, einsum
import torch.nn.functional as F
from inspect import isfunction
from functools import partial
from torch.utils import data
from pathlib import Path
from torch.optim import Adam
from torchvision import transforms, utils
from PIL import Image
import numpy as np
from tqdm import tqdm
from einops import rearrange, repeat
from ddpm_proteins.utils import broadcat
try:
from apex import amp
APEX_AVAILABLE = True
except:
APEX_AVAILABLE = False
# constants
SAVE_AND_SAMPLE_EVERY = 1000
UPDATE_EMA_EVERY = 10
EXTS = ['jpg', 'jpeg', 'png']
RESULTS_FOLDER = Path('./results')
RESULTS_FOLDER.mkdir(exist_ok = True)
# helpers functions
def exists(x):
return x is not None
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def cycle(dl):
while True:
for data in dl:
yield data
def num_to_groups(num, divisor):
groups = num // divisor
remainder = num % divisor
arr = [divisor] * groups
if remainder > 0:
arr.append(remainder)
return arr
def loss_backwards(fp16, loss, optimizer, **kwargs):
if fp16:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward(**kwargs)
else:
loss.backward(**kwargs)
# small helper modules
class EMA():
def __init__(self, beta):
super().__init__()
self.beta = beta
def update_model_average(self, ma_model, current_model):
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
old_weight, up_weight = ma_params.data, current_params.data
ma_params.data = self.update_average(old_weight, up_weight)
def update_average(self, old, new):
if old is None:
return new
return old * self.beta + (1 - self.beta) * new
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, *args, **kwargs):
return self.fn(x, *args, **kwargs) + x
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class Mish(nn.Module):
def forward(self, x):
return x * torch.tanh(F.softplus(x))
class Upsample(nn.Module):
def __init__(self, dim):
super().__init__()
self.conv = nn.ConvTranspose2d(dim, dim, 4, 2, 1)
def forward(self, x):
return self.conv(x)
class Downsample(nn.Module):
def __init__(self, dim):
super().__init__()
self.conv = nn.Conv2d(dim, dim, 3, 2, 1)
def forward(self, x):
return self.conv(x)
# building block modules
class ResnetBlock(nn.Module):
def __init__(
self,
dim,
dim_out,
*,
time_emb_dim,
hybrid_dim_conv = False,
groups = 8
# 定义一个继承自 nn.Module 的类
):
# 调用父类的初始化方法
super().__init__()
# 定义卷积核大小和填充大小的元组
kernels = ((3, 3),)
paddings = ((1, 1),)
# 如果使用混合维度卷积
if hybrid_dim_conv:
# 添加额外的卷积核和填充大小
kernels = (*kernels, (9, 1), (1, 9))
paddings = (*paddings, (4, 0), (0, 4))
# 定义一个包含 Mish 激活函数和线性层的序列
self.mlp = nn.Sequential(
Mish(),
nn.Linear(time_emb_dim, dim_out)
)
# 初始化输入和输出的模块列表
self.blocks_in = nn.ModuleList([])
self.blocks_out = nn.ModuleList([])
# 遍历卷积核和填充大小,构建输入和输出的模块列表
for kernel, padding in zip(kernels, paddings):
self.blocks_in.append(nn.Sequential(
nn.Conv2d(dim, dim_out, kernel, padding = padding),
nn.GroupNorm(groups, dim_out),
Mish()
))
self.blocks_out.append(nn.Sequential(
nn.Conv2d(dim_out, dim_out, kernel, padding = padding),
nn.GroupNorm(groups, dim_out),
Mish()
))
# 如果输入维度和输出维度不同,使用 1x1 卷积进行维度匹配,否则使用恒等映射
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
# 定义前向传播方法
def forward(self, x, time_emb):
# 对输入数据进行一系列的卷积操作,存储在 hiddens 中
hiddens = [fn(x) for fn in self.blocks_in]
# 对时间嵌入进行 MLP 处理
time_emb = self.mlp(time_emb)
# 重新排列时间嵌入的维度
time_emb = rearrange(time_emb, 'b c -> b c () ()')
# 将时间嵌入加到 hiddens 中
hiddens = [h + time_emb for h in hiddens]
# 对 hiddens 进行一系列的卷积操作
hiddens = [fn(h) for fn, h in zip(self.blocks_out, hiddens)]
# 将所有 hiddens 相加并加上残差连接
return sum(hiddens) + self.res_conv(x)
# 定义应用旋转嵌入的函数
def apply_rotary_emb(q, k, pos_emb):
# 将位置嵌入分解为正弦和余弦部分
sin, cos = pos_emb
dim_rotary = sin.shape[-1]
# 分别提取旋转嵌入的正弦和余弦部分
(q, q_pass), (k, k_pass) = map(lambda t: (t[..., :dim_rotary], t[..., dim_rotary:]), (q, k))
# 应用旋转嵌入到查询和键上
q, k = map(lambda t: (t * cos) + (rotate_every_two(t) * sin), (q, k))
# 将处理后的查询和键重新拼接
q, k = map(lambda t: torch.cat(t, dim = -1), ((q, q_pass), (k, k_pass)))
return q, k
# 旋转每两个元素的函数
def rotate_every_two(x):
# 重新排列张量形状
x = rearrange(x, '... (d j) -> ... d j', j = 2)
x1, x2 = x.unbind(dim = -1)
# 旋转每两个元素
x = torch.stack((-x2, x1), dim = -1)
return rearrange(x, '... d j -> ... (d j)')
# 定义轴向旋转嵌入类
class AxialRotaryEmbedding(nn.Module):
def __init__(self, dim, max_freq = 10):
super().__init__()
self.dim = dim
# 计算频率范围
scales = torch.logspace(0., log(max_freq / 2) / log(2), self.dim // 4, base = 2)
self.cached_pos_emb = None
self.register_buffer('scales', scales)
def forward(self, x):
device, dtype, h, w = x.device, x.dtype, *x.shape[-2:]
if exists(self.cached_pos_emb):
return self.cached_pos_emb
# 生成正弦和余弦位置编码
seq_x = torch.linspace(-1., 1., steps = h, device = device)
seq_x = seq_x.unsqueeze(-1)
seq_y = torch.linspace(-1., 1., steps = w, device = device)
seq_y = seq_y.unsqueeze(-1)
scales = self.scales[(*((None,) * (len(seq_x.shape) - 1)), Ellipsis)]
scales = scales.to(x)
scales = self.scales[(*((None,) * (len(seq_y.shape) - 1)), Ellipsis)]
scales = scales.to(x)
seq_x = seq_x * scales * pi
seq_y = seq_y * scales * pi
x_sinu = repeat(seq_x, 'i d -> i j d', j = w)
y_sinu = repeat(seq_y, 'j d -> i j d', i = h)
sin = torch.cat((x_sinu.sin(), y_sinu.sin()), dim = -1)
cos = torch.cat((x_sinu.cos(), y_sinu.cos()), dim = -1)
sin, cos = map(lambda t: rearrange(t, 'i j d -> i j d'), (sin, cos))
sin, cos = map(lambda t: repeat(t, 'i j d -> () (i j) (d r)', r = 2), (sin, cos))
self.cached_pos_emb = (sin, cos)
return sin, cos
# 线性注意力函数
def linear_attn_kernel(t):
return F.elu(t) + 1
# 线性注意力机制
def linear_attention(q, k, v):
k_sum = k.sum(dim = -2)
D_inv = 1. / torch.einsum('...nd,...d->...n', q, k_sum.type_as(q))
context = torch.einsum('...nd,...ne->...de', k, v)
out = torch.einsum('...de,...nd,...n->...ne', context, q, D_inv)
return out
# 层归一化类
class LayerNorm(nn.Module):
def __init__(self, dim, eps = 1e-5):
super().__init__()
self.eps = eps
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
def forward(self, x):
std = torch.var(x, dim = 1, unbiased = False, keepdim = True).sqrt()
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (std + self.eps) * self.g + self.b
# 线性注意力类
class LinearAttention(nn.Module):
def __init__(self, dim, heads = 4, dim_head = 32):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
hidden_dim = dim_head * heads
self.pos_emb = AxialRotaryEmbedding(dim = dim_head)
self.norm = LayerNorm(dim)
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x):
b, c, h, w, heads = *x.shape, self.heads
x = self.norm(x)
q, k, v = self.to_qkv(x).chunk(3, dim = 1)
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h (x y) c', h = heads), (q, k, v))
sin, cos = self.pos_emb(x)
q, k = apply_rotary_emb(q, k, (sin, cos))
q = linear_attn_kernel(q)
k = linear_attn_kernel(k)
q = q * self.scale
out = linear_attention(q, k, v)
out = rearrange(out, 'b h (x y) c -> b (h c) x y', h = heads, x = h, y = w)
return self.to_out(out)
# 模型类
class Unet(nn.Module):
# 初始化函数,设置模型参数
def __init__(
self,
dim,
out_dim = None,
dim_mults=(1, 2, 4, 8),
groups = 8,
channels = 3,
condition_dim = 0,
hybrid_dim_conv = False,
):
# 调用父类的初始化函数
super().__init__()
self.channels = channels
self.condition_dim = condition_dim
# 计算输入通道数,考虑条件输入,为 MSA Transformers 做准备
input_channels = channels + condition_dim
# 计算不同分辨率下的维度
dims = [input_channels, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
# 初始化时间位置编码和 MLP 层
self.time_pos_emb = SinusoidalPosEmb(dim)
self.mlp = nn.Sequential(
nn.Linear(dim, dim * 4),
Mish(),
nn.Linear(dim * 4, dim)
)
# 初始化下采样和上采样模块
self.downs = nn.ModuleList([])
self.ups = nn.ModuleList([])
num_resolutions = len(in_out)
# 部分函数参数固定的 ResnetBlock 函数
get_resnet_block = partial(ResnetBlock, time_emb_dim = dim, hybrid_dim_conv = hybrid_dim_conv)
# 遍历不同分辨率下的维度
for ind, (dim_in, dim_out) in enumerate(in_out):
is_first = ind == 0
is_last = ind >= (num_resolutions - 1)
# 添加下采样模块
self.downs.append(nn.ModuleList([
get_resnet_block(dim_in, dim_out),
get_resnet_block(dim_out, dim_out),
Residual(LinearAttention(dim_out)),
Downsample(dim_out) if not is_last else nn.Identity()
]))
mid_dim = dims[-1]
# 添加中间块
self.mid_block1 = get_resnet_block(mid_dim, mid_dim)
self.mid_attn = Residual(LinearAttention(mid_dim))
self.mid_block2 = get_resnet_block(mid_dim, mid_dim)
# 遍历不同分辨率下的维度(逆序)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
is_last = ind >= (num_resolutions - 1)
# 添加上采样模块
self.ups.append(nn.ModuleList([
get_resnet_block(dim_out * 2, dim_in),
get_resnet_block(dim_in, dim_in),
Residual(LinearAttention(dim_in)),
Upsample(dim_in) if not is_last else nn.Identity()
]))
# 设置输出维度
out_dim = default(out_dim, channels)
# 最终卷积层
self.final_conv = nn.Sequential(
nn.Sequential(
nn.Conv2d(dim, dim, 3, padding = 1),
nn.GroupNorm(groups, dim),
Mish()
),
nn.Conv2d(dim, out_dim, 1)
)
# 前向传播函数
def forward(self, x, time):
t = self.time_pos_emb(time)
t = self.mlp(t)
h = []
# 下采样过程
for resnet, resnet2, attn, downsample in self.downs:
x = resnet(x, t)
x = resnet2(x, t)
x = attn(x)
h.append(x)
x = downsample(x)
x = self.mid_block1(x, t)
x = self.mid_attn(x)
x = self.mid_block2(x, t)
# 上采样过程
for resnet, resnet2, attn, upsample in self.ups:
x = torch.cat((x, h.pop()), dim=1)
x = resnet(x, t)
x = resnet2(x, t)
x = attn(x)
x = upsample(x)
return self.final_conv(x)
# 高斯扩散训练器类
# 从输入张量中提取指定索引的元素,返回形状与 x_shape 相同的张量
def extract(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
# 生成与给定形状相同的噪声张量,可以选择是否重复生成相同的噪声
def noise_like(shape, device, repeat=False):
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
noise = lambda: torch.randn(shape, device=device)
return repeat_noise() if repeat else noise()
# 根据余弦函数生成 beta 的时间表,用于扩散过程
def cosine_beta_schedule(timesteps, s = 0.008):
"""
余弦时间表
参考 https://openreview.net/forum?id=-NEXDKk8gZ
"""
steps = timesteps + 1
x = np.linspace(0, timesteps, steps)
alphas_cumprod = np.cos(((x / timesteps) + s) / (1 + s) * np.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return np.clip(betas, a_min = 0, a_max = 0.999)
# 高斯扩散模型类
class GaussianDiffusion(nn.Module):
def __init__(
self,
denoise_model,
*,
image_size,
timesteps = 1000,
loss_type = 'l1',
betas = None
):
super().__init__()
self.channels = denoise_model.channels
self.condition_dim = denoise_model.condition_dim
self.image_size = image_size
self.denoise_model = denoise_model
# 如果提供了 betas,则将其转换为 numpy 数组,否则使用余弦时间表生成 betas
if exists(betas):
betas = betas.detach().cpu().numpy() if isinstance(betas, torch.Tensor) else betas
else:
betas = cosine_beta_schedule(timesteps)
alphas = 1. - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
self.loss_type = loss_type
to_torch = partial(torch.tensor, dtype=torch.float32)
# 将 betas、alphas_cumprod 和 alphas_cumprod_prev 注册为模型的缓冲区
self.register_buffer('betas', to_torch(betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
# 计算扩散 q(x_t | x_{t-1}) 和其他参数
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
# 计算后验 q(x_{t-1} | x_t, x_0) 的方差
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
self.register_buffer('posterior_variance', to_torch(posterior_variance))
self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20)))
self.register_buffer('posterior_mean_coef1', to_torch(
betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
self.register_buffer('posterior_mean_coef2', to_torch(
(1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
# 计算 q(x_t | x_{t-1}) 的均值、方差和对数方差
def q_mean_variance(self, x_start, t):
mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
variance = extract(1. - self.alphas_cumprod, t, x_start.shape)
log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
return mean, variance, log_variance
# 根据噪声和当前时间步预测起始图像
def predict_start_from_noise(self, x_t, t, noise):
return (
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
)
# 计算后验分布的均值、方差和截断后的对数方差
def q_posterior(self, x_start, x_t, t):
# 计算后验分布的均值
posterior_mean = (
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
# 计算后验分布的方差
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
# 获取截断后的对数方差
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
# 计算模型的均值、方差和对数方差
def p_mean_variance(self, x, t, clip_denoised=True, condition_tensor=None):
# 初始化去噪模型输入
denoise_model_input = x
# 如果存在条件张量,则将其与输入张量拼接
if exists(condition_tensor):
denoise_model_input = broadcat((condition_tensor, x), dim=1)
# 使用去噪模型对输入进行去噪
denoise_model_output = self.denoise_model(denoise_model_input, t)
# 使用去噪后的输入预测起始值
x_recon = self.predict_start_from_noise(x, t=t, noise=denoise_model_output)
# 如果需要对去噪后的值进行截断
if clip_denoised:
x_recon.clamp_(0., 1.)
# 计算后验分布的均值、方差和对数方差
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
return model_mean, posterior_variance, posterior_log_variance
# 生成样本
@torch.no_grad()
def p_sample(self, x, t, clip_denoised=True, repeat_noise=False, condition_tensor=None):
# 获取输入张量的形状和设备信息
b, *_, device = *x.shape, x.device
# 计算模型的均值、方差和对数方差
model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised, condition_tensor=condition_tensor)
# 生成噪声
noise = noise_like(x.shape, device, repeat_noise)
# 当 t == 0 时不添加噪声
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
# 循环生成样本
@torch.no_grad()
def p_sample_loop(self, shape, condition_tensor=None):
# 获取设备信息
device = self.betas.device
# 初始化图像张量
b = shape[0]
img = torch.randn(shape, device=device)
# 在时间步上进行循环
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long), condition_tensor=condition_tensor)
return img
# 生成样本
@torch.no_grad()
def sample(self, batch_size=16, condition_tensor=None):
# 检查是否需要传入条件张量
assert not (self.condition_dim > 0 and not exists(condition_tensor)), 'the conditioning tensor needs to be passed'
# 获取图像大小和通道数,返回生成的样本
image_size = self.image_size
channels = self.channels
return self.p_sample_loop((batch_size, channels, image_size, image_size), condition_tensor=condition_tensor)
# 插值生成样本
@torch.no_grad()
def interpolate(self, x1, x2, t=None, lam=0.5):
# 获取输入张量的形状和设备信息
b, *_, device = *x1.shape, x1.device
# 如果未指定时间步,则默认为最大时间步
t = default(t, self.num_timesteps - 1)
# 确保输入张量形状相同
assert x1.shape == x2.shape
# 创建时间步张量
t_batched = torch.stack([torch.tensor(t, device=device)] * b)
xt1, xt2 = map(lambda x: self.q_sample(x, t=t_batched), (x1, x2))
# 进行插值
img = (1 - lam) * xt1 + lam * xt2
for i in tqdm(reversed(range(0, t)), desc='interpolation sample time step', total=t):
img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long))
return img
# 从起始值生成样本
def q_sample(self, x_start, t, noise=None):
# 如果未提供噪声,则生成随机噪声
noise = default(noise, lambda: torch.randn_like(x_start))
# 使用累积平方根系数和噪声生成样本
return (
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
)
# 计算像素损失函数
def p_losses(self, x_start, t, noise = None, condition_tensor = None):
# 获取输入张量的形状信息
b, c, h, w = x_start.shape
# 如果没有提供噪声,则生成一个与输入张量相同形状的随机噪声张量
noise = default(noise, lambda: torch.randn_like(x_start))
# 使用噪声生成噪声图像
x_noisy = self.q_sample(x_start = x_start, t = t, noise = noise)
# 如果提供了条件张量,则将其与噪声图像拼接在一起
if exists(condition_tensor):
x_noisy = broadcat((condition_tensor, x_noisy), dim = 1)
# 使用去噪模型对噪声图像进行去噪
x_recon = self.denoise_model(x_noisy, t)
# 根据损失类型计算损失值
if self.loss_type == 'l1':
loss = (noise - x_recon).abs().mean()
elif self.loss_type == 'l2':
loss = F.mse_loss(noise, x_recon)
else:
raise NotImplementedError()
# 返回计算得到的损失值
return loss
# 前向传播函数
def forward(self, x, *args, **kwargs):
# 获取输入张量的形状信息和设备信息
b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
# 检查输入图像的高度和宽度是否符合要求
assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
# 生成随机时间步长
t = torch.randint(0, self.num_timesteps, (b,), device=device).long()
# 调用像素损失函数计算损失值并返回
return self.p_losses(x, t, *args, **kwargs)
# dataset classes
class Dataset(data.Dataset):
# 初始化数据集类,设置文件夹路径和图像大小
def __init__(self, folder, image_size):
super().__init__()
self.folder = folder
self.image_size = image_size
# 获取文件夹中所有指定扩展名的文件路径
self.paths = [p for ext in EXTS for p in Path(f'{folder}').glob(f'**/*.{ext}')]
# 设置图像转换操作
self.transform = transforms.Compose([
transforms.Resize(image_size),
transforms.RandomHorizontalFlip(),
transforms.CenterCrop(image_size),
transforms.ToTensor()
])
# 返回数据集的长度
def __len__(self):
return len(self.paths)
# 获取指定索引处的数据
def __getitem__(self, index):
path = self.paths[index]
img = Image.open(path)
return self.transform(img)
# trainer class
class Trainer(object):
# 初始化训练器类,设置参数和数据集
def __init__(
self,
diffusion_model,
folder,
*,
ema_decay = 0.995,
image_size = 128,
train_batch_size = 32,
train_lr = 2e-5,
train_num_steps = 100000,
gradient_accumulate_every = 2,
fp16 = False,
step_start_ema = 2000
):
super().__init__()
self.model = diffusion_model
self.ema = EMA(ema_decay)
self.ema_model = copy.deepcopy(self.model)
self.step_start_ema = step_start_ema
self.batch_size = train_batch_size
self.image_size = diffusion_model.image_size
self.gradient_accumulate_every = gradient_accumulate_every
self.train_num_steps = train_num_steps
# 创建数据集和数据加载器
self.ds = Dataset(folder, image_size)
self.dl = cycle(data.DataLoader(self.ds, batch_size = train_batch_size, shuffle=True, pin_memory=True))
self.opt = Adam(diffusion_model.parameters(), lr=train_lr)
self.step = 0
# 检查是否启用混合精度训练
assert not fp16 or fp16 and APEX_AVAILABLE, 'Apex must be installed in order for mixed precision training to be turned on'
self.fp16 = fp16
if fp16:
# 初始化混合精度训练
(self.model, self.ema_model), self.opt = amp.initialize([self.model, self.ema_model], self.opt, opt_level='O1')
# 重置参数
self.reset_parameters()
# 重置参数
def reset_parameters(self):
self.ema_model.load_state_dict(self.model.state_dict())
# 更新指数移动平均模型
def step_ema(self):
if self.step < self.step_start_ema:
self.reset_parameters()
return
self.ema.update_model_average(self.ema_model, self.model)
# 保存模型
def save(self, milestone):
data = {
'step': self.step,
'model': self.model.state_dict(),
'ema': self.ema_model.state_dict()
}
torch.save(data, str(RESULTS_FOLDER / f'model-{milestone}.pt'))
# 加载模型
def load(self, milestone):
data = torch.load(str(RESULTS_FOLDER / f'model-{milestone}.pt'))
self.step = data['step']
self.model.load_state_dict(data['model'])
self.ema_model.load_state_dict(data['ema'])
# 训练模型
def train(self):
backwards = partial(loss_backwards, self.fp16)
while self.step < self.train_num_steps:
for i in range(self.gradient_accumulate_every):
data = next(self.dl).cuda()
loss = self.model(data)
print(f'{self.step}: {loss.item()}')
backwards(loss / self.gradient_accumulate_every, self.opt)
self.opt.step()
self.opt.zero_grad()
if self.step % UPDATE_EMA_EVERY == 0:
self.step_ema()
if self.step != 0 and self.step % SAVE_AND_SAMPLE_EVERY == 0:
milestone = self.step // SAVE_AND_SAMPLE_EVERY
batches = num_to_groups(36, self.batch_size)
all_images_list = list(map(lambda n: self.ema_model.sample(batch_size=n), batches))
all_images = torch.cat(all_images_list, dim=0)
utils.save_image(all_images, str(RESULTS_FOLDER / f'sample-{milestone}.png'), nrow=6)
self.save(milestone)
self.step += 1
print('training completed')