Lucidrains-系列项目源码解析-九十三-

93 阅读25分钟

Lucidrains 系列项目源码解析(九十三)

.\lucidrains\soundstorm-pytorch\soundstorm_pytorch\trainer.py

# 导入必要的模块
from pathlib import Path
import re
from shutil import rmtree

# 导入 beartype 模块及相关类型
from beartype import beartype
from beartype.typing import Optional

# 导入 PyTorch 相关模块
import torch
from torch import nn
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import Dataset, random_split

# 导入自定义模块
from audiolm_pytorch.data import get_dataloader
from audiolm_pytorch.optimizer import get_optimizer

from soundstorm_pytorch.soundstorm import SoundStorm

# 导入加速器模块及分布式类型
from accelerate import Accelerator, DistributedType

# 定义一些辅助函数

# 判断值是否存在
def exists(val):
    return val is not None

# 空操作函数
def noop(*args, **kwargs):
    pass

# 生成数据循环
def cycle(dl):
    while True:
        for data in dl:
            yield data

# 将输入转换为元组
def cast_tuple(t):
    return t if isinstance(t, (tuple, list)) else (t,)

# 询问用户是或否
def yes_or_no(question):
    answer = input(f'{question} (y/n) ')
    return answer.lower() in ('yes', 'y')

# 累积日志信息
def accum_log(log, new_logs):
    for key, new_value in new_logs.items():
        old_value = log.get(key, 0.)
        log[key] = old_value + new_value
    return log

# 从检查点文件名中获取训练步数
def checkpoint_num_steps(checkpoint_path):
    """Returns the number of steps trained from a checkpoint based on the filename.

    Filename format assumed to be something like "/path/to/soundstorm.20000.pt" which is
    for 20k train steps. Returns 20000 in that case.
    """
    results = re.findall(r'\d+', str(checkpoint_path)

    if len(results) == 0:
        return 0

    return int(results[-1])

# 定义 SoundStormTrainer 类
class SoundStormTrainer(nn.Module):
    @beartype
    def __init__(
        self,
        model: SoundStorm,
        *,
        num_train_steps,
        num_warmup_steps,
        batch_size,
        dataset: Optional[Dataset] = None,
        only_train_generator = False,
        only_train_critic = False,
        lr = 3e-4,
        initial_lr = 1e-5,
        grad_accum_every = 1,
        wd = 0.,
        max_grad_norm = 0.5,
        valid_frac = 0.05,
        random_split_seed = 42,
        save_results_every = 100,
        save_model_every = 1000,
        results_folder = './results',
        accelerate_kwargs: dict = dict(),
        split_batches = False,
        drop_last = False,
        force_clear_prev_results = None
    # 初始化函数,继承父类的初始化方法
    ):
        super().__init__()

        # 初始化加速器对象
        self.accelerator = Accelerator(
            split_batches = split_batches,
            **accelerate_kwargs
        )

        # 设置模型
        self.model = model

        # 注册缓冲区,存储训练步数
        self.register_buffer('steps', torch.Tensor([0]))

        # 设置训练步数、预热步数、批量大小、梯度累积步数等参数
        self.num_train_steps = num_train_steps
        self.num_warmup_steps = num_warmup_steps
        self.batch_size = batch_size
        self.grad_accum_every = grad_accum_every
        
        self.only_train_generator = only_train_generator
        self.only_train_critic = only_train_critic

        # 初始化优化器
        self.optim = get_optimizer(
            model.parameters(),
            lr = lr,
            wd = wd
        )

        self.lr = lr
        self.initial_lr = initial_lr
        # 设置学习率调度器为余弦退火调度器
        self.scheduler = CosineAnnealingLR(self.optim, T_max = num_train_steps)

        # 设置梯度裁剪阈值
        self.max_grad_norm = max_grad_norm

        # 创建数据集
        self.ds = dataset

        # 划分验证集
        if valid_frac > 0:
            train_size = int((1 - valid_frac) * len(self.ds))
            valid_size = len(self.ds) - train_size
            self.ds, self.valid_ds = random_split(self.ds, [train_size, valid_size], generator = torch.Generator().manual_seed(random_split_seed))
            self.print(f'training with dataset of {len(self.ds)} samples and validating with randomly splitted {len(self.valid_ds)} samples')
        else:
            self.valid_ds = self.ds
            self.print(f'training with shared training and valid dataset of {len(self.ds)} samples')

        # 断言确保数据集和验证集的样本数足够
        assert len(self.ds) >= batch_size, 'dataset must have sufficient samples for training'
        assert len(self.valid_ds) >= batch_size, f'validation dataset must have sufficient number of samples (currently {len(self.valid_ds)}) for training'

        # 创建数据加载器
        self.dl = get_dataloader(self.ds, batch_size = batch_size, shuffle = True, drop_last = drop_last)
        self.valid_dl = get_dataloader(self.valid_ds, batch_size = batch_size, shuffle = True, drop_last = drop_last)

        # 使用加速器准备模型、优化器、调度器、数据加载器
        (
            self.model,
            self.optim,
            self.scheduler,
            self.dl,
            self.valid_dl
        ) = self.accelerator.prepare(
            self.model,
            self.optim,
            self.scheduler,
            self.dl,
            self.valid_dl
        )

        # 创建数据加载器迭代器
        self.dl_iter = cycle(self.dl)
        self.valid_dl_iter = cycle(self.valid_dl)

        # 设置保存模型和结果的频率
        self.save_model_every = save_model_every
        self.save_results_every = save_results_every

        # 设置结果文件夹路径
        self.results_folder = Path(results_folder)

        # 如果是主进程且需要清除之前的结果,则清除结果文件夹
        if self.is_main and force_clear_prev_results is True or (not exists(force_clear_prev_results) and len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?')):
            rmtree(str(self.results_folder))

        # 创建结果文件夹
        self.results_folder.mkdir(parents = True, exist_ok = True)
        
        # 初始化超参数追踪器
        hps = {"num_train_steps": num_train_steps, "num_warmup_steps": num_warmup_steps, "learning_rate": lr, "initial_learning_rate": lr}
        self.accelerator.init_trackers("soundstorm", config=hps)

    # 保存模型方法
    def save(self, path):
        pkg = dict(
            model = self.accelerator.get_state_dict(self.model),
            optim = self.optim.state_dict(),
            scheduler = self.scheduler.state_dict()
        )
        torch.save(pkg, path)

    # 加载模型方法
    def load(self, path, restore_optimizer = True):
        model = self.accelerator.unwrap_model(self.model)
        pkg = model.load(path)

        # 如果需要恢复优化器状态,则加载优化器和调度器状态
        if restore_optimizer:
            self.optim.load_state_dict(pkg['optim'])
            self.scheduler.load_state_dict(pkg['scheduler'])

            # + 1 to start from the next step and avoid overwriting the last checkpoint
            self.steps = torch.tensor([checkpoint_num_steps(path) + 1], device=self.device)
    # 打印消息,调用加速器对象的打印方法
    def print(self, msg):
        self.accelerator.print(msg)

    # 生成结果,调用模型对象的生成方法
    def generate(self, *args, **kwargs):
        return self.model.generate(*args, **kwargs)

    # 返回设备信息,调用加速器对象的设备属性
    @property
    def device(self):
        return self.accelerator.device

    # 返回是否分布式训练,判断加速器对象的分布式类型和进程数是否为1
    @property
    def is_distributed(self):
        return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1)

    # 返回是否为主进程,判断加速器对象是否为主进程
    @property
    def is_main(self):
        return self.accelerator.is_main_process

    # 返回是否为本地主进程,判断加速器对象是否为本地主进程
    @property
    def is_local_main(self):
        return self.accelerator.is_local_main_process

    # 预热方法,根据步数计算学习率
    def warmup(self, step):
        if step < self.num_warmup_steps:
            return self.initial_lr + (self.lr - self.initial_lr) * step / self.num_warmup_steps
        else:
            return self.lr
    # 定义训练步骤函数
    def train_step(self):
        # 获取当前步数
        steps = int(self.steps.item())

        # 将模型设置为训练模式
        self.model.train()
        
        # 根据训练步数调整学习率
        if steps < self.num_warmup_steps:
            # 如果步数小于预热步数,应用预热学习率
            lr = self.warmup(steps)
            for param_group in self.optim.param_groups:
                param_group['lr'] = lr
        else:
            # 预热期后,开始应用余弦退火学习率
            self.scheduler.step()

        # 初始化日志
        logs = {}

        # 更新生成器
        for _ in range(self.grad_accum_every):
            # 获取下一个数据批次
            semantic_token_ids, acoustic_token_ids = next(self.dl_iter)

            # 计算损失和损失细分
            loss, loss_breakdown = self.model(
                acoustic_token_ids,
                cond_semantic_token_ids = semantic_token_ids,
                only_train_generator = self.only_train_generator,
                only_train_critic = self.only_train_critic
            )

            generator_loss, critic_loss = loss_breakdown
            generator_loss = 0. if generator_loss is None else generator_loss
            critic_loss = 0. if critic_loss is None else critic_loss
            
            # 反向传播
            self.accelerator.backward(loss / self.grad_accum_every)

            # 累积日志
            accum_log(logs, {'loss': loss.item() / self.grad_accum_every, 'generator_loss': generator_loss / self.grad_accum_every, 'critic_loss': critic_loss / self.grad_accum_every})

        # 如果存在最大梯度范数,则进行梯度裁剪
        if exists(self.max_grad_norm):
            self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)

        # 更新优化器
        self.optim.step()
        self.optim.zero_grad()

        # 记录日志
        self.print(f"{steps}: loss: {logs['loss']:0.3f}, generator loss: {logs['generator_loss']:0.3f}, critic loss: {logs['critic_loss']:0.3f}")
        self.accelerator.log({"train_loss": logs['loss']}, step=steps)

        # 定期采样结果
        self.accelerator.wait_for_everyone()

        if self.is_main and not (steps % self.save_results_every):
            # 获取验证数据批次
            semantic_token_ids, acoustic_token_ids = next(self.valid_dl_iter)

            with torch.inference_mode():
                self.model.eval()
                # 计算验证损失和损失细分
                valid_loss, valid_loss_breakdown = self.model(acoustic_token_ids, cond_semantic_token_ids = semantic_token_ids)
                
                valid_generator_loss, valid_critic_loss = valid_loss_breakdown
                valid_generator_loss = 0. if valid_generator_loss is None else valid_generator_loss
                valid_critic_loss = 0. if valid_critic_loss is None else valid_critic_loss

            # 记录验证日志
            self.print(f'{steps}: valid loss {valid_loss:0.3f}, valid generator loss {valid_generator_loss:0.3f}, valid critic loss {valid_critic_loss:0.3f}')
            self.accelerator.log({"valid_loss": valid_loss, "valid_generator_loss": valid_generator_loss, "valid_critic_loss": valid_critic_loss}, step=steps)

        # 定期保存模型
        if self.is_main and not (steps % self.save_model_every):
            model_path = str(self.results_folder / f'soundstorm.{steps}.pt')
            self.save(model_path)

            self.print(f'{steps}: saving model to {str(self.results_folder)}')

        # 更新步数并返回日志
        self.steps += 1
        return logs

    # 训练函数
    def train(self, log_fn = noop):
        # 循环直到达到训练步数上限
        while self.steps < self.num_train_steps:
            logs = self.train_step()
            log_fn(logs)

        self.print('training complete')

