Lucidrains 系列项目源码解析(三十六)
Data source
The enwik8 data was downloaded from the Hutter prize page: prize.hutter1.net/
.\lucidrains\flash-cosine-sim-attention\flash_cosine_sim_attention\benchmark.py
# 导入 torch 库
import torch
# 从 torch.cuda 模块中导入 synchronize 和 Event 类
from torch.cuda import synchronize, Event
# 从 functools 模块中导入 wraps 和 partial 函数
from functools import wraps, partial
# 创建一个名为 timer 的 partial 函数,用于创建启用计时功能的 Event 对象
timer = partial(Event, enable_timing = True)
# 定义一个装饰器函数 benchmark,用于对指定函数进行性能测试
def benchmark(
fn,
*,
num_times = 10, # 默认测试次数为 10 次
warmup_iters = 10, # 默认预热迭代次数为 10 次
forwards = True, # 默认进行前向传播
backwards = False # 默认不进行反向传播
):
assert forwards or backwards
# 定义内部函数 inner,用于实际执行性能测试
@wraps(fn)
def inner(*args, **kwargs):
# 预热阶段
for _ in range(warmup_iters):
# 调用被测试函数,并获取损失值
loss = fn(*args, **kwargs)
# 如果需要进行反向传播,则计算梯度
if backwards:
loss.sum().backward()
# 计算多次函数调用的平均时间
all_measured_times_ms = 0.
for _ in range(num_times):
# 创建开始和结束事件对象
start_event = timer()
end_event = timer()
# 如果需要进行前向传播,则记录开始事件
if forwards:
start_event.record()
# 调用被测试函数
o = fn(*args, **kwargs)
# 如果不需要进行反向传播,则记录结束事件
if not backwards:
end_event.record()
# 如果不需要进行前向传播,则记录开始事件
if not forwards:
start_event.record()
# 如果需要进行反向传播,则计算损失并反向传播,然后记录结束事件
if backwards:
loss = o.sum()
loss.backward()
end_event.record()
# 同步事件
synchronize()
# 计算经过的时间
elapsed_time_ms = start_event.elapsed_time(end_event)
all_measured_times_ms += elapsed_time_ms
# 返回平均时间
return all_measured_times_ms / num_times
return inner
.\lucidrains\flash-cosine-sim-attention\flash_cosine_sim_attention\dispatch.h
#pragma once
// 自定义调度,灵感来源于其他项目的实现
// 宏工具:
#include <ATen/Dispatch.h>
// 移除括号
#define REMOVE_PAREN_IMPL(...) __VA_ARGS__
#define REMOVE_PAREN(args) REMOVE_PAREN_IMPL args
// 递归展开宏
#define EVAL0(...) __VA_ARGS__
#define EVAL1(...) EVAL0(EVAL0(EVAL0(__VA_ARGS__)))
#define EVAL2(...) EVAL1(EVAL1(EVAL1(__VA_ARGS__)))
#define EVAL3(...) EVAL2(EVAL2(EVAL2(__VA_ARGS__)))
#define EVAL4(...) EVAL3(EVAL3(EVAL3(__VA_ARGS__)))
#define EVAL(...) EVAL4(EVAL4(EVAL4(__VA_ARGS__)))
// 定义宏结束标记
#define MAP_END(...)
#define MAP_OUT
// 获取宏结束标记
#define MAP_GET_END2() 0, MAP_END
#define MAP_GET_END1(...) MAP_GET_END2
#define MAP_GET_END(...) MAP_GET_END1
#define MAP_NEXT0(test, next, ...) next MAP_OUT
#define MAP_NEXT1(test, next) MAP_NEXT0(test, next, 0)
#define MAP_NEXT(test, next) MAP_NEXT1(MAP_GET_END test, next)
// 宏映射
#define MAP0(f, TYPE_NAME, CASE_CODE, x, peek, ...) f(TYPE_NAME, CASE_CODE, x) MAP_NEXT(peek, MAP1)(f, TYPE_NAME, CASE_CODE, peek, __VA_ARGS__)
#define MAP1(f, TYPE_NAME, CASE_CODE, x, peek, ...) f(TYPE_NAME, CASE_CODE, x) MAP_NEXT(peek, MAP0)(f, TYPE_NAME, CASE_CODE, peek, __VA_ARGS__)
#define MAP(f, TYPE_NAME, CASE_CODE, ...) EVAL(MAP1(f, TYPE_NAME, CASE_CODE, __VA_ARGS__, ()()(), ()()(), ()()(), 0))
// 类型调度
#define AT_TYPE_DISPATCH_CASE(TYPE_NAME, CASE_CODE, x) \
case x: { \
using TYPE_NAME C10_UNUSED_DISPATCH_CUDA_WORKAROUND = \
typename c10::impl::ScalarTypeToCPPType<x>::type; \
REMOVE_PAREN(CASE_CODE) \
break; \
}
#define AT_TYPE_DISPATCH_SWITCH(TYPE, TYPE_NAME, TYPES, CASE_CODE, DEFAULT_CODE) \
{ \
switch (TYPE) { \
MAP(AT_TYPE_DISPATCH_CASE, TYPE_NAME, CASE_CODE, REMOVE_PAREN(TYPES)) \
default: { \
REMOVE_PAREN(DEFAULT_CODE) \
} \
} \
}
// 值调度
#define VALUE_DISPATCH_CASE(VALUE_NAME, CASE_CODE, x) \
case x: { \
constexpr const auto VALUE_NAME = x; \
REMOVE_PAREN(CASE_CODE) \
break; \
}
#define VALUE_DISPATCH_SWITCH(VALUE, VALUE_NAME, VALUES, CASE_CODE, DEFAULT_CODE) \
{ \
switch (VALUE) { \
MAP(VALUE_DISPATCH_CASE, VALUE_NAME, CASE_CODE, REMOVE_PAREN(VALUES)) \
default: { \
REMOVE_PAREN(DEFAULT_CODE) \
} \
} \
}
.\lucidrains\flash-cosine-sim-attention\flash_cosine_sim_attention\flash_cosine_sim_attention.py
import os
import math
import importlib
from functools import partial, wraps
import torch
from torch import einsum
import torch.nn.functional as F
from torch.autograd import Function
# 导入版本信息
exec(open(os.path.dirname(os.path.abspath(__file__)) + '/version.py').read())
# 尝试导入 CUDA 扩展
try:
cuda_pkg = importlib.import_module(__cuda_pkg_name__)
# 从 CUDA 包中导入函数
forward = cuda_pkg.forward
backward = cuda_pkg.backward
debug = cuda_pkg.debug
except ImportError:
# 如果导入失败,则打印错误信息
print('CUDA extension for flash-cosine-sim-attention was not compiled correctly - please run `pip install flash-cosine-sim-attention --force-reinstall --no-cache-dir`')
# 辅助函数
# 检查值是否存在
def exists(val):
return val is not None
# 返回默认值
def default(val, d):
if exists(val):
return val
return d() if callable(d) else d
# 检查是否可以整除
def divisible_by(numer, denom):
return (numer % denom) == 0
# CPU 上的 L2 范数计算
def l2norm_cpu(t):
eps = 1e-12 if t.dtype == torch.float32 else 1e-3
norm = t.norm(dim = -1)
norm_clamped = torch.where(norm > eps, norm, eps)
return t / norm_clamped[..., None]
# 对输入进行 L2 范数归一化
def l2norm(t):
if t.data.is_cuda:
return F.normalize(t, dim = -1)
return l2norm_cpu(t)
# 对分组进行 L2 范数归一化
def grouped_l2norm(t, groups = 1):
shape = t.shape
dim = shape[-1]
t = t.reshape(*shape[:-1], groups, dim // groups)
t = l2norm(t)
return t.reshape(shape)
# 对多个张量进行 L2 范数归一化
def l2norm_tensors(*tensors, groups = 1):
assert len(tensors) > 0
dtype = tensors[0].dtype
fn = partial(grouped_l2norm, groups = groups)
tensors = tuple(map(fn, tensors))
tensors = tuple(map(lambda t: t.type(dtype), tensors))
return tensors
# 原始的余弦相似度注意力机制
# b - batch
# h - heads
# i - src sequence length
# j - target sequence length
# d - feature dimension
def plain_cosine_sim_attention(
q,
k,
v,
mask = None,
attn_bias = None,
scale = 8,
groups = 1,
causal = False,
l2norm_qk = True,
attn_bias_batch_dim = False
):
assert not (causal and exists(mask)), 'mask should not be supplied if causality is needed'
is_merged_batch_heads_query = q.ndim == 3
single_head_kv = k.ndim == 3
if is_merged_batch_heads_query:
assert k.ndim == 3 and v.ndim ==3, 'if batch and heads are merged for queries, keys and values must also similarly have only 3 dimensions'
attn_bias_batch_dim = True
q = q[:, None, ...]
if l2norm_qk:
q, k = l2norm_tensors(q, k, groups = groups)
kv_einsum_eq = 'b j d' if single_head_kv else 'b h j d'
sim = einsum(f'b h i d, {kv_einsum_eq} -> b h i j', q, k)
sim = sim * scale
if exists(attn_bias):
attn_bias = attn_bias.unsqueeze(1 if attn_bias_batch_dim else 0)
sim = sim + attn_bias
mask_value = -torch.finfo(sim.dtype).max
if causal:
i, j = sim.shape[-2:]
causal_mask = torch.ones((i, j), device = q.device, dtype = torch.bool).triu(j - i + 1)
sim = sim.masked_fill(causal_mask, mask_value)
if exists(mask):
sim = sim.masked_fill(~mask[:, None, None, :], mask_value)
attn = sim.softmax(dim = -1)
out = einsum(f'b h i j, {kv_einsum_eq} -> b h i d', attn, v)
if is_merged_batch_heads_query:
out = out.squeeze(1)
return out
# CPU 上的前向传播
def flash_cosine_sim_attention_cpu(
q, k, v,
mask,
attn_bias,
scale,
causal,
attn_bias_batch_dim,
row_tile_size = 512,
col_tile_size = 512
):
needs_backwards = any([exists(t) and t.requires_grad for t in (q, k, v, attn_bias)])
assert not needs_backwards, 'cpu version does not support backwards'
assert not (causal and exists(mask)), 'mask should not be supplied if causality is needed'
dtype = q.dtype
q, k, v = q.float(), k.float(), v.float()
is_merged_batch_heads_query = q.ndim == 3
single_head_kv = k.ndim == 3
shape = q.shape
col_seq_len = k.shape[-2]
row_seq_len = q.shape[-2]
seq_len_diff = col_seq_len - row_seq_len
# 计算行方向上的瓦片数量
row_tiles = math.ceil(row_seq_len / row_tile_size)
# 计算列方向上的瓦片数量
col_tiles = math.ceil(col_seq_len / col_tile_size)
# 获取数据类型 q 的最小负值
max_neg_value = -torch.finfo(q.dtype).max
# 如果合并了批次和头部的查询,则确保键和值也只有3个维度
if is_merged_batch_heads_query:
assert k.ndim == 3 and v.ndim ==3, 'if batch and heads are merged for queries, keys and values must also similarly have only 3 dimensions'
# 在批次维度上添加一个维度
attn_bias_batch_dim = True
q = q.unsqueeze(1)
# 如果存在注意力偏置
if exists(attn_bias):
# 在适当的维度上添加一个维度
attn_bias = attn_bias.unsqueeze(1 if attn_bias_batch_dim else 0)
# 根据是否为单头键值对,设置矩阵乘法的公式
kv_einsum_eq = 'b j d' if single_head_kv else 'b h j d'
# 循环遍历行和列
# 创建一个与 q 相同形状的全零张量
o = torch.zeros_like(q)
# 创建一个与 q 形状除了最后一个维度为1的张量
l = torch.zeros((*q.shape[:-1], 1))
# 准备掩码
# 如果不存在掩码,则创建与列瓦片数量相同数量的 None
if not exists(mask):
mask = (None,) * col_tiles
else:
# 在适当的维度上添加一个维度,并按列瓦片大小拆分
mask = mask[:, None, None, :]
mask = mask.split(col_tile_size, dim = -1)
# 如果不存在注意力偏置,则创建与行瓦片数量相同数量的 None
if not exists(attn_bias):
attn_bias = (None,) * row_tiles
else:
# 按行瓦片大小拆分
attn_bias = attn_bias.split(row_tile_size, dim = -2)
# 按行瓦片大小拆分 q, o, l 和 attn_bias
row_splits = zip(
q.split(row_tile_size, dim = -2),
o.split(row_tile_size, dim = -2),
l.split(row_tile_size, dim = -2),
attn_bias
)
# 遍历行拆分
for ind, (qc, oc, lc, bc) in enumerate(row_splits):
row_chunk_size = qc.shape[-2]
q_start_index = ind * row_tile_size + seq_len_diff
# 如果不存在 bc,则创建与列瓦片数量相同数量的 None
if not exists(bc):
bc = (None,) * col_tiles
else:
# 按列瓦片大小拆分
bc = bc.split(col_tile_size, dim = -1)
# 按列瓦片大小拆分 k, v, mask 和 bc
col_splits = zip(
k.split(col_tile_size, dim = -2),
v.split(col_tile_size, dim = -2),
mask,
bc
)
# 遍历列拆分
for k_ind, (kc, vc, maskc, bias) in enumerate(col_splits):
col_chunk_size = kc.shape[-2]
k_start_index = k_ind * col_tile_size
# 如果是因果的,并且 q_start_index 大于等于 (k_start_index + col_tile_size - 1),则跳过
if causal and q_start_index >= (k_start_index + col_tile_size - 1):
continue
# 计算注意力权重
attn_weights = einsum(f'b h i d, {kv_einsum_eq} -> b h i j', qc, kc) * scale
# 如果存在偏置,则加上偏置
if exists(bias):
attn_weights += bias
# 如果存在掩码,则用最大负值填充不满足掩码条件的位置
if exists(maskc):
attn_weights.masked_fill_(~maskc, max_neg_value)
# 如果是因果的,并且 q_start_index 小于 (k_start_index + col_tile_size - 1)
if causal and q_start_index < (k_start_index + col_tile_size - 1):
# 创建一个因果掩码
causal_mask = torch.ones((row_chunk_size, col_chunk_size), dtype = torch.bool).triu(q_start_index - k_start_index + 1)
attn_weights.masked_fill_(causal_mask, max_neg_value)
# 计算指数权重
exp_weights = torch.exp(attn_weights - scale)
# 如果存在掩码,则用 0 填充不满足掩码条件的位置
if exists(maskc):
exp_weights.masked_fill_(~maskc, 0.)
# 计算指数值
exp_values = einsum(f'b h i j, {kv_einsum_eq} -> b h i d', exp_weights, vc)
# 更新输出张量和权重总和张量
oc.add_(exp_values)
lc.add_(exp_weights.sum(dim = -1, keepdim = True))
# 对输出张量除以权重总和张量,并返回重塑后的结果
o.div_(l.clamp(min = 1e-12))
return o.reshape(shape).type(dtype)
# 主要类
class FlashCosineSimAttention(Function):
# 前向传播函数
@staticmethod
def forward(
ctx,
q, k, v,
mask,
attn_bias,
scale,
causal,
attn_bias_batch_dim
):
# 调用前向传播函数计算输出
o, inv_l, should_backwards = forward(
q, k, v,
mask,
attn_bias,
attn_bias_batch_dim,
scale,
causal
)
# 如果不需要反向传播,则直接返回输出
if not should_backwards:
return o
# 保存需要反向传播的信息
ctx.should_backwards = should_backwards
ctx.save_for_backward(o, inv_l, q, k, v, mask, attn_bias)
ctx.params = (
scale,
causal,
attn_bias_batch_dim
)
return o
# 反向传播函数
@staticmethod
def backward(ctx, do):
assert ctx.should_backwards
o, inv_l, q, k, v, mask, attn_bias = ctx.saved_tensors
(
scale,
causal,
attn_bias_batch_dim
) = ctx.params
# 调用反向传播函数计算梯度
dq, dk, dv, db = backward(
do, o, inv_l,
q, k, v,
mask,
attn_bias,
attn_bias_batch_dim,
scale,
causal
)
return dq, dk, dv, None, db, None, None, None, None, None, None, None, None, None, None
# 使用 CUDA 实现的 FlashCosineSimAttention 类
flash_cosine_sim_attention_cuda = FlashCosineSimAttention.apply
# 包装函数
def flash_cosine_sim_attention(
q,
k,
v,
mask = None,
attn_bias = None,
scale = 8,
groups = 1,
causal = False,
l2norm_qk = True,
attn_bias_batch_dim = False
):
# 如果需要对输入进行 L2 归一化,则调用 l2norm_tensors 函数
if l2norm_qk:
q, k = l2norm_tensors(q, k, groups = groups)
# 根据输入是否在 CUDA 上选择使用 CUDA 还是 CPU 实现的函数
fn = flash_cosine_sim_attention_cuda if q.data.is_cuda else flash_cosine_sim_attention_cpu
# 调用实现函数计算输出
o = fn(
q, k, v,
mask,
attn_bias,
scale,
causal,
attn_bias_batch_dim
)
return o
.\lucidrains\flash-cosine-sim-attention\flash_cosine_sim_attention\transformer.py
import torch
from functools import partial
from torch import nn, einsum
import torch.nn.functional as F
try:
from einops import rearrange
except ImportError:
print('pip install einops to use transformer')
from flash_cosine_sim_attention.flash_cosine_sim_attention import plain_cosine_sim_attention, flash_cosine_sim_attention
# helper function
# 检查变量是否存在的辅助函数
def exists(val):
return val is not None
# 使用 Xavier 初始化权重的函数
def init_weight_xavier_normal_(module, beta):
nn.init.xavier_normal_(module.weight.data, gain = beta)
# 评估装饰器函数,用于在模型评估时切换模型状态
def eval_decorator(fn):
def inner(model, *args, **kwargs):
was_training = model.training
model.eval()
out = fn(model, *args, **kwargs)
model.train(was_training)
return out
return inner
# 非余弦相似度注意力函数
def non_cosine_sim_attn_fn(q, k, v, **kwargs):
q = q * (q.shape[-1] ** -0.5)
sim = einsum('b h i d, b h j d -> b h i j', q, k)
i, j = sim.shape[-2:]
causal_mask = torch.ones((i, j), dtype = torch.bool, device = q.device).triu(j - i + 1)
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
attn = sim.softmax(dim = -1)
return einsum('b h i j, b h j d -> b h i d', attn, v)
# top k 过滤函数
def top_k(logits, thres = 0.9):
k = int((1 - thres) * logits.shape[-1])
val, ind = torch.topk(logits, k)
probs = torch.full_like(logits, float('-inf'))
probs.scatter_(1, ind, val)
return probs
# 注意力和前馈网络
# 前馈网络函数
def FeedForward(dim, mult = 4, pre_norm = False):
dim_hidden = int(dim * mult)
return nn.Sequential(
nn.LayerNorm(dim) if pre_norm else nn.Identity(),
nn.Linear(dim, dim_hidden, bias = False),
nn.GELU(),
nn.Linear(dim_hidden, dim, bias = False)
)
# 注意力模块
class Attention(nn.Module):
def __init__(
self,
dim,
dim_head = 64,
heads = 8,
scale = 8,
l2norm_groups = 1,
pre_norm = False,
use_cuda_kernel = False,
non_cosine_sim_attn = False,
**kwargs
):
super().__init__()
inner_dim = dim_head * heads
self.norm = nn.LayerNorm(dim) if pre_norm else nn.Identity()
self.scale = scale
self.heads = heads
self.l2norm_groups = l2norm_groups
if non_cosine_sim_attn:
self.attn_fn = non_cosine_sim_attn_fn
elif use_cuda_kernel:
self.attn_fn = partial(flash_cosine_sim_attention, **kwargs)
else:
self.attn_fn = plain_cosine_sim_attention
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_k = nn.Linear(dim, inner_dim, bias = False)
self.to_v = nn.Linear(dim, inner_dim, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
def forward(self, x):
h, scale, l2norm_groups = self.heads, self.scale, self.l2norm_groups
x = self.norm(x)
q, k, v = self.to_q(x), self.to_k(x), self.to_v(x)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
o = self.attn_fn(q, k, v, causal = True, scale = scale, groups = l2norm_groups)
o = rearrange(o, 'b h n d -> b n (h d)')
return self.to_out(o)
# 用于测试的变换器模型
class CosineSimCausalTransformer(nn.Module):
def __init__(
self,
*,
num_tokens,
dim,
max_seq_len,
depth,
attn_scale = 8,
attn_l2norm_groups = 1,
heads = 8,
dim_head = 64,
use_cuda_kernel = False,
pre_norm = False,
non_cosine_sim_attn = False,
**kwargs
# 初始化模型参数
def __init__(
self,
max_seq_len,
num_tokens,
dim,
depth,
dim_head,
heads,
use_cuda_kernel,
attn_scale,
attn_l2norm_groups,
pre_norm,
non_cosine_sim_attn,
**kwargs
):
# 调用父类的初始化方法
super().__init__()
# 设置最大序列长度
self.max_seq_len = max_seq_len
# 创建 token embedding 层
self.token_emb = nn.Embedding(num_tokens, dim)
# 创建位置 embedding 层
self.pos_emb = nn.Embedding(max_seq_len, dim)
# 计算残差连接的缩放因子
self.residual_scale = 1 if pre_norm else ((2 * depth) ** 0.25)
# 创建多层 Transformer 模型
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
# 添加注意力机制层
Attention(dim, dim_head=dim_head, heads=heads, use_cuda_kernel=use_cuda_kernel, scale=attn_scale, groups=attn_l2norm_groups, pre_norm=pre_norm, non_cosine_sim_attn=non_cosine_sim_attn, **kwargs),
# 添加 LayerNorm 层或者恒等映射层
nn.LayerNorm(dim) if not pre_norm else nn.Identity(),
# 添加前馈神经网络层
FeedForward(dim, pre_norm=pre_norm),
# 添加 LayerNorm 层或者恒等映射层
nn.LayerNorm(dim) if not pre_norm else nn.Identity(),
]))
# 创建输出层
self.to_logits = nn.Sequential(
nn.LayerNorm(dim) if pre_norm else nn.Identity(),
nn.Linear(dim, num_tokens, bias=False)
)
# 如果不使用预层归一化,则初始化模型参数
if not pre_norm:
self.init_(depth)
# 初始化模型参数
def init_(self, depth):
# 初始化 token embedding 层和位置 embedding 层的权重
nn.init.normal_(self.token_emb.weight, std=1e-5)
nn.init.normal_(self.pos_emb.weight, std=1e-5)
# 计算初始化权重的增益
init_gain = (8 * depth) ** -0.25
# 初始化每一层的权重
for attn, _, ff, _ in self.layers:
init_weight_xavier_normal_(attn.to_q, 1.)
init_weight_xavier_normal_(attn.to_k, 1.)
init_weight_xavier_normal_(attn.to_v, init_gain)
init_weight_xavier_normal_(attn.to_out, init_gain)
init_weight_xavier_normal_(ff[1], init_gain)
init_weight_xavier_normal_(ff[3], init_gain)
init_weight_xavier_normal_(self.to_logits[-1], 1)
# 生成序列
@torch.no_grad()
@eval_decorator
def generate(self, start_tokens, seq_len, temperature=1., filter_thres=0.9, **kwargs):
# 获取输入序列的形状和设备信息
b, n, device = *start_tokens.shape, start_tokens.device
# 初始化输出序列
out = start_tokens
# 生成序列
for _ in range(seq_len):
logits = self.forward(out[:, -self.max_seq_len:], **kwargs)[:, -1, :]
filtered_logits = top_k(logits, thres=filter_thres)
probs = F.softmax(filtered_logits / temperature, dim=-1)
sample = torch.multinomial(probs, 1)
out = torch.cat((out, sample), dim=-1)
return out[:, n:]
# 前向传播
def forward(self, x, return_loss=False):
# 如果需要计算损失,则获取输入序列和标签序列
if return_loss:
x, labels = x[:, :-1], x[:, 1:]
# 对输入序列进行 token embedding 和位置 embedding
x = self.token_emb(x)
x = x + self.pos_emb(torch.arange(x.shape[1], device=x.device))
# 多层 Transformer 模型的前向传播
for attn, attn_norm, ff, ff_norm in self.layers:
x = attn(x) + x * self.residual_scale
x = attn_norm(x)
x = ff(x) + x * self.residual_scale
x = ff_norm(x)
# 输出层得到 logits
logits = self.to_logits(x)
# 如果不需要计算损失,则返回 logits
if not return_loss:
return logits
# 计算交叉熵损失
loss = F.cross_entropy(rearrange(logits, 'b c n -> b n c'), labels)
return loss
.\lucidrains\flash-cosine-sim-attention\flash_cosine_sim_attention\version.py
# 定义当前模块的版本号
__version__ = '0.1.40'
# 根据当前模块的版本号生成 CUDA 包的名称
__cuda_pkg_name__ = f'flash_cosine_sim_attention_cuda_{__version__.replace(".", "_")}'
.\lucidrains\flash-cosine-sim-attention\flash_cosine_sim_attention\__init__.py
# 从flash_cosine_sim_attention.flash_cosine_sim_attention模块中导入flash_cosine_sim_attention, plain_cosine_sim_attention, l2norm_tensors, debug函数
from flash_cosine_sim_attention.flash_cosine_sim_attention import flash_cosine_sim_attention, plain_cosine_sim_attention, l2norm_tensors, debug

