Lucidrains-系列项目源码解析-五十八-

50 阅读4分钟

Lucidrains 系列项目源码解析(五十八)

.\lucidrains\memorizing-transformers-pytorch\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'memorizing-transformers-pytorch',  # 包的名称
  packages = find_packages(exclude=[]),  # 查找所有包
  version = '0.4.1',  # 版本号
  license='MIT',  # 许可证
  description = 'Memorizing Transformer - Pytorch',  # 描述
  long_description_content_type = 'text/markdown',  # 长描述内容类型
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  url = 'https://github.com/lucidrains/memorizing-transformers-pytorch',  # 项目链接
  keywords = [  # 关键词列表
    'artificial intelligence',
    'deep learning',
    'transformers',
    'memory',
    'retrieval'
  ],
  install_requires=[  # 安装依赖
    'einops>=0.6',
    'filelock',
    'joblib',
    'faiss-gpu',
    'numpy',
    'torch>=1.6',
  ],
  classifiers=[  # 分类器
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\memorizing-transformers-pytorch\train.py

# 导入所需的库
from memorizing_transformers_pytorch import MemorizingTransformer

import random
import tqdm
import gzip
import numpy as np
import torch
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset

# 常量定义
NUM_BATCHES = int(1e5)
BATCH_SIZE = 16
SEQ_LEN = 512
SEGMENTS = 5

LEARNING_RATE = 2e-4
MAX_GRAD_CLIP_NORM = 0.5

VALIDATE_EVERY  = 100
GENERATE_EVERY  = 500
GENERATE_LENGTH = 512

# 辅助函数

# 从 token 解码为字符
def decode_token(token):
    return str(chr(max(32, token)))

# 从 tokens 解码为字符串
def decode_tokens(tokens):
    return ''.join(list(map(decode_token, tokens)))

# 实例化类似 GPT 的解码器模型
model = MemorizingTransformer(
    num_tokens = 256,
    dim = 512,
    depth = 8,
    memorizing_layers = 4,
    max_knn_memories = 512 * 15,
    num_retrieved_memories = 32,
    xl_memory_layers = (7, 8),
    xl_max_memories = 512,
).cuda()

# 准备 enwik8 数据
with gzip.open('./data/enwik8.gz') as file:
    X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
    trX, vaX = np.split(X, [int(90e6)])
    data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)

# 定义文本采样数据集类
class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
        full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
        return full_seq.cuda()

    def __len__(self):
        return self.data.size(0) // self.seq_len

# 数据集和数据加载器
train_dataset = TextSamplerDataset(data_train, SEQ_LEN * SEGMENTS)
train_loader  = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE, drop_last = True))
valid_dataset = TextSamplerDataset(data_val, SEQ_LEN * SEGMENTS)
valid_loader = cycle(DataLoader(valid_dataset, batch_size = BATCH_SIZE, drop_last = True))

# 优化器
optim = torch.optim.Adam(model.parameters(), lr = LEARNING_RATE)

# 训练
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval = 10., desc = 'training'):
    model.train()

    data = next(train_loader)

    train_loss = 0.
    with model.knn_memories_context(batch_size = BATCH_SIZE) as knn_memories:
        xl_memories = None    
        seq, labels = data[:, :-1], data[:, 1:]

        for seq_segment, labels_segment in zip(seq.chunk(SEGMENTS, dim = -1), labels.chunk(SEGMENTS, dim = -1)):
            loss, xl_memories = model(
                seq_segment,
                labels = labels_segment,
                knn_memories = knn_memories,
                xl_memories = xl_memories
            )

            train_loss += loss.item() / SEGMENTS
            (loss / SEGMENTS).backward()    

    print(f'training loss: {train_loss}')
    torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_CLIP_NORM)
    optim.step()
    optim.zero_grad()

    if not (i % VALIDATE_EVERY):
        model.eval()

        valid_data = next(valid_loader)
        valid_loss = 0.

        with torch.no_grad(), model.knn_memories_context(batch_size = BATCH_SIZE) as knn_memories:
            xl_memories = None    
            seq, labels = data[:, :-1], data[:, 1:]

            for seq_segment, labels_segment in zip(seq.chunk(SEGMENTS, dim = -1), labels.chunk(SEGMENTS, dim = -1)):
                loss, xl_memories = model(
                    seq_segment,
                    labels = labels_segment,
                    knn_memories = knn_memories,
                    xl_memories = xl_memories
                )

                valid_loss += loss.item() / SEGMENTS

        print(f'valid loss: {valid_loss}')

.\lucidrains\memory-compressed-attention\memory_compressed_attention\memory_compressed_attention.py

import torch
from torch import nn
import torch.nn.functional as F

# 定义卷积压缩类
class ConvCompress(nn.Module):
    def __init__(self, dim, ratio = 3, groups = 1):
        super().__init__()
        self.conv = nn.Conv1d(dim, dim, ratio, stride = ratio, groups = groups)

    def forward(self, mem):
        mem = mem.transpose(1, 2)
        compressed_mem = self.conv(mem)
        return compressed_mem.transpose(1, 2)

# 主类
class MemoryCompressedAttention(nn.Module):
    def __init__(
        self,
        dim,
        heads = 8,
        causal = False,
        compression_factor = 3,
        dropout = 0.):
        super().__init__()
        assert (dim % heads) == 0, 'dimension must be divisible by number of heads'

        self.heads = heads
        self.causal = causal

        self.compression_factor = compression_factor
        self.compress_fn = ConvCompress(dim, compression_factor, groups = heads)

        self.to_qkv = nn.Linear(dim, dim * 3, bias = False)
        self.to_out = nn.Linear(dim, dim)

        self.dropout = nn.Dropout(dropout)

        self.null_k = nn.Parameter(torch.zeros(1, 1, dim))
        self.null_v = nn.Parameter(torch.zeros(1, 1, dim))

    def forward(self, x, input_mask = None):
        b, t, d, h, cf, device = *x.shape, self.heads, self.compression_factor, x.device
        q, k, v = self.to_qkv(x).chunk(3, dim=-1)

        # 确保键和值的序列长度可以被压缩因子整除
        padding = cf - (t % cf)
        if padding < cf:
            k, v = map(lambda t: F.pad(t, (0, 0, padding, 0)), (k, v))

        # 压缩键和值
        k, v = map(self.compress_fn, (k, v))

        # 在第一个查询没有键需要关注的情况下,附加一个空键和值
        nk, nv = map(lambda t: t.expand(b, -1, -1), (self.null_k, self.null_v))
        k = torch.cat((nk, k), dim=1)
        v = torch.cat((nv, v), dim=1)

        # 合并头部
        q, k, v = map(lambda t: t.reshape(*t.shape[:2], h, -1).transpose(1, 2), (q, k, v))

        # 注意力计算
        dots = torch.einsum('bhid,bhjd->bhij', q, k) * d ** -0.5

        mask_value = -torch.finfo(dots.dtype).max

        # 如果需要,进行因果遮罩
        if self.causal:
            mask_q = mask_k = torch.arange(t, device=device)

            if padding < cf:
                mask_k = F.pad(mask_k, (padding, 0))

            mask_k, _ = mask_k.reshape(-1, cf).max(dim=-1)
            mask = mask_q[:, None] < mask_k[None, :]
            mask = F.pad(mask, (1, 0), value=False)

            dots.masked_fill_(mask[None, None, ...], mask_value)
            del mask

        # 输入遮罩
        if input_mask is not None:
            mask_q = mask_k = input_mask
            if padding < cf:
                mask_k = F.pad(mask_k, (padding, 0), value=True)
            mask_k = mask_k.reshape(b, -1, cf).sum(dim=-1) > 0
            mask = mask_q[:, None, :, None] < mask_k[:, None, None, :]
            mask = F.pad(mask, (1, 0), value=True)

            dots.masked_fill_(~mask, mask_value)
            del mask

        # 注意力权重
        attn = dots.softmax(dim=-1)

        # dropout
        attn = self.dropout(attn)
        out = torch.einsum('bhij,bhjd->bhid', attn, v)

        # 分割头部并合并
        out = out.transpose(1, 2).reshape(b, t, d)
        return self.to_out(out)