.\lucidrains\soundstorm-pytorch\soundstorm_pytorch\__init__.py

# 从soundstorm_pytorch包中导入SoundStorm、SoundStream、ConformerWrapper和Conformer类
from soundstorm_pytorch.soundstorm import (
    SoundStorm,
    SoundStream,
    ConformerWrapper,
    Conformer
)
# 从soundstorm_pytorch包中导入SoundStormTrainer类
from soundstorm_pytorch.trainer import (
    SoundStormTrainer
)

Spear-TTS - Pytorch

Implementation of Spear-TTS - multi-speaker text-to-speech attention network, in Pytorch

The text-to-semantic module built here will be used for SoundStorm for conditioning.

Appreciation

  • Stability for their generous sponsorships to work on and open source cutting edge artificial intelligence research

  • Lucas Newman for completing the backtranslation portion, as well as beam search decoding!

  • Lucas Newman for completing the final text to semantic transformer training code!

Install

$ pip install spear-tts-pytorch

Usage

import torch

from audiolm_pytorch import HubertWithKmeans

from spear_tts_pytorch import (
    TextToSemantic,
    SemanticToTextDatasetGenerator,
    GeneratedAudioTextDataset,
    MockDataset
)

wav2vec = HubertWithKmeans(
    checkpoint_path = './hubert_base_ls960.pt',
    kmeans_path = './hubert_base_ls960_L9_km500.bin'
)

model = TextToSemantic(
    wav2vec = wav2vec,
    dim = 512,
    num_text_token_ids = 256,
    heads = 8,
    target_kv_heads = 2, # grouped query attention, for memory efficient decoding
    source_depth = 1,
    target_depth = 1
)

ds = MockDataset(10)

dataset_generator = SemanticToTextDatasetGenerator(
    model = model,
    dataset = ds,
    folder = './output_folder'
)

dataset_generator(max_length = 2)

generated_dataset = GeneratedAudioTextDataset(
    folder = './output_folder'
)

assert len(generated_dataset) == 10

Todo

  • add eos logic + generate, and hook up end-to-end generation in soundstorm

  • add first pretraining speech-to-speech with the reconstruction of 60% deleted tokens

  • add dropouts for this project, as low-resource

  • add total flexiblity of which layers of encoder / decoder to freeze during training

  • add step for training on small speech -> text corpus and generating pseudo-labelled dataset + finetuning (thanks to @lucasnewman)

  • add final step of finetuning on text -> speech + pseudolabelled dataset

  • figure out the best way to store and manage the pseudo-labelled generated dataset

  • batched beam search decoding

  • allow for using rotary positions in decoder + flash attention, give Tri another citation

  • integrate speculative decoding with some improvisation - done in same model using early exit strategy

  • add cached key / values for starter + single / grouped key values, make sure flash attention can support specialized causal mask before flash attention 2 is in pytorch core

  • polish the audio-text generation workflow

  • concatting the real audio-text dataset with the generated one -> or being able to convert real audio-text dataset to generated

Citations

