Transformers 源码解析(六)
.\image_utils.py
import base64
import os
from io import BytesIO
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union
import numpy as np
import requests
from packaging import version
from .utils import (
ExplicitEnum,
is_jax_tensor,
is_tf_tensor,
is_torch_available,
is_torch_tensor,
is_vision_available,
logging,
requires_backends,
to_numpy,
)
from .utils.constants import (
IMAGENET_DEFAULT_MEAN,
IMAGENET_DEFAULT_STD,
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
OPENAI_CLIP_MEAN,
OPENAI_CLIP_STD,
)
if is_vision_available():
import PIL.Image
import PIL.ImageOps
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
PILImageResampling = PIL.Image.Resampling
else:
PILImageResampling = PIL.Image
if TYPE_CHECKING:
if is_torch_available():
import torch
logger = logging.get_logger(__name__)
ImageInput = Union[
"PIL.Image.Image", np.ndarray, "torch.Tensor", List["PIL.Image.Image"], List[np.ndarray], List["torch.Tensor"]
]
class ChannelDimension(ExplicitEnum):
FIRST = "channels_first"
LAST = "channels_last"
class AnnotationFormat(ExplicitEnum):
COCO_DETECTION = "coco_detection"
COCO_PANOPTIC = "coco_panoptic"
class AnnotionFormat(ExplicitEnum):
COCO_DETECTION = AnnotationFormat.COCO_DETECTION.value
COCO_PANOPTIC = AnnotationFormat.COCO_PANOPTIC.value
AnnotationType = Dict[str, Union[int, str, List[Dict]]]
def is_pil_image(img):
return is_vision_available() and isinstance(img, PIL.Image.Image)
def is_valid_image(img):
return (
(is_vision_available() and isinstance(img, PIL.Image.Image))
or isinstance(img, np.ndarray)
or is_torch_tensor(img)
or is_tf_tensor(img)
or is_jax_tensor(img)
)
def valid_images(imgs):
if isinstance(imgs, (list, tuple)):
for img in imgs:
if not valid_images(img):
return False
elif not is_valid_image(imgs):
return False
return True
def is_batched(img):
if isinstance(img, (list, tuple)):
return is_valid_image(img[0])
return False
def is_scaled_image(image: np.ndarray) -> bool:
if image.dtype == np.uint8:
return False
return np.min(image) >= 0 and np.max(image) <= 1
def make_list_of_images(images, expected_ndims: int = 3) -> List[ImageInput]:
if is_batched(images):
return images
if isinstance(images, PIL.Image.Image):
return [images]
if is_valid_image(images):
if images.ndim == expected_ndims + 1:
images = list(images)
elif images.ndim == expected_ndims:
images = [images]
else:
raise ValueError(
f"Invalid image shape. Expected either {expected_ndims + 1} or {expected_ndims} dimensions, but got"
f" {images.ndim} dimensions."
)
return images
raise ValueError(
"Invalid image type. Expected either PIL.Image.Image, numpy.ndarray, torch.Tensor, tf.Tensor or "
f"jax.ndarray, but got {type(images)}."
)
def to_numpy_array(img) -> np.ndarray:
if not is_valid_image(img):
raise ValueError(f"Invalid image type: {type(img)}")
if is_vision_available() and isinstance(img, PIL.Image.Image):
return np.array(img)
return to_numpy(img)
def infer_channel_dimension_format(
image: np.ndarray, num_channels: Optional[Union[int, Tuple[int, ...]]] = None
) -> ChannelDimension:
"""
Infers the channel dimension format of `image`.
Args:
image (`np.ndarray`):
The image to infer the channel dimension of.
num_channels (`int` or `Tuple[int, ...]`, *optional*, defaults to `(1, 3)`):
The number of channels of the image.
Returns:
The channel dimension of the image.
"""
num_channels = num_channels if num_channels is not None else (1, 3)
num_channels = (num_channels,) if isinstance(num_channels, int) else num_channels
if image.ndim == 3:
first_dim, last_dim = 0, 2
elif image.ndim == 4:
first_dim, last_dim = 1, 3
else:
raise ValueError(f"Unsupported number of image dimensions: {image.ndim}")
if image.shape[first_dim] in num_channels:
return ChannelDimension.FIRST
elif image.shape[last_dim] in num_channels:
return ChannelDimension.LAST
raise ValueError("Unable to infer channel dimension format")
def get_channel_dimension_axis(
image: np.ndarray, input_data_format: Optional[Union[ChannelDimension, str]] = None
) -> int:
"""
Returns the channel dimension axis of the image.
Args:
image (`np.ndarray`):
The image to get the channel dimension axis of.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the image. If `None`, will infer the channel dimension from the image.
Returns:
The channel dimension axis of the image.
"""
if input_data_format is None:
input_data_format = infer_channel_dimension_format(image)
if input_data_format == ChannelDimension.FIRST:
return image.ndim - 3
elif input_data_format == ChannelDimension.LAST:
return image.ndim - 1
raise ValueError(f"Unsupported data format: {input_data_format}")
def get_image_size(image: np.ndarray, channel_dim: ChannelDimension = None) -> Tuple[int, int]:
"""
Returns the (height, width) dimensions of the image.
Args:
image (`np.ndarray`):
The image to get the dimensions of.
channel_dim (`ChannelDimension`, *optional*):
Which dimension the channel dimension is in. If `None`, will infer the channel dimension from the image.
Returns:
A tuple of the image's height and width.
"""
if channel_dim is None:
channel_dim = infer_channel_dimension_format(image)
if channel_dim == ChannelDimension.FIRST:
return image.shape[-2], image.shape[-1]
elif channel_dim == ChannelDimension.LAST:
return image.shape[-3], image.shape[-2]
else:
raise ValueError(f"Unsupported data format: {channel_dim}")
def is_valid_annotation_coco_detection(annotation: Dict[str, Union[List, Tuple]]) -> bool:
"""
Checks if the given annotation is a valid COCO detection annotation.
Args:
annotation (`Dict[str, Union[List, Tuple]]`):
The annotation dictionary to validate.
Returns:
`True` if the annotation is valid, `False` otherwise.
"""
if (
isinstance(annotation, dict)
and "image_id" in annotation
and "annotations" in annotation
and isinstance(annotation["annotations"], (list, tuple))
and (
len(annotation["annotations"]) == 0 or isinstance(annotation["annotations"][0], dict)
)
):
return True
return False
def is_valid_annotation_coco_panoptic(annotation: Dict[str, Union[List, Tuple]]) -> bool:
"""
Checks if the given annotation is a valid COCO panoptic segmentation annotation.
Args:
annotation (`Dict[str, Union[List, Tuple]]`):
The annotation dictionary to validate.
Returns:
`True` if the annotation is valid, `False` otherwise.
"""
if (
isinstance(annotation, dict)
and "image_id" in annotation
and "segments_info" in annotation
and "file_name" in annotation
and isinstance(annotation["segments_info"], (list, tuple))
and (
len(annotation["segments_info"]) == 0 or isinstance(annotation["segments_info"][0], dict)
)
):
return True
return False
def valid_coco_detection_annotations(annotations: Iterable[Dict[str, Union[List, Tuple]]]) -> bool:
"""
Checks if all annotations in the given iterable are valid COCO detection annotations.
Args:
annotations (`Iterable[Dict[str, Union[List, Tuple]]]`):
The iterable of annotations to validate.
Returns:
`True` if all annotations are valid, `False` otherwise.
"""
return all(is_valid_annotation_coco_detection(ann) for ann in annotations)
def valid_coco_panoptic_annotations(annotations: Iterable[Dict[str, Union[List, Tuple]]]) -> bool:
return all(is_valid_annotation_coco_panoptic(ann) for ann in annotations)
def load_image(image: Union[str, "PIL.Image.Image"], timeout: Optional[float] = None) -> "PIL.Image.Image":
"""
将 `image` 加载为 PIL 图像。
Args:
image (`str` or `PIL.Image.Image`):
要转换为 PIL 图像格式的图像。
timeout (`float`, *optional*):
URL 请求的超时值(秒)。
Returns:
`PIL.Image.Image`: 一个 PIL 图像。
"""
requires_backends(load_image, ["vision"])
if isinstance(image, str):
if image.startswith("http://") or image.startswith("https://"):
image = PIL.Image.open(requests.get(image, stream=True, timeout=timeout).raw)
elif os.path.isfile(image):
image = PIL.Image.open(image)
else:
if image.startswith("data:image/"):
image = image.split(",")[1]
try:
b64 = base64.b64decode(image, validate=True)
image = PIL.Image.open(BytesIO(b64))
except Exception as e:
raise ValueError(
f"图像源格式错误。必须是以 `http://` 或 `https://` 开头的有效 URL,有效的图像文件路径,或者是 base64 编码的字符串。传入值为 {image}。错误信息:{e}"
)
elif isinstance(image, PIL.Image.Image):
image = image
else:
raise ValueError(
"图像格式不正确。应为指向图像的 URL、base64 字符串、本地路径,或者是一个 PIL 图像。"
)
image = PIL.ImageOps.exif_transpose(image)
image = image.convert("RGB")
return image
def validate_preprocess_arguments(
do_rescale: Optional[bool] = None,
rescale_factor: Optional[float] = None,
do_normalize: Optional[bool] = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_pad: Optional[bool] = None,
size_divisibility: Optional[int] = None,
do_center_crop: Optional[bool] = None,
crop_size: Optional[Dict[str, int]] = None,
do_resize: Optional[bool] = None,
size: Optional[Dict[str, int]] = None,
resample: Optional["PILImageResampling"] = None,
):
"""
检查 `ImageProcessor` 的 `preprocess` 方法中常用参数的有效性。
如果发现参数不兼容,则抛出 `ValueError` 异常。
许多不兼容性是与模型相关的。`do_pad` 有时需要 `size_divisor`,有时需要 `size_divisibility`,有时需要 `size`。
新增的模型和处理器应尽量遵循现有参数的使用规则。
"""
if do_rescale and rescale_factor is None:
raise ValueError("rescale_factor must be specified if do_rescale is True.")
raise ValueError(
"Depending on model, size_divisibility, size_divisor, pad_size or size must be specified if do_pad is True."
)
if do_normalize and (image_mean is None or image_std is None):
raise ValueError("image_mean and image_std must both be specified if do_normalize is True.")
if do_center_crop and crop_size is None:
raise ValueError("crop_size must be specified if do_center_crop is True.")
if do_resize and (size is None or resample is None):
raise ValueError("size and resample must be specified if do_resize is True.")
class ImageFeatureExtractionMixin:
"""
包含用于准备图像特征的工具函数的 Mixin。
"""
def _ensure_format_supported(self, image):
"""
确保图像格式受支持,如果不受支持则引发 ValueError 异常。
Args:
image (`PIL.Image.Image` or `numpy.ndarray` or `torch.Tensor`):
要检查的图像对象。
"""
if not isinstance(image, (PIL.Image.Image, np.ndarray)) and not is_torch_tensor(image):
raise ValueError(
f"Got type {type(image)} which is not supported, only `PIL.Image.Image`, `np.array` and "
"`torch.Tensor` are."
)
def to_pil_image(self, image, rescale=None):
"""
将 `image` 转换为 PIL Image 格式。可选地重新缩放,并在需要时将通道维度放回到最后一个轴。
Args:
image (`PIL.Image.Image` or `numpy.ndarray` or `torch.Tensor`):
要转换为 PIL Image 格式的图像对象。
rescale (`bool`, *optional*):
是否应用缩放因子(使像素值成为介于0到255之间的整数)。如果图像类型是浮点类型,则默认为 `True`。
"""
self._ensure_format_supported(image)
if is_torch_tensor(image):
image = image.numpy()
if isinstance(image, np.ndarray):
if rescale is None:
rescale = isinstance(image.flat[0], np.floating)
if image.ndim == 3 and image.shape[0] in [1, 3]:
image = image.transpose(1, 2, 0)
if rescale:
image = image * 255
image = image.astype(np.uint8)
return PIL.Image.fromarray(image)
return image
def convert_rgb(self, image):
"""
将 `PIL.Image.Image` 转换为 RGB 格式。
Args:
image (`PIL.Image.Image`):
要转换的图像对象。
"""
self._ensure_format_supported(image)
if not isinstance(image, PIL.Image.Image):
return image
return image.convert("RGB")
def rescale(self, image: np.ndarray, scale: Union[float, int]) -> np.ndarray:
"""
缩放 numpy 图像按比例 `scale`。
Args:
image (`numpy.ndarray`):
要缩放的图像数组。
scale (Union[float, int]):
缩放因子。
Returns:
`numpy.ndarray`: 缩放后的图像数组。
"""
self._ensure_format_supported(image)
return image * scale
def to_numpy_array(self, image, rescale=None, channel_first=True):
"""
Converts `image` to a numpy array. Optionally rescales it and puts the channel dimension as the first
dimension.
Args:
image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
The image to convert to a NumPy array.
rescale (`bool`, *optional*):
Whether or not to apply the scaling factor (to make pixel values floats between 0. and 1.). Will
default to `True` if the image is a PIL Image or an array/tensor of integers, `False` otherwise.
channel_first (`bool`, *optional*, defaults to `True`):
Whether or not to permute the dimensions of the image to put the channel dimension first.
"""
self._ensure_format_supported(image)
if isinstance(image, PIL.Image.Image):
image = np.array(image)
if is_torch_tensor(image):
image = image.numpy()
rescale = isinstance(image.flat[0], np.integer) if rescale is None else rescale
if rescale:
image = self.rescale(image.astype(np.float32), 1 / 255.0)
if channel_first and image.ndim == 3:
image = image.transpose(2, 0, 1)
return image
def expand_dims(self, image):
"""
Expands 2-dimensional `image` to 3 dimensions.
Args:
image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
The image to expand.
"""
self._ensure_format_supported(image)
if isinstance(image, PIL.Image.Image):
return image
if is_torch_tensor(image):
image = image.unsqueeze(0)
else:
image = np.expand_dims(image, axis=0)
return image
def normalize(self, image, mean, std, rescale=False):
"""
Normalizes `image` with `mean` and `std`. Note that this will trigger a conversion of `image` to a NumPy array
if it's a PIL Image.
Args:
image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
The image to normalize.
mean (`List[float]` or `np.ndarray` or `torch.Tensor`):
The mean (per channel) to use for normalization.
std (`List[float]` or `np.ndarray` or `torch.Tensor`):
The standard deviation (per channel) to use for normalization.
rescale (`bool`, *optional*, defaults to `False`):
Whether or not to rescale the image to be between 0 and 1. If a PIL image is provided, scaling will
happen automatically.
"""
self._ensure_format_supported(image)
if isinstance(image, PIL.Image.Image):
image = self.to_numpy_array(image, rescale=True)
elif rescale:
if isinstance(image, np.ndarray):
image = self.rescale(image.astype(np.float32), 1 / 255.0)
elif is_torch_tensor(image):
image = self.rescale(image.float(), 1 / 255.0)
if isinstance(image, np.ndarray):
if not isinstance(mean, np.ndarray):
mean = np.array(mean).astype(image.dtype)
if not isinstance(std, np.ndarray):
std = np.array(std).astype(image.dtype)
elif is_torch_tensor(image):
import torch
if not isinstance(mean, torch.Tensor):
mean = torch.tensor(mean)
if not isinstance(std, torch.Tensor):
std = torch.tensor(std)
if image.ndim == 3 and image.shape[0] in [1, 3]:
return (image - mean[:, None, None]) / std[:, None, None]
else:
return (image - mean) / std
def flip_channel_order(self, image):
"""
Flips the channel order of `image` from RGB to BGR, or vice versa. Note that this will trigger a conversion of
`image` to a NumPy array if it's a PIL Image.
Args:
image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
The image whose color channels to flip. If `np.ndarray` or `torch.Tensor`, the channel dimension should
be first.
"""
self._ensure_format_supported(image)
if isinstance(image, PIL.Image.Image):
image = self.to_numpy_array(image)
return image[::-1, :, :]
def rotate(self, image, angle, resample=None, expand=0, center=None, translate=None, fillcolor=None):
"""
Returns a rotated copy of `image`. This method returns a copy of `image`, rotated the given number of degrees
counter clockwise around its centre.
Args:
image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
The image to rotate. If `np.ndarray` or `torch.Tensor`, will be converted to `PIL.Image.Image` before
rotating.
angle (float or int):
The rotation angle in degrees. Positive angles are counter-clockwise.
resample (int, optional):
An optional resampling filter. Default is `PIL.Image.NEAREST`.
expand (bool or int, optional):
Optional expansion flag. If true, the output image size is expanded to contain the entire rotated image.
If integer, it specifies the desired size of the output image (tuple). Default is 0.
center (tuple of int, optional):
Optional center of rotation. Default is None, which means the center is calculated as the center of the image.
translate (tuple of int, optional):
Optional translation offset. Default is None.
fillcolor (tuple or int, optional):
Optional background color given as a single integer value or a tuple of three integers.
Returns:
`PIL.Image.Image`: A rotated `PIL.Image.Image` instance.
"""
resample = resample if resample is not None else PIL.Image.NEAREST
self._ensure_format_supported(image)
if not isinstance(image, PIL.Image.Image):
image = self.to_pil_image(image)
return image.rotate(
angle, resample=resample, expand=expand, center=center, translate=translate, fillcolor=fillcolor
)
def promote_annotation_format(annotation_format: Union[AnnotionFormat, AnnotationFormat]) -> AnnotationFormat:
return AnnotationFormat(annotation_format.value)
def validate_annotations(
annotation_format: AnnotationFormat,
supported_annotation_formats: Tuple[AnnotationFormat, ...],
annotations: List[Dict],
) -> None:
if isinstance(annotation_format, AnnotionFormat):
logger.warning_once(
f"`{annotation_format.__class__.__name__}` is deprecated and will be removed in v4.38. "
f"Please use `{AnnotationFormat.__name__}` instead."
)
annotation_format = promote_annotation_format(annotation_format)
if annotation_format not in supported_annotation_formats:
raise ValueError(f"Unsupported annotation format: {format} must be one of {supported_annotation_formats}")
if annotation_format is AnnotationFormat.COCO_DETECTION:
if not valid_coco_detection_annotations(annotations):
raise ValueError(
"Invalid COCO detection annotations. Annotations must a dict (single image) or list of dicts "
"(batch of images) with the following keys: `image_id` and `annotations`, with the latter "
"being a list of annotations in the COCO format."
)
if annotation_format is AnnotationFormat.COCO_PANOPTIC:
if not valid_coco_panoptic_annotations(annotations):
raise ValueError(
"Invalid COCO panoptic annotations. Annotations must a dict (single image) or list of dicts "
"(batch of images) with the following keys: `image_id`, `file_name` and `segments_info`, with "
"the latter being a list of annotations in the COCO format."
)
def validate_kwargs(valid_processor_keys: List[str], captured_kwargs: List[str]):
unused_keys = set(captured_kwargs).difference(set(valid_processor_keys))
if unused_keys:
unused_key_str = ", ".join(unused_keys)
logger.warning(f"Unused or unrecognized kwargs: {unused_key_str}.")
.\integrations\aqlm.py
"AQLM (Additive Quantization of Language Model) integration file"
from ..utils import is_accelerate_available, is_aqlm_available, is_torch_available
if is_torch_available():
import torch.nn as nn
def replace_with_aqlm_linear(
model,
quantization_config=None,
linear_weights_not_to_quantize=None,
current_key_name=None,
has_been_replaced=False,
):
"""
Public method that recursively replaces the Linear layers of the given model with AQLM quantized layers.
`accelerate` is needed to use this method. Returns the converted model and a boolean that indicates if the
conversion has been successfull or not.
Args:
model (`torch.nn.Module`):
The model to convert, can be any `torch.nn.Module` instance.
quantization_config (`AqlmConfig`):
The quantization config object that contains the quantization parameters.
linear_weights_not_to_quantize (`list[str]`, *optional*):
A list of nn.Linear weights to not convert. If a parameter path is in the list (e.g. `lm_head.weight`), the corresponding module will not be
converted.
current_key_name (`list`, *optional*):
A list that contains the current key name. This is used for recursion and should not be passed by the user.
has_been_replaced (`bool`, *optional*):
A boolean that indicates if the conversion has been successful or not. This is used for recursion and
should not be passed by the user.
"""
if not is_aqlm_available():
raise ValueError("AQLM is not available. Please install it with `pip install aqlm[cpu,gpu]`")
if not is_accelerate_available():
raise ValueError("AQLM requires Accelerate to be installed: `pip install accelerate`")
if linear_weights_not_to_quantize is None:
linear_weights_not_to_quantize = []
from accelerate import init_empty_weights
from aqlm import QuantizedLinear
for name, module in model.named_children():
if current_key_name is None:
current_key_name = []
current_key_name.append(name)
if isinstance(module, nn.Linear):
if ".".join(current_key_name) + ".weight" not in linear_weights_not_to_quantize:
with init_empty_weights():
in_features = module.in_features
out_features = module.out_features
model._modules[name] = QuantizedLinear(
in_features,
out_features,
bias=module.bias is not None,
in_group_size=quantization_config.in_group_size,
out_group_size=quantization_config.out_group_size,
num_codebooks=quantization_config.num_codebooks,
nbits_per_codebook=quantization_config.nbits_per_codebook,
)
has_been_replaced = True
model._modules[name].source_cls = type(module)
model._modules[name].requires_grad_(False)
if len(list(module.children())) > 0:
_, has_been_replaced = replace_with_aqlm_linear(
module,
quantization_config=quantization_config,
linear_weights_not_to_quantize=linear_weights_not_to_quantize,
current_key_name=current_key_name,
has_been_replaced=has_been_replaced,
)
current_key_name.pop(-1)
return model, has_been_replaced
.\integrations\awq.py
"AWQ (Activation aware Weight Quantization) integration file"
from ..activations import ACT2FN
from ..modeling_utils import PreTrainedModel
from ..utils import is_auto_awq_available, is_torch_available
from ..utils.quantization_config import (
AwqBackendPackingMethod,
AwqConfig,
AWQLinearVersion,
ExllamaVersion,
)
if is_torch_available():
import torch
import torch.nn as nn
AWQ_FUSED_MAPPINGS = {
"mistral": {
"attention": ["q_proj", "k_proj", "v_proj", "o_proj"],
"mlp": ["gate_proj", "up_proj", "down_proj"],
"layernorm": ["input_layernorm", "post_attention_layernorm", "norm"],
"use_alibi": False,
},
"mixtral": {
"attention": ["q_proj", "k_proj", "v_proj", "o_proj"],
"mlp": ["w1", "w3", "w2"],
"layernorm": ["input_layernorm", "post_attention_layernorm", "norm"],
"use_alibi": False,
"rope_theta": 1000000.0,
},
"llama": {
"attention": ["q_proj", "k_proj", "v_proj", "o_proj"],
"mlp": ["gate_proj", "up_proj", "down_proj"],
"layernorm": ["input_layernorm", "post_attention_layernorm", "norm"],
"use_alibi": False,
},
"llava": {
"attention": ["q_proj", "k_proj", "v_proj", "o_proj"],
"mlp": ["gate_proj", "up_proj", "down_proj"],
"layernorm": ["input_layernorm", "post_attention_layernorm", "norm"],
"use_alibi": False,
},
}
def replace_with_awq_linear(
model,
modules_to_not_convert=None,
quantization_config=None,
current_key_name=None,
has_been_replaced=False,
) -> bool:
"""
Public method that recursively replaces the Linear layers of the given model with AWQ quantized layers.
`accelerate` is needed to use this method. Returns the converted model and a boolean that indicates if the
conversion has been successful or not.
During the module replacement, we also infer the backend to use through the `quantization_config` object.
"""
Args:
model (`torch.nn.Module`):
要转换的模型,可以是任何 `torch.nn.Module` 实例。
quantization_config (`AwqConfig`):
包含量化参数的量化配置对象。
modules_to_not_convert (`list`, *可选*):
不需要转换的模块列表。如果模块名在列表中(例如 `lm_head`),则不会进行转换。
current_key_name (`list`, *可选*):
包含当前键名的列表。这用于递归,用户不应传递此参数。
has_been_replaced (`bool`, *可选*):
表示转换是否成功的布尔值。这用于递归,用户不应传递此参数。
"""
# 如果未指定不转换的模块列表,则初始化为空列表
if modules_to_not_convert is None:
modules_to_not_convert = []
# 获取量化配置中的后端信息
backend = quantization_config.backend
# 检查是否存在自动 AWQ 支持
if not is_auto_awq_available():
raise ValueError(
"AWQ(`autoawq` 或 `llmawq`)不可用。请使用 `pip install autoawq` 安装或查看安装指南:https://github.com/mit-han-lab/llm-awq"
)
# 根据量化配置选择合适的量化线性层类
if backend == AwqBackendPackingMethod.AUTOAWQ:
if quantization_config.version == AWQLinearVersion.GEMM:
# 导入 GEMM 版本的量化线性层类
from awq.modules.linear.gemm import WQLinear_GEMM
target_cls = WQLinear_GEMM
elif quantization_config.version == AWQLinearVersion.GEMV:
# 导入 GEMV 版本的量化线性层类
from awq.modules.linear.gemv import WQLinear_GEMV
target_cls = WQLinear_GEMV
elif quantization_config.version == AWQLinearVersion.EXLLAMA:
if quantization_config.exllama_config["version"] == ExllamaVersion.ONE:
# 导入 Exllama 版本一的量化线性层类
from awq.modules.linear.exllama import WQLinear_Exllama
target_cls = WQLinear_Exllama
elif quantization_config.exllama_config["version"] == ExllamaVersion.TWO:
# 导入 Exllama 版本二的量化线性层类
from awq.modules.linear.exllamav2 import WQLinear_ExllamaV2
target_cls = WQLinear_ExllamaV2
else:
raise ValueError(f"未知的 Exllama 版本: {quantization_config.exllama_config['version']}")
else:
raise ValueError(f"未知的 AWQ 版本: {quantization_config.version}")
else:
# 若未选择 AUTOAWQ 后端,则使用默认的量化线性层类
from awq.quantize.qmodule import WQLinear
target_cls = WQLinear
# 遍历模型的所有子模块,获取每个子模块的名称和实例
for name, module in model.named_children():
# 如果当前键名为 None,则初始化为一个空列表
if current_key_name is None:
current_key_name = []
# 将当前子模块名称添加到当前键名列表中
current_key_name.append(name)
# 如果当前模块是 nn.Linear 类型,并且其名称不在不转换的模块列表中
if isinstance(module, nn.Linear) and name not in modules_to_not_convert:
# 检查当前键名组合不在不转换模块列表中的任何键中
if not any(key in ".".join(current_key_name) for key in modules_to_not_convert):
# 获取线性层的输入和输出特征数
in_features = module.in_features
out_features = module.out_features
# 将模型中的当前线性层替换为目标类的实例化对象
model._modules[name] = target_cls(
w_bit=quantization_config.bits,
group_size=quantization_config.group_size,
in_features=in_features,
out_features=out_features,
bias=module.bias is not None,
dev=module.weight.device,
)
# 标记该模块已被替换
has_been_replaced = True
# 强制设置该模块的 requires_grad 为 False,以避免意外错误
model._modules[name].requires_grad_(False)
# 如果当前模块还有子模块,则递归调用此函数来替换子模块
if len(list(module.children())) > 0:
_, has_been_replaced = replace_with_awq_linear(
module,
modules_to_not_convert=modules_to_not_convert,
current_key_name=current_key_name,
quantization_config=quantization_config,
has_been_replaced=has_been_replaced,
)
# 移除当前键名列表中的最后一个键名,为递归调用做准备
current_key_name.pop(-1)
# 返回替换后的模型和是否有模块被替换的标志
return model, has_been_replaced
# 返回模型中需要融合的模块映射,根据给定的量化配置和模型
def get_modules_to_fuse(model, quantization_config):
"""
Returns the fusing mapping given the quantization config and the model
Args:
model (`~PreTrainedModel`):
The model to fuse - note this model should have been converted into AWQ format beforehand.
quantization_config (`~transformers.quantization_config.AWQConfig`):
The quantization configuration to use.
"""
# 如果模型不是PreTrainedModel的实例,则抛出错误
if not isinstance(model, PreTrainedModel):
raise ValueError(f"The model should be an instance of `PreTrainedModel`, got {model.__class__.__name__}")
# 总是默认使用 `quantization_config.modules_to_fuse`
if quantization_config.modules_to_fuse is not None:
current_fused_mapping = quantization_config.modules_to_fuse
current_fused_mapping["max_seq_len"] = quantization_config.fuse_max_seq_len
# 如果 `quantization_config.modules_to_fuse` 为None,则根据模型类型在 `AWQ_FUSED_MAPPINGS` 中查找对应的映射
elif model.config.model_type in AWQ_FUSED_MAPPINGS:
current_fused_mapping = AWQ_FUSED_MAPPINGS[model.config.model_type]
# 处理多模态模型的情况(如Llava),区分 `model.config` 和 `model.config.text_config`
if not hasattr(model.config, "text_config"):
config = model.config
else:
config = model.config.text_config
# 单独处理 `hidden_size`、`num_attention_heads` 和 `num_key_value_heads` 字段
hidden_size = config.hidden_size
num_attention_heads = config.num_attention_heads
num_key_value_heads = getattr(config, "num_key_value_heads", num_attention_heads)
# 填充 `current_fused_mapping` 中的预期值
current_fused_mapping["hidden_size"] = hidden_size
current_fused_mapping["num_attention_heads"] = num_attention_heads
current_fused_mapping["num_key_value_heads"] = num_key_value_heads
current_fused_mapping["max_seq_len"] = quantization_config.fuse_max_seq_len
# 如果都没有找到合适的融合映射,则抛出错误
else:
raise ValueError(
"Fusing mapping not found either on the quantization config or the supported `AWQ_FUSED_MAPPINGS`. Please pass a `fused_mapping` argument"
" in the `quantization_config` or raise an issue on transformers https://github.com/huggingface/transformers to add its support."
)
return current_fused_mapping
# 可选地融合模型中的一些模块以加速推断
def fuse_awq_modules(model, quantization_config):
"""
Optionally fuse some modules in the model to speedup inference.
Args:
model (`~PreTrainedModel`):
The model to fuse - note this model should have been converted into AWQ format beforehand.
quantization_config (`Union[AwqConfig, dict]`):
The quantization configuration to use.
"""
# 如果 `quantization_config` 是字典,则将其转换为 AwqConfig 对象以便获取 `backend` 等字段
# 否则这些字段将不可用
# https://github.com/huggingface/transformers/pull/27411#discussion_r1414044495
if isinstance(quantization_config, dict):
quantization_config = AwqConfig.from_dict(quantization_config)
# 获取量化配置中的后端信息
backend = quantization_config.backend
# 获取需要融合的模块列表
modules_to_fuse = get_modules_to_fuse(model, quantization_config)
# 获取不需要转换的模块列表(如果有的话)
modules_to_not_convert = getattr(quantization_config, "modules_to_not_convert", None)
# 检查是否使用自动 AWQ 后端
if backend == AwqBackendPackingMethod.AUTOAWQ:
# 导入 AWQ 后端的融合模块
from awq.modules.fused.attn import QuantAttentionFused
from awq.modules.fused.mlp import QuantFusedMLP
from awq.modules.fused.norm import FasterTransformerRMSNorm
else:
# 抛出数值错误,只支持 AutoAWQ 后端的融合
raise ValueError("Fusing is only supported for the AutoAWQ backend")
# 遍历模型的所有命名模块
for name, module in model.named_modules():
# 如果存在不需要转换的模块列表
if modules_to_not_convert is not None:
# 检查当前模块名是否在不需要转换的模块列表中的任意一个
if any(module_name_to_not_convert in name for module_name_to_not_convert in modules_to_not_convert):
# 如果是,则跳过当前模块的处理
continue
# 替换模型中的 LayerNorm 层
_fuse_awq_layernorm(modules_to_fuse["layernorm"], module, FasterTransformerRMSNorm)
# 替换模型中的 MLP 层
_fuse_awq_mlp(model, name, modules_to_fuse["mlp"], module, QuantFusedMLP)
# 替换模型中的 Attention 层
_fuse_awq_attention_layers(model, module, modules_to_fuse, name, QuantAttentionFused)
# 返回融合后的模型
return model
# 融合 LayerNorm 层到目标类中,使用自动 AWQ(Automatic Weight Quantization)
def _fuse_awq_layernorm(fuse_module_names, module, target_cls):
"""
Fuse the LayerNorm layers into a target class using autoawq
Args:
fuse_module_names (`List[str]`):
The list of module names to fuse
module (`nn.Module`):
The pytorch parent module that has layernorm modules to fuse
target_cls (`~autoawq.FasterTransformerRMSNorm`):
The `FasterTransformerRMSNorm` class as it only supports that class
for now.
"""
# 遍历要融合的模块名列表
for module_name in fuse_module_names:
# 检查父模块是否具有指定的模块名
if hasattr(module, module_name):
# 获取旧的 LayerNorm 模块
old_module = getattr(module, module_name)
# 创建新的 target_cls 类实例来替换旧模块
module._modules[module_name] = target_cls(
old_module.weight,
old_module.variance_epsilon,
).to(old_module.weight.device)
# 删除旧模块,释放内存
del old_module
# 融合 MLP 层到目标类中,使用自动 AWQ
def _fuse_awq_mlp(model, current_module_name, fuse_module_names, module, target_cls):
"""
Fuse the MLP layers into a target class using autoawq
Args:
model (`~PreTrainedModel`):
The input pretrained model
current_module_name (`str`):
The current submodule name
fuse_module_names (`List[str]`):
The list of module names to fuse. For the MLP layers it has to be an array
of length 3 that consists of the 3 MLP layers in the order (gate (dense layer post-attention) / up / down layers)
module (`nn.Module`):
The pytorch parent module that has layernorm modules to fuse
target_cls (`~autoawq.QuantFusedMLP`):
The `QuantFusedMLP` class as it only supports that class
for now.
"""
# 如果没有要融合的模块名,直接返回
if len(fuse_module_names) == 0:
return
# 检查父模块是否具有第一个要融合的模块名
if hasattr(module, fuse_module_names[0]):
# 获取三个 MLP 层的引用
gate_proj = getattr(module, fuse_module_names[0])
up_proj = getattr(module, fuse_module_names[1])
down_proj = getattr(module, fuse_module_names[2])
# 记录 gate_proj 的设备信息
previous_device = gate_proj.qweight.device
# 处理模型具有 `text_config` 属性的情况
hidden_act = (
model.config.hidden_act
if not hasattr(model.config, "text_config")
else model.config.text_config.hidden_act
)
# 根据 hidden_act 获取激活函数
activation_fn = ACT2FN[hidden_act]
# 创建新的 QuantFusedMLP 实例
new_module = target_cls(gate_proj, down_proj, up_proj, activation_fn)
# 分离当前模块的父子名称
parent_name, child_name = current_module_name.rsplit(".", 1)
# 获取父模块
parent = model.get_submodule(parent_name)
# 将新模块设置为子模块的属性,并转移到之前的设备上
setattr(parent, child_name, new_module.to(previous_device))
# 删除临时变量,释放内存
del gate_proj, up_proj, down_proj
def _fuse_awq_attention_layers(model, module, modules_to_fuse, current_module_name, target_cls):
"""
Fuse the Attention layers into a target class using autoawq
"""
# 这部分代码还未提供,需要根据实际情况继续补充
pass
# 导入需要的模块:WQLinear_GEMM 和 WQLinear_GEMV,来自 awq.modules.linear
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
# 如果 modules_to_fuse 字典中的 attention 列表为空,直接返回
if len(modules_to_fuse["attention"]) == 0:
return
# 检查模块是否具有指定名称的属性
if hasattr(module, modules_to_fuse["attention"][0]):
# 获取模块中指定注意力组件的引用
q_proj = getattr(module, modules_to_fuse["attention"][0])
# 根据不同的注意力组件类型选择相应的线性层类和连接维度
if isinstance(q_proj, WQLinear_GEMV):
linear_target_cls = WQLinear_GEMV
cat_dim = 0
elif isinstance(q_proj, WQLinear_GEMM):
linear_target_cls = WQLinear_GEMM
cat_dim = 1
else:
# 如果遇到不支持的 q_proj 类型,则抛出异常
raise ValueError(f"Unsupported q_proj type: {type(q_proj)}")
# 记录 q_proj 的设备信息
previous_device = q_proj.qweight.device
# 获取其他相关的注意力组件
k_proj = getattr(module, modules_to_fuse["attention"][1])
v_proj = getattr(module, modules_to_fuse["attention"][2])
o_proj = getattr(module, modules_to_fuse["attention"][3])
# 合并偏置项,如果存在的话
bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None
# 创建新的量化线性层,整合 QKV 权重、量化零点和缩放因子
qkv_layer = linear_target_cls(
q_proj.w_bit,
q_proj.group_size,
q_proj.in_features,
q_proj.out_features + k_proj.out_features + v_proj.out_features,
q_proj.bias is not None,
next(iter(module.state_dict().values())).device,
)
# 合并 QKV 层的权重、量化零点和缩放因子
qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=cat_dim)
qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=cat_dim)
qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=cat_dim)
# 如果是 WQLinear_GEMV 类型的量化线性层,设置其特有的 split_k_iters 属性
if isinstance(qkv_layer, WQLinear_GEMV):
qkv_layer.split_k_iters = q_proj.split_k_iters
# 设置合并后注意力层的偏置项
qkv_layer.bias = bias
# 创建融合后的注意力层,使用指定的参数初始化
fused_attention_layer = target_cls(
modules_to_fuse["hidden_size"],
modules_to_fuse["num_attention_heads"],
modules_to_fuse["num_key_value_heads"],
qkv_layer,
o_proj,
previous_device,
modules_to_fuse["max_seq_len"],
use_alibi=modules_to_fuse["use_alibi"],
rope_theta=modules_to_fuse.get("rope_theta", 10000.0), # 设置默认的 rope_theta 值为 10000.0
)
# 标记融合后的注意力层是 HF Transformers 的一部分
fused_attention_layer.is_hf_transformers = True
# 将融合后的注意力层设置为模型中对应的子模块
parent_name, child_name = current_module_name.rsplit(".", 1)
parent = model.get_submodule(parent_name)
setattr(parent, child_name, fused_attention_layer.to(previous_device))
# 清理不再需要的变量引用,释放内存
del q_proj, k_proj, v_proj, o_proj
def post_init_awq_exllama_modules(model, exllama_config):
"""
Runs post init for Exllama layers which performs:
- Weights unpacking, reordering and repacking
- Devices scratch space allocation
"""
# 检查配置中的 Exllama 版本是否为版本一
if exllama_config["version"] == ExllamaVersion.ONE:
# 如果是版本一,则导入版本一的初始化函数,并对模型进行处理
from awq.modules.linear.exllama import exllama_post_init
model = exllama_post_init(model)
# 检查配置中的 Exllama 版本是否为版本二
elif exllama_config["version"] == ExllamaVersion.TWO:
# 如果是版本二,则导入版本二的初始化函数,并根据配置参数对模型进行处理
from awq.modules.linear.exllamav2 import exllamav2_post_init
model = exllamav2_post_init(
model,
max_input_len=exllama_config["max_input_len"],
max_batch_size=exllama_config["max_batch_size"],
)
else:
# 如果配置中的 Exllama 版本既不是一也不是二,则抛出异常
raise ValueError(f"Unrecognized Exllama version: {exllama_config['version']}")
# 返回经过 Exllama 模块处理后的模型
return model
.\integrations\bitsandbytes.py
import importlib.metadata
import warnings
from copy import deepcopy
from inspect import signature
from packaging import version
from ..utils import is_accelerate_available, is_bitsandbytes_available, logging
if is_bitsandbytes_available():
import bitsandbytes as bnb
import torch
import torch.nn as nn
from ..pytorch_utils import Conv1D
if is_accelerate_available():
from accelerate import init_empty_weights
from accelerate.utils import find_tied_parameters
logger = logging.get_logger(__name__)
def set_module_quantized_tensor_to_device(module, tensor_name, device, value=None, quantized_stats=None):
"""
A helper function to set a given tensor (parameter of buffer) of a module on a specific device (note that doing
`param.to(device)` creates a new tensor not linked to the parameter, which is why we need this function). The
function is adapted from `set_module_tensor_to_device` function from accelerate that is adapted to support the
class `Int8Params` from `bitsandbytes`.
Args:
module (`torch.nn.Module`):
The module in which the tensor we want to move lives.
tensor_name (`str`):
The full name of the parameter/buffer.
device (`int`, `str` or `torch.device`):
The device on which to set the tensor.
value (`torch.Tensor`, *optional*):
The value of the tensor (useful when going from the meta device to any other device).
quantized_stats (`dict[str, Any]`, *optional*):
Dict with items for either 4-bit or 8-bit serialization
"""
if "." in tensor_name:
splits = tensor_name.split(".")
for split in splits[:-1]:
new_module = getattr(module, split)
if new_module is None:
raise ValueError(f"{module} has no attribute {split}.")
module = new_module
tensor_name = splits[-1]
if tensor_name not in module._parameters and tensor_name not in module._buffers:
raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.")
is_buffer = tensor_name in module._buffers
old_value = getattr(module, tensor_name)
if old_value.device == torch.device("meta") and device not in ["meta", torch.device("meta")] and value is None:
raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {device}.")
prequantized_loading = quantized_stats is not None
if is_buffer or not is_bitsandbytes_available():
is_8bit = False
is_4bit = False
else:
is_4bit = hasattr(bnb.nn, "Params4bit") and isinstance(module._parameters[tensor_name], bnb.nn.Params4bit)
is_8bit = isinstance(module._parameters[tensor_name], bnb.nn.Int8Params)
if is_8bit or is_4bit:
param = module._parameters[tensor_name]
if param.device.type != "cuda":
if value is None:
new_value = old_value.to(device)
elif isinstance(value, torch.Tensor):
new_value = value.to("cpu")
else:
new_value = torch.tensor(value, device="cpu")
if issubclass(module.source_cls, Conv1D) and not prequantized_loading:
new_value = new_value.T
kwargs = old_value.__dict__
if prequantized_loading != (new_value.dtype in (torch.int8, torch.uint8)):
raise ValueError(
f"Value dtype `{new_value.dtype}` is not compatible with parameter quantization status."
)
if is_8bit:
is_8bit_serializable = version.parse(importlib.metadata.version("bitsandbytes")) > version.parse(
"0.37.2"
)
if new_value.dtype in (torch.int8, torch.uint8) and not is_8bit_serializable:
raise ValueError(
"Detected int8 weights but the version of bitsandbytes is not compatible with int8 serialization. "
"Make sure to download the latest `bitsandbytes` version. `pip install --upgrade bitsandbytes`."
)
new_value = bnb.nn.Int8Params(new_value, requires_grad=False, **kwargs).to(device)
if prequantized_loading:
setattr(new_value, "SCB", quantized_stats["SCB"].to(device))
elif is_4bit:
if prequantized_loading:
is_4bit_serializable = version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse(
"0.41.3"
)
if new_value.dtype in (torch.int8, torch.uint8) and not is_4bit_serializable:
raise ValueError(
"Detected 4-bit weights but the version of bitsandbytes is not compatible with 4-bit serialization. "
"Make sure to download the latest `bitsandbytes` version. `pip install --upgrade bitsandbytes`."
)
new_value = bnb.nn.Params4bit.from_prequantized(
data=new_value,
quantized_stats=quantized_stats,
requires_grad=False,
device=device,
**kwargs,
)
else:
new_value = bnb.nn.Params4bit(new_value, requires_grad=False, **kwargs).to(device)
module._parameters[tensor_name] = new_value
else:
if value is None:
new_value = old_value.to(device)
elif isinstance(value, torch.Tensor):
new_value = value.to(device)
else:
new_value = torch.tensor(value, device=device)
if is_buffer:
module._buffers[tensor_name] = new_value
else:
new_value = nn.Parameter(new_value, requires_grad=old_value.requires_grad)
module._parameters[tensor_name] = new_value
def _replace_with_bnb_linear(
model,
modules_to_not_convert=None,
current_key_name=None,
quantization_config=None,
has_been_replaced=False,
):
"""
Private method that wraps the recursion for module replacement.
Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
"""
return model, has_been_replaced
def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name=None, quantization_config=None):
"""
A helper function to replace all `torch.nn.Linear` modules by `bnb.nn.Linear8bit` modules from the `bitsandbytes`
library. This will enable running your models using mixed int8 precision as described by the paper `LLM.int8():
8-bit Matrix Multiplication for Transformers at Scale`. Make sure `bitsandbytes` compiled with the correct CUDA
version of your hardware is installed before running this function. `pip install -i https://test.pypi.org/simple/
bitsandbytes`
The function will be run recursively and replace all `torch.nn.Linear` modules except for the `lm_head` that should
be kept as a `torch.nn.Linear` module. The replacement is done under `init_empty_weights` context manager so no
CPU/GPU memory is required to run this function. Int8 mixed-precision matrix decomposition works by separating a
matrix multiplication into two streams: (1) and systematic feature outlier stream matrix multiplied in fp16
(0.01%), (2) a regular stream of int8 matrix multiplication (99.9%). With this method, int8 inference with no
predictive degradation is possible for very large models (>=176B parameters).
Parameters:
model (`torch.nn.Module`):
Input model or `torch.nn.Module` as the function is run recursively.
modules_to_not_convert (`List[`str`]`, *optional*, defaults to `["lm_head"]`):
Names of the modules to not convert in `Linear8bitLt`. In practice we keep the `lm_head` in full precision
for numerical stability reasons.
current_key_name (`List[`str`]`, *optional*):
An array to track the current key of the recursion. This is used to check whether the current key (part of
it) is not in the list of modules to not convert (for instances modules that are offloaded to `cpu` or
`disk`).
"""
modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert
model, has_been_replaced = _replace_with_bnb_linear(
model, modules_to_not_convert, current_key_name, quantization_config
)
if not has_been_replaced:
logger.warning(
"You are loading your model in 8bit or 4bit but no linear modules were found in your model."
" Please double check your model architecture, or submit an issue on github if you think this is"
" a bug."
)
return model
def replace_8bit_linear(*args, **kwargs):
warnings.warn(
"`replace_8bit_linear` will be deprecated in a future version, please use `replace_with_bnb_linear` instead",
FutureWarning,
)
return replace_with_bnb_linear(*args, **kwargs)
def set_module_8bit_tensor_to_device(*args, **kwargs):
warnings.warn(
"`set_module_8bit_tensor_to_device` will be deprecated in a future version, please use `set_module_quantized_tensor_to_device` instead",
FutureWarning,
)
return set_module_quantized_tensor_to_device(*args, **kwargs)
def get_keys_to_not_convert(model):
r"""
获取模块的键列表,用于指定不转换为 int8 的模块。例如对于 CausalLM 模块,
我们可能希望保持 lm_head 以完整精度,以确保数值稳定性。对于其他架构,
我们可能希望保持模型的 tied weights。该函数将返回一个不需要转换为 int8 的模块键列表。
Parameters:
model (`torch.nn.Module`):
输入的模型
"""
tied_model = deepcopy(model)
tied_model.tie_weights()
tied_params = find_tied_parameters(tied_model)
if isinstance(tied_params, dict):
tied_keys = sum(list(tied_params.values()), []) + list(tied_params.keys())
else:
tied_keys = sum(tied_params, [])
has_tied_params = len(tied_keys) > 0
if not has_tied_params:
output_emb = model.get_output_embeddings()
if output_emb is not None:
list_last_module = [name for name, module in model.named_modules() if id(module) == id(output_emb)]
return list_last_module
list_modules = list(model.named_parameters())
list_last_module = [list_modules[-1][0]]
intersection = set(list_last_module) - set(tied_keys)
list_untouched = list(set(tied_keys)) + list(intersection)
names_to_remove = [".weight", ".bias"]
filtered_module_names = []
for name in list_untouched:
for name_to_remove in names_to_remove:
if name_to_remove in name:
name = name.replace(name_to_remove, "")
filtered_module_names.append(name)
return filtered_module_names
.\integrations\deepspeed.py
"""
Integration with Deepspeed
"""
import copy
import importlib.metadata as importlib_metadata
import importlib.util
import weakref
from functools import partialmethod
from ..dependency_versions_check import dep_version_check
from ..utils import is_accelerate_available, is_torch_available, logging
if is_torch_available():
import torch
logger = logging.get_logger(__name__)
def is_deepspeed_available():
package_exists = importlib.util.find_spec("deepspeed") is not None
if package_exists:
try:
_ = importlib_metadata.metadata("deepspeed")
return True
except importlib_metadata.PackageNotFoundError:
return False
if is_accelerate_available() and is_deepspeed_available():
from accelerate.utils.deepspeed import HfDeepSpeedConfig as DeepSpeedConfig
else:
from builtins import object as DeepSpeedConfig
class HfDeepSpeedConfig(DeepSpeedConfig):
"""
This object contains a DeepSpeed configuration dictionary and can be quickly queried for things like zero stage.
A `weakref` of this object is stored in the module's globals to be able to access the config from areas where
things like the Trainer object is not available (e.g. `from_pretrained` and `_get_resized_embeddings`). Therefore
it's important that this object remains alive while the program is still running.
[`Trainer`] uses the `HfTrainerDeepSpeedConfig` subclass instead. That subclass has logic to sync the configuration
with values of [`TrainingArguments`] by replacing special placeholder values: `"auto"`. Without this special logic
the DeepSpeed configuration is not modified in any way.
Args:
config_file_or_dict (`Union[str, Dict]`): path to DeepSpeed config file or dict.
"""
def __init__(self, config_file_or_dict):
set_hf_deepspeed_config(self)
dep_version_check("accelerate")
dep_version_check("deepspeed")
super().__init__(config_file_or_dict)
class HfTrainerDeepSpeedConfig(HfDeepSpeedConfig):
"""
The `HfTrainerDeepSpeedConfig` object is meant to be created during `TrainingArguments` object creation and has the
same lifespan as the latter.
"""
def __init__(self, config_file_or_dict):
super().__init__(config_file_or_dict)
self._dtype = None
self.mismatches = []
def dtype(self):
if self._dtype is None:
raise ValueError("trainer_config_process() wasn't called yet to tell dtype")
return self._dtype
def is_auto(self, ds_key_long):
val = self.get_value(ds_key_long)
if val is None:
return False
else:
return val == "auto"
def fill_match(self, ds_key_long, hf_val, hf_key=None, must_match=True):
"""
A utility method that massages the config file and can optionally verify that the values match.
1. Replace "auto" values with `TrainingArguments` value.
2. If it wasn't "auto" and `must_match` is true, then check that DS config matches Trainer
config values and if mismatched add the entry to `self.mismatched` - will assert during
`trainer_config_finalize` for one or more mismatches.
"""
config, ds_key = self.find_config_node(ds_key_long)
if config is None:
return
if config.get(ds_key) == "auto":
config[ds_key] = hf_val
return
if not must_match:
return
ds_val = config.get(ds_key)
if ds_val is not None and ds_val != hf_val:
self.mismatches.append(f"- ds {ds_key_long}={ds_val} vs hf {hf_key}={hf_val}")
fill_only = partialmethod(fill_match, must_match=False)
hidden_size_based_keys = [
"zero_optimization.reduce_bucket_size",
"zero_optimization.stage3_prefetch_bucket_size",
"zero_optimization.stage3_param_persistence_threshold",
]
hidden_size_auto_keys = [x for x in hidden_size_based_keys if self.is_auto(x)]
if len(hidden_size_auto_keys) > 0:
if hasattr(model.config, "hidden_size"):
hidden_size = model.config.hidden_size
elif hasattr(model.config, "hidden_sizes"):
hidden_size = max(model.config.hidden_sizes)
else:
raise ValueError(
"The model's config file has neither `hidden_size` nor `hidden_sizes` entry, "
"therefore it's not possible to automatically fill out the following `auto` entries "
f"in the DeepSpeed config file: {hidden_size_auto_keys}. You can fix that by replacing "
"`auto` values for these keys with an integer value of your choice."
)
self.fill_only("zero_optimization.reduce_bucket_size", hidden_size * hidden_size)
if self.is_zero3():
self.fill_only(
"zero_optimization.stage3_prefetch_bucket_size",
0.9 * hidden_size * hidden_size,
)
self.fill_only(
"zero_optimization.stage3_param_persistence_threshold",
10 * hidden_size,
)
self.fill_match(
"scheduler.params.total_num_steps",
num_training_steps,
"num_training_steps (calculated)",
)
self.fill_match(
"scheduler.params.warmup_num_steps",
args.get_warmup_steps(num_training_steps),
"warmup_steps",
)
if len(self.mismatches) > 0:
mismatches = "\n".join(self.mismatches)
raise ValueError(
"Please correct the following DeepSpeed config values that mismatch TrainingArguments"
f" values:\n{mismatches}\nThe easiest method is to set these DeepSpeed config values to 'auto'."
)
_hf_deepspeed_config_weak_ref = None
def set_hf_deepspeed_config(hf_deepspeed_config_obj):
global _hf_deepspeed_config_weak_ref
_hf_deepspeed_config_weak_ref = weakref.ref(hf_deepspeed_config_obj)
def unset_hf_deepspeed_config():
global _hf_deepspeed_config_weak_ref
_hf_deepspeed_config_weak_ref = None
def is_deepspeed_zero3_enabled():
if _hf_deepspeed_config_weak_ref is not None and _hf_deepspeed_config_weak_ref() is not None:
return _hf_deepspeed_config_weak_ref().is_zero3()
else:
return False
def deepspeed_config():
if _hf_deepspeed_config_weak_ref is not None and _hf_deepspeed_config_weak_ref() is not None:
return _hf_deepspeed_config_weak_ref().config
else:
return None
def deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps, model_parameters):
"""
A convenience wrapper that deals with optimizer and lr scheduler configuration.
"""
from accelerate.utils import DummyOptim, DummyScheduler
config = hf_deepspeed_config.config
if "optimizer" in config:
if args.adafactor:
raise ValueError(
"--adafactor was passed, but also found `optimizer` configured in the DeepSpeed config. "
"Only one optimizer can be configured."
)
optimizer = DummyOptim(params=model_parameters)
else:
if hf_deepspeed_config.is_offload():
logger.info(
"Detected ZeRO Offload and non-DeepSpeed optimizers: This combination should work as long as the"
" custom optimizer has both CPU and GPU implementation (except LAMB)"
)
optimizer = trainer.create_optimizer()
config["zero_allow_untested_optimizer"] = True
lr_scheduler = None
if "scheduler" in config:
lr_scheduler = DummyScheduler(optimizer)
else:
if isinstance(optimizer, DummyOptim):
def _lr_scheduler_callable(optimizer):
trainer_copy = copy.copy(trainer)
trainer_copy.lr_scheduler = None
lr_scheduler = trainer_copy.create_scheduler(
num_training_steps=num_training_steps, optimizer=optimizer
)
return lr_scheduler
lr_scheduler = DummyScheduler(optimizer, lr_scheduler_callable=_lr_scheduler_callable)
else:
lr_scheduler = trainer.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer)
return optimizer, lr_scheduler
def deepspeed_init(trainer, num_training_steps, inference=False):
from deepspeed.utils import logger as ds_logger
model = trainer.model
args = trainer.args
hf_deepspeed_config = trainer.accelerator.state.deepspeed_plugin.hf_ds_config
hf_deepspeed_config.trainer_config_finalize(args, model, num_training_steps)
ds_logger.setLevel(args.get_process_log_level())
if inference:
if not hf_deepspeed_config.is_zero3():
raise ValueError("ZeRO inference 只适用于 ZeRO Stage 3,请调整配置")
hf_deepspeed_config.del_config_sub_tree("optimizer")
hf_deepspeed_config.del_config_sub_tree("lr_scheduler")
optimizer, lr_scheduler = None, None
model_parameters = None
else:
trainer.optimizer = None
model_parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
optimizer, lr_scheduler = deepspeed_optim_sched(
trainer, hf_deepspeed_config, args, num_training_steps, model_parameters
)
return optimizer, lr_scheduler
def deepspeed_load_checkpoint(deepspeed_engine, checkpoint_path, load_module_strict=True):
import glob
deepspeed_checkpoint_dirs = sorted(glob.glob(f"{checkpoint_path}/global_step*"))
if len(deepspeed_checkpoint_dirs) > 0:
logger.info(f"Attempting to resume from {checkpoint_path}")
load_path, _ = deepspeed_engine.load_checkpoint(
checkpoint_path,
load_module_strict=load_module_strict,
load_optimizer_states=True,
load_lr_scheduler_states=True,
)
if load_path is None:
raise ValueError(f"[deepspeed] failed to resume from checkpoint {checkpoint_path}")
else:
raise ValueError(f"Can't find a valid checkpoint at {checkpoint_path}")
.\integrations\integration_utils.py
"""
Integrations with other Python libraries.
"""
import functools
import importlib.metadata
import importlib.util
import json
import numbers
import os
import pickle
import shutil
import sys
import tempfile
from dataclasses import asdict, fields
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union
import numpy as np
import packaging.version
from .. import __version__ as version
from ..utils import flatten_dict, is_datasets_available, is_pandas_available, is_torch_available, logging
logger = logging.get_logger(__name__)
if is_torch_available():
import torch
_has_comet = importlib.util.find_spec("comet_ml") is not None and os.getenv("COMET_MODE", "").upper() != "DISABLED"
if _has_comet:
try:
import comet_ml
if hasattr(comet_ml, "config") and comet_ml.config.get_config("comet.api_key"):
_has_comet = True
else:
if os.getenv("COMET_MODE", "").upper() != "DISABLED":
logger.warning("comet_ml is installed but `COMET_API_KEY` is not set.")
_has_comet = False
except (ImportError, ValueError):
_has_comet = False
_has_neptune = (
importlib.util.find_spec("neptune") is not None or importlib.util.find_spec("neptune-client") is not None
)
if TYPE_CHECKING and _has_neptune:
try:
_neptune_version = importlib.metadata.version("neptune")
logger.info(f"Neptune version {_neptune_version} available.")
except importlib.metadata.PackageNotFoundError:
try:
_neptune_version = importlib.metadata.version("neptune-client")
logger.info(f"Neptune-client version {_neptune_version} available.")
except importlib.metadata.PackageNotFoundError:
_has_neptune = False
from ..trainer_callback import ProgressCallback, TrainerCallback
from ..trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun, IntervalStrategy
from ..training_args import ParallelMode
from ..utils import ENV_VARS_TRUE_VALUES, is_torch_xla_available
def is_wandb_available():
if os.getenv("WANDB_DISABLED", "").upper() in ENV_VARS_TRUE_VALUES:
logger.warning(
"Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the "
"--report_to flag to control the integrations used for logging result (for instance --report_to none)."
)
return False
return importlib.util.find_spec("wandb") is not None
def is_clearml_available():
return importlib.util.find_spec("clearml") is not None
def is_comet_available():
return _has_comet
def is_tensorboard_available():
return importlib.util.find_spec("tensorboard") is not None or importlib.util.find_spec("tensorboardX") is not None
def is_optuna_available():
return importlib.util.find_spec("optuna") is not None
def is_ray_available():
return importlib.util.find_spec("ray") is not None
def is_ray_tune_available():
if not is_ray_available():
return False
return importlib.util.find_spec("ray.tune") is not None
def is_sigopt_available():
return importlib.util.find_spec("sigopt") is not None
def is_azureml_available():
if importlib.util.find_spec("azureml") is None:
return False
if importlib.util.find_spec("azureml.core") is None:
return False
return importlib.util.find_spec("azureml.core.run") is not None
def is_mlflow_available():
if os.getenv("DISABLE_MLFLOW_INTEGRATION", "FALSE").upper() == "TRUE":
return False
return importlib.util.find_spec("mlflow") is not None
def is_dagshub_available():
return None not in [importlib.util.find_spec("dagshub"), importlib.util.find_spec("mlflow")]
def is_neptune_available():
return _has_neptune
def is_codecarbon_available():
return importlib.util.find_spec("codecarbon") is not None
def is_flytekit_available():
return importlib.util.find_spec("flytekit") is not None
def is_flyte_deck_standard_available():
if not is_flytekit_available():
return False
return importlib.util.find_spec("flytekitplugins.deck") is not None
def is_dvclive_available():
return importlib.util.find_spec("dvclive") is not None
def hp_params(trial):
if is_optuna_available():
import optuna
if isinstance(trial, optuna.Trial):
return trial.params
if is_ray_tune_available():
if isinstance(trial, dict):
return trial
if is_sigopt_available():
if isinstance(trial, dict):
return trial
if is_wandb_available():
if isinstance(trial, dict):
return trial
raise RuntimeError(f"Unknown type for trial {trial.__class__}")
def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
import optuna
if trainer.args.process_index == 0:
def _objective(trial, checkpoint_dir=None):
checkpoint = None
if checkpoint_dir:
for subdir in os.listdir(checkpoint_dir):
if subdir.startswith(PREFIX_CHECKPOINT_DIR):
checkpoint = os.path.join(checkpoint_dir, subdir)
trainer.objective = None
if trainer.args.world_size > 1:
if trainer.args.parallel_mode != ParallelMode.DISTRIBUTED:
raise RuntimeError("only support DDP optuna HPO for ParallelMode.DISTRIBUTED currently.")
trainer._hp_search_setup(trial)
torch.distributed.broadcast_object_list(pickle.dumps(trainer.args), src=0)
trainer.train(resume_from_checkpoint=checkpoint)
else:
trainer.train(resume_from_checkpoint=checkpoint, trial=trial)
if getattr(trainer, "objective", None) is None:
metrics = trainer.evaluate()
trainer.objective = trainer.compute_objective(metrics)
return trainer.objective
timeout = kwargs.pop("timeout", None)
n_jobs = kwargs.pop("n_jobs", 1)
directions = direction if isinstance(direction, list) else None
direction = None if directions is not None else direction
study = optuna.create_study(direction=direction, directions=directions, **kwargs)
study.optimize(_objective, n_trials=n_trials, timeout=timeout, n_jobs=n_jobs)
if not study._is_multi_objective():
best_trial = study.best_trial
return BestRun(str(best_trial.number), best_trial.value, best_trial.params)
else:
best_trials = study.best_trials
return [BestRun(str(best.number), best.values, best.params) for best in best_trials]
else:
for i in range(n_trials):
trainer.objective = None
args_main_rank = list(pickle.dumps(trainer.args))
if trainer.args.parallel_mode != ParallelMode.DISTRIBUTED:
raise RuntimeError("only support DDP optuna HPO for ParallelMode.DISTRIBUTED currently.")
torch.distributed.broadcast_object_list(args_main_rank, src=0)
args = pickle.loads(bytes(args_main_rank))
for key, value in asdict(args).items():
if key != "local_rank":
setattr(trainer.args, key, value)
trainer.train(resume_from_checkpoint=None)
if getattr(trainer, "objective", None) is None:
metrics = trainer.evaluate()
trainer.objective = trainer.compute_objective(metrics)
return None
def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
import ray
import ray.train
def _objective(trial: dict, local_trainer):
try:
from transformers.utils.notebook import NotebookProgressCallback
if local_trainer.pop_callback(NotebookProgressCallback):
local_trainer.add_callback(ProgressCallback)
except ModuleNotFoundError:
pass
local_trainer.objective = None
checkpoint = ray.train.get_checkpoint()
if checkpoint:
local_trainer.objective = "objective"
with checkpoint.as_directory() as checkpoint_dir:
checkpoint_path = next(Path(checkpoint_dir).glob(f"{PREFIX_CHECKPOINT_DIR}*")).as_posix()
local_trainer.train(resume_from_checkpoint=checkpoint_path, trial=trial)
else:
local_trainer.train(trial=trial)
if getattr(local_trainer, "objective", None) is None:
metrics = local_trainer.evaluate()
local_trainer.objective = local_trainer.compute_objective(metrics)
metrics.update({"objective": local_trainer.objective, "done": True})
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
local_trainer._tune_save_checkpoint(checkpoint_dir=temp_checkpoint_dir)
checkpoint = ray.train.Checkpoint.from_directory(temp_checkpoint_dir)
ray.train.report(metrics, checkpoint=checkpoint)
if not trainer._memory_tracker.skip_memory_metrics:
from ..trainer_utils import TrainerMemoryTracker
logger.warning(
"Memory tracking for your Trainer is currently "
"enabled. Automatically disabling the memory tracker "
"since the memory tracker is not serializable."
)
trainer._memory_tracker = TrainerMemoryTracker(skip_memory_metrics=True)
_tb_writer = trainer.pop_callback(TensorBoardCallback)
trainer.model = None
if "resources_per_trial" not in kwargs:
kwargs["resources_per_trial"] = {"cpu": 1}
if trainer.args.n_gpu > 0:
kwargs["resources_per_trial"]["gpu"] = 1
resource_msg = "1 CPU" + (" and 1 GPU" if trainer.args.n_gpu > 0 else "")
logger.info(
"No `resources_per_trial` arg was passed into "
"`hyperparameter_search`. Setting it to a default value "
f"of {resource_msg} for each trial."
)
gpus_per_trial = kwargs["resources_per_trial"].get("gpu", 0)
trainer.args._n_gpu = gpus_per_trial
if "progress_reporter" not in kwargs:
from ray.tune import CLIReporter
kwargs["progress_reporter"] = CLIReporter(metric_columns=["objective"])
if "scheduler" in kwargs:
from ray.tune.schedulers import ASHAScheduler, HyperBandForBOHB, MedianStoppingRule, PopulationBasedTraining
if isinstance(
kwargs["scheduler"], (ASHAScheduler, MedianStoppingRule, HyperBandForBOHB, PopulationBasedTraining)
) and (not trainer.args.do_eval or trainer.args.evaluation_strategy == IntervalStrategy.NO):
raise RuntimeError(
"You are using {cls} as a scheduler but you haven't enabled evaluation during training. "
"This means your trials will not report intermediate results to Ray Tune, and "
"can thus not be stopped early or used to exploit other trials parameters. "
"If this is what you want, do not use {cls}. If you would like to use {cls}, "
"make sure you pass `do_eval=True` and `evaluation_strategy='steps'` in the "
"Trainer `args`.".format(cls=type(kwargs["scheduler"]).__name__)
)
trainable = ray.tune.with_parameters(_objective, local_trainer=trainer)
@functools.wraps(trainable)
def dynamic_modules_import_trainable(*args, **kwargs):
"""
Wrapper around `tune.with_parameters` to ensure datasets_modules are loaded on each Actor.
Without this, an ImportError will be thrown. See https://github.com/huggingface/transformers/issues/11565.
Assumes that `_objective`, defined above, is a function.
"""
if is_datasets_available():
import datasets.load
dynamic_modules_path = os.path.join(datasets.load.init_dynamic_modules(), "__init__.py")
spec = importlib.util.spec_from_file_location("datasets_modules", dynamic_modules_path)
datasets_modules = importlib.util.module_from_spec(spec)
sys.modules[spec.name] = datasets_modules
spec.loader.exec_module(datasets_modules)
return trainable(*args, **kwargs)
if hasattr(trainable, "__mixins__"):
dynamic_modules_import_trainable.__mixins__ = trainable.__mixins__
analysis = ray.tune.run(
dynamic_modules_import_trainable,
config=trainer.hp_space(None),
num_samples=n_trials,
**kwargs,
)
best_trial = analysis.get_best_trial(metric="objective", mode=direction[:3], scope=trainer.args.ray_scope)
best_run = BestRun(best_trial.trial_id, best_trial.last_result["objective"], best_trial.config, analysis)
if _tb_writer is not None:
trainer.add_callback(_tb_writer)
return best_run
def run_hp_search_wandb(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
from ..integrations import is_wandb_available
if not is_wandb_available():
raise ImportError("This function needs wandb installed: `pip install wandb`")
import wandb
reporting_to_wandb = False
for callback in trainer.callback_handler.callbacks:
if isinstance(callback, WandbCallback):
reporting_to_wandb = True
break
if not reporting_to_wandb:
trainer.add_callback(WandbCallback())
trainer.args.report_to = ["wandb"]
best_trial = {"run_id": None, "objective": None, "hyperparameters": None}
sweep_id = kwargs.pop("sweep_id", None)
project = kwargs.pop("project", None)
name = kwargs.pop("name", None)
entity = kwargs.pop("entity", None)
metric = kwargs.pop("metric", "eval/loss")
sweep_config = trainer.hp_space(None)
sweep_config["metric"]["goal"] = direction
sweep_config["metric"]["name"] = metric
if name:
sweep_config["name"] = name
def _objective():
run = wandb.run if wandb.run else wandb.init()
trainer.state.trial_name = run.name
run.config.update({"assignments": {}, "metric": metric})
config = wandb.config
trainer.objective = None
trainer.train(resume_from_checkpoint=None, trial=vars(config)["_items"])
if getattr(trainer, "objective", None) is None:
metrics = trainer.evaluate()
trainer.objective = trainer.compute_objective(metrics)
format_metrics = rewrite_logs(metrics)
if metric not in format_metrics:
logger.warning(
f"Provided metric {metric} not found. This might result in unexpected sweeps charts. The available"
f" metrics are {format_metrics.keys()}"
)
best_score = False
if best_trial["run_id"] is not None:
if direction == "minimize":
best_score = trainer.objective < best_trial["objective"]
elif direction == "maximize":
best_score = trainer.objective > best_trial["objective"]
if best_score or best_trial["run_id"] is None:
best_trial["run_id"] = run.id
best_trial["objective"] = trainer.objective
best_trial["hyperparameters"] = dict(config)
return trainer.objective
sweep_id = wandb.sweep(sweep_config, project=project, entity=entity) if not sweep_id else sweep_id
logger.info(f"wandb sweep id - {sweep_id}")
wandb.agent(sweep_id, function=_objective, count=n_trials)
return BestRun(best_trial["run_id"], best_trial["objective"], best_trial["hyperparameters"])
def get_available_reporting_integrations():
integrations = []
if is_azureml_available() and not is_mlflow_available():
integrations.append("azure_ml")
if is_comet_available():
integrations.append("comet_ml")
if is_dagshub_available():
integrations.append("dagshub")
if is_dvclive_available():
integrations.append("dvclive")
if is_mlflow_available():
integrations.append("mlflow")
if is_neptune_available():
integrations.append("neptune")
if is_tensorboard_available():
integrations.append("tensorboard")
if is_wandb_available():
integrations.append("wandb")
if is_codecarbon_available():
integrations.append("codecarbon")
if is_clearml_available():
integrations.append("clearml")
return integrations
def rewrite_logs(d):
new_d = {}
eval_prefix = "eval_"
eval_prefix_len = len(eval_prefix)
test_prefix = "test_"
test_prefix_len = len(test_prefix)
for k, v in d.items():
if k.startswith(eval_prefix):
new_d["eval/" + k[eval_prefix_len:]] = v
elif k.startswith(test_prefix):
new_d["test/" + k[test_prefix_len:]] = v
else:
new_d["train/" + k] = v
return new_d
class TensorBoardCallback(TrainerCallback):
"""
A [`TrainerCallback`] that sends the logs to [TensorBoard](https://www.tensorflow.org/tensorboard).
Args:
tb_writer (`SummaryWriter`, *optional*):
The writer to use. Will instantiate one if not set.
"""
def __init__(self, tb_writer=None):
has_tensorboard = is_tensorboard_available()
if not has_tensorboard:
raise RuntimeError(
"TensorBoardCallback requires tensorboard to be installed. Either update your PyTorch version or"
" install tensorboardX."
)
if has_tensorboard:
try:
from torch.utils.tensorboard import SummaryWriter
self._SummaryWriter = SummaryWriter
except ImportError:
try:
from tensorboardX import SummaryWriter
self._SummaryWriter = SummaryWriter
except ImportError:
self._SummaryWriter = None
else:
self._SummaryWriter = None
self.tb_writer = tb_writer
def _init_summary_writer(self, args, log_dir=None):
log_dir = log_dir or args.logging_dir
if self._SummaryWriter is not None:
self.tb_writer = self._SummaryWriter(log_dir=log_dir)
def on_train_begin(self, args, state, control, **kwargs):
if not state.is_world_process_zero:
return
log_dir = None
if state.is_hyper_param_search:
trial_name = state.trial_name
if trial_name is not None:
log_dir = os.path.join(args.logging_dir, trial_name)
if self.tb_writer is None:
self._init_summary_writer(args, log_dir)
if self.tb_writer is not None:
self.tb_writer.add_text("args", args.to_json_string())
if "model" in kwargs:
model = kwargs["model"]
if hasattr(model, "config") and model.config is not None:
model_config_json = model.config.to_json_string()
self.tb_writer.add_text("model_config", model_config_json)
def on_log(self, args, state, control, logs=None, **kwargs):
if not state.is_world_process_zero:
return
if self.tb_writer is None:
self._init_summary_writer(args)
if self.tb_writer is not None:
logs = rewrite_logs(logs)
for k, v in logs.items():
if isinstance(v, (int, float)):
self.tb_writer.add_scalar(k, v, state.global_step)
else:
logger.warning(
"Trainer is attempting to log a value of "
f'"{v}" of type {type(v)} for key "{k}" as a scalar. '
"This invocation of Tensorboard's writer.add_scalar() "
"is incorrect so we dropped this attribute."
)
self.tb_writer.flush()
def on_train_end(self, args, state, control, **kwargs):
if self.tb_writer:
self.tb_writer.close()
self.tb_writer = None
"""
A [`TrainerCallback`] that logs metrics, media, model checkpoints to [Weight and Biases](https://www.wandb.com/).
"""
def __init__(self):
has_wandb = is_wandb_available()
if not has_wandb:
raise RuntimeError("WandbCallback requires wandb to be installed. Run `pip install wandb`.")
if has_wandb:
import wandb
self._wandb = wandb
self._initialized = False
if os.getenv("WANDB_LOG_MODEL", "FALSE").upper() in ENV_VARS_TRUE_VALUES.union({"TRUE"}):
DeprecationWarning(
f"Setting `WANDB_LOG_MODEL` as {os.getenv('WANDB_LOG_MODEL')} is deprecated and will be removed in "
"version 5 of transformers. Use one of `'end'` or `'checkpoint'` instead."
)
logger.info(f"Setting `WANDB_LOG_MODEL` from {os.getenv('WANDB_LOG_MODEL')} to `end` instead")
self._log_model = "end"
else:
self._log_model = os.getenv("WANDB_LOG_MODEL", "false").lower()
def on_train_begin(self, args, state, control, model=None, **kwargs):
if self._wandb is None:
return
hp_search = state.is_hyper_param_search
if hp_search:
self._wandb.finish()
self._initialized = False
args.run_name = None
if not self._initialized:
self.setup(args, state, model, **kwargs)
def on_train_end(self, args, state, control, model=None, tokenizer=None, **kwargs):
if self._wandb is None:
return
if self._log_model in ("end", "checkpoint") and self._initialized and state.is_world_process_zero:
from ..trainer import Trainer
fake_trainer = Trainer(args=args, model=model, tokenizer=tokenizer)
with tempfile.TemporaryDirectory() as temp_dir:
fake_trainer.save_model(temp_dir)
metadata = (
{
k: v
for k, v in dict(self._wandb.summary).items()
if isinstance(v, numbers.Number) and not k.startswith("_")
}
if not args.load_best_model_at_end
else {
f"eval/{args.metric_for_best_model}": state.best_metric,
"train/total_floss": state.total_flos,
}
)
logger.info("Logging model artifacts. ...")
model_name = (
f"model-{self._wandb.run.id}"
if (args.run_name is None or args.run_name == args.output_dir)
else f"model-{self._wandb.run.name}"
)
artifact = self._wandb.Artifact(name=model_name, type="model", metadata=metadata)
for f in Path(temp_dir).glob("*"):
if f.is_file():
with artifact.new_file(f.name, mode="wb") as fa:
fa.write(f.read_bytes())
self._wandb.run.log_artifact(artifact)
def on_log(self, args, state, control, model=None, logs=None, **kwargs):
single_value_scalars = [
"train_runtime",
"train_samples_per_second",
"train_steps_per_second",
"train_loss",
"total_flos",
]
if self._wandb is None:
return
if not self._initialized:
self.setup(args, state, model)
if state.is_world_process_zero:
for k, v in logs.items():
if k in single_value_scalars:
self._wandb.run.summary[k] = v
non_scalar_logs = {k: v for k, v in logs.items() if k not in single_value_scalars}
non_scalar_logs = rewrite_logs(non_scalar_logs)
self._wandb.log({**non_scalar_logs, "train/global_step": state.global_step})
def on_save(self, args, state, control, **kwargs):
if self._log_model == "checkpoint" and self._initialized and state.is_world_process_zero:
checkpoint_metadata = {
k: v
for k, v in dict(self._wandb.summary).items()
if isinstance(v, numbers.Number) and not k.startswith("_")
}
ckpt_dir = f"checkpoint-{state.global_step}"
artifact_path = os.path.join(args.output_dir, ckpt_dir)
logger.info(f"Logging checkpoint artifacts in {ckpt_dir}. ...")
checkpoint_name = (
f"checkpoint-{self._wandb.run.id}"
if (args.run_name is None or args.run_name == args.output_dir)
else f"checkpoint-{self._wandb.run.name}"
)
artifact = self._wandb.Artifact(name=checkpoint_name, type="model", metadata=checkpoint_metadata)
artifact.add_dir(artifact_path)
self._wandb.log_artifact(artifact, aliases=[f"checkpoint-{state.global_step}"])
class CometCallback(TrainerCallback):
"""
A [`TrainerCallback`] that sends the logs to [Comet ML](https://www.comet.ml/site/).
"""
def __init__(self):
if not _has_comet:
raise RuntimeError("CometCallback requires comet-ml to be installed. Run `pip install comet-ml`.")
self._initialized = False
self._log_assets = False
def setup(self, args, state, model):
"""
Setup the optional Comet.ml integration.
Environment:
- **COMET_MODE** (`str`, *optional*, defaults to `ONLINE`):
Whether to create an online, offline experiment or disable Comet logging. Can be `OFFLINE`, `ONLINE`, or
`DISABLED`.
- **COMET_PROJECT_NAME** (`str`, *optional*):
Comet project name for experiments.
- **COMET_OFFLINE_DIRECTORY** (`str`, *optional*):
Folder to use for saving offline experiments when `COMET_MODE` is `OFFLINE`.
- **COMET_LOG_ASSETS** (`str`, *optional*, defaults to `TRUE`):
Whether or not to log training assets (tf event logs, checkpoints, etc), to Comet. Can be `TRUE`, or
`FALSE`.
For a number of configurable items in the environment, see
[here](https://www.comet.ml/docs/python-sdk/advanced/#comet-configuration-variables).
"""
self._initialized = True
log_assets = os.getenv("COMET_LOG_ASSETS", "FALSE").upper()
if log_assets in {"TRUE", "1"}:
self._log_assets = True
if state.is_world_process_zero:
comet_mode = os.getenv("COMET_MODE", "ONLINE").upper()
experiment = None
experiment_kwargs = {"project_name": os.getenv("COMET_PROJECT_NAME", "huggingface")}
if comet_mode == "ONLINE":
experiment = comet_ml.Experiment(**experiment_kwargs)
experiment.log_other("Created from", "transformers")
logger.info("Automatic Comet.ml online logging enabled")
elif comet_mode == "OFFLINE":
experiment_kwargs["offline_directory"] = os.getenv("COMET_OFFLINE_DIRECTORY", "./")
experiment = comet_ml.OfflineExperiment(**experiment_kwargs)
experiment.log_other("Created from", "transformers")
logger.info("Automatic Comet.ml offline logging enabled; use `comet upload` when finished")
if experiment is not None:
experiment._set_model_graph(model, framework="transformers")
experiment._log_parameters(args, prefix="args/", framework="transformers")
if hasattr(model, "config"):
experiment._log_parameters(model.config, prefix="config/", framework="transformers")
def on_train_begin(self, args, state, control, model=None, **kwargs):
if not self._initialized:
self.setup(args, state, model)
def on_log(self, args, state, control, model=None, logs=None, **kwargs):
if not self._initialized:
self.setup(args, state, model)
if state.is_world_process_zero:
experiment = comet_ml.config.get_global_experiment()
if experiment is not None:
experiment._log_metrics(logs, step=state.global_step, epoch=state.epoch, framework="transformers")
def on_train_end(self, args, state, control, **kwargs):
if self._initialized and state.is_world_process_zero:
experiment = comet_ml.config.get_global_experiment()
if experiment is not None:
if self._log_assets is True:
logger.info("Logging checkpoints. This may take time.")
experiment.log_asset_folder(
args.output_dir, recursive=True, log_file_name=True, step=state.global_step
)
experiment.end()
class AzureMLCallback(TrainerCallback):
"""
A [`TrainerCallback`] that sends the logs to [AzureML](https://pypi.org/project/azureml-sdk/).
"""
def __init__(self, azureml_run=None):
if not is_azureml_available():
raise RuntimeError("AzureMLCallback requires azureml to be installed. Run `pip install azureml-sdk`.")
self.azureml_run = azureml_run
def on_init_end(self, args, state, control, **kwargs):
from azureml.core.run import Run
if self.azureml_run is None and state.is_world_process_zero:
self.azureml_run = Run.get_context()
def on_log(self, args, state, control, logs=None, **kwargs):
if self.azureml_run and state.is_world_process_zero:
for k, v in logs.items():
if isinstance(v, (int, float)):
self.azureml_run.log(k, v, description=k)
class MLflowCallback(TrainerCallback):
"""
A [`TrainerCallback`] that sends the logs to [MLflow](https://www.mlflow.org/). Can be disabled by setting
environment variable `DISABLE_MLFLOW_INTEGRATION = TRUE`.
"""
def __init__(self):
if not is_mlflow_available():
raise RuntimeError("MLflowCallback requires mlflow to be installed. Run `pip install mlflow`.")
import mlflow
self._MAX_PARAM_VAL_LENGTH = mlflow.utils.validation.MAX_PARAM_VAL_LENGTH
self._MAX_PARAMS_TAGS_PER_BATCH = mlflow.utils.validation.MAX_PARAMS_TAGS_PER_BATCH
self._initialized = False
self._auto_end_run = False
self._log_artifacts = False
self._ml_flow = mlflow
def on_train_begin(self, args, state, control, model=None, **kwargs):
if not self._initialized:
self.setup(args, state, model)
def on_log(self, args, state, control, logs, model=None, **kwargs):
if not self._initialized:
self.setup(args, state, model)
if state.is_world_process_zero:
metrics = {}
for k, v in logs.items():
if isinstance(v, (int, float)):
metrics[k] = v
else:
logger.warning(
f'Trainer is attempting to log a value of "{v}" of type {type(v)} for key "{k}" as a metric. '
"MLflow's log_metric() only accepts float and int types so we dropped this attribute."
)
if self._async_log:
self._ml_flow.log_metrics(metrics=metrics, step=state.global_step, synchronous=False)
else:
self._ml_flow.log_metrics(metrics=metrics, step=state.global_step)
def on_train_end(self, args, state, control, **kwargs):
if self._initialized and state.is_world_process_zero:
if self._auto_end_run and self._ml_flow.active_run():
self._ml_flow.end_run()
def on_save(self, args, state, control, **kwargs):
if self._initialized and state.is_world_process_zero and self._log_artifacts:
ckpt_dir = f"checkpoint-{state.global_step}"
artifact_path = os.path.join(args.output_dir, ckpt_dir)
logger.info(f"Logging checkpoint artifacts in {ckpt_dir}. This may take time.")
self._ml_flow.pyfunc.log_model(
ckpt_dir,
artifacts={"model_path": artifact_path},
python_model=self._ml_flow.pyfunc.PythonModel(),
)
def __del__(self):
if (
self._auto_end_run
and callable(getattr(self._ml_flow, "active_run", None))
and self._ml_flow.active_run() is not None
):
self._ml_flow.end_run()
class DagsHubCallback(MLflowCallback):
"""
A [`TrainerCallback`] that logs to [DagsHub](https://dagshub.com/). Extends [`MLflowCallback`]
"""
def __init__(self):
super().__init__()
if not is_dagshub_available():
raise ImportError("DagsHubCallback requires dagshub to be installed. Run `pip install dagshub`.")
from dagshub.upload import Repo
self.Repo = Repo
def setup(self, *args, **kwargs):
"""
Setup the DagsHub's Logging integration.
Environment:
- **HF_DAGSHUB_LOG_ARTIFACTS** (`str`, *optional*):
Whether to save the data and model artifacts for the experiment. Default to `False`.
"""
self.log_artifacts = os.getenv("HF_DAGSHUB_LOG_ARTIFACTS", "FALSE").upper() in ENV_VARS_TRUE_VALUES
self.name = os.getenv("HF_DAGSHUB_MODEL_NAME") or "main"
self.remote = os.getenv("MLFLOW_TRACKING_URI")
self.repo = self.Repo(
owner=self.remote.split(os.sep)[-2],
name=self.remote.split(os.sep)[-1].split(".")[0],
branch=os.getenv("BRANCH") or "main",
)
self.path = Path("artifacts")
if self.remote is None:
raise RuntimeError(
"DagsHubCallback requires the `MLFLOW_TRACKING_URI` environment variable to be set. Did you run"
" `dagshub.init()`?"
)
super().setup(*args, **kwargs)
def on_train_end(self, args, state, control, **kwargs):
if self.log_artifacts:
if getattr(self, "train_dataloader", None):
torch.save(self.train_dataloader.dataset, os.path.join(args.output_dir, "dataset.pt"))
self.repo.directory(str(self.path)).add_dir(args.output_dir)
class NeptuneMissingConfiguration(Exception):
def __init__(self):
super().__init__(
"""
------ Unsupported ---- We were not able to create new runs. You provided a custom Neptune run to
`NeptuneCallback` with the `run` argument. For the integration to work fully, provide your `api_token` and
`project` by saving them as environment variables or passing them to the callback.
"""
)
class NeptuneCallback(TrainerCallback):
"""TrainerCallback that sends the logs to [Neptune](https://app.neptune.ai).
# NeptuneLogger 类定义,用于集成 Transformers 框架与 Neptune 平台
class NeptuneLogger:
# 集成版本键,用于 Neptune 运行日志中的源代码路径
integration_version_key = "source_code/integrations/transformers"
# 模型参数键,用于 Neptune 运行日志中的模型参数
model_parameters_key = "model_parameters"
# 实验名称键,用于 Neptune 运行日志中的实验名称
trial_name_key = "trial"
# 实验参数键,用于 Neptune 运行日志中的实验参数
trial_params_key = "trial_params"
# 训练器参数键,用于 Neptune 运行日志中的训练器参数
trainer_parameters_key = "trainer_parameters"
# 扁平化指标,用于 Neptune 运行日志中的扁平化指标记录
flat_metrics = {"train/epoch"}
# NeptuneLogger 类的初始化方法
def __init__(
self,
*,
api_token: Optional[str] = None, # Neptune API token,可选参数,用于身份验证
project: Optional[str] = None, # Neptune 项目名称,可选参数,指定要记录的项目
name: Optional[str] = None, # 自定义运行名称,可选参数,指定 Neptune 运行的名称
base_namespace: str = "finetuning", # 基础命名空间,默认为 "finetuning",用于 Neptune 日志的根命名空间
run=None, # Neptune 运行对象,可选参数,如果要继续记录到现有运行中
log_parameters: bool = True, # 是否记录训练器参数和模型参数的标志,可选参数,默认为 True
log_checkpoints: Optional[str] = None, # 检查点记录选项,可选参数,指定何时上传检查点文件
**neptune_run_kwargs, # 其他 Neptune 初始化函数的关键字参数,用于创建新的 Neptune 运行时
):
):
# 检查 Neptune 是否可用,如果不可用则抛出 ValueError 异常
if not is_neptune_available():
raise ValueError(
"NeptuneCallback requires the Neptune client library to be installed. "
"To install the library, run `pip install neptune`."
)
try:
# 尝试导入 Neptune 相关模块
from neptune import Run
from neptune.internal.utils import verify_type
except ImportError:
# 如果导入失败,则尝试从新路径导入 Neptune 相关模块
from neptune.new.internal.utils import verify_type
from neptune.new.metadata_containers.run import Run
# 验证参数类型
verify_type("api_token", api_token, (str, type(None)))
verify_type("project", project, (str, type(None)))
verify_type("name", name, (str, type(None)))
verify_type("base_namespace", base_namespace, str)
verify_type("run", run, (Run, type(None)))
verify_type("log_parameters", log_parameters, bool)
verify_type("log_checkpoints", log_checkpoints, (str, type(None)))
# 设置内部变量
self._base_namespace_path = base_namespace
self._log_parameters = log_parameters
self._log_checkpoints = log_checkpoints
self._initial_run: Optional[Run] = run
# 初始化变量
self._run = None
self._is_monitoring_run = False
self._run_id = None
self._force_reset_monitoring_run = False
self._init_run_kwargs = {"api_token": api_token, "project": project, "name": name, **neptune_run_kwargs}
self._volatile_checkpoints_dir = None
self._should_upload_checkpoint = self._log_checkpoints is not None
self._recent_checkpoint_path = None
# 根据 log_checkpoints 的值设置目标检查点命名空间和是否清理最近上传的检查点
if self._log_checkpoints in {"last", "best"}:
self._target_checkpoints_namespace = f"checkpoints/{self._log_checkpoints}"
self._should_clean_recently_uploaded_checkpoint = True
else:
self._target_checkpoints_namespace = "checkpoints"
self._should_clean_recently_uploaded_checkpoint = False
# 如果运行实例存在,则停止该运行实例
def _stop_run_if_exists(self):
if self._run:
self._run.stop()
del self._run
self._run = None
# 初始化 Neptune 运行实例
def _initialize_run(self, **additional_neptune_kwargs):
try:
# 尝试从 neptune 包中导入 init_run 和异常处理类
from neptune import init_run
from neptune.exceptions import NeptuneMissingApiTokenException, NeptuneMissingProjectNameException
except ImportError:
# 如果导入失败,则尝试从新路径导入 init_run 和异常处理类
from neptune.new import init_run
from neptune.new.exceptions import NeptuneMissingApiTokenException, NeptuneMissingProjectNameException
# 停止已存在的运行实例
self._stop_run_if_exists()
try:
# 创建运行实例的参数集合
run_params = additional_neptune_kwargs.copy()
run_params.update(self._init_run_kwargs)
# 使用参数初始化运行实例,并获取运行实例的 ID
self._run = init_run(**run_params)
self._run_id = self._run["sys/id"].fetch()
except (NeptuneMissingProjectNameException, NeptuneMissingApiTokenException) as e:
# 如果缺少项目名或 API token,则抛出 NeptuneMissingConfiguration 异常
raise NeptuneMissingConfiguration() from e
# 将初始运行设置为当前运行,并开启监控模式
def _use_initial_run(self):
self._run = self._initial_run
self._is_monitoring_run = True
self._run_id = self._run["sys/id"].fetch()
self._initial_run = None
# 确保存在带监控的运行环境
def _ensure_run_with_monitoring(self):
if self._initial_run is not None:
# 如果存在初始运行,则使用初始运行
self._use_initial_run()
else:
if not self._force_reset_monitoring_run and self._is_monitoring_run:
return
if self._run and not self._is_monitoring_run and not self._force_reset_monitoring_run:
# 如果存在运行环境但未开启监控,则重新初始化运行并开启监控
self._initialize_run(with_id=self._run_id)
self._is_monitoring_run = True
else:
# 否则,初始化一个新的运行环境
self._initialize_run()
self._force_reset_monitoring_run = False
# 确保至少存在一个不带监控的运行环境
def _ensure_at_least_run_without_monitoring(self):
if self._initial_run is not None:
# 如果存在初始运行,则使用初始运行
self._use_initial_run()
else:
if not self._run:
# 如果没有运行环境,则初始化一个新的运行环境,不捕获 stdout、stderr、硬件指标和 traceback
self._initialize_run(
with_id=self._run_id,
capture_stdout=False,
capture_stderr=False,
capture_hardware_metrics=False,
capture_traceback=False,
)
self._is_monitoring_run = False
# 返回当前运行环境
@property
def run(self):
if self._run is None:
self._ensure_at_least_run_without_monitoring()
return self._run
# 返回运行环境的元数据命名空间
@property
def _metadata_namespace(self):
return self.run[self._base_namespace_path]
# 记录集成版本号到运行环境中
def _log_integration_version(self):
self.run[NeptuneCallback.integration_version_key] = version
# 记录训练器参数到运行环境的元数据命名空间中
def _log_trainer_parameters(self, args):
self._metadata_namespace[NeptuneCallback.trainer_parameters_key] = args.to_sanitized_dict()
# 记录模型参数到运行环境的元数据命名空间中
def _log_model_parameters(self, model):
from neptune.utils import stringify_unsupported
if model and hasattr(model, "config") and model.config is not None:
self._metadata_namespace[NeptuneCallback.model_parameters_key] = stringify_unsupported(
model.config.to_dict()
)
# 记录超参数搜索参数到运行环境的元数据命名空间中
def _log_hyper_param_search_parameters(self, state):
if state and hasattr(state, "trial_name"):
self._metadata_namespace[NeptuneCallback.trial_name_key] = state.trial_name
if state and hasattr(state, "trial_params") and state.trial_params is not None:
self._metadata_namespace[NeptuneCallback.trial_params_key] = state.trial_params
# 将源目录和检查点路径合并成目标路径
target_path = relative_path = os.path.join(source_directory, checkpoint)
# 如果存在易失性检查点目录
if self._volatile_checkpoints_dir is not None:
# 构建一致性检查点路径
consistent_checkpoint_path = os.path.join(self._volatile_checkpoints_dir, checkpoint)
try:
# 从相对路径中移除开头的 ../,并去掉开头的路径分隔符
cpkt_path = relative_path.replace("..", "").lstrip(os.path.sep)
copy_path = os.path.join(consistent_checkpoint_path, cpkt_path)
# 复制整个目录树到一致性检查点路径
shutil.copytree(relative_path, copy_path)
# 更新目标路径为一致性检查点路径
target_path = consistent_checkpoint_path
except IOError as e:
# 如果复制过程中出现 I/O 异常,则记录警告信息
logger.warning(
"NeptuneCallback was unable to made a copy of checkpoint due to I/O exception: '{}'. "
"Could fail trying to upload.".format(e)
)
# 将目标路径中的文件上传到 Neptune 中
self._metadata_namespace[self._target_checkpoints_namespace].upload_files(target_path)
# 如果需要清理最近上传的检查点,并且最近的检查点路径不为 None,则删除它
if self._should_clean_recently_uploaded_checkpoint and self._recent_checkpoint_path is not None:
self._metadata_namespace[self._target_checkpoints_namespace].delete_files(self._recent_checkpoint_path)
# 更新最近的检查点路径为相对路径
self._recent_checkpoint_path = relative_path
def on_init_end(self, args, state, control, **kwargs):
# 初始化易失性检查点目录为 None
self._volatile_checkpoints_dir = None
# 如果需要记录检查点,并且要求覆盖输出目录或设置了保存总数限制,则创建一个临时目录用于存储检查点
if self._log_checkpoints and (args.overwrite_output_dir or args.save_total_limit is not None):
self._volatile_checkpoints_dir = tempfile.TemporaryDirectory().name
# 如果要求记录最佳检查点但未设置在训练结束时加载最佳模型,则引发 ValueError 异常
if self._log_checkpoints == "best" and not args.load_best_model_at_end:
raise ValueError("To save the best model checkpoint, the load_best_model_at_end argument must be enabled.")
def on_train_begin(self, args, state, control, model=None, **kwargs):
# 如果不是全局进程的主进程,则直接返回
if not state.is_world_process_zero:
return
# 确保在监控下运行
self._ensure_run_with_monitoring()
# 强制重置监控运行状态
self._force_reset_monitoring_run = True
# 记录集成版本信息
self._log_integration_version()
# 如果需要记录参数,则记录训练器参数和模型参数
if self._log_parameters:
self._log_trainer_parameters(args)
self._log_model_parameters(model)
# 如果是超参数搜索状态,则记录超参数搜索参数
if state.is_hyper_param_search:
self._log_hyper_param_search_parameters(state)
def on_train_end(self, args, state, control, **kwargs):
# 如果存在运行,则停止该运行
self._stop_run_if_exists()
def __del__(self):
# 如果存在易失性检查点目录,则删除该目录及其内容,忽略所有错误
if self._volatile_checkpoints_dir is not None:
shutil.rmtree(self._volatile_checkpoints_dir, ignore_errors=True)
# 停止 Neptune 运行,如果存在的话
self._stop_run_if_exists()
def on_save(self, args, state, control, **kwargs):
# 如果需要上传检查点,则记录模型检查点
if self._should_upload_checkpoint:
self._log_model_checkpoint(args.output_dir, f"checkpoint-{state.global_step}")
# 定义一个方法,处理评估时的回调函数
def on_evaluate(self, args, state, control, metrics=None, **kwargs):
# 如果设置了日志保存最佳模型
if self._log_checkpoints == "best":
# 获取用于最佳模型判定的指标名称
best_metric_name = args.metric_for_best_model
# 如果指标名称不以"eval_"开头,则添加前缀"eval_"
if not best_metric_name.startswith("eval_"):
best_metric_name = f"eval_{best_metric_name}"
# 获取指定名称的指标值
metric_value = metrics.get(best_metric_name)
# 根据参数指定的条件判断函数选择比较操作符
operator = np.greater if args.greater_is_better else np.less
# 判断是否应上传检查点,判断标准是当前指标值是否优于之前保存的最佳指标值
self._should_upload_checkpoint = state.best_metric is None or operator(metric_value, state.best_metric)
# 类方法:获取训练器关联的 NeptuneCallback 实例的运行配置
@classmethod
def get_run(cls, trainer):
# 遍历训练器回调处理程序中的回调函数
for callback in trainer.callback_handler.callbacks:
# 如果回调函数是 NeptuneCallback 的实例,则返回其运行配置
if isinstance(callback, cls):
return callback.run
# 如果没有 NeptuneCallback 配置,抛出异常
raise Exception("The trainer doesn't have a NeptuneCallback configured.")
# 定义一个方法,处理记录日志时的回调函数
def on_log(self, args, state, control, logs: Optional[Dict[str, float]] = None, **kwargs):
# 如果不是全局进程的主进程,直接返回
if not state.is_world_process_zero:
return
# 如果有日志内容
if logs is not None:
# 对每个重写后的日志项进行处理
for name, value in rewrite_logs(logs).items():
# 如果值是整数或浮点数
if isinstance(value, (int, float)):
# 如果日志名称在 NeptuneCallback 的平坦指标中
if name in NeptuneCallback.flat_metrics:
# 将值记录到元数据命名空间中
self._metadata_namespace[name] = value
else:
# 否则,将值记录到元数据命名空间中并指定步骤为全局步骤数
self._metadata_namespace[name].log(value, step=state.global_step)
# 定义一个名为 CodeCarbonCallback 的类,继承自 TrainerCallback 类
class CodeCarbonCallback(TrainerCallback):
"""
A [`TrainerCallback`] that tracks the CO2 emission of training.
"""
# 初始化方法
def __init__(self):
# 检查是否安装了 codecarbon 库,若未安装则引发运行时错误
if not is_codecarbon_available():
raise RuntimeError(
"CodeCarbonCallback requires `codecarbon` to be installed. Run `pip install codecarbon`."
)
# 导入 codecarbon 库
import codecarbon
# 将 codecarbon 模块赋值给 self._codecarbon
self._codecarbon = codecarbon
# 初始化追踪器为 None
self.tracker = None
# 当初始化结束时触发的回调方法
def on_init_end(self, args, state, control, **kwargs):
# 如果追踪器为 None 并且是本地进程的第零号进程
if self.tracker is None and state.is_local_process_zero:
# 使用指定的输出目录创建 CO2 排放追踪器对象
self.tracker = self._codecarbon.EmissionsTracker(output_dir=args.output_dir)
# 当训练开始时触发的回调方法
def on_train_begin(self, args, state, control, model=None, **kwargs):
# 如果追踪器存在并且是本地进程的第零号进程
if self.tracker and state.is_local_process_zero:
# 启动 CO2 排放追踪器
self.tracker.start()
# 当训练结束时触发的回调方法
def on_train_end(self, args, state, control, **kwargs):
# 如果追踪器存在并且是本地进程的第零号进程
if self.tracker and state.is_local_process_zero:
# 停止 CO2 排放追踪器
self.tracker.stop()
# 定义一个名为 ClearMLCallback 的类,继承自 TrainerCallback 类
class ClearMLCallback(TrainerCallback):
"""
A [`TrainerCallback`] that sends the logs to [ClearML](https://clear.ml/).
Environment:
- **CLEARML_PROJECT** (`str`, *optional*, defaults to `HuggingFace Transformers`):
ClearML project name.
- **CLEARML_TASK** (`str`, *optional*, defaults to `Trainer`):
ClearML task name.
- **CLEARML_LOG_MODEL** (`bool`, *optional*, defaults to `False`):
Whether to log models as artifacts during training.
"""
# 类级别的属性
log_suffix = ""
_hparams_section = "Transformers"
_model_config_section = "Model Configuration"
_ignore_hparams_overrides = "_ignore_hparams_ui_overrides_"
_ignoge_model_config_overrides = "_ignore_model_config_ui_overrides_"
_model_config_description = "The configuration of model number {}."
_model_config_description_note = (
"Note that, when cloning this task and running it remotely,"
" the configuration might be applied to another model instead of this one."
" To avoid this, initialize the task externally by calling `Task.init`"
" before the `ClearMLCallback` is instantiated."
)
_train_run_counter = 0
_model_connect_counter = 0
_task_created_in_callback = False
_should_close_on_train_end = None
# 初始化方法
def __init__(self):
# 检查是否安装了 clearml 库,若未安装则引发运行时错误
if is_clearml_available():
import clearml
# 导入 clearml 库
self._clearml = clearml
else:
raise RuntimeError("ClearMLCallback requires 'clearml' to be installed. Run `pip install clearml`.")
# 初始化标志为 False
self._initialized = False
# 初始化 ClearML 任务为 None
self._clearml_task = None
# 初始化日志模型为 False
self._log_model = False
# 初始化检查点保存列表为空列表
self._checkpoints_saved = []
# 当训练开始时调用的方法,初始化训练过程中需要用到的参数和模型
def on_train_begin(self, args, state, control, model=None, tokenizer=None, **kwargs):
# 如果未初始化 ClearML,直接返回
if self._clearml is None:
return
# 初始化一个空列表来存储保存的检查点文件名
self._checkpoints_saved = []
# 如果当前训练是超参数搜索,则标记为未初始化状态
if state.is_hyper_param_search:
self._initialized = False
# 如果未初始化,调用 setup 方法来设置参数、模型和分词器等
if not self._initialized:
self.setup(args, state, model, tokenizer, **kwargs)
# 当训练结束时调用的方法,用于清理和关闭 ClearML 相关的任务和计数器
def on_train_end(self, args, state, control, **kwargs):
# 如果应该在训练结束时关闭 ClearML 任务,则关闭当前任务
if ClearMLCallback._should_close_on_train_end:
self._clearml_task.close()
# 重置训练运行计数器为零
ClearMLCallback._train_run_counter = 0
# 定义一个方法,用于处理日志信息,将其发送到 ClearML 平台
def on_log(self, args, state, control, model=None, tokenizer=None, logs=None, **kwargs):
# 如果 ClearML 客户端未初始化,则直接返回
if self._clearml is None:
return
# 如果未初始化,则进行初始化设置
if not self._initialized:
self.setup(args, state, model, tokenizer, **kwargs)
# 如果是全局进程的第一个进程(通常是主进程)
if state.is_world_process_zero:
# 定义评估数据的前缀和长度
eval_prefix = "eval_"
eval_prefix_len = len(eval_prefix)
# 定义测试数据的前缀和长度
test_prefix = "test_"
test_prefix_len = len(test_prefix)
# 定义单值标量的列表,这些值通常用于表示单个标量的日志信息
single_value_scalars = [
"train_runtime",
"train_samples_per_second",
"train_steps_per_second",
"train_loss",
"total_flos",
"epoch",
]
# 遍历日志中的每个键值对
for k, v in logs.items():
# 如果值 v 是整数或浮点数
if isinstance(v, (int, float)):
# 如果键 k 在单值标量列表中,则将其作为单值报告到 ClearML
if k in single_value_scalars:
self._clearml_task.get_logger().report_single_value(
name=k + ClearMLCallback.log_suffix, value=v
)
# 如果键 k 以评估数据前缀开头,则将其作为评估数据报告到 ClearML
elif k.startswith(eval_prefix):
self._clearml_task.get_logger().report_scalar(
title="eval" + ClearMLCallback.log_suffix,
series=k[eval_prefix_len:],
value=v,
iteration=state.global_step,
)
# 如果键 k 以测试数据前缀开头,则将其作为测试数据报告到 ClearML
elif k.startswith(test_prefix):
self._clearml_task.get_logger().report_scalar(
title="test" + ClearMLCallback.log_suffix,
series=k[test_prefix_len:],
value=v,
iteration=state.global_step,
)
# 否则,将其作为训练数据报告到 ClearML
else:
self._clearml_task.get_logger().report_scalar(
title="train" + ClearMLCallback.log_suffix,
series=k,
value=v,
iteration=state.global_step,
)
else:
# 如果值 v 的类型不是整数或浮点数,则记录警告信息
logger.warning(
"Trainer is attempting to log a value of "
f'"{v}" of type {type(v)} for key "{k}" as a scalar. '
"This invocation of ClearML logger's report_scalar() "
"is incorrect so we dropped this attribute."
)
# 定义一个保存模型回调函数,根据特定条件执行保存操作
def on_save(self, args, state, control, **kwargs):
# 如果启用模型日志、存在 ClearML 任务并且当前进程为主进程
if self._log_model and self._clearml_task and state.is_world_process_zero:
# 根据全局步数创建检查点目录名
ckpt_dir = f"checkpoint-{state.global_step}"
# 构建检查点在文件系统中的路径
artifact_path = os.path.join(args.output_dir, ckpt_dir)
# 定义保存的模型名,包括 ClearML 日志后缀
name = ckpt_dir + ClearMLCallback.log_suffix
# 输出日志,指示正在记录检查点信息
logger.info(f"Logging checkpoint artifact `{name}`. This may take some time.")
# 创建 ClearML 的 OutputModel 对象,关联到当前任务并设置名称
output_model = self._clearml.OutputModel(task=self._clearml_task, name=name)
output_model.connect(task=self._clearml_task, name=name)
# 更新模型权重包,将指定路径的权重文件打包,指定迭代次数,并禁止自动删除文件
output_model.update_weights_package(
weights_path=artifact_path,
target_filename=ckpt_dir,
iteration=state.global_step,
auto_delete_file=False,
)
# 将保存的模型对象添加到检查点列表中
self._checkpoints_saved.append(output_model)
# 当设置了保存总数限制并且当前保存的检查点数量超过限制时执行
while args.save_total_limit and args.save_total_limit < len(self._checkpoints_saved):
try:
# 尝试移除最早的检查点模型及其关联的权重文件
self._clearml.model.Model.remove(
self._checkpoints_saved[0],
delete_weights_file=True,
force=True,
raise_on_errors=True,
)
except Exception as e:
# 记录警告,指示在超过保存限制后无法移除检查点的错误信息
logger.warning(
"Could not remove checkpoint `{}` after going over the `save_total_limit`. Error is: {}".format(
self._checkpoints_saved[0].name, e
)
)
# 中断循环,保持检查点列表不变
break
# 移除成功后,更新检查点列表,去除最早的一个检查点对象
self._checkpoints_saved = self._checkpoints_saved[1:]
# 将训练参数复制为超参数,并将其传递给 ClearML 任务
def _copy_training_args_as_hparams(self, training_args, prefix):
# 将训练参数对象中的每个字段转换为字典,排除以 "_token" 结尾的字段
as_dict = {
field.name: getattr(training_args, field.name)
for field in fields(training_args)
if field.init and not field.name.endswith("_token")
}
# 扁平化字典,将所有键转换为字符串
flat_dict = {str(k): v for k, v in self._clearml.utilities.proxy_object.flatten_dictionary(as_dict).items()}
# 将扁平化后的字典作为超参数设置到 ClearML 任务中
self._clearml_task._arguments.copy_from_dict(flat_dict, prefix=prefix)
# 定义一个名为 FlyteCallback 的类,继承自 TrainerCallback 类
class FlyteCallback(TrainerCallback):
"""
A [`TrainerCallback`] that sends the logs to [Flyte](https://flyte.org/).
NOTE: This callback only works within a Flyte task.
Args:
save_log_history (`bool`, *optional*, defaults to `True`):
When set to True, the training logs are saved as a Flyte Deck.
sync_checkpoints (`bool`, *optional*, defaults to `True`):
When set to True, checkpoints are synced with Flyte and can be used to resume training in the case of an
interruption.
Example:
```
from flytekit import current_context, task
@task
def train_hf_transformer():
cp = current_context().checkpoint
trainer = Trainer(..., callbacks=[FlyteCallback()])
output = trainer.train(resume_from_checkpoint=cp.restore())
```
"""
# 初始化方法,接受两个可选参数:save_log_history 和 sync_checkpoints
def __init__(self, save_log_history: bool = True, sync_checkpoints: bool = True):
# 调用父类的初始化方法
super().__init__()
# 检查 flytekit 是否可用,如果不可用则抛出 ImportError
if not is_flytekit_available():
raise ImportError("FlyteCallback requires flytekit to be installed. Run `pip install flytekit`.")
# 检查是否安装了 flytekitplugins-deck-standard 和 pandas,如果未安装,则记录警告并将 save_log_history 设置为 False
if not is_flyte_deck_standard_available() or not is_pandas_available():
logger.warning(
"Syncing log history requires both flytekitplugins-deck-standard and pandas to be installed. "
"Run `pip install flytekitplugins-deck-standard pandas` to enable this feature."
)
save_log_history = False
# 导入当前上下文的 checkpoint 对象
from flytekit import current_context
self.cp = current_context().checkpoint
# 初始化实例变量
self.save_log_history = save_log_history
self.sync_checkpoints = sync_checkpoints
# 在保存方法回调时执行的操作
def on_save(self, args, state, control, **kwargs):
# 如果 sync_checkpoints 为 True,并且当前状态的全局进程是零(即主进程)
if self.sync_checkpoints and state.is_world_process_zero:
# 构建检查点目录和存储路径
ckpt_dir = f"checkpoint-{state.global_step}"
artifact_path = os.path.join(args.output_dir, ckpt_dir)
# 记录信息,将检查点同步到 Flyte。这可能需要一些时间。
logger.info(f"Syncing checkpoint in {ckpt_dir} to Flyte. This may take time.")
self.cp.save(artifact_path)
# 在训练结束时执行的操作
def on_train_end(self, args, state, control, **kwargs):
# 如果 save_log_history 为 True
if self.save_log_history:
# 导入 pandas、Deck 类以及 TableRenderer
import pandas as pd
from flytekit import Deck
from flytekitplugins.deck.renderer import TableRenderer
# 创建日志历史的 DataFrame
log_history_df = pd.DataFrame(state.log_history)
# 创建一个名为 "Log History" 的 Flyte Deck,使用 TableRenderer 将 DataFrame 转换为 HTML 格式
Deck("Log History", TableRenderer().to_html(log_history_df))
# DVCLiveCallback 类暂时省略注释
class DVCLiveCallback(TrainerCallback):
"""
A [`TrainerCallback`] that sends the logs to [DVCLive](https://www.dvc.org/doc/dvclive).
Use the environment variables below in `setup` to configure the integration. To customize this callback beyond
those environment variables, see [here](https://dvc.org/doc/dvclive/ml-frameworks/huggingface).
"""
Args:
live (`dvclive.Live`, *optional*, defaults to `None`):
Optional Live instance. If None, a new instance will be created using **kwargs.
log_model (Union[Literal["all"], bool], *optional*, defaults to `None`):
Whether to use `dvclive.Live.log_artifact()` to log checkpoints created by [`Trainer`]. If set to `True`,
the final checkpoint is logged at the end of training. If set to `"all"`, the entire
[`TrainingArguments`]'s `output_dir` is logged at each checkpoint.
"""
def __init__(
self,
live: Optional[Any] = None,
log_model: Optional[Union[Literal["all"], bool]] = None,
**kwargs,
):
if not is_dvclive_available():
raise RuntimeError("DVCLiveCallback requires dvclive to be installed. Run `pip install dvclive`.")
from dvclive import Live
self._initialized = False
self.live = None
if isinstance(live, Live):
self.live = live
elif live is not None:
raise RuntimeError(f"Found class {live.__class__} for live, expected dvclive.Live")
self._log_model = log_model
if self._log_model is None:
log_model_env = os.getenv("HF_DVCLIVE_LOG_MODEL", "FALSE")
if log_model_env.upper() in ENV_VARS_TRUE_VALUES:
self._log_model = True
elif log_model_env.lower() == "all":
self._log_model = "all"
def setup(self, args, state, model):
"""
Setup the optional DVCLive integration. To customize this callback beyond the environment variables below, see
[here](https://dvc.org/doc/dvclive/ml-frameworks/huggingface).
Environment:
- **HF_DVCLIVE_LOG_MODEL** (`str`, *optional*):
Whether to use `dvclive.Live.log_artifact()` to log checkpoints created by [`Trainer`]. If set to `True` or
*1*, the final checkpoint is logged at the end of training. If set to `all`, the entire
[`TrainingArguments`]'s `output_dir` is logged at each checkpoint.
"""
from dvclive import Live
self._initialized = True
if state.is_world_process_zero:
if not self.live:
self.live = Live()
self.live.log_params(args.to_dict())
def on_train_begin(self, args, state, control, model=None, **kwargs):
if not self._initialized:
self.setup(args, state, model)
def on_log(self, args, state, control, model=None, logs=None, **kwargs):
if not self._initialized:
self.setup(args, state, model)
if state.is_world_process_zero:
from dvclive.plots import Metric
from dvclive.utils import standardize_metric_name
for key, value in logs.items():
if Metric.could_log(value):
self.live.log_metric(standardize_metric_name(key, "dvclive.huggingface"), value)
else:
logger.warning(
"Trainer is attempting to log a value of "
f'"{value}" of type {type(value)} for key "{key}" as a scalar. '
"This invocation of DVCLive's Live.log_metric() "
"is incorrect so we dropped this attribute."
)
self.live.next_step()
def on_save(self, args, state, control, **kwargs):
if self._log_model == "all" and self._initialized and state.is_world_process_zero:
self.live.log_artifact(args.output_dir)
def on_train_end(self, args, state, control, **kwargs):
if self._initialized and state.is_world_process_zero:
from transformers.trainer import Trainer
if self._log_model is True:
fake_trainer = Trainer(args=args, model=kwargs.get("model"), tokenizer=kwargs.get("tokenizer"))
name = "best" if args.load_best_model_at_end else "last"
output_dir = os.path.join(args.output_dir, name)
fake_trainer.save_model(output_dir)
self.live.log_artifact(output_dir, name=name, type="model", copy=True)
self.live.end()
INTEGRATION_TO_CALLBACK = {
"azure_ml": AzureMLCallback,
"comet_ml": CometCallback,
"mlflow": MLflowCallback,
"neptune": NeptuneCallback,
"tensorboard": TensorBoardCallback,
"wandb": WandbCallback,
"codecarbon": CodeCarbonCallback,
"clearml": ClearMLCallback,
"dagshub": DagsHubCallback,
"flyte": FlyteCallback,
"dvclive": DVCLiveCallback,
}
def get_reporting_integration_callbacks(report_to):
for integration in report_to:
if integration not in INTEGRATION_TO_CALLBACK:
raise ValueError(
f"{integration} is not supported, only {', '.join(INTEGRATION_TO_CALLBACK.keys())} are supported."
)
return [INTEGRATION_TO_CALLBACK[integration] for integration in report_to]
.\integrations\peft.py
import inspect
import warnings
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from ..utils import (
check_peft_version,
find_adapter_config_file,
is_accelerate_available,
is_peft_available,
is_torch_available,
logging,
)
if is_accelerate_available():
from accelerate import dispatch_model
from accelerate.utils import get_balanced_memory, infer_auto_device_map
MIN_PEFT_VERSION = "0.5.0"
if TYPE_CHECKING:
if is_torch_available():
import torch
logger = logging.get_logger(__name__)
class PeftAdapterMixin:
"""
包含加载和使用 PEFT 库支持的适配器权重的所有函数的类。有关适配器及如何在基于 Transformer 的模型中注入它们的详细信息,
请参阅 PEFT 库的文档: https://huggingface.co/docs/peft/index
当前支持的 PEFT 方法是所有非前缀调整方法。以下是可以使用此混合类加载、训练和运行的支持的 PEFT 方法列表:
- Low Rank Adapters (LoRA): https://huggingface.co/docs/peft/conceptual_guides/lora
- IA3: https://huggingface.co/docs/peft/conceptual_guides/ia3
- AdaLora: https://arxiv.org/abs/2303.10512
其他 PEFT 模型,如提示调整、提示学习等因其适配器无法“注入”到 torch 模块而不在讨论范围内。要使用这些方法,请参阅 PEFT 库的使用指南。
使用此混合类,如果安装了正确的 PEFT 版本,可以:
- 加载存储在本地路径或远程 Hub 存储库中的适配器,并将其注入模型中
- 在模型中附加新的适配器,并使用 Trainer 或自己的方法进行训练
- 附加多个适配器并迭代地激活/停用它们
- 激活/停用模型中的所有适配器
- 获取激活适配器的 `state_dict`
"""
_hf_peft_config_loaded = False
python
def load_adapter(
self,
peft_model_id: Optional[str] = None,
adapter_name: Optional[str] = None,
revision: Optional[str] = None,
token: Optional[str] = None,
device_map: Optional[str] = "auto",
max_memory: Optional[str] = None,
offload_folder: Optional[str] = None,
offload_index: Optional[int] = None,
peft_config: Dict[str, Any] = None,
adapter_state_dict: Optional[Dict[str, "torch.Tensor"]] = None,
adapter_kwargs: Optional[Dict[str, Any]] = None,
):
r"""
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
official documentation: https://huggingface.co/docs/peft
Adds a fresh new adapter to the current model for training purpose. If no adapter name is passed, a default
name is assigned to the adapter to follow the convention of PEFT library (in PEFT we use "default" as the
default adapter name).
Args:
adapter_config (`~peft.PeftConfig`):
The configuration of the adapter to add, supported adapters are non-prefix tuning and adaption prompts
methods
adapter_name (`str`, *optional*, defaults to `"default"`):
The name of the adapter to add. If no name is passed, a default name is assigned to the adapter.
"""
check_peft_version(min_version=MIN_PEFT_VERSION)
from peft import PeftConfig, inject_adapter_in_model
adapter_name = adapter_name or "default"
if not self._hf_peft_config_loaded:
self._hf_peft_config_loaded = True
elif adapter_name in self.peft_config:
raise ValueError(f"Adapter with name {adapter_name} already exists. Please use a different name.")
if not isinstance(adapter_config, PeftConfig):
raise ValueError(
f"adapter_config should be an instance of PeftConfig. Got {type(adapter_config)} instead."
)
adapter_config.base_model_name_or_path = self.__dict__.get("name_or_path", None)
inject_adapter_in_model(adapter_config, self, adapter_name)
self.set_adapter(adapter_name)
def set_adapter(self, adapter_name: Union[List[str], str]) -> None:
"""
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
official documentation: https://huggingface.co/docs/peft
Sets a specific adapter by forcing the model to use a that adapter and disable the other adapters.
Args:
adapter_name (`Union[List[str], str]`):
The name of the adapter to set. Can be also a list of strings to set multiple adapters.
"""
check_peft_version(min_version=MIN_PEFT_VERSION)
if not self._hf_peft_config_loaded:
raise ValueError("No adapter loaded. Please load an adapter first.")
elif isinstance(adapter_name, list):
missing = set(adapter_name) - set(self.peft_config)
if len(missing) > 0:
raise ValueError(
f"Following adapter(s) could not be found: {', '.join(missing)}. Make sure you are passing the correct adapter name(s)."
f" current loaded adapters are: {list(self.peft_config.keys())}"
)
elif adapter_name not in self.peft_config:
raise ValueError(
f"Adapter with name {adapter_name} not found. Please pass the correct adapter name among {list(self.peft_config.keys())}"
)
from peft.tuners.tuners_utils import BaseTunerLayer
from peft.utils import ModulesToSaveWrapper
_adapters_has_been_set = False
for _, module in self.named_modules():
if isinstance(module, (BaseTunerLayer, ModulesToSaveWrapper)):
if hasattr(module, "set_adapter"):
module.set_adapter(adapter_name)
else:
module.active_adapter = adapter_name
_adapters_has_been_set = True
if not _adapters_has_been_set:
raise ValueError(
"Did not succeeded in setting the adapter. Please make sure you are using a model that supports adapters."
)
def disable_adapters(self) -> None:
r"""
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
official documentation: https://huggingface.co/docs/peft
Disable all adapters that are attached to the model. This leads to inferring with the base model only.
"""
check_peft_version(min_version=MIN_PEFT_VERSION)
if not self._hf_peft_config_loaded:
raise ValueError("No adapter loaded. Please load an adapter first.")
from peft.tuners.tuners_utils import BaseTunerLayer
from peft.utils import ModulesToSaveWrapper
for _, module in self.named_modules():
if isinstance(module, (BaseTunerLayer, ModulesToSaveWrapper)):
if hasattr(module, "enable_adapters"):
module.enable_adapters(enabled=False)
else:
module.disable_adapters = True
def enable_adapters(self) -> None:
"""
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
official documentation: https://huggingface.co/docs/peft
Enable adapters that are attached to the model. The model will use `self.active_adapter()`
"""
check_peft_version(min_version=MIN_PEFT_VERSION)
if not self._hf_peft_config_loaded:
raise ValueError("No adapter loaded. Please load an adapter first.")
from peft.tuners.tuners_utils import BaseTunerLayer
for _, module in self.named_modules():
if isinstance(module, BaseTunerLayer):
if hasattr(module, "enable_adapters"):
module.enable_adapters(enabled=True)
else:
module.disable_adapters = False
def active_adapters(self) -> List[str]:
"""
获取当前模型的活跃适配器列表。如果进行多适配器推理(结合多个适配器进行推理),返回所有活跃适配器的列表,以便用户可以相应处理。
对于之前版本的 PEFT(不支持多适配器推理),`module.active_adapter` 将返回一个单独的字符串。
"""
check_peft_version(min_version=MIN_PEFT_VERSION)
if not is_peft_available():
raise ImportError("PEFT is not available. Please install PEFT to use this function: `pip install peft`.")
if not self._hf_peft_config_loaded:
raise ValueError("No adapter loaded. Please load an adapter first.")
from peft.tuners.tuners_utils import BaseTunerLayer
for _, module in self.named_modules():
if isinstance(module, BaseTunerLayer):
active_adapters = module.active_adapter
break
if isinstance(active_adapters, str):
active_adapters = [active_adapters]
return active_adapters
def active_adapter(self) -> str:
"""
警告:`active_adapter` 方法已弃用,并将在未来版本中移除。
"""
warnings.warn(
"The `active_adapter` method is deprecated and will be removed in a future version.", FutureWarning
)
return self.active_adapters()[0]
def get_adapter_state_dict(self, adapter_name: Optional[str] = None) -> dict:
"""
获取适配器的状态字典,该字典应仅包含指定适配器名称的权重张量。如果未传适配器名称,则使用活跃适配器。
Args:
adapter_name (`str`, *optional*):
要获取状态字典的适配器名称。如果未传适配器名称,则使用活跃适配器。
"""
check_peft_version(min_version=MIN_PEFT_VERSION)
if not self._hf_peft_config_loaded:
raise ValueError("No adapter loaded. Please load an adapter first.")
from peft import get_peft_model_state_dict
if adapter_name is None:
adapter_name = self.active_adapter()
adapter_state_dict = get_peft_model_state_dict(self, adapter_name=adapter_name)
return adapter_state_dict
def _dispatch_accelerate_model(
self,
device_map: str,
max_memory: Optional[int] = None,
offload_folder: Optional[str] = None,
offload_index: Optional[int] = None,
) -> None:
"""
Optional re-dispatch the model and attach new hooks to the model in case the model has been loaded with
accelerate (i.e. with `device_map=xxx`)
Args:
device_map (`str` or `Dict[str, Union[int, str, torch.device]]` or `int` or `torch.device`, *optional*):
A map that specifies where each submodule should go. It doesn't need to be refined to each
parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
same device. If we only pass the device (*e.g.*, `"cpu"`, `"cuda:1"`, `"mps"`, or a GPU ordinal rank
like `1`) on which the model will be allocated, the device map will map the entire model to this
device. Passing `device_map = 0` means put the whole model on GPU 0.
To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
more information about each option see [designing a device
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
max_memory (`Dict`, *optional*):
A dictionary device identifier to maximum memory. Will default to the maximum memory available for each
GPU and the available CPU RAM if unset.
offload_folder (`str` or `os.PathLike`, *optional*):
If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
offload_index (`int`, *optional*):
The offload_index argument to be passed to `accelerate.dispatch_model` method.
"""
dispatch_model_kwargs = {}
if "offload_index" in inspect.signature(dispatch_model).parameters:
dispatch_model_kwargs["offload_index"] = offload_index
no_split_module_classes = self._no_split_modules
if device_map != "sequential":
max_memory = get_balanced_memory(
self,
max_memory=max_memory,
no_split_module_classes=no_split_module_classes,
low_zero=(device_map == "balanced_low_0"),
)
if isinstance(device_map, str):
device_map = infer_auto_device_map(
self, max_memory=max_memory, no_split_module_classes=no_split_module_classes
)
dispatch_model(
self,
device_map=device_map,
offload_dir=offload_folder,
**dispatch_model_kwargs,
)
.\integrations\quanto.py
from ..utils import is_torch_available
if is_torch_available():
import torch
def replace_with_quanto_layers(
model,
quantization_config=None,
modules_to_not_convert=None,
current_key_name=None,
has_been_replaced=False,
):
"""
Public method that recursively replaces the Linear layers of the given model with Quanto quantized layers.
Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
Args:
model (`torch.nn.Module`):
The model to convert, can be any `torch.nn.Module` instance.
quantization_config (`AqlmConfig`, defaults to `None`):
The quantization config object that contains the quantization parameters.
modules_to_not_convert (`list`, *optional*, defaults to `None`):
A list of modules to not convert. If a module name is in the list (e.g. `lm_head`), it will not be
converted.
current_key_name (`list`, *optional*, defaults to `None`):
A list that contains the current key name. This is used for recursion and should not be passed by the user.
has_been_replaced (`bool`, *optional*, defaults to `None`):
A boolean that indicates if the conversion has been successful or not. This is used for recursion and
should not be passed by the user.
"""
from accelerate import init_empty_weights
from quanto import QLayerNorm, QLinear, qfloat8, qint2, qint4, qint8
w_mapping = {"float8": qfloat8, "int8": qint8, "int4": qint4, "int2": qint2}
a_mapping = {None: None, "float8": qfloat8, "int8": qint8}
if modules_to_not_convert is None:
modules_to_not_convert = []
for name, module in model.named_children():
if current_key_name is None:
current_key_name = []
current_key_name.append(name)
if not any(key in ".".join(current_key_name) for key in modules_to_not_convert):
with init_empty_weights():
if isinstance(module, torch.nn.Linear):
model._modules[name] = QLinear(
in_features=module.in_features,
out_features=module.out_features,
bias=module.bias is not None,
dtype=module.weight.dtype,
weights=w_mapping[quantization_config.weights],
activations=a_mapping[quantization_config.activations],
)
model._modules[name].requires_grad_(False)
has_been_replaced = True
elif isinstance(module, torch.nn.LayerNorm):
if quantization_config.activations is not None:
model._modules[name] = QLayerNorm(
module.normalized_shape,
module.eps,
module.elementwise_affine,
module.bias is not None,
activations=a_mapping[quantization_config.activations],
)
has_been_replaced = True
if len(list(module.children())) > 0:
_, has_been_replaced = replace_with_quanto_layers(
module,
quantization_config=quantization_config,
modules_to_not_convert=modules_to_not_convert,
current_key_name=current_key_name,
has_been_replaced=has_been_replaced,
)
current_key_name.pop(-1)
return model, has_been_replaced
.\integrations\tpu.py
from torch.utils.data import DataLoader
from ..utils import is_torch_xla_available
def tpu_spmd_dataloader(dataloader: DataLoader):
if is_torch_xla_available():
import torch_xla.distributed.parallel_loader as pl
assert isinstance(
dataloader, pl.MpDeviceLoader
), "The dataloader must be a `torch_xla.distributed.parallel_loader.MpDeviceLoader`."
import torch_xla.distributed.spmd as xs
sharding_spec = xs.ShardingSpec(xs.get_global_mesh(), ("fsdp", None))
dataloader._parallel_loader_kwargs["input_sharding"] = sharding_spec
return dataloader
else:
return dataloader
.\integrations\__init__.py
from typing import TYPE_CHECKING
from ..utils import _LazyModule
_import_structure = {
"aqlm": ["replace_with_aqlm_linear"],
"awq": [
"fuse_awq_modules",
"post_init_awq_exllama_modules",
"replace_with_awq_linear",
],
"bitsandbytes": [
"get_keys_to_not_convert",
"replace_8bit_linear",
"replace_with_bnb_linear",
"set_module_8bit_tensor_to_device",
"set_module_quantized_tensor_to_device",
],
"deepspeed": [
"HfDeepSpeedConfig",
"HfTrainerDeepSpeedConfig",
"deepspeed_config",
"deepspeed_init",
"deepspeed_load_checkpoint",
"deepspeed_optim_sched",
"is_deepspeed_available",
"is_deepspeed_zero3_enabled",
"set_hf_deepspeed_config",
"unset_hf_deepspeed_config",
],
"integration_utils": [
"INTEGRATION_TO_CALLBACK",
"AzureMLCallback",
"ClearMLCallback",
"CodeCarbonCallback",
"CometCallback",
"DagsHubCallback",
"DVCLiveCallback",
"FlyteCallback",
"MLflowCallback",
"NeptuneCallback",
"NeptuneMissingConfiguration",
"TensorBoardCallback",
"WandbCallback",
"get_available_reporting_integrations",
"get_reporting_integration_callbacks",
"hp_params",
"is_azureml_available",
"is_clearml_available",
"is_codecarbon_available",
"is_comet_available",
"is_dagshub_available",
"is_dvclive_available",
"is_flyte_deck_standard_available",
"is_flytekit_available",
"is_mlflow_available",
"is_neptune_available",
"is_optuna_available",
"is_ray_available",
"is_ray_tune_available",
"is_sigopt_available",
"is_tensorboard_available",
"is_wandb_available",
"rewrite_logs",
"run_hp_search_optuna",
"run_hp_search_ray",
"run_hp_search_sigopt",
"run_hp_search_wandb",
],
"peft": ["PeftAdapterMixin"],
"quanto": ["replace_with_quanto_layers"],
}
if TYPE_CHECKING:
from .aqlm import replace_with_aqlm_linear
from .awq import (
fuse_awq_modules,
post_init_awq_exllama_modules,
replace_with_awq_linear,
)
from .bitsandbytes import (
get_keys_to_not_convert,
replace_8bit_linear,
replace_with_bnb_linear,
set_module_8bit_tensor_to_device,
set_module_quantized_tensor_to_device
)
from .deepspeed import (
HfDeepSpeedConfig,
HfTrainerDeepSpeedConfig,
deepspeed_config,
deepspeed_init,
deepspeed_load_checkpoint,
deepspeed_optim_sched,
is_deepspeed_available,
is_deepspeed_zero3_enabled,
set_hf_deepspeed_config,
unset_hf_deepspeed_config
)
from .integration_utils import (
INTEGRATION_TO_CALLBACK,
AzureMLCallback,
ClearMLCallback,
CodeCarbonCallback,
CometCallback,
DagsHubCallback,
DVCLiveCallback,
FlyteCallback,
MLflowCallback,
NeptuneCallback,
NeptuneMissingConfiguration,
TensorBoardCallback,
WandbCallback,
get_available_reporting_integrations,
get_reporting_integration_callbacks,
hp_params,
is_azureml_available,
is_clearml_available,
is_codecarbon_available,
is_comet_available,
is_dagshub_available,
is_dvclive_available,
is_flyte_deck_standard_available,
is_flytekit_available,
is_mlflow_available,
is_neptune_available,
is_optuna_available,
is_ray_available,
is_ray_tune_available,
is_sigopt_available,
is_tensorboard_available,
is_wandb_available,
rewrite_logs,
run_hp_search_optuna,
run_hp_search_ray,
run_hp_search_sigopt,
run_hp_search_wandb
)
from .peft import PeftAdapterMixin
from .quanto import replace_with_quanto_layers
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
.\keras_callbacks.py
import logging
import os
from pathlib import Path
from time import sleep
from typing import Callable, List, Optional, Union
import numpy as np
import tensorflow as tf
from huggingface_hub import Repository, create_repo
from packaging.version import parse
from . import IntervalStrategy, PreTrainedTokenizerBase
from .modelcard import TrainingSummary
from .modeling_tf_utils import keras
logger = logging.getLogger(__name__)
class KerasMetricCallback(keras.callbacks.Callback):
"""
Callback to compute metrics at the end of every epoch. Unlike normal Keras metrics, these do not need to be
compilable by TF. It is particularly useful for common NLP metrics like BLEU and ROUGE that require string
operations or generation loops that cannot be compiled. Predictions (or generations) will be computed on the
`eval_dataset` before being passed to the `metric_fn` in `np.ndarray` format. The `metric_fn` should compute
metrics and return a dict mapping metric names to metric values.
We provide an example of a suitable metric_fn that computes ROUGE scores for a summarization model below. Note that
this example skips some post-processing for readability and simplicity, and should probably not be used as-is!
```
from datasets import load_metric
rouge_metric = load_metric("rouge")
def rouge_fn(predictions, labels):
decoded_predictions = tokenizer.batch_decode(predictions, skip_special_tokens=True)
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
result = rouge_metric.compute(predictions=decoded_predictions, references=decoded_labels)
return {key: value.mid.fmeasure * 100 for key, value in result.items()}
```
The above function will return a dict containing values which will be logged like any other Keras metric:
```
{'rouge1': 37.4199, 'rouge2': 13.9768, 'rougeL': 34.361, 'rougeLsum': 35.0781
```
"""
pass
def __init__(
self,
metric_fn: Callable,
eval_dataset: Union[tf.data.Dataset, np.ndarray, tf.Tensor, tuple, dict],
output_cols: Optional[List[str]] = None,
label_cols: Optional[List[str]] = None,
batch_size: Optional[int] = None,
predict_with_generate: bool = False,
use_xla_generation: bool = False,
generate_kwargs: Optional[dict] = None,
):
super().__init__()
self.metric_fn = metric_fn
self.batch_size = batch_size
if not isinstance(eval_dataset, tf.data.Dataset):
if batch_size is None:
raise ValueError(
"When passing data to KerasMetricCallback that is not a pre-batched tf.data.Dataset "
"the batch_size argument must be set."
)
eval_dataset = tf.data.Dataset.from_tensor_slices(eval_dataset).batch(batch_size, drop_remainder=False)
self.eval_dataset = eval_dataset
self.predict_with_generate = predict_with_generate
self.output_cols = output_cols
if isinstance(eval_dataset.element_spec, tuple) and len(eval_dataset.element_spec) == 2:
input_spec, label_spec = eval_dataset.element_spec
else:
input_spec = eval_dataset.element_spec
label_spec = None
if label_cols is not None:
for label in label_cols:
if label not in input_spec:
raise ValueError(f"Label {label} is in label_cols but could not be found in the dataset inputs!")
self.label_cols = label_cols
self.use_keras_label = False
elif label_spec is not None:
self.label_cols = None
self.use_keras_label = True
elif "labels" in input_spec:
self.label_cols = ["labels"]
self.use_keras_label = False
logging.warning("No label_cols specified for KerasMetricCallback, assuming you want the 'labels' key.")
elif "start_positions" in input_spec and "end_positions" in input_spec:
self.label_cols = ["start_positions", "end_positions"]
self.use_keras_label = False
logging.warning(
"No label_cols specified for KerasMetricCallback, assuming you want the "
"start_positions and end_positions keys."
)
else:
raise ValueError("Could not autodetect label_cols for KerasMetricCallback, please specify them!")
if parse(tf.__version__) < parse("2.7"):
logging.warning("TF versions less than 2.7 may encounter issues with KerasMetricCallback!")
self.use_xla_generation = use_xla_generation
self.generate_kwargs = {} if generate_kwargs is None else generate_kwargs
self.generation_function = None
@staticmethod
def _concatenate_batches(batches, padding_index=-100):
if batches[0].ndim == 1 or all(batch.shape[1] == batches[0].shape[1] for batch in batches):
return np.concatenate(batches, axis=0)
max_len = max([batch.shape[1] for batch in batches])
num_samples = sum([batch.shape[0] for batch in batches])
output = np.full_like(
batches[0], fill_value=padding_index, shape=[num_samples, max_len] + list(batches[0].shape[2:])
)
i = 0
for batch in batches:
output[i : i + len(batch), : batch.shape[1]] = batch
i += len(batch)
return output
def _postprocess_predictions_or_labels(self, inputs):
if isinstance(inputs[0], dict):
outputs = {}
for key in inputs[0].keys():
outputs[key] = self._concatenate_batches([batch[key] for batch in inputs])
if len(outputs) == 1:
outputs = list(outputs.values())[0]
elif isinstance(inputs[0], list) or isinstance(inputs[0], tuple):
outputs = []
for input_list in zip(*inputs):
outputs.append(self._concatenate_batches(input_list))
if len(outputs) == 1:
outputs = outputs[0]
elif isinstance(inputs[0], np.ndarray):
outputs = self._concatenate_batches(inputs)
elif isinstance(inputs[0], tf.Tensor):
outputs = self._concatenate_batches([tensor.numpy() for tensor in inputs])
else:
raise TypeError(f"Couldn't handle batch of type {type(inputs[0])}!")
return outputs
class PushToHubCallback(keras.callbacks.Callback):
"""
Callback that will save and push the model to the Hub regularly. By default, it pushes once per epoch, but this can
be changed with the `save_strategy` argument. Pushed models can be accessed like any other model on the hub, such
as with the `from_pretrained` method.
```
from transformers.keras_callbacks import PushToHubCallback
push_to_hub_callback = PushToHubCallback(
output_dir="./model_save",
tokenizer=tokenizer,
hub_model_id="gpt5-7xlarge",
)
model.fit(train_dataset, callbacks=[push_to_hub_callback])
```
Args:
output_dir (`str`):
The output directory where the model predictions and checkpoints will be written and synced with the
repository on the Hub.
save_strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"epoch"`):
The checkpoint save strategy to adopt during training. Possible values are:
- `"no"`: Save is done at the end of training.
- `"epoch"`: Save is done at the end of each epoch.
- `"steps"`: Save is done every `save_steps`
save_steps (`int`, *optional*):
The number of steps between saves when using the "steps" `save_strategy`.
tokenizer (`PreTrainedTokenizerBase`, *optional*):
The tokenizer used by the model. If supplied, will be uploaded to the repo alongside the weights.
hub_model_id (`str`, *optional*):
The name of the repository to keep in sync with the local `output_dir`. It can be a simple model ID in
which case the model will be pushed in your namespace. Otherwise it should be the whole repository name,
for instance `"user_name/model"`, which allows you to push to an organization you are a member of with
`"organization_name/model"`.
Will default to the name of `output_dir`.
hub_token (`str`, *optional*):
The token to use to push the model to the Hub. Will default to the token in the cache folder obtained with
`huggingface-cli login`.
checkpoint (`bool`, *optional*, defaults to `False`):
Whether to save full training checkpoints (including epoch and optimizer state) to allow training to be
resumed. Only usable when `save_strategy` is `"epoch"`.
"""
def __init__(
self,
output_dir: Union[str, Path],
save_strategy: Union[str, IntervalStrategy] = "epoch",
save_steps: Optional[int] = None,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
hub_model_id: Optional[str] = None,
hub_token: Optional[str] = None,
checkpoint: bool = False,
**model_card_args,
):
super().__init__()
self.output_dir = output_dir
self.save_strategy = save_strategy
self.save_steps = save_steps
self.tokenizer = tokenizer
self.hub_model_id = hub_model_id
self.hub_token = hub_token
self.checkpoint = checkpoint
self.model_card_args = model_card_args
):
super().__init__()
if checkpoint and save_strategy != "epoch":
raise ValueError("Cannot save checkpoints when save_strategy is not 'epoch'!")
if isinstance(save_strategy, str):
save_strategy = IntervalStrategy(save_strategy.lower())
self.save_strategy = save_strategy
if self.save_strategy == IntervalStrategy.STEPS and (not isinstance(save_steps, int) or save_steps <= 0):
raise ValueError("Please supply a positive integer argument for save_steps when save_strategy == 'steps'!")
self.save_steps = save_steps
output_dir = Path(output_dir)
if hub_model_id is None:
hub_model_id = output_dir.absolute().name
self.hub_model_id = create_repo(repo_id=hub_model_id, exist_ok=True, token=hub_token).repo_id
self.output_dir = output_dir
self.repo = Repository(str(self.output_dir), clone_from=self.hub_model_id, token=hub_token)
self.tokenizer = tokenizer
self.last_job = None
self.checkpoint = checkpoint
self.training_history = None
self.model_card_args = model_card_args
def on_train_begin(self, logs=None):
self.training_history = []
def on_train_batch_end(self, batch, logs=None):
if self.save_strategy == IntervalStrategy.STEPS and (batch + 1) % self.save_steps == 0:
if self.last_job is not None and not self.last_job.is_done:
return
self.model.save_pretrained(self.output_dir)
if self.tokenizer is not None:
self.tokenizer.save_pretrained(self.output_dir)
_, self.last_job = self.repo.push_to_hub(
commit_message=f"Training in progress steps {batch}", blocking=False
)
def on_epoch_end(self, epoch, logs=None):
logs = logs.copy()
if "epoch" not in logs:
logs["epoch"] = epoch
self.training_history.append(logs)
if self.save_strategy == IntervalStrategy.EPOCH:
if self.last_job is not None and not self.last_job.is_done:
return
self.model.save_pretrained(self.output_dir)
if self.tokenizer is not None:
self.tokenizer.save_pretrained(self.output_dir)
if self.checkpoint:
checkpoint_dir = os.path.join(self.output_dir, "checkpoint")
self.model._save_checkpoint(checkpoint_dir, epoch)
train_summary = TrainingSummary.from_keras(
model=self.model,
model_name=self.hub_model_id,
keras_history=self.training_history,
**self.model_card_args,
)
model_card = train_summary.to_model_card()
with (self.output_dir / "README.md").open("w") as f:
f.write(model_card)
_, self.last_job = self.repo.push_to_hub(
commit_message=f"Training in progress epoch {epoch}", blocking=False
)
def on_train_end(self, logs=None):
if self.last_job is not None and not self.last_job.is_done:
logging.info("Pushing the last epoch to the Hub, this may take a while...")
while not self.last_job.is_done:
sleep(1)
else:
self.model.save_pretrained(self.output_dir)
if self.tokenizer is not None:
self.tokenizer.save_pretrained(self.output_dir)
train_summary = TrainingSummary.from_keras(
model=self.model,
model_name=self.hub_model_id,
keras_history=self.training_history,
**self.model_card_args,
)
model_card = train_summary.to_model_card()
with (self.output_dir / "README.md").open("w") as f:
f.write(model_card)
self.repo.push_to_hub(commit_message="End of training", blocking=True)