Lucidrains 系列项目源码解析(九十二)
.\lucidrains\siren-pytorch\siren_pytorch\siren_pytorch.py
# 导入数学库和PyTorch库
import math
import torch
# 从torch库中导入神经网络模块
from torch import nn
# 从torch.nn.functional中导入函数F
import torch.nn.functional as F
# 从einops库中导入rearrange函数
from einops import rearrange
# 辅助函数
# 判断值是否存在的函数
def exists(val):
return val is not None
# 将值转换为元组的函数
def cast_tuple(val, repeat = 1):
return val if isinstance(val, tuple) else ((val,) * repeat)
# 正弦激活函数
class Sine(nn.Module):
def __init__(self, w0 = 1.):
super().__init__()
self.w0 = w0
def forward(self, x):
return torch.sin(self.w0 * x)
# Siren层
class Siren(nn.Module):
def __init__(
self,
dim_in,
dim_out,
w0 = 1.,
c = 6.,
is_first = False,
use_bias = True,
activation = None,
dropout = 0.
):
super().__init__()
self.dim_in = dim_in
self.is_first = is_first
weight = torch.zeros(dim_out, dim_in)
bias = torch.zeros(dim_out) if use_bias else None
self.init_(weight, bias, c = c, w0 = w0)
self.weight = nn.Parameter(weight)
self.bias = nn.Parameter(bias) if use_bias else None
self.activation = Sine(w0) if activation is None else activation
self.dropout = nn.Dropout(dropout)
def init_(self, weight, bias, c, w0):
dim = self.dim_in
w_std = (1 / dim) if self.is_first else (math.sqrt(c / dim) / w0)
weight.uniform_(-w_std, w_std)
if exists(bias):
bias.uniform_(-w_std, w_std)
def forward(self, x):
out = F.linear(x, self.weight, self.bias)
out = self.activation(out)
out = self.dropout(out)
return out
# Siren网络
class SirenNet(nn.Module):
def __init__(
self,
dim_in,
dim_hidden,
dim_out,
num_layers,
w0 = 1.,
w0_initial = 30.,
use_bias = True,
final_activation = None,
dropout = 0.
):
super().__init__()
self.num_layers = num_layers
self.dim_hidden = dim_hidden
self.layers = nn.ModuleList([])
for ind in range(num_layers):
is_first = ind == 0
layer_w0 = w0_initial if is_first else w0
layer_dim_in = dim_in if is_first else dim_hidden
layer = Siren(
dim_in = layer_dim_in,
dim_out = dim_hidden,
w0 = layer_w0,
use_bias = use_bias,
is_first = is_first,
dropout = dropout
)
self.layers.append(layer)
final_activation = nn.Identity() if not exists(final_activation) else final_activation
self.last_layer = Siren(dim_in = dim_hidden, dim_out = dim_out, w0 = w0, use_bias = use_bias, activation = final_activation)
def forward(self, x, mods = None):
mods = cast_tuple(mods, self.num_layers)
for layer, mod in zip(self.layers, mods):
x = layer(x)
if exists(mod):
x *= rearrange(mod, 'd -> () d')
return self.last_layer(x)
# 调制前馈
class Modulator(nn.Module):
def __init__(self, dim_in, dim_hidden, num_layers):
super().__init__()
self.layers = nn.ModuleList([])
for ind in range(num_layers):
is_first = ind == 0
dim = dim_in if is_first else (dim_hidden + dim_in)
self.layers.append(nn.Sequential(
nn.Linear(dim, dim_hidden),
nn.ReLU()
))
def forward(self, z):
x = z
hiddens = []
for layer in self.layers:
x = layer(x)
hiddens.append(x)
x = torch.cat((x, z))
return tuple(hiddens)
# 包装器
class SirenWrapper(nn.Module):
# 初始化函数,接受神经网络、图像宽度、图像高度和潜在维度作为参数
def __init__(self, net, image_width, image_height, latent_dim = None):
# 调用父类的初始化函数
super().__init__()
# 断言网络类型为 SirenNet
assert isinstance(net, SirenNet), 'SirenWrapper must receive a Siren network'
# 初始化网络、图像宽度和图像高度
self.net = net
self.image_width = image_width
self.image_height = image_height
# 初始化调制器为 None,如果传入了潜在维度,则创建 Modulator 对象
self.modulator = None
if exists(latent_dim):
self.modulator = Modulator(
dim_in = latent_dim,
dim_hidden = net.dim_hidden,
num_layers = net.num_layers
)
# 创建坐标张量
tensors = [torch.linspace(-1, 1, steps = image_height), torch.linspace(-1, 1, steps = image_width)]
mgrid = torch.stack(torch.meshgrid(*tensors, indexing = 'ij'), dim=-1)
mgrid = rearrange(mgrid, 'h w c -> (h w) c')
# 将坐标张量注册为缓冲区
self.register_buffer('grid', mgrid)
# 前向传播函数,接受图像或潜在向量作为参数
def forward(self, img = None, *, latent = None):
# 判断是否需要调制
modulate = exists(self.modulator)
# 断言只有在初始化时传入了潜在向量才能提供潜在向量
assert not (modulate ^ exists(latent)), 'latent vector must be only supplied if `latent_dim` was passed in on instantiation'
# 如果需要调制,则计算调制结果
mods = self.modulator(latent) if modulate else None
# 复制坐标张量并设置为需要梯度
coords = self.grid.clone().detach().requires_grad_()
# 将坐标张量输入网络得到输出
out = self.net(coords, mods)
out = rearrange(out, '(h w) c -> () c h w', h = self.image_height, w = self.image_width)
# 如果提供了图像,则计算均方误差损失
if exists(img):
return F.mse_loss(img, out)
# 返回输出结果
return out
.\lucidrains\siren-pytorch\siren_pytorch\__init__.py
# 从 siren_pytorch.siren_pytorch 模块中导入 Sine, Siren, SirenNet, SirenWrapper 类
from siren_pytorch.siren_pytorch import Sine, Siren, SirenNet, SirenWrapper

