Lucidrains 系列项目源码解析(四十六)
.\lucidrains\imagen-pytorch\imagen_pytorch\utils.py
# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块
from torch import nn
# 从 functools 库中导入 reduce 函数
from functools import reduce
# 从 pathlib 库中导入 Path 类
from pathlib import Path
# 从 imagen_pytorch.configs 模块中导入 ImagenConfig 和 ElucidatedImagenConfig 类
from imagen_pytorch.configs import ImagenConfig, ElucidatedImagenConfig
# 从 ema_pytorch 模块中导入 EMA 类
from ema_pytorch import EMA
# 定义一个函数,用于检查变量是否存在
def exists(val):
return val is not None
# 定义一个函数,用于安全获取字典中的值
def safeget(dictionary, keys, default = None):
return reduce(lambda d, key: d.get(key, default) if isinstance(d, dict) else default, keys.split('.'), dictionary)
# 加载模型和配置信息
def load_imagen_from_checkpoint(
checkpoint_path,
load_weights = True,
load_ema_if_available = False
):
# 创建 Path 对象
model_path = Path(checkpoint_path)
# 获取完整的模型路径
full_model_path = str(model_path.resolve())
# 断言模型路径存在
assert model_path.exists(), f'checkpoint not found at {full_model_path}'
# 加载模型参数
loaded = torch.load(str(model_path), map_location='cpu')
# 获取 imagen 参数和类型
imagen_params = safeget(loaded, 'imagen_params')
imagen_type = safeget(loaded, 'imagen_type')
# 根据 imagen 类型选择对应的配置类
if imagen_type == 'original':
imagen_klass = ImagenConfig
elif imagen_type == 'elucidated':
imagen_klass = ElucidatedImagenConfig
else:
raise ValueError(f'unknown imagen type {imagen_type} - you need to instantiate your Imagen with configurations, using classes ImagenConfig or ElucidatedImagenConfig')
# 断言 imagen 参数和类型存在
assert exists(imagen_params) and exists(imagen_type), 'imagen type and configuration not saved in this checkpoint'
# 根据配置类和参数创建 imagen 对象
imagen = imagen_klass(**imagen_params).create()
# 如果不加载权重,则直接返回 imagen 对象
if not load_weights:
return imagen
# 检查是否存在 EMA 模型
has_ema = 'ema' in loaded
should_load_ema = has_ema and load_ema_if_available
# 加载模型参数
imagen.load_state_dict(loaded['model'])
# 如果不需要加载 EMA 模型,则直接返回 imagen 对象
if not should_load_ema:
print('loading non-EMA version of unets')
return imagen
# 创建 EMA 模型列表
ema_unets = nn.ModuleList([])
# 遍历 imagen.unets,为每个 unet 创建一个 EMA 模型
for unet in imagen.unets:
ema_unets.append(EMA(unet))
# 加载 EMA 模型参数
ema_unets.load_state_dict(loaded['ema'])
# 将 EMA 模型参数加载到对应的 unet 模型中
for unet, ema_unet in zip(imagen.unets, ema_unets):
unet.load_state_dict(ema_unet.ema_model.state_dict())
# 打印信息并返回 imagen 对象
print('loaded EMA version of unets')
return imagen
.\lucidrains\imagen-pytorch\imagen_pytorch\version.py
# 定义变量 __version__,赋值为字符串 '1.26.2'
__version__ = '1.26.2'
.\lucidrains\imagen-pytorch\imagen_pytorch\__init__.py
# 从 imagen_pytorch 模块中导入 Imagen 和 Unet 类
from imagen_pytorch.imagen_pytorch import Imagen, Unet
# 从 imagen_pytorch 模块中导入 NullUnet 类
from imagen_pytorch.imagen_pytorch import NullUnet
# 从 imagen_pytorch 模块中导入 BaseUnet64, SRUnet256, SRUnet1024 类
from imagen_pytorch.imagen_pytorch import BaseUnet64, SRUnet256, SRUnet1024
# 从 imagen_pytorch 模块中导入 ImagenTrainer 类
from imagen_pytorch.trainer import ImagenTrainer
# 从 imagen_pytorch 模块中导入 __version__ 变量
from imagen_pytorch.version import __version__
# 使用 Tero Karras 的新论文中阐述的 ddpm 创建 imagen
# 从 imagen_pytorch 模块中导入 ElucidatedImagen 类
from imagen_pytorch.elucidated_imagen import ElucidatedImagen
# 通过配置创建 imagen 实例
# 从 imagen_pytorch 模块中导入 UnetConfig, ImagenConfig, ElucidatedImagenConfig, ImagenTrainerConfig 类
from imagen_pytorch.configs import UnetConfig, ImagenConfig, ElucidatedImagenConfig, ImagenTrainerConfig
# 工具
# 从 imagen_pytorch 模块中导入 load_imagen_from_checkpoint 函数
from imagen_pytorch.utils import load_imagen_from_checkpoint
# 视频
# 从 imagen_pytorch 模块中导入 Unet3D 类
from imagen_pytorch.imagen_video import Unet3D