@misc{kharitonov2023speak,
    title   = {Speak, Read and Prompt: High-Fidelity Text-to-Speech with Minimal Supervision}, 
    author  = {Eugene Kharitonov and Damien Vincent and Zalán Borsos and Raphaël Marinier and Sertan Girgin and Olivier Pietquin and Matt Sharifi and Marco Tagliasacchi and Neil Zeghidour},
    year    = {2023},
    eprint  = {2302.03540},
    archivePrefix = {arXiv},
    primaryClass = {cs.SD}
}
@inproceedings{dao2022flashattention,
    title   = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
    author  = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
    booktitle = {Advances in Neural Information Processing Systems},
    year    = {2022}
}
@misc{shi2023enhance,
    title   = {Enhance audio generation controllability through representation similarity regularization}, 
    author  = {Yangyang Shi and Gael Le Lan and Varun Nagaraja and Zhaoheng Ni and Xinhao Mei and Ernie Chang and Forrest Iandola and Yang Liu and Vikas Chandra},
    year    = {2023},
    eprint  = {2309.08773},
    archivePrefix = {arXiv},
    primaryClass = {cs.SD}
}
@article{Ainslie2023GQATG,
    title   = {GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints},
    author  = {Joshua Ainslie and James Lee-Thorp and Michiel de Jong and Yury Zemlyanskiy and Federico Lebr'on and Sumit K. Sanghai},
    journal = {ArXiv},
    year    = {2023},
    volume  = {abs/2305.13245},
    url     = {https://api.semanticscholar.org/CorpusID:258833177}
}
@inproceedings{Leviathan2022FastIF,
    title   = {Fast Inference from Transformers via Speculative Decoding},
    author  = {Yaniv Leviathan and Matan Kalman and Y. Matias},
    booktitle = {International Conference on Machine Learning},
    year    = {2022},
    url     = {https://api.semanticscholar.org/CorpusID:254096365}
}

.\lucidrains\spear-tts-pytorch\setup.py

# 导入设置工具和查找包工具
from setuptools import setup, find_packages

# 设置包的信息
setup(
  # 包名
  name = 'spear-tts-pytorch',
  # 查找包,不排除任何包
  packages = find_packages(exclude=[]),
  # 版本号
  version = '0.4.8',
  # 许可证
  license='MIT',
  # 描述
  description = 'Spear-TTS - Pytorch',
  # 作者
  author = 'Phil Wang',
  # 作者邮箱
  author_email = 'lucidrains@gmail.com',
  # 长描述内容类型
  long_description_content_type = 'text/markdown',
  # 项目链接
  url = 'https://github.com/lucidrains/spear-tts-pytorch',
  # 关键词
  keywords = [
    'artificial intelligence',
    'deep learning',
    'transformers',
    'attention mechanism',
    'text-to-speech'
  ],
  # 安装依赖
  install_requires=[
    'audiolm-pytorch>=1.2.8',
    'beartype',
    'einops>=0.6.1',
    'rotary-embedding-torch>=0.3.0',
    'torch>=1.6',
    'tqdm',
    'x-clip>=0.12.2'
  ],
  # 分类
  classifiers=[
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\spear-tts-pytorch\spear_tts_pytorch\attend.py

import torch
from torch import nn, einsum
import torch.nn.functional as F

from collections import namedtuple
from functools import wraps
from packaging import version

from einops import rearrange, repeat

# 定义一个命名元组 Config,包含三个布尔类型的参数
Config = namedtuple('EfficientAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])

# 定义一个辅助函数,用于检查变量是否存在
def exists(val):
    return val is not None

# 定义一个装饰器函数,用于确保被装饰的函数只执行一次
def once(fn):
    called = False
    @wraps(fn)
    def inner(x):
        nonlocal called
        if called:
            return
        called = True
        return fn(x)
    return inner

# 用装饰器 once 包装 print 函数,确保只打印一次
print_once = once(print)

# 主要类 Attend
class Attend(nn.Module):
    def __init__(
        self,
        dropout = 0.,
        causal = False,
        flash = False
    ):
        super().__init__()
        self.dropout = dropout
        self.attn_dropout = nn.Dropout(dropout)

        self.causal = causal
        self.register_buffer("mask", None, persistent=False)

        self.flash = flash
        assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'

        # 确定用于 cuda 和 cpu 的高效注意力配置

        self.cpu_config = Config(True, True, True)
        self.cuda_config = None

        if not torch.cuda.is_available() or not flash:
            return

        device_properties = torch.cuda.get_device_properties(torch.device('cuda'))

        if device_properties.major == 8 and device_properties.minor == 0:
            print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
            self.cuda_config = Config(True, False, False)
        else:
            print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
            self.cuda_config = Config(False, True, True)

    # 获取掩码
    def get_mask(self, i, j, device):
        n = max(i, j)

        if exists(self.mask) and self.mask.shape[-1] >= n:
            mask = self.mask[:n, :n]
        else:
            mask = torch.ones((n, n), device = device, dtype = torch.bool).triu(1)
            self.register_buffer("mask", mask, persistent = False)

        return mask[-i:, :]

    # Flash Attention 函数
    def flash_attn(self, q, k, v, mask = None):
        _, heads, q_len, _, k_len, causal, is_cuda, device = *q.shape, k.shape[-2], self.causal, q.is_cuda, q.device

        # 检查掩码是否存在并扩展到兼容的形状
        if exists(mask):
            mask = rearrange(mask, 'b j -> b 1 1 j')
            mask = mask.expand(-1, heads, q_len, -1)

        # 检查是否有兼容的设备用于 Flash Attention
        config = self.cuda_config if is_cuda else self.cpu_config

        # 如果 q 和 k 的长度不同(缓存键/值),并且是因果的,手动构造因果注意力掩码作为浮点数,因为不支持(Flash Attention 2 最终会支持这一点)
        row_is_entirely_masked = None
        if causal and q_len != k_len:
            causal_mask = self.get_mask(q_len, k_len, device = device)

            if exists(mask):
                mask = mask & ~causal_mask
            else:
                mask = ~causal_mask

            row_is_entirely_masked = ~mask.any(dim = -1)
            mask[..., 0] = mask[..., 0] | row_is_entirely_masked

            causal = False

        # 使用 torch.backends.cuda.sdp_kernel 函数应用 PyTorch 2.0 Flash Attention
        with torch.backends.cuda.sdp_kernel(**config._asdict()):
            out = F.scaled_dot_product_attention(
                q, k, v,
                attn_mask = mask,
                dropout_p = self.dropout if self.training else 0., 
                is_causal = causal
            )

        if exists(row_is_entirely_masked):
            out = out.masked_fill(row_is_entirely_masked[..., None], 0.)

        return out
    # 定义一个前向传播函数,接受查询(q)、键(k)、值(v)和掩码(mask)作为输入参数
    """
    einstein notation
    b - batch
    h - heads
    n, i, j - sequence length (base sequence length, source, target)
    d - feature dimension
    """

    # 获取查询(q)的序列长度和设备信息
    n, device = q.shape[-2], q.device
    # 获取头数和键值对应的头数
    heads, kv_heads = q.shape[1], k.shape[1]

    # 如果键值对应的头数小于总头数,则对键(k)和值(v)进行重复以匹配总头数
    if kv_heads < heads:
        k, v = map(lambda t: repeat(t, 'b h ... -> b (g h) ...', g = heads // kv_heads), (k, v))

    # 缩放因子
    scale = q.shape[-1] ** -0.5

    # 如果启用了flash注意力机制,则调用flash_attn函数
    if self.flash:
        return self.flash_attn(q, k, v, mask = mask)

    # 相似度计算

    sim = einsum("b h i d, b h j d -> b h i j", q, k) * scale

    # 键填充掩码

    # 如果存在掩码,则重新排列掩码并用极小值替换相似度矩阵中的无效位置
    if exists(mask):
        mask = rearrange(mask, 'b j -> b 1 1 j')
        sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)

    # 因果掩码

    # 如果启用了因果掩码,则生成因果掩码并用极小值替换相似度矩阵中的无效位置
    if self.causal:
        i, j = sim.shape[-2:]
        causal_mask = self.get_mask(i, j, device)
        sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)

    # 注意力权重计算

    # 对相似度矩阵进行softmax操作,得到注意力权重
    attn = sim.softmax(dim = -1)
    # 对注意力权重进行dropout操作
    attn = self.attn_dropout(attn)

    # 聚合值

    # 根据注意力权重对值(v)进行加权求和,得到输出结果
    out = einsum("b h i j, b h j d -> b h i d", attn, v)

    return out

.\lucidrains\spear-tts-pytorch\spear_tts_pytorch\data.py

# 导入必要的模块
from pathlib import Path
import torch
from torch.utils.data import Dataset
from beartype import beartype

# 模拟数据集类
class MockDataset(Dataset):
    # 初始化方法,接受数据集长度参数
    def __init__(self, length: int):
        self.length = length

    # 返回数据集长度
    def __len__(self):
        return self.length

    # 获取数据集中指定索引的数据
    def __getitem__(self, ind):
        return torch.randn(1024)

# 生成音频文本数据集类
class GeneratedAudioTextDataset(Dataset):
    # 初始化方法,接受文件夹路径和分隔符ID参数
    @beartype
    def __init__(
        self,
        folder: str,
        delimiter_id: int = -1
    ):
        # 将文件夹路径转换为Path对象
        self.folder = Path(folder)
        # 断言文件夹存在且是一个目录
        assert self.folder.exists() and self.folder.is_dir()
        # 获取文件夹中所有以'.pt'结尾的文件路径列表
        self.paths = list(self.folder.glob('*.pt'))
        # 设置分隔符ID
        self.delimiter_id = delimiter_id

    # 返回数据集的长度
    def __len__(self):
        return len(self.paths)

    # 获取数据集中指定索引的数据
    def __getitem__(self, ind):
        # 获取指定索引的文件路径
        path = self.paths[ind]
        # 加载文件中的数据为张量
        tensor = torch.load(str(path))

        # 创建一个布尔张量,标记分隔符ID的位置
        delimiter_mask = tensor == self.delimiter_id
        # 断言至少存在一个分隔符,否则抛出异常
        assert delimiter_mask.any(), f'delimeter (<audio> <delimeter> <text>) not found'

        # 找到第一个分隔符的位置
        ind = (delimiter_mask.cumsum(dim=-1) == 0).sum().item()

        # 返回分隔符之前的部分和分隔符之后的部分作为数据
        return tensor[:ind], tensor[(ind + 1):]

.\lucidrains\spear-tts-pytorch\spear_tts_pytorch\distributed.py

# 导入 torch 库
import torch
# 从 torch.autograd 模块中导入 Function 类
from torch.autograd import Function
# 导入 torch.distributed 模块
import torch.distributed as distributed
# 从 einops 库中导入 rearrange 函数

from einops import rearrange

# distributed helpers

# 定义一个函数用于在所有进程中收集具有可变维度的张量
def all_gather_variable_dim(t, dim = 0, sizes = None):
    # 获取当前设备、进程的排名和总进程数
    device, rank, world_size = t.device, distributed.get_rank(), distributed.get_world_size()

    # 如果 sizes 不存在
    if not exists(sizes):
        # 创建一个张量表示 t 在指定维度上的大小
        size = torch.tensor(t.shape[dim], device = device, dtype = torch.long)
        # 创建一个列表,用于存储各个进程的大小信息
        sizes = [torch.empty_like(size, device = device, dtype = torch.long) for i in range(world_size)]
        # 在所有进程中收集各个进程的大小信息
        distributed.all_gather(sizes, size)
        # 将收集到的大小信息堆叠成一个张量
        sizes = torch.stack(sizes)

    # 获取所有进程中最大的大小
    max_size = sizes.amax().item()
    # 将 t 在指定维度上填充到最大大小
    padded_t = pad_dim_to(t, max_size, dim = dim)

    # 创建一个列表,用于存储各个进程收集到的张量
    gathered_tensors = [torch.empty(padded_t.shape, device = device, dtype = padded_t.dtype) for i in range(world_size)]
    # 在所有进程中收集填充后的张量
    distributed.all_gather(gathered_tensors, padded_t)

    # 将所有进程收集到的张量在指定维度上拼接
    gathered_tensor = torch.cat(gathered_tensors, dim = dim)
    # 创建一个序列张量
    seq = torch.arange(max_size, device = device)

    # 创建一个掩码,用于选择有效的数据
    mask = rearrange(seq, 'j -> 1 j') < rearrange(sizes, 'i -> i 1')
    mask = rearrange(mask, 'i j -> (i j)')
    seq = torch.arange(mask.shape[-1], device = device)
    indices = seq[mask]

    # 根据掩码选择有效的数据
    gathered_tensor = gathered_tensor.index_select(dim, indices)

    return gathered_tensor, sizes

# 定义一个继承自 Function 的类 AllGather
class AllGather(Function):
    @staticmethod
    def forward(ctx, x, dim, sizes):
        # 检查是否处于分布式环境中且进程数大于 1
        is_dist = distributed.is_initialized() and distributed.get_world_size() > 1
        ctx.is_dist = is_dist

        # 如果不处于分布式环境中,直接返回输入张量和空值
        if not is_dist:
            return x, None

        # 在所有进程中收集具有可变维度的张量
        x, batch_sizes = all_gather_variable_dim(x, dim = dim, sizes = sizes)
        ctx.batch_sizes = batch_sizes.tolist()
        ctx.dim = dim
        return x, batch_sizes

    @staticmethod
    def backward(ctx, grads, _):
        # 如果不处于分布式环境中,直接返回梯度和空值
        if not ctx.is_dist:
            return grads, None, None

        # 获取各个进程的大小信息和当前进程的排名
        batch_sizes, rank = ctx.batch_sizes, distributed.get_rank()
        # 根据各个进程的大小信息拆分梯度
        grads_by_rank = grads.split(batch_sizes, dim = ctx.dim)
        return grads_by_rank[rank], None, None

# 将 AllGather 类应用为一个函数
all_gather = AllGather.apply

.\lucidrains\spear-tts-pytorch\spear_tts_pytorch\spear_tts_pytorch.py

# 导入数学库
import math
# 从路径库中导入路径类
from pathlib import Path
# 从 functools 库中导入 partial 函数
from functools import partial
# 从 random 库中导入 random 函数
from random import random

# 导入 torch 库
import torch
# 从 torch.nn.functional 中导入 F
import torch.nn.functional as F
# 从 torch.nn.utils.rnn 中导入 pad_sequence
from torch.nn.utils.rnn import pad_sequence
# 从 torch 中导入 Tensor, nn, einsum, IntTensor, LongTensor
from torch import Tensor, nn, einsum, IntTensor, LongTensor

# 从 torch.nn 中导入 Module, ModuleList
from torch.nn import Module, ModuleList

# 从 torch.utils.data 中导入 Dataset
from torch.utils.data import Dataset

# 从 einops 中导入 rearrange, repeat, pack, reduce
from einops import rearrange, repeat, pack, reduce
# 从 einops.layers.torch 中导入 Rearrange
from einops.layers.torch import Rearrange

# 从 audiolm_pytorch 中导入 FairseqVQWav2Vec, HubertWithKmeans
from audiolm_pytorch import FairseqVQWav2Vec, HubertWithKmeans
# 从 audiolm_pytorch.data 中导入 get_dataloader
from audiolm_pytorch.data import get_dataloader

# 从 rotary_embedding_torch 中导入 RotaryEmbedding
from rotary_embedding_torch import RotaryEmbedding

# 从 beartype 中导入 beartype
from beartype import beartype
# 从 beartype.door 中导入 is_bearable
from beartype.door import is_bearable
# 从 beartype.typing 中导入 Optional, Union, Callable, Literal, Tuple, List
from beartype.typing import Optional, Union, Callable, Literal, Tuple, List

# 从 x_clip.tokenizer 中导入 tokenizer
from x_clip.tokenizer import tokenizer

# 从 spear_tts_pytorch 中导入 Attend, all_gather
from spear_tts_pytorch.attend import Attend
from spear_tts_pytorch.distributed import all_gather

# 从 tqdm 中导入 tqdm
from tqdm import tqdm

# 定义 FloatTensor 类型为 Union 类型,包含 torch.FloatTensor 和 torch.cuda.FloatTensor
FloatTensor = Union[
    torch.FloatTensor,
    torch.cuda.FloatTensor
]

# 辅助函数

# 判断值是否存在
def exists(val):
    return val is not None

# 如果值存在则返回该值,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# 判断张量是否为空
def empty(t: Tensor):
    return t.numel() == 0

# 对张量进行 L2 归一化
def l2norm(t):
    return F.normalize(t, dim = -1)

# 设置 EOS 标识符的位置
def set_eos_id(t: Tensor, eos_id: int, pad_id: int):
    eos_indices = ((t == pad_id).cumsum(dim = -1) == 0).sum(dim = -1, keepdim = True).long()

    batch_range = torch.arange(t.shape[0], device = t.device, dtype = torch.long)
    batch_range = rearrange(batch_range, '... -> ... 1')

    t = F.pad(t, (0, 1), value = pad_id)
    t[batch_range, eos_indices] = eos_id
    return t

# 对批次中的唯一连续值进行填充
def batch_unique_consecutive(t, pad_value = 0.):
    unique_arr = [torch.unique_consecutive(el) for el in t.unbind(dim = 0)]
    return pad_sequence(unique_arr, batch_first = True, padding_value = pad_value)

# 在 EOS 之后进行掩码处理
def mask_after_eos(target, eos_id, pad_id):
    mask = (target == eos_id).cumsum(dim = -1) > 0
    mask = F.pad(mask, (1, -1), value = False)
    return target.masked_fill(mask, pad_id)

# 安全除法
def safe_div(num, den, eps = 1e-10):
    return num / max(den, eps)

# 查找第一个为真的索引
def find_first_true_index(bool_tensor, dim = -1):
    return (bool_tensor.cumsum(dim = dim) == 0).sum(dim = dim)

# 冻结和解冻辅助函数

# 设置模块参数是否需要梯度
def set_requires_grad_(module: Module, requires_grad: bool):
    for p in module.parameters():
        p.requires_grad = requires_grad

# 冻结模块参数
def freeze(module: Module):
    set_requires_grad_(module, False)

# 解冻模块参数
def unfreeze(module: Module):
    set_requires_grad_(module, True)

# 采样辅助函数

# 评估装饰器
def eval_decorator(fn):
    def inner(self, *args, **kwargs):
        was_training = self.training
        self.eval()
        out = fn(self, *args, **kwargs)
        self.train(was_training)
        return out
    return inner

# 对数函数
def log(t, eps = 1e-20):
    return torch.log(t.clamp(min = eps))

# 生成 Gumbel 噪声
def gumbel_noise(t):
    noise = torch.zeros_like(t).uniform_(0, 1)
    return -log(-log(noise))

# Gumbel 采样
def gumbel_sample(t, temperature = 1., dim = -1):
    return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim = dim)

# Top-p 采样
def top_p(logits, thres = 0.9):
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
    sorted_indices_to_remove = F.pad(cum_probs > thres, (1, -1), value = 0)
    sorted_logits[sorted_indices_to_remove] = float('-inf')
    sorted_logits = sorted_logits.scatter(-1, sorted_indices, sorted_logits)
    return sorted_logits

# Top-k 采样
def top_k(logits, thres = 0.1, k = None):
    if not exists(k):
        k = math.ceil(thres * logits.shape[-1])
    val, ind = torch.topk(logits, k, dim = -1)
    probs = torch.full_like(logits, float('-inf'))
    probs.scatter_(-1, ind, val)
    return probs

# 残差包装器

class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) + x

# RMSNorm

class RMSNorm(nn.Module):
    # 初始化函数,接受一个维度参数
    def __init__(self, dim):
        # 调用父类的初始化函数
        super().__init__()
        # 计算缩放因子为维度的平方根
        self.scale = dim ** 0.5
        # 创建一个可学习的参数 gamma,维度为输入维度
        self.gamma = nn.Parameter(torch.ones(dim))
    
    # 前向传播函数,接受输入 x
    def forward(self, x):
        # 对输入 x 进行归一化操作,dim=-1 表示对最后一个维度进行归一化
        return F.normalize(x, dim=-1) * self.scale * self.gamma
# 定义 GEGLU 类,用于实现 GEGLU 激活函数
class GEGLU(nn.Module):
    # GEGLU 类的前向传播函数
    def forward(self, x):
        # 将输入张量 x 按照最后一个维度分成两部分
        x, gate = x.chunk(2, dim = -1)
        # 对 gate 部分应用 GELU 激活函数,并与 x 相乘
        return F.gelu(gate) * x

# 定义 FeedForward 函数,用于创建前馈神经网络层
def FeedForward(dim, mult = 4, dropout = 0.):
    # 计算内部维度
    dim_inner = int(dim * mult * 2 / 3)
    # 返回一个包含多个层的神经网络模型
    return nn.Sequential(
        RMSNorm(dim),  # 使用 RMSNorm 进行归一化
        nn.Linear(dim, dim_inner * 2),  # 线性变换层
        GEGLU(),  # 使用 GEGLU 激活函数
        nn.Dropout(dropout),  # Dropout 层
        nn.Linear(dim_inner, dim)  # 线性变换层
    )

# 定义 Attention 类,用于实现注意力机制
class Attention(nn.Module):
    # Attention 类的初始化函数
    def __init__(
        self,
        dim,
        *,
        dim_head = 64,
        heads = 8,
        kv_heads = None,
        causal = False,
        dim_context = None,
        dropout = 0.,
        rotary_emb: Optional[RotaryEmbedding] = None,
        flash = False,
        add_null_kv = False
    ):
        super().__init__()
        dim_context = default(dim_context, dim)

        self.heads = heads
        self.kv_heads = default(kv_heads, heads)
        assert (self.heads % self.kv_heads) == 0, 'number of key value heads must be divisible by query heads'

        self.scale = dim_head ** -0.5
        dim_query_inner = heads * dim_head
        dim_kv_inner = self.kv_heads * dim_head

        self.rotary_emb = rotary_emb

        self.attend = Attend(
            causal = causal,
            flash = flash,
            dropout = dropout
        )

        self.norm = RMSNorm(dim)
        self.attn_dropout = nn.Dropout(dropout)

        # 将输入转换为查询向量
        self.to_q = nn.Sequential(
            nn.Linear(dim, dim_query_inner, bias = False),
            Rearrange('b n (h d) -> b h n d', h = self.heads)
        )

        # 将上下文转换为键值对
        self.to_kv = nn.Sequential(
            nn.Linear(dim_context, dim_kv_inner * 2, bias = False),
            Rearrange('b n (kv h d) -> kv b h n d', kv = 2, h = self.kv_heads)
        )

        # 将输出转换为指定维度
        self.to_out = nn.Linear(dim_query_inner, dim, bias = False)

        self.add_null_kv = add_null_kv
        if add_null_kv:
            self.null_kv = nn.Parameter(torch.randn(2, self.kv_heads, 1, dim_head))

    # Attention 类的前向传播函数
    def forward(
        self,
        x,
        context = None,
        mask = None,
        cache = None,
        return_cached_key_values = False
    ):
        has_context = exists(context)
        b = x.shape[0]

        x = self.norm(x)

        context = default(context, x)

        q, k, v = (self.to_q(x), *self.to_kv(context))

        if exists(cache):
            ck, cv = cache.unbind(dim = 1)
            k = torch.cat((ck, k), dim = -2)
            v = torch.cat((cv, v), dim = -2)

        new_cache = torch.stack((k, v), dim = 1)

        if exists(self.rotary_emb):
            assert not has_context
            q, k = self.rotary_emb.rotate_queries_with_cached_keys(q, k)

        if self.add_null_kv:
            assert not exists(self.rotary_emb)
            nk, nv = map(lambda t: repeat(t, 'h 1 d -> b h 1 d', b = b), self.null_kv)
            k = torch.cat((nk, k), dim = -2)
            v = torch.cat((nv, v), dim = -2)

            if exists(mask):
                mask = F.pad(mask, (1, 0), value = True)

        out = self.attend(q, k, v, mask = mask)

        out = rearrange(out, 'b h n d -> b n (h d)')
        out =  self.to_out(out)

        if not return_cached_key_values:
            return out

        return out, new_cache

# 定义 Transformer 类,用于实现 Transformer 模型
class Transformer(nn.Module):
    # Transformer 类的初始化函数
    def __init__(
        self,
        *,
        dim,
        depth,
        dim_head = 64,
        heads = 8,
        kv_heads = None,
        causal = False,
        attn_dropout = 0.,
        ff_mult = 4,
        ff_dropout = 0.,
        cross_attend = False,
        attn_flash = False
    ):
        # 调用父类的构造函数
        super().__init__()

        # 创建旋转嵌入对象
        rotary_emb = RotaryEmbedding(dim_head)

        # 初始化神经网络层列表
        self.layers = nn.ModuleList([])

        # 循环创建指定数量的层
        for _ in range(depth):
            # 每一层包含注意力机制、交叉注意力机制(可选)、前馈神经网络
            self.layers.append(nn.ModuleList([
                Attention(dim = dim, causal = causal, dim_head = dim_head, heads = heads, kv_heads = kv_heads, dropout = attn_dropout, rotary_emb = rotary_emb, flash = attn_flash),
                Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, flash = attn_flash, add_null_kv = True) if cross_attend else None,
                FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
            ]))

        # 创建最终的归一化层
        self.final_norm = RMSNorm(dim)

    def forward(
        self,
        x,
        mask = None,
        context = None,
        context_mask = None,
        cache = None,
        return_cache = False,
        return_hiddens = False,
        early_exit_at_layer = None,
        seq_start_pos = None
    ):
        # 检查是否存在上下文信息
        has_context = exists(context)

        # 如果存在序列起始位置信息,则生成对应的掩码
        if exists(seq_start_pos):
            assert not exists(mask)
            seq_len = x.shape[-2]
            seq_arange = torch.arange(seq_len, device = x.device, dtype = torch.long)
            mask = seq_arange >= seq_start_pos[..., None]

        # 如果存在缓存信息,则截取输入序列
        if exists(cache):
            cached_length, seq_len = cache.shape[-2], x.shape[-2]
            assert seq_len > cached_length
            x = x[:, cached_length:]

        # 初始化新的缓存列表和隐藏层列表
        new_cache = []
        hiddens = []

        # 如果存在缓存信息,则创建迭代器
        if exists(cache):
            iter_cache = iter(cache.unbind(dim = 1))
        else:
            iter_cache = iter([])

        # 遍历每一层
        for ind, (self_attn, maybe_cross_attn, ff) in enumerate(self.layers):
            layer = ind + 1

            # 计算自注意力机制输出,并更新缓存
            residual = x
            attn_out, key_values = self_attn(x, mask = mask, cache = next(iter_cache, None), return_cached_key_values = True)
            x = attn_out + residual
            new_cache.append(key_values)

            # ��果存在交叉注意力机制,则应用
            if exists(maybe_cross_attn):
                assert has_context
                x = maybe_cross_attn(x, context = context, mask = context_mask) + x

            # 应用前馈神经网络
            x = ff(x) + x
            hiddens.append(x)

            # 如果设置了提前退出层,则在该层结束循环
            if exists(early_exit_at_layer) and early_exit_at_layer == layer:
                break

        # 如果设置了提前退出层,则返回结果或缓存
        if exists(early_exit_at_layer):
            if return_cache:
                return x, torch.stack(new_cache, dim = 1)
            return x

        # 对最终输出进行归一化
        out = self.final_norm(x)

        # 如果需要返回隐藏层信息,则返回结果和隐藏层列表
        if return_hiddens:
            assert not return_cache
            return out, torch.stack(hiddens)

        # 如果不需要返回缓存信息,则返回结果
        if not return_cache:
            return out

        # 返回结果和缓存信息
        return out, torch.stack(new_cache, dim = 1)
# 定义 SpeechOrTextLiteral 类型,可以是'speech'或'text'中的一个
SpeechOrTextLiteral = Union[
    Literal['speech'],
    Literal['text']
]

# 定义 SemanticModelType 类型,可以是 FairseqVQWav2Vec 或 HubertWithKmeans 中的一个
SemanticModelType = Union[
    FairseqVQWav2Vec,
    HubertWithKmeans
]

# 定义 TextToSemantic 类,继承自 Module 类
class TextToSemantic(Module):
    # 初始化函数
    @beartype
    def __init__(
        self,
        dim,
        *,
        source_depth,
        target_depth,
        num_text_token_ids = None,
        tokenizer_encode: Optional[Callable] = None,
        use_openai_tokenizer = False,
        wav2vec: Optional[SemanticModelType] = None,
        num_semantic_token_ids = None,
        dim_head = 64,
        heads = 8,
        target_kv_heads = None,  # for grouped query attention, saving memory on decoder inference
        attn_dropout = 0.,
        ff_mult = 4,
        ff_dropout = 0.,
        semantic_pad_id = -1,
        text_pad_id = 0,
        autoset_semantic_eos_id = True,
        autoset_text_eos_id = True,
        attn_flash = False,
        cond_drop_prob = 0.,
        target_early_exit_layer = None,
        detach_early_exit_embed = False,
        align_reg_loss_weight = 0.1,
        align_reg_use_logsumexp_pool = True,
        align_reg_logsumexp_pool_temp = 0.1
    @property
    def device(self):
        # 返回第一个参数的设备
        return next(self.parameters()).device

    # 加载函数
    def load(self, path, strict = True):
        # 返回 pkg,以便如果此函数从 Trainer 函数调用中调用,则 Trainer 也可以访问从检查点加载的包
        path = Path(path)
        assert path.exists()
        pkg = torch.load(str(path), map_location = 'cpu')
        self.load_state_dict(pkg['model'], strict = strict)
        return pkg

    # 一组冻结/解冻工具
    # 然后依赖 get_optimizer 来过滤不需要梯度的参数,使其暴露给优化器

    # 解冻所有参数
    def unfreeze_all(self):
        unfreeze(self)

    # 冻结编码器
    def freeze_encoder(self):
        freeze(self.source_transformer)

    # 冻结编码器到某一层
    def freeze_encoder_below_layer(self, layer: int):
        """
        用于在伪标记数据集上对文本到语义的最终训练
        他们将编码器部分冻结到某一层
        """
        unfreeze(self.source_transformer)

        for ind, module in enumerate(self.source_transformer.layers):
            current_layer = ind + 1

            if current_layer <= layer:
                freeze(module)

    # 冻结解码器
    def freeze_decoder(self):
        freeze(self.target_transformer)

    # 冻结语音嵌入
    def freeze_speech_emb(self):
        freeze(self.token_emb['speech'])
        self.start_token['speech'].requires_grad = False

    # 冻结文本嵌入
    def freeze_text_emb(self):
        freeze(self.token_emb['text'])
        self.start_token['text'].requires_grad = False

    # 采样函数

    @torch.no_grad()
    @eval_decorator
    @beartype
    def generate(
        self,
        source: Union[List[str], Tensor],
        *,
        source_type: SpeechOrTextLiteral,
        target_type: SpeechOrTextLiteral,
        temperature = 1.,
        filter_logits_fn = top_k,
        filter_fn_kwargs: dict = dict(),
        source_mask: Optional[Tensor] = None,
        max_length = 2048,
        beam_search_decode = False,
        spec_decode = False,
        spec_decode_gamma = 5,
        spec_decode_lenience = 1.,
        beam_size = 4,
        return_source = False,
        return_target_mask = False,
        cond_scale = 1.
    @beartype
    def forward(
        self,
        source: Union[List[str], Tensor],
        target: Union[List[str], Tensor],
        *,
        source_type: SpeechOrTextLiteral,
        target_type: SpeechOrTextLiteral,
        source_mask: Optional[Tensor] = None,
        target_mask: Optional[Tensor] = None,
        return_loss = False,
        return_logits = False,
        cond_drop_prob: Optional[float] = None,
        should_sim_regularize = True,
        return_early_exit_loss = False
# 预训练模块

# 获取掩码子集概率函数
def get_mask_subset_prob(mask, prob, min_mask = 0):
    batch, seq, device = *mask.shape, mask.device
    # 计算每个位置需要mask的数量,根据mask的和与概率相乘,并限制最小值为min_mask
    num_to_mask = (mask.sum(dim=-1, keepdim=True) * prob).clamp(min=min_mask)
    # 生成一个指定大小的随机张量,用于存储logits
    logits = torch.rand((batch, seq), device=device)
    # 根据mask将logits中的非mask位置填充为-1
    logits = logits.masked_fill(~mask, -1)

    # 对logits进行排序,返回排序后的索引
    randperm = logits.argsort(dim=-1).float()

    # 计算每个样本中需要填充的数量
    num_padding = (~mask).sum(dim=-1, keepdim=True)
    # 将randperm中的索引减去需要填充的数量,以保证填充的位置不会被选中
    randperm -= num_padding

    # 生成一个布尔张量,表示哪些位置需要被选中
    subset_mask = randperm < num_to_mask
    # 将subset_mask中非mask位置填充为False
    subset_mask.masked_fill_(~mask, False)
    # 返回subset_mask
    return subset_mask
# 定义一个包装器类,用于语音到语义预训练任务
class SpeechSpeechPretrainWrapper(nn.Module):
    # 初始化方法
    @beartype
    def __init__(
        self,
        model: TextToSemantic,  # 语义模型
        wav2vec: Optional[SemanticModelType] = None,  # 可选的语音模型
        deletion_prob: float = 0.6,  # 删除概率
        reconstruct_seq: bool = False,  # 是否重构序列
        mask_id = None  # 掩码 ID
    ):
        super().__init__()

        self.model = model  # 保存语义模型
        self.wav2vec = default(wav2vec, model.wav2vec)  # 保存语音模型,默认为语义模型的 wav2vec

        self.deletion_prob = deletion_prob  # 保存删除概率
        self.reconstruct_seq = reconstruct_seq  # 是否重构序列
        self.mask_id = mask_id  # 掩码 ID

    # 前向传播方法
    def forward(
        self,
        x,  # 输入数据
        return_early_exit_loss = False  # 是否返回早期退出损失
    ):
        is_raw_audio = x.dtype == torch.float  # 判断输入数据是否为原始音频

        if is_raw_audio:
            assert exists(self.wav2vec)  # 断言语音模型存在
            
            with torch.no_grad():
                self.wav2vec.eval()  # 设置语音模型为评估模式
                x = self.wav2vec(x, flatten = False)  # 对输入数据进行处理

        batch = x.shape[0]  # 获取批次大小

        mask = torch.ones_like(x, dtype = torch.bool, device = self.model.device)  # 创建与输入数据相同形状的掩码

        if exists(self.mask_id):
            assert self.reconstruct_seq, 'reconstruct_seq must be true if mask id is provided'  # 如果提供了掩码 ID,则重构序列必须为真
            
            mask = mask.masked_fill(x == self.model.semantic_pad_id, False)  # 根据语义填充 ID 进行掩码
            delete_mask = get_mask_subset_prob(mask, self.deletion_prob)  # 获取删除掩码

            source = x.masked_fill(delete_mask, self.mask_id)  # 根据删除掩码和掩码 ID 生成源数据
        else:
            delete_mask = get_mask_subset_prob(mask, self.deletion_prob)  # 获取删除掩码

            source = rearrange(x[~delete_mask], '(b n) -> b n', b = batch)  # 重新排列数据

        if self.reconstruct_seq:
            target = x  # 目标数据为输入数据
        else:
            target = rearrange(x[delete_mask], '(b n) -> b n', b = batch)  # 目标数据为删除后的数据

        loss, logits = self.model(
            source, target,  # 输入源数据和目标数据
            source_type = 'speech',  # 源数据类型为语音
            target_type = 'speech',  # 目标数据类型为语音
            return_loss = True,  # 返回损失
            return_logits = True,  # 返回 logits
            return_early_exit_loss = return_early_exit_loss,  # 是否返回早期退出损失
        )

        return loss, logits

# 包装器类,用于反向翻译任务
class SemanticToTextWrapper(nn.Module):
    # 初始化方法
    @beartype
    def __init__(
        self,
        model: TextToSemantic  # 语义模型
    ):
        super().__init__()

        self.model = model  # 保存语义模型

    # 前向传播方法
    def forward(
        self,
        semantic_token_ids,  # 语义标记 ID
        grapheme_token_ids,  # 字形标记 ID
    ):
        source = semantic_token_ids  # 源数据为语义标记 ID
        target = grapheme_token_ids  # 目标数据为字形标记 ID

        loss, logits = self.model(
            source, target,  # 输入源数据和目标数据
            source_type = 'speech',  # 源数据类型为语音
            target_type = 'text',  # 目标数据类型为文本
            return_loss = True,  # 返回损失
            return_logits = True  # 返回 logits
        )

        return loss, logits

# 包装器类,用于文本到语义任务
class TextToSemanticWrapper(nn.Module):
    # 初始化方法
    @beartype
    def __init__(
        self,
        model: TextToSemantic  # 语义模型
    ):
        super().__init__()

        self.model = model  # 保存语义模型

    # 前向传播方法
    def forward(
        self,
        grapheme_token_ids,  # 字形标记 ID
        semantic_token_ids,  # 语义标记 ID
        return_early_exit_loss = True  # 是否返回早期退出损失
    ):
        source = grapheme_token_ids  # 源数据为字形标记 ID
        target = semantic_token_ids  # 目标数据为语义标记 ID

        loss, logits = self.model(
            source, target,  # 输入源数据和目标数据
            source_type = 'text',  # 源数据类型为文本
            target_type = 'speech',  # 目标数据类型为语音
            return_loss = True,  # 返回损失
            return_logits = True,  # 返回 logits
            return_early_exit_loss = return_early_exit_loss  # 是否返回早期退出损失
        )

        return loss, logits

# 包装器类,用于生成伪标记的音频到文本数据集
class SemanticToTextDatasetGenerator(nn.Module):
    # 初始化方法
    @beartype
    def __init__(
        self,
        model,  # 模型
        *,
        dataset: Dataset,  # 数据集
        folder = './generated-audio-text-pairs',  # 文件夹路径
        batch_size = 4,  # 批次大小
        delimiter_id: int = -1,  # 分隔符 ID
        audio_pad_id = None,  # 音频填充 ID
        text_pad_id = 0  # 文本填充 ID
    # 初始化函数,设置模型、数据集、数据加载器等参数
    def __init__(
        self,
        model,
        dataset,
        batch_size,
        delimiter_id,
        audio_pad_id,
        text_pad_id,
        folder
    ):
        # 调用父类的初始化函数
        super().__init__()
        # 设置模型
        self.model = model

        # 设置数据集
        self.dataset = dataset
        # 根据数据集和批量大小创建数据加载器
        self.dl = get_dataloader(dataset, batch_size=batch_size)
        # 设置分隔符的 ID
        self.delimiter_id = delimiter_id

        # 设置音频填充符的 ID
        self.audio_pad_id = audio_pad_id
        # 设置文本填充符的 ID
        self.text_pad_id = text_pad_id

        # 将文件夹路径转换为 Path 对象,并创建文件夹(如果不存在)
        self.folder = Path(folder)
        self.folder.mkdir(exist_ok=True, parents=True)

    # 前向传播函数,生成文本数据
    def forward(
        self,
        max_length=2048,
        beam_search_decode=True,
        **generate_kwargs
    ):
        # 创建包含分隔符 ID 的张量
        delimiter = torch.tensor([self.delimiter_id], device=self.model.device)

        # 计数器,用于生成文件名
        counter = 0

        # 遍历数据加载器中的音频数据
        for audio, in self.dl:
            # 生成音频语义 ID 和文本 ID
            audio_semantic_ids, text_ids = self.model.generate(
                source=audio,
                source_type='speech',
                target_type='text',
                return_source=True,
                max_length=max_length,
                beam_search_decode=beam_search_decode,
                **generate_kwargs
            )

            # 遍历音频语义 ID 和文本 ID
            for audio_semantic_id, text_id in zip(audio_semantic_ids, text_ids):

                # 如果音频填充符存在,则创建音频填充掩码并去除填充符
                if exists(self.audio_pad_id):
                    audio_pad_mask = audio_semantic_id == self.audio_pad_id
                    audio_semantic_id = audio_semantic_id[~audio_pad_mask]

                # 如果文本填充符存在,则创建文本填充掩码并去除填充符
                if exists(self.text_pad_id):
                    text_pad_mask = text_id == self.text_pad_id
                    text_id = text_id[~text_pad_mask]

                # 将音频语义 ID、分隔符和文本 ID 打包成一行数据
                row, _ = pack([audio_semantic_id, delimiter, text_id], '*')
                # 构建保存路径
                path = str(self.folder / f'{counter}.pt')

                # 保存数据到指定路径
                torch.save(row, path)
                # 更新计数器
                counter += 1

.\lucidrains\spear-tts-pytorch\spear_tts_pytorch\trainer.py

# 导入必要的库
import re
from pathlib import Path
from shutil import rmtree

# 导入 beartype 库中的函数和类型
from beartype import beartype
from beartype.door import is_bearable
from beartype.typing import Union, Optional, Tuple

# 导入 PyTorch 库
import torch
from torch import nn, LongTensor, IntTensor
from torch.utils.data import ConcatDataset
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import Dataset, random_split

# 导入 audiolm_pytorch 库中的模型和函数
from audiolm_pytorch import FairseqVQWav2Vec, HubertWithKmeans
from audiolm_pytorch.data import get_dataloader
from audiolm_pytorch.optimizer import get_optimizer

# 导入 spear_tts_pytorch 库中的模型和数据集
from spear_tts_pytorch.spear_tts_pytorch import SpeechSpeechPretrainWrapper, TextToSemantic, SemanticToTextWrapper, TextToSemanticWrapper
from spear_tts_pytorch.data import GeneratedAudioTextDataset

# 导入 accelerate 库中的加速器和分布式类型
from accelerate import Accelerator, DistributedType

# 定义类型别名
IndicesTensor = Union[LongTensor, IntTensor]

# 确保只有一个 Trainer 实例化
ONE_TRAINER_INSTANTIATED = False

def check_one_trainer():
    global ONE_TRAINER_INSTANTIATED
    assert not ONE_TRAINER_INSTANTIATED, 'only one Trainer can be instantiated at a time for training'
    ONE_TRAINER_INSTANTIATED = True

# 辅助函数

# 检查值是否存在
def exists(val):
    return val is not None

# 空操作函数
def noop(*args, **kwargs):
    pass

# 无限循环生成数据集
def cycle(dl):
    while True:
        for data in dl:
            yield data

# 将输入转换为元组
def cast_tuple(t):
    return t if isinstance(t, (tuple, list)) else (t,)

# 询问用户是或否
def yes_or_no(question):
    answer = input(f'{question} (y/n) ')
    return answer.lower() in ('yes', 'y')

# 累积日志信息
def accum_log(log, new_logs):
    for key, new_value in new_logs.items():
        old_value = log.get(key, 0.)
        log[key] = old_value + new_value
    return log

# 从检查点文件名中获取训练步数
def checkpoint_num_steps(checkpoint_path):
    """Returns the number of steps trained from a checkpoint based on the filename.

    Filename format assumed to be something like "/path/to/speech.speech.20000.pt" which is
    for 20k train steps. Returns 20000 in that case.
    """
    results = re.findall(r'\d+', str(checkpoint_path)

    if len(results) == 0:
        return 0

    return int(results[-1])

# 定义 SpeechSpeechPretrainer 类
class SpeechSpeechPretrainer(nn.Module):
    @beartype
    def __init__(
        self,
        model: TextToSemantic,
        wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]],
        *,
        num_train_steps,
        num_warmup_steps,
        batch_size,
        dataset: Optional[Dataset] = None,
        deletion_prob: float = 0.6,
        reconstruct_seq: bool = False,
        mask_id = None,
        lr = 3e-4,
        initial_lr = 1e-5,
        grad_accum_every = 1,
        wd = 0.,
        max_grad_norm = 0.5,
        valid_frac = 0.05,
        random_split_seed = 42,
        log_every = 10,
        save_results_every = 100,
        save_model_every = 1000,
        results_folder = './results',
        accelerate_kwargs: dict = dict(),
        split_batches = False,
        drop_last = False,
        force_clear_prev_results = None
        ):
        # 调用父类的构造函数
        super().__init__()
        # 检查是否只有一个训练器
        check_one_trainer()

        # 初始化加速器
        self.accelerator = Accelerator(
            split_batches = split_batches,
            **accelerate_kwargs
        )

        # 设置模型和wav2vec
        self.model = model
        self.wav2vec = wav2vec

        # 初始化训练包装器
        self.train_wrapper = SpeechSpeechPretrainWrapper(
            model = model,
            wav2vec = wav2vec,
            deletion_prob = deletion_prob,
            reconstruct_seq = reconstruct_seq,
            mask_id = mask_id
        )

        # 注册缓冲区
        self.register_buffer('steps', torch.Tensor([0]))

        # 设置训练步数、热身步数、批量大小、梯度累积频率
        self.num_train_steps = num_train_steps
        self.num_warmup_steps = num_warmup_steps
        self.batch_size = batch_size
        self.grad_accum_every = grad_accum_every

        # 优化器
        self.lr = lr
        self.initial_lr = initial_lr
        self.optim = get_optimizer(model.parameters(), lr = lr, wd = wd)
        self.scheduler = CosineAnnealingLR(self.optim, T_max = num_train_steps)

        # 最大梯度范数
        self.max_grad_norm = max_grad_norm

        # 创建数据集
        self.ds = dataset

        # 划分验证集
        if valid_frac > 0:
            train_size = int((1 - valid_frac) * len(self.ds))
            valid_size = len(self.ds) - train_size
            self.ds, self.valid_ds = random_split(self.ds, [train_size, valid_size], generator = torch.Generator().manual_seed(random_split_seed))
            self.print(f'training with dataset of {len(self.ds)} samples and validating with randomly splitted {len(self.valid_ds)} samples')
        else:
            self.valid_ds = self.ds
            self.print(f'training with shared training and valid dataset of {len(self.ds)} samples')

        # 断言确保数据集和验证集的样本数足够
        assert len(self.ds) >= batch_size, 'dataset must have sufficient samples for training'
        assert len(self.valid_ds) >= batch_size, f'validation dataset must have sufficient number of samples (currently {len(self.valid_ds)}) for training'

        # 数据加载器
        self.dl = get_dataloader(self.ds, batch_size = batch_size, shuffle = True, drop_last = drop_last)
        self.valid_dl = get_dataloader(self.valid_ds, batch_size = batch_size, shuffle = True, drop_last = drop_last)

        # 使用加速器准备训练所需的对象
        (
            self.train_wrapper,
            self.optim,
            self.scheduler,
            self.dl,
            self.valid_dl
        ) = self.accelerator.prepare(
            self.train_wrapper,
            self.optim,
            self.scheduler,
            self.dl,
            self.valid_dl
        )

        # 数据加载器迭代器
        self.dl_iter = cycle(self.dl)
        self.valid_dl_iter = cycle(self.valid_dl)

        # 设置日志、保存模型和保存结果的频率
        self.log_every = log_every
        self.save_model_every = save_model_every
        self.save_results_every = save_results_every

        # 设置结果文件夹路径
        self.results_folder = Path(results_folder)

        # 如果是主进程且需要清除之前的结果,则清除结果文件夹
        if self.is_main and force_clear_prev_results is True or (not exists(force_clear_prev_results) and len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?')):
            rmtree(str(self.results_folder))

        # 创建结果文件夹
        self.results_folder.mkdir(parents = True, exist_ok = True)
        
        # 初始化超参数跟踪器
        hps = {"num_train_steps": num_train_steps, "num_warmup_steps": num_warmup_steps, "learning_rate": lr, "initial_learning_rate": lr}
        self.accelerator.init_trackers("speechspeech", config=hps)

    # 保存模型
    def save(self, path):
        pkg = dict(
            model = self.accelerator.get_state_dict(self.model),
            optim = self.optim.state_dict(),
            scheduler = self.scheduler.state_dict()
        )
        torch.save(pkg, path)
    # 加载模型参数和优化器状态
    def load(self, path):
        # 获取未封装的模型
        model = self.accelerator.unwrap_model(self.model)
        # 加载模型
        pkg = model.load(path)

        # 加载优化器状态
        self.optim.load_state_dict(pkg['optim'])
        # 加载调度器状态
        self.scheduler.load_state_dict(pkg['scheduler'])

        # 从下一个步骤开始,避免覆盖最后一个检查点
        self.steps = torch.tensor([checkpoint_num_steps(path) + 1], device=self.device)

    # 打印消息
    def print(self, msg):
        self.accelerator.print(msg)

    # 生成结果
    def generate(self, *args, **kwargs):
        return self.train_wrapper.generate(*args, **kwargs)

    # 获取设备
    @property
    def device(self):
        return self.accelerator.device

    # 判断是否分布式训练
    @property
    def is_distributed(self):
        return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1)

    # 判断是否为主进程
    @property
    def is_main(self):
        return self.accelerator.is_main_process

    # 判断是否为本地主进程
    @property
    def is_local_main(self):
        return self.accelerator.is_local_main_process

    # 热身训练
    def warmup(self, step):
        if step < self.num_warmup_steps:
            return self.initial_lr + (self.lr - self.initial_lr) * step / self.num_warmup_steps
        else:
            return self.lr
    
    # 训练步骤
    def train_step(self):
        steps = int(self.steps.item())

        self.model.train()
        
        # 根据调度器调整学习率
        
        if steps < self.num_warmup_steps:
            # 应用热身训练
            lr = self.warmup(steps)
            for param_group in self.optim.param_groups:
                param_group['lr'] = lr
        else:
            # 热身训练后,开始应用余弦退火学习率调度器
            self.scheduler.step()

        # 日志

        logs = {}

        # 更新 VAE(生成器)

        for _ in range(self.grad_accum_every):
            x, = next(self.dl_iter)

            loss, _ = self.train_wrapper(x)

            self.accelerator.backward(loss / self.grad_accum_every)

            accum_log(logs, {'loss': loss.item() / self.grad_accum_every})

        if exists(self.max_grad_norm):
            self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)

        self.optim.step()
        self.optim.zero_grad()

        # 日志

        if not (steps % self.log_every):
            self.print(f"{steps}: loss: {logs['loss']:0.3f}")

        self.accelerator.log({"train_loss": logs['loss']}, step=steps)

        # 定期采样结果

        self.accelerator.wait_for_everyone()

        if self.is_main and not (steps % self.save_results_every):
            x, = next(self.valid_dl_iter)

            with torch.inference_mode():
                self.train_wrapper.eval()
                valid_loss, _ = self.train_wrapper(x)

            self.print(f'{steps}: valid loss {valid_loss:0.3f}')
            self.accelerator.log({"valid_loss": valid_loss}, step=steps)

        # 定期保存模型

        if self.is_main and not (steps % self.save_model_every):
            model_path = str(self.results_folder / f'speech.speech.{steps}.pt')
            self.save(model_path)

            self.print(f'{steps}: saving model to {str(self.results_folder)}')

        self.steps += 1
        return logs

    # 训练模型
    def train(self, log_fn = noop):
        while self.steps < self.num_train_steps:
            logs = self.train_step()
            log_fn(logs)

        self.print('training complete')
# 定义一个用于将语义转换为文本的训练器类
class SemanticToTextTrainer(nn.Module):
    # 初始化方法,接受多个参数
    @beartype
    def __init__(
        self,
        model: TextToSemantic,  # 模型参数,用于将文本转换为语义
        *,
        num_train_steps,  # 训练步数
        num_warmup_steps,  # 热身步数
        batch_size,  # 批量大小
        dataset: Optional[Dataset] = None,  # 数据集,默认为None
        lr = 3e-4,  # 学习率,默认为3e-4
        initial_lr = 1e-5,  # 初始学习率,默认为1e-5
        grad_accum_every = 1,  # 梯度累积频率,默认为1
        wd = 0.,  # 权重衰减,默认为0
        max_grad_norm = 0.5,  # 最大梯度范数,默认为0.5
        valid_frac = 0.05,  # 验证集比例,默认为0.05
        random_split_seed = 42,  # 随机拆分种子,默认为42
        log_every = 10,  # 每隔多少步记录日志,默认为10
        save_results_every = 100,  # 每隔多少步保存结果,默认为100
        save_model_every = 1000,  # 每隔多少步保存模型,默认为1000
        results_folder = './results',  # 结果保存文件夹,默认为'./results'
        accelerate_kwargs: dict = dict(),  # 加速参数,默认为空字典
        split_batches = False,  # 是否拆分批次,默认为False
        drop_last = False,  # 是否丢弃最后一批数据,默认为False
        force_clear_prev_results = None  # 强制清除之前的结果,默认为None
        ):
        # 调用父类的构造函数
        super().__init__()
        # 检查是否只有一个训练器
        check_one_trainer()

        # 初始化加速器
        self.accelerator = Accelerator(
            split_batches = split_batches,
            **accelerate_kwargs
        )

        # 设置模型
        self.model = model

        # 创建训练包装器
        self.train_wrapper = SemanticToTextWrapper(model = model)

        # 注册缓冲区
        self.register_buffer('steps', torch.Tensor([0]))

        # 设置训练步数、预热步数、批量大小、梯度累积频率
        self.num_train_steps = num_train_steps
        self.num_warmup_steps = num_warmup_steps
        self.batch_size = batch_size
        self.grad_accum_every = grad_accum_every

        # 在进行反向翻译时,冻结编码器和语音嵌入
        model.unfreeze_all()
        model.freeze_speech_emb()
        model.freeze_encoder()

        # 优化器
        # get_optimizer应该过滤掉冻结的参数(requires_grad设置为False的参数)
        self.optim = get_optimizer(
            model.parameters(),
            lr = lr,
            wd = wd,
            filter_by_requires_grad = True
        )

        self.lr = lr
        self.initial_lr = initial_lr
        self.scheduler = CosineAnnealingLR(self.optim, T_max = num_train_steps)

        # 最大梯度范数
        self.max_grad_norm = max_grad_norm

        # 创建数据集
        self.ds = dataset

        # 划分验证集
        if valid_frac > 0:
            train_size = int((1 - valid_frac) * len(self.ds))
            valid_size = len(self.ds) - train_size
            self.ds, self.valid_ds = random_split(self.ds, [train_size, valid_size], generator = torch.Generator().manual_seed(random_split_seed))
            self.print(f'training with dataset of {len(self.ds)} samples and validating with randomly splitted {len(self.valid_ds)} samples')
        else:
            self.valid_ds = self.ds
            self.print(f'training with shared training and valid dataset of {len(self.ds)} samples')

        assert len(self.ds) >= batch_size, 'dataset must have sufficient samples for training'
        assert len(self.valid_ds) >= batch_size, f'validation dataset must have sufficient number of samples (currently {len(self.valid_ds)}) for training'

        # 数据加载器
        self.dl = get_dataloader(self.ds, batch_size = batch_size, shuffle = True, drop_last = drop_last)

        self.valid_dl = get_dataloader(self.valid_ds, batch_size = batch_size, shuffle = True, drop_last = drop_last)

        # 使用加速器准备
        (
            self.train_wrapper,
            self.optim,
            self.scheduler,
            self.dl,
            self.valid_dl
        ) = self.accelerator.prepare(
            self.train_wrapper,
            self.optim,
            self.scheduler,
            self.dl,
            self.valid_dl
        )

        # 数据加载器迭代器
        self.dl_iter = cycle(self.dl)
        self.valid_dl_iter = cycle(self.valid_dl)

        self.log_every = log_every
        self.save_model_every = save_model_every
        self.save_results_every = save_results_every

        self.results_folder = Path(results_folder)

        # 如果是主进程并且强制清除之前的结果或者(force_clear_prev_results不存在且结果文件夹中有文件且用户确认清除)
        if self.is_main and force_clear_prev_results is True or (not exists(force_clear_prev_results) and len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?')):
            rmtree(str(self.results_folder))

        # 创建结果文件夹
        self.results_folder.mkdir(parents = True, exist_ok = True)
        
        # 初始化超参数跟踪器
        hps = {"num_train_steps": num_train_steps, "num_warmup_steps": num_warmup_steps, "learning_rate": lr, "initial_learning_rate": lr}
        self.accelerator.init_trackers("semantictext", config=hps)

    # 保存模型
    def save(self, path):
        pkg = dict(
            model = self.accelerator.get_state_dict(self.model),
            optim = self.optim.state_dict(),
            scheduler = self.scheduler.state_dict()
        )
        torch.save(pkg, path)
    # 加载模型参数和优化器状态
    def load(self, path, restore_optimizer = True):
        # 获取未封装的模型对象
        model = self.accelerator.unwrap_model(self.model)
        # 加载模型参数
        pkg = model.load(path)

        # 如果需要恢复优化器状态
        if restore_optimizer:
            # 加载优化器状态
            self.optim.load_state_dict(pkg['optim'])
            # 加载学习率调度器状态
            self.scheduler.load_state_dict(pkg['scheduler'])

            # 从下一个步骤开始,避免覆盖最后一个检查点
            self.steps = torch.tensor([checkpoint_num_steps(path) + 1], device=self.device)

    # 打印消息
    def print(self, msg):
        self.accelerator.print(msg)

    # 生成结果
    def generate(self, *args, **kwargs):
        return self.train_wrapper.generate(*args, **kwargs)

    # 获取设备
    @property
    def device(self):
        return self.accelerator.device

    # 判断是否分布式训练
    @property
    def is_distributed(self):
        return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1)

    # 判断是否为主进程
    @property
    def is_main(self):
        return self.accelerator.is_main_process

    # 判断是否为本地主进程
    @property
    def is_local_main(self):
        return self.accelerator.is_local_main_process

    # 热身训练
    def warmup(self, step):
        if step < self.num_warmup_steps:
            return self.initial_lr + (self.lr - self.initial_lr) * step / self.num_warmup_steps
        else:
            return self.lr
    
    # 训练步骤
    def train_step(self):
        steps = int(self.steps.item())

        # 设置模型为训练模式
        self.model.train()
        
        # 根据调度器调整学习率

        if steps < self.num_warmup_steps:
            # 应用热身训练
            lr = self.warmup(steps)
            for param_group in self.optim.param_groups:
                param_group['lr'] = lr
        else:
            # 热身训练后,开始应用余弦退火学习率调度器
            self.scheduler.step()

        # 日志

        logs = {}

        # 更新 VAE(生成器)

        for _ in range(self.grad_accum_every):
            semantic_token_ids, grapheme_token_ids = next(self.dl_iter)

            loss, _ = self.train_wrapper(semantic_token_ids = semantic_token_ids, grapheme_token_ids = grapheme_token_ids)

            self.accelerator.backward(loss / self.grad_accum_every)

            accum_log(logs, {'loss': loss.item() / self.grad_accum_every})

        # 如果存在最大梯度范数,则进行梯度裁剪
        if exists(self.max_grad_norm):
            self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)

        self.optim.step()
        self.optim.zero_grad()

        # 记录日志

        if not (steps % self.log_every):
            self.print(f"{steps}: loss: {logs['loss']:0.3f}")
        self.accelerator.log({"train_loss": logs['loss']}, step=steps)

        # 定期采样结果

        self.accelerator.wait_for_everyone()

        if self.is_main and not (steps % self.save_results_every):
            semantic_token_ids, grapheme_token_ids = next(self.valid_dl_iter)

            with torch.inference_mode():
                self.train_wrapper.eval()
                valid_loss, _ = self.train_wrapper(semantic_token_ids = semantic_token_ids, grapheme_token_ids = grapheme_token_ids)

            self.print(f'{steps}: valid loss {valid_loss:0.3f}')
            self.accelerator.log({"valid_loss": valid_loss}, step=steps)

        # 定期保存模型

        if self.is_main and not (steps % self.save_model_every):
            model_path = str(self.results_folder / f'semantic.text.{steps}.pt')
            self.save(model_path)

            self.print(f'{steps}: saving model to {str(self.results_folder)}')

        self.steps += 1
        return logs

    # 训练模型
    def train(self, log_fn = noop):
        while self.steps < self.num_train_steps:
            logs = self.train_step()
            log_fn(logs)

        self.print('training complete')
# 定义一个用于训练文本到语义模型的类
class TextToSemanticTrainer(nn.Module):
    # 初始化函数,接受模型、训练步数、预热步数等参数
    @beartype
    def __init__(
        self,
        model: TextToSemantic,
        *,
        num_train_steps,
        num_warmup_steps,
        batch_size,
        dataset: Optional[Dataset] = None,
        generated_audio_text_dataset_folder = None,
        dataset_delimiter_id = -1,
        lr = 3e-4,
        initial_lr = 1e-5,
        grad_accum_every = 1,
        wd = 0.,
        max_grad_norm = 0.5,
        valid_frac = 0.05,
        random_split_seed = 42,
        log_every = 10,
        save_results_every = 100,
        save_model_every = 1000,
        results_folder = './results',
        accelerate_kwargs: dict = dict(),
        split_batches = False,
        drop_last = False,
        force_clear_prev_results = None,
        freeze_encoder_layers_below = 2,
        should_train_early_exit_layer_if_available = True
    # 保存模型参数到指定路径
    def save(self, path):
        pkg = dict(
            model = self.accelerator.get_state_dict(self.model),
            optim = self.optim.state_dict(),
            scheduler = self.scheduler.state_dict()
        )
        torch.save(pkg, path)

    # 从指定路径加载模型参数,可选择是否还原优化器状态
    def load(self, path, restore_optimizer = True):
        model = self.accelerator.unwrap_model(self.model)
        pkg = model.load(path)

        if restore_optimizer:
            self.optim.load_state_dict(pkg['optim'])
            self.scheduler.load_state_dict(pkg['scheduler'])

            # + 1 to start from the next step and avoid overwriting the last checkpoint
            self.steps = torch.tensor([checkpoint_num_steps(path) + 1], device=self.device)

    # 打印消息
    def print(self, msg):
        self.accelerator.print(msg)

    # 生成结果
    def generate(self, *args, **kwargs):
        return self.train_wrapper.generate(*args, **kwargs)

    # 返回设备信息
    @property
    def device(self):
        return self.accelerator.device

    # 判断是否为分布式训练
    @property
    def is_distributed(self):
        return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1)

    # 判断是否为主进程
    @property
    def is_main(self):
        return self.accelerator.is_main_process

    # 判断是否为本地主进程
    @property
    def is_local_main(self):
        return self.accelerator.is_local_main_process

    # 根据当前步数计算学习率
    def warmup(self, step):
        if step < self.num_warmup_steps:
            return self.initial_lr + (self.lr - self.initial_lr) * step / self.num_warmup_steps
        else:
            return self.lr
    # 定义训练步骤函数
    def train_step(self):
        # 获取当前步数
        steps = int(self.steps.item())

        # 设置模型为训练模式
        self.model.train()
        
        # 根据训练步数调整学习率
        
        if steps < self.num_warmup_steps:
            # 如果步数小于预热步数,应用预热
            lr = self.warmup(steps)
            for param_group in self.optim.param_groups:
                param_group['lr'] = lr
        else:
            # 预热期后,开始应用余弦退火学习率调度器
            self.scheduler.step()

        # 日志

        logs = {}

        # 更新 VAE(生成器)

        for _ in range(self.grad_accum_every):
            semantic_token_ids, grapheme_token_ids = next(self.dl_iter)

            # 计算损失并进行训练
            loss, _ = self.train_wrapper(semantic_token_ids=semantic_token_ids, grapheme_token_ids=grapheme_token_ids, return_early_exit_loss=self.train_early_exit)

            self.accelerator.backward(loss / self.grad_accum_every)

            accum_log(logs, {'loss': loss.item() / self.grad_accum_every})

        # 如果存在最大梯度范数,对梯度进行裁剪
        if exists(self.max_grad_norm):
            self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)

        # 更新优化器
        self.optim.step()
        self.optim.zero_grad()

        # 记录日志

        if not (steps % self.log_every):
            self.print(f"{steps}: loss: {logs['loss']:0.3f}")
        
        self.accelerator.log({"train_loss": logs['loss']}, step=steps)

        # 定期采样结果

        self.accelerator.wait_for_everyone()

        if self.is_main and not (steps % self.save_results_every):
            semantic_token_ids, grapheme_token_ids = next(self.valid_dl_iter)

            with torch.inference_mode():
                self.train_wrapper.eval()
                valid_loss, _ = self.train_wrapper(semantic_token_ids=semantic_token_ids, grapheme_token_ids=grapheme_token_ids, return_early_exit_loss=self.train_early_exit)

            self.print(f'{steps}: valid loss {valid_loss:0.3f}')
            self.accelerator.log({"valid_loss": valid_loss}, step=steps)

        # 定期保存模型

        if self.is_main and not (steps % self.save_model_every):
            model_path = str(self.results_folder / f'text.semantic.{steps}.pt')
            self.save(model_path)

            self.print(f'{steps}: saving model to {str(self.results_folder)}')

        # 更新步数并返回日志
        self.steps += 1
        return logs

    # 训练函数
    def train(self, log_fn=noop):
        # 在未达到训练步数前循环执行训练步骤
        while self.steps < self.num_train_steps:
            logs = self.train_step()
            log_fn(logs)

        self.print('training complete')