Lucidrains 系列项目源码解析(三十五)
.\lucidrains\feedback-transformer-pytorch\feedback_transformer_pytorch\__init__.py
# 从 feedback_transformer_pytorch 模块中导入 FeedbackTransformer 类
from feedback_transformer_pytorch.feedback_transformer_pytorch import FeedbackTransformer
Feedback Transformer - Pytorch
Simple implementation of Feedback Transformer in Pytorch. They improve on Transformer-XL by having each token have access to the representations of all previous layers through time. This is achieved by aggregating the outputs of all layers into a shared memory, which each token across layers can attend to at each time step.
The main drawback is longer training time, due to its non-parallel nature. But I thought I'd build it to further exploration and research into this line of work.
I also took the liberty to add some various enhancements, including pre-normalization, GLU gated feedforwards, as well as simplified T5 relative positional embeddings.
Install
$ pip install feedback-transformer-pytorch
Usage
import torch
from feedback_transformer_pytorch import FeedbackTransformer
model = FeedbackTransformer(
num_tokens = 20000, # number of tokens
dim = 512, # dimension
depth = 6, # depth
seq_len = 2, # the sequence length of each segment or window
mem_len = 256, # length of the memory buffer
dim_head = 64, # dimension of each head
heads = 8, # number of heads
attn_dropout = 0.1, # attention dropout
ff_dropout = 0.1 # feedforward dropout
).cuda()
x = torch.randint(0, 20000, (2, 64)).cuda()
model(x) # (2, 64, 20000)
If you would like to have fine control over the memory (when to detach, etc), you can do it with some extra keyword arguments on .forward
import torch
from feedback_transformer_pytorch import FeedbackTransformer
model = FeedbackTransformer(
num_tokens = 20000,
dim = 512,
depth = 6,
seq_len = 32,
mem_len = 256
).cuda()
x1 = torch.randint(0, 20000, (2, 32)).cuda()
x2 = torch.randint(0, 20000, (2, 32)).cuda()
x3 = torch.randint(0, 20000, (2, 32)).cuda()
out1, mem1 = model(x1, return_memory = True)
out2, mem2 = model(x2, memory = mem1, return_memory = True)
out3, mem3 = model(x3, memory = mem2, return_memory = True) # (2, 32, 20000)
Citations
@misc{fan2021addressing,
title = {Addressing Some Limitations of Transformers with Feedback Memory},
author = {Angela Fan and Thibaut Lavril and Edouard Grave and Armand Joulin and Sainbayar Sukhbaatar},
year = {2021},
eprint = {2002.09402},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
.\lucidrains\feedback-transformer-pytorch\setup.py
# 导入设置和查找包的函数
from setuptools import setup, find_packages
# 设置包的元数据
setup(
name = 'feedback-transformer-pytorch', # 包的名称
packages = find_packages(), # 查找所有包
version = '0.0.11', # 版本号
license='MIT', # 许可证
description = 'Implementation of Feedback Transformer in Pytorch', # 描述
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
url = 'https://github.com/lucidrains/feedback-transformer-pytorch', # 项目链接
keywords = [ # 关键词列表
'attention',
'artificial intelligence',
'transformer',
'deep learning',
'memory'
],
install_requires=[ # 安装依赖
'torch>=1.6',
'einops'
],
classifiers=[ # 分类器
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\flamingo-pytorch\flamingo_pytorch\flamingo_palm.py
# 导入所需的库
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from torch import einsum, nn
# 导入自定义模块
from flamingo_pytorch.flamingo_pytorch import GatedCrossAttentionBlock, PerceiverResampler
# 辅助函数
# 检查值是否存在
def exists(val):
return val is not None
# 控制在训练过程中冻结 flamingo 模型的函数
# 设置模块参数是否需要梯度
def set_module_requires_grad_(module, requires_grad):
for param in module.parameters():
param.requires_grad = requires_grad
# 冻结所有层
def freeze_all_layers_(module):
set_module_requires_grad_(module, False)
# 解冻所有层
def unfreeze_all_layers_(module):
set_module_requires_grad_(module, True)
# 冻结模型并设置为评估模式
def freeze_model_and_make_eval_(model):
model.eval()
freeze_all_layers_(model)
# 归一化
# 使用没有偏置的层归一化,PyTorch 中没有提供这种功能
# 自定义的 LayerNorm 模块
class LayerNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.gamma = nn.Parameter(torch.ones(dim))
self.register_buffer("beta", torch.zeros(dim))
def forward(self, x):
return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
# 残差连接
# 自定义的 Residual 模块
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x):
return self.fn(x) + x
# 旋转位置嵌入
# https://arxiv.org/abs/2104.09864
# 自定义的 RotaryEmbedding 模块
class RotaryEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
def forward(self, max_seq_len, *, device):
seq = torch.arange(max_seq_len, device=device, dtype=self.inv_freq.dtype)
freqs = einsum("i , j -> i j", seq, self.inv_freq)
return torch.cat((freqs, freqs), dim=-1)
# 旋转半个位置嵌入
def rotate_half(x):
x = rearrange(x, "... (j d) -> ... j d", j=2)
x1, x2 = x.unbind(dim=-2)
return torch.cat((-x2, x1), dim=-1)
# 应用旋转位置嵌入
def apply_rotary_pos_emb(pos, t):
return (t * pos.cos()) + (rotate_half(t) * pos.sin())
# 经典的 Noam Shazeer 论文,这里使用 SwiGLU 代替更流行的 GEGLU 作为前馈门控
# https://arxiv.org/abs/2002.05202
# 自定义的 SwiGLU 模块
class SwiGLU(nn.Module):
def forward(self, x):
x, gate = x.chunk(2, dim=-1)
return F.silu(gate) * x
# 并行注意力和前馈连接的残差块
# 由 Wang 等人和 GPT-J 的 EleutherAI 发现
# 自定义的 ParallelTransformerBlock 模块
class ParallelTransformerBlock(nn.Module):
def __init__(self, dim, dim_head=64, heads=8, ff_mult=4):
super().__init__()
self.norm = LayerNorm(dim)
attn_inner_dim = dim_head * heads
ff_inner_dim = dim * ff_mult
self.fused_dims = (attn_inner_dim, dim_head, dim_head, (ff_inner_dim * 2))
self.heads = heads
self.scale = dim_head**-0.5
self.rotary_emb = RotaryEmbedding(dim_head)
self.fused_attn_ff_proj = nn.Linear(dim, sum(self.fused_dims), bias=False)
self.attn_out = nn.Linear(attn_inner_dim, dim, bias=False)
self.ff_out = nn.Sequential(
SwiGLU(),
nn.Linear(ff_inner_dim, dim, bias=False)
)
# 用于缓存因果掩码和旋转嵌入
self.register_buffer("mask", None, persistent=False)
self.register_buffer("pos_emb", None, persistent=False)
def get_mask(self, n, device):
if self.mask is not None and self.mask.shape[-1] >= n:
return self.mask[:n, :n]
mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1)
self.register_buffer("mask", mask, persistent=False)
return mask
def get_rotary_embedding(self, n, device):
if self.pos_emb is not None and self.pos_emb.shape[-2] >= n:
return self.pos_emb[:n]
pos_emb = self.rotary_emb(n, device=device)
self.register_buffer("pos_emb", pos_emb, persistent=False)
return pos_emb
def forward(self, x):
"""
使用爱因斯坦符号表示
b - 批次
h - 头数
n, i, j - 序列长度(基本序列长度,源,目标)
d - 特征维度
"""
n, device, h = x.shape[1], x.device, self.heads
# 预先 Layernorm 处理
x = self.norm(x)
# 注意力查询、键、值和前馈内部
q, k, v, ff = self.fused_attn_ff_proj(x).split(self.fused_dims, dim=-1)
# 分割头部
# 他们使用多查询单键值注意力,另一篇 Noam Shazeer 的论文
# 他们发现在一定规模之后没有性能损失,并且解码更有效率
# https://arxiv.org/abs/1911.02150
q = rearrange(q, "b n (h d) -> b h n d", h=h)
# 旋转嵌入
positions = self.get_rotary_embedding(n, device)
q, k = map(lambda t: apply_rotary_pos_emb(positions, t), (q, k))
# 缩放
q = q * self.scale
# 相似度
sim = einsum("b h i d, b j d -> b h i j", q, k)
# 因果掩码
causal_mask = self.get_mask(n, device)
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
# 注意力
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
attn = sim.softmax(dim=-1)
# 聚合值
out = einsum("b h i j, b j d -> b h i d", attn, v)
# 合并头部
out = rearrange(out, "b h n d -> b n (h d)")
return self.attn_out(out) + self.ff_out(ff)
# transformer
# 定义一个名为FlamingoPaLM的神经网络模块
class FlamingoPaLM(nn.Module):
def __init__(
self,
*,
dim, # 特征维度
num_tokens, # 标记数量
depth, # 深度
dim_head=64, # 头部维度
heads=8, # 头部数量
ff_mult=4, # FeedForward模块的倍增因子
media_token_id=3, # 媒体标记ID
cross_attn_every=3, # 每隔多少层进行交叉注意力
img_encoder=None, # 图像编码器
perceiver_num_latents=64, # 感知器潜在特征数量
perceiver_depth=2, # 感知器深度
max_video_frames = None, # 最大视频帧数
only_attend_immediate_media=True # 是否只关注即时媒体
):
super().__init__()
# 标记嵌入层
self.token_emb = nn.Embedding(num_tokens, dim)
# 媒体标记ID,需要为媒体保留一个特殊的标记ID
self.media_token_id = media_token_id
# 视频帧位置嵌入
self.video_frame_pos_emb = nn.Parameter(torch.randn(max_video_frames, dim)) if exists(max_video_frames) else None
# 图像编码器
self.img_encoder = img_encoder
# 冻结图像编码器并设置为评估模式
freeze_model_and_make_eval_(self.img_encoder)
# 感知器重采样器
self.perceiver_resampler = PerceiverResampler(
dim=dim,
depth=perceiver_depth,
dim_head=dim_head,
heads=heads,
num_latents=perceiver_num_latents
)
# 层列表
self.layers = nn.ModuleList([])
for ind in range(depth):
self.layers.append(nn.ModuleList([
Residual(ParallelTransformerBlock(dim=dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult)),
GatedCrossAttentionBlock(dim=dim, dim_head=dim_head, heads=heads, only_attend_immediate_media=only_attend_immediate_media) if not (ind % cross_attn_every) else None
]))
# 转换为logits
self.to_logits = nn.Sequential(
LayerNorm(dim),
nn.Linear(dim, num_tokens, bias=False)
)
# 使用嵌入权重来绑定投影到logits,这种方式不常见,但有效
self.to_logits[-1].weight = self.token_emb.weight
# 初始化嵌入权重
nn.init.normal_(self.token_emb.weight, std=0.02)
# 前向传播函数
def forward(
self,
text,
*,
images=None,
videos=None,
embeds=None
):
# 获取文本的批次大小和设备信息
batch, device = text.shape[0], text.device
# 判断是否处于flamingo模式
flamingo_mode = any([exists(t) for t in (images, videos, embeds)])
# 根据传入的参数自动决定冻结或解冻层
if flamingo_mode:
# 在flamingo模式下,冻结除了perceiver和gated cross attention之外的所有层
freeze_all_layers_(self)
unfreeze_all_layers_(self.perceiver_resampler)
[unfreeze_all_layers_(cross_attn) for _, cross_attn in self.layers if exists(cross_attn)]
else:
# 解冻所有层
unfreeze_all_layers_(self)
# 推导媒体令牌的ID(作为布尔张量),用于计算掩码交叉注意力
if flamingo_mode:
media_locations = text == self.media_token_id
# 对文本令牌进行编码
text_tokens = self.token_emb(text)
# 断言不存在embeds并且存在images或video
assert not (exists(embeds) and (exists(images) or exists(video)))
# 将视频或图像编码为嵌入
# 使用在init中传入的img_encoder
# 也可以接受预先计算的图像嵌入
if exists(images):
assert exists(self.img_encoder), 'img_encoder must be passed in for automatic image encoding'
images = rearrange(images, 'b t ... -> (b t) ...')
with torch.no_grad():
embeds = self.img_encoder(images)
embeds = rearrange(embeds, '(b t) ... -> b t ...', b = batch)
if exists(videos):
assert exists(self.img_encoder), 'img_encoder must be passed in for automatic video encoding'
batch, media, num_times, *_ = videos.shape
videos = rearrange(videos, '... c h w -> (...) c h w')
with torch.no_grad():
embeds = self.img_encoder(videos)
embeds = rearrange(embeds, '(b m t) ... -> b m t ...', b = batch, m = media, t = num_times)
video_time_pos_emb = repeat(self.video_frame_pos_emb[:num_times], 't d -> b m t n d', b = batch, m = media, n = embeds.shape[-2])
embeds = embeds + video_time_pos_emb
embeds = rearrange(embeds, 'b m t n d -> b m (t n) d')
if exists(embeds):
embeds = self.perceiver_resampler(embeds)
# 遍历层
for attn_ff, flamingo_cross_attn in self.layers:
text_tokens = attn_ff(text_tokens)
# 如果存在图像嵌入并且为该层设置了flamingo交叉注意力,则进行交叉注意力
if exists(flamingo_cross_attn) and exists(embeds):
text_tokens = flamingo_cross_attn(
text_tokens,
embeds,
media_locations = media_locations
)
return self.to_logits(text_tokens)
.\lucidrains\flamingo-pytorch\flamingo_pytorch\flamingo_pytorch.py
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange, repeat
from einops_exts import rearrange_many, repeat_many
def exists(val):
return val is not None
def FeedForward(dim, mult = 4):
inner_dim = int(dim * mult)
return nn.Sequential(
nn.LayerNorm(dim), # 对输入进行 Layer Normalization
nn.Linear(dim, inner_dim, bias = False), # 线性变换,将输入维度转换为 inner_dim
nn.GELU(), # GELU 激活函数
nn.Linear(inner_dim, dim, bias = False) # 线性变换,将 inner_dim 转换为 dim
)
class PerceiverAttention(nn.Module):
def __init__(
self,
*,
dim,
dim_head = 64,
heads = 8
):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
inner_dim = dim_head * heads
self.norm_media = nn.LayerNorm(dim) # 对媒体数据进行 Layer Normalization
self.norm_latents = nn.LayerNorm(dim) # 对潜在数据进行 Layer Normalization
self.to_q = nn.Linear(dim, inner_dim, bias = False) # 线性变换,将输入维度转换为 inner_dim
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False) # 线性变换,将输入维度转换为 inner_dim * 2
self.to_out = nn.Linear(inner_dim, dim, bias = False) # 线性变换,将 inner_dim 转换为 dim
def forward(self, x, latents):
"""
einstein notation
b - batch
t - time
n - sequence
d - dimension
"""
x = self.norm_media(x) # 对媒体数据进行 Layer Normalization
latents = self.norm_latents(latents) # 对潜在数据进行 Layer Normalization
b, m, h = *x.shape[:2], self.heads
q = self.to_q(latents) # 对潜在数据进行线性变换
# the paper differs from Perceiver in which they also concat the key / values derived from the latents to be attended to
kv_input = torch.cat((x, latents), dim = -2) # 拼接媒体数据和潜在数据
k, v = self.to_kv(kv_input).chunk(2, dim = -1) # 将拼接后的数据进行线性变换并分割为 key 和 value
q, k, v = rearrange_many((q, k, v), 'b t n (h d) -> b h t n d', h = h) # 重排数据维度
q = q * self.scale # 缩放 q
# attention
sim = einsum('... i d, ... j d -> ... i j', q, k) # 计算注意力分数
sim = sim - sim.amax(dim = -1, keepdim = True).detach() # 对注意力分数进行处理
attn = sim.softmax(dim = -1) # 计算注意力权��
out = einsum('... i j, ... j d -> ... i d', attn, v) # 计算输出
out = rearrange(out, 'b h t n d -> b t n (h d)', h = h) # 重排输出维度
return self.to_out(out) # 返回输出数据
class PerceiverResampler(nn.Module):
def __init__(
self,
*,
dim,
depth,
dim_head = 64,
heads = 8,
num_latents = 64,
num_media_embeds = 4,
ff_mult = 4
):
super().__init__()
self.latents = nn.Parameter(torch.randn(num_latents, dim)) # 初始化潜在数据
self.media_pos_emb = nn.Parameter(torch.randn(num_media_embeds, 1, dim)) # 初始化媒体位置嵌入
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PerceiverAttention(dim = dim, dim_head = dim_head, heads = heads), # 添加 PerceiverAttention 层
FeedForward(dim = dim, mult = ff_mult) # 添加 FeedForward 层
]))
self.norm = nn.LayerNorm(dim) # 对数据进行 Layer Normalization
def forward(self, x):
if x.ndim == 3:
x = rearrange(x, 'b n d -> b 1 n d') # 重排输入数据维度
times = x.shape[1]
x = x + self.media_pos_emb[:times] # 将媒体位置嵌入加到输入数据上
latents = repeat(self.latents, 'n d -> b m n d', b = x.shape[0], m = x.shape[1]) # 重复潜在数据
for attn, ff in self.layers:
latents = attn(x, latents) + latents # 使用 PerceiverAttention 层
latents = ff(latents) + latents # 使用 FeedForward 层
return self.norm(latents) # 对输出数据进行 Layer Normalization
# gated cross attention
class MaskedCrossAttention(nn.Module):
def __init__(
self,
*,
dim,
dim_head = 64,
heads = 8,
only_attend_immediate_media = True
):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
inner_dim = dim_head * heads
self.norm = nn.LayerNorm(dim) # 对数据进行 Layer Normalization
self.to_q = nn.Linear(dim, inner_dim, bias = False) # 线性变换,将输入维度转换为 inner_dim
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False) # 线性变换,将输入维度转换为 inner_dim * 2
self.to_out = nn.Linear(inner_dim, dim, bias = False) # 线性变换,将 inner_dim 转换为 dim
# whether for text to only attend to immediate preceding image, or all images
self.only_attend_immediate_media = only_attend_immediate_media # 是否只关注紧邻的图像
def forward(
self,
x,
media,
media_locations = None
):
# 获取媒体数据的形状信息
b, t, m = media.shape[:3]
# 获取头数
h = self.heads
# 对输入进行归一化处理
x = self.norm(x)
# 将输入转换为查询向量
q = self.to_q(x)
# 重新排列媒体数据的维度
media = rearrange(media, 'b t n d -> b (t n) d')
# 将媒体数据转换为键值对
k, v = self.to_kv(media).chunk(2, dim = -1)
# 重新排列多个张量的维度
q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = h)
# 对查询向量进行缩放
q = q * self.scale
# 计算查询向量和键向量之间的相似度
sim = einsum('... i d, ... j d -> ... i j', q, k)
if exists(media_locations):
# 计算文本时间
text_time = media_locations.cumsum(dim = -1) # 在每个 True 布尔值处,增加时间计数器(相对于媒体时间)
media_time = torch.arange(t, device = x.device) + 1
# 如果只关注最近的图像,则文本时间必须等于媒体时间
# 否则,只要文本时间大于媒体时间(如果关注所有先前的图像/媒体)
mask_op = torch.eq if self.only_attend_immediate_media else torch.ge
# 创建文本到媒体的掩码
text_to_media_mask = mask_op(rearrange(text_time, 'b i -> b 1 i 1'), repeat(media_time, 'j -> 1 1 1 (j m)', m = m))
sim = sim.masked_fill(~text_to_media_mask, -torch.finfo(sim.dtype).max)
# 对相似度进行归一化处理
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
attn = sim.softmax(dim = -1)
if exists(media_locations) and self.only_attend_immediate_media:
# 需要将没有前置媒体的文本的注意力置零
text_without_media_mask = text_time == 0
text_without_media_mask = rearrange(text_without_media_mask, 'b i -> b 1 i 1')
attn = attn.masked_fill(text_without_media_mask, 0.)
# 计算输出
out = einsum('... i j, ... j d -> ... i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
# 定义一个 GatedCrossAttentionBlock 类,继承自 nn.Module
class GatedCrossAttentionBlock(nn.Module):
# 初始化函数,接受一些参数
def __init__(
self,
*,
dim, # 输入维度
dim_head = 64, # 每个头的维度
heads = 8, # 多头注意力的头数
ff_mult = 4, # FeedForward 层的倍数
only_attend_immediate_media = True # 是否只关注直接媒体
):
super().__init__()
# 创建 MaskedCrossAttention 对象,用于计算交叉注意力
self.attn = MaskedCrossAttention(dim = dim, dim_head = dim_head, heads = heads, only_attend_immediate_media = only_attend_immediate_media)
# 创建一个可学习的参数,用于门控交叉注意力
self.attn_gate = nn.Parameter(torch.tensor([0.]))
# 创建 FeedForward 对象,用于前馈神经网络
self.ff = FeedForward(dim, mult = ff_mult)
# 创建一个可学习的参数,用于门控前馈神经网络
self.ff_gate = nn.Parameter(torch.tensor([0.]))
# 前向传播函数
def forward(
self,
x, # 输入张量
media, # 媒体张量,由感知器重新采样编码 - (batch, time, latents, dim)
media_locations = None # 表示媒体位置的布尔张量 - (batch, sequence)
):
# 计算交叉注意力并应用门控
x = self.attn(x, media, media_locations = media_locations) * self.attn_gate.tanh() + x
# 应用前馈神经网络并应用门控
x = self.ff(x) * self.ff_gate.tanh() + x
# 返回结果张量
return x
.\lucidrains\flamingo-pytorch\flamingo_pytorch\__init__.py
# 从 flamingo_pytorch.flamingo_pytorch 模块中导入 PerceiverResampler 和 GatedCrossAttentionBlock 类
from flamingo_pytorch.flamingo_pytorch import PerceiverResampler, GatedCrossAttentionBlock
# 从 flamingo_pytorch.flamingo_palm 模块中导入 FlamingoPaLM 类
from flamingo_pytorch.flamingo_palm import FlamingoPaLM

🦩 Flamingo - Pytorch
Implementation of Flamingo, state-of-the-art few-shot visual question answering attention net, in Pytorch. It will include the perceiver resampler (including the scheme where the learned queries contributes keys / values to be attended to, in addition to media embeddings), the specialized masked cross attention blocks, and finally the tanh gating at the ends of the cross attention + corresponding feedforward blocks
Install
$ pip install flamingo-pytorch
Usage
import torch
from flamingo_pytorch import PerceiverResampler
perceive = PerceiverResampler(
dim = 1024,
depth = 2,
dim_head = 64,
heads = 8,
num_latents = 64, # the number of latents to shrink your media sequence to, perceiver style
num_time_embeds = 4 # say you have 4 images maximum in your dialogue
)
medias = torch.randn(1, 2, 256, 1024) # (batch, time, sequence length, dimension)
perceived = perceive(medias) # (1, 2, 64, 1024) - (batch, time, num latents, dimension)
Then you insert the GatedCrossAttentionBlock at different intervals in your giant language model. Your text would then attend to the perceived media from above
The recommended way to derive the media_locations boolean tensor would be to allocate a special token id to the media, and then, at the start of your large language model, do media_locations = text_id == media_token_id
import torch
from flamingo_pytorch import GatedCrossAttentionBlock
cross_attn = GatedCrossAttentionBlock(
dim = 1024,
dim_head = 64,
heads = 8
)
text = torch.randn(1, 512, 1024)
perceived = torch.randn(1, 2, 64, 1024)
media_locations = torch.randint(0, 2, (1, 512)).bool()
text = cross_attn(
text,
perceived,
media_locations = media_locations
)
That's it!
Attention is all you need.
Full working example with Flamingo + PaLM 🌴🦩🌴
Integration with PaLM
First install vit-pytorch for the vision encoder
$ pip install vit-pytorch
Then
from vit_pytorch.vit import ViT
from vit_pytorch.extractor import Extractor
vit = ViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)
vit = Extractor(vit, return_embeddings_only = True)
# first take your trained image encoder and wrap it in an adapter that returns the image embeddings
# here we use the ViT from the vit-pytorch library
import torch
from flamingo_pytorch import FlamingoPaLM
# a PaLM language model, the 540 billion parameter model from google that shows signs of general intelligence
flamingo_palm = FlamingoPaLM(
num_tokens = 20000, # number of tokens
dim = 1024, # dimensions
depth = 12, # depth
heads = 8, # attention heads
dim_head = 64, # dimension per attention head
img_encoder = vit, # plugin your image encoder (this can be optional if you pass in the image embeddings separately, but probably want to train end to end given the perceiver resampler)
media_token_id = 3, # the token id representing the [media] or [image]
cross_attn_every = 3, # how often to cross attend
perceiver_num_latents = 64, # perceiver number of latents, should be smaller than the sequence length of the image tokens
perceiver_depth = 2 # perceiver resampler depth
)
# train your PaLM as usual
text = torch.randint(0, 20000, (2, 512))
palm_logits = flamingo_palm(text)
# after much training off the regular PaLM logits
# now you are ready to train Flamingo + PaLM
# by passing in images, it automatically freezes everything but the perceiver and cross attention blocks, as in the paper
dialogue = torch.randint(0, 20000, (4, 512))
images = torch.randn(4, 2, 3, 256, 256)
flamingo_logits = flamingo_palm(dialogue, images)
# do your usual cross entropy loss
It is quite evident where this is all headed if you think beyond just images.
Inception
For factual correctness, just imagine where this system would stand if one were to use a state of the art retrieval language model as the base.
Citations
@article{Alayrac2022Flamingo,
title = {Flamingo: a Visual Language Model for Few-Shot Learning},
author = {Jean-Baptiste Alayrac et al},
year = {2022}
}
@inproceedings{Chowdhery2022PaLMSL,
title = {PaLM: Scaling Language Modeling with Pathways},
author = {Aakanksha Chowdhery and Sharan Narang and Jacob Devlin and Maarten Bosma and Gaurav Mishra and Adam Roberts and Paul Barham and Hyung Won Chung and Charles Sutton and Sebastian Gehrmann and Parker Schuh and Kensen Shi and Sasha Tsvyashchenko and Joshua Maynez and Abhishek Rao and Parker Barnes and Yi Tay and Noam M. Shazeer and Vinodkumar Prabhakaran and Emily Reif and Nan Du and Benton C. Hutchinson and Reiner Pope and James Bradbury and Jacob Austin and Michael Isard and Guy Gur-Ari and Pengcheng Yin and Toju Duke and Anselm Levskaya and Sanjay Ghemawat and Sunipa Dev and Henryk Michalewski and Xavier Garc{\'i}a and Vedant Misra and Kevin Robinson and Liam Fedus and Denny Zhou and Daphne Ippolito and David Luan and Hyeontaek Lim and Barret Zoph and Alexander Spiridonov and Ryan Sepassi and David Dohan and Shivani Agrawal and Mark Omernick and Andrew M. Dai and Thanumalayan Sankaranarayana Pillai and Marie Pellat and Aitor Lewkowycz and Erica Oliveira Moreira and Rewon Child and Oleksandr Polozov and Katherine Lee and Zongwei Zhou and Xuezhi Wang and Brennan Saeta and Mark Diaz and Orhan Firat and Michele Catasta and Jason Wei and Kathleen S. Meier-Hellstern and Douglas Eck and Jeff Dean and Slav Petrov and Noah Fiedel},
year = {2022}
}
.\lucidrains\flamingo-pytorch\setup.py
# 导入设置工具和查找包工具
from setuptools import setup, find_packages
# 设置包的元数据
setup(
name = 'flamingo-pytorch', # 包名
packages = find_packages(exclude=[]), # 查找包
version = '0.1.2', # 版本号
license='MIT', # 许可证
description = 'Flamingo - Pytorch', # 描述
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
url = 'https://github.com/lucidrains/flamingo-pytorch', # 项目链接
long_description_content_type = 'text/markdown', # 长描述内容类型
keywords = [ # 关键词
'artificial intelligence',
'deep learning',
'transformers',
'attention mechanism',
'visual question answering'
],
install_requires=[ # 安装依赖
'einops>=0.4',
'einops-exts',
'torch>=1.6'
],
classifiers=[ # 分类器
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\flash-attention-jax\flash_attention_jax\attention.py
# 导入需要的库
import jax
from jax import nn
from jax import jit, numpy as jnp
from jax.numpy import einsum
# 导入重塑函数
from einops import rearrange
# 定义常量
EPSILON = 1e-10
MASK_VALUE = -1e10
COSINE_SIM_SCALE = 10
# 定义注意力机制函数
@jit
def attention(q, k, v, key_mask):
# 获取维度和 k 的长度
dim, k_len = q.shape[-1], k.shape[-2]
scale = 1 / jnp.sqrt(dim)
# 对查询进行缩放
q = q * scale
# 计算查询和键之间的相似度
sim = einsum('... i d, ... j d -> ... i j', q, k)
# 对键进行掩码处理
key_mask = rearrange(key_mask, 'b j -> b 1 1 j')
sim = jnp.where(key_mask, sim, MASK_VALUE)
# 计算注意力权重并返回加权后的值
attn = nn.softmax(sim, axis = -1)
return attn @ v
# 定义因果注意力机制函数
@jit
def causal_attention(q, k, v):
q_len, dim, k_len = *q.shape[-2:], k.shape[-2]
scale = 1 / jnp.sqrt(dim)
# 对查询进行缩放
q = q * scale
# 计算查询和键之间的相似度
sim = einsum('... i d, ... j d -> ... i j', q, k)
# 创建因果掩码
causal_mask = jnp.triu(jnp.ones((q_len, k_len)), k_len - q_len + 1)
sim = jnp.where(causal_mask, MASK_VALUE, sim)
# 计算注意力权重并返回加权后的值
attn = nn.softmax(sim, axis = -1)
return einsum('... i j, ... j d -> ... i d', attn, v)
# 定义余弦相似度注意力机制函数
@jit
def l2norm(t):
return t / (jnp.linalg.norm(t) + EPSILON)
@jit
def cosine_sim_attention(q, k, v, key_mask):
dim, k_len = q.shape[-1], k.shape[-2]
# 对查询和键进行 L2 归一化
q, k = map(l2norm, (q, k))
# 计算余弦相似度
sim = einsum('... i d, ... j d -> ... i j', q, k) * COSINE_SIM_SCALE
# 对键进行掩码处理
key_mask = rearrange(key_mask, 'b j -> b 1 1 j')
sim = jnp.where(key_mask, sim, MASK_VALUE)
# 计算注意力权重并返回加权后的值
attn = nn.softmax(sim, axis = -1)
return einsum('... i j, ... j d -> ... i d', attn, v)
.\lucidrains\flash-attention-jax\flash_attention_jax\causal_flash_attention.py
# 导入所需的库
import math
import jax
from functools import partial
from jax import nn
from jax import custom_vjp
from jax import numpy as jnp, lax, jit
from jax.numpy import einsum
from einops import rearrange
# 定义常量
EPSILON = 1e-10
MASK_VALUE = -1e10
Q_CHUNK_SIZE = 1024
K_CHUNK_SIZE = 1024
# 定义 flash attention 函数
def _query_chunk_flash_attention(q_range_chunk, k_range, q, k, v):
# 获取输入张量的形状信息
q_len, k_len, bh, dim, v_dim = q.shape[0], *k.shape, v.shape[-1]
scale = 1 / jnp.sqrt(dim)
q_scaled = q * scale
# 定义内部函数用于处理数据块
def chunk_scanner(carries, _):
key_chunk_idx, out, row_sum, row_max = carries
k_chunk_sizes = min(K_CHUNK_SIZE, k_len)
# 切片获取 k 和 v 的数据块
k_chunk = lax.dynamic_slice(k, (key_chunk_idx, 0, 0), slice_sizes=(k_chunk_sizes, bh, dim))
v_chunk = lax.dynamic_slice(v, (key_chunk_idx, 0, 0), slice_sizes=(k_chunk_sizes, bh, v_dim))
k_range_chunk = lax.dynamic_slice(k_range, (0, key_chunk_idx), slice_sizes=(1, k_chunk_sizes))
# 创建因果 mask
causal_mask = q_range_chunk < k_range_chunk
# 计算注意力权重
attn_weights = einsum('i ... d, j ... d -> i ... j', q_scaled, k_chunk)
causal_mask = rearrange(causal_mask, 'i j -> i 1 j')
attn_weights = jnp.where(causal_mask, MASK_VALUE, attn_weights)
block_row_max = jnp.max(attn_weights, axis = -1, keepdims = True)
exp_weights = jnp.exp(attn_weights - block_row_max)
exp_weights = jnp.where(causal_mask, 0., exp_weights)
block_row_sum = jnp.sum(exp_weights, axis = -1, keepdims = True) + EPSILON
exp_values = einsum('i ... j, j ... d -> i ... d', exp_weights, v_chunk)
new_row_max = jnp.maximum(block_row_max, row_max)
exp_row_max_diff = jnp.exp(row_max - new_row_max)
exp_block_row_max_diff = jnp.exp(block_row_max - new_row_max)
new_row_sum = exp_row_max_diff * row_sum + exp_block_row_max_diff * block_row_sum
out = (row_sum / new_row_sum) * exp_row_max_diff * out + \
(exp_block_row_max_diff / new_row_sum) * exp_values
return (key_chunk_idx + k_chunk_sizes, out, new_row_sum, new_row_max), None
# 初始化输出张量
out = jnp.zeros((q_len, bh, dim))
row_sum = jnp.zeros((q_len, bh, 1))
row_max = jnp.ones((q_len, bh, 1)) * -1e6
# 扫描数据块并处理
(_, out, row_sum, row_max), _ = lax.scan(chunk_scanner, init = (0, out, row_sum, row_max), xs = None, length = math.ceil(k_len / K_CHUNK_SIZE))
out = out.reshape(q_len, bh, v_dim)
row_sum = row_sum.reshape(q_len, bh)
row_max = row_max.reshape(q_len, bh)
return out, row_sum, row_max
# 定义因果 flash attention 函数
def _causal_flash_attention(q, k, v):
batch, heads, q_len, dim, k_len, v_dim = *q.shape, *v.shape[-2:]
bh = batch * heads
q, k, v = map(lambda t: rearrange(t, 'b h n d -> n (b h) d'), (q, k, v))
q_range = jnp.arange(q_len).reshape(q_len, 1) + (k_len - q_len)
k_range = jnp.arange(k_len).reshape(1, k_len)
# 定义内部函数用于处理数据块
def chunk_scanner(chunk_idx, _):
chunk_sizes = min(Q_CHUNK_SIZE, q_len)
q_chunk = lax.dynamic_slice(q, (chunk_idx, 0, 0), slice_sizes = (chunk_sizes, bh, dim))
q_range_chunk = lax.dynamic_slice(q_range, (chunk_idx, 0), slice_sizes = (chunk_sizes, 1))
return (chunk_idx + chunk_sizes, _query_chunk_flash_attention(q_range_chunk, k_range, q_chunk, k, v))
_, (out, row_sum, row_max) = lax.scan(chunk_scanner, init = 0, xs = None, length = math.ceil(q_len / Q_CHUNK_SIZE))
out = out.reshape(q_len, bh, v_dim)
row_sum = row_sum.reshape(q_len, bh)
row_max = row_max.reshape(q_len, bh)
out = rearrange(out, 'n (b h) d -> b h n d', b = batch)
return out, (row_sum, row_max)
# 定义自定义 VJP 和 JIT 编译的因果 flash attention 函数
@custom_vjp
@jit
def causal_flash_attention(q, k, v):
out, _ = _causal_flash_attention(q, k, v)
return out
# JIT 编译的 flash attention 前向传播函数
@jit
def flash_attention_forward(q, k, v):
out, (row_sum, row_max) = _causal_flash_attention(q, k, v)
return out, (q, k, v, out, row_sum, row_max)
# 定义用于反向传播的内部函数
def _query_chunk_flash_attention_backward(query_range_chunk, key_range, q, k, v, o, do, l, m):
q_len, bh, dim, k_len, _, v_dim = *q.shape, *v.shape
# 计算缩放因子,用于缩放查询向量
scale = 1 / jnp.sqrt(dim)
# 对查询向量进行缩放
q_scaled = q * scale
# 定义一个函数,用于处理每个块的计算
def chunk_scanner(carries, _):
key_chunk_idx, dq = carries
# 确定当前块的大小
k_chunk_sizes = min(K_CHUNK_SIZE, k_len)
# 从键和值中提取当前块的数据
k_chunk = lax.dynamic_slice(k, (key_chunk_idx, 0, 0), slice_sizes=(k_chunk_sizes, bh, dim))
v_chunk = lax.dynamic_slice(v, (key_chunk_idx, 0, 0), slice_sizes=(k_chunk_sizes, bh, v_dim))
# 从键范围中提取当前块的数据
key_range_chunk = lax.dynamic_slice(key_range, (0, key_chunk_idx), slice_sizes=(1, k_chunk_sizes))
# 创建因果掩码,用于屏蔽未来信息
causal_mask = query_range_chunk < key_range_chunk
# 计算注意力权重
attn_weights = einsum('i ... d, j ... d -> i ... j', q_scaled, k_chunk)
# 将因果掩码应用到注意力权重中
causal_mask = rearrange(causal_mask, 'i j -> i 1 j')
attn_weights = jnp.where(causal_mask, MASK_VALUE, attn_weights)
# 计算指数化的注意力权重
exp_attn_weights = jnp.exp(attn_weights - m)
# 将因果掩码应用到指数化的注意力权重中
exp_attn_weights = jnp.where(causal_mask, 0., exp_attn_weights)
# 计算归一化的注意力权重
p = exp_attn_weights / l
# 计算值向量的加权和
dv_chunk = einsum('i ... j, i ... d -> j ... d', p, do)
dp = einsum('i ... d, j ... d -> i ... j', do, v_chunk)
# 计算 D 和 ds
D = jnp.sum(do * o, axis = -1, keepdims = True)
ds = p * scale * (dp - D)
# 计算查询向量的梯度
dq_chunk = einsum('i ... j, j ... d -> i ... d', ds, k_chunk)
dk_chunk = einsum('i ... j, i ... d -> j ... d', ds, q)
return (key_chunk_idx + k_chunk_sizes, dq + dq_chunk), (dk_chunk, dv_chunk)
# 初始化查询向量的梯度
dq = jnp.zeros_like(q)
# 执行块扫描操作,计算查询向量、键向量和值向量的梯度
(_, dq), (dk, dv) = lax.scan(chunk_scanner, init = (0, dq), xs = None, length = math.ceil(k_len / K_CHUNK_SIZE))
# 重塑查询向量、键向量和值向量的梯度
dq = dq.reshape(q_len, bh, dim)
dk = dk.reshape(k_len, bh, v_dim)
dv = dv.reshape(k_len, bh, v_dim)
# 返回查询向量、键向量和值向量的梯度
return dq, dk, dv
# 使用 JIT 编译器对函数进行优化
@jit
# 定义反向传播函数 flash_attention_backward,接受 res 和 do 两个参数
def flash_attention_backward(res, do):
# 解包 res 中的变量 q, k, v, o, l, m
q, k, v, o, l, m = res
# 获取 q, k, v 的形状信息
batch, heads, q_len, dim, k_len, v_dim = *q.shape, *v.shape[-2:]
# 计算 batch * heads
bh = batch * heads
# 重塑 m 和 l 的形状
m = m.reshape(q_len, bh, 1)
l = l.reshape(q_len, bh, 1)
# 重塑 q, k, v, o, do 的形状
q, k, v, o, do = map(lambda t: rearrange(t, 'b h n d -> n (b h) d'), (q, k, v, o, do))
# 创建与 k 形状相同的全零数组 dk 和 dv
dk = jnp.zeros_like(k)
dv = jnp.zeros_like(v)
# 创建 q_len 的范围数组
q_range = jnp.arange(q_len).reshape(q_len, 1) + (k_len - q_len)
k_range = jnp.arange(k_len).reshape(1, k_len)
# 定义 chunk_scanner 函数
def chunk_scanner(carries, _):
chunk_idx, dk, dv = carries
# 计算 chunk_sizes
chunk_sizes = min(Q_CHUNK_SIZE, q_len)
# 切片获取 q_chunk 和 q_range_chunk
q_chunk = lax.dynamic_slice(q, (chunk_idx, 0, 0), slice_sizes = (chunk_sizes, bh, q.shape[-1]))
q_range_chunk = lax.dynamic_slice(q_range, (chunk_idx, 0), slice_sizes = (chunk_sizes, 1))
# 切片获取 m_chunk, l_chunk, o_chunk, do_chunk
m_chunk = lax.dynamic_slice(m, (chunk_idx, 0, 0), slice_sizes = (chunk_sizes, bh, 1))
l_chunk = lax.dynamic_slice(l, (chunk_idx, 0, 0), slice_sizes = (chunk_sizes, bh, 1))
o_chunk = lax.dynamic_slice(o, (chunk_idx, 0, 0), slice_sizes = (chunk_sizes, bh, o.shape[-1]))
do_chunk = lax.dynamic_slice(do, (chunk_idx, 0, 0), slice_sizes = (chunk_sizes, bh, do.shape[-1]))
# 调用 _query_chunk_flash_attention_backward 函数处理 chunk 数据
dq_chunk, dk_chunk, dv_chunk = _query_chunk_flash_attention_backward(q_range_chunk, k_range, q_chunk, k, v, o_chunk, do_chunk, l_chunk, m_chunk)
return (chunk_idx + chunk_sizes, dk + dk_chunk, dv + dv_chunk), dq_chunk
# 使用 lax.scan 函数对 chunk_scanner 进行迭代计算
(_, dk, dv), dq = lax.scan(chunk_scanner, init = (0, dk, dv), xs = None, length = math.ceil(q_len / Q_CHUNK_SIZE))
# 重塑 dq 的形状
dq = dq.reshape(q_len, bh, dim)
# 重塑 dq, dk, dv 的形状
dq, dk, dv = map(lambda t: rearrange(t, 'n (b h) d -> b h n d', b = batch), (dq, dk, dv))
# 返回 dq, dk, dv
return dq, dk, dv
# 定义 causal_flash_attention 的导数函数
causal_flash_attention.defvjp(flash_attention_forward, flash_attention_backward)
.\lucidrains\flash-attention-jax\flash_attention_jax\cosine_sim_flash_attention.py
# 导入数学库和 JAX 库,以及部分函数
import math
import jax
from functools import partial
from jax import nn
from jax import custom_vjp
from jax import numpy as jnp, lax, jit
# 常量定义
EPSILON = 1e-10
MASK_VALUE = -1e10
Q_CHUNK_SIZE = 1024
K_CHUNK_SIZE = 1024
COSINE_SIM_SCALE = 10 # 这可能需要是 log(序列长度) 的函数,但在我的测试中,16 对于 2048 和 4096 是足够的
# 闪电注意力
def _query_chunk_flash_attention(chunk_idx, q, k, v, key_mask):
q_len, k_len, dim, v_dim = q.shape[-2], *k.shape, v.shape[-1]
def chunk_scanner(carries, _):
chunk_idx, out, row_sum = carries
k_chunk_sizes = min(K_CHUNK_SIZE, k_len)
k_chunk = lax.dynamic_slice(k, (chunk_idx, 0), slice_sizes=(k_chunk_sizes, dim))
v_chunk = lax.dynamic_slice(v, (chunk_idx, 0), slice_sizes=(k_chunk_sizes, v_dim))
key_mask_chunk = lax.dynamic_slice(key_mask, (chunk_idx,), slice_sizes=(k_chunk_sizes,))
attn_weights = (q @ k_chunk.transpose() * COSINE_SIM_SCALE) - COSINE_SIM_SCALE # 这个输出范围为 [-2 * scale, 0],行和现在受到键/值序列长度的限制 - 如果您希望定制归一化常数(在极端序列长度的情况下),也可以进一步移动这个值
attn_weights = jnp.where(key_mask_chunk, attn_weights, MASK_VALUE)
exp_weights = jnp.exp(attn_weights)
exp_weights = jnp.where(key_mask_chunk, exp_weights, 0.)
block_row_sum = jnp.sum(exp_weights, axis = -1, keepdims = True)
exp_values = exp_weights @ v_chunk
chunk_out = exp_values / k_len
return (chunk_idx + k_chunk_sizes, out + chunk_out, row_sum + block_row_sum), None
out = jnp.zeros((q_len, dim))
row_sum = jnp.zeros((q_len, 1))
(_, out, row_sum), _ = lax.scan(chunk_scanner, init = (0, out, row_sum), xs = None, length = math.ceil(k_len / K_CHUNK_SIZE))
out = out * (k_len / (row_sum + EPSILON)) # 在获取所有正确的行和之后重新归一化
out = out.reshape(q_len, v_dim)
row_sum = row_sum.reshape(q_len)
return out, row_sum
@jit
def l2norm(t):
return t / (jnp.linalg.norm(t) + EPSILON)
@jit
def cosine_sim_flash_attention(q, k, v, key_mask):
q, k = map(l2norm, (q, k))
return cosine_sim_flash_attention_after_l2norm(q, k, v, key_mask)
def _cosine_sim_flash_attention_after_l2norm(q, k, v, key_mask):
q_len, dim, v_dim = *q.shape, v.shape[-1]
def chunk_scanner(chunk_idx, _):
chunk_sizes = min(Q_CHUNK_SIZE, q_len)
q_chunk = lax.dynamic_slice(q, (chunk_idx, 0), slice_sizes = (chunk_sizes, dim))
return (chunk_idx + chunk_sizes, _query_chunk_flash_attention(chunk_idx, q_chunk, k, v, key_mask))
_, (out, row_sum) = lax.scan(chunk_scanner, init = 0, xs = None, length = math.ceil(q_len / Q_CHUNK_SIZE))
out = out.reshape(q_len, v_dim)
row_sum = row_sum.reshape(q_len)
return out, (row_sum,)
@custom_vjp
def cosine_sim_flash_attention_after_l2norm(q, k, v, key_mask):
out, _ = _cosine_sim_flash_attention_after_l2norm(q, k, v, key_mask)
return out
@jit
def flash_attention_forward(q, k, v, key_mask):
out, (row_sum,) = _cosine_sim_flash_attention_after_l2norm(q, k, v, key_mask)
return out, (q, k, v, key_mask, out, row_sum)
def _query_chunk_flash_attention_backward(q, k, v, key_mask,o, do, l):
q_len, dim, k_len, v_dim = *q.shape, *v.shape
# 定义一个函数,用于扫描处理输入数据的分块
def chunk_scanner(carries, _):
# 从输入参数中获取当前处理的分块索引和数据
chunk_idx, dq = carries
# 计算当前分块的大小,取最小值
k_chunk_sizes = min(K_CHUNK_SIZE, k_len)
# 从输入数据中切片出当前处理的键值对应的分块数据
k_chunk = lax.dynamic_slice(k, (chunk_idx, 0), slice_sizes=(k_chunk_sizes, dim))
v_chunk = lax.dynamic_slice(v, (chunk_idx, 0), slice_sizes=(k_chunk_sizes, v_dim))
key_mask_chunk = lax.dynamic_slice(key_mask, (chunk_idx,), slice_sizes=(k_chunk_sizes,))
# 计算注意力权重
attn_weights = q @ k_chunk.transpose() * COSINE_SIM_SCALE - COSINE_SIM_SCALE
# 计算指数化的注意力权重
exp_attn_weights = jnp.exp(attn_weights)
# 将注意力权重应用于键掩码
exp_attn_weights = jnp.where(key_mask_chunk, exp_attn_weights, 0.)
# 计算注意力概率
p = exp_attn_weights / (l + EPSILON)
# 计算值的梯度
dv_chunk = p.transpose() @ do
dp = do @ v_chunk.transpose()
# 计算 D 值
D = jnp.sum(do * o, axis=-1, keepdims=True)
# 计算 s 值
ds = p * COSINE_SIM_SCALE * (dp - D)
# 计算查询的梯度
dq_chunk = ds @ k_chunk
# 计算键的梯度
dk_chunk = ds.transpose() @ q
# 返回更新后的分块索引和数据
return (chunk_idx + k_chunk_sizes, dq + dq_chunk), (dk_chunk, dv_chunk)
# 初始化 dq
dq = jnp.zeros_like(q)
# 使用 chunk_scanner 函数扫描处理输入数据的分块
(_, dq), (dk, dv) = lax.scan(chunk_scanner, init=(0, dq), xs=None, length=math.ceil(k_len / K_CHUNK_SIZE))
# 重新调整 dq、dk、dv 的形状
dq = dq.reshape(q_len, dim)
dk = dk.reshape(k_len, v_dim)
dv = dv.reshape(k_len, v_dim)
# 返回更新后的 dq、dk、dv
return dq, dk, dv
# 使用 JIT 编译器对函数进行即时编译,提高性能
@jit
# 定义反向传播函数,接收前向传播的结果和梯度
def flash_attention_backward(res, do):
# 解包前向传播结果
q, k, v, key_mask, o, l = res
# 获取查询向量的长度和维度
q_len, dim = q.shape
# 创建和 k, v 相同形状的零矩阵
dk = jnp.zeros_like(k)
dv = jnp.zeros_like(v)
# 重塑 l 的形状为 (q_len, 1)
l = l.reshape(q_len, 1)
# 定义一个函数用于扫描数据块
def chunk_scanner(carries, _):
chunk_idx, dk, dv = carries
# 设置数据块的大小,不超过 Q_CHUNK_SIZE
chunk_sizes = min(Q_CHUNK_SIZE, q_len)
# 切片获取查询向量的数据块
q_chunk = lax.dynamic_slice(q, (chunk_idx, 0), slice_sizes=(chunk_sizes, q.shape[-1]))
l_chunk = lax.dynamic_slice(l, (chunk_idx, 0), slice_sizes=(chunk_sizes, 1))
o_chunk = lax.dynamic_slice(o, (chunk_idx, 0), slice_sizes=(chunk_sizes, o.shape[-1]))
do_chunk = lax.dynamic_slice(do, (chunk_idx, 0), slice_sizes=(chunk_sizes, do.shape[-1]))
# 调用子函数计算梯度
dq_chunk, dk_chunk, dv_chunk = _query_chunk_flash_attention_backward(q_chunk, k, v, key_mask, o_chunk, do_chunk, l_chunk)
return (chunk_idx + chunk_sizes, dk + dk_chunk, dv + dv_chunk), dq_chunk
# 使用 lax.scan 函数扫描数据块
(_, dk, dv), dq = lax.scan(chunk_scanner, init=(0, dk, dv), xs=None, length=math.ceil(q_len / Q_CHUNK_SIZE))
# 重塑 dq 的形状为 (q_len, dim)
dq = dq.reshape(q_len, dim)
# 返回 dq, dk, dv 和 None
return dq, dk, dv, None
# 定义反向传播函数的导数
cosine_sim_flash_attention_after_l2norm.defvjp(flash_attention_forward, flash_attention_backward)
.\lucidrains\flash-attention-jax\flash_attention_jax\flash_attention.py
# 导入数学库和 JAX 库
import math
import jax
# 导入 partial 函数
from functools import partial
# 从 JAX 库中导入 nn、custom_vjp、numpy、lax、jit 模块
from jax import nn
from jax import custom_vjp
from jax import numpy as jnp, lax, jit
# 从 JAX 的 numpy 模块中导入 einsum 函数
from jax.numpy import einsum
# 从 einops 库中导入 rearrange 函数
from einops import rearrange
# 常量定义
# 定义 EPSILON 常量
EPSILON = 1e-10
# 定义 MASK_VALUE 常量
MASK_VALUE = -1e10
# 定义 Q_CHUNK_SIZE 常量
Q_CHUNK_SIZE = 1024
# 定义 K_CHUNK_SIZE 常量
K_CHUNK_SIZE = 1024
# 闪电注意力
# 定义 _query_chunk_flash_attention 函数
def _query_chunk_flash_attention(chunk_idx, q, k, v, key_mask):
# 获取 q 的长度、batch 大小、头数、维度、k 的长度和 v 的维度
q_len, batch, heads, dim, k_len, v_dim = *q.shape, k.shape[0], v.shape[-1]
# 计算缩放因子
scale = 1 / jnp.sqrt(dim)
# 对 q 进行缩放
q_scaled = q * scale
# 定义 chunk_scanner 函数
def chunk_scanner(carries, _):
chunk_idx, out, row_sum, row_max = carries
# 计算 k_chunk_sizes
k_chunk_sizes = min(K_CHUNK_SIZE, k_len)
# 切片获取 k_chunk 和 v_chunk
k_chunk = lax.dynamic_slice(k, (chunk_idx, 0, 0, 0), slice_sizes=(k_chunk_sizes, batch, heads, dim))
v_chunk = lax.dynamic_slice(v, (chunk_idx, 0, 0, 0), slice_sizes=(k_chunk_sizes, batch, heads, v_dim))
key_mask_chunk = lax.dynamic_slice(key_mask, (chunk_idx, 0), slice_sizes=(k_chunk_sizes, batch))
# 计算注意力权重
attn_weights = einsum('i ... d, j ... d -> i ... j', q_scaled, k_chunk)
key_mask_chunk = rearrange(key_mask_chunk, 'j b -> 1 b 1 j')
attn_weights = jnp.where(key_mask_chunk, attn_weights, MASK_VALUE)
block_row_max = jnp.max(attn_weights, axis=-1, keepdims=True)
new_row_max = jnp.maximum(block_row_max, row_max)
exp_weights = jnp.exp(attn_weights - new_row_max)
exp_weights = jnp.where(key_mask_chunk, exp_weights, 0.)
block_row_sum = jnp.sum(exp_weights, axis=-1, keepdims=True) + EPSILON
exp_values = einsum('i ... j, j ... d -> i ... d', exp_weights, v_chunk)
exp_row_max_diff = jnp.exp(row_max - new_row_max)
new_row_sum = exp_row_max_diff * row_sum + block_row_sum
out = (row_sum / new_row_sum) * exp_row_max_diff * out + \
(1. / new_row_sum) * exp_values
return (chunk_idx + k_chunk_sizes, out, new_row_sum, new_row_max), None
# 初始化 out、row_sum、row_max
out = jnp.zeros((q_len, batch, heads, dim))
row_sum = jnp.zeros((q_len, batch, heads, 1))
row_max = jnp.ones((q_len, batch, heads, 1)) * -1e6
# 扫描 chunk_scanner 函数
(_, out, row_sum, row_max), _ = lax.scan(chunk_scanner, init=(0, out, row_sum, row_max), xs=None, length=math.ceil(k_len / K_CHUNK_SIZE))
row_sum = rearrange(row_sum, 'n ... 1 -> n ...')
row_max = rearrange(row_max, 'n ... 1 -> n ...')
lse = jnp.log(row_sum) + row_max
return out, lse
# 定义 _flash_attention 函数
def _flash_attention(q, k, v, key_mask):
# 获取 batch、heads、q_len、dim、v_dim
batch, heads, q_len, dim, v_dim = *q.shape, v.shape[-1]
# 定义 chunk_scanner 函数
def chunk_scanner(chunk_idx, _):
chunk_sizes = min(Q_CHUNK_SIZE, q_len)
q_chunk = lax.dynamic_slice(q, (chunk_idx, 0, 0, 0), slice_sizes=(chunk_sizes, batch, heads, dim))
return (chunk_idx + chunk_sizes, _query_chunk_flash_attention(chunk_idx, q_chunk, k, v, key_mask))
# 重排 q、k、v 和 key_mask
q, k, v = map(lambda t: rearrange(t, 'b h n d -> n b h d'), (q, k, v))
key_mask = rearrange(key_mask, 'b j -> j b')
_, (out, lse) = lax.scan(chunk_scanner, init=0, xs=None, length=math.ceil(q_len / Q_CHUNK_SIZE))
out = rearrange(out, 'c n b h d -> b h (c n) d')
lse = rearrange(lse, 'c n b h -> b h (c n)')
return out, lse
# 定义 flash_attention 函数
@custom_vjp
@jit
def flash_attention(q, k, v, key_mask):
out, _ = _flash_attention(q, k, v, key_mask)
return out
# 定义 flash_attention_forward 函数
@jit
def flash_attention_forward(q, k, v, key_mask):
out, lse = _flash_attention(q, k, v, key_mask)
return out, (q, k, v, key_mask, out, lse)
# 定义 _query_chunk_flash_attention_backward 函数
def _query_chunk_flash_attention_backward(q, k, v, key_mask, o, do, lse):
q_len, batch, heads, dim, k_len, v_dim = *q.shape, v.shape[0], v.shape[-1]
scale = 1 / jnp.sqrt(dim)
q_scaled = q * scale
# 定义一个函数用于扫描数据块,处理注意力机制中的计算
def chunk_scanner(carries, _):
# 从参数中获取数据块索引和数据块
chunk_idx, dq = carries
# 计算数据块的大小,取最小值
k_chunk_sizes = min(K_CHUNK_SIZE, k_len)
# 从输入的k中切片出当前数据块的部分
k_chunk = lax.dynamic_slice(k, (chunk_idx, batch, heads, 0), slice_sizes=(k_chunk_sizes, batch, heads, dim))
# 从输入的v中切片出当前数据块的部分
v_chunk = lax.dynamic_slice(v, (chunk_idx, batch, heads, 0), slice_sizes=(k_chunk_sizes, batch, heads, v_dim))
# 从输入的key_mask中切片出当前数据块的部分
key_mask_chunk = lax.dynamic_slice(key_mask, (chunk_idx, batch), slice_sizes=(k_chunk_sizes, batch))
# 计算注意力权重
attn_weights = einsum('i ... d, j ... d -> i ... j', q_scaled, k_chunk)
# 计算softmax
p = jnp.exp(attn_weights - lse)
# 根据key_mask对softmax结果进行掩码处理
key_mask_chunk = rearrange(key_mask_chunk, 'j b -> 1 b 1 j')
p = jnp.where(key_mask_chunk, p, 0.)
# 计算值向量的加权和
dv_chunk = einsum('i ... j, i ... d -> j ... d', p, do)
# 计算梯度
dp = einsum('i ... d, j ... d -> i ... j', do, v_chunk)
# 计算D
D = jnp.sum(do * o, axis = -1, keepdims = True)
# 计算梯度
ds = p * scale * (dp - D)
# 计算查询向量的梯度
dq_chunk = einsum('i ... j, j ... d -> i ... d', ds, k_chunk)
# 计算键向量的梯度
dk_chunk = einsum('i ... j, i ... d -> j ... d', ds, q)
# 返回更新后的数据块索引和梯度
return (chunk_idx + k_chunk_sizes, dq + dq_chunk), (dk_chunk, dv_chunk)
# 初始化查询向量的梯度
dq = jnp.zeros_like(q)
# 使用scan函数对数据块进行扫描,计算梯度
(_, dq), (dk, dv) = lax.scan(chunk_scanner, init = (0, dq), xs = None, length = math.ceil(k_len / K_CHUNK_SIZE))
# 重组键向量的梯度
dk = rearrange(dk, 'c n ... -> (c n) ...')
# 重组值向量的梯度
dv = rearrange(dv, 'c n ... -> (c n) ...')
# 返回查询向量、键向量和值向量的梯度
return dq, dk, dv
# 使用 JIT 编译器对函数进行即时编译,提高性能
@jit
# 定义反向传播函数,接收前向传播的结果和梯度
def flash_attention_backward(res, do):
# 解包前向传播结果
q, k, v, key_mask, o, lse = res
# 获取输入张量的形状信息
batch, heads, q_len, dim = q.shape
# 重新排列张量的维度顺序
lse = rearrange(lse, 'b h n -> n b h 1')
q, k, v, o, do = map(lambda t: rearrange(t, 'b h n d -> n b h d'), (q, k, v, o, do))
key_mask = rearrange(key_mask, 'b j -> j b')
# 创建与 k 和 v 形状相同的零张量
dk = jnp.zeros_like(k)
dv = jnp.zeros_like(v)
# 定义用于扫描每个块的函数
def chunk_scanner(carries, _):
chunk_idx, dk, dv = carries
# 定义每个块的大小
chunk_sizes = min(Q_CHUNK_SIZE, q_len)
# 切片获取每个块的输入张量
q_chunk = lax.dynamic_slice(q, (chunk_idx, batch, heads, 0), slice_sizes = (chunk_sizes, batch, heads, q.shape[-1]))
lse_chunk = lax.dynamic_slice(lse, (chunk_idx, batch, heads, 0), slice_sizes = (chunk_sizes, batch, heads, 1))
o_chunk = lax.dynamic_slice(o, (chunk_idx, batch, heads, 0), slice_sizes = (chunk_sizes, batch, heads, o.shape[-1]))
do_chunk = lax.dynamic_slice(do, (chunk_idx, batch, heads, 0), slice_sizes = (chunk_sizes, batch, heads, do.shape[-1]))
# 调用子函数计算每个块的梯度
dq_chunk, dk_chunk, dv_chunk = _query_chunk_flash_attention_backward(q_chunk, k, v, key_mask, o_chunk, do_chunk, lse_chunk)
return (chunk_idx + chunk_sizes, dk + dk_chunk, dv + dv_chunk), dq_chunk
# 使用 lax.scan 函数对每个块进行扫描
(_, dk, dv), dq = lax.scan(chunk_scanner, init = (0, dk, dv), xs = None, length = math.ceil(q_len / Q_CHUNK_SIZE))
# 重新排列梯度张量的维度顺序
dq = rearrange(dq, 'c n b h d -> b h (c n) d')
dk, dv = map(lambda t: rearrange(t, 'n b h d -> b h n d'), (dk, dv))
# 返回计算得到的梯度
return dq, dk, dv, None
# 将反向传播函数注册到前向传播函数上
flash_attention.defvjp(flash_attention_forward, flash_attention_backward)
.\lucidrains\flash-attention-jax\flash_attention_jax\rabe_attention.py
# 导入数学库和部分函数
import math
from functools import partial
# 导入 JAX 库
import jax
from jax import lax, numpy as jnp, jit
# 定义常量
HIGHEST_PRECISION = jax.lax.Precision.HIGHEST
# 使用 partial 函数创建一个新的函数 einsum,指定精度为 HIGHEST_PRECISION
einsum = partial(jnp.einsum, precision = HIGHEST_PRECISION)
# 定义函数 _query_chunk_attention,实现分块注意力机制
def _query_chunk_attention(q, k, v, k_chunk_size = 4096):
# 获取输入张量的维度信息
q_len, k_len, dim, v_dim = q.shape[-2], *k.shape, v.shape[-1]
# 确定 k_chunk_size 的大小
k_chunk_size = min(k_chunk_size, k_len)
# 对查询张量 q 进行缩放
q = q / jnp.sqrt(dim)
# 定义一个内部函数 summarize_chunk,用于计算每个块的注意力权重和值
@partial(jax.checkpoint, prevent_cse = False)
def summarize_chunk(q, k, v):
# 计算注意力权重
attn_weights = einsum('qd, kd -> qk', q, k)
# 计算最大分数
max_score = jnp.max(attn_weights, axis = -1, keepdims = True)
max_score = jax.lax.stop_gradient(max_score)
# 计算指数权重和值
exp_weights = jnp.exp(attn_weights - max_score)
exp_values = einsum('vf, qv -> qf', v, exp_weights)
return (exp_values, exp_weights.sum(axis = -1), max_score.reshape((q_len,)))
# 定义一个函数 chunk_scanner,用于遍历块并计算注意力权重和值
def chunk_scanner(chunk_idx):
k_chunk = lax.dynamic_slice(k, (chunk_idx, 0), slice_sizes=(k_chunk_size, dim))
v_chunk = lax.dynamic_slice(v, (chunk_idx, 0), slice_sizes=(k_chunk_size, v_dim))
return summarize_chunk(q, k_chunk, v_chunk)
# 使用 map 函数并行处理所有块
chunk_values, chunk_weights, chunk_max = jax.lax.map(chunk_scanner, xs = jnp.arange(0, k_len, k_chunk_size))
global_max = jnp.max(chunk_max, axis = 0, keepdims = True)
max_diffs = jnp.exp(chunk_max - global_max)
chunk_values *= jnp.expand_dims(max_diffs, axis=-1)
chunk_weights *= max_diffs
# 汇总所有块的值和权重
all_values = chunk_values.sum(axis = 0)
all_weights = jnp.expand_dims(chunk_weights, -1).sum(axis = 0)
return all_values / all_weights
# 使用 JIT 编译函数 rabe_attention,实现基于块的自注意力机制
@jit
def rabe_attention(q, k, v, q_chunk_size = 1024, k_chunk_size = 4096):
# 获取输入张量的维度信息
q_len, dim, v_dim = *q.shape, v.shape[-1]
# 定义函数 chunk_scanner,用于遍历查询张量的块并计算注意力权重和值
def chunk_scanner(chunk_idx, _):
q_chunk = lax.dynamic_slice(q, (chunk_idx, 0), slice_sizes = (min(q_chunk_size, q_len), dim))
return (chunk_idx + q_chunk_size, _query_chunk_attention(q_chunk, k, v, k_chunk_size = k_chunk_size))
# 使用 scan 函数并行处理所有查询张量的块
_, res = jax.lax.scan(chunk_scanner, init = 0, xs = None, length = math.ceil(q_len / q_chunk_size))
return res.reshape(q_len, v_dim)
.\lucidrains\flash-attention-jax\flash_attention_jax\utils.py
# 导入 JAX 库
import jax
# 导入 partial 函数
from functools import partial
# 导入 JAX 中的 numpy 模块
import jax.numpy as jnp
# 从 JAX 中导入 random 模块
from jax import random
# 从 JAX 中导入 value_and_grad 函数
# 定义一个装饰器函数,用于计算函数值和梯度
def value_and_grad_wrapper(fn, **kwargs):
# 使用 partial 函数将 value_and_grad 函数应用到 fn 函数上
@partial(value_and_grad, **kwargs)
def inner(*args, **kwargs):
# 返回 fn 函数的和
return jnp.sum(fn(*args, **kwargs))
return inner
# 定义计算两个张量之间差异的函数
def diff(t1, t2):
# 返回两个张量之间的最大绝对值差
return jnp.max(jnp.abs(t1 - t2))
# 定义 PRNGKey 生成器函数
def PRNGKeyGenerator(seed = 42):
# 使用给定种子创建 PRNGKey
key = random.PRNGKey(seed)
# 生成子密钥
while True:
sub_key, key = random.split(key)
yield sub_key
# 定义计算两个函数值和梯度之间差异的函数
def value_and_grad_difference(
fn1,
fn2,
seed = 42,
batch = 2,
heads = 4,
q_seq_len = 4096,
k_seq_len = 8192,
add_key_mask = True,
dim = 512
):
# 创建 PRNGKey 生成器
key_gen = PRNGKeyGenerator(seed)
# 生成随机正态分布的张量 q, k, v
q = random.normal(next(key_gen), (batch, heads, q_seq_len, dim))
k = random.normal(next(key_gen), (batch, heads, k_seq_len, dim))
v = random.normal(next(key_gen), (batch, heads, k_seq_len, dim))
# 生成随机的 key_mask
key_mask = random.randint(next(key_gen), (batch, k_seq_len), 0, 2) == 1
# 使用 partial 函数将 value_and_grad_wrapper 函数应用到 fn1, fn2 上
fn1_value_and_grad, fn2_value_and_grad = map(partial(value_and_grad_wrapper, argnums = (0, 1, 2)), (fn1, fn2))
# 将参数 q, k, v 和 key_mask(如果需要)传递给函数 fn1 和 fn2,并计算函数值和梯度
args = (q, k, v)
if add_key_mask:
args = (*args, key_mask)
# 计算 fn1 和 fn2 的函数值和梯度
o1, grads1 = fn1_value_and_grad(*args)
o2, grads2 = fn2_value_and_grad(*args)
# 返回函数值之间的差异和梯度之间的差异
return diff(o1, o2), [diff(*args) for args in zip(grads1, grads2)]
.\lucidrains\flash-attention-jax\flash_attention_jax\__init__.py
# 从 flash_attention_jax 模块中导入 flash_attention 函数
from flash_attention_jax.flash_attention import flash_attention
# 从 flash_attention_jax 模块中导入 cosine_sim_flash_attention 函数
from flash_attention_jax.cosine_sim_flash_attention import cosine_sim_flash_attention
# 从 flash_attention_jax 模块中导入 causal_flash_attention 函数
from flash_attention_jax.causal_flash_attention import causal_flash_attention
# 从 flash_attention_jax 模块中导入 rabe_attention 函数
from flash_attention_jax.rabe_attention import rabe_attention
# 从 flash_attention_jax 模块中导入 attention, causal_attention, cosine_sim_attention 函数
from flash_attention_jax.attention import attention, causal_attention, cosine_sim_attention
# 从 flash_attention_jax.utils 模块中导入 value_and_grad_difference, PRNGKeyGenerator 函数
from flash_attention_jax.utils import value_and_grad_difference, PRNGKeyGenerator
# 将 attention 函数赋值给 plain_attention 变量
plain_attention = attention

Flash Attention - Jax
Implementation of Flash Attention in Jax. It will likely not be as performant as with the official CUDA version, given lack of ability for fine memory management. But just for educational purposes as well as to see how clever XLA compiler is (or is not).
Install
$ pip install flash-attention-jax
Usage
from jax import random
from flash_attention_jax import flash_attention
rng_key = random.PRNGKey(42)
q = random.normal(rng_key, (1, 2, 131072, 512)) # (batch, heads, seq, dim)
k = random.normal(rng_key, (1, 2, 131072, 512))
v = random.normal(rng_key, (1, 2, 131072, 512))
mask = random.randint(rng_key, (1, 131072,), 0, 2) # (batch, seq)
out, _ = flash_attention(q, k, v, mask)
out.shape # (1, 2, 131072, 512) - (batch, heads, seq, dim)
Quick sanity check
from flash_attention_jax import plain_attention, flash_attention, value_and_grad_difference
diff, (dq_diff, dk_diff, dv_diff) = value_and_grad_difference(
plain_attention,
flash_attention,
seed = 42
)
print('shows differences between normal and flash attention for output, dq, dk, dv')
print(f'o: {diff}') # < 1e-4
print(f'dq: {dq_diff}') # < 1e-6
print(f'dk: {dk_diff}') # < 1e-6
print(f'dv: {dv_diff}') # < 1e-6
Autoregressive Flash Attention - GPT-like decoder attention
from jax import random
from flash_attention_jax import causal_flash_attention
rng_key = random.PRNGKey(42)
q = random.normal(rng_key, (131072, 512))
k = random.normal(rng_key, (131072, 512))
v = random.normal(rng_key, (131072, 512))
out, _ = causal_flash_attention(q, k, v)
out.shape # (131072, 512)
Todo
-
leading dimensions for causal flash attention variant
-
figure out issue with jit and static argnums
-
comment with references to paper algorithms and explanations
-
make sure it can work one-headed key / values, as in PaLM
Citations
@article{Dao2022FlashAttentionFA,
title = {FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness},
author = {Tri Dao and Daniel Y. Fu and Stefano Ermon and Atri Rudra and Christopher R'e},
journal = {ArXiv},
year = {2022},
volume = {abs/2205.14135}
}
@article{Rabe2021SelfattentionDN,
title = {Self-attention Does Not Need O(n2) Memory},
author = {Markus N. Rabe and Charles Staats},
journal = {ArXiv},
year = {2021},
volume = {abs/2112.05682}
}
.\lucidrains\flash-attention-jax\setup.py
# 导入设置工具和查找包的函数
from setuptools import setup, find_packages
# 设置包的元数据
setup(
name = 'flash-attention-jax', # 包的名称
packages = find_packages(exclude=[]), # 查找所有包
version = '0.3.1', # 版本号
license='MIT', # 许可证
description = 'Flash Attention - in Jax', # 描述
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
long_description_content_type = 'text/markdown', # 长描述内容类型
url = 'https://github.com/lucidrains/flash-attention-jax', # 项目链接
keywords = [ # 关键词列表
'artificial intelligence',
'deep learning',
'transformers',
'attention mechanism',
'jax'
],
install_requires=[ # 安装依赖
'einops',
'jax>=0.2.20'
],
classifiers=[ # 分类器
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\flash-cosine-sim-attention\benchmark.py
# 导入必要的库
import argparse
from itertools import product
import torch
from torch import einsum
assert torch.cuda.is_available(), 'cuda must be available to run benchmark'
# 导入自定义模块
from flash_cosine_sim_attention.benchmark import benchmark
from flash_cosine_sim_attention import flash_cosine_sim_attention, l2norm_tensors
# 定义辅助函数
# 检查变量是否存在
def exists(t):
return t is not None
# 将输入转换为元组
def cast_tuple(t):
return t if isinstance(t, tuple) else (t,)
# 解析命令行参数
parser = argparse.ArgumentParser()
parser.add_argument('--causal', default = False, action = 'store_true')
parser.add_argument('--mask-prob', type = float, default = 0.)
parser.add_argument('--only-forwards', default = False, action = 'store_true')
parser.add_argument('--only-backwards', default = False, action = 'store_true')
parser.add_argument('--num-times', default = 20, type = int)
args = parser.parse_args()
# 定义常量
BATCH_SIZES = 4
HEADS = 8
DIM = 64
CAUSAL = args.causal
SHOULD_MASK = args.mask_prob > 0.
assert args.mask_prob >= 0 and args.mask_prob < 1.
assert not (args.only_forwards and args.only_backwards)
assert not (CAUSAL and SHOULD_MASK)
TEST_SEQUENCE_LENGTHS = (128, 256, 512, 1024, 2048, 4096, 8192)
TEST_FORWARDS = not args.only_backwards
TEST_BACKWARDS = not args.only_forwards
# 简化的余弦相似度注意力机制用于基准测试
def simplified_cosine_sim_attention(
q,
k,
v,
scale = 10,
l2norm_qk = True,
causal_mask = None,
mask = None
):
if l2norm_qk:
q, k = l2norm_tensors(q, k)
sim = einsum(f'b h i d, b h j d -> b h i j', q, k)
sim = sim * scale
if exists(mask):
sim = sim.masked_fill(~mask[:, None, None, :], -torch.finfo(sim.dtype).max)
if exists(causal_mask):
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
attn = sim.softmax(dim = -1)
return einsum(f'b h i j, b h j d -> b h i d', attn, v)
# 创建基准测试函数
fused_attention_fn = benchmark(
flash_cosine_sim_attention,
forwards = TEST_FORWARDS,
backwards = TEST_BACKWARDS,
num_times = args.num_times
)
attention_fn = benchmark(
simplified_cosine_sim_attention,
forwards = TEST_FORWARDS,
backwards = TEST_BACKWARDS,
num_times = args.num_times
)
# 所有排列组合
params = dict((
('batch size', BATCH_SIZES),
('heads', HEADS),
('feature dimension', DIM)
))
permutations = list(product(*map(cast_tuple, params.values())))
for name, dtype in (('float32', torch.float32), ('float16', torch.float16)):
# 对于每个批次大小、头数和维度的排列组合
for batch, heads, dim in permutations:
# 打印分隔线
print('-' * 60)
# 打印当前排列组合的名称、批次大小、头数和维度
print(f'{name}\t\tbatch: {batch}\theads: {heads}\tdim {dim}')
# 打印分隔线
print('-' * 60)
# 对于测试序列长度中的每个序列长度
for seq in TEST_SEQUENCE_LENGTHS:
# 生成随机的查询、键和值张量,设置为需要梯度计算,并移动到 GPU 上
q = torch.randn(batch, heads, seq, dim, dtype=dtype).cuda().requires_grad_()
k = torch.randn(batch, heads, seq, dim, dtype=dtype).cuda().requires_grad_()
v = torch.randn(batch, heads, seq, dim, dtype=dtype).cuda().requires_grad_()
# 生成一个上三角矩阵作为因果掩码
causal_mask = torch.ones((seq, seq), dtype=torch.bool).cuda().triu(1)
# 初始化融合注意力函数参数和基准函数参数
fused_args = dict(causal=CAUSAL)
baseline_args = dict()
# 如果使用因果掩码
if CAUSAL:
baseline_args = {**baseline_args, 'causal_mask': causal_mask}
# 如果需要进行掩码
if SHOULD_MASK:
# 生成一个掩码张量
mask = torch.zeros((batch, seq)).float().cuda().uniform_(0, 1) > args.mask_prob
# 更新融合注意力函数参数和基准函数参数
fused_args = {**fused_args, 'mask': mask}
baseline_args = {**baseline_args, 'mask': mask}
# 运行基准函数并考虑内存溢出
fused_time = fused_attention_fn(q, k, v, **fused_args)
try:
baseline_time = attention_fn(q, k, v, **baseline_args)
except:
# 清空 GPU 缓存
torch.cuda.empty_cache()
baseline_time = -1
# 计算融合函数相对于基准函数的速度差异
times_slower = (fused_time / baseline_time) if baseline_time != -1 else 0.
baseline_time_str = 'oom' if baseline_time == -1 else f"{baseline_time:.2f}ms"
# 打印结果信息:序列长度、速度差异倍数、融合函数时间、基准函数时间
print(f'seq_len: {seq}\tslower: {times_slower:.2f}x\tkernel: {fused_time:.2f}ms\tbaseline: {baseline_time_str}')