.\lucidrains\memory-compressed-attention\memory_compressed_attention\__init__.py

# 从 memory_compressed_attention.memory_compressed_attention 模块中导入 MemoryCompressedAttention 类
from memory_compressed_attention.memory_compressed_attention import MemoryCompressedAttention

Memory Compressed Attention

Implementation of the Self-Attention layer of the proposed Memory-Compressed Attention, in Pytorch. This repository offers both the causal and non-causal variant, and will take care of the padding if the sequence length is not divisible by the compression ratio.

The code also resolves an edge-case where the very first query have no keys to attend to in the auto-regressive scenario. The solution is to use null key/values, appended to the final compressed set, so that there is always at least 1 key for all queries to attend to.

Install

$ pip install memory_compressed_attention

Usage

import torch
from memory_compressed_attention import MemoryCompressedAttention

attn = MemoryCompressedAttention(
    dim = 512,
    heads = 8,                 # number of heads
    causal = False,            # auto-regressive or not
    compression_factor = 3,    # compression ratio
    dropout = 0.1              # dropout post-attention
)

x = torch.randn(1, 1024, 512)
mask = torch.ones(1, 1024).bool()

attn(x, input_mask = mask) # (1, 1024, 512)

Citations

@misc{liu2018generating,
    title={Generating Wikipedia by Summarizing Long Sequences},
    author={Peter J. Liu and Mohammad Saleh and Etienne Pot and Ben Goodrich and Ryan Sepassi and Lukasz Kaiser and Noam Shazeer},
    year={2018},
    eprint={1801.10198},
    archivePrefix={arXiv},
    primaryClass={cs.CL}
}

.\lucidrains\memory-compressed-attention\setup.py

# 导入设置和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'memory_compressed_attention',  # 包的名称
  packages = find_packages(),  # 查找所有包
  version = '0.0.7',  # 版本号
  license='MIT',  # 许可证
  description = 'Memory-Compressed Self Attention',  # 描述
  long_description_content_type = 'text/markdown',  # 长描述内容类型
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  url = 'https://github.com/lucidrains/memory-compressed-attention',  # 项目链接
  keywords = ['transformers', 'artificial intelligence', 'attention mechanism'],  # 关键词
  install_requires=[
    'torch'  # 安装所需的依赖
  ],
  classifiers=[
    'Development Status :: 4 - Beta',  # 开发状态
    'Intended Audience :: Developers',  # 预期受众
    'Topic :: Scientific/Engineering :: Artificial Intelligence',  # 主题
    'License :: OSI Approved :: MIT License',  # 许可证
    'Programming Language :: Python :: 3.6',  # 编程语言
  ],
)

My explorations into editing the knowledge and memories of an attention network.

Citations

@article{meng2022memit,
  title   = {Mass Editing Memory in a Transformer},
  author  = {Kevin Meng and Sen Sharma, Arnab and Alex Andonian and Yonatan Belinkov and David Bau},
  journal = {arXiv preprint arXiv:2210.07229},
  year    = {2022}
}
@inproceedings{Burns2022DiscoveringLK,
  title  = {Discovering Latent Knowledge in Language Models Without Supervision},
  author = {Collin Burns and Hao-Tong Ye and Dan Klein and Jacob Steinhardt},
  year   = {2022}
}

Data source

The enwik8 data was downloaded from the Hutter prize page: prize.hutter1.net/

.\lucidrains\memory-efficient-attention-pytorch\memory_efficient_attention_pytorch\autoregressive_wrapper.py

import torch
from torch import nn
import torch.nn.functional as F

# helper function

# 检查值是否存在的辅助函数
def exists(val):
    return val is not None

# 评估装饰器函数
def eval_decorator(fn):
    def inner(model, *args, **kwargs):
        # 保存模型原始训练状态
        was_training = model.training
        # 将模型设置为评估模式
        model.eval()
        # 调用传入的函数
        out = fn(model, *args, **kwargs)
        # 恢复模型原始训练状态
        model.train(was_training)
        return out
    return inner

# top k filtering

# 根据阈值过滤 logits 中的 top k 值
def top_k(logits, thres = 0.9):
    # 计算 top k 的数量
    k = int((1 - thres) * logits.shape[-1])
    # 获取 top k 的值和索引
    val, ind = torch.topk(logits, k)
    # 创建与 logits 相同形状的全为负无穷的张量
    probs = torch.full_like(logits, float('-inf'))
    # 根据索引将 top k 的值填充到 probs 中
    probs.scatter_(1, ind, val)
    return probs

# 自回归包装器类
class AutoregressiveWrapper(nn.Module):
    def __init__(self, net, pad_value = 0):
        super().__init__()
        self.pad_value = pad_value
        self.net = net
        self.max_seq_len = net.max_seq_len

    # 生成序列的方法
    @torch.no_grad()
    @eval_decorator
    def generate(self, start_tokens, seq_len, eos_token = None, temperature = 1., filter_thres = 0.9, **kwargs):
        # 获取起始 tokens 的形状和设备信息
        b, t, device = *start_tokens.shape, start_tokens.device

        out = start_tokens

        for _ in range(seq_len):
            # 获取最后 self.max_seq_len 个 token
            x = out[:, -self.max_seq_len:]

            # 获取模型预测的 logits
            logits = self.net(x, **kwargs)[:, -1, :]

            # 过滤 logits 中的 top k 值
            filtered_logits = top_k(logits, thres = filter_thres)
            # 计算 softmax 温度调节后的概率
            probs = F.softmax(filtered_logits / temperature, dim=-1)

            # 从概率分布中采样一个 token
            sample = torch.multinomial(probs, 1)

            # 将采样的 token 添加到输出序列中
            out = torch.cat((out, sample), dim=-1)

            if exists(eos_token):
                # 检查是否存在 eos_token
                is_eos_token = (out == eos_token)

                if is_eos_token.any(dim = -1).all():
                    # 如果所有序列中都存在 eos_token,则停止生成
                    # 创建一个向右移动一位�� eos_token mask
                    shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
                    # 创建一个 mask,标记 eos_token 后的所有位置
                    mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1
                    # 将 mask 标记的位置填充为 pad_value
                    out = out.masked_fill(mask, self.pad_value)
                    break

        # 去除起始 tokens,返回生成的序列
        out = out[:, t:]
        return out

    # 前向传播方法
    def forward(self, x, **kwargs):
        # 将输入拆分为输入和标签
        x_inp, x_labels = x[:, :-1], x[:, 1:]
        return self.net(x_inp, labels = x_labels, **kwargs)

.\lucidrains\memory-efficient-attention-pytorch\memory_efficient_attention_pytorch\cosine_sim_flash_attention.py

# 导入所需的库
import math
import torch
from functools import partial
from torch import nn, einsum
import torch.nn.functional as F
from torch.autograd.function import Function

from einops import rearrange

# 定义常量
EPSILON = 1e-6

# 辅助函数

# 检查变量是否存在
def exists(val):
    return val is not None

# 如果变量存在则返回其值,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# 对输入张量进行 L2 归一化
def l2norm(t):
    return F.normalize(t, dim = -1)

