Lucidrains 系列项目源码解析(五十五)
.\lucidrains\MaMMUT-pytorch\mammut_pytorch\__init__.py
# 从 mammut_pytorch 包中导入 MaMMUT 类
from mammut_pytorch.mammut_pytorch import MaMMUT

MaMMUT - Pytorch
Implementation of MaMMUT, a simple vision-encoder text-decoder architecture for multimodal tasks from Google, in Pytorch. Blog post
This work is basically just a simplified CoCa. I copied the code from this repository and made the change in the paper, which was to simply do two passes through the text encoder, one with cross attention for the generative loss, and the other without for the contrastive loss.
This is also a good time to plug an open sourced version of CoCa from the folks at OpenCLIP!
Appreciation
- Stability and 🤗 Huggingface for their generous sponsorships to work on and open source cutting edge artificial intelligence research
Install
$ pip install mammut-pytorch
Usage
First install the vit-pytorch for the image encoder, which needs to be pretrained
$ pip install vit-pytorch>=0.40.2
Then
import torch
# import vision transformer
from vit_pytorch.simple_vit_with_patch_dropout import SimpleViT
from vit_pytorch.extractor import Extractor
vit = SimpleViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048,
patch_dropout = 0.5 # https://arxiv.org/abs/2212.00794
)
vit = Extractor(vit, return_embeddings_only = True, detach = False)
# extractor will enable it so the vision transformer returns its embeddings
# import MaMMUT and instantiate it
from mammut_pytorch.mammut_pytorch import MaMMUT
mammut = MaMMUT(
dim = 512, # model dimension
img_encoder = vit, # vision transformer - image encoder, returning image embeddings as (batch, seq, dim)
image_dim = 1024, # image embedding dimension, if not the same as model dimensions
num_tokens = 20000, # number of text tokens
depth = 6, # depth of the transformer
dim_head = 64, # dimension per attention head
heads = 8, # number of attention heads
caption_loss_weight = 1., # weight on the autoregressive caption loss
contrastive_loss_weight = 1., # weight on the contrastive loss between image and text CLS embeddings
).cuda()
# mock text and images
text = torch.randint(0, 20000, (4, 512)).cuda()
images = torch.randn(4, 3, 256, 256).cuda()
# train by giving MaMMUT your text and images with `return_loss = True`
loss = mammut(
text = text,
images = images,
return_loss = True # set this to True to get the full caption + contrastive loss
)
loss.backward()
# do the above for as much text and images...
# then you can get the caption logits as so
logits = mammut(
text = text,
images = images
) # (4, 512, 20000)
# and the CLIP-like text and image embeddings as
text_embeds, image_embeds = mammut(
text = text,
images = images,
return_embeddings = True
) # (4, 512), (4, 512)
One of the main findings of the paper is that different tasks perform differently depending on the amount of cross attention. This repository will give you full control over how much cross attention you want to place in the network.
mammut = MaMMUT(
dim = 512,
img_encoder = vit,
image_dim = 1024,
num_tokens = 20000,
depth = 6,
cross_attend_every = 2, # say you want to cross attend only every 2 layers
dim_head = 64,
heads = 8,
caption_loss_weight = 1.,
contrastive_loss_weight = 1.
).cuda()
# or you can finely specify which layers to do cross attention
mammut = MaMMUT(
dim = 512,
img_encoder = vit,
image_dim = 1024,
num_tokens = 20000,
depth = 6,
cross_attend_layers = (4, 5, 6), # only last three layers have cross attention
dim_head = 64,
heads = 8,
caption_loss_weight = 1.,
contrastive_loss_weight = 1.
).cuda()
Todo
- offer masked mean pooling of text embeddings and mean pooling for images for contrastive latents
Citations
@article{Kuo2023MaMMUTAS,
title = {MaMMUT: A Simple Architecture for Joint Learning for MultiModal Tasks},
author = {Weicheng Kuo and A. J. Piergiovanni and Dahun Kim and Xiyang Luo and Benjamin Caine and W. Li and Abhijit S. Ogale and Luowei Zhou and Andrew M. Dai and Zhifeng Chen and Claire Cui and Anelia Angelova},
journal = {ArXiv},
year = {2023},
volume = {abs/2303.16839}
}
@inproceedings{Chowdhery2022PaLMSL,
title = {PaLM: Scaling Language Modeling with Pathways},
author = {Aakanksha Chowdhery and Sharan Narang and Jacob Devlin and Maarten Bosma and Gaurav Mishra and Adam Roberts and Paul Barham and Hyung Won Chung and Charles Sutton and Sebastian Gehrmann and Parker Schuh and Kensen Shi and Sasha Tsvyashchenko and Joshua Maynez and Abhishek Rao and Parker Barnes and Yi Tay and Noam M. Shazeer and Vinodkumar Prabhakaran and Emily Reif and Nan Du and Benton C. Hutchinson and Reiner Pope and James Bradbury and Jacob Austin and Michael Isard and Guy Gur-Ari and Pengcheng Yin and Toju Duke and Anselm Levskaya and Sanjay Ghemawat and Sunipa Dev and Henryk Michalewski and Xavier Garc{\'i}a and Vedant Misra and Kevin Robinson and Liam Fedus and Denny Zhou and Daphne Ippolito and David Luan and Hyeontaek Lim and Barret Zoph and Alexander Spiridonov and Ryan Sepassi and David Dohan and Shivani Agrawal and Mark Omernick and Andrew M. Dai and Thanumalayan Sankaranarayana Pillai and Marie Pellat and Aitor Lewkowycz and Erica Oliveira Moreira and Rewon Child and Oleksandr Polozov and Katherine Lee and Zongwei Zhou and Xuezhi Wang and Brennan Saeta and Mark Diaz and Orhan Firat and Michele Catasta and Jason Wei and Kathleen S. Meier-Hellstern and Douglas Eck and Jeff Dean and Slav Petrov and Noah Fiedel},
year = {2022}
}
.\lucidrains\MaMMUT-pytorch\setup.py
# 导入设置工具和查找包的函数
from setuptools import setup, find_packages
# 设置包的元数据
setup(
name = 'MaMMUT-pytorch', # 包的名称
packages = find_packages(exclude=[]), # 查找并包含所有包
version = '0.0.7', # 版本号
license='MIT', # 许可证
description = 'MaMMUT - Pytorch', # 描述
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
long_description_content_type = 'text/markdown', # 长描述内容类型
url = 'https://github.com/lucidrains/MaMMUT-pytorch', # 项目链接
keywords = [ # 关键词列表
'artificial intelligence',
'deep learning',
'multimodal',
'attention mechanism',
'contrastive learning'
],
install_requires=[ # 安装依赖
'einops>=0.6.1',
'torch>=1.6',
],
classifiers=[ # 分类器
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\marge-pytorch\marge_pytorch\autoregressive_wrapper.py
from functools import partial
import torch
import random
from torch import nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
# 定义一个函数,返回 value 或者 default 中的值
def default(value, default):
return value if value is not None else default
# 对输入张量取对数,加上一个很小的值 eps 防止出现取对数时的错误
def log(t, eps=1e-9):
return torch.log(t + eps)
# 根据 top-p 策略过滤 logits
def top_p(logits, thres = 0.9):
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cum_probs > 1.0 - thres
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
sorted_indices_to_remove[:, 0] = 0
sorted_logits[sorted_indices_to_remove] = float('-inf')
return sorted_logits.scatter(1, sorted_indices, sorted_logits)
# 根据 top-k 策略过滤 logits
def top_k(logits, thres = 0.9):
k = int((1 - thres) * logits.shape[-1])
val, ind = torch.topk(logits, k)
probs = torch.full_like(logits, float('-inf'))
probs.scatter_(1, ind, val)
return probs
# 自回归包装器类
class AutoregressiveWrapper(nn.Module):
def __init__(self, net, ignore_index = None, pad_value = 0):
super().__init__()
self.pad_value = pad_value
self.ignore_index = default(ignore_index, pad_value)
self.net = net
self.max_seq_len = net.max_seq_len
# 生成序列
@torch.no_grad()
def generate(self, start_tokens, seq_len, eos_token = None, temperature = 1., filter_logits_fn = top_k, filter_thres = 0.9, **kwargs):
was_training = self.net.training
num_dims = len(start_tokens.shape)
if num_dims == 1:
start_tokens = start_tokens[None, :]
b, t = start_tokens.shape
self.net.eval()
out = start_tokens
input_mask = kwargs.pop('src_mask', None)
if input_mask is None:
input_mask = torch.full_like(out, True, dtype=torch.bool, device=out.device)
for _ in range(seq_len):
x = out[:, -self.max_seq_len:]
input_mask = input_mask[:, -self.max_seq_len:]
logits, *_ = self.net(x, src_mask=input_mask, **kwargs)
logits = logits[:, -1, :]
filtered_logits = filter_logits_fn(logits, thres = filter_thres)
gumbel_noise = -log(-log(torch.zeros_like(filtered_logits).uniform_(0, 1)))
sample = ((filtered_logits / temperature) + gumbel_noise).argmax(dim=-1)
out = torch.cat((out, sample[:, None]), dim=-1)
input_mask = F.pad(input_mask, (1, 0), value=True)
if eos_token is not None and (sample == eos_token).all():
break
out = out[:, t:]
if num_dims == 1:
out = out.squeeze(0)
self.net.train(was_training)
return out
# 前向传播函数
def forward(self, x, *args, **kwargs):
pad = partial(pad_sequence, batch_first = True, padding_value = self.pad_value)
m = kwargs.pop('input_mask', None)
xi, xo = x[:, :-1], x[:, 1:]
if m is not None:
assert m.shape == x.shape[0:2], 'input mask must be the same shape as the input of the auto-regressive wrapper to automatically handle'
kwargs.update(input_mask = m[:, :-1])
out, *rest = self.net(xi, *args, **kwargs)
loss = F.cross_entropy(out.transpose(1, 2), xo, ignore_index = self.ignore_index)
return (loss, *rest)
.\lucidrains\marge-pytorch\marge_pytorch\marge_pytorch.py
# 导入必要的库
import faiss
import math
import numpy as np
from tqdm import tqdm
from einops import rearrange, repeat
from functools import partial
from contextlib import contextmanager
import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn, einsum
import torch.nn.functional as F
from marge_pytorch.autoregressive_wrapper import AutoregressiveWrapper
# 定义一些辅助函数
# 返回输入值
def identity(x, *args, **kwargs):
return x
# 检查输入值是否存在
def exists(x):
return x is not None
# 如果输入值存在则返回输入值,否则返回默认值
def default(x, d):
return x if exists(x) else d
# 将列表分块
def chunk(chunk_size, l):
for lo in range(0, l, chunk_size):
hi = min(l, lo + chunk_size)
yield slice(lo, hi)
# 返回输入张量的最大负值
def max_neg_value(tensor):
return -torch.finfo(tensor.dtype).max
# 上下文管理器,用于创建内存映射
@contextmanager
def memmap(*args, **kwargs):
pointer = np.memmap(*args, **kwargs)
yield pointer
del pointer
# 注意力蒸馏损失函数
def distill_attn_loss(evi_dots, doc_similarities, mask = None, eps = 1e-5):
evi_dots = rearrange(evi_dots, 'b l h i n j -> b (l h i) n j')
if exists(mask):
mask = rearrange(mask, 'b n j -> b () n j')
evi_dots.masked_fill_(~mask, 0.)
denom = mask.expand_as(evi_dots).sum(dim = (1, -1))
evi_dots_mean = evi_dots.sum(dim = (1, -1)) / (denom + eps)
else:
evi_dots_mean = evi_dots.mean(dim = (1, -1))
normed_evi_dots = evi_dots_mean.softmax(dim = -1)
normed_evi_dots.detach_()
doc_similarities = doc_similarities.softmax(dim = -1).log()
loss = F.kl_div(doc_similarities, normed_evi_dots, reduction = 'batchmean')
return loss
# 辅助类
# 带有 LayerNorm 的预正规化
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, *args, **kwargs):
x = self.norm(x)
return self.fn(x, *args, **kwargs)
# GEGLU 激活函数
class GEGLU(nn.Module):
def forward(self, x):
x, gates = x.chunk(2, dim = -1)
return F.gelu(gates) * x
# 前馈神经网络
class FeedForward(nn.Module):
def __init__(self, dim, mult = 4, dropout = 0.):
super().__init__()
# 为了保持参数数量/计算量与非 GLU 变体相对恒定
mult = int(mult / 3 * 2)
self.net = nn.Sequential(
nn.Linear(dim, dim * mult * 2),
GEGLU(),
nn.Dropout(dropout),
nn.Linear(dim * mult, dim)
)
def forward(self, x):
return self.net(x)
# 自注意力机制
class SelfAttention(nn.Module):
def __init__(self, dim, heads = 8, causal = True, dropout = 0.):
super().__init__()
self.scale = dim ** -0.5
self.heads = heads
self.causal = causal
self.to_qkv = nn.Linear(dim, dim * 3, bias = False)
self.to_out = nn.Linear(dim, dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask = None):
_, n, _, h, device = *x.shape, self.heads, x.device
qkv = self.to_qkv(x)
q, k, v = rearrange(qkv, 'b n (qkv h d) -> qkv b h n d', h = h, qkv = 3)
dots = einsum('bhid,bhjd->bhij', q, k) * self.scale
mask_value = max_neg_value(dots)
if exists(mask):
mask = mask[:, None, :, None] * mask[:, None, None, :]
dots.masked_fill_(~mask, mask_value)
del mask
if self.causal:
causal_mask = torch.ones(n, n, device=device).triu_(1).bool()
dots.masked_fill_(causal_mask, mask_value)
del causal_mask
attn = dots.softmax(dim=-1)
attn = self.dropout(attn)
out = einsum('bhij,bhjd->bhid', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out)
return out
class CrossAttention(nn.Module):
# 初始化函数,设置注意力机制的参数
def __init__(self, dim, heads = 8, dropout = 0.):
# 调用父类的初始化函数
super().__init__()
# 计算缩放因子
self.scale = dim ** -0.5
# 设置头数
self.heads = heads
# 线性变换,将输入转换为查询向量
self.to_q = nn.Linear(dim, dim, bias = False)
# 线性变换,将输入转换为键值对
self.to_kv = nn.Linear(dim, dim * 2, bias = False)
# 初始化可学习参数 beta
self.beta = nn.Parameter(torch.tensor(1.), requires_grad=True)
# 线性变换,将输出转换为最终输出
self.to_out = nn.Linear(dim, dim)
# Dropout 层,用于防止过拟合
self.dropout = nn.Dropout(dropout)
# 前向传播函数
def forward(self, x, context, doc_similarities, mask = None, context_mask = None):
# 获取输入 x 的形状信息
b, n, _, h, device = *x.shape, self.heads, x.device
# 将输入 x 转换为查询向量 q
q = self.to_q(x)
q = rearrange(q, 'b n (h d) -> b h n d', h = h)
# 重排上下文信息 context 的形状
context_len = context.shape[2]
context = rearrange(context, 'b m n d -> b (m n) d')
context_mask = rearrange(context_mask, 'b m n -> b (m n)') if exists(context_mask) else None
# 重复文档相似度信息 doc_similarities
doc_similarities = repeat(doc_similarities, 'b m -> b m n', n=context_len)
doc_similarities = rearrange(doc_similarities, 'b m n -> b (m n)')
doc_similarities = doc_similarities[:, None, None, :] * self.beta
# 将上下文信息 context 转换为键值对 k, v
kv = self.to_kv(context)
k, v = rearrange(kv, 'b n (kv h d) -> kv b h n d', h = h, kv = 2)
# 计算注意力分数
dots = einsum('bhid,bhjd->bhij', q, k) * self.scale
pre_attn_dots = dots
# 添加文档相似度信息到注意力分数
dots = dots + doc_similarities
# 处理掩码信息
if any(map(exists, (mask, context_mask))):
if not exists(mask):
mask = torch.full((b, n), True, dtype=torch.bool, device=device)
if not exists(context_mask):
context_mask = torch.full(context.shape[:2], True, dtype=torch.bool, device=device)
cross_mask = mask[:, None, :, None] * context_mask[:, None, None, :]
mask_value = max_neg_value(dots)
dots.masked_fill_(~cross_mask, mask_value)
del cross_mask
# 计算注意力权重
attn = dots.softmax(dim=-1)
attn = self.dropout(attn)
# 计算输出
out = einsum('bhij,bhjd->bhid', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out)
return out, pre_attn_dots
class Encoder(nn.Module):
def __init__(self, dim, depth, retrieval_depth = 4, heads = 8, ff_mult = 4, attn_dropout = 0., ff_dropout = 0.):
super().__init__()
assert depth > retrieval_depth, f'Depth must be at least the depth set for the retrieval encoder ({retrieval_depth})'
# 定义一个 lambda 函数,用于创建包含 SelfAttention 和 FeedForward 的模块列表
block = lambda: nn.ModuleList([
PreNorm(dim, SelfAttention(dim, causal=False, dropout = attn_dropout)),
PreNorm(dim, FeedForward(dim, mult = ff_mult))
])
# 初始化模型参数
self.cls = nn.Parameter(torch.zeros(1, dim), requires_grad=True)
self.encoder_head = nn.ModuleList([])
self.encoder_tail = nn.ModuleList([])
# 创建 retrieval_depth 个 encoder_head 模块
for _ in range(retrieval_depth):
self.encoder_head.append(block())
# 创建 depth - retrieval_depth 个 encoder_tail 模块
for _ in range(depth - retrieval_depth):
self.encoder_tail.append(block())
def forward(self, x, src_mask = None, return_embed_only = False):
b, _, _ = x.shape
# 添加 cls token
cls_token = repeat(self.cls, 'n d -> b n d', b=b)
x = torch.cat((cls_token, x), dim=1)
src_mask = F.pad(src_mask, (1, 0), value=True) if exists(src_mask) else None
# 对 encoder_head 中的模块进行前向传播
for attn, ff in self.encoder_head:
x = attn(x, mask = src_mask) + x
x = ff(x) + x
cls_tokens = x[:, 0]
if return_embed_only:
return cls_tokens, None
# 对 encoder_tail 中的模块进行前向传播
for attn, ff in self.encoder_tail:
x = attn(x, mask = src_mask) + x
x = ff(x) + x
return x[:, 1:], cls_tokens
class Decoder(nn.Module):
def __init__(self, dim, depth, head_depth = 4, heads = 8, ff_mult = 4, attn_dropout = 0., ff_dropout = 0.):
super().__init__()
self.decoder_head = nn.ModuleList([])
self.decoder_tail = nn.ModuleList([])
# 创建 head_depth 个 decoder_head 模块
for _ in range(head_depth):
self.decoder_head.append(nn.ModuleList([
PreNorm(dim, SelfAttention(dim, causal = True, dropout = attn_dropout)),
PreNorm(dim, FeedForward(dim))
]))
# 创建 depth - head_depth 个 decoder_tail 模块
for _ in range(depth - head_depth):
self.decoder_tail.append(nn.ModuleList([
PreNorm(dim, SelfAttention(dim, causal = True, dropout = attn_dropout)),
PreNorm(dim, FeedForward(dim)),
PreNorm(dim, CrossAttention(dim, dropout = attn_dropout)),
PreNorm(dim, FeedForward(dim, mult = ff_mult))
]))
def forward(self, x, *, context, similarities, src_mask = None, context_mask = None):
# 对 decoder_head 中的模块进行前向传播
for self_attn, self_ff in self.decoder_head:
x = self_attn(x, mask = src_mask) + x
x = self_ff(x) + x
cross_pre_attns = []
# 对 decoder_tail 中的模块进行前向传播
for self_attn, self_ff, cross_attn, cross_ff in self.decoder_tail:
x = self_attn(x, mask = src_mask) + x
x = self_ff(x) + x
x_out, attn = cross_attn(x, context, similarities, mask = src_mask, context_mask = context_mask)
x = x_out + x
x = cross_ff(x) + x
cross_pre_attns.append(attn)
return x, cross_pre_attns
class TransformerWrapper(nn.Module):
def __init__(self, num_tokens, dim, max_seq_len, layers, return_logits = False):
super().__init__()
self.token_emb = nn.Embedding(num_tokens, dim)
self.pos_emb = nn.Embedding(max_seq_len, dim)
self.max_seq_len = max_seq_len
self.layers = layers
self.to_logits = nn.Linear(dim, num_tokens) if return_logits else identity
def forward(self, x, *args, **kwargs):
b, n, device = *x.shape, x.device
assert n <= self.max_seq_len, f'your sequence length {n} needs to be less than or equal to the max sequence length {self.max_seq_len}'
x = self.token_emb(x)
x += self.pos_emb(torch.arange(n, device=device))
x, *out = self.layers(x, *args, **kwargs)
return (self.to_logits(x), *out)
class Marge(nn.Module):
# 初始化函数,设置模型参数
def __init__(
self,
dim,
num_tokens = 20000,
max_seq_len = 1024,
enc_depth = 12,
enc_retrieval_depth = 4,
enc_heads = 8,
enc_ff_mult = 4,
enc_attn_dropout = 0.,
enc_ff_dropout = 0.,
dec_depth = 12,
dec_heads = 8,
dec_ff_mult = 16,
dec_attn_dropout = 0.,
dec_ff_dropout = 0.,
distill_attn = False,
distill_loss_coef = 1.
):
# 调用父类的初始化函数
super().__init__()
# 设置模型维度
self.dim = dim
# 创建编码器和解码器对象
self.encoder = TransformerWrapper(num_tokens, dim, max_seq_len, Encoder(dim, depth = enc_depth, retrieval_depth = enc_retrieval_depth, heads = enc_heads, ff_mult = enc_ff_mult, attn_dropout = enc_attn_dropout, ff_dropout = enc_ff_dropout))
self.decoder = TransformerWrapper(num_tokens, dim, max_seq_len, Decoder(dim, depth = dec_depth, heads = dec_heads, ff_mult = dec_ff_mult, attn_dropout = dec_attn_dropout, ff_dropout = dec_ff_dropout), return_logits = True)
# 共享编码器和解码器的词嵌入层
self.encoder.token_emb = self.decoder.token_emb
# 将解码器包装为自回归模型
self.decoder = AutoregressiveWrapper(self.decoder)
# 实验性的注意力蒸馏设置
self.distill_attn = distill_attn
self.distill_loss_coef = distill_loss_coef
# 获取文档的嵌入表示
def get_embeds(self, documents, batch_size = 16, masks = None):
embeds = []
# 将文档分成批次
batched_documents = documents.split(batch_size)
batched_masks = masks.split(batch_size) if exists(masks) else ([None] * len(batched_documents))
# 对每个批次的文档计算嵌入表示
for docs, mask in zip(batched_documents, batched_masks):
embed, *_ = self.encoder(docs, src_mask = mask, return_embed_only = True)
embeds.append(embed)
# 拼接所有嵌入表示并进行归一化
embeds = torch.cat(embeds)
return F.normalize(embeds, dim=-1)
# 生成文本序列
@torch.no_grad()
def generate(self, prime, seq_len, evidence, mask = None, similarities = None):
b, num_evidences, *_ = evidence.shape
evidence = rearrange(evidence, 'b m n -> (b m) n')
enc_src_mask = rearrange(mask, 'b m n -> (b m) n') if exists(mask) else None
# 编码证据文本
encodings, evidence_embeds = self.encoder(evidence, src_mask = enc_src_mask)
encodings = rearrange(encodings, '(b m) n d -> b m n d', m = num_evidences)
# 计算相似度
similarities = similarities if exists(similarities) else torch.ones((b, num_evidences)).float().cuda()
context_mask = F.pad(mask, (1, 0), value = True) if exists(mask) else None
return self.decoder.generate(prime, seq_len, context = encodings, similarities = similarities, context_mask = context_mask)
# 前向传播函数
def forward(self, evidence, target, target_embeds, src_mask = None, tgt_mask = None):
num_evidences = evidence.shape[1]
evidence = rearrange(evidence, 'b m n -> (b m) n')
enc_src_mask = rearrange(src_mask, 'b m n -> (b m) n') if exists(src_mask) else None
encodings, evidence_embeds = self.encoder(evidence, src_mask = enc_src_mask)
encodings = rearrange(encodings, '(b m) n d -> b m n d', m = num_evidences)
evidence_embeds = rearrange(evidence_embeds, '(b m) d -> b m d', m = num_evidences)
# 计算相似度
similarities = einsum('bmd,bd->bm', evidence_embeds, target_embeds)
dec_src_mask = tgt_mask[:, :-1] if exists(tgt_mask) else None
# 计算损失和交叉注意力
loss, cross_attns = self.decoder(target, context = encodings, similarities = similarities, src_mask = dec_src_mask, context_mask = src_mask)
# 如果开启了注意力蒸馏
if self.distill_attn:
cross_attns = torch.stack(cross_attns, dim = 1)
cross_attns = rearrange(cross_attns, 'b l h i (n j) -> b l h i n j', n = num_evidences)
distill_loss = distill_attn_loss(cross_attns, similarities, mask = src_mask)
aux_loss = self.distill_loss_coef * distill_loss
loss = loss + aux_loss
return loss
# training related classes
# 从证据中移除目标
def remove_target_from_evidence(evidence_ids, target_ids):
b, n = evidence_ids.shape
# 创建匹配掩码,标记证据中是否存在目标
match_mask = evidence_ids == target_ids[:, None]
# 创建行没有匹配项的掩码
rows_without_matches = (match_mask.sum(axis=-1) == 0)[:, None]
# 创建需要移除的掩码
remove_mask = np.concatenate((np.full((b, n - 1), False), rows_without_matches), axis=1)
# 合并匹配掩码和移除掩码
mask = match_mask + remove_mask
# 过滤掉匹配和需要移除的证据
filtered_ids = evidence_ids[~mask]
return filtered_ids.reshape(b, n - 1)
# 文档数据集类
class DocumentDataset(Dataset):
def __init__(self, num_docs, doc_seq_len, num_evidences, documents_path, masks_path, num_targets, target_seq_len, target_path, target_masks_path):
super().__init__()
self.shape = (num_docs, doc_seq_len)
self.target_shape = (num_targets, target_seq_len)
self.knn_shape = (num_targets, num_evidences)
self.documents = np.memmap(documents_path, dtype=np.int32, shape=self.shape)
self.targets = np.memmap(target_path, dtype=np.int32, shape=self.target_shape)
self.masks = np.memmap(masks_path, dtype=np.bool, shape=self.shape) if exists(masks_path) else None
self.target_masks = np.memmap(target_masks_path, dtype=np.bool, shape=self.target_shape) if exists(target_masks_path) else None
self.knn = None
# 设置最近邻路径
def set_knn_path(self, path):
if exists(self.knn):
del self.knn
self.knn = np.memmap(path, dtype=np.int32, shape=self.knn_shape)
def __len__(self):
return self.target_shape[0]
def __getitem__(self, ind):
assert exists(self.knn), 'The memmap path to the generated k nearest neighbors for evidences must be set for the dataset'
target_data = torch.from_numpy(self.targets[ind, :]).long()
target_masks = torch.from_numpy(self.target_masks[ind, :]) if exists(self.target_masks) else torch.ones_like(target_data).bool()
evidence_ids = self.knn[ind, :]
evidence_data = torch.from_numpy(self.documents[evidence_ids, :]).long()
evidence_masks = torch.from_numpy(self.masks[evidence_ids, :]) if exists(self.masks) else torch.ones_like(evidence_data).bool()
return target_data.cuda(), target_masks.cuda(), evidence_data.cuda(), evidence_masks.cuda()
# FaissANN 类
class FaissANN():
def __init__(
self,
dim,
num_documents,
num_subvectors = 16,
hnsw_m = 32,
nbits = 8
):
super().__init__()
nlist = math.floor(math.sqrt(num_documents))
quantizer = faiss.IndexHNSWFlat(dim, hnsw_m)
index = faiss.IndexIVFPQ(quantizer, dim, nlist, num_subvectors, nbits)
self.index = faiss.index_cpu_to_all_gpus(index)
self.num_training = max(nlist * 10, 256)
def reset(self):
return self.index.reset()
def train(self, x):
return self.index.train(x)
def add(self, x):
return self.index.add(x)
def search(self, x, topk, nprobe=8):
self.index.nprobe = nprobe
return self.index.search(x, k=topk)
# 训练包装类
class TrainingWrapper(nn.Module):
def __init__(
self,
model,
*,
num_documents,
doc_seq_len,
documents_memmap_path,
masks_memmap_path = None,
num_targets = None,
target_seq_len = None,
target_memmap_path = None,
target_masks_memmap_path = None,
num_evidence = 4,
reindex_batch_size = 4,
use_faiss_ann = False
# 初始化函数,继承父类的初始化方法
def __init__(
self,
model,
num_documents,
doc_seq_len,
documents_memmap_path,
num_evidence,
num_targets=None,
target_memmap_path=None,
target_masks_memmap_path=None,
target_seq_len=None,
use_faiss_ann=False,
reindex_batch_size=1000
):
# 调用父类的初始化方法
super().__init__()
# 设置模型的维度和证据数量
self.dim = model.dim
self.num_evidence = num_evidence
# 将模型移到 GPU 上
self.model = model.cuda()
self.num_docs = num_documents
# 设置目标数量,默认为文档数量
num_targets = default(num_targets, num_documents)
self.num_targets = num_targets
# 设置文档的形状
self.doc_shape = (num_documents, doc_seq_len)
# 设置文档路径和是否分开目标和证据
self.documents_path = documents_memmap_path
self.separate_target_and_evidence = exists(target_memmap_path)
# 如果分开目标和证据
if self.separate_target_and_evidence:
assert exists(num_targets), 'number of target documents must be defined if target document set is different than evidence document set'
assert exists(target_seq_len), 'target sequence length must be specified'
else:
# 否则设置目标路径和序列长度
target_memmap_path = default(target_memmap_path, documents_memmap_path)
target_masks_memmap_path = default(target_masks_memmap_path, masks_memmap_path)
target_seq_len = default(target_seq_len, doc_seq_len)
# 设置目标的形状和路径
self.target_shape = (num_targets, target_seq_len)
self.target_path = target_memmap_path
self.knn_path = f'{self.documents_path}.knn'
# 设置是否使用 Faiss 近似最近邻搜索
self.use_faiss_ann = use_faiss_ann
if use_faiss_ann:
self.index = FaissANN(self.dim, self.num_docs)
else:
index = faiss.IndexFlatL2(self.dim)
self.index = faiss.index_cpu_to_all_gpus(index)
# 设置重新索引的批量大小并重新索引
self.reindex_batch_size = reindex_batch_size
self.reindex()
# 创建数据集
self.dataset = DocumentDataset(
num_documents,
doc_seq_len,
num_evidence,
documents_memmap_path,
masks_memmap_path,
num_targets,
target_seq_len,
target_memmap_path,
target_masks_memmap_path
)
# 设置数据集的 KNN 路径
self.dataset.set_knn_path(self.knn_path)
# 获取数据集的方法
def get_dataset(self):
return self.dataset
# 禁用梯度计算
@torch.no_grad()
# 重新索引方法,用于更新索引
def reindex(self):
# 设置批处理大小
batch_size = self.reindex_batch_size
# 定义获取嵌入向量的函数
def get_embeds(data):
# 获取模型的嵌入向量并转换为 NumPy 数组
embeds = self.model.get_embeds(data, batch_size=batch_size)
return embeds.detach().cpu().numpy()
# 使用内存映射打开文档路径、目标路径和最近邻路径
with memmap(self.documents_path, dtype=np.int32, shape=self.doc_shape) as (doc_pointer
), memmap(self.target_path, dtype=np.int32, shape=self.target_shape) as (target_pointer
), memmap(self.knn_path, dtype=np.int32, shape=(self.num_docs, self.num_evidence), mode='w+') as knn_writer:
# 如果使用 Faiss 近似最近邻搜索
if self.use_faiss_ann:
# 随机选择部分文档进行训练
random_indices = np.random.permutation(self.num_docs)[:self.index.num_training]
np_data = torch.from_numpy(doc_pointer[random_indices]).cuda().long()
train_embeds = get_embeds(np_data)
# 训练索引
self.index.train(train_embeds)
# 计算总的文档块数
total_evidence_chunks = math.ceil(self.num_docs / batch_size)
# 遍历文档数据块,将嵌入向量添加到索引中
for data_slice in tqdm(chunk(batch_size, self.num_docs), total=total_evidence_chunks, desc='Adding embedding to indexes'):
np_data = torch.from_numpy(doc_pointer[data_slice, :]).cuda().long()
embeds = get_embeds(np_data)
self.index.add(embeds)
# 计算总的目标块数
total_target_chunks = math.ceil(self.num_targets / batch_size)
# 遍历目标数据块,获取并存储最近邻
for data_slice in tqdm(chunk(batch_size, self.num_targets), total=total_target_chunks, desc='Fetching and storing nearest neighbors'):
np_data = torch.from_numpy(target_pointer[data_slice, :]).cuda().long()
embeds = get_embeds(np_data)
fetch_num_evidences = self.num_evidence + (0 if self.separate_target_and_evidence else 1)
# 搜索最近邻
_, evidence_ids = self.index.search(embeds, fetch_num_evidences)
target_ids = np.arange(data_slice.start, data_slice.stop)
# 如果不分离目标和证据
if not self.separate_target_and_evidence:
evidence_ids = remove_target_from_evidence(evidence_ids, target_ids)
# 将最近邻写入内存映射
knn_writer[data_slice, :] = evidence_ids
# 重置索引
self.index.reset()
# 打印重新索引完成信息
print('reindexing complete')
# 前向传播方法,用于计算损失
def forward(self, data):
# 解析输入数据
targets, target_masks, evidences, evidence_masks = data
# 获取目标嵌入向量
target_embeds = self.model.get_embeds(targets, masks=target_masks)
# 计算损失
loss = self.model(evidences, targets, target_embeds, src_mask=evidence_masks, tgt_mask=target_masks)
return loss
.\lucidrains\marge-pytorch\marge_pytorch\__init__.py
# 从 marge_pytorch 模块中导入 Marge 和 TrainingWrapper 类
# 从 marge_pytorch 模块中导入 AutoregressiveWrapper 类
from marge_pytorch.marge_pytorch import Marge, TrainingWrapper
from marge_pytorch.autoregressive_wrapper import AutoregressiveWrapper

Marge - Pre-training via Paraphrasing
Implementation of Marge, Pre-training via Paraphrasing, in Pytorch. It is an alternative to masked language modeling pretraining, where an encoder / decoder attention network learns to reconstruct a target document from a collection of evidence documents.
Update: Three researchers have independently reported that the repository works for them
Install
$ pip install marge-pytorch
Usage
import torch
import numpy as np
from torch.utils.data import DataLoader
from marge_pytorch import Marge, TrainingWrapper
# your documents must be tokenized and stored as memmap in the shape (num documents, seq length)
# constants
NUM_DOCS = 10000
SEQ_LEN = 1024
SHAPE = (NUM_DOCS, SEQ_LEN)
# generate mock training data
f = np.memmap('./train.dat', dtype=np.int32, mode='w+', shape=SHAPE)
f[:] = np.random.randint(0, 20000, size=SHAPE)
del f
# generate mock masking data
f = np.memmap('./train.mask.dat', dtype=np.bool, mode='w+', shape=SHAPE)
f[:] = np.full(SHAPE, True)
del f
# instantiate model
model = Marge(
dim = 512,
num_tokens = 20000,
max_seq_len = SEQ_LEN,
enc_depth = 12,
enc_retrieval_depth = 4, # defaults to 4 as in paper (take the CLS token after the 4th layer of the encoder)
enc_heads = 8,
enc_ff_mult = 4,
dec_depth = 12,
dec_heads = 8,
dec_ff_mult = 16, # paper noted that decoder needs to have much bigger feed forward sizes
distill_attn = False, # (experimental) will add, on top of the decoder loss, an auxiliary distillation loss as defined in https://arxiv.org/abs/2012.04584
distill_loss_coef = 1. # weight of distillation auxilliary loss
)
# wrap your model and your documents
trainer = TrainingWrapper(
model,
num_documents = NUM_DOCS,
doc_seq_len = SEQ_LEN,
num_evidence = 4, # number of evidence documents to fetch per target document to construct
reindex_batch_size = 32, # batch size to use when reindexing
documents_memmap_path = './train.dat', # path to the mem-mapped documents
masks_memmap_path = './train.mask.dat', # if None is supplied, will assume all tokens are visible
use_faiss_ann = True # set this to false if you have a low number of documents, and approximate nearest neighbor is not needed
)
# instantiate dataloader
dl = DataLoader(trainer.dataset, batch_size=16)
# now you can train, and use the reindex method on the training wrapper at appropriate intervals
for ind, data in enumerate(dl):
loss = trainer(data)
loss.backward()
# optimizer step and all that
# reindex and precompute knn every 10000 steps, as in paper
if ind > 0 and ind % 10000 == 0:
trainer.reindex()
Save your model after much training
torch.save(model, f'./trained-model.pt')
Advanced
If you would like the target and evidence documents to be from different sets, you just have to pass in up to four additional keyword arguments, as shown below.
trainer = TrainingWrapper(
model,
num_documents = NUM_DOCS,
doc_seq_len = SEQ_LEN,
num_evidence = 4,
reindex_batch_size = 32,
documents_memmap_path = './evidence.dat',
masks_memmap_path = './evidence.mask.dat',
num_targets = NUM_TARGETS, # 1. number of target documents, with sequence length the same as the document (evidence)
target_seq_len = SEQ_LEN, # 2. sequence length of target documents
target_memmap_path = './target.dat', # 3. path to target memmap, same as documents (evidence)
target_masks_memmap_path = './target.mask.dat', # 4. path to target mask memmap, same as document masks (evidence)
use_faiss_ann = True
)
Sampling
You can sample from the decoder with the following instructions
# some random evidence from the dataset
# or provide your own in the dimensions (b x num_evidences x seq_len)
*_, evidence, mask = trainer.dataset[0:1]
# assume 1 is start token
prime = torch.tensor([[1.]]).long().cuda()
# supply your own document similarities array (b x num_evidences)
# if not supplied, will default to 1. for all evidence
doc_similarities = torch.ones(evidence.shape[:2]).float().cuda()
# generate sample of length 1024
samples = model.generate(prime, 1024, evidence, mask = mask, similarities = doc_similarities)
Citations
@misc{lewis2020pretraining,
title={Pre-training via Paraphrasing},
author={Mike Lewis and Marjan Ghazvininejad and Gargi Ghosh and Armen Aghajanyan and Sida Wang and Luke Zettlemoyer},
year={2020},
eprint={2006.15020},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
@misc{komatsuzaki2020current,
title={Current Limitations of Language Models: What You Need is Retrieval},
author={Aran Komatsuzaki},
year={2020},
eprint={2009.06857},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
@misc{izacard2020distilling,
title={Distilling Knowledge from Reader to Retriever for Question Answering},
author={Gautier Izacard and Edouard Grave},
year={2020},
eprint={2012.04584},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
.\lucidrains\marge-pytorch\setup.py
# 导入设置工具和查找包工具
from setuptools import setup, find_packages
# 设置包的信息
setup(
name = 'marge-pytorch', # 包名
packages = find_packages(), # 查找所有包
version = '0.2.9', # 版本号
license='MIT', # 许可证
description = 'Marge - Pytorch', # 描述
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
url = 'https://github.com/lucidrains/marge-pytorch', # 项目链接
keywords = [ # 关键词列表
'artificial intelligence',
'attention mechanism',
'transformers',
'pre-training'
],
install_requires=[ # 安装依赖
'einops>=0.3',
'faiss-gpu',
'numpy',
'torch>=1.6',
'tqdm'
],
classifiers=[ # 分类器列表
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\med-seg-diff-pytorch\driver.py
import os
import argparse
from tqdm import tqdm
import torch
import numpy as np
import torchvision.transforms as transforms
from torch.optim import AdamW
from lion_pytorch import Lion
from med_seg_diff_pytorch import Unet, MedSegDiff
from med_seg_diff_pytorch.dataset import ISICDataset, GenericNpyDataset
from accelerate import Accelerator
import wandb
## Parse CLI arguments ##
def parse_args():
# 创建参数解析器
parser = argparse.ArgumentParser()
# 添加参数选项
parser.add_argument('-slr', '--scale_lr', action='store_true', help="Whether to scale lr.")
parser.add_argument('-rt', '--report_to', type=str, default="wandb", choices=["wandb"],
help="Where to log to. Currently only supports wandb")
parser.add_argument('-ld', '--logging_dir', type=str, default="logs", help="Logging dir.")
parser.add_argument('-od', '--output_dir', type=str, default="output", help="Output dir.")
parser.add_argument('-mp', '--mixed_precision', type=str, default="no", choices=["no", "fp16", "bf16"],
help="Whether to do mixed precision")
parser.add_argument('-ga', '--gradient_accumulation_steps', type=int, default=4,
help="The number of gradient accumulation steps.")
parser.add_argument('-img', '--img_folder', type=str, default='ISBI2016_ISIC_Part3B_Training_Data',
help='The image file path from data_path')
parser.add_argument('-csv', '--csv_file', type=str, default='ISBI2016_ISIC_Part3B_Training_GroundTruth.csv',
help='The csv file to load in from data_path')
parser.add_argument('-sc', '--self_condition', action='store_true', help='Whether to do self condition')
parser.add_argument('-lr', '--learning_rate', type=float, default=5e-4, help='learning rate')
parser.add_argument('-ab1', '--adam_beta1', type=float, default=0.95,
help='The beta1 parameter for the Adam optimizer.')
parser.add_argument('-ab2', '--adam_beta2', type=float, default=0.999,
help='The beta2 parameter for the Adam optimizer.')
parser.add_argument('-aw', '--adam_weight_decay', type=float, default=1e-6,
help='Weight decay magnitude for the Adam optimizer.')
parser.add_argument('-ae', '--adam_epsilon', type=float, default=1e-08,
help='Epsilon value for the Adam optimizer.')
parser.add_argument('-ul', '--use_lion', type=bool, default=False, help='use Lion optimizer')
parser.add_argument('-ic', '--mask_channels', type=int, default=1, help='input channels for training (default: 3)')
parser.add_argument('-c', '--input_img_channels', type=int, default=3,
help='output channels for training (default: 3)')
parser.add_argument('-is', '--image_size', type=int, default=128, help='input image size (default: 128)')
parser.add_argument('-dd', '--data_path', default='./data', help='directory of input image')
parser.add_argument('-d', '--dim', type=int, default=64, help='dim (default: 64)')
parser.add_argument('-e', '--epochs', type=int, default=10000, help='number of epochs (default: 10000)')
parser.add_argument('-bs', '--batch_size', type=int, default=8, help='batch size to train on (default: 8)')
parser.add_argument('--timesteps', type=int, default=1000, help='number of timesteps (default: 1000)')
parser.add_argument('-ds', '--dataset', default='generic', help='Dataset to use')
parser.add_argument('--save_every', type=int, default=100, help='save_every n epochs (default: 100)')
parser.add_argument('--load_model_from', default=None, help='path to pt file to load from')
# 解析参数并返回
return parser.parse_args()
def load_data(args):
# 加载数据集
# 如果数据集为ISIC,则定义ISIC数据集的转换操作列表
if args.dataset == 'ISIC':
transform_list = [transforms.Resize((args.image_size, args.image_size)), transforms.ToTensor(), ]
transform_train = transforms.Compose(transform_list)
# 创建ISIC数据集对象
dataset = ISICDataset(args.data_path, args.csv_file, args.img_folder, transform=transform_train, training=True,
flip_p=0.5)
# 如果数据集为generic,则定义generic数据集的转换操作列表
elif args.dataset == 'generic':
transform_list = [transforms.ToPILImage(), transforms.Resize(args.image_size), transforms.ToTensor()]
transform_train = transforms.Compose(transform_list)
# 创建generic数据集对象
dataset = GenericNpyDataset(args.data_path, transform=transform_train, test_flag=False)
# 如果数据集不是ISIC或generic,则抛出未实现错误
else:
raise NotImplementedError(f"Your dataset {args.dataset} hasn't been implemented yet.")
## 定义PyTorch数据生成器
training_generator = torch.utils.data.DataLoader(
dataset,
batch_size=args.batch_size,
shuffle=True)
# 返回训练数据生成器
return training_generator
def main():
# 解析命令行参数
args = parse_args()
# 创建检查点目录
checkpoint_dir = os.path.join(args.output_dir, 'checkpoints')
# 创建日志目录
logging_dir = os.path.join(args.output_dir, args.logging_dir)
# 如果目录不存在则创建
os.makedirs(checkpoint_dir, exist_ok=True)
# 初始化加速器
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with=args.report_to,
logging_dir=logging_dir,
)
# 如果是主进程则初始化跟踪器
if accelerator.is_main_process:
accelerator.init_trackers("med-seg-diff", config=vars(args))
## DEFINE MODEL ##
# 定义模型
model = Unet(
dim=args.dim,
image_size=args.image_size,
dim_mults=(1, 2, 4, 8),
mask_channels=args.mask_channels,
input_img_channels=args.input_img_channels,
self_condition=args.self_condition
)
## LOAD DATA ##
# 加载数据
data_loader = load_data(args)
# 如果需要缩放学习率,则重新计算学习率
if args.scale_lr:
args.learning_rate = (
args.learning_rate * args.gradient_accumulation_steps * args.batch_size * accelerator.num_processes
)
## Initialize optimizer
# 初始化优化器
if not args.use_lion:
optimizer = AdamW(
model.parameters(),
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
else:
optimizer = Lion(
model.parameters(),
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay
)
## TRAIN MODEL ##
counter = 0
# 准备模型、优化器和数据加载器
model, optimizer, data_loader = accelerator.prepare(
model, optimizer, data_loader
)
# 创建 MedSegDiff 对象
diffusion = MedSegDiff(
model,
timesteps=args.timesteps
).to(accelerator.device)
# 如果指定了加载模型的路径,则加载模型
if args.load_model_from is not None:
save_dict = torch.load(args.load_model_from)
diffusion.model.load_state_dict(save_dict['model_state_dict'])
optimizer.load_state_dict(save_dict['optimizer_state_dict'])
accelerator.print(f'Loaded from {args.load_model_from}')
## Iterate across training loop
# 遍历训练循环
for epoch in range(args.epochs):
running_loss = 0.0
print('Epoch {}/{}'.format(epoch + 1, args.epochs))
for (img, mask) in tqdm(data_loader):
with accelerator.accumulate(model):
loss = diffusion(mask, img)
running_loss += loss.item() * img.size(0)
accelerator.log({'loss': loss}) # Log loss to wandb
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
counter += 1
epoch_loss = running_loss / len(data_loader)
print('Training Loss : {:.4f}'.format(epoch_loss))
## INFERENCE ##
# 如果满足保存间隔条件,则保存模型
if epoch % args.save_every == 0:
torch.save({
'epoch': epoch,
'model_state_dict': diffusion.model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
}, os.path.join(checkpoint_dir, f'state_dict_epoch_{epoch}_loss_{epoch_loss}.pt'))
# 生成预测结果
pred = diffusion.sample(img).cpu().detach().numpy()
for tracker in accelerator.trackers:
if tracker.name == "wandb":
# 保存每个批次的一张图像
tracker.log(
{'pred-img-mask': [wandb.Image(pred[0, 0, :, :]), wandb.Image(img[0, 0, :, :]),
wandb.Image(mask[0, 0, :, :])]}
)
if __name__ == '__main__':
main()
.\lucidrains\med-seg-diff-pytorch\med_seg_diff_pytorch\dataset.py
import os
import numpy as np
# 设置环境变量,允许重复加载库
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
import torch
from torch.utils.data import Dataset
from PIL import Image
import pandas as pd
import random
import torchvision.transforms.functional as F
# 创建自定义数据集类 ISICDataset
class ISICDataset(Dataset):
def __init__(self, data_path, csv_file, img_folder, transform=None, training=True, flip_p=0.5):
# 读取 CSV 文件
df = pd.read_csv(os.path.join(data_path, csv_file), encoding='gbk')
self.img_folder = img_folder
self.name_list = df.iloc[:, 0].tolist()
self.label_list = df.iloc[:, 1].tolist()
self.data_path = data_path
self.transform = transform
self.training = training
self.flip_p = flip_p
def __len__(self):
return len(self.name_list)
def __getitem__(self, index):
"""Get the images"""
name = self.name_list[index] + '.jpg'
img_path = os.path.join(self.data_path, self.img_folder, name)
mask_name = name.split('.')[0] + '_Segmentation.png'
msk_path = os.path.join(self.data_path, self.img_folder, mask_name)
img = Image.open(img_path).convert('RGB')
mask = Image.open(msk_path).convert('L')
if self.training:
label = 0 if self.label_list[index] == 'benign' else 1
else:
label = int(self.label_list[index])
if self.transform:
# 保存随机状态,以便如果使用更复杂的转换,则将相同的转换应用于 mask 和 img
state = torch.get_rng_state()
img = self.transform(img)
torch.set_rng_state(state)
mask = self.transform(mask)
if random.random() < self.flip_p:
img = F.vflip(img)
mask = F.vflip(mask)
if self.training:
return (img, mask)
return (img, mask, label)
# 创建自定义数据集类 GenericNpyDataset
class GenericNpyDataset(torch.utils.data.Dataset):
def __init__(self, directory: str, transform, test_flag: bool = True):
'''
Genereic dataset for loading npy files.
The npy store 3D arrays with the first two dimensions being the image and the third dimension being the channels.
channel 0 is the image and the other channel is the label.
'''
super().__init__()
self.directory = os.path.expanduser(directory)
self.transform = transform
self.test_flag = test_flag
self.filenames = [x for x in os.listdir(self.directory) if x.endswith('.npy')]
def __getitem__(self, x: int):
fname = self.filenames[x]
npy_img = np.load(os.path.join(self.directory, fname))
img = npy_img[:, :, :1]
img = torch.from_numpy(img).permute(2, 0, 1)
mask = npy_img[:, :, 1:]
mask = np.where(mask > 0, 1, 0)
image = img[:, ...]
mask = torch.from_numpy(mask).permute(2, 0, 1).float()
if self.transform:
# 保存随机状态,以便如果使用更复杂的转换,则将相同的转换应用于 mask 和 img
state = torch.get_rng_state()
image = self.transform(image)
torch.set_rng_state(state)
mask = self.transform(mask)
if self.test_flag:
return image, mask, fname
return image, mask
def __len__(self) -> int:
return len(self.filenames)
.\lucidrains\med-seg-diff-pytorch\med_seg_diff_pytorch\med_seg_diff_pytorch.py
# 导入所需的库
import math
import copy
from random import random
from functools import partial
from collections import namedtuple
# 导入第三方库
from beartype import beartype
# 导入 PyTorch 库
import torch
from torch import nn, einsum
from torch.nn import Module, ModuleList
import torch.nn.functional as F
from torch.fft import fft2, ifft2
# 导入 einops 库
from einops import rearrange, reduce
from einops.layers.torch import Rearrange
# 导入 tqdm 库
from tqdm.auto import tqdm
# 定义常量
ModelPrediction = namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start'])
# 辅助函数
# 判断变量是否存在
def exists(x):
return x is not None
# 返回默认值
def default(val, d):
if exists(val):
return val
return d() if callable(d) else d
# 返回输入本身
def identity(t, *args, **kwargs):
return t
# 标准化函数
# 将图像标准化到 -1 到 1 之间
def normalize_to_neg_one_to_one(img):
return img * 2 - 1
# 将标准化后的图像反标准化到 0 到 1 之间
def unnormalize_to_zero_to_one(t):
return (t + 1) * 0.5
# 小型辅助模块
# 残差模块
class Residual(Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, *args, **kwargs):
return self.fn(x, *args, **kwargs) + x
# 上采样模块
def Upsample(dim, dim_out = None):
return nn.Sequential(
nn.Upsample(scale_factor = 2, mode = 'nearest'),
nn.Conv2d(dim, default(dim_out, dim), 3, padding = 1)
)
# 下采样模块
def Downsample(dim, dim_out = None):
return nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (c p1 p2) h w', p1 = 2, p2 = 2),
nn.Conv2d(dim * 4, default(dim_out, dim), 1)
)
# 层归一化模块
class LayerNorm(Module):
def __init__(self, dim, bias = False):
super().__init__()
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1)) if bias else None
def forward(self, x):
eps = 1e-5 if x.dtype == torch.float32 else 1e-3
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) * (var + eps).rsqrt() * self.g + default(self.b, 0)
# 正弦位置编码模块
class SinusoidalPosEmb(Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
# 构建块模块
# 基础块模块
class Block(Module):
def __init__(self, dim, dim_out, groups = 8):
super().__init__()
self.proj = nn.Conv2d(dim, dim_out, 3, padding = 1)
self.norm = nn.GroupNorm(groups, dim_out)
self.act = nn.SiLU()
def forward(self, x, scale_shift = None):
x = self.proj(x)
x = self.norm(x)
if exists(scale_shift):
scale, shift = scale_shift
x = x * (scale + 1) + shift
x = self.act(x)
return x
# ResNet 块模块
class ResnetBlock(Module):
def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8):
super().__init__()
self.mlp = nn.Sequential(
nn.SiLU(),
nn.Linear(time_emb_dim, dim_out * 2)
) if exists(time_emb_dim) else None
self.block1 = Block(dim, dim_out, groups = groups)
self.block2 = Block(dim_out, dim_out, groups = groups)
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
def forward(self, x, time_emb = None):
scale_shift = None
if exists(self.mlp) and exists(time_emb):
time_emb = self.mlp(time_emb)
time_emb = rearrange(time_emb, 'b c -> b c 1 1')
scale_shift = time_emb.chunk(2, dim = 1)
h = self.block1(x, scale_shift = scale_shift)
h = self.block2(h)
return h + self.res_conv(x)
# 前馈网络模块
def FeedForward(dim, mult = 4):
inner_dim = int(dim * mult)
# 返回一个包含多个层的神经网络模型
return nn.Sequential(
# 对输入数据进行层归一化
LayerNorm(dim),
# 1x1卷积层,将输入维度转换为inner_dim
nn.Conv2d(dim, inner_dim, 1),
# GELU激活函数
nn.GELU(),
# 1x1卷积层,将inner_dim维度转换为dim
nn.Conv2d(inner_dim, dim, 1),
)
class LinearAttention(Module):
# 定义线性注意力机制模块
def __init__(self, dim, heads = 4, dim_head = 32):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
hidden_dim = dim_head * heads
self.prenorm = LayerNorm(dim)
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
self.to_out = nn.Sequential(
nn.Conv2d(hidden_dim, dim, 1),
LayerNorm(dim)
)
def forward(self, x):
b, c, h, w = x.shape
x = self.prenorm(x)
qkv = self.to_qkv(x).chunk(3, dim = 1)
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv)
q = q.softmax(dim = -2)
k = k.softmax(dim = -1)
q = q * self.scale
context = torch.einsum('b h d n, b h e n -> b h d e', k, v)
out = torch.einsum('b h d e, b h d n -> b h e n', context, q)
out = rearrange(out, 'b h c (x y) -> b (h c) x y', h = self.heads, x = h, y = w)
return self.to_out(out)
class Attention(Module):
# 定义注意力机制模块
def __init__(self, dim, heads = 4, dim_head = 32):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
hidden_dim = dim_head * heads
self.prenorm = LayerNorm(dim)
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x):
b, c, h, w = x.shape
x = self.prenorm(x)
qkv = self.to_qkv(x).chunk(3, dim = 1)
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv)
q = q * self.scale
sim = einsum('b h d i, b h d j -> b h i j', q, k)
attn = sim.softmax(dim = -1)
out = einsum('b h i j, b h d j -> b h i d', attn, v)
out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w)
return self.to_out(out)
class Transformer(Module):
# 定义变压器模块
def __init__(
self,
dim,
dim_head = 32,
heads = 4,
depth = 1
):
super().__init__()
self.layers = ModuleList([])
for _ in range(depth):
self.layers.append(ModuleList([
Residual(Attention(dim, dim_head = dim_head, heads = heads)),
Residual(FeedForward(dim))
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x)
x = ff(x)
return x
# vision transformer for dynamic ff-parser
class ViT(Module):
# 定义视觉变压器模块
def __init__(
self,
dim,
*,
image_size,
patch_size,
channels = 3,
channels_out = None,
dim_head = 32,
heads = 4,
depth = 4,
):
super().__init__()
assert exists(image_size)
assert (image_size % patch_size) == 0
num_patches_height_width = image_size // patch_size
self.pos_emb = nn.Parameter(torch.zeros(dim, num_patches_height_width, num_patches_height_width))
channels_out = default(channels_out, channels)
patch_dim = channels * (patch_size ** 2)
output_patch_dim = channels_out * (patch_size ** 2)
self.to_tokens = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (c p1 p2) h w', p1 = patch_size, p2 = patch_size),
nn.Conv2d(patch_dim, dim, 1),
LayerNorm(dim)
)
self.transformer = Transformer(
dim = dim,
dim_head = dim_head,
depth = depth
)
self.to_patches = nn.Sequential(
LayerNorm(dim),
nn.Conv2d(dim, output_patch_dim, 1),
Rearrange('b (c p1 p2) h w -> b c (h p1) (w p2)', p1 = patch_size, p2 = patch_size),
)
nn.init.zeros_(self.to_patches[-2].weight)
nn.init.zeros_(self.to_patches[-2].bias)
# 定义前向传播函数,接收输入 x
def forward(self, x):
# 将输入 x 转换为 tokens
x = self.to_tokens(x)
# 将输入 x 与位置编码相加
x = x + self.pos_emb
# 使用 Transformer 处理输入 x
x = self.transformer(x)
# 将处理后的结果转换为 patches
return self.to_patches(x)
# 定义一个名为 Conditioning 的类,继承自 Module 类
class Conditioning(Module):
# 初始化函数,接受多个参数
def __init__(
self,
fmap_size,
dim,
dynamic = True,
image_size = None,
dim_head = 32,
heads = 4,
depth = 4,
patch_size = 16
):
# 调用父类的初始化函数
super().__init__()
# 创建一个可学习的参数 ff_parser_attn_map,维度为 (dim, fmap_size, fmap_size)
self.ff_parser_attn_map = nn.Parameter(torch.ones(dim, fmap_size, fmap_size))
# 设置是否为动态模式
self.dynamic = dynamic
# 如果是动态模式
if dynamic:
# 创建一个 ViT 模型,用于动态调整 ff_parser_attn_map
self.to_dynamic_ff_parser_attn_map = ViT(
dim = dim,
channels = dim * 2 * 2, # 输入和条件的通道数,考虑到复数(实部和虚部)
channels_out = dim,
image_size = image_size,
patch_size = patch_size,
heads = heads,
dim_head = dim_head
)
# 创建 LayerNorm 层,用于输入和条件的归一化
self.norm_input = LayerNorm(dim, bias = True)
self.norm_condition = LayerNorm(dim, bias = True)
# 创建一个 ResnetBlock 模块
self.block = ResnetBlock(dim, dim)
# 前向传播函数,接受输入 x 和条件 c
def forward(self, x, c):
# 获取 ff_parser_attn_map 参数
ff_parser_attn_map = self.ff_parser_attn_map
# 对输入 x 进行二维傅立叶变换
dtype = x.dtype
x = fft2(x)
# 如果是动态模式
if self.dynamic:
# 对条件 c 进行二维傅立叶变换
c_complex = fft2(c)
x_as_real, c_as_real = map(torch.view_as_real, (x, c_complex))
x_as_real, c_as_real = map(lambda t: rearrange(t, 'b d h w ri -> b (d ri) h w'), (x_as_real, c_as_real))
# 将 x 和 c 连接起来
to_dynamic_input = torch.cat((x_as_real, c_as_real), dim = 1)
# 使用 ViT 模型调整 ff_parser_attn_map
dynamic_ff_parser_attn_map = self.to_dynamic_ff_parser_attn_map(to_dynamic_input)
# 更新 ff_parser_attn_map
ff_parser_attn_map = ff_parser_attn_map + dynamic_ff_parser_attn_map
# 使用 ff_parser_attn_map 对 x 进行调制
x = x * ff_parser_attn_map
# 对 x 进行逆二维傅立叶变换,并取实部
x = ifft2(x).real
x = x.type(dtype)
# 在论文中的公式 3
# 对 x 和 c 进���归一化,然后相乘再乘以 c
normed_x = self.norm_input(x)
normed_c = self.norm_condition(c)
c = (normed_x * normed_c) * c
# 添加一个额外的块以允许更多信息的整合
# 在 Condition 块之后有一个下采样(但也许有一个更好的地方可以进行条件化,而不是就在下采样之前)
# 返回经过块处理后的 c
return self.block(c)
# 定义一个名为 Unet 的类,继承自 Module 类
@beartype
class Unet(Module):
# 初始化函数,接受多个参数
def __init__(
self,
dim,
image_size,
mask_channels = 1,
input_img_channels = 3,
init_dim = None,
out_dim = None,
dim_mults: tuple = (1, 2, 4, 8),
full_self_attn: tuple = (False, False, False, True),
attn_dim_head = 32,
attn_heads = 4,
mid_transformer_depth = 1,
self_condition = False,
resnet_block_groups = 8,
conditioning_klass = Conditioning,
skip_connect_condition_fmaps = False, # 是否在后续解码器上采样部分连接条件 fmaps
dynamic_ff_parser_attn_map = False, # 允许 ff-parser 根据输入动态调整。暂时排除条件
conditioning_kwargs: dict = dict(
dim_head = 32,
heads = 4,
depth = 4,
patch_size = 16
)
):
# 调用父类的构造函数
super().__init__()
# 设置图像大小
self.image_size = image_size
# 确定维度
# 输入图像通道数
self.input_img_channels = input_img_channels
# mask 通道数
self.mask_channels = mask_channels
# 是否自身条件
self.self_condition = self_condition
# 输出通道数为 mask 通道数
output_channels = mask_channels
# 如果有自身条件,mask 通道数变为原来的两倍,否则不变
mask_channels = mask_channels * (2 if self_condition else 1)
# 初始化维度为默认维度或者给定的维度
init_dim = default(init_dim, dim)
# 初始化卷积层,输入为 mask 通道数,输出为 init_dim,卷积核大小为 7x7,填充为 3
self.init_conv = nn.Conv2d(mask_channels, init_dim, 7, padding = 3)
# 条件初始化卷积层,输入为输入图像通道数,输出为 init_dim,卷积核大小为 7x7,填充为 3
self.cond_init_conv = nn.Conv2d(input_img_channels, init_dim, 7, padding = 3)
# 维度列表,包括初始化维度和后续维度的倍数
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
# 输入输出维度对
in_out = list(zip(dims[:-1], dims[1:]))
# 部分 ResnetBlock 类的初始化
block_klass = partial(ResnetBlock, groups = resnet_block_groups)
# 时间嵌入维度
time_dim = dim * 4
# 时间 MLP 模型
self.time_mlp = nn.Sequential(
SinusoidalPosEmb(dim),
nn.Linear(dim, time_dim),
nn.GELU(),
nn.Linear(time_dim, time_dim)
)
# 注意力相关参数
attn_kwargs = dict(
dim_head = attn_dim_head,
heads = attn_heads
)
# conditioner 设置
if conditioning_klass == Conditioning:
conditioning_klass = partial(
Conditioning,
dynamic = dynamic_ff_parser_attn_map,
**conditioning_kwargs
)
# 层
num_resolutions = len(in_out)
assert len(full_self_attn) == num_resolutions
# 条件器列表
self.conditioners = ModuleList([])
# 是否跳过连接条件特征图
self.skip_connect_condition_fmaps = skip_connect_condition_fmaps
# 下采样编码块
self.downs = ModuleList([])
curr_fmap_size = image_size
for ind, ((dim_in, dim_out), full_attn) in enumerate(zip(in_out, full_self_attn)):
is_last = ind >= (num_resolutions - 1)
attn_klass = Attention if full_attn else LinearAttention
self.conditioners.append(conditioning_klass(curr_fmap_size, dim_in, image_size = curr_fmap_size))
self.downs.append(ModuleList([
block_klass(dim_in, dim_in, time_emb_dim = time_dim),
block_klass(dim_in, dim_in, time_emb_dim = time_dim),
Residual(attn_klass(dim_in, **attn_kwargs)),
Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding = 1)
]))
if not is_last:
curr_fmap_size //= 2
# 中间块
mid_dim = dims[-1]
self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim)
self.mid_transformer = Transformer(mid_dim, depth = mid_transformer_depth, **attn_kwargs)
self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim)
# 条件编码路径与主编码路径相同
self.cond_downs = copy.deepcopy(self.downs)
self.cond_mid_block1 = copy.deepcopy(self.mid_block1)
# 上采样解码块
self.ups = ModuleList([])
for ind, ((dim_in, dim_out), full_attn) in enumerate(zip(reversed(in_out), reversed(full_self_attn))):
is_last = ind == (len(in_out) - 1)
attn_klass = Attention if full_attn else LinearAttention
skip_connect_dim = dim_in * (2 if self.skip_connect_condition_fmaps else 1)
self.ups.append(ModuleList([
block_klass(dim_out + skip_connect_dim, dim_out, time_emb_dim = time_dim),
block_klass(dim_out + skip_connect_dim, dim_out, time_emb_dim = time_dim),
Residual(attn_klass(dim_out, **attn_kwargs)),
Upsample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, padding = 1)
]))
# 投影到预测
self.final_res_block = block_klass(dim * 2, dim, time_emb_dim = time_dim)
self.final_conv = nn.Conv2d(dim, output_channels, 1)
# 定义前向传播函数,接受输入 x、时间 time、条件 cond、自身条件 x_self_cond
def forward(
self,
x,
time,
cond,
x_self_cond = None
):
# 获取输入 x 的数据类型和是否跳过连接的条件特征图
dtype, skip_connect_c = x.dtype, self.skip_connect_condition_fmaps
# 如果存在自身条件,将其与输入 x 进行拼接
if self.self_condition:
x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x))
x = torch.cat((x_self_cond, x), dim = 1)
# 对输入 x 进行初始卷积
x = self.init_conv(x)
# 复制输入 x 作为中间结果
r = x.clone()
# 对条件 cond 进行初始卷积
c = self.cond_init_conv(cond)
# 对时间 time 进行多层感知机处理
t = self.time_mlp(time)
# 初始化中间结果列表
h = []
# 遍历下采样模块、条件下采样模块和条件器
for (block1, block2, attn, downsample), (cond_block1, cond_block2, cond_attn, cond_downsample), conditioner in zip(self.downs, self.cond_downs, self.conditioners):
# 对输入 x 进行第一个块的处理
x = block1(x, t)
# 对条件 c 进行第一个块的处理
c = cond_block1(c, t)
# 将当前处理结果加入中间结果列表
h.append([x, c] if skip_connect_c else [x])
# 对输入 x 进行第二个块的处理
x = block2(x, t)
# 对条件 c 进行第二个块的处理
c = cond_block2(c, t)
# 对输入 x 进行注意力机制处理
x = attn(x)
# 对条件 c 进行注意力机制处理
c = cond_attn(c)
# 使用条件器对条件 c 进行处理
c = conditioner(x, c)
# 将当前处理结果加入中间结果列表
h.append([x, c] if skip_connect_c else [x])
# 对输入 x 进行下采样
x = downsample(x)
# 对条件 c 进行下采样
c = cond_downsample(c)
# 对输入 x 进行中间块1的处理
x = self.mid_block1(x, t)
# 对条件 c 进行中间块1的处理
c = self.cond_mid_block1(c, t)
# 将条件 c 加到输入 x 上
x = x + c
# 对输入 x 进行中间变换器处理
x = self.mid_transformer(x)
# 对输入 x 进行中间块2的处理
x = self.mid_block2(x, t)
# 遍历上采样模块
for block1, block2, attn, upsample in self.ups:
# 将中间结果与 h 中的结果拼接
x = torch.cat((x, *h.pop()), dim = 1)
# 对输入 x 进行第一个块的处理
x = block1(x, t)
# 将中间结果与 h 中的结果拼接
x = torch.cat((x, *h.pop()), dim = 1)
# 对输入 x 进行第二个块的处理
x = block2(x, t)
# 对输入 x 进行注意力机制处理
x = attn(x)
# 对输入 x 进行上采样
x = upsample(x)
# 将输入 x 与初始输入 r 拼接
x = torch.cat((x, r), dim = 1)
# 对拼接后的结果进行最终残差块处理
x = self.final_res_block(x, t)
# 返回最终卷积结果
return self.final_conv(x)
# 高斯扩散训练器类
# 从输入张量 a 中提取指定索引 t 对应的值,并根据 x_shape 的形状重新组织输出
def extract(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
# 线性的 beta 调度函数,根据总步数 timesteps 计算出 beta 的线性变化范围
def linear_beta_schedule(timesteps):
scale = 1000 / timesteps
beta_start = scale * 0.0001
beta_end = scale * 0.02
return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)
# 余弦形式的 beta 调度函数,根据总步数 timesteps 和参数 s 计算出 beta 的余弦变化范围
def cosine_beta_schedule(timesteps, s=0.008):
"""
余弦调度函数
参考 https://openreview.net/forum?id=-NEXDKk8gZ
"""
steps = timesteps + 1
x = torch.linspace(0, timesteps, steps, dtype=torch.float64)
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return torch.clip(betas, 0, 0.999)
# 医学分割扩散模块类,继承自 Module 类
class MedSegDiff(Module):
def __init__(
self,
model,
*,
timesteps=1000,
sampling_timesteps=None,
objective='pred_noise',
beta_schedule='cosine',
ddim_sampling_eta=1.
):
# 调用父类的构造函数
super().__init__()
# 如果传入的模型不是 Unet 类型,则取其 module 属性
self.model = model if isinstance(model, Unet) else model.module
# 获取模型的输入图像通道数、掩模通道数、自身条件、图像大小等属性
self.input_img_channels = self.model.input_img_channels
self.mask_channels = self.model.mask_channels
self.self_condition = self.model.self_condition
self.image_size = self.model.image_size
# 设置目标类型
self.objective = objective
# 检查目标类型是否合法
assert objective in {'pred_noise', 'pred_x0', 'pred_v'}, 'objective must be either pred_noise (predict noise) or pred_x0 (predict image start) or pred_v (predict v [v-parameterization as defined in appendix D of progressive distillation paper, used in imagen-video successfully])'
# 根据 beta_schedule 选择不同的 beta 调度
if beta_schedule == 'linear':
betas = linear_beta_schedule(timesteps)
elif beta_schedule == 'cosine':
betas = cosine_beta_schedule(timesteps)
else:
raise ValueError(f'unknown beta schedule {beta_schedule}')
# 计算 alphas
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
# 获取时间步数
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
# 设置采样相关参数
self.sampling_timesteps = default(sampling_timesteps, timesteps) # 默认采样时间步数为训练时间步数
assert self.sampling_timesteps <= timesteps
self.is_ddim_sampling = self.sampling_timesteps < timesteps
self.ddim_sampling_eta = ddim_sampling_eta
# 注册缓冲区函数
register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))
# 注册缓冲区
register_buffer('betas', betas)
register_buffer('alphas_cumprod', alphas_cumprod)
register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
# 计算扩散 q(x_t | x_{t-1}) 和其他参数
register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
# 计算后验 q(x_{t-1} | x_t, x_0) 参数
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
register_buffer('posterior_variance', posterior_variance)
register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
@property
def device(self):
# 返回参数的设备信息
return next(self.parameters()).device
def predict_start_from_noise(self, x_t, t, noise):
# 预测起始值从噪声
return (
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
)
def predict_noise_from_start(self, x_t, t, x0):
# 从起始值预测噪声
return (
(extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / \
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
)
def predict_v(self, x_start, t, noise):
# 预测 v
return (
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * noise -
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start
)
# 根据给定的输入 x_t, t 和 v 预测起始值
def predict_start_from_v(self, x_t, t, v):
return (
# 使用累积平方根系数乘积提取 t 时刻的值,与输入 x_t 相乘
extract(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
# 使用累积平方根系数乘积提取 t 时刻的值,与输入 v 相乘
extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
)
# 计算后验分布的均值和方差
def q_posterior(self, x_start, x_t, t):
posterior_mean = (
# 提取 t 时刻的系数1,与输入 x_start 相乘
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
# 提取 t 时刻的系数2,与输入 x_t 相乘
extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
# 提取 t 时刻的后验方差
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
# 提取 t 时刻的修剪后的后验对数方差
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
# 模型预测函数,根据不同的目标类型进行预测
def model_predictions(self, x, t, c, x_self_cond = None, clip_x_start = False):
model_output = self.model(x, t, c, x_self_cond)
maybe_clip = partial(torch.clamp, min = -1., max = 1.) if clip_x_start else identity
if self.objective == 'pred_noise':
pred_noise = model_output
x_start = self.predict_start_from_noise(x, t, pred_noise)
x_start = maybe_clip(x_start)
elif self.objective == 'pred_x0':
x_start = model_output
x_start = maybe_clip(x_start)
pred_noise = self.predict_noise_from_start(x, t, x_start)
elif self.objective == 'pred_v':
v = model_output
x_start = self.predict_start_from_v(x, t, v)
x_start = maybe_clip(x_start)
pred_noise = self.predict_noise_from_start(x, t, x_start)
return ModelPrediction(pred_noise, x_start)
# 计算均值和方差,可选择是否对去噪后的值进行裁剪
def p_mean_variance(self, x, t, c, x_self_cond = None, clip_denoised = True):
preds = self.model_predictions(x, t, c, x_self_cond)
x_start = preds.pred_x_start
if clip_denoised:
x_start.clamp_(-1., 1.)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start = x_start, x_t = x, t = t)
return model_mean, posterior_variance, posterior_log_variance, x_start
# 生成样本,根��输入 x, t, c 生成预测图像
@torch.no_grad()
def p_sample(self, x, t, c, x_self_cond = None, clip_denoised = True):
b, *_, device = *x.shape, x.device
batched_times = torch.full((x.shape[0],), t, device = x.device, dtype = torch.long)
model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = batched_times, c = c, x_self_cond = x_self_cond, clip_denoised = clip_denoised)
noise = torch.randn_like(x) if t > 0 else 0. # 若 t == 0 则无噪声
pred_img = model_mean + (0.5 * model_log_variance).exp() * noise
return pred_img, x_start
# 循环生成样本,根据给定的形状和条件
@torch.no_grad()
def p_sample_loop(self, shape, cond):
batch, device = shape[0], self.betas.device
img = torch.randn(shape, device = device)
x_start = None
for t in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
self_cond = x_start if self.self_condition else None
img, x_start = self.p_sample(img, t, cond, self_cond)
img = unnormalize_to_zero_to_one(img)
return img
# 禁用梯度计算
@torch.no_grad()
# 从给定形状和条件图像中生成 DDIM 采样结果
def ddim_sample(self, shape, cond_img, clip_denoised = True):
# 获取形状参数和设备信息
batch, device, total_timesteps, sampling_timesteps, eta, objective = shape[0], self.betas.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective
# 生成时间步长序列
times = torch.linspace(-1, total_timesteps - 1, steps=sampling_timesteps + 1) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
times = list(reversed(times.int().tolist()))
time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]
# 生成随机初始图像
img = torch.randn(shape, device = device)
x_start = None
# 遍历时间步长对
for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
time_cond = torch.full((batch,), time, device=device, dtype=torch.long)
self_cond = x_start if self.self_condition else None
pred_noise, x_start, *_ = self.model_predictions(img, time_cond, cond_img, self_cond, clip_x_start = clip_denoised)
if time_next < 0:
img = x_start
continue
alpha = self.alphas_cumprod[time]
alpha_next = self.alphas_cumprod[time_next]
sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
c = (1 - alpha_next - sigma ** 2).sqrt()
noise = torch.randn_like(img)
img = x_start * alpha_next.sqrt() + \
c * pred_noise + \
sigma * noise
# 将图像还原到 [0, 1] 范围内
img = unnormalize_to_zero_to_one(img)
return img
# 生成采样结果
@torch.no_grad()
def sample(self, cond_img):
batch_size, device = cond_img.shape[0], self.device
cond_img = cond_img.to(self.device)
image_size, mask_channels = self.image_size, self.mask_channels
sample_fn = self.p_sample_loop if not self.is_ddim_sampling else self.ddim_sample
return sample_fn((batch_size, mask_channels, image_size, image_size), cond_img)
# 生成 Q 采样结果
def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
return (
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
)
# 计算 P 损失
def p_losses(self, x_start, t, cond, noise = None):
b, c, h, w = x_start.shape
noise = default(noise, lambda: torch.randn_like(x_start))
# 生成噪声样本
x = self.q_sample(x_start = x_start, t = t, noise = noise)
# 如果进行自条件生成,50% 的时间,从当前时间预测 x_start,并使用 unet 进行条件生成
# 这种技术会使训练速度减慢 25%,但似乎显著降低 FID
x_self_cond = None
if self.self_condition and random() < 0.5:
with torch.no_grad():
# 预测 x_0
x_self_cond = self.model_predictions(x, t, cond).pred_x_start
x_self_cond.detach_()
# 预测并进行梯度下降
model_out = self.model(x, t, cond, x_self_cond)
if self.objective == 'pred_noise':
target = noise
elif self.objective == 'pred_x0':
target = x_start
elif self.objective == 'pred_v':
v = self.predict_v(x_start, t, noise)
target = v
else:
raise ValueError(f'unknown objective {self.objective}')
return F.mse_loss(model_out, target)
# 定义一个前向传播函数,接受输入图像、条件图像以及其他参数
def forward(self, img, cond_img, *args, **kwargs):
# 如果输入图像维度为3,则将其重排为'b h w -> b 1 h w'
if img.ndim == 3:
img = rearrange(img, 'b h w -> b 1 h w')
# 如果条件图像维度为3,则将其重排为'b h w -> b 1 h w'
if cond_img.ndim == 3:
cond_img = rearrange(cond_img, 'b h w -> b 1 h w')
# 获取设备信息并将输入图像和条件图像移动到该设备上
device = self.device
img, cond_img = img.to(device), cond_img.to(device)
# 获取输入图像的形状信息
b, c, h, w, device, img_size, img_channels, mask_channels = *img.shape, img.device, self.image_size, self.input_img_channels, self.mask_channels
# 断言输入图像的高度和宽度必须为img_size
assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
# 断言条件图像的通道数必须为img_channels
assert cond_img.shape[1] == img_channels, f'your input medical must have {img_channels} channels'
# 断言输入图像的通道数必须为mask_channels
assert img.shape[1] == mask_channels, f'the segmented image must have {mask_channels} channels'
# 生成一个随机整数张量,范围为[0, num_timesteps),形状为(b,)
times = torch.randint(0, self.num_timesteps, (b,), device=device).long()
# 对输入图像进行归一化到[-1, 1]范围内
img = normalize_to_neg_one_to_one(img)
# 调用p_losses函数计算损失并返回结果
return self.p_losses(img, times, cond_img, *args, **kwargs)