Dive into Deep Learning, redone by Quanta Magazine
Flash Cosine Similarity Attention
Implementation of fused cosine similarity attention in the same style as Flash Attention. The observation is that by adopting l2 normalized queries and keys, you no longer need to keep track of the row maximums for numerical stability. This greatly simplifies the flash attention algorithm, assuming cosine similarity attention comes at no generalization cost.
In other words, stable, fast, memory efficient, and longer context attention with no downsides.
Update: Unfortunately, Robin's experiments showed much worse evaluation FID scores not reflected in the loss. Pending more experiments. Use this library with caution.
Update 2: The only saving grace would be to use grouped l2norm, which could potentially allow for more expressivity. If anyone can evaluate this technique on their generative work and obtain some FID scores, would be much appreciated.
Update 3: An approach similar to cosine sim attention has been proven at scale, with a 22B parameter vision model from Brain.
Status (wip)
At the moment, autoregressive and variable lengthed sequences should be faster across all architectures. For sequences longer than 2048, it will also be memory efficient where regular attention would not.
However, for non-autoregressive without masking, the architecture is still slower on A100 for F16. The aim is to get it to perform faster on A100 forwards and backwards for both F32 and F16, as shared memory is not fully exploited yet.
Older graphic cards without enough shared memory, one will have to gauge the tradeoff of memory efficiency and speed depending on the sequence length being trained at.
Appreciation
-
Arthur Hennequin for coaching me through my first CUDA kernel, and for coding up a simple reference implementation, which helped me to bootstrap the first kernel that comes within reasonable performance to baseline. This work would not have been possible without his expertise.
-
Boris Dayma and Robin Rombach for running experiments the simplified cosine sim attention with fixed scaling on some significant text-to-image models and verifying that it indeeds perform just as well as regular attention.
-
Markus Rabe for penning the paper that showed attention does not require O(n²) memory, and Tri Dao for putting it all together in a CUDA kernel implementation for regular attention, demonstrating superiority in speed using the tiled approach minimizing HBM accesses (and for figuring out
dO * O == dP * Pfor backwards pass). Would not have been able to complete my pilgrimage looking for the ultimate attention formulation without their discoveries. -
Stability.ai for the generous sponsorship to work on cutting edge artificial intelligence research
Install
$ pip install flash-cosine-sim-attention
Usage
Self Attention
import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention
q = torch.randn(1, 8, 1024, 64).cuda()
k = torch.randn(1, 8, 1024, 64).cuda()
v = torch.randn(1, 8, 1024, 64).cuda()
out = flash_cosine_sim_attention(q, k, v) # (1, 8, 1024, 64)
Cross attention
import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention
q = torch.randn(1, 8, 1024, 64).cuda()
k = torch.randn(1, 8, 2048, 64).cuda()
v = torch.randn(1, 8, 2048, 64).cuda()
out = flash_cosine_sim_attention(q, k, v) # (1, 8, 1024, 64)
With key / value masking
import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention
q = torch.randn(1, 8, 1024, 64).cuda()
k = torch.randn(1, 8, 2048, 64).cuda()
v = torch.randn(1, 8, 2048, 64).cuda()
mask = torch.ones(1, 2048).bool().cuda()
out = flash_cosine_sim_attention(q, k, v, mask = mask) # (1, 8, 1024, 64)
Autoregressive
import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention
q = torch.randn(4, 8, 1024, 64).cuda()
k = torch.randn(4, 8, 1024, 64).cuda()
v = torch.randn(4, 8, 1024, 64).cuda()
out = flash_cosine_sim_attention(q, k, v, causal = True) # (4, 8, 1024, 64)
Miscellaneous
Single-headed key / values (Shazeer et al & used in PaLM)
import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention
q = torch.randn(4, 8, 1024, 64).cuda()
k = torch.randn(4, 1024, 64).cuda()
v = torch.randn(4, 1024, 64).cuda()
out = flash_cosine_sim_attention(q, k, v, causal = True) # (4, 8, 1024, 64)
If you need to do operations on the queries and keys in between the l2norm and the actual attention step, just set l2norm_qk = False
ex.
import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention, l2norm_tensors
q = torch.randn(4, 8, 1024, 64).cuda()
k = torch.randn(4, 1024, 64).cuda()
v = torch.randn(4, 1024, 64).cuda()
q, k = l2norm_tensors(q, k)
# do your rotation of queries and keys
# say with https://github.com/lucidrains/rotary-embedding-torch
out = flash_cosine_sim_attention(q, k, v, l2norm_qk = False) # (4, 8, 1024, 64)
Cross attention with causal works as expected - (caching of keys and values in autoregressive during inference, or transformer-xl like training)
import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention
q = torch.randn(1, 8, 1024, 64).cuda()
k = torch.randn(1, 8, 2048, 64).cuda()
v = torch.randn(1, 8, 2048, 64).cuda()
out = flash_cosine_sim_attention(q, k, v, causal = True) # (1, 8, 1024, 64)
If you have batch and head dimensions merged, that is ok
import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention
q = torch.randn(32, 1024, 64).cuda()
k = torch.randn(32, 2048, 64).cuda()
v = torch.randn(32, 2048, 64).cuda()
out = flash_cosine_sim_attention(q, k, v, causal = True) # (32, 1024, 64)
Supported head dimensions
-
16 - f32
-
32
-
64
-
96
-
128
-
16 -f16
-
80 - in progress
Todo
-
bfloat16 support, use sfinae as recommended by Arthur
-
stream from qk_mma to shared memory in chunks to calculate out mma, see if freed smem can be used for caching more
-
support O(n) 1d dynamic positional bias
-
figure out why smem fragment caching would lead to performance degrade, it does not make sense
-
think about use of logsumexp - works but extra log lead to degraded perf
-
prepare a smem fragment caching mechanism, to allow for as much caching as allowed on A100 (or f16)
-
make attention tile size processing customizable for backwards pass
-
move atomic add to overloaded function inside mma
-
flexible which type is used for accumulation
-
test out 64x96 tiles on f16
-
bring in a CPU memory efficient version (only for inference, as training does not make sense) using just plain pytorch code
-
figure out how to dispatch differently for architectures (say A100), in case backwards can make use of the increase in shared memory differently
-
decouple row and column sizes for attention tiles
-
dk and dv are now in f16 when it can be (non single headed kv)
-
support more standard head dimensions (wip)
-
debug and fix bias backwards gradients yet again for head size of 32
-
fix attention bias gradients
-
allow for single-headed key / values, as in PaLM
-
fix atomic add for f16
-
attention bias should be able to accept dimensions of an extra batch dimension, for Alphafold2 like attention biasing
-
automate cache-busting of kernel using version as suffix to package name
-
resolve f16 causal numerical issues
-
adopt all learnings from forward kernel to backwards kernel and make sure it outperforms at least on A100
Description
So far cosine similarity attention is not widely used in industry. The only large model that has been trained with it so far is SwinV2. If anyone can invalidate the approach, please open an issue or send me an email. You can run experiments against regular attention using the x-transformers repository.
Update: Boris Dayma has graciously kicked off an experiment (blue with red as baseline) to validate cosine similarity attention with a fixed scale of 10 in a real-world model setting. 🙏
Update 2: Cosine similarity attention has been proven out in a real-world text-to-image attention network, using a constant scale of 10. No worse than regular attention. Credit goes to Boris Dayma for investing the time to run the experiment and removing doubts surrounding the technique.
Update 3: Robin Rombach has tested out the kernel in this repository with head size of 64 and fixed scale of 10 in a text-to-image model, observing no difference from regular attention. More evaluations pending.
Update 4: The improvement in performance seen in Boris' experiments are likely due to the fact that cosine-sim attention allows for one to switch from pre layernorm to post layernorm configuration in the transformers (as the l2norm effectively takes the place of the pre-layernorm). Cosine sim attention will likely yield results the same as regular attention, without any other changes to the transformer.
Testing
For testing output and gradients are equal for non-autoregressive and autoregressive scenarios
$ python setup.py test
Benchmarking
Make sure to first install the CUDA kernel
$ python setup.py install
Then
$ python benchmark.py
For only benchmarking forwards or backwards, append either --only-forwards or --only-backwards flag to the above. To benchmark autoregressive, append --causal
Benchmarks - wip
GTX 2080 Ti
Forward
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 1.05x kernel: 0.24ms baseline: 0.23ms
seq_len: 256 slower: 1.27x kernel: 0.38ms baseline: 0.30ms
seq_len: 512 slower: 1.28x kernel: 0.87ms baseline: 0.68ms
seq_len: 1024 slower: 1.15x kernel: 2.63ms baseline: 2.28ms
seq_len: 2048 slower: 0.99x kernel: 7.99ms baseline: 8.10ms
seq_len: 4096 slower: 0.88x kernel: 30.82ms baseline: 34.84ms
seq_len: 8192 slower: 0.00x kernel: 121.96ms baseline: oom
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.85x kernel: 0.20ms baseline: 0.24ms
seq_len: 256 slower: 0.97x kernel: 0.24ms baseline: 0.25ms
seq_len: 512 slower: 1.22x kernel: 0.43ms baseline: 0.35ms
seq_len: 1024 slower: 0.95x kernel: 0.93ms baseline: 0.98ms
seq_len: 2048 slower: 0.90x kernel: 3.16ms baseline: 3.50ms
seq_len: 4096 slower: 0.85x kernel: 11.06ms baseline: 13.07ms
seq_len: 8192 slower: 0.00x kernel: 42.61ms baseline: oom
Backwards - still needs work
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 1.07x kernel: 0.61ms baseline: 0.57ms
seq_len: 256 slower: 1.40x kernel: 0.91ms baseline: 0.65ms
seq_len: 512 slower: 1.70x kernel: 2.34ms baseline: 1.38ms
seq_len: 1024 slower: 1.26x kernel: 5.67ms baseline: 4.50ms
seq_len: 2048 slower: 1.29x kernel: 20.60ms baseline: 15.91ms
seq_len: 4096 slower: 1.30x kernel: 78.93ms baseline: 60.81ms
seq_len: 8192 slower: 0.00x kernel: 314.51ms baseline: oom
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.91x kernel: 0.50ms baseline: 0.55ms
seq_len: 256 slower: 1.06x kernel: 0.58ms baseline: 0.55ms
seq_len: 512 slower: 1.13x kernel: 0.81ms baseline: 0.72ms
seq_len: 1024 slower: 0.97x kernel: 2.09ms baseline: 2.16ms
seq_len: 2048 slower: 0.96x kernel: 7.06ms baseline: 7.35ms
seq_len: 4096 slower: 0.97x kernel: 26.08ms baseline: 26.84ms
seq_len: 8192 slower: 0.00x kernel: 101.02ms baseline: oom
Forward & Backwards - F32 is definitely slower
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 1.05x kernel: 0.83ms baseline: 0.79ms
seq_len: 256 slower: 1.34x kernel: 1.26ms baseline: 0.95ms
seq_len: 512 slower: 1.44x kernel: 3.14ms baseline: 2.18ms
seq_len: 1024 slower: 1.15x kernel: 7.83ms baseline: 6.81ms
seq_len: 2048 slower: 1.20x kernel: 28.83ms baseline: 24.03ms
seq_len: 4096 slower: 1.20x kernel: 111.13ms baseline: 92.51ms
seq_len: 8192 slower: 0.00x kernel: 441.70ms baseline: oom
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.89x kernel: 0.68ms baseline: 0.77ms
seq_len: 256 slower: 1.03x kernel: 0.80ms baseline: 0.77ms
seq_len: 512 slower: 1.06x kernel: 1.16ms baseline: 1.10ms
seq_len: 1024 slower: 0.93x kernel: 2.94ms baseline: 3.16ms
seq_len: 2048 slower: 0.93x kernel: 10.06ms baseline: 10.87ms
seq_len: 4096 slower: 0.93x kernel: 37.09ms baseline: 39.96ms
seq_len: 8192 slower: 0.00x kernel: 143.13ms baseline: oom
For autoregressive, a clear win python benchmark.py --causal
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.97x kernel: 0.81ms baseline: 0.84ms
seq_len: 256 slower: 1.07x kernel: 1.12ms baseline: 1.05ms
seq_len: 512 slower: 0.83x kernel: 2.23ms baseline: 2.68ms
seq_len: 1024 slower: 0.55x kernel: 4.83ms baseline: 8.82ms
seq_len: 2048 slower: 0.49x kernel: 15.89ms baseline: 32.68ms
seq_len: 4096 slower: 0.46x kernel: 57.50ms baseline: 126.00ms
seq_len: 8192 slower: 0.00x kernel: 224.76ms baseline: oom
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.82x kernel: 0.69ms baseline: 0.84ms
seq_len: 256 slower: 0.95x kernel: 0.79ms baseline: 0.83ms
seq_len: 512 slower: 0.78x kernel: 1.06ms baseline: 1.37ms
seq_len: 1024 slower: 0.50x kernel: 2.10ms baseline: 4.24ms
seq_len: 2048 slower: 0.37x kernel: 5.85ms baseline: 15.92ms
seq_len: 4096 slower: 0.31x kernel: 19.80ms baseline: 64.42ms
seq_len: 8192 slower: 0.00x kernel: 75.25ms baseline: oom
For variable length sequences with masking, also a clear win. Assume on average 25% of tokens masked out python benchmark.py --mask-prob 0.25
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.95x kernel: 0.84ms baseline: 0.89ms
seq_len: 256 slower: 1.19x kernel: 1.28ms baseline: 1.08ms
seq_len: 512 slower: 1.23x kernel: 3.19ms baseline: 2.59ms
seq_len: 1024 slower: 0.92x kernel: 8.19ms baseline: 8.88ms
seq_len: 2048 slower: 0.92x kernel: 30.08ms baseline: 32.57ms
seq_len: 4096 slower: 0.94x kernel: 123.20ms baseline: 131.22ms
seq_len: 8192 slower: 0.00x kernel: 461.77ms baseline: oom
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.85x kernel: 0.77ms baseline: 0.90ms
seq_len: 256 slower: 0.93x kernel: 0.86ms baseline: 0.93ms
seq_len: 512 slower: 0.93x kernel: 1.31ms baseline: 1.40ms
seq_len: 1024 slower: 0.76x kernel: 3.31ms baseline: 4.35ms
seq_len: 2048 slower: 0.71x kernel: 11.19ms baseline: 15.65ms
seq_len: 4096 slower: 0.70x kernel: 41.27ms baseline: 59.01ms
seq_len: 8192 slower: 0.00x kernel: 158.60ms baseline: oom
A100 40GB (wip)
Thanks goes out to Stability for providing access to A100s for testing. Thanks to Enrico for taking the time to run some benchmarks when I had no access yet.
A100 is still a work in progress. Shared memory is not fully exploited yet. Strangely enough, F32 seems to be doing better than F16
Forwards
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.98x kernel: 0.29ms baseline: 0.30ms
seq_len: 256 slower: 1.19x kernel: 0.35ms baseline: 0.29ms
seq_len: 512 slower: 0.94x kernel: 0.52ms baseline: 0.55ms
seq_len: 1024 slower: 0.75x kernel: 1.23ms baseline: 1.65ms
seq_len: 2048 slower: 0.88x kernel: 4.17ms baseline: 4.73ms
seq_len: 4096 slower: 0.79x kernel: 14.53ms baseline: 18.36ms
seq_len: 8192 slower: 0.64x kernel: 55.01ms baseline: 85.93ms
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.84x kernel: 0.24ms baseline: 0.29ms
seq_len: 256 slower: 1.02x kernel: 0.29ms baseline: 0.29ms
seq_len: 512 slower: 1.24x kernel: 0.36ms baseline: 0.29ms
seq_len: 1024 slower: 1.48x kernel: 0.79ms baseline: 0.54ms
seq_len: 2048 slower: 1.31x kernel: 2.08ms baseline: 1.59ms
seq_len: 4096 slower: 1.21x kernel: 6.89ms baseline: 5.70ms
seq_len: 8192 slower: 1.07x kernel: 24.80ms baseline: 23.15ms
Backwards
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.94x kernel: 0.57ms baseline: 0.60ms
seq_len: 256 slower: 1.29x kernel: 0.75ms baseline: 0.58ms
seq_len: 512 slower: 1.16x kernel: 1.30ms baseline: 1.12ms
seq_len: 1024 slower: 0.98x kernel: 3.14ms baseline: 3.19ms
seq_len: 2048 slower: 1.05x kernel: 11.13ms baseline: 10.63ms
seq_len: 4096 slower: 0.98x kernel: 40.11ms baseline: 40.79ms
seq_len: 8192 slower: 0.97x kernel: 154.96ms baseline: 159.70ms
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.91x kernel: 0.55ms baseline: 0.60ms
seq_len: 256 slower: 1.03x kernel: 0.62ms baseline: 0.60ms
seq_len: 512 slower: 1.36x kernel: 0.82ms baseline: 0.60ms
seq_len: 1024 slower: 1.52x kernel: 1.52ms baseline: 1.01ms
seq_len: 2048 slower: 1.37x kernel: 4.14ms baseline: 3.03ms
seq_len: 4096 slower: 1.33x kernel: 14.23ms baseline: 10.71ms
seq_len: 8192 slower: 1.34x kernel: 53.90ms baseline: 40.28ms
Forwards & Backwards
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.92x kernel: 0.80ms baseline: 0.87ms
seq_len: 256 slower: 1.23x kernel: 1.07ms baseline: 0.87ms
seq_len: 512 slower: 1.08x kernel: 1.80ms baseline: 1.66ms
seq_len: 1024 slower: 0.94x kernel: 4.33ms baseline: 4.62ms
seq_len: 2048 slower: 0.99x kernel: 15.26ms baseline: 15.44ms
seq_len: 4096 slower: 0.93x kernel: 54.78ms baseline: 59.21ms
seq_len: 8192 slower: 0.91x kernel: 210.38ms baseline: 230.97ms
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.90x kernel: 0.78ms baseline: 0.86ms
seq_len: 256 slower: 1.00x kernel: 0.87ms baseline: 0.87ms
seq_len: 512 slower: 1.36x kernel: 1.18ms baseline: 0.86ms
seq_len: 1024 slower: 1.49x kernel: 2.31ms baseline: 1.55ms
seq_len: 2048 slower: 1.33x kernel: 6.17ms baseline: 4.63ms
seq_len: 4096 slower: 1.28x kernel: 21.08ms baseline: 16.44ms
seq_len: 8192 slower: 1.24x kernel: 78.75ms baseline: 63.45ms
Autoregressive
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.82x kernel: 0.82ms baseline: 1.01ms
seq_len: 256 slower: 1.02x kernel: 1.00ms baseline: 0.98ms
seq_len: 512 slower: 0.82x kernel: 1.55ms baseline: 1.89ms
seq_len: 1024 slower: 0.51x kernel: 2.79ms baseline: 5.44ms
seq_len: 2048 slower: 0.45x kernel: 8.37ms baseline: 18.67ms
seq_len: 4096 slower: 0.40x kernel: 29.16ms baseline: 72.97ms
seq_len: 8192 slower: 0.38x kernel: 108.68ms baseline: 285.47ms
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.82x kernel: 0.81ms baseline: 0.98ms
seq_len: 256 slower: 0.90x kernel: 0.88ms baseline: 0.98ms
seq_len: 512 slower: 1.16x kernel: 1.13ms baseline: 0.97ms
seq_len: 1024 slower: 0.80x kernel: 1.68ms baseline: 2.10ms
seq_len: 2048 slower: 0.54x kernel: 3.66ms baseline: 6.81ms
seq_len: 4096 slower: 0.45x kernel: 11.43ms baseline: 25.32ms
seq_len: 8192 slower: 0.41x kernel: 40.58ms baseline: 99.14ms
Variable lengthed sequences (up to 25% tokens masked out)
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.80x kernel: 0.85ms baseline: 1.07ms
seq_len: 256 slower: 1.07x kernel: 1.15ms baseline: 1.08ms
seq_len: 512 slower: 1.00x kernel: 1.94ms baseline: 1.94ms
seq_len: 1024 slower: 0.84x kernel: 4.64ms baseline: 5.55ms
seq_len: 2048 slower: 0.84x kernel: 15.86ms baseline: 18.86ms
seq_len: 4096 slower: 0.76x kernel: 55.19ms baseline: 72.47ms
seq_len: 8192 slower: 0.75x kernel: 212.48ms baseline: 282.71ms
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.80x kernel: 0.83ms baseline: 1.04ms
seq_len: 256 slower: 0.90x kernel: 0.93ms baseline: 1.03ms
seq_len: 512 slower: 1.18x kernel: 1.22ms baseline: 1.04ms
seq_len: 1024 slower: 1.10x kernel: 2.40ms baseline: 2.17ms
seq_len: 2048 slower: 0.89x kernel: 6.27ms baseline: 7.06ms
seq_len: 4096 slower: 0.82x kernel: 21.19ms baseline: 25.95ms
seq_len: 8192 slower: 0.78x kernel: 79.45ms baseline: 101.83ms
Training a small GPT on Enwik8
$ make train
Try 8192 sequence length. It'll be slow but will work (normal attention will break at > 2048, you'll see this if you remove the --use-cuda-kernel flag)
$ python train.py --seq-len 8192 --use-cuda-kernel
Citations
@article{Dao2022FlashAttentionFA,
title = {FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness},
author = {Tri Dao and Daniel Y. Fu and Stefano Ermon and Atri Rudra and Christopher R'e},
journal = {ArXiv},
year = {2022},
volume = {abs/2205.14135}
}
@misc{rabe2021selfattention,
title = {Self-attention Does Not Need $O(n^2)$ Memory},
author = {Markus N. Rabe and Charles Staats},
year = {2021},
eprint = {2112.05682},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
@inproceedings{Henry2020QueryKeyNF,
title = {Query-Key Normalization for Transformers},
author = {Alex Henry and Prudhvi Raj Dachapally and Shubham Vivek Pawar and Yuxuan Chen},
booktitle = {FINDINGS},
year = {2020}
}
@article{Wang2022DeepNetST,
title = {DeepNet: Scaling Transformers to 1, 000 Layers},
author = {Hongyu Wang and Shuming Ma and Li Dong and Shaohan Huang and Dongdong Zhang and Furu Wei},
journal = {ArXiv},
year = {2022},
volume = {abs/2203.00555}
}
.\lucidrains\flash-cosine-sim-attention\setup.py
# 导入必要的库
import sys
from functools import lru_cache
from subprocess import DEVNULL, call
from setuptools import setup, find_packages
import torch
from torch.utils.cpp_extension import CUDAExtension, BuildExtension
# 以下代码来源于指定链接,用于获取版本号
exec(open('flash_cosine_sim_attention/version.py').read())
# 检查是否存在 CUDA 工具包
@lru_cache(None)
def cuda_toolkit_available():
try:
# 尝试调用 nvcc 命令,如果成功则返回 True
call(["nvcc"], stdout = DEVNULL, stderr = DEVNULL)
return True
except FileNotFoundError:
# 如果未找到 nvcc 命令,则返回 False
return False
# 编译参数
def compile_args():
args = ["-fopenmp", "-ffast-math"]
if sys.platform == "darwin":
# 如果是 macOS 系统,则添加额外的编译参数
args = ["-Xpreprocessor", *args]
return args
# 扩展模块
def ext_modules():
if not cuda_toolkit_available():
# 如果 CUDA 工具包不可用,则返回空列表
return []
return [
CUDAExtension(
__cuda_pkg_name__,
sources = ["flash_cosine_sim_attention/flash_cosine_sim_attention_cuda.cu"]
)
]
# 主要设置代码
setup(
name = 'flash-cosine-sim-attention',
packages = find_packages(exclude=[]),
version = __version__,
license='MIT',
description = 'Flash Cosine Similarity Attention',
author = 'Phil Wang',
author_email = 'lucidrains@gmail.com',
long_description_content_type = 'text/markdown',
url = 'https://github.com/lucidrains/flash-cosine-sim-attention',
keywords = [
'artificial intelligence',
'deep learning',
'attention mechanism'
],
install_requires=[
'torch>=1.10'
],
setup_requires=[
'pytest-runner',
],
tests_require=[
'pytest'
],
ext_modules = ext_modules(),
cmdclass = {"build_ext": BuildExtension},
include_package_data = True,
classifiers=[
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\flash-cosine-sim-attention\tests\test.py
import torch
import pytest
from flash_cosine_sim_attention import plain_cosine_sim_attention, flash_cosine_sim_attention
# 检查是否CUDA可用
assert torch.cuda.is_available(), 'cuda must be available'
# 辅助函数
# 检查张量中是否存在NaN或无穷大值
def not_nan_or_infs(t):
return not (torch.any(torch.isnan(t)) or torch.any(torch.isinf(t)))
# 检查两个张量是否在指定的绝对误差范围内相等
def allclose(a, b, atol = 1e-4):
diff = (a - b).abs().amax()
if torch.any(diff > atol):
print(f'diff: {diff}')
return diff <= atol
# 检查张量是否存在
def exists(t):
return t is not None
# 如果张量存在,则将其移动到CPU上
def maybe_cpu(t):
if not exists(t):
return None
return t.cpu()
# 测试
# 参数化测试用例
@pytest.mark.parametrize('causal,mask', [(True, False), (False, True), (False, False)])
@pytest.mark.parametrize('attn_bias', [True, False])
@pytest.mark.parametrize('seq_len', [63, 127])
@pytest.mark.parametrize('dim_head', [32, 64, 96, 128])
@pytest.mark.parametrize('float16', [False, True])
@pytest.mark.parametrize('attn_bias_batch_dim', [False, True])
@pytest.mark.parametrize('single_head_kv', [False, True])
def test_output_equal(
causal,
mask,
attn_bias,
seq_len,
dim_head,
float16,
attn_bias_batch_dim,
single_head_kv
):
batch, heads = 4, 8
dtype, atol = (torch.float16, 1e-1) if float16 else (torch.float32, 1e-4)
kv_shape = (batch, heads, seq_len, dim_head) if not single_head_kv else (batch, seq_len, dim_head)
q = torch.randn(batch, heads, seq_len, dim_head, dtype = dtype).cuda()
k = torch.randn(kv_shape, dtype = dtype).cuda()
v = torch.randn(kv_shape, dtype = dtype).cuda()
attn_mask = torch.randint(0, 2, (batch, seq_len), dtype = torch.bool).cuda() if mask else None
bias = torch.randn(batch if attn_bias_batch_dim else heads, seq_len, seq_len, dtype = dtype).cuda() if attn_bias else None
plain_output = plain_cosine_sim_attention(q, k, v, causal = causal, mask = attn_mask, attn_bias = bias, attn_bias_batch_dim = attn_bias_batch_dim)
flash_output = flash_cosine_sim_attention(q, k, v, causal = causal, mask = attn_mask, attn_bias = bias, attn_bias_batch_dim = attn_bias_batch_dim)
# 断言flash_output中不存在NaN或无穷大值
assert not_nan_or_infs(flash_output)
# 断言plain_output和flash_output在指定的绝对误差范围内相等
assert allclose(plain_output, flash_output, atol = atol)
# 参数化测试用例
@pytest.mark.parametrize('causal,mask', [(True, False), (False, True), (False, False)])
@pytest.mark.parametrize('attn_bias', [True, False])
@pytest.mark.parametrize('seq_len', [63, 127])
@pytest.mark.parametrize('dim_head', [32, 64, 96, 128])
@pytest.mark.parametrize('float16', [False, True])
@pytest.mark.parametrize('attn_bias_batch_dim', [False, True])
@pytest.mark.parametrize('single_head_kv', [False, True])
def test_grad_equal(
causal,
mask,
attn_bias,
seq_len,
dim_head,
float16,
attn_bias_batch_dim,
single_head_kv
):
batch, heads = 4, 8
dtype, atol = (torch.float16, 1e-1) if float16 else (torch.float32, 1e-4)
kv_shape = (batch, heads, seq_len, dim_head)
q = torch.randn(batch, heads, seq_len, dim_head, dtype = dtype).cuda().requires_grad_()
k = torch.randn(kv_shape, dtype = dtype).cuda().requires_grad_()
v = torch.randn(kv_shape, dtype = dtype).cuda().requires_grad_()
attn_mask = torch.randint(0, 2, (batch, seq_len), dtype = torch.bool).cuda() if mask else None
bias = torch.randn(batch if attn_bias_batch_dim else heads, seq_len, seq_len, dtype = dtype).cuda().requires_grad_() if attn_bias else None
plain_output = plain_cosine_sim_attention(q, k, v, causal = causal, mask = attn_mask, attn_bias = bias, attn_bias_batch_dim = attn_bias_batch_dim)
plain_output.sum().backward()
dq, dk, dv = q.grad, k.grad, v.grad
db = bias.grad if attn_bias else None
q.grad, k.grad, v.grad = None, None, None
if attn_bias:
bias.grad = None
flash_output = flash_cosine_sim_attention(q, k, v, causal = causal, mask = attn_mask, attn_bias = bias, attn_bias_batch_dim = attn_bias_batch_dim)
flash_output.sum().backward()
fdq, fdk, fdv = q.grad, k.grad, v.grad
fdb = bias.grad if attn_bias else None
# 断言 fdv 中不存在 NaN 或 无穷大值
assert not_nan_or_infs(fdv)
# 断言 fdk 中不存在 NaN 或 无穷大值
assert not_nan_or_infs(fdk)
# 断言 fdq 中不存在 NaN 或 无穷大值
assert not_nan_or_infs(fdq)
# 断言 dv 与 fdv 之间的所有元素在指定的容差范围内相等
assert allclose(dv, fdv, atol=atol)
# 断言 dk 与 fdk 之间的所有元素在指定的容差范围内相等
assert allclose(dk, fdk, atol=atol)
# 断言 dq 与 fdq 之间的所有元素在指定的容差范围内相等
assert allclose(dq, fdq, atol=atol)
# 如果存在注意力偏置,则断言 fdb 中不存在 NaN 或 无穷大值
if attn_bias:
assert not_nan_or_infs(fdb)
# 断言 db 与 fdb 之间的所有元素在指定的容差范围内相等
assert allclose(db, fdb, atol=atol)
# 测试 CPU 上的函数
# 参数化测试,测试不同的组合情况
@pytest.mark.parametrize('causal,mask', [(True, False), (False, True), (False, False)])
@pytest.mark.parametrize('attn_bias', [True, False])
@pytest.mark.parametrize('seq_len', [63, 127])
@pytest.mark.parametrize('dim_head', [32, 64, 96, 128])
@pytest.mark.parametrize('float16', [False, True])
@pytest.mark.parametrize('attn_bias_batch_dim', [False, True])
@pytest.mark.parametrize('single_head_kv', [False, True])
def test_output_equal_cuda_and_cpu_forward(
causal,
mask,
attn_bias,
seq_len,
dim_head,
float16,
attn_bias_batch_dim,
single_head_kv
):
# 定义 batch 和 heads 的值
batch, heads = 4, 8
# 根据 float16 参数选择数据类型和容差值
dtype, atol = (torch.float16, 1e-1) if float16 else (torch.float32, 1e-4)
# 根据 single_head_kv 参数确定 kv_shape 的形状
kv_shape = (batch, heads, seq_len, dim_head) if not single_head_kv else (batch, seq_len, dim_head)
# 生成随机的 q, k, v 张量,并移动到 GPU 上
q = torch.randn(batch, heads, seq_len, dim_head, dtype = dtype).cuda()
k = torch.randn(kv_shape, dtype = dtype).cuda()
v = torch.randn(kv_shape, dtype = dtype).cuda()
# 根据 mask 参数生成注意力掩码
attn_mask = torch.randint(0, 2, (batch, seq_len), dtype = torch.bool).cuda() if mask else None
# 根据 attn_bias 参数生成偏置
bias = torch.randn(batch if attn_bias_batch_dim else heads, seq_len, seq_len, dtype = dtype).cuda() if attn_bias else None
# 在 GPU 上调用 flash_cosine_sim_attention 函数
flash_output = flash_cosine_sim_attention(q, k, v, causal = causal, mask = attn_mask, attn_bias = bias, attn_bias_batch_dim = attn_bias_batch_dim)
# 在 CPU 上调用 flash_cosine_sim_attention 函数
flash_output_cpu = flash_cosine_sim_attention(q.cpu(), k.cpu(), v.cpu(), causal = causal, mask = maybe_cpu(attn_mask), attn_bias = maybe_cpu(bias), attn_bias_batch_dim = attn_bias_batch_dim)
# 断言两个输出是否相等
assert allclose(flash_output.cpu(), flash_output_cpu, atol = atol)
.\lucidrains\flash-cosine-sim-attention\tests\__init__.py
# 定义一个名为calculate_area的函数,用于计算矩形的面积
def calculate_area(length, width):
# 计算矩形的面积
area = length * width
# 返回计算得到的面积
return area
.\lucidrains\flash-cosine-sim-attention\train.py
# 导入所需的模块和类
from flash_cosine_sim_attention.transformer import CosineSimCausalTransformer
import argparse
import random
import tqdm
import gzip
import numpy as np
import torch
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
from torch.cuda.amp import autocast, GradScaler
# 解析命令行参数
parser = argparse.ArgumentParser()
parser.add_argument('--use-cuda-kernel', default = False, action = 'store_true')
parser.add_argument('--use-float32', default = False, action = 'store_true')
parser.add_argument('--seq-len', default = 1024, type = int)
args = parser.parse_args()
# 定义常量
NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 2e-4
VALIDATE_EVERY = 100
GENERATE_EVERY = 500
GENERATE_LENGTH = 512
SEQ_LEN = args.seq_len
USE_AMP = not args.use_float32
print(f'\ntraining at sequence length {args.seq_len} with {"float32" if args.use_float32 else "float16"}\n')
# 定义辅助函数
def cycle(loader):
while True:
for data in loader:
yield data
def decode_token(token):
return str(chr(max(32, token)))
def decode_tokens(tokens):
return ''.join(list(map(decode_token, tokens)))
# 实例化类似 GPT 的解码器模型
model = CosineSimCausalTransformer(
num_tokens = 256,
dim = 512,
depth = 8,
attn_scale = 1,
attn_l2norm_groups = 8,
dim_head = 64,
pre_norm = True,
non_cosine_sim_attn = False,
max_seq_len = SEQ_LEN,
use_cuda_kernel = args.use_cuda_kernel
)
model.cuda()
# 准备 enwik8 数据
with gzip.open('./data/enwik8.gz') as file:
x = np.array(np.frombuffer(file.read(int(95e6)), dtype = np.uint8))
train_x, valid_x = np.split(x, [int(90e6)])
data_train, data_val = torch.from_numpy(train_x), torch.from_numpy(valid_x)
# 定义数据集类
class TextSamplerDataset(Dataset):
def __init__(self, data, seq_len):
super().__init__()
self.data = data
self.seq_len = seq_len
def __getitem__(self, index):
rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
return full_seq.cuda()
def __len__(self):
return self.data.size(0) // self.seq_len
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))
# 定义优化器
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
scaler = GradScaler(enabled = USE_AMP)
# 训练模型
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
model.train()
optim.zero_grad()
for __ in range(GRADIENT_ACCUMULATE_EVERY):
with autocast(enabled = USE_AMP):
loss = model(next(train_loader), return_loss = True)
scaler.scale(loss / GRADIENT_ACCUMULATE_EVERY).backward()
print(f'training loss: {loss.item()}')
scaler.unscale_(optim)
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
scaler.step(optim)
scaler.update()
if i % VALIDATE_EVERY == 0:
model.eval()
with torch.no_grad():
loss = model(next(val_loader), return_loss = True)
print(f'validation loss: {loss.item()}')
if i % GENERATE_EVERY == 0:
model.eval()
inp = random.choice(val_dataset)[:-1]
prime = decode_tokens(inp)
print(f"\n\n {prime} \n\n {'-' * 80} \n")
sample = model.generate(inp[None, ...], GENERATE_LENGTH)
output_str = decode_tokens(sample[0])
print(output_str + "\n\n")
.\lucidrains\flash-genomics-model\flash_genomics_model\attend.py
# 导入必要的库
from collections import namedtuple
from functools import wraps
from packaging import version
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange
# 定义一个命名元组EfficientAttentionConfig,包含三个布尔类型的参数
EfficientAttentionConfig = namedtuple('EfficientAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
# 定义辅助函数
# 判断变量是否存在
def exists(val):
return val is not None
# 保证函数只执行一次的装饰器
def once(fn):
called = False
@wraps(fn)
def inner(x):
nonlocal called
if called:
return
called = True
return fn(x)
return inner
# 用once装饰的print函数,确保只打印一次
print_once = once(print)
# 主要类
class Attend(nn.Module):
def __init__(
self,
causal = False,
dropout = 0.,
flash = False
):
super().__init__()
self.dropout = dropout
self.attn_dropout = nn.Dropout(dropout)
self.causal = causal
self.flash = flash
assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
# 确定用于cuda和cpu的高效注意力配置
self.cpu_config = EfficientAttentionConfig(True, True, True)
self.cuda_config = None
if not torch.cuda.is_available() or not flash:
return
device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
if device_properties.major == 8 and device_properties.minor == 0:
print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
self.cuda_config = EfficientAttentionConfig(True, False, False)
else:
print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
self.cuda_config = EfficientAttentionConfig(False, True, True)
# 生成掩码
def get_mask(self, i, j, device):
return torch.ones((i, j), device=device, dtype=torch.bool).triu(j - i + 1)
# Flash Attention函数
def flash_attn(self, q, k, v, mask = None, attn_bias = None):
_, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device
# 单头键/值
if k.ndim == 3:
k = rearrange(k, 'b n d -> b 1 n d')
if v.ndim == 3:
v = rearrange(v, 'b n d -> b 1 n d')
# 检查掩码是否存在并扩展到兼容的形状
# 掩码是B L,因此必须扩展为B H N L
if exists(mask) and mask.ndim != 4:
mask = rearrange(mask, 'b j -> b 1 1 j')
mask = mask.expand(-1, heads, q_len, -1)
# 检查是否有兼容的设备用于Flash Attention
config = self.cuda_config if is_cuda else self.cpu_config
causal = self.causal
# 处理注意力偏置
if exists(attn_bias):
mask_value = -torch.finfo(q.dtype).max // 2
causal_mask = self.get_mask(q_len, k_len, device)
attn_bias = attn_bias.masked_fill(causal_mask, mask_value)
if exists(mask):
attn_bias = attn_bias.masked_fill(~mask, mask_value)
mask = attn_bias
causal = False
# 使用torch.backends.cuda.sdp_kernel(**config._asdict())来调用Flash Attention
with torch.backends.cuda.sdp_kernel(**config._asdict()):
out = F.scaled_dot_product_attention(
q, k, v,
attn_mask = mask,
dropout_p = self.dropout if self.training else 0.,
is_causal = causal
)
return out
# 定义一个前向传播函数,接受查询(q)、键(k)、值(v)、掩码(mask)和注意力偏置(attn_bias)作为参数
def forward(self, q, k, v, mask = None, attn_bias = None):
"""
einstein notation
b - batch
h - heads
n, i, j - sequence length (base sequence length, source, target)
d - feature dimension
"""
# 获取查询(q)和键(k)的序列长度以及设备信息
q_len, k_len, device = q.shape[-2], k.shape[-2], q.device
# 计算缩放因子
scale = q.shape[-1] ** -0.5
# 根据键(k)的维度确定 einsum 的等式
kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d'
# 如果启用了 flash 模式,则调用 flash_attn 函数
if self.flash:
return self.flash_attn(q, k, v, mask = mask, attn_bias = attn_bias)
# 计算相似度
sim = einsum(f"b h i d, {kv_einsum_eq} -> b h i j", q, k) * scale
# 添加注意力偏置
if exists(attn_bias):
sim = sim + attn_bias
# 如果启用了因果关系,生成因果掩码
if self.causal:
causal_mask = self.get_mask(q_len, k_len, device)
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
# 计算注意力权重
attn = sim.softmax(dim=-1)
attn = self.attn_dropout(attn)
# 聚合值
out = einsum(f"b h i j, {kv_einsum_eq} -> b h i d", attn, v)
return out
.\lucidrains\flash-genomics-model\flash_genomics_model\flash_genomics_model.py
# 导入 torch 库,包括神经网络模块和函数模块
import torch
# 导入 torch 中的函数模块
import torch.nn.functional as F
# 从 torch 中导入 nn、einsum、Tensor 模块
from torch import nn, einsum, Tensor
# 从 einops 库中导入 rearrange、reduce 函数
from einops import rearrange, reduce
# 从 flash_genomics_model.attend 模块中导入 Attend 类
from flash_genomics_model.attend import Attend
# functions
# attention
# 定义 Attention 类,用于实现注意力机制
class Attention(nn.Module):
def __init__(
self,
dim,
dim_head = 64,
heads = 8,
flash = True
):
super().__init__()
self.heads = heads
dim_inner = heads * dim_head
# 创建 Attend 类的实例
self.attend = Attend(flash = flash)
# 定义将输入转换为查询、键、值的线性变换
self.to_qkv = nn.Linear(dim, dim_inner * 3, bias = False)
# 定义将输出转换为最终输出的线性变换
self.to_out = nn.Linear(dim_inner, dim, bias = False)
def forward(
self,
x,
mask = None
):
h = self.heads
# 将输入 x 转换为查询 q、键 k、值 v
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
# 将查询 q、键 k、值 v 重排维度
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
# 使用 Attend 类实现注意力机制
out = self.attend(q, k, v, mask = mask)
# 将输出重排维度并通过线性变换得到最终输出
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
# main class
# 定义 FlashGenomicsModel 类,继承自 nn.Module 类
class FlashGenomicsModel(nn.Module):
def __init__(self):
super().__init__()
# 实现前向传播函数,返回输入 x
def forward(self, x):
return x