# FlashAttentionFunction 类,实现了自定义的 PyTorch 函数
class FlashAttentionFunction(Function):
    # 前向传播函数
    @staticmethod
    @torch.no_grad()
    def forward(ctx, q, k, v, mask, scale, causal, q_bucket_size, k_bucket_size):
        device = q.device
        max_neg_value = -torch.finfo(q.dtype).max
        qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)

        k_len = k.shape[-2] # 在余弦相似度注意力中,行和受到键/值序列长度的限制

        o = torch.zeros_like(q)
        all_row_sums = torch.zeros((*q.shape[:-1], 1), device = device)

        # 处理输入的 mask
        if not exists(mask):
            mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
        else:
            mask = rearrange(mask, 'b n -> b 1 1 n')
            mask = mask.split(q_bucket_size, dim = -1)

        row_splits = zip(
            q.split(q_bucket_size, dim = -2),
            o.split(q_bucket_size, dim = -2),
            mask,
            all_row_sums.split(q_bucket_size, dim = -2),
        )

        # 遍历每个分块的行
        for ind, (qc, oc, row_mask, row_sums) in enumerate(row_splits):
            q_start_index = ind * q_bucket_size - qk_len_diff

            col_splits = zip(
                k.split(k_bucket_size, dim = -2),
                v.split(k_bucket_size, dim = -2),
            )

            # 遍历每个分块的列
            for k_ind, (kc, vc) in enumerate(col_splits):
                k_start_index = k_ind * k_bucket_size

                # 计算注意力权重
                attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale

                # 如果存在行 mask,则进行填充
                if exists(row_mask):
                    attn_weights.masked_fill_(~row_mask, max_neg_value)

                # 如果启用因果注意力,并且当前位置不应该看到后续位置的信息,则进行填充
                if causal and q_start_index < (k_start_index + k_bucket_size - 1):
                    causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype = torch.bool, device = device).triu(q_start_index - k_start_index + 1)
                    attn_weights.masked_fill_(causal_mask, max_neg_value)

                attn_weights -= scale
                exp_weights = torch.exp(attn_weights)

                # 如果存在行 mask,则进行填充
                if exists(row_mask):
                    exp_weights.masked_fill_(~row_mask, 0.)

                block_row_sums = exp_weights.sum(dim = -1, keepdims = True).clamp(min = EPSILON)

                exp_values = einsum('... i j, ... j d -> ... i d', exp_weights, vc)

                oc.add_(exp_values / k_len)
                row_sums.add_(block_row_sums)

        # 保存参数和中间结果,用于反向传播
        ctx.args = (scale, causal, mask, q_bucket_size, k_bucket_size)
        ctx.save_for_backward(q, k, v, o, all_row_sums)

        # 对输出进行缩放
        o.mul_(k_len / all_row_sums)

        return o

    @staticmethod
    @torch.no_grad()
    # 定义一个反向传播函数,接收上下文和梯度作为参数
    def backward(ctx, do):
        # 解包上下文参数
        scale, causal, mask, q_bucket_size, k_bucket_size = ctx.args
        q, k, v, o, l = ctx.saved_tensors

        # 获取设备信息
        device = q.device

        # 计算最大负值
        max_neg_value = -torch.finfo(q.dtype).max
        # 计算 q 和 k 的长度差
        qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)

        # 初始化梯度变量
        dq = torch.zeros_like(q)
        dk = torch.zeros_like(k)
        dv = torch.zeros_like(v)

        # 按照 q_bucket_size 分割张量
        row_splits = zip(
            q.split(q_bucket_size, dim = -2),
            o.split(q_bucket_size, dim = -2),
            do.split(q_bucket_size, dim = -2),
            mask,
            l.split(q_bucket_size, dim = -2),
            dq.split(q_bucket_size, dim = -2)
        )

        # 遍历分割后的张量
        for ind, (qc, oc, doc, row_mask, lc, dqc) in enumerate(row_splits):
            # 计算 q 的起始索引
            q_start_index = ind * q_bucket_size - qk_len_diff

            # 按照 k_bucket_size 分割张量
            col_splits = zip(
                k.split(k_bucket_size, dim = -2),
                v.split(k_bucket_size, dim = -2),
                dk.split(k_bucket_size, dim = -2),
                dv.split(k_bucket_size, dim = -2),
            )

            # 遍历分割后的张量
            for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
                # 计算 k 的起始索引
                k_start_index = k_ind * k_bucket_size

                # 计算注意力权重
                attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale

                # 如果是因果注意力机制,进行掩码处理
                if causal and q_start_index < (k_start_index + k_bucket_size - 1):
                    causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype = torch.bool, device = device).triu(q_start_index - k_start_index + 1)
                    attn_weights.masked_fill_(causal_mask, max_neg_value)

                # 计算指数化的注意力权重
                exp_attn_weights = torch.exp(attn_weights - scale)

                # 如果存在行掩码,进行填充
                if exists(row_mask):
                    exp_attn_weights.masked_fill_(~row_mask, 0.)

                # 计算概率
                p = exp_attn_weights / lc

                # 计算 dv_chunk
                dv_chunk = einsum('... i j, ... i d -> ... j d', p, doc)
                # 计算 dp
                dp = einsum('... i d, ... j d -> ... i j', doc, vc)

                # 计算 D
                D = (doc * oc).sum(dim = -1, keepdims = True)
                # 计算 ds
                ds = p * scale * (dp - D)

                # 计算 dq_chunk
                dq_chunk = einsum('... i j, ... j d -> ... i d', ds, kc)
                # 计算 dk_chunk
                dk_chunk = einsum('... i j, ... i d -> ... j d', ds, qc)

                # 累加梯度
                dqc.add_(dq_chunk)
                dkc.add_(dk_chunk)
                dvc.add_(dv_chunk)

        # 返回梯度
        return dq, dk, dv, None, None, None, None, None
# 主类
# 闪光注意力机制用于余弦相似度注意力
# 相对较简单,不再需要担心 softmax 数值稳定性问题,行和受到限制

class FlashAttention(nn.Module):
    def __init__(
        self,
        *,
        dim,
        scale = 16,
        heads = 8,
        dim_head = 64,
        causal = False,
        q_bucket_size = 512,
        k_bucket_size = 1024
    ):
        super().__init__()
        self.heads = heads

        self.scale = scale
        self.causal = causal

        inner_dim = heads * dim_head

        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
        self.to_out = nn.Linear(inner_dim, dim, bias = False)

        # 内存高效的注意力相关参数
        # 可以在前向传播中被覆盖
        self.q_bucket_size = q_bucket_size
        self.k_bucket_size = k_bucket_size

    def forward(
        self,
        x,
        context = None,
        mask = None,
        q_bucket_size = None,
        k_bucket_size = None,
    ):
        q_bucket_size = default(q_bucket_size, self.q_bucket_size)
        k_bucket_size = default(k_bucket_size, self.k_bucket_size)

        h = self.heads
        context = default(context, x)

        q = self.to_q(x)
        k, v = self.to_kv(context).chunk(2, dim = -1)

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))

        q, k = map(l2norm, (q, k))

        out = FlashAttentionFunction.apply(q, k, v, mask, self.scale, self.causal, q_bucket_size, k_bucket_size)

        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

.\lucidrains\memory-efficient-attention-pytorch\memory_efficient_attention_pytorch\flash_attention.py

# 导入数学库和 PyTorch 库
import math
import torch
# 导入 partial 函数
from functools import partial
# 从 torch 模块中导入 nn 和 einsum 函数
from torch import nn, einsum
# 从 torch.autograd.function 模块中导入 Function 类
from torch.autograd.function import Function
# 从 einops 库中导入 rearrange 函数

from einops import rearrange

# 定义常量 EPSILON
EPSILON = 1e-10

# 定义辅助函数

# 判断变量是否存在的函数
def exists(val):
    return val is not None

# 如果变量存在则返回其值,否则返回默认值的函数
def default(val, d):
    return val if exists(val) else d

# flash attention 前向和后向

# flash attention v1 - https://arxiv.org/abs/2205.14135
# flash attention v2 - https://tridao.me/publications/flash2/flash2.pdf

