Lucidrains 系列项目源码解析(八十四)
.\lucidrains\RETRO-pytorch\retro_pytorch\utils.py
# 导入 os 模块
import os
# 导入 numpy 模块并重命名为 np
import numpy as np
# 从 pathlib 模块中导入 Path 类
from pathlib import Path
# 从 shutil 模块中导入 rmtree 函数
from shutil import rmtree
# 从 contextlib 模块中导入 contextmanager 装饰器
from contextlib import contextmanager
# 检查环境变量是否为真
def is_true_env_flag(env_flag):
return os.getenv(env_flag, 'false').lower() in ('true', '1', 't')
# 重置文件夹
def reset_folder_(p):
# 创建 Path 对象
path = Path(p)
# 删除文件夹及其内容,如果文件夹不存在则忽略错误
rmtree(path, ignore_errors = True)
# 创建文件夹,如果文件夹已存在则忽略
path.mkdir(exist_ok = True, parents = True)
# 创建内存映射对象的上下文管理器
@contextmanager
def memmap(*args, **kwargs):
# 创建内存映射对象
pointer = np.memmap(*args, **kwargs)
# 通过 yield 将指针传递给调用者
yield pointer
# 在退出上下文管理器时删除内存映射对象
del pointer
.\lucidrains\RETRO-pytorch\retro_pytorch\__init__.py
# 从 retro_pytorch.retro_pytorch 模块中导入 RETRO 类
# 从 retro_pytorch.data 模块中导入 RETRODataset 类
# 从 retro_pytorch.training 模块中导入 TrainingWrapper 类
from retro_pytorch.retro_pytorch import RETRO
from retro_pytorch.data import RETRODataset
from retro_pytorch.training import TrainingWrapper
.\lucidrains\RETRO-pytorch\setup.py
# 导入设置工具和查找包的函数
from setuptools import setup, find_packages
# 设置包的元数据
setup(
name = 'retro-pytorch', # 包的名称
packages = find_packages(exclude=[]), # 查找所有包
version = '0.3.9', # 版本号
license='MIT', # 许可证
description = 'RETRO - Retrieval Enhanced Transformer - Pytorch', # 描述
long_description_content_type = 'text/markdown', # 长描述内容类型
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
url = 'https://github.com/lucidrains/RETRO-pytorch', # 项目链接
keywords = [ # 关键词列表
'artificial intelligence',
'deep learning',
'transformers',
'attention-mechanism',
'retrieval',
],
install_requires=[ # 安装依赖
'autofaiss',
'einops>=0.3',
'numpy',
'sentencepiece',
'torch>=1.6',
'tqdm'
],
classifiers=[ # 分类器
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\ring-attention-pytorch\assert.py
# 导入必要的库
import os
from math import ceil
from copy import deepcopy
import torch
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from ring_attention_pytorch.ring_attention import RingTransformer
from ring_attention_pytorch.distributed import all_gather_variable_dim
# 设置分布式训练环境
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
dist.init_process_group("gloo", rank=rank, world_size=world_size)
# 清理分布式训练环境
def cleanup():
dist.destroy_process_group()
# 启动训练
def start(
rank,
world_size,
batch_size,
batch_size_var_len,
seq_len,
num_sharded_batches,
causal,
striped_ring_attn,
dim,
use_cuda
):
# 设置分布式训练环境
setup(rank, world_size)
# 计算环形注意力网络的参数
ring_seq_size = ceil(seq_len / world_size) * num_sharded_batches
bucket_size = ring_seq_size // 2
# 创建环形注意力网络和普通注意力网络
ring_attention_net = RingTransformer(
num_tokens=256,
dim=dim,
causal=causal,
depth=2,
dim_head=8,
ring_attn=True,
striped_ring_attn=striped_ring_attn,
ring_seq_size=ring_seq_size,
bucket_size=bucket_size
)
flash_attention_net = RingTransformer(
num_tokens=256,
dim=dim,
causal=causal,
depth=2,
dim_head=8,
ring_attn=False,
bucket_size=bucket_size
)
# 加载环形注意力网络的参数到普通注意力网络
flash_attention_net.load_state_dict(ring_attention_net.state_dict())
# 根据是否变长批次更新批次大小
if batch_size_var_len:
batch_size = batch_size + rank
# 生成随机序列
seq = torch.randint(0, 256, (batch_size, seq_len))
# 封装成分布式数据并行模型
ddp_ring_attention_net = DDP(ring_attention_net)
ddp_flash_attention_net = DDP(flash_attention_net)
# 如果使用 GPU,将数据和模型移动到对应 GPU
if use_cuda:
seq = inputs.cuda(rank)
flash_attention_net.cuda(rank)
ring_attention_net.cuda(rank)
# 在普通注意力网络上进行前向传播和反向传播
flash_out = ddp_flash_attention_net(seq)
flash_out.mean().backward()
# 在环形注意力网络上进行前向传播和反向传播
ring_out = ddp_ring_attention_net(seq)
ring_out.mean().backward()
# 验��序列跨多台机器和不跨机器时输出是否相同
if rank == 0:
ring_attention_net = ring_attention_net.cpu()
flash_attention_net = flash_attention_net.cpu()
ring_out = ring_out.cpu()
flash_out = flash_out.cpu()
assert torch.allclose(ring_out, flash_out, atol=1e-6), 'output is not the same'
# 验证环形和非环形注意力网络的 token embedding 梯度是否相同
get_embed_grad = lambda model: model.token_emb.weight.grad
ring_embed_grad = get_embed_grad(ring_attention_net)
flash_embed_grad = get_embed_grad(flash_attention_net)
assert torch.allclose(
ring_embed_grad,
flash_embed_grad,
atol=1e-2
), 'grad is not the same'
print('✅ outputs and gradients are same between ring attention and non-ring attention')
# 清理分布式训练环境
cleanup()
# 主函数入口
if __name__ == '__main__':
# 设置参数
world_size = 8
batch_size = 2
num_sharded_batches = 1
batch_size_var_len = False
use_cuda = False
causal = True
striped_ring_attn = True
# 断言检查是否使用 GPU 数量小于等于机器数量
assert not use_cuda or torch.cuda.device_count() <= world_size
seq_len = 31
dim = 8
# 多进程启动训练
mp.spawn(
start,
args=(
world_size,
batch_size,
batch_size_var_len,
seq_len,
num_sharded_batches,
causal,
striped_ring_attn,
dim,
use_cuda
),
nprocs=world_size,
join=True
)
.\lucidrains\ring-attention-pytorch\assert_flash.py
# 导入 torch 库
import torch
# 从 ring_attention_pytorch 模块中导入 default_attention 和 ring_flash_attn 函数
from ring_attention_pytorch import (
default_attention,
ring_flash_attn
)
# 定义变量
# 是否使用因果关系
causal = True
# 序列长度
seq_len = 62
# 桶大小
bucket_size = 4
# 基础的 qkv
# 随机生成 q 张量,形状为 (2, seq_len, 2, 16)
q = torch.randn(2, seq_len, 2, 16)
# 随机生成 k 张量,形状为 (2, seq_len, 2, 16)
k = torch.randn(2, seq_len, 2, 16)
# 随机生成 v 张量,形状为 (2, seq_len, 2, 16)
v = torch.randn(2, seq_len, 2, 16)
# flash 和 regular qkv
# 克隆 q 张量,并设置 requires_grad 为 True
fq = q.clone().requires_grad_()
# 克隆 k 张量,并设置 requires_grad 为 True
fk = k.clone().requires_grad_()
# 克隆 v 张量,并设置 requires_grad 为 True
fv = v.clone().requires_grad_()
# 克隆 q 张量,并设置 requires_grad 为 True
rq = q.clone().requires_grad_()
# 克隆 k 张量,并设置 requires_grad 为 True
rk = k.clone().requires_grad_()
# 克隆 v 张量,并设置 requires_grad 为 True
rv = v.clone().requires_grad_()
# 前向传播
# 使用 default_attention 函数计算输出 o
o = default_attention(rq, rk, rv, causal=causal)
# 使用 ring_flash_attn 函数计算输出 fo
fo = ring_flash_attn(fq, fk, fv, bucket_size=bucket_size, causal=causal)
# 断言 o 和 fo 的值在给定的容差范围内相等
assert torch.allclose(o, fo, atol=1e-6)
# 反向传播
# 对 o 求和并进行反向传播
o.sum().backward()
# 对 fo 求和并进行反向传播
fo.sum().backward()
# 断言 rq.grad 和 fq.grad 的值在给定的容差范围内相等
assert torch.allclose(rq.grad, fq.grad, atol=1e-6)
# 断言 rk.grad 和 fk.grad 的值在给定的容差范围内相等
assert torch.allclose(rk.grad, fk.grad, atol=1e-6)
# 断言 rv.grad 和 fv.grad 的值在给定的容差范围内相等
assert torch.allclose(rv.grad, fv.grad, atol=1e-6)

Ring Attention - Pytorch
Explorations into Ring Attention, from Liu et al. at Berkeley AI.
It basically splits the data across the sequence dimension (instead of batch) and applies ring reduce to the processing of the tiles of the attention matrix, flash attention style.
I believe this is being used for the 1-10 million tokens for the latest Gemini. At least some form of it; the other possibility would be unpublished improvements on top of RMT.
In addition, the repository also contains the logic for Striped Attention, a follow up paper that permutes the sequence for better workload balancing for autoregressive transformers.
Appreciation
- A16Z Open Source AI Grant Program for the generous sponsorship, as well as my other sponsors, for affording me the independence to open source current artificial intelligence research
Install
$ pip install ring-attention-pytorch
Usage
import torch
from ring_attention_pytorch import RingAttention
attn = RingAttention(
dim = 512,
dim_head = 64,
heads = 8,
causal = True,
auto_shard_seq = True,
ring_attn = True,
ring_seq_size = 512
)
tokens = torch.randn(1, 1024, 512)
attended = attn(tokens)
assert attended.shape == tokens.shape
Test
$ python assert.py
Todo
-
make it work with derived causal mask based on rank and chunk sizes
-
modify flash attention to output intermediates and figure out backwards with recompute and ring passes
-
functions for splitting the sequence evenly among ranks, either within attention function, or in the external ring transformer wrapper
-
basic test case with two processes and check for equivalent output and gradients
-
testing
- make sure key padding mask works
- make sure causal mask works
- rotary embeddings, with proper key/value offset depending on ring rank
-
striped attention
- add the permutating logic before and after transformer
- add causal masking logic - account for sub bucketing by flash attention
-
fix issue with ring attention when flash buckets > 1
-
move flash attention back to key / value column traversal on outer loop and save on ring communication
- backwards
- forwards
-
fix rotary positions for striped ring attention when flash buckets > 1
-
allow for variable ring passes per layer, for local -> global attention in ring transformer as one goes up the layers.
-
when doing ring passes, alternate between designated send and receive buffers
-
instead of max ring passes, able to specify lookback in terms of sequence length, and derive number of flash attention bucket + ring passes from that
-
ability to have ring size < world size, sharding the batch and sequence, and doing ring reduce with the correct set of ranks
-
add flash attention kernel version in the presence of cuda
- for forwards, use modified Triton flash attention forwards that outputs row sums, maxes, and exponentiated weighted sum
- for backwards, use Tri's flash attention kernels, accumulate dq, dk, dv across rings
- refactor to have naive ring+flash attention work with
(batch, seq, head, dim) - handle key padding mask for forwards by translating mask to bias
- figure out how Tri handles key padding mask for backwards
- scale output of flash attention forwards on the last ring pass reduce
- verify backwards working in a100 runpod
- dk, dv needs to be float32, while kv needs to be float16. see if both can be cast to int before stacked and ring passed all in one go, then reinterpret back to float32 and float16
- prevent an unnecessary
tl.loadon the first ring pass - cuda backwards pass must have same dq, dk, dv as naive
-
fix naive flash attention backwards
-
validate cuda causal and striped ring attention works
-
find a machine with 8 GPUs and test with a quarter million tokens first
-
think about how to craft a special
Datasetthat shards across sequence length (take into account labels for cross entropy loss) for ring transformer training -
add ring attention to Tri's flash attention implementation. find some cuda ring reduce impl
-
batch_isend_irecvin the presence of key padding mask needing ring exchange, but not a big priority -
figure out how to pytest distributed pytorch
-
use sdp context manager to validate when it is possible to use
ring_flash_attn_cuda, otherwise assert out
Citations
@article{Liu2023RingAW,
title = {Ring Attention with Blockwise Transformers for Near-Infinite Context},
author = {Hao Liu and Matei Zaharia and Pieter Abbeel},
journal = {ArXiv},
year = {2023},
volume = {abs/2310.01889},
url = {https://api.semanticscholar.org/CorpusID:263608461}
}
@article{Brandon2023StripedAF,
title = {Striped Attention: Faster Ring Attention for Causal Transformers},
author = {William Brandon and Aniruddha Nrusimha and Kevin Qian and Zachary Ankner and Tian Jin and Zhiye Song and Jonathan Ragan-Kelley},
journal = {ArXiv},
year = {2023},
volume = {abs/2311.09431},
url = {https://api.semanticscholar.org/CorpusID:265220849}
}
@article{Dao2022FlashAttentionFA,
title = {FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness},
author = {Tri Dao and Daniel Y. Fu and Stefano Ermon and Atri Rudra and Christopher R'e},
journal = {ArXiv},
year = {2022},
volume = {abs/2205.14135}
}
@article{dao2023flashattention2,
title = {Flash{A}ttention-2: Faster Attention with Better Parallelism and Work Partitioning,
author = {Dao, Tri},
year = {2023}
}
@article{Tillet2019TritonAI,
title = {Triton: an intermediate language and compiler for tiled neural network computations},
author = {Philippe Tillet and H. Kung and D. Cox},
journal = {Proceedings of the 3rd ACM SIGPLAN International Workshop on Machine Learning and Programming Languages},
year = {2019}
}
.\lucidrains\ring-attention-pytorch\ring_attention_pytorch\distributed.py
# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块
from torch import nn
# 从 torch.nn 模块中导入 Module 类
from torch.nn import Module
# 从 torch.nn.functional 模块中导入 F 函数
import torch.nn.functional as F
# 从 torch.autograd 模块中导入 Function 类
from torch.autograd import Function
# 导入 torch.distributed 模块
import torch.distributed as dist
# 定义函数,判断变量是否存在
def exists(val):
return val is not None
# 定义函数,如果变量存在则返回该变量,否则返回默认值
def default(val, d):
return val if exists(val) else d
# 定义函数,判断两个数是否整除
def divisible_by(num, den):
return (num % den) == 0
# 定义函数,将张量在指定维度上填充到指定长度
def pad_dim_to(t, length, dim = 0):
pad_length = length - t.shape[dim]
zero_pairs = (-dim - 1) if dim < 0 else (t.ndim - dim - 1)
return F.pad(t, (*((0, 0) * zero_pairs), 0, pad_length)
# 定义函数,将所有进程中的张量在相同维度上聚合
def all_gather_same_dim(t):
t = t.contiguous()
world_size = dist.get_world_size()
gathered_tensors = [torch.empty_like(t, device = t.device, dtype = t.dtype) for i in range(world_size)]
dist.all_gather(gathered_tensors, t)
return gathered_tensors
# 定义函数,收集张量在指定维度上的大小信息
def gather_sizes(t, *, dim):
size = torch.tensor(t.shape[dim], device = t.device, dtype = torch.long)
sizes = all_gather_same_dim(size)
return torch.stack(sizes)
# 定义函数,判断张量是否只有一个值
def has_only_one_value(t):
return (t == t[0]).all()
# 定义函数,将所有进程中的张量在指定维度上聚合,并处理变长情况
def all_gather_variable_dim(t, dim = 0, sizes = None):
device, rank, world_size = t.device, dist.get_rank(), dist.get_world_size()
if not exists(sizes):
sizes = gather_sizes(t, dim = dim)
if has_only_one_value(sizes):
gathered_tensors = all_gather_same_dim(t)
gathered_tensors = torch.cat(gathered_tensors, dim = dim)
return gathered_tensors, sizes
max_size = sizes.amax().item()
padded_t = pad_dim_to(t, max_size, dim = dim)
gathered_tensors = all_gather_same_dim(padded_t)
gathered_tensors = torch.cat(gathered_tensors, dim = dim)
seq = torch.arange(max_size, device = device)
mask = einx.less('j, i -> (i j)', seq, sizes)
seq = torch.arange(mask.shape[-1], device = device)
indices = seq[mask]
gathered_tensors = gathered_tensors.index_select(dim, indices)
return gathered_tensors, sizes
# 定义自定义函数类,用于实现分布式全局聚合
class AllGatherFunction(Function):
@staticmethod
def forward(ctx, x, dim, sizes):
x, batch_sizes = all_gather_variable_dim(x, dim = dim, sizes = sizes)
ctx.batch_sizes = batch_sizes.tolist()
ctx.dim = dim
return x, batch_sizes
@staticmethod
def backward(ctx, grads, _):
batch_sizes, rank = ctx.batch_sizes, dist.get_rank()
grads_by_rank = grads.split(batch_sizes, dim = ctx.dim)
return grads_by_rank[rank], None, None
# 定义类,实现分布式全局聚合
class AllGather(Module):
def __init__(self, *, dim = 0):
super().__init__()
self.dim = dim
def forward(self, x, sizes = None):
return AllGatherFunction.apply(x, self.dim, sizes)
# 定义函数,根据进程编号拆分张量
def split_by_rank(x):
rank = dist.get_rank()
out = x[rank]
if isinstance(x, tuple):
sizes = tuple(map(lambda t: t.shape[0], x))
else:
sizes = (x.shape[1],) * x.shape[0]
sizes = torch.tensor(sizes, device = out.device, dtype = torch.long)
return out, sizes
.\lucidrains\ring-attention-pytorch\ring_attention_pytorch\ring.py
# 导入必要的模块
from typing import Optional
from functools import lru_cache, partial, wraps
from collections import namedtuple
import torch
from torch import nn, Tensor
from torch.nn import Module, ModuleList
from torch.autograd import Function
import torch.distributed as dist
# 辅助函数
# 检查变量是否存在
def exists(v):
return v is not None
# 如果变量存在则返回该变量,否则返回默认值
def default(v, d):
return v if exists(v) else d
# 将输入转换为元组,如果不是元组则重复多次
def cast_tuple(t, length = 1):
return t if isinstance(t, tuple) else ((t,) * length)
# 缓存装饰器,用于缓存函数的结果
cache = partial(lru_cache, maxsize = None)
# 分布式全局变量
# 获取当前进程的排名
@cache()
def get_rank():
return dist.get_rank() if dist.is_initialized() else 0
# 获取世界中进程的数量
@cache()
def get_world_size():
return dist.get_world_size() if dist.is_initialized() else 1
# 判断是否处于分布式环境
@cache()
def is_distributed():
return dist.is_initialized() and dist.get_world_size() > 1
# 环函数
# 左循环索引
def circular_index_left(pos, ring_size, num = 1):
return ((pos - num) + ring_size) % ring_size
# 右循环索引
def circular_index_right(pos, ring_size, num = 1):
return (pos + num) % ring_size
# 分布式环
# 左循环排名
def circular_rank_left(rank = None, ring_size = None, num = 1):
rank = default(rank, get_rank())
ring_size = default(ring_size, get_world_size())
ring_set_num = rank // ring_size
offset = ring_set_num * ring_size
return circular_index_left(rank, ring_size, num) + offset
# 右循环排名
def circular_rank_right(rank = None, ring_size = None, num = 1):
rank = default(rank, get_rank())
ring_size = default(ring_size, get_world_size())
ring_set_num = rank // ring_size
offset = ring_set_num * ring_size
return circular_index_right(rank, ring_size, num) + offset
# 单次环传递
# 发送和接收数据
def send_and_receive_(x, receive_buffer, send_to_rank, receive_from_rank):
send_op = dist.P2POp(dist.isend, x, send_to_rank)
recv_op = dist.P2POp(dist.irecv, receive_buffer, receive_from_rank)
reqs = dist.batch_isend_irecv([send_op, recv_op])
for req in reqs:
req.wait()
dist.barrier()
# 环传递
def ring_pass(
num_ring_passes: int,
x: Tensor,
receive_buffer: Optional[Tensor] = None,
ring_size: Optional[int] = None
):
ring_size = default(ring_size, get_world_size())
x = x.contiguous()
if not exists(receive_buffer):
receive_buffer = torch.zeros_like(x)
else:
receive_buffer = receive_buffer.contiguous()
send_and_receive_(x, receive_buffer, circular_rank_right(ring_size = ring_size), circular_rank_left(ring_size = ring_size))
return receive_buffer, x
# 一次环传递
one_ring_pass = partial(ring_pass, 1)
# 迭代器,用于所有张量的所有环传递
# 环信息命名元组
RingInfo = namedtuple('RingInfo', ['ring_rank', 'iter_info'])
# 空环传递
def null_ring_pass(*tensors, max_iters = None, receive_buffers = None, ring_size = None):
yield RingInfo(0, (True, True)), (tensors, receive_buffers)
# 所有环传递
def all_ring_pass(*tensors, max_iters = None, receive_buffers = None, ring_size = None):
ring_size = default(ring_size, get_world_size())
max_iters = default(max_iters, ring_size)
receive_buffers = cast_tuple(receive_buffers, len(tensors))
# 确保迭代次数在1和世界大小之间
total_iters = max(1, min(ring_size, max_iters))
curr_ring_pos = get_rank()
for ind in range(total_iters):
is_first = ind == 0
is_last = ind == (total_iters - 1)
yield RingInfo(curr_ring_pos, (is_first, is_last)), (tensors, receive_buffers)
curr_ring_pos = circular_index_left(curr_ring_pos, ring_size)
if is_last:
continue
new_tensors = []
new_receive_buffers = []
for tensor, receive_buffer in zip(tensors, receive_buffers):
if exists(tensor):
new_tensor, new_receive_buffer = one_ring_pass(tensor, receive_buffer, ring_size)
else:
new_tensor, new_receive_buffer = None, None
new_tensors.append(new_tensor)
new_receive_buffers.append(new_receive_buffer)
tensors = new_tensors
receive_buffers = new_receive_buffers
.\lucidrains\ring-attention-pytorch\ring_attention_pytorch\ring_attention.py
# 导入必要的库
from typing import Optional, Tuple, Union
import torch
from torch import nn, einsum, Tensor
import torch.nn.functional as F
from torch.cuda.amp import autocast
from torch.nn import Module, ModuleList
import einx
from einx import rearrange
from beartype import beartype
# 导入自定义模块和函数
from ring_attention_pytorch.ring import (
all_ring_pass,
is_distributed,
get_rank,
get_world_size
)
from ring_attention_pytorch.ring_flash_attention import (
ring_flash_attn
)
from ring_attention_pytorch.distributed import (
split_by_rank,
AllGather
)
# 辅助函数
# 检查变量是否存在
def exists(v):
return v is not None
# 如果变量存在则返回变量,否则返回默认值
def default(v, d):
return v if exists(v) else d
# 将输入转换为元组,如果输入已经是元组则返回,否则返回包含输入的元组
def cast_tuple(t, length = 1):
return t if isinstance(t, tuple) else ((t,) * length)
# 检查一个数是否可以被另一个数整除
def divisible_by(num, den):
return (num % den) == 0
# 默认的注意力函数
def default_attention(
q: Tensor,
k: Tensor,
v: Tensor,
mask: Optional[Tensor] = None,
causal: bool = False
):
q = q * (q.shape[-1] ** -0.5)
mask_value = -torch.finfo(q.dtype).max
# 相似度计算
sim = einsum('b i h d, b j h d -> b h i j', q, k)
# 掩码处理
if causal:
i, j = sim.shape[-2:]
causal_mask = torch.ones((i, j), dtype = torch.bool).triu(j - i + 1)
sim = torch.where(causal_mask, mask_value, sim)
elif exists(mask):
sim = einx.where('b j, b h i j, -> b h i j', mask, sim, mask_value)
# 注意力计算
attn = einx.softmax('b h i [j]', sim)
# 聚合
out = einsum('b h i j, b j h d -> b i h d', attn, v)
return out
# 旋转嵌入,支持条纹注意力的修改
class RingRotaryEmbedding(Module):
def __init__(
self,
dim,
ring: bool = False,
striped: bool = False,
buckets: int = 1, # 在带有 flash buckets > 1 的条纹注意力中,需要指定每台机器的桶数
theta = 10000
):
super().__init__()
self.ring = ring
self.striped = striped
self.buckets = buckets
inv_freq = theta ** -(torch.arange(0, dim, 2).float() / dim)
self.register_buffer('inv_freq', inv_freq)
@property
def device(self):
return self.inv_freq.device
@autocast(enabled = False)
def forward(
self,
seq_len: int,
offset = 0
):
device = self.device
pos = None
if self.ring:
if self.striped:
buckets = self.buckets
ring_stride = get_world_size() * buckets
ring_offset = buckets
pos = torch.arange(seq_len // buckets, device = device)
pos = rearrange('n -> n b', pos, b = buckets)
pos = pos * ring_stride
pos += torch.arange(buckets, device = device) + (get_rank() * buckets)
pos = rearrange('n b -> (b n)', pos)
else:
pos = torch.arange(seq_len, device = device)
pos += seq_len * get_rank()
else:
pos = torch.arange(seq_len, device = device)
pos = pos.type_as(self.inv_freq)
freqs = torch.einsum('i , j -> i j', pos, self.inv_freq)
return torch.cat((freqs, freqs), dim = -1)
# 旋转半部分
def rotate_half(x):
x1, x2 = x.chunk(2, dim = -1)
return torch.cat((-x2, x1), dim=-1)
@autocast(enabled = False)
def apply_rotary_pos_emb(pos, t):
pos = rearrange('n d -> n 1 d', pos)
return t * pos.cos() + rotate_half(t) * pos.sin()
# 批量到序列分片和反向操作
# 将张量填充到指定长度的倍数
def pad_to_multiple(
x: Tensor,
length: int,
pad_value = 0
):
seq_len = x.shape[-1]
remainder = seq_len % length
if remainder == 0:
return x, 0
pad_length = length - remainder
return F.pad(x, (0, pad_length), value = pad_value), pad_length
# 可能填充序列和掩码
def maybe_pad_seq_and_mask(
x: Tensor,
mask: Optional[Tensor],
seq_size: int
):
orig_x, seq_len = x, x.shape[-1]
# 自动填充序列和掩码,因为环传递假设张量的形状都相同
# 调用函数将输入张量 x 填充到 seq_size 的倍数,并返回填充后的张量和填充长度
x, pad_length = pad_to_multiple(x, seq_size)
# 如果填充长度为 0,则直接返回填充后的张量 x 和掩码 mask
if pad_length == 0:
return x, mask
# 如果掩码 mask 不存在,则创建一个与原始输入 orig_x 相同形状的全为 True 的掩码
if not exists(mask):
mask = torch.ones_like(orig_x).bool()
# 调用函数将掩码 mask 填充到 seq_size 的倍数,并使用 False 值进行填充
mask, _ = pad_to_multiple(mask, seq_size, pad_value = False)
# 返回填充后的张量 x 和掩码 mask
return x, mask
def sharded_batch_to_sharded_seq(
x: Tensor,
mask: Optional[Tensor],
seq_size: int
):
assert is_distributed()
# 创建 AllGather 对象,用于在批次维度上进行全局收集
all_gather = AllGather(dim = 0)
# 在批次维度上对输入张量 x 进行全局收集
x, sizes = all_gather(x)
if exists(mask):
# 如果存在 mask,则在批次维度上对 mask 进行全局收集
mask, _ = all_gather(mask)
# 确保世界大小可以被序列大小整除
world_size = get_world_size()
total_split_seq = x.shape[-1] // seq_size
assert divisible_by(world_size, total_split_seq)
num_sharded_batches = world_size // total_split_seq
# 重新排列输入张量 x,以便在序列维度上进行分片
x = rearrange('(b s) n -> b (s n)', x, s = num_sharded_batches)
# 在序列维度上对 x 进行分片
x = x.split(seq_size, dim = -1)
# 根据排名对 x 进行分割
x, _ = split_by_rank(x)
if exists(mask):
# 如果存在 mask,则重新排列 mask,并在序列维度上对其进行分片
mask = rearrange('(b s) n -> b (s n)', mask, s = num_sharded_batches)
mask = mask.split(seq_size, dim = -1)
mask, _ = split_by_rank(mask)
return (x, mask), sizes, num_sharded_batches
def sharded_seq_to_sharded_batch(
logits: Tensor,
sizes,
num_sharded_batches = 1
):
all_gather = AllGather(dim = -2) # 在序列维度上进行全局收集
# 在序列维度上对 logits 进行全局收集
logits, _ = all_gather(logits)
# 重新排列 logits,以便在批次维度上进行分片
logits = rearrange('b (s n) c -> (b s) n c', logits, s = num_sharded_batches)
# 在批次维度上对 logits 进行分片
logits = logits.split(sizes.tolist(), dim = 0)
# 根据排名对 logits 进行分割
logits, _ = split_by_rank(logits)
return logits
# 主类 RingAttention
class RingAttention(Module):
@beartype
def __init__(
self,
dim: int,
*,
dim_head: int = 64,
heads: int = 8,
causal: bool = False,
eps: float = 1e-10,
bucket_size: int = 512,
ring_attn: bool = False,
ring_seq_size: int = 512,
max_lookback_seq_len: Optional[int] = None,
striped_ring_attn: bool = False,
auto_shard_seq: Optional[bool] = None,
prenorm: bool = True,
force_regular_attn: bool = False,
rotary_embed: bool = False,
rotary_embed_theta: int = 10000,
use_cuda_kernel: bool = None
):
super().__init__()
self.eps = eps
self.heads = heads
self.scale = dim_head ** -0.5
self.causal = causal
assert divisible_by(ring_seq_size, bucket_size)
self.ring_attn = ring_attn
self.max_lookback_seq_len = max_lookback_seq_len
self.striped_ring_attn = striped_ring_attn
self.force_regular_attn = force_regular_attn
self.auto_shard_seq = default(auto_shard_seq, ring_attn) # 这应该在 token ids 的转换器级别上完成,但出于测试目的
assert not (not self.ring_attn and self.auto_shard_seq)
self.ring_seq_size = ring_seq_size
self.bucket_size = bucket_size
# 初始化旋转嵌入
self.rotary_embed = None
if rotary_embed:
self.rotary_embed = RingRotaryEmbedding(
dim = dim_head,
ring = ring_attn,
striped = striped_ring_attn,
theta = rotary_embed_theta,
buckets = ring_seq_size // bucket_size
)
# 投影层
dim_inner = dim_head * heads
self.to_qkv = nn.Sequential(
RMSNorm(dim) if prenorm else nn.Identity(),
nn.Linear(dim, dim_inner * 3, bias = False)
)
self.to_out = nn.Linear(dim_inner, dim, bias = False)
# 是否使用 flash attention cuda kernel
self.use_cuda_kernel = default(use_cuda_kernel, torch.cuda.is_available())
assert not (use_cuda_kernel and not torch.cuda.is_available())
def forward(
self,
x,
mask = None,
rotary_emb = None,
force_ring_reduce_off = False,
ring_size = None,
):
"""
einstein notation
b - batch
h - heads
d - feature dimension
n, i, j - sequence
"""
# 设置环的大小为默认值或者获取当前环的大小
ring_size = default(ring_size, get_world_size())
# 判断是否使用环形注意力,并且当前环是否分布式
ring_attn = self.ring_attn & is_distributed()
# 判断是否自动分片序列,并且当前环是否分布式
auto_shard_seq = self.auto_shard_seq & is_distributed()
# 获取序列的长度
seq_len = x.shape[-1]
# 如果自动分片序列为真
if auto_shard_seq:
# 可能填充序列和掩码,使其长度符合环形序列的大小
x, mask = maybe_pad_seq_and_mask(x, mask, self.ring_seq_size)
# 如果使用条纹环形注意力
if self.striped_ring_attn:
# 重新排列张量维度,以适应条纹环形注意力
x = rearrange('b (i j) d -> b (j i) d', x, i = self.bucket_size)
# 如果存在掩码
if exists(mask):
# 重新排列掩码张量维度,以适应条纹环形注意力
mask = rearrange('b (i j) -> b (j i)', mask, i = self.bucket_size)
# 将批次转换为序列,并返回批次大小
(x, mask), batch_sizes = sharded_batch_to_sharded_seq(x, mask, self.ring_seq_size)
# 获取设备信息
device = x.device
# 将输入张量转换为查询、键、值
qkv = self.to_qkv(x)
q, k, v = rearrange('b n (qkv h d) -> qkv b n h d', qkv, qkv = 3, h = self.heads)
# 旋转相对位置
# 如果旋转嵌入不存在且存在旋转嵌入
if not exists(rotary_emb) and exists(self.rotary_embed):
# 生成旋转嵌入
rotary_emb = self.rotary_embed(q.shape[-2])
# 如果存在旋转嵌入
if exists(rotary_emb):
# 应用旋转位置嵌入到查询和键
q = apply_rotary_pos_emb(rotary_emb, q)
k = apply_rotary_pos_emb(rotary_emb, k)
# 常规注意力 vs 闪存注意力(带或不带 kv 环减少)
# 判断是否有任何 CUDA 输入
any_cuda_inputs = any([t.is_cuda for t in (q, k, v)])
# 如果强制使用常规注意力
if self.force_regular_attn:
# 使用默认的注意力机制
out = default_attention(q, k, v, mask = mask, causal = self.causal)
# 如果有任何 CUDA 输入并且使用 CUDA 内核
elif any_cuda_inputs and self.use_cuda_kernel:
# 导入 CUDA 实现的闪存注意力
from ring_attention_pytorch.ring_flash_attention_cuda import ring_flash_attn_cuda
# 使用 CUDA 实现的闪存注意力
out = ring_flash_attn_cuda(
q, k, v,
mask,
self.causal,
self.bucket_size,
ring_attn and not force_ring_reduce_off,
self.striped_ring_attn and not force_ring_reduce_off,
self.max_lookback_seq_len,
ring_size
)
else:
# 使用 Python 实现的闪存注意力
out = ring_flash_attn(
q, k, v,
mask,
self.causal,
self.bucket_size,
ring_attn and not force_ring_reduce_off,
self.striped_ring_attn and not force_ring_reduce_off,
self.max_lookback_seq_len,
ring_size
)
# 合并头部
out = rearrange('b n h d -> b n (h d)', out)
out = self.to_out(out)
# 如果自动分片序列为真
if auto_shard_seq:
# 将序列转换为批次,并截取到原始序列长度
out, _ = sharded_seq_to_sharded_batch(out, batch_sizes)
out = out[:, :seq_len]
# 返回结果
return out
# 定义一个简单的端到端测试的转换器
class RMSNorm(Module):
# 初始化函数,接受一个维度参数
def __init__(self, dim):
super().__init__()
# 计算缩放因子
self.scale = dim ** 0.5
# 初始化可学习参数 gamma
self.gamma = nn.Parameter(torch.ones(dim))
# 前向传播函数
def forward(self, x):
# 对输入进行归一化处理,乘以缩放因子和 gamma
return F.normalize(x, dim = -1) * self.scale * self.gamma
# 定义一个前馈神经网络模块
def FeedForward(dim, mult = 4):
# 计算内部维度
dim_inner = int(dim * mult)
return nn.Sequential(
RMSNorm(dim), # 使用 RMSNorm 进行归一化
nn.Linear(dim, dim_inner), # 线性变换
nn.GELU(), # GELU 激活函数
nn.Linear(dim_inner, dim) # 线性变换
)
# 定义一个环形注意力机制模块
class RingTransformer(Module):
# 初始化函数,接受多个参数
@beartype
def __init__(
self,
*,
num_tokens: int,
dim: int,
depth: int,
causal: bool = False,
dim_head: int = 64,
heads: int = 8,
ff_mult: int = 4,
bucket_size: int = 512,
ring_attn: bool = False,
striped_ring_attn: bool = False,
ring_seq_size: int = 512,
auto_shard_seq: Optional[bool] = None,
max_lookback_seq_len: Optional[Union[Tuple[int, ...], int]] = None,
rotary_embed_theta: int = 10000, # 需要根据上下文中的百万标记进行更改
ignore_index: int = -1
):
super().__init__()
# 初始化环形注意力机制相关参数
self.ring_attn = ring_attn
self.striped_ring_attn = striped_ring_attn
self.ring_seq_size = ring_seq_size
self.bucket_size = bucket_size
assert divisible_by(ring_seq_size, bucket_size)
self.auto_shard_seq = default(auto_shard_seq, ring_attn) # 如果环形注意力机制打开,则自动在序列维度上进行分片。这也可以关闭,在数据加载的其他地方手动完成
assert not (not self.ring_attn and self.auto_shard_seq)
assert not (not self.ring_attn and self.striped_ring_attn)
assert not (self.striped_ring_attn and not causal), 'striped ring attention only applies to autoregressive models'
# 初始化标记嵌入层
self.token_emb = nn.Embedding(num_tokens, dim)
# 初始化旋转嵌入层
self.rotary_emb = RingRotaryEmbedding(
dim = dim_head,
ring = ring_attn,
striped = striped_ring_attn,
theta = rotary_embed_theta,
buckets = ring_seq_size // bucket_size
)
# 初始化层列表
self.layers = ModuleList([])
max_lookback_seq_len = cast_tuple(max_lookback_seq_len, depth)
assert len(max_lookback_seq_len) == depth
for layer_max_lookback_seq_len in max_lookback_seq_len:
self.layers.append(ModuleList([
RingAttention(
dim = dim,
causal = causal,
dim_head = dim_head,
heads = heads,
bucket_size = bucket_size,
ring_attn = ring_attn,
ring_seq_size = ring_seq_size,
max_lookback_seq_len = layer_max_lookback_seq_len,
striped_ring_attn = striped_ring_attn,
auto_shard_seq = False,
),
FeedForward(dim = dim, mult = ff_mult)
]))
# 输出层
self.to_logits = nn.Sequential(
RMSNorm(dim),
nn.Linear(dim, num_tokens, bias = False)
)
# 训练相关
self.ignore_index = ignore_index
# 前向传播函数
def forward(
self,
x,
mask = None,
labels = None,
return_loss = False,
force_ring_reduce_off = False,
ring_size = None
):
# 获取序列长度和设备信息
seq_len, device = x.shape[-1], x.device
# 是否自动分片序列,如果不强制关闭环形归约且自动分片序列且处于分布式环境下
auto_shard_seq = not force_ring_reduce_off and self.auto_shard_seq and is_distributed()
# 如果没有传入标签,则获取标签
return_loss |= exists(labels)
# 如果需要返回损失值且没有传入标签,则将输入数据切片为输入和标签
if return_loss and not exists(labels):
x, labels = x[:, :-1], x[:, 1:]
# 处理填充以便将序列分割到不同机器上
ring_size = default(ring_size, get_world_size())
# 如果自动分片序列
if auto_shard_seq:
# 首先填充到右侧的倍数
x, mask = maybe_pad_seq_and_mask(x, mask, self.ring_seq_size)
# 处理标签
if exists(labels):
labels, label_mask = maybe_pad_seq_and_mask(labels, mask[:, 1:], self.ring_seq_size)
labels.masked_fill_(~label_mask, self.ignore_index)
# 考虑条纹注意力以进行工作负载平衡
if self.striped_ring_attn:
x = rearrange('b (i j) -> b (j i)', x, i = self.bucket_size)
if exists(labels):
labels = rearrange('b (i j) -> b (j i)', labels, i = self.bucket_size)
if exists(mask):
mask = rearrange('b (i j) -> b (j i)', mask, i = self.bucket_size)
# 在批次之间收集并在世界中分割
(x, mask), batch_sizes, num_sharded_batches = sharded_batch_to_sharded_seq(x, mask, self.ring_seq_size)
if exists(labels):
(labels, _), *_ = sharded_batch_to_sharded_seq(labels, None, self.ring_seq_size)
# 根据分片批次数计算环大小
ring_size = get_world_size() // num_sharded_batches
# 旋转位置,考虑环和条纹
rotary_emb = self.rotary_emb(x.shape[-1])
# 主要的Transformer逻辑
x = self.token_emb(x)
for attn, ff in self.layers:
x = attn(
x,
mask = mask,
rotary_emb = rotary_emb,
force_ring_reduce_off = force_ring_reduce_off,
ring_size = ring_size
) + x
x = ff(x) + x
logits = self.to_logits(x)
# 处理返回损失值
if return_loss:
logits = rearrange('b n c -> b c n', logits)
ce_loss = F.cross_entropy(
logits,
labels,
ignore_index = self.ignore_index
)
return ce_loss
# 否则收集所有机器上的序列块以获取logits并分片批次维度
if not auto_shard_seq:
return logits
logits = sharded_seq_to_sharded_batch(logits, batch_sizes, num_sharded_batches)
if self.striped_ring_attn:
logits = rearrange('b (i j) d -> b (j i) d', logits, j = self.bucket_size)
return logits[:, :seq_len]
.\lucidrains\ring-attention-pytorch\ring_attention_pytorch\ring_flash_attention.py
# 导入数学库
import math
# 导入 functools 库中的 partial 函数
from functools import partial
# 导入 typing 库中的 Optional 类型
from typing import Optional
# 导入 torch 库
import torch
# 从 torch 库中导入 nn、einsum、Tensor 类
from torch import nn, einsum, Tensor
# 从 torch.autograd.function 中导入 Function 类
from torch.autograd.function import Function
# 导入 einx 库
import einx
# 从 einx 库中导入 rearrange 函数
from einx import rearrange
# 导入 ring_attention_pytorch.ring 模块中的函数
from ring_attention_pytorch.ring import (
ring_pass,
all_ring_pass,
null_ring_pass,
one_ring_pass,
get_rank,
get_world_size
)
# 导入 beartype 库中的 beartype 装饰器
from beartype import beartype
# 常量定义
EPSILON = 1e-10
# 辅助函数
# 判断变量是否存在
def exists(val):
return val is not None
# 如果变量存在则返回其值,否则返回默认值
def default(val, d):
return val if exists(val) else d
# 判断一个数是否可以被另一个数整除
def divisible_by(num, den):
return (num % den) == 0
# 生成一个无限循环产生 None 的迭代器
def none_iterator():
while True:
yield None
# 根据条件切分张量
def maybe_split(t, size, dim = -2):
if not exists(t):
return none_iterator()
return t.split(size, dim = dim)
# ring + (flash) attention 前向和后向
# flash attention v1 - https://arxiv.org/abs/2205.14135
# flash attention v2 - https://tridao.me/publications/flash2/flash2.pdf
# ring attention - https://arxiv.org/abs/2310.01889
# 定义 RingFlashAttentionFunction 类
class RingFlashAttentionFunction(Function):
# 静态方法,用于前向传播
@staticmethod
@torch.no_grad()
def forward(
ctx,
q: Tensor,
k: Tensor,
v: Tensor,
mask: Optional[Tensor],
causal: bool,
bucket_size: int,
ring_reduce_col: bool,
striped_ring_attn: bool,
max_lookback_seq_len: Optional[int],
ring_size: Optional[int]
@staticmethod
@torch.no_grad()
# 调用 RingFlashAttentionFunction 类的 apply 方法
ring_flash_attn_ = RingFlashAttentionFunction.apply
# 使用 beartype 装饰器定义 ring_flash_attn 函数
@beartype
def ring_flash_attn(
q: Tensor,
k: Tensor,
v: Tensor,
mask: Optional[Tensor] = None,
causal: bool = False,
bucket_size: int = 1024,
ring_reduce_col: bool = False,
striped_ring_attn: bool = False,
max_lookback_seq_len: Optional[int] = None,
ring_size: Optional[int] = None
):
# 调用 ring_flash_attn_ 函数
return ring_flash_attn_(q, k, v, mask, causal, bucket_size, ring_reduce_col, striped_ring_attn, max_lookback_seq_len, ring_size)
.\lucidrains\ring-attention-pytorch\ring_attention_pytorch\ring_flash_attention_cuda.py
# 导入数学库
import math
# 导入 functools 库中的 partial 函数
from functools import partial
# 导入 typing 库中的 Optional 和 Tuple 类型
from typing import Optional, Tuple
# 导入 packaging 库中的 version 模块
import packaging.version as pkg_version
# 导入 torch 库
import torch
# 从 torch 库中导入 nn, einsum, Tensor 模块
from torch import nn, einsum, Tensor
# 从 torch 库中导入 F 模块
import torch.nn.functional as F
# 从 torch.autograd.function 中导入 Function 类
from torch.autograd.function import Function
# 从 ring_attention_pytorch.ring 模块中导入相关函数
from ring_attention_pytorch.ring import (
ring_pass,
all_ring_pass,
null_ring_pass,
one_ring_pass,
get_rank,
get_world_size
)
# 从 beartype 库中导入 beartype 函数
from beartype import beartype
# 从 einops 库中导入 repeat, rearrange 函数
from einops import repeat, rearrange
# 定义函数 exists,判断变量是否存在
def exists(v):
return v is not None
# 定义函数 pad_at_dim,对张量在指定维度进行填充
def pad_at_dim(t, pad: Tuple[int, int], *, dim = -1, value = 0.):
dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
zeros = ((0, 0) * dims_from_right)
return F.pad(t, (*zeros, *pad), value = value)
# 定义函数 is_contiguous,判断张量是否是连续的
def is_contiguous(x):
return x.stride(-1) == 1
# 确保 flash attention 已安装用于反向传播
import importlib
from importlib.metadata import version
# 断言 flash-attn 必须已安装
assert exists(importlib.util.find_spec('flash_attn')), 'flash-attn must be installed. `pip install flash-attn --no-build-isolation` first'
# 获取 flash-attn 版本信息
flash_attn_version = version('flash_attn')
# 断言 flash-attn 版本大于等于 2.5.1
assert pkg_version.parse(flash_attn_version) >= pkg_version.parse('2.5.1')
# 从 flash_attn.flash_attn_interface 模块中导入相关函数
from flash_attn.flash_attn_interface import (
_flash_attn_varlen_backward,
_flash_attn_backward
)
# 确保 triton 已安装用于前向传播
assert exists(importlib.util.find_spec('triton')), 'latest triton must be installed. `pip install triton -U` first'
# 获取 triton 版本信息
triton_version = version('triton')
# 断言 triton 版本大于等于 2.1
assert pkg_version.parse(triton_version) >= pkg_version.parse('2.1')
# 导入 triton 库
import triton
# 从 triton.language 中导入 tl 模块
import triton.language as tl
# 从 Tri 的 flash_attn 仓库中获取 flash attention 前向传播代码,并进行修改以返回未归一化的累积值、行最大值和行 lse - 减少通过环传递
@triton.heuristics(
{
"EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
"EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
"EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
}
)
@triton.jit
def _fwd_kernel(
Q,
K,
V,
Bias,
Out,
M,
Lse,
softmax_scale,
stride_qb,
stride_qh,
stride_qm,
stride_kb,
stride_kh,
stride_kn,
stride_vb,
stride_vh,
stride_vn,
stride_bb,
stride_bh,
stride_bm,
stride_ob,
stride_oh,
stride_om,
nheads,
seqlen_q,
seqlen_k,
seqlen_q_rounded,
headdim,
CACHE_KEY_SEQLEN_Q,
CACHE_KEY_SEQLEN_K,
HAS_BIAS: tl.constexpr,
IS_CAUSAL: tl.constexpr,
CAUSAL_MASK_DIAGONAL: tl.constexpr,
LOAD_ACCUMULATED: tl.constexpr,
RETURN_NORMALIZED_OUTPUT: tl.constexpr,
BLOCK_HEADDIM: tl.constexpr,
EVEN_M: tl.constexpr,
EVEN_N: tl.constexpr,
EVEN_HEADDIM: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
start_m = tl.program_id(0)
off_hb = tl.program_id(1)
off_b = off_hb // nheads
off_h = off_hb % nheads
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_HEADDIM)
q_ptrs = (
Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :])
)
k_ptrs = (
K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :])
)
v_ptrs = (
V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :])
)
if HAS_BIAS:
b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n
# 最大值
m_ptrs = M + off_hb * seqlen_q_rounded + offs_m
if LOAD_ACCUMULATED:
m_i = tl.load(m_ptrs)
else:
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
# 加载 lse
lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m
if LOAD_ACCUMULATED:
lse_i = tl.load(lse_ptrs)
else:
# 如果条件不成立,创建一个形状为 [BLOCK_M],数据类型为 float32 的张量,并填充为负无穷大
lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
# 加载累积输出的偏移量
offs_d = tl.arange(0, BLOCK_HEADDIM)
# 计算输出指针的位置
out_ptrs = (
Out
+ off_b * stride_ob
+ off_h * stride_oh
+ (offs_m[:, None] * stride_om + offs_d[None, :])
)
# 如果需要加载累积值
if LOAD_ACCUMULATED:
# 如果 BLOCK_M 是偶数
if EVEN_M:
# 如果 BLOCK_HEADDIM 是偶数
if EVEN_HEADDIM:
acc_o = tl.load(out_ptrs)
else:
acc_o = tl.load(out_ptrs, mask=offs_d[None, :] < headdim)
else:
# 如果 BLOCK_HEADDIM 是偶数
if EVEN_HEADDIM:
acc_o = tl.load(out_ptrs, mask=offs_m[:, None] < seqlen_q)
else:
acc_o = tl.load(
out_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim)
)
acc_o = acc_o.to(tl.float32)
else:
# 创建一个形状为 [BLOCK_M, BLOCK_HEADDIM],数据类型为 float32 的零张量
acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
# 加载查询、键、值
if EVEN_M & EVEN_N:
# 如果 BLOCK_M 和 BLOCK_N 都是偶数
if EVEN_HEADDIM:
q = tl.load(q_ptrs)
else:
q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
else:
# 如果 BLOCK_M 和 BLOCK_N 不都是偶数
if EVEN_HEADDIM:
q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)
else:
q = tl.load(
q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0
)
# 计算结束位置
end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
# 循环遍历起始位置,每次增加 BLOCK_N
for start_n in range(0, end_n, BLOCK_N):
# 将 start_n 调整为 BLOCK_N 的倍数
start_n = tl.multiple_of(start_n, BLOCK_N)
# 根据条件判断是否加载 k
if EVEN_N & EVEN_M:
# 根据条件加载 k
if EVEN_HEADDIM:
k = tl.load(k_ptrs + start_n * stride_kn)
else:
k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0)
else:
if EVEN_HEADDIM:
k = tl.load(
k_ptrs + start_n * stride_kn,
mask=(start_n + offs_n)[:, None] < seqlen_k,
other=0.0,
)
else:
k = tl.load(
k_ptrs + start_n * stride_kn,
mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
other=0.0,
)
# 初始化 qk
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
# 计算 qk
qk += tl.dot(q, tl.trans(k))
# 根据条件判断是否添加特定值到 qk
if not EVEN_N:
qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf"))
# 根据条件判断是否添加特定值到 qk
if IS_CAUSAL:
if CAUSAL_MASK_DIAGONAL:
# 为 stripe attention 需要的操作
qk += tl.where(offs_m[:, None] > (start_n + offs_n)[None, :], 0, float("-inf"))
else:
qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf"))
# 根据条件判断是否添加偏置到 qk
if HAS_BIAS:
if EVEN_N:
bias = tl.load(b_ptrs + start_n)
else:
bias = tl.load(
b_ptrs + start_n, mask=(start_n + offs_n) < seqlen_k, other=0.0
)
bias = bias[None, :]
bias = bias.to(tl.float32)
qk = qk * softmax_scale + bias
m_ij = tl.maximum(tl.max(qk, 1), lse_i)
p = tl.exp(qk - m_ij[:, None])
else:
m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i)
p = tl.exp(qk * softmax_scale - m_ij[:, None])
# 计算 l_ij
l_ij = tl.sum(p, 1)
# 计算 acc_o_scale
acc_o_scale = tl.exp(m_i - m_ij)
acc_o = acc_o * acc_o_scale[:, None]
# 根据条件判断是否加载 v
if EVEN_N & EVEN_M:
if EVEN_HEADDIM:
v = tl.load(v_ptrs + start_n * stride_vn)
else:
v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0)
else:
if EVEN_HEADDIM:
v = tl.load(
v_ptrs + start_n * stride_vn,
mask=(start_n + offs_n)[:, None] < seqlen_k,
other=0.0,
)
else:
v = tl.load(
v_ptrs + start_n * stride_vn,
mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
other=0.0,
)
# 将 p 转换为与 v 相同的数据类型
p = p.to(v.dtype)
acc_o += tl.dot(p, v)
# -- 更新统计信息
m_i = m_ij
l_i_new = tl.exp(lse_i - m_ij) + l_ij
lse_i = m_ij + tl.log(l_i_new)
# 如果需要返回归一化的输出
if RETURN_NORMALIZED_OUTPUT:
acc_o_scale = tl.exp(m_i - lse_i)
acc_o = acc_o * acc_o_scale[:, None]
# 计算 m 和 lse 的偏移量
start_m = tl.program_id(0)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
# 写回 lse 和 m
tl.store(lse_ptrs, lse_i)
if not RETURN_NORMALIZED_OUTPUT:
tl.store(m_ptrs, m_i)
# 写入输出
if EVEN_M:
if EVEN_HEADDIM:
tl.store(out_ptrs, acc_o)
else:
tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim)
else:
if EVEN_HEADDIM:
tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q)
else:
tl.store(
out_ptrs, acc_o, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim)
)
# 定义 flash attention 的前向传播函数
def flash_attn_forward(
q,
k,
v,
bias = None,
causal = False,
o = None,
m = None,
lse = None,
softmax_scale = None,
causal_mask_diagonal = False,
return_normalized_output = False,
load_accumulated = True
):
# 如果输入的张量不是连续的,则将其转换为连续的张量
q, k, v = [x if is_contiguous(x) else x.contiguous() for x in (q, k, v)]
# 获取输入张量的形状信息
batch, seqlen_q, nheads, d = q.shape
_, seqlen_k, _, _ = k.shape
# 断言输入张量的形状符合要求
assert k.shape == (batch, seqlen_k, nheads, d)
assert v.shape == (batch, seqlen_k, nheads, d)
assert d <= 128, "FlashAttention only support head dimensions up to 128"
assert q.dtype == k.dtype == v.dtype, "All tensors must have the same type"
assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16"
assert q.is_cuda and k.is_cuda and v.is_cuda
# 设置 softmax 的缩放因子
softmax_scale = default(softmax_scale, d ** -0.5)
# 检查是否存在偏置项
has_bias = exists(bias)
if has_bias:
assert bias.dtype in [q.dtype, torch.float]
assert bias.is_cuda
# 如果偏置项是二维的,则进行扩展
if bias.ndim == 2:
bias = repeat(bias, 'b j -> b h i j', h = nheads, i = seqlen_q)
# 如果偏置项不是连续的,则转换为连续的张量
if not is_contiguous(bias):
bias = bias.contiguous()
assert bias.shape[-2:] == (1, seqlen_k)
bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)
# 记录偏置项的步长信息
bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)
# 对序列长度进行向上取整,使其能够被 128 整除
seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
# 初始化 lse 张量
if not exists(lse):
max_neg_value = -torch.finfo(torch.float32).max
init_fn = partial(torch.full, fill_value = max_neg_value) if load_accumulated else torch.empty
lse = init_fn((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
# 初始化 m 张量
if not exists(m):
max_neg_value = -torch.finfo(torch.float32).max
init_fn = partial(torch.full, fill_value = max_neg_value) if load_accumulated else torch.empty
m = init_fn((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
# 初始化输出张量 o
if not exists(o):
init_fn = torch.zeros_like if load_accumulated else torch.empty_like
o = init_fn(q)
# 设置 BLOCK_HEADDIM 和 BLOCK 的值
BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
BLOCK = 128
num_warps = 4 if d <= 64 else 8
grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
# 调用 _fwd_kernel 函数进行前向传播计算
_fwd_kernel[grid](
q,
k,
v,
bias,
o,
m,
lse,
softmax_scale,
q.stride(0),
q.stride(2),
q.stride(1),
k.stride(0),
k.stride(2),
k.stride(1),
v.stride(0),
v.stride(2),
v.stride(1),
*bias_strides,
o.stride(0),
o.stride(2),
o.stride(1),
nheads,
seqlen_q,
seqlen_k,
seqlen_q_rounded,
d,
seqlen_q // 32,
seqlen_k // 32,
has_bias,
causal,
causal_mask_diagonal,
load_accumulated,
return_normalized_output,
BLOCK_HEADDIM,
BLOCK_M = BLOCK,
BLOCK_N = BLOCK,
num_warps = num_warps,
num_stages = 1,
)
# 返回输出张量 o, m, lse
return o, m, lse
# 辅助函数
# 判断变量是否存在
def exists(val):
return val is not None
# 返回默认值
def default(val, d):
return val if exists(val) else d
# 判断一个数是否能被另一个数整除
def divisible_by(num, den):
return (num % den) == 0
# ring + (flash) attention forwards and backwards
# flash attention v1 - https://arxiv.org/abs/2205.14135
# flash attention v2 - https://tridao.me/publications/flash2/flash2.pdf
# ring attention - https://arxiv.org/abs/2310.01889
# 定义 RingFlashAttentionCUDAFunction 类
class RingFlashAttentionCUDAFunction(Function):
# 前向传播函数
@staticmethod
@torch.no_grad()
def forward(
ctx,
q: Tensor,
k: Tensor,
v: Tensor,
mask: Optional[Tensor],
causal: bool,
bucket_size: int,
ring_reduce_col: bool,
striped_ring_attn: bool,
max_lookback_seq_len: Optional[int],
ring_size: Optional[int]
@staticmethod
@torch.no_grad()
# 将自定义的 CUDA 函数应用到环形闪光注意力机制上
ring_flash_attn_cuda_ = RingFlashAttentionCUDAFunction.apply
# 定义环形闪光注意力机制的 CUDA 函数
@beartype
def ring_flash_attn_cuda(
q: Tensor,
k: Tensor,
v: Tensor,
mask: Optional[Tensor] = None,
causal: bool = False,
bucket_size: int = 1024,
ring_reduce_col: bool = False,
striped_ring_attn: bool = False,
max_lookback_seq_len: Optional[int] = None,
ring_size: Optional[int] = None
):
# 调用环形闪光注意力机制的 CUDA 函数,传入参数并返回结果
return ring_flash_attn_cuda_(q, k, v, mask, causal, bucket_size, ring_reduce_col, striped_ring_attn, max_lookback_seq_len, ring_size)