Slot Attention
Implementation of Slot Attention from the paper 'Object-Centric Learning with Slot Attention' in Pytorch. Here is a video that describes what this network can do.
Update: The official repository has been released here
Install
$ pip install slot_attention
Usage
import torch
from slot_attention import SlotAttention
slot_attn = SlotAttention(
num_slots = 5,
dim = 512,
iters = 3 # iterations of attention, defaults to 3
)
inputs = torch.randn(2, 1024, 512)
slot_attn(inputs) # (2, 5, 512)
After training, the network is reported to be able to generalize to slightly different number of slots (clusters). You can override the number of slots used by the num_slots keyword in forward.
slot_attn(inputs, num_slots = 8) # (2, 8, 512)
Citation
@misc{locatello2020objectcentric,
title = {Object-Centric Learning with Slot Attention},
author = {Francesco Locatello and Dirk Weissenborn and Thomas Unterthiner and Aravindh Mahendran and Georg Heigold and Jakob Uszkoreit and Alexey Dosovitskiy and Thomas Kipf},
year = {2020},
eprint = {2006.15055},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
.\lucidrains\slot-attention\setup.py
# 导入设置工具和查找包工具
from setuptools import setup, find_packages
# 设置包的元数据
setup(
name = 'slot_attention', # 包名
packages = find_packages(), # 查找所有包
version = '1.1.2', # 版本号
license='MIT', # 许可证
description = 'Implementation of Slot Attention in Pytorch', # 描述
long_description_content_type = 'text/markdown', # 长描述内容类型
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
url = 'https://github.com/lucidrains/slot-attention', # 项目链接
keywords = ['attention', 'artificial intelligence'], # 关键词
install_requires=[
'torch' # 安装依赖
],
classifiers=[
'Development Status :: 4 - Beta', # 开发状态
'Intended Audience :: Developers', # 预期受众
'Topic :: Scientific/Engineering :: Artificial Intelligence', # 主题
'License :: OSI Approved :: MIT License', # 许可证
'Programming Language :: Python :: 3.6', # 编程语言
],
)
.\lucidrains\slot-attention\slot_attention\slot_attention.py
import torch
from torch import nn
from torch.nn import init
class SlotAttention(nn.Module):
# 定义 SlotAttention 类,继承自 nn.Module
def __init__(self, num_slots, dim, iters = 3, eps = 1e-8, hidden_dim = 128):
# 初始化函数,接受 num_slots(槽的数量)、dim(维度)、iters(迭代次数,默认为3)、eps(小数值,默认为1e-8)、hidden_dim(隐藏层维度,默认为128)
super().__init__()
# 调用父类的初始化函数
self.num_slots = num_slots
# 设置槽的数量
self.iters = iters
# 设置迭代次数
self.eps = eps
# 设置小数值
self.scale = dim ** -0.5
# 计算缩放因子
self.slots_mu = nn.Parameter(torch.randn(1, 1, dim))
# 初始化槽的均值参数
self.slots_logsigma = nn.Parameter(torch.zeros(1, 1, dim))
# 初始化槽的对数标准差参数
init.xavier_uniform_(self.slots_logsigma)
# 使用 Xavier 初始化方法初始化槽的对数标准差参数
self.to_q = nn.Linear(dim, dim)
# 创建线性层,用于将输入转换为查询向量
self.to_k = nn.Linear(dim, dim)
# 创建线性层,用于将输入转换为键向量
self.to_v = nn.Linear(dim, dim)
# 创建线性层,用于将输入转换为值向量
self.gru = nn.GRUCell(dim, dim)
# 创建 GRU 单元,用于更新槽的状态
hidden_dim = max(dim, hidden_dim)
# 计算隐藏层维度
self.mlp = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.ReLU(inplace = True),
nn.Linear(hidden_dim, dim)
)
# 创建多层感知机模型,用于更新槽的状态
self.norm_input = nn.LayerNorm(dim)
# 创建 LayerNorm 层,用于对输入进行归一化
self.norm_slots = nn.LayerNorm(dim)
# 创建 LayerNorm 层,用于对槽的状态进行归一化
self.norm_pre_ff = nn.LayerNorm(dim)
# 创建 LayerNorm 层,用于对前馈网络的输出进行归一化
def forward(self, inputs, num_slots = None):
# 前向传播函数,接受输入和槽的数量(可选)
b, n, d, device, dtype = *inputs.shape, inputs.device, inputs.dtype
# 获取输入的形状、设备和数据类型
n_s = num_slots if num_slots is not None else self.num_slots
# 设置槽的数量为给定值或默认值
mu = self.slots_mu.expand(b, n_s, -1)
# 复制槽的均值参数以匹配批次大小和槽的数��
sigma = self.slots_logsigma.exp().expand(b, n_s, -1)
# 计算槽的标准差并复制以匹配批次大小和槽的数量
slots = mu + sigma * torch.randn(mu.shape, device = device, dtype = dtype)
# 生成服从正态分布的槽的状态
inputs = self.norm_input(inputs)
# 对输入进行归一化
k, v = self.to_k(inputs), self.to_v(inputs)
# 将输入转换为键和值
for _ in range(self.iters):
# 迭代更新槽的状态
slots_prev = slots
# 保存上一次的槽状态
slots = self.norm_slots(slots)
# 对槽的状态进行归一化
q = self.to_q(slots)
# 将槽的状态转换为查询向量
dots = torch.einsum('bid,bjd->bij', q, k) * self.scale
# 计算查询向量和键向量的点积,并乘以缩放因子
attn = dots.softmax(dim=1) + self.eps
# 对点积结果进行 softmax 操作,并加上小数值
attn = attn / attn.sum(dim=-1, keepdim=True)
# 归一化注意力权重
updates = torch.einsum('bjd,bij->bid', v, attn)
# 根据注意力权重更新值向量
slots = self.gru(
updates.reshape(-1, d),
slots_prev.reshape(-1, d)
)
# 使用 GRU 单元更新槽的状态
slots = slots.reshape(b, -1, d)
# 重新调整槽的状态的形状
slots = slots + self.mlp(self.norm_pre_ff(slots))
# 使用多层感知机更新槽的状态
return slots
# 返回更新后的槽的状态
.\lucidrains\slot-attention\slot_attention\slot_attention_experimental.py
import torch
from torch import nn
from torch.nn import init
class WeightedAttention(nn.Module):
def __init__(self, dim, eps = 1e-8, softmax_dim = 1, weighted_mean_dim = 2):
super().__init__()
self.norm_input = nn.LayerNorm(dim) # 对输入进行归一化
self.norm_context = nn.LayerNorm(dim) # 对上下文进行归一化
self.to_q = nn.Linear(dim, dim) # 线性变换,将输入转换为查询向量
self.to_k = nn.Linear(dim, dim) # 线性变换,将上下文转换为键向量
self.to_v = nn.Linear(dim, dim) # 线性变换,将上下文转换为值向量
self.eps = eps # 用于稳定softmax计算的小值
self.scale = dim ** -0.5 # 缩放因子
self.softmax_dim = softmax_dim # softmax计算的维度
self.weighted_mean_dim = weighted_mean_dim # 加权平均的维度
def forward(self, inputs, context):
inputs = self.norm_input(inputs) # 对输入进行归一化
context = self.norm_context(context) # 对上下文进行归一化
q = self.to_q(inputs) # 计算查询向量
k = self.to_k(context) # 计算键向量
v = self.to_v(context) # 计算值向量
dots = torch.einsum('bid,bjd->bij', q, k) * self.scale # 计算点积
attn = dots.softmax(dim = self.softmax_dim) + self.eps # 计算注意力权重
attn = attn / attn.sum(dim = self.weighted_mean_dim, keepdim=True) # 计算加权平均
updates = torch.einsum('bjd,bij->bid', v, attn) # 计算更新
return updates
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x):
return x + self.fn(x)
class GatedResidual(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.gru = nn.GRUCell(dim, dim) # GRU单元
self.fn = fn
def forward(self, *args):
inputs = args[0]
b, _, d = inputs.shape
updates = self.fn(*args)
inputs = self.gru(
updates.reshape(-1, d),
inputs.reshape(-1, d)
)
return inputs.reshape(b, -1, d)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
hidden_dim = max(dim, hidden_dim)
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim), # 线性变换
nn.ReLU(inplace = True), # ReLU激活函数
nn.Linear(hidden_dim, dim) # 线性变换
)
self.norm = nn.LayerNorm(dim) # 对输出进行归一化
def forward(self, x):
x = self.norm(x) # 对输入进行归一化
return self.net(x)
class SlotAttentionExperimental(nn.Module):
def __init__(self, num_slots, dim, iters = 3, eps = 1e-8, hidden_dim = 128):
super().__init__()
scale = dim ** -0.5
self.num_slots = num_slots
self.iters = iters
self.norm_inputs = nn.LayerNorm(dim) # 对输入进行归一化
self.slots_mu = nn.Parameter(torch.randn(1, 1, dim)) # 槽的均值参数
self.slots_logsigma = nn.Parameter(torch.zeros(1, 1, dim)) # 槽的对数标准差参数
init.xavier_uniform_(self.slots_logsigma) # 初始化槽的对数标准差参数
self.slots_to_inputs_attn = GatedResidual(dim, WeightedAttention(dim, eps = eps)) # 槽到输入的注意力机制
self.slots_ff = GatedResidual(dim, FeedForward(dim, hidden_dim)) # 槽的前馈网络
self.inputs_to_slots_attn = GatedResidual(dim, WeightedAttention(dim, eps = eps, softmax_dim = 2, weighted_mean_dim = 1)) # 输入到槽的注意力机制
self.inputs_ff = GatedResidual(dim, FeedForward(dim, hidden_dim)) # 输入的前馈网络
def forward(self, inputs, num_slots = None):
b, n, d, device, dtype = *inputs.shape, inputs.device, inputs.dtype
n_s = num_slots if num_slots is not None else self.num_slots
mu = self.slots_mu.expand(b, n_s, -1) # 扩展槽的均值参数
sigma = self.slots_logsigma.exp().expand(b, n_s, -1) # 扩展槽的对数标准差参数
slots = mu + sigma * torch.randn(mu.shape, device = device, dtype = dtype) # 生成槽
inputs = self.norm_inputs(inputs) # 对输入进行归一化
for _ in range(self.iters):
slots = self.slots_to_inputs_attn(slots, inputs) # 槽到输入的注意力机制
slots = self.slots_ff(slots) # 槽的前馈网络
inputs = self.inputs_to_slots_attn(inputs, slots) # 输入到槽的注意力机制
inputs = self.inputs_ff(inputs) # 输入的前馈网络
return slots, inputs # 返回槽和输入
.\lucidrains\slot-attention\slot_attention\__init__.py
# 从slot_attention模块中导入SlotAttention类
from slot_attention.slot_attention import SlotAttention
# 从slot_attention_experimental模块中导入SlotAttentionExperimental类
from slot_attention.slot_attention_experimental import SlotAttentionExperimental
.\lucidrains\soft-moe-pytorch\assert.py
# 导入必要的库
import os
from copy import deepcopy
import torch
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from soft_moe_pytorch.soft_moe import Experts, FeedForward as Expert
from soft_moe_pytorch.distributed import all_gather_variable_dim
# 设置初始化函数,用于初始化分布式进程组
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
dist.init_process_group("gloo", rank=rank, world_size=world_size)
# 清理函数,用于销毁进程组
def cleanup():
dist.destroy_process_group()
# 主函数,启动分布式训练
def start(
rank,
world_size,
batch_size,
batch_size_var_len,
num_experts,
tokens_per_expert,
dim,
):
# 初始化分布式进程组
setup(rank, world_size)
# 创建专家网络
net = Experts([Expert(dim) for _ in range(num_experts)])
# 根据是否变长批次设置批次大小
if batch_size_var_len:
batch_size = batch_size + rank
# 生成随机输入序列
seq = torch.randn(batch_size, num_experts, tokens_per_expert, dim)
# 分布式训练
# 使用分布式数据并行包装模型
model = DDP(net)
out = model(seq)
out.mean().backward()
# 所有进程收集输出
ddp_all_out, _ = all_gather_variable_dim(out)
# 单设备上
# 所有进程收集输入
all_inputs, _ = all_gather_variable_dim(seq)
copied_net = deepcopy(net)
# 在单设备上进行前向传播
single_out = copied_net(
all_inputs,
is_distributed=False
)
single_out.mean().backward()
if rank == 0:
# 验证输出是否相同
# 如果在单台机器上和多台机器上进行
assert torch.allclose(single_out, ddp_all_out), 'output is not the same'
# 验证梯度和grad是否相同
get_first_expert_grad = lambda t: t.experts[0][0].weight.grad
assert torch.allclose(
get_first_expert_grad(net),
get_first_expert_grad(copied_net),
atol=1e-2
), 'grad is not the same'
print('✅')
# 清理进程组
cleanup()
if __name__ == '__main__':
# 设置参数
world_size = 9
num_experts = 8
batch_size = 2
batch_size_var_len = False
seq_len = 32
dim = 8
# 多进程启动
mp.spawn(
start,
args=(
world_size,
batch_size,
batch_size_var_len,
num_experts,
seq_len,
dim
),
nprocs=world_size,
join=True
)


