Transformers 源码解析(一百三十五)
.\tools\agent_types.py
import os
import pathlib
import tempfile
import uuid
import numpy as np
from ..utils import (
is_soundfile_availble,
is_torch_available,
is_vision_available,
logging
)
logger = logging.get_logger(__name__)
if is_vision_available():
import PIL.Image
from PIL import Image
from PIL.Image import Image as ImageType
else:
ImageType = object
if is_torch_available():
import torch
if is_soundfile_availble():
import soundfile as sf
class AgentType:
"""
抽象类,用于定义代理返回的对象类型。
这些对象具有以下三个目的:
- 它们表现为它们所代表的类型,例如文本的字符串,图像的 PIL.Image
- 它们可以转化为字符串形式:str(object) 返回对象定义的字符串
- 它们应该在 ipython 笔记本/colab/jupyter 中正确显示
"""
def __init__(self, value):
self._value = value
def __str__(self):
return self.to_string()
def to_raw(self):
logger.error(
"This is a raw AgentType of unknown type. Display in notebooks and string conversion will be unreliable"
)
return self._value
def to_string(self) -> str:
logger.error(
"This is a raw AgentType of unknown type. Display in notebooks and string conversion will be unreliable"
)
return str(self._value)
class AgentText(AgentType, str):
"""
代理返回的文本类型,表现为字符串。
"""
def to_raw(self):
return self._value
def to_string(self):
return self._value
class AgentImage(AgentType, ImageType):
"""
代理返回的图像类型,表现为 PIL.Image。
"""
def __init__(self, value):
super().__init__(value)
if not is_vision_available():
raise ImportError("PIL must be installed in order to handle images.")
self._path = None
self._raw = None
self._tensor = None
if isinstance(value, ImageType):
self._raw = value
elif isinstance(value, (str, pathlib.Path)):
self._path = value
elif isinstance(value, torch.Tensor):
self._tensor = value
else:
raise ValueError(f"Unsupported type for {self.__class__.__name__}: {type(value)}")
def _ipython_display_(self, include=None, exclude=None):
"""
Displays correctly this type in an ipython notebook (ipython, colab, jupyter, ...)
"""
from IPython.display import Image, display
display(Image(self.to_string()))
def to_raw(self):
"""
Returns the "raw" version of that object. In the case of an AgentImage, it is a PIL.Image.
"""
if self._raw is not None:
return self._raw
if self._path is not None:
self._raw = Image.open(self._path)
return self._raw
def to_string(self):
"""
Returns the stringified version of that object. In the case of an AgentImage, it is a path to the serialized
version of the image.
"""
if self._path is not None:
return self._path
if self._raw is not None:
directory = tempfile.mkdtemp()
self._path = os.path.join(directory, str(uuid.uuid4()) + ".png")
self._raw.save(self._path)
return self._path
if self._tensor is not None:
array = self._tensor.cpu().detach().numpy()
img = Image.fromarray((array * 255).astype(np.uint8))
directory = tempfile.mkdtemp()
self._path = os.path.join(directory, str(uuid.uuid4()) + ".png")
img.save(self._path)
return self._path
class AgentAudio(AgentType):
"""
Audio type returned by the agent.
"""
def __init__(self, value, samplerate=16_000):
super().__init__(value)
if not is_soundfile_availble():
raise ImportError("soundfile must be installed in order to handle audio.")
self._path = None
self._tensor = None
self.samplerate = samplerate
if isinstance(value, (str, pathlib.Path)):
self._path = value
elif isinstance(value, torch.Tensor):
self._tensor = value
else:
raise ValueError(f"Unsupported audio type: {type(value)}")
def _ipython_display_(self, include=None, exclude=None):
"""
Displays correctly this type in an ipython notebook (ipython, colab, jupyter, ...)
"""
from IPython.display import Audio, display
display(Audio(self.to_string(), rate=self.samplerate))
def to_raw(self):
"""
Returns the "raw" version of that object. It is a `torch.Tensor` object.
"""
if self._tensor is not None:
return self._tensor
if self._path is not None:
tensor, self.samplerate = sf.read(self._path)
self._tensor = torch.tensor(tensor)
return self._tensor
def to_string(self):
"""
Returns the stringified version of that object. In the case of an AgentAudio, it is a path to the serialized
version of the audio.
"""
if self._path is not None:
return self._path
if self._tensor is not None:
directory = tempfile.mkdtemp()
self._path = os.path.join(directory, str(uuid.uuid4()) + ".wav")
sf.write(self._path, self._tensor, samplerate=self.samplerate)
return self._path
AGENT_TYPE_MAPPING = {"text": AgentText, "image": AgentImage, "audio": AgentAudio}
INSTANCE_TYPE_MAPPING = {str: AgentText}
if is_vision_available():
INSTANCE_TYPE_MAPPING[PIL.Image] = AgentImage
def handle_agent_inputs(*args, **kwargs):
"""
Handles input arguments by converting AgentType objects to their raw form (if applicable).
"""
args = [(arg.to_raw() if isinstance(arg, AgentType) else arg) for arg in args]
kwargs = {k: (v.to_raw() if isinstance(v, AgentType) else v) for k, v in kwargs.items()}
return args, kwargs
def handle_agent_outputs(outputs, output_types=None):
"""
Placeholder function to handle agent outputs.
"""
if isinstance(outputs, dict):
decoded_outputs = {}
for i, (k, v) in enumerate(outputs.items()):
if output_types is not None:
if output_types[i] in AGENT_TYPE_MAPPING:
decoded_outputs[k] = AGENT_TYPE_MAPPING[output_types[i]](v)
else:
decoded_outputs[k] = AgentType(v)
else:
for _k, _v in INSTANCE_TYPE_MAPPING.items():
if isinstance(v, _k):
decoded_outputs[k] = _v(v)
if k not in decoded_outputs:
decoded_outputs[k] = AgentType[v]
elif isinstance(outputs, (list, tuple)):
decoded_outputs = type(outputs)()
for i, v in enumerate(outputs):
if output_types is not None:
if output_types[i] in AGENT_TYPE_MAPPING:
decoded_outputs.append(AGENT_TYPE_MAPPING[output_types[i]](v))
else:
decoded_outputs.append(AgentType(v))
else:
found = False
for _k, _v in INSTANCE_TYPE_MAPPING.items():
if isinstance(v, _k):
decoded_outputs.append(_v(v))
found = True
if not found:
decoded_outputs.append(AgentType(v))
else:
if output_types[0] in AGENT_TYPE_MAPPING:
decoded_outputs = AGENT_TYPE_MAPPING[output_types[0]](outputs)
else:
for _k, _v in INSTANCE_TYPE_MAPPING.items():
if isinstance(outputs, _k):
return _v(outputs)
return AgentType(outputs)
return decoded_outputs
.\tools\base.py
import base64
import importlib
import inspect
import io
import json
import os
import tempfile
from typing import Any, Dict, List, Optional, Union
from huggingface_hub import create_repo, hf_hub_download, metadata_update, upload_folder
from huggingface_hub.utils import RepositoryNotFoundError, build_hf_headers, get_session
from ..dynamic_module_utils import custom_object_save, get_class_from_dynamic_module, get_imports
from ..image_utils import is_pil_image
from ..models.auto import AutoProcessor
from ..utils import (
CONFIG_NAME,
cached_file,
is_accelerate_available,
is_torch_available,
is_vision_available,
logging,
)
logger = logging.get_logger(__name__)
if is_torch_available():
import torch
if is_accelerate_available():
from accelerate import PartialState
from accelerate.utils import send_to_device
TOOL_CONFIG_FILE = "tool_config.json"
def get_repo_type(repo_id, repo_type=None, **hub_kwargs):
if repo_type is not None:
return repo_type
try:
hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="space", **hub_kwargs)
return "space"
except RepositoryNotFoundError:
try:
hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="model", **hub_kwargs)
return "model"
except RepositoryNotFoundError:
raise EnvironmentError(f"`{repo_id}` does not seem to be a valid repo identifier on the Hub.")
except Exception:
return "model"
except Exception:
return "space"
APP_FILE_TEMPLATE = """from transformers import launch_gradio_demo
from {module_name} import {class_name}
launch_gradio_demo({class_name})
"""
class Tool:
"""
代理函数使用的基类,实现 `__call__` 方法以及以下类属性:
- **description** (`str`) -- 工具功能的简要描述,包括预期的输入和输出。例如,'这是一个从 `url` 下载文件的工具。它接受 `url` 作为输入,并返回文件中的文本内容'。
"""
class Tool:
description: str = "This is a tool that ..."
name: str = ""
inputs: List[str]
outputs: List[str]
def __init__(self, *args, **kwargs):
self.is_initialized = False
def __call__(self, *args, **kwargs):
return NotImplemented("Write this method in your subclass of `Tool`.")
def setup(self):
self.is_initialized = True
def save(self, output_dir):
"""
Saves the relevant code files for your tool so it can be pushed to the Hub. This will copy the code of your
tool in `output_dir` as well as autogenerate:
- a config file named `tool_config.json`
- an `app.py` file so that your tool can be converted to a space
- a `requirements.txt` containing the names of the module used by your tool (as detected when inspecting its
code)
You should only use this method to save tools that are defined in a separate module (not `__main__`).
Args:
output_dir (`str`): The folder in which you want to save your tool.
"""
os.makedirs(output_dir, exist_ok=True)
if self.__module__ == "__main__":
raise ValueError(
f"We can't save the code defining {self} in {output_dir} as it's been defined in __main__. You "
"have to put this code in a separate module so we can include it in the saved folder."
)
module_files = custom_object_save(self, output_dir)
module_name = self.__class__.__module__
last_module = module_name.split(".")[-1]
full_name = f"{last_module}.{self.__class__.__name__}"
config_file = os.path.join(output_dir, "tool_config.json")
if os.path.isfile(config_file):
with open(config_file, "r", encoding="utf-8") as f:
tool_config = json.load(f)
else:
tool_config = {}
tool_config = {"tool_class": full_name, "description": self.description, "name": self.name}
with open(config_file, "w", encoding="utf-8") as f:
f.write(json.dumps(tool_config, indent=2, sort_keys=True) + "\n")
app_file = os.path.join(output_dir, "app.py")
with open(app_file, "w", encoding="utf-8") as f:
f.write(APP_FILE_TEMPLATE.format(module_name=last_module, class_name=self.__class__.__name__))
requirements_file = os.path.join(output_dir, "requirements.txt")
imports = []
for module in module_files:
imports.extend(get_imports(module))
imports = list(set(imports))
with open(requirements_file, "w", encoding="utf-8") as f:
f.write("\n".join(imports) + "\n")
) -> str:
"""
Upload the tool to the Hub.
Parameters:
repo_id (`str`):
The name of the repository you want to push your tool to. It should contain your organization name when
pushing to a given organization.
commit_message (`str`, *optional*, defaults to `"Upload tool"`):
Message to commit while pushing.
private (`bool`, *optional`):
Whether or not the repository created should be private.
token (`bool` or `str`, *optional*):
The token to use as HTTP bearer authorization for remote files. If unset, will use the token generated
when running `huggingface-cli login` (stored in `~/.huggingface`).
create_pr (`bool`, *optional*, defaults to `False`):
Whether or not to create a PR with the uploaded files or directly commit.
"""
repo_url = create_repo(
repo_id=repo_id, token=token, private=private, exist_ok=True, repo_type="space", space_sdk="gradio"
)
repo_id = repo_url.repo_id
metadata_update(repo_id, {"tags": ["tool"]}, repo_type="space")
with tempfile.TemporaryDirectory() as work_dir:
self.save(work_dir)
logger.info(f"Uploading the following files to {repo_id}: {','.join(os.listdir(work_dir))}")
return upload_folder(
repo_id=repo_id,
commit_message=commit_message,
folder_path=work_dir,
token=token,
create_pr=create_pr,
repo_type="space",
)
@staticmethod
def from_gradio(gradio_tool):
"""
Creates a [`Tool`] from a gradio tool.
"""
class GradioToolWrapper(Tool):
def __init__(self, _gradio_tool):
super().__init__()
self.name = _gradio_tool.name
self.description = _gradio_tool.description
GradioToolWrapper.__call__ = gradio_tool.run
return GradioToolWrapper(gradio_tool)
class RemoteTool(Tool):
"""
A [`Tool`] that will make requests to an inference endpoint.
Args:
endpoint_url (`str`, *optional*):
The url of the endpoint to use.
token (`str`, *optional*):
The token to use as HTTP bearer authorization for remote files. If unset, will use the token generated when
running `huggingface-cli login` (stored in `~/.huggingface`).
tool_class (`type`, *optional`):
The corresponding `tool_class` if this is a remote version of an existing tool. Will help determine when
the output should be converted to another type (like images).
"""
def __init__(self, endpoint_url=None, token=None, tool_class=None):
self.endpoint_url = endpoint_url
self.client = EndpointClient(endpoint_url, token=token)
self.tool_class = tool_class
def prepare_inputs(self, *args, **kwargs):
"""
Prepare the inputs received for the HTTP client sending data to the endpoint. Positional arguments will be
matched with the signature of the `tool_class` if it was provided at instantiation. Images will be encoded into
bytes.
You can override this method in your custom class of [`RemoteTool`].
"""
inputs = kwargs.copy()
if len(args) > 0:
if self.tool_class is not None:
if issubclass(self.tool_class, PipelineTool):
call_method = self.tool_class.encode
else:
call_method = self.tool_class.__call__
signature = inspect.signature(call_method).parameters
parameters = [
k
for k, p in signature.items()
if p.kind not in [inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD]
]
if parameters[0] == "self":
parameters = parameters[1:]
if len(args) > len(parameters):
raise ValueError(
f"{self.tool_class} only accepts {len(parameters)} arguments but {len(args)} were given."
)
for arg, name in zip(args, parameters):
inputs[name] = arg
elif len(args) > 1:
raise ValueError("A `RemoteTool` can only accept one positional input.")
elif len(args) == 1:
if is_pil_image(args[0]):
return {"inputs": self.client.encode_image(args[0])}
return {"inputs": args[0]}
for key, value in inputs.items():
if is_pil_image(value):
inputs[key] = self.client.encode_image(value)
return {"inputs": inputs}
def extract_outputs(self, outputs):
"""
You can override this method in your custom class of [`RemoteTool`] to apply some custom post-processing of the
outputs of the endpoint.
"""
return outputs
def __call__(self, *args, **kwargs):
args, kwargs = handle_agent_inputs(*args, **kwargs)
output_image = self.tool_class is not None and self.tool_class.outputs == ["image"]
inputs = self.prepare_inputs(*args, **kwargs)
if isinstance(inputs, dict):
outputs = self.client(**inputs, output_image=output_image)
else:
outputs = self.client(inputs, output_image=output_image)
if isinstance(outputs, list) and len(outputs) == 1 and isinstance(outputs[0], list):
outputs = outputs[0]
outputs = handle_agent_outputs(outputs, self.tool_class.outputs if self.tool_class is not None else None)
return self.extract_outputs(outputs)
class PipelineTool(Tool):
"""
A [`Tool`] tailored towards Transformer models. On top of the class attributes of the base class [`Tool`], you will
need to specify:
- **model_class** (`type`) -- The class to use to load the model in this tool.
- **default_checkpoint** (`str`) -- The default checkpoint that should be used when the user doesn't specify one.
- **pre_processor_class** (`type`, *optional*, defaults to [`AutoProcessor`]) -- The class to use to load the
pre-processor
- **post_processor_class** (`type`, *optional*, defaults to [`AutoProcessor`]) -- The class to use to load the
post-processor (when different from the pre-processor).
Args:
model (`str` or [`PreTrainedModel`], *optional*):
The name of the checkpoint to use for the model, or the instantiated model. If unset, will default to the
value of the class attribute `default_checkpoint`.
pre_processor (`str` or `Any`, *optional*):
The name of the checkpoint to use for the pre-processor, or the instantiated pre-processor (can be a
tokenizer, an image processor, a feature extractor or a processor). Will default to the value of `model` if
unset.
post_processor (`str` or `Any`, *optional*):
The name of the checkpoint to use for the post-processor, or the instantiated pre-processor (can be a
tokenizer, an image processor, a feature extractor or a processor). Will default to the `pre_processor` if
unset.
device (`int`, `str` or `torch.device`, *optional*):
The device on which to execute the model. Will default to any accelerator available (GPU, MPS etc...), the
CPU otherwise.
device_map (`str` or `dict`, *optional*):
If passed along, will be used to instantiate the model.
model_kwargs (`dict`, *optional*):
Any keyword argument to send to the model instantiation.
token (`str`, *optional*):
The token to use as HTTP bearer authorization for remote files. If unset, will use the token generated when
running `huggingface-cli login` (stored in `~/.huggingface`).
hub_kwargs (additional keyword arguments, *optional*):
Any additional keyword argument to send to the methods that will load the data from the Hub.
"""
pre_processor_class = AutoProcessor
model_class = None
post_processor_class = AutoProcessor
default_checkpoint = None
def __init__(
self,
model=None,
pre_processor=None,
post_processor=None,
device=None,
device_map=None,
model_kwargs=None,
token=None,
**hub_kwargs,
):
if not is_torch_available():
raise ImportError("Please install torch in order to use this tool.")
if not is_accelerate_available():
raise ImportError("Please install accelerate in order to use this tool.")
if model is None:
if self.default_checkpoint is None:
raise ValueError("This tool does not implement a default checkpoint, you need to pass one.")
model = self.default_checkpoint
if pre_processor is None:
pre_processor = model
self.model = model
self.pre_processor = pre_processor
self.post_processor = post_processor
self.device = device
self.device_map = device_map
self.model_kwargs = {} if model_kwargs is None else model_kwargs
if device_map is not None:
self.model_kwargs["device_map"] = device_map
self.hub_kwargs = hub_kwargs
self.hub_kwargs["token"] = token
super().__init__()
def setup(self):
"""
Instantiates the `pre_processor`, `model` and `post_processor` if necessary.
"""
if isinstance(self.pre_processor, str):
self.pre_processor = self.pre_processor_class.from_pretrained(self.pre_processor, **self.hub_kwargs)
if isinstance(self.model, str):
self.model = self.model_class.from_pretrained(self.model, **self.model_kwargs, **self.hub_kwargs)
if self.post_processor is None:
self.post_processor = self.pre_processor
elif isinstance(self.post_processor, str):
self.post_processor = self.post_processor_class.from_pretrained(self.post_processor, **self.hub_kwargs)
if self.device is None:
if self.device_map is not None:
self.device = list(self.model.hf_device_map.values())[0]
else:
self.device = PartialState().default_device
if self.device_map is None:
self.model.to(self.device)
super().setup()
def encode(self, raw_inputs):
"""
Uses the `pre_processor` to prepare the inputs for the `model`.
"""
return self.pre_processor(raw_inputs)
def forward(self, inputs):
"""
Sends the inputs through the `model`.
"""
with torch.no_grad():
return self.model(**inputs)
def decode(self, outputs):
"""
Uses the `post_processor` to decode the model output.
"""
return self.post_processor(outputs)
def __call__(self, *args, **kwargs):
args, kwargs = handle_agent_inputs(*args, **kwargs)
if not self.is_initialized:
self.setup()
encoded_inputs = self.encode(*args, **kwargs)
encoded_inputs = send_to_device(encoded_inputs, self.device)
outputs = self.forward(encoded_inputs)
outputs = send_to_device(outputs, "cpu")
decoded_outputs = self.decode(outputs)
return handle_agent_outputs(decoded_outputs, self.outputs)
def launch_gradio_demo(tool_class: Tool):
try:
import gradio as gr
except ImportError:
raise ImportError("Gradio 应该安装才能启动 gradio 演示。")
tool = tool_class()
def fn(*args, **kwargs):
return tool(*args, **kwargs)
gr.Interface(
fn=fn,
inputs=tool_class.inputs,
outputs=tool_class.outputs,
title=tool_class.__name__,
article=tool.description,
).launch()
TASK_MAPPING = {
"document-question-answering": "DocumentQuestionAnsweringTool",
"image-captioning": "ImageCaptioningTool",
"image-question-answering": "ImageQuestionAnsweringTool",
"image-segmentation": "ImageSegmentationTool",
"speech-to-text": "SpeechToTextTool",
"summarization": "TextSummarizationTool",
"text-classification": "TextClassificationTool",
"text-question-answering": "TextQuestionAnsweringTool",
"text-to-speech": "TextToSpeechTool",
"translation": "TranslationTool",
}
def get_default_endpoints():
endpoints_file = cached_file("huggingface-tools/default-endpoints", "default_endpoints.json", repo_type="dataset")
with open(endpoints_file, "r", encoding="utf-8") as f:
endpoints = json.load(f)
return endpoints
def supports_remote(task_or_repo_id):
endpoints = get_default_endpoints()
return task_or_repo_id in endpoints
def load_tool(task_or_repo_id, model_repo_id=None, remote=False, token=None, **kwargs):
"""
主函数,快速加载一个工具,无论是在 Hub 上还是在 Transformers 库中。
<Tip warning={true}>
加载工具意味着你会下载并在本地执行该工具。
在加载到运行时之前,始终检查你要下载的工具,就像使用 pip/npm/apt 安装软件包时一样。
</Tip>
"""
if task_or_repo_id in TASK_MAPPING:
tool_class_name = TASK_MAPPING[task_or_repo_id]
main_module = importlib.import_module("transformers")
tools_module = main_module.tools
tool_class = getattr(tools_module, tool_class_name)
if remote:
if model_repo_id is None:
endpoints = get_default_endpoints()
if task_or_repo_id not in endpoints:
raise ValueError(
f"Could not infer a default endpoint for {task_or_repo_id}, you need to pass one using the "
"`model_repo_id` argument."
)
model_repo_id = endpoints[task_or_repo_id]
return RemoteTool(model_repo_id, token=token, tool_class=tool_class)
else:
return tool_class(model_repo_id, token=token, **kwargs)
else:
logger.warning_once(
f"You're loading a tool from the Hub from {model_repo_id}. Please make sure this is a source that you "
f"trust as the code within that tool will be executed on your machine. Always verify the code of "
f"the tools that you load. We recommend specifying a `revision` to ensure you're loading the "
f"code that you have checked."
)
return Tool.from_hub(task_or_repo_id, model_repo_id=model_repo_id, token=token, remote=remote, **kwargs)
def add_description(description):
"""
A decorator that adds a description to a function.
"""
def inner(func):
func.description = description
func.name = func.__name__
return func
return inner
class EndpointClient:
def __init__(self, endpoint_url: str, token: Optional[str] = None):
self.headers = {**build_hf_headers(token=token), "Content-Type": "application/json"}
self.endpoint_url = endpoint_url
@staticmethod
def encode_image(image):
_bytes = io.BytesIO()
image.save(_bytes, format="PNG")
b64 = base64.b64encode(_bytes.getvalue())
return b64.decode("utf-8")
@staticmethod
def decode_image(raw_image):
if not is_vision_available():
raise ImportError(
"This tool returned an image but Pillow is not installed. Please install it (`pip install Pillow`)."
)
from PIL import Image
b64 = base64.b64decode(raw_image)
_bytes = io.BytesIO(b64)
return Image.open(_bytes)
def __call__(
self,
inputs: Optional[Union[str, Dict, List[str], List[List[str]]]] = None,
params: Optional[Dict] = None,
data: Optional[bytes] = None,
output_image: bool = False,
) -> Any:
payload = {}
if inputs:
payload["inputs"] = inputs
if params:
payload["parameters"] = params
response = get_session().post(self.endpoint_url, headers=self.headers, json=payload, data=data)
if output_image:
return self.decode_image(response.content)
else:
return response.json()
.\tools\document_question_answering.py
import re
from ..models.auto import AutoProcessor
from ..models.vision_encoder_decoder import VisionEncoderDecoderModel
from ..utils import is_vision_available
from .base import PipelineTool
if is_vision_available():
from PIL import Image
class DocumentQuestionAnsweringTool(PipelineTool):
default_checkpoint = "naver-clova-ix/donut-base-finetuned-docvqa"
description = (
"This is a tool that answers a question about an document (pdf). It takes an input named `document` which "
"should be the document containing the information, as well as a `question` that is the question about the "
"document. It returns a text that contains the answer to the question."
)
name = "document_qa"
pre_processor_class = AutoProcessor
model_class = VisionEncoderDecoderModel
inputs = ["image", "text"]
outputs = ["text"]
def __init__(self, *args, **kwargs):
if not is_vision_available():
raise ValueError("Pillow must be installed to use the DocumentQuestionAnsweringTool.")
super().__init__(*args, **kwargs)
def encode(self, document: "Image", question: str):
task_prompt = "<s_docvqa><s_question>{user_input}</s_question><s_answer>"
prompt = task_prompt.replace("{user_input}", question)
decoder_input_ids = self.pre_processor.tokenizer(
prompt, add_special_tokens=False, return_tensors="pt"
).input_ids
pixel_values = self.pre_processor(document, return_tensors="pt").pixel_values
return {"decoder_input_ids": decoder_input_ids, "pixel_values": pixel_values}
def forward(self, inputs):
return self.model.generate(
inputs["pixel_values"].to(self.device),
decoder_input_ids=inputs["decoder_input_ids"].to(self.device),
max_length=self.model.decoder.config.max_position_embeddings,
early_stopping=True,
pad_token_id=self.pre_processor.tokenizer.pad_token_id,
eos_token_id=self.pre_processor.tokenizer.eos_token_id,
use_cache=True,
num_beams=1,
bad_words_ids=[[self.pre_processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
).sequences
def decode(self, outputs):
sequence = self.pre_processor.batch_decode(outputs)[0]
sequence = sequence.replace(self.pre_processor.tokenizer.eos_token, "")
sequence = sequence.replace(self.pre_processor.tokenizer.pad_token, "")
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip()
sequence = self.pre_processor.token2json(sequence)
return sequence["answer"]
.\tools\evaluate_agent.py
def classifier(text, labels):
return f"This is the classification of {text} along {labels}."
def translator(text, src_lang, tgt_lang):
return f"This is the translation of {text} from {src_lang} to {tgt_lang}."
def speaker(text):
return f"This is actually a sound reading {text}."
def transcriber(audio):
if "sound" not in audio:
raise ValueError(f"`audio` ({audio}) is not a sound.")
return f"This is the transcribed text from {audio}."
def image_generator(prompt):
return f"This is actually an image representing {prompt}."
def image_captioner(image):
if "image" not in image:
raise ValueError(f"`image` ({image}) is not an image.")
return f"This is a description of {image}."
def image_transformer(image, prompt):
if "image" not in image:
raise ValueError(f"`image` ({image}) is not an image.")
return f"This is a transformation of {image} according to {prompt}."
def question_answerer(text, question):
return f"This is the answer to {question} from {text}."
def image_qa(image, question):
if "image" not in image:
raise ValueError(f"`image` ({image}) is not an image.")
return f"This is the answer to {question} from {image}."
def text_downloader(url):
return f"This is the content of {url}."
def summarizer(text):
return f"This is a summary of {text}."
def video_generator(prompt, seconds=2):
return f"A video of {prompt}"
def document_qa(image, question):
return f"This is the answer to {question} from the document {image}."
def image_segmenter(image, prompt):
return f"This is the mask of {prompt} in {image}"
TEST_TOOLS = {
"text_classifier": classifier,
"translator": translator,
"text_reader": speaker,
"summarizer": summarizer,
"transcriber": transcriber,
"image_generator": image_generator,
"image_captioner": image_captioner,
"image_transformer": image_transformer,
"text_qa": question_answerer,
"text_downloader": text_downloader,
"image_qa": image_qa,
"video_generator": video_generator,
"document_qa": document_qa,
"image_segmenter": image_segmenter,
}
class Problem:
"""
占位符类,暂时没有定义任何内容
"""
Args:
task (`str` 或 `list[str]`):
要执行任务的一个或多个描述。如果是列表,则应包含相同任务的不同表达方式。
inputs (`list[str]` 或 `dict[str, str]`):
将提供给工具的输入。在这个测试环境中,只接受字符串作为值。当你想要指定每个输入的值时,请传递一个字典;或者直接传递期望的输入列表(在这种情况下,使用 `<<input_name>>` 作为值)。
answer (`str` 或 `list[str]`):
问题的理论答案(或可能的有效答案列表),作为代码。
"""
# 初始化方法,用于设置实例的属性
def __init__(self, task, inputs, answer):
self.task = task # 将传入的任务描述存储在实例的属性中
self.inputs = inputs # 将传入的输入数据存储在实例的属性中
self.answer = answer # 将传入的答案数据存储在实例的属性中
### 定义一个评估任务列表,包含多个问题实例
EVALUATION_TASKS = [
# 定义一个问题实例,任务是判断给定的 `text`(西班牙语)是积极还是消极的
Problem(
task=[
"Is the following `text` (in Spanish) positive or negative?",
"Is the text in the variable `text` (in Spanish) positive or negative?",
"Translate the following `text` from Spanish to English then tell me if its positive or negative.",
],
inputs=["text"],
# 答案是一个字符串表达式,调用了多个函数来进行文本处理和分类
answer="""text_classifier(translator(text, src_lang="Spanish", tgt_lang="English"), labels=["positive", "negative"])""",
),
# 定义一个问题实例,任务是描述给定的 `image` 包含的内容
Problem(
task=[
"Tell me out loud what the `image` contains.",
"Describe the following `image` out loud.",
"Find what is in the picture stored in `image` then read it out loud.",
],
inputs=["image"],
# 答案是一个列表,包含了两种描述图片内容的方法
answer=[
"text_reader(image_captioner(image))",
"text_reader(image_qa(image, question='What is in the image?'))",
],
),
# 定义一个问题实例,任务是根据 `text_input` 生成图片,然后根据 `prompt` 进行变换
Problem(
task=[
"Generate an image from the text given in `text_input`. Then transform it according to the text in `prompt`.",
"Use the following `text_input` to generate an image, then transform it by using the text in `prompt`.",
],
inputs=["text_input", "prompt"],
# 答案是一个字符串,调用了多个函数来生成并变换图片
answer="image_transformer(image_generator(text_input), prompt)",
),
# 定义一个问题实例,任务是根据 `url` 下载内容,进行摘要并生成一张图片
Problem(
task=[
"Download the content of `url`, summarize it then generate an image from its content.",
"Use a summary of the web page at `url` to generate an image.",
"Summarize the content of the web page at `url`, and use the result to generate an image.",
],
inputs=["url"],
# 答案是一个字符串,调用了多个函数来下载、摘要并生成图片
answer="image_generator(summarizer(text_downloader(url)))",
),
# 定义一个问题实例,任务是根据 `text` 和 `image` 进行图片的文本提示变换
Problem(
task=[
"Transform the following `image` using the prompt in `text`. The prompt is in Spanish.",
"Use the text prompt in `text` (in Spanish) to transform the following `image`.",
"Translate the `text` from Spanish to English then use it to transform the picture in `image`.",
],
inputs=["text", "image"],
# 答案是一个字符串,调用了多个函数来进行图片的文本提示变换
answer="image_transformer(image, translator(text, src_lang='Spanish', tgt_lang='English'))",
),
# 定义一个问题实例,任务是根据 `url` 下载内容,进行摘要并朗读摘要
Problem(
task=[
"Download the content of `url`, summarize it then read it out loud to me.",
"Read me a summary of the web page at `url`.",
],
inputs=["url"],
# 答案是一个字符串,调用了多个函数来下载、摘要并朗读摘要
answer="text_reader(summarizer(text_downloader(url)))",
),
# 定义一个问题实例,任务是根据 `text_input` 生成一张图片
Problem(
task=[
"Generate an image from the text given in `text_input`.",
],
inputs=["text_input"],
# 答案是一个字符串,调用了一个函数来生成图片
answer="image_generator(text_input)",
),
]
Problem(
task=[
"Replace the beaver in the `image` by the `prompt`.",
"Transform the `image` so that it contains the `prompt`.",
"Use `prompt` to transform this `image`.",
],
inputs=["image", "prompt"],
answer="image_transformer(image, prompt)",
),
Problem(
task=[
"Provide me the summary of the `text`, then read it to me before transcribing it and translating it in French.",
"Summarize `text`, read it out loud then transcribe the audio and translate it in French.",
"Read me a summary of the `text` out loud. Transcribe this and translate it in French.",
],
inputs=["text"],
answer="translator(transcriber(text_reader(summarizer(text))), src_lang='English', tgt_lang='French')",
),
Problem(
task=["Generate a video of the `prompt`", "Animate a `prompt`", "Make me a short video using `prompt`."],
inputs={"prompt": "A lobster swimming"},
answer="video_generator('A lobster swimming')",
),
Problem(
task=[
"Download the following file `url`, summarize it in a few words and generate a video from it."
"Fetch the file at this `url`, summarize it, and create an animation out of it."
],
inputs=["url"],
answer="video_generator(summarizer(text_downloader(url)))",
),
Problem(
task=[
"Replace the beaver in the `image` by the `prompt`.",
"Transform the `image` so that it contains the `prompt`.",
"Use `prompt` to transform this `image`.",
],
inputs=["image", "prompt"],
answer="image_transformer(image, prompt)",
),
# 创建一个 Problem 对象,包含了替换图片中的某物体和转换图片的任务,输入为图片和提示,答案为调用 image_transformer 函数
Problem(
task=[
"Provide me the summary of the `text`, then read it to me before transcribing it and translating it in French.",
"Summarize `text`, read it out loud then transcribe the audio and translate it in French.",
"Read me a summary of the `text` out loud. Transcribe this and translate it in French.",
],
inputs=["text"],
answer="translator(transcriber(text_reader(summarizer(text))), src_lang='English', tgt_lang='French')",
),
# 创建一个 Problem 对象,包含了对文本进行摘要、朗读、转录和翻译任务,输入为文本,答案为复合函数调用
Problem(
task=["Generate a video of the `prompt`", "Animate a `prompt`", "Make me a short video using `prompt`."],
inputs={"prompt": "A lobster swimming"},
answer="video_generator('A lobster swimming')",
),
# 创建一个 Problem 对象,包含了根据提示生成视频的任务,输入为提示,答案为调用 video_generator 函数
Problem(
task=[
"Download the following file `url`, summarize it in a few words and generate a video from it."
"Fetch the file at this `url`, summarize it, and create an animation out of it."
],
inputs=["url"],
answer="video_generator(summarizer(text_downloader(url)))",
),
# 创建一个 Problem 对象,包含了从 URL 下载文件、摘要并生成视频的任务,输入为 URL,答案为调用 video_generator 函数
EVALUATION_CHATS = [
[ # 开始一个列表,包含多个问题对象
Problem( # 创建第一个问题对象
task=[ # 问题描述列表
"Translate the following `text` from Spanish to English.", # 翻译从西班牙语到英语的文本
"Translate the following `text` from Spanish to English.", # 同上,重复描述
],
inputs=["text"], # 输入参数为一个文本字符串
answer="translated_text=translator(text, src_lang='Spanish', tgt_lang='English')", # 答案是调用翻译器函数进行翻译
),
Problem( # 创建第二个问题对象
task=[ # 问题描述列表
"Is it positive or negative?", # 判断文本情感是积极还是消极
"Tell me if its positive or negative.", # 同上,重复描述
],
inputs=[], # 无输入参数
answer="text_classifier(translated_text, labels=['positive', 'negative'])", # 使用文本分类器判断文本情感
),
],
[ # 开始第二个问题列表
Problem( # 创建第一个问题对象
task=[ # 问题描述列表
"What does this `image` contain?", # 描述图像包含的内容
"Describe the following `image`.", # 描述以下的图像
"Find what is in the picture stored in `image`", # 找出存储在 `image` 中图片的内容
],
inputs=["image"], # 输入参数为一个图像
answer=[ # 答案是一个包含两个动作的列表
"description=image_captioner(image)", # 生成图像描述
"description=image_qa(image, question='What is in the image?')", # 使用图像问答系统找出图像中的内容
],
),
Problem( # 创建第二个问题对象
task=[ # 问题描述列表
"Now, read the description out loud.", # 现在大声朗读描述
"Great! Can you read it out loud?", # 太棒了!你能大声朗读吗?
"Read it out loud.", # 大声朗读
],
inputs=[], # 无输入参数
answer=["audio=text_reader(description)", "audio=text_reader(description)"], # 生成描述的语音输出
),
],
[ # 开始第三个问题列表
Problem( # 创建第一个问题对象
task=[ # 问题描述列表
"Generate an image from the text given in `text_input`.", # 使用 `text_input` 中的文本生成图像
"Use the following `text_input` to generate an image", # 使用以下 `text_input` 生成图像
],
inputs=["text_input"], # 输入参数为一个文本输入
answer="image = image_generator(text_input)", # 生成图像的操作
),
Problem( # 创建第二个问题对象
task=[ # 问题描述列表
"Transform it according to the text in `prompt`.", # 根据 `prompt` 中的文本对图像进行转换
"Transform it by using the text in `prompt`.", # 使用 `prompt` 中的文本进行转换
],
inputs=["prompt"], # 输入参数为一个提示文本
answer="image_transformer(image, prompt)", # 对图像进行转换的操作
),
],
[ # 开始第四个问题列表
Problem( # 创建第一个问题对象
task=[ # 问题描述列表
"Download the content of `url` and summarize it.", # 下载 `url` 的内容并进行摘要
"Summarize the content of the web page at `url`.", # 总结位于 `url` 的网页内容
],
inputs=["url"], # 输入参数为一个 URL
answer="summary = summarizer(text_downloader(url))", # 使用文本下载器下载内容并进行摘要生成
),
Problem( # 创建第二个问题对象
task=[ # 问题描述列表
"Generate an image from its content.", # 从其内容生成一幅图像
"Use the previous result to generate an image.", # 使用上述结果生成图像
],
inputs=[], # 无输入参数
answer="image_generator(summary)", # 根据摘要内容生成图像
),
],
]
[
# 第一个问题组
Problem(
# 任务描述:将这段西班牙文`text`翻译成英文。
task=[
"Translate this Spanish `text` in English.",
"Translate the `text` from Spanish to English.",
],
# 输入参数:text,需要翻译的文本
inputs=["text"],
# 答案:调用translator函数进行翻译,从西班牙语到英语
answer="translated_text = translator(text, src_lang='Spanish', tgt_lang='English')",
),
Problem(
# 任务描述:使用翻译后的`text`来转换以下的`image`。
task=[
"Transform the following `image` using the translated `text`.",
"Use the previous result to transform the following `image`.",
],
# 输入参数:image,需要进行转换的图像;translated_text,已翻译的文本
inputs=["image"],
# 答案:调用image_transformer函数,使用翻译后的文本来转换图像
answer="image_transformer(image, translated_text)",
),
],
[
# 第二个问题组
Problem(
# 任务描述:下载`url`的内容。
task=["Download the content of `url`.", "Get me the text on the web page `url`."],
# 输入参数:url,需要下载内容的网址
inputs=["url"],
# 答案:调用text_downloader函数下载网页内容
answer="text = text_downloader(url)",
),
Problem(
# 任务描述:对文本进行总结。
task=["Summarize this text.", "Summarize this text."],
# 输入参数:无(使用前面下载的文本)
inputs=[],
# 答案:调用summarizer函数对文本进行总结
answer="summary = summarizer(text)",
),
Problem(
# 任务描述:朗读给我听。
task=["Read it out loud to me.", "Read me the previous result."],
# 输入参数:无(使用前面生成的总结文本)
inputs=[],
# 答案:调用text_reader函数朗读总结文本
answer="text_reader(summary)",
),
],
[
# 第三个问题组
Problem(
# 任务描述:根据给定的`text_input`生成一张图像。
task=["Generate an image from the text given in `text_input`."],
# 输入参数:text_input,用于生成图像的文本输入
inputs=["text_input"],
# 答案:调用image_generator函数生成图像
answer="image_generator(text_input)",
),
],
[
# 第四个问题组
Problem(
# 任务描述:用`prompt`替换`image`中的海狸。
task=[
"Replace the beaver in the `image` by the `prompt`.",
"Transform the `image` so that it contains the `prompt`.",
"Use `prompt` to transform this `image`.",
],
# 输入参数:image,需要进行转换的图像;prompt,用于替换的提示
inputs=["image", "prompt"],
# 答案:调用image_transformer函数,使用prompt来转换图像
answer="image_transformer(image, prompt)",
),
],
[
# 第五个问题组
Problem(
# 任务描述:提供`text`的摘要。
task=["Provide me the summary of the `text`.", "Summarize `text`."],
# 输入参数:text,需要进行总结的文本
inputs=["text"],
# 答案:调用summarizer函数对文本进行总结
answer="summary = summarizer(text)",
),
Problem(
# 任务描述:将摘要朗读给我听。
task=["Read this summary to me.", "Read it out loud."],
# 输入参数:无(使用前面生成的总结文本)
inputs=[],
# 答案:调用text_reader函数朗读总结文本
answer="audio = text_reader(summarizer(text))",
),
Problem(
# 任务描述:将上一结果转录成文本。
task=["Transcribing the previous result back in text.", "Transcribe the audio."],
# 输入参数:无(使用前面生成的音频)
inputs=[],
# 答案:调用transcriber函数将音频转录成文本
answer="text = transcriber(audio)",
),
Problem(
# 任务描述:将上一结果翻译成法语。
task=["Translating the last result in French.", "Translate this in French."],
# 输入参数:无(使用前面生成的文本)
inputs=[],
# 答案:调用translator函数将文本从英语翻译成法语
answer="translator(text, src_lang='English', tgt_lang='French')",
),
],
[
# 第六个问题组
Problem(
# 任务描述:根据`prompt`生成一个视频。
task=[
"Generate a video of the `prompt`",
"Animate a `prompt`",
"Make me a short video using `prompt`.",
],
# 输入参数:prompt,用于生成视频的提示文本
inputs={"prompt": "A lobster swimming"},
# 答案:调用video_generator函数生成视频
answer="video_generator('A lobster swimming')",
),
],
[
# 创建一个包含两个问题的列表,每个问题包括任务描述、输入要求和答案方法
Problem(
# 第一个问题的任务描述
task=[
"Download the content of `url` and summarize it.",
"Summarize the content of the web page at `url`."
],
# 第一个问题的输入要求,需要一个参数 `url`
inputs=["url"],
# 第一个问题的答案方法,使用 `text_downloader` 下载 `url` 的内容,然后使用 `summarizer` 进行总结
answer="summary = summarizer(text_downloader(url))"
),
# 第二个问题的问题描述
Problem(
task=["generate a video from it.", "Create an animation from the last result."],
# 第二个问题没有输入要求,所以是一个空列表
inputs=[],
# 第二个问题的答案方法,使用上一个问题中生成的 `summary` 来生成视频
answer="video_generator(summary)"
),
],
# 定义函数,用于获取理论工具集和代码中实际使用的工具集的比较结果
def get_theoretical_tools(agent_answer, theoretical_answer, code_answer):
# 如果理论答案不是列表,则返回代码中的测试工具集合
if not isinstance(theoretical_answer, list):
return {name for name in TEST_TOOLS if name in code_answer}
# 如果代理答案是字典类型,则逐个比较理论答案和代码答案
if isinstance(agent_answer, dict):
for one_answer, one_code in zip(theoretical_answer, code_answer):
# 如果代理答案的值包含在理论答案中,则返回在代码中使用的测试工具集合
if one_answer in agent_answer.values():
return {name for name in TEST_TOOLS if name in one_code}
# 逐个比较理论答案和代码答案
for one_answer, one_code in zip(theoretical_answer, code_answer):
# 如果代理答案等于理论答案之一,则返回在代码中使用的测试工具集合
if agent_answer == one_answer:
return {name for name in TEST_TOOLS if name in one_code}
# 返回代码中使用的第一个测试工具集合
return {name for name in TEST_TOOLS if name in code_answer[0]}
# 定义函数,评估给定的代码
def evaluate_code(code, inputs=None, state=None, verbose=False, return_interpretor_error=False):
# 复制基本的 Python 工具集到当前工具集中
tools = BASE_PYTHON_TOOLS.copy()
# 遍历测试工具集合,将代码中使用的工具添加到当前工具集中
for name, tool in TEST_TOOLS.items():
if name not in code:
continue
tools[name] = tool
# 如果输入是字典类型,则复制一份输入
if isinstance(inputs, dict):
inputs = inputs.copy()
# 如果输入不为空,则将每个输入映射到特定的占位符格式
elif inputs is not None:
inputs = {inp: f"<<{inp}>>" for inp in inputs}
# 如果状态不为空,则更新状态信息,否则使用输入作为状态信息
if state is not None:
state.update(inputs)
else:
state = inputs
try:
# 尝试评估代码,使用当前工具集和状态
return evaluate(code, tools, state)
except InterpretorError as e:
# 如果发生解释器错误,则返回错误消息字符串
return str(e)
except Exception as e:
# 如果发生其他异常,根据 verbose 参数决定是否打印异常信息,并返回 None
if verbose:
print(e)
return None
# 定义函数,评分给定的代码答案
def score_code(agent_answer, theoretical_answer, verbose: bool = False):
# 如果 verbose 为 True,则打印代理答案和理论答案
if verbose:
print(agent_answer, theoretical_answer)
# 如果理论答案不是列表,则将其转换为列表形式
theoretical_answer = theoretical_answer if isinstance(theoretical_answer, list) else [theoretical_answer]
# 如果代理答案包含在理论答案中,则返回完美匹配的分数 1.0
if agent_answer in theoretical_answer:
if verbose:
print("Perfect!")
return 1
# 如果代理答案是字典类型,并且其值在理论答案中,则返回部分匹配的分数 0.75
elif isinstance(agent_answer, dict) and any(v in theoretical_answer for v in agent_answer.values()):
if verbose:
print("Almost perfect, result in state!")
return 0.75
# 否则,返回未完全匹配的分数 0.3
else:
if verbose:
print("Result is not the right one but code executed.")
return 0.3
# 定义函数,评估单个结果的解释
def evaluate_one_result(explanation, code, agent_answer, theoretical_answer, answer, verbose=False):
# 提取解释中使用的工具集合
tools_in_explanation = {name for name in TEST_TOOLS if f"`{name}`" in explanation}
# 获取理论工具集和代码实际使用的工具集的比较结果
theoretical_tools = get_theoretical_tools(agent_answer, theoretical_answer, answer)
# 如果解释中使用的工具集与理论工具集完全匹配,则设置工具选择分数为 1.0,无工具选择错误
if tools_in_explanation == theoretical_tools:
tool_selection_score = 1.0
tool_selection_errors = None
else:
# 否则,计算缺失的工具和意外的工具数量,并计算工具选择分数
missing_tools = len(theoretical_tools - tools_in_explanation)
unexpected_tools = len(tools_in_explanation - theoretical_tools)
tool_selection_score = max(0, 1.0 - 0.25 * missing_tools - 0.25 * unexpected_tools)
# 设置工具选择错误信息
tool_selection_errors = {
"selected_tools": tools_in_explanation,
"theoretical_tools": theoretical_tools,
}
# 提取代码中使用的工具集合
tools_in_code = {name for name in TEST_TOOLS if name in code}
# 如果代码中使用的工具与理论工具相匹配
if tools_in_code == theoretical_tools:
# 工具使用得分为满分 1.0
tool_used_score = 1.0
# 错误信息为空
tool_used_errors = None
else:
# 计算缺失的工具数量
missing_tools = len(theoretical_tools - tools_in_code)
# 计算多余的工具数量
unexpected_tools = len(tools_in_code - theoretical_tools)
# 计算工具使用得分,考虑缺失工具和多余工具的惩罚
tool_used_score = max(0, 1.0 - 0.25 * missing_tools - 0.25 * unexpected_tools)
# 生成工具使用错误信息,包含选中的工具和理论上应有的工具
tool_used_errors = {
"selected_tools": tools_in_explanation,
"theoretical_tools": theoretical_tools,
}
# 对代码进行评分,返回评分结果
score = score_code(agent_answer, theoretical_answer, verbose=verbose)
# 如果评分小于 1.0
if score < 1.0:
# 生成代码错误信息,包含生成的代码、评估结果和理论答案
code_errors = {
"code_produced": code,
"evaluation": agent_answer,
"theoretical_answer": theoretical_answer,
}
else:
# 如果评分为满分,错误信息为空
code_errors = None
# 返回工具选择得分、工具使用得分、代码评分以及相应的错误信息
return (tool_selection_score, tool_used_score, score), (tool_selection_errors, tool_used_errors, code_errors)
# 对代理工具进行一致性检查,确保包含所有必需的测试工具
agent_tools = set(agent.toolbox.keys())
if agent_tools != set(TEST_TOOLS):
# 计算缺失的工具和多余的工具,并引发值错误
missing_tools = set(TEST_TOOLS) - agent_tools
unexpected_tools = agent_tools - set(TEST_TOOLS)
raise ValueError(
f"Fix the test tools in the evaluate_agent module. Tools missing: {missing_tools}. Extra tools: {unexpected_tools}."
)
# 初始化评估任务列表和其对应的索引列表
eval_tasks = []
eval_idx = []
for idx, pb in enumerate(EVALUATION_TASKS):
if isinstance(pb.task, list):
# 将任务列表展开,并更新索引列表
eval_tasks.extend(pb.task)
eval_idx.extend([idx] * len(pb.task))
else:
# 添加单个任务及其索引
eval_tasks.append(pb.task)
eval_idx.append(idx)
# 初始化评分变量
tool_selection_score = 0
tool_used_score = 0
code_score = 0
# 如果需要返回错误信息,则初始化错误字典
if return_errors:
tool_selection_errors = {}
tool_used_errors = {}
code_errors = {}
# 分批次处理评估任务
for start_idx in range(0, len(eval_tasks), batch_size):
end_idx = min(start_idx + batch_size, len(eval_tasks))
batch_tasks = eval_tasks[start_idx:end_idx]
# 根据任务生成相应的提示语句
prompts = [agent.format_prompt(task) for task in batch_tasks]
# 代理执行生成代码任务,停止条件为 "Task:"
results = agent.generate_many(prompts, stop=["Task:"])
# 遍历每个任务结果
for idx, result in enumerate(results):
# 获取当前任务的问题和答案
problem = EVALUATION_TASKS[eval_idx[start_idx + idx]]
if verbose:
# 如果启用了详细输出,打印任务内容
print(f"====Task {start_idx + idx}====\n{batch_tasks[idx]}\n")
# 清理生成的代码并准备执行
explanation, code = agent.clean_code_for_run(result)
# 评估代理的答案和生成的代码答案
agent_answer = evaluate_code(code, problem.inputs, verbose=verbose)
if isinstance(problem.answer, list):
theoretical_answer = [evaluate_code(answer, problem.inputs) for answer in problem.answer]
else:
theoretical_answer = evaluate_code(problem.answer, problem.inputs)
# 调用评估函数,获取得分和可能的错误
scores, errors = evaluate_one_result(
explanation, code, agent_answer, theoretical_answer, problem.answer, verbose=verbose
)
# 累加各项得分
tool_selection_score += scores[0]
tool_used_score += scores[1]
code_score += scores[2]
# 如果需要记录错误信息,则将其添加到相应的错误字典中
if return_errors:
if errors[0] is not None:
tool_selection_errors[batch_tasks[idx]] = errors[0]
if errors[1] is not None:
tool_used_errors[batch_tasks[idx]] = errors[1]
if errors[2] is not None:
code_errors[batch_tasks[idx]] = errors[2]
# 计算并构建评分字典,包括工具选择、工具使用和代码评分,每项分数都是相对于评估任务数量的百分比
scores = {
"tool selection score": 100 * (tool_selection_score / len(eval_tasks)),
"tool used score": 100 * (tool_used_score / len(eval_tasks)),
"code score": 100 * (code_score / len(eval_tasks)),
}
# 如果需要返回错误信息,则返回评分字典和各类错误列表;否则,仅返回评分字典
if return_errors:
return scores, tool_selection_errors, tool_used_errors, code_errors
else:
return scores
# 对给定的代理程序进行评估,检查其是否具备正确的工具集
def evaluate_chat_agent(agent, verbose=False, return_errors=False):
"""
Evaluates a new agent on all `EVALUATION_CHATS`.
Example:
```
agent = NewOpenAiAgent(model="text-davinci-003", api_key=your_api_key)
bads = new_evaluate_agent(agent)
for bad in bads:
print(bad)
```
"""
# 检查代理程序的工具集合是否与预期的测试工具集合一致
agent_tools = set(agent.toolbox.keys())
if agent_tools != set(TEST_TOOLS):
# 计算缺失的工具和多余的工具
missing_tools = set(TEST_TOOLS) - agent_tools
unexpected_tools = agent_tools - set(TEST_TOOLS)
# 抛出数值错误,指示需要修复评估模块中的测试工具
raise ValueError(
f"Fix the test tools in the evaluate_agent module. Tools mising: {missing_tools}. Extra tools: {unexpected_tools}."
)
# 初始化评分变量
tool_selection_score = 0
tool_used_score = 0
code_score = 0
total_steps = 0
# 如果需要返回错误信息,初始化错误字典
if return_errors:
tool_selection_errors = {}
tool_used_errors = {}
code_errors = {}
# 遍历评估对话中的每个问题
for chat_problem in EVALUATION_CHATS:
# 检查第一个任务是否为字符串,若是则标记为已解决的问题列表
if isinstance(chat_problem[0].task, str):
resolved_problems = [chat_problem]
else:
# 否则,根据每个任务生成一个新的Problem对象列表
resolved_problems = [
[Problem(task=pb.task[i], inputs=pb.inputs, answer=pb.answer) for pb in chat_problem]
for i in range(len(chat_problem[0].task))
]
# 遍历解决的问题列表
for problem in resolved_problems:
# 准备Agent进行新对话的准备工作
agent.prepare_for_new_chat()
agent_state = {} # 重置Agent的状态
# 根据第一个答案是否为列表,确定理论状态的初始化方式
theoretical_state = (
[{} for _ in range(len(problem[0].answer))] if isinstance(problem[0].answer, list) else {}
)
# 遍历每个问题中的每个步骤
for step, step_problem in enumerate(problem):
# 如果设定了详细输出模式,打印当前任务描述
if verbose:
print(step_problem.task)
total_steps += 1 # 总步数加一
# 格式化Agent的提示信息,准备生成一条对话
prompt = agent.format_prompt(step_problem.task, chat_mode=True)
# 生成Agent的回答,同时设定停止词以防止过长输出
result = agent.generate_one(prompt, stop=["Human:", "====="])
# 将生成的对话历史记录保存到Agent的聊天历史中
agent.chat_history = prompt + result + "\n"
# 清理生成的代码,获取解释和代码本身
explanation, code = clean_code_for_chat(result)
# 如果设定了详细输出模式,打印Agent生成的解释和代码
if verbose:
print(f"==Explanation from the agent==\n{explanation}")
print(f"\n==Code generated by the agent==\n{code}")
# 评估Agent的回答和生成的代码
agent_answer = evaluate_code(code, step_problem.inputs, state=agent_state, verbose=verbose)
answer = step_problem.answer
if isinstance(answer, list):
# 若答案为列表,计算每个理论答案对应的状态
theoretical_answer = [
evaluate_code(a, step_problem.inputs, state=state)
for a, state in zip(answer, theoretical_state)
]
else:
# 否则,直接计算理论答案
theoretical_answer = evaluate_code(answer, step_problem.inputs, state=theoretical_state)
# 评估一次结果,获取分数和可能的错误信息
scores, errors = evaluate_one_result(
explanation, code, agent_answer, theoretical_answer, answer, verbose=verbose
)
# 累加工具选择得分、工具使用得分和代码得分
tool_selection_score += scores[0]
tool_used_score += scores[1]
code_score += scores[2]
# 如果需要返回错误信息,记录工具选择、工具使用和代码错误
if return_errors:
if errors[0] is not None:
tool_selection_errors[step_problem.task] = errors[0]
if errors[1] is not None:
tool_used_errors[step_problem.task] = errors[1]
if errors[2] is not None:
code_errors[step_problem.task] = errors[2]
# 计算并返回总体得分,根据需要返回错误信息
scores = {
"tool selection score": 100 * (tool_selection_score / total_steps),
"tool used score": 100 * (tool_used_score / total_steps),
"code score": 100 * (code_score / total_steps),
}
if return_errors:
return scores, tool_selection_errors, tool_used_errors, code_errors
else:
return scores
.\tools\image_captioning.py
from typing import TYPE_CHECKING
from ..models.auto import AutoModelForVision2Seq
from ..utils import requires_backends
from .base import PipelineTool
if TYPE_CHECKING:
from PIL import Image
class ImageCaptioningTool(PipelineTool):
default_checkpoint = "Salesforce/blip-image-captioning-base"
description = (
"This is a tool that generates a description of an image. It takes an input named `image` which should be the "
"image to caption, and returns a text that contains the description in English."
)
name = "image_captioner"
model_class = AutoModelForVision2Seq
inputs = ["image"]
outputs = ["text"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"])
super().__init__(*args, **kwargs)
def encode(self, image: "Image"):
return self.pre_processor(images=image, return_tensors="pt")
def forward(self, inputs):
return self.model.generate(**inputs)
def decode(self, outputs):
return self.pre_processor.batch_decode(outputs, skip_special_tokens=True)[0].strip()
.\tools\image_question_answering.py
from typing import TYPE_CHECKING
import torch
from ..models.auto import AutoModelForVisualQuestionAnswering, AutoProcessor
from ..utils import requires_backends
from .base import PipelineTool
if TYPE_CHECKING:
from PIL import Image
class ImageQuestionAnsweringTool(PipelineTool):
default_checkpoint = "dandelin/vilt-b32-finetuned-vqa"
description = (
"This is a tool that answers a question about an image. It takes an input named `image` which should be the "
"image containing the information, as well as a `question` which should be the question in English. It "
"returns a text that is the answer to the question."
)
name = "image_qa"
pre_processor_class = AutoProcessor
model_class = AutoModelForVisualQuestionAnswering
inputs = ["image", "text"]
outputs = ["text"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"])
super().__init__(*args, **kwargs)
def encode(self, image: "Image", question: str):
return self.pre_processor(image, question, return_tensors="pt")
def forward(self, inputs):
with torch.no_grad():
return self.model(**inputs).logits
def decode(self, outputs):
idx = outputs.argmax(-1).item()
return self.model.config.id2label[idx]
.\tools\image_segmentation.py
import numpy as np
import torch
from ..models.clipseg import CLIPSegForImageSegmentation
from ..utils import is_vision_available, requires_backends
from .base import PipelineTool
if is_vision_available():
from PIL import Image
class ImageSegmentationTool(PipelineTool):
description = (
"This is a tool that creates a segmentation mask of an image according to a label. It cannot create an image. "
"It takes two arguments named `image` which should be the original image, and `label` which should be a text "
"describing the elements what should be identified in the segmentation mask. The tool returns the mask."
)
default_checkpoint = "CIDAS/clipseg-rd64-refined"
name = "image_segmenter"
model_class = CLIPSegForImageSegmentation
inputs = ["image", "text"]
outputs = ["image"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"])
super().__init__(*args, **kwargs)
def encode(self, image: "Image", label: str):
return self.pre_processor(text=[label], images=[image], padding=True, return_tensors="pt")
def forward(self, inputs):
with torch.no_grad():
logits = self.model(**inputs).logits
return logits
def decode(self, outputs):
array = outputs.cpu().detach().numpy()
array[array <= 0] = 0
array[array > 0] = 1
return Image.fromarray((array * 255).astype(np.uint8))
.\tools\prompts.py
CHAT_MESSAGE_PROMPT = """
Human: <<task>>
Assistant: """
DEFAULT_PROMPTS_REPO = "huggingface-tools/default-prompts"
PROMPT_FILES = {"chat": "chat_prompt_template.txt", "run": "run_prompt_template.txt"}
def download_prompt(prompt_or_repo_id, agent_name, mode="run"):
"""
根据提示信息或仓库 ID 下载并缓存提示信息,并返回其内容(如果需要)
"""
if prompt_or_repo_id is None:
prompt_or_repo_id = DEFAULT_PROMPTS_REPO
if re.search("\\s", prompt_or_repo_id) is not None:
return prompt_or_repo_id
prompt_file = cached_file(
prompt_or_repo_id, PROMPT_FILES[mode], repo_type="dataset", user_agent={"agent": agent_name}
)
with open(prompt_file, "r", encoding="utf-8") as f:
return f.read()
.\tools\python_interpreter.py
"""
Evaluate an abstract syntax tree (AST) node representing a Python expression, using variables from `state` and
restricted to functions in `tools`.
Args:
expression (`ast.AST`):
The AST node to evaluate.
state (`Dict[str, Any]`):
A dictionary mapping variable names to their current values.
tools (`Dict[str, Callable]`):
Allowed functions that can be called during evaluation.
Returns:
Any:
The result of evaluating the expression represented by `expression`.
Raises:
InterpretorError:
If evaluation encounters an unsupported operation or other error.
"""
try:
line_result = ast.literal_eval(expression, globals=state, locals=tools)
except (ValueError, TypeError, SyntaxError) as e:
raise InterpretorError(f"Failed to evaluate expression: {e}")
return line_result
This function will recurse trough the nodes of the tree provided.
Args:
expression (`ast.AST`):
The code to evaluate, as an abstract syntax tree.
state (`Dict[str, Any]`):
A dictionary mapping variable names to values. The `state` is updated if need be when the evaluation
encounters assignments.
tools (`Dict[str, Callable]`):
The functions that may be called during the evaluation. Any call to another function will fail with an
`InterpretorError`.
"""
if isinstance(expression, ast.Assign):
# If the expression is an assignment statement
# Evaluate the assignment and return the assigned variable's value
return evaluate_assign(expression, state, tools)
elif isinstance(expression, ast.Call):
# If the expression is a function call
# Evaluate the function call and return its value
return evaluate_call(expression, state, tools)
elif isinstance(expression, ast.Constant):
# If the expression is a constant value (literal)
# Return the constant's value
return expression.value
elif isinstance(expression, ast.Dict):
# If the expression is a dictionary literal
# Evaluate all keys and values recursively and return a dictionary
keys = [evaluate_ast(k, state, tools) for k in expression.keys]
values = [evaluate_ast(v, state, tools) for v in expression.values]
return dict(zip(keys, values))
elif isinstance(expression, ast.Expr):
# If the expression is an expression statement
# Evaluate the expression and return its value
return evaluate_ast(expression.value, state, tools)
elif isinstance(expression, ast.For):
# If the expression is a for loop
# Evaluate the loop and return its result
return evaluate_for(expression, state, tools)
elif isinstance(expression, ast.FormattedValue):
# If the expression is a formatted value in an f-string
# Evaluate the content and return its value
return evaluate_ast(expression.value, state, tools)
elif isinstance(expression, ast.If):
# If the expression is an if statement
# Evaluate the condition and execute the appropriate branch
return evaluate_if(expression, state, tools)
elif hasattr(ast, "Index") and isinstance(expression, ast.Index):
# If the expression is an index operation
# Evaluate the indexed value and return it
return evaluate_ast(expression.value, state, tools)
elif isinstance(expression, ast.JoinedStr):
# If the expression is a joined string (part of an f-string)
# Evaluate the concatenated parts and return the resulting string
return "".join([str(evaluate_ast(v, state, tools)) for v in expression.values])
elif isinstance(expression, ast.List):
# If the expression is a list literal
# Evaluate all elements recursively and return a list
return [evaluate_ast(elt, state, tools) for elt in expression.elts]
elif isinstance(expression, ast.Name):
# If the expression is a variable name
# Retrieve its value from the state dictionary
return evaluate_name(expression, state, tools)
elif isinstance(expression, ast.Subscript):
# If the expression is a subscript operation
# Evaluate the subscripted value and return it
return evaluate_subscript(expression, state, tools)
else:
# If the expression type is not recognized
# Raise an interpreter error indicating the unsupported expression type
raise InterpretorError(f"{expression.__class__.__name__} is not supported.")
# 对赋值表达式进行求值,更新状态并返回结果
def evaluate_assign(assign, state, tools):
# 获取赋值表达式左侧的变量名列表
var_names = assign.targets
# 调用 evaluate_ast 函数求解赋值表达式右侧的值
result = evaluate_ast(assign.value, state, tools)
# 如果只有一个变量名,则直接将结果赋给状态中的对应变量
if len(var_names) == 1:
state[var_names[0].id] = result
else:
# 否则,检查结果的长度是否与变量名列表相符
if len(result) != len(var_names):
raise InterpretorError(f"Expected {len(var_names)} values but got {len(result)}.")
# 遍历变量名列表和结果,逐个更新状态中的变量值
for var_name, r in zip(var_names, result):
state[var_name.id] = r
# 返回结果
return result
# 对函数调用表达式进行求值,返回调用结果
def evaluate_call(call, state, tools):
# 如果调用的函数不是一个简单的名称,抛出错误
if not isinstance(call.func, ast.Name):
raise InterpretorError(
f"It is not permitted to evaluate other functions than the provided tools (tried to execute {call.func} of "
f"type {type(call.func)}."
)
# 获取函数名
func_name = call.func.id
# 如果函数名不在提供的工具集中,抛出错误
if func_name not in tools:
raise InterpretorError(
f"It is not permitted to evaluate other functions than the provided tools (tried to execute {call.func.id})."
)
# 获取函数对象
func = tools[func_name]
# 处理函数调用的参数
args = [evaluate_ast(arg, state, tools) for arg in call.args]
kwargs = {keyword.arg: evaluate_ast(keyword.value, state, tools) for keyword in call.keywords}
# 调用函数并返回结果
return func(*args, **kwargs)
# 对下标表达式进行求值,返回索引后的值
def evaluate_subscript(subscript, state, tools):
# 求解下标和值
index = evaluate_ast(subscript.slice, state, tools)
value = evaluate_ast(subscript.value, state, tools)
# 如果值是列表或元组,则返回索引对应的值
if isinstance(value, (list, tuple)):
return value[int(index)]
# 如果索引存在于值中,则返回相应的值
if index in value:
return value[index]
# 如果索引是字符串且值是映射类型,则找出最接近的键并返回其对应的值
if isinstance(index, str) and isinstance(value, Mapping):
close_matches = difflib.get_close_matches(index, list(value.keys()))
if len(close_matches) > 0:
return value[close_matches[0]]
# 抛出错误,表示无法进行索引操作
raise InterpretorError(f"Could not index {value} with '{index}'.")
# 对名称表达式进行求值,返回变量的值
def evaluate_name(name, state, tools):
# 如果变量名存在于状态中,则返回其对应的值
if name.id in state:
return state[name.id]
# 否则,查找变量名的最接近匹配,并返回对应的值
close_matches = difflib.get_close_matches(name.id, list(state.keys()))
if len(close_matches) > 0:
return state[close_matches[0]]
# 抛出错误,表示变量未定义
raise InterpretorError(f"The variable `{name.id}` is not defined.")
# 对条件表达式进行求值,返回布尔值表示的条件结果
def evaluate_condition(condition, state, tools):
# 如果条件包含多个操作符,抛出错误
if len(condition.ops) > 1:
raise InterpretorError("Cannot evaluate conditions with multiple operators")
# 求解条件左侧和右侧的值
left = evaluate_ast(condition.left, state, tools)
comparator = condition.ops[0]
right = evaluate_ast(condition.comparators[0], state, tools)
# 根据比较符的类型,比较左右两侧的值并返回结果
if isinstance(comparator, ast.Eq):
return left == right
elif isinstance(comparator, ast.NotEq):
return left != right
elif isinstance(comparator, ast.Lt):
return left < right
elif isinstance(comparator, ast.LtE):
return left <= right
elif isinstance(comparator, ast.Gt):
return left > right
elif isinstance(comparator, ast.GtE):
return left >= right
elif isinstance(comparator, ast.Is):
return left is right
elif isinstance(comparator, ast.IsNot):
return left is not right
# 如果比较符号是 'in',则返回左操作数是否包含在右操作数中的布尔值
elif isinstance(comparator, ast.In):
return left in right
# 如果比较符号是 'not in',则返回左操作数是否不包含在右操作数中的布尔值
elif isinstance(comparator, ast.NotIn):
return left not in right
else:
# 如果比较符号不是以上两种情况,抛出解释器错误,显示不支持的操作符信息
raise InterpretorError(f"Operator not supported: {comparator}")
# 根据条件语句评估条件并执行相应的操作,返回最后一个操作的结果
def evaluate_if(if_statement, state, tools):
result = None
# 如果条件为真,执行条件体内的语句
if evaluate_condition(if_statement.test, state, tools):
# 遍历条件体内的每一行语句
for line in if_statement.body:
# 评估并执行当前行的抽象语法树节点
line_result = evaluate_ast(line, state, tools)
# 如果结果不为空,更新结果
if line_result is not None:
result = line_result
else:
# 如果条件为假,执行否定体内的语句
for line in if_statement.orelse:
# 评估并执行当前行的抽象语法树节点
line_result = evaluate_ast(line, state, tools)
# 如果结果不为空,更新结果
if line_result is not None:
result = line_result
# 返回最后执行的结果
return result
# 根据for循环语句评估迭代器,并依次执行循环体内的操作,返回最后一个操作的结果
def evaluate_for(for_loop, state, tools):
result = None
# 评估迭代器表达式,获取迭代器对象
iterator = evaluate_ast(for_loop.iter, state, tools)
# 遍历迭代器对象中的每一个元素
for counter in iterator:
# 将当前元素赋值给循环目标变量
state[for_loop.target.id] = counter
# 遍历for循环体内的每一个表达式
for expression in for_loop.body:
# 评估并执行当前表达式的抽象语法树节点
line_result = evaluate_ast(expression, state, tools)
# 如果结果不为空,更新结果
if line_result is not None:
result = line_result
# 返回最后执行的结果
return result
.\tools\speech_to_text.py
from ..models.whisper import WhisperForConditionalGeneration, WhisperProcessor
from .base import PipelineTool
class SpeechToTextTool(PipelineTool):
default_checkpoint = "openai/whisper-base"
description = (
"This is a tool that transcribes an audio into text. It takes an input named `audio` and returns the "
"transcribed text."
)
name = "transcriber"
pre_processor_class = WhisperProcessor
model_class = WhisperForConditionalGeneration
inputs = ["audio"]
outputs = ["text"]
def encode(self, audio):
return self.pre_processor(audio, return_tensors="pt").input_features
def forward(self, inputs):
return self.model.generate(inputs=inputs)
def decode(self, outputs):
return self.pre_processor.batch_decode(outputs, skip_special_tokens=True)[0]
.\tools\text_classification.py
import torch
from ..models.auto import AutoModelForSequenceClassification, AutoTokenizer
from .base import PipelineTool
class TextClassificationTool(PipelineTool):
"""
文本分类工具类,继承自PipelineTool基类。
Example:
```
from transformers.tools import TextClassificationTool
classifier = TextClassificationTool()
classifier("This is a super nice API!", labels=["positive", "negative"])
```
"""
default_checkpoint = "facebook/bart-large-mnli"
description = (
"This is a tool that classifies an English text using provided labels. It takes two inputs: `text`, which "
"should be the text to classify, and `labels`, which should be the list of labels to use for classification. "
"It returns the most likely label in the list of provided `labels` for the input text."
)
name = "text_classifier"
pre_processor_class = AutoTokenizer
model_class = AutoModelForSequenceClassification
inputs = ["text", ["text"]]
outputs = ["text"]
def setup(self):
super().setup()
config = self.model.config
self.entailment_id = -1
for idx, label in config.id2label.items():
if label.lower().startswith("entail"):
self.entailment_id = int(idx)
if self.entailment_id == -1:
raise ValueError("Could not determine the entailment ID from the model config, please pass it at init.")
def encode(self, text, labels):
self._labels = labels
return self.pre_processor(
[text] * len(labels),
[f"This example is {label}" for label in labels],
return_tensors="pt",
padding="max_length",
)
def decode(self, outputs):
logits = outputs.logits
label_id = torch.argmax(logits[:, 2]).item()
return self._labels[label_id]
.\tools\text_question_answering.py
from ..models.auto import AutoModelForSeq2SeqLM, AutoTokenizer
from .base import PipelineTool
QA_PROMPT = """Here is a text containing a lot of information: '''{text}'''.
Can you answer this question about the text: '{question}'"""
class TextQuestionAnsweringTool(PipelineTool):
default_checkpoint = "google/flan-t5-base"
description = (
"This is a tool that answers questions related to a text. It takes two arguments named `text`, which is the "
"text where to find the answer, and `question`, which is the question, and returns the answer to the question."
)
name = "text_qa"
pre_processor_class = AutoTokenizer
model_class = AutoModelForSeq2SeqLM
inputs = ["text", "text"]
outputs = ["text"]
def encode(self, text: str, question: str):
prompt = QA_PROMPT.format(text=text, question=question)
return self.pre_processor(prompt, return_tensors="pt")
def forward(self, inputs):
output_ids = self.model.generate(**inputs)
in_b, _ = inputs["input_ids"].shape
out_b = output_ids.shape[0]
return output_ids.reshape(in_b, out_b // in_b, *output_ids.shape[1:])[0][0]
def decode(self, outputs):
return self.pre_processor.decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True)
.\tools\text_summarization.py
from ..models.auto import AutoModelForSeq2SeqLM, AutoTokenizer
from .base import PipelineTool
class TextSummarizationTool(PipelineTool):
"""
Example:
```
from transformers.tools import TextSummarizationTool
summarizer = TextSummarizationTool()
summarizer(long_text)
```
"""
default_checkpoint = "philschmid/bart-large-cnn-samsum"
description = (
"This is a tool that summarizes an English text. It takes an input `text` containing the text to summarize, "
"and returns a summary of the text."
)
name = "summarizer"
pre_processor_class = AutoTokenizer
model_class = AutoModelForSeq2SeqLM
inputs = ["text"]
outputs = ["text"]
def encode(self, text):
return self.pre_processor(text, return_tensors="pt", truncation=True)
def forward(self, inputs):
return self.model.generate(**inputs)[0]
def decode(self, outputs):
return self.pre_processor.decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True)
.\tools\text_to_speech.py
import torch
from ..models.speecht5 import SpeechT5ForTextToSpeech, SpeechT5HifiGan, SpeechT5Processor
from ..utils import is_datasets_available
from .base import PipelineTool
if is_datasets_available():
from datasets import load_dataset
class TextToSpeechTool(PipelineTool):
default_checkpoint = "microsoft/speecht5_tts"
description = (
"This is a tool that reads an English text out loud. It takes an input named `text` which should contain the "
"text to read (in English) and returns a waveform object containing the sound."
)
name = "text_reader"
pre_processor_class = SpeechT5Processor
model_class = SpeechT5ForTextToSpeech
post_processor_class = SpeechT5HifiGan
inputs = ["text"]
outputs = ["audio"]
def setup(self):
if self.post_processor is None:
self.post_processor = "microsoft/speecht5_hifigan"
super().setup()
def encode(self, text, speaker_embeddings=None):
inputs = self.pre_processor(text=text, return_tensors="pt", truncation=True)
if speaker_embeddings is None:
if not is_datasets_available():
raise ImportError("Datasets needs to be installed if not passing speaker embeddings.")
embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
speaker_embeddings = torch.tensor(embeddings_dataset[7305]["xvector"]).unsqueeze(0)
return {"input_ids": inputs["input_ids"], "speaker_embeddings": speaker_embeddings}
def forward(self, inputs):
with torch.no_grad():
return self.model.generate_speech(**inputs)
def decode(self, outputs):
with torch.no_grad():
return self.post_processor(outputs).cpu().detach()
.\tools\translation.py
from ..models.auto import AutoModelForSeq2SeqLM, AutoTokenizer
from .base import PipelineTool
LANGUAGE_CODES = {
"Acehnese Arabic": "ace_Arab",
"Acehnese Latin": "ace_Latn",
"Mesopotamian Arabic": "acm_Arab",
"Ta'izzi-Adeni Arabic": "acq_Arab",
"Tunisian Arabic": "aeb_Arab",
"Afrikaans": "afr_Latn",
"South Levantine Arabic": "ajp_Arab",
"Akan": "aka_Latn",
"Amharic": "amh_Ethi",
"North Levantine Arabic": "apc_Arab",
"Modern Standard Arabic": "arb_Arab",
"Modern Standard Arabic Romanized": "arb_Latn",
"Najdi Arabic": "ars_Arab",
"Moroccan Arabic": "ary_Arab",
"Egyptian Arabic": "arz_Arab",
"Assamese": "asm_Beng",
"Asturian": "ast_Latn",
"Awadhi": "awa_Deva",
"Central Aymara": "ayr_Latn",
"South Azerbaijani": "azb_Arab",
"North Azerbaijani": "azj_Latn",
"Bashkir": "bak_Cyrl",
"Bambara": "bam_Latn",
"Balinese": "ban_Latn",
"Belarusian": "bel_Cyrl",
"Bemba": "bem_Latn",
"Bengali": "ben_Beng",
"Bhojpuri": "bho_Deva",
"Banjar Arabic": "bjn_Arab",
"Banjar Latin": "bjn_Latn",
"Standard Tibetan": "bod_Tibt",
"Bosnian": "bos_Latn",
"Buginese": "bug_Latn",
"Bulgarian": "bul_Cyrl",
"Catalan": "cat_Latn",
"Cebuano": "ceb_Latn",
"Czech": "ces_Latn",
"Chokwe": "cjk_Latn",
"Central Kurdish": "ckb_Arab",
"Crimean Tatar": "crh_Latn",
"Welsh": "cym_Latn",
"Danish": "dan_Latn",
"German": "deu_Latn",
"Southwestern Dinka": "dik_Latn",
"Dyula": "dyu_Latn",
"Dzongkha": "dzo_Tibt",
"Greek": "ell_Grek",
"English": "eng_Latn",
"Esperanto": "epo_Latn",
"Estonian": "est_Latn",
"Basque": "eus_Latn",
"Ewe": "ewe_Latn",
"Faroese": "fao_Latn",
"Fijian": "fij_Latn",
"Finnish": "fin_Latn",
"Fon": "fon_Latn",
"French": "fra_Latn",
"Friulian": "fur_Latn",
"Nigerian Fulfulde": "fuv_Latn",
"Scottish Gaelic": "gla_Latn",
"Irish": "gle_Latn",
"Galician": "glg_Latn",
"Guarani": "grn_Latn",
"Gujarati": "guj_Gujr",
"Haitian Creole": "hat_Latn",
"Hausa": "hau_Latn",
"Hebrew": "heb_Hebr",
"Hindi": "hin_Deva",
"Chhattisgarhi": "hne_Deva",
"Croatian": "hrv_Latn",
"Hungarian": "hun_Latn",
"Armenian": "hye_Armn",
"Igbo": "ibo_Latn",
"Ilocano": "ilo_Latn",
"Indonesian": "ind_Latn",
"Icelandic": "isl_Latn",
"Italian": "ita_Latn",
"Javanese": "jav_Latn",
"Japanese": "jpn_Jpan",
"Kabyle": "kab_Latn",
"Jingpho": "kac_Latn",
"Kamba": "kam_Latn",
"Kannada": "kan_Knda",
"Kashmiri Arabic": "kas_Arab",
"Kashmiri Devanagari": "kas_Deva",
"Georgian": "kat_Geor",
"Central Kanuri Arabic": "knc_Arab",
"Central Kanuri Latin": "knc_Latn",
"Kazakh": "kaz_Cyrl",
"Kabiyè": "kbp_Latn",
"Kabuverdianu": "kea_Latn",
"Khmer": "khm_Khmr",
"Kikuyu": "kik_Latn",
"Kinyarwanda": "kin_Latn",
"Kyrgyz": "kir_Cyrl",
"Kimbundu": "kmb_Latn",
"Northern Kurdish": "kmr_Latn",
"Kikongo": "kon_Latn",
"Korean": "kor_Hang",
"Lao": "lao_Laoo",
"Ligurian": "lij_Latn",
"Limburgish": "lim_Latn",
"Lingala": "lin_Latn",
"Lithuanian": "lit_Latn",
"Lombard": "lmo_Latn",
"Latgalian": "ltg_Latn",
"Luxembourgish": "ltz_Latn",
"Luba-Kasai": "lua_Latn",
"Ganda": "lug_Latn",
"Luo": "luo_Latn",
"Mizo": "lus_Latn",
"Standard Latvian": "lvs_Latn",
"Magahi": "mag_Deva",
"Maithili": "mai_Deva",
"Malayalam": "mal_Mlym",
"Marathi": "mar_Deva",
"Minangkabau Arabic ": "min_Arab",
"Minangkabau Latin": "min_Latn",
"Macedonian": "mkd_Cyrl",
"Plateau Malagasy": "plt_Latn",
"Maltese": "mlt_Latn",
"Meitei Bengali": "mni_Beng",
"Halh Mongolian": "khk_Cyrl",
"Mossi": "mos_Latn",
"Maori": "mri_Latn",
"Burmese": "mya_Mymr",
"Dutch": "nld_Latn",
"Norwegian Nynorsk": "nno_Latn",
"Norwegian Bokmål": "nob_Latn",
"Nepali": "npi_Deva",
"Northern Sotho": "nso_Latn",
"Nuer": "nus_Latn",
"Nyanja": "nya_Latn",
"Occitan": "oci_Latn",
"West Central Oromo": "gaz_Latn",
"Odia": "ory_Orya",
"Pangasinan": "pag_Latn",
"Eastern Panjabi": "pan_Guru",
"Papiamento": "pap_Latn",
"Western Persian": "pes_Arab",
{
"Tamasheq Latin": "taq_Latn",
"Tamasheq Tifinagh": "taq_Tfng",
"Tok Pisin": "tpi_Latn",
"Tswana": "tsn_Latn",
"Tsonga": "tso_Latn",
"Turkmen": "tuk_Latn",
"Tumbuka": "tum_Latn",
"Turkish": "tur_Latn",
"Twi": "twi_Latn",
"Central Atlas Tamazight": "tzm_Tfng",
"Uyghur": "uig_Arab",
"Ukrainian": "ukr_Cyrl",
"Umbundu": "umb_Latn",
"Urdu": "urd_Arab",
"Northern Uzbek": "uzn_Latn",
"Venetian": "vec_Latn",
"Vietnamese": "vie_Latn",
"Waray": "war_Latn",
"Wolof": "wol_Latn",
"Xhosa": "xho_Latn",
"Eastern Yiddish": "ydd_Hebr",
"Yoruba": "yor_Latn",
"Yue Chinese": "yue_Hant",
"Chinese Simplified": "zho_Hans",
"Chinese Traditional": "zho_Hant",
"Standard Malay": "zsm_Latn",
"Zulu": "zul_Latn",
}
}
class TranslationTool(PipelineTool):
"""
Example:
```
from transformers.tools import TranslationTool
translator = TranslationTool()
translator("This is a super nice API!", src_lang="English", tgt_lang="French")
```
"""
default_checkpoint = "facebook/nllb-200-distilled-600M"
description = (
"This is a tool that translates text from a language to another. It takes three inputs: `text`, which should "
"be the text to translate, `src_lang`, which should be the language of the text to translate and `tgt_lang`, "
"which should be the language for the desired ouput language. Both `src_lang` and `tgt_lang` are written in "
"plain English, such as 'Romanian', or 'Albanian'. It returns the text translated in `tgt_lang`."
)
name = "translator"
pre_processor_class = AutoTokenizer
model_class = AutoModelForSeq2SeqLM
lang_to_code = LANGUAGE_CODES
inputs = ["text", "text", "text"]
outputs = ["text"]
def encode(self, text, src_lang, tgt_lang):
if src_lang not in self.lang_to_code:
raise ValueError(f"{src_lang} is not a supported language.")
if tgt_lang not in self.lang_to_code:
raise ValueError(f"{tgt_lang} is not a supported language.")
src_lang = self.lang_to_code[src_lang]
tgt_lang = self.lang_to_code[tgt_lang]
return self.pre_processor._build_translation_inputs(
text, return_tensors="pt", src_lang=src_lang, tgt_lang=tgt_lang
)
def forward(self, inputs):
return self.model.generate(**inputs)
def decode(self, outputs):
return self.post_processor.decode(outputs[0].tolist(), skip_special_tokens=True)
.\tools\__init__.py
from typing import TYPE_CHECKING
from ..utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_torch_available,
)
_import_structure = {
"agents": ["Agent", "AzureOpenAiAgent", "HfAgent", "LocalAgent", "OpenAiAgent"],
"base": ["PipelineTool", "RemoteTool", "Tool", "launch_gradio_demo", "load_tool"],
}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["document_question_answering"] = ["DocumentQuestionAnsweringTool"]
_import_structure["image_captioning"] = ["ImageCaptioningTool"]
_import_structure["image_question_answering"] = ["ImageQuestionAnsweringTool"]
_import_structure["image_segmentation"] = ["ImageSegmentationTool"]
_import_structure["speech_to_text"] = ["SpeechToTextTool"]
_import_structure["text_classification"] = ["TextClassificationTool"]
_import_structure["text_question_answering"] = ["TextQuestionAnsweringTool"]
_import_structure["text_summarization"] = ["TextSummarizationTool"]
_import_structure["text_to_speech"] = ["TextToSpeechTool"]
_import_structure["translation"] = ["TranslationTool"]
if TYPE_CHECKING:
from .agents import Agent, AzureOpenAiAgent, HfAgent, LocalAgent, OpenAiAgent
from .base import PipelineTool, RemoteTool, Tool, launch_gradio_demo, load_tool
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .document_question_answering import DocumentQuestionAnsweringTool
from .image_captioning import ImageCaptioningTool
from .image_question_answering import ImageQuestionAnsweringTool
from .image_segmentation import ImageSegmentationTool
from .speech_to_text import SpeechToTextTool
from .text_classification import TextClassificationTool
from .text_question_answering import TextQuestionAnsweringTool
from .text_summarization import TextSummarizationTool
from .text_to_speech import TextToSpeechTool
from .translation import TranslationTool
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
.\trainer.py
"""
The Trainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task.
"""
import contextlib
import copy
import functools
import glob
import importlib.metadata
import inspect
import math
import os
import random
import re
import shutil
import sys
import tempfile
import time
import warnings
from collections.abc import Mapping
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
from .integrations import (
get_reporting_integration_callbacks,
hp_params,
)
import huggingface_hub.utils as hf_hub_utils
import numpy as np
import torch
import torch.distributed as dist
from huggingface_hub import ModelCard, create_repo, upload_folder
from packaging import version
from torch import nn
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
from . import __version__
from .configuration_utils import PretrainedConfig
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
from .debug_utils import DebugOption, DebugUnderflowOverflow
from .hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS, default_hp_search_backend
from .integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available
from .integrations.tpu import tpu_spmd_dataloader
from .modelcard import TrainingSummary
from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
from .models.auto.modeling_auto import (
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
MODEL_MAPPING_NAMES,
)
from .optimization import Adafactor, get_scheduler
from .pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13
from .tokenization_utils_base import PreTrainedTokenizerBase
from .trainer_callback import (
CallbackHandler,
DefaultFlowCallback,
PrinterCallback,
ProgressCallback,
TrainerCallback,
TrainerControl,
TrainerState,
)
from .trainer_pt_utils import (
DistributedTensorGatherer,
IterableDatasetShard,
LabelSmoother,
LayerWiseDummyOptimizer,
LengthGroupedSampler,
SequentialDistributedSampler,
distributed_broadcast_scalars,
distributed_concat,
find_batch_size,
get_dataloader_sampler,
get_model_param_count,
)
from .integrations import get_reporting_integration_callbacks, hp_params
import huggingface_hub.utils as hf_hub_utils
import numpy as np
import torch
import torch.distributed as dist
from huggingface_hub import ModelCard, create_repo, upload_folder
from packaging import version
from torch import nn
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
from . import __version__
from .configuration_utils import PretrainedConfig
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
from .debug_utils import DebugOption, DebugUnderflowOverflow
from .hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS, default_hp_search_backend
from .integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available
from .integrations.tpu import tpu_spmd_dataloader
from .modelcard import TrainingSummary
from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
from .models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES
from .optimization import Adafactor, get_scheduler
from .pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13
from .tokenization_utils_base import PreTrainedTokenizerBase
from .trainer_callback import CallbackHandler, DefaultFlowCallback, PrinterCallback, ProgressCallback, TrainerCallback, TrainerControl, TrainerState
from .trainer_pt_utils import DistributedTensorGatherer, IterableDatasetShard, LabelSmoother, LayerWiseDummyOptimizer, LengthGroupedSampler, SequentialDistributedSampler, distributed_broadcast_scalars, distributed_concat, find_batch_size, get_dataloader_sampler, get_model_param_count
get_module_class_from_name,
get_parameter_names,
nested_concat,
nested_detach,
nested_numpify,
nested_xla_mesh_reduce,
reissue_pt_warnings,
remove_dummy_checkpoint,
from .trainer_utils import (
PREFIX_CHECKPOINT_DIR,
BestRun,
EvalLoopOutput,
EvalPrediction,
HPSearchBackend,
HubStrategy,
IntervalStrategy,
PredictionOutput,
RemoveColumnsCollator,
TrainerMemoryTracker,
TrainOutput,
check_target_module_exists,
default_compute_objective,
denumpify_detensorize,
enable_full_determinism,
find_executable_batch_size,
get_last_checkpoint,
has_length,
neftune_post_forward_hook,
number_of_arguments,
seed_worker,
set_seed,
speed_metrics,
)
from .training_args import OptimizerNames, ParallelMode, TrainingArguments
from .utils import (
ADAPTER_CONFIG_NAME,
ADAPTER_SAFE_WEIGHTS_NAME,
ADAPTER_WEIGHTS_NAME,
CONFIG_NAME,
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
PushInProgress,
PushToHubMixin,
can_return_loss,
find_labels,
is_accelerate_available,
is_apex_available,
is_bitsandbytes_available,
is_datasets_available,
is_galore_torch_available,
is_in_notebook,
is_ipex_available,
is_peft_available,
is_safetensors_available,
is_sagemaker_dp_enabled,
is_sagemaker_mp_enabled,
is_torch_compile_available,
is_torch_neuroncore_available,
is_torch_npu_available,
is_torch_xla_available,
logging,
strtobool,
)
from .utils.quantization_config import QuantizationMethod
DEFAULT_CALLBACKS = [DefaultFlowCallback]
DEFAULT_PROGRESS_CALLBACK = ProgressCallback
if is_in_notebook():
from .utils.notebook import NotebookProgressCallback
DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback
if is_apex_available():
from apex import amp
if is_datasets_available():
import datasets
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.distributed.spmd as xs
import torch_xla.runtime as xr
if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
from smdistributed.modelparallel import __version__ as SMP_VERSION
IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10")
from .trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat
else:
IS_SAGEMAKER_MP_POST_1_10 = False
if is_safetensors_available():
import safetensors.torch
if is_peft_available():
from peft import PeftModel
if is_accelerate_available():
from accelerate import Accelerator, skip_first_batches
from accelerate import __version__ as accelerate_version
from accelerate.utils import (
DistributedDataParallelKwargs,
DistributedType,
GradientAccumulationPlugin,
load_fsdp_model,
load_fsdp_optimizer,
save_fsdp_model,
save_fsdp_optimizer,
)
DATA_SAMPLERS = [RandomSampler]
if version.parse(accelerate_version) > version.parse("0.23.0"):
from accelerate.data_loader import SeedableRandomSampler
DATA_SAMPLERS += [SeedableRandomSampler]
if is_deepspeed_available():
from accelerate.utils import DeepSpeedSchedulerWrapper
def _is_peft_model(model):
if is_peft_available():
classes_to_check = (PeftModel,)
if version.parse(importlib.metadata.version("peft")) >= version.parse("0.7.0"):
from peft import PeftMixedModel
classes_to_check = (*classes_to_check, PeftMixedModel)
return isinstance(model, classes_to_check)
return False
def _get_fsdp_ckpt_kwargs():
if is_accelerate_available() and "adapter_only" in list(inspect.signature(save_fsdp_model).parameters):
return {"adapter_only": True}
else:
return {}
if TYPE_CHECKING:
import optuna
logger = logging.get_logger(__name__)
TRAINING_ARGS_NAME = "training_args.bin"
TRAINER_STATE_NAME = "trainer_state.json"
OPTIMIZER_NAME = "optimizer.pt"
OPTIMIZER_NAME_BIN = "optimizer.bin"
SCHEDULER_NAME = "scheduler.pt"
SCALER_NAME = "scaler.pt"
FSDP_MODEL_NAME = "pytorch_model_fsdp"
class Trainer:
"""
Trainer 是一个简单但功能齐全的 PyTorch 训练和评估循环,专为 🤗 Transformers 优化。
重要属性:
- **model** -- 始终指向核心模型。如果使用的是 transformers 模型,它将是 [`PreTrainedModel`] 的子类。
- **model_wrapped** -- 始终指向最外层的模型。如果使用 `DeepSpeed`,内部模型会被包装成 `DeepSpeed` 和 `torch.nn.DistributedDataParallel`。
如果内部模型没有被包装,则 `self.model_wrapped` 与 `self.model` 相同。
- **is_model_parallel** -- 是否将模型切换到模型并行模式(不同于数据并行,意味着一些模型层在不同的 GPU 上拆分)。
- **place_model_on_device** -- 是否自动将模型放置在设备上。如果使用模型并行或 DeepSpeed,或者默认的 `TrainingArguments.place_model_on_device`
被覆盖为返回 `False`,它将设置为 `False`。
- **is_in_train** -- 当前模型是否正在执行 `train`(例如,在 `train` 运行时调用 `evaluate`)。
"""
from .trainer_pt_utils import _get_learning_rate, log_metrics, metrics_format, save_metrics, save_state
def __init__(
self,
model: Union[PreTrainedModel, nn.Module] = None,
args: TrainingArguments = None,
data_collator: Optional[DataCollator] = None,
train_dataset: Optional[Dataset] = None,
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
model_init: Optional[Callable[[], PreTrainedModel]] = None,
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
callbacks: Optional[List[TrainerCallback]] = None,
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
):
"""
初始化一个 `Trainer` 对象,用于模型训练和评估。
Args:
model (Union[PreTrainedModel, nn.Module], optional): 要训练的模型。
args (TrainingArguments, optional): 训练和评估的参数设置。
data_collator (Optional[DataCollator], optional): 用于批处理数据的数据收集器。
train_dataset (Optional[Dataset], optional): 训练数据集。
eval_dataset (Optional[Union[Dataset, Dict[str, Dataset]]], optional): 评估数据集。
tokenizer (Optional[PreTrainedTokenizerBase], optional): 用于处理输入数据的分词器。
model_init (Optional[Callable[[], PreTrainedModel]], optional): 初始化模型的函数。
compute_metrics (Optional[Callable[[EvalPrediction], Dict]], optional): 用于计算评估指标的函数。
callbacks (Optional[List[TrainerCallback]], optional): 训练过程中使用的回调函数列表。
optimizers (Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR], optional): 优化器和学习率调度器的元组。
preprocess_logits_for_metrics (Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]], optional): 对预测结果进行预处理的函数。
"""
def _activate_neftune(self, model):
r"""
激活 NEFTune 方法,参考代码和论文:
https://github.com/neelsjain/NEFTune
https://arxiv.org/abs/2310.05914
"""
unwrapped_model = unwrap_model(model)
if _is_peft_model(unwrapped_model):
embeddings = unwrapped_model.base_model.model.get_input_embeddings()
else:
embeddings = unwrapped_model.get_input_embeddings()
del unwrapped_model
embeddings.neftune_noise_alpha = self.neftune_noise_alpha
hook_handle = embeddings.register_forward_hook(neftune_post_forward_hook)
self.neftune_hook_handle = hook_handle
return model
def _deactivate_neftune(self, model):
"""
停用 NEFTune 方法。确保先调用 `_activate_neftune`。
"""
if not hasattr(self, "neftune_hook_handle"):
raise ValueError("Neftune is not activated make sure to call `trainer._activate_neftune()` first")
unwrapped_model = unwrap_model(model)
if _is_peft_model(unwrapped_model):
embeddings = unwrapped_model.base_model.model.get_input_embeddings()
else:
embeddings = unwrapped_model.get_input_embeddings()
self.neftune_hook_handle.remove()
del embeddings.neftune_noise_alpha, unwrapped_model
def add_callback(self, callback):
"""
向当前的 [`~transformers.TrainerCallback`] 列表中添加一个回调函数。
Args:
callback (type or [`~transformers.TrainerCallback`]):
[`~transformers.TrainerCallback`] 类或其实例。如果是类,则实例化一个该类的成员。
"""
def pop_callback(self, callback):
"""
Remove a callback from the current list of [`~transformers.TrainerCallback`] and returns it.
If the callback is not found, returns `None` (and no error is raised).
Args:
callback (`type` or [`~transformers.TrainerCallback`]):
A [`~transformers.TrainerCallback`] class or an instance of a [`~transformers.TrainerCallback`]. In the
first case, will pop the first member of that class found in the list of callbacks.
Returns:
[`~transformers.TrainerCallback`]: The callback removed, if found.
"""
return self.callback_handler.pop_callback(callback)
def remove_callback(self, callback):
"""
Remove a callback from the current list of [`~transformers.TrainerCallback`].
Args:
callback (`type` or [`~transformers.TrainerCallback`]):
A [`~transformers.TrainerCallback`] class or an instance of a [`~transformers.TrainerCallback`]. In the
first case, will remove the first member of that class found in the list of callbacks.
"""
self.callback_handler.remove_callback(callback)
def _move_model_to_device(self, model, device):
model = model.to(device)
if self.args.parallel_mode == ParallelMode.TPU and hasattr(model, "tie_weights"):
model.tie_weights()
def _set_signature_columns_if_needed(self):
if self._signature_columns is None:
model_to_inspect = self.model
if _is_peft_model(self.model):
if hasattr(self.model, "get_base_model"):
model_to_inspect = self.model.get_base_model()
else:
model_to_inspect = self.model.base_model.model
signature = inspect.signature(model_to_inspect.forward)
self._signature_columns = list(signature.parameters.keys())
self._signature_columns += list(set(["label", "label_ids"] + self.label_names))
def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
if not self.args.remove_unused_columns:
return dataset
self._set_signature_columns_if_needed()
signature_columns = self._signature_columns
ignored_columns = list(set(dataset.column_names) - set(signature_columns))
dset_description = "" if description is None else f"in the {description} set"
logger.info(
f"The following columns {dset_description} don't have a corresponding argument in "
f"`{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}."
f" If {', '.join(ignored_columns)} are not expected by `{self.model.__class__.__name__}.forward`, "
" you can safely ignore this message."
)
columns = [k for k in signature_columns if k in dataset.column_names]
if version.parse(datasets.__version__) < version.parse("1.4.0"):
dataset.set_format(
type=dataset.format["type"], columns=columns, format_kwargs=dataset.format["format_kwargs"]
)
return dataset
else:
return dataset.remove_columns(ignored_columns)
def _get_collator_with_removed_columns(
self, data_collator: Callable, description: Optional[str] = None
) -> Callable:
"""Wrap the data collator in a callable removing unused columns."""
if not self.args.remove_unused_columns:
return data_collator
self._set_signature_columns_if_needed()
signature_columns = self._signature_columns
remove_columns_collator = RemoveColumnsCollator(
data_collator=data_collator,
signature_columns=signature_columns,
logger=logger,
description=description,
model_name=self.model.__class__.__name__,
)
return remove_columns_collator
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.train_dataset is None or not has_length(self.train_dataset):
return None
if self.args.group_by_length:
if is_datasets_available() and isinstance(self.train_dataset, datasets.Dataset):
lengths = (
self.train_dataset[self.args.length_column_name]
if self.args.length_column_name in self.train_dataset.column_names
else None
)
else:
lengths = None
model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None
return LengthGroupedSampler(
self.args.train_batch_size * self.args.gradient_accumulation_steps,
dataset=self.train_dataset,
lengths=lengths,
model_input_name=model_input_name,
)
else:
return RandomSampler(self.train_dataset)
def get_train_dataloader(self) -> DataLoader:
"""
返回训练数据加载器 [`~torch.utils.data.DataLoader`]。
如果 `train_dataset` 不实现 `__len__`,则不使用采样器;否则使用随机采样器(适应分布式训练)。
如果需要注入自定义行为,请子类化并重写此方法。
"""
if self.train_dataset is None:
raise ValueError("Trainer: training requires a train_dataset.")
train_dataset = self.train_dataset
data_collator = self.data_collator
if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
train_dataset = self._remove_unused_columns(train_dataset, description="training")
else:
data_collator = self._get_collator_with_removed_columns(data_collator, description="training")
dataloader_params = {
"batch_size": self._train_batch_size,
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
"persistent_workers": self.args.dataloader_persistent_workers,
}
if not isinstance(train_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_train_sampler()
dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["worker_init_fn"] = seed_worker
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]:
if self.args.use_legacy_prediction_loop:
if is_torch_xla_available():
return SequentialDistributedSampler(
eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()
)
elif is_sagemaker_mp_enabled():
return SequentialDistributedSampler(
eval_dataset,
num_replicas=smp.dp_size(),
rank=smp.dp_rank(),
batch_size=self.args.per_device_eval_batch_size,
)
else:
return SequentialSampler(eval_dataset)
if self.args.world_size <= 1:
return SequentialSampler(eval_dataset)
else:
return None
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
"""
Returns the evaluation `torch.utils.data.DataLoader`.
Subclass and override this method if you want to inject some custom behavior.
Args:
eval_dataset (`torch.utils.data.Dataset`, *optional*):
If provided, will override `self.eval_dataset`. If it is a `datasets.Dataset`, columns not accepted
by the `model.forward()` method are automatically removed. It must implement `__len__`.
"""
if eval_dataset is None and self.eval_dataset is None:
raise ValueError("Trainer: evaluation requires an eval_dataset.")
if hasattr(self, "_eval_dataloader") and self.args.dataloader_persistent_workers:
return self.accelerator.prepare(self._eval_dataloader)
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
data_collator = self.data_collator
if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation")
else:
data_collator = self._get_collator_with_removed_columns(data_collator, description="evaluation")
dataloader_params = {
"batch_size": self.args.eval_batch_size,
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
"persistent_workers": self.args.dataloader_persistent_workers,
}
if not isinstance(eval_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_eval_sampler(eval_dataset)
dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
eval_dataloader = DataLoader(eval_dataset, **dataloader_params)
if self.args.dataloader_persistent_workers:
self._eval_dataloader = eval_dataloader
return self.accelerator.prepare(eval_dataloader)
def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
"""
Returns the test [`~torch.utils.data.DataLoader`].
Subclass and override this method if you want to inject some custom behavior.
Args:
test_dataset (`torch.utils.data.Dataset`, *optional*):
The test dataset to use. If it is a [`~datasets.Dataset`], columns not accepted by the
`model.forward()` method are automatically removed. It must implement `__len__`.
"""
data_collator = self.data_collator
if is_datasets_available() and isinstance(test_dataset, datasets.Dataset):
test_dataset = self._remove_unused_columns(test_dataset, description="test")
else:
data_collator = self._get_collator_with_removed_columns(data_collator, description="test")
dataloader_params = {
"batch_size": self.args.eval_batch_size,
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
"persistent_workers": self.args.dataloader_persistent_workers,
}
if not isinstance(test_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_eval_sampler(test_dataset)
dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
return self.accelerator.prepare(DataLoader(test_dataset, **dataloader_params))
def create_optimizer_and_scheduler(self, num_training_steps: int):
"""
Setup the optimizer and the learning rate scheduler.
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
Trainer's init through `optimizers`, or subclass and override this method (or `create_optimizer` and/or
`create_scheduler`) in a subclass.
"""
self.create_optimizer()
if IS_SAGEMAKER_MP_POST_1_10 and smp.state.cfg.fp16:
optimizer = self.optimizer.optimizer
else:
optimizer = self.optimizer
self.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer)
def get_decay_parameter_names(self, model) -> List[str]:
"""
Get all parameter names that weight decay will be applied to
Note that some models implement their own layernorm instead of calling nn.LayerNorm, weight decay could still
apply to those modules since this function only filter out instance of nn.LayerNorm
"""
decay_parameters = get_parameter_names(model, ALL_LAYERNORM_LAYERS)
decay_parameters = [name for name in decay_parameters if "bias" not in name]
return decay_parameters
def create_optimizer(self):
"""
Setup the optimizer.
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
Trainer's init through `optimizers`, or subclass and override this method in a subclass.
"""
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if self.optimizer is None:
decay_parameters = self.get_decay_parameter_names(opt_model)
optimizer_grouped_parameters = [
{
"params": [
p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)
],
"weight_decay": self.args.weight_decay,
},
{
"params": [
p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)
],
"weight_decay": 0.0,
},
]
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args, opt_model)
if "params" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop("params")
if "optimizer_dict" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop("optimizer_dict")
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if optimizer_cls.__name__ == "Adam8bit":
import bitsandbytes
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
skipped = 0
for module in opt_model.modules():
if isinstance(module, nn.Embedding):
skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
logger.info(f"skipped {module}: {skipped/2**20}M params")
manager.register_module_override(module, "weight", {"optim_bits": 32})
logger.debug(f"bitsandbytes: will optimize {module} in fp32")
logger.info(f"skipped: {skipped/2**20}M params")
if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer(self.optimizer)
return self.optimizer
@staticmethod
def get_optimizer_cls_and_kwargs(
args: TrainingArguments, model: Optional[PreTrainedModel] = None
):
"""
Helper function to retrieve the optimizer class and its keyword arguments based on training arguments and model.
"""
pass
def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None):
"""
设置调度器。在调用此方法之前,训练器的优化器必须已经设置好,或者作为参数传递进来。
Args:
num_training_steps (int): 要执行的训练步数。
"""
if self.lr_scheduler is None:
self.lr_scheduler = get_scheduler(
self.args.lr_scheduler_type,
optimizer=self.optimizer if optimizer is None else optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
scheduler_specific_kwargs=self.args.lr_scheduler_kwargs,
)
self._created_lr_scheduler = True
return self.lr_scheduler
def num_examples(self, dataloader: DataLoader) -> int:
"""
辅助函数,通过访问其数据集来获取 [`~torch.utils.data.DataLoader`] 中的样本数量。
当 dataloader.dataset 不存在或长度为零时,尽可能进行估计。
"""
try:
dataset = dataloader.dataset
if isinstance(dataset, IterableDatasetShard):
return len(dataloader.dataset.dataset)
return len(dataloader.dataset)
except (NameError, AttributeError, TypeError):
return len(dataloader) * self.args.per_device_train_batch_size
def num_tokens(self, train_dl: DataLoader, max_steps: Optional[int] = None) -> int:
"""
辅助函数,通过枚举数据加载器来获取 [`~torch.utils.data.DataLoader`] 中的令牌数量。
"""
train_tokens = 0
try:
for step, batch in enumerate(train_dl):
tokens = batch["input_ids"].numel()
if max_steps is not None:
return tokens * max_steps
train_tokens += tokens
return train_tokens
except KeyError:
logger.warning("Cannot get num_tokens from dataloader")
return train_tokens
def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]):
"""HP search setup code"""
self._trial = trial
if self.hp_search_backend is None or trial is None:
return
if self.hp_search_backend == HPSearchBackend.OPTUNA:
params = self.hp_space(trial)
elif self.hp_search_backend == HPSearchBackend.RAY:
params = trial
params.pop("wandb", None)
elif self.hp_search_backend == HPSearchBackend.SIGOPT:
params = {k: int(v) if isinstance(v, str) else v for k, v in trial.assignments.items()}
elif self.hp_search_backend == HPSearchBackend.WANDB:
params = trial
for key, value in params.items():
if not hasattr(self.args, key):
logger.warning(
f"Trying to set {key} in the hyperparameter search but there is no corresponding field in"
" `TrainingArguments`."
)
continue
old_attr = getattr(self.args, key, None)
if old_attr is not None:
value = type(old_attr)(value)
setattr(self.args, key, value)
if self.hp_search_backend == HPSearchBackend.OPTUNA:
logger.info(f"Trial: {trial.params}")
if self.hp_search_backend == HPSearchBackend.SIGOPT:
logger.info(f"SigOpt Assignments: {trial.assignments}")
if self.hp_search_backend == HPSearchBackend.WANDB:
logger.info(f"W&B Sweep parameters: {trial}")
if self.is_deepspeed_enabled:
if self.args.deepspeed is None:
raise ValueError("For sweeps with deepspeed, `args.deepspeed` must be set")
from accelerate.utils import DeepSpeedPlugin
from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig
self.args.hf_deepspeed_config = HfTrainerDeepSpeedConfig(self.args.deepspeed)
self.args.hf_deepspeed_config.trainer_config_process(self.args)
self.args.deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=self.args.hf_deepspeed_config)
self.create_accelerator_and_postprocess()
def _report_to_hp_search(self, trial: Union["optuna.Trial", Dict[str, Any]], step: int, metrics: Dict[str, float]):
if self.hp_search_backend is None or trial is None:
return
metrics = metrics.copy()
self.objective = self.compute_objective(metrics)
if self.hp_search_backend == HPSearchBackend.OPTUNA:
import optuna
if not trial.study._is_multi_objective():
trial.report(self.objective, step)
if trial.should_prune():
self.callback_handler.on_train_end(self.args, self.state, self.control)
raise optuna.TrialPruned()
elif self.hp_search_backend == HPSearchBackend.RAY:
import ray.train
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
checkpoint = None
if self.control.should_save:
self._tune_save_checkpoint(checkpoint_dir=temp_checkpoint_dir)
checkpoint = ray.train.Checkpoint.from_directory(temp_checkpoint_dir)
metrics["objective"] = self.objective
ray.train.report(metrics, checkpoint=checkpoint)
def _tune_save_checkpoint(self, checkpoint_dir: str):
output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
self.save_model(output_dir, _internal_call=True)
if self.args.should_save:
self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
def call_model_init(self, trial=None):
model_init_argcount = number_of_arguments(self.model_init)
if model_init_argcount == 0:
model = self.model_init()
elif model_init_argcount == 1:
model = self.model_init(trial)
else:
raise RuntimeError("model_init should have 0 or 1 argument.")
if model is None:
raise RuntimeError("model_init should not return None.")
return model
def torch_jit_model_eval(self, model, dataloader, training=False):
if not training:
if dataloader is None:
logger.warning("failed to use PyTorch jit mode due to current dataloader is none.")
return model
example_batch = next(iter(dataloader))
example_batch = self._prepare_inputs(example_batch)
try:
jit_model = copy.copy(model)
jit_model.eval()
original_forward = jit_model.__dict__.pop("_original_forward", None)
if original_forward:
jit_model.forward = original_forward
with self.accelerator.autocast(cache_enabled=False), torch.no_grad():
if version.parse(version.parse(torch.__version__).base_version) >= version.parse("2.0.0"):
if isinstance(example_batch, dict):
jit_model = torch.jit.trace(jit_model, example_kwarg_inputs=example_batch, strict=False)
else:
jit_model = torch.jit.trace(
jit_model,
example_kwarg_inputs={key: example_batch[key] for key in example_batch},
strict=False,
)
else:
jit_inputs = []
for key in example_batch:
example_tensor = torch.ones_like(example_batch[key])
jit_inputs.append(example_tensor)
jit_inputs = tuple(jit_inputs)
jit_model = torch.jit.trace(jit_model, jit_inputs, strict=False)
jit_model = torch.jit.freeze(jit_model)
with torch.no_grad():
jit_model(**example_batch)
jit_model(**example_batch)
model = jit_model
self.use_cpu_amp = False
except (RuntimeError, TypeError, ValueError, NameError, IndexError) as e:
logger.warning(f"failed to use PyTorch jit mode due to: {e}.")
return model
def ipex_optimize_model(self, model, training=False, dtype=torch.float32):
if not is_ipex_available():
raise ImportError(
"Using IPEX but IPEX is not installed or IPEX's version does not match current PyTorch, please refer"
" to https://github.com/intel/intel-extension-for-pytorch."
)
import intel_extension_for_pytorch as ipex
if not training:
model.eval()
dtype = torch.bfloat16 if not self.is_in_train and self.args.bf16_full_eval else dtype
model = ipex.optimize(model, dtype=dtype, level="O1", conv_bn_folding=False, inplace=not self.is_in_train)
else:
if not model.training:
model.train()
model, self.optimizer = ipex.optimize(
model, dtype=dtype, optimizer=self.optimizer, inplace=True, level="O1"
)
return model
def train(
self,
resume_from_checkpoint: Optional[Union[str, bool]] = None,
trial: Union["optuna.Trial", Dict[str, Any]] = None,
ignore_keys_for_eval: Optional[List[str]] = None,
**kwargs,
):
def _inner_training_loop(
self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None
):
def _get_output_dir(self, trial):
if self.hp_search_backend is not None and trial is not None:
if self.hp_search_backend == HPSearchBackend.OPTUNA:
run_id = trial.number
elif self.hp_search_backend == HPSearchBackend.RAY:
import ray.train
run_id = ray.train.get_context().get_trial_id()
elif self.hp_search_backend == HPSearchBackend.SIGOPT:
run_id = trial.id
elif self.hp_search_backend == HPSearchBackend.WANDB:
import wandb
run_id = wandb.run.id
run_name = self.hp_name(trial) if self.hp_name is not None else f"run-{run_id}"
run_dir = os.path.join(self.args.output_dir, run_name)
else:
run_dir = self.args.output_dir
return run_dir
def _issue_warnings_after_load(self, load_result):
if len(load_result.missing_keys) != 0:
if self.model._keys_to_ignore_on_save is not None and set(load_result.missing_keys) == set(
self.model._keys_to_ignore_on_save
):
self.model.tie_weights()
else:
logger.warning(f"There were missing keys in the checkpoint model loaded: {load_result.missing_keys}.")
if len(load_result.unexpected_keys) != 0:
logger.warning(
f"There were unexpected keys in the checkpoint model loaded: {load_result.unexpected_keys}."
)
if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
if is_torch_xla_available():
xm.mark_step()
logs: Dict[str, float] = {}
tr_loss_scalar = self._nested_gather(tr_loss).mean().item()
tr_loss -= tr_loss
logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
if grad_norm is not None:
logs["grad_norm"] = grad_norm.detach().item() if isinstance(grad_norm, torch.Tensor) else grad_norm
logs["learning_rate"] = self._get_learning_rate()
self._total_loss_scalar += tr_loss_scalar
self._globalstep_last_logged = self.state.global_step
self.store_flos()
self.log(logs)
metrics = None
if self.control.should_evaluate:
metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
self._report_to_hp_search(trial, self.state.global_step, metrics)
if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
metric_to_check = self.args.metric_for_best_model
if not metric_to_check.startswith("eval_"):
metric_to_check = f"eval_{metric_to_check}"
self.lr_scheduler.step(metrics[metric_to_check])
if self.control.should_save:
self._save_checkpoint(model, trial, metrics=metrics)
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
def _load_rng_state(self, checkpoint):
if checkpoint is None:
return
if self.args.world_size > 1:
process_index = self.args.process_index
rng_file = os.path.join(checkpoint, f"rng_state_{process_index}.pth")
if not os.path.isfile(rng_file):
logger.info(
f"Didn't find an RNG file for process {process_index}, if you are resuming a training that "
"wasn't launched in a distributed fashion, reproducibility is not guaranteed."
)
return
else:
rng_file = os.path.join(checkpoint, "rng_state.pth")
if not os.path.isfile(rng_file):
logger.info(
"Didn't find an RNG file, if you are resuming a training that was launched in a distributed "
"fashion, reproducibility is not guaranteed."
)
return
checkpoint_rng_state = torch.load(rng_file)
random.setstate(checkpoint_rng_state["python"])
np.random.set_state(checkpoint_rng_state["numpy"])
torch.random.set_rng_state(checkpoint_rng_state["cpu"])
if torch.cuda.is_available():
if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
torch.cuda.random.set_rng_state_all(checkpoint_rng_state["cuda"])
else:
try:
torch.cuda.random.set_rng_state(checkpoint_rng_state["cuda"])
except Exception as e:
logger.info(
f"Didn't manage to set back the RNG states of the GPU because of the following error:\n {e}"
"\nThis won't yield the same results as if the training had not been interrupted."
)
if is_torch_xla_available():
xm.set_rng_state(checkpoint_rng_state["xla"])
if is_torch_npu_available():
if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
torch.npu.random.set_rng_state_all(checkpoint_rng_state["npu"])
else:
try:
torch.npu.random.set_rng_state(checkpoint_rng_state["npu"])
except Exception as e:
logger.info(
f"Didn't manage to set back the RNG states of the NPU because of the following error:\n {e}"
"\nThis won't yield the same results as if the training had not been interrupted."
)
def _save_checkpoint(self, model, trial, metrics=None):
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
if self.hp_search_backend is None and trial is None:
self.store_flos()
run_dir = self._get_output_dir(trial=trial)
output_dir = os.path.join(run_dir, checkpoint_folder)
self.save_model(output_dir, _internal_call=True)
if not self.args.save_only_model:
self._save_optimizer_and_scheduler(output_dir)
self._save_rng_state(output_dir)
if metrics is not None and self.args.metric_for_best_model is not None:
metric_to_check = self.args.metric_for_best_model
if not metric_to_check.startswith("eval_"):
metric_to_check = f"eval_{metric_to_check}"
metric_value = metrics[metric_to_check]
operator = np.greater if self.args.greater_is_better else np.less
if (
self.state.best_metric is None
or self.state.best_model_checkpoint is None
or operator(metric_value, self.state.best_metric)
):
self.state.best_metric = metric_value
self.state.best_model_checkpoint = output_dir
if self.args.should_save:
self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
if self.args.push_to_hub:
self._push_from_checkpoint(output_dir)
if self.args.should_save:
self._rotate_checkpoints(use_mtime=False, output_dir=run_dir)
rng_states = {
"python": random.getstate(),
"numpy": np.random.get_state(),
"cpu": torch.random.get_rng_state(),
}
if torch.cuda.is_available():
if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
rng_states["cuda"] = torch.cuda.random.get_rng_state_all()
else:
rng_states["cuda"] = torch.cuda.random.get_rng_state()
if is_torch_xla_available():
rng_states["xla"] = xm.get_rng_state()
if is_torch_npu_available():
if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
rng_states["npu"] = torch.npu.random.get_rng_state_all()
else:
rng_states["npu"] = torch.npu.random.get_rng_state()
os.makedirs(output_dir, exist_ok=True)
if self.args.world_size <= 1:
torch.save(rng_states, os.path.join(output_dir, "rng_state.pth"))
else:
torch.save(rng_states, os.path.join(output_dir, f"rng_state_{self.args.process_index}.pth"))
def _save_optimizer_and_scheduler(self, output_dir):
if is_torch_xla_available():
xm.rendezvous("saving_optimizer_states")
xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
with warnings.catch_warnings(record=True) as caught_warnings:
xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
reissue_pt_warnings(caught_warnings)
elif is_sagemaker_mp_enabled():
opt_state_dict = self.optimizer.local_state_dict(gather_if_shard=False)
smp.barrier()
if smp.rdp_rank() == 0 or smp.state.cfg.shard_optimizer_state:
smp.save(
opt_state_dict,
os.path.join(output_dir, OPTIMIZER_NAME),
partial=True,
v3=smp.state.cfg.shard_optimizer_state,
)
elif self.is_deepspeed_enabled:
accept_exclude_frozen_parameters = "exclude_frozen_parameters" in set(
inspect.signature(self.model_wrapped.save_checkpoint).parameters.keys()
)
if accept_exclude_frozen_parameters and _is_peft_model(self.model):
self.model_wrapped.save_checkpoint(output_dir, exclude_frozen_parameters=True)
else:
self.model_wrapped.save_checkpoint(output_dir)
elif self.is_fsdp_enabled:
save_fsdp_model(
self.accelerator.state.fsdp_plugin, self.accelerator, self.model, output_dir, **_get_fsdp_ckpt_kwargs()
)
save_fsdp_optimizer(
self.accelerator.state.fsdp_plugin, self.accelerator, self.optimizer, self.model, output_dir
)
elif self.args.should_save:
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
is_deepspeed_custom_scheduler = self.is_deepspeed_enabled and not isinstance(
self.lr_scheduler, DeepSpeedSchedulerWrapper
)
if (
self.args.should_save
and (not self.is_deepspeed_enabled or is_deepspeed_custom_scheduler)
and not is_torch_xla_available()
):
with warnings.catch_warnings(record=True) as caught_warnings:
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
reissue_pt_warnings(caught_warnings)
def hyperparameter_search(
self,
hp_space: Optional[Callable[["optuna.Trial"], Dict[str, float]]] = None,
compute_objective: Optional[Callable[[Dict[str, float]], float]] = None,
n_trials: int = 20,
direction: Union[str, List[str]] = "minimize",
backend: Optional[Union["str", HPSearchBackend]] = None,
hp_name: Optional[Callable[["optuna.Trial"], str]] = None,
**kwargs,
):
"""
Perform hyperparameter search using Optuna.
Args:
hp_space (Optional[Callable[["optuna.Trial"], Dict[str, float]]]):
Function defining the hyperparameter search space.
compute_objective (Optional[Callable[[Dict[str, float]], float]]):
Function to compute the objective given a set of hyperparameters.
n_trials (int):
Number of trials (hyperparameter combinations) to evaluate.
direction (Union[str, List[str]]):
Direction to optimize the objective, either 'minimize' or 'maximize'.
backend (Optional[Union[str, HPSearchBackend]]):
Backend for hyperparameter search.
hp_name (Optional[Callable[["optuna.Trial"], str]]):
Function to generate a name for the hyperparameter set.
**kwargs:
Additional keyword arguments passed to the hyperparameter search.
Returns:
None
"""
pass
def log(self, logs: Dict[str, float]) -> None:
"""
Log `logs` on the various objects watching training.
Subclass and override this method to inject custom behavior.
Args:
logs (`Dict[str, float]`):
The values to log.
"""
if self.state.epoch is not None:
logs["epoch"] = round(self.state.epoch, 2)
if self.args.include_num_input_tokens_seen:
logs["num_input_tokens_seen"] = self.state.num_input_tokens_seen
output = {**logs, **{"step": self.state.global_step}}
self.state.log_history.append(output)
self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)
def _prepare_input(self, data: Union[torch.Tensor, Any]) -> Union[torch.Tensor, Any]:
"""
Prepares one `data` before feeding it to the model, be it a tensor or a nested list/dictionary of tensors.
Args:
data (Union[torch.Tensor, Any]):
The input data to prepare.
Returns:
Union[torch.Tensor, Any]:
Prepared data ready to be fed into the model.
"""
if isinstance(data, Mapping):
return type(data)({k: self._prepare_input(v) for k, v in data.items()})
elif isinstance(data, (tuple, list)):
return type(data)(self._prepare_input(v) for v in data)
elif isinstance(data, torch.Tensor):
kwargs = {"device": self.args.device}
if self.is_deepspeed_enabled and (torch.is_floating_point(data) or torch.is_complex(data)):
kwargs.update({"dtype": self.accelerator.state.deepspeed_plugin.hf_ds_config.dtype()})
return data.to(**kwargs)
return data
def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:
"""
Prepare `inputs` before feeding them to the model, converting them to tensors if they are not already and
handling potential state.
"""
inputs = self._prepare_input(inputs)
if len(inputs) == 0:
raise ValueError(
"The batch received was empty, your model won't be able to train on it. Double-check that your "
f"training dataset contains keys expected by the model: {','.join(self._signature_columns)}."
)
if self.args.past_index >= 0 and self._past is not None:
inputs["mems"] = self._past
return inputs
def compute_loss_context_manager(self):
"""
A helper wrapper to group together context managers.
"""
return self.autocast_smart_context_manager()
def autocast_smart_context_manager(self, cache_enabled: Optional[bool] = True):
"""
A helper wrapper that creates an appropriate context manager for `autocast` while feeding it the desired
arguments, depending on the situation.
"""
if self.use_cpu_amp:
ctx_manager = torch.cpu.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)
else:
ctx_manager = contextlib.nullcontext()
return ctx_manager
def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
"""
Perform a training step on a batch of inputs.
Subclass and override to inject custom behavior.
Args:
model (`nn.Module`):
The model to train.
inputs (`Dict[str, Union[torch.Tensor, Any]]`):
The inputs and targets of the model.
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
argument `labels`. Check your model's documentation for all accepted arguments.
Return:
`torch.Tensor`: The tensor with training loss on this batch.
"""
model.train()
inputs = self._prepare_inputs(inputs)
if is_sagemaker_mp_enabled():
loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
return loss_mb.reduce_mean().detach().to(self.args.device)
with self.compute_loss_context_manager():
loss = self.compute_loss(model, inputs)
if self.args.n_gpu > 1:
loss = loss.mean()
if self.use_apex:
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
else:
self.accelerator.backward(loss)
return loss.detach() / self.args.gradient_accumulation_steps
def compute_loss(self, model, inputs, return_outputs=False):
"""
How the loss is computed by Trainer. By default, all models return the loss in the first element.
Subclass and override for custom behavior.
"""
if self.label_smoother is not None and "labels" in inputs:
labels = inputs.pop("labels")
else:
labels = None
outputs = model(**inputs)
if self.args.past_index >= 0:
self._past = outputs[self.args.past_index]
if labels is not None:
unwrapped_model = unwrap_model(model)
if _is_peft_model(unwrapped_model):
model_name = unwrapped_model.base_model.model._get_name()
else:
model_name = unwrapped_model._get_name()
if model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
loss = self.label_smoother(outputs, labels, shift_labels=True)
else:
loss = self.label_smoother(outputs, labels)
else:
if isinstance(outputs, dict) and "loss" not in outputs:
raise ValueError(
"The model did not return a loss from the inputs, only the following keys: "
f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
)
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
return (loss, outputs) if return_outputs else loss
def is_local_process_zero(self) -> bool:
"""
Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on several
machines) main process.
"""
return self.args.local_process_index == 0
def is_world_process_zero(self) -> bool:
"""
Whether or not this process is the global main process (when training in a distributed fashion on several
machines, this is only going to be `True` for one process).
"""
if is_sagemaker_mp_enabled():
return smp.rank() == 0
else:
return self.args.process_index == 0
def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
"""
Will save the model, so you can reload it using `from_pretrained()`.
Will only save from the main process.
"""
if output_dir is None:
output_dir = self.args.output_dir
if is_torch_xla_available():
self._save_tpu(output_dir)
elif is_sagemaker_mp_enabled():
os.makedirs(output_dir, exist_ok=True)
state_dict = self.model_wrapped.state_dict()
if self.args.should_save:
self._save(output_dir, state_dict=state_dict)
if IS_SAGEMAKER_MP_POST_1_10:
Path(os.path.join(output_dir, "user_content.pt")).touch()
elif self.is_fsdp_enabled:
if ("FULL_STATE_DICT" in str(self.accelerator.state.fsdp_plugin.state_dict_type)) and (
version.parse(accelerate_version) > version.parse("0.24.1")
):
state_dict = self.accelerator.get_state_dict(self.model)
if self.args.should_save:
self._save(output_dir, state_dict=state_dict)
elif self.is_deepspeed_enabled:
try:
state_dict = self.accelerator.get_state_dict(self.deepspeed)
if self.args.should_save:
self._save(output_dir, state_dict=state_dict)
except ValueError:
logger.warning(
" stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use"
" zero_to_fp32.py to recover weights"
)
if self.args.should_save:
self._save(output_dir, state_dict={})
remove_dummy_checkpoint(self.args.should_save, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME])
self.model_wrapped.save_checkpoint(output_dir)
elif self.args.should_save:
self._save(output_dir)
if self.args.push_to_hub and not _internal_call:
self.push_to_hub(commit_message="Model save")
def _save_tpu(self, output_dir: Optional[str] = None):
output_dir = output_dir if output_dir is not None else self.args.output_dir
logger.info(f"Saving model checkpoint to {output_dir}")
model = self.model
xm.mark_step()
model.to("cpu")
if xm.is_master_ordinal():
os.makedirs(output_dir, exist_ok=True)
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
supported_classes = (PushToHubMixin,)
xm.rendezvous("saving_checkpoint")
if not isinstance(model, supported_classes):
if isinstance(unwrap_model(model), supported_classes):
unwrap_model(model).save_pretrained(
output_dir,
is_main_process=self.args.should_save,
state_dict=model.state_dict(),
save_function=xm.save,
safe_serialization=self.args.save_safetensors,
)
else:
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
state_dict = model.state_dict()
xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
else:
model.save_pretrained(
output_dir,
is_main_process=self.args.should_save,
save_function=xm.save,
safe_serialization=self.args.save_safetensors,
)
if self.tokenizer is not None and self.args.should_save:
self.tokenizer.save_pretrained(output_dir)
model.to(self.args.device)
def _save(self, output_dir: Optional[str] = None, state_dict=None):
output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True)
logger.info(f"Saving model checkpoint to {output_dir}")
supported_classes = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel)
if not isinstance(self.model, supported_classes):
if state_dict is None:
state_dict = self.model.state_dict()
if isinstance(unwrap_model(self.model), supported_classes):
unwrap_model(self.model).save_pretrained(
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
)
else:
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
if self.args.save_safetensors:
safetensors.torch.save_file(
state_dict, os.path.join(output_dir, SAFE_WEIGHTS_NAME), metadata={"format": "pt"}
)
else:
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
else:
self.model.save_pretrained(
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
)
if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir)
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
def store_flos(self):
if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
self.state.total_flos += (
distributed_broadcast_scalars([self.current_flos], device=self.args.device).sum().item()
)
self.current_flos = 0
else:
self.state.total_flos += self.current_flos
self.current_flos = 0
def _sorted_checkpoints(
self, output_dir=None, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False
):
def _sorted_checkpoints(self, use_mtime=False, output_dir=None) -> List[str]:
ordering_and_checkpoint_path = []
glob_checkpoints = [str(x) for x in Path(output_dir).glob(f"{checkpoint_prefix}-*") if os.path.isdir(x)]
for path in glob_checkpoints:
if use_mtime:
ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
else:
regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path)
if regex_match is not None and regex_match.groups() is not None:
ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))
checkpoints_sorted = sorted(ordering_and_checkpoint_path)
checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
if (
self.state.best_model_checkpoint is not None
and str(Path(self.state.best_model_checkpoint)) in checkpoints_sorted
):
best_model_index = checkpoints_sorted.index(str(Path(self.state.best_model_checkpoint)))
for i in range(best_model_index, len(checkpoints_sorted) - 2):
checkpoints_sorted[i], checkpoints_sorted[i + 1] = checkpoints_sorted[i + 1], checkpoints_sorted[i]
return checkpoints_sorted
def _rotate_checkpoints(self, use_mtime=False, output_dir=None) -> None:
if self.args.save_total_limit is None or self.args.save_total_limit <= 0:
return
checkpoints_sorted = self._sorted_checkpoints(use_mtime=use_mtime, output_dir=output_dir)
if len(checkpoints_sorted) <= self.args.save_total_limit:
return
save_total_limit = self.args.save_total_limit
if (
self.state.best_model_checkpoint is not None
and self.args.save_total_limit == 1
and checkpoints_sorted[-1] != self.state.best_model_checkpoint
):
save_total_limit = 2
number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - save_total_limit)
checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
for checkpoint in checkpoints_to_be_deleted:
logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
shutil.rmtree(checkpoint, ignore_errors=True)
"""
# 运行预测并返回预测结果和可能的指标。
# 根据数据集和使用情况,测试数据集可能包含标签。在这种情况下,该方法还会返回指标,例如在 `evaluate()` 中一样。
Args:
test_dataset (`Dataset`):
要运行预测的数据集。如果是 `datasets.Dataset`,则会自动删除模型 `forward()` 方法不接受的列。必须实现 `__len__` 方法。
ignore_keys (`List[str]`, *可选*):
在模型输出中应忽略的键列表(如果是字典)。
metric_key_prefix (`str`, *可选*, 默认为 `"test"`):
用作指标键前缀的可选前缀。例如,如果前缀是 "test"(默认),则指标 "bleu" 将命名为 "test_bleu"。
<Tip>
如果您的预测或标签具有不同的序列长度(例如,因为您在标记分类任务中进行动态填充),则会对预测进行填充(在右侧),以允许串联到一个数组中。填充索引为 -100。
</Tip>
Returns: *NamedTuple* 具有以下键的命名元组:
- predictions (`np.ndarray`): 对 `test_dataset` 的预测。
- label_ids (`np.ndarray`, *可选*): 标签(如果数据集包含)。
- metrics (`Dict[str, float]`, *可选*): 可能包含标签的字典。
"""
self._memory_tracker.start()
test_dataloader = self.get_test_dataloader(test_dataset)
start_time = time.time()
eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
output = eval_loop(
test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix
)
total_batch_size = self.args.eval_batch_size * self.args.world_size
if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
output.metrics.update(
speed_metrics(
metric_key_prefix,
start_time,
num_samples=output.num_samples,
num_steps=math.ceil(output.num_samples / total_batch_size),
)
)
self.control = self.callback_handler.on_predict(self.args, self.state, self.control, output.metrics)
self._memory_tracker.stop_and_update_metrics(output.metrics)
return PredictionOutput(predictions=output.predictions, label_ids=output.label_ids, metrics=output.metrics)
def evaluation_loop(
self,
dataloader: DataLoader,
description: str,
prediction_loss_only: Optional[bool] = None,
ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "eval",
):
"""
Runs evaluation loop over `dataloader`.
Args:
dataloader (DataLoader): The data loader for evaluation.
description (str): Description of the evaluation loop.
prediction_loss_only (Optional[bool], optional): Whether to compute only prediction loss. Defaults to None.
ignore_keys (Optional[List[str]], optional): List of keys to ignore during evaluation. Defaults to None.
metric_key_prefix (str, optional): Prefix for metric keys. Defaults to "eval".
"""
def _nested_gather(self, tensors, name=None):
"""
Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before
concatenating them to `gathered`.
"""
if tensors is None:
return
if is_torch_xla_available():
if name is None:
name = "nested_gather"
tensors = nested_xla_mesh_reduce(tensors, name)
elif is_sagemaker_mp_enabled():
tensors = smp_gather(tensors)
elif (self.args.distributed_state is not None and self.args.distributed_state.distributed_type != "NO") or (
self.args.distributed_state is None and self.args.local_rank != -1
):
tensors = distributed_concat(tensors)
return tensors
def prediction_step(
self,
model: nn.Module,
inputs: Dict[str, Union[torch.Tensor, Any]],
prediction_loss_only: bool,
ignore_keys: Optional[List[str]] = None,
):
"""
Perform a prediction step using `model` on `inputs`.
Args:
model (nn.Module): The model for prediction.
inputs (Dict[str, Union[torch.Tensor, Any]]): Dictionary of inputs for the model.
prediction_loss_only (bool): Whether to compute only prediction loss.
ignore_keys (Optional[List[str]], optional): List of keys to ignore during prediction.
Returns:
Depends on the model's prediction step implementation.
"""
def floating_point_ops(self, inputs: Dict[str, Union[torch.Tensor, Any]]):
"""
Computes the number of floating point operations for the model.
Args:
inputs (Dict[str, Union[torch.Tensor, Any]]): The inputs and targets of the model.
Returns:
int: The number of floating-point operations.
"""
if hasattr(self.model, "floating_point_ops"):
return self.model.floating_point_ops(inputs)
else:
return 0
def init_hf_repo(self):
"""
Initializes a git repository in `self.args.hub_model_id`.
"""
if not self.is_world_process_zero():
return
if self.args.hub_model_id is None:
repo_name = Path(self.args.output_dir).absolute().name
else:
repo_name = self.args.hub_model_id
repo_url = create_repo(repo_name, token=self.args.hub_token, private=self.args.hub_private_repo, exist_ok=True)
self.hub_model_id = repo_url.repo_id
self.push_in_progress = None
`
def create_model_card(
self,
language: Optional[str] = None,
license: Optional[str] = None,
tags: Union[str, List[str], None] = None,
model_name: Optional[str] = None,
finetuned_from: Optional[str] = None,
tasks: Union[str, List[str], None] = None,
dataset_tags: Union[str, List[str], None] = None,
):
"""
Creates a draft of a model card using the information available to the `Trainer`.
Args:
language (`str`, *optional*):
The language of the model (if applicable)
license (`str`, *optional*):
The license of the model. Will default to the license of the pretrained model used, if the original
model given to the `Trainer` comes from a repo on the Hub.
tags (`str` or `List[str]`, *optional*):
Some tags to be included in the metadata of the model card.
model_name (`str`, *optional*):
The name of the model.
finetuned_from (`str`, *optional*):
The name of the model used to fine-tune this one (if applicable). Will default to the name of the repo
of the original model given to the `Trainer` (if it comes from the Hub).
tasks (`str` or `List[str]`, *optional*):
One or several task identifiers, to be included in the metadata of the model card.
dataset_tags (`str` or `List[str]`, *optional*):
One or several dataset tags, to be included in the metadata of the model card.
dataset (`str` or `List[str]`, *optional*):
One or several dataset identifiers, to be included in the metadata of the model card.
dataset_args (`str` or `List[str]`, *optional*):
One or several dataset arguments, to be included in the metadata of the model card.
"""
if not self.is_world_process_zero():
return
model_card_filepath = os.path.join(self.args.output_dir, "README.md")
is_peft_library = False
if os.path.exists(model_card_filepath):
library_name = ModelCard.load(model_card_filepath).data.get("library_name")
is_peft_library = library_name == "peft"
existing_tags = ModelCard.load(model_card_filepath).data.tags
if tags is not None and existing_tags is not None:
if isinstance(tags, str):
tags = [tags]
for tag in existing_tags:
if tag not in tags:
tags.append(tag)
training_summary = TrainingSummary.from_trainer(
self,
language=language,
license=license,
tags=tags,
model_name=model_name,
finetuned_from=finetuned_from,
tasks=tasks,
dataset_tags=dataset_tags,
dataset=dataset,
dataset_args=dataset_args,
)
model_card = training_summary.to_model_card()
with open(model_card_filepath, "w") as f:
f.write(model_card)
if is_peft_library:
unwrap_model(self.model).create_or_update_model_card(self.args.output_dir)
if not self.is_world_process_zero() or self.args.hub_strategy == HubStrategy.END:
return
if not self.args.hub_always_push and self.push_in_progress is not None and not self.push_in_progress.is_done():
return
output_dir = self.args.output_dir
modeling_files = [CONFIG_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_NAME]
if is_peft_available():
modeling_files.extend([ADAPTER_CONFIG_NAME, ADAPTER_WEIGHTS_NAME, ADAPTER_SAFE_WEIGHTS_NAME])
for modeling_file in modeling_files:
if os.path.isfile(os.path.join(checkpoint_folder, modeling_file)):
shutil.copy(os.path.join(checkpoint_folder, modeling_file), os.path.join(output_dir, modeling_file))
if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir)
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
if self.args.save_strategy == IntervalStrategy.STEPS:
commit_message = f"Training in progress, step {self.state.global_step}"
else:
commit_message = f"Training in progress, epoch {int(self.state.epoch)}"
model_push_job = upload_folder(
repo_id=self.hub_model_id,
folder_path=output_dir,
commit_message=commit_message,
token=self.args.hub_token,
run_as_future=True,
ignore_patterns=["_*", f"{PREFIX_CHECKPOINT_DIR}-*"],
)
push_jobs = [model_push_job]
if self.args.hub_strategy in [HubStrategy.CHECKPOINT, HubStrategy.ALL_CHECKPOINTS]:
path_in_repo = (
"last-checkpoint" if self.args.hub_strategy == HubStrategy.CHECKPOINT else Path(checkpoint_folder).name
)
checkpoint_push = upload_folder(
repo_id=self.hub_model_id,
folder_path=checkpoint_folder,
path_in_repo=path_in_repo,
commit_message=commit_message + ", checkpoint",
token=self.args.hub_token,
run_as_future=True,
)
push_jobs.append(checkpoint_push)
if self.push_in_progress is None or self.push_in_progress.is_done():
self.push_in_progress = PushInProgress(push_jobs)
else:
self.push_in_progress.jobs.extend(push_jobs)
if not hasattr(self, "push_in_progress"):
return
if self.push_in_progress is not None and not self.push_in_progress.is_done():
logger.info("Waiting for the current checkpoint push to be finished, this might take a couple of minutes.")
self.push_in_progress.wait_until_done()
def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str:
"""
Upload `self.model` and `self.tokenizer` to the 🤗 model hub on the repo `self.args.hub_model_id`.
Parameters:
commit_message (`str`, *optional*, defaults to `"End of training"`):
Message to commit while pushing.
blocking (`bool`, *optional*, defaults to `True`):
Whether the function should return only when the `git push` has finished.
kwargs (`Dict[str, Any]`, *optional*):
Additional keyword arguments passed along to [`~Trainer.create_model_card`].
Returns:
The URL of the repository where the model was pushed if `blocking=False`, or a `Future` object tracking the
progress of the commit if `blocking=True`.
"""
model_name = kwargs.pop("model_name", None)
if model_name is None and self.args.should_save:
if self.args.hub_model_id is None:
model_name = Path(self.args.output_dir).name
else:
model_name = self.args.hub_model_id.split("/")[-1]
if self.hub_model_id is None:
self.init_hf_repo()
self.save_model(_internal_call=True)
if not self.is_world_process_zero():
return
if getattr(self.model, "model_tags", None) is not None:
if "tags" not in kwargs:
kwargs["tags"] = []
if isinstance(kwargs["tags"], str):
kwargs["tags"] = [kwargs["tags"]]
for model_tag in self.model.model_tags:
if model_tag not in kwargs["tags"]:
kwargs["tags"].append(model_tag)
self.create_model_card(model_name=model_name, **kwargs)
self._finish_current_push()
return upload_folder(
repo_id=self.hub_model_id,
folder_path=self.args.output_dir,
commit_message=commit_message,
token=self.args.hub_token,
run_as_future=not blocking,
ignore_patterns=["_*", f"{PREFIX_CHECKPOINT_DIR}-*"],
)
def prediction_loop(
self,
dataloader: DataLoader,
description: str,
prediction_loss_only: Optional[bool] = None,
ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "eval",
):
"""
Perform a prediction loop over the given data loader.
Args:
dataloader (DataLoader): The data loader containing the data to predict on.
description (str): Description of the prediction loop.
prediction_loss_only (Optional[bool], optional): Whether to calculate only prediction loss.
ignore_keys (Optional[List[str]], optional): Keys to ignore during prediction.
metric_key_prefix (str, optional): Prefix for metric keys.
Returns:
None
"""
def _gather_and_numpify(self, tensors, name):
"""
Gather values of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy arrays.
Args:
tensors: Tensor or list/tuple of nested tensors to gather and convert.
name: Name associated with the gathering operation.
Returns:
numpy.ndarray or list/tuple/nested structure of numpy arrays corresponding to `tensors`.
"""
if tensors is None:
return
if is_torch_xla_available():
tensors = nested_xla_mesh_reduce(tensors, name)
elif is_sagemaker_mp_enabled():
tensors = smp_gather(tensors)
elif self.args.parallel_mode == ParallelMode.DISTRIBUTED:
tensors = distributed_concat(tensors)
return nested_numpify(tensors)
def _add_sm_patterns_to_gitignore(self) -> None:
"""
Add SageMaker Checkpointing patterns to .gitignore file if running on the main process.
Returns:
None
"""
if not self.is_world_process_zero():
return
patterns = ["*.sagemaker-uploading", "*.sagemaker-uploaded"]
if os.path.exists(os.path.join(self.repo.local_dir, ".gitignore")):
with open(os.path.join(self.repo.local_dir, ".gitignore"), "r") as f:
current_content = f.read()
else:
current_content = ""
content = current_content
for pattern in patterns:
if pattern not in content:
if content.endswith("\n"):
content += pattern
else:
content += f"\n{pattern}"
if content != current_content:
with open(os.path.join(self.repo.local_dir, ".gitignore"), "w") as f:
logger.debug(f"Writing .gitignore file. Content: {content}")
f.write(content)
self.repo.git_add(".gitignore")
time.sleep(0.5)
if not self.repo.is_repo_clean():
self.repo.git_commit("Add *.sagemaker patterns to .gitignore.")
self.repo.git_push()
def create_accelerator_and_postprocess(self):
grad_acc_kwargs = {"num_steps": self.args.gradient_accumulation_steps}
grad_acc_kwargs["sync_with_dataloader"] = False
gradient_accumulation_plugin = GradientAccumulationPlugin(**grad_acc_kwargs)
self.accelerator = Accelerator(
deepspeed_plugin=self.args.deepspeed_plugin,
gradient_accumulation_plugin=gradient_accumulation_plugin,
**self.args.accelerator_config.to_dict(),
)
self.gather_function = self.accelerator.gather_for_metrics
self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
if self.is_fsdp_enabled:
fsdp_plugin = self.accelerator.state.fsdp_plugin
fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get(
"limit_all_gathers", fsdp_plugin.limit_all_gathers
)
if is_accelerate_available("0.23.0"):
fsdp_plugin.activation_checkpointing = self.args.fsdp_config.get(
"activation_checkpointing", fsdp_plugin.activation_checkpointing
)
if fsdp_plugin.activation_checkpointing and self.args.gradient_checkpointing:
raise ValueError(
"The activation_checkpointing in FSDP config and the gradient_checkpointing in training arg "
"can't be set to True simultaneously. Please use FSDP's activation_checkpointing logic "
"when using FSDP."
)
if self.is_deepspeed_enabled and getattr(self.args, "hf_deepspeed_config", None) is None:
self.propagate_args_to_deepspeed()
if (
self.args.save_only_model
and (self.is_deepspeed_enabled or self.is_fsdp_enabled)
and self.args.load_best_model_at_end
):
wrapper = "DeepSpeed" if self.is_deepspeed_enabled else "FSDP"
raise ValueError(f"{wrapper} can't be used with `save_only_model` along with `load_best_model_at_end`.")
if (self.is_deepspeed_enabled or self.is_fsdp_enabled) and self.args.auto_find_batch_size:
wrapper = "DeepSpeed" if self.is_deepspeed_enabled else "FSDP"
raise NotImplementedError(f"`{wrapper}` doesn't support `auto_find_batch_size`.")
def propagate_args_to_deepspeed(self, auto_find_batch_size=False):
"""
Sets values in the deepspeed plugin based on the Trainer args
根据 Trainer 参数设置 DeepSpeed 插件中的数值
"""
from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig
ds_plugin = self.accelerator.state.deepspeed_plugin
ds_plugin.hf_ds_config = HfTrainerDeepSpeedConfig(ds_plugin.hf_ds_config.config)
ds_plugin.deepspeed_config = ds_plugin.hf_ds_config.config
ds_plugin.hf_ds_config.trainer_config_process(self.args, auto_find_batch_size)
def _fsdp_qlora_plugin_updates(self):
"""
Updates the FSDP plugin with QLoRa related settings if applicable
如果适用,更新 FSDP 插件的 QLoRa 相关设置
"""
if self.is_fsdp_enabled and _is_peft_model(self.model):
from peft import LoraConfig
from peft.utils.other import fsdp_auto_wrap_policy
if isinstance(self.model.active_peft_config, LoraConfig):
fsdp_plugin = self.accelerator.state.fsdp_plugin
fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(self.model)
if (
getattr(self.model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES
and self.model.hf_quantizer.quantization_config.bnb_4bit_quant_storage.is_floating_point
and version.parse(accelerate_version) > version.parse("0.27.0")
):
fsdp_plugin.set_mixed_precision(
self.model.hf_quantizer.quantization_config.bnb_4bit_quant_storage, override=True
)