Lucidrains 系列项目源码解析(五十八)
.\lucidrains\memorizing-transformers-pytorch\setup.py
# 导入设置工具和查找包的函数
from setuptools import setup, find_packages
# 设置包的元数据
setup(
name = 'memorizing-transformers-pytorch', # 包的名称
packages = find_packages(exclude=[]), # 查找所有包
version = '0.4.1', # 版本号
license='MIT', # 许可证
description = 'Memorizing Transformer - Pytorch', # 描述
long_description_content_type = 'text/markdown', # 长描述内容类型
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
url = 'https://github.com/lucidrains/memorizing-transformers-pytorch', # 项目链接
keywords = [ # 关键词列表
'artificial intelligence',
'deep learning',
'transformers',
'memory',
'retrieval'
],
install_requires=[ # 安装依赖
'einops>=0.6',
'filelock',
'joblib',
'faiss-gpu',
'numpy',
'torch>=1.6',
],
classifiers=[ # 分类器
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\memorizing-transformers-pytorch\train.py
# 导入所需的库
from memorizing_transformers_pytorch import MemorizingTransformer
import random
import tqdm
import gzip
import numpy as np
import torch
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
# 常量定义
NUM_BATCHES = int(1e5)
BATCH_SIZE = 16
SEQ_LEN = 512
SEGMENTS = 5
LEARNING_RATE = 2e-4
MAX_GRAD_CLIP_NORM = 0.5
VALIDATE_EVERY = 100
GENERATE_EVERY = 500
GENERATE_LENGTH = 512
# 辅助函数
# 从 token 解码为字符
def decode_token(token):
return str(chr(max(32, token)))
# 从 tokens 解码为字符串
def decode_tokens(tokens):
return ''.join(list(map(decode_token, tokens)))
# 实例化类似 GPT 的解码器模型
model = MemorizingTransformer(
num_tokens = 256,
dim = 512,
depth = 8,
memorizing_layers = 4,
max_knn_memories = 512 * 15,
num_retrieved_memories = 32,
xl_memory_layers = (7, 8),
xl_max_memories = 512,
).cuda()
# 准备 enwik8 数据
with gzip.open('./data/enwik8.gz') as file:
X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
trX, vaX = np.split(X, [int(90e6)])
data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)
# 定义文本采样数据集类
class TextSamplerDataset(Dataset):
def __init__(self, data, seq_len):
super().__init__()
self.data = data
self.seq_len = seq_len
def __getitem__(self, index):
rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
return full_seq.cuda()
def __len__(self):
return self.data.size(0) // self.seq_len
# 数据集和数据加载器
train_dataset = TextSamplerDataset(data_train, SEQ_LEN * SEGMENTS)
train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE, drop_last = True))
valid_dataset = TextSamplerDataset(data_val, SEQ_LEN * SEGMENTS)
valid_loader = cycle(DataLoader(valid_dataset, batch_size = BATCH_SIZE, drop_last = True))
# 优化器
optim = torch.optim.Adam(model.parameters(), lr = LEARNING_RATE)
# 训练
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval = 10., desc = 'training'):
model.train()
data = next(train_loader)
train_loss = 0.
with model.knn_memories_context(batch_size = BATCH_SIZE) as knn_memories:
xl_memories = None
seq, labels = data[:, :-1], data[:, 1:]
for seq_segment, labels_segment in zip(seq.chunk(SEGMENTS, dim = -1), labels.chunk(SEGMENTS, dim = -1)):
loss, xl_memories = model(
seq_segment,
labels = labels_segment,
knn_memories = knn_memories,
xl_memories = xl_memories
)
train_loss += loss.item() / SEGMENTS
(loss / SEGMENTS).backward()
print(f'training loss: {train_loss}')
torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_CLIP_NORM)
optim.step()
optim.zero_grad()
if not (i % VALIDATE_EVERY):
model.eval()
valid_data = next(valid_loader)
valid_loss = 0.
with torch.no_grad(), model.knn_memories_context(batch_size = BATCH_SIZE) as knn_memories:
xl_memories = None
seq, labels = data[:, :-1], data[:, 1:]
for seq_segment, labels_segment in zip(seq.chunk(SEGMENTS, dim = -1), labels.chunk(SEGMENTS, dim = -1)):
loss, xl_memories = model(
seq_segment,
labels = labels_segment,
knn_memories = knn_memories,
xl_memories = xl_memories
)
valid_loss += loss.item() / SEGMENTS
print(f'valid loss: {valid_loss}')
.\lucidrains\memory-compressed-attention\memory_compressed_attention\memory_compressed_attention.py
import torch
from torch import nn
import torch.nn.functional as F
# 定义卷积压缩类
class ConvCompress(nn.Module):
def __init__(self, dim, ratio = 3, groups = 1):
super().__init__()
self.conv = nn.Conv1d(dim, dim, ratio, stride = ratio, groups = groups)
def forward(self, mem):
mem = mem.transpose(1, 2)
compressed_mem = self.conv(mem)
return compressed_mem.transpose(1, 2)
# 主类
class MemoryCompressedAttention(nn.Module):
def __init__(
self,
dim,
heads = 8,
causal = False,
compression_factor = 3,
dropout = 0.):
super().__init__()
assert (dim % heads) == 0, 'dimension must be divisible by number of heads'
self.heads = heads
self.causal = causal
self.compression_factor = compression_factor
self.compress_fn = ConvCompress(dim, compression_factor, groups = heads)
self.to_qkv = nn.Linear(dim, dim * 3, bias = False)
self.to_out = nn.Linear(dim, dim)
self.dropout = nn.Dropout(dropout)
self.null_k = nn.Parameter(torch.zeros(1, 1, dim))
self.null_v = nn.Parameter(torch.zeros(1, 1, dim))
def forward(self, x, input_mask = None):
b, t, d, h, cf, device = *x.shape, self.heads, self.compression_factor, x.device
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
# 确保键和值的序列长度可以被压缩因子整除
padding = cf - (t % cf)
if padding < cf:
k, v = map(lambda t: F.pad(t, (0, 0, padding, 0)), (k, v))
# 压缩键和值
k, v = map(self.compress_fn, (k, v))
# 在第一个查询没有键需要关注的情况下,附加一个空键和值
nk, nv = map(lambda t: t.expand(b, -1, -1), (self.null_k, self.null_v))
k = torch.cat((nk, k), dim=1)
v = torch.cat((nv, v), dim=1)
# 合并头部
q, k, v = map(lambda t: t.reshape(*t.shape[:2], h, -1).transpose(1, 2), (q, k, v))
# 注意力计算
dots = torch.einsum('bhid,bhjd->bhij', q, k) * d ** -0.5
mask_value = -torch.finfo(dots.dtype).max
# 如果需要,进行因果遮罩
if self.causal:
mask_q = mask_k = torch.arange(t, device=device)
if padding < cf:
mask_k = F.pad(mask_k, (padding, 0))
mask_k, _ = mask_k.reshape(-1, cf).max(dim=-1)
mask = mask_q[:, None] < mask_k[None, :]
mask = F.pad(mask, (1, 0), value=False)
dots.masked_fill_(mask[None, None, ...], mask_value)
del mask
# 输入遮罩
if input_mask is not None:
mask_q = mask_k = input_mask
if padding < cf:
mask_k = F.pad(mask_k, (padding, 0), value=True)
mask_k = mask_k.reshape(b, -1, cf).sum(dim=-1) > 0
mask = mask_q[:, None, :, None] < mask_k[:, None, None, :]
mask = F.pad(mask, (1, 0), value=True)
dots.masked_fill_(~mask, mask_value)
del mask
# 注意力权重
attn = dots.softmax(dim=-1)
# dropout
attn = self.dropout(attn)
out = torch.einsum('bhij,bhjd->bhid', attn, v)
# 分割头部并合并
out = out.transpose(1, 2).reshape(b, t, d)
return self.to_out(out)
.\lucidrains\memory-compressed-attention\memory_compressed_attention\__init__.py
# 从 memory_compressed_attention.memory_compressed_attention 模块中导入 MemoryCompressedAttention 类
from memory_compressed_attention.memory_compressed_attention import MemoryCompressedAttention