# 定义 FlashAttentionFunction 类,继承自 Function 类
class FlashAttentionFunction(Function):
    # 静态方法,用 @torch.no_grad() 装饰
    @staticmethod
    @torch.no_grad()
    # 前向传播函数,接收参数 q, k, v, mask, causal, q_bucket_size, k_bucket_size
    def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
        """ Algorithm 1 in the v2 paper """

        # 获取设备信息
        device = q.device
        # 获取最大负值
        max_neg_value = -torch.finfo(q.dtype).max
        # 计算 q 和 k 的长度差
        qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)

        # 初始化输出 o,所有行的和和最大值
        o = torch.zeros_like(q)
        all_row_sums = torch.zeros((*q.shape[:-1], 1), device=device)
        all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, device=device)

        # 缩放因子
        scale = (q.shape[-1] ** -0.5)

        # 计算行和列的分块数量
        num_row_tiles = math.ceil(q.shape[-2] / q_bucket_size)
        num_col_tiles = math.ceil(k.shape[-2] / k_bucket_size)

        # 处理 mask
        if exists(mask) and mask.ndim == 2:
            mask = rearrange(mask, 'b n -> b 1 1 n')

        if not exists(mask):
            col_masks = (None,) * num_col_tiles
            mask = (col_masks,) * num_row_tiles 
        else:
            mask = ((mask,) * num_row_tiles) if mask.shape[-2] == 1 else mask.split(q_bucket_size, dim=-2)
            mask = tuple(((row_mask,) * num_col_tiles) if row_mask.shape[-1] == 1 else row_mask.split(k_bucket_size, dim=-1) for row_mask in mask)

        # 按行分块
        row_splits = zip(
            q.split(q_bucket_size, dim=-2),
            o.split(q_bucket_size, dim=-2),
            mask,
            all_row_sums.split(q_bucket_size, dim=-2),
            all_row_maxes.split(q_bucket_size, dim=-2),
        )

        for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
            q_start_index = ind * q_bucket_size - qk_len_diff

            # 按列分块
            col_splits = zip(
                k.split(k_bucket_size, dim=-2),
                v.split(k_bucket_size, dim=-2),
                row_mask
            )

            for k_ind, (kc, vc, col_mask) in enumerate(col_splits):
                k_start_index = k_ind * k_bucket_size

                # 计算注意力权重
                attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale

                if exists(col_mask):
                    attn_weights.masked_fill_(~col_mask, max_neg_value)

                if causal and q_start_index < (k_start_index + k_bucket_size - 1):
                    causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu(q_start_index - k_start_index + 1)
                    attn_weights.masked_fill_(causal_mask, max_neg_value)

                block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
                new_row_maxes = torch.maximum(block_row_maxes, row_maxes)

                exp_weights = torch.exp(attn_weights - new_row_maxes)

                if exists(col_mask):
                    exp_weights.masked_fill_(~col_mask, 0.)

                block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON)

                exp_values = einsum('... i j, ... j d -> ... i d', exp_weights, vc)

                exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)

                new_row_sums = exp_row_max_diff * row_sums + block_row_sums

                oc.mul_(exp_row_max_diff).add_(exp_values)

                row_maxes.copy_(new_row_maxes)
                row_sums.copy_(new_row_sums)

            oc.div_(row_sums)

        lse = all_row_sums.log() + all_row_maxes

        # 保存参数并返回输出 o
        ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
        ctx.save_for_backward(q, k, v, o, lse)

        return o

    # 静态方法,用 @torch.no_grad() 装饰
    @staticmethod
    @torch.no_grad()
    # 定义一个向后传播函数,实现 v2 论文中的算法 2
    def backward(ctx, do):
        """ Algorithm 2 in the v2 paper """

        # 从上下文中获取参数
        causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
        q, k, v, o, lse = ctx.saved_tensors

        # 获取计算设备
        device = q.device

        # 获取最大负值
        max_neg_value = -torch.finfo(q.dtype).max
        # 计算 q 和 k 的长度差
        qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)

        # 初始化 dq, dk, dv
        dq = torch.zeros_like(q)
        dk = torch.zeros_like(k)
        dv = torch.zeros_like(v)

        # 按照 q_bucket_size 分割 q, o, do, mask, lse, dq
        row_splits = zip(
            q.split(q_bucket_size, dim = -2),
            o.split(q_bucket_size, dim = -2),
            do.split(q_bucket_size, dim = -2),
            mask,
            lse.split(q_bucket_size, dim = -2),
            dq.split(q_bucket_size, dim = -2)
        )

        # 遍历每个分割后的行
        for ind, (qc, oc, doc, row_mask, lsec, dqc) in enumerate(row_splits):
            q_start_index = ind * q_bucket_size - qk_len_diff

            # 按照 k_bucket_size 分割 k, v, dk, dv, row_mask
            col_splits = zip(
                k.split(k_bucket_size, dim = -2),
                v.split(k_bucket_size, dim = -2),
                dk.split(k_bucket_size, dim = -2),
                dv.split(k_bucket_size, dim = -2),
                row_mask
            )

            # 遍历每个分割后的列
            for k_ind, (kc, vc, dkc, dvc, col_mask) in enumerate(col_splits):
                k_start_index = k_ind * k_bucket_size

                # 计算注意力权重
                attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale

                # 如果是因果注意力机制,并且 q_start_index 小于 (k_start_index + k_bucket_size - 1)
                if causal and q_start_index < (k_start_index + k_bucket_size - 1):
                    causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype = torch.bool, device = device).triu(q_start_index - k_start_index + 1)
                    attn_weights.masked_fill_(causal_mask, max_neg_value)

                # 计算概率
                p = torch.exp(attn_weights - lsec)

                # 如果存在列掩码,则将概率中对应位置置零
                if exists(col_mask):
                    p.masked_fill_(~col_mask, 0.)

                # 计算 dv_chunk
                dv_chunk = einsum('... i j, ... i d -> ... j d', p, doc)
                dp = einsum('... i d, ... j d -> ... i j', doc, vc)

                # 计算 D 和 ds
                D = (doc * oc).sum(dim = -1, keepdims = True)
                ds = p * scale * (dp - D)

                # 计算 dq_chunk, dk_chunk
                dq_chunk = einsum('... i j, ... j d -> ... i d', ds, kc)
                dk_chunk = einsum('... i j, ... i d -> ... j d', ds, qc)

                # 累加到梯度中
                dqc.add_(dq_chunk)
                dkc.add_(dk_chunk)
                dvc.add_(dv_chunk)

        # 返回梯度 dq, dk, dv
        return dq, dk, dv, None, None, None, None
# 主类 FlashAttention,用于实现注意力机制
# 在纯 PyTorch 中实现会比在 CUDA 中实现慢很多
# 用于调试和教育目的

class FlashAttention(nn.Module):
    def __init__(
        self,
        *,
        dim,  # 输入维度
        heads = 8,  # 头数
        dim_head = 64,  # 每个头的维度
        causal = False,  # 是否使用因果注意力
        q_bucket_size = 512,  # 查询桶大小
        k_bucket_size = 1024  # 键值桶大小
    ):
        super().__init__()
        self.heads = heads
        self.causal = causal

        inner_dim = heads * dim_head

        self.to_q = nn.Linear(dim, inner_dim, bias = False)  # 查询线性层
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)  # 键值线性层
        self.to_out = nn.Linear(inner_dim, dim, bias = False)  # 输出线性层

        # 内存高效的注意力相关参数
        # 可以在前向传播中被覆盖
        self.q_bucket_size = q_bucket_size
        self.k_bucket_size = k_bucket_size

    def forward(
        self,
        x,  # 输入张量
        context = None,  # 上下文张量
        mask = None,  # 掩码张量
        q_bucket_size = None,  # 查询桶大小
        k_bucket_size = None,  # 键值桶大小
    ):
        q_bucket_size = default(q_bucket_size, self.q_bucket_size)  # 设置查询桶大小
        k_bucket_size = default(k_bucket_size, self.k_bucket_size)  # 设置键值桶大小

        h = self.heads
        context = default(context, x)  # 如果上下文为空,则使用输入张量作为上下文

        q = self.to_q(x)  # 计算查询张量
        k, v = self.to_kv(context).chunk(2, dim = -1)  # 计算键值张量并分割为键和值

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))  # 重排张量形状

        out = FlashAttentionFunction.apply(q, k, v, mask, self.causal, q_bucket_size, k_bucket_size)  # 调用自定义的注意力函数

        out = rearrange(out, 'b h n d -> b n (h d)')  # 重排输出张量形状
        return self.to_out(out)  # 返回输出结果

.\lucidrains\memory-efficient-attention-pytorch\memory_efficient_attention_pytorch\memory_efficient_attention.py

import torch
from functools import partial
from torch import nn, einsum
from torch.utils.checkpoint import checkpoint
import torch.nn.functional as F

from einops import rearrange

# 导入所需的库

def exists(val):
    return val is not None

# 检查值是否存在的辅助函数

def default(val, d):
    return val if exists(val) else d

# 如果值存在则返回该值,否则返回默认值的辅助函数

# regular attention

