Lucidrains-系列项目源码解析-九十四-

60 阅读25分钟

Lucidrains 系列项目源码解析(九十四)

.\lucidrains\spear-tts-pytorch\spear_tts_pytorch\__init__.py

# 导入spear_tts_pytorch包中的TextToSemantic、SpeechSpeechPretrainWrapper、SemanticToTextWrapper、TextToSemanticWrapper、SemanticToTextDatasetGenerator类
# 导入spear_tts_pytorch包中的trainer模块中的SpeechSpeechPretrainer、SemanticToTextTrainer、TextToSemanticTrainer类
# 导入spear_tts_pytorch包中的data模块中的GeneratedAudioTextDataset、MockDataset类
from spear_tts_pytorch.spear_tts_pytorch import (
    TextToSemantic,
    SpeechSpeechPretrainWrapper,
    SemanticToTextWrapper,
    TextToSemanticWrapper,
    SemanticToTextDatasetGenerator
)

from spear_tts_pytorch.trainer import (
    SpeechSpeechPretrainer,
    SemanticToTextTrainer,
    TextToSemanticTrainer
)

from spear_tts_pytorch.data import (
    GeneratedAudioTextDataset,
    MockDataset
)

Data source

The enwik8 data was downloaded from the Hutter prize page: prize.hutter1.net/

Speculative Decoding

Explorations into some recent techniques surrounding speculative decoding

Also have a few ideas of my own that I will try and share in this repository, if they work. The goal is to initially use it to speed up the text-to-semantic decoder in Spear-TTS

Appreciation

  • StabilityAI and 🤗 Huggingface for the generous sponsorship, as well as my other sponsors, for affording me the independence to open source current artificial intelligence techniques.

Todo

  • in early exit scheme, cache the hidden layer during spec decoding, as small and large models share the same first few layers

  • for early exit, allow an extra transformer block head (separate from main transformer stem)

  • figure out batched spec decoding - different rows may advance at different rates

  • further optimize batched spec decoding, as losing some performance from all the indexing - seems like it will take some work for this technique to be actually usable

  • make batched spec decoding work with early exit strategy

  • complete speculative sampling with prophet transformer idea - seems to work well! 🙌

  • get some wandb charts and see how prophet compares with early exit strategy, share on repository

  • also run experiments to see if prophet transformer brings any benefit to main model loss. original prophet paper only did a simple linear projection

  • for early exit strategy, try randomly summing last cached embedding back to the same model (a la alphafold2 recycling), randomly cropped along sequence length, and train early exit loss this way. see if one can improve the gamma this way

  • dedicate a morning to microoptimizations

Citations