Memory Compressed Attention
Implementation of the Self-Attention layer of the proposed Memory-Compressed Attention, in Pytorch. This repository offers both the causal and non-causal variant, and will take care of the padding if the sequence length is not divisible by the compression ratio.
The code also resolves an edge-case where the very first query have no keys to attend to in the auto-regressive scenario. The solution is to use null key/values, appended to the final compressed set, so that there is always at least 1 key for all queries to attend to.
Install
$ pip install memory_compressed_attention
Usage
import torch
from memory_compressed_attention import MemoryCompressedAttention
attn = MemoryCompressedAttention(
dim = 512,
heads = 8, # number of heads
causal = False, # auto-regressive or not
compression_factor = 3, # compression ratio
dropout = 0.1 # dropout post-attention
)
x = torch.randn(1, 1024, 512)
mask = torch.ones(1, 1024).bool()
attn(x, input_mask = mask) # (1, 1024, 512)
Citations
@misc{liu2018generating,
title={Generating Wikipedia by Summarizing Long Sequences},
author={Peter J. Liu and Mohammad Saleh and Etienne Pot and Ben Goodrich and Ryan Sepassi and Lukasz Kaiser and Noam Shazeer},
year={2018},
eprint={1801.10198},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
.\lucidrains\memory-compressed-attention\setup.py
# 导入设置和查找包的函数
from setuptools import setup, find_packages
# 设置包的元数据
setup(
name = 'memory_compressed_attention', # 包的名称
packages = find_packages(), # 查找所有包
version = '0.0.7', # 版本号
license='MIT', # 许可证
description = 'Memory-Compressed Self Attention', # 描述
long_description_content_type = 'text/markdown', # 长描述内容类型
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
url = 'https://github.com/lucidrains/memory-compressed-attention', # 项目链接
keywords = ['transformers', 'artificial intelligence', 'attention mechanism'], # 关键词
install_requires=[
'torch' # 安装所需的依赖
],
classifiers=[
'Development Status :: 4 - Beta', # 开发状态
'Intended Audience :: Developers', # 预期受众
'Topic :: Scientific/Engineering :: Artificial Intelligence', # 主题
'License :: OSI Approved :: MIT License', # 许可证
'Programming Language :: Python :: 3.6', # 编程语言
],
)
My explorations into editing the knowledge and memories of an attention network.
Citations
@article{meng2022memit,
title = {Mass Editing Memory in a Transformer},
author = {Kevin Meng and Sen Sharma, Arnab and Alex Andonian and Yonatan Belinkov and David Bau},
journal = {arXiv preprint arXiv:2210.07229},
year = {2022}
}
@inproceedings{Burns2022DiscoveringLK,
title = {Discovering Latent Knowledge in Language Models Without Supervision},
author = {Collin Burns and Hao-Tong Ye and Dan Klein and Jacob Steinhardt},
year = {2022}
}
Data source
The enwik8 data was downloaded from the Hutter prize page: prize.hutter1.net/
.\lucidrains\memory-efficient-attention-pytorch\memory_efficient_attention_pytorch\autoregressive_wrapper.py
import torch
from torch import nn
import torch.nn.functional as F
# helper function
# 检查值是否存在的辅助函数
def exists(val):
return val is not None
# 评估装饰器函数
def eval_decorator(fn):
def inner(model, *args, **kwargs):
# 保存模型原始训练状态
was_training = model.training
# 将模型设置为评估模式
model.eval()
# 调用传入的函数
out = fn(model, *args, **kwargs)
# 恢复模型原始训练状态
model.train(was_training)
return out
return inner
# top k filtering
# 根据阈值过滤 logits 中的 top k 值
def top_k(logits, thres = 0.9):
# 计算 top k 的数量
k = int((1 - thres) * logits.shape[-1])
# 获取 top k 的值和索引
val, ind = torch.topk(logits, k)
# 创建与 logits 相同形状的全为负无穷的张量
probs = torch.full_like(logits, float('-inf'))
# 根据索引将 top k 的值填充到 probs 中
probs.scatter_(1, ind, val)
return probs
# 自回归包装器类
class AutoregressiveWrapper(nn.Module):
def __init__(self, net, pad_value = 0):
super().__init__()
self.pad_value = pad_value
self.net = net
self.max_seq_len = net.max_seq_len
# 生成序列的方法
@torch.no_grad()
@eval_decorator
def generate(self, start_tokens, seq_len, eos_token = None, temperature = 1., filter_thres = 0.9, **kwargs):
# 获取起始 tokens 的形状和设备信息
b, t, device = *start_tokens.shape, start_tokens.device
out = start_tokens
for _ in range(seq_len):
# 获取最后 self.max_seq_len 个 token
x = out[:, -self.max_seq_len:]
# 获取模型预测的 logits
logits = self.net(x, **kwargs)[:, -1, :]
# 过滤 logits 中的 top k 值
filtered_logits = top_k(logits, thres = filter_thres)
# 计算 softmax 温度调节后的概率
probs = F.softmax(filtered_logits / temperature, dim=-1)
# 从概率分布中采样一个 token
sample = torch.multinomial(probs, 1)
# 将采样的 token 添加到输出序列中
out = torch.cat((out, sample), dim=-1)
if exists(eos_token):
# 检查是否存在 eos_token
is_eos_token = (out == eos_token)
if is_eos_token.any(dim = -1).all():
# 如果所有序列中都存在 eos_token,则停止生成
# 创建一个向右移动一位�� eos_token mask
shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
# 创建一个 mask,标记 eos_token 后的所有位置
mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1
# 将 mask 标记的位置填充为 pad_value
out = out.masked_fill(mask, self.pad_value)
break
# 去除起始 tokens,返回生成的序列
out = out[:, t:]
return out
# 前向传播方法
def forward(self, x, **kwargs):
# 将输入拆分为输入和标签
x_inp, x_labels = x[:, :-1], x[:, 1:]
return self.net(x_inp, labels = x_labels, **kwargs)
.\lucidrains\memory-efficient-attention-pytorch\memory_efficient_attention_pytorch\cosine_sim_flash_attention.py
# 导入所需的库
import math
import torch
from functools import partial
from torch import nn, einsum
import torch.nn.functional as F
from torch.autograd.function import Function
from einops import rearrange
# 定义常量
EPSILON = 1e-6
# 辅助函数
# 检查变量是否存在
def exists(val):
return val is not None
# 如果变量存在则返回其值,否则返回默认值
def default(val, d):
return val if exists(val) else d
# 对输入张量进行 L2 归一化
def l2norm(t):
return F.normalize(t, dim = -1)
# FlashAttentionFunction 类,实现了自定义的 PyTorch 函数
class FlashAttentionFunction(Function):
# 前向传播函数
@staticmethod
@torch.no_grad()
def forward(ctx, q, k, v, mask, scale, causal, q_bucket_size, k_bucket_size):
device = q.device
max_neg_value = -torch.finfo(q.dtype).max
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
k_len = k.shape[-2] # 在余弦相似度注意力中,行和受到键/值序列长度的限制
o = torch.zeros_like(q)
all_row_sums = torch.zeros((*q.shape[:-1], 1), device = device)
# 处理输入的 mask
if not exists(mask):
mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
else:
mask = rearrange(mask, 'b n -> b 1 1 n')
mask = mask.split(q_bucket_size, dim = -1)
row_splits = zip(
q.split(q_bucket_size, dim = -2),
o.split(q_bucket_size, dim = -2),
mask,
all_row_sums.split(q_bucket_size, dim = -2),
)
# 遍历每个分块的行
for ind, (qc, oc, row_mask, row_sums) in enumerate(row_splits):
q_start_index = ind * q_bucket_size - qk_len_diff
col_splits = zip(
k.split(k_bucket_size, dim = -2),
v.split(k_bucket_size, dim = -2),
)
# 遍历每个分块的列
for k_ind, (kc, vc) in enumerate(col_splits):
k_start_index = k_ind * k_bucket_size
# 计算注意力权重
attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale
# 如果存在行 mask,则进行填充
if exists(row_mask):
attn_weights.masked_fill_(~row_mask, max_neg_value)
# 如果启用因果注意力,并且当前位置不应该看到后续位置的信息,则进行填充
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype = torch.bool, device = device).triu(q_start_index - k_start_index + 1)
attn_weights.masked_fill_(causal_mask, max_neg_value)
attn_weights -= scale
exp_weights = torch.exp(attn_weights)
# 如果存在行 mask,则进行填充
if exists(row_mask):
exp_weights.masked_fill_(~row_mask, 0.)
block_row_sums = exp_weights.sum(dim = -1, keepdims = True).clamp(min = EPSILON)
exp_values = einsum('... i j, ... j d -> ... i d', exp_weights, vc)
oc.add_(exp_values / k_len)
row_sums.add_(block_row_sums)
# 保存参数和中间结果,用于反向传播
ctx.args = (scale, causal, mask, q_bucket_size, k_bucket_size)
ctx.save_for_backward(q, k, v, o, all_row_sums)
# 对输出进行缩放
o.mul_(k_len / all_row_sums)
return o
@staticmethod
@torch.no_grad()
# 定义一个反向传播函数,接收上下文和梯度作为参数
def backward(ctx, do):
# 解包上下文参数
scale, causal, mask, q_bucket_size, k_bucket_size = ctx.args
q, k, v, o, l = ctx.saved_tensors
# 获取设备信息
device = q.device
# 计算最大负值
max_neg_value = -torch.finfo(q.dtype).max
# 计算 q 和 k 的长度差
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
# 初始化梯度变量
dq = torch.zeros_like(q)
dk = torch.zeros_like(k)
dv = torch.zeros_like(v)
# 按照 q_bucket_size 分割张量
row_splits = zip(
q.split(q_bucket_size, dim = -2),
o.split(q_bucket_size, dim = -2),
do.split(q_bucket_size, dim = -2),
mask,
l.split(q_bucket_size, dim = -2),
dq.split(q_bucket_size, dim = -2)
)
# 遍历分割后的张量
for ind, (qc, oc, doc, row_mask, lc, dqc) in enumerate(row_splits):
# 计算 q 的起始索引
q_start_index = ind * q_bucket_size - qk_len_diff
# 按照 k_bucket_size 分割张量
col_splits = zip(
k.split(k_bucket_size, dim = -2),
v.split(k_bucket_size, dim = -2),
dk.split(k_bucket_size, dim = -2),
dv.split(k_bucket_size, dim = -2),
)
# 遍历分割后的张量
for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
# 计算 k 的起始索引
k_start_index = k_ind * k_bucket_size
# 计算注意力权重
attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale
# 如果是因果注意力机制,进行掩码处理
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype = torch.bool, device = device).triu(q_start_index - k_start_index + 1)
attn_weights.masked_fill_(causal_mask, max_neg_value)
# 计算指数化的注意力权重
exp_attn_weights = torch.exp(attn_weights - scale)
# 如果存在行掩码,进行填充
if exists(row_mask):
exp_attn_weights.masked_fill_(~row_mask, 0.)
# 计算概率
p = exp_attn_weights / lc
# 计算 dv_chunk
dv_chunk = einsum('... i j, ... i d -> ... j d', p, doc)
# 计算 dp
dp = einsum('... i d, ... j d -> ... i j', doc, vc)
# 计算 D
D = (doc * oc).sum(dim = -1, keepdims = True)
# 计算 ds
ds = p * scale * (dp - D)
# 计算 dq_chunk
dq_chunk = einsum('... i j, ... j d -> ... i d', ds, kc)
# 计算 dk_chunk
dk_chunk = einsum('... i j, ... i d -> ... j d', ds, qc)
# 累加梯度
dqc.add_(dq_chunk)
dkc.add_(dk_chunk)
dvc.add_(dv_chunk)
# 返回梯度
return dq, dk, dv, None, None, None, None, None
# 主类
# 闪光注意力机制用于余弦相似度注意力
# 相对较简单,不再需要担心 softmax 数值稳定性问题,行和受到限制
class FlashAttention(nn.Module):
def __init__(
self,
*,
dim,
scale = 16,
heads = 8,
dim_head = 64,
causal = False,
q_bucket_size = 512,
k_bucket_size = 1024
):
super().__init__()
self.heads = heads
self.scale = scale
self.causal = causal
inner_dim = heads * dim_head
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
# 内存高效的注意力相关参数
# 可以在前向传播中被覆盖
self.q_bucket_size = q_bucket_size
self.k_bucket_size = k_bucket_size
def forward(
self,
x,
context = None,
mask = None,
q_bucket_size = None,
k_bucket_size = None,
):
q_bucket_size = default(q_bucket_size, self.q_bucket_size)
k_bucket_size = default(k_bucket_size, self.k_bucket_size)
h = self.heads
context = default(context, x)
q = self.to_q(x)
k, v = self.to_kv(context).chunk(2, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
q, k = map(l2norm, (q, k))
out = FlashAttentionFunction.apply(q, k, v, mask, self.scale, self.causal, q_bucket_size, k_bucket_size)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
.\lucidrains\memory-efficient-attention-pytorch\memory_efficient_attention_pytorch\flash_attention.py
# 导入数学库和 PyTorch 库
import math
import torch
# 导入 partial 函数
from functools import partial
# 从 torch 模块中导入 nn 和 einsum 函数
from torch import nn, einsum
# 从 torch.autograd.function 模块中导入 Function 类
from torch.autograd.function import Function
# 从 einops 库中导入 rearrange 函数
from einops import rearrange
# 定义常量 EPSILON
EPSILON = 1e-10
# 定义辅助函数
# 判断变量是否存在的函数
def exists(val):
return val is not None
# 如果变量存在则返回其值,否则返回默认值的函数
def default(val, d):
return val if exists(val) else d
# flash attention 前向和后向
# flash attention v1 - https://arxiv.org/abs/2205.14135
# flash attention v2 - https://tridao.me/publications/flash2/flash2.pdf
# 定义 FlashAttentionFunction 类,继承自 Function 类
class FlashAttentionFunction(Function):
# 静态方法,用 @torch.no_grad() 装饰
@staticmethod
@torch.no_grad()
# 前向传播函数,接收参数 q, k, v, mask, causal, q_bucket_size, k_bucket_size
def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
""" Algorithm 1 in the v2 paper """
# 获取设备信息
device = q.device
# 获取最大负值
max_neg_value = -torch.finfo(q.dtype).max
# 计算 q 和 k 的长度差
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
# 初始化输出 o,所有行的和和最大值
o = torch.zeros_like(q)
all_row_sums = torch.zeros((*q.shape[:-1], 1), device=device)
all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, device=device)
# 缩放因子
scale = (q.shape[-1] ** -0.5)
# 计算行和列的分块数量
num_row_tiles = math.ceil(q.shape[-2] / q_bucket_size)
num_col_tiles = math.ceil(k.shape[-2] / k_bucket_size)
# 处理 mask
if exists(mask) and mask.ndim == 2:
mask = rearrange(mask, 'b n -> b 1 1 n')
if not exists(mask):
col_masks = (None,) * num_col_tiles
mask = (col_masks,) * num_row_tiles
else:
mask = ((mask,) * num_row_tiles) if mask.shape[-2] == 1 else mask.split(q_bucket_size, dim=-2)
mask = tuple(((row_mask,) * num_col_tiles) if row_mask.shape[-1] == 1 else row_mask.split(k_bucket_size, dim=-1) for row_mask in mask)
# 按行分块
row_splits = zip(
q.split(q_bucket_size, dim=-2),
o.split(q_bucket_size, dim=-2),
mask,
all_row_sums.split(q_bucket_size, dim=-2),
all_row_maxes.split(q_bucket_size, dim=-2),
)
for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
q_start_index = ind * q_bucket_size - qk_len_diff
# 按列分块
col_splits = zip(
k.split(k_bucket_size, dim=-2),
v.split(k_bucket_size, dim=-2),
row_mask
)
for k_ind, (kc, vc, col_mask) in enumerate(col_splits):
k_start_index = k_ind * k_bucket_size
# 计算注意力权重
attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale
if exists(col_mask):
attn_weights.masked_fill_(~col_mask, max_neg_value)
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu(q_start_index - k_start_index + 1)
attn_weights.masked_fill_(causal_mask, max_neg_value)
block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
exp_weights = torch.exp(attn_weights - new_row_maxes)
if exists(col_mask):
exp_weights.masked_fill_(~col_mask, 0.)
block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON)
exp_values = einsum('... i j, ... j d -> ... i d', exp_weights, vc)
exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
new_row_sums = exp_row_max_diff * row_sums + block_row_sums
oc.mul_(exp_row_max_diff).add_(exp_values)
row_maxes.copy_(new_row_maxes)
row_sums.copy_(new_row_sums)
oc.div_(row_sums)
lse = all_row_sums.log() + all_row_maxes
# 保存参数并返回输出 o
ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
ctx.save_for_backward(q, k, v, o, lse)
return o
# 静态方法,用 @torch.no_grad() 装饰
@staticmethod
@torch.no_grad()
# 定义一个向后传播函数,实现 v2 论文中的算法 2
def backward(ctx, do):
""" Algorithm 2 in the v2 paper """
# 从上下文中获取参数
causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
q, k, v, o, lse = ctx.saved_tensors
# 获取计算设备
device = q.device
# 获取最大负值
max_neg_value = -torch.finfo(q.dtype).max
# 计算 q 和 k 的长度差
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
# 初始化 dq, dk, dv
dq = torch.zeros_like(q)
dk = torch.zeros_like(k)
dv = torch.zeros_like(v)
# 按照 q_bucket_size 分割 q, o, do, mask, lse, dq
row_splits = zip(
q.split(q_bucket_size, dim = -2),
o.split(q_bucket_size, dim = -2),
do.split(q_bucket_size, dim = -2),
mask,
lse.split(q_bucket_size, dim = -2),
dq.split(q_bucket_size, dim = -2)
)
# 遍历每个分割后的行
for ind, (qc, oc, doc, row_mask, lsec, dqc) in enumerate(row_splits):
q_start_index = ind * q_bucket_size - qk_len_diff
# 按照 k_bucket_size 分割 k, v, dk, dv, row_mask
col_splits = zip(
k.split(k_bucket_size, dim = -2),
v.split(k_bucket_size, dim = -2),
dk.split(k_bucket_size, dim = -2),
dv.split(k_bucket_size, dim = -2),
row_mask
)
# 遍历每个分割后的列
for k_ind, (kc, vc, dkc, dvc, col_mask) in enumerate(col_splits):
k_start_index = k_ind * k_bucket_size
# 计算注意力权重
attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale
# 如果是因果注意力机制,并且 q_start_index 小于 (k_start_index + k_bucket_size - 1)
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype = torch.bool, device = device).triu(q_start_index - k_start_index + 1)
attn_weights.masked_fill_(causal_mask, max_neg_value)
# 计算概率
p = torch.exp(attn_weights - lsec)
# 如果存在列掩码,则将概率中对应位置置零
if exists(col_mask):
p.masked_fill_(~col_mask, 0.)
# 计算 dv_chunk
dv_chunk = einsum('... i j, ... i d -> ... j d', p, doc)
dp = einsum('... i d, ... j d -> ... i j', doc, vc)
# 计算 D 和 ds
D = (doc * oc).sum(dim = -1, keepdims = True)
ds = p * scale * (dp - D)
# 计算 dq_chunk, dk_chunk
dq_chunk = einsum('... i j, ... j d -> ... i d', ds, kc)
dk_chunk = einsum('... i j, ... i d -> ... j d', ds, qc)
# 累加到梯度中
dqc.add_(dq_chunk)
dkc.add_(dk_chunk)
dvc.add_(dv_chunk)
# 返回梯度 dq, dk, dv
return dq, dk, dv, None, None, None, None
# 主类 FlashAttention,用于实现注意力机制
# 在纯 PyTorch 中实现会比在 CUDA 中实现慢很多
# 用于调试和教育目的
class FlashAttention(nn.Module):
def __init__(
self,
*,
dim, # 输入维度
heads = 8, # 头数
dim_head = 64, # 每个头的维度
causal = False, # 是否使用因果注意力
q_bucket_size = 512, # 查询桶大小
k_bucket_size = 1024 # 键值桶大小
):
super().__init__()
self.heads = heads
self.causal = causal
inner_dim = heads * dim_head
self.to_q = nn.Linear(dim, inner_dim, bias = False) # 查询线性层
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False) # 键值线性层
self.to_out = nn.Linear(inner_dim, dim, bias = False) # 输出线性层
# 内存高效的注意力相关参数
# 可以在前向传播中被覆盖
self.q_bucket_size = q_bucket_size
self.k_bucket_size = k_bucket_size
def forward(
self,
x, # 输入张量
context = None, # 上下文张量
mask = None, # 掩码张量
q_bucket_size = None, # 查询桶大小
k_bucket_size = None, # 键值桶大小
):
q_bucket_size = default(q_bucket_size, self.q_bucket_size) # 设置查询桶大小
k_bucket_size = default(k_bucket_size, self.k_bucket_size) # 设置键值桶大小
h = self.heads
context = default(context, x) # 如果上下文为空,则使用输入张量作为上下文
q = self.to_q(x) # 计算查询张量
k, v = self.to_kv(context).chunk(2, dim = -1) # 计算键值张量并分割为键和值
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) # 重排张量形状
out = FlashAttentionFunction.apply(q, k, v, mask, self.causal, q_bucket_size, k_bucket_size) # 调用自定义的注意力函数
out = rearrange(out, 'b h n d -> b n (h d)') # 重排输出张量形状
return self.to_out(out) # 返回输出结果
.\lucidrains\memory-efficient-attention-pytorch\memory_efficient_attention_pytorch\memory_efficient_attention.py
import torch
from functools import partial
from torch import nn, einsum
from torch.utils.checkpoint import checkpoint
import torch.nn.functional as F
from einops import rearrange
# 导入所需的库
def exists(val):
return val is not None
# 检查值是否存在的辅助函数
def default(val, d):
return val if exists(val) else d
# 如果值存在则返回该值,否则返回默认值的辅助函数
# regular attention
def attention(
q, k, v,
mask = None,
causal = False,
attn_bias = None,
**kwargs
):
scale = q.shape[-1] ** -0.5
q = q * scale
# 缩放查询向量
sim = einsum('b h i d, b h j d -> b h i j', q, k)
# 计算注意力分数
if exists(attn_bias):
sim = sim + attn_bias
# 添加注意力偏置
mask_value = -torch.finfo(sim.dtype).max
# 计算掩码值
if exists(mask):
if mask.ndim == 2:
mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, mask_value)
# 应用掩码
if causal:
i, j = sim.shape[-2:]
mask = torch.ones(i, j, device = q.device, dtype = torch.bool).triu(j - i + 1)
sim = sim.masked_fill(mask, mask_value)
# 应用因果掩码
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
attn = sim.softmax(dim = -1)
# 计算注意力权重
out = einsum('b h i j, b h j d -> b h i d', attn, v)
return out
# 计算输出
# memory efficient attention
def summarize_qkv_chunk(q, k, v, mask, attn_bias_chunk, causal, qk_start_indices, dropout):
q_start_index, k_start_index, q_chunk_size, k_chunk_size, device = *qk_start_indices, q.shape[-2], k.shape[-2], q.device
weight = einsum('b h i d, b h j d -> b h i j', q, k)
# 计算权重
if exists(attn_bias_chunk):
weight = weight + attn_bias_chunk
# 添加注意力偏置
mask_value = -torch.finfo(weight.dtype).max
# 计算掩码值
if exists(mask):
mask = rearrange(mask, 'b j -> b 1 1 j')
weight = weight.masked_fill(~mask, mask_value)
# 应用掩码
if causal and q_start_index < (k_start_index + k_chunk_size - 1):
causal_mask = torch.ones((q_chunk_size, k_chunk_size), dtype = torch.bool, device = device).triu(q_start_index - k_start_index + 1)
weight = weight.masked_fill(causal_mask, mask_value)
# 应用因果掩码
weight_max = weight.amax(dim = -1, keepdim = True).detach()
weight = weight - weight_max
exp_weight = weight.exp()
exp_weight = F.dropout(exp_weight, p = dropout)
weighted_value = einsum('b h i j, b h j d -> b h i d', exp_weight, v)
return exp_weight.sum(dim = -1), weighted_value, rearrange(weight_max, '... 1 -> ...')
checkpointed_summarize_qkv_chunk = partial(checkpoint, summarize_qkv_chunk)
# 创建检查点函数
def memory_efficient_attention(
q, k, v,
mask = None,
causal = False,
attn_bias = None,
q_bucket_size = 512,
k_bucket_size = 1024,
eps = 1e-8,
dropout = 0.,
training = False
):
scale = q.shape[-1] ** -0.5
q = q * scale
# 缩放查询向量
needs_backwards = q.requires_grad or k.requires_grad or v.requires_grad
summarize_qkv_fn = checkpointed_summarize_qkv_chunk if needs_backwards else summarize_qkv_chunk
# 根据是否需要反向传播选择函数
q_chunks = q.split(q_bucket_size, dim = -2)
k_chunks = k.split(k_bucket_size, dim = -2)
v_chunks = v.split(k_bucket_size, dim = -2)
mask_chunks = mask.split(k_bucket_size, dim = -1) if exists(mask) else ((None,) * len(k_chunks))
if exists(attn_bias):
i, j = attn_bias.shape[-2:]
attn_bias_chunks = attn_bias.split(q_bucket_size, dim = -2)
attn_bias_chunks = list(map(lambda t: t.split(k_bucket_size, dim = -1), attn_bias_chunks))
# 将输入分块
out = []
# 初始化输出列表
# 遍历查询块列表,获取索引和查询块
for q_index, q_chunk in enumerate(q_chunks):
# 初始化空列表,用于存储期望权重、加权值和权重最大值
exp_weights = []
weighted_values = []
weight_maxes = []
# 遍历键值块、值块和掩码块的元组列表
for k_index, (k_chunk, v_chunk, mask_chunk) in enumerate(zip(k_chunks, v_chunks, mask_chunks)):
# 计算查询块和键块的起始索引
q_start_index = q_index * q_bucket_size
k_start_index = k_index * k_bucket_size
# 如果是因果的且键块的起始索引大于查询块的结束索引,则跳过当前循环
if causal and k_start_index > (q_start_index + q_chunk.shape[-2] - 1):
continue
# 如果存在注意力偏置,则获取当前注意力偏置块
attn_bias_chunk = attn_bias_chunks[q_index][k_index] if exists(attn_bias) else None
# 调用 summarize_qkv_fn 函数,计算期望权重、加权值和权重最大值
exp_weight_chunk, weighted_value_chunk, weight_max_chunk = summarize_qkv_fn(
q_chunk,
k_chunk,
v_chunk,
mask_chunk,
attn_bias_chunk,
causal,
(q_start_index, k_start_index),
dropout if training else 0.
)
# 将计算得到的结果添加到对应的列表中
exp_weights.append(exp_weight_chunk)
weighted_values.append(weighted_value_chunk)
weight_maxes.append(weight_max_chunk)
# 将权重最大值堆叠在一起
weight_maxes = torch.stack(weight_maxes, dim=-1)
# 将加权值堆叠在一起
weighted_values = torch.stack(weighted_values, dim=-1)
# 将期望权重堆叠在一起
exp_weights = torch.stack(exp_weights, dim=-1)
# 计算全局最大值
global_max = weight_maxes.amax(dim=-1, keepdim=True)
# 计算重新归一化因子
renorm_factor = (weight_maxes - global_max).exp().detach()
# 期望权重乘以重新归一化因子
exp_weights = exp_weights * renorm_factor
# 加权值乘以重新排列的重新归一化因子
weighted_values = weighted_values * rearrange(renorm_factor, '... c -> ... 1 c')
# 对所有加权值进行求和
all_values = weighted_values.sum(dim=-1)
# 对所有期望权重进行求和
all_weights = exp_weights.sum(dim=-1)
# 对归一化���的值进行计算
normalized_values = all_values / (rearrange(all_weights, '... -> ... 1') + eps)
# 将归一化后的值添加到输出列表中
out.append(normalized_values)
# 沿着指定维度连接输出列表中的张量
return torch.cat(out, dim=-2)
# 主要的注意力机制类
class Attention(nn.Module):
# 初始化函数
def __init__(
self,
*,
dim, # 输入维度
heads = 8, # 头数,默认为8
dim_head = 64, # 每个头的维度,默认为64
dropout = 0., # 丢弃概率,默认为0
causal = False, # 是否使用因果注意力,默认为False
memory_efficient = False, # 是否使用内存高效的注意力,默认为False
q_bucket_size = 512, # 查询桶大小,默认为512
k_bucket_size = 1024 # 键值桶大小,默认为1024
):
super().__init__()
self.heads = heads # 头数
self.causal = causal # 是否因果
self.dropout = dropout # 丢弃概率
inner_dim = heads * dim_head # 内部维度为头数乘以每个头的维度
self.to_q = nn.Linear(dim, inner_dim, bias = False) # 输入到查询的线性层
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False) # 输入到键值的线性层
self.to_out = nn.Linear(inner_dim, dim, bias = False) # 输出的线性层
# 内存高效注意力相关参数
# 可在前向传播中覆盖
self.memory_efficient = memory_efficient # 是否内存高效
self.q_bucket_size = q_bucket_size # 查询桶大小
self.k_bucket_size = k_bucket_size # 键值桶大小
# 前向传播函数
def forward(
self,
x, # 输入张量
context = None, # 上下文,默认为None
mask = None, # 掩码,默认为None
attn_bias = None, # 注意力偏置,默认为None
memory_efficient = None, # 是否内存高效,默认为None
q_bucket_size = None, # 查询桶大小,默认为None
k_bucket_size = None, # 键值桶大小,默认为None
):
memory_efficient = default(memory_efficient, self.memory_efficient) # 使用默认值或者自定义值
q_bucket_size = default(q_bucket_size, self.q_bucket_size) # 使用默认值或者自定义值
k_bucket_size = default(k_bucket_size, self.k_bucket_size) # 使用默认值或者自定义值
h = self.heads # 头数
context = default(context, x) # 上下文,默认为输入张量
q = self.to_q(x) # 查询张量
k, v = self.to_kv(context).chunk(2, dim = -1) # 键值张量拆分为k和v
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) # 重排张量形状
attn_fn = attention if not memory_efficient else memory_efficient_attention # 根据内存高效性选择不同的注意力函数
out = attn_fn(q, k, v, mask = mask, attn_bias = attn_bias, causal = self.causal, q_bucket_size = q_bucket_size,
k_bucket_size = k_bucket_size, dropout = self.dropout, training = self.training) # 注意力计算
out = rearrange(out, 'b h n d -> b n (h d)') # 重排输出形状
return self.to_out(out) # 输出结果
.\lucidrains\memory-efficient-attention-pytorch\memory_efficient_attention_pytorch\memory_efficient_cosine_sim_attention.py
import math
import torch
import torch.nn.functional as F
from functools import partial
from torch import nn, einsum
from torch.utils.checkpoint import checkpoint
from einops import rearrange
# helper functions
# 检查变量是否存在
def exists(val):
return val is not None
# 如果变量存在则返回其值,否则返回默认值
def default(val, d):
return val if exists(val) else d
# 对输入张量进行 L2 归一化
def l2norm(t):
return F.normalize(t, dim = -1)
# regular attention
# 普通的注意力机制
def attention(
q, k, v,
mask = None,
causal = False,
attn_bias = None,
**kwargs
):
# 计算查询、键之间的相似度
sim = einsum('b h i d, b h j d -> b h i j', q, k)
# 添加注意力偏置
if exists(attn_bias):
sim = sim + attn_bias
mask_value = -torch.finfo(sim.dtype).max
# 处理掩码
if exists(mask):
mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, mask_value)
# 处理因果关系
if causal:
i, j = sim.shape[-2:]
mask = torch.ones(i, j, device = q.device, dtype = torch.bool).triu(j - i + 1)
sim = sim.masked_fill(mask, mask_value)
# 计算注意力权重
attn = sim.softmax(dim = -1)
# 计算输出
out = einsum('b h i j, b h j d -> b h i d', attn, v)
return out
# memory efficient attention
# 汇总查询、键、值的函数
def summarize_qkv_chunk(q, k, v, mask, attn_bias_chunk, causal, qk_start_indices):
q_start_index, k_start_index, q_chunk_size, k_chunk_size, device = *qk_start_indices, q.shape[-2], k.shape[-2], q.device
weight = einsum('b h i d, b h j d -> b h i j', q, k)
if exists(attn_bias_chunk):
weight = weight + attn_bias_chunk
mask_value = -torch.finfo(weight.dtype).max
if exists(mask):
mask = rearrange(mask, 'b j -> b 1 1 j')
weight = weight.masked_fill(~mask, mask_value)
if causal and q_start_index < (k_start_index + k_chunk_size - 1):
causal_mask = torch.ones((q_chunk_size, k_chunk_size), dtype = torch.bool, device = device).triu(q_start_index - k_start_index + 1)
weight = weight.masked_fill(causal_mask, mask_value)
exp_weight = weight.exp()
weighted_value = einsum('b h i j, b h j d -> b h i d', exp_weight, v)
return exp_weight.sum(dim = -1), weighted_value
# 使用 checkpoint 优化的汇总查询、键、值的函数
checkpointed_summarize_qkv_chunk = partial(checkpoint, summarize_qkv_chunk)
# 数值不稳定的内存高效注意力机制
def numerically_unstable_memory_efficient_attention(
q, k, v,
mask = None,
causal = False,
attn_bias = None,
q_bucket_size = 512,
k_bucket_size = 1024,
eps = 1e-8
):
needs_backwards = q.requires_grad or k.requires_grad or v.requires_grad
summarize_qkv_fn = checkpointed_summarize_qkv_chunk if needs_backwards else summarize_qkv_chunk
# 将所有输入分块
q_chunks = q.split(q_bucket_size, dim = -2)
k_chunks = k.split(k_bucket_size, dim = -2)
v_chunks = v.split(k_bucket_size, dim = -2)
mask_chunks = mask.split(k_bucket_size, dim = -1) if exists(mask) else ((None,) * len(k_chunks))
if exists(attn_bias):
i, j = attn_bias.shape[-2:]
attn_bias_chunks = attn_bias.split(q_bucket_size, dim = -2)
attn_bias_chunks = list(map(lambda t: t.split(k_bucket_size, dim = -1), attn_bias_chunks))
# 循环遍历所有块并累积
out = []
# 遍历查询块列表,获取索引和查询块
for q_index, q_chunk in enumerate(q_chunks):
# 计算查询块的起始索引
q_start_index = q_index * q_bucket_size
# 初始化期望权重列表和加权值列表
exp_weights = []
weighted_values = []
# 遍历键值块、值块和掩码块的元组列表,获取索引和对应的块
for k_index, (k_chunk, v_chunk, mask_chunk) in enumerate(zip(k_chunks, v_chunks, mask_chunks)):
# 计算键块的起始索引
k_start_index = k_index * k_bucket_size
# 如果是因果的且键块的起始索引大于查询块的起始索引加上查询块的长度减1,则跳过当前循环
if causal and k_start_index > (q_start_index + q_chunk.shape[-2] - 1):
continue
# 如果存在注意力偏置,则获取当前查询块和键块对应的注意力偏置
attn_bias_chunk = attn_bias_chunks[q_index][k_index] if exists(attn_bias) else None
# 调用summarize_qkv_fn函数,计算期望权重和加权值
exp_weight_chunk, weighted_value_chunk = summarize_qkv_fn(
q_chunk,
k_chunk,
v_chunk,
mask_chunk,
attn_bias_chunk,
causal,
(q_start_index, k_start_index)
)
# 将计算得到的期望权重和加权值添加到对应列表中
exp_weights.append(exp_weight_chunk)
weighted_values.append(weighted_value_chunk)
# 计算所有加权值的总和
all_values = sum(weighted_values)
# 计算所有期望权重的总和
all_weights = sum(exp_weights)
# 对所有加权值进行归一化处理
normalized_values = all_values / (rearrange(all_weights, '... -> ... 1') + eps)
# 将归一化后的值添加到输出列表中
out.append(normalized_values)
# 沿着指定维度连接输出列表中的张量,形成最终输出结果
return torch.cat(out, dim=-2)
# 主要类定义
class CosineSimAttention(nn.Module):
def __init__(
self,
*,
dim,
seq_len,
heads = 8,
dim_head = 64,
dropout = 0.,
causal = False,
memory_efficient = False,
q_bucket_size = 512,
k_bucket_size = 1024
):
super().__init__()
self.heads = heads
self.causal = causal
inner_dim = heads * dim_head
# 初始化缩放参数
scale_init_value = -math.log(math.log2(seq_len ** 2 - seq_len))
self.scale = nn.Parameter(torch.full((1, heads, 1, 1), scale_init_value))
# 线性变换层,将输入维度映射到内部维度
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
# 内存高效注意力相关参数
# 可在前向传播中覆盖
self.memory_efficient = memory_efficient
self.q_bucket_size = q_bucket_size
self.k_bucket_size = k_bucket_size
def forward(
self,
x,
context = None,
mask = None,
attn_bias = None,
memory_efficient = None,
q_bucket_size = None,
k_bucket_size = None,
):
memory_efficient = default(memory_efficient, self.memory_efficient)
q_bucket_size = default(q_bucket_size, self.q_bucket_size)
k_bucket_size = default(k_bucket_size, self.k_bucket_size)
h = self.heads
context = default(context, x)
# 对输入进行线性变换得到查询、键、值
q = self.to_q(x)
k, v = self.to_kv(context).chunk(2, dim = -1)
# 重排维度以适应多头注意力计算
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
# 对查询、键进行 L2 归一化
q, k = map(l2norm, (q, k))
# 缩放查询
q = q * self.scale.exp()
# 根据内存高效标志选择注意力函数
attn_fn = attention if not memory_efficient else numerically_unstable_memory_efficient_attention
# 计算注意力得到输出
out = attn_fn(q, k, v, mask = mask, attn_bias = attn_bias, causal = self.causal, q_bucket_size = q_bucket_size, k_bucket_size = k_bucket_size)
# 重排维度以还原原始形状
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
.\lucidrains\memory-efficient-attention-pytorch\memory_efficient_attention_pytorch\reversible.py
# 导入 torch 库
import torch
# 导入 torch 中的神经网络模块
import torch.nn as nn
# 从 operator 模块中导入 itemgetter 函数
from operator import itemgetter
# 从 torch.autograd.function 模块中导入 Function 类
from torch.autograd.function import Function
# 从 torch.utils.checkpoint 模块中导入 get_device_states 和 set_device_states 函数
# 用于将参数路由到可逆层函数中的函数
def route_args(router, args, depth):
# 初始化路由后的参数列表
routed_args = [(dict(), dict()) for _ in range(depth)]
# 获取参数中与路由器匹配的键
matched_keys = [key for key in args.keys() if key in router]
# 遍历匹配的键
for key in matched_keys:
val = args[key]
# 遍历路由后的参数列表和路由器中的路由
for depth, ((f_args, g_args), routes) in enumerate(zip(routed_args, router[key])):
# 根据路由将参数添加到对应的函数参数中
new_f_args, new_g_args = map(lambda route: ({key: val} if route else {}), routes)
routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args})
return routed_args
# 参考示例 https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html 中的保存和设置随机数生成器
class Deterministic(nn.Module):
def __init__(self, net):
super().__init__()
self.net = net
self.cpu_state = None
self.cuda_in_fwd = None
self.gpu_devices = None
self.gpu_states = None
def record_rng(self, *args):
self.cpu_state = torch.get_rng_state()
if torch.cuda._initialized:
self.cuda_in_fwd = True
self.gpu_devices, self.gpu_states = get_device_states(*args)
def forward(self, *args, record_rng = False, set_rng = False, **kwargs):
if record_rng:
self.record_rng(*args)
if not set_rng:
return self.net(*args, **kwargs)
rng_devices = []
if self.cuda_in_fwd:
rng_devices = self.gpu_devices
with torch.random.fork_rng(devices=rng_devices, enabled=True):
torch.set_rng_state(self.cpu_state)
if self.cuda_in_fwd:
set_device_states(self.gpu_devices, self.gpu_states)
return self.net(*args, **kwargs)
# 受 https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py 启发
# 一旦多 GPU 确认工作正常,重构并将 PR 发回源代码
class ReversibleBlock(nn.Module):
def __init__(self, f, g):
super().__init__()
self.f = Deterministic(f)
self.g = Deterministic(g)
def forward(self, x, f_args = {}, g_args = {}):
x1, x2 = torch.chunk(x, 2, dim=2)
y1, y2 = None, None
with torch.no_grad():
y1 = x1 + self.f(x2, record_rng=self.training, **f_args)
y2 = x2 + self.g(y1, record_rng=self.training, **g_args)
return torch.cat([y1, y2], dim=2)
def backward_pass(self, y, dy, f_args = {}, g_args = {}):
y1, y2 = torch.chunk(y, 2, dim=2)
del y
dy1, dy2 = torch.chunk(dy, 2, dim=2)
del dy
with torch.enable_grad():
y1.requires_grad = True
gy1 = self.g(y1, set_rng=True, **g_args)
torch.autograd.backward(gy1, dy2)
with torch.no_grad():
x2 = y2 - gy1
del y2, gy1
dx1 = dy1 + y1.grad
del dy1
y1.grad = None
with torch.enable_grad():
x2.requires_grad = True
fx2 = self.f(x2, set_rng=True, **f_args)
torch.autograd.backward(fx2, dx1, retain_graph=True)
with torch.no_grad():
x1 = y1 - fx2
del y1, fx2
dx2 = dy2 + x2.grad
del dy2
x2.grad = None
x = torch.cat([x1, x2.detach()], dim=2)
dx = torch.cat([dx1, dx2], dim=2)
return x, dx
class _ReversibleFunction(Function):
@staticmethod
def forward(ctx, x, blocks, args):
ctx.args = args
for block, kwarg in zip(blocks, args):
x = block(x, **kwarg)
ctx.y = x.detach()
ctx.blocks = blocks
return x
@staticmethod
# 定义反向传播函数,接收上下文和梯度作为参数
def backward(ctx, dy):
# 获取上下文中的 y 和 args
y = ctx.y
args = ctx.args
# 反向遍历上下文中的 blocks 和 args
for block, kwargs in zip(ctx.blocks[::-1], args[::-1]):
# 调用每个 block 的反向传播函数,更新 y 和 dy
y, dy = block.backward_pass(y, dy, **kwargs)
# 返回更新后的梯度
return dy, None, None
# 定义一个可逆序列的神经网络模块
class ReversibleSequence(nn.Module):
# 初始化函数,接受一组块和参数路由作为输入
def __init__(self, blocks, args_route = {}):
super().__init__()
# 将参数路由保存在对象中
self.args_route = args_route
# 创建一个包含多个可逆块的模块列表
self.blocks = nn.ModuleList([ReversibleBlock(f=f, g=g) for f, g in blocks])
# 前向传播函数
def forward(self, x, **kwargs):
# 在最后一个维度上将输入张量 x 进行拼接
x = torch.cat([x, x], dim=-1)
# 获取模块列表和参数路由
blocks = self.blocks
args = route_args(self.args_route, kwargs, len(blocks))
# 将参数转换为字典形式
args = list(map(lambda x: {'f_args': x[0], 'g_args': x[1]}, args))
# 将块和参数组成元组列表
layers_and_args = list(zip(blocks, args))
# 调用自定义的可逆函数 _ReversibleFunction 的前向传播方法
out = _ReversibleFunction.apply(x, blocks, args)
# 在最后一个维度上将输出张量拆分成两部分,然后对它们进行求和
return torch.stack(out.chunk(2, dim=-1)).sum(dim=0)
.\lucidrains\memory-efficient-attention-pytorch\memory_efficient_attention_pytorch\transformer.py
# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块和 einsum 函数
from torch import nn, einsum
# 从 torch 库中导入 nn.functional 模块,并重命名为 F
import torch.nn.functional as F
# 从 functools 库中导入 partial 函数
from functools import partial
# 从 einops 库中导入 rearrange 函数
from einops import rearrange
# 从 memory_efficient_attention_pytorch 库中导入 FlashAttention 和 Attention 类
from memory_efficient_attention_pytorch import FlashAttention, Attention
# 从 memory_efficient_attention_pytorch.reversible 库中导入 ReversibleSequence 类
from memory_efficient_attention_pytorch.reversible import ReversibleSequence
# 定义一个函数,用于检查变量是否存在
def exists(val):
return val is not None
# 定义一个继承自 nn.Module 的类 PreNorm
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = nn.LayerNorm(dim)
def forward(self, x, **kwargs):
# 对输入数据进行 LayerNorm 归一化
x = self.norm(x)
# 调用传入的函数处理归一化后的数据
return self.fn(x, **kwargs)
# 定义一个继承自 nn.Module 的类 FeedForward
class FeedForward(nn.Module):
def __init__(self, dim, mult = 4, chunks = 1):
super().__init__()
self.chunks = chunks
# 定义一个包含线性层和 GELU 激活函数的神经网络
self.net = nn.Sequential(
nn.Linear(dim, dim * mult),
nn.GELU(),
nn.Linear(dim * mult, dim)
)
def forward(self, x):
# 如果 chunks 小于等于 1,则直接对输入数据进行处理
if self.chunks <= 1:
return self.net(x)
# 将输入数据按照指定维度进行切分
chunks = x.chunk(self.chunks, dim = 1)
# 对每个切分后的数据块进行处理
out = [self.net(chunk) for chunk in chunks]
# 将处理后的数据块拼接在一起
return torch.cat(out, dim = 1)
# 定义一个继承自 nn.Module 的类 Transformer
class Transformer(nn.Module):
def __init__(
self,
*,
num_tokens,
max_seq_len,
dim,
depth,
causal = False,
dim_head = 64,
heads = 8,
ff_mult = 4,
ff_chunks = 1,
use_flash_attn = True,
**kwargs
):
super().__init__()
self.max_seq_len = max_seq_len
# 定义一个 token 的 Embedding 层
self.token_emb = nn.Embedding(num_tokens, dim)
# ���义一个位置编码的 Embedding 层
self.pos_emb = nn.Embedding(max_seq_len, dim)
# 根据 use_flash_attn 参数选择不同的注意力机制类
attn_klass = FlashAttention if use_flash_attn else partial(Attention, memory_efficient = True)
# 初始化一个空的神经网络层列表
self.layers = nn.ModuleList([])
# 根据深度循环创建多个层
for _ in range(depth):
# 每个层包含一个注意力机制和一个前馈神经网络
self.layers.append(nn.ModuleList([
PreNorm(dim, attn_klass(dim = dim, dim_head = dim_head, heads = heads, causal = causal, **kwargs)),
PreNorm(dim, FeedForward(dim = dim, mult = ff_mult, chunks = ff_chunks)),
]))
# 创建一个可逆序列
self.net = ReversibleSequence(self.layers)
# 定义一个输出层,用于将模型输出转换为预测标签
self.to_logits = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_tokens)
)
def forward(self, x, labels = None):
device = x.device
# 对输入数据进行 token embedding
x = self.token_emb(x)
# 生成位置编码
pos_emb = self.pos_emb(torch.arange(x.shape[-2], device = device))
x = x + pos_emb
# 通过网络层进行前向传播
x = self.net(x)
# 将输出数据转换为预测标签
logits = self.to_logits(x)
# 如果不存在标签,则直接返回预测结果
if not exists(labels):
return logits
# 计算交叉熵损失
return F.cross_entropy(rearrange(logits, 'b n d -> b d n'), labels)
.\lucidrains\memory-efficient-attention-pytorch\memory_efficient_attention_pytorch\__init__.py
# 从 memory_efficient_attention_pytorch.memory_efficient_attention 模块中导入 Attention 类和 memory_efficient_attention 函数
from memory_efficient_attention_pytorch.memory_efficient_attention import Attention, memory_efficient_attention
# 从 memory_efficient_attention_pytorch.memory_efficient_cosine_sim_attention 模块中导入 CosineSimAttention 类和 numerically_unstable_memory_efficient_attention 函数
from memory_efficient_attention_pytorch.memory_efficient_cosine_sim_attention import CosineSimAttention, numerically_unstable_memory_efficient_attention
# 从 memory_efficient_attention_pytorch.flash_attention 模块中导入 FlashAttention 类
from memory_efficient_attention_pytorch.flash_attention import FlashAttention
Memory Efficient Attention Pytorch (obsolete)
Implementation of a memory efficient multi-head attention as proposed in the paper, Self-attention Does Not Need O(n²) Memory. In addition, the module will take care of masking, causal masking, as well as cross attention.
This repository also contains a naive non-CUDA implementation of the improvements made by Tri Dao with his Flash Attention 2 paper, for educational purposes. It is a game changer for attention and building long-context transformers.
Update: from now on, you should just be using the F.scaled_dot_product_attention function in Pytorch 2.0 for built-in Flash Attention v1 support - or use Flash Attention v2 at the official repository
Install
$ pip install memory-efficient-attention-pytorch
Usage
For autoregressive language model
import torch
from memory_efficient_attention_pytorch import Attention
attn = Attention(
dim = 512,
dim_head = 64, # dimension per head
heads = 8, # number of attention heads
causal = True, # autoregressive or not
memory_efficient = True, # whether to use memory efficient attention (can be turned off to test against normal attention)
q_bucket_size = 1024, # bucket size along queries dimension
k_bucket_size = 2048 # bucket size along key / values dimension
).cuda()
x = torch.randn(1, 65536, 512).cuda()
out = attn(x) # (1, 65536, 512)
Cross attention
import torch
from memory_efficient_attention_pytorch import Attention
cross_attn = Attention(
dim = 512,
dim_head = 64,
heads = 8,
memory_efficient = True,
q_bucket_size = 1024,
k_bucket_size = 2048
).cuda()
x = torch.randn(1, 65536, 512).cuda()
context = torch.randn(1, 65536, 512).cuda()
mask = torch.ones(1, 65536).bool().cuda()
out = cross_attn(x, context = context, mask = mask) # (1, 65536, 512)
Citations
@misc{rabe2021selfattention,
title = {Self-attention Does Not Need $O(n^2)$ Memory},
author = {Markus N. Rabe and Charles Staats},
year = {2021},
eprint = {2112.05682},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
@misc{liu2021swin,
title = {Swin Transformer V2: Scaling Up Capacity and Resolution},
author = {Ze Liu and Han Hu and Yutong Lin and Zhuliang Yao and Zhenda Xie and Yixuan Wei and Jia Ning and Yue Cao and Zheng Zhang and Li Dong and Furu Wei and Baining Guo},
year = {2021},
eprint = {2111.09883},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
@article{Dao2022FlashAttentionFA,
title = {FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness},
author = {Tri Dao and Daniel Y. Fu and Stefano Ermon and Atri Rudra and Christopher R'e},
journal = {ArXiv},
year = {2022},
volume = {abs/2205.14135}
}
@article{dao2023flashattention2,
title = {Flash{A}ttention-2: Faster Attention with Better Parallelism and Work Partitioning,
author = {Dao, Tri},
year = {2023}
}
.\lucidrains\memory-efficient-attention-pytorch\setup.py
# 导入设置工具和查找包的函数
from setuptools import setup, find_packages
# 设置包的元数据
setup(
name = 'memory-efficient-attention-pytorch', # 包的名称
packages = find_packages(exclude=[]), # 查找所有包
version = '0.1.6', # 版本号
license='MIT', # 许可证
description = 'Memory Efficient Attention - Pytorch', # 描述
long_description_content_type = 'text/markdown', # 长描述内容类型
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
url = 'https://github.com/lucidrains/memory-efficient-attention-pytorch', # 项目链接
keywords = [
'artificial intelligence', # 关键词
'deep learning', # 关键词
'attention-mechanism' # 关键词
],
install_requires=[
'einops>=0.4.1', # 安装所需的依赖项
'torch>=1.6' # 安装所需的依赖项
],
setup_requires=[
'pytest-runner', # 安装设置所需的依赖项
],
tests_require=[
'pytest' # 安装测试所需的依赖项
],
classifiers=[
'Development Status :: 4 - Beta', # 分类器
'Intended Audience :: Developers', # 分类器
'Topic :: Scientific/Engineering :: Artificial Intelligence', # 分类器
'License :: OSI Approved :: MIT License', # 分类器
'Programming Language :: Python :: 3.8', # 分类器
],
)