def attention(
    q, k, v,
    mask = None,
    causal = False,
    attn_bias = None,
    **kwargs
):
    scale = q.shape[-1] ** -0.5
    q = q * scale

    # 缩放查询向量

    sim = einsum('b h i d, b h j d -> b h i j', q, k)

    # 计算注意力分数

    if exists(attn_bias):
        sim = sim + attn_bias

    # 添加注意力偏置

    mask_value = -torch.finfo(sim.dtype).max

    # 计算掩码值

    if exists(mask):
        if mask.ndim == 2:
            mask = rearrange(mask, 'b j -> b 1 1 j')
        sim = sim.masked_fill(~mask, mask_value)

    # 应用掩码

    if causal:
        i, j = sim.shape[-2:]
        mask = torch.ones(i, j, device = q.device, dtype = torch.bool).triu(j - i + 1)
        sim = sim.masked_fill(mask, mask_value)

    # 应用因果掩码

    sim = sim - sim.amax(dim = -1, keepdim = True).detach()
    attn = sim.softmax(dim = -1)

    # 计算注意力权重

    out = einsum('b h i j, b h j d -> b h i d', attn, v)
    return out

    # 计算输出

# memory efficient attention

def summarize_qkv_chunk(q, k, v, mask, attn_bias_chunk, causal, qk_start_indices, dropout):
    q_start_index, k_start_index, q_chunk_size, k_chunk_size, device = *qk_start_indices, q.shape[-2], k.shape[-2], q.device

    weight = einsum('b h i d, b h j d -> b h i j', q, k)

    # 计算权重

    if exists(attn_bias_chunk):
        weight = weight + attn_bias_chunk

    # 添加注意力偏置

    mask_value = -torch.finfo(weight.dtype).max

    # 计算掩码值

    if exists(mask):
        mask = rearrange(mask, 'b j -> b 1 1 j')
        weight = weight.masked_fill(~mask, mask_value)

    # 应用掩码

    if causal and q_start_index < (k_start_index + k_chunk_size - 1):
        causal_mask = torch.ones((q_chunk_size, k_chunk_size), dtype = torch.bool, device = device).triu(q_start_index - k_start_index + 1)
        weight = weight.masked_fill(causal_mask, mask_value)

    # 应用因果掩码

    weight_max = weight.amax(dim = -1, keepdim = True).detach()
    weight = weight - weight_max

    exp_weight = weight.exp()

    exp_weight = F.dropout(exp_weight, p = dropout)

    weighted_value = einsum('b h i j, b h j d -> b h i d', exp_weight, v)

    return exp_weight.sum(dim = -1), weighted_value, rearrange(weight_max, '... 1 -> ...')

checkpointed_summarize_qkv_chunk = partial(checkpoint, summarize_qkv_chunk)

# 创建检查点函数

def memory_efficient_attention(
    q, k, v,
    mask = None,
    causal = False,
    attn_bias = None,
    q_bucket_size = 512,
    k_bucket_size = 1024,
    eps = 1e-8,
    dropout = 0.,
    training = False
):
    scale = q.shape[-1] ** -0.5
    q = q * scale

    # 缩放查询向量

    needs_backwards = q.requires_grad or k.requires_grad or v.requires_grad
    summarize_qkv_fn = checkpointed_summarize_qkv_chunk if needs_backwards else summarize_qkv_chunk

    # 根据是否需要反向传播选择函数

    q_chunks = q.split(q_bucket_size, dim = -2)
    k_chunks = k.split(k_bucket_size, dim = -2)
    v_chunks = v.split(k_bucket_size, dim = -2)
    mask_chunks = mask.split(k_bucket_size, dim = -1) if exists(mask) else ((None,) * len(k_chunks))

    if exists(attn_bias):
        i, j = attn_bias.shape[-2:]
        attn_bias_chunks = attn_bias.split(q_bucket_size, dim = -2)
        attn_bias_chunks = list(map(lambda t: t.split(k_bucket_size, dim = -1), attn_bias_chunks))

    # 将输入分块

    out = []

    # 初始化输出列表
    # 遍历查询块列表,获取索引和查询块
    for q_index, q_chunk in enumerate(q_chunks):
        # 初始化空列表,用于存储期望权重、加权值和权重最大值
        exp_weights = []
        weighted_values = []
        weight_maxes = []

        # 遍历键值块、值块和掩码块的元组列表
        for k_index, (k_chunk, v_chunk, mask_chunk) in enumerate(zip(k_chunks, v_chunks, mask_chunks)):
            # 计算查询块和键块的起始索引
            q_start_index = q_index * q_bucket_size
            k_start_index = k_index * k_bucket_size

            # 如果是因果的且键块的起始索引大于查询块的结束索引,则跳过当前循环
            if causal and k_start_index > (q_start_index + q_chunk.shape[-2] - 1):
                continue

            # 如果存在注意力偏置,则获取当前注意力偏置块
            attn_bias_chunk = attn_bias_chunks[q_index][k_index] if exists(attn_bias) else None

            # 调用 summarize_qkv_fn 函数,计算期望权重、加权值和权重最大值
            exp_weight_chunk, weighted_value_chunk, weight_max_chunk = summarize_qkv_fn(
                q_chunk,
                k_chunk,
                v_chunk,
                mask_chunk,
                attn_bias_chunk,
                causal,
                (q_start_index, k_start_index),
                dropout if training else 0.
            )

            # 将计算得到的结果添加到对应的列表中
            exp_weights.append(exp_weight_chunk)
            weighted_values.append(weighted_value_chunk)
            weight_maxes.append(weight_max_chunk)

        # 将权重最大值堆叠在一起
        weight_maxes = torch.stack(weight_maxes, dim=-1)

        # 将加权值堆叠在一起
        weighted_values = torch.stack(weighted_values, dim=-1)
        # 将期望权重堆叠在一起
        exp_weights = torch.stack(exp_weights, dim=-1)

        # 计算全局最大值
        global_max = weight_maxes.amax(dim=-1, keepdim=True)
        # 计算重新归一化因子
        renorm_factor = (weight_maxes - global_max).exp().detach()

        # 期望权重乘以重新归一化因子
        exp_weights = exp_weights * renorm_factor
        # 加权值乘以重新排列的重新归一化因子
        weighted_values = weighted_values * rearrange(renorm_factor, '... c -> ... 1 c')

        # 对所有加权值进行求和
        all_values = weighted_values.sum(dim=-1)
        # 对所有期望权重进行求和
        all_weights = exp_weights.sum(dim=-1)

        # 对归一化���的值进行计算
        normalized_values = all_values / (rearrange(all_weights, '... -> ... 1') + eps)
        # 将归一化后的值添加到输出列表中
        out.append(normalized_values)

    # 沿着指定维度连接输出列表中的张量
    return torch.cat(out, dim=-2)
# 主要的注意力机制类

class Attention(nn.Module):
    # 初始化函数
    def __init__(
        self,
        *,
        dim,  # 输入维度
        heads = 8,  # 头数,默认为8
        dim_head = 64,  # 每个头的维度,默认为64
        dropout = 0.,  # 丢弃概率,默认为0
        causal = False,  # 是否使用因果注意力,默认为False
        memory_efficient = False,  # 是否使用内存高效的注意力,默认为False
        q_bucket_size = 512,  # 查询桶大小,默认为512
        k_bucket_size = 1024  # 键值桶大小,默认为1024
    ):
        super().__init__()
        self.heads = heads  # 头数
        self.causal = causal  # 是否因果
        self.dropout = dropout  # 丢弃概率
        inner_dim = heads * dim_head  # 内部维度为头数乘以每个头的维度

        self.to_q = nn.Linear(dim, inner_dim, bias = False)  # 输入到查询的线性层
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)  # 输入到键值的线性层
        self.to_out = nn.Linear(inner_dim, dim, bias = False)  # 输出的线性层

        # 内存高效注意力相关参数
        # 可在前向传播中覆盖
        self.memory_efficient = memory_efficient  # 是否内存高效
        self.q_bucket_size = q_bucket_size  # 查询桶大小
        self.k_bucket_size = k_bucket_size  # 键值桶大小

    # 前向传播函数
    def forward(
        self,
        x,  # 输入张量
        context = None,  # 上下文,默认为None
        mask = None,  # 掩码,默认为None
        attn_bias = None,  # 注意力偏置,默认为None
        memory_efficient = None,  # 是否内存高效,默认为None
        q_bucket_size = None,  # 查询桶大小,默认为None
        k_bucket_size = None,  # 键值桶大小,默认为None
    ):
        memory_efficient = default(memory_efficient, self.memory_efficient)  # 使用默认值或者自定义值
        q_bucket_size = default(q_bucket_size, self.q_bucket_size)  # 使用默认值或者自定义值
        k_bucket_size = default(k_bucket_size, self.k_bucket_size)  # 使用默认值或者自定义值

        h = self.heads  # 头数
        context = default(context, x)  # 上下文,默认为输入张量

        q = self.to_q(x)  # 查询张量
        k, v = self.to_kv(context).chunk(2, dim = -1)  # 键值张量拆分为k和v

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))  # 重排张量形状

        attn_fn = attention if not memory_efficient else memory_efficient_attention  # 根据内存高效性选择不同的注意力函数

        out = attn_fn(q, k, v, mask = mask, attn_bias = attn_bias, causal = self.causal, q_bucket_size = q_bucket_size, 
                    k_bucket_size = k_bucket_size, dropout = self.dropout, training = self.training)  # 注意力计算

        out = rearrange(out, 'b h n d -> b n (h d)')  # 重排输出形状
        return self.to_out(out)  # 输出结果