@inproceedings{Leviathan2022FastIF,
    title   = {Fast Inference from Transformers via Speculative Decoding},
    author  = {Yaniv Leviathan and Matan Kalman and Y. Matias},
    booktitle = {International Conference on Machine Learning},
    year    = {2022},
    url     = {https://api.semanticscholar.org/CorpusID:254096365}
}
@inproceedings{sun2023spectr,
    title     = {SpecTr: Fast Speculative Decoding via Optimal Transport},
    author    = {Ziteng Sun and Ananda Theertha Suresh and Jae Hun Ro and Ahmad Beirami and Himanshu Jain and Felix Yu and Michael Riley and Sanjiv Kumar},
    booktitle = {Workshop on Efficient Systems for Foundation Models @ ICML2023},
    year      = {2023},
    url       = {https://openreview.net/forum?id=d0mGsaheuT}
}
@article{Chen2023AcceleratingLL,
    title     = {Accelerating Large Language Model Decoding with Speculative Sampling},
    author    = {Charlie Chen and Sebastian Borgeaud and Geoffrey Irving and Jean-Baptiste Lespiau and L. Sifre and John M. Jumper},
    journal   = {ArXiv},
    year      = {2023},
    volume    = {abs/2302.01318},
    url       = {https://api.semanticscholar.org/CorpusID:256503945}
}
@article{Yan2020ProphetNetPF,
    title   = {ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training},
    author  = {Yu Yan and Weizhen Qi and Yeyun Gong and Dayiheng Liu and Nan Duan and Jiusheng Chen and Ruofei Zhang and Ming Zhou},
    journal = {ArXiv},
    year    = {2020},
    volume  = {abs/2001.04063},
    url     = {https://api.semanticscholar.org/CorpusID:210164665}
}
@article{Zhang2023DraftV,
    title     = {Draft \& Verify: Lossless Large Language Model Acceleration via Self-Speculative Decoding},
    author    = {Jinchao Zhang and Jue Wang and Huan Li and Lidan Shou and Ke Chen and Gang Chen and Sharad Mehrotra},
    journal   = {ArXiv},
    year      = {2023},
    volume    = {abs/2309.08168},
    url       = {https://api.semanticscholar.org/CorpusID:262013673}
}
@misc{medusa,
    author     = {Tianle Cai and Yuhong Li and Zhengyang Geng and Hongwu Peng and Tri Dao},
    title      = {Medusa: Simple Framework for Accelerating LLM Generation with Multiple Decoding Heads},
    year       = {2023},
    publisher  = {GitHub},
    journal    = {GitHub repository},
    howpublished = {\url{https://github.com/FasterDecoding/Medusa}},
}

.\lucidrains\speculative-decoding\setup.py

# 导入设置安装和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'speculative-decoding', # 包的名称
  packages = find_packages(exclude=[]), # 查找所有包,不排除任何包
  version = '0.1.2', # 版本号
  license='MIT', # 许可证
  description = 'Speculative Decoding', # 描述
  author = 'Phil Wang', # 作者
  author_email = 'lucidrains@gmail.com', # 作者邮箱
  long_description_content_type = 'text/markdown', # 长描述内容类型
  url = 'https://github.com/lucidrains/speculative-decoding', # 项目链接
  keywords = [ # 关键词列表
    'artificial intelligence',
    'deep learning',
    'transformers',
    'efficient decoding'
  ],
  install_requires=[ # 安装依赖
    'beartype',
    'einops>=0.6.1',
    'torch>=1.12',
  ],
  classifiers=[ # 分类器列表
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\speculative-decoding\speculative_decoding\speculative_decoding.py

import math
# 导入数学库

import torch
# 导入 PyTorch 库
from torch.nn import Module, ModuleList
# 从 PyTorch 中导入 Module 和 ModuleList
from torch import nn, einsum, Tensor
# 从 PyTorch 中导入 nn、einsum 和 Tensor
import torch.nn.functional as F
# 从 PyTorch 中导入 nn.functional,并简称为 F

from rotary_embedding_torch import RotaryEmbedding
# 导入自定义的 RotaryEmbedding 模块
from beartype import beartype
# 导入 beartype 模块,用于类型检查

from collections import namedtuple
# 导入 namedtuple 模块

from einops import rearrange
# 导入 einops 中的 rearrange 函数

# constants

Cache = namedtuple('Cache', ['cached_kvs', 'embeds'])
# 定义一个命名元组 Cache,包含 cached_kvs 和 embeds 两个字段

# helper functions

def exists(val):
    return val is not None
# 定义函数 exists,用于判断值是否存在

def default(val, d):
    return val if exists(val) else d
# 定义函数 default,用于返回值或默认值

# sampling helpers

def log(t, eps = 1e-20):
    return torch.log(t.clamp(min = eps))
# 定义函数 log,用于计算对数并进行截断处理

def gumbel_noise(t):
    noise = torch.zeros_like(t).uniform_(0, 1)
    return -log(-log(noise))
# 定义函数 gumbel_noise,生成 Gumbel 噪声

def gumbel_sample(t, temperature = 1., dim = -1):
    return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim = dim)
# 定义函数 gumbel_sample,用于根据温度参数进行采样

def top_k(logits, thres = 0.9):
    k = math.ceil((1 - thres) * logits.shape[-1])
    val, ind = torch.topk(logits, k)
    probs = torch.full_like(logits, float('-inf'))
    probs.scatter_(-1, ind, val)
    return probs
# 定义函数 top_k,用于获取前 k 个最大值并进行处理

# rotary embeddings

class RotaryEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)
    # 定义 RotaryEmbedding 类,用于生成旋转嵌入

    def forward(self, seq_len):
        t = torch.arange(seq_len, device = self.inv_freq.device).type_as(self.inv_freq)
        freqs = einsum('i, j -> i j', t, self.inv_freq)
        freqs = torch.cat((freqs, freqs), dim = -1)
        return freqs
    # 前向传播函数,生成旋转嵌入

def rotate_half(x):
    x1, x2 = x.chunk(2, dim=-1)
    return torch.cat((-x2, x1), dim=-1)
# 定义函数 rotate_half,用于旋转张量的一半

def apply_rotary_pos_emb(pos, t):
    seq_len = t.shape[-2]
    pos = pos[-seq_len:, :]
    return t * pos.cos() + rotate_half(t) * pos.sin()
# 定义函数 apply_rotary_pos_emb,应用旋转位置嵌入到张量中

# different decoding strategies

@torch.no_grad()
def base_decoding(
    net: Module,
    prompt: Tensor,
    seq_len: int,
    temperature = 1.,
    filter_thres = 0.9,
):
    prompt_seq_len, out = prompt.shape[-1], prompt.clone()
    sample_num_times = max(0, seq_len - prompt_seq_len)

    cache = None

    for _ in range(sample_num_times):
        logits, cache = net(out, cache = cache, return_cache = True)
        logits = logits[:, -1]

        logits = top_k(logits, thres = filter_thres)
        sample = gumbel_sample(logits, temperature = temperature, dim = -1)

        out = torch.cat((out, sample[..., None]), dim = -1)

    return out[..., prompt_seq_len:]
# 定义函数 base_decoding,基础解码策略

# speculative decoding functions

def safe_div(num, den, eps = 1e-10):
    return num / max(den, eps)
# 定义函数 safe_div,安全除法

def find_first_true_index(bool_tensor, dim = -1):
    return (bool_tensor.cumsum(dim = dim) == 0).sum(dim = dim)
# 定义函数 find_first_true_index,查找第一个为真的索引

@torch.no_grad()
def speculative_decoding(
    net: Module,
    small_net: Module,
    prompt: Tensor,
    seq_len: int,
    gamma: int = 5,
    temperature = 1.,
    filter_thres = 0.9,
    lenience = 1.,
    pad_id = 0
):
    """
    eq. algorithm 1 in paper https://arxiv.org/abs/2211.17192
    """
    # 假设性解码函数,参考论文中的算法1

    batch, prompt_seq_len, out, device = *prompt.shape, prompt.clone(), prompt.device
    sample_num_times = max(0, seq_len - prompt_seq_len)

    cache = None
    small_cache = None

    num_steps = 0
    total_accepted = 0

    batch_range = torch.arange(batch, device = device, dtype = torch.long)[..., None]
    seq_lens = torch.full((batch,), prompt_seq_len, device = device, dtype = torch.long)

    # now left align

    num_pad_left = out.shape[-1] - seq_lens
    max_pad_left = num_pad_left.amax()
    out = F.pad(out, (0, max_pad_left), value = pad_id)

    seq_len_range = torch.arange(seq_len, device = device, dtype = torch.long)
    out = out[batch_range, seq_len_range + num_pad_left[..., None]]

    return out[..., prompt_seq_len:], total_accepted / num_steps
# 定义函数 speculative_decoding,假设性解码函数

@torch.no_grad()
def speculative_decoding_with_same_model(
    net: Module,
    prompt: Tensor,
    seq_len: int,
    gamma: int = 5,
    temperature = 1.,
    filter_thres = 0.9,
    lenience = 1.,
    pad_id = 0
):
    """
    eq. algorithm 1 in paper https://arxiv.org/abs/2211.17192
    """
    # 假设性解码函数,参考论文中的算法1
    # 将 prompt 的形状解包为 batch, prompt_seq_len, out, device
    batch, prompt_seq_len, out, device = *prompt.shape, prompt.clone(), prompt.device
    # 计算需要采样的次数
    sample_num_times = max(0, seq_len - prompt_seq_len)

    # 初始化缓存变量
    cache = None
    small_cache = None

    # 初始化步数和接受总数
    num_steps = 0
    total_accepted = 0

    # 创建 batch_range 和 seq_lens 张量
    batch_range = torch.arange(batch, device=device, dtype=torch.long)[..., None]
    seq_lens = torch.full((batch,), prompt_seq_len, device=device, dtype=torch.long)

    # 对输出进行左对齐填充
    num_pad_left = out.shape[-1] - seq_lens
    max_pad_left = num_pad_left.amax()
    out = F.pad(out, (0, max_pad_left), value=pad_id)

    # 选择左对齐后的输出
    seq_len_range = torch.arange(seq_len, device=device, dtype=torch.long)
    out = out[batch_range, seq_len_range + num_pad_left[..., None]]

    # 返回处理后的输出和接受率
    return out[..., prompt_seq_len:], total_accepted / num_steps
# 定义一个模块,用于对输入进行 RMS 归一化处理
class RMSNorm(Module):
    def __init__(self, dim):
        super().__init__()
        self.scale = dim ** 0.5
        self.gamma = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        return F.normalize(x, dim = -1) * self.scale * self.gamma

# 定义一个模块,实现自注意力机制
class CausalAttention(Module):
    def __init__(
        self,
        dim,
        *,
        dim_head = 64,
        heads = 8,
    ):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        dim_inner = dim_head * heads

        self.norm = RMSNorm(dim)

        self.to_qkv = nn.Linear(dim, dim_inner * 3, bias = False)
        self.to_out = nn.Linear(dim_inner, dim, bias = False)

    def forward(
        self,
        x,
        cache = None,
        context_mask = None,
        rotary_emb = None
    ):
        h, device = self.heads, x.device

        x = self.norm(x)

        q, k, v = rearrange(self.to_qkv(x), 'b n (qkv h d) -> qkv b h n d', qkv = 3, h = h)

        if exists(cache):
            ck, cv = cache.unbind(dim = 1)
            k = torch.cat((ck, k), dim = -2)
            v = torch.cat((cv, v), dim = -2)

        cached_kv = torch.stack((k, v), dim = 1)

        if exists(rotary_emb):
            q = apply_rotary_pos_emb(rotary_emb, q)
            k = apply_rotary_pos_emb(rotary_emb, k)

        sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        i, j = sim.shape[-2:]
        causal_mask = torch.ones((i, j), device = device, dtype = torch.bool).triu(j - i + 1)

        sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)

        if exists(context_mask):
            context_mask = rearrange(context_mask, 'b j -> b 1 1 j')
            sim = sim.masked_fill(~context_mask, -torch.finfo(sim.dtype).max)

        attn = sim.softmax(dim = -1)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)

        out = rearrange(out, 'b h n d -> b n (h d)')
        out = self.to_out(out)

        return out, cached_kv

# 定义一个前馈神经网络模块
def FeedForward(dim, mult = 4):
    dim_inner = dim * mult
    return nn.Sequential(
        RMSNorm(dim),
        nn.Linear(dim, dim_inner),
        nn.GELU(),
        nn.Linear(dim_inner, dim)
    )

# 主要的解码器类
class Decoder(Module):
    def __init__(
        self,
        *,
        num_tokens,
        dim,
        depth,
        heads = 8,
        dim_head = 64,
        ff_mult = 4,
        ignore_index = -1,
        early_exit_layer = None,
        early_exit_extra_transformer_blocks = 0,
        detach_early_exit_hiddens = False
    ):
        super().__init__()
        self.token_emb = nn.Embedding(num_tokens, dim)

        self.layers = ModuleList([])

        self.rotary_emb = RotaryEmbedding(dim = dim_head)

        # 创建多个解码器层,每个层包含自注意力和前馈神经网络模块
        for _ in range(depth):
            self.layers.append(ModuleList([
                CausalAttention(dim = dim, dim_head = dim_head, heads = heads),
                FeedForward(dim = dim, mult = ff_mult)
            ]))

        # 输出层,将解码器输出映射到标记空间
        self.to_logits = nn.Sequential(
            RMSNorm(dim),
            nn.Linear(dim, num_tokens, bias = False)
        )

        self.detach_early_exit_hiddens = detach_early_exit_hiddens
        self.early_exit_layer = early_exit_layer
        self.to_early_exit_logits = None
        self.early_exit_transformer_blocks = ModuleList([])

        # 如果存在提前退出层,则创建额外的解码器层
        if exists(early_exit_layer):
            for _ in range(early_exit_extra_transformer_blocks):
                self.early_exit_transformer_blocks.append(ModuleList([
                    CausalAttention(dim = dim, dim_head = dim_head, heads = heads, rotary_emb = rotary_emb),
                    FeedForward(dim = dim, mult = ff_mult)
                ]))

            # 提前退出层的输出层
            self.to_early_exit_logits = nn.Sequential(
                RMSNorm(dim),
                nn.Linear(dim, num_tokens, bias = False)
            )

        self.ignore_index = ignore_index
    # 定义一个方法用于前向传播
    def forward(
        self,
        x,
        return_loss = False,  # 是否返回损失,默认为False
        return_cache = False,  # 是否返回缓存,默认为False
        seq_start_pos = None,  # 序列起始位置,默认为None
        cache = None,  # 缓存,默认为None
        early_exit_cache = None,  # 提前退出缓存,默认为None
        return_early_exit_only = False,  # 是否仅返回提前退出,默认为False
        start_from_early_exit_hiddens = False  # 是否从提前退出隐藏状态开始,默认为False

.\lucidrains\speculative-decoding\speculative_decoding\speculative_decoding_with_prophet.py

import math
import torch
from torch.nn import Module, ModuleList
from torch import nn, einsum, Tensor
import torch.nn.functional as F
from rotary_embedding_torch import RotaryEmbedding
from beartype import beartype
from collections import namedtuple
from einops import rearrange

# 定义一个命名元组Cache,包含cached_kvs和embeds两个字段
Cache = namedtuple('Cache', ['cached_kvs', 'embeds'])

# 定义一些辅助函数

# 判断变量是否存在
def exists(val):
    return val is not None

# 如果val存在则返回val,否则返回默认值d
def default(val, d):
    return val if exists(val) else d

# 采样辅助函数

# 计算输入张量的对数,避免出现负无穷
def log(t, eps = 1e-20):
    return torch.log(t.clamp(min = eps))

# 生成Gumbel噪声
def gumbel_noise(t):
    noise = torch.zeros_like(t).uniform_(0, 1)
    return -log(-log(noise))

# 从Gumbel分布中采样
def gumbel_sample(t, temperature = 1., dim = -1):
    return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim = dim)

# 保留top-k的概率值,其余设置为负无穷
def top_k(logits, thres = 0.9):
    k = math.ceil((1 - thres) * logits.shape[-1])
    val, ind = torch.topk(logits, k)
    probs = torch.full_like(logits, float('-inf'))
    probs.scatter_(-1, ind, val)
    return probs

# 旋转嵌入

# 定义旋转嵌入类
class RotaryEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)

    def forward(self, seq_len):
        t = torch.arange(seq_len, device = self.inv_freq.device).type_as(self.inv_freq)
        freqs = einsum('i, j -> i j', t, self.inv_freq)
        freqs = torch.cat((freqs, freqs), dim = -1)
        return freqs

# 将输入张量的一半旋转
def rotate_half(x):
    x1, x2 = x.chunk(2, dim=-1)
    return torch.cat((-x2, x1), dim=-1)

# 应用旋转位置嵌入
def apply_rotary_pos_emb(pos, t):
    seq_len = t.shape[-2]
    pos = pos[-seq_len:, :]
    return t * pos.cos() + rotate_half(t) * pos.sin()

# 不同的解码策略

# 基础解码函数,用于生成序列
@torch.no_grad()
def base_decoding(
    net: Module,
    prompt: Tensor,
    seq_len: int,
    temperature = 1.,
    filter_thres = 0.9,
):
    prompt_seq_len, out = prompt.shape[-1], prompt.clone()
    sample_num_times = max(0, seq_len - prompt_seq_len)

    cache = None

    for _ in range(sample_num_times):
        logits, cache = net(out, cache = cache, return_cache = True)
        logits = logits[:, -1]

        logits = top_k(logits, thres = filter_thres)
        sample = gumbel_sample(logits, temperature = temperature, dim = -1)

        out = torch.cat((out, sample[..., None]), dim = -1)

    return out[..., prompt_seq_len:]

# 归一化

# 均方根归一化类
class RMSNorm(Module):
    def __init__(self, dim):
        super().__init__()
        self.scale = dim ** 0.5
        self.gamma = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        return F.normalize(x, dim = -1) * self.scale * self.gamma

# 注意力和前馈

# 因果注意力类
class CausalAttention(Module):
    def __init__(
        self,
        dim,
        *,
        dim_head = 64,
        heads = 8,
    ):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        dim_inner = dim_head * heads

        self.norm = RMSNorm(dim)

        self.to_qkv = nn.Linear(dim, dim_inner * 3, bias = False)
        self.to_out = nn.Linear(dim_inner, dim, bias = False)

    def forward(
        self,
        x,
        cache = None,
        context_mask = None,
        rotary_emb = None
        ):
        # 获取头数和输入张量的设备信息
        h, device = self.heads, x.device

        # 对输入张量进行归一化处理
        x = self.norm(x)

        # 将输入张量转换为查询、键、值,并重新排列维度
        q, k, v = rearrange(self.to_qkv(x), 'b n (qkv h d) -> qkv b h n d', qkv = 3, h = h)

        # 如果存在缓存,则将缓存的键值与当前计算的键值拼接
        if exists(cache):
            ck, cv = cache.unbind(dim = 1)
            k = torch.cat((ck, k), dim = -2)
            v = torch.cat((cv, v), dim = -2)

        # 将键值对堆叠在一起
        cached_kv = torch.stack((k, v), dim = 1)

        # 如果存在旋转位置编码,则应用旋转位置编码到查询和键
        if exists(rotary_emb):
            q = apply_rotary_pos_emb(rotary_emb, q)
            k = apply_rotary_pos_emb(rotary_emb, k)

        # 计算注意力矩阵
        sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        i, j = sim.shape[-2:]
        # 创建因果掩码
        causal_mask = torch.ones((i, j), device = device, dtype = torch.bool).triu(j - i + 1)

        # 使用因果掩码填充注意力矩阵
        sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)

        # 如果存在上下文掩码,则使用上下文掩码填充注意力矩阵
        if exists(context_mask):
            context_mask = rearrange(context_mask, 'b j -> b 1 1 j')
            sim = sim.masked_fill(~context_mask, -torch.finfo(sim.dtype).max)

        # 计算注意力权重
        attn = sim.softmax(dim = -1)

        # 计算输出张量
        out = einsum('b h i j, b h j d -> b h i d', attn, v)

        # 重新排列输出张量的维度
        out = rearrange(out, 'b h n d -> b n (h d)')
        # 将输出张量转换为输出
        out = self.to_out(out)

        # 返回输出张量和缓存的键值对
        return out, cached_kv
# 定义一个前馈神经网络模块,包含 RMSNorm 层、线性层、GELU 激活函数和另一个线性层
def FeedForward(dim, mult = 4):
    # 计算内部维度
    dim_inner = dim * mult
    return nn.Sequential(
        RMSNorm(dim),  # 使用 RMSNorm 对输入进行归一化
        nn.Linear(dim, dim_inner),  # 线性变换,将输入维度转换为内部维度
        nn.GELU(),  # GELU 激活函数
        nn.Linear(dim_inner, dim)  # 线性变换,将内部维度转换为输出维度
    )

# 主要类

class Decoder(Module):
    def __init__(
        self,
        *,
        num_tokens,
        dim,
        depth,
        heads = 8,
        dim_head = 64,
        ff_mult = 4,
        ignore_index = -1
    ):
        super().__init__()
        self.dim = dim
        self.token_emb = nn.Embedding(num_tokens, dim)  # 创建一个嵌入层,将标记映射到指定维度的向量

        self.layers = ModuleList([])  # 创建一个空的模块列表

        self.rotary_emb = RotaryEmbedding(dim = dim_head)  # 创建一个旋转嵌入层,用于相对位置编码

        for _ in range(depth):
            self.layers.append(ModuleList([
                CausalAttention(dim = dim, dim_head = dim_head, heads = heads),  # 创建一个因果注意力层
                FeedForward(dim = dim, mult = ff_mult)  # 创建一个前馈神经网络模块
            ]))

        self.to_logits = nn.Sequential(
            RMSNorm(dim),  # 使用 RMSNorm 对输入进行归一化
            nn.Linear(dim, num_tokens, bias = False)  # 线性变换,将维度转换为标记数量,不使用偏置
        )

        self.ignore_index = ignore_index  # 设置忽略的索引值

    def forward(
        self,
        x,
        start_tokens = None,
        return_loss = False,
        return_cache = False,
        seq_start_pos = None,
        cache = None
    ):
        has_start_tokens = exists(start_tokens)  # 检查是否存在起始标记

        start_token_len = 0
        if exists(start_tokens):
            if start_tokens.ndim == 2:
                start_tokens = rearrange(start_tokens, 'b d -> b 1 d')  # 重新排列起始标记的维度

            start_token_len = start_tokens.shape[-2]  # 获取起始标记的长度

        if return_loss:
            x, labels = x[:, start_token_len:-1], x[:, 1:]  # 如果需要返回损失,则截取输入和标签序列

        x = self.token_emb(x)  # 将输入序列映射为嵌入向量

        if exists(start_tokens):
            x = torch.cat((start_tokens, x), dim = 1)  # 如果存在起始标记,则将其与输入序列连接起来

        # 处理序列起始位置偏移

        self_attn_kv_mask = None  # 初始化自注意力键值掩码为 None
        if exists(seq_start_pos):
            batch, seq_len = x.shape[:2]
            seq_range = torch.arange(seq_len, device = x.device, dtype = torch.long)
            self_attn_kv_mask = seq_range >= seq_start_pos[..., None]  # 生成自注意力键值掩码

        # 相对位置编码

        rotary_emb = self.rotary_emb(x.shape[-2])  # 获取相对位置编码

        # 设置缓存

        new_cached_kvs = []  # 创建一个新的缓存键值对列表

        cache_kvs = cache_embeds = None  # 初始化缓存键值对和嵌入向量为 None

        if exists(cache):
            cache_kvs, cache_embeds = cache  # 如果存在缓存,则获取缓存键值对和嵌入向量

        if exists(cache_kvs):
            iter_cache_kvs = iter(cache_kvs.unbind(dim = 1))  # 迭代缓存键值对
        else:
            iter_cache_kvs = iter([])  # 否则创建一个空迭代器

        # 如果传入了缓存,则只使用最后一个标记

        if exists(cache):
            num_tokens_keep = x.shape[-2] - cache_kvs.shape[-2]  # 计算保留的标记数量
            x = x[:, -num_tokens_keep:]  # 截取保留的标记

        # 主要的变换器体

        for ind, (attn, ff) in enumerate(self.layers):
            layer = ind + 1  # 获取当前层索引

            residual = x  # 保存残差连接
            attn_out, cached_kv = attn(x, rotary_emb = rotary_emb, cache = next(iter_cache_kvs, None))  # 执行注意力计算
            x = residual + attn_out  # 添加残差连接

            new_cached_kvs.append(cached_kv)  # 将缓存键值对添加到列表中

        new_cached_kvs = torch.stack(new_cached_kvs, dim = 1)  # 将新的缓存键值对堆叠在一起

        logits = self.to_logits(x)  # 获取输出 logits

        if not return_loss:
            if not return_cache:
                return logits  # 如果不需要返回损失和缓存,则直接返回 logits

            return logits, Cache(new_cached_kvs, x)  # 否则返回 logits 和缓存

        loss = F.cross_entropy(
            rearrange(logits, 'b n c -> b c n'),  # 重新排列 logits 的维度
            labels,  # 标签
            ignore_index = self.ignore_index  # 忽略的索引值
        )

        return loss, Cache(new_cached_kvs, x)  # 返回损失和缓存

class ModelWithProphetWrapper(Module):
    def __init__(
        self,
        model: Decoder,
        prophet: Decoder,
        prophet_train_length = 8,  # 先知训练长度,应大于主模型解码伽马,因为主模型缓存嵌入是滞后一步的
        detach_model_embed_for_prophet = False,
        num_leading_start_tokens = 1
    # 初始化函数,继承父类的初始化方法
    def __init__(
        super().__init__()
        # 初始化模型和prophet
        self.model = model
        self.prophet = prophet

        # 判断模型和prophet的维度是否相同
        model_prophet_same_dim = model.dim == prophet.dim
        # 如果维度相同,则使用nn.Identity(),否则使用nn.Linear()进行维度转换
        self.to_prophet_start_token = nn.Identity() if model_prophet_same_dim else nn.Linear(model.dim, prophet.dim, bias = False)

        # 确保num_leading_start_tokens大于等于1
        assert num_leading_start_tokens >= 1
        self.num_leading_start_tokens = num_leading_start_tokens

        # 设置prophet的训练长度和是否在模型嵌入中分离prophet
        self.prophet_train_length = prophet_train_length
        self.detach_model_embed_for_prophet = detach_model_embed_for_prophet

    # 前向传播函数
    def forward(self, x):
        # 获取num_start_tokens、batch、seq_len、device
        num_start_tokens = self.num_leading_start_tokens
        batch, seq_len, device = *x.shape, x.device
        prophet_seq_len = self.prophet_train_length
        # 确保序列长度大于等于prophet训练长度
        assert seq_len >= prophet_seq_len

        total_loss = 0.

        # 调用模型的前向传播函数,返回主要损失和缓存的键值对以及嵌入
        main_loss, (cached_kvs, embeds) = self.model(x, return_loss = True)

        # 累加主要损失
        total_loss = total_loss + main_loss

        # 如果需要分离模型嵌入用于prophet
        if self.detach_model_embed_for_prophet:
            embeds = embeds.detach()

        # 将嵌入转换为prophet的起始标记
        prophet_start_tokens = self.to_prophet_start_token(embeds)

        # 创建batch索引和prophet序列长度索引
        batch_arange = torch.arange(batch, device = device, dtype = torch.long)
        prophet_seq_arange = torch.arange(prophet_seq_len, device = device, dtype = torch.long)

        # 计算用于prophet训练的序列数量
        num_seq_train_prophet = seq_len - prophet_seq_len - (num_start_tokens - 1)

        # 创建偏移量
        offsets = torch.arange(num_seq_train_prophet, device = device, dtype = torch.long)

        # 获取prophet的输入序列
        prophet_input = x[
            batch_arange[:, None, None],
            offsets[..., None] + prophet_seq_arange
        ]

        # 重新排列prophet的输入序列
        prophet_input = rearrange(prophet_input, '... n -> (...) n')

        # 创建起始标记索引
        start_tokens_arange = torch.arange(num_start_tokens, device = device, dtype = torch.long)

        # 获取prophet的起始标记
        prophet_start_tokens = prophet_start_tokens[
            batch_arange[:, None, None],
            offsets[..., None] + start_tokens_arange
        ]

        # 重新排列prophet的起始标记
        prophet_start_tokens = rearrange(prophet_start_tokens[:, :num_seq_train_prophet], 'b n l d -> (b n) l d')

        # 调用prophet的前向传播函数,返回prophet损失
        prophet_loss, _ = self.prophet(prophet_input, start_tokens = prophet_start_tokens, return_loss = True)

        # 累加prophet损失
        total_loss = total_loss + prophet_loss

        # 返回总损失和主要损失、prophet损失
        return total_loss, (main_loss, prophet_loss)
# 安全除法函数,避免分母为零的情况
def safe_div(num, den, eps = 1e-10):
    return num / max(den, eps)

# 在布尔张量中查找第一个为True的索引
def find_first_true_index(bool_tensor, dim = -1):
    return (bool_tensor.cumsum(dim = dim) == 0).sum(dim = dim)

# 使用Prophet模型进行推测解码
@torch.no_grad()
def speculative_decoding_with_prophet_model(
    net: ModelWithProphetWrapper,
    prompt: Tensor,
    seq_len: int,
    gamma: int = 5,
    temperature = 1.,
    filter_thres = 0.9,
    lenience = 1.,
    pad_id = 0
):
    """
    eq. algorithm 1 in paper https://arxiv.org/abs/2211.17192
    """

    # 提取模型、Prophet模型和模型到Prophet模型的转换(如果它们的模型维度不同)

    model = net.model
    to_prophet_start_token = net.to_prophet_start_token
    prophet = net.prophet
    num_start_tokens = net.num_leading_start_tokens

    batch, prompt_seq_len, out, device = *prompt.shape, prompt.clone(), prompt.device

    if (seq_len - prompt_seq_len) <= 0:
        return prompt, None

    cache = None
    small_cache = None

    num_steps = 0
    total_accepted = 0

    batch_range = torch.arange(batch, device = device, dtype = torch.long)[..., None]
    seq_lens = torch.full((batch,), prompt_seq_len, device = device, dtype = torch.long)

    # 从主模型中随机抽样第一个标记

    for _ in range(max(1, num_start_tokens - prompt_seq_len)):
        logits, cache = model(out, cache = cache, return_cache = True)
        logits = logits[:, -1:]
        logits = top_k(logits, thres = filter_thres)
        sample = gumbel_sample(logits, temperature = temperature, dim = -1)
        out = torch.cat((out, sample), dim = -1)
        seq_lens += 1

    # 现在我们有第一个缓存的嵌入,用作推测抽样的Prophet网络的起始标记

    _, embeds = cache
    next_prophet_start_tokens = to_prophet_start_token(embeds[:, -num_start_tokens:])
    # 当序列长度小于给定的序列长度时,执行循环
    while (seq_lens < seq_len).any():

        # 使用较小的网络进行预测

        # 存储所有较小网络的logits和采样输出
        all_small_logits = []
        q_sampled_out = []

        small_cache = None
        num_tokens = 2  # 主模型的嵌入比主序列滞后1步

        # 运行gamma次循环
        for _ in range(gamma):
            # 使用prophet函数进行预测
            small_logits, small_cache = prophet(
                out[..., -num_tokens:],
                start_tokens = next_prophet_start_tokens,
                cache = small_cache,
                return_cache = True
            )

            small_logits = small_logits[:, -1:]

            # 对logits进行top-k筛选
            small_logits = top_k(small_logits, thres = filter_thres)
            all_small_logits.append(small_logits)

            # 使用gumbel采样得到样本
            sample = gumbel_sample(small_logits, temperature = temperature, dim = -1)
            out = torch.cat((out, sample), dim = -1)

            seq_lens += 1
            num_tokens += 1

            q_sampled_out.append(rearrange(sample, '... -> ... 1'))

        q_sampled_out = torch.cat(q_sampled_out, dim = -2)
        small_logits = torch.cat(all_small_logits, dim = -2)

        # 使用较大的网络进行验证

        logits, cache = model(
            out,
            cache = cache,
            return_cache = True,
            seq_start_pos = out.shape[-1] - seq_lens
        )

        logits = logits[..., -(gamma + 1):, :]
        logits = top_k(logits, thres = filter_thres)

        # 计算较大网络和较小网络的概率(算法1中的p(x)和q(x))

        prob = safe_div(logits, temperature).softmax(dim = -1)
        small_prob = safe_div(small_logits, temperature).softmax(dim = -1)

        p, prob_next = prob[:, :-1], prob[:, -1]

        p = p.gather(-1, q_sampled_out)
        q = small_prob.gather(-1, q_sampled_out) * lenience

        p, q = [rearrange(t, 'b n 1 -> b n') for t in (p, q)]

        r = random_uniform = torch.zeros_like(q).float().uniform_(0, 1)

        accepted = find_first_true_index(r > (p / q))

        total_accepted += accepted.float().mean()
        num_steps += 1

        num_rejected = gamma - accepted
        has_rejected = num_rejected > 0

        accepted = rearrange(accepted, 'b -> b 1')
        accepted.clamp_(max = gamma - 1)

        adjusted_prob = F.relu(prob[batch_range, accepted] - small_prob[batch_range, accepted])
        adjusted_prob = adjusted_prob / adjusted_prob.sum(dim = -1, keepdim = True)
        adjusted_prob = rearrange(adjusted_prob, 'b 1 d -> b d')

        prob_next = torch.where(
            rearrange(has_rejected, '... -> ... 1'),
            adjusted_prob,
            prob_next
        )

        # 进行一系列切片操作,将所有内容对齐到右侧,包括kv缓存

        max_num_rejected = num_rejected.amax()
        seq_arange = torch.arange(out.shape[-1], device = device, dtype = torch.long)
        seq_offset_indices = seq_arange + (max_num_rejected - num_rejected)[..., None]

        seq_lens -= num_rejected
        max_seq_len = seq_lens.amax()

        if batch > 1:
            out = F.pad(out, (0, max_num_rejected), value = pad_id)
            out = out[batch_range, seq_offset_indices]

            cache = tuple(F.pad(t, (0, 0, 0, max_num_rejected), value = pad_id) for t in cache)
            cache = tuple(rearrange(t, 'b ... n d -> b n ... d') for t in cache)
            cache = tuple(t[batch_range, seq_offset_indices] for t in cache)
            cache = tuple(rearrange(t, 'b n ... d -> b ... n d') for t in cache)

            if out.shape[-1] > max_seq_len:
                left_index = out.shape[-1] - max_seq_len
                out = out[:, left_index:]
                cache = tuple(t[..., left_index:, :] for t in cache)

        # 采样额外的token,这是论文中的一个技巧,用于更好地限制最坏情况

        next_token = torch.multinomial(prob_next, 1)

        out = torch.cat((out, next_token), dim = -1)
        seq_lens += 1

        _, embeds = cache
        next_prophet_start_tokens = to_prophet_start_token(embeds[:, -num_start_tokens:])
    # 将输出向左对齐

    # 计算需要左侧填充的数量
    num_pad_left = out.shape[-1] - seq_lens
    # 计算最大的左侧填充数量
    max_pad_left = num_pad_left.amax()
    # 在输出张量的最后一个维度上进行填充,左侧填充0,右侧填充最大填充数量,填充值为pad_id
    out = F.pad(out, (0, max_pad_left), value=pad_id)

    # 创建一个序列长度范围的张量
    seq_len_range = torch.arange(seq_len, device=device, dtype=torch.long)
    # 从out张量中选择出需要的部分,根据batch_range和seq_len_range进行索引
    out = out[batch_range, seq_len_range + num_pad_left[..., None]]

    # 返回去除prompt_seq_len长度后的out张量和total_accepted除以num_steps的结果
    return out[..., prompt_seq_len:], total_accepted / num_steps

.\lucidrains\speculative-decoding\speculative_decoding\__init__.py

# 从 speculative_decoding.speculative_decoding 模块中导入 Decoder、base_decoding、speculative_decoding、speculative_decoding_with_same_model 函数
from speculative_decoding.speculative_decoding import (
    Decoder,
    base_decoding,
    speculative_decoding,
    speculative_decoding_with_same_model
)

.\lucidrains\speculative-decoding\train.py

# 导入所需的库
import gzip
import random
import tqdm
import numpy as np
import time
from functools import wraps, partial
import torch
from torch.optim import Adam
from torch.nn import functional as F
from torch.cuda import synchronize, Event
from torch.utils.data import DataLoader, Dataset
from speculative_decoding import (
    Decoder,
    base_decoding,
    speculative_decoding
)

# 定义常量
NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRAD_ACCUM_EVERY = 4
LEARNING_RATE = 1e-4
VALIDATE_EVERY = 100
PRIME_LENGTH = 128
GENERATE_EVERY = 500
GENERATE_LENGTH = 512
SEQ_LEN = 512
GAMMA = 5
DEVICE_STR = 'cuda' if torch.cuda.is_available() else 'cpu'

# 定义辅助函数
def cycle(loader):
    while True:
        for data in loader:
            yield data

def decode_token(token):
    return str(chr(max(32, token)))

def decode_tokens(tokens):
    return "".join(list(map(decode_token, tokens)))

def benchmark(fn):
    @wraps(fn)
    def inner(*args, **kwargs):
        start_event = timer()
        end_event = timer()
        start_event.record()

        out = fn(*args, **kwargs)

        end_event.record()
        torch.cuda.synchronize()
        elapsed_time_ms = start_event.elapsed_time(end_event)
        return out, elapsed_time_ms
    return inner

# 实例化 Transformer 模型
device = torch.device(DEVICE_STR)
model = Decoder(
    num_tokens = 256,
    dim = 512,
    depth = 10
).to(device)

# 实例化小型 Transformer 模型
small_model = Decoder(
    num_tokens = 256,
    dim = 512,
    depth = 2
).to(device)

# 准备 enwik8 数据
with gzip.open("./data/enwik8.gz") as file:
    data = np.frombuffer(file.read(int(95e6)), dtype=np.uint8).copy()
    np_train, np_valid = np.split(data, [int(90e6)])
    data_train, data_val = torch.from_numpy(np_train), torch.from_numpy(np_valid)

# 定义数据集类
class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
        full_seq = self.data[rand_start : rand_start + self.seq_len + 1].long()
        return full_seq.to(device)

    def __len__(self):
        return self.data.size(0) // self.seq_len

# 创建训练集和验证集
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
train_loader = cycle(DataLoader(train_dataset, batch_size=BATCH_SIZE))
val_loader = cycle(DataLoader(val_dataset, batch_size=BATCH_SIZE))

# 定义优化器
optim = Adam(model.parameters(), lr = LEARNING_RATE)
small_optim = Adam(small_model.parameters(), lr = LEARNING_RATE)

# 训练过程
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval = 10.0, desc = "training"):
    model.train()
    small_model.train()

    for _ in range(GRAD_ACCUM_EVERY):
        data = next(train_loader)

        loss = model(data, return_loss = True)
        small_loss = small_model(data, return_loss = True)

        (loss / GRAD_ACCUM_EVERY).backward()
        (small_loss / GRAD_ACCUM_EVERY).backward()

    print(f"training loss: {loss.item():.3f}")
    print(f"training small loss: {small_loss.item():.3f}")

    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
    torch.nn.utils.clip_grad_norm_(small_model.parameters(), 0.5)

    optim.step()
    optim.zero_grad()

    small_optim.step()
    small_optim.zero_grad()

    if i % VALIDATE_EVERY == 0:
        model.eval()
        with torch.no_grad():
            valid_data = next(val_loader)

            loss = model(valid_data, return_loss = True)
            print(f"validation loss: {loss.item():.3f}")

            small_loss = small_model(valid_data, return_loss = True)
            print(f"validation small loss: {small_loss.item():.3f}")
    # 检查是否达到生成频率
    if i % GENERATE_EVERY == 0:
        # 将模型设置为评估模式
        model.eval()
        small_model.eval()

        # 从验证数据集中随机选择一个样本作为输入
        inp = random.choice(val_dataset)[:PRIME_LENGTH]
        # 将输入解码为文本
        prime = decode_tokens(inp)
        # 打印输入的提示信息
        print(f"%s \n\n %s", (prime, "*" * 100))

        # 将输入转换为张量
        prompt = inp[None, ...]

        # 使用基本解码函数生成文本序列,并记录基本解码时间
        sampled, base_decode_elapsed = benchmark(base_decoding)(model, prompt, GENERATE_LENGTH)

        # 使用推测解码函数生成文本序列,并记录推测解码时间以及接受的标记数量
        (spec_decode_sampled, num_accepted), spec_decode_elapsed = benchmark(speculative_decoding)(model, small_model, prompt, GENERATE_LENGTH, GAMMA)

        # 将基本解码和推测解码的输出解码为文本
        base_decode_output = decode_tokens(sampled[0])
        spec_decode_output = decode_tokens(spec_decode_sampled[0])

        # 打印基本解码的输出
        print("\nbase decoding:\n\n", base_decode_output, "\n")
        # 打印推测解码的输出
        print("\nspec decoding:\n\n", spec_decode_output, "\n")

        # 打印基本解码的时间
        print(f'base decoding in: {base_decode_elapsed:.3f}ms\n')
        # 打印推测解码的时间
        print(f'spec decoding in: {spec_decode_elapsed:.3f}ms\n')
        # 打印平均接受的标记数量
        print(f'average num accepted: {num_accepted:.1f} / {GAMMA}\n')

.\lucidrains\speculative-decoding\train_early_exit.py

# 导入必要的库
import gzip
import random
import tqdm
import numpy as np
import time
from functools import wraps, partial

import torch
from torch.optim import Adam
from torch.nn import functional as F
from torch.cuda import synchronize, Event
from torch.utils.data import DataLoader, Dataset

# 创建计时器
timer = partial(Event, enable_timing = True)

# 导入自定义模块
from speculative_decoding import (
    Decoder,
    base_decoding,
    speculative_decoding_with_same_model
)

# 定义常量
NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRAD_ACCUM_EVERY = 4
LEARNING_RATE = 1e-4
VALIDATE_EVERY = 100
PRIME_LENGTH = 128
GENERATE_EVERY = 500
GENERATE_LENGTH = 512
SEQ_LEN = 512
GAMMA = 5
EARLY_EXIT_LOSS_WEIGHT = 1.

DEVICE_STR = 'cuda' if torch.cuda.is_available() else 'cpu'

# 定义循环函数
def cycle(loader):
    while True:
        for data in loader:
            yield data

# 解码单个 token
def decode_token(token):
    return str(chr(max(32, token)))

# 解码一组 tokens
def decode_tokens(tokens):
    return "".join(list(map(decode_token, tokens)))

# 计时装饰器
def benchmark(fn):
    @wraps(fn)
    def inner(*args, **kwargs):
        start_event = timer()
        end_event = timer()
        start_event.record()

        out = fn(*args, **kwargs)

        end_event.record()
        torch.cuda.synchronize()
        elapsed_time_ms = start_event.elapsed_time(end_event)
        return out, elapsed_time_ms
    return inner

# 实例化 Transformer 模型
device = torch.device(DEVICE_STR)

model = Decoder(
    num_tokens = 256,
    dim = 512,
    depth = 10,
    early_exit_layer = 2   # 使用与小近似模型相同的模型,稍后考虑缓存层隐藏状态
).to(device)

# 准备 enwik8 数据
with gzip.open("./data/enwik8.gz") as file:
    data = np.frombuffer(file.read(int(95e6)), dtype=np.uint8).copy()
    np_train, np_valid = np.split(data, [int(90e6)])
    data_train, data_val = torch.from_numpy(np_train), torch.from_numpy(np_valid)

# 定义数据集类
class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
        full_seq = self.data[rand_start : rand_start + self.seq_len + 1].long()
        return full_seq.to(device)

    def __len__(self):
        return self.data.size(0) // self.seq_len

# 创建训练集和验证集
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
train_loader = cycle(DataLoader(train_dataset, batch_size=BATCH_SIZE))
val_loader = cycle(DataLoader(val_dataset, batch_size=BATCH_SIZE))

# 定义优化器
optim = Adam(model.parameters(), lr = LEARNING_RATE)

# 训练过程
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval = 10.0, desc = "training"):
    model.train()

    for _ in range(GRAD_ACCUM_EVERY):
        data = next(train_loader)

        loss, small_loss = model(data, return_loss = True)

        ((loss + small_loss * EARLY_EXIT_LOSS_WEIGHT) / GRAD_ACCUM_EVERY).backward()

    print(f"training loss: {loss.item():.3f}")
    print(f"training small loss: {small_loss.item():.3f}")

    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)

    optim.step()
    optim.zero_grad()

    if i % VALIDATE_EVERY == 0:
        model.eval()
        with torch.no_grad():
            valid_data = next(val_loader)

            loss, small_loss = model(valid_data, return_loss = True)
            print(f"validation loss: {loss.item():.3f}")
            print(f"validation small loss: {small_loss.item():.3f}")
    # 检查是否达到生成的次数
    if i % GENERATE_EVERY == 0:
        # 将模型设置为评估模式
        model.eval()

        # 从验证数据集中随机选择一个样本作为输入
        inp = random.choice(val_dataset)[:PRIME_LENGTH]
        # 将输入解码为文本
        prime = decode_tokens(inp)
        # 打印输入的文本和分隔符
        print(f"%s \n\n %s", (prime, "*" * 100))

        # 将输入转换为张量
        prompt = inp[None, ...]

        # 使用基本解码函数对模型进行基本解码,并记录时间
        sampled, base_decode_elapsed = benchmark(base_decoding)(model, prompt, GENERATE_LENGTH)

        # 使用具有相同模型的推测解码函数对模型进行推测解码,并记录时间
        (spec_decode_sampled, num_accepted), spec_decode_elapsed = benchmark(speculative_decoding_with_same_model)(model, prompt, GENERATE_LENGTH, GAMMA)

        # 将基本解码的输出解码为文本
        base_decode_output = decode_tokens(sampled[0])
        # 将推测解码的输出解码为文本
        spec_decode_output = decode_tokens(spec_decode_sampled[0])

        # 打印基本解码的输出
        print("\nbase decoding:\n\n", base_decode_output, "\n")
        # 打印推测解码的输出
        print("\nspec decoding:\n\n", spec_decode_output, "\n")

        # 打印基本解码的时间
        print(f'base decoding in: {base_decode_elapsed:.3f}ms\n')
        # 打印推测解码的时间
        print(f'spec decoding in: {spec_decode_elapsed:.3f}ms\n')
        # 打印平均接受的数量
        print(f'average num accepted: {num_accepted:.1f} / {GAMMA}\n')

.\lucidrains\speculative-decoding\train_prophet.py

# 导入必要的库
import gzip
import random
import tqdm
import numpy as np
import time
from functools import wraps, partial
import torch
from torch.optim import Adam
from torch.nn import functional as F
from torch.cuda import synchronize, Event
from torch.utils.data import DataLoader, Dataset

# 创建计时器
timer = partial(Event, enable_timing = True)

# 导入自定义模块
from speculative_decoding.speculative_decoding_with_prophet import (
    Decoder,
    ModelWithProphetWrapper,
    base_decoding,
    speculative_decoding_with_prophet_model
)

# 定义常量
NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRAD_ACCUM_EVERY = 4
LEARNING_RATE = 1e-4
PRIME_LENGTH = 128
GENERATE_EVERY = 100
GENERATE_LENGTH = 512
SEQ_LEN = 512
GAMMA = 5
TRAIN_PROPHET = True

DEVICE_STR = 'cuda' if torch.cuda.is_available() else 'cpu'

# 定义辅助函数

# 生成数据循环
def cycle(loader):
    while True:
        for data in loader:
            yield data

# 解码单个 token
def decode_token(token):
    return str(chr(max(32, token)))

# 解码一组 tokens
def decode_tokens(tokens):
    return "".join(list(map(decode_token, tokens)))

# 计时装饰器
def benchmark(fn):
    @wraps(fn)
    def inner(*args, **kwargs):
        start_event = timer()
        end_event = timer()
        start_event.record()

        out = fn(*args, **kwargs)

        end_event.record()
        torch.cuda.synchronize()
        elapsed_time_ms = start_event.elapsed_time(end_event)
        return out, elapsed_time_ms
    return inner

# 实例化 Transformer 模型

device = torch.device(DEVICE_STR)

model = Decoder(
    num_tokens = 256,
    dim = 512,
    depth = 10
)

prophet = Decoder(
    num_tokens = 256,
    dim = 512,
    depth = 2
)

model_and_prophet = ModelWithProphetWrapper(
    model,
    prophet,
    prophet_train_length = GAMMA + 2,
    num_leading_start_tokens = 2,
    detach_model_embed_for_prophet = False
).to(device)

# 准备 enwik8 数据

with gzip.open("./data/enwik8.gz") as file:
    data = np.frombuffer(file.read(int(95e6)), dtype=np.uint8).copy()
    np_train, np_valid = np.split(data, [int(90e6)])
    data_train, data_val = torch.from_numpy(np_train), torch.from_numpy(np_valid)

# 创建数据集类
class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
        full_seq = self.data[rand_start : rand_start + self.seq_len + 1].long()
        return full_seq.to(device)

    def __len__(self):
        return self.data.size(0) // self.seq_len

# 创建训练和验证数据集
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
train_loader = cycle(DataLoader(train_dataset, batch_size=BATCH_SIZE))

# 选择优化器参数
params = model_and_prophet.parameters() if TRAIN_PROPHET else model.parameters()

# 创建优化器
optim = Adam(params, lr = LEARNING_RATE)

# 训练循环
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval = 10.0, desc = "training"):
    model_and_prophet.train()

    for _ in range(GRAD_ACCUM_EVERY):
        data = next(train_loader)

        total_loss, (loss, prophet_loss) = model_and_prophet(data)

        (total_loss / GRAD_ACCUM_EVERY).backward()

    print(f"training loss: {loss.item():.3f}")
    print(f"training prophet loss: {prophet_loss.item():.3f}")

    torch.nn.utils.clip_grad_norm_(model_and_prophet.parameters(), 0.5)

    optim.step()
    optim.zero_grad()
    # 检查是否达到生成频率
    if i % GENERATE_EVERY == 0:
        # 将模型和prophet评估为当前状态
        model_and_prophet.eval()

        # 从验证数据集中随机选择一个样本作为输入
        inp = random.choice(val_dataset)[:PRIME_LENGTH]
        # 将输入解码为文本
        prime = decode_tokens(inp)
        # 打印输入的prime文本和分隔符
        print(f"%s \n\n %s", (prime, "*" * 100))

        # 将输入转换为张量
        prompt = inp[None, ...]

        # 使用基本解码函数对模型进行基本解码
        sampled, base_decode_elapsed = benchmark(base_decoding)(model, prompt, GENERATE_LENGTH)

        # 使用带有prophet模型的推测解码函数对模型进行推测解码
        (spec_decode_sampled, num_accepted), spec_decode_elapsed = benchmark(speculative_decoding_with_prophet_model)(model_and_prophet, prompt, GENERATE_LENGTH, GAMMA)

        # 将基本解码和推测解码的输出解码为文本
        base_decode_output = decode_tokens(sampled[0])
        spec_decode_output = decode_tokens(spec_decode_sampled[0])

        # 打印基本解码的输出
        print("\nbase decoding:\n\n", base_decode_output, "\n")
        # 打印推测解码的输出
        print("\nspec decoding:\n\n", spec_decode_output, "\n")

        # 打印基本解码的时间
        print(f'base decoding in: {base_decode_elapsed:.3f}ms\n')
        # 打印推测解码的时间
        print(f'spec decoding in: {spec_decode_elapsed:.3f}ms\n')
        # 打印平均接受的数量
        print(f'average num accepted: {num_accepted:.1f} / {GAMMA}\n')

