Transformers 源码解析(四)
.\deepspeed.py
import warnings
warnings.warn(
"transformers.deepspeed module is deprecated and will be removed in a future version. Please import deepspeed modules directly from transformers.integrations",
FutureWarning,
)
from .integrations.deepspeed import (
HfDeepSpeedConfig,
HfTrainerDeepSpeedConfig,
deepspeed_config,
deepspeed_init,
deepspeed_load_checkpoint,
deepspeed_optim_sched,
is_deepspeed_available,
is_deepspeed_zero3_enabled,
set_hf_deepspeed_config,
unset_hf_deepspeed_config,
)
.\dependency_versions_check.py
from .dependency_versions_table import deps
from .utils.versions import require_version, require_version_core
pkgs_to_check_at_runtime = [
"python",
"tqdm",
"regex",
"requests",
"packaging",
"filelock",
"numpy",
"tokenizers",
"huggingface-hub",
"safetensors",
"accelerate",
"pyyaml",
]
for pkg in pkgs_to_check_at_runtime:
if pkg in deps:
if pkg == "tokenizers":
from .utils import is_tokenizers_available
if not is_tokenizers_available():
continue
elif pkg == "accelerate":
from .utils import is_accelerate_available
if not is_accelerate_available():
continue
require_version_core(deps[pkg])
else:
raise ValueError(f"can't find {pkg} in {deps.keys()}, check dependency_versions_table.py")
def dep_version_check(pkg, hint=None):
require_version(deps[pkg], hint)
.\dependency_versions_table.py
deps = {
"Pillow": "Pillow>=10.0.1,<=15.0",
"accelerate": "accelerate>=0.21.0",
"av": "av==9.2.0",
"beautifulsoup4": "beautifulsoup4",
"codecarbon": "codecarbon==1.2.0",
"cookiecutter": "cookiecutter==1.7.3",
"dataclasses": "dataclasses",
"datasets": "datasets!=2.5.0",
"decord": "decord==0.6.0",
"deepspeed": "deepspeed>=0.9.3",
"diffusers": "diffusers",
"dill": "dill<0.3.5",
"evaluate": "evaluate>=0.2.0",
"faiss-cpu": "faiss-cpu",
"fastapi": "fastapi",
"filelock": "filelock",
"flax": "flax>=0.4.1,<=0.7.0",
"fsspec": "fsspec<2023.10.0",
"ftfy": "ftfy",
"fugashi": "fugashi>=1.0",
"GitPython": "GitPython<3.1.19",
"hf-doc-builder": "hf-doc-builder>=0.3.0",
"huggingface-hub": "huggingface-hub>=0.19.3,<1.0",
"importlib_metadata": "importlib_metadata",
"ipadic": "ipadic>=1.0.0,<2.0",
"isort": "isort>=5.5.4",
"jax": "jax>=0.4.1,<=0.4.13",
"jaxlib": "jaxlib>=0.4.1,<=0.4.13",
"jieba": "jieba",
"kenlm": "kenlm",
"keras": "keras<2.16",
"keras-nlp": "keras-nlp>=0.3.1",
"librosa": "librosa",
"nltk": "nltk",
"natten": "natten>=0.14.6,<0.15.0",
"numpy": "numpy>=1.17",
"onnxconverter-common": "onnxconverter-common",
"onnxruntime-tools": "onnxruntime-tools>=1.4.2",
"onnxruntime": "onnxruntime>=1.4.0",
"opencv-python": "opencv-python",
"optuna": "optuna",
"optax": "optax>=0.0.8,<=0.1.4",
"packaging": "packaging>=20.0",
"parameterized": "parameterized",
"phonemizer": "phonemizer",
"protobuf": "protobuf",
"psutil": "psutil",
"pyyaml": "pyyaml>=5.1",
"pydantic": "pydantic",
"pytest": "pytest>=7.2.0,<8.0.0",
"pytest-timeout": "pytest-timeout",
"pytest-xdist": "pytest-xdist",
"python": "python>=3.8.0",
"ray[t
"pyctcdecode": "pyctcdecode>=0.4.0",
# 定义依赖项:pyctcdecode 库,版本需大于或等于 0.4.0
"tqdm": "tqdm>=4.27",
# 定义依赖项:tqdm 库,版本需大于或等于 4.27
"unidic": "unidic>=1.0.2",
# 定义依赖项:unidic 库,版本需大于或等于 1.0.2
"unidic_lite": "unidic_lite>=1.0.7",
# 定义依赖项:unidic_lite 库,版本需大于或等于 1.0.7
"urllib3": "urllib3<2.0.0",
# 定义依赖项:urllib3 库,版本需小于 2.0.0
"uvicorn": "uvicorn",
# 定义依赖项:uvicorn 库,无指定版本要求
}
注释:
# 这行代码表示一个代码块的结束,对应于一个以 '{' 开始的代码块的结束
.\dynamic_module_utils.py
"""Utilities to dynamically load objects from the Hub."""
import filecmp
import importlib
import os
import re
import shutil
import signal
import sys
import typing
import warnings
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from huggingface_hub import try_to_load_from_cache
from .utils import (
HF_MODULES_CACHE,
TRANSFORMERS_DYNAMIC_MODULE_NAME,
cached_file,
extract_commit_hash,
is_offline_mode,
logging,
)
logger = logging.get_logger(__name__)
def init_hf_modules():
"""
Creates the cache directory for modules with an init, and adds it to the Python path.
"""
if HF_MODULES_CACHE in sys.path:
return
sys.path.append(HF_MODULES_CACHE)
os.makedirs(HF_MODULES_CACHE, exist_ok=True)
init_path = Path(HF_MODULES_CACHE) / "__init__.py"
if not init_path.exists():
init_path.touch()
importlib.invalidate_caches()
def create_dynamic_module(name: Union[str, os.PathLike]):
"""
Creates a dynamic module in the cache directory for modules.
Args:
name (`str` or `os.PathLike`):
The name of the dynamic module to create.
"""
init_hf_modules()
dynamic_module_path = (Path(HF_MODULES_CACHE) / name).resolve()
if not dynamic_module_path.parent.exists():
create_dynamic_module(dynamic_module_path.parent)
os.makedirs(dynamic_module_path, exist_ok=True)
init_path = dynamic_module_path / "__init__.py"
if not init_path.exists():
init_path.touch()
importlib.invalidate_caches()
def get_relative_imports(module_file: Union[str, os.PathLike]) -> List[str]:
"""
Get the list of modules that are relatively imported in a module file.
Args:
module_file (`str` or `os.PathLike`): The module file to inspect.
Returns:
`List[str]`: The list of relative imports in the module.
"""
with open(module_file, "r", encoding="utf-8") as f:
content = f.read()
relative_imports = re.findall(r"^\s*import\s+\.(\S+)\s*$", content, flags=re.MULTILINE)
relative_imports += re.findall(r"^\s*from\s+\.(\S+)\s+import", content, flags=re.MULTILINE)
return list(set(relative_imports))
def get_relative_import_files(module_file: Union[str, os.PathLike]) -> List[str]:
"""
Get the list of all files that are needed for a given module. Note that this function recurses through the relative
imports (if a imports b and b imports c, it will return module files for b and c).
Args:
module_file (`str` or `os.PathLike`): The module file to inspect.
Returns:
`List[str]`: The list of all relative imports a given module needs (recursively), which will give us the list
of module files a given module needs.
"""
no_change = False
files_to_check = [module_file]
all_relative_imports = []
while not no_change:
new_imports = []
for f in files_to_check:
new_imports.extend(get_relative_imports(f))
module_path = Path(module_file).parent
new_import_files = [str(module_path / m) for m in new_imports]
new_import_files = [f for f in new_import_files if f not in all_relative_imports]
files_to_check = [f"{f}.py" for f in new_import_files]
no_change = len(new_import_files) == 0
all_relative_imports.extend(files_to_check)
return all_relative_imports
def get_imports(filename: Union[str, os.PathLike]) -> List[str]:
"""
Extracts all the libraries (not relative imports this time) that are imported in a file.
Args:
filename (`str` or `os.PathLike`): The module file to inspect.
Returns:
`List[str]`: The list of all packages required to use the input module.
"""
with open(filename, "r", encoding="utf-8") as f:
content = f.read()
content = re.sub(r"\s*try\s*:\s*.*?\s*except\s*.*?:", "", content, flags=re.MULTILINE | re.DOTALL)
imports = re.findall(r"^\s*import\s+(\S+)\s*$", content, flags=re.MULTILINE)
imports += re.findall(r"^\s*from\s+(\S+)\s+import", content, flags=re.MULTILINE)
imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")]
return list(set(imports))
def check_imports(filename: Union[str, os.PathLike]) -> List[str]:
"""
Check if the current Python environment contains all the libraries that are imported in a file. Will raise if a
library is missing.
Args:
filename (`str` or `os.PathLike`): The module file to check.
Returns:
`List[str]`: The list of relative imports in the file.
"""
imports = get_imports(filename)
missing_packages = []
for imp in imports:
try:
importlib.import_module(imp)
except ImportError:
missing_packages.append(imp)
if len(missing_packages) > 0:
raise ImportError(
"This modeling file requires the following packages that were not found in your environment: "
f"{', '.join(missing_packages)}. Run `pip install {' '.join(missing_packages)}`"
)
return get_relative_imports(filename)
def get_class_in_module(class_name: str, module_path: Union[str, os.PathLike]) -> typing.Type:
"""
Import a module on the cache directory for modules and extract a class from it.
Args:
class_name (`str`): The name of the class to import.
module_path (`str` or `os.PathLike`): The path to the module to import.
Returns:
`typing.Type`: The class looked for.
"""
name = os.path.normpath(module_path).replace(".py", "").replace(os.path.sep, ".")
module_path = str(Path(HF_MODULES_CACHE) / module_path)
module = importlib.machinery.SourceFileLoader(name, module_path).load_module()
return getattr(module, class_name)
def get_cached_module_file(
pretrained_model_name_or_path: Union[str, os.PathLike],
module_file: str,
cache_dir: Optional[Union[str, os.PathLike]] = None,
force_download: bool = False,
resume_download: bool = False,
proxies: Optional[Dict[str, str]] = None,
token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
local_files_only: bool = False,
repo_type: Optional[str] = None,
_commit_hash: Optional[str] = None,
**deprecated_kwargs,
) -> str:
"""
Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached
Transformers module.
"""
use_auth_token = deprecated_kwargs.pop("use_auth_token", None)
if use_auth_token is not None:
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
FutureWarning,
)
if token is not None:
raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
token = use_auth_token
if is_offline_mode() and not local_files_only:
logger.info("Offline mode: forcing local_files_only=True")
local_files_only = True
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
is_local = os.path.isdir(pretrained_model_name_or_path)
if is_local:
submodule = os.path.basename(pretrained_model_name_or_path)
else:
submodule = pretrained_model_name_or_path.replace("/", os.path.sep)
cached_module = try_to_load_from_cache(
pretrained_model_name_or_path, module_file, cache_dir=cache_dir, revision=_commit_hash, repo_type=repo_type
)
new_files = []
try:
resolved_module_file = cached_file(
pretrained_model_name_or_path,
module_file,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
token=token,
revision=revision,
repo_type=repo_type,
_commit_hash=_commit_hash,
)
if not is_local and cached_module != resolved_module_file:
new_files.append(module_file)
except EnvironmentError:
logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
raise
modules_needed = check_imports(resolved_module_file)
full_submodule = TRANSFORMERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule
create_dynamic_module(full_submodule)
submodule_path = Path(HF_MODULES_CACHE) / full_submodule
if submodule == os.path.basename(pretrained_model_name_or_path):
if not (submodule_path / module_file).exists() or not filecmp.cmp(
resolved_module_file, str(submodule_path / module_file)
):
shutil.copy(resolved_module_file, submodule_path / module_file)
importlib.invalidate_caches()
for module_needed in modules_needed:
module_needed = f"{module_needed}.py"
module_needed_file = os.path.join(pretrained_model_name_or_path, module_needed)
if not (submodule_path / module_needed).exists() or not filecmp.cmp(
module_needed_file, str(submodule_path / module_needed)
):
shutil.copy(module_needed_file, submodule_path / module_needed)
importlib.invalidate_caches()
else:
commit_hash = extract_commit_hash(resolved_module_file, _commit_hash)
submodule_path = submodule_path / commit_hash
full_submodule = full_submodule + os.path.sep + commit_hash
create_dynamic_module(full_submodule)
if not (submodule_path / module_file).exists():
shutil.copy(resolved_module_file, submodule_path / module_file)
importlib.invalidate_caches()
for module_needed in modules_needed:
if not (submodule_path / f"{module_needed}.py").exists():
get_cached_module_file(
pretrained_model_name_or_path,
f"{module_needed}.py",
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
token=token,
revision=revision,
local_files_only=local_files_only,
_commit_hash=commit_hash,
)
new_files.append(f"{module_needed}.py")
if len(new_files) > 0 and revision is None:
new_files = "\n".join([f"- {f}" for f in new_files])
repo_type_str = "" if repo_type is None else f"{repo_type}s/"
url = f"https://huggingface.co/{repo_type_str}{pretrained_model_name_or_path}"
logger.warning(
f"A new version of the following files was downloaded from {url}:\n{new_files}"
"\n. Make sure to double-check they do not contain any added malicious code. To avoid downloading new "
"versions of the code file, you can pin a revision."
)
return os.path.join(full_submodule, module_file)
def get_class_from_dynamic_module(
class_reference: str,
pretrained_model_name_or_path: Union[str, os.PathLike],
cache_dir: Optional[Union[str, os.PathLike]] = None,
force_download: bool = False,
resume_download: bool = False,
proxies: Optional[Dict[str, str]] = None,
token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
local_files_only: bool = False,
repo_type: Optional[str] = None,
code_revision: Optional[str] = None,
**kwargs,
) -> typing.Type:
"""
从本地文件夹或模型仓库中提取一个类的定义。
<Tip warning={true}>
调用此函数将执行本地或从 Hub 下载的模块文件中的代码。因此,应仅在可信任的仓库中调用。
</Tip>
# 加载指定类的配置和模型数据
Args:
class_reference (`str`):
要加载的类的完整名称,包括其模块和可选的存储库。
pretrained_model_name_or_path (`str` or `os.PathLike`):
可以是以下之一:
- 字符串,表示在 huggingface.co 模型仓库中预训练模型配置的 *模型 ID*。
- 目录路径,包含使用 [`~PreTrainedTokenizer.save_pretrained`] 方法保存的配置文件,例如 `./my_model_directory/`。
当 `class_reference` 没有指定其他存储库时使用。
module_file (`str`):
包含要查找的类的模块文件名。
class_name (`str`):
要在模块中导入的类的名称。
cache_dir (`str` or `os.PathLike`, *optional*):
下载预训练模型配置时应该缓存的目录路径,如果不想使用标准缓存。
force_download (`bool`, *optional*, defaults to `False`):
是否强制下载配置文件,并覆盖已存在的缓存版本。
resume_download (`bool`, *optional*, defaults to `False`):
是否删除未完全接收的文件。如果存在这样的文件,则尝试恢复下载。
proxies (`Dict[str, str]`, *optional*):
使用的代理服务器字典,按协议或端点分组,例如 `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`。
代理服务器会在每个请求上使用。
token (`str` or `bool`, *optional*):
用作远程文件的 HTTP Bearer 授权令牌。如果是 `True`,将使用运行 `huggingface-cli login` 时生成的令牌(存储在 `~/.huggingface` 中)。
revision (`str`, *optional*, defaults to `"main"`):
要使用的特定模型版本。可以是分支名称、标签名称或提交 ID。由于我们在 huggingface.co 上使用基于 Git 的系统存储模型和其他工件,因此 `revision` 可以是 Git 允许的任何标识符。
local_files_only (`bool`, *optional*, defaults to `False`):
如果为 `True`,将仅尝试从本地文件加载 tokenizer 配置。
repo_type (`str`, *optional*):
指定存储库类型(在下载时特别有用,例如从空间下载)。
code_revision (`str`, *optional*, defaults to `"main"`):
在 Hub 上使用的代码的特定版本。如果代码存储在与模型其余部分不同的存储库中,可以是分支名称、标签名称或提交 ID。由于我们在 huggingface.co 上使用基于 Git 的系统存储模型和其他工件,因此 `revision` 可以是 Git 允许的任何标识符。
Passing `token=True` is required when you want to use a private model.
</Tip>
Returns:
`typing.Type`: The class, dynamically imported from the module.
Examples:
```
# Download module `modeling.py` from huggingface.co and cache then extract the class `MyBertModel` from this
# module.
cls = get_class_from_dynamic_module("modeling.MyBertModel", "sgugger/my-bert-model")
# Download module `modeling.py` from a given repo and cache then extract the class `MyBertModel` from this
# module.
cls = get_class_from_dynamic_module("sgugger/my-bert-model--modeling.MyBertModel", "sgugger/another-bert-model")
```"""
use_auth_token = kwargs.pop("use_auth_token", None)
if use_auth_token is not None:
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
FutureWarning,
)
if token is not None:
raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
token = use_auth_token
if "--" in class_reference:
repo_id, class_reference = class_reference.split("--")
else:
repo_id = pretrained_model_name_or_path
module_file, class_name = class_reference.split(".")
if code_revision is None and pretrained_model_name_or_path == repo_id:
code_revision = revision
final_module = get_cached_module_file(
repo_id,
module_file + ".py",
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
token=token,
revision=code_revision,
local_files_only=local_files_only,
repo_type=repo_type,
)
return get_class_in_module(class_name, final_module)
def custom_object_save(obj: Any, folder: Union[str, os.PathLike], config: Optional[Dict] = None) -> List[str]:
"""
Save the modeling files corresponding to a custom model/configuration/tokenizer etc. in a given folder. Optionally
adds the proper fields in a config.
Args:
obj (`Any`): The object for which to save the module files.
folder (`str` or `os.PathLike`): The folder where to save.
config (`PretrainedConfig` or dictionary, `optional`):
A config in which to register the auto_map corresponding to this custom object.
Returns:
`List[str]`: The list of files saved.
"""
if obj.__module__ == "__main__":
logger.warning(
f"We can't save the code defining {obj} in {folder} as it's been defined in __main__. You should put "
"this code in a separate module so we can include it in the saved folder and make it easier to share via "
"the Hub."
)
return
def _set_auto_map_in_config(_config):
module_name = obj.__class__.__module__
last_module = module_name.split(".")[-1]
full_name = f"{last_module}.{obj.__class__.__name__}"
if "Tokenizer" in full_name:
slow_tokenizer_class = None
fast_tokenizer_class = None
if obj.__class__.__name__.endswith("Fast"):
fast_tokenizer_class = f"{last_module}.{obj.__class__.__name__}"
if getattr(obj, "slow_tokenizer_class", None) is not None:
slow_tokenizer = getattr(obj, "slow_tokenizer_class")
slow_tok_module_name = slow_tokenizer.__module__
last_slow_tok_module = slow_tok_module_name.split(".")[-1]
slow_tokenizer_class = f"{last_slow_tok_module}.{slow_tokenizer.__name__}"
else:
slow_tokenizer_class = f"{last_module}.{obj.__class__.__name__}"
full_name = (slow_tokenizer_class, fast_tokenizer_class)
if isinstance(_config, dict):
auto_map = _config.get("auto_map", {})
auto_map[obj._auto_class] = full_name
_config["auto_map"] = auto_map
elif getattr(_config, "auto_map", None) is not None:
_config.auto_map[obj._auto_class] = full_name
else:
_config.auto_map = {obj._auto_class: full_name}
if isinstance(config, (list, tuple)):
for cfg in config:
_set_auto_map_in_config(cfg)
elif config is not None:
_set_auto_map_in_config(config)
result = []
object_file = sys.modules[obj.__module__].__file__
dest_file = Path(folder) / (Path(object_file).name)
shutil.copy(object_file, dest_file)
result.append(dest_file)
for needed_file in get_relative_import_files(object_file):
dest_file = Path(folder) / (Path(needed_file).name)
shutil.copy(needed_file, dest_file)
result.append(dest_file)
return result
def _raise_timeout_error(signum, frame):
raise ValueError(
"Loading this model requires you to execute custom code contained in the model repository on your local "
"machine. Please set the option `trust_remote_code=True` to permit loading of this model."
)
TIME_OUT_REMOTE_CODE = 15
def resolve_trust_remote_code(trust_remote_code, model_name, has_local_code, has_remote_code):
if trust_remote_code is None:
if has_local_code:
trust_remote_code = False
elif has_remote_code and TIME_OUT_REMOTE_CODE > 0:
try:
signal.signal(signal.SIGALRM, _raise_timeout_error)
signal.alarm(TIME_OUT_REMOTE_CODE)
while trust_remote_code is None:
answer = input(
f"The repository for {model_name} contains custom code which must be executed to correctly "
f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n"
f"You can avoid this prompt in future by passing the argument `trust_remote_code=True`.\n\n"
f"Do you wish to run the custom code? [y/N] "
)
if answer.lower() in ["yes", "y", "1"]:
trust_remote_code = True
elif answer.lower() in ["no", "n", "0", ""]:
trust_remote_code = False
signal.alarm(0)
except Exception:
raise ValueError(
f"The repository for {model_name} contains custom code which must be executed to correctly "
f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n"
f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
)
elif has_remote_code:
_raise_timeout_error(None, None)
if has_remote_code and not has_local_code and not trust_remote_code:
raise ValueError(
f"Loading {model_name} requires you to execute the configuration file in that"
" repo on your local machine. Make sure you have read the code there to avoid malicious use, then"
" set the option `trust_remote_code=True` to remove this error."
)
return trust_remote_code
.\feature_extraction_sequence_utils.py
`
"""
Sequence feature extraction class for common feature extractors to preprocess sequences.
"""
from typing import Dict, List, Optional, Union
import numpy as np
from .feature_extraction_utils import BatchFeature, FeatureExtractionMixin
from .utils import PaddingStrategy, TensorType, is_tf_tensor, is_torch_tensor, logging, to_numpy
logger = logging.get_logger(__name__)
class SequenceFeatureExtractor(FeatureExtractionMixin):
"""
This is a general feature extraction class for speech recognition.
Args:
feature_size (`int`):
The feature dimension of the extracted features.
sampling_rate (`int`):
The sampling rate at which the audio files should be digitalized expressed in hertz (Hz).
padding_value (`float`):
The value that is used to fill the padding values / vectors.
"""
def __init__(self, feature_size: int, sampling_rate: int, padding_value: float, **kwargs):
self.feature_size = feature_size
self.sampling_rate = sampling_rate
self.padding_value = padding_value
self.padding_side = kwargs.pop("padding_side", "right")
self.return_attention_mask = kwargs.pop("return_attention_mask", True)
super().__init__(**kwargs)
def pad(
self,
processed_features: Union[
BatchFeature,
List[BatchFeature],
Dict[str, BatchFeature],
Dict[str, List[BatchFeature]],
List[Dict[str, BatchFeature]],
],
padding: Union[bool, str, PaddingStrategy] = True,
max_length: Optional[int] = None,
truncation: bool = False,
pad_to_multiple_of: Optional[int] = None,
return_attention_mask: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
):
"""
Pad sequences of features to the same length.
Args:
processed_features (Union[BatchFeature, List[BatchFeature], Dict[str, BatchFeature], ...]):
The processed features to be padded.
padding (Union[bool, str, PaddingStrategy]):
Strategy for padding. Can be a boolean, string, or enum from PaddingStrategy.
max_length (Optional[int]):
Maximum length to pad or truncate the sequences.
truncation (bool):
Whether to truncate sequences that exceed `max_length`.
pad_to_multiple_of (Optional[int]):
Pad to a multiple of this value.
return_attention_mask (Optional[bool]):
Whether to return attention masks.
return_tensors (Optional[Union[str, TensorType]]):
The type of tensor(s) to be returned.
Returns:
Padded sequences of features.
"""
pass
def _pad(
self,
processed_features: Union[Dict[str, np.ndarray], BatchFeature],
max_length: Optional[int] = None,
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
pad_to_multiple_of: Optional[int] = None,
return_attention_mask: Optional[bool] = None,
):
"""
Internal method for padding sequences of features.
Args:
processed_features (Union[Dict[str, np.ndarray], BatchFeature]):
The processed features to be padded.
max_length (Optional[int]):
Maximum length to pad or truncate the sequences.
padding_strategy (PaddingStrategy):
Strategy for padding. Default is DO_NOT_PAD.
pad_to_multiple_of (Optional[int]):
Pad to a multiple of this value.
return_attention_mask (Optional[bool]):
Whether to return attention masks.
"""
pass
def _truncate(
self,
processed_features: Union[Dict[str, np.ndarray], BatchFeature],
max_length: Optional[int] = None,
pad_to_multiple_of: Optional[int] = None,
truncation: Optional[bool] = None,
):
"""
Internal method for truncating sequences of features.
Args:
processed_features (Union[Dict[str, np.ndarray], BatchFeature]):
The processed features to be truncated.
max_length (Optional[int]):
Maximum length to truncate the sequences.
pad_to_multiple_of (Optional[int]):
Pad to a multiple of this value.
truncation (Optional[bool]):
Whether to truncate sequences that exceed `max_length`.
"""
pass
"""
Truncate inputs to predefined length or max length in the batch
Args:
processed_features(`Union[Dict[str, np.ndarray], BatchFeature]`):
Dictionary of input values (`np.ndarray[float]`) / input vectors (`List[np.ndarray[float]]`) or batch
of inputs values (`List[np.ndarray[int]]`) / input vectors (`List[np.ndarray[int]]`)
max_length (`int`, *optional*):
maximum length of the returned list and optionally padding length (see below)
pad_to_multiple_of (`int`, *optional*) :
Integer if set will pad the sequence to a multiple of the provided value. This is especially useful to
enable the use of Tensor Core on NVIDIA hardware with compute capability `>= 7.5` (Volta), or on TPUs
which benefit from having sequence lengths be a multiple of 128.
truncation (`bool`, *optional*):
Activates truncation to cut input sequences longer than `max_length` to `max_length`.
"""
if not truncation:
return processed_features
elif truncation and max_length is None:
raise ValueError("When setting ``truncation=True``, make sure that ``max_length`` is defined.")
required_input = processed_features[self.model_input_names[0]]
if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
needs_to_be_truncated = len(required_input) > max_length
if needs_to_be_truncated:
processed_features[self.model_input_names[0]] = processed_features[self.model_input_names[0]][:max_length]
if "attention_mask" in processed_features:
processed_features["attention_mask"] = processed_features["attention_mask"][:max_length]
return processed_features
def _get_padding_strategies(self, padding=False, max_length=None):
"""
Find the correct padding strategy
"""
if padding is not False:
if padding is True:
padding_strategy = PaddingStrategy.LONGEST
elif not isinstance(padding, PaddingStrategy):
padding_strategy = PaddingStrategy(padding)
elif isinstance(padding, PaddingStrategy):
padding_strategy = padding
else:
padding_strategy = PaddingStrategy.DO_NOT_PAD
if max_length is None:
if padding_strategy == PaddingStrategy.MAX_LENGTH:
raise ValueError(
f"When setting ``padding={PaddingStrategy.MAX_LENGTH}``, make sure that max_length is defined"
)
if padding_strategy != PaddingStrategy.DO_NOT_PAD and (self.padding_value is None):
raise ValueError(
"Asking to pad but the feature_extractor does not have a padding value. Please select a value to use"
" as `padding_value`. For example: `feature_extractor.padding_value = 0.0`."
)
return padding_strategy
.\feature_extraction_utils.py
"""
用于常见特征提取器的特征提取保存/加载的类。
"""
import copy
import json
import os
import warnings
from collections import UserDict
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
import numpy as np
from .dynamic_module_utils import custom_object_save
from .utils import (
FEATURE_EXTRACTOR_NAME,
PushToHubMixin,
TensorType,
add_model_info_to_auto_map,
cached_file,
copy_func,
download_url,
is_flax_available,
is_jax_tensor,
is_numpy_array,
is_offline_mode,
is_remote_url,
is_tf_available,
is_torch_available,
is_torch_device,
is_torch_dtype,
logging,
requires_backends,
)
if TYPE_CHECKING:
if is_torch_available():
import torch
logger = logging.get_logger(__name__)
PreTrainedFeatureExtractor = Union["SequenceFeatureExtractor"]
class BatchFeature(UserDict):
r"""
Holds the output of the [`~SequenceFeatureExtractor.pad`] and feature extractor specific `__call__` methods.
This class is derived from a python dictionary and can be used as a dictionary.
Args:
data (`dict`, *optional*):
Dictionary of lists/arrays/tensors returned by the __call__/pad methods ('input_values', 'attention_mask',
etc.).
tensor_type (`Union[None, str, TensorType]`, *optional*):
You can give a tensor_type here to convert the lists of integers in PyTorch/TensorFlow/Numpy Tensors at
initialization.
"""
def __init__(self, data: Optional[Dict[str, Any]] = None, tensor_type: Union[None, str, TensorType] = None):
super().__init__(data)
self.convert_to_tensors(tensor_type=tensor_type)
def __getitem__(self, item: str) -> Union[Any]:
"""
If the key is a string, returns the value of the dict associated to `key` ('input_values', 'attention_mask',
etc.).
"""
if isinstance(item, str):
return self.data[item]
else:
raise KeyError("Indexing with integers is not available when using Python based feature extractors")
def __getattr__(self, item: str):
try:
return self.data[item]
except KeyError:
raise AttributeError
def __getstate__(self):
return {"data": self.data}
def __setstate__(self, state):
if "data" in state:
self.data = state["data"]
def keys(self):
return self.data.keys()
def values(self):
return self.data.values()
def items(self):
return self.data.items()
def _get_is_as_tensor_fns(self, tensor_type: Optional[Union[str, TensorType]] = None):
if tensor_type is None:
return None, None
if not isinstance(tensor_type, TensorType):
tensor_type = TensorType(tensor_type)
if tensor_type == TensorType.TENSORFLOW:
if not is_tf_available():
raise ImportError(
"Unable to convert output to TensorFlow tensors format, TensorFlow is not installed."
)
import tensorflow as tf
as_tensor = tf.constant
is_tensor = tf.is_tensor
elif tensor_type == TensorType.PYTORCH:
if not is_torch_available():
raise ImportError("Unable to convert output to PyTorch tensors format, PyTorch is not installed.")
import torch
def as_tensor(value):
if isinstance(value, (list, tuple)) and len(value) > 0 and isinstance(value[0], np.ndarray):
value = np.array(value)
return torch.tensor(value)
is_tensor = torch.is_tensor
elif tensor_type == TensorType.JAX:
if not is_flax_available():
raise ImportError("Unable to convert output to JAX tensors format, JAX is not installed.")
import jax.numpy as jnp
as_tensor = jnp.array
is_tensor = is_jax_tensor
else:
def as_tensor(value, dtype=None):
if isinstance(value, (list, tuple)) and isinstance(value[0], (list, tuple, np.ndarray)):
value_lens = [len(val) for val in value]
if len(set(value_lens)) > 1 and dtype is None:
value = as_tensor([np.asarray(val) for val in value], dtype=object)
return np.asarray(value, dtype=dtype)
is_tensor = is_numpy_array
return is_tensor, as_tensor
def convert_to_tensors(self, tensor_type: Optional[Union[str, TensorType]] = None):
"""
Convert the inner content to tensors.
Args:
tensor_type (`str` or [`~utils.TensorType`], *optional*):
The type of tensors to use. If `str`, should be one of the values of the enum [`~utils.TensorType`]. If
`None`, no modification is done.
"""
if tensor_type is None:
return self
is_tensor, as_tensor = self._get_is_as_tensor_fns(tensor_type)
for key, value in self.items():
try:
if not is_tensor(value):
tensor = as_tensor(value)
self[key] = tensor
except:
if key == "overflowing_values":
raise ValueError("Unable to create tensor returning overflowing values of different lengths. ")
raise ValueError(
"Unable to create tensor, you should probably activate padding "
"with 'padding=True' to have batched tensors with the same length."
)
return self
def to(self, *args, **kwargs) -> "BatchFeature":
"""
Send all values to device by calling `v.to(*args, **kwargs)` (PyTorch only). This should support casting in
different `dtypes` and sending the `BatchFeature` to a different `device`.
Args:
args (`Tuple`):
Will be passed to the `to(...)` function of the tensors.
kwargs (`Dict`, *optional*):
Will be passed to the `to(...)` function of the tensors.
Returns:
[`BatchFeature`]: The same instance after modification.
"""
requires_backends(self, ["torch"])
import torch
new_data = {}
device = kwargs.get("device")
if device is None and len(args) > 0:
arg = args[0]
if is_torch_dtype(arg):
pass
elif isinstance(arg, str) or is_torch_device(arg) or isinstance(arg, int):
device = arg
else:
raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.")
for k, v in self.items():
if torch.is_floating_point(v):
new_data[k] = v.to(*args, **kwargs)
elif device is not None:
new_data[k] = v.to(device=device)
else:
new_data[k] = v
self.data = new_data
return self
"""
# 这是一个特征提取的 Mixin 类,用于为顺序数据和图像特征提取器提供保存和加载功能。
"""
_auto_class = None
def __init__(self, **kwargs):
"""
# 初始化方法,将 kwargs 中的元素设置为对象的属性。
"""
self._processor_class = kwargs.pop("processor_class", None)
for key, value in kwargs.items():
try:
setattr(self, key, value)
except AttributeError as err:
logger.error(f"Can't set {key} with value {value} for {self}")
raise err
def _set_processor_class(self, processor_class: str):
"""
# 设置处理器类作为对象的属性。
"""
self._processor_class = processor_class
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Union[str, os.PathLike],
cache_dir: Optional[Union[str, os.PathLike]] = None,
force_download: bool = False,
local_files_only: bool = False,
token: Optional[Union[str, bool]] = None,
revision: str = "main",
**kwargs,
):
"""
# 从预训练模型或路径加载类实例,并配置相关参数。
"""
def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
"""
Save a feature_extractor object to the directory `save_directory`, so that it can be re-loaded using the
[`~feature_extraction_utils.FeatureExtractionMixin.from_pretrained`] class method.
Args:
save_directory (`str` or `os.PathLike`):
Directory where the feature extractor JSON file will be saved (will be created if it does not exist).
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
namespace).
kwargs (`Dict[str, Any]`, *optional*):
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
"""
use_auth_token = kwargs.pop("use_auth_token", None)
if use_auth_token is not None:
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
FutureWarning,
)
if kwargs.get("token", None) is not None:
raise ValueError(
"`token` and `use_auth_token` are both specified. Please set only the argument `token`."
)
kwargs["token"] = use_auth_token
if os.path.isfile(save_directory):
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
os.makedirs(save_directory, exist_ok=True)
if push_to_hub:
commit_message = kwargs.pop("commit_message", None)
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
repo_id = self._create_repo(repo_id, **kwargs)
files_timestamps = self._get_files_timestamps(save_directory)
if self._auto_class is not None:
custom_object_save(self, save_directory, config=self)
output_feature_extractor_file = os.path.join(save_directory, FEATURE_EXTRACTOR_NAME)
self.to_json_file(output_feature_extractor_file)
logger.info(f"Feature extractor saved in {output_feature_extractor_file}")
if push_to_hub:
self._upload_modified_files(
save_directory,
repo_id,
files_timestamps,
commit_message=commit_message,
token=kwargs.get("token"),
)
return [output_feature_extractor_file]
@classmethod
@classmethod
def get_feature_extractor_dict(
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
):
"""
Instantiates a type of [`~feature_extraction_utils.FeatureExtractionMixin`] from a Python dictionary of
parameters.
Args:
feature_extractor_dict (`Dict[str, Any]`):
Dictionary that will be used to instantiate the feature extractor object. Such a dictionary can be
retrieved from a pretrained checkpoint by leveraging the
[`~feature_extraction_utils.FeatureExtractionMixin.to_dict`] method.
kwargs (`Dict[str, Any]`):
Additional parameters from which to initialize the feature extractor object.
Returns:
[`~feature_extraction_utils.FeatureExtractionMixin`]: The feature extractor object instantiated from those
parameters.
"""
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
feature_extractor = cls(**feature_extractor_dict)
to_remove = []
for key, value in kwargs.items():
if hasattr(feature_extractor, key):
setattr(feature_extractor, key, value)
to_remove.append(key)
for key in to_remove:
kwargs.pop(key, None)
logger.info(f"Feature extractor {feature_extractor}")
if return_unused_kwargs:
return feature_extractor, kwargs
else:
return feature_extractor
def to_dict(self) -> Dict[str, Any]:
"""
Serializes this instance to a Python dictionary. Returns:
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
"""
output = copy.deepcopy(self.__dict__)
output["feature_extractor_type"] = self.__class__.__name__
if "mel_filters" in output:
del output["mel_filters"]
if "window" in output:
del output["window"]
return output
@classmethod
def from_json_file(cls, json_file: Union[str, os.PathLike]) -> PreTrainedFeatureExtractor:
"""
Instantiates a feature extractor of type [`~feature_extraction_utils.FeatureExtractionMixin`] from the path to
a JSON file of parameters.
Args:
json_file (`str` or `os.PathLike`):
Path to the JSON file containing the parameters.
Returns:
A feature extractor of type [`~feature_extraction_utils.FeatureExtractionMixin`]: The feature_extractor
object instantiated from that JSON file.
"""
with open(json_file, "r", encoding="utf-8") as reader:
text = reader.read()
feature_extractor_dict = json.loads(text)
return cls(**feature_extractor_dict)
def to_json_string(self) -> str:
"""
Serializes this instance to a JSON string.
Returns:
`str`: String containing all the attributes that make up this feature_extractor instance in JSON format.
"""
dictionary = self.to_dict()
for key, value in dictionary.items():
if isinstance(value, np.ndarray):
dictionary[key] = value.tolist()
_processor_class = dictionary.pop("_processor_class", None)
if _processor_class is not None:
dictionary["processor_class"] = _processor_class
return json.dumps(dictionary, indent=2, sort_keys=True) + "\n"
def to_json_file(self, json_file_path: Union[str, os.PathLike]):
"""
Save this instance to a JSON file.
Args:
json_file_path (`str` or `os.PathLike`):
Path to the JSON file in which this feature_extractor instance's parameters will be saved.
"""
with open(json_file_path, "w", encoding="utf-8") as writer:
writer.write(self.to_json_string())
def __repr__(self):
return f"{self.__class__.__name__} {self.to_json_string()}"
@classmethod
def register_for_auto_class(cls, auto_class="AutoFeatureExtractor"):
"""
Register this class with a given auto class. This should only be used for custom feature extractors as the ones
in the library are already mapped with `AutoFeatureExtractor`.
<Tip warning={true}>
This API is experimental and may have some slight breaking changes in the next releases.
</Tip>
Args:
auto_class (`str` or `type`, *optional*, defaults to `"AutoFeatureExtractor"`):
The auto class to register this new feature extractor with.
"""
if not isinstance(auto_class, str):
auto_class = auto_class.__name__
import transformers.models.auto as auto_module
if not hasattr(auto_module, auto_class):
raise ValueError(f"{auto_class} is not a valid auto class.")
cls._auto_class = auto_class
FeatureExtractionMixin.push_to_hub = copy_func(FeatureExtractionMixin.push_to_hub)
if FeatureExtractionMixin.push_to_hub.__doc__ is not None:
FeatureExtractionMixin.push_to_hub.__doc__ = FeatureExtractionMixin.push_to_hub.__doc__.format(
object="feature extractor", object_class="AutoFeatureExtractor", object_files="feature extractor file"
)
.\file_utils.py
"""
File utilities: utilities related to download and cache models
This module should not be update anymore and is only left for backward compatibility.
"""
from huggingface_hub import get_full_repo_name
from huggingface_hub.constants import HF_HUB_DISABLE_TELEMETRY as DISABLE_TELEMETRY
from . import __version__
from .utils import (
CLOUDFRONT_DISTRIB_PREFIX,
CONFIG_NAME,
DUMMY_INPUTS,
DUMMY_MASK,
ENV_VARS_TRUE_AND_AUTO_VALUES,
ENV_VARS_TRUE_VALUES,
FEATURE_EXTRACTOR_NAME,
FLAX_WEIGHTS_NAME,
HF_MODULES_CACHE,
HUGGINGFACE_CO_PREFIX,
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
MODEL_CARD_NAME,
MULTIPLE_CHOICE_DUMMY_INPUTS,
PYTORCH_PRETRAINED_BERT_CACHE,
PYTORCH_TRANSFORMERS_CACHE,
S3_BUCKET_PREFIX,
SENTENCEPIECE_UNDERLINE,
SPIECE_UNDERLINE,
TF2_WEIGHTS_NAME,
TF_WEIGHTS_NAME,
TORCH_FX_REQUIRED_VERSION,
TRANSFORMERS_CACHE,
TRANSFORMERS_DYNAMIC_MODULE_NAME,
USE_JAX,
USE_TF,
USE_TORCH,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
ContextManagers,
DummyObject,
EntryNotFoundError,
ExplicitEnum,
ModelOutput,
PaddingStrategy,
PushToHubMixin,
RepositoryNotFoundError,
RevisionNotFoundError,
TensorType,
_LazyModule,
add_code_sample_docstrings,
add_end_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
cached_property,
copy_func,
default_cache_path,
define_sagemaker_information,
get_cached_models,
get_file_from_repo,
get_torch_version,
has_file,
http_user_agent,
is_apex_available,
is_bs4_available,
is_coloredlogs_available,
is_datasets_available,
is_detectron2_available,
is_faiss_available,
is_flax_available,
is_ftfy_available,
is_g2p_en_available,
is_in_notebook,
is_ipex_available,
is_librosa_available,
is_offline_mode,
is_onnx_available,
is_pandas_available,
is_phonemizer_available,
is_protobuf_available,
is_psutil_available,
is_py3nvml_available,
is_pyctcdecode_available,
is_pytesseract_available,
is_pytorch_quantization_available,
is_rjieba_available,
is_sagemaker_dp_enabled,
is_sagemaker_mp_enabled,
is_scipy_available,
is_sentencepiece_available,
is_seqio_available,
is_sklearn_available,
is_soundfile_availble,
is_spacy_available,
is_speech_available,
is_tensor,
is_tensorflow_probability_available,
is_tf2onnx_available,
is_tf_available,
is_timm_available,
is_tokenizers_available,
is_torch_available,
is_torch_bf16_available,
is_torch_cuda_available,
is_torch_fx_available,
is_torch_fx_proxy,
is_torch_mps_available,
is_torch_tf32_available,
is_torch_xla_available,
is_torchaudio_available,
is_training_run_on_sagemaker,
is_vision_available,
replace_return_docstrings,
requires_backends,
to_numpy,
to_py_obj,
torch_only_method,
)
.\generation\beam_constraints.py
from abc import ABC, abstractmethod
from typing import List, Optional
class Constraint(ABC):
r"""Abstract base class for all constraints that can be applied during generation.
It must define how the constraint can be satisfied.
All classes that inherit Constraint must follow the requirement that
```
completed = False
while not completed:
_, completed = constraint.update(constraint.advance())
```
will always terminate (halt).
"""
def __init__(self):
self.test()
def test(self):
"""
Tests whether this constraint has been properly defined.
"""
counter = 0
completed = False
while not completed:
if counter == 1:
self.reset()
advance = self.advance()
if not self.does_advance(advance):
raise Exception(
"Custom Constraint is not defined correctly. self.does_advance(self.advance()) must be true."
)
stepped, completed, reset = self.update(advance)
counter += 1
if counter > 10000:
raise Exception("update() does not fulfill the constraint.")
if self.remaining() != 0:
raise Exception("Custom Constraint is not defined correctly.")
@abstractmethod
def advance(self):
"""
When called, returns the token that would take this constraint one step closer to being fulfilled.
Return:
token_ids(`torch.tensor`): Must be a tensor of a list of indexable tokens, not some integer.
"""
raise NotImplementedError(
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
)
@abstractmethod
def does_advance(self, token_id: int):
"""
Reads in a token and returns whether it creates progress.
"""
raise NotImplementedError(
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
)
@abstractmethod
def update(self, token_id: int):
"""
Given a token, updates the constraint.
Return:
stepped(`bool`): Whether the step was successful in moving towards completion.
completed(`bool`): Whether the constraint is now completed.
reset(`bool`): Whether the constraint was reset during this update.
"""
raise NotImplementedError(
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
)
def update(self, token_id: int):
"""
Reads in a token and returns booleans that indicate the progress made by it. This function will update the
state of this object unlike `does_advance(self, token_id: int)`.
This isn't to test whether a certain token will advance the progress; it's to update its state as if it has
been generated. This becomes important if token_id != desired token (refer to else statement in
PhrasalConstraint)
Args:
token_id(`int`):
The id of a newly generated token in the beam search.
Return:
stepped(`bool`):
Whether this constraint has become one step closer to being fulfilled.
completed(`bool`):
Whether this constraint has been completely fulfilled by this token being generated.
reset (`bool`):
Whether this constraint has reset its progress by this token being generated.
"""
raise NotImplementedError(
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
)
@abstractmethod
def reset(self):
"""
Resets the state of this constraint to its initialization. We would call this in cases where the fulfillment of
a constraint is aborted by an unwanted token.
"""
raise NotImplementedError(
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
)
@abstractmethod
def remaining(self):
"""
Returns the number of remaining steps of `advance()` in order to complete this constraint.
"""
raise NotImplementedError(
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
)
@abstractmethod
def copy(self, stateful=False):
"""
Creates a new instance of this constraint.
Args:
stateful(`bool`): Whether to not only copy the constraint for new instance, but also its state.
Return:
constraint(`Constraint`): The same constraint as the one being called from.
"""
raise NotImplementedError(
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
)
class PhrasalConstraint(Constraint):
r"""
[`Constraint`] enforcing that an ordered sequence of tokens is included in the output.
Args:
token_ids (`List[int]`):
The id of the token that must be generated by the output.
"""
def __init__(self, token_ids: List[int]):
super(Constraint, self).__init__()
if not isinstance(token_ids, list) or len(token_ids) == 0:
raise ValueError(f"`token_ids` has to be a non-empty list, but is {token_ids}.")
if any((not isinstance(token_id, int) or token_id < 0) for token_id in token_ids):
raise ValueError(f"Each list in `token_ids` has to be a list of positive integers, but is {token_ids}.")
self.token_ids = token_ids
self.seqlen = len(self.token_ids)
self.fulfilled_idx = -1
self.completed = False
def advance(self):
if self.completed:
return None
return self.token_ids[self.fulfilled_idx + 1]
def does_advance(self, token_id: int):
if not isinstance(token_id, int):
raise ValueError(f"`token_id` has to be an `int`, but is {token_id} of type {type(token_id)}")
if self.completed:
return False
return token_id == self.token_ids[self.fulfilled_idx + 1]
def update(self, token_id: int):
if not isinstance(token_id, int):
raise ValueError(f"`token_id` has to be an `int`, but is {token_id} of type {type(token_id)}")
stepped = False
completed = False
reset = False
if self.does_advance(token_id):
self.fulfilled_idx += 1
stepped = True
if self.fulfilled_idx == (self.seqlen - 1):
completed = True
self.completed = completed
else:
reset = True
self.reset()
return stepped, completed, reset
def reset(self):
self.completed = False
self.fulfilled_idx = 0
def remaining(self):
return self.seqlen - (self.fulfilled_idx + 1)
def copy(self, stateful=False):
new_constraint = PhrasalConstraint(self.token_ids)
if stateful:
new_constraint.seq_len = self.seqlen
new_constraint.fulfilled_idx = self.fulfilled_idx
new_constraint.completed = self.completed
return new_constraint
class DisjunctiveTrie:
def __init__(self, nested_token_ids: List[List[int]], no_subsets=True):
r"""
A helper class that builds a trie with the words represented in `nested_token_ids`.
"""
self.max_height = max([len(one) for one in nested_token_ids])
root = {}
for token_ids in nested_token_ids:
level = root
for tidx, token_id in enumerate(token_ids):
if token_id not in level:
level[token_id] = {}
level = level[token_id]
if no_subsets and self.has_subsets(root, nested_token_ids):
raise ValueError(
"Each list in `nested_token_ids` can't be a complete subset of another list, but is"
f" {nested_token_ids}."
)
self.trie = root
def next_tokens(self, current_seq):
"""
The next possible tokens that will progress the trie, given the current sequence of tokens in `current_seq`.
"""
start = self.trie
for current_token in current_seq:
start = start[current_token]
next_tokens = list(start.keys())
return next_tokens
def reached_leaf(self, current_seq):
next_tokens = self.next_tokens(current_seq)
return len(next_tokens) == 0
def count_leaves(self, root):
next_nodes = list(root.values())
if len(next_nodes) == 0:
return 1
else:
return sum([self.count_leaves(nn) for nn in next_nodes])
def has_subsets(self, trie, nested_token_ids):
"""
Returns whether # of leaves == # of words. Otherwise some word is a subset of another.
"""
leaf_count = self.count_leaves(trie)
return len(nested_token_ids) != leaf_count
class DisjunctiveConstraint(Constraint):
r"""
A special [`Constraint`] that is fulfilled by fulfilling just one of several constraints.
Args:
nested_token_ids (`List[List[int]]`):
A list of words, where each word is a list of ids. This constraint is fulfilled by generating just one from
the list of words.
"""
def __init__(self, nested_token_ids: List[List[int]]):
super(Constraint, self).__init__()
if not isinstance(nested_token_ids, list) or len(nested_token_ids) == 0:
raise ValueError(f"`nested_token_ids` has to be a non-empty list, but is {nested_token_ids}.")
if any(not isinstance(token_ids, list) for token_ids in nested_token_ids):
raise ValueError(f"`nested_token_ids` has to be a list of lists, but is {nested_token_ids}.")
if any(
any((not isinstance(token_id, int) or token_id < 0) for token_id in token_ids)
for token_ids in nested_token_ids
):
raise ValueError(
f"Each list in `nested_token_ids` has to be a list of positive integers, but is {nested_token_ids}."
)
self.trie = DisjunctiveTrie(nested_token_ids)
self.token_ids = nested_token_ids
self.seqlen = self.trie.max_height
self.current_seq = []
self.completed = False
def advance(self):
token_list = self.trie.next_tokens(self.current_seq)
if len(token_list) == 0:
return None
else:
return token_list
def does_advance(self, token_id: int):
if not isinstance(token_id, int):
raise ValueError(f"`token_id` is supposed to be type `int`, but is {token_id} of type {type(token_id)}")
next_tokens = self.trie.next_tokens(self.current_seq)
return token_id in next_tokens
def update(self, token_id: int):
if not isinstance(token_id, int):
raise ValueError(f"`token_id` is supposed to be type `int`, but is {token_id} of type {type(token_id)}")
stepped = False
completed = False
reset = False
if self.does_advance(token_id):
self.current_seq.append(token_id)
stepped = True
else:
reset = True
self.reset()
completed = self.trie.reached_leaf(self.current_seq)
self.completed = completed
return stepped, completed, reset
def reset(self):
self.completed = False
self.current_seq = []
def remaining(self):
if self.completed:
return 0
else:
return self.seqlen - len(self.current_seq)
def copy(self, stateful=False):
new_constraint = DisjunctiveConstraint(self.token_ids)
if stateful:
new_constraint.seq_len = self.seqlen
new_constraint.current_seq = self.current_seq
new_constraint.completed = self.completed
return new_constraint
class ConstraintListState:
r"""
A class for beam scorers to track its progress through a list of constraints.
Args:
constraints (`List[Constraint]`):
A list of [`Constraint`] objects that must be fulfilled by the beam scorer.
"""
def __init__(self, constraints: List[Constraint]):
self.constraints = constraints
self.max_seqlen = max([c.seqlen for c in constraints])
self.n_constraints = len(constraints)
self.completed = False
self.init_state()
def init_state(self):
self.complete_constraints = []
self.inprogress_constraint = None
self.pending_constraints = [constraint.copy(stateful=False) for constraint in self.constraints]
def get_bank(self):
add = 0
if self.inprogress_constraint:
add += self.max_seqlen - self.inprogress_constraint.remaining()
return (len(self.complete_constraints) * self.max_seqlen) + add
def advance(self):
"""The list of tokens to generate such that we can make progress.
By "list" we don't mean the list of token that will fully fulfill a constraint.
Given constraints `c_i = {t_ij | j == # of tokens}`, If we're not in the middle of progressing through a
specific constraint `c_i`, we return:
`[t_k1 for k in indices of unfulfilled constraints]`
If we are in the middle of a constraint, then we return:
`[t_ij]`, where `i` is the index of the inprogress constraint, `j` is the next step for the constraint.
Though we don't care which constraint is fulfilled first, if we are in the progress of fulfilling a constraint,
that's the only one we'll return.
"""
token_list = []
if self.inprogress_constraint is None:
for constraint in self.pending_constraints:
advance = constraint.advance()
if isinstance(advance, int):
token_list.append(advance)
elif isinstance(advance, list):
token_list.extend(advance)
else:
advance = self.inprogress_constraint.advance()
if isinstance(advance, int):
token_list.append(advance)
elif isinstance(advance, list):
token_list.extend(advance)
if len(token_list) == 0:
return None
else:
return token_list
def reset(self, token_ids: Optional[List[int]]):
"""
重置对象状态,根据给定的token_ids重新设置约束的进度状态。
token_ids: 到目前为止生成的令牌,用于重置通过约束的进度状态。
"""
self.init_state()
if token_ids is not None:
for token in token_ids:
complete, stepped = self.add(token)
if self.completed:
break
def copy(self, stateful=True):
"""
创建并返回一个当前对象的副本,可以选择是否保持状态。
stateful: 是否保持状态,默认为True。
"""
new_state = ConstraintListState(self.constraints)
if stateful:
new_state.complete_constraints = [
constraint.copy(stateful=True) for constraint in self.complete_constraints
]
if self.inprogress_constraint is not None:
new_state.inprogress_constraint = self.inprogress_constraint.copy(stateful=True)
new_state.pending_constraints = [constraint.copy() for constraint in self.pending_constraints]
return new_state
.\generation\beam_search.py
from abc import ABC, abstractmethod
from collections import UserDict
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from ..utils import add_start_docstrings
from .beam_constraints import Constraint, ConstraintListState
PROCESS_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size * num_beams, sequence_length)`):
Indices of input sequence tokens in the vocabulary.
Indices can be obtained using any class inheriting from [`PreTrainedTokenizer`]. See
[`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
next_scores (`torch.FloatTensor` of shape `(batch_size, 2 * num_beams)`):
Current scores of the top `2 * num_beams` non-finished beam hypotheses.
next_tokens (`torch.LongTensor` of shape `(batch_size, 2 * num_beams)`):
`input_ids` of the tokens corresponding to the top `2 * num_beams` non-finished beam hypotheses.
next_indices (`torch.LongTensor` of shape `(batch_size, 2 * num_beams)`):
Beam indices indicating to which beam hypothesis the `next_tokens` correspond.
pad_token_id (`int`, *optional*):
The id of the *padding* token.
eos_token_id (`Union[int, List[int]]`, *optional*):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
beam_indices (`torch.LongTensor`, *optional*):
Beam indices indicating to which beam hypothesis each token correspond.
group_index (`int`, *optional*):
The index of the group of beams. Used with [`~PreTrainedModel.group_beam_search`].
Return:
`UserDict`: A dictionary composed of the fields as defined above:
- **next_beam_scores** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Updated scores of all
non-finished beams.
- **next_beam_tokens** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Next tokens to be added
to the non-finished beam_hypotheses.
- **next_beam_indices** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Beam indices
indicating to which beam the next tokens shall be added.
FINALIZE_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size * num_beams, sequence_length)`):
Indices of input sequence tokens in the vocabulary.
Indices can be obtained using any class inheriting from [`PreTrainedTokenizer`]. See
[`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary
final_beam_scores (`torch.FloatTensor` of shape `(batch_size * num_beams)`):
The final scores of all non-finished beams.
final_beam_tokens (`torch.FloatTensor` of shape `(batch_size * num_beams)`):
The last tokens to be added to the non-finished beam_hypotheses.
final_beam_indices (`torch.FloatTensor` of shape `(batch_size * num_beams)`):
The beam indices indicating to which beam the `final_beam_tokens` shall be added.
pad_token_id (`int`, *optional*):
The id of the *padding* token.
eos_token_id (`Union[int, List[int]]`, *optional*):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
Return:
`torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`: The generated sequences.
The second dimension (sequence_length) is either equal to `max_length` or shorter if all batches finished early
due to the `eos_token_id`.
"""
class BeamScorer(ABC):
"""
Abstract base class for all beam scorers that are used for [`~PreTrainedModel.beam_search`] and
[`~PreTrainedModel.beam_sample`].
"""
@abstractmethod
@add_start_docstrings(PROCESS_INPUTS_DOCSTRING) # 添加输入处理方法的文档字符串
def process(
self,
input_ids: torch.LongTensor,
next_scores: torch.FloatTensor,
next_tokens: torch.LongTensor,
next_indices: torch.LongTensor,
**kwargs,
) -> Tuple[torch.Tensor]:
raise NotImplementedError("This is an abstract method.")
@abstractmethod
@add_start_docstrings(FINALIZE_INPUTS_DOCSTRING) # 添加最终处理方法的文档字符串
def finalize(
self,
input_ids: torch.LongTensor,
next_scores: torch.FloatTensor,
next_tokens: torch.LongTensor,
next_indices: torch.LongTensor,
max_length: int,
**kwargs,
) ->
Args:
batch_size (`int`):
并行运行标准束搜索解码的 `input_ids` 的批大小。
num_beams (`int`):
梁搜索的束大小。
device (`torch.device`):
分配此 `BeamSearchScorer` 实例的设备类型(例如 `"cpu"` 或 `"cuda"`)。
length_penalty (`float`, *optional*, defaults to 1.0):
用于基于束搜索的生成的指数长度惩罚。应用为序列长度的指数,然后用于将序列的分数除以此值。由于分数是序列的对数似然(即负数),`length_penalty` > 0.0 会促进更长的序列,而 `length_penalty` < 0.0 会鼓励更短的序列。
do_early_stopping (`bool` or `str`, *optional*, defaults to `False`):
控制束搜索等方法(如束搜索)的停止条件。接受以下值:
`True`,生成器一旦有 `num_beams` 个完整候选项即停止;
`False`,应用启发式方法,生成器停止时不太可能找到更好的候选项;
`"never"`,束搜索过程仅在不能有更好的候选项时停止(典型的束搜索算法)。
num_beam_hyps_to_keep (`int`, *optional*, defaults to 1):
在调用 [`~transformers.BeamSearchScorer.finalize`] 后返回的束假设数量。
num_beam_groups (`int`, *optional*, defaults to 1):
为了确保不同束组之间的多样性,将 `num_beams` 分成的组数。详细信息请参阅[此论文](https://arxiv.org/pdf/1610.02424.pdf)。
max_length (`int`, *optional*):
要生成的序列的最大长度。
"""
def __init__(
self,
batch_size: int,
num_beams: int,
device: torch.device,
length_penalty: Optional[float] = 1.0,
do_early_stopping: Optional[Union[bool, str]] = False,
num_beam_hyps_to_keep: Optional[int] = 1,
num_beam_groups: Optional[int] = 1,
max_length: Optional[int] = None,
):
self.num_beams = num_beams
self.device = device
self.length_penalty = length_penalty
self.do_early_stopping = do_early_stopping
self.num_beam_hyps_to_keep = num_beam_hyps_to_keep
self.num_beam_groups = num_beam_groups
self.group_size = self.num_beams // self.num_beam_groups
self._is_init = False
self._beam_hyps = [
BeamHypotheses(
num_beams=self.group_size,
length_penalty=self.length_penalty,
early_stopping=self.do_early_stopping,
max_length=max_length,
)
for _ in range(batch_size * self.num_beam_groups)
]
self._done = torch.tensor(
[False for _ in range(batch_size * self.num_beam_groups)], dtype=torch.bool, device=self.device
)
if not isinstance(num_beams, int) or num_beams <= 1:
raise ValueError(
f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1,"
" one should make use of `greedy_search` instead."
)
if not isinstance(num_beam_groups, int) or (num_beam_groups > num_beams) or (num_beams % num_beam_groups != 0):
raise ValueError(
"`num_beam_groups` has to be an integer smaller or equal than `num_beams` and `num_beams` has to be"
f" divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}."
)
@property
def is_done(self) -> bool:
return self._done.all()
def process(
self,
input_ids: torch.LongTensor,
next_scores: torch.FloatTensor,
next_tokens: torch.LongTensor,
next_indices: torch.LongTensor,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
beam_indices: Optional[torch.LongTensor] = None,
group_index: Optional[int] = 0,
decoder_prompt_len: Optional[int] = 0,
):
def finalize(
self,
input_ids: torch.LongTensor,
final_beam_scores: torch.FloatTensor,
final_beam_tokens: torch.LongTensor,
final_beam_indices: torch.LongTensor,
max_length: int,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
beam_indices: Optional[torch.LongTensor] = None,
decoder_prompt_len: Optional[int] = 0,
):
r"""
[`BeamScorer`] implementing constrained beam search decoding.
实现受限束搜索解码的 [`BeamScorer`]。
Args:
batch_size (`int`):
Batch Size of `input_ids` for which standard beam search decoding is run in parallel.
输入 `input_ids` 的批处理大小,用于并行运行标准的束搜索解码。
num_beams (`int`):
Number of beams for beam search.
束搜索的束数。
constraints (`List[Constraint]`):
A list of positive constraints represented as `Constraint` objects that must be fulfilled in the generation
output. For more information, the documentation of [`Constraint`] should be read.
表示为 `Constraint` 对象的正约束列表,必须在生成的输出中满足。有关更多信息,请阅读 [`Constraint`] 的文档。
device (`torch.device`):
Defines the device type (*e.g.*, `"cpu"` or `"cuda"`) on which this instance of `BeamSearchScorer` will be
allocated.
定义此 `BeamSearchScorer` 实例将分配到的设备类型(例如 `"cpu"` 或 `"cuda"`)。
length_penalty (`float`, *optional*, defaults to 1.0):
Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent to
the sequence length, which in turn is used to divide the score of the sequence. Since the score is the log
likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences, while
`length_penalty` < 0.0 encourages shorter sequences.
用于基于束的生成的长度的指数惩罚。它作为序列长度的指数应用,进而用于分割序列的分数。由于分数是序列的对数似然(即负数),`length_penalty` > 0.0 促进更长的序列,而 `length_penalty` < 0.0 鼓励更短的序列。
do_early_stopping (`bool` or `str`, *optional*, defaults to `False`):
Controls the stopping condition for beam-based methods, like beam-search. It accepts the following values:
`True`, where the generation stops as soon as there are `num_beams` complete candidates; `False`, where an
heuristic is applied and the generation stops when is it very unlikely to find better candidates;
`"never"`, where the beam search procedure only stops when there cannot be better candidates (canonical
beam search algorithm).
控制基于束的方法(如束搜索)的停止条件。它接受以下值:`True`,生成在有 `num_beams` 个完整候选时停止;`False`,应用启发式并在很不可能找到更好的候选时停止生成;`"never"`,束搜索过程仅在不能有更好的候选时停止(经典的束搜索算法)。
num_beam_hyps_to_keep (`int`, *optional*, defaults to 1):
The number of beam hypotheses that shall be returned upon calling
[`~transformers.BeamSearchScorer.finalize`].
在调用 [`~transformers.BeamSearchScorer.finalize`] 时将返回的束假设数。
num_beam_groups (`int`, *optional*, defaults to 1):
Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams.
See [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.
为了确保不同组的束之间的多样性,将 `num_beams` 分成的组数。有关更多详细信息,请参见 [此文献](https://arxiv.org/pdf/1610.02424.pdf)。
max_length (`int`, *optional*):
The maximum length of the sequence to be generated.
要生成的序列的最大长度。
"""
def __init__(
self,
batch_size: int,
num_beams: int,
constraints: List[Constraint],
device: torch.device,
length_penalty: Optional[float] = 1.0,
do_early_stopping: Optional[Union[bool, str]] = False,
num_beam_hyps_to_keep: Optional[int] = 1,
num_beam_groups: Optional[int] = 1,
max_length: Optional[int] = None,
):
self.num_beams = num_beams
self.device = device
self.length_penalty = length_penalty
self.do_early_stopping = do_early_stopping
self.num_beam_hyps_to_keep = num_beam_hyps_to_keep
self.num_beam_groups = num_beam_groups
self.group_size = self.num_beams // self.num_beam_groups
self.constraints = constraints
self._is_init = False
self._beam_hyps = [
BeamHypotheses(
num_beams=self.num_beams,
length_penalty=self.length_penalty,
early_stopping=self.do_early_stopping,
max_length=max_length,
)
for _ in range(batch_size)
]
self._done = torch.tensor([False for _ in range(batch_size)], dtype=torch.bool, device=self.device)
if not isinstance(num_beams, int) or num_beams <= 1:
raise ValueError(
f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1,"
" one should make use of `greedy_search` instead."
)
if not isinstance(num_beam_groups, int) or (num_beam_groups > num_beams) or (num_beams % num_beam_groups != 0):
raise ValueError(
"`num_beam_groups` has to be an integer smaller or equal than `num_beams` and `num_beams` has to be"
f" divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}."
)
@property
def is_done(self) -> bool:
return self._done.all()
def make_constraint_states(self, n):
return [ConstraintListState([constraint.copy() for constraint in self.constraints]) for _ in range(n)]
def check_completes_constraints(self, sequence):
new_state = self.make_constraint_states(1)[0]
new_state.reset(sequence)
return new_state.completed
def process(
self,
input_ids: torch.LongTensor,
next_scores: torch.FloatTensor,
next_tokens: torch.LongTensor,
next_indices: torch.LongTensor,
scores_for_all_vocab: torch.FloatTensor,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
beam_indices: Optional[torch.LongTensor] = None,
decoder_prompt_len: Optional[int] = 0,
):
...
def step_sentence_constraint(
self,
batch_idx: int,
input_ids: torch.LongTensor,
vocab_scores: torch.FloatTensor,
sent_beam_scores: torch.FloatTensor,
sent_beam_tokens: torch.LongTensor,
sent_beam_indices: torch.LongTensor,
push_progress: bool = False,
):
...
def finalize(
self,
input_ids: torch.LongTensor,
final_beam_scores: torch.FloatTensor,
final_beam_tokens: torch.LongTensor,
final_beam_indices: torch.LongTensor,
max_length: int,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
beam_indices: Optional[torch.LongTensor] = None,
decoder_prompt_len: Optional[int] = 0,
class BeamHypotheses:
def __init__(self, num_beams: int, length_penalty: float, early_stopping: bool, max_length: Optional[int] = None):
"""
Initialize n-best list of hypotheses.
Args:
num_beams (int): Beam size, i.e., number of beams to keep.
length_penalty (float): Length penalty to be applied to scores.
early_stopping (bool): Whether to stop generation early based on conditions.
max_length (Optional[int]): Optional maximum length for generated hypotheses.
"""
self.length_penalty = length_penalty
self.early_stopping = early_stopping
self.max_length = max_length
self.num_beams = num_beams
self.beams = []
self.worst_score = 1e9
if not isinstance(self.early_stopping, bool) and self.max_length is None:
raise ValueError(
"When `do_early_stopping` is set to a string, `max_length` must be defined. Ensure it is passed to the"
" BeamScorer class instance at initialization time."
)
def __len__(self):
"""
Number of hypotheses in the list.
"""
return len(self.beams)
def add(
self,
hyp: torch.LongTensor,
sum_logprobs: float,
beam_indices: Optional[torch.LongTensor] = None,
generated_len: Optional[int] = None,
):
"""
Add a new hypothesis to the list.
Args:
hyp (torch.LongTensor): Tensor representing the hypothesis.
sum_logprobs (float): Sum of log probabilities associated with the hypothesis.
beam_indices (Optional[torch.LongTensor]): Optional tensor of beam indices.
generated_len (Optional[int]): Optional length of the generated sequence.
"""
if generated_len is not None:
score = sum_logprobs / (generated_len**self.length_penalty)
else:
score = sum_logprobs / (hyp.shape[-1] ** self.length_penalty)
if len(self) < self.num_beams or score > self.worst_score:
self.beams.append((score, hyp, beam_indices))
if len(self) > self.num_beams:
sorted_next_scores = sorted([(s, idx) for idx, (s, _, _) in enumerate(self.beams)])
del self.beams[sorted_next_scores[0][1]]
self.worst_score = sorted_next_scores[1][0]
else:
self.worst_score = min(score, self.worst_score)
def is_done(self, best_sum_logprobs: float, cur_len: int, decoder_prompt_len: Optional[int] = 0) -> bool:
"""
If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst
one in the heap, then we are done with this sentence.
"""
if len(self) < self.num_beams:
return False
if self.early_stopping is True:
return True
elif self.early_stopping is False:
highest_attainable_score = best_sum_logprobs / (cur_len - decoder_prompt_len) ** self.length_penalty
ret = self.worst_score >= highest_attainable_score
return ret
else:
if self.length_penalty > 0.0:
if self.max_length <= decoder_prompt_len:
raise ValueError("max_length is not larger than decoder prompt length")
highest_attainable_score = (
best_sum_logprobs / (self.max_length - decoder_prompt_len) ** self.length_penalty
)
else:
highest_attainable_score = best_sum_logprobs / (cur_len - decoder_prompt_len) ** self.length_penalty
ret = self.worst_score >= highest_attainable_score
return ret
.\generation\candidate_generator.py
import copy
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
import torch
if TYPE_CHECKING:
from ..modeling_utils import PreTrainedModel
from .configuration_utils import GenerationConfig
from .logits_process import LogitsProcessorList
class CandidateGenerator:
"""所有候选生成器的抽象基类,可在辅助生成过程中应用。"""
def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
"""
获取当前输入的候选生成序列。
Args:
input_ids (`torch.LongTensor`,形状为 `(batch_size, sequence_length)`):
输入序列标记在词汇表中的索引。[什么是输入ID?](../glossary#input-ids)
Return:
`torch.LongTensor`,形状为 `(batch_size, candidate_length)`,包含模型评估的候选序列,
以及一个可选的 `torch.FloatTensor`,形状为 `(batch_size, candidate_length, vocabulary_size)`,
包含与每个候选相关的logits。
"""
raise NotImplementedError(
f"{self.__class__} 是一个抽象类。只有继承此类的类才能调用 `get_candidates`。"
)
def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int):
"""
根据结果更新候选生成策略。
Args:
input_ids (`torch.LongTensor`,形状为 `(batch_size, sequence_length)`):
输入序列标记在词汇表中的索引。[什么是输入ID?](../glossary#input-ids)
scores (`torch.FloatTensor`,形状为 `(batch_size, candidate_length, config.vocab_size)`):
语言建模头部的预测分数。当不使用beam搜索时,这些可以是每个词汇的logits,或者在使用beam搜索时,每个词汇token的log softmax。
num_matches (`int`):
候选序列与模型预测之间的匹配数。
"""
raise NotImplementedError(
f"{self.__class__} 是一个抽象类。只有继承此类的类才能调用 `update_candidate_strategy`。"
)
"""
`CandidateGenerator` class to be used for assisted generation and speculative decoding. This class generates
candidates through the use of a smaller model. Read the following blog post for more information:
https://huggingface.co/blog/assisted-generation
"""
def __init__(
self,
input_ids: torch.LongTensor,
assistant_model: "PreTrainedModel",
generation_config: "GenerationConfig",
logits_processor: "LogitsProcessorList",
model_kwargs: Dict,
inputs_tensor: Optional[torch.Tensor] = None,
):
"""
Initialize the `AssistedCandidateGenerator` with necessary parameters.
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary.
assistant_model (`PreTrainedModel`):
The model used for generating candidates, which is smaller than the main model.
generation_config (`~generation.GenerationConfig`, *optional*):
Configuration for the generation process.
logits_processor (`LogitsProcessorList`):
List of processors to modify prediction scores of the language modeling head during generation.
model_kwargs (`Dict`):
Keyword arguments passed to the main model and the assistant model.
inputs_tensor (`torch.Tensor`, *optional*):
The input tensor for the model, typically the encoder input in encoder-decoder models.
"""
super().__init__(input_ids, assistant_model, generation_config, logits_processor, model_kwargs)
self.inputs_tensor = inputs_tensor
def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
"""
Fetches the candidates to be tried for the current input.
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
Return:
`torch.LongTensor` of shape `(batch_size, candidate_length)` containing the candidate sequences to be
assessed by the model and a `torch.FloatTensor` of shape `(batch_size, candidate_length,
vocabulary_size)` containing the logits associated to each candidate.
"""
input_ids = input_ids.to(self.assistant_model.device)
new_cur_len = input_ids.shape[-1]
max_new_tokens = min(int(self.num_assistant_tokens), self.generation_config.max_length - new_cur_len - 1)
if max_new_tokens == 0:
return input_ids, None
has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None
if has_past_key_values:
new_cache_size = new_cur_len - 1
self.assistant_kwargs["past_key_values"] = _crop_past_key_values(
self.assistant_model, self.assistant_kwargs["past_key_values"], new_cache_size - 1
)
self.assistant_kwargs = _prepare_attention_mask(
self.assistant_kwargs, new_cur_len, self.assistant_model.config.is_encoder_decoder
)
self.assistant_kwargs = _prepare_token_type_ids(self.assistant_kwargs, new_cur_len)
assistant_generation_kwargs = {
self.input_ids_key: input_ids,
"max_new_tokens": max_new_tokens,
"generation_config": self.generation_config,
"logits_processor": self.logits_processor,
}
assistant_output = self.assistant_model.generate(**assistant_generation_kwargs, **self.assistant_kwargs)
self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values
candidate_logits = torch.stack(assistant_output.scores, dim=1)
candidate_ids = assistant_output.sequences
return candidate_ids, candidate_logits
def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int):
"""
Updates the candidate generation strategy based on the outcomes.
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
scores (`torch.FloatTensor` of shape `(batch_size, candidate_length, config.vocab_size)`):
Prediction scores of a language modeling head. These can be logits for each vocabulary when not using
beam search or log softmax for each vocabulary token when using beam search
num_matches (`int`):
The number of matches between the candidate sequences and the model predictions.
"""
if self.assistant_model.generation_config.num_assistant_tokens_schedule in {
"heuristic",
"heuristic_transient",
}:
if num_matches == int(self.num_assistant_tokens):
self.num_assistant_tokens += 2.0
else:
self.num_assistant_tokens = max(1.0, self.num_assistant_tokens - 1.0)
class PromptLookupCandidateGenerator(CandidateGenerator):
"""
`CandidateGenerator` class to be used for prompt lookup generation. This class generates candidates by looking up
likely continuations in the provided prompt (input_ids) itself.
Read the following blog post for more information: https://github.com/apoorvumang/prompt-lookup-decoding
"""
def __init__(
self,
num_output_tokens: int = 10,
max_matching_ngram_size: int = None,
):
self.num_output_tokens = num_output_tokens
self.max_matching_ngram_size = max_matching_ngram_size if max_matching_ngram_size else 2
if self.max_matching_ngram_size <= 0 or self.num_output_tokens <= 0:
raise ValueError("Invalid max_matching_ngram_size or num_output_tokens")
def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
"""
Fetches the candidates to be tried for the current input.
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
Return:
`torch.LongTensor` of shape `(num_candidates, candidate_length)`: The candidate sequences to be tried.
"""
input_length = input_ids.size(1)
chosen_ids = None
match_found = False
for ngram_size in range(min(self.max_matching_ngram_size, input_length - 1), 0, -1):
windows = input_ids.unfold(dimension=1, size=ngram_size, step=1)
ngram_tensor = input_ids[0, -ngram_size:]
matches = (windows == ngram_tensor).all(dim=2)
match_indices = matches.nonzero(as_tuple=True)[1]
for idx in match_indices:
start_idx = idx + ngram_size
end_idx = start_idx + self.num_output_tokens
end_idx = min(end_idx, input_length)
if start_idx < end_idx:
chosen_ids = input_ids[0, start_idx:end_idx]
match_found = True
break
if match_found:
break
if chosen_ids is None or len(chosen_ids) == 0:
return input_ids, None
chosen_ids = chosen_ids.unsqueeze(0)
candidate_input_ids = torch.cat((input_ids, chosen_ids), dim=1)
return candidate_input_ids, None
def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int):
"""
Updates the candidate generation strategy based on the outcomes.
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
scores (`torch.FloatTensor` of shape `(batch_size, candidate_length, config.vocab_size)`):
Prediction scores of a language modeling head. These can be logits for each vocabulary when not using
beam search or log softmax for each vocabulary token when using beam search
num_matches (`int`):
The number of matches between the candidate sequences and the model predictions.
"""
return
def _crop_past_key_values(model, past_key_values, maximum_length):
new_past = []
if model.config.is_encoder_decoder:
for idx in range(len(past_key_values)):
new_past.append(
(
past_key_values[idx][0][:, :, :maximum_length, :],
past_key_values[idx][1][:, :, :maximum_length, :],
past_key_values[idx][2],
past_key_values[idx][3],
)
)
past_key_values = tuple(new_past)
elif "bloom" in model.__class__.__name__.lower() or (
model.config.architectures is not None and "bloom" in model.config.architectures[0].lower()
):
for idx in range(len(past_key_values)):
new_past.append(
(
past_key_values[idx][0][:, :, :maximum_length],
past_key_values[idx][1][:, :maximum_length, :],
)
)
past_key_values = tuple(new_past)
elif "gptbigcode" in model.__class__.__name__.lower() or (
model.config.architectures is not None and "gptbigcode" in model.config.architectures[0].lower()
):
if model.config.multi_query:
for idx in range(len(past_key_values)):
past_key_values[idx] = past_key_values[idx][:, :maximum_length, :]
else:
for idx in range(len(past_key_values)):
past_key_values[idx] = past_key_values[idx][:, :, :maximum_length, :]
else:
for idx in range(len(past_key_values)):
new_past.append(
(
past_key_values[idx][0][:, :, :maximum_length, :],
past_key_values[idx][1][:, :, :maximum_length, :],
)
)
past_key_values = tuple(new_past)
return past_key_values
def _prepare_attention_mask(model_kwargs: Dict[str, Any], new_length: int, is_encoder_decoder: bool) -> Dict[str, Any]:
"""Expands or crops the model's mask for decoding purposes, to the defined length"""
mask_key = "decoder_attention_mask" if is_encoder_decoder else "attention_mask"
if mask_key not in model_kwargs:
return model_kwargs
mask = model_kwargs[mask_key]
mask_length_diff = new_length - mask.shape[1]
if mask_length_diff < 0:
model_kwargs[mask_key] = mask[:, :mask_length_diff]
elif mask_length_diff > 0:
model_kwargs[mask_key] = torch.cat([mask, mask.new_ones((mask.shape[0], mask_length_diff))], dim=-1)
return model_kwargs
def _prepare_token_type_ids(model_kwargs: Dict[str, Any], new_length: int) -> Dict[str, Any]:
"""Expands or crops the model's token_type_ids for decoding purposes, to the defined length"""
if "token_type_ids" not in model_kwargs or model_kwargs["token_type_ids"] is None:
return model_kwargs
token_type_ids = model_kwargs["token_type_ids"]
final_token_type = token_type_ids[:, -1].unsqueeze(-1)
type_length_diff = new_length - token_type_ids.shape[1]
if type_length_diff < 0:
token_type_ids = token_type_ids[:, :type_length_diff]
elif type_length_diff > 0:
token_type_copies = final_token_type.repeat(1, type_length_diff)
model_kwargs["token_type_ids"] = torch.cat([model_kwargs["token_type_ids"], token_type_copies], dim=-1)
return model_kwargs
.\generation\configuration_utils.py
""" Generation configuration class and utilities."""
import copy
import json
import os
import warnings
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
from .. import __version__
from ..configuration_utils import PretrainedConfig
from ..utils import (
GENERATION_CONFIG_NAME,
ExplicitEnum,
PushToHubMixin,
cached_file,
download_url,
extract_commit_hash,
is_remote_url,
logging,
)
if TYPE_CHECKING:
from ..modeling_utils import PreTrainedModel
logger = logging.get_logger(__name__)
METADATA_FIELDS = ("_from_model_config", "_commit_hash", "_original_object_hash", "transformers_version")
class GenerationMode(ExplicitEnum):
"""
Possible generation modes, downstream of the [`~generation.GenerationMixin.generate`] method.
"""
CONTRASTIVE_SEARCH = "contrastive_search"
GREEDY_SEARCH = "greedy_search"
SAMPLE = "sample"
ASSISTED_GENERATION = "assisted_generation"
BEAM_SEARCH = "beam_search"
BEAM_SAMPLE = "beam_sample"
CONSTRAINED_BEAM_SEARCH = "constrained_beam_search"
GROUP_BEAM_SEARCH = "group_beam_search"
class GenerationConfig(PushToHubMixin):
r"""
Class that holds a configuration for a generation task. A `generate` call supports the following generation methods
```
# 生成任务配置类,支持以下生成方法
"""
Defines special methods for hash, equality comparison, and representation of GenerationConfig objects.
"""
# 计算对象的哈希值,基于忽略元数据的 JSON 字符串表示
def __hash__(self):
return hash(self.to_json_string(ignore_metadata=True))
# 判断两个 GenerationConfig 对象是否相等,忽略元数据进行比较
def __eq__(self, other):
# 如果 other 不是 GenerationConfig 类型,直接返回 False
if not isinstance(other, GenerationConfig):
return False
# 分别获取去除元数据后的 JSON 字符串
self_without_metadata = self.to_json_string(use_diff=False, ignore_metadata=True)
other_without_metadata = other.to_json_string(use_diff=False, ignore_metadata=True)
# 比较两个 JSON 字符串是否相等
return self_without_metadata == other_without_metadata
# 返回 GenerationConfig 对象的字符串表示,包括忽略元数据的 JSON 字符串
def __repr__(self):
return f"{self.__class__.__name__} {self.to_json_string(ignore_metadata=True)}"
def get_generation_mode(self, assistant_model: Optional["PreTrainedModel"] = None) -> GenerationMode:
"""
Returns the generation mode triggered by the [`GenerationConfig`] instance.
Arg:
assistant_model (`PreTrainedModel`, *optional*):
The assistant model to be used for assisted generation. If set, the generation mode will be
assisted generation.
Returns:
`GenerationMode`: The generation mode triggered by the instance.
"""
# Determine generation mode based on various configuration parameters
if self.constraints is not None or self.force_words_ids is not None:
generation_mode = GenerationMode.CONSTRAINED_BEAM_SEARCH
elif self.num_beams == 1:
if self.do_sample is False:
if (
self.top_k is not None
and self.top_k > 1
and self.penalty_alpha is not None
and self.penalty_alpha > 0
):
generation_mode = GenerationMode.CONTRASTIVE_SEARCH
else:
generation_mode = GenerationMode.GREEDY_SEARCH
else:
generation_mode = GenerationMode.SAMPLE
else:
if self.num_beam_groups > 1:
generation_mode = GenerationMode.GROUP_BEAM_SEARCH
elif self.do_sample is True:
generation_mode = GenerationMode.BEAM_SAMPLE
else:
generation_mode = GenerationMode.BEAM_SEARCH
# Modify generation mode if assistant model is specified for assisted generation
if assistant_model is not None or self.prompt_lookup_num_tokens is not None:
if generation_mode in (GenerationMode.GREEDY_SEARCH, GenerationMode.SAMPLE):
generation_mode = GenerationMode.ASSISTED_GENERATION
else:
raise ValueError(
"You've set `assistant_model`, which triggers assisted generation. Currently, assisted generate "
"is only supported with Greedy Search and Sample."
)
# Return the determined generation mode
return generation_mode
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
config_file_name: Optional[Union[str, os.PathLike]] = None,
push_to_hub: bool = False,
**kwargs,
):
"""
Saves the current configuration to the specified directory.
Args:
save_directory (Union[str, os.PathLike]): Directory where the configuration should be saved.
config_file_name (Optional[Union[str, os.PathLike]], *optional*):
Name for the configuration file. If not provided, a default name will be used.
push_to_hub (bool, *optional*):
Whether to push the saved configuration to the model hub (if applicable).
**kwargs:
Additional keyword arguments for future expansion.
"""
@classmethod
def from_pretrained(
cls,
pretrained_model_name: Union[str, os.PathLike],
config_file_name: Optional[Union[str, os.PathLike]] = None,
cache_dir: Optional[Union[str, os.PathLike]] = None,
force_download: bool = False,
local_files_only: bool = False,
token: Optional[Union[str, bool]] = None,
revision: str = "main",
**kwargs,
):
"""
Creates an instance of the class from a pretrained model.
Args:
pretrained_model_name (Union[str, os.PathLike]): Name or path of the pretrained model.
config_file_name (Optional[Union[str, os.PathLike]], *optional*):
Name for the configuration file. If not provided, a default name will be used.
cache_dir (Optional[Union[str, os.PathLike]], *optional*):
Directory to cache downloaded files (if applicable).
force_download (bool, *optional*):
Whether to force re-download of the model files, ignoring any cached versions.
local_files_only (bool, *optional*):
Whether to only consider local files as sources for the model, ignoring any remote repositories.
token (Optional[Union[str, bool]], *optional*):
Access token for private model repositories (if applicable).
revision (str, *optional*):
Revision or version of the model to load.
**kwargs:
Additional keyword arguments for future expansion.
Returns:
Instance of the class loaded from the pretrained model.
"""
# 从给定的 JSON 文件中读取内容并将其解析为 Python 字典
def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
with open(json_file, "r", encoding="utf-8") as reader:
text = reader.read()
return json.loads(text)
@classmethod
def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "GenerationConfig":
"""
从一个 Python 字典参数实例化一个 GenerationConfig 对象。
Args:
config_dict (`Dict[str, Any]`):
将用于实例化配置对象的字典。
kwargs (`Dict[str, Any]`):
用于初始化配置对象的额外参数。
Returns:
[`GenerationConfig`]: 从这些参数实例化的配置对象。
"""
# 是否返回未使用的关键字参数,默认为 False
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
# 移除内部遥测用的参数,以防止它们出现在 `return_unused_kwargs` 中
kwargs.pop("_from_auto", None)
kwargs.pop("_from_pipeline", None)
# 如果 `_commit_hash` 在 kwargs 中且在 config_dict 中,则更新 `_commit_hash`
if "_commit_hash" in kwargs and "_commit_hash" in config_dict:
kwargs["_commit_hash"] = config_dict["_commit_hash"]
# 下面的语句允许通过 kwargs 加载特定于模型的配置,并进行安全检查。
# 参考:https://github.com/huggingface/transformers/pull/21269
config = cls(**{**config_dict, **kwargs})
# 更新配置,并返回未使用的关键字参数
unused_kwargs = config.update(**kwargs)
# 记录生成的配置信息
logger.info(f"Generate config {config}")
if return_unused_kwargs:
return config, unused_kwargs
else:
return config
# 将字典及其嵌套字典中的 `torch_dtype` 键转换为字符串形式,例如 `torch.float32` 转换为 `"float32"`
def dict_torch_dtype_to_str(self, d: Dict[str, Any]) -> None:
if d.get("torch_dtype", None) is not None and not isinstance(d["torch_dtype"], str):
d["torch_dtype"] = str(d["torch_dtype"]).split(".")[1]
for value in d.values():
if isinstance(value, dict):
self.dict_torch_dtype_to_str(value)
def to_diff_dict(self) -> Dict[str, Any]:
"""
Removes all attributes from config which correspond to the default config attributes for better readability and
serializes to a Python dictionary.
Returns:
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance,
"""
# 将当前配置转换为字典形式
config_dict = self.to_dict()
# 获取默认配置的字典形式
default_config_dict = GenerationConfig().to_dict()
# 初始化一个空字典,用于存储与默认配置不同的配置项
serializable_config_dict = {}
# 只序列化与默认配置不同的值
for key, value in config_dict.items():
# 如果配置项不在默认配置中,或者是特定例外项,或者值不同,则加入序列化字典中
if key not in default_config_dict or key == "transformers_version" or value != default_config_dict[key]:
serializable_config_dict[key] = value
# 转换字典中的 torch 数据类型为字符串表示
self.dict_torch_dtype_to_str(serializable_config_dict)
return serializable_config_dict
def to_dict(self) -> Dict[str, Any]:
"""
Serializes this instance to a Python dictionary.
Returns:
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
"""
# 深拷贝对象的 __dict__ 属性,得到一个副本
output = copy.deepcopy(self.__dict__)
# 在序列化时忽略的字段
if "_commit_hash" in output:
del output["_commit_hash"]
if "_original_object_hash" in output:
del output["_original_object_hash"]
# 序列化时记录 Transformers 版本信息
output["transformers_version"] = __version__
# 转换字典中的 torch 数据类型为字符串表示
self.dict_torch_dtype_to_str(output)
return output
def to_json_string(self, use_diff: bool = True, ignore_metadata: bool = False) -> str:
"""
Serializes this instance to a JSON string.
Args:
use_diff (`bool`, *optional*, defaults to `True`):
If set to `True`, only the difference between the config instance and the default `GenerationConfig()`
is serialized to JSON string.
ignore_metadata (`bool`, *optional*, defaults to `False`):
Whether to ignore the metadata fields present in the instance
Returns:
`str`: String containing all the attributes that make up this configuration instance in JSON format.
"""
# 根据 use_diff 参数决定是否只序列化配置实例与默认 GenerationConfig() 之间的差异
if use_diff is True:
config_dict = self.to_diff_dict() # 调用实例方法获取配置实例与默认配置之间的差异字典
else:
config_dict = self.to_dict() # 调用实例方法获取完整的配置实例字典
# 如果 ignore_metadata 参数为 True,则移除配置字典中的元数据字段
if ignore_metadata:
for metadata_field in METADATA_FIELDS:
config_dict.pop(metadata_field, None)
# 定义一个函数,将字典中的键转换为字符串类型
def convert_keys_to_string(obj):
if isinstance(obj, dict):
return {str(key): convert_keys_to_string(value) for key, value in obj.items()}
elif isinstance(obj, list):
return [convert_keys_to_string(item) for item in obj]
else:
return obj
# 转换配置字典中所有键为字符串类型
config_dict = convert_keys_to_string(config_dict)
# 将转换后的配置字典转换为带缩进、按键排序的 JSON 格式字符串,并添加换行符
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
def to_json_file(self, json_file_path: Union[str, os.PathLike], use_diff: bool = True):
"""
Save this instance to a JSON file.
Args:
json_file_path (`str` or `os.PathLike`):
Path to the JSON file in which this configuration instance's parameters will be saved.
use_diff (`bool`, *optional*, defaults to `True`):
If set to `True`, only the difference between the config instance and the default `GenerationConfig()`
is serialized to JSON file.
"""
# 打开指定路径的 JSON 文件,并将实例转换为 JSON 字符串后写入文件
with open(json_file_path, "w", encoding="utf-8") as writer:
writer.write(self.to_json_string(use_diff=use_diff))
@classmethod
def from_model_config(cls, model_config: PretrainedConfig) -> "GenerationConfig":
"""
从一个预训练配置 (`PretrainedConfig`) 实例化一个生成配置 (`GenerationConfig`)。
这个函数用于将可能包含生成参数的旧式预训练配置对象转换为独立的生成配置对象。
Args:
model_config (`PretrainedConfig`):
将用于实例化生成配置的模型配置。
Returns:
[`GenerationConfig`]: 从这些参数实例化的配置对象。
"""
# 将模型配置转换为字典
config_dict = model_config.to_dict()
# 移除特定的属性,这些属性不应该用于构建生成配置
config_dict.pop("_from_model_config", None)
# 通过字典创建生成配置对象,确保不返回未使用的关键字参数
config = cls.from_dict(config_dict, return_unused_kwargs=False, _from_model_config=True)
# 特殊情况:某些模型在解码器中设置了生成属性。如果生成配置中仍未设置这些属性,则使用解码器中的值。
for decoder_name in ("decoder", "generator", "text_config"):
if decoder_name in config_dict:
default_generation_config = GenerationConfig()
decoder_config = config_dict[decoder_name]
# 检查生成配置中的每个属性,如果属性在解码器配置中存在且生成配置中未设置,则设置为解码器中的值
for attr in config.to_dict().keys():
if attr in decoder_config and getattr(config, attr) == getattr(default_generation_config, attr):
setattr(config, attr, decoder_config[attr])
# 计算对象的哈希值,用于检测实例是否已修改
config._original_object_hash = hash(config)
return config
def update(self, **kwargs):
"""
使用 `kwargs` 中的属性更新该类实例的属性,如果属性匹配现有属性,则返回所有未使用的 kwargs。
Args:
kwargs (`Dict[str, Any]`):
尝试更新此类的属性的属性字典。
Returns:
`Dict[str, Any]`: 包含所有未用于更新实例的键值对的字典。
"""
to_remove = []
# 遍历传入的关键字参数
for key, value in kwargs.items():
# 如果类实例具有这个属性,则更新为传入的值,并记录已更新的属性名
if hasattr(self, key):
setattr(self, key, value)
to_remove.append(key)
# 确保更新后的实例仍然有效
self.validate()
# 返回所有未使用的关键字参数,即未更新到类实例的参数
unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove}
return unused_kwargs
.\generation\flax_logits_process.py
import inspect
import jax
import jax.lax as lax
import jax.numpy as jnp
from ..utils import add_start_docstrings
from ..utils.logging import get_logger
logger = get_logger(__name__)
LOGITS_PROCESSOR_INPUTS_DOCSTRING = r"""
Args:
input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
输入序列标记的索引,形状为(batch_size, sequence_length)。
可以使用[`PreTrainedTokenizer`]来获取索引。参见[`PreTrainedTokenizer.encode`]和
[`PreTrainedTokenizer.__call__`]获取详情。
[什么是输入ID?](../glossary#input-ids)
scores (`jnp.ndarray` of shape `(batch_size, config.vocab_size)`):
语言模型头的预测分数。当不使用beam搜索时,这些可以是每个词汇的logits;当使用beam搜索时,可以是
每个词汇token的log softmax。
kwargs (`Dict[str, Any]`, *optional*):
特定于logits处理器的额外kwargs参数。
Return:
`jnp.ndarray` of shape `(batch_size, config.vocab_size)`: 处理后的预测分数。
"""
class FlaxLogitsProcessor:
"""用于生成过程中可以应用的所有logits处理器的抽象基类。"""
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray) -> jnp.ndarray:
"""处理logits的Flax方法。"""
raise NotImplementedError(
f"{self.__class__}是一个抽象类。只有继承了这个类的类才能被调用。"
)
class FlaxLogitsWarper:
"""用于使用多项式采样生成过程中可以应用的所有logit变形器的抽象基类。"""
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray) -> jnp.ndarray:
"""变形logits的Flax方法。"""
raise NotImplementedError(
f"{self.__class__}是一个抽象类。只有继承了这个类的类才能被调用。"
)
class FlaxLogitsProcessorList(list):
"""
此类可用于创建[`FlaxLogitsProcessor`]或[`FlaxLogitsWarper`]的列表,以随后处理`scores`输入张量。
此类继承自列表,并添加了一个特定的*__call__*方法来应用每个[`FlaxLogitsProcessor`]或[`FlaxLogitsWarper`]到输入上。
"""
"""
对象方法,根据给定的输入和参数处理逻辑
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int, **kwargs) -> jnp.ndarray:
# 遍历每个处理器对象
for processor in self:
# 获取处理器的调用方法参数签名
function_args = inspect.signature(processor.__call__).parameters
# 如果参数个数大于3
if len(function_args) > 3:
# 检查是否所有所需的参数都在kwargs中
if not all(arg in kwargs for arg in list(function_args.keys())[2:]):
# 如果有缺失参数,抛出数值错误异常
raise ValueError(
f"Make sure that all the required parameters: {list(function_args.keys())} for "
f"{processor.__class__} are passed to the logits processor."
)
# 调用处理器的方法,传入输入数据、得分、当前长度和其他参数
scores = processor(input_ids, scores, cur_len, **kwargs)
else:
# 如果参数个数不大于3,直接调用处理器的方法,传入输入数据、得分和当前长度
scores = processor(input_ids, scores, cur_len)
# 返回处理后的得分
return scores
```
class FlaxTemperatureLogitsWarper(FlaxLogitsWarper):
r"""
[`FlaxLogitsWarper`] for temperature (exponential scaling output probability distribution).
Args:
temperature (`float`):
The value used to module the logits distribution.
"""
def __init__(self, temperature: float):
# 检查温度参数是否为正浮点数,如果不是则抛出异常
if not isinstance(temperature, float) or not (temperature > 0):
raise ValueError(f"`temperature` has to be a strictly positive float, but is {temperature}")
self.temperature = temperature
def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
# 将得分按温度值缩放,用于温度调节输出概率分布
scores = scores / self.temperature
return scores
class FlaxTopPLogitsWarper(FlaxLogitsWarper):
"""
[`FlaxLogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
Args:
top_p (`float`):
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
higher are kept for generation.
filter_value (`float`, *optional*, defaults to -inf):
All filtered values will be set to this float value.
min_tokens_to_keep (`int`, *optional*, defaults to 1):
Minimum number of tokens that cannot be filtered.
"""
def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
# 检查 top_p 是否为介于 0 和 1 之间的浮点数,否则抛出异常
if not isinstance(top_p, float) or (top_p < 0 or top_p > 1.0):
raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")
# 检查 min_tokens_to_keep 是否为正整数,否则抛出异常
if not isinstance(min_tokens_to_keep, int) or (min_tokens_to_keep < 1):
raise ValueError(f"`min_tokens_to_keep` has to be a positive integer, but is {min_tokens_to_keep}")
self.top_p = top_p
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep
def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
# 获取前 k 个最高得分和其对应的索引
topk_scores, topk_indices = lax.top_k(scores, scores.shape[-1])
# 创建一个与 scores 形状相同的数组,填充为 filter_value
mask_scores = jnp.full_like(scores, self.filter_value)
# 计算 softmax 后的累积概率
cumulative_probs = jax.nn.softmax(topk_scores, axis=-1).cumsum(axis=-1)
# 创建用于掩码的布尔数组,仅保留累积概率小于 top_p 的部分
score_mask = cumulative_probs < self.top_p
# 将累积概率大于 top_p 的位置移到 score_mask 中
score_mask = jnp.roll(score_mask, 1)
score_mask |= score_mask.at[:, 0].set(True)
# 至少保留 min_tokens_to_keep 个 token
score_mask = score_mask.at[:, : self.min_tokens_to_keep].set(True)
# 根据 score_mask 选择相应的得分值或者 filter_value
topk_next_scores = jnp.where(score_mask, topk_scores, mask_scores)
# 按照 topk_indices 排序,获取排序后的最终得分
next_scores = jax.lax.sort_key_val(topk_indices, topk_next_scores)[-1]
return next_scores
class FlaxTopKLogitsWarper(FlaxLogitsWarper):
r"""
[`FlaxLogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements.
Args:
top_k (`int`):
The number of highest probability vocabulary tokens to keep for top-k-filtering.
filter_value (`float`, *optional*, defaults to -inf):
All filtered values will be set to this float value.
min_tokens_to_keep (`int`, *optional*, defaults to 1):
Minimum number of tokens that cannot be filtered.
"""
定义一个类,用于执行Top-K筛选操作,保留概率最高的词汇标记。
def __init__(self, top_k: int, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
if not isinstance(top_k, int) or top_k <= 0:
raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}")
# 初始化方法,设置Top-K值,并确保不小于最小保留标记数
self.top_k = max(top_k, min_tokens_to_keep)
self.filter_value = filter_value
def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
# 调用实例时,执行Top-K筛选操作
# 获取输入的批次大小和词汇表大小
batch_size, vocab_size = scores.shape
# 初始化一个数组,用来存储被过滤后的分数值,默认为filter_value
next_scores_flat = jnp.full(batch_size * vocab_size, self.filter_value)
# 确定实际的Top-K值,避免超过分数数组的长度
topk = min(self.top_k, scores.shape[-1])
# 使用JAX库中的top_k函数找到每个批次中前Top-K个分数及其对应的索引
topk_scores, topk_indices = lax.top_k(scores, topk)
# 计算扁平化后的索引偏移,以便在一维数组中正确设置Top-K分数
shift = jnp.broadcast_to((jnp.arange(batch_size) * vocab_size)[:, None], (batch_size, topk)).flatten()
topk_scores_flat = topk_scores.flatten()
topk_indices_flat = topk_indices.flatten() + shift
# 在next_scores_flat数组中设置Top-K分数值
next_scores_flat = next_scores_flat.at[topk_indices_flat].set(topk_scores_flat)
# 将扁平化后的数组重新形状为(batch_size, vocab_size),得到最终的Top-K分数数组
next_scores = next_scores_flat.reshape(batch_size, vocab_size)
return next_scores
class FlaxForcedBOSTokenLogitsProcessor(FlaxLogitsProcessor):
r"""
[`FlaxLogitsProcessor`] that enforces the specified token as the first generated token.
Args:
bos_token_id (`int`):
The id of the token to force as the first generated token.
"""
def __init__(self, bos_token_id: int):
self.bos_token_id = bos_token_id # 初始化函数,保存要强制作为第一个生成token的token id
def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
new_scores = jnp.full(scores.shape, -float("inf")) # 创建一个形状与scores相同的全负无穷数组
apply_penalty = 1 - jnp.bool_(cur_len - 1) # 根据当前生成长度是否为0,决定是否应用惩罚
scores = jnp.where(apply_penalty, new_scores.at[:, self.bos_token_id].set(0), scores)
# 根据apply_penalty条件,将scores中对应bos_token_id列的值设置为0,其它位置不变
return scores
class FlaxForcedEOSTokenLogitsProcessor(FlaxLogitsProcessor):
r"""
[`FlaxLogitsProcessor`] that enforces the specified token as the last generated token when `max_length` is reached.
Args:
max_length (`int`):
The maximum length of the sequence to be generated.
eos_token_id (`int`):
The id of the token to force as the last generated token when `max_length` is reached.
"""
def __init__(self, max_length: int, eos_token_id: int):
self.max_length = max_length # 初始化函数,保存最大生成长度
self.eos_token_id = eos_token_id # 初始化函数,保存要强制作为末尾生成token的token id
def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
new_scores = jnp.full(scores.shape, -float("inf")) # 创建一个形状与scores相同的全负无穷数组
apply_penalty = 1 - jnp.bool_(cur_len - self.max_length + 1)
# 根据当前生成长度是否为max_length,决定是否应用惩罚
scores = jnp.where(apply_penalty, new_scores.at[:, self.eos_token_id].set(0), scores)
# 根据apply_penalty条件,将scores中对应eos_token_id列的值设置为0,其它位置不变
return scores
class FlaxMinLengthLogitsProcessor(FlaxLogitsProcessor):
r"""
[`FlaxLogitsProcessor`] enforcing a min-length by setting EOS probability to 0.
Args:
min_length (`int`):
The minimum length below which the score of `eos_token_id` is set to `-float("Inf")`.
eos_token_id (`int`):
The id of the *end-of-sequence* token.
"""
def __init__(self, min_length: int, eos_token_id: int):
if not isinstance(min_length, int) or min_length < 0:
raise ValueError(f"`min_length` has to be a positive integer, but is {min_length}")
if not isinstance(eos_token_id, int) or eos_token_id < 0:
raise ValueError(f"`eos_token_id` has to be a positive integer, but is {eos_token_id}")
self.min_length = min_length # 初始化函数,保存最小生成长度
self.eos_token_id = eos_token_id # 初始化函数,保存要设置其概率为负无穷的token id
def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
# create boolean flag to decide if min length penalty should be applied
apply_penalty = 1 - jnp.clip(cur_len - self.min_length, 0, 1)
# 根据当前生成长度是否小于min_length,决定是否应用惩罚
scores = jnp.where(apply_penalty, scores.at[:, self.eos_token_id].set(-float("inf")), scores)
# 根据apply_penalty条件,将scores中对应eos_token_id列的值设置为负无穷,其它位置不变
return scores
class FlaxSuppressTokensAtBeginLogitsProcessor(FlaxLogitsProcessor):
r"""
Args:
begin_suppress_tokens (`List[int]`):
不抽样的 token 列表。
begin_index (`int`):
开始抑制 token 的索引位置。
"""
class FlaxLogitsProcessor:
def __init__(self, begin_suppress_tokens, begin_index):
# 将输入的 begin_suppress_tokens 转换为列表
self.begin_suppress_tokens = list(begin_suppress_tokens)
# 设置开始抑制 token 的索引位置
self.begin_index = begin_index
def __call__(self, input_ids, scores, cur_len: int):
# 根据当前生成长度 `cur_len` 和开始抑制的索引 `begin_index` 计算是否应用惩罚
apply_penalty = 1 - jnp.bool_(cur_len - self.begin_index)
# 根据应用的惩罚,将指定的 `begin_suppress_tokens` 的分数设置为负无穷大
scores = jnp.where(apply_penalty, scores.at[:, self.begin_suppress_tokens].set(-float("inf")), scores)
# 返回处理后的分数
return scores
class FlaxSuppressTokensLogitsProcessor(FlaxLogitsProcessor):
r"""
[`FlaxLogitsProcessor`] suppressing a list of tokens at each decoding step. The processor will set their log probs
to be `-inf` so they are not sampled.
Args:
suppress_tokens (`list`):
Tokens to not sample.
"""
def __init__(self, suppress_tokens: list):
# 初始化方法,接收一个要抑制的token列表
self.suppress_tokens = list(suppress_tokens)
def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
# 在scores张量的指定位置设置为负无穷,以便在采样时不被选中
scores = scores.at[..., self.suppress_tokens].set(-float("inf"))
return scores
class FlaxForceTokensLogitsProcessor(FlaxLogitsProcessor):
r"""
[`FlaxLogitsProcessor`] that takes a list of pairs of integers which indicates a mapping from generation indices to
token indices that will be forced before sampling. The processor will set their log probs to 0 and all other tokens
to `-inf` so that they are sampled at their corresponding index.
Args:
force_token_map (`list`):
Map giving token ids and indices where they will be forced to be sampled.
"""
def __init__(self, force_token_map):
# 将force_token_map转换为字典格式,并初始化一个强制token的数组以提高XLA的兼容性
force_token_map = dict(force_token_map)
force_token_array = jnp.ones((max(force_token_map.keys()) + 1), dtype=jnp.int32) * -1
for index, token in force_token_map.items():
if token is not None:
force_token_array = force_token_array.at[index].set(token)
self.force_token_array = jnp.int32(force_token_array)
def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
def _force_token(generation_idx):
# 根据generation_idx确定要强制采样的token,并更新scores张量
batch_size = scores.shape[0]
current_token = self.force_token_array[generation_idx]
new_scores = jnp.ones_like(scores, dtype=scores.dtype) * -float("inf")
updates = jnp.zeros((batch_size, 1), dtype=scores.dtype)
new_scores = lax.dynamic_update_slice(new_scores, updates, (0, current_token))
return new_scores
# 使用lax.cond根据cur_len的值来决定是否进行token强制操作
scores = lax.cond(
cur_len >= self.force_token_array.shape[0],
# 如果当前长度大于等于force_token_array的长度,则不进行强制操作
lambda: scores,
# 否则,根据force_token_array[cur_len]的值来判断是否强制采样特定token
lambda: lax.cond(
self.force_token_array[cur_len] >= 0,
# 只有有效(非负)的token才会被强制采样
lambda: _force_token(cur_len),
# 否则不进行强制操作
lambda: scores,
),
)
return scores
class FlaxWhisperTimeStampLogitsProcessor(FlaxLogitsProcessor):
r"""
Whisper specific Processor. This processor can be used to force a list of tokens. The processor will set their log
probs to `inf` so that they are sampled at their corresponding index.
Args:
generate_config (`GenerateConfig`):
The generate config used to generate the output. The following parameters are required:
eos_token_id (`int`, *optional*, defaults to 50257):
The id of the *end-of-sequence* token.
no_timestamps_token_id (`int`, *optional*, defaults to 50363):
The id of the `"<|notimestamps|>"` token.
max_initial_timestamp_index (`int`, *optional*, defaults to 1):
Used to set the maximum value of the initial timestamp. This is used to prevent the model from
predicting timestamps that are too far in the future.
"""
def __init__(self, generate_config, model_config, decoder_input_length):
# 初始化方法,设置对象的初始属性
self.eos_token_id = generate_config.eos_token_id
self.no_timestamps_token_id = generate_config.no_timestamps_token_id
# 设置时间戳开始的位置
self.timestamp_begin = generate_config.no_timestamps_token_id + 1
# 设置开始索引,考虑解码器输入长度
self.begin_index = decoder_input_length + 1
# 如果是多语言模型,为语言标记和任务标记预留空间
if generate_config.is_multilingual:
self.begin_index += 2
# 如果生成配置有最大初始时间戳索引属性,使用该值;否则使用模型词汇表大小
if hasattr(generate_config, "max_initial_timestamp_index"):
self.max_initial_timestamp_index = generate_config.max_initial_timestamp_index
else:
self.max_initial_timestamp_index = model_config.vocab_size
# 如果最大初始时间戳索引为 None,则设为模型词汇表大小
if self.max_initial_timestamp_index is None:
self.max_initial_timestamp_index = model_config.vocab_size
def __call__(self, input_ids, scores, cur_len):
# 将包含 self.no_timestamps_token_id 的列设为负无穷,这由 without_timestamps 处理
scores = scores.at[:, self.no_timestamps_token_id].set(-float("inf"))
def handle_pairs(input_ids_k, scores_k):
# 判断前一个 token 是否为时间戳,如果是,则设置为 True,否则为 False
last_was_timestamp = jnp.where((cur_len - self.begin_index) >= 1, True, False)
last_was_timestamp = jnp.where(
input_ids_k[cur_len - 1] >= self.timestamp_begin,
True and last_was_timestamp,
False,
)
# 判断倒数第二个 token 是否为时间戳,如果是,则设置为 True,否则为 False
penultimate_was_timestamp = jnp.where((cur_len - self.begin_index) < 2, True, False)
penultimate_was_timestamp = jnp.where(
input_ids_k[cur_len - 2] >= self.timestamp_begin,
True,
penultimate_was_timestamp,
)
return jnp.where(
last_was_timestamp,
jnp.where(
penultimate_was_timestamp > 0,
scores_k.at[self.timestamp_begin :].set(-float("inf")), # 如果倒数第二个是时间戳,则将时间戳之后的分数设为负无穷
scores_k.at[: self.eos_token_id].set(-float("inf")), # 否则将句子结束符之前的分数设为负无穷
),
scores_k, # 如果前一个不是时间戳,则保持分数不变
)
# 对每对 (input_ids, scores) 应用 handle_pairs 函数
scores = jax.vmap(handle_pairs)(input_ids, scores)
# 判断是否应用最大初始时间戳策略
apply_max_initial_timestamp = jnp.where(cur_len == self.begin_index, True, False)
apply_max_initial_timestamp = jnp.where(
self.max_initial_timestamp_index is not None,
True and apply_max_initial_timestamp,
False,
)
# 计算最大允许的时间戳
last_allowed = self.timestamp_begin + self.max_initial_timestamp_index
# 如果应用最大初始时间戳策略,则将分数矩阵中大于最大允许时间戳之后的分数设为负无穷
scores = jnp.where(
apply_max_initial_timestamp,
scores.at[:, last_allowed + 1 :].set(-float("inf")),
scores,
)
# 如果时间戳的概率总和超过其它 token 的概率总和,则将时间戳之前的分数设为负无穷
logprobs = jax.nn.log_softmax(scores, axis=-1)
def handle_cumulative_probs(logprobs_k, scores_k):
timestamp_logprob = jax.nn.logsumexp(logprobs_k[self.timestamp_begin :], axis=-1)
max_text_token_logprob = jnp.max(logprobs_k[: self.timestamp_begin])
return jnp.where(
timestamp_logprob > max_text_token_logprob,
scores_k.at[: self.timestamp_begin].set(-float("inf")), # 如果时间戳的概率总和高于其它 token,则将时间戳之前的分数设为负无穷
scores_k, # 否则保持分数不变
)
# 对每个 (logprobs, scores) 应用 handle_cumulative_probs 函数
scores = jax.vmap(handle_cumulative_probs)(logprobs, scores)
# 返回处理后的分数矩阵
return scores
.\generation\flax_utils.py
import copy
import inspect
import warnings
from functools import partial
from typing import Any, Dict, Optional, Union
import flax
import jax
import jax.numpy as jnp
import numpy as np
from jax import lax
from ..models.auto import (
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING,
)
from ..utils import ModelOutput, logging
from .configuration_utils import GenerationConfig
from .flax_logits_process import (
FlaxForcedBOSTokenLogitsProcessor,
FlaxForcedEOSTokenLogitsProcessor,
FlaxForceTokensLogitsProcessor,
FlaxLogitsProcessorList,
FlaxMinLengthLogitsProcessor,
FlaxSuppressTokensAtBeginLogitsProcessor,
FlaxSuppressTokensLogitsProcessor,
FlaxTemperatureLogitsWarper,
FlaxTopKLogitsWarper,
FlaxTopPLogitsWarper,
)
logger = logging.get_logger(__name__)
@flax.struct.dataclass
class FlaxGreedySearchOutput(ModelOutput):
"""
Flax Base class for outputs of decoder-only generation models using greedy search.
Args:
sequences (`jnp.ndarray` of shape `(batch_size, max_length)`):
The generated sequences.
"""
sequences: jnp.ndarray = None
@flax.struct.dataclass
class FlaxSampleOutput(ModelOutput):
"""
Flax Base class for outputs of decoder-only generation models using sampling.
Args:
sequences (`jnp.ndarray` of shape `(batch_size, max_length)`):
The generated sequences.
"""
sequences: jnp.ndarray = None
@flax.struct.dataclass
class FlaxBeamSearchOutput(ModelOutput):
"""
Flax Base class for outputs of decoder-only generation models using beam search.
Args:
sequences (`jnp.ndarray` of shape `(batch_size, max_length)`):
The generated sequences.
scores (`jnp.ndarray` of shape `(batch_size,)`):
The scores (log probabilities) of the generated sequences.
"""
sequences: jnp.ndarray = None
scores: jnp.ndarray = None
@flax.struct.dataclass
class GreedyState:
"""
Dataclass to store state during greedy decoding.
Args:
cur_len (`jnp.ndarray`): Current lengths of sequences.
sequences (`jnp.ndarray`): Generated sequences.
running_token (`jnp.ndarray`): Running tokens for decoding.
is_sent_finished (`jnp.ndarray`): Boolean array indicating finished sentences.
model_kwargs (Dict[str, jnp.ndarray]): Additional model arguments.
"""
cur_len: jnp.ndarray
sequences: jnp.ndarray
running_token: jnp.ndarray
is_sent_finished: jnp.ndarray
model_kwargs: Dict[str, jnp.ndarray]
@flax.struct.dataclass
class SampleState:
"""
Dataclass to store state during sampling.
Args:
cur_len (`jnp.ndarray`): Current lengths of sequences.
"""
cur_len: jnp.ndarray
sequences: jnp.ndarray
running_token: jnp.ndarray
is_sent_finished: jnp.ndarray
prng_key: jnp.ndarray
model_kwargs: Dict[str, jnp.ndarray]
@flax.struct.dataclass
class BeamSearchState:
cur_len: jnp.ndarray
running_sequences: jnp.ndarray
running_scores: jnp.ndarray
sequences: jnp.ndarray
scores: jnp.ndarray
is_sent_finished: jnp.ndarray
model_kwargs: Dict[str, jnp.ndarray]
class FlaxGenerationMixin:
"""
包含自回归文本生成的所有函数的类,作为[`FlaxPreTrainedModel`]的混合类使用。
该类公开[`~generation.FlaxGenerationMixin.generate`]方法,可用于:
- 当`num_beams=1`且`do_sample=False`时通过调用[`~generation.FlaxGenerationMixin._greedy_search`]进行贪婪解码
- 当`num_beams=1`且`do_sample=True`时通过调用[`~generation.FlaxGenerationMixin._sample`]进行多项式采样
- 当`num_beams>1`且`do_sample=False`时通过调用[`~generation.FlaxGenerationMixin._beam_search`]进行束搜索解码
无需直接调用上述任何方法。只需将自定义参数值传递给'generate'方法即可。有关解码策略的更多信息,请参阅[文本生成策略指南](../generation_strategies)。
"""
def prepare_inputs_for_generation(self, *args, **kwargs):
raise NotImplementedError(
"A model class needs to define a `prepare_inputs_for_generation` method in order to use `generate`."
)
@staticmethod
def _run_loop_in_debug(cond_fn, body_fn, init_state):
"""
以非跟踪模式运行生成过程。仅用于调试目的。
"""
state = init_state
while cond_fn(state):
state = body_fn(state)
return state
def _prepare_encoder_decoder_kwargs_for_generation(self, input_ids, params, model_kwargs):
encoder_kwargs = {
argument: value
for argument, value in model_kwargs.items()
if not (argument.startswith("decoder_") or argument.startswith("cross_attn"))
}
model_kwargs["encoder_outputs"] = self.encode(input_ids, params=params, return_dict=True, **encoder_kwargs)
return model_kwargs
def _prepare_decoder_input_ids_for_generation(
self,
batch_size: int,
decoder_start_token_id: int = None,
bos_token_id: int = None,
model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
) -> jnp.ndarray:
if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
decoder_input_ids = model_kwargs.pop("decoder_input_ids")
if decoder_input_ids is not None:
return decoder_input_ids
decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id)
return jnp.array(decoder_start_token_id, dtype="i4").reshape(1, -1).repeat(batch_size, axis=0)
def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_token_id: int = None) -> int:
decoder_start_token_id = (
decoder_start_token_id
if decoder_start_token_id is not None
else self.generation_config.decoder_start_token_id
)
bos_token_id = bos_token_id if bos_token_id is not None else self.generation_config.bos_token_id
if decoder_start_token_id is not None:
return decoder_start_token_id
elif (
hasattr(self.config, "decoder")
and hasattr(self.config.decoder, "decoder_start_token_id")
and self.config.decoder.decoder_start_token_id is not None
):
return self.config.decoder.decoder_start_token_id
elif bos_token_id is not None:
return bos_token_id
elif (
hasattr(self.config, "decoder")
and hasattr(self.config.decoder, "bos_token_id")
and self.config.decoder.bos_token_id is not None
):
return self.config.decoder.bos_token_id
raise ValueError(
"`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
)
@staticmethod
def _expand_to_num_beams(tensor, num_beams):
return jnp.broadcast_to(tensor[:, None], (tensor.shape[0], num_beams) + tensor.shape[1:])
def _adapt_logits_for_beam_search(self, logits):
"""
This function can be overwritten in the specific modeling_flax_<model-name>.py classes to allow for custom beam
search behavior. Note that the only model that overwrites this method is [`~transformes.FlaxMarianMTModel`].
"""
return logits
def _validate_model_class(self):
"""
Confirms that the model class is compatible with generation. If not, raises an exception that points to the
right class to use.
"""
if not self.can_generate():
generate_compatible_mappings = [
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING,
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
]
generate_compatible_classes = set()
for model_mapping in generate_compatible_mappings:
supported_models = model_mapping.get(type(self.config), default=None)
if supported_models is not None:
generate_compatible_classes.add(supported_models.__name__)
exception_message = (
f"The current model class ({self.__class__.__name__}) is not compatible with `.generate()`, as "
"it doesn't have a language model head."
)
if generate_compatible_classes:
exception_message += f" Please use one of the following classes instead: {generate_compatible_classes}"
raise TypeError(exception_message)
def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):
"""Validates model kwargs for generation. Generate argument typos will also be caught here."""
unused_model_args = []
model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters)
if "kwargs" in model_args or "model_kwargs" in model_args:
model_args |= set(inspect.signature(self.__call__).parameters)
for key, value in model_kwargs.items():
if value is not None and key not in model_args:
unused_model_args.append(key)
if unused_model_args:
raise ValueError(
f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the"
" generate arguments will also show up in this list)"
)
def generate(
self,
input_ids: jnp.ndarray,
generation_config: Optional[GenerationConfig] = None,
prng_key: Optional[jnp.ndarray] = None,
trace: bool = True,
params: Optional[Dict[str, jnp.ndarray]] = None,
logits_processor: Optional[FlaxLogitsProcessorList] = None,
**kwargs,
def _get_logits_warper(self, generation_config: GenerationConfig) -> FlaxLogitsProcessorList:
"""
返回一个 [`FlaxLogitsProcessorList`] 列表对象,其中包含所有用于多项式采样的相关 [`FlaxLogitsWarper`] 实例。
"""
warpers = FlaxLogitsProcessorList()
if generation_config.temperature is not None and generation_config.temperature != 1.0:
warpers.append(FlaxTemperatureLogitsWarper(generation_config.temperature))
if generation_config.top_k is not None and generation_config.top_k != 0:
warpers.append(FlaxTopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=1))
if generation_config.top_p is not None and generation_config.top_p < 1.0:
warpers.append(FlaxTopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=1))
return warpers
def _get_logits_processor(
self,
generation_config: GenerationConfig,
input_ids_seq_length: int,
logits_processor: Optional[FlaxLogitsProcessorList],
) -> FlaxLogitsProcessorList:
"""
This method returns a [`FlaxLogitsProcessorList`] object containing all relevant
[`FlaxLogitsProcessor`] instances used to modify the scores of the language model head.
"""
processors = FlaxLogitsProcessorList()
if (
generation_config.min_length is not None
and generation_config.eos_token_id is not None
and generation_config.min_length > -1
):
processors.append(
FlaxMinLengthLogitsProcessor(generation_config.min_length, generation_config.eos_token_id)
)
if generation_config.forced_bos_token_id is not None:
processors.append(FlaxForcedBOSTokenLogitsProcessor(generation_config.forced_bos_token_id))
if generation_config.forced_eos_token_id is not None:
processors.append(
FlaxForcedEOSTokenLogitsProcessor(generation_config.max_length, generation_config.forced_eos_token_id)
)
if generation_config.suppress_tokens is not None:
processors.append(FlaxSuppressTokensLogitsProcessor(generation_config.suppress_tokens))
if generation_config.begin_suppress_tokens is not None:
begin_index = input_ids_seq_length
begin_index = (
begin_index
if (input_ids_seq_length > 1 or generation_config.forced_bos_token_id is None)
else begin_index + 1
)
if generation_config.forced_decoder_ids is not None and len(generation_config.forced_decoder_ids) > 0:
begin_index += generation_config.forced_decoder_ids[-1][0]
processors.append(
FlaxSuppressTokensAtBeginLogitsProcessor(generation_config.begin_suppress_tokens, begin_index)
)
if generation_config.forced_decoder_ids is not None:
forced_decoder_ids = [
[input_ids_seq_length + i[0] - 1, i[1]] for i in generation_config.forced_decoder_ids
]
processors.append(FlaxForceTokensLogitsProcessor(forced_decoder_ids))
processors = self._merge_criteria_processor_list(processors, logits_processor)
return processors
def _merge_criteria_processor_list(
self,
default_list: FlaxLogitsProcessorList,
custom_list: FlaxLogitsProcessorList,
) -> FlaxLogitsProcessorList:
"""
This method merges a default list of logits processors with a custom list of logits processors.
It returns a combined [`FlaxLogitsProcessorList`] object.
"""
) -> FlaxLogitsProcessorList:
if len(custom_list) == 0:
return default_list
for default in default_list:
for custom in custom_list:
if type(custom) is type(default):
object_type = "logits processor"
raise ValueError(
f"A custom {object_type} of type {type(custom)} with values {custom} has been passed to"
f" `generate`, but it has already been created with the values {default}. {default} has been"
" created by passing the corresponding arguments to generate or by the model's config default"
f" values. If you just want to change the default values of {object_type} consider passing"
f" them as arguments to `generate` instead of using a custom {object_type}."
)
default_list.extend(custom_list)
return default_list
def _greedy_search(
self,
input_ids: None,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
logits_processor: Optional[FlaxLogitsProcessorList] = None,
trace: bool = True,
params: Optional[Dict[str, jnp.ndarray]] = None,
model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
def _sample(
self,
input_ids: None,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
prng_key: Optional[jnp.ndarray] = None,
logits_processor: Optional[FlaxLogitsProcessorList] = None,
logits_warper: Optional[FlaxLogitsProcessorList] = None,
trace: bool = True,
params: Optional[Dict[str, jnp.ndarray]] = None,
model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
def _beam_search(
self,
input_ids: None,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
length_penalty: Optional[float] = None,
early_stopping: Optional[Union[bool, str]] = None,
logits_processor: Optional[FlaxLogitsProcessorList] = None,
trace: bool = True,
params: Optional[Dict[str, jnp.ndarray]] = None,
num_return_sequences: Optional[int] = None,
model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
.\generation\logits_process.py
import inspect
import math
import warnings
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
import numpy as np
import torch
from ..utils import add_start_docstrings
from ..utils.logging import get_logger
logger = get_logger(__name__)
LOGITS_PROCESSOR_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
输入序列标记在词汇表中的索引。[什么是输入 ID?](../glossary#input-ids)
scores (`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`):
语言建模头的预测分数。当不使用 beam search 时,这些可以是每个词汇表的 logits;
当使用 beam search 时,这些可以是每个词汇表标记的对数 softmax
Return:
`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`: 处理后的预测分数。
"""
class LogitsProcessor:
"""所有生成过程中可以应用的 logits 处理器的抽象基类。"""
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
raise NotImplementedError(
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
)
class LogitsWarper:
"""所有多项式采样生成过程中可以应用的 logits 转换器的抽象基类。"""
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
raise NotImplementedError(
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
)
class LogitsProcessorList(list):
"""
可用于创建一个 [`LogitsProcessor`] 或 [`LogitsWarper`] 列表,以便随后处理输入张量 `scores`。
此类继承自列表,并添加了一个特定的 *__call__* 方法来对输入应用每个 [`LogitsProcessor`] 或 [`LogitsWarper`]。
"""
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
scores (`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`):
Prediction scores of a language modeling head. These can be logits for each vocabulary when not using
beam search or log softmax for each vocabulary token when using beam search
kwargs (`Dict[str, Any]`, *optional*):
Additional kwargs that are specific to a logits processor.
Return:
`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`:
The processed prediction scores.
"""
for processor in self:
function_args = inspect.signature(processor.__call__).parameters
if len(function_args) > 2:
if not all(arg in kwargs for arg in list(function_args.keys())[2:]):
raise ValueError(
f"Make sure that all the required parameters: {list(function_args.keys())} for "
f"{processor.__class__} are passed to the logits processor."
)
scores = processor(input_ids, scores, **kwargs)
else:
scores = processor(input_ids, scores)
return scores
class MinLengthLogitsProcessor(LogitsProcessor):
r"""
[`LogitsProcessor`] enforcing a min-length by setting EOS probability to 0. Note that, for decoder-only models
like most LLMs, the length includes the prompt.
Args:
min_length (`int`):
The minimum length below which the score of `eos_token_id` is set to `-float("Inf")`.
eos_token_id (`Union[int, List[int]]`):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
Examples:
```
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> tokenizer = AutoTokenizer.from_pretrained("bigscience/bloomz-560m")
>>> model = AutoModelForCausalLM.from_pretrained("bigscience/bloomz-560m")
>>> inputs = tokenizer("A number:", return_tensors="pt")
>>> gen_out = model.generate(**inputs)
>>> print(tokenizer.batch_decode(gen_out, skip_special_tokens=True)[0])
A number: one
>>> # setting `min_length` to a value smaller than the uncontrolled output length has no impact
>>> gen_out = model.generate(**inputs, min_length=3)
>>> print(tokenizer.batch_decode(gen_out, skip_special_tokens=True)[0])
A number: one
>>> # setting a larger `min_length` will force the model to generate beyond its natural ending point, which is not
>>> # necessarily incorrect
>>> gen_out = model.generate(**inputs, min_length=10)
>>> print(tokenizer.batch_decode(gen_out, skip_special_tokens=True)[0])
A number: one thousand, nine hundred and ninety-four
```
"""
def __init__(self, min_length: int, eos_token_id: Union[int, List[int]]):
if not isinstance(min_length, int) or min_length < 0:
raise ValueError(f"`min_length` has to be a non-negative integer, but is {min_length}")
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
if not all(isinstance(i, int) for i in eos_token_id) or any(i < 0 for i in eos_token_id):
logger.warning(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}")
self.min_length = min_length
self.eos_token_id = eos_token_id
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
cur_len = input_ids.shape[-1]
if cur_len < self.min_length:
for i in self.eos_token_id:
scores[:, i] = -float("inf")
return scores
class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
r"""
[`LogitsProcessor`] enforcing a min-length of new tokens by setting EOS (End-Of-Sequence) token probability to 0.
Contrarily to [`MinLengthLogitsProcessor`], this processor ignores the prompt.
```
# 注释继续在下一个代码块中
Args:
prompt_length_to_skip (`int`):
要跳过的输入标记长度。与 `generate` 一起使用时,不是有效的参数,因为它会自动分配输入长度。
min_new_tokens (`int`):
下面这个得分为 `-float("Inf")` 的条件最小 *新* 标记长度。
eos_token_id (`Union[int, List[int]]`):
*结束序列* 标记的 ID。可选择使用列表设置多个 *结束序列* 标记。
Examples:
```
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> tokenizer = AutoTokenizer.from_pretrained("bigscience/bloomz-560m")
>>> model = AutoModelForCausalLM.from_pretrained("bigscience/bloomz-560m")
>>> inputs = tokenizer(["A number:"], return_tensors="pt")
>>> gen_out = model.generate(**inputs)
>>> print(tokenizer.batch_decode(gen_out, skip_special_tokens=True)[0])
A number: one
>>> # 设置 `min_new_tokens` 将强制模型生成超出其自然结束点,这不一定是错误的
>>> gen_out = model.generate(**inputs, min_new_tokens=2)
>>> print(tokenizer.batch_decode(gen_out, skip_special_tokens=True)[0])
A number: one thousand
```
"""
def __init__(self, prompt_length_to_skip: int, min_new_tokens: int, eos_token_id: Union[int, List[int]]):
for arg_name, arg_value in [
("prompt_length_to_skip", prompt_length_to_skip),
("min_new_tokens", min_new_tokens),
]:
if not isinstance(arg_value, int) or arg_value < 0:
raise ValueError(f"`{arg_name}` 必须是正整数,但其值为 {arg_value}")
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
if not all(isinstance(i, int) for i in eos_token_id) or any(i < 0 for i in eos_token_id):
logger.warning(f"`eos_token_id` 必须是正整数列表,但其值为 {eos_token_id}")
self.prompt_length_to_skip = prompt_length_to_skip
self.min_new_tokens = min_new_tokens
self.eos_token_id = eos_token_id
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
new_tokens_length = input_ids.shape[-1] - self.prompt_length_to_skip
if new_tokens_length < self.min_new_tokens:
for i in self.eos_token_id:
scores[:, i] = -float("inf")
return scores
class TemperatureLogitsWarper(LogitsWarper):
r"""
[`LogitsWarper`] for temperature (exponential scaling output probability distribution), which effectively means
that it can control the randomness of the predicted tokens. Often used together with [`TopPLogitsWarper`] and
[`TopKLogitsWarper`].
<Tip>
Make sure that `do_sample=True` is included in the `generate` arguments otherwise the temperature value won't have
any effect.
</Tip>
Args:
temperature (`float`):
Strictly positive float value used to modulate the logits distribution. A value smaller than `1` decreases
randomness (and vice versa), with `0` being equivalent to shifting all probability mass to the most likely
token.
Examples:
```
>>> import torch
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
>>> set_seed(0) # for reproducibility
>>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
>>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
>>> model.config.pad_token_id = model.config.eos_token_id
>>> inputs = tokenizer(["Hugging Face Company is"], return_tensors="pt")
>>> # With temperature=1.0, the default, we consistently get random outputs due to random sampling.
>>> generate_kwargs = {"max_new_tokens": 10, "do_sample": True, "temperature": 1.0, "num_return_sequences": 2}
>>> outputs = model.generate(**inputs, **generate_kwargs)
>>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
['Hugging Face Company is a joint venture between GEO Group, one of',
'Hugging Face Company is not an exact science – but what we believe does']
>>> # However, with temperature close to 0, it approximates greedy decoding strategies (invariant)
>>> generate_kwargs["temperature"] = 0.0001
>>> outputs = model.generate(**inputs, **generate_kwargs)
>>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
['Hugging Face Company is a company that has been around for over 20 years',
'Hugging Face Company is a company that has been around for over 20 years']
```
"""
def __init__(self, temperature: float):
if not isinstance(temperature, float) or not (temperature > 0):
except_msg = (
f"`temperature` (={temperature}) has to be a strictly positive float, otherwise your next token "
"scores will be invalid."
)
if isinstance(temperature, float) and temperature == 0.0:
except_msg += " If you're looking for greedy decoding strategies, set `do_sample=False`."
raise ValueError(except_msg)
self.temperature = temperature
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
scores = scores / self.temperature
return scores
class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
r"""
[`LogitsProcessor`] that prevents the repetition of previous tokens through a penalty. This penalty is applied at
most once per token. Note that, for decoder-only models like most LLMs, the considered tokens include the prompt.
In the original [paper](https://arxiv.org/pdf/1909.05858.pdf), the authors suggest the use of a penalty of around
1.2 to achieve a good balance between truthful generation and lack of repetition. To penalize and reduce
repetition, use `penalty` values above 1.0, where a higher value penalizes more strongly. To reward and encourage
repetition, use `penalty` values between 0.0 and 1.0, where a lower value rewards more strongly.
Args:
penalty (`float`):
The parameter for repetition penalty. 1.0 means no penalty. Above 1.0 penalizes previously generated
tokens. Between 0.0 and 1.0 rewards previously generated tokens.
Examples:
```
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
>>> # Initializing the model and tokenizer for it
>>> model = AutoModelForCausalLM.from_pretrained("distilbert/distilgpt2")
>>> tokenizer = AutoTokenizer.from_pretrained("distilbert/distilgpt2")
>>> inputs = tokenizer(["I'm not going to"], return_tensors="pt")
>>> # This shows a normal generate without any specific parameters
>>> summary_ids = model.generate(**inputs)
>>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True)[0])
I'm not going to be able to do that. I'm going to be able to do that
>>> # This generates a penalty for repeated tokens
>>> penalized_ids = model.generate(**inputs, repetition_penalty=1.1)
>>> print(tokenizer.batch_decode(penalized_ids, skip_special_tokens=True)[0])
I'm not going to be able to do that. I'll just have to go out and play
```
"""
def __init__(self, penalty: float):
if not isinstance(penalty, float) or not (penalty > 0):
raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
self.penalty = penalty
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
score = torch.gather(scores, 1, input_ids)
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
scores.scatter_(1, input_ids, score)
return scores
class EncoderRepetitionPenaltyLogitsProcessor(LogitsProcessor):
r"""
[`LogitsProcessor`] that works similarly to [`RepetitionPenaltyLogitsProcessor`], but with an *inverse* penalty
that is applied to the tokens present in the prompt. In other words, a penalty above 1.0 increases the odds of
selecting tokens that were present in the prompt.
def __init__(self, penalty: float, encoder_input_ids: torch.LongTensor):
# 检查 penalty 是否为 float 类型且大于 0,否则抛出数值错误异常
if not isinstance(penalty, float) or not (penalty > 0):
raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
# 计算实际的惩罚值,即将 1 除以 penalty
self.penalty = 1 / penalty
# 将输入的 encoder_input_ids 赋值给实例变量
self.encoder_input_ids = encoder_input_ids
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# 从 scores 中按列索引提取与 encoder_input_ids 相对应的分数
score = torch.gather(scores, 1, self.encoder_input_ids)
# 如果分数小于 0,则乘以 penalty 值以增加 token 的概率
# 如果分数大于等于 0,则除以 penalty 值以降低 token 的概率
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
# 将处理后的 score 根据 encoder_input_ids 的索引位置更新到 scores 中
scores.scatter_(1, self.encoder_input_ids, score)
# 返回更新后的 scores
return scores
class TopPLogitsWarper(LogitsWarper):
"""
[`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off. Often
used together with [`TemperatureLogitsWarper`] and [`TopKLogitsWarper`].
Args:
top_p (`float`):
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
higher are kept for generation.
filter_value (`float`, *optional*, defaults to -inf):
All filtered values will be set to this float value.
min_tokens_to_keep (`int`, *optional*, defaults to 1):
Minimum number of tokens that cannot be filtered.
Examples:
```
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
>>> set_seed(0)
>>> model = AutoModelForCausalLM.from_pretrained("distilbert/distilgpt2")
>>> tokenizer = AutoTokenizer.from_pretrained("distilbert/distilgpt2")
>>> inputs = tokenizer("A sequence: 1, 2", return_tensors="pt")
>>>
>>> outputs = model.generate(**inputs, do_sample=True)
>>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
A sequence: 1, 2, 0, 2, 2. 2, 2, 2, 2
>>>
>>>
>>> outputs = model.generate(**inputs, do_sample=True, top_p=0.1)
>>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
A sequence: 1, 2, 3, 4, 5, 6, 7, 8, 9
```
"""
def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
# 初始化 TopPLogitsWarper 对象,设置 top-p 概率截断参数
top_p = float(top_p)
# 检查 top_p 参数是否在有效范围 (0, 1) 内,否则引发 ValueError 异常
if top_p < 0 or top_p > 1.0:
raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")
# 检查 min_tokens_to_keep 参数是否为正整数,否则引发 ValueError 异常
if not isinstance(min_tokens_to_keep, int) or (min_tokens_to_keep < 1):
raise ValueError(f"`min_tokens_to_keep` has to be a positive integer, but is {min_tokens_to_keep}")
# 设置对象的属性
self.top_p = top_p
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep
# 添加文档字符串作为类的一部分,描述输入参数
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
# 定义一个调用函数,接受输入的token IDs和对应的分数,返回处理后的分数
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# 对分数进行升序排序,并返回排序后的分数和索引
sorted_logits, sorted_indices = torch.sort(scores, descending=False)
# 对排序后的分数进行 softmax 处理并计算累积概率
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
# 移除累积概率超过 top_p 阈值的token(累积概率为0的token保留)
sorted_indices_to_remove = cumulative_probs <= (1 - self.top_p)
# 至少保留 min_tokens_to_keep 个token
sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0
# 将排序后的移除指标张量按照排序后的索引分散到原始索引位置
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
# 使用 filter_value 替换需要移除的token对应的分数
scores = scores.masked_fill(indices_to_remove, self.filter_value)
# 返回处理后的分数张量
return scores
class TopKLogitsWarper(LogitsWarper):
r"""
[`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements. Often used together
with [`TemperatureLogitsWarper`] and [`TopPLogitsWarper`].
Args:
top_k (`int`):
The number of highest probability vocabulary tokens to keep for top-k-filtering.
filter_value (`float`, *optional*, defaults to -inf):
All filtered values will be set to this float value.
min_tokens_to_keep (`int`, *optional*, defaults to 1):
Minimum number of tokens that cannot be filtered.
Examples:
```
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
>>> set_seed(0)
>>> model = AutoModelForCausalLM.from_pretrained("distilbert/distilgpt2")
>>> tokenizer = AutoTokenizer.from_pretrained("distilbert/distilgpt2")
>>> inputs = tokenizer("A sequence: A, B, C, D", return_tensors="pt")
>>>
>>> outputs = model.generate(**inputs, do_sample=True)
>>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
A sequence: A, B, C, D, G, H, I. A, M
>>>
>>>
>>> outputs = model.generate(**inputs, do_sample=True, top_k=2)
>>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
A sequence: A, B, C, D, E, F, G, H, I
```
"""
def __init__(self, top_k: int, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
# 检查并初始化 `top_k` 参数,确保其为正整数
if not isinstance(top_k, int) or top_k <= 0:
raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}")
# 将 `top_k` 设为不小于 `min_tokens_to_keep` 的值,设置过滤值 `filter_value`
self.top_k = max(top_k, min_tokens_to_keep)
self.filter_value = filter_value
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# 确保 `top_k` 不超过 `scores` 的最后一维大小,以避免越界
top_k = min(self.top_k, scores.size(-1)) # Safety check
# 移除概率小于 `top-k` 中最后一个概率值的所有 token
indices_to_remove = scores < torch.topk(scores, top_k)[0][..., -1, None]
scores = scores.masked_fill(indices_to_remove, self.filter_value)
return scores
class TypicalLogitsWarper(LogitsWarper):
r"""
[`LogitsWarper`] that performs typical decoding. Inspired on how humans use language, it prioritizes tokens whose
log probability is close to the entropy of the token probability distribution. This means that the most likely
tokens may be discarded in the process.
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information.
def __init__(self, mass: float = 0.9, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
mass = float(mass)
if not (mass > 0 and mass < 1):
raise ValueError(f"`typical_p` has to be a float > 0 and < 1, but is {mass}")
if not isinstance(min_tokens_to_keep, int) or (min_tokens_to_keep < 1):
raise ValueError(f"`min_tokens_to_keep` has to be a positive integer, but is {min_tokens_to_keep}")
self.filter_value = filter_value
self.mass = mass
self.min_tokens_to_keep = min_tokens_to_keep
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
normalized = torch.nn.functional.log_softmax(scores, dim=-1)
p = torch.exp(normalized)
ent = -(normalized * p).nansum(-1, keepdim=True)
shifted_scores = torch.abs((-normalized) - ent)
sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False)
sorted_logits = scores.gather(-1, sorted_indices)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
last_ind = (cumulative_probs < self.mass).sum(dim=1)
last_ind.clamp_(max=sorted_scores.shape[-1] - 1)
sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1))
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
scores = scores.masked_fill(indices_to_remove, self.filter_value)
return scores
class EpsilonLogitsWarper(LogitsWarper):
r"""
[`LogitsWarper`] that performs epsilon-sampling, i.e. restricting to tokens with `prob >= epsilon`. Takes the
largest min_tokens_to_keep tokens if no tokens satisfy this constraint. See [Truncation Sampling as Language Model
Desmoothing](https://arxiv.org/abs/2210.15191) for more information.
Args:
epsilon (`float`):
If set to > 0, only the most tokens with probabilities `epsilon` or higher are kept for generation.
filter_value (`float`, *optional*, defaults to -inf):
All filtered values will be set to this float value.
min_tokens_to_keep (`int`, *optional*, defaults to 1):
Minimum number of tokens that cannot be filtered.
Examples:
```
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
>>> set_seed(0)
>>> model = AutoModelForCausalLM.from_pretrained("distilbert/distilgpt2")
>>> tokenizer = AutoTokenizer.from_pretrained("distilbert/distilgpt2")
>>> inputs = tokenizer("A sequence: 1, 2", return_tensors="pt")
>>> # With sampling, the output is unexpected -- sometimes too unexpected.
>>> outputs = model.generate(**inputs, do_sample=True)
>>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
A sequence: 1, 2, 0, 2, 2. 2, 2, 2, 2
>>> # With epsilon sampling, the output gets restricted to high-probability tokens. Note that this is similar to
>>> # Top P sampling, which restricts tokens based on their cumulative probability.
>>> # Pro tip: The paper recomends using `epsilon_cutoff` values between 3e-4 and 9e-4
>>> outputs = model.generate(**inputs, do_sample=True, epsilon_cutoff=0.1)
>>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
A sequence: 1, 2, 3, 4, 5, 6, 7, 8, 9
```
"""
def __init__(self, epsilon: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
epsilon = float(epsilon)
if epsilon <= 0 or epsilon >= 1:
raise ValueError(f"`epsilon_cutoff` has to be a float > 0 and < 1, but is {epsilon}")
min_tokens_to_keep = int(min_tokens_to_keep)
if min_tokens_to_keep < 1:
raise ValueError(
f"`min_tokens_to_keep` has to be a strictly positive integer, but is {min_tokens_to_keep}"
)
self.epsilon = epsilon
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
probabilities = scores.softmax(dim=-1)
indices_to_remove = probabilities < self.epsilon
top_k = min(self.min_tokens_to_keep, scores.size(-1))
indices_to_remove = indices_to_remove & (scores < torch.topk(scores, top_k)[0][..., -1, None])
scores = scores.masked_fill(indices_to_remove, self.filter_value)
return scores
class EtaLogitsWarper(LogitsWarper):
r"""
[`LogitsWarper`] that performs eta-sampling, a technique to filter out tokens with probabilities below a dynamic
cutoff value, `eta`, which is calculated based on a combination of the hyperparameter `epsilon` and the entropy of
the token probabilities, i.e. `eta := min(epsilon, sqrt(epsilon * e^-entropy(probabilities)))`. Takes the largest
min_tokens_to_keep tokens if no tokens satisfy this constraint. It addresses the issue of poor quality in long
samples of text generated by neural language models leading to more coherent and fluent text. See [Truncation
Sampling as Language Model Desmoothing](https://arxiv.org/abs/2210.15191) for more information. Note: `do_sample`
must be set to `True` for this `LogitsWarper` to work.
Args:
epsilon (`float`):
A float value in the range (0, 1). Hyperparameter used to calculate the dynamic cutoff value, `eta`. The
suggested values from the paper ranges from 3e-4 to 4e-3 depending on the size of the model.
filter_value (`float`, *optional*, defaults to -inf):
All values that are found to be below the dynamic cutoff value, `eta`, are set to this float value. This
parameter is useful when logits need to be modified for very low probability tokens that should be excluded
from generation entirely.
min_tokens_to_keep (`int`, *optional*, defaults to 1):
Specifies the minimum number of tokens that must be kept for generation, regardless of their probabilities.
For example, if `min_tokens_to_keep` is set to 1, at least one token will always be kept for generation,
even if all tokens have probabilities below the cutoff `eta`.
Examples:
```
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
>>> set_seed(0)
>>> model = AutoModelForCausalLM.from_pretrained("distilbert/distilgpt2")
>>> tokenizer = AutoTokenizer.from_pretrained("distilbert/distilgpt2")
>>> inputs = tokenizer("A sequence: 1, 2", return_tensors="pt")
>>> # With sampling, the output is unexpected -- sometimes too unexpected.
>>> outputs = model.generate(**inputs, do_sample=True)
>>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
A sequence: 1, 2, 0, 2, 2. 2, 2, 2, 2
>>> # With eta sampling, the output gets restricted to high-probability tokens. You can see it as a dynamic form of
>>> # epsilon sampling that adapts its cutoff probability based on the entropy (high entropy = lower cutoff).
>>> # Pro tip: The paper recomends using `eta_cutoff` values between 3e-4 to 4e-3
>>> outputs = model.generate(**inputs, do_sample=True, eta_cutoff=0.1)
>>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
A sequence: 1, 2, 3, 4, 5, 6, 7, 8, 9
```
"""
def __init__(self, epsilon: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
epsilon = float(epsilon)
if epsilon <= 0 or epsilon >= 1:
raise ValueError(f"`eta_cutoff` has to be a float > 0 and < 1, but is {epsilon}")
min_tokens_to_keep = int(min_tokens_to_keep)
if min_tokens_to_keep < 1:
raise ValueError(
f"`min_tokens_to_keep` has to be a strictly positive integer, but is {min_tokens_to_keep}"
)
self.epsilon = torch.tensor(epsilon)
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
probabilities = scores.softmax(dim=-1)
entropy = torch.distributions.Categorical(logits=scores).entropy()
eta = torch.min(self.epsilon, torch.sqrt(self.epsilon) * torch.exp(-entropy))[..., None]
indices_to_remove = probabilities < eta
top_k = min(self.min_tokens_to_keep, scores.size(-1))
indices_to_remove = indices_to_remove & (scores < torch.topk(scores, top_k)[0][..., -1, None])
scores = scores.masked_fill(indices_to_remove, self.filter_value)
return scores
def _get_ngrams(ngram_size: int, prev_input_ids: torch.Tensor, num_hypos: int):
"""
Assume ngram_size=2 and prev_input_ids=tensor([[40, 2883, 2712, 4346]]). The output of generated ngrams look like
this {(40,): [2883], (2883,): [2712], (2712,): [4346]}.
Args:
ngram_size (`int`):
The number sequential tokens taken as a group which may only occur once before being banned.
prev_input_ids (`torch.Tensor`):
Generated token ids for the current hypothesis.
num_hypos (`int`):
The number of hypotheses for which n-grams need to be generated.
Returns:
generated_ngrams (`dict`):
Dictionary of generated ngrams.
"""
generated_ngrams = [{} for _ in range(num_hypos)]
for idx in range(num_hypos):
gen_tokens = prev_input_ids[idx].tolist()
generated_ngram = generated_ngrams[idx]
for ngram in zip(*[gen_tokens[i:] for i in range(ngram_size)]):
prev_ngram_tuple = tuple(ngram[:-1])
generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]
return generated_ngrams
def _get_generated_ngrams(banned_ngrams, prev_input_ids, ngram_size, cur_len):
"""
Determines the banned tokens for the current hypothesis based on previously generated n-grams.
Args:
banned_ngrams (`dict`):
A dictionary containing previously generated n-grams for each hypothesis.
prev_input_ids (`torch.Tensor`):
Generated token ids for the current hypothesis.
ngram_size (`int`):
The number sequential tokens taken as a group which may only occur once before being banned.
cur_len (`int`):
The current length of the token sequences for which the n-grams are being checked.
Returns:
List of tokens that are banned.
"""
start_idx = cur_len + 1 - ngram_size
ngram_idx = tuple(prev_input_ids[start_idx:cur_len].tolist())
return banned_ngrams.get(ngram_idx, [])
def _calc_banned_ngram_tokens(
ngram_size: int, prev_input_ids: torch.Tensor, num_hypos: int, cur_len: int
) -> List[Iterable[int]]:
"""Copied from fairseq for no_repeat_ngram in beam_search"""
if cur_len + 1 < ngram_size:
return [[] for _ in range(num_hypos)]
generated_ngrams = _get_ngrams(ngram_size, prev_input_ids, num_hypos)
banned_tokens = [
_get_generated_ngrams(generated_ngrams[hypo_idx], prev_input_ids[hypo_idx], ngram_size, cur_len)
for hypo_idx in range(num_hypos)
]
return banned_tokens
class NoRepeatNGramLogitsProcessor(LogitsProcessor):
r"""
N-grams are groups of "n" consecutive words, characters, or tokens taken from a sequence of text. Given the
sentence: "She runs fast", the bi-grams (n=2) would be ("she", "runs") and ("runs", "fast"). In text generation,
avoiding repetitions of word sequences provides a more diverse output. This [`LogitsProcessor`] enforces no
repetition of n-grams by setting the scores of banned tokens to negative infinity which eliminates those tokens
from consideration when further processing the scores. Note that, for decoder-only models like most LLMs, the
prompt is also considered to obtain the n-grams.
[Fairseq](https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345).
<Tip>
Use n-gram penalties with care. For instance, penalizing 2-grams (bigrams) in an article about the city of New York
might lead to undesirable outcomes where the city's name appears only once in the entire text.
[Reference](https://huggingface.co/blog/how-to-generate)
</Tip>
Args:
ngram_size (`int`):
All ngrams of size `ngram_size` can only occur once.
Examples:
```
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
>>> model = AutoModelForCausalLM.from_pretrained("distilbert/distilgpt2")
>>> tokenizer = AutoTokenizer.from_pretrained("distilbert/distilgpt2")
>>> inputs = tokenizer(["Today I"], return_tensors="pt")
>>> output = model.generate(**inputs)
>>> print(tokenizer.decode(output[0], skip_special_tokens=True))
Today I’m not sure if I’m going to be able to do it.
>>> # Now let's add ngram size using `no_repeat_ngram_size`. This stops the repetitions ("I’m") in the output.
>>> output = model.generate(**inputs, no_repeat_ngram_size=2)
>>> print(tokenizer.decode(output[0], skip_special_tokens=True))
Today I’m not sure if I can get a better understanding of the nature of this issue
```
"""
def __init__(self, ngram_size: int):
if not isinstance(ngram_size, int) or ngram_size <= 0:
raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}")
self.ngram_size = ngram_size
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
num_batch_hypotheses = scores.shape[0]
cur_len = input_ids.shape[-1]
banned_batch_tokens = _calc_banned_ngram_tokens(self.ngram_size, input_ids, num_batch_hypotheses, cur_len)
for i, banned_tokens in enumerate(banned_batch_tokens):
scores[i, banned_tokens] = -float("inf")
return scores
class EncoderNoRepeatNGramLogitsProcessor(LogitsProcessor):
r"""
[`LogitsProcessor`] that works similarly to [`NoRepeatNGramLogitsProcessor`], but applied exclusively to prevent
"""
Initializes an instance of the ultimate n-gram blocker.
Args:
encoder_ngram_size (`int`):
Size of the n-grams that should not be repeated in the decoder.
encoder_input_ids (`torch.LongTensor`):
Tensor containing input IDs for the encoder.
"""
def __init__(self, encoder_ngram_size: int, encoder_input_ids: torch.LongTensor):
# Check if encoder_ngram_size is a positive integer
if not isinstance(encoder_ngram_size, int) or encoder_ngram_size <= 0:
raise ValueError(
f"`encoder_ngram_size` has to be a strictly positive integer, but is {encoder_ngram_size}"
)
# Store the n-gram size
self.ngram_size = encoder_ngram_size
# Ensure encoder_input_ids is 2-dimensional
if len(encoder_input_ids.shape) == 1:
encoder_input_ids = encoder_input_ids.unsqueeze(0)
# Store batch size
self.batch_size = encoder_input_ids.shape[0]
# Generate n-grams from the encoder input IDs
self.generated_ngrams = _get_ngrams(encoder_ngram_size, encoder_input_ids, self.batch_size)
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# Calculate number of hypotheses
num_hypos = scores.shape[0]
# Calculate number of beams per hypothesis
num_beams = num_hypos // self.batch_size
# Current length of input_ids
cur_len = input_ids.shape[-1]
# List of banned tokens for each hypothesis
banned_batch_tokens = [
_get_generated_ngrams(
self.generated_ngrams[hypo_idx // num_beams], input_ids[hypo_idx], self.ngram_size, cur_len
)
for hypo_idx in range(num_hypos)
]
# Apply -inf score to banned tokens in scores tensor
for i, banned_tokens in enumerate(banned_batch_tokens):
scores[i, banned_tokens] = -float("inf")
return scores
class SequenceBiasLogitsProcessor(LogitsProcessor):
"""
[`LogitsProcessor`] that applies an additive bias on sequences. The bias is applied to the last token of a sequence
when the next generated token can complete it. Consequently, to take the most of biasing sequences with more than
one token, consider using beam methods (to gracefully work around partially completed sequences that have a
negative bias) and applying the bias to their prefixes (to ensure the bias is applied earlier).
<Tip>
In order to get the token ids of the sequences that you want to bias, make sure to set `add_prefix_space=True` when
initializing the tokenizer, and use `tokenizer(bad_words, add_special_tokens=False).input_ids`. The
`add_prefix_space` argument is only supported for some slow tokenizers, as fast tokenizers' prefixing behaviours
come from `pre tokenizers`. Read more [here](https://huggingface.co/docs/tokenizers/api/pre-tokenizers).
</Tip>
Args:
sequence_bias (`Dict[Tuple[int], float]`):
Dictionary that maps a sequence of tokens to its bias term. Positive biases increase the odds of the
sequence being selected, while negative biases do the opposite. If a sequence has a length of 1, its bias
will always be applied. Otherwise, the bias will only be applied if the sequence in question is about to be
completed (in the token selection step after this processor is applied).
Examples:
```
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
>>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
>>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
>>> inputs = tokenizer(["The full name of Donald is Donald"], return_tensors="pt")
>>> summary_ids = model.generate(inputs["input_ids"], max_new_tokens=4)
>>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True)[0])
The full name of Donald is Donald J. Trump Jr
>>> # Now let's control generation through a bias. Please note that the tokenizer is initialized differently!
>>> tokenizer_with_prefix_space = AutoTokenizer.from_pretrained("openai-community/gpt2", add_prefix_space=True)
>>> def get_tokens_as_tuple(word):
... return tuple(tokenizer_with_prefix_space([word], add_special_tokens=False).input_ids[0])
>>>
>>> sequence_bias = {get_tokens_as_tuple("Trump"): -10.0}
>>> biased_ids = model.generate(inputs["input_ids"], max_new_tokens=4, sequence_bias=sequence_bias)
>>> print(tokenizer.batch_decode(biased_ids, skip_special_tokens=True)[0])
The full name of Donald is Donald J. Donald,
>>> biased_ids = model.generate(inputs["input_ids"], max_new_tokens=4, num_beams=4, sequence_bias=sequence_bias)
"""
def __init__(self, sequence_bias):
"""
Initialize the SequenceBiasLogitsProcessor with a sequence bias dictionary.
Args:
sequence_bias (`Dict[Tuple[int], float]`): A dictionary mapping sequences of tokens to their bias values.
"""
super().__init__()
self.sequence_bias = sequence_bias
def __call__(self, input_ids, scores):
"""
Apply the sequence bias to the logits.
Args:
input_ids (torch.Tensor): Input token IDs.
scores (torch.Tensor): Logits (scores) for each token.
Returns:
torch.Tensor: Modified logits after applying sequence bias.
"""
# Determine the sequence length
seq_len = input_ids.size(1)
# Get the last token's token IDs
last_token_ids = input_ids[:, -1].tolist()
# Check if the last token is in the sequence_bias dictionary
if tuple(last_token_ids) in self.sequence_bias:
# Apply bias to the last token's logits
scores[:, -1] += self.sequence_bias[tuple(last_token_ids)]
return scores
>>> print(tokenizer.batch_decode(biased_ids, skip_special_tokens=True)[0])
The full name of Donald is Donald Rumsfeld,
>>> # We can also add a positive bias to nudge the model towards specific tokens or continuations
>>> sequence_bias = {get_tokens_as_tuple("Donald Duck"): 10.0}
>>> biased_ids = model.generate(inputs["input_ids"], max_new_tokens=4, num_beams=4, sequence_bias=sequence_bias)
>>> print(tokenizer.batch_decode(biased_ids, skip_special_tokens=True)[0])
The full name of Donald is Donald Duck.
```
"""
def __init__(self, sequence_bias: Dict[Tuple[int], float]):
self.sequence_bias = sequence_bias
self._validate_arguments()
self.length_1_bias = None
self.prepared_bias_variables = False
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
if not self.prepared_bias_variables:
self._prepare_bias_variables(scores)
bias = torch.zeros_like(scores)
bias += self.length_1_bias
for sequence_ids, sequence_bias in self.sequence_bias.items():
if len(sequence_ids) == 1:
continue
if len(sequence_ids) > input_ids.shape[1]:
continue
prefix_length = len(sequence_ids) - 1
last_token = sequence_ids[-1]
matching_rows = torch.eq(
input_ids[:, -prefix_length:],
torch.tensor(sequence_ids[:-1], dtype=input_ids.dtype, device=input_ids.device),
).prod(dim=1)
bias[:, last_token] += torch.where(
matching_rows.bool(),
torch.tensor(sequence_bias, device=input_ids.device),
torch.tensor(0.0, device=input_ids.device),
)
scores = scores + bias
return scores
def _prepare_bias_variables(self, scores: torch.FloatTensor):
vocabulary_size = scores.shape[-1]
invalid_biases = []
for sequence_ids in self.sequence_bias:
for token_id in sequence_ids:
if token_id >= vocabulary_size:
invalid_biases.append(token_id)
if len(invalid_biases) > 0:
raise ValueError(
f"The model vocabulary size is {vocabulary_size}, but the following tokens were being biased: "
f"{invalid_biases}"
)
self.length_1_bias = torch.zeros((vocabulary_size,), dtype=torch.float).to(scores.device)
for sequence_ids, bias in self.sequence_bias.items():
if len(sequence_ids) == 1:
self.length_1_bias[sequence_ids[-1]] = bias
self.prepared_bias_variables = True
def _validate_arguments(self):
sequence_bias = self.sequence_bias
if not isinstance(sequence_bias, dict) or len(sequence_bias) == 0:
raise ValueError(f"`sequence_bias` has to be a non-empty dictionary, but is {sequence_bias}.")
if any(not isinstance(sequence_ids, tuple) for sequence_ids in sequence_bias.keys()):
raise ValueError(f"`sequence_bias` has to be a dict with tuples as keys, but is {sequence_bias}.")
if any(
any((not isinstance(token_id, (int, np.integer)) or token_id < 0) for token_id in sequence_ids)
or len(sequence_ids) == 0
for sequence_ids in sequence_bias.keys()
):
raise ValueError(
f"Each key in `sequence_bias` has to be a non-empty tuple of positive integers, but is "
f"{sequence_bias}."
)
if any(not isinstance(bias, float) for bias in sequence_bias.values()):
raise ValueError(f"`sequence_bias` has to be a dict with floats as values, but is {sequence_bias}.")
"""
[`LogitsProcessor`] that enforces that specified sequences will never be selected.
<Tip>
In order to get the token ids of the words that should not appear in the generated text, make sure to set
`add_prefix_space=True` when initializing the tokenizer, and use `tokenizer(bad_words,
add_special_tokens=False).input_ids`. The `add_prefix_space` argument is only supported for some slow tokenizers,
as fast tokenizers' prefixing behaviours come from `pre tokenizers`. Read more
[here](https://huggingface.co/docs/tokenizers/api/pre-tokenizers).
</Tip>
Args:
bad_words_ids (`List[List[int]]`):
List of list of token ids that are not allowed to be generated.
eos_token_id (`Union[int, List[int]]`):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
Examples:
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
>>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
>>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
>>> inputs = tokenizer(["In a word, the cake is a"], return_tensors="pt")
>>> output_ids = model.generate(inputs["input_ids"], max_new_tokens=5, pad_token_id=tokenizer.eos_token_id)
>>> print(tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0])
In a word, the cake is a bit of a mess.
>>> # Now let's take the bad words out. Please note that the tokenizer is initialized differently
>>> tokenizer_with_prefix_space = AutoTokenizer.from_pretrained("openai-community/gpt2", add_prefix_space=True)
>>> def get_tokens_as_list(word_list):
... "Converts a sequence of words into a list of tokens"
... tokens_list = []
... for word in word_list:
... tokenized_word = tokenizer_with_prefix_space([word], add_special_tokens=False).input_ids[0]
... tokens_list.append(tokenized_word)
... return tokens_list
>>> bad_words_ids = get_tokens_as_list(word_list=["mess"])
>>> output_ids = model.generate(
... inputs["input_ids"], max_new_tokens=5, bad_words_ids=bad_words_ids, pad_token_id=tokenizer.eos_token_id
... )
>>> print(tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0])
In a word, the cake is a bit of a surprise.
"""
def __init__(self, bad_words_ids: List[List[int]], eos_token_id: Union[int, List[int]]):
self.bad_word_ids = bad_words_ids
self._validate_arguments()
if eos_token_id is None:
eos_token_id = []
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
bad_words_ids = list(
filter(lambda bad_token_seq: all(bad_token_seq != [i] for i in eos_token_id), bad_words_ids)
)
sequence_bias = {tuple(sequence): float("-inf") for sequence in bad_words_ids}
super().__init__(sequence_bias=sequence_bias)
def _validate_arguments(self):
bad_words_ids = self.bad_word_ids
if not isinstance(bad_words_ids, list) or len(bad_words_ids) == 0:
raise ValueError(f"`bad_words_ids` has to be a non-empty list, but is {bad_words_ids}.")
if any(not isinstance(bad_word_ids, list) for bad_word_ids in bad_words_ids):
raise ValueError(f"`bad_words_ids` has to be a list of lists, but is {bad_words_ids}.")
if any(
any((not isinstance(token_id, (int, np.integer)) or token_id < 0) for token_id in bad_word_ids)
for bad_word_ids in bad_words_ids
):
raise ValueError(
f"Each list in `bad_words_ids` has to be a list of positive integers, but is {bad_words_ids}."
)
class PrefixConstrainedLogitsProcessor(LogitsProcessor):
r"""
[`LogitsProcessor`] that enforces constrained generation and is useful for prefix-conditioned constrained
generation. See [Autoregressive Entity Retrieval](https://arxiv.org/abs/2010.00904) for more information.
Args:
prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`):
This function constraints the beam search to allowed tokens only at each step. This function takes 2
arguments `inputs_ids` and the batch ID `batch_id`. It has to return a list with the allowed tokens for the
next generation step conditioned on the previously generated tokens `inputs_ids` and the batch ID
`batch_id`.
Examples:
```
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
>>> model = AutoModelForCausalLM.from_pretrained("bigscience/bloomz-560m")
>>> tokenizer = AutoTokenizer.from_pretrained("bigscience/bloomz-560m")
>>> inputs = tokenizer("Alice and Bob", return_tensors="pt")
>>> # By default, it continues generating according to the model's logits
>>> outputs = model.generate(**inputs, max_new_tokens=5)
>>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
Alice and Bob are friends
>>> # We can contrain it with `prefix_allowed_tokens_fn` to force a certain behavior based on a prefix.
>>> # For instance, we can force an entire entity to be generated when its beginning is detected.
>>> entity = tokenizer(" Bob Marley", return_tensors="pt").input_ids[0] # 3 tokens
>>> def prefix_allowed_tokens_fn(batch_id, input_ids):
... '''
... Attempts to generate 'Bob Marley' when 'Bob' is detected.
... In this case, `batch_id` is not used, but you can set rules for each batch member.
... '''
... if input_ids[-1] == entity[0]:
... return entity[1]
... elif input_ids[-2] == entity[0] and input_ids[-1] == entity[1]:
... return entity[2]
... return list(range(tokenizer.vocab_size)) # If no match, allow all tokens
>>> outputs = model.generate(**inputs, max_new_tokens=5, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn)
>>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
Alice and Bob Marley
```
"""
def __init__(self, prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]], num_beams: int):
self._prefix_allowed_tokens_fn = prefix_allowed_tokens_fn
self._num_beams = num_beams
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
mask = torch.full_like(scores, -math.inf)
for batch_id, beam_sent in enumerate(input_ids.view(-1, self._num_beams, input_ids.shape[-1])):
for beam_id, sent in enumerate(beam_sent):
prefix_allowed_tokens = self._prefix_allowed_tokens_fn(batch_id, sent)
if len(prefix_allowed_tokens) == 0:
raise ValueError(
f"`prefix_allowed_tokens_fn` returned an empty list for batch ID {batch_id}."
f"This means that the constraint is unsatisfiable. Please check your implementation"
f"of `prefix_allowed_tokens_fn` "
)
mask[batch_id * self._num_beams + beam_id, prefix_allowed_tokens] = 0
return scores + mask
class HammingDiversityLogitsProcessor(LogitsProcessor):
r"""
[`LogitsProcessor`] that enforces diverse beam search.
Note that this logits processor is only effective for [`PreTrainedModel.group_beam_search`]. See [Diverse Beam
Search: Decoding Diverse Solutions from Neural Sequence Models](https://arxiv.org/pdf/1610.02424.pdf) for more
details.
Traditional beam search often generates very similar sequences across different beams.
`HammingDiversityLogitsProcessor` addresses this by penalizing beams that generate tokens already chosen by other
beams in the same time step.
Args:
diversity_penalty (`float`):
This value is subtracted from a beam's score if it generates a token same as any beam from other group at a
particular time. A higher `diversity_penalty` will enforce greater diversity among the beams. Adjusting
this value can help strike a balance between diversity and natural likelihood.
num_beams (`int`):
Number of beams for beam search. 1 means no beam search.
num_beam_groups (`int`):
Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams.
[this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.
Examples:
```
>>> from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
>>> import torch
>>> # Initialize the model and tokenizer
>>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-base")
>>> model = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-base")
>>> # A long text about the solar system
>>> text = (
... "The Solar System is a gravitationally bound system comprising the Sun and the objects that orbit it, "
... "either directly or indirectly. Of the objects that orbit the Sun directly, the largest are the eight "
... "planets, with the remainder being smaller objects, such as the five dwarf planets and small Solar System "
... "bodies. The Solar System formed 4.6 billion years ago from the gravitational collapse of a giant "
... "interstellar molecular cloud."
... )
>>> inputs = tokenizer("summarize: " + text, return_tensors="pt")
>>> # Generate diverse summary
>>> outputs_diverse = model.generate(
... **inputs,
... num_beam_groups=2,
... diversity_penalty=10.0,
... max_length=100,
... num_beams=4,
... num_return_sequences=2,
... )
>>> summaries_diverse = tokenizer.batch_decode(outputs_diverse, skip_special_tokens=True)
>>> # Generate non-diverse summary
>>> outputs_non_diverse = model.generate(
... **inputs,
... max_length=100,
... num_beams=4,
... num_return_sequences=2,
... )
>>> summary_non_diverse = tokenizer.batch_decode(outputs_non_diverse, skip_special_tokens=True)
# 初始化方法,用于设置多样性惩罚、束搜索数和束搜索组数的初始值
def __init__(self, diversity_penalty: float, num_beams: int, num_beam_groups: int):
# 检查并确保 diversity_penalty 是大于0的浮点数
if not isinstance(diversity_penalty, float) or (not diversity_penalty > 0.0):
raise ValueError("`diversity_penalty` should be a float strictly larger than 0.")
self._diversity_penalty = diversity_penalty # 设置多样性惩罚参数
# 检查并确保 num_beams 是大于1的整数
if not isinstance(num_beams, int) or num_beams < 2:
raise ValueError("`num_beams` should be an integer strictly larger than 1.")
self._num_beams = num_beams # 设置束搜索数
# 检查并确保 num_beam_groups 是大于1的整数,且不超过 num_beams
if not isinstance(num_beam_groups, int) or num_beam_groups < 2:
raise ValueError("`num_beam_groups` should be an integer strictly larger than 1.")
if num_beam_groups > num_beams:
raise ValueError("`beam_groups` has to be smaller or equal to `num_beams`.")
self._num_sub_beams = num_beams // num_beam_groups # 计算并设置每个束搜索组的子束搜索数
# 对象被调用时执行的方法,用于执行束搜索过程
def __call__(
self,
input_ids: torch.LongTensor,
scores: torch.FloatTensor,
current_tokens: torch.LongTensor,
beam_group_idx: int,
) -> torch.FloatTensor:
r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary
scores (`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`):
Prediction scores of a language modeling head. These can be logits for each vocabulary when not using
beam search or log softmax for each vocabulary token when using beam search
current_tokens (`torch.LongTensor` of shape `(batch_size)`):
Indices of input sequence tokens in the vocabulary, corresponding to the tokens selected by the other
beam groups in the current generation step.
beam_group_idx (`int`):
The index of the beam group currently being processed.
Return:
`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`:
The processed prediction scores.
"""
# hamming diversity: penalise using same token in current group which was used in previous groups at
# the same time step
batch_size = current_tokens.shape[0] // self._num_beams # 计算批次大小
group_start_idx = beam_group_idx * self._num_sub_beams # 计算当前处理的 beam 组的起始索引
group_end_idx = min(group_start_idx + self._num_sub_beams, self._num_beams) # 计算当前处理的 beam 组的结束索引,确保不超过总数
group_size = group_end_idx - group_start_idx # 计算当前处理的 beam 组的大小
vocab_size = scores.shape[-1] # 获取词汇表大小
if group_start_idx == 0:
return scores # 如果是第一个组,直接返回原始预测分数
for batch_idx in range(batch_size):
# predicted tokens of last time step of previous groups
previous_group_tokens = current_tokens[
batch_idx * self._num_beams : batch_idx * self._num_beams + group_start_idx
] # 获取前面组在当前时间步的预测 token
token_frequency = torch.bincount(previous_group_tokens, minlength=vocab_size).to(scores.device)
# 计算前面组使用的 token 频率,并转移到与 scores 设备一致的张量上
scores[batch_idx * group_size : (batch_idx + 1) * group_size] -= self._diversity_penalty * token_frequency
# 根据多样性惩罚系数,减少当前组的预测分数
return scores
class ForcedBOSTokenLogitsProcessor(LogitsProcessor):
r"""
[`LogitsProcessor`] that enforces the specified token as the first generated token. Used with encoder-decoder
models.
Args:
bos_token_id (`int`):
The id of the token to force as the first generated token.
Examples:
```
>>> from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
>>> model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")
>>> tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
>>> inputs = tokenizer("Translate from English to German: I love cats.", return_tensors="pt")
>>>
>>> outputs = model.generate(**inputs, max_new_tokens=10)
>>> print(tokenizer.batch_decode(outputs)[0])
<pad> Ich liebe Kitty.</s>
>>>
>>>
>>> outputs = model.generate(**inputs, max_new_tokens=10, forced_bos_token_id=tokenizer.eos_token_id)
>>> print(tokenizer.batch_decode(outputs)[0])
<pad></s>
```
"""
def __init__(self, bos_token_id: int):
# 初始化方法,设置强制起始 token 的 ID
self.bos_token_id = bos_token_id
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# 获取当前生成序列的长度
cur_len = input_ids.shape[-1]
# 如果当前长度为1,即刚开始生成
if cur_len == 1:
# 获取 logits 的可能 token 数量
num_tokens = scores.shape[1]
# 将除了指定的强制起始 token 之外的 logits 设置为负无穷大,确保不会被生成
scores[:, [i for i in range(num_tokens) if i != self.bos_token_id]] = -float("inf")
# 将强制起始 token 的 logits 设置为0,确保它被生成
scores[:, self.bos_token_id] = 0
return scores
class ForcedEOSTokenLogitsProcessor(LogitsProcessor):
r"""
[`LogitsProcessor`] that enforces the specified token as the last generated token when `max_length` is reached.
Args:
max_length (`int`):
The maximum length of the sequence to be generated.
eos_token_id (`Union[int, List[int]]`):
The id of the token to force as the last generated token when `max_length` is reached. Optionally, use a
list to set multiple *end-of-sequence* tokens.
Examples:
```
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
>>> model = AutoModelForCausalLM.from_pretrained("distilbert/distilgpt2")
>>> tokenizer = AutoTokenizer.from_pretrained("distilbert/distilgpt2")
>>> inputs = tokenizer("A sequence: 1, 2, 3", return_tensors="pt")
>>>
>>> outputs = model.generate(**inputs, max_new_tokens=10)
>>> print(tokenizer.batch_decode(outputs)[0])
A sequence: 1, 2, 3, 4, 5, 6, 7, 8
>>>
```
"""
def __init__(self, max_length: int, eos_token_id: Union[int, List[int]]):
# 初始化方法,设置强制结束 token 的 ID 或 IDs
self.max_length = max_length
self.eos_token_id = eos_token_id
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# 获取当前生成序列的长度
cur_len = input_ids.shape[-1]
# 如果达到最大长度,强制设置生成序列的最后 token(s)
if cur_len == self.max_length:
if isinstance(self.eos_token_id, int):
# 如果是单个 EOS token ID,将除了它之外的 logits 设置为负无穷大
scores[:, [i for i in range(scores.shape[1]) if i != self.eos_token_id]] = -float("inf")
# 将 EOS token 的 logits 设置为0,确保它被生成
scores[:, self.eos_token_id] = 0
else:
# 如果是多个 EOS token IDs,将除了它们之外的 logits 设置为负无穷大
for eos_id in self.eos_token_id:
scores[:, [i for i in range(scores.shape[1]) if i != eos_id]] = -float("inf")
# 将所有 EOS tokens 的 logits 设置为0,确保它们中的任意一个被生成
for eos_id in self.eos_token_id:
scores[:, eos_id] = 0
return scores
# 使用模型生成文本输出,限制生成的新标记数目为10个,强制结束标记使用给定的 eos_token_id
outputs = model.generate(**inputs, max_new_tokens=10, forced_eos_token_id=tokenizer.eos_token_id)
# 解码生成的输出序列并打印第一个结果
print(tokenizer.batch_decode(outputs)[0])
class InfNanRemoveLogitsProcessor(LogitsProcessor):
r"""
[`LogitsProcessor`] that removes all `nan` and `inf` values to avoid the generation method to fail. Note that using
the logits processor should only be used if necessary since it can slow down the generation method.
This logits processor has no `generate` example, as there shouldn't be a correct combination of flags that warrants
its use.
"""
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# set all nan values to 0.0
scores[scores != scores] = 0.0 # 将所有的NaN值设置为0.0
# set all +/-inf values to max/min possible value
scores[scores == float("inf")] = torch.finfo(scores.dtype).max # 将所有的正无穷值设置为数据类型的最大值
scores[scores == float("-inf")] = torch.finfo(scores.dtype).min # 将所有的负无穷值设置为数据类型的最小值
return scores
"""
该类的构造函数初始化对象的属性,并计算长度调整的起始点和衰减因子。
def __init__(
self,
exponential_decay_length_penalty: Tuple[int, float], # 接收一个元组,包含衰减长度和衰减因子
eos_token_id: Union[int, List[int]], # 接收结束标记的 ID,可以是单个整数或整数列表
input_ids_seq_length: int, # 输入的序列长度
):
# 计算调整起始点,考虑输入序列的长度
self.regulation_start = exponential_decay_length_penalty[0] + input_ids_seq_length
# 设置衰减因子
self.regulation_factor = exponential_decay_length_penalty[1]
# 如果结束标记是整数,则转换为列表
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
# 存储结束标记的 ID
self.eos_token_id = eos_token_id
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# 获取当前输入序列的长度
cur_len = input_ids.shape[-1]
# 如果当前长度超过了调整起始点
if cur_len > self.regulation_start:
# 对每个结束标记执行以下操作
for i in self.eos_token_id:
# 计算惩罚的索引,基于当前长度和调整起始点
penalty_idx = cur_len - self.regulation_start
# 支持负对数,计算绝对值的惩罚,并添加到原始对数中
scores[:, i] = scores[:, i] + torch.abs(scores[:, i]) * (pow(self.regulation_factor, penalty_idx) - 1)
# 返回调整后的分数
return scores
"""
class LogitNormalization(LogitsProcessor, LogitsWarper):
r"""
[`LogitsWarper`] and [`LogitsProcessor`] for normalizing the scores using log-softmax. It's important to normalize
the scores during beam search, after applying the logits processors or warpers, since the search algorithm used in
this library doesn't do it (it only does it before, but they may need re-normalization) but it still supposes that
the scores are normalized when comparing the hypotheses.
Examples:
```
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
>>> import torch
>>> model = AutoModelForCausalLM.from_pretrained("distilbert/distilgpt2")
>>> tokenizer = AutoTokenizer.from_pretrained("distilbert/distilgpt2")
>>> inputs = tokenizer("A sequence: 1, 2, 3", return_tensors="pt")
>>> # By default, the scores are not normalized -- the sum of their exponentials is NOT a normalized probability
>>> # distribution, summing to 1
>>> outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True)
>>> print(torch.sum(torch.exp(outputs.scores[-1])))
tensor(816.3250)
>>> # Normalizing them may have a positive impact on beam methods, or when using the scores on your application
>>> outputs = model.generate(**inputs, renormalize_logits=True, return_dict_in_generate=True, output_scores=True)
>>> print(torch.sum(torch.exp(outputs.scores[-1])))
tensor(1.0000)
```
"""
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
# 定义一个类方法,继承自 LogitsProcessor 类,并添加了文档字符串描述输入
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# 对 scores 执行 log_softmax 操作,使得 scores 在最后一个维度上进行 log-softmax 归一化
scores = scores.log_softmax(dim=-1)
# 返回处理后的 scores
return scores
class SuppressTokensAtBeginLogitsProcessor(LogitsProcessor):
r"""
[`SuppressTokensAtBeginLogitsProcessor`] supresses a list of tokens as soon as the `generate` function starts
generating using `begin_index` tokens. This should ensure that the tokens defined by `begin_suppress_tokens` are
not generated at the begining. Originally created for
[Whisper](https://huggingface.co/docs/transformers/model_doc/whisper).
Examples:
```
>>> from transformers import AutoProcessor, WhisperForConditionalGeneration
>>> from datasets import load_dataset
>>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
>>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt")
>>> # Whisper has `begin_suppress_tokens` set by default (= `[220, 50256]`). 50256 is the EOS token, so this means
>>> # it can't generate and EOS token in the first iteration, but it can in the others.
>>> outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True)
```
"""
>>> print(outputs.scores[1][0, 50256]) # 1 (and not 0) is the first freely generated token
tensor(-inf)
>>> print(outputs.scores[-1][0, 50256]) # in other places we can see some probability mass for EOS
tensor(29.9010)
>>> # If we disable `begin_suppress_tokens`, we can generate EOS in the first iteration.
>>> outputs = model.generate(
... **inputs, return_dict_in_generate=True, output_scores=True, begin_suppress_tokens=None
... )
>>> print(outputs.scores[1][0, 50256])
tensor(11.2027)
```
"""
def __init__(self, begin_suppress_tokens, begin_index):
self.begin_suppress_tokens = list(begin_suppress_tokens)
self.begin_index = begin_index
def set_begin_index(self, begin_index):
self.begin_index = begin_index
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
if input_ids.shape[1] == self.begin_index:
scores[:, self.begin_suppress_tokens] = -float("inf")
return scores
class SuppressTokensLogitsProcessor(LogitsProcessor):
r"""
This processor can be used to suppress a list of tokens. The processor will set their log probs to `-inf` so
that they are not generated. Originally created for
[Whisper](https://huggingface.co/docs/transformers/model_doc/whisper).
Examples:
```
>>> from transformers import AutoProcessor, WhisperForConditionalGeneration
>>> from datasets import load_dataset
>>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
>>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt")
>>> # Whisper has a long list of suppressed tokens. For instance, in this case, the token 1 is suppressed by default.
>>> outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True)
>>> print(outputs.scores[1][0, 1]) # 1 (and not 0) is the first freely generated token
tensor(-inf)
>>> # If we disable `suppress_tokens`, we can generate it.
>>> outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True, suppress_tokens=None)
>>> print(outputs.scores[1][0, 1])
tensor(5.7738)
```
"""
def __init__(self, suppress_tokens):
self.suppress_tokens = list(suppress_tokens)
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
scores[:, self.suppress_tokens] = -float("inf")
return scores
all(outputs.scores[0][0, i] == float("-inf") for i in range(processor.tokenizer.vocab_size) if i != 50362)
True
>>>
>>> print(outputs.scores[0][0, 50362])
tensor(0.)
>>>
>>> outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True, forced_decoder_ids=None)
>>>
>>> print(
... all(outputs.scores[0][0, i] == float("-inf") for i in range(processor.tokenizer.vocab_size) if i != 50362)
... )
False
>>>
>>> print(outputs.scores[0][0, 50362])
tensor(19.3140)
```
"""
def __init__(self, force_token_map: List[List[int]], _has_warned: Optional[bool] = False):
# 初始化 ForceTokensLogitsProcessor 类,接收一个强制令牌映射 force_token_map 和一个是否警告的标志 _has_warned
self.force_token_map = dict(force_token_map)
if not _has_warned:
# 如果 _has_warned 为 False,发出警告,提醒在 v4.40 版本中移除该处理器
warnings.warn(
"This `ForceTokensLogitsProcessor` has been deprecated and will be removed in v4.40. Should you need to provide prompt ids for generation, specify `input_ids` to the generate method for decoder-only models, or `decoder_input_ids` for encoder-decoder models.",
FutureWarning,
)
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# 根据传入的 input_ids 和 scores 处理 logits
generation_idx = input_ids.shape[-1] # 获取生成的索引
current_token = self.force_token_map.get(generation_idx, None) # 获取当前索引对应的强制令牌
if current_token is not None:
# 如果当前令牌不为 None,则将所有 scores 设置为负无穷大,并将当前令牌的 score 设置为 0
scores[:, :] = -float("inf")
scores[:, current_token] = 0
return scores
class WhisperTimeStampLogitsProcessor(LogitsProcessor):
r"""
[`LogitsProcessor`] that modifies the logits for the generation of timestamps in the transcription. When the input
tokens are at a specific threshold, the processor sets the scores to negative infinity. The processor makes sure
that timestamp tokens appear in pairs, by masking out the logits that would break this pairing pattern. This is
done to maintain the consistency and structure of generated timestamps. It also ensures that when the predicted
probability of sampling any of the timestamp token is greater than any individual non-timestamp token, those
non-timestamp logits are set to negative infinity. This is done to ensure the generation of timestamps over other
potential tokens.
See [the paper](https://arxiv.org/abs/2212.04356) for more information.
Args:
generate_config (`GenerateConfig`):
The generate config used to generate the output. The following parameters are required:
eos_token_id (`int`, *optional*, defaults to 50257):
The id of the *end-of-sequence* token.
no_timestamps_token_id (`int`, *optional*, defaults to 50363):
The id of the `"<|notimestamps|>"` token.
max_initial_timestamp_index (`int`, *optional*, defaults to 1):
Used to set the maximum value of the initial timestamp. This is used to prevent the model from
predicting timestamps that are too far in the future.
begin_index (`Optional`, *optional*): Token index of the first token that is generated by the model.
_detect_timestamp_from_logprob (`bool`, *optional*): Whether timestamps can be predicted from logprobs over all timestamps.
Examples:
``` python
>>> import torch
>>> from transformers import AutoProcessor, WhisperForConditionalGeneration, GenerationConfig
>>> from datasets import load_dataset
>>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
>>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> inputs = processor(ds[3]["audio"]["array"], return_tensors="pt")
>>> input_features = inputs.input_features
>>>
>>> generated_ids = model.generate(inputs=input_features, return_timestamps=True)
>>> transcription = processor.batch_decode(generated_ids, decode_with_timestamps=True)[0]
>>> print("Transcription:", transcription)
Transcription: <|startoftranscript|><|0.00|> He has grave doubts whether Sir Frederick Layton's work is really Greek after all, and can<|6.44|><|6.44|> discover in it but little of rocky Ithaca.<|9.44|><|endoftext|>
>>> #No timestamps & change EOS:
```
"""
# 初始化函数,接受生成配置、可选的起始索引和检测时间戳的标志位
def __init__(
self, generate_config, begin_index: Optional[int] = None, _detect_timestamp_from_logprob: Optional[bool] = None
): # support for the kwargs
# 设置不带时间戳的特殊 token ID
self.no_timestamps_token_id = generate_config.no_timestamps_token_id
# 计算时间戳起始的 token ID
self.timestamp_begin = generate_config.no_timestamps_token_id + 1
# 设置终止生成的 token ID,可以从生成配置的 EOS 或 BOS token ID 中获取
self.eos_token_id = generate_config.eos_token_id or generate_config.bos_token_id
# 用于测试的变量,控制是否通过对数概率检测时间戳
self._detect_timestamp_from_logprob = (
_detect_timestamp_from_logprob
if _detect_timestamp_from_logprob is not None
else getattr(generate_config, "_detect_timestamp_from_logprob", True)
)
# 计算开始索引,考虑到强制解码器 ID 的数量
num_forced_ids = (
len(generate_config.forced_decoder_ids) if generate_config.forced_decoder_ids is not None else 0
)
self.begin_index = begin_index or (num_forced_ids + 1)
# 最大初始时间戳索引,从生成配置中获取,默认为 None
self.max_initial_timestamp_index = getattr(generate_config, "max_initial_timestamp_index", None)
# TODO(Patrick): 确保官方模型将 max_initial_timestamp_index 设置为 50
# self.max_initial_timestamp_index = 50
# 设置起始索引的方法
def set_begin_index(self, begin_index):
self.begin_index = begin_index
# 添加文档字符串,描述输入的 logits 处理器的输入
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
"""
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# suppress <|notimestamps|> which is handled by without_timestamps
# 将不带时间戳的标记 <|notimestamps|> 的分数设为负无穷,这些标记由 without_timestamps 处理
scores[:, self.no_timestamps_token_id] = -float("inf")
# timestamps have to appear in pairs, except directly before eos_token; mask logits accordingly
# 时间戳必须成对出现,除非直接位于 eos_token 前面;相应地屏蔽对数几率
for k in range(input_ids.shape[0]):
sampled_tokens = input_ids[k, self.begin_index :]
seq = list(sampled_tokens.tolist())
last_was_timestamp = len(seq) >= 1 and seq[-1] >= self.timestamp_begin
penultimate_was_timestamp = len(seq) < 2 or seq[-2] >= self.timestamp_begin
if last_was_timestamp:
if penultimate_was_timestamp:
# has to be non-timestamp
# 必须是非时间戳
scores[k, self.timestamp_begin :] = -float("inf")
else:
# cannot be normal text tokens
# 不能是正常文本标记
scores[k, : self.eos_token_id] = -float("inf")
timestamps = sampled_tokens[sampled_tokens.ge(self.timestamp_begin)]
if timestamps.numel() > 0:
# `timestamps` shouldn't decrease; forbid timestamp tokens smaller than the last
if last_was_timestamp and not penultimate_was_timestamp:
timestamp_last = timestamps[-1]
else:
timestamp_last = timestamps[-1] + 1
scores[k, self.timestamp_begin : timestamp_last] = -float("inf")
if input_ids.shape[1] == self.begin_index:
scores[:, : self.timestamp_begin] = -float("inf")
if self.max_initial_timestamp_index is not None:
last_allowed = self.timestamp_begin + self.max_initial_timestamp_index
scores[:, last_allowed + 1 :] = -float("inf")
logprobs = torch.nn.functional.log_softmax(scores.float(), dim=-1)
for k in range(input_ids.shape[0]):
timestamp_logprob = logprobs[k, self.timestamp_begin :].logsumexp(dim=-1)
max_text_token_logprob = logprobs[k, : self.timestamp_begin].max()
if timestamp_logprob > max_text_token_logprob and self._detect_timestamp_from_logprob:
scores[k, : self.timestamp_begin] = -float("inf")
return scores
class WhisperNoSpeechDetection(LogitsProcessor):
r"""This processor can be used to detect silence when using Whisper. It should take as input unprocessed logits to follow the original implementation"""
def __init__(self, no_speech_token: int, begin_index: int, scores_is_logprobs: bool = False):
self.no_speech_token = no_speech_token
self.start_of_trans_offset = begin_index
self.begin_index = begin_index
self._no_speech_prob = [0.0]
self.is_scores_logprobs = scores_is_logprobs
self.model = None
self.inputs = None
def set_model(self, model):
self.model = model
def set_inputs(self, inputs):
self.inputs = {**self.model.prepare_inputs_for_generation(**inputs), **inputs}
self.inputs["input_features"] = self.inputs.pop("inputs")
@property
def no_speech_prob(self):
return self._no_speech_prob
def set_begin_index(self, begin_index):
self.begin_index = begin_index
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
if input_ids.shape[1] == self.begin_index:
if self.start_of_trans_offset > 1:
with torch.no_grad():
logits = self.model(**self.inputs).logits
no_speech_index = self.begin_index - self.start_of_trans_offset
no_speech_scores = logits[:, no_speech_index]
else:
no_speech_scores = scores
if self.is_scores_logprobs:
probs = no_speech_scores.exp()
else:
probs = no_speech_scores.float().softmax(dim=-1)
self._no_speech_prob = probs[:, self.no_speech_token]
return scores
class ClassifierFreeGuidanceLogitsProcessor(LogitsProcessor):
r"""
[`LogitsProcessor`] for classifier free guidance (CFG). The scores are split over the batch dimension,
where the first half correspond to the conditional logits (predicted from the input prompt) and the second half
correspond to the unconditional logits (predicted from an empty or 'null' prompt). The processor computes a
weighted average across the conditional and unconditional logits, parameterised by the `guidance_scale`.
See [the paper](https://arxiv.org/abs/2306.05284) for more information.
<Tip warning={true}>
This logits processor is exclusively compatible with
[MusicGen](https://huggingface.co/docs/transformers/main/en/model_doc/musicgen)
</Tip>
def __init__(self, guidance_scale):
# 初始化方法,接受一个参数 guidance_scale,用于设置分类器自由引导(CFG)的比例尺。CFG 通过设置 `guidance_scale > 1` 启用。
# 较高的 guidance_scale 鼓励模型生成与输入提示更紧密相关的样本,但通常会导致质量较差的生成结果。
if guidance_scale > 1:
# 如果 guidance_scale 大于 1,则将其赋值给实例变量 self.guidance_scale
self.guidance_scale = guidance_scale
else:
# 如果 guidance_scale 不大于 1,则抛出 ValueError 异常,提示需要 guidance_scale 大于 1 才能使用分类器自由引导处理器。
raise ValueError(
"Require guidance scale >1 to use the classifier free guidance processor, got guidance scale "
f"{guidance_scale}."
)
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# 简单检查确保 logits 分数(条件和非条件)与输入的 input_ids(仅条件)具有兼容的批次大小。
if scores.shape[0] != 2 * input_ids.shape[0]:
# 如果 logits 的批次大小不是 input_ids 批次大小的两倍,则抛出 ValueError 异常。
raise ValueError(
f"Logits should have twice the batch size of the input ids, the first half of batches corresponding to "
f"the conditional inputs, and the second half of batches corresponding to the unconditional inputs. Got "
f"batch size {scores.shape[0]} for the logits and {input_ids.shape[0]} for the input ids."
)
# 计算非引导批次大小
unguided_bsz = scores.shape[0] // 2
# 将 scores 按照非引导批次大小分割成条件 logits 和非条件 logits
cond_logits, uncond_logits = scores.split(unguided_bsz, dim=0)
# 应用 guidance_scale 对 scores 进行加权处理,增强生成的条件性输出
scores = uncond_logits + (cond_logits - uncond_logits) * self.guidance_scale
# 返回处理后的 scores
return scores
class AlternatingCodebooksLogitsProcessor(LogitsProcessor):
r"""
[`LogitsProcessor`] enforcing alternated generation between the two codebooks of Bark.
<Tip warning={true}>
This logits processor is exclusively compatible with
[Bark](https://huggingface.co/docs/transformers/en/model_doc/bark)'s fine submodel. See the model documentation
for examples.
</Tip>
Args:
input_start_len (`int`):
The length of the initial input sequence.
semantic_vocab_size (`int`):
Vocabulary size of the semantic part, i.e number of tokens associated to the semantic vocabulary.
codebook_size (`int`):
Number of tokens associated to the codebook.
"""
def __init__(self, input_start_len: int, semantic_vocab_size: int, codebook_size: int):
if not isinstance(input_start_len, int) or input_start_len < 0:
raise ValueError(f"`input_starting_length` has to be a non-negative integer, but is {input_start_len}")
# 初始化函数,验证并设置输入的起始长度、语义词汇表大小和码书大小
self.input_start_len = input_start_len
self.semantic_vocab_size = semantic_vocab_size
self.codebook_size = codebook_size
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# 获取当前输入序列的长度
curr_len = input_ids.shape[-1]
# 判断当前序列长度决定使用哪个码书:偶数长度使用第一个码书,奇数长度使用第二个码书
is_first_codebook = ((curr_len - self.input_start_len) % 2) == 0
if is_first_codebook:
# 如果是第一个码书,将第一个码书的部分置为负无穷,表示不考虑这些部分的生成
scores[:, : self.semantic_vocab_size] = -float("inf")
scores[:, self.semantic_vocab_size + self.codebook_size :] = -float("inf")
else:
# 如果是第二个码书,将第二个码书的部分置为负无穷,表示不考虑这些部分的生成
scores[:, : self.semantic_vocab_size + self.codebook_size] = -float("inf")
# 返回处理后的得分张量
return scores
class UnbatchedClassifierFreeGuidanceLogitsProcessor(LogitsProcessor):
r"""
Logits processor for Classifier-Free Guidance (CFG). The processors computes a weighted average across scores
from prompt conditional and prompt unconditional (or negative) logits, parameterized by the `guidance_scale`.
The unconditional scores are computed internally by prompting `model` with the `unconditional_ids` branch.
See [the paper](https://arxiv.org/abs/2306.17806) for more information.
"""
Args:
guidance_scale (`float`):
CFG的引导比例,用于分类器自由引导。通过设置 `guidance_scale != 1` 来启用CFG。较高的引导比例鼓励模型生成与输入提示更紧密相关的样本,通常会以较差的质量为代价。小于1的值具有相反的效果,同时使得提供的负提示(如果有的话)作为正提示。
model (`PreTrainedModel`):
计算无条件分数的模型。假定与计算条件分数的模型相同。这两个模型必须使用相同的分词器。
unconditional_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
无条件分支中输入序列标记在词汇表中的索引。如果未设置,则默认为提示的最后一个标记。
unconditional_attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
用于无条件_ids的注意力掩码。
use_cache (`bool`, *optional*, defaults to `True`):
是否在负提示前向传递过程中缓存键/值对。
Examples:
```
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
>>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
>>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
>>> inputs = tokenizer(["Today, a dragon flew over Paris, France,"], return_tensors="pt")
>>> out = model.generate(inputs["input_ids"], guidance_scale=1.5)
>>> tokenizer.batch_decode(out, skip_special_tokens=True)[0]
'Today, a dragon flew over Paris, France, killing at least 50 people and injuring more than 100'
>>> # with a negative prompt
>>> neg_inputs = tokenizer(["A very happy event happened,"], return_tensors="pt")
>>> out = model.generate(inputs["input_ids"], guidance_scale=2, negative_prompt_ids=neg_inputs["input_ids"])
>>> tokenizer.batch_decode(out, skip_special_tokens=True)[0]
'Today, a dragon flew over Paris, France, killing at least 130 people. French media reported that'
>>> # with a positive prompt
>>> neg_inputs = tokenizer(["A very happy event happened,"], return_tensors="pt")
>>> out = model.generate(inputs["input_ids"], guidance_scale=0, negative_prompt_ids=neg_inputs["input_ids"])
>>> tokenizer.batch_decode(out, skip_special_tokens=True)[0]
"Today, a dragon flew over Paris, France, and I'm very happy to be here. I"
```
):
self.guidance_scale = guidance_scale
self.model = model
self.unconditional_context = {
"input_ids": unconditional_ids,
"attention_mask": unconditional_attention_mask,
"use_cache": use_cache,
"past_key_values": None,
"first_pass": True,
}
# 初始化方法,设置对象的初始属性
self.guidance_scale = guidance_scale # 设置引导尺度
self.model = model # 设置模型
# 设置无条件生成的上下文信息,包括输入id、注意力掩码、是否使用缓存、过去的键值对和第一次通行标志
self.unconditional_context = {
"input_ids": unconditional_ids,
"attention_mask": unconditional_attention_mask,
"use_cache": use_cache,
"past_key_values": None,
"first_pass": True,
}
def get_unconditional_logits(self, input_ids):
if self.unconditional_context["first_pass"]:
if self.unconditional_context["input_ids"] is None:
self.unconditional_context["input_ids"] = input_ids[:, -1:]
if self.unconditional_context["attention_mask"] is None:
self.unconditional_context["attention_mask"] = torch.ones_like(
self.unconditional_context["input_ids"], dtype=torch.long
)
input_ids = self.unconditional_context["input_ids"]
attention_mask = self.unconditional_context["attention_mask"]
self.unconditional_context["first_pass"] = False
else:
attention_mask = torch.cat(
[
self.unconditional_context["attention_mask"],
torch.ones_like(input_ids[:, -1:], dtype=torch.long),
],
dim=1,
)
if not self.unconditional_context["use_cache"]:
input_ids = torch.cat([self.unconditional_context["input_ids"], input_ids[:, -1:]], dim=1)
else:
input_ids = input_ids[:, -1:]
self.unconditional_context["input_ids"] = input_ids
self.unconditional_context["attention_mask"] = attention_mask
# 根据上下文信息进行无条件生成的logits计算
if self.unconditional_context["first_pass"]:
# 如果是第一次通行,则根据输入的最后一个token设置初始输入id和注意力掩码
if self.unconditional_context["input_ids"] is None:
self.unconditional_context["input_ids"] = input_ids[:, -1:]
if self.unconditional_context["attention_mask"] is None:
self.unconditional_context["attention_mask"] = torch.ones_like(
self.unconditional_context["input_ids"], dtype=torch.long
)
input_ids = self.unconditional_context["input_ids"]
attention_mask = self.unconditional_context["attention_mask"]
self.unconditional_context["first_pass"] = False
else:
# 如果不是第一次通行,则根据是否使用缓存来更新输入id和注意力掩码
attention_mask = torch.cat(
[
self.unconditional_context["attention_mask"],
torch.ones_like(input_ids[:, -1:], dtype=torch.long),
],
dim=1,
)
if not self.unconditional_context["use_cache"]:
input_ids = torch.cat([self.unconditional_context["input_ids"], input_ids[:, -1:]], dim=1)
else:
input_ids = input_ids[:, -1:]
self.unconditional_context["input_ids"] = input_ids
self.unconditional_context["attention_mask"] = attention_mask
# 调用模型生成输出,传入当前的输入id、注意力掩码、是否使用缓存以及过去的键值对
out = self.model(
input_ids,
attention_mask=attention_mask,
use_cache=self.unconditional_context["use_cache"],
past_key_values=self.unconditional_context["past_key_values"],
)
self.unconditional_context["past_key_values"] = out.get("past_key_values", None)
return out.logits
def __call__(self, input_ids, scores):
scores = torch.nn.functional.log_softmax(scores, dim=-1)
if self.guidance_scale == 1:
return scores
logits = self.get_unconditional_logits(input_ids)
# 计算无条件logits的对数softmax
unconditional_logits = torch.nn.functional.log_softmax(logits[:, -1], dim=-1)
# 根据引导尺度调整得分的对数softmax并加上无条件生成的对数softmax
out = self.guidance_scale * (scores - unconditional_logits) + unconditional_logits
return out
class BarkEosPrioritizerLogitsProcessor(LogitsProcessor):
r"""This processor ensures that the EOS token is selected if its probability is greater than the `min_eos_p`.
<Tip warning={true}>
This logits processor is exclusively compatible with
[Bark](https://huggingface.co/docs/transformers/en/model_doc/bark). See the model documentation for examples.
</Tip>
Args:
eos_token_id (`Union[int, List[int]]`):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
min_eos_p (`float`, *optional*):
Minimum end of speech threshold.
"""
def __init__(self, eos_token_id: Union[int, List[int]], min_eos_p: float):
# Convert eos_token_id to a list if it's provided as an integer
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
self.eos_token_id = eos_token_id
# Validate min_eos_p is a positive float if provided
if min_eos_p is not None and min_eos_p <= 0:
raise ValueError(f"`min_eos_p` has to be a positive float, but is {min_eos_p}")
self.min_eos_p = min_eos_p
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# Check if min_eos_p is set
if self.min_eos_p:
# Compute softmax probabilities across the last dimension of scores tensor
probs = torch.nn.functional.softmax(scores.float(), dim=-1)
# Initialize a tensor with -inf values except for the eos_token_id
early_stop_scores = torch.ones_like(scores) * -float("inf")
early_stop_scores[:, self.eos_token_id] = scores[:, self.eos_token_id]
# Determine if any EOS token's probability exceeds min_eos_p
do_early_stop = probs[:, self.eos_token_id] > self.min_eos_p
do_early_stop = torch.any(do_early_stop, dim=1, keepdim=True)
# Conditionally replace scores with early_stop_scores where needed
scores = torch.where(do_early_stop, early_stop_scores, scores)
return scores