.\lucidrains\memory-efficient-attention-pytorch\memory_efficient_attention_pytorch\memory_efficient_cosine_sim_attention.py

import math
import torch
import torch.nn.functional as F
from functools import partial
from torch import nn, einsum
from torch.utils.checkpoint import checkpoint

from einops import rearrange

# helper functions

# 检查变量是否存在
def exists(val):
    return val is not None

# 如果变量存在则返回其值,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# 对输入张量进行 L2 归一化
def l2norm(t):
    return F.normalize(t, dim = -1)

# regular attention

# 普通的注意力机制
def attention(
    q, k, v,
    mask = None,
    causal = False,
    attn_bias = None,
    **kwargs
):
    # 计算查询、键之间的相似度
    sim = einsum('b h i d, b h j d -> b h i j', q, k)

    # 添加注意力偏置
    if exists(attn_bias):
        sim = sim + attn_bias

    mask_value = -torch.finfo(sim.dtype).max

    # 处理掩码
    if exists(mask):
        mask = rearrange(mask, 'b j -> b 1 1 j')
        sim = sim.masked_fill(~mask, mask_value)

    # 处理因果关系
    if causal:
        i, j = sim.shape[-2:]
        mask = torch.ones(i, j, device = q.device, dtype = torch.bool).triu(j - i + 1)
        sim = sim.masked_fill(mask, mask_value)

    # 计算注意力权重
    attn = sim.softmax(dim = -1)

    # 计算输出
    out = einsum('b h i j, b h j d -> b h i d', attn, v)
    return out

# memory efficient attention

# 汇总查询、键、值的函数
def summarize_qkv_chunk(q, k, v, mask, attn_bias_chunk, causal, qk_start_indices):
    q_start_index, k_start_index, q_chunk_size, k_chunk_size, device = *qk_start_indices, q.shape[-2], k.shape[-2], q.device

    weight = einsum('b h i d, b h j d -> b h i j', q, k)

    if exists(attn_bias_chunk):
        weight = weight + attn_bias_chunk

    mask_value = -torch.finfo(weight.dtype).max

    if exists(mask):
        mask = rearrange(mask, 'b j -> b 1 1 j')
        weight = weight.masked_fill(~mask, mask_value)

    if causal and q_start_index < (k_start_index + k_chunk_size - 1):
        causal_mask = torch.ones((q_chunk_size, k_chunk_size), dtype = torch.bool, device = device).triu(q_start_index - k_start_index + 1)
        weight = weight.masked_fill(causal_mask, mask_value)

    exp_weight = weight.exp()
    weighted_value = einsum('b h i j, b h j d -> b h i d', exp_weight, v)

    return exp_weight.sum(dim = -1), weighted_value

# 使用 checkpoint 优化的汇总查询、键、值的函数
checkpointed_summarize_qkv_chunk = partial(checkpoint, summarize_qkv_chunk)

# 数值不稳定的内存高效注意力机制
def numerically_unstable_memory_efficient_attention(
    q, k, v,
    mask = None,
    causal = False,
    attn_bias = None,
    q_bucket_size = 512,
    k_bucket_size = 1024,
    eps = 1e-8
):
    needs_backwards = q.requires_grad or k.requires_grad or v.requires_grad
    summarize_qkv_fn = checkpointed_summarize_qkv_chunk if needs_backwards else summarize_qkv_chunk

    # 将所有输入分块

    q_chunks = q.split(q_bucket_size, dim = -2)
    k_chunks = k.split(k_bucket_size, dim = -2)
    v_chunks = v.split(k_bucket_size, dim = -2)
    mask_chunks = mask.split(k_bucket_size, dim = -1) if exists(mask) else ((None,) * len(k_chunks))

    if exists(attn_bias):
        i, j = attn_bias.shape[-2:]
        attn_bias_chunks = attn_bias.split(q_bucket_size, dim = -2)
        attn_bias_chunks = list(map(lambda t: t.split(k_bucket_size, dim = -1), attn_bias_chunks))

    # 循环遍历所有块并累积

    out = []
    # 遍历查询块列表,获取索引和查询块
    for q_index, q_chunk in enumerate(q_chunks):
        # 计算查询块的起始索引
        q_start_index = q_index * q_bucket_size
        # 初始化期望权重列表和加权值列表
        exp_weights = []
        weighted_values = []

        # 遍历键值块、值块和掩码块的元组列表,获取索引和对应的块
        for k_index, (k_chunk, v_chunk, mask_chunk) in enumerate(zip(k_chunks, v_chunks, mask_chunks)):
            # 计算键块的起始索引
            k_start_index = k_index * k_bucket_size

            # 如果是因果的且键块的起始索引大于查询块的起始索引加上查询块的长度减1,则跳过当前循环
            if causal and k_start_index > (q_start_index + q_chunk.shape[-2] - 1):
                continue

            # 如果存在注意力偏置,则获取当前查询块和键块对应的注意力偏置
            attn_bias_chunk = attn_bias_chunks[q_index][k_index] if exists(attn_bias) else None

            # 调用summarize_qkv_fn函数,计算期望权重和加权值
            exp_weight_chunk, weighted_value_chunk = summarize_qkv_fn(
                q_chunk,
                k_chunk,
                v_chunk,
                mask_chunk,
                attn_bias_chunk,
                causal,
                (q_start_index, k_start_index)
            )

            # 将计算得到的期望权重和加权值添加到对应列表中
            exp_weights.append(exp_weight_chunk)
            weighted_values.append(weighted_value_chunk)

        # 计算所有加权值的总和
        all_values = sum(weighted_values)
        # 计算所有期望权重的总和
        all_weights = sum(exp_weights)

        # 对所有加权值进行归一化处理
        normalized_values = all_values / (rearrange(all_weights, '... -> ... 1') + eps)
        # 将归一化后的值添加到输出列表中
        out.append(normalized_values)

    # 沿着指定维度连接输出列表中的张量,形成最终输出结果
    return torch.cat(out, dim=-2)
# 主要类定义