Soft MoE - Pytorch
Implementation of Soft MoE (Mixture of Experts), proposed by Brain's Vision team, in Pytorch.
This MoE has only been made to work with non-autoregressive encoder. However, some recent text-to-image models have started using MoE with great results, so may be a fit there.
If anyone has any ideas for how to make it work for autoregressive, let me know (through email or discussions). I meditated on it but can't think of a good way. The other issue with the slot scheme is that the routing suffers the quadratic as sequence length increases (much like attention)
Appreciation
-
StabilityAI for the generous sponsorship, as well as my other sponsors out there
-
Einops for making my life easy
Install
$ pip install soft-moe-pytorch
Usage
import torch
from soft_moe_pytorch import SoftMoE
moe = SoftMoE(
dim = 512, # model dimensions
seq_len = 1024, # max sequence length (will automatically calculate number of slots as seq_len // num_experts) - you can also set num_slots directly
num_experts = 4 # number of experts - (they suggest number of experts should be high enough that each of them get only 1 slot. wonder if that is the weakness of the paper?)
)
x = torch.randn(1, 1024, 512)
out = moe(x) + x # (1, 1024, 512) - add in a transformer in place of a feedforward at a certain layer (here showing the residual too)
For an improvised variant that does dynamic slots so that number of slots ~= sequence length, just import DynamicSlotsSoftMoe instead
import torch
from soft_moe_pytorch import DynamicSlotsSoftMoE
# sequence length or number of slots need not be specified
moe = DynamicSlotsSoftMoE(
dim = 512, # model dimensions
num_experts = 4, # number of experts
geglu = True
)
x = torch.randn(1, 1023, 512)
out = moe(x) + x # (1, 1023, 512)
Todo
- address the limitation of number of slots being fixed. think about a way to make dynamic number of slots based on sequence length
- once variable sequence length is handled in distributed, add to dynamic soft moe
- the dispatch and combine tensors can also be split and moved into the
Expertsclass to better distribute work
Citations
@misc{puigcerver2023sparse,
title = {From Sparse to Soft Mixtures of Experts},
author = {Joan Puigcerver and Carlos Riquelme and Basil Mustafa and Neil Houlsby},
year = {2023},
eprint = {2308.00951},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
@misc{shazeer2020glu,
title = {GLU Variants Improve Transformer},
author = {Noam Shazeer},
year = {2020},
url = {https://arxiv.org/abs/2002.05202}
}
.\lucidrains\soft-moe-pytorch\setup.py
# 导入设置工具和查找包的函数
from setuptools import setup, find_packages
# 设置软件包的元数据
setup(
name = 'soft-moe-pytorch', # 软件包的名称
packages = find_packages(exclude=[]), # 查找所有包
version = '0.1.7', # 版本号
license='MIT', # 许可证
description = 'Soft MoE - Pytorch', # 描述
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
long_description_content_type = 'text/markdown', # 长描述内容类型
url = 'https://github.com/lucidrains/soft-moe-pytorch', # 项目链接
keywords = [
'artificial intelligence', # 关键词:人工智能
'deep learning', # 关键词:深度学习
'mixture of experts' # 关键词:专家混合
],
install_requires=[
'einops>=0.6.1', # 安装所需的依赖项:einops 版本大于等于 0.6.1
'torch>=2.0' # 安装所需的依赖项:torch 版本大于等于 2.0
],
classifiers=[
'Development Status :: 4 - Beta', # 分类器:开发状态为 Beta
'Intended Audience :: Developers', # 分类器:面向的受众为开发者
'Topic :: Scientific/Engineering :: Artificial Intelligence', # 分类器:主题为科学/工程 - 人工智能
'License :: OSI Approved :: MIT License', # 分类器:许可证为 MIT 许可证
'Programming Language :: Python :: 3.6', # 分类器:编程语言为 Python 3.6
],
)
.\lucidrains\soft-moe-pytorch\soft_moe_pytorch\distributed.py
# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块
from torch import nn
# 从 torch 库中导入 nn.functional 模块,并重命名为 F
import torch.nn.functional as F
# 从 torch.autograd 模块中导入 Function 类
from torch.autograd import Function
# 从 torch.distributed 模块中导入 dist 对象
import torch.distributed as dist
# 从 einops 库中导入 rearrange, pack, unpack 函数
from einops import rearrange, pack, unpack
# 定义函数,判断变量是否存在
def exists(val):
return val is not None
# 定义函数,返回默认值
def default(val, d):
return val if exists(val) else d
# 定义函数,判断两个数是否整除
def divisible_by(num, den):
return (num % den) == 0
# 定义函数,将张量在指定维度上进行填充
def pad_dim_to(t, length, dim = 0):
pad_length = length - t.shape[dim]
zero_pairs = (-dim - 1) if dim < 0 else (t.ndim - dim - 1)
return F.pad(t, (*((0, 0) * zero_pairs), 0, pad_length)
# 定义函数,对所有进程进行相同维度的全局收集
def all_gather_same_dim(t):
world_size = dist.get_world_size()
t = t.contiguous()
gathered_tensors = [torch.empty_like(t, device = t.device, dtype = t.dtype) for i in range(world_size)]
dist.all_gather(gathered_tensors, t)
return gathered_tensors
# 定义函数,收集指定维度的大小信息
def gather_sizes(t, *, dim):
size = torch.tensor(t.shape[dim], device = t.device, dtype = torch.long)
sizes = all_gather_same_dim(size)
return torch.stack(sizes)
# 定义函数,判断张量是否只有一个值
def has_only_one_value(t):
return (t == t[0]).all()
# 定义函数,对变量维度进行全局收集
def all_gather_variable_dim(t, dim = 0, sizes = None):
device, rank, world_size = t.device, dist.get_rank(), dist.get_world_size()
if not exists(sizes):
sizes = gather_sizes(t, dim = dim)
if has_only_one_value(sizes):
gathered_tensors = all_gather_same_dim(t)
gathered_tensors = torch.cat(gathered_tensors, dim = dim)
return gathered_tensors, sizes
max_size = sizes.amax().item()
padded_t = pad_dim_to(t, max_size, dim = dim)
gathered_tensors = all_gather_same_dim(padded_t)
gathered_tensors = torch.cat(gathered_tensors, dim = dim)
seq = torch.arange(max_size, device = device)
mask = rearrange(seq, 'j -> 1 j') < rearrange(sizes, 'i -> i 1')
mask = rearrange(mask, 'i j -> (i j)')
seq = torch.arange(mask.shape[-1], device = device)
indices = seq[mask]
gathered_tensors = gathered_tensors.index_select(dim, indices)
return gathered_tensors, sizes
# 定义一个继承自 Function 的类 AllGatherFunction
class AllGatherFunction(Function):
@staticmethod
def forward(ctx, x, dim, sizes):
x, batch_sizes = all_gather_variable_dim(x, dim = dim, sizes = sizes)
ctx.batch_sizes = batch_sizes.tolist()
ctx.dim = dim
return x, batch_sizes
@staticmethod
def backward(ctx, grads, _):
batch_sizes, rank = ctx.batch_sizes, dist.get_rank()
grads_by_rank = grads.split(batch_sizes, dim = ctx.dim)
return grads_by_rank[rank], None, None
# 定义一个继承自 nn.Module 的类 AllGather
class AllGather(nn.Module):
def __init__(self, *, dim = 0):
super().__init__()
self.dim = dim
def forward(self, x, sizes = None):
return AllGatherFunction.apply(x, self.dim, sizes)
# 定义函数,根据进程排名拆分张量
def split_by_rank(x):
rank = dist.get_rank()
out = x[rank]
return out
.\lucidrains\soft-moe-pytorch\soft_moe_pytorch\soft_moe.py
# 导入 torch 库
import torch
# 从 torch.nn 中导入 Module 类
from torch.nn import Module
# 从 torch.nn.functional 中导入 F
import torch.nn.functional as F
# 从 torch.distributed 中导入 dist
import torch.distributed as dist
# 从 torch 中导入 nn, einsum, Tensor
from torch import nn, einsum, Tensor
# 从 einops 中导入 rearrange, pack, unpack
from einops import rearrange, pack, unpack
# 从 soft_moe_pytorch.distributed 中导入 AllGather, split_by_rank, gather_sizes, has_only_one_value
from soft_moe_pytorch.distributed import (
AllGather,
split_by_rank,
gather_sizes,
has_only_one_value
)
# 辅助函数
# 判断变量是否存在
def exists(val):
return val is not None
# 如果变量存在则返回其值,否则返回默认值
def default(val, d):
return val if exists(val) else d
# 判断一个数是否可以被另一个数整除
def divisible_by(num, den):
return (num % den) == 0
# 将一个数均匀分成若干份
def chunk_num(num, chunks):
num_per_chunk, remainder = divmod(num, chunks)
out = []
for i in range(chunks):
n = num_per_chunk
out.append(n + int(i < remainder))
return out
# 将一个张量按照指定模式打包
def pack_one(t, pattern):
return pack([t], pattern)
# 将一个打包后的张量按照指定模式解包
def unpack_one(t, ps, pattern):
return unpack(t, ps, pattern)[0]
# 对张量进行 L2 归一化
def l2norm(t):
return F.normalize(t, dim = - 1)
# 计算张量的累积和(exclusive)
def cumsum_exclusive(t, dim = -3):
assert dim < 0
num_pad_dims = -dim - 1
pre_padding = (0, 0) * num_pad_dims
return F.pad(t, (*pre_padding, 1, -1)).cumsum(dim = dim)
# 计算张量的对数
def log(t, eps = 1e-20):
return torch.log(t.clamp(min = eps))
# 生成 Gumbel 噪声
def gumbel_noise(t):
noise = torch.zeros_like(t).uniform_(0, 1)
return -log(-log(noise))
# 归一化
# LayerNorm 类
class LayerNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.gamma = nn.Parameter(torch.ones(dim))
self.register_buffer("beta", torch.zeros(dim))
def forward(self, x):
return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
# RMSNorm 类
class RMSNorm(Module):
def __init__(self, dim):
super().__init__()
self.scale = dim ** 0.5
self.gamma = nn.Parameter(torch.ones(dim))
def forward(self, x):
return l2norm(x) * self.scale * self.gamma
# expert
# 创建 FeedForward 网络
def FeedForward(
dim,
mult = 4,
dropout = 0.
):
dim_hidden = int(dim * mult)
return nn.Sequential(
nn.Linear(dim, dim_hidden),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(dim_hidden, dim)
)
# GEGLU 类
class GEGLU(Module):
def forward(self, x):
x, gate = x.chunk(2, dim = -1)
return x * F.gelu(gate)
# 创建 GLUFeedForward 网络
def GLUFeedForward(
dim,
mult = 4,
dropout = 0.
):
dim_hidden = int(dim * mult * 2 / 3)
return nn.Sequential(
nn.Linear(dim, dim_hidden * 2),
GEGLU(),
nn.Dropout(dropout),
nn.Linear(dim_hidden, dim)
)
# experts
# 专家类
class Experts(nn.Module):
def __init__(
self,
experts,
is_distributed = None,
offload_unused_experts_to_cpu = True
):
super().__init__()
self.num_experts = len(experts)
self.experts = nn.ModuleList(experts)
self.is_distributed = is_distributed
if not exists(self.is_distributed):
self.is_distributed = dist.is_initialized() and dist.get_world_size() > 1
# 是否将未使用的专家转移到 CPU,需要优化器处理梯度转换到正确设备
self.offload_unused_experts_to_cpu = offload_unused_experts_to_cpu
self.all_gather = AllGather()
self.register_buffer('dummy', torch.ones(1), persistent = False)
@property
def device(self):
return self.dummy.device
# 将所有专家转移到 CPU,除了指定的专家
def all_experts_to_cpu_besides(self, selection):
if not self.offload_unused_experts_to_cpu:
return
if isinstance(selection, int):
experts = [self.experts[selection]]
if isinstance(selection, slice):
experts = self.experts[selection]
else:
experts = selection
experts_set = set(experts)
for expert in self.experts:
device = self.device if expert in experts_set else 'cpu'
expert.to(device)
def forward(
self,
x,
is_distributed = None
"""
einops notation:
b - batch
r - rank (device / machines)
e - experts
n - sequence (number of tokens per expert)
d - feature dimension
"""
# 检查是否为分布式环境,默认为 self.is_distributed
is_distributed = default(is_distributed, self.is_distributed)
# 获取输入张量 x 的形状和专家数量
shape, num_experts = x.shape, self.num_experts
# 如果是分布式环境,则在批次维度上进行全局收集,暂时简单处理,后续优化
if is_distributed:
# 收集每个专家的序列大小
seq_sizes = gather_sizes(x, dim=-2)
assert has_only_one_value(seq_sizes), 'number of tokens per expert must be the same'
# 在批次维度上进行全局收集
x, batch_sizes = self.all_gather(x)
total_batch_size = x.shape[0]
world_size = dist.get_world_size()
rank = dist.get_rank()
else:
world_size = 1
rank = 0
# 在当前 rank 上使用的专家
if is_distributed:
if world_size <= num_experts:
num_experts_across_ranks = chunk_num(num_experts, world_size)
start_indices = cumsum_exclusive(torch.tensor(num_experts_across_ranks), dim=-1)
num_experts_per_rank = num_experts_across_ranks[rank]
num_experts_batches_across_ranks = tuple(i * total_batch_size for i in num_experts_across_ranks)
expert_start_index = start_indices[rank].item()
else:
num_batch_chunks = world_size // num_experts
total_ranks_in_use = num_batch_chunks * num_experts
expert_start_index = rank // num_batch_chunks
batch_splits = chunk_num(total_batch_size, num_batch_chunks)
num_experts_batches_across_ranks = batch_splits * num_experts
# 目前,剩余的机器不处理任何内容
remain_ranks = world_size % num_experts
num_experts_batches_across_ranks += (0,) * remain_ranks
num_experts_per_rank = int(rank < total_ranks_in_use)
assert len(num_experts_batches_across_ranks) == world_size
expert_slice = slice(expert_start_index, expert_start_index + num_experts_per_rank)
else:
num_experts_per_rank = num_experts
expert_slice = slice(0, num_experts)
# 如果是分布式的,每台机器只处理专家和批次的子集
# 重新排列输入张量 x 的维度
x = rearrange(x, 'b e n d -> e b n d')
if is_distributed:
# 打包 x,获取打包后的形状
x, expert_batch_packed_shape = pack_one(x, '* n d')
x = x.split(num_experts_batches_across_ranks, dim=0)
x = split_by_rank(x)
if num_experts_per_rank > 0:
x = rearrange(x, '(e b) n d -> e b n d', e=num_experts_per_rank)
else:
x = x.reshape(num_experts, *x.shape)
# 获取正在使用的专家
self.all_experts_to_cpu_besides(expert_slice)
experts = self.experts[expert_slice]
# 将标记路由到适当的专家
outs = []
for expert, expert_input in zip(experts, x):
out = expert(expert_input)
outs.append(out)
if len(outs) > 0:
outs = torch.stack(outs)
else:
outs = torch.empty_like(x).requires_grad_()
# 在合并的专家批次维度上进行全局收集,然后将批次维度拆分回来
if is_distributed:
outs = rearrange(outs, 'e b n d -> (e b) n d')
outs, _ = self.all_gather(outs)
outs = unpack_one(outs, expert_batch_packed_shape, '* n d')
outs = rearrange(outs, 'e b n d -> b e n d')
if is_distributed:
outs = outs.split(batch_sizes.tolist())
outs = split_by_rank(outs)
assert outs.shape == shape
return outs
# 主类 SoftMoE
class SoftMoE(Module):
# 初始化函数
def __init__(
self,
dim,
*,
seq_len = None,
num_experts = 4,
num_slots = None,
expert_mult = 4,
dropout = 0.,
geglu = False,
is_distributed = None,
offload_unused_experts_to_cpu = True,
use_layernorm = False
):
# 调用父类的初始化函数
super().__init__()
# 断言语句,确保 seq_len 或 num_slots 必须传入 SoftMoE
assert exists(seq_len) ^ exists(num_slots), 'either seq_len, or num_slots must be passed into SoftMoE'
# 如果 num_slots 为 None,则计算默认值
num_slots = default(num_slots, seq_len // num_experts)
# 根据 use_layernorm 的值选择不同的归一化类
norm_klass = LayerNorm if use_layernorm else RMSNorm
# 初始化 norm 层
self.norm = norm_klass(dim)
# 初始化 slot_norm 层
self.slot_norm = norm_klass(dim)
# 初始化 slot_embeds 参数
self.slot_embeds = nn.Parameter(torch.randn(num_experts, num_slots, dim))
# 根据 geglu 的值选择不同的 FeedForward 类
expert_klass = GLUFeedForward if geglu else FeedForward
# 初始化 experts 层
self.experts = Experts(
experts = [expert_klass(dim = dim, mult = expert_mult, dropout = dropout) for _ in range(num_experts)],
is_distributed = is_distributed,
offload_unused_experts_to_cpu = offload_unused_experts_to_cpu
)
# 前向传播函数
def forward(self, x, mask = None, add_noise = False, noise_mult = 1.):
"""
einstein notation
b - batch
n - sequence length
e - number of experts
s - number of slots per expert
d - feature dimension
"""
# 判断输入是否为单个 token
is_single_token = x.ndim == 2
# 判断输入是否为图像
is_image = x.ndim == 4
# 如果输入为图像,则重新排列维度
if is_image:
x = rearrange(x, 'b d h w -> b h w d')
x, ps = pack([x], 'b * d')
# 如果输入为单个 token,则重新排列维度
elif is_single_token:
x = rearrange(x, 'b d -> b 1 d')
# 对输入进行归一化
x = self.norm(x)
# 对 slot_embeds 进行归一化
slot_embeds = self.slot_norm(self.slot_embeds)
# 计算 logits
logits = einsum('b n d, e s d -> b n e s', x, slot_embeds)
# 添加噪音到 dispatch 和 combine gate logits,如果需要则进行退火
if add_noise:
noise = gumbel_noise(logits) * noise_mult
logits = logits + noise
# 处理 key padding mask
if exists(mask):
mask = rearrange(mask, 'b n -> b n 1 1')
logits = logits.masked_fill(~mask, -torch.finfo(logits.dtype).max)
# 获取 dispatch 和 combine 权重(在正确的维度上进行 softmax)
dispatch_weights = logits.softmax(dim = 1)
combine_weights = rearrange(logits, 'b n e s -> b n (e s)')
combine_weights = combine_weights.softmax(dim = -1)
# 通过使用上面的 dispatch 权重对输入 token 进行加权平均,得到 slots
slots = einsum('b n d, b n e s -> b e s d', x, dispatch_weights)
# 将每个专家的 slots 路由到每个专家
out = self.experts(slots)
# 合并输出
out = rearrange(out, ' b e s d -> b (e s) d')
out = einsum('b s d, b n s -> b n d', out, combine_weights)
# 如果输入为图像,则恢复原始维度
if is_image:
out, = unpack(out, ps, 'b * d')
out = rearrange(out, 'b h w d -> b d h w')
# 如果输入为单个 token,则恢复原始维度
elif is_single_token:
out = rearrange(out, 'b 1 d -> b d')
return out
.\lucidrains\soft-moe-pytorch\soft_moe_pytorch\soft_moe_with_dynamic_slots.py
# 导入数学库
import math
# 导入 PyTorch 库
import torch
from torch.nn import Module
import torch.nn.functional as F
from torch import nn, einsum, Tensor
# 导入 einops 库中的函数
from einops import rearrange, reduce, pack, unpack
from einops.layers.torch import Rearrange
# 辅助函数
# 检查值是否存在
def exists(val):
return val is not None
# 如果值存在则返回该值,否则返回默认值
def default(val, d):
return val if exists(val) else d
# 对输入张量进行 L2 归一化
def l2norm(t):
return F.normalize(t, dim = -1)
# 将张量填充到指定的倍数
def pad_to_multiple(
tensor,
multiple,
dim = -1,
value = 0
):
seqlen = tensor.shape[dim]
m = seqlen / multiple
if m.is_integer():
return False, tensor
remainder = math.ceil(m) * multiple - seqlen
pad_offset = (0,) * (-1 - dim) * 2
return True, F.pad(tensor, (*pad_offset, 0, remainder), value = value)
# 归一化模块
class RMSNorm(Module):
def __init__(self, dim):
super().__init__()
self.scale = dim ** 0.5
self.gamma = nn.Parameter(torch.ones(dim))
def forward(self, x):
return l2norm(x) * self.scale * self.gamma
# 专家模块
# 创建前馈神经网络
def FeedForward(
dim,
mult = 4,
dropout = 0.
):
dim_hidden = int(dim * mult)
return nn.Sequential(
nn.Linear(dim, dim_hidden),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(dim_hidden, dim)
)
# GEGLU 激活函数
class GEGLU(Module):
def forward(self, x):
x, gate = x.chunk(2, dim = -1)
return x * F.gelu(gate)
# 创建 GLU 前馈神经网络
def GLUFeedForward(
dim,
mult = 4,
dropout = 0.
):
dim_hidden = int(dim * mult * 2 / 3)
return nn.Sequential(
nn.Linear(dim, dim_hidden * 2),
GEGLU(),
nn.Dropout(dropout),
nn.Linear(dim_hidden, dim)
)
# 主类
class DynamicSlotsSoftMoE(Module):
def __init__(
self,
dim,
*,
num_experts = 4,
expert_mult = 4,
dropout = 0.,
geglu = False
):
super().__init__()
self.norm = RMSNorm(dim)
self.num_experts = num_experts
# 将输入映射到槽位嵌入
self.to_slot_embeds = nn.Sequential(
nn.Linear(dim, dim * num_experts, bias = False),
Rearrange('b n (e d) -> b e n d', e = num_experts),
RMSNorm(dim)
)
# 根据是否使用 GEGLU 创建专家模块
expert_klass = GLUFeedForward if geglu else FeedForward
# 创建多个专家模块
self.experts = nn.ModuleList([
expert_klass(dim = dim, mult = expert_mult, dropout = dropout) for _ in range(num_experts)
])
# 定义前向传播函数,接受输入 x 和 mask(可选)
def forward(self, x, mask = None):
"""
einstein notation
b - batch
n - sequence length
e - number of experts
s - number of slots per expert
d - feature dimension
"""
# 获取输入 x 的序列长度、是否为图像、专家数量等信息
seq_len, is_image, num_experts = x.shape[-2], x.ndim == 4, self.num_experts
# 如果输入为图像,则重新排列维度
if is_image:
x = rearrange(x, 'b d h w -> b h w d')
x, ps = pack([x], 'b * d')
# 对输入进行归一化处理
x = self.norm(x)
# 动态槽嵌入
# 首先对连续的令牌进行平均,然后将每个位置投影到相应数量的专家槽令牌
# 槽的数量应该约等于序列长度,就像通常的具有 1 个专家的 MoE 一样
# 检查是否需要填充,对输入进行填充
is_padded, x = pad_to_multiple(x, num_experts, dim = -2)
# 如果需要填充,且没有提供 mask,则创建一个全为 True 的 mask
if is_padded:
if not exists(mask):
mask = torch.ones(x.shape[:2], device = x.device, dtype = torch.bool)
_, mask = pad_to_multiple(mask, num_experts, dim = -1, value = False)
# 对输入进行分段处理
x_segmented = rearrange(x, 'b (n e) d -> b n e d', e = num_experts)
# 如果存在 mask,则根据 mask 进行填充
if exists(mask):
segmented_mask = rearrange(mask, 'b (n e) -> b n e', e = num_experts)
x_segmented = x_segmented.masked_fill(~rearrange(segmented_mask, '... -> ... 1'), 0.)
# 执行带有 mask 的均值计算
if exists(mask):
num = reduce(x_segmented, 'b n e d -> b n d', 'sum')
den = reduce(segmented_mask.float(), 'b n e -> b n 1', 'sum').clamp(min = 1e-5)
x_consecutive_mean = num / den
slots_mask = segmented_mask.any(dim = -1)
else:
x_consecutive_mean = reduce(x_segmented, 'b n e d -> b n d', 'mean')
# 投影以获取动态槽嵌入
slot_embeds = self.to_slot_embeds(x_consecutive_mean)
logits = einsum('b n d, b e s d -> b n e s', x, slot_embeds)
# 考虑键填充 mask
if exists(mask):
mask = rearrange(mask, 'b n -> b n 1 1')
slots_mask = rearrange(slots_mask, 'b s -> b 1 1 s')
logits = logits.masked_fill(~mask, -torch.finfo(logits.dtype).max)
logits = logits.masked_fill(~slots_mask, -torch.finfo(logits.dtype).max)
# 获取分发权重和组合权重(在正确的维度上进行 softmax)
dispatch_weights = logits.softmax(dim = 1)
combine_weights = rearrange(logits, 'b n e s -> b n (e s)')
combine_weights = combine_weights.softmax(dim = -1)
# 通过使用上述分发权重对输入令牌进行加权平均,得到槽
slots = einsum('b n d, b n e s -> e b s d', x, dispatch_weights)
# 将每个专家的槽路由到每个专家
out = []
for slots_per_expert, expert in zip(slots, self.experts):
out.append(expert(slots_per_expert))
out = torch.stack(out)
# 合并输出
out = rearrange(out, 'e b s d -> b (e s) d')
out = einsum('b s d, b n s -> b n d', out, combine_weights)
# 如果输入为图像,则恢复原始维度
if is_image:
out, = unpack(out, ps, 'b * d')
out = rearrange(out, 'b h w d -> b d h w')
return out[:, :seq_len]
.\lucidrains\soft-moe-pytorch\soft_moe_pytorch\__init__.py
# 从 soft_moe_pytorch 软件包中导入 SoftMoE 类
# 从 soft_moe_pytorch 软件包中导入 DynamicSlotsSoftMoE 类
from soft_moe_pytorch.soft_moe import SoftMoE
from soft_moe_pytorch.soft_moe_with_dynamic_slots import DynamicSlotsSoftMoE

Soundstorm - Pytorch
Implementation of SoundStorm, Efficient Parallel Audio Generation from Google Deepmind, in Pytorch.
They basically applied MaskGiT to the residual vector quantized codes from Soundstream. The transformer architecture they chose to use is one that fits well with the audio domain, named Conformer
Appreciation
-
Stability and 🤗 Huggingface for their generous sponsorships to work on and open source cutting edge artificial intelligence research
-
Lucas Newman for numerous contributions, including the initial training code, acoustic prompting logic, per-level quantizer decoding!
-
🤗 Accelerate for providing a simple and powerful solution for training
-
Einops for the indispensable abstraction that makes building neural networks fun, easy, and uplifting
-
Steven Hillis for submitting the correct masking strategy and for verifying that the repository works! 🙏
-
Lucas Newman for basically training a small working Soundstorm with models across multiple repositories, showing it all works end-to-end. Models include SoundStream, Text-to-Semantic T5, and finally the SoundStorm transformer here.
-
@Jiang-Stan for identifying a critical bug in the iterative demasking!
Install
$ pip install soundstorm-pytorch
Usage
import torch
from soundstorm_pytorch import SoundStorm, ConformerWrapper
conformer = ConformerWrapper(
codebook_size = 1024,
num_quantizers = 12,
conformer = dict(
dim = 512,
depth = 2
),
)
model = SoundStorm(
conformer,
steps = 18, # 18 steps, as in original maskgit paper
schedule = 'cosine' # currently the best schedule is cosine
)
# get your pre-encoded codebook ids from the soundstream from a lot of raw audio
codes = torch.randint(0, 1024, (2, 1024, 12)) # (batch, seq, num residual VQ)
# do the below in a loop for a ton of data
loss, _ = model(codes)
loss.backward()
# model can now generate in 18 steps. ~2 seconds sounds reasonable
generated = model.generate(1024, batch_size = 2) # (2, 1024)
To directly train on raw audio, you need to pass in your pretrained SoundStream into SoundStorm. You can train your own SoundStream at audiolm-pytorch.
import torch
from soundstorm_pytorch import SoundStorm, ConformerWrapper, Conformer, SoundStream
conformer = ConformerWrapper(
codebook_size = 1024,
num_quantizers = 12,
conformer = dict(
dim = 512,
depth = 2
),
)
soundstream = SoundStream(
codebook_size = 1024,
rq_num_quantizers = 12,
attn_window_size = 128,
attn_depth = 2
)
model = SoundStorm(
conformer,
soundstream = soundstream # pass in the soundstream
)
# find as much audio you'd like the model to learn
audio = torch.randn(2, 10080)
# course it through the model and take a gazillion tiny steps
loss, _ = model(audio)
loss.backward()
# and now you can generate state-of-the-art speech
generated_audio = model.generate(seconds = 30, batch_size = 2) # generate 30 seconds of audio (it will calculate the length in seconds based off the sampling frequency and cumulative downsamples in the soundstream passed in above)
Complete text-to-speech will rely on a trained TextToSemantic encoder / decoder transformer. You will then load the weights and pass it into the SoundStorm as spear_tts_text_to_semantic
This is a work-in-progress, as spear-tts-pytorch only has the model architecture complete, and not the pretraining + pseudo-labeling + backtranslation logic.
from spear_tts_pytorch import TextToSemantic
text_to_semantic = TextToSemantic(
dim = 512,
source_depth = 12,
target_depth = 12,
num_text_token_ids = 50000,
num_semantic_token_ids = 20000,
use_openai_tokenizer = True
)
# load the trained text-to-semantic transformer
text_to_semantic.load('/path/to/trained/model.pt')
# pass it into the soundstorm
model = SoundStorm(
conformer,
soundstream = soundstream,
spear_tts_text_to_semantic = text_to_semantic
).cuda()
# and now you can generate state-of-the-art speech
generated_speech = model.generate(
texts = [
'the rain in spain stays mainly in the plain',
'the quick brown fox jumps over the lazy dog'
]
) # (2, n) - raw waveform decoded from soundstream
Todo
-
integrate soundstream
-
when generating, and length can be defined in seconds (takes into sampling freq etc)
-
make sure grouped rvq is supported. concat embeddings rather than sum across group dimension
-
just copy conformer over and redo shaw's relative positional embedding with rotary embedding. nobody uses shaw anymore.
-
default flash attention to true
-
remove batchnorm, and just use layernorm, but after the swish (as in normformer paper)
-
trainer with accelerate - thanks to @lucasnewman
-
allow for variable lengthed sequence training and generation, by passing in
maskatforwardandgenerate -
option to return list of audio files when generating
-
turn it into a command line tool
-
add cross attention and adaptive layernorm conditioning
Citations
@misc{borsos2023soundstorm,
title = {SoundStorm: Efficient Parallel Audio Generation},
author = {Zalán Borsos and Matt Sharifi and Damien Vincent and Eugene Kharitonov and Neil Zeghidour and Marco Tagliasacchi},
year = {2023},
eprint = {2305.09636},
archivePrefix = {arXiv},
primaryClass = {cs.SD}
}
@inproceedings{dao2022flashattention,
title = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
author = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
booktitle = {Advances in Neural Information Processing Systems},
year = {2022}
}
@article{Chang2022MaskGITMG,
title = {MaskGIT: Masked Generative Image Transformer},
author = {Huiwen Chang and Han Zhang and Lu Jiang and Ce Liu and William T. Freeman},
journal = {2022 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
year = {2022},
pages = {11305-11315}
}
@article{Lezama2022ImprovedMI,
title = {Improved Masked Image Generation with Token-Critic},
author = {Jos{\'e} Lezama and Huiwen Chang and Lu Jiang and Irfan Essa},
journal = {ArXiv},
year = {2022},
volume = {abs/2209.04439}
}
@inproceedings{Nijkamp2021SCRIPTSP,
title = {SCRIPT: Self-Critic PreTraining of Transformers},
author = {Erik Nijkamp and Bo Pang and Ying Nian Wu and Caiming Xiong},
booktitle = {North American Chapter of the Association for Computational Linguistics},
year = {2021}
}
@inproceedings{rogozhnikov2022einops,
title = {Einops: Clear and Reliable Tensor Manipulations with Einstein-like Notation},
author = {Alex Rogozhnikov},
booktitle = {International Conference on Learning Representations},
year = {2022},
url = {https://openreview.net/forum?id=oapKSVM2bcj}
}
@misc{su2021roformer,
title = {RoFormer: Enhanced Transformer with Rotary Position Embedding},
author = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu},
year = {2021},
eprint = {2104.09864},
archivePrefix = {arXiv},
primaryClass = {cs.CL}
}
.\lucidrains\soundstorm-pytorch\setup.py
# 导入设置和查找包的函数
from setuptools import setup, find_packages
# 设置包的元数据
setup(
name = 'soundstorm-pytorch', # 包的名称
packages = find_packages(exclude=[]), # 查找所有包
version = '0.4.2', # 版本号
license='MIT', # 许可证
description = 'SoundStorm - Efficient Parallel Audio Generation from Google Deepmind, in Pytorch', # 描述
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
long_description_content_type = 'text/markdown', # 长描述内容类型
url = 'https://github.com/lucidrains/soundstorm-pytorch', # URL
keywords = [ # 关键词列表
'artificial intelligence',
'deep learning',
'transformers',
'attention mechanism',
'audio generation'
],
install_requires=[ # 安装依赖列表
'accelerate',
'audiolm-pytorch>=1.2.8',
'beartype',
'classifier-free-guidance-pytorch>=0.1.5',
'gateloop-transformer>=0.1.1',
'einops>=0.6.1',
'spear-tts-pytorch>=0.4.0',
'torch>=1.6',
],
classifiers = [ # 分类器列表
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\soundstorm-pytorch\soundstorm_pytorch\attend.py
# 导入必要的库
from collections import namedtuple
from functools import wraps
from packaging import version
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange
# 定义一个命名元组EfficientAttentionConfig,包含三个布尔类型的参数
EfficientAttentionConfig = namedtuple('EfficientAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
# 定义辅助函数
# 判断变量是否存在
def exists(val):
return val is not None
# 保证函数只执行一次的装饰器
def once(fn):
called = False
@wraps(fn)
def inner(x):
nonlocal called
if called:
return
called = True
return fn(x)
return inner
# 用once装饰的print函数,确保只打印一次
print_once = once(print)
# 主要类
class Attend(nn.Module):
def __init__(
self,
causal = False,
dropout = 0.,
flash = False
):
super().__init__()
self.dropout = dropout
self.attn_dropout = nn.Dropout(dropout)
self.causal = causal
self.flash = flash
assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
# 确定用于cuda和cpu的高效注意力配置
self.cpu_config = EfficientAttentionConfig(True, True, True)
self.cuda_config = None
if not torch.cuda.is_available() or not flash:
return
device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
if device_properties.major == 8 and device_properties.minor == 0:
print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
self.cuda_config = EfficientAttentionConfig(True, False, False)
else:
print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
self.cuda_config = EfficientAttentionConfig(False, True, True)
# 生成掩码
def get_mask(self, i, j, device):
return torch.ones((i, j), device=device, dtype=torch.bool).triu(j - i + 1)
# Flash Attention函数
def flash_attn(self, q, k, v, mask = None, attn_bias = None):
_, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device
# 单头键/值
if k.ndim == 3:
k = rearrange(k, 'b n d -> b 1 n d')
if v.ndim == 3:
v = rearrange(v, 'b n d -> b 1 n d')
# 检查掩码是否存在并扩展到兼容的形状
# 掩码是B L,因此必须扩展为B H N L
if exists(mask) and mask.ndim != 4:
mask = rearrange(mask, 'b j -> b 1 1 j')
mask = mask.expand(-1, heads, q_len, -1)
# 检查是否有兼容的设备用于Flash Attention
config = self.cuda_config if is_cuda else self.cpu_config
causal = self.causal
# 处理注意力偏置
if exists(attn_bias):
mask_value = -torch.finfo(q.dtype).max // 2
causal_mask = self.get_mask(q_len, k_len, device)
attn_bias = attn_bias.masked_fill(causal_mask, mask_value)
if exists(mask):
attn_bias = attn_bias.masked_fill(~mask, mask_value)
mask = attn_bias
causal = False
# 使用torch.backends.cuda.sdp_kernel(**config._asdict())来调用Flash Attention
with torch.backends.cuda.sdp_kernel(**config._asdict()):
out = F.scaled_dot_product_attention(
q, k, v,
attn_mask = mask,
dropout_p = self.dropout if self.training else 0.,
is_causal = causal
)
return out
# 定义一个前向传播函数,接受查询(q)、键(k)、值(v)、掩码(mask)和注意力偏置(attn_bias)作为参数
def forward(self, q, k, v, mask = None, attn_bias = None):
"""
einstein notation
b - batch
h - heads
n, i, j - sequence length (base sequence length, source, target)
d - feature dimension
"""
# 获取查询(q)和键(k)的序列长度以及设备信息
q_len, k_len, device = q.shape[-2], k.shape[-2], q.device
# 计算缩放因子
scale = q.shape[-1] ** -0.5
# 根据键(k)的维度确定 einsum 的等式
kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d'
# 如果启用了 flash 模式,则调用 flash_attn 函数进行注意力计算
if self.flash:
assert not exists(attn_bias)
return self.flash_attn(q, k, v, mask = mask)
# 计算相似度
sim = einsum(f"b h i d, {kv_einsum_eq} -> b h i j", q, k) * scale
# 添加注意力偏置
if exists(attn_bias):
sim = sim + attn_bias
# 如果启用了因果模式,则获取因果掩码
if self.causal:
causal_mask = self.get_mask(q_len, k_len, device)
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
# 如果存在掩码,则根据掩码进行填充
if exists(mask):
if mask.ndim != 4:
mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
# 计算注意力权重
attn = sim.softmax(dim=-1)
attn = self.attn_dropout(attn)
# 聚合值
out = einsum(f"b h i j, {kv_einsum_eq} -> b h i d", attn, v)
return out
.\lucidrains\soundstorm-pytorch\soundstorm_pytorch\soundstorm.py
import math
from random import random, randrange # 导入随机数生成相关函数
from functools import wraps # 导入wraps装饰器
from contextlib import nullcontext # 导入nullcontext上下文管理器
from collections import namedtuple # 导入namedtuple命名元组
from pathlib import Path # 导入Path路径操作模块
import torch # 导入PyTorch深度学习库
from torch.cuda.amp import autocast # 导入自动混合精度计算
from torch import Tensor, nn, einsum # 导入张量、神经网络、einsum函数
import torch.nn.functional as F # 导入PyTorch中的函数模块
from einops import rearrange, reduce, repeat, unpack, pack # 导入einops库中的函数
from einops.layers.torch import Rearrange, EinMix # 导入einops库中的层函数
from beartype import beartype # 导入beartype类型检查库
from beartype.door import is_bearable # 导入is_bearable函数
from beartype.typing import Union, Dict, Optional, List, Optional # 导入beartype中的类型注解
from soundstorm_pytorch.attend import Attend # 导入Attend模块
from spear_tts_pytorch import TextToSemantic # 导入TextToSemantic模块
from audiolm_pytorch import SoundStream # 导入SoundStream模块
from audiolm_pytorch import HubertWithKmeans, FairseqVQWav2Vec # 导入HubertWithKmeans和FairseqVQWav2Vec模块
from gateloop_transformer import SimpleGateLoopLayer as GateLoop # 导入SimpleGateLoopLayer模块
from tqdm import tqdm # 导入tqdm进度条模块
# helpers
def exists(val):
return val is not None # 判断值是否存在
def default(val, d):
return val if exists(val) else d # 如果值存在则返回值,否则返回默认值
def divisible_by(numer, denom):
return (numer % denom) == 0 # 判断是否可以整除
def calc_same_padding(kernel_size):
pad = kernel_size // 2 # 计算padding值
return (pad, pad - (kernel_size + 1) % 2) # 返回padding元组
def eval_decorator(fn):
@wraps(fn)
def inner(model, *args, **kwargs):
was_training = model.training
model.eval() # 设置模型为评估模式
out = fn(model, *args, **kwargs) # 调用函数
model.train(was_training) # 恢复模型训练模式
return out
return inner
# sampling helpers
def top_k(logits, thres = 0.9):
k = math.ceil((1 - thres) * logits.shape[-1]) # 计算top-k值
val, ind = logits.topk(k, dim = -1) # 获取top-k值和索引
probs = torch.full_like(logits, float('-inf')) # 创建与logits相同形状的全为负无穷的张量
probs.scatter_(2, ind, val) # 根据索引填充top-k值
return probs # 返回top-k值
def log(t, eps = 1e-10):
return torch.log(t + eps) # 计算对数
def gumbel_noise(t):
noise = torch.zeros_like(t).uniform_(0, 1) # 生成均匀分布的噪声
return -log(-log(noise)) # 计算Gumbel噪声
def gumbel_sample(t, temperature = 1., dim = -1):
return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim = dim) # 计算Gumbel采样
# prob helpers
def sample_prob(prob):
return random() < prob # 根据概率进行采样
def coin_flip():
return sample_prob(0.5) # 以0.5的概率进行翻转
# tensor helpers
@beartype
def get_mask_subset_prob(
mask: Tensor,
prob: Union[float, Tensor],
min_mask: int = 0
):
batch, seq, device = *mask.shape, mask.device # 获取批次大小、序列长度和设备信息
if isinstance(prob, Tensor):
prob = rearrange(prob, 'b -> b 1') # 重排概率张量的维度
num_to_mask = (mask.sum(dim = -1, keepdim = True) * prob).clamp(min = min_mask) # 计算要屏蔽的数量
logits = torch.rand((batch, seq), device = device) # 生成随机数张量
logits = logits.masked_fill(~mask, -1) # 根据mask进行填充
randperm = logits.argsort(dim = -1).argsort(dim = -1).float() # 对logits进行排序
num_padding = (~mask).sum(dim = -1, keepdim = True) # 计算填充数量
randperm -= num_padding # 减去填充数量
subset_mask = randperm < num_to_mask # 生成子集mask
subset_mask.masked_fill_(~mask, False) # 根据mask进行填充
return subset_mask # 返回子集mask
# schedules
def linear_schedule(t):
return 1 - t # 线性调度函数
def cosine_schedule(t):
""" https://arxiv.org/abs/2202.04200 """
return torch.cos(t * math.pi / 2) # 余弦调度函数
# rotary embedding
class RotaryEmbedding(nn.Module):
def __init__(self, dim, theta = 10000):
super().__init__()
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) # 计算频率
self.register_buffer("inv_freq", inv_freq, persistent = False) # 注册缓冲区
@property
def device(self):
return next(self.buffers()).device # 获取设备信息
@autocast(enabled = False)
def forward(self, seq_len):
t = torch.arange(seq_len, device = self.device).type_as(self.inv_freq) # 生成序列长度张量
freqs = torch.einsum('i , j -> i j', t, self.inv_freq) # 计算频率
freqs = torch.cat((freqs, freqs), dim = -1) # 拼接频率
return freqs # 返回频率
def rotate_half(x):
x1, x2 = x.chunk(2, dim=-1) # 将张量分成两部分
return torch.cat((-x2, x1), dim=-1) # 拼接张量
@autocast(enabled = False)
def apply_rotary_pos_emb(pos, t):
return (t * pos.cos()) + (rotate_half(t) * pos.sin()) # 应用旋转位置嵌入
# t5 relative positional bias
class T5RelativePositionBias(nn.Module):
def __init__(
self,
scale = 1.,
num_buckets = 32,
max_distance = 128,
heads = 8
):
# 调用父类的构造函数
super().__init__()
# 初始化缩放因子
self.scale = scale
# 初始化桶的数量
self.num_buckets = num_buckets
# 初始化最大距离
self.max_distance = max_distance
# 创建相对注意力偏置的嵌入层
self.relative_attention_bias = nn.Embedding(num_buckets, heads)
@staticmethod
def _relative_position_bucket(
relative_position,
num_buckets = 32,
max_distance = 128
):
# 初始化返回值
ret = 0
# 计算相对位置的负值
n = -relative_position
# 将桶的数量减半
num_buckets //= 2
# 根据n是否小于0来更新ret
ret += (n < 0).long() * num_buckets
n = torch.abs(n)
# 计算最大精确值
max_exact = num_buckets // 2
# 判断n是否小于最大精确值
is_small = n < max_exact
# 计算大值时的结果
val_if_large = max_exact + (
torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
).long()
# 将大值结果限制在桶的范围内
val_if_large = torch.min(
val_if_large,
torch.full_like(val_if_large, num_buckets - 1)
)
# 根据is_small选择n或者val_if_large
ret += torch.where(is_small, n, val_if_large)
return ret
@property
def device(self):
# 返回参数的设备信息
return next(self.parameters()).device
def forward(self, n):
# 生成长度为n的张量
pos = torch.arange(n, device = self.device).long()
# 计算相对位置
rel_pos = rearrange(pos, 'j -> 1 j') - rearrange(pos, 'i -> i 1')
# 计算相对位置的桶
rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance)
# 获取相对注意力偏置的值
values = self.relative_attention_bias(rp_bucket)
# 重排values的维度
bias = rearrange(values, 'i j h -> h i j')
return bias * self.scale
# 定义 Swish 激活函数模块
class Swish(nn.Module):
# 前向传播函数
def forward(self, x):
return x * x.sigmoid()
# 定义 GLU 模块
class GLU(nn.Module):
# 初始化函数
def __init__(self, dim):
super().__init__()
self.dim = dim
# 前向传播函数
def forward(self, x):
# 将输入张量按维度分割成两部分
out, gate = x.chunk(2, dim=self.dim)
return out * gate.sigmoid()
# 定义 DepthWiseConv1d 模块
class DepthWiseConv1d(nn.Module):
# 初始化函数
def __init__(self, chan_in, chan_out, kernel_size, padding):
super().__init__()
self.padding = padding
# 创建深度卷积层
self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, groups=chan_in)
# 前向传播函数
def forward(self, x, mask=None):
# 如果存在掩码,则将掩码应用到输入张量上
if exists(mask):
mask = rearrange(mask, 'b n -> b 1 n')
x = x.masked_fill(~mask, 0.)
# 对输入张量进行填充
x = F.pad(x, self.padding)
# 进行卷积操作
out = self.conv(x)
# 如果存在掩码,则将掩码应用到输出张量上
if exists(mask):
out = out.masked_fill(~mask, 0.)
return out
# 定义 Scale 模块
class Scale(nn.Module):
# 初始化函数
def __init__(self, scale, fn):
super().__init__()
self.fn = fn
self.scale = scale
# 前向传播函数
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) * self.scale
# 定义 ChanLayerNorm 模块
class ChanLayerNorm(nn.Module):
# 初始化函数
def __init__(self, dim):
super().__init__()
self.gamma = nn.Parameter(torch.ones(1, dim, 1))
# 前向传播函数
def forward(self, x):
eps = 1e-6 if x.dtype == torch.float32 else 1e-4
var = torch.var(x, dim=1, unbiased=False, keepdim=True)
mean = torch.mean(x, dim=1, keepdim=True)
return (x - mean) * var.clamp(min=eps).rsqrt() * self.gamma
# 定义 PreNorm 模块
class PreNorm(nn.Module):
# 初始化函数
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = nn.LayerNorm(dim)
# 前向传播函数
def forward(self, x, **kwargs):
x = self.norm(x)
return self.fn(x, **kwargs)
# 定义 Attention 模块
class Attention(nn.Module):
# 初始化函数
def __init__(
self,
dim,
heads=8,
dim_head=64,
dropout=0.,
flash=True
):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.attend = Attend(
flash=flash,
dropout=dropout
)
self.dropout = nn.Dropout(dropout)
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
self.to_out = nn.Linear(inner_dim, dim)
# 前向传播函数
def forward(
self,
x,
context=None,
mask=None,
rotary_emb=None,
attn_bias=None
):
n, device, h, has_context = x.shape[-2], x.device, self.heads, exists(context)
context = default(context, x)
q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim=-1))
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
if exists(rotary_emb):
q = apply_rotary_pos_emb(rotary_emb, q)
k = apply_rotary_pos_emb(rotary_emb, k)
out = self.attend(q, k, v, mask=mask, attn_bias=attn_bias)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
# 定义 FeedForward 模块
class FeedForward(nn.Module):
# 初始化函数
def __init__(
self,
dim,
mult=4,
dropout=0.
):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, dim * mult),
Swish(),
nn.Dropout(dropout),
nn.Linear(dim * mult, dim),
nn.Dropout(dropout)
)
# 前向传播函数
def forward(self, x):
return self.net(x)
# 定义 ConformerConvModule 模块
class ConformerConvModule(nn.Module):
# 初始化函数
def __init__(
self,
dim,
causal=False,
expansion_factor=2,
kernel_size=31,
dropout=0.
# 定义一个类,继承自 nn.Module
):
# 调用父类的构造函数
super().__init__()
# 计算内部维度
inner_dim = dim * expansion_factor
# 计算填充大小,如果是因果卷积则填充为 (kernel_size - 1, 0)
padding = calc_same_padding(kernel_size) if not causal else (kernel_size - 1, 0)
# 定义网络结构 net1,包括 LayerNorm、Rearrange、Conv1d 和 GLU 激活函数
self.net1 = nn.Sequential(
nn.LayerNorm(dim),
Rearrange('b n c -> b c n'),
nn.Conv1d(dim, inner_dim * 2, 1),
GLU(dim=1)
)
# 定义深度卷积层 ds_conv
self.ds_conv = DepthWiseConv1d(inner_dim, inner_dim, kernel_size = kernel_size, padding = padding)
# 定义网络结构 net2,包括 Swish 激活函数、ChanLayerNorm、Conv1d、Rearrange 和 Dropout
self.net2 = nn.Sequential(
Swish(),
ChanLayerNorm(inner_dim),
nn.Conv1d(inner_dim, dim, 1),
Rearrange('b c n -> b n c'),
nn.Dropout(dropout)
)
# 定义前向传播函数
def forward(self, x, mask = None):
# 使用 net1 进行前向传播
x = self.net1(x)
# 使用 ds_conv 进行前向传播
x = self.ds_conv(x, mask = mask)
# 使用 net2 进行前向传播
return self.net2(x)
# Conformer Block
# 定义 ConformerBlock 类
class ConformerBlock(nn.Module):
# 初始化函数
def __init__(
self,
*,
dim, # 维度
dim_head = 64, # 头的维度
heads = 8, # 头的数量
ff_mult = 4, # FeedForward 层的倍数
conv_expansion_factor = 2, # 卷积扩展因子
conv_kernel_size = 31, # 卷积核大小
attn_dropout = 0., # 注意力机制的 dropout
attn_flash = True, # 是否使用闪存注意力
ff_dropout = 0., # FeedForward 层的 dropout
conv_dropout = 0., # 卷积层的 dropout
conv_causal = False, # 是否是因果卷积
use_gateloop_layers = False # 是否使用门循环层
):
super().__init__()
# 创建第一个 FeedForward 层
self.ff1 = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
# 如果使用门循环层,则创建 GateLoop 层
self.gateloop = GateLoop(dim) if use_gateloop_layers else None
# 创建注意力机制层
self.attn = Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, flash = attn_flash)
# 创建 ConformerConvModule 层
self.conv = ConformerConvModule(dim = dim, causal = conv_causal, expansion_factor = conv_expansion_factor, kernel_size = conv_kernel_size, dropout = conv_dropout)
# 创建第二个 FeedForward 层
self.ff2 = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
# 对注意力机制层进行预归一化
self.attn = PreNorm(dim, self.attn)
# 对第一个 FeedForward 层进行预归一化
self.ff1 = Scale(0.5, PreNorm(dim, self.ff1))
# 对第二个 FeedForward 层进行预归一化
self.ff2 = Scale(0.5, PreNorm(dim, self.ff2))
# 创建 LayerNorm 层
self.post_norm = nn.LayerNorm(dim)
# 前向传播函数
def forward(
self,
x,
mask = None,
rotary_emb = None,
attn_bias = None
):
# 第一个 FeedForward 层
x = self.ff1(x) + x
# 如果存在门循环层,则应用门循环层
if exists(self.gateloop):
x = self.gateloop(x) + x
# 注意力机制层
x = self.attn(x, mask = mask, rotary_emb = rotary_emb, attn_bias = attn_bias) + x
# 卷积层
x = self.conv(x, mask = mask) + x
# 第二个 FeedForward 层
x = self.ff2(x) + x
# LayerNorm 层
x = self.post_norm(x)
return x
# Conformer
# 定义 Conformer 类
class Conformer(nn.Module):
# 初始化函数
def __init__(
self,
dim,
*,
depth, # 深度
dim_head = 64, # 头的维度
heads = 8, # 头的数量
ff_mult = 4, # FeedForward 层的倍数
conv_expansion_factor = 2, # 卷积扩展因子
conv_kernel_size = 31, # 卷积核大小
attn_dropout = 0., # 注意力机制的 dropout
ff_dropout = 0., # FeedForward 层的 dropout
conv_dropout = 0., # 卷积层的 dropout
conv_causal = False, # 是否是因果卷积
attn_flash = True, # 是否使用闪存注意力
t5_rel_pos_bias = False, # 是否使用 T5 相对位置偏置
use_gateloop_layers = True # 是否使用门循环层
):
super().__init__()
# 断言,确保闪存注意力和学习偏置不兼容
assert not (t5_rel_pos_bias and attn_flash), 'flash attention is not compatible with learned bias'
self.dim = dim
self.layers = nn.ModuleList([])
# 如果不使用 T5 相对位置偏置,则创建 RotaryEmbedding 层
self.rotary_emb = RotaryEmbedding(dim_head) if not t5_rel_pos_bias else None
# 如果使用 T5 相对位置偏置,则创建 T5RelativePositionBias 层
self.rel_pos_bias = T5RelativePositionBias(dim_head ** 0.5, heads = heads) if t5_rel_pos_bias else None
# 根据深度循环创建 ConformerBlock 层
for _ in range(depth):
self.layers.append(ConformerBlock(
dim = dim,
dim_head = dim_head,
heads = heads,
ff_mult = ff_mult,
conv_expansion_factor = conv_expansion_factor,
conv_kernel_size = conv_kernel_size,
attn_dropout = attn_dropout,
ff_dropout = ff_dropout,
conv_dropout = conv_dropout,
conv_causal = conv_causal,
attn_flash = attn_flash,
use_gateloop_layers = use_gateloop_layers
))
# 前向传播函数
def forward(self, x, mask = None):
seq_len = x.shape[-2]
# 如果存在 RotaryEmbedding 层,则创建旋转嵌入
rotary_emb = self.rotary_emb(seq_len) if exists(self.rotary_emb) else None
# 如果存在 T5RelativePositionBias 层,则创建注意力偏置
attn_bias = self.rel_pos_bias(seq_len) if exists(self.rel_pos_bias) else None
# 遍历 ConformerBlock 层进行前向传播
for block in self.layers:
x = block(
x,
mask = mask,
rotary_emb = rotary_emb,
attn_bias = attn_bias
)
return x
# conformer with sum reduction across quantized tokens at the beginning, along with heads
# 定义 ConformerWrapper 类
class ConformerWrapper(nn.Module):
@beartype
# 初始化函数
def __init__(
self,
*,
codebook_size, # 代码本大小
num_quantizers, # 量化器数量
conformer: Union[Conformer, Dict[str, any]], # Conformer 模型
grouped_quantizers = 1 # 分组量化器数量
):
# 调用父类的构造函数
super().__init__()
# 初始化属性conformer
self.conformer = conformer
# 如果conformer是字典类型,则使用Conformer类初始化self.conformer
if isinstance(conformer, dict):
self.conformer = Conformer(**self.conformer)
# 获取conformer的维度
dim = self.conformer.dim
# 根据grouped_quantizers的值判断是否需要进行embedding投影
self.embedding_proj = nn.Sequential(
nn.Linear(dim * grouped_quantizers, dim),
nn.LayerNorm(dim)
) if grouped_quantizers > 1 else nn.Identity()
# 计算带有mask的量化器代码数量
num_codes_with_mask = codebook_size + 1
num_effective_quantizers = num_quantizers * grouped_quantizers
# 初始化代码嵌入层
self.code_embeds = nn.Embedding(num_codes_with_mask * num_effective_quantizers, dim)
# 注册缓冲区,存储量化器偏移和mask标记
self.register_buffer('quantizer_offsets', torch.arange(num_effective_quantizers) * num_codes_with_mask, persistent=False)
self.register_buffer('mask_tokens', self.quantizer_offsets + num_codes_with_mask, persistent=False)
# 初始化其他属性
self.dim = dim
self.codebook_size = codebook_size
self.num_codes_with_mask = num_codes_with_mask
self.num_quantizers = num_quantizers
self.grouped_quantizers = grouped_quantizers
# 初始化头部
self.heads = nn.Sequential(
nn.Linear(dim, dim * num_effective_quantizers),
Rearrange('b n (h d) -> b (n h) d', h=num_effective_quantizers)
)
# 每个量化器代码本都需要自己的logits权重和偏置矩阵
# 使用EinMix和einops实现
self.to_logits = nn.Sequential(
nn.LayerNorm(dim),
Rearrange('b (n gq) d -> b n gq d', gq=num_effective_quantizers),
EinMix(
'b n gq d -> b n gq l',
weight_shape='gq d l',
bias_shape='gq l',
gq=num_effective_quantizers,
l=codebook_size,
d=dim
),
Rearrange('b ... d -> b (...) d')
)
def forward(
self,
x,
*,
mask=None,
cond=None,
sum_embeds=None,
return_embeddings=False,
return_logits_and_embeddings=False
):
"""
einops notation:
b - batch
n - sequence
g - groups
q - quantizers
d - feature dimension
"""
# 获取x的维度信息
n, q, g = x.shape[-1], self.num_quantizers, self.grouped_quantizers
assert divisible_by(n, g * q), 'sequence must be divisible by number of quantizers'
# 重排x的维度
x = rearrange(x, 'b (n gq) -> b n gq', gq=g * q)
x = x + self.quantizer_offsets
# 对x进行代码嵌入
x = self.code_embeds(x)
# 对x进行降维操作
x = reduce(x, 'b n (g q) d -> b n (g d)', 'sum', g=g)
# 对x进行嵌入投影
x = self.embedding_proj(x)
# 如果存在sum_embeds,则将其加到x上
if exists(sum_embeds):
x = x + sum_embeds
# 如果存在cond,则将其加到x上
if exists(cond):
if cond.ndim == 2:
cond = rearrange(cond, 'b d -> b 1 d')
x = x + cond
# 对x进行Conformer处理
x = self.conformer(x, mask=mask)
embeds = self.heads(x)
# 如果需要返回嵌入向量或者没有to_logits,则返回embeds
if return_embeddings or not exists(self.to_logits):
return embeds
# 获取logits
logits = self.to_logits(embeds)
# 如果需要返回logits和嵌入向量,则返回logits和embeds
if return_logits_and_embeddings:
return logits, embeds
return logits
# 定义 LogitHead 类,用于处理主要的 logits 以及自我 token 评论
class LogitHead(nn.Module):
def __init__(
self,
net: ConformerWrapper,
logit_dim
):
super().__init__()
self.net = net
dim = net.dim
self.to_logits = nn.Linear(dim, logit_dim)
def forward(self, x):
# 获取网络的嵌入表示
embed = self.net(x, return_embeddings = True)
return self.to_logits(embed)
# 定义 LossBreakdown 命名元组,包含生成器损失和评论家损失
LossBreakdown = namedtuple('LossBreakdown', ['generator_loss', 'critic_loss'])
# 定义 SoundStorm 类,用于处理声音数据
class SoundStorm(nn.Module):
@beartype
def __init__(
self,
net: ConformerWrapper,
*,
soundstream: Optional[SoundStream] = None,
spear_tts_text_to_semantic: Optional[TextToSemantic] = None,
wav2vec: Optional[Union[HubertWithKmeans, FairseqVQWav2Vec]] = None,
steps = 18,
self_cond = False,
self_cond_train_prob = 0.75,
no_replace_prob = 0.15, # 原始 MLM 论文中指定的一定比例的 tokens 会保持不变
random_token_prob = 0.1, # 原始 MLM 论文中指定的一定比例的 tokens 会被替换为随机 token
schedule = 'linear',
can_mask_prev_unmasked = False, # 当解除 mask 时,是否可以重新 mask 之前未 mask 的 tokens
self_token_critic = False, # 是否使用自我 token 评论家
critic_loss_weight = 1.,
num_semantic_token_ids = None,
semantic_pad_id = -1,
pad_id = None,
wav2vec_target_sample_hz = None,
wav2vec_downsample_factor = None,
codec_target_sample_hz = None,
codec_downsample_factor = None,
@property
def device(self):
return next(self.net.parameters()).device
def load(self, path, strict = True):
# 加载模型参数
# 返回 pkg,以便如果此函数从 Trainer 函数调用中调用,则 Trainer 也可以访问从检查点加载的 package
path = Path(path)
assert path.exists()
pkg = torch.load(str(path), map_location = 'cpu')
self.load_state_dict(pkg['model'], strict = strict)
return pkg
@torch.no_grad()
@eval_decorator
def generate(
self,
num_latents = None,
*,
mask = None,
texts: Optional[Union[List[str], Tensor]] = None,
cond_semantic_token_ids = None,
prompt_acoustic_token_ids = None,
seconds = None,
batch_size = None,
start_temperature = 1.,
filter_thres = 0.7,
noise_level_scale = 1.,
num_full_sampling_levels = 1,
text_to_semantic_generate_kwargs: dict = {},
spec_decode = False,
spec_decode_gamma = 5,
**kwargs
# 定义一个方法,用于获取条件信息
def maybe_get_condition(self, token_ids = None, length = None):
# 断言条件:如果传入的 token_ids 存在,则应该开启文本条件化,反之亦然
assert not (exists(token_ids) ^ self.should_condition), 'you either have text-conditioning turned on and have not passed in any conditioning semantic token ids, or vice versa'
# 如果 token_ids 不存在,则返回 None
if not exists(token_ids):
return None
# 根据是否存在文本到语义的映射,选择是否开启 torch 的无梯度上下文
context = torch.no_grad if exists(self.text_to_semantic) else nullcontext
# 在上下文中执行以下代码块
with context():
# 创建一个 mask,用于过滤掉语义填充标记
mask = token_ids != self.semantic_pad_id
# 如果存在文本到语义的映射,并且自动设置了 eos 语义标记 id
if exists(self.text_to_semantic) and self.text_to_semantic.autoset_eos_id['speech']:
# 进一步过滤掉 eos 语义标记 id
mask &= token_ids != self.num_semantic_token_ids
# 将不符合 mask 的 token_ids 替换为 0
token_ids = token_ids.masked_fill(~mask, 0)
# 获取语义标记的嵌入
semantic_tokens = self.semantic_token_emb(token_ids)
# 将语义标记转换为模型维度的条件 tokens
cond_tokens = self.semantic_cond_to_model_dim(semantic_tokens)
# 将填充部分的值设为 0,让网络学习处理
cond_tokens = cond_tokens.masked_fill(~rearrange(mask, '... -> ... 1'), 0.)
# 需要插值条件 tokens,以使语义和向量量化 tokens 在时间上对齐
cond_length = cond_tokens.shape[-2]
# 计算目标条件长度
target_cond_length = math.ceil(cond_length * (self.wav2vec_downsample_factor / self.wav2vec_target_sample_hz) / (self.codec_downsample_factor / self.codec_target_sample_hz))
# 由于 PyTorch 不支持 1D 插值,将数据转换为 2D 进行插值
if cond_length != target_cond_length:
cond_tokens = rearrange(cond_tokens, 'b n d -> b d n 1')
cond_tokens = F.interpolate(cond_tokens, (target_cond_length, 1), mode = 'bilinear')
cond_tokens = rearrange(cond_tokens, 'b d n 1 -> b n d')
# 根据长度是否存在,决定是截断还是填充条件 tokens
cond_length = cond_tokens.shape[-2]
if exists(length):
if cond_length < length:
cond_tokens = F.pad(cond_tokens, (0, 0, 0, length - cond_length), value = 0.)
elif cond_length > length:
cond_tokens = cond_tokens[:, :length]
# 返回处理后的条件 tokens
return cond_tokens
# 定义前向传播方法
def forward(
self,
x,
*,
mask = None,
cond_semantic_token_ids = None,
only_train_generator = False,
only_train_critic = False,
generator_sample_temperature = None,
**kwargs