.\lucidrains\st-moe-pytorch\assert.py

# 导入必要的库
import os
from copy import deepcopy
import torch
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from st_moe_pytorch.st_moe_pytorch import Experts, Expert
from st_moe_pytorch.distributed import all_gather_variable_dim

# 设置初始化函数,用于初始化分布式训练环境
def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group("gloo", rank=rank, world_size=world_size)

# 清理函数,用于销毁进程组
def cleanup():
    dist.destroy_process_group()

# 主函数,用于启动分布式训练
def start(
    rank,
    world_size,
    batch_size,
    batch_size_var_len,
    num_experts,
    tokens_per_expert,
    dim,
    use_cuda
):
    # 初始化分布式训练环境
    setup(rank, world_size)

    # 创建专家网络
    net = Experts([Expert(dim) for _ in range(num_experts)])

    # 根据是否变长批次设置批次大小
    if batch_size_var_len:
        batch_size = batch_size + rank

    # 生成随机输入序列
    seq = torch.randn(batch_size, num_experts, tokens_per_expert, dim)

    # 本地计算

    # 深拷贝专家网络
    local_net = deepcopy(net)

    # 聚合所有进程的输入数据
    local_inputs, _ = all_gather_variable_dim(seq)

    # 在本地网络上进行前向传播
    local_out = local_net(
        local_inputs,
        is_distributed=False
    )

    # 计算本地输出的均值并进行反向传播
    local_out.mean().backward()

    # 分布式计算

    # 使用分布式数据并行模型
    model = DDP(net)
    ddp_inputs = seq

    # 如果使用CUDA,则将模型和输入数据移动到对应设备
    if use_cuda:
        model.cuda(rank)
        ddp_inputs = seq.cuda(rank)

    # 在分布式模型上进行前向传播
    out = model(ddp_inputs)
    out.mean().backward()

    # 聚合所有进程的输出数据
    ddp_all_out, _ = all_gather_variable_dim(out)

    if rank == 0:
        # 验证本地和分布式输出是否一致

        # 将模型和输出数据移回CPU
        model.cpu()
        ddp_all_out.cpu()

        # 使用assert检查本地和分布式输出是否一致
        assert torch.allclose(local_out, ddp_all_out.cpu(), atol=1e-3), 'output is not the same'

        # 验证本地和分布式第一个专家的梯度是否一致

        # 定义获取第一个专家梯度的函数
        get_first_expert_grad = lambda t: t.experts[0].net[0].weight.grad

        # 使用assert检查本地和分布式第一个专家的梯度是否一致
        assert torch.allclose(
            get_first_expert_grad(net).cpu(),
            get_first_expert_grad(local_net),
            atol=1e-2
        ), 'grad is not the same'

        # 输出验证结果
        print('✅ outputs and gradients are same between local and ddp')

    # 清理环境
    cleanup()

