Lucidrains 系列项目源码解析(一百零八)
.\lucidrains\vit-pytorch\vit_pytorch\mae.py
import torch
from torch import nn
import torch.nn.functional as F
from einops import repeat
from vit_pytorch.vit import Transformer
class MAE(nn.Module):
def __init__(
self,
*,
encoder,
decoder_dim,
masking_ratio = 0.75,
decoder_depth = 1,
decoder_heads = 8,
decoder_dim_head = 64
):
super().__init__()
assert masking_ratio > 0 and masking_ratio < 1, 'masking ratio must be kept between 0 and 1'
self.masking_ratio = masking_ratio
self.encoder = encoder
num_patches, encoder_dim = encoder.pos_embedding.shape[-2:]
self.to_patch = encoder.to_patch_embedding[0]
self.patch_to_emb = nn.Sequential(*encoder.to_patch_embedding[1:])
pixel_values_per_patch = encoder.to_patch_embedding[2].weight.shape[-1]
self.decoder_dim = decoder_dim
self.enc_to_dec = nn.Linear(encoder_dim, decoder_dim) if encoder_dim != decoder_dim else nn.Identity()
self.mask_token = nn.Parameter(torch.randn(decoder_dim))
self.decoder = Transformer(dim = decoder_dim, depth = decoder_depth, heads = decoder_heads, dim_head = decoder_dim_head, mlp_dim = decoder_dim * 4)
self.decoder_pos_emb = nn.Embedding(num_patches, decoder_dim)
self.to_pixels = nn.Linear(decoder_dim, pixel_values_per_patch)
def forward(self, img):
device = img.device
patches = self.to_patch(img)
batch, num_patches, *_ = patches.shape
tokens = self.patch_to_emb(patches)
if self.encoder.pool == "cls":
tokens += self.encoder.pos_embedding[:, 1:(num_patches + 1)]
elif self.encoder.pool == "mean":
tokens += self.encoder.pos_embedding.to(device, dtype=tokens.dtype)
num_masked = int(self.masking_ratio * num_patches)
rand_indices = torch.rand(batch, num_patches, device=device).argsort(dim=-1)
masked_indices, unmasked_indices = rand_indices[:, :num_masked], rand_indices[:, num_masked:]
batch_range = torch.arange(batch, device=device)[:, None]
tokens = tokens[batch_range, unmasked_indices]
masked_patches = patches[batch_range, masked_indices]
encoded_tokens = self.encoder.transformer(tokens)
decoder_tokens = self.enc_to_dec(encoded_tokens)
unmasked_decoder_tokens = decoder_tokens + self.decoder_pos_emb(unmasked_indices)
mask_tokens = repeat(self.mask_token, 'd -> b n d', b=batch, n=num_masked)
mask_tokens = mask_tokens + self.decoder_pos_emb(masked_indices)
decoder_tokens = torch.zeros(batch, num_patches, self.decoder_dim, device=device)
decoder_tokens[batch_range, unmasked_indices] = unmasked_decoder_tokens
decoder_tokens[batch_range, masked_indices] = mask_tokens
decoded_tokens = self.decoder(decoder_tokens)
mask_tokens = decoded_tokens[batch_range, masked_indices]
pred_pixel_values = self.to_pixels(mask_tokens)
recon_loss = F.mse_loss(pred_pixel_values, masked_patches)
return recon_loss
.\lucidrains\vit-pytorch\vit_pytorch\max_vit.py
from functools import partial
import torch
from torch import nn, einsum
from einops import rearrange, repeat
from einops.layers.torch import Rearrange, Reduce
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def cast_tuple(val, length = 1):
return val if isinstance(val, tuple) else ((val,) * length)
class Residual(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
def forward(self, x):
return self.fn(x) + x
class FeedForward(nn.Module):
def __init__(self, dim, mult = 4, dropout = 0.):
super().__init__()
inner_dim = int(dim * mult)
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, inner_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class SqueezeExcitation(nn.Module):
def __init__(self, dim, shrinkage_rate = 0.25):
super().__init__()
hidden_dim = int(dim * shrinkage_rate)
self.gate = nn.Sequential(
Reduce('b c h w -> b c', 'mean'),
nn.Linear(dim, hidden_dim, bias = False),
nn.SiLU(),
nn.Linear(hidden_dim, dim, bias = False),
nn.Sigmoid(),
Rearrange('b c -> b c 1 1')
)
def forward(self, x):
return x * self.gate(x)
class MBConvResidual(nn.Module):
def __init__(self, fn, dropout = 0.):
super().__init__()
self.fn = fn
self.dropsample = Dropsample(dropout)
def forward(self, x):
out = self.fn(x)
out = self.dropsample(out)
return out + x
class Dropsample(nn.Module):
def __init__(self, prob = 0):
super().__init__()
self.prob = prob
def forward(self, x):
device = x.device
if self.prob == 0. or (not self.training):
return x
keep_mask = torch.FloatTensor((x.shape[0], 1, 1, 1), device = device).uniform_() > self.prob
return x * keep_mask / (1 - self.prob)
def MBConv(
dim_in,
dim_out,
*,
downsample,
expansion_rate = 4,
shrinkage_rate = 0.25,
dropout = 0.
):
hidden_dim = int(expansion_rate * dim_out)
stride = 2 if downsample else 1
net = nn.Sequential(
nn.Conv2d(dim_in, hidden_dim, 1),
nn.BatchNorm2d(hidden_dim),
nn.GELU(),
nn.Conv2d(hidden_dim, hidden_dim, 3, stride = stride, padding = 1, groups = hidden_dim),
nn.BatchNorm2d(hidden_dim),
nn.GELU(),
SqueezeExcitation(hidden_dim, shrinkage_rate = shrinkage_rate),
nn.Conv2d(hidden_dim, dim_out, 1),
nn.BatchNorm2d(dim_out)
)
if dim_in == dim_out and not downsample:
net = MBConvResidual(net, dropout = dropout)
return net
class Attention(nn.Module):
def __init__(
self,
dim,
dim_head = 32,
dropout = 0.,
window_size = 7
):
super().__init__()
assert (dim % dim_head) == 0, 'dimension should be divisible by dimension per head'
self.heads = dim // dim_head
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.to_qkv = nn.Linear(dim, dim * 3, bias = False)
self.attend = nn.Sequential(
nn.Softmax(dim = -1),
nn.Dropout(dropout)
)
self.to_out = nn.Sequential(
nn.Linear(dim, dim, bias = False),
nn.Dropout(dropout)
)
self.rel_pos_bias = nn.Embedding((2 * window_size - 1) ** 2, self.heads)
pos = torch.arange(window_size)
grid = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij'))
grid = rearrange(grid, 'c i j -> (i j) c')
rel_pos = rearrange(grid, 'i ... -> i 1 ...') - rearrange(grid, 'j ... -> 1 j ...')
rel_pos += window_size - 1
rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum(dim = -1)
self.register_buffer('rel_pos_indices', rel_pos_indices, persistent = False)
def forward(self, x):
batch, height, width, window_height, window_width, _, device, h = *x.shape, x.device, self.heads
x = self.norm(x)
x = rearrange(x, 'b x y w1 w2 d -> (b x y) (w1 w2) d')
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
q = q * self.scale
sim = einsum('b h i d, b h j d -> b h i j', q, k)
bias = self.rel_pos_bias(self.rel_pos_indices)
sim = sim + rearrange(bias, 'i j h -> h i j')
attn = self.attend(sim)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h (w1 w2) d -> b w1 w2 (h d)', w1 = window_height, w2 = window_width)
out = self.to_out(out)
return rearrange(out, '(b x y) ... -> b x y ...', x = height, y = width)
class MaxViT(nn.Module):
def __init__(
self,
*,
num_classes,
dim,
depth,
dim_head = 32,
dim_conv_stem = None,
window_size = 7,
mbconv_expansion_rate = 4,
mbconv_shrinkage_rate = 0.25,
dropout = 0.1,
channels = 3
):
super().__init__()
assert isinstance(depth, tuple), 'depth needs to be tuple if integers indicating number of transformer blocks at that stage'
dim_conv_stem = default(dim_conv_stem, dim)
self.conv_stem = nn.Sequential(
nn.Conv2d(channels, dim_conv_stem, 3, stride = 2, padding = 1),
nn.Conv2d(dim_conv_stem, dim_conv_stem, 3, padding = 1)
)
num_stages = len(depth)
dims = tuple(map(lambda i: (2 ** i) * dim, range(num_stages)))
dims = (dim_conv_stem, *dims)
dim_pairs = tuple(zip(dims[:-1], dims[1:]))
self.layers = nn.ModuleList([])
w = window_size
for ind, ((layer_dim_in, layer_dim), layer_depth) in enumerate(zip(dim_pairs, depth)):
for stage_ind in range(layer_depth):
is_first = stage_ind == 0
stage_dim_in = layer_dim_in if is_first else layer_dim
block = nn.Sequential(
MBConv(
stage_dim_in,
layer_dim,
downsample = is_first,
expansion_rate = mbconv_expansion_rate,
shrinkage_rate = mbconv_shrinkage_rate
),
Rearrange('b d (x w1) (y w2) -> b x y w1 w2 d', w1 = w, w2 = w),
Residual(layer_dim, Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = w)),
Residual(layer_dim, FeedForward(dim = layer_dim, dropout = dropout)),
Rearrange('b x y w1 w2 d -> b d (x w1) (y w2)'),
Rearrange('b d (w1 x) (w2 y) -> b x y w1 w2 d', w1 = w, w2 = w),
Residual(layer_dim, Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = w)),
Residual(layer_dim, FeedForward(dim = layer_dim, dropout = dropout)),
Rearrange('b x y w1 w2 d -> b d (w1 x) (w2 y)'),
)
self.layers.append(block)
self.mlp_head = nn.Sequential(
Reduce('b d h w -> b d', 'mean'),
nn.LayerNorm(dims[-1]),
nn.Linear(dims[-1], num_classes)
)
def forward(self, x):
x = self.conv_stem(x)
for stage in self.layers:
x = stage(x)
return self.mlp_head(x)
.\lucidrains\vit-pytorch\vit_pytorch\max_vit_with_registers.py
from functools import partial
import torch
from torch import nn, einsum
import torch.nn.functional as F
from torch.nn import Module, ModuleList, Sequential
from einops import rearrange, repeat, reduce, pack, unpack
from einops.layers.torch import Rearrange, Reduce
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def pack_one(x, pattern):
return pack([x], pattern)
def unpack_one(x, ps, pattern):
return unpack(x, ps, pattern)[0]
def cast_tuple(val, length = 1):
return val if isinstance(val, tuple) else ((val,) * length
def FeedForward(dim, mult = 4, dropout = 0.):
inner_dim = int(dim * mult)
return Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, inner_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
class SqueezeExcitation(Module):
def __init__(self, dim, shrinkage_rate = 0.25):
super().__init__()
hidden_dim = int(dim * shrinkage_rate)
self.gate = Sequential(
Reduce('b c h w -> b c', 'mean'),
nn.Linear(dim, hidden_dim, bias = False),
nn.SiLU(),
nn.Linear(hidden_dim, dim, bias = False),
nn.Sigmoid(),
Rearrange('b c -> b c 1 1')
)
def forward(self, x):
return x * self.gate(x)
class MBConvResidual(Module):
def __init__(self, fn, dropout = 0.):
super().__init__()
self.fn = fn
self.dropsample = Dropsample(dropout)
def forward(self, x):
out = self.fn(x)
out = self.dropsample(out)
return out + x
class Dropsample(Module):
def __init__(self, prob = 0):
super().__init__()
self.prob = prob
def forward(self, x):
device = x.device
if self.prob == 0. or (not self.training):
return x
keep_mask = torch.FloatTensor((x.shape[0], 1, 1, 1), device = device).uniform_() > self.prob
return x * keep_mask / (1 - self.prob)
def MBConv(
dim_in,
dim_out,
*,
downsample,
expansion_rate = 4,
shrinkage_rate = 0.25,
dropout = 0.
):
hidden_dim = int(expansion_rate * dim_out)
stride = 2 if downsample else 1
net = Sequential(
nn.Conv2d(dim_in, hidden_dim, 1),
nn.BatchNorm2d(hidden_dim),
nn.GELU(),
nn.Conv2d(hidden_dim, hidden_dim, 3, stride = stride, padding = 1, groups = hidden_dim),
nn.BatchNorm2d(hidden_dim),
nn.GELU(),
SqueezeExcitation(hidden_dim, shrinkage_rate = shrinkage_rate),
nn.Conv2d(hidden_dim, dim_out, 1),
nn.BatchNorm2d(dim_out)
)
if dim_in == dim_out and not downsample:
net = MBConvResidual(net, dropout = dropout)
return net
class Attention(Module):
def __init__(
self,
dim,
dim_head = 32,
dropout = 0.,
window_size = 7,
num_registers = 1
):
super().__init__()
assert num_registers > 0
assert (dim % dim_head) == 0, 'dimension should be divisible by dimension per head'
self.heads = dim // dim_head
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.to_qkv = nn.Linear(dim, dim * 3, bias = False)
self.attend = nn.Sequential(
nn.Softmax(dim = -1),
nn.Dropout(dropout)
)
self.to_out = nn.Sequential(
nn.Linear(dim, dim, bias = False),
nn.Dropout(dropout)
)
num_rel_pos_bias = (2 * window_size - 1) ** 2
self.rel_pos_bias = nn.Embedding(num_rel_pos_bias + 1, self.heads)
pos = torch.arange(window_size)
grid = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij'))
grid = rearrange(grid, 'c i j -> (i j) c')
rel_pos = rearrange(grid, 'i ... -> i 1 ...') - rearrange(grid, 'j ... -> 1 j ...')
rel_pos += window_size - 1
rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum(dim = -1)
rel_pos_indices = F.pad(rel_pos_indices, (num_registers, 0, num_registers, 0), value = num_rel_pos_bias)
self.register_buffer('rel_pos_indices', rel_pos_indices, persistent = False)
def forward(self, x):
device, h, bias_indices = x.device, self.heads, self.rel_pos_indices
x = self.norm(x)
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
q = q * self.scale
sim = einsum('b h i d, b h j d -> b h i j', q, k)
bias = self.rel_pos_bias(bias_indices)
sim = sim + rearrange(bias, 'i j h -> h i j')
attn = self.attend(sim)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class MaxViT(Module):
def __init__(
self,
*,
num_classes,
dim,
depth,
dim_head = 32,
dim_conv_stem = None,
window_size = 7,
mbconv_expansion_rate = 4,
mbconv_shrinkage_rate = 0.25,
dropout = 0.1,
channels = 3,
num_register_tokens = 4
):
super().__init__()
assert isinstance(depth, tuple), 'depth needs to be tuple if integers indicating number of transformer blocks at that stage'
assert num_register_tokens > 0
dim_conv_stem = default(dim_conv_stem, dim)
self.conv_stem = Sequential(
nn.Conv2d(channels, dim_conv_stem, 3, stride = 2, padding = 1),
nn.Conv2d(dim_conv_stem, dim_conv_stem, 3, padding = 1)
)
num_stages = len(depth)
dims = tuple(map(lambda i: (2 ** i) * dim, range(num_stages)))
dims = (dim_conv_stem, *dims)
dim_pairs = tuple(zip(dims[:-1], dims[1:]))
self.layers = nn.ModuleList([])
self.window_size = window_size
self.register_tokens = nn.ParameterList([])
for ind, ((layer_dim_in, layer_dim), layer_depth) in enumerate(zip(dim_pairs, depth)):
for stage_ind in range(layer_depth):
is_first = stage_ind == 0
stage_dim_in = layer_dim_in if is_first else layer_dim
conv = MBConv(
stage_dim_in,
layer_dim,
downsample = is_first,
expansion_rate = mbconv_expansion_rate,
shrinkage_rate = mbconv_shrinkage_rate
)
block_attn = Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = window_size, num_registers = num_register_tokens)
block_ff = FeedForward(dim = layer_dim, dropout = dropout)
grid_attn = Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = window_size, num_registers = num_register_tokens)
grid_ff = FeedForward(dim = layer_dim, dropout = dropout)
register_tokens = nn.Parameter(torch.randn(num_register_tokens, layer_dim))
self.layers.append(ModuleList([
conv,
ModuleList([block_attn, block_ff]),
ModuleList([grid_attn, grid_ff])
]))
self.register_tokens.append(register_tokens)
self.mlp_head = nn.Sequential(
Reduce('b d h w -> b d', 'mean'),
nn.LayerNorm(dims[-1]),
nn.Linear(dims[-1], num_classes)
)
def forward(self, x):
b, w = x.shape[0], self.window_size
x = self.conv_stem(x)
for (conv, (block_attn, block_ff), (grid_attn, grid_ff)), register_tokens in zip(self.layers, self.register_tokens):
x = conv(x)
x = rearrange(x, 'b d (x w1) (y w2) -> b x y w1 w2 d', w1 = w, w2 = w)
r = repeat(register_tokens, 'n d -> b x y n d', b = b, x = x.shape[1],y = x.shape[2])
r, register_batch_ps = pack_one(r, '* n d')
x, window_ps = pack_one(x, 'b x y * d')
x, batch_ps = pack_one(x, '* n d')
x, register_ps = pack([r, x], 'b * d')
x = block_attn(x) + x
x = block_ff(x) + x
r, x = unpack(x, register_ps, 'b * d')
x = unpack_one(x, batch_ps, '* n d')
x = unpack_one(x, window_ps, 'b x y * d')
x = rearrange(x, 'b x y w1 w2 d -> b d (x w1) (y w2)')
r = unpack_one(r, register_batch_ps, '* n d')
x = rearrange(x, 'b d (w1 x) (w2 y) -> b x y w1 w2 d', w1 = w, w2 = w)
r = reduce(r, 'b x y n d -> b n d', 'mean')
r = repeat(r, 'b n d -> b x y n d', x = x.shape[1], y = x.shape[2])
r, register_batch_ps = pack_one(r, '* n d')
x, window_ps = pack_one(x, 'b x y * d')
x, batch_ps = pack_one(x, '* n d')
x, register_ps = pack([r, x], 'b * d')
x = grid_attn(x) + x
r, x = unpack(x, register_ps, 'b * d')
x = grid_ff(x) + x
x = unpack_one(x, batch_ps, '* n d')
x = unpack_one(x, window_ps, 'b x y * d')
x = rearrange(x, 'b x y w1 w2 d -> b d (w1 x) (w2 y)')
return self.mlp_head(x)
.\lucidrains\vit-pytorch\vit_pytorch\mobile_vit.py
import torch
import torch.nn as nn
from einops import rearrange
from einops.layers.torch import Reduce
def conv_1x1_bn(inp, oup):
return nn.Sequential(
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
nn.SiLU()
)
def conv_nxn_bn(inp, oup, kernel_size=3, stride=1):
return nn.Sequential(
nn.Conv2d(inp, oup, kernel_size, stride, 1, bias=False),
nn.BatchNorm2d(oup),
nn.SiLU()
)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout=0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.SiLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim=-1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, 'b p n (h d) -> b p h n d', h=self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b p h n d -> b p n (h d)')
return self.to_out(out)
class Transformer(nn.Module):
"""Transformer block described in ViT.
Paper: https://arxiv.org/abs/2010.11929
Based on: https://github.com/lucidrains/vit-pytorch
"""
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, heads, dim_head, dropout),
FeedForward(dim, mlp_dim, dropout)
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x
class MV2Block(nn.Module):
"""MV2 block described in MobileNetV2.
Paper: https://arxiv.org/pdf/1801.04381
Based on: https://github.com/tonylins/pytorch-mobilenet-v2
"""
def __init__(self, inp, oup, stride=1, expansion=4):
super().__init__()
self.stride = stride
assert stride in [1, 2]
hidden_dim = int(inp * expansion)
self.use_res_connect = self.stride == 1 and inp == oup
if expansion == 1:
self.conv = nn.Sequential(
nn.Conv2d(hidden_dim, hidden_dim, 3, stride,
1, groups=hidden_dim, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.SiLU(),
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
)
else:
self.conv = nn.Sequential(
nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.SiLU(),
nn.Conv2d(hidden_dim, hidden_dim, 3, stride,
1, groups=hidden_dim, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.SiLU(),
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
)
def forward(self, x):
out = self.conv(x)
if self.use_res_connect:
out = out + x
return out
class MobileViTBlock(nn.Module):
def __init__(self, dim, depth, channel, kernel_size, patch_size, mlp_dim, dropout=0.):
super().__init__()
self.ph, self.pw = patch_size
self.conv1 = conv_nxn_bn(channel, channel, kernel_size)
self.conv2 = conv_1x1_bn(channel, dim)
self.transformer = Transformer(dim, depth, 4, 8, mlp_dim, dropout)
self.conv3 = conv_1x1_bn(dim, channel)
self.conv4 = conv_nxn_bn(2 * channel, channel, kernel_size)
def forward(self, x):
y = x.clone()
x = self.conv1(x)
x = self.conv2(x)
_, _, h, w = x.shape
x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw)
x = self.transformer(x)
x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h//self.ph, w=w//self.pw, ph=self.ph, pw=self.pw)
x = self.conv3(x)
x = torch.cat((x, y), 1)
x = self.conv4(x)
return x
class MobileViT(nn.Module):
"""MobileViT.
Paper: https://arxiv.org/abs/2110.02178
Based on: https://github.com/chinhsuanwu/mobilevit-pytorch
"""
def __init__(
self,
image_size,
dims,
channels,
num_classes,
expansion=4,
kernel_size=3,
patch_size=(2, 2),
depths=(2, 4, 3)
):
super().__init__()
assert len(dims) == 3, 'dims must be a tuple of 3'
assert len(depths) == 3, 'depths must be a tuple of 3'
ih, iw = image_size
ph, pw = patch_size
assert ih % ph == 0 and iw % pw == 0
init_dim, *_, last_dim = channels
self.conv1 = conv_nxn_bn(3, init_dim, stride=2)
self.stem = nn.ModuleList([])
self.stem.append(MV2Block(channels[0], channels[1], 1, expansion))
self.stem.append(MV2Block(channels[1], channels[2], 2, expansion))
self.stem.append(MV2Block(channels[2], channels[3], 1, expansion))
self.stem.append(MV2Block(channels[2], channels[3], 1, expansion))
self.trunk = nn.ModuleList([])
self.trunk.append(nn.ModuleList([
MV2Block(channels[3], channels[4], 2, expansion),
MobileViTBlock(dims[0], depths[0], channels[5],
kernel_size, patch_size, int(dims[0] * 2))
]))
self.trunk.append(nn.ModuleList([
MV2Block(channels[5], channels[6], 2, expansion),
MobileViTBlock(dims[1], depths[1], channels[7],
kernel_size, patch_size, int(dims[1] * 4))
]))
self.trunk.append(nn.ModuleList([
MV2Block(channels[7], channels[8], 2, expansion),
MobileViTBlock(dims[2], depths[2], channels[9],
kernel_size, patch_size, int(dims[2] * 4))
]))
self.to_logits = nn.Sequential(
conv_1x1_bn(channels[-2], last_dim),
Reduce('b c h w -> b c', 'mean'),
nn.Linear(channels[-1], num_classes, bias=False)
)
def forward(self, x):
x = self.conv1(x)
for conv in self.stem:
x = conv(x)
for conv, attn in self.trunk:
x = conv(x)
x = attn(x)
return self.to_logits(x)
.\lucidrains\vit-pytorch\vit_pytorch\mp3.py
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def pair(t):
return t if isinstance(t, tuple) else (t, t)
def posemb_sincos_2d(patches, temperature = 10000, dtype = torch.float32):
_, h, w, dim, device, dtype = *patches.shape, patches.device, patches.dtype
y, x = torch.meshgrid(torch.arange(h, device = device), torch.arange(w, device = device), indexing = 'ij')
assert (dim % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb'
omega = torch.arange(dim // 4, device = device) / (dim // 4 - 1)
omega = 1. / (temperature ** omega)
y = y.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1)
return pe.type(dtype)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.norm = nn.LayerNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x, context = None):
b, n, _, h = *x.shape, self.heads
x = self.norm(x)
context = self.norm(context) if exists(context) else x
qkv = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
def forward(self, x, context = None):
for attn, ff in self.layers:
x = attn(x, context = context) + x
x = ff(x) + x
return x
class ViT(nn.Module):
def __init__(self, *, num_classes, image_size, patch_size, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0.):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width
self.dim = dim
self.num_patches = num_patches
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b h w (p1 p2 c)', p1 = patch_height, p2 = patch_width),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
self.to_latent = nn.Identity()
self.linear_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, img):
*_, h, w, dtype = *img.shape, img.dtype
x = self.to_patch_embedding(img)
pe = posemb_sincos_2d(x)
x = rearrange(x, 'b ... d -> b (...) d') + pe
x = self.transformer(x)
x = x.mean(dim = 1)
x = self.to_latent(x)
return self.linear_head(x)
class MP3(nn.Module):
def __init__(self, vit: ViT, masking_ratio):
super().__init__()
self.vit = vit
assert masking_ratio > 0 and masking_ratio < 1, 'masking ratio must be kept between 0 and 1'
self.masking_ratio = masking_ratio
dim = vit.dim
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, vit.num_patches)
)
def forward(self, img):
device = img.device
tokens = self.vit.to_patch_embedding(img)
tokens = rearrange(tokens, 'b ... d -> b (...) d')
batch, num_patches, *_ = tokens.shape
num_masked = int(self.masking_ratio * num_patches)
rand_indices = torch.rand(batch, num_patches, device=device).argsort(dim=-1)
masked_indices, unmasked_indices = rand_indices[:, :num_masked], rand_indices[:, num_masked:]
batch_range = torch.arange(batch, device=device)[:, None]
tokens_unmasked = tokens[batch_range, unmasked_indices]
attended_tokens = self.vit.transformer(tokens, tokens_unmasked)
logits = rearrange(self.mlp_head(attended_tokens), 'b n d -> (b n) d')
labels = repeat(torch.arange(num_patches, device=device), 'n -> (b n)', b=batch)
loss = F.cross_entropy(logits, labels)
return loss
.\lucidrains\vit-pytorch\vit_pytorch\mpp.py
import math
import torch
from torch import nn
import torch.nn.functional as F
from einops import rearrange, repeat, reduce
def exists(val):
return val is not None
def prob_mask_like(t, prob):
batch, seq_length, _ = t.shape
return torch.zeros((batch, seq_length)).float().uniform_(0, 1) < prob
def get_mask_subset_with_prob(patched_input, prob):
batch, seq_len, _, device = *patched_input.shape, patched_input.device
max_masked = math.ceil(prob * seq_len)
rand = torch.rand((batch, seq_len), device=device)
_, sampled_indices = rand.topk(max_masked, dim=-1)
new_mask = torch.zeros((batch, seq_len), device=device)
new_mask.scatter_(1, sampled_indices, 1)
return new_mask.bool()
class MPPLoss(nn.Module):
def __init__(
self,
patch_size,
channels,
output_channel_bits,
max_pixel_val,
mean,
std
):
super().__init__()
self.patch_size = patch_size
self.channels = channels
self.output_channel_bits = output_channel_bits
self.max_pixel_val = max_pixel_val
self.mean = torch.tensor(mean).view(-1, 1, 1) if mean else None
self.std = torch.tensor(std).view(-1, 1, 1) if std else None
def forward(self, predicted_patches, target, mask):
p, c, mpv, bits, device = self.patch_size, self.channels, self.max_pixel_val, self.output_channel_bits, target.device
bin_size = mpv / (2 ** bits)
if exists(self.mean) and exists(self.std):
target = target * self.std + self.mean
target = target.clamp(max=mpv)
avg_target = reduce(target, 'b c (h p1) (w p2) -> b (h w) c', 'mean', p1=p, p2=p).contiguous()
channel_bins = torch.arange(bin_size, mpv, bin_size, device=device)
discretized_target = torch.bucketize(avg_target, channel_bins)
bin_mask = (2 ** bits) ** torch.arange(0, c, device=device).long()
bin_mask = rearrange(bin_mask, 'c -> () () c')
target_label = torch.sum(bin_mask * discretized_target, dim=-1)
loss = F.cross_entropy(predicted_patches[mask], target_label[mask])
return loss
class MPP(nn.Module):
def __init__(
self,
transformer,
patch_size,
dim,
output_channel_bits=3,
channels=3,
max_pixel_val=1.0,
mask_prob=0.15,
replace_prob=0.5,
random_patch_prob=0.5,
mean=None,
std=None
):
super().__init__()
self.transformer = transformer
self.loss = MPPLoss(patch_size, channels, output_channel_bits,
max_pixel_val, mean, std)
self.patch_to_emb = nn.Sequential(transformer.to_patch_embedding[1:])
self.to_bits = nn.Linear(dim, 2**(output_channel_bits * channels))
self.patch_size = patch_size
self.mask_prob = mask_prob
self.replace_prob = replace_prob
self.random_patch_prob = random_patch_prob
self.mask_token = nn.Parameter(torch.randn(1, 1, channels * patch_size ** 2))
def forward(self, input, **kwargs):
transformer = self.transformer
img = input.clone().detach()
p = self.patch_size
input = rearrange(input,
'b c (h p1) (w p2) -> b (h w) (p1 p2 c)',
p1=p,
p2=p)
mask = get_mask_subset_with_prob(input, self.mask_prob)
masked_input = input.clone().detach()
if self.random_patch_prob > 0:
random_patch_sampling_prob = self.random_patch_prob / (
1 - self.replace_prob)
random_patch_prob = prob_mask_like(input,
random_patch_sampling_prob).to(mask.device)
bool_random_patch_prob = mask * (random_patch_prob == True)
random_patches = torch.randint(0,
input.shape[1],
(input.shape[0], input.shape[1]),
device=input.device)
randomized_input = masked_input[
torch.arange(masked_input.shape[0]).unsqueeze(-1),
random_patches]
masked_input[bool_random_patch_prob] = randomized_input[
bool_random_patch_prob]
replace_prob = prob_mask_like(input, self.replace_prob).to(mask.device)
bool_mask_replace = (mask * replace_prob) == True
masked_input[bool_mask_replace] = self.mask_token
masked_input = self.patch_to_emb(masked_input)
b, n, _ = masked_input.shape
cls_tokens = repeat(transformer.cls_token, '() n d -> b n d', b=b)
masked_input = torch.cat((cls_tokens, masked_input), dim=1)
masked_input += transformer.pos_embedding[:, :(n + 1)]
masked_input = transformer.dropout(masked_input)
masked_input = transformer.transformer(masked_input, **kwargs)
cls_logits = self.to_bits(masked_input)
logits = cls_logits[:, 1:, :]
mpp_loss = self.loss(logits, img, mask)
return mpp_loss
.\lucidrains\vit-pytorch\vit_pytorch\na_vit.py
from functools import partial
from typing import List, Union
import torch
import torch.nn.functional as F
from torch import nn, Tensor
from torch.nn.utils.rnn import pad_sequence as orig_pad_sequence
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def always(val):
return lambda *args: val
def pair(t):
return t if isinstance(t, tuple) else (t, t)
def divisible_by(numer, denom):
return (numer % denom) == 0
def group_images_by_max_seq_len(
images: List[Tensor],
patch_size: int,
calc_token_dropout = None,
max_seq_len = 2048
) -> List[List[Tensor]]:
calc_token_dropout = default(calc_token_dropout, always(0.))
groups = []
group = []
seq_len = 0
if isinstance(calc_token_dropout, (float, int)):
calc_token_dropout = always(calc_token_dropout)
for image in images:
assert isinstance(image, Tensor)
image_dims = image.shape[-2:]
ph, pw = map(lambda t: t // patch_size, image_dims)
image_seq_len = (ph * pw)
image_seq_len = int(image_seq_len * (1 - calc_token_dropout(*image_dims)))
assert image_seq_len <= max_seq_len, f'image with dimensions {image_dims} exceeds maximum sequence length'
if (seq_len + image_seq_len) > max_seq_len:
groups.append(group)
group = []
seq_len = 0
group.append(image)
seq_len += image_seq_len
if len(group) > 0:
groups.append(group)
return groups
class LayerNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.gamma = nn.Parameter(torch.ones(dim))
self.register_buffer('beta', torch.zeros(dim))
def forward(self, x):
return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
class RMSNorm(nn.Module):
def __init__(self, heads, dim):
super().__init__()
self.scale = dim ** 0.5
self.gamma = nn.Parameter(torch.ones(heads, 1, dim))
def forward(self, x):
normed = F.normalize(x, dim = -1)
return normed * self.scale * self.gamma
def FeedForward(dim, hidden_dim, dropout = 0.):
return nn.Sequential(
LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.norm = LayerNorm(dim)
self.q_norm = RMSNorm(heads, dim_head)
self.k_norm = RMSNorm(heads, dim_head)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim, bias = False),
nn.Dropout(dropout)
)
def forward(
self,
x,
context = None,
mask = None,
attn_mask = None
):
x = self.norm(x)
kv_input = default(context, x)
qkv = (self.to_q(x), *self.to_kv(kv_input).chunk(2, dim = -1))
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
q = self.q_norm(q)
k = self.k_norm(k)
dots = torch.matmul(q, k.transpose(-1, -2))
if exists(mask):
mask = rearrange(mask, 'b j -> b 1 1 j')
dots = dots.masked_fill(~mask, -torch.finfo(dots.dtype).max)
if exists(attn_mask):
dots = dots.masked_fill(~attn_mask, -torch.finfo(dots.dtype).max)
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
self.norm = LayerNorm(dim)
def forward(
self,
x,
mask = None,
attn_mask = None
):
for attn, ff in self.layers:
x = attn(x, mask = mask, attn_mask = attn_mask) + x
x = ff(x) + x
return self.norm(x)
class NaViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., token_dropout_prob = None):
super().__init__()
image_height, image_width = pair(image_size)
self.calc_token_dropout = None
if callable(token_dropout_prob):
self.calc_token_dropout = token_dropout_prob
elif isinstance(token_dropout_prob, (float, int)):
assert 0. < token_dropout_prob < 1.
token_dropout_prob = float(token_dropout_prob)
self.calc_token_dropout = lambda height, width: token_dropout_prob
assert divisible_by(image_height, patch_size) and divisible_by(image_width, patch_size), 'Image dimensions must be divisible by the patch size.'
patch_height_dim, patch_width_dim = (image_height // patch_size), (image_width // patch_size)
patch_dim = channels * (patch_size ** 2)
self.channels = channels
self.patch_size = patch_size
self.to_patch_embedding = nn.Sequential(
LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
LayerNorm(dim),
)
self.pos_embed_height = nn.Parameter(torch.randn(patch_height_dim, dim))
self.pos_embed_width = nn.Parameter(torch.randn(patch_width_dim, dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
self.attn_pool_queries = nn.Parameter(torch.randn(dim))
self.attn_pool = Attention(dim = dim, dim_head = dim_head, heads = heads)
self.to_latent = nn.Identity()
self.mlp_head = nn.Sequential(
LayerNorm(dim),
nn.Linear(dim, num_classes, bias = False)
)
@property
def device(self):
return next(self.parameters()).device
def forward(
self,
batched_images: Union[List[Tensor], List[List[Tensor]]],
group_images = False,
group_max_seq_len = 2048
.\lucidrains\vit-pytorch\vit_pytorch\nest.py
from functools import partial
import torch
from torch import nn, einsum
from einops import rearrange
from einops.layers.torch import Rearrange, Reduce
def cast_tuple(val, depth):
return val if isinstance(val, tuple) else ((val,) * depth)
class LayerNorm(nn.Module):
def __init__(self, dim, eps = 1e-5):
super().__init__()
self.eps = eps
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
def forward(self, x):
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
class FeedForward(nn.Module):
def __init__(self, dim, mlp_mult = 4, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
LayerNorm(dim),
nn.Conv2d(dim, dim * mlp_mult, 1),
nn.GELU(),
nn.Dropout(dropout),
nn.Conv2d(dim * mlp_mult, dim, 1),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dropout = 0.):
super().__init__()
dim_head = dim // heads
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)
self.to_out = nn.Sequential(
nn.Conv2d(inner_dim, dim, 1),
nn.Dropout(dropout)
)
def forward(self, x):
b, c, h, w, heads = *x.shape, self.heads
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = 1)
q, k, v = map(lambda t: rearrange(t, 'b (h d) x y -> b h (x y) d', h = heads), qkv)
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w)
return self.to_out(out)
def Aggregate(dim, dim_out):
return nn.Sequential(
nn.Conv2d(dim, dim_out, 3, padding = 1),
LayerNorm(dim_out),
nn.MaxPool2d(3, stride = 2, padding = 1)
)
class Transformer(nn.Module):
def __init__(self, dim, seq_len, depth, heads, mlp_mult, dropout = 0.):
super().__init__()
self.layers = nn.ModuleList([])
self.pos_emb = nn.Parameter(torch.randn(seq_len))
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, heads = heads, dropout = dropout),
FeedForward(dim, mlp_mult, dropout = dropout)
]))
def forward(self, x):
*_, h, w = x.shape
pos_emb = self.pos_emb[:(h * w)]
pos_emb = rearrange(pos_emb, '(h w) -> () () h w', h = h, w = w)
x = x + pos_emb
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x
class NesT(nn.Module):
def __init__(
self,
*,
image_size,
patch_size,
num_classes,
dim,
heads,
num_hierarchies,
block_repeats,
mlp_mult = 4,
channels = 3,
dim_head = 64,
dropout = 0.
):
super().__init__()
assert (image_size % patch_size) == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_size // patch_size) ** 2
patch_dim = channels * patch_size ** 2
fmap_size = image_size // patch_size
blocks = 2 ** (num_hierarchies - 1)
seq_len = (fmap_size // blocks) ** 2
hierarchies = list(reversed(range(num_hierarchies)))
mults = [2 ** i for i in reversed(hierarchies)]
layer_heads = list(map(lambda t: t * heads, mults))
layer_dims = list(map(lambda t: t * dim, mults))
last_dim = layer_dims[-1]
layer_dims = [*layer_dims, layer_dims[-1]]
dim_pairs = zip(layer_dims[:-1], layer_dims[1:])
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (p1 p2 c) h w', p1 = patch_size, p2 = patch_size),
LayerNorm(patch_dim),
nn.Conv2d(patch_dim, layer_dims[0], 1),
LayerNorm(layer_dims[0])
)
block_repeats = cast_tuple(block_repeats, num_hierarchies)
self.layers = nn.ModuleList([])
for level, heads, (dim_in, dim_out), block_repeat in zip(hierarchies, layer_heads, dim_pairs, block_repeats):
is_last = level == 0
depth = block_repeat
self.layers.append(nn.ModuleList([
Transformer(dim_in, seq_len, depth, heads, mlp_mult, dropout),
Aggregate(dim_in, dim_out) if not is_last else nn.Identity()
]))
self.mlp_head = nn.Sequential(
LayerNorm(last_dim),
Reduce('b c h w -> b c', 'mean'),
nn.Linear(last_dim, num_classes)
)
def forward(self, img):
x = self.to_patch_embedding(img)
b, c, h, w = x.shape
num_hierarchies = len(self.layers)
for level, (transformer, aggregate) in zip(reversed(range(num_hierarchies)), self.layers):
block_size = 2 ** level
x = rearrange(x, 'b c (b1 h) (b2 w) -> (b b1 b2) c h w', b1 = block_size, b2 = block_size)
x = transformer(x)
x = rearrange(x, '(b b1 b2) c h w -> b c (b1 h) (b2 w)', b1 = block_size, b2 = block_size)
x = aggregate(x)
return self.mlp_head(x)