Lucidrains-系列项目源码解析-六-

118 阅读29分钟

Lucidrains 系列项目源码解析(六)

.\lucidrains\audiolm-pytorch\audiolm_pytorch\data.py

# 导入必要的模块
from pathlib import Path
from functools import partial, wraps

# 导入 beartype 模块及相关类型
from beartype import beartype
from beartype.typing import Tuple, Union, Optional
from beartype.door import is_bearable

# 导入 torchaudio 模块及相关函数
import torchaudio
from torchaudio.functional import resample

# 导入 torch 模块及相关函数
import torch
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader

# 导入自定义工具函数
from audiolm_pytorch.utils import curtail_to_multiple

# 导入 einops 模块中的函数
from einops import rearrange, reduce

# 定义一些辅助函数

# 判断值是否存在
def exists(val):
    return val is not None

# 将值转换为元组
def cast_tuple(val, length = 1):
    return val if isinstance(val, tuple) else ((val,) * length)

# 判断列表中的元素是否唯一
def is_unique(arr):
    return len(set(arr)) == len(arr

# 定义数据集类
class SoundDataset(Dataset):
    @beartype
    def __init__(
        self,
        folder,
        target_sample_hz: Union[int, Tuple[int, ...]],  # 目标采样率必须指定,或者是一个包含多个目标采样率的元组
        exts = ['flac', 'wav', 'mp3', 'webm'],
        max_length: Optional[int] = None,               # 如果有多个目标采样率,最大长度将应用于最高的采样率
        seq_len_multiple_of: Optional[Union[int, Tuple[Optional[int], ...]]] = None
    ):
        super().__init__()
        path = Path(folder)
        assert path.exists(), f'folder "{str(path)}" does not exist'

        files = [file for ext in exts for file in path.glob(f'**/*.{ext}')]
        assert len(files) > 0, 'no sound files found'

        self.files = files

        self.max_length = max_length
        self.target_sample_hz = cast_tuple(target_sample_hz)
        num_outputs = len(self.target_sample_hz)

        # 如果有多个目标采样率,首先将其重采样为最高的采样率,然后应用最大长度,最后再重采样为其他采样率

        self.max_target_sample_hz = max(self.target_sample_hz)
        self.seq_len_multiple_of = cast_tuple(seq_len_multiple_of, num_outputs)

        assert len(self.target_sample_hz) == len(self.seq_len_multiple_of)

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        file = self.files[idx]

        data, sample_hz = torchaudio.load(file)

        assert data.numel() > 0, f'one of your audio file ({file}) is empty. please remove it from your folder'

        if data.shape[0] > 1:
            # 如果音频有多个通道,转换为单声道
            data = reduce(data, 'c ... -> 1 ...', 'mean')

        # 首先将数据重采样为最大目标频率

        data = resample(data, sample_hz, self.max_target_sample_hz)
        sample_hz = self.max_target_sample_hz

        # 根据最大长度截断或填充音频

        max_length = self.max_length
        audio_length = data.size(1)

        if exists(max_length):
            if audio_length > max_length:
                max_start = audio_length - max_length
                start = torch.randint(0, max_start, (1, ))
                data = data[:, start:start + max_length]
            else:
                data = F.pad(data, (0, max_length - audio_length), 'constant')

        data = rearrange(data, '1 ... -> ...')

        # 如果目标采样率不是元组中的 None,则重采样

        num_outputs = len(self.target_sample_hz)
        data = cast_tuple(data, num_outputs)

        data_tuple = tuple(resample(d, sample_hz, target_sample_hz) for d, target_sample_hz in zip(data, self.target_sample_hz))

        output = []

        # 逐个处理不同频率下的数据以符合多个长度的要求

        for data, seq_len_multiple_of in zip(data_tuple, self.seq_len_multiple_of):
            if exists(seq_len_multiple_of):
                data = curtail_to_multiple(data, seq_len_multiple_of)

            output.append(data.float())

        # 从列表转换为元组

        output = tuple(output)

        # 如果只有一个目标重采样频率,则返回一个音频

        if num_outputs == 1:
            return output[0]

        return output

# 数据加载函数

# 定义一个装饰器函数,用于处理单个或多个张量的填充
def collate_one_or_multiple_tensors(fn):
    @wraps(fn)
    def inner(data):
        is_one_data = not isinstance(data[0], tuple)

        if is_one_data:
            data = fn(data)
            return (data,)

        outputs = []
        for datum in zip(*data):
            if is_bearable(datum, Tuple[str, ...]):
                output = list(datum)
            else:
                output = fn(datum)

            outputs.append(output)

        return tuple(outputs)

    return inner

# 对最短的数据进行填充
@collate_one_or_multiple_tensors
def curtail_to_shortest_collate(data):
    min_len = min(*[datum.shape[0] for datum in data])
    data = [datum[:min_len] for datum in data]
    return torch.stack(data)

# 对最长的数据进行填充
@collate_one_or_multiple_tensors
def pad_to_longest_fn(data):
    return pad_sequence(data, batch_first = True)

# 获取数据加载器
def get_dataloader(ds, pad_to_longest = True, **kwargs):
    collate_fn = pad_to_longest_fn if pad_to_longest else curtail_to_shortest_collate
    return DataLoader(ds, collate_fn = collate_fn, **kwargs)

.\lucidrains\audiolm-pytorch\audiolm_pytorch\encodec.py

# 导入所需的库和模块
from functools import reduce
from einops import rearrange, pack, unpack
import torch
from torch import nn
from torchaudio.functional import resample
from vector_quantize_pytorch import ResidualVQ
from encodec import EncodecModel
from encodec.utils import _linear_overlap_add

# 定义一个辅助函数,用于检查变量是否存在
def exists(val):
    return val is not None

# 获取模型中的量化器数量
def get_num_quantizers(model: EncodecModel, audio_length = 512):
    out = model.encode(torch.randn(1, 1, audio_length))
    return out[0][0].shape[1]

# 定义一个包装器类,用于支持预训练的 24kHz Encodec 模型
class EncodecWrapper(nn.Module):
    def __init__(
        self,
        target_sample_hz = 24000,
        strides = (2, 4, 5, 8),
        num_quantizers = 8,
        bandwidth = 6.0
    ):
        super().__init__()
        # 实例化一个预训练的 Encodec 模型
        self.model = EncodecModel.encodec_model_24khz()
        self.model.normalize = False

        # 设置目标带宽,影响量化器数量
        self.model.set_target_bandwidth(bandwidth)
        num_quantizers = get_num_quantizers(self.model)

        # 设置一些字段
        self.target_sample_hz = target_sample_hz
        assert self.target_sample_hz == 24000, "haven't done anything with non-24kHz yet"
        self.codebook_dim = 128
        self.rq_groups = 1
        self.num_quantizers = num_quantizers
        self.strides = strides

        # 初始化 ResidualVQ 模块
        self.rq = ResidualVQ(
            dim = 128,
            codebook_size = 1024,
            num_quantizers = num_quantizers
        )

        # 复制编码器的码书到 ResidualVQ 模块
        for encodec_rq_layer, rq_layer in zip(self.model.quantizer.vq.layers, self.rq.layers):
            encodec_codebook = dict(encodec_rq_layer._codebook.named_buffers()).get('embed')
            vq_codebook = dict(rq_layer._codebook.named_buffers()).get('embed')
            encodec_codebook = rearrange(encodec_codebook, '... -> 1 ...')
            vq_codebook.copy_(encodec_codebook)

    @property
    def seq_len_multiple_of(self):
        return reduce(lambda x, y: x * y, self.strides)

    @property
    def downsample_factor(self):
        return self.seq_len_multiple_of

    def forward(
        self,
        x,
        input_sample_hz = None,
        return_encoded = False,
        **kwargs
    ):
        x, ps = pack([x], '* n')

        if exists(input_sample_hz):
            x = resample(x, input_sample_hz, self.target_sample_hz)

        assert not self.model.training, "Encodec is pretrained and should never be called outside eval mode."

        wav = rearrange(x, f'b t -> b {self.model.channels} t')

        with torch.inference_mode():
            encoded_frames = self.model.encode(wav)

        codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1)
        codes = rearrange(codes, 'b q n -> b n q')

        emb = None

        if return_encoded:
            emb = self.get_emb_from_indices(codes)
            emb, = unpack(emb, ps, '* n c')

        codes, = unpack(codes, ps, '* n q')

        return emb, codes, None

    def decode_from_codebook_indices(self, quantized_indices):
        frames = self._decode_frame(quantized_indices)
        result = _linear_overlap_add(frames, self.model.segment_stride or 1)
        return rearrange(result, 'b n -> b 1 n')

    def get_emb_from_indices(self, indices):
        codes = rearrange(indices, 'b t q -> q b t')
        emb = self.model.quantizer.decode(codes)
        return rearrange(emb, 'b c n -> b n c')

    def decode(self, emb):
        emb = rearrange(emb, 'b n c -> b c n')
        return self.model.decoder(emb)
    # 解码帧数据,输入为量化后的索引
    def _decode_frame(self, quantized_indices):
        # 以下代码是从 self.model._decode_frame() (Encodec 版本 0.1.1) 中插入的,假设我们已经解包了 EncodedFrame
        # 输入: batch x num tokens x num quantizers
        # 输出: batch x new_num_samples,其中 new_num_samples 是 num_frames * stride 的乘积(可能略大于原始 num samples,因为最后一帧可能不是完全填满的)
        # num_frames == 你拥有的声学标记数量,每个标记对应一帧
        # 重新排列量化后的索引,形状为 'b t q -> q b t'
        codes = rearrange(quantized_indices, 'b t q -> q b t')
        # 使用量化器解码得到的嵌入
        emb = self.model.quantizer.decode(codes)
        # emb 形状: batch x self.model.quantizer.dimension x T。注意 self.model.quantizer.dimension 是嵌入维度
        return self.model.decoder(emb)

.\lucidrains\audiolm-pytorch\audiolm_pytorch\hubert_kmeans.py

