Lucidrains 系列项目源码解析(六)
.\lucidrains\audiolm-pytorch\audiolm_pytorch\data.py
from pathlib import Path
from functools import partial, wraps
from beartype import beartype
from beartype.typing import Tuple, Union, Optional
from beartype.door import is_bearable
import torchaudio
from torchaudio.functional import resample
import torch
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
from audiolm_pytorch.utils import curtail_to_multiple
from einops import rearrange, reduce
def exists(val):
return val is not None
def cast_tuple(val, length = 1):
return val if isinstance(val, tuple) else ((val,) * length)
def is_unique(arr):
return len(set(arr)) == len(arr
class SoundDataset(Dataset):
@beartype
def __init__(
self,
folder,
target_sample_hz: Union[int, Tuple[int, ...]],
exts = ['flac', 'wav', 'mp3', 'webm'],
max_length: Optional[int] = None,
seq_len_multiple_of: Optional[Union[int, Tuple[Optional[int], ...]]] = None
):
super().__init__()
path = Path(folder)
assert path.exists(), f'folder "{str(path)}" does not exist'
files = [file for ext in exts for file in path.glob(f'**/*.{ext}')]
assert len(files) > 0, 'no sound files found'
self.files = files
self.max_length = max_length
self.target_sample_hz = cast_tuple(target_sample_hz)
num_outputs = len(self.target_sample_hz)
self.max_target_sample_hz = max(self.target_sample_hz)
self.seq_len_multiple_of = cast_tuple(seq_len_multiple_of, num_outputs)
assert len(self.target_sample_hz) == len(self.seq_len_multiple_of)
def __len__(self):
return len(self.files)
def __getitem__(self, idx):
file = self.files[idx]
data, sample_hz = torchaudio.load(file)
assert data.numel() > 0, f'one of your audio file ({file}) is empty. please remove it from your folder'
if data.shape[0] > 1:
data = reduce(data, 'c ... -> 1 ...', 'mean')
data = resample(data, sample_hz, self.max_target_sample_hz)
sample_hz = self.max_target_sample_hz
max_length = self.max_length
audio_length = data.size(1)
if exists(max_length):
if audio_length > max_length:
max_start = audio_length - max_length
start = torch.randint(0, max_start, (1, ))
data = data[:, start:start + max_length]
else:
data = F.pad(data, (0, max_length - audio_length), 'constant')
data = rearrange(data, '1 ... -> ...')
num_outputs = len(self.target_sample_hz)
data = cast_tuple(data, num_outputs)
data_tuple = tuple(resample(d, sample_hz, target_sample_hz) for d, target_sample_hz in zip(data, self.target_sample_hz))
output = []
for data, seq_len_multiple_of in zip(data_tuple, self.seq_len_multiple_of):
if exists(seq_len_multiple_of):
data = curtail_to_multiple(data, seq_len_multiple_of)
output.append(data.float())
output = tuple(output)
if num_outputs == 1:
return output[0]
return output
def collate_one_or_multiple_tensors(fn):
@wraps(fn)
def inner(data):
is_one_data = not isinstance(data[0], tuple)
if is_one_data:
data = fn(data)
return (data,)
outputs = []
for datum in zip(*data):
if is_bearable(datum, Tuple[str, ...]):
output = list(datum)
else:
output = fn(datum)
outputs.append(output)
return tuple(outputs)
return inner
@collate_one_or_multiple_tensors
def curtail_to_shortest_collate(data):
min_len = min(*[datum.shape[0] for datum in data])
data = [datum[:min_len] for datum in data]
return torch.stack(data)
@collate_one_or_multiple_tensors
def pad_to_longest_fn(data):
return pad_sequence(data, batch_first = True)
def get_dataloader(ds, pad_to_longest = True, **kwargs):
collate_fn = pad_to_longest_fn if pad_to_longest else curtail_to_shortest_collate
return DataLoader(ds, collate_fn = collate_fn, **kwargs)
.\lucidrains\audiolm-pytorch\audiolm_pytorch\encodec.py
from functools import reduce
from einops import rearrange, pack, unpack
import torch
from torch import nn
from torchaudio.functional import resample
from vector_quantize_pytorch import ResidualVQ
from encodec import EncodecModel
from encodec.utils import _linear_overlap_add
def exists(val):
return val is not None
def get_num_quantizers(model: EncodecModel, audio_length = 512):
out = model.encode(torch.randn(1, 1, audio_length))
return out[0][0].shape[1]
class EncodecWrapper(nn.Module):
def __init__(
self,
target_sample_hz = 24000,
strides = (2, 4, 5, 8),
num_quantizers = 8,
bandwidth = 6.0
):
super().__init__()
self.model = EncodecModel.encodec_model_24khz()
self.model.normalize = False
self.model.set_target_bandwidth(bandwidth)
num_quantizers = get_num_quantizers(self.model)
self.target_sample_hz = target_sample_hz
assert self.target_sample_hz == 24000, "haven't done anything with non-24kHz yet"
self.codebook_dim = 128
self.rq_groups = 1
self.num_quantizers = num_quantizers
self.strides = strides
self.rq = ResidualVQ(
dim = 128,
codebook_size = 1024,
num_quantizers = num_quantizers
)
for encodec_rq_layer, rq_layer in zip(self.model.quantizer.vq.layers, self.rq.layers):
encodec_codebook = dict(encodec_rq_layer._codebook.named_buffers()).get('embed')
vq_codebook = dict(rq_layer._codebook.named_buffers()).get('embed')
encodec_codebook = rearrange(encodec_codebook, '... -> 1 ...')
vq_codebook.copy_(encodec_codebook)
@property
def seq_len_multiple_of(self):
return reduce(lambda x, y: x * y, self.strides)
@property
def downsample_factor(self):
return self.seq_len_multiple_of
def forward(
self,
x,
input_sample_hz = None,
return_encoded = False,
**kwargs
):
x, ps = pack([x], '* n')
if exists(input_sample_hz):
x = resample(x, input_sample_hz, self.target_sample_hz)
assert not self.model.training, "Encodec is pretrained and should never be called outside eval mode."
wav = rearrange(x, f'b t -> b {self.model.channels} t')
with torch.inference_mode():
encoded_frames = self.model.encode(wav)
codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1)
codes = rearrange(codes, 'b q n -> b n q')
emb = None
if return_encoded:
emb = self.get_emb_from_indices(codes)
emb, = unpack(emb, ps, '* n c')
codes, = unpack(codes, ps, '* n q')
return emb, codes, None
def decode_from_codebook_indices(self, quantized_indices):
frames = self._decode_frame(quantized_indices)
result = _linear_overlap_add(frames, self.model.segment_stride or 1)
return rearrange(result, 'b n -> b 1 n')
def get_emb_from_indices(self, indices):
codes = rearrange(indices, 'b t q -> q b t')
emb = self.model.quantizer.decode(codes)
return rearrange(emb, 'b c n -> b n c')
def decode(self, emb):
emb = rearrange(emb, 'b n c -> b c n')
return self.model.decoder(emb)
def _decode_frame(self, quantized_indices):
codes = rearrange(quantized_indices, 'b t q -> q b t')
emb = self.model.quantizer.decode(codes)
return self.model.decoder(emb)
.\lucidrains\audiolm-pytorch\audiolm_pytorch\hubert_kmeans.py
from pathlib import Path
import torch
from torch import nn, einsum
from torchaudio.functional import resample
from einops import rearrange, repeat, pack, unpack
from audiolm_pytorch.utils import curtail_to_multiple
def noop(*args, **kwargs):
pass
import warnings
import logging
logging.root.setLevel(logging.ERROR)
warnings.warn = noop
import joblib
import fairseq
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
class HubertWithKmeans(nn.Module):
"""
checkpoint and kmeans can be downloaded at https://github.com/facebookresearch/fairseq/tree/main/examples/hubert
or you can train your own
"""
def __init__(
self,
checkpoint_path,
kmeans_path,
target_sample_hz = 16000,
seq_len_multiple_of = None,
output_layer = 9
):
super().__init__()
self.target_sample_hz = target_sample_hz
self.seq_len_multiple_of = seq_len_multiple_of
self.output_layer = output_layer
model_path = Path(checkpoint_path)
kmeans_path = Path(kmeans_path)
assert model_path.exists(), f'path {checkpoint_path} does not exist'
assert kmeans_path.exists(), f'path {kmeans_path} does not exist'
checkpoint = torch.load(checkpoint_path)
load_model_input = {checkpoint_path: checkpoint}
model, *_ = fairseq.checkpoint_utils.load_model_ensemble_and_task(load_model_input)
self.model = model[0]
self.model.eval()
kmeans = joblib.load(kmeans_path)
self.kmeans = kmeans
self.register_buffer(
'cluster_centers',
torch.from_numpy(kmeans.cluster_centers_)
)
@property
def groups(self):
return 1
@property
def codebook_size(self):
return self.kmeans.n_clusters
@property
def downsample_factor(self):
return 320
@torch.inference_mode()
def forward(
self,
wav_input,
flatten = True,
input_sample_hz = None
):
batch, device = wav_input.shape[0], wav_input.device
if exists(input_sample_hz):
wav_input = resample(wav_input, input_sample_hz, self.target_sample_hz)
if exists(self.seq_len_multiple_of):
wav_input = curtail_to_multiple(wav_input, self.seq_len_multiple_of)
embed = self.model(
wav_input,
features_only = True,
mask = False,
output_layer = self.output_layer
)['x']
batched_cluster_centers = repeat(self.cluster_centers, 'c d -> b c d', b = embed.shape[0])
dists = -torch.cdist(embed, batched_cluster_centers, p = 2)
clusters = dists.argmax(dim = -1)
if flatten:
return clusters
return rearrange(clusters, 'b ... -> b (...)')
.\lucidrains\audiolm-pytorch\audiolm_pytorch\optimizer.py
from torch.optim import AdamW, Adam
def separate_weight_decayable_params(params):
wd_params, no_wd_params = [], []
for param in params:
param_list = no_wd_params if param.ndim < 2 else wd_params
param_list.append(param)
return wd_params, no_wd_params
def get_optimizer(
params,
lr = 1e-4,
wd = 1e-2,
betas = (0.9, 0.99),
eps = 1e-8,
filter_by_requires_grad = False,
group_wd_params = True,
use_lion = False,
**kwargs
):
has_wd = wd > 0
if filter_by_requires_grad:
params = list(filter(lambda t: t.requires_grad, params))
if group_wd_params and has_wd:
wd_params, no_wd_params = separate_weight_decayable_params(params)
params = [
{'params': wd_params},
{'params': no_wd_params, 'weight_decay': 0},
]
if not has_wd:
return Adam(params, lr = lr, betas = betas, eps = eps)
return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps)
.\lucidrains\audiolm-pytorch\audiolm_pytorch\soundstream.py
import functools
from pathlib import Path
from functools import partial, wraps
from itertools import cycle, zip_longest
from typing import Optional, List
import torch
from torch import nn, einsum
from torch.nn import Module, ModuleList
from torch.autograd import grad as torch_grad
import torch.nn.functional as F
from torch.linalg import vector_norm
import torchaudio.transforms as T
from torchaudio.functional import resample
from einops import rearrange, reduce, pack, unpack
from vector_quantize_pytorch import (
GroupedResidualVQ,
GroupedResidualLFQ,
GroupedResidualFSQ
)
from local_attention import LocalMHA
from local_attention.transformer import FeedForward, DynamicPositionBias
from gateloop_transformer import SimpleGateLoopLayer as GateLoop
from audiolm_pytorch.utils import curtail_to_multiple
from audiolm_pytorch.version import __version__
from packaging import version
parsed_version = version.parse(__version__)
import pickle
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def cast_tuple(t, l = 1):
return ((t,) * l) if not isinstance(t, tuple) else t
def filter_by_keys(fn, d):
return {k: v for k, v in d.items() if fn(k)}
def map_keys(fn, d):
return {fn(k): v for k, v in d.items()}
def log(t, eps = 1e-20):
return torch.log(t.clamp(min = eps))
def hinge_discr_loss(fake, real):
return (F.relu(1 + fake) + F.relu(1 - real)).mean()
def hinge_gen_loss(fake):
return -fake.mean()
def leaky_relu(p = 0.1):
return nn.LeakyReLU(p)
def gradient_penalty(wave, output, weight = 10):
batch_size, device = wave.shape[0], wave.device
gradients = torch_grad(
outputs = output,
inputs = wave,
grad_outputs = torch.ones_like(output),
create_graph = True,
retain_graph = True,
only_inputs = True
)[0]
gradients = rearrange(gradients, 'b ... -> b (...)')
return weight * ((vector_norm(gradients, dim = 1) - 1) ** 2).mean()
def Sequential(*mods):
return nn.Sequential(*filter(exists, mods))
class MultiScaleDiscriminator(Module):
def __init__(
self,
channels = 16,
layers = 4,
groups = (4, 16, 64, 256),
chan_max = 1024,
input_channels = 1
):
super().__init__()
self.init_conv = nn.Conv1d(input_channels, channels, 15, padding = 7)
self.conv_layers = ModuleList([])
curr_channels = channels
for _, group in zip(range(layers), groups):
chan_out = min(curr_channels * 4, chan_max)
self.conv_layers.append(nn.Sequential(
nn.Conv1d(curr_channels, chan_out, 41, stride = 4, padding = 20, groups = group),
leaky_relu()
))
curr_channels = chan_out
self.final_conv = nn.Sequential(
nn.Conv1d(curr_channels, curr_channels, 5, padding = 2),
leaky_relu(),
nn.Conv1d(curr_channels, 1, 3, padding = 1),
)
def forward(
self,
x,
return_intermediates = False
):
x = self.init_conv(x)
intermediates = []
for layer in self.conv_layers:
x = layer(x)
intermediates.append(x)
out = self.final_conv(x)
if not return_intermediates:
return out
return out, intermediates
class SqueezeExcite(Module):
def __init__(self, dim, reduction_factor = 4, dim_minimum = 8):
super().__init__()
dim_inner = max(dim_minimum, dim // reduction_factor)
self.net = nn.Sequential(
nn.Conv1d(dim, dim_inner, 1),
nn.SiLU(),
nn.Conv1d(dim_inner, dim, 1),
nn.Sigmoid()
)
def forward(self, x):
seq, device = x.shape[-2], x.device
cum_sum = x.cumsum(dim = -2)
denom = torch.arange(1, seq + 1, device = device).float()
cum_mean = cum_sum / rearrange(denom, 'n -> n 1')
gate = self.net(cum_mean)
return x * gate
class ModReLU(Module):
"""
https://arxiv.org/abs/1705.09792
https://github.com/pytorch/pytorch/issues/47052#issuecomment-718948801
"""
def __init__(self):
super().__init__()
self.b = nn.Parameter(torch.tensor(0.))
def forward(self, x):
return F.relu(torch.abs(x) + self.b) * torch.exp(1.j * torch.angle(x))
class ComplexConv2d(Module):
def __init__(
self,
dim,
dim_out,
kernel_size,
stride = 1,
padding = 0
):
super().__init__()
conv = nn.Conv2d(dim, dim_out, kernel_size, dtype = torch.complex64)
self.weight = nn.Parameter(torch.view_as_real(conv.weight))
self.bias = nn.Parameter(torch.view_as_real(conv.bias))
self.stride = stride
self.padding = padding
def forward(self, x):
weight, bias = map(torch.view_as_complex, (self.weight, self.bias))
x = x.to(weight.dtype)
return F.conv2d(x, weight, bias, stride = self.stride, padding = self.padding)
def ComplexSTFTResidualUnit(chan_in, chan_out, strides):
kernel_sizes = tuple(map(lambda t: t + 2, strides))
paddings = tuple(map(lambda t: t // 2, kernel_sizes))
return nn.Sequential(
Residual(Sequential(
ComplexConv2d(chan_in, chan_in, 3, padding = 1),
ModReLU(),
ComplexConv2d(chan_in, chan_in, 3, padding = 1)
)),
ComplexConv2d(chan_in, chan_out, kernel_sizes, stride = strides, padding = paddings)
)
class ComplexSTFTDiscriminator(Module):
def __init__(
self,
*,
channels = 32,
strides = ((1, 2), (2, 2), (1, 2), (2, 2), (1, 2), (2, 2)),
chan_mults = (1, 2, 4, 4, 8, 8),
input_channels = 1,
n_fft = 1024,
hop_length = 256,
win_length = 1024,
stft_normalized = False,
stft_window_fn = torch.hann_window,
logits_abs = True
):
super().__init__()
self.init_conv = ComplexConv2d(input_channels, channels, 7, padding = 3)
layer_channels = tuple(map(lambda mult: mult * channels, chan_mults))
layer_channels = (channels, *layer_channels)
layer_channels_pairs = tuple(zip(layer_channels[:-1], layer_channels[1:]))
curr_channels = channels
self.layers = ModuleList([])
for layer_stride, (chan_in, chan_out) in zip(strides, layer_channels_pairs):
self.layers.append(ComplexSTFTResidualUnit(chan_in, chan_out, layer_stride))
self.final_conv = ComplexConv2d(layer_channels[-1], 1, (16, 1))
self.stft_normalized = stft_normalized
self.stft_window_fn = stft_window_fn
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
self.logits_abs = logits_abs
def forward(self, x, return_intermediates = False):
x = rearrange(x, 'b 1 n -> b n')
'''
reference: The content of the paper( https://arxiv.org/pdf/2107.03312.pdf)is as follows:
The STFT-based discriminator is illustrated in Figure 4
and operates on a single scale, computing the STFT with a
window length of W = 1024 samples and a hop length of
H = 256 samples
'''
stft_window = self.stft_window_fn(self.win_length, device = x.device)
x = torch.stft(
x,
self.n_fft,
hop_length = self.hop_length,
win_length = self.win_length,
window = stft_window,
normalized = self.stft_normalized,
return_complex = True
)
x = rearrange(x, 'b ... -> b 1 ...')
intermediates = []
x = self.init_conv(x)
intermediates.append(x)
for layer in self.layers:
x = layer(x)
intermediates.append(x)
complex_logits = self.final_conv(x)
if self.logits_abs:
complex_logits = complex_logits.abs()
else:
complex_logits = torch.view_as_real(complex_logits)
if not return_intermediates:
return complex_logits
return complex_logits, intermediates
class Residual(Module):
def __init__(self, fn: Module):
super().__init__()
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) + x
class ChannelTranspose(Module):
def __init__(self, fn: Module):
super().__init__()
self.fn = fn
def forward(self, x, **kwargs):
x = rearrange(x, 'b c n -> b n c')
out = self.fn(x, **kwargs) + x
return rearrange(out, 'b n c -> b c n')
class CausalConv1d(Module):
def __init__(self, chan_in, chan_out, kernel_size, pad_mode = 'reflect', **kwargs):
super().__init__()
kernel_size = kernel_size
dilation = kwargs.get('dilation', 1)
stride = kwargs.get('stride', 1)
self.pad_mode = pad_mode
self.causal_padding = dilation * (kernel_size - 1) + (1 - stride)
self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, **kwargs)
def forward(self, x):
x = F.pad(x, (self.causal_padding, 0), mode = self.pad_mode)
return self.conv(x)
class CausalConvTranspose1d(Module):
def __init__(self, chan_in, chan_out, kernel_size, stride, **kwargs):
super().__init__()
self.upsample_factor = stride
self.padding = kernel_size - 1
self.conv = nn.ConvTranspose1d(chan_in, chan_out, kernel_size, stride, **kwargs)
def forward(self, x):
n = x.shape[-1]
out = self.conv(x)
out = out[..., :(n * self.upsample_factor)]
return out
def ResidualUnit(chan_in, chan_out, dilation, kernel_size = 7, squeeze_excite = False, pad_mode = 'reflect'):
return Residual(Sequential(
CausalConv1d(chan_in, chan_out, kernel_size, dilation = dilation, pad_mode = pad_mode),
nn.ELU(),
CausalConv1d(chan_out, chan_out, 1, pad_mode = pad_mode),
nn.ELU(),
SqueezeExcite(chan_out) if squeeze_excite else None
))
def EncoderBlock(chan_in, chan_out, stride, cycle_dilations = (1, 3, 9), squeeze_excite = False, pad_mode = 'reflect'):
it = cycle(cycle_dilations)
residual_unit = partial(ResidualUnit, squeeze_excite = squeeze_excite, pad_mode = pad_mode)
return nn.Sequential(
residual_unit(chan_in, chan_in, next(it)),
residual_unit(chan_in, chan_in, next(it)),
residual_unit(chan_in, chan_in, next(it)),
CausalConv1d(chan_in, chan_out, 2 * stride, stride = stride)
)
def DecoderBlock(chan_in, chan_out, stride, cycle_dilations = (1, 3, 9), squeeze_excite = False, pad_mode = 'reflect'):
even_stride = (stride % 2 == 0)
padding = (stride + (0 if even_stride else 1)) // 2
output_padding = 0 if even_stride else 1
residual_unit = partial(ResidualUnit, squeeze_excite = squeeze_excite, pad_mode = pad_mode)
it = cycle(cycle_dilations)
return nn.Sequential(
CausalConvTranspose1d(chan_in, chan_out, 2 * stride, stride = stride),
residual_unit(chan_out, chan_out, next(it)),
residual_unit(chan_out, chan_out, next(it)),
residual_unit(chan_out, chan_out, next(it)),
)
class LocalTransformer(Module):
def __init__(
self,
*,
dim,
depth,
heads,
window_size,
dynamic_pos_bias = False,
**kwargs
):
super().__init__()
self.window_size = window_size
self.layers = ModuleList([])
self.pos_bias = None
if dynamic_pos_bias:
self.pos_bias = DynamicPositionBias(dim = dim // 2, heads = heads)
for _ in range(depth):
self.layers.append(ModuleList([
LocalMHA(
dim = dim,
heads = heads,
qk_rmsnorm = True,
window_size = window_size,
use_rotary_pos_emb = not dynamic_pos_bias,
gate_values_per_head = True,
use_xpos = True,
**kwargs
),
FeedForward(dim = dim)
]))
def forward(self, x):
w = self.window_size
attn_bias = self.pos_bias(w, w * 2) if exists(self.pos_bias) else None
for attn, ff in self.layers:
x = attn(x, attn_bias = attn_bias) + x
x = ff(x) + x
return x
class FiLM(Module):
def __init__(self, dim, dim_cond):
super().__init__()
self.to_cond = nn.Linear(dim_cond, dim * 2)
def forward(self, x, cond):
gamma, beta = self.to_cond(cond).chunk(2, dim = -1)
return x * gamma + beta
class SoundStream(Module):
def __init__(
self,
*,
channels = 32,
strides = (2, 4, 5, 8),
channel_mults = (2, 4, 8, 16),
codebook_dim = 512,
codebook_size: Optional[int] = None,
finite_scalar_quantizer_levels: Optional[List[int]] = None,
rq_num_quantizers = 8,
rq_commitment_weight = 1.,
rq_ema_decay = 0.95,
rq_quantize_dropout_multiple_of = 1,
rq_groups = 1,
rq_stochastic_sample_codes = False,
rq_kwargs: dict = {},
use_lookup_free_quantizer = False,
use_finite_scalar_quantizer = False,
input_channels = 1,
discr_multi_scales = (1, 0.5, 0.25),
stft_normalized = False,
enc_cycle_dilations = (1, 3, 9),
dec_cycle_dilations = (1, 3, 9),
multi_spectral_window_powers_of_two = tuple(range(6, 12)),
multi_spectral_n_ffts = 512,
multi_spectral_n_mels = 64,
recon_loss_weight = 1.,
multi_spectral_recon_loss_weight = 1e-5,
adversarial_loss_weight = 1.,
feature_loss_weight = 100,
quantize_dropout_cutoff_index = 1,
target_sample_hz = 16000,
use_local_attn = True,
attn_window_size = 128,
attn_dim_head = 64,
attn_heads = 8,
attn_depth = 1,
attn_xpos_scale_base = None,
attn_dynamic_pos_bias = False,
use_gate_loop_layers = False,
squeeze_excite = False,
complex_stft_discr_logits_abs = True,
pad_mode = 'reflect',
stft_discriminator: Optional[Module] = None,
complex_stft_discr_kwargs: dict = dict()
@property
def device(self):
return next(self.parameters()).device
@property
def configs(self):
return pickle.loads(self._configs)
def decode_from_codebook_indices(self, quantized_indices):
assert quantized_indices.dtype in (torch.long, torch.int32)
if quantized_indices.ndim == 3:
quantized_indices = rearrange(quantized_indices, 'b n (g q) -> g b n q', g = self.rq_groups)
x = self.rq.get_output_from_indices(quantized_indices)
return self.decode(x)
def decode(self, x, quantize = False):
if quantize:
x, *_ = self.rq(x)
if exists(self.decoder_attn):
x = self.decoder_attn(x)
x = rearrange(x, 'b n c -> b c n')
return self.decoder(x)
def save(self, path):
path = Path(path)
pkg = dict(
model = self.state_dict(),
config = self._configs,
version = __version__
)
torch.save(pkg, str(path))
@classmethod
def init_and_load_from(cls, path, strict = True):
path = Path(path)
assert path.exists()
pkg = torch.load(str(path), map_location = 'cpu')
assert 'config' in pkg, 'model configs were not found in this saved checkpoint'
config = pickle.loads(pkg['config'])
soundstream = cls(**config)
soundstream.load(path, strict = strict)
soundstream.eval()
return soundstream
def load(self, path, strict = True):
path = Path(path)
assert path.exists()
pkg = torch.load(str(path), map_location = 'cpu')
if 'version' in pkg and version.parse(pkg['version']) < parsed_version:
print(f'soundstream model being loaded was trained on an older version of audiolm-pytorch ({pkg["version"]})')
has_ema = 'ema_model' in pkg
model_pkg = pkg['ema_model'] if has_ema else pkg['model']
if has_ema:
model_pkg = filter_by_keys(lambda k: k.startswith('ema_model.'), model_pkg)
model_pkg = map_keys(lambda k: k[len('ema_model.'):], model_pkg)
self.load_state_dict(model_pkg, strict = strict)
def load_from_trainer_saved_obj(self, path):
path = Path(path)
assert path.exists()
obj = torch.load(str(path))
self.load_state_dict(obj['model'])
def non_discr_parameters(self):
return [
*self.encoder.parameters(),
*self.decoder.parameters(),
*(self.encoder_attn.parameters() if exists(self.encoder_attn) else []),
*(self.decoder_attn.parameters() if exists(self.decoder_attn) else []),
*self.encoder_film.parameters(),
*self.decoder_film.parameters(),
*self.rq.parameters()
]
@property
def seq_len_multiple_of(self):
return functools.reduce(lambda x, y: x * y, self.strides)
@property
def downsample_factor(self):
return self.seq_len_multiple_of
def process_input(
self,
x,
input_sample_hz = None,
curtail_from_left = False
):
x, ps = pack([x], '* n')
if exists(input_sample_hz):
x = resample(x, input_sample_hz, self.target_sample_hz)
x = curtail_to_multiple(x, self.seq_len_multiple_of, from_left = curtail_from_left)
if x.ndim == 2:
x = rearrange(x, 'b n -> b 1 n')
return x, ps
@torch.no_grad()
def tokenize(self, audio):
self.eval()
return self.forward(audio, return_codes_only = True)
def forward(
self,
x,
target = None,
is_denoising = None,
return_encoded = False,
return_codes_only = False,
return_discr_loss = False,
return_discr_losses_separately = False,
return_loss_breakdown = False,
return_recons_only = False,
input_sample_hz = None,
apply_grad_penalty = False,
curtail_from_left = False
def AudioLMSoundStream(
strides = (2, 4, 5, 8),
target_sample_hz = 16000,
rq_num_quantizers = 12,
**kwargs
):
return SoundStream(
strides = strides,
target_sample_hz = target_sample_hz,
rq_num_quantizers = rq_num_quantizers,
**kwargs
)
def MusicLMSoundStream(
strides = (3, 4, 5, 8),
target_sample_hz = 24000,
rq_num_quantizers = 12,
**kwargs
):
return SoundStream(
strides = strides,
target_sample_hz = target_sample_hz,
rq_num_quantizers = rq_num_quantizers,
**kwargs
)
.\lucidrains\audiolm-pytorch\audiolm_pytorch\t5.py
import torch
import transformers
from transformers import T5Tokenizer, T5EncoderModel, T5Config
from beartype import beartype
from beartype.typing import Union, List
transformers.logging.set_verbosity_error()
def exists(val):
return val is not None
MAX_LENGTH = 256
DEFAULT_T5_NAME = 'google/t5-v1_1-base'
T5_CONFIGS = {}
def get_tokenizer(name):
tokenizer = T5Tokenizer.from_pretrained(name)
return tokenizer
def get_model(name):
model = T5EncoderModel.from_pretrained(name)
return model
def get_model_and_tokenizer(name):
global T5_CONFIGS
if name not in T5_CONFIGS:
T5_CONFIGS[name] = dict()
if "model" not in T5_CONFIGS[name]:
T5_CONFIGS[name]["model"] = get_model(name)
if "tokenizer" not in T5_CONFIGS[name]:
T5_CONFIGS[name]["tokenizer"] = get_tokenizer(name)
return T5_CONFIGS[name]['model'], T5_CONFIGS[name]['tokenizer']
def get_encoded_dim(name):
if name not in T5_CONFIGS:
config = T5Config.from_pretrained(name)
T5_CONFIGS[name] = dict(config = config)
elif "config" in T5_CONFIGS[name]:
config = T5_CONFIGS[name]["config"]
elif "model" in T5_CONFIGS[name]:
config = T5_CONFIGS[name]["model"].config
else:
raise ValueError(f'unknown t5 name {name}')
return config.d_model
@beartype
def t5_encode_text(
texts: Union[str, List[str]],
name = DEFAULT_T5_NAME,
output_device = None
):
if isinstance(texts, str):
texts = [texts]
t5, tokenizer = get_model_and_tokenizer(name)
if torch.cuda.is_available():
t5 = t5.cuda()
device = next(t5.parameters()).device
encoded = tokenizer.batch_encode_plus(
texts,
return_tensors = 'pt',
padding = 'longest',
max_length = MAX_LENGTH,
truncation = True
)
input_ids = encoded.input_ids.to(device)
attn_mask = encoded.attention_mask.to(device)
t5.eval()
with torch.inference_mode():
output = t5(input_ids = input_ids, attention_mask = attn_mask)
encoded_text = output.last_hidden_state.detach()
attn_mask = attn_mask[..., None].bool()
if not exists(output_device):
encoded_text = encoded_text.masked_fill(~attn_mask, 0.)
return encoded_text
encoded_text.to(output_device)
attn_mask.to(output_device)
encoded_text = encoded_text.masked_fill(~attn_mask, 0.)
return encoded_text
.\lucidrains\audiolm-pytorch\audiolm_pytorch\trainer.py
import re
import copy
from math import sqrt
from datetime import timedelta
from random import choice
from pathlib import Path
from shutil import rmtree
from functools import partial
from collections import Counter
from contextlib import contextmanager, nullcontext
from beartype.typing import Union, List, Optional, Tuple, Type
from typing_extensions import Annotated
from beartype import beartype
from beartype.door import is_bearable
from beartype.vale import Is
import torch
import torchaudio
from torch import nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR, _LRScheduler
from torch.utils.data import Dataset, DataLoader, random_split
import pytorch_warmup as warmup
from einops import rearrange
from audiolm_pytorch.optimizer import get_optimizer
import wandb
from ema_pytorch import EMA
from audiolm_pytorch.soundstream import SoundStream
from audiolm_pytorch.encodec import EncodecWrapper
from audiolm_pytorch.audiolm_pytorch import (
SemanticTransformer,
SemanticTransformerWrapper,
CoarseTransformer,
CoarseTransformerWrapper,
FineTransformer,
FineTransformerWrapper,
FairseqVQWav2Vec,
HubertWithKmeans
)
from audiolm_pytorch.data import SoundDataset, get_dataloader
from audiolm_pytorch.utils import AudioConditionerBase
from audiolm_pytorch.version import __version__
from packaging import version
from accelerate import Accelerator, DistributedType
from accelerate.utils import DistributedDataParallelKwargs, InitProcessGroupKwargs
from accelerate.tracking import WandBTracker
DEFAULT_SAMPLE_RATE = 16000
ConstantLRScheduler = partial(LambdaLR, lr_lambda = lambda step: 1.)
ONE_TRAINER_INSTANTIATED = False
def check_one_trainer():
global ONE_TRAINER_INSTANTIATED
assert not ONE_TRAINER_INSTANTIATED, 'only one Trainer can be instantiated at a time for training'
ONE_TRAINER_INSTANTIATED = True
DEFAULT_DDP_KWARGS = DistributedDataParallelKwargs(find_unused_parameters = True)
DATASET_FIELD_TYPE_CONFIG = dict(
raw_wave = Annotated[
torch.Tensor,
Is[lambda t: t.dtype == torch.float and t.ndim in {2, 3}]
],
text = List[str],
text_embeds = Annotated[
torch.Tensor,
Is[lambda t: t.dtype == torch.float and t.ndim == 3]
],
)
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def noop(*args, **kwargs):
pass
def find_first(cond, arr):
for el in arr:
if cond(el):
return el
return None
def cycle(dl):
while True:
for data in dl:
yield data
def cast_tuple(t):
return t if isinstance(t, (tuple, list)) else (t,)
def yes_or_no(question):
answer = input(f'{question} (y/n) ')
return answer.lower() in ('yes', 'y')
def accum_log(log, new_logs):
for key, new_value in new_logs.items():
old_value = log.get(key, 0.)
log[key] = old_value + new_value
return log
def dict_values_to_device(d: dict, device):
out = {}
for k, v in d.items():
out[k] = v.to(device) if torch.is_tensor(v) else v
return out
def has_duplicates(tup):
counts = dict(Counter(tup))
return any(filter(lambda count: count > 1, counts.values()))
def determine_types(data, config):
output = []
for el in data:
for name, data_type in config.items():
if is_bearable(el, data_type):
output.append(name)
break
else:
raise TypeError(f'unable to determine type of {data}')
return tuple(output)
def checkpoint_num_steps(checkpoint_path):
"""Returns the number of steps trained from a checkpoint based on the filename.
# 假设文件名格式类似于"/path/to/semantic.transformer.20000.pt",表示训练步数为2万步。在这种情况下返回20000
"""
results = re.findall(r'\d+', str(checkpoint_path))
if len(results) == 0:
return 0
return int(results[-1])
class OptimizerWithWarmupSchedule(nn.Module):
@beartype
def __init__(
self,
accelerator: Accelerator,
optimizer: Optimizer,
scheduler: Optional[Type[_LRScheduler]] = None,
scheduler_kwargs: dict = dict(),
warmup_steps: int = 0
):
super().__init__()
self.warmup = warmup.LinearWarmup(optimizer, warmup_period = warmup_steps)
if exists(scheduler):
self.scheduler = scheduler(optimizer, **scheduler_kwargs)
else:
self.scheduler = ConstantLRScheduler(optimizer)
self.optimizer = optimizer
self.optimizer, self.scheduler = accelerator.prepare(self.optimizer, self.scheduler)
self.accelerator = accelerator
def state_dict(self):
return dict(
optimizer = self.optimizer.state_dict(),
scheduler = self.scheduler.state_dict(),
warmup = self.warmup.state_dict()
)
def load_state_dict(self, pkg):
self.optimizer.load_state_dict(pkg['optimizer'])
self.scheduler.load_state_dict(pkg['scheduler'])
self.warmup.load_state_dict(pkg['warmup'])
def zero_grad(self):
self.optimizer.zero_grad()
def step(self):
self.optimizer.step()
if not self.accelerator.optimizer_step_was_skipped:
with self.warmup.dampening():
self.scheduler.step()
class SoundStreamTrainer(nn.Module):
@beartype
def __init__(
self,
soundstream: SoundStream,
*,
num_train_steps: int,
batch_size: int,
data_max_length: int = None,
data_max_length_seconds: Union[int, float] = None,
folder: str = None,
dataset: Optional[Dataset] = None,
val_dataset: Optional[Dataset] = None,
train_dataloader: Optional[DataLoader] = None,
val_dataloader: Optional[DataLoader] = None,
lr: float = 2e-4,
grad_accum_every: int = 4,
wd: float = 0.,
warmup_steps: int = 1000,
scheduler: Optional[Type[_LRScheduler]] = None,
scheduler_kwargs: dict = dict(),
discr_warmup_steps: Optional[int] = None,
discr_scheduler: Optional[Type[_LRScheduler]] = None,
discr_scheduler_kwargs: dict = dict(),
max_grad_norm: float = 0.5,
discr_max_grad_norm: float = None,
save_results_every: int = 100,
save_model_every: int = 1000,
log_losses_every: int = 1,
results_folder: str = './results',
valid_frac: float = 0.05,
random_split_seed: int = 42,
use_ema: bool = True,
ema_beta: float = 0.995,
ema_update_after_step: int = 500,
ema_update_every: int = 10,
apply_grad_penalty_every: int = 4,
dl_num_workers: int = 0,
accelerator: Optional[Accelerator] = None,
accelerate_kwargs: dict = dict(),
init_process_group_timeout_seconds = 1800,
dataloader_drop_last = True,
split_batches = False,
use_wandb_tracking = False,
force_clear_prev_results: bool = None
@property
def ema_tokenizer(self):
return self.ema_soundstream.ema_model
def tokenize(self, audio):
return ema_tokenizer.tokenize(audio)
def set_model_as_ema_model_(self):
""" this will force the main 'online' model to have same parameters as the exponentially moving averaged model """
assert self.use_ema
self.ema_soundstream.ema_model.load_state_dict(self.soundstream.state_dict())
def save(self, path):
pkg = dict(
model = self.accelerator.get_state_dict(self.soundstream),
optim = self.optim.state_dict(),
config = self.unwrapped_soundstream._configs,
discr_optim = self.discr_optim.state_dict(),
version = __version__
)
if self.use_ema:
pkg['ema_model'] = self.ema_soundstream.state_dict()
for key, _ in self.multiscale_discriminator_iter():
discr_optim = getattr(self, key)
pkg[key] = discr_optim.state_dict()
torch.save(pkg, path)
@property
def unwrapped_soundstream(self):
return self.accelerator.unwrap_model(self.soundstream)
def load(self, path):
path = Path(path)
assert path.exists()
pkg = torch.load(str(path), map_location = 'cpu')
if len(pkg.keys()) > 20:
self.unwrapped_soundstream.load_state_dict(pkg)
if self.use_ema:
self.ema_soundstream.ema_model.load_state_dict(pkg)
return
if 'version' in pkg and version.parse(pkg['version']) < version.parse(__version__):
print(f'model was trained on older version {pkg["version"]} of audiolm-pytorch')
self.unwrapped_soundstream.load_state_dict(pkg['model'])
if self.use_ema:
assert 'ema_model' in pkg
self.ema_soundstream.load_state_dict(pkg['ema_model'])
self.optim.load_state_dict(pkg['optim'])
self.discr_optim.load_state_dict(pkg['discr_optim'])
for key, _ in self.multiscale_discriminator_iter():
discr_optim = getattr(self, key)
discr_optim.load_state_dict(pkg[key])
self.steps = torch.tensor([checkpoint_num_steps(path) + 1], device=self.device)
def multiscale_discriminator_iter(self):
for ind, discr in enumerate(self.unwrapped_soundstream.discriminators):
yield f'multiscale_discr_optimizer_{ind}', discr
def multiscale_discriminator_optim_iter(self):
for name, _ in self.multiscale_discriminator_iter():
yield name, getattr(self, name)
def print(self, msg):
self.accelerator.print(msg)
def log(self, **logs_as_kwargs):
self.accelerator.log(logs_as_kwargs, step = self.steps.item())
@contextmanager
def wandb_tracker(self, project, run = None, hps = None):
assert self.use_wandb_tracking, '`use_wandb_tracking` must be set to True on SoundStreamTrainer'
hps = default(hps, self.tracker_hps)
self.accelerator.init_trackers(project, config = None)
if exists(run):
wandb_tracker = find_first(lambda el: isinstance(el, WandBTracker), self.accelerator.trackers)
assert exists(wandb_tracker)
wandb_tracker.run.name = run
yield
self.accelerator.end_training()
@property
def device(self):
return self.accelerator.device
@property
def is_distributed(self):
return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1)
@property
def is_main(self):
return self.accelerator.is_main_process
@property
def is_local_main(self):
return self.accelerator.is_local_main_process
def train(self, log_fn = noop):
while self.steps < self.num_train_steps:
logs = self.train_step()
log_fn(logs)
self.print('training complete')
class SemanticTransformerTrainer(nn.Module):
@beartype
def __init__(
self,
wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]],
transformer: SemanticTransformer,
*,
num_train_steps,
batch_size,
audio_conditioner: Optional[AudioConditionerBase] = None,
dataset: Optional[Dataset] = None,
valid_dataset: Optional[Dataset] = None,
data_max_length = None,
data_max_length_seconds = None,
folder = None,
lr = 3e-4,
grad_accum_every = 1,
wd = 0.,
max_grad_norm = 0.5,
valid_frac = 0.05,
random_split_seed = 42,
save_results_every = 100,
save_model_every = 1000,
results_folder = './results',
accelerate_kwargs: dict = dict(),
init_process_group_timeout_seconds = 1800,
use_wandb_tracking = False,
split_batches = False,
drop_last = False,
force_clear_prev_results = None,
average_valid_loss_over_grad_accum_every: bool = True,
def save(self, path):
pkg = dict(
model = self.accelerator.get_state_dict(self.transformer),
optim = self.optim.state_dict(),
version = __version__
)
torch.save(pkg, path)
def load(self, path):
transformer = self.accelerator.unwrap_model(self.transformer)
pkg = transformer.load(path)
self.optim.load_state_dict(pkg['optim'])
self.steps = torch.tensor([checkpoint_num_steps(path) + 1], device=self.device)
def print(self, msg):
self.accelerator.print(msg)
def generate(self, *args, **kwargs):
return self.train_wrapper.generate(*args, **kwargs)
@property
def device(self):
return self.accelerator.device
@property
def is_distributed(self):
return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1)
@property
def is_main(self):
return self.accelerator.is_main_process
@property
def is_local_main(self):
return self.accelerator.is_local_main_process
def data_tuple_to_kwargs(self, data):
if not exists(self.ds_fields):
self.ds_fields = determine_types(data, DATASET_FIELD_TYPE_CONFIG)
assert not has_duplicates(self.ds_fields), 'dataset fields must not have duplicate field names'
return dict(zip(self.ds_fields, data))
@contextmanager
def wandb_tracker(self, project, run = None, hps = None):
assert self.use_wandb_tracking, '`use_wandb_tracking` must be set to True on SemanticTransformerTrainer'
hps = default(hps, self.tracker_hps)
self.accelerator.init_trackers(project, config = None)
if exists(run):
wandb_tracker = find_first(lambda el: isinstance(el, WandBTracker), self.accelerator.trackers)
assert exists(wandb_tracker)
wandb_tracker.run.name = run
yield
self.accelerator.end_training()
def train_step(self):
device = self.device
steps = int(self.steps.item())
self.transformer.train()
logs = {}
for i in range(self.grad_accum_every):
is_last = i == (self.grad_accum_every - 1)
context = partial(self.accelerator.no_sync, self.train_wrapper) if not is_last else nullcontext
data_kwargs = self.data_tuple_to_kwargs(next(self.dl_iter))
with self.accelerator.autocast(), context():
loss = self.train_wrapper(**data_kwargs, return_loss = True)
self.accelerator.backward(loss / self.grad_accum_every)
accum_log(logs, {'loss': loss.item() / self.grad_accum_every})
if exists(self.max_grad_norm):
self.accelerator.clip_grad_norm_(self.transformer.parameters(), self.max_grad_norm)
self.optim.step()
self.optim.zero_grad()
self.print(f"{steps}: loss: {logs['loss']}")
self.accelerator.log({"train_loss": logs['loss']}, step=steps)
self.accelerator.wait_for_everyone()
if self.is_main and not (steps % self.save_results_every):
valid_loss = 0
unwrapped_model = self.accelerator.unwrap_model(self.train_wrapper)
for _ in range(self.average_valid_loss_over_grad_accum_every):
data_kwargs = self.data_tuple_to_kwargs(next(self.valid_dl_iter))
data_kwargs = dict_values_to_device(data_kwargs, unwrapped_model.device)
with torch.inference_mode():
unwrapped_model.eval()
valid_loss += unwrapped_model(**data_kwargs, return_loss = True)
valid_loss = valid_loss.clone()
valid_loss /= self.average_valid_loss_over_grad_accum_every
self.print(f'{steps}: valid loss {valid_loss}')
self.accelerator.log({"valid_loss": valid_loss}, step=steps)
if self.is_main and not (steps % self.save_model_every):
model_path = str(self.results_folder / f'semantic.transformer.{steps}.pt')
self.save(model_path)
if self.use_wandb_tracking:
wandb.save(model_path)
self.print(f'{steps}: saving model to {str(self.results_folder)}')
self.accelerator.wait_for_everyone()
self.steps.add_(1)
return logs
def train(self, log_fn = noop):
while self.steps < self.num_train_steps:
logs = self.train_step()
log_fn(logs)
self.print('training complete')
class CoarseTransformerTrainer(nn.Module):
@beartype
def __init__(
self,
transformer: CoarseTransformer,
codec: Union[SoundStream, EncodecWrapper],
wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]],
*,
num_train_steps,
batch_size,
audio_conditioner: Optional[AudioConditionerBase] = None,
dataset: Optional[Dataset] = None,
valid_dataset: Optional[Dataset] = None,
ds_fields: Tuple[str, ...] = ('raw_wave', 'raw_wave_for_codec', 'text'),
data_max_length = None,
data_max_length_seconds = None,
folder = None,
lr = 3e-4,
grad_accum_every = 1,
wd = 0.,
max_grad_norm = 0.5,
valid_frac = 0.05,
random_split_seed = 42,
save_results_every = 100,
save_model_every = 1000,
results_folder = './results',
accelerate_kwargs: dict = dict(),
init_process_group_timeout_seconds = 1800,
split_batches = False,
drop_last = False,
force_clear_prev_results = None,
use_wandb_tracking = False,
average_valid_loss_over_grad_accum_every: bool = True,
def save(self, path):
pkg = dict(
model = self.accelerator.get_state_dict(self.transformer),
optim = self.optim.state_dict(),
version = __version__
)
torch.save(pkg, path)
def load(self, path):
transformer = self.accelerator.unwrap_model(self.transformer)
pkg = transformer.load(path)
self.optim.load_state_dict(pkg['optim'])
self.steps = torch.tensor([checkpoint_num_steps(path) + 1], device=self.device)
def print(self, msg):
self.accelerator.print(msg)
def generate(self, *args, **kwargs):
return self.train_wrapper.generate(*args, **kwargs)
@contextmanager
def wandb_tracker(self, project, run = None, hps = None):
assert self.use_wandb_tracking, '`use_wandb_tracking` must be set to True on CoarseTransformerTrainer'
hps = default(hps, self.tracker_hps)
self.accelerator.init_trackers(project, config = None)
if exists(run):
wandb_tracker = find_first(lambda el: isinstance(el, WandBTracker), self.accelerator.trackers)
assert exists(wandb_tracker)
wandb_tracker.run.name = run
yield
self.accelerator.end_training()
@property
def device(self):
return self.accelerator.device
@property
def is_distributed(self):
return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1)
@property
def is_main(self):
return self.accelerator.is_main_process
@property
def is_local_main(self):
return self.accelerator.is_local_main_process
def train_step(self):
device = self.device
steps = int(self.steps.item())
self.transformer.train()
logs = {}
for i in range(self.grad_accum_every):
is_last = i == (self.grad_accum_every - 1)
context = partial(self.accelerator.no_sync, self.train_wrapper) if not is_last else nullcontext
data_kwargs = dict(zip(self.ds_fields, next(self.dl_iter)))
with self.accelerator.autocast(), context():
loss = self.train_wrapper(
**data_kwargs,
return_loss = True
)
self.accelerator.backward(loss / self.grad_accum_every)
accum_log(logs, {'loss': loss.item() / self.grad_accum_every})
if exists(self.max_grad_norm):
self.accelerator.clip_grad_norm_(self.transformer.parameters(), self.max_grad_norm)
self.optim.step()
self.optim.zero_grad()
self.print(f"{steps}: loss: {logs['loss']}")
self.accelerator.log({"train_loss": logs['loss']}, step=steps)
self.accelerator.wait_for_everyone()
if self.is_main and not (steps % self.save_results_every):
valid_loss = 0
unwrapped_model = self.accelerator.unwrap_model(self.train_wrapper)
for i in range(self.average_valid_loss_over_grad_accum_every):
data_kwargs = dict(zip(self.ds_fields, next(self.valid_dl_iter)))
data_kwargs = dict_values_to_device(data_kwargs, unwrapped_model.device)
with torch.no_grad():
unwrapped_model.eval()
valid_loss += unwrapped_model(
**data_kwargs,
return_loss = True
)
valid_loss = valid_loss.clone()
valid_loss /= self.average_valid_loss_over_grad_accum_every
self.print(f'{steps}: valid loss {valid_loss}')
self.accelerator.log({"valid_loss": valid_loss}, step=steps)
if self.is_main and not (steps % self.save_model_every):
model_path = str(self.results_folder / f'coarse.transformer.{steps}.pt')
self.save(model_path)
if self.use_wandb_tracking:
wandb.save(model_path)
self.print(f'{steps}: saving model to {str(self.results_folder)}')
self.accelerator.wait_for_everyone()
self.steps.add_(1)
return logs
def train(self, log_fn = noop):
while self.steps < self.num_train_steps:
logs = self.train_step()
log_fn(logs)
self.print('training complete')
class FineTransformerTrainer(nn.Module):
@beartype
def __init__(
self,
transformer: FineTransformer,
codec: Union[SoundStream, EncodecWrapper],
*,
num_train_steps,
batch_size,
audio_conditioner: Optional[AudioConditionerBase] = None,
dataset: Optional[Dataset] = None,
valid_dataset: Optional[Dataset] = None,
data_max_length = None,
data_max_length_seconds = None,
dataset_normalize = False,
folder = None,
lr = 3e-4,
grad_accum_every = 1,
wd = 0.,
max_grad_norm = 0.5,
valid_frac = 0.05,
random_split_seed = 42,
save_results_every = 100,
save_model_every = 1000,
results_folder = './results',
accelerate_kwargs: dict = dict(),
init_process_group_timeout_seconds = 1800,
split_batches = False,
drop_last = False,
use_wandb_tracking = False,
force_clear_prev_results = None,
average_valid_loss_over_grad_accum_every: bool = True,
def save(self, path):
pkg = dict(
model = self.accelerator.get_state_dict(self.transformer),
optim = self.optim.state_dict(),
version = __version__
)
torch.save(pkg, path)
def load(self, path):
transformer = self.accelerator.unwrap_model(self.transformer)
pkg = transformer.load(path)
self.optim.load_state_dict(pkg['optim'])
self.steps = torch.tensor([checkpoint_num_steps(path) + 1], device=self.device)
def print(self, msg):
self.accelerator.print(msg)
def generate(self, *args, **kwargs):
return self.train_wrapper.generate(*args, **kwargs)
@contextmanager
def wandb_tracker(self, project, run = None, hps = None):
assert self.use_wandb_tracking, '`use_wandb_tracking` must be set to True on FineTransformerTrainer'
hps = default(hps, self.tracker_hps)
self.accelerator.init_trackers(project, config = None)
if exists(run):
wandb_tracker = find_first(lambda el: isinstance(el, WandBTracker), self.accelerator.trackers)
assert exists(wandb_tracker)
wandb_tracker.run.name = run
yield
self.accelerator.end_training()
@property
def device(self):
return self.accelerator.device
@property
def is_distributed(self):
return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1)
@property
def is_main(self):
return self.accelerator.is_main_process
@property
def is_local_main(self):
return self.accelerator.is_local_main_process
def data_tuple_to_kwargs(self, data):
if not exists(self.ds_fields):
self.ds_fields = determine_types(data, DATASET_FIELD_TYPE_CONFIG)
assert not has_duplicates(self.ds_fields), 'dataset fields must not have duplicate field names'
return dict(zip(self.ds_fields, data))
def train_step(self):
device = self.device
steps = int(self.steps.item())
self.transformer.train()
logs = {}
for i in range(self.grad_accum_every):
is_last = i == (self.grad_accum_every - 1)
context = partial(self.accelerator.no_sync, self.train_wrapper) if not is_last else nullcontext
data_kwargs = self.data_tuple_to_kwargs(next(self.dl_iter))
with self.accelerator.autocast(), context():
loss = self.train_wrapper(**data_kwargs, return_loss = True)
self.accelerator.backward(loss / self.grad_accum_every)
accum_log(logs, {'loss': loss.item() / self.grad_accum_every})
if exists(self.max_grad_norm):
self.accelerator.clip_grad_norm_(self.transformer.parameters(), self.max_grad_norm)
self.optim.step()
self.optim.zero_grad()
self.print(f"{steps}: loss: {logs['loss']}")
self.accelerator.log({"train_loss": logs['loss']}, step=steps)
self.accelerator.wait_for_everyone()
if self.is_main and not (steps % self.save_results_every):
unwrapped_model = self.accelerator.unwrap_model(self.train_wrapper)
valid_loss = 0
for i in range(self.average_valid_loss_over_grad_accum_every):
data_kwargs = self.data_tuple_to_kwargs(next(self.valid_dl_iter))
data_kwargs = dict_values_to_device(data_kwargs, unwrapped_model.device)
with torch.inference_mode():
unwrapped_model.eval()
valid_loss += unwrapped_model(**data_kwargs, return_loss = True)
valid_loss = valid_loss.clone()
valid_loss /= self.average_valid_loss_over_grad_accum_every
self.print(f'{steps}: valid loss {valid_loss}')
self.accelerator.log({"valid_loss": valid_loss}, step=steps)
if self.is_main and not (steps % self.save_model_every):
model_path = str(self.results_folder / f'fine.transformer.{steps}.pt')
self.save(model_path)
if self.use_wandb_tracking:
wandb.save(model_path)
self.print(f'{steps}: saving model to {str(self.results_folder)}')
self.accelerator.wait_for_everyone()
self.steps.add_(1)
return logs
def train(self, log_fn = noop):
while self.steps < self.num_train_steps:
logs = self.train_step()
log_fn(logs)
self.print('training complete')