Lucidrains 系列项目源码解析(十六)
.\lucidrains\cross-transformers-pytorch\setup.py
from setuptools import setup, find_packages
setup(
name = 'cross-transformers-pytorch',
packages = find_packages(),
version = '0.0.2',
license='MIT',
description = 'Cross Transformers - Pytorch',
author = 'Phil Wang',
author_email = 'lucidrains@gmail.com',
url = 'https://github.com/lucidrains/cross-transformers-pytorch',
keywords = [
'artificial intelligence',
'attention mechanism',
'cross attention',
'few shot learning'
],
install_requires=[
'torch>=1.6',
'einops>=0.3'
],
classifiers=[
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\DALLE-pytorch\dalle_pytorch\attention.py
def exists(val):
return val is not None
def uniq(arr):
return{el: True for el in arr}.keys()
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def max_neg_value(t):
return -torch.finfo(t.dtype).max
def stable_softmax(t, dim = -1, alpha = 32 ** 2):
t = t / alpha
t = t - torch.amax(t, dim = dim, keepdim = True).detach()
return (t * alpha).softmax(dim = dim)
def apply_pos_emb(pos_emb, qkv):
n = qkv[0].shape[-2]
pos_emb = pos_emb[..., :n, :]
return tuple(map(lambda t: apply_rotary_emb(pos_emb, t), qkv))
class Attention(nn.Module):
def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropout = 0., stable = False,
static_mask = None):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.seq_len = seq_len
self.scale = dim_head ** -0.5
self.stable = stable
self.causal = causal
self.register_buffer('static_mask', static_mask, persistent=False)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key = None):
b, n, _, h, device = *x.shape, self.heads, x.device
softmax = torch.softmax if not self.stable else stable_softmax
offset = cache.get('offset', 0) if exists(cache) else 0
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
if exists(rotary_pos_emb):
q, k, v = apply_pos_emb(rotary_pos_emb[..., offset:, :], (q, k, v))
q = q * self.scale
if offset > 0:
k_top, v_top = cache[cache_key]
k = torch.cat([k_top, k], dim=-2)
v = torch.cat([v_top, v], dim=-2)
if exists(cache):
cache[cache_key] = k, v
dots = torch.einsum('b h i d, b h j d -> b h i j', q, k)
mask_value = max_neg_value(dots)
if exists(mask):
mask = rearrange(mask, 'b j -> b () () j')
dots.masked_fill_(~mask, mask_value)
del mask
if self.causal and offset == 0:
i, j = dots.shape[-2:]
mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
dots.masked_fill_(mask, mask_value)
if exists(self.static_mask):
dots.masked_fill_(~self.static_mask[offset:offset + n, :offset + n], mask_value)
attn = softmax(dots, dim=-1)
out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out)
return out
class SparseConvCausalAttention(nn.Module):
def __init__(self, dim, seq_len, image_size = 32, kernel_size = 5, dilation = 1, heads = 8, dim_head = 64, dropout = 0., stable = False, **kwargs):
super().__init__()
assert kernel_size % 2 == 1, 'kernel size must be odd'
inner_dim = dim_head * heads
self.seq_len = seq_len
self.heads = heads
self.scale = dim_head ** -0.5
self.image_size = image_size
self.kernel_size = kernel_size
self.dilation = dilation
self.stable = stable
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x, mask = None, rotary_pos_emb = None):
b, n, _, h, img_size, kernel_size, dilation, seq_len, device = *x.shape, self.heads, self.image_size, self.kernel_size, self.dilation, self.seq_len, x.device
softmax = torch.softmax if not self.stable else stable_softmax
img_seq_len = img_size ** 2
text_len = seq_len + 1 - img_seq_len
padding = seq_len - n + 1
mask = default(mask, lambda: torch.ones(b, text_len, device = device).bool())
x = F.pad(x, (0, 0, 0, padding), value = 0)
mask = mask[:, :text_len]
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), qkv)
if exists(rotary_pos_emb):
q, k, v = apply_pos_emb(rotary_pos_emb, (q, k, v))
q *= self.scale
((q_text, q_img), (k_text, k_img), (v_text, v_img)) = map(lambda t: (t[:, :-img_seq_len], t[:, -img_seq_len:]), (q, k, v))
dots_text = einsum('b i d, b j d -> b i j', q_text, k_text)
mask_value = max_neg_value(dots_text)
i, j = dots_text.shape[-2:]
text_causal_mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
dots_text.masked_fill_(text_causal_mask, mask_value)
attn_text = softmax(dots_text, dim = -1)
out_text = einsum('b i j, b j d -> b i d', attn_text, v_text)
effective_kernel_size = (kernel_size - 1) * dilation + 1
same_padding = effective_kernel_size // 2
causal_padding = (same_padding * 2, 0, same_padding * 2, 0)
k_img, v_img = map(lambda t: rearrange(t, 'b (h w) c -> b c h w', h = img_size), (k_img, v_img))
k_img, v_img = map(lambda t: F.pad(t, causal_padding), (k_img, v_img))
k_img, v_img = map(lambda t: F.unfold(t, kernel_size, dilation = dilation), (k_img, v_img))
k_img, v_img = map(lambda t: rearrange(t, 'b (d j) i -> b i j d', j = kernel_size ** 2), (k_img, v_img))
dots_image = einsum('b i d, b i j d -> b i j', q_img, k_img)
dots_image_to_text = einsum('b i d, b j d -> b i j', q_img, k_text)
i, j = dots_image.shape[-2:]
ones = torch.ones((img_seq_len,), device = device)
ones = rearrange(ones, '(h w) -> () () h w', h = img_size)
ones = F.pad(ones, causal_padding, value = 0.)
ones = F.unfold(ones, kernel_size, dilation = dilation)
ones = rearrange(ones, 'b j i -> b i j')
padding_mask = ones == 0.
padding_mask = repeat(padding_mask, '() i j -> b i j', b = b * h)
mask = repeat(mask, 'b j -> (b h) i j', i = i, h = h)
mask = torch.cat((~mask, padding_mask), dim = -1)
dots = torch.cat((dots_image_to_text, dots_image), dim = -1)
dots.masked_fill_(mask, mask_value)
attn = softmax(dots, dim = -1)
attn_image_to_text, attn_image = attn[..., :text_len], attn[..., text_len:]
out_image_to_image = einsum('b i j, b i j d -> b i d', attn_image, v_img)
out_image_to_text = einsum('b i j, b j d -> b i d', attn_image_to_text, v_text)
out_image = out_image_to_image + out_image_to_text
out = torch.cat((out_text, out_image), dim = 1)
out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
out = self.to_out(out)
return out[:, :n]
class SparseAxialCausalAttention(nn.Module):
def __init__(self, dim, seq_len, image_size = 32, axis = 0, heads = 8, dim_head = 64, dropout = 0., stable = False, **kwargs):
super().__init__()
assert axis in {0, 1}, 'axis must be either 0 (along height) or 1 (along width)'
self.axis = axis
inner_dim = dim_head * heads
self.seq_len = seq_len
self.heads = heads
self.scale = dim_head ** -0.5
self.image_size = image_size
self.stable = stable
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x, mask = None, rotary_pos_emb = None):
b, n, _, h, img_size, axis, seq_len, device = *x.shape, self.heads, self.image_size, self.axis, self.seq_len, x.device
softmax = torch.softmax if not self.stable else stable_softmax
img_seq_len = img_size ** 2
text_len = seq_len + 1 - img_seq_len
padding = seq_len - n + 1
mask = default(mask, lambda: torch.ones(b, text_len, device = device).bool())
x = F.pad(x, (0, 0, 0, padding), value = 0)
mask = mask[:, :text_len]
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), qkv)
if exists(rotary_pos_emb):
q, k, v = apply_pos_emb(rotary_pos_emb, (q, k, v))
q *= self.scale
((q_text, q_img), (k_text, k_img), (v_text, v_img)) = map(lambda t: (t[:, :-img_seq_len], t[:, -img_seq_len:]), (q, k, v))
dots_text = einsum('b i d, b j d -> b i j', q_text, k_text)
mask_value = max_neg_value(dots_text)
i, j = dots_text.shape[-2:]
text_causal_mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
dots_text.masked_fill_(text_causal_mask, mask_value)
attn_text = softmax(dots_text, dim = -1)
out_text = einsum('b i j, b j d -> b i d', attn_text, v_text)
split_axis_einops = 'b (h w) c -> b h w c' if axis == 0 else 'b (h w) c -> b w h c'
merge_axis_einops = 'b x n d -> b (x n) d' if axis == 0 else 'b x n d -> b (n x) d'
q_img, k_img, v_img = map(lambda t: rearrange(t, split_axis_einops, h = img_size), (q_img, k_img, v_img))
dots_image_to_image = einsum('b x i d, b x j d -> b x i j', q_img, k_img)
dots_image_to_text = einsum('b x i d, b j d -> b x i j', q_img, k_text)
dots = torch.cat((dots_image_to_text, dots_image_to_image), dim = -1)
bh, x, i, j = dots.shape
causal_mask = torch.ones(i, img_size, device = device).triu_(img_size - i + 1).bool()
causal_mask = repeat(causal_mask, 'i j -> b x i j', b = bh, x = x)
mask = repeat(mask, 'b j -> (b h) x i j', h = h, x = x, i = i)
mask = torch.cat((~mask, causal_mask), dim = -1)
dots.masked_fill_(mask, mask_value)
attn = softmax(dots, dim = -1)
attn_image_to_text, attn_image_to_image = attn[..., :text_len], attn[..., text_len:]
out_image_to_image = einsum('b x i j, b x j d -> b x i d', attn_image_to_image, v_img)
out_image_to_text = einsum('b x i j, b j d -> b x i d', attn_image_to_text, v_text)
out_image = out_image_to_image + out_image_to_text
out_image = rearrange(out_image, merge_axis_einops, x = img_size)
out = torch.cat((out_text, out_image), dim = 1)
out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
out = self.to_out(out)
return out[:, :n]
class SparseAttention(Attention):
def __init__(
self,
*args,
block_size = 16,
text_seq_len = 256,
num_random_blocks = None,
**kwargs
):
super().__init__(*args, **kwargs)
from deepspeed.ops.sparse_attention import SparseSelfAttention, VariableSparsityConfig
self.block_size = block_size
num_random_blocks = default(num_random_blocks, self.seq_len // block_size // 4)
global_block_indices = list(range(ceil(text_seq_len / block_size)))
self.attn_fn = SparseSelfAttention(
sparsity_config = VariableSparsityConfig(
num_heads = self.heads,
block = self.block_size,
num_random_blocks = num_random_blocks,
global_block_indices = global_block_indices,
attention = 'unidirectional' if self.causal else 'bidirectional'
),
max_seq_length = self.seq_len,
attn_mask_mode = 'add'
)
def forward(self, x, mask = None, rotary_pos_emb = None):
b, n, _, h, device = *x.shape, self.heads, x.device
remainder = n % self.block_size
mask = default(mask, lambda: torch.ones(b, n, device = device).bool())
if remainder > 0:
padding = self.block_size - remainder
x = F.pad(x, (0, 0, 0, padding), value = 0)
mask = F.pad(mask, (0, padding), value = False)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
if exists(rotary_pos_emb):
q, k, v = apply_pos_emb(rotary_pos_emb, (q, k, v))
key_pad_mask = None
if exists(mask):
key_pad_mask = ~mask
attn_mask = None
if self.causal:
i, j = q.shape[-2], k.shape[-2]
mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
attn_mask = torch.zeros(i, j, device = device).to(q)
mask_value = max_neg_value(q) / 2
attn_mask.masked_fill_(mask, mask_value)
out = self.attn_fn(q, k, v, attn_mask = attn_mask, key_padding_mask = key_pad_mask)
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out)
return out[:, :n]
.\lucidrains\DALLE-pytorch\dalle_pytorch\dalle_pytorch.py
from math import log2, sqrt
import torch
from torch import nn, einsum
import torch.nn.functional as F
import numpy as np
from axial_positional_embedding import AxialPositionalEmbedding
from einops import rearrange
from dalle_pytorch import distributed_utils
from dalle_pytorch.vae import OpenAIDiscreteVAE, VQGanVAE
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
class always():
def __init__(self, val):
self.val = val
def __call__(self, x, *args, **kwargs):
return self.val
def is_empty(t):
return t.nelement() == 0
def masked_mean(t, mask, dim = 1):
t = t.masked_fill(~mask[:, :, None], 0.)
return t.sum(dim = 1) / mask.sum(dim = 1)[..., None]
def prob_mask_like(shape, prob, device):
return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob
def set_requires_grad(model, value):
for param in model.parameters():
param.requires_grad = value
def eval_decorator(fn):
def inner(model, *args, **kwargs):
was_training = model.training
model.eval()
out = fn(model, *args, **kwargs)
model.train(was_training)
return out
return inner
def log(t, eps = 1e-20):
return torch.log(t.clamp(min = eps)
def gumbel_noise(t):
noise = torch.zeros_like(t).uniform_(0, 1)
return -log(-log(noise))
def gumbel_sample(t, temperature = 1., dim = -1):
return ((t / temperature) + gumbel_noise(t)).argmax(dim = dim)
def top_k(logits, thres = 0.5):
num_logits = logits.shape[-1]
k = max(int((1 - thres) * num_logits), 1)
val, ind = torch.topk(logits, k)
probs = torch.full_like(logits, float('-inf'))
probs.scatter_(1, ind, val)
return probs
class SharedEmbedding(nn.Embedding):
def __init__(self, linear, start_index, end_index, **kwargs):
super().__init__(end_index - start_index, linear.weight.shape[1], **kwargs)
del self.weight
self.linear = linear
self.start_index = start_index
self.end_index = end_index
def forward(self, input):
return F.embedding(
input, self.linear.weight[self.start_index:self.end_index], self.padding_idx, self.max_norm,
self.norm_type, self.scale_grad_by_freq, self.sparse)
class ResBlock(nn.Module):
def __init__(self, chan):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(chan, chan, 3, padding = 1),
nn.ReLU(),
nn.Conv2d(chan, chan, 3, padding = 1),
nn.ReLU(),
nn.Conv2d(chan, chan, 1)
)
def forward(self, x):
return self.net(x) + x
class DiscreteVAE(nn.Module):
def __init__(
self,
image_size = 256,
num_tokens = 512,
codebook_dim = 512,
num_layers = 3,
num_resnet_blocks = 0,
hidden_dim = 64,
channels = 3,
smooth_l1_loss = False,
temperature = 0.9,
straight_through = False,
reinmax = False,
kl_div_loss_weight = 0.,
normalization = ((*((0.5,) * 3), 0), (*((0.5,) * 3), 1))
):
super().__init__()
assert log2(image_size).is_integer(), 'image size must be a power of 2'
assert num_layers >= 1, 'number of layers must be greater than or equal to 1'
has_resblocks = num_resnet_blocks > 0
self.channels = channels
self.image_size = image_size
self.num_tokens = num_tokens
self.num_layers = num_layers
self.temperature = temperature
self.straight_through = straight_through
self.reinmax = reinmax
self.codebook = nn.Embedding(num_tokens, codebook_dim)
hdim = hidden_dim
enc_chans = [hidden_dim] * num_layers
dec_chans = list(reversed(enc_chans))
enc_chans = [channels, *enc_chans]
dec_init_chan = codebook_dim if not has_resblocks else dec_chans[0]
dec_chans = [dec_init_chan, *dec_chans]
enc_chans_io, dec_chans_io = map(lambda t: list(zip(t[:-1], t[1:])), (enc_chans, dec_chans))
enc_layers = []
dec_layers = []
for (enc_in, enc_out), (dec_in, dec_out) in zip(enc_chans_io, dec_chans_io):
enc_layers.append(nn.Sequential(nn.Conv2d(enc_in, enc_out, 4, stride = 2, padding = 1), nn.ReLU()))
dec_layers.append(nn.Sequential(nn.ConvTranspose2d(dec_in, dec_out, 4, stride = 2, padding = 1), nn.ReLU()))
for _ in range(num_resnet_blocks):
dec_layers.insert(0, ResBlock(dec_chans[1]))
enc_layers.append(ResBlock(enc_chans[-1]))
if num_resnet_blocks > 0:
dec_layers.insert(0, nn.Conv2d(codebook_dim, dec_chans[1], 1))
enc_layers.append(nn.Conv2d(enc_chans[-1], num_tokens, 1))
dec_layers.append(nn.Conv2d(dec_chans[-1], channels, 1))
self.encoder = nn.Sequential(*enc_layers)
self.decoder = nn.Sequential(*dec_layers)
self.loss_fn = F.smooth_l1_loss if smooth_l1_loss else F.mse_loss
self.kl_div_loss_weight = kl_div_loss_weight
self.normalization = tuple(map(lambda t: t[:channels], normalization))
self._register_external_parameters()
def _register_external_parameters(self):
"""Register external parameters for DeepSpeed partitioning."""
if (
not distributed_utils.is_distributed
or not distributed_utils.using_backend(
distributed_utils.DeepSpeedBackend)
):
return
deepspeed = distributed_utils.backend.backend_module
deepspeed.zero.register_external_parameter(self, self.codebook.weight)
def norm(self, images):
if not exists(self.normalization):
return images
means, stds = map(lambda t: torch.as_tensor(t).to(images), self.normalization)
means, stds = map(lambda t: rearrange(t, 'c -> () c () ()'), (means, stds))
images = images.clone()
images.sub_(means).div_(stds)
return images
@torch.no_grad()
@eval_decorator
def get_codebook_indices(self, images):
logits = self(images, return_logits = True)
codebook_indices = logits.argmax(dim = 1).flatten(1)
return codebook_indices
def decode(
self,
img_seq
):
image_embeds = self.codebook(img_seq)
b, n, d = image_embeds.shape
h = w = int(sqrt(n))
image_embeds = rearrange(image_embeds, 'b (h w) d -> b d h w', h = h, w = w)
images = self.decoder(image_embeds)
return images
def forward(
self,
img,
return_loss = False,
return_recons = False,
return_logits = False,
temp = None
):
device, num_tokens, image_size, kl_div_loss_weight = img.device, self.num_tokens, self.image_size, self.kl_div_loss_weight
assert img.shape[-1] == image_size and img.shape[-2] == image_size, f'input must have the correct image size {image_size}'
img = self.norm(img)
logits = self.encoder(img)
if return_logits:
return logits
temp = default(temp, self.temperature)
one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=self.straight_through)
if self.straight_through and self.reinmax:
one_hot = one_hot.detach()
π0 = logits.softmax(dim=1)
π1 = (one_hot + (logits / temp).softmax(dim=1)) / 2
π1 = ((log(π1) - logits).detach() + logits).softmax(dim=1)
π2 = 2 * π1 - 0.5 * π0
one_hot = π2 - π2.detach() + one_hot
sampled = einsum('b n h w, n d -> b d h w', one_hot, self.codebook.weight)
out = self.decoder(sampled)
if not return_loss:
return out
recon_loss = self.loss_fn(img, out)
logits = rearrange(logits, 'b n h w -> b (h w) n')
log_qy = F.log_softmax(logits, dim=-1)
log_uniform = torch.log(torch.tensor([1. / num_tokens], device=device))
kl_div = F.kl_div(log_uniform, log_qy, None, None, 'batchmean', log_target=True)
loss = recon_loss + (kl_div * kl_div_loss_weight)
if not return_recons:
return loss
return loss, out
class CLIP(nn.Module):
def __init__(
self,
*,
dim_text = 512,
dim_image = 512,
dim_latent = 512,
num_text_tokens = 10000,
text_enc_depth = 6,
text_seq_len = 256,
text_heads = 8,
num_visual_tokens = 512,
visual_enc_depth = 6,
visual_heads = 8,
visual_image_size = 256,
visual_patch_size = 32,
channels = 3
):
super().__init__()
self.text_emb = nn.Embedding(num_text_tokens, dim_text)
self.text_pos_emb = nn.Embedding(text_seq_len, dim_text)
self.text_transformer = Transformer(causal = False, seq_len = text_seq_len, dim = dim_text, depth = text_enc_depth, heads = text_heads, rotary_emb = False)
self.to_text_latent = nn.Linear(dim_text, dim_latent, bias = False)
assert visual_image_size % visual_patch_size == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (visual_image_size // visual_patch_size) ** 2
patch_dim = channels * visual_patch_size ** 2
self.visual_patch_size = visual_patch_size
self.to_visual_embedding = nn.Linear(patch_dim, dim_image)
self.visual_pos_emb = nn.Embedding(num_patches, dim_image)
self.visual_transformer = Transformer(causal = False, seq_len = num_patches, dim = dim_image, depth = visual_enc_depth, heads = visual_heads, rotary_emb = False)
self.to_visual_latent = nn.Linear(dim_image, dim_latent, bias = False)
self.temperature = nn.Parameter(torch.tensor(1.))
def forward(
self,
text,
image,
text_mask = None,
return_loss = False
):
b, device, p = text.shape[0], text.device, self.visual_patch_size
text_emb = self.text_emb(text)
text_emb += self.text_pos_emb(torch.arange(text.shape[1], device = device))
image_patches = rearrange(image, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
image_emb = self.to_visual_embedding(image_patches)
image_emb += self.visual_pos_emb(torch.arange(image_emb.shape[1], device = device))
enc_text = self.text_transformer(text_emb, mask = text_mask)
enc_image = self.visual_transformer(image_emb)
if exists(text_mask):
text_latents = masked_mean(enc_text, text_mask, dim = 1)
else:
text_latents = enc_text.mean(dim = 1)
image_latents = enc_image.mean(dim = 1)
text_latents = self.to_text_latent(text_latents)
image_latents = self.to_visual_latent(image_latents)
text_latents, image_latents = map(lambda t: F.normalize(t, p = 2, dim = -1), (text_latents, image_latents))
temp = self.temperature.exp()
if not return_loss:
sim = einsum('n d, n d -> n', text_latents, image_latents) * temp
return sim
sim = einsum('i d, j d -> i j', text_latents, image_latents) * temp
labels = torch.arange(b, device = device)
loss = (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2
return loss
class DALLE(nn.Module):
def __init__(
self,
*,
dim,
vae,
num_text_tokens = 10000,
text_seq_len = 256,
depth,
heads = 8,
dim_head = 64,
reversible = False,
attn_dropout = 0.,
ff_dropout = 0,
sparse_attn = False,
attn_types = None,
loss_img_weight = 7,
stable = False,
sandwich_norm = False,
shift_tokens = True,
rotary_emb = True,
shared_attn_ids = None,
shared_ff_ids = None,
share_input_output_emb = False,
optimize_for_inference = False,
):
super().__init__()
assert isinstance(vae, (DiscreteVAE, OpenAIDiscreteVAE, VQGanVAE)), 'vae must be an instance of DiscreteVAE'
image_size = vae.image_size
num_image_tokens = vae.num_tokens
image_fmap_size = (vae.image_size // (2 ** vae.num_layers))
image_seq_len = image_fmap_size ** 2
num_text_tokens = num_text_tokens + text_seq_len
self.text_pos_emb = nn.Embedding(text_seq_len + 1, dim) if not rotary_emb else always(0)
self.image_pos_emb = AxialPositionalEmbedding(dim, axial_shape = (image_fmap_size, image_fmap_size)) if not rotary_emb else always(0)
self.num_text_tokens = num_text_tokens
self.num_image_tokens = num_image_tokens
self.text_seq_len = text_seq_len
self.image_seq_len = image_seq_len
seq_len = text_seq_len + image_seq_len
total_tokens = num_text_tokens + num_image_tokens
self.total_tokens = total_tokens
self.total_seq_len = seq_len
self.vae = vae
set_requires_grad(self.vae, False)
self.transformer = Transformer(
dim = dim,
causal = True,
seq_len = seq_len,
depth = depth,
heads = heads,
dim_head = dim_head,
reversible = reversible,
attn_dropout = attn_dropout,
ff_dropout = ff_dropout,
attn_types = attn_types,
image_fmap_size = image_fmap_size,
sparse_attn = sparse_attn,
stable = stable,
sandwich_norm = sandwich_norm,
shift_tokens = shift_tokens,
rotary_emb = rotary_emb,
shared_attn_ids = shared_attn_ids,
shared_ff_ids = shared_ff_ids,
optimize_for_inference = optimize_for_inference,
)
self.stable = stable
if stable:
self.norm_by_max = DivideMax(dim = -1)
self.to_logits = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, self.total_tokens),
)
if share_input_output_emb:
self.text_emb = SharedEmbedding(self.to_logits[1], 0, num_text_tokens)
self.image_emb = SharedEmbedding(self.to_logits[1], num_text_tokens, total_tokens)
else:
self.text_emb = nn.Embedding(num_text_tokens, dim)
self.image_emb = nn.Embedding(num_image_tokens, dim)
seq_range = torch.arange(seq_len)
logits_range = torch.arange(total_tokens)
seq_range = rearrange(seq_range, 'n -> () n ()')
logits_range = rearrange(logits_range, 'd -> () () d')
logits_mask = (
((seq_range >= text_seq_len) & (logits_range < num_text_tokens)) |
((seq_range < text_seq_len) & (logits_range >= num_text_tokens))
)
self.register_buffer('logits_mask', logits_mask, persistent=False)
self.loss_img_weight = loss_img_weight
@torch.no_grad()
@eval_decorator
def generate_texts(
self,
tokenizer,
text = None,
*,
filter_thres = 0.5,
temperature = 1.
):
text_seq_len = self.text_seq_len
if text is None or text == "":
text_tokens = torch.tensor([[0]]).cuda()
else:
text_tokens = torch.tensor(tokenizer.tokenizer.encode(text)).cuda().unsqueeze(0)
for _ in range(text_tokens.shape[1], text_seq_len):
device = text_tokens.device
tokens = self.text_emb(text_tokens)
tokens += self.text_pos_emb(torch.arange(text_tokens.shape[1], device=device))
seq_len = tokens.shape[1]
output_transf = self.transformer(tokens)
if self.stable:
output_transf = self.norm_by_max(output_transf)
logits = self.to_logits(output_transf)
logits_mask = self.logits_mask[:, :seq_len]
max_neg_value = -torch.finfo(logits.dtype).max
logits.masked_fill_(logits_mask, max_neg_value)
logits = logits[:, -1, :]
filtered_logits = top_k(logits, thres=filter_thres)
sample = gumbel_sample(filtered_logits, temperature=temperature, dim=-1)
text_tokens = torch.cat((text_tokens, sample[:, None]), dim=-1)
padding_tokens = set(np.arange(self.text_seq_len) + (self.num_text_tokens - self.text_seq_len))
texts = [tokenizer.tokenizer.decode(text_token, pad_tokens=padding_tokens) for text_token in text_tokens]
return text_tokens, texts
@torch.no_grad()
@eval_decorator
def generate_images(
self,
text,
*,
clip=None,
filter_thres=0.5,
temperature=1.,
img=None,
num_init_img_tokens=None,
cond_scale=1.,
use_cache=False,
):
vae, text_seq_len, image_seq_len, num_text_tokens = self.vae, self.text_seq_len, self.image_seq_len, self.num_text_tokens
total_len = text_seq_len + image_seq_len
text = text[:, :text_seq_len]
out = text
if exists(img):
image_size = vae.image_size
assert img.shape[1] == 3 and img.shape[2] == image_size and img.shape[3] == image_size, f'input image must have the correct image size {image_size}'
indices = vae.get_codebook_indices(img)
num_img_tokens = default(num_init_img_tokens, int(0.4375 * image_seq_len))
assert num_img_tokens < image_seq_len, 'number of initial image tokens for priming must be less than the total image token sequence length'
indices = indices[:, :num_img_tokens]
out = torch.cat((out, indices), dim=-1)
prev_cache = None
cache = {} if use_cache else None
for cur_len in range(out.shape[1], total_len):
is_image = cur_len >= text_seq_len
text, image = out[:, :text_seq_len], out[:, text_seq_len:]
logits = self.forward_with_cond_scale(text, image, cond_scale=cond_scale, cache=cache)
logits = logits[:, -1, :]
filtered_logits = top_k(logits, thres=filter_thres)
sample = gumbel_sample(filtered_logits, temperature=temperature, dim=-1)
sample -= (num_text_tokens if is_image else 0)
out = torch.cat((out, sample[:, None]), dim=-1)
text_seq = out[:, :text_seq_len]
img_seq = out[:, -image_seq_len:]
images = vae.decode(img_seq)
if exists(clip):
scores = clip(text_seq, images, return_loss=False)
return images, scores
return images
def forward_with_cond_scale(self, *args, cond_scale = 1, cache = None, **kwargs):
if cond_scale == 1:
return self(*args, **kwargs)
prev_cache = cache.copy() if exists(cache) else None
logits = self(*args, cache = cache, **kwargs)
null_cond_logits = self(*args, null_cond_prob = 1., cache = prev_cache, **kwargs)
return null_cond_logits + (logits - null_cond_logits) * cond_scale
def forward(
self,
text,
image = None,
return_loss = False,
null_cond_prob = 0.,
cache = None,
):
assert text.shape[-1] == self.text_seq_len, f'the length {text.shape[-1]} of the text tokens you passed in does not have the correct length ({self.text_seq_len})'
batch, device, total_seq_len = text.shape[0], text.device, self.total_seq_len
if null_cond_prob > 0:
null_mask = prob_mask_like((batch,), null_cond_prob, device=device)
text *= rearrange(~null_mask, 'b -> b 1')
text_range = torch.arange(self.text_seq_len, device=device) + (self.num_text_tokens - self.text_seq_len)
text = torch.where(text == 0, text_range, text)
text = F.pad(text, (1, 0), value=0)
tokens = self.text_emb(text)
tokens += self.text_pos_emb(torch.arange(text.shape[1], device=device))
seq_len = tokens.shape[1]
if exists(image) and not is_empty(image):
is_raw_image = len(image.shape) == 4
if is_raw_image:
image_size = self.vae.image_size
channels = self.vae.channels
assert tuple(image.shape[1:]) == (channels, image_size, image_size), f'invalid image of dimensions {image.shape} passed in during training'
image = self.vae.get_codebook_indices(image)
image_len = image.shape[1]
image_emb = self.image_emb(image)
image_emb += self.image_pos_emb(image_emb)
tokens = torch.cat((tokens, image_emb), dim=1)
seq_len += image_len
if tokens.shape[1] > total_seq_len:
seq_len -= 1
tokens = tokens[:, :-1]
if self.stable:
alpha = 0.1
tokens = tokens * alpha + tokens.detach() * (1 - alpha)
if exists(cache) and cache.get('offset'):
tokens = tokens[:, -1:]
out = self.transformer(tokens, cache=cache)
if self.stable:
out = self.norm_by_max(out)
logits = self.to_logits(out)
logits_mask = self.logits_mask[:, :seq_len]
if exists(cache) and cache.get('offset'):
logits_mask = logits_mask[:, -1:]
max_neg_value = -torch.finfo(logits.dtype).max
logits.masked_fill_(logits_mask, max_neg_value)
if exists(cache):
cache['offset'] = cache.get('offset', 0) + logits.shape[1]
if not return_loss:
return logits
assert exists(image), 'when training, image must be supplied'
offsetted_image = image + self.num_text_tokens
labels = torch.cat((text[:, 1:], offsetted_image), dim=1)
logits = rearrange(logits, 'b n c -> b c n')
loss_text = F.cross_entropy(logits[:, :, :self.text_seq_len], labels[:, :self.text_seq_len])
loss_img = F.cross_entropy(logits[:, :, self.text_seq_len:], labels[:, self.text_seq_len:])
loss = (loss_text + self.loss_img_weight * loss_img) / (self.loss_img_weight + 1)
return loss
.\lucidrains\DALLE-pytorch\dalle_pytorch\distributed_backends\deepspeed_backend.py
import json
import os
import torch
from .distributed_backend import DistributedBackend
class DeepSpeedBackend(DistributedBackend):
"""使用 DeepSpeed 引擎的分布式后端。"""
BACKEND_MODULE_NAME = 'deepspeed'
BACKEND_NAME = 'DeepSpeed'
def wrap_arg_parser(self, parser):
if not self.has_backend():
parser.add_argument(
'--deepspeed',
type=lambda _: False,
help=(
'是否使用 DeepSpeed '
"(由于不可用,此选项被忽略)"
),
)
else:
parser = self.backend_module.add_config_arguments(parser)
parser.add_argument(
'--local_rank',
type=int,
default=-1,
help='从分布式启动器传递的本地排名',
)
return parser
def _initialize(self):
self.backend_module.init_distributed()
if torch.cuda.is_available():
torch.cuda.set_device(self._get_local_rank())
@staticmethod
def _require_torch_distributed_init():
"""当 `torch.distributed` 尚未初始化时引发错误。"""
assert torch.distributed.is_initialized(), \
('`torch.distributed` 未初始化;请在脚本开头调用 '
'`DeepSpeedBackend.initialize`')
def _get_world_size(self):
self._require_torch_distributed_init()
return torch.distributed.get_world_size()
def _get_rank(self):
self._require_torch_distributed_init()
return torch.distributed.get_rank()
def _get_local_rank(self):
self._require_torch_distributed_init()
return int(os.environ['LOCAL_RANK'])
def _local_barrier(self):
self._require_torch_distributed_init()
torch.distributed.barrier()
def _check_args(self, args, optimizer, lr_scheduler, kwargs):
"""在检查传递给 `distribute` 的值后,返回适当的优化器和学习率调度器。"""
self._check_argvs(args, optimizer, lr_scheduler, kwargs)
(optimizer, lr_scheduler) = self._check_config(
args, optimizer, lr_scheduler, kwargs)
return (optimizer, lr_scheduler)
def _check_argvs(self, args, optimizer, lr_scheduler, kwargs):
"""对给定的命令行参数应用几个合理性检查。"""
has_json_config = (hasattr(args, 'deepspeed_config')
and args.deepspeed_config is not None)
has_dict_config = 'config_params' in kwargs
if (
(not has_json_config and not has_dict_config)
or (not has_dict_config
and not os.path.isfile(args.deepspeed_config))
):
return
if not args.deepspeed:
print(
'警告:已选择 DeepSpeed 后端;设置 `args.deepspeed = True`'
)
args.deepspeed = True
if has_json_config and has_dict_config:
print(
'警告:DeepSpeed 配置同时以 JSON 文件和 Python 字典形式给出。Python 字典优先。'
)
def _check_config(self, args, optimizer, lr_scheduler, kwargs):
"""Return an appropriate optimizer and learning rate scheduler
for the DeepSpeed configuration.
"""
if 'config_params' in kwargs:
config = kwargs['config_params']
else:
with open(args.deepspeed_config, 'r') as json_config_file:
config = json.load(json_config_file)
if 'optimizer' in config and optimizer is not None:
print(
'WARNING: Optimizer encountered in both DeepSpeed config and '
'keyword arguments. Optimizer in DeepSpeed config '
'takes precedence.'
)
optimizer = None
if 'scheduler' in config and lr_scheduler is not None:
print(
'WARNING: Learning rate scheduler encountered in both '
'DeepSpeed config and keyword arguments. Learning rate '
'scheduler in DeepSpeed config takes precedence.'
)
lr_scheduler = None
return (optimizer, lr_scheduler)
def _distribute(
self,
args=None,
model=None,
optimizer=None,
model_parameters=None,
training_data=None,
lr_scheduler=None,
**kwargs,
):
"""Return a distributed model engine, optimizer, dataloader, and
learning rate scheduler. These are obtained by wrapping the
given values with the backend.
For the other or other possible arguments,
see `deepspeed.initialize`.
"""
(optimizer, lr_scheduler) = self._check_args(
args, optimizer, lr_scheduler, kwargs)
return self.backend_module.initialize(
args=args,
model=model,
optimizer=optimizer,
model_parameters=model_parameters,
training_data=training_data,
lr_scheduler=lr_scheduler,
**kwargs,
)
def _average_all(self, tensor):
self._require_torch_distributed_init()
averaged = tensor.detach().clone()
torch.distributed.all_reduce(averaged, torch.distributed.ReduceOp.SUM)
return averaged / self.get_world_size()
.\lucidrains\DALLE-pytorch\dalle_pytorch\distributed_backends\distributed_backend.py
"""
An abstract backend for distributed deep learning.
Provides several standard utility methods under a common API.
Please check the documentation of the class `DistributedBackend` for
details to implement a new backend.
"""
from importlib import import_module
class DistributedBackend:
"""An abstract backend class for distributed deep learning.
Provides several standard utility methods under a common API.
Variables that must be overridden:
- BACKEND_MODULE_NAME
- BACKEND_NAME
Methods that must be overridden:
- wrap_arg_parser
- _initialize
- _get_world_size
- _get_rank
- _get_local_rank
- _local_barrier
- _distribute
- _average_all
"""
BACKEND_MODULE_NAME = None
"""Name of the module to import for the backend."""
BACKEND_NAME = None
"""Name of the backend for printing."""
ROOT_RANK = 0
backend_module = None
"""The module to access the backend."""
is_initialized = False
"""Whether the backend is initialized."""
def __init__(self):
if self.BACKEND_MODULE_NAME is None:
raise NotImplementedError('BACKEND_MODULE_NAME is not set')
if self.BACKEND_NAME is None:
raise NotImplementedError('BACKEND_NAME is not set')
def has_backend(self):
"""Return whether the backend module is now imported."""
try:
self.backend_module = import_module(self.BACKEND_MODULE_NAME)
except ModuleNotFoundError:
return False
return True
def check_batch_size(self, batch_size):
"""Check whether the batch size makes sense for distribution."""
assert batch_size >= self.get_world_size(), \
(f"batch size can't be smaller than number of processes "
f'({batch_size} < {self.get_world_size()})')
def wrap_arg_parser(self, parser):
"""Add arguments to support optional distributed backend usage."""
raise NotImplementedError
def initialize(self):
"""Initialize the distributed backend."""
self._initialize()
self.is_initialized = True
def _initialize(self):
"""Initialize the distributed backend."""
raise NotImplementedError
def require_init(self):
"""Raise an error when the backend has not been initialized yet."""
assert self.is_initialized, \
(f'{BACKEND_NAME} backend has not been initialized; please call '
f'`distributed_utils.initialize` at the start of your script to '
f'allow optional distributed usage')
def get_world_size(self):
"""Return the amount of distributed processes."""
self.require_init()
return self._get_world_size()
def _get_world_size(self):
"""Return the amount of distributed processes."""
raise NotImplementedError
def get_rank(self):
"""Return the global rank of the calling worker process."""
self.require_init()
return self._get_rank()
def _get_rank(self):
"""Return the global rank of the calling worker process."""
raise NotImplementedError
def get_local_rank(self):
"""Return the local rank of the calling worker process.
The local rank is the rank based on a single node's processes.
"""
self.require_init()
return self._get_local_rank()
def _get_local_rank(self):
"""Return the local rank of the calling worker process.
The local rank is the rank based on a single node's processes.
"""
raise NotImplementedError
def is_root_worker(self):
"""Return whether the calling worker has the root rank."""
return self.get_rank() == self.ROOT_RANK
def is_local_root_worker(self):
"""Return whether the calling worker has the root rank on this node."""
return self.get_local_rank() == self.ROOT_RANK
def local_barrier(self):
"""Wait until all processes on this node have called this function."""
self.require_init()
self._local_barrier()
def _local_barrier(self):
"""Wait until all processes on this node have called this function."""
raise NotImplementedError
def distribute(
self,
args=None,
model=None,
optimizer=None,
model_parameters=None,
training_data=None,
lr_scheduler=None,
**kwargs,
):
"""Return a distributed model engine, optimizer, dataloader, and
learning rate scheduler. These are obtained by wrapping the
given values with the backend.
"""
self.require_init()
return self._distribute(
args,
model,
optimizer,
model_parameters,
training_data,
lr_scheduler,
**kwargs,
)
def _distribute(
self,
args=None,
model=None,
optimizer=None,
model_parameters=None,
training_data=None,
lr_scheduler=None,
**kwargs,
):
"""Return a distributed model engine, optimizer, dataloader, and
learning rate scheduler. These are obtained by wrapping the
given values with the backend.
"""
raise NotImplementedError
def average_all(self, tensor):
"""Return the average of `tensor` over all workers."""
self.require_init()
return self._average_all(tensor)
def _average_all(self, tensor):
"""Return the average of `tensor` over all workers."""
raise NotImplementedError
.\lucidrains\DALLE-pytorch\dalle_pytorch\distributed_backends\dummy_backend.py
from .distributed_backend import DistributedBackend
class DummyBackend(DistributedBackend):
"""Acts like a distributed backend.
Used as a stand-in replacement to obtain a non-distributed program.
"""
BACKEND_MODULE_NAME = 'NO MODULE'
BACKEND_NAME = 'Dummy'
def has_backend(self):
return True
def wrap_arg_parser(self, parser):
return parser
def _initialize(self):
pass
def _get_world_size(self):
return 1
def _get_rank(self):
return self.ROOT_RANK
def _get_local_rank(self):
return self.ROOT_RANK
def _local_barrier(self):
pass
def _distribute(
self,
_args=None,
model=None,
optimizer=None,
_model_parameters=None,
training_data=None,
lr_scheduler=None,
**_kwargs,
):
"""Return the model, optimizer, dataloader, and learning rate scheduler
as is.
"""
return (model, optimizer, training_data, lr_scheduler)
def _average_all(self, tensor):
return tensor
.\lucidrains\DALLE-pytorch\dalle_pytorch\distributed_backends\horovod_backend.py
import torch
from .distributed_backend import DistributedBackend
class HorovodBackend(DistributedBackend):
"""Distributed backend using Horovod."""
BACKEND_MODULE_NAME = 'horovod.torch'
BACKEND_NAME = 'Horovod'
def wrap_arg_parser(self, parser):
return parser
def check_batch_size(self, batch_size):
pass
def _initialize(self):
self.backend_module.init()
if torch.cuda.is_available():
torch.cuda.set_device(self._get_local_rank())
def _get_world_size(self):
return self.backend_module.size()
def _get_rank(self):
return self.backend_module.rank()
def _get_local_rank(self):
return self.backend_module.local_rank()
def _local_barrier(self):
self.backend_module.join()
def _distribute(
self,
_args=None,
model=None,
optimizer=None,
_model_parameters=None,
training_data=None,
lr_scheduler=None,
**_kwargs,
):
optimizer = self.backend_module.DistributedOptimizer(optimizer)
self.backend_module.broadcast_parameters(
model.state_dict(), root_rank=self.ROOT_RANK)
self.backend_module.broadcast_optimizer_state(
optimizer, root_rank=self.ROOT_RANK)
return (model, optimizer, training_data, lr_scheduler)
def _average_all(self, tensor):
averaged = self.backend_module.allreduce(tensor)
return averaged
.\lucidrains\DALLE-pytorch\dalle_pytorch\distributed_backends\__init__.py
from .deepspeed_backend import DeepSpeedBackend
from .distributed_backend import DistributedBackend
from .dummy_backend import DummyBackend
from .horovod_backend import HorovodBackend
.\lucidrains\DALLE-pytorch\dalle_pytorch\distributed_utils.py
"""
Utility functions for optional distributed execution.
To use,
1. set the `BACKENDS` to the ones you want to make available,
2. in the script, wrap the argument parser with `wrap_arg_parser`,
3. in the script, set and use the backend by calling
`set_backend_from_args`.
You can check whether a backend is in use with the `using_backend`
function.
"""
from dalle_pytorch.distributed_backends import \
DeepSpeedBackend, \
DummyBackend, \
HorovodBackend
_DEFAULT_BACKEND = DummyBackend()
"""Which backend to use by default. Assumed to be _not_ distributed."""
BACKENDS = [
_DEFAULT_BACKEND,
DeepSpeedBackend(),
HorovodBackend(),
]
is_distributed = None
"""Whether we are distributed."""
backend = None
"""Backend in usage."""
def wrap_arg_parser(parser):
"""Add arguments to support optional distributed backend usage."""
parser.add_argument(
'--distributed_backend',
'--distr_backend',
type=str,
default=None,
help='which distributed backend to use. Do not distribute by default',
)
for distr_backend in BACKENDS:
parser = distr_backend.wrap_arg_parser(parser)
return parser
def set_backend_from_args(args):
"""Set and return the backend based on the given `args`."""
global is_distributed, backend
if args.deepspeed:
args.distributed_backend = DeepSpeedBackend.BACKEND_NAME
if not args.distributed_backend:
is_distributed = False
backend = _DEFAULT_BACKEND
return backend
backend_name = args.distributed_backend.lower()
for distr_backend in BACKENDS:
if distr_backend.BACKEND_NAME.lower() == backend_name:
backend = distr_backend
if not backend.has_backend():
raise ModuleNotFoundError(
f'{backend.BACKEND_NAME} backend selected but '
'module not available'
)
print(f'Using {backend.BACKEND_NAME} for distributed execution')
is_distributed = True
return backend
raise ValueError(
'unknown backend; please check `distributed_utils.BACKENDS`')
def require_set_backend():
"""Raise an `AssertionError` when the backend has not been set."""
assert backend is not None, (
'distributed backend is not set. Please call '
'`distributed_utils.set_backend_from_args` at the start of your script'
)
def using_backend(test_backend):
"""Return whether the backend is set to `test_backend`.
`test_backend` may be a string of the name of the backend or
its class.
"""
require_set_backend()
if isinstance(test_backend, str):
return backend.BACKEND_NAME == test_backend
return isinstance(backend, test_backend)
.\lucidrains\DALLE-pytorch\dalle_pytorch\loader.py
from pathlib import Path
from random import randint, choice
import PIL
from torch.utils.data import Dataset
from torchvision import transforms as T
class TextImageDataset(Dataset):
def __init__(self,
folder,
text_len=256,
image_size=128,
truncate_captions=False,
resize_ratio=0.75,
transparent=False,
tokenizer=None,
shuffle=False
):
"""
@param folder: 包含图像和文本文件的文件夹,它们通过其路径的相应“stem”匹配
@param truncate_captions: 如果标题太长,将截断标题而不是抛出异常
"""
super().__init__()
self.shuffle = shuffle
path = Path(folder)
text_files = [*path.glob('**/*.txt')]
image_files = [
*path.glob('**/*.png'), *path.glob('**/*.jpg'),
*path.glob('**/*.jpeg'), *path.glob('**/*.bmp')
]
text_files = {text_file.stem: text_file for text_file in text_files}
image_files = {image_file.stem: image_file for image_file in image_files}
keys = (image_files.keys() & text_files.keys())
self.keys = list(keys)
self.text_files = {k: v for k, v in text_files.items() if k in keys}
self.image_files = {k: v for k, v in image_files.items() if k in keys}
self.text_len = text_len
self.truncate_captions = truncate_captions
self.resize_ratio = resize_ratio
self.tokenizer = tokenizer
image_mode = 'RGBA' if transparent else 'RGB'
self.image_transform = T.Compose([
T.Lambda(lambda img: img.convert(image_mode)
if img.mode != image_mode else img),
T.RandomResizedCrop(image_size,
scale=(self.resize_ratio, 1.),
ratio=(1., 1.)),
T.ToTensor()
])
def __len__(self):
return len(self.keys)
def random_sample(self):
return self.__getitem__(randint(0, self.__len__() - 1))
def sequential_sample(self, ind):
if ind >= self.__len__() - 1:
return self.__getitem__(0)
return self.__getitem__(ind + 1)
def skip_sample(self, ind):
if self.shuffle:
return self.random_sample()
return self.sequential_sample(ind=ind)
def __getitem__(self, ind):
key = self.keys[ind]
text_file = self.text_files[key]
image_file = self.image_files[key]
descriptions = text_file.read_text().split('\n')
descriptions = list(filter(lambda t: len(t) > 0, descriptions))
try:
description = choice(descriptions)
except IndexError as zero_captions_in_file_ex:
print(f"An exception occurred trying to load file {text_file}.")
print(f"Skipping index {ind}")
return self.skip_sample(ind)
tokenized_text = self.tokenizer.tokenize(
description,
self.text_len,
truncate_text=self.truncate_captions
).squeeze(0)
try:
image_tensor = self.image_transform(PIL.Image.open(image_file))
except (PIL.UnidentifiedImageError, OSError) as corrupt_image_exceptions:
print(f"An exception occurred trying to load file {image_file}.")
print(f"Skipping index {ind}")
return self.skip_sample(ind)
return tokenized_text, image_tensor