# 主程序入口
if __name__ == '__main__':
    # 设置参数
    world_size = 8
    num_experts = 3
    batch_size = 2
    batch_size_var_len = True
    use_cuda = False

    # 检查是否使用CUDA并且设备数量小于等于进程数量
    assert not use_cuda or torch.cuda.device_count() <= world_size

    seq_len = 32
    dim = 8

    # 使用多进程启动分布式训练
    mp.spawn(
        start,
        args=(
            world_size,
            batch_size,
            batch_size_var_len,
            num_experts,
            seq_len,
            dim,
            use_cuda
        ),
        nprocs=world_size,
        join=True
    )

ST-MoE - Pytorch

Implementation of ST-MoE, the latest incarnation of mixture of experts after years of research at Brain, in Pytorch. Will be largely a transcription of the official Mesh Tensorflow implementation. If you have any papers you think should be added, while I have my attention on mixture of experts, please open an issue.

This should be SOTA for mixture-of-experts for autoregressive transformers. It is rumored that GPT4 is using 16 experts with top2 gating.

For non-autoregressive, would recommend going with the simpler and better Soft MoE.

Install

$ pip install st-moe-pytorch

Appreciation

  • StabilityAI for the generous sponsorship, as well as my other sponsors, for affording me the independence to open source artificial intelligence.

  • Aran Komatsuzaki for consultation on mixture-of-experts, for removal of 2-level MoE and simplifications to code

