Lucidrains 系列项目源码解析(七十二)
.\lucidrains\parti-pytorch\setup.py
# 导入设置工具和查找包
from setuptools import setup, find_packages
# 执行版本文件中的代码,将版本信息导入当前环境
exec(open('parti_pytorch/version.py').read())
# 设置包的元信息
setup(
name = 'parti-pytorch', # 包名
packages = find_packages(exclude=[]), # 查找包
version = __version__, # 版本号
license='MIT', # 许可证
description = 'Parti - Pathways Autoregressive Text-to-Image Model - Pytorch', # 描述
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
long_description_content_type = 'text/markdown', # 长描述内容类型
url = 'https://github.com/lucidrains/parti-pytorch', # URL
keywords = [ # 关键词
'artificial intelligence',
'deep learning',
'transformers',
'attention mechanism',
'text-to-image'
],
install_requires=[ # 安装依赖
'einops>=0.7',
'einops-exts',
'ema-pytorch',
'torch>=1.6',
'torchvision',
'transformers',
'vector-quantize-pytorch>=1.11.8'
],
classifiers=[ # 分类
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\pause-transformer\pause_transformer\pause_transformer.py
import torch
import torch.nn.functional as F
from torch import nn, Tensor, einsum
from torch.nn import Module, ModuleList, Sequential
from einops import rearrange, repeat, pack, unpack
from einops.layers.torch import Rearrange
# functions
# 检查变量是否存在的函数
def exists(v):
return v is not None
# tensor functions
# 计算张量的对数,避免出现负无穷
def log(t, eps = 1e-20):
return t.clamp(min = eps).log()
# 计算张量的熵
def entropy(t, dim = -1):
prob = t.softmax(dim = dim)
return (prob * log(prob)).sum(dim = dim)
# norm
# RMS 归一化
class RMSNorm(Module):
def __init__(self, dim):
super().__init__()
self.scale = dim ** 0.5
self.gamma = nn.Parameter(torch.ones(dim))
def forward(self, x):
return F.normalize(x, dim = -1) * self.scale * self.gamma
# cheap relative positions
# from Peng Bo's RWKV
# 移动 token 的模块
class ShiftTokens(Module):
def forward(self, x):
x, x_shift = x.chunk(2, dim = -1)
x_shift = F.pad(x_shift, (0, 0, 1, -1), value = 0.)
return torch.cat((x, x_shift), dim = -1)
# feedforward
# 前馈神经网络
def FeedForward(dim, mult = 4):
dim_inner = int(dim * mult)
return Sequential(
ShiftTokens(),
RMSNorm(dim),
nn.Linear(dim, dim_inner),
nn.GELU(),
nn.Linear(dim_inner, dim)
)
# CausalAttention
# 因果注意力机制
class CausalAttention(Module):
def __init__(
self,
dim,
*,
dim_head = 64,
heads = 8
):
super().__init__()
self.scale = dim ** -0.5
dim_inner = dim_head * heads
self.norm = RMSNorm(dim)
self.to_qkv = Sequential(
nn.Linear(dim, dim_inner * 3, bias = False),
Rearrange('b n (qkv h d) -> qkv b h n d', qkv = 3, h = heads)
)
self.to_out = Sequential(
Rearrange('b h n d -> b n (h d)'),
nn.Linear(dim_inner, dim, bias = False)
)
def forward(self, x):
x = self.norm(x)
q, k, v = self.to_qkv(x)
q = q * self.scale
sim = einsum('b h i d, b h j d -> b h i j', q, k)
i, j = sim.shape[-2:]
causal_mask = torch.ones((i, j), device = x.device).triu(j - i + 1)
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
attn = sim.softmax(dim = -1)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
return self.to_out(out), torch.stack((k, v))
# integrate previous pause / thinking information
# 整合之前的暂停/思考信息
class IntegratePreviousThought(Module):
def __init__(self, dim):
super().__init__()
self.net = Sequential(
RMSNorm(dim),
Rearrange('b n p d -> b n (p d)'),
nn.Linear(dim * 2, dim)
)
def forward(
self,
x,
pause_tokens,
pause_lengths = None
):
if not exists(pause_lengths):
p = pause_tokens[:, :, -1]
else:
batch, seq_len = x.shape[:2]
batch_arange = torch.arange(batch, device = x.device)[:, None, None]
seq_arange = torch.arange(seq_len, device = x.device)[:, None]
pause_lengths = pause_lengths[:, :, None]
p = pause_tokens[batch_arange, seq_arange, pause_lengths]
p = rearrange(p, '... 1 d -> ... d')
p = F.pad(p, (0, 0, 1, -1), value = 0.)
x = torch.stack((x, p), dim = -2)
out = self.net(x)
return out
# class
# 暂停 Transformer
class PauseTransformer(Module):
def __init__(
self,
*,
num_tokens,
dim,
depth,
max_pause_length = 2,
dim_head = 64,
heads = 8,
ff_mult = 4
):
# 调用父类的构造函数
super().__init__()
# 创建一个嵌入层,用于将输入的 token 映射为指定维度的向量
self.token_emb = nn.Embedding(num_tokens, dim)
# 设置最大暂停长度
self.max_pause_length = max_pause_length
# 创建一个可学习的参数,表示暂停的 token
self.pause_tokens = nn.Parameter(torch.randn(max_pause_length, dim))
# 创建一个用于整合前一个暂停的模块
self.integrate_prev_pause = IntegratePreviousThought(dim)
# 创建一个空的模块列表,用于存储多个层
self.layers = ModuleList([])
# 根据指定的深度循环创建多个层
for _ in range(depth):
# 每个层包含一个自注意力机制和一个前馈神经网络
self.layers.append(ModuleList([
CausalAttention(dim = dim, dim_head = dim_head, heads = heads),
FeedForward(dim = dim, mult = ff_mult)
]))
# 创建一个用于输出 logits 的序列模块
self.to_logits = Sequential(
RMSNorm(dim),
nn.Linear(dim, num_tokens, bias = False)
)
def forward(
self,
x,
return_loss = False,
return_logit_entropy = False,
arrest_pausing = False,
no_prev_pause_integration = False,
pause_lengths = None,
rand_uniform_pausing = False # this would do random pausing uniform from [0, max_pause_length]
):
"""
einstein notation:
b - batch
n - main sequence length
p - thinking sequence length (pause)
d - feature dimension
"""
# 如果需要返回损失,则提取输入序列和标签序列
if return_loss:
x, labels = x[:, :-1], x[:, 1:]
# 如果不需要阻止暂停
if not arrest_pausing:
# 如果需要随机暂停且暂停长度未指定,则随机生成暂停长度
if rand_uniform_pausing and not exists(pause_lengths):
pause_lengths = torch.randint(0, self.max_pause_length, x.shape)
# 获取输入张量的批量大小和序列长度
batch, seq_len = x.shape
# 将输入 token 映射为向量
x = self.token_emb(x)
# 重复暂停 token,以便与输入张量形状匹配
p = repeat(self.pause_tokens, 'p d -> b n p d', b = batch, n = seq_len)
# 如果暂停长度已指定
if exists(pause_lengths):
max_pause = int(pause_lengths.amax().item())
p = p[:, :, :(max_pause + 1)]
# 如果最大暂停长度为 0,则阻止暂停
arrest_pausing = max_pause == 0
# 遍历每个层的自注意力机制和前馈神经网络
for attn, ff in self.layers:
attn_out, cached_kvs = attn(x)
x = x + attn_out
x = ff(x) + x
# 如果阻止暂停,则跳过暂停处理
if arrest_pausing:
continue
# 处理思考 token
x, ps = pack([x, p], 'b n * d')
x = rearrange(x, '... p d -> (...) p d')
attn_out, _ = attn(x)
x = x + attn_out
x = ff(x) + x
x = rearrange(x, '(b n) p d -> b n p d', b = batch)
x, p = unpack(x, ps, 'b n * d')
# 在训练过程中,允许每个 token 独立思考,不受前一个 token 思考的影响
if no_prev_pause_integration:
continue
# 整合前一个暂停的最后一个 token
x = x + self.integrate_prev_pause(x, p, pause_lengths)
# 如果不阻止暂停,则重新打包输入张量和暂停张量
if not arrest_pausing:
x, _ = pack([x, p], 'b n * d')
# 计算 logits
logits = self.to_logits(x)
# 如果需要返回 logits 的熵
if return_logit_entropy:
return entropy(logits)
# 如果不需要返回损失,则返回 logits
if not return_loss:
return logits
# 如果阻止暂停,则重新排列 logits 的形状
if arrest_pausing:
logits = rearrange(logits, 'b n d -> b d n')
else:
labels = repeat(labels, 'b n -> (b p) n', p = self.max_pause_length + 1)
logits = rearrange(logits, 'b n p d -> (b p) d n')
# 计算交叉熵损失
loss = F.cross_entropy(logits, labels)
return loss
.\lucidrains\pause-transformer\pause_transformer\__init__.py
# 从 pause_transformer.pause_transformer 模块中导入 PauseTransformer 类
from pause_transformer.pause_transformer import PauseTransformer

Pause Transformer (wip)
Yet another random morning idea to be quickly tried and architecture shared if it works; to allow the transformer to pause for any amount of time on any token.
Again, the idea relies on axial attention; one axis attends along the sequence length as in the usual transformer, the other along a thinking or pause dimension.
Todo
-
allow for custom pause distributions across token
-
see if one can do a two pass, using the logit entropy as a way to decide how to shape the pause mask
-
run experiments on enwik8, but if do not see anything, move onwards to something harder, say arithmetic
Citations
@inproceedings{Goyal2023ThinkBY,
title = {Think before you speak: Training Language Models With Pause Tokens},
author = {Sachin Goyal and Ziwei Ji and Ankit Singh Rawat and Aditya Krishna Menon and Sanjiv Kumar and Vaishnavh Nagarajan},
year = {2023},
url = {https://api.semanticscholar.org/CorpusID:263608983}
}
.\lucidrains\pause-transformer\setup.py
# 导入设置安装和查找包的函数
from setuptools import setup, find_packages
# 设置包的元数据
setup(
# 包的名称
name = 'pause-transformer',
# 查找所有包,不排除任何包
packages = find_packages(exclude=[]),
# 版本号
version = '0.0.7',
# 许可证类型
license='MIT',
# 描述信息
description = 'Pause Transformer',
# 作者
author = 'Phil Wang',
# 作者邮箱
author_email = 'lucidrains@gmail.com',
# 长描述内容类型
long_description_content_type = 'text/markdown',
# 项目链接
url = 'https://github.com/lucidrains/pause-transformer',
# 关键词列表
keywords = [
'artificial intelligence',
'deep learning',
'adaptive computation'
],
# 安装依赖
install_requires=[
'einops>=0.7.0',
'torch>=2.0'
],
# 分类标签
classifiers=[
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
Data source
The enwik8 data was downloaded from the Hutter prize page: prize.hutter1.net/
.\lucidrains\perceiver-ar-pytorch\perceiver_ar_pytorch\autoregressive_wrapper.py
# 导入 torch 库
import torch
# 导入 torch 中的函数库
import torch.nn.functional as F
# 从 einops 库中导入 rearrange 函数
from einops import rearrange
# 从 torch 库中导入 nn 模块
from torch import nn
# 定义一个辅助函数,用于检查值是否存在
def exists(val):
return val is not None
# 定义一个装饰器函数,用于在模型评估时切换模型状态
def eval_decorator(fn):
def inner(model, *args, **kwargs):
was_training = model.training
model.eval()
out = fn(model, *args, **kwargs)
model.train(was_training)
return out
return inner
# 定义一个函数用于对 logits 进行 top k 过滤
def top_k(logits, thres=0.9):
k = int((1 - thres) * logits.shape[-1])
val, ind = torch.topk(logits, k)
probs = torch.full_like(logits, float("-inf"))
probs.scatter_(1, ind, val)
return probs
# 定义一个自回归封装器类
class AutoregressiveWrapper(nn.Module):
def __init__(self, net, pad_value=0):
super().__init__()
self.max_seq_len = net.max_seq_len
self.pad_value = pad_value
self.net = net
# 生成函数,用于生成序列
@torch.no_grad()
@eval_decorator
def generate(
self,
start_tokens,
seq_len,
eos_token=None,
temperature=1.0,
filter_thres=0.9,
**kwargs
):
b, n, device = *start_tokens.shape, start_tokens.device
out = start_tokens
for _ in range(seq_len):
logits = self.net(
out[:, -self.max_seq_len:],
**kwargs
)[:, -1]
filtered_logits = top_k(logits, thres=filter_thres)
probs = F.softmax(filtered_logits / temperature, dim=-1)
sample = torch.multinomial(probs, 1)
out = torch.cat((out, sample), dim=-1)
if exists(eos_token):
is_eos_token = out == eos_token
if is_eos_token.any(dim=-1).all():
# mask out everything after the eos tokens
shifted_is_eos_tokens = F.pad(is_eos_token, (1, -1))
mask = shifted_is_eos_tokens.float().cumsum(dim=-1) >= 1
out = out.masked_fill(mask, self.pad_value)
break
out = out[:, n:]
return out
# 前向传播函数,用于模型训练
def forward(self, x, **kwargs):
x_inp, x_labels = x[:, :-1], x[:, 1:]
return self.net(x_inp, labels=x_labels, **kwargs)
.\lucidrains\perceiver-ar-pytorch\perceiver_ar_pytorch\perceiver_ar_pytorch.py
import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat
# helper functions
# 检查变量是否存在的辅助函数
def exists(val):
return val is not None
# feedforward
# 定义前馈神经网络层
def FeedForward(dim, mult = 4, dropout = 0.):
hidden_dim = int(dim * mult)
return nn.Sequential(
nn.LayerNorm(dim), # 对输入进行 Layer Normalization
nn.Linear(dim, hidden_dim, bias = False), # 线性变换
nn.GELU(), # GELU 激活函数
nn.Dropout(dropout), # Dropout 正则化
nn.Linear(hidden_dim, dim, bias = False) # 线性变换
)
# rotary positional embedding
# https://arxiv.org/abs/2104.09864
# 旋转位置嵌入类
class RotaryEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
def forward(self, max_seq_len, *, device):
seq = torch.arange(max_seq_len, device = device, dtype = self.inv_freq.dtype)
freqs = einsum("i , j -> i j", seq, self.inv_freq)
return torch.cat((freqs, freqs), dim = -1)
# 旋转半个张量
def rotate_half(x):
x = rearrange(x, "... (j d) -> ... j d", j = 2)
x1, x2 = x.unbind(dim = -2)
return torch.cat((-x2, x1), dim = -1)
# 应用旋转位置嵌入
def apply_rotary_pos_emb(pos, t):
seq_len, rotate_dim = t.shape[-2], pos.shape[-1]
pos = pos[..., -seq_len:, :]
t, t_pass = t[..., :rotate_dim], t[..., rotate_dim:]
t = (t * pos.cos()) + (rotate_half(t) * pos.sin())
return torch.cat((t, t_pass), dim = -1)
# attention
# 因果注意力机制类
class CausalAttention(nn.Module):
def __init__(
self,
*,
dim,
dim_head = 64,
heads = 8,
dropout = 0.
):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
inner_dim = heads * dim_head
self.norm = nn.LayerNorm(dim)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
def forward(self, x, rotary_pos_emb = None):
x = self.norm(x)
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v))
q = q * self.scale
if exists(rotary_pos_emb):
q = apply_rotary_pos_emb(rotary_pos_emb, q)
k = apply_rotary_pos_emb(rotary_pos_emb, k)
sim = einsum('b h i d, b h j d -> b h i j', q, k)
i, j = sim.shape[-2:]
causal_mask = torch.ones((i, j), device = x.device, dtype = torch.bool).triu(j - i + 1)
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
attn = sim.softmax(dim = -1)
attn = self.dropout(attn)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
# 因果前缀注意力机制类
class CausalPrefixAttention(nn.Module):
def __init__(
self,
*,
dim,
dim_head = 64,
heads = 8,
max_heads_process = 2,
dropout = 0.,
cross_attn_dropout = 0.
):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
self.max_heads_process = max_heads_process
inner_dim = heads * dim_head
self.norm = nn.LayerNorm(dim)
self.context_norm = nn.LayerNorm(dim)
self.dropout = nn.Dropout(dropout)
self.cross_attn_dropout = cross_attn_dropout # they drop out a percentage of the prefix during training, shown to help prevent overfitting
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim)
# 定义前向传播函数,接受输入 x、上下文 context、上下文掩码 context_mask 和旋转位置嵌入 rotary_pos_emb
def forward(self, x, context, context_mask = None, rotary_pos_emb = None):
# 获取输入 x 的批量大小、上下文长度和设备信息
batch, context_len, device = x.shape[0], context.shape[-2], x.device
# 复制旋转位置嵌入作为查询和键的旋转位置嵌入
q_rotary_pos_emb = rotary_pos_emb
k_rotary_pos_emb = rotary_pos_emb
# 处理交叉注意力的 dropout
if self.training and self.cross_attn_dropout > 0.:
# 生成随机数用于 dropout
rand = torch.zeros((batch, context_len), device = device).uniform_()
keep_context_len = context_len - int(context_len * self.cross_attn_dropout)
keep_indices = rand.topk(keep_context_len, dim = -1).indices
keep_mask = torch.zeros_like(rand).scatter_(1, keep_indices, 1).bool()
# 根据掩码保留一部分上下文信息
context = rearrange(context[keep_mask], '(b n) d -> b n d', b = batch)
if exists(context_mask):
context_mask = rearrange(context_mask[keep_mask], '(b n) -> b n', b = batch)
# 对键的旋转位置嵌入进行操作
k_rotary_pos_emb = repeat(k_rotary_pos_emb, '... -> b ...', b = batch)
k_rotary_pos_emb_context, k_rotary_pos_emb_seq = k_rotary_pos_emb[:, :context_len], k_rotary_pos_emb[:, context_len:]
k_rotary_pos_emb_context = rearrange(k_rotary_pos_emb_context[keep_mask], '(b n) d -> b n d', b = batch)
k_rotary_pos_emb = torch.cat((k_rotary_pos_emb_context, k_rotary_pos_emb_seq), dim = 1)
k_rotary_pos_emb = rearrange(k_rotary_pos_emb, 'b n d -> b 1 n d')
# 归一化处理
x = self.norm(x)
context = self.context_norm(context)
# 获取查询、键、值
q = self.to_q(x)
k_input, v_input = self.to_kv(x).chunk(2, dim = -1)
k_context, v_context = self.to_kv(context).chunk(2, dim = -1)
k = torch.cat((k_context, k_input), dim = 1)
v = torch.cat((v_context, v_input), dim = 1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v))
q = q * self.scale
# 使用旋转位置嵌入旋转查询和键
if exists(rotary_pos_emb):
q = apply_rotary_pos_emb(q_rotary_pos_emb, q)
k = apply_rotary_pos_emb(k_rotary_pos_emb, k)
# 处理掩码
i, j = q.shape[-2], k.shape[-2]
mask_value = -torch.finfo(q.dtype).max
if exists(context_mask):
mask_len = context_mask.shape[-1]
context_mask = F.pad(context_mask, (0, max(j - mask_len, 0)), value = True)
context_mask = rearrange(context_mask, 'b j -> b 1 1 j')
causal_mask = torch.ones((i, j), device = x.device, dtype = torch.bool).triu(j - i + 1)
# 按头部分块处理
out = []
max_heads = self.max_heads_process
for q_chunk, k_chunk, v_chunk in zip(q.split(max_heads, dim = 1), k.split(max_heads, dim = 1), v.split(max_heads, dim = 1):
sim = einsum('b h i d, b h j d -> b h i j', q_chunk, k_chunk)
if exists(context_mask):
sim = sim.masked_fill(~context_mask, mask_value)
sim = sim.masked_fill(causal_mask, mask_value)
attn = sim.softmax(dim = -1)
attn = self.dropout(attn)
out_chunk = einsum('b h i j, b h j d -> b h i d', attn, v_chunk)
out.append(out_chunk)
# 拼接所有头部
out = torch.cat(out, dim = 1)
# 合并头部并与线性层结合
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class PerceiverAR(nn.Module):
# 定义 PerceiverAR 类,继承自 nn.Module
def __init__(
self,
*,
num_tokens,
dim,
depth,
max_seq_len,
cross_attn_seq_len,
dim_head = 64,
heads = 8,
dropout = 0.,
cross_attn_dropout = 0.,
ff_mult = 4,
perceive_depth = 1,
perceive_max_heads_process = 2 # processes the heads in the perceiver layer in chunks to lower peak memory, in the case the prefix is really long
):
# 初始化函数,接受多个参数
super().__init__()
# 调用父类的初始化函数
assert max_seq_len > cross_attn_seq_len, 'max_seq_len must be greater than cross_attn_seq_len, the length of the sequence for which to cross attend to "perceiver" style'
# 断言,确保 max_seq_len 大于 cross_attn_seq_len
self.max_seq_len = max_seq_len
self.cross_attn_seq_len = cross_attn_seq_len
self.token_emb = nn.Embedding(num_tokens, dim)
# 创建 token embedding 层
self.pos_emb = nn.Embedding(max_seq_len, dim)
# 创建位置 embedding 层
self.rotary_pos_emb = RotaryEmbedding(dim = max(32, dim_head // 2))
# 创建旋转位置 embedding 层
self.perceive_layers = nn.ModuleList([])
# 创建感知层的 ModuleList
for _ in range(perceive_depth):
# 循环感知深度次数
self.perceive_layers.append(nn.ModuleList([
CausalPrefixAttention(dim = dim, dim_head = dim_head, heads = heads, max_heads_process = perceive_max_heads_process, dropout = dropout, cross_attn_dropout = cross_attn_dropout),
FeedForward(dim, mult = ff_mult, dropout = dropout)
]))
# 将 CausalPrefixAttention 和 FeedForward 添加到感知层中
self.layers = nn.ModuleList([])
# 创建层的 ModuleList
for _ in range(depth):
# 循环深度次数
self.layers.append(nn.ModuleList([
CausalAttention(dim = dim, dim_head = dim_head, heads = heads),
FeedForward(dim, mult = ff_mult, dropout = dropout),
]))
# 将 CausalAttention 和 FeedForward 添加到层中
self.to_logits = nn.Linear(dim, num_tokens, bias = False)
# 创建线性层,用于输出 logits
def forward(
self,
x,
prefix_mask = None,
labels = None
):
# 前向传播函数,接受输入 x,前缀掩码和标签
seq_len, device = x.shape[1], x.device
# 获取序列长度和设备信息
assert self.cross_attn_seq_len < seq_len <= self.max_seq_len
# 断言,确保交叉注意力序列长度小于序列长度且小于等于最大序列长度
x = self.token_emb(x)
# 对输入进行 token embedding
x = x + self.pos_emb(torch.arange(seq_len, device = device))
# 添加位置 embedding
# rotary positional embedding
rotary_pos_emb = self.rotary_pos_emb(seq_len, device = device)
# 获取旋转位置 embedding
# divide into prefix to cross attend to and sequence to self attend to
prefix, x = x[:, :self.cross_attn_seq_len], x[:, self.cross_attn_seq_len:]
# 将输入分为前缀和序列部分
# initial perceiver attention and feedforward (one cross attention)
for cross_attn, ff in self.perceive_layers:
# 遍历感知层
x = cross_attn(x, prefix, context_mask = prefix_mask, rotary_pos_emb = rotary_pos_emb) + x
# 进行交叉注意力操作
x = ff(x) + x
# 进行前馈操作
# layers
for attn, ff in self.layers:
# 遍历层
x = attn(x, rotary_pos_emb = rotary_pos_emb) + x
# 进行自注意力操作
x = ff(x) + x
# 进行前馈操作
# to logits
logits = self.to_logits(x)
# 计算 logits
# take care of cross entropy loss if labels are provided
if not exists(labels):
return logits
# 如果提供了标签,则处理交叉熵损失
labels = labels[:, self.cross_attn_seq_len:]
# 获取标签的序列部分
return F.cross_entropy(rearrange(logits, 'b n c -> b c n'), labels, ignore_index = 0)
# 计算交叉熵损失
.\lucidrains\perceiver-ar-pytorch\perceiver_ar_pytorch\__init__.py
# 从 perceiver_ar_pytorch.perceiver_ar_pytorch 模块中导入 PerceiverAR 类
from perceiver_ar_pytorch.perceiver_ar_pytorch import PerceiverAR

Perceiver AR - Pytorch
Implementation of Perceiver AR, Deepmind's new long-context attention network based on Perceiver architecture, in Pytorch.
I am building this out of popular demand, not because I believe in the architecture. As someone else puts it succinctly, this is equivalent to an encoder / decoder transformer architecture where the encoder has 0 layers (and the decoder cross attention is restricted to 1 layer)
However, the experimental results they provided are still worthwhile and I'll build it out so students and researchers alike can explore along this avenue.
Update: seems to be performing decently well on enwik8 with 4096 context length. maybe I was wrong to be pessimistic
Install
$ pip install perceiver-ar-pytorch
Usage
import torch
from perceiver_ar_pytorch import PerceiverAR
model = PerceiverAR(
num_tokens = 20000, # number of tokens
dim = 512, # model dimensions
depth = 8, # model depth
dim_head = 64, # attention head dimension
heads = 8, # attention heads
max_seq_len = 4096, # total max sequence length
cross_attn_seq_len = 3072, # the sequence length in which to attend to, but does not undergo self attention (must be less than max_seq_len)
cross_attn_dropout = 0.5, # what percentage of the prefix to dropout during training, in paper they had extensive experimentation to show up to 50% dropout helped prevent overfitting
)
x = torch.randint(0, 20000, (1, 4096))
logits = model(x) # (1, 1024, 20000) - (4096 [seq len] - 3072 [perceived prefix] == 1024)
Test
Enwik8 at 4096
$ python train.py
Citations
@article{Hawthorne2022GeneralpurposeLA,
title = {General-purpose, long-context autoregressive modeling with Perceiver AR},
author = {Curtis Hawthorne and Andrew Jaegle and Cătălina Cangea and Sebastian Borgeaud and Charlie Nash and Mateusz Malinowski and Sander Dieleman and Oriol Vinyals and Matthew M. Botvinick and Ian Simon and Hannah R. Sheahan and Neil Zeghidour and Jean-Baptiste Alayrac and Jo{\~a}o Carreira and Jesse Engel},
journal = {ArXiv},
year = {2022},
volume = {abs/2202.07765}
}
.\lucidrains\perceiver-ar-pytorch\setup.py
# 导入设置工具和查找包的函数
from setuptools import setup, find_packages
# 设置包的元数据
setup(
# 包的名称
name = 'perceiver-ar-pytorch',
# 查找所有包,不排除任何包
packages = find_packages(exclude=[]),
# 版本号
version = '0.0.10',
# 许可证类型
license='MIT',
# 描述
description = 'Perceiver AR',
# 作者
author = 'Phil Wang',
# 作者邮箱
author_email = 'lucidrains@gmail.com',
# 长描述内容类型
long_description_content_type = 'text/markdown',
# 项目链接
url = 'https://github.com/lucidrains/perceiver-ar-pytorch',
# 关键词列表
keywords = [
'artificial intelligence',
'deep learning',
'transformers',
'long context',
'attention'
],
# 安装依赖
install_requires=[
'einops>=0.4',
'torch>=1.6',
],
# 分类标签
classifiers=[
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\perceiver-ar-pytorch\train.py
# 导入所需的库
import gzip
import random
import numpy as np
import torch
import torch.optim as optim
import tqdm
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
# 导入自定义的模型和包装器
from perceiver_ar_pytorch import PerceiverAR
from perceiver_ar_pytorch.autoregressive_wrapper import AutoregressiveWrapper
# 定义常量
NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 2e-4
VALIDATE_EVERY = 100
GENERATE_EVERY = 500
GENERATE_LENGTH = 512
SEQ_LEN = 4096
PREFIX_SEQ_LEN = 3584
# 定义循环函数
def cycle(loader):
while True:
for data in loader:
yield data
# 解码单个 token
def decode_token(token):
return str(chr(max(32, token)))
# 解码一组 tokens
def decode_tokens(tokens):
return "".join(list(map(decode_token, tokens)))
# 创建 PerceiverAR 模型
model = PerceiverAR(
num_tokens = 256,
dim = 512,
depth = 8,
heads = 8,
dim_head = 64,
cross_attn_dropout = 0.5,
max_seq_len = SEQ_LEN,
cross_attn_seq_len = PREFIX_SEQ_LEN
)
# 使用 AutoregressiveWrapper 包装模型
model = AutoregressiveWrapper(model)
# 将模型移动到 GPU
model.cuda()
# 准备 enwik8 数据
with gzip.open("./data/enwik8.gz") as file:
X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
trX, vaX = np.split(X, [int(90e6)])
data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)
# 定义自定义数据集类
class TextSamplerDataset(Dataset):
def __init__(self, data, seq_len):
super().__init__()
self.data = data
self.seq_len = seq_len
def __getitem__(self, index):
rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
full_seq = self.data[rand_start : rand_start + self.seq_len + 1].long()
return full_seq.cuda()
def __len__(self):
return self.data.size(0) // self.seq_len
# 创建训练集和验证集的数据加载器
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
train_loader = cycle(DataLoader(train_dataset, batch_size=BATCH_SIZE))
val_loader = cycle(DataLoader(val_dataset, batch_size=BATCH_SIZE))
# 定义优化器
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
# 训练模型
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"):
model.train()
for __ in range(GRADIENT_ACCUMULATE_EVERY):
loss = model(next(train_loader))
loss.backward()
print(f"training loss: {loss.item()}")
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
optim.step()
optim.zero_grad()
if i % VALIDATE_EVERY == 0:
model.eval()
with torch.no_grad():
loss = model(next(val_loader))
print(f"validation loss: {loss.item()}")
if i % GENERATE_EVERY == 0:
model.eval()
inp = random.choice(val_dataset)[:-1]
prime = decode_tokens(inp)
print(f"%s \n\n %s", (prime, "*" * 100))
sample = model.generate(inp[None, ...], GENERATE_LENGTH)
output_str = decode_tokens(sample[0])
print(output_str)
.\lucidrains\perceiver-pytorch\perceiver_pytorch\experimental.py
# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块和 einsum 函数
from torch import nn, einsum
# 从 torch.nn.functional 模块中导入 F 别名
import torch.nn.functional as F
# 从 einops 库中导入 rearrange 和 repeat 函数
from einops import rearrange, repeat
# 从 perceiver_pytorch.perceiver_pytorch 模块中导入 exists, default, cache_fn, fourier_encode, PreNorm, FeedForward, Attention 类
# 定义线性注意力类 LinearAttention
class LinearAttention(nn.Module):
def __init__(
self,
dim,
*,
heads = 4,
dim_head = 64,
dropout = 0.
):
super().__init__()
inner_dim = heads * dim_head
self.heads = heads
self.scale = dim_head ** -0.5
# 定义线性变换层,将输入维度转换为内部维度的三倍
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
# 定义输出层,包含线性变换和 dropout 操作
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
# 前向传播函数
def forward(self, x, mask = None):
h = self.heads
# 将输入 x 经过线性变换层得到查询、键、值
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
# 重排查询、键、值的维度
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v))
# 缩放查询
q = q * self.scale
# 对查询和键进行 softmax 操作
q, k = q.softmax(dim = -1), k.softmax(dim = -2)
# 如果存在 mask,则对键进行填充
if exists(mask):
k.masked_fill_(mask, 0.)
# 计算上下文信息
context = einsum('b n d, b n e -> b d e', q, k)
# 计算输出
out = einsum('b d e, b n d -> b n e', context, v)
# 重排输出的维度
out = rearrange(out, ' (b h) n d -> b n (h d)', h = h)
return self.to_out(out)
# 主类 Perceiver
class Perceiver(nn.Module):
def __init__(
self,
*,
num_freq_bands,
depth,
max_freq,
input_channels = 3,
input_axis = 2,
num_latents = 512,
latent_dim = 512,
cross_heads = 1,
latent_heads = 8,
cross_dim_head = 64,
latent_dim_head = 64,
num_classes = 1000,
attn_dropout = 0.,
ff_dropout = 0.,
weight_tie_layers = False,
fourier_encode_data = True
):
# 调用父类的构造函数
super().__init__()
# 设置输入数据的轴数
self.input_axis = input_axis
# 设置最大频率
self.max_freq = max_freq
# 设置频率带数量
self.num_freq_bands = num_freq_bands
# 是否对数据进行傅立叶编码
self.fourier_encode_data = fourier_encode_data
# 计算输入维度
input_dim = input_channels
# 如果需要对数据进行傅立叶编码
if fourier_encode_data:
# 更新输入维度
input_dim += input_axis * ((num_freq_bands * 2) + 1) + input_channels
# 初始化潜在变量
self.latents = nn.Parameter(torch.randn(num_latents, latent_dim))
# 数据投影层
self.data_proj = nn.Linear(input_dim, input_dim)
# 定义获取交叉注意力的函数
get_cross_attn = lambda: PreNorm(latent_dim, Attention(latent_dim, input_dim, heads = cross_heads, dim_head = cross_dim_head, dropout = attn_dropout), context_dim = input_dim)
# 定义获取交叉前馈网络的函数
get_cross_ff = lambda: PreNorm(latent_dim, FeedForward(latent_dim, dropout = ff_dropout))
# 定义获取输入注意力的函数
get_input_attn = lambda: PreNorm(input_dim, LinearAttention(input_dim, dropout = attn_dropout))
# 定义获取反向交叉注意力的函数
get_rev_cross_attn = lambda: PreNorm(input_dim, Attention(input_dim, latent_dim, heads = cross_heads, dim_head = cross_dim_head, dropout = attn_dropout), context_dim = latent_dim)
# 定义获取反向交叉前馈网络的函数
get_rev_cross_ff = lambda: PreNorm(input_dim, FeedForward(input_dim, dropout = ff_dropout))
# 定义获取潜在注意力的函数
get_latent_attn = lambda: PreNorm(latent_dim, Attention(latent_dim, heads = latent_heads, dim_head = latent_dim_head, dropout = attn_dropout))
# 定义获取潜在前馈网络的函数
get_latent_ff = lambda: PreNorm(latent_dim, FeedForward(latent_dim, dropout = ff_dropout))
# 使用缓存函数对获取函数进行缓存
get_cross_attn, get_cross_ff, get_rev_cross_attn, get_rev_cross_ff, get_input_attn, get_latent_attn, get_latent_ff = map(cache_fn, (get_cross_attn, get_cross_ff, get_rev_cross_attn, get_rev_cross_ff, get_input_attn, get_latent_attn, get_latent_ff))
# 初始化网络层
self.layers = nn.ModuleList([])
for i in range(depth):
should_cache = i > 0 and weight_tie_layers
cache_args = {'_cache': should_cache}
self.layers.append(nn.ModuleList([
get_cross_attn(**cache_args),
get_cross_ff(**cache_args),
get_rev_cross_attn(**cache_args),
get_rev_cross_ff(**cache_args),
get_input_attn(**cache_args),
get_latent_attn(**cache_args),
get_latent_ff(**cache_args)
]))
# 输出层
self.to_logits = nn.Sequential(
nn.LayerNorm(latent_dim),
nn.Linear(latent_dim, num_classes)
)
def forward(self, data, mask = None):
# 获取数据的维度信息
b, *axis, _, device = *data.shape, data.device
# 断言数据维度与输入轴数相符
assert len(axis) == self.input_axis, 'input data must have the right number of axis'
# 如果需要对数据进行傅立叶编码
if self.fourier_encode_data:
# 计算在[-1, 1]范围内的傅立叶编码位置,对所有轴
axis_pos = list(map(lambda size: torch.linspace(-1., 1., steps = size, device = device), axis))
pos = torch.stack(torch.meshgrid(*axis_pos, indexing = 'ij'), dim = -1)
enc_pos = fourier_encode(pos, self.max_freq, self.num_freq_bands)
enc_pos = rearrange(enc_pos, '... n d -> ... (n d)')
enc_pos = repeat(enc_pos, '... -> b ...', b = b)
# 将编码位置与数据的通道连接并展平轴
data = torch.cat((data, enc_pos), dim = -1)
data = rearrange(data, 'b ... d -> b (...) d')
# 数据投影
data = self.data_proj(data)
# 重复潜在变量
x = repeat(self.latents, 'n d -> b n d', b = b)
# 遍历网络层
for i, (cross_attn, cross_ff, rev_cross_attn, rev_cross_ff, input_attn, latent_attn, latent_ff) in enumerate(self.layers):
is_last = i == (len(self.layers) - 1)
x = cross_attn(x, context = data, mask = mask) + x
x = cross_ff(x) + x
if not is_last:
data = input_attn(data, mask = mask) + data
data = rev_cross_attn(data, context = x) + data
data = rev_cross_ff(data) + data
x = latent_attn(x) + x
x = latent_ff(x) + x
# 对最后的输出进行平均处理
x = x.mean(dim = -2)
return self.to_logits(x)
.\lucidrains\perceiver-pytorch\perceiver_pytorch\gated.py
# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块、einsum 函数
from torch import nn, einsum
# 从 torch.nn.functional 中导入 F 模块
import torch.nn.functional as F
# 从 einops 库中导入 rearrange、repeat 函数
from einops import rearrange, repeat
# 从 perceiver_pytorch.perceiver_pytorch 中导入 exists、default、cache_fn、fourier_encode、PreNorm、FeedForward、Attention
# helpers
# 定义 Residual 类,继承 nn.Module 类
class Residual(nn.Module):
# 初始化函数
def __init__(self, fn):
super().__init__()
self.fn = fn
# 前向传播函数
def forward(self, x, **kwargs):
return x + self.fn(x, **kwargs)
# 定义 GRUGating 类,继承 nn.Module 类
class GRUGating(nn.Module):
# 初始化函数
def __init__(self, dim, fn):
super().__init__()
self.dim = dim
self.fn = fn
self.gru = nn.GRUCell(dim, dim)
# 前向传播函数
def forward(self, x, **kwargs):
b, dim = x.shape[0], self.dim
y = self.fn(x, **kwargs)
gated_output = self.gru(
rearrange(y, '... d -> (...) d'),
rearrange(x, '... d -> (...) d')
)
gated_output = rearrange(gated_output, '(b n) d -> b n d', b = b)
return gated_output
# main class
# 定义 Perceiver 类,继承 nn.Module 类
class Perceiver(nn.Module):
# 初始化函数
def __init__(
self,
*,
num_freq_bands,
depth,
max_freq,
input_channels = 3,
input_axis = 2,
num_latents = 512,
latent_dim = 512,
cross_heads = 1,
latent_heads = 8,
cross_dim_head = 64,
latent_dim_head = 64,
num_classes = 1000,
attn_dropout = 0.,
ff_dropout = 0.,
weight_tie_layers = False
):
super().__init__()
self.input_axis = input_axis
self.max_freq = max_freq
self.num_freq_bands = num_freq_bands
input_dim = input_axis * ((num_freq_bands * 2) + 1) + input_channels
self.latents = nn.Parameter(torch.randn(num_latents, latent_dim))
get_cross_attn = lambda: GRUGating(latent_dim, PreNorm(latent_dim, Attention(latent_dim, input_dim, heads = cross_heads, dim_head = cross_dim_head, dropout = attn_dropout), context_dim = input_dim))
get_latent_attn = lambda: GRUGating(latent_dim, PreNorm(latent_dim, Attention(latent_dim, heads = latent_heads, dim_head = latent_dim_head, dropout = attn_dropout))
get_cross_ff = lambda: Residual(PreNorm(latent_dim, FeedForward(latent_dim, dropout = ff_dropout)))
get_latent_ff = lambda: Residual(PreNorm(latent_dim, FeedForward(latent_dim, dropout = ff_dropout)))
get_cross_attn, get_cross_ff, get_latent_attn, get_latent_ff = map(cache_fn, (get_cross_attn, get_cross_ff, get_latent_attn, get_latent_ff))
self.layers = nn.ModuleList([])
for i in range(depth):
should_cache = i > 0 and weight_tie_layers
cache_args = {'_cache': should_cache}
self.layers.append(nn.ModuleList([
get_cross_attn(**cache_args),
get_cross_ff(**cache_args),
get_latent_attn(**cache_args),
get_latent_ff(**cache_args)
]))
self.to_logits = nn.Sequential(
nn.LayerNorm(latent_dim),
nn.Linear(latent_dim, num_classes)
)
# 前向传播函数,接受数据和掩码作为输入
def forward(self, data, mask = None):
# 获取数据的形状和设备信息
b, *axis, _, device = *data.shape, data.device
# 断言数据的轴数与输入轴数相同
assert len(axis) == self.input_axis, 'input data must have the right number of axis'
# 计算傅立叶编码的位置,范围为[-1, 1],对所有轴
# 生成每个轴上的位置信息
axis_pos = list(map(lambda size: torch.linspace(-1., 1., steps = size, device = device), axis))
# 生成位置的网格
pos = torch.stack(torch.meshgrid(*axis_pos, indexing = 'ij'), dim = -1)
# 对位置信息进行傅立叶编码
enc_pos = fourier_encode(pos, self.max_freq, self.num_freq_bands)
# 重新排列编码后的位置信息
enc_pos = rearrange(enc_pos, '... n d -> ... (n d)')
# 复制编码后的位置信息,使其与数据维度相匹配
enc_pos = repeat(enc_pos, '... -> b ...', b = b)
# 将编码后的位置信息连接到数据的通道上,并展平轴
data = torch.cat((data, enc_pos), dim = -1)
data = rearrange(data, 'b ... d -> b (...) d')
# 复制潜在变量,使其与数据维度相匹配
x = repeat(self.latents, 'n d -> b n d', b = b)
# 遍历每个层,进行交叉注意力、交叉前馈、潜在注意力和潜在前馈操作
for cross_attn, cross_ff, latent_attn, latent_ff in self.layers:
x = cross_attn(x, context = data, mask = mask)
x = cross_ff(x)
x = latent_attn(x)
x = latent_ff(x)
# 对最终结果进行平均处理,并返回logits
x = x.mean(dim = -2)
return self.to_logits(x)
.\lucidrains\perceiver-pytorch\perceiver_pytorch\mixed_latents.py
# 导入所需的库
import torch
from torch import nn, einsum
import torch.nn.functional as F
# 导入额外的库
from einops import rearrange, repeat
# 导入自定义的模块
from perceiver_pytorch.perceiver_pytorch import exists, default, cache_fn, fourier_encode, PreNorm, FeedForward, Attention
# 定义 latent mixer 函数
def Mixer(seq_len, mult = 4, dropout = 0.):
return nn.Sequential(
nn.Conv1d(seq_len, seq_len * mult, 1),
nn.GELU(),
nn.Dropout(dropout),
nn.Conv1d(seq_len * mult, seq_len, 1)
)
# 定义主要的 Perceiver 类
class Perceiver(nn.Module):
def __init__(
self,
*,
num_freq_bands,
depth,
max_freq,
input_channels = 3,
input_axis = 2,
num_latents = 512,
latent_dim = 512,
cross_heads = 1,
latent_heads = 8,
cross_dim_head = 64,
latent_dim_head = 64,
num_classes = 1000,
attn_dropout = 0.,
ff_dropout = 0.,
weight_tie_layers = False,
**kwargs
):
super().__init__()
self.input_axis = input_axis
self.max_freq = max_freq
self.num_freq_bands = num_freq_bands
# 计算输入维度
input_dim = input_axis * ((num_freq_bands * 2) + 1) + input_channels
# 初始化可学习参数
self.latents = nn.Parameter(torch.randn(num_latents, latent_dim))
# 定义获取不同类型注意力和前馈网络的函数
get_cross_attn = lambda: PreNorm(latent_dim, Attention(latent_dim, input_dim, heads = cross_heads, dim_head = cross_dim_head, dropout = attn_dropout), context_dim = input_dim)
get_latent_attn = lambda: PreNorm(latent_dim, Mixer(num_latents, dropout = ff_dropout))
get_cross_ff = lambda: PreNorm(latent_dim, FeedForward(latent_dim, dropout = ff_dropout))
get_latent_ff = lambda: PreNorm(latent_dim, FeedForward(latent_dim, dropout = ff_dropout))
# 缓存函数的结果
get_cross_attn, get_cross_ff, get_latent_attn, get_latent_ff = map(cache_fn, (get_cross_attn, get_cross_ff, get_latent_attn, get_latent_ff))
# 初始化层列表
self.layers = nn.ModuleList([])
for i in range(depth):
should_cache = i > 0 and weight_tie_layers
cache_args = {'_cache': should_cache}
self.layers.append(nn.ModuleList([
get_cross_attn(**cache_args),
get_cross_ff(**cache_args),
get_latent_attn(**cache_args),
get_latent_ff(**cache_args)
]))
# 定义输出层
self.to_logits = nn.Sequential(
nn.LayerNorm(latent_dim),
nn.Linear(latent_dim, num_classes)
)
def forward(self, data, mask = None):
# 获取数据的形状和设备信息
b, *axis, _, device = *data.shape, data.device
assert len(axis) == self.input_axis, 'input data must have the right number of axis'
# 计算傅立叶编码的位置信息
axis_pos = list(map(lambda size: torch.linspace(-1., 1., steps = size, device = device), axis))
pos = torch.stack(torch.meshgrid(*axis_pos, indexing = 'ij'), dim = -1)
enc_pos = fourier_encode(pos, self.max_freq, self.num_freq_bands)
enc_pos = rearrange(enc_pos, '... n d -> ... (n d)')
enc_pos = repeat(enc_pos, '... -> b ...', b = b)
# 将位置信息拼接到数据中并展平轴
data = torch.cat((data, enc_pos), dim = -1)
data = rearrange(data, 'b ... d -> b (...) d')
# 复制 latent 参数到每个样本
x = repeat(self.latents, 'n d -> b n d', b = b)
# 循环处理每一层
for cross_attn, cross_ff, latent_attn, latent_ff in self.layers:
x = cross_attn(x, context = data, mask = mask) + x
x = cross_ff(x) + x
x = latent_attn(x) + x
x = latent_ff(x) + x
# 对最后的输出进行平均处理并返回
x = x.mean(dim = -2)
return self.to_logits(x)
.\lucidrains\perceiver-pytorch\perceiver_pytorch\perceiver_io.py
# 从 math 模块中导入 pi 和 log 函数
# 从 functools 模块中导入 wraps 函数
# 导入 torch 模块及其子模块 nn, einsum, functional
# 从 einops 模块中导入 rearrange, repeat 函数
from math import pi, log
from functools import wraps
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange, repeat
# 定义辅助函数
# 判断值是否存在
def exists(val):
return val is not None
# 如果值存在则返回该值,否则返回默认值
def default(val, d):
return val if exists(val) else d
# 缓存函数的结果
def cache_fn(f):
cache = None
@wraps(f)
def cached_fn(*args, _cache = True, **kwargs):
if not _cache:
return f(*args, **kwargs)
nonlocal cache
if cache is not None:
return cache
cache = f(*args, **kwargs)
return cache
return cached_fn
# 结构化的 dropout,比传统的注意力 dropout 更有效
# 对序列进行 dropout
def dropout_seq(seq, mask, dropout):
b, n, *_, device = *seq.shape, seq.device
logits = torch.randn(b, n, device = device)
if exists(mask):
logits = logits.masked_fill(~mask, -torch.finfo(logits.dtype).max)
keep_prob = 1. - dropout
num_keep = max(1, int(keep_prob * n))
keep_indices = logits.topk(num_keep, dim = 1).indices
batch_indices = torch.arange(b, device = device)
batch_indices = rearrange(batch_indices, 'b -> b 1')
seq = seq[batch_indices, keep_indices]
if exists(mask):
seq_counts = mask.sum(dim = -1)
seq_keep_counts = torch.ceil(seq_counts * keep_prob).int()
keep_mask = torch.arange(num_keep, device = device) < rearrange(seq_keep_counts, 'b -> b 1')
mask = mask[batch_indices, keep_indices] & keep_mask
return seq, mask
# 辅助类
# 预层归一化
class PreNorm(nn.Module):
def __init__(self, dim, fn, context_dim = None):
super().__init__()
self.fn = fn
self.norm = nn.LayerNorm(dim)
self.norm_context = nn.LayerNorm(context_dim) if exists(context_dim) else None
def forward(self, x, **kwargs):
x = self.norm(x)
if exists(self.norm_context):
context = kwargs['context']
normed_context = self.norm_context(context)
kwargs.update(context = normed_context)
return self.fn(x, **kwargs)
# GEGLU 激活函数
class GEGLU(nn.Module):
def forward(self, x):
x, gates = x.chunk(2, dim = -1)
return x * F.gelu(gates)
# 前馈神经网络
class FeedForward(nn.Module):
def __init__(self, dim, mult = 4):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, dim * mult * 2),
GEGLU(),
nn.Linear(dim * mult, dim)
)
def forward(self, x):
return self.net(x)
# 注意力机制
class Attention(nn.Module):
def __init__(self, query_dim, context_dim = None, heads = 8, dim_head = 64):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.scale = dim_head ** -0.5
self.heads = heads
self.to_q = nn.Linear(query_dim, inner_dim, bias = False)
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)
self.to_out = nn.Linear(inner_dim, query_dim)
def forward(self, x, context = None, mask = None):
h = self.heads
q = self.to_q(x)
context = default(context, x)
k, v = self.to_kv(context).chunk(2, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v))
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
if exists(mask):
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h = h)
sim.masked_fill_(~mask, max_neg_value)
# 注意力机制,我们无法获得足够的
attn = sim.softmax(dim = -1)
out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
return self.to_out(out)
# 主类
class PerceiverIO(nn.Module):
# 初始化函数,设置模型参数
def __init__(
self,
*,
depth,
dim,
queries_dim,
logits_dim = None,
num_latents = 512,
latent_dim = 512,
cross_heads = 1,
latent_heads = 8,
cross_dim_head = 64,
latent_dim_head = 64,
weight_tie_layers = False,
decoder_ff = False,
seq_dropout_prob = 0.
):
# 调用父类初始化函数
super().__init__()
# 设置序列的dropout概率
self.seq_dropout_prob = seq_dropout_prob
# 初始化模型中的可学习参数
self.latents = nn.Parameter(torch.randn(num_latents, latent_dim))
# 创建交叉注意力块和前馈网络块
self.cross_attend_blocks = nn.ModuleList([
PreNorm(latent_dim, Attention(latent_dim, dim, heads = cross_heads, dim_head = cross_dim_head), context_dim = dim),
PreNorm(latent_dim, FeedForward(latent_dim))
])
# 定义获取潜在注意力和前馈网络的函数
get_latent_attn = lambda: PreNorm(latent_dim, Attention(latent_dim, heads = latent_heads, dim_head = latent_dim_head))
get_latent_ff = lambda: PreNorm(latent_dim, FeedForward(latent_dim))
# 使用缓存函数对获取潜在注意力和前馈网络的函数进行缓存
get_latent_attn, get_latent_ff = map(cache_fn, (get_latent_attn, get_latent_ff))
# 初始化模型的层
self.layers = nn.ModuleList([])
cache_args = {'_cache': weight_tie_layers}
# 循环创建多个层
for i in range(depth):
self.layers.append(nn.ModuleList([
get_latent_attn(**cache_args),
get_latent_ff(**cache_args)
]))
# 创建解码器的交叉注意力块和前馈网络块
self.decoder_cross_attn = PreNorm(queries_dim, Attention(queries_dim, latent_dim, heads = cross_heads, dim_head = cross_dim_head), context_dim = latent_dim)
self.decoder_ff = PreNorm(queries_dim, FeedForward(queries_dim)) if decoder_ff else None
# 创建输出层
self.to_logits = nn.Linear(queries_dim, logits_dim) if exists(logits_dim) else nn.Identity()
# 前向传播函数
def forward(
self,
data,
mask = None,
queries = None
):
# 获取数据的维度和设备信息
b, *_, device = *data.shape, data.device
# 将潜在向量重复扩展到与数据相同的维度
x = repeat(self.latents, 'n d -> b n d', b = b)
# 获取交��注意力块和前馈网络块
cross_attn, cross_ff = self.cross_attend_blocks
# 结构化的dropout操作
if self.training and self.seq_dropout_prob > 0.:
data, mask = dropout_seq(data, mask, self.seq_dropout_prob)
# 执行交叉注意力操作
x = cross_attn(x, context = data, mask = mask) + x
x = cross_ff(x) + x
# 多层自注意力和前馈网络操作
for self_attn, self_ff in self.layers:
x = self_attn(x) + x
x = self_ff(x) + x
# 如果没有查询数据,则直接返回结果
if not exists(queries):
return x
# 确保查询数据包含批处理维度
if queries.ndim == 2:
queries = repeat(queries, 'n d -> b n d', b = b)
# 从解码器查询到潜在向量的交叉注意力操作
latents = self.decoder_cross_attn(queries, context = x)
# 可选的解码器前馈网络操作
if exists(self.decoder_ff):
latents = latents + self.decoder_ff(latents)
# 最终的线性输出
return self.to_logits(latents)
# Perceiver LM 示例
class PerceiverLM(nn.Module):
def __init__(
self,
*,
dim, # 定义维度
num_tokens, # 定义标记数量
max_seq_len, # 定义最大序列长度
**kwargs # 其他参数
):
super().__init__()
self.token_emb = nn.Embedding(num_tokens, dim) # 创建标记嵌入层
self.pos_emb = nn.Embedding(max_seq_len, dim) # 创建位置嵌入层
self.perceiver_io = PerceiverIO( # 创建 PerceiverIO 模块
dim = dim,
queries_dim = dim,
logits_dim = num_tokens,
**kwargs
)
def forward(
self,
x, # 输入张量
mask = None # 掩码,默认为空
):
n, device = x.shape[1], x.device # 获取输入张量的维度和设备信息
x = self.token_emb(x) # 对输入张量进行标记嵌入
pos_emb = self.pos_emb(torch.arange(n, device = device)) # 根据序列长度创建位置嵌入
pos_emb = rearrange(pos_emb, 'n d -> () n d') # 重新排列位置嵌入的维度
x = x + pos_emb # 将标记嵌入和位置嵌入相加
logits = self.perceiver_io(x, mask = mask, queries = x) # 使用 PerceiverIO 模块进行前向传播
return logits # 返回输出结果
.\lucidrains\perceiver-pytorch\perceiver_pytorch\perceiver_pytorch.py
# 从 math 模块中导入 pi 和 log 函数
# 从 functools 模块中导入 wraps 装饰器
# 导入 torch 库及其相关模块
# 从 torch.nn 模块中导入 nn 和 einsum
# 从 torch.nn.functional 模块中导入 F
# 导入 einops 库中的 rearrange 和 repeat 函数
# 从 einops.layers.torch 模块中导入 Reduce 类
from math import pi, log
from functools import wraps
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange, repeat
from einops.layers.torch import Reduce
# 定义一些辅助函数
# 判断变量是否存在的函数
def exists(val):
return val is not None
# 如果变量存在则返回该变量,否则返回默认值的函数
def default(val, d):
return val if exists(val) else d
# 缓存函数结果的装饰器
def cache_fn(f):
cache = dict()
@wraps(f)
def cached_fn(*args, _cache = True, key = None, **kwargs):
if not _cache:
return f(*args, **kwargs)
nonlocal cache
if key in cache:
return cache[key]
result = f(*args, **kwargs)
cache[key] = result
return result
return cached_fn
# 对输入进行傅立叶编码的函数
def fourier_encode(x, max_freq, num_bands = 4):
x = x.unsqueeze(-1)
device, dtype, orig_x = x.device, x.dtype, x
scales = torch.linspace(1., max_freq / 2, num_bands, device = device, dtype = dtype)
scales = scales[(*((None,) * (len(x.shape) - 1)), Ellipsis]
x = x * scales * pi
x = torch.cat([x.sin(), x.cos()], dim = -1)
x = torch.cat((x, orig_x), dim = -1)
return x
# 定义一些辅助类
# 实现预层归一化的类
class PreNorm(nn.Module):
def __init__(self, dim, fn, context_dim = None):
super().__init__()
self.fn = fn
self.norm = nn.LayerNorm(dim)
self.norm_context = nn.LayerNorm(context_dim) if exists(context_dim) else None
def forward(self, x, **kwargs):
x = self.norm(x)
if exists(self.norm_context):
context = kwargs['context']
normed_context = self.norm_context(context)
kwargs.update(context = normed_context)
return self.fn(x, **kwargs)
# 实现GEGLU激活函数的类
class GEGLU(nn.Module):
def forward(self, x):
x, gates = x.chunk(2, dim = -1)
return x * F.gelu(gates)
# 实现前馈神经网络的类
class FeedForward(nn.Module):
def __init__(self, dim, mult = 4, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, dim * mult * 2),
GEGLU(),
nn.Linear(dim * mult, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
# 实现注意力机制的类
class Attention(nn.Module):
def __init__(self, query_dim, context_dim = None, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.scale = dim_head ** -0.5
self.heads = heads
self.to_q = nn.Linear(query_dim, inner_dim, bias = False)
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)
self.dropout = nn.Dropout(dropout)
self.to_out = nn.Linear(inner_dim, query_dim)
def forward(self, x, context = None, mask = None):
h = self.heads
q = self.to_q(x)
context = default(context, x)
k, v = self.to_kv(context).chunk(2, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v))
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
if exists(mask):
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h = h)
sim.masked_fill_(~mask, max_neg_value)
# 注意力机制,获取重要信息
attn = sim.softmax(dim = -1)
attn = self.dropout(attn)
out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
return self.to_out(out)
# 主类
class Perceiver(nn.Module):
# 初始化函数,设置Transformer模型的参数
def __init__(
self,
*,
num_freq_bands, # 频率带数量
depth, # Transformer的深度
max_freq, # 最大频率
input_channels = 3, # 输入通道数,默认为3
input_axis = 2, # 输入轴,默认为2
num_latents = 512, # 潜在变量数量,默认为512
latent_dim = 512, # 潜在维度,默认为512
cross_heads = 1, # 交叉头数,默认为1
latent_heads = 8, # 潜在头数,默认为8
cross_dim_head = 64, # 交叉维度头数,默认为64
latent_dim_head = 64, # 潜在维度头数,默认为64
num_classes = 1000, # 类别数量,默认为1000
attn_dropout = 0., # 注意力机制的dropout,默认为0
ff_dropout = 0., # 前馈网络的dropout,默认为0
weight_tie_layers = False, # 是否权重绑定层,默认为False
fourier_encode_data = True, # 是否对数据进行傅立叶编码,默认为True
self_per_cross_attn = 1, # 自注意力与交叉注意力的比例,默认为1
final_classifier_head = True # 是否使用最终分类头,默认为True
"""The shape of the final attention mechanism will be:
depth * (cross attention -> self_per_cross_attn * self attention)
Args:
num_freq_bands: Number of freq bands, with original value (2 * K + 1)
depth: Depth of net.
max_freq: Maximum frequency, hyperparameter depending on how
fine the data is.
freq_base: Base for the frequency
input_channels: Number of channels for each token of the input.
input_axis: Number of axes for input data (2 for images, 3 for video)
num_latents: Number of latents, or induced set points, or centroids.
Different papers giving it different names.
latent_dim: Latent dimension.
cross_heads: Number of heads for cross attention. Paper said 1.
latent_heads: Number of heads for latent self attention, 8.
cross_dim_head: Number of dimensions per cross attention head.
latent_dim_head: Number of dimensions per latent self attention head.
num_classes: Output number of classes.
attn_dropout: Attention dropout
ff_dropout: Feedforward dropout
weight_tie_layers: Whether to weight tie layers (optional).
fourier_encode_data: Whether to auto-fourier encode the data, using
the input_axis given. defaults to True, but can be turned off
if you are fourier encoding the data yourself.
self_per_cross_attn: Number of self attention blocks per cross attn.
final_classifier_head: mean pool and project embeddings to number of classes (num_classes) at the end
"""
super().__init__()
self.input_axis = input_axis
self.max_freq = max_freq
self.num_freq_bands = num_freq_bands
self.fourier_encode_data = fourier_encode_data
fourier_channels = (input_axis * ((num_freq_bands * 2) + 1)) if fourier_encode_data else 0
input_dim = fourier_channels + input_channels
self.latents = nn.Parameter(torch.randn(num_latents, latent_dim))
get_cross_attn = lambda: PreNorm(latent_dim, Attention(latent_dim, input_dim, heads = cross_heads, dim_head = cross_dim_head, dropout = attn_dropout), context_dim = input_dim)
get_cross_ff = lambda: PreNorm(latent_dim, FeedForward(latent_dim, dropout = ff_dropout))
get_latent_attn = lambda: PreNorm(latent_dim, Attention(latent_dim, heads = latent_heads, dim_head = latent_dim_head, dropout = attn_dropout))
get_latent_ff = lambda: PreNorm(latent_dim, FeedForward(latent_dim, dropout = ff_dropout))
get_cross_attn, get_cross_ff, get_latent_attn, get_latent_ff = map(cache_fn, (get_cross_attn, get_cross_ff, get_latent_attn, get_latent_ff))
self.layers = nn.ModuleList([])
for i in range(depth):
should_cache = i > 0 and weight_tie_layers
cache_args = {'_cache': should_cache}
self_attns = nn.ModuleList([])
for block_ind in range(self_per_cross_attn):
self_attns.append(nn.ModuleList([
get_latent_attn(**cache_args, key = block_ind),
get_latent_ff(**cache_args, key = block_ind)
]))
self.layers.append(nn.ModuleList([
get_cross_attn(**cache_args),
get_cross_ff(**cache_args),
self_attns
]))
self.to_logits = nn.Sequential(
Reduce('b n d -> b d', 'mean'),
nn.LayerNorm(latent_dim),
nn.Linear(latent_dim, num_classes)
) if final_classifier_head else nn.Identity()
def forward(
self,
data,
mask = None,
return_embeddings = False
):
# 解构 data 的 shape,获取除了最后两个元素外的所有元素,分别赋值给 b 和 axis
b, *axis, _, device, dtype = *data.shape, data.device, data.dtype
# 断言 axis 的长度等于 self.input_axis,确保输入数据具有正确数量的轴
assert len(axis) == self.input_axis, 'input data must have the right number of axis'
if self.fourier_encode_data:
# 如果需要对数据进行傅立叶编码
# 计算每个轴上范围为[-1, 1]的傅立叶编码位置
# 为每个轴生成均匀分布的位置
axis_pos = list(map(lambda size: torch.linspace(-1., 1., steps=size, device=device, dtype=dtype), axis))
# 将每个轴的位置组合成多维网格
pos = torch.stack(torch.meshgrid(*axis_pos, indexing='ij'), dim=-1)
# 对位置进行傅立叶编码
enc_pos = fourier_encode(pos, self.max_freq, self.num_freq_bands)
# 重新排列编码后的位置
enc_pos = rearrange(enc_pos, '... n d -> ... (n d)')
# 将编码后的位置重复 b 次
enc_pos = repeat(enc_pos, '... -> b ...', b=b)
# 将编码后的位置拼接到数据的通道中
data = torch.cat((data, enc_pos), dim=-1)
# 将数据拼接到通道并展平轴
data = rearrange(data, 'b ... d -> b (...) d')
# 将 latents 重复 b 次
x = repeat(self.latents, 'n d -> b n d', b=b)
# 循环处理每一层
for cross_attn, cross_ff, self_attns in self.layers:
# 跨通道注意力和前馈网络
x = cross_attn(x, context=data, mask=mask) + x
x = cross_ff(x) + x
# 处理每个自注意力和前馈网络
for self_attn, self_ff in self_attns:
x = self_attn(x) + x
x = self_ff(x) + x
# 如果需要返回嵌入向量
if return_embeddings:
return x
# 转换为 logits
return self.to_logits(x)