Lucidrains 系列项目源码解析(九十三)
.\lucidrains\soundstorm-pytorch\soundstorm_pytorch\trainer.py
# 导入必要的模块
from pathlib import Path
import re
from shutil import rmtree
# 导入 beartype 模块及相关类型
from beartype import beartype
from beartype.typing import Optional
# 导入 PyTorch 相关模块
import torch
from torch import nn
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import Dataset, random_split
# 导入自定义模块
from audiolm_pytorch.data import get_dataloader
from audiolm_pytorch.optimizer import get_optimizer
from soundstorm_pytorch.soundstorm import SoundStorm
# 导入加速器模块及分布式类型
from accelerate import Accelerator, DistributedType
# 定义一些辅助函数
# 判断值是否存在
def exists(val):
return val is not None
# 空操作函数
def noop(*args, **kwargs):
pass
# 生成数据循环
def cycle(dl):
while True:
for data in dl:
yield data
# 将输入转换为元组
def cast_tuple(t):
return t if isinstance(t, (tuple, list)) else (t,)
# 询问用户是或否
def yes_or_no(question):
answer = input(f'{question} (y/n) ')
return answer.lower() in ('yes', 'y')
# 累积日志信息
def accum_log(log, new_logs):
for key, new_value in new_logs.items():
old_value = log.get(key, 0.)
log[key] = old_value + new_value
return log
# 从检查点文件名中获取训练步数
def checkpoint_num_steps(checkpoint_path):
"""Returns the number of steps trained from a checkpoint based on the filename.
Filename format assumed to be something like "/path/to/soundstorm.20000.pt" which is
for 20k train steps. Returns 20000 in that case.
"""
results = re.findall(r'\d+', str(checkpoint_path)
if len(results) == 0:
return 0
return int(results[-1])
# 定义 SoundStormTrainer 类
class SoundStormTrainer(nn.Module):
@beartype
def __init__(
self,
model: SoundStorm,
*,
num_train_steps,
num_warmup_steps,
batch_size,
dataset: Optional[Dataset] = None,
only_train_generator = False,
only_train_critic = False,
lr = 3e-4,
initial_lr = 1e-5,
grad_accum_every = 1,
wd = 0.,
max_grad_norm = 0.5,
valid_frac = 0.05,
random_split_seed = 42,
save_results_every = 100,
save_model_every = 1000,
results_folder = './results',
accelerate_kwargs: dict = dict(),
split_batches = False,
drop_last = False,
force_clear_prev_results = None
# 初始化函数,继承父类的初始化方法
):
super().__init__()
# 初始化加速器对象
self.accelerator = Accelerator(
split_batches = split_batches,
**accelerate_kwargs
)
# 设置模型
self.model = model
# 注册缓冲区,存储训练步数
self.register_buffer('steps', torch.Tensor([0]))
# 设置训练步数、预热步数、批量大小、梯度累积步数等参数
self.num_train_steps = num_train_steps
self.num_warmup_steps = num_warmup_steps
self.batch_size = batch_size
self.grad_accum_every = grad_accum_every
self.only_train_generator = only_train_generator
self.only_train_critic = only_train_critic
# 初始化优化器
self.optim = get_optimizer(
model.parameters(),
lr = lr,
wd = wd
)
self.lr = lr
self.initial_lr = initial_lr
# 设置学习率调度器为余弦退火调度器
self.scheduler = CosineAnnealingLR(self.optim, T_max = num_train_steps)
# 设置梯度裁剪阈值
self.max_grad_norm = max_grad_norm
# 创建数据集
self.ds = dataset
# 划分验证集
if valid_frac > 0:
train_size = int((1 - valid_frac) * len(self.ds))
valid_size = len(self.ds) - train_size
self.ds, self.valid_ds = random_split(self.ds, [train_size, valid_size], generator = torch.Generator().manual_seed(random_split_seed))
self.print(f'training with dataset of {len(self.ds)} samples and validating with randomly splitted {len(self.valid_ds)} samples')
else:
self.valid_ds = self.ds
self.print(f'training with shared training and valid dataset of {len(self.ds)} samples')
# 断言确保数据集和验证集的样本数足够
assert len(self.ds) >= batch_size, 'dataset must have sufficient samples for training'
assert len(self.valid_ds) >= batch_size, f'validation dataset must have sufficient number of samples (currently {len(self.valid_ds)}) for training'
# 创建数据加载器
self.dl = get_dataloader(self.ds, batch_size = batch_size, shuffle = True, drop_last = drop_last)
self.valid_dl = get_dataloader(self.valid_ds, batch_size = batch_size, shuffle = True, drop_last = drop_last)
# 使用加速器准备模型、优化器、调度器、数据加载器
(
self.model,
self.optim,
self.scheduler,
self.dl,
self.valid_dl
) = self.accelerator.prepare(
self.model,
self.optim,
self.scheduler,
self.dl,
self.valid_dl
)
# 创建数据加载器迭代器
self.dl_iter = cycle(self.dl)
self.valid_dl_iter = cycle(self.valid_dl)
# 设置保存模型和结果的频率
self.save_model_every = save_model_every
self.save_results_every = save_results_every
# 设置结果文件夹路径
self.results_folder = Path(results_folder)
# 如果是主进程且需要清除之前的结果,则清除结果文件夹
if self.is_main and force_clear_prev_results is True or (not exists(force_clear_prev_results) and len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?')):
rmtree(str(self.results_folder))
# 创建结果文件夹
self.results_folder.mkdir(parents = True, exist_ok = True)
# 初始化超参数追踪器
hps = {"num_train_steps": num_train_steps, "num_warmup_steps": num_warmup_steps, "learning_rate": lr, "initial_learning_rate": lr}
self.accelerator.init_trackers("soundstorm", config=hps)
# 保存模型方法
def save(self, path):
pkg = dict(
model = self.accelerator.get_state_dict(self.model),
optim = self.optim.state_dict(),
scheduler = self.scheduler.state_dict()
)
torch.save(pkg, path)
# 加载模型方法
def load(self, path, restore_optimizer = True):
model = self.accelerator.unwrap_model(self.model)
pkg = model.load(path)
# 如果需要恢复优化器状态,则加载优化器和调度器状态
if restore_optimizer:
self.optim.load_state_dict(pkg['optim'])
self.scheduler.load_state_dict(pkg['scheduler'])
# + 1 to start from the next step and avoid overwriting the last checkpoint
self.steps = torch.tensor([checkpoint_num_steps(path) + 1], device=self.device)
# 打印消息,调用加速器对象的打印方法
def print(self, msg):
self.accelerator.print(msg)
# 生成结果,调用模型对象的生成方法
def generate(self, *args, **kwargs):
return self.model.generate(*args, **kwargs)
# 返回设备信息,调用加速器对象的设备属性
@property
def device(self):
return self.accelerator.device
# 返回是否分布式训练,判断加速器对象的分布式类型和进程数是否为1
@property
def is_distributed(self):
return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1)
# 返回是否为主进程,判断加速器对象是否为主进程
@property
def is_main(self):
return self.accelerator.is_main_process
# 返回是否为本地主进程,判断加速器对象是否为本地主进程
@property
def is_local_main(self):
return self.accelerator.is_local_main_process
# 预热方法,根据步数计算学习率
def warmup(self, step):
if step < self.num_warmup_steps:
return self.initial_lr + (self.lr - self.initial_lr) * step / self.num_warmup_steps
else:
return self.lr
# 定义训练步骤函数
def train_step(self):
# 获取当前步数
steps = int(self.steps.item())
# 将模型设置为训练模式
self.model.train()
# 根据训练步数调整学习率
if steps < self.num_warmup_steps:
# 如果步数小于预热步数,应用预热学习率
lr = self.warmup(steps)
for param_group in self.optim.param_groups:
param_group['lr'] = lr
else:
# 预热期后,开始应用余弦退火学习率
self.scheduler.step()
# 初始化日志
logs = {}
# 更新生成器
for _ in range(self.grad_accum_every):
# 获取下一个数据批次
semantic_token_ids, acoustic_token_ids = next(self.dl_iter)
# 计算损失和损失细分
loss, loss_breakdown = self.model(
acoustic_token_ids,
cond_semantic_token_ids = semantic_token_ids,
only_train_generator = self.only_train_generator,
only_train_critic = self.only_train_critic
)
generator_loss, critic_loss = loss_breakdown
generator_loss = 0. if generator_loss is None else generator_loss
critic_loss = 0. if critic_loss is None else critic_loss
# 反向传播
self.accelerator.backward(loss / self.grad_accum_every)
# 累积日志
accum_log(logs, {'loss': loss.item() / self.grad_accum_every, 'generator_loss': generator_loss / self.grad_accum_every, 'critic_loss': critic_loss / self.grad_accum_every})
# 如果存在最大梯度范数,则进行梯度裁剪
if exists(self.max_grad_norm):
self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
# 更新优化器
self.optim.step()
self.optim.zero_grad()
# 记录日志
self.print(f"{steps}: loss: {logs['loss']:0.3f}, generator loss: {logs['generator_loss']:0.3f}, critic loss: {logs['critic_loss']:0.3f}")
self.accelerator.log({"train_loss": logs['loss']}, step=steps)
# 定期采样结果
self.accelerator.wait_for_everyone()
if self.is_main and not (steps % self.save_results_every):
# 获取验证数据批次
semantic_token_ids, acoustic_token_ids = next(self.valid_dl_iter)
with torch.inference_mode():
self.model.eval()
# 计算验证损失和损失细分
valid_loss, valid_loss_breakdown = self.model(acoustic_token_ids, cond_semantic_token_ids = semantic_token_ids)
valid_generator_loss, valid_critic_loss = valid_loss_breakdown
valid_generator_loss = 0. if valid_generator_loss is None else valid_generator_loss
valid_critic_loss = 0. if valid_critic_loss is None else valid_critic_loss
# 记录验证日志
self.print(f'{steps}: valid loss {valid_loss:0.3f}, valid generator loss {valid_generator_loss:0.3f}, valid critic loss {valid_critic_loss:0.3f}')
self.accelerator.log({"valid_loss": valid_loss, "valid_generator_loss": valid_generator_loss, "valid_critic_loss": valid_critic_loss}, step=steps)
# 定期保存模型
if self.is_main and not (steps % self.save_model_every):
model_path = str(self.results_folder / f'soundstorm.{steps}.pt')
self.save(model_path)
self.print(f'{steps}: saving model to {str(self.results_folder)}')
# 更新步数并返回日志
self.steps += 1
return logs
# 训练函数
def train(self, log_fn = noop):
# 循环直到达到训练步数上限
while self.steps < self.num_train_steps:
logs = self.train_step()
log_fn(logs)
self.print('training complete')
.\lucidrains\soundstorm-pytorch\soundstorm_pytorch\__init__.py
# 从soundstorm_pytorch包中导入SoundStorm、SoundStream、ConformerWrapper和Conformer类
from soundstorm_pytorch.soundstorm import (
SoundStorm,
SoundStream,
ConformerWrapper,
Conformer
)
# 从soundstorm_pytorch包中导入SoundStormTrainer类
from soundstorm_pytorch.trainer import (
SoundStormTrainer
)

Spear-TTS - Pytorch
Implementation of Spear-TTS - multi-speaker text-to-speech attention network, in Pytorch
The text-to-semantic module built here will be used for SoundStorm for conditioning.
Appreciation
-
Stability for their generous sponsorships to work on and open source cutting edge artificial intelligence research
-
Lucas Newman for completing the backtranslation portion, as well as beam search decoding!
-
Lucas Newman for completing the final text to semantic transformer training code!
Install
$ pip install spear-tts-pytorch
Usage
import torch
from audiolm_pytorch import HubertWithKmeans
from spear_tts_pytorch import (
TextToSemantic,
SemanticToTextDatasetGenerator,
GeneratedAudioTextDataset,
MockDataset
)
wav2vec = HubertWithKmeans(
checkpoint_path = './hubert_base_ls960.pt',
kmeans_path = './hubert_base_ls960_L9_km500.bin'
)
model = TextToSemantic(
wav2vec = wav2vec,
dim = 512,
num_text_token_ids = 256,
heads = 8,
target_kv_heads = 2, # grouped query attention, for memory efficient decoding
source_depth = 1,
target_depth = 1
)
ds = MockDataset(10)
dataset_generator = SemanticToTextDatasetGenerator(
model = model,
dataset = ds,
folder = './output_folder'
)
dataset_generator(max_length = 2)
generated_dataset = GeneratedAudioTextDataset(
folder = './output_folder'
)
assert len(generated_dataset) == 10
Todo
-
add eos logic + generate, and hook up end-to-end generation in soundstorm
-
add first pretraining speech-to-speech with the reconstruction of 60% deleted tokens
-
add dropouts for this project, as low-resource
-
add total flexiblity of which layers of encoder / decoder to freeze during training
-
add step for training on small speech -> text corpus and generating pseudo-labelled dataset + finetuning (thanks to @lucasnewman)
-
add final step of finetuning on text -> speech + pseudolabelled dataset
-
figure out the best way to store and manage the pseudo-labelled generated dataset
-
batched beam search decoding
-
allow for using rotary positions in decoder + flash attention, give Tri another citation
-
integrate speculative decoding with some improvisation - done in same model using early exit strategy
-
add cached key / values for starter + single / grouped key values, make sure flash attention can support specialized causal mask before flash attention 2 is in pytorch core
-
polish the audio-text generation workflow
-
concatting the real audio-text dataset with the generated one -> or being able to convert real audio-text dataset to generated
Citations
@misc{kharitonov2023speak,
title = {Speak, Read and Prompt: High-Fidelity Text-to-Speech with Minimal Supervision},
author = {Eugene Kharitonov and Damien Vincent and Zalán Borsos and Raphaël Marinier and Sertan Girgin and Olivier Pietquin and Matt Sharifi and Marco Tagliasacchi and Neil Zeghidour},
year = {2023},
eprint = {2302.03540},
archivePrefix = {arXiv},
primaryClass = {cs.SD}
}
@inproceedings{dao2022flashattention,
title = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
author = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
booktitle = {Advances in Neural Information Processing Systems},
year = {2022}
}
@misc{shi2023enhance,
title = {Enhance audio generation controllability through representation similarity regularization},
author = {Yangyang Shi and Gael Le Lan and Varun Nagaraja and Zhaoheng Ni and Xinhao Mei and Ernie Chang and Forrest Iandola and Yang Liu and Vikas Chandra},
year = {2023},
eprint = {2309.08773},
archivePrefix = {arXiv},
primaryClass = {cs.SD}
}
@article{Ainslie2023GQATG,
title = {GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints},
author = {Joshua Ainslie and James Lee-Thorp and Michiel de Jong and Yury Zemlyanskiy and Federico Lebr'on and Sumit K. Sanghai},
journal = {ArXiv},
year = {2023},
volume = {abs/2305.13245},
url = {https://api.semanticscholar.org/CorpusID:258833177}
}
@inproceedings{Leviathan2022FastIF,
title = {Fast Inference from Transformers via Speculative Decoding},
author = {Yaniv Leviathan and Matan Kalman and Y. Matias},
booktitle = {International Conference on Machine Learning},
year = {2022},
url = {https://api.semanticscholar.org/CorpusID:254096365}
}
.\lucidrains\spear-tts-pytorch\setup.py
# 导入设置工具和查找包工具
from setuptools import setup, find_packages
# 设置包的信息
setup(
# 包名
name = 'spear-tts-pytorch',
# 查找包,不排除任何包
packages = find_packages(exclude=[]),
# 版本号
version = '0.4.8',
# 许可证
license='MIT',
# 描述
description = 'Spear-TTS - Pytorch',
# 作者
author = 'Phil Wang',
# 作者邮箱
author_email = 'lucidrains@gmail.com',
# 长描述内容类型
long_description_content_type = 'text/markdown',
# 项目链接
url = 'https://github.com/lucidrains/spear-tts-pytorch',
# 关键词
keywords = [
'artificial intelligence',
'deep learning',
'transformers',
'attention mechanism',
'text-to-speech'
],
# 安装依赖
install_requires=[
'audiolm-pytorch>=1.2.8',
'beartype',
'einops>=0.6.1',
'rotary-embedding-torch>=0.3.0',
'torch>=1.6',
'tqdm',
'x-clip>=0.12.2'
],
# 分类
classifiers=[
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\spear-tts-pytorch\spear_tts_pytorch\attend.py
import torch
from torch import nn, einsum
import torch.nn.functional as F
from collections import namedtuple
from functools import wraps
from packaging import version
from einops import rearrange, repeat
# 定义一个命名元组 Config,包含三个布尔类型的参数
Config = namedtuple('EfficientAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
# 定义一个辅助函数,用于检查变量是否存在
def exists(val):
return val is not None
# 定义一个装饰器函数,用于确保被装饰的函数只执行一次
def once(fn):
called = False
@wraps(fn)
def inner(x):
nonlocal called
if called:
return
called = True
return fn(x)
return inner
# 用装饰器 once 包装 print 函数,确保只打印一次
print_once = once(print)
# 主要类 Attend
class Attend(nn.Module):
def __init__(
self,
dropout = 0.,
causal = False,
flash = False
):
super().__init__()
self.dropout = dropout
self.attn_dropout = nn.Dropout(dropout)
self.causal = causal
self.register_buffer("mask", None, persistent=False)
self.flash = flash
assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
# 确定用于 cuda 和 cpu 的高效注意力配置
self.cpu_config = Config(True, True, True)
self.cuda_config = None
if not torch.cuda.is_available() or not flash:
return
device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
if device_properties.major == 8 and device_properties.minor == 0:
print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
self.cuda_config = Config(True, False, False)
else:
print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
self.cuda_config = Config(False, True, True)
# 获取掩码
def get_mask(self, i, j, device):
n = max(i, j)
if exists(self.mask) and self.mask.shape[-1] >= n:
mask = self.mask[:n, :n]
else:
mask = torch.ones((n, n), device = device, dtype = torch.bool).triu(1)
self.register_buffer("mask", mask, persistent = False)
return mask[-i:, :]
# Flash Attention 函数
def flash_attn(self, q, k, v, mask = None):
_, heads, q_len, _, k_len, causal, is_cuda, device = *q.shape, k.shape[-2], self.causal, q.is_cuda, q.device
# 检查掩码是否存在并扩展到兼容的形状
if exists(mask):
mask = rearrange(mask, 'b j -> b 1 1 j')
mask = mask.expand(-1, heads, q_len, -1)
# 检查是否有兼容的设备用于 Flash Attention
config = self.cuda_config if is_cuda else self.cpu_config
# 如果 q 和 k 的长度不同(缓存键/值),并且是因果的,手动构造因果注意力掩码作为浮点数,因为不支持(Flash Attention 2 最终会支持这一点)
row_is_entirely_masked = None
if causal and q_len != k_len:
causal_mask = self.get_mask(q_len, k_len, device = device)
if exists(mask):
mask = mask & ~causal_mask
else:
mask = ~causal_mask
row_is_entirely_masked = ~mask.any(dim = -1)
mask[..., 0] = mask[..., 0] | row_is_entirely_masked
causal = False
# 使用 torch.backends.cuda.sdp_kernel 函数应用 PyTorch 2.0 Flash Attention
with torch.backends.cuda.sdp_kernel(**config._asdict()):
out = F.scaled_dot_product_attention(
q, k, v,
attn_mask = mask,
dropout_p = self.dropout if self.training else 0.,
is_causal = causal
)
if exists(row_is_entirely_masked):
out = out.masked_fill(row_is_entirely_masked[..., None], 0.)
return out
# 定义一个前向传播函数,接受查询(q)、键(k)、值(v)和掩码(mask)作为输入参数
"""
einstein notation
b - batch
h - heads
n, i, j - sequence length (base sequence length, source, target)
d - feature dimension
"""
# 获取查询(q)的序列长度和设备信息
n, device = q.shape[-2], q.device
# 获取头数和键值对应的头数
heads, kv_heads = q.shape[1], k.shape[1]
# 如果键值对应的头数小于总头数,则对键(k)和值(v)进行重复以匹配总头数
if kv_heads < heads:
k, v = map(lambda t: repeat(t, 'b h ... -> b (g h) ...', g = heads // kv_heads), (k, v))
# 缩放因子
scale = q.shape[-1] ** -0.5
# 如果启用了flash注意力机制,则调用flash_attn函数
if self.flash:
return self.flash_attn(q, k, v, mask = mask)
# 相似度计算
sim = einsum("b h i d, b h j d -> b h i j", q, k) * scale
# 键填充掩码
# 如果存在掩码,则重新排列掩码并用极小值替换相似度矩阵中的无效位置
if exists(mask):
mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
# 因果掩码
# 如果启用了因果掩码,则生成因果掩码并用极小值替换相似度矩阵中的无效位置
if self.causal:
i, j = sim.shape[-2:]
causal_mask = self.get_mask(i, j, device)
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
# 注意力权重计算
# 对相似度矩阵进行softmax操作,得到注意力权重
attn = sim.softmax(dim = -1)
# 对注意力权重进行dropout操作
attn = self.attn_dropout(attn)
# 聚合值
# 根据注意力权重对值(v)进行加权求和,得到输出结果
out = einsum("b h i j, b h j d -> b h i d", attn, v)
return out
.\lucidrains\spear-tts-pytorch\spear_tts_pytorch\data.py
# 导入必要的模块
from pathlib import Path
import torch
from torch.utils.data import Dataset
from beartype import beartype
# 模拟数据集类
class MockDataset(Dataset):
# 初始化方法,接受数据集长度参数
def __init__(self, length: int):
self.length = length
# 返回数据集长度
def __len__(self):
return self.length
# 获取数据集中指定索引的数据
def __getitem__(self, ind):
return torch.randn(1024)
# 生成音频文本数据集类
class GeneratedAudioTextDataset(Dataset):
# 初始化方法,接受文件夹路径和分隔符ID参数
@beartype
def __init__(
self,
folder: str,
delimiter_id: int = -1
):
# 将文件夹路径转换为Path对象
self.folder = Path(folder)
# 断言文件夹存在且是一个目录
assert self.folder.exists() and self.folder.is_dir()
# 获取文件夹中所有以'.pt'结尾的文件路径列表
self.paths = list(self.folder.glob('*.pt'))
# 设置分隔符ID
self.delimiter_id = delimiter_id
# 返回数据集的长度
def __len__(self):
return len(self.paths)
# 获取数据集中指定索引的数据
def __getitem__(self, ind):
# 获取指定索引的文件路径
path = self.paths[ind]
# 加载文件中的数据为张量
tensor = torch.load(str(path))
# 创建一个布尔张量,标记分隔符ID的位置
delimiter_mask = tensor == self.delimiter_id
# 断言至少存在一个分隔符,否则抛出异常
assert delimiter_mask.any(), f'delimeter (<audio> <delimeter> <text>) not found'
# 找到第一个分隔符的位置
ind = (delimiter_mask.cumsum(dim=-1) == 0).sum().item()
# 返回分隔符之前的部分和分隔符之后的部分作为数据
return tensor[:ind], tensor[(ind + 1):]
.\lucidrains\spear-tts-pytorch\spear_tts_pytorch\distributed.py
# 导入 torch 库
import torch
# 从 torch.autograd 模块中导入 Function 类
from torch.autograd import Function
# 导入 torch.distributed 模块
import torch.distributed as distributed
# 从 einops 库中导入 rearrange 函数
from einops import rearrange
# distributed helpers
# 定义一个函数用于在所有进程中收集具有可变维度的张量
def all_gather_variable_dim(t, dim = 0, sizes = None):
# 获取当前设备、进程的排名和总进程数
device, rank, world_size = t.device, distributed.get_rank(), distributed.get_world_size()
# 如果 sizes 不存在
if not exists(sizes):
# 创建一个张量表示 t 在指定维度上的大小
size = torch.tensor(t.shape[dim], device = device, dtype = torch.long)
# 创建一个列表,用于存储各个进程的大小信息
sizes = [torch.empty_like(size, device = device, dtype = torch.long) for i in range(world_size)]
# 在所有进程中收集各个进程的大小信息
distributed.all_gather(sizes, size)
# 将收集到的大小信息堆叠成一个张量
sizes = torch.stack(sizes)
# 获取所有进程中最大的大小
max_size = sizes.amax().item()
# 将 t 在指定维度上填充到最大大小
padded_t = pad_dim_to(t, max_size, dim = dim)
# 创建一个列表,用于存储各个进程收集到的张量
gathered_tensors = [torch.empty(padded_t.shape, device = device, dtype = padded_t.dtype) for i in range(world_size)]
# 在所有进程中收集填充后的张量
distributed.all_gather(gathered_tensors, padded_t)
# 将所有进程收集到的张量在指定维度上拼接
gathered_tensor = torch.cat(gathered_tensors, dim = dim)
# 创建一个序列张量
seq = torch.arange(max_size, device = device)
# 创建一个掩码,用于选择有效的数据
mask = rearrange(seq, 'j -> 1 j') < rearrange(sizes, 'i -> i 1')
mask = rearrange(mask, 'i j -> (i j)')
seq = torch.arange(mask.shape[-1], device = device)
indices = seq[mask]
# 根据掩码选择有效的数据
gathered_tensor = gathered_tensor.index_select(dim, indices)
return gathered_tensor, sizes
# 定义一个继承自 Function 的类 AllGather
class AllGather(Function):
@staticmethod
def forward(ctx, x, dim, sizes):
# 检查是否处于分布式环境中且进程数大于 1
is_dist = distributed.is_initialized() and distributed.get_world_size() > 1
ctx.is_dist = is_dist
# 如果不处于分布式环境中,直接返回输入张量和空值
if not is_dist:
return x, None
# 在所有进程中收集具有可变维度的张量
x, batch_sizes = all_gather_variable_dim(x, dim = dim, sizes = sizes)
ctx.batch_sizes = batch_sizes.tolist()
ctx.dim = dim
return x, batch_sizes
@staticmethod
def backward(ctx, grads, _):
# 如果不处于分布式环境中,直接返回梯度和空值
if not ctx.is_dist:
return grads, None, None
# 获取各个进程的大小信息和当前进程的排名
batch_sizes, rank = ctx.batch_sizes, distributed.get_rank()
# 根据各个进程的大小信息拆分梯度
grads_by_rank = grads.split(batch_sizes, dim = ctx.dim)
return grads_by_rank[rank], None, None
# 将 AllGather 类应用为一个函数
all_gather = AllGather.apply
.\lucidrains\spear-tts-pytorch\spear_tts_pytorch\spear_tts_pytorch.py
# 导入数学库
import math
# 从路径库中导入路径类
from pathlib import Path
# 从 functools 库中导入 partial 函数
from functools import partial
# 从 random 库中导入 random 函数
from random import random
# 导入 torch 库
import torch
# 从 torch.nn.functional 中导入 F
import torch.nn.functional as F
# 从 torch.nn.utils.rnn 中导入 pad_sequence
from torch.nn.utils.rnn import pad_sequence
# 从 torch 中导入 Tensor, nn, einsum, IntTensor, LongTensor
from torch import Tensor, nn, einsum, IntTensor, LongTensor
# 从 torch.nn 中导入 Module, ModuleList
from torch.nn import Module, ModuleList
# 从 torch.utils.data 中导入 Dataset
from torch.utils.data import Dataset
# 从 einops 中导入 rearrange, repeat, pack, reduce
from einops import rearrange, repeat, pack, reduce
# 从 einops.layers.torch 中导入 Rearrange
from einops.layers.torch import Rearrange
# 从 audiolm_pytorch 中导入 FairseqVQWav2Vec, HubertWithKmeans
from audiolm_pytorch import FairseqVQWav2Vec, HubertWithKmeans
# 从 audiolm_pytorch.data 中导入 get_dataloader
from audiolm_pytorch.data import get_dataloader
# 从 rotary_embedding_torch 中导入 RotaryEmbedding
from rotary_embedding_torch import RotaryEmbedding
# 从 beartype 中导入 beartype
from beartype import beartype
# 从 beartype.door 中导入 is_bearable
from beartype.door import is_bearable
# 从 beartype.typing 中导入 Optional, Union, Callable, Literal, Tuple, List
from beartype.typing import Optional, Union, Callable, Literal, Tuple, List
# 从 x_clip.tokenizer 中导入 tokenizer
from x_clip.tokenizer import tokenizer
# 从 spear_tts_pytorch 中导入 Attend, all_gather
from spear_tts_pytorch.attend import Attend
from spear_tts_pytorch.distributed import all_gather
# 从 tqdm 中导入 tqdm
from tqdm import tqdm
# 定义 FloatTensor 类型为 Union 类型,包含 torch.FloatTensor 和 torch.cuda.FloatTensor
FloatTensor = Union[
torch.FloatTensor,
torch.cuda.FloatTensor
]
# 辅助函数
# 判断值是否存在
def exists(val):
return val is not None
# 如果值存在则返回该值,否则返回默认值
def default(val, d):
return val if exists(val) else d
# 判断张量是否为空
def empty(t: Tensor):
return t.numel() == 0
# 对张量进行 L2 归一化
def l2norm(t):
return F.normalize(t, dim = -1)
# 设置 EOS 标识符的位置
def set_eos_id(t: Tensor, eos_id: int, pad_id: int):
eos_indices = ((t == pad_id).cumsum(dim = -1) == 0).sum(dim = -1, keepdim = True).long()
batch_range = torch.arange(t.shape[0], device = t.device, dtype = torch.long)
batch_range = rearrange(batch_range, '... -> ... 1')
t = F.pad(t, (0, 1), value = pad_id)
t[batch_range, eos_indices] = eos_id
return t
# 对批次中的唯一连续值进行填充
def batch_unique_consecutive(t, pad_value = 0.):
unique_arr = [torch.unique_consecutive(el) for el in t.unbind(dim = 0)]
return pad_sequence(unique_arr, batch_first = True, padding_value = pad_value)
# 在 EOS 之后进行掩码处理
def mask_after_eos(target, eos_id, pad_id):
mask = (target == eos_id).cumsum(dim = -1) > 0
mask = F.pad(mask, (1, -1), value = False)
return target.masked_fill(mask, pad_id)
# 安全除法
def safe_div(num, den, eps = 1e-10):
return num / max(den, eps)
# 查找第一个为真的索引
def find_first_true_index(bool_tensor, dim = -1):
return (bool_tensor.cumsum(dim = dim) == 0).sum(dim = dim)
# 冻结和解冻辅助函数
# 设置模块参数是否需要梯度
def set_requires_grad_(module: Module, requires_grad: bool):
for p in module.parameters():
p.requires_grad = requires_grad
# 冻结模块参数
def freeze(module: Module):
set_requires_grad_(module, False)
# 解冻模块参数
def unfreeze(module: Module):
set_requires_grad_(module, True)
# 采样辅助函数
# 评估装饰器
def eval_decorator(fn):
def inner(self, *args, **kwargs):
was_training = self.training
self.eval()
out = fn(self, *args, **kwargs)
self.train(was_training)
return out
return inner
# 对数函数
def log(t, eps = 1e-20):
return torch.log(t.clamp(min = eps))
# 生成 Gumbel 噪声
def gumbel_noise(t):
noise = torch.zeros_like(t).uniform_(0, 1)
return -log(-log(noise))
# Gumbel 采样
def gumbel_sample(t, temperature = 1., dim = -1):
return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim = dim)
# Top-p 采样
def top_p(logits, thres = 0.9):
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = F.pad(cum_probs > thres, (1, -1), value = 0)
sorted_logits[sorted_indices_to_remove] = float('-inf')
sorted_logits = sorted_logits.scatter(-1, sorted_indices, sorted_logits)
return sorted_logits
# Top-k 采样
def top_k(logits, thres = 0.1, k = None):
if not exists(k):
k = math.ceil(thres * logits.shape[-1])
val, ind = torch.topk(logits, k, dim = -1)
probs = torch.full_like(logits, float('-inf'))
probs.scatter_(-1, ind, val)
return probs
# 残差包装器
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) + x
# RMSNorm
class RMSNorm(nn.Module):
# 初始化函数,接受一个维度参数
def __init__(self, dim):
# 调用父类的初始化函数
super().__init__()
# 计算缩放因子为维度的平方根
self.scale = dim ** 0.5
# 创建一个可学习的参数 gamma,维度为输入维度
self.gamma = nn.Parameter(torch.ones(dim))
# 前向传播函数,接受输入 x
def forward(self, x):
# 对输入 x 进行归一化操作,dim=-1 表示对最后一个维度进行归一化
return F.normalize(x, dim=-1) * self.scale * self.gamma
# 定义 GEGLU 类,用于实现 GEGLU 激活函数
class GEGLU(nn.Module):
# GEGLU 类的前向传播函数
def forward(self, x):
# 将输入张量 x 按照最后一个维度分成两部分
x, gate = x.chunk(2, dim = -1)
# 对 gate 部分应用 GELU 激活函数,并与 x 相乘
return F.gelu(gate) * x
# 定义 FeedForward 函数,用于创建前馈神经网络层
def FeedForward(dim, mult = 4, dropout = 0.):
# 计算内部维度
dim_inner = int(dim * mult * 2 / 3)
# 返回一个包含多个层的神经网络模型
return nn.Sequential(
RMSNorm(dim), # 使用 RMSNorm 进行归一化
nn.Linear(dim, dim_inner * 2), # 线性变换层
GEGLU(), # 使用 GEGLU 激活函数
nn.Dropout(dropout), # Dropout 层
nn.Linear(dim_inner, dim) # 线性变换层
)
# 定义 Attention 类,用于实现注意力机制
class Attention(nn.Module):
# Attention 类的初始化函数
def __init__(
self,
dim,
*,
dim_head = 64,
heads = 8,
kv_heads = None,
causal = False,
dim_context = None,
dropout = 0.,
rotary_emb: Optional[RotaryEmbedding] = None,
flash = False,
add_null_kv = False
):
super().__init__()
dim_context = default(dim_context, dim)
self.heads = heads
self.kv_heads = default(kv_heads, heads)
assert (self.heads % self.kv_heads) == 0, 'number of key value heads must be divisible by query heads'
self.scale = dim_head ** -0.5
dim_query_inner = heads * dim_head
dim_kv_inner = self.kv_heads * dim_head
self.rotary_emb = rotary_emb
self.attend = Attend(
causal = causal,
flash = flash,
dropout = dropout
)
self.norm = RMSNorm(dim)
self.attn_dropout = nn.Dropout(dropout)
# 将输入转换为查询向量
self.to_q = nn.Sequential(
nn.Linear(dim, dim_query_inner, bias = False),
Rearrange('b n (h d) -> b h n d', h = self.heads)
)
# 将上下文转换为键值对
self.to_kv = nn.Sequential(
nn.Linear(dim_context, dim_kv_inner * 2, bias = False),
Rearrange('b n (kv h d) -> kv b h n d', kv = 2, h = self.kv_heads)
)
# 将输出转换为指定维度
self.to_out = nn.Linear(dim_query_inner, dim, bias = False)
self.add_null_kv = add_null_kv
if add_null_kv:
self.null_kv = nn.Parameter(torch.randn(2, self.kv_heads, 1, dim_head))
# Attention 类的前向传播函数
def forward(
self,
x,
context = None,
mask = None,
cache = None,
return_cached_key_values = False
):
has_context = exists(context)
b = x.shape[0]
x = self.norm(x)
context = default(context, x)
q, k, v = (self.to_q(x), *self.to_kv(context))
if exists(cache):
ck, cv = cache.unbind(dim = 1)
k = torch.cat((ck, k), dim = -2)
v = torch.cat((cv, v), dim = -2)
new_cache = torch.stack((k, v), dim = 1)
if exists(self.rotary_emb):
assert not has_context
q, k = self.rotary_emb.rotate_queries_with_cached_keys(q, k)
if self.add_null_kv:
assert not exists(self.rotary_emb)
nk, nv = map(lambda t: repeat(t, 'h 1 d -> b h 1 d', b = b), self.null_kv)
k = torch.cat((nk, k), dim = -2)
v = torch.cat((nv, v), dim = -2)
if exists(mask):
mask = F.pad(mask, (1, 0), value = True)
out = self.attend(q, k, v, mask = mask)
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out)
if not return_cached_key_values:
return out
return out, new_cache
# 定义 Transformer 类,用于实现 Transformer 模型
class Transformer(nn.Module):
# Transformer 类的初始化函数
def __init__(
self,
*,
dim,
depth,
dim_head = 64,
heads = 8,
kv_heads = None,
causal = False,
attn_dropout = 0.,
ff_mult = 4,
ff_dropout = 0.,
cross_attend = False,
attn_flash = False
):
# 调用父类的构造函数
super().__init__()
# 创建旋转嵌入对象
rotary_emb = RotaryEmbedding(dim_head)
# 初始化神经网络层列表
self.layers = nn.ModuleList([])
# 循环创建指定数量的层
for _ in range(depth):
# 每一层包含注意力机制、交叉注意力机制(可选)、前馈神经网络
self.layers.append(nn.ModuleList([
Attention(dim = dim, causal = causal, dim_head = dim_head, heads = heads, kv_heads = kv_heads, dropout = attn_dropout, rotary_emb = rotary_emb, flash = attn_flash),
Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, flash = attn_flash, add_null_kv = True) if cross_attend else None,
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
]))
# 创建最终的归一化层
self.final_norm = RMSNorm(dim)
def forward(
self,
x,
mask = None,
context = None,
context_mask = None,
cache = None,
return_cache = False,
return_hiddens = False,
early_exit_at_layer = None,
seq_start_pos = None
):
# 检查是否存在上下文信息
has_context = exists(context)
# 如果存在序列起始位置信息,则生成对应的掩码
if exists(seq_start_pos):
assert not exists(mask)
seq_len = x.shape[-2]
seq_arange = torch.arange(seq_len, device = x.device, dtype = torch.long)
mask = seq_arange >= seq_start_pos[..., None]
# 如果存在缓存信息,则截取输入序列
if exists(cache):
cached_length, seq_len = cache.shape[-2], x.shape[-2]
assert seq_len > cached_length
x = x[:, cached_length:]
# 初始化新的缓存列表和隐藏层列表
new_cache = []
hiddens = []
# 如果存在缓存信息,则创建迭代器
if exists(cache):
iter_cache = iter(cache.unbind(dim = 1))
else:
iter_cache = iter([])
# 遍历每一层
for ind, (self_attn, maybe_cross_attn, ff) in enumerate(self.layers):
layer = ind + 1
# 计算自注意力机制输出,并更新缓存
residual = x
attn_out, key_values = self_attn(x, mask = mask, cache = next(iter_cache, None), return_cached_key_values = True)
x = attn_out + residual
new_cache.append(key_values)
# ��果存在交叉注意力机制,则应用
if exists(maybe_cross_attn):
assert has_context
x = maybe_cross_attn(x, context = context, mask = context_mask) + x
# 应用前馈神经网络
x = ff(x) + x
hiddens.append(x)
# 如果设置了提前退出层,则在该层结束循环
if exists(early_exit_at_layer) and early_exit_at_layer == layer:
break
# 如果设置了提前退出层,则返回结果或缓存
if exists(early_exit_at_layer):
if return_cache:
return x, torch.stack(new_cache, dim = 1)
return x
# 对最终输出进行归一化
out = self.final_norm(x)
# 如果需要返回隐藏层信息,则返回结果和隐藏层列表
if return_hiddens:
assert not return_cache
return out, torch.stack(hiddens)
# 如果不需要返回缓存信息,则返回结果
if not return_cache:
return out
# 返回结果和缓存信息
return out, torch.stack(new_cache, dim = 1)
# 定义 SpeechOrTextLiteral 类型,可以是'speech'或'text'中的一个
SpeechOrTextLiteral = Union[
Literal['speech'],
Literal['text']
]
# 定义 SemanticModelType 类型,可以是 FairseqVQWav2Vec 或 HubertWithKmeans 中的一个
SemanticModelType = Union[
FairseqVQWav2Vec,
HubertWithKmeans
]
# 定义 TextToSemantic 类,继承自 Module 类
class TextToSemantic(Module):
# 初始化函数
@beartype
def __init__(
self,
dim,
*,
source_depth,
target_depth,
num_text_token_ids = None,
tokenizer_encode: Optional[Callable] = None,
use_openai_tokenizer = False,
wav2vec: Optional[SemanticModelType] = None,
num_semantic_token_ids = None,
dim_head = 64,
heads = 8,
target_kv_heads = None, # for grouped query attention, saving memory on decoder inference
attn_dropout = 0.,
ff_mult = 4,
ff_dropout = 0.,
semantic_pad_id = -1,
text_pad_id = 0,
autoset_semantic_eos_id = True,
autoset_text_eos_id = True,
attn_flash = False,
cond_drop_prob = 0.,
target_early_exit_layer = None,
detach_early_exit_embed = False,
align_reg_loss_weight = 0.1,
align_reg_use_logsumexp_pool = True,
align_reg_logsumexp_pool_temp = 0.1
@property
def device(self):
# 返回第一个参数的设备
return next(self.parameters()).device
# 加载函数
def load(self, path, strict = True):
# 返回 pkg,以便如果此函数从 Trainer 函数调用中调用,则 Trainer 也可以访问从检查点加载的包
path = Path(path)
assert path.exists()
pkg = torch.load(str(path), map_location = 'cpu')
self.load_state_dict(pkg['model'], strict = strict)
return pkg
# 一组冻结/解冻工具
# 然后依赖 get_optimizer 来过滤不需要梯度的参数,使其暴露给优化器
# 解冻所有参数
def unfreeze_all(self):
unfreeze(self)
# 冻结编码器
def freeze_encoder(self):
freeze(self.source_transformer)
# 冻结编码器到某一层
def freeze_encoder_below_layer(self, layer: int):
"""
用于在伪标记数据集上对文本到语义的最终训练
他们将编码器部分冻结到某一层
"""
unfreeze(self.source_transformer)
for ind, module in enumerate(self.source_transformer.layers):
current_layer = ind + 1
if current_layer <= layer:
freeze(module)
# 冻结解码器
def freeze_decoder(self):
freeze(self.target_transformer)
# 冻结语音嵌入
def freeze_speech_emb(self):
freeze(self.token_emb['speech'])
self.start_token['speech'].requires_grad = False
# 冻结文本嵌入
def freeze_text_emb(self):
freeze(self.token_emb['text'])
self.start_token['text'].requires_grad = False
# 采样函数
@torch.no_grad()
@eval_decorator
@beartype
def generate(
self,
source: Union[List[str], Tensor],
*,
source_type: SpeechOrTextLiteral,
target_type: SpeechOrTextLiteral,
temperature = 1.,
filter_logits_fn = top_k,
filter_fn_kwargs: dict = dict(),
source_mask: Optional[Tensor] = None,
max_length = 2048,
beam_search_decode = False,
spec_decode = False,
spec_decode_gamma = 5,
spec_decode_lenience = 1.,
beam_size = 4,
return_source = False,
return_target_mask = False,
cond_scale = 1.
@beartype
def forward(
self,
source: Union[List[str], Tensor],
target: Union[List[str], Tensor],
*,
source_type: SpeechOrTextLiteral,
target_type: SpeechOrTextLiteral,
source_mask: Optional[Tensor] = None,
target_mask: Optional[Tensor] = None,
return_loss = False,
return_logits = False,
cond_drop_prob: Optional[float] = None,
should_sim_regularize = True,
return_early_exit_loss = False
# 预训练模块
# 获取掩码子集概率函数
def get_mask_subset_prob(mask, prob, min_mask = 0):
batch, seq, device = *mask.shape, mask.device
# 计算每个位置需要mask的数量,根据mask的和与概率相乘,并限制最小值为min_mask
num_to_mask = (mask.sum(dim=-1, keepdim=True) * prob).clamp(min=min_mask)
# 生成一个指定大小的随机张量,用于存储logits
logits = torch.rand((batch, seq), device=device)
# 根据mask将logits中的非mask位置填充为-1
logits = logits.masked_fill(~mask, -1)
# 对logits进行排序,返回排序后的索引
randperm = logits.argsort(dim=-1).float()
# 计算每个样本中需要填充的数量
num_padding = (~mask).sum(dim=-1, keepdim=True)
# 将randperm中的索引减去需要填充的数量,以保证填充的位置不会被选中
randperm -= num_padding
# 生成一个布尔张量,表示哪些位置需要被选中
subset_mask = randperm < num_to_mask
# 将subset_mask中非mask位置填充为False
subset_mask.masked_fill_(~mask, False)
# 返回subset_mask
return subset_mask
# 定义一个包装器类,用于语音到语义预训练任务
class SpeechSpeechPretrainWrapper(nn.Module):
# 初始化方法
@beartype
def __init__(
self,
model: TextToSemantic, # 语义模型
wav2vec: Optional[SemanticModelType] = None, # 可选的语音模型
deletion_prob: float = 0.6, # 删除概率
reconstruct_seq: bool = False, # 是否重构序列
mask_id = None # 掩码 ID
):
super().__init__()
self.model = model # 保存语义模型
self.wav2vec = default(wav2vec, model.wav2vec) # 保存语音模型,默认为语义模型的 wav2vec
self.deletion_prob = deletion_prob # 保存删除概率
self.reconstruct_seq = reconstruct_seq # 是否重构序列
self.mask_id = mask_id # 掩码 ID
# 前向传播方法
def forward(
self,
x, # 输入数据
return_early_exit_loss = False # 是否返回早期退出损失
):
is_raw_audio = x.dtype == torch.float # 判断输入数据是否为原始音频
if is_raw_audio:
assert exists(self.wav2vec) # 断言语音模型存在
with torch.no_grad():
self.wav2vec.eval() # 设置语音模型为评估模式
x = self.wav2vec(x, flatten = False) # 对输入数据进行处理
batch = x.shape[0] # 获取批次大小
mask = torch.ones_like(x, dtype = torch.bool, device = self.model.device) # 创建与输入数据相同形状的掩码
if exists(self.mask_id):
assert self.reconstruct_seq, 'reconstruct_seq must be true if mask id is provided' # 如果提供了掩码 ID,则重构序列必须为真
mask = mask.masked_fill(x == self.model.semantic_pad_id, False) # 根据语义填充 ID 进行掩码
delete_mask = get_mask_subset_prob(mask, self.deletion_prob) # 获取删除掩码
source = x.masked_fill(delete_mask, self.mask_id) # 根据删除掩码和掩码 ID 生成源数据
else:
delete_mask = get_mask_subset_prob(mask, self.deletion_prob) # 获取删除掩码
source = rearrange(x[~delete_mask], '(b n) -> b n', b = batch) # 重新排列数据
if self.reconstruct_seq:
target = x # 目标数据为输入数据
else:
target = rearrange(x[delete_mask], '(b n) -> b n', b = batch) # 目标数据为删除后的数据
loss, logits = self.model(
source, target, # 输入源数据和目标数据
source_type = 'speech', # 源数据类型为语音
target_type = 'speech', # 目标数据类型为语音
return_loss = True, # 返回损失
return_logits = True, # 返回 logits
return_early_exit_loss = return_early_exit_loss, # 是否返回早期退出损失
)
return loss, logits
# 包装器类,用于反向翻译任务
class SemanticToTextWrapper(nn.Module):
# 初始化方法
@beartype
def __init__(
self,
model: TextToSemantic # 语义模型
):
super().__init__()
self.model = model # 保存语义模型
# 前向传播方法
def forward(
self,
semantic_token_ids, # 语义标记 ID
grapheme_token_ids, # 字形标记 ID
):
source = semantic_token_ids # 源数据为语义标记 ID
target = grapheme_token_ids # 目标数据为字形标记 ID
loss, logits = self.model(
source, target, # 输入源数据和目标数据
source_type = 'speech', # 源数据类型为语音
target_type = 'text', # 目标数据类型为文本
return_loss = True, # 返回损失
return_logits = True # 返回 logits
)
return loss, logits
# 包装器类,用于文本到语义任务
class TextToSemanticWrapper(nn.Module):
# 初始化方法
@beartype
def __init__(
self,
model: TextToSemantic # 语义模型
):
super().__init__()
self.model = model # 保存语义模型
# 前向传播方法
def forward(
self,
grapheme_token_ids, # 字形标记 ID
semantic_token_ids, # 语义标记 ID
return_early_exit_loss = True # 是否返回早期退出损失
):
source = grapheme_token_ids # 源数据为字形标记 ID
target = semantic_token_ids # 目标数据为语义标记 ID
loss, logits = self.model(
source, target, # 输入源数据和目标数据
source_type = 'text', # 源数据类型为文本
target_type = 'speech', # 目标数据类型为语音
return_loss = True, # 返回损失
return_logits = True, # 返回 logits
return_early_exit_loss = return_early_exit_loss # 是否返回早期退出损失
)
return loss, logits
# 包装器类,用于生成伪标记的音频到文本数据集
class SemanticToTextDatasetGenerator(nn.Module):
# 初始化方法
@beartype
def __init__(
self,
model, # 模型
*,
dataset: Dataset, # 数据集
folder = './generated-audio-text-pairs', # 文件夹路径
batch_size = 4, # 批次大小
delimiter_id: int = -1, # 分隔符 ID
audio_pad_id = None, # 音频填充 ID
text_pad_id = 0 # 文本填充 ID
# 初始化函数,设置模型、数据集、数据加载器等参数
def __init__(
self,
model,
dataset,
batch_size,
delimiter_id,
audio_pad_id,
text_pad_id,
folder
):
# 调用父类的初始化函数
super().__init__()
# 设置模型
self.model = model
# 设置数据集
self.dataset = dataset
# 根据数据集和批量大小创建数据加载器
self.dl = get_dataloader(dataset, batch_size=batch_size)
# 设置分隔符的 ID
self.delimiter_id = delimiter_id
# 设置音频填充符的 ID
self.audio_pad_id = audio_pad_id
# 设置文本填充符的 ID
self.text_pad_id = text_pad_id
# 将文件夹路径转换为 Path 对象,并创建文件夹(如果不存在)
self.folder = Path(folder)
self.folder.mkdir(exist_ok=True, parents=True)
# 前向传播函数,生成文本数据
def forward(
self,
max_length=2048,
beam_search_decode=True,
**generate_kwargs
):
# 创建包含分隔符 ID 的张量
delimiter = torch.tensor([self.delimiter_id], device=self.model.device)
# 计数器,用于生成文件名
counter = 0
# 遍历数据加载器中的音频数据
for audio, in self.dl:
# 生成音频语义 ID 和文本 ID
audio_semantic_ids, text_ids = self.model.generate(
source=audio,
source_type='speech',
target_type='text',
return_source=True,
max_length=max_length,
beam_search_decode=beam_search_decode,
**generate_kwargs
)
# 遍历音频语义 ID 和文本 ID
for audio_semantic_id, text_id in zip(audio_semantic_ids, text_ids):
# 如果音频填充符存在,则创建音频填充掩码并去除填充符
if exists(self.audio_pad_id):
audio_pad_mask = audio_semantic_id == self.audio_pad_id
audio_semantic_id = audio_semantic_id[~audio_pad_mask]
# 如果文本填充符存在,则创建文本填充掩码并去除填充符
if exists(self.text_pad_id):
text_pad_mask = text_id == self.text_pad_id
text_id = text_id[~text_pad_mask]
# 将音频语义 ID、分隔符和文本 ID 打包成一行数据
row, _ = pack([audio_semantic_id, delimiter, text_id], '*')
# 构建保存路径
path = str(self.folder / f'{counter}.pt')
# 保存数据到指定路径
torch.save(row, path)
# 更新计数器
counter += 1
.\lucidrains\spear-tts-pytorch\spear_tts_pytorch\trainer.py
# 导入必要的库
import re
from pathlib import Path
from shutil import rmtree
# 导入 beartype 库中的函数和类型
from beartype import beartype
from beartype.door import is_bearable
from beartype.typing import Union, Optional, Tuple
# 导入 PyTorch 库
import torch
from torch import nn, LongTensor, IntTensor
from torch.utils.data import ConcatDataset
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import Dataset, random_split
# 导入 audiolm_pytorch 库中的模型和函数
from audiolm_pytorch import FairseqVQWav2Vec, HubertWithKmeans
from audiolm_pytorch.data import get_dataloader
from audiolm_pytorch.optimizer import get_optimizer
# 导入 spear_tts_pytorch 库中的模型和数据集
from spear_tts_pytorch.spear_tts_pytorch import SpeechSpeechPretrainWrapper, TextToSemantic, SemanticToTextWrapper, TextToSemanticWrapper
from spear_tts_pytorch.data import GeneratedAudioTextDataset
# 导入 accelerate 库中的加速器和分布式类型
from accelerate import Accelerator, DistributedType
# 定义类型别名
IndicesTensor = Union[LongTensor, IntTensor]
# 确保只有一个 Trainer 实例化
ONE_TRAINER_INSTANTIATED = False
def check_one_trainer():
global ONE_TRAINER_INSTANTIATED
assert not ONE_TRAINER_INSTANTIATED, 'only one Trainer can be instantiated at a time for training'
ONE_TRAINER_INSTANTIATED = True
# 辅助函数
# 检查值是否存在
def exists(val):
return val is not None
# 空操作函数
def noop(*args, **kwargs):
pass
# 无限循环生成数据集
def cycle(dl):
while True:
for data in dl:
yield data
# 将输入转换为元组
def cast_tuple(t):
return t if isinstance(t, (tuple, list)) else (t,)
# 询问用户是或否
def yes_or_no(question):
answer = input(f'{question} (y/n) ')
return answer.lower() in ('yes', 'y')
# 累积日志信息
def accum_log(log, new_logs):
for key, new_value in new_logs.items():
old_value = log.get(key, 0.)
log[key] = old_value + new_value
return log
# 从检查点文件名中获取训练步数
def checkpoint_num_steps(checkpoint_path):
"""Returns the number of steps trained from a checkpoint based on the filename.
Filename format assumed to be something like "/path/to/speech.speech.20000.pt" which is
for 20k train steps. Returns 20000 in that case.
"""
results = re.findall(r'\d+', str(checkpoint_path)
if len(results) == 0:
return 0
return int(results[-1])
# 定义 SpeechSpeechPretrainer 类
class SpeechSpeechPretrainer(nn.Module):
@beartype
def __init__(
self,
model: TextToSemantic,
wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]],
*,
num_train_steps,
num_warmup_steps,
batch_size,
dataset: Optional[Dataset] = None,
deletion_prob: float = 0.6,
reconstruct_seq: bool = False,
mask_id = None,
lr = 3e-4,
initial_lr = 1e-5,
grad_accum_every = 1,
wd = 0.,
max_grad_norm = 0.5,
valid_frac = 0.05,
random_split_seed = 42,
log_every = 10,
save_results_every = 100,
save_model_every = 1000,
results_folder = './results',
accelerate_kwargs: dict = dict(),
split_batches = False,
drop_last = False,
force_clear_prev_results = None
):
# 调用父类的构造函数
super().__init__()
# 检查是否只有一个训练器
check_one_trainer()
# 初始化加速器
self.accelerator = Accelerator(
split_batches = split_batches,
**accelerate_kwargs
)
# 设置模型和wav2vec
self.model = model
self.wav2vec = wav2vec
# 初始化训练包装器
self.train_wrapper = SpeechSpeechPretrainWrapper(
model = model,
wav2vec = wav2vec,
deletion_prob = deletion_prob,
reconstruct_seq = reconstruct_seq,
mask_id = mask_id
)
# 注册缓冲区
self.register_buffer('steps', torch.Tensor([0]))
# 设置训练步数、热身步数、批量大小、梯度累积频率
self.num_train_steps = num_train_steps
self.num_warmup_steps = num_warmup_steps
self.batch_size = batch_size
self.grad_accum_every = grad_accum_every
# 优化器
self.lr = lr
self.initial_lr = initial_lr
self.optim = get_optimizer(model.parameters(), lr = lr, wd = wd)
self.scheduler = CosineAnnealingLR(self.optim, T_max = num_train_steps)
# 最大梯度范数
self.max_grad_norm = max_grad_norm
# 创建数据集
self.ds = dataset
# 划分验证集
if valid_frac > 0:
train_size = int((1 - valid_frac) * len(self.ds))
valid_size = len(self.ds) - train_size
self.ds, self.valid_ds = random_split(self.ds, [train_size, valid_size], generator = torch.Generator().manual_seed(random_split_seed))
self.print(f'training with dataset of {len(self.ds)} samples and validating with randomly splitted {len(self.valid_ds)} samples')
else:
self.valid_ds = self.ds
self.print(f'training with shared training and valid dataset of {len(self.ds)} samples')
# 断言确保数据集和验证集的样本数足够
assert len(self.ds) >= batch_size, 'dataset must have sufficient samples for training'
assert len(self.valid_ds) >= batch_size, f'validation dataset must have sufficient number of samples (currently {len(self.valid_ds)}) for training'
# 数据加载器
self.dl = get_dataloader(self.ds, batch_size = batch_size, shuffle = True, drop_last = drop_last)
self.valid_dl = get_dataloader(self.valid_ds, batch_size = batch_size, shuffle = True, drop_last = drop_last)
# 使用加速器准备训练所需的对象
(
self.train_wrapper,
self.optim,
self.scheduler,
self.dl,
self.valid_dl
) = self.accelerator.prepare(
self.train_wrapper,
self.optim,
self.scheduler,
self.dl,
self.valid_dl
)
# 数据加载器迭代器
self.dl_iter = cycle(self.dl)
self.valid_dl_iter = cycle(self.valid_dl)
# 设置日志、保存模型和保存结果的频率
self.log_every = log_every
self.save_model_every = save_model_every
self.save_results_every = save_results_every
# 设置结果文件夹路径
self.results_folder = Path(results_folder)
# 如果是主进程且需要清除之前的结果,则清除结果文件夹
if self.is_main and force_clear_prev_results is True or (not exists(force_clear_prev_results) and len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?')):
rmtree(str(self.results_folder))
# 创建结果文件夹
self.results_folder.mkdir(parents = True, exist_ok = True)
# 初始化超参数跟踪器
hps = {"num_train_steps": num_train_steps, "num_warmup_steps": num_warmup_steps, "learning_rate": lr, "initial_learning_rate": lr}
self.accelerator.init_trackers("speechspeech", config=hps)
# 保存模型
def save(self, path):
pkg = dict(
model = self.accelerator.get_state_dict(self.model),
optim = self.optim.state_dict(),
scheduler = self.scheduler.state_dict()
)
torch.save(pkg, path)
# 加载模型参数和优化器状态
def load(self, path):
# 获取未封装的模型
model = self.accelerator.unwrap_model(self.model)
# 加载模型
pkg = model.load(path)
# 加载优化器状态
self.optim.load_state_dict(pkg['optim'])
# 加载调度器状态
self.scheduler.load_state_dict(pkg['scheduler'])
# 从下一个步骤开始,避免覆盖最后一个检查点
self.steps = torch.tensor([checkpoint_num_steps(path) + 1], device=self.device)
# 打印消息
def print(self, msg):
self.accelerator.print(msg)
# 生成结果
def generate(self, *args, **kwargs):
return self.train_wrapper.generate(*args, **kwargs)
# 获取设备
@property
def device(self):
return self.accelerator.device
# 判断是否分布式训练
@property
def is_distributed(self):
return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1)
# 判断是否为主进程
@property
def is_main(self):
return self.accelerator.is_main_process
# 判断是否为本地主进程
@property
def is_local_main(self):
return self.accelerator.is_local_main_process
# 热身训练
def warmup(self, step):
if step < self.num_warmup_steps:
return self.initial_lr + (self.lr - self.initial_lr) * step / self.num_warmup_steps
else:
return self.lr
# 训练步骤
def train_step(self):
steps = int(self.steps.item())
self.model.train()
# 根据调度器调整学习率
if steps < self.num_warmup_steps:
# 应用热身训练
lr = self.warmup(steps)
for param_group in self.optim.param_groups:
param_group['lr'] = lr
else:
# 热身训练后,开始应用余弦退火学习率调度器
self.scheduler.step()
# 日志
logs = {}
# 更新 VAE(生成器)
for _ in range(self.grad_accum_every):
x, = next(self.dl_iter)
loss, _ = self.train_wrapper(x)
self.accelerator.backward(loss / self.grad_accum_every)
accum_log(logs, {'loss': loss.item() / self.grad_accum_every})
if exists(self.max_grad_norm):
self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
self.optim.step()
self.optim.zero_grad()
# 日志
if not (steps % self.log_every):
self.print(f"{steps}: loss: {logs['loss']:0.3f}")
self.accelerator.log({"train_loss": logs['loss']}, step=steps)
# 定期采样结果
self.accelerator.wait_for_everyone()
if self.is_main and not (steps % self.save_results_every):
x, = next(self.valid_dl_iter)
with torch.inference_mode():
self.train_wrapper.eval()
valid_loss, _ = self.train_wrapper(x)
self.print(f'{steps}: valid loss {valid_loss:0.3f}')
self.accelerator.log({"valid_loss": valid_loss}, step=steps)
# 定期保存模型
if self.is_main and not (steps % self.save_model_every):
model_path = str(self.results_folder / f'speech.speech.{steps}.pt')
self.save(model_path)
self.print(f'{steps}: saving model to {str(self.results_folder)}')
self.steps += 1
return logs
# 训练模型
def train(self, log_fn = noop):
while self.steps < self.num_train_steps:
logs = self.train_step()
log_fn(logs)
self.print('training complete')
# 定义一个用于将语义转换为文本的训练器类
class SemanticToTextTrainer(nn.Module):
# 初始化方法,接受多个参数
@beartype
def __init__(
self,
model: TextToSemantic, # 模型参数,用于将文本转换为语义
*,
num_train_steps, # 训练步数
num_warmup_steps, # 热身步数
batch_size, # 批量大小
dataset: Optional[Dataset] = None, # 数据集,默认为None
lr = 3e-4, # 学习率,默认为3e-4
initial_lr = 1e-5, # 初始学习率,默认为1e-5
grad_accum_every = 1, # 梯度累积频率,默认为1
wd = 0., # 权重衰减,默认为0
max_grad_norm = 0.5, # 最大梯度范数,默认为0.5
valid_frac = 0.05, # 验证集比例,默认为0.05
random_split_seed = 42, # 随机拆分种子,默认为42
log_every = 10, # 每隔多少步记录日志,默认为10
save_results_every = 100, # 每隔多少步保存结果,默认为100
save_model_every = 1000, # 每隔多少步保存模型,默认为1000
results_folder = './results', # 结果保存文件夹,默认为'./results'
accelerate_kwargs: dict = dict(), # 加速参数,默认为空字典
split_batches = False, # 是否拆分批次,默认为False
drop_last = False, # 是否丢弃最后一批数据,默认为False
force_clear_prev_results = None # 强制清除之前的结果,默认为None
):
# 调用父类的构造函数
super().__init__()
# 检查是否只有一个训练器
check_one_trainer()
# 初始化加速器
self.accelerator = Accelerator(
split_batches = split_batches,
**accelerate_kwargs
)
# 设置模型
self.model = model
# 创建训练包装器
self.train_wrapper = SemanticToTextWrapper(model = model)
# 注册缓冲区
self.register_buffer('steps', torch.Tensor([0]))
# 设置训练步数、预热步数、批量大小、梯度累积频率
self.num_train_steps = num_train_steps
self.num_warmup_steps = num_warmup_steps
self.batch_size = batch_size
self.grad_accum_every = grad_accum_every
# 在进行反向翻译时,冻结编码器和语音嵌入
model.unfreeze_all()
model.freeze_speech_emb()
model.freeze_encoder()
# 优化器
# get_optimizer应该过滤掉冻结的参数(requires_grad设置为False的参数)
self.optim = get_optimizer(
model.parameters(),
lr = lr,
wd = wd,
filter_by_requires_grad = True
)
self.lr = lr
self.initial_lr = initial_lr
self.scheduler = CosineAnnealingLR(self.optim, T_max = num_train_steps)
# 最大梯度范数
self.max_grad_norm = max_grad_norm
# 创建数据集
self.ds = dataset
# 划分验证集
if valid_frac > 0:
train_size = int((1 - valid_frac) * len(self.ds))
valid_size = len(self.ds) - train_size
self.ds, self.valid_ds = random_split(self.ds, [train_size, valid_size], generator = torch.Generator().manual_seed(random_split_seed))
self.print(f'training with dataset of {len(self.ds)} samples and validating with randomly splitted {len(self.valid_ds)} samples')
else:
self.valid_ds = self.ds
self.print(f'training with shared training and valid dataset of {len(self.ds)} samples')
assert len(self.ds) >= batch_size, 'dataset must have sufficient samples for training'
assert len(self.valid_ds) >= batch_size, f'validation dataset must have sufficient number of samples (currently {len(self.valid_ds)}) for training'
# 数据加载器
self.dl = get_dataloader(self.ds, batch_size = batch_size, shuffle = True, drop_last = drop_last)
self.valid_dl = get_dataloader(self.valid_ds, batch_size = batch_size, shuffle = True, drop_last = drop_last)
# 使用加速器准备
(
self.train_wrapper,
self.optim,
self.scheduler,
self.dl,
self.valid_dl
) = self.accelerator.prepare(
self.train_wrapper,
self.optim,
self.scheduler,
self.dl,
self.valid_dl
)
# 数据加载器迭代器
self.dl_iter = cycle(self.dl)
self.valid_dl_iter = cycle(self.valid_dl)
self.log_every = log_every
self.save_model_every = save_model_every
self.save_results_every = save_results_every
self.results_folder = Path(results_folder)
# 如果是主进程并且强制清除之前的结果或者(force_clear_prev_results不存在且结果文件夹中有文件且用户确认清除)
if self.is_main and force_clear_prev_results is True or (not exists(force_clear_prev_results) and len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?')):
rmtree(str(self.results_folder))
# 创建结果文件夹
self.results_folder.mkdir(parents = True, exist_ok = True)
# 初始化超参数跟踪器
hps = {"num_train_steps": num_train_steps, "num_warmup_steps": num_warmup_steps, "learning_rate": lr, "initial_learning_rate": lr}
self.accelerator.init_trackers("semantictext", config=hps)
# 保存模型
def save(self, path):
pkg = dict(
model = self.accelerator.get_state_dict(self.model),
optim = self.optim.state_dict(),
scheduler = self.scheduler.state_dict()
)
torch.save(pkg, path)
# 加载模型参数和优化器状态
def load(self, path, restore_optimizer = True):
# 获取未封装的模型对象
model = self.accelerator.unwrap_model(self.model)
# 加载模型参数
pkg = model.load(path)
# 如果需要恢复优化器状态
if restore_optimizer:
# 加载优化器状态
self.optim.load_state_dict(pkg['optim'])
# 加载学习率调度器状态
self.scheduler.load_state_dict(pkg['scheduler'])
# 从下一个步骤开始,避免覆盖最后一个检查点
self.steps = torch.tensor([checkpoint_num_steps(path) + 1], device=self.device)
# 打印消息
def print(self, msg):
self.accelerator.print(msg)
# 生成结果
def generate(self, *args, **kwargs):
return self.train_wrapper.generate(*args, **kwargs)
# 获取设备
@property
def device(self):
return self.accelerator.device
# 判断是否分布式训练
@property
def is_distributed(self):
return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1)
# 判断是否为主进程
@property
def is_main(self):
return self.accelerator.is_main_process
# 判断是否为本地主进程
@property
def is_local_main(self):
return self.accelerator.is_local_main_process
# 热身训练
def warmup(self, step):
if step < self.num_warmup_steps:
return self.initial_lr + (self.lr - self.initial_lr) * step / self.num_warmup_steps
else:
return self.lr
# 训练步骤
def train_step(self):
steps = int(self.steps.item())
# 设置模型为训练模式
self.model.train()
# 根据调度器调整学习率
if steps < self.num_warmup_steps:
# 应用热身训练
lr = self.warmup(steps)
for param_group in self.optim.param_groups:
param_group['lr'] = lr
else:
# 热身训练后,开始应用余弦退火学习率调度器
self.scheduler.step()
# 日志
logs = {}
# 更新 VAE(生成器)
for _ in range(self.grad_accum_every):
semantic_token_ids, grapheme_token_ids = next(self.dl_iter)
loss, _ = self.train_wrapper(semantic_token_ids = semantic_token_ids, grapheme_token_ids = grapheme_token_ids)
self.accelerator.backward(loss / self.grad_accum_every)
accum_log(logs, {'loss': loss.item() / self.grad_accum_every})
# 如果存在最大梯度范数,则进行梯度裁剪
if exists(self.max_grad_norm):
self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
self.optim.step()
self.optim.zero_grad()
# 记录日志
if not (steps % self.log_every):
self.print(f"{steps}: loss: {logs['loss']:0.3f}")
self.accelerator.log({"train_loss": logs['loss']}, step=steps)
# 定期采样结果
self.accelerator.wait_for_everyone()
if self.is_main and not (steps % self.save_results_every):
semantic_token_ids, grapheme_token_ids = next(self.valid_dl_iter)
with torch.inference_mode():
self.train_wrapper.eval()
valid_loss, _ = self.train_wrapper(semantic_token_ids = semantic_token_ids, grapheme_token_ids = grapheme_token_ids)
self.print(f'{steps}: valid loss {valid_loss:0.3f}')
self.accelerator.log({"valid_loss": valid_loss}, step=steps)
# 定期保存模型
if self.is_main and not (steps % self.save_model_every):
model_path = str(self.results_folder / f'semantic.text.{steps}.pt')
self.save(model_path)
self.print(f'{steps}: saving model to {str(self.results_folder)}')
self.steps += 1
return logs
# 训练模型
def train(self, log_fn = noop):
while self.steps < self.num_train_steps:
logs = self.train_step()
log_fn(logs)
self.print('training complete')
# 定义一个用于训练文本到语义模型的类
class TextToSemanticTrainer(nn.Module):
# 初始化函数,接受模型、训练步数、预热步数等参数
@beartype
def __init__(
self,
model: TextToSemantic,
*,
num_train_steps,
num_warmup_steps,
batch_size,
dataset: Optional[Dataset] = None,
generated_audio_text_dataset_folder = None,
dataset_delimiter_id = -1,
lr = 3e-4,
initial_lr = 1e-5,
grad_accum_every = 1,
wd = 0.,
max_grad_norm = 0.5,
valid_frac = 0.05,
random_split_seed = 42,
log_every = 10,
save_results_every = 100,
save_model_every = 1000,
results_folder = './results',
accelerate_kwargs: dict = dict(),
split_batches = False,
drop_last = False,
force_clear_prev_results = None,
freeze_encoder_layers_below = 2,
should_train_early_exit_layer_if_available = True
# 保存模型参数到指定路径
def save(self, path):
pkg = dict(
model = self.accelerator.get_state_dict(self.model),
optim = self.optim.state_dict(),
scheduler = self.scheduler.state_dict()
)
torch.save(pkg, path)
# 从指定路径加载模型参数,可选择是否还原优化器状态
def load(self, path, restore_optimizer = True):
model = self.accelerator.unwrap_model(self.model)
pkg = model.load(path)
if restore_optimizer:
self.optim.load_state_dict(pkg['optim'])
self.scheduler.load_state_dict(pkg['scheduler'])
# + 1 to start from the next step and avoid overwriting the last checkpoint
self.steps = torch.tensor([checkpoint_num_steps(path) + 1], device=self.device)
# 打印消息
def print(self, msg):
self.accelerator.print(msg)
# 生成结果
def generate(self, *args, **kwargs):
return self.train_wrapper.generate(*args, **kwargs)
# 返回设备信息
@property
def device(self):
return self.accelerator.device
# 判断是否为分布式训练
@property
def is_distributed(self):
return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1)
# 判断是否为主进程
@property
def is_main(self):
return self.accelerator.is_main_process
# 判断是否为本地主进程
@property
def is_local_main(self):
return self.accelerator.is_local_main_process
# 根据当前步数计算学习率
def warmup(self, step):
if step < self.num_warmup_steps:
return self.initial_lr + (self.lr - self.initial_lr) * step / self.num_warmup_steps
else:
return self.lr
# 定义训练步骤函数
def train_step(self):
# 获取当前步数
steps = int(self.steps.item())
# 设置模型为训练模式
self.model.train()
# 根据训练步数调整学习率
if steps < self.num_warmup_steps:
# 如果步数小于预热步数,应用预热
lr = self.warmup(steps)
for param_group in self.optim.param_groups:
param_group['lr'] = lr
else:
# 预热期后,开始应用余弦退火学习率调度器
self.scheduler.step()
# 日志
logs = {}
# 更新 VAE(生成器)
for _ in range(self.grad_accum_every):
semantic_token_ids, grapheme_token_ids = next(self.dl_iter)
# 计算损失并进行训练
loss, _ = self.train_wrapper(semantic_token_ids=semantic_token_ids, grapheme_token_ids=grapheme_token_ids, return_early_exit_loss=self.train_early_exit)
self.accelerator.backward(loss / self.grad_accum_every)
accum_log(logs, {'loss': loss.item() / self.grad_accum_every})
# 如果存在最大梯度范数,对梯度进行裁剪
if exists(self.max_grad_norm):
self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
# 更新优化器
self.optim.step()
self.optim.zero_grad()
# 记录日志
if not (steps % self.log_every):
self.print(f"{steps}: loss: {logs['loss']:0.3f}")
self.accelerator.log({"train_loss": logs['loss']}, step=steps)
# 定期采样结果
self.accelerator.wait_for_everyone()
if self.is_main and not (steps % self.save_results_every):
semantic_token_ids, grapheme_token_ids = next(self.valid_dl_iter)
with torch.inference_mode():
self.train_wrapper.eval()
valid_loss, _ = self.train_wrapper(semantic_token_ids=semantic_token_ids, grapheme_token_ids=grapheme_token_ids, return_early_exit_loss=self.train_early_exit)
self.print(f'{steps}: valid loss {valid_loss:0.3f}')
self.accelerator.log({"valid_loss": valid_loss}, step=steps)
# 定期保存模型
if self.is_main and not (steps % self.save_model_every):
model_path = str(self.results_folder / f'text.semantic.{steps}.pt')
self.save(model_path)
self.print(f'{steps}: saving model to {str(self.results_folder)}')
# 更新步数并返回日志
self.steps += 1
return logs
# 训练函数
def train(self, log_fn=noop):
# 在未达到训练步数前循环执行训练步骤
while self.steps < self.num_train_steps:
logs = self.train_step()
log_fn(logs)
self.print('training complete')