Usage

import torch
from st_moe_pytorch import MoE

moe = MoE(
    dim = 512,
    num_experts = 16,               # increase the experts (# parameters) of your model without increasing computation
    gating_top_n = 2,               # default to top 2 gating, but can also be more (3 was tested in the paper with a lower threshold)
    threshold_train = 0.2,          # at what threshold to accept a token to be routed to second expert and beyond - 0.2 was optimal for 2 expert routing, and apparently should be lower for 3
    threshold_eval = 0.2,
    capacity_factor_train = 1.25,   # experts have fixed capacity per batch. we need some extra capacity in case gating is not perfectly balanced.
    capacity_factor_eval = 2.,      # capacity_factor_* should be set to a value >=1
    balance_loss_coef = 1e-2,       # multiplier on the auxiliary expert balancing auxiliary loss
    router_z_loss_coef = 1e-3,      # loss weight for router z-loss
)

inputs = torch.randn(4, 1024, 512)
out, total_aux_loss, balance_loss, router_z_loss = moe(inputs) # (4, 1024, 512), (1,), (1,), (1,)

# for the entire mixture of experts block, in context of transformer

from st_moe_pytorch import SparseMoEBlock

moe_block = SparseMoEBlock(
    moe,
    add_ff_before = True,
    add_ff_after = True
)