class CosineSimAttention(nn.Module):
    def __init__(
        self,
        *,
        dim,
        seq_len,
        heads = 8,
        dim_head = 64,
        dropout = 0.,
        causal = False,
        memory_efficient = False,
        q_bucket_size = 512,
        k_bucket_size = 1024
    ):
        super().__init__()
        self.heads = heads
        self.causal = causal

        inner_dim = heads * dim_head

        # 初始化缩放参数
        scale_init_value = -math.log(math.log2(seq_len ** 2 - seq_len))
        self.scale = nn.Parameter(torch.full((1, heads, 1, 1), scale_init_value))

        # 线性变换层,将输入维度映射到内部维度
        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
        self.to_out = nn.Linear(inner_dim, dim, bias = False)

        # 内存高效注意力相关参数
        # 可在前向传播中覆盖
        self.memory_efficient = memory_efficient
        self.q_bucket_size = q_bucket_size
        self.k_bucket_size = k_bucket_size

    def forward(
        self,
        x,
        context = None,
        mask = None,
        attn_bias = None,
        memory_efficient = None,
        q_bucket_size = None,
        k_bucket_size = None,
    ):
        memory_efficient = default(memory_efficient, self.memory_efficient)
        q_bucket_size = default(q_bucket_size, self.q_bucket_size)
        k_bucket_size = default(k_bucket_size, self.k_bucket_size)

        h = self.heads
        context = default(context, x)

        # 对输入进行线性变换得到查询、键、值
        q = self.to_q(x)
        k, v = self.to_kv(context).chunk(2, dim = -1)

        # 重排维度以适应多头注意力计算
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))

        # 对查询、键进行 L2 归一化
        q, k = map(l2norm, (q, k))

        # 缩放查询
        q = q * self.scale.exp()

        # 根据内存高效标志选择注意力函数
        attn_fn = attention if not memory_efficient else numerically_unstable_memory_efficient_attention

        # 计算注意力得到输出
        out = attn_fn(q, k, v, mask = mask, attn_bias = attn_bias, causal = self.causal, q_bucket_size = q_bucket_size, k_bucket_size = k_bucket_size)

        # 重排维度以还原原始形状
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

.\lucidrains\memory-efficient-attention-pytorch\memory_efficient_attention_pytorch\reversible.py

# 导入 torch 库
import torch
# 导入 torch 中的神经网络模块
import torch.nn as nn
# 从 operator 模块中导入 itemgetter 函数
from operator import itemgetter
# 从 torch.autograd.function 模块中导入 Function 类
from torch.autograd.function import Function
# 从 torch.utils.checkpoint 模块中导入 get_device_states 和 set_device_states 函数

# 用于将参数路由到可逆层函数中的函数
def route_args(router, args, depth):
    # 初始化路由后的参数列表
    routed_args = [(dict(), dict()) for _ in range(depth)]
    # 获取参数中与路由器匹配的键
    matched_keys = [key for key in args.keys() if key in router]

    # 遍历匹配的键
    for key in matched_keys:
        val = args[key]
        # 遍历路由后的参数列表和路由器中的路由
        for depth, ((f_args, g_args), routes) in enumerate(zip(routed_args, router[key])):
            # 根据路由将参数添加到对应的函数参数中
            new_f_args, new_g_args = map(lambda route: ({key: val} if route else {}), routes)
            routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args})
    return routed_args

# 参考示例 https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html 中的保存和设置随机数生成器
class Deterministic(nn.Module):
    def __init__(self, net):
        super().__init__()
        self.net = net
        self.cpu_state = None
        self.cuda_in_fwd = None
        self.gpu_devices = None
        self.gpu_states = None

    def record_rng(self, *args):
        self.cpu_state = torch.get_rng_state()
        if torch.cuda._initialized:
            self.cuda_in_fwd = True
            self.gpu_devices, self.gpu_states = get_device_states(*args)

    def forward(self, *args, record_rng = False, set_rng = False, **kwargs):
        if record_rng:
            self.record_rng(*args)

        if not set_rng:
            return self.net(*args, **kwargs)

        rng_devices = []
        if self.cuda_in_fwd:
            rng_devices = self.gpu_devices

        with torch.random.fork_rng(devices=rng_devices, enabled=True):
            torch.set_rng_state(self.cpu_state)
            if self.cuda_in_fwd:
                set_device_states(self.gpu_devices, self.gpu_states)
            return self.net(*args, **kwargs)

# 受 https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py 启发
# 一旦多 GPU 确认工作正常,重构并将 PR 发回源代码
class ReversibleBlock(nn.Module):
    def __init__(self, f, g):
        super().__init__()
        self.f = Deterministic(f)
        self.g = Deterministic(g)

    def forward(self, x, f_args = {}, g_args = {}):
        x1, x2 = torch.chunk(x, 2, dim=2)
        y1, y2 = None, None

        with torch.no_grad():
            y1 = x1 + self.f(x2, record_rng=self.training, **f_args)
            y2 = x2 + self.g(y1, record_rng=self.training, **g_args)

        return torch.cat([y1, y2], dim=2)

    def backward_pass(self, y, dy, f_args = {}, g_args = {}):
        y1, y2 = torch.chunk(y, 2, dim=2)
        del y

        dy1, dy2 = torch.chunk(dy, 2, dim=2)
        del dy

        with torch.enable_grad():
            y1.requires_grad = True
            gy1 = self.g(y1, set_rng=True, **g_args)
            torch.autograd.backward(gy1, dy2)

        with torch.no_grad():
            x2 = y2 - gy1
            del y2, gy1

            dx1 = dy1 + y1.grad
            del dy1
            y1.grad = None

        with torch.enable_grad():
            x2.requires_grad = True
            fx2 = self.f(x2, set_rng=True, **f_args)
            torch.autograd.backward(fx2, dx1, retain_graph=True)

        with torch.no_grad():
            x1 = y1 - fx2
            del y1, fx2

            dx2 = dy2 + x2.grad
            del dy2
            x2.grad = None

            x = torch.cat([x1, x2.detach()], dim=2)
            dx = torch.cat([dx1, dx2], dim=2)

        return x, dx

class _ReversibleFunction(Function):
    @staticmethod
    def forward(ctx, x, blocks, args):
        ctx.args = args
        for block, kwarg in zip(blocks, args):
            x = block(x, **kwarg)
        ctx.y = x.detach()
        ctx.blocks = blocks
        return x

    @staticmethod
    # 定义反向传播函数,接收上下文和梯度作为参数
    def backward(ctx, dy):
        # 获取上下文中的 y 和 args
        y = ctx.y
        args = ctx.args
        # 反向遍历上下文中的 blocks 和 args
        for block, kwargs in zip(ctx.blocks[::-1], args[::-1]):
            # 调用每个 block 的反向传播函数,更新 y 和 dy
            y, dy = block.backward_pass(y, dy, **kwargs)
        # 返回更新后的梯度
        return dy, None, None
# 定义一个可逆序列的神经网络模块
class ReversibleSequence(nn.Module):
    # 初始化函数,接受一组块和参数路由作为输入
    def __init__(self, blocks, args_route = {}):
        super().__init__()
        # 将参数路由保存在对象中
        self.args_route = args_route
        # 创建一个包含多个可逆块的模块列表
        self.blocks = nn.ModuleList([ReversibleBlock(f=f, g=g) for f, g in blocks])

    # 前向传播函数
    def forward(self, x, **kwargs):
        # 在最后一个维度上将输入张量 x 进行拼接
        x = torch.cat([x, x], dim=-1)

        # 获取模块列表和参数路由
        blocks = self.blocks
        args = route_args(self.args_route, kwargs, len(blocks))
        # 将参数转换为字典形式
        args = list(map(lambda x: {'f_args': x[0], 'g_args': x[1]}, args))

        # 将块和参数组成元组列表
        layers_and_args = list(zip(blocks, args))

        # 调用自定义的可逆函数 _ReversibleFunction 的前向传播方法
        out =  _ReversibleFunction.apply(x, blocks, args)
        # 在最后一个维度上将输出张量拆分成两部分,然后对它们进行求和
        return torch.stack(out.chunk(2, dim=-1)).sum(dim=0)

.\lucidrains\memory-efficient-attention-pytorch\memory_efficient_attention_pytorch\transformer.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块和 einsum 函数
from torch import nn, einsum
# 从 torch 库中导入 nn.functional 模块,并重命名为 F
import torch.nn.functional as F
# 从 functools 库中导入 partial 函数
from functools import partial
# 从 einops 库中导入 rearrange 函数
from einops import rearrange
# 从 memory_efficient_attention_pytorch 库中导入 FlashAttention 和 Attention 类
from memory_efficient_attention_pytorch import FlashAttention, Attention
# 从 memory_efficient_attention_pytorch.reversible 库中导入 ReversibleSequence 类
from memory_efficient_attention_pytorch.reversible import ReversibleSequence

# 定义一个函数,用于检查变量是否存在
def exists(val):
    return val is not None

