Transformers 源码解析(一)
.\activations.py
import math
from collections import OrderedDict
import torch
from packaging import version
from torch import Tensor, nn
from .utils import logging
logger = logging.get_logger(__name__)
class PytorchGELUTanh(nn.Module):
"""
A fast C implementation of the tanh approximation of the GeLU activation function. See
https://arxiv.org/abs/1606.08415.
This implementation is equivalent to NewGELU and FastGELU but much faster. However, it is not an exact numerical
match due to rounding errors.
"""
def __init__(self):
super().__init__()
if version.parse(torch.__version__) < version.parse("1.12.0"):
raise ImportError(
f"You are using torch=={torch.__version__}, but torch>=1.12.0 is required to use "
"PytorchGELUTanh. Please upgrade torch."
)
def forward(self, input: Tensor) -> Tensor:
return nn.functional.gelu(input, approximate="tanh")
class NewGELUActivation(nn.Module):
"""
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
"""
def forward(self, input: Tensor) -> Tensor:
return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
class GELUActivation(nn.Module):
"""
Original Implementation of the GELU activation function in Google BERT repo when initially created. For
information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 +
torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in nn.functional
Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
"""
def __init__(self, use_gelu_python: bool = False):
super().__init__()
if use_gelu_python:
self.act = self._gelu_python
else:
self.act = nn.functional.gelu
def _gelu_python(self, input: Tensor) -> Tensor:
return input * 0.5 * (1.0 + torch.erf(input / math.sqrt(2.0)))
def forward(self, input: Tensor) -> Tensor:
return self.act(input)
class FastGELUActivation(nn.Module):
"""
Placeholder for a fast GELU activation function. Actual implementation is not provided here.
"""
"""
# 前向传播函数,接收一个张量作为输入,返回处理后的张量
def forward(self, input: Tensor) -> Tensor:
# 使用 GELU 近似函数计算
return 0.5 * input * (1.0 + torch.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input)))
class QuickGELUActivation(nn.Module):
"""
Applies a fast but approximate version of GELU activation.
Reference: https://github.com/hendrycks/GELUs
"""
def forward(self, input: Tensor) -> Tensor:
# Implementing GELU approximation using a sigmoid function
return input * torch.sigmoid(1.702 * input)
class ClippedGELUActivation(nn.Module):
"""
Applies GELU activation with output clipped to a specified range [min, max].
This is useful for quantization purposes to handle negative values in the GELU spectrum.
References:
- https://arxiv.org/abs/2004.09602
"""
def __init__(self, min: float, max: float):
if min > max:
raise ValueError(f"min should be < max (got min: {min}, max: {max})")
super().__init__()
self.min = min
self.max = max
def forward(self, x: Tensor) -> Tensor:
# Applying GELU activation and clipping the output
return torch.clip(gelu(x), self.min, self.max)
class AccurateGELUActivation(nn.Module):
"""
Applies a more accurate version of GELU activation compared to QuickGELU.
Reference: https://github.com/hendrycks/GELUs
Implemented in the context of MEGA (Moving Average Equipped Gated Attention).
"""
def __init__(self):
super().__init__()
self.precomputed_constant = math.sqrt(2 / math.pi)
def forward(self, input: Tensor) -> Tensor:
# Implementing the accurate GELU activation formula
return 0.5 * input * (1 + torch.tanh(self.precomputed_constant * (input + 0.044715 * torch.pow(input, 3))))
class MishActivation(nn.Module):
"""
Applies the Mish activation function, a self-regularized non-monotonic activation.
Reference: https://arxiv.org/abs/1908.08681
"""
def __init__(self):
super().__init__()
if version.parse(torch.__version__) < version.parse("1.9.0"):
self.act = self._mish_python
else:
self.act = nn.functional.mish
def _mish_python(self, input: Tensor) -> Tensor:
# Implementing Mish activation using Python function
return input * torch.tanh(nn.functional.softplus(input))
def forward(self, input: Tensor) -> Tensor:
# Applying Mish activation function
return self.act(input)
class LinearActivation(nn.Module):
"""
Applies the linear activation function, i.e., forwarding input directly to output.
"""
def forward(self, input: Tensor) -> Tensor:
# Identity function; returns input unchanged
return input
class LaplaceActivation(nn.Module):
"""
Applies an elementwise activation based on the Laplace function, introduced in MEGA for attention.
This activation is inspired by squared ReLU but offers a bounded range and gradient for improved stability.
Reference: https://arxiv.org/abs/2209.10655
"""
"""
此方法用于计算正向传播过程中的操作,对输入进行标准化处理后,应用误差函数。
:param input: 输入张量
:param mu: 均值参数,默认为0.707107
:param sigma: 标准差参数,默认为0.282095
:return: 处理后的张量
将输入张量标准化,减去均值 mu 并除以标准差乘以 sqrt(2.0)
input = (input - mu).div(sigma * math.sqrt(2.0))
应用误差函数,计算误差函数的正向传播结果,返回结果
return 0.5 * (1.0 + torch.erf(input))
"""
# 定义一个自定义的激活函数 ReLUSquaredActivation,继承自 nn.Module
class ReLUSquaredActivation(nn.Module):
"""
Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
"""
# 定义前向传播方法,接受输入 input
def forward(self, input):
# 应用 ReLU 激活函数到输入
relu_applied = nn.functional.relu(input)
# 对经过 ReLU 激活后的结果进行平方操作
squared = torch.square(relu_applied)
# 返回平方后的结果作为输出
return squared
# 定义一个名为 ClassInstantier 的类,继承自 OrderedDict
class ClassInstantier(OrderedDict):
# 重写 __getitem__ 方法,接受键 key 作为输入
def __getitem__(self, key):
# 调用父类 OrderedDict 的 __getitem__ 方法获取键对应的值 content
content = super().__getitem__(key)
# 如果值 content 是一个元组,则将其解包为 cls 和 kwargs;否则将 cls 设为 content,kwargs 设为一个空字典
cls, kwargs = content if isinstance(content, tuple) else (content, {})
# 返回使用 cls 和 kwargs 创建的类实例
return cls(**kwargs)
# 定义一个名为 ACT2CLS 的字典,将字符串映射到对应的激活函数类或者类与参数元组
ACT2CLS = {
"gelu": GELUActivation,
"gelu_10": (ClippedGELUActivation, {"min": -10, "max": 10}),
"gelu_fast": FastGELUActivation,
"gelu_new": NewGELUActivation,
"gelu_python": (GELUActivation, {"use_gelu_python": True}),
"gelu_pytorch_tanh": PytorchGELUTanh,
"gelu_accurate": AccurateGELUActivation,
"laplace": LaplaceActivation,
"leaky_relu": nn.LeakyReLU,
"linear": LinearActivation,
"mish": MishActivation,
"quick_gelu": QuickGELUActivation,
"relu": nn.ReLU,
"relu2": ReLUSquaredActivation, # 引用了之前定义的 ReLUSquaredActivation 激活函数类
"relu6": nn.ReLU6,
"sigmoid": nn.Sigmoid,
"silu": nn.SiLU, # SiLU 激活函数类,也称作 Swish
"swish": nn.SiLU, # 同上,SiLU 激活函数
"tanh": nn.Tanh,
}
# 使用 ClassInstantier 类创建 ACT2FN 字典,将字符串映射为对应的激活函数类实例
ACT2FN = ClassInstantier(ACT2CLS)
# 定义一个函数 get_activation,接受一个激活函数字符串作为参数
def get_activation(activation_string):
# 如果 activation_string 存在于 ACT2FN 字典中,则返回对应的激活函数类实例
if activation_string in ACT2FN:
return ACT2FN[activation_string]
else:
# 否则抛出 KeyError,指示找不到对应的激活函数字符串
raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}")
# 创建几个全局变量,用于快速访问不同的激活函数实例
gelu_python = get_activation("gelu_python")
gelu_new = get_activation("gelu_new")
gelu = get_activation("gelu")
gelu_fast = get_activation("gelu_fast")
quick_gelu = get_activation("quick_gelu")
silu = get_activation("silu")
mish = get_activation("mish")
linear_act = get_activation("linear")
.\activations_tf.py
import math
import tensorflow as tf
from packaging.version import parse
try:
import tf_keras as keras
except (ModuleNotFoundError, ImportError):
import keras
if parse(keras.__version__).major > 2:
raise ValueError(
"Your currently installed version of Keras is Keras 3, but this is not yet supported in "
"Transformers. Please install the backwards-compatible tf-keras package with "
"`pip install tf-keras`."
)
def _gelu(x):
"""
Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when
initially created. For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) Also see
https://arxiv.org/abs/1606.08415
"""
x = tf.convert_to_tensor(x)
cdf = 0.5 * (1.0 + tf.math.erf(x / tf.cast(tf.sqrt(2.0), x.dtype)))
return x * cdf
def _gelu_new(x):
"""
Gaussian Error Linear Unit. This is a smoother version of the GELU. Original paper: https://arxiv.org/abs/1606.0841
Args:
x: float Tensor to perform activation
Returns:
`x` with the GELU activation applied.
"""
x = tf.convert_to_tensor(x)
pi = tf.cast(math.pi, x.dtype)
coeff = tf.cast(0.044715, x.dtype)
cdf = 0.5 * (1.0 + tf.tanh(tf.sqrt(2.0 / pi) * (x + coeff * tf.pow(x, 3))))
return x * cdf
def mish(x):
x = tf.convert_to_tensor(x)
return x * tf.tanh(tf.math.softplus(x))
def gelu_fast(x):
x = tf.convert_to_tensor(x)
coeff1 = tf.cast(0.044715, x.dtype)
coeff2 = tf.cast(0.7978845608, x.dtype)
return 0.5 * x * (1.0 + tf.tanh(x * coeff2 * (1.0 + coeff1 * x * x)))
def quick_gelu(x):
x = tf.convert_to_tensor(x)
coeff = tf.cast(1.702, x.dtype)
return x * tf.math.sigmoid(coeff * x)
def gelu_10(x):
"""
Clip the range of possible GeLU outputs between [-10, 10]. This is especially useful for quantization purpose, as
it allows mapping 2 negatives values in the GeLU spectrum. For more information on this trick, please refer to
https://arxiv.org/abs/2004.09602
Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when
"""
"""
对输入的张量 x 应用改进的 GELU(Gaussian Error Linear Unit)激活函数,并进行值裁剪。
GELU 函数的数学表达式是:
0.5 * x * (1 + tanh(math.sqrt(2 / pi) * (x + 0.044715 * x^3)))
这里使用了一个 TensorFlow 的内置函数 _gelu 来实现 GELU 激活函数。
参数 x: 输入的张量
返回值: 应用 GELU 激活函数后的张量,裁剪在 [-10, 10] 的范围内
"""
return tf.clip_by_value(_gelu(x), -10, 10)
def glu(x, axis=-1):
"""
Gated Linear Unit. Implementation as defined in the original paper (see https://arxiv.org/abs/1612.08083), where
the input `x` is split in two halves across a dimension (`axis`), A and B, returning A * sigmoid(B).
Args:
`x`: float Tensor to perform activation
`axis`: dimension across which `x` be split in half
Returns:
`x` with the GLU activation applied (with its size halved across the dimension `axis`).
"""
a, b = tf.split(x, 2, axis=axis)
return a * tf.math.sigmoid(b)
if parse(tf.version.VERSION) >= parse("2.4"):
def approximate_gelu_wrap(x):
return keras.activations.gelu(x, approximate=True)
gelu = keras.activations.gelu
gelu_new = approximate_gelu_wrap
else:
gelu = _gelu
gelu_new = _gelu_new
ACT2FN = {
"gelu": gelu,
"gelu_10": gelu_10,
"gelu_fast": gelu_fast,
"gelu_new": gelu_new,
"glu": glu,
"mish": mish,
"quick_gelu": quick_gelu,
"relu": keras.activations.relu,
"sigmoid": keras.activations.sigmoid,
"silu": keras.activations.swish,
"swish": keras.activations.swish,
"tanh": keras.activations.tanh,
}
def get_tf_activation(activation_string):
if activation_string in ACT2FN:
return ACT2FN[activation_string]
else:
raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}")
.\audio_utils.py
"""
Audio processing functions to extract features from audio waveforms. This code is pure numpy to support all frameworks
and remove unnecessary dependencies.
"""
import warnings
from typing import Optional, Tuple, Union
import numpy as np
def hertz_to_mel(freq: Union[float, np.ndarray], mel_scale: str = "htk") -> Union[float, np.ndarray]:
"""
Convert frequency from hertz to mels.
Args:
freq (`float` or `np.ndarray`):
The frequency, or multiple frequencies, in hertz (Hz).
mel_scale (`str`, *optional*, defaults to `"htk"`):
The mel frequency scale to use, `"htk"`, `"kaldi"` or `"slaney"`.
Returns:
`float` or `np.ndarray`: The frequencies on the mel scale.
"""
if mel_scale not in ["slaney", "htk", "kaldi"]:
raise ValueError('mel_scale should be one of "htk", "slaney" or "kaldi".')
if mel_scale == "htk":
return 2595.0 * np.log10(1.0 + (freq / 700.0))
elif mel_scale == "kaldi":
return 1127.0 * np.log(1.0 + (freq / 700.0))
min_log_hertz = 1000.0
min_log_mel = 15.0
logstep = 27.0 / np.log(6.4)
mels = 3.0 * freq / 200.0
if isinstance(freq, np.ndarray):
log_region = freq >= min_log_hertz
mels[log_region] = min_log_mel + np.log(freq[log_region] / min_log_hertz) * logstep
elif freq >= min_log_hertz:
mels = min_log_mel + np.log(freq / min_log_hertz) * logstep
return mels
def mel_to_hertz(mels: Union[float, np.ndarray], mel_scale: str = "htk") -> Union[float, np.ndarray]:
"""
Convert frequency from mels to hertz.
Args:
mels (`float` or `np.ndarray`):
The frequency, or multiple frequencies, in mels.
mel_scale (`str`, *optional*, `"htk"`):
The mel frequency scale to use, `"htk"`, `"kaldi"` or `"slaney"`.
Returns:
`float` or `np.ndarray`: The frequencies in hertz.
"""
if mel_scale not in ["slaney", "htk", "kaldi"]:
raise ValueError('mel_scale should be one of "htk", "slaney" or "kaldi".')
if mel_scale == "htk":
return 700.0 * (np.power(10, mels / 2595.0) - 1.0)
elif mel_scale == "kaldi":
return 700.0 * (np.exp(mels / 1127.0) - 1.0)
min_log_hertz = 1000.0
min_log_mel = 15.0
logstep = np.log(6.4) / 27.0
freq = 200.0 * mels / 3.0
return freq
if isinstance(mels, np.ndarray):
log_region = mels >= min_log_mel
freq[log_region] = min_log_hertz * np.exp(logstep * (mels[log_region] - min_log_mel))
elif mels >= min_log_mel:
freq = min_log_hertz * np.exp(logstep * (mels - min_log_mel))
return freq
def hertz_to_octave(
freq: Union[float, np.ndarray], tuning: Optional[float] = 0.0, bins_per_octave: Optional[int] = 12
):
"""
Convert frequency from hertz to fractional octave numbers.
Adapted from *librosa*.
Args:
freq (`float` or `np.ndarray`):
The frequency, or multiple frequencies, in hertz (Hz).
tuning (`float`, defaults to `0.`):
Tuning deviation from the Stuttgart pitch (A440) in (fractional) bins per octave.
bins_per_octave (`int`, defaults to `12`):
Number of bins per octave.
Returns:
`float` or `np.ndarray`: The frequencies on the octave scale.
"""
stuttgart_pitch = 440.0 * 2.0 ** (tuning / bins_per_octave)
octave = np.log2(freq / (float(stuttgart_pitch) / 16))
return octave
def _create_triangular_filter_bank(fft_freqs: np.ndarray, filter_freqs: np.ndarray) -> np.ndarray:
"""
Creates a triangular filter bank.
Adapted from *torchaudio* and *librosa*.
Args:
fft_freqs (`np.ndarray` of shape `(num_frequency_bins,)`):
Discrete frequencies of the FFT bins in Hz.
filter_freqs (`np.ndarray` of shape `(num_mel_filters,)`):
Center frequencies of the triangular filters to create, in Hz.
Returns:
`np.ndarray` of shape `(num_frequency_bins, num_mel_filters)`
"""
filter_diff = np.diff(filter_freqs)
slopes = np.expand_dims(filter_freqs, 0) - np.expand_dims(fft_freqs, 1)
down_slopes = -slopes[:, :-2] / filter_diff[:-1]
up_slopes = slopes[:, 2:] / filter_diff[1:]
return np.maximum(np.zeros(1), np.minimum(down_slopes, up_slopes))
def chroma_filter_bank(
num_frequency_bins: int,
num_chroma: int,
sampling_rate: int,
tuning: float = 0.0,
power: Optional[float] = 2.0,
weighting_parameters: Optional[Tuple[float]] = (5.0, 2),
start_at_c_chroma: Optional[bool] = True,
):
"""
Creates a chroma filter bank, i.e a linear transformation to project spectrogram bins onto chroma bins.
Adapted from *librosa*.
"""
frequencies = np.linspace(0, sampling_rate, num_frequency_bins, endpoint=False)[1:]
freq_bins = num_chroma * hertz_to_octave(frequencies, tuning=tuning, bins_per_octave=num_chroma)
freq_bins = np.concatenate(([freq_bins[0] - 1.5 * num_chroma], freq_bins))
bins_width = np.concatenate((np.maximum(freq_bins[1:] - freq_bins[:-1], 1.0), [1]))
chroma_filters = np.subtract.outer(freq_bins, np.arange(0, num_chroma, dtype="d")).T
num_chroma2 = np.round(float(num_chroma) / 2)
chroma_filters = np.remainder(chroma_filters + num_chroma2 + 10 * num_chroma, num_chroma) - num_chroma2
chroma_filters = np.exp(-0.5 * (2 * chroma_filters / np.tile(bins_width, (num_chroma, 1))) ** 2)
if power is not None:
chroma_filters = chroma_filters / np.sum(chroma_filters**power, axis=0, keepdims=True) ** (1.0 / power)
if weighting_parameters is not None:
center, half_width = weighting_parameters
chroma_filters *= np.tile(
np.exp(-0.5 * (((freq_bins / num_chroma - center) / half_width) ** 2)),
(num_chroma, 1),
)
if start_at_c_chroma:
chroma_filters = np.roll(chroma_filters, -3 * (num_chroma // 12), axis=0)
return np.ascontiguousarray(chroma_filters[:, : int(1 + num_frequency_bins / 2)])
def mel_filter_bank(
num_frequency_bins: int,
num_mel_filters: int,
min_frequency: float,
max_frequency: float,
sampling_rate: int,
norm: Optional[str] = None,
mel_scale: str = "htk",
triangularize_in_mel_space: bool = False,
) -> np.ndarray:
"""
创建用于生成梅尔频谱图的频率 bin 转换矩阵,称为梅尔滤波器组。存在多种实现方式,这些方式在滤波器数量、滤波器形状、
滤波器间距、滤波器带宽以及频谱扭曲方式上都有所不同。这些特性旨在近似人类对频率变化的非线性感知。
文献中引入了不同的梅尔滤波器组变体。以下几种变体是支持的:
- MFCC FB-20: 由Davis和Mermelstein于1980年引入,假设采样频率为10 kHz,语音带宽为 `[0, 4600]` Hz。
- MFCC FB-24 HTK: 来自于剑桥HMM工具包(HTK)(1995年),使用24个滤波器的滤波器组,语音带宽为 `[0, 8000]` Hz。
假设采样率 ≥ 16 kHz。
- MFCC FB-40: 来自于Slaney在1998年为MATLAB编写的听觉工具箱,假设采样率为16 kHz,语音带宽为 `[133, 6854]` Hz。
此版本还包括区域归一化。
- HFCC-E FB-29(人因谱系数):由Skowronski和Harris于2004年提出,假设采样率为12.5 kHz,语音带宽为 `[0, 6250]` Hz。
此代码改编自 *torchaudio* 和 *librosa*。请注意,torchaudio 的 `melscale_fbanks` 的默认参数实现了 `"htk"` 滤波器,
而 librosa 使用 `"slaney"` 实现。
Args:
num_frequency_bins (`int`):
用于计算频谱图的频率数量(应与 `stft` 中的相同)。
num_mel_filters (`int`):
要生成的梅尔滤波器数量。
min_frequency (`float`):
兴趣的最低频率(单位:Hz)。
max_frequency (`float`):
兴趣的最高频率(单位:Hz)。不应超过 `sampling_rate / 2`。
sampling_rate (`int`):
音频波形的采样率。
norm (`str`, *optional*):
如果是 `"slaney"`,将三角形梅尔权重除以梅尔带宽的宽度(区域归一化)。
mel_scale (`str`, *optional*, defaults to `"htk"`):
要使用的梅尔频率刻度,可选 `"htk"`、`"kaldi"` 或 `"slaney"`。
triangularize_in_mel_space (`bool`, *optional*, defaults to `False`):
如果启用此选项,则在梅尔空间而不是频率空间中应用三角形滤波器。在计算梅尔滤波器时应将其设置为 `True`,以便获得与 `torchaudio` 相同的结果。
"""
pass
if norm is not None and norm != "slaney":
raise ValueError('norm must be one of None or "slaney"')
mel_min = hertz_to_mel(min_frequency, mel_scale=mel_scale)
mel_max = hertz_to_mel(max_frequency, mel_scale=mel_scale)
mel_freqs = np.linspace(mel_min, mel_max, num_mel_filters + 2)
filter_freqs = mel_to_hertz(mel_freqs, mel_scale=mel_scale)
if triangularize_in_mel_space:
fft_bin_width = sampling_rate / (num_frequency_bins * 2)
fft_freqs = hertz_to_mel(fft_bin_width * np.arange(num_frequency_bins), mel_scale=mel_scale)
filter_freqs = mel_freqs
else:
fft_freqs = np.linspace(0, sampling_rate // 2, num_frequency_bins)
mel_filters = _create_triangular_filter_bank(fft_freqs, filter_freqs)
if norm is not None and norm == "slaney":
enorm = 2.0 / (filter_freqs[2 : num_mel_filters + 2] - filter_freqs[:num_mel_filters])
mel_filters *= np.expand_dims(enorm, 0)
if (mel_filters.max(axis=0) == 0.0).any():
warnings.warn(
"At least one mel filter has all zero values. "
f"The value for `num_mel_filters` ({num_mel_filters}) may be set too high. "
f"Or, the value for `num_frequency_bins` ({num_frequency_bins}) may be set too low."
)
return mel_filters
def optimal_fft_length(window_length: int) -> int:
"""
Finds the best FFT input size for a given `window_length`. This function takes a given window length and, if not
already a power of two, rounds it up to the next power or two.
The FFT algorithm works fastest when the length of the input is a power of two, which may be larger than the size
of the window or analysis frame. For example, if the window is 400 samples, using an FFT input size of 512 samples
is more optimal than an FFT size of 400 samples. Using a larger FFT size does not affect the detected frequencies,
it simply gives a higher frequency resolution (i.e. the frequency bins are smaller).
"""
return 2 ** int(np.ceil(np.log2(window_length)))
def window_function(
window_length: int,
name: str = "hann",
periodic: bool = True,
frame_length: Optional[int] = None,
center: bool = True,
) -> np.ndarray:
"""
Returns an array containing the specified window. This window is intended to be used with `stft`.
The following window types are supported:
- `"boxcar"`: a rectangular window
- `"hamming"`: the Hamming window
- `"hann"`: the Hann window
- `"povey"`: the Povey window
Args:
window_length (`int`):
The length of the window in samples.
name (`str`, *optional*, defaults to `"hann"`):
The name of the window function.
periodic (`bool`, *optional*, defaults to `True`):
Whether the window is periodic or symmetric.
frame_length (`int`, *optional*):
The length of the analysis frames in samples. Provide a value for `frame_length` if the window is smaller
than the frame length, so that it will be zero-padded.
center (`bool`, *optional*, defaults to `True`):
Whether to center the window inside the FFT buffer. Only used when `frame_length` is provided.
Returns:
`np.ndarray` of shape `(window_length,)` or `(frame_length,)` containing the window.
"""
length = window_length + 1 if periodic else window_length
if name == "boxcar":
window = np.ones(length)
elif name in ["hamming", "hamming_window"]:
window = np.hamming(length)
elif name in ["hann", "hann_window"]:
window = np.hanning(length)
elif name in ["povey"]:
window = np.power(np.hanning(length), 0.85)
else:
raise ValueError(f"Unknown window function '{name}'")
if periodic:
window = window[:-1]
if frame_length is None:
return window
if window_length > frame_length:
raise ValueError(
f"Length of the window ({window_length}) may not be larger than frame_length ({frame_length})"
)
padded_window = np.zeros(frame_length)
offset = (frame_length - window_length) // 2 if center else 0
padded_window[offset : offset + window_length] = window
return padded_window
def spectrogram(
waveform: np.ndarray,
window: np.ndarray,
frame_length: int,
hop_length: int,
fft_length: Optional[int] = None,
power: Optional[float] = 1.0,
center: bool = True,
pad_mode: str = "reflect",
onesided: bool = True,
preemphasis: Optional[float] = None,
mel_filters: Optional[np.ndarray] = None,
mel_floor: float = 1e-10,
log_mel: Optional[str] = None,
reference: float = 1.0,
min_value: float = 1e-10,
db_range: Optional[float] = None,
remove_dc_offset: Optional[bool] = None,
dtype: np.dtype = np.float32,
) -> np.ndarray:
"""
Calculates a spectrogram over one waveform using the Short-Time Fourier Transform.
This function can create the following kinds of spectrograms:
- amplitude spectrogram (`power = 1.0`)
- power spectrogram (`power = 2.0`)
- complex-valued spectrogram (`power = None`)
- log spectrogram (use `log_mel` argument)
- mel spectrogram (provide `mel_filters`)
- log-mel spectrogram (provide `mel_filters` and `log_mel`)
How this works:
1. The input waveform is split into frames of size `frame_length` that are partially overlapping by `hop_length` samples.
2. Each frame is multiplied by the window and placed into a buffer of size `fft_length`.
3. The DFT is taken of each windowed frame.
4. The results are stacked into a spectrogram.
We make a distinction between the following "blocks" of sample data, each of which may have different lengths:
- The analysis frame. This is the size of the time slices that the input waveform is split into.
- The window. Each analysis frame is multiplied by the window to avoid spectral leakage.
- The FFT input buffer. The length of this determines how many frequency bins are in the spectrogram.
In this implementation, the window is assumed to be zero-padded to have the same size as the analysis frame. A
padded window can be obtained from `window_function()`. The FFT input buffer may be larger than the analysis frame,
typically the next power of two.
Note: This function is not optimized for speed yet. It should be mostly compatible with `librosa.stft` and
`torchaudio.functional.transforms.Spectrogram`, although it is more flexible due to the different ways spectrograms
can be constructed.
Returns:
`nd.array` containing a spectrogram of shape `(num_frequency_bins, length)` for a regular spectrogram or shape
`(num_mel_filters, length)` for a mel spectrogram.
"""
window_length = len(window)
if fft_length is None:
fft_length = frame_length
if frame_length > fft_length:
raise ValueError(f"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})")
if window_length != frame_length:
raise ValueError(f"Length of the window ({window_length}) must equal frame_length ({frame_length})")
if hop_length <= 0:
raise ValueError("hop_length must be greater than zero")
if waveform.ndim != 1:
raise ValueError(f"Input waveform must have only one dimension, shape is {waveform.shape}")
if np.iscomplexobj(waveform):
raise ValueError("Complex-valued input waveforms are not currently supported")
if power is None and mel_filters is not None:
raise ValueError(
"You have provided `mel_filters` but `power` is `None`. Mel spectrogram computation is not yet supported for complex-valued spectrogram."
"Specify `power` to fix this issue."
)
if center:
padding = [(int(frame_length // 2), int(frame_length // 2))]
waveform = np.pad(waveform, padding, mode=pad_mode)
waveform = waveform.astype(np.float64)
window = window.astype(np.float64)
num_frames = int(1 + np.floor((waveform.size - frame_length) / hop_length))
num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length
spectrogram = np.empty((num_frames, num_frequency_bins), dtype=np.complex64)
fft_func = np.fft.rfft if onesided else np.fft.fft
buffer = np.zeros(fft_length)
timestep = 0
for frame_idx in range(num_frames):
buffer[:frame_length] = waveform[timestep : timestep + frame_length]
if remove_dc_offset:
buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean()
if preemphasis is not None:
buffer[1:frame_length] -= preemphasis * buffer[: frame_length - 1]
buffer[0] *= 1 - preemphasis
buffer[:frame_length] *= window
spectrogram[frame_idx] = fft_func(buffer)
timestep += hop_length
if power is not None:
spectrogram = np.abs(spectrogram, dtype=np.float64) ** power
spectrogram = spectrogram.T
if mel_filters is not None:
spectrogram = np.maximum(mel_floor, np.dot(mel_filters.T, spectrogram))
if power is not None and log_mel is not None:
if log_mel == "log":
spectrogram = np.log(spectrogram)
elif log_mel == "log10":
spectrogram = np.log10(spectrogram)
elif log_mel == "dB":
if power == 1.0:
spectrogram = amplitude_to_db(spectrogram, reference, min_value, db_range)
elif power == 2.0:
spectrogram = power_to_db(spectrogram, reference, min_value, db_range)
else:
raise ValueError(f"Cannot use log_mel option '{log_mel}' with power {power}")
else:
raise ValueError(f"Unknown log_mel option: {log_mel}")
spectrogram = np.asarray(spectrogram, dtype)
return spectrogram
def power_to_db(
spectrogram: np.ndarray,
reference: float = 1.0,
min_value: float = 1e-10,
db_range: Optional[float] = None,
) -> np.ndarray:
"""
Converts a power spectrogram to the decibel scale. This computes `10 * log10(spectrogram / reference)`, using basic
logarithm properties for numerical stability.
The motivation behind applying the log function on the (mel) spectrogram is that humans do not hear loudness on a
linear scale. Generally to double the perceived volume of a sound we need to put 8 times as much energy into it.
This means that large variations in energy may not sound all that different if the sound is loud to begin with.
This compression operation makes the (mel) spectrogram features match more closely what humans actually hear.
Based on the implementation of `librosa.power_to_db`.
Args:
spectrogram (`np.ndarray`):
The input power (mel) spectrogram. Note that a power spectrogram has the amplitudes squared!
reference (`float`, *optional*, defaults to 1.0):
Sets the input spectrogram value that corresponds to 0 dB. For example, use `np.max(spectrogram)` to set
the loudest part to 0 dB. Must be greater than zero.
min_value (`float`, *optional*, defaults to `1e-10`):
The spectrogram will be clipped to this minimum value before conversion to decibels, to avoid taking
`log(0)`. The default of `1e-10` corresponds to a minimum of -100 dB. Must be greater than zero.
db_range (`float`, *optional*):
Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the
peak value and the smallest value will never be more than 80 dB. Must be greater than zero.
Returns:
`np.ndarray`: the spectrogram in decibels
"""
if reference <= 0.0:
raise ValueError("reference must be greater than zero")
if min_value <= 0.0:
raise ValueError("min_value must be greater than zero")
reference = max(min_value, reference)
spectrogram = np.clip(spectrogram, a_min=min_value, a_max=None)
spectrogram = 10.0 * (np.log10(spectrogram) - np.log10(reference))
if db_range is not None:
if db_range <= 0.0:
raise ValueError("db_range must be greater than zero")
spectrogram = np.clip(spectrogram, a_min=spectrogram.max() - db_range, a_max=None)
return spectrogram
def amplitude_to_db(
spectrogram: np.ndarray,
reference: float = 1.0,
min_value: float = 1e-5,
db_range: Optional[float] = None,
) -> np.ndarray:
"""
Converts an amplitude spectrogram to the decibel scale. This computes `20 * log10(spectrogram / reference)`, using
basic logarithm properties for numerical stability.
"""
def amplitude_to_db(spectrogram, reference=1.0, min_value=1e-5, db_range=None):
if reference <= 0.0:
raise ValueError("reference must be greater than zero")
if min_value <= 0.0:
raise ValueError("min_value must be greater than zero")
reference = max(min_value, reference)
spectrogram = np.clip(spectrogram, a_min=min_value, a_max=None)
spectrogram = 20.0 * (np.log10(spectrogram) - np.log10(reference))
if db_range is not None:
if db_range <= 0.0:
raise ValueError("db_range must be greater than zero")
spectrogram = np.clip(spectrogram, a_min=spectrogram.max() - db_range, a_max=None)
return spectrogram
def get_mel_filter_banks(
nb_frequency_bins: int,
nb_mel_filters: int,
frequency_min: float,
frequency_max: float,
sample_rate: int,
norm: Optional[str] = None,
mel_scale: str = "htk",
) -> np.array:
warnings.warn(
"The function `get_mel_filter_banks` is deprecated and will be removed in version 4.31.0 of Transformers",
FutureWarning,
)
return mel_filter_bank(
num_frequency_bins=nb_frequency_bins,
num_mel_filters=nb_mel_filters,
min_frequency=frequency_min,
max_frequency=frequency_max,
sampling_rate=sample_rate,
norm=norm,
mel_scale=mel_scale,
)
def fram_wave(waveform: np.array, hop_length: int = 160, fft_window_size: int = 400, center: bool = True):
"""
为了计算短时傅里叶变换,需要将波形分割成重叠的窗口化片段,称为“帧”。
Args:
waveform (`np.array` of shape `(sample_length,)`):
将被分割成较小块的原始波形。
hop_length (`int`, *optional*, defaults to 160):
波形的每个窗口之间的步长。
fft_window_size (`int`, *optional*, defaults to 400):
窗口的大小。
center (`bool`, defaults to `True`):
是否将每个帧居中于帧的中间。居中通过在左右两侧反射波形来实现。
Return:
framed_waveform (`np.array` of shape `(waveform.shape // hop_length , fft_window_size)`):
可供 `np.fft` 使用的帧化波形。
"""
warnings.warn(
"The function `fram_wave` is deprecated and will be removed in version 4.31.0 of Transformers",
FutureWarning,
)
frames = []
for i in range(0, waveform.shape[0] + 1, hop_length):
if center:
half_window = (fft_window_size - 1) // 2 + 1
start = i - half_window if i > half_window else 0
end = i + half_window if i < waveform.shape[0] - half_window else waveform.shape[0]
frame = waveform[start:end]
if start == 0:
padd_width = (-i + half_window, 0)
frame = np.pad(frame, pad_width=padd_width, mode="reflect")
elif end == waveform.shape[0]:
padd_width = (0, (i - waveform.shape[0] + half_window))
frame = np.pad(frame, pad_width=padd_width, mode="reflect")
else:
frame = waveform[i : i + fft_window_size]
frame_width = frame.shape[0]
if frame_width < waveform.shape[0]:
frame = np.lib.pad(
frame, pad_width=(0, fft_window_size - frame_width), mode="constant", constant_values=0
)
frames.append(frame)
frames = np.stack(frames, 0)
return frames
def stft(frames: np.array, windowing_function: np.array, fft_window_size: int = None):
"""
Calculates the complex Short-Time Fourier Transform (STFT) of the given framed signal. Should give the same results
as `torch.stft`.
Args:
frames (`np.array` of dimension `(num_frames, fft_window_size)`):
A framed audio signal obtained using `audio_utils.fram_wav`.
windowing_function (`np.array` of dimension `(nb_frequency_bins, nb_mel_filters)`:
An array representing the function used to reduce amplitude discontinuities at frame boundaries when computing STFT.
Each frame is multiplied by this windowing function. For details on these discontinuities (Spectral leakage),
refer to [this tutorial](https://download.ni.com/evaluation/pxi/Understanding%20FFTs%20and%20Windowing.pdf).
fft_window_size (`int`, *optional*):
Size of the window on which the Fourier transform is applied, controlling frequency resolution of the spectrogram.
Default is `None`, where it defaults to `frame_size`. Increasing `fft_window_size` slows computation but improves resolution.
Example:
```
>>> from transformers.audio_utils import stft, fram_wave
>>> import numpy as np
>>> audio = np.random.rand(50)
>>> fft_window_size = 10
>>> hop_length = 2
>>> framed_audio = fram_wave(audio, hop_length, fft_window_size)
>>> spectrogram = stft(framed_audio, np.hanning(fft_window_size + 1))
```
Returns:
spectrogram (`np.ndarray`):
Spectrogram of shape `(num_frames, nb_frequency_bins)` obtained using STFT algorithm
"""
warnings.warn(
"The function `stft` is deprecated and will be removed in version 4.31.0 of Transformers",
FutureWarning,
)
frame_size = frames.shape[1]
if fft_window_size is None:
fft_window_size = frame_size
if fft_window_size < frame_size:
raise ValueError("FFT size must be greater or equal to the frame size")
nb_frequency_bins = (fft_window_size >> 1) + 1
spectrogram = np.empty((len(frames), nb_frequency_bins), dtype=np.complex64)
fft_signal = np.zeros(fft_window_size)
for f, frame in enumerate(frames):
if windowing_function is not None:
np.multiply(frame, windowing_function, out=fft_signal[:frame_size])
else:
fft_signal[:frame_size] = frame
spectrogram[f] = np.fft.fft(fft_signal, axis=0)[:nb_frequency_bins]
return spectrogram.T
.\benchmark\benchmark.py
"""
在 PyTorch 中对库进行推理和训练的基准测试。
"""
import timeit
from typing import Callable, Optional
from ..configuration_utils import PretrainedConfig
from ..models.auto.modeling_auto import MODEL_MAPPING, MODEL_WITH_LM_HEAD_MAPPING
from ..utils import is_py3nvml_available, is_torch_available, logging
from .benchmark_utils import (
Benchmark,
Memory,
MemorySummary,
measure_peak_memory_cpu,
start_memory_tracing,
stop_memory_tracing,
)
if is_torch_available():
import torch
from .benchmark_args import PyTorchBenchmarkArguments
if is_py3nvml_available():
import py3nvml.py3nvml as nvml
logger = logging.get_logger(__name__)
class PyTorchBenchmark(Benchmark):
args: PyTorchBenchmarkArguments
configs: PretrainedConfig
framework: str = "PyTorch"
@property
def framework_version(self):
return torch.__version__
def _inference_speed(self, model_name: str, batch_size: int, sequence_length: int) -> float:
_inference = self._prepare_inference_func(model_name, batch_size, sequence_length)
return self._measure_speed(_inference)
def _inference_memory(
self, model_name: str, batch_size: int, sequence_length: int
) -> [Memory, Optional[MemorySummary]]:
_inference = self._prepare_inference_func(model_name, batch_size, sequence_length)
return self._measure_memory(_inference)
def _train_speed(self, model_name: str, batch_size: int, sequence_length: int) -> float:
_train = self._prepare_train_func(model_name, batch_size, sequence_length)
return self._measure_speed(_train)
def _train_memory(
self, model_name: str, batch_size: int, sequence_length: int
) -> [Memory, Optional[MemorySummary]]:
_train = self._prepare_train_func(model_name, batch_size, sequence_length)
return self._measure_memory(_train)
def _prepare_inference_func(self, model_name: str, batch_size: int, sequence_length: int) -> Callable[[], None]:
config = self.config_dict[model_name]
if self.args.torchscript:
config.torchscript = True
has_model_class_in_config = (
hasattr(config, "architectures")
and isinstance(config.architectures, list)
and len(config.architectures) > 0
)
if not self.args.only_pretrain_model and has_model_class_in_config:
try:
model_class = config.architectures[0]
transformers_module = __import__("transformers", fromlist=[model_class])
model_cls = getattr(transformers_module, model_class)
model = model_cls(config)
except ImportError:
raise ImportError(
f"{model_class} does not exist. If you just want to test the pretrained model, you might want to"
" set `--only_pretrain_model` or `args.only_pretrain_model=True`."
)
else:
model = MODEL_MAPPING[config.__class__](config)
model.eval()
model.to(self.args.device)
vocab_size = config.vocab_size if hasattr(config, "vocab_size") else config.encoder.vocab_size
input_ids = torch.randint(vocab_size, (batch_size, sequence_length), dtype=torch.long, device=self.args.device)
if self.args.fp16:
logger.info("Running training in Mixed Precision...")
if not self.args.is_gpu:
raise ValueError("Mixed precision is possible only for GPU.")
model.half()
if self.args.torchscript:
with torch.no_grad():
inference_model = torch.jit.trace(model, input_ids)
else:
inference_model = model
def encoder_decoder_forward():
with torch.no_grad():
outputs = inference_model(input_ids, decoder_input_ids=input_ids)
return outputs
def encoder_forward():
with torch.no_grad():
outputs = inference_model(input_ids)
return outputs
_forward = encoder_decoder_forward if config.is_encoder_decoder else encoder_forward
return _forward
def _prepare_train_func(self, model_name: str, batch_size: int, sequence_length: int) -> Callable[[], None]:
config = self.config_dict[model_name]
has_model_class_in_config = (
hasattr(config, "architectures")
and isinstance(config.architectures, list)
and len(config.architectures) > 0
)
if not self.args.only_pretrain_model and has_model_class_in_config:
try:
model_class = config.architectures[0]
transformers_module = __import__("transformers", fromlist=[model_class])
model_cls = getattr(transformers_module, model_class)
model = model_cls(config)
except ImportError:
raise ImportError(
f"{model_class} does not exist. If you just want to test the pretrained model, you might want to"
" set `--only_pretrain_model` or `args.only_pretrain_model=True`."
)
else:
model = MODEL_WITH_LM_HEAD_MAPPING[config.__class__](config)
if self.args.torchscript:
raise NotImplementedError("Training for torchscript is currently not implemented")
else:
train_model = model
model.train()
model.to(self.args.device)
vocab_size = config.vocab_size if hasattr(config, "vocab_size") else config.encoder.vocab_size
input_ids = torch.randint(vocab_size, (batch_size, sequence_length), dtype=torch.long, device=self.args.device)
if self.args.fp16:
logger.info("Running training in Mixed Precision...")
if not self.args.is_gpu:
raise ValueError("Mixed precision is possible only for GPU.")
model.half()
def compute_loss_and_backprob_encoder():
loss = train_model(input_ids, labels=input_ids)[0]
loss.backward()
return loss
def compute_loss_and_backprob_encoder_decoder():
loss = train_model(input_ids, decoder_input_ids=input_ids, labels=input_ids)[0]
loss.backward()
return loss
_train = (
compute_loss_and_backprob_encoder_decoder
if config.is_encoder_decoder
else compute_loss_and_backprob_encoder
)
return _train
def _measure_speed(self, func) -> float:
try:
logger.info("Do inference on TPU or torchscript. Running model 5 times to stabilize compilation")
timeit.repeat(
func,
repeat=1,
number=5,
)
runtimes = timeit.repeat(
func,
repeat=self.args.repeat,
number=10,
)
if self.args.is_tpu and self.args.torch_xla_tpu_print_metrics:
import torch_xla.debug.metrics as met
self.print_fn(met.metrics_report())
return min(runtimes) / 10.0
except RuntimeError as e:
self.print_fn(f"Doesn't fit on GPU. {e}")
return "N/A"
def _measure_memory(self, func: Callable[[], None]) -> [Memory, MemorySummary]:
try:
if self.args.trace_memory_line_by_line:
trace = start_memory_tracing("transformers")
if self.args.is_tpu:
raise NotImplementedError(
"Memory Benchmarking is currently not implemented for TPU. Please disable memory benchmarking with"
" `--no-memory` or `args.memory=False`"
)
elif self.args.is_gpu:
if not is_py3nvml_available():
logger.warning(
"py3nvml not installed, we won't log GPU memory usage. "
"Install py3nvml (pip install py3nvml) to log information about GPU."
)
memory = "N/A"
else:
logger.info(
"Measuring total GPU usage on GPU device. Make sure to not have additional processes running"
" on the same GPU."
)
nvml.nvmlInit()
func()
handle = nvml.nvmlDeviceGetHandleByIndex(self.args.device_idx)
meminfo = nvml.nvmlDeviceGetMemoryInfo(handle)
max_bytes_in_use = meminfo.used
memory = Memory(max_bytes_in_use)
nvml.nvmlShutdown()
else:
memory_bytes = measure_peak_memory_cpu(func)
memory = Memory(memory_bytes) if isinstance(memory_bytes, int) else memory_bytes
if self.args.trace_memory_line_by_line:
summary = stop_memory_tracing(trace)
else:
summary = None
return memory, summary
except RuntimeError as e:
self.print_fn(f"Doesn't fit on GPU. {e}")
return "N/A", None
.\benchmark\benchmark_args.py
from dataclasses import dataclass, field
from typing import Tuple
from ..utils import (
cached_property,
is_torch_available,
is_torch_xla_available,
is_torch_xpu_available,
logging,
requires_backends,
)
from .benchmark_args_utils import BenchmarkArguments
if is_torch_available():
import torch
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
logger = logging.get_logger(__name__)
@dataclass
class PyTorchBenchmarkArguments(BenchmarkArguments):
deprecated_args = [
"no_inference",
"no_cuda",
"no_tpu",
"no_speed",
"no_memory",
"no_env_print",
"no_multi_process",
]
def __init__(self, **kwargs):
"""
此 __init__ 方法用于向后兼容。在完全移除弃用参数后,可以简单删除这个类。
"""
for deprecated_arg in self.deprecated_args:
if deprecated_arg in kwargs:
positive_arg = deprecated_arg[3:]
setattr(self, positive_arg, not kwargs.pop(deprecated_arg))
logger.warning(
f"{deprecated_arg} is depreciated. Please use --no_{positive_arg} or"
f" {positive_arg}={kwargs[positive_arg]}"
)
self.torchscript = kwargs.pop("torchscript", self.torchscript)
self.torch_xla_tpu_print_metrics = kwargs.pop("torch_xla_tpu_print_metrics", self.torch_xla_tpu_print_metrics)
self.fp16_opt_level = kwargs.pop("fp16_opt_level", self.fp16_opt_level)
super().__init__(**kwargs)
@cached_property
def _setup_devices(self) -> Tuple["torch.device", int]:
requires_backends(self, ["torch"])
logger.info("PyTorch: setting up devices")
if not self.cuda:
device = torch.device("cpu")
n_gpu = 0
elif is_torch_xla_available():
device = xm.xla_device()
n_gpu = 0
elif is_torch_xpu_available():
device = torch.device("xpu")
n_gpu = torch.xpu.device_count()
else:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpu = torch.cuda.device_count()
return device, n_gpu
@property
def is_tpu(self):
return is_torch_xla_available() and self.tpu
@property
def device_idx(self) -> int:
requires_backends(self, ["torch"])
return torch.cuda.current_device()
@property
def device(self) -> "torch.device":
requires_backends(self, ["torch"])
return self._setup_devices[0]
@property
def n_gpu(self):
requires_backends(self, ["torch"])
return self._setup_devices[1]
@property
def is_gpu(self):
return self.n_gpu > 0
.\benchmark\benchmark_args_tf.py
from dataclasses import dataclass, field
from typing import Tuple
from ..utils import cached_property, is_tf_available, logging, requires_backends
from .benchmark_args_utils import BenchmarkArguments
if is_tf_available():
import tensorflow as tf
logger = logging.get_logger(__name__)
@dataclass
class TensorFlowBenchmarkArguments(BenchmarkArguments):
deprecated_args = [
"no_inference",
"no_cuda",
"no_tpu",
"no_speed",
"no_memory",
"no_env_print",
"no_multi_process",
]
def __init__(self, **kwargs):
"""
初始化方法用于处理已弃用的参数。在完全移除弃用参数后,可以删除此方法和相应代码。
"""
for deprecated_arg in self.deprecated_args:
if deprecated_arg in kwargs:
positive_arg = deprecated_arg[3:]
kwargs[positive_arg] = not kwargs.pop(deprecated_arg)
logger.warning(
f"{deprecated_arg} is deprecated. Please use --no-{positive_arg} or "
f"{positive_arg}={kwargs[positive_arg]}"
)
self.tpu_name = kwargs.pop("tpu_name", self.tpu_name)
self.device_idx = kwargs.pop("device_idx", self.device_idx)
self.eager_mode = kwargs.pop("eager_mode", self.eager_mode)
self.use_xla = kwargs.pop("use_xla", self.use_xla)
super().__init__(**kwargs)
tpu_name: str = field(
default=None,
metadata={"help": "Name of TPU"},
)
device_idx: int = field(
default=0,
metadata={"help": "CPU / GPU device index. Defaults to 0."},
)
eager_mode: bool = field(default=False, metadata={"help": "Benchmark models in eager model."})
use_xla: bool = field(
default=False,
metadata={
"help": "Benchmark models using XLA JIT compilation. Note that `eager_model` has to be set to `False`."
},
)
@cached_property
def _setup_tpu(self) -> Tuple["tf.distribute.cluster_resolver.TPUClusterResolver"]:
requires_backends(self, ["tf"])
tpu = None
if self.tpu:
try:
if self.tpu_name:
tpu = tf.distribute.cluster_resolver.TPUClusterResolver(self.tpu_name)
else:
tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
except ValueError:
tpu = None
return tpu
@cached_property
def _setup_strategy(self) -> Tuple["tf.distribute.Strategy", "tf.distribute.cluster_resolver.TPUClusterResolver"]:
requires_backends(self, ["tf"])
if self.is_tpu:
tf.config.experimental_connect_to_cluster(self._setup_tpu)
tf.tpu.experimental.initialize_tpu_system(self._setup_tpu)
strategy = tf.distribute.TPUStrategy(self._setup_tpu)
else:
if self.is_gpu:
tf.config.set_visible_devices(self.gpu_list[self.device_idx], "GPU")
strategy = tf.distribute.OneDeviceStrategy(device=f"/gpu:{self.device_idx}")
else:
tf.config.set_visible_devices([], "GPU")
strategy = tf.distribute.OneDeviceStrategy(device=f"/cpu:{self.device_idx}")
return strategy
@property
def is_tpu(self) -> bool:
requires_backends(self, ["tf"])
return self._setup_tpu is not None
@property
def strategy(self) -> "tf.distribute.Strategy":
requires_backends(self, ["tf"])
return self._setup_strategy
@property
def gpu_list(self):
requires_backends(self, ["tf"])
return tf.config.list_physical_devices("GPU")
@property
def n_gpu(self) -> int:
requires_backends(self, ["tf"])
if self.cuda:
return len(self.gpu_list)
return 0
@property
def is_gpu(self) -> bool:
return self.n_gpu > 0
.\benchmark\benchmark_args_utils.py
import dataclasses
import json
import warnings
from dataclasses import dataclass, field
from time import time
from typing import List
from ..utils import logging
logger = logging.get_logger(__name__)
def list_field(default=None, metadata=None):
return field(default_factory=lambda: default, metadata=metadata)
@dataclass
class BenchmarkArguments:
"""
BenchMarkArguments are arguments we use in our benchmark scripts **which relate to the training loop itself**.
Using `HfArgumentParser` we can turn this class into argparse arguments to be able to specify them on the command
line.
"""
models: List[str] = list_field(
default=[],
metadata={
"help": (
"Model checkpoints to be provided to the AutoModel classes. Leave blank to benchmark the base version"
" of all available models"
)
},
)
batch_sizes: List[int] = list_field(
default=[8], metadata={"help": "List of batch sizes for which memory and time performance will be evaluated"}
)
sequence_lengths: List[int] = list_field(
default=[8, 32, 128, 512],
metadata={"help": "List of sequence lengths for which memory and time performance will be evaluated"},
)
inference: bool = field(
default=True,
metadata={"help": "Whether to benchmark inference of model. Inference can be disabled via --no-inference."},
)
cuda: bool = field(
default=True,
metadata={"help": "Whether to run on available cuda devices. Cuda can be disabled via --no-cuda."},
)
tpu: bool = field(
default=True, metadata={"help": "Whether to run on available tpu devices. TPU can be disabled via --no-tpu."}
)
fp16: bool = field(default=False, metadata={"help": "Use FP16 to accelerate inference."})
training: bool = field(default=False, metadata={"help": "Benchmark training of model"})
verbose: bool = field(default=False, metadata={"help": "Verbose memory tracing"})
speed: bool = field(
default=True,
metadata={"help": "Whether to perform speed measurements. Speed measurements can be disabled via --no-speed."},
)
memory: bool = field(
default=True,
metadata={
"help": "Whether to perform memory measurements. Memory measurements can be disabled via --no-memory"
},
)
trace_memory_line_by_line: bool = field(default=False, metadata={"help": "Trace memory line by line"})
save_to_csv: bool = field(default=False, metadata={"help": "Save result to a CSV file"})
log_print: bool = field(default=False, metadata={"help": "Save all print statements in a log file"})
env_print: bool = field(default=False, metadata={"help": "Whether to print environment information"})
multi_process: bool = field(
default=True,
metadata={
"help": (
"Whether to use multiprocessing for memory and speed measurement. It is highly recommended to use"
" multiprocessing for accurate CPU and GPU memory measurements. This option should only be disabled"
" for debugging / testing and on TPU."
)
},
)
inference_time_csv_file: str = field(
default=f"inference_time_{round(time())}.csv",
metadata={"help": "CSV filename used if saving time results to csv."},
)
inference_memory_csv_file: str = field(
default=f"inference_memory_{round(time())}.csv",
metadata={"help": "CSV filename used if saving memory results to csv."},
)
train_time_csv_file: str = field(
default=f"train_time_{round(time())}.csv",
metadata={"help": "CSV filename used if saving time results to csv for training."},
)
train_memory_csv_file: str = field(
default=f"train_memory_{round(time())}.csv",
metadata={"help": "CSV filename used if saving memory results to csv for training."},
)
env_info_csv_file: str = field(
default=f"env_info_{round(time())}.csv",
metadata={"help": "CSV filename used if saving environment information."},
)
log_filename: str = field(
default=f"log_{round(time())}.csv",
metadata={"help": "Log filename used if print statements are saved in log."},
)
repeat: int = field(default=3, metadata={"help": "Times an experiment will be run."})
only_pretrain_model: bool = field(
default=False,
metadata={
"help": (
"Instead of loading the model as defined in `config.architectures` if exists, just load the pretrain"
" model weights."
)
},
)
def __post_init__(self):
warnings.warn(
f"The class {self.__class__} is deprecated. Hugging Face Benchmarking utils"
" are deprecated in general and it is advised to use external Benchmarking libraries "
" to benchmark Transformer models.",
FutureWarning,
)
def to_json_string(self):
"""
Serializes this instance to a JSON string.
"""
return json.dumps(dataclasses.asdict(self), indent=2)
def model_names(self) -> List[str]:
if len(self.models) <= 0:
raise ValueError(
"Please make sure you provide at least one model name / model identifier, *e.g.* `--models"
" google-bert/bert-base-cased` or `args.models = ['google-bert/bert-base-cased']."
)
return self.models
@property
def do_multi_processing(self):
if not self.multi_process:
return False
elif self.is_tpu:
logger.info("Multiprocessing is currently not possible on TPU.")
return False
else:
return True
.\benchmark\benchmark_tf.py
"""
Benchmarking the library on inference and training in PyTorch.
"""
import random
import timeit
from functools import wraps
from typing import Callable, Optional
from ..configuration_utils import PretrainedConfig
from ..models.auto.modeling_tf_auto import TF_MODEL_MAPPING, TF_MODEL_WITH_LM_HEAD_MAPPING
from ..utils import is_py3nvml_available, is_tf_available, logging
from .benchmark_utils import (
Benchmark,
Memory,
MemorySummary,
measure_peak_memory_cpu,
start_memory_tracing,
stop_memory_tracing,
)
if is_tf_available():
import tensorflow as tf
from tensorflow.python.framework.errors_impl import ResourceExhaustedError
from .benchmark_args_tf import TensorFlowBenchmarkArguments
if is_py3nvml_available():
import py3nvml.py3nvml as nvml
logger = logging.get_logger(__name__)
def run_with_tf_optimizations(do_eager_mode: bool, use_xla: bool):
"""
返回一个装饰器函数,根据参数决定以急切模式还是图模式运行 TensorFlow 函数。
Args:
do_eager_mode (bool): 是否使用急切执行模式
use_xla (bool): 是否使用 XLA 加速
Returns:
Callable: 装饰器函数,用于在急切模式或图模式下运行给定函数
"""
def run_func(func):
@wraps(func)
def run_in_eager_mode(*args, **kwargs):
return func(*args, **kwargs)
@wraps(func)
@tf.function(experimental_compile=use_xla)
def run_in_graph_mode(*args, **kwargs):
return func(*args, **kwargs)
if do_eager_mode is True:
if use_xla is not False:
raise ValueError(
"Cannot run model in XLA, if `args.eager_mode` is set to `True`. Please set `args.eager_mode=False`."
)
return run_in_eager_mode
else:
return run_in_graph_mode
return run_func
def random_input_ids(batch_size: int, sequence_length: int, vocab_size: int) -> ["tf.Tensor"]:
"""
生成指定形状和范围内随机整数张量作为输入 ID。
Args:
batch_size (int): 批量大小
sequence_length (int): 序列长度
vocab_size (int): 词汇表大小
Returns:
tf.Tensor: 随机整数张量,形状为 (batch_size, sequence_length)
"""
rng = random.Random()
values = [rng.randint(0, vocab_size - 1) for i in range(batch_size * sequence_length)]
return tf.constant(values, shape=(batch_size, sequence_length), dtype=tf.int32)
class TensorFlowBenchmark(Benchmark):
"""
TensorFlow 的性能基准测试类,继承自 Benchmark 类。
"""
args: TensorFlowBenchmarkArguments
configs: PretrainedConfig
framework: str = "TensorFlow"
@property
def framework_version(self):
"""
返回当前 TensorFlow 的版本号。
Returns:
str: TensorFlow 的版本号字符串
"""
return tf.__version__
def _inference_speed(self, model_name: str, batch_size: int, sequence_length: int) -> float:
strategy = self.args.strategy
if strategy is None:
raise ValueError("A device strategy has to be initialized before using TensorFlow.")
_inference = self._prepare_inference_func(model_name, batch_size, sequence_length)
return self._measure_speed(_inference)
def _train_speed(self, model_name: str, batch_size: int, sequence_length: int) -> float:
strategy = self.args.strategy
if strategy is None:
raise ValueError("A device strategy has to be initialized before using TensorFlow.")
_train = self._prepare_train_func(model_name, batch_size, sequence_length)
return self._measure_speed(_train)
def _inference_memory(
self, model_name: str, batch_size: int, sequence_length: int
) -> [Memory, Optional[MemorySummary]]:
if self.args.is_gpu:
tf.config.experimental.set_memory_growth(self.args.gpu_list[self.args.device_idx], True)
strategy = self.args.strategy
if strategy is None:
raise ValueError("A device strategy has to be initialized before using TensorFlow.")
_inference = self._prepare_inference_func(model_name, batch_size, sequence_length)
return self._measure_memory(_inference)
def _train_memory(
self, model_name: str, batch_size: int, sequence_length: int
) -> [Memory, Optional[MemorySummary]]:
if self.args.is_gpu:
tf.config.experimental.set_memory_growth(self.args.gpu_list[self.args.device_idx], True)
strategy = self.args.strategy
if strategy is None:
raise ValueError("A device strategy has to be initialized before using TensorFlow.")
_train = self._prepare_train_func(model_name, batch_size, sequence_length)
return self._measure_memory(_train)
def _prepare_inference_func(self, model_name: str, batch_size: int, sequence_length: int) -> Callable[[], None]:
config = self.config_dict[model_name]
if self.args.fp16:
raise NotImplementedError("Mixed precision is currently not supported.")
has_model_class_in_config = (
hasattr(config, "architectures")
and isinstance(config.architectures, list)
and len(config.architectures) > 0
)
if not self.args.only_pretrain_model and has_model_class_in_config:
try:
model_class = "TF" + config.architectures[0]
transformers_module = __import__("transformers", fromlist=[model_class])
model_cls = getattr(transformers_module, model_class)
model = model_cls(config)
except ImportError:
raise ImportError(
f"{model_class} does not exist. If you just want to test the pretrained model, you might want to"
" set `--only_pretrain_model` or `args.only_pretrain_model=True`."
)
else:
model = TF_MODEL_MAPPING[config.__class__](config)
vocab_size = config.vocab_size if hasattr(config, "vocab_size") else config.encoder.vocab_size
input_ids = random_input_ids(batch_size, sequence_length, vocab_size)
@run_with_tf_optimizations(self.args.eager_mode, self.args.use_xla)
def encoder_decoder_forward():
return model(input_ids, decoder_input_ids=input_ids, training=False)
@run_with_tf_optimizations(self.args.eager_mode, self.args.use_xla)
def encoder_forward():
return model(input_ids, training=False)
_inference = encoder_decoder_forward if config.is_encoder_decoder else encoder_forward
return _inference
def _prepare_train_func(self, model_name: str, batch_size: int, sequence_length: int) -> Callable[[], None]:
config = self.config_dict[model_name]
if self.args.eager_mode is not False:
raise ValueError("Training cannot be done in eager mode. Please make sure that `args.eager_mode = False`.")
if self.args.fp16:
raise NotImplementedError("Mixed precision is currently not supported.")
has_model_class_in_config = (
hasattr(config, "architectures")
and isinstance(config.architectures, list)
and len(config.architectures) > 0
)
if not self.args.only_pretrain_model and has_model_class_in_config:
try:
model_class = "TF" + config.architectures[0]
transformers_module = __import__("transformers", fromlist=[model_class])
model_cls = getattr(transformers_module, model_class)
model = model_cls(config)
except ImportError:
raise ImportError(
f"{model_class} does not exist. If you just want to test the pretrained model, you might want to"
" set `--only_pretrain_model` or `args.only_pretrain_model=True`."
)
else:
model = TF_MODEL_WITH_LM_HEAD_MAPPING[config.__class__](config)
vocab_size = config.vocab_size if hasattr(config, "vocab_size") else config.encoder.vocab_size
input_ids = random_input_ids(batch_size, sequence_length, vocab_size)
@run_with_tf_optimizations(self.args.eager_mode, self.args.use_xla)
def encoder_decoder_train():
loss = model(input_ids, decoder_input_ids=input_ids, labels=input_ids, training=True)[0]
gradients = tf.gradients(loss, model.trainable_variables)
return gradients
@run_with_tf_optimizations(self.args.eager_mode, self.args.use_xla)
def encoder_train():
loss = model(input_ids, labels=input_ids, training=True)[0]
gradients = tf.gradients(loss, model.trainable_variables)
return gradients
_train = encoder_decoder_train if config.is_encoder_decoder else encoder_train
return _train
def _measure_speed(self, func) -> float:
with self.args.strategy.scope():
try:
if self.args.is_tpu or self.args.use_xla:
logger.info("Do inference on TPU. Running model 5 times to stabilize compilation")
timeit.repeat(func, repeat=1, number=5)
runtimes = timeit.repeat(
func,
repeat=self.args.repeat,
number=10,
)
return min(runtimes) / 10.0
except ResourceExhaustedError as e:
self.print_fn(f"Doesn't fit on GPU. {e}")
.\benchmark\benchmark_utils.py
import copy
import csv
import linecache
import os
import platform
import sys
import warnings
from abc import ABC, abstractmethod
from collections import defaultdict, namedtuple
from datetime import datetime
from multiprocessing import Pipe, Process, Queue
from multiprocessing.connection import Connection
from typing import Callable, Iterable, List, NamedTuple, Optional, Union
from .. import AutoConfig, PretrainedConfig
from .. import __version__ as version
from ..utils import (
is_psutil_available, is_py3nvml_available, is_tf_available, is_torch_available, logging
)
from .benchmark_args_utils import BenchmarkArguments
if is_torch_available():
from torch.cuda import empty_cache as torch_empty_cache
if is_tf_available():
from tensorflow.python.eager import context as tf_context
if is_psutil_available():
import psutil
if is_py3nvml_available():
import py3nvml.py3nvml as nvml
if platform.system() == "Windows":
from signal import CTRL_C_EVENT as SIGKILL
else:
from signal import SIGKILL
logger = logging.get_logger(__name__)
_is_memory_tracing_enabled = False
BenchmarkOutput = namedtuple(
"BenchmarkOutput",
[
"time_inference_result",
"memory_inference_result",
"time_train_result",
"memory_train_result",
"inference_summary",
"train_summary",
],
)
def separate_process_wrapper_fn(func: Callable[[], None], do_multi_processing: bool) -> Callable[[], None]:
"""
This function wraps another function into its own separated process. In order to ensure accurate memory
measurements it is important that the function is executed in a separate process
Args:
- `func`: (`callable`): function() -> ... generic function which will be executed in its own separate process
- `do_multi_processing`: (`bool`) Whether to run function on separate process or not
"""
def multi_process_func(*args, **kwargs):
def wrapper_func(queue: Queue, *args):
try:
result = func(*args)
except Exception as e:
logger.error(e)
print(e)
result = "N/A"
queue.put(result)
queue = Queue()
p = Process(target=wrapper_func, args=[queue] + list(args))
p.start()
result = queue.get()
p.join()
return result
if do_multi_processing:
logger.info(f"Function {func} is executed in its own process...")
return multi_process_func
else:
return func
def is_memory_tracing_enabled():
global _is_memory_tracing_enabled
return _is_memory_tracing_enabled
class Frame(NamedTuple):
"""
`Frame` 是一个 NamedTuple,用于收集当前帧的状态。`Frame` 有以下字段:
- 'filename' (string): 当前执行的文件名
- 'module' (string): 当前执行的模块名
- 'line_number' (int): 当前执行的行号
- 'event' (string): 触发追踪的事件(默认为 "line")
- 'line_text' (string): Python 脚本中行的文本内容
"""
filename: str
module: str
line_number: int
event: str
line_text: str
class UsedMemoryState(NamedTuple):
"""
`UsedMemoryState` 是一个命名元组,具有以下字段:
- 'frame': 一个 `Frame` 命名元组,存储当前追踪帧的信息(当前文件、当前文件中的位置)
- 'cpu_memory': 执行该行前的 CPU RSS 内存状态
- 'gpu_memory': 执行该行前的 GPU 使用内存(所有 GPU 的总和,或者仅限于 `gpus_to_trace` 指定的 GPU)
"""
frame: Frame
cpu_memory: int
gpu_memory: int
class Memory(NamedTuple):
"""
`Memory` 命名元组只有一个字段 `bytes`,可以通过调用 `__repr__` 方法得到以兆字节为单位的人类可读字符串。
- `bytes` (integer): 字节数
"""
bytes: int
def __repr__(self) -> str:
return str(bytes_to_mega_bytes(self.bytes))
class MemoryState(NamedTuple):
"""
`MemoryState` 是一个命名元组,列出了带有以下字段的帧 + CPU/GPU 内存:
- `frame` (`Frame`): 当前帧(参见上面的定义)
- `cpu`: 当前帧期间消耗的 CPU 内存,作为 `Memory` 命名元组
- `gpu`: 当前帧期间消耗的 GPU 内存,作为 `Memory` 命名元组
- `cpu_gpu`: 当前帧期间消耗的 CPU + GPU 内存,作为 `Memory` 命名元组
"""
frame: Frame
cpu: Memory
gpu: Memory
cpu_gpu: Memory
class MemorySummary(NamedTuple):
"""
`MemorySummary` 是一个命名元组,还未定义字段,将来可能会添加关于内存概述的信息。
"""
sequential: List[MemoryState]
cumulative: List[MemoryState]
current: List[MemoryState]
total: Memory
MemoryTrace = List[UsedMemoryState]
def measure_peak_memory_cpu(function: Callable[[], None], interval=0.5, device_idx=None) -> int:
"""
测量给定函数 `function` 的 CPU 内存峰值消耗,运行时间至少 interval 秒,最多 20 * interval 秒。
此函数受 `memory_profiler` 包中 `memory_usage` 函数的启发:
https://github.com/pythonprofilers/memory_profiler/blob/895c4ac7a08020d66ae001e24067da6dcea42451/memory_profiler.py#L239
Args:
- `function`: (`callable`): 无参数函数,用于测量其内存消耗的函数
- `interval`: (`float`, `optional`, 默认为 `0.5`): 测量内存使用的时间间隔(秒)
- `device_idx`: (`int`, `optional`, 默认为 `None`): 要测量 GPU 使用情况的设备 ID
Returns:
- `max_memory`: (`int`) 字节单位的内存峰值消耗
"""
def get_cpu_memory(process_id: int) -> int:
"""
测量给定 `process_id` 的当前 CPU 内存使用量
Args:
- `process_id`: (`int`) 要测量内存的进程 ID
Returns:
- `memory`: (`int`) 字节单位的内存消耗
"""
process = psutil.Process(process_id)
try:
meminfo_attr = "memory_info" if hasattr(process, "memory_info") else "get_memory_info"
memory = getattr(process, meminfo_attr)()[0]
except psutil.AccessDenied:
raise ValueError("Psutil 访问错误.")
return memory
if not is_psutil_available():
logger.warning(
"未安装 Psutil,无法记录 CPU 内存使用情况。"
"安装 Psutil (pip install psutil) 以使用 CPU 内存跟踪。"
)
max_memory = "N/A"
else:
class MemoryMeasureProcess(Process):
"""
`MemoryMeasureProcess` inherits from `Process` and overwrites its `run()` method. Used to measure the
memory usage of a process
"""
def __init__(self, process_id: int, child_connection: Connection, interval: float):
super().__init__()
self.process_id = process_id
self.interval = interval
self.connection = child_connection
self.num_measurements = 1
self.mem_usage = get_cpu_memory(self.process_id)
def run(self):
self.connection.send(0)
stop = False
while True:
self.mem_usage = max(self.mem_usage, get_cpu_memory(self.process_id))
self.num_measurements += 1
if stop:
break
stop = self.connection.poll(self.interval)
self.connection.send(self.mem_usage)
self.connection.send(self.num_measurements)
while True:
child_connection, parent_connection = Pipe()
mem_process = MemoryMeasureProcess(os.getpid(), child_connection, interval)
mem_process.start()
parent_connection.recv()
try:
function()
parent_connection.send(0)
max_memory = parent_connection.recv()
num_measurements = parent_connection.recv()
except Exception:
parent = psutil.Process(os.getpid())
for child in parent.children(recursive=True):
os.kill(child.pid, SIGKILL)
mem_process.join(0)
raise RuntimeError("Process killed. Error in Process")
mem_process.join(20 * interval)
if (num_measurements > 4) or (interval < 1e-6):
break
interval /= 10
return max_memory
def start_memory_tracing(
modules_to_trace: Optional[Union[str, Iterable[str]]] = None,
modules_not_to_trace: Optional[Union[str, Iterable[str]]] = None,
events_to_trace: str = "line",
gpus_to_trace: Optional[List[int]] = None,
) -> MemoryTrace:
"""
设置逐行跟踪,记录模块或子模块每行的 RAM 使用情况。详见 `./benchmark.py` 示例。
Args:
- `modules_to_trace`: (None, string, list/tuple of string) 如果为 None,则记录所有事件;如果是字符串或字符串列表:仅记录列出的模块/子模块的事件(例如 'fairseq' 或 'transformers.models.gpt2.modeling_gpt2')
- `modules_not_to_trace`: (None, string, list/tuple of string) 如果为 None,则不避免任何模块;如果是字符串或字符串列表:不记录列出的模块/子模块的事件(例如 'torch')
- `events_to_trace`: 要记录的事件字符串或事件字符串列表(参见官方 Python 文档的 `sys.settrace` 关于事件的列表),默认为 line
- `gpus_to_trace`: (可选列表,默认为 None) 要跟踪的 GPU 列表。默认为跟踪所有 GPU
Return:
- `memory_trace`: 一个包含每个事件的 `UsedMemoryState` 列表(默认为跟踪脚本的每行)。
- `UsedMemoryState` 是命名元组,包含以下字段:
- 'frame': 一个 `Frame` 命名元组(见下文),存储当前追踪帧的信息(当前文件、当前文件中的位置)
- 'cpu_memory': 执行该行前的 CPU RSS 内存状态
- 'gpu_memory': 执行该行前的 GPU 使用内存(所有 GPU 的总和或仅对 `gpus_to_trace` 如果提供的 GPU)
`Frame` 是由 `UsedMemoryState` 使用的命名元组,列出当前帧的状态。`Frame` 具有以下字段:
- 'filename' (字符串): 当前执行的文件名
- 'module' (字符串): 当前执行的模块名
- 'line_number' (整数): 当前执行的行号
- 'event' (字符串): 触发跟踪的事件(默认为 "line")
- 'line_text' (字符串): Python 脚本中该行的文本
"""
if is_psutil_available():
process = psutil.Process(os.getpid())
else:
logger.warning(
"Psutil not installed, we won't log CPU memory usage. "
"Install psutil (pip install psutil) to use CPU memory tracing."
)
process = None
if is_py3nvml_available():
try:
nvml.nvmlInit()
devices = list(range(nvml.nvmlDeviceGetCount())) if gpus_to_trace is None else gpus_to_trace
nvml.nvmlShutdown()
except (OSError, nvml.NVMLError):
logger.warning("Error while initializing communication with GPU. We won't perform GPU memory tracing.")
log_gpu = False
else:
log_gpu = is_torch_available() or is_tf_available()
else:
logger.warning(
"py3nvml not installed, we won't log GPU memory usage. "
"Install py3nvml (pip install py3nvml) to use GPU memory tracing."
)
log_gpu = False
memory_trace = []
def traceit(frame, event, args):
"""
定义一个追踪函数,在模块或子模块的每行执行之前执行,记录分配的内存到带有调试信息的列表中
"""
global _is_memory_tracing_enabled
if not _is_memory_tracing_enabled:
return traceit
if events_to_trace is not None:
if isinstance(events_to_trace, str) and event != events_to_trace:
return traceit
elif isinstance(events_to_trace, (list, tuple)) and event not in events_to_trace:
return traceit
if "__name__" not in frame.f_globals:
return traceit
name = frame.f_globals["__name__"]
if not isinstance(name, str):
return traceit
else:
if modules_to_trace is not None:
if isinstance(modules_to_trace, str) and modules_to_trace not in name:
return traceit
elif isinstance(modules_to_trace, (list, tuple)) and all(m not in name for m in modules_to_trace):
return traceit
if modules_not_to_trace is not None:
if isinstance(modules_not_to_trace, str) and modules_not_to_trace in name:
return traceit
elif isinstance(modules_not_to_trace, (list, tuple)) and any(m in name for m in modules_not_to_trace):
return traceit
lineno = frame.f_lineno
filename = frame.f_globals["__file__"]
if filename.endswith(".pyc") or filename.endswith(".pyo"):
filename = filename[:-1]
line = linecache.getline(filename, lineno).rstrip()
traced_state = Frame(filename, name, lineno, event, line)
cpu_mem = 0
if process is not None:
mem = process.memory_info()
cpu_mem = mem.rss
gpu_mem = 0
if log_gpu:
if is_torch_available():
torch_empty_cache()
if is_tf_available():
tf_context.context()._clear_caches()
nvml.nvmlInit()
for i in devices:
handle = nvml.nvmlDeviceGetHandleByIndex(i)
meminfo = nvml.nvmlDeviceGetMemoryInfo(handle)
gpu_mem += meminfo.used
nvml.nvmlShutdown()
mem_state = UsedMemoryState(traced_state, cpu_mem, gpu_mem)
memory_trace.append(mem_state)
return traceit
sys.settrace(traceit)
global _is_memory_tracing_enabled
_is_memory_tracing_enabled = True
return memory_trace
def stop_memory_tracing(
memory_trace: Optional[MemoryTrace] = None, ignore_released_memory: bool = True
) -> Optional[MemorySummary]:
"""
停止内存追踪并返回内存追踪的摘要信息。
Args:
`memory_trace` (optional output of start_memory_tracing, default: None):
要转换为摘要的内存追踪
`ignore_released_memory` (boolean, default: None):
如果为 True,则仅计算内存增加量以获取总内存
Return:
- 如果 `memory_trace` 为 None,则返回 None
- 否则返回一个 `MemorySummary` 命名元组,包含以下字段:
- `sequential`:从提供的 `memory_trace` 计算而来的 `MemoryState` 命名元组列表,通过减去每行执行后的内存从而计算出来。
- `cumulative`:每行的累积内存增加量的 `MemoryState` 命名元组列表,如果一行被多次执行,则累加其内存增加量。列表按内存消耗最大到最小排序(如果内存释放则可能为负数),如果 `ignore_released_memory` 为 True(默认)则忽略释放内存的行。
- `total`:完整追踪期间的总内存增加量,作为 `Memory` 命名元组。
`Memory` 命名元组包含以下字段:
- `byte` (integer): 字节数
- `string` (string): 人类可读的字符串表示 (例如:"3.5MB")
`Frame` 是命名元组,用于列出当前帧状态,包含以下字段:
- 'filename' (string): 当前执行的文件名
- 'module' (string): 当前执行的模块名
- 'line_number' (int): 当前执行的行号
- 'event' (string): 触发追踪的事件(默认为 "line")
- 'line_text' (string): Python 脚本中行的文本
`MemoryState` 是命名元组,列出了帧 + CPU/GPU 内存,包含以下字段:
- `frame` (`Frame`): 当前帧 (参见上文)
- `cpu`: 当前帧期间消耗的 CPU 内存,作为 `Memory` 命名元组
- `gpu`: 当前帧期间消耗的 GPU 内存,作为 `Memory` 命名元组
- `cpu_gpu`: 当前帧期间消耗的 CPU + GPU 内存,作为 `Memory` 命名元组
"""
global _is_memory_tracing_enabled
_is_memory_tracing_enabled = False
if memory_trace is not None and len(memory_trace) > 1:
memory_diff_trace = []
memory_curr_trace = []
cumulative_memory_dict = defaultdict(lambda: [0, 0, 0])
for (
(frame, cpu_mem, gpu_mem),
(next_frame, next_cpu_mem, next_gpu_mem),
) in zip(memory_trace[:-1], memory_trace[1:]):
cpu_mem_inc = next_cpu_mem - cpu_mem
gpu_mem_inc = next_gpu_mem - gpu_mem
cpu_gpu_mem_inc = cpu_mem_inc + gpu_mem_inc
memory_diff_trace.append(
MemoryState(
frame=frame,
cpu=Memory(cpu_mem_inc),
gpu=Memory(gpu_mem_inc),
cpu_gpu=Memory(cpu_gpu_mem_inc),
)
)
memory_curr_trace.append(
MemoryState(
frame=frame,
cpu=Memory(next_cpu_mem),
gpu=Memory(next_gpu_mem),
cpu_gpu=Memory(next_cpu_mem + next_gpu_mem),
)
)
cumulative_memory_dict[frame][0] += cpu_mem_inc
cumulative_memory_dict[frame][1] += gpu_mem_inc
cumulative_memory_dict[frame][2] += cpu_gpu_mem_inc
cumulative_memory = sorted(
cumulative_memory_dict.items(), key=lambda x: x[1][2], reverse=True
)
cumulative_memory = [
MemoryState(
frame=frame,
cpu=Memory(cpu_mem_inc),
gpu=Memory(gpu_mem_inc),
cpu_gpu=Memory(cpu_gpu_mem_inc),
)
for frame, (cpu_mem_inc, gpu_mem_inc, cpu_gpu_mem_inc) in cumulative_memory
]
memory_curr_trace = sorted(memory_curr_trace, key=lambda x: x.cpu_gpu.bytes, reverse=True)
if ignore_released_memory:
total_memory = sum(max(0, step_trace.cpu_gpu.bytes) for step_trace in memory_diff_trace)
else:
total_memory = sum(step_trace.cpu_gpu.bytes for step_trace in memory_diff_trace)
total_memory = Memory(total_memory)
return MemorySummary(
sequential=memory_diff_trace,
cumulative=cumulative_memory,
current=memory_curr_trace,
total=total_memory,
)
return None
def bytes_to_mega_bytes(memory_amount: int) -> int:
"""Utility to convert a number of bytes (int) into a number of mega bytes (int)"""
return memory_amount >> 20
class Benchmark(ABC):
"""
Benchmarks is a simple but feature-complete benchmarking script to compare memory and time performance of models in
Transformers.
"""
args: BenchmarkArguments
configs: PretrainedConfig
framework: str
def __init__(self, args: BenchmarkArguments = None, configs: PretrainedConfig = None):
self.args = args
if configs is None:
self.config_dict = {
model_name: AutoConfig.from_pretrained(model_name) for model_name in self.args.model_names
}
else:
self.config_dict = dict(zip(self.args.model_names, configs))
warnings.warn(
f"The class {self.__class__} is deprecated. Hugging Face Benchmarking utils"
" are deprecated in general and it is advised to use external Benchmarking libraries "
" to benchmark Transformer models.",
FutureWarning,
)
if self.args.memory and os.getenv("TRANSFORMERS_USE_MULTIPROCESSING") == 0:
logger.warning(
"Memory consumption will not be measured accurately if `args.multi_process` is set to `False.` The"
" flag 'TRANSFORMERS_USE_MULTIPROCESSING' should only be disabled for debugging / testing."
)
self._print_fn = None
self._framework_version = None
self._environment_info = None
@property
def print_fn(self):
if self._print_fn is None:
if self.args.log_print:
def print_and_log(*args):
with open(self.args.log_filename, "a") as log_file:
log_file.write("".join(args) + "\n")
print(*args)
self._print_fn = print_and_log
else:
self._print_fn = print
return self._print_fn
@property
@abstractmethod
def framework_version(self):
pass
@abstractmethod
def _inference_speed(self, model_name: str, batch_size: int, sequence_length: int) -> float:
pass
@abstractmethod
def _train_speed(self, model_name: str, batch_size: int, sequence_length: int) -> float:
pass
@abstractmethod
def _inference_memory(
self, model_name: str, batch_size: int, sequence_length: int
) -> [Memory, Optional[MemorySummary]]:
pass
@abstractmethod
def _train_memory(
self, model_name: str, batch_size: int, sequence_length: int
) -> [Memory, Optional[MemorySummary]]:
pass
def inference_speed(self, *args, **kwargs) -> float:
return separate_process_wrapper_fn(self._inference_speed, self.args.do_multi_processing)(*args, **kwargs)
def train_speed(self, *args, **kwargs) -> float:
return separate_process_wrapper_fn(self._train_speed, self.args.do_multi_processing)(*args, **kwargs)
def inference_memory(self, *args, **kwargs) -> [Memory, Optional[MemorySummary]]:
return separate_process_wrapper_fn(self._inference_memory, self.args.do_multi_processing)(*args, **kwargs)
def train_memory(self, *args, **kwargs) -> [Memory, Optional[MemorySummary]]:
return separate_process_wrapper_fn(self._train_memory, self.args.do_multi_processing)(*args, **kwargs)
@property
def environment_info(self):
if self._environment_info is None:
info = {}
info["transformers_version"] = version
info["framework"] = self.framework
if self.framework == "PyTorch":
info["use_torchscript"] = self.args.torchscript
if self.framework == "TensorFlow":
info["eager_mode"] = self.args.eager_mode
info["use_xla"] = self.args.use_xla
info["framework_version"] = self.framework_version
info["python_version"] = platform.python_version()
info["system"] = platform.system()
info["cpu"] = platform.processor()
info["architecture"] = platform.architecture()[0]
info["date"] = datetime.date(datetime.now())
info["time"] = datetime.time(datetime.now())
info["fp16"] = self.args.fp16
info["use_multiprocessing"] = self.args.do_multi_processing
info["only_pretrain_model"] = self.args.only_pretrain_model
if is_psutil_available():
info["cpu_ram_mb"] = bytes_to_mega_bytes(psutil.virtual_memory().total)
else:
logger.warning(
"Psutil not installed, we won't log available CPU memory. "
"Install psutil (pip install psutil) to log available CPU memory."
)
info["cpu_ram_mb"] = "N/A"
info["use_gpu"] = self.args.is_gpu
if self.args.is_gpu:
info["num_gpus"] = 1
if is_py3nvml_available():
nvml.nvmlInit()
handle = nvml.nvmlDeviceGetHandleByIndex(self.args.device_idx)
info["gpu"] = nvml.nvmlDeviceGetName(handle)
info["gpu_ram_mb"] = bytes_to_mega_bytes(nvml.nvmlDeviceGetMemoryInfo(handle).total)
info["gpu_power_watts"] = nvml.nvmlDeviceGetPowerManagementLimit(handle) / 1000
info["gpu_performance_state"] = nvml.nvmlDeviceGetPerformanceState(handle)
nvml.nvmlShutdown()
else:
logger.warning(
"py3nvml not installed, we won't log GPU memory usage. "
"Install py3nvml (pip install py3nvml) to log information about GPU."
)
info["gpu"] = "N/A"
info["gpu_ram_mb"] = "N/A"
info["gpu_power_watts"] = "N/A"
info["gpu_performance_state"] = "N/A"
info["use_tpu"] = self.args.is_tpu
self._environment_info = info
return self._environment_info
def print_results(self, result_dict, type_label):
self.print_fn(80 * "-")
self.print_fn(
"Model Name".center(30) + "Batch Size".center(15) + "Seq Length".center(15) + type_label.center(15)
)
self.print_fn(80 * "-")
for model_name in self.args.model_names:
for batch_size in result_dict[model_name]["bs"]:
for sequence_length in result_dict[model_name]["ss"]:
result = result_dict[model_name]["result"][batch_size][sequence_length]
if isinstance(result, float):
result = round(1000 * result) / 1000
result = "< 0.001" if result == 0.0 else str(result)
else:
result = str(result)
self.print_fn(
model_name[:30].center(30) + str(batch_size).center(15),
str(sequence_length).center(15),
result.center(15),
)
self.print_fn(80 * "-")
def print_memory_trace_statistics(self, summary: MemorySummary):
self.print_fn(
"\nLine by line memory consumption:\n"
+ "\n".join(
f"{state.frame.filename}:{state.frame.line_number}: mem {state.cpu_gpu}: {state.frame.line_text}"
for state in summary.sequential
)
)
self.print_fn(
"\nLines with top memory consumption:\n"
+ "\n".join(
f"=> {state.frame.filename}:{state.frame.line_number}: mem {state.cpu_gpu}: {state.frame.line_text}"
for state in summary.cumulative[:6]
)
)
self.print_fn(
"\nLines with lowest memory consumption:\n"
+ "\n".join(
f"=> {state.frame.filename}:{state.frame.line_number}: mem {state.cpu_gpu}: {state.frame.line_text}"
for state in summary.cumulative[-6:]
)
)
self.print_fn(f"\nTotal memory increase: {summary.total}")
def save_to_csv(self, result_dict, filename):
if not self.args.save_to_csv:
return
self.print_fn("Saving results to csv.")
with open(filename, mode="w") as csv_file:
if len(self.args.model_names) <= 0:
raise ValueError(f"At least 1 model should be defined, but got {self.model_names}")
fieldnames = ["model", "batch_size", "sequence_length"]
writer = csv.DictWriter(csv_file, fieldnames=fieldnames + ["result"])
writer.writeheader()
for model_name in self.args.model_names:
result_dict_model = result_dict[model_name]["result"]
for bs in result_dict_model:
for ss in result_dict_model[bs]:
result_model = result_dict_model[bs][ss]
writer.writerow(
{
"model": model_name,
"batch_size": bs,
"sequence_length": ss,
"result": ("{}" if not isinstance(result_model, float) else "{:.4f}").format(
result_model
),
}
)
.\benchmark\__init__.py
def bubble_sort(arr):
n = len(arr)
for i in range(n - 1):
for j in range(0, n - i - 1):
if arr[j] > arr[j + 1]:
arr[j], arr[j + 1] = arr[j + 1], arr[j]
return arr
.\cache_utils.py
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple
import torch
from .configuration_utils import PretrainedConfig
from .utils import logging
logger = logging.get_logger(__name__)
@dataclass
class Cache:
"""
Base, abstract class for all caches. The actual data structure is specific to each subclass.
"""
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
Parameters:
key_states (`torch.Tensor`):
The new key states to cache.
value_states (`torch.Tensor`):
The new value states to cache.
layer_idx (`int`):
The index of the layer to cache the states for.
cache_kwargs (`Dict[str, Any]`, `optional`):
Additional arguments for the cache subclass. These are specific to each subclass and allow new types of
cache to be created.
Return:
A tuple containing the updated key and value states.
"""
raise NotImplementedError("Make sure to implement `update` in a subclass.")
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.")
def get_max_length(self) -> Optional[int]:
"""Returns the maximum sequence length of the cached states, if there is any."""
raise NotImplementedError("Make sure to implement `get_max_length` in a subclass.")
def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int:
"""Given the sequence length of the new inputs, returns the usable length of the cache."""
max_length = self.get_max_length()
previous_seq_length = self.get_seq_length(layer_idx)
if max_length is not None and previous_seq_length + new_seq_length > max_length:
return max_length - new_seq_length
return previous_seq_length
@property
def seen_tokens(self):
logger.warning_once(
"The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` "
"model input instead."
)
if hasattr(self, "_seen_tokens"):
return self._seen_tokens
else:
return None
class DynamicCache(Cache):
"""
Concrete subclass of Cache representing a dynamic cache.
"""
"""
A cache that grows dynamically as more tokens are generated. This is the default for generative models.
It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
`[batch_size, num_heads, seq_len, head_dim]`.
"""
def __init__(self) -> None:
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
self._seen_tokens = 0
def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
"""
Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
sequence length.
"""
if layer_idx < len(self):
return (self.key_cache[layer_idx], self.value_cache[layer_idx])
else:
raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
def __iter__(self):
"""
Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over
keys and values
"""
for layer_idx in range(len(self)):
yield (self.key_cache[layer_idx], self.value_cache[layer_idx])
def __len__(self):
"""
Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
to the number of layers in the model.
"""
return len(self.key_cache)
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
):
"""
Update function to update the cache with new key and value states for a specific layer.
"""
self.key_cache[layer_idx] = key_states
self.value_cache[layer_idx] = value_states
if cache_kwargs is not None:
pass
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
Parameters:
key_states (`torch.Tensor`):
The new key states to cache.
value_states (`torch.Tensor`):
The new value states to cache.
layer_idx (`int`):
The index of the layer to cache the states for.
cache_kwargs (`Dict[str, Any]`, `optional`):
Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
Return:
A tuple containing the updated key and value states.
"""
if layer_idx == 0:
self._seen_tokens += key_states.shape[-2]
if len(self.key_cache) <= layer_idx:
self.key_cache.append(key_states)
self.value_cache.append(value_states)
else:
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
return self.key_cache[layer_idx], self.value_cache[layer_idx]
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
if len(self.key_cache) <= layer_idx:
return 0
return self.key_cache[layer_idx].shape[-2]
def get_max_length(self) -> Optional[int]:
"""Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
return None
def reorder_cache(self, beam_idx: torch.LongTensor):
"""Reorders the cache for beam search, given the selected beam indices."""
for layer_idx in range(len(self.key_cache)):
device = self.key_cache[layer_idx].device
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
device = self.value_cache[layer_idx].device
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
"""Converts the `DynamicCache` instance into the its equivalent in the legacy cache format."""
legacy_cache = ()
for layer_idx in range(len(self)):
legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),)
return legacy_cache
@classmethod
def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
"""Converts a cache in the legacy cache format into an equivalent `DynamicCache`."""
cache = cls()
if past_key_values is not None:
for layer_idx in range(len(past_key_values)):
key_states, value_states = past_key_values[layer_idx]
cache.update(key_states, value_states, layer_idx)
return cache
class SinkCache(Cache):
"""
A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to
generate beyond the length of its context window, without losing fluency in the conversation. As it discards past
tokens, the model will lose the ability to generate tokens that depend on the context that was discarded.
It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
`[batch_size, num_heads, seq_len, head_dim]`.
Parameters:
window_length (`int`):
The length of the context window.
num_sink_tokens (`int`):
The number of sink tokens. See the original paper for more information.
"""
def __init__(self, window_length: int, num_sink_tokens: int) -> None:
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
self.window_length = window_length
self.num_sink_tokens = num_sink_tokens
self.cos_sin_cache = {}
self._seen_tokens = 0
@staticmethod
def _rotate_half(x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def _apply_key_rotary_pos_emb(
self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> torch.Tensor:
rotated_key_states = (key_states * cos) + (self._rotate_half(key_states) * sin)
return rotated_key_states
def _get_rerotation_cos_sin(
self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
if key_states.shape[-2] not in self.cos_sin_cache:
cos = cos.to(torch.float32)
sin = sin.to(torch.float32)
original_cos = cos[self.num_sink_tokens + key_states.shape[-2] :]
shifted_cos = cos[self.num_sink_tokens : -key_states.shape[-2]]
original_sin = sin[self.num_sink_tokens + key_states.shape[-2] :]
shifted_sin = sin[self.num_sink_tokens : -key_states.shape[-2]]
rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin
rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin
self.cos_sin_cache[key_states.shape[-2]] = (
rerotation_cos.to(key_states.dtype).unsqueeze(0),
rerotation_sin.to(key_states.dtype).unsqueeze(0),
)
return self.cos_sin_cache[key_states.shape[-2]]
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
if len(self.key_cache) <= layer_idx:
return 0
return self.key_cache[layer_idx].shape[-2]
def get_max_length(self) -> Optional[int]:
"""Returns the maximum sequence length of the cached states."""
return self.window_length
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
):
"""Updates the cache with new key and value states for a specific layer."""
self.key_cache[layer_idx] = key_states
self.value_cache[layer_idx] = value_states
def reorder_cache(self, beam_idx: torch.LongTensor):
"""Reorders the cache for beam search, given the selected beam indices."""
for layer_idx in range(len(self.key_cache)):
device = self.key_cache[layer_idx].device
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
device = self.value_cache[layer_idx].device
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
"""
Static Cache class to be used with `torch.compile(model)`.
Parameters:
config (`PretrainedConfig):
The configuration file defining the `max_position_embeddings`, `hidden_size` and `num_attention_heads`
required to initialize the static cache.
max_batch_size (`int`):
The maximum batch size with which the model will be used.
max_cache_len (`int`):
The maximum sequence length with which the model will be used.
device (`torch.device`):
The device on which the cache should be initialized. Should be the same as the layer.
dtype (*optional*, defaults to `torch.float32`):
The default `dtype` to use when initializing the layer.
"""
def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None:
super().__init__()
self.max_batch_size = max_batch_size
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
self.head_dim = (
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
)
self.dtype = dtype if dtype is not None else torch.float32
self.num_key_value_heads = (
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
)
cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
self.key_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device)
self.value_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device)
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
def update_cache(self, key_states: torch.Tensor, value_states: torch.Tensor,
layer_idx: int, cache_kwargs: Optional[Dict[str, Any]] = None) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
It is VERY important to index using a tensor, otherwise you introduce a copy to the device.
Parameters:
key_states (`torch.Tensor`):
The new key states to cache.
value_states (`torch.Tensor`):
The new value states to cache.
layer_idx (`int`):
The index of the layer to cache the states for. Kept for backward compatibility
cache_kwargs (`Dict[str, Any]`, `optional`):
Additional arguments for the cache subclass. The `StaticCache` just needs the `q_len`
to know how much of the cache it should overwrite.
Return:
A tuple containing the updated key and value states.
"""
new_cache_positions = cache_kwargs.get("cache_position")
k_out = self.key_cache
v_out = self.value_cache
k_out[:, :, new_cache_positions] = key_states
v_out[:, :, new_cache_positions] = value_states
return k_out, v_out
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""
Returns the sequence length of the cached states that were seen by the model. `layer_idx` kept for BC
"""
return (self.key_cache[0, 0].any(dim=-1)).sum()
def get_max_length(self) -> Optional[int]:
"""
Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length.
"""
return self.max_cache_len
def reorder_cache(self, beam_idx: torch.LongTensor):
"""
Reorders the cache for beam search, given the selected beam indices.
"""
device = self.key_cache.device
self.key_cache = self.key_cache.index_select(0, beam_idx.to(device))
device = self.value_cache.device
self.value_cache = self.value_cache.index_select(0, beam_idx.to(device))
def to_legacy_cache(self):
"""
Dummy function for BC. We have to keep it because otherwise the call in the forward of models will break it
"""
return None
.\commands\add_new_model.py
import json
import os
import shutil
import warnings
from argparse import ArgumentParser, Namespace
from pathlib import Path
from typing import List
from ..utils import logging
from . import BaseTransformersCLICommand
try:
from cookiecutter.main import cookiecutter
_has_cookiecutter = True
except ImportError:
_has_cookiecutter = False
logger = logging.get_logger(__name__)
def add_new_model_command_factory(args: Namespace):
return AddNewModelCommand(args.testing, args.testing_file, path=args.path)
class AddNewModelCommand(BaseTransformersCLICommand):
@staticmethod
def register_subcommand(parser: ArgumentParser):
add_new_model_parser = parser.add_parser("add-new-model")
add_new_model_parser.add_argument("--testing", action="store_true", help="If in testing mode.")
add_new_model_parser.add_argument("--testing_file", type=str, help="Configuration file on which to run.")
add_new_model_parser.add_argument(
"--path", type=str, help="Path to cookiecutter. Should only be used for testing purposes."
)
add_new_model_parser.set_defaults(func=add_new_model_command_factory)
def __init__(self, testing: bool, testing_file: str, path=None, *args):
self._testing = testing
self._testing_file = testing_file
self._path = path
.\commands\add_new_model_like.py
import difflib
import json
import os
import re
from argparse import ArgumentParser, Namespace
from dataclasses import dataclass
from datetime import date
from itertools import chain
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Pattern, Tuple, Union
import yaml
from ..models import auto as auto_module
from ..models.auto.configuration_auto import model_type_to_module_name
from ..utils import is_flax_available, is_tf_available, is_torch_available, logging
from . import BaseTransformersCLICommand
logger = logging.get_logger(__name__)
CURRENT_YEAR = date.today().year
TRANSFORMERS_PATH = Path(__file__).parent.parent
REPO_PATH = TRANSFORMERS_PATH.parent.parent
@dataclass
class ModelPatterns:
"""
Holds the basic information about a new model for the add-new-model-like command.
"""
Args:
model_name (`str`): 模型名称。
checkpoint (`str`): 用于文档示例的检查点。
model_type (`str`, *optional*):
模型类型,内部库中使用的标识符,如 `bert` 或 `xlm-roberta`。默认为 `model_name` 的小写形式,空格用短横线(-)替换。
model_lower_cased (`str`, *optional*):
模型名称的小写形式,用于模块名称或函数名称。默认为 `model_name` 的小写形式,空格和短横线都替换为下划线。
model_camel_cased (`str`, *optional*):
模型名称的驼峰式命名形式,用于类名。默认为 `model_name` 的驼峰式命名(考虑空格和短横线都作为单词分隔符)。
model_upper_cased (`str`, *optional*):
模型名称的大写形式,用于常量名称。默认为 `model_name` 的大写形式,空格和短横线都替换为下划线。
config_class (`str`, *optional*):
与此模型关联的配置类。默认为 `"{model_camel_cased}Config"`。
tokenizer_class (`str`, *optional*):
与此模型关联的分词器类(对于不使用分词器的模型,请将其保留为 `None`)。
image_processor_class (`str`, *optional*):
与此模型关联的图像处理器类(对于不使用图像处理器的模型,请将其保留为 `None`)。
feature_extractor_class (`str`, *optional*):
与此模型关联的特征提取器类(对于不使用特征提取器的模型,请将其保留为 `None`)。
processor_class (`str`, *optional*):
与此模型关联的处理器类(对于不使用处理器的模型,请将其保留为 `None`)。
def __post_init__(self):
if self.model_type is None:
self.model_type = self.model_name.lower().replace(" ", "-")
if self.model_lower_cased is None:
self.model_lower_cased = self.model_name.lower().replace(" ", "_").replace("-", "_")
if self.model_camel_cased is None:
words = self.model_name.split(" ")
words = list(chain(*[w.split("-") for w in words]))
words = [w[0].upper() + w[1:] for w in words]
self.model_camel_cased = "".join(words)
if self.model_upper_cased is None:
self.model_upper_cased = self.model_name.upper().replace(" ", "_").replace("-", "_")
if self.config_class is None:
self.config_class = f"{self.model_camel_cased}Config"
ATTRIBUTE_TO_PLACEHOLDER = {
"config_class": "[CONFIG_CLASS]",
"tokenizer_class": "[TOKENIZER_CLASS]",
"image_processor_class": "[IMAGE_PROCESSOR_CLASS]",
"feature_extractor_class": "[FEATURE_EXTRACTOR_CLASS]",
"processor_class": "[PROCESSOR_CLASS]",
"checkpoint": "[CHECKPOINT]",
"model_type": "[MODEL_TYPE]",
"model_upper_cased": "[MODEL_UPPER_CASED]",
"model_camel_cased": "[MODEL_CAMELCASED]",
"model_lower_cased": "[MODEL_LOWER_CASED]",
"model_name": "[MODEL_NAME]",
}
def is_empty_line(line: str) -> bool:
"""
Determines whether a line is empty or not.
判断一行是否为空行。
"""
return len(line) == 0 or line.isspace()
def find_indent(line: str) -> int:
"""
Returns the number of spaces that start a line indent.
返回一行开头的空格数,即缩进量。
"""
search = re.search(r"^(\s*)(?:\S|$)", line)
if search is None:
return 0
return len(search.groups()[0])
def parse_module_content(content: str) -> List[str]:
"""
Parse the content of a module in the list of objects it defines.
Args:
content (`str`): The content to parse
要解析的模块内容。
Returns:
`List[str]`: The list of objects defined in the module.
返回模块定义的对象列表。
"""
objects = []
current_object = []
lines = content.split("\n")
end_markers = [")", "]", "}", '"""']
for line in lines:
is_valid_object = len(current_object) > 0
if is_valid_object and len(current_object) == 1:
is_valid_object = not current_object[0].startswith("# Copied from")
if not is_empty_line(line) and find_indent(line) == 0 and is_valid_object:
if line in end_markers:
current_object.append(line)
objects.append("\n".join(current_object))
current_object = []
else:
objects.append("\n".join(current_object))
current_object = [line]
else:
current_object.append(line)
if len(current_object) > 0:
objects.append("\n".join(current_object))
return objects
def extract_block(content: str, indent_level: int = 0) -> str:
"""
Return the first block in `content` with the indent level `indent_level`.
The first line in `content` should be indented at `indent_level` level, otherwise an error will be thrown.
This method will immediately stop the search when a (non-empty) line with indent level less than `indent_level` is
encountered.
Args:
content (`str`): The content to parse
indent_level (`int`, *optional*, default to 0): The indent level of the blocks to search for
Returns:
`str`: The first block in `content` with the indent level `indent_level`.
返回在`content`中具有缩进级别`indent_level`的第一个块。
Raises:
ValueError: If the content does not start with the specified indent level.
如果内容不以指定的缩进级别开头,则引发 ValueError 异常。
"""
current_object = []
lines = content.split("\n")
end_markers = [")", "]", "}", '"""']
for idx, line in enumerate(lines):
if idx == 0 and indent_level > 0 and not is_empty_line(line) and find_indent(line) != indent_level:
raise ValueError(
f"When `indent_level > 0`, the first line in `content` should have indent level {indent_level}. Got "
f"{find_indent(line)} instead."
)
if find_indent(line) < indent_level and not is_empty_line(line):
break
is_valid_object = len(current_object) > 0
if (
not is_empty_line(line)
and not line.endswith(":")
and find_indent(line) == indent_level
and is_valid_object
):
if line.lstrip() in end_markers:
current_object.append(line)
return "\n".join(current_object)
else:
current_object.append(line)
if len(current_object) > 0:
return "\n".join(current_object)
def add_content_to_text(
text: str,
content: str,
add_after: Optional[Union[str, Pattern]] = None,
add_before: Optional[Union[str, Pattern]] = None,
exact_match: bool = False,
) -> str:
"""
A utility to add some content inside a given text.
Args:
text (`str`): The text in which we want to insert some content.
content (`str`): The content to add.
add_after (`str` or `Pattern`):
The pattern to test on a line of `text`, the new content is added after the first instance matching it.
add_before (`str` or `Pattern`):
The pattern to test on a line of `text`, the new content is added before the first instance matching it.
exact_match (`bool`, *optional*, defaults to `False`):
A line is considered a match with `add_after` or `add_before` if it matches exactly when `exact_match=True`,
otherwise, if `add_after`/`add_before` is present in the line.
<Tip warning={true}>
The arguments `add_after` and `add_before` are mutually exclusive, and one exactly needs to be provided.
</Tip>
Returns:
`str`: The text with the new content added if a match was found.
"""
if add_after is None and add_before is None:
raise ValueError("You need to pass either `add_after` or `add_before`")
if add_after is not None and add_before is not None:
raise ValueError("You can't pass both `add_after` or `add_before`")
pattern = add_after if add_before is None else add_before
def this_is_the_line(line):
if isinstance(pattern, Pattern):
return pattern.search(line) is not None
elif exact_match:
return pattern == line
else:
return pattern in line
new_lines = []
for line in text.split("\n"):
if this_is_the_line(line):
if add_before is not None:
new_lines.append(content)
new_lines.append(line)
if add_after is not None:
new_lines.append(content)
else:
new_lines.append(line)
return "\n".join(new_lines)
def add_content_to_file(
file_name: Union[str, os.PathLike],
content: str,
add_after: Optional[Union[str, Pattern]] = None,
add_before: Optional[Union[str, Pattern]] = None,
exact_match: bool = False,
):
"""
A utility to add some content inside a given file.
<Tip warning={true}>
The arguments `add_after` and `add_before` are mutually exclusive, and one exactly needs to be provided.
</Tip>
"""
with open(file_name, "r", encoding="utf-8") as f:
old_content = f.read()
new_content = add_content_to_text(
old_content, content, add_after=add_after, add_before=add_before, exact_match=exact_match
)
with open(file_name, "w", encoding="utf-8") as f:
f.write(new_content)
def replace_model_patterns(
text: str, old_model_patterns: ModelPatterns, new_model_patterns: ModelPatterns
) -> Tuple[str, str]:
"""
Replace all patterns present in a given text.
Args:
text (`str`): The text to treat.
old_model_patterns (`ModelPatterns`): The patterns for the old model.
new_model_patterns (`ModelPatterns`): The patterns for the new model.
Returns:
`Tuple(str, str)`: A tuple of with the treated text and the replacement actually done in it.
"""
attributes_to_check = ["config_class"]
for attr in ["tokenizer_class", "image_processor_class", "feature_extractor_class", "processor_class"]:
if getattr(old_model_patterns, attr) is not None and getattr(new_model_patterns, attr) is not None:
attributes_to_check.append(attr)
if old_model_patterns.checkpoint not in [old_model_patterns.model_type, old_model_patterns.model_lower_cased]:
attributes_to_check.append("checkpoint")
if old_model_patterns.model_type != old_model_patterns.model_lower_cased:
attributes_to_check.append("model_type")
else:
text = re.sub(
rf'(\s*)model_type = "{old_model_patterns.model_type}"',
r'\1model_type = "[MODEL_TYPE]"',
text,
)
if old_model_patterns.model_upper_cased == old_model_patterns.model_camel_cased:
old_model_value = old_model_patterns.model_upper_cased
if re.search(rf"{old_model_value}_[A-Z_]*[^A-Z_]", text) is not None:
text = re.sub(rf"{old_model_value}([A-Z_]*)([^a-zA-Z_])", r"[MODEL_UPPER_CASED]\1\2", text)
else:
attributes_to_check.append("model_upper_cased")
attributes_to_check.extend(["model_camel_cased", "model_lower_cased", "model_name"])
for attr in attributes_to_check:
text = text.replace(getattr(old_model_patterns, attr), ATTRIBUTE_TO_PLACEHOLDER[attr])
replacements = []
for attr, placeholder in ATTRIBUTE_TO_PLACEHOLDER.items():
if placeholder in text:
replacements.append((getattr(old_model_patterns, attr), getattr(new_model_patterns, attr)))
text = text.replace(placeholder, getattr(new_model_patterns, attr))
old_replacement_values = [old for old, new in replacements]
if len(set(old_replacement_values)) != len(old_replacement_values):
return text, ""
replacements = simplify_replacements(replacements)
replacements = [f"{old}->{new}" for old, new in replacements]
return text, ",".join(replacements)
def simplify_replacements(replacements):
if len(replacements) <= 1:
return replacements
replacements.sort(key=lambda x: len(x[0]))
idx = 0
while idx < len(replacements):
old, new = replacements[idx]
j = idx + 1
while j < len(replacements):
old_2, new_2 = replacements[j]
if old_2.replace(old, new) == new_2:
replacements.pop(j)
else:
j += 1
idx += 1
return replacements
def get_module_from_file(module_file: Union[str, os.PathLike]) -> str:
full_module_path = Path(module_file).absolute()
module_parts = full_module_path.with_suffix("").parts
idx = len(module_parts) - 1
while idx >= 0 and module_parts[idx] != "transformers":
idx -= 1
if idx < 0:
raise ValueError(f"{module_file} is not a transformers module.")
return ".".join(module_parts[idx:])
SPECIAL_PATTERNS = {
"_CHECKPOINT_FOR_DOC =": "checkpoint",
"_CONFIG_FOR_DOC =": "config_class",
"_TOKENIZER_FOR_DOC =": "tokenizer_class",
"_IMAGE_PROCESSOR_FOR_DOC =": "image_processor_class",
"_FEAT_EXTRACTOR_FOR_DOC =": "feature_extractor_class",
"_PROCESSOR_FOR_DOC =": "processor_class",
}
_re_class_func = re.compile(r"^(?:class|def)\s+([^\s:\(]+)\s*(?:\(|\:)", flags=re.MULTILINE)
def remove_attributes(obj, target_attr):
lines = obj.split(os.linesep)
target_idx = None
for idx, line in enumerate(lines):
if line.lstrip().startswith(f"{target_attr} = "):
target_idx = idx
break
elif line.lstrip().startswith(f"def {target_attr}("):
target_idx = idx
break
if target_idx is None:
return obj
line = lines[target_idx]
indent_level = find_indent(line)
parsed = extract_block("\n".join(lines[target_idx:]), indent_level)
num_lines = len(parsed.split("\n"))
for idx in range(num_lines):
lines[target_idx + idx] = None
for idx in range(target_idx - 1, -1, -1):
line = lines[idx]
if (line.lstrip().startswith("#") or line.lstrip().startswith("@")) and find_indent(line) == indent_level:
lines[idx] = None
else:
break
new_obj = os.linesep.join([x for x in lines if x is not None])
return new_obj
"""
Create a new module from an existing one and adapting all function and classes names from old patterns to new ones.
Args:
module_file (`str` or `os.PathLike`): Path to the module to duplicate.
old_model_patterns (`ModelPatterns`): The patterns for the old model.
new_model_patterns (`ModelPatterns`): The patterns for the new model.
dest_file (`str` or `os.PathLike`, *optional*): Path to the new module.
add_copied_from (`bool`, *optional*, defaults to `True`):
Whether or not to add `# Copied from` statements in the duplicated module.
"""
if dest_file is None:
dest_file = str(module_file).replace(
old_model_patterns.model_lower_cased, new_model_patterns.model_lower_cased
)
with open(module_file, "r", encoding="utf-8") as f:
content = f.read()
content = re.sub(r"# Copyright (\d+)\s", f"# Copyright {CURRENT_YEAR} ", content)
objects = parse_module_content(content)
new_objects = []
for obj in objects:
if "PRETRAINED_CONFIG_ARCHIVE_MAP = {" in obj:
obj = (
f"{new_model_patterns.model_upper_cased}_PRETRAINED_CONFIG_ARCHIVE_MAP = "
+ "{"
+ f"""
"{new_model_patterns.checkpoint}": "https://huggingface.co/{new_model_patterns.checkpoint}/resolve/main/config.json",
"""
+ "}\n"
)
new_objects.append(obj)
continue
elif "PRETRAINED_MODEL_ARCHIVE_LIST = [" in obj:
if obj.startswith("TF_"):
prefix = "TF_"
elif obj.startswith("FLAX_"):
prefix = "FLAX_"
else:
prefix = ""
obj = f"""{prefix}{new_model_patterns.model_upper_cased}_PRETRAINED_MODEL_ARCHIVE_LIST = [
"{new_model_patterns.checkpoint}",
# See all {new_model_patterns.model_name} models at https://huggingface.co/models?filter={new_model_patterns.model_type}
]
"""
new_objects.append(obj)
def filter_framework_files(
files: List[Union[str, os.PathLike]], frameworks: Optional[List[str]] = None
) -> List[Union[str, os.PathLike]]:
"""
Filter a list of files to only keep the ones corresponding to a list of frameworks.
Args:
files (`List[Union[str, os.PathLike]]`): The list of files to filter.
frameworks (`List[str]`, *optional*): The list of allowed frameworks.
Returns:
`List[Union[str, os.PathLike]]`: The list of filtered files.
"""
if frameworks is None:
frameworks = get_default_frameworks()
framework_to_file = {}
others = []
for f in files:
parts = Path(f).name.split("_")
if "modeling" not in parts:
others.append(f)
continue
if "tf" in parts:
framework_to_file["tf"] = f
elif "flax" in parts:
framework_to_file["flax"] = f
else:
framework_to_file["pt"] = f
return [framework_to_file[f] for f in frameworks if f in framework_to_file] + others
def get_model_files(model_type: str, frameworks: Optional[List[str]] = None) -> Dict[str, Union[Path, List[Path]]]:
"""
Retrieves all the files associated to a model.
Args:
model_type (`str`): A valid model type (like "bert" or "gpt2")
frameworks (`List[str]`, *optional*):
If passed, will only keep the model files corresponding to the passed frameworks.
Returns:
`Dict[str, Union[Path, List[Path]]]`: A dictionary with the following keys:
- **doc_file** -- The documentation file for the model.
- **model_files** -- All the files in the model module.
- **module_name** -- The name of the module corresponding to the model type.
- **test_files** -- The test files for the model.
"""
module_name = model_type_to_module_name(model_type)
model_module = TRANSFORMERS_PATH / "models" / module_name
model_files = list(model_module.glob("*.py"))
model_files = filter_framework_files(model_files, frameworks=frameworks)
doc_file = REPO_PATH / "docs" / "source" / "en" / "model_doc" / f"{model_type}.md"
test_files = [
f"test_modeling_{module_name}.py",
f"test_modeling_tf_{module_name}.py",
f"test_modeling_flax_{module_name}.py",
f"test_tokenization_{module_name}.py",
f"test_image_processing_{module_name}.py",
f"test_feature_extraction_{module_name}.py",
f"test_processor_{module_name}.py",
]
test_files = filter_framework_files(test_files, frameworks=frameworks)
test_files = [REPO_PATH / "tests" / "models" / module_name / f for f in test_files]
test_files = [f for f in test_files if f.exists()]
return {"doc_file": doc_file, "model_files": model_files, "module_name": module_name, "test_files": test_files}
_re_checkpoint_for_doc = re.compile(r"^_CHECKPOINT_FOR_DOC\s+=\s+(\S*)\s*$", flags=re.MULTILINE)
def find_base_model_checkpoint(
model_type: str, model_files: Optional[Dict[str, Union[Path, List[Path]]]] = None
) -> str:
"""
Finds the model checkpoint used in the docstrings for a given model.
Args:
model_type (`str`): A valid model type (like "bert" or "gpt2")
model_files (`Dict[str, Union[Path, List[Path]]`, *optional*):
The files associated to `model_type`. Can be passed to speed up the function, otherwise will be computed.
Returns:
`str`: The checkpoint used.
"""
if model_files is None:
model_files = get_model_files(model_type)
module_files = model_files["model_files"]
for fname in module_files:
if "modeling" not in str(fname):
continue
with open(fname, "r", encoding="utf-8") as f:
content = f.read()
if _re_checkpoint_for_doc.search(content) is not None:
checkpoint = _re_checkpoint_for_doc.search(content).groups()[0]
checkpoint = checkpoint.replace('"', "")
checkpoint = checkpoint.replace("'", "")
return checkpoint
return ""
def get_default_frameworks():
"""
Returns the list of frameworks (PyTorch, TensorFlow, Flax) that are installed in the environment.
"""
frameworks = []
if is_torch_available():
frameworks.append("pt")
if is_tf_available():
frameworks.append("tf")
if is_flax_available():
frameworks.append("flax")
return frameworks
_re_model_mapping = re.compile("MODEL_([A-Z_]*)MAPPING_NAMES")
def retrieve_model_classes(model_type: str, frameworks: Optional[List[str]] = None) -> Dict[str, List[str]]:
"""
Retrieve the model classes associated to a given model.
Args:
model_type (`str`): A valid model type (like "bert" or "gpt2")
frameworks (`List[str]`, *optional*):
The frameworks to look for. Will default to `["pt", "tf", "flax"]`, passing a smaller list will restrict
the classes returned.
Returns:
`Dict[str, List[str]]`: A dictionary with one key per framework and the list of model classes associated to
that framework as values.
"""
if frameworks is None:
frameworks = get_default_frameworks()
modules = {
"pt": auto_module.modeling_auto if is_torch_available() else None,
"tf": auto_module.modeling_tf_auto if is_tf_available() else None,
"flax": auto_module.modeling_flax_auto if is_flax_available() else None,
}
model_classes = {}
for framework in frameworks:
new_model_classes = []
if modules[framework] is None:
raise ValueError(f"You selected {framework} in the frameworks, but it is not installed.")
model_mappings = [attr for attr in dir(modules[framework]) if _re_model_mapping.search(attr) is not None]
for model_mapping_name in model_mappings:
model_mapping = getattr(modules[framework], model_mapping_name)
if model_type in model_mapping:
new_model_classes.append(model_mapping[model_type])
if len(new_model_classes) > 0:
model_classes[framework] = list(set(new_model_classes))
return model_classes
"""
Retrieves all the information from a given model_type.
Args:
model_type (`str`): A valid model type (like "bert" or "gpt2")
frameworks (`List[str]`, *optional*):
If passed, will only keep the info corresponding to the passed frameworks.
Returns:
`Dict`: A dictionary with the following keys:
- **frameworks** (`List[str]`): The list of frameworks that back this model type.
- **model_classes** (`Dict[str, List[str]]`): The model classes implemented for that model type.
- **model_files** (`Dict[str, Union[Path, List[Path]]]`): The files associated with that model type.
- **model_patterns** (`ModelPatterns`): The various patterns for the model.
"""
if model_type not in auto_module.MODEL_NAMES_MAPPING:
raise ValueError(f"{model_type} is not a valid model type.")
model_name = auto_module.MODEL_NAMES_MAPPING[model_type]
config_class = auto_module.configuration_auto.CONFIG_MAPPING_NAMES[model_type]
archive_map = auto_module.configuration_auto.CONFIG_ARCHIVE_MAP_MAPPING_NAMES.get(model_type, None)
if model_type in auto_module.tokenization_auto.TOKENIZER_MAPPING_NAMES:
tokenizer_classes = auto_module.tokenization_auto.TOKENIZER_MAPPING_NAMES[model_type]
tokenizer_class = tokenizer_classes[0] if tokenizer_classes[0] is not None else tokenizer_classes[1]
else:
tokenizer_class = None
image_processor_class = auto_module.image_processing_auto.IMAGE_PROCESSOR_MAPPING_NAMES.get(model_type, None)
feature_extractor_class = auto_module.feature_extraction_auto.FEATURE_EXTRACTOR_MAPPING_NAMES.get(model_type, None)
processor_class = auto_module.processing_auto.PROCESSOR_MAPPING_NAMES.get(model_type, None)
model_files = get_model_files(model_type, frameworks=frameworks)
model_camel_cased = config_class.replace("Config", "")
available_frameworks = []
for fname in model_files["model_files"]:
if "modeling_tf" in str(fname):
available_frameworks.append("tf")
elif "modeling_flax" in str(fname):
available_frameworks.append("flax")
elif "modeling" in str(fname):
available_frameworks.append("pt")
if frameworks is None:
frameworks = get_default_frameworks()
frameworks = [f for f in frameworks if f in available_frameworks]
model_classes = retrieve_model_classes(model_type, frameworks=frameworks)
if archive_map is None:
model_upper_cased = model_camel_cased.upper()
else:
parts = archive_map.split("_")
idx = 0
while idx < len(parts) and parts[idx] != "PRETRAINED":
idx += 1
if idx < len(parts):
model_upper_cased = "_".join(parts[:idx])
else:
model_upper_cased = model_camel_cased.upper()
model_patterns = ModelPatterns(
model_name,
checkpoint=find_base_model_checkpoint(model_type, model_files=model_files),
model_type=model_type,
model_camel_cased=model_camel_cased,
model_lower_cased=model_files["module_name"],
model_upper_cased=model_upper_cased,
config_class=config_class,
tokenizer_class=tokenizer_class,
image_processor_class=image_processor_class,
feature_extractor_class=feature_extractor_class,
processor_class=processor_class,
)
return {
"frameworks": frameworks,
"model_classes": model_classes,
"model_files": model_files,
"model_patterns": model_patterns,
}
with open(init_file, "r", encoding="utf-8") as f:
content = f.read()
lines = content.split("\n")
new_lines = []
idx = 0
while idx < len(lines):
if (re_conditional_imports.search(lines[idx]) is not None) and (re_try.search(lines[idx - 1]) is not None):
new_lines.pop()
idx += 1
while is_empty_line(lines[idx]) or re_else.search(lines[idx]) is None:
idx += 1
idx += 1
indent = find_indent(lines[idx])
while find_indent(lines[idx]) >= indent or is_empty_line(lines[idx]):
idx += 1
elif re_is_xxx_available.search(lines[idx]) is not None:
line = lines[idx]
for framework in to_remove:
line = line.replace(f", is_{framework}_available", "")
line = line.replace(f"is_{framework}_available, ", "")
line = line.replace(f"is_{framework}_available,", "")
line = line.replace(f"is_{framework}_available", "")
if len(line.strip()) > 0:
new_lines.append(line)
idx += 1
elif keep_processing or (
re.search(r'^\s*"(tokenization|processing|feature_extraction|image_processing)', lines[idx]) is None
and re.search(r"^\s*from .(tokenization|processing|feature_extraction|image_processing)", lines[idx])
is None
):
new_lines.append(lines[idx])
idx += 1
else:
idx += 1
with open(init_file, "w", encoding="utf-8") as f:
f.write("\n".join(new_lines))
with open(TRANSFORMERS_PATH / "__init__.py", "r", encoding="utf-8") as f:
content = f.read()
lines = content.split("\n")
idx = 0
new_lines = []
framework = None
while idx < len(lines):
new_framework = False
if not is_empty_line(lines[idx]) and find_indent(lines[idx]) == 0:
framework = None
elif lines[idx].lstrip().startswith("if not is_torch_available"):
framework = "pt"
new_framework = True
elif lines[idx].lstrip().startswith("if not is_tf_available"):
framework = "tf"
new_framework = True
elif lines[idx].lstrip().startswith("if not is_flax_available"):
framework = "flax"
new_framework = True
if new_framework:
while lines[idx].strip() != "else:":
new_lines.append(lines[idx])
idx += 1
if framework is not None and frameworks is not None and framework not in frameworks:
new_lines.append(lines[idx])
idx += 1
elif re.search(rf'models.{old_model_patterns.model_lower_cased}( |")', lines[idx]) is not None:
block = [lines[idx]]
indent = find_indent(lines[idx])
idx += 1
while find_indent(lines[idx]) > indent:
block.append(lines[idx])
idx += 1
if lines[idx].strip() in [")", "]", "],"]:
block.append(lines[idx])
idx += 1
block = "\n".join(block)
new_lines.append(block)
add_block = True
if not with_processing:
processing_classes = [
old_model_patterns.tokenizer_class,
old_model_patterns.image_processor_class,
old_model_patterns.feature_extractor_class,
old_model_patterns.processor_class,
]
processing_classes = [c for c in processing_classes if c is not None]
for processing_class in processing_classes:
block = block.replace(f' "{processing_class}",', "")
block = block.replace(f', "{processing_class}"', "")
block = block.replace(f" {processing_class},", "")
block = block.replace(f", {processing_class}", "")
if processing_class in block:
add_block = False
if add_block:
new_lines.append(replace_model_patterns(block, old_model_patterns, new_model_patterns)[0])
else:
new_lines.append(lines[idx])
idx += 1
with open(TRANSFORMERS_PATH / "__init__.py", "w", encoding="utf-8") as f:
f.write("\n".join(new_lines))
def insert_tokenizer_in_auto_module(old_model_patterns: ModelPatterns, new_model_patterns: ModelPatterns):
"""
Add a tokenizer to the relevant mappings in the auto module.
Args:
old_model_patterns (`ModelPatterns`): The patterns for the old model.
new_model_patterns (`ModelPatterns`): The patterns for the new model.
"""
if old_model_patterns.tokenizer_class is None or new_model_patterns.tokenizer_class is None:
return
with open(TRANSFORMERS_PATH / "models" / "auto" / "tokenization_auto.py", "r", encoding="utf-8") as f:
content = f.read()
lines = content.split("\n")
idx = 0
while not lines[idx].startswith(" TOKENIZER_MAPPING_NAMES = OrderedDict("):
idx += 1
idx += 1
while not lines[idx].startswith("TOKENIZER_MAPPING = _LazyAutoMapping"):
if lines[idx].endswith(","):
block = lines[idx]
else:
block = []
while not lines[idx].startswith(" ),"):
block.append(lines[idx])
idx += 1
block = "\n".join(block)
idx += 1
if f'"{old_model_patterns.model_type}"' in block and old_model_patterns.tokenizer_class in block:
break
new_block = block.replace(old_model_patterns.model_type, new_model_patterns.model_type)
new_block = new_block.replace(old_model_patterns.tokenizer_class, new_model_patterns.tokenizer_class)
new_lines = lines[:idx] + [new_block] + lines[idx:]
with open(TRANSFORMERS_PATH / "models" / "auto" / "tokenization_auto.py", "w", encoding="utf-8") as f:
f.write("\n".join(new_lines))
AUTO_CLASSES_PATTERNS = {
"configuration_auto.py": [
' ("{model_type}", "{model_name}"),',
' ("{model_type}", "{config_class}"),',
' ("{model_type}", "{pretrained_archive_map}"),',
],
"feature_extraction_auto.py": [' ("{model_type}", "{feature_extractor_class}"),'],
"image_processing_auto.py": [' ("{model_type}", "{image_processor_class}"),'],
"modeling_auto.py": [' ("{model_type}", "{any_pt_class}"),'],
"modeling_tf_auto.py": [' ("{model_type}", "{any_tf_class}"),'],
"modeling_flax_auto.py": [' ("{model_type}", "{any_flax_class}"),'],
"processing_auto.py": [' ("{model_type}", "{processor_class}"),'],
}
def add_model_to_auto_classes(
old_model_patterns: ModelPatterns, new_model_patterns: ModelPatterns, model_classes: Dict[str, List[str]]
):
"""
Add a model to the relevant mappings in the auto module.
Args:
old_model_patterns (`ModelPatterns`): The patterns for the old model.
new_model_patterns (`ModelPatterns`): The patterns for the new model.
model_classes (`Dict[str, List[str]]`): A dictionary mapping auto module filenames to lists of model class names.
"""
Args:
old_model_patterns (`ModelPatterns`): The patterns for the old model.
new_model_patterns (`ModelPatterns`): The patterns for the new model.
model_classes (`Dict[str, List[str]]`): A dictionary framework to list of model classes implemented.
"""
# 调用函数将旧模型模式中的所有分词器插入到新模型模式的自动模块中
insert_tokenizer_in_auto_module(old_model_patterns, new_model_patterns)
# 模板文档字符串,用于生成新模型的概述性文档
DOC_OVERVIEW_TEMPLATE = """
The {model_name} model was proposed in [<INSERT PAPER NAME HERE>](<INSERT PAPER LINK HERE>) by <INSERT AUTHORS HERE>.
<INSERT SHORT SUMMARY HERE>
The abstract from the paper is the following:
*<INSERT PAPER ABSTRACT HERE>*
Tips:
<INSERT TIPS ABOUT MODEL HERE>
This model was contributed by [INSERT YOUR HF USERNAME HERE](https://huggingface.co/<INSERT YOUR HF USERNAME HERE>).
The original code can be found [here](<INSERT LINK TO GITHUB REPO HERE>).
"""
def duplicate_doc_file(
doc_file: Union[str, os.PathLike],
old_model_patterns: ModelPatterns,
new_model_patterns: ModelPatterns,
dest_file: Optional[Union[str, os.PathLike]] = None,
frameworks: Optional[List[str]] = None,
):
"""
Duplicate a documentation file and adapts it for a new model.
Args:
module_file (`str` or `os.PathLike`): Path to the doc file to duplicate.
old_model_patterns (`ModelPatterns`): The patterns for the old model.
new_model_patterns (`ModelPatterns`): The patterns for the new model.
dest_file (`str` or `os.PathLike`, *optional*): Path to the new doc file.
Will default to the a file named `{new_model_patterns.model_type}.md` in the same folder as `module_file`.
frameworks (`List[str]`, *optional*):
If passed, will only keep the model classes corresponding to this list of frameworks in the new doc file.
"""
# 读取原始文档文件内容
with open(doc_file, "r", encoding="utf-8") as f:
content = f.read()
# 更新版权信息为当前年份
content = re.sub(r"<!--\s*Copyright (\d+)\s", f"<!--Copyright {CURRENT_YEAR} ", content)
# 如果未提供特定框架列表,则使用默认框架列表
if frameworks is None:
frameworks = get_default_frameworks()
# 如果未提供目标文件路径,则默认为与原文档文件同目录下,新模型类型命名的文件
if dest_file is None:
dest_file = Path(doc_file).parent / f"{new_model_patterns.model_type}.md"
# 解析文档内容为块。每个块对应一个部分/标题
lines = content.split("\n")
blocks = []
current_block = []
for line in lines:
if line.startswith("#"):
blocks.append("\n".join(current_block))
current_block = [line]
else:
current_block.append(line)
blocks.append("\n".join(current_block))
new_blocks = []
in_classes = False
# 遍历输入的文本块列表
for block in blocks:
# 检查是否以版权声明开始,如果不是则添加到新的文本块列表中
if not block.startswith("#"):
new_blocks.append(block)
# 检查是否为主标题,如果是则替换为新模型名称的标题
elif re.search(r"^#\s+\S+", block) is not None:
new_blocks.append(f"# {new_model_patterns.model_name}\n")
# 检查是否进入类定义部分,根据旧模型配置类来确定
elif not in_classes and old_model_patterns.config_class in block.split("\n")[0]:
# 标记已进入类定义部分,并添加文档概述模板及替换后的模型配置块
in_classes = True
new_blocks.append(DOC_OVERVIEW_TEMPLATE.format(model_name=new_model_patterns.model_name))
new_block, _ = replace_model_patterns(block, old_model_patterns, new_model_patterns)
new_blocks.append(new_block)
# 处理在类定义部分的情况
elif in_classes:
in_classes = True
# 获取当前文本块的标题,并提取类名
block_title = block.split("\n")[0]
block_class = re.search(r"^#+\s+(\S.*)$", block_title).groups()[0]
new_block, _ = replace_model_patterns(block, old_model_patterns, new_model_patterns)
# 根据类名条件性地添加新的文本块
if "Tokenizer" in block_class:
# 仅在需要时添加标记器类
if old_model_patterns.tokenizer_class != new_model_patterns.tokenizer_class:
new_blocks.append(new_block)
elif "ImageProcessor" in block_class:
# 仅在需要时添加图像处理器类
if old_model_patterns.image_processor_class != new_model_patterns.image_processor_class:
new_blocks.append(new_block)
elif "FeatureExtractor" in block_class:
# 仅在需要时添加特征提取器类
if old_model_patterns.feature_extractor_class != new_model_patterns.feature_extractor_class:
new_blocks.append(new_block)
elif "Processor" in block_class:
# 仅在需要时添加处理器类
if old_model_patterns.processor_class != new_model_patterns.processor_class:
new_blocks.append(new_block)
elif block_class.startswith("Flax"):
# 仅在所选框架中包含 Flax 模型时添加
if "flax" in frameworks:
new_blocks.append(new_block)
elif block_class.startswith("TF"):
# 仅在所选框架中包含 TF 模型时添加
if "tf" in frameworks:
new_blocks.append(new_block)
elif len(block_class.split(" ")) == 1:
# 仅在所选框架中包含 PyTorch 模型时添加
if "pt" in frameworks:
new_blocks.append(new_block)
else:
new_blocks.append(new_block)
# 将新的文本块列表写入目标文件
with open(dest_file, "w", encoding="utf-8") as f:
f.write("\n".join(new_blocks))
# 在文档目录中插入新模型的条目,与旧模型在同一部分。
def insert_model_in_doc_toc(old_model_patterns, new_model_patterns):
"""
Insert the new model in the doc TOC, in the same section as the old model.
Args:
old_model_patterns (`ModelPatterns`): The patterns for the old model.
new_model_patterns (`ModelPatterns`): The patterns for the new model.
"""
# 指定文档目录文件路径
toc_file = REPO_PATH / "docs" / "source" / "en" / "_toctree.yml"
# 打开并加载 YAML 格式的目录文件内容
with open(toc_file, "r", encoding="utf8") as f:
content = yaml.safe_load(f)
# 定位到 API 文档的索引
api_idx = 0
while content[api_idx]["title"] != "API":
api_idx += 1
# 获取 API 文档下的各个部分
api_doc = content[api_idx]["sections"]
# 定位到 Models 部分的索引
model_idx = 0
while api_doc[model_idx]["title"] != "Models":
model_idx += 1
# 获取 Models 部分下的各个小节
model_doc = api_doc[model_idx]["sections"]
# 在目录中查找基础模型的位置
old_model_type = old_model_patterns.model_type
section_idx = 0
while section_idx < len(model_doc):
# 获取当前小节中的本地目录项列表
sections = [entry["local"] for entry in model_doc[section_idx]["sections"]]
# 如果旧模型的目录项在当前小节中,则跳出循环
if f"model_doc/{old_model_type}" in sections:
break
section_idx += 1
# 如果未找到旧模型的目录项,则输出警告信息并返回
if section_idx == len(model_doc):
old_model = old_model_patterns.model_name
new_model = new_model_patterns.model_name
print(f"Did not find {old_model} in the table of content, so you will need to add {new_model} manually.")
return
# 准备新模型的目录项信息
toc_entry = {"local": f"model_doc/{new_model_patterns.model_type}", "title": new_model_patterns.model_name}
# 将新模型的目录项添加到找到的旧模型所在的小节中
model_doc[section_idx]["sections"].append(toc_entry)
# 根据标题排序小节中的目录项
model_doc[section_idx]["sections"] = sorted(model_doc[section_idx]["sections"], key=lambda s: s["title"].lower())
# 更新 API 文档中的 Models 部分
api_doc[model_idx]["sections"] = model_doc
# 更新整体内容中的 API 文档
content[api_idx]["sections"] = api_doc
# 将更新后的内容重新写入目录文件
with open(toc_file, "w", encoding="utf-8") as f:
f.write(yaml.dump(content, allow_unicode=True))
# 获取给定模型类型的相关信息,包括模型文件、模型模式等
model_info = retrieve_info_for_model(model_type, frameworks=frameworks)
# 从模型信息中获取模型文件列表和旧模型模式
model_files = model_info["model_files"]
old_model_patterns = model_info["model_patterns"]
# 如果有提供旧的检查点,则更新旧模型模式的检查点属性
if old_checkpoint is not None:
old_model_patterns.checkpoint = old_checkpoint
# 检查旧模型模式的检查点属性是否为空,如果是则引发 ValueError
if len(old_model_patterns.checkpoint) == 0:
raise ValueError(
"The old model checkpoint could not be recovered from the model type. Please pass it to the "
"`old_checkpoint` argument."
)
# 初始化保持旧处理方式的标志为 True
keep_old_processing = True
# 检查特定处理属性(如图像处理类、特征提取器类、处理器类、分词器类)是否与新模型模式相同,若有不同则将标志设为 False
for processing_attr in ["image_processor_class", "feature_extractor_class", "processor_class", "tokenizer_class"]:
if getattr(old_model_patterns, processing_attr) != getattr(new_model_patterns, processing_attr):
keep_old_processing = False
# 从模型信息中获取模型类别
model_classes = model_info["model_classes"]
# 1. 创建新模型的模块
old_module_name = model_files["module_name"]
module_folder = TRANSFORMERS_PATH / "models" / new_model_patterns.model_lower_cased
# 确保模块文件夹存在,如果不存在则创建
os.makedirs(module_folder, exist_ok=True)
# 根据保持旧处理方式的标志筛选要适应的文件列表
files_to_adapt = model_files["model_files"]
if keep_old_processing:
files_to_adapt = [
f
for f in files_to_adapt
if "tokenization" not in str(f)
and "processing" not in str(f)
and "feature_extraction" not in str(f)
and "image_processing" not in str(f)
]
# 再次确保模块文件夹存在,如果不存在则创建
os.makedirs(module_folder, exist_ok=True)
# 遍历要适应的文件列表,生成新的模块文件名并复制到目标位置
for module_file in files_to_adapt:
new_module_name = module_file.name.replace(
old_model_patterns.model_lower_cased, new_model_patterns.model_lower_cased
)
dest_file = module_folder / new_module_name
duplicate_module(
module_file,
old_model_patterns,
new_model_patterns,
dest_file=dest_file,
add_copied_from=add_copied_from and "modeling" in new_module_name,
)
# 清理模块的初始化文件,根据保持旧处理方式的标志更新处理类别
clean_frameworks_in_init(
module_folder / "__init__.py", frameworks=frameworks, keep_processing=not keep_old_processing
)
# 2. 将新模型添加到模型包的初始化文件和主初始化文件中
add_content_to_file(
TRANSFORMERS_PATH / "models" / "__init__.py",
f" {new_model_patterns.model_lower_cased},",
add_after=f" {old_module_name},",
exact_match=True,
)
add_model_to_main_init(
old_model_patterns, new_model_patterns, frameworks=frameworks, with_processing=not keep_old_processing
)
# 3. 添加测试文件
files_to_adapt = model_files["test_files"]
if keep_old_processing:
files_to_adapt = [
f
for f in files_to_adapt
if "tokenization" not in str(f)
and "processor" not in str(f)
and "feature_extraction" not in str(f)
and "image_processing" not in str(f)
]
# 定义一个函数,用于禁用与指定文件相关的特定功能测试
def disable_fx_test(filename: Path) -> bool:
# 打开文件并读取其内容
with open(filename) as fp:
content = fp.read()
# 使用正则表达式替换文件内容中的特定文本
new_content = re.sub(r"fx_compatible\s*=\s*True", "fx_compatible = False", content)
# 将修改后的内容写回到文件中
with open(filename, "w") as fp:
fp.write(new_content)
# 返回值指示是否有内容被修改过
return content != new_content
# 初始化一个标志,用于追踪是否禁用了功能测试
disabled_fx_test = False
# 创建测试文件夹,如果不存在则创建
tests_folder = REPO_PATH / "tests" / "models" / new_model_patterns.model_lower_cased
os.makedirs(tests_folder, exist_ok=True)
# 创建一个空的 __init__.py 文件
with open(tests_folder / "__init__.py", "w"):
pass
# 遍历需要调整的文件列表
for test_file in files_to_adapt:
# 根据模式替换文件名中的旧模型名称为新模型名称
new_test_file_name = test_file.name.replace(
old_model_patterns.model_lower_cased, new_model_patterns.model_lower_cased
)
# 构建目标文件的路径
dest_file = test_file.parent.parent / new_model_patterns.model_lower_cased / new_test_file_name
# 复制指定的测试文件到目标位置,并禁用功能测试
duplicate_module(
test_file,
old_model_patterns,
new_model_patterns,
dest_file=dest_file,
add_copied_from=False,
attrs_to_remove=["pipeline_model_mapping", "is_pipeline_test_to_skip"],
)
# 更新功能测试禁用状态
disabled_fx_test = disabled_fx_test | disable_fx_test(dest_file)
# 如果有功能测试被禁用,则输出提示信息
if disabled_fx_test:
print(
"The tests for symbolic tracing with torch.fx were disabled, you can add those once symbolic tracing works"
" for your new model."
)
# 将新模型添加到自动类中
add_model_to_auto_classes(old_model_patterns, new_model_patterns, model_classes)
# 添加文档文件
doc_file = REPO_PATH / "docs" / "source" / "en" / "model_doc" / f"{old_model_patterns.model_type}.md"
duplicate_doc_file(doc_file, old_model_patterns, new_model_patterns, frameworks=frameworks)
# 在文档目录中插入新模型
insert_model_in_doc_toc(old_model_patterns, new_model_patterns)
# 如果旧模型类型与其检查点名称相同,输出警告信息
if old_model_patterns.model_type == old_model_patterns.checkpoint:
print(
"The model you picked has the same name for the model type and the checkpoint name "
f"({old_model_patterns.model_type}). As a result, it's possible some places where the new checkpoint "
f"should be, you have {new_model_patterns.model_type} instead. You should search for all instances of "
f"{new_model_patterns.model_type} in the new files and check they're not badly used as checkpoints."
)
# 如果旧模型名称(小写形式)与其检查点名称相同,输出警告信息
elif old_model_patterns.model_lower_cased == old_model_patterns.checkpoint:
print(
"The model you picked has the same name for the model type and the checkpoint name "
f"({old_model_patterns.model_lower_cased}). As a result, it's possible some places where the new "
f"checkpoint should be, you have {new_model_patterns.model_lower_cased} instead. You should search for "
f"all instances of {new_model_patterns.model_lower_cased} in the new files and check they're not badly "
"used as checkpoints."
)
# 检查旧模型模式的类型是否为小写,并且新模型模式的类型不是小写时
if (
old_model_patterns.model_type == old_model_patterns.model_lower_cased
and new_model_patterns.model_type != new_model_patterns.model_lower_cased
):
# 输出警告信息,说明选择的模型类型和小写模型名称相同,可能导致新模型类型在某些地方被误用为小写模型名称
print(
"The model you picked has the same name for the model type and the lowercased model name "
f"({old_model_patterns.model_lower_cased}). As a result, it's possible some places where the new "
f"model type should be, you have {new_model_patterns.model_lower_cased} instead. You should search for "
f"all instances of {new_model_patterns.model_lower_cased} in the new files and check they're not badly "
"used as the model type."
)
# 如果不保留旧的处理逻辑并且旧模型模式的分词器类不为空时
if not keep_old_processing and old_model_patterns.tokenizer_class is not None:
# 输出提示信息,指出需要手动修复新分词器文件开头的常量问题。如果新模型有一个快速分词器,还需手动将转换器添加到 `convert_slow_tokenizer.py` 的 `SLOW_TO_FAST_CONVERTERS` 常量中
print(
"The constants at the start of the new tokenizer file created needs to be manually fixed. If your new "
"model has a tokenizer fast, you will also need to manually add the converter in the "
"`SLOW_TO_FAST_CONVERTERS` constant of `convert_slow_tokenizer.py`."
)
def add_new_model_like_command_factory(args: Namespace):
# 创建并返回一个 AddNewModelLikeCommand 对象,使用参数中的配置文件和仓库路径
return AddNewModelLikeCommand(config_file=args.config_file, path_to_repo=args.path_to_repo)
class AddNewModelLikeCommand(BaseTransformersCLICommand):
@staticmethod
def register_subcommand(parser: ArgumentParser):
# 注册子命令 "add-new-model-like" 到指定的 ArgumentParser 对象
add_new_model_like_parser = parser.add_parser("add-new-model-like")
add_new_model_like_parser.add_argument(
"--config_file", type=str, help="A file with all the information for this model creation."
)
add_new_model_like_parser.add_argument(
"--path_to_repo", type=str, help="When not using an editable install, the path to the Transformers repo."
)
# 设置默认的函数处理程序为 add_new_model_like_command_factory 函数
add_new_model_like_parser.set_defaults(func=add_new_model_like_command_factory)
def __init__(self, config_file=None, path_to_repo=None, *args):
if config_file is not None:
# 如果配置文件不为 None,从配置文件中加载配置信息
with open(config_file, "r", encoding="utf-8") as f:
config = json.load(f)
# 初始化对象的各个属性
self.old_model_type = config["old_model_type"]
self.model_patterns = ModelPatterns(**config["new_model_patterns"])
self.add_copied_from = config.get("add_copied_from", True)
self.frameworks = config.get("frameworks", get_default_frameworks())
self.old_checkpoint = config.get("old_checkpoint", None)
else:
# 如果配置文件为 None,调用 get_user_input() 函数获取用户输入的属性值
(
self.old_model_type,
self.model_patterns,
self.add_copied_from,
self.frameworks,
self.old_checkpoint,
) = get_user_input()
self.path_to_repo = path_to_repo
def run(self):
if self.path_to_repo is not None:
# 如果仓库路径不为 None,则设定全局变量 TRANSFORMERS_PATH 和 REPO_PATH
global TRANSFORMERS_PATH
global REPO_PATH
REPO_PATH = Path(self.path_to_repo)
TRANSFORMERS_PATH = REPO_PATH / "src" / "transformers"
# 调用 create_new_model_like 函数创建新模型
create_new_model_like(
model_type=self.old_model_type,
new_model_patterns=self.model_patterns,
add_copied_from=self.add_copied_from,
frameworks=self.frameworks,
old_checkpoint=self.old_checkpoint,
)
def get_user_field(
question: str,
default_value: Optional[str] = None,
is_valid_answer: Optional[Callable] = None,
convert_to: Optional[Callable] = None,
fallback_message: Optional[str] = None,
) -> Any:
"""
A utility function that asks a question to the user to get an answer, potentially looping until it gets a valid
answer.
"""
# 简单的用户输入获取函数,带有一些可选的参数和验证功能
# 如果问题字符串不以空格结尾,添加一个空格
if not question.endswith(" "):
question = question + " "
# 如果提供了默认值,将默认值添加到问题的末尾
if default_value is not None:
question = f"{question} [{default_value}] "
# 初始化有效答案为 False,用于循环直到得到有效答案
valid_answer = False
while not valid_answer:
# 提示用户输入问题,并获取用户输入的答案
answer = input(question)
# 如果提供了默认值且用户未输入任何内容,则使用默认值
if default_value is not None and len(answer) == 0:
answer = default_value
# 如果提供了自定义的答案验证函数 is_valid_answer
if is_valid_answer is not None:
valid_answer = is_valid_answer(answer)
# 如果提供了转换函数 convert_to
elif convert_to is not None:
try:
# 尝试将答案转换为指定类型
answer = convert_to(answer)
valid_answer = True
except Exception:
# 如果转换失败,则标记答案为无效,继续循环
valid_answer = False
else:
# 如果没有提供 is_valid_answer 或 convert_to,直接标记答案为有效
valid_answer = True
# 如果答案无效,则打印回退消息
if not valid_answer:
print(fallback_message)
# 返回经过验证和可能转换的答案
return answer
# 将字符串转换为布尔值
def convert_to_bool(x: str) -> bool:
"""
Converts a string to a bool.
"""
# 检查字符串是否在可接受的真值列表中,返回对应的布尔值
if x.lower() in ["1", "y", "yes", "true"]:
return True
# 检查字符串是否在可接受的假值列表中,返回对应的布尔值
if x.lower() in ["0", "n", "no", "false"]:
return False
# 如果字符串既不是真值也不是假值,抛出 ValueError 异常
raise ValueError(f"{x} is not a value that can be converted to a bool.")
# 获取用户输入以添加新模型
def get_user_input():
"""
Ask the user for the necessary inputs to add the new model.
"""
# 获取模型类型列表
model_types = list(auto_module.configuration_auto.MODEL_NAMES_MAPPING.keys())
# 获取旧模型类型
valid_model_type = False
while not valid_model_type:
# 提示用户输入要复制的模型类型
old_model_type = input(
"What is the model you would like to duplicate? Please provide the lowercase `model_type` (e.g. roberta): "
)
# 检查用户输入是否在模型类型列表中
if old_model_type in model_types:
valid_model_type = True
else:
# 如果输入不在列表中,提示用户并尝试提供建议
print(f"{old_model_type} is not a valid model type.")
near_choices = difflib.get_close_matches(old_model_type, model_types)
if len(near_choices) >= 1:
if len(near_choices) > 1:
near_choices = " or ".join(near_choices)
print(f"Did you mean {near_choices}?")
# 获取旧模型的详细信息
old_model_info = retrieve_info_for_model(old_model_type)
old_tokenizer_class = old_model_info["model_patterns"].tokenizer_class
old_image_processor_class = old_model_info["model_patterns"].image_processor_class
old_feature_extractor_class = old_model_info["model_patterns"].feature_extractor_class
old_processor_class = old_model_info["model_patterns"].processor_class
old_frameworks = old_model_info["frameworks"]
# 如果旧模型没有检查点信息,要求用户输入基础检查点的名称
old_checkpoint = None
if len(old_model_info["model_patterns"].checkpoint) == 0:
old_checkpoint = get_user_field(
"We couldn't find the name of the base checkpoint for that model, please enter it here."
)
# 获取新模型的名称
model_name = get_user_field(
"What is the name (with no special casing) for your new model in the paper (e.g. RoBERTa)? "
)
# 创建默认模型模式对象
default_patterns = ModelPatterns(model_name, model_name)
# 获取用户输入的模型标识符
model_type = get_user_field(
"What identifier would you like to use for the `model_type` of this model? ",
default_value=default_patterns.model_type,
)
# 获取用户输入的模型模块名(小写)
model_lower_cased = get_user_field(
"What lowercase name would you like to use for the module (folder) of this model? ",
default_value=default_patterns.model_lower_cased,
)
# 获取用户输入的模型类的前缀(驼峰命名)
model_camel_cased = get_user_field(
"What prefix (camel-cased) would you like to use for the model classes of this model (e.g. Roberta)? ",
default_value=default_patterns.model_camel_cased,
)
# 获取用户输入的模型常量的前缀(大写)
model_upper_cased = get_user_field(
"What prefix (upper-cased) would you like to use for the constants relative to this model? ",
default_value=default_patterns.model_upper_cased,
)
# 获取用户输入的配置类名称
config_class = get_user_field(
"What will be the name of the config class for this model? ", default_value=f"{model_camel_cased}Config"
)
)
# 调用 get_user_field 函数获取用户输入,用于指定新模型的检查点标识符
checkpoint = get_user_field(
"Please give a checkpoint identifier (on the model Hub) for this new model (e.g. facebook/FacebookAI/roberta-base): "
)
# 创建旧处理类列表,仅包含非空元素
old_processing_classes = [
c
for c in [old_image_processor_class, old_feature_extractor_class, old_tokenizer_class, old_processor_class]
if c is not None
]
# 将列表转换为逗号分隔的字符串
old_processing_classes = ", ".join(old_processing_classes)
# 获取用户输入,确认新模型是否使用与旧模型相同的处理类
keep_processing = get_user_field(
f"Will your new model use the same processing class as {old_model_type} ({old_processing_classes}) (yes/no)? ",
convert_to=convert_to_bool,
fallback_message="Please answer yes/no, y/n, true/false or 1/0. ",
)
# 根据用户的选择,确定新模型的处理类
if keep_processing:
image_processor_class = old_image_processor_class
feature_extractor_class = old_feature_extractor_class
processor_class = old_processor_class
tokenizer_class = old_tokenizer_class
else:
# 如果不使用与旧模型相同的处理类,则根据需要获取各种处理类的新名称
if old_tokenizer_class is not None:
tokenizer_class = get_user_field(
"What will be the name of the tokenizer class for this model? ",
default_value=f"{model_camel_cased}Tokenizer",
)
else:
tokenizer_class = None
if old_image_processor_class is not None:
image_processor_class = get_user_field(
"What will be the name of the image processor class for this model? ",
default_value=f"{model_camel_cased}ImageProcessor",
)
else:
image_processor_class = None
if old_feature_extractor_class is not None:
feature_extractor_class = get_user_field(
"What will be the name of the feature extractor class for this model? ",
default_value=f"{model_camel_cased}FeatureExtractor",
)
else:
feature_extractor_class = None
if old_processor_class is not None:
processor_class = get_user_field(
"What will be the name of the processor class for this model? ",
default_value=f"{model_camel_cased}Processor",
)
else:
processor_class = None
# 创建 ModelPatterns 对象,用于保存新模型的相关属性
model_patterns = ModelPatterns(
model_name,
checkpoint,
model_type=model_type,
model_lower_cased=model_lower_cased,
model_camel_cased=model_camel_cased,
model_upper_cased=model_upper_cased,
config_class=config_class,
tokenizer_class=tokenizer_class,
image_processor_class=image_processor_class,
feature_extractor_class=feature_extractor_class,
processor_class=processor_class,
)
# 获取用户输入,确定在创建新建模型文件时是否添加 # Copied from 注释
add_copied_from = get_user_field(
"Should we add # Copied from statements when creating the new modeling file (yes/no)? ",
convert_to=convert_to_bool,
default_value="yes",
fallback_message="Please answer yes/no, y/n, true/false or 1/0.",
)
# 调用函数获取用户字段,询问是否在所有旧模型类型的框架中添加新模型的版本
# 用户字段包括确认消息、类型转换函数、默认值和回退消息
all_frameworks = get_user_field(
"Should we add a version of your new model in all the frameworks implemented by"
f" {old_model_type} ({old_frameworks}) (yes/no)? ",
convert_to=convert_to_bool, # 将用户输入转换为布尔类型的函数
default_value="yes", # 默认值为 "yes"
fallback_message="Please answer yes/no, y/n, true/false or 1/0.", # 如果用户输入不合法时的提示消息
)
# 如果用户选择在所有框架中添加新模型版本
if all_frameworks:
frameworks = None # 框架列表设为 None
else:
# 否则,获取用户字段,请求用户输入要使用的框架列表
frameworks = get_user_field(
"Please enter the list of framworks you want (pt, tf, flax) separated by spaces",
# 检查用户输入是否有效,要求所有输入项必须是 ["pt", "tf", "flax"] 中的一种
is_valid_answer=lambda x: all(p in ["pt", "tf", "flax"] for p in x.split(" ")),
)
frameworks = list(set(frameworks.split(" "))) # 将输入的框架列表转换为集合去重后再转为列表
# 返回元组包含旧模型类型、模型模式、复制来源、框架列表和旧的检查点
return (old_model_type, model_patterns, add_copied_from, frameworks, old_checkpoint)