Lucidrains 系列项目源码解析(六十五)
.\lucidrains\musiclm-pytorch\setup.py
# 导入设置工具和查找包的函数
from setuptools import setup, find_packages
# 设置包的元数据
setup(
name = 'musiclm-pytorch', # 包的名称
packages = find_packages(exclude=[]), # 查找所有包
version = '0.2.8', # 版本号
license='MIT', # 许可证
description = 'MusicLM - AudioLM + Audio CLIP to text to music synthesis', # 描述
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
long_description_content_type = 'text/markdown', # 长描述内容类型
url = 'https://github.com/lucidrains/musiclm-pytorch', # 项目链接
keywords = [ # 关键词列表
'artificial intelligence',
'deep learning',
'transformers',
'attention mechanism',
'text to music',
'contrastive learning'
],
install_requires=[ # 安装依赖列表
'accelerate',
'audiolm-pytorch>=0.17.0',
'beartype',
'einops>=0.6',
'lion-pytorch',
'vector-quantize-pytorch>=1.0.0',
'x-clip',
'torch>=1.12',
'torchaudio'
],
classifiers=[ # 分类器列表
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\n-grammer-pytorch\n_grammer_pytorch\n_grammer_pytorch.py
# 基于 jax 代码的实现
# https://github.com/tensorflow/lingvo/blob/master/lingvo/jax/layers/ngrammer.py
import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat
import sympy
# 辅助函数
def exists(val):
return val is not None
def sum_squares(t, dim = -1):
return (t ** 2).sum(dim = dim)
# 与 bigram 相关的函数
def multi_way_hash_ids(x, a, b, prime, buckets):
return ((x * a + b) % prime) % buckets
def get_bigram_ids(ids, vocab_size, segment_pos = None):
# ids 的形状为 (batch, seq, heads)
ids = ids.long()
ids_0 = F.pad(ids, (0, 0, 0, 1))
ids_1 = F.pad(ids, (0, 0, 1, 0))
if exists(segment_pos):
segment_pos = rearrange(segment_pos, 'b n -> b n 1')
mask = (segment_pos == 0).long()
mask = 1 - mask
mask = F.pad(mask, (0, 0, 0, 1))
ids_1 *= mask
ngram_ids = ids_0 + ids_1 * vocab_size
ngram_ids = ngram_ids[:, :-1]
return ngram_ids
# 与优化器相关的函数
def get_ngrammer_parameters(module):
params = set()
for m in module.modules():
if isinstance(m, Ngrammer):
params.update(m.parameters())
rest = set(module.parameters()) - params
return list(params), list(rest)
def get_ngrammer_param_groups(module, ngrammer_learning_rate = 1e-2):
ngrammer_params, rest = get_ngrammer_parameters(module)
return [{'params': rest}, {'params': ngrammer_params, 'lr': ngrammer_learning_rate}]
# layernorm
class MultiheadLayerNorm(nn.Module):
def __init__(self, dim, heads = 1, eps = 1e-5):
super().__init__()
self.eps = eps
self.g = nn.Parameter(torch.ones(heads, dim))
self.b = nn.Parameter(torch.zeros(heads, dim))
def forward(self, x):
std = torch.var(x, dim = -1, unbiased = False, keepdim = True).sqrt()
mean = torch.mean(x, dim = -1, keepdim = True)
return (x - mean) / (std + self.eps) * self.g + self.b
# 类
class VectorQuantization(nn.Module):
def __init__(
self,
*,
num_clusters,
num_heads,
dim_per_head,
decay = 0.999,
epsilon = 1e-6
):
super().__init__()
self.decay = decay
self.epsilon = epsilon
self.num_heads = num_heads
self.dim_per_head = dim_per_head
self.num_clusters = num_clusters
self.register_buffer('means', torch.randn(num_heads, num_clusters, dim_per_head))
def forward(
self,
x,
mask = None
):
h, dim_head, num_clusters, eps, decay, means = self.num_heads, self.dim_per_head, self.num_clusters, self.epsilon, self.decay, self.means
assert x.shape[-1] == (h * dim_head), f'input embedding feature dimension must be {h * dim_head}'
# 将输入中的头部分离出来
x = rearrange(x, 'b n (h d) -> b n h d', h = h)
# 获取输入嵌入与均值之间的距离
dists = (
rearrange(sum_squares(x), 'b n h -> b n h 1')
- 2 * einsum('b n h d, h k d -> b n h k', x, means)
+ rearrange(sum_squares(means), 'h k -> 1 1 h k')
)
# 获取簇 id
cluster_ids = dists.argmin(dim = -1)
if self.training:
# 获取 one hot 编码,用于计算每个均值的匹配数
nearest_one_hot = F.one_hot(cluster_ids, num_classes = num_clusters)
per_cluster_count = nearest_one_hot.sum(dim = (0, 1))
# 每个最近质心的输入之和。
sum_x = einsum('b n h k, b n h d -> h k d', nearest_one_hot.float(), x)
# 计算新的均值
new_means = sum_x / (eps + rearrange(per_cluster_count, '... -> ... 1'))
# 指数移动平均
updated_means = (1. - decay) * new_means + decay * means
self.means.data.copy_(updated_means)
return cluster_ids
class Ngrammer(nn.Module):
# 初始化函数,设置模型参数
def __init__(
self,
*,
unigram_vocab_size, # 单字词汇表大小
dim_per_head, # 每个头的维度
num_heads = 1, # 头的数量,默认为1
ngram_emb_dim = 8, # n-gram嵌入维度,默认为8
ngram_vocab_size = 768 * 256, # n-gram词汇表大小,默认为768 * 256
concat_ngrams = True # 是否连接n-gram,默认为True
):
super().__init__()
# 断言,确保当连接n-gram时,每个头的维度不能小于n-gram嵌入维度
assert not (concat_ngrams and dim_per_head <= ngram_emb_dim), 'unigram head dimension cannot be smaller than ngram embedding dimension when concatting'
# 断言,确保当不连接n-gram时,每个头的维度必须等于n-gram嵌入维度
assert not (not concat_ngrams and dim_per_head != ngram_emb_dim), 'unigram head dimension must be equal to ngram embedding dimension if not concatting'
# 初始化模型参数
self.num_heads = num_heads
self.ngram_vocab_size = ngram_vocab_size
self.unigram_vocab_size = unigram_vocab_size
self.concat_ngrams = concat_ngrams
# 初始化模型的嵌入层
self.embeddings = nn.ModuleList([])
# 初始化n-gram的LayerNorm
self.ngram_layernorm = MultiheadLayerNorm(ngram_emb_dim, heads = num_heads)
# 初始化嵌入的LayerNorm
self.embeds_layernorm = MultiheadLayerNorm(dim_per_head, heads = num_heads)
# 初始化n-gram的Embedding层
self.ngram_embeds = nn.Embedding(ngram_vocab_size * num_heads, ngram_emb_dim)
# 生成质数列表,用于多头哈希计算
primes = list(sympy.primerange(ngram_vocab_size + 1, 2 * ngram_vocab_size))[:num_heads]
self.register_buffer('primes', torch.tensor(primes), persistent = False)
# 前向传播函数
def forward(
self,
embeds, # 嵌入
cluster_ids, # 聚类ID
mask = None, # 掩码,默认为None
segment_pos = None # 分段位置,默认为None
):
# 获取模型参数
num_heads, vocab_size, unigram_vocab_size, device = self.num_heads, self.ngram_vocab_size, self.unigram_vocab_size, embeds.device
# 如果聚类ID的维度为2,则重复扩展为多头
if cluster_ids.ndim == 2:
cluster_ids = repeat(cluster_ids, '... -> ... h', h = num_heads)
# 获取n-gram聚类ID
ngram_cluster_ids = get_bigram_ids(cluster_ids, unigram_vocab_size, segment_pos)
# 准备用于并行计算多头哈希ID的头范围
head_range = torch.arange(num_heads, device = device)
head_range = rearrange(head_range, 'h -> 1 1 h')
primes = rearrange(self.primes, 'h -> 1 1 h')
# 多头哈希ID计算
ngram_ids = multi_way_hash_ids(ngram_cluster_ids, head_range + 1, head_range + 1, primes, vocab_size)
# 根据头编号适当地移动词汇范围
ngram_ids = ngram_ids + (vocab_size * head_range)
# 一次性获取所有n-gram嵌入,并进行多头LayerNorm
ngram_embeds = self.ngram_embeds(ngram_ids)
normed_ngram_embeds = self.ngram_layernorm(ngram_embeds)
# 多头LayerNorm输入
embeds = rearrange(embeds, 'b n (h d) -> b n h d', h = num_heads)
normed_embeds = self.embeds_layernorm(embeds)
# 连接原始单字嵌入和bigram
if self.concat_ngrams:
input_sliced_dim = normed_embeds.shape[-1] - normed_ngram_embeds.shape[-1]
out = torch.cat((
normed_embeds[..., :input_sliced_dim],
normed_ngram_embeds
), dim = -1)
else:
out = normed_embeds + normed_ngram_embeds
# 展平
out = rearrange(out, 'b n ... -> b n (...)')
# 如果需要,进行掩码
if exists(mask):
out = out * rearrange(mask, 'b n -> b n 1').float()
return out
# 主类定义
class VQNgrammer(nn.Module):
def __init__(
self,
*,
num_clusters, # 聚类中心数量
num_heads, # 多头注意力机制中头的数量
dim_per_head, # 每个头的维度
ngram_vocab_size = 768 * 256, # N-gram词汇表大小,默认为768*256
ngram_emb_dim = 8, # N-gram嵌入维度,默认为8
concat_ngrams = True, # 是否连接N-gram
decay = 0.999, # 衰减率,默认为0.999
epsilon = 1e-6 # 防止除零错误的小值,默认为1e-6
):
super().__init__()
assert ngram_vocab_size < (num_clusters ** 2), 'the ngram vocab size should be less than the number of clusters squared'
# 初始化向量量化模块
self.vq = VectorQuantization(
num_clusters = num_clusters,
num_heads = num_heads,
dim_per_head = dim_per_head,
decay = decay,
epsilon = epsilon
)
# 初始化N-gram模块
self.ngram = Ngrammer(
unigram_vocab_size = num_clusters,
ngram_vocab_size = ngram_vocab_size,
ngram_emb_dim = ngram_emb_dim,
concat_ngrams = concat_ngrams,
num_heads = num_heads,
dim_per_head = dim_per_head
)
def forward(
self,
x,
mask = None,
segment_pos = None
):
# 使用向量量化模块对输入进行聚类
cluster_ids = self.vq(x, mask = mask)
# 使用N-gram模块处理输入数据
out = self.ngram(
x,
cluster_ids = cluster_ids,
mask = mask,
segment_pos = segment_pos
)
return out
.\lucidrains\n-grammer-pytorch\n_grammer_pytorch\__init__.py
# 从 n_grammer_pytorch.n_grammer_pytorch 模块中导入 VQNgrammer, Ngrammer, get_ngrammer_parameters, get_ngrammer_param_groups 类/函数
from n_grammer_pytorch.n_grammer_pytorch import VQNgrammer, Ngrammer, get_ngrammer_parameters, get_ngrammer_param_groups

N-Grammer - Pytorch
Implementation of N-Grammer, augmenting Transformers with latent n-grams, in Pytorch
Install
$ pip install n-grammer-pytorch
Usage
import torch
from n_grammer_pytorch import VQNgrammer
vq_ngram = VQNgrammer(
num_clusters = 1024, # number of clusters
dim_per_head = 32, # dimension per head
num_heads = 16, # number of heads
ngram_vocab_size = 768 * 256, # ngram vocab size
ngram_emb_dim = 16, # ngram embedding dimension
decay = 0.999 # exponential moving decay value
)
x = torch.randn(1, 1024, 32 * 16)
vq_ngram(x) # (1, 1024, 32 * 16)
Learning Rates
Like product key memories, Ngrammer parameters need to have a higher learning rate (1e-2 was recommended in the paper). The repository offers an easy way to generate the parameter groups.
from torch.optim import Adam
from n_grammer_pytorch import get_ngrammer_parameters
# this helper function, for your root model, finds all the VQNgrammer models and the embedding parameters
ngrammer_parameters, other_parameters = get_ngrammer_parameters(transformer)
optim = Adam([
{'params': other_parameters},
{'params': ngrammer_parameters, 'lr': 1e-2}
], lr = 3e-4)
Or, even more simply
from torch.optim import Adam
from n_grammer_pytorch import get_ngrammer_param_groups
param_groups = get_ngrammer_param_groups(model) # automatically creates array of parameter settings with learning rate set at 1e-2 for ngrammer parameter values
optim = Adam(param_groups, lr = 3e-4)
Citations
@inproceedings{thai2020using,
title = {N-grammer: Augmenting Transformers with latent n-grams},
author = {Anonymous},
year = {2021},
url = {https://openreview.net/forum?id=GxjCYmQAody}
}
.\lucidrains\n-grammer-pytorch\setup.py
# 导入设置工具和查找包的函数
from setuptools import setup, find_packages
# 设置包的元数据
setup(
name = 'n-grammer-pytorch', # 包的名称
packages = find_packages(exclude=[]), # 查找所有包
version = '0.0.14', # 版本号
license='MIT', # 许可证
description = 'N-Grammer - Pytorch', # 描述
long_description_content_type = 'text/markdown', # 长描述内容类型
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
url = 'https://github.com/lucidrains/n-grammer-pytorch', # 项目链接
keywords = [ # 关键词列表
'artificial intelligence',
'attention mechanism',
'transformers',
'n-grams',
'memory'
],
install_requires=[ # 安装依赖
'einops>=0.3',
'sympy',
'torch>=1.6'
],
classifiers=[ # 分类器
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\naturalspeech2-pytorch\naturalspeech2_pytorch\aligner.py
from typing import Tuple
import numpy as np
import torch
from torch import nn, Tensor
from torch.nn import Module
import torch.nn.functional as F
from einops import rearrange, repeat
from beartype import beartype
from beartype.typing import Optional
# 检查变量是否存在
def exists(val):
return val is not None
# 定义对齐模型类
class AlignerNet(Module):
"""alignment model https://arxiv.org/pdf/2108.10447.pdf """
def __init__(
self,
dim_in=80,
dim_hidden=512,
attn_channels=80,
temperature=0.0005,
):
super().__init__()
self.temperature = temperature
# 定义关键字层
self.key_layers = nn.ModuleList([
nn.Conv1d(
dim_hidden,
dim_hidden * 2,
kernel_size=3,
padding=1,
bias=True,
),
nn.ReLU(inplace=True),
nn.Conv1d(dim_hidden * 2, attn_channels, kernel_size=1, padding=0, bias=True)
])
# 定义查询层
self.query_layers = nn.ModuleList([
nn.Conv1d(
dim_in,
dim_in * 2,
kernel_size=3,
padding=1,
bias=True,
),
nn.ReLU(inplace=True),
nn.Conv1d(dim_in * 2, dim_in, kernel_size=1, padding=0, bias=True),
nn.ReLU(inplace=True),
nn.Conv1d(dim_in, attn_channels, kernel_size=1, padding=0, bias=True)
])
# 前向传播函数
@beartype
def forward(
self,
queries: Tensor,
keys: Tensor,
mask: Optional[Tensor] = None
):
key_out = keys
for layer in self.key_layers:
key_out = layer(key_out)
query_out = queries
for layer in self.query_layers:
query_out = layer(query_out)
key_out = rearrange(key_out, 'b c t -> b t c')
query_out = rearrange(query_out, 'b c t -> b t c')
attn_logp = torch.cdist(query_out, key_out)
attn_logp = rearrange(attn_logp, 'b ... -> b 1 ...')
if exists(mask):
mask = rearrange(mask.bool(), '... c -> ... 1 c')
attn_logp.data.masked_fill_(~mask, -torch.finfo(attn_logp.dtype).max)
attn = attn_logp.softmax(dim = -1)
return attn, attn_logp
# 填充张量函数
def pad_tensor(input, pad, value=0):
pad = [item for sublist in reversed(pad) for item in sublist] # Flatten the tuple
assert len(pad) // 2 == len(input.shape), 'Padding dimensions do not match input dimensions'
return F.pad(input, pad, mode='constant', value=value)
# 最大路径函数
def maximum_path(value, mask, const=None):
device = value.device
dtype = value.dtype
if not exists(const):
const = torch.tensor(float('-inf')).to(device) # Patch for Sphinx complaint
value = value * mask
b, t_x, t_y = value.shape
direction = torch.zeros(value.shape, dtype=torch.int64, device=device)
v = torch.zeros((b, t_x), dtype=torch.float32, device=device)
x_range = torch.arange(t_x, dtype=torch.float32, device=device).view(1, -1)
for j in range(t_y):
v0 = pad_tensor(v, ((0, 0), (1, 0)), value = const)[:, :-1]
v1 = v
max_mask = v1 >= v0
v_max = torch.where(max_mask, v1, v0)
direction[:, :, j] = max_mask
index_mask = x_range <= j
v = torch.where(index_mask.view(1,-1), v_max + value[:, :, j], const)
direction = torch.where(mask.bool(), direction, 1)
path = torch.zeros(value.shape, dtype=torch.float32, device=device)
index = mask[:, :, 0].sum(1).long() - 1
index_range = torch.arange(b, device=device)
for j in reversed(range(t_y)):
path[index_range, index, j] = 1
index = index + direction[index_range, index, j] - 1
path = path * mask.float()
path = path.to(dtype=dtype)
return path
# 前向求和损失类
class ForwardSumLoss(Module):
def __init__(
self,
blank_logprob = -1
# 初始化类,继承父类的属性和方法
):
super().__init__()
# 设置空白标签的对数概率
self.blank_logprob = blank_logprob
# 创建 CTC 损失函数对象
self.ctc_loss = torch.nn.CTCLoss(
blank = 0, # 设置空白标签的值为0
zero_infinity = True # 设置是否将无穷大值转换为零
)
# 前向传播函数
def forward(self, attn_logprob, key_lens, query_lens):
# 获取设备信息和空白标签对数概率
device, blank_logprob = attn_logprob.device, self.blank_logprob
# 获取输入的最大键长度
max_key_len = attn_logprob.size(-1)
# 重新排列输入数据的维度为[query_len, batch_size, key_len]
attn_logprob = rearrange(attn_logprob, 'b 1 c t -> c b t')
# 添加空白标签
attn_logprob = F.pad(attn_logprob, (1, 0, 0, 0, 0, 0), value = blank_logprob)
# 转换为对数概率
# 注意:屏蔽超出键长度的概率
mask_value = -torch.finfo(attn_logprob.dtype).max
attn_logprob.masked_fill_(torch.arange(max_key_len + 1, device=device, dtype=torch.long).view(1, 1, -1) > key_lens.view(1, -1, 1), mask_value)
attn_logprob = attn_logprob.log_softmax(dim = -1)
# 目标序列
target_seqs = torch.arange(1, max_key_len + 1, device=device, dtype=torch.long)
target_seqs = repeat(target_seqs, 'n -> b n', b = key_lens.numel())
# 计算 CTC 损失
cost = self.ctc_loss(attn_logprob, target_seqs, query_lens, key_lens)
return cost
class BinLoss(Module):
# 定义一个继承自 Module 的 BinLoss 类
def forward(self, attn_hard, attn_logprob, key_lens):
# 前向传播函数,接受注意力机制的硬分配、对数概率和键长度作为输入
batch, device = attn_logprob.shape[0], attn_logprob.device
# 获取 batch 大小和设备信息
max_key_len = attn_logprob.size(-1)
# 获取键的最大长度
# 重新排列输入为 [query_len, batch_size, key_len]
attn_logprob = rearrange(attn_logprob, 'b 1 c t -> c b t')
attn_hard = rearrange(attn_hard, 'b t c -> c b t')
# 重新排列注意力机制的输入形状
mask_value = -torch.finfo(attn_logprob.dtype).max
# 创建一个用于掩码的值
attn_logprob.masked_fill_(torch.arange(max_key_len, device=device, dtype=torch.long).view(1, 1, -1) > key_lens.view(1, -1, 1), mask_value)
# 使用掩码值对注意力对数概率进行填充
attn_logprob = attn_logprob.log_softmax(dim = -1)
# 对注意力对数概率进行 log_softmax 操作
return (attn_hard * attn_logprob).sum() / batch
# 返回加权后的结果除以 batch 大小
class Aligner(Module):
# 定义一个继承自 Module 的 Aligner 类
def __init__(
self,
dim_in,
dim_hidden,
attn_channels=80,
temperature=0.0005
):
# 初始化函数,接受输入维度、隐藏维度、注意力通道数和温度参数
super().__init__()
self.dim_in = dim_in
self.dim_hidden = dim_hidden
self.attn_channels = attn_channels
self.temperature = temperature
# 设置类的属性
self.aligner = AlignerNet(
dim_in = self.dim_in,
dim_hidden = self.dim_hidden,
attn_channels = self.attn_channels,
temperature = self.temperature
)
# 初始化 AlignerNet 模型
def forward(
self,
x,
x_mask,
y,
y_mask
):
# 前向传播函数,接受输入 x、x_mask、y、y_mask
alignment_soft, alignment_logprob = self.aligner(y, rearrange(x, 'b d t -> b t d'), x_mask)
# 使用 AlignerNet 模型计算软对齐和对数概率
x_mask = rearrange(x_mask, '... i -> ... i 1')
y_mask = rearrange(y_mask, '... j -> ... 1 j')
attn_mask = x_mask * y_mask
attn_mask = rearrange(attn_mask, 'b 1 i j -> b i j')
# 生成注意力掩码
alignment_soft = rearrange(alignment_soft, 'b 1 c t -> b t c')
alignment_mask = maximum_path(alignment_soft, attn_mask)
# 重新排列软对齐结果并计算最大路径
alignment_hard = torch.sum(alignment_mask, -1).int()
# 计算硬对齐结果
return alignment_hard, alignment_soft, alignment_logprob, alignment_mask
# 返回硬对齐结果、软对齐结果、对数概率和对齐掩码
if __name__ == '__main__':
# 如果作为脚本运行
batch_size = 10
seq_len_y = 200 # 序列 y 的长度
seq_len_x = 35
feature_dim = 80 # 特征维度
x = torch.randn(batch_size, 512, seq_len_x)
y = torch.randn(batch_size, seq_len_y, feature_dim)
y = y.transpose(1,2) #dim-1 is the channels for conv
# 生成输入 x 和 y,并对 y 进行转置
# 创建掩码
x_mask = torch.ones(batch_size, 1, seq_len_x)
y_mask = torch.ones(batch_size, 1, seq_len_y)
align = Aligner(dim_in = 80, dim_hidden=512, attn_channels=80)
# 初始化 Aligner 模型
alignment_hard, alignment_soft, alignment_logprob, alignment_mas = align(x, x_mask, y, y_mask)
# 进行对齐操作
.\lucidrains\naturalspeech2-pytorch\naturalspeech2_pytorch\attend.py
# 导入必要的库
from collections import namedtuple
from functools import wraps
from packaging import version
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange
# 定义一个命名元组 Config,用于存储 EfficientAttention 的配置信息
Config = namedtuple('EfficientAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
# 定义一个辅助函数,用于检查变量是否存在
def exists(val):
return val is not None
# 定义一个装饰器函数,确保被装饰的函数只执行一次
def once(fn):
called = False
@wraps(fn)
def inner(x):
nonlocal called
if called:
return
called = True
return fn(x)
return inner
# 用 once 装饰 print 函数,确保只打印一次
print_once = once(print)
# 主要的 Attend 类
class Attend(nn.Module):
def __init__(
self,
dropout = 0.,
causal = False,
use_flash = False
):
super().__init__()
self.dropout = dropout
self.attn_dropout = nn.Dropout(dropout)
self.causal = causal
self.register_buffer("mask", None, persistent=False)
self.use_flash = use_flash
assert not (use_flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
# 确定在 cuda 和 cpu 上的高效注意力配置
self.cpu_config = Config(True, True, True)
self.cuda_config = None
if not torch.cuda.is_available() or not use_flash:
return
device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
if device_properties.major == 8 and device_properties.minor == 0:
print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
self.cuda_config = Config(True, False, False)
else:
print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
self.cuda_config = Config(False, True, True)
# 获取掩码
def get_mask(self, n, device):
if exists(self.mask) and self.mask.shape[-1] >= n:
return self.mask[:n, :n]
mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1)
self.register_buffer("mask", mask, persistent=False)
return mask
# Flash Attention 函数
def flash_attn(self, q, k, v, mask = None):
_, heads, q_len, _, k_len, is_cuda = *q.shape, k.shape[-2], q.is_cuda
# 推荐的多查询单键值注意力结构
if k.ndim == 3:
k = rearrange(k, 'b ... -> b 1 ...').expand_as(q)
if v.ndim == 3:
v = rearrange(v, 'b ... -> b 1 ...').expand_as(q)
# 检查掩码是否存在并扩展到兼容的形状
if exists(mask):
mask = rearrange(mask, 'b j -> b 1 1 j')
mask = mask.expand(-1, heads, q_len, -1)
# 检查是否有兼容的设备用于 Flash Attention
config = self.cuda_config if is_cuda else self.cpu_config
# 使用 pytorch 2.0 的 Flash Attention
with torch.backends.cuda.sdp_kernel(**config._asdict()):
out = F.scaled_dot_product_attention(
q, k, v,
attn_mask = mask,
dropout_p = self.dropout if self.training else 0.,
is_causal = self.causal
)
return out
# 定义一个前向传播函数,实现注意力机制
def forward(self, q, k, v, mask = None):
"""
einstein notation
b - batch
h - heads
n, i, j - sequence length (base sequence length, source, target)
d - feature dimension
"""
# 获取序列长度和设备信息
n, device = q.shape[-2], q.device
# 计算缩放因子
scale = q.shape[-1] ** -0.5
# 如果使用闪回注意力机制,则调用相应函数
if self.use_flash:
return self.flash_attn(q, k, v, mask = mask)
# 根据输入维度确定键值对的 einsum 方程
kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d'
# 计算相似度
sim = einsum(f"b h i d, {kv_einsum_eq} -> b h i j", q, k) * scale
# 处理键的填充掩码
if exists(mask):
mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
# 处理因果掩码
if self.causal:
causal_mask = self.get_mask(n, device)
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
# 计算注意力权重
attn = sim.softmax(dim=-1)
attn = self.attn_dropout(attn)
# 聚合数值
out = einsum(f"b h i j, {kv_einsum_eq} -> b h i d", attn, v)
return out
.\lucidrains\naturalspeech2-pytorch\naturalspeech2_pytorch\naturalspeech2_pytorch.py
# 导入所需的库
import math
import copy
from multiprocessing import cpu_count
from pathlib import Path
from random import random
from functools import partial
from collections import namedtuple
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn, einsum, Tensor
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader
import torchaudio
import torchaudio.transforms as T
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce
from audiolm_pytorch import SoundStream, EncodecWrapper
from audiolm_pytorch.data import SoundDataset, get_dataloader
from beartype import beartype
from beartype.typing import Tuple, Union, Optional, List
from beartype.door import is_bearable
from naturalspeech2_pytorch.attend import Attend
from naturalspeech2_pytorch.aligner import Aligner, ForwardSumLoss, BinLoss
from naturalspeech2_pytorch.utils.tokenizer import Tokenizer, ESpeak
from naturalspeech2_pytorch.utils.utils import average_over_durations, create_mask
from naturalspeech2_pytorch.version import __version__
from accelerate import Accelerator
from ema_pytorch import EMA
from tqdm.auto import tqdm
import pyworld as pw
# 定义常量
mlist = nn.ModuleList
def Sequential(*mods):
return nn.Sequential(*filter(exists, mods))
# 辅助函数
def exists(x):
return x is not None
def default(val, d):
if exists(val):
return val
return d() if callable(d) else d
def divisible_by(num, den):
return (num % den) == 0
def identity(t, *args, **kwargs):
return t
def has_int_squareroot(num):
return (math.sqrt(num) ** 2) == num
# 张量辅助函数
def pad_or_curtail_to_length(t, length):
if t.shape[-1] == length:
return t
if t.shape[-1] > length:
return t[..., :length]
return F.pad(t, (0, length - t.shape[-1]))
def prob_mask_like(shape, prob, device):
if prob == 1:
return torch.ones(shape, device=device, dtype=torch.bool)
elif prob == 0:
return torch.zeros(shape, device=device, dtype=torch.bool)
else:
return torch.zeros(shape, device=device).float().uniform_(0, 1) < prob
def generate_mask_from_repeats(repeats):
repeats = repeats.int()
device = repeats.device
lengths = repeats.sum(dim=-1)
max_length = lengths.amax().item()
cumsum = repeats.cumsum(dim=-1)
cumsum_exclusive = F.pad(cumsum, (1, -1), value=0.)
seq = torch.arange(max_length, device=device)
seq = repeat(seq, '... j -> ... i j', i=repeats.shape[-1])
cumsum = rearrange(cumsum, '... i -> ... i 1')
cumsum_exclusive = rearrange(cumsum_exclusive, '... i -> ... i 1')
lengths = rearrange(lengths, 'b -> b 1 1')
mask = (seq < cumsum) & (seq >= cumsum_exclusive) & (seq < lengths)
return mask
# 正弦位置嵌入
class LearnedSinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
assert divisible_by(dim, 2)
half_dim = dim // 2
self.weights = nn.Parameter(torch.randn(half_dim))
def forward(self, x):
x = rearrange(x, 'b -> b 1')
freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
fouriered = torch.cat((x, fouriered), dim=-1)
return fouriered
# 计算音高
def compute_pitch_pytorch(wav, sample_rate):
# 使用 torchaudio 库中的 compute_kaldi_pitch 函数计算音高特征
pitch_feature = torchaudio.functional.compute_kaldi_pitch(wav, sample_rate)
pitch, nfcc = pitch_feature.unbind(dim=-1)
return pitch
# 根据论文使用 pyworld 计算音高
def compute_pitch_pyworld(wav, sample_rate, hop_length, pitch_fmax=640.0):
is_tensor_input = torch.is_tensor(wav)
if is_tensor_input:
device = wav.device
wav = wav.contiguous().cpu().numpy()
# 如果音频长度可以被 hop_length 整除,则在末尾填充一半的 hop_length 长度,使用反射模式填充
if divisible_by(len(wav), hop_length):
wav = np.pad(wav, (0, hop_length // 2), mode="reflect")
# 将音频数据类型转换为双精度浮点型
wav = wav.astype(np.double)
# 初始化一个空列表用于存储音频样本的基频值
outs = []
# 遍历音频样本,提取基频值
for sample in wav:
# 使用 dio 函数提取音频样本的基频值和时间信息
f0, t = pw.dio(
sample,
fs = sample_rate,
f0_ceil = pitch_fmax,
frame_period = 1000 * hop_length / sample_rate,
)
# 使用 stonemask 函数对基频值进行修正
f0 = pw.stonemask(sample, f0, t, sample_rate)
# 将修正后的基频值添加到 outs 列表中
outs.append(f0)
# 将 outs 列表转换为 numpy 数组
outs = np.stack(outs)
# 如果输入是张量形式,则将 outs 转换为张量并移动到指定设备上
if is_tensor_input:
outs = torch.from_numpy(outs).to(device)
# 返回提取的基频值
return outs
def f0_to_coarse(f0, f0_bin = 256, f0_max = 1100.0, f0_min = 50.0):
# 计算最大和最小频率对应的梅尔频率
f0_mel_max = 1127 * torch.log(1 + torch.tensor(f0_max) / 700)
f0_mel_min = 1127 * torch.log(1 + torch.tensor(f0_min) / 700)
# 计算输入频率对应的梅尔频率
f0_mel = 1127 * (1 + f0 / 700).log()
# 对梅尔频率进行线性变换,映射到[1, f0_bin-1]的范围
f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * (f0_bin - 2) / (f0_mel_max - f0_mel_min) + 1
# 将小于等于1的值设置为1
f0_mel[f0_mel <= 1] = 1
# 将大于f0_bin-1的值设置为f0_bin-1
f0_mel[f0_mel > f0_bin - 1] = f0_bin - 1
# 对梅尔频率四舍五入取整
f0_coarse = (f0_mel + 0.5).int()
# 断言确保f0_coarse的取值范围在[1, 255]之间
assert f0_coarse.max() <= 255 and f0_coarse.min() >= 1, (f0_coarse.max(), f0_coarse.min())
return f0_coarse
# peripheral models
# audio to mel
class AudioToMel(nn.Module):
def __init__(
self,
*,
n_mels = 100,
sampling_rate = 24000,
f_max = 8000,
n_fft = 1024,
win_length = 640,
hop_length = 160,
log = True
):
super().__init__()
self.log = log
self.n_mels = n_mels
self.n_fft = n_fft
self.f_max = f_max
self.win_length = win_length
self.hop_length = hop_length
self.sampling_rate = sampling_rate
def forward(self, audio):
# 创建STFT变换对象
stft_transform = T.Spectrogram(
n_fft = self.n_fft,
win_length = self.win_length,
hop_length = self.hop_length,
window_fn = torch.hann_window
)
# 对音频进行STFT变换得到频谱图
spectrogram = stft_transform(audio)
# 创建梅尔频率变换对象
mel_transform = T.MelScale(
n_mels = self.n_mels,
sample_rate = self.sampling_rate,
n_stft = self.n_fft // 2 + 1,
f_max = self.f_max
)
# 对频谱图进行梅尔频率变换得到梅尔频谱图
mel = mel_transform(spectrogram)
# 如果log为True,则将梅尔频谱图转换为对数幅度
if self.log:
mel = T.AmplitudeToDB()(mel)
return mel
# phoneme - pitch - speech prompt - duration predictors
class PhonemeEncoder(nn.Module):
def __init__(
self,
*,
tokenizer: Optional[Tokenizer] = None,
num_tokens = None,
dim = 512,
dim_hidden = 512,
kernel_size = 9,
depth = 6,
dim_head = 64,
heads = 8,
conv_dropout = 0.2,
attn_dropout = 0.,
use_flash = False
):
super().__init__()
# 初始化模型参数
self.tokenizer = tokenizer
num_tokens = default(num_tokens, tokenizer.vocab_size if exists(tokenizer) else None)
self.token_emb = nn.Embedding(num_tokens + 1, dim) if exists(num_tokens) else nn.Identity()
self.pad_id = num_tokens
same_padding = (kernel_size - 1) // 2
# 定义卷积层和变换层
self.conv = nn.Sequential(
Rearrange('b n c -> b c n'),
CausalConv1d(dim, dim_hidden, kernel_size),
nn.SiLU(),
nn.Dropout(conv_dropout),
Rearrange('b c n -> b n c'),
)
self.transformer = Transformer(
dim = dim_hidden,
depth = depth,
dim_head = dim_head,
heads = heads,
dropout = attn_dropout,
use_flash = use_flash
)
@beartype
def forward(
self,
x: Union[Tensor, List[str]],
mask = None
):
# 如果输入为字符串列表,则将其转换为张量
if is_bearable(x, List[str]):
assert exists(self.tokenizer)
x = self.tokenizer.texts_to_tensor_ids(x)
# 将小于0的值设置为pad_id
is_padding = x < 0
x = x.masked_fill(is_padding, self.pad_id)
x = self.token_emb(x)
x = self.conv(x)
x = self.transformer(x, mask = mask)
return x
class SpeechPromptEncoder(nn.Module):
@beartype
def __init__(
self,
dim_codebook,
dims: Tuple[int] = (256, 2048, 2048, 2048, 2048, 512, 512, 512),
*,
depth = 6,
heads = 8,
dim_head = 64,
dropout = 0.2,
kernel_size = 9,
padding = 4,
use_flash_attn = True
# 定义一个继承自 nn.Module 的类,用于实现一个包含卷积和Transformer的模型
):
# 调用父类的构造函数
super().__init__()
# 将dim_codebook添加到dims列表的开头
dims = [dim_codebook, *dims]
# 设置self.dim为dims列表的第一个元素,设置self.dim_out为dims列表的最后一个元素
self.dim, self.dim_out = dims[0], dims[-1]
# 将dims列表中相邻的两个元素组成一对,形成一个维度对的列表
dim_pairs = zip(dims[:-1], dims[1:])
# 初始化一个空的模块列表
modules = []
# 遍历维度对列表,为每一对维度创建一个卷积层和SiLU激活函数,并添加到模块列表中
for dim_in, dim_out in dim_pairs:
modules.extend([
nn.Conv1d(dim_in, dim_out, kernel_size, padding = padding),
nn.SiLU()
])
# 构建一个包含卷积层和SiLU激活函数的序列模块
self.conv = nn.Sequential(
Rearrange('b n c -> b c n'),
*modules,
Rearrange('b c n -> b n c')
)
# 初始化一个Transformer模块
self.transformer = Transformer(
dim = dims[-1],
depth = depth,
heads = heads,
dim_head = dim_head,
dropout = dropout,
use_flash = use_flash_attn
)
# 定义前向传播函数
def forward(self, x):
# 断言输入张量x的最后一个维度与self.dim相等
assert x.shape[-1] == self.dim
# 将输入张量通过卷积层和Transformer模块进行前向传播
x = self.conv(x)
x = self.transformer(x)
return x
# 定义一个名为 Block 的类,继承自 nn.Module
class Block(nn.Module):
# 初始化函数,接受输入维度 dim、输出维度 dim_out、卷积核大小 kernel、分组数 groups 和 dropout 概率
def __init__(
self,
dim,
dim_out,
kernel = 3,
groups = 8,
dropout = 0.
):
super().__init__()
# 创建一个卷积层,将输入维度映射到输出维度
self.proj = nn.Conv1d(dim, dim_out, kernel, padding = kernel // 2)
# 对输出进行分组归一化
self.norm = nn.GroupNorm(groups, dim_out)
# 使用 SiLU 激活函数
self.act = nn.SiLU()
# 使用 dropout 进行正则化
self.dropout = nn.Dropout(dropout)
# 前向传播函数
def forward(self, x):
# 对输入进行卷积操作
x = self.proj(x)
# 对卷积结果进行分组归一化
x = self.norm(x)
# 使用激活函数
x = self.act(x)
# 使用 dropout
x = self.dropout(x)
return x
# 定义一个名为 ResnetBlock 的类,继承自 nn.Module
class ResnetBlock(nn.Module):
# 初始化函数,接受输入维度 dim、输出维度 dim_out、卷积核大小 kernel、dropout 概率、分组数 groups 和卷积层数 num_convs
def __init__(
self,
dim,
dim_out,
kernel,
*,
dropout = 0.,
groups = 8,
num_convs = 2
):
super().__init__()
blocks = []
# 循环创建 num_convs 个 Block 实例
for ind in range(num_convs):
is_first = ind == 0
dim_in = dim if is_first else dim_out
# 创建一个 Block 实例
block = Block(
dim_in,
dim_out,
kernel,
groups = groups,
dropout = dropout
)
blocks.append(block)
# 将所有 Block 实例组合成一个序列
self.blocks = nn.Sequential(*blocks)
# 如果输入维度和输出维度不相等,使用 1x1 卷积进行维度匹配,否则使用恒等映射
self.res_conv = nn.Conv1d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
# 前向传播函数
def forward(self, x):
# 将输入维度重新排列
x = rearrange(x, 'b n c -> b c n')
# 对输入进行 Block 序列操作
h = self.blocks(x)
# 将 Block 序列的输出与输入进行残差连接
out = h + self.res_conv(x)
# 将输出维度重新排列
return rearrange(out, 'b c n -> b n c')
# 定义一个函数 ConvBlock,接受输入维度 dim、输出维度 dim_out、卷积核大小 kernel 和 dropout 概率
def ConvBlock(dim, dim_out, kernel, dropout = 0.):
# 返回一个包含卷积、激活函数、dropout 的序列
return nn.Sequential(
Rearrange('b n c -> b c n'),
nn.Conv1d(dim, dim_out, kernel, padding = kernel // 2),
nn.SiLU(),
nn.Dropout(dropout),
Rearrange('b c n -> b n c'),
)
# 定义一个名为 DurationPitchPredictorTrunk 的类,继承自 nn.Module
class DurationPitchPredictorTrunk(nn.Module):
# 初始化函数,接受输入维度 dim、深度 depth、卷积核大小 kernel_size、上下文维度 dim_context、头数 heads、头维度 dim_head、dropout 概率、是否使用 ResNet 块 use_resnet_block、每个 ResNet 块的卷积层数 num_convs_per_resnet_block、每个块的卷积层数 num_convolutions_per_block、是否使用 Flash 注意力 use_flash_attn
def __init__(
self,
dim = 512,
depth = 10,
kernel_size = 3,
dim_context = None,
heads = 8,
dim_head = 64,
dropout = 0.2,
use_resnet_block = True,
num_convs_per_resnet_block = 2,
num_convolutions_per_block = 3,
use_flash_attn = False,
):
super().__init__()
# 初始化一个空的模块列表
self.layers = nn.ModuleList([])
# 根据是否使用 ResNet 块选择卷积类
conv_klass = ConvBlock if not use_resnet_block else partial(ResnetBlock, num_convs = num_convs_per_resnet_block)
# 循环创建 depth 个层
for _ in range(depth):
# 每个层包含一个卷积序列、RMSNorm 归一化和注意力机制
layer = nn.ModuleList([
nn.Sequential(*[
conv_klass(dim, dim, kernel_size) for _ in range(num_convolutions_per_block)
]),
RMSNorm(dim),
Attention(
dim,
dim_context = dim_context,
heads = heads,
dim_head = dim_head,
dropout = dropout,
use_flash = use_flash_attn,
cross_attn_include_queries = True
)
])
self.layers.append(layer)
# 最后的预测层,包含线性层、维度重排和 ReLU 激活函数
self.to_pred = nn.Sequential(
nn.Linear(dim, 1),
Rearrange('... 1 -> ...'),
nn.ReLU()
)
# 前向传播函数,接受输入 x、编码的提示信息 encoded_prompts 和提示信息的掩码 prompt_mask
def forward(
self,
x,
encoded_prompts,
prompt_mask = None,
):
# 对每个层进行操作
for conv, norm, attn in self.layers:
x = conv(x)
x = attn(norm(x), encoded_prompts, mask = prompt_mask) + x
return self.to_pred(x)
# 定义一个名为 DurationPitchPredictor 的类,继承自 nn.Module
class DurationPitchPredictor(nn.Module):
# 初始化函数,接受维度 dim、音素标记数 num_phoneme_tokens、分词器 tokenizer、编码提示信息的维度 dim_encoded_prompts、每个块的卷积层数 num_convolutions_per_block、是否使用 ResNet 块 use_resnet_block、每个 ResNet 块的卷积层数 num_convs_per_resnet_block、深度 depth、卷积核大小 kernel_size、头数 heads、头维度 dim_head、隐藏层维度 dim_hidden、dropout 概率、是否使用 Flash 注意力 use_flash_attn
def __init__(
self,
*,
dim,
num_phoneme_tokens = None,
tokenizer: Optional[Tokenizer] = None,
dim_encoded_prompts = None,
num_convolutions_per_block = 3,
use_resnet_block = True,
num_convs_per_resnet_block = 2,
depth = 10,
kernel_size = 3,
heads = 8,
dim_head = 64,
dim_hidden = 512,
dropout = 0.2,
use_flash_attn = False
):
super().__init__()
# 略
):
# 调用父类的构造函数
super().__init__()
# 初始化 tokenizer 属性
self.tokenizer = tokenizer
# 如果存在 tokenizer,则将 num_phoneme_tokens 设置为 tokenizer 的词汇表大小,否则为 None
num_phoneme_tokens = default(num_phoneme_tokens, tokenizer.vocab_size if exists(tokenizer) else None)
# 如果未提供 dim_encoded_prompts,则将其设置为 dim
dim_encoded_prompts = default(dim_encoded_prompts, dim)
# 如果存在 num_phoneme_tokens,则创建一个 num_phoneme_tokens x dim 的嵌入层,否则创建一个恒等映射
self.phoneme_token_emb = nn.Embedding(num_phoneme_tokens, dim) if exists(num_phoneme_tokens) else nn.Identity()
# 初始化 to_pitch_pred 属性为 DurationPitchPredictorTrunk 类的实例
self.to_pitch_pred = DurationPitchPredictorTrunk(
dim = dim_hidden,
depth = depth,
kernel_size = kernel_size,
dim_context = dim_encoded_prompts,
heads = heads,
dim_head = dim_head,
dropout = dropout,
use_resnet_block = use_resnet_block,
num_convs_per_resnet_block = num_convs_per_resnet_block,
num_convolutions_per_block = num_convolutions_per_block,
use_flash_attn = use_flash_attn,
)
# 使用深拷贝创建 to_duration_pred 属性
self.to_duration_pred = copy.deepcopy(self.to_pitch_pred)
# 定义 forward 方法
@beartype
def forward(
self,
x: Union[Tensor, List[str]],
encoded_prompts,
prompt_mask = None
):
# 如果 x 是 List[str] 类型,则将其转换为张量
if is_bearable(x, List[str]):
assert exists(self.tokenizer)
x = self.tokenizer.texts_to_tensor_ids(x)
# 对输入 x 进行嵌入
x = self.phoneme_token_emb(x)
# 使用 map 函数对 to_duration_pred 和 to_pitch_pred 进行计算
duration_pred, pitch_pred = map(lambda fn: fn(x, encoded_prompts = encoded_prompts, prompt_mask = prompt_mask), (self.to_duration_pred, self.to_pitch_pred))
# 返回持续时间预测和音高预测结果
return duration_pred, pitch_pred
# 使用来自 flamingo 论文的 Perceiver Resampler,替代 "q-k-v" 注意力机制,其中 m 个查询成为网络条件的关键/值
class PerceiverResampler(nn.Module):
def __init__(
self,
*,
dim,
depth,
dim_context = None,
num_latents = 64, # 论文中的 m
dim_head = 64,
heads = 8,
ff_mult = 4,
use_flash_attn = False
):
super().__init__()
dim_context = default(dim_context, dim)
self.proj_context = nn.Linear(dim_context, dim) if dim_context != dim else nn.Identity()
self.latents = nn.Parameter(torch.randn(num_latents, dim))
nn.init.normal_(self.latents, std = 0.02)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(
dim = dim,
dim_head = dim_head,
heads = heads,
use_flash = use_flash_attn,
cross_attn_include_queries = True
),
FeedForward(dim = dim, mult = ff_mult)
]))
self.norm = RMSNorm(dim)
def forward(self, x, mask = None):
batch = x.shape[0]
x = self.proj_context(x)
latents = repeat(self.latents, 'n d -> b n d', b = batch)
for attn, ff in self.layers:
latents = attn(latents, x, mask = mask) + latents
latents = ff(latents) + latents
return self.norm(latents)
# 模型,即 Wavenet + Transformer
class CausalConv1d(nn.Conv1d):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
kernel_size, = self.kernel_size
dilation, = self.dilation
stride, = self.stride
assert stride == 1
self.causal_padding = dilation * (kernel_size - 1)
def forward(self, x):
causal_padded_x = F.pad(x, (self.causal_padding, 0), value = 0.)
return super().forward(causal_padded_x)
class WavenetResBlock(nn.Module):
def __init__(
self,
dim,
*,
dilation,
kernel_size = 3,
skip_conv = False,
dim_cond_mult = None
):
super().__init__()
self.cond = exists(dim_cond_mult)
self.to_time_cond = None
if self.cond:
self.to_time_cond = nn.Linear(dim * dim_cond_mult, dim * 2)
self.conv = CausalConv1d(dim, dim, kernel_size, dilation = dilation)
self.res_conv = CausalConv1d(dim, dim, 1)
self.skip_conv = CausalConv1d(dim, dim, 1) if skip_conv else None
def forward(self, x, t = None):
if self.cond:
assert exists(t)
t = self.to_time_cond(t)
t = rearrange(t, 'b c -> b c 1')
t_gamma, t_beta = t.chunk(2, dim = -2)
res = self.res_conv(x)
x = self.conv(x)
if self.cond:
x = x * t_gamma + t_beta
x = x.tanh() * x.sigmoid()
x = x + res
skip = None
if exists(self.skip_conv):
skip = self.skip_conv(x)
return x, skip
class WavenetStack(nn.Module):
def __init__(
self,
dim,
*,
layers,
kernel_size = 3,
has_skip = False,
dim_cond_mult = None
):
super().__init__()
dilations = 2 ** torch.arange(layers)
self.has_skip = has_skip
self.blocks = mlist([])
for dilation in dilations.tolist():
block = WavenetResBlock(
dim = dim,
kernel_size = kernel_size,
dilation = dilation,
skip_conv = has_skip,
dim_cond_mult = dim_cond_mult
)
self.blocks.append(block)
# 定义前向传播函数,接受输入 x 和时间 t
def forward(self, x, t):
# 初始化残差和跳跃连接列表
residuals = []
skips = []
# 如果输入 x 是张量类型,则将其重复多次,以匹配网络块的数量
if isinstance(x, Tensor):
x = (x,) * len(self.blocks)
# 遍历输入 x 和网络块,计算残差和跳跃连接
for block_input, block in zip(x, self.blocks):
residual, skip = block(block_input, t)
# 将计算得到的残差和跳跃连接添加到对应的列表中
residuals.append(residual)
skips.append(skip)
# 如果存在跳跃连接,则返回所有跳跃连接的张量堆叠
if self.has_skip:
return torch.stack(skips)
# 否则返回所有残差的列表
return residuals
class Wavenet(nn.Module):
def __init__(
self,
dim,
*,
stacks,
layers,
init_conv_kernel = 3,
dim_cond_mult = None
):
# 初始化 Wavenet 类
super().__init__()
# 创建初始卷积层对象
self.init_conv = CausalConv1d(dim, dim, init_conv_kernel)
# 初始化堆栈列表
self.stacks = mlist([])
# 循环创建堆栈
for ind in range(stacks):
is_last = ind == (stacks - 1)
# 创建 WavenetStack 对象
stack = WavenetStack(
dim,
layers = layers,
dim_cond_mult = dim_cond_mult,
has_skip = is_last
)
# 将堆栈对象添加到堆栈列表中
self.stacks.append(stack)
# 创建最终卷积层对象
self.final_conv = CausalConv1d(dim, dim, 1)
def forward(self, x, t = None):
# 对输入数据进行初始卷积
x = self.init_conv(x)
# 遍历堆栈列表,对数据进行处理
for stack in self.stacks:
x = stack(x, t)
# 对处理后的数据进行最终卷积并返回结果
return self.final_conv(x.sum(dim = 0))
class RMSNorm(nn.Module):
def __init__(self, dim, scale = True, dim_cond = None):
# 初始化 RMSNorm 类
super().__init__()
# 检查是否有条件输入
self.cond = exists(dim_cond)
# 根据条件初始化线性层
self.to_gamma_beta = nn.Linear(dim_cond, dim * 2) if self.cond else None
# 初始化缩放参数和 gamma 参数
self.scale = dim ** 0.5
self.gamma = nn.Parameter(torch.ones(dim)) if scale else None
def forward(self, x, cond = None):
# 获取 gamma 参数
gamma = default(self.gamma, 1)
# 对输入数据进行归一化处理
out = F.normalize(x, dim = -1) * self.scale * gamma
# 如果没有条件输入,则直接返回处理后的数据
if not self.cond:
return out
# 如果有条件输入,则根据条件计算 gamma 和 beta,并进行处理
assert exists(cond)
gamma, beta = self.to_gamma_beta(cond).chunk(2, dim = -1)
gamma, beta = map(lambda t: rearrange(t, 'b d -> b 1 d'), (gamma, beta))
return out * gamma + beta
class ConditionableTransformer(nn.Module):
def __init__(
self,
dim,
*,
depth,
dim_head = 64,
heads = 8,
ff_mult = 4,
ff_causal_conv = False,
dim_cond_mult = None,
cross_attn = False,
use_flash = False
):
# 初始化 ConditionableTransformer 类
super().__init__()
# 设置维度和层列表
self.dim = dim
self.layers = mlist([])
# 检查是否有条件输入
cond = exists(dim_cond_mult)
# 根据条件初始化 RMSNorm 层
maybe_adaptive_norm_kwargs = dict(scale = not cond, dim_cond = dim * dim_cond_mult) if cond else dict()
rmsnorm = partial(RMSNorm, **maybe_adaptive_norm_kwargs)
# 循环创建层
for _ in range(depth):
self.layers.append(mlist([
rmsnorm(dim),
Attention(dim = dim, dim_head = dim_head, heads = heads, use_flash = use_flash),
rmsnorm(dim) if cross_attn else None,
Attention(dim = dim, dim_head = dim_head, heads = heads, use_flash = use_flash) if cross_attn else None,
rmsnorm(dim),
FeedForward(dim = dim, mult = ff_mult, causal_conv = ff_causal_conv)
]))
# 创建预测层
self.to_pred = nn.Sequential(
RMSNorm(dim),
nn.Linear(dim, dim, bias = False)
)
def forward(
self,
x,
times = None,
context = None
):
t = times
# 遍历层列表,对输入数据进行处理
for attn_norm, attn, cross_attn_norm, cross_attn, ff_norm, ff in self.layers:
res = x
x = attn_norm(x, cond = t)
x = attn(x) + res
# 如果有交叉注意力,则进行处理
if exists(cross_attn):
assert exists(context)
res = x
x = cross_attn_norm(x, cond = t)
x = cross_attn(x, context = context) + res
res = x
x = ff_norm(x, cond = t)
x = ff(x) + res
# 返回预测结果
return self.to_pred(x)
class Model(nn.Module):
@beartype
def __init__(
self,
dim,
*,
depth,
dim_head = 64,
heads = 8,
ff_mult = 4,
wavenet_layers = 8,
wavenet_stacks = 4,
dim_cond_mult = 4,
use_flash_attn = True,
dim_prompt = None,
num_latents_m = 32, # number of latents to be perceiver resampled ('q-k-v' with 'm' queries in the paper)
resampler_depth = 2,
cond_drop_prob = 0.,
condition_on_prompt= False
):
# 调用父类的构造函数
super().__init__()
# 初始化模型的维度
self.dim = dim
# 时间条件
# 根据维度计算时间条件的维度
dim_time = dim * dim_cond_mult
# 创建时间条件的网络层
self.to_time_cond = Sequential(
LearnedSinusoidalPosEmb(dim), # 学习的正弦位置编码
nn.Linear(dim + 1, dim_time), # 线性层,将输入维度转换为时间条件的维度
nn.SiLU() # SiLU激活函数
)
# 提示条件
self.cond_drop_prob = cond_drop_prob # 用于分类器无指导的概率
self.condition_on_prompt = condition_on_prompt
self.to_prompt_cond = None
if self.condition_on_prompt:
self.null_prompt_cond = nn.Parameter(torch.randn(dim_time)) # 随机初始化空提示条件
self.null_prompt_tokens = nn.Parameter(torch.randn(num_latents_m, dim)) # 随机初始化空提示标记
nn.init.normal_(self.null_prompt_cond, std = 0.02) # 使用正态分布初始化空提示条件
nn.init.normal_(self.null_prompt_tokens, std = 0.02) # 使用正态分布初始化空提示标记
# 创建提示条件的网络层
self.to_prompt_cond = Sequential(
Reduce('b n d -> b d', 'mean'), # 减少维度
nn.Linear(dim_prompt, dim_time), # 线性层,将输入维度转换为提示条件的维度
nn.SiLU() # SiLU激活函数
)
# 创建PerceiverResampler对象
self.perceiver_resampler = PerceiverResampler(
dim = dim,
dim_context = dim_prompt,
num_latents = num_latents_m,
depth = resampler_depth,
dim_head = dim_head,
heads = heads,
use_flash_attn = use_flash_attn
)
# 从对齐器和持续时间模块获取对齐的条件
self.null_cond = None
self.cond_to_model_dim = None
if self.condition_on_prompt:
self.cond_to_model_dim = nn.Conv1d(dim_prompt, dim, 1) # 一维卷积层,将提示条件转换为模型维度
self.null_cond = nn.Parameter(torch.zeros(dim, 1)) # 初始化空条件
# 条件包括时间和可选的提示
dim_cond_mult = dim_cond_mult * (2 if condition_on_prompt else 1) # 更新条件的维度乘数
# WaveNet
# 创建WaveNet模型
self.wavenet = Wavenet(
dim = dim,
stacks = wavenet_stacks,
layers = wavenet_layers,
dim_cond_mult = dim_cond_mult
)
# Transformer
# 创建ConditionableTransformer模型
self.transformer = ConditionableTransformer(
dim = dim,
depth = depth,
dim_head = dim_head,
heads = heads,
ff_mult = ff_mult,
ff_causal_conv = True,
dim_cond_mult = dim_cond_mult,
use_flash = use_flash_attn,
cross_attn = condition_on_prompt
)
@property
def device(self):
return next(self.parameters()).device
def forward_with_cond_scale(
self,
*args,
cond_scale = 1.,
**kwargs
):
# 前向传播函数,带有条件缩放
logits = self.forward(*args, cond_drop_prob = 0., **kwargs)
if cond_scale == 1.:
return logits
null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
return null_logits + (logits - null_logits) * cond_scale
def forward(
self,
x,
times,
prompt = None,
prompt_mask = None,
cond = None,
cond_drop_prob = None
):
# 获取输入张量 x 的 batch 大小
b = x.shape[0]
# 如果未指定条件丢弃概率,则使用默认值
cond_drop_prob = default(cond_drop_prob, self.cond_drop_prob)
# 准备时间条件
# 概率应该在向前移除
# 将时间转换为条件
t = self.to_time_cond(times)
c = None
# 如果存在 prompt 条件
if exists(self.to_prompt_cond):
assert exists(prompt)
# 创建与 prompt 条件大小相同的概率掩码
prompt_cond_drop_mask = prob_mask_like((b,), cond_drop_prob, self.device)
# 将 prompt 转换为条件
prompt_cond = self.to_prompt_cond(prompt)
# 根据概率掩码更新 prompt 条件
prompt_cond = torch.where(
rearrange(prompt_cond_drop_mask, 'b -> b 1'),
self.null_prompt_cond,
prompt_cond,
)
# 将时间条件和 prompt 条件连接起来
t = torch.cat((t, prompt_cond), dim = -1)
# 对 prompt 进行重采样
resampled_prompt_tokens = self.perceiver_resampler(prompt, mask = prompt_mask)
# 根据概率掩码更新 prompt tokens
c = torch.where(
rearrange(prompt_cond_drop_mask, 'b -> b 1 1'),
self.null_prompt_tokens,
resampled_prompt_tokens
)
# 重新排列为通道优先格式
x = rearrange(x, 'b n d -> b d n')
# 将对齐的条件加到输入序列中
if exists(self.cond_to_model_dim):
assert exists(cond)
# 将条件转换为模型维度
cond = self.cond_to_model_dim(cond)
# 创建与条件大小相同的概率掩码
cond_drop_mask = prob_mask_like((b,), cond_drop_prob, self.device)
# 根据概率掩码更新条件
cond = torch.where(
rearrange(cond_drop_mask, 'b -> b 1 1'),
self.null_cond,
cond
)
# 目前,将条件调整为潜在特征的长度
cond = pad_or_curtail_to_length(cond, x.shape[-1])
# 将条件加到输入张量中
x = x + cond
# 主要的 WaveNet 模块
x = self.wavenet(x, t)
x = rearrange(x, 'b d n -> b n d')
# 使用 Transformer 模块
x = self.transformer(x, t, context = c)
return x
# feedforward
# GEGLU 激活函数类,用于前向传播
class GEGLU(nn.Module):
# 前向传播函数
def forward(self, x):
# 将输入张量 x 按照最后一个维度分成两部分
x, gate = x.chunk(2, dim = -1)
# 返回 GEGLU 激活函数的结果
return F.gelu(gate) * x
# 创建前馈神经网络层
def FeedForward(dim, mult = 4, causal_conv = False):
# 计算内部维度
dim_inner = int(dim * mult * 2 / 3)
conv = None
# 如果是因果卷积
if causal_conv:
# 创建因果卷积层
conv = nn.Sequential(
Rearrange('b n d -> b d n'),
CausalConv1d(dim_inner, dim_inner, 3),
Rearrange('b d n -> b n d'),
)
return Sequential(
nn.Linear(dim, dim_inner * 2),
GEGLU(),
conv,
nn.Linear(dim_inner, dim)
)
# attention
# 注意力机制类
class Attention(nn.Module):
# 初始化函数
def __init__(
self,
dim,
*,
dim_context = None,
causal = False,
dim_head = 64,
heads = 8,
dropout = 0.,
use_flash = False,
cross_attn_include_queries = False
):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
self.cross_attn_include_queries = cross_attn_include_queries
dim_inner = dim_head * heads
dim_context = default(dim_context, dim)
self.attend = Attend(causal = causal, dropout = dropout, use_flash = use_flash)
self.to_q = nn.Linear(dim, dim_inner, bias = False)
self.to_kv = nn.Linear(dim_context, dim_inner * 2, bias = False)
self.to_out = nn.Linear(dim_inner, dim, bias = False)
# 前向传播函数
def forward(self, x, context = None, mask = None):
h, has_context = self.heads, exists(context)
context = default(context, x)
if has_context and self.cross_attn_include_queries:
context = torch.cat((x, context), dim = -2)
q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
out = self.attend(q, k, v, mask = mask)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
# transformer encoder
# Transformer 编码器类
class Transformer(nn.Module):
# 初始化函数
def __init__(
self,
dim,
*,
depth,
causal = False,
dim_head = 64,
heads = 8,
use_flash = False,
dropout = 0.,
ff_mult = 4,
final_norm = False
):
super().__init__()
self.layers = nn.ModuleList([])
# 创建多层 Transformer 编码器
for _ in range(depth):
self.layers.append(nn.ModuleList([
RMSNorm(dim),
Attention(
dim,
causal = causal,
dim_head = dim_head,
heads = heads,
dropout = dropout,
use_flash = use_flash
),
RMSNorm(dim),
FeedForward(
dim,
mult = ff_mult
)
]))
self.norm = RMSNorm(dim) if final_norm else nn.Identity()
# 前向传播函数
def forward(self, x, mask = None):
for attn_norm, attn, ff_norm, ff in self.layers:
x = attn(attn_norm(x), mask = mask) + x
x = ff(ff_norm(x)) + x
return self.norm(x)
# tensor helper functions
# 对数函数
def log(t, eps = 1e-20):
return torch.log(t.clamp(min = eps))
# ���全除法函数
def safe_div(numer, denom):
return numer / denom.clamp(min = 1e-10)
# 将 x 张量的维度右侧填充到与 t 张量相同维度
def right_pad_dims_to(x, t):
padding_dims = x.ndim - t.ndim
if padding_dims <= 0:
return t
return t.view(*t.shape, *((1,) * padding_dims))
# noise schedules
# 简单线性调度函数
def simple_linear_schedule(t, clip_min = 1e-9):
return (1 - t).clamp(min = clip_min)
# 余弦调度函数
def cosine_schedule(t, start = 0, end = 1, tau = 1, clip_min = 1e-9):
power = 2 * tau
v_start = math.cos(start * math.pi / 2) ** power
v_end = math.cos(end * math.pi / 2) ** power
output = math.cos((t * (end - start) + start) * math.pi / 2) ** power
output = (v_end - output) / (v_end - v_start)
return output.clamp(min = clip_min)
# sigmoid 调度函数
def sigmoid_schedule(t, start = -3, end = 3, tau = 1, clamp_min = 1e-9):
# 根据起始时间和结束时间计算对应的 sigmoid 值
v_start = torch.tensor(start / tau).sigmoid()
v_end = torch.tensor(end / tau).sigmoid()
# 计算 gamma 值,用于调整时间范围
gamma = (-((t * (end - start) + start) / tau).sigmoid() + v_end) / (v_end - v_start)
# 对 gamma 进行范围限制,确保在指定范围内
return gamma.clamp_(min=clamp_min, max=1.)
# 将 gamma 转换为 alpha、sigma 或 logsnr
def gamma_to_alpha_sigma(gamma, scale = 1):
# 计算 alpha 和 sigma,并乘以指定的比例
return torch.sqrt(gamma) * scale, torch.sqrt(1 - gamma)
def gamma_to_log_snr(gamma, scale = 1, eps = 1e-5):
# 计算 logsnr,根据给定的 gamma、比例和 eps
return log(gamma * (scale ** 2) / (1 - gamma), eps = eps)
# 高斯扩散
class NaturalSpeech2(nn.Module):
@beartype
def __init__(
self,
model: Model,
codec: Optional[Union[SoundStream, EncodecWrapper]] = None,
*,
tokenizer: Optional[Tokenizer] = None,
target_sample_hz = None,
timesteps = 1000,
use_ddim = True,
noise_schedule = 'sigmoid',
objective = 'v',
schedule_kwargs: dict = dict(),
time_difference = 0.,
min_snr_loss_weight = True,
min_snr_gamma = 5,
train_prob_self_cond = 0.9,
rvq_cross_entropy_loss_weight = 0., # 默认关闭,直到确定其是否有效。不确定这是否至关重要
dim_codebook: int = 128,
duration_pitch_dim: int = 512,
aligner_dim_in: int = 80,
aligner_dim_hidden: int = 512,
aligner_attn_channels: int = 80,
num_phoneme_tokens: int = 150,
pitch_emb_dim: int = 256,
pitch_emb_pp_hidden_dim: int= 512,
calc_pitch_with_pyworld = True, # 使用 pyworld 或 kaldi 从 torchaudio 计算音高
mel_hop_length = 160,
audio_to_mel_kwargs: dict = dict(),
scale = 1., # 在训练高分辨率图像时,将此设置为 < 1 以获得更好的收敛性
duration_loss_weight = 1.,
pitch_loss_weight = 1.,
aligner_loss_weight = 1.,
aligner_bin_loss_weight = 0.
# 初始化函数,继承父类的初始化方法
def __init__(
self
):
# 调用父类的初始化方法
super().__init__()
# 设置条件变量
self.conditional = model.condition_on_prompt
# 设置模型和编解码器
self.model = model
self.codec = codec
# 确保编解码器存在或目标采样率存在
assert exists(codec) or exists(target_sample_hz)
# 设置目标采样率和序列长度的倍数
self.target_sample_hz = target_sample_hz
self.seq_len_multiple_of = None
# 如果编解码器存在,则设置目标采样率和序列长度的倍数
if exists(codec):
self.target_sample_hz = codec.target_sample_hz
self.seq_len_multiple_of = codec.seq_len_multiple_of
# 准备条件
if self.conditional:
# 如果目标采样率存在,则更新音频到梅尔频谱的参数
if exists(self.target_sample_hz):
audio_to_mel_kwargs.update(sampling_rate = self.target_sample_hz)
# 设置梅尔频谱的跳跃长度
self.mel_hop_length = mel_hop_length
# 创建音频到梅尔频谱的转换器
self.audio_to_mel = AudioToMel(
n_mels = aligner_dim_in,
hop_length = mel_hop_length,
**audio_to_mel_kwargs
)
# 设置是否使用 PyWorld 计算音高
self.calc_pitch_with_pyworld = calc_pitch_with_pyworld
# 初始化音素编码器、语音提示编码器、持续时间和音高预测器、对齐器、音高嵌入层等
self.phoneme_enc = PhonemeEncoder(tokenizer=tokenizer, num_tokens=num_phoneme_tokens)
self.prompt_enc = SpeechPromptEncoder(dim_codebook=dim_codebook)
self.duration_pitch = DurationPitchPredictor(dim=duration_pitch_dim)
self.aligner = Aligner(dim_in=aligner_dim_in, dim_hidden=aligner_dim_hidden, attn_channels=aligner_attn_channels)
self.pitch_emb = nn.Embedding(pitch_emb_dim, pitch_emb_pp_hidden_dim)
# 初始化对齐器损失和二值损失
self.aligner_loss = ForwardSumLoss()
self.bin_loss = BinLoss()
self.aligner_bin_loss_weight = aligner_bin_loss_weight
# 其余的 DDPM
# 确保编解码器维度与模型维度相等
assert not exists(codec) or model.dim == codec.codebook_dim, f'transformer model dimension {model.dim} must be equal to codec dimension {codec.codebook_dim}'
# 设置维度
self.dim = codec.codebook_dim if exists(codec) else model.dim
# 确保目标是 'x0', 'eps', 'v' 中的一个
assert objective in {'x0', 'eps', 'v'}, 'objective must be either predict x0 or noise'
self.objective = objective
# 根据噪声调度设置 gamma 调度
if noise_schedule == "linear":
self.gamma_schedule = simple_linear_schedule
elif noise_schedule == "cosine":
self.gamma_schedule = cosine_schedule
elif noise_schedule == "sigmoid":
self.gamma_schedule = sigmoid_schedule
else:
raise ValueError(f'invalid noise schedule {noise_schedule}')
# 设置缩放比例
assert scale <= 1, 'scale must be less than or equal to 1'
self.scale = scale
# 设置 gamma 调度的参数
self.gamma_schedule = partial(self.gamma_schedule, **schedule_kwargs)
# 设置时间步长和是否使用 DDIM
self.timesteps = timesteps
self.use_ddim = use_ddim
# 提出的方法,将时间差加到下一个时间步长,以修复自我条件不足和在采样时间步长小于 400 时降低 FID
self.time_difference = time_difference
# 训练时自我条件的概率
self.train_prob_self_cond = train_prob_self_cond
# 最小 SNR 损失权重
self.min_snr_loss_weight = min_snr_loss_weight
self.min_snr_gamma = min_snr_gamma
# 持续时间和音高的损失权重
self.duration_loss_weight = duration_loss_weight
self.pitch_loss_weight = pitch_loss_weight
self.aligner_loss_weight = aligner_loss_weight
# 设备属性
@property
def device(self):
return next(self.model.parameters()).device
# 打印方法
def print(self, s):
return self.accelerator.print(s)
# 获取采样时间步长
def get_sampling_timesteps(self, batch, *, device):
# 在设备上创建一个从1到0的时间序列
times = torch.linspace(1., 0., self.timesteps + 1, device=device)
# 将时间序列重复batch次
times = repeat(times, 't -> b t', b=batch)
# 将时间序列拆分成相邻时间步长的对
times = torch.stack((times[:, :-1], times[:, 1:]), dim=0)
times = times.unbind(dim=-1)
return times
# 生成DDPM采样
@torch.no_grad()
def ddpm_sample(self, shape, prompt=None, time_difference=None, cond_scale=1., cond=None):
batch, device = shape[0], self.device
# 设置时间差
time_difference = default(time_difference, self.time_difference)
# 获取采样时间对
time_pairs = self.get_sampling_timesteps(batch, device=device)
# 生成随机音频
audio = torch.randn(shape, device=device)
x_start = None
last_latents = None
# 遍历时间对
for time, time_next in tqdm(time_pairs, desc='sampling loop time step', total=self.timesteps):
# 添加时间延迟
time_next = (time_next - self.time_difference).clamp(min=0.)
noise_cond = time
# 获取预测的x0
model_output = self.model.forward_with_cond_scale(audio, noise_cond, prompt=prompt, cond_scale=cond_scale, cond=cond)
# 获取log(snr)
gamma = self.gamma_schedule(time)
gamma_next = self.gamma_schedule(time_next)
gamma, gamma_next = map(partial(right_pad_dims_to, audio), (gamma, gamma_next))
# 获取alpha和sigma
alpha, sigma = gamma_to_alpha_sigma(gamma, self.scale)
alpha_next, sigma_next = gamma_to_alpha_sigma(gamma_next, self.scale)
# 计算x0和噪声
if self.objective == 'x0':
x_start = model_output
elif self.objective == 'eps':
x_start = safe_div(audio - sigma * model_output, alpha)
elif self.objective == 'v':
x_start = alpha * audio - sigma * model_output
# 推导后验均值和方差
log_snr, log_snr_next = map(gamma_to_log_snr, (gamma, gamma_next))
c = -expm1(log_snr - log_snr_next)
mean = alpha_next * (audio * (1 - c) / alpha + c * x_start)
variance = (sigma_next ** 2) * c
log_variance = log(variance)
# 获取噪声
noise = torch.where(
rearrange(time_next > 0, 'b -> b 1 1 1'),
torch.randn_like(audio),
torch.zeros_like(audio)
)
# 更新音频
audio = mean + (0.5 * log_variance).exp() * noise
return audio
@torch.no_grad()
# 生成一个指定形状的样本,可以设置时间差异、条件比例和条件
def ddim_sample(self, shape, prompt = None, time_difference = None, cond_scale = 1., cond = None):
# 获取批次大小和设备
batch, device = shape[0], self.device
# 设置时间差异
time_difference = default(time_difference, self.time_difference)
# 获取采样时间步
time_pairs = self.get_sampling_timesteps(batch, device = device)
# 生成随机噪声
audio = torch.randn(shape, device = device)
x_start = None
last_latents = None
# 遍历时间步
for times, times_next in tqdm(time_pairs, desc = 'sampling loop time step'):
# 获取时间和噪声水平
gamma = self.gamma_schedule(times)
gamma_next = self.gamma_schedule(times_next)
# 填充时间和噪声水平
padded_gamma, padded_gamma_next = map(partial(right_pad_dims_to, audio), (gamma, gamma_next))
# 将噪声水平转换为 alpha 和 sigma
alpha, sigma = gamma_to_alpha_sigma(padded_gamma, self.scale)
alpha_next, sigma_next = gamma_to_alpha_sigma(padded_gamma_next, self.scale)
# 添加时间延迟
times_next = (times_next - time_difference).clamp(min = 0.)
# 预测 x0
model_output = self.model.forward_with_cond_scale(audio, times, prompt = prompt, cond_scale = cond_scale, cond = cond)
# 计算 x0 和噪声
if self.objective == 'x0':
x_start = model_output
elif self.objective == 'eps':
x_start = safe_div(audio - sigma * model_output, alpha)
elif self.objective == 'v':
x_start = alpha * audio - sigma * model_output
# 获取预测噪声
pred_noise = safe_div(audio - alpha * x_start, sigma)
# 计算下一个 x
audio = x_start * alpha_next + pred_noise * sigma_next
return audio
# 处理提示信息
def process_prompt(self, prompt = None):
if not exists(prompt):
return None
assert self.model.condition_on_prompt
is_raw_prompt = prompt.ndim == 2
assert not (is_raw_prompt and not exists(self.codec)), 'codec must be passed in if one were to train on raw prompt'
if is_raw_prompt:
with torch.no_grad():
self.codec.eval()
prompt, _, _ = self.codec(prompt, curtail_from_left = True, return_encoded = True)
return prompt
# 扩展编码
def expand_encodings(self, phoneme_enc, attn, pitch):
expanded_dur = einsum('k l m n, k j m -> k j n', attn, phoneme_enc)
pitch_emb = self.pitch_emb(rearrange(f0_to_coarse(pitch), 'b 1 t -> b t'))
pitch_emb = rearrange(pitch_emb, 'b t d -> b d t')
expanded_pitch = einsum('k l m n, k j m -> k j n', attn, pitch_emb)
expanded_encodings = expanded_dur + expanded_pitch
return expanded_encodings
# 生成样本
@torch.no_grad()
def sample(
self,
*,
length,
prompt = None,
batch_size = 1,
cond_scale = 1.,
text = None,
text_lens = None,
):
# 如果不使用 DDIM,则使用 DDPM 进行采样
sample_fn = self.ddpm_sample if not self.use_ddim else self.ddim_sample
prompt_enc = cond = None
# 如果是有条件的生成
if self.conditional:
# 确保 prompt 和 text 存在
assert exists(prompt) and exists(text)
# 处理 prompt
prompt = self.process_prompt(prompt)
# 对 prompt 进行编码
prompt_enc = self.prompt_enc(prompt)
# 对 text 进行音素编码
phoneme_enc = self.phoneme_enc(text)
# 计算音频的持续时间和音高
duration, pitch = self.duration_pitch(phoneme_enc, prompt_enc)
# 重新排列 pitch 的维度
pitch = rearrange(pitch, 'b n -> b 1 n')
# 生成基于重复的掩码
aln_mask = generate_mask_from_repeats(duration).float()
# 对编码进行扩展
cond = self.expand_encodings(rearrange(phoneme_enc, 'b n d -> b d n'), rearrange(aln_mask, 'b n c -> b 1 n c'), pitch)
# 如果 prompt 存在
if exists(prompt):
# 获取批量大小
batch_size = prompt.shape[0]
# 生成音频
audio = sample_fn(
(batch_size, length, self.dim),
prompt = prompt_enc,
cond = cond,
cond_scale = cond_scale
)
# 如果存在编解码器
if exists(self.codec):
# 解码音频
audio = self.codec.decode(audio)
# 如果音频维度为 3
if audio.ndim == 3:
# 重新排列音频的维度
audio = rearrange(audio, 'b 1 n -> b n')
# 返回音频
return audio
def forward(
self,
audio,
text = None,
text_lens = None,
mel = None,
mel_lens = None,
codes = None,
prompt = None,
pitch = None,
*args,
**kwargs
# trainer
# 定义一个循环生成器函数,用于循环遍历数据集
def cycle(dl):
while True:
for data in dl:
yield data
# Trainer 类,用于训练模型
class Trainer(object):
def __init__(
self,
diffusion_model: NaturalSpeech2,
*,
dataset: Optional[Dataset] = None,
folder = None,
train_batch_size = 16,
gradient_accumulate_every = 1,
train_lr = 1e-4,
train_num_steps = 100000,
ema_update_every = 10,
ema_decay = 0.995,
adam_betas = (0.9, 0.99),
save_and_sample_every = 1000,
num_samples = 1,
results_folder = './results',
amp = False,
mixed_precision_type = 'fp16',
use_ema = True,
split_batches = True,
dataloader = None,
data_max_length = None,
data_max_length_seconds = 2,
sample_length = None
):
super().__init__()
# accelerator
# 初始化加速器,用于加速训练过程
self.accelerator = Accelerator(
split_batches = split_batches,
mixed_precision = mixed_precision_type if amp else 'no'
)
# model
# 设置模型为扩散模型
self.model = diffusion_model
assert exists(diffusion_model.codec)
self.dim = diffusion_model.dim
# training hyperparameters
# 设置训练超参数
self.batch_size = train_batch_size
self.gradient_accumulate_every = gradient_accumulate_every
self.train_num_steps = train_num_steps
# dataset and dataloader
dl = dataloader
if not exists(dl):
assert exists(dataset) or exists(folder)
if exists(dataset):
self.ds = dataset
elif exists(folder):
# create dataset
if exists(data_max_length_seconds):
assert not exists(data_max_length)
data_max_length = int(data_max_length_seconds * diffusion_model.target_sample_hz)
else:
assert exists(data_max_length)
# 创建数据集
self.ds = SoundDataset(
folder,
max_length = data_max_length,
target_sample_hz = diffusion_model.target_sample_hz,
seq_len_multiple_of = diffusion_model.seq_len_multiple_of
)
dl = DataLoader(
self.ds,
batch_size = train_batch_size,
shuffle = True,
pin_memory = True,
num_workers = cpu_count()
)
dl = self.accelerator.prepare(dl)
self.dl = cycle(dl)
# optimizer
# 初始化优化器
self.opt = Adam(diffusion_model.parameters(), lr = train_lr, betas = adam_betas)
# for logging results in a folder periodically
self.use_ema = use_ema
self.ema = None
if self.accelerator.is_main_process and use_ema:
# make sure codec is not part of the EMA
# encodec seems to be not deepcopyable, so this is a necessary hack
codec = diffusion_model.codec
diffusion_model.codec = None
# 初始化指数移动平均模型
self.ema = EMA(
diffusion_model,
beta = ema_decay,
update_every = ema_update_every,
ignore_startswith_names = set(['codec.'])
).to(self.device)
diffusion_model.codec = codec
self.ema.ema_model.codec = codec
# sampling hyperparameters
# 设置采样超参数
self.sample_length = default(sample_length, data_max_length)
self.num_samples = num_samples
self.save_and_sample_every = save_and_sample_every
# results folder
# 设置结果保存文件夹
self.results_folder = Path(results_folder)
self.results_folder.mkdir(exist_ok = True)
# step counter state
# 设置步数计数器
self.step = 0
# prepare model, dataloader, optimizer with accelerator
# 使用加速器准备模型、数据加载器和优化器
self.model, self.opt = self.accelerator.prepare(self.model, self.opt)
# 打印函数
def print(self, msg):
return self.accelerator.print(msg)
@property
# 返回未包装的模型
def unwrapped_model(self):
return self.accelerator.unwrap_model(self.model)
# 返回设备加速器的设备
@property
def device(self):
return self.accelerator.device
# 保存训练里程碑的模型状态
def save(self, milestone):
# 如果不是本地主进程,则返回
if not self.accelerator.is_local_main_process:
return
# 构建保存的数据字典
data = {
'step': self.step,
'model': self.accelerator.get_state_dict(self.model),
'opt': self.opt.state_dict(),
'ema': self.ema.state_dict(),
'scaler': self.accelerator.scaler.state_dict() if exists(self.accelerator.scaler) else None,
'version': __version__
}
# 保存数据到文件
torch.save(data, str(self.results_folder / f'model-{milestone}.pt'))
# 加载训练里程碑的模型状态
def load(self, milestone):
accelerator = self.accelerator
device = accelerator.device
# 从文件加载数据
data = torch.load(str(self.results_folder / f'model-{milestone}.pt'), map_location=device)
# 解包模型并加载状态
model = self.accelerator.unwrap_model(self.model)
model.load_state_dict(data['model'])
self.step = data['step']
self.opt.load_state_dict(data['opt'])
if self.accelerator.is_main_process:
self.ema.load_state_dict(data["ema"])
# 打印加载的版本信息
if 'version' in data:
print(f"loading from version {data['version']}")
# 如果存在加速器的缩放器和数据中的缩放器,则加载缩放器状态
if exists(self.accelerator.scaler) and exists(data['scaler']):
self.accelerator.scaler.load_state_dict(data['scaler'])
# 训练模型
def train(self):
accelerator = self.accelerator
device = accelerator.device
# 使用 tqdm 显示训练进度
with tqdm(initial=self.step, total=self.train_num_steps, disable=not accelerator.is_main_process) as pbar:
while self.step < self.train_num_steps:
total_loss = 0.
# 累积梯度并更新模型
for _ in range(self.gradient_accumulate_every):
data = next(self.dl).to(device)
with self.accelerator.autocast():
loss = self.model(data)
loss = loss / self.gradient_accumulate_every
total_loss += loss.item()
self.accelerator.backward(loss)
accelerator.clip_grad_norm_(self.model.parameters(), 1.0)
pbar.set_description(f'loss: {total_loss:.4f}')
accelerator.wait_for_everyone()
self.opt.step()
self.opt.zero_grad()
accelerator.wait_for_everyone()
self.step += 1
# 如果是主进程,更新指数移动平均模型并保存模型
if accelerator.is_main_process:
self.ema.update()
if divisible_by(self.step, self.save_and_sample_every):
milestone = self.step // self.save_and_sample_every
models = [(self.unwrapped_model, str(self.step))]
if self.use_ema:
models.append((self.ema.ema_model, f'{self.step}.ema'))
for model, label in models:
model.eval()
with torch.no_grad():
generated = model.sample(
batch_size=self.num_samples,
length=self.sample_length
)
for ind, t in enumerate(generated):
filename = str(self.results_folder / f'sample_{label}.flac')
t = rearrange(t, 'n -> 1 n')
torchaudio.save(filename, t.cpu().detach(), self.unwrapped_model.target_sample_hz)
self.print(f'{self.step}: saving to {str(self.results_folder)}')
self.save(milestone)
pbar.update(1)
self.print('training complete')