Lucidrains 系列项目源码解析(六十一)
.\lucidrains\metnet3-pytorch\metnet3_pytorch\__init__.py
# 从 metnet3_pytorch 包中导入 MetNet3 类
from metnet3_pytorch.metnet3_pytorch import (
MetNet3
)

MetNet-3 - Pytorch
Implementation of MetNet 3, SOTA neural weather model out of Google Deepmind, in Pytorch
The model architecture is pretty unremarkable. It is basically a U-net with a specific well performing vision transformer. The most interesting thing about the paper may end up being the loss scaling in section 4.3.2
Appreciation
- StabilityAI, A16Z Open Source AI Grant Program, and 🤗 Huggingface for the generous sponsorships, as well as my other sponsors, for affording me the independence to open source current artificial intelligence research
Install
$ pip install metnet3-pytorch
Usage
import torch
from metnet3_pytorch import MetNet3
metnet3 = MetNet3(
dim = 512,
num_lead_times = 722,
lead_time_embed_dim = 32,
input_spatial_size = 624,
attn_dim_head = 8,
hrrr_channels = 617,
input_2496_channels = 2 + 14 + 1 + 2 + 20,
input_4996_channels = 16 + 1,
precipitation_target_bins = dict(
mrms_rate = 512,
mrms_accumulation = 512,
),
surface_target_bins = dict(
omo_temperature = 256,
omo_dew_point = 256,
omo_wind_speed = 256,
omo_wind_component_x = 256,
omo_wind_component_y = 256,
omo_wind_direction = 180
),
hrrr_loss_weight = 10,
hrrr_norm_strategy = 'sync_batchnorm', # this would use a sync batchnorm to normalize the input hrrr and target, without having to precalculate the mean and variance of the hrrr dataset per channel
hrrr_norm_statistics = None # you can also also set `hrrr_norm_strategy = "precalculated"` and pass in the mean and variance as shape `(2, 617)` through this keyword argument
)
# inputs
lead_times = torch.randint(0, 722, (2,))
hrrr_input_2496 = torch.randn((2, 617, 624, 624))
hrrr_stale_state = torch.randn((2, 1, 624, 624))
input_2496 = torch.randn((2, 39, 624, 624))
input_4996 = torch.randn((2, 17, 624, 624))
# targets
precipitation_targets = dict(
mrms_rate = torch.randint(0, 512, (2, 512, 512)),
mrms_accumulation = torch.randint(0, 512, (2, 512, 512)),
)
surface_targets = dict(
omo_temperature = torch.randint(0, 256, (2, 128, 128)),
omo_dew_point = torch.randint(0, 256, (2, 128, 128)),
omo_wind_speed = torch.randint(0, 256, (2, 128, 128)),
omo_wind_component_x = torch.randint(0, 256, (2, 128, 128)),
omo_wind_component_y = torch.randint(0, 256, (2, 128, 128)),
omo_wind_direction = torch.randint(0, 180, (2, 128, 128))
)
hrrr_target = torch.randn(2, 617, 128, 128)
total_loss, loss_breakdown = metnet3(
lead_times = lead_times,
hrrr_input_2496 = hrrr_input_2496,
hrrr_stale_state = hrrr_stale_state,
input_2496 = input_2496,
input_4996 = input_4996,
precipitation_targets = precipitation_targets,
surface_targets = surface_targets,
hrrr_target = hrrr_target
)
total_loss.backward()
# after much training from above, you can predict as follows
metnet3.eval()
surface_preds, hrrr_pred, precipitation_preds = metnet3(
lead_times = lead_times,
hrrr_input_2496 = hrrr_input_2496,
hrrr_stale_state = hrrr_stale_state,
input_2496 = input_2496,
input_4996 = input_4996,
)
# Dict[str, Tensor], Tensor, Dict[str, Tensor]
Todo
-
figure out all the cross entropy and MSE losses
-
auto-handle normalization across all the channels of the HRRR by tracking a running mean and variance of targets during training (using sync batchnorm as hack)
-
allow researcher to pass in their own normalization variables for HRRR
-
build all the inputs to spec, also make sure hrrr input is normalized, offer option to unnormalize hrrr predictions
-
make sure model can be easily saved and loaded, with different ways of handling hrrr norm
-
figure out the topological embedding, consult a neural weather researcher
Citations
@article{Andrychowicz2023DeepLF,
title = {Deep Learning for Day Forecasts from Sparse Observations},
author = {Marcin Andrychowicz and Lasse Espeholt and Di Li and Samier Merchant and Alexander Merose and Fred Zyda and Shreya Agrawal and Nal Kalchbrenner},
journal = {ArXiv},
year = {2023},
volume = {abs/2306.06079},
url = {https://api.semanticscholar.org/CorpusID:259129311}
}
@inproceedings{ElNouby2021XCiTCI,
title = {XCiT: Cross-Covariance Image Transformers},
author = {Alaaeldin El-Nouby and Hugo Touvron and Mathilde Caron and Piotr Bojanowski and Matthijs Douze and Armand Joulin and Ivan Laptev and Natalia Neverova and Gabriel Synnaeve and Jakob Verbeek and Herv{\'e} J{\'e}gou},
booktitle = {Neural Information Processing Systems},
year = {2021},
url = {https://api.semanticscholar.org/CorpusID:235458262}
}
.\lucidrains\metnet3-pytorch\setup.py
# 导入设置安装包和查找包的函数
from setuptools import setup, find_packages
# 设置安装包的信息
setup(
name = 'metnet3-pytorch', # 包的名称
packages = find_packages(exclude=[]), # 查找并包含所有包
version = '0.0.12', # 版本号
license='MIT', # 许可证
description = 'MetNet 3 - Pytorch', # 描述
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
long_description_content_type = 'text/markdown', # 长描述内容类型
url = 'https://github.com/lucidrains/metnet3-pytorch', # 项目链接
keywords = [ # 关键词列表
'artificial intelligence',
'deep learning',
'vision transformers',
'unet',
'weather forecasting'
],
install_requires=[ # 安装依赖
'beartype',
'einops>=0.7.0',
'torch>=2.0',
],
classifiers=[ # 分类器列表
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\mirasol-pytorch\mirasol_pytorch\distributed.py
# 导入必要的库
from functools import cache
import torch
from torch.autograd import Function
import torch.distributed as distributed
from einops import rearrange
# 辅助函数
# 使用缓存装饰器缓存结果,判断当前是否处于分布式环境
@cache
def get_is_distributed():
return distributed.is_initialized() and distributed.get_world_size() > 1
# 在指定维度上对张量进行填充,使其达到指定长度
def pad_dim_to(t, length, dim = 0):
pad_length = length - t.shape[dim]
zero_pairs = (-dim - 1) if dim < 0 else (t.ndim - dim - 1)
return F.pad(t, (*((0, 0) * zero_pairs), 0, pad_length))
# 分布式辅助函数
# 在所有进程中收集具有可变维度的张量,根据给定的维度和大小
def all_gather_variable_dim(t, dim = 0, sizes = None):
device, rank, world_size = t.device, distributed.get_rank(), distributed.get_world_size()
if not exists(sizes):
size = torch.tensor(t.shape[dim], device = device, dtype = torch.long)
sizes = [torch.empty_like(size, device = device, dtype = torch.long) for i in range(world_size)]
distributed.all_gather(sizes, size)
sizes = torch.stack(sizes)
max_size = sizes.amax().item()
padded_t = pad_dim_to(t, max_size, dim = dim)
gathered_tensors = [torch.empty(padded_t.shape, device = device, dtype = padded_t.dtype) for i in range(world_size)]
distributed.all_gather(gathered_tensors, padded_t)
gathered_tensor = torch.cat(gathered_tensors, dim = dim)
seq = torch.arange(max_size, device = device)
mask = rearrange(seq, 'j -> 1 j') < rearrange(sizes, 'i -> i 1')
mask = rearrange(mask, 'i j -> (i j)')
seq = torch.arange(mask.shape[-1], device = device)
indices = seq[mask]
gathered_tensor = gathered_tensor.index_select(dim, indices)
return gathered_tensor, sizes
# 自定义的 Function 类,用于实现 all_gather 操作
class AllGather(Function):
@staticmethod
def forward(ctx, x, dim, sizes):
assert get_is_distributed()
x, batch_sizes = all_gather_variable_dim(x, dim = dim, sizes = sizes)
ctx.batch_sizes = batch_sizes.tolist()
ctx.dim = dim
return x, batch_sizes
@staticmethod
def backward(ctx, grads, _):
batch_sizes, rank = ctx.batch_sizes, distributed.get_rank()
grads_by_rank = grads.split(batch_sizes, dim = ctx.dim)
return grads_by_rank[rank], None, None
# 将自定义的 Function 应用到 all_gather 函数上
all_gather = AllGather.apply
.\lucidrains\mirasol-pytorch\mirasol_pytorch\mirasol_pytorch.py
# 导入所需的模块和函数
import operator
from functools import partial
from collections import namedtuple
import torch
import torch.nn.functional as F
from torch import Tensor, nn, einsum
from torch.nn import Module, ModuleList
# 导入 beartype 模块和相关类型
from beartype import beartype
from beartype.typing import Optional, Union, Tuple, Dict, Any
# 导入 einops 相关函数和层
from einops import rearrange, repeat, reduce, pack, unpack
from einops.layers.torch import Rearrange
# 导入 x_transformers 相关模块和类
from x_transformers import (
Encoder,
Decoder,
TransformerWrapper,
AutoregressiveWrapper
)
# 导入 x_transformers 中的 RotaryEmbedding 类
from x_transformers.x_transformers import RotaryEmbedding
# 导入 mirasol_pytorch 中的分布式函数
from mirasol_pytorch.distributed import all_gather, get_is_distributed
# 辅助函数
# 判断变量是否存在
def exists(v):
return v is not None
# 返回参数中第一个存在的值
def default(*args):
for arg in args:
if exists(arg):
return arg
return None
# 判断一个数是否可以被另一个数整除
def divisible_by(num, den):
return (num % den) == 0
# 判断参数中只有一个为 True
def only_one_true(*bools):
return sum(*[map(int, bools)]) == 1
# 将张量打包成指定模式
def pack_one(t, pattern):
return pack([t], pattern)
# 将打包的张量解包成指定模式
def unpack_one(t, ps, pattern):
return unpack(t, ps, pattern)[0]
# 张量操作函数
# 计算张量的 L2 范数
def l2norm(t):
return F.normalize(t, dim = -1)
# 计算张量之间的余弦相似度损失
def cosine_sim_loss(x, y):
x, y = map(l2norm, (x, y))
return 1. - einsum('b n d, b n d -> b n', x, y).mean()
# 生成位置编码的正弦和余弦值
def posemb_sincos_nd(
t: Tensor,
temperature: int = 10000,
dtype = torch.float32
):
b, *dims, feat_dim, device = *t.shape, t.device
seq_len = torch.tensor(dims).cumprod(dim = -1)[-1].item()
arange = partial(torch.arange, device = device)
num_dims = len(dims)
two_times_num_dims = 2 * num_dims # 2 because sin and cos of same position
rounded_feat_dim = feat_dim // num_dims * num_dims
feat_dim_remainder = feat_dim % num_dims
omega = arange(rounded_feat_dim // two_times_num_dims) / (rounded_feat_dim // two_times_num_dims - 1)
omega = 1.0 / (temperature ** omega)
meshed = torch.meshgrid(*[*map(arange, dims)], indexing = 'ij')
pos = torch.cat(tuple(m.flatten()[..., None] for m in meshed), dim = 0)
pos = pos * omega[None, :]
pos = torch.cat((pos.sin(), pos.cos()))
pos = rearrange(pos, '(n f) d -> n (f d)', n = seq_len)
pos = pos.type(dtype)
return F.pad(pos, (0, feat_dim_remainder))
# 生成具有一定概率的掩码张量
def mask_with_prob(
shape: Tuple[int, ...],
prob: float,
device = None
) -> Tensor:
length = shape[-1]
num_mask = int(prob * length)
randperm = torch.randn(shape, device = device).argsort(dim = -1)
return randperm >= num_mask
# 主类
# 定义 Losses 命名元组,包含不同类型的损失
Losses = namedtuple('Losses', [
'text_autoregressive',
'av_autoregressive',
'av_recon',
'text_av_sim_reg'
])
# Mirasol 类,继承自 Module 类
class Mirasol(Module):
@beartype
# 初始化函数,设置模型的各种参数
def __init__(
self,
*,
dim,
num_text_tokens,
video_image_size,
video_frames_per_timechunk,
audio_freq_dim,
audio_time_dim_per_timechunk,
audio_patch_size: Tuple[int, int], # 音频补丁大小 (频率, 时间)
video_patch_size: Tuple[int, int], # 视频补丁大小 (空间, 时间)
video_recon_patch_size: Optional[Tuple[int, int]] = None, # 视频重建补丁大小 (空间, 时间) - 用于重建损失的较小视频
video_recon_interpolate_mode = 'nearest',
audio_encoder: Union[Module, Dict[str, Any]],
video_encoder: Union[Module, Dict[str, Any]],
num_audio_video_register_tokens = 8, # 音频视频注册令牌数量 https://arxiv.org/abs/2309.16588
audio_video_mask_prob = 0.15, # 在论文中,他们使用了被屏蔽的令牌,但从伯克利遗忘-因果-掩码论文中,一个简单的键值掩码应该足够
text_max_seq_len = 2048,
text_forgetful_causal_mask_prob = 0.1, # https://arxiv.org/abs/2210.13432
encoder_depth = 6,
decoder_depth = 6,
combiner_depth = 2,
combiner_output_num_tokens = 3,
video_channels = 3,
attn_dim_head = 64,
attn_heads = 8,
flash_attn = True,
attn_layers_kwargs: dict = dict(),
combiner: Optional[Module] = None,
combiner_kwargs: dict = dict(),
autoregressive_wrapper_kwargs: dict = dict(
pad_value = 0,
ignore_index = -100
),
av_autoregressive_loss_weight = 1.,
av_reconstruction_loss_weight = 1.,
sim_reg_loss_weight = 0.
# 返回设备信息
@property
def device(self):
return next(self.parameters()).device
# 生成函数,用于生成序列
@torch.no_grad()
def generate(
self,
*,
seq_len: int,
prompt: Optional[Tensor] = None,
**kwargs
):
was_training = self.training
self.eval()
assert 'generate' not in kwargs
assert 'generate_seq_len' not in kwargs
# 调用前向传播函数生成序列
out = self.forward(
text = prompt,
generate = True,
generate_seq_len = seq_len,
**kwargs
)
self.train(was_training)
return out
# 前向传播函数,接收输入并返回输出
@beartype
def forward(
self,
*,
audio: Optional[Tensor] = None,
video: Optional[Tensor] = None,
encoded_audio: Optional[Tensor] = None,
encoded_video: Optional[Tensor] = None,
text: Optional[Tensor] = None,
text_mask: Optional[Tensor] = None,
return_loss = True,
return_loss_breakdown = False,
generate = False,
generate_seq_len = None
.\lucidrains\mirasol-pytorch\mirasol_pytorch\__init__.py
# 从 mirasol_pytorch 包中导入 Mirasol 类
from mirasol_pytorch.mirasol_pytorch import Mirasol

🌻 Mirasol - Pytorch
Implementation of Mirasol, SOTA Multimodal Autoregressive model out of Google Deepmind, in Pytorch
Will simply implement the Transformer Combiner and omit the other variants.
Appreciation
- StabilityAI, A16Z Open Source AI Grant Program, and 🤗 Huggingface for the generous sponsorships, as well as my other sponsors, for affording me the independence to open source current artificial intelligence research
Install
$ pip install mirasol-pytorch
Usage
import torch
from mirasol_pytorch import Mirasol
model = Mirasol(
dim = 512,
num_text_tokens = 256,
video_image_size = 128,
video_frames_per_timechunk = 2,
audio_freq_dim = 64,
audio_time_dim_per_timechunk = 32,
audio_patch_size = (32, 16),
video_patch_size = (64, 2),
audio_encoder = dict(
dim = 512,
depth = 2
),
video_encoder = dict(
dim = 512,
depth = 2
)
)
audio = torch.randn(1, 64, 1024)
video = torch.randn(1, 3, 12, 128, 128)
text = torch.randint(0, 256, (1, 1024))
loss = model(
audio = audio,
video = video,
text = text
)
loss.backward()
# after much training
sampled_text = model.generate(
audio = audio,
video = video,
seq_len = 512
)
Todo
- text generation code
- auto-handle start token for decoder
- positional embeddings for video and audio encoder
- enable register tokens for both video and audio encoder, inline with new research
- add audio and video reconstruction losses
- add similarity regularization from TTS research
Citations
@article{Piergiovanni2023Mirasol3BAM,
title = {Mirasol3B: A Multimodal Autoregressive model for time-aligned and contextual modalities},
author = {A. J. Piergiovanni and Isaac Noble and Dahun Kim and Michael S. Ryoo and Victor Gomes and Anelia Angelova},
journal = {ArXiv},
year = {2023},
volume = {abs/2311.05698},
url = {https://api.semanticscholar.org/CorpusID:265129010}
}
@inproceedings{Liu2022TowardsBF,
title = {Towards Better Few-Shot and Finetuning Performance with Forgetful Causal Language Models},
author = {Hao Liu and Xinyang Geng and Lisa Lee and Igor Mordatch and Sergey Levine and Sharan Narang and P. Abbeel},
year = {2022},
url = {https://api.semanticscholar.org/CorpusID:256416540}
}
@article{Darcet2023VisionTN,
title = {Vision Transformers Need Registers},
author = {Timoth'ee Darcet and Maxime Oquab and Julien Mairal and Piotr Bojanowski},
journal = {ArXiv},
year = {2023},
volume = {abs/2309.16588},
url = {https://api.semanticscholar.org/CorpusID:263134283}
}
@article{Bondarenko2023QuantizableTR,
title = {Quantizable Transformers: Removing Outliers by Helping Attention Heads Do Nothing},
author = {Yelysei Bondarenko and Markus Nagel and Tijmen Blankevoort},
journal = {ArXiv},
year = {2023},
volume = {abs/2306.12929},
url = {https://api.semanticscholar.org/CorpusID:259224568}
}
@misc{shi2023enhance,
title = {Enhance audio generation controllability through representation similarity regularization},
author = {Yangyang Shi and Gael Le Lan and Varun Nagaraja and Zhaoheng Ni and Xinhao Mei and Ernie Chang and Forrest Iandola and Yang Liu and Vikas Chandra},
year = {2023},
eprint = {2309.08773},
archivePrefix = {arXiv},
primaryClass = {cs.SD}
}
.\lucidrains\mirasol-pytorch\setup.py
# 导入设置和查找包的函数
from setuptools import setup, find_packages
# 设置包的信息
setup(
# 包的名称
name = 'mirasol-pytorch',
# 查找所有包,不排除任何包
packages = find_packages(exclude=[]),
# 版本号
version = '0.0.16',
# 许可证类型
license='MIT',
# 描述信息
description = 'Mirasol - Pytorch',
# 作者
author = 'Phil Wang',
# 作者邮箱
author_email = 'lucidrains@gmail.com',
# 长描述内容类型
long_description_content_type = 'text/markdown',
# 项目链接
url = 'https://github.com/lucidrains/mirasol-pytorch',
# 关键词列表
keywords = [
'artificial intelligence',
'deep learning',
'multimodality'
],
# 安装依赖项
install_requires=[
'beartype',
'einops>=0.7.0',
'x-transformers>=1.25.10',
'torch>=2.0'
],
# 分类标签
classifiers=[
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
Data source
The enwik8 data was downloaded from the Hutter prize page: prize.hutter1.net/
.\lucidrains\mixture-of-attention\mixture_of_attention\attend.py
# 导入必要的库
from collections import namedtuple
from functools import wraps
from packaging import version
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange
# 定义一个命名元组EfficientAttentionConfig,用于存储配置信息
EfficientAttentionConfig = namedtuple('EfficientAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
# 定义辅助函数exists,用于检查变量是否存在
def exists(val):
return val is not None
# 定义装饰器once,确保函数只被调用一次
def once(fn):
called = False
@wraps(fn)
def inner(x):
nonlocal called
if called:
return
called = True
return fn(x)
return inner
# 用once装饰print函数,确保只打印一次
print_once = once(print)
# 主要类Attend
class Attend(nn.Module):
def __init__(
self,
dropout = 0.,
causal = False,
flash = False
):
super().__init__()
self.dropout = dropout
self.attn_dropout = nn.Dropout(dropout)
self.causal = causal
self.flash = flash
assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
# 确定cuda和cpu的高效注意力配置
self.cpu_config = EfficientAttentionConfig(True, True, True)
self.cuda_config = None
if not torch.cuda.is_available() or not flash:
return
device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
if device_properties.major == 8 and device_properties.minor == 0:
print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
self.cuda_config = EfficientAttentionConfig(True, False, False)
else:
print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
self.cuda_config = EfficientAttentionConfig(False, True, True)
# 生成mask
def get_mask(self, i, j, device):
return torch.ones((i, j), device=device, dtype=torch.bool).triu(j - i + 1)
# Flash Attention
def flash_attn(self, q, k, v, mask = None):
_, heads, q_len, _, k_len, is_cuda = *q.shape, k.shape[-2], q.is_cuda
if exists(mask) and mask.ndim != 4:
mask = rearrange(mask, 'b j -> b 1 1 j')
mask = mask.expand(-1, heads, q_len, -1)
config = self.cuda_config if is_cuda else self.cpu_config
with torch.backends.cuda.sdp_kernel(**config._asdict()):
out = F.scaled_dot_product_attention(
q, k, v,
attn_mask = mask,
dropout_p = self.dropout if self.training else 0.,
is_causal = self.causal
)
return out
# 前向传播函数
def forward(self, q, k, v, mask = None):
"""
einstein notation
b - batch
h - heads
n, i, j - sequence length (base sequence length, source, target)
d - feature dimension
"""
n, device = q.shape[-2], q.device
scale = q.shape[-1] ** -0.5
if self.flash:
return self.flash_attn(q, k, v, mask = mask)
# 相似度计算
sim = einsum("b h i d, b h j d -> b h i j", q, k) * scale
# key padding mask
if exists(mask):
mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
# causal mask
if self.causal:
causal_mask = self.get_mask(n, device)
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
# 注意力计算
attn = sim.softmax(dim=-1)
attn = self.attn_dropout(attn)
# 聚合值
out = einsum("b h i j, b h j d -> b h i d", attn, v)
return out
.\lucidrains\mixture-of-attention\mixture_of_attention\autoregressive_wrapper.py
# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块
from torch import nn
# 从 torch 库中导入 nn.functional 模块,并重命名为 F
import torch.nn.functional as F
# 从 einops 库中导入 rearrange 函数
from einops import rearrange
# 辅助函数
# 判断变量是否存在的函数
def exists(val):
return val is not None
# 评估装饰器函数
def eval_decorator(fn):
def inner(model, *args, **kwargs):
# 保存模型当前是否为训练状态
was_training = model.training
# 将模型设置为评估状态
model.eval()
# 调用传入的函数,并传入模型、参数和关键字参数
out = fn(model, *args, **kwargs)
# 恢复模型之前的训练状态
model.train(was_training)
return out
return inner
# top k 过滤
# 根据阈值过滤 logits 中的 top k 值
def top_k(logits, thres = 0.9):
# 计算需要保留的 top k 值的数量
k = int((1 - thres) * logits.shape[-1])
# 获取 top k 值和对应的索引
val, ind = torch.topk(logits, k)
# 创建一个与 logits 相同形状的张量,填充为负的最大值
probs = torch.full_like(logits, -torch.finfo(logits.dtype).max)
# 根据索引将 top k 值填充到 probs 中
probs.scatter_(1, ind, val)
return probs
# 自回归包装器类
class AutoregressiveWrapper(nn.Module):
def __init__(
self,
net,
pad_value = 0
):
super().__init__()
# 初始化属性
self.seq_len = net.seq_len
self.pad_value = pad_value
self.net = net
# 生成函数装饰器,用于生成序列
@torch.no_grad()
@eval_decorator
def generate(
self,
prompt,
seq_len,
temperature=1.0,
filter_thres=0.9,
**kwargs
):
# 获取 prompt 的形状和设备信息
b, t, device = *prompt.shape, prompt.device
out = prompt
# 生成序列
for _ in range(seq_len):
# 获取最后 self.seq_len 长度的序列,并传入网络获取 logits
logits = self.net(out[:, -self.seq_len:], **kwargs)[:, -1]
# 对 logits 进行 top k 过滤
filtered_logits = top_k(logits, thres = filter_thres)
# 计算概率分布
probs = F.softmax(filtered_logits / temperature, dim = -1)
# 从概率分布中采样一个值
sample = torch.multinomial(probs, 1)
# 将采样值拼接到输出序列中
out = torch.cat((out, sample), dim = -1)
# 去除前面的 prompt 部分,返回生成的序列
out = out[:, t:]
return out
# 前向传播函数
def forward(self, x, **kwargs):
# 获取输入 x 和标签 labels
x, labels = x[:, :-1], x[:, 1:]
# 将输入传入网络获取 logits
logits = self.net(x, **kwargs)
# 重新排列 logits 的维度
logits = rearrange(logits, "b c n -> b n c")
# 计算交叉熵损失
return F.cross_entropy(logits, labels)
.\lucidrains\mixture-of-attention\mixture_of_attention\mixture_of_attention.py
# 导入数学库
import math
# 导入 PyTorch 库
import torch
import torch.nn.functional as F
from torch import Tensor, nn, einsum
# 导入类型提示
from typing import Tuple, Optional
# 导入 einops 库中的函数
from einops import rearrange, repeat, reduce, pack, unpack
# 导入自定义模块
from mixture_of_attention.attend import Attend
from mixture_of_attention.rotary_emb import apply_rotary_pos_emb
from local_attention import LocalMHA
from colt5_attention import CoordinateDescentRouter
# 辅助函数
# 判断变量是否存在
def exists(val):
return val is not None
# 如果变量存在则返回其值,否则返回默认值
def default(val, d):
return val if exists(val) else d
# 将张量打包成指定模式的形状
def pack_one(t, pattern):
return pack([t], pattern)
# 将打包后的张量解包成原始形状
def unpack_one(t, ps, pattern):
return unpack(t, ps, pattern)[0]
# 将张量填充到指定的倍数
def pad_to_multiple(tensor, multiple, dim = -1, value = 0):
seq_len = tensor.shape[dim]
m = seq_len / multiple
if m.is_integer():
return tensor, seq_len
remainder = math.ceil(m) * multiple - seq_len
pad_offset = (0,) * (-1 - dim) * 2
padded_tensor = F.pad(tensor, (*pad_offset, 0, remainder), value = value)
return padded_tensor, seq_len
# 归一化
# RMS 归一化模块
class RMSNorm(nn.Module):
def __init__(self, dim, groups = 1):
super().__init__()
self.scale = dim ** 0.5
self.gamma = nn.Parameter(torch.ones(groups, dim, 1))
def forward(self, x):
normed = F.normalize(x, dim = -2)
return normed * self.scale * self.gamma
# 注意力机制
# 注意力模块
class Attention(nn.Module):
def __init__(
self,
dim,
*,
dim_head = 64,
dim_context = None,
heads = 8,
causal = False,
groups = 1, # 定义专家的数量
dropout = 0.,
flash = False,
prenorm = False
):
super().__init__()
self.heads = heads
self.groups = groups
dim_inner = dim_head * heads
dim_context = default(dim_context, dim)
self.norm = RMSNorm(dim, groups = groups) if prenorm else nn.Identity()
self.context_norm = RMSNorm(dim_context, groups = groups) if prenorm else nn.Identity()
self.attend = Attend(
dropout = dropout,
causal = causal,
flash = flash
)
# 空键/值,用于防止一行全部被掩码掉
self.null_kv = nn.Parameter(torch.randn(2, groups, heads, 1, dim_head))
# 利用卷积组并行处理专家
self.to_q = nn.Conv1d(dim * groups, dim_inner * groups, 1, bias = False, groups = groups)
self.to_kv = nn.Conv1d(dim_context * groups, dim_inner * 2 * groups, 1, bias = False, groups = groups)
self.to_out = nn.Conv1d(dim_inner * groups, dim * groups, 1, bias = False, groups = groups)
def forward(
self,
x,
context = None,
mask = None,
queries_scale = None,
keys_scale = None,
values_scale = None,
output_scale = None,
rotary_emb: Optional[Tuple[Tensor, Tensor]] = None
):
"""
einops
b - batch
g - groups
n - sequence
d - feature dimension
"""
# 获取输入张量的形状信息
b, g, h = x.shape[0], self.groups, self.heads
# 判断是否只有一个专家
one_expert = x.ndim == 3
# 如果只有一个专家,则将其维度扩展为4维
if one_expert:
assert g == 1
x = rearrange(x, 'b n d -> b 1 n d')
# 断言输入张量为4维
assert x.ndim == 4
# 断言输入张量的第二维为groups
assert x.shape[1] == g
# 将groups折叠到特征维度中,以便通过分组卷积一次处理
x = rearrange(x, 'b g n d -> b g d n')
# 处理交叉注意力的上下文
if exists(context):
context_one_expert = context.ndim == 3
if context_one_expert:
assert g == 1
context = rearrange(context, 'b n d -> b 1 n d')
assert context.ndim == 4
assert context.shape[1] == g
context = rearrange(context, 'b g n d -> b g d n')
# 如果没有传入context,则使用输入张量x
context = default(context, x)
# 处理mask
if exists(mask):
if mask.ndim == 2:
mask = repeat(mask, 'b n -> (b g) n', g = g)
elif mask.ndim == 3:
mask = rearrange(mask, 'b g n -> (b g) n')
mask = F.pad(mask, (1, 0), value = True)
# 如果适用,进行预归一化
x = self.norm(x)
context = self.context_norm(context)
# 将groups折叠到维度中以进行分组卷积
x, context = map(lambda t: rearrange(t, 'b g d n -> b (g d) n'), (x, context))
# 获取查询、键、值
q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = 1))
# 拆分头部并将groups合并到批次中
q, k, v = map(lambda t: rearrange(t, 'b (g h d) n -> b g h n d', h = h, g = g), (q, k, v))
# 旋转嵌入
if exists(rotary_emb):
q_rotary_emb, k_rotary_emb = rotary_emb
if q_rotary_emb.ndim > 2:
q_rotary_emb = rearrange(q_rotary_emb, 'b g n d -> b g 1 n d')
if k_rotary_emb.ndim > 2:
k_rotary_emb = rearrange(k_rotary_emb, 'b g n d -> b g 1 n d')
q = apply_rotary_pos_emb(q_rotary_emb, q)
k = apply_rotary_pos_emb(k_rotary_emb, k)
# 如果传入了queries_scale,则给查询加权
if exists(queries_scale):
q = q * queries_scale
# 如果传入了keys_scale,则给键加权
if exists(keys_scale):
k = k * keys_scale
# 如果传入了values_scale,则给值加权
if exists(values_scale):
v = v * values_scale
# 将groups合并到批次中
q, k, v = map(lambda t: rearrange(t, 'b g ... -> (b g) ...'), (q, k, v))
# 连接空键/值,以防止一行中所有元素都被屏蔽并节省大量麻烦
nk, nv = map(lambda t: repeat(t, 'g h 1 d -> (b g) h 1 d', b = b), self.null_kv)
k = torch.cat((nk, k), dim = -2)
v = torch.cat((nv, v), dim = -2)
# 注意力机制
out = self.attend(q, k, v, mask = mask)
# 合并头部输出
out = rearrange(out, '(b g) h n d -> b (g h d) n', g = g)
out = self.to_out(out)
out = rearrange(out, 'b (g d) n -> b g n d', g = g)
# 如果只有一个专家,则将其维度还原为3维
if one_expert:
out = rearrange(out, 'b 1 n d -> b n d')
# 如果传入了output_scale,则给输出加权
if exists(output_scale):
out = out * output_scale
return out
# 定义混合注意力机制的类
class MixtureOfAttention(nn.Module):
# 初始化函数
def __init__(
self,
dim,
*,
num_routed_queries,
num_routed_key_values,
dim_context = None,
local_attn = False,
local_attn_window_size = None,
num_experts = 2,
dim_head = 64,
heads = 8,
dropout = 0.,
use_triton = True,
flash_attn = True,
prenorm = True,
average_routed = False,
**kwargs
):
super().__init__()
dim_context = default(dim_context, dim)
self.num_routed_queries = num_routed_queries
self.num_routed_key_values = num_routed_key_values
# 如果不是本地注意力,创建一个参数化的空路由令牌
self.null_routed_token = nn.Parameter(torch.randn(1, 1, dim)) if not local_attn else None
self.average_routed = average_routed
self.local_attn = None
# 如果使用本地注意力,创建本地多头注意力对象
if local_attn:
assert exists(local_attn_window_size)
self.local_attn = LocalMHA(
dim = dim,
dim_head = dim_head,
heads = heads,
prenorm = prenorm,
window_size = local_attn_window_size
)
# 创建查询路由器对象
self.query_router = CoordinateDescentRouter(
dim,
num_routing_tokens = num_experts,
use_triton = use_triton,
**kwargs
)
# 创建键值路由器对象
self.key_value_router = CoordinateDescentRouter(
dim_context,
num_routing_tokens = num_experts,
use_triton = use_triton,
**kwargs
)
# 创建注意力对象
self.attn = Attention(
dim = dim,
dim_context = dim_context,
dim_head = dim_head,
heads = heads,
groups = num_experts,
dropout = dropout,
flash = flash_attn,
prenorm = prenorm
)
# 返回模型参数所在的设备
@property
def device(self):
return next(self.parameters()).device
# 前向传播函数
def forward(
self,
x,
context = None,
mask = None,
context_mask = None,
num_routed_queries = None,
num_routed_key_values = None,
rotary_emb = None
):
# 设置路由查询数量为默认值或者传入的值
num_routed_queries = default(num_routed_queries, self.num_routed_queries)
# 设置路由键值对数量为默认值或者传入的值
num_routed_key_values = default(num_routed_key_values, self.num_routed_key_values)
# 判断是否进行跨注意力
is_cross_attn = exists(context)
# 断言不能同时存在本地注意力和跨注意力
assert not (exists(self.local_attn) and is_cross_attn), 'cannot do cross attention with local attention (only for self attention)'
if not is_cross_attn:
# 如果不是跨注意力,则使用自注意力
context = x
context_mask = mask
# 获取查询索引、查询分数、查询、查询掩码
query_indices, query_scores, queries, query_mask = self.query_router(x, mask = mask, num_tokens = num_routed_queries, keep_one_route_dim = True)
query_scores = rearrange(query_scores, 'b g n -> b g n 1')
# 获取键值索引、键值分数、键值、键值掩码
kv_indices, key_value_scores, key_values, key_value_mask = self.key_value_router(context, mask = context_mask, num_tokens = num_routed_key_values, keep_one_route_dim = True)
key_value_scores = rearrange(key_value_scores, 'b g n -> b g 1 n 1')
# 旋转嵌入
if exists(rotary_emb):
assert not is_cross_attn, 'rotary embedding should not be used for cross attending'
q_rotary_emb = rotary_emb[query_indices] if exists(query_indices) else rotary_emb
k_rotary_emb = rotary_emb[kv_indices] if exists(kv_indices) else rotary_emb
rotary_emb = (q_rotary_emb, k_rotary_emb)
# 注意力计算
attn_out = self.attn(
queries,
rotary_emb = rotary_emb,
context = key_values,
mask = key_value_mask,
values_scale = key_value_scores,
output_scale = query_scores
)
local_out = None
if exists(self.local_attn):
local_out = self.local_attn(x, mask = mask)
need_route_queries = exists(query_indices)
if not need_route_queries:
out = attn_out
if exists(local_out):
local_out = rearrange(local_out, 'b n d -> b 1 n d')
out = torch.cat((local_out, out), dim = 1)
out = reduce(attn_out, 'b e n d -> b n d', 'mean')
if exists(mask):
out = out.masked_fill(~mask[..., None], 0.)
return out
out = torch.zeros_like(x)
counts = torch.zeros(x.shape[:-1], device = x.device)
query_indices = rearrange(query_indices, 'b g n -> b (g n)')
attn_out = rearrange(attn_out, 'b g n d -> b (g n) d')
expanded_query_indices = repeat(query_indices, 'b n -> b n d', d = x.shape[-1])
attn_out_summed = out.scatter_add(1, expanded_query_indices, attn_out)
ones = torch.ones(attn_out.shape[:-1], device = self.device)
if exists(query_mask):
ones = ones * rearrange(query_mask, 'b g n -> b (g n)')
counts = counts.scatter_add(1, query_indices, ones)
counts = rearrange(counts, '... -> ... 1')
has_unrouted = not exists(local_out)
if not has_unrouted:
counts = counts + 1
attn_out_summed = attn_out_summed + local_out
else:
not_routed_mask = counts == 0
attn_out_summed = attn_out_summed.masked_fill(not_routed_mask, 0.)
out = attn_out_summed
# 如果需要,进行平均
if self.average_routed:
out = out / counts.clamp(min = 1e-5)
# 对于未路由的位置,使用学习到的路由令牌而不是仅仅是0
if has_unrouted:
out = torch.where(
not_routed_mask,
self.null_routed_token,
out,
)
if exists(mask):
out = out.masked_fill(~mask[..., None], 0.)
return out
# 定义一个混合自回归注意力模型类
class MixtureOfAutoregressiveAttention(nn.Module):
def __init__(
self,
dim,
*,
num_routed_queries, # 路由查询的数量
num_routed_key_values, # 路由键值对的数量
local_attn_window_size, # 本地注意力窗口大小
routed_window_size = None, # 路由窗口大小,默认为None
num_experts = 2, # 专家数量,默认为2
dim_head = 64, # 头维度,默认为64
heads = 8, # 头数,默认为8
dropout = 0., # 丢弃率,默认为0
use_triton = False, # 是否使用 Triton,默认为False
flash_attn = True, # 是否使用 Flash 注意力,默认为True
prenorm = True, # 是否使用预归一化,默认为True
average_routed = False, # 是否平均路由,默认为False
**kwargs
):
super().__init__()
self.num_routed_queries = num_routed_queries # 初始化路由查询数量
self.num_routed_key_values = num_routed_key_values # 初始化路由键值对数量
self.num_experts = num_experts # 初始化专家数量
self.null_tokens = nn.Parameter(torch.randn(num_experts, dim)) # 初始化空令牌
routed_window_size = default(routed_window_size, local_attn_window_size) # 设置路由窗口大小为默认值或本地注意力窗口大小
self.routed_window_size = routed_window_size # 初始化路由窗口大小
self.average_routed = average_routed # 初始化是否平均路由
# 创建本地多头自注意力模块
self.local_attn = LocalMHA(
dim = dim,
dim_head = dim_head,
heads = heads,
prenorm = prenorm,
causal = True,
window_size = local_attn_window_size
)
# 创建查询路由器
self.query_router = CoordinateDescentRouter(
dim,
num_routing_tokens = num_experts,
use_triton = use_triton,
**kwargs
)
# 创建键值路由器
self.key_value_router = CoordinateDescentRouter(
dim,
num_routing_tokens = num_experts,
use_triton = use_triton,
**kwargs
)
# 创建注意力模块
self.attn = Attention(
dim = dim,
dim_head = dim_head,
heads = heads,
groups = num_experts,
dropout = dropout,
flash = flash_attn,
prenorm = prenorm
)
# 定义设备属性
@property
def device(self):
return next(self.parameters()).device
# 前向传播函数
def forward(
self,
x,
rotary_emb = None,
num_routed_queries = None,
num_routed_key_values = None
.\lucidrains\mixture-of-attention\mixture_of_attention\transformer.py
# 导入所需的库
import torch
import torch.nn.functional as F
from torch import nn, einsum
# 导入重排操作库
from einops import rearrange
# 导入自定义的注意力机制类
from mixture_of_attention.mixture_of_attention import MixtureOfAutoregressiveAttention
# 导入自定义的旋转嵌入类
from mixture_of_attention.rotary_emb import RotaryEmbedding
# 辅助函数
# 判断变量是否存在的辅助函数
def exists(val):
return val is not None
# 类定义
# RMS 归一化类
class RMSNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.scale = dim ** 0.5
self.gamma = nn.Parameter(torch.ones(dim))
def forward(self, x):
normed = F.normalize(x, dim = -1)
return normed * self.scale * self.gamma
# 前馈神经网络类
def FeedForward(dim, mult = 4):
return nn.Sequential(
RMSNorm(dim),
nn.Linear(dim, dim * mult),
nn.GELU(),
nn.Linear(dim * mult, dim)
)
# 主类定义
# Transformer 模型类
class Transformer(nn.Module):
def __init__(
self,
*,
dim,
num_tokens,
depth,
seq_len,
local_attn_window_size,
num_routed_queries,
num_routed_key_values,
num_experts,
cosine_sim_routing = True,
routed_window_size = None,
dim_head = 64,
heads = 8,
ff_mult = 4,
use_triton = True,
routed_rotary_emb = True
):
super().__init__()
self.token_emb = nn.Embedding(num_tokens, dim)
self.pos_emb = nn.Embedding(seq_len, dim)
self.seq_len = seq_len
self.rotary_emb = RotaryEmbedding(dim_head) if routed_rotary_emb else None
self.layers = nn.ModuleList([])
# 创建多层 Transformer 模型
for _ in range(depth):
self.layers.append(nn.ModuleList([
MixtureOfAutoregressiveAttention(
dim = dim,
local_attn_window_size = local_attn_window_size,
routed_window_size = routed_window_size,
num_routed_queries = num_routed_queries,
num_routed_key_values = num_routed_key_values,
cosine_sim_routing = cosine_sim_routing,
num_experts = num_experts,
dim_head = dim_head,
heads = heads,
use_triton = use_triton
),
FeedForward(dim = dim, mult = ff_mult)
]))
# 输出层
self.to_logits = nn.Sequential(
RMSNorm(dim),
nn.Linear(dim, num_tokens)
)
# ��取设备信息
@property
def device(self):
return next(self.parameters()).device
# 前向传播函数
def forward(self, x):
x = self.token_emb(x)
x = x + self.pos_emb(torch.arange(x.shape[-2], device = self.device))
rotary_emb = None
if exists(self.rotary_emb):
rotary_emb = self.rotary_emb(x.shape[1])
# 多层 Transformer 模型的前向传播
for attn, ff in self.layers:
x = attn(x, rotary_emb = rotary_emb) + x
x = ff(x) + x
return self.to_logits(x)
.\lucidrains\mixture-of-attention\mixture_of_attention\__init__.py
# 从mixture_of_attention包中导入MixtureOfAttention、MixtureOfAutoregressiveAttention和Attention类
from mixture_of_attention.mixture_of_attention import (
MixtureOfAttention,
MixtureOfAutoregressiveAttention,
Attention
)
Mixture-of-Attention
Some personal experiments around routing tokens to different autoregressive attention, akin to mixture-of-experts
Learned from researcher friend that this has been tried in Switch Transformers unsuccessfully, but I'll give it a go, bringing in some learning points from recent papers like CoLT5.
In my opinion, the CoLT5 paper basically demonstrates mixture of attention already for 2 experts. This just has to be generalized to greater than 2 experts, and for autoregressive case. Local attention branch would just be a special case of one expert with fixed routing. If I route only half the tokens, that would lead to a savings of 4x. If I can show even ~4 experts being better than 1 attention, that should be a win.
Appreciation
-
Stability and 🤗 Huggingface for their generous sponsorships to work on and open source cutting edge artificial intelligence research
-
einops for making tensor manipulation fun and easy
Install
$ pip install mixture-of-attention
Usage
import torch
from mixture_of_attention import MixtureOfAttention
mixture_of_attn = MixtureOfAttention(
dim = 512,
dim_context = 256,
num_routed_queries = 16,
num_routed_key_values = 16,
num_experts = 2,
dim_head = 64,
heads = 8
)
x = torch.randn(1, 1024, 512)
mask = torch.ones((1, 1024)).bool()
context = torch.randn(1, 512, 256)
context_mask = torch.ones((1, 512)).bool()
mixture_of_attn(x, context = context, mask = mask) # (1, 1024, 512)
Autoregressive flavor
import torch
from mixture_of_attention import MixtureOfAutoregressiveAttention
mixture_of_attn = MixtureOfAutoregressiveAttention(
dim = 512,
local_attn_window_size = 64, # local attention window size
routed_window_size = None, # will be set to the same as local_attn_window_size if None. ideally less than or equal to local attention window size for full receptive field
num_routed_queries = 12,
num_routed_key_values = 12,
num_experts = 2,
dim_head = 64,
heads = 8
)
x = torch.randn(1, 1023, 512)
out = mixture_of_attn(x) # (1, 1023, 512)
Todo
-
allow for local attention to be automatically included, either for grouped attention, or use
LocalMHAfromlocal-attentionrepository in parallel, weighted properly -
make it work for autoregressive
-
try dynamic routing tokens, using projection of masked mean-pooled queries
-
try out arxiv.org/abs/2210.05…
Citations
@inproceedings{Ainslie2023CoLT5FL,
title = {CoLT5: Faster Long-Range Transformers with Conditional Computation},
author = {Joshua Ainslie and Tao Lei and Michiel de Jong and Santiago Ontan'on and Siddhartha Brahma and Yury Zemlyanskiy and David Uthus and Mandy Guo and James Lee-Thorp and Yi Tay and Yun-Hsuan Sung and Sumit Sanghai},
year = {2023}
}
@inproceedings{dao2022flashattention,
title = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
author = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
booktitle = {Advances in Neural Information Processing Systems},
year = {2022}
}
@article{Wright2015CoordinateDA,
title = {Coordinate descent algorithms},
author = {Stephen J. Wright},
journal = {Mathematical Programming},
year = {2015},
volume = {151},
pages = {3-34}
}
@article{Schmitzer2016StabilizedSS,
title = {Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems},
author = {Bernhard Schmitzer},
journal = {ArXiv},
year = {2016},
volume = {abs/1610.06519}
}
@inproceedings{rogozhnikov2022einops,
title = {Einops: Clear and Reliable Tensor Manipulations with Einstein-like Notation},
author = {Alex Rogozhnikov},
booktitle = {International Conference on Learning Representations},
year = {2022},
url = {https://openreview.net/forum?id=oapKSVM2bcj}
}
.\lucidrains\mixture-of-attention\setup.py
# 导入设置和查找包的函数
from setuptools import setup, find_packages
# 设置包的元数据
setup(
name = 'mixture-of-attention', # 包的名称
packages = find_packages(exclude=[]), # 查找所有包
version = '0.0.24', # 版本号
license='MIT', # 许可证
description = 'Mixture of Attention', # 描述
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
long_description_content_type = 'text/markdown', # 长描述内容类型
url = 'https://github.com/lucidrains/mixture-of-attention', # URL
keywords = [ # 关键词列表
'artificial intelligence',
'deep learning',
'transformers',
'attention mechanism',
'mixture-of-experts',
'routed attention'
],
install_requires=[ # 安装依赖
'colt5-attention>=0.10.14',
'einops>=0.6.1',
'local-attention>=1.8.6',
'torch>=1.6',
],
classifiers=[ # 分类器
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\mixture-of-attention\train.py
# 导入必要的库
import gzip
import random
import tqdm
import numpy as np
import torch
from torch.optim import Adam
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
from mixture_of_attention.transformer import Transformer
from mixture_of_attention.autoregressive_wrapper import AutoregressiveWrapper
# 定义常量
NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 1e-4
VALIDATE_EVERY = 100
PRIME_LENGTH = 128
GENERATE_EVERY = 500
GENERATE_LENGTH = 512
SEQ_LEN = 512
# 定义辅助函数
# 从 token 解码为字符
def decode_token(token):
return str(chr(max(32, token)))
# 从 tokens 解码为字符串
def decode_tokens(tokens):
return "".join(list(map(decode_token, tokens)))
# 实例化 Transformer 模型
model = Transformer(
num_tokens = 256,
dim = 512,
depth = 8,
num_experts = 2,
seq_len = SEQ_LEN,
local_attn_window_size = 64,
num_routed_queries = 32,
num_routed_key_values = 64,
cosine_sim_routing = True,
use_triton = True
)
model = AutoregressiveWrapper(model).cuda()
# 准备 enwik8 数据
with gzip.open("./data/enwik8.gz") as file:
data = np.frombuffer(file.read(int(95e6)), dtype=np.uint8).copy()
np_train, np_valid = np.split(data, [int(90e6)])
data_train, data_val = torch.from_numpy(np_train), torch.from_numpy(np_valid)
# 定义自定义数据集类
class TextSamplerDataset(Dataset):
def __init__(self, data, seq_len):
super().__init__()
self.data = data
self.seq_len = seq_len
def __getitem__(self, index):
rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
full_seq = self.data[rand_start : rand_start + self.seq_len + 1].long()
return full_seq.cuda()
def __len__(self):
return self.data.size(0) // self.seq_len
# 创建训练集和验证集的 DataLoader
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
train_loader = cycle(DataLoader(train_dataset, batch_size=BATCH_SIZE))
val_loader = cycle(DataLoader(val_dataset, batch_size=BATCH_SIZE))
# 定义优化器
optim = Adam(model.parameters(), lr = LEARNING_RATE)
# 训练模型
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval = 10.0, desc = "training"):
model.train()
for _ in range(GRADIENT_ACCUMULATE_EVERY):
loss = model(next(train_loader))
loss.backward(loss / GRADIENT_ACCUMULATE_EVERY)
print(f"training loss: {loss.item()}")
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
optim.step()
optim.zero_grad()
if i % VALIDATE_EVERY == 0:
model.eval()
with torch.no_grad():
loss = model(next(val_loader))
print(f"validation loss: {loss.item()}")
if i % GENERATE_EVERY == 0:
model.eval()
inp = random.choice(val_dataset)[:PRIME_LENGTH]
prime = decode_tokens(inp)
print(f"%s \n\n %s", (prime, "*" * 100))
sample = model.generate(inp[None, ...], GENERATE_LENGTH)
output_str = decode_tokens(sample[0])
print(output_str, "\n")
.\lucidrains\mixture-of-experts\mixture_of_experts\mixture_of_experts.py
# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块
from torch import nn
# 从 torch 库中导入 nn.functional 模块,并使用别名 F
import torch.nn.functional as F
# 导入 math 库
import math
# 从 inspect 库中导入 isfunction 函数
# 常量定义
MIN_EXPERT_CAPACITY = 4
# 辅助函数
# 默认值函数,如果 val 为 None,则返回 default_val
def default(val, default_val):
# 如果 default_val 是函数,则调用该函数,否则直接返回 default_val
default_val = default_val() if isfunction(default_val) else default_val
return val if val is not None else default_val
# 将元素 el 转换为元组
def cast_tuple(el):
return el if isinstance(el, tuple) else (el,)
# 与张量相关的辅助函数
# 获取张量 t 中最大的值和对应的索引
def top1(t):
values, index = t.topk(k=1, dim=-1)
values, index = map(lambda x: x.squeeze(dim=-1), (values, index))
return values, index
# 计算张量 t 在指定维度上的累积和,不包括当前位置的值
def cumsum_exclusive(t, dim=-1):
num_dims = len(t.shape)
num_pad_dims = - dim - 1
pre_padding = (0, 0) * num_pad_dims
pre_slice = (slice(None),) * num_pad_dims
padded_t = F.pad(t, (*pre_padding, 1, 0)).cumsum(dim=dim)
return padded_t[(..., slice(None, -1), *pre_slice)]
# 安全的 one-hot 编码函数,避免索引超出范围
def safe_one_hot(indexes, max_length):
max_index = indexes.max() + 1
return F.one_hot(indexes, max(max_index + 1, max_length))[..., :max_length]
# 初始化张量 t,使用均匀分布
def init_(t):
dim = t.shape[-1]
std = 1 / math.sqrt(dim)
return t.uniform_(-std, std)
# 激活函数
# GELU 激活函数类
class GELU_(nn.Module):
def forward(self, x):
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))
# 如果 nn 模块中存在 GELU 函数,则使用该函数,否则使用自定义的 GELU_ 激活函数
GELU = nn.GELU if hasattr(nn, 'GELU') else GELU_
# 专家类
class Experts(nn.Module):
def __init__(self,
dim,
num_experts = 16,
hidden_dim = None,
activation = GELU):
super().__init__()
hidden_dim = default(hidden_dim, dim * 4)
num_experts = cast_tuple(num_experts)
w1 = torch.zeros(*num_experts, dim, hidden_dim)
w2 = torch.zeros(*num_experts, hidden_dim, dim)
w1 = init_(w1)
w2 = init_(w2)
self.w1 = nn.Parameter(w1)
self.w2 = nn.Parameter(w2)
self.act = activation()
def forward(self, x):
hidden = torch.einsum('...nd,...dh->...nh', x, self.w1)
hidden = self.act(hidden)
out = torch.einsum('...nh,...hd->...nd', hidden, self.w2)
return out
# 下面的代码几乎完全从官方的 tensorflow 版本转录而来,相关论文也是基于此版本编写
# https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/research/moe.py
# 门控网络
class Top2Gating(nn.Module):
def __init__(
self,
dim,
num_gates,
eps = 1e-9,
outer_expert_dims = tuple(),
second_policy_train = 'random',
second_policy_eval = 'random',
second_threshold_train = 0.2,
second_threshold_eval = 0.2,
capacity_factor_train = 1.25,
capacity_factor_eval = 2.):
super().__init__()
self.eps = eps
self.num_gates = num_gates
self.w_gating = nn.Parameter(torch.randn(*outer_expert_dims, dim, num_gates))
self.second_policy_train = second_policy_train
self.second_policy_eval = second_policy_eval
self.second_threshold_train = second_threshold_train
self.second_threshold_eval = second_threshold_eval
self.capacity_factor_train = capacity_factor_train
self.capacity_factor_eval = capacity_factor_eval
# 普通的专家混合模型
class MoE(nn.Module):
# 初始化函数,设置模型参数和属性
def __init__(self,
dim,
num_experts = 16,
hidden_dim = None,
activation = nn.ReLU,
second_policy_train = 'random',
second_policy_eval = 'random',
second_threshold_train = 0.2,
second_threshold_eval = 0.2,
capacity_factor_train = 1.25,
capacity_factor_eval = 2.,
loss_coef = 1e-2,
experts = None):
# 调用父类的初始化函数
super().__init__()
# 设置模型的专家数量
self.num_experts = num_experts
# 设置门控参数
gating_kwargs = {'second_policy_train': second_policy_train, 'second_policy_eval': second_policy_eval, 'second_threshold_train': second_threshold_train, 'second_threshold_eval': second_threshold_eval, 'capacity_factor_train': capacity_factor_train, 'capacity_factor_eval': capacity_factor_eval}
# 创建门控对象
self.gate = Top2Gating(dim, num_gates = num_experts, **gating_kwargs)
# 创建专家对象
self.experts = default(experts, lambda: Experts(dim, num_experts = num_experts, hidden_dim = hidden_dim, activation = activation))
# 设置损失系数
self.loss_coef = loss_coef
# 前向传播函数
def forward(self, inputs, **kwargs):
# 获取输入的形状信息
b, n, d, e = *inputs.shape, self.num_experts
# 获取门控输出和损失
dispatch_tensor, combine_tensor, loss = self.gate(inputs)
# 将输入数据分发给专家
expert_inputs = torch.einsum('bnd,bnec->ebcd', inputs, dispatch_tensor)
# 将专家输入数据传递给专家模型
orig_shape = expert_inputs.shape
expert_inputs = expert_inputs.reshape(e, -1, d)
expert_outputs = self.experts(expert_inputs)
expert_outputs = expert_outputs.reshape(*orig_shape)
# 将专家输出数据合并
output = torch.einsum('ebcd,bnec->bnd', expert_outputs, combine_tensor)
# 返回输出和损失乘以损失系数
return output, loss * self.loss_coef
# 定义一个名为 HeirarchicalMoE 的类,表示两级层次混合专家模型
class HeirarchicalMoE(nn.Module):
def __init__(self,
dim,
num_experts = (4, 4), # 设置专家数量,默认为 (4, 4)
hidden_dim = None, # 隐藏层维度,默认为 None
activation = nn.ReLU, # 激活函数,默认为 ReLU
second_policy_train = 'random', # 第二级门控策略(训练阶段),默认为 'random'
second_policy_eval = 'random', # 第二级门控策略(评估阶段),默认为 'random'
second_threshold_train = 0.2, # 第二级门控阈值(训练阶段),默认为 0.2
second_threshold_eval = 0.2, # 第二级门控阈值(评估阶段),默认为 0.2
capacity_factor_train = 1.25, # 容量因子(训练阶段),默认为 1.25
capacity_factor_eval = 2., # 容量因子(评估阶段),默认为 2.0
loss_coef = 1e-2, # 损失系数,默认为 0.01
experts = None): # 专家模型,默认为 None
super().__init__()
assert len(num_experts) == 2, 'only 2 levels of heirarchy for experts allowed for now' # 断言,只允许两级专家层次
num_experts_outer, num_experts_inner = num_experts
self.num_experts_outer = num_experts_outer
self.num_experts_inner = num_experts_inner
gating_kwargs = {'second_policy_train': second_policy_train, 'second_policy_eval': second_policy_eval, 'second_threshold_train': second_threshold_train, 'second_threshold_eval': second_threshold_eval, 'capacity_factor_train': capacity_factor_train, 'capacity_factor_eval': capacity_factor_eval}
# 创建外层门控模块和内层门控模块
self.gate_outer = Top2Gating(dim, num_gates = num_experts_outer, **gating_kwargs)
self.gate_inner = Top2Gating(dim, num_gates = num_experts_inner, outer_expert_dims = (num_experts_outer,), **gating_kwargs)
# 创建专家模型
self.experts = default(experts, lambda: Experts(dim, num_experts = num_experts, hidden_dim = hidden_dim, activation = activation))
self.loss_coef = loss_coef
def forward(self, inputs, **kwargs):
b, n, d, eo, ei = *inputs.shape, self.num_experts_outer, self.num_experts_inner
dispatch_tensor_outer, combine_tensor_outer, loss_outer = self.gate_outer(inputs)
expert_inputs_outer = torch.einsum('bnd,bnec->ebcd', inputs, dispatch_tensor_outer)
# 构建“重要性”张量,用于第二级门控
importance = combine_tensor_outer.permute(2, 0, 3, 1).sum(dim=-1)
importance = 0.5 * ((importance > 0.5).float() + (importance > 0.).float())
dispatch_tensor_inner, combine_tensor_inner, loss_inner = self.gate_inner(expert_inputs_outer, importance = importance)
expert_inputs = torch.einsum('ebnd,ebnfc->efbcd', expert_inputs_outer, dispatch_tensor_inner)
# 通过专家模型处理专家输入
orig_shape = expert_inputs.shape
expert_inputs = expert_inputs.reshape(eo, ei, -1, d)
expert_outputs = self.experts(expert_inputs)
expert_outputs = expert_outputs.reshape(*orig_shape)
# 合并专家输出
expert_outputs_outer = torch.einsum('efbcd,ebnfc->ebnd', expert_outputs, combine_tensor_inner)
output = torch.einsum('ebcd,bnec->bnd', expert_outputs_outer, combine_tensor_outer)
return output, (loss_outer + loss_inner) * self.loss_coef