Transformers 源码解析(一百三十九)
.\utils\fx.py
import builtins
import collections
import functools
import inspect
import math
import operator
import os
import random
import warnings
from typing import Any, Callable, Dict, List, Optional, Type, Union
import torch
from torch import nn
from torch.fx import Graph, GraphModule, Proxy, Tracer
from torch.fx._compatibility import compatibility
from torch.fx.proxy import ParameterProxy
from .. import PretrainedConfig, PreTrainedModel, logging
from ..models.auto import get_values
from ..models.auto.modeling_auto import (
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
MODEL_FOR_BACKBONE_MAPPING_NAMES,
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
MODEL_FOR_CTC_MAPPING_NAMES,
MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES,
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
MODEL_FOR_IMAGE_MAPPING_NAMES,
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES,
MODEL_FOR_MASKED_LM_MAPPING_NAMES,
MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES,
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES,
MODEL_FOR_PRETRAINING_MAPPING_NAMES,
MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES,
MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES,
MODEL_MAPPING_NAMES,
)
from ..pytorch_utils import is_torch_greater_or_equal_than_2_0
from ..utils import (
ENV_VARS_TRUE_VALUES,
TORCH_FX_REQUIRED_VERSION,
get_torch_version,
is_peft_available,
is_torch_fx_available,
)
if is_peft_available():
from peft import PeftModel
logger = logging.get_logger(__name__)
_IS_IN_DEBUG_MODE = os.environ.get("FX_DEBUG_MODE", "").upper() in ENV_VARS_TRUE_VALUES
def _generate_supported_model_class_names(
model_name: Type[PretrainedConfig],
supported_tasks: Optional[Union[str, List[str]]] = None,
) -> List[str]:
task_mapping = {
"default": MODEL_MAPPING_NAMES,
"pretraining": MODEL_FOR_PRETRAINING_MAPPING_NAMES,
"next-sentence-prediction": MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES,
"masked-lm": MODEL_FOR_MASKED_LM_MAPPING_NAMES,
"causal-lm": MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
"seq2seq-lm": MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
"speech-seq2seq": MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES,
"multiple-choice": MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES,
"document-question-answering": MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES,
"question-answering": MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES,
"sequence-classification": MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
"token-classification": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
"masked-image-modeling": MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES,
"image-classification": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
"zero-shot-image-classification": MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES,
"ctc": MODEL_FOR_CTC_MAPPING_NAMES,
"audio-classification": MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
"semantic-segmentation": MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES,
"backbone": MODEL_FOR_BACKBONE_MAPPING_NAMES,
"image-feature-extraction": MODEL_FOR_IMAGE_MAPPING_NAMES,
}
if supported_tasks is None:
supported_tasks = task_mapping.keys()
if isinstance(supported_tasks, str):
supported_tasks = [supported_tasks]
model_class_names = []
for task in supported_tasks:
class_name = task_mapping[task].get(model_name, None)
if class_name:
model_class_names.append(class_name)
return model_class_names
_REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [
"altclip",
"albert",
"bart",
"bert",
"blenderbot",
"blenderbot-small",
"bloom",
"clip",
"convnext",
"deberta",
"deberta-v2",
"dinov2",
"distilbert",
"donut-swin",
"electra",
"gpt2",
"gpt_neo",
"gptj",
"hubert",
"layoutlm",
"llama",
"cohere",
"lxmert",
"m2m_100",
"marian",
"mbart",
"megatron-bert",
"mobilebert",
"mt5",
"nezha",
"opt",
"pegasus",
"plbart",
"resnet",
"roberta",
"segformer",
"speech_to_text",
"speech_to_text_2",
"swin",
"t5",
"trocr",
"vit",
"xglm",
"wav2vec2",
]
_FX_SUPPORTED_MODELS_WITH_KV_CACHE = ["llama", "opt"]
_REGULAR_SUPPORTED_MODELS = []
for item in _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS:
if isinstance(item, dict):
_REGULAR_SUPPORTED_MODELS.extend(_generate_supported_model_class_names(**item))
else:
_REGULAR_SUPPORTED_MODELS.extend(_generate_supported_model_class_names(item))
_SPECIAL_SUPPORTED_MODELS = [
"CLIPTextModel",
"CLIPTextModelWithProjection",
"CLIPVisionModel",
"CLIPVisionModelWithProjection",
"AltCLIPTextModel",
"AltCLIPVisionModel",
"GitVisionModel",
"GPT2DoubleHeadsModel",
"Speech2Text2Decoder",
"TrOCRDecoder",
"PeftModelForCausalLM",
"PeftModelForSeq2SeqLM",
]
_SUPPORTED_MODELS = tuple(sorted(set(_REGULAR_SUPPORTED_MODELS + _SPECIAL_SUPPORTED_MODELS)))
def torch_nn_embedding(self, input):
return torch.empty(*input.shape, self.weight.shape[-1], device="meta", dtype=self.weight.dtype)
def torch_nn_functional_embedding(
input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False
):
return torch.empty(*input.shape, weight.shape[-1], device="meta", dtype=weight.dtype)
def torch_nn_layernorm(self, input):
return input
def torch_nn_groupnorm(self, input):
return input
def torch_nn_linear(self, input):
return torch.empty(input.shape[:-1] + (self.out_features,), device="meta")
def torch_relu(x):
return x
def torch_nn_relu(self, x):
return x
def torch_nn_functional_relu(x, inplace=False):
if not inplace:
raise ValueError("Don't support in-place functional.relu for MetaTensor analysis")
return x
def torch_where(condition, x, y):
return condition.to(device="meta") + x.to(device="meta") + y.to(device="meta")
def torch_abs(input, *, out=None):
if out is not None:
raise ValueError("Don't support in-place abs for MetaTensor analysis")
return input
def torch_arange(*args, **kwargs):
n = len(args)
step = 1
if n == 1:
start = 0
end = args[0]
elif n == 2:
start, end = args
else:
start, end, step = args
if isinstance(start, float):
start = int(start)
if isinstance(end, float):
end = int(end)
if isinstance(step, float):
step = int(step)
step = kwargs.get("step", step)
dtype = kwargs.get("dtype")
return torch.empty((end - start) // step, dtype=dtype, device="meta")
def torch_full(*args, **kwargs):
args = list(args)
if isinstance(args[1], torch.Tensor) and args[1].device == torch.device("meta"):
args[1] = 1
kwargs_without_device = dict(kwargs)
kwargs_without_device.pop("device", None)
return torch.full(*args, **kwargs_without_device)
def torch_cat(tensors, dim=None, axis=None, *, out=None):
if dim is None and axis is None:
dim = 0
if dim is None and axis is not None:
dim = axis
if dim < 0:
dim = tensors[0].dim() + dim
shapes = [t.shape for t in tensors]
shape = list(shapes[0])
concatenated_dim = sum(shape[dim] for shape in shapes)
final_shape = shape[:dim] + [concatenated_dim] + shape[dim + 1 :]
return torch.empty(final_shape, device="meta")
def torch_stack(tensors, dim=None, axis=None, *, out=None):
if dim is None and axis is None:
dim = 0
if dim is None and axis is not None:
dim = axis
if dim < 0:
dim = tensors[0].dim() + 1 + dim
shape = list(tensors[0].shape)
shape.insert(dim, len(tensors))
return torch.empty(shape, device="meta")
def torch_add(input, other, *, alpha=1, out=None):
if not isinstance(input, torch.Tensor):
return torch.empty_like(other, device="meta")
if not isinstance(other, torch.Tensor):
return torch.empty_like(input, device="meta")
max_length = max(input.dim(), other.dim())
input_shape = list(input.shape) + [1] * (max_length - input.dim())
other_shape = list(other.shape) + [1] * (max_length - other.dim())
shape = []
for i in range(max_length):
shape.append(max(input_shape[i], other_shape[i]))
return torch.empty(shape, device="meta")
def torch_mul(input, other, *, out=None):
return torch_add(input, other, out=out)
def torch_tensor_mul(self, other):
return torch_mul(self, other)
def torch_matmul(input, other, *, out=None):
d1 = input.dim()
d2 = other.dim()
shape = None
if d1 == 1 and d2 == 1:
shape = None
elif d1 == 2 and d2 == 2:
shape = (input.size(0), other.size(1))
elif d1 == 1 and d2 == 2:
shape = (other.size(1),)
elif d1 == 2 and d1 == 1:
shape = (input.size(0),)
else:
max_length = max(input.dim(), other.dim())
shape1 = list(input.shape)
shape2 = list(other.shape)
if d1 == 1:
shape1 = [1] + shape1
if d2 == 1:
shape2.append(1)
shape1 = [-1] * (max_length - d1) + list(input.shape)
shape2 = [-1] * (max_length - d2) + list(other.shape)
shape = []
for i in range(max_length):
shape.append(max(shape1[i], shape2[i]))
shape[-2] = shape1[-2]
shape[-1] = shape2[-1]
if d1 == 1:
shape.pop(-2)
if d2 == 1:
shape.pop(-1)
if shape is None:
return torch.tensor(0.0, device="meta")
return torch.empty(*shape, device="meta")
def torch_bmm(input, mat2, *, out=None):
if out is not None:
raise ValueError("Don't support in-place bmm for MetaTensor analysis")
batch_size, n, m = input.shape
_, _, p = mat2.shape
return torch.empty(batch_size, n, p, device="meta")
def torch_baddbmm(input, batch1, batch2, *, beta=1, alpha=1, out=None):
if out is not None:
raise ValueError("Don't support in-place baddbmm for MetaTensor analysis")
return torch_bmm(batch1, batch2)
def torch_tensor_baddbmm(self, batch1, batch2, *, beta=1, alpha=1, out=None):
return torch_baddbmm(self, batch1, batch2, beta=beta, alpha=alpha, out=out)
def torch_einsum(equation, *operands):
concrete_operands = (torch.empty_like(operand, device="cpu") for operand in operands)
return torch.einsum(equation, *concrete_operands).to("meta")
def torch_tensor_repeat(self, *sizes):
shape = list(self.shape)
for i, x in enumerate(sizes):
shape[i] *= x
return torch.empty(shape, device="meta")
def torch_repeat_interleave(*args, dim=None, output_size=None):
num_args = len(args)
if num_args == 1:
shape = [output_size if output_size is not None else args[0].sum()]
else:
shape = list(args[0].shape)
if dim is None:
if num_args > 2:
dim = args[2]
else:
shape = [sum(shape)]
dim = 0
repeats = args[1]
if isinstance(repeats, int) or torch.numel(repeats) == 1:
shape[dim] *= int(repeats)
else:
shape[dim] = output_size if output_size is not None else repeats.sum()
return torch.empty(*shape, device="meta")
def torch_index_select(input, dim, index, *, out=None):
shape = list(input.shape)
shape[dim] = len(index)
return torch.empty(*shape, device="meta")
def torch_tensor_index_select(self, dim, index):
return torch_index_select(self, dim, index)
def torch_gather(input, dim, index, *, sparse_grad=False, out=None):
shape = list(input.shape)
shape[dim] = index.shape[dim]
return torch.empty(*shape, device="meta")
def torch_tensor_gather(self, dim, index):
return torch_gather(self, dim, index)
def torch_roll(input, shifts, dims=None):
return input
def torch_flip(input, dims):
return input
def torch_tensor_flip(self, dims):
return self
def torch_nn_conv1d(self, input):
l_in = input.shape[-1]
shape = None
padding = self.padding
if padding == "valid":
padding = (0, 0)
if padding == "same":
shape = list(input.shape)
if shape is None:
shape = list(input.shape)
l_out = math.floor(
(l_in + 2 * padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1
)
shape[-1] = l_out
shape[-2] = self.out_channels
return torch.empty(shape, device="meta")
def torch_nn_conv2d(self, input):
h_in, w_in = input.shape[-2:]
shape = None
padding = self.padding
if padding == "valid":
padding = (0, 0)
if padding == "same":
shape = list(input.shape)
if shape is None:
shape = list(input.shape)
h_out = math.floor(
(h_in + 2 * padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1
)
w_out = math.floor(
(w_in + 2 * padding[1] - self.dilation[1] * (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1
)
shape[-2:] = [h_out, w_out]
shape[-3] = self.out_channels
return torch.empty(shape, device="meta")
def torch_squeeze(input, dim=None):
shape = list(input.shape)
if dim is not None:
if dim < 0:
dim = input.dim() + dim
if shape[dim] == 1:
shape.pop(dim)
else:
new_shape = []
for dim_value in shape:
if dim_value == 1:
continue
new_shape.append(dim_value)
shape = new_shape
return torch.empty(shape, device="meta")
def torch_tensor_squeeze(self, dim=None):
return torch_squeeze(self, dim)
def torch_unsqueeze(input, dim):
shape = list(input.shape)
if dim < 0:
dim = input.dim() + 1 + dim
shape.insert(dim, 1)
return torch.empty(shape, device="meta")
def torch_tensor_unsqueeze(self, dim):
return torch_unsqueeze(self, dim)
def torch_unique_consecutive(input, **kwargs):
output = torch.unique_consecutive(torch.zeros_like(input, device="cpu"), **kwargs)
if isinstance(output, torch.Tensor):
return output.to("meta")
else:
return tuple(map(output, lambda x: x.to("meta")))
def torch_nn_functional_one_hot(tensor, num_classes=-1):
if num_classes < 0:
raise ValueError("Don't support automatic num_classes inference for MetaTensor analysis")
shape = list(tensor.shape) + [num_classes]
return torch.empty(shape, device="meta")
def torch_nn_functional_scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None
):
target_length = query.shape[-2]
head_dim = value.shape[-1]
return torch.empty((*query.shape[:-2], target_length, head_dim), device="meta")
def torch_nn_mseloss(self, input, target):
if self.reduction == "none":
shape = target.shape
else:
shape = (1,)
return torch.empty(shape, device="meta")
def torch_nn_crossentropyloss(self, input, target):
if self.reduction == "none":
shape = target.shape
else:
shape = (1,)
return torch.empty(shape, device="meta")
def torch_nn_bcewithlogitsloss(self, input, target):
if self.reduction == "none":
shape = target.shape
else:
shape = (1,)
return torch.empty(shape, device="meta")
def operator_getitem(a, b):
pass
def to_concrete(t):
if isinstance(t, torch.Tensor):
concrete = torch.ones_like(t, device="cpu")
if concrete.dtype in [torch.float16, torch.float32, torch.float64, torch.int32]:
concrete = concrete.to(torch.int64)
return concrete
return t
if isinstance(a, torch.Tensor):
if isinstance(b, tuple):
b = tuple(map(to_concrete, b))
else:
b = to_concrete(b)
return operator.getitem(torch.empty_like(a, device="cpu"), b).to("meta")
return operator.getitem(a, b)
_MANUAL_META_OVERRIDES: Dict[Callable, Callable] = {
torch.nn.Embedding: torch_nn_embedding,
torch.nn.functional.embedding: torch_nn_functional_embedding,
torch.nn.LayerNorm: torch_nn_layernorm,
torch.nn.GroupNorm: torch_nn_groupnorm,
torch.nn.Linear: torch_nn_linear,
torch.relu: torch_relu,
torch.nn.functional.relu: torch_nn_functional_relu,
torch.nn.ReLU: torch_nn_relu,
torch.where: torch_where,
torch.abs: torch_abs,
torch.arange: torch_arange,
torch.full: torch_full,
torch.cat: torch_cat,
torch.stack: torch_stack,
torch.add: torch_add,
torch.mul: torch_mul,
torch.Tensor.mul: torch_tensor_mul,
torch.matmul: torch_matmul,
torch.bmm: torch_bmm,
torch.baddbmm: torch_baddbmm,
torch.Tensor.baddbmm: torch_tensor_baddbmm,
torch.einsum: torch_einsum,
torch.Tensor.repeat: torch_tensor_repeat,
torch.repeat_interleave: torch_repeat_interleave,
torch.roll: torch_roll,
torch.flip: torch_flip,
torch.Tensor.flip: torch_tensor_flip,
torch.index_select: torch_index_select,
torch.Tensor.index_select: torch_tensor_index_select,
torch.gather: torch_gather,
torch.Tensor.gather: torch_tensor_gather,
torch.nn.Conv1d: torch_nn_conv1d,
torch.nn.Conv2d: torch_nn_conv2d,
torch.squeeze: torch_squeeze,
torch.Tensor.squeeze: torch_tensor_squeeze,
torch.unsqueeze: torch_unsqueeze,
torch.Tensor.unsqueeze: torch_tensor_unsqueeze,
torch.unique_consecutive: torch_unique_consecutive,
torch.nn.functional.one_hot: torch_nn_functional_one_hot,
torch.nn.MSELoss: torch_nn_mseloss,
torch.nn.CrossEntropyLoss: torch_nn_crossentropyloss,
torch.nn.BCEWithLogitsLoss: torch_nn_bcewithlogitsloss,
operator.getitem: operator_getitem,
}
if is_torch_greater_or_equal_than_2_0:
_MANUAL_META_OVERRIDES[
torch.nn.functional.scaled_dot_product_attention
] = torch_nn_functional_scaled_dot_product_attention
class HFProxy(Proxy):
"""
Proxy that uses metadata to handle data-dependent control-flow.
"""
def install_metadata(self, metadata):
self._metadata = metadata
@property
def shape(self):
return self.tracer.create_proxy("call_method", "size", (self,), {})
@property
def device(self):
return MetaDeviceAttribute(self, "device")
def __len__(self):
if hasattr(self, "_metadata") and self._metadata is not None:
return len(self._metadata)
return super().__len__()
def __bool__(self):
if hasattr(self, "_metadata") and self._metadata is not None:
return self._metadata
return super().__bool__()
def __getattr__(self, k):
if k == "_metadata":
return self.__getattribute__(k)
return HFAttribute(self, k)
def __setitem__(self, indices, values):
return self.tracer.create_proxy("call_function", operator.setitem, (self, indices, values), {})
def __contains__(self, key):
if hasattr(self, "_metadata") and self._metadata is not None:
return key in self._metadata
return super().__contains__(key)
class HFAttribute(HFProxy):
def __init__(self, root, attr: str):
self.root = root
self.attr = attr
self.tracer = root.tracer
self._node = None
if hasattr(self.root, "_metadata"):
self.install_metadata(getattr(self.root._metadata, attr))
@property
def node(self):
if self._node is None:
self._node = self.tracer.create_proxy("call_function", builtins.getattr, (self.root, self.attr), {}).node
return self._node
def __call__(self, *args, **kwargs):
return self.tracer.create_proxy("call_method", self.attr, (self.root,) + args, kwargs)
class MetaDeviceAttribute(HFAttribute):
pass
def _proxies_to_metas(v):
"""Returns the underlying metadata for HFProxies, and behaves like the identity for the others."""
if isinstance(v, MetaDeviceAttribute):
return "meta"
if isinstance(v, torch.fx.Proxy):
if not (isinstance(v, HFProxy) and hasattr(v, "_metadata")):
raise RuntimeError(f"No metadata was found for {v}")
return v._metadata
return v
def _gen_constructor_wrapper(target):
@functools.wraps(target)
def wrapper(*args, **kwargs):
proxy = None
def check_has_proxy(v):
if isinstance(v, Proxy):
nonlocal proxy
proxy = v
torch.fx.node.map_aggregate(args, check_has_proxy)
torch.fx.node.map_aggregate(kwargs, check_has_proxy)
if proxy is not None:
return proxy.tracer.create_proxy("call_function", target, args, kwargs)
else:
return target(*args, **kwargs)
return wrapper, target
def _generate_random_int(low: int = 10, high: int = 20, forbidden_values: Optional[List[int]] = None):
if forbidden_values is None:
forbidden_values = []
value = random.randint(low, high)
while value in forbidden_values:
value = random.randint(low, high)
return value
class HFTracer(Tracer):
"""
Tracer that is able to symbolically trace models from the library. To do that, it uses the HFProxy instead of the
regular PyTorch torch.fx.Proxy.
"""
proxy_buffer_attributes: bool = True
allow_insert_stateless_mods: bool = True
_TORCH_METHODS_TO_PATCH = [
"arange",
"zeros",
"ones",
"full",
"full_like",
"eye",
"empty",
"tensor",
"clamp",
"finfo",
]
supported_archs = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel)
def __init__(self, autowrap_modules=(math,), autowrap_functions=()):
super().__init__(autowrap_modules=autowrap_modules, autowrap_functions=autowrap_functions)
if not is_torch_fx_available():
raise ImportError(
f"Found an incompatible version of torch. Found version {get_torch_version()}, but only version "
f"{TORCH_FX_REQUIRED_VERSION} is supported."
)
def _generate_dummy_input(
self, model: PreTrainedModel, input_name: str, shape: List[int], input_names: List[str]
):
def _module_getattr(self, attr, attr_val, parameter_proxy_cache):
if getattr(self, "_disable_module_getattr", False):
return attr_val
else:
def maybe_get_proxy_for_attr(attr_val, collection_to_search, parameter_proxy_cache):
for n, p in collection_to_search:
if attr_val is p:
if n not in parameter_proxy_cache:
kwargs = {}
if "proxy_factory_fn" in inspect.signature(self.create_proxy).parameters:
kwargs["proxy_factory_fn"] = (
None
if not self.param_shapes_constant
else lambda node: ParameterProxy(self, node, n, attr_val)
)
val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs)
parameter_proxy_cache[n] = val_proxy
return parameter_proxy_cache[n]
return None
if isinstance(attr_val, torch.nn.Parameter):
maybe_parameter_proxy = maybe_get_proxy_for_attr(
attr_val, self.root.named_parameters(), parameter_proxy_cache
)
if maybe_parameter_proxy is not None:
return maybe_parameter_proxy
if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor):
maybe_buffer_proxy = maybe_get_proxy_for_attr(
attr_val, self.root.named_buffers(), parameter_proxy_cache
)
if maybe_buffer_proxy is not None:
return maybe_buffer_proxy
return attr_val
def getattr(self, attr: str, attr_val: Any, parameter_proxy_cache: Dict[str, Any]):
return self._module_getattr(attr, attr_val, parameter_proxy_cache)
def call_module(self, m, forward, args, kwargs):
self.orig_forward = forward
return super().call_module(m, forward, args, kwargs)
def proxy(self, node):
return HFProxy(node, self)
def trace(
self,
root: Union[torch.nn.Module, Callable[..., Any]],
concrete_args: Optional[Dict[str, Any]] = None,
dummy_inputs: Optional[Dict[str, Any]] = None,
complete_concrete_args_with_inputs_not_in_dummy_inputs: bool = True,
):
"""
Trace method for tracing through the module hierarchy starting from `root`.
Args:
root (Union[torch.nn.Module, Callable[..., Any]]): The root module or callable to start tracing from.
concrete_args (Optional[Dict[str, Any]]): Concrete arguments for the traced function.
dummy_inputs (Optional[Dict[str, Any]]): Dummy inputs for the traced function.
complete_concrete_args_with_inputs_not_in_dummy_inputs (bool):
Flag indicating whether to complete concrete arguments with inputs not in dummy inputs.
Returns:
None
"""
def _stateless_mod_instanciation_depends_on_proxies(self, mod: nn.Module) -> bool:
"""
Check if the module's instantiation depends on Proxies.
Args:
mod (nn.Module): The module to check.
Returns:
bool: True if the module was instantiated with Proxies, otherwise False.
"""
return any(isinstance(attr, Proxy) for attr in mod.__dict__.values())
def _insert_module_as_submodule(self, mod: nn.Module) -> str:
"""
Try to insert a module that was not declared as a submodule.
Args:
mod (nn.Module): The module to insert.
Returns:
str: Path where the module was inserted as a submodule, or an empty string if insertion failed.
"""
if self._stateless_mod_instanciation_depends_on_proxies(mod):
return ""
idx = 0
mod_name = mod.__class__.__name__.lower()
path = f"{mod_name}_{idx}"
already_inserted = False
while hasattr(self.root, path):
if getattr(self.root, path) is mod:
already_inserted = True
break
path = f"{mod_name}_{idx}"
idx += 1
if not already_inserted:
self.root.add_module(path, mod)
return path
def path_of_module(self, mod: nn.Module) -> str:
"""
Find the qualified name of `mod` in the Module hierarchy of `root`.
Args:
mod (nn.Module): The module to retrieve the qualified name for.
Returns:
str: Qualified path of the module in the Module hierarchy of `root`.
"""
try:
return super().path_of_module(mod)
except NameError as e:
if self.allow_insert_stateless_mods and len(list(mod.parameters())) == 0 and len(list(mod.buffers())) == 0:
path = self._insert_module_as_submodule(mod)
return path
raise e
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool:
"""
Check if a module is a leaf module in the module hierarchy.
Args:
m (torch.nn.Module): The module to check.
module_qualified_name (str): Qualified name of the module in the hierarchy.
Returns:
bool: True if the module is a leaf module, otherwise False.
"""
return (not self._stateless_mod_instanciation_depends_on_proxies(m)) and super().is_leaf_module(
m, module_qualified_name
)
@compatibility(is_backward_compatible=True)
def keys(self, obj: "Proxy") -> Any:
"""Called when a proxy object has the keys() method called.
当代理对象调用keys()方法时调用此函数。
This is what happens when ** is called on a proxy.
当代理对象上调用**运算符时会发生这种情况。
This should return an iterator if ** is supposed to work in
your custom tracer.
如果希望在自定义的追踪器中**运算符正常工作,此方法应返回一个迭代器。
"""
attribute = HFAttribute(obj, "keys")()
if obj.node.target == "**kwargs":
return attribute._metadata
return attribute
sig = inspect.signature(model.forward)
if not (set(input_names) <= set(sig.parameters.keys())):
formatted_input_names = input_names[0] if len(input_names) == 1 else ", ".join(input_names)
formatted_allowed_input_names = ", ".join(sig.parameters.keys())
raise ValueError(
f"The model does not have input(s) named: {formatted_input_names}, expected a subset of the following:"
f" {formatted_allowed_input_names}"
)
return {p.name: p.default for p in sig.parameters.values() if p.name not in input_names}
.\utils\generic.py
import inspect
import tempfile
from collections import OrderedDict, UserDict
from collections.abc import MutableMapping
from contextlib import ExitStack, contextmanager
from dataclasses import fields, is_dataclass
from enum import Enum
from functools import partial
from typing import Any, ContextManager, Iterable, List, Tuple
import numpy as np
from packaging import version
from .import_utils import (
get_torch_version,
is_flax_available,
is_mlx_available,
is_tf_available,
is_torch_available,
is_torch_fx_proxy,
)
if is_flax_available():
import jax.numpy as jnp
class cached_property(property):
"""
Descriptor that mimics @property but caches output in member variable.
From tensorflow_datasets
Built-in in functools from Python 3.8.
"""
def __get__(self, obj, objtype=None):
if obj is None:
return self
if self.fget is None:
raise AttributeError("unreadable attribute")
attr = "__cached_" + self.fget.__name__
cached = getattr(obj, attr, None)
if cached is None:
cached = self.fget(obj)
setattr(obj, attr, cached)
return cached
def strtobool(val):
"""Convert a string representation of truth to true (1) or false (0).
True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values are 'n', 'no', 'f', 'false', 'off', and '0'.
Raises ValueError if 'val' is anything else.
"""
val = val.lower()
if val in {"y", "yes", "t", "true", "on", "1"}:
return 1
if val in {"n", "no", "f", "false", "off", "0"}:
return 0
raise ValueError(f"invalid truth value {val!r}")
def infer_framework_from_repr(x):
"""
Tries to guess the framework of an object `x` from its repr (brittle but will help in `is_tensor` to try the
frameworks in a smart order, without the need to import the frameworks).
"""
representation = str(type(x))
if representation.startswith("<class 'torch."):
return "pt"
elif representation.startswith("<class 'tensorflow."):
return "tf"
elif representation.startswith("<class 'jax"):
return "jax"
elif representation.startswith("<class 'numpy."):
return "np"
elif representation.startswith("<class 'mlx."):
return "mlx"
def _get_frameworks_and_test_func(x):
framework_to_test = {
"pt": is_torch_tensor,
"tf": is_tf_tensor,
"jax": is_jax_tensor,
"np": is_numpy_array,
"mlx": is_mlx_array,
}
preferred_framework = infer_framework_from_repr(x)
frameworks = [] if preferred_framework is None else [preferred_framework]
if preferred_framework != "np":
frameworks.append("np")
frameworks.extend([f for f in framework_to_test if f not in [preferred_framework, "np"]])
return {f: framework_to_test[f] for f in frameworks}
def is_tensor(x):
framework_to_test_func = _get_frameworks_and_test_func(x)
for test_func in framework_to_test_func.values():
if test_func(x):
return True
if is_torch_fx_proxy(x):
return True
if is_flax_available():
from jax.core import Tracer
if isinstance(x, Tracer):
return True
return False
def is_numpy_array(x):
return _is_numpy(x)
def is_torch_tensor(x):
return False if not is_torch_available() else _is_torch(x)
def is_torch_device(x):
return False if not is_torch_available() else _is_torch_device(x)
def is_torch_dtype(x):
return False if not is_torch_available() else _is_torch_dtype(x)
def is_tf_tensor(x):
return False if not is_tf_available() else _is_tensorflow(x)
if hasattr(tf, "is_symbolic_tensor"):
return tf.is_symbolic_tensor(x)
return type(x) == tf.Tensor
def is_tf_symbolic_tensor(x):
return False if not is_tf_available() else _is_tf_symbolic_tensor(x)
def _is_jax(x):
import jax.numpy as jnp
return isinstance(x, jnp.ndarray)
def is_jax_tensor(x):
return False if not is_flax_available() else _is_jax(x)
def _is_mlx(x):
import mlx.core as mx
return isinstance(x, mx.array)
def is_mlx_array(x):
return False if not is_mlx_available() else _is_mlx(x)
def to_py_obj(obj):
framework_to_py_obj = {
"pt": lambda obj: obj.detach().cpu().tolist(),
"tf": lambda obj: obj.numpy().tolist(),
"jax": lambda obj: np.asarray(obj).tolist(),
"np": lambda obj: obj.tolist(),
}
if isinstance(obj, (dict, UserDict)):
return {k: to_py_obj(v) for k, v in obj.items()}
elif isinstance(obj, (list, tuple)):
return [to_py_obj(o) for o in obj]
framework_to_test_func = _get_frameworks_and_test_func(obj)
for framework, test_func in framework_to_test_func.items():
if test_func(obj):
return framework_to_py_obj[framework](obj)
if isinstance(obj, np.number):
return obj.tolist()
else:
return obj
def to_numpy(obj):
framework_to_numpy = {
"pt": lambda obj: obj.detach().cpu().numpy(),
"tf": lambda obj: obj.numpy(),
"jax": lambda obj: np.asarray(obj),
"np": lambda obj: obj,
}
if isinstance(obj, (dict, UserDict)):
return {k: to_numpy(v) for k, v in obj.items()}
elif isinstance(obj, (list, tuple)):
return np.array(obj)
framework_to_test_func = _get_frameworks_and_test_func(obj)
for framework, test_func in framework_to_test_func.items():
if test_func(obj):
return framework_to_numpy[framework](obj)
return obj
class ModelOutput(OrderedDict):
"""
Base class for all model outputs as dataclass. Has a `__getitem__` that allows indexing by integer or slice (like a
tuple) or strings (like a dictionary) that will ignore the `None` attributes. Otherwise behaves like a regular
python dictionary.
<Tip warning={true}>
"""
"""
# 注册子类作为 pytree 节点
def __init_subclass__(cls) -> None:
"""Register subclasses as pytree nodes.
This is necessary to synchronize gradients when using `torch.nn.parallel.DistributedDataParallel` with
`static_graph=True` with modules that output `ModelOutput` subclasses.
"""
# 如果 PyTorch 可用且版本大于等于 2.2,则注册 pytree 节点
if is_torch_available():
if version.parse(get_torch_version()) >= version.parse("2.2"):
_torch_pytree.register_pytree_node(
cls,
_model_output_flatten,
partial(_model_output_unflatten, output_type=cls),
serialized_type_name=f"{cls.__module__}.{cls.__name__}",
)
else:
# 对于低版本的 PyTorch,使用旧的注册方式
_torch_pytree._register_pytree_node(
cls,
_model_output_flatten,
partial(_model_output_unflatten, output_type=cls),
)
# 初始化函数,检查是否为 ModelOutput 的子类,并且必须使用 @dataclass 装饰器
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# 子类必须使用 @dataclass 装饰器,这个检查在 __init__ 中进行,因为 @dataclass 装饰器
# 在 __init_subclass__ 之后才生效
# 如果当前类不是 ModelOutput 本身,即当前类是其子类
is_modeloutput_subclass = self.__class__ != ModelOutput
# 如果当前类是 ModelOutput 的子类,并且没有使用 @dataclass 装饰器,则抛出 TypeError
if is_modeloutput_subclass and not is_dataclass(self):
raise TypeError(
f"{self.__module__}.{self.__class__.__name__} is not a dataclasss."
" This is a subclass of ModelOutput and so must use the @dataclass decorator."
)
def __post_init__(self):
"""初始化后检查ModelOutput数据类。
仅在使用@dataclass装饰器时发生。
"""
# 获取数据类的所有字段
class_fields = fields(self)
# 安全性和一致性检查
if not len(class_fields):
# 如果没有字段,则引发值错误异常
raise ValueError(f"{self.__class__.__name__} has no fields.")
if not all(field.default is None for field in class_fields[1:]):
# 如果有超过一个必需字段,则引发值错误异常
raise ValueError(f"{self.__class__.__name__} should not have more than one required field.")
# 获取第一个字段的值
first_field = getattr(self, class_fields[0].name)
# 检查其它字段是否都为None
other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:])
if other_fields_are_none and not is_tensor(first_field):
if isinstance(first_field, dict):
# 如果第一个字段是字典,则遍历字典项
iterator = first_field.items()
first_field_iterator = True
else:
try:
# 尝试迭代第一个字段
iterator = iter(first_field)
first_field_iterator = True
except TypeError:
first_field_iterator = False
# 如果第一个字段是迭代器且是(key, value)形式的迭代器
if first_field_iterator:
for idx, element in enumerate(iterator):
if (
not isinstance(element, (list, tuple))
or not len(element) == 2
or not isinstance(element[0], str)
):
if idx == 0:
# 如果不是(key, value)形式的迭代器,将其设置为属性
self[class_fields[0].name] = first_field
else:
# 如果是混合迭代器,引发值错误异常
raise ValueError(
f"Cannot set key/value for {element}. It needs to be a tuple (key, value)."
)
break
# 设置属性为(key, value)对
setattr(self, element[0], element[1])
if element[1] is not None:
self[element[0]] = element[1]
elif first_field is not None:
# 如果第一个字段不为空,则将其设置为属性
self[class_fields[0].name] = first_field
else:
# 如果存在非None的字段,则将其设置为属性
for field in class_fields:
v = getattr(self, field.name)
if v is not None:
self[field.name] = v
def __delitem__(self, *args, **kwargs):
"""阻止对ModelOutput实例使用``__delitem__``方法。"""
# 抛出异常,不允许使用``__delitem__``方法
raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
def setdefault(self, *args, **kwargs):
"""阻止对ModelOutput实例使用``setdefault``方法。"""
# 抛出异常,不允许使用``setdefault``方法
raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
def pop(self, *args, **kwargs):
"""阻止对ModelOutput实例使用``pop``方法。"""
# 抛出异常,不允许使用``pop``方法
raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
def update(self, *args, **kwargs):
# 抛出异常,阻止在该类实例上使用 `update` 方法
raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
def __getitem__(self, k):
if isinstance(k, str):
# 将内部数据转换为字典,然后返回键 `k` 对应的值
inner_dict = dict(self.items())
return inner_dict[k]
else:
# 调用 `to_tuple()` 方法返回的元组,并使用 `k` 作为索引获取元组中的值
return self.to_tuple()[k]
def __setattr__(self, name, value):
if name in self.keys() and value is not None:
# 避免递归错误,不调用 `self.__setitem__` 方法
super().__setitem__(name, value)
# 设置对象的属性 `name` 为 `value`
super().__setattr__(name, value)
def __setitem__(self, key, value):
# 调用父类的 `__setitem__` 方法设置键 `key` 对应的值 `value`
super().__setitem__(key, value)
# 避免递归错误,不调用 `self.__setattr__` 方法
super().__setattr__(key, value)
def __reduce__(self):
if not is_dataclass(self):
# 如果对象不是数据类,则调用父类的 `__reduce__` 方法
return super().__reduce__()
# 否则,获取对象所有非 `None` 属性或键的元组,并返回
callable, _args, *remaining = super().__reduce__()
args = tuple(getattr(self, field.name) for field in fields(self))
return callable, args, *remaining
def to_tuple(self) -> Tuple[Any]:
"""
Convert self to a tuple containing all the attributes/keys that are not `None`.
"""
# 返回包含所有非 `None` 属性或键的元组
return tuple(self[k] for k in self.keys())
# 检查是否安装了 Torch
if is_torch_available():
# 导入 Torch 的私有模块 _pytree
import torch.utils._pytree as _torch_pytree
# 将模型输出展平化的函数,返回值和上下文信息
def _model_output_flatten(output: ModelOutput) -> Tuple[List[Any], "_torch_pytree.Context"]:
return list(output.values()), list(output.keys())
# 将模型输出还原为原始结构的函数
def _model_output_unflatten(
values: Iterable[Any],
context: "_torch_pytree.Context",
output_type=None,
) -> ModelOutput:
return output_type(**dict(zip(context, values)))
# 如果 Torch 的版本大于等于 2.2,则注册 PyTree 节点
if version.parse(get_torch_version()) >= version.parse("2.2"):
_torch_pytree.register_pytree_node(
ModelOutput,
_model_output_flatten,
partial(_model_output_unflatten, output_type=ModelOutput),
serialized_type_name=f"{ModelOutput.__module__}.{ModelOutput.__name__}",
)
else:
# 否则使用旧的注册方式
_torch_pytree._register_pytree_node(
ModelOutput,
_model_output_flatten,
partial(_model_output_unflatten, output_type=ModelOutput),
)
# 定义一个显式枚举类 ExplicitEnum,继承自 str 和 Enum
class ExplicitEnum(str, Enum):
"""
Enum with more explicit error message for missing values.
"""
# 当枚举值缺失时,提供更明确的错误消息
@classmethod
def _missing_(cls, value):
raise ValueError(
f"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}"
)
# 定义一个填充策略枚举类 PaddingStrategy,继承自 ExplicitEnum
class PaddingStrategy(ExplicitEnum):
"""
Possible values for the `padding` argument in [`PreTrainedTokenizerBase.__call__`]. Useful for tab-completion in an
IDE.
"""
LONGEST = "longest"
MAX_LENGTH = "max_length"
DO_NOT_PAD = "do_not_pad"
# 定义一个张量类型枚举类 TensorType,继承自 ExplicitEnum
class TensorType(ExplicitEnum):
"""
Possible values for the `return_tensors` argument in [`PreTrainedTokenizerBase.__call__`]. Useful for
tab-completion in an IDE.
"""
PYTORCH = "pt"
TENSORFLOW = "tf"
NUMPY = "np"
JAX = "jax"
MLX = "mlx"
# 定义一个上下文管理器类 ContextManagers
class ContextManagers:
"""
Wrapper for `contextlib.ExitStack` which enters a collection of context managers. Adaptation of `ContextManagers`
in the `fastcore` library.
"""
# 初始化方法,接受一个上下文管理器列表作为参数
def __init__(self, context_managers: List[ContextManager]):
self.context_managers = context_managers
self.stack = ExitStack() # 使用 contextlib.ExitStack 创建堆栈
# 进入上下文管理器的方法
def __enter__(self):
for context_manager in self.context_managers:
self.stack.enter_context(context_manager)
# 退出上下文管理器的方法
def __exit__(self, *args, **kwargs):
self.stack.__exit__(*args, **kwargs)
# 定义一个函数,检查给定的模型类是否能返回损失值
def can_return_loss(model_class):
"""
Check if a given model can return loss.
Args:
model_class (`type`): The class of the model.
"""
framework = infer_framework(model_class) # 推断模型所属的框架
if framework == "tf":
signature = inspect.signature(model_class.call) # TensorFlow 模型
elif framework == "pt":
signature = inspect.signature(model_class.forward) # PyTorch 模型
else:
signature = inspect.signature(model_class.__call__) # Flax 模型
# 遍历函数签名的参数列表
for p in signature.parameters:
# 检查当前参数是否为 "return_loss",且其默认值为 True
if p == "return_loss" and signature.parameters[p].default is True:
# 如果满足条件,返回 True
return True
# 如果未找到符合条件的参数,返回 False
return False
# 查找给定模型使用的标签参数列表
def find_labels(model_class):
model_name = model_class.__name__ # 获取模型类的名称
framework = infer_framework(model_class) # 推断模型使用的框架
if framework == "tf":
signature = inspect.signature(model_class.call) # 获取TensorFlow模型的调用签名
elif framework == "pt":
signature = inspect.signature(model_class.forward) # 获取PyTorch模型的前向方法签名
else:
signature = inspect.signature(model_class.__call__) # 获取Flax模型的调用方法签名
if "QuestionAnswering" in model_name: # 如果模型名称中包含"QuestionAnswering"
return [p for p in signature.parameters if "label" in p or p in ("start_positions", "end_positions")] # 返回标签相关的参数列表
else:
return [p for p in signature.parameters if "label" in p] # 返回标签相关的参数列表
# 将嵌套字典展开为单层字典
def flatten_dict(d: MutableMapping, parent_key: str = "", delimiter: str = "."):
def _flatten_dict(d, parent_key="", delimiter="."):
for k, v in d.items():
key = str(parent_key) + delimiter + str(k) if parent_key else k
if v and isinstance(v, MutableMapping):
yield from flatten_dict(v, key, delimiter=delimiter).items() # 递归展开嵌套字典
else:
yield key, v # 直接添加键值对到展开的字典中
return dict(_flatten_dict(d, parent_key, delimiter))
# 提供工作目录或临时目录的上下文管理器
@contextmanager
def working_or_temp_dir(working_dir, use_temp_dir: bool = False):
if use_temp_dir:
with tempfile.TemporaryDirectory() as tmp_dir:
yield tmp_dir # 使用临时目录作为上下文环境
else:
yield working_dir # 使用指定的工作目录作为上下文环境
# 框架无关的数组转置函数,支持numpy、torch、tensorflow和jax的数组
def transpose(array, axes=None):
if is_numpy_array(array): # 如果是numpy数组
return np.transpose(array, axes=axes) # 使用numpy的转置函数
elif is_torch_tensor(array): # 如果是torch张量
return array.T if axes is None else array.permute(*axes) # 使用torch的转置或者按指定轴排列
elif is_tf_tensor(array): # 如果是tensorflow张量
import tensorflow as tf
return tf.transpose(array, perm=axes) # 使用tensorflow的转置函数
elif is_jax_tensor(array): # 如果是jax张量
return jnp.transpose(array, axes=axes) # 使用jax的转置函数
else:
raise ValueError(f"Type not supported for transpose: {type(array)}.") # 抛出类型不支持的异常
# 框架无关的数组重塑函数,支持numpy、torch、tensorflow和jax的数组
def reshape(array, newshape):
if is_numpy_array(array): # 如果是numpy数组
return np.reshape(array, newshape) # 使用numpy的重塑函数
elif is_torch_tensor(array): # 如果是torch张量
return array.reshape(*newshape) # 使用torch的重塑方法
elif is_tf_tensor(array): # 如果是tensorflow张量
import tensorflow as tf
return tf.reshape(array, newshape) # 使用tensorflow的重塑函数
elif is_jax_tensor(array): # 如果是jax张量
return jnp.reshape(array, newshape) # 使用jax的重塑函数
else:
raise ValueError(f"Type not supported for reshape: {type(array)}.") # 抛出类型不支持的异常
# 框架无关的数组挤压函数,支持numpy、torch、tensorflow和jax的数组
def squeeze(array, axis=None):
if is_numpy_array(array): # 如果是numpy数组
return np.squeeze(array, axis=axis) # 使用numpy的挤压函数
# 如果输入的数组是 PyTorch 张量,则进行挤压操作,去除维度为1的轴
elif is_torch_tensor(array):
return array.squeeze() if axis is None else array.squeeze(dim=axis)
# 如果输入的数组是 TensorFlow 张量,则导入 TensorFlow 库并进行挤压操作,去除指定的轴
elif is_tf_tensor(array):
import tensorflow as tf
return tf.squeeze(array, axis=axis)
# 如果输入的数组是 JAX 张量,则进行挤压操作,去除指定的轴
elif is_jax_tensor(array):
return jnp.squeeze(array, axis=axis)
# 如果输入的数组类型不被支持,则抛出异常并显示错误信息
else:
raise ValueError(f"Type not supported for squeeze: {type(array)}.")
# 定义一个函数,用于在不同深度学习框架下扩展张量的维度
def expand_dims(array, axis):
"""
Framework-agnostic version of `numpy.expand_dims` that will work on torch/TensorFlow/Jax tensors as well as NumPy
arrays.
"""
# 如果输入数组是 NumPy 数组,则使用 NumPy 的 `expand_dims` 函数
if is_numpy_array(array):
return np.expand_dims(array, axis)
# 如果输入数组是 PyTorch 张量,则使用 PyTorch 的 `unsqueeze` 函数
elif is_torch_tensor(array):
return array.unsqueeze(dim=axis)
# 如果输入数组是 TensorFlow 张量,则使用 TensorFlow 的 `expand_dims` 函数
elif is_tf_tensor(array):
import tensorflow as tf
return tf.expand_dims(array, axis=axis)
# 如果输入数组是 Jax 张量,则使用 Jax 的 `expand_dims` 函数
elif is_jax_tensor(array):
return jnp.expand_dims(array, axis=axis)
else:
# 如果输入数组类型不被支持,则抛出 ValueError 异常
raise ValueError(f"Type not supported for expand_dims: {type(array)}.")
# 定义一个函数,用于计算不同深度学习框架下张量的大小
def tensor_size(array):
"""
Framework-agnostic version of `numpy.size` that will work on torch/TensorFlow/Jax tensors as well as NumPy arrays.
"""
# 如果输入数组是 NumPy 数组,则返回数组的大小
if is_numpy_array(array):
return np.size(array)
# 如果输入数组是 PyTorch 张量,则返回张量的元素个数
elif is_torch_tensor(array):
return array.numel()
# 如果输入数组是 TensorFlow 张量,则返回张量的大小
elif is_tf_tensor(array):
import tensorflow as tf
return tf.size(array)
# 如果输入数组是 Jax 张量,则返回张量的大小
elif is_jax_tensor(array):
return array.size
else:
# 如果输入数组类型不被支持,则抛出 ValueError 异常
raise ValueError(f"Type not supported for tensor_size: {type(array)}.")
# 定义一个函数,将 repo_id 的信息添加到给定的自动映射 auto_map 中
def add_model_info_to_auto_map(auto_map, repo_id):
"""
Adds the information of the repo_id to a given auto map.
"""
# 遍历 auto_map 的键值对
for key, value in auto_map.items():
# 如果值是列表或元组,则将每个元素前添加 repo_id,避免重复添加
if isinstance(value, (tuple, list)):
auto_map[key] = [f"{repo_id}--{v}" if (v is not None and "--" not in v) else v for v in value]
# 如果值不是 None 且不包含 "--",则在值前添加 repo_id
elif value is not None and "--" not in value:
auto_map[key] = f"{repo_id}--{value}"
# 返回更新后的 auto_map
return auto_map
# 定义一个函数,推断给定模型类的深度学习框架
def infer_framework(model_class):
"""
Infers the framework of a given model without using isinstance(), because we cannot guarantee that the relevant
classes are imported or available.
"""
# 遍历模型类的方法解析顺序(Method Resolution Order)
for base_class in inspect.getmro(model_class):
module = base_class.__module__
name = base_class.__name__
# 如果基类模块名以 "tensorflow" 或 "keras" 开头,或者基类名为 "TFPreTrainedModel",则推断为 TensorFlow 框架
if module.startswith("tensorflow") or module.startswith("keras") or name == "TFPreTrainedModel":
return "tf"
# 如果基类模块名以 "torch" 开头,或者基类名为 "PreTrainedModel",则推断为 PyTorch 框架
elif module.startswith("torch") or name == "PreTrainedModel":
return "pt"
# 如果基类模块名以 "flax" 或 "jax" 开头,或者基类名为 "FlaxPreTrainedModel",则推断为 Jax/Flax 框架
elif module.startswith("flax") or module.startswith("jax") or name == "FlaxPreTrainedModel":
return "flax"
else:
# 如果无法推断出框架,则抛出 TypeError 异常
raise TypeError(f"Could not infer framework from class {model_class}.")
.\utils\hp_naming.py
import copy
import re
class TrialShortNamer:
PREFIX = "hp"
DEFAULTS = {}
NAMING_INFO = None
@classmethod
def set_defaults(cls, prefix, defaults):
cls.PREFIX = prefix
cls.DEFAULTS = defaults
cls.build_naming_info()
@staticmethod
def shortname_for_word(info, word):
if len(word) == 0:
return ""
short_word = None
if any(char.isdigit() for char in word):
raise Exception(f"Parameters should not contain numbers: '{word}' contains a number")
if word in info["short_word"]:
return info["short_word"][word]
for prefix_len in range(1, len(word) + 1):
prefix = word[:prefix_len]
if prefix in info["reverse_short_word"]:
continue
else:
short_word = prefix
break
if short_word is None:
def int_to_alphabetic(integer):
s = ""
while integer != 0:
s = chr(ord("A") + integer % 10) + s
integer //= 10
return s
i = 0
while True:
sword = word + "#" + int_to_alphabetic(i)
if sword in info["reverse_short_word"]:
continue
else:
short_word = sword
break
info["short_word"][word] = short_word
info["reverse_short_word"][short_word] = word
return short_word
@staticmethod
def shortname_for_key(info, param_name):
words = param_name.split("_")
shortname_parts = [TrialShortNamer.shortname_for_word(info, word) for word in words]
separators = ["", "_"]
for separator in separators:
shortname = separator.join(shortname_parts)
if shortname not in info["reverse_short_param"]:
info["short_param"][param_name] = shortname
info["reverse_short_param"][shortname] = param_name
return shortname
return param_name
@staticmethod
def add_new_param_name(info, param_name):
short_name = TrialShortNamer.shortname_for_key(info, param_name)
info["short_param"][param_name] = short_name
info["reverse_short_param"][short_name] = param_name
@classmethod
def build_naming_info(cls):
if cls.NAMING_INFO is not None:
return
info = {
"short_word": {},
"reverse_short_word": {},
"short_param": {},
"reverse_short_param": {},
}
field_keys = list(cls.DEFAULTS.keys())
for k in field_keys:
cls.add_new_param_name(info, k)
cls.NAMING_INFO = info
@classmethod
def shortname(cls, params):
cls.build_naming_info()
assert cls.PREFIX is not None
name = [copy.copy(cls.PREFIX)]
for k, v in params.items():
if k not in cls.DEFAULTS:
raise Exception(f"You should provide a default value for the param name {k} with value {v}")
if v == cls.DEFAULTS[k]:
continue
key = cls.NAMING_INFO["short_param"][k]
if isinstance(v, bool):
v = 1 if v else 0
sep = "" if isinstance(v, (int, float)) else "-"
e = f"{key}{sep}{v}"
name.append(e)
return "_".join(name)
@classmethod
def parse_repr(cls, repr):
repr = repr[len(cls.PREFIX) + 1 :]
if repr == "":
values = []
else:
values = repr.split("_")
parameters = {}
for value in values:
if "-" in value:
p_k, p_v = value.split("-")
else:
p_k = re.sub("[0-9.]", "", value)
p_v = float(re.sub("[^0-9.]", "", value))
key = cls.NAMING_INFO["reverse_short_param"][p_k]
parameters[key] = p_v
for k in cls.DEFAULTS:
if k not in parameters:
parameters[k] = cls.DEFAULTS[k]
return parameters
.\utils\hub.py
import json
import os
import re
import shutil
import sys
import tempfile
import traceback
import warnings
from concurrent import futures
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
from urllib.parse import urlparse
from uuid import uuid4
import huggingface_hub
import requests
from huggingface_hub import (
_CACHED_NO_EXIST,
CommitOperationAdd,
ModelCard,
ModelCardData,
constants,
create_branch,
create_commit,
create_repo,
get_hf_file_metadata,
hf_hub_download,
hf_hub_url,
try_to_load_from_cache,
)
from huggingface_hub.file_download import REGEX_COMMIT_HASH, http_get
from huggingface_hub.utils import (
EntryNotFoundError,
GatedRepoError,
HFValidationError,
LocalEntryNotFoundError,
RepositoryNotFoundError,
RevisionNotFoundError,
build_hf_headers,
hf_raise_for_status,
send_telemetry,
)
from huggingface_hub.utils._deprecation import _deprecate_method
from requests.exceptions import HTTPError
from . import __version__, logging
from .generic import working_or_temp_dir
from .import_utils import (
ENV_VARS_TRUE_VALUES,
_tf_version,
_torch_version,
is_tf_available,
is_torch_available,
is_training_run_on_sagemaker,
)
from .logging import tqdm
logger = logging.get_logger(__name__)
_is_offline_mode = True if os.environ.get("TRANSFORMERS_OFFLINE", "0").upper() in ENV_VARS_TRUE_VALUES else False
def is_offline_mode():
return _is_offline_mode
torch_cache_home = os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch"))
default_cache_path = constants.default_cache_path
old_default_cache_path = os.path.join(torch_cache_home, "transformers")
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", constants.HF_HUB_CACHE)
PYTORCH_TRANSFORMERS_CACHE = os.getenv("PYTORCH_TRANSFORMERS_CACHE", PYTORCH_PRETRAINED_BERT_CACHE)
TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", PYTORCH_TRANSFORMERS_CACHE)
if (
os.path.isdir(old_default_cache_path)
and not os.path.isdir(constants.HF_HUB_CACHE)
and "PYTORCH_PRETRAINED_BERT_CACHE" not in os.environ
and "PYTORCH_TRANSFORMERS_CACHE" not in os.environ
and "TRANSFORMERS_CACHE" not in os.environ
):
logger.warning(
"In Transformers v4.22.0, the default path to cache downloaded models changed from"
" '~/.cache/torch/transformers' to '~/.cache/huggingface/hub'. Since you don't seem to have"
" overridden and '~/.cache/torch/transformers' is a directory that exists, we're moving it to"
" '~/.cache/huggingface/hub' to avoid redownloading models you have already in the cache. You should"
" only see this message once."
)
shutil.move(old_default_cache_path, constants.HF_HUB_CACHE)
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(constants.HF_HOME, "modules"))
TRANSFORMERS_DYNAMIC_MODULE_NAME = "transformers_modules"
SESSION_ID = uuid4().hex
for key in ("PYTORCH_PRETRAINED_BERT_CACHE", "PYTORCH_TRANSFORMERS_CACHE", "TRANSFORMERS_CACHE"):
if os.getenv(key) is not None:
warnings.warn(
f"Using `{key}` is deprecated and will be removed in v5 of Transformers. Use `HF_HOME` instead.",
FutureWarning,
)
S3_BUCKET_PREFIX = "https://s3.amazonaws.com/models.huggingface.co/bert"
CLOUDFRONT_DISTRIB_PREFIX = "https://cdn.huggingface.co"
_staging_mode = os.environ.get("HUGGINGFACE_CO_STAGING", "NO").upper() in ENV_VARS_TRUE_VALUES
_default_endpoint = "https://hub-ci.huggingface.co" if _staging_mode else "https://huggingface.co"
HUGGINGFACE_CO_RESOLVE_ENDPOINT = _default_endpoint
if os.environ.get("HUGGINGFACE_CO_RESOLVE_ENDPOINT", None) is not None:
warnings.warn(
"Using the environment variable `HUGGINGFACE_CO_RESOLVE_ENDPOINT` is deprecated and will be removed in "
"Transformers v5. Use `HF_ENDPOINT` instead.",
FutureWarning,
)
HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HUGGINGFACE_CO_RESOLVE_ENDPOINT", None)
HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HF_ENDPOINT", HUGGINGFACE_CO_RESOLVE_ENDPOINT)
HUGGINGFACE_CO_PREFIX = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/{model_id}/resolve/{revision}/{filename}"
HUGGINGFACE_CO_EXAMPLES_TELEMETRY = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/api/telemetry/examples"
def _get_cache_file_to_return(
path_or_repo_id: str, full_filename: str, cache_dir: Union[str, Path, None] = None, revision: Optional[str] = None
):
def try_to_load_from_cache(path_or_repo_id, full_filename, cache_dir=None, revision=None):
resolved_file = try_to_load_from_cache_inner(path_or_repo_id, full_filename, cache_dir=cache_dir, revision=revision)
if resolved_file is not None and resolved_file != _CACHED_NO_EXIST:
return resolved_file
return None
def try_to_load_from_cache_inner(path_or_repo_id, full_filename, cache_dir=None, revision=None):
cache_file_path = None
if cache_dir is not None:
cache_file_path = os.path.join(cache_dir, full_filename)
if os.path.exists(cache_file_path):
return os.fspath(cache_file_path)
else:
return None
else:
return None
def is_remote_url(url_or_filename):
parsed = urlparse(url_or_filename)
return parsed.scheme in ("http", "https")
@_deprecate_method(version="4.39.0", message="This method is outdated and does not support the new cache system.")
def get_cached_models(cache_dir: Union[str, Path] = None) -> List[Tuple]:
"""
返回一个列表,表示本地缓存的模型二进制文件。每个元组的形式为 `(model_url, etag, size_MB)`。
只有以 *.bin* 结尾的 URL 文件名会被添加到列表中。
Args:
cache_dir (`Union[str, Path]`, *optional*):
要在其中搜索模型的缓存目录。如果未设置,将默认使用 transformers 的缓存目录。
Returns:
List[Tuple]: 包含 `(model_url, etag, size_MB)` 形式的元组列表
"""
if cache_dir is None:
cache_dir = TRANSFORMERS_CACHE
elif isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
if not os.path.isdir(cache_dir):
return []
cached_models = []
for file in os.listdir(cache_dir):
if file.endswith(".json"):
meta_path = os.path.join(cache_dir, file)
with open(meta_path, encoding="utf-8") as meta_file:
metadata = json.load(meta_file)
url = metadata["url"]
etag = metadata["etag"]
if url.endswith(".bin"):
size_MB = os.path.getsize(meta_path.strip(".json")) / 1e6
cached_models.append((url, etag, size_MB))
return cached_models
def define_sagemaker_information():
try:
instance_data = requests.get(os.environ["ECS_CONTAINER_METADATA_URI"]).json()
dlc_container_used = instance_data["Image"]
dlc_tag = instance_data["Image"].split(":")[1]
except Exception:
dlc_container_used = None
dlc_tag = None
sagemaker_params = json.loads(os.getenv("SM_FRAMEWORK_PARAMS", "{}"))
runs_distributed_training = True if "sagemaker_distributed_dataparallel_enabled" in sagemaker_params else False
account_id = os.getenv("TRAINING_JOB_ARN").split(":")[4] if "TRAINING_JOB_ARN" in os.environ else None
sagemaker_object = {
"sm_framework": os.getenv("SM_FRAMEWORK_MODULE", None),
"sm_region": os.getenv("AWS_REGION", None),
"sm_number_gpu": os.getenv("SM_NUM_GPUS", 0),
"sm_number_cpu": os.getenv("SM_NUM_CPUS", 0),
"sm_distributed_training": runs_distributed_training,
"sm_deep_learning_container": dlc_container_used,
"sm_deep_learning_container_tag": dlc_tag,
"sm_account_id": account_id,
}
return sagemaker_object
def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
"""
# 格式化用户代理字符串,包含请求的基本信息
"""
ua = f"transformers/{__version__}; python/{sys.version.split()[0]}; session_id/{SESSION_ID}"
if is_torch_available():
ua += f"; torch/{_torch_version}"
if is_tf_available():
ua += f"; tensorflow/{_tf_version}"
if constants.HF_HUB_DISABLE_TELEMETRY:
return ua + "; telemetry/off"
if is_training_run_on_sagemaker():
ua += "; " + "; ".join(f"{k}/{v}" for k, v in define_sagemaker_information().items())
if os.environ.get("TRANSFORMERS_IS_CI", "").upper() in ENV_VARS_TRUE_VALUES:
ua += "; is_ci/true"
if isinstance(user_agent, dict):
ua += "; " + "; ".join(f"{k}/{v}" for k, v in user_agent.items())
elif isinstance(user_agent, str):
ua += "; " + user_agent
返回格式化后的用户代理字符串
"""
# 从已解析的文件名中提取提交哈希值,并用于缓存文件。
def extract_commit_hash(resolved_file: Optional[str], commit_hash: Optional[str]) -> Optional[str]:
# 如果 resolved_file 为 None 或者 commit_hash 不为 None,则直接返回 commit_hash
if resolved_file is None or commit_hash is not None:
return commit_hash
# 将 resolved_file 转换为标准的 POSIX 路径字符串
resolved_file = str(Path(resolved_file).as_posix())
# 使用正则表达式在 resolved_file 中搜索匹配 'snapshots/([^/]+)/' 的内容
search = re.search(r"snapshots/([^/]+)/", resolved_file)
# 如果未找到匹配项,则返回 None
if search is None:
return None
# 从搜索结果中获取第一个捕获组,即提取的 commit_hash
commit_hash = search.groups()[0]
# 如果提取的 commit_hash 符合预期的格式(由 REGEX_COMMIT_HASH 定义),则返回 commit_hash,否则返回 None
return commit_hash if REGEX_COMMIT_HASH.match(commit_hash) else None
# 尝试在本地文件夹或存储库中定位文件,如果必要则下载并缓存它。
def cached_file(
path_or_repo_id: Union[str, os.PathLike],
filename: str,
cache_dir: Optional[Union[str, os.PathLike]] = None,
force_download: bool = False,
resume_download: bool = False,
proxies: Optional[Dict[str, str]] = None,
token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
local_files_only: bool = False,
subfolder: str = "",
repo_type: Optional[str] = None,
user_agent: Optional[Union[str, Dict[str, str]]] = None,
_raise_exceptions_for_gated_repo: bool = True,
_raise_exceptions_for_missing_entries: bool = True,
_raise_exceptions_for_connection_errors: bool = True,
_commit_hash: Optional[str] = None,
**deprecated_kwargs,
) -> Optional[str]:
"""
Tries to locate a file in a local folder and repo, downloads and cache it if necessary.
"""
# 获取 deprecated_kwargs 字典中的 use_auth_token 键对应的值,并将其从字典中移除
use_auth_token = deprecated_kwargs.pop("use_auth_token", None)
# 如果 use_auth_token 参数不为 None,则发出警告信息,说明该参数已弃用,并将在 Transformers 版本 v5 中移除。建议使用 `token` 参数替代。
# 引发 FutureWarning 警告。
if use_auth_token is not None:
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
FutureWarning,
)
# 如果同时指定了 token 参数和 use_auth_token 参数,则抛出 ValueError 异常。
if token is not None:
raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
# 将 token 参数设置为 use_auth_token 参数的值。
token = use_auth_token
# Private arguments
# _raise_exceptions_for_gated_repo: if False, do not raise an exception for gated repo error but return
# None.
# _raise_exceptions_for_missing_entries: if False, do not raise an exception for missing entries but return
# None.
# _raise_exceptions_for_connection_errors: if False, do not raise an exception for connection errors but return
# None.
# _commit_hash: passed when we are chaining several calls to various files (e.g. when loading a tokenizer or
# a pipeline). If files are cached for this commit hash, avoid calls to head and get from the cache.
# 如果处于离线模式且 local_files_only 参数为 False,则设置 local_files_only 参数为 True,并输出相应的日志信息。
if is_offline_mode() and not local_files_only:
logger.info("Offline mode: forcing local_files_only=True")
local_files_only = True
# 如果 subfolder 参数为 None,则将其设置为空字符串。
if subfolder is None:
subfolder = ""
# 将 path_or_repo_id 参数转换为字符串。
path_or_repo_id = str(path_or_repo_id)
# 将 subfolder 和 filename 参数拼接成完整的文件路径。
full_filename = os.path.join(subfolder, filename)
# 如果 path_or_repo_id 参数指定的路径是一个目录,则解析文件路径并检查文件是否存在。
if os.path.isdir(path_or_repo_id):
resolved_file = os.path.join(os.path.join(path_or_repo_id, subfolder), filename)
# 如果解析后的文件路径不是一个文件,并且 _raise_exceptions_for_missing_entries 参数为 True,则抛出 EnvironmentError 异常。
if not os.path.isfile(resolved_file):
if _raise_exceptions_for_missing_entries:
raise EnvironmentError(
f"{path_or_repo_id} does not appear to have a file named {full_filename}. Checkout "
f"'https://huggingface.co/{path_or_repo_id}/tree/{revision}' for available files."
)
# 如果 _raise_exceptions_for_missing_entries 参数为 False,则返回 None。
else:
return None
# 返回解析后的文件路径。
return resolved_file
# 如果 cache_dir 参数为 None,则将其设置为 TRANSFORMERS_CACHE 变量的值。
if cache_dir is None:
cache_dir = TRANSFORMERS_CACHE
# 如果 cache_dir 参数是 Path 对象,则将其转换为字符串。
if isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
# 如果 _commit_hash 参数不为 None 并且 force_download 参数为 False,则尝试从缓存中加载文件。
if _commit_hash is not None and not force_download:
# 如果文件在指定的 _commit_hash 下被缓存,则直接返回该文件。
resolved_file = try_to_load_from_cache(
path_or_repo_id, full_filename, cache_dir=cache_dir, revision=_commit_hash, repo_type=repo_type
)
# 如果成功加载文件,则根据情况返回解析后的文件路径、None 或抛出异常。
if resolved_file is not None:
if resolved_file is not _CACHED_NO_EXIST:
return resolved_file
elif not _raise_exceptions_for_missing_entries:
return None
else:
raise EnvironmentError(f"Could not locate {full_filename} inside {path_or_repo_id}.")
# 调用 http_user_agent 函数来处理 user_agent 参数。
user_agent = http_user_agent(user_agent)
try:
# 尝试从 URL 或缓存加载文件
resolved_file = hf_hub_download(
path_or_repo_id,
filename,
subfolder=None if len(subfolder) == 0 else subfolder,
repo_type=repo_type,
revision=revision,
cache_dir=cache_dir,
user_agent=user_agent,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
except GatedRepoError as e:
# 如果遇到受限制的仓库错误,则尝试从缓存中获取文件以返回
resolved_file = _get_cache_file_to_return(path_or_repo_id, full_filename, cache_dir, revision)
# 如果已获取文件或不应为受限制的仓库错误引发异常,则返回解析的文件
if resolved_file is not None or not _raise_exceptions_for_gated_repo:
return resolved_file
# 否则,引发环境错误并显示详细信息
raise EnvironmentError(
"You are trying to access a gated repo.\nMake sure to have access to it at "
f"https://huggingface.co/{path_or_repo_id}.\n{str(e)}"
) from e
except RepositoryNotFoundError as e:
# 如果仓库未找到,则引发环境错误并显示详细信息
raise EnvironmentError(
f"{path_or_repo_id} is not a local folder and is not a valid model identifier "
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token "
"having permission to this repo either by logging in with `huggingface-cli login` or by passing "
"`token=<your_token>`"
) from e
except RevisionNotFoundError as e:
# 如果找不到指定的版本号,则引发环境错误并显示详细信息
raise EnvironmentError(
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists "
"for this model name. Check the model page at "
f"'https://huggingface.co/{path_or_repo_id}' for available revisions."
) from e
except LocalEntryNotFoundError as e:
# 如果本地条目未找到,则尝试从缓存获取文件以返回
resolved_file = _get_cache_file_to_return(path_or_repo_id, full_filename, cache_dir, revision)
# 如果已获取文件或不应为丢失条目或连接错误引发异常,则返回解析的文件
if (
resolved_file is not None
or not _raise_exceptions_for_missing_entries
or not _raise_exceptions_for_connection_errors
):
return resolved_file
# 否则,引发环境错误并显示详细信息
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this file, couldn't find it in the"
f" cached files and it looks like {path_or_repo_id} is not the path to a directory containing a file named"
f" {full_filename}.\nCheckout your internet connection or see how to run the library in offline mode at"
" 'https://huggingface.co/docs/transformers/installation#offline-mode'."
) from e
# 处理 EntryNotFoundError 异常,如果设置了不抛出缺失条目异常,则返回 None
except EntryNotFoundError as e:
if not _raise_exceptions_for_missing_entries:
return None
# 如果未指定修订版本,则默认为 "main"
if revision is None:
revision = "main"
# 抛出环境错误,指示指定的路径或 repo_id 中不存在指定的完整文件名
raise EnvironmentError(
f"{path_or_repo_id} does not appear to have a file named {full_filename}. Checkout "
f"'https://huggingface.co/{path_or_repo_id}/{revision}' for available files."
) from e
# 处理 HTTPError 异常
except HTTPError as err:
# 尝试获取缓存中已解决的文件,如果存在或设置了不抛出连接错误异常,则返回该文件
resolved_file = _get_cache_file_to_return(path_or_repo_id, full_filename, cache_dir, revision)
if resolved_file is not None or not _raise_exceptions_for_connection_errors:
return resolved_file
# 抛出环境错误,指示加载指定路径或 repo_id 时发生特定的连接错误
raise EnvironmentError(f"There was a specific connection error when trying to load {path_or_repo_id}:\n{err}")
# 处理 HFValidationError 异常
except HFValidationError as e:
# 抛出环境错误,指示路径或模型 ID 的提供方式不正确,应提供本地文件夹的路径或 Hub 上模型的 repo_id
raise EnvironmentError(
f"Incorrect path_or_model_id: '{path_or_repo_id}'. Please provide either the path to a local folder or the repo_id of a model on the Hub."
) from e
# 返回已解决的文件(如果有)
return resolved_file
# TODO: deprecate `get_file_from_repo` or document it differently?
# Docstring is exactly the same as `cached_repo` but behavior is slightly different. If file is missing or if
# there is a connection error, `cached_repo` will return None while `get_file_from_repo` will raise an error.
# IMO we should keep only 1 method and have a single `raise_error` argument (to be discussed).
# 定义了一个函数 `get_file_from_repo`,用于从本地文件夹或仓库中获取文件,并在需要时下载和缓存它。
def get_file_from_repo(
path_or_repo: Union[str, os.PathLike], # 参数1: 文件路径或仓库位置,可以是字符串或PathLike对象
filename: str, # 参数2: 文件名,表示需要获取的文件名
cache_dir: Optional[Union[str, os.PathLike]] = None, # 参数3: 缓存目录的路径,可选,默认为None
force_download: bool = False, # 参数4: 是否强制下载文件,默认为False
resume_download: bool = False, # 参数5: 是否继续下载(即断点续传),默认为False
proxies: Optional[Dict[str, str]] = None, # 参数6: 代理设置,可选,默认为None
token: Optional[Union[bool, str]] = None, # 参数7: 访问令牌,可选,默认为None
revision: Optional[str] = None, # 参数8: 仓库的版本或分支,可选,默认为None
local_files_only: bool = False, # 参数9: 是否只使用本地文件,不从仓库下载,默认为False
subfolder: str = "", # 参数10: 仓库中的子文件夹路径,默认为空字符串
**deprecated_kwargs, # 其他已废弃的关键字参数将被收集到deprecated_kwargs中
):
"""
Tries to locate a file in a local folder and repo, downloads and cache it if necessary.
"""
Args:
path_or_repo (`str` or `os.PathLike`):
This can be either:
- a string, the *model id* of a model repo on huggingface.co.
- a path to a *directory* potentially containing the file.
filename (`str`):
The name of the file to locate in `path_or_repo`.
cache_dir (`str` or `os.PathLike`, *optional*):
Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
cache should not be used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force to (re-)download the configuration files and override the cached versions if they
exist.
resume_download (`bool`, *optional*, defaults to `False`):
Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `huggingface-cli login` (stored in `~/.huggingface`).
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
identifier allowed by git.
local_files_only (`bool`, *optional*, defaults to `False`):
If `True`, will only try to load the tokenizer configuration from local files.
subfolder (`str`, *optional*, defaults to `""`):
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
specify the folder name here.
<Tip>
Passing `token=True` is required when you want to use a private model.
</Tip>
Returns:
`Optional[str]`: Returns the resolved file (to the cache folder if downloaded from a repo) or `None` if the
file does not exist.
Examples:
```
# Download a tokenizer configuration from huggingface.co and cache.
tokenizer_config = get_file_from_repo("google-bert/bert-base-uncased", "tokenizer_config.json")
# This model does not have a tokenizer config so the result will be None.
tokenizer_config = get_file_from_repo("FacebookAI/xlm-roberta-base", "tokenizer_config.json")
```
"""
use_auth_token = deprecated_kwargs.pop("use_auth_token", None)
if use_auth_token is not None:
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
FutureWarning,
)
if token is not None:
raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
token = use_auth_token
return cached_file(
path_or_repo_id=path_or_repo,
filename=filename,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
token=token,
revision=revision,
local_files_only=local_files_only,
subfolder=subfolder,
_raise_exceptions_for_gated_repo=False,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
)
def download_url(url, proxies=None):
warnings.warn(
f"Using `from_pretrained` with the url of a file (here {url}) is deprecated and won't be possible anymore in"
" v5 of Transformers. You should host your file on the Hub (hf.co) instead and use the repository ID. Note"
" that this is not compatible with the caching system (your file will be downloaded at each execution) or"
" multiple processes (each process will download the file in a different temporary file).",
FutureWarning,
)
tmp_fd, tmp_file = tempfile.mkstemp()
with os.fdopen(tmp_fd, "wb") as f:
http_get(url, f, proxies=proxies)
return tmp_file
def has_file(
path_or_repo: Union[str, os.PathLike],
filename: str,
revision: Optional[str] = None,
proxies: Optional[Dict[str, str]] = None,
token: Optional[Union[bool, str]] = None,
**deprecated_kwargs,
):
use_auth_token = deprecated_kwargs.pop("use_auth_token", None)
if use_auth_token is not None:
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
FutureWarning,
)
if token is not None:
raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
token = use_auth_token
if os.path.isdir(path_or_repo):
return os.path.isfile(os.path.join(path_or_repo, filename))
url = hf_hub_url(path_or_repo, filename=filename, revision=revision)
headers = build_hf_headers(token=token, user_agent=http_user_agent())
r = requests.head(url, headers=headers, allow_redirects=False, proxies=proxies, timeout=10)
try:
hf_raise_for_status(r)
return True
except GatedRepoError as e:
logger.error(e)
raise EnvironmentError(
f"{path_or_repo} is a gated repository. Make sure to request access at "
f"https://huggingface.co/{path_or_repo} and pass a token having permission to this repo either by "
"logging in with `huggingface-cli login` or by passing `token=<your_token>`."
) from e
except RepositoryNotFoundError as e:
logger.error(e)
raise EnvironmentError(
f"{path_or_repo} is not a local folder or a valid repository name on 'https://hf.co'."
)
except RevisionNotFoundError as e:
logger.error(e)
raise EnvironmentError(
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for this "
f"model name. Check the model page at 'https://huggingface.co/{path_or_repo}' for available revisions."
)
except requests.HTTPError:
return False
def _create_repo(
self,
repo_id: str,
private: Optional[bool] = None,
token: Optional[Union[bool, str]] = None,
repo_url: Optional[str] = None,
organization: Optional[str] = None,
) -> str:
"""
创建仓库(如果需要),清理使用了已弃用参数 `repo_url` 和 `organization` 的 `repo_id`,并获取 token。
"""
if repo_url is not None:
warnings.warn(
"The `repo_url` argument is deprecated and will be removed in v5 of Transformers. Use `repo_id` "
"instead."
)
if repo_id is not None:
raise ValueError(
"`repo_id` and `repo_url` are both specified. Please set only the argument `repo_id`."
)
repo_id = repo_url.replace(f"{HUGGINGFACE_CO_RESOLVE_ENDPOINT}/", "")
if organization is not None:
warnings.warn(
"The `organization` argument is deprecated and will be removed in v5 of Transformers. Set your "
"organization directly in the `repo_id` passed instead (`repo_id={organization}/{model_id}`)."
)
if not repo_id.startswith(organization):
if "/" in repo_id:
repo_id = repo_id.split("/")[-1]
repo_id = f"{organization}/{repo_id}"
url = create_repo(repo_id=repo_id, token=token, private=private, exist_ok=True)
return url.repo_id
def _get_files_timestamps(self, working_dir: Union[str, os.PathLike]):
"""
返回工作目录下文件及其最后修改时间戳的字典。
"""
return {f: os.path.getmtime(os.path.join(working_dir, f)) for f in os.listdir(working_dir)}
def _upload_modified_files(
self,
working_dir: Union[str, os.PathLike],
repo_id: str,
files_timestamps: Dict[str, float],
commit_message: Optional[str] = None,
token: Optional[Union[bool, str]] = None,
create_pr: bool = False,
revision: str = None,
commit_description: str = None,
):
"""
上传修改过的文件到指定的仓库,并支持创建 Pull Request 功能。
"""
):
"""
Uploads all modified files in `working_dir` to `repo_id`, based on `files_timestamps`.
"""
if commit_message is None:
if "Model" in self.__class__.__name__:
commit_message = "Upload model"
elif "Config" in self.__class__.__name__:
commit_message = "Upload config"
elif "Tokenizer" in self.__class__.__name__:
commit_message = "Upload tokenizer"
elif "FeatureExtractor" in self.__class__.__name__:
commit_message = "Upload feature extractor"
elif "Processor" in self.__class__.__name__:
commit_message = "Upload processor"
else:
commit_message = f"Upload {self.__class__.__name__}"
modified_files = [
f
for f in os.listdir(working_dir)
if f not in files_timestamps or os.path.getmtime(os.path.join(working_dir, f)) > files_timestamps[f]
]
modified_files = [
f
for f in modified_files
if os.path.isfile(os.path.join(working_dir, f)) or os.path.isdir(os.path.join(working_dir, f))
]
operations = []
for file in modified_files:
if os.path.isdir(os.path.join(working_dir, file)):
for f in os.listdir(os.path.join(working_dir, file)):
operations.append(
CommitOperationAdd(
path_or_fileobj=os.path.join(working_dir, file, f), path_in_repo=os.path.join(file, f)
)
)
else:
operations.append(
CommitOperationAdd(path_or_fileobj=os.path.join(working_dir, file), path_in_repo=file)
)
if revision is not None:
create_branch(repo_id=repo_id, branch=revision, token=token, exist_ok=True)
logger.info(f"Uploading the following files to {repo_id}: {','.join(modified_files)}")
return create_commit(
repo_id=repo_id,
operations=operations,
commit_message=commit_message,
commit_description=commit_description,
token=token,
create_pr=create_pr,
revision=revision,
)
def push_to_hub(
self,
repo_id: str,
use_temp_dir: Optional[bool] = None,
commit_message: Optional[str] = None,
private: Optional[bool] = None,
token: Optional[Union[bool, str]] = None,
max_shard_size: Optional[Union[int, str]] = "5GB",
create_pr: bool = False,
safe_serialization: bool = True,
revision: str = None,
commit_description: str = None,
tags: Optional[List[str]] = None,
**deprecated_kwargs,
def send_example_telemetry(example_name, *example_args, framework="pytorch"):
"""
Sends telemetry that helps tracking the examples use.
Args:
example_name (`str`): The name of the example.
*example_args (dataclasses or `argparse.ArgumentParser`): The arguments to the script. This function will only
try to extract the model and dataset name from those. Nothing else is tracked.
framework (`str`, *optional*, defaults to `"pytorch"`): The framework for the example.
"""
if is_offline_mode():
return
data = {"example": example_name, "framework": framework}
for args in example_args:
args_as_dict = {k: v for k, v in args.__dict__.items() if not k.startswith("_") and v is not None}
if "model_name_or_path" in args_as_dict:
model_name = args_as_dict["model_name_or_path"]
if not os.path.isdir(model_name):
data["model_name"] = args_as_dict["model_name_or_path"]
if "dataset_name" in args_as_dict:
data["dataset_name"] = args_as_dict["dataset_name"]
elif "task_name" in args_as_dict:
script_name = example_name.replace("tf_", "").replace("flax_", "").replace("run_", "")
script_name = script_name.replace("_no_trainer", "")
data["dataset_name"] = f"{script_name}-{args_as_dict['task_name']}"
send_telemetry(
topic="examples", library_name="transformers", library_version=__version__, user_agent=http_user_agent(data)
)
def convert_file_size_to_int(size: Union[int, str]):
"""
Converts a size expressed as a string with digits an unit (like `"5MB"`) to an integer (in bytes).
Args:
size (`int` or `str`): The size to convert. Will be directly returned if an `int`.
Example:
```
>>> convert_file_size_to_int("1MiB")
1048576
```
"""
if isinstance(size, int):
return size
if size.upper().endswith("GIB"):
return int(size[:-3]) * (2**30)
if size.upper().endswith("MIB"):
return int(size[:-3]) * (2**20)
if size.upper().endswith("KIB"):
return int(size[:-3]) * (2**10)
if size.upper().endswith("GB"):
int_size = int(size[:-2]) * (10**9)
return int_size // 8 if size.endswith("b") else int_size
if size.upper().endswith("MB"):
int_size = int(size[:-2]) * (10**6)
return int_size // 8 if size.endswith("b") else int_size
if size.upper().endswith("KB"):
int_size = int(size[:-2]) * (10**3)
return int_size // 8 if size.endswith("b") else int_size
raise ValueError("`size` is not in a valid format. Use an integer followed by the unit, e.g., '5GB'.")
def get_checkpoint_shard_files(
pretrained_model_name_or_path,
index_filename,
cache_dir=None,
force_download=False,
proxies=None,
resume_download=False,
local_files_only=False,
token=None,
user_agent=None,
revision=None,
subfolder="",
_commit_hash=None,
**deprecated_kwargs,
"""
For a given model:
- download and cache all the shards of a sharded checkpoint if `pretrained_model_name_or_path` is a model ID on the
Hub
- returns the list of paths to all the shards, as well as some metadata.
For the description of each arg, see [`PreTrainedModel.from_pretrained`]. `index_filename` is the full path to the
index (downloaded and cached if `pretrained_model_name_or_path` is a model ID on the Hub).
"""
import json
use_auth_token = deprecated_kwargs.pop("use_auth_token", None)
if use_auth_token is not None:
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
FutureWarning,
)
if token is not None:
raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
token = use_auth_token
if not os.path.isfile(index_filename):
raise ValueError(f"Can't find a checkpoint index ({index_filename}) in {pretrained_model_name_or_path}.")
with open(index_filename, "r") as f:
index = json.loads(f.read())
shard_filenames = sorted(set(index["weight_map"].values()))
sharded_metadata = index["metadata"]
sharded_metadata["all_checkpoint_keys"] = list(index["weight_map"].keys())
sharded_metadata["weight_map"] = index["weight_map"].copy()
if os.path.isdir(pretrained_model_name_or_path):
shard_filenames = [os.path.join(pretrained_model_name_or_path, subfolder, f) for f in shard_filenames]
return shard_filenames, sharded_metadata
cached_filenames = []
last_shard = try_to_load_from_cache(
pretrained_model_name_or_path, shard_filenames[-1], cache_dir=cache_dir, revision=_commit_hash
)
show_progress_bar = last_shard is None or force_download
for shard_filename in tqdm(shard_filenames, desc="Downloading shards", disable=not show_progress_bar):
try:
cached_filename = cached_file(
pretrained_model_name_or_path,
shard_filename,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
token=token,
user_agent=user_agent,
revision=revision,
subfolder=subfolder,
_commit_hash=_commit_hash,
)
except EntryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {shard_filename} which is "
"required according to the checkpoint index."
)
except HTTPError:
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {shard_filename}. You should try"
" again after checking your internet connection."
)
cached_filenames.append(cached_filename)
return cached_filenames, sharded_metadata
def get_all_cached_files(cache_dir=None):
if cache_dir is None:
cache_dir = TRANSFORMERS_CACHE
else:
cache_dir = str(cache_dir)
if not os.path.isdir(cache_dir):
return []
cached_files = []
for file in os.listdir(cache_dir):
meta_path = os.path.join(cache_dir, f"{file}.json")
if not os.path.isfile(meta_path):
continue
with open(meta_path, encoding="utf-8") as meta_file:
metadata = json.load(meta_file)
url = metadata["url"]
etag = metadata["etag"].replace('"', "")
cached_files.append({"file": file, "url": url, "etag": etag})
return cached_files
def extract_info_from_url(url):
search = re.search(r"^https://huggingface\.co/(.*)/resolve/([^/]*)/(.*)$", url)
if search is None:
return None
repo, revision, filename = search.groups()
cache_repo = "--".join(["models"] + repo.split("/"))
return {"repo": cache_repo, "revision": revision, "filename": filename}
def create_and_tag_model_card(
repo_id: str,
tags: Optional[List[str]] = None,
token: Optional[str] = None,
ignore_metadata_errors: bool = False,
):
try:
model_card = ModelCard.load(repo_id, token=token, ignore_metadata_errors=ignore_metadata_errors)
except EntryNotFoundError:
model_description = "This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated."
card_data = ModelCardData(tags=[] if tags is None else tags, library_name="transformers")
model_card = ModelCard.from_template(card_data, model_description=model_description)
if tags is not None:
for model_tag in tags:
if model_tag not in model_card.data.tags:
model_card.data.tags.append(model_tag)
return model_card
def clean_files_for(file):
pass
for f in [file, f"{file}.json", f"{file}.lock"]:
if os.path.isfile(f):
os.remove(f)
def move_to_new_cache(file, repo, filename, revision, etag, commit_hash):
os.makedirs(repo, exist_ok=True)
os.makedirs(os.path.join(repo, "refs"), exist_ok=True)
if revision != commit_hash:
ref_path = os.path.join(repo, "refs", revision)
with open(ref_path, "w") as f:
f.write(commit_hash)
os.makedirs(os.path.join(repo, "blobs"), exist_ok=True)
blob_path = os.path.join(repo, "blobs", etag)
shutil.move(file, blob_path)
os.makedirs(os.path.join(repo, "snapshots"), exist_ok=True)
os.makedirs(os.path.join(repo, "snapshots", commit_hash), exist_ok=True)
pointer_path = os.path.join(repo, "snapshots", commit_hash, filename)
huggingface_hub.file_download._create_relative_symlink(blob_path, pointer_path)
clean_files_for(file)
def move_cache(cache_dir=None, new_cache_dir=None, token=None):
if new_cache_dir is None:
new_cache_dir = TRANSFORMERS_CACHE
if cache_dir is None:
old_cache = Path(TRANSFORMERS_CACHE).parent / "transformers"
if os.path.isdir(str(old_cache)):
cache_dir = str(old_cache)
else:
cache_dir = new_cache_dir
cached_files = get_all_cached_files(cache_dir=cache_dir)
logger.info(f"Moving {len(cached_files)} files to the new cache system")
hub_metadata = {}
for file_info in tqdm(cached_files):
url = file_info.pop("url")
if url not in hub_metadata:
try:
hub_metadata[url] = get_hf_file_metadata(url, token=token)
except requests.HTTPError:
continue
etag, commit_hash = hub_metadata[url].etag, hub_metadata[url].commit_hash
if etag is None or commit_hash is None:
continue
if file_info["etag"] != etag:
clean_files_for(os.path.join(cache_dir, file_info["file"]))
continue
url_info = extract_info_from_url(url)
if url_info is None:
continue
repo = os.path.join(new_cache_dir, url_info["repo"])
move_to_new_cache(
file=os.path.join(cache_dir, file_info["file"]),
repo=repo,
filename=url_info["filename"],
revision=url_info["revision"],
etag=etag,
commit_hash=commit_hash,
)
class PushInProgress:
"""
Internal class to keep track of a push in progress (which might contain multiple `Future` jobs).
"""
def __init__(self, jobs: Optional[futures.Future] = None) -> None:
self.jobs = [] if jobs is None else jobs
def is_done(self):
return all(job.done() for job in self.jobs)
def wait_until_done(self):
futures.wait(self.jobs)
def cancel(self) -> None:
self.jobs = [
job
for job in self.jobs
if not (job.cancel() or job.done())
]
cache_version_file = os.path.join(TRANSFORMERS_CACHE, "version.txt")
if not os.path.isfile(cache_version_file):
cache_version = 0
else:
with open(cache_version_file) as f:
try:
cache_version = int(f.read())
except ValueError:
cache_version = 0
cache_is_not_empty = os.path.isdir(TRANSFORMERS_CACHE) and len(os.listdir(TRANSFORMERS_CACHE)) > 0
if cache_version < 1 and cache_is_not_empty:
if is_offline_mode():
logger.warning(
"You are offline and the cache for model files in Transformers v4.22.0 has been updated while your local "
"cache seems to be the one of a previous version. It is very likely that all your calls to any "
"`from_pretrained()` method will fail. Remove the offline mode and enable internet connection to have "
"your cache be updated automatically, then you can go back to offline mode."
)
else:
logger.warning(
"The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a "
"one-time only operation. You can interrupt this and resume the migration later on by calling "
"`transformers.utils.move_cache()`."
)
try:
if TRANSFORMERS_CACHE != constants.HF_HUB_CACHE:
move_cache(TRANSFORMERS_CACHE, TRANSFORMERS_CACHE)
else:
move_cache()
except Exception as e:
trace = "\n".join(traceback.format_tb(e.__traceback__))
logger.error(
f"There was a problem when trying to move your cache:\n\n{trace}\n{e.__class__.__name__}: {e}\n\nPlease "
"file an issue at https://github.com/huggingface/transformers/issues/new/choose and copy paste this whole "
"message and we will do our best to help."
)
if cache_version < 1:
try:
os.makedirs(TRANSFORMERS_CACHE, exist_ok=True)
with open(cache_version_file, "w") as f:
f.write("1")
except Exception:
logger.warning(
f"There was a problem when trying to write in your cache folder ({TRANSFORMERS_CACHE}). You should set "
"the environment variable TRANSFORMERS_CACHE to a writable directory."
)
.\utils\import_utils.py
import importlib.metadata
import importlib.util
import json
import os
import shutil
import subprocess
import sys
import warnings
from collections import OrderedDict
from functools import lru_cache
from itertools import chain
from types import ModuleType
from typing import Any, Tuple, Union
from packaging import version
from . import logging
logger = logging.get_logger(__name__)
def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[Tuple[bool, str], bool]:
"""
检查指定的包是否可用,并返回其版本信息(如果指定)。
Args:
pkg_name (str): 要检查的包的名称。
return_version (bool, optional): 是否返回包的版本信息。默认为 False。
Returns:
Union[Tuple[bool, str], bool]: 如果 return_version 为 True,则返回包的存在状态和版本信息的元组;
否则,仅返回包的存在状态(布尔值)。
Notes:
如果包存在,则尝试获取其版本信息,如果无法获取则使用特定的后备方法。
使用 logging 模块记录调试信息,包括检测到的包的版本信息。
"""
package_exists = importlib.util.find_spec(pkg_name) is not None
package_version = "N/A"
if package_exists:
try:
package_version = importlib.metadata.version(pkg_name)
except importlib.metadata.PackageNotFoundError:
if pkg_name == "torch":
try:
package = importlib.import_module(pkg_name)
temp_version = getattr(package, "__version__", "N/A")
if "dev" in temp_version:
package_version = temp_version
package_exists = True
else:
package_exists = False
except ImportError:
package_exists = False
else:
package_exists = False
logger.debug(f"Detected {pkg_name} version: {package_version}")
if return_version:
return package_exists, package_version
else:
return package_exists
ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})
USE_TF = os.environ.get("USE_TF", "AUTO").upper()
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper()
USE_TORCH_XLA = os.environ.get("USE_TORCH_XLA", "1").upper()
FORCE_TF_AVAILABLE = os.environ.get("FORCE_TF_AVAILABLE", "AUTO").upper()
TORCH_FX_REQUIRED_VERSION = version.parse("1.10")
ACCELERATE_MIN_VERSION = "0.21.0"
FSDP_MIN_VERSION = "1.12.0"
_accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True)
_apex_available = _is_package_available("apex")
_aqlm_available = _is_package_available("aqlm")
_bitsandbytes_available = _is_package_available("bitsandbytes")
_galore_torch_available = _is_package_available("galore_torch")
_bs4_available = importlib.util.find_spec("bs4") is not None
_coloredlogs_available = _is_package_available("coloredlogs")
_cv2_available = importlib.util.find_spec("cv2") is not None
_datasets_available = _is_package_available("datasets")
_decord_available = importlib.util.find_spec("decord") is not None
_detectron2_available = _is_package_available("detectron2")
_faiss_available = importlib.util.find_spec("faiss") is not None
try:
_faiss_version = importlib.metadata.version("faiss")
logger.debug(f"Successfully imported faiss version {_faiss_version}")
except importlib.metadata.PackageNotFoundError:
try:
_faiss_version = importlib.metadata.version("faiss-cpu")
logger.debug(f"Successfully imported faiss version {_faiss_version}")
except importlib.metadata.PackageNotFoundError:
_faiss_available = False
_ftfy_available = _is_package_available("ftfy")
_g2p_en_available = _is_package_available("g2p_en")
_ipex_available, _ipex_version = _is_package_available("intel_extension_for_pytorch", return_version=True)
_jieba_available = _is_package_available("jieba")
_jinja_available = _is_package_available("jinja2")
_kenlm_available = _is_package_available("kenlm")
_keras_nlp_available = _is_package_available("keras_nlp")
_levenshtein_available = _is_package_available("Levenshtein")
_librosa_available = _is_package_available("librosa")
_natten_available = _is_package_available("natten")
_nltk_available = _is_package_available("nltk")
_onnx_available = _is_package_available("onnx")
_openai_available = _is_package_available("openai")
_optimum_available = _is_package_available("optimum")
_auto_gptq_available = _is_package_available("auto_gptq")
_auto_awq_available = importlib.util.find_spec("awq") is not None
_quanto_available = _is_package_available("quanto")
_pandas_available = _is_package_available("pandas")
_peft_available = _is_package_available("peft")
_phonemizer_available = _is_package_available("phonemizer")
_psutil_available = _is_package_available("psutil")
_py3nvml_available = _is_package_available("py3nvml")
_pyctcdecode_available = _is_package_available("pyctcdecode")
_pytesseract_available = _is_package_available("pytesseract")
_pytest_available = _is_package_available("pytest")
_pytorch_quantization_available = _is_package_available("pytorch_quantization")
_rjieba_available = _is_package_available("rjieba")
_sacremoses_available = _is_package_available("sacremoses")
_safetensors_available = _is_package_available("safetensors")
_scipy_available = _is_package_available("scipy")
_sentencepiece_available = _is_package_available("sentencepiece")
_is_seqio_available = _is_package_available("seqio")
_sklearn_available = importlib.util.find_spec("sklearn") is not None
if _sklearn_available:
try:
importlib.metadata.version("scikit-learn")
except importlib.metadata.PackageNotFoundError:
_sklearn_available = False
_smdistributed_available = importlib.util.find_spec("smdistributed") is not None
_soundfile_available = _is_package_available("soundfile")
_spacy_available = _is_package_available("spacy")
_sudachipy_available, _sudachipy_version = _is_package_available("sudachipy", return_version=True)
_tensorflow_probability_available = _is_package_available("tensorflow_probability")
_tensorflow_text_available = _is_package_available("tensorflow_text")
_tf2onnx_available = _is_package_available("tf2onnx")
_timm_available = _is_package_available("timm")
_tokenizers_available = _is_package_available("tokenizers")
_torchaudio_available = _is_package_available("torchaudio")
_torchdistx_available = _is_package_available("torchdistx")
_torchvision_available = _is_package_available("torchvision")
_mlx_available = _is_package_available("mlx")
_torch_version = "N/A"
_torch_available = False
if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
_torch_available, _torch_version = _is_package_available("torch", return_version=True)
else:
logger.info("Disabling PyTorch because USE_TF is set")
_torch_available = False
_tf_version = "N/A"
_tf_available = False
if FORCE_TF_AVAILABLE in ENV_VARS_TRUE_VALUES:
_tf_available = True
else:
if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES:
_tf_available = importlib.util.find_spec("tensorflow") is not None
if _tf_available:
candidates = (
"tensorflow",
"tensorflow-cpu",
"tensorflow-gpu",
"tf-nightly",
"tf-nightly-cpu",
"tf-nightly-gpu",
"tf-nightly-rocm",
"intel-tensorflow",
"intel-tensorflow-avx512",
"tensorflow-rocm",
"tensorflow-macos",
"tensorflow-aarch64",
)
_tf_version = None
for pkg in candidates:
try:
_tf_version = importlib.metadata.version(pkg)
break
except importlib.metadata.PackageNotFoundError:
pass
_tf_available = _tf_version is not None
if _tf_available:
if version.parse(_tf_version) < version.parse("2"):
logger.info(
f"TensorFlow found but with version {_tf_version}. Transformers requires version 2 minimum."
)
_tf_available = False
else:
logger.info("Disabling Tensorflow because USE_TORCH is set")
_essentia_available = importlib.util.find_spec("essentia") is not None
try:
_essentia_version = importlib.metadata.version("essentia")
logger.debug(f"Successfully imported essentia version {_essentia_version}")
except importlib.metadata.PackageNotFoundError:
_essentia_version = False
_pretty_midi_available = importlib.util.find_spec("pretty_midi") is not None
try:
_pretty_midi_version = importlib.metadata.version("pretty_midi")
logger.debug(f"Successfully imported pretty_midi version {_pretty_midi_version}")
except importlib.metadata.PackageNotFoundError:
_pretty_midi_available = False
ccl_version = "N/A"
_is_ccl_available = (
importlib.util.find_spec("torch_ccl") is not None
or importlib.util.find_spec("oneccl_bindings_for_pytorch") is not None
)
try:
ccl_version = importlib.metadata.version("oneccl_bind_pt")
logger.debug(f"Detected oneccl_bind_pt version {ccl_version}")
except importlib.metadata.PackageNotFoundError:
_is_ccl_available = False
_flax_available = False
if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
_flax_available, _flax_version = _is_package_available("flax", return_version=True)
if _flax_available:
_jax_available, _jax_version = _is_package_available("jax", return_version=True)
if _jax_available:
logger.info(f"JAX version {_jax_version}, Flax version {_flax_version} available.")
else:
_flax_available = _jax_available = False
_jax_version = _flax_version = "N/A"
_torch_fx_available = False
if _torch_available:
torch_version = version.parse(_torch_version)
_torch_fx_available = (torch_version.major, torch_version.minor) >= (
TORCH_FX_REQUIRED_VERSION.major,
TORCH_FX_REQUIRED_VERSION.minor,
)
_torch_xla_available = False
if USE_TORCH_XLA in ENV_VARS_TRUE_VALUES:
_torch_xla_available, _torch_xla_version = _is_package_available("torch_xla", return_version=True)
if _torch_xla_available:
logger.info(f"Torch XLA version {_torch_xla_version} available.")
def is_kenlm_available():
return _kenlm_available
def is_cv2_available():
return _cv2_available
def is_torch_available():
return _torch_available
def get_torch_version():
return _torch_version
def is_torch_sdpa_available():
if not is_torch_available():
return False
elif _torch_version == "N/A":
return False
return version.parse(_torch_version) >= version.parse("2.1.1")
def is_torchvision_available():
return _torchvision_available
def is_galore_torch_available():
return _galore_torch_available
def is_pyctcdecode_available():
return _pyctcdecode_available
def is_librosa_available():
return _librosa_available
def is_essentia_available():
return _essentia_available
def is_pretty_midi_available():
return _pretty_midi_available
def is_torch_cuda_available():
if is_torch_available():
import torch
return torch.cuda.is_available()
else:
return False
def is_mamba_ssm_available():
if is_torch_available():
import torch
if not torch.cuda.is_available():
return False
else:
return _is_package_available("mamba_ssm")
return False
def is_causal_conv1d_available():
if is_torch_available():
import torch
if not torch.cuda.is_available():
return False
return _is_package_available("causal_conv1d")
return False
def is_torch_mps_available():
if is_torch_available():
import torch
if hasattr(torch.backends, "mps"):
return torch.backends.mps.is_available()
return False
def is_torch_bf16_gpu_available():
if not is_torch_available():
return False
import torch
return torch.cuda.is_available() and torch.cuda.is_bf16_supported()
def is_torch_bf16_cpu_available():
if not is_torch_available():
return False
import torch
try:
_ = torch.cpu.amp.autocast
except AttributeError:
return False
return True
def is_torch_bf16_available():
warnings.warn(
"The util is_torch_bf16_available is deprecated, please use is_torch_bf16_gpu_available "
"or is_torch_bf16_cpu_available instead according to whether it's used with cpu or gpu",
FutureWarning,
)
return is_torch_bf16_gpu_available()
def is_torch_fp16_available_on_device(device):
if not is_torch_available():
return False
import torch
try:
x = torch.zeros(2, 2, dtype=torch.float16).to(device)
_ = x @ x
batch, sentence_length, embedding_dim = 3, 4, 5
embedding = torch.randn(batch, sentence_length, embedding_dim, dtype=torch.float16, device=device)
layer_norm = torch.nn.LayerNorm(embedding_dim, dtype=torch.float16, device=device)
_ = layer_norm(embedding)
except:
return False
return True
@lru_cache()
def is_torch_bf16_available_on_device(device):
if not is_torch_available():
return False
import torch
if device == "cuda":
return is_torch_bf16_gpu_available()
try:
x = torch.zeros(2, 2, dtype=torch.bfloat16).to(device)
_ = x @ x
except:
return False
return True
def is_torch_tf32_available():
if not is_torch_available():
return False
import torch
if not torch.cuda.is_available() or torch.version.cuda is None:
return False
if torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8:
return False
if int(torch.version.cuda.split(".")[0]) < 11:
return False
if version.parse(version.parse(torch.__version__).base_version) < version.parse("1.7"):
return False
return True
def is_torch_fx_available():
return _torch_fx_available
def is_peft_available():
return _peft_available
def is_bs4_available():
return _bs4_available
def is_tf_available():
return _tf_available
def is_coloredlogs_available():
return _coloredlogs_available
def is_tf2onnx_available():
return _tf2onnx_available
def is_onnx_available():
return _onnx_available
def is_openai_available():
return _openai_available
def is_flax_available():
return _flax_available
def is_ftfy_available():
return _ftfy_available
def is_g2p_en_available():
return _g2p_en_available
@lru_cache()
def is_torch_tpu_available(check_device=True):
warnings.warn(
"`is_torch_tpu_available` is deprecated and will be removed in 4.41.0. "
"Please use the `is_torch_xla_available` instead.",
FutureWarning,
)
if not _torch_available:
return False
if importlib.util.find_spec("torch_xla") is not None:
if check_device:
try:
import torch_xla.core.xla_model as xm
_ = xm.xla_device()
return True
except RuntimeError:
return False
return True
return False
@lru_cache
def is_torch_xla_available(check_is_tpu=False, check_is_gpu=False):
"""
Check if `torch_xla` is available. To train a native pytorch job in an environment with torch xla installed, set
the USE_TORCH_XLA to false.
"""
assert not (check_is_tpu and check_is_gpu), "The check_is_tpu and check_is_gpu cannot both be true."
if not _torch_xla_available:
return False
import torch_xla
if check_is_gpu:
return torch_xla.runtime.device_type() in ["GPU", "CUDA"]
elif check_is_tpu:
return torch_xla.runtime.device_type() == "TPU"
return True
@lru_cache()
def is_torch_neuroncore_available(check_device=True):
if importlib.util.find_spec("torch_neuronx") is not None:
return is_torch_xla_available()
return False
@lru_cache()
def is_torch_npu_available(check_device=False):
if not _torch_available or importlib.util.find_spec("torch_npu") is None:
return False
import torch
import torch_npu
if check_device:
try:
_ = torch.npu.device_count()
return torch.npu.is_available()
except RuntimeError:
return False
return hasattr(torch, "npu") and torch.npu.is_available()
def is_torchdynamo_available():
if not is_torch_available():
return False
try:
import torch._dynamo as dynamo
return True
except Exception:
return False
def is_torch_compile_available():
if not is_torch_available():
return False
import torch
return hasattr(torch, "compile")
def is_torchdynamo_compiling():
if not is_torch_available():
return False
try:
import torch._dynamo as dynamo
return dynamo.is_compiling()
except Exception:
return False
def is_torch_tensorrt_fx_available():
if importlib.util.find_spec("torch_tensorrt") is None:
return False
return importlib.util.find_spec("torch_tensorrt.fx") is not None
def is_datasets_available():
return _datasets_available
def is_detectron2_available():
return _detectron2_available
def is_rjieba_available():
return _rjieba_available
def is_psutil_available():
return _psutil_available
def is_py3nvml_available():
return _py3nvml_available
def is_sacremoses_available():
return _sacremoses_available
def is_apex_available():
return _apex_available
def is_aqlm_available():
return _aqlm_available
def is_ninja_available():
r"""
Code comes from *torch.utils.cpp_extension.is_ninja_available()*. Returns `True` if the
[ninja](https://ninja-build.org/) build system is available on the system, `False` otherwise.
"""
try:
subprocess.check_output("ninja --version".split())
except Exception:
return False
else:
return True
def is_ipex_available():
def get_major_and_minor_from_version(full_version):
return str(version.parse(full_version).major) + "." + str(version.parse(full_version).minor)
if not is_torch_available() or not _ipex_available:
return False
torch_major_and_minor = get_major_and_minor_from_version(_torch_version)
ipex_major_and_minor = get_major_and_minor_from_version(_ipex_version)
if torch_major_and_minor != ipex_major_and_minor:
logger.warning(
f"Intel Extension for PyTorch {ipex_major_and_minor} needs to work with PyTorch {ipex_major_and_minor}.*,"
f" but PyTorch {_torch_version} is found. Please switch to the matching version and run again."
)
return False
return True
@lru_cache
def is_torch_xpu_available(check_device=False):
if not is_ipex_available():
return False
import intel_extension_for_pytorch
import torch
if check_device:
try:
_ = torch.xpu.device_count()
return torch.xpu.is_available()
except RuntimeError:
return False
return hasattr(torch, "xpu") and torch.xpu.is_available()
def is_bitsandbytes_available():
if not is_torch_available():
return False
import torch
return _bitsandbytes_available and torch.cuda.is_available()
def is_flash_attn_2_available():
if not is_torch_available():
return False
if not _is_package_available("flash_attn"):
return False
import torch
if not torch.cuda.is_available():
return False
if torch.version.cuda:
return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.1.0")
elif torch.version.hip:
return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.0.4")
else:
return False
def is_flash_attn_greater_or_equal_2_10():
if not _is_package_available("flash_attn"):
return False
return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.1.0")
def is_torchdistx_available():
return _torchdistx_available
def is_faiss_available():
return _faiss_available
def is_scipy_available():
return _scipy_available
def is_sklearn_available():
return _sklearn_available
def is_sentencepiece_available():
return _sentencepiece_available
def is_seqio_available():
return _is_seqio_available
def is_protobuf_available():
if importlib.util.find_spec("google") is None:
return False
return importlib.util.find_spec("google.protobuf") is not None
def is_accelerate_available(min_version: str = ACCELERATE_MIN_VERSION):
if min_version is not None:
return _accelerate_available and version.parse(_accelerate_version) >= version.parse(min_version)
return _accelerate_available
def is_fsdp_available(min_version: str = FSDP_MIN_VERSION):
if is_torch_available():
return version.parse(_torch_version) >= version.parse(min_version)
return False
def is_optimum_available():
return _optimum_available
def is_auto_awq_available():
return _auto_awq_available
def is_quanto_available():
return _quanto_available
def is_auto_gptq_available():
return _auto_gptq_available
def is_levenshtein_available():
return _levenshtein_available
def is_optimum_neuron_available():
return _optimum_available and _is_package_available("optimum.neuron")
def is_safetensors_available():
return _safetensors_available
def is_tokenizers_available():
return _tokenizers_available
@lru_cache
def is_vision_available():
_pil_available = importlib.util.find_spec("PIL") is not None
if _pil_available:
try:
package_version = importlib.metadata.version("Pillow")
except importlib.metadata.PackageNotFoundError:
try:
package_version = importlib.metadata.version("Pillow-SIMD")
except importlib.metadata.PackageNotFoundError:
return False
logger.debug(f"Detected PIL version {package_version}")
return _pil_available
def is_pytesseract_available():
return _pytesseract_available
def is_pytest_available():
return _pytest_available
def is_spacy_available():
return _spacy_available
def is_tensorflow_text_available():
return is_tf_available() and _tensorflow_text_available
def is_keras_nlp_available():
return is_tensorflow_text_available() and _keras_nlp_available
def is_in_notebook():
try:
get_ipython = sys.modules["IPython"].get_ipython
if "IPKernelApp" not in get_ipython().config:
raise ImportError("console")
if "VSCODE_PID" in os.environ:
raise ImportError("vscode")
if "DATABRICKS_RUNTIME_VERSION" in os.environ and os.environ["DATABRICKS_RUNTIME_VERSION"] < "11.0":
raise ImportError("databricks")
return importlib.util.find_spec("IPython") is not None
except (AttributeError, ImportError, KeyError):
return False
def is_pytorch_quantization_available():
return _pytorch_quantization_available
def is_tensorflow_probability_available():
return _tensorflow_probability_available
def is_pandas_available():
return _pandas_available
def is_sagemaker_dp_enabled():
sagemaker_params = os.getenv("SM_FRAMEWORK_PARAMS", "{}")
try:
sagemaker_params = json.loads(sagemaker_params)
if not sagemaker_params.get("sagemaker_distributed_dataparallel_enabled", False):
return False
except json.JSONDecodeError:
return False
return _smdistributed_available
def is_sagemaker_mp_enabled():
smp_options = os.getenv("SM_HP_MP_PARAMETERS", "{}")
try:
smp_options = json.loads(smp_options)
if "partitions" not in smp_options:
return False
except json.JSONDecodeError:
return False
mpi_options = os.getenv("SM_FRAMEWORK_PARAMS", "{}")
try:
mpi_options = json.loads(mpi_options)
if not mpi_options.get("sagemaker_mpi_enabled", False):
return False
except json.JSONDecodeError:
return False
return _smdistributed_available
def is_training_run_on_sagemaker():
return "SAGEMAKER_JOB_NAME" in os.environ
def is_soundfile_availble():
return _soundfile_available
def is_timm_available():
return _timm_available
def is_natten_available():
return _natten_available
def is_nltk_available():
return _nltk_available
def is_torchaudio_available():
return _torchaudio_available
def is_speech_available():
return _torchaudio_available
def is_phonemizer_available():
return _phonemizer_available
def torch_only_method(fn):
def wrapper(*args, **kwargs):
if not _torch_available:
raise ImportError(
"You need to install pytorch to use this method or class, "
"or activate it with environment variables USE_TORCH=1 and USE_TF=0."
)
else:
return fn(*args, **kwargs)
return wrapper
def is_ccl_available():
return _is_ccl_available
def is_decord_available():
return _decord_available
def is_sudachi_available():
return _sudachipy_available
def get_sudachi_version():
return _sudachipy_version
def is_sudachi_projection_available():
if not is_sudachi_available():
return False
return version.parse(_sudachipy_version) >= version.parse("0.6.8")
def is_jumanpp_available():
return (importlib.util.find_spec("rhoknp") is not None) and (shutil.which("jumanpp") is not None)
def is_cython_available():
return importlib.util.find_spec("pyximport") is not None
def is_jieba_available():
return _jieba_available
def is_jinja_available():
return _jinja_available
def is_mlx_available():
return _mlx_available
CV2_IMPORT_ERROR = """
{0} requires the OpenCV library but it was not found in your environment. You can install it with:
pip install opencv-python
Please note that you may need to restart your runtime after installation.
"""
DATASETS_IMPORT_ERROR = """
{0} requires the 🤗 Datasets library but it was not found in your environment. You can install it with:
pip install datasets
In a notebook or a colab, you can install it by executing a cell with
!pip install datasets
then restarting your kernel.
Note that if you have a local folder named `datasets` or a local python file named `datasets.py` in your current
working directory, python may try to import this instead of the 🤗 Datasets library. You should rename this folder or
that python file if that's the case. Please note that you may need to restart your runtime after installation.
"""
TOKENIZERS_IMPORT_ERROR = """
# 格式化字符串,用于给定模块名的导入错误提示信息
SENTENCEPIECE_IMPORT_ERROR = """
{0} requires the SentencePiece library but it was not found in your environment. Checkout the instructions on the
installation page of its repo: https://github.com/google/sentencepiece
that match your environment. Please note that you may need to restart your runtime after installation.
"""
# 格式化字符串,用于给定模块名的导入错误提示信息
PROTOBUF_IMPORT_ERROR = """
{0} requires the protobuf library but it was not found in your environment. Checkout the instructions on the
installation page of its repo: https://github.com/protocolbuffers/protobuf/tree/master/python
that match your environment. Please note that you may need to restart your runtime after installation.
"""
# 格式化字符串,用于给定模块名的导入错误提示信息
FAISS_IMPORT_ERROR = """
{0} requires the faiss library but it was not found in your environment. Checkout the instructions on the
installation page of its repo: https://github.com/facebookresearch/faiss/blob/master/INSTALL.md and follow the ones
that match your environment. Please note that you may need to restart your runtime after installation.
"""
# 格式化字符串,用于给定模块名的导入错误提示信息
PYTORCH_IMPORT_ERROR = """
{0} requires the PyTorch library but it was not found in your environment. Checkout the instructions on the
installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment.
Please note that you may need to restart your runtime after installation.
"""
# 格式化字符串,用于给定模块名的导入错误提示信息
TORCHVISION_IMPORT_ERROR = """
{0} requires the Torchvision library but it was not found in your environment. Checkout the instructions on the
installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment.
Please note that you may need to restart your runtime after installation.
"""
# 格式化字符串,用于给定模块名的导入错误提示信息,同时提供了关于 TensorFlow 和 PyTorch 的信息
PYTORCH_IMPORT_ERROR_WITH_TF = """
{0} requires the PyTorch library but it was not found in your environment.
However, we were able to find a TensorFlow installation. TensorFlow classes begin
with "TF", but are otherwise identically named to our PyTorch classes. This
means that the TF equivalent of the class you tried to import would be "TF{0}".
If you want to use TensorFlow, please use TF classes instead!
If you really do want to use PyTorch please go to
https://pytorch.org/get-started/locally/ and follow the instructions that
match your environment.
"""
# 格式化字符串,用于给定模块名的导入错误提示信息,同时提供了关于 TensorFlow 和 PyTorch 的信息
TF_IMPORT_ERROR_WITH_PYTORCH = """
{0} requires the TensorFlow library but it was not found in your environment.
However, we were able to find a PyTorch installation. PyTorch classes do not begin
BS4_IMPORT_ERROR = """
{0} requires the Beautiful Soup library but it was not found in your environment. You can install it with pip:
`pip install beautifulsoup4`. Please note that you may need to restart your runtime after installation.
"""
SKLEARN_IMPORT_ERROR = """
{0} requires the scikit-learn library but it was not found in your environment. You can install it with:
pip install -U scikit-learn
In a notebook or a colab, you can install it by executing a cell with
!pip install -U scikit-learn
Please note that you may need to restart your runtime after installation.
"""
TENSORFLOW_IMPORT_ERROR = """
{0} requires the TensorFlow library but it was not found in your environment. Checkout the instructions on the
installation page: https://www.tensorflow.org/install and follow the ones that match your environment.
Please note that you may need to restart your runtime after installation.
"""
DETECTRON2_IMPORT_ERROR = """
{0} requires the detectron2 library but it was not found in your environment. Checkout the instructions on the
installation page: https://github.com/facebookresearch/detectron2/blob/master/INSTALL.md and follow the ones
that match your environment. Please note that you may need to restart your runtime after installation.
"""
FLAX_IMPORT_ERROR = """
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
installation page: https://github.com/google/flax and follow the ones that match your environment.
Please note that you may need to restart your runtime after installation.
"""
FTFY_IMPORT_ERROR = """
{0} requires the ftfy library but it was not found in your environment. Checkout the instructions on the
installation section: https://github.com/rspeer/python-ftfy/tree/master#installing and follow the ones
that match your environment. Please note that you may need to restart your runtime after installation.
"""
LEVENSHTEIN_IMPORT_ERROR = """
{0} requires the python-Levenshtein library but it was not found in your environment. You can install it with pip: `pip
install python-Levenshtein`. Please note that you may need to restart your runtime after installation.
"""
G2P_EN_IMPORT_ERROR = """
{0} requires the g2p-en library but it was not found in your environment. You can install it with pip:
`pip install g2p-en`. Please note that you may need to restart your runtime after installation.
"""
PYTORCH_QUANTIZATION_IMPORT_ERROR = """
"""
TENSORFLOW_PROBABILITY_IMPORT_ERROR = """
{0} requires the tensorflow_probability library but it was not found in your environment. You can install it with pip as
explained here: https://github.com/tensorflow/probability. Please note that you may need to restart your runtime after installation.
"""
TENSORFLOW_TEXT_IMPORT_ERROR = """
{0} requires the tensorflow_text library but it was not found in your environment. You can install it with pip as
explained here: https://www.tensorflow.org/text/guide/tf_text_intro.
Please note that you may need to restart your runtime after installation.
"""
PANDAS_IMPORT_ERROR = """
{0} requires the pandas library but it was not found in your environment. You can install it with pip as
explained here: https://pandas.pydata.org/pandas-docs/stable/getting_started/install.html.
Please note that you may need to restart your runtime after installation.
"""
PHONEMIZER_IMPORT_ERROR = """
{0} requires the phonemizer library but it was not found in your environment. You can install it with pip:
`pip install phonemizer`. Please note that you may need to restart your runtime after installation.
"""
SACREMOSES_IMPORT_ERROR = """
{0} requires the sacremoses library but it was not found in your environment. You can install it with pip:
`pip install sacremoses`. Please note that you may need to restart your runtime after installation.
"""
SCIPY_IMPORT_ERROR = """
{0} requires the scipy library but it was not found in your environment. You can install it with pip:
`pip install scipy`. Please note that you may need to restart your runtime after installation.
"""
SPEECH_IMPORT_ERROR = """
{0} requires the torchaudio library but it was not found in your environment. You can install it with pip:
`pip install torchaudio`. Please note that you may need to restart your runtime after installation.
"""
TIMM_IMPORT_ERROR = """
{0} requires the timm library but it was not found in your environment. You can install it with pip:
`pip install timm`. Please note that you may need to restart your runtime after installation.
"""
NATTEN_IMPORT_ERROR = """
{0} requires the natten library but it was not found in your environment. You can install it by referring to:
shi-labs.com/natten . You can also install it with pip (may take longer to build):
`pip install natten`. Please note that you may need to restart your runtime after installation.
"""
NLTK_IMPORT_ERROR = """
{0} requires the NLTK library but it was not found in your environment. You can install it by referring to:
# 引入 docstyle-ignore,以下注释内容是一些导入错误消息的字符串模板
# 引入 Vision 模块时的导入错误消息模板
VISION_IMPORT_ERROR = """
{0} requires the PIL library but it was not found in your environment. You can install it with pip:
`pip install pillow`. Please note that you may need to restart your runtime after installation.
"""
# 引入 PyTesseract 模块时的导入错误消息模板
PYTESSERACT_IMPORT_ERROR = """
{0} requires the PyTesseract library but it was not found in your environment. You can install it with pip:
`pip install pytesseract`. Please note that you may need to restart your runtime after installation.
"""
# 引入 pyctcdecode 模块时的导入错误消息模板
PYCTCDECODE_IMPORT_ERROR = """
{0} requires the pyctcdecode library but it was not found in your environment. You can install it with pip:
`pip install pyctcdecode`. Please note that you may need to restart your runtime after installation.
"""
# 引入 accelerate 模块时的导入错误消息模板
ACCELERATE_IMPORT_ERROR = """
{0} requires the accelerate library >= {ACCELERATE_MIN_VERSION} it was not found in your environment.
You can install or update it with pip: `pip install --upgrade accelerate`. Please note that you may need to restart your
runtime after installation.
"""
# 引入 torch ccl 模块时的导入错误消息模板
CCL_IMPORT_ERROR = """
{0} requires the torch ccl library but it was not found in your environment. You can install it with pip:
`pip install oneccl_bind_pt -f https://developer.intel.com/ipex-whl-stable`
Please note that you may need to restart your runtime after installation.
"""
# 引入 essentia 模块时的导入错误消息模板
ESSENTIA_IMPORT_ERROR = """
{0} requires essentia library. But that was not found in your environment. You can install them with pip:
`pip install essentia==2.1b6.dev1034`
Please note that you may need to restart your runtime after installation.
"""
# 引入 librosa 模块时的导入错误消息模板
LIBROSA_IMPORT_ERROR = """
{0} requires thes librosa library. But that was not found in your environment. You can install them with pip:
`pip install librosa`
Please note that you may need to restart your runtime after installation.
"""
# 引入 pretty_midi 模块时的导入错误消息模板
PRETTY_MIDI_IMPORT_ERROR = """
{0} requires thes pretty_midi library. But that was not found in your environment. You can install them with pip:
`pip install pretty_midi`
Please note that you may need to restart your runtime after installation.
"""
# 引入 decord 模块时的导入错误消息模板
DECORD_IMPORT_ERROR = """
{0} requires the decord library but it was not found in your environment. You can install it with pip: `pip install
decord`. Please note that you may need to restart your runtime after installation.
"""
# 引入 Cython 模块时的导入错误消息模板
CYTHON_IMPORT_ERROR = """
{0} requires the Cython library but it was not found in your environment. You can install it with pip: `pip install
Cython`. Please note that you may need to restart your runtime after installation.
"""
# 引入 jieba 模块时的导入错误消息模板
JIEBA_IMPORT_ERROR = """
{0} requires the jieba library but it was not found in your environment. You can install it with pip: `pip install
jieba`. Please note that you may need to restart your runtime after installation.
"""
# 引入 PEFT 模块时的注释内容为空,因此无需添加任何注释
PEFT_IMPORT_ERROR = """
BACKENDS_MAPPING = OrderedDict(
[
("bs4", (is_bs4_available, BS4_IMPORT_ERROR)),
("cv2", (is_cv2_available, CV2_IMPORT_ERROR)),
("datasets", (is_datasets_available, DATASETS_IMPORT_ERROR)),
("detectron2", (is_detectron2_available, DETECTRON2_IMPORT_ERROR)),
("essentia", (is_essentia_available, ESSENTIA_IMPORT_ERROR)),
("faiss", (is_faiss_available, FAISS_IMPORT_ERROR)),
("flax", (is_flax_available, FLAX_IMPORT_ERROR)),
("ftfy", (is_ftfy_available, FTFY_IMPORT_ERROR)),
("g2p_en", (is_g2p_en_available, G2P_EN_IMPORT_ERROR)),
("pandas", (is_pandas_available, PANDAS_IMPORT_ERROR)),
("phonemizer", (is_phonemizer_available, PHONEMIZER_IMPORT_ERROR)),
("pretty_midi", (is_pretty_midi_available, PRETTY_MIDI_IMPORT_ERROR)),
("levenshtein", (is_levenshtein_available, LEVENSHTEIN_IMPORT_ERROR)),
("librosa", (is_librosa_available, LIBROSA_IMPORT_ERROR)),
("protobuf", (is_protobuf_available, PROTOBUF_IMPORT_ERROR)),
("pyctcdecode", (is_pyctcdecode_available, PYCTCDECODE_IMPORT_ERROR)),
("pytesseract", (is_pytesseract_available, PYTESSERACT_IMPORT_ERROR)),
("sacremoses", (is_sacremoses_available, SACREMOSES_IMPORT_ERROR)),
("pytorch_quantization", (is_pytorch_quantization_available, PYTORCH_QUANTIZATION_IMPORT_ERROR)),
("sentencepiece", (is_sentencepiece_available, SENTENCEPIECE_IMPORT_ERROR)),
("sklearn", (is_sklearn_available, SKLEARN_IMPORT_ERROR)),
("speech", (is_speech_available, SPEECH_IMPORT_ERROR)),
("tensorflow_probability", (is_tensorflow_probability_available, TENSORFLOW_PROBABILITY_IMPORT_ERROR)),
("tf", (is_tf_available, TENSORFLOW_IMPORT_ERROR)),
("tensorflow_text", (is_tensorflow_text_available, TENSORFLOW_TEXT_IMPORT_ERROR)),
("timm", (is_timm_available, TIMM_IMPORT_ERROR)),
("natten", (is_natten_available, NATTEN_IMPORT_ERROR)),
("nltk", (is_nltk_available, NLTK_IMPORT_ERROR)),
("tokenizers", (is_tokenizers_available, TOKENIZERS_IMPORT_ERROR)),
("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)),
("torchvision", (is_torchvision_available, TORCHVISION_IMPORT_ERROR)),
("vision", (is_vision_available, VISION_IMPORT_ERROR)),
("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)),
("accelerate", (is_accelerate_available, ACCELERATE_IMPORT_ERROR)),
("oneccl_bind_pt", (is_ccl_available, CCL_IMPORT_ERROR)),
("decord", (is_decord_available, DECORD_IMPORT_ERROR)),
("cython", (is_cython_available, CYTHON_IMPORT_ERROR)),
("jieba", (is_jieba_available, JIEBA_IMPORT_ERROR)),
("peft", (is_peft_available, PEFT_IMPORT_ERROR)),
("jinja", (is_jinja_available, JINJA_IMPORT_ERROR)),
]
class DummyObject(type):
"""
Metaclass for the dummy objects. Any class inheriting from it will return the ImportError generated by
`requires_backend` each time a user tries to access any method of that class.
"""
def __getattribute__(cls, key):
if key.startswith("_") and key != "_from_config":
return super().__getattribute__(key)
requires_backends(cls, cls._backends)
def is_torch_fx_proxy(x):
if is_torch_fx_available():
import torch.fx
return isinstance(x, torch.fx.Proxy)
return False
class _LazyModule(ModuleType):
"""
Module class that surfaces all objects but only performs associated imports when the objects are requested.
"""
def __init__(self, name, module_file, import_structure, module_spec=None, extra_objects=None):
super().__init__(name)
self._modules = set(import_structure.keys())
self._class_to_module = {}
for key, values in import_structure.items():
for value in values:
self._class_to_module[value] = key
self.__all__ = list(import_structure.keys()) + list(chain(*import_structure.values()))
self.__file__ = module_file
self.__spec__ = module_spec
self.__path__ = [os.path.dirname(module_file)]
self._objects = {} if extra_objects is None else extra_objects
self._name = name
self._import_structure = import_structure
def __dir__(self):
result = super().__dir__()
for attr in self.__all__:
if attr not in result:
result.append(attr)
return result
def __getattr__(self, name: str) -> Any:
if name in self._objects:
return self._objects[name]
if name in self._modules:
value = self._get_module(name)
elif name in self._class_to_module.keys():
module = self._get_module(self._class_to_module[name])
value = getattr(module, name)
else:
raise AttributeError(f"module {self.__name__} has no attribute {name}")
setattr(self, name, value)
return value
def _get_module(self, module_name: str):
try:
return importlib.import_module("." + module_name, self.__name__)
except Exception as e:
raise RuntimeError(
f"Failed to import {self.__name__}.{module_name} because of the following error (look up to see its"
f" traceback):\n{e}"
) from e
def __reduce__(self):
return (self.__class__, (self._name, self.__file__, self._import_structure))
class OptionalDependencyNotAvailable(BaseException):
"""用于表示未找到可选依赖项的内部错误类。"""
def direct_transformers_import(path: str, file="__init__.py") -> ModuleType:
"""直接导入 transformers 模块
Args:
path (`str`): 源文件的路径
file (`str`, optional): 要与路径拼接的文件名。默认为 "__init__.py".
Returns:
`ModuleType`: 导入的结果模块对象
"""
name = "transformers"
location = os.path.join(path, file)
spec = importlib.util.spec_from_file_location(name, location, submodule_search_locations=[path])
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
module = sys.modules[name]
return module