# 导入必要的库
from pathlib import Path
import torch
from torch import nn, einsum
from torchaudio.functional import resample
from einops import rearrange, repeat, pack, unpack
from audiolm_pytorch.utils import curtail_to_multiple

# 定义一个空函数用于忽略警告
def noop(*args, **kwargs):
    pass

import warnings
import logging

# 设置日志级别为 ERROR
logging.root.setLevel(logging.ERROR)

# 忽略警告
warnings.warn = noop

# 导入 fairseq 和 joblib 用于 hubert 模型
import joblib
import fairseq

# 定义辅助函数
def exists(val):
    return val is not None

def default(val, d):
    return val if exists(val) else d

# 定义一个带有 kmeans 的 Hubert 模型类
class HubertWithKmeans(nn.Module):
    """
    checkpoint and kmeans can be downloaded at https://github.com/facebookresearch/fairseq/tree/main/examples/hubert
    or you can train your own
    """

    def __init__(
        self,
        checkpoint_path,
        kmeans_path,
        target_sample_hz = 16000,
        seq_len_multiple_of = None,
        output_layer = 9
    ):
        super().__init__()

        # 初始化模型参数
        self.target_sample_hz = target_sample_hz
        self.seq_len_multiple_of = seq_len_multiple_of
        self.output_layer = output_layer

        # 加载模型和 kmeans
        model_path = Path(checkpoint_path)
        kmeans_path = Path(kmeans_path)

        assert model_path.exists(), f'path {checkpoint_path} does not exist'
        assert kmeans_path.exists(), f'path {kmeans_path} does not exist'

        checkpoint = torch.load(checkpoint_path)
        load_model_input = {checkpoint_path: checkpoint}
        model, *_ = fairseq.checkpoint_utils.load_model_ensemble_and_task(load_model_input)

        self.model = model[0]
        self.model.eval()

        kmeans = joblib.load(kmeans_path)

        self.kmeans = kmeans

        # 注册缓冲区
        self.register_buffer(
            'cluster_centers',
            torch.from_numpy(kmeans.cluster_centers_)
        )

    @property
    def groups(self):
        return 1

    @property
    def codebook_size(self):
        return self.kmeans.n_clusters

    @property
    def downsample_factor(self):
        # todo: double check
        return 320

    @torch.inference_mode()
    def forward(
        self,
        wav_input,
        flatten = True,
        input_sample_hz = None
    ):
        # 获取输入数据的批次和设备
        batch, device = wav_input.shape[0], wav_input.device

        # 如果输入采样率存在,则对输入进行重采样
        if exists(input_sample_hz):
            wav_input = resample(wav_input, input_sample_hz, self.target_sample_hz)

        # 如果设置了 seq_len_multiple_of,则对输入进行截断
        if exists(self.seq_len_multiple_of):
            wav_input = curtail_to_multiple(wav_input, self.seq_len_multiple_of)

        # 提取特征
        embed = self.model(
            wav_input,
            features_only = True,
            mask = False,
            output_layer = self.output_layer
        )['x']

        # 重复聚类中心以匹配嵌入的形状
        batched_cluster_centers = repeat(self.cluster_centers, 'c d -> b c d', b = embed.shape[0])
        # 计算嵌入和聚类中心之间的欧氏距离
        dists = -torch.cdist(embed, batched_cluster_centers, p = 2)
        # 获取最大距离对应的聚类
        clusters = dists.argmax(dim = -1)

        # 如果 flatten 为 True,则返回平坦的聚类结果
        if flatten:
            return clusters

        # 否则返回重排后的聚类结果
        return rearrange(clusters, 'b ... -> b (...)')

.\lucidrains\audiolm-pytorch\audiolm_pytorch\optimizer.py

# 从 torch.optim 模块中导入 AdamW 和 Adam 优化器
from torch.optim import AdamW, Adam

# 将参数分为需要权重衰减和不需要权重衰减的两个列表
def separate_weight_decayable_params(params):
    wd_params, no_wd_params = [], []
    for param in params:
        # 根据参数的维度判断是否需要权重衰减
        param_list = no_wd_params if param.ndim < 2 else wd_params
        param_list.append(param)
    return wd_params, no_wd_params

# 获取优化器
def get_optimizer(
    params,
    lr = 1e-4,
    wd = 1e-2,
    betas = (0.9, 0.99),
    eps = 1e-8,
    filter_by_requires_grad = False,
    group_wd_params = True,
    use_lion = False,
    **kwargs
):
    # 判断是否需要权重衰减
    has_wd = wd > 0

    # 根据是否需要过滤梯度为零的参数来更新参数列表
    if filter_by_requires_grad:
        params = list(filter(lambda t: t.requires_grad, params))

    # 如果需要对参数进行分组并应用权重衰减
    if group_wd_params and has_wd:
        wd_params, no_wd_params = separate_weight_decayable_params(params)

        params = [
            {'params': wd_params},
            {'params': no_wd_params, 'weight_decay': 0},
        ]

    # 如果不需要权重衰减,则使用 Adam 优化器
    if not has_wd:
        return Adam(params, lr = lr, betas = betas, eps = eps)

    # 如果需要权重衰减,则使用 AdamW 优化器
    return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps)

.\lucidrains\audiolm-pytorch\audiolm_pytorch\soundstream.py

# 导入必要的库
import functools
from pathlib import Path
from functools import partial, wraps
from itertools import cycle, zip_longest
from typing import Optional, List

import torch
from torch import nn, einsum
from torch.nn import Module, ModuleList
from torch.autograd import grad as torch_grad
import torch.nn.functional as F
from torch.linalg import vector_norm

import torchaudio.transforms as T
from torchaudio.functional import resample

from einops import rearrange, reduce, pack, unpack

# 导入自定义模块
from vector_quantize_pytorch import (
    GroupedResidualVQ,
    GroupedResidualLFQ,
    GroupedResidualFSQ
)

from local_attention import LocalMHA
from local_attention.transformer import FeedForward, DynamicPositionBias

from gateloop_transformer import SimpleGateLoopLayer as GateLoop

from audiolm_pytorch.utils import curtail_to_multiple

from audiolm_pytorch.version import __version__
from packaging import version
parsed_version = version.parse(__version__)

import pickle

# 辅助函数

# 检查值是否存在
def exists(val):
    return val is not None

# 如果值存在则返回该值,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# 将元组转换为指定长度的元组
def cast_tuple(t, l = 1):
    return ((t,) * l) if not isinstance(t, tuple) else t

# 根据键过滤字典
def filter_by_keys(fn, d):
    return {k: v for k, v in d.items() if fn(k)}

# 映射字典键
def map_keys(fn, d):
    return {fn(k): v for k, v in d.items()}

# GAN 损失函数

# 对数函数
def log(t, eps = 1e-20):
    return torch.log(t.clamp(min = eps))

# Hinge 判别器损失
def hinge_discr_loss(fake, real):
    return (F.relu(1 + fake) + F.relu(1 - real)).mean()

# Hinge 生成器损失
def hinge_gen_loss(fake):
    return -fake.mean()

# Leaky ReLU 激活函数
def leaky_relu(p = 0.1):
    return nn.LeakyReLU(p)

# 梯度惩罚
def gradient_penalty(wave, output, weight = 10):
    batch_size, device = wave.shape[0], wave.device

    gradients = torch_grad(
        outputs = output,
        inputs = wave,
        grad_outputs = torch.ones_like(output),
        create_graph = True,
        retain_graph = True,
        only_inputs = True
    )[0]

    gradients = rearrange(gradients, 'b ... -> b (...)')
    return weight * ((vector_norm(gradients, dim = 1) - 1) ** 2).mean()

# 更好的序列化函数

def Sequential(*mods):
    return nn.Sequential(*filter(exists, mods))

# 判别器

class MultiScaleDiscriminator(Module):
    def __init__(
        self,
        channels = 16,
        layers = 4,
        groups = (4, 16, 64, 256),
        chan_max = 1024,
        input_channels = 1
    ):
        super().__init__()
        self.init_conv = nn.Conv1d(input_channels, channels, 15, padding = 7)
        self.conv_layers = ModuleList([])

        curr_channels = channels

        for _, group in zip(range(layers), groups):
            chan_out = min(curr_channels * 4, chan_max)

            self.conv_layers.append(nn.Sequential(
                nn.Conv1d(curr_channels, chan_out, 41, stride = 4, padding = 20, groups = group),
                leaky_relu()
            ))

            curr_channels = chan_out

        self.final_conv = nn.Sequential(
            nn.Conv1d(curr_channels, curr_channels, 5, padding = 2),
            leaky_relu(),
            nn.Conv1d(curr_channels, 1, 3, padding = 1),
        )

    def forward(
        self,
        x,
        return_intermediates = False
    ):
        x = self.init_conv(x)
        intermediates = []

        for layer in self.conv_layers:
            x = layer(x)
            intermediates.append(x)

        out = self.final_conv(x)

        if not return_intermediates:
            return out

        return out, intermediates

# 自回归挤压激励
# https://arxiv.org/abs/1709.01507

