Transformers 源码解析(十一)
.\models\audio_spectrogram_transformer\feature_extraction_audio_spectrogram_transformer.py
"""
Feature extractor class for Audio Spectrogram Transformer.
"""
from typing import List, Optional, Union
import numpy as np
from ...audio_utils import mel_filter_bank, spectrogram, window_function
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
from ...feature_extraction_utils import BatchFeature
from ...utils import TensorType, is_speech_available, is_torch_available, logging
if is_speech_available():
import torchaudio.compliance.kaldi as ta_kaldi
if is_torch_available():
import torch
logger = logging.get_logger(__name__)
class ASTFeatureExtractor(SequenceFeatureExtractor):
r"""
Constructs a Audio Spectrogram Transformer (AST) feature extractor.
This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains
most of the main methods. Users should refer to this superclass for more information regarding those methods.
This class extracts mel-filter bank features from raw speech using TorchAudio if installed or using numpy
otherwise, pads/truncates them to a fixed length and normalizes them using a mean and standard deviation.
Args:
feature_size (`int`, *optional*, defaults to 1):
The feature dimension of the extracted features.
sampling_rate (`int`, *optional*, defaults to 16000):
The sampling rate at which the audio files should be digitalized expressed in hertz (Hz).
num_mel_bins (`int`, *optional*, defaults to 128):
Number of Mel-frequency bins.
max_length (`int`, *optional*, defaults to 1024):
Maximum length to which to pad/truncate the extracted features.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether or not to normalize the log-Mel features using `mean` and `std`.
mean (`float`, *optional*, defaults to -4.2677393):
The mean value used to normalize the log-Mel features. Uses the AudioSet mean by default.
std (`float`, *optional*, defaults to 4.5689974):
The standard deviation value used to normalize the log-Mel features. Uses the AudioSet standard deviation
by default.
return_attention_mask (`bool`, *optional*, defaults to `False`):
Whether or not [`~ASTFeatureExtractor.__call__`] should return `attention_mask`.
"""
model_input_names = ["input_values", "attention_mask"]
def __init__(
self,
feature_size=1,
sampling_rate=16000,
num_mel_bins=128,
max_length=1024,
padding_value=0.0,
do_normalize=True,
mean=-4.2677393,
std=4.5689974,
return_attention_mask=False,
**kwargs,
):
super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
self.num_mel_bins = num_mel_bins
self.max_length = max_length
self.do_normalize = do_normalize
self.mean = mean
self.std = std
self.return_attention_mask = return_attention_mask
if not is_speech_available():
mel_filters = mel_filter_bank(
num_frequency_bins=256,
num_mel_filters=self.num_mel_bins,
min_frequency=20,
max_frequency=sampling_rate // 2,
sampling_rate=sampling_rate,
norm=None,
mel_scale="kaldi",
triangularize_in_mel_space=True,
)
self.mel_filters = np.pad(mel_filters, ((0, 1), (0, 0)))
self.window = window_function(400, "hann", periodic=False)
def _extract_fbank_features(
self,
waveform: np.ndarray,
max_length: int,
) -> np.ndarray:
"""
Get mel-filter bank features using TorchAudio. Note that TorchAudio requires 16-bit signed integers as inputs
and hence the waveform should not be normalized before feature extraction.
"""
if is_speech_available():
waveform = torch.from_numpy(waveform).unsqueeze(0)
fbank = ta_kaldi.fbank(
waveform,
sample_frequency=self.sampling_rate,
window_type="hanning",
num_mel_bins=self.num_mel_bins,
)
else:
waveform = np.squeeze(waveform)
fbank = spectrogram(
waveform,
self.window,
frame_length=400,
hop_length=160,
fft_length=512,
power=2.0,
center=False,
preemphasis=0.97,
mel_filters=self.mel_filters,
log_mel="log",
mel_floor=1.192092955078125e-07,
remove_dc_offset=True,
).T
fbank = torch.from_numpy(fbank)
n_frames = fbank.shape[0]
difference = max_length - n_frames
if difference > 0:
pad_module = torch.nn.ZeroPad2d((0, 0, 0, difference))
fbank = pad_module(fbank)
elif difference < 0:
fbank = fbank[0:max_length, :]
fbank = fbank.numpy()
return fbank
def normalize(self, input_values: np.ndarray) -> np.ndarray:
return (input_values - (self.mean)) / (self.std * 2)
def __call__(
self,
raw_speech: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]],
sampling_rate: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
**kwargs,
.\models\audio_spectrogram_transformer\modeling_audio_spectrogram_transformer.py
""" PyTorch Audio Spectrogram Transformer (AST) model."""
import math
from typing import Dict, List, Optional, Set, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, SequenceClassifierOutput
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from .configuration_audio_spectrogram_transformer import ASTConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "ASTConfig"
_CHECKPOINT_FOR_DOC = "MIT/ast-finetuned-audioset-10-10-0.4593"
_EXPECTED_OUTPUT_SHAPE = [1, 1214, 768]
_SEQ_CLASS_CHECKPOINT = "MIT/ast-finetuned-audioset-10-10-0.4593"
_SEQ_CLASS_EXPECTED_OUTPUT = "'Speech'"
_SEQ_CLASS_EXPECTED_LOSS = 0.17
AUDIO_SPECTROGRAM_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
"MIT/ast-finetuned-audioset-10-10-0.4593",
]
class ASTEmbeddings(nn.Module):
"""
构建 CLS 标记、位置和补丁嵌入。
"""
def __init__(self, config: ASTConfig) -> None:
super().__init__()
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
self.distillation_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
self.patch_embeddings = ASTPatchEmbeddings(config)
frequency_out_dimension, time_out_dimension = self.get_shape(config)
num_patches = frequency_out_dimension * time_out_dimension
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 2, config.hidden_size))
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.config = config
def get_shape(self, config):
frequency_out_dimension = (config.num_mel_bins - config.patch_size) // config.frequency_stride + 1
time_out_dimension = (config.max_length - config.patch_size) // config.time_stride + 1
return frequency_out_dimension, time_out_dimension
def forward(self, input_values: torch.Tensor) -> torch.Tensor:
batch_size = input_values.shape[0]
embeddings = self.patch_embeddings(input_values)
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
distillation_tokens = self.distillation_token.expand(batch_size, -1, -1)
embeddings = torch.cat((cls_tokens, distillation_tokens, embeddings), dim=1)
embeddings = embeddings + self.position_embeddings
embeddings = self.dropout(embeddings)
return embeddings
def __init__(self, config: ASTConfig) -> None:
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
f"heads {config.num_attention_heads}."
)
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
-> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
attention_probs = self.dropout(attention_probs)
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs
class ASTSelfOutput(nn.Module):
"""
The residual connection is defined in ASTLayer instead of here (as is the case with other models), due to the
layernorm applied before each block.
"""
def __init__(self, config: ASTConfig) -> None:
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states
class ASTAttention(nn.Module):
def __init__(self, config: ASTConfig) -> None:
super().__init__()
self.attention = ASTSelfAttention(config)
self.output = ASTSelfOutput(config)
self.pruned_heads = set()
def prune_heads(self, heads: Set[int]) -> None:
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(
heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
)
self.attention.query = prune_linear_layer(self.attention.query, index)
self.attention.key = prune_linear_layer(self.attention.key, index)
self.attention.value = prune_linear_layer(self.attention.value, index)
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
def forward(
self,
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
self_outputs = self.attention(hidden_states, head_mask, output_attentions)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:]
return outputs
class ASTIntermediate(nn.Module):
def __init__(self, config: ASTConfig) -> None:
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class ASTOutput(nn.Module):
def __init__(self, config: ASTConfig) -> None:
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = hidden_states + input_tensor
return hidden_states
class ASTLayer(nn.Module):
"""This corresponds to the Block class in the timm implementation."""
def __init__(self, config: ASTConfig) -> None:
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = ASTAttention(config)
self.intermediate = ASTIntermediate(config)
self.output = ASTOutput(config)
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
self_attention_outputs = self.attention(
self.layernorm_before(hidden_states),
head_mask,
output_attentions=output_attentions,
)
attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:]
hidden_states = attention_output + hidden_states
layer_output = self.layernorm_after(hidden_states)
layer_output = self.intermediate(layer_output)
layer_output = self.output(layer_output, hidden_states)
outputs = (layer_output,) + outputs
return outputs
class ASTEncoder(nn.Module):
def __init__(self, config: ASTConfig) -> None:
super().__init__()
self.config = config
self.layer = nn.ModuleList([ASTLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
) -> Union[tuple, BaseModelOutput]:
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
layer_head_mask,
output_attentions,
)
else:
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
class ASTPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = ASTConfig
base_model_prefix = "audio_spectrogram_transformer"
main_input_name = "input_values"
supports_gradient_checkpointing = True
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d)):
module.weight.data = nn.init.trunc_normal_(
module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
).to(module.weight.dtype)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
AUDIO_SPECTROGRAM_TRANSFORMER_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
behavior.
Parameters:
config ([`ASTConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
AUDIO_SPECTROGRAM_TRANSFORMER_INPUTS_DOCSTRING = r"""
This class can be used as a regular PyTorch Module. Please refer to the PyTorch documentation for general usage and
behavior details.
Parameters:
config (:class:`~transformers.ASTConfig`):
The configuration class holding all parameters of this model. Initializing with a configuration file only
initializes the model configuration; it does not load the weights. For loading weights, use the
:meth:`~transformers.PreTrainedModel.from_pretrained` method.
"""
Args:
input_values (`torch.FloatTensor` of shape `(batch_size, max_length, num_mel_bins)`):
Float values mel features extracted from the raw audio waveform. Raw audio waveform can be obtained by
loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via
the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
[`AutoFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a
tensor of type `torch.FloatTensor`. See [`~ASTFeatureExtractor.__call__`]
输入参数 `input_values`:
- 代表形状为 `(batch_size, max_length, num_mel_bins)` 的 `torch.FloatTensor`。
- 包含从原始音频波形提取的梅尔特征的浮点值。可以通过将 `.flac` 或 `.wav` 音频文件加载到 `List[float]` 或 `numpy.ndarray` 数组中获得原始音频波形。
- 要将数组准备成 `input_features`,应使用 [`AutoFeatureExtractor`] 提取梅尔特征,进行填充并转换为 `torch.FloatTensor` 类型的张量。参见 [`~ASTFeatureExtractor.__call__`]。
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
输入参数 `head_mask`(可选):
- 形状为 `(num_heads,)` 或 `(num_layers, num_heads)` 的 `torch.FloatTensor`。
- 用于屏蔽自注意力模块中选定头部的掩码。掩码值在 `[0, 1]` 范围内选择:
- 1 表示头部 **未被屏蔽**,
- 0 表示头部 **被屏蔽**。
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
输入参数 `output_attentions`(可选):
- 布尔值,指示是否返回所有注意力层的注意力张量。
- 查看返回的张量中的 `attentions` 以获取更多详细信息。
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
输入参数 `output_hidden_states`(可选):
- 布尔值,指示是否返回所有层的隐藏状态。
- 查看返回的张量中的 `hidden_states` 以获取更多详细信息。
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
输入参数 `return_dict`(可选):
- 布尔值,指示是否返回 [`~utils.ModelOutput`] 而不是普通元组。
"""
This class defines a transformer model for AST (Audio Spectrogram Transformer) without a specific head for output.
@add_start_docstrings(
"The bare AST Model transformer outputting raw hidden-states without any specific head on top.",
AUDIO_SPECTROGRAM_TRANSFORMER_START_DOCSTRING,
)
class ASTModel(ASTPreTrainedModel):
def __init__(self, config: ASTConfig) -> None:
super().__init__(config)
self.config = config
# Initialize AST embeddings and encoder based on provided configuration
self.embeddings = ASTEmbeddings(config)
self.encoder = ASTEncoder(config)
# Apply layer normalization across the hidden size dimension
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self) -> ASTPatchEmbeddings:
# Retrieve the patch embeddings used for input
return self.embeddings.patch_embeddings
def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
"""
Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
See base class PreTrainedModel.
"""
for layer, heads in heads_to_prune.items():
# Prune specified attention heads in each encoder layer
self.encoder.layer[layer].attention.prune_heads(heads)
@add_start_docstrings_to_model_forward(AUDIO_SPECTROGRAM_TRANSFORMER_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=BaseModelOutputWithPooling,
config_class=_CONFIG_FOR_DOC,
modality="audio",
expected_output=_EXPECTED_OUTPUT_SHAPE,
)
def forward(
self,
input_values: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
"""
Perform forward pass of the AST model.
Args:
input_values (Optional[torch.Tensor]): Input tensor to the model.
head_mask (Optional[torch.Tensor]): Mask to nullify selected heads of the model.
output_attentions (Optional[bool]): Whether to output attentions weights.
output_hidden_states (Optional[bool]): Whether to output hidden states.
return_dict (Optional[bool]): Whether to return a dictionary.
Returns:
BaseModelOutputWithPooling: Output with pooled representation.
"""
) -> Union[Tuple, BaseModelOutputWithPooling]:
# 如果未显式指定output_attentions,则使用配置中的默认值
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
# 如果未显式指定output_hidden_states,则使用配置中的默认值
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
# 如果未显式指定return_dict,则使用配置中的默认值
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_values is None:
# 如果输入值为None,则抛出数值错误
raise ValueError("You have to specify input_values")
# 准备头部掩码(如果需要)
# 头部掩码中的1.0表示保留该头部
# attention_probs的形状为bsz x n_heads x N x N
# 输入的head_mask形状为[num_heads]或[num_hidden_layers x num_heads]
# head_mask被转换为形状[num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
# 将输入值传递给嵌入层进行嵌入
embedding_output = self.embeddings(input_values)
# 将嵌入输出传递给编码器
encoder_outputs = self.encoder(
embedding_output,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# 取编码器输出的第一个元素作为序列输出
sequence_output = encoder_outputs[0]
# 序列输出经过LayerNorm层处理
sequence_output = self.layernorm(sequence_output)
# 计算池化输出,取序列输出的第一个和第二个位置的平均值
pooled_output = (sequence_output[:, 0] + sequence_output[:, 1]) / 2
if not return_dict:
# 如果不要求返回字典,则返回元组形式的输出
return (sequence_output, pooled_output) + encoder_outputs[1:]
# 如果要求返回字典形式的输出,则创建BaseModelOutputWithPooling对象
return BaseModelOutputWithPooling(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
class ASTMLPHead(nn.Module):
def __init__(self, config: ASTConfig):
super().__init__()
# 初始化一个 LayerNorm 层,用于标准化隐藏状态
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
# 如果标签数量大于0,则使用全连接层作为分类器;否则使用恒等映射
self.dense = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
def forward(self, hidden_state):
# 对隐藏状态进行 LayerNorm 标准化
hidden_state = self.layernorm(hidden_state)
# 应用全连接层或恒等映射,得到分类结果
hidden_state = self.dense(hidden_state)
return hidden_state
@add_start_docstrings(
"""
Audio Spectrogram Transformer model with an audio classification head on top (a linear layer on top of the pooled
output) e.g. for datasets like AudioSet, Speech Commands v2.
""",
AUDIO_SPECTROGRAM_TRANSFORMER_START_DOCSTRING,
)
class ASTForAudioClassification(ASTPreTrainedModel):
def __init__(self, config: ASTConfig) -> None:
super().__init__(config)
self.num_labels = config.num_labels
# 初始化 ASTModel 类,该类用于处理音频谱图的转换
self.audio_spectrogram_transformer = ASTModel(config)
# 分类器头部
self.classifier = ASTMLPHead(config)
# 初始化权重并应用最终处理
self.post_init()
@add_start_docstrings_to_model_forward(AUDIO_SPECTROGRAM_TRANSFORMER_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_SEQ_CLASS_CHECKPOINT,
output_type=SequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
modality="audio",
expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
)
def forward(
self,
input_values: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[tuple, SequenceClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the audio classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
# 确定是否返回字典形式的输出,如果未指定则使用配置中的默认设置
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# 使用音频频谱变换器处理输入数据,获取输出
outputs = self.audio_spectrogram_transformer(
input_values,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# 从处理后的输出中获取汇聚输出(pooled_output)
pooled_output = outputs[1]
# 使用分类器对汇聚输出进行分类得到 logits
logits = self.classifier(pooled_output)
# 初始化损失为 None
loss = None
# 如果提供了标签,则计算损失
if labels is not None:
# 如果问题类型未定义,则根据标签数据类型和标签数量确定问题类型
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
# 根据问题类型选择相应的损失函数进行计算
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
# 如果不需要返回字典形式的输出,则返回 logits 和可能的隐藏状态
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
# 如果需要返回字典形式的输出,则返回 SequenceClassifierOutput 对象
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
.\models\audio_spectrogram_transformer\__init__.py
from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
_import_structure = {
"configuration_audio_spectrogram_transformer": [
"AUDIO_SPECTROGRAM_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP",
"ASTConfig",
],
"feature_extraction_audio_spectrogram_transformer": ["ASTFeatureExtractor"],
}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_audio_spectrogram_transformer"] = [
"AUDIO_SPECTROGRAM_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
"ASTForAudioClassification",
"ASTModel",
"ASTPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_audio_spectrogram_transformer import (
AUDIO_SPECTROGRAM_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
ASTConfig,
)
from .feature_extraction_audio_spectrogram_transformer import ASTFeatureExtractor
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_audio_spectrogram_transformer import (
AUDIO_SPECTROGRAM_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
ASTForAudioClassification,
ASTModel,
ASTPreTrainedModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
.\models\auto\auto_factory.py
import copy
import importlib
import json
import os
import warnings
from collections import OrderedDict
from ...configuration_utils import PretrainedConfig
from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
from ...utils import (
CONFIG_NAME,
cached_file,
copy_func,
extract_commit_hash,
find_adapter_config_file,
is_peft_available,
logging,
requires_backends,
)
from .configuration_auto import (
AutoConfig,
model_type_to_module_name,
replace_list_option_in_docstrings,
)
logger = logging.get_logger(__name__)
CLASS_DOCSTRING = """
This is a generic model class that will be instantiated as one of the model classes of the library when created
with the [`~BaseAutoModelClass.from_pretrained`] class method or the [`~BaseAutoModelClass.from_config`] class
method.
This class cannot be instantiated directly using `__init__()` (throws an error).
"""
FROM_CONFIG_DOCSTRING = """
Instantiates one of the model classes of the library from a configuration.
Note:
Loading a model from its configuration file does **not** load the model weights. It only affects the
model's configuration. Use [`~BaseAutoModelClass.from_pretrained`] to load the model weights.
Args:
config ([`PretrainedConfig`]):
The model class to instantiate is selected based on the configuration class:
List options
attn_implementation (`str`, *optional*):
The attention implementation to use in the model (if relevant). Can be any of `"eager"` (manual implementation of the attention), `"sdpa"` (using [`F.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html)), or `"flash_attention_2"` (using [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention)). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual `"eager"` implementation.
Examples:
```
>>> from transformers import AutoConfig, BaseAutoModelClass
>>> # Download configuration from huggingface.co and cache.
>>> config = AutoConfig.from_pretrained("checkpoint_placeholder")
>>> model = BaseAutoModelClass.from_config(config)
```
"""
"""
BaseAutoModelClass is a base class for auto models. It provides functionality to select and instantiate a model class based on a configuration.
"""
def _get_model_class(config, model_mapping):
supported_models = model_mapping[type(config)]
if not isinstance(supported_models, (list, tuple)):
return supported_models
name_to_model = {model.__name__: model for model in supported_models}
architectures = getattr(config, "architectures", [])
for arch in architectures:
if arch in name_to_model:
return name_to_model[arch]
elif f"TF{arch}" in name_to_model:
return name_to_model[f"TF{arch}"]
elif f"Flax{arch}" in name_to_model:
return name_to_model[f"Flax{arch}"]
return supported_models[0]
class _BaseAutoModelClass:
_model_mapping = None
def __init__(self, *args, **kwargs):
raise EnvironmentError(
f"{self.__class__.__name__} is designed to be instantiated "
f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or "
f"`{self.__class__.__name__}.from_config(config)` methods."
)
@classmethod
def from_config(cls, config, **kwargs):
trust_remote_code = kwargs.pop("trust_remote_code", None)
has_remote_code = hasattr(config, "auto_map") and cls.__name__ in config.auto_map
has_local_code = type(config) in cls._model_mapping.keys()
trust_remote_code = resolve_trust_remote_code(
trust_remote_code, config._name_or_path, has_local_code, has_remote_code
)
if has_remote_code and trust_remote_code:
class_ref = config.auto_map[cls.__name__]
if "--" in class_ref:
repo_id, class_ref = class_ref.split("--")
else:
repo_id = config.name_or_path
model_class = get_class_from_dynamic_module(class_ref, repo_id, **kwargs)
if os.path.isdir(config._name_or_path):
model_class.register_for_auto_class(cls.__name__)
else:
cls.register(config.__class__, model_class, exist_ok=True)
_ = kwargs.pop("code_revision", None)
return model_class._from_config(config, **kwargs)
elif type(config) in cls._model_mapping.keys():
model_class = _get_model_class(config, cls._model_mapping)
return model_class._from_config(config, **kwargs)
raise ValueError(
f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
)
@classmethod
def register(cls, config_class, model_class, exist_ok=False):
"""
Register a new model for this class.
Args:
config_class ([`PretrainedConfig`]):
The configuration corresponding to the model to register.
model_class ([`PreTrainedModel`]):
The model to register.
"""
if hasattr(model_class, "config_class") and model_class.config_class != config_class:
raise ValueError(
"The model class you are passing has a `config_class` attribute that is not consistent with the "
f"config class you passed (model has {model_class.config_class} and you passed {config_class}. Fix "
"one of those so they match!"
)
cls._model_mapping.register(config_class, model_class, exist_ok=exist_ok)
class _BaseAutoBackboneClass(_BaseAutoModelClass):
_model_mapping = None
@classmethod
def _load_timm_backbone_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
requires_backends(cls, ["vision", "timm"])
from ...models.timm_backbone import TimmBackboneConfig
config = kwargs.pop("config", TimmBackboneConfig())
if kwargs.get("out_features", None) is not None:
raise ValueError("Cannot specify `out_features` for timm backbones")
if kwargs.get("output_loading_info", False):
raise ValueError("Cannot specify `output_loading_info=True` when loading from timm")
num_channels = kwargs.pop("num_channels", config.num_channels)
features_only = kwargs.pop("features_only", config.features_only)
use_pretrained_backbone = kwargs.pop("use_pretrained_backbone", config.use_pretrained_backbone)
out_indices = kwargs.pop("out_indices", config.out_indices)
config = TimmBackboneConfig(
backbone=pretrained_model_name_or_path,
num_channels=num_channels,
features_only=features_only,
use_pretrained_backbone=use_pretrained_backbone,
out_indices=out_indices,
)
return super().from_config(config, **kwargs)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
use_timm_backbone = kwargs.pop("use_timm_backbone", False)
if use_timm_backbone:
return cls._load_timm_backbone_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
def insert_head_doc(docstring, head_doc=""):
if len(head_doc) > 0:
return docstring.replace(
"one of the model classes of the library ",
f"one of the model classes of the library (with a {head_doc} head) ",
)
else:
return docstring.replace(
"one of the model classes of the library ", "one of the base model classes of the library "
)
def auto_class_update(cls, checkpoint_for_example="google-bert/bert-base-cased", head_doc=""):
model_mapping = cls._model_mapping
name = cls.__name__
class_docstring = insert_head_doc(CLASS_DOCSTRING, head_doc=head_doc)
cls.__doc__ = class_docstring.replace("BaseAutoModelClass", name)
from_config = copy_func(_BaseAutoModelClass.from_config)
from_config_docstring = insert_head_doc(FROM_CONFIG_DOCSTRING, head_doc=head_doc)
from_config_docstring = from_config_docstring.replace("BaseAutoModelClass", name)
from_config_docstring = from_config_docstring.replace("checkpoint_placeholder", checkpoint_for_example)
from_config.__doc__ = from_config_docstring
from_config = replace_list_option_in_docstrings(model_mapping._model_mapping, use_model_types=False)(from_config)
cls.from_config = classmethod(from_config)
if name.startswith("TF"):
from_pretrained_docstring = FROM_PRETRAINED_TF_DOCSTRING
elif name.startswith("Flax"):
from_pretrained_docstring = FROM_PRETRAINED_FLAX_DOCSTRING
else:
from_pretrained_docstring = FROM_PRETRAINED_TORCH_DOCSTRING
from_pretrained = copy_func(_BaseAutoModelClass.from_pretrained)
from_pretrained_docstring = insert_head_doc(from_pretrained_docstring, head_doc=head_doc)
from_pretrained_docstring = from_pretrained_docstring.replace("BaseAutoModelClass", name)
from_pretrained_docstring = from_pretrained_docstring.replace("checkpoint_placeholder", checkpoint_for_example)
shortcut = checkpoint_for_example.split("/")[-1].split("-")[0]
from_pretrained_docstring = from_pretrained_docstring.replace("shortcut_placeholder", shortcut)
from_pretrained.__doc__ = from_pretrained_docstring
from_pretrained = replace_list_option_in_docstrings(model_mapping._model_mapping)(from_pretrained)
cls.from_pretrained = classmethod(from_pretrained)
return cls
def get_values(model_mapping):
result = []
for model in model_mapping.values():
if isinstance(model, (list, tuple)):
result += list(model)
else:
result.append(model)
return result
def getattribute_from_module(module, attr):
if attr is None:
return None
if isinstance(attr, tuple):
return tuple(getattribute_from_module(module, a) for a in attr)
if hasattr(module, attr):
return getattr(module, attr)
transformers_module = importlib.import_module("transformers")
if module != transformers_module:
try:
return getattribute_from_module(transformers_module, attr)
except ValueError:
raise ValueError(f"Could not find {attr} neither in {module} nor in {transformers_module}!")
else:
raise ValueError(f"Could not find {attr} in {transformers_module}!")
class _LazyAutoMapping(OrderedDict):
"""
一个映射配置到对象(例如模型或分词器),在访问时加载键和值的类。
Args:
- config_mapping: 模型类型到配置类的映射
- model_mapping: 模型类型到模型(或分词器)类的映射
"""
def __init__(self, config_mapping, model_mapping):
self._config_mapping = config_mapping
self._reverse_config_mapping = {v: k for k, v in config_mapping.items()}
self._model_mapping = model_mapping
self._model_mapping._model_mapping = self
self._extra_content = {}
self._modules = {}
def __len__(self):
common_keys = set(self._config_mapping.keys()).intersection(self._model_mapping.keys())
return len(common_keys) + len(self._extra_content)
def __getitem__(self, key):
if key in self._extra_content:
return self._extra_content[key]
model_type = self._reverse_config_mapping[key.__name__]
if model_type in self._model_mapping:
model_name = self._model_mapping[model_type]
return self._load_attr_from_module(model_type, model_name)
model_types = [k for k, v in self._config_mapping.items() if v == key.__name__]
for mtype in model_types:
if mtype in self._model_mapping:
model_name = self._model_mapping[mtype]
return self._load_attr_from_module(mtype, model_name)
raise KeyError(key)
def _load_attr_from_module(self, model_type, attr):
module_name = model_type_to_module_name(model_type)
if module_name not in self._modules:
self._modules[module_name] = importlib.import_module(f".{module_name}", "transformers.models")
return getattribute_from_module(self._modules[module_name], attr)
def keys(self):
mapping_keys = [
self._load_attr_from_module(key, name)
for key, name in self._config_mapping.items()
if key in self._model_mapping.keys()
]
return mapping_keys + list(self._extra_content.keys())
def get(self, key, default):
try:
return self.__getitem__(key)
except KeyError:
return default
def __bool__(self):
return bool(self.keys())
def values(self):
mapping_values = [
self._load_attr_from_module(key, name)
for key, name in self._model_mapping.items()
if key in self._config_mapping.keys()
]
return mapping_values + list(self._extra_content.values())
def items(self):
mapping_items = [
(
self._load_attr_from_module(key, self._config_mapping[key]),
self._load_attr_from_module(key, self._model_mapping[key]),
)
for key in self._model_mapping.keys()
if key in self._config_mapping.keys()
]
return mapping_items + list(self._extra_content.items())
def __iter__(self):
return iter(self.keys())
def __contains__(self, item):
if item in self._extra_content:
return True
if not hasattr(item, "__name__") or item.__name__ not in self._reverse_config_mapping:
return False
model_type = self._reverse_config_mapping[item.__name__]
return model_type in self._model_mapping
def register(self, key, value, exist_ok=False):
"""
Register a new model in this mapping.
"""
if hasattr(key, "__name__") and key.__name__ in self._reverse_config_mapping:
model_type = self._reverse_config_mapping[key.__name__]
if model_type in self._model_mapping.keys() and not exist_ok:
raise ValueError(f"'{key}' is already used by a Transformers model.")
self._extra_content[key] = value
.\models\auto\configuration_auto.py
""" Auto Config class."""
import importlib
import os
import re
import warnings
from collections import OrderedDict
from typing import List, Union
from ...configuration_utils import PretrainedConfig
from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
from ...utils import CONFIG_NAME, logging
logger = logging.get_logger(__name__)
CONFIG_MAPPING_NAMES = OrderedDict(
[]
)
CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict(
[]
)
MODEL_NAMES_MAPPING = OrderedDict(
[]
)
DEPRECATED_MODELS = [
"bort",
"mctct",
"mmbt",
"open_llama",
"retribert",
"tapex",
"trajectory_transformer",
"transfo_xl",
"van",
]
SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict(
[
("openai-gpt", "openai"),
("data2vec-audio", "data2vec"),
("data2vec-text", "data2vec"),
("data2vec-vision", "data2vec"),
("donut-swin", "donut"),
("kosmos-2", "kosmos2"),
("maskformer-swin", "maskformer"),
("xclip", "x_clip"),
("clip_vision_model", "clip"),
("siglip_vision_model", "siglip"),
("chinese_clip_vision_model", "chinese_clip"),
]
)
def model_type_to_module_name(key):
"""Converts a config key to the corresponding module."""
if key in SPECIAL_MODEL_TYPE_TO_MODULE_NAME:
return SPECIAL_MODEL_TYPE_TO_MODULE_NAME[key]
key = key.replace("-", "_")
if key in DEPRECATED_MODELS:
key = f"deprecated.{key}"
return key
def config_class_to_model_type(config):
"""Converts a config class name to the corresponding model type"""
for key, cls in CONFIG_MAPPING_NAMES.items():
if cls == config:
return key
for key, cls in CONFIG_MAPPING._extra_content.items():
if cls.__name__ == config:
return key
return None
class _LazyConfigMapping(OrderedDict):
"""
A dictionary that lazily load its values when they are requested.
"""
def __init__(self, mapping):
self._mapping = mapping
self._extra_content = {}
self._modules = {}
def __getitem__(self, key):
if key in self._extra_content:
return self._extra_content[key]
if key not in self._mapping:
raise KeyError(key)
value = self._mapping[key]
module_name = model_type_to_module_name(key)
if module_name not in self._modules:
self._modules[module_name] = importlib.import_module(f".{module_name}", "transformers.models")
if hasattr(self._modules[module_name], value):
return getattr(self._modules[module_name], value)
transformers_module = importlib.import_module("transformers")
return getattr(transformers_module, value)
def keys(self):
return list(self._mapping.keys()) + list(self._extra_content.keys())
def values(self):
return [self[k] for k in self._mapping.keys()] + list(self._extra_content.values())
def items(self):
return [(k, self[k]) for k in self._mapping.keys()] + list(self._extra_content.items())
def __iter__(self):
return iter(list(self._mapping.keys()) + list(self._extra_content.keys()))
def __contains__(self, item):
return item in self._mapping or item in self._extra_content
def register(self, key, value, exist_ok=False):
"""
Register a new configuration in this mapping.
"""
if key in self._mapping.keys() and not exist_ok:
raise ValueError(f"'{key}' is already used by a Transformers config, pick another name.")
self._extra_content[key] = value
CONFIG_MAPPING = _LazyConfigMapping(CONFIG_MAPPING_NAMES)
class _LazyLoadAllMappings(OrderedDict):
"""
A mapping that will load all pairs of key values at the first access (either by indexing, requestions keys, values,
etc.)
Args:
mapping: The mapping to load.
"""
def __init__(self, mapping):
self._mapping = mapping
self._initialized = False
self._data = {}
def _initialize(self):
if self._initialized:
return
warnings.warn(
"ALL_PRETRAINED_CONFIG_ARCHIVE_MAP is deprecated and will be removed in v5 of Transformers. "
"It does not contain all available model checkpoints, far from it. Checkout hf.co/models for that.",
FutureWarning,
)
for model_type, map_name in self._mapping.items():
module_name = model_type_to_module_name(model_type)
module = importlib.import_module(f".{module_name}", "transformers.models")
mapping = getattr(module, map_name)
self._data.update(mapping)
self._initialized = True
def __getitem__(self, key):
self._initialize()
return self._data[key]
def keys(self):
self._initialize()
return self._data.keys()
def values(self):
self._initialize()
return self._data.values()
def items(self):
self._initialize()
return self._data.keys()
def __iter__(self):
self._initialize()
return iter(self._data)
def __contains__(self, item):
self._initialize()
return item in self._data
ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = _LazyLoadAllMappings(CONFIG_ARCHIVE_MAP_MAPPING_NAMES)
def _get_class_name(model_class: Union[str, List[str]]):
if isinstance(model_class, (list, tuple)):
return " or ".join([f"[`{c}`]" for c in model_class if c is not None])
return f"[`{model_class}`]"
def _list_model_options(indent, config_to_class=None, use_model_types=True):
if config_to_class is None and not use_model_types:
raise ValueError("Using `use_model_types=False` requires a `config_to_class` dictionary.")
if use_model_types:
if config_to_class is None:
model_type_to_name = {model_type: f"[`{config}`]" for model_type, config in CONFIG_MAPPING_NAMES.items()}
else:
model_type_to_name = {
model_type: _get_class_name(model_class)
for model_type, model_class in config_to_class.items()
if model_type in MODEL_NAMES_MAPPING
}
lines = [
f"{indent}- **{model_type}** -- {model_type_to_name[model_type]} ({MODEL_NAMES_MAPPING[model_type]} model)"
for model_type in sorted(model_type_to_name.keys())
]
else:
config_to_name = {
CONFIG_MAPPING_NAMES[config]: _get_class_name(clas)
for config, clas in config_to_class.items()
if config in CONFIG_MAPPING_NAMES
}
config_to_model_name = {
config: MODEL_NAMES_MAPPING[model_type] for model_type, config in CONFIG_MAPPING_NAMES.items()
}
lines = [
f"{indent}- [`{config_name}`] configuration class: {config_to_name[config_name]} ({config_to_model_name[config_name]} model)"
for config_name in sorted(config_to_name.keys())
]
return "\n".join(lines)
def replace_list_option_in_docstrings(config_to_class=None, use_model_types=True):
def docstring_decorator(fn):
docstrings = fn.__doc__
lines = docstrings.split("\n")
i = 0
while i < len(lines) and re.search(r"^(\s*)List options\s*$", lines[i]) is None:
i += 1
if i < len(lines):
indent = re.search(r"^(\s*)List options\s*$", lines[i]).groups()[0]
if use_model_types:
indent = f"{indent} "
lines[i] = _list_model_options(indent, config_to_class=config_to_class, use_model_types=use_model_types)
docstrings = "\n".join(lines)
else:
raise ValueError(
f"The function {fn} should have an empty 'List options' in its docstring as placeholder, current"
f" docstring is:\n{docstrings}"
)
fn.__doc__ = docstrings
return fn
return docstring_decorator
class AutoConfig:
r"""
This is a generic configuration class that will be instantiated as one of the configuration classes of the library
when created with the [`~AutoConfig.from_pretrained`] class method.
This class cannot be instantiated directly using `__init__()` (throws an error).
"""
def __init__(self):
raise EnvironmentError(
"AutoConfig is designed to be instantiated "
"using the `AutoConfig.from_pretrained(pretrained_model_name_or_path)` method."
)
@classmethod
def for_model(cls, model_type: str, *args, **kwargs):
if model_type in CONFIG_MAPPING:
config_class = CONFIG_MAPPING[model_type]
return config_class(*args, **kwargs)
raise ValueError(
f"Unrecognized model identifier: {model_type}. Should contain one of {', '.join(CONFIG_MAPPING.keys())}"
)
@classmethod
@replace_list_option_in_docstrings()
def register(model_type, config, exist_ok=False):
"""
Register a new configuration for this class.
Args:
model_type (`str`): The model type like "bert" or "gpt".
config ([`PretrainedConfig`]): The config to register.
"""
if issubclass(config, PretrainedConfig) and config.model_type != model_type:
raise ValueError(
"The config you are passing has a `model_type` attribute that is not consistent with the model type "
f"you passed (config has {config.model_type} and you passed {model_type}. Fix one of those so they "
"match!"
)
CONFIG_MAPPING.register(model_type, config, exist_ok=exist_ok)
.\models\auto\feature_extraction_auto.py
""" AutoFeatureExtractor class."""
import importlib
import json
import os
import warnings
from collections import OrderedDict
from typing import Dict, Optional, Union
from ...configuration_utils import PretrainedConfig
from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
from ...feature_extraction_utils import FeatureExtractionMixin
from ...utils import CONFIG_NAME, FEATURE_EXTRACTOR_NAME, get_file_from_repo, logging
from .auto_factory import _LazyAutoMapping
from .configuration_auto import (
CONFIG_MAPPING_NAMES,
AutoConfig,
model_type_to_module_name,
replace_list_option_in_docstrings,
)
logger = logging.get_logger(__name__)
FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
[
]
)
FEATURE_EXTRACTOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FEATURE_EXTRACTOR_MAPPING_NAMES)
def feature_extractor_class_from_name(class_name: str):
"""
根据特征提取器类名获取对应的特征提取器类对象。
Args:
class_name (str): 特征提取器类的名称。
Returns:
type or None: 如果找到匹配的特征提取器类,则返回该类对象;否则返回 None。
"""
for module_name, extractors in FEATURE_EXTRACTOR_MAPPING_NAMES.items():
if class_name in extractors:
module_name = model_type_to_module_name(module_name)
module = importlib.import_module(f".{module_name}", "transformers.models")
try:
return getattr(module, class_name)
except AttributeError:
continue
for _, extractor in FEATURE_EXTRACTOR_MAPPING._extra_content.items():
if getattr(extractor, "__name__", None) == class_name:
return extractor
main_module = importlib.import_module("transformers")
if hasattr(main_module, class_name):
return getattr(main_module, class_name)
return None
def get_feature_extractor_config(
pretrained_model_name_or_path: Union[str, os.PathLike],
cache_dir: Optional[Union[str, os.PathLike]] = None,
force_download: bool = False,
resume_download: bool = False,
proxies: Optional[Dict[str, str]] = None,
token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
local_files_only: bool = False,
**kwargs,
):
"""
从预训练模型加载特征提取器的配置信息。
Args:
pretrained_model_name_or_path (Union[str, os.PathLike]): 预训练模型的名称或路径。
cache_dir (Optional[Union[str, os.PathLike]], optional): 缓存目录路径,可选参数。默认为 None。
force_download (bool, optional): 是否强制下载,默认为 False。
resume_download (bool, optional): 是否恢复下载,默认为 False。
proxies (Optional[Dict[str, str]], optional): 代理设置,可选参数。默认为 None。
token (Optional[Union[bool, str]], optional): 访问令牌,可选参数。默认为 None。
revision (Optional[str], optional): 仓库的版本号,可选参数。默认为 None。
local_files_only (bool, optional): 是否仅使用本地文件,默认为 False。
**kwargs: 其他关键字参数。
Returns:
None
"""
pass
use_auth_token = kwargs.pop("use_auth_token", None)
if use_auth_token is not None:
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
FutureWarning,
)
if token is not None:
raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
token = use_auth_token
resolved_config_file = get_file_from_repo(
pretrained_model_name_or_path,
FEATURE_EXTRACTOR_NAME,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
token=token,
revision=revision,
local_files_only=local_files_only,
)
if resolved_config_file is None:
logger.info(
"Could not locate the feature extractor configuration file, will try to use the model config instead."
)
return {}
with open(resolved_config_file, encoding="utf-8") as reader:
return json.load(reader)
class AutoFeatureExtractor:
r"""
This is a generic feature extractor class that will be instantiated as one of the feature extractor classes of the
library when created with the [`AutoFeatureExtractor.from_pretrained`] class method.
This class cannot be instantiated directly using `__init__()` (throws an error).
"""
def __init__(self):
raise EnvironmentError(
"AutoFeatureExtractor is designed to be instantiated "
"using the `AutoFeatureExtractor.from_pretrained(pretrained_model_name_or_path)` method."
)
@classmethod
@replace_list_option_in_docstrings(FEATURE_EXTRACTOR_MAPPING_NAMES)
@staticmethod
def register(config_class, feature_extractor_class, exist_ok=False):
"""
Register a new feature extractor for this class.
Args:
config_class ([`PretrainedConfig`]):
The configuration corresponding to the model to register.
feature_extractor_class ([`FeatureExtractorMixin`]): The feature extractor to register.
"""
FEATURE_EXTRACTOR_MAPPING.register(config_class, feature_extractor_class, exist_ok=exist_ok)
.\models\auto\image_processing_auto.py
import warnings
from collections import OrderedDict
from typing import Dict, Optional, Union
from ...configuration_utils import PretrainedConfig
from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
from ...image_processing_utils import ImageProcessingMixin
from ...utils import CONFIG_NAME, IMAGE_PROCESSOR_NAME, get_file_from_repo, logging
from .auto_factory import _LazyAutoMapping
from .configuration_auto import (
CONFIG_MAPPING_NAMES,
AutoConfig,
model_type_to_module_name,
replace_list_option_in_docstrings,
)
logger = logging.get_logger(__name__)
IMAGE_PROCESSOR_MAPPING_NAMES = OrderedDict(
)
IMAGE_PROCESSOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, IMAGE_PROCESSOR_MAPPING_NAMES)
def image_processor_class_from_name(class_name: str):
for module_name, extractors in IMAGE_PROCESSOR_MAPPING_NAMES.items():
if class_name in extractors:
module_name = model_type_to_module_name(module_name)
module = importlib.import_module(f".{module_name}", "transformers.models")
try:
return getattr(module, class_name)
except AttributeError:
continue
for _, extractor in IMAGE_PROCESSOR_MAPPING._extra_content.items():
if getattr(extractor, "__name__", None) == class_name:
return extractor
main_module = importlib.import_module("transformers")
if hasattr(main_module, class_name):
return getattr(main_module, class_name)
return None
def get_image_processor_config(
pretrained_model_name_or_path: Union[str, os.PathLike],
cache_dir: Optional[Union[str, os.PathLike]] = None,
force_download: bool = False,
resume_download: bool = False,
proxies: Optional[Dict[str, str]] = None,
token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
local_files_only: bool = False,
**kwargs,
):
"""
从预训练模型的图像处理器配置中加载图像处理器配置信息。
"""
Args:
pretrained_model_name_or_path (`str` or `os.PathLike`):
This can be either:
- a string, the *model id* of a pretrained model configuration hosted inside a model repo on
huggingface.co.
- a path to a *directory* containing a configuration file saved using the
[`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
cache_dir (`str` or `os.PathLike`, *optional*):
Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
cache should not be used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force to (re-)download the configuration files and override the cached versions if they
exist.
resume_download (`bool`, *optional*, defaults to `False`):
Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `huggingface-cli login` (stored in `~/.huggingface`).
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
identifier allowed by git.
local_files_only (`bool`, *optional*, defaults to `False`):
If `True`, will only try to load the image processor configuration from local files.
<Tip>
Passing `token=True` is required when you want to use a private model.
</Tip>
Returns:
`Dict`: The configuration of the image processor.
Examples:
```
image_processor_config = get_image_processor_config("google-bert/bert-base-uncased")
image_processor_config = get_image_processor_config("FacebookAI/xlm-roberta-base")
from transformers import AutoTokenizer
image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
image_processor.save_pretrained("image-processor-test")
image_processor_config = get_image_processor_config("image-processor-test")
```
"""
use_auth_token = kwargs.pop("use_auth_token", None)
# 如果 use_auth_token 参数不为 None,则发出警告,提醒该参数将在 Transformers v5 版本中被移除
if use_auth_token is not None:
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
FutureWarning,
)
# 如果同时指定了 token 参数和 use_auth_token 参数,则抛出数值错误
if token is not None:
raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
# 将 token 参数设置为 use_auth_token 参数的值
token = use_auth_token
# 从指定的预训练模型名或路径中获取配置文件路径
resolved_config_file = get_file_from_repo(
pretrained_model_name_or_path,
IMAGE_PROCESSOR_NAME,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
token=token,
revision=revision,
local_files_only=local_files_only,
)
# 如果未能定位到图像处理器配置文件,则记录信息并返回空字典
if resolved_config_file is None:
logger.info(
"Could not locate the image processor configuration file, will try to use the model config instead."
)
return {}
# 打开配置文件并以 UTF-8 编码读取其中的内容,解析为 JSON 格式返回
with open(resolved_config_file, encoding="utf-8") as reader:
return json.load(reader)
class AutoImageProcessor:
r"""
This is a generic image processor class that will be instantiated as one of the image processor classes of the
library when created with the [`AutoImageProcessor.from_pretrained`] class method.
This class cannot be instantiated directly using `__init__()` (throws an error).
"""
def __init__(self):
# 抛出环境错误,阻止直接实例化该类
raise EnvironmentError(
"AutoImageProcessor is designed to be instantiated "
"using the `AutoImageProcessor.from_pretrained(pretrained_model_name_or_path)` method."
)
@classmethod
@replace_list_option_in_docstrings(IMAGE_PROCESSOR_MAPPING_NAMES)
@staticmethod
def register(config_class, image_processor_class, exist_ok=False):
"""
Register a new image processor for this class.
Args:
config_class ([`PretrainedConfig`]):
The configuration corresponding to the model to register.
image_processor_class ([`ImageProcessingMixin`]): The image processor to register.
"""
# 调用全局注册函数,将给定的配置类和图像处理器类注册到映射表中
IMAGE_PROCESSOR_MAPPING.register(config_class, image_processor_class, exist_ok=exist_ok)
.\models\auto\modeling_auto.py
""" Auto Model class."""
import warnings
from collections import OrderedDict
from ...utils import logging
from .auto_factory import (
_BaseAutoBackboneClass,
_BaseAutoModelClass,
_LazyAutoMapping,
auto_class_update,
)
from .configuration_auto import CONFIG_MAPPING_NAMES
logger = logging.get_logger(__name__)
MODEL_MAPPING_NAMES = OrderedDict(
)
MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
)
MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
)
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
)
MODEL_FOR_IMAGE_MAPPING_NAMES = OrderedDict(
)
[
("beit", "BeitModel"),
("bit", "BitModel"),
("conditional_detr", "ConditionalDetrModel"),
("convnext", "ConvNextModel"),
("convnextv2", "ConvNextV2Model"),
("data2vec-vision", "Data2VecVisionModel"),
("deformable_detr", "DeformableDetrModel"),
("deit", "DeiTModel"),
("deta", "DetaModel"),
("detr", "DetrModel"),
("dinat", "DinatModel"),
("dinov2", "Dinov2Model"),
("dpt", "DPTModel"),
("efficientformer", "EfficientFormerModel"),
("efficientnet", "EfficientNetModel"),
("focalnet", "FocalNetModel"),
("glpn", "GLPNModel"),
("imagegpt", "ImageGPTModel"),
("levit", "LevitModel"),
("mobilenet_v1", "MobileNetV1Model"),
("mobilenet_v2", "MobileNetV2Model"),
("mobilevit", "MobileViTModel"),
("mobilevitv2", "MobileViTV2Model"),
("nat", "NatModel"),
("poolformer", "PoolFormerModel"),
("pvt", "PvtModel"),
("regnet", "RegNetModel"),
("resnet", "ResNetModel"),
("segformer", "SegformerModel"),
("siglip_vision_model", "SiglipVisionModel"),
("swiftformer", "SwiftFormerModel"),
("swin", "SwinModel"),
("swin2sr", "Swin2SRModel"),
("swinv2", "Swinv2Model"),
("table-transformer", "TableTransformerModel"),
("timesformer", "TimesformerModel"),
("timm_backbone", "TimmBackbone"),
("van", "VanModel"),
("videomae", "VideoMAEModel"),
("vit", "ViTModel"),
("vit_hybrid", "ViTHybridModel"),
("vit_mae", "ViTMAEModel"),
("vit_msn", "ViTMSNModel"),
("vitdet", "VitDetModel"),
("vivit", "VivitModel"),
("yolos", "YolosModel"),
]
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES = OrderedDict(
[
("deit", "DeiTForMaskedImageModeling"),
("focalnet", "FocalNetForMaskedImageModeling"),
("swin", "SwinForMaskedImageModeling"),
("swinv2", "Swinv2ForMaskedImageModeling"),
("vit", "ViTForMaskedImageModeling"),
]
)
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING_NAMES = OrderedDict(
[
("imagegpt", "ImageGPTForCausalImageModeling"),
]
)
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
("beit", "BeitForImageClassification"),
("bit", "BitForImageClassification"),
("clip", "CLIPForImageClassification"),
("convnext", "ConvNextForImageClassification"),
("convnextv2", "ConvNextV2ForImageClassification"),
("cvt", "CvtForImageClassification"),
("data2vec-vision", "Data2VecVisionForImageClassification"),
(
"deit",
("DeiTForImageClassification", "DeiTForImageClassificationWithTeacher"),
),
("dinat", "DinatForImageClassification"),
("dinov2", "Dinov2ForImageClassification"),
(
"efficientformer",
(
"EfficientFormerForImageClassification",
"EfficientFormerForImageClassificationWithTeacher",
),
),
("efficientnet", "EfficientNetForImageClassification"),
("focalnet", "FocalNetForImageClassification"),
("imagegpt", "ImageGPTForImageClassification"),
(
"levit",
("LevitForImageClassification", "LevitForImageClassificationWithTeacher"),
),
("mobilenet_v1", "MobileNetV1ForImageClassification"),
("mobilenet_v2", "MobileNetV2ForImageClassification"),
("mobilevit", "MobileViTForImageClassification"),
("mobilevitv2", "MobileViTV2ForImageClassification"),
("nat", "NatForImageClassification"),
(
"perceiver",
(
"PerceiverForImageClassificationLearned",
"PerceiverForImageClassificationFourier",
"PerceiverForImageClassificationConvProcessing",
),
),
("poolformer", "PoolFormerForImageClassification"),
("pvt", "PvtForImageClassification"),
("pvt_v2", "PvtV2ForImageClassification"),
("regnet", "RegNetForImageClassification"),
("resnet", "ResNetForImageClassification"),
("segformer", "SegformerForImageClassification"),
("siglip", "SiglipForImageClassification"),
("swiftformer", "SwiftFormerForImageClassification"),
("swin", "SwinForImageClassification"),
("swinv2", "Swinv2ForImageClassification"),
("van", "VanForImageClassification"),
("vit", "ViTForImageClassification"),
("vit_hybrid", "ViTHybridForImageClassification"),
("vit_msn", "ViTMSNForImageClassification"),
]
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES = OrderedDict(
[
("detr", "DetrForSegmentation"),
]
)
MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = OrderedDict(
[
("beit", "BeitForSemanticSegmentation"),
("data2vec-vision", "Data2VecVisionForSemanticSegmentation"),
("dpt", "DPTForSemanticSegmentation"),
("mobilenet_v2", "MobileNetV2ForSemanticSegmentation"),
("mobilevit", "MobileViTForSemanticSegmentation"),
("mobilevitv2", "MobileViTV2ForSemanticSegmentation"),
("segformer", "SegformerForSemanticSegmentation"),
("upernet", "UperNetForSemanticSegmentation"),
]
)
MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING_NAMES = OrderedDict(
[
("maskformer", "MaskFormerForInstanceSegmentation"),
]
)
MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES = OrderedDict(
[
("detr", "DetrForSegmentation"),
("mask2former", "Mask2FormerForUniversalSegmentation"),
("maskformer", "MaskFormerForInstanceSegmentation"),
("oneformer", "OneFormerForUniversalSegmentation"),
]
)
MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
("timesformer", "TimesformerForVideoClassification"),
("videomae", "VideoMAEForVideoClassification"),
("vivit", "VivitForVideoClassification"),
]
)
MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
[
("blip", "BlipForConditionalGeneration"),
("blip-2", "Blip2ForConditionalGeneration"),
("git", "GitForCausalLM"),
("instructblip", "InstructBlipForConditionalGeneration"),
("kosmos-2", "Kosmos2ForConditionalGeneration"),
("llava", "LlavaForConditionalGeneration"),
("llava_next", "LlavaNextForConditionalGeneration"),
("pix2struct", "Pix2StructForConditionalGeneration"),
("vipllava", "VipLlavaForConditionalGeneration"),
("vision-encoder-decoder", "VisionEncoderDecoderModel"),
]
)
MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
[
("albert", "AlbertForMaskedLM"),
("bart", "BartForConditionalGeneration"),
("bert", "BertForMaskedLM"),
("big_bird", "BigBirdForMaskedLM"),
("camembert", "CamembertForMaskedLM"),
("convbert", "ConvBertForMaskedLM"),
("data2vec-text", "Data2VecTextForMaskedLM"),
("deberta", "DebertaForMaskedLM"),
("deberta-v2", "DebertaV2ForMaskedLM"),
("distilbert", "DistilBertForMaskedLM"),
("electra", "ElectraForMaskedLM"),
("ernie", "ErnieForMaskedLM"),
("esm", "EsmForMaskedLM"),
("flaubert", "FlaubertWithLMHeadModel"),
("fnet", "FNetForMaskedLM"),
("funnel", "FunnelForMaskedLM"),
("ibert", "IBertForMaskedLM"),
("layoutlm", "LayoutLMForMaskedLM"),
("longformer", "LongformerForMaskedLM"),
("luke", "LukeForMaskedLM"),
("mbart", "MBartForConditionalGeneration"),
("mega", "MegaForMaskedLM"),
("megatron-bert", "MegatronBertForMaskedLM"),
("mobilebert", "MobileBertForMaskedLM"),
("mpnet", "MPNetForMaskedLM"),
("mra", "MraForMaskedLM"),
("mvp", "MvpForConditionalGeneration"),
("nezha", "NezhaForMaskedLM"),
("nystromformer", "NystromformerForMaskedLM"),
("perceiver", "PerceiverForMaskedLM"),
("qdqbert", "QDQBertForMaskedLM"),
("reformer", "ReformerForMaskedLM"),
("rembert", "RemBertForMaskedLM"),
("roberta", "RobertaForMaskedLM"),
("roberta-prelayernorm", "RobertaPreLayerNormForMaskedLM"),
("roc_bert", "RoCBertForMaskedLM"),
("roformer", "RoFormerForMaskedLM"),
("squeezebert", "SqueezeBertForMaskedLM"),
("tapas", "TapasForMaskedLM"),
("wav2vec2", "Wav2Vec2ForMaskedLM"),
("xlm", "XLMWithLMHeadModel"),
("xlm-roberta", "XLMRobertaForMaskedLM"),
("xlm-roberta-xl", "XLMRobertaXLForMaskedLM"),
("xmod", "XmodForMaskedLM"),
("yoso", "YosoForMaskedLM"),
]
MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict(
[
("conditional_detr", "ConditionalDetrForObjectDetection"),
("deformable_detr", "DeformableDetrForObjectDetection"),
("deta", "DetaForObjectDetection"),
("detr", "DetrForObjectDetection"),
("table-transformer", "TableTransformerForObjectDetection"),
("yolos", "YolosForObjectDetection"),
]
)
MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict(
[
("owlv2", "Owlv2ForObjectDetection"),
("owlvit", "OwlViTForObjectDetection"),
]
)
MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES = OrderedDict(
[
("depth_anything", "DepthAnythingForDepthEstimation"),
("dpt", "DPTForDepthEstimation"),
("glpn", "GLPNForDepthEstimation"),
]
)
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
[
("bart", "BartForConditionalGeneration"),
("bigbird_pegasus", "BigBirdPegasusForConditionalGeneration"),
("blenderbot", "BlenderbotForConditionalGeneration"),
("blenderbot-small", "BlenderbotSmallForConditionalGeneration"),
("encoder-decoder", "EncoderDecoderModel"),
("fsmt", "FSMTForConditionalGeneration"),
("gptsan-japanese", "GPTSanJapaneseForConditionalGeneration"),
("led", "LEDForConditionalGeneration"),
("longt5", "LongT5ForConditionalGeneration"),
("m2m_100", "M2M100ForConditionalGeneration"),
("marian", "MarianMTModel"),
("mbart", "MBartForConditionalGeneration"),
("mt5", "MT5ForConditionalGeneration"),
("mvp", "MvpForConditionalGeneration"),
("nllb-moe", "NllbMoeForConditionalGeneration"),
("pegasus", "PegasusForConditionalGeneration"),
("pegasus_x", "PegasusXForConditionalGeneration"),
("plbart", "PLBartForConditionalGeneration"),
("prophetnet", "ProphetNetForConditionalGeneration"),
("seamless_m4t", "SeamlessM4TForTextToText"),
("seamless_m4t_v2", "SeamlessM4Tv2ForTextToText"),
("switch_transformers", "SwitchTransformersForConditionalGeneration"),
("t5", "T5ForConditionalGeneration"),
("umt5", "UMT5ForConditionalGeneration"),
("xlm-prophetnet", "XLMProphetNetForConditionalGeneration"),
]
)
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
[
("pop2piano", "Pop2PianoForConditionalGeneration"),
("seamless_m4t", "SeamlessM4TForSpeechToText"),
("seamless_m4t_v2", "SeamlessM4Tv2ForSpeechToText"),
("speech-encoder-decoder", "SpeechEncoderDecoderModel"),
("speech_to_text", "Speech2TextForConditionalGeneration"),
("speecht5", "SpeechT5ForSpeechToText"),
("whisper", "WhisperForConditionalGeneration"),
]
)
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
]
)
MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
]
)
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
[
("tapas", "TapasForQuestionAnswering"),
]
)
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
[
("blip", "BlipForQuestionAnswering"),
("blip-2", "Blip2ForConditionalGeneration"),
("vilt", "ViltForQuestionAnswering"),
]
)
MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
[
("layoutlm", "LayoutLMForQuestionAnswering"),
("layoutlmv2", "LayoutLMv2ForQuestionAnswering"),
("layoutlmv3", "LayoutLMv3ForQuestionAnswering"),
]
)
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
]
)
MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
[
("albert", "AlbertForMultipleChoice"),
("bert", "BertForMultipleChoice"),
("big_bird", "BigBirdForMultipleChoice"),
("camembert", "CamembertForMultipleChoice"),
("canine", "CanineForMultipleChoice"),
("convbert", "ConvBertForMultipleChoice"),
("data2vec-text", "Data2VecTextForMultipleChoice"),
("deberta-v2", "DebertaV2ForMultipleChoice"),
("distilbert", "DistilBertForMultipleChoice"),
("electra", "ElectraForMultipleChoice"),
("ernie", "ErnieForMultipleChoice"),
("ernie_m", "ErnieMForMultipleChoice"),
("flaubert", "FlaubertForMultipleChoice"),
("fnet", "FNetForMultipleChoice"),
("funnel", "FunnelForMultipleChoice"),
("ibert", "IBertForMultipleChoice"),
("longformer", "LongformerForMultipleChoice"),
("luke", "LukeForMultipleChoice"),
("mega", "MegaForMultipleChoice"),
("megatron-bert", "MegatronBertForMultipleChoice"),
("mobilebert", "MobileBertForMultipleChoice"),
("mpnet", "MPNetForMultipleChoice"),
("mra", "MraForMultipleChoice"),
("nezha", "NezhaForMultipleChoice"),
("nystromformer", "NystromformerForMultipleChoice"),
("qdqbert", "QDQBertForMultipleChoice"),
("rembert", "RemBertForMultipleChoice"),
("roberta", "RobertaForMultipleChoice"),
("roberta-prelayernorm", "RobertaPreLayerNormForMultipleChoice"),
("roc_bert", "RoCBertForMultipleChoice"),
("roformer", "RoFormerForMultipleChoice"),
("squeezebert", "SqueezeBertForMultipleChoice"),
("xlm", "XLMForMultipleChoice"),
("xlm-roberta", "XLMRobertaForMultipleChoice"),
("xlm-roberta-xl", "XLMRobertaXLForMultipleChoice"),
("xlnet", "XLNetForMultipleChoice"),
("xmod", "XmodForMultipleChoice"),
("yoso", "YosoForMultipleChoice"),
]
)
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict(
)
[
("bert", "BertForNextSentencePrediction"),
("ernie", "ErnieForNextSentencePrediction"),
("fnet", "FNetForNextSentencePrediction"),
("megatron-bert", "MegatronBertForNextSentencePrediction"),
("mobilebert", "MobileBertForNextSentencePrediction"),
("nezha", "NezhaForNextSentencePrediction"),
("qdqbert", "QDQBertForNextSentencePrediction"),
]
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
("audio-spectrogram-transformer", "ASTForAudioClassification"),
("data2vec-audio", "Data2VecAudioForSequenceClassification"),
("hubert", "HubertForSequenceClassification"),
("sew", "SEWForSequenceClassification"),
("sew-d", "SEWDForSequenceClassification"),
("unispeech", "UniSpeechForSequenceClassification"),
("unispeech-sat", "UniSpeechSatForSequenceClassification"),
("wav2vec2", "Wav2Vec2ForSequenceClassification"),
("wav2vec2-bert", "Wav2Vec2BertForSequenceClassification"),
("wav2vec2-conformer", "Wav2Vec2ConformerForSequenceClassification"),
("wavlm", "WavLMForSequenceClassification"),
("whisper", "WhisperForAudioClassification"),
]
)
MODEL_FOR_CTC_MAPPING_NAMES = OrderedDict(
[
("data2vec-audio", "Data2VecAudioForCTC"),
("hubert", "HubertForCTC"),
("mctct", "MCTCTForCTC"),
("sew", "SEWForCTC"),
("sew-d", "SEWDForCTC"),
("unispeech", "UniSpeechForCTC"),
("unispeech-sat", "UniSpeechSatForCTC"),
("wav2vec2", "Wav2Vec2ForCTC"),
("wav2vec2-bert", "Wav2Vec2BertForCTC"),
("wav2vec2-conformer", "Wav2Vec2ConformerForCTC"),
("wavlm", "WavLMForCTC"),
]
)
MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
("data2vec-audio", "Data2VecAudioForAudioFrameClassification"),
("unispeech-sat", "UniSpeechSatForAudioFrameClassification"),
("wav2vec2", "Wav2Vec2ForAudioFrameClassification"),
("wav2vec2-bert", "Wav2Vec2BertForAudioFrameClassification"),
("wav2vec2-conformer", "Wav2Vec2ConformerForAudioFrameClassification"),
("wavlm", "WavLMForAudioFrameClassification"),
]
)
MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES = OrderedDict(
[
("data2vec-audio", "Data2VecAudioForXVector"),
("unispeech-sat", "UniSpeechSatForXVector"),
("wav2vec2", "Wav2Vec2ForXVector"),
("wav2vec2-bert", "Wav2Vec2BertForXVector"),
("wav2vec2-conformer", "Wav2Vec2ConformerForXVector"),
("wavlm", "WavLMForXVector"),
]
)
MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES = OrderedDict(
[
("fastspeech2_conformer", "FastSpeech2ConformerModel"),
("speecht5", "SpeechT5ForTextToSpeech"),
]
)
MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES = OrderedDict(
[
("bark", "BarkModel"),
("fastspeech2_conformer", "FastSpeech2ConformerWithHifiGan"),
("musicgen", "MusicgenForConditionalGeneration"),
("musicgen_melody", "MusicgenMelodyForConditionalGeneration"),
("seamless_m4t", "SeamlessM4TForTextToSpeech"),
("seamless_m4t_v2", "SeamlessM4Tv2ForTextToSpeech"),
("vits", "VitsModel"),
]
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
("align", "AlignModel"),
("altclip", "AltCLIPModel"),
("blip", "BlipModel"),
("chinese_clip", "ChineseCLIPModel"),
("clip", "CLIPModel"),
("clipseg", "CLIPSegModel"),
("siglip", "SiglipModel"),
]
)
MODEL_FOR_BACKBONE_MAPPING_NAMES = OrderedDict(
[
("beit", "BeitBackbone"),
("bit", "BitBackbone"),
("convnext", "ConvNextBackbone"),
("convnextv2", "ConvNextV2Backbone"),
("dinat", "DinatBackbone"),
("dinov2", "Dinov2Backbone"),
("focalnet", "FocalNetBackbone"),
("maskformer-swin", "MaskFormerSwinBackbone"),
("nat", "NatBackbone"),
("pvt_v2", "PvtV2Backbone"),
("resnet", "ResNetBackbone"),
("swin", "SwinBackbone"),
("swinv2", "Swinv2Backbone"),
("timm_backbone", "TimmBackbone"),
("vitdet", "VitDetBackbone"),
]
)
MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = OrderedDict(
[
("sam", "SamModel"),
]
)
MODEL_FOR_KEYPOINT_DETECTION_MAPPING_NAMES = OrderedDict(
[
("superpoint", "SuperPointForKeypointDetection"),
]
)
MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES = OrderedDict(
[
("albert", "AlbertModel"),
("bert", "BertModel"),
("big_bird", "BigBirdModel"),
("data2vec-text", "Data2VecTextModel"),
("deberta", "DebertaModel"),
("deberta-v2", "DebertaV2Model"),
("distilbert", "DistilBertModel"),
("electra", "ElectraModel"),
("flaubert", "FlaubertModel"),
("ibert", "IBertModel"),
("longformer", "LongformerModel"),
("mobilebert", "MobileBertModel"),
("mt5", "MT5EncoderModel"),
("nystromformer", "NystromformerModel"),
("reformer", "ReformerModel"),
("rembert", "RemBertModel"),
("roberta", "RobertaModel"),
("roberta-prelayernorm", "RobertaPreLayerNormModel"),
("roc_bert", "RoCBertModel"),
("roformer", "RoFormerModel"),
("squeezebert", "SqueezeBertModel"),
("t5", "T5EncoderModel"),
("umt5", "UMT5EncoderModel"),
("xlm", "XLMModel"),
("xlm-roberta", "XLMRobertaModel"),
("xlm-roberta-xl", "XLMRobertaXLModel"),
]
)
MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
("patchtsmixer", "PatchTSMixerForTimeSeriesClassification"),
("patchtst", "PatchTSTForClassification"),
]
)
MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING_NAMES = OrderedDict(
[
("patchtsmixer", "PatchTSMixerForRegression"),
("patchtst", "PatchTSTForRegression"),
]
)
MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES = OrderedDict(
[
("swin2sr", "Swin2SRForImageSuperResolution"),
]
)
MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES)
MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_PRETRAINING_MAPPING_NAMES)
MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_WITH_LM_HEAD_MAPPING_NAMES)
MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING_NAMES)
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES)
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES)
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES)
MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES)
MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING_NAMES)
MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES)
MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES)
MODEL_FOR_VISION_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES)
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING_NAMES)
MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES)
MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_LM_MAPPING_NAMES)
MODEL_FOR_IMAGE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_MAPPING_NAMES)
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES)
MODEL_FOR_OBJECT_DETECTION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES)
MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES)
MODEL_FOR_DEPTH_ESTIMATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES)
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES)
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES)
MODEL_FOR_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES)
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES
)
CONFIG_MAPPING_NAMES, MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
)
MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES)
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES
)
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
)
MODEL_FOR_CTC_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CTC_MAPPING_NAMES)
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES)
MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES
)
MODEL_FOR_AUDIO_XVECTOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES)
MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES
)
MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES)
MODEL_FOR_BACKBONE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_BACKBONE_MAPPING_NAMES)
MODEL_FOR_MASK_GENERATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASK_GENERATION_MAPPING_NAMES)
MODEL_FOR_KEYPOINT_DETECTION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_KEYPOINT_DETECTION_MAPPING_NAMES
)
MODEL_FOR_TEXT_ENCODING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES)
MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING_NAMES
)
MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING_NAMES
)
MODEL_FOR_IMAGE_TO_IMAGE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES)
_model_mapping = MODEL_WITH_LM_HEAD_MAPPING
_AutoModelWithLMHead = auto_class_update(_AutoModelWithLMHead, head_doc="language modeling")
class AutoModelForCausalLM(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_CAUSAL_LM_MAPPING
AutoModelForCausalLM = auto_class_update(AutoModelForCausalLM, head_doc="causal language modeling")
class AutoModelForMaskedLM(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_MASKED_LM_MAPPING
AutoModelForMaskedLM = auto_class_update(AutoModelForMaskedLM, head_doc="masked language modeling")
class AutoModelForSeq2SeqLM(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
AutoModelForSeq2SeqLM = auto_class_update(
AutoModelForSeq2SeqLM,
head_doc="sequence-to-sequence language modeling",
checkpoint_for_example="google-t5/t5-base",
)
class AutoModelForSequenceClassification(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
AutoModelForSequenceClassification = auto_class_update(AutoModelForSequenceClassification, head_doc="sequence classification")
class AutoModelForQuestionAnswering(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_QUESTION_ANSWERING_MAPPING
AutoModelForQuestionAnswering = auto_class_update(AutoModelForQuestionAnswering, head_doc="question answering")
class AutoModelForTableQuestionAnswering(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING
AutoModelForTableQuestionAnswering = auto_class_update(
AutoModelForTableQuestionAnswering,
head_doc="table question answering",
checkpoint_for_example="google/tapas-base-finetuned-wtq",
)
class AutoModelForVisualQuestionAnswering(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING
AutoModelForVisualQuestionAnswering = auto_class_update(
AutoModelForVisualQuestionAnswering,
head_doc="visual question answering",
checkpoint_for_example="dandelin/vilt-b32-finetuned-vqa",
)
class AutoModelForDocumentQuestionAnswering(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING
AutoModelForDocumentQuestionAnswering = auto_class_update(
AutoModelForDocumentQuestionAnswering,
head_doc="document question answering",
checkpoint_for_example='impira/layoutlm-document-qa", revision="52e01b3',
)
class AutoModelForTokenClassification(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
AutoModelForTokenClassification = auto_class_update(AutoModelForTokenClassification, head_doc="token classification")
class AutoModelForMultipleChoice(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_MULTIPLE_CHOICE_MAPPING
AutoModelForMultipleChoice = auto_class_update(AutoModelForMultipleChoice, head_doc="multiple choice")
class AutoModelForNextSentencePrediction(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING
AutoModelForNextSentencePrediction, head_doc="next sentence prediction"
class AutoModelForImageClassification(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
AutoModelForImageClassification = auto_class_update(AutoModelForImageClassification, head_doc="image classification")
class AutoModelForZeroShotImageClassification(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
AutoModelForZeroShotImageClassification = auto_class_update(
AutoModelForZeroShotImageClassification, head_doc="zero-shot image classification"
)
class AutoModelForImageSegmentation(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_IMAGE_SEGMENTATION_MAPPING
AutoModelForImageSegmentation = auto_class_update(AutoModelForImageSegmentation, head_doc="image segmentation")
class AutoModelForSemanticSegmentation(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING
AutoModelForSemanticSegmentation = auto_class_update(
AutoModelForSemanticSegmentation, head_doc="semantic segmentation"
)
class AutoModelForUniversalSegmentation(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING
AutoModelForUniversalSegmentation = auto_class_update(
AutoModelForUniversalSegmentation, head_doc="universal image segmentation"
)
class AutoModelForInstanceSegmentation(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING
AutoModelForInstanceSegmentation = auto_class_update(
AutoModelForInstanceSegmentation, head_doc="instance segmentation"
)
class AutoModelForObjectDetection(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_OBJECT_DETECTION_MAPPING
AutoModelForObjectDetection = auto_class_update(AutoModelForObjectDetection, head_doc="object detection")
class AutoModelForZeroShotObjectDetection(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING
AutoModelForZeroShotObjectDetection = auto_class_update(
AutoModelForZeroShotObjectDetection, head_doc="zero-shot object detection"
)
class AutoModelForDepthEstimation(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_DEPTH_ESTIMATION_MAPPING
AutoModelForDepthEstimation = auto_class_update(AutoModelForDepthEstimation, head_doc="depth estimation")
class AutoModelForVideoClassification(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING
AutoModelForVideoClassification = auto_class_update(AutoModelForVideoClassification, head_doc="video classification")
class AutoModelForVision2Seq(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_VISION_2_SEQ_MAPPING
AutoModelForVision2Seq = auto_class_update(AutoModelForVision2Seq, head_doc="vision-to-text modeling")
class AutoModelForAudioClassification(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
AutoModelForAudioClassification = auto_class_update(AutoModelForAudioClassification, head_doc="audio classification")
_model_mapping = MODEL_FOR_CTC_MAPPING
AutoModelForCTC = auto_class_update(AutoModelForCTC, head_doc="connectionist temporal classification")
class AutoModelForSpeechSeq2Seq(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
AutoModelForSpeechSeq2Seq = auto_class_update(
AutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeling"
)
class AutoModelForAudioFrameClassification(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING
AutoModelForAudioFrameClassification = auto_class_update(
AutoModelForAudioFrameClassification, head_doc="audio frame (token) classification"
)
class AutoModelForAudioXVector(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_AUDIO_XVECTOR_MAPPING
class AutoModelForTextToSpectrogram(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING
class AutoModelForTextToWaveform(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING
class AutoBackbone(_BaseAutoBackboneClass):
_model_mapping = MODEL_FOR_BACKBONE_MAPPING
AutoModelForAudioXVector = auto_class_update(AutoModelForAudioXVector, head_doc="audio retrieval via x-vector")
class AutoModelForMaskedImageModeling(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING
AutoModelForMaskedImageModeling = auto_class_update(AutoModelForMaskedImageModeling, head_doc="masked image modeling")
class AutoModelWithLMHead(_AutoModelWithLMHead):
@classmethod
def from_config(cls, config):
warnings.warn(
"The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use "
"`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and "
"`AutoModelForSeq2SeqLM` for encoder-decoder models.",
FutureWarning,
)
return super().from_config(config)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
warnings.warn(
"The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use "
"`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and "
"`AutoModelForSeq2SeqLM` for encoder-decoder models.",
FutureWarning,
)
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
.\models\auto\modeling_flax_auto.py
from collections import OrderedDict
from ...utils import logging
from .auto_factory import _BaseAutoModelClass, _LazyAutoMapping, auto_class_update
from .configuration_auto import CONFIG_MAPPING_NAMES
logger = logging.get_logger(__name__)
FLAX_MODEL_MAPPING_NAMES = OrderedDict(
[
("albert", "FlaxAlbertModel"),
("bart", "FlaxBartModel"),
("beit", "FlaxBeitModel"),
("bert", "FlaxBertModel"),
("big_bird", "FlaxBigBirdModel"),
("blenderbot", "FlaxBlenderbotModel"),
("blenderbot-small", "FlaxBlenderbotSmallModel"),
("bloom", "FlaxBloomModel"),
("clip", "FlaxCLIPModel"),
("distilbert", "FlaxDistilBertModel"),
("electra", "FlaxElectraModel"),
("gemma", "FlaxGemmaModel"),
("gpt-sw3", "FlaxGPT2Model"),
("gpt2", "FlaxGPT2Model"),
("gpt_neo", "FlaxGPTNeoModel"),
("gptj", "FlaxGPTJModel"),
("llama", "FlaxLlamaModel"),
("longt5", "FlaxLongT5Model"),
("marian", "FlaxMarianModel"),
("mbart", "FlaxMBartModel"),
("mistral", "FlaxMistralModel"),
("mt5", "FlaxMT5Model"),
("opt", "FlaxOPTModel"),
("pegasus", "FlaxPegasusModel"),
("regnet", "FlaxRegNetModel"),
("resnet", "FlaxResNetModel"),
("roberta", "FlaxRobertaModel"),
("roberta-prelayernorm", "FlaxRobertaPreLayerNormModel"),
("roformer", "FlaxRoFormerModel"),
("t5", "FlaxT5Model"),
("vision-text-dual-encoder", "FlaxVisionTextDualEncoderModel"),
("vit", "FlaxViTModel"),
("wav2vec2", "FlaxWav2Vec2Model"),
("whisper", "FlaxWhisperModel"),
("xglm", "FlaxXGLMModel"),
("xlm-roberta", "FlaxXLMRobertaModel"),
]
)
FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
[
("albert", "FlaxAlbertForPreTraining"),
("bart", "FlaxBartForConditionalGeneration"),
("bert", "FlaxBertForPreTraining"),
("big_bird", "FlaxBigBirdForPreTraining"),
("electra", "FlaxElectraForPreTraining"),
("longt5", "FlaxLongT5ForConditionalGeneration"),
("mbart", "FlaxMBartForConditionalGeneration"),
("mt5", "FlaxMT5ForConditionalGeneration"),
("roberta", "FlaxRobertaForMaskedLM"),
("roberta-prelayernorm", "FlaxRobertaPreLayerNormForMaskedLM"),
("roformer", "FlaxRoFormerForMaskedLM"),
("t5", "FlaxT5ForConditionalGeneration"),
("wav2vec2", "FlaxWav2Vec2ForPreTraining"),
("whisper", "FlaxWhisperForConditionalGeneration"),
("xlm-roberta", "FlaxXLMRobertaForMaskedLM"),
]
FLAX_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
[
("albert", "FlaxAlbertForMaskedLM"),
("bart", "FlaxBartForConditionalGeneration"),
("bert", "FlaxBertForMaskedLM"),
("big_bird", "FlaxBigBirdForMaskedLM"),
("distilbert", "FlaxDistilBertForMaskedLM"),
("electra", "FlaxElectraForMaskedLM"),
("mbart", "FlaxMBartForConditionalGeneration"),
("roberta", "FlaxRobertaForMaskedLM"),
("roberta-prelayernorm", "FlaxRobertaPreLayerNormForMaskedLM"),
("roformer", "FlaxRoFormerForMaskedLM"),
("xlm-roberta", "FlaxXLMRobertaForMaskedLM"),
]
)
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
[
("bart", "FlaxBartForConditionalGeneration"),
("blenderbot", "FlaxBlenderbotForConditionalGeneration"),
("blenderbot-small", "FlaxBlenderbotSmallForConditionalGeneration"),
("encoder-decoder", "FlaxEncoderDecoderModel"),
("longt5", "FlaxLongT5ForConditionalGeneration"),
("marian", "FlaxMarianMTModel"),
("mbart", "FlaxMBartForConditionalGeneration"),
("mt5", "FlaxMT5ForConditionalGeneration"),
("pegasus", "FlaxPegasusForConditionalGeneration"),
("t5", "FlaxT5ForConditionalGeneration"),
]
)
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
("beit", "FlaxBeitForImageClassification"),
("regnet", "FlaxRegNetForImageClassification"),
("resnet", "FlaxResNetForImageClassification"),
("vit", "FlaxViTForImageClassification"),
]
)
FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
[
("vision-encoder-decoder", "FlaxVisionEncoderDecoderModel"),
]
)
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
[
("bart", "FlaxBartForCausalLM"),
("bert", "FlaxBertForCausalLM"),
("big_bird", "FlaxBigBirdForCausalLM"),
("bloom", "FlaxBloomForCausalLM"),
("electra", "FlaxElectraForCausalLM"),
("gemma", "FlaxGemmaForCausalLM"),
("gpt-sw3", "FlaxGPT2LMHeadModel"),
("gpt2", "FlaxGPT2LMHeadModel"),
("gpt_neo", "FlaxGPTNeoForCausalLM"),
("gptj", "FlaxGPTJForCausalLM"),
("llama", "FlaxLlamaForCausalLM"),
("mistral", "FlaxMistralForCausalLM"),
("opt", "FlaxOPTForCausalLM"),
("roberta", "FlaxRobertaForCausalLM"),
("roberta-prelayernorm", "FlaxRobertaPreLayerNormForCausalLM"),
("xglm", "FlaxXGLMForCausalLM"),
("xlm-roberta", "FlaxXLMRobertaForCausalLM"),
]
)
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
("albert", "FlaxAlbertForSequenceClassification"),
("bart", "FlaxBartForSequenceClassification"),
("bert", "FlaxBertForSequenceClassification"),
("big_bird", "FlaxBigBirdForSequenceClassification"),
("distilbert", "FlaxDistilBertForSequenceClassification"),
("electra", "FlaxElectraForSequenceClassification"),
("mbart", "FlaxMBartForSequenceClassification"),
("roberta", "FlaxRobertaForSequenceClassification"),
("roberta-prelayernorm", "FlaxRobertaPreLayerNormForSequenceClassification"),
("roformer", "FlaxRoFormerForSequenceClassification"),
("xlm-roberta", "FlaxXLMRobertaForSequenceClassification"),
]
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
[
("albert", "FlaxAlbertForQuestionAnswering"),
("bart", "FlaxBartForQuestionAnswering"),
("bert", "FlaxBertForQuestionAnswering"),
("big_bird", "FlaxBigBirdForQuestionAnswering"),
("distilbert", "FlaxDistilBertForQuestionAnswering"),
("electra", "FlaxElectraForQuestionAnswering"),
("mbart", "FlaxMBartForQuestionAnswering"),
("roberta", "FlaxRobertaForQuestionAnswering"),
("roberta-prelayernorm", "FlaxRobertaPreLayerNormForQuestionAnswering"),
("roformer", "FlaxRoFormerForQuestionAnswering"),
("xlm-roberta", "FlaxXLMRobertaForQuestionAnswering"),
]
)
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
("albert", "FlaxAlbertForTokenClassification"),
("bert", "FlaxBertForTokenClassification"),
("big_bird", "FlaxBigBirdForTokenClassification"),
("distilbert", "FlaxDistilBertForTokenClassification"),
("electra", "FlaxElectraForTokenClassification"),
("roberta", "FlaxRobertaForTokenClassification"),
("roberta-prelayernorm", "FlaxRobertaPreLayerNormForTokenClassification"),
("roformer", "FlaxRoFormerForTokenClassification"),
("xlm-roberta", "FlaxXLMRobertaForTokenClassification"),
]
)
FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
[
("albert", "FlaxAlbertForMultipleChoice"),
("bert", "FlaxBertForMultipleChoice"),
("big_bird", "FlaxBigBirdForMultipleChoice"),
("distilbert", "FlaxDistilBertForMultipleChoice"),
("electra", "FlaxElectraForMultipleChoice"),
("roberta", "FlaxRobertaForMultipleChoice"),
("roberta-prelayernorm", "FlaxRobertaPreLayerNormForMultipleChoice"),
("roformer", "FlaxRoFormerForMultipleChoice"),
("xlm-roberta", "FlaxXLMRobertaForMultipleChoice"),
]
)
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict(
[
("bert", "FlaxBertForNextSentencePrediction"),
]
)
FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
[
("speech-encoder-decoder", "FlaxSpeechEncoderDecoderModel"),
("whisper", "FlaxWhisperForConditionalGeneration"),
]
)
FLAX_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
("whisper", "FlaxWhisperForAudioClassification"),
]
)
FLAX_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_MAPPING_NAMES)
FLAX_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES)
FLAX_MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_MASKED_LM_MAPPING_NAMES)
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES)
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
)
FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES)
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
)
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
)
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
)
FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES
)
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES
)
FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES
)
FLAX_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
)
class FlaxAutoModel(_BaseAutoModelClass):
_model_mapping = FLAX_MODEL_MAPPING
FlaxAutoModel = auto_class_update(FlaxAutoModel)
class FlaxAutoModelForPreTraining(_BaseAutoModelClass):
_model_mapping = FLAX_MODEL_FOR_PRETRAINING_MAPPING
FlaxAutoModelForPreTraining = auto_class_update(FlaxAutoModelForPreTraining, head_doc="pretraining")
class FlaxAutoModelForCausalLM(_BaseAutoModelClass):
_model_mapping = FLAX_MODEL_FOR_CAUSAL_LM_MAPPING
FlaxAutoModelForCausalLM = auto_class_update(FlaxAutoModelForCausalLM, head_doc="causal language modeling")
class FlaxAutoModelForMaskedLM(_BaseAutoModelClass):
_model_mapping = FLAX_MODEL_FOR_MASKED_LM_MAPPING
FlaxAutoModelForMaskedLM = auto_class_update(FlaxAutoModelForMaskedLM, head_doc="masked language modeling")
class FlaxAutoModelForSeq2SeqLM(_BaseAutoModelClass):
_model_mapping = FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
FlaxAutoModelForSeq2SeqLM = auto_class_update(
FlaxAutoModelForSeq2SeqLM,
head_doc="sequence-to-sequence language modeling",
checkpoint_for_example="google-t5/t5-base",
)
class FlaxAutoModelForSequenceClassification(_BaseAutoModelClass):
_model_mapping = FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
FlaxAutoModelForSequenceClassification = auto_class_update(
FlaxAutoModelForSequenceClassification, head_doc="sequence classification"
)
class FlaxAutoModelForQuestionAnswering(_BaseAutoModelClass):
_model_mapping = FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING
FlaxAutoModelForQuestionAnswering = auto_class_update(FlaxAutoModelForQuestionAnswering, head_doc="question answering")
class FlaxAutoModelForTokenClassification(_BaseAutoModelClass):
_model_mapping = FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
FlaxAutoModelForTokenClassification = auto_class_update(
FlaxAutoModelForTokenClassification, head_doc="token classification"
)
class FlaxAutoModelForMultipleChoice(_BaseAutoModelClass):
_model_mapping = FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING
FlaxAutoModelForMultipleChoice = auto_class_update(FlaxAutoModelForMultipleChoice, head_doc="multiple choice")
class FlaxAutoModelForNextSentencePrediction(_BaseAutoModelClass):
_model_mapping = FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING
FlaxAutoModelForNextSentencePrediction = auto_class_update(
FlaxAutoModelForNextSentencePrediction, head_doc="next sentence prediction"
)
class FlaxAutoModelForImageClassification(_BaseAutoModelClass):
_model_mapping = FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
FlaxAutoModelForImageClassification = auto_class_update(
FlaxAutoModelForImageClassification, head_doc="image classification"
)
class FlaxAutoModelForVision2Seq(_BaseAutoModelClass):
_model_mapping = FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING
FlaxAutoModelForVision2Seq = auto_class_update(FlaxAutoModelForVision2Seq, head_doc="vision-to-text modeling")
class FlaxAutoModelForSpeechSeq2Seq(_BaseAutoModelClass):
_model_mapping = FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
FlaxAutoModelForSpeechSeq2Seq = auto_class_update(
FlaxAutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeling"
)
.\models\auto\modeling_tf_auto.py
import warnings
from collections import OrderedDict
from ...utils import logging
from .auto_factory import _BaseAutoModelClass, _LazyAutoMapping, auto_class_update
from .configuration_auto import CONFIG_MAPPING_NAMES
logger = logging.get_logger(__name__)
TF_MODEL_MAPPING_NAMES = OrderedDict(
[
("albert", "TFAlbertModel"),
("bart", "TFBartModel"),
("bert", "TFBertModel"),
("blenderbot", "TFBlenderbotModel"),
("blenderbot-small", "TFBlenderbotSmallModel"),
("blip", "TFBlipModel"),
("camembert", "TFCamembertModel"),
("clip", "TFCLIPModel"),
("convbert", "TFConvBertModel"),
("convnext", "TFConvNextModel"),
("convnextv2", "TFConvNextV2Model"),
("ctrl", "TFCTRLModel"),
("cvt", "TFCvtModel"),
("data2vec-vision", "TFData2VecVisionModel"),
("deberta", "TFDebertaModel"),
("deberta-v2", "TFDebertaV2Model"),
("deit", "TFDeiTModel"),
("distilbert", "TFDistilBertModel"),
("dpr", "TFDPRQuestionEncoder"),
("efficientformer", "TFEfficientFormerModel"),
("electra", "TFElectraModel"),
("esm", "TFEsmModel"),
("flaubert", "TFFlaubertModel"),
("funnel", ("TFFunnelModel", "TFFunnelBaseModel")),
("gpt-sw3", "TFGPT2Model"),
("gpt2", "TFGPT2Model"),
("gptj", "TFGPTJModel"),
("groupvit", "TFGroupViTModel"),
("hubert", "TFHubertModel"),
("layoutlm", "TFLayoutLMModel"),
("layoutlmv3", "TFLayoutLMv3Model"),
("led", "TFLEDModel"),
("longformer", "TFLongformerModel"),
("lxmert", "TFLxmertModel"),
("marian", "TFMarianModel"),
("mbart", "TFMBartModel"),
("mobilebert", "TFMobileBertModel"),
("mobilevit", "TFMobileViTModel"),
("mpnet", "TFMPNetModel"),
("mt5", "TFMT5Model"),
("openai-gpt", "TFOpenAIGPTModel"),
("opt", "TFOPTModel"),
("pegasus", "TFPegasusModel"),
("regnet", "TFRegNetModel"),
("rembert", "TFRemBertModel"),
("resnet", "TFResNetModel"),
("roberta", "TFRobertaModel"),
("roberta-prelayernorm", "TFRobertaPreLayerNormModel"),
("roformer", "TFRoFormerModel"),
("sam", "TFSamModel"),
("segformer", "TFSegformerModel"),
("speech_to_text", "TFSpeech2TextModel"),
("swin", "TFSwinModel"),
TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
[
("albert", "TFAlbertForPreTraining"),
("bart", "TFBartForConditionalGeneration"),
("bert", "TFBertForPreTraining"),
("camembert", "TFCamembertForMaskedLM"),
("ctrl", "TFCTRLLMHeadModel"),
("distilbert", "TFDistilBertForMaskedLM"),
("electra", "TFElectraForPreTraining"),
("flaubert", "TFFlaubertWithLMHeadModel"),
("funnel", "TFFunnelForPreTraining"),
("gpt-sw3", "TFGPT2LMHeadModel"),
("gpt2", "TFGPT2LMHeadModel"),
("layoutlm", "TFLayoutLMForMaskedLM"),
("lxmert", "TFLxmertForPreTraining"),
("mobilebert", "TFMobileBertForPreTraining"),
("mpnet", "TFMPNetForMaskedLM"),
("openai-gpt", "TFOpenAIGPTLMHeadModel"),
("roberta", "TFRobertaForMaskedLM"),
("roberta-prelayernorm", "TFRobertaPreLayerNormForMaskedLM"),
("t5", "TFT5ForConditionalGeneration"),
("tapas", "TFTapasForMaskedLM"),
("transfo-xl", "TFTransfoXLLMHeadModel"),
("vit_mae", "TFViTMAEForPreTraining"),
("xlm", "TFXLMWithLMHeadModel"),
("xlm-roberta", "TFXLMRobertaForMaskedLM"),
("xlnet", "TFXLNetLMHeadModel"),
]
)
TF_MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
[
("albert", "TFAlbertForMaskedLM"),
("bart", "TFBartForConditionalGeneration"),
("bert", "TFBertForMaskedLM"),
("camembert", "TFCamembertForMaskedLM"),
("convbert", "TFConvBertForMaskedLM"),
("ctrl", "TFCTRLLMHeadModel"),
("distilbert", "TFDistilBertForMaskedLM"),
("electra", "TFElectraForMaskedLM"),
("esm", "TFEsmForMaskedLM"),
("flaubert", "TFFlaubertWithLMHeadModel"),
("funnel", "TFFunnelForMaskedLM"),
("gpt-sw3", "TFGPT2LMHeadModel"),
("gpt2", "TFGPT2LMHeadModel"),
("gptj", "TFGPTJForCausalLM"),
("layoutlm", "TFLayoutLMForMaskedLM"),
("led", "TFLEDForConditionalGeneration"),
("longformer", "TFLongformerForMaskedLM"),
("marian", "TFMarianMTModel"),
("mobilebert", "TFMobileBertForMaskedLM"),
("mpnet", "TFMPNetForMaskedLM"),
("openai-gpt", "TFOpenAIGPTLMHeadModel"),
("rembert", "TFRemBertForMaskedLM"),
("roberta", "TFRobertaForMaskedLM"),
("roberta-prelayernorm", "TFRobertaPreLayerNormForMaskedLM"),
("roformer", "TFRoFormerForMaskedLM"),
("speech_to_text", "TFSpeech2TextForConditionalGeneration"),
("t5", "TFT5ForConditionalGeneration"),
("tapas", "TFTapasForMaskedLM"),
("transfo-xl", "TFTransfoXLLMHeadModel"),
("whisper", "TFWhisperForConditionalGeneration"),
("xlm", "TFXLMWithLMHeadModel"),
("xlm-roberta", "TFXLMRobertaForMaskedLM"),
("xlnet", "TFXLNetLMHeadModel"),
]
)
TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
[
("bert", "TFBertLMHeadModel"),
("camembert", "TFCamembertForCausalLM"),
("ctrl", "TFCTRLLMHeadModel"),
("gpt-sw3", "TFGPT2LMHeadModel"),
("gpt2", "TFGPT2LMHeadModel"),
("gptj", "TFGPTJForCausalLM"),
("openai-gpt", "TFOpenAIGPTLMHeadModel"),
("opt", "TFOPTForCausalLM"),
("rembert", "TFRemBertForCausalLM"),
("roberta", "TFRobertaForCausalLM"),
("roberta-prelayernorm", "TFRobertaPreLayerNormForCausalLM"),
("roformer", "TFRoFormerForCausalLM"),
("transfo-xl", "TFTransfoXLLMHeadModel"),
("xglm", "TFXGLMForCausalLM"),
("xlm", "TFXLMWithLMHeadModel"),
("xlm-roberta", "TFXLMRobertaForCausalLM"),
("xlnet", "TFXLNetLMHeadModel"),
]
TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES = OrderedDict(
[
("deit", "TFDeiTForMaskedImageModeling"),
("swin", "TFSwinForMaskedImageModeling"),
]
)
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
("convnext", "TFConvNextForImageClassification"),
("convnextv2", "TFConvNextV2ForImageClassification"),
("cvt", "TFCvtForImageClassification"),
("data2vec-vision", "TFData2VecVisionForImageClassification"),
("deit", ("TFDeiTForImageClassification", "TFDeiTForImageClassificationWithTeacher")),
(
"efficientformer",
("TFEfficientFormerForImageClassification", "TFEfficientFormerForImageClassificationWithTeacher"),
),
("mobilevit", "TFMobileViTForImageClassification"),
("regnet", "TFRegNetForImageClassification"),
("resnet", "TFResNetForImageClassification"),
("segformer", "TFSegformerForImageClassification"),
("swin", "TFSwinForImageClassification"),
("vit", "TFViTForImageClassification"),
]
)
TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
("blip", "TFBlipModel"),
("clip", "TFCLIPModel"),
]
)
TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = OrderedDict(
[
("data2vec-vision", "TFData2VecVisionForSemanticSegmentation"),
("mobilevit", "TFMobileViTForSemanticSegmentation"),
("segformer", "TFSegformerForSemanticSegmentation"),
]
)
TF_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
[
("blip", "TFBlipForConditionalGeneration"),
("vision-encoder-decoder", "TFVisionEncoderDecoderModel"),
]
)
TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
[
("albert", "TFAlbertForMaskedLM"),
("bert", "TFBertForMaskedLM"),
("camembert", "TFCamembertForMaskedLM"),
("convbert", "TFConvBertForMaskedLM"),
("deberta", "TFDebertaForMaskedLM"),
("deberta-v2", "TFDebertaV2ForMaskedLM"),
("distilbert", "TFDistilBertForMaskedLM"),
("electra", "TFElectraForMaskedLM"),
("esm", "TFEsmForMaskedLM"),
("flaubert", "TFFlaubertWithLMHeadModel"),
("funnel", "TFFunnelForMaskedLM"),
("layoutlm", "TFLayoutLMForMaskedLM"),
("longformer", "TFLongformerForMaskedLM"),
("mobilebert", "TFMobileBertForMaskedLM"),
("mpnet", "TFMPNetForMaskedLM"),
("rembert", "TFRemBertForMaskedLM"),
("roberta", "TFRobertaForMaskedLM"),
("roberta-prelayernorm", "TFRobertaPreLayerNormForMaskedLM"),
("roformer", "TFRoFormerForMaskedLM"),
("tapas", "TFTapasForMaskedLM"),
("xlm", "TFXLMWithLMHeadModel"),
("xlm-roberta", "TFXLMRobertaForMaskedLM"),
]
)
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
[
("bart", "TFBartForConditionalGeneration"),
("blenderbot", "TFBlenderbotForConditionalGeneration"),
("blenderbot-small", "TFBlenderbotSmallForConditionalGeneration"),
("encoder-decoder", "TFEncoderDecoderModel"),
("led", "TFLEDForConditionalGeneration"),
("marian", "TFMarianMTModel"),
("mbart", "TFMBartForConditionalGeneration"),
("mt5", "TFMT5ForConditionalGeneration"),
("pegasus", "TFPegasusForConditionalGeneration"),
("t5", "TFT5ForConditionalGeneration"),
]
)
TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
[
("speech_to_text", "TFSpeech2TextForConditionalGeneration"),
("whisper", "TFWhisperForConditionalGeneration"),
]
)
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
("albert", "TFAlbertForSequenceClassification"),
("bart", "TFBartForSequenceClassification"),
("bert", "TFBertForSequenceClassification"),
("camembert", "TFCamembertForSequenceClassification"),
("convbert", "TFConvBertForSequenceClassification"),
("ctrl", "TFCTRLForSequenceClassification"),
("deberta", "TFDebertaForSequenceClassification"),
("deberta-v2", "TFDebertaV2ForSequenceClassification"),
("distilbert", "TFDistilBertForSequenceClassification"),
("electra", "TFElectraForSequenceClassification"),
("esm", "TFEsmForSequenceClassification"),
("flaubert", "TFFlaubertForSequenceClassification"),
("funnel", "TFFunnelForSequenceClassification"),
("gpt-sw3", "TFGPT2ForSequenceClassification"),
("gpt2", "TFGPT2ForSequenceClassification"),
("gptj", "TFGPTJForSequenceClassification"),
("layoutlm", "TFLayoutLMForSequenceClassification"),
("layoutlmv3", "TFLayoutLMv3ForSequenceClassification"),
("longformer", "TFLongformerForSequenceClassification"),
("mobilebert", "TFMobileBertForSequenceClassification"),
("mpnet", "TFMPNetForSequenceClassification"),
("openai-gpt", "TFOpenAIGPTForSequenceClassification"),
("rembert", "TFRemBertForSequenceClassification"),
("roberta", "TFRobertaForSequenceClassification"),
("roberta-prelayernorm", "TFRobertaPreLayerNormForSequenceClassification"),
("roformer", "TFRoFormerForSequenceClassification"),
("tapas", "TFTapasForSequenceClassification"),
("transfo-xl", "TFTransfoXLForSequenceClassification"),
("xlm", "TFXLMForSequenceClassification"),
("xlm-roberta", "TFXLMRobertaForSequenceClassification"),
("xlnet", "TFXLNetForSequenceClassification"),
]
)
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
[
("albert", "TFAlbertForQuestionAnswering"),
("bert", "TFBertForQuestionAnswering"),
("camembert", "TFCamembertForQuestionAnswering"),
("convbert", "TFConvBertForQuestionAnswering"),
("deberta", "TFDebertaForQuestionAnswering"),
("deberta-v2", "TFDebertaV2ForQuestionAnswering"),
("distilbert", "TFDistilBertForQuestionAnswering"),
("electra", "TFElectraForQuestionAnswering"),
("flaubert", "TFFlaubertForQuestionAnsweringSimple"),
("funnel", "TFFunnelForQuestionAnswering"),
("gptj", "TFGPTJForQuestionAnswering"),
("layoutlmv3", "TFLayoutLMv3ForQuestionAnswering"),
("longformer", "TFLongformerForQuestionAnswering"),
("mobilebert", "TFMobileBertForQuestionAnswering"),
("mpnet", "TFMPNetForQuestionAnswering"),
("rembert", "TFRemBertForQuestionAnswering"),
("roberta", "TFRobertaForQuestionAnswering"),
("roberta-prelayernorm", "TFRobertaPreLayerNormForQuestionAnswering"),
("roformer", "TFRoFormerForQuestionAnswering"),
("xlm", "TFXLMForQuestionAnsweringSimple"),
("xlm-roberta", "TFXLMRobertaForQuestionAnswering"),
("xlnet", "TFXLNetForQuestionAnsweringSimple"),
]
TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict([("wav2vec2", "TFWav2Vec2ForSequenceClassification")])
TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
[
("layoutlm", "TFLayoutLMForQuestionAnswering"),
("layoutlmv3", "TFLayoutLMv3ForQuestionAnswering"),
]
)
TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
[
("tapas", "TFTapasForQuestionAnswering"),
]
)
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
("albert", "TFAlbertForTokenClassification"),
("bert", "TFBertForTokenClassification"),
("camembert", "TFCamembertForTokenClassification"),
("convbert", "TFConvBertForTokenClassification"),
("deberta", "TFDebertaForTokenClassification"),
("deberta-v2", "TFDebertaV2ForTokenClassification"),
("distilbert", "TFDistilBertForTokenClassification"),
("electra", "TFElectraForTokenClassification"),
("esm", "TFEsmForTokenClassification"),
("flaubert", "TFFlaubertForTokenClassification"),
("funnel", "TFFunnelForTokenClassification"),
("layoutlm", "TFLayoutLMForTokenClassification"),
("layoutlmv3", "TFLayoutLMv3ForTokenClassification"),
("longformer", "TFLongformerForTokenClassification"),
("mobilebert", "TFMobileBertForTokenClassification"),
("mpnet", "TFMPNetForTokenClassification"),
("rembert", "TFRemBertForTokenClassification"),
("roberta", "TFRobertaForTokenClassification"),
("roberta-prelayernorm", "TFRobertaPreLayerNormForTokenClassification"),
("roformer", "TFRoFormerForTokenClassification"),
("xlm", "TFXLMForTokenClassification"),
("xlm-roberta", "TFXLMRobertaForTokenClassification"),
("xlnet", "TFXLNetForTokenClassification"),
]
)
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
[
("albert", "TFAlbertForMultipleChoice"),
("bert", "TFBertForMultipleChoice"),
("camembert", "TFCamembertForMultipleChoice"),
("convbert", "TFConvBertForMultipleChoice"),
("deberta-v2", "TFDebertaV2ForMultipleChoice"),
("distilbert", "TFDistilBertForMultipleChoice"),
("electra", "TFElectraForMultipleChoice"),
("flaubert", "TFFlaubertForMultipleChoice"),
("funnel", "TFFunnelForMultipleChoice"),
("longformer", "TFLongformerForMultipleChoice"),
("mobilebert", "TFMobileBertForMultipleChoice"),
("mpnet", "TFMPNetForMultipleChoice"),
("rembert", "TFRemBertForMultipleChoice"),
("roberta", "TFRobertaForMultipleChoice"),
("roberta-prelayernorm", "TFRobertaPreLayerNormForMultipleChoice"),
("roformer", "TFRoFormerForMultipleChoice"),
("xlm", "TFXLMForMultipleChoice"),
("xlm-roberta", "TFXLMRobertaForMultipleChoice"),
("xlnet", "TFXLNetForMultipleChoice"),
]
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict(
[
("bert", "TFBertForNextSentencePrediction"),
("mobilebert", "TFMobileBertForNextSentencePrediction"),
]
)
TF_MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = OrderedDict(
[
("sam", "TFSamModel"),
]
)
TF_MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES = OrderedDict(
[
("albert", "TFAlbertModel"),
("bert", "TFBertModel"),
("convbert", "TFConvBertModel"),
("deberta", "TFDebertaModel"),
("deberta-v2", "TFDebertaV2Model"),
("distilbert", "TFDistilBertModel"),
("electra", "TFElectraModel"),
("flaubert", "TFFlaubertModel"),
("longformer", "TFLongformerModel"),
("mobilebert", "TFMobileBertModel"),
("mt5", "TFMT5EncoderModel"),
("rembert", "TFRemBertModel"),
("roberta", "TFRobertaModel"),
("roberta-prelayernorm", "TFRobertaPreLayerNormModel"),
("roformer", "TFRoFormerModel"),
("t5", "TFT5EncoderModel"),
("xlm", "TFXLMModel"),
("xlm-roberta", "TFXLMRobertaModel"),
]
)
TF_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_MAPPING_NAMES)
TF_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES)
TF_MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_WITH_LM_HEAD_MAPPING_NAMES)
TF_MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES
)
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
)
TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES
)
TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES
)
TF_MODEL_FOR_VISION_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES)
TF_MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES)
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
)
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
)
TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES
)
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
)
TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES
)
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES
TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES
)
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
)
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES
)
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES
)
TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
)
TF_MODEL_FOR_MASK_GENERATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASK_GENERATION_MAPPING_NAMES
)
TF_MODEL_FOR_TEXT_ENCODING_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES
)
class TFAutoModelForMaskGeneration(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_MASK_GENERATION_MAPPING
class TFAutoModelForTextEncoding(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_TEXT_ENCODING_MAPPING
class TFAutoModel(_BaseAutoModelClass):
_model_mapping = TF_MODEL_MAPPING
TFAutoModel = auto_class_update(TFAutoModel)
class TFAutoModelForAudioClassification(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
TFAutoModelForAudioClassification = auto_class_update(
TFAutoModelForAudioClassification, head_doc="audio classification"
)
class TFAutoModelForPreTraining(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_PRETRAINING_MAPPING
TFAutoModelForPreTraining = auto_class_update(TFAutoModelForPreTraining, head_doc="pretraining")
class _TFAutoModelWithLMHead(_BaseAutoModelClass):
_model_mapping = TF_MODEL_WITH_LM_HEAD_MAPPING
_TFAutoModelWithLMHead = auto_class_update(_TFAutoModelWithLMHead, head_doc="language modeling")
class TFAutoModelForCausalLM(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_CAUSAL_LM_MAPPING
TFAutoModelForCausalLM = auto_class_update(TFAutoModelForCausalLM, head_doc="causal language modeling")
class TFAutoModelForMaskedImageModeling(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING
TFAutoModelForMaskedImageModeling = auto_class_update(
TFAutoModelForMaskedImageModeling, head_doc="masked image modeling"
)
class TFAutoModelForImageClassification(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
TFAutoModelForImageClassification = auto_class_update(
TFAutoModelForImageClassification, head_doc="image classification"
)
class TFAutoModelForZeroShotImageClassification(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
TFAutoModelForZeroShotImageClassification = auto_class_update(
TFAutoModelForZeroShotImageClassification,
head_doc="zero-shot image classification"
)
TFAutoModelForZeroShotImageClassification, head_doc="zero-shot image classification"
class TFAutoModelForSemanticSegmentation(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING
TFAutoModelForSemanticSegmentation = auto_class_update(
TFAutoModelForSemanticSegmentation, head_doc="semantic segmentation"
)
class TFAutoModelForVision2Seq(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_VISION_2_SEQ_MAPPING
TFAutoModelForVision2Seq = auto_class_update(
TFAutoModelForVision2Seq, head_doc="vision-to-text modeling"
)
class TFAutoModelForMaskedLM(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_MASKED_LM_MAPPING
TFAutoModelForMaskedLM = auto_class_update(
TFAutoModelForMaskedLM, head_doc="masked language modeling"
)
class TFAutoModelForSeq2SeqLM(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
TFAutoModelForSeq2SeqLM = auto_class_update(
TFAutoModelForSeq2SeqLM,
head_doc="sequence-to-sequence language modeling",
checkpoint_for_example="google-t5/t5-base",
)
class TFAutoModelForSequenceClassification(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
TFAutoModelForSequenceClassification = auto_class_update(
TFAutoModelForSequenceClassification, head_doc="sequence classification"
)
class TFAutoModelForQuestionAnswering(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING
TFAutoModelForQuestionAnswering = auto_class_update(
TFAutoModelForQuestionAnswering, head_doc="question answering"
)
class TFAutoModelForDocumentQuestionAnswering(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING
TFAutoModelForDocumentQuestionAnswering = auto_class_update(
TFAutoModelForDocumentQuestionAnswering,
head_doc="document question answering",
checkpoint_for_example='impira/layoutlm-document-qa", revision="52e01b3',
)
class TFAutoModelForTableQuestionAnswering(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING
TFAutoModelForTableQuestionAnswering = auto_class_update(
TFAutoModelForTableQuestionAnswering,
head_doc="table question answering",
checkpoint_for_example="google/tapas-base-finetuned-wtq",
)
class TFAutoModelForTokenClassification(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
TFAutoModelForTokenClassification = auto_class_update(
TFAutoModelForTokenClassification, head_doc="token classification"
)
class TFAutoModelForMultipleChoice(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING
TFAutoModelForMultipleChoice = auto_class_update(
TFAutoModelForMultipleChoice, head_doc="multiple choice"
)
class TFAutoModelForNextSentencePrediction(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING
TFAutoModelForNextSentencePrediction = auto_class_update(
TFAutoModelForNextSentencePrediction, head_doc="next sentence prediction"
)
class TFAutoModelForSpeechSeq2Seq(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
TFAutoModelForSpeechSeq2Seq = auto_class_update(
TFAutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeling"
)
class TFAutoModelWithLMHead(_TFAutoModelWithLMHead):
@classmethod
def from_config(cls, config):
warnings.warn(
"The class `TFAutoModelWithLMHead` is deprecated and will be removed in a future version. Please use"
" `TFAutoModelForCausalLM` for causal language models, `TFAutoModelForMaskedLM` for masked language models"
" and `TFAutoModelForSeq2SeqLM` for encoder-decoder models.",
FutureWarning,
)
return super().from_config(config)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
warnings.warn(
"The class `TFAutoModelWithLMHead` is deprecated and will be removed in a future version. Please use"
" `TFAutoModelForCausalLM` for causal language models, `TFAutoModelForMaskedLM` for masked language models"
" and `TFAutoModelForSeq2SeqLM` for encoder-decoder models.",
FutureWarning,
)
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
.\models\auto\processing_auto.py
""" AutoProcessor class."""
import importlib
import inspect
import json
import os
import warnings
from collections import OrderedDict
from ...configuration_utils import PretrainedConfig
from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
from ...feature_extraction_utils import FeatureExtractionMixin
from ...image_processing_utils import ImageProcessingMixin
from ...processing_utils import ProcessorMixin
from ...tokenization_utils import TOKENIZER_CONFIG_FILE
from ...utils import FEATURE_EXTRACTOR_NAME, PROCESSOR_NAME, get_file_from_repo, logging
from .auto_factory import _LazyAutoMapping
from .configuration_auto import (
CONFIG_MAPPING_NAMES,
AutoConfig,
model_type_to_module_name,
replace_list_option_in_docstrings,
)
from .feature_extraction_auto import AutoFeatureExtractor
from .image_processing_auto import AutoImageProcessor
from .tokenization_auto import AutoTokenizer
logger = logging.get_logger(__name__)
PROCESSOR_MAPPING_NAMES = OrderedDict(
[
("align", "AlignProcessor"),
("altclip", "AltCLIPProcessor"),
("bark", "BarkProcessor"),
("blip", "BlipProcessor"),
("blip-2", "Blip2Processor"),
("bridgetower", "BridgeTowerProcessor"),
("chinese_clip", "ChineseCLIPProcessor"),
("clap", "ClapProcessor"),
("clip", "CLIPProcessor"),
("clipseg", "CLIPSegProcessor"),
("clvp", "ClvpProcessor"),
("flava", "FlavaProcessor"),
("fuyu", "FuyuProcessor"),
("git", "GitProcessor"),
("groupvit", "CLIPProcessor"),
("hubert", "Wav2Vec2Processor"),
("idefics", "IdeficsProcessor"),
("instructblip", "InstructBlipProcessor"),
("kosmos-2", "Kosmos2Processor"),
("layoutlmv2", "LayoutLMv2Processor"),
("layoutlmv3", "LayoutLMv3Processor"),
("llava", "LlavaProcessor"),
("llava_next", "LlavaNextProcessor"),
("markuplm", "MarkupLMProcessor"),
("mctct", "MCTCTProcessor"),
("mgp-str", "MgpstrProcessor"),
("oneformer", "OneFormerProcessor"),
("owlv2", "Owlv2Processor"),
("owlvit", "OwlViTProcessor"),
("pix2struct", "Pix2StructProcessor"),
("pop2piano", "Pop2PianoProcessor"),
("sam", "SamProcessor"),
("seamless_m4t", "SeamlessM4TProcessor"),
("sew", "Wav2Vec2Processor"),
("sew-d", "Wav2Vec2Processor"),
("siglip", "SiglipProcessor"),
("speech_to_text", "Speech2TextProcessor"),
("speech_to_text_2", "Speech2Text2Processor"),
("speecht5", "SpeechT5Processor"),
("trocr", "TrOCRProcessor"),
("tvlt", "TvltProcessor"),
("tvp", "TvpProcessor"),
("unispeech", "Wav2Vec2Processor"),
PROCESSOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, PROCESSOR_MAPPING_NAMES)
def processor_class_from_name(class_name: str):
for module_name, processors in PROCESSOR_MAPPING_NAMES.items():
if class_name in processors:
module_name = model_type_to_module_name(module_name)
module = importlib.import_module(f".{module_name}", "transformers.models")
try:
return getattr(module, class_name)
except AttributeError:
continue
for processor in PROCESSOR_MAPPING._extra_content.values():
if getattr(processor, "__name__", None) == class_name:
return processor
main_module = importlib.import_module("transformers")
if hasattr(main_module, class_name):
return getattr(main_module, class_name)
return None
class AutoProcessor:
r"""
This is a generic processor class that will be instantiated as one of the processor classes of the library when
created with the [`AutoProcessor.from_pretrained`] class method.
This class cannot be instantiated directly using `__init__()` (throws an error).
"""
def __init__(self):
raise EnvironmentError(
"AutoProcessor is designed to be instantiated "
"using the `AutoProcessor.from_pretrained(pretrained_model_name_or_path)` method."
)
@classmethod
@replace_list_option_in_docstrings(PROCESSOR_MAPPING_NAMES)
def register(config_class, processor_class, exist_ok=False):
"""
Register a new processor for this class.
Args:
config_class ([`PretrainedConfig`]):
The configuration corresponding to the model to register.
processor_class ([`FeatureExtractorMixin`]): The processor to register.
"""
PROCESSOR_MAPPING.register(config_class, processor_class, exist_ok=exist_ok)
.\models\auto\tokenization_auto.py
""" Auto Tokenizer class."""
import importlib
import json
import os
import warnings
from collections import OrderedDict
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
from ...configuration_utils import PretrainedConfig
from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
from ...tokenization_utils import PreTrainedTokenizer
from ...tokenization_utils_base import TOKENIZER_CONFIG_FILE
from ...utils import (
cached_file,
extract_commit_hash,
is_g2p_en_available,
is_sentencepiece_available,
is_tokenizers_available,
logging,
)
from ..encoder_decoder import EncoderDecoderConfig
from .auto_factory import _LazyAutoMapping
from .configuration_auto import (
CONFIG_MAPPING_NAMES,
AutoConfig,
config_class_to_model_type,
model_type_to_module_name,
replace_list_option_in_docstrings,
)
if is_tokenizers_available():
from ...tokenization_utils_fast import PreTrainedTokenizerFast
else:
PreTrainedTokenizerFast = None
logger = logging.get_logger(__name__)
if TYPE_CHECKING:
TOKENIZER_MAPPING_NAMES: OrderedDict[str, Tuple[Optional[str], Optional[str]]] = OrderedDict()
else:
TOKENIZER_MAPPING_NAMES = OrderedDict()
TOKENIZER_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TOKENIZER_MAPPING_NAMES)
CONFIG_TO_TYPE = {v: k for k, v in CONFIG_MAPPING_NAMES.items()}
def tokenizer_class_from_name(class_name: str):
if class_name == "PreTrainedTokenizerFast":
return PreTrainedTokenizerFast
for module_name, tokenizers in TOKENIZER_MAPPING_NAMES.items():
if class_name in tokenizers:
module_name = model_type_to_module_name(module_name)
module = importlib.import_module(f".{module_name}", "transformers.models")
try:
return getattr(module, class_name)
except AttributeError:
continue
for config, tokenizers in TOKENIZER_MAPPING._extra_content.items():
for tokenizer in tokenizers:
if getattr(tokenizer, "__name__", None) == class_name:
return tokenizer
main_module = importlib.import_module("transformers")
if hasattr(main_module, class_name):
return getattr(main_module, class_name)
return None
def get_tokenizer_config(
pretrained_model_name_or_path: Union[str, os.PathLike],
cache_dir: Optional[Union[str, os.PathLike]] = None,
force_download: bool = False,
resume_download: bool = False,
proxies: Optional[Dict[str, str]] = None,
token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
local_files_only: bool = False,
subfolder: str = "",
**kwargs,
):
"""
Loads the tokenizer configuration from a pretrained model tokenizer configuration.
Args:
pretrained_model_name_or_path (`str` or `os.PathLike`):
This can be either:
- a string, the *model id* of a pretrained model configuration hosted inside a model repo on
huggingface.co.
- a path to a *directory* containing a configuration file saved using the
[`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
cache_dir (`str` or `os.PathLike`, *optional*):
Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
cache should not be used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force to (re-)download the configuration files and override the cached versions if they
exist.
resume_download (`bool`, *optional*, defaults to `False`):
Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `huggingface-cli login` (stored in `~/.huggingface`).
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
identifier allowed by git.
local_files_only (`bool`, *optional*, defaults to `False`):
If `True`, will only try to load the tokenizer configuration from local files.
subfolder (`str`, *optional*, defaults to `""`):
In case the tokenizer config is located inside a subfolder of the model repo on huggingface.co, you can
specify the folder name here.
<Tip>
Passing `token=True` is required when you want to use a private model.
</Tip>
Returns:
`Dict`: The configuration of the tokenizer.
Examples:
```
# 从huggingface.co下载配置文件并进行缓存
```
# 获取指定预训练模型的分词器配置信息
tokenizer_config = get_tokenizer_config("google-bert/bert-base-uncased")
# 由于这个模型没有分词器配置,所以结果将会是一个空字典。
tokenizer_config = get_tokenizer_config("FacebookAI/xlm-roberta-base")
# 导入transformers库中的AutoTokenizer类,用于自动获取预训练模型的分词器
from transformers import AutoTokenizer
# 从预训练模型路径中加载分词器
tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-cased")
# 将分词器保存到本地目录"tokenizer-test"中
tokenizer.save_pretrained("tokenizer-test")
# 获取保存的分词器配置信息
tokenizer_config = get_tokenizer_config("tokenizer-test")
```
# 处理`use_auth_token`参数的兼容性警告和错误处理逻辑
use_auth_token = kwargs.pop("use_auth_token", None)
if use_auth_token is not None:
# 发出将在Transformers v5中移除`use_auth_token`参数的警告
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
FutureWarning,
)
# 如果同时指定了`token`和`use_auth_token`参数,则抛出错误
if token is not None:
raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
# 将`use_auth_token`参数的值赋给`token`变量
token = use_auth_token
# 获取kwargs中的_commit_hash参数值
commit_hash = kwargs.get("_commit_hash", None)
# 解析和缓存预训练模型的tokenizer配置文件
resolved_config_file = cached_file(
pretrained_model_name_or_path,
TOKENIZER_CONFIG_FILE,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
token=token,
revision=revision,
local_files_only=local_files_only,
subfolder=subfolder,
_raise_exceptions_for_gated_repo=False,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
_commit_hash=commit_hash,
)
# 如果未能定位tokenizer配置文件,则记录日志并返回空字典
if resolved_config_file is None:
logger.info("Could not locate the tokenizer configuration file, will try to use the model config instead.")
return {}
# 提取配置文件的提交哈希值
commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
# 打开配置文件并加载其内容到result字典中
with open(resolved_config_file, encoding="utf-8") as reader:
result = json.load(reader)
# 将提取的提交哈希值存入result字典中的"_commit_hash"键
result["_commit_hash"] = commit_hash
return result
class AutoTokenizer:
r"""
This is a generic tokenizer class that will be instantiated as one of the tokenizer classes of the library when
created with the [`AutoTokenizer.from_pretrained`] class method.
This class cannot be instantiated directly using `__init__()` (throws an error).
"""
def __init__(self):
# 抛出环境错误,阻止直接实例化该类
raise EnvironmentError(
"AutoTokenizer is designed to be instantiated "
"using the `AutoTokenizer.from_pretrained(pretrained_model_name_or_path)` method."
)
@classmethod
@replace_list_option_in_docstrings(TOKENIZER_MAPPING_NAMES)
def register(config_class, slow_tokenizer_class=None, fast_tokenizer_class=None, exist_ok=False):
"""
Register a new tokenizer in this mapping.
Args:
config_class ([`PretrainedConfig`]):
The configuration corresponding to the model to register.
slow_tokenizer_class ([`PretrainedTokenizer`], *optional*):
The slow tokenizer to register.
fast_tokenizer_class ([`PretrainedTokenizerFast`], *optional*):
The fast tokenizer to register.
"""
# 检查是否提供了慢速或快速的分词器类,否则抛出值错误
if slow_tokenizer_class is None and fast_tokenizer_class is None:
raise ValueError("You need to pass either a `slow_tokenizer_class` or a `fast_tokenizer_class`.")
# 如果在`slow_tokenizer_class`中传入了快速分词器类,则抛出值错误
if slow_tokenizer_class is not None and issubclass(slow_tokenizer_class, PreTrainedTokenizerFast):
raise ValueError("You passed a fast tokenizer in the `slow_tokenizer_class`.")
# 如果在`fast_tokenizer_class`中传入了慢速分词器类,则抛出值错误
if fast_tokenizer_class is not None and issubclass(fast_tokenizer_class, PreTrainedTokenizer):
raise ValueError("You passed a slow tokenizer in the `fast_tokenizer_class`.")
# 如果同时提供了慢速和快速分词器类,并且快速分词器类有一个与传入的慢速分词器类不一致的`slow_tokenizer_class`属性,则抛出值错误
if (
slow_tokenizer_class is not None
and fast_tokenizer_class is not None
and issubclass(fast_tokenizer_class, PreTrainedTokenizerFast)
and fast_tokenizer_class.slow_tokenizer_class != slow_tokenizer_class
):
raise ValueError(
"The fast tokenizer class you are passing has a `slow_tokenizer_class` attribute that is not "
"consistent with the slow tokenizer class you passed (fast tokenizer has "
f"{fast_tokenizer_class.slow_tokenizer_class} and you passed {slow_tokenizer_class}. Fix one of those "
"so they match!"
)
# 如果已经在TOKENIZER_MAPPING._extra_content中注册了config_class,则尝试使用现有的慢速和快速分词器类
if config_class in TOKENIZER_MAPPING._extra_content:
existing_slow, existing_fast = TOKENIZER_MAPPING[config_class]
if slow_tokenizer_class is None:
slow_tokenizer_class = existing_slow
if fast_tokenizer_class is None:
fast_tokenizer_class = existing_fast
# 在TOKENIZER_MAPPING中注册config_class与其对应的慢速和快速分词器类的映射
TOKENIZER_MAPPING.register(config_class, (slow_tokenizer_class, fast_tokenizer_class), exist_ok=exist_ok)