out, total_aux_loss, balance_loss, router_z_loss = moe_block(inputs) # (4, 1024, 512), (1,) (1,), (1,)

# the total auxiliary loss will need to be summed and then added to the main loss

# the other two losses are the unweighted breakdown for logging purposes

Todo

  • add the router z-loss proposed in paper

  • add the geglu expert with multiplicative gating

  • add an entire sparse moe block, complete with rmsnorm + residual as well as the ability to specify a feedforward before or after for stability

  • double check equation for router z-loss for experts inner in hierarchical moe

  • redo all the transcribed code from google with einops, as it is not very clear

  • consult some MoE experts in the open source community; question why hierarchical MoE is needed, in light of results from soft-MoE

  • offer top-n gating generalization, as it seems top3 (with smaller threshold) can work even better

  • figure out if there was an error in a previous transcription - no there was not an error

  • allow for different thresholds for second vs third routed expert

  • add coordinate descent based routing

  • make first naive non-optimized attempt at distributed code for mixture of experts

  • distributed

    • handle any world size less than number of experts
    • handle any world size greater than number of experts - for now, just have remainder machines do nothing
    • support variable batch sizes
    • support variable seq lengths
    • figure out how to move assert.py to pytests
    • simplify the variable sequence length test code from another folder and move in so other researchers gain confidence
    • optimize
    • figure out what is faster, all gather, or broadcast with async followed by barrier
    • make all distributed code pluggable, for different strategies
    • figure out why there is tiny error in gradients
  • improvise a Top2GatingWithCoordinateDescent for MoE without importance

Citations

@inproceedings{Zoph2022STMoEDS,
    title   = {ST-MoE: Designing Stable and Transferable Sparse Expert Models},
    author  = {Barret Zoph and Irwan Bello and Sameer Kumar and Nan Du and Yanping Huang and Jeff Dean and Noam M. Shazeer and William Fedus},
    year    = {2022}
}

.\lucidrains\st-moe-pytorch\setup.py

# 导入设置和查找包的函数
from setuptools import setup, find_packages

# 设置包的信息
setup(
  name = 'st-moe-pytorch',  # 包的名称
  packages = find_packages(exclude=[]),  # 查找所有包
  version = '0.1.7',  # 版本号
  license='MIT',  # 许可证
  description = 'ST - Mixture of Experts - Pytorch',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  long_description_content_type = 'text/markdown',  # 长描述内容类型
  url = 'https://github.com/lucidrains/st-moe-pytorch',  # URL
  keywords = [
    'artificial intelligence',  # 关键词
    'deep learning',  # 关键词
    'mixture of experts'  # 关键词
  ],
  install_requires=[
    'beartype',  # 安装所需的包
    'CoLT5-attention>=0.10.15',  # 安装所需的包
    'einops>=0.6',  # 安装所需的包
    'torch>=2.0',  # 安装所需的包
  ],
  classifiers=[
    'Development Status :: 4 - Beta',  # 分类器
    'Intended Audience :: Developers',  # 分类器
    'Topic :: Scientific/Engineering :: Artificial Intelligence',  # 分类器
    'License :: OSI Approved :: MIT License',  # 分类器
    'Programming Language :: Python :: 3.6',  # 分类器
  ],
)

.\lucidrains\st-moe-pytorch\st_moe_pytorch\distributed.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块
from torch import nn
# 从 torch 库中导入 nn.functional 模块,并重命名为 F
import torch.nn.functional as F
# 从 torch.autograd 模块中导入 Function 类
from torch.autograd import Function

# 从 torch.distributed 模块中导入 dist 对象
import torch.distributed as dist

# 从 einops 库中导入 rearrange, pack, unpack 函数
from einops import rearrange, pack, unpack

# 定义函数,判断变量是否存在
def exists(val):
    return val is not None

# 定义函数,返回默认值
def default(val, d):
    return val if exists(val) else d

# 定义函数,判断两个数是否整除
def divisible_by(num, den):
    return (num % den) == 0

# 定义函数,将张量在指定维度上进行填充
def pad_dim_to(t, length, dim = 0):
    pad_length = length - t.shape[dim]
    zero_pairs = (-dim - 1) if dim < 0 else (t.ndim - dim - 1)
    return F.pad(t, (*((0, 0) * zero_pairs), 0, pad_length)

# 定义函数,对所有进程进行相同维度的全局收集
def all_gather_same_dim(t):
    t = t.contiguous()
    world_size = dist.get_world_size()
    gathered_tensors = [torch.empty_like(t, device = t.device, dtype = t.dtype) for i in range(world_size)]
    dist.all_gather(gathered_tensors, t)
    return gathered_tensors

# 定义函数,收集指定维度的大小信息
def gather_sizes(t, *, dim):
    size = torch.tensor(t.shape[dim], device = t.device, dtype = torch.long)
    sizes = all_gather_same_dim(size)
    return torch.stack(sizes)

# 定义函数,判断张量是否只有一个值
def has_only_one_value(t):
    return (t == t[0]).all()

# 定义函数,对所有进程进行变量维度的全局收集
def all_gather_variable_dim(t, dim = 0, sizes = None):
    device, rank, world_size = t.device, dist.get_rank(), dist.get_world_size()

    if not exists(sizes):
        sizes = gather_sizes(t, dim = dim)

    if has_only_one_value(sizes):
        gathered_tensors = all_gather_same_dim(t)
        gathered_tensors = torch.cat(gathered_tensors, dim = dim)
        return gathered_tensors, sizes

    max_size = sizes.amax().item()

    padded_t = pad_dim_to(t, max_size, dim = dim)
    gathered_tensors = all_gather_same_dim(padded_t)

    gathered_tensors = torch.cat(gathered_tensors, dim = dim)
    seq = torch.arange(max_size, device = device)

    mask = rearrange(seq, 'j -> 1 j') < rearrange(sizes, 'i -> i 1')
    mask = rearrange(mask, 'i j -> (i j)')
    seq = torch.arange(mask.shape[-1], device = device)
    indices = seq[mask]

    gathered_tensors = gathered_tensors.index_select(dim, indices)

    return gathered_tensors, sizes

# 定义 AllGatherFunction 类,继承自 Function 类
class AllGatherFunction(Function):
    @staticmethod
    def forward(ctx, x, dim, sizes):
        x, batch_sizes = all_gather_variable_dim(x, dim = dim, sizes = sizes)
        ctx.batch_sizes = batch_sizes.tolist()
        ctx.dim = dim
        return x, batch_sizes

    @staticmethod
    def backward(ctx, grads, _):
        batch_sizes, rank = ctx.batch_sizes, dist.get_rank()
        grads_by_rank = grads.split(batch_sizes, dim = ctx.dim)
        return grads_by_rank[rank], None, None

# 定义 AllGather 类,继承自 nn.Module 类
class AllGather(nn.Module):
    def __init__(self, *, dim = 0):
        super().__init__()
        self.dim = dim

    def forward(self, x, sizes = None):
        return AllGatherFunction.apply(x, self.dim, sizes)

# 定义函数,根据进程排名拆分张量
def split_by_rank(x):
    rank = dist.get_rank()
    out = x[rank]

    if isinstance(x, tuple):
        sizes = tuple(map(lambda t: t.shape[0], x))
    else:
        sizes = (x.shape[1],) * x.shape[0]

    sizes = torch.tensor(sizes, device = out.device, dtype = torch.long)
    return out, sizes