Lucidrains 系列项目源码解析(十九)
.\lucidrains\DALLE2-pytorch\dalle2_pytorch\dataloaders\decoder_loader.py
import os
import webdataset as wds
import torch
from torch.utils.data import DataLoader
import numpy as np
import fsspec
import shutil
def get_shard(filename):
"""
Filenames with shards in them have a consistent structure that we can take advantage of
Standard structure: path/to/file/prefix_string_00001.ext
"""
try:
return filename.split("_")[-1].split(".")[0]
except ValueError:
raise RuntimeError(f"Could not find shard for filename {filename}")
def get_example_file(fs, path, file_format):
"""
Given a file system and a file extension, return the example file
"""
return fs.glob(os.path.join(path, f"*.{file_format}"))[0]
def embedding_inserter(samples, embeddings_url, index_width, sample_key='npy', handler=wds.handlers.reraise_exception):
"""Given a datum of {"__key__": str, "__url__": str, ...} adds the corresponding embedding and yields"""
previous_tar_url = None
current_embeddings = None
# Get a reference to an abstract file system where the embeddings are stored
embeddings_fs, embeddings_path = fsspec.core.url_to_fs(embeddings_url)
example_embedding_file = get_example_file(embeddings_fs, embeddings_path, "npy")
example_embedding_shard = get_shard(example_embedding_file)
emb_shard_width = len(example_embedding_shard)
# Easier to get the basename without the shard once than search through for the correct file every time
embedding_file_basename = '_'.join(example_embedding_file.split("_")[:-1]) + "_"
def load_corresponding_embeds(tar_url):
"""Finds and reads the npy files that contains embeddings for the given webdataset tar"""
shard = int(tar_url.split("/")[-1].split(".")[0])
embedding_url = embedding_file_basename + str(shard).zfill(emb_shard_width) + '.npy'
with embeddings_fs.open(embedding_url) as f:
data = np.load(f)
return torch.from_numpy(data)
for sample in samples:
try:
tar_url = sample["__url__"]
key = sample["__key__"]
if tar_url != previous_tar_url:
# If the tar changed, we need to download new embeddings
# This means if we shuffle before inserting it will load many more files than we expect and be very inefficient.
previous_tar_url = tar_url
current_embeddings = load_corresponding_embeds(tar_url)
embedding_index = int(key[-index_width:])
embedding = current_embeddings[embedding_index]
# We need to check if this sample is nonzero. If it is, this embedding is not valid and we should continue to the next loop
if torch.count_nonzero(embedding) == 0:
raise RuntimeError(f"Webdataset had a sample, but no embedding was found. ImgShard: {key[:-index_width]} - Index: {key[-index_width:]}")
sample[sample_key] = embedding
yield sample
except Exception as exn: # From wds implementation
if handler(exn):
continue
else:
break
insert_embedding = wds.filters.pipelinefilter(embedding_inserter)
def unassociated_shard_skipper(tarfiles, embeddings_url, handler=wds.handlers.reraise_exception):
"""Finds if there is a corresponding embedding for the tarfile at { url: [URL] }"""
embeddings_fs, embeddings_path = fsspec.core.url_to_fs(embeddings_url)
embedding_files = embeddings_fs.ls(embeddings_path)
get_embedding_shard = lambda embedding_file: int(embedding_file.split("_")[-1].split(".")[0])
embedding_shards = set([get_embedding_shard(filename) for filename in embedding_files]) # Sets have O(1) check for member
get_tar_shard = lambda tar_file: int(tar_file.split("/")[-1].split(".")[0])
# 遍历 tarfiles 列表中的每个 tarfile
for tarfile in tarfiles:
try:
# 获取 tarfile 对应的 webdataset shard
webdataset_shard = get_tar_shard(tarfile["url"])
# 如果该 shard 有关联的 embeddings 文件,则返回该 tarfile
# 否则继续迭代直到找到有关联的 embeddings 文件
if webdataset_shard in embedding_shards:
yield tarfile
except Exception as exn: # 从 wds 实现中捕获异常
# 如果 handler 函数处理了异常,则继续循环
if handler(exn):
continue
# 如果 handler 函数未处理异常,则跳出循环
else:
break
# 创建一个过滤器,用于跳过未关联的碎片
skip_unassociated_shards = wds.filters.pipelinefilter(unassociated_shard_skipper)
# 将样本中的img_emb和text_emb键合并为一个键"emb": { "text": text_emb, "img": img_emb }
# 如果text_emb和img_emb中的一个或两个不存在于样本中,则只添加存在的部分
def join_embeddings(samples, handler=wds.handlers.reraise_exception):
"""
Takes the img_emb and text_emb keys and turns them into one key "emb": { "text": text_emb, "img": img_emb }
either or both of text_emb and img_emb may not be in the sample so we only add the ones that exist
"""
for sample in samples:
try:
sample['emb'] = {}
if 'text_emb' in sample:
sample['emb']['text'] = sample['text_emb']
if 'img_emb' in sample:
sample['emb']['img'] = sample['img_emb']
yield sample
except Exception as exn: # From wds implementation
if handler(exn):
continue
else:
break
# 验证样本中是否存在所需的键,如果不存在则抛出异常
def verify_keys(samples, required_keys, handler=wds.handlers.reraise_exception):
"""
Requires that both the image and embedding are present in the sample
This is important to do as a user may forget they do not have embeddings in their webdataset and neglect to add them using the embedding_folder_url parameter.
"""
for sample in samples:
try:
for key in required_keys:
assert key in sample, f"Sample {sample['__key__']} missing {key}. Has keys {sample.keys()}"
yield sample
except Exception as exn: # From wds implementation
if handler(exn):
continue
else:
break
# 创建一个过滤器,用于验证样本中是否存在所需的键
key_verifier = wds.filters.pipelinefilter(verify_keys)
# ImageEmbeddingDataset类,是DataPipeline的流式接口包装器,返回图像嵌入对
# 从webdataset中读取npy文件作为嵌入,如果存在的话。如果设置了embedding_folder_url,则会从替代来源插入它们
class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
"""
A fluid interface wrapper for DataPipline that returns image embedding pairs
Reads embeddings as npy files from the webdataset if they exist. If embedding_folder_url is set, they will be inserted in from the alternate source.
"""
def __init__(
self,
urls,
img_embedding_folder_url=None,
text_embedding_folder_url=None,
index_width=None,
img_preproc=None,
extra_keys=[],
handler=wds.handlers.reraise_exception,
resample=False,
shuffle_shards=True
def preproc(self, sample):
"""Applies the preprocessing for images"""
if self.img_preproc is not None:
sample["jpg"] = self.img_preproc(sample["jpg"])
return sample
# 创建一个图像嵌入数据加载器的便捷函数
def create_image_embedding_dataloader(
tar_url,
num_workers,
batch_size,
img_embeddings_url=None,
text_embeddings_url=None,
index_width=None,
shuffle_num = None,
shuffle_shards = True,
resample_shards = False,
img_preproc=None,
extra_keys=[],
handler=wds.handlers.reraise_exception#warn_and_continue
):
"""
Convenience function to create an image embedding dataseta and dataloader in one line
:param tar_url: A url pointing to the tar files of the webdataset formatted as /path/to/webdataset/{0000..9999}.tar
:param num_workers: The number of workers to use for the dataloader
:param batch_size: The batch size to use for the dataloader
:param embeddings_url: Required if webdataset does not contain embeddings. A url pointing to the npy files of the embeddings. Should have the same number of shards as the webdataset.
Webdataset image keys should align with the index of the embedding. This means missing image indices must have a corresponding embedding of all zeros.
:param index_width: The number of digits in the index. This is used to align the embedding index with the image index.
For example, if a file in the webdataset shard 3 is named 0003039.jpg, we know the shard is 4 digits and the last 3 digits are the index_width.
:param shuffle_num: If not None, shuffle the dataset with this size buffer after sampling.
:param shuffle_shards: If true, shuffle the shards before sampling. This cannot be true if resample is true.
:param resample_shards: 如果为True,则对webdataset分片进行有放回抽样。如果设置为True,则需要设置自己的epoch大小,因为它将无限重采样。
:param handler: webdataset处理程序。
"""
# 创建ImageEmbeddingDataset对象
ds = ImageEmbeddingDataset(
tar_url,
img_embedding_folder_url=img_embeddings_url,
text_embedding_folder_url=text_embeddings_url,
index_width=index_width,
shuffle_shards=shuffle_shards,
resample=resample_shards,
extra_keys=extra_keys,
img_preproc=img_preproc,
handler=handler
)
# 如果设置了shuffle_num并且大于0,则对数据集进行洗牌
if shuffle_num is not None and shuffle_num > 0:
ds.shuffle(1000)
# 返回一个DataLoader对象
return DataLoader(
ds,
num_workers=num_workers,
batch_size=batch_size,
prefetch_factor=2, # 这可能是一个好主意,使其较高,以便预取下一个npy文件
pin_memory=True,
shuffle=False
)
.\lucidrains\DALLE2-pytorch\dalle2_pytorch\dataloaders\prior_loader.py
# 从 math 模块中导入 ceil 函数
from math import ceil
# 从 clip 模块中导入 tokenize 函数
from clip import tokenize
# 从 embedding_reader 模块中导入 EmbeddingReader 类
from embedding_reader import EmbeddingReader
# 从 torch 模块中导入 from_numpy 函数和 DataLoader 类
from torch import from_numpy
from torch.utils.data import IterableDataset, DataLoader
# 定义 PriorEmbeddingDataset 类,继承自 IterableDataset 类
class PriorEmbeddingDataset(IterableDataset):
"""
PriorEmbeddingDataset is a wrapper of EmbeddingReader.
It enables one to simplify the logic necessary to yield samples from
the different EmbeddingReader configurations available.
"""
# 初始化方法
def __init__(
self,
text_conditioned: bool,
batch_size: int,
start: int,
stop: int,
image_reader,
text_reader: EmbeddingReader = None,
) -> None:
# 调用父类的初始化方法
super(PriorEmbeddingDataset).__init__()
# 设置属性值
self.text_conditioned = text_conditioned
# 如果不是文本条件,则设置文本阅读器
if not self.text_conditioned:
self.text_reader = text_reader
# 设置属性值
self.image_reader = image_reader
self.start = start
self.stop = stop
self.batch_size = batch_size
# 返回数据集的长度
def __len__(self):
return self.stop - self.start
# 迭代器方法
def __iter__(self):
# 定义 loader_args 字典
loader_args = dict(
batch_size=self.batch_size,
start=self.start,
end=self.stop,
show_progress=False,
)
# 如果请求的数据是文本条件的,则只加载图像
if self.text_conditioned:
self.loader = self.image_reader(**loader_args)
# 否则,包括文本嵌入并绕过元数据
else:
self.loader = zip(
self.image_reader(**loader_args), self.text_reader(**loader_args)
)
# 返回格式化后的数据加载器
return self
# 获取下一个数据样本
def __next__(self):
try:
return self.get_sample()
except StopIteration:
raise StopIteration
# 返回对象的字符串表示形式
def __str__(self):
return f"<PriorEmbeddingDataset: start: {self.start}, stop: {self.stop}, len: {self.__len__()}>"
# 设置起始点
def set_start(self, start):
"""
Adjust the starting point within the reader, useful for resuming an epoch
"""
self.start = start
# 获取起始点
def get_start(self):
return self.start
# 获取样本数据
def get_sample(self):
"""
pre-proocess data from either reader into a common format
"""
if self.text_conditioned:
image_embedding, caption = next(self.loader)
image_embedding = from_numpy(image_embedding)
tokenized_caption = tokenize(caption["caption"].to_list(), truncate=True)
return image_embedding, tokenized_caption
else:
(image_embedding, _), (text_embedding, _) = next(self.loader)
image_embedding = from_numpy(image_embedding)
text_embedding = from_numpy(text_embedding)
return image_embedding, text_embedding
# 辅助函数
# 分发数据给每个排名
def distribute_to_rank(start, stop, rank, world_size):
"""
Distribute data to each rank given the world size.
Return:
- New start and stop points for this rank.
"""
num_samples = int(stop - start)
per_rank = int(ceil((num_samples) / float(world_size)))
assert (
per_rank > 0
), f"Number of samples per rank must be larger than 0, (found: {per_rank})"
rank_start = start + rank * per_rank
rank_stop = min(rank_start + per_rank, stop)
new_length = rank_stop - rank_start
assert (
new_length > 0
), "Calculated start and stop points result in a length of zero for this rank."
return rank_start, rank_stop
# 获取阅读器对象
def get_reader(
text_conditioned: bool, img_url: str, meta_url: str = None, txt_url: str = None
):
"""
Create an EmbeddingReader object from the specified URLs
get_reader() will always expect a url to image embeddings.
If text-conditioned, it will also expect a meta_url for the captions.
Otherwise, it will need txt_url for the matching text embeddings.
Returns an image_reader object if text-conditioned.
Otherwise it returns both an image_reader and a text_reader
"""
# 断言确保图像 URL 不为空
assert img_url is not None, "Must supply a image url"
# 如果需要文本条件,则断言确保元数据 URL 不为空
if text_conditioned:
assert meta_url is not None, "Must supply meta url if text-conditioned"
# 创建一个 EmbeddingReader 对象用于读取图像数据
image_reader = EmbeddingReader(
embeddings_folder=img_url,
file_format="parquet_npy",
# 假设标题列存在且是唯一请求的列
meta_columns=["caption"],
metadata_folder=meta_url,
)
# 返回图像数据读取器
return image_reader
# 否则,需要文本嵌入,返回两个读取器
assert (
txt_url is not None
), "Must supply text embedding url if not text-conditioning"
# 创建一个 EmbeddingReader 对象用于读取图像数据
image_reader = EmbeddingReader(img_url, file_format="npy")
# 创建一个 EmbeddingReader 对象用于读取文本数据
text_reader = EmbeddingReader(txt_url, file_format="npy")
# 返回图像数据读取器和文本数据读取器
return image_reader, text_reader
def make_splits(
text_conditioned: bool,
batch_size: int,
num_data_points: int,
train_split: float,
eval_split: float,
image_reader: EmbeddingReader,
text_reader: EmbeddingReader = None,
start=0,
rank=0,
world_size=1,
):
"""
Split an embedding reader object as needed.
NOTE: make_splits() will infer the test set size from your train and eval.
Input:
- text_conditioned: whether to prepare text-conditioned training data
- batch_size: the batch size for a single gpu
- num_data_points: the total number of data points you wish to train on
- train_split: the percentage of data you wish to train on
- eval_split: the percentage of data you wish to validate on
- image_reader: the image_reader you wish to split
- text_reader: the text_reader you want to split (if !text_conditioned)
- start: the starting point within your dataset
- rank: the rank of your worker
- world_size: the total world size of your distributed training run
Returns:
- PyTorch Dataloaders that yield tuples of (img, txt) data.
"""
assert start < image_reader.count, "start position cannot exceed reader count."
# verify that the num_data_points does not exceed the max points
if num_data_points > (image_reader.count - start):
print(
"Specified count is larger than what's available...defaulting to reader's count."
)
num_data_points = image_reader.count
# compute split points
train_set_size = int(train_split * num_data_points)
eval_set_size = int(eval_split * num_data_points)
eval_start = train_set_size
eval_stop = int(eval_start + eval_set_size)
assert (
train_split + eval_split
) < 1.0, "Specified train and eval split is too large to infer a test split."
# distribute to rank
rank_train_start, rank_train_stop = distribute_to_rank(
start, train_set_size, rank, world_size
)
rank_eval_start, rank_eval_stop = distribute_to_rank(
train_set_size, eval_stop, rank, world_size
)
rank_test_start, rank_test_stop = distribute_to_rank(
eval_stop, num_data_points, rank, world_size
)
# wrap up splits into a dict
train_split_args = dict(
start=rank_train_start, stop=rank_train_stop, batch_size=batch_size
)
eval_split_args = dict(
start=rank_eval_start, stop=rank_eval_stop, batch_size=batch_size
)
test_split_args = dict(
start=rank_test_start, stop=rank_test_stop, batch_size=batch_size
)
if text_conditioned:
# add the text-conditioned args to a unified dict
reader_args = dict(
text_conditioned=text_conditioned,
image_reader=image_reader,
)
train_split_args = dict(**reader_args, **train_split_args)
eval_split_args = dict(**reader_args, **eval_split_args)
test_split_args = dict(**reader_args, **test_split_args)
train = PriorEmbeddingDataset(**train_split_args)
val = PriorEmbeddingDataset(**eval_split_args)
test = PriorEmbeddingDataset(**test_split_args)
else:
# add the non-conditioned args to a unified dict
reader_args = dict(
text_conditioned=text_conditioned,
image_reader=image_reader,
text_reader=text_reader,
)
train_split_args = dict(**reader_args, **train_split_args)
eval_split_args = dict(**reader_args, **eval_split_args)
test_split_args = dict(**reader_args, **test_split_args)
train = PriorEmbeddingDataset(**train_split_args)
val = PriorEmbeddingDataset(**eval_split_args)
test = PriorEmbeddingDataset(**test_split_args)
# true batch size is specifed in the PriorEmbeddingDataset
train_loader = DataLoader(train, batch_size=None)
eval_loader = DataLoader(val, batch_size=None)
# 创建一个数据加载器用于加载测试数据集,batch_size设置为None表示每次加载整个数据集
test_loader = DataLoader(test, batch_size=None)
# 返回训练数据加载器、验证数据加载器和测试数据加载器
return train_loader, eval_loader, test_loader
Dataloaders
In order to make loading data simple and efficient, we include some general dataloaders that can be used to train portions of the network.
Decoder: Image Embedding Dataset
When training the decoder (and up samplers if training together) in isolation, you will need to load images and corresponding image embeddings. This dataset can read two similar types of datasets. First, it can read a webdataset that contains .jpg and .npy files in the .tars that contain the images and associated image embeddings respectively. Alternatively, you can also specify a source for the embeddings outside of the webdataset. In this case, the path to the embeddings should contain .npy files with the same shard numbers as the webdataset and there should be a correspondence between the filename of the .jpg and the index of the embedding in the .npy. So, for example, 0001.tar from the webdataset with image 00010509.jpg (the first 4 digits are the shard number and the last 4 are the index) in it should be paralleled by a img_emb_0001.npy which contains a NumPy array with the embedding at index 509.
Generating a dataset of this type:
- Use img2dataset to generate a webdataset.
- Use clip-retrieval to convert the images to embeddings.
- Use embedding-dataset-reordering to reorder the embeddings into the expected format.
Usage:
from dalle2_pytorch.dataloaders import ImageEmbeddingDataset, create_image_embedding_dataloader
# Create a dataloader directly.
dataloader = create_image_embedding_dataloader(
tar_url="/path/or/url/to/webdataset/{0000..9999}.tar", # Uses bracket expanding notation. This specifies to read all tars from 0000.tar to 9999.tar
embeddings_url="path/or/url/to/embeddings/folder", # Included if .npy files are not in webdataset. Left out or set to None otherwise
num_workers=4,
batch_size=32,
shard_width=4, # If a file in the webdataset shard 3 is named 0003039.jpg, we know the shard width is 4 and the last three digits are the index
shuffle_num=200, # Does a shuffle of the data with a buffer size of 200
shuffle_shards=True, # Shuffle the order the shards are read in
resample_shards=False, # Sample shards with replacement. If true, an epoch will be infinite unless stopped manually
)
for img, emb in dataloader:
print(img.shape) # torch.Size([32, 3, 256, 256])
print(emb.shape) # torch.Size([32, 512])
# Train decoder only as shown above
# Or create a dataset without a loader so you can configure it manually
dataset = ImageEmbeddingDataset(
urls="/path/or/url/to/webdataset/{0000..9999}.tar",
embedding_folder_url="path/or/url/to/embeddings/folder",
shard_width=4,
shuffle_shards=True,
resample=False
)
Diffusion Prior: Prior Embedding Dataset
When training the prior it is much more efficient to work with pre-computed embeddings. The PriorEmbeddingDataset class enables you to leverage the same script (with minimal modification) for both embedding-only and text-conditioned prior training. This saves you from having to worry about a lot of the boilerplate code.
To utilize the PriorEmbeddingDataset, all you need to do is make a single call to get_reader() which will create EmbeddingReader object(s) for you. Afterwards, you can utilize make_splits() to cleanly create DataLoader objects from for your training run.
If you are training in a distributed manner, make_splits() accepts rank and world_size arguments to properly distribute to each process. The defaults for these values are rank=0 and world_size=1, so single-process training can safely ignore these parameters.
Usage:
from dalle2_pytorch.dataloaders import get_reader, make_splits
# grab embeddings from some specified location
IMG_URL = "data/img_emb/"
META_URL = "data/meta/"
reader = get_reader(text_conditioned=True, img_url=IMG_URL, meta_url=META_URL)
# some config for training
TRAIN_ARGS = {
"world_size": 3,
"text_conditioned": True,
"start": 0,
"num_data_points": 10000,
"batch_size": 2,
"train_split": 0.5,
"eval_split": 0.25,
"image_reader": reader,
}
# specifying a rank will handle allocation internally
rank0_train, rank0_eval, rank0_test = make_splits(rank=0, **TRAIN_ARGS)
rank1_train, rank1_eval, rank1_test = make_splits(rank=1, **TRAIN_ARGS)
rank2_train, rank2_eval, rank2_test = make_splits(rank=2, **TRAIN_ARGS)
.\lucidrains\DALLE2-pytorch\dalle2_pytorch\dataloaders\simple_image_only_dataloader.py
# 导入所需的库
from pathlib import Path
import torch
from torch.utils import data
from torchvision import transforms, utils
from PIL import Image
# 定义一个循环生成器函数,用于无限循环遍历数据集
def cycle(dl):
while True:
for data in dl:
yield data
# 定义数据集类
class Dataset(data.Dataset):
def __init__(
self,
folder,
image_size,
exts = ['jpg', 'jpeg', 'png']
):
super().__init__()
self.folder = folder
self.image_size = image_size
# 获取指定文件夹下所有指定扩展名的文件路径
self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]
# 定义数据预处理的操作
self.transform = transforms.Compose([
transforms.Resize(image_size),
transforms.RandomHorizontalFlip(),
transforms.CenterCrop(image_size),
transforms.ToTensor()
])
# 返回数据集的长度
def __len__(self):
return len(self.paths)
# 根据索引获取数据
def __getitem__(self, index):
path = self.paths[index]
img = Image.open(path)
return self.transform(img)
# 获取图像数据的数据加载器
def get_images_dataloader(
folder,
*,
batch_size,
image_size,
shuffle = True,
cycle_dl = True,
pin_memory = True
):
# 创建数据集对象
ds = Dataset(folder, image_size)
# 创建数据加载器对象
dl = data.DataLoader(ds, batch_size = batch_size, shuffle = shuffle, pin_memory = pin_memory)
# 如果需要循环遍历数据加载器,则将数据加载器设置为循环生成器
if cycle_dl:
dl = cycle(dl)
return dl
.\lucidrains\DALLE2-pytorch\dalle2_pytorch\dataloaders\__init__.py
# 从dalle2_pytorch.dataloaders.decoder_loader模块中导入ImageEmbeddingDataset和create_image_embedding_dataloader函数
from dalle2_pytorch.dataloaders.decoder_loader import ImageEmbeddingDataset, create_image_embedding_dataloader
# 从dalle2_pytorch.dataloaders.prior_loader模块中导入make_splits、get_reader和PriorEmbeddingDataset函数
from dalle2_pytorch.dataloaders.prior_loader import make_splits, get_reader, PriorEmbeddingDataset
.\lucidrains\DALLE2-pytorch\dalle2_pytorch\optimizer.py
# 从 torch.optim 模块中导入 AdamW 和 Adam 优化器
from torch.optim import AdamW, Adam
# 将参数分为需要权重衰减和不需要权重衰减的两个列表
def separate_weight_decayable_params(params):
wd_params, no_wd_params = [], []
for param in params:
# 根据参数的维度判断是否需要权重衰减
param_list = no_wd_params if param.ndim < 2 else wd_params
param_list.append(param)
return wd_params, no_wd_params
# 获取优化器
def get_optimizer(
params,
lr = 1e-4,
wd = 1e-2,
betas = (0.9, 0.99),
eps = 1e-8,
filter_by_requires_grad = False,
group_wd_params = True,
**kwargs
):
# 根据是否需要梯度过滤参数
if filter_by_requires_grad:
params = list(filter(lambda t: t.requires_grad, params))
# 如果权重衰减为0,则使用 Adam 优化器
if wd == 0:
return Adam(params, lr = lr, betas = betas, eps = eps)
# 如果需要对参数进行分组权重衰减
if group_wd_params:
wd_params, no_wd_params = separate_weight_decayable_params(params)
# 将参数分为需要权重衰减和不需要权重衰减的两组
params = [
{'params': wd_params},
{'params': no_wd_params, 'weight_decay': 0},
]
# 使用 AdamW 优化器,设置学习率、权重衰减、动量参数和 epsilon
return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps)
.\lucidrains\DALLE2-pytorch\dalle2_pytorch\trackers.py
# 导入所需的库
import urllib.request
import os
import json
from pathlib import Path
import shutil
from itertools import zip_longest
from typing import Any, Optional, List, Union
from pydantic import BaseModel
import torch
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
from dalle2_pytorch.utils import import_or_print_error
from dalle2_pytorch.trainer import DecoderTrainer, DiffusionPriorTrainer
from dalle2_pytorch.version import __version__
from packaging import version
# 常量定义
DEFAULT_DATA_PATH = './.tracker-data'
# 辅助函数
def exists(val):
return val is not None
# 定义基础日志类
class BaseLogger:
"""
An abstract class representing an object that can log data.
Parameters:
data_path (str): A file path for storing temporary data.
verbose (bool): Whether of not to always print logs to the console.
"""
def __init__(self, data_path: str, resume: bool = False, auto_resume: bool = False, verbose: bool = False, **kwargs):
self.data_path = Path(data_path)
self.resume = resume
self.auto_resume = auto_resume
self.verbose = verbose
def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:
"""
Initializes the logger.
Errors if the logger is invalid.
full_config is the config file dict while extra_config is anything else from the script that is not defined the config file.
"""
raise NotImplementedError
def log(self, log, **kwargs) -> None:
raise NotImplementedError
def log_images(self, images, captions=[], image_section="images", **kwargs) -> None:
raise NotImplementedError
def log_file(self, file_path, **kwargs) -> None:
raise NotImplementedError
def log_error(self, error_string, **kwargs) -> None:
raise NotImplementedError
def get_resume_data(self, **kwargs) -> dict:
"""
Sets tracker attributes that along with { "resume": True } will be used to resume training.
It is assumed that after init is called this data will be complete.
If the logger does not have any resume functionality, it should return an empty dict.
"""
raise NotImplementedError
# 定义控制台日志类
class ConsoleLogger(BaseLogger):
def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:
print("Logging to console")
def log(self, log, **kwargs) -> None:
print(log)
def log_images(self, images, captions=[], image_section="images", **kwargs) -> None:
pass
def log_file(self, file_path, **kwargs) -> None:
pass
def log_error(self, error_string, **kwargs) -> None:
print(error_string)
def get_resume_data(self, **kwargs) -> dict:
return {}
# 定义Wandb日志类
class WandbLogger(BaseLogger):
"""
Logs to a wandb run.
Parameters:
data_path (str): A file path for storing temporary data.
wandb_entity (str): The wandb entity to log to.
wandb_project (str): The wandb project to log to.
wandb_run_id (str): The wandb run id to resume.
wandb_run_name (str): The wandb run name to use.
"""
def __init__(self,
data_path: str,
wandb_entity: str,
wandb_project: str,
wandb_run_id: Optional[str] = None,
wandb_run_name: Optional[str] = None,
**kwargs
):
super().__init__(data_path, **kwargs)
self.entity = wandb_entity
self.project = wandb_project
self.run_id = wandb_run_id
self.run_name = wandb_run_name
# 初始化函数,接受完整配置、额外配置和其他参数,不返回任何内容
def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:
# 断言 wandb_entity 必须被指定以使用 wandb 记录器
assert self.entity is not None, "wandb_entity must be specified for wandb logger"
# 断言 wandb_project 必须被指定以使用 wandb 记录器
assert self.project is not None, "wandb_project must be specified for wandb logger"
# 导入 wandb 模块或打印错误信息
self.wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb logger')
# 设置环境变量 WANDB_SILENT 为 true
os.environ["WANDB_SILENT"] = "true"
# 初始化 wandb 运行对象
init_object = {
"entity": self.entity,
"project": self.project,
"config": {**full_config.dict(), **extra_config}
}
# 如果指定了运行名称,则设置到初始化对象中
if self.run_name is not None:
init_object['name'] = self.run_name
# 如果要恢复运行,则设置相应参数
if self.resume:
assert self.run_id is not None, '`wandb_run_id` must be provided if `wandb_resume` is True'
if self.run_name is not None:
print("You are renaming a run. I hope that is what you intended.")
init_object['resume'] = 'must'
init_object['id'] = self.run_id
# 初始化 wandb 运行
self.wandb.init(**init_object)
print(f"Logging to wandb run {self.wandb.run.path}-{self.wandb.run.name}")
# 记录日志函数
def log(self, log, **kwargs) -> None:
# 如果设置了 verbose,则打印日志
if self.verbose:
print(log)
# 记录日志到 wandb
self.wandb.log(log, **kwargs)
# 记录图片函数
def log_images(self, images, captions=[], image_section="images", **kwargs) -> None:
"""
Takes a tensor of images and a list of captions and logs them to wandb.
"""
# 创建 wandb 图像对象列表
wandb_images = [self.wandb.Image(image, caption=caption) for image, caption in zip_longest(images, captions)]
# 记录图像到 wandb
self.wandb.log({ image_section: wandb_images }, **kwargs)
# 记录文件函数
def log_file(self, file_path, base_path: Optional[str] = None, **kwargs) -> None:
# 如果未指定基本路径,则将文件路径的父路径作为基本路径
if base_path is None:
base_path = Path(file_path).parent
# 保存文件到 wandb
self.wandb.save(str(file_path), base_path = str(base_path))
# 记录错误函数
def log_error(self, error_string, step=None, **kwargs) -> None:
# 如果设置了 verbose,则打印错误信息
if self.verbose:
print(error_string)
# 记录错误信息到 wandb
self.wandb.log({"error": error_string, **kwargs}, step=step)
# 获取恢复数据函数
def get_resume_data(self, **kwargs) -> dict:
# 为了恢复运行,需要 wandb_entity、wandb_project 和 wandb_run_id
return {
"entity": self.entity,
"project": self.project,
"run_id": self.wandb.run.id
}
# 定义一个字典,将不同的日志类型映射到对应的日志类
logger_type_map = {
'console': ConsoleLogger,
'wandb': WandbLogger,
}
# 创建日志记录器的函数,根据日志类型选择对应的日志类进行实例化
def create_logger(logger_type: str, data_path: str, **kwargs) -> BaseLogger:
# 如果日志类型为'custom',则抛出未实现错误
if logger_type == 'custom':
raise NotImplementedError('Custom loggers are not supported yet. Please use a different logger type.')
try:
# 根据日志类型从映射字典中获取对应的日志类
logger_class = logger_type_map[logger_type]
except KeyError:
# 如果日志类型未知,则抛出数值错误
raise ValueError(f'Unknown logger type: {logger_type}. Must be one of {list(logger_type_map.keys())}')
# 返回实例化的日志类对象
return logger_class(data_path, **kwargs)
# 定义一个抽象基类,表示可以加载模型检查点的对象
class BaseLoader:
"""
An abstract class representing an object that can load a model checkpoint.
Parameters:
data_path (str): A file path for storing temporary data.
"""
def __init__(self, data_path: str, only_auto_resume: bool = False, **kwargs):
self.data_path = Path(data_path)
self.only_auto_resume = only_auto_resume
def init(self, logger: BaseLogger, **kwargs) -> None:
raise NotImplementedError
def recall() -> dict:
raise NotImplementedError
# 定义一个从 URL 下载文件并加载的加载器类
class UrlLoader(BaseLoader):
"""
A loader that downloads the file from a url and loads it
Parameters:
data_path (str): A file path for storing temporary data.
url (str): The url to download the file from.
"""
def __init__(self, data_path: str, url: str, **kwargs):
super().__init__(data_path, **kwargs)
self.url = url
def init(self, logger: BaseLogger, **kwargs) -> None:
# 确保要下载的文件存在
pass # TODO: Actually implement that
def recall(self) -> dict:
# 下载文件
save_path = self.data_path / 'loaded_checkpoint.pth'
urllib.request.urlretrieve(self.url, str(save_path))
# 加载文件
return torch.load(str(save_path), map_location='cpu')
# 定义一个从本地路径加载文件的加载器类
class LocalLoader(BaseLoader):
"""
A loader that loads a file from a local path
Parameters:
data_path (str): A file path for storing temporary data.
file_path (str): The path to the file to load.
"""
def __init__(self, data_path: str, file_path: str, **kwargs):
super().__init__(data_path, **kwargs)
self.file_path = Path(file_path)
def init(self, logger: BaseLogger, **kwargs) -> None:
# 确保要加载的文件存在
if not self.file_path.exists() and not self.only_auto_resume:
raise FileNotFoundError(f'Model not found at {self.file_path}')
def recall(self) -> dict:
# 加载文件
return torch.load(str(self.file_path), map_location='cpu')
# 定义一个从 wandb 运行中加载模型的加载器类
class WandbLoader(BaseLoader):
"""
A loader that loads a model from an existing wandb run
"""
def __init__(self, data_path: str, wandb_file_path: str, wandb_run_path: Optional[str] = None, **kwargs):
super().__init__(data_path, **kwargs)
self.run_path = wandb_run_path
self.file_path = wandb_file_path
def init(self, logger: BaseLogger, **kwargs) -> None:
self.wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb recall function')
# 确保文件可以被下载
if self.wandb.run is not None and self.run_path is None:
self.run_path = self.wandb.run.path
assert self.run_path is not None, 'wandb run was not found to load from. If not using the wandb logger must specify the `wandb_run_path`.'
assert self.run_path is not None, '`wandb_run_path` must be provided for the wandb loader'
assert self.file_path is not None, '`wandb_file_path` must be provided for the wandb loader'
os.environ["WANDB_SILENT"] = "true"
pass # TODO: Actually implement that
def recall(self) -> dict:
file_reference = self.wandb.restore(self.file_path, run_path=self.run_path)
return torch.load(file_reference.name, map_location='cpu')
# 定义一个字典,将不同的加载器类型映射到对应的加载器类
loader_type_map = {
'url': UrlLoader,
'local': LocalLoader,
# 键为'wandb',值为WandbLoader的键值对
'wandb': WandbLoader,
# 结束当前代码块
}
# 创建数据加载器的函数,根据给定的加载器类型和数据路径返回相应的加载器对象
def create_loader(loader_type: str, data_path: str, **kwargs) -> BaseLoader:
# 如果加载器类型为'custom',则抛出未实现错误
if loader_type == 'custom':
raise NotImplementedError('Custom loaders are not supported yet. Please use a different loader type.')
# 尝试获取对应加载器类型的加载器类
try:
loader_class = loader_type_map[loader_type]
except KeyError:
# 如果加载器类型未知,则抛出数值错误
raise ValueError(f'Unknown loader type: {loader_type}. Must be one of {list(loader_type_map.keys())}')
# 返回使用给定数据路径和参数初始化的加载器对象
return loader_class(data_path, **kwargs)
# 基础保存器类
class BaseSaver:
# 初始化函数
def __init__(self,
data_path: str,
save_latest_to: Optional[Union[str, bool]] = None,
save_best_to: Optional[Union[str, bool]] = None,
save_meta_to: Optional[str] = None,
save_type: str = 'checkpoint',
**kwargs
):
# 初始化保存器属性
self.data_path = Path(data_path)
self.save_latest_to = save_latest_to
self.saving_latest = save_latest_to is not None and save_latest_to is not False
self.save_best_to = save_best_to
self.saving_best = save_best_to is not None and save_best_to is not False
self.save_meta_to = save_meta_to
self.saving_meta = save_meta_to is not None
self.save_type = save_type
# 断言保存类型为'checkpoint'或'model'
assert save_type in ['checkpoint', 'model'], '`save_type` must be one of `checkpoint` or `model`'
# 断言至少有一个保存选项被指定
assert self.saving_latest or self.saving_best or self.saving_meta, 'At least one saving option must be specified'
# 初始化函数,抛出未实现错误
def init(self, logger: BaseLogger, **kwargs) -> None:
raise NotImplementedError
# 保存文件函数,抛出未实现错误
def save_file(self, local_path: Path, save_path: str, is_best=False, is_latest=False, **kwargs) -> None:
"""
Save a general file under save_meta_to
"""
raise NotImplementedError
# 本地保存器类,继承自基础保存器类
class LocalSaver(BaseSaver):
# 初始化函数
def __init__(self,
data_path: str,
**kwargs
):
# 调用父类初始化函数
super().__init__(data_path, **kwargs)
# 初始化函数,确保要保存的目录存在
def init(self, logger: BaseLogger, **kwargs) -> None:
print(f"Saving {self.save_type} locally")
# 如果数据路径不存在,则创建目录
if not self.data_path.exists():
self.data_path.mkdir(parents=True)
# 保存文件函数,复制文件到指定路径
def save_file(self, local_path: str, save_path: str, **kwargs) -> None:
# 获取保存路径文件名
save_path_file_name = Path(save_path).name
# 确保父目录存在
save_path_parent = Path(save_path).parent
if not save_path_parent.exists():
save_path_parent.mkdir(parents=True)
print(f"Saving {save_path_file_name} {self.save_type} to local path {save_path}")
# 复制文件到保存路径
shutil.copy(local_path, save_path)
# Wandb保存器类,继承自基础保存器类
class WandbSaver(BaseSaver):
# 初始化函数
def __init__(self, data_path: str, wandb_run_path: Optional[str] = None, **kwargs):
# 调用父类初始化函数
super().__init__(data_path, **kwargs)
self.run_path = wandb_run_path
# 初始化函数,初始化wandb并确保用户可以上传到此运行
def init(self, logger: BaseLogger, **kwargs) -> None:
self.wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb logger')
os.environ["WANDB_SILENT"] = "true"
# 确保用户可以上传到此运行
if self.run_path is not None:
entity, project, run_id = self.run_path.split("/")
self.run = self.wandb.init(entity=entity, project=project, id=run_id)
else:
assert self.wandb.run is not None, 'You must be using the wandb logger if you are saving to wandb and have not set `wandb_run_path`'
self.run = self.wandb.run
# TODO: 现在实际检查上传是否可行
print(f"Saving to wandb run {self.run.path}-{self.run.name}")
# 保存文件到指定路径,并在wandb中记录相同的文件结构
def save_file(self, local_path: Path, save_path: str, **kwargs) -> None:
# 获取保存路径中的文件名
save_path_file_name = Path(save_path).name
# 打印保存文件的信息,包括文件名、保存类型和wandb运行的路径和名称
print(f"Saving {save_path_file_name} {self.save_type} to wandb run {self.run.path}-{self.run.name}")
# 将保存路径设置为数据路径加上保存路径
save_path = Path(self.data_path) / save_path
# 创建保存路径的父目录,如果不存在则创建
save_path.parent.mkdir(parents=True, exist_ok=True)
# 复制本地文件到保存路径
shutil.copy(local_path, save_path)
# 在wandb中保存文件,设置基本路径为数据路径,保存策略为立即保存
self.run.save(str(save_path), base_path = str(self.data_path), policy='now')
class HuggingfaceSaver(BaseSaver):
# HuggingfaceSaver 类继承自 BaseSaver 类
def __init__(self, data_path: str, huggingface_repo: str, token_path: Optional[str] = None, **kwargs):
# 初始化方法,接受数据路径、Huggingface 仓库、token 路径等参数
super().__init__(data_path, **kwargs)
# 调用父类的初始化方法
self.huggingface_repo = huggingface_repo
# 设置 Huggingface 仓库
self.token_path = token_path
# 设置 token 路径
def init(self, logger: BaseLogger, **kwargs):
# 初始化方法,接受 logger 和其他参数
# 确保用户可以上传到仓库
self.hub = import_or_print_error('huggingface_hub', '`pip install huggingface_hub` to use the huggingface saver')
# 导入 huggingface_hub 模块
try:
identity = self.hub.whoami() # Errors if not logged in
# 获取当前用户信息,如果未登录则报错
# 然后表示已登录
except:
# 如果未登录,使用 token_path 设置 token
if not os.path.exists(self.token_path):
raise Exception("Not logged in to huggingface and no token_path specified. Please login with `huggingface-cli login` or if that does not work set the token_path.")
with open(self.token_path, "r") as f:
token = f.read().strip()
self.hub.HfApi.set_access_token(token)
identity = self.hub.whoami()
print(f"Saving to huggingface repo {self.huggingface_repo}")
# 打印保存到 Huggingface 仓库的信息
def save_file(self, local_path: Path, save_path: str, **kwargs) -> None:
# 保存文件到 Huggingface 很简单,只需要上传文件并指定正确的名称
save_path_file_name = Path(save_path).name
# 获取保存路径的文件名
print(f"Saving {save_path_file_name} {self.save_type} to huggingface repo {self.huggingface_repo}")
# 打印保存文件的信息
self.hub.upload_file(
path_or_fileobj=str(local_path),
path_in_repo=str(save_path),
repo_id=self.huggingface_repo
)
# 上传文件到 Huggingface 仓库
saver_type_map = {
'local': LocalSaver,
'wandb': WandbSaver,
'huggingface': HuggingfaceSaver
}
# 不同的保存类型映射到不同的 Saver 类
def create_saver(saver_type: str, data_path: str, **kwargs) -> BaseSaver:
# 创建 Saver 对象的方法,接受保存类型、数据路径和其他参数
if saver_type == 'custom':
raise NotImplementedError('Custom savers are not supported yet. Please use a different saver type.')
# 如果是自定义类型,则抛出未实现错误
try:
saver_class = saver_type_map[saver_type]
except KeyError:
raise ValueError(f'Unknown saver type: {saver_type}. Must be one of {list(saver_type_map.keys())}')
# 获取对应保存类型的 Saver 类
return saver_class(data_path, **kwargs)
# 返回创建的 Saver 对象
class Tracker:
# Tracker 类
def __init__(self, data_path: Optional[str] = DEFAULT_DATA_PATH, overwrite_data_path: bool = False, dummy_mode: bool = False):
# 初始化方法,接受数据路径、是否覆盖数据路径和是否为虚拟模式等参数
self.data_path = Path(data_path)
# 设置数据路径为给定的路径
if not dummy_mode:
# 如果不是虚拟模式
if not overwrite_data_path:
assert not self.data_path.exists(), f'Data path {self.data_path} already exists. Set overwrite_data_path to True to overwrite.'
# 断言数据路径不存在,如果存在则报错
if not self.data_path.exists():
self.data_path.mkdir(parents=True)
# 如果数据路径不存在,则创建该路径
self.logger: BaseLogger = None
# 初始化 logger 为 None
self.loader: Optional[BaseLoader] = None
# 初始化 loader 为 None
self.savers: List[BaseSaver]= []
# 初始化 savers 为空列表
self.dummy_mode = dummy_mode
# 设置虚拟模式标志
def _load_auto_resume(self) -> bool:
# 加载自动恢复数据
# 如果文件不存在,则返回 False。如果自动恢复已启用,则打印警告,以便用户知道这是第一次运行。
if not self.auto_resume_path.exists():
if self.logger.auto_resume:
print("Auto_resume is enabled but no auto_resume.json file exists. Assuming this is the first run.")
return False
# 现在我们知道自动恢复文件存在,但如果我们不自动恢复,我们应该删除它,以免下次意外加载它
if not self.logger.auto_resume:
print(f'Removing auto_resume.json because auto_resume is not enabled in the config')
self.auto_resume_path.unlink()
return False
# 否则,我们将将 JSON 读入字典,将覆盖 logger.__dict__ 的部分
with open(self.auto_resume_path, 'r') as f:
auto_resume_dict = json.load(f)
# 检查记录器是否与自动恢复保存的类型相同
if auto_resume_dict["logger_type"] != self.logger.__class__.__name__:
raise Exception(f'The logger type in the auto_resume file is {auto_resume_dict["logger_type"]} but the current logger is {self.logger.__class__.__name__}. Either use the original logger type, set `auto_resume` to `False`, or delete your existing tracker-data folder.')
# 然后我们准备用自动恢复保存覆盖记录器
self.logger.__dict__["resume"] = True
print(f"Updating {self.logger.__dict__} with {auto_resume_dict}")
self.logger.__dict__.update(auto_resume_dict)
return True
def _save_auto_resume(self):
# 从记录器获取自动恢复字典,并将 "logger_type" 添加到其中,然后将其保存到 auto_resume 文件
auto_resume_dict = self.logger.get_resume_data()
auto_resume_dict['logger_type'] = self.logger.__class__.__name__
with open(self.auto_resume_path, 'w') as f:
json.dump(auto_resume_dict, f)
def init(self, full_config: BaseModel, extra_config: dict):
self.auto_resume_path = self.data_path / 'auto_resume.json'
# 检查是否恢复运行
self.did_auto_resume = self._load_auto_resume()
if self.did_auto_resume:
print(f'\n\nWARNING: RUN HAS BEEN AUTO-RESUMED WITH THE LOGGER TYPE {self.logger.__class__.__name__}.\nIf this was not your intention, stop this run and set `auto_resume` to `False` in the config.\n\n')
print(f"New logger config: {self.logger.__dict__}")
self.save_metadata = dict(
version = version.parse(__version__)
) # 将保存在检查点或模型旁��的数据
self.blacklisted_checkpoint_metadata_keys = ['scaler', 'optimizer', 'model', 'version', 'step', 'steps'] # 如果尝试将它们保存为元数据,这些键将导致我们出错
assert self.logger is not None, '`logger` must be set before `init` is called'
if self.dummy_mode:
# 我们唯一需要的是一个加载器
if self.loader is not None:
self.loader.init(self.logger)
return
assert len(self.savers) > 0, '`savers` must be set before `init` is called'
self.logger.init(full_config, extra_config)
if self.loader is not None:
self.loader.init(self.logger)
for saver in self.savers:
saver.init(self.logger)
if self.logger.auto_resume:
# 然后我们需要保存自动恢复文件。假定在调用 logger.init 后,记录器已准备好保存。
self._save_auto_resume()
def add_logger(self, logger: BaseLogger):
self.logger = logger
def add_loader(self, loader: BaseLoader):
self.loader = loader
def add_saver(self, saver: BaseSaver):
self.savers.append(saver)
# 记录日志,如果处于虚拟模式,则直接返回
def log(self, *args, **kwargs):
if self.dummy_mode:
return
# 调用logger对象的log方法记录日志
self.logger.log(*args, **kwargs)
# 记录图片日志,如果处于虚拟模式,则直接返回
def log_images(self, *args, **kwargs):
if self.dummy_mode:
return
# 调用logger对象的log_images方法记录图片日志
self.logger.log_images(*args, **kwargs)
# 记录文件日志,如果处于虚拟模式,则直接返回
def log_file(self, *args, **kwargs):
if self.dummy_mode:
return
# 调用logger对象的log_file方法记录文件日志
self.logger.log_file(*args, **kwargs)
# 保存配置文件,如果处于虚拟模式,则直接返回
def save_config(self, current_config_path: str, config_name = 'config.json'):
if self.dummy_mode:
return
# 将当前配置文件复制到data_path根目录下的config_name文件中
shutil.copy(current_config_path, self.data_path / config_name)
# 遍历所有savers,如果saver正在保存元数据,则将当前配置文件保存到指定路径下
for saver in self.savers:
if saver.saving_meta:
remote_path = Path(saver.save_meta_to) / config_name
saver.save_file(current_config_path, str(remote_path))
# 添加保存元数据,用于与模型或解码器一起保存
def add_save_metadata(self, state_dict_key: str, metadata: Any):
"""
Adds a new piece of metadata that will be saved along with the model or decoder.
"""
# 将元数据添加到save_metadata字典中
self.save_metadata[state_dict_key] = metadata
# 保存状态字典,根据保存类型和文件路径保存状态字典
def _save_state_dict(self, trainer: Union[DiffusionPriorTrainer, DecoderTrainer], save_type: str, file_path: str, **kwargs) -> Path:
"""
Gets the state dict to be saved and writes it to file_path.
If save_type is 'checkpoint', we save the entire trainer state dict.
If save_type is 'model', we save only the model state dict.
"""
assert save_type in ['checkpoint', 'model']
if save_type == 'checkpoint':
# 创建不包含黑名单键的元数据字典,以便在创建状态字典时不出错
metadata = {k: v for k, v in self.save_metadata.items() if k not in self.blacklisted_checkpoint_metadata_keys}
# 保存整个trainer状态字典
trainer.save(file_path, overwrite=True, **kwargs, **metadata)
elif save_type == 'model':
if isinstance(trainer, DiffusionPriorTrainer):
prior = trainer.ema_diffusion_prior.ema_model if trainer.use_ema else trainer.diffusion_prior
prior: DiffusionPrior = trainer.accelerator.unwrap_model(prior)
# 如果模型中包含CLIP,则移除CLIP
original_clip = prior.clip
prior.clip = None
model_state_dict = prior.state_dict()
prior.clip = original_clip
elif isinstance(trainer, DecoderTrainer):
decoder: Decoder = trainer.accelerator.unwrap_model(trainer.decoder)
# 如果模型中包含CLIP,则移除CLIP
original_clip = decoder.clip
decoder.clip = None
if trainer.use_ema:
trainable_unets = decoder.unets
decoder.unets = trainer.unets # 交换EMA unets
model_state_dict = decoder.state_dict()
decoder.unets = trainable_unets # 恢复原始unets
else:
model_state_dict = decoder.state_dict()
decoder.clip = original_clip
else:
raise NotImplementedError('Saving this type of model with EMA mode enabled is not yet implemented. Actually, how did you get here?')
# 构建状态字典,包含save_metadata和模型的state_dict
state_dict = {
**self.save_metadata,
'model': model_state_dict
}
# 将状态字典保存到文件路径中
torch.save(state_dict, file_path)
return Path(file_path)
# 保存训练器的状态和模型到指定路径
def save(self, trainer, is_best: bool, is_latest: bool, **kwargs):
# 如果处于虚拟模式,则直接返回
if self.dummy_mode:
return
# 如果既不是最佳模型也不是最新模型,则无需保存
if not is_best and not is_latest:
# 无需执行任何操作
return
# 保存检查点和模型到指定路径
checkpoint_path = self.data_path / 'checkpoint.pth'
self._save_state_dict(trainer, 'checkpoint', checkpoint_path, **kwargs)
model_path = self.data_path / 'model.pth'
self._save_state_dict(trainer, 'model', model_path, **kwargs)
print("Saved cached models")
# 调用保存器的保存方法
for saver in self.savers:
local_path = checkpoint_path if saver.save_type == 'checkpoint' else model_path
# 如果需要保存最新模型且当前为最新模型,则保存最新模型
if saver.saving_latest and is_latest:
latest_checkpoint_path = saver.save_latest_to.format(**kwargs)
try:
saver.save_file(local_path, latest_checkpoint_path, is_latest=True, **kwargs)
except Exception as e:
self.logger.log_error(f'Error saving checkpoint: {e}', **kwargs)
print(f'Error saving checkpoint: {e}')
# 如果需要保存最佳模型且当前为最佳模型,则保存最佳模型
if saver.saving_best and is_best:
best_checkpoint_path = saver.save_best_to.format(**kwargs)
try:
saver.save_file(local_path, best_checkpoint_path, is_best=True, **kwargs)
except Exception as e:
self.logger.log_error(f'Error saving checkpoint: {e}', **kwargs)
print(f'Error saving checkpoint: {e}')
@property
# 定义是否可以执行回溯操作
def can_recall(self):
return self.loader is not None and (not self.loader.only_auto_resume or self.did_auto_resume)
# 执行回溯操作
def recall(self):
if self.can_recall:
return self.loader.recall()
else:
raise ValueError('Tried to recall, but no loader was set or auto-resume was not performed.')
.\lucidrains\DALLE2-pytorch\dalle2_pytorch\trainer.py
# 导入必要的库
import time
import copy
from pathlib import Path
from math import ceil
from functools import partial, wraps
from contextlib import nullcontext
from collections.abc import Iterable
import torch
import torch.nn.functional as F
from torch import nn
from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR
from torch.cuda.amp import autocast, GradScaler
# 导入自定义模块
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
from dalle2_pytorch.optimizer import get_optimizer
from dalle2_pytorch.version import __version__
from packaging import version
# 导入第三方库
import pytorch_warmup as warmup
from ema_pytorch import EMA
from accelerate import Accelerator, DistributedType
import numpy as np
# 辅助函数
# 检查值是否存在
def exists(val):
return val is not None
# 返回默认值
def default(val, d):
if exists(val):
return val
return d() if callable(d) else d
# 将值转换为元组
def cast_tuple(val, length = 1):
return val if isinstance(val, tuple) else ((val,) * length)
# 从字典中选择指定键的值并弹出这些键
def pick_and_pop(keys, d):
values = list(map(lambda key: d.pop(key), keys))
return dict(zip(keys, values))
# 根据条件将字典分组
def group_dict_by_key(cond, d):
return_val = [dict(),dict()]
for key in d.keys():
match = bool(cond(key))
ind = int(not match)
return_val[ind][key] = d[key]
return (*return_val,)
# 检查字符串是否以指定前缀开头
def string_begins_with(prefix, str):
return str.startswith(prefix)
# 根据键的前缀将字典分组
def group_by_key_prefix(prefix, d):
return group_dict_by_key(partial(string_begins_with, prefix), d)
# 根据前缀将字典分组并修剪键
def groupby_prefix_and_trim(prefix, d):
kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))
return kwargs_without_prefix, kwargs
# 将数字分成若干组
def num_to_groups(num, divisor):
groups = num // divisor
remainder = num % divisor
arr = [divisor] * groups
if remainder > 0:
arr.append(remainder)
return arr
# 装饰器
# 将函数参数转换为 torch 张量
def cast_torch_tensor(fn):
@wraps(fn)
def inner(model, *args, **kwargs):
device = kwargs.pop('_device', next(model.parameters()).device)
cast_device = kwargs.pop('_cast_device', True)
cast_deepspeed_precision = kwargs.pop('_cast_deepspeed_precision', True)
kwargs_keys = kwargs.keys()
all_args = (*args, *kwargs.values())
split_kwargs_index = len(all_args) - len(kwargs_keys)
all_args = tuple(map(lambda t: torch.from_numpy(t) if exists(t) and isinstance(t, np.ndarray) else t, all_args))
if cast_device:
all_args = tuple(map(lambda t: t.to(device) if exists(t) and isinstance(t, torch.Tensor) else t, all_args))
if cast_deepspeed_precision:
try:
accelerator = model.accelerator
if accelerator is not None and accelerator.distributed_type == DistributedType.DEEPSPEED:
cast_type_map = {
"fp16": torch.half,
"bf16": torch.bfloat16,
"no": torch.float
}
precision_type = cast_type_map[accelerator.mixed_precision]
all_args = tuple(map(lambda t: t.to(precision_type) if exists(t) and isinstance(t, torch.Tensor) else t, all_args))
except AttributeError:
# Then this model doesn't have an accelerator
pass
args, kwargs_values = all_args[:split_kwargs_index], all_args[split_kwargs_index:]
kwargs = dict(tuple(zip(kwargs_keys, kwargs_values)))
out = fn(model, *args, **kwargs)
return out
return inner
# 梯度累积函数
# 将可迭代对象分割成指定大小的子集
def split_iterable(it, split_size):
accum = []
for ind in range(ceil(len(it) / split_size)):
start_index = ind * split_size
accum.append(it[start_index: (start_index + split_size)])
return accum
# 如果未提供分割大小,则返回原始对象
def split(t, split_size = None):
if not exists(split_size):
return t
# 检查输入是否为 torch.Tensor 类型
if isinstance(t, torch.Tensor):
# 如果是,则按照指定维度和大小拆分张量
return t.split(split_size, dim=0)
# 检查输入是否为可迭代对象
if isinstance(t, Iterable):
# 如果是,则调用自定义函数 split_iterable() 拆分可迭代对象
return split_iterable(t, split_size)
# 如果输入既不是 torch.Tensor 也不是可迭代对象,则返回类型错误
return TypeError
# 在给定条件下,查找数组中第一个满足条件的元素并返回
def find_first(cond, arr):
for el in arr:
if cond(el):
return el
return None
# 将位置参数和关键字参数拆分成一个包含所有参数值的元组,并计算参数的长度
def split_args_and_kwargs(*args, split_size = None, **kwargs):
# 将所有参数值组合成一个元组
all_args = (*args, *kwargs.values())
len_all_args = len(all_args)
# 查找第一个是 torch.Tensor 类型的参数
first_tensor = find_first(lambda t: isinstance(t, torch.Tensor), all_args)
# 断言第一个参数存在
assert exists(first_tensor)
# 获取第一个参数的长度作为批量大小
batch_size = len(first_tensor)
# 如果未指定拆分大小,则默认为批量大小
split_size = default(split_size, batch_size)
# 计算拆分后的块数
num_chunks = ceil(batch_size / split_size)
# 计算关键字参数的长度和键名
dict_len = len(kwargs)
dict_keys = kwargs.keys()
# 计算关键字参数在拆分后的参数中的索引位置
split_kwargs_index = len_all_args - dict_len
# 对所有参数进行拆分,如果参数是 torch.Tensor 或可迭代对象,则按拆分大小进行拆分,否则复制参数值
split_all_args = [split(arg, split_size = split_size) if exists(arg) and isinstance(arg, (torch.Tensor, Iterable)) else ((arg,) * num_chunks) for arg in all_args]
# 计算每个块的大小
chunk_sizes = tuple(map(len, split_all_args[0]))
# 遍历每个块,将参数和关键字参数拆分成块,并生成块大小的比例和拆分后的参数
for (chunk_size, *chunked_all_args) in tuple(zip(chunk_sizes, *split_all_args)):
chunked_args, chunked_kwargs_values = chunked_all_args[:split_kwargs_index], chunked_all_args[split_kwargs_index:]
chunked_kwargs = dict(tuple(zip(dict_keys, chunked_kwargs_values)))
chunk_size_frac = chunk_size / batch_size
yield chunk_size_frac, (chunked_args, chunked_kwargs)
# 扩散先验训练器
# 将函数分块处理
def prior_sample_in_chunks(fn):
@wraps(fn)
def inner(self, *args, max_batch_size = None, **kwargs):
# 如果未指定最大批量大小,则直接调用函数
if not exists(max_batch_size):
return fn(self, *args, **kwargs)
# 拆分参数并调用函数,将结果拼接在一起
outputs = [fn(self, *chunked_args, **chunked_kwargs) for _, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs)]
return torch.cat(outputs, dim = 0)
return inner
# 扩散先验训练器类
class DiffusionPriorTrainer(nn.Module):
def __init__(
self,
diffusion_prior,
accelerator = None,
use_ema = True,
lr = 3e-4,
wd = 1e-2,
eps = 1e-6,
max_grad_norm = None,
group_wd_params = True,
warmup_steps = None,
cosine_decay_max_steps = None,
**kwargs
# 初始化函数,设置一些成员变量和参数
):
# 调用父类的初始化函数
super().__init__()
# 断言确保传入的参数是 DiffusionPrior 类型的对象
assert isinstance(diffusion_prior, DiffusionPrior)
# 将参数按照前缀 'ema_' 分组并去除前缀,返回未分组的参数和 ema 参数
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
# 将参数按照前缀 'accelerator_' 分组并去除前缀,返回未分组的参数和 accelerator 参数
accelerator_kwargs, kwargs = groupby_prefix_and_trim('accelerator_', kwargs)
# 如果 accelerator 不存在,则根据参数创建一个 Accelerator 对象
if not exists(accelerator):
accelerator = Accelerator(**accelerator_kwargs)
# 设置一些有用的成员变量
self.accelerator = accelerator
self.text_conditioned = diffusion_prior.condition_on_text_encodings
# 设置设备
self.device = accelerator.device
diffusion_prior.to(self.device)
# 保存模型
self.diffusion_prior = diffusion_prior
# 混合精度检查
if (
exists(self.accelerator)
and self.accelerator.distributed_type == DistributedType.DEEPSPEED
and self.diffusion_prior.clip is not None
):
# 确保 clip 使用正确的精度,否则 deepspeed 会报错
cast_type_map = {
"fp16": torch.half,
"bf16": torch.bfloat16,
"no": torch.float
}
precision_type = cast_type_map[accelerator.mixed_precision]
assert precision_type == torch.float, "DeepSpeed currently only supports float32 precision when using on the fly embedding generation from clip"
self.diffusion_prior.clip.to(precision_type)
# 优化器设置
self.optim_kwargs = dict(lr=lr, wd=wd, eps=eps, group_wd_params=group_wd_params)
# 根据参数创建优化器
self.optimizer = get_optimizer(
self.diffusion_prior.parameters(),
**self.optim_kwargs,
**kwargs
)
# 如果存在 cosine_decay_max_steps,则使用 CosineAnnealingLR 调度器,否则使用 LambdaLR 调度器
if exists(cosine_decay_max_steps):
self.scheduler = CosineAnnealingLR(self.optimizer, T_max = cosine_decay_max_steps)
else:
self.scheduler = LambdaLR(self.optimizer, lr_lambda = lambda _: 1.0)
# 如果存在 warmup_steps,则使用 LinearWarmup 调度器
self.warmup_scheduler = warmup.LinearWarmup(self.optimizer, warmup_period = warmup_steps) if exists(warmup_steps) else None
# 如果使用 HFA,则分发模型
self.diffusion_prior, self.optimizer, self.scheduler = self.accelerator.prepare(self.diffusion_prior, self.optimizer, self.scheduler)
# 指数移动平均设置
self.use_ema = use_ema
if self.use_ema:
self.ema_diffusion_prior = EMA(self.accelerator.unwrap_model(self.diffusion_prior), **ema_kwargs)
# 如果需要梯度裁剪
self.max_grad_norm = max_grad_norm
# 内部跟踪步数
self.register_buffer('step', torch.tensor([0], device = self.device))
# 实用函数
def save(self, path, overwrite = True, **kwargs):
# 只在主进程上保存
if self.accelerator.is_main_process:
print(f"Saving checkpoint at step: {self.step.item()}")
path = Path(path)
assert not (path.exists() and not overwrite)
path.parent.mkdir(parents = True, exist_ok = True)
# FIXME: LambdaLR 由于 pickling 问题无法保存
save_obj = dict(
optimizer = self.optimizer.state_dict(),
scheduler = self.scheduler.state_dict(),
warmup_scheduler = self.warmup_scheduler,
model = self.accelerator.unwrap_model(self.diffusion_prior).state_dict(),
version = version.parse(__version__),
step = self.step,
**kwargs
)
# 如果使用指数移动平均,则保存相关参数
if self.use_ema:
save_obj = {
**save_obj,
'ema': self.ema_diffusion_prior.state_dict(),
'ema_model': self.ema_diffusion_prior.ema_model.state_dict() # 为了方便只保存 ema 模型
}
# 保存模型
torch.save(save_obj, str(path))
def load(self, path_or_state, overwrite_lr = True, strict = True):
"""
Load a checkpoint of a diffusion prior trainer.
Will load the entire trainer, including the optimizer and EMA.
Params:
- path_or_state (str | torch): a path to the DiffusionPriorTrainer checkpoint file
- overwrite_lr (bool): wether or not to overwrite the stored LR with the LR specified in the new trainer
- strict (bool): kwarg for `torch.nn.Module.load_state_dict`, will force an exact checkpoint match
Returns:
loaded_obj (dict): The loaded checkpoint dictionary
"""
# all processes need to load checkpoint. no restriction here
if isinstance(path_or_state, str):
path = Path(path_or_state)
assert path.exists()
loaded_obj = torch.load(str(path), map_location=self.device)
elif isinstance(path_or_state, dict):
loaded_obj = path_or_state
if version.parse(__version__) != loaded_obj['version']:
print(f'loading saved diffusion prior at version {loaded_obj["version"]} but current package version is at {__version__}')
# unwrap the model when loading from checkpoint
self.accelerator.unwrap_model(self.diffusion_prior).load_state_dict(loaded_obj['model'], strict = strict)
self.step.copy_(torch.ones_like(self.step, device=self.device) * loaded_obj['step'].to(self.device))
self.optimizer.load_state_dict(loaded_obj['optimizer'])
self.scheduler.load_state_dict(loaded_obj['scheduler'])
# set warmupstep
if exists(self.warmup_scheduler):
self.warmup_scheduler.last_step = self.step.item()
# ensure new lr is used if different from old one
if overwrite_lr:
new_lr = self.optim_kwargs["lr"]
for group in self.optimizer.param_groups:
group["lr"] = new_lr if group["lr"] > 0.0 else 0.0
if self.use_ema:
assert 'ema' in loaded_obj
self.ema_diffusion_prior.load_state_dict(loaded_obj['ema'], strict = strict)
# below might not be necessary, but I had a suspicion that this wasn't being loaded correctly
self.ema_diffusion_prior.ema_model.load_state_dict(loaded_obj["ema_model"])
return loaded_obj
# model functionality
def update(self):
if exists(self.max_grad_norm):
self.accelerator.clip_grad_norm_(self.diffusion_prior.parameters(), self.max_grad_norm)
self.optimizer.step()
self.optimizer.zero_grad()
# accelerator will ocassionally skip optimizer steps in a "dynamic loss scaling strategy"
if not self.accelerator.optimizer_step_was_skipped:
sched_context = self.warmup_scheduler.dampening if exists(self.warmup_scheduler) else nullcontext
with sched_context():
self.scheduler.step()
if self.use_ema:
self.ema_diffusion_prior.update()
self.step += 1
@torch.no_grad()
@cast_torch_tensor
@prior_sample_in_chunks
def p_sample_loop(self, *args, **kwargs):
model = self.ema_diffusion_prior.ema_model if self.use_ema else self.diffusion_prior
return model.p_sample_loop(*args, **kwargs)
@torch.no_grad()
@cast_torch_tensor
@prior_sample_in_chunks
def sample(self, *args, **kwargs):
model = self.ema_diffusion_prior.ema_model if self.use_ema else self.diffusion_prior
return model.sample(*args, **kwargs)
@torch.no_grad()
def sample_batch_size(self, *args, **kwargs):
model = self.ema_diffusion_prior.ema_model if self.use_ema else self.diffusion_prior
return model.sample_batch_size(*args, **kwargs)
@torch.no_grad()
@cast_torch_tensor
@prior_sample_in_chunks
# 调用加速器对象的unwrap_model方法,将扩散先验解包后调用clip对象的embed_text方法,返回结果
def embed_text(self, *args, **kwargs):
return self.accelerator.unwrap_model(self.diffusion_prior).clip.embed_text(*args, **kwargs)
# 使用装饰器将函数参数转换为torch张量
def forward(
self,
*args,
max_batch_size = None,
**kwargs
):
# 初始化总损失为0
total_loss = 0.
# 将参数和关键字参数按照指定大小分块,遍历每个分块
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
# 使用加速器对象的autocast方法进行自动混合精度计算
with self.accelerator.autocast():
# 调用扩散先验函数,传入分块参数和关键字参数,计算损失
loss = self.diffusion_prior(*chunked_args, **chunked_kwargs)
# 将损失乘以分块大小比例
loss = loss * chunk_size_frac
# 将损失值加到总损失中
total_loss += loss.item()
# 如果处于训练状态,使用加速器对象的backward方法进行反向传播
if self.training:
self.accelerator.backward(loss)
# 返回总损失值
return total_loss
# 解码器训练器
# 定义一个装饰器函数,用于将输入数据分成多个批次进行处理
def decoder_sample_in_chunks(fn):
@wraps(fn)
def inner(self, *args, max_batch_size = None, **kwargs):
# 如果未指定最大批次大小,则直接调用原始函数
if not exists(max_batch_size):
return fn(self, *args, **kwargs)
# 如果解码器是无条件的,则将批次大小分组成多个子批次进行处理
if self.decoder.unconditional:
batch_size = kwargs.get('batch_size')
batch_sizes = num_to_groups(batch_size, max_batch_size)
outputs = [fn(self, *args, **{**kwargs, 'batch_size': sub_batch_size}) for sub_batch_size in batch_sizes]
else:
# 如果解码器是有条件的,则将输入数据分成多个子块进行处理
outputs = [fn(self, *chunked_args, **chunked_kwargs) for _, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs)]
# 将所有子批次或子块的输出拼接在一起
return torch.cat(outputs, dim = 0)
return inner
# 定义解码器训练器类
class DecoderTrainer(nn.Module):
def __init__(
self,
decoder,
accelerator = None,
dataloaders = None,
use_ema = True,
lr = 1e-4,
wd = 1e-2,
eps = 1e-8,
warmup_steps = None,
cosine_decay_max_steps = None,
max_grad_norm = 0.5,
amp = False,
group_wd_params = True,
**kwargs
):
# 调用父类的构造函数
super().__init__()
# 断言确保decoder是Decoder类型的实例
assert isinstance(decoder, Decoder)
# 将参数中以'ema_'开头的参数分组并去除前缀,返回两个字典
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
# 设置加速器,默认为Accelerator
self.accelerator = default(accelerator, Accelerator)
# 获取decoder中包含的unet数量
self.num_unets = len(decoder.unets)
# 设置是否使用指数移动平均
self.use_ema = use_ema
# 初始化ema_unets为一个空的ModuleList
self.ema_unets = nn.ModuleList([])
# 设置是否使用混合精度训练
self.amp = amp
# 可以对每个unet进行学习率、权重衰减等参数的细致定制
# 将lr, wd, eps, warmup_steps, cosine_decay_max_steps映射为长度为num_unets的元组
lr, wd, eps, warmup_steps, cosine_decay_max_steps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps, warmup_steps, cosine_decay_max_steps))
# 断言确保所有unet的学习率都不超过1e-2
assert all([unet_lr <= 1e-2 for unet_lr in lr]), 'your learning rate is too high, recommend sticking with 1e-4, at most 5e-4'
# 初始化优化器、调度器和预热调度器列表
optimizers = []
schedulers = []
warmup_schedulers = []
# 遍历decoder中的unets以及对应的lr, wd, eps, warmup_steps, cosine_decay_max_steps
for unet, unet_lr, unet_wd, unet_eps, unet_warmup_steps, unet_cosine_decay_max_steps in zip(decoder.unets, lr, wd, eps, warmup_steps, cosine_decay_max_steps):
# 如果unet是nn.Identity类型,则添加None到列表中
if isinstance(unet, nn.Identity):
optimizers.append(None)
schedulers.append(None)
warmup_schedulers.append(None)
else:
# 获取unet的参数,初始化优化器
optimizer = get_optimizer(
unet.parameters(),
lr = unet_lr,
wd = unet_wd,
eps = unet_eps,
group_wd_params = group_wd_params,
**kwargs
)
optimizers.append(optimizer)
# 初始化调度器和预热调度器
if exists(unet_cosine_decay_max_steps):
scheduler = CosineAnnealingLR(optimizer, T_max = unet_cosine_decay_max_steps)
else:
scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0)
warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps) if exists(unet_warmup_steps) else None
warmup_schedulers.append(warmup_scheduler)
schedulers.append(scheduler)
# 如果使用指数移动平均,则将unet添加到ema_unets中
if self.use_ema:
self.ema_unets.append(EMA(unet, **ema_kwargs))
# 如果需要梯度裁剪
self.max_grad_norm = max_grad_norm
# 注册一个名为steps的缓冲区,值为长度为num_unets的全零张量
self.register_buffer('steps', torch.tensor([0] * self.num_unets))
# 如果使用的分布式类型是DEEPSPEED且decoder中有clip参数
if self.accelerator.distributed_type == DistributedType.DEEPSPEED and decoder.clip is not None:
# 确保clip使用正确的精度,否则会出错
cast_type_map = {
"fp16": torch.half,
"bf16": torch.bfloat16,
"no": torch.float
}
precision_type = cast_type_map[accelerator.mixed_precision]
assert precision_type == torch.float, "DeepSpeed currently only supports float32 precision when using on the fly embedding generation from clip"
clip = decoder.clip
clip.to(precision_type)
# 准备decoder和optimizers
decoder, *optimizers = list(self.accelerator.prepare(decoder, *optimizers))
self.decoder = decoder
# 准备数据加载器
train_loader = val_loader = None
if exists(dataloaders):
train_loader, val_loader = self.accelerator.prepare(dataloaders["train"], dataloaders["val"])
self.train_loader = train_loader
self.val_loader = val_loader
# 存储优化器
for opt_ind, optimizer in zip(range(len(optimizers)), optimizers):
setattr(self, f'optim{opt_ind}', optimizer)
# 存储调度器
for sched_ind, scheduler in zip(range(len(schedulers)), schedulers):
setattr(self, f'sched{sched_ind}', scheduler)
# 存储预热调度器
self.warmup_schedulers = warmup_schedulers
# 验证并返回unet的编号
def validate_and_return_unet_number(self, unet_number = None):
# 如果只有一个unet,则默认unet_number为1
if self.num_unets == 1:
unet_number = default(unet_number, 1)
# 断言确保unet_number存在且在1到num_unets之间
assert exists(unet_number) and 1 <= unet_number <= self.num_unets
return unet_number
# 返回指定 UNet 编号已经执行的步数
def num_steps_taken(self, unet_number = None):
# 验证并返回 UNet 编号
unet_number = self.validate_and_return_unet_number(unet_number)
# 返回指定 UNet 编号已经执行的步数
return self.steps[unet_number - 1].item()
# 保存模型状态到指定路径
def save(self, path, overwrite = True, **kwargs):
# 转换路径为 Path 对象
path = Path(path)
# 断言路径不存在或者可以覆盖
assert not (path.exists() and not overwrite)
# 创建父目录
path.parent.mkdir(parents = True, exist_ok = True)
# 构建保存对象字典
save_obj = dict(
model = self.accelerator.unwrap_model(self.decoder).state_dict(),
version = __version__,
steps = self.steps.cpu(),
**kwargs
)
# 遍历 UNet 数量
for ind in range(0, self.num_unets):
optimizer_key = f'optim{ind}'
scheduler_key = f'sched{ind}'
optimizer = getattr(self, optimizer_key)
scheduler = getattr(self, scheduler_key)
optimizer_state_dict = optimizer.state_dict() if exists(optimizer) else None
scheduler_state_dict = scheduler.state_dict() if exists(scheduler) else None
# 更新保存对象字典
save_obj = {**save_obj, optimizer_key: optimizer_state_dict, scheduler_key: scheduler_state_dict}
# 如果使用 EMA,更新保存对象字典
if self.use_ema:
save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()}
# 保存模型状态到指定路径
self.accelerator.save(save_obj, str(path))
# 加载模型状态
def load_state_dict(self, loaded_obj, only_model = False, strict = True):
# 检查版本是否匹配
if version.parse(__version__) != version.parse(loaded_obj['version']):
self.accelerator.print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {__version__}')
# 加载模型状态
self.accelerator.unwrap_model(self.decoder).load_state_dict(loaded_obj['model'], strict = strict)
self.steps.copy_(loaded_obj['steps'])
# 如果只加载模型状态,直接返回加载的对象
if only_model:
return loaded_obj
# 遍历 UNet 数量,加载优化器和调度器状态
for ind, last_step in zip(range(0, self.num_unets), self.steps.tolist()):
optimizer_key = f'optim{ind}'
optimizer = getattr(self, optimizer_key)
scheduler_key = f'sched{ind}'
scheduler = getattr(self, scheduler_key)
warmup_scheduler = self.warmup_schedulers[ind]
if exists(optimizer):
optimizer.load_state_dict(loaded_obj[optimizer_key])
if exists(scheduler):
scheduler.load_state_dict(loaded_obj[scheduler_key])
if exists(warmup_scheduler):
warmup_scheduler.last_step = last_step
# 如果使用 EMA,加载 EMA 模型状态
if self.use_ema:
assert 'ema' in loaded_obj
self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict)
# 加载模型状态
def load(self, path, only_model = False, strict = True):
# 转换路径为 Path 对象
path = Path(path)
# 断言路径存在
assert path.exists()
# 加载模型状态
loaded_obj = torch.load(str(path), map_location = 'cpu')
# 调用 load_state_dict 方法加载模型状态
self.load_state_dict(loaded_obj, only_model = only_model, strict = strict)
return loaded_obj
# 返回 EMA 模型列表
@property
def unets(self):
return nn.ModuleList([ema.ema_model for ema in self.ema_unets])
# 增加步数
def increment_step(self, unet_number):
# 断言 UNet 编号在有效范围内
assert 1 <= unet_number <= self.num_unets
# 转换 UNet 编号为张量
unet_index_tensor = torch.tensor(unet_number - 1, device = self.steps.device)
# 增加步数
self.steps += F.one_hot(unet_index_tensor, num_classes = len(self.steps))
# 更新模型参数
def update(self, unet_number = None):
# 验证并返回UNET编号
unet_number = self.validate_and_return_unet_number(unet_number)
index = unet_number - 1
# 获取对应的优化器和调度器
optimizer = getattr(self, f'optim{index}')
scheduler = getattr(self, f'sched{index}')
# 如果存在最大梯度范数,则对解码器参数进行梯度裁剪
if exists(self.max_grad_norm):
self.accelerator.clip_grad_norm_(self.decoder.parameters(), self.max_grad_norm) # Automatically unscales gradients
# 执行优化器的步骤和梯度清零操作
optimizer.step()
optimizer.zero_grad()
# 获取热身调度器,并根据是否存在进行相应操作
warmup_scheduler = self.warmup_schedulers[index]
scheduler_context = warmup_scheduler.dampening if exists(warmup_scheduler) else nullcontext
# 在上下文中执行调度器的步骤
with scheduler_context():
scheduler.step()
# 如果使用指数移动平均模型,则更新模型
if self.use_ema:
ema_unet = self.ema_unets[index]
ema_unet.update()
# 增加步数
self.increment_step(unet_number)
# 生成样本
@torch.no_grad()
@cast_torch_tensor
@decoder_sample_in_chunks
def sample(self, *args, **kwargs):
distributed = self.accelerator.num_processes > 1
base_decoder = self.accelerator.unwrap_model(self.decoder)
was_training = base_decoder.training
base_decoder.eval()
# 根据是否使用EMA模型进行采样
if kwargs.pop('use_non_ema', False) or not self.use_ema:
out = base_decoder.sample(*args, **kwargs, distributed = distributed)
base_decoder.train(was_training)
return out
# 切换为指数移动平均UNET进行采样
trainable_unets = self.accelerator.unwrap_model(self.decoder).unets
base_decoder.unets = self.unets # swap in exponential moving averaged unets for sampling
output = base_decoder.sample(*args, **kwargs, distributed = distributed)
base_decoder.unets = trainable_unets # restore original training unets
# 将EMA模型UNET转回原始设备
for ema in self.ema_unets:
ema.restore_ema_model_device()
base_decoder.train(was_training)
return output
# 嵌入文本
@torch.no_grad()
@cast_torch_tensor
@prior_sample_in_chunks
def embed_text(self, *args, **kwargs):
return self.accelerator.unwrap_model(self.decoder).clip.embed_text(*args, **kwargs)
# 嵌入图像
@torch.no_grad()
@cast_torch_tensor
@prior_sample_in_chunks
def embed_image(self, *args, **kwargs):
return self.accelerator.unwrap_model(self.decoder).clip.embed_image(*args, **kwargs)
# 前向传播
@cast_torch_tensor
def forward(
self,
*args,
unet_number = None,
max_batch_size = None,
return_lowres_cond_image=False,
**kwargs
):
# 验证并返回UNET编号
unet_number = self.validate_and_return_unet_number(unet_number)
total_loss = 0.
cond_images = []
# 将参数拆分为指定大小的块,并进行处理
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
with self.accelerator.autocast():
# 调用解码器进行前向传播,计算损失
loss_obj = self.decoder(*chunked_args, unet_number = unet_number, return_lowres_cond_image=return_lowres_cond_image, **chunked_kwargs)
# 如果需要返回低分辨率条件图像,则提取出来
if return_lowres_cond_image:
loss, cond_image = loss_obj
else:
loss = loss_obj
cond_image = None
loss = loss * chunk_size_frac
if cond_image is not None:
cond_images.append(cond_image)
total_loss += loss.item()
# 如果处于训练状态,则进行反向传播
if self.training:
self.accelerator.backward(loss)
# 如果需要返回低分辨率条件图像,则返回总损失和条件图像的张量
if return_lowres_cond_image:
return total_loss, torch.stack(cond_images)
else:
return total_loss