class SqueezeExcite(Module):
    def __init__(self, dim, reduction_factor = 4, dim_minimum = 8):
        super().__init__()
        dim_inner = max(dim_minimum, dim // reduction_factor)
        self.net = nn.Sequential(
            nn.Conv1d(dim, dim_inner, 1),
            nn.SiLU(),
            nn.Conv1d(dim_inner, dim, 1),
            nn.Sigmoid()
        )
    # 定义前向传播函数,输入参数 x
    def forward(self, x):
        # 获取输入 x 的序列长度和设备信息
        seq, device = x.shape[-2], x.device

        # 计算累积均值 - 因为是自回归的

        # 沿着倒数第二个维度对 x 进行累积求和
        cum_sum = x.cumsum(dim = -2)
        # 创建一个序列长度范围的张量,转换为浮点数类型,并移动到指定设备
        denom = torch.arange(1, seq + 1, device = device).float()
        # 计算累积均值,即累积和除以对应的序号
        cum_mean = cum_sum / rearrange(denom, 'n -> n 1')

        # glu 门

        # 通过神经网络计算门控值
        gate = self.net(cum_mean)

        # 返回输入 x 与门控值的乘积
        return x * gate
# 定义一个复杂的短时傅里叶变换鉴别器

class ModReLU(Module):
    """
    https://arxiv.org/abs/1705.09792
    https://github.com/pytorch/pytorch/issues/47052#issuecomment-718948801
    """
    # 定义一个自定义的激活函数模块,参考论文和GitHub链接
    def __init__(self):
        super().__init__()
        self.b = nn.Parameter(torch.tensor(0.))

    def forward(self, x):
        # 返回修正的ReLU激活函数应用于输入 x 的结果
        return F.relu(torch.abs(x) + self.b) * torch.exp(1.j * torch.angle(x))

class ComplexConv2d(Module):
    # 定义一个复杂卷积层模块
    def __init__(
        self,
        dim,
        dim_out,
        kernel_size,
        stride = 1,
        padding = 0
    ):
        super().__init__()
        # 创建一个普通的卷积层对象
        conv = nn.Conv2d(dim, dim_out, kernel_size, dtype = torch.complex64)
        # 将卷积层的权重和偏置参数转换为复数类型
        self.weight = nn.Parameter(torch.view_as_real(conv.weight))
        self.bias = nn.Parameter(torch.view_as_real(conv.bias))

        self.stride = stride
        self.padding = padding

    def forward(self, x):
        # 将权重和偏置参数转换为复数类型
        weight, bias = map(torch.view_as_complex, (self.weight, self.bias))

        x = x.to(weight.dtype)
        # 返回卷积操作的结果
        return F.conv2d(x, weight, bias, stride = self.stride, padding = self.padding)

def ComplexSTFTResidualUnit(chan_in, chan_out, strides):
    kernel_sizes = tuple(map(lambda t: t + 2, strides))
    paddings = tuple(map(lambda t: t // 2, kernel_sizes))

    return nn.Sequential(
        # 定义一个复杂短时傅里叶变换残差单元
        Residual(Sequential(
            ComplexConv2d(chan_in, chan_in, 3, padding = 1),
            ModReLU(),
            ComplexConv2d(chan_in, chan_in, 3, padding = 1)
        )),
        ComplexConv2d(chan_in, chan_out, kernel_sizes, stride = strides, padding = paddings)
    )

class ComplexSTFTDiscriminator(Module):
    # 定义一个复杂短时傅里叶变换鉴别器模块
    def __init__(
        self,
        *,
        channels = 32,
        strides = ((1, 2), (2, 2), (1, 2), (2, 2), (1, 2), (2, 2)),
        chan_mults = (1, 2, 4, 4, 8, 8),
        input_channels = 1,
        n_fft = 1024,
        hop_length = 256,
        win_length = 1024,
        stft_normalized = False,
        stft_window_fn = torch.hann_window,
        logits_abs = True
    ):
        super().__init__()
        # 初始化卷积层
        self.init_conv = ComplexConv2d(input_channels, channels, 7, padding = 3)

        layer_channels = tuple(map(lambda mult: mult * channels, chan_mults))
        layer_channels = (channels, *layer_channels)
        layer_channels_pairs = tuple(zip(layer_channels[:-1], layer_channels[1:]))

        curr_channels = channels

        self.layers = ModuleList([])

        for layer_stride, (chan_in, chan_out) in zip(strides, layer_channels_pairs):
            # 添加复杂短时傅里叶变换残差单元到层列表中
            self.layers.append(ComplexSTFTResidualUnit(chan_in, chan_out, layer_stride))

        # 添加最终的卷积层
        self.final_conv = ComplexConv2d(layer_channels[-1], 1, (16, 1)) # todo: remove hardcoded 16

        # stft 设置

        self.stft_normalized = stft_normalized
        self.stft_window_fn = stft_window_fn

        self.n_fft = n_fft
        self.hop_length = hop_length
        self.win_length = win_length

        # 如何将对数输出转换为实数空间

        self.logits_abs = logits_abs
    # 定义一个前向传播函数,接受输入 x 和是否返回中间结果的标志
    def forward(self, x, return_intermediates = False):
        # 重新排列输入张量 x 的维度,将 'b 1 n' 转换为 'b n'
        x = rearrange(x, 'b 1 n -> b n')

        '''
        reference: The content of the paper( https://arxiv.org/pdf/2107.03312.pdf)is as follows:
        The STFT-based discriminator is illustrated in Figure 4
        and operates on a single scale, computing the STFT with a
        window length of W = 1024 samples and a hop length of
        H = 256 samples
        '''
        
        # 使用 self.stft_window_fn 函数生成 STFT 窗口
        stft_window = self.stft_window_fn(self.win_length, device = x.device)

        # 计算输入 x 的短时傅里叶变换(STFT)
        x = torch.stft(
            x,
            self.n_fft,
            hop_length = self.hop_length,
            win_length = self.win_length,
            window = stft_window,
            normalized = self.stft_normalized,
            return_complex = True
        )

        # 重新排列 STFT 结果的维度,将 'b ...' 转换为 'b 1 ...'
        x = rearrange(x, 'b ... -> b 1 ...')

        intermediates = []

        # 对输入 x 进行初始卷积操作
        x = self.init_conv(x)

        intermediates.append(x)

        # 遍历所有层进行处理
        for layer in self.layers:
            x = layer(x)
            intermediates.append(x)

        # 对最终卷积结果进行处理,得到复数形式的 logits
        complex_logits = self.final_conv(x)

        # 如果 logits_abs 为 True,则取复数 logits 的绝对值
        if self.logits_abs:
            complex_logits = complex_logits.abs()
        else:
            complex_logits = torch.view_as_real(complex_logits)

        # 如果不需要返回中间结果,则直接返回复数 logits
        if not return_intermediates:
            return complex_logits

        # 如果需要返回中间结果,则同时返回复数 logits 和中间结果列表
        return complex_logits, intermediates
# 定义一个名为 Residual 的类,继承自 Module 类
class Residual(Module):
    # 初始化函数,接受一个名为 fn 的 Module 对象作为参数
    def __init__(self, fn: Module):
        super().__init__()
        self.fn = fn

    # 前向传播函数,接受输入 x 和关键字参数 kwargs
    def forward(self, x, **kwargs):
        # 返回输入 x 经过 fn 处理后的结果与 x 相加的结果
        return self.fn(x, **kwargs) + x

# 定义一个名为 ChannelTranspose 的类,继承自 Module 类
class ChannelTranspose(Module):
    # 初始化函数,接受一个名为 fn 的 Module 对象作为参数
    def __init__(self, fn: Module):
        super().__init__()
        self.fn = fn

    # 前向传播函数,接受输入 x 和关键字参数 kwargs
    def forward(self, x, **kwargs):
        # 将输入 x 的维度重新排列为 'b c n'
        x = rearrange(x, 'b c n -> b n c')
        # 将重新排列后的输入 x 经过 fn 处理后的结果与 x 相加的结果
        out = self.fn(x, **kwargs) + x
        # 将输出 out 的维度重新排列为 'b n c'
        return rearrange(out, 'b n c -> b c n')

# 定义一个名为 CausalConv1d 的类,继承自 Module 类
class CausalConv1d(Module):
    # 初始化函数,接受通道数 chan_in、输出通道数 chan_out、卷积核大小 kernel_size 和填充模式 pad_mode 等参数
    def __init__(self, chan_in, chan_out, kernel_size, pad_mode = 'reflect', **kwargs):
        super().__init__()
        # 设置卷积核大小
        kernel_size = kernel_size
        # 获取关键字参数中的膨胀值和步长
        dilation = kwargs.get('dilation', 1)
        stride = kwargs.get('stride', 1)
        self.pad_mode = pad_mode
        # 计算因果填充值
        self.causal_padding = dilation * (kernel_size - 1) + (1 - stride)

        # 创建一个 1D 卷积层
        self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, **kwargs)

    # 前向传播函数,接受输入 x
    def forward(self, x):
        # 对输入 x 进行填充,使用填充模式 pad_mode
        x = F.pad(x, (self.causal_padding, 0), mode = self.pad_mode)
        # 将填充后的输入 x 经过卷积层处理后返回
        return self.conv(x)

# 定义一个名为 CausalConvTranspose1d 的类,继承自 Module 类
class CausalConvTranspose1d(Module):
    # 初始化函数,接受通道数 chan_in、输出通道数 chan_out、卷积核大小 kernel_size 和步长 stride 等参数
    def __init__(self, chan_in, chan_out, kernel_size, stride, **kwargs):
        super().__init__()
        self.upsample_factor = stride
        self.padding = kernel_size - 1
        # 创建一个 1D 转置卷积层
        self.conv = nn.ConvTranspose1d(chan_in, chan_out, kernel_size, stride, **kwargs)

    # 前向传播函数,接受输入 x
    def forward(self, x):
        n = x.shape[-1]

        # 将输入 x 经过转置卷积层处理后返回,并截取指定长度的输出
        out = self.conv(x)
        out = out[..., :(n * self.upsample_factor)]

        return out

# 定义一个名为 ResidualUnit 的函数,接受输入通道数 chan_in、输出通道数 chan_out、膨胀值 dilation 等参数
def ResidualUnit(chan_in, chan_out, dilation, kernel_size = 7, squeeze_excite = False, pad_mode = 'reflect'):
    # 返回一个 Residual 类的实例,包含一系列操作
    return Residual(Sequential(
        CausalConv1d(chan_in, chan_out, kernel_size, dilation = dilation, pad_mode = pad_mode),
        nn.ELU(),
        CausalConv1d(chan_out, chan_out, 1, pad_mode = pad_mode),
        nn.ELU(),
        SqueezeExcite(chan_out) if squeeze_excite else None
    ))

# 定义一个名为 EncoderBlock 的函数,接受输入通道数 chan_in、输出通道数 chan_out、步长 stride 等参数
def EncoderBlock(chan_in, chan_out, stride, cycle_dilations = (1, 3, 9), squeeze_excite = False, pad_mode = 'reflect'):
    # 创建一个循环迭代器
    it = cycle(cycle_dilations)
    # 使用偏函数创建一个 ResidualUnit 函数的部分应用
    residual_unit = partial(ResidualUnit, squeeze_excite = squeeze_excite, pad_mode = pad_mode)

    return nn.Sequential(
        # 一系列残差单元和卷积操作组成的编码器块
        residual_unit(chan_in, chan_in, next(it)),
        residual_unit(chan_in, chan_in, next(it)),
        residual_unit(chan_in, chan_in, next(it)),
        CausalConv1d(chan_in, chan_out, 2 * stride, stride = stride)
    )

# 定义一个名为 DecoderBlock 的函数,接受输入通道数 chan_in、输出通道数 chan_out、步长 stride 等参数
def DecoderBlock(chan_in, chan_out, stride, cycle_dilations = (1, 3, 9), squeeze_excite = False, pad_mode = 'reflect'):
    even_stride = (stride % 2 == 0)
    padding = (stride + (0 if even_stride else 1)) // 2
    output_padding = 0 if even_stride else 1

    residual_unit = partial(ResidualUnit, squeeze_excite = squeeze_excite, pad_mode = pad_mode)

    it = cycle(cycle_dilations)
    return nn.Sequential(
        # 一系列残差单元和卷积操作组成的解码器块
        CausalConvTranspose1d(chan_in, chan_out, 2 * stride, stride = stride),
        residual_unit(chan_out, chan_out, next(it)),
        residual_unit(chan_out, chan_out, next(it)),
        residual_unit(chan_out, chan_out, next(it)),
    )

# 定义一个名为 LocalTransformer 的类,继承自 Module 类
class LocalTransformer(Module):
    # 初始化函数,接受关键字参数 dim、depth、heads、window_size、dynamic_pos_bias 等
    def __init__(
        self,
        *,
        dim,
        depth,
        heads,
        window_size,
        dynamic_pos_bias = False,
        **kwargs
        ):
        # 调用父类的构造函数
        super().__init__()
        # 初始化窗口大小
        self.window_size = window_size
        # 初始化层列表
        self.layers = ModuleList([])

        # 初始化位置偏置
        self.pos_bias = None
        # 如果需要动态位置偏置
        if dynamic_pos_bias:
            # 创建动态位置偏置对象
            self.pos_bias = DynamicPositionBias(dim = dim // 2, heads = heads)

        # 根据深度循环创建多个层
        for _ in range(depth):
            # 每个层包含局部多头注意力和前馈网络
            self.layers.append(ModuleList([
                LocalMHA(
                    dim = dim,
                    heads = heads,
                    qk_rmsnorm = True,
                    window_size = window_size,
                    use_rotary_pos_emb = not dynamic_pos_bias,
                    gate_values_per_head = True,
                    use_xpos = True,
                    **kwargs
                ),
                FeedForward(dim = dim)
            ]))

    # 前向传播函数
    def forward(self, x):
        # 获取窗口大小
        w = self.window_size

        # 如果存在位置偏置,则计算注意力偏置
        attn_bias = self.pos_bias(w, w * 2) if exists(self.pos_bias) else None

        # 遍历每个层,依次进行局部多头注意力和前馈网络操作
        for attn, ff in self.layers:
            x = attn(x, attn_bias = attn_bias) + x
            x = ff(x) + x

        # 返回处理后的数据
        return x
class FiLM(Module):
    # 定义 FiLM 类,继承自 Module 类
    def __init__(self, dim, dim_cond):
        # 初始化函数,接受两个参数 dim 和 dim_cond
        super().__init__()
        # 调用父类的初始化函数
        self.to_cond = nn.Linear(dim_cond, dim * 2)
        # 创建一个线性层,输入维度为 dim_cond,输出维度为 dim * 2

    def forward(self, x, cond):
        # 前向传播函数,接受输入 x 和条件 cond
        gamma, beta = self.to_cond(cond).chunk(2, dim = -1)
        # 将条件 cond 输入到线性层中,得到 gamma 和 beta
        return x * gamma + beta
        # 返回经过 FiLM 操作后的结果

class SoundStream(Module):
    # 定义 SoundStream 类,继承自 Module 类
    def __init__(
        self,
        *,
        channels = 32,
        strides = (2, 4, 5, 8),
        channel_mults = (2, 4, 8, 16),
        codebook_dim = 512,
        codebook_size: Optional[int] = None,
        finite_scalar_quantizer_levels: Optional[List[int]] = None,
        rq_num_quantizers = 8,
        rq_commitment_weight = 1.,
        rq_ema_decay = 0.95,
        rq_quantize_dropout_multiple_of = 1,
        rq_groups = 1,
        rq_stochastic_sample_codes = False,
        rq_kwargs: dict = {},
        use_lookup_free_quantizer = False,              
        use_finite_scalar_quantizer = False,            
        input_channels = 1,
        discr_multi_scales = (1, 0.5, 0.25),
        stft_normalized = False,
        enc_cycle_dilations = (1, 3, 9),
        dec_cycle_dilations = (1, 3, 9),
        multi_spectral_window_powers_of_two = tuple(range(6, 12)),
        multi_spectral_n_ffts = 512,
        multi_spectral_n_mels = 64,
        recon_loss_weight = 1.,
        multi_spectral_recon_loss_weight = 1e-5,
        adversarial_loss_weight = 1.,
        feature_loss_weight = 100,
        quantize_dropout_cutoff_index = 1,
        target_sample_hz = 16000,
        use_local_attn = True,
        attn_window_size = 128,
        attn_dim_head = 64,
        attn_heads = 8,
        attn_depth = 1,
        attn_xpos_scale_base = None,
        attn_dynamic_pos_bias = False,
        use_gate_loop_layers = False,
        squeeze_excite = False,
        complex_stft_discr_logits_abs = True,
        pad_mode = 'reflect',
        stft_discriminator: Optional[Module] = None,  
        complex_stft_discr_kwargs: dict = dict()
    @property
    def device(self):
        # 返回模型参数所在的设备
        return next(self.parameters()).device

    @property
    def configs(self):
        # 返回模型的配置信息
        return pickle.loads(self._configs)

    def decode_from_codebook_indices(self, quantized_indices):
        # 从量化索引解码得到输出
        assert quantized_indices.dtype in (torch.long, torch.int32)

        if quantized_indices.ndim == 3:
            quantized_indices = rearrange(quantized_indices, 'b n (g q) -> g b n q', g = self.rq_groups)

        x = self.rq.get_output_from_indices(quantized_indices)

        return self.decode(x)

    def decode(self, x, quantize = False):
        # 解码函数,接受输入 x 和是否进行量化的标志
        if quantize:
            x, *_ = self.rq(x)

        if exists(self.decoder_attn):
            x = self.decoder_attn(x)

        x = rearrange(x, 'b n c -> b c n')
        return self.decoder(x)

    def save(self, path):
        # 保存模型参数到指定路径
        path = Path(path)
        pkg = dict(
            model = self.state_dict(),
            config = self._configs,
            version = __version__
        )

        torch.save(pkg, str(path))

    @classmethod
    def init_and_load_from(cls, path, strict = True):
        # 初始化���从指定路径加载模型
        path = Path(path)
        assert path.exists()
        pkg = torch.load(str(path), map_location = 'cpu')

        assert 'config' in pkg, 'model configs were not found in this saved checkpoint'

        config = pickle.loads(pkg['config'])
        soundstream = cls(**config)
        soundstream.load(path, strict = strict)
        soundstream.eval()
        return soundstream
    # 加载模型参数
    def load(self, path, strict = True):
        # 将路径转换为 Path 对象
        path = Path(path)
        # 断言路径存在
        assert path.exists()
        # 加载模型参数
        pkg = torch.load(str(path), map_location = 'cpu')

        # 检查版本

        # 如果包中包含版本信息且版本小于指定版本,则打印警告信息
        if 'version' in pkg and version.parse(pkg['version']) < parsed_version:
            print(f'soundstream model being loaded was trained on an older version of audiolm-pytorch ({pkg["version"]})')

        # 检查是否有 EMA 模型
        has_ema = 'ema_model' in pkg
        # 选择要加载的模型参数
        model_pkg = pkg['ema_model'] if has_ema else pkg['model']

        # 如果有 EMA 模型,则对模型参数进行处理
        if has_ema:
            # 过滤出以 'ema_model.' 开头的键
            model_pkg = filter_by_keys(lambda k: k.startswith('ema_model.'), model_pkg)
            # 将键名中的 'ema_model.' 替换为空
            model_pkg = map_keys(lambda k: k[len('ema_model.'):], model_pkg)

        # 加载模型参数
        self.load_state_dict(model_pkg, strict = strict)

    # 从训练器保存的对象中加载模型参数
    def load_from_trainer_saved_obj(self, path):
        # 将路径转换为 Path 对象
        path = Path(path)
        # 断言路径存在
        assert path.exists()
        # 加载模型参数
        obj = torch.load(str(path))
        self.load_state_dict(obj['model'])

    # 返回非判别器参数
    def non_discr_parameters(self):
        return [
            *self.encoder.parameters(),
            *self.decoder.parameters(),
            *(self.encoder_attn.parameters() if exists(self.encoder_attn) else []),
            *(self.decoder_attn.parameters() if exists(self.decoder_attn) else []),
            *self.encoder_film.parameters(),
            *self.decoder_film.parameters(),
            *self.rq.parameters()
        ]

    # 返回序列长度的倍数
    @property
    def seq_len_multiple_of(self):
        return functools.reduce(lambda x, y: x * y, self.strides)

    # 返回下采样因子
    @property
    def downsample_factor(self):
        return self.seq_len_multiple_of

    # 处理输入数据
    def process_input(
        self,
        x,
        input_sample_hz = None,
        curtail_from_left = False
    ):
        # 打包输入数据
        x, ps = pack([x], '* n')

        # 如果输入采样率存在,则重新采样输入数据
        if exists(input_sample_hz):
            x = resample(x, input_sample_hz, self.target_sample_hz)

        # 对输入数据进行截断
        x = curtail_to_multiple(x, self.seq_len_multiple_of, from_left = curtail_from_left)

        # 如果输入数据维度为 2,则重新排列维度
        if x.ndim == 2:
            x = rearrange(x, 'b n -> b 1 n')

        return x, ps

    # 对音频数据进行编码
    @torch.no_grad()
    def tokenize(self, audio):
        self.eval()
        return self.forward(audio, return_codes_only = True)

    # 前向传播函数
    def forward(
        self,
        x,
        target = None,
        is_denoising = None, # 如果要学习教 SoundStream 进行去噪的 film conditioner - 需要在上面传入目标
        return_encoded = False,
        return_codes_only = False,
        return_discr_loss = False,
        return_discr_losses_separately = False,
        return_loss_breakdown = False,
        return_recons_only = False,
        input_sample_hz = None,
        apply_grad_penalty = False,
        curtail_from_left = False
# 定义一个默认的音频语音流函数,参数包括步长、目标采样率和 RQ 量化器数量
def AudioLMSoundStream(
    strides = (2, 4, 5, 8),
    target_sample_hz = 16000,
    rq_num_quantizers = 12,
    **kwargs
):
    # 返回一个音频流对象,参数包括步长、目标采样率和 RQ 量化器数量
    return SoundStream(
        strides = strides,
        target_sample_hz = target_sample_hz,
        rq_num_quantizers = rq_num_quantizers,
        **kwargs
    )

# 定义一个默认的音乐语音流函数,参数包括步长、目标采样率和 RQ 量化器数量
def MusicLMSoundStream(
    strides = (3, 4, 5, 8),
    target_sample_hz = 24000,
    rq_num_quantizers = 12,
    **kwargs
):
    # 返回一个音频流对象,参数包括步长、目标采样率和 RQ 量化器数量
    return SoundStream(
        strides = strides,
        target_sample_hz = target_sample_hz,
        rq_num_quantizers = rq_num_quantizers,
        **kwargs
    )

.\lucidrains\audiolm-pytorch\audiolm_pytorch\t5.py

# 导入 torch 库
import torch
# 导入 transformers 库
import transformers
# 从 transformers 库中导入 T5Tokenizer, T5EncoderModel, T5Config
from transformers import T5Tokenizer, T5EncoderModel, T5Config
# 从 beartype 库中导入 beartype, Union, List
from beartype import beartype
from beartype.typing import Union, List

# 设置 transformers 库的日志级别为 error,减少警告信息
transformers.logging.set_verbosity_error()

# 定义一个辅助函数 exists,用于判断值是否存在
def exists(val):
    return val is not None

# 配置常量
MAX_LENGTH = 256
DEFAULT_T5_NAME = 'google/t5-v1_1-base'
T5_CONFIGS = {}

# 全局单例变量

# 获取指定名称的 tokenizer
def get_tokenizer(name):
    tokenizer = T5Tokenizer.from_pretrained(name)
    return tokenizer

# 获取指定名称的模型
def get_model(name):
    model = T5EncoderModel.from_pretrained(name)
    return model

# 获取指定名称的模型和 tokenizer
def get_model_and_tokenizer(name):
    global T5_CONFIGS

    if name not in T5_CONFIGS:
        T5_CONFIGS[name] = dict()

    if "model" not in T5_CONFIGS[name]:
        T5_CONFIGS[name]["model"] = get_model(name)

    if "tokenizer" not in T5_CONFIGS[name]:
        T5_CONFIGS[name]["tokenizer"] = get_tokenizer(name)

    return T5_CONFIGS[name]['model'], T5_CONFIGS[name]['tokenizer']

# 获取编码维度
def get_encoded_dim(name):
    if name not in T5_CONFIGS:
        config = T5Config.from_pretrained(name)
        T5_CONFIGS[name] = dict(config = config)

    elif "config" in T5_CONFIGS[name]:
        config = T5_CONFIGS[name]["config"]

    elif "model" in T5_CONFIGS[name]:
        config = T5_CONFIGS[name]["model"].config

    else:
        raise ValueError(f'unknown t5 name {name}')

    return config.d_model

# 对文本进行编码
@beartype
def t5_encode_text(
    texts: Union[str, List[str]],
    name = DEFAULT_T5_NAME,
    output_device = None
):
    # 如果 texts 是字符串,则转换为列表
    if isinstance(texts, str):
        texts = [texts]

    # 获取指定名称的模型和 tokenizer
    t5, tokenizer = get_model_and_tokenizer(name)

    # 如果 CUDA 可用,则将模型移至 CUDA
    if torch.cuda.is_available():
        t5 = t5.cuda()

    # 获取模型的设备
    device = next(t5.parameters()).device

    # 对文本进行编码
    encoded = tokenizer.batch_encode_plus(
        texts,
        return_tensors = 'pt',
        padding = 'longest',
        max_length = MAX_LENGTH,
        truncation = True
    )

    # 将输入张量和注意力掩��移至设备
    input_ids = encoded.input_ids.to(device)
    attn_mask = encoded.attention_mask.to(device)

    # 设置模型为评估模式
    t5.eval()

    # 进行推理
    with torch.inference_mode():
        output = t5(input_ids = input_ids, attention_mask = attn_mask)
        encoded_text = output.last_hidden_state.detach()

    # 扩展注意力掩码的维度
    attn_mask = attn_mask[..., None].bool()

    # 如果输出设备不存在,则对编码文本进行掩码填充并返回
    if not exists(output_device):
        encoded_text = encoded_text.masked_fill(~attn_mask, 0.)
        return encoded_text

    # 将编码文本和注意力掩码移至输出设备
    encoded_text.to(output_device)
    attn_mask.to(output_device)

    # 对编码文本进行掩码填充并返回
    encoded_text = encoded_text.masked_fill(~attn_mask, 0.)
    return encoded_text

.\lucidrains\audiolm-pytorch\audiolm_pytorch\trainer.py

# 导入所需的库
import re
import copy
from math import sqrt
from datetime import timedelta
from random import choice
from pathlib import Path
from shutil import rmtree
from functools import partial
from collections import Counter
from contextlib import contextmanager, nullcontext

# 导入类型提示相关的库
from beartype.typing import Union, List, Optional, Tuple, Type
from typing_extensions import Annotated

# 导入 beartype 相关的库
from beartype import beartype
from beartype.door import is_bearable
from beartype.vale import Is

# 导入 PyTorch 相关的库
import torch
import torchaudio
from torch import nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR, _LRScheduler
from torch.utils.data import Dataset, DataLoader, random_split

# 导入 pytorch_warmup 库
import pytorch_warmup as warmup

# 导入 einops 库
from einops import rearrange

# 导入 audiolm_pytorch 相关的库
from audiolm_pytorch.optimizer import get_optimizer
import wandb
from ema_pytorch import EMA
from audiolm_pytorch.soundstream import SoundStream
from audiolm_pytorch.encodec import EncodecWrapper
from audiolm_pytorch.audiolm_pytorch import (
    SemanticTransformer,
    SemanticTransformerWrapper,
    CoarseTransformer,
    CoarseTransformerWrapper,
    FineTransformer,
    FineTransformerWrapper,
    FairseqVQWav2Vec,
    HubertWithKmeans
)

# 导入 audiolm_pytorch 中的数据处理相关的库
from audiolm_pytorch.data import SoundDataset, get_dataloader
from audiolm_pytorch.utils import AudioConditionerBase

# 导入 audiolm_pytorch 版本相关的库
from audiolm_pytorch.version import __version__
from packaging import version

# 导入 accelerate 相关的库
from accelerate import Accelerator, DistributedType
from accelerate.utils import DistributedDataParallelKwargs, InitProcessGroupKwargs
from accelerate.tracking import WandBTracker

# 常量定义

DEFAULT_SAMPLE_RATE = 16000

# 定义 ConstantLRScheduler 为 LambdaLR 的部分应用
ConstantLRScheduler = partial(LambdaLR, lr_lambda = lambda step: 1.)

# 确保只有一个 Trainer 实例化

ONE_TRAINER_INSTANTIATED = False

def check_one_trainer():
    global ONE_TRAINER_INSTANTIATED
    assert not ONE_TRAINER_INSTANTIATED, 'only one Trainer can be instantiated at a time for training'
    ONE_TRAINER_INSTANTIATED = True

DEFAULT_DDP_KWARGS = DistributedDataParallelKwargs(find_unused_parameters = True)

# 用于自动将数据从数据集传递到变换器包装器的关键字

DATASET_FIELD_TYPE_CONFIG = dict(
    raw_wave = Annotated[
        torch.Tensor,
        Is[lambda t: t.dtype == torch.float and t.ndim in {2, 3}]
    ],
    text = List[str],
    text_embeds = Annotated[
        torch.Tensor,
        Is[lambda t: t.dtype == torch.float and t.ndim == 3]
    ],
)

# 辅助函数

def exists(val):
    return val is not None

def default(val, d):
    return val if exists(val) else d

def noop(*args, **kwargs):
    pass

def find_first(cond, arr):
    for el in arr:
        if cond(el):
            return el
    return None

def cycle(dl):
    while True:
        for data in dl:
            yield data

def cast_tuple(t):
    return t if isinstance(t, (tuple, list)) else (t,)

def yes_or_no(question):
    answer = input(f'{question} (y/n) ')
    return answer.lower() in ('yes', 'y')

def accum_log(log, new_logs):
    for key, new_value in new_logs.items():
        old_value = log.get(key, 0.)
        log[key] = old_value + new_value
    return log

def dict_values_to_device(d: dict, device):
    out = {}
    for k, v in d.items():
        out[k] = v.to(device) if torch.is_tensor(v) else v
    return out

# 自动将数据传递到模块关键字参数路由函数

def has_duplicates(tup):
    counts = dict(Counter(tup))
    return any(filter(lambda count: count > 1, counts.values()))

def determine_types(data, config):
    output = []
    for el in data:
        for name, data_type in config.items():
            if is_bearable(el, data_type):
                output.append(name)
                break
        else:
            raise TypeError(f'unable to determine type of {data}')

    return tuple(output)

def checkpoint_num_steps(checkpoint_path):
    """Returns the number of steps trained from a checkpoint based on the filename.
    # 假设文件名格式类似于"/path/to/semantic.transformer.20000.pt",表示训练步数为2万步。在这种情况下返回20000
    """
    # 使用正则表达式查找文件路径中的数字部分,并返回结果列表
    results = re.findall(r'\d+', str(checkpoint_path))

    # 如果结果列表为空,则返回0
    if len(results) == 0:
        return 0

    # 返回结果列表中最后一个元素(即最后一个数字)
    return int(results[-1])
# 定义一个带有调度器和热身启动的优化器类
class OptimizerWithWarmupSchedule(nn.Module):
    # 初始化函数
    @beartype
    def __init__(
        self,
        accelerator: Accelerator,
        optimizer: Optimizer,
        scheduler: Optional[Type[_LRScheduler]] = None,
        scheduler_kwargs: dict = dict(),
        warmup_steps: int = 0
    ):
        super().__init__()
        # 创建一个线性热身启动对象
        self.warmup = warmup.LinearWarmup(optimizer, warmup_period = warmup_steps)

        # 如果调度器存在,则使用给定的调度器,否则使用常数学习率调度器
        if exists(scheduler):
            self.scheduler = scheduler(optimizer, **scheduler_kwargs)
        else:
            self.scheduler = ConstantLRScheduler(optimizer)

        self.optimizer = optimizer

        # 准备优化器和调度器
        self.optimizer, self.scheduler = accelerator.prepare(self.optimizer, self.scheduler)
        self.accelerator = accelerator

    # 返回状态字典
    def state_dict(self):
        return dict(
            optimizer = self.optimizer.state_dict(),
            scheduler = self.scheduler.state_dict(),
            warmup = self.warmup.state_dict()
        )

    # 加载状态字典
    def load_state_dict(self, pkg):
        self.optimizer.load_state_dict(pkg['optimizer'])
        self.scheduler.load_state_dict(pkg['scheduler'])
        self.warmup.load_state_dict(pkg['warmup'])

    # 清零梯度
    def zero_grad(self):
        self.optimizer.zero_grad()

    # 执行优化步骤
    def step(self):
        self.optimizer.step()

        # 如果优化步骤未被跳过,则执行调度器步骤
        if not self.accelerator.optimizer_step_was_skipped:
            with self.warmup.dampening():
                self.scheduler.step()

# 主训练器类
class SoundStreamTrainer(nn.Module):
    # 初始化函数
    @beartype
    def __init__(
        self,
        soundstream: SoundStream,
        *,
        num_train_steps: int,
        batch_size: int,
        data_max_length: int = None,
        data_max_length_seconds: Union[int, float] = None,
        folder: str = None,
        dataset: Optional[Dataset] = None,
        val_dataset: Optional[Dataset] = None,
        train_dataloader: Optional[DataLoader] = None,
        val_dataloader: Optional[DataLoader] = None,
        lr: float = 2e-4,
        grad_accum_every: int = 4,
        wd: float = 0.,
        warmup_steps: int = 1000,
        scheduler: Optional[Type[_LRScheduler]] = None,
        scheduler_kwargs: dict = dict(),
        discr_warmup_steps: Optional[int] = None,
        discr_scheduler: Optional[Type[_LRScheduler]] = None,
        discr_scheduler_kwargs: dict = dict(),
        max_grad_norm: float = 0.5,
        discr_max_grad_norm: float = None,
        save_results_every: int = 100,
        save_model_every: int = 1000,
        log_losses_every: int = 1,
        results_folder: str = './results',
        valid_frac: float = 0.05,
        random_split_seed: int = 42,
        use_ema: bool = True,
        ema_beta: float = 0.995,
        ema_update_after_step: int = 500,
        ema_update_every: int = 10,
        apply_grad_penalty_every: int = 4,
        dl_num_workers: int = 0,
        accelerator: Optional[Accelerator] = None,
        accelerate_kwargs: dict = dict(),
        init_process_group_timeout_seconds = 1800,
        dataloader_drop_last = True,
        split_batches = False,
        use_wandb_tracking = False,
        force_clear_prev_results: bool = None  # set to True | False to skip the prompt
    @property
    def ema_tokenizer(self):
        return self.ema_soundstream.ema_model

    # 对音频进行标记化处理
    def tokenize(self, audio):
        return ema_tokenizer.tokenize(audio)

    # 将模型设置为指数移动平均模型
    def set_model_as_ema_model_(self):
        """ this will force the main 'online' model to have same parameters as the exponentially moving averaged model """
        assert self.use_ema
        self.ema_soundstream.ema_model.load_state_dict(self.soundstream.state_dict())
    # 保存模型参数到指定路径
    def save(self, path):
        # 构建包含模型参数、优化器状态、配置信息等的字典
        pkg = dict(
            model = self.accelerator.get_state_dict(self.soundstream),
            optim = self.optim.state_dict(),
            config = self.unwrapped_soundstream._configs,
            discr_optim = self.discr_optim.state_dict(),
            version = __version__
        )

        # 如果使用指数移动平均模型,保存其参数
        if self.use_ema:
            pkg['ema_model'] = self.ema_soundstream.state_dict()

        # 遍历多尺度鉴别器优化器,保存其参数
        for key, _ in self.multiscale_discriminator_iter():
            discr_optim = getattr(self, key)
            pkg[key] = discr_optim.state_dict()

        # 保存整个包含模型参数的字典到指定路径
        torch.save(pkg, path)

    # 获取未包装的声音流模型
    @property
    def unwrapped_soundstream(self):
        return self.accelerator.unwrap_model(self.soundstream)

    # 加载模型参数
    def load(self, path):
        path = Path(path)
        assert path.exists()
        # 加载模型参数字典
        pkg = torch.load(str(path), map_location = 'cpu')

        # 如果加载的是旧版本,进行特殊处理

        if len(pkg.keys()) > 20:
            self.unwrapped_soundstream.load_state_dict(pkg)

            if self.use_ema:
                self.ema_soundstream.ema_model.load_state_dict(pkg)
            return

        # 检查版本

        if 'version' in pkg and version.parse(pkg['version']) < version.parse(__version__):
            print(f'model was trained on older version {pkg["version"]} of audiolm-pytorch')

        # 否则正常加载模型参数

        self.unwrapped_soundstream.load_state_dict(pkg['model'])

        if self.use_ema:
            assert 'ema_model' in pkg
            self.ema_soundstream.load_state_dict(pkg['ema_model'])

        self.optim.load_state_dict(pkg['optim'])
        self.discr_optim.load_state_dict(pkg['discr_optim'])

        for key, _ in self.multiscale_discriminator_iter():
            discr_optim = getattr(self, key)
            discr_optim.load_state_dict(pkg[key])

        # + 1 以从下一步开始,避免覆盖最后一个检查点

        self.steps = torch.tensor([checkpoint_num_steps(path) + 1], device=self.device)

    # 遍历多尺度鉴别器
    def multiscale_discriminator_iter(self):
        for ind, discr in enumerate(self.unwrapped_soundstream.discriminators):
            yield f'multiscale_discr_optimizer_{ind}', discr

    # 遍历多尺度鉴别器优化器
    def multiscale_discriminator_optim_iter(self):
        for name, _ in self.multiscale_discriminator_iter():
            yield name, getattr(self, name)

    # 打印消息
    def print(self, msg):
        self.accelerator.print(msg)

    # 记录日志
    def log(self, **logs_as_kwargs):
        self.accelerator.log(logs_as_kwargs, step = self.steps.item())

    # 使用wandb跟踪器
    @contextmanager
    def wandb_tracker(self, project, run = None, hps = None):
        assert self.use_wandb_tracking, '`use_wandb_tracking` must be set to True on SoundStreamTrainer'

        hps = default(hps, self.tracker_hps)

        self.accelerator.init_trackers(project, config = None)

        if exists(run):
            wandb_tracker = find_first(lambda el: isinstance(el, WandBTracker), self.accelerator.trackers)
            assert exists(wandb_tracker)

            wandb_tracker.run.name = run

        yield

        self.accelerator.end_training()

    # 获取设备
    @property
    def device(self):
        return self.accelerator.device

    # 是否分布式训练
    @property
    def is_distributed(self):
        return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1)

    # 是否主进程
    @property
    def is_main(self):
        return self.accelerator.is_main_process

    # 是否本地主进程
    @property
    def is_local_main(self):
        return self.accelerator.is_local_main_process

    # 训练模型
    def train(self, log_fn = noop):

        while self.steps < self.num_train_steps:
            logs = self.train_step()
            log_fn(logs)

        self.print('training complete')
# 语义转换器训练器

class SemanticTransformerTrainer(nn.Module):
    @beartype
    def __init__(
        self,
        wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]],
        transformer: SemanticTransformer,
        *,
        num_train_steps,
        batch_size,
        audio_conditioner: Optional[AudioConditionerBase] = None,
        dataset: Optional[Dataset] = None,
        valid_dataset: Optional[Dataset] = None,
        data_max_length = None,
        data_max_length_seconds = None,
        folder = None,
        lr = 3e-4,
        grad_accum_every = 1,
        wd = 0.,
        max_grad_norm = 0.5,
        valid_frac = 0.05,
        random_split_seed = 42,
        save_results_every = 100,
        save_model_every = 1000,
        results_folder = './results',
        accelerate_kwargs: dict = dict(),
        init_process_group_timeout_seconds = 1800,
        use_wandb_tracking = False,
        split_batches = False,
        drop_last = False,
        force_clear_prev_results = None,
        average_valid_loss_over_grad_accum_every: bool = True, # if False, valid loss on a single batch
    # 保存模型参数到指定路径
    def save(self, path):
        pkg = dict(
            model = self.accelerator.get_state_dict(self.transformer),
            optim = self.optim.state_dict(),
            version = __version__
        )
        torch.save(pkg, path)

    # 从指定路径加载模型参数
    def load(self, path):
        transformer = self.accelerator.unwrap_model(self.transformer)
        pkg = transformer.load(path)
        # 特定于训练器的操作
        self.optim.load_state_dict(pkg['optim'])

        # + 1 to start from the next step and avoid overwriting the last checkpoint
        self.steps = torch.tensor([checkpoint_num_steps(path) + 1], device=self.device)


    # 打印消息
    def print(self, msg):
        self.accelerator.print(msg)

    # 生成结果
    def generate(self, *args, **kwargs):
        return self.train_wrapper.generate(*args, **kwargs)

    @property
    def device(self):
        return self.accelerator.device

    @property
    def is_distributed(self):
        return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1)

    @property
    def is_main(self):
        return self.accelerator.is_main_process

    @property
    def is_local_main(self):
        return self.accelerator.is_local_main_process

    # 将数据元组转换为关键字参数
    def data_tuple_to_kwargs(self, data):
        if not exists(self.ds_fields):
            self.ds_fields = determine_types(data, DATASET_FIELD_TYPE_CONFIG)
            assert not has_duplicates(self.ds_fields), 'dataset fields must not have duplicate field names'

        return dict(zip(self.ds_fields, data))

    @contextmanager
    def wandb_tracker(self, project, run = None, hps = None):
        assert self.use_wandb_tracking, '`use_wandb_tracking` must be set to True on SemanticTransformerTrainer'

        hps = default(hps, self.tracker_hps)

        self.accelerator.init_trackers(project, config = None)

        if exists(run):
            wandb_tracker = find_first(lambda el: isinstance(el, WandBTracker), self.accelerator.trackers)
            assert exists(wandb_tracker)

            wandb_tracker.run.name = run

        yield

        self.accelerator.end_training()
    # 定义训练步骤函数
    def train_step(self):
        # 获取设备信息
        device = self.device

        # 获取当前步数
        steps = int(self.steps.item())

        # 设置 Transformer 模型为训练模式
        self.transformer.train()

        # 初始化日志字典
        logs = {}

        # 更新 Transformer 模型
        for i in range(self.grad_accum_every):
            # 判断是否为最后一次迭代
            is_last = i == (self.grad_accum_every - 1)
            # 根据是否为最后一次迭代选择上下文管理器
            context = partial(self.accelerator.no_sync, self.train_wrapper) if not is_last else nullcontext

            # 将数据转换为关键字参数
            data_kwargs = self.data_tuple_to_kwargs(next(self.dl_iter))

            # 使用自动混合精度和上下文管理器进行训练
            with self.accelerator.autocast(), context():
                # 计算损失
                loss = self.train_wrapper(**data_kwargs, return_loss = True)

                # 反向传播
                self.accelerator.backward(loss / self.grad_accum_every)

            # 累积损失日志
            accum_log(logs, {'loss': loss.item() / self.grad_accum_every})

        # 根据最大梯度范数对梯度进行裁剪
        if exists(self.max_grad_norm):
            self.accelerator.clip_grad_norm_(self.transformer.parameters(), self.max_grad_norm)

        # 更新优化器
        self.optim.step()
        self.optim.zero_grad()

        # 打印日志
        self.print(f"{steps}: loss: {logs['loss']}")
        self.accelerator.log({"train_loss": logs['loss']}, step=steps)

        # 每隔一段时间对结果进行采样
        self.accelerator.wait_for_everyone()

        # 如果是主进程且满足保存结果的条件
        if self.is_main and not (steps % self.save_results_every):
            # 初始化验证损失
            valid_loss = 0
            # 获取未包装的模型
            unwrapped_model = self.accelerator.unwrap_model(self.train_wrapper)

            # 计算平均验证损失
            for _ in range(self.average_valid_loss_over_grad_accum_every):
                data_kwargs = self.data_tuple_to_kwargs(next(self.valid_dl_iter))
                data_kwargs = dict_values_to_device(data_kwargs, unwrapped_model.device)

                with torch.inference_mode():
                    unwrapped_model.eval()
                    valid_loss += unwrapped_model(**data_kwargs, return_loss = True)

            valid_loss = valid_loss.clone() # 避免推理模���到非推理模式的错误
            valid_loss /= self.average_valid_loss_over_grad_accum_every

            # 打印验证损失日志
            self.print(f'{steps}: valid loss {valid_loss}')
            self.accelerator.log({"valid_loss": valid_loss}, step=steps)

        # 每隔一段时间保存模型
        if self.is_main and not (steps % self.save_model_every):
            model_path = str(self.results_folder / f'semantic.transformer.{steps}.pt')
            self.save(model_path)
            if self.use_wandb_tracking:
                wandb.save(model_path)
            self.print(f'{steps}: saving model to {str(self.results_folder)}')

        self.accelerator.wait_for_everyone()

        # 更新步数
        self.steps.add_(1)
        return logs

    # 训练函数
    def train(self, log_fn = noop):

        # 循环训练直到达到指定步数
        while self.steps < self.num_train_steps:
            logs = self.train_step()
            log_fn(logs)

        # 打印训练完成信息
        self.print('training complete')
# 定义粗糙变换器训练器类
class CoarseTransformerTrainer(nn.Module):
    # 初始化方法
    @beartype
    def __init__(
        self,
        transformer: CoarseTransformer,  # 粗糙变换器对象
        codec: Union[SoundStream, EncodecWrapper],  # 编解码器对象
        wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]],  # 可选的音频向量化器对象
        *,
        num_train_steps,  # 训练步数
        batch_size,  # 批量大小
        audio_conditioner: Optional[AudioConditionerBase] = None,  # 可选的音频调节器对象
        dataset: Optional[Dataset] = None,  # 可选的数据集对象
        valid_dataset: Optional[Dataset] = None,  # 可选的验证数据集对象
        ds_fields: Tuple[str, ...] = ('raw_wave', 'raw_wave_for_codec', 'text'),  # 数据集字段元组
        data_max_length = None,  # 数据最大长度
        data_max_length_seconds = None,  # 数据最大长度(秒)
        folder = None,  # 文件夹路径
        lr = 3e-4,  # 学习率
        grad_accum_every = 1,  # 梯度累积频率
        wd = 0.,  # 权重衰减
        max_grad_norm = 0.5,  # 最大梯度范数
        valid_frac = 0.05,  # 验证集比例
        random_split_seed = 42,  # 随机拆分种子
        save_results_every = 100,  # 每隔多少步保存结果
        save_model_every = 1000,  # 每隔多少步保存模型
        results_folder = './results',  # 结果文件夹路径
        accelerate_kwargs: dict = dict(),  # 加速参数字典
        init_process_group_timeout_seconds = 1800,  # 初始化进程组超时时间(秒)
        split_batches = False,  # 是否拆分批次
        drop_last = False,  # 是否丢弃最后一批
        force_clear_prev_results = None,  # 强制清除之前的结果
        use_wandb_tracking = False,  # 是否使用WandB跟踪
        average_valid_loss_over_grad_accum_every: bool = True,  # 是否在梯度累积频率上平均验证损失
    # 保存方法
    def save(self, path):
        # 封装模型、优化器状态字典和版本信息,保存到指定路径
        pkg = dict(
            model = self.accelerator.get_state_dict(self.transformer),
            optim = self.optim.state_dict(),
            version = __version__
        )
        torch.save(pkg, path)

    # 加载方法
    def load(self, path):
        # 解封装模型,加载模型状态字典和优化器状态字典
        transformer = self.accelerator.unwrap_model(self.transformer)
        pkg = transformer.load(path)
        # 加载训练器特定内容
        self.optim.load_state_dict(pkg['optim'])

        # 从下一步开始,避免覆盖最后一个检查点
        self.steps = torch.tensor([checkpoint_num_steps(path) + 1], device=self.device)

    # 打印方法
    def print(self, msg):
        # 打印消息
        self.accelerator.print(msg)

    # 生成方法
    def generate(self, *args, **kwargs):
        return self.train_wrapper.generate(*args, **kwargs)

    # WandB跟踪器上下文管理器
    @contextmanager
    def wandb_tracker(self, project, run = None, hps = None):
        assert self.use_wandb_tracking, '`use_wandb_tracking` must be set to True on CoarseTransformerTrainer'

        hps = default(hps, self.tracker_hps)

        self.accelerator.init_trackers(project, config = None)

        if exists(run):
            wandb_tracker = find_first(lambda el: isinstance(el, WandBTracker), self.accelerator.trackers)
            assert exists(wandb_tracker)

            wandb_tracker.run.name = run

        yield

        self.accelerator.end_training()  

    # 设备属性
    @property
    def device(self):
        return self.accelerator.device

    # 是否分布式属性
    @property
    def is_distributed(self):
        return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1)

    # 是否主进程属性
    @property
    def is_main(self):
        return self.accelerator.is_main_process

    # 是否本地主进程属性
    @property
    def is_local_main(self):
        return self.accelerator.is_local_main_process
    # 定义训练步骤函数
    def train_step(self):
        # 获取设备信息
        device = self.device

        # 获取当前步数
        steps = int(self.steps.item())

        # 设置 Transformer 模型为训练模式
        self.transformer.train()

        # 初始化日志字典
        logs = {}

        # 更新 Transformer 模型
        for i in range(self.grad_accum_every):
            # 判断是否是最后一次迭代
            is_last = i == (self.grad_accum_every - 1)
            # 根据是否是最后一次迭代选择上下文管理器
            context = partial(self.accelerator.no_sync, self.train_wrapper) if not is_last else nullcontext

            # 从数据加载器迭代器中获取数据关键字参数
            data_kwargs = dict(zip(self.ds_fields, next(self.dl_iter)))

            # 在自动混合精度下,执行训练包装器
            with self.accelerator.autocast(), context():
                loss = self.train_wrapper(
                    **data_kwargs,
                    return_loss = True
                )

                # 反向传播并计算梯度
                self.accelerator.backward(loss / self.grad_accum_every)

            # 累积损失日志
            accum_log(logs, {'loss': loss.item() / self.grad_accum_every})

        # 如果存在最大梯度范数限制,则进行梯度裁剪
        if exists(self.max_grad_norm):
            self.accelerator.clip_grad_norm_(self.transformer.parameters(), self.max_grad_norm)

        # 更新优化器参数
        self.optim.step()
        self.optim.zero_grad()

        # 记录日志
        self.print(f"{steps}: loss: {logs['loss']}")
        self.accelerator.log({"train_loss": logs['loss']}, step=steps)

        # 定期采样结果

        self.accelerator.wait_for_everyone()

        # 如果是主进程且满足保存结果的条件
        if self.is_main and not (steps % self.save_results_every):
            valid_loss = 0
            unwrapped_model = self.accelerator.unwrap_model(self.train_wrapper)

            # 计算平均验证损失
            for i in range(self.average_valid_loss_over_grad_accum_every):
                data_kwargs = dict(zip(self.ds_fields, next(self.valid_dl_iter)))
                data_kwargs = dict_values_to_device(data_kwargs, unwrapped_model.device)

                with torch.no_grad():
                    unwrapped_model.eval()

                    valid_loss += unwrapped_model(
                        **data_kwargs,
                        return_loss = True
                    )

            valid_loss = valid_loss.clone() # 避免推理模式到非推理模式的错误
            valid_loss /= self.average_valid_loss_over_grad_accum_every

            # 记录验证损失日志
            self.print(f'{steps}: valid loss {valid_loss}')
            self.accelerator.log({"valid_loss": valid_loss}, step=steps)

        # 定期保存模型
        if self.is_main and not (steps % self.save_model_every):
            model_path = str(self.results_folder / f'coarse.transformer.{steps}.pt')
            self.save(model_path)
            if self.use_wandb_tracking:
                wandb.save(model_path)
            self.print(f'{steps}: saving model to {str(self.results_folder)}')

        self.accelerator.wait_for_everyone()

        # 更新步数
        self.steps.add_(1)
        return logs

    # 训练函数
    def train(self, log_fn = noop):

        # 在未达到训练步数之前循环执行训练步骤
        while self.steps < self.num_train_steps:
            logs = self.train_step()
            log_fn(logs)

        # 打印训练完成信息
        self.print('training complete')
# 定义一个 FineTransformerTrainer 类,用于训练 FineTransformer 模型
class FineTransformerTrainer(nn.Module):
    # 初始化方法
    @beartype
    def __init__(
        self,
        transformer: FineTransformer,  # 接收 FineTransformer 模型
        codec: Union[SoundStream, EncodecWrapper],  # 接收音频流或编码器包装器
        *,
        num_train_steps,  # 训练步数
        batch_size,  # 批量大小
        audio_conditioner: Optional[AudioConditionerBase] = None,  # 可选的音频调节器
        dataset: Optional[Dataset] = None,  # 可选的数据集
        valid_dataset: Optional[Dataset] = None,  # 可选的验证数据集
        data_max_length = None,  # 数据最大长度
        data_max_length_seconds = None,  # 数据最大长度(秒)
        dataset_normalize = False,  # 是否对数据集进行归一化
        folder = None,  # 文件夹路径
        lr = 3e-4,  # 学习率
        grad_accum_every = 1,  # 梯度累积频率
        wd = 0.,  # 权重衰减
        max_grad_norm = 0.5,  # 最大梯度范数
        valid_frac = 0.05,  # 验证集比例
        random_split_seed = 42,  # 随机拆分种子
        save_results_every = 100,  # 每隔多少步保存结果
        save_model_every = 1000,  # 每隔多少步保存模型
        results_folder = './results',  # 结果保存文件夹路径
        accelerate_kwargs: dict = dict(),  # 加速参数
        init_process_group_timeout_seconds = 1800,  # 初始化进程组超时时间(秒)
        split_batches = False,  # 是否拆分批次
        drop_last = False,  # 是否丢弃最后一批次
        use_wandb_tracking = False,  # 是否使用 WandB 追踪
        force_clear_prev_results = None,  # 强制清除之前的结果
        average_valid_loss_over_grad_accum_every: bool = True,  # 是否在梯度累积频率上计算验证损失的平均值
    # 保存模型方法
    def save(self, path):
        pkg = dict(
            model = self.accelerator.get_state_dict(self.transformer),  # 获取模型状态字典
            optim = self.optim.state_dict(),  # 获取优化器状态字典
            version = __version__  # 版本信息
        )
        torch.save(pkg, path)  # 保存模型参数到指定路径

    # 加载模型方法
    def load(self, path):
        transformer = self.accelerator.unwrap_model(self.transformer)  # 解封装模型
        pkg = transformer.load(path)  # 加载模型参数
        # 特定于训练器的操作
        self.optim.load_state_dict(pkg['optim'])  # 加载优化器参数

        # + 1 to start from the next step and avoid overwriting the last checkpoint
        self.steps = torch.tensor([checkpoint_num_steps(path) + 1], device=self.device)  # 设置训练步数

    # 打印方法
    def print(self, msg):
        self.accelerator.print(msg)  # 打印消息

    # 生成方法
    def generate(self, *args, **kwargs):
        return self.train_wrapper.generate(*args, **kwargs)  # 生成结果

    # WandB 追踪上下文管理器
    @contextmanager
    def wandb_tracker(self, project, run = None, hps = None):
        assert self.use_wandb_tracking, '`use_wandb_tracking` must be set to True on FineTransformerTrainer'  # 断言是否启用 WandB 追踪

        hps = default(hps, self.tracker_hps)  # 设置超参数

        self.accelerator.init_trackers(project, config = None)  # 初始化追踪器

        if exists(run):
            wandb_tracker = find_first(lambda el: isinstance(el, WandBTracker), self.accelerator.trackers)  # 查找 WandB 追踪器
            assert exists(wandb_tracker)  # 断言是否存在 WandB 追踪器

            wandb_tracker.run.name = run  # 设置运行名称

        yield  # 生成结果

        self.accelerator.end_training()  # 结束训练

    # 设备属性
    @property
    def device(self):
        return self.accelerator.device  # 返回设备

    # 是否分布式属性
    @property
    def is_distributed(self):
        return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1)  # 判断是否分布式

    # 是否主进程属性
    @property
    def is_main(self):
        return self.accelerator.is_main_process  # 判断是否主进程

    # 是否本地主进程属性
    @property
    def is_local_main(self):
        return self.accelerator.is_local_main_process  # 判断是否本地主进程

    # 数据元组转关键字参数方法
    def data_tuple_to_kwargs(self, data):
        if not exists(self.ds_fields):
            self.ds_fields = determine_types(data, DATASET_FIELD_TYPE_CONFIG)  # 确定数据类型
            assert not has_duplicates(self.ds_fields), 'dataset fields must not have duplicate field names'  # 断言数据字段不能有重复字段名

        return dict(zip(self.ds_fields, data))  # 返回数据关键字参数
    # 定义训练步骤函数
    def train_step(self):
        # 获取设备信息
        device = self.device

        # 获取当前步数
        steps = int(self.steps.item())

        # 设置 Transformer 模型为训练模式
        self.transformer.train()

        # 初始化日志字典
        logs = {}

        # 更新 Transformer 模型
        for i in range(self.grad_accum_every):
            # 判断是否是最后一次迭代
            is_last = i == (self.grad_accum_every - 1)
            # 根据是否是最后一次迭代选择上下文管理器
            context = partial(self.accelerator.no_sync, self.train_wrapper) if not is_last else nullcontext

            # 将数据转换为关键字参数
            data_kwargs = self.data_tuple_to_kwargs(next(self.dl_iter))

            # 使用自动混合精度和上下文管理器执行训练
            with self.accelerator.autocast(), context():
                # 计算损失
                loss = self.train_wrapper(**data_kwargs, return_loss = True)

                # 反向传播
                self.accelerator.backward(loss / self.grad_accum_every)

            # 累积损失日志
            accum_log(logs, {'loss': loss.item() / self.grad_accum_every})

        # 如果存在最大梯度范数,则进行梯度裁剪
        if exists(self.max_grad_norm):
            self.accelerator.clip_grad_norm_(self.transformer.parameters(), self.max_grad_norm)

        # 更新优化器
        self.optim.step()
        self.optim.zero_grad()

        # 打印日志
        self.print(f"{steps}: loss: {logs['loss']}")
        self.accelerator.log({"train_loss": logs['loss']}, step=steps)

        # 定期采样结果
        self.accelerator.wait_for_everyone()

        # 如果是主进程且满足保存结果条件
        if self.is_main and not (steps % self.save_results_every):
            # 获取未包装的模型
            unwrapped_model = self.accelerator.unwrap_model(self.train_wrapper)
            valid_loss = 0

            # 计算验证集损失
            for i in range(self.average_valid_loss_over_grad_accum_every):
                data_kwargs = self.data_tuple_to_kwargs(next(self.valid_dl_iter))
                data_kwargs = dict_values_to_device(data_kwargs, unwrapped_model.device)

                with torch.inference_mode():
                    unwrapped_model.eval()
                    valid_loss += unwrapped_model(**data_kwargs, return_loss = True)

            valid_loss = valid_loss.clone() # 避免推理模式到非推理模式的错误
            valid_loss /= self.average_valid_loss_over_grad_accum_every

            # 打印验证集损失
            self.print(f'{steps}: valid loss {valid_loss}')
            self.accelerator.log({"valid_loss": valid_loss}, step=steps)

        # 定期保存模型
        if self.is_main and not (steps % self.save_model_every):
            model_path = str(self.results_folder / f'fine.transformer.{steps}.pt')
            self.save(model_path)
            if self.use_wandb_tracking:
                wandb.save(model_path)
            self.print(f'{steps}: saving model to {str(self.results_folder)}')

        self.accelerator.wait_for_everyone()

        # 更新步数
        self.steps.add_(1)
        return logs

    # 训练函数
    def train(self, log_fn = noop):

        # 循环执行训练步骤
        while self.steps < self.num_train_steps:
            logs = self.train_step()
            log_fn(logs)

        # 训练完成后打印信息
        self.print('training complete')