# 定义一个继承自 nn.Module 的类 PreNorm
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = nn.LayerNorm(dim)

    def forward(self, x, **kwargs):
        # 对输入数据进行 LayerNorm 归一化
        x = self.norm(x)
        # 调用传入的函数处理归一化后的数据
        return self.fn(x, **kwargs)

# 定义一个继承自 nn.Module 的类 FeedForward
class FeedForward(nn.Module):
    def __init__(self, dim, mult = 4, chunks = 1):
        super().__init__()
        self.chunks = chunks

        # 定义一个包含线性层和 GELU 激活函数的神经网络
        self.net = nn.Sequential(
            nn.Linear(dim, dim * mult),
            nn.GELU(),
            nn.Linear(dim * mult, dim)
        )

    def forward(self, x):
        # 如果 chunks 小于等于 1,则直接对输入数据进行处理
        if self.chunks <= 1:
            return self.net(x)

        # 将输入数据按照指定维度进行切分
        chunks = x.chunk(self.chunks, dim = 1)
        # 对每个切分后的数据块进行处理
        out = [self.net(chunk) for chunk in chunks]
        # 将处理后的数据块拼接在一起
        return torch.cat(out, dim = 1)

# 定义一个继承自 nn.Module 的类 Transformer
class Transformer(nn.Module):
    def __init__(
        self,
        *,
        num_tokens,
        max_seq_len,
        dim,
        depth,
        causal = False,
        dim_head = 64,
        heads = 8,
        ff_mult = 4,
        ff_chunks = 1,
        use_flash_attn = True,
        **kwargs
    ):
        super().__init__()
        self.max_seq_len = max_seq_len

        # 定义一个 token 的 Embedding 层
        self.token_emb = nn.Embedding(num_tokens, dim)
        # ���义一个位置编码的 Embedding 层
        self.pos_emb = nn.Embedding(max_seq_len, dim)

        # 根据 use_flash_attn 参数选择不同的注意力机制类
        attn_klass = FlashAttention if use_flash_attn else partial(Attention, memory_efficient = True)

        # 初始化一个空的神经网络层列表
        self.layers = nn.ModuleList([])
        # 根据深度循环创建多个层
        for _ in range(depth):
            # 每个层包含一个注意力机制和一个前馈神经网络
            self.layers.append(nn.ModuleList([
                PreNorm(dim, attn_klass(dim = dim, dim_head = dim_head, heads = heads, causal = causal, **kwargs)),
                PreNorm(dim, FeedForward(dim = dim, mult = ff_mult, chunks = ff_chunks)),
            ]))

        # 创建一个可逆序列
        self.net = ReversibleSequence(self.layers)

        # 定义一个输出层,用于将模型输出转换为预测标签
        self.to_logits = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_tokens)
        )

    def forward(self, x, labels = None):
        device = x.device
        # 对输入数据进行 token embedding
        x = self.token_emb(x)

        # 生成位置编码
        pos_emb = self.pos_emb(torch.arange(x.shape[-2], device = device))
        x = x + pos_emb

        # 通过网络层进行前向传播
        x = self.net(x)

        # 将输出数据转换为预测标签
        logits = self.to_logits(x)

        # 如果不存在标签,则直接返回预测结果
        if not exists(labels):
            return logits

        # 计算交叉熵损失
        return F.cross_entropy(rearrange(logits, 'b n d -> b d n'), labels)

.\lucidrains\memory-efficient-attention-pytorch\memory_efficient_attention_pytorch\__init__.py

# 从 memory_efficient_attention_pytorch.memory_efficient_attention 模块中导入 Attention 类和 memory_efficient_attention 函数
from memory_efficient_attention_pytorch.memory_efficient_attention import Attention, memory_efficient_attention
# 从 memory_efficient_attention_pytorch.memory_efficient_cosine_sim_attention 模块中导入 CosineSimAttention 类和 numerically_unstable_memory_efficient_attention 函数
from memory_efficient_attention_pytorch.memory_efficient_cosine_sim_attention import CosineSimAttention, numerically_unstable_memory_efficient_attention
# 从 memory_efficient_attention_pytorch.flash_attention 模块中导入 FlashAttention 类
from memory_efficient_attention_pytorch.flash_attention import FlashAttention

Memory Efficient Attention Pytorch (obsolete)

Implementation of a memory efficient multi-head attention as proposed in the paper, Self-attention Does Not Need O(n²) Memory. In addition, the module will take care of masking, causal masking, as well as cross attention.

This repository also contains a naive non-CUDA implementation of the improvements made by Tri Dao with his Flash Attention 2 paper, for educational purposes. It is a game changer for attention and building long-context transformers.

Update: from now on, you should just be using the F.scaled_dot_product_attention function in Pytorch 2.0 for built-in Flash Attention v1 support - or use Flash Attention v2 at the official repository

Install

$ pip install memory-efficient-attention-pytorch

Usage

For autoregressive language model

import torch
from memory_efficient_attention_pytorch import Attention

attn = Attention(
    dim = 512,
    dim_head = 64,                # dimension per head
    heads = 8,                    # number of attention heads
    causal = True,                # autoregressive or not
    memory_efficient = True,      # whether to use memory efficient attention (can be turned off to test against normal attention)
    q_bucket_size = 1024,         # bucket size along queries dimension
    k_bucket_size = 2048          # bucket size along key / values dimension
).cuda()

x = torch.randn(1, 65536, 512).cuda()
out = attn(x) # (1, 65536, 512)

Cross attention

import torch
from memory_efficient_attention_pytorch import Attention

cross_attn = Attention(
    dim = 512,
    dim_head = 64,
    heads = 8,
    memory_efficient = True,
    q_bucket_size = 1024,
    k_bucket_size = 2048
).cuda()

x = torch.randn(1, 65536, 512).cuda()
context = torch.randn(1, 65536, 512).cuda()
mask = torch.ones(1, 65536).bool().cuda()

out = cross_attn(x, context = context, mask = mask) # (1, 65536, 512)

Citations

@misc{rabe2021selfattention,
    title   = {Self-attention Does Not Need $O(n^2)$ Memory}, 
    author  = {Markus N. Rabe and Charles Staats},
    year    = {2021},
    eprint  = {2112.05682},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@misc{liu2021swin,
    title   = {Swin Transformer V2: Scaling Up Capacity and Resolution},
    author  = {Ze Liu and Han Hu and Yutong Lin and Zhuliang Yao and Zhenda Xie and Yixuan Wei and Jia Ning and Yue Cao and Zheng Zhang and Li Dong and Furu Wei and Baining Guo},
    year    = {2021},
    eprint  = {2111.09883},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@article{Dao2022FlashAttentionFA,
    title   = {FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness},
    author  = {Tri Dao and Daniel Y. Fu and Stefano Ermon and Atri Rudra and Christopher R'e},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2205.14135}
}
@article{dao2023flashattention2,
  title     = {Flash{A}ttention-2: Faster Attention with Better Parallelism and Work Partitioning,
  author    = {Dao, Tri},
  year      = {2023}
}

.\lucidrains\memory-efficient-attention-pytorch\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'memory-efficient-attention-pytorch',  # 包的名称
  packages = find_packages(exclude=[]),  # 查找所有包
  version = '0.1.6',  # 版本号
  license='MIT',  # 许可证
  description = 'Memory Efficient Attention - Pytorch',  # 描述
  long_description_content_type = 'text/markdown',  # 长描述内容类型
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  url = 'https://github.com/lucidrains/memory-efficient-attention-pytorch',  # 项目链接
  keywords = [
    'artificial intelligence',  # 关键词
    'deep learning',  # 关键词
    'attention-mechanism'  # 关键词
  ],
  install_requires=[
    'einops>=0.4.1',  # 安装所需的依赖项
    'torch>=1.6'    # 安装所需的依赖项
  ],
  setup_requires=[
    'pytest-runner',  # 安装设置所需的依赖项
  ],
  tests_require=[
    'pytest'  # 安装测试所需的依赖项
  ],
  classifiers=[
    'Development Status :: 4 - Beta',  # 分类器
    'Intended Audience :: Developers',  # 分类器
    'Topic :: Scientific/Engineering :: Artificial Intelligence',  # 分类器
    'License :: OSI Approved :: MIT License',  # 分类器
    'Programming Language :: Python :: 3.8',  # 分类器
  ],
)