Transformers 源码解析(二)
.\commands\convert.py
from argparse import ArgumentParser, Namespace
from ..utils import logging
from . import BaseTransformersCLICommand
def convert_command_factory(args: Namespace):
"""
Factory function used to convert a model TF 1.0 checkpoint in a PyTorch checkpoint.
Returns: ServeCommand
"""
return ConvertCommand(
args.model_type, args.tf_checkpoint, args.pytorch_dump_output, args.config, args.finetuning_task_name
)
IMPORT_ERROR_MESSAGE = """
transformers can only be used from the commandline to convert TensorFlow models in PyTorch, In that case, it requires
TensorFlow to be installed. Please see https://www.tensorflow.org/install/ for installation instructions.
"""
class ConvertCommand(BaseTransformersCLICommand):
@staticmethod
def register_subcommand(parser: ArgumentParser):
"""
Register this command to argparse so it's available for the transformer-cli
Args:
parser: Root parser to register command-specific arguments
"""
train_parser = parser.add_parser(
"convert",
help="CLI tool to run convert model from original author checkpoints to Transformers PyTorch checkpoints.",
)
train_parser.add_argument("--model_type", type=str, required=True, help="Model's type.")
train_parser.add_argument(
"--tf_checkpoint", type=str, required=True, help="TensorFlow checkpoint path or folder."
)
train_parser.add_argument(
"--pytorch_dump_output", type=str, required=True, help="Path to the PyTorch saved model output."
)
train_parser.add_argument("--config", type=str, default="", help="Configuration file path or folder.")
train_parser.add_argument(
"--finetuning_task_name",
type=str,
default=None,
help="Optional fine-tuning task name if the TF model was a finetuned model.",
)
train_parser.set_defaults(func=convert_command_factory)
def __init__(
self,
model_type: str,
tf_checkpoint: str,
pytorch_dump_output: str,
config: str,
finetuning_task_name: str,
*args,
):
self._logger = logging.get_logger("transformers-cli/converting")
self._logger.info(f"Loading model {model_type}")
self._model_type = model_type
self._tf_checkpoint = tf_checkpoint
self._pytorch_dump_output = pytorch_dump_output
self._config = config
self._finetuning_task_name = finetuning_task_name
.\commands\download.py
from argparse import ArgumentParser
from . import BaseTransformersCLICommand
def download_command_factory(args):
return DownloadCommand(args.model, args.cache_dir, args.force, args.trust_remote_code)
class DownloadCommand(BaseTransformersCLICommand):
@staticmethod
def register_subcommand(parser: ArgumentParser):
download_parser = parser.add_parser("download")
download_parser.add_argument(
"--cache-dir", type=str, default=None, help="Path to location to store the models"
)
download_parser.add_argument(
"--force", action="store_true", help="Force the model to be download even if already in cache-dir"
)
download_parser.add_argument(
"--trust-remote-code",
action="store_true",
help="Whether or not to allow for custom models defined on the Hub in their own modeling files. Use only if you've reviewed the code as it will execute on your local machine",
)
download_parser.add_argument("model", type=str, help="Name of the model to download")
download_parser.set_defaults(func=download_command_factory)
def __init__(self, model: str, cache: str, force: bool, trust_remote_code: bool):
self._model = model
self._cache = cache
self._force = force
self._trust_remote_code = trust_remote_code
def run(self):
from ..models.auto import AutoModel, AutoTokenizer
AutoModel.from_pretrained(
self._model, cache_dir=self._cache, force_download=self._force, trust_remote_code=self._trust_remote_code
)
AutoTokenizer.from_pretrained(
self._model, cache_dir=self._cache, force_download=self._force, trust_remote_code=self._trust_remote_code
)
.\commands\env.py
import importlib.util
import os
import platform
from argparse import ArgumentParser
import huggingface_hub
from .. import __version__ as version
from ..utils import (
is_accelerate_available,
is_flax_available,
is_safetensors_available,
is_tf_available,
is_torch_available,
)
from . import BaseTransformersCLICommand
def info_command_factory(_):
return EnvironmentCommand()
def download_command_factory(args):
return EnvironmentCommand(args.accelerate_config_file)
class EnvironmentCommand(BaseTransformersCLICommand):
@staticmethod
def register_subcommand(parser: ArgumentParser):
download_parser = parser.add_parser("env")
download_parser.set_defaults(func=info_command_factory)
download_parser.add_argument(
"--accelerate-config_file",
default=None,
help="The accelerate config file to use for the default values in the launching script.",
)
download_parser.set_defaults(func=download_command_factory)
def __init__(self, accelerate_config_file, *args) -> None:
self._accelerate_config_file = accelerate_config_file
@staticmethod
def format_dict(d):
return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n"
.\commands\lfs.py
"""
Implementation of a custom transfer agent for the transfer type "multipart" for git-lfs.
Inspired by: github.com/cbartz/git-lfs-swift-transfer-agent/blob/master/git_lfs_swift_transfer.py
Spec is: github.com/git-lfs/git-lfs/blob/master/docs/custom-transfers.md
To launch debugger while developing:
``` [lfs "customtransfer.multipart"]
path = /path/to/transformers/.env/bin/python args = -m debugpy --listen 5678 --wait-for-client
/path/to/transformers/src/transformers/commands/transformers_cli.py lfs-multipart-upload ```"""
import json
import os
import subprocess
import sys
import warnings
from argparse import ArgumentParser
from contextlib import AbstractContextManager
from typing import Dict, List, Optional
import requests
from ..utils import logging
from . import BaseTransformersCLICommand
logger = logging.get_logger(__name__)
LFS_MULTIPART_UPLOAD_COMMAND = "lfs-multipart-upload"
class LfsCommands(BaseTransformersCLICommand):
"""
Implementation of a custom transfer agent for the transfer type "multipart" for git-lfs. This lets users upload
large files >5GB 🔥. Spec for LFS custom transfer agent is:
https://github.com/git-lfs/git-lfs/blob/master/docs/custom-transfers.md
This introduces two commands to the CLI:
1. $ transformers-cli lfs-enable-largefiles
This should be executed once for each model repo that contains a model file >5GB. It's documented in the error
message you get if you just try to git push a 5GB file without having enabled it before.
2. $ transformers-cli lfs-multipart-upload
This command is called by lfs directly and is not meant to be called by the user.
"""
@staticmethod
def register_subcommand(parser: ArgumentParser):
enable_parser = parser.add_parser(
"lfs-enable-largefiles",
help=(
"Deprecated: use `huggingface-cli` instead. Configure your repository to enable upload of files > 5GB."
),
)
enable_parser.add_argument("path", type=str, help="Local path to repository you want to configure.")
enable_parser.set_defaults(func=lambda args: LfsEnableCommand(args))
upload_parser = parser.add_parser(
LFS_MULTIPART_UPLOAD_COMMAND,
help=(
"Deprecated: use `huggingface-cli` instead. "
"Command will get called by git-lfs, do not call it directly."
),
)
upload_parser.set_defaults(func=lambda args: LfsUploadCommand(args))
class LfsEnableCommand:
def __init__(self, args):
self.args = args
def run(self):
warnings.warn(
"Managing repositories through transformers-cli is deprecated. Please use `huggingface-cli` instead."
)
local_path = os.path.abspath(self.args.path)
if not os.path.isdir(local_path):
print("This does not look like a valid git repo.")
exit(1)
subprocess.run(
"git config lfs.customtransfer.multipart.path transformers-cli".split(), check=True, cwd=local_path
)
subprocess.run(
f"git config lfs.customtransfer.multipart.args {LFS_MULTIPART_UPLOAD_COMMAND}".split(),
check=True,
cwd=local_path,
)
print("Local repo set up for largefiles")
def write_msg(msg: Dict):
msg = json.dumps(msg) + "\n"
sys.stdout.write(msg)
sys.stdout.flush()
def read_msg() -> Optional[Dict]:
msg = json.loads(sys.stdin.readline().strip())
if "terminate" in (msg.get("type"), msg.get("event")):
return None
if msg.get("event") not in ("download", "upload"):
logger.critical("Received unexpected message")
sys.exit(1)
return msg
class FileSlice(AbstractContextManager):
"""
File-like object that only reads a slice of a file
Inspired by stackoverflow.com/a/29838711/593036
"""
def __init__(self, filepath: str, seek_from: int, read_limit: int):
self.filepath = filepath
self.seek_from = seek_from
self.read_limit = read_limit
self.n_seen = 0
def __enter__(self):
self.f = open(self.filepath, "rb")
self.f.seek(self.seek_from)
return self
def __len__(self):
total_length = os.fstat(self.f.fileno()).st_size
return min(self.read_limit, total_length - self.seek_from)
def read(self, n=-1):
if self.n_seen >= self.read_limit:
return b""
remaining_amount = self.read_limit - self.n_seen
data = self.f.read(remaining_amount if n < 0 else min(n, remaining_amount))
self.n_seen += len(data)
return data
def __iter__(self):
yield self.read(n=4 * 1024 * 1024)
def __exit__(self, *args):
self.f.close()
class LfsUploadCommand:
def __init__(self, args):
self.args = args
def run(self):
init_msg = json.loads(sys.stdin.readline().strip())
if not (init_msg.get("event") == "init" and init_msg.get("operation") == "upload"):
write_msg({"error": {"code": 32, "message": "Wrong lfs init operation"}})
sys.exit(1)
write_msg({})
while True:
msg = read_msg()
if msg is None:
sys.exit(0)
oid = msg["oid"]
filepath = msg["path"]
completion_url = msg["action"]["href"]
header = msg["action"]["header"]
chunk_size = int(header.pop("chunk_size"))
presigned_urls: List[str] = list(header.values())
parts = []
for i, presigned_url in enumerate(presigned_urls):
with FileSlice(filepath, seek_from=i * chunk_size, read_limit=chunk_size) as data:
r = requests.put(presigned_url, data=data)
r.raise_for_status()
parts.append(
{
"etag": r.headers.get("etag"),
"partNumber": i + 1,
}
)
write_msg(
{
"event": "progress",
"oid": oid,
"bytesSoFar": (i + 1) * chunk_size,
"bytesSinceLast": chunk_size,
}
)
r = requests.post(
completion_url,
json={
"oid": oid,
"parts": parts,
},
)
r.raise_for_status()
write_msg({"event": "complete", "oid": oid})
.\commands\pt_to_tf.py
import inspect
import os
from argparse import ArgumentParser, Namespace
from importlib import import_module
import huggingface_hub
import numpy as np
from packaging import version
from .. import (
FEATURE_EXTRACTOR_MAPPING,
IMAGE_PROCESSOR_MAPPING,
PROCESSOR_MAPPING,
TOKENIZER_MAPPING,
AutoConfig,
AutoFeatureExtractor,
AutoImageProcessor,
AutoProcessor,
AutoTokenizer,
is_datasets_available,
is_tf_available,
is_torch_available,
)
from ..utils import TF2_WEIGHTS_INDEX_NAME, TF2_WEIGHTS_NAME, logging
from . import BaseTransformersCLICommand
if is_tf_available():
import tensorflow as tf
tf.config.experimental.enable_tensor_float_32_execution(False)
if is_torch_available():
import torch
if is_datasets_available():
from datasets import load_dataset
MAX_ERROR = 5e-5
def convert_command_factory(args: Namespace):
"""
Factory function used to convert a model PyTorch checkpoint in a TensorFlow 2 checkpoint.
Returns: ServeCommand
"""
return PTtoTFCommand(
args.model_name,
args.local_dir,
args.max_error,
args.new_weights,
args.no_pr,
args.push,
args.extra_commit_description,
args.override_model_class,
)
class PTtoTFCommand(BaseTransformersCLICommand):
@staticmethod
def register_subcommand(parser: ArgumentParser):
"""
Register this command to argparse so it's available for the transformer-cli
Args:
parser: Root parser to register command-specific arguments
"""
train_parser = parser.add_parser(
"pt-to-tf",
help=(
"CLI tool to run convert a transformers model from a PyTorch checkpoint to a TensorFlow checkpoint."
" Can also be used to validate existing weights without opening PRs, with --no-pr."
),
)
train_parser.add_argument(
"--model-name",
type=str,
required=True,
help="The model name, including owner/organization, as seen on the hub.",
)
train_parser.add_argument(
"--local-dir",
type=str,
default="",
help="Optional local directory of the model repository. Defaults to /tmp/{model_name}",
)
train_parser.add_argument(
"--max-error",
type=float,
default=MAX_ERROR,
help=(
f"Maximum error tolerance. Defaults to {MAX_ERROR}. This flag should be avoided, use at your own risk."
),
)
train_parser.add_argument(
"--new-weights",
action="store_true",
help="Optional flag to create new TensorFlow weights, even if they already exist.",
)
train_parser.add_argument(
"--no-pr", action="store_true", help="Optional flag to NOT open a PR with converted weights."
)
train_parser.add_argument(
"--push",
action="store_true",
help="Optional flag to push the weights directly to `main` (requires permissions)",
)
train_parser.add_argument(
"--extra-commit-description",
type=str,
default="",
help="Optional additional commit description to use when opening a PR (e.g. to tag the owner).",
)
train_parser.add_argument(
"--override-model-class",
type=str,
default=None,
help="If you think you know better than the auto-detector, you can specify the model class here. "
"Can be either an AutoModel class or a specific model class like BertForSequenceClassification.",
)
train_parser.set_defaults(func=convert_command_factory)
def find_pt_tf_differences(pt_outputs, tf_outputs):
"""
Compares the TensorFlow and PyTorch outputs, returning a dictionary with all tensor differences.
"""
pt_out_attrs = set(pt_outputs.keys())
tf_out_attrs = set(tf_outputs.keys())
if pt_out_attrs != tf_out_attrs:
raise ValueError(
f"The model outputs have different attributes, aborting. (Pytorch: {pt_out_attrs}, TensorFlow:"
f" {tf_out_attrs})"
)
def _find_pt_tf_differences(pt_out, tf_out, differences, attr_name=""):
if isinstance(pt_out, torch.Tensor):
tensor_difference = np.max(np.abs(pt_out.numpy() - tf_out.numpy()))
differences[attr_name] = tensor_difference
else:
root_name = attr_name
for i, pt_item in enumerate(pt_out):
if isinstance(pt_item, str):
branch_name = root_name + pt_item
tf_item = tf_out[pt_item]
pt_item = pt_out[pt_item]
else:
branch_name = root_name + f"[{i}]"
tf_item = tf_out[i]
differences = _find_pt_tf_differences(pt_item, tf_item, differences, branch_name)
return differences
return _find_pt_tf_differences(pt_outputs, tf_outputs, {})
def __init__(
self,
model_name: str,
local_dir: str,
max_error: float,
new_weights: bool,
no_pr: bool,
push: bool,
extra_commit_description: str,
override_model_class: str,
*args,
):
self._logger = logging.get_logger("transformers-cli/pt_to_tf")
self._model_name = model_name
self._local_dir = local_dir if local_dir else os.path.join("/tmp", model_name)
self._max_error = max_error
self._new_weights = new_weights
self._no_pr = no_pr
self._push = push
self._extra_commit_description = extra_commit_description
self._override_model_class = override_model_class
.\commands\run.py
from argparse import ArgumentParser
from ..pipelines import Pipeline, PipelineDataFormat, get_supported_tasks, pipeline
from ..utils import logging
from . import BaseTransformersCLICommand
logger = logging.get_logger(__name__)
def try_infer_format_from_ext(path: str):
if not path:
return "pipe"
for ext in PipelineDataFormat.SUPPORTED_FORMATS:
if path.endswith(ext):
return ext
raise Exception(
f"Unable to determine file format from file extension {path}. "
f"Please provide the format through --format {PipelineDataFormat.SUPPORTED_FORMATS}"
)
def run_command_factory(args):
nlp = pipeline(
task=args.task,
model=args.model if args.model else None,
config=args.config,
tokenizer=args.tokenizer,
device=args.device,
)
format = try_infer_format_from_ext(args.input) if args.format == "infer" else args.format
reader = PipelineDataFormat.from_str(
format=format,
output_path=args.output,
input_path=args.input,
column=args.column if args.column else nlp.default_input_names,
overwrite=args.overwrite,
)
return RunCommand(nlp, reader)
class RunCommand(BaseTransformersCLICommand):
def __init__(self, nlp: Pipeline, reader: PipelineDataFormat):
self._nlp = nlp
self._reader = reader
@staticmethod
def run(self):
nlp, outputs = self._nlp, []
for entry in self._reader:
if self._reader.is_multi_columns:
output = nlp(**entry)
else:
output = nlp(entry)
if isinstance(output, dict):
outputs.append(output)
else:
outputs += output
if self._nlp.binary_output:
binary_path = self._reader.save_binary(outputs)
logger.warning(f"Current pipeline requires output to be in binary format, saving at {binary_path}")
else:
self._reader.save(outputs)
.\commands\serving.py
from argparse import ArgumentParser, Namespace
from typing import Any, List, Optional
from ..pipelines import Pipeline, get_supported_tasks, pipeline
from ..utils import logging
from . import BaseTransformersCLICommand
try:
from fastapi import Body, FastAPI, HTTPException
from fastapi.routing import APIRoute
from pydantic import BaseModel
from starlette.responses import JSONResponse
from uvicorn import run
_serve_dependencies_installed = True
except (ImportError, AttributeError):
BaseModel = object
def Body(*x, **y):
pass
_serve_dependencies_installed = False
logger = logging.get_logger("transformers-cli/serving")
def serve_command_factory(args: Namespace):
"""
从提供的命令行参数实例化服务服务器的工厂函数。
Returns: ServeCommand 实例
"""
nlp = pipeline(
task=args.task,
model=args.model if args.model else None,
config=args.config,
tokenizer=args.tokenizer,
device=args.device,
)
return ServeCommand(nlp, args.host, args.port, args.workers)
class ServeModelInfoResult(BaseModel):
"""
暴露模型信息的数据模型
"""
infos: dict
class ServeTokenizeResult(BaseModel):
"""
分词结果数据模型
"""
tokens: List[str]
tokens_ids: Optional[List[int]]
class ServeDeTokenizeResult(BaseModel):
"""
反分词结果数据模型
"""
text: str
class ServeForwardResult(BaseModel):
"""
前向传播结果数据模型
"""
output: Any
class ServeCommand(BaseTransformersCLICommand):
@staticmethod
def register_subcommand(parser: ArgumentParser):
"""
Register this command to argparse so it's available for the transformer-cli
Args:
parser: Root parser to register command-specific arguments
"""
serve_parser = parser.add_parser(
"serve", help="CLI tool to run inference requests through REST and GraphQL endpoints."
)
serve_parser.add_argument(
"--task",
type=str,
choices=get_supported_tasks(),
help="The task to run the pipeline on",
)
serve_parser.add_argument("--host", type=str, default="localhost", help="Interface the server will listen on.")
serve_parser.add_argument("--port", type=int, default=8888, help="Port the serving will listen to.")
serve_parser.add_argument("--workers", type=int, default=1, help="Number of http workers")
serve_parser.add_argument("--model", type=str, help="Model's name or path to stored model.")
serve_parser.add_argument("--config", type=str, help="Model's config name or path to stored model.")
serve_parser.add_argument("--tokenizer", type=str, help="Tokenizer name to use.")
serve_parser.add_argument(
"--device",
type=int,
default=-1,
help="Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)",
)
serve_parser.set_defaults(func=serve_command_factory)
def __init__(self, pipeline: Pipeline, host: str, port: int, workers: int):
self._pipeline = pipeline
self.host = host
self.port = port
self.workers = workers
if not _serve_dependencies_installed:
raise RuntimeError(
"Using serve command requires FastAPI and uvicorn. "
'Please install transformers with [serving]: pip install "transformers[serving]". '
"Or install FastAPI and uvicorn separately."
)
else:
logger.info(f"Serving model over {host}:{port}")
self._app = FastAPI(
routes=[
APIRoute(
"/",
self.model_info,
response_model=ServeModelInfoResult,
response_class=JSONResponse,
methods=["GET"],
),
APIRoute(
"/tokenize",
self.tokenize,
response_model=ServeTokenizeResult,
response_class=JSONResponse,
methods=["POST"],
),
APIRoute(
"/detokenize",
self.detokenize,
response_model=ServeDeTokenizeResult,
response_class=JSONResponse,
methods=["POST"],
),
APIRoute(
"/forward",
self.forward,
response_model=ServeForwardResult,
response_class=JSONResponse,
methods=["POST"],
),
],
timeout=600,
)
def run(self):
run(self._app, host=self.host, port=self.port, workers=self.workers)
def model_info(self):
return ServeModelInfoResult(infos=vars(self._pipeline.model.config))
def tokenize(self, text_input: str = Body(None, embed=True), return_ids: bool = Body(False, embed=True)):
"""
Tokenize the provided input and eventually returns corresponding tokens id: - **text_input**: String to
tokenize - **return_ids**: Boolean flags indicating if the tokens have to be converted to their integer
mapping.
"""
try:
tokens_txt = self._pipeline.tokenizer.tokenize(text_input)
if return_ids:
tokens_ids = self._pipeline.tokenizer.convert_tokens_to_ids(tokens_txt)
return ServeTokenizeResult(tokens=tokens_txt, tokens_ids=tokens_ids)
else:
return ServeTokenizeResult(tokens=tokens_txt)
except Exception as e:
raise HTTPException(status_code=500, detail={"model": "", "error": str(e)})
def detokenize(
self,
tokens_ids: List[int] = Body(None, embed=True),
skip_special_tokens: bool = Body(False, embed=True),
cleanup_tokenization_spaces: bool = Body(True, embed=True),
):
"""
Detokenize the provided tokens ids to readable text:
- **tokens_ids**: List of tokens ids
- **skip_special_tokens**: Flag indicating to not try to decode special tokens
- **cleanup_tokenization_spaces**: Flag indicating to remove all leading/trailing spaces and intermediate ones.
"""
try:
decoded_str = self._pipeline.tokenizer.decode(tokens_ids, skip_special_tokens, cleanup_tokenization_spaces)
return ServeDeTokenizeResult(model="", text=decoded_str)
except Exception as e:
raise HTTPException(status_code=500, detail={"model": "", "error": str(e)})
async def forward(self, inputs=Body(None, embed=True)):
"""
**inputs**: **attention_mask**: **tokens_type_ids**:
"""
if len(inputs) == 0:
return ServeForwardResult(output=[], attention=[])
try:
output = self._pipeline(inputs)
return ServeForwardResult(output=output)
except Exception as e:
raise HTTPException(500, {"error": str(e)})
.\commands\train.py
import os
from argparse import ArgumentParser, Namespace
from ..data import SingleSentenceClassificationProcessor as Processor
from ..pipelines import TextClassificationPipeline
from ..utils import is_tf_available, is_torch_available, logging
from . import BaseTransformersCLICommand
if not is_tf_available() and not is_torch_available():
raise RuntimeError("At least one of PyTorch or TensorFlow 2.0+ should be installed to use CLI training")
USE_XLA = False
USE_AMP = False
def train_command_factory(args: Namespace):
"""
工厂函数,根据给定的命令行参数实例化训练命令对象。
Returns:
TrainCommand: 实例化的训练命令对象
"""
return TrainCommand(args)
class TrainCommand(BaseTransformersCLICommand):
@staticmethod
def register_subcommand(parser: ArgumentParser):
"""
Register this command to argparse so it's available for the transformer-cli
Args:
parser: Root parser to register command-specific arguments
"""
train_parser = parser.add_parser("train", help="CLI tool to train a model on a task.")
train_parser.add_argument(
"--train_data",
type=str,
required=True,
help="path to train (and optionally evaluation) dataset as a csv with tab separated labels and sentences.",
)
train_parser.add_argument(
"--column_label", type=int, default=0, help="Column of the dataset csv file with example labels."
)
train_parser.add_argument(
"--column_text", type=int, default=1, help="Column of the dataset csv file with example texts."
)
train_parser.add_argument(
"--column_id", type=int, default=2, help="Column of the dataset csv file with example ids."
)
train_parser.add_argument(
"--skip_first_row", action="store_true", help="Skip the first row of the csv file (headers)."
)
train_parser.add_argument("--validation_data", type=str, default="", help="path to validation dataset.")
train_parser.add_argument(
"--validation_split",
type=float,
default=0.1,
help="if validation dataset is not provided, fraction of train dataset to use as validation dataset.",
)
train_parser.add_argument("--output", type=str, default="./", help="path to saved the trained model.")
train_parser.add_argument(
"--task", type=str, default="text_classification", help="Task to train the model on."
)
train_parser.add_argument(
"--model", type=str, default="google-bert/bert-base-uncased", help="Model's name or path to stored model."
)
train_parser.add_argument("--train_batch_size", type=int, default=32, help="Batch size for training.")
train_parser.add_argument("--valid_batch_size", type=int, default=64, help="Batch size for validation.")
train_parser.add_argument("--learning_rate", type=float, default=3e-5, help="Learning rate.")
train_parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon for Adam optimizer.")
train_parser.set_defaults(func=train_command_factory)
def __init__(self, args: Namespace):
self.logger = logging.get_logger("transformers-cli/training")
self.framework = "tf" if is_tf_available() else "torch"
os.makedirs(args.output, exist_ok=True)
self.output = args.output
self.column_label = args.column_label
self.column_text = args.column_text
self.column_id = args.column_id
self.logger.info(f"Loading {args.task} pipeline for {args.model}")
if args.task == "text_classification":
self.pipeline = TextClassificationPipeline.from_pretrained(args.model)
elif args.task == "token_classification":
raise NotImplementedError
elif args.task == "question_answering":
raise NotImplementedError
self.logger.info(f"Loading dataset from {args.train_data}")
self.train_dataset = Processor.create_from_csv(
args.train_data,
column_label=args.column_label,
column_text=args.column_text,
column_id=args.column_id,
skip_first_row=args.skip_first_row,
)
self.valid_dataset = None
if args.validation_data:
self.logger.info(f"Loading validation dataset from {args.validation_data}")
self.valid_dataset = Processor.create_from_csv(
args.validation_data,
column_label=args.column_label,
column_text=args.column_text,
column_id=args.column_id,
skip_first_row=args.skip_first_row,
)
self.validation_split = args.validation_split
self.train_batch_size = args.train_batch_size
self.valid_batch_size = args.valid_batch_size
self.learning_rate = args.learning_rate
self.adam_epsilon = args.adam_epsilon
def run(self):
if self.framework == "tf":
return self.run_tf()
return self.run_torch()
def run_torch(self):
raise NotImplementedError
def run_tf(self):
self.pipeline.fit(
self.train_dataset,
validation_data=self.valid_dataset,
validation_split=self.validation_split,
learning_rate=self.learning_rate,
adam_epsilon=self.adam_epsilon,
train_batch_size=self.train_batch_size,
valid_batch_size=self.valid_batch_size,
)
self.pipeline.save_pretrained(self.output)
.\commands\transformers_cli.py
from argparse import ArgumentParser
from .add_new_model import AddNewModelCommand
from .add_new_model_like import AddNewModelLikeCommand
from .convert import ConvertCommand
from .download import DownloadCommand
from .env import EnvironmentCommand
from .lfs import LfsCommands
from .pt_to_tf import PTtoTFCommand
from .run import RunCommand
from .serving import ServeCommand
from .user import UserCommands
def main():
parser = ArgumentParser("Transformers CLI tool", usage="transformers-cli <command> [<args>]")
commands_parser = parser.add_subparsers(help="transformers-cli command helpers")
ConvertCommand.register_subcommand(commands_parser)
DownloadCommand.register_subcommand(commands_parser)
EnvironmentCommand.register_subcommand(commands_parser)
RunCommand.register_subcommand(commands_parser)
ServeCommand.register_subcommand(commands_parser)
UserCommands.register_subcommand(commands_parser)
AddNewModelCommand.register_subcommand(commands_parser)
AddNewModelLikeCommand.register_subcommand(commands_parser)
LfsCommands.register_subcommand(commands_parser)
PTtoTFCommand.register_subcommand(commands_parser)
args = parser.parse_args()
if not hasattr(args, "func"):
parser.print_help()
exit(1)
service = args.func(args)
service.run()
if __name__ == "__main__":
main()
.\commands\user.py
import subprocess
from argparse import ArgumentParser
from typing import List, Union
from huggingface_hub.hf_api import HfFolder, create_repo, whoami
from requests.exceptions import HTTPError
from . import BaseTransformersCLICommand
class UserCommands(BaseTransformersCLICommand):
@staticmethod
def register_subcommand(parser: ArgumentParser):
login_parser = parser.add_parser("login", help="Log in using the same credentials as on huggingface.co")
login_parser.set_defaults(func=lambda args: LoginCommand(args))
whoami_parser = parser.add_parser("whoami", help="Find out which huggingface.co account you are logged in as.")
whoami_parser.set_defaults(func=lambda args: WhoamiCommand(args))
logout_parser = parser.add_parser("logout", help="Log out")
logout_parser.set_defaults(func=lambda args: LogoutCommand(args))
repo_parser = parser.add_parser(
"repo",
help="Deprecated: use `huggingface-cli` instead. Commands to interact with your huggingface.co repos.",
)
repo_subparsers = repo_parser.add_subparsers(
help="Deprecated: use `huggingface-cli` instead. huggingface.co repos related commands"
)
repo_create_parser = repo_subparsers.add_parser(
"create", help="Deprecated: use `huggingface-cli` instead. Create a new repo on huggingface.co"
)
repo_create_parser.add_argument(
"name",
type=str,
help="Name for your model's repo. Will be namespaced under your username to build the model id.",
)
repo_create_parser.add_argument("--organization", type=str, help="Optional: organization namespace.")
repo_create_parser.add_argument("-y", "--yes", action="store_true", help="Optional: answer Yes to the prompt")
repo_create_parser.set_defaults(func=lambda args: RepoCreateCommand(args))
class ANSI:
"""
Helper for en.wikipedia.org/wiki/ANSI_escape_code
"""
_bold = "\u001b[1m"
_red = "\u001b[31m"
_gray = "\u001b[90m"
_reset = "\u001b[0m"
@classmethod
def bold(cls, s):
return f"{cls._bold}{s}{cls._reset}"
@classmethod
def red(cls, s):
return f"{cls._bold}{cls._red}{s}{cls._reset}"
def gray(cls, s):
return f"{cls._gray}{s}{cls._reset}"
def tabulate(rows: List[List[Union[str, int]]], headers: List[str]) -> str:
"""
Inspired by:
- stackoverflow.com/a/8356620/593036
- stackoverflow.com/questions/9535954/printing-lists-as-tabular-data
"""
col_widths = [max(len(str(x)) for x in col) for col in zip(*rows, headers)]
row_format = ("{{:{}}} " * len(headers)).format(*col_widths)
lines = []
lines.append(row_format.format(*headers))
lines.append(row_format.format(*["-" * w for w in col_widths]))
for row in rows:
lines.append(row_format.format(*row))
return "\n".join(lines)
class BaseUserCommand:
def __init__(self, args):
self.args = args
class LoginCommand(BaseUserCommand):
def run(self):
print(
ANSI.red(
"ERROR! `huggingface-cli login` uses an outdated login mechanism "
"that is not compatible with the Hugging Face Hub backend anymore. "
"Please use `huggingface-cli login instead."
)
)
class WhoamiCommand(BaseUserCommand):
def run(self):
print(
ANSI.red(
"WARNING! `transformers-cli whoami` is deprecated and will be removed in v5. Please use "
"`huggingface-cli whoami` instead."
)
)
token = HfFolder.get_token()
if token is None:
print("Not logged in")
exit()
try:
user, orgs = whoami(token)
print(user)
if orgs:
print(ANSI.bold("orgs: "), ",".join(orgs))
except HTTPError as e:
print(e)
print(ANSI.red(e.response.text))
exit(1)
class LogoutCommand(BaseUserCommand):
def run(self):
print(
ANSI.red(
"ERROR! `transformers-cli logout` uses an outdated logout mechanism "
"that is not compatible with the Hugging Face Hub backend anymore. "
"Please use `huggingface-cli logout instead."
)
)
class RepoCreateCommand(BaseUserCommand):
def run(self):
print(
ANSI.red(
"WARNING! Managing repositories through transformers-cli is deprecated. "
"Please use `huggingface-cli` instead."
)
)
token = HfFolder.get_token()
if token is None:
print("Not logged in")
exit(1)
try:
stdout = subprocess.check_output(["git", "--version"]).decode("utf-8")
print(ANSI.gray(stdout.strip()))
except FileNotFoundError:
print("Looks like you do not have git installed, please install.")
try:
stdout = subprocess.check_output(["git-lfs", "--version"]).decode("utf-8")
print(ANSI.gray(stdout.strip()))
except FileNotFoundError:
print(
ANSI.red(
"Looks like you do not have git-lfs installed, please install."
" You can install from https://git-lfs.github.com/."
" Then run `git lfs install` (you only have to do this once)."
)
)
print("")
user, _ = whoami(token)
namespace = self.args.organization if self.args.organization is not None else user
full_name = f"{namespace}/{self.args.name}"
print(f"You are about to create {ANSI.bold(full_name)}")
if not self.args.yes:
choice = input("Proceed? [Y/n] ").lower()
if not (choice == "" or choice == "y" or choice == "yes"):
print("Abort")
exit()
try:
url = create_repo(token, name=self.args.name, organization=self.args.organization)
except HTTPError as e:
print(e)
print(ANSI.red(e.response.text))
exit(1)
print("\nYour repo now lives at:")
print(f" {ANSI.bold(url)}")
print("\nYou can clone it locally with the command below, and commit/push as usual.")
print(f"\n git clone {url}")
print("")
.\commands\__init__.py
from abc import ABC, abstractmethod
from argparse import ArgumentParser
class BaseTransformersCLICommand(ABC):
@staticmethod
@abstractmethod
def register_subcommand(parser: ArgumentParser):
raise NotImplementedError()
@abstractmethod
def run(self):
raise NotImplementedError()
.\configuration_utils.py
""" Configuration base class and utilities."""
import copy
import json
import os
import re
import warnings
from typing import Any, Dict, List, Optional, Tuple, Union
from packaging import version
from . import __version__
from .dynamic_module_utils import custom_object_save
from .utils import (
CONFIG_NAME,
PushToHubMixin,
add_model_info_to_auto_map,
cached_file,
copy_func,
download_url,
extract_commit_hash,
is_remote_url,
is_torch_available,
logging,
)
logger = logging.get_logger(__name__)
_re_configuration_file = re.compile(r"config\.(.*)\.json")
class PretrainedConfig(PushToHubMixin):
r"""
Base class for all configuration classes. Handles a few parameters common to all models' configurations as well as
methods for loading/downloading/saving configurations.
<Tip>
A configuration file can be loaded and saved to disk. Loading the configuration file and using this file to
initialize a model does **not** load the model weights. It only affects the model's configuration.
</Tip>
Class attributes (overridden by derived classes):
- **model_type** (`str`) -- An identifier for the model type, serialized into the JSON file, and used to recreate
the correct object in [`~transformers.AutoConfig`].
- **is_composition** (`bool`) -- Whether the config class is composed of multiple sub-configs. In this case the
config has to be initialized from two or more configs of type [`~transformers.PretrainedConfig`] like:
[`~transformers.EncoderDecoderConfig`] or [`~RagConfig`].
- **keys_to_ignore_at_inference** (`List[str]`) -- A list of keys to ignore by default when looking at dictionary
outputs of the model during inference.
- **attribute_map** (`Dict[str, str]`) -- A dict that maps model specific attribute names to the standardized
naming of attributes.
Common attributes (present in all subclasses):
- **vocab_size** (`int`) -- The number of tokens in the vocabulary, which is also the first dimension of the
embeddings matrix (this attribute may be missing for models that don't have a text modality like ViT).
- **hidden_size** (`int`) -- The hidden size of the model.
"""
model_type: str = ""
is_composition: bool = False
attribute_map: Dict[str, str] = {}
_auto_class: Optional[str] = None
def __setattr__(self, key, value):
if key in super().__getattribute__("attribute_map"):
key = super().__getattribute__("attribute_map")[key]
super().__setattr__(key, value)
def __getattribute__(self, key):
if key != "attribute_map" and key in super().__getattribute__("attribute_map"):
key = super().__getattribute__("attribute_map")[key]
return super().__getattribute__(key)
@property
def name_or_path(self) -> str:
return getattr(self, "_name_or_path", None)
@name_or_path.setter
def name_or_path(self, value):
self._name_or_path = str(value)
@property
def use_return_dict(self) -> bool:
"""
`bool`: 是否返回 [`~utils.ModelOutput`] 而不是元组。
"""
return self.return_dict and not self.torchscript
@property
def num_labels(self) -> int:
"""
`int`: 分类模型的标签数量。
"""
return len(self.id2label)
@num_labels.setter
def num_labels(self, num_labels: int):
if not hasattr(self, "id2label") or self.id2label is None or len(self.id2label) != num_labels:
self.id2label = {i: f"LABEL_{i}" for i in range(num_labels)}
self.label2id = dict(zip(self.id2label.values(), self.id2label.keys()))
@property
def _attn_implementation(self):
"""
`str`: 注意力机制的实现方式。
"""
if hasattr(self, "_attn_implementation_internal"):
if self._attn_implementation_internal is None:
return "eager"
else:
return self._attn_implementation_internal
else:
return "eager"
@_attn_implementation.setter
def _attn_implementation(self, value):
self._attn_implementation_internal = value
@classmethod
def _set_token_in_kwargs(kwargs, token=None):
"""在 kwargs 中设置 `token` 参数。
这个方法是为了避免在所有模型配置类中重复应用相同的更改,这些类重写了 `from_pretrained` 方法。
需要在随后的 PR 中清理 `use_auth_token`。
"""
if token is None:
token = kwargs.pop("token", None)
use_auth_token = kwargs.pop("use_auth_token", None)
if use_auth_token is not None:
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
FutureWarning,
)
if token is not None:
raise ValueError(
"`token` and `use_auth_token` are both specified. Please set only the argument `token`."
)
token = use_auth_token
if token is not None:
kwargs["token"] = token
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Union[str, os.PathLike],
cache_dir: Optional[Union[str, os.PathLike]] = None,
force_download: bool = False,
local_files_only: bool = False,
token: Optional[Union[str, bool]] = None,
revision: str = "main",
**kwargs,
):
"""
从 `pretrained_model_name_or_path` 解析出参数字典,用于通过 `from_dict` 实例化 `PretrainedConfig`。
参数:
pretrained_model_name_or_path (`str` 或 `os.PathLike`):
想要获取参数字典的预训练检查点的标识符。
返回:
`Tuple[Dict, Dict]`: 将用于实例化配置对象的字典。
"""
cls._set_token_in_kwargs(kwargs)
original_kwargs = copy.deepcopy(kwargs)
config_dict, kwargs = cls._get_config_dict(pretrained_model_name_or_path, **kwargs)
if "_commit_hash" in config_dict:
original_kwargs["_commit_hash"] = config_dict["_commit_hash"]
if "configuration_files" in config_dict:
configuration_file = get_configuration_file(config_dict["configuration_files"])
config_dict, kwargs = cls._get_config_dict(
pretrained_model_name_or_path, _configuration_file=configuration_file, **original_kwargs
)
return config_dict, kwargs
@classmethod
def _get_config_dict(
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
):
def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PretrainedConfig":
"""
Instantiates a [`PretrainedConfig`] from a Python dictionary of parameters.
Args:
config_dict (`Dict[str, Any]`):
Dictionary that will be used to instantiate the configuration object. Such a dictionary can be
retrieved from a pretrained checkpoint by leveraging the [`~PretrainedConfig.get_config_dict`] method.
kwargs (`Dict[str, Any]`):
Additional parameters from which to initialize the configuration object.
Returns:
[`PretrainedConfig`]: The configuration object instantiated from those parameters.
"""
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
kwargs.pop("_from_auto", None)
kwargs.pop("_from_pipeline", None)
if "_commit_hash" in kwargs and "_commit_hash" in config_dict:
kwargs["_commit_hash"] = config_dict["_commit_hash"]
config_dict["attn_implementation"] = kwargs.pop("attn_implementation", None)
config = cls(**config_dict)
if hasattr(config, "pruned_heads"):
config.pruned_heads = {int(key): value for key, value in config.pruned_heads.items()}
if "num_labels" in kwargs and "id2label" in kwargs:
num_labels = kwargs["num_labels"]
id2label = kwargs["id2label"] if kwargs["id2label"] is not None else []
if len(id2label) != num_labels:
raise ValueError(
f"You passed along `num_labels={num_labels}` with an incompatible id to label map: "
f"{kwargs['id2label']}. Since those arguments are inconsistent with each other, you should remove "
"one of them."
)
to_remove = []
for key, value in kwargs.items():
if hasattr(config, key):
current_attr = getattr(config, key)
if isinstance(current_attr, PretrainedConfig) and isinstance(value, dict):
value = current_attr.__class__(**value)
setattr(config, key, value)
if key != "torch_dtype":
to_remove.append(key)
for key in to_remove:
kwargs.pop(key, None)
logger.info(f"Model config {config}")
if return_unused_kwargs:
return config, kwargs
else:
return config
@classmethod
def from_json_file(cls, json_file: Union[str, os.PathLike]) -> "PretrainedConfig":
"""
Instantiates a [`PretrainedConfig`] from the path to a JSON file of parameters.
Args:
json_file (`str` or `os.PathLike`):
Path to the JSON file containing the parameters.
Returns:
[`PretrainedConfig`]: The configuration object instantiated from that JSON file.
"""
config_dict = cls._dict_from_json_file(json_file)
return cls(**config_dict)
@classmethod
def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
"""
Reads and parses a JSON file into a dictionary.
Args:
json_file (`str` or `os.PathLike`):
Path to the JSON file.
Returns:
dict: Dictionary containing the parsed JSON content.
"""
with open(json_file, "r", encoding="utf-8") as reader:
text = reader.read()
return json.loads(text)
def __eq__(self, other):
return isinstance(other, PretrainedConfig) and (self.__dict__ == other.__dict__)
def __repr__(self):
return f"{self.__class__.__name__} {self.to_json_string()}"
def to_diff_dict(self) -> Dict[str, Any]:
"""
Removes all attributes from config which correspond to the default config attributes for better readability and
serializes to a Python dictionary.
Returns:
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance,
"""
config_dict = self.to_dict()
default_config_dict = PretrainedConfig().to_dict()
class_config_dict = self.__class__().to_dict() if not self.is_composition else {}
serializable_config_dict = {}
for key, value in config_dict.items():
if (
isinstance(getattr(self, key, None), PretrainedConfig)
and key in class_config_dict
and isinstance(class_config_dict[key], dict)
):
diff = recursive_diff_dict(value, class_config_dict[key], config_obj=getattr(self, key, None))
if "model_type" in value:
diff["model_type"] = value["model_type"]
if len(diff) > 0:
serializable_config_dict[key] = diff
elif (
key not in default_config_dict
or key == "transformers_version"
or value != default_config_dict[key]
or (key in class_config_dict and value != class_config_dict[key])
):
serializable_config_dict[key] = value
if hasattr(self, "quantization_config"):
if isinstance(self.quantization_config, dict):
serializable_config_dict["quantization_config"] = self.quantization_config
else:
serializable_config_dict["quantization_config"] = self.quantization_config.to_dict()
_ = serializable_config_dict.pop("_pre_quantization_dtype", None)
self.dict_torch_dtype_to_str(serializable_config_dict)
if "_attn_implementation_internal" in serializable_config_dict:
del serializable_config_dict["_attn_implementation_internal"]
return serializable_config_dict
def to_dict(self) -> Dict[str, Any]:
"""
将当前实例序列化为一个 Python 字典。
Returns:
`Dict[str, Any]`: 包含构成该配置实例的所有属性的字典。
"""
output = copy.deepcopy(self.__dict__)
if hasattr(self.__class__, "model_type"):
output["model_type"] = self.__class__.model_type
if "_auto_class" in output:
del output["_auto_class"]
if "_commit_hash" in output:
del output["_commit_hash"]
if "_attn_implementation_internal" in output:
del output["_attn_implementation_internal"]
output["transformers_version"] = __version__
for key, value in output.items():
if isinstance(value, PretrainedConfig):
value = value.to_dict()
del value["transformers_version"]
output[key] = value
if hasattr(self, "quantization_config"):
output["quantization_config"] = (
self.quantization_config.to_dict()
if not isinstance(self.quantization_config, dict)
else self.quantization_config
)
_ = output.pop("_pre_quantization_dtype", None)
self.dict_torch_dtype_to_str(output)
return output
def to_json_string(self, use_diff: bool = True) -> str:
"""
将当前实例序列化为 JSON 字符串。
Args:
use_diff (`bool`, *optional*, 默认为 `True`):
如果设置为 `True`,则只序列化配置实例与默认 `PretrainedConfig()` 之间的差异。
Returns:
`str`: 包含构成该配置实例的所有属性的 JSON 格式字符串。
"""
if use_diff is True:
config_dict = self.to_diff_dict()
else:
config_dict = self.to_dict()
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
def to_json_file(self, json_file_path: Union[str, os.PathLike], use_diff: bool = True):
"""
将当前实例保存为 JSON 文件。
Args:
json_file_path (`str` 或 `os.PathLike`):
保存配置实例参数的 JSON 文件路径。
use_diff (`bool`, *optional*, 默认为 `True`):
如果设置为 `True`,则只序列化配置实例与默认 `PretrainedConfig()` 之间的差异。
"""
with open(json_file_path, "w", encoding="utf-8") as writer:
writer.write(self.to_json_string(use_diff=use_diff))
def update(self, config_dict: Dict[str, Any]):
"""
Updates attributes of this class with attributes from `config_dict`.
Args:
config_dict (`Dict[str, Any]`): Dictionary of attributes that should be updated for this class.
"""
for key, value in config_dict.items():
setattr(self, key, value)
def update_from_string(self, update_str: str):
"""
Updates attributes of this class with attributes from `update_str`.
The expected format is ints, floats and strings as is, and for booleans use `true` or `false`. For example:
"n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
The keys to change have to already exist in the config object.
Args:
update_str (`str`): String with attributes that should be updated for this class.
"""
d = dict(x.split("=") for x in update_str.split(","))
for k, v in d.items():
if not hasattr(self, k):
raise ValueError(f"key {k} isn't in the original config dict")
old_v = getattr(self, k)
if isinstance(old_v, bool):
if v.lower() in ["true", "1", "y", "yes"]:
v = True
elif v.lower() in ["false", "0", "n", "no"]:
v = False
else:
raise ValueError(f"can't derive true or false from {v} (key {k})")
elif isinstance(old_v, int):
v = int(v)
elif isinstance(old_v, float):
v = float(v)
elif not isinstance(old_v, str):
raise ValueError(
f"You can only update int, float, bool or string values in the config, got {v} for key {k}"
)
setattr(self, k, v)
def dict_torch_dtype_to_str(self, d: Dict[str, Any]) -> None:
"""
Checks whether the passed dictionary and its nested dicts have a *torch_dtype* key and if it's not None,
converts torch.dtype to a string of just the type. For example, `torch.float32` get converted into *"float32"*
string, which can then be stored in the json format.
"""
if d.get("torch_dtype", None) is not None and not isinstance(d["torch_dtype"], str):
d["torch_dtype"] = str(d["torch_dtype"]).split(".")[1]
for value in d.values():
if isinstance(value, dict):
self.dict_torch_dtype_to_str(value)
@classmethod
def`
def register_for_auto_class(cls, auto_class="AutoConfig"):
"""
Register this class with a given auto class. This should only be used for custom configurations as the ones in
the library are already mapped with `AutoConfig`.
<Tip warning={true}>
This API is experimental and may have some slight breaking changes in the next releases.
</Tip>
Args:
auto_class (`str` or `type`, *optional*, defaults to `"AutoConfig"`):
The auto class to register this new configuration with.
"""
if not isinstance(auto_class, str):
auto_class = auto_class.__name__
import transformers.models.auto as auto_module
if not hasattr(auto_module, auto_class):
raise ValueError(f"{auto_class} is not a valid auto class.")
cls._auto_class = auto_class
@staticmethod
def _get_generation_defaults() -> Dict[str, Any]:
return {
"max_length": 20,
"min_length": 0,
"do_sample": False,
"early_stopping": False,
"num_beams": 1,
"num_beam_groups": 1,
"diversity_penalty": 0.0,
"temperature": 1.0,
"top_k": 50,
"top_p": 1.0,
"typical_p": 1.0,
"repetition_penalty": 1.0,
"length_penalty": 1.0,
"no_repeat_ngram_size": 0,
"encoder_no_repeat_ngram_size": 0,
"bad_words_ids": None,
"num_return_sequences": 1,
"output_scores": False,
"return_dict_in_generate": False,
"forced_bos_token_id": None,
"forced_eos_token_id": None,
"remove_invalid_values": False,
"exponential_decay_length_penalty": None,
"suppress_tokens": None,
"begin_suppress_tokens": None,
}
def _has_non_default_generation_parameters(self) -> bool:
"""
Whether or not this instance holds non-default generation parameters.
"""
defaults = self._get_generation_defaults()
for parameter_name, default_value in defaults.items():
if hasattr(self, parameter_name) and getattr(self, parameter_name) != default_value:
return True
return False
def get_configuration_file(configuration_files: List[str]) -> str:
"""
Get the configuration file to use for this version of transformers.
Args:
configuration_files (`List[str]`): The list of available configuration files.
Returns:
`str`: The configuration file to use.
"""
configuration_files_map = {}
for file_name in configuration_files:
search = _re_configuration_file.search(file_name)
if search is not None:
v = search.groups()[0]
configuration_files_map[v] = file_name
available_versions = sorted(configuration_files_map.keys())
configuration_file = CONFIG_NAME
transformers_version = version.parse(__version__)
for v in available_versions:
if version.parse(v) <= transformers_version:
configuration_file = configuration_files_map[v]
else:
break
return configuration_file
def recursive_diff_dict(dict_a, dict_b, config_obj=None):
"""
Helper function to recursively take the diff between two nested dictionaries. The resulting diff only contains the
values from `dict_a` that are different from values in `dict_b`.
"""
diff = {}
default = config_obj.__class__().to_dict() if config_obj is not None else {}
for key, value in dict_a.items():
obj_value = getattr(config_obj, str(key), None)
if isinstance(obj_value, PretrainedConfig) and key in dict_b and isinstance(dict_b[key], dict):
diff_value = recursive_diff_dict(value, dict_b[key], config_obj=obj_value)
if len(diff_value) > 0:
diff[key] = diff_value
elif key not in dict_b or value != dict_b[key] or key not in default or value != default[key]:
diff[key] = value
return diff
PretrainedConfig.push_to_hub = copy_func(PretrainedConfig.push_to_hub)
if PretrainedConfig.push_to_hub.__doc__ is not None:
PretrainedConfig.push_to_hub.__doc__ = PretrainedConfig.push_to_hub.__doc__.format(
object="config", object_class="AutoConfig", object_files="configuration file"
)
.\convert_graph_to_onnx.py
import warnings
from argparse import ArgumentParser
from os import listdir, makedirs
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from packaging.version import Version, parse
from transformers.pipelines import Pipeline, pipeline
from transformers.tokenization_utils import BatchEncoding
from transformers.utils import ModelOutput, is_tf_available, is_torch_available
ORT_QUANTIZE_MINIMUM_VERSION = parse("1.4.0")
SUPPORTED_PIPELINES = [
"feature-extraction",
"ner",
"sentiment-analysis",
"fill-mask",
"question-answering",
"text-generation",
"translation_en_to_fr",
"translation_en_to_de",
"translation_en_to_ro",
]
class OnnxConverterArgumentParser(ArgumentParser):
"""
Wraps all the script arguments supported to export transformers models to ONNX IR
"""
def __init__(self):
super().__init__("ONNX Converter")
self.add_argument(
"--pipeline",
type=str,
choices=SUPPORTED_PIPELINES,
default="feature-extraction",
)
self.add_argument(
"--model",
type=str,
required=True,
help="Model's id or path (ex: google-bert/bert-base-cased)",
)
self.add_argument("--tokenizer", type=str, help="Tokenizer's id or path (ex: google-bert/bert-base-cased)")
self.add_argument(
"--framework",
type=str,
choices=["pt", "tf"],
help="Framework for loading the model",
)
self.add_argument("--opset", type=int, default=11, help="ONNX opset to use")
self.add_argument(
"--check-loading",
action="store_true",
help="Check ONNX is able to load the model",
)
self.add_argument(
"--use-external-format",
action="store_true",
help="Allow exporting model >= than 2Gb",
)
self.add_argument(
"--quantize",
action="store_true",
help="Quantize the neural network to be run with int8",
)
self.add_argument("output")
def generate_identified_filename(filename: Path, identifier: str) -> Path:
"""
# 在提供的文件路径末尾(在扩展名之前,如果有的话)添加一个字符串标识符
Args:
filename: pathlib.Path 实际的路径对象,我们希望在其末尾添加标识符后缀
identifier: 要添加的后缀
Returns: 添加了标识符的字符串,连接在文件名的末尾
# 检查 onnxruntime 的安装情况及版本是否符合要求
def check_onnxruntime_requirements(minimum_version: Version):
"""
Check onnxruntime is installed and if the installed version match is recent enough
Raises:
ImportError: If onnxruntime is not installed or too old version is found
"""
try:
import onnxruntime
# 解析已安装的 onnxruntime 的版本
ort_version = parse(onnxruntime.__version__)
# 要求最低版本为 1.4.0
if ort_version < ORT_QUANTIZE_MINIMUM_VERSION:
raise ImportError(
f"We found an older version of onnxruntime ({onnxruntime.__version__}) "
f"but we require onnxruntime to be >= {minimum_version} to enable all the conversions options.\n"
"Please update onnxruntime by running `pip install --upgrade onnxruntime`"
)
except ImportError:
raise ImportError(
"onnxruntime doesn't seem to be currently installed. "
"Please install the onnxruntime by running `pip install onnxruntime`"
" and relaunch the conversion."
)
# 确保输入在正确顺序中,没有非法输入
def ensure_valid_input(model, tokens, input_names):
"""
Ensure inputs are presented in the correct order, without any Non
Args:
model: The model used to forward the input data
tokens: BatchEncoding holding the input data
input_names: The name of the inputs
Returns: Tuple
"""
print("Ensuring inputs are in correct order")
# 获取模型前向方法的参数名列表
model_args_name = model.forward.__code__.co_varnames
model_args, ordered_input_names = [], []
for arg_name in model_args_name[1:]: # 从索引1开始以跳过 "self" 参数
if arg_name in input_names:
ordered_input_names.append(arg_name)
model_args.append(tokens[arg_name])
else:
print(f"{arg_name} is not present in the generated input list.")
break
# 打印生成的输入顺序
print(f"Generated inputs order: {ordered_input_names}")
return ordered_input_names, tuple(model_args)
# 推断模型输入输出张量的静态与动态轴
def infer_shapes(nlp: Pipeline, framework: str) -> Tuple[List[str], List[str], Dict, BatchEncoding]:
"""
Attempt to infer the static vs dynamic axes for each input and output tensors for a specific model
Args:
nlp: The pipeline object holding the model to be exported
framework: The framework identifier to dispatch to the correct inference scheme (pt/tf)
Returns:
- List of the inferred input variable names
- List of the inferred output variable names
- Dictionary with input/output variables names as key and shape tensor as value
- a BatchEncoding reference which was used to infer all the above information
"""
def build_shape_dict(name: str, tensor, is_input: bool, seq_len: int):
# 如果 tensor 是元组或列表,则递归调用 build_shape_dict 处理每个元素
if isinstance(tensor, (tuple, list)):
return [build_shape_dict(name, t, is_input, seq_len) for t in tensor]
else:
# 假设第一个维度是批处理维度,且只有一个元素
axes = {[axis for axis, numel in enumerate(tensor.shape) if numel == 1][0]: "batch"}
# 如果是输入数据,判断维度是否为二维,将第二个维度标记为 "sequence"
if is_input:
if len(tensor.shape) == 2:
axes[1] = "sequence"
else:
raise ValueError(f"Unable to infer tensor axes ({len(tensor.shape)})")
else:
# 找到与指定序列长度相匹配的维度,并将其标记为 "sequence"
seq_axes = [dim for dim, shape in enumerate(tensor.shape) if shape == seq_len]
axes.update({dim: "sequence" for dim in seq_axes})
# 打印找到的输入或输出的名称、形状信息
print(f"Found {'input' if is_input else 'output'} {name} with shape: {axes}")
return axes
# 使用 NLP 模型的分词器生成 tokens,并返回张量表示
tokens = nlp.tokenizer("This is a sample output", return_tensors=framework)
# 获取序列长度
seq_len = tokens.input_ids.shape[-1]
# 根据框架类型调用 NLP 模型
outputs = nlp.model(**tokens) if framework == "pt" else nlp.model(tokens)
# 如果输出是 ModelOutput 类型,则转换为元组
if isinstance(outputs, ModelOutput):
outputs = outputs.to_tuple()
# 如果输出不是列表或元组,则将其包装成元组
if not isinstance(outputs, (list, tuple)):
outputs = (outputs,)
# 生成输入变量的名称及其动态轴信息
input_vars = list(tokens.keys())
input_dynamic_axes = {k: build_shape_dict(k, v, True, seq_len) for k, v in tokens.items()}
# 将可能包含分组输出(例如 gpt2 中的过去状态或注意力)展平
outputs_flat = []
for output in outputs:
if isinstance(output, (tuple, list)):
outputs_flat.extend(output)
else:
outputs_flat.append(output)
# 生成输出变量的名称及其动态轴信息
output_names = [f"output_{i}" for i in range(len(outputs_flat))]
output_dynamic_axes = {k: build_shape_dict(k, v, False, seq_len) for k, v in zip(output_names, outputs_flat)}
# 创建汇总的动态轴表示
dynamic_axes = dict(input_dynamic_axes, **output_dynamic_axes)
return input_vars, output_names, dynamic_axes, tokens
def load_graph_from_args(
pipeline_name: str, framework: str, model: str, tokenizer: Optional[str] = None, **models_kwargs
) -> Pipeline:
"""
Convert the set of arguments provided through the CLI to an actual pipeline reference (tokenizer + model)
Args:
pipeline_name: The kind of pipeline to use (ner, question-answering, etc.)
framework: The actual model to convert the pipeline from ("pt" or "tf")
model: The model name which will be loaded by the pipeline
tokenizer: The tokenizer name which will be loaded by the pipeline, default to the model's value
Returns: Pipeline object
"""
# 如果未提供 tokenizer,则使用 model 作为 tokenizer
if tokenizer is None:
tokenizer = model
# 检查所需的 framework 是否可用
if framework == "pt" and not is_torch_available():
raise Exception("Cannot convert because PyTorch is not installed. Please install torch first.")
if framework == "tf" and not is_tf_available():
raise Exception("Cannot convert because TF is not installed. Please install tensorflow first.")
print(f"Loading pipeline (model: {model}, tokenizer: {tokenizer})")
# 分配 tokenizer 和 model
return pipeline(pipeline_name, model=model, tokenizer=tokenizer, framework=framework, model_kwargs=models_kwargs)
def convert_pytorch(nlp: Pipeline, opset: int, output: Path, use_external_format: bool):
"""
Export a PyTorch backed pipeline to ONNX Intermediate Representation (IR)
Args:
nlp: The pipeline to be exported
opset: The actual version of the ONNX operator set to use
output: Path where will be stored the generated ONNX model
use_external_format: Split the model definition from its parameters to allow model bigger than 2GB
Returns:
"""
if not is_torch_available():
raise Exception("Cannot convert because PyTorch is not installed. Please install torch first.")
import torch
from torch.onnx import export
print(f"Using framework PyTorch: {torch.__version__}")
# 通过 infer_shapes 推断输入、输出和动态轴
with torch.no_grad():
input_names, output_names, dynamic_axes, tokens = infer_shapes(nlp, "pt")
# 确保输入名称有效,并按顺序提供模型参数
ordered_input_names, model_args = ensure_valid_input(nlp.model, tokens, input_names)
# 导出模型到 ONNX
export(
nlp.model,
model_args,
f=output.as_posix(),
input_names=ordered_input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
do_constant_folding=True,
opset_version=opset,
)
def convert_tensorflow(nlp: Pipeline, opset: int, output: Path):
"""
Export a TensorFlow backed pipeline to ONNX Intermediate Representation (IR)
Args:
nlp: The pipeline to be exported
opset: The actual version of the ONNX operator set to use
output: Path where will be stored the generated ONNX model
"""
# 检查是否安装了 TensorFlow
if not is_tf_available():
raise Exception("Cannot convert because TF is not installed. Please install tensorflow first.")
# 检查 TensorFlow 是否可用,若不可用则引发异常提示安装 TensorFlow
if not is_tf_available():
raise Exception("Cannot convert because TF is not installed. Please install tensorflow first.")
# 提示用户注意:TensorFlow 不支持导出超过2GB的模型
print("/!\\ Please note TensorFlow doesn't support exporting model > 2Gb /!\\")
try:
# 尝试导入 TensorFlow 和 tf2onnx
import tensorflow as tf
import tf2onnx
from tf2onnx import __version__ as t2ov
# 打印当前使用的框架和 tf2onnx 的版本信息
print(f"Using framework TensorFlow: {tf.version.VERSION}, tf2onnx: {t2ov}")
# 推断模型输入形状等信息,并获取 tokens
input_names, output_names, dynamic_axes, tokens = infer_shapes(nlp, "tf")
# 使用模型进行前向推断
nlp.model.predict(tokens.data)
# 根据 tokens 的数据创建输入签名
input_signature = [tf.TensorSpec.from_tensor(tensor, name=key) for key, tensor in tokens.items()]
# 使用 tf2onnx 将 Keras 模型转换为 ONNX 格式
model_proto, _ = tf2onnx.convert.from_keras(
nlp.model, input_signature, opset=opset, output_path=output.as_posix()
)
except ImportError as e:
# 若导入出错,引发异常提示缺少必要的包
raise Exception(
f"Cannot import {e.name} required to convert TF model to ONNX. Please install {e.name} first. {e}"
)
# 定义一个函数 convert,用于将管道对象转换为 ONNX 中间表示(IR)格式
def convert(
framework: str,
model: str,
output: Path,
opset: int,
tokenizer: Optional[str] = None,
use_external_format: bool = False,
pipeline_name: str = "feature-extraction",
**model_kwargs,
):
"""
Convert the pipeline object to the ONNX Intermediate Representation (IR) format
Args:
framework: 管道所使用的框架 ("pt" 或 "tf")
model: 管道加载的模型名称
output: 存储 ONNX 图的路径
opset: 使用的 ONNX 运算集的实际版本
tokenizer: 管道所使用的分词器名称,如果未提供则默认使用模型名称
use_external_format:
是否将模型定义与其参数分离,以允许超过 2GB 的模型大小(仅适用于 PyTorch)
pipeline_name: 实例化的管道类型(ner、question-answering 等)
model_kwargs: 转发给模型构造函数的关键字参数
Returns:
"""
# 发出警告,指示 `transformers.convert_graph_to_onnx` 包已过时,并将在 Transformers 的第五个版本中移除
warnings.warn(
"The `transformers.convert_graph_to_onnx` package is deprecated and will be removed in version 5 of"
" Transformers",
FutureWarning,
)
# 打印设置的 ONNX 运算集版本号
print(f"ONNX opset version set to: {opset}")
# 加载管道对象
nlp = load_graph_from_args(pipeline_name, framework, model, tokenizer, **model_kwargs)
# 检查输出路径的父目录是否存在,若不存在则创建
if not output.parent.exists():
print(f"Creating folder {output.parent}")
makedirs(output.parent.as_posix())
# 若输出路径的父目录非空,则抛出异常
elif len(listdir(output.parent.as_posix())) > 0:
raise Exception(f"Folder {output.parent.as_posix()} is not empty, aborting conversion")
# 根据不同的框架导出图
if framework == "pt":
convert_pytorch(nlp, opset, output, use_external_format)
else:
convert_tensorflow(nlp, opset, output)
# 定义一个函数 optimize,用于优化 ONNX 模型
def optimize(onnx_model_path: Path) -> Path:
"""
Load the model at the specified path and let onnxruntime look at transformations on the graph to enable all the
optimizations possible
Args:
onnx_model_path: 模型二进制描述文件的路径
Returns: 优化后的模型二进制描述文件保存的路径
"""
from onnxruntime import InferenceSession, SessionOptions
# 生成带有后缀 "-optimized" 的优化模型文件名
opt_model_path = generate_identified_filename(onnx_model_path, "-optimized")
sess_option = SessionOptions()
# 设置优化后的模型文件路径
sess_option.optimized_model_filepath = opt_model_path.as_posix()
_ = InferenceSession(onnx_model_path.as_posix(), sess_option)
# 打印优化后的模型写入路径
print(f"Optimized model has been written at {opt_model_path}: \N{heavy check mark}")
# 提示优化后的模型包含特定硬件操作符,可能不具备可移植性
print("/!\\ Optimized model contains hardware specific operators which might not be portable. /!\\")
return opt_model_path
# 定义一个函数 quantize,用于将模型权重从 float32 量化为 int8,以实现在现代 CPU 上高效推断
def quantize(onnx_model_path: Path) -> Path:
"""
Quantize the weights of the model from float32 to in8 to allow very efficient inference on modern CPU
Args:
onnx_model_path: 模型二进制描述文件的路径
Returns: 量化后的模型二进制描述文件保存的路径
"""
# 函数体未完,暂时省略
# 导入必要的库和模块
import onnx
import onnxruntime
from onnx.onnx_pb import ModelProto
from onnxruntime.quantization import QuantizationMode
from onnxruntime.quantization.onnx_quantizer import ONNXQuantizer
from onnxruntime.quantization.registry import IntegerOpsRegistry
# 加载指定路径下的 ONNX 模型
onnx_model = onnx.load(onnx_model_path.as_posix())
# 检查 ONNX 版本是否小于 1.5.0,提示模型大小限制问题
if parse(onnx.__version__) < parse("1.5.0"):
print(
"Models larger than 2GB will fail to quantize due to protobuf constraint.\n"
"Please upgrade to onnxruntime >= 1.5.0."
)
# 创建 ONNX 模型的副本
copy_model = ModelProto()
copy_model.CopyFrom(onnx_model)
# 构造量化器
# 检查 ONNX Runtime 版本,根据版本选择合适的量化器参数设置
if parse(onnxruntime.__version__) < parse("1.13.1"):
quantizer = ONNXQuantizer(
model=copy_model,
per_channel=False,
reduce_range=False,
mode=QuantizationMode.IntegerOps,
static=False,
weight_qType=True,
input_qType=False,
tensors_range=None,
nodes_to_quantize=None,
nodes_to_exclude=None,
op_types_to_quantize=list(IntegerOpsRegistry),
)
else:
quantizer = ONNXQuantizer(
model=copy_model,
per_channel=False,
reduce_range=False,
mode=QuantizationMode.IntegerOps,
static=False,
weight_qType=True,
activation_qType=False,
tensors_range=None,
nodes_to_quantize=None,
nodes_to_exclude=None,
op_types_to_quantize=list(IntegerOpsRegistry),
)
# 执行模型量化
quantizer.quantize_model()
# 生成量化后模型的文件名,并在原模型文件名末尾添加 "-quantized" 后缀
quantized_model_path = generate_identified_filename(onnx_model_path, "-quantized")
# 保存量化后的模型
print(f"Quantized model has been written at {quantized_model_path}: \N{heavy check mark}")
onnx.save_model(quantizer.model.model, quantized_model_path.as_posix())
# 返回量化后模型的路径
return quantized_model_path
def verify(path: Path):
# 引入需要的库和模块
from onnxruntime import InferenceSession, SessionOptions
from onnxruntime.capi.onnxruntime_pybind11_state import RuntimeException
# 打印正在加载的 ONNX 模型路径
print(f"Checking ONNX model loading from: {path} ...")
try:
# 设置 ONNX 运行时的选项
onnx_options = SessionOptions()
# 创建推理会话,加载模型并指定 CPU 执行提供者
_ = InferenceSession(path.as_posix(), onnx_options, providers=["CPUExecutionProvider"])
# 打印模型加载成功的消息
print(f"Model {path} correctly loaded: \N{heavy check mark}")
except RuntimeException as re:
# 捕获模型加载时的异常并打印错误消息
print(f"Error while loading the model {re}: \N{heavy ballot x}")
if __name__ == "__main__":
# 解析命令行参数
parser = OnnxConverterArgumentParser()
args = parser.parse_args()
# 确保输出路径为绝对路径
args.output = Path(args.output).absolute()
try:
print("\n====== Converting model to ONNX ======")
# 执行模型转换
convert(
args.framework,
args.model,
args.output,
args.opset,
args.tokenizer,
args.use_external_format,
args.pipeline,
)
if args.quantize:
# 确保满足 quantization 在 onnxruntime 上的要求
check_onnxruntime_requirements(ORT_QUANTIZE_MINIMUM_VERSION)
# 对于 TensorFlow 框架,性能优化不如 PyTorch 显著
if args.framework == "tf":
print(
"\t Using TensorFlow might not provide the same optimization level compared to PyTorch.\n"
"\t For TensorFlow users you can try optimizing the model directly through onnxruntime_tools.\n"
"\t For more information, please refer to the onnxruntime documentation:\n"
"\t\thttps://github.com/microsoft/onnxruntime/tree/master/onnxruntime/python/tools/transformers\n"
)
print("\n====== Optimizing ONNX model ======")
# 对优化后的模型进行量化
args.optimized_output = optimize(args.output)
# 在正确的图上执行量化
args.quantized_output = quantize(args.optimized_output)
# 验证转换后的模型
if args.check_loading:
print("\n====== Check exported ONNX model(s) ======")
verify(args.output)
if hasattr(args, "optimized_output"):
verify(args.optimized_output)
if hasattr(args, "quantized_output"):
verify(args.quantized_output)
except Exception as e:
# 捕获转换过程中的异常并打印错误消息
print(f"Error while converting the model: {e}")
exit(1)
.\convert_pytorch_checkpoint_to_tf2.py
""" Convert pytorch checkpoints to TensorFlow"""
import argparse
import os
from . import (
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
BART_PRETRAINED_MODEL_ARCHIVE_LIST,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
DPR_QUESTION_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST,
ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP,
FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,
LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST,
LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
T5_PRETRAINED_CONFIG_ARCHIVE_MAP,
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP,
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
AlbertConfig,
BartConfig,
BertConfig,
CamembertConfig,
CTRLConfig,
DistilBertConfig,
DPRConfig,
ElectraConfig,
FlaubertConfig,
GPT2Config,
LayoutLMConfig,
LxmertConfig,
OpenAIGPTConfig,
RobertaConfig,
T5Config,
TFAlbertForPreTraining,
TFBartForConditionalGeneration,
TFBartForSequenceClassification,
TFBertForPreTraining,
TFBertForQuestionAnswering,
TFBertForSequenceClassification,
TFCamembertForMaskedLM,
TFCTRLLMHeadModel,
TFDistilBertForMaskedLM,
TFDistilBertForQuestionAnswering,
TFDPRContextEncoder,
TFDPRQuestionEncoder,
TFDPRReader,
TFElectraForPreTraining,
TFFlaubertWithLMHeadModel,
TFGPT2LMHeadModel,
TFLayoutLMForMaskedLM,
TFLxmertForPreTraining,
TFLxmertVisualFeatureEncoder,
TFOpenAIGPTLMHeadModel,
TFRobertaForCausalLM,
TFRobertaForMaskedLM,
TFRobertaForSequenceClassification,
TFT5ForConditionalGeneration,
TFTransfoXLLMHeadModel,
TFWav2Vec2Model,
TFXLMRobertaForMaskedLM,
TFXLMWithLMHeadModel,
TFXLNetLMHeadModel,
TransfoXLConfig,
Wav2Vec2Config,
Wav2Vec2Model,
XLMConfig,
XLMRobertaConfig,
XLNetConfig,
is_torch_available,
load_pytorch_checkpoint_in_tf2_model,
)
from .utils import CONFIG_NAME, WEIGHTS_NAME, cached_file, logging
if is_torch_available():
import numpy as np
import torch
from . import (
AlbertForPreTraining,
BartForConditionalGeneration,
BertForPreTraining,
BertForQuestionAnswering,
BertForSequenceClassification,
CamembertForMaskedLM,
CTRLLMHeadModel,
DistilBertForMaskedLM,
DistilBertForQuestionAnswering,
DPRContextEncoder,
DPRQuestionEncoder,
DPRReader,
ElectraForPreTraining,
FlaubertWithLMHeadModel,
GPT2LMHeadModel,
LayoutLMForMaskedLM,
LxmertForPreTraining,
LxmertVisualFeatureEncoder,
OpenAIGPTLMHeadModel,
RobertaForMaskedLM,
RobertaForSequenceClassification,
T5ForConditionalGeneration,
TransfoXLLMHeadModel,
XLMRobertaForMaskedLM,
XLMWithLMHeadModel,
XLNetLMHeadModel,
)
from .pytorch_utils import is_torch_greater_or_equal_than_1_13
logging.set_verbosity_info()
MODEL_CLASSES = {
"bart": (
BartConfig,
TFBartForConditionalGeneration,
TFBartForSequenceClassification,
BartForConditionalGeneration,
BART_PRETRAINED_MODEL_ARCHIVE_LIST,
),
"bert": (
BertConfig,
TFBertForPreTraining,
BertForPreTraining,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"google-bert/bert-large-uncased-whole-word-masking-finetuned-squad": (
BertConfig,
TFBertForQuestionAnswering,
BertForQuestionAnswering,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"google-bert/bert-large-cased-whole-word-masking-finetuned-squad": (
BertConfig,
TFBertForQuestionAnswering,
BertForQuestionAnswering,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"google-bert/bert-base-cased-finetuned-mrpc": (
BertConfig,
TFBertForSequenceClassification,
BertForSequenceClassification,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"dpr": (
DPRConfig,
TFDPRQuestionEncoder,
TFDPRContextEncoder,
TFDPRReader,
DPRQuestionEncoder,
DPRContextEncoder,
DPRReader,
DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
DPR_QUESTION_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST,
),
"openai-community/gpt2": (
GPT2Config,
TFGPT2LMHeadModel,
GPT2LMHeadModel,
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"xlnet": (
XLNetConfig,
TFXLNetLMHeadModel,
XLNetLMHeadModel,
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"xlm": (
XLMConfig,
TFXLMWithLMHeadModel,
XLMWithLMHeadModel,
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"xlm-roberta": (
XLMRobertaConfig,
TFXLMRobertaForMaskedLM,
XLMRobertaForMaskedLM,
XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"transfo-xl": (
TransfoXLConfig,
TFTransfoXLLMHeadModel,
TransfoXLLMHeadModel,
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"openai-community/openai-gpt": (
OpenAIGPTConfig,
TFOpenAIGPTLMHeadModel,
OpenAIGPTLMHeadModel,
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"roberta": (
RobertaConfig,
TFRobertaForCausalLM,
TFRobertaForMaskedLM,
RobertaForMaskedLM,
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"layoutlm": (
LayoutLMConfig,
TFLayoutLMForMaskedLM,
LayoutLMForMaskedLM,
LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST,
),
"FacebookAI/roberta-large-mnli": (
RobertaConfig,
TFRobertaForSequenceClassification,
RobertaForSequenceClassification,
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"camembert": (
CamembertConfig,
TFCamembertForMaskedLM,
CamembertForMaskedLM,
CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"flaubert": (
FlaubertConfig,
TFFlaubertWithLMHeadModel,
FlaubertWithLMHeadModel,
FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"distilbert": (
DistilBertConfig,
TFDistilBertForMaskedLM,
DistilBertForMaskedLM,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"distilbert-base-distilled-squad": (
DistilBertConfig,
TFDistilBertForQuestionAnswering,
DistilBertForQuestionAnswering,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"lxmert": (
LxmertConfig,
TFLxmertForPreTraining,
LxmertForPreTraining,
LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"lxmert-visual-feature-encoder": (
LxmertConfig,
TFLxmertVisualFeatureEncoder,
LxmertVisualFeatureEncoder,
LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"Salesforce/ctrl": (
CTRLConfig,
TFCTRLLMHeadModel,
CTRLLMHeadModel,
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"albert": (
AlbertConfig,
TFAlbertForPreTraining,
AlbertForPreTraining,
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"t5": (
T5Config,
TFT5ForConditionalGeneration,
T5ForConditionalGeneration,
T5_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"electra": (
ElectraConfig,
TFElectraForPreTraining,
ElectraForPreTraining,
ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"wav2vec2": (
Wav2Vec2Config,
TFWav2Vec2Model,
Wav2Vec2Model,
WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
}
def convert_pt_checkpoint_to_tf(
model_type, pytorch_checkpoint_path, config_file, tf_dump_path, compare_with_pt_model=False, use_cached_models=True
):
if model_type not in MODEL_CLASSES:
raise ValueError(f"Unrecognized model type, should be one of {list(MODEL_CLASSES.keys())}.")
config_class, model_class, pt_model_class, aws_config_map = MODEL_CLASSES[model_type]
if config_file in aws_config_map:
config_file = cached_file(config_file, CONFIG_NAME, force_download=not use_cached_models)
config = config_class.from_json_file(config_file)
config.output_hidden_states = True
config.output_attentions = True
print(f"Building TensorFlow model from configuration: {config}")
tf_model = model_class(config)
if pytorch_checkpoint_path in aws_config_map.keys():
pytorch_checkpoint_path = cached_file(
pytorch_checkpoint_path, WEIGHTS_NAME, force_download=not use_cached_models
)
tf_model = load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path)
if compare_with_pt_model:
tfo = tf_model(tf_model.dummy_inputs, training=False)
weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
state_dict = torch.load(
pytorch_checkpoint_path,
map_location="cpu",
**weights_only_kwarg,
)
pt_model = pt_model_class.from_pretrained(
pretrained_model_name_or_path=None, config=config, state_dict=state_dict
)
with torch.no_grad():
pto = pt_model(**pt_model.dummy_inputs)
np_pt = pto[0].numpy()
np_tf = tfo[0].numpy()
diff = np.amax(np.abs(np_pt - np_tf))
print(f"Max absolute difference between models outputs {diff}")
assert diff <= 2e-2, f"Error, model absolute difference is >2e-2: {diff}"
print(f"Save TensorFlow model to {tf_dump_path}")
tf_model.save_weights(tf_dump_path, save_format="h5")
def convert_all_pt_checkpoints_to_tf(
args_model_type,
tf_dump_path,
model_shortcut_names_or_path=None,
config_shortcut_names_or_path=None,
compare_with_pt_model=False,
use_cached_models=False,
remove_cached_files=False,
only_convert_finetuned_models=False,
):
if args_model_type is None:
model_types = list(MODEL_CLASSES.keys())
else:
model_types = [args_model_type]
for j, model_type in enumerate(model_types, start=1):
print("=" * 100)
print(f" Converting model type {j}/{len(model_types)}: {model_type}")
print("=" * 100)
if model_type not in MODEL_CLASSES:
raise ValueError(f"Unrecognized model type {model_type}, should be one of {list(MODEL_CLASSES.keys())}.")
config_class, model_class, pt_model_class, aws_model_maps, aws_config_map = MODEL_CLASSES[model_type]
if model_shortcut_names_or_path is None:
model_shortcut_names_or_path = list(aws_model_maps.keys())
if config_shortcut_names_or_path is None:
config_shortcut_names_or_path = model_shortcut_names_or_path
for i, (model_shortcut_name, config_shortcut_name) in enumerate(
zip(model_shortcut_names_or_path, config_shortcut_names_or_path), start=1
):
print("-" * 100)
if "-squad" in model_shortcut_name or "-mrpc" in model_shortcut_name or "-mnli" in model_shortcut_name:
if not only_convert_finetuned_models:
print(f" Skipping finetuned checkpoint {model_shortcut_name}")
continue
model_type = model_shortcut_name
elif only_convert_finetuned_models:
print(f" Skipping not finetuned checkpoint {model_shortcut_name}")
continue
print(f" Converting checkpoint {i}/{len(aws_config_map)}: {model_shortcut_name} - model_type {model_type}")
print("-" * 100)
if config_shortcut_name in aws_config_map:
config_file = cached_file(config_shortcut_name, CONFIG_NAME, force_download=not use_cached_models)
else:
config_file = config_shortcut_name
if model_shortcut_name in aws_model_maps:
model_file = cached_file(model_shortcut_name, WEIGHTS_NAME, force_download=not use_cached_models)
else:
model_file = model_shortcut_name
if os.path.isfile(model_shortcut_name):
model_shortcut_name = "converted_model"
convert_pt_checkpoint_to_tf(
model_type=model_type,
pytorch_checkpoint_path=model_file,
config_file=config_file,
tf_dump_path=os.path.join(tf_dump_path, model_shortcut_name + "-tf_model.h5"),
compare_with_pt_model=compare_with_pt_model,
)
if remove_cached_files:
os.remove(config_file)
os.remove(model_file)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--tf_dump_path", default=None, type=str, required=True, help="Path to the output Tensorflow dump file."
)
parser.add_argument(
"--model_type",
default=None,
type=str,
help=(
f"Model type selected in the list of {list(MODEL_CLASSES.keys())}. If not given, will download and "
"convert all the models from AWS."
),
)
parser.add_argument(
"--pytorch_checkpoint_path",
default=None,
type=str,
help=(
"Path to the PyTorch checkpoint path or shortcut name to download from AWS. "
"If not given, will download and convert all the checkpoints from AWS."
),
)
parser.add_argument(
"--config_file",
default=None,
type=str,
help=(
"The config json file corresponding to the pre-trained model. \n"
"This specifies the model architecture. If not given and "
"--pytorch_checkpoint_path is not given or is a shortcut name "
"use the configuration associated to the shortcut name on the AWS"
),
)
parser.add_argument(
"--compare_with_pt_model", action="store_true", help="Compare Tensorflow and PyTorch model predictions."
)
parser.add_argument(
"--use_cached_models",
action="store_true",
help="Use cached models if possible instead of updating to latest checkpoint versions.",
)
parser.add_argument(
"--remove_cached_files",
action="store_true",
help="Remove pytorch models after conversion (save memory when converting in batches).",
)
parser.add_argument("--only_convert_finetuned_models", action="store_true", help="Only convert finetuned models.")
args = parser.parse_args()
convert_all_pt_checkpoints_to_tf(
args.model_type.lower() if args.model_type is not None else None,
args.tf_dump_path,
model_shortcut_names_or_path=[args.pytorch_checkpoint_path] if args.pytorch_checkpoint_path is not None else None,
config_shortcut_names_or_path=[args.config_file] if args.config_file is not None else None,
compare_with_pt_model=args.compare_with_pt_model,
use_cached_models=args.use_cached_models,
remove_cached_files=args.remove_cached_files,
only_convert_finetuned_models=args.only_convert_finetuned_models,
)
.\convert_slow_tokenizer.py
"""
Utilities to convert slow tokenizers in their fast tokenizers counterparts.
All the conversions are grouped here to gather SentencePiece dependencies outside of the fast tokenizers files and
allow to make our dependency on SentencePiece optional.
"""
import warnings
from typing import Dict, List, Tuple
from packaging import version
from tokenizers import AddedToken, Regex, Tokenizer, decoders, normalizers, pre_tokenizers, processors
from tokenizers.models import BPE, Unigram, WordPiece
from .utils import is_protobuf_available, requires_backends
from .utils.import_utils import PROTOBUF_IMPORT_ERROR
def import_protobuf(error_message=""):
if is_protobuf_available():
import google.protobuf
if version.parse(google.protobuf.__version__) < version.parse("4.0.0"):
from transformers.utils import sentencepiece_model_pb2
else:
from transformers.utils import sentencepiece_model_pb2_new as sentencepiece_model_pb2
return sentencepiece_model_pb2
else:
raise ImportError(PROTOBUF_IMPORT_ERROR.format(error_message))
class SentencePieceExtractor:
"""
Extractor implementation for SentencePiece trained models. https://github.com/google/sentencepiece
"""
def __init__(self, model: str):
requires_backends(self, "sentencepiece")
from sentencepiece import SentencePieceProcessor
self.sp = SentencePieceProcessor()
self.sp.Load(model)
def extract(self, vocab_scores=None) -> Tuple[Dict[str, int], List[Tuple]]:
"""
By default will return vocab and merges with respect to their order, by sending `vocab_scores` we're going to
order the merges with respect to the piece scores instead.
"""
sp = self.sp
vocab = {sp.id_to_piece(index): index for index in range(sp.GetPieceSize())}
if vocab_scores is not None:
vocab_scores, reverse = dict(vocab_scores), True
else:
vocab_scores, reverse = vocab, False
merges = []
for merge, piece_score in vocab_scores.items():
local = []
for index in range(1, len(merge)):
piece_l, piece_r = merge[:index], merge[index:]
if piece_l in vocab and piece_r in vocab:
local.append((piece_l, piece_r, piece_score))
local = sorted(local, key=lambda x: (vocab[x[0]], vocab[x[1]]))
merges.extend(local)
merges = sorted(merges, key=lambda val: val[2], reverse=reverse)
merges = [(val[0], val[1]) for val in merges]
return vocab, merges
class GemmaSentencePieceExtractor(SentencePieceExtractor):
def extract(self, vocab_scores=None) -> Tuple[Dict[str, int], List[Tuple]]:
"""
By default will return vocab and merges with respect to their order, by sending `vocab_scores` we're going to
order the merges with respect to the piece scores instead.
"""
sp = self.sp
vocab = {sp.id_to_piece(index): index for index in range(sp.GetPieceSize())}
vocab["\t"] = vocab.pop("<0x09>")
if vocab_scores is not None:
vocab_scores, reverse = dict(vocab_scores), True
else:
vocab_scores, reverse = vocab, False
merges = []
for merge, piece_score in vocab_scores.items():
local = []
for index in range(1, len(merge)):
piece_l, piece_r = merge[:index], merge[index:]
if piece_l in vocab and piece_r in vocab:
local.append((piece_l, piece_r, piece_score))
local = sorted(local, key=lambda x: (vocab[x[0]], vocab[x[1]]))
merges.extend(local)
merges = sorted(merges, key=lambda val: val[2], reverse=reverse)
merges = [(val[0], val[1]) for val in merges]
return vocab, merges
def check_number_comma(piece: str) -> bool:
return len(piece) < 2 or piece[-1] != "," or not piece[-2].isdigit()
class Converter:
def __init__(self, original_tokenizer):
self.original_tokenizer = original_tokenizer
def converted(self) -> Tokenizer:
raise NotImplementedError()
class BertConverter(Converter):
def converted(self) -> Tokenizer:
vocab = self.original_tokenizer.vocab
tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
tokenize_chinese_chars = False
strip_accents = False
do_lower_case = False
if hasattr(self.original_tokenizer, "basic_tokenizer"):
tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
tokenizer.normalizer = normalizers.BertNormalizer(
clean_text=True,
handle_chinese_chars=tokenize_chinese_chars,
strip_accents=strip_accents,
lowercase=do_lower_case,
)
tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
cls = str(self.original_tokenizer.cls_token)
sep = str(self.original_tokenizer.sep_token)
cls_token_id = self.original_tokenizer.cls_token_id
sep_token_id = self.original_tokenizer.sep_token_id
tokenizer.post_processor = processors.TemplateProcessing(
single=f"{cls}:0 $A:0 {sep}:0",
pair=f"{cls}:0 $A:0 {sep}:0 $B:1 {sep}:1",
special_tokens=[
(cls, cls_token_id),
(sep, sep_token_id),
],
)
tokenizer.decoder = decoders.WordPiece(prefix="##")
return tokenizer
class SplinterConverter(Converter):
def converted(self) -> Tokenizer:
vocab = self.original_tokenizer.vocab
tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
tokenize_chinese_chars = False
strip_accents = False
do_lower_case = False
if hasattr(self.original_tokenizer, "basic_tokenizer"):
tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
tokenizer.normalizer = normalizers.BertNormalizer(
clean_text=True,
handle_chinese_chars=tokenize_chinese_chars,
strip_accents=strip_accents,
lowercase=do_lower_case,
)
tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
cls = str(self.original_tokenizer.cls_token)
sep = str(self.original_tokenizer.sep_token)
question = str(self.original_tokenizer.question_token)
dot = "."
cls_token_id = self.original_tokenizer.cls_token_id
sep_token_id = self.original_tokenizer.sep_token_id
question_token_id = self.original_tokenizer.question_token_id
dot_token_id = self.original_tokenizer.convert_tokens_to_ids(".")
if self.original_tokenizer.padding_side == "right":
pair = f"{cls}:0 $A:0 {question} {dot} {sep}:0 $B:1 {sep}:1"
else:
pair = f"{cls}:0 $A:0 {sep}:0 $B:1 {question} {dot} {sep}:1"
tokenizer.post_processor = processors.TemplateProcessing(
single=f"{cls}:0 $A:0 {sep}:0",
pair=pair,
special_tokens=[
(cls, cls_token_id),
(sep, sep_token_id),
(question, question_token_id),
(dot, dot_token_id),
],
)
tokenizer.decoder = decoders.WordPiece(prefix="##")
return tokenizer
def converted(self) -> Tokenizer:
vocab = self.original_tokenizer.vocab
tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
tokenize_chinese_chars = False
strip_accents = False
do_lower_case = False
if hasattr(self.original_tokenizer, "basic_tokenizer"):
tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
tokenizer.normalizer = normalizers.BertNormalizer(
clean_text=True,
handle_chinese_chars=tokenize_chinese_chars,
strip_accents=strip_accents,
lowercase=do_lower_case,
)
tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
cls = str(self.original_tokenizer.cls_token)
sep = str(self.original_tokenizer.sep_token)
cls_token_id = self.original_tokenizer.cls_token_id
sep_token_id = self.original_tokenizer.sep_token_id
tokenizer.post_processor = processors.TemplateProcessing(
single=f"{cls}:2 $A:0 {sep}:0",
pair=f"{cls}:2 $A:0 {sep}:0 $B:1 {sep}:1",
special_tokens=[
(cls, cls_token_id),
(sep, sep_token_id),
],
)
tokenizer.decoder = decoders.WordPiece(prefix="##")
return tokenizer
class MPNetConverter(Converter):
def converted(self) -> Tokenizer:
vocab = self.original_tokenizer.vocab
tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
tokenize_chinese_chars = False
strip_accents = False
do_lower_case = False
if hasattr(self.original_tokenizer, "basic_tokenizer"):
tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
tokenizer.normalizer = normalizers.BertNormalizer(
clean_text=True,
handle_chinese_chars=tokenize_chinese_chars,
strip_accents=strip_accents,
lowercase=do_lower_case,
)
tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
cls = str(self.original_tokenizer.cls_token)
sep = str(self.original_tokenizer.sep_token)
cls_token_id = self.original_tokenizer.cls_token_id
sep_token_id = self.original_tokenizer.sep_token_id
tokenizer.post_processor = processors.TemplateProcessing(
single=f"{cls}:0 $A:0 {sep}:0",
pair=f"{cls}:0 $A:0 {sep}:0 {sep}:0 $B:1 {sep}:1",
special_tokens=[
(cls, cls_token_id),
(sep, sep_token_id),
],
)
tokenizer.decoder = decoders.WordPiece(prefix="##")
return tokenizer
class OpenAIGPTConverter(Converter):
def converted(self) -> Tokenizer:
vocab = self.original_tokenizer.encoder
merges = list(self.original_tokenizer.bpe_ranks.keys())
unk_token = self.original_tokenizer.unk_token
tokenizer = Tokenizer(
BPE(
vocab=vocab,
merges=merges,
dropout=None,
unk_token=str(unk_token),
end_of_word_suffix="</w>",
fuse_unk=False,
)
)
if tokenizer.token_to_id(str(unk_token)) is not None:
tokenizer.add_special_tokens([str(unk_token)])
tokenizer.normalizer = normalizers.BertNormalizer(lowercase=True)
tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
tokenizer.decoder = decoders.BPEDecoder(suffix="</w>")
return tokenizer
class GPT2Converter(Converter):
def converted(self) -> Tokenizer:
vocab = self.original_tokenizer.encoder
merges = list(self.original_tokenizer.bpe_ranks.keys())
tokenizer = Tokenizer(
BPE(
vocab=vocab,
merges=merges,
dropout=None,
continuing_subword_prefix="",
end_of_word_suffix="",
fuse_unk=False,
)
)
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=self.original_tokenizer.add_prefix_space)
tokenizer.decoder = decoders.ByteLevel()
if self.original_tokenizer.add_bos_token:
bos = self.original_tokenizer.bos_token
bos_token_id = self.original_tokenizer.bos_token_id
tokenizer.post_processor = processors.TemplateProcessing(
single=f"{bos}:0 $A:0",
pair=f"{bos}:0 $A:0 $B:1",
special_tokens=[
(bos, bos_token_id),
],
)
else:
tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
return tokenizer
class HerbertConverter(Converter):
def converted(self) -> Tokenizer:
tokenizer_info_str = "#version:"
token_suffix = "</w>"
vocab = self.original_tokenizer.encoder
merges = list(self.original_tokenizer.bpe_ranks.keys())
if tokenizer_info_str in merges[0][0]:
merges = merges[1:]
tokenizer = Tokenizer(
BPE(
vocab,
merges,
dropout=None,
unk_token=self.original_tokenizer.unk_token,
end_of_word_suffix=token_suffix,
)
)
tokenizer.normalizer = normalizers.BertNormalizer(lowercase=False, strip_accents=False)
tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
tokenizer.decoder = decoders.BPEDecoder(suffix=token_suffix)
tokenizer.post_processor = processors.BertProcessing(
sep=(self.original_tokenizer.sep_token, self.original_tokenizer.sep_token_id),
cls=(self.original_tokenizer.cls_token, self.original_tokenizer.cls_token_id),
)
return tokenizer
class Qwen2Converter(Converter):
def converted(self) -> Tokenizer:
vocab = self.original_tokenizer.encoder
merges = list(self.original_tokenizer.bpe_ranks.keys())
tokenizer = Tokenizer(
BPE(
vocab=vocab,
merges=merges,
dropout=None,
unk_token=None,
continuing_subword_prefix="",
end_of_word_suffix="",
fuse_unk=False,
byte_fallback=False,
)
)
tokenizer.normalizer = normalizers.NFC()
tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
[
pre_tokenizers.Split(
Regex(
r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
),
behavior="isolated",
invert=False,
),
pre_tokenizers.ByteLevel(
add_prefix_space=getattr(self.original_tokenizer, "add_prefix_space", False),
use_regex=False,
),
]
)
tokenizer.decoder = decoders.ByteLevel()
tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
return tokenizer
def converted(self) -> Tokenizer:
ot = self.original_tokenizer
vocab = ot.encoder
merges = list(ot.bpe_ranks.keys())
tokenizer = Tokenizer(
BPE(
vocab=vocab,
merges=merges,
dropout=None,
continuing_subword_prefix="",
end_of_word_suffix="",
fuse_unk=False,
)
)
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=ot.add_prefix_space)
tokenizer.decoder = decoders.ByteLevel()
tokenizer.post_processor = processors.RobertaProcessing(
sep=(ot.sep_token, ot.sep_token_id),
cls=(ot.cls_token, ot.cls_token_id),
add_prefix_space=ot.add_prefix_space,
trim_offsets=True,
)
return tokenizer
class RoFormerConverter(Converter):
def converted(self) -> Tokenizer:
from .models.roformer.tokenization_utils import JiebaPreTokenizer
vocab = self.original_tokenizer.vocab
tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
strip_accents = False
do_lower_case = False
if hasattr(self.original_tokenizer, "basic_tokenizer"):
strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
tokenizer.normalizer = normalizers.BertNormalizer(
clean_text=True,
handle_chinese_chars=False,
strip_accents=strip_accents,
lowercase=do_lower_case,
)
tokenizer.pre_tokenizer = pre_tokenizers.PreTokenizer.custom(JiebaPreTokenizer(vocab))
cls = str(self.original_tokenizer.cls_token)
sep = str(self.original_tokenizer.sep_token)
cls_token_id = self.original_tokenizer.cls_token_id
sep_token_id = self.original_tokenizer.sep_token_id
tokenizer.post_processor = processors.TemplateProcessing(
single=f"{cls}:0 $A:0 {sep}:0",
pair=f"{cls}:0 $A:0 {sep}:0 $B:1 {sep}:1",
special_tokens=[
(cls, cls_token_id),
(sep, sep_token_id),
],
)
tokenizer.decoder = decoders.WordPiece(prefix="##")
return tokenizer
class DebertaConverter(Converter):
def converted(self) -> Tokenizer:
ot = self.original_tokenizer
tokenizer = Tokenizer(
BPE(
vocab=ot.encoder,
merges=list(ot.bpe_ranks.keys()),
dropout=None,
continuing_subword_prefix="",
end_of_word_suffix="",
fuse_unk=False,
)
)
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=ot.add_prefix_space)
tokenizer.decoder = decoders.ByteLevel()
tokenizer.post_processor = processors.TemplateProcessing(
single="[CLS]:0 $A:0 [SEP]:0",
pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
special_tokens=[
("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
],
)
return tokenizer
class SpmConverter(Converter):
def __init__(self, *args):
requires_backends(self, "protobuf")
super().__init__(*args)
model_pb2 = import_protobuf()
with open(self.original_tokenizer.vocab_file, "rb") as f:
m.ParseFromString(f.read())
self.proto = m
if self.proto.trainer_spec.byte_fallback:
if not getattr(self, "handle_byte_fallback", None):
warnings.warn(
"The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option"
" which is not implemented in the fast tokenizers. In practice this means that the fast version of the"
" tokenizer can produce unknown tokens whereas the sentencepiece version would have converted these "
"unknown tokens into a sequence of byte tokens matching the original piece of text."
)
def vocab(self, proto):
return [(piece.piece, piece.score) for piece in proto.pieces]
def unk_id(self, proto):
return proto.trainer_spec.unk_id
def tokenizer(self, proto):
model_type = proto.trainer_spec.model_type
vocab_scores = self.vocab(proto)
unk_id = self.unk_id(proto)
if model_type == 1:
tokenizer = Tokenizer(Unigram(vocab_scores, unk_id))
elif model_type == 2:
_, merges = SentencePieceExtractor(self.original_tokenizer.vocab_file).extract()
bpe_vocab = {word: i for i, (word, score) in enumerate(vocab_scores)}
tokenizer = Tokenizer(
BPE(
bpe_vocab,
merges,
unk_token=proto.trainer_spec.unk_piece,
fuse_unk=True,
)
)
else:
raise Exception(
"You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
)
return tokenizer
def normalizer(self, proto):
precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
_normalizers = [
normalizers.Strip(left=False, right=True),
normalizers.Replace(Regex(" {2,}"), "▁"),
]
if not precompiled_charsmap:
return normalizers.Sequence(_normalizers)
else:
return normalizers.Sequence([normalizers.Precompiled(precompiled_charsmap)] + _normalizers)
def pre_tokenizer(self, replacement, add_prefix_space):
prepend_scheme = "always"
if hasattr(self.original_tokenizer, "legacy") and not self.original_tokenizer.legacy:
prepend_scheme = "first"
return pre_tokenizers.Metaspace(
replacement=replacement, add_prefix_space=add_prefix_space, prepend_scheme=prepend_scheme
)
def post_processor(self):
return None
def decoder(self, replacement, add_prefix_space):
return decoders.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space)
def converted(self) -> Tokenizer:
tokenizer = self.tokenizer(self.proto)
normalizer = self.normalizer(self.proto)
if normalizer is not None:
tokenizer.normalizer = normalizer
replacement = "▁"
add_prefix_space = True
if hasattr(self.original_tokenizer, "add_prefix_space"):
add_prefix_space = self.original_tokenizer.add_prefix_space
pre_tokenizer = self.pre_tokenizer(replacement, add_prefix_space)
if pre_tokenizer is not None:
tokenizer.pre_tokenizer = pre_tokenizer
tokenizer.decoder = self.decoder(replacement, add_prefix_space)
post_processor = self.post_processor()
if post_processor:
tokenizer.post_processor = post_processor
return tokenizer
class AlbertConverter(SpmConverter):
def vocab(self, proto):
return [
(piece.piece, piece.score) if check_number_comma(piece.piece) else (piece.piece, piece.score - 100)
for piece in proto.pieces
]
def normalizer(self, proto):
list_normalizers = [
normalizers.Replace("``", '"'),
normalizers.Replace("''", '"'),
]
if not self.original_tokenizer.keep_accents:
list_normalizers.append(normalizers.NFKD())
list_normalizers.append(normalizers.StripAccents())
if self.original_tokenizer.do_lower_case:
list_normalizers.append(normalizers.Lowercase())
precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
if precompiled_charsmap:
list_normalizers.append(normalizers.Precompiled(precompiled_charsmap))
list_normalizers.append(normalizers.Replace(Regex(" {2,}"), " "))
return normalizers.Sequence(list_normalizers)
def post_processor(self):
return processors.TemplateProcessing(
single="[CLS]:0 $A:0 [SEP]:0",
pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
special_tokens=[
("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
],
)
class BarthezConverter(SpmConverter):
def unk_id(self, proto):
unk_id = 3
return unk_id
def post_processor(self):
return processors.TemplateProcessing(
single="<s> $A </s>",
pair="<s> $A </s> </s> $B </s>",
special_tokens=[
("<s>", self.original_tokenizer.convert_tokens_to_ids("<s>")),
("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
],
)
class CamembertConverter(SpmConverter):
def vocab(self, proto):
vocab = [
("<s>NOTUSED", 0.0),
("<pad>", 0.0),
("</s>NOTUSED", 0.0),
("<unk>", 0.0),
("<unk>NOTUSED", -100),
]
vocab += [(piece.piece, piece.score) for piece in proto.pieces[1:]]
vocab += [("<mask>", 0.0)]
return vocab
def unk_id(self, proto):
return 3
def post_processor(self):
return processors.TemplateProcessing(
single="<s> $A </s>",
pair="<s> $A </s> </s> $B </s>",
special_tokens=[
("<s>", self.original_tokenizer.convert_tokens_to_ids("<s>")),
("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
],
)
def pre_tokenizer(self, replacement, add_prefix_space):
list_pretokenizers = []
if self.original_tokenizer.split_by_punct:
list_pretokenizers.append(pre_tokenizers.Punctuation(behavior="isolated"))
list_pretokenizers.append(pre_tokenizers.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space))
return pre_tokenizers.Sequence(list_pretokenizers)
def normalizer(self, proto):
list_normalizers = []
if self.original_tokenizer.do_lower_case:
list_normalizers.append(normalizers.Lowercase())
list_normalizers.append(normalizers.Strip())
precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
if precompiled_charsmap:
list_normalizers.append(normalizers.Precompiled(precompiled_charsmap))
list_normalizers.append(normalizers.Replace(Regex(" {2,}"), " "))
return normalizers.Sequence(list_normalizers)
def post_processor(self):
return processors.TemplateProcessing(
single="[CLS]:0 $A:0 [SEP]:0",
pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
special_tokens=[
("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
],
)
class MBartConverter(SpmConverter):
def vocab(self, proto):
vocab = [
("<s>", 0.0),
("<pad>", 0.0),
("</s>", 0.0),
("<unk>", 0.0),
]
vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
vocab += [
("ar_AR", 0.0),
("cs_CZ", 0.0),
("de_DE", 0.0),
("en_XX", 0.0),
("es_XX", 0.0),
("et_EE", 0.0),
("fi_FI", 0.0),
("fr_XX", 0.0),
("gu_IN", 0.0),
("hi_IN", 0.0),
("it_IT", 0.0),
("ja_XX", 0.0),
("kk_KZ", 0.0),
("ko_KR", 0.0),
("lt_LT", 0.0),
("lv_LV", 0.0),
("my_MM", 0.0),
("ne_NP", 0.0),
("nl_XX", 0.0),
("ro_RO", 0.0),
("ru_RU", 0.0),
("si_LK", 0.0),
("tr_TR", 0.0),
("vi_VN", 0.0),
("zh_CN", 0.0),
]
vocab += [("<mask>", 0.0)]
return vocab
def unk_id(self, proto):
return 3
def post_processor(self):
return processors.TemplateProcessing(
single="$A </s> en_XX",
pair="$A $B </s> en_XX",
special_tokens=[
("en_XX", self.original_tokenizer.convert_tokens_to_ids("en_XX")),
("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
],
)
class MBart50Converter(SpmConverter):
def vocab(self, proto):
vocab = [
("<s>", 0.0),
("<pad>", 0.0),
("</s>", 0.0),
("<unk>", 0.0),
]
vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
vocab += [
("ar_AR", 0.0), ("cs_CZ", 0.0), ("de_DE", 0.0), ("en_XX", 0.0), ("es_XX", 0.0), ("et_EE", 0.0), ("fi_FI", 0.0),
("fr_XX", 0.0), ("gu_IN", 0.0), ("hi_IN", 0.0), ("it_IT", 0.0), ("ja_XX", 0.0), ("kk_KZ", 0.0), ("ko_KR", 0.0),
("lt_LT", 0.0), ("lv_LV", 0.0), ("my_MM", 0.0), ("ne_NP", 0.0), ("nl_XX", 0.0), ("ro_RO", 0.0), ("ru_RU", 0.0),
("si_LK", 0.0), ("tr_TR", 0.0), ("vi_VN", 0.0), ("zh_CN", 0.0), ("af_ZA", 0.0), ("az_AZ", 0.0), ("bn_IN", 0.0),
("fa_IR", 0.0), ("he_IL", 0.0), ("hr_HR", 0.0), ("id_ID", 0.0), ("ka_GE", 0.0), ("km_KH", 0.0), ("mk_MK", 0.0),
("ml_IN", 0.0), ("mn_MN", 0.0), ("mr_IN", 0.0), ("pl_PL", 0.0), ("ps_AF", 0.0), ("pt_XX", 0.0), ("sv_SE", 0.0),
("sw_KE", 0.0), ("ta_IN", 0.0), ("te_IN", 0.0), ("th_TH", 0.0), ("tl_XX", 0.0), ("uk_UA", 0.0), ("ur_PK", 0.0),
("xh_ZA", 0.0), ("gl_ES", 0.0), ("sl_SI", 0.0)
]
vocab += [("<mask>", 0.0)]
return vocab
def unk_id(self, proto):
return 3
def post_processor(self):
return processors.TemplateProcessing(
single="en_XX $A </s>",
pair="en_XX $A $B </s>",
special_tokens=[
("en_XX", self.original_tokenizer.convert_tokens_to_ids("en_XX")),
("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
],
)
class NllbConverter(SpmConverter):
def vocab(self, proto):
vocab = [
("<s>", 0.0),
("<pad>", 0.0),
("</s>", 0.0),
("<unk>", 0.0),
]
vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
return vocab
def unk_id(self, proto):
return 3
def post_processor(self):
return processors.TemplateProcessing(
single="eng_Latn $A </s>",
pair="eng_Latn $A $B </s>",
special_tokens=[
("eng_Latn", self.original_tokenizer.convert_tokens_to_ids("eng_Latn")),
("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
],
)
class SeamlessM4TConverter(SpmConverter):
def vocab(self, proto):
vocab = [
("<pad>", 0.0),
("<unk>", 0.0),
("<s>", 0.0),
("</s>", 0.0),
]
vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
return vocab
def unk_id(self, proto):
return self.original_tokenizer.unk_token_id
def post_processor(self):
return processors.TemplateProcessing(
single="__eng__ $A </s>",
pair="__eng__ $A $B </s>",
special_tokens=[
("__eng__", self.original_tokenizer.convert_tokens_to_ids("__eng__")),
("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
],
)
class XLMRobertaConverter(SpmConverter):
def vocab(self, proto):
vocab = [
("<s>", 0.0),
("<pad>", 0.0),
("</s>", 0.0),
("<unk>", 0.0),
]
vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
vocab += [("<mask>", 0.0)]
return vocab
def unk_id(self, proto):
unk_id = 3
return unk_id
def post_processor(self):
return processors.TemplateProcessing(
single="<s> $A </s>",
pair="<s> $A </s> </s> $B </s>",
special_tokens=[
("<s>", self.original_tokenizer.convert_tokens_to_ids("<s>")),
("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
],
)
class XLNetConverter(SpmConverter):
def vocab(self, proto):
return [
(piece.piece, piece.score) if check_number_comma(piece.piece) else (piece.piece, piece.score - 100)
for piece in proto.pieces
]
def normalizer(self, proto):
list_normalizers = [
normalizers.Replace("``", '"'),
normalizers.Replace("''", '"'),
]
if not self.original_tokenizer.keep_accents:
list_normalizers.append(normalizers.NFKD())
list_normalizers.append(normalizers.StripAccents())
if self.original_tokenizer.do_lower_case:
list_normalizers.append(normalizers.Lowercase())
precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
if precompiled_charsmap:
list_normalizers.append(normalizers.Precompiled(precompiled_charsmap))
list_normalizers.append(normalizers.Replace(Regex(" {2,}"), " "))
return normalizers.Sequence(list_normalizers)
def post_processor(self):
return processors.TemplateProcessing(
single="$A:0 <sep>:0 <cls>:2",
pair="$A:0 <sep>:0 $B:1 <sep>:1 <cls>:2",
special_tokens=[
("<sep>", self.original_tokenizer.convert_tokens_to_ids("<sep>")),
("<cls>", self.original_tokenizer.convert_tokens_to_ids("<cls>")),
],
)
class ReformerConverter(SpmConverter):
pass
class RemBertConverter(SpmConverter):
def normalizer(self, proto):
list_normalizers = [
normalizers.Replace("``", '"'),
normalizers.Replace("''", '"'),
normalizers.Replace(Regex(" {2,}"), " "),
]
if not self.original_tokenizer.keep_accents:
list_normalizers.append(normalizers.NFKD())
list_normalizers.append(normalizers.StripAccents())
if self.original_tokenizer.do_lower_case:
list_normalizers.append(normalizers.Lowercase())
precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
if precompiled_charsmap:
list_normalizers.append(normalizers.Precompiled(precompiled_charsmap))
return normalizers.Sequence(list_normalizers)
def post_processor(self):
return processors.TemplateProcessing(
single="[CLS]:0 $A:0 [SEP]:0",
pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
special_tokens=[
("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
],
)
class BertGenerationConverter(SpmConverter):
pass
class PegasusConverter(SpmConverter):
def vocab(self, proto):
vocab = [
(self.original_tokenizer.pad_token, 0.0),
(self.original_tokenizer.eos_token, 0.0),
]
if self.original_tokenizer.mask_token_sent is not None:
vocab += [(self.original_tokenizer.mask_token_sent, 0.0)]
if (
self.original_tokenizer.mask_token is not None
and self.original_tokenizer.mask_token_id < self.original_tokenizer.offset
):
vocab += [(self.original_tokenizer.mask_token, 0.0)]
vocab += [(f"<unk_{i}>", -100.0) for i in range(2, self.original_tokenizer.offset)]
vocab += [(piece.piece, piece.score) for piece in proto.pieces[2:]]
return vocab
def unk_id(self, proto):
return proto.trainer_spec.unk_id + self.original_tokenizer.offset
def pre_tokenizer(self, replacement, add_prefix_space):
return pre_tokenizers.Sequence(
[
pre_tokenizers.WhitespaceSplit(),
pre_tokenizers.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space),
]
)
def post_processor(self):
eos = self.original_tokenizer.eos_token
special_tokens = [
(eos, self.original_tokenizer.eos_token_id),
]
return processors.TemplateProcessing(single=["$A", eos], pair=["$A", "$B", eos], special_tokens=special_tokens)
class T5Converter(SpmConverter):
pass
def vocab(self, proto):
num_extra_ids = self.original_tokenizer._extra_ids
vocab = [(piece.piece, piece.score) for piece in proto.pieces]
vocab += [(f"<extra_id_{i}>", 0.0) for i in range(num_extra_ids - 1, -1, -1)]
return vocab
def post_processor(self):
return processors.TemplateProcessing(
single=["$A", "</s>"],
pair=["$A", "</s>", "$B", "</s>"],
special_tokens=[
("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
],
)
class UdopConverter(SpmConverter):
def post_processor(self):
return processors.TemplateProcessing(
single=["$A", "</s>"],
pair=["$A", "</s>", "$B", "</s>"],
special_tokens=[
("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
],
)
class WhisperConverter(Converter):
def converted(self) -> Tokenizer:
vocab = self.original_tokenizer.encoder
merges = list(self.original_tokenizer.bpe_ranks.keys())
tokenizer = Tokenizer(
BPE(
vocab=vocab,
merges=merges,
dropout=None,
continuing_subword_prefix="",
end_of_word_suffix="",
fuse_unk=False,
)
)
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=self.original_tokenizer.add_prefix_space)
tokenizer.decoder = decoders.ByteLevel()
prefix_token_ids = self.original_tokenizer.prefix_tokens
prefixes = self.original_tokenizer.convert_ids_to_tokens(prefix_token_ids)
eos = self.original_tokenizer.eos_token
eos_token_id = self.original_tokenizer.eos_token_id
prefix_template = " ".join([f"{token}:0" for token in prefixes])
tokenizer.post_processor = processors.TemplateProcessing(
single=f"{prefix_template} $A:0 {eos}:0",
pair=f"{prefix_template} $A:0 $B:1 {eos}:1",
special_tokens=[
(eos, eos_token_id),
*zip(prefixes, prefix_token_ids),
],
)
return tokenizer
class BigBirdConverter(SpmConverter):
def post_processor(self):
return processors.TemplateProcessing(
single="[CLS]:0 $A:0 [SEP]:0",
pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
special_tokens=[
("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
],
)
class CLIPConverter(Converter):
def converted(self) -> Tokenizer:
vocab = self.original_tokenizer.encoder
merges = list(self.original_tokenizer.bpe_ranks.keys())
unk_token = self.original_tokenizer.unk_token
tokenizer = Tokenizer(
BPE(
vocab=vocab,
merges=merges,
dropout=None,
continuing_subword_prefix="",
end_of_word_suffix="</w>",
fuse_unk=False,
unk_token=str(unk_token),
)
)
tokenizer.normalizer = normalizers.Sequence(
[normalizers.NFC(), normalizers.Replace(Regex(r"\s+"), " "), normalizers.Lowercase()]
)
tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
[
pre_tokenizers.Split(
Regex(r"""'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+"""),
behavior="removed",
invert=True,
),
pre_tokenizers.ByteLevel(add_prefix_space=False),
]
)
tokenizer.decoder = decoders.ByteLevel()
tokenizer.post_processor = processors.RobertaProcessing(
sep=(self.original_tokenizer.eos_token, self.original_tokenizer.eos_token_id),
cls=(self.original_tokenizer.bos_token, self.original_tokenizer.bos_token_id),
add_prefix_space=False,
trim_offsets=False,
)
return tokenizer
class LayoutLMv2Converter(Converter):
def converted(self) -> Tokenizer:
vocab = self.original_tokenizer.vocab
tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
tokenize_chinese_chars = False
strip_accents = False
do_lower_case = True
if hasattr(self.original_tokenizer, "basic_tokenizer"):
tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
tokenizer.normalizer = normalizers.BertNormalizer(
clean_text=True,
handle_chinese_chars=tokenize_chinese_chars,
strip_accents=strip_accents,
lowercase=do_lower_case,
)
tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
cls = str(self.original_tokenizer.cls_token)
sep = str(self.original_tokenizer.sep_token)
cls_token_id = self.original_tokenizer.cls_token_id
sep_token_id = self.original_tokenizer.sep_token_id
tokenizer.post_processor = processors.TemplateProcessing(
single=f"{cls}:0 $A:0 {sep}:0",
pair=f"{cls}:0 $A:0 {sep}:0 $B:1 {sep}:1",
special_tokens=[
(cls, cls_token_id),
(sep, sep_token_id),
],
)
tokenizer.decoder = decoders.WordPiece(prefix="##")
return tokenizer
class BlenderbotConverter(Converter):
def converted(self) -> Tokenizer:
ot = self.original_tokenizer
vocab = ot.encoder
merges = list(ot.bpe_ranks.keys())
tokenizer = Tokenizer(
BPE(
vocab=vocab,
merges=merges,
dropout=None,
continuing_subword_prefix="",
end_of_word_suffix="",
fuse_unk=False,
)
)
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=ot.add_prefix_space)
tokenizer.decoder = decoders.ByteLevel()
tokenizer.post_processor = processors.TemplateProcessing(
single=f"$A:0 {ot.eos_token}:0",
special_tokens=[
(ot.eos_token, ot.eos_token_id),
],
)
return tokenizer
class XGLMConverter(SpmConverter):
def vocab(self, proto):
vocab = [
("<s>", 0.0),
("<pad>", 0.0),
("</s>", 0.0),
("<unk>", 0.0),
]
vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
return vocab
def unk_id(self, proto):
unk_id = 3
return unk_id
def post_processor(self):
return processors.TemplateProcessing(
single="</s> $A",
pair="</s> $A </s> </s> $B",
special_tokens=[
("<s>", self.original_tokenizer.convert_tokens_to_ids("<s>")),
("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
],
)
class GemmaConvert(SpmConverter):
handle_byte_fallback = True
""""
split_by_unicode_script: true
split_by_number: true
split_by_whitespace: true
treat_whitespace_as_suffix: false
allow_whitespace_only_pieces: true
split_digits: true
byte_fallback: true
"""
def normalizer(self, proto):
return normalizers.Replace(" ", "▁")
def vocab(self, proto):
vocab = [
(self.original_tokenizer.pad_token, 0.0),
(self.original_tokenizer.eos_token, 0.0),
(self.original_tokenizer.bos_token, 0.0),
]
for piece in proto.pieces[3:]:
if piece.piece == "<0x09>":
vocab += [("\t", piece.score)]
else:
vocab += [(piece.piece, piece.score)]
return vocab
def pre_tokenizer(self, replacement, add_prefix_space):
return None
def unk_id(self, proto):
unk_id = 3
return unk_id
def decoder(self, replacement, add_prefix_space):
return decoders.Sequence(
[
decoders.Replace("▁", " "),
decoders.ByteFallback(),
decoders.Fuse(),
]
)
def tokenizer(self, proto):
model_type = proto.trainer_spec.model_type
vocab_scores = self.vocab(proto)
if model_type == 1:
import tokenizers
if version.parse(tokenizers.__version__) < version.parse("0.14.0"):
tokenizer = Tokenizer(Unigram(vocab_scores, 0))
else:
tokenizer = Tokenizer(Unigram(vocab_scores, 0, byte_fallback=True))
elif model_type == 2:
_, merges = GemmaSentencePieceExtractor(self.original_tokenizer.vocab_file).extract(vocab_scores)
bpe_vocab = {word: i for i, (word, _score) in enumerate(vocab_scores)}
tokenizer = Tokenizer(
BPE(
bpe_vocab,
merges,
unk_token=proto.trainer_spec.unk_piece,
fuse_unk=True,
byte_fallback=True,
dropout=None,
)
)
tokenizer.add_special_tokens(
[
AddedToken("<pad>", normalized=False, special=True),
AddedToken("<eos>", normalized=False, special=True),
AddedToken("<bos>", normalized=False, special=True),
AddedToken("<unk>", normalized=False, special=True),
]
)
else:
raise Exception(
"You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
)
user_defined_symbols = [
AddedToken(token, normalized=False, special=False) for token in proto.trainer_spec.user_defined_symbols
]
tokenizer.add_tokens(user_defined_symbols)
return tokenizer
class LlamaConverter(SpmConverter):
handle_byte_fallback = True
def vocab(self, proto):
vocab = [
("<unk>", 0.0),
("<s>", 0.0),
("</s>", 0.0),
]
vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
return vocab
def unk_id(self, proto):
unk_id = 0
return unk_id
def decoder(self, replacement, add_prefix_space):
sequence = [
decoders.Replace("▁", " "),
decoders.ByteFallback(),
decoders.Fuse(),
]
if add_prefix_space:
sequence += [decoders.Strip(content=" ", left=1)]
return decoders.Sequence(sequence)
def tokenizer(self, proto):
model_type = proto.trainer_spec.model_type
vocab_scores = self.vocab(proto)
if model_type == 1:
import tokenizers
if version.parse(tokenizers.__version__) < version.parse("0.14.0"):
tokenizer = Tokenizer(Unigram(vocab_scores, 0))
else:
tokenizer = Tokenizer(Unigram(vocab_scores, 0, byte_fallback=True))
elif model_type == 2:
_, merges = SentencePieceExtractor(self.original_tokenizer.vocab_file).extract(vocab_scores)
bpe_vocab = {word: i for i, (word, _score) in enumerate(vocab_scores)}
tokenizer = Tokenizer(
BPE(bpe_vocab, merges, unk_token=proto.trainer_spec.unk_piece, fuse_unk=True, byte_fallback=True)
)
tokenizer.add_special_tokens(
[
AddedToken("<unk>", normalized=False, special=True),
AddedToken("<s>", normalized=False, special=True),
AddedToken("</s>", normalized=False, special=True),
]
)
else:
raise Exception(
"You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
)
return tokenizer
def normalizer(self, proto):
sequence = []
if hasattr(self.original_tokenizer, "add_prefix_space"):
if self.original_tokenizer.add_prefix_space:
sequence += [normalizers.Prepend(prepend="▁")]
sequence += [normalizers.Replace(pattern=" ", content="▁")]
return normalizers.Sequence(sequence)
def pre_tokenizer(self, replacement, add_prefix_space):
return None
def post_processor(self):
return None
def converted(self) -> Tokenizer:
ot = self.original_tokenizer
vocab = ot.encoder
merges = list(ot.bpe_ranks.keys())
tokenizer = Tokenizer(
BPE(
vocab=vocab,
merges=merges,
dropout=None,
continuing_subword_prefix="",
end_of_word_suffix="",
fuse_unk=False,
unk_token=self.original_tokenizer.unk_token,
)
)
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=ot.add_prefix_space)
tokenizer.decoder = decoders.ByteLevel()
cls = str(self.original_tokenizer.cls_token)
sep = str(self.original_tokenizer.sep_token)
cls_token_id = self.original_tokenizer.cls_token_id
sep_token_id = self.original_tokenizer.sep_token_id
tokenizer.post_processor = processors.TemplateProcessing(
single=f"{cls} $A {sep}",
pair=f"{cls} $A {sep} $B {sep}",
special_tokens=[
(cls, cls_token_id),
(sep, sep_token_id),
],
)
return tokenizer
SLOW_TO_FAST_CONVERTERS = {
"AlbertTokenizer": AlbertConverter,
"BartTokenizer": RobertaConverter,
"BarthezTokenizer": BarthezConverter,
"BertTokenizer": BertConverter,
"BigBirdTokenizer": BigBirdConverter,
"BlenderbotTokenizer": BlenderbotConverter,
"CamembertTokenizer": CamembertConverter,
"CLIPTokenizer": CLIPConverter,
"CodeGenTokenizer": GPT2Converter,
"ConvBertTokenizer": BertConverter,
"DebertaTokenizer": DebertaConverter,
"DebertaV2Tokenizer": DebertaV2Converter,
"DistilBertTokenizer": BertConverter,
"DPRReaderTokenizer": BertConverter,
"DPRQuestionEncoderTokenizer": BertConverter,
"DPRContextEncoderTokenizer": BertConverter,
"ElectraTokenizer": BertConverter,
"FNetTokenizer": AlbertConverter,
"FunnelTokenizer": FunnelConverter,
"GPT2Tokenizer": GPT2Converter,
"HerbertTokenizer": HerbertConverter,
"LayoutLMTokenizer": BertConverter,
"LayoutLMv2Tokenizer": BertConverter,
"LayoutLMv3Tokenizer": RobertaConverter,
"LayoutXLMTokenizer": XLMRobertaConverter,
"LongformerTokenizer": RobertaConverter,
"LEDTokenizer": RobertaConverter,
"LxmertTokenizer": BertConverter,
"MarkupLMTokenizer": MarkupLMConverter,
"MBartTokenizer": MBartConverter,
"MBart50Tokenizer": MBart50Converter,
"MPNetTokenizer": MPNetConverter,
"MobileBertTokenizer": BertConverter,
"MvpTokenizer": RobertaConverter,
"NllbTokenizer": NllbConverter,
"OpenAIGPTTokenizer": OpenAIGPTConverter,
"PegasusTokenizer": PegasusConverter,
"Qwen2Tokenizer": Qwen2Converter,
"RealmTokenizer": BertConverter,
"ReformerTokenizer": ReformerConverter,
"RemBertTokenizer": RemBertConverter,
"RetriBertTokenizer": BertConverter,
"RobertaTokenizer": RobertaConverter,
"RoFormerTokenizer": RoFormerConverter,
"SeamlessM4TTokenizer": SeamlessM4TConverter,
"SqueezeBertTokenizer": BertConverter,
"T5Tokenizer": T5Converter,
"UdopTokenizer": UdopConverter,
"WhisperTokenizer": WhisperConverter,
"XLMRobertaTokenizer": XLMRobertaConverter,
"XLNetTokenizer": XLNetConverter,
"SplinterTokenizer": SplinterConverter,
"XGLMTokenizer": XGLMConverter,
"LlamaTokenizer": LlamaConverter,
"CodeLlamaTokenizer": LlamaConverter,
"GemmaTokenizer": GemmaConvert,
}
def convert_slow_tokenizer(transformer_tokenizer) -> Tokenizer:
"""
Utilities to convert a slow tokenizer instance in a fast tokenizer instance.
Args:
transformer_tokenizer ([`~tokenization_utils_base.PreTrainedTokenizer`]):
Instance of a slow tokenizer to convert in the backend tokenizer for
[`~tokenization_utils_base.PreTrainedTokenizerFast`].
Return:
A instance of [`~tokenizers.Tokenizer`] to be used as the backend tokenizer of a
[`~tokenization_utils_base.PreTrainedTokenizerFast`]
"""
tokenizer_class_name = transformer_tokenizer.__class__.__name__
if tokenizer_class_name not in SLOW_TO_FAST_CONVERTERS:
raise ValueError(
f"An instance of tokenizer class {tokenizer_class_name} cannot be converted in a Fast tokenizer instance."
" No converter was found. Currently available slow->fast convertors:"
f" {list(SLOW_TO_FAST_CONVERTERS.keys())}"
)
converter_class = SLOW_TO_FAST_CONVERTERS[tokenizer_class_name]
return converter_class(transformer_tokenizer).converted()