Imagen - Pytorch
Implementation of Imagen, Google's Text-to-Image Neural Network that beats DALL-E2, in Pytorch. It is the new SOTA for text-to-image synthesis.
Architecturally, it is actually much simpler than DALL-E2. It consists of a cascading DDPM conditioned on text embeddings from a large pretrained T5 model (attention network). It also contains dynamic clipping for improved classifier free guidance, noise level conditioning, and a memory efficient unet design.
It appears neither CLIP nor prior network is needed after all. And so research continues.
AI Coffee Break with Letitia | Assembly AI | Yannic Kilcher
Please join if you are interested in helping out with the replication with the LAION community
Shoutouts
-
StabilityAI for the generous sponsorship, as well as my other sponsors out there
-
🤗 Huggingface for their amazing transformers library. The text encoder portion is pretty much taken care of because of them
-
Jonathan Ho for bringing about a revolution in generative artificial intelligence through his seminal paper
-
Sylvain and Zachary for the Accelerate library, which this repository uses for distributed training
-
Jorge Gomes for helping out with the T5 loading code and advice on the correct T5 version
-
Katherine Crowson, for her beautiful code, which helped me understand the continuous time version of gaussian diffusion
-
Marunine and Netruk44, for reviewing code, sharing experimental results, and help with debugging
-
Marunine for providing a potential solution for a color shifting issue in the memory efficient u-nets. Thanks to Jacob for sharing experimental comparisons between the base and memory-efficient unets
-
Marunine for finding numerous bugs, resolving an issue with resize right, and for sharing his experimental configurations and results
-
MalumaDev for proposing the use of pixel shuffle upsampler to fix checkboard artifacts
-
Valentin for pointing out insufficient skip connections in the unet, as well as the specific method of attention conditioning in the base-unet in the appendix
-
BIGJUN for catching a big bug with continuous time gaussian diffusion noise level conditioning at inference time
-
Bingbing for identifying a bug with sampling and order of normalizing and noising with low resolution conditioning image
-
Kay for contributing one line command training of Imagen!
-
Hadrien Reynaud for testing out text-to-video on a medical dataset, sharing his results, and identifying issues!
Install
$ pip install imagen-pytorch
Usage
import torch
from imagen_pytorch import Unet, Imagen
# unet for imagen
unet1 = Unet(
dim = 32,
cond_dim = 512,
dim_mults = (1, 2, 4, 8),
num_resnet_blocks = 3,
layer_attns = (False, True, True, True),
layer_cross_attns = (False, True, True, True)
)
unet2 = Unet(
dim = 32,
cond_dim = 512,
dim_mults = (1, 2, 4, 8),
num_resnet_blocks = (2, 4, 8, 8),
layer_attns = (False, False, False, True),
layer_cross_attns = (False, False, False, True)
)
# imagen, which contains the unets above (base unet and super resoluting ones)
imagen = Imagen(
unets = (unet1, unet2),
image_sizes = (64, 256),
timesteps = 1000,
cond_drop_prob = 0.1
).cuda()
# mock images (get a lot of this) and text encodings from large T5
text_embeds = torch.randn(4, 256, 768).cuda()
images = torch.randn(4, 3, 256, 256).cuda()
# feed images into imagen, training each unet in the cascade
for i in (1, 2):
loss = imagen(images, text_embeds = text_embeds, unet_number = i)
loss.backward()
# do the above for many many many many steps
# now you can sample an image based on the text embeddings from the cascading ddpm
images = imagen.sample(texts = [
'a whale breaching from afar',
'young girl blowing out candles on her birthday cake',
'fireworks with blue and green sparkles'
], cond_scale = 3.)
images.shape # (3, 3, 256, 256)
For simpler training, you can directly supply text strings instead of precomputing text encodings. (Although for scaling purposes, you will definitely want to precompute the textual embeddings + mask)
The number of textual captions must match the batch size of the images if you go this route.
# mock images and text (get a lot of this)
texts = [
'a child screaming at finding a worm within a half-eaten apple',
'lizard running across the desert on two feet',
'waking up to a psychedelic landscape',
'seashells sparkling in the shallow waters'
]
images = torch.randn(4, 3, 256, 256).cuda()
# feed images into imagen, training each unet in the cascade
for i in (1, 2):
loss = imagen(images, texts = texts, unet_number = i)
loss.backward()
With the ImagenTrainer wrapper class, the exponential moving averages for all of the U-nets in the cascading DDPM will be automatically taken care of when calling update
import torch
from imagen_pytorch import Unet, Imagen, ImagenTrainer
# unet for imagen
unet1 = Unet(
dim = 32,
cond_dim = 512,
dim_mults = (1, 2, 4, 8),
num_resnet_blocks = 3,
layer_attns = (False, True, True, True),
)
unet2 = Unet(
dim = 32,
cond_dim = 512,
dim_mults = (1, 2, 4, 8),
num_resnet_blocks = (2, 4, 8, 8),
layer_attns = (False, False, False, True),
layer_cross_attns = (False, False, False, True)
)
# imagen, which contains the unets above (base unet and super resoluting ones)
imagen = Imagen(
unets = (unet1, unet2),
text_encoder_name = 't5-large',
image_sizes = (64, 256),
timesteps = 1000,
cond_drop_prob = 0.1
).cuda()
# wrap imagen with the trainer class
trainer = ImagenTrainer(imagen)
# mock images (get a lot of this) and text encodings from large T5
text_embeds = torch.randn(64, 256, 1024).cuda()
images = torch.randn(64, 3, 256, 256).cuda()
# feed images into imagen, training each unet in the cascade
loss = trainer(
images,
text_embeds = text_embeds,
unet_number = 1, # training on unet number 1 in this example, but you will have to also save checkpoints and then reload and continue training on unet number 2
max_batch_size = 4 # auto divide the batch of 64 up into batch size of 4 and accumulate gradients, so it all fits in memory
)
trainer.update(unet_number = 1)
# do the above for many many many many steps
# now you can sample an image based on the text embeddings from the cascading ddpm
images = trainer.sample(texts = [
'a puppy looking anxiously at a giant donut on the table',
'the milky way galaxy in the style of monet'
], cond_scale = 3.)
images.shape # (2, 3, 256, 256)
You can also train Imagen without text (unconditional image generation) as follows
import torch
from imagen_pytorch import Unet, Imagen, SRUnet256, ImagenTrainer
# unets for unconditional imagen
unet1 = Unet(
dim = 32,
dim_mults = (1, 2, 4),
num_resnet_blocks = 3,
layer_attns = (False, True, True),
layer_cross_attns = False,
use_linear_attn = True
)
unet2 = SRUnet256(
dim = 32,
dim_mults = (1, 2, 4),
num_resnet_blocks = (2, 4, 8),
layer_attns = (False, False, True),
layer_cross_attns = False
)
# imagen, which contains the unets above (base unet and super resoluting ones)
imagen = Imagen(
condition_on_text = False, # this must be set to False for unconditional Imagen
unets = (unet1, unet2),
image_sizes = (64, 128),
timesteps = 1000
)
trainer = ImagenTrainer(imagen).cuda()
# now get a ton of images and feed it through the Imagen trainer
training_images = torch.randn(4, 3, 256, 256).cuda()
# train each unet separately
# in this example, only training on unet number 1
loss = trainer(training_images, unet_number = 1)
trainer.update(unet_number = 1)
# do the above for many many many many steps
# now you can sample images unconditionally from the cascading unet(s)
images = trainer.sample(batch_size = 16) # (16, 3, 128, 128)
Or train only super-resoluting unets
import torch
from imagen_pytorch import Unet, NullUnet, Imagen
# unet for imagen
unet1 = NullUnet() # add a placeholder "null" unet for the base unet
unet2 = Unet(
dim = 32,
cond_dim = 512,
dim_mults = (1, 2, 4, 8),
num_resnet_blocks = (2, 4, 8, 8),
layer_attns = (False, False, False, True),
layer_cross_attns = (False, False, False, True)
)
# imagen, which contains the unets above (base unet and super resoluting ones)
imagen = Imagen(
unets = (unet1, unet2),
image_sizes = (64, 256),
timesteps = 250,
cond_drop_prob = 0.1
).cuda()
# mock images (get a lot of this) and text encodings from large T5
text_embeds = torch.randn(4, 256, 768).cuda()
images = torch.randn(4, 3, 256, 256).cuda()
# feed images into imagen, training each unet in the cascade
loss = imagen(images, text_embeds = text_embeds, unet_number = 2)
loss.backward()
# do the above for many many many many steps
# now you can sample an image based on the text embeddings as well as low resolution images
lowres_images = torch.randn(3, 3, 64, 64).cuda() # starting un-resoluted images
images = imagen.sample(
texts = [
'a whale breaching from afar',
'young girl blowing out candles on her birthday cake',
'fireworks with blue and green sparkles'
],
start_at_unet_number = 2, # start at unet number 2
start_image_or_video = lowres_images, # pass in low resolution images to be resoluted
cond_scale = 3.)
images.shape # (3, 3, 256, 256)
At any time you can save and load the trainer and all associated states with the save and load methods. It is recommended you use these methods instead of manually saving with a state_dict call, as there are some device memory management being done underneath the hood within the trainer.
ex.
trainer.save('./path/to/checkpoint.pt')
trainer.load('./path/to/checkpoint.pt')
trainer.steps # (2,) step number for each of the unets, in this case 2
Dataloader
You can also rely on the ImagenTrainer to automatically train off DataLoader instances. You simply have to craft your DataLoader to return either images (for unconditional case), or of ('images', 'text_embeds') for text-guided generation.
ex. unconditional training
from imagen_pytorch import Unet, Imagen, ImagenTrainer
from imagen_pytorch.data import Dataset
# unets for unconditional imagen
unet = Unet(
dim = 32,
dim_mults = (1, 2, 4, 8),
num_resnet_blocks = 1,
layer_attns = (False, False, False, True),
layer_cross_attns = False
)
# imagen, which contains the unet above
imagen = Imagen(
condition_on_text = False, # this must be set to False for unconditional Imagen
unets = unet,
image_sizes = 128,
timesteps = 1000
)
trainer = ImagenTrainer(
imagen = imagen,
split_valid_from_train = True # whether to split the validation dataset from the training
).cuda()
# instantiate your dataloader, which returns the necessary inputs to the DDPM as tuple in the order of images, text embeddings, then text masks. in this case, only images is returned as it is unconditional training
dataset = Dataset('/path/to/training/images', image_size = 128)
trainer.add_train_dataset(dataset, batch_size = 16)
# working training loop
for i in range(200000):
loss = trainer.train_step(unet_number = 1, max_batch_size = 4)
print(f'loss: {loss}')
if not (i % 50):
valid_loss = trainer.valid_step(unet_number = 1, max_batch_size = 4)
print(f'valid loss: {valid_loss}')
if not (i % 100) and trainer.is_main: # is_main makes sure this can run in distributed
images = trainer.sample(batch_size = 1, return_pil_images = True) # returns List[Image]
images[0].save(f'./sample-{i // 100}.png')
Multi GPU
Thanks to 🤗 Accelerate, you can do multi GPU training easily with two steps.
First you need to invoke accelerate config in the same directory as your training script (say it is named train.py)
$ accelerate config
Next, instead of calling python train.py as you would for single GPU, you would use the accelerate CLI as so
$ accelerate launch train.py
That's it!
Command-line
Imagen can also be used via CLI directly.
Configuration
ex.
$ imagen config
or
$ imagen config --path ./configs/config.json
In the config you are able to change settings for the trainer, dataset and the imagen config.
The Imagen config parameters can be found here
The Elucidated Imagen config parameters can be found here
The Imagen Trainer config parameters can be found here
For the dataset parameters all dataloader parameters can be used.
Training
This command allows you to train or resume training your model
ex.
$ imagen train
or
$ imagen train --unet 2 --epoches 10
You can pass following arguments to the training command.
--configspecify the config file to use for training [default: ./imagen_config.json]--unetthe index of the unet to train [default: 1]--epocheshow many epoches to train for [default: 50]
Sampling
Be aware when sampling your checkpoint should have trained all unets to get a usable result.
ex.
$ imagen sample --model ./path/to/model/checkpoint.pt "a squirrel raiding the birdfeeder"
# image is saved to ./a_squirrel_raiding_the_birdfeeder.png
You can pass following arguments to the sample command.
--modelspecify the model file to use for sampling--cond_scaleconditioning scale (classifier free guidance) in decoder--load_emaload EMA version of unets if available
In order to use a saved checkpoint with this feature, you either must instantiate your Imagen instance using the config classes, ImagenConfig and ElucidatedImagenConfig or create a checkpoint via the CLI directly
For proper training, you'll likely want to setup config-driven training anyways.
ex.
import torch
from imagen_pytorch import ImagenConfig, ElucidatedImagenConfig, ImagenTrainer
# in this example, using elucidated imagen
imagen = ElucidatedImagenConfig(
unets = [
dict(dim = 32, dim_mults = (1, 2, 4, 8)),
dict(dim = 32, dim_mults = (1, 2, 4, 8))
],
image_sizes = (64, 128),
cond_drop_prob = 0.5,
num_sample_steps = 32
).create()
trainer = ImagenTrainer(imagen)
# do your training ...
# then save it
trainer.save('./checkpoint.pt')
# you should see a message informing you that ./checkpoint.pt is commandable from the terminal
It really should be as simple as that
You can also pass this checkpoint file around, and anyone can continue finetune on their own data
from imagen_pytorch import load_imagen_from_checkpoint, ImagenTrainer
imagen = load_imagen_from_checkpoint('./checkpoint.pt')
trainer = ImagenTrainer(imagen)
# continue training / fine-tuning
Inpainting
Inpainting follows the formulation laid out by the recent Repaint paper. Simply pass in inpaint_images and inpaint_masks to the sample function on either Imagen or ElucidatedImagen
inpaint_images = torch.randn(4, 3, 512, 512).cuda() # (batch, channels, height, width)
inpaint_masks = torch.ones((4, 512, 512)).bool().cuda() # (batch, height, width)
inpainted_images = trainer.sample(texts = [
'a whale breaching from afar',
'young girl blowing out candles on her birthday cake',
'fireworks with blue and green sparkles',
'dust motes swirling in the morning sunshine on the windowsill'
], inpaint_images = inpaint_images, inpaint_masks = inpaint_masks, cond_scale = 5.)
inpainted_images # (4, 3, 512, 512)
For video, similarly pass in your videos to inpaint_videos keyword on .sample. Inpainting mask can either be the same across all frames (batch, height, width) or different (batch, frames, height, width)
inpaint_videos = torch.randn(4, 3, 8, 512, 512).cuda() # (batch, channels, frames, height, width)
inpaint_masks = torch.ones((4, 8, 512, 512)).bool().cuda() # (batch, frames, height, width)
inpainted_videos = trainer.sample(texts = [
'a whale breaching from afar',
'young girl blowing out candles on her birthday cake',
'fireworks with blue and green sparkles',
'dust motes swirling in the morning sunshine on the windowsill'
], inpaint_videos = inpaint_videos, inpaint_masks = inpaint_masks, cond_scale = 5.)
inpainted_videos # (4, 3, 8, 512, 512)
Experimental
Tero Karras of StyleGAN fame has written a new paper with results that have been corroborated by a number of independent researchers as well as on my own machine. I have decided to create a version of Imagen, the ElucidatedImagen, so that one can use the new elucidated DDPM for text-guided cascading generation.
Simply import ElucidatedImagen, and then instantiate the instance as you did before. The hyperparameters are different than the usual ones for discrete and continuous time gaussian diffusion, and can be individualized for each unet in the cascade.
Ex.
from imagen_pytorch import ElucidatedImagen
# instantiate your unets ...
imagen = ElucidatedImagen(
unets = (unet1, unet2),
image_sizes = (64, 128),
cond_drop_prob = 0.1,
num_sample_steps = (64, 32), # number of sample steps - 64 for base unet, 32 for upsampler (just an example, have no clue what the optimal values are)
sigma_min = 0.002, # min noise level
sigma_max = (80, 160), # max noise level, @crowsonkb recommends double the max noise level for upsampler
sigma_data = 0.5, # standard deviation of data distribution
rho = 7, # controls the sampling schedule
P_mean = -1.2, # mean of log-normal distribution from which noise is drawn for training
P_std = 1.2, # standard deviation of log-normal distribution from which noise is drawn for training
S_churn = 80, # parameters for stochastic sampling - depends on dataset, Table 5 in apper
S_tmin = 0.05,
S_tmax = 50,
S_noise = 1.003,
).cuda()
# rest is the same as above
Text to Video
This repository will also start accumulating new research around text guided video synthesis. For starters it will adopt the 3d unet architecture described by Jonathan Ho in Video Diffusion Models
Update: verified working by Hadrien Reynaud!
Ex.
import torch
from imagen_pytorch import Unet3D, ElucidatedImagen, ImagenTrainer
unet1 = Unet3D(dim = 64, dim_mults = (1, 2, 4, 8)).cuda()
unet2 = Unet3D(dim = 64, dim_mults = (1, 2, 4, 8)).cuda()
# elucidated imagen, which contains the unets above (base unet and super resoluting ones)
imagen = ElucidatedImagen(
unets = (unet1, unet2),
image_sizes = (16, 32),
random_crop_sizes = (None, 16),
temporal_downsample_factor = (2, 1), # in this example, the first unet would receive the video temporally downsampled by 2x
num_sample_steps = 10,
cond_drop_prob = 0.1,
sigma_min = 0.002, # min noise level
sigma_max = (80, 160), # max noise level, double the max noise level for upsampler
sigma_data = 0.5, # standard deviation of data distribution
rho = 7, # controls the sampling schedule
P_mean = -1.2, # mean of log-normal distribution from which noise is drawn for training
P_std = 1.2, # standard deviation of log-normal distribution from which noise is drawn for training
S_churn = 80, # parameters for stochastic sampling - depends on dataset, Table 5 in apper
S_tmin = 0.05,
S_tmax = 50,
S_noise = 1.003,
).cuda()
# mock videos (get a lot of this) and text encodings from large T5
texts = [
'a whale breaching from afar',
'young girl blowing out candles on her birthday cake',
'fireworks with blue and green sparkles',
'dust motes swirling in the morning sunshine on the windowsill'
]
videos = torch.randn(4, 3, 10, 32, 32).cuda() # (batch, channels, time / video frames, height, width)
# feed images into imagen, training each unet in the cascade
# for this example, only training unet 1
trainer = ImagenTrainer(imagen)
# you can also ignore time when training on video initially, shown to improve results in video-ddpm paper. eventually will make the 3d unet trainable with either images or video. research shows it is essential (with current data regimes) to train first on text-to-image. probably won't be true in another decade. all big data becomes small data
trainer(videos, texts = texts, unet_number = 1, ignore_time = False)
trainer.update(unet_number = 1)
videos = trainer.sample(texts = texts, video_frames = 20) # extrapolating to 20 frames from training on 10 frames
videos.shape # (4, 3, 20, 32, 32)
You can also train on text - image pairs first. The Unet3D will automatically convert it to single framed videos and learn without the temporal components (by automatically setting ignore_time = True), whether it be 1d convolutions or causal attention across time.
This is the current approach taken by all the big artificial intelligence labs (Brain, MetaAI, Bytedance)
FAQ
- Why are my generated images not aligning well with the text?
Imagen uses an algorithm called Classifier Free Guidance. When sampling, you apply a scale to the conditioning (text in this case) of greater than 1.0.
Researcher Netruk44 have reported 5-10 to be optimal, but anything greater than 10 to break.
trainer.sample(texts = [
'a cloud in the shape of a roman gladiator'
], cond_scale = 5.) # <-- cond_scale is the conditioning scale, needs to be greater than 1.0 to be better than average
- Are there any pretrained models yet?
Not at the moment but one will likely be trained and open sourced within the year, if not sooner. If you would like to participate, you can join the community of artificial neural network trainers at Laion (discord link is in the Readme above) and start collaborating.
- Will this technology take my job?
More the reason why you should start training your own model, starting today! The last thing we need is this technology being in the hands of an elite few. Hopefully this repository reduces the work to just finding the necessary compute, and augmenting with your own curated dataset.
- What am I allowed to do with this repository?
Anything! It is MIT licensed. In other words, you can freely copy / paste for your own research, remixed for whatever modality you can think of. Go train amazing models for profit, for science, or simply to satiate your own personal pleasure at witnessing something divine unravel in front of you.
Cool Applications!
Related Works
Todo
-
use huggingface transformers for T5-small text embeddings
-
add dynamic thresholding
-
add dynamic thresholding DALLE2 and video-diffusion repository as well
-
allow for one to set T5-large (and perhaps small factory method to take in any huggingface transformer)
-
add the lowres noise level with the pseudocode in appendix, and figure out what is this sweep they do at inference time
-
port over some training code from DALLE2
-
need to be able to use a different noise schedule per unet (cosine was used for base, but linear for SR)
-
just make one master-configurable unet
-
complete resnet block (biggan inspired? but with groupnorm) - complete self attention
-
complete conditioning embedding block (and make it completely configurable, whether it be attention, film etc)
-
consider using perceiver-resampler from github.com/lucidrains/… in place of attention pooling
-
add attention pooling option, in addition to cross attention and film
-
add optional cosine decay schedule with warmup, for each unet, to trainer
-
switch to continuous timesteps instead of discretized, as it seems that is what they used for all stages - first figure out the linear noise schedule case from the variational ddpm paper openreview.net/forum?id=2L…
-
figure out log(snr) for alpha cosine noise schedule.
-
suppress the transformers warning because only T5encoder is used
-
allow setting for using linear attention on layers where full attention cannot be used
-
force unets in continuous time case to use non-fouriered conditions (just pass the log(snr) through an MLP with optional layernorms), as that is what i have working locally
-
removed learned variance
-
add p2 loss weighting for continuous time
-
make sure cascading ddpm can be trained without text condition, and make sure both continuous and discrete time gaussian diffusion works
-
use primer's depthwise convs on the qkv projections in linear attention (or use token shifting before projections) - also use new dropout proposed by bayesformer, as it seems to work well with linear attention
-
explore skip layer excitation in unet decoder
-
accelerate integration
-
build out CLI tool and one-line generation of image
-
knock out any issues that arised from accelerate
-
add inpainting ability using resampler from repaint paper arxiv.org/abs/2201.09…
-
build a simple checkpointing system, backed by a folder
-
add skip connection from outputs of all upsample blocks, used in unet squared paper and some previous unet works
-
add fsspec, recommended by Romain @rom1504, for cloud / local file system agnostic persistence of checkpoints
-
test out persistence in gcs with github.com/fsspec/gcsf…
-
extend to video generation, using axial time attention as in Ho's video ddpm paper
-
allow elucidated imagen to generalize to any shape
-
allow for imagen to generalize to any shape
-
add dynamic positional bias for the best type of length extrapolation across video time
-
move video frames to sample function, as we will be attempting time extrapolation
-
attention bias to null key / values should be a learned scalar of head dimension
-
add self-conditioning from bit diffusion paper, already coded up at ddpm-pytorch
-
add v-parameterization (arxiv.org/abs/2202.00…) from imagen video paper, the only thing new
-
incorporate all learnings from make-a-video (makeavideo.studio/)
-
build out CLI tool for training, resuming training off config file
-
allow for temporal interpolation at specific stages
-
make sure temporal interpolation works with inpainting
-
make sure one can customize all interpolation modes (some researchers are finding better results with trilinear)
-
imagen-video : allow for conditioning on preceding (and possibly future) frames of videos. ignore time should not be allowed in that scenario
-
make sure to automatically take care of temporal down/upsampling for conditioning video frames, but allow for an option to turn it off
-
make sure inpainting works with video
-
make sure inpainting mask for video can accept be customized per frame
-
add flash attention
-
reread cogvideo and figure out how frame rate conditioning could be used
-
bring in attention expertise for self attention layers in unet3d
-
consider bringing in NUWA's 3d convolutional attention
-
consider transformer-xl memories in the temporal attention blocks
-
consider perceiver-ar approach to attending to past time
-
frame dropouts during attention for achieving both regularizing effect as well as shortened training time
-
investigate frank wood's claims github.com/lucidrains/… and either add the hierarchical sampling technique, or let people know about its deficiencies
-
offer challenging moving mnist (with distractor objects) as a one-line trainable baseline for researchers to branch off of for text to video
-
preencoding of text to memmapped embeddings
-
be able to create dataloader iterators based on the old epoch style, also configure shuffling etc
-
be able to also pass in arguments (instead of requiring forward to be all keyword args on model)
-
bring in reversible blocks from revnets for 3d unet, to lessen memory burden
-
add ability to only train super-resolution network
-
read dpm-solver see if it is applicable to continuous time gaussian diffusion
-
allow for conditioning video frames with arbitrary absolute times (calculate RPE during temporal attention)
-
accommodate dream booth fine tuning
-
add textual inversion
-
cleanup self conditioning to be extracted at imagen instantiation
-
make sure eventual dreambooth works with imagen-video
-
add framerate conditioning for video diffusion
-
make sure one can simulataneously condition on video frames as a prompt, as well as some conditioning image across all frames
-
test and add distillation technique from consistency models
Citations
@inproceedings{Saharia2022PhotorealisticTD,
title = {Photorealistic Text-to-Image Diffusion Models with Deep Language Understanding},
author = {Chitwan Saharia and William Chan and Saurabh Saxena and Lala Li and Jay Whang and Emily L. Denton and Seyed Kamyar Seyed Ghasemipour and Burcu Karagol Ayan and Seyedeh Sara Mahdavi and Raphael Gontijo Lopes and Tim Salimans and Jonathan Ho and David Fleet and Mohammad Norouzi},
year = {2022}
}
@article{Alayrac2022Flamingo,
title = {Flamingo: a Visual Language Model for Few-Shot Learning},
author = {Jean-Baptiste Alayrac et al},
year = {2022}
}
@inproceedings{Sankararaman2022BayesFormerTW,
title = {BayesFormer: Transformer with Uncertainty Estimation},
author = {Karthik Abinav Sankararaman and Sinong Wang and Han Fang},
year = {2022}
}
@article{So2021PrimerSF,
title = {Primer: Searching for Efficient Transformers for Language Modeling},
author = {David R. So and Wojciech Ma'nke and Hanxiao Liu and Zihang Dai and Noam M. Shazeer and Quoc V. Le},
journal = {ArXiv},
year = {2021},
volume = {abs/2109.08668}
}
@misc{cao2020global,
title = {Global Context Networks},
author = {Yue Cao and Jiarui Xu and Stephen Lin and Fangyun Wei and Han Hu},
year = {2020},
eprint = {2012.13375},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
@article{Karras2022ElucidatingTD,
title = {Elucidating the Design Space of Diffusion-Based Generative Models},
author = {Tero Karras and Miika Aittala and Timo Aila and Samuli Laine},
journal = {ArXiv},
year = {2022},
volume = {abs/2206.00364}
}
@inproceedings{NEURIPS2020_4c5bcfec,
author = {Ho, Jonathan and Jain, Ajay and Abbeel, Pieter},
booktitle = {Advances in Neural Information Processing Systems},
editor = {H. Larochelle and M. Ranzato and R. Hadsell and M.F. Balcan and H. Lin},
pages = {6840--6851},
publisher = {Curran Associates, Inc.},
title = {Denoising Diffusion Probabilistic Models},
url = {https://proceedings.neurips.cc/paper/2020/file/4c5bcfec8584af0d967f1ab10179ca4b-Paper.pdf},
volume = {33},
year = {2020}
}
@article{Lugmayr2022RePaintIU,
title = {RePaint: Inpainting using Denoising Diffusion Probabilistic Models},
author = {Andreas Lugmayr and Martin Danelljan and Andr{\'e}s Romero and Fisher Yu and Radu Timofte and Luc Van Gool},
journal = {ArXiv},
year = {2022},
volume = {abs/2201.09865}
}
@misc{ho2022video,
title = {Video Diffusion Models},
author = {Jonathan Ho and Tim Salimans and Alexey Gritsenko and William Chan and Mohammad Norouzi and David J. Fleet},
year = {2022},
eprint = {2204.03458},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
@inproceedings{rogozhnikov2022einops,
title = {Einops: Clear and Reliable Tensor Manipulations with Einstein-like Notation},
author = {Alex Rogozhnikov},
booktitle = {International Conference on Learning Representations},
year = {2022},
url = {https://openreview.net/forum?id=oapKSVM2bcj}
}
@misc{chen2022analog,
title = {Analog Bits: Generating Discrete Data using Diffusion Models with Self-Conditioning},
author = {Ting Chen and Ruixiang Zhang and Geoffrey Hinton},
year = {2022},
eprint = {2208.04202},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
@misc{Singer2022,
author = {Uriel Singer},
url = {https://makeavideo.studio/Make-A-Video.pdf}
}
@article{Sunkara2022NoMS,
title = {No More Strided Convolutions or Pooling: A New CNN Building Block for Low-Resolution Images and Small Objects},
author = {Raja Sunkara and Tie Luo},
journal = {ArXiv},
year = {2022},
volume = {abs/2208.03641}
}
@article{Salimans2022ProgressiveDF,
title = {Progressive Distillation for Fast Sampling of Diffusion Models},
author = {Tim Salimans and Jonathan Ho},
journal = {ArXiv},
year = {2022},
volume = {abs/2202.00512}
}
@article{Ho2022ImagenVH,
title = {Imagen Video: High Definition Video Generation with Diffusion Models},
author = {Jonathan Ho and William Chan and Chitwan Saharia and Jay Whang and Ruiqi Gao and Alexey A. Gritsenko and Diederik P. Kingma and Ben Poole and Mohammad Norouzi and David J. Fleet and Tim Salimans},
journal = {ArXiv},
year = {2022},
volume = {abs/2210.02303}
}
@misc{gilmer2023intriguing
title = {Intriguing Properties of Transformer Training Instabilities},
author = {Justin Gilmer, Andrea Schioppa, and Jeremy Cohen},
year = {2023},
status = {to be published - one attention stabilization technique is circulating within Google Brain, being used by multiple teams}
}
@inproceedings{Hang2023EfficientDT,
title = {Efficient Diffusion Training via Min-SNR Weighting Strategy},
author = {Tiankai Hang and Shuyang Gu and Chen Li and Jianmin Bao and Dong Chen and Han Hu and Xin Geng and Baining Guo},
year = {2023}
}
@article{Zhang2021TokenST,
title = {Token Shift Transformer for Video Classification},
author = {Hao Zhang and Y. Hao and Chong-Wah Ngo},
journal = {Proceedings of the 29th ACM International Conference on Multimedia},
year = {2021}
}
@inproceedings{anonymous2022normformer,
title = {NormFormer: Improved Transformer Pretraining with Extra Normalization},
author = {Anonymous},
booktitle = {Submitted to The Tenth International Conference on Learning Representations },
year = {2022},
url = {https://openreview.net/forum?id=GMYWzWztDx5},
note = {under review}
}
.\lucidrains\imagen-pytorch\setup.py
# 导入设置工具和查找包工具
from setuptools import setup, find_packages
# 执行版本文件中的代码,将版本信息导入当前环境
exec(open('imagen_pytorch/version.py').read())
# 设置包的信息
setup(
# 包名
name = 'imagen-pytorch',
# 查找所有包,不排除任何包
packages = find_packages(exclude=[]),
# 包含所有数据文件
include_package_data = True,
# 设置入口点,定义命令行脚本
entry_points={
'console_scripts': [
'imagen_pytorch = imagen_pytorch.cli:main',
'imagen = imagen_pytorch.cli:imagen'
],
},
# 版本号
version = __version__,
# 许可证
license='MIT',
# 描述
description = 'Imagen - unprecedented photorealism × deep level of language understanding',
# 作者
author = 'Phil Wang',
# 作者邮箱
author_email = 'lucidrains@gmail.com',
# 长描述内容类型
long_description_content_type = 'text/markdown',
# 项目链接
url = 'https://github.com/lucidrains/imagen-pytorch',
# 关键词
keywords = [
'artificial intelligence',
'deep learning',
'transformers',
'text-to-image',
'denoising-diffusion'
],
# 安装依赖
install_requires=[
'accelerate>=0.23.0',
'beartype',
'click',
'datasets',
'einops>=0.7.0',
'ema-pytorch>=0.0.3',
'fsspec',
'kornia',
'numpy',
'packaging',
'pillow',
'pydantic>=2',
'pytorch-warmup',
'sentencepiece',
'torch>=1.6',
'torchvision',
'transformers',
'tqdm'
],
# 分类
classifiers=[
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
Insertion Deletion Denoising Diffusion Probabilistic Models (wip)
Implementation of Insertion Deletion Denoising Diffusion Probabilistic Models. This scheme basically allows for DDPM to work beyond just in-place corruption along the sequence. They try to apply this to text generation with lukewarm results. I think it holds promise for protein design, as it would be able to infill certain regions without being constrained to a fixed number of amino acids.
Citations
@article{Johnson2021BeyondIC,
title = {Beyond In-Place Corruption: Insertion and Deletion In Denoising Probabilistic Models},
author = {Daniel D. Johnson and Jacob Austin and Rianne van den Berg and Daniel Tarlow},
journal = {ArXiv},
year = {2021},
volume = {abs/2107.07675}
}
.\lucidrains\invariant-point-attention\denoise.py
# 导入所需的库
import torch
import torch.nn.functional as F
from torch import nn, einsum
from torch.optim import Adam
# 导入 einops 库中的函数
from einops import rearrange, repeat
# 导入 sidechainnet 库
import sidechainnet as scn
# 导入自定义的模块 invariant_point_attention 中的 IPATransformer 类
from invariant_point_attention import IPATransformer
# 定义批处理大小和梯度累积次数
BATCH_SIZE = 1
GRADIENT_ACCUMULATE_EVERY = 16
# 定义一个循环生成器函数,用于处理数据加载器中的数据
def cycle(loader, len_thres = 200):
while True:
for data in loader:
# 如果序列长度超过阈值,则跳过
if data.seqs.shape[1] > len_thres:
continue
yield data
# 创建 IPATransformer 模型实例
net = IPATransformer(
dim = 16,
num_tokens = 21,
depth = 5,
require_pairwise_repr = False,
predict_points = True
).cuda()
# 加载数据集
data = scn.load(
casp_version = 12,
thinning = 30,
with_pytorch = 'dataloaders',
batch_size = BATCH_SIZE,
dynamic_batching = False
)
# 创建数据加载器
dl = cycle(data['train'])
# 初始化 Adam 优化器
optim = Adam(net.parameters(), lr=1e-3)
# 迭代训练模型
for _ in range(10000):
# 梯度累积
for _ in range(GRADIENT_ACCUMULATE_EVERY):
# 获取一个批次的数据
batch = next(dl)
seqs, coords, masks = batch.seqs, batch.crds, batch.msks
# 将序列转移到 GPU 并获取最大值索引
seqs = seqs.cuda().argmax(dim = -1)
coords = coords.cuda()
masks = masks.cuda().bool()
# 获取序列长度并重新排列坐标
l = seqs.shape[1]
coords = rearrange(coords, 'b (l s) c -> b l s c', s = 14)
# 仅保留 Ca 原子坐标
coords = coords[:, :, 1, :]
# 添加随机噪声
noised_coords = coords + torch.randn_like(coords)
# 输入模型进行去噪处理
denoised_coords = net(
seqs,
translations = noised_coords,
mask = masks
)
# 计算损失
loss = F.mse_loss(denoised_coords[masks], coords[masks])
# 反向传播
(loss / GRADIENT_ACCUMULATE_EVERY).backward()
# 输出损失值
print('loss:', loss.item())
# 更新优化器
optim.step()
# 梯度清零
optim.zero_grad()
.\lucidrains\invariant-point-attention\invariant_point_attention\invariant_point_attention.py
import torch
import torch.nn.functional as F
from torch.cuda.amp import autocast
from contextlib import contextmanager
from torch import nn, einsum
from einops.layers.torch import Rearrange
from einops import rearrange, repeat
# helpers
# 检查值是否存在
def exists(val):
return val is not None
# 返回默认值
def default(val, d):
return val if exists(val) else d
# 返回给定类型的最大负值
def max_neg_value(t):
return -torch.finfo(t.dtype).max
@contextmanager
def disable_tf32():
orig_value = torch.backends.cuda.matmul.allow_tf32
torch.backends.cuda.matmul.allow_tf32 = False
yield
torch.backends.cuda.matmul.allow_tf32 = orig_value
# classes
class InvariantPointAttention(nn.Module):
def __init__(
self,
*,
dim,
heads = 8,
scalar_key_dim = 16,
scalar_value_dim = 16,
point_key_dim = 4,
point_value_dim = 4,
pairwise_repr_dim = None,
require_pairwise_repr = True,
eps = 1e-8
):
super().__init__()
self.eps = eps
self.heads = heads
self.require_pairwise_repr = require_pairwise_repr
# num attention contributions
num_attn_logits = 3 if require_pairwise_repr else 2
# qkv projection for scalar attention (normal)
self.scalar_attn_logits_scale = (num_attn_logits * scalar_key_dim) ** -0.5
self.to_scalar_q = nn.Linear(dim, scalar_key_dim * heads, bias = False)
self.to_scalar_k = nn.Linear(dim, scalar_key_dim * heads, bias = False)
self.to_scalar_v = nn.Linear(dim, scalar_value_dim * heads, bias = False)
# qkv projection for point attention (coordinate and orientation aware)
point_weight_init_value = torch.log(torch.exp(torch.full((heads,), 1.)) - 1.)
self.point_weights = nn.Parameter(point_weight_init_value)
self.point_attn_logits_scale = ((num_attn_logits * point_key_dim) * (9 / 2)) ** -0.5
self.to_point_q = nn.Linear(dim, point_key_dim * heads * 3, bias = False)
self.to_point_k = nn.Linear(dim, point_key_dim * heads * 3, bias = False)
self.to_point_v = nn.Linear(dim, point_value_dim * heads * 3, bias = False)
# pairwise representation projection to attention bias
pairwise_repr_dim = default(pairwise_repr_dim, dim) if require_pairwise_repr else 0
if require_pairwise_repr:
self.pairwise_attn_logits_scale = num_attn_logits ** -0.5
self.to_pairwise_attn_bias = nn.Sequential(
nn.Linear(pairwise_repr_dim, heads),
Rearrange('b ... h -> (b h) ...')
)
# combine out - scalar dim + pairwise dim + point dim * (3 for coordinates in R3 and then 1 for norm)
self.to_out = nn.Linear(heads * (scalar_value_dim + pairwise_repr_dim + point_value_dim * (3 + 1)), dim)
def forward(
self,
single_repr,
pairwise_repr = None,
*,
rotations,
translations,
mask = None
):
pass
# one transformer block based on IPA
def FeedForward(dim, mult = 1., num_layers = 2, act = nn.ReLU):
layers = []
dim_hidden = dim * mult
for ind in range(num_layers):
is_first = ind == 0
is_last = ind == (num_layers - 1)
dim_in = dim if is_first else dim_hidden
dim_out = dim if is_last else dim_hidden
layers.append(nn.Linear(dim_in, dim_out))
if is_last:
continue
layers.append(act())
return nn.Sequential(*layers)
class IPABlock(nn.Module):
def __init__(
self,
*,
dim,
ff_mult = 1,
ff_num_layers = 3, # in the paper, they used 3 layer transition (feedforward) block
post_norm = True, # in the paper, they used post-layernorm - offering pre-norm as well
post_attn_dropout = 0.,
post_ff_dropout = 0.,
**kwargs
):
pass
# 初始化函数,继承父类的初始化方法
def __init__(
self,
post_norm: bool
):
# 调用父类的初始化方法
super().__init__()
# 设置是否在后处理时进行归一化
self.post_norm = post_norm
# 初始化注意力层的归一化层
self.attn_norm = nn.LayerNorm(dim)
# 创建不变点注意力层对象
self.attn = InvariantPointAttention(dim = dim, **kwargs)
# 初始化注意力层后的丢弃层
self.post_attn_dropout = nn.Dropout(post_attn_dropout)
# 初始化前馈神经网络的归一化层
self.ff_norm = nn.LayerNorm(dim)
# 创建前馈神经网络对象
self.ff = FeedForward(dim, mult = ff_mult, num_layers = ff_num_layers)
# 初始化前馈神经网络后的丢弃层
self.post_ff_dropout = nn.Dropout(post_ff_dropout)
# 前向传播函数
def forward(self, x, **kwargs):
# 获取是否在后处理时进行归一化的标志
post_norm = self.post_norm
# 如果不进行后处理归一化,则直接使用输入作为注意力层的输入,否则对输入进行归一化
attn_input = x if post_norm else self.attn_norm(x)
# 经过注意力层的计算,并加上残差连接
x = self.attn(attn_input, **kwargs) + x
# 经过注意力层后的丢弃操作
x = self.post_attn_dropout(x)
# 如果不进行后处理归一化,则对输出进行归一化,否则直接输出
x = self.attn_norm(x) if post_norm else x
# 如果不进行后处理归一化,则直接使用输入作为前馈神经网络的输入,否则对输入进行归一化
ff_input = x if post_norm else self.ff_norm(x)
# 经过前馈神经网络的计算,并加上残差连接
x = self.ff(ff_input) + x
# 经过前馈神经网络后的丢弃操作
x = self.post_ff_dropout(x)
# 如果不进行后处理归一化,则对输出进行归一化,否则直接输出
x = self.ff_norm(x) if post_norm else x
# 返回最终输出
return x
# 添加一个 IPA Transformer - 迭代更新旋转和平移
# 这部分与 AF2 不太准确,因为 AF2 在每一层都应用了一个 FAPE 辅助损失,以及在旋转上应用了一个停止梯度
# 这只是一个尝试,看看是否可以演变成更普遍可用的东西
class IPATransformer(nn.Module):
def __init__(
self,
*,
dim,
depth,
num_tokens = None,
predict_points = False,
detach_rotations = True,
**kwargs
):
super().__init__()
# 使用来自 pytorch3d 的四元数函数
try:
from pytorch3d.transforms import quaternion_multiply, quaternion_to_matrix
self.quaternion_to_matrix = quaternion_to_matrix
self.quaternion_multiply = quaternion_multiply
except (ImportError, ModuleNotFoundError) as err:
print('unable to import pytorch3d - please install with `conda install pytorch3d -c pytorch3d`')
raise err
# 嵌入
self.token_emb = nn.Embedding(num_tokens, dim) if exists(num_tokens) else None
# 层
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
IPABlock(dim = dim, **kwargs),
nn.Linear(dim, 6)
]))
# 是否分离旋转以保持训练稳定性
self.detach_rotations = detach_rotations
# 输出
self.predict_points = predict_points
if predict_points:
self.to_points = nn.Linear(dim, 3)
def forward(
self,
single_repr,
*,
translations = None,
quaternions = None,
pairwise_repr = None,
mask = None
):
x, device, quaternion_multiply, quaternion_to_matrix = single_repr, single_repr.device, self.quaternion_multiply, self.quaternion_to_matrix
b, n, *_ = x.shape
if exists(self.token_emb):
x = self.token_emb(x)
# 如果没有传入初始四元数,从单位矩阵开始
if not exists(quaternions):
quaternions = torch.tensor([1., 0., 0., 0.], device = device) # 初始旋转
quaternions = repeat(quaternions, 'd -> b n d', b = b, n = n)
# 如果没有传入平移,从零开始
if not exists(translations):
translations = torch.zeros((b, n, 3), device = device)
# 遍历层并应用不变点注意力和前馈
for block, to_update in self.layers:
rotations = quaternion_to_matrix(quaternions)
if self.detach_rotations:
rotations = rotations.detach()
x = block(
x,
pairwise_repr = pairwise_repr,
rotations = rotations,
translations = translations
)
# 更新四元数和平移
quaternion_update, translation_update = to_update(x).chunk(2, dim = -1)
quaternion_update = F.pad(quaternion_update, (1, 0), value = 1.)
quaternion_update = quaternion_update / torch.linalg.norm(quaternion_update, dim=-1, keepdim=True)
quaternions = quaternion_multiply(quaternions, quaternion_update)
translations = translations + einsum('b n c, b n c r -> b n r', translation_update, rotations)
if not self.predict_points:
return x, translations, quaternions
points_local = self.to_points(x)
rotations = quaternion_to_matrix(quaternions)
points_global = einsum('b n c, b n c d -> b n d', points_local, rotations) + translations
return points_global
.\lucidrains\invariant-point-attention\invariant_point_attention\utils.py
# 导入 torch 库
import torch
# 从 torch 库中导入 sin, cos, atan2, acos 函数
from torch import sin, cos, atan2, acos
# 从 functools 库中导入 wraps 装饰器
from functools import wraps
# 定义一个装饰器函数,将输入转换为 torch 张量
def cast_torch_tensor(fn):
# 定义内部函数,用于实际执行函数并进行类型转换
@wraps(fn)
def inner(t):
# 如果输入不是 torch 张量,则将其转换为 torch 张量
if not torch.is_tensor(t):
t = torch.tensor(t, dtype=torch.get_default_dtype())
# 调用原始函数并返回结果
return fn(t)
# 返回内部函数
return inner
# 使用装饰器将 rot_z 函数转换为接受 torch 张量作为输入的函数
@cast_torch_tensor
def rot_z(gamma):
# 返回绕 z 轴旋转角度 gamma 的旋转矩阵
return torch.tensor([
[cos(gamma), -sin(gamma), 0],
[sin(gamma), cos(gamma), 0],
[0, 0, 1]
], dtype=gamma.dtype)
# 使用装饰器将 rot_y 函数转换为接受 torch 张量作为输入的函数
@cast_torch_tensor
def rot_y(beta):
# 返回绕 y 轴旋转角度 beta 的旋转矩阵
return torch.tensor([
[cos(beta), 0, sin(beta)],
[0, 1, 0],
[-sin(beta), 0, cos(beta)]
], dtype=beta.dtype)
# 定义一个函数,通过组合旋转矩阵实现绕不同轴的旋转
def rot(alpha, beta, gamma):
# 返回绕 z 轴旋转角度 alpha、绕 y 轴旋转角度 beta、绕 z 轴旋转角度 gamma 的组合旋转矩阵
return rot_z(alpha) @ rot_y(beta) @ rot_z(gamma)
.\lucidrains\invariant-point-attention\invariant_point_attention\__init__.py
# 从 invariant_point_attention 模块中导入 InvariantPointAttention, IPABlock, IPATransformer 类
from invariant_point_attention.invariant_point_attention import InvariantPointAttention, IPABlock, IPATransformer

Invariant Point Attention - Pytorch
Implementation of Invariant Point Attention as a standalone module, which was used in the structure module of Alphafold2 for coordinate refinement.
- enforce float32 for certain operations
Install
$ pip install invariant-point-attention
Usage
import torch
from einops import repeat
from invariant_point_attention import InvariantPointAttention
attn = InvariantPointAttention(
dim = 64, # single (and pairwise) representation dimension
heads = 8, # number of attention heads
scalar_key_dim = 16, # scalar query-key dimension
scalar_value_dim = 16, # scalar value dimension
point_key_dim = 4, # point query-key dimension
point_value_dim = 4 # point value dimension
)
single_repr = torch.randn(1, 256, 64) # (batch x seq x dim)
pairwise_repr = torch.randn(1, 256, 256, 64) # (batch x seq x seq x dim)
mask = torch.ones(1, 256).bool() # (batch x seq)
rotations = repeat(torch.eye(3), '... -> b n ...', b = 1, n = 256) # (batch x seq x rot1 x rot2) - example is identity
translations = torch.zeros(1, 256, 3) # translation, also identity for example
attn_out = attn(
single_repr,
pairwise_repr,
rotations = rotations,
translations = translations,
mask = mask
)
attn_out.shape # (1, 256, 64)
You can also use this module without the pairwise representations, which is very specific to the Alphafold2 architecture.
import torch
from einops import repeat
from invariant_point_attention import InvariantPointAttention
attn = InvariantPointAttention(
dim = 64,
heads = 8,
require_pairwise_repr = False # set this to False to use the module without pairwise representations
)
seq = torch.randn(1, 256, 64)
mask = torch.ones(1, 256).bool()
rotations = repeat(torch.eye(3), '... -> b n ...', b = 1, n = 256)
translations = torch.randn(1, 256, 3)
attn_out = attn(
seq,
rotations = rotations,
translations = translations,
mask = mask
)
attn_out.shape # (1, 256, 64)
You can also use one IPA-based transformer block, which is an IPA followed by a feedforward. By default it will use post-layernorm as done in the official code, but you can also try pre-layernorm by setting post_norm = False
import torch
from torch import nn
from einops import repeat
from invariant_point_attention import IPABlock
block = IPABlock(
dim = 64,
heads = 8,
scalar_key_dim = 16,
scalar_value_dim = 16,
point_key_dim = 4,
point_value_dim = 4
)
seq = torch.randn(1, 256, 64)
pairwise_repr = torch.randn(1, 256, 256, 64)
mask = torch.ones(1, 256).bool()
rotations = repeat(torch.eye(3), 'r1 r2 -> b n r1 r2', b = 1, n = 256)
translations = torch.randn(1, 256, 3)
block_out = block(
seq,
pairwise_repr = pairwise_repr,
rotations = rotations,
translations = translations,
mask = mask
)
updates = nn.Linear(64, 6)(block_out)
quaternion_update, translation_update = updates.chunk(2, dim = -1) # (1, 256, 3), (1, 256, 3)
# apply updates to rotations and translations for the next iteration
Toy Example
To run IPA on a toy task for denoising protein backbone coordinates, first install pytorch3d by running
$ conda install pytorch3d -c pytorch3d
Then you need to install sidechainnet with
$ pip install sidechainnet
Finally
$ python denoise.py
Citations
@Article{AlphaFold2021,
author = {Jumper, John and Evans, Richard and Pritzel, Alexander and Green, Tim and Figurnov, Michael and Ronneberger, Olaf and Tunyasuvunakool, Kathryn and Bates, Russ and {\v{Z}}{\'\i}dek, Augustin and Potapenko, Anna and Bridgland, Alex and Meyer, Clemens and Kohl, Simon A A and Ballard, Andrew J and Cowie, Andrew and Romera-Paredes, Bernardino and Nikolov, Stanislav and Jain, Rishub and Adler, Jonas and Back, Trevor and Petersen, Stig and Reiman, David and Clancy, Ellen and Zielinski, Michal and Steinegger, Martin and Pacholska, Michalina and Berghammer, Tamas and Bodenstein, Sebastian and Silver, David and Vinyals, Oriol and Senior, Andrew W and Kavukcuoglu, Koray and Kohli, Pushmeet and Hassabis, Demis},
journal = {Nature},
title = {Highly accurate protein structure prediction with {AlphaFold}},
year = {2021},
doi = {10.1038/s41586-021-03819-2},
note = {(Accelerated article preview)},
}
.\lucidrains\invariant-point-attention\setup.py
# 导入设置和查找包的函数
from setuptools import setup, find_packages
# 设置包的元数据
setup(
name = 'invariant-point-attention', # 包的名称
packages = find_packages(), # 查找所有包
version = '0.2.2', # 版本号
license='MIT', # 许可证
description = 'Invariant Point Attention', # 描述
long_description_content_type = 'text/markdown', # 长描述内容类型
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
url = 'https://github.com/lucidrains/invariant-point-attention', # 项目链接
keywords = [ # 关键词列表
'artificial intelligence',
'deep learning',
'protein folding'
],
install_requires=[ # 安装依赖
'einops>=0.3',
'torch>=1.7'
],
setup_requires=[ # 设置依赖
'pytest-runner',
],
tests_require=[ # 测试依赖
'pytest'
],
classifiers=[ # 分类器
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\invariant-point-attention\tests\invariance.py
# 导入所需的库
import torch
from torch import nn
from einops import repeat
from invariant_point_attention import InvariantPointAttention, IPABlock
from invariant_point_attention.utils import rot
# 测试不变性点注意力机制的函数
def test_ipa_invariance():
# 创建不变性点注意力机制对象
attn = InvariantPointAttention(
dim = 64,
heads = 8,
scalar_key_dim = 16,
scalar_value_dim = 16,
point_key_dim = 4,
point_value_dim = 4
)
# 创建随机输入序列、成对表示、掩码
seq = torch.randn(1, 256, 64)
pairwise_repr = torch.randn(1, 256, 256, 64)
mask = torch.ones(1, 256).bool()
# 创建随机旋转和平移
rotations = repeat(rot(*torch.randn(3)), 'r1 r2 -> b n r1 r2', b = 1, n = 256)
translations = torch.randn(1, 256, 3)
# 随机旋转,用于测试不变性
random_rotation = rot(*torch.randn(3))
# 获取不变性点注意力机制的输出
attn_out = attn(
seq,
pairwise_repr = pairwise_repr,
rotations = rotations,
translations = translations,
mask = mask
)
# 获取旋转后的不变性点注意力机制的输出
rotated_attn_out = attn(
seq,
pairwise_repr = pairwise_repr,
rotations = rotations @ random_rotation,
translations = translations @ random_rotation,
mask = mask
)
# 输出必须是不变的
diff = (attn_out - rotated_attn_out).max()
assert diff <= 1e-6, 'must be invariant to global rotation'
# 测试不变性点注意力机制块的函数
def test_ipa_block_invariance():
# 创建不变性点注意力机制块对象
attn = IPABlock(
dim = 64,
heads = 8,
scalar_key_dim = 16,
scalar_value_dim = 16,
point_key_dim = 4,
point_value_dim = 4
)
# 创建随机输入序列、成对表示、掩码
seq = torch.randn(1, 256, 64)
pairwise_repr = torch.randn(1, 256, 256, 64)
mask = torch.ones(1, 256).bool()
# 创建随机旋转和平移
rotations = repeat(rot(*torch.randn(3)), 'r1 r2 -> b n r1 r2', b = 1, n = 256)
translations = torch.randn(1, 256, 3)
# 随机旋转,用于测试不变性
random_rotation = rot(*torch.randn(3))
# 获取不变性点注意力机制块的输出
attn_out = attn(
seq,
pairwise_repr = pairwise_repr,
rotations = rotations,
translations = translations,
mask = mask
)
# 获取旋转后的不变性点注意力机制块的输出
rotated_attn_out = attn(
seq,
pairwise_repr = pairwise_repr,
rotations = rotations @ random_rotation,
translations = translations @ random_rotation,
mask = mask
)
# 输出必须是不变的
diff = (attn_out - rotated_attn_out).max()
assert diff <= 1e-6, 'must be invariant to global rotation'
.\lucidrains\isab-pytorch\isab_pytorch\isab_pytorch.py
# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块和 einsum 函数
from torch import nn, einsum
# 从 torch.nn.functional 中导入 F 模块
import torch.nn.functional as F
# 从 einops 库中导入 rearrange 和 repeat 函数
from einops import rearrange, repeat
# 定义一个辅助函数,用于检查值是否存在
def exists(val):
return val is not None
# 定义一个注意力机制类
class Attention(nn.Module):
def __init__(
self,
dim,
heads = 8,
dim_head = 64,
and_self_attend = False
):
super().__init__()
inner_dim = heads * dim_head
self.heads = heads
self.scale = dim_head ** -0.5
self.and_self_attend = and_self_attend
# 定义将输入转换为查询向量的线性层
self.to_q = nn.Linear(dim, inner_dim, bias = False)
# 定义将输入转换为键值对的线性层
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
# 定义将输出转换为最终输出的线性层
self.to_out = nn.Linear(inner_dim, dim, bias = False)
def forward(
self,
x,
context,
mask = None
):
h, scale = self.heads, self.scale
if self.and_self_attend:
# 如果需要自注意力机制,则将上下文信息与输入拼接在一起
context = torch.cat((x, context), dim = -2)
if exists(mask):
# 对 mask 进行填充,使其与输入的维度相匹配
mask = F.pad(mask, (x.shape[-2], 0), value = True)
# 将输入 x 转换为查询向量 q,将上下文信息转换为键值对 k 和 v
q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
# 将查询向量 q、键 k、值 v 重排维度,以适应注意力计算
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
# 计算点积注意力得分
dots = einsum('b h i d, b h j d -> b h i j', q, k) * scale
if exists(mask):
# 对注意力得分进行 mask 处理
mask_value = -torch.finfo(dots.dtype).max
mask = rearrange(mask, 'b n -> b 1 1 n')
dots.masked_fill_(~mask, mask_value)
# 对注意力得分进行 softmax 操作,得到注意力权重
attn = dots.softmax(dim = -1)
# 根据注意力权重计算输出
out = einsum('b h i j, b h j d -> b h i d', attn, v)
# 重排输出维度,返回最终输出
out = rearrange(out, 'b h n d -> b n (h d)', h = h)
return self.to_out(out)
# 定义一个独立的多头自注意力块类
class ISAB(nn.Module):
def __init__(
self,
*,
dim,
heads = 8,
num_latents = None,
latent_self_attend = False
):
super().__init__()
# 如果存在 latents 数量,则初始化为随机张量,否则为 None
self.latents = nn.Parameter(torch.randn(num_latents, dim)) if exists(num_latents) else None
# 定义第一个注意力机制,用于处理 latents 和输入 x
self.attn1 = Attention(dim, heads, and_self_attend = latent_self_attend)
# 定义第二个注意力机制,用于处理输入 x 和 latents
self.attn2 = Attention(dim, heads)
def forward(self, x, latents = None, mask = None):
b, *_ = x.shape
# 确保 latents 参数存在性与 latents 属性的一致性
assert exists(latents) ^ exists(self.latents), 'you can only either learn the latents within the module, or pass it in externally'
latents = latents if exists(latents) else self.latents
if latents.ndim == 2:
# 如果 latents 是二维张量,则重复扩展为与输入 x 相同的 batch 维度
latents = repeat(latents, 'n d -> b n d', b = b)
# 使用第一个注意力机制处理 latents 和输入 x,得到 latents
latents = self.attn1(latents, x, mask = mask)
# 使用第二个注意力机制处理输入 x 和 latents,得到输出
out = self.attn2(x, latents)
return out, latents