Lucidrains-系列项目源码解析-六十一-

198 阅读18分钟

Lucidrains 系列项目源码解析(六十一)

.\lucidrains\metnet3-pytorch\metnet3_pytorch\__init__.py

# 从 metnet3_pytorch 包中导入 MetNet3 类
from metnet3_pytorch.metnet3_pytorch import (
    MetNet3
)

MetNet-3 - Pytorch

Implementation of MetNet 3, SOTA neural weather model out of Google Deepmind, in Pytorch

The model architecture is pretty unremarkable. It is basically a U-net with a specific well performing vision transformer. The most interesting thing about the paper may end up being the loss scaling in section 4.3.2

Appreciation

Install

$ pip install metnet3-pytorch

Usage

import torch
from metnet3_pytorch import MetNet3

metnet3 = MetNet3(
    dim = 512,
    num_lead_times = 722,
    lead_time_embed_dim = 32,
    input_spatial_size = 624,
    attn_dim_head = 8,
    hrrr_channels = 617,
    input_2496_channels = 2 + 14 + 1 + 2 + 20,
    input_4996_channels = 16 + 1,
    precipitation_target_bins = dict(
        mrms_rate = 512,
        mrms_accumulation = 512,
    ),
    surface_target_bins = dict(
        omo_temperature = 256,
        omo_dew_point = 256,
        omo_wind_speed = 256,
        omo_wind_component_x = 256,
        omo_wind_component_y = 256,
        omo_wind_direction = 180
    ),
    hrrr_loss_weight = 10,
    hrrr_norm_strategy = 'sync_batchnorm',  # this would use a sync batchnorm to normalize the input hrrr and target, without having to precalculate the mean and variance of the hrrr dataset per channel
    hrrr_norm_statistics = None             # you can also also set `hrrr_norm_strategy = "precalculated"` and pass in the mean and variance as shape `(2, 617)` through this keyword argument
)

# inputs

lead_times = torch.randint(0, 722, (2,))
hrrr_input_2496 = torch.randn((2, 617, 624, 624))
hrrr_stale_state = torch.randn((2, 1, 624, 624))
input_2496 = torch.randn((2, 39, 624, 624))
input_4996 = torch.randn((2, 17, 624, 624))

# targets

precipitation_targets = dict(
    mrms_rate = torch.randint(0, 512, (2, 512, 512)),
    mrms_accumulation = torch.randint(0, 512, (2, 512, 512)),
)

surface_targets = dict(
    omo_temperature = torch.randint(0, 256, (2, 128, 128)),
    omo_dew_point = torch.randint(0, 256, (2, 128, 128)),
    omo_wind_speed = torch.randint(0, 256, (2, 128, 128)),
    omo_wind_component_x = torch.randint(0, 256, (2, 128, 128)),
    omo_wind_component_y = torch.randint(0, 256, (2, 128, 128)),
    omo_wind_direction = torch.randint(0, 180, (2, 128, 128))
)

hrrr_target = torch.randn(2, 617, 128, 128)

total_loss, loss_breakdown = metnet3(
    lead_times = lead_times,
    hrrr_input_2496 = hrrr_input_2496,
    hrrr_stale_state = hrrr_stale_state,
    input_2496 = input_2496,
    input_4996 = input_4996,
    precipitation_targets = precipitation_targets,
    surface_targets = surface_targets,
    hrrr_target = hrrr_target
)

total_loss.backward()

# after much training from above, you can predict as follows

metnet3.eval()

surface_preds, hrrr_pred, precipitation_preds = metnet3(
    lead_times = lead_times,
    hrrr_input_2496 = hrrr_input_2496,
    hrrr_stale_state = hrrr_stale_state,
    input_2496 = input_2496,
    input_4996 = input_4996,
)


# Dict[str, Tensor], Tensor, Dict[str, Tensor]

Todo

  • figure out all the cross entropy and MSE losses

  • auto-handle normalization across all the channels of the HRRR by tracking a running mean and variance of targets during training (using sync batchnorm as hack)

  • allow researcher to pass in their own normalization variables for HRRR

  • build all the inputs to spec, also make sure hrrr input is normalized, offer option to unnormalize hrrr predictions

  • make sure model can be easily saved and loaded, with different ways of handling hrrr norm

  • figure out the topological embedding, consult a neural weather researcher

Citations

@article{Andrychowicz2023DeepLF,
    title   = {Deep Learning for Day Forecasts from Sparse Observations},
    author  = {Marcin Andrychowicz and Lasse Espeholt and Di Li and Samier Merchant and Alexander Merose and Fred Zyda and Shreya Agrawal and Nal Kalchbrenner},
    journal = {ArXiv},
    year    = {2023},
    volume  = {abs/2306.06079},
    url     = {https://api.semanticscholar.org/CorpusID:259129311}
}
@inproceedings{ElNouby2021XCiTCI,
    title   = {XCiT: Cross-Covariance Image Transformers},
    author  = {Alaaeldin El-Nouby and Hugo Touvron and Mathilde Caron and Piotr Bojanowski and Matthijs Douze and Armand Joulin and Ivan Laptev and Natalia Neverova and Gabriel Synnaeve and Jakob Verbeek and Herv{\'e} J{\'e}gou},
    booktitle = {Neural Information Processing Systems},
    year    = {2021},
    url     = {https://api.semanticscholar.org/CorpusID:235458262}
}

.\lucidrains\metnet3-pytorch\setup.py

# 导入设置安装包和查找包的函数
from setuptools import setup, find_packages

# 设置安装包的信息
setup(
  name = 'metnet3-pytorch',  # 包的名称
  packages = find_packages(exclude=[]),  # 查找并包含所有包
  version = '0.0.12',  # 版本号
  license='MIT',  # 许可证
  description = 'MetNet 3 - Pytorch',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  long_description_content_type = 'text/markdown',  # 长描述内容类型
  url = 'https://github.com/lucidrains/metnet3-pytorch',  # 项目链接
  keywords = [  # 关键词列表
    'artificial intelligence',
    'deep learning',
    'vision transformers',
    'unet',
    'weather forecasting'
  ],
  install_requires=[  # 安装依赖
    'beartype',
    'einops>=0.7.0',
    'torch>=2.0',
  ],
  classifiers=[  # 分类器列表
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\mirasol-pytorch\mirasol_pytorch\distributed.py

# 导入必要的库
from functools import cache

import torch
from torch.autograd import Function
import torch.distributed as distributed

from einops import rearrange

# 辅助函数

# 使用缓存装饰器缓存结果,判断当前是否处于分布式环境
@cache
def get_is_distributed():
    return distributed.is_initialized() and distributed.get_world_size() > 1

# 在指定维度上对张量进行填充,使其达到指定长度
def pad_dim_to(t, length, dim = 0):
    pad_length = length - t.shape[dim]
    zero_pairs = (-dim - 1) if dim < 0 else (t.ndim - dim - 1)
    return F.pad(t, (*((0, 0) * zero_pairs), 0, pad_length))

# 分布式辅助函数

# 在所有进程中收集具有可变维度的张量,根据给定的维度和大小
def all_gather_variable_dim(t, dim = 0, sizes = None):
    device, rank, world_size = t.device, distributed.get_rank(), distributed.get_world_size()

    if not exists(sizes):
        size = torch.tensor(t.shape[dim], device = device, dtype = torch.long)
        sizes = [torch.empty_like(size, device = device, dtype = torch.long) for i in range(world_size)]
        distributed.all_gather(sizes, size)
        sizes = torch.stack(sizes)

    max_size = sizes.amax().item()
    padded_t = pad_dim_to(t, max_size, dim = dim)

    gathered_tensors = [torch.empty(padded_t.shape, device = device, dtype = padded_t.dtype) for i in range(world_size)]
    distributed.all_gather(gathered_tensors, padded_t)

    gathered_tensor = torch.cat(gathered_tensors, dim = dim)
    seq = torch.arange(max_size, device = device)

    mask = rearrange(seq, 'j -> 1 j') < rearrange(sizes, 'i -> i 1')
    mask = rearrange(mask, 'i j -> (i j)')
    seq = torch.arange(mask.shape[-1], device = device)
    indices = seq[mask]

    gathered_tensor = gathered_tensor.index_select(dim, indices)

    return gathered_tensor, sizes

# 自定义的 Function 类,用于实现 all_gather 操作
class AllGather(Function):
    @staticmethod
    def forward(ctx, x, dim, sizes):
        assert get_is_distributed()
        x, batch_sizes = all_gather_variable_dim(x, dim = dim, sizes = sizes)
        ctx.batch_sizes = batch_sizes.tolist()
        ctx.dim = dim
        return x, batch_sizes

    @staticmethod
    def backward(ctx, grads, _):
        batch_sizes, rank = ctx.batch_sizes, distributed.get_rank()
        grads_by_rank = grads.split(batch_sizes, dim = ctx.dim)
        return grads_by_rank[rank], None, None

# 将自定义的 Function 应用到 all_gather 函数上
all_gather = AllGather.apply

.\lucidrains\mirasol-pytorch\mirasol_pytorch\mirasol_pytorch.py

# 导入所需的模块和函数
import operator
from functools import partial
from collections import namedtuple

import torch
import torch.nn.functional as F
from torch import Tensor, nn, einsum
from torch.nn import Module, ModuleList

# 导入 beartype 模块和相关类型
from beartype import beartype
from beartype.typing import Optional, Union, Tuple, Dict, Any

# 导入 einops 相关函数和层
from einops import rearrange, repeat, reduce, pack, unpack
from einops.layers.torch import Rearrange

# 导入 x_transformers 相关模块和类
from x_transformers import (
    Encoder,
    Decoder,
    TransformerWrapper,
    AutoregressiveWrapper
)

# 导入 x_transformers 中的 RotaryEmbedding 类
from x_transformers.x_transformers import RotaryEmbedding

# 导入 mirasol_pytorch 中的分布式函数
from mirasol_pytorch.distributed import all_gather, get_is_distributed

# 辅助函数

# 判断变量是否存在
def exists(v):
    return v is not None

# 返回参数中第一个存在的值
def default(*args):
    for arg in args:
        if exists(arg):
            return arg
    return None

# 判断一个数是否可以被另一个数整除
def divisible_by(num, den):
    return (num % den) == 0

# 判断参数中只有一个为 True
def only_one_true(*bools):
    return sum(*[map(int, bools)]) == 1

# 将张量打包成指定模式
def pack_one(t, pattern):
    return pack([t], pattern)

# 将打包的张量解包成指定模式
def unpack_one(t, ps, pattern):
    return unpack(t, ps, pattern)[0]

# 张量操作函数

# 计算张量的 L2 范数
def l2norm(t):
    return F.normalize(t, dim = -1)

# 计算张量之间的余弦相似度损失
def cosine_sim_loss(x, y):
    x, y = map(l2norm, (x, y))
    return 1. - einsum('b n d, b n d -> b n', x, y).mean()

# 生成位置编码的正弦和余弦值
def posemb_sincos_nd(
    t: Tensor,
    temperature: int = 10000,
    dtype = torch.float32
):
    b, *dims, feat_dim, device = *t.shape, t.device
    seq_len = torch.tensor(dims).cumprod(dim = -1)[-1].item()

    arange = partial(torch.arange, device = device)

    num_dims = len(dims)
    two_times_num_dims = 2 * num_dims # 2 because sin and cos of same position

    rounded_feat_dim = feat_dim // num_dims * num_dims
    feat_dim_remainder = feat_dim % num_dims

    omega = arange(rounded_feat_dim // two_times_num_dims) / (rounded_feat_dim // two_times_num_dims - 1)
    omega = 1.0 / (temperature ** omega)
    meshed = torch.meshgrid(*[*map(arange, dims)], indexing = 'ij')

    pos = torch.cat(tuple(m.flatten()[..., None] for m in meshed), dim = 0)
    pos = pos * omega[None, :]

    pos = torch.cat((pos.sin(), pos.cos()))

    pos = rearrange(pos, '(n f) d -> n (f d)', n = seq_len)
    pos = pos.type(dtype)

    return F.pad(pos, (0, feat_dim_remainder))

# 生成具有一定概率的掩码张量
def mask_with_prob(
    shape: Tuple[int, ...],
    prob: float,
    device = None
) -> Tensor:
    length = shape[-1]
    num_mask = int(prob * length)
    randperm = torch.randn(shape, device = device).argsort(dim = -1)
    return randperm >= num_mask

# 主类

# 定义 Losses 命名元组,包含不同类型的损失
Losses = namedtuple('Losses', [
    'text_autoregressive',
    'av_autoregressive',
    'av_recon',
    'text_av_sim_reg'
])

# Mirasol 类,继承自 Module 类
class Mirasol(Module):

    @beartype
    # 初始化函数,设置模型的各种参数
    def __init__(
        self,
        *,
        dim,
        num_text_tokens,
        video_image_size,
        video_frames_per_timechunk,
        audio_freq_dim,
        audio_time_dim_per_timechunk,
        audio_patch_size: Tuple[int, int],                          # 音频补丁大小 (频率, 时间)
        video_patch_size: Tuple[int, int],                          # 视频补丁大小 (空间, 时间)
        video_recon_patch_size: Optional[Tuple[int, int]] = None,   # 视频重建补丁大小 (空间, 时间) - 用于重建损失的较小视频
        video_recon_interpolate_mode = 'nearest',
        audio_encoder: Union[Module, Dict[str, Any]],
        video_encoder: Union[Module, Dict[str, Any]],
        num_audio_video_register_tokens = 8,                        # 音频视频注册令牌数量 https://arxiv.org/abs/2309.16588
        audio_video_mask_prob = 0.15,                         # 在论文中,他们使用了被屏蔽的令牌,但从伯克利遗忘-因果-掩码论文中,一个简单的键值掩码应该足够
        text_max_seq_len = 2048,
        text_forgetful_causal_mask_prob = 0.1,                      # https://arxiv.org/abs/2210.13432
        encoder_depth = 6,
        decoder_depth = 6,
        combiner_depth = 2,
        combiner_output_num_tokens = 3,
        video_channels = 3,
        attn_dim_head = 64,
        attn_heads = 8,
        flash_attn = True,
        attn_layers_kwargs: dict = dict(),
        combiner: Optional[Module] = None,
        combiner_kwargs: dict = dict(),
        autoregressive_wrapper_kwargs: dict = dict(
            pad_value = 0,
            ignore_index = -100
        ),
        av_autoregressive_loss_weight = 1.,
        av_reconstruction_loss_weight = 1.,
        sim_reg_loss_weight = 0.
    
    # 返回设备信息
    @property
    def device(self):
        return next(self.parameters()).device

    # 生成函数,用于生成序列
    @torch.no_grad()
    def generate(
        self,
        *,
        seq_len: int,
        prompt: Optional[Tensor] = None,
        **kwargs
    ):
        was_training = self.training
        self.eval()

        assert 'generate' not in kwargs
        assert 'generate_seq_len' not in kwargs

        # 调用前向传播函数生成序列
        out = self.forward(
            text = prompt,
            generate = True,
            generate_seq_len = seq_len,
            **kwargs
        )

        self.train(was_training)
        return out

    # 前向传播函数,接收输入并返回输出
    @beartype
    def forward(
        self,
        *,
        audio: Optional[Tensor] = None,
        video: Optional[Tensor] = None,
        encoded_audio: Optional[Tensor] = None,
        encoded_video: Optional[Tensor] = None,
        text: Optional[Tensor] = None,
        text_mask: Optional[Tensor] = None,
        return_loss = True,
        return_loss_breakdown = False,
        generate = False,
        generate_seq_len = None

.\lucidrains\mirasol-pytorch\mirasol_pytorch\__init__.py

# 从 mirasol_pytorch 包中导入 Mirasol 类
from mirasol_pytorch.mirasol_pytorch import Mirasol

🌻 Mirasol - Pytorch

Implementation of Mirasol, SOTA Multimodal Autoregressive model out of Google Deepmind, in Pytorch

Will simply implement the Transformer Combiner and omit the other variants.

Appreciation

Install

$ pip install mirasol-pytorch

Usage

import torch
from mirasol_pytorch import Mirasol

model = Mirasol(
    dim = 512,
    num_text_tokens = 256,
    video_image_size = 128,
    video_frames_per_timechunk = 2,
    audio_freq_dim = 64,
    audio_time_dim_per_timechunk = 32,
    audio_patch_size = (32, 16),
    video_patch_size = (64, 2),
    audio_encoder = dict(
        dim = 512,
        depth = 2
    ),
    video_encoder = dict(
        dim = 512,
        depth = 2
    )
)

audio = torch.randn(1, 64, 1024)
video = torch.randn(1, 3, 12, 128, 128)

text = torch.randint(0, 256, (1, 1024))

loss = model(
    audio = audio,
    video = video,
    text = text
)

loss.backward()

# after much training

sampled_text = model.generate(
    audio = audio,
    video = video,
    seq_len = 512
)

Todo

  • text generation code
  • auto-handle start token for decoder
  • positional embeddings for video and audio encoder
  • enable register tokens for both video and audio encoder, inline with new research
  • add audio and video reconstruction losses
  • add similarity regularization from TTS research

Citations

@article{Piergiovanni2023Mirasol3BAM,
    title   = {Mirasol3B: A Multimodal Autoregressive model for time-aligned and contextual modalities},
    author  = {A. J. Piergiovanni and Isaac Noble and Dahun Kim and Michael S. Ryoo and Victor Gomes and Anelia Angelova},
    journal = {ArXiv},
    year    = {2023},
    volume  = {abs/2311.05698},
    url     = {https://api.semanticscholar.org/CorpusID:265129010}
}
@inproceedings{Liu2022TowardsBF,
    title   = {Towards Better Few-Shot and Finetuning Performance with Forgetful Causal Language Models},
    author  = {Hao Liu and Xinyang Geng and Lisa Lee and Igor Mordatch and Sergey Levine and Sharan Narang and P. Abbeel},
    year    = {2022},
    url     = {https://api.semanticscholar.org/CorpusID:256416540}
}
@article{Darcet2023VisionTN,
    title   = {Vision Transformers Need Registers},
    author  = {Timoth'ee Darcet and Maxime Oquab and Julien Mairal and Piotr Bojanowski},
    journal = {ArXiv},
    year    = {2023},
    volume  = {abs/2309.16588},
    url     = {https://api.semanticscholar.org/CorpusID:263134283}
}
@article{Bondarenko2023QuantizableTR,
    title   = {Quantizable Transformers: Removing Outliers by Helping Attention Heads Do Nothing},
    author  = {Yelysei Bondarenko and Markus Nagel and Tijmen Blankevoort},
    journal = {ArXiv},
    year    = {2023},
    volume  = {abs/2306.12929},
    url     = {https://api.semanticscholar.org/CorpusID:259224568}
}
@misc{shi2023enhance,
    title   = {Enhance audio generation controllability through representation similarity regularization}, 
    author  = {Yangyang Shi and Gael Le Lan and Varun Nagaraja and Zhaoheng Ni and Xinhao Mei and Ernie Chang and Forrest Iandola and Yang Liu and Vikas Chandra},
    year    = {2023},
    eprint  = {2309.08773},
    archivePrefix = {arXiv},
    primaryClass = {cs.SD}
}

.\lucidrains\mirasol-pytorch\setup.py

# 导入设置和查找包的函数
from setuptools import setup, find_packages

# 设置包的信息
setup(
  # 包的名称
  name = 'mirasol-pytorch',
  # 查找所有包,不排除任何包
  packages = find_packages(exclude=[]),
  # 版本号
  version = '0.0.16',
  # 许可证类型
  license='MIT',
  # 描述信息
  description = 'Mirasol - Pytorch',
  # 作者
  author = 'Phil Wang',
  # 作者邮箱
  author_email = 'lucidrains@gmail.com',
  # 长描述内容类型
  long_description_content_type = 'text/markdown',
  # 项目链接
  url = 'https://github.com/lucidrains/mirasol-pytorch',
  # 关键词列表
  keywords = [
    'artificial intelligence',
    'deep learning',
    'multimodality'
  ],
  # 安装依赖项
  install_requires=[
    'beartype',
    'einops>=0.7.0',
    'x-transformers>=1.25.10',
    'torch>=2.0'
  ],
  # 分类标签
  classifiers=[
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

Data source

The enwik8 data was downloaded from the Hutter prize page: prize.hutter1.net/

.\lucidrains\mixture-of-attention\mixture_of_attention\attend.py

# 导入必要的库
from collections import namedtuple
from functools import wraps
from packaging import version

import torch
from torch import nn, einsum
import torch.nn.functional as F

from einops import rearrange

# 定义一个命名元组EfficientAttentionConfig,用于存储配置信息
EfficientAttentionConfig = namedtuple('EfficientAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])

# 定义辅助函数exists,用于检查变量是否存在
def exists(val):
    return val is not None

# 定义装饰器once,确保函数只被调用一次
def once(fn):
    called = False
    @wraps(fn)
    def inner(x):
        nonlocal called
        if called:
            return
        called = True
        return fn(x)
    return inner

# 用once装饰print函数,确保只打印一次
print_once = once(print)

# 主要类Attend
class Attend(nn.Module):
    def __init__(
        self,
        dropout = 0.,
        causal = False,
        flash = False
    ):
        super().__init__()
        self.dropout = dropout
        self.attn_dropout = nn.Dropout(dropout)

        self.causal = causal
        self.flash = flash
        assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'

        # 确定cuda和cpu的高效注意力配置
        self.cpu_config = EfficientAttentionConfig(True, True, True)
        self.cuda_config = None

        if not torch.cuda.is_available() or not flash:
            return

        device_properties = torch.cuda.get_device_properties(torch.device('cuda'))

        if device_properties.major == 8 and device_properties.minor == 0:
            print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
            self.cuda_config = EfficientAttentionConfig(True, False, False)
        else:
            print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
            self.cuda_config = EfficientAttentionConfig(False, True, True)

    # 生成mask
    def get_mask(self, i, j, device):
        return torch.ones((i, j), device=device, dtype=torch.bool).triu(j - i + 1)

    # Flash Attention
    def flash_attn(self, q, k, v, mask = None):
        _, heads, q_len, _, k_len, is_cuda = *q.shape, k.shape[-2], q.is_cuda

        if exists(mask) and mask.ndim != 4:
            mask = rearrange(mask, 'b j -> b 1 1 j')
            mask = mask.expand(-1, heads, q_len, -1)

        config = self.cuda_config if is_cuda else self.cpu_config

        with torch.backends.cuda.sdp_kernel(**config._asdict()):
            out = F.scaled_dot_product_attention(
                q, k, v,
                attn_mask = mask,
                dropout_p = self.dropout if self.training else 0., 
                is_causal = self.causal
            )

        return out

    # 前向传播函数
    def forward(self, q, k, v, mask = None):
        """
        einstein notation
        b - batch
        h - heads
        n, i, j - sequence length (base sequence length, source, target)
        d - feature dimension
        """

        n, device = q.shape[-2], q.device

        scale = q.shape[-1] ** -0.5

        if self.flash:
            return self.flash_attn(q, k, v, mask = mask)

        # 相似度计算
        sim = einsum("b h i d, b h j d -> b h i j", q, k) * scale

        # key padding mask
        if exists(mask):
            mask = rearrange(mask, 'b j -> b 1 1 j')
            sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)

        # causal mask
        if self.causal:
            causal_mask = self.get_mask(n, device)
            sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)

        # 注意力计算
        attn = sim.softmax(dim=-1)
        attn = self.attn_dropout(attn)

        # 聚合值
        out = einsum("b h i j, b h j d -> b h i d", attn, v)

        return out

.\lucidrains\mixture-of-attention\mixture_of_attention\autoregressive_wrapper.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块
from torch import nn
# 从 torch 库中导入 nn.functional 模块,并重命名为 F
import torch.nn.functional as F

# 从 einops 库中导入 rearrange 函数
from einops import rearrange

# 辅助函数

# 判断变量是否存在的函数
def exists(val):
    return val is not None

# 评估装饰器函数
def eval_decorator(fn):
    def inner(model, *args, **kwargs):
        # 保存模型当前是否为训练状态
        was_training = model.training
        # 将模型设置为评估状态
        model.eval()
        # 调用传入的函数,并传入模型、参数和关键字参数
        out = fn(model, *args, **kwargs)
        # 恢复模型之前的训练状态
        model.train(was_training)
        return out
    return inner

# top k 过滤

# 根据阈值过滤 logits 中的 top k 值
def top_k(logits, thres = 0.9):
    # 计算需要保留的 top k 值的数量
    k = int((1 - thres) * logits.shape[-1])
    # 获取 top k 值和对应的索引
    val, ind = torch.topk(logits, k)
    # 创建一个与 logits 相同形状的张量,填充为负的最大值
    probs = torch.full_like(logits, -torch.finfo(logits.dtype).max)
    # 根据索引将 top k 值填充到 probs 中
    probs.scatter_(1, ind, val)
    return probs

# 自回归包装器类
class AutoregressiveWrapper(nn.Module):
    def __init__(
        self,
        net,        
        pad_value = 0
    ):
        super().__init__()
        # 初始化属性
        self.seq_len = net.seq_len
        self.pad_value = pad_value
        self.net = net

    # 生成函数装饰器,用于生成序列
    @torch.no_grad()
    @eval_decorator
    def generate(
        self,
        prompt,
        seq_len,
        temperature=1.0,
        filter_thres=0.9,
        **kwargs
    ):
        # 获取 prompt 的形状和设备信息
        b, t, device = *prompt.shape, prompt.device

        out = prompt

        # 生成序列
        for _ in range(seq_len):
            # 获取最后 self.seq_len 长度的序列,并传入网络获取 logits
            logits = self.net(out[:, -self.seq_len:], **kwargs)[:, -1]

            # 对 logits 进行 top k 过滤
            filtered_logits = top_k(logits, thres = filter_thres)
            # 计算概率分布
            probs = F.softmax(filtered_logits / temperature, dim = -1)

            # 从概率分布中采样一个值
            sample = torch.multinomial(probs, 1)
            # 将采样值拼接到输出序列中
            out = torch.cat((out, sample), dim = -1)

        # 去除前面的 prompt 部分,返回生成的序列
        out = out[:, t:]
        return out

    # 前向传播函数
    def forward(self, x, **kwargs):
        # 获取输入 x 和标签 labels
        x, labels = x[:, :-1], x[:, 1:]
        # 将输入传入网络获取 logits
        logits = self.net(x, **kwargs)
        # 重新排列 logits 的维度
        logits = rearrange(logits, "b c n -> b n c")
        # 计算交叉熵损失
        return F.cross_entropy(logits, labels)

.\lucidrains\mixture-of-attention\mixture_of_attention\mixture_of_attention.py

# 导入数学库
import math

# 导入 PyTorch 库
import torch
import torch.nn.functional as F
from torch import Tensor, nn, einsum

# 导入类型提示
from typing import Tuple, Optional

# 导入 einops 库中的函数
from einops import rearrange, repeat, reduce, pack, unpack

# 导入自定义模块
from mixture_of_attention.attend import Attend
from mixture_of_attention.rotary_emb import apply_rotary_pos_emb

from local_attention import LocalMHA

from colt5_attention import CoordinateDescentRouter

# 辅助函数

# 判断变量是否存在
def exists(val):
    return val is not None

# 如果变量存在则返回其值,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# 将张量打包成指定模式的形状
def pack_one(t, pattern):
    return pack([t], pattern)

# 将打包后的张量解包成原始形状
def unpack_one(t, ps, pattern):
    return unpack(t, ps, pattern)[0]

# 将张量填充到指定的倍数
def pad_to_multiple(tensor, multiple, dim = -1, value = 0):
    seq_len = tensor.shape[dim]
    m = seq_len / multiple
    if m.is_integer():
        return tensor, seq_len

    remainder = math.ceil(m) * multiple - seq_len
    pad_offset = (0,) * (-1 - dim) * 2
    padded_tensor = F.pad(tensor, (*pad_offset, 0, remainder), value = value)
    return padded_tensor, seq_len

# 归一化

# RMS 归一化模块
class RMSNorm(nn.Module):
    def __init__(self, dim, groups = 1):
        super().__init__()
        self.scale = dim ** 0.5
        self.gamma = nn.Parameter(torch.ones(groups, dim, 1))

    def forward(self, x):
        normed = F.normalize(x, dim = -2)
        return normed * self.scale * self.gamma

# 注意力机制

# 注意力模块
class Attention(nn.Module):
    def __init__(
        self,
        dim,
        *,
        dim_head = 64,
        dim_context = None,
        heads = 8,
        causal = False,
        groups = 1, # 定义专家的数量
        dropout = 0.,
        flash = False,
        prenorm = False
    ):
        super().__init__()
        self.heads = heads
        self.groups = groups

        dim_inner = dim_head * heads
        dim_context = default(dim_context, dim)

        self.norm = RMSNorm(dim, groups = groups) if prenorm else nn.Identity()
        self.context_norm = RMSNorm(dim_context, groups = groups) if prenorm else nn.Identity()

        self.attend = Attend(
            dropout = dropout,
            causal = causal,
            flash = flash
        )

        # 空键/值,用于防止一行全部被掩码掉

        self.null_kv = nn.Parameter(torch.randn(2, groups, heads, 1, dim_head))

        # 利用卷积组并行处理专家

        self.to_q = nn.Conv1d(dim * groups, dim_inner * groups, 1, bias = False, groups = groups)
        self.to_kv = nn.Conv1d(dim_context * groups, dim_inner * 2 * groups, 1, bias = False, groups = groups)
        self.to_out = nn.Conv1d(dim_inner * groups, dim * groups, 1, bias = False, groups = groups)

    def forward(
        self,
        x,
        context = None,
        mask = None,
        queries_scale = None,
        keys_scale = None,
        values_scale = None,
        output_scale = None,
        rotary_emb: Optional[Tuple[Tensor, Tensor]] = None
        ):
            """
            einops
            b - batch
            g - groups
            n - sequence
            d - feature dimension
            """
            # 获取输入张量的形状信息
            b, g, h = x.shape[0], self.groups, self.heads

            # 判断是否只有一个专家
            one_expert = x.ndim == 3

            # 如果只有一个专家,则将其维度扩展为4维
            if one_expert:
                assert g == 1
                x = rearrange(x, 'b n d -> b 1 n d')

            # 断言输入张量为4维
            assert x.ndim == 4
            # 断言输入张量的第二维为groups
            assert x.shape[1] == g

            # 将groups折叠到特征维度中,以便通过分组卷积一次处理
            x = rearrange(x, 'b g n d -> b g d n')

            # 处理交叉注意力的上下文
            if exists(context):
                context_one_expert = context.ndim == 3

                if context_one_expert:
                    assert g == 1
                    context = rearrange(context, 'b n d -> b 1 n d')

                assert context.ndim == 4
                assert context.shape[1] == g

                context = rearrange(context, 'b g n d -> b g d n')

            # 如果没有传入context,则使用输入张量x
            context = default(context, x)

            # 处理mask
            if exists(mask):
                if mask.ndim == 2:
                    mask = repeat(mask, 'b n -> (b g) n', g = g)
                elif mask.ndim == 3:
                    mask = rearrange(mask, 'b g n -> (b g) n')

                mask = F.pad(mask, (1, 0), value = True)

            # 如果适用,进行预归一化
            x = self.norm(x)
            context = self.context_norm(context)

            # 将groups折叠到维度中以进行分组卷积
            x, context = map(lambda t: rearrange(t, 'b g d n -> b (g d) n'), (x, context))

            # 获取查询、键、值
            q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = 1))

            # 拆分头部并将groups合并到批次中
            q, k, v = map(lambda t: rearrange(t, 'b (g h d) n -> b g h n d', h = h, g = g), (q, k, v))

            # 旋转嵌入
            if exists(rotary_emb):
                q_rotary_emb, k_rotary_emb = rotary_emb

                if q_rotary_emb.ndim > 2:
                    q_rotary_emb = rearrange(q_rotary_emb, 'b g n d -> b g 1 n d')

                if k_rotary_emb.ndim > 2:
                    k_rotary_emb = rearrange(k_rotary_emb, 'b g n d -> b g 1 n d')

                q = apply_rotary_pos_emb(q_rotary_emb, q)
                k = apply_rotary_pos_emb(k_rotary_emb, k)

            # 如果传入了queries_scale,则给查询加权
            if exists(queries_scale):
                q = q * queries_scale

            # 如果传入了keys_scale,则给键加权
            if exists(keys_scale):
                k = k * keys_scale

            # 如果传入了values_scale,则给值加权
            if exists(values_scale):
                v = v * values_scale

            # 将groups合并到批次中
            q, k, v = map(lambda t: rearrange(t, 'b g ... -> (b g) ...'), (q, k, v))

            # 连接空键/值,以防止一行中所有元素都被屏蔽并节省大量麻烦
            nk, nv = map(lambda t: repeat(t, 'g h 1 d -> (b g) h 1 d', b = b), self.null_kv)

            k = torch.cat((nk, k), dim = -2)
            v = torch.cat((nv, v), dim = -2)

            # 注意力机制
            out = self.attend(q, k, v, mask = mask)

            # 合并头部输出
            out = rearrange(out, '(b g) h n d -> b (g h d) n', g = g)

            out = self.to_out(out)

            out = rearrange(out, 'b (g d) n -> b g n d', g = g)

            # 如果只有一个专家,则将其维度还原为3维
            if one_expert:
                out = rearrange(out, 'b 1 n d -> b n d')

            # 如果传入了output_scale,则给输出加权
            if exists(output_scale):
                out = out * output_scale

            return out
# 定义混合注意力机制的类
class MixtureOfAttention(nn.Module):
    # 初始化函数
    def __init__(
        self,
        dim,
        *,
        num_routed_queries,
        num_routed_key_values,
        dim_context = None,
        local_attn = False,
        local_attn_window_size = None,
        num_experts = 2,
        dim_head = 64,
        heads = 8,
        dropout = 0.,
        use_triton = True,
        flash_attn = True,
        prenorm = True,
        average_routed = False,
        **kwargs
    ):
        super().__init__()
        dim_context = default(dim_context, dim)
        self.num_routed_queries = num_routed_queries
        self.num_routed_key_values = num_routed_key_values

        # 如果不是本地注意力,创建一个参数化的空路由令牌
        self.null_routed_token = nn.Parameter(torch.randn(1, 1, dim)) if not local_attn else None

        self.average_routed = average_routed

        self.local_attn = None

        # 如果使用本地注意力,创建本地多头注意力对象
        if local_attn:
            assert exists(local_attn_window_size)
            self.local_attn = LocalMHA(
                dim = dim,
                dim_head = dim_head,
                heads = heads,
                prenorm = prenorm,
                window_size = local_attn_window_size
            )

        # 创建查询路由器对象
        self.query_router = CoordinateDescentRouter(
            dim,
            num_routing_tokens = num_experts,
            use_triton = use_triton,
            **kwargs
        )

        # 创建键值路由器对象
        self.key_value_router = CoordinateDescentRouter(
            dim_context,
            num_routing_tokens = num_experts,
            use_triton = use_triton,
            **kwargs
        )

        # 创建注意力对象
        self.attn = Attention(
            dim = dim,
            dim_context = dim_context,
            dim_head = dim_head,
            heads = heads,
            groups = num_experts,
            dropout = dropout,
            flash = flash_attn,
            prenorm = prenorm
        )

    # 返回模型参数所在的设备
    @property
    def device(self):
        return next(self.parameters()).device

    # 前向传播函数
    def forward(
        self,
        x,
        context = None,
        mask = None,
        context_mask = None,
        num_routed_queries = None,
        num_routed_key_values = None,
        rotary_emb = None
        ):
            # 设置路由查询数量为默认值或者传入的值
            num_routed_queries = default(num_routed_queries, self.num_routed_queries)
            # 设置路由键值对数量为默认值或者传入的值
            num_routed_key_values = default(num_routed_key_values, self.num_routed_key_values)

            # 判断是否进行跨注意力
            is_cross_attn = exists(context)

            # 断言不能同时存在本地注意力和跨注意力
            assert not (exists(self.local_attn) and is_cross_attn), 'cannot do cross attention with local attention (only for self attention)'

            if not is_cross_attn:
                # 如果不是跨注意力,则使用自注意力
                context = x
                context_mask = mask

            # 获取查询索引、查询分数、查询、查询掩码
            query_indices, query_scores, queries, query_mask = self.query_router(x, mask = mask, num_tokens = num_routed_queries, keep_one_route_dim = True)
            query_scores = rearrange(query_scores, 'b g n -> b g n 1')

            # 获取键值索引、键值分数、键值、键值掩码
            kv_indices, key_value_scores, key_values, key_value_mask = self.key_value_router(context, mask = context_mask, num_tokens = num_routed_key_values, keep_one_route_dim = True)
            key_value_scores = rearrange(key_value_scores, 'b g n -> b g 1 n 1')

            # 旋转嵌入

            if exists(rotary_emb):
                assert not is_cross_attn, 'rotary embedding should not be used for cross attending'
                q_rotary_emb = rotary_emb[query_indices] if exists(query_indices) else rotary_emb
                k_rotary_emb = rotary_emb[kv_indices] if exists(kv_indices) else rotary_emb
                rotary_emb = (q_rotary_emb, k_rotary_emb)

            # 注意力计算

            attn_out = self.attn(
                queries,
                rotary_emb = rotary_emb,
                context = key_values,
                mask = key_value_mask,
                values_scale = key_value_scores,
                output_scale = query_scores
            )

            local_out = None
            if exists(self.local_attn):
                local_out = self.local_attn(x, mask = mask)

            need_route_queries = exists(query_indices)

            if not need_route_queries:
                out = attn_out

                if exists(local_out):
                    local_out = rearrange(local_out, 'b n d -> b 1 n d')
                    out = torch.cat((local_out, out), dim = 1)

                out = reduce(attn_out, 'b e n d -> b n d', 'mean')

                if exists(mask):
                    out = out.masked_fill(~mask[..., None], 0.)

                return out

            out = torch.zeros_like(x)
            counts = torch.zeros(x.shape[:-1], device = x.device)

            query_indices = rearrange(query_indices, 'b g n -> b (g n)')
            attn_out = rearrange(attn_out, 'b g n d -> b (g n) d')

            expanded_query_indices = repeat(query_indices, 'b n -> b n d', d = x.shape[-1])

            attn_out_summed = out.scatter_add(1, expanded_query_indices, attn_out)

            ones = torch.ones(attn_out.shape[:-1], device = self.device)

            if exists(query_mask):
                ones = ones * rearrange(query_mask, 'b g n -> b (g n)')

            counts = counts.scatter_add(1, query_indices, ones)
            counts = rearrange(counts, '... -> ... 1')

            has_unrouted = not exists(local_out)

            if not has_unrouted:
                counts = counts + 1
                attn_out_summed = attn_out_summed + local_out
            else:
                not_routed_mask = counts == 0
                attn_out_summed = attn_out_summed.masked_fill(not_routed_mask, 0.)

            out = attn_out_summed

            # 如果需要,进行平均

            if self.average_routed:
                out = out / counts.clamp(min = 1e-5)

            # 对于未路由的位置,使用学习到的路由令牌而不是仅仅是0

            if has_unrouted:
                out = torch.where(
                    not_routed_mask,
                    self.null_routed_token,
                    out,
                )

            if exists(mask):
                out = out.masked_fill(~mask[..., None], 0.)

            return out
# 定义一个混合自回归注意力模型类
class MixtureOfAutoregressiveAttention(nn.Module):
    def __init__(
        self,
        dim,
        *,
        num_routed_queries,  # 路由查询的数量
        num_routed_key_values,  # 路由键值对的数量
        local_attn_window_size,  # 本地注意力窗口大小
        routed_window_size = None,  # 路由窗口大小,默认为None
        num_experts = 2,  # 专家数量,默认为2
        dim_head = 64,  # 头维度,默认为64
        heads = 8,  # 头数,默认为8
        dropout = 0.,  # 丢弃率,默认为0
        use_triton = False,  # 是否使用 Triton,默认为False
        flash_attn = True,  # 是否使用 Flash 注意力,默认为True
        prenorm = True,  # 是否使用预归一化,默认为True
        average_routed = False,  # 是否平均路由,默认为False
        **kwargs
    ):
        super().__init__()
        self.num_routed_queries = num_routed_queries  # 初始化路由查询数量
        self.num_routed_key_values = num_routed_key_values  # 初始化路由键值对数量

        self.num_experts = num_experts  # 初始化专家数量
        self.null_tokens = nn.Parameter(torch.randn(num_experts, dim))  # 初始化空令牌

        routed_window_size = default(routed_window_size, local_attn_window_size)  # 设置路由窗口大小为默认值或本地注意力窗口大小

        self.routed_window_size = routed_window_size  # 初始化路由窗口大小
        self.average_routed = average_routed  # 初始化是否平均路由

        # 创建本地多头自注意力模块
        self.local_attn = LocalMHA(
            dim = dim,
            dim_head = dim_head,
            heads = heads,
            prenorm = prenorm,
            causal = True,
            window_size = local_attn_window_size
        )

        # 创建查询路由器
        self.query_router = CoordinateDescentRouter(
            dim,
            num_routing_tokens = num_experts,
            use_triton = use_triton,
            **kwargs
        )

        # 创建键值路由器
        self.key_value_router = CoordinateDescentRouter(
            dim,
            num_routing_tokens = num_experts,
            use_triton = use_triton,
            **kwargs
        )

        # 创建注意力模块
        self.attn = Attention(
            dim = dim,
            dim_head = dim_head,
            heads = heads,
            groups = num_experts,
            dropout = dropout,
            flash = flash_attn,
            prenorm = prenorm
        )

    # 定义设备属性
    @property
    def device(self):
        return next(self.parameters()).device

    # 前向传播函数
    def forward(
        self,
        x,
        rotary_emb = None,
        num_routed_queries = None,
        num_routed_key_values = None

.\lucidrains\mixture-of-attention\mixture_of_attention\transformer.py

# 导入所需的库
import torch
import torch.nn.functional as F
from torch import nn, einsum

# 导入重排操作库
from einops import rearrange

# 导入自定义的注意力机制类
from mixture_of_attention.mixture_of_attention import MixtureOfAutoregressiveAttention

# 导入自定义的旋转嵌入类
from mixture_of_attention.rotary_emb import RotaryEmbedding

# 辅助函数

# 判断变量是否存在的辅助函数
def exists(val):
    return val is not None

# 类定义

# RMS 归一化类
class RMSNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.scale = dim ** 0.5
        self.gamma = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        normed = F.normalize(x, dim = -1)
        return normed * self.scale * self.gamma

# 前馈神经网络类
def FeedForward(dim, mult = 4):
    return nn.Sequential(
        RMSNorm(dim),
        nn.Linear(dim, dim * mult),
        nn.GELU(),
        nn.Linear(dim * mult, dim)
    )

# 主类定义

# Transformer 模型类
class Transformer(nn.Module):
    def __init__(
        self,
        *,
        dim,
        num_tokens,
        depth,
        seq_len,
        local_attn_window_size,
        num_routed_queries,
        num_routed_key_values,
        num_experts,
        cosine_sim_routing = True,
        routed_window_size = None,
        dim_head = 64,
        heads = 8,
        ff_mult = 4,
        use_triton = True,
        routed_rotary_emb = True
    ):
        super().__init__()
        self.token_emb = nn.Embedding(num_tokens, dim)
        self.pos_emb = nn.Embedding(seq_len, dim)
        self.seq_len = seq_len

        self.rotary_emb = RotaryEmbedding(dim_head) if routed_rotary_emb else None

        self.layers = nn.ModuleList([])

        # 创建多层 Transformer 模型
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                MixtureOfAutoregressiveAttention(
                    dim = dim,
                    local_attn_window_size = local_attn_window_size,
                    routed_window_size = routed_window_size,
                    num_routed_queries = num_routed_queries,
                    num_routed_key_values = num_routed_key_values,
                    cosine_sim_routing = cosine_sim_routing,
                    num_experts = num_experts,
                    dim_head = dim_head,
                    heads = heads,
                    use_triton = use_triton
                ),
                FeedForward(dim = dim, mult = ff_mult)
            ]))

        # 输出层
        self.to_logits = nn.Sequential(
            RMSNorm(dim),
            nn.Linear(dim, num_tokens)
        )

    # ��取设备信息
    @property
    def device(self):
        return next(self.parameters()).device

    # 前向传播函数
    def forward(self, x):
        x = self.token_emb(x)
        x = x + self.pos_emb(torch.arange(x.shape[-2], device = self.device))

        rotary_emb = None
        if exists(self.rotary_emb):
            rotary_emb = self.rotary_emb(x.shape[1])

        # 多层 Transformer 模型的前向传播
        for attn, ff in self.layers:
            x = attn(x, rotary_emb = rotary_emb) + x

            x = ff(x) + x

        return self.to_logits(x)

.\lucidrains\mixture-of-attention\mixture_of_attention\__init__.py

# 从mixture_of_attention包中导入MixtureOfAttention、MixtureOfAutoregressiveAttention和Attention类
from mixture_of_attention.mixture_of_attention import (
    MixtureOfAttention,
    MixtureOfAutoregressiveAttention,
    Attention
)

Mixture-of-Attention

Some personal experiments around routing tokens to different autoregressive attention, akin to mixture-of-experts

Learned from researcher friend that this has been tried in Switch Transformers unsuccessfully, but I'll give it a go, bringing in some learning points from recent papers like CoLT5.

In my opinion, the CoLT5 paper basically demonstrates mixture of attention already for 2 experts. This just has to be generalized to greater than 2 experts, and for autoregressive case. Local attention branch would just be a special case of one expert with fixed routing. If I route only half the tokens, that would lead to a savings of 4x. If I can show even ~4 experts being better than 1 attention, that should be a win.

Appreciation

  • Stability and 🤗 Huggingface for their generous sponsorships to work on and open source cutting edge artificial intelligence research

  • einops for making tensor manipulation fun and easy

Install

$ pip install mixture-of-attention

Usage

import torch
from mixture_of_attention import MixtureOfAttention

mixture_of_attn = MixtureOfAttention(
    dim = 512,
    dim_context = 256,
    num_routed_queries = 16,
    num_routed_key_values = 16,
    num_experts = 2,
    dim_head = 64,
    heads = 8
)

x = torch.randn(1, 1024, 512)
mask = torch.ones((1, 1024)).bool()

context = torch.randn(1, 512, 256)
context_mask = torch.ones((1, 512)).bool()

mixture_of_attn(x, context = context, mask = mask) # (1, 1024, 512)

Autoregressive flavor

import torch
from mixture_of_attention import MixtureOfAutoregressiveAttention

mixture_of_attn = MixtureOfAutoregressiveAttention(
    dim = 512,
    local_attn_window_size = 64,       # local attention window size
    routed_window_size = None,         # will be set to the same as local_attn_window_size if None. ideally less than or equal to local attention window size for full receptive field
    num_routed_queries = 12,
    num_routed_key_values = 12,
    num_experts = 2,
    dim_head = 64,
    heads = 8
)

x = torch.randn(1, 1023, 512)

out = mixture_of_attn(x) # (1, 1023, 512)

Todo

  • allow for local attention to be automatically included, either for grouped attention, or use LocalMHA from local-attention repository in parallel, weighted properly

  • make it work for autoregressive

  • try dynamic routing tokens, using projection of masked mean-pooled queries

  • try out arxiv.org/abs/2210.05…

Citations

@inproceedings{Ainslie2023CoLT5FL,
    title   = {CoLT5: Faster Long-Range Transformers with Conditional Computation},
    author  = {Joshua Ainslie and Tao Lei and Michiel de Jong and Santiago Ontan'on and Siddhartha Brahma and Yury Zemlyanskiy and David Uthus and Mandy Guo and James Lee-Thorp and Yi Tay and Yun-Hsuan Sung and Sumit Sanghai},
    year    = {2023}
}
@inproceedings{dao2022flashattention,
    title   = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
    author  = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
    booktitle = {Advances in Neural Information Processing Systems},
    year    = {2022}
}
@article{Wright2015CoordinateDA,
    title   = {Coordinate descent algorithms},
    author  = {Stephen J. Wright},
    journal = {Mathematical Programming},
    year    = {2015},
    volume  = {151},
    pages   = {3-34}
}
@article{Schmitzer2016StabilizedSS,
    title   = {Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems},
    author  = {Bernhard Schmitzer},
    journal = {ArXiv},
    year    = {2016},
    volume  = {abs/1610.06519}
}
@inproceedings{rogozhnikov2022einops,
    title   = {Einops: Clear and Reliable Tensor Manipulations with Einstein-like Notation},
    author  = {Alex Rogozhnikov},
    booktitle = {International Conference on Learning Representations},
    year    = {2022},
    url     = {https://openreview.net/forum?id=oapKSVM2bcj}
}

.\lucidrains\mixture-of-attention\setup.py

# 导入设置和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'mixture-of-attention', # 包的名称
  packages = find_packages(exclude=[]), # 查找所有包
  version = '0.0.24', # 版本号
  license='MIT', # 许可证
  description = 'Mixture of Attention', # 描述
  author = 'Phil Wang', # 作者
  author_email = 'lucidrains@gmail.com', # 作者邮箱
  long_description_content_type = 'text/markdown', # 长描述内容类型
  url = 'https://github.com/lucidrains/mixture-of-attention', # URL
  keywords = [ # 关键词列表
    'artificial intelligence',
    'deep learning',
    'transformers',
    'attention mechanism',
    'mixture-of-experts',
    'routed attention'
  ],
  install_requires=[ # 安装依赖
    'colt5-attention>=0.10.14',
    'einops>=0.6.1',
    'local-attention>=1.8.6',
    'torch>=1.6',
  ],
  classifiers=[ # 分类器
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\mixture-of-attention\train.py

# 导入必要的库
import gzip
import random
import tqdm
import numpy as np

import torch
from torch.optim import Adam
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset

from mixture_of_attention.transformer import Transformer
from mixture_of_attention.autoregressive_wrapper import AutoregressiveWrapper

# 定义常量
NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 1e-4
VALIDATE_EVERY = 100
PRIME_LENGTH = 128
GENERATE_EVERY = 500
GENERATE_LENGTH = 512
SEQ_LEN = 512

# 定义辅助函数

# 从 token 解码为字符
def decode_token(token):
    return str(chr(max(32, token)))

# 从 tokens 解码为字符串
def decode_tokens(tokens):
    return "".join(list(map(decode_token, tokens)))

# 实例化 Transformer 模型
model = Transformer(
    num_tokens = 256,
    dim = 512,
    depth = 8,
    num_experts = 2,
    seq_len = SEQ_LEN,
    local_attn_window_size = 64,
    num_routed_queries = 32,
    num_routed_key_values = 64,
    cosine_sim_routing = True,
    use_triton = True
)

model = AutoregressiveWrapper(model).cuda()

# 准备 enwik8 数据

with gzip.open("./data/enwik8.gz") as file:
    data = np.frombuffer(file.read(int(95e6)), dtype=np.uint8).copy()
    np_train, np_valid = np.split(data, [int(90e6)])
    data_train, data_val = torch.from_numpy(np_train), torch.from_numpy(np_valid)

# 定义自定义数据集类
class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
        full_seq = self.data[rand_start : rand_start + self.seq_len + 1].long()
        return full_seq.cuda()

    def __len__(self):
        return self.data.size(0) // self.seq_len

# 创建训练集和验证集的 DataLoader
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
train_loader = cycle(DataLoader(train_dataset, batch_size=BATCH_SIZE))
val_loader = cycle(DataLoader(val_dataset, batch_size=BATCH_SIZE))

# 定义优化器
optim = Adam(model.parameters(), lr = LEARNING_RATE)

# 训练模型
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval = 10.0, desc = "training"):
    model.train()

    for _ in range(GRADIENT_ACCUMULATE_EVERY):
        loss = model(next(train_loader))
        loss.backward(loss / GRADIENT_ACCUMULATE_EVERY)

    print(f"training loss: {loss.item()}")
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)

    optim.step()
    optim.zero_grad()

    if i % VALIDATE_EVERY == 0:
        model.eval()
        with torch.no_grad():
            loss = model(next(val_loader))
            print(f"validation loss: {loss.item()}")

    if i % GENERATE_EVERY == 0:
        model.eval()
        inp = random.choice(val_dataset)[:PRIME_LENGTH]
        prime = decode_tokens(inp)
        print(f"%s \n\n %s", (prime, "*" * 100))

        sample = model.generate(inp[None, ...], GENERATE_LENGTH)
        output_str = decode_tokens(sample[0])
        print(output_str, "\n")

.\lucidrains\mixture-of-experts\mixture_of_experts\mixture_of_experts.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块
from torch import nn
# 从 torch 库中导入 nn.functional 模块,并使用别名 F
import torch.nn.functional as F

# 导入 math 库
import math
# 从 inspect 库中导入 isfunction 函数

# 常量定义
MIN_EXPERT_CAPACITY = 4

# 辅助函数

# 默认值函数,如果 val 为 None,则返回 default_val
def default(val, default_val):
    # 如果 default_val 是函数,则调用该函数,否则直接返回 default_val
    default_val = default_val() if isfunction(default_val) else default_val
    return val if val is not None else default_val

# 将元素 el 转换为元组
def cast_tuple(el):
    return el if isinstance(el, tuple) else (el,)

# 与张量相关的辅助函数

# 获取张量 t 中最大的值和对应的索引
def top1(t):
    values, index = t.topk(k=1, dim=-1)
    values, index = map(lambda x: x.squeeze(dim=-1), (values, index))
    return values, index

# 计算张量 t 在指定维度上的累积和,不包括当前位置的值
def cumsum_exclusive(t, dim=-1):
    num_dims = len(t.shape)
    num_pad_dims = - dim - 1
    pre_padding = (0, 0) * num_pad_dims
    pre_slice   = (slice(None),) * num_pad_dims
    padded_t = F.pad(t, (*pre_padding, 1, 0)).cumsum(dim=dim)
    return padded_t[(..., slice(None, -1), *pre_slice)]

# 安全的 one-hot 编码函数,避免索引超出范围
def safe_one_hot(indexes, max_length):
    max_index = indexes.max() + 1
    return F.one_hot(indexes, max(max_index + 1, max_length))[..., :max_length]

# 初始化张量 t,使用均匀分布
def init_(t):
    dim = t.shape[-1]
    std = 1 / math.sqrt(dim)
    return t.uniform_(-std, std)

# 激活函数

# GELU 激活函数类
class GELU_(nn.Module):
    def forward(self, x):
        return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))

# 如果 nn 模块中存在 GELU 函数,则使用该函数,否则使用自定义的 GELU_ 激活函数
GELU = nn.GELU if hasattr(nn, 'GELU') else GELU_

# 专家类

class Experts(nn.Module):
    def __init__(self,
        dim,
        num_experts = 16,
        hidden_dim = None,
        activation = GELU):
        super().__init__()

        hidden_dim = default(hidden_dim, dim * 4)
        num_experts = cast_tuple(num_experts)

        w1 = torch.zeros(*num_experts, dim, hidden_dim)
        w2 = torch.zeros(*num_experts, hidden_dim, dim)

        w1 = init_(w1)
        w2 = init_(w2)

        self.w1 = nn.Parameter(w1)
        self.w2 = nn.Parameter(w2)
        self.act = activation()

    def forward(self, x):
        hidden = torch.einsum('...nd,...dh->...nh', x, self.w1)
        hidden = self.act(hidden)
        out    = torch.einsum('...nh,...hd->...nd', hidden, self.w2)
        return out

# 下面的代码几乎完全从官方的 tensorflow 版本转录而来,相关论文也是基于此版本编写
# https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/research/moe.py

# 门控网络

class Top2Gating(nn.Module):
    def __init__(
        self,
        dim,
        num_gates,
        eps = 1e-9,
        outer_expert_dims = tuple(),
        second_policy_train = 'random',
        second_policy_eval = 'random',
        second_threshold_train = 0.2,
        second_threshold_eval = 0.2,
        capacity_factor_train = 1.25,
        capacity_factor_eval = 2.):
        super().__init__()

        self.eps = eps
        self.num_gates = num_gates
        self.w_gating = nn.Parameter(torch.randn(*outer_expert_dims, dim, num_gates))

        self.second_policy_train = second_policy_train
        self.second_policy_eval = second_policy_eval
        self.second_threshold_train = second_threshold_train
        self.second_threshold_eval = second_threshold_eval
        self.capacity_factor_train = capacity_factor_train
        self.capacity_factor_eval = capacity_factor_eval

# 普通的专家混合模型

class MoE(nn.Module):
    # 初始化函数,设置模型参数和属性
    def __init__(self,
        dim,
        num_experts = 16,
        hidden_dim = None,
        activation = nn.ReLU,
        second_policy_train = 'random',
        second_policy_eval = 'random',
        second_threshold_train = 0.2,
        second_threshold_eval = 0.2,
        capacity_factor_train = 1.25,
        capacity_factor_eval = 2.,
        loss_coef = 1e-2,
        experts = None):
        # 调用父类的初始化函数
        super().__init__()

        # 设置模型的专家数量
        self.num_experts = num_experts

        # 设置门控参数
        gating_kwargs = {'second_policy_train': second_policy_train, 'second_policy_eval': second_policy_eval, 'second_threshold_train': second_threshold_train, 'second_threshold_eval': second_threshold_eval, 'capacity_factor_train': capacity_factor_train, 'capacity_factor_eval': capacity_factor_eval}
        # 创建门控对象
        self.gate = Top2Gating(dim, num_gates = num_experts, **gating_kwargs)
        # 创建专家对象
        self.experts = default(experts, lambda: Experts(dim, num_experts = num_experts, hidden_dim = hidden_dim, activation = activation))
        # 设置损失系数
        self.loss_coef = loss_coef

    # 前向传播函数
    def forward(self, inputs, **kwargs):
        # 获取输入的形状信息
        b, n, d, e = *inputs.shape, self.num_experts
        # 获取门控输出和损失
        dispatch_tensor, combine_tensor, loss = self.gate(inputs)
        # 将输入数据分发给专家
        expert_inputs = torch.einsum('bnd,bnec->ebcd', inputs, dispatch_tensor)

        # 将专家输入数据传递给专家模型
        orig_shape = expert_inputs.shape
        expert_inputs = expert_inputs.reshape(e, -1, d)
        expert_outputs = self.experts(expert_inputs)
        expert_outputs = expert_outputs.reshape(*orig_shape)

        # 将专家输出数据合并
        output = torch.einsum('ebcd,bnec->bnd', expert_outputs, combine_tensor)
        # 返回输出和损失乘以损失系数
        return output, loss * self.loss_coef
# 定义一个名为 HeirarchicalMoE 的类,表示两级层次混合专家模型
class HeirarchicalMoE(nn.Module):
    def __init__(self,
        dim,
        num_experts = (4, 4),  # 设置专家数量,默认为 (4, 4)
        hidden_dim = None,  # 隐藏层维度,默认为 None
        activation = nn.ReLU,  # 激活函数,默认为 ReLU
        second_policy_train = 'random',  # 第二级门控策略(训练阶段),默认为 'random'
        second_policy_eval = 'random',  # 第二级门控策略(评估阶段),默认为 'random'
        second_threshold_train = 0.2,  # 第二级门控阈值(训练阶段),默认为 0.2
        second_threshold_eval = 0.2,  # 第二级门控阈值(评估阶段),默认为 0.2
        capacity_factor_train = 1.25,  # 容量因子(训练阶段),默认为 1.25
        capacity_factor_eval = 2.,  # 容量因子(评估阶段),默认为 2.0
        loss_coef = 1e-2,  # 损失系数,默认为 0.01
        experts = None):  # 专家模型,默认为 None
        super().__init__()

        assert len(num_experts) == 2, 'only 2 levels of heirarchy for experts allowed for now'  # 断言,只允许两级专家层次
        num_experts_outer, num_experts_inner = num_experts
        self.num_experts_outer = num_experts_outer
        self.num_experts_inner = num_experts_inner

        gating_kwargs = {'second_policy_train': second_policy_train, 'second_policy_eval': second_policy_eval, 'second_threshold_train': second_threshold_train, 'second_threshold_eval': second_threshold_eval, 'capacity_factor_train': capacity_factor_train, 'capacity_factor_eval': capacity_factor_eval}

        # 创建外层门控模块和内层门控模块
        self.gate_outer = Top2Gating(dim, num_gates = num_experts_outer, **gating_kwargs)
        self.gate_inner = Top2Gating(dim, num_gates = num_experts_inner, outer_expert_dims = (num_experts_outer,), **gating_kwargs)

        # 创建专家模型
        self.experts = default(experts, lambda: Experts(dim, num_experts = num_experts, hidden_dim = hidden_dim, activation = activation))
        self.loss_coef = loss_coef

    def forward(self, inputs, **kwargs):
        b, n, d, eo, ei = *inputs.shape, self.num_experts_outer, self.num_experts_inner
        dispatch_tensor_outer, combine_tensor_outer, loss_outer = self.gate_outer(inputs)
        expert_inputs_outer = torch.einsum('bnd,bnec->ebcd', inputs, dispatch_tensor_outer)

        # 构建“重要性”张量,用于第二级门控
        importance = combine_tensor_outer.permute(2, 0, 3, 1).sum(dim=-1)
        importance = 0.5 * ((importance > 0.5).float() + (importance > 0.).float())

        dispatch_tensor_inner, combine_tensor_inner, loss_inner = self.gate_inner(expert_inputs_outer, importance = importance)
        expert_inputs = torch.einsum('ebnd,ebnfc->efbcd', expert_inputs_outer, dispatch_tensor_inner)

        # 通过专家模型处理专家输入
        orig_shape = expert_inputs.shape
        expert_inputs = expert_inputs.reshape(eo, ei, -1, d)
        expert_outputs = self.experts(expert_inputs)
        expert_outputs = expert_outputs.reshape(*orig_shape)

        # 合并专家输出
        expert_outputs_outer = torch.einsum('efbcd,ebnfc->ebnd', expert_outputs, combine_tensor_inner)
        output = torch.einsum('ebcd,bnec->bnd', expert_outputs_outer, combine_tensor_outer)
        return output, (loss_outer + loss_inner) * self.loss_coef