Transformers 源码解析(五)
.\generation\stopping_criteria.py
import time
import warnings
from abc import ABC
from copy import deepcopy
from typing import Optional
import torch
from ..utils import add_start_docstrings, logging
logger = logging.get_logger(__name__)
STOPPING_CRITERIA_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
scores (`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`):
Prediction scores of a language modeling head. These can be scores for each vocabulary token before SoftMax
or scores for each vocabulary token after SoftMax. If this stopping criteria depends on the `scores` input,
make sure you pass `return_dict_in_generate=True, output_scores=True` to `generate`.
kwargs (`Dict[str, Any]`, *optional*):
Additional stopping criteria specific kwargs.
Return:
`torch.BoolTensor`. (`torch.BoolTensor` of shape `(batch_size, 1)`), where `True` indicates we stop generation
for a particular row, `True` indicates we should continue.
"""
class StoppingCriteria(ABC):
"""Abstract base class for all stopping criteria that can be applied during generation.
If your stopping criteria depends on the `scores` input, make sure you pass `return_dict_in_generate=True,
output_scores=True` to `generate`.
"""
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
raise NotImplementedError("StoppingCriteria needs to be subclassed")
class MaxLengthCriteria(StoppingCriteria):
"""
This class can be used to stop generation whenever the full generated number of tokens exceeds `max_length`. Keep
in mind for decoder-only type of transformers, this will include the initial prompted tokens.
Args:
max_length (`int`):
The maximum length that the output sequence can have in number of tokens.
max_position_embeddings (`int`, *optional*):
The maximum model length, as defined by the model's `config.max_position_embeddings` attribute.
"""
def __init__(self, max_length: int, max_position_embeddings: Optional[int] = None):
self.max_length = max_length
self.max_position_embeddings = max_position_embeddings
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
cur_len = input_ids.shape[-1]
is_done = cur_len >= self.max_length
if self.max_position_embeddings is not None and not is_done and cur_len >= self.max_position_embeddings:
logger.warning_once(
"This is a friendly reminder - the current text generation call will exceed the model's predefined "
f"maximum length ({self.max_position_embeddings}). Depending on the model, you may observe "
"exceptions, performance degradation, or nothing at all."
)
return torch.full((input_ids.shape[0],), is_done, device=input_ids.device, dtype=torch.bool)
class MaxNewTokensCriteria(StoppingCriteria):
"""
This class can be used to stop generation whenever the generated number of tokens exceeds `max_new_tokens`. Keep in
mind for decoder-only type of transformers, this will **not** include the initial prompted tokens. This is very
close to `MaxLengthCriteria` but ignores the number of initial tokens.
Args:
start_length (`int`):
The number of initial tokens.
max_new_tokens (`int`):
The maximum number of tokens to generate.
"""
def __init__(self, start_length: int, max_new_tokens: int):
warnings.warn(
"The class `MaxNewTokensCriteria` is deprecated. "
f"Please use `MaxLengthCriteria(max_length={start_length + max_new_tokens})` "
"with `max_length = start_length + max_new_tokens` instead.",
FutureWarning,
)
self.start_length = start_length
self.max_new_tokens = max_new_tokens
self.max_length = start_length + max_new_tokens
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
is_done = input_ids.shape[-1] >= self.max_length
return torch.full((input_ids.shape[0],), is_done, device=input_ids.device, dtype=torch.bool)
class MaxTimeCriteria(StoppingCriteria):
"""
This class can be used to stop generation whenever the full generation exceeds some amount of time. By default, the
time will start being counted when you initialize this function. You can override this by passing an
`initial_time`.
Args:
max_time (`float`):
The maximum allowed time in seconds for the generation.
initial_time (`float`, *optional*, defaults to `time.time()`):
The start of the generation allowed time.
"""
def __init__(self, max_time: float, initial_timestamp: Optional[float] = None):
self.max_time = max_time
self.initial_timestamp = time.time() if initial_timestamp is None else initial_timestamp
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
is_done = time.time() - self.initial_timestamp > self.max_time
return torch.full((input_ids.shape[0],), is_done, device=input_ids.device, dtype=torch.bool)
class StoppingCriteriaList(list):
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
is_done = torch.full((input_ids.shape[0],), False, device=input_ids.device)
for criteria in self:
is_done = is_done | criteria(input_ids, scores, **kwargs)
return is_done
def max_length(self) -> Optional[int]:
for stopping_criterium in self:
if isinstance(stopping_criterium, MaxLengthCriteria):
return stopping_criterium.max_length
elif isinstance(stopping_criterium, MaxNewTokensCriteria):
return stopping_criterium.max_length
return None
def validate_stopping_criteria(stopping_criteria: StoppingCriteriaList, max_length: int) -> StoppingCriteriaList:
stopping_max_length = stopping_criteria.max_length
new_stopping_criteria = deepcopy(stopping_criteria)
if stopping_max_length is not None and stopping_max_length != max_length:
warnings.warn("You set different `max_length` for stopping criteria and `max_length` parameter", UserWarning)
elif stopping_max_length is None:
new_stopping_criteria.append(MaxLengthCriteria(max_length=max_length))
return new_stopping_criteria
.\generation\streamers.py
from queue import Queue
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
from ..models.auto import AutoTokenizer
class BaseStreamer:
"""
Base class from which `.generate()` streamers should inherit.
"""
def put(self, value):
"""Function that is called by `.generate()` to push new tokens"""
raise NotImplementedError()
def end(self):
"""Function that is called by `.generate()` to signal the end of generation"""
raise NotImplementedError()
class TextStreamer(BaseStreamer):
"""
Simple text streamer that prints the token(s) to stdout as soon as entire words are formed.
<Tip warning={true}>
The API for the streamer classes is still under development and may change in the future.
</Tip>
Parameters:
tokenizer (`AutoTokenizer`):
The tokenized used to decode the tokens.
skip_prompt (`bool`, *optional*, defaults to `False`):
Whether to skip the prompt to `.generate()` or not. Useful e.g. for chatbots.
decode_kwargs (`dict`, *optional*):
Additional keyword arguments to pass to the tokenizer's `decode` method.
Examples:
```
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
>>> tok = AutoTokenizer.from_pretrained("openai-community/gpt2")
>>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
>>> inputs = tok(["An increasing sequence: one,"], return_tensors="pt")
>>> streamer = TextStreamer(tok)
>>> # Despite returning the usual output, the streamer will also print the generated text to stdout.
>>> _ = model.generate(**inputs, streamer=streamer, max_new_tokens=20)
An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven,
```
"""
def __init__(self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs):
self.tokenizer = tokenizer
self.skip_prompt = skip_prompt
self.decode_kwargs = decode_kwargs
self.token_cache = []
self.print_len = 0
self.next_tokens_are_prompt = True
def put(self, value):
"""
Receives tokens, decodes them, and prints them to stdout as soon as they form entire words.
"""
if len(value.shape) > 1 and value.shape[0] > 1:
raise ValueError("TextStreamer only supports batch size 1")
elif len(value.shape) > 1:
value = value[0]
if self.skip_prompt and self.next_tokens_are_prompt:
self.next_tokens_are_prompt = False
return
self.token_cache.extend(value.tolist())
text = self.tokenizer.decode(self.token_cache, **self.decode_kwargs)
if text.endswith("\n"):
printable_text = text[self.print_len :]
self.token_cache = []
self.print_len = 0
elif len(text) > 0 and self._is_chinese_char(ord(text[-1])):
printable_text = text[self.print_len :]
self.print_len += len(printable_text)
else:
printable_text = text[self.print_len : text.rfind(" ") + 1]
self.print_len += len(printable_text)
self.on_finalized_text(printable_text)
def end(self):
"""Flushes any remaining cache and prints a newline to stdout."""
if len(self.token_cache) > 0:
text = self.tokenizer.decode(self.token_cache, **self.decode_kwargs)
printable_text = text[self.print_len :]
self.token_cache = []
self.print_len = 0
else:
printable_text = ""
self.next_tokens_are_prompt = True
self.on_finalized_text(printable_text, stream_end=True)
def on_finalized_text(self, text: str, stream_end: bool = False):
"""Prints the new text to stdout. If the stream is ending, also prints a newline."""
print(text, flush=True, end="" if not stream_end else None)
def _is_chinese_char(self, cp):
"""Checks whether CP is the codepoint of a CJK character."""
if (
(cp >= 0x4E00 and cp <= 0x9FFF)
or (cp >= 0x3400 and cp <= 0x4DBF)
or (cp >= 0x20000 and cp <= 0x2A6DF)
or (cp >= 0x2A700 and cp <= 0x2B73F)
or (cp >= 0x2B740 and cp <= 0x2B81F)
or (cp >= 0x2B820 and cp <= 0x2CEAF)
or (cp >= 0xF900 and cp <= 0xFAFF)
or (cp >= 0x2F800 and cp <= 0x2FA1F)
):
return True
return False
class TextIteratorStreamer(TextStreamer):
"""
Streamer that stores print-ready text in a queue, to be used by a downstream application as an iterator. This is
useful for applications that benefit from accessing the generated text in a non-blocking way (e.g. in an interactive
Gradio demo).
<Tip warning={true}>
The API for the streamer classes is still under development and may change in the future.
</Tip>
Parameters:
tokenizer (`AutoTokenizer`):
The tokenizer used to decode the tokens.
skip_prompt (`bool`, *optional*, defaults to `False`):
Whether to skip the prompt to `.generate()` or not. Useful e.g. for chatbots.
timeout (`float`, *optional*):
The timeout for the text queue. If `None`, the queue will block indefinitely. Useful to handle exceptions
in `.generate()`, when it is called in a separate thread.
decode_kwargs (`dict`, *optional*):
Additional keyword arguments to pass to the tokenizer's `decode` method.
Examples:
```
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
>>> from threading import Thread
>>> tok = AutoTokenizer.from_pretrained("openai-community/gpt2")
>>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
>>> inputs = tok(["An increasing sequence: one,"], return_tensors="pt")
>>> streamer = TextIteratorStreamer(tok)
>>> # Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way.
>>> generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=20)
>>> thread = Thread(target=model.generate, kwargs=generation_kwargs)
>>> thread.start()
>>> generated_text = ""
>>> for new_text in streamer:
... generated_text += new_text
>>> generated_text
'An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven,'
```
"""
def __init__(
self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, timeout: Optional[float] = None, **decode_kwargs
):
super().__init__(tokenizer, skip_prompt, **decode_kwargs)
self.text_queue = Queue()
self.stop_signal = None
self.timeout = timeout
def on_finalized_text(self, text: str, stream_end: bool = False):
"""Put the new text in the queue. If the stream is ending, also put a stop signal in the queue."""
self.text_queue.put(text, timeout=self.timeout)
if stream_end:
self.text_queue.put(self.stop_signal, timeout=self.timeout)
def __iter__(self):
return self
def __next__(self):
value = self.text_queue.get(timeout=self.timeout)
if value == self.stop_signal:
raise StopIteration()
else:
return value
.\generation\tf_logits_process.py
import inspect
from typing import List, Tuple
import numpy as np
import tensorflow as tf
from ..tf_utils import stable_softmax
from ..utils import add_start_docstrings
from ..utils.logging import get_logger
logger = get_logger(__name__)
TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING = r"""
Args:
input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary.
Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
scores (`tf.Tensor` of shape `(batch_size, config.vocab_size)`):
Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam
search or log softmax for each vocabulary token when using beam search.
cur_len (`int`):
The current length of valid input sequence tokens. In the TF implementation, the input_ids' sequence length
is the maximum length generate can produce, and we need to know which of its tokens are valid.
kwargs (`Dict[str, Any]`, *optional*):
Additional logits processor specific kwargs.
Return:
`tf.Tensor` of shape `(batch_size, config.vocab_size)`: The processed prediction scores.
"""
class TFLogitsProcessor:
"""Abstract base class for all logit processors that can be applied during generation."""
@add_start_docstrings(TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
"""TF method for processing logits."""
raise NotImplementedError(
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
)
class TFLogitsWarper:
"""Abstract base class for all logit warpers that can be applied during generation with multinomial sampling."""
@add_start_docstrings(TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
"""TF method for warping logits."""
raise NotImplementedError(
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
)
class TFLogitsProcessorList(list):
@add_start_docstrings(TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int, **kwargs) -> tf.Tensor:
for processor in self:
function_args = inspect.signature(processor.__call__).parameters
if len(function_args) > 3:
if not all(arg in kwargs for arg in list(function_args.keys())[2:]):
raise ValueError(
f"Make sure that all the required parameters: {list(function_args.keys())} for "
f"{processor.__class__} are passed to the logits processor."
)
scores = processor(input_ids, scores, cur_len, **kwargs)
else:
scores = processor(input_ids, scores, cur_len)
return scores
class TFTemperatureLogitsWarper(TFLogitsWarper):
def __init__(self, temperature: float):
if not isinstance(temperature, float) or not (temperature > 0):
raise ValueError(f"`temperature` has to be a strictly positive float, but is {temperature}")
self.temperature = temperature
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
scores = scores / self.temperature
return scores
class TFTopKLogitsWarper(TFLogitsWarper):
def __init__(self, top_k: int, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
if not isinstance(top_k, int) or top_k <= 0:
raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}")
self.top_k = max(top_k, min_tokens_to_keep)
self.filter_value = filter_value
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
top_k = min(self.top_k, scores.shape[-1])
indices_to_remove = scores < tf.math.top_k(scores, k=top_k)[0][..., -1:]
next_scores = tf.where(indices_to_remove, self.filter_value, scores)
return next_scores
Args:
top_p (`float`):
如果设置为小于1的值,则只保留概率相加达到`top_p`或更高的最有可能的token用于生成。
filter_value (`float`, *optional*, 默认为负无穷):
所有被过滤的值将被设置为这个浮点数值。
min_tokens_to_keep (`int`, *optional*, 默认为1):
不能被过滤掉的最小token数目。
"""
def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
# 检查top_p是否为浮点数且在0到1之间
if not isinstance(top_p, float) or (top_p < 0 or top_p > 1.0):
raise ValueError(f"`top_p`必须是一个大于0且小于1的浮点数,当前值为{top_p}")
# 检查min_tokens_to_keep是否为正整数
if not isinstance(min_tokens_to_keep, int) or (min_tokens_to_keep < 1):
raise ValueError(f"`min_tokens_to_keep`必须是一个正整数,当前值为{min_tokens_to_keep}")
# 初始化实例变量
self.top_p = top_p
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
# 获取前k个最高分数和对应的索引
topk_scores, topk_indices = tf.math.top_k(scores, scores.shape[-1])
# 创建与scores相同形状的填充值为filter_value的张量
mask_scores = tf.fill(scores.shape, self.filter_value)
# 计算topk_scores的稳定softmax,并累积概率
cumulative_probs = tf.math.cumsum(stable_softmax(topk_scores, axis=-1), axis=-1)
# 创建一个布尔掩码,标记累积概率小于top_p的位置
score_mask = cumulative_probs < self.top_p
# 将第一个false替换为true,确保包含大于top_p的token
score_mask = tf.concat((tf.ones([score_mask.shape[0], 1], dtype=tf.bool), score_mask[:, :-1]), axis=-1)
# 确保保留至少min_tokens_to_keep个token
score_mask = tf.concat(
(
tf.ones([score_mask.shape[0], self.min_tokens_to_keep], dtype=tf.bool),
score_mask[:, self.min_tokens_to_keep:],
),
axis=-1,
)
# 根据掩码将不符合条件的值设为filter_value
topk_next_scores = tf.where(score_mask, topk_scores, mask_scores)
# 恢复topk排序的顺序:将原始索引位置重新分散到张量中
scatter_rows = tf.tile(tf.expand_dims(tf.range(topk_indices.shape[0]), axis=-1), [1, topk_indices.shape[-1]])
scatter_indices = tf.stack((scatter_rows, topk_indices), axis=-1)
next_scores = tf.scatter_nd(scatter_indices, topk_next_scores, shape=topk_next_scores.shape)
return next_scores
# 定义一个 TFLogitsProcessor 类,用于处理 logits(预测得分),实现通过设置 EOS 概率为 0 来强制最小长度。
Args:
min_length (`int`):
最小长度,低于此长度时,`eos_token_id` 的得分被设置为 `-float("Inf")`。
eos_token_id (`int`):
*end-of-sequence*(EOS)标记的 id。
"""
def __init__(self, min_length: int, eos_token_id: int):
if not isinstance(min_length, int) or min_length < 0:
raise ValueError(f"`min_length` 必须是正整数,但其值为 {min_length}")
if not isinstance(eos_token_id, int) or eos_token_id < 0:
raise ValueError(f"`eos_token_id` 必须是正整数,但其值为 {eos_token_id}")
self.min_length = min_length
self.eos_token_id = eos_token_id
def _apply_eos_token_mask(self, scores: tf.Tensor) -> tf.Tensor:
eos_token_id_mask = tf.range(scores.shape[-1]) == self.eos_token_id
scores = tf.where(eos_token_id_mask, float("-inf"), scores)
return scores
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
scores = tf.cond(
tf.less(cur_len, self.min_length),
lambda: self._apply_eos_token_mask(scores),
lambda: tf.identity(scores),
)
return scores
class TFRepetitionPenaltyLogitsProcessor(TFLogitsProcessor):
r"""
[`TFLogitsProcessor`] enforcing an exponential penalty on repeated sequences.
Args:
repetition_penalty (`float`):
The parameter for repetition penalty. 1.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
"""
def __init__(self, penalty: float):
if not isinstance(penalty, float) or not (penalty > 0):
raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
self.penalty = penalty
def _create_score_penalties(self, input_ids: tf.Tensor, logits: tf.Tensor) -> tf.Tensor:
logit_penalties = tf.gather(logits, input_ids, axis=1, batch_dims=1)
logit_penalties = tf.where(logit_penalties > 0, 1 / self.penalty, logit_penalties)
logit_penalties = tf.where(logit_penalties < 0, self.penalty, logit_penalties)
token_penalties = tf.ones(logits.shape)
batch_size = input_ids.shape[0]
seq_len = tf.shape(input_ids)[1]
indexable_prev_input_ids = tf.concat(
(
tf.expand_dims(tf.repeat(tf.range(batch_size), seq_len), axis=-1),
tf.expand_dims(tf.reshape(input_ids, [-1]), axis=-1),
),
axis=1,
)
token_penalties = tf.tensor_scatter_nd_update(
token_penalties, indices=indexable_prev_input_ids, updates=tf.reshape(logit_penalties, [-1])
)
return token_penalties
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
score_penalties = self._create_score_penalties(input_ids[:, :cur_len], scores)
scores = tf.math.multiply(scores, score_penalties)
return scores
class TFNoBadWordsLogitsProcessor(TFLogitsProcessor):
"""
[`TFLogitsProcessor`] that enforces that specified sequences will never be sampled.
"""
Args:
bad_words_ids (`List[List[int]]`):
不允许生成的令牌 ID 列表的列表。为了获取不应出现在生成文本中的词汇的令牌,请确保在初始化分词器时设置 `add_prefix_space=True`,并使用 `tokenizer(bad_words, add_special_tokens=False).input_ids` 来获取这些词汇的令牌 ID 列表。对于某些较慢的分词器,`add_prefix_space` 参数是支持的,因为快速分词器的前缀行为来自于 `pre tokenizers`。详细信息请参阅 [这里](https://huggingface.co/docs/tokenizers/api/pre-tokenizers)。
eos_token_id (`int`):
*end-of-sequence*(EOS)令牌的 ID。
"""
def __init__(self, bad_words_ids: List[List[int]], eos_token_id: int):
# 检查 `bad_words_ids` 是否为列表且非空
if not isinstance(bad_words_ids, List) or len(bad_words_ids) == 0:
raise ValueError(f"`bad_words_ids` 必须是非空列表,当前为 {bad_words_ids}。")
# 检查 `bad_words_ids` 中的每个元素是否为列表
if any(not isinstance(bad_word_ids, list) for bad_word_ids in bad_words_ids):
raise ValueError(f"`bad_words_ids` 必须是列表的列表,当前为 {bad_words_ids}。")
# 检查 `bad_words_ids` 中的每个元素是否为正整数列表
if any(
any((not isinstance(token_id, (int, np.integer)) or token_id < 0) for token_id in bad_word_ids)
for bad_word_ids in bad_words_ids
):
raise ValueError(
f"`bad_words_ids` 中的每个列表必须是正整数列表,当前为 {bad_words_ids}。"
)
# 存储关于不允许的词汇的信息,使用三个张量:
# 1. 一个矩形张量,包含禁止序列(用 `-1` 填充),用于完整数据比较
self.bad_word_seqs_ids = tf.ragged.constant(bad_words_ids).to_tensor(default_value=-1)
# 2. 一个张量,包含每个禁止序列的未填充长度,用于快速长度比较
bad_word_seqs_len = [len(bad_words) for bad_words in bad_words_ids]
# 检查禁止词汇序列的长度是否为零
if any(word_len == 0 for word_len in bad_word_seqs_len):
raise ValueError(f"禁止词汇序列 {bad_words_ids} 不能包含空列表")
self.bad_word_seqs_len = tf.convert_to_tensor(bad_word_seqs_len, dtype=tf.int32)
# 3. 一个张量,包含每个序列的最后一个令牌,便于访问可能被禁止的令牌
self.seq_forbidden_tokens = tf.convert_to_tensor([bad_words[-1] for bad_words in bad_words_ids])
def _calc_row_banned_bad_tokens(self, row_input_ids: tf.Tensor) -> tf.Tensor:
def _tokens_match(bad_word_seq_number):
def _len_one():
# 如果坏序列只有一个标记,则始终屏蔽它
return tf.cond(
tf.math.equal(self.bad_word_seqs_len[bad_word_seq_number], 1),
lambda: tf.ones((), dtype=tf.bool),
_len_greater_than_cur_len,
)
def _len_greater_than_cur_len():
# 否则,如果坏序列比当前长度长,它们永远不会匹配
return tf.cond(
tf.math.greater(self.bad_word_seqs_len[bad_word_seq_number], tf.shape(row_input_ids)[0]),
lambda: tf.zeros((), dtype=tf.bool),
_match_found,
)
def _match_found():
# 最后,执行实际的比较。只有在之前的比较没有结果时才能调用(否则会导致索引异常)
compare_len = self.bad_word_seqs_len[bad_word_seq_number] - 1
return tf.cond(
tf.math.reduce_all(
tf.math.equal(
row_input_ids[-compare_len:], self.bad_word_seqs_ids[bad_word_seq_number, :compare_len]
)
),
lambda: tf.ones((), dtype=tf.bool),
lambda: tf.zeros((), dtype=tf.bool),
)
match = _len_one()
return match
# 将当前行与所有坏词序列进行比较,获取匹配的掩码
match_mask = tf.map_fn(_tokens_match, tf.range(self.bad_word_seqs_ids.shape[0]), fn_output_signature=tf.bool)
row_banned_tokens = self.seq_forbidden_tokens[match_mask]
return row_banned_tokens
# 定义一个调用函数,接受输入的 `input_ids`(Tensor 类型)、分数 `scores`(Tensor 类型)、当前长度 `cur_len`(整数类型),返回更新后的分数 `scores`(Tensor 类型)
# 我们希望在分数级别上屏蔽一些被禁止的令牌。由于被禁止的令牌取决于前一个 `input_ids`,它们可能对每一行具有不同的长度,甚至对某些行来说可能为空。
# 为了保持简单并与 XLA 兼容,我们以逐行的方式进行操作。
# TODO(Joao):这个函数可能会因为 `cur_len` 的增加而触发 XLA 重追踪。如果这成为频繁的瓶颈,请修复它。(将 `cur_len` 设为一个张量?)
def _get_row_updated_score(row_inputs: Tuple[tf.Tensor]) -> tf.Tensor:
# 获取当前行的输入 `row_input_ids` 和分数 `row_score`
row_input_ids, row_score = row_inputs
# 计算当前行被禁止的坏令牌列表,基于 `row_input_ids` 的前 `cur_len` 部分
banned_tokens = self._calc_row_banned_bad_tokens(row_input_ids[:cur_len])
# 创建一个布尔类型的张量,表示被禁止的令牌的位置,其形状与 `row_score` 相同
banned_tokens_mask = tf.scatter_nd(
indices=tf.expand_dims(banned_tokens, axis=-1),
updates=tf.ones_like(banned_tokens, dtype=tf.bool),
shape=row_score.shape,
)
# 使用 `-inf` 替换被禁止令牌的位置上的分数,保持其它位置不变
row_score = tf.where(banned_tokens_mask, -float("inf"), row_score)
return row_score
# 对每一行调用 `_get_row_updated_score` 函数,更新分数 `scores`,并返回更新后的 `scores`
scores = tf.map_fn(_get_row_updated_score, (input_ids, scores), fn_output_signature=tf.float32)
return scores
class TFNoRepeatNGramLogitsProcessor(TFLogitsProcessor):
r"""
[`TFLogitsProcessor`] that enforces no repetition of n-grams. See
[Fairseq](https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py
Args:
ngram_size (`int`):
All ngrams of size `ngram_size` can only occur once.
"""
def __init__(self, ngram_size: int):
# 初始化方法,验证并设置 ngram_size 参数
if not isinstance(ngram_size, int) or ngram_size <= 0:
raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}")
self.ngram_size = ngram_size
def calc_banned_ngram_tokens(self, input_ids, num_hypos, cur_len):
# 计算禁止的 ngram tokens,用于防止 ngram 重复
# 从 fairseq 中复制用于在 beam search 中实现 no_repeat_ngram
if cur_len + 1 < self.ngram_size:
# 如果当前长度加 1 小于 ngram_size,返回空列表表示没有禁止的 token
return [[] for _ in range(num_hypos)]
generated_ngrams = [{} for _ in range(num_hypos)]
prev_input_ids = input_ids[:, :cur_len]
for idx in range(num_hypos):
gen_tokens = prev_input_ids[idx].numpy().tolist()
generated_ngram = generated_ngrams[idx]
for ngram in zip(*[gen_tokens[i:] for i in range(self.ngram_size)]):
prev_ngram_tuple = tuple(ngram[:-1])
generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]
def _get_generated_ngrams(hypo_idx):
# 在解码下一个 token 前,防止解码已经出现的 ngrams
start_idx = cur_len + 1 - self.ngram_size
ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].numpy().tolist())
return generated_ngrams[hypo_idx].get(ngram_idx, [])
# 返回禁止的 tokens 列表
banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]
return banned_tokens
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
# 调用对象时的处理方法,用于处理 logits
# TODO (joao): enable XLA on this logits processor. See discussion and attempts in
# https://github.com/huggingface/transformers/pull/16974
if not tf.executing_eagerly():
raise NotImplementedError("TFNoRepeatNGramLogitsProcessor is only implemented for eager execution.")
batch_size, vocab_size = scores.shape
# 计算禁止的 ngram tokens
banned_tokens = self.calc_banned_ngram_tokens(input_ids, batch_size, cur_len)
# 创建禁止 tokens 的布尔掩码
banned_tokens_indices_mask = []
for banned_tokens_slice in banned_tokens:
banned_tokens_indices_mask.append(
[True if token in banned_tokens_slice else False for token in range(vocab_size)]
)
# 将禁止的 tokens 对应位置的 logits 设置为负无穷
scores = tf.where(tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf"), scores)
return scores
class TFForcedBOSTokenLogitsProcessor(TFLogitsProcessor):
r"""
def __init__(self, bos_token_id: int):
if bos_token_id < 0:
raise ValueError(f"The forced bos token id must be a non-negative integer, got {bos_token_id}")
self.bos_token_id = bos_token_id
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
if cur_len == 1:
batch_size, num_tokens = scores.shape
scores = tf.zeros((batch_size, 1))
if self.bos_token_id > 0:
scores = tf.concat((tf.broadcast_to(-float("inf"), (batch_size, self.bos_token_id)), scores), axis=-1)
if self.bos_token_id < (num_tokens - 1):
scores = tf.concat(
(scores, tf.broadcast_to(-float("inf"), (batch_size, (num_tokens - 1) - self.bos_token_id))),
axis=-1,
)
return scores
class TFForcedEOSTokenLogitsProcessor(TFLogitsProcessor):
r"""
[`TFLogitsProcessor`] that enforces the specified token as the last generated token when `max_length` is reached.
Args:
max_length (`int`):
The maximum length of the sequence to be generated.
eos_token_id (`int`):
The id of the token to force as the last generated token when `max_length` is reached.
"""
def __init__(self, max_length: int, eos_token_id: int):
self.max_length = max_length
if eos_token_id < 0:
raise ValueError(f"The forced eos token id must be a non-negative integer, got {eos_token_id}")
self.eos_token_id = eos_token_id
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
if cur_len == self.max_length - 1:
batch_size, num_tokens = scores.shape
scores = tf.zeros((batch_size, 1))
if self.eos_token_id > 0:
scores = tf.concat((tf.broadcast_to(-float("inf"), (batch_size, self.eos_token_id)), scores), axis=-1)
if self.eos_token_id < (num_tokens - 1):
scores = tf.concat(
(scores, tf.broadcast_to(-float("inf"), (batch_size, (num_tokens - 1) - self.eos_token_id))),
axis=-1,
)
return scores
class TFSuppressTokensAtBeginLogitsProcessor(TFLogitsProcessor):
r"""
[`TFSuppressTokensAtBeginLogitsProcessor`] suppresses a list of tokens as soon as the `generate` function starts
generating using `begin_index` tokens. This should ensure that the tokens defined by `begin_suppress_tokens` at not
sampled at the begining of the generation.
"""
def __init__(self, begin_suppress_tokens, begin_index):
self.begin_suppress_tokens = list(begin_suppress_tokens)
self.begin_index = begin_index
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
scores = tf.cond(
tf.equal(cur_len, self.begin_index),
lambda: tf.tensor_scatter_nd_update(
scores,
indices=[[i, token] for i in range(scores.shape[0]) for token in self.begin_suppress_tokens],
updates=[-float("inf") for _ in range(scores.shape[0] * len(self.begin_suppress_tokens))],
),
lambda: scores,
)
return scores
class TFSuppressTokensLogitsProcessor(TFLogitsProcessor):
r"""This processor can be used to suppress a list of tokens. The processor will set their log probs to `-inf` so that they
are not sampled."""
def __init__(self, suppress_tokens):
self.suppress_tokens = list(suppress_tokens)
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
scores = tf.tensor_scatter_nd_update(
scores,
indices=[[i, token] for i in range(scores.shape[0]) for token in self.suppress_tokens],
updates=[-float("inf") for _ in range(scores.shape[0] * len(self.suppress_tokens))],
)
return scores
class TFForceTokensLogitsProcessor(TFLogitsProcessor):
r"""This processor takes a list of pairs of integers which indicates a mapping from generation indices to token
indices that will be forced before sampling. The processor will set their log probs to `0` and all other tokens to
`-inf` so that they are sampled at their corresponding index."""
def __init__(self, force_token_map: List[List[int]]):
force_token_map = dict(force_token_map)
force_token_array = np.ones((max(force_token_map.keys()) + 1), dtype=np.int32) * -1
for index, token in force_token_map.items():
if token is not None:
force_token_array[index] = token
self.force_token_array = tf.convert_to_tensor(force_token_array, dtype=tf.int32)
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
def _force_token(generation_idx):
batch_size = scores.shape[0]
current_token = self.force_token_array[generation_idx]
new_scores = tf.ones_like(scores, dtype=scores.dtype) * -float("inf")
indices = tf.stack((tf.range(batch_size), tf.tile([current_token], [batch_size])), axis=1)
updates = tf.zeros((batch_size,), dtype=scores.dtype)
new_scores = tf.tensor_scatter_nd_update(new_scores, indices, updates)
return new_scores
scores = tf.cond(
tf.greater_equal(cur_len, tf.shape(self.force_token_array)[0]),
lambda: tf.identity(scores),
lambda: tf.cond(
tf.greater_equal(self.force_token_array[cur_len], 0),
lambda: _force_token(cur_len),
lambda: scores,
),
)
return scores
.\generation\tf_utils.py
import copy
import inspect
import warnings
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union
import numpy as np
import tensorflow as tf
from tensorflow.compiler.tf2xla.python.xla import dynamic_update_slice
from ..modeling_tf_outputs import TFCausalLMOutputWithPast, TFSeq2SeqLMOutput
from ..models.auto import (
TF_MODEL_FOR_CAUSAL_LM_MAPPING,
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
TF_MODEL_FOR_VISION_2_SEQ_MAPPING,
)
from ..tf_utils import shape_list, stable_softmax
from ..utils import ModelOutput, logging
from .configuration_utils import GenerationConfig
from .tf_logits_process import (
TFForcedBOSTokenLogitsProcessor,
TFForcedEOSTokenLogitsProcessor,
TFForceTokensLogitsProcessor,
TFLogitsProcessorList,
TFMinLengthLogitsProcessor,
TFNoBadWordsLogitsProcessor,
TFNoRepeatNGramLogitsProcessor,
TFRepetitionPenaltyLogitsProcessor,
TFSuppressTokensAtBeginLogitsProcessor,
TFSuppressTokensLogitsProcessor,
TFTemperatureLogitsWarper,
TFTopKLogitsWarper,
TFTopPLogitsWarper,
)
logger = logging.get_logger(__name__)
@dataclass
class TFGreedySearchDecoderOnlyOutput(ModelOutput):
"""
Base class for outputs of decoder-only generation models using greedy search.
"""
pass
sequences: tf.Tensor = None
scores: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[Tuple[tf.Tensor]]] = None
hidden_states: Optional[Tuple[Tuple[tf.Tensor]]] = None
@dataclass
class TFGreedySearchEncoderDecoderOutput(ModelOutput):
"""
Base class for outputs of encoder-decoder generation models using greedy search. Hidden states and attention
weights of the decoder (respectively the encoder) can be accessed via the encoder_attentions and the
encoder_hidden_states attributes (respectively the decoder_attentions and the decoder_hidden_states attributes)
"""
sequences: tf.Tensor = None
scores: Optional[Tuple[tf.Tensor]] = None
encoder_attentions: Optional[Tuple[tf.Tensor]] = None
encoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
decoder_attentions: Optional[Tuple[Tuple[tf.Tensor]]] = None
cross_attentions: Optional[Tuple[Tuple[tf.Tensor]]] = None
decoder_hidden_states: Optional[Tuple[Tuple[tf.Tensor]]] = None
encoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
decoder_attentions: Optional[Tuple[Tuple[tf.Tensor]]] = None
cross_attentions: Optional[Tuple[Tuple[tf.Tensor]]] = None
decoder_hidden_states: Optional[Tuple[Tuple[tf.Tensor]]] = None
@dataclass
class TFSampleDecoderOnlyOutput(ModelOutput):
"""
Decoder-only生成模型使用采样生成的输出的基类。
Args:
sequences (`tf.Tensor` of shape `(batch_size*num_return_sequences, sequence_length)`):
生成的序列。第二个维度(sequence_length)要么等于`max_length`,要么比`eos_token_id`提前结束。
scores (`tuple(tf.Tensor)` *optional*, 当传入`output_scores=True`或者`config.output_scores=True`时返回):
语言建模头部的处理过的预测分数(SoftMax之前的每个词汇标记的分数)在每个生成步骤中。
元组中包含最多`max_new_tokens`个元素(每个生成的词汇标记一个元素),每个张量的形状为`(batch_size*num_return_sequences, config.vocab_size)`。
attentions (`tuple(tuple(tf.Tensor))`, *optional*, 当传入`output_attentions=True`或者`config.output_attentions=True`时返回):
每个生成的词汇标记的元组(每个生成的词汇标记一个元素),其中包含解码器每一层的注意力张量。
注意力张量的形状为`(num_return_sequences*batch_size, num_heads, generated_length, sequence_length)`。
hidden_states (`tuple(tuple(tf.Tensor))`, *optional*, 当传入`output_hidden_states=True`或者`config.output_hidden_states=True`时返回):
每个生成的词汇标记的元组(每个生成的词汇标记一个元素),其中包含解码器每一层的隐藏状态张量。
隐藏状态张量的形状为`(num_return_sequences*batch_size, generated_length, hidden_size)`。
"""
sequences: tf.Tensor = None
scores: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[Tuple[tf.Tensor]]] = None
hidden_states: Optional[Tuple[Tuple[tf.Tensor]]] = None
@dataclass
class TFSampleEncoderDecoderOutput(ModelOutput):
"""
Encoder-decoder生成模型使用采样生成的输出的基类。可以通过encoder_attentions和encoder_hidden_states属性(分别通过decoder_attentions和decoder_hidden_states属性)访问解码器(分别是编码器)的隐藏状态和注意力权重。
"""
Args:
sequences (`tf.Tensor` of shape `(batch_size*num_return_sequences, sequence_length)`):
生成的序列。第二维(sequence_length)要么等于 `max_length`,要么因为 `eos_token_id` 导致所有批次提前结束而更短。
scores (`tuple(tf.Tensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
语言建模头部处理后的预测分数(在SoftMax之前的每个词汇标记的分数),每一代步骤有一个元组,包含最多 `max_new_tokens` 个元素,
每个张量的形状为 `(batch_size*num_return_sequences, config.vocab_size)`。
encoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
编码器注意力的元组(每个解码器层一个张量),形状为 `(batch_size*num_return_sequences, num_heads, sequence_length, sequence_length)`。
encoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
编码器隐藏状态的元组(每个解码器层一个张量),形状为 `(batch_size*num_return_sequences, sequence_length, hidden_size)`。
decoder_attentions (`tuple(tuple(tf.Tensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
解码器注意力的元组(每个生成的令牌一个元组,每个解码器层一个张量),形状为 `(batch_size*num_return_sequences, num_heads, generated_length, sequence_length)`。
cross_attentions (`tuple(tuple(tf.Tensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
交叉注意力的元组(每个生成的令牌一个元组,每个解码器层一个张量),形状为 `(batch_size, num_heads, generated_length, sequence_length)`。
decoder_hidden_states (`tuple(tuple(tf.Tensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
解码器隐藏状态的元组(每个生成的令牌一个元组,每个解码器层一个张量),形状为 `(batch_size*num_return_sequences, generated_length, hidden_size)`。
sequences: tf.Tensor = None
scores: Optional[Tuple[tf.Tensor]] = None
encoder_attentions: Optional[Tuple[tf.Tensor]] = None
encoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
decoder_attentions: Optional[Tuple[Tuple[tf.Tensor]]] = None
cross_attentions: Optional[Tuple[Tuple[tf.Tensor]]] = None
decoder_hidden_states: Optional[Tuple[Tuple[tf.Tensor]]] = None
@dataclass
class TFBeamSearchDecoderOnlyOutput(ModelOutput):
"""
Base class for outputs of decoder-only generation models using beam search.
解码器仅使用 beam search 生成模型的输出的基类。
Args:
sequences (`tf.Tensor` of shape `(batch_size*num_return_sequences, sequence_length)`):
The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
if all batches finished early due to the `eos_token_id`.
生成的序列。第二维度(sequence_length)要么等于 `max_length`,要么由于 `eos_token_id` 导致所有批次提前结束而更短。
sequences_scores (`tf.Tensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
Final beam scores of the generated `sequences`.
生成的 `sequences` 的最终 beam 分数。
scores (`tuple(tf.Tensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
Processed beam scores for each vocabulary token at each generation step. Beam scores consisting of log
softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this
beam. Tuple of `tf.Tensor` with up to `max_new_tokens` elements (one element for each generated token),
with each tensor of shape `(batch_size*num_beams*num_return_sequences, config.vocab_size)`.
每一代生成步骤中每个词汇标记的处理过的 beam 分数。包括每个词汇标记的 log softmax 分数和该 beam 中先前生成的标记的 log softmax 的总和。
beam_indices (`tf.Tensor`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
Beam indices of generated token id at each generation step. `tf.Tensor` of shape
`(batch_size*num_return_sequences, sequence_length)`.
每个生成步骤生成的标记 id 的 beam 索引。形状为 `(batch_size*num_return_sequences, sequence_length)` 的 `tf.Tensor`。
attentions (`tuple(tuple(tf.Tensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`tf.Tensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`.
每个生成的标记的元组(每个解码器层的一个元素)的元组(每个生成的标记的元素)的注意力张量。形状为 `(batch_size*num_beams, num_heads, generated_length, sequence_length)` 的 `tf.Tensor`。
hidden_states (`tuple(tuple(tf.Tensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`tf.Tensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`.
每个生成的标记的元组(每个解码器层的一个元素)的元组(每个生成的标记的元素)的隐藏状态张量。形状为 `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)` 的 `tf.Tensor`。
"""
sequences: tf.Tensor = None
sequences_scores: Optional[tf.Tensor] = None
scores: Optional[Tuple[tf.Tensor]] = None
beam_indices: Optional[tf.Tensor] = None
attentions: Optional[Tuple[Tuple[tf.Tensor]]] = None
hidden_states: Optional[Tuple[Tuple[tf.Tensor]]] = None
@dataclass
class TFBeamSearchEncoderDecoderOutput(ModelOutput):
"""
Base class for outputs of encoder-decoder generation models using beam search. Hidden states and attention weights
of the decoder (respectively the encoder) can be accessed via the encoder_attentions and the encoder_hidden_states
编码器-解码器使用 beam search 生成模型的输出的基类。可以通过 encoder_attentions 和 encoder_hidden_states 访问解码器(或编码器)的隐藏状态和注意力权重。
"""
sequences: tf.Tensor = None
sequences_scores: Optional[tf.Tensor] = None
scores: Optional[Tuple[tf.Tensor]] = None
beam_indices: Optional[tf.Tensor] = None
encoder_attentions: Optional[Tuple[tf.Tensor]] = None
encoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
decoder_attentions: Optional[Tuple[Tuple[tf.Tensor]]] = None
cross_attentions: Optional[Tuple[Tuple[tf.Tensor]]] = None
decoder_hidden_states: Optional[Tuple[Tuple[tf.Tensor]]] = None
@dataclass
class TFBeamSampleDecoderOnlyOutput(ModelOutput):
"""
Decoder-only生成模型使用Beam采样的输出的基类。
Args:
sequences (`tf.Tensor` of shape `(batch_size*num_return_sequences, sequence_length)`):
生成的序列。第二维(sequence_length)要么等于`max_length`,要么因为`eos_token_id`导致所有批次提前结束而更短。
sequences_scores (`tf.Tensor` of shape `(batch_size * num_return_sequence)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
生成的`sequences`的最终beam分数。
scores (`tuple(tf.Tensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
每个生成步骤中每个词汇标记的处理beam分数。每个元素为`tf.Tensor`的元组,最多有`max_new_tokens`个元素(每个生成的标记一个元素),每个张量的形状为`(batch_size*num_beams*num_return_sequences, config.vocab_size)`。
beam_indices (`tf.Tensor`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
每个生成步骤生成的标记ID的beam索引。形状为`(batch_size*num_return_sequences, sequence_length)`的`tf.Tensor`。
attentions (`tuple(tuple(tf.Tensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
每个生成的标记的注意力权重。元组(每个生成标记一个元素),元组(每个解码器层一个元素),`tf.Tensor`的元组,形状为`(batch_size*num_beams, num_heads, generated_length, sequence_length)`。
hidden_states (`tuple(tuple(tf.Tensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
解码器每层的隐藏状态。元组(每个生成标记一个元素),元组(每个解码器层一个元素),`tf.Tensor`的元组,形状为`(batch_size*num_beams, generated_length, hidden_size)`。
"""
sequences: tf.Tensor = None
sequences_scores: Optional[tf.Tensor] = None
scores: Optional[Tuple[tf.Tensor]] = None
beam_indices: Optional[tf.Tensor] = None
attentions: Optional[Tuple[Tuple[tf.Tensor]]] = None
hidden_states: Optional[Tuple[Tuple[tf.Tensor]]] = None
@dataclass
class TFBeamSampleEncoderDecoderOutput(ModelOutput):
"""
Encoder-decoder生成模型使用Beam采样的输出的基类。可以通过encoder_attentions和encoder_hidden_states属性访问解码器(或者通过decoder_attentions和decoder_hidden_states属性访问编码器)的隐藏状态和注意力权重。
"""
sequences: tf.Tensor = None
sequences_scores: Optional[tf.Tensor] = None
scores: Optional[Tuple[tf.Tensor]] = None
beam_indices: Optional[tf.Tensor] = None
encoder_attentions: Optional[Tuple[tf.Tensor]] = None
encoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
decoder_attentions: Optional[Tuple[Tuple[tf.Tensor]]] = None
cross_attentions: Optional[Tuple[Tuple[tf.Tensor]]] = None
decoder_hidden_states: Optional[Tuple[Tuple[tf.Tensor]]] = None
@dataclass
class TFContrastiveSearchDecoderOnlyOutput(ModelOutput):
"""
Decoder-only generation model output class for contrastive search.
Args:
sequences (`tf.Tensor` of shape `(batch_size, sequence_length)`):
The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
if all batches finished early due to the `eos_token_id`.
scores (`tuple(tf.Tensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
at each generation step. Tuple of `tf.Tensor` with up to `max_new_tokens` elements (one element for each
generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
attentions (`tuple(tuple(tf.Tensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`tf.Tensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
hidden_states (`tuple(tuple(tf.Tensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`tf.Tensor` of shape `(batch_size, generated_length, hidden_size)`.
"""
sequences: tf.Tensor = None
scores: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[Tuple[tf.Tensor]]] = None
hidden_states: Optional[Tuple[Tuple[tf.Tensor]]] = None
@dataclass
class TFContrastiveSearchEncoderDecoderOutput(ModelOutput):
"""
Encoder-decoder generation model output class for contrastive search.
Base class for outputs of encoder-decoder generation models using contrastive search. Hidden states and attention
weights of the decoder (respectively the encoder) can be accessed via the encoder_attentions and the
encoder_hidden_states attributes (respectively the decoder_attentions and the decoder_hidden_states attributes)
"""
"""
Args:
sequences (`tf.Tensor` of shape `(batch_size, sequence_length)`):
生成的序列。第二个维度 (sequence_length) 可能等于 `max_length`,或者如果所有批次由于 `eos_token_id` 而提前结束,则会更短。
scores (`tuple(tf.Tensor)` *optional*, 当 `output_scores=True` 传递或 `config.output_scores=True` 时返回):
语言建模头部处理后的预测分数(SoftMax 前每个词汇标记的分数),每个生成步骤一个元组元素,元素数最多为 `max_new_tokens`,每个张量的形状为 `(batch_size, config.vocab_size)`。
encoder_attentions (`tuple(tf.Tensor)`, *optional*, 当 `output_attentions=True` 传递或 `config.output_attentions=True` 时返回):
解码器每一层的注意力权重张量的元组,形状为 `(batch_size, num_heads, sequence_length, sequence_length)`。
encoder_hidden_states (`tuple(tf.Tensor)`, *optional*, 当 `output_hidden_states=True` 传递或 `config.output_hidden_states=True` 时返回):
解码器每一层的隐藏状态张量的元组,形状为 `(batch_size, sequence_length, hidden_size)`,包含从嵌入层开始的所有层的输出。
decoder_attentions (`tuple(tuple(tf.Tensor))`, *optional*, 当 `output_attentions=True` 传递或 `config.output_attentions=True` 时返回):
每个生成的标记一个元组元素,其中每个元素是解码器每一层的注意力权重张量的元组,形状为 `(batch_size, num_heads, generated_length, sequence_length)`。
cross_attentions (`tuple(tuple(tf.Tensor))`, *optional*, 当 `output_attentions=True` 传递或 `config.output_attentions=True` 时返回):
每个生成的标记一个元组元素,其中每个元素是解码器每一层与编码器的交叉注意力权重张量的元组,形状为 `(batch_size, num_heads, generated_length, sequence_length)`。
decoder_hidden_states (`tuple(tuple(tf.Tensor))`, *optional*, 当 `output_hidden_states=True` 传递或 `config.output_hidden_states=True` 时返回):
每个生成的标记一个元组元素,其中每个元素是解码器每一层的隐藏状态张量的元组,形状为 `(batch_size, generated_length, hidden_size)`。
"""
sequences: tf.Tensor = None
scores: Optional[Tuple[tf.Tensor]] = None
encoder_attentions: Optional[Tuple[tf.Tensor]] = None
encoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
decoder_attentions: Optional[Tuple[Tuple[tf.Tensor]]] = None
cross_attentions: Optional[Tuple[Tuple[tf.Tensor]]] = None
decoder_hidden_states: Optional[Tuple[Tuple[tf.Tensor]]] = None
TFGreedySearchOutput = Union[TFGreedySearchEncoderDecoderOutput, TFGreedySearchDecoderOnlyOutput]
TFSampleOutput = Union[TFSampleEncoderDecoderOutput, TFSampleDecoderOnlyOutput]
TFBeamSearchOutput = Union[TFBeamSearchEncoderDecoderOutput, TFBeamSearchDecoderOnlyOutput]
TFBeamSampleOutput = Union[TFBeamSampleEncoderDecoderOutput, TFBeamSampleDecoderOnlyOutput]
TFContrastiveSearchOutput = Union[TFContrastiveSearchEncoderDecoderOutput, TFContrastiveSearchDecoderOnlyOutput]
TFGenerateOutput = Union[
TFGreedySearchOutput, TFSampleOutput, TFBeamSearchOutput, TFBeamSampleOutput, TFContrastiveSearchOutput
]
class TFGenerationMixin:
"""
包含支持生成的所有函数的类,用作[`TFPreTrainedModel`]中的混合类。
该类公开[`~generation.TFGenerationMixin.generate`],可以用于:
- 当`num_beams=1`且`do_sample=False`时通过调用[`~generation.TFGenerationMixin.greedy_search`]进行*贪婪解码*
- 当`penalty_alpha>0`且`top_k>1`时通过调用[`~generation.TFGenerationMixin.contrastive_search`]进行*对比搜索*
- 当`num_beams=1`且`do_sample=True`时通过调用[`~generation.TFGenerationMixin.sample`]进行*多项式采样*
- 当`num_beams>1`时通过调用[`~generation.TFGenerationMixin.beam_search`]进行*束搜索解码*
不需要直接调用上述任何方法。而是将自定义参数值传递给 'generate' 方法。有关解码策略的更多信息,请参阅[text generation strategies guide](../generation_strategies)。
"""
_seed_generator = None
@property
def seed_generator(self):
warnings.warn("`seed_generator` is deprecated and will be removed in a future version.", UserWarning)
if self._seed_generator is None:
self._seed_generator = tf.random.Generator.from_non_deterministic_state()
return self._seed_generator
supports_xla_generation = True
def prepare_inputs_for_generation(self, *args, **kwargs):
raise NotImplementedError(
"A model class needs to define a `prepare_inputs_for_generation` method in order to use `generate`."
)
def compute_transition_scores(
self,
sequences: tf.Tensor,
scores: Tuple[tf.Tensor],
beam_indices: Optional[tf.Tensor] = None,
normalize_logits: bool = False,
def _validate_model_class(self):
"""
Confirms that the model class is compatible with generation. If not, raises an exception that points to the
right class to use.
"""
if not self.can_generate():
generate_compatible_mappings = [
TF_MODEL_FOR_CAUSAL_LM_MAPPING,
TF_MODEL_FOR_VISION_2_SEQ_MAPPING,
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
]
generate_compatible_classes = set()
for model_mapping in generate_compatible_mappings:
supported_models = model_mapping.get(type(self.config), default=None)
if supported_models is not None:
generate_compatible_classes.add(supported_models.__name__)
exception_message = (
f"The current model class ({self.__class__.__name__}) is not compatible with `.generate()`, as "
"it doesn't have a language model head."
)
if generate_compatible_classes:
exception_message += f" Please use one of the following classes instead: {generate_compatible_classes}"
raise TypeError(exception_message)
def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):
"""Validates model kwargs for generation. Generate argument typos will also be caught here."""
if self.config.is_encoder_decoder:
for key in ["decoder_input_ids"]:
model_kwargs.pop(key, None)
unused_model_args = []
model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters)
if "kwargs" in model_args or "model_kwargs" in model_args:
model_args |= set(inspect.signature(self.call).parameters)
for key, value in model_kwargs.items():
if value is not None and key not in model_args:
unused_model_args.append(key)
if unused_model_args:
raise ValueError(
f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the"
" generate arguments will also show up in this list)"
)
) -> tf.Tensor:
is_input_ids = len(inputs.shape) == 2 and inputs.dtype in (tf.int32, tf.int64)
is_pad_token_in_inputs = (pad_token_id is not None) and tf.math.reduce_any(inputs == pad_token_id)
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (pad_token_id != eos_token_id)
if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id:
return tf.cast(tf.math.not_equal(inputs, pad_token_id), dtype=tf.int32)
else:
return tf.ones(inputs.shape[:2], dtype=tf.int32)
def _prepare_encoder_decoder_kwargs_for_generation(
self, inputs_tensor: tf.Tensor, model_kwargs, model_input_name: Optional[str] = None
) -> Dict[str, Any]:
encoder = self.get_encoder()
irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
encoder_kwargs = {
argument: value
for argument, value in model_kwargs.items()
if not any(argument.startswith(p) for p in irrelevant_prefix)
}
encoder_signature = set(inspect.signature(encoder.call).parameters)
encoder_accepts_wildcard = "kwargs" in encoder_signature or "model_kwargs" in encoder_signature
if not encoder_accepts_wildcard:
encoder_kwargs = {
argument: value for argument, value in encoder_kwargs.items() if argument in encoder_signature
}
encoder_kwargs["return_dict"] = True
encoder_kwargs[model_input_name] = inputs_tensor
if model_input_name != self.main_input_name:
encoder_kwargs[self.main_input_name] = None
encoder_outputs = encoder(**encoder_kwargs)
model_kwargs["encoder_outputs"] = encoder_outputs
return model_kwargs
def _prepare_decoder_input_ids_for_generation(
self,
batch_size: int,
model_input_name: str,
model_kwargs: Dict[str, tf.Tensor],
decoder_start_token_id: int = None,
bos_token_id: int = None,
) -> Tuple[tf.Tensor, Dict[str, tf.Tensor]]:
"""Prepares `decoder_input_ids` for generation with encoder-decoder models"""
if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
decoder_input_ids = model_kwargs.pop("decoder_input_ids")
elif "input_ids" in model_kwargs and model_input_name != "input_ids":
decoder_input_ids = model_kwargs.pop("input_ids")
else:
decoder_input_ids = None
decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id)
decoder_input_ids_start = tf.ones((batch_size, 1), dtype=tf.int32) * decoder_start_token_id
if decoder_input_ids is None:
decoder_input_ids = decoder_input_ids_start
elif tf.reduce_all(decoder_input_ids[:, 0] != decoder_start_token_id):
decoder_input_ids = tf.concat([decoder_input_ids_start, decoder_input_ids], axis=-1)
if "decoder_attention_mask" in model_kwargs:
decoder_attention_mask = model_kwargs["decoder_attention_mask"]
decoder_attention_mask = tf.concat(
(tf.ones_like(decoder_attention_mask)[:, :1], decoder_attention_mask),
axis=-1,
)
model_kwargs["decoder_attention_mask"] = decoder_attention_mask
return decoder_input_ids, model_kwargs
def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_token_id: int = None) -> int:
decoder_start_token_id = (
decoder_start_token_id
if decoder_start_token_id is not None
else self.generation_config.decoder_start_token_id
)
bos_token_id = bos_token_id if bos_token_id is not None else self.generation_config.bos_token_id
if decoder_start_token_id is not None:
return decoder_start_token_id
elif bos_token_id is not None:
return bos_token_id
raise ValueError(
"`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
)
def _expand_inputs_for_generation(
expand_size: int = 1,
is_encoder_decoder: bool = False,
input_ids: Optional[tf.Tensor] = None,
expand_in_new_axis: bool = False,
**model_kwargs,
) -> Tuple[tf.Tensor, Dict[str, Any]]:
"""
Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...] or [batch_size, expand_size, ...],
depending on `expand_in_new_axis`. Beam-based approaches expect this function to be used with
`expand_in_new_axis=True`
"""
def _expand_tensor(tensor: tf.Tensor):
if expand_in_new_axis:
shape = shape_list(tensor)
return tf.broadcast_to(tensor[:, None], (shape[0], expand_size) + tuple(shape[1:]))
else:
return tf.repeat(tensor, expand_size, axis=0)
def _expand_dict_for_generation(dict_to_expand):
for key in dict_to_expand:
if dict_to_expand[key] is not None and isinstance(dict_to_expand[key], tf.Tensor):
dict_to_expand[key] = _expand_tensor(dict_to_expand[key])
return dict_to_expand
if input_ids is not None:
input_ids = _expand_tensor(input_ids)
model_kwargs = _expand_dict_for_generation(model_kwargs)
if is_encoder_decoder:
if model_kwargs.get("encoder_outputs") is None:
raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"])
return input_ids, model_kwargs
def _prepare_model_inputs(
self,
inputs: Optional[tf.Tensor] = None,
bos_token_id: Optional[int] = None,
model_kwargs: Optional[Dict[str, tf.Tensor]] = None,
):
"""
Prepares inputs for the model, optionally including a beginning-of-sequence token ID (`bos_token_id`).
"""
def _maybe_initialize_input_ids_for_generation(
self,
inputs: Optional[tf.Tensor] = None,
bos_token_id: Optional[int] = None,
model_kwargs: Optional[Dict[str, tf.Tensor]] = None,
):
"""
Initializes `input_ids` for generation, optionally including a beginning-of-sequence token ID (`bos_token_id`).
"""
) -> tf.Tensor:
"""Initializes input ids for generation, if necessary."""
if inputs is not None:
return inputs
encoder_outputs = model_kwargs.get("encoder_outputs")
if self.config.is_encoder_decoder and encoder_outputs is not None:
shape = encoder_outputs.last_hidden_state.shape[:-1]
return tf.ones(shape, dtype=tf.int32) * -100
if bos_token_id is None:
raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.")
batch_size = 1
for value in model_kwargs.values():
if isinstance(value, tf.Tensor):
batch_size = value.shape[0]
break
return tf.ones((batch_size, 1), dtype=tf.int32) * bos_token_id
@staticmethod
def _extract_past_from_model_output(outputs: ModelOutput):
"""Extracts past key values from model outputs."""
past_key_values = None
if "past_key_values" in outputs:
past_key_values = outputs.past_key_values
elif "mems" in outputs:
past_key_values = outputs.mems
elif "past_buckets_states" in outputs:
past_key_values = outputs.past_buckets_states
return past_key_values
def _update_model_kwargs_for_generation(
self, outputs: ModelOutput, model_kwargs: Dict[str, Any], is_encoder_decoder: bool = False
) -> Dict[str, Any]:
"""Updates model keyword arguments for generation."""
model_kwargs["past_key_values"] = self._extract_past_from_model_output(outputs)
if not is_encoder_decoder:
if "attention_mask" in model_kwargs:
attention_mask = model_kwargs["attention_mask"]
model_kwargs["attention_mask"] = tf.concat(
[attention_mask, tf.ones((shape_list(attention_mask)[0], 1), dtype=tf.int32)], axis=-1
)
return model_kwargs
def _update_model_kwargs_for_xla_generation(
self,
model_outputs: ModelOutput,
model_kwargs: Dict[str, Any],
cur_len: int,
max_length: int,
batch_size: int,
is_encoder_decoder: bool = False,
batch_axis: int = 0,
):
"""Updates model keyword arguments for XLA generation."""
pass
def _get_logits_warper(
self,
generation_config: GenerationConfig,
):
"""Gets the logits warper for generation."""
pass
) -> TFLogitsProcessorList:
"""
This class returns a [`TFLogitsProcessorList`] list object that contains all relevant [`TFLogitsWarper`]
instances used for multinomial sampling.
"""
warpers = TFLogitsProcessorList()
if generation_config.num_beams > 1:
if isinstance(generation_config.eos_token_id, list):
min_tokens_to_keep = len(generation_config.eos_token_id) + 1
else:
min_tokens_to_keep = 2
else:
min_tokens_to_keep = 1
if generation_config.temperature is not None and generation_config.temperature != 1.0:
warpers.append(TFTemperatureLogitsWarper(generation_config.temperature))
if generation_config.top_k is not None and generation_config.top_k != 0:
warpers.append(TFTopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep))
if generation_config.top_p is not None and generation_config.top_p < 1.0:
warpers.append(TFTopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=min_tokens_to_keep))
return warpers
) -> TFLogitsProcessorList:
"""
This class returns a [`TFLogitsProcessorList`] list object that contains all relevant [`TFLogitsProcessor`]
instances used to modify the scores of the language model head.
"""
processors = TFLogitsProcessorList()
if generation_config.repetition_penalty is not None and generation_config.repetition_penalty != 1.0:
processors.append(TFRepetitionPenaltyLogitsProcessor(penalty=generation_config.repetition_penalty))
if generation_config.no_repeat_ngram_size is not None and generation_config.no_repeat_ngram_size > 0:
processors.append(TFNoRepeatNGramLogitsProcessor(generation_config.no_repeat_ngram_size))
if generation_config.bad_words_ids is not None:
processors.append(
TFNoBadWordsLogitsProcessor(generation_config.bad_words_ids, generation_config.eos_token_id)
)
if (
generation_config.min_length is not None
and generation_config.eos_token_id is not None
and generation_config.min_length > 0
):
processors.append(TFMinLengthLogitsProcessor(generation_config.min_length, generation_config.eos_token_id))
if generation_config.forced_bos_token_id is not None:
processors.append(TFForcedBOSTokenLogitsProcessor(generation_config.forced_bos_token_id))
if generation_config.forced_eos_token_id is not None:
processors.append(
TFForcedEOSTokenLogitsProcessor(generation_config.max_length, generation_config.forced_eos_token_id)
)
if generation_config.suppress_tokens is not None:
processors.append(TFSuppressTokensLogitsProcessor(generation_config.suppress_tokens))
if generation_config.begin_suppress_tokens is not None:
begin_index = input_ids_seq_length
begin_index = (
begin_index
if (input_ids_seq_length > 1 or generation_config.forced_bos_token_id is None)
else begin_index + 1
)
if generation_config.forced_decoder_ids is not None:
begin_index += generation_config.forced_decoder_ids[-1][
0
]
processors.append(
TFSuppressTokensAtBeginLogitsProcessor(generation_config.begin_suppress_tokens, begin_index)
)
if generation_config.forced_decoder_ids is not None:
processors.append(TFForceTokensLogitsProcessor(generation_config.forced_decoder_ids))
processors = self._merge_criteria_processor_list(processors, logits_processor)
return processors
def __init__(self, custom_list: List[TFLogitsProcessor] = []) -> TFLogitsProcessorList:
if len(custom_list) == 0:
return default_list
for default in default_list:
for custom in custom_list:
if type(custom) is type(default):
object_type = "logits processor"
raise ValueError(
f"A custom {object_type} of type {type(custom)} with values {custom} has been passed to"
f" `generate`, but it has already been created with the values {default}. {default} has been"
" created by passing the corresponding arguments to generate or by the model's config default"
f" values. If you just want to change the default values of {object_type} consider passing"
f" them as arguments to `generate` instead of using a custom {object_type}."
)
default_list.extend(custom_list)
return default_list
def greedy_search(
self,
input_ids: tf.Tensor,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
logits_processor: Optional[TFLogitsProcessorList] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_scores: Optional[bool] = None,
return_dict_in_generate: Optional[bool] = None,
**model_kwargs,
):
def sample(
self,
input_ids: tf.Tensor,
logits_processor: Optional[TFLogitsProcessorList] = None,
logits_warper: Optional[TFLogitsProcessorList] = None,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
seed: Optional[Tuple[int, int]] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_scores: Optional[bool] = None,
return_dict_in_generate: Optional[bool] = None,
**model_kwargs,
):
@staticmethod
def _gather_beams(nested, beam_indices, batch_axis=0):
"""Gathers the beam slices indexed by beam_indices into new beam array."""
def gather_fn(tensor):
if batch_axis > 0:
perm = tf.concat((tf.range(tf.rank(tensor))[batch_axis:], tf.range(batch_axis)), axis=0)
tensor = tf.transpose(tensor, perm=perm)
gathered_tensor = tf.gather(params=tensor, indices=beam_indices, axis=1, batch_dims=1)
if batch_axis > 0:
perm = tf.concat((tf.range(tf.rank(tensor))[batch_axis:], tf.range(batch_axis)), axis=0)
perm = tf.math.invert_permutation(perm)
gathered_tensor = tf.transpose(gathered_tensor, perm=perm)
return gathered_tensor
return tf.nest.map_structure(gather_fn, nested)
def scatter_values_on_batch_indices(values, batch_indices):
shape = shape_list(batch_indices)
broad_casted_batch_dims = tf.reshape(tf.broadcast_to(tf.expand_dims(tf.range(shape[0]), axis=-1), shape), [1, -1])
pair_indices = tf.transpose(tf.concat([broad_casted_batch_dims, tf.reshape(batch_indices, [1, -1])], 0))
return tf.scatter_nd(pair_indices, tf.reshape(values, [-1]), shape)
def sample_without_replacement(logits, num_samples):
"""
不重复的分类采样当前尚未实现,现在使用Gumbel-Max技巧代替,请参见
https://github.com/tensorflow/tensorflow/issues/9260 获取更多信息
"""
z = -tf.math.log(-tf.math.log(tf.random.uniform(shape_list(logits), 0, 1)))
_, indices = tf.nn.top_k(logits + z, num_samples)
return indices
def _ranking_fast(
context_hidden: tf.Tensor,
next_hidden: tf.Tensor,
next_top_k_probs: tf.Tensor,
alpha: float,
beam_width: int,
) -> tf.Tensor:
"""
根据文献《神经文本生成的对比框架》中描述的退化惩罚(与先前标记的余弦相似度)对top_k候选进行重新排序。
返回批次中每行最佳候选的索引。
"""
norm_context_hidden = context_hidden / tf.norm(context_hidden, axis=2, keepdims=True)
norm_next_hidden = next_hidden / tf.norm(next_hidden, axis=2, keepdims=True)
cosine_matrix = tf.squeeze(tf.linalg.matmul(norm_context_hidden, norm_next_hidden, transpose_b=True), axis=-1)
degeneration_penalty = tf.reduce_max(cosine_matrix, axis=-1)
next_top_k_probs = tf.reshape(next_top_k_probs, shape=[-1])
contrastive_score = (1.0 - alpha) * next_top_k_probs - alpha * degeneration_penalty
contrastive_score = tf.reshape(contrastive_score, shape=[-1, beam_width])
selected_idx = tf.argmax(contrastive_score, axis=1)
return selected_idx
.\generation\utils.py
import copy
import inspect
import warnings
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
import torch
import torch.distributed as dist
from torch import nn
from ..cache_utils import Cache, DynamicCache, StaticCache
from ..integrations.deepspeed import is_deepspeed_zero3_enabled
from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput
from ..models.auto import (
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING,
MODEL_FOR_CAUSAL_LM_MAPPING,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
MODEL_FOR_VISION_2_SEQ_MAPPING,
)
from ..utils import ModelOutput, is_accelerate_available, is_torchdynamo_compiling, logging
from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint
from .beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
from .candidate_generator import (
AssistedCandidateGenerator,
CandidateGenerator,
PromptLookupCandidateGenerator,
_crop_past_key_values,
_prepare_attention_mask,
_prepare_token_type_ids,
)
from .configuration_utils import GenerationConfig, GenerationMode
from .logits_process import (
EncoderNoRepeatNGramLogitsProcessor,
EncoderRepetitionPenaltyLogitsProcessor,
EpsilonLogitsWarper,
EtaLogitsWarper,
ExponentialDecayLengthPenalty,
ForcedBOSTokenLogitsProcessor,
ForcedEOSTokenLogitsProcessor,
ForceTokensLogitsProcessor,
HammingDiversityLogitsProcessor,
InfNanRemoveLogitsProcessor,
LogitNormalization,
LogitsProcessorList,
MinLengthLogitsProcessor,
MinNewTokensLengthLogitsProcessor,
NoBadWordsLogitsProcessor,
NoRepeatNGramLogitsProcessor,
PrefixConstrainedLogitsProcessor,
RepetitionPenaltyLogitsProcessor,
SequenceBiasLogitsProcessor,
SuppressTokensAtBeginLogitsProcessor,
SuppressTokensLogitsProcessor,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
TypicalLogitsWarper,
UnbatchedClassifierFreeGuidanceLogitsProcessor,
)
from .stopping_criteria import (
MaxLengthCriteria,
MaxTimeCriteria,
StoppingCriteria,
StoppingCriteriaList,
validate_stopping_criteria,
)
if TYPE_CHECKING:
from ..modeling_utils import PreTrainedModel
from .streamers import BaseStreamer
logger = logging.get_logger(__name__)
if is_accelerate_available():
from accelerate.hooks import AlignDevicesHook, add_hook_to_module
NEED_SETUP_CACHE_CLASSES_MAPPING = {
"static": StaticCache,
}
@dataclass
class GenerateDecoderOnlyOutput(ModelOutput):
"""
Outputs of decoder-only generation models, when using non-beam methods.
"""
Args:
sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
if all batches finished early due to the `eos_token_id`.
scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True` is passed or when `config.output_logits=True`):
Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
NOTE: some models have a different `past_key_values` format, confirm with the model's documentation.
Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value
tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
encoder_sequence_length, embed_size_per_head)`.
# 声明一个可选的变量 hidden_states,其类型是一个元组,包含一个元组,该元组中包含一个 torch.FloatTensor 类型的值
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
# 声明一个可选的变量 past_key_values,其类型是一个元组,包含一个元组,该元组中包含一个元组,该元组中包含一个 torch.FloatTensor 类型的值
past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None
# 用于生成编码器-解码器模型的输出,非使用 Beam 方法时的情况
@dataclass
class GenerateEncoderDecoderOutput(ModelOutput):
"""
编码器-解码器生成模型的输出,当不使用 Beam 方法时。
"""
sequences: torch.LongTensor = None # 生成的序列(token ID)
scores: Optional[Tuple[torch.FloatTensor]] = None # 每个生成序列的分数
logits: Optional[Tuple[torch.FloatTensor]] = None # 每个生成序列的 logits
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None # 编码器注意力权重
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None # 编码器隐藏状态
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None # 解码器注意力权重
cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None # 交叉注意力权重
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None # 解码器隐藏状态
past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None # 额外的过去键值(针对 Transformer 模型)
# 用于生成仅解码器模型的输出,使用 Beam 方法时的情况
@dataclass
class GenerateBeamDecoderOnlyOutput(ModelOutput):
"""
解码器生成模型的输出,仅在使用 Beam 方法时。
"""
sequences: torch.LongTensor = None # 生成的序列(token ID)
sequences_scores: Optional[torch.FloatTensor] = None # 生成序列的分数
scores: Optional[Tuple[torch.FloatTensor]] = None # 每个生成序列的分数
logits: Optional[Tuple[torch.FloatTensor]] = None # 每个生成序列的 logits
beam_indices: Optional[torch.LongTensor] = None # Beam 搜索时使用的索引
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None # 注意力权重
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None # 隐藏状态
past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None # 额外的过去键值(针对 Transformer 模型)
# 用于生成编码器-解码器模型的输出,使用 Beam 方法时的情况
@dataclass
class GenerateBeamEncoderDecoderOutput(ModelOutput):
"""
编码器-解码器生成模型的输出,使用 Beam 方法时。
"""
sequences: torch.LongTensor = None # 生成的序列(token ID)
sequences_scores: Optional[torch.FloatTensor] = None # 生成序列的分数
scores: Optional[Tuple[torch.FloatTensor]] = None # 每个生成序列的分数
logits: Optional[Tuple[torch.FloatTensor]] = None # 每个生成序列的 logits
beam_indices: Optional[torch.LongTensor] = None # Beam 搜索时使用的索引
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None # 编码器注意力权重
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None # 编码器隐藏状态
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None # 解码器注意力权重
cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None # 交叉注意力权重
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None # 解码器隐藏状态
past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None # 额外的过去键值(针对 Transformer 模型)
# 以下是为了向后兼容而保留的等效类
GreedySearchDecoderOnlyOutput = GenerateDecoderOnlyOutput # 贪婪搜索解码器模型的输出
ContrastiveSearchDecoderOnlyOutput = GenerateDecoderOnlyOutput # 对比搜索解码器模型的输出
SampleDecoderOnlyOutput = GenerateDecoderOnlyOutput # 示例解码器模型的输出
ContrastiveSearchEncoderDecoderOutput = GenerateEncoderDecoderOutput # 对比搜索编码器-解码器模型的输出
GreedySearchEncoderDecoderOutput = GenerateEncoderDecoderOutput # 贪婪搜索编码器-解码器模型的输出
SampleEncoderDecoderOutput = GenerateEncoderDecoderOutput # 示例编码器-解码器模型的输出
BeamSearchDecoderOnlyOutput = GenerateBeamDecoderOnlyOutput # Beam 搜索解码器模型的输出
BeamSampleDecoderOnlyOutput = GenerateBeamDecoderOnlyOutput # Beam 示例解码器模型的输出
BeamSearchEncoderDecoderOutput = GenerateBeamEncoderDecoderOutput # Beam 搜索编码器-解码器模型的输出
BeamSampleEncoderDecoderOutput = GenerateBeamEncoderDecoderOutput # Beam 示例编码器-解码器模型的输出
GreedySearchOutput = Union[GreedySearchEncoderDecoderOutput, GreedySearchDecoderOnlyOutput] # 贪婪搜索的输出类型
# Typing shortcuts for specific types of model outputs
SampleOutput = Union[SampleEncoderDecoderOutput, SampleDecoderOnlyOutput]
BeamSearchOutput = Union[BeamSearchEncoderDecoderOutput, BeamSearchDecoderOnlyOutput]
BeamSampleOutput = Union[BeamSampleEncoderDecoderOutput, BeamSampleDecoderOnlyOutput]
ContrastiveSearchOutput = Union[ContrastiveSearchEncoderDecoderOutput, ContrastiveSearchDecoderOnlyOutput]
# Typing shortcut for non-beam text generation output
GenerateNonBeamOutput = Union[GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput]
# Typing shortcut for beam search text generation output
GenerateBeamOutput = Union[GenerateBeamDecoderOnlyOutput, GenerateBeamEncoderDecoderOutput]
# Typing shortcut for any text generation output
GenerateOutput = Union[GenerateNonBeamOutput, GenerateBeamOutput]
class GenerationMixin:
"""
A class containing all functions for auto-regressive text generation, to be used as a mixin in [`PreTrainedModel`].
The class exposes [`~generation.GenerationMixin.generate`], which can be used for:
- *greedy decoding* by calling [`~generation.GenerationMixin._greedy_search`] if `num_beams=1` and
`do_sample=False`
- *contrastive search* by calling [`~generation.GenerationMixin._contrastive_search`] if `penalty_alpha>0` and
`top_k>1`
- *multinomial sampling* by calling [`~generation.GenerationMixin._sample`] if `num_beams=1` and
`do_sample=True`
- *beam-search decoding* by calling [`~generation.GenerationMixin._beam_search`] if `num_beams>1` and
`do_sample=False`
- *beam-search multinomial sampling* by calling [`~generation.GenerationMixin._beam_sample`] if `num_beams>1`
and `do_sample=True`
- *diverse beam-search decoding* by calling [`~generation.GenerationMixin._group_beam_search`], if `num_beams>1`
and `num_beam_groups>1`
- *constrained beam-search decoding* by calling [`~generation.GenerationMixin._constrained_beam_search`], if
`constraints!=None` or `force_words_ids!=None`
- *assisted decoding* by calling [`~generation.GenerationMixin._assisted_decoding`], if
`assistant_model` or `prompt_lookup_num_tokens` is passed to `.generate()`
You do not need to call any of the above methods directly. Pass custom parameter values to 'generate' instead. To
learn more about decoding strategies refer to the [text generation strategies guide](../generation_strategies).
"""
def prepare_inputs_for_generation(self, *args, **kwargs):
# Raise an error if this method is not implemented in the subclass
raise NotImplementedError(
"A model class needs to define a `prepare_inputs_for_generation` method in order to use `.generate()`."
)
def _prepare_model_inputs(
self,
inputs: Optional[torch.Tensor] = None,
bos_token_id: Optional[int] = None,
model_kwargs: Optional[Dict[str, torch.Tensor]] = None,
):
# Internal method for preparing model inputs for text generation
...
def _maybe_initialize_input_ids_for_generation(
self,
inputs: Optional[torch.Tensor] = None,
bos_token_id: Optional[int] = None,
model_kwargs: Optional[Dict[str, torch.Tensor]] = None,
):
# Internal method to initialize input IDs for text generation if necessary
...
) -> torch.LongTensor:
"""Initializes input ids for generation, if necessary."""
# 如果已经提供了输入,则直接返回输入
if inputs is not None:
return inputs
# 获取模型关键字参数中的 encoder_outputs
encoder_outputs = model_kwargs.get("encoder_outputs")
# 如果模型是编码-解码模型且 encoder_outputs 不为空
if self.config.is_encoder_decoder and encoder_outputs is not None:
# 创建一个与 encoder_outputs 最后一层隐藏状态相同形状的输入 id 张量,填充值为 -100
shape = encoder_outputs.last_hidden_state.size()[:-1]
return torch.ones(shape, dtype=torch.long, device=self.device) * -100
# 如果未提供 input_ids 且未定义 bos_token_id,则引发错误
if bos_token_id is None:
raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.")
# 如果 model_kwargs 中有某些张量,则可以从中推断出批量大小
batch_size = 1
for value in model_kwargs.values():
if isinstance(value, torch.Tensor):
batch_size = value.shape[0]
break
# 如果 model_kwargs 中包含 "inputs_embeds" 键
if "inputs_embeds" in model_kwargs:
# 返回一个形状为 (batch_size, 0) 的全 1 张量,dtype 为 torch.long
return torch.ones((batch_size, 0), dtype=torch.long, device=self.device)
# 否则返回一个形状为 (batch_size, 1) 的全 bos_token_id 值的张量,dtype 为 torch.long
return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id
def _prepare_attention_mask_for_generation(
self,
inputs: torch.Tensor,
pad_token_id: Optional[int],
eos_token_id: Optional[Union[int, List[int]]],
) -> torch.LongTensor:
# 检查输入是否为 input_ids 且已被填充,只有这种情况下才定义 attention_mask
is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long]
is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs)
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (pad_token_id not in eos_token_id)
# 如果输入是 input_ids 且已填充,并且填充标记不等于 eos_token_id,则返回 attention_mask
if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id:
return inputs.ne(pad_token_id).long()
else:
# 否则返回一个形状与 inputs 的前两维相同的全 1 张量,dtype 为 torch.long
return torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device)
def _prepare_encoder_decoder_kwargs_for_generation(
self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name: Optional[str] = None
# 1. 获取编码器
encoder = self.get_encoder()
# 2. 兼容加速大模型推断:确保编码器在与输入相同的设备上输出结果
if hasattr(self, "hf_device_map"):
# 如果编码器有 `_hf_hook` 属性,设置其 `io_same_device` 为 True
if hasattr(encoder, "_hf_hook"):
encoder._hf_hook.io_same_device = True
# 否则,向编码器添加一个 AlignDevicesHook,设置 `io_same_device` 为 True
else:
add_hook_to_module(encoder, AlignDevicesHook(io_same_device=True))
# 3. 从模型参数中准备编码器的参数和关键字参数
irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
# 从 `model_kwargs` 中选择与编码器相关的参数和值
encoder_kwargs = {
argument: value
for argument, value in model_kwargs.items()
if not any(argument.startswith(p) for p in irrelevant_prefix)
}
# 检查编码器的输入签名,确定是否支持 `kwargs` 或 `model_kwargs`
encoder_signature = set(inspect.signature(encoder.forward).parameters)
encoder_accepts_wildcard = "kwargs" in encoder_signature or "model_kwargs" in encoder_signature
if not encoder_accepts_wildcard:
# 如果编码器不支持通配符参数,仅选择编码器签名中存在的参数和值
encoder_kwargs = {
argument: value for argument, value in encoder_kwargs.items() if argument in encoder_signature
}
# 4. 确保编码器返回 `ModelOutput`
model_input_name = model_input_name if model_input_name is not None else self.main_input_name
encoder_kwargs["return_dict"] = True
encoder_kwargs[model_input_name] = inputs_tensor
# 调用编码器并将结果保存在 `model_kwargs` 的 `encoder_outputs` 键中
model_kwargs["encoder_outputs"]: ModelOutput = encoder(**encoder_kwargs)
return model_kwargs
# 准备用于生成的解码器输入 ID
def _prepare_decoder_input_ids_for_generation(
self,
batch_size: int,
model_input_name: str,
model_kwargs: Dict[str, torch.Tensor],
decoder_start_token_id: Union[int, List[int]] = None,
bos_token_id: int = None,
device: torch.device = None,
) -> Dict[str, torch.Tensor]:
...
# 获取解码器起始标记 ID
def _get_decoder_start_token_id(
self, decoder_start_token_id: Union[int, List[int]] = None, bos_token_id: int = None
) -> int:
...
# 扩展用于生成的输入
@staticmethod
def _expand_inputs_for_generation(
expand_size: int = 1,
is_encoder_decoder: bool = False,
input_ids: Optional[torch.LongTensor] = None,
**model_kwargs,
) -> Dict[str, torch.Tensor]:
...
) -> Tuple[torch.LongTensor, Dict[str, Any]]:
"""Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...]"""
# 定义函数签名,指定返回类型为元组,包含一个长整型张量和一个任意类型字典
def _expand_dict_for_generation(dict_to_expand):
# 为生成过程扩展字典中的张量
for key in dict_to_expand:
if dict_to_expand[key] is not None and isinstance(dict_to_expand[key], torch.Tensor):
dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
return dict_to_expand
if input_ids is not None:
input_ids = input_ids.repeat_interleave(expand_size, dim=0)
# 如果输入 ID 不为空,则按照指定的扩展大小在指定维度上重复扩展
model_kwargs = _expand_dict_for_generation(model_kwargs)
# 扩展模型参数字典中的张量
if is_encoder_decoder:
if model_kwargs.get("encoder_outputs") is None:
raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"])
# 如果是编码器-解码器模型,确保编码器输出在模型参数中被定义,并进行扩展
return input_ids, model_kwargs
# 返回扩展后的输入 ID 和模型参数字典
def _extract_past_from_model_output(self, outputs: ModelOutput, standardize_cache_format: bool = False):
past_key_values = None
if "past_key_values" in outputs:
past_key_values = outputs.past_key_values
elif "mems" in outputs:
past_key_values = outputs.mems
elif "past_buckets_states" in outputs:
past_key_values = outputs.past_buckets_states
# 从模型输出中提取过去的键-值对
# Bloom fix: standardizes the cache format when requested
if standardize_cache_format and hasattr(self, "_convert_to_standard_cache"):
batch_size = outputs.logits.shape[0]
past_key_values = self._convert_to_standard_cache(past_key_values, batch_size=batch_size)
# 在请求时,如果需要,标准化缓存格式
return past_key_values
# 返回提取的过去键-值对
def _update_model_kwargs_for_generation(
self,
outputs: ModelOutput,
model_kwargs: Dict[str, Any],
is_encoder_decoder: bool = False,
standardize_cache_format: bool = False,
# 更新用于生成的模型参数字典
) -> Dict[str, Any]:
# 更新 model_kwargs 中的 past_key_values,从模型输出中提取过去的键值
model_kwargs["past_key_values"] = self._extract_past_from_model_output(
outputs, standardize_cache_format=standardize_cache_format
)
# 如果 outputs 有 state 属性,则更新 model_kwargs 中的 state
if getattr(outputs, "state", None) is not None:
model_kwargs["state"] = outputs.state
# 更新 token_type_ids,使用最后一个值进行扩展
if "token_type_ids" in model_kwargs:
token_type_ids = model_kwargs["token_type_ids"]
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
# 如果不是 encoder-decoder 架构
if not is_encoder_decoder:
# 更新 attention_mask
if "attention_mask" in model_kwargs:
attention_mask = model_kwargs["attention_mask"]
model_kwargs["attention_mask"] = torch.cat(
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
)
else:
# 更新 decoder_attention_mask
if "decoder_attention_mask" in model_kwargs:
decoder_attention_mask = model_kwargs["decoder_attention_mask"]
model_kwargs["decoder_attention_mask"] = torch.cat(
[decoder_attention_mask, decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], 1))],
dim=-1,
)
# 如果 model_kwargs 中存在 cache_position 并且不为 None,则更新 cache_position
if "cache_position" in model_kwargs and model_kwargs["cache_position"] is not None:
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1
# 返回更新后的 model_kwargs
return model_kwargs
# 抛出未实现错误,提示在当前类的模块中实现 _reorder_cache 函数以启用 beam search
def _reorder_cache(self, past_key_values, beam_idx):
raise NotImplementedError(
f"Make sure that a `_reorder_cache` function is correctly implemented in {self.__class__.__module__} to"
f" enable beam search for {self.__class__}"
)
# 返回用于辅助生成的候选生成器
def _get_candidate_generator(
self,
generation_config: GenerationConfig,
input_ids: torch.LongTensor,
inputs_tensor: torch.Tensor,
assistant_model: "PreTrainedModel",
logits_processor: LogitsProcessorList,
model_kwargs: Dict,
) -> CandidateGenerator:
"""
Returns the candidate generator to be used in `assisted_generation`
"""
# 如果指定了 prompt_lookup_num_tokens,则使用 PromptLookupCandidateGenerator
if generation_config.prompt_lookup_num_tokens is not None:
candidate_generator = PromptLookupCandidateGenerator(
num_output_tokens=generation_config.prompt_lookup_num_tokens,
max_matching_ngram_size=generation_config.max_matching_ngram_size,
)
else:
# 否则使用 AssistedCandidateGenerator
candidate_generator = AssistedCandidateGenerator(
input_ids=input_ids,
assistant_model=assistant_model,
generation_config=generation_config,
logits_processor=logits_processor,
model_kwargs=model_kwargs,
inputs_tensor=inputs_tensor,
)
return candidate_generator
def _get_logits_warper(
self,
generation_config: GenerationConfig,
) -> LogitsProcessorList:
"""
This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances
used for multinomial sampling.
"""
# instantiate warpers list
warpers = LogitsProcessorList()
# In beam methods, we need to keep at least one non-eos token to explore continuations that might have a
# better score (i.e. keep len(list(generation_config.eos_token_id)) + 1)
if generation_config.num_beams > 1:
if isinstance(generation_config.eos_token_id, list):
min_tokens_to_keep = len(generation_config.eos_token_id) + 1
else:
min_tokens_to_keep = 2
else:
min_tokens_to_keep = 1
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
# all samplers can be found in `generation_utils_samplers.py`
# Apply temperature warping if temperature is defined and not equal to 1.0
if generation_config.temperature is not None and generation_config.temperature != 1.0:
warpers.append(TemperatureLogitsWarper(generation_config.temperature))
# Apply top-k warping if top-k is defined and not equal to 0
if generation_config.top_k is not None and generation_config.top_k != 0:
warpers.append(TopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep))
# Apply top-p warping if top-p is defined and less than 1.0
if generation_config.top_p is not None and generation_config.top_p < 1.0:
warpers.append(TopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=min_tokens_to_keep))
# Apply typical-p warping if typical-p is defined and less than 1.0
if generation_config.typical_p is not None and generation_config.typical_p < 1.0:
warpers.append(
TypicalLogitsWarper(mass=generation_config.typical_p, min_tokens_to_keep=min_tokens_to_keep)
)
# Apply epsilon cutoff warping if epsilon cutoff is defined and within (0, 1)
if generation_config.epsilon_cutoff is not None and 0.0 < generation_config.epsilon_cutoff < 1.0:
warpers.append(
EpsilonLogitsWarper(epsilon=generation_config.epsilon_cutoff, min_tokens_to_keep=min_tokens_to_keep)
)
# Apply eta cutoff warping if eta cutoff is defined and within (0, 1)
if generation_config.eta_cutoff is not None and 0.0 < generation_config.eta_cutoff < 1.0:
warpers.append(
EtaLogitsWarper(epsilon=generation_config.eta_cutoff, min_tokens_to_keep=min_tokens_to_keep)
)
# `LogitNormalization` should always be the last logit processor, when present
# Apply logit normalization if renormalize_logits flag is True
if generation_config.renormalize_logits is True:
warpers.append(LogitNormalization())
# Return the list of warpers containing all relevant LogitsWarper instances
return warpers
# 获取 logits 处理器函数,根据给定的配置和参数
def _get_logits_processor(
self,
generation_config: GenerationConfig, # 生成配置对象
input_ids_seq_length: int, # 输入的序列长度
encoder_input_ids: torch.LongTensor, # 编码器输入的张量
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]], # 可以使用的前缀令牌函数
logits_processor: Optional[LogitsProcessorList], # logits 处理器的可选列表
model_kwargs: Optional[Dict[str, Any]] = None, # 模型参数的可选字典,默认为空
negative_prompt_ids: Optional[torch.Tensor] = None, # 负面提示的可选张量,默认为空
negative_prompt_attention_mask: Optional[torch.Tensor] = None, # 负面提示的注意力掩码,可选,默认为空
):
# 定义 stopping_criteria 对象并初始化为空列表
criteria = StoppingCriteriaList()
# 如果生成配置中指定了最大长度
if generation_config.max_length is not None:
# 从模型配置中获取最大位置嵌入数
max_position_embeddings = getattr(self.config, "max_position_embeddings", None)
# 向 criteria 中添加最大长度的停止条件
criteria.append(
MaxLengthCriteria(
max_length=generation_config.max_length,
max_position_embeddings=max_position_embeddings,
)
)
# 如果生成配置中指定了最大时间
if generation_config.max_time is not None:
# 向 criteria 中添加最大时间的停止条件
criteria.append(MaxTimeCriteria(max_time=generation_config.max_time))
# 将自定义的 stopping_criteria 合并到 criteria 中
criteria = self._merge_criteria_processor_list(criteria, stopping_criteria)
# 返回最终的 criteria 列表
return criteria
# 合并默认列表和自定义列表的 logits 处理器或停止条件
def _merge_criteria_processor_list(
self,
default_list: Union[LogitsProcessorList, StoppingCriteriaList], # 默认的处理器或停止条件列表
custom_list: Union[LogitsProcessorList, StoppingCriteriaList], # 自定义的处理器或停止条件列表
) -> Union[LogitsProcessorList, StoppingCriteriaList]: # 返回合并后的处理器或停止条件列表
# 如果自定义列表为空,直接返回默认列表
if len(custom_list) == 0:
return default_list
# 遍历默认列表
for default in default_list:
# 遍历自定义列表
for custom in custom_list:
# 如果自定义的对象类型和默认的对象类型相同
if type(custom) is type(default):
# 确定对象类型是停止条件还是 logits 处理器
object_type = "stopping criteria" if isinstance(custom, StoppingCriteria) else "logits processor"
# 抛出值错误,提示不允许自定义与默认相同类型的处理器或条件
raise ValueError(
f"A custom {object_type} of type {type(custom)} with values {custom} has been passed to"
f" `.generate()`, but it has already been created with the values {default}. {default} has been"
" created by passing the corresponding arguments to generate or by the model's config default"
f" values. If you just want to change the default values of {object_type} consider passing"
f" them as arguments to `.generate()` instead of using a custom {object_type}."
)
# 将自定义列表的内容扩展到默认列表中
default_list.extend(custom_list)
# 返回合并后的默认列表
return default_list
# 计算转移分数的函数
def compute_transition_scores(
self,
sequences: torch.Tensor, # 序列张量
scores: Tuple[torch.Tensor], # 分数元组
beam_indices: Optional[torch.Tensor] = None, # 光束索引的可选张量,默认为空
normalize_logits: bool = False, # 是否对 logits 进行归一化,默认为 False
def _validate_model_class(self):
"""
Confirms that the model class is compatible with generation. If not, raises an exception that points to the
right class to use.
"""
# 检查当前模型是否能够生成文本
if not self.can_generate():
# 可生成的模型映射列表
generate_compatible_mappings = [
MODEL_FOR_CAUSAL_LM_MAPPING,
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING,
MODEL_FOR_VISION_2_SEQ_MAPPING,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
]
generate_compatible_classes = set()
# 遍历可生成的模型映射列表,获取支持的模型类名集合
for model_mapping in generate_compatible_mappings:
supported_models = model_mapping.get(type(self.config), default=None)
if supported_models is not None:
generate_compatible_classes.add(supported_models.__name__)
# 出现异常的错误信息
exception_message = (
f"The current model class ({self.__class__.__name__}) is not compatible with `.generate()`, as "
"it doesn't have a language model head."
)
# 如果存在兼容的模型类名集合,则添加到异常信息中
if generate_compatible_classes:
exception_message += f" Please use one of the following classes instead: {generate_compatible_classes}"
# 抛出类型错误异常,包含详细的异常信息
raise TypeError(exception_message)
# 执行与生成长度相关的验证,包括警告和错误处理
# 1. 针对参数化不良的最大长度警告
if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20:
# 如果使用了默认的 `max_length`(=20)来控制生成长度,会发出警告
warnings.warn(
f"Using the model-agnostic default `max_length` (={generation_config.max_length}) to control the "
"generation length. We recommend setting `max_new_tokens` to control the maximum length of the "
"generation.",
UserWarning,
)
# 如果输入的ids长度超过了指定的最大长度,会引发异常
if input_ids_length >= generation_config.max_length:
input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
raise ValueError(
f"Input length of {input_ids_string} is {input_ids_length}, but `max_length` is set to"
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
" increasing `max_length` or, better yet, setting `max_new_tokens`."
)
# 2. 由于不可行的参数组合,发出最小长度警告
min_length_error_suffix = (
" Generation will stop at the defined maximum length. You should decrease the minimum length and/or "
"increase the maximum length."
)
if has_default_max_length:
min_length_error_suffix += (
f" Note that `max_length` is set to {generation_config.max_length}, its default value."
)
# 如果设定了最小长度,并且该长度大于最大长度,则发出警告
if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
warnings.warn(
f"Unfeasible length constraints: `min_length` ({generation_config.min_length}) is larger than"
f" the maximum possible length ({generation_config.max_length})." + min_length_error_suffix,
UserWarning,
)
# 如果设置了最小新token数量,并且计算后的最小长度超过了最大长度,则发出警告
if generation_config.min_new_tokens is not None:
min_length = generation_config.min_new_tokens + input_ids_length
if min_length > generation_config.max_length:
warnings.warn(
f"Unfeasible length constraints: `min_new_tokens` ({generation_config.min_new_tokens}), when "
f"added to the prompt length ({input_ids_length}), is larger than"
f" the maximum possible length ({generation_config.max_length})." + min_length_error_suffix,
UserWarning,
)
def generate(
self,
inputs: Optional[torch.Tensor] = None,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
synced_gpus: Optional[bool] = None,
assistant_model: Optional["PreTrainedModel"] = None,
streamer: Optional["BaseStreamer"] = None,
negative_prompt_ids: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
**kwargs,
):
"""
Generates sequences based on the provided inputs and configuration.
Args:
inputs (Optional[torch.Tensor]): Input tensor for generation.
generation_config (Optional[GenerationConfig]): Configuration for generation.
logits_processor (Optional[LogitsProcessorList]): Processors for logits during generation.
stopping_criteria (Optional[StoppingCriteriaList]): Criteria for stopping generation.
prefix_allowed_tokens_fn (Optional[Callable[[int, torch.Tensor], List[int]]]): Function to allow tokens during generation.
synced_gpus (Optional[bool]): Whether to synchronize generation across GPUs.
assistant_model (Optional["PreTrainedModel"]): Model used for generation assistance.
streamer (Optional["BaseStreamer"]): Streamer for generation.
negative_prompt_ids (Optional[torch.Tensor]): IDs for negative prompts.
negative_prompt_attention_mask (Optional[torch.Tensor]): Attention mask for negative prompts.
**kwargs: Additional keyword arguments.
Returns:
dict: Dictionary containing generated sequences and other relevant outputs.
"""
...
def _has_unfinished_sequences(self, this_peer_finished: bool, synced_gpus: bool, device: torch.device) -> bool:
"""
Returns whether there are still unfinished sequences on the specified device.
Args:
this_peer_finished (bool): Flag indicating if the current peer has finished generation.
synced_gpus (bool): Whether generation is synchronized across GPUs.
device (torch.device): Device on which generation is performed.
Returns:
bool: True if there are unfinished sequences, False otherwise.
"""
if synced_gpus:
# Under synced_gpus, ensure all GPUs complete their sequence generation.
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(device)
# Send 0.0 if this peer finished, 1.0 otherwise
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# Check if all peers finished (sum should be 0.0 if all finished)
if this_peer_finished_flag.item() == 0.0:
return False
elif this_peer_finished:
return False
return True
def contrastive_search(self, *args, **kwargs):
"""
Deprecated method for performing contrastive search. Use `generate` or a custom generation loop instead.
Args:
*args: Positional arguments passed to `_contrastive_search`.
**kwargs: Keyword arguments passed to `_contrastive_search`.
Returns:
Any: Result from `_contrastive_search`.
"""
logger.warning_once(
"Calling `contrastive_search` directly is deprecated and will be removed in v4.41. Use `generate` or a "
"custom generation loop instead.",
)
return self._contrastive_search(*args, **kwargs)
@torch.no_grad()
def _contrastive_search(
self,
input_ids: torch.LongTensor,
top_k: Optional[int] = 1,
penalty_alpha: Optional[float] = 0,
logits_processor: Optional[LogitsProcessorList] = None,
logits_warper: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_scores: Optional[bool] = None,
output_logits: Optional[bool] = None,
return_dict_in_generate: Optional[bool] = None,
synced_gpus: bool = False,
streamer: Optional["BaseStreamer"] = None,
sequential: Optional[bool] = None,
**model_kwargs,
):
"""
Performs contrastive search to generate sequences based on the input_ids and additional arguments.
Args:
input_ids (torch.LongTensor): Input tensor containing token IDs.
top_k (Optional[int]): Number of top-k results to consider.
penalty_alpha (Optional[float]): Penalty factor for contrastive search.
logits_processor (Optional[LogitsProcessorList]): Processors for logits during contrastive search.
logits_warper (Optional[LogitsProcessorList]): Processors for logits warping during contrastive search.
stopping_criteria (Optional[StoppingCriteriaList]): Criteria for stopping contrastive search.
pad_token_id (Optional[int]): Token ID for padding.
eos_token_id (Optional[Union[int, List[int]]]): Token ID(s) for end-of-sequence.
output_attentions (Optional[bool]): Whether to output attention weights.
output_hidden_states (Optional[bool]): Whether to output hidden states.
output_scores (Optional[bool]): Whether to output scores.
output_logits (Optional[bool]): Whether to output logits.
return_dict_in_generate (Optional[bool]): Whether to return results in a dictionary format.
synced_gpus (bool): Whether generation is synchronized across GPUs.
streamer (Optional["BaseStreamer"]): Streamer for contrastive search.
sequential (Optional[bool]): Whether to generate sequentially.
**model_kwargs: Additional keyword arguments.
Returns:
Any: Result of contrastive search, typically sequences or generated outputs.
"""
...
# 发出警告日志,提醒直接调用该方法已经被废弃,将在 v4.41 版本中移除,建议使用 `generate` 方法或自定义生成循环代替。
def greedy_search(self, *args, **kwargs):
logger.warning_once(
"Calling `greedy_search` directly is deprecated and will be removed in v4.41. Use `generate` or a "
"custom generation loop instead.",
)
# 调用 `_greedy_search` 方法,并将所有传入的位置参数和关键字参数传递给它
return self._greedy_search(*args, **kwargs)
# 发出警告日志,提醒直接调用该方法已经被废弃,将在 v4.41 版本中移除,建议使用 `generate` 方法或自定义生成循环代替。
def _greedy_search(
self,
input_ids: torch.LongTensor,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_scores: Optional[bool] = None,
output_logits: Optional[bool] = None,
return_dict_in_generate: Optional[bool] = None,
synced_gpus: bool = False,
streamer: Optional["BaseStreamer"] = None,
**model_kwargs,
):
# 方法实现略去,用于执行贪婪搜索算法或相关任务
pass
# 发出警告日志,提醒直接调用该方法已经被废弃,将在 v4.41 版本中移除,建议使用 `generate` 方法或自定义生成循环代替。
def sample(self, *args, **kwargs):
logger.warning_once(
"Calling `sample` directly is deprecated and will be removed in v4.41. Use `generate` or a "
"custom generation loop instead.",
)
# 调用 `_sample` 方法,并将所有传入的位置参数和关键字参数传递给它
return self._sample(*args, **kwargs)
# 发出警告日志,提醒直接调用该方法已经被废弃,将在 v4.41 版本中移除,建议使用 `generate` 方法或自定义生成循环代替。
def _sample(
self,
input_ids: torch.LongTensor,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
logits_warper: Optional[LogitsProcessorList] = None,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_scores: Optional[bool] = None,
output_logits: Optional[bool] = None,
return_dict_in_generate: Optional[bool] = None,
synced_gpus: bool = False,
streamer: Optional["BaseStreamer"] = None,
**model_kwargs,
):
# 方法实现略去,用于执行采样或相关生成任务
pass
def _temporary_reorder_cache(self, past_key_values, beam_idx):
"""
Temporary function to handle the different types of cache reordering processes while we roll out `Cache`.
TODO: standardize cache formats and make all models compatible with `Cache`. It would remove the need
for this function, with `Cache.reorder_cache` being the sole remaining code path
"""
# 获取当前类名的小写形式
model_class = self.__class__.__name__.lower()
# 异常情况1:处理使用传统缓存格式的模型的代码路径
if isinstance(past_key_values, (tuple, list)):
past_key_values = self._reorder_cache(past_key_values, beam_idx)
# 异常情况2:处理具有不同缓存格式的模型。这些模型目前仅限于 `DynamicCache`,直到它们的缓存格式标准化为止。
elif "bloom" in model_class or "gptbigcode" in model_class:
if not isinstance(past_key_values, DynamicCache):
raise ValueError(
f"Using an unsupported cache format with {model_class}. Currently, it only supports the "
"legacy tuple format or `DynamicCache`"
)
past_key_values = self._reorder_cache(past_key_values, beam_idx)
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
# 标准代码路径:使用 `Cache.reorder_cache`
else:
past_key_values.reorder_cache(beam_idx)
return past_key_values
def beam_search(self, *args, **kwargs):
logger.warning_once(
"Calling `beam_search` directly is deprecated and will be removed in v4.41. Use `generate` or a "
"custom generation loop instead.",
)
return self._beam_search(*args, **kwargs)
def _beam_search(
self,
input_ids: torch.LongTensor,
beam_scorer: BeamScorer,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_scores: Optional[bool] = None,
output_logits: Optional[bool] = None,
return_dict_in_generate: Optional[bool] = None,
synced_gpus: bool = False,
sequential: Optional[bool] = None,
**model_kwargs,
):
"""
Perform beam search to generate sequences based on input_ids and beam_scorer.
"""
logger.warning_once(
"Calling `beam_search` directly is deprecated and will be removed in v4.41. Use `generate` or a "
"custom generation loop instead.",
)
return self._beam_search(*args, **kwargs)
def beam_sample(self, *args, **kwargs):
logger.warning_once(
"Calling `beam_sample` directly is deprecated and will be removed in v4.41. Use `generate` or a "
"custom generation loop instead.",
)
return self._beam_sample(*args, **kwargs)
# 定义一个私有方法 `_beam_sample`,用于执行束搜索采样
def _beam_sample(
self,
input_ids: torch.LongTensor,
beam_scorer: BeamScorer,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
logits_warper: Optional[LogitsProcessorList] = None,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_scores: Optional[bool] = None,
output_logits: Optional[bool] = None,
return_dict_in_generate: Optional[bool] = None,
synced_gpus: bool = False,
**model_kwargs,
):
# 具体功能的注释可以在方法内部详细描述
pass
# 警告用户 `group_beam_search` 方法即将在 v4.41 版本中移除,建议使用 `generate` 方法或自定义生成循环
def group_beam_search(self, *args, **kwargs):
logger.warning_once(
"Calling `group_beam_search` directly is deprecated and will be removed in v4.41. Use `generate` or a "
"custom generation loop instead.",
)
# 调用 `_group_beam_search` 方法来执行实际的束搜索操作
return self._group_beam_search(*args, **kwargs)
# 定义一个私有方法 `_group_beam_search`,用于执行束搜索
def _group_beam_search(
self,
input_ids: torch.LongTensor,
beam_scorer: BeamScorer,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_scores: Optional[bool] = None,
output_logits: Optional[bool] = None,
return_dict_in_generate: Optional[bool] = None,
synced_gpus: bool = False,
**model_kwargs,
):
# 具体功能的注释可以在方法内部详细描述
pass
# 警告用户 `constrained_beam_search` 方法即将在 v4.41 版本中移除,建议使用 `generate` 方法或自定义生成循环
def constrained_beam_search(self, *args, **kwargs):
logger.warning_once(
"Calling `constrained_beam_search` directly is deprecated and will be removed in v4.41. Use `generate` or a "
"custom generation loop instead.",
)
# 调用 `_constrained_beam_search` 方法来执行实际的约束束搜索操作
return self._constrained_beam_search(*args, **kwargs)
# 定义一个私有方法 `_constrained_beam_search`,用于执行约束束搜索
def _constrained_beam_search(
self,
input_ids: torch.LongTensor,
constrained_beam_scorer: ConstrainedBeamSearchScorer,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_scores: Optional[bool] = None,
output_logits: Optional[bool] = None,
return_dict_in_generate: Optional[bool] = None,
synced_gpus: Optional[bool] = None,
**model_kwargs,
):
# 具体功能的注释可以在方法内部详细描述
pass
# 发出警告日志,提醒直接调用 `_assisted_decoding` 方法已不推荐,在 v4.41 版本中将被移除。建议使用 `generate` 方法或自定义生成循环。
logger.warning_once(
"Calling `_assisted_decoding` directly is deprecated and will be removed in v4.41. Use `generate` or a "
"custom generation loop instead.",
)
# 调用 `_assisted_decoding` 方法,将所有传入的位置参数和关键字参数传递给它,并返回其结果。
return self._assisted_decoding(*args, **kwargs)
def _speculative_sampling(
candidate_input_ids,
candidate_logits,
candidate_length,
new_logits,
last_assistant_token_is_eos,
max_matches,
):
"""
Applies sampling as in the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf, algorithm 1). Returns
the selected tokens, as well as the number of candidate matches.
NOTE: Unless otherwise stated, the variable names match those in the paper.
"""
# Selects the last `candidate_length` tokens from `candidate_input_ids`
new_candidate_input_ids = candidate_input_ids[:, -candidate_length:]
# Converts logits to probabilities and extracts assistant (q_i) and model (p_i) probabilities for selected tokens
q = candidate_logits.softmax(dim=-1)
q_i = q[:, torch.arange(candidate_length), new_candidate_input_ids].squeeze(0, 1)
p = new_logits.softmax(dim=-1)
p_i = p[:, torch.arange(candidate_length), new_candidate_input_ids].squeeze(0, 1)
probability_ratio = p_i / q_i
# Determines which tokens to accept based on probability ratios
r_i = torch.rand_like(probability_ratio)
is_accepted = r_i <= probability_ratio
# Computes the number of accepted tokens (`n_matches` in algorithm 1)
n_matches = ((~is_accepted).cumsum(dim=-1) < 1).sum()
# Ensures the generated sequence does not exceed `max_matches` or end with an EOS token
if last_assistant_token_is_eos and n_matches == candidate_length:
# Adjusts `n_matches` if the sequence ends with an EOS token
n_matches -= 1
valid_tokens = new_candidate_input_ids[:, : n_matches + 1]
else:
n_matches = min(n_matches, max_matches)
# Selects the next token considering rejection and adjusts probabilities if needed
gamma = min(candidate_logits.shape[1], max_matches)
p_n_plus_1 = p[:, n_matches, :]
if n_matches < gamma:
q_n_plus_1 = q[:, n_matches, :]
p_prime = torch.clamp((p_n_plus_1 - q_n_plus_1), min=0)
p_prime.div_(p_prime.sum())
else:
p_prime = p_n_plus_1
t = torch.multinomial(p_prime, num_samples=1).squeeze(1)[None, :]
# Constructs the final sequence of valid tokens
if n_matches > 0:
valid_tokens = torch.cat((new_candidate_input_ids[:, :n_matches], t), dim=-1)
else:
valid_tokens = t
return valid_tokens, n_matches
# 给定多个生成的标记的解码器注意力或隐藏状态,将其拆分成一个元组,其中每个成员对应于单个生成的标记。
"""
if len(outputs) == 0:
new_tuple = ()
for layer in new_outputs:
last_dim_size = cur_len if is_decoder_attention else layer.shape[-1]
new_tuple += (layer[..., :cur_len, :last_dim_size],)
outputs += (new_tuple,)
cur_len += 1
added_len -= cur_len
for i in range(added_len):
new_tuple = ()
for layer in new_outputs:
last_dim_size = cur_len + i if is_decoder_attention else layer.shape[-1]
new_tuple += (layer[..., i : i + 1, :last_dim_size],)
outputs += (new_tuple,)
return outputs
norm_context_hidden = context_hidden / context_hidden.norm(dim=2, keepdim=True)
norm_next_hidden = next_hidden / next_hidden.norm(dim=2, keepdim=True)
cosine_matrix = torch.matmul(norm_context_hidden, norm_next_hidden.transpose(1, 2)).squeeze(-1)
degeneration_penalty, _ = torch.max(cosine_matrix, dim=-1)
next_top_k_probs = next_top_k_probs.view(-1)
contrastive_score = (1.0 - alpha) * next_top_k_probs - alpha * degeneration_penalty
contrastive_score = torch.stack(torch.split(contrastive_score, beam_width))
_, selected_idx = contrastive_score.max(dim=-1)
return selected_idx
def _split(data, full_batch_size: int, split_size: int = None):
if data is None:
return [None] * (full_batch_size // split_size)
if isinstance(data, torch.Tensor):
return [data[i : i + split_size] for i in range(0, full_batch_size, split_size)]
elif isinstance(data, tuple):
if isinstance(data[0], tuple):
return [
tuple(tuple(tensor[i : i + split_size] for tensor in inner_tuple) for inner_tuple in data)
for i in range(0, full_batch_size, split_size)
]
else:
return [
tuple(sub_tensor[i : i + split_size] for sub_tensor in data)
for i in range(0, full_batch_size, split_size)
]
else:
raise ValueError(f"Unexpected attribute type: {type(data)}")
def _split_model_inputs(
model_input: Union[ModelOutput, Dict], split_size: int, full_batch_size: int
) -> List[Union[ModelOutput, Dict]]:
"""
Split a ModelOutput object (or its subclasses) or Dict into a list of same-class objects based on a specified split
size. The input object is dict when it was prepared for forward pass and ModelOutput when it was returned from
previous forward pass.
"""
if model_input is None:
return [model_input] * (full_batch_size // split_size)
model_output_cls = type(model_input)
if (full_batch_size % split_size) != 0:
raise ValueError("`full_batch_size` must be divisible by `split_size`")
if split_size > full_batch_size:
raise ValueError("`split_size` must be smaller or equal to `full_batch_size`")
keys = (
model_input.__dataclass_fields__.keys() if hasattr(model_input, "__dataclass_fields__") else model_input.keys()
)
keys = [k for k in keys if k in model_input]
bool_keys = [k for k in keys if isinstance(model_input[k], bool) or k == "cache_position"]
keys_to_ignore = ["cache_position", "encoder_outputs"]
non_bool_keys = [k for k in keys if not isinstance(model_input[k], bool) and k not in keys_to_ignore]
data_split_list = [
{k: _split(model_input[k], full_batch_size, split_size)[i] for k in non_bool_keys}
for i in range(full_batch_size // split_size)
]
bool_data = {k: model_input[k] for k in bool_keys}
if "encoder_outputs" in model_input:
encoder_outputs_split = _split_model_inputs(model_input["encoder_outputs"], split_size, full_batch_size)
data_split_list = [
{**data_split, "encoder_outputs": encoder_outputs_split[i]} for i, data_split in enumerate(data_split_list)
]
split_model_inputs: List[Union[ModelOutput, Dict]] = [
model_output_cls(**data_split, **bool_data) for data_split in data_split_list
]
return split_model_inputs
def stack_model_outputs(model_outputs: List[ModelOutput]) -> ModelOutput:
"""
Stack a list of ModelOutput objects (or its subclasses) along the batch_size dimension. The function infers the
specific ModelOutput subclass from the list provided.
"""
if not model_outputs:
raise ValueError("Input list is empty.")
model_output_cls = type(model_outputs[0])
if not all(isinstance(obj, model_output_cls) for obj in model_outputs):
raise ValueError("All elements in the list should be of the same type.")
def _concat(data):
"""
Reverse of `_split` function above.
"""
if any(data is None for data in data):
return None
if isinstance(data[0], torch.Tensor):
return torch.cat(data, dim=0)
elif isinstance(data[0], tuple):
if isinstance(data[0][0], tuple):
return tuple(
tuple(torch.cat([attr[i][j] for attr in data], dim=0) for j in range(len(data[0][0])))
for i in range(len(data[0]))
)
else:
return tuple(torch.cat([attr[i] for attr in data], dim=0) for i in range(len(data[0])))
elif isinstance(data[0], (int, float)):
return torch.tensor(data)
else:
raise ValueError(f"Unexpected attribute type: {type(data[0])}")
concatenated_data = {
k: _concat([getattr(model_output, k) for model_output in model_outputs])
for k in model_output_cls.__dataclass_fields__.keys()
}
return model_output_cls(**concatenated_data)
.\generation\__init__.py
from typing import TYPE_CHECKING
from ..utils import OptionalDependencyNotAvailable, _LazyModule, is_flax_available, is_tf_available, is_torch_available
_import_structure = {
"configuration_utils": ["GenerationConfig", "GenerationMode"],
"streamers": ["TextIteratorStreamer", "TextStreamer"],
}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["beam_constraints"] = [
"Constraint",
"ConstraintListState",
"DisjunctiveConstraint",
"PhrasalConstraint",
]
_import_structure["beam_search"] = [
"BeamHypotheses",
"BeamScorer",
"BeamSearchScorer",
"ConstrainedBeamSearchScorer",
]
_import_structure["candidate_generator"] = [
"AssistedCandidateGenerator",
"CandidateGenerator",
"PromptLookupCandidateGenerator",
]
_import_structure["logits_process"] = [
"AlternatingCodebooksLogitsProcessor",
"ClassifierFreeGuidanceLogitsProcessor",
"EncoderNoRepeatNGramLogitsProcessor",
"EncoderRepetitionPenaltyLogitsProcessor",
"EpsilonLogitsWarper",
"EtaLogitsWarper",
"ExponentialDecayLengthPenalty",
"ForcedBOSTokenLogitsProcessor",
"ForcedEOSTokenLogitsProcessor",
"ForceTokensLogitsProcessor",
"HammingDiversityLogitsProcessor",
"InfNanRemoveLogitsProcessor",
"LogitNormalization",
"LogitsProcessor",
"LogitsProcessorList",
"LogitsWarper",
"MinLengthLogitsProcessor",
"MinNewTokensLengthLogitsProcessor",
"NoBadWordsLogitsProcessor",
"NoRepeatNGramLogitsProcessor",
"PrefixConstrainedLogitsProcessor",
"RepetitionPenaltyLogitsProcessor",
"SequenceBiasLogitsProcessor",
"SuppressTokensLogitsProcessor",
"SuppressTokensAtBeginLogitsProcessor",
"TemperatureLogitsWarper",
"TopKLogitsWarper",
"TopPLogitsWarper",
"TypicalLogitsWarper",
"UnbatchedClassifierFreeGuidanceLogitsProcessor",
"WhisperTimeStampLogitsProcessor",
]
_import_structure["stopping_criteria"] = [
"MaxNewTokensCriteria",
"MaxLengthCriteria",
"MaxTimeCriteria",
"StoppingCriteria",
"StoppingCriteriaList",
"validate_stopping_criteria",
]
_import_structure["utils"] = [
"GenerationMixin",
"GreedySearchEncoderDecoderOutput",
"GreedySearchDecoderOnlyOutput",
"SampleEncoderDecoderOutput",
"SampleDecoderOnlyOutput",
"BeamSearchEncoderDecoderOutput",
"BeamSearchDecoderOnlyOutput",
"BeamSampleEncoderDecoderOutput",
"BeamSampleDecoderOnlyOutput",
"ContrastiveSearchEncoderDecoderOutput",
"ContrastiveSearchDecoderOnlyOutput",
"GenerateBeamDecoderOnlyOutput",
"GenerateBeamEncoderDecoderOutput",
"GenerateDecoderOnlyOutput",
"GenerateEncoderDecoderOutput",
]
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["tf_logits_process"] = [
"TFForcedBOSTokenLogitsProcessor",
"TFForcedEOSTokenLogitsProcessor",
"TFForceTokensLogitsProcessor",
"TFLogitsProcessor",
"TFLogitsProcessorList",
"TFLogitsWarper",
"TFMinLengthLogitsProcessor",
"TFNoBadWordsLogitsProcessor",
"TFNoRepeatNGramLogitsProcessor",
"TFRepetitionPenaltyLogitsProcessor",
"TFSuppressTokensAtBeginLogitsProcessor",
"TFSuppressTokensLogitsProcessor",
"TFTemperatureLogitsWarper",
"TFTopKLogitsWarper",
"TFTopPLogitsWarper",
]
_import_structure["tf_utils"] = [
"TFGenerationMixin",
"TFGreedySearchDecoderOnlyOutput",
"TFGreedySearchEncoderDecoderOutput",
"TFSampleEncoderDecoderOutput",
"TFSampleDecoderOnlyOutput",
"TFBeamSearchEncoderDecoderOutput",
"TFBeamSearchDecoderOnlyOutput",
"TFBeamSampleEncoderDecoderOutput",
"TFBeamSampleDecoderOnlyOutput",
"TFContrastiveSearchEncoderDecoderOutput",
"TFContrastiveSearchDecoderOnlyOutput",
]
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["flax_logits_process"] = [
"FlaxForcedBOSTokenLogitsProcessor",
"FlaxForcedEOSTokenLogitsProcessor",
"FlaxForceTokensLogitsProcessor",
"FlaxLogitsProcessor",
"FlaxLogitsProcessorList",
"FlaxLogitsWarper",
"FlaxMinLengthLogitsProcessor",
"FlaxSuppressTokensAtBeginLogitsProcessor",
"FlaxSuppressTokensLogitsProcessor",
"FlaxTemperatureLogitsWarper",
"FlaxTopKLogitsWarper",
"FlaxTopPLogitsWarper",
"FlaxWhisperTimeStampLogitsProcessor",
]
_import_structure["flax_utils"] = [
"FlaxGenerationMixin",
"FlaxGreedySearchOutput",
"FlaxSampleOutput",
"FlaxBeamSearchOutput",
]
if TYPE_CHECKING:
from .configuration_utils import GenerationConfig, GenerationMode
from .streamers import TextIteratorStreamer, TextStreamer
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
from .beam_constraints import Constraint, ConstraintListState, DisjunctiveConstraint, PhrasalConstraint
from .beam_search import BeamHypotheses, BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
from .candidate_generator import AssistedCandidateGenerator, CandidateGenerator, PromptLookupCandidateGenerator
from .logits_process import (
AlternatingCodebooksLogitsProcessor,
ClassifierFreeGuidanceLogitsProcessor,
EncoderNoRepeatNGramLogitsProcessor,
EncoderRepetitionPenaltyLogitsProcessor,
EpsilonLogitsWarper,
EtaLogitsWarper,
ExponentialDecayLengthPenalty,
ForcedBOSTokenLogitsProcessor,
ForcedEOSTokenLogitsProcessor,
ForceTokensLogitsProcessor,
HammingDiversityLogitsProcessor,
InfNanRemoveLogitsProcessor,
LogitNormalization,
LogitsProcessor,
LogitsProcessorList,
LogitsWarper,
MinLengthLogitsProcessor,
MinNewTokensLengthLogitsProcessor,
NoBadWordsLogitsProcessor,
NoRepeatNGramLogitsProcessor,
PrefixConstrainedLogitsProcessor,
RepetitionPenaltyLogitsProcessor,
SequenceBiasLogitsProcessor,
SuppressTokensAtBeginLogitsProcessor,
SuppressTokensLogitsProcessor,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
TypicalLogitsWarper,
UnbatchedClassifierFreeGuidanceLogitsProcessor,
WhisperTimeStampLogitsProcessor,
)
from .stopping_criteria import (
MaxLengthCriteria,
MaxNewTokensCriteria,
MaxTimeCriteria,
StoppingCriteria,
StoppingCriteriaList,
validate_stopping_criteria,
)
from .utils import (
BeamSampleDecoderOnlyOutput,
BeamSampleEncoderDecoderOutput,
BeamSearchDecoderOnlyOutput,
BeamSearchEncoderDecoderOutput,
ContrastiveSearchDecoderOnlyOutput,
ContrastiveSearchEncoderDecoderOutput,
GenerateBeamDecoderOnlyOutput,
GenerateBeamEncoderDecoderOutput,
GenerateDecoderOnlyOutput,
GenerateEncoderDecoderOutput,
GenerationMixin,
GreedySearchDecoderOnlyOutput,
GreedySearchEncoderDecoderOutput,
SampleDecoderOnlyOutput,
SampleEncoderDecoderOutput,
)
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .tf_logits_process import (
TFForcedBOSTokenLogitsProcessor,
TFForcedEOSTokenLogitsProcessor,
TFForceTokensLogitsProcessor,
TFLogitsProcessor,
TFLogitsProcessorList,
TFLogitsWarper,
TFMinLengthLogitsProcessor,
TFNoBadWordsLogitsProcessor,
TFNoRepeatNGramLogitsProcessor,
TFRepetitionPenaltyLogitsProcessor,
TFSuppressTokensAtBeginLogitsProcessor,
TFSuppressTokensLogitsProcessor,
TFTemperatureLogitsWarper,
TFTopKLogitsWarper,
TFTopPLogitsWarper,
)
from .tf_utils import (
TFBeamSampleDecoderOnlyOutput,
TFBeamSampleEncoderDecoderOutput,
TFBeamSearchDecoderOnlyOutput,
TFBeamSearchEncoderDecoderOutput,
TFContrastiveSearchDecoderOnlyOutput,
TFContrastiveSearchEncoderDecoderOutput,
TFGenerationMixin,
TFGreedySearchDecoderOnlyOutput,
TFGreedySearchEncoderDecoderOutput,
TFSampleDecoderOnlyOutput,
TFSampleEncoderDecoderOutput,
)
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .flax_logits_process import (
FlaxForcedBOSTokenLogitsProcessor,
FlaxForcedEOSTokenLogitsProcessor,
FlaxForceTokensLogitsProcessor,
FlaxLogitsProcessor,
FlaxLogitsProcessorList,
FlaxLogitsWarper,
FlaxMinLengthLogitsProcessor,
FlaxSuppressTokensAtBeginLogitsProcessor,
FlaxSuppressTokensLogitsProcessor,
FlaxTemperatureLogitsWarper,
FlaxTopKLogitsWarper,
FlaxTopPLogitsWarper,
FlaxWhisperTimeStampLogitsProcessor,
)
from .flax_utils import (
FlaxBeamSearchOutput,
FlaxGenerationMixin,
FlaxGreedySearchOutput,
FlaxSampleOutput,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
.\generation_flax_utils.py
import warnings
from .generation import FlaxGenerationMixin
class FlaxGenerationMixin(FlaxGenerationMixin):
warnings.warn(
"Importing `FlaxGenerationMixin` from `src/transformers/generation_flax_utils.py` is deprecated and will "
"be removed in Transformers v4.40. Import as `from transformers import FlaxGenerationMixin` instead.",
FutureWarning,
)
.\generation_tf_utils.py
import warnings
warnings.warn(
"Importing `TFGenerationMixin` from `src/transformers/generation_tf_utils.py` is deprecated and will "
"be removed in Transformers v4.40. Import as `from transformers import TFGenerationMixin` instead.",
FutureWarning,
)
.\generation_utils.py
import warnings
from .generation import GenerationMixin
class GenerationMixin(GenerationMixin):
warnings.warn(
"Importing `GenerationMixin` from `src/transformers/generation_utils.py` is deprecated and will "
"be removed in Transformers v4.40. Import as `from transformers import GenerationMixin` instead.",
FutureWarning,
)
.\hf_argparser.py
import dataclasses
import json
import sys
import types
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser, ArgumentTypeError
from copy import copy
from enum import Enum
from inspect import isclass
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, List, Literal, NewType, Optional, Tuple, Union, get_type_hints
import yaml
DataClass = NewType("DataClass", Any)
DataClassType = NewType("DataClassType", Any)
def string_to_bool(v):
"""
解析字符串表示的布尔值。
Args:
v (str): 输入的字符串值。
Returns:
bool: 如果字符串表示真值,则返回 True;否则返回 False。
Raises:
ArgumentTypeError: 如果无法解析字符串为布尔值,抛出异常。
"""
if isinstance(v, bool):
return v
if v.lower() in ("yes", "true", "t", "y", "1"):
return True
elif v.lower() in ("no", "false", "f", "n", "0"):
return False
else:
raise ArgumentTypeError(
f"Truthy value expected: got {v} but expected one of yes/no, true/false, t/f, y/n, 1/0 (case insensitive)."
)
def make_choice_type_function(choices: list) -> Callable[[str], Any]:
"""
创建从每个选择字符串表示到实际值的映射函数。用于支持单个参数的多个值类型。
Args:
choices (list): 选择列表。
Returns:
Callable[[str], Any]: 从字符串表示到每个选择的实际值的映射函数。
"""
str_to_choice = {str(choice): choice for choice in choices}
return lambda arg: str_to_choice.get(arg, arg)
def HfArg(
*,
aliases: Union[str, List[str]] = None,
help: str = None,
default: Any = dataclasses.MISSING,
default_factory: Callable[[], Any] = dataclasses.MISSING,
metadata: dict = None,
**kwargs,
) -> dataclasses.Field:
"""
参数辅助函数,允许使用简洁的语法为 `HfArgumentParser` 创建数据类字段。
Example comparing the use of `HfArg` and `dataclasses.field`:
示例比较了 `HfArg` 和 `dataclasses.field` 的使用:
```
@dataclass
class Args:
regular_arg: str = dataclasses.field(default="Huggingface", metadata={"aliases": ["--example", "-e"], "help": "This syntax could be better!"})
hf_arg: str = HfArg(default="Huggingface", aliases=["--example", "-e"], help="What a nice syntax!")
```
"""
pass
def make_field(aliases=None, help=None, default=dataclasses.MISSING, default_factory=dataclasses.MISSING, metadata=None, **kwargs):
"""
Construct a `dataclasses.Field` object with specified properties.
Args:
aliases (Union[str, List[str]], optional):
Single string or list of strings of aliases to pass on to argparse, e.g. `aliases=["--example", "-e"]`.
Defaults to None.
help (str, optional): Help string to pass on to argparse that can be displayed with --help. Defaults to None.
default (Any, optional):
Default value for the argument. If not default or default_factory is specified, the argument is required.
Defaults to dataclasses.MISSING.
default_factory (Callable[[], Any], optional):
The default_factory is a 0-argument function called to initialize a field's value. It is useful to provide
default values for mutable types, e.g. lists: `default_factory=list`. Mutually exclusive with `default=`.
Defaults to dataclasses.MISSING.
metadata (dict, optional): Further metadata to pass on to `dataclasses.field`. Defaults to None.
Returns:
Field: A `dataclasses.Field` with the desired properties.
"""
if metadata is None:
metadata = {}
if aliases is not None:
metadata["aliases"] = aliases
if help is not None:
metadata["help"] = help
return dataclasses.field(metadata=metadata, default=default, default_factory=default_factory, **kwargs)
class HfArgumentParser(ArgumentParser):
"""
This subclass of `argparse.ArgumentParser` uses type hints on dataclasses to generate arguments.
The class is designed to play well with the native argparse. In particular, you can add more (non-dataclass backed)
arguments to the parser after initialization and you'll get the output back after parsing as an additional
namespace. Optional: To create sub argument groups use the `_argument_group_name` attribute in the dataclass.
"""
dataclass_types: Iterable[DataClassType]
def __init__(self, dataclass_types: Union[DataClassType, Iterable[DataClassType]], **kwargs):
"""
Args:
dataclass_types:
Dataclass type, or list of dataclass types for which we will "fill" instances with the parsed args.
kwargs (`Dict[str, Any]`, *optional*):
Passed to `argparse.ArgumentParser()` in the regular way.
"""
if "formatter_class" not in kwargs:
kwargs["formatter_class"] = ArgumentDefaultsHelpFormatter
super().__init__(**kwargs)
if dataclasses.is_dataclass(dataclass_types):
dataclass_types = [dataclass_types]
self.dataclass_types = list(dataclass_types)
for dtype in self.dataclass_types:
self._add_dataclass_arguments(dtype)
@staticmethod
def _add_dataclass_arguments(self, dtype: DataClassType):
if hasattr(dtype, "_argument_group_name"):
parser = self.add_argument_group(dtype._argument_group_name)
else:
parser = self
try:
type_hints: Dict[str, type] = get_type_hints(dtype)
except NameError:
raise RuntimeError(
f"Type resolution failed for {dtype}. Try declaring the class in global scope or "
"removing line of `from __future__ import annotations` which opts in Postponed "
"Evaluation of Annotations (PEP 563)"
)
except TypeError as ex:
if sys.version_info[:2] < (3, 10) and "unsupported operand type(s) for |" in str(ex):
python_version = ".".join(map(str, sys.version_info[:3]))
raise RuntimeError(
f"Type resolution failed for {dtype} on Python {python_version}. Try removing "
"line of `from __future__ import annotations` which opts in union types as "
"`X | Y` (PEP 604) via Postponed Evaluation of Annotations (PEP 563). To "
"support Python versions that lower than 3.10, you need to use "
"`typing.Union[X, Y]` instead of `X | Y` and `typing.Optional[X]` instead of "
"`X | None`."
) from ex
raise
for field in dataclasses.fields(dtype):
if not field.init:
continue
field.type = type_hints[field.name]
self._parse_dataclass_field(parser, field)
def parse_args_into_dataclasses(
self,
args=None,
return_remaining_strings=False,
look_for_args_file=True,
args_filename=None,
args_file_flag=None,
def parse_dict(self, args: Dict[str, Any], allow_extra_keys: bool = False) -> Tuple[DataClass, ...]:
"""
Alternative helper method that does not use `argparse` at all, instead uses a dict and populating the dataclass
types.
Args:
args (`dict`):
dict containing config values
allow_extra_keys (`bool`, *optional*, defaults to `False`):
Defaults to False. If False, will raise an exception if the dict contains keys that are not parsed.
Returns:
Tuple consisting of:
- the dataclass instances in the same order as they were passed to the initializer.
"""
unused_keys = set(args.keys())
outputs = []
for dtype in self.dataclass_types:
keys = {f.name for f in dataclasses.fields(dtype) if f.init}
inputs = {k: v for k, v in args.items() if k in keys}
unused_keys.difference_update(inputs.keys())
obj = dtype(**inputs)
outputs.append(obj)
if not allow_extra_keys and unused_keys:
raise ValueError(f"Some keys are not used by the HfArgumentParser: {sorted(unused_keys)}")
return tuple(outputs)
def parse_json_file(self, json_file: str, allow_extra_keys: bool = False) -> Tuple[DataClass, ...]:
"""
Alternative helper method that does not use `argparse` at all, instead loading a json file and populating the
dataclass types.
Args:
json_file (`str` or `os.PathLike`):
File name of the json file to parse
allow_extra_keys (`bool`, *optional*, defaults to `False`):
Defaults to False. If False, will raise an exception if the json file contains keys that are not
parsed.
Returns:
Tuple consisting of:
- the dataclass instances in the same order as they were passed to the initializer.
"""
with open(Path(json_file), encoding="utf-8") as open_json_file:
data = json.loads(open_json_file.read())
outputs = self.parse_dict(data, allow_extra_keys=allow_extra_keys)
return tuple(outputs)
def parse_yaml_file(self, yaml_file: str, allow_extra_keys: bool = False) -> Tuple[DataClass, ...]:
"""
Alternative helper method that does not use `argparse` at all, instead loading a yaml file and populating the
dataclass types.
Args:
yaml_file (`str` or `os.PathLike`):
File name of the yaml file to parse
allow_extra_keys (`bool`, *optional*, defaults to `False`):
Defaults to False. If False, will raise an exception if the json file contains keys that are not
parsed.
Returns:
Tuple consisting of:
- the dataclass instances in the same order as they were passed to the initializer.
"""
outputs = self.parse_dict(yaml.safe_load(Path(yaml_file).read_text()), allow_extra_keys=allow_extra_keys)
return tuple(outputs)
.\hyperparameter_search.py
from .integrations import (
is_optuna_available,
is_ray_tune_available,
is_sigopt_available,
is_wandb_available,
run_hp_search_optuna,
run_hp_search_ray,
run_hp_search_sigopt,
run_hp_search_wandb,
)
from .trainer_utils import (
HPSearchBackend,
default_hp_space_optuna,
default_hp_space_ray,
default_hp_space_sigopt,
default_hp_space_wandb,
)
from .utils import logging
logger = logging.get_logger(__name__)
class HyperParamSearchBackendBase:
name: str
pip_package: str = None
@staticmethod
def is_available():
raise NotImplementedError
def run(self, trainer, n_trials: int, direction: str, **kwargs):
raise NotImplementedError
def default_hp_space(self, trial):
raise NotImplementedError
def ensure_available(self):
if not self.is_available():
raise RuntimeError(
f"You picked the {self.name} backend, but it is not installed. Run {self.pip_install()}."
)
@classmethod
def pip_install(cls):
return f"`pip install {cls.pip_package or cls.name}`"
class OptunaBackend(HyperParamSearchBackendBase):
name = "optuna"
@staticmethod
def is_available():
return is_optuna_available()
def run(self, trainer, n_trials: int, direction: str, **kwargs):
return run_hp_search_optuna(trainer, n_trials, direction, **kwargs)
def default_hp_space(self, trial):
return default_hp_space_optuna(trial)
class RayTuneBackend(HyperParamSearchBackendBase):
name = "ray"
pip_package = "'ray[tune]'"
@staticmethod
def is_available():
return is_ray_tune_available()
def run(self, trainer, n_trials: int, direction: str, **kwargs):
return run_hp_search_ray(trainer, n_trials, direction, **kwargs)
def default_hp_space(self, trial):
return default_hp_space_ray(trial)
class SigOptBackend(HyperParamSearchBackendBase):
name = "sigopt"
@staticmethod
def is_available():
return is_sigopt_available()
def run(self, trainer, n_trials: int, direction: str, **kwargs):
return run_hp_search_sigopt(trainer, n_trials, direction, **kwargs)
def default_hp_space(self, trial):
return default_hp_space_sigopt(trial)
class WandbBackend(HyperParamSearchBackendBase):
name = "wandb"
@staticmethod
def is_available():
return is_wandb_available()
def run(self, trainer, n_trials: int, direction: str, **kwargs):
return run_hp_search_wandb(trainer, n_trials, direction, **kwargs)
def default_hp_space(self, trial):
return default_hp_space_wandb(trial)
@staticmethod
def is_available():
return is_wandb_available()
def run(self, trainer, n_trials: int, direction: str, **kwargs):
return run_hp_search_wandb(trainer, n_trials, direction, **kwargs)
def default_hp_space(self, trial):
return default_hp_space_wandb(trial)
ALL_HYPERPARAMETER_SEARCH_BACKENDS = {
HPSearchBackend(backend.name): backend for backend in [OptunaBackend, RayTuneBackend, SigOptBackend, WandbBackend]
}
def default_hp_search_backend() -> str:
available_backends = [backend for backend in ALL_HYPERPARAMETER_SEARCH_BACKENDS.values() if backend.is_available()]
if len(available_backends) > 0:
name = available_backends[0].name
if len(available_backends) > 1:
logger.info(
f"{len(available_backends)} hyperparameter search backends available. Using {name} as the default."
)
return name
raise RuntimeError(
"No hyperparameter search backend available.\n"
+ "\n".join(
f" - To install {backend.name} run {backend.pip_install()}"
for backend in ALL_HYPERPARAMETER_SEARCH_BACKENDS.values()
)
)
.\image_processing_utils.py
import copy
import json
import os
import warnings
from io import BytesIO
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import numpy as np
import requests
from .dynamic_module_utils import custom_object_save
from .feature_extraction_utils import BatchFeature as BaseBatchFeature
from .image_transforms import center_crop, normalize, rescale
from .image_utils import ChannelDimension
from .utils import (
IMAGE_PROCESSOR_NAME,
PushToHubMixin,
add_model_info_to_auto_map,
cached_file,
copy_func,
download_url,
is_offline_mode,
is_remote_url,
is_vision_available,
logging,
)
if is_vision_available():
from PIL import Image
logger = logging.get_logger(__name__)
class BatchFeature(BaseBatchFeature):
r"""
Holds the output of the image processor specific `__call__` methods.
This class is derived from a python dictionary and can be used as a dictionary.
Args:
data (`dict`):
Dictionary of lists/arrays/tensors returned by the __call__ method ('pixel_values', etc.).
tensor_type (`Union[None, str, TensorType]`, *optional*):
You can give a tensor_type here to convert the lists of integers in PyTorch/TensorFlow/Numpy Tensors at
initialization.
"""
class ImageProcessingMixin(PushToHubMixin):
"""
This is an image processor mixin used to provide saving/loading functionality for sequential and image feature
extractors.
"""
_auto_class = None
def __init__(self, **kwargs):
"""Set elements of `kwargs` as attributes."""
kwargs.pop("feature_extractor_type", None)
self._processor_class = kwargs.pop("processor_class", None)
for key, value in kwargs.items():
try:
setattr(self, key, value)
except AttributeError as err:
logger.error(f"Can't set {key} with value {value} for {self}")
raise err
def _set_processor_class(self, processor_class: str):
"""Sets processor class as an attribute."""
self._processor_class = processor_class
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Union[str, os.PathLike],
cache_dir: Optional[Union[str, os.PathLike]] = None,
force_download: bool = False,
local_files_only: bool = False,
token: Optional[Union[str, bool]] = None,
revision: str = "main",
**kwargs,
def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
"""
Save an image processor object to the directory `save_directory`, so that it can be re-loaded using the
[`~image_processing_utils.ImageProcessingMixin.from_pretrained`] class method.
Args:
save_directory (`str` or `os.PathLike`):
Directory where the image processor JSON file will be saved (will be created if it does not exist).
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
namespace).
kwargs (`Dict[str, Any]`, *optional*):
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
"""
use_auth_token = kwargs.pop("use_auth_token", None)
if use_auth_token is not None:
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
FutureWarning,
)
if kwargs.get("token", None) is not None:
raise ValueError(
"`token` and `use_auth_token` are both specified. Please set only the argument `token`."
)
kwargs["token"] = use_auth_token
if os.path.isfile(save_directory):
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
os.makedirs(save_directory, exist_ok=True)
if push_to_hub:
commit_message = kwargs.pop("commit_message", None)
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
repo_id = self._create_repo(repo_id, **kwargs)
files_timestamps = self._get_files_timestamps(save_directory)
if self._auto_class is not None:
custom_object_save(self, save_directory, config=self)
output_image_processor_file = os.path.join(save_directory, IMAGE_PROCESSOR_NAME)
self.to_json_file(output_image_processor_file)
logger.info(f"Image processor saved in {output_image_processor_file}")
if push_to_hub:
self._upload_modified_files(
save_directory,
repo_id,
files_timestamps,
commit_message=commit_message,
token=kwargs.get("token"),
)
return [output_image_processor_file]
@classmethod
@classmethod
def get_image_processor_dict(
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
):
"""
Creates a dictionary of parameters (`image_processor_dict`) needed to instantiate an image processor.
Args:
cls: Class method descriptor.
pretrained_model_name_or_path (Union[str, os.PathLike]):
Name or path of the pretrained model for the image processor.
kwargs (Dict[str, Any]):
Additional keyword arguments to customize the image processor.
Returns:
Dict[str, Any]: Dictionary of parameters (`image_processor_dict`) required to instantiate
the image processor.
"""
image_processor_dict = {
"pretrained_model_name_or_path": pretrained_model_name_or_path,
**kwargs
}
return image_processor_dict
@classmethod
def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
"""
Instantiates an image processor object from a dictionary of parameters.
Args:
image_processor_dict (Dict[str, Any]):
Dictionary containing parameters to instantiate the image processor.
Typically obtained from a pretrained checkpoint using `to_dict` method.
kwargs (Dict[str, Any]):
Additional parameters to initialize the image processor object.
Returns:
ImageProcessingMixin: The instantiated image processor object.
"""
image_processor_dict = image_processor_dict.copy()
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
if "size" in kwargs and "size" in image_processor_dict:
image_processor_dict["size"] = kwargs.pop("size")
if "crop_size" in kwargs and "crop_size" in image_processor_dict:
image_processor_dict["crop_size"] = kwargs.pop("crop_size")
image_processor = cls(**image_processor_dict)
to_remove = []
for key, value in kwargs.items():
if hasattr(image_processor, key):
setattr(image_processor, key, value)
to_remove.append(key)
for key in to_remove:
kwargs.pop(key, None)
logger.info(f"Image processor {image_processor}")
if return_unused_kwargs:
return image_processor, kwargs
else:
return image_processor
def to_dict(self) -> Dict[str, Any]:
"""
Serializes the instance attributes of this image processor to a Python dictionary.
Returns:
Dict[str, Any]: Dictionary containing all attributes of the image processor instance.
"""
output = copy.deepcopy(self.__dict__)
output["image_processor_type"] = self.__class__.__name__
return output
def from_json_file(cls, json_file: Union[str, os.PathLike]):
"""
从包含参数的 JSON 文件路径实例化一个 `~image_processing_utils.ImageProcessingMixin` 类型的图像处理器。
Args:
json_file (`str` or `os.PathLike`):
包含参数的 JSON 文件路径。
Returns:
`~image_processing_utils.ImageProcessingMixin` 类型的图像处理器:从指定 JSON 文件实例化的图像处理器对象。
"""
with open(json_file, "r", encoding="utf-8") as reader:
text = reader.read()
image_processor_dict = json.loads(text)
return cls(**image_processor_dict)
def to_json_string(self) -> str:
"""
将当前实例序列化为 JSON 字符串。
Returns:
`str`: 包含当前特征提取器实例所有属性的 JSON 格式字符串。
"""
dictionary = self.to_dict()
for key, value in dictionary.items():
if isinstance(value, np.ndarray):
dictionary[key] = value.tolist()
_processor_class = dictionary.pop("_processor_class", None)
if _processor_class is not None:
dictionary["processor_class"] = _processor_class
return json.dumps(dictionary, indent=2, sort_keys=True) + "\n"
def to_json_file(self, json_file_path: Union[str, os.PathLike]):
"""
将当前实例保存到 JSON 文件中。
Args:
json_file_path (`str` or `os.PathLike`):
将保存此图像处理器实例参数的 JSON 文件路径。
"""
with open(json_file_path, "w", encoding="utf-8") as writer:
writer.write(self.to_json_string())
def __repr__(self):
"""
返回当前实例的字符串表示形式。
Returns:
`str`: 包含当前实例 JSON 格式化字符串的类名。
"""
return f"{self.__class__.__name__} {self.to_json_string()}"
@classmethod
def register_for_auto_class(cls, auto_class="AutoImageProcessor"):
"""
使用给定的自动类注册此类。这仅适用于自定义图像处理器,因为库中的图像处理器已与 `AutoImageProcessor` 映射。
<Tip warning={true}>
此 API 是实验性的,可能在未来版本中有些微的破坏性更改。
</Tip>
Args:
auto_class (`str` or `type`, *optional*, 默认为 `"AutoImageProcessor"`):
要将此新图像处理器注册到的自动类。
"""
if not isinstance(auto_class, str):
auto_class = auto_class.__name__
import transformers.models.auto as auto_module
if not hasattr(auto_module, auto_class):
raise ValueError(f"{auto_class} 不是有效的自动类。")
cls._auto_class = auto_class
def fetch_images(self, image_url_or_urls: Union[str, List[str]]):
"""
Convert a single or a list of URLs into corresponding `PIL.Image` objects.
If a single URL is passed, the return value will be a single object. If a list is passed, a list of objects is
returned.
"""
headers = {
"User-Agent": (
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/114.0.0.0"
" Safari/537.36"
)
}
if isinstance(image_url_or_urls, list):
return [self.fetch_images(x) for x in image_url_or_urls]
elif isinstance(image_url_or_urls, str):
response = requests.get(image_url_or_urls, stream=True, headers=headers)
response.raise_for_status()
return Image.open(BytesIO(response.content))
else:
raise ValueError(f"only a single or a list of entries is supported but got type={type(image_url_or_urls)}")
class BaseImageProcessor(ImageProcessingMixin):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def __call__(self, images, **kwargs) -> BatchFeature:
"""Preprocess an image or a batch of images."""
return self.preprocess(images, **kwargs)
def preprocess(self, images, **kwargs) -> BatchFeature:
raise NotImplementedError("Each image processor must implement its own preprocess method")
def rescale(
self,
image: np.ndarray,
scale: float,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
) -> np.ndarray:
"""
Rescale an image by a scale factor. image = image * scale.
Args:
image (`np.ndarray`):
Image to rescale.
scale (`float`):
The scaling factor to rescale pixel values by.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the output image. If unset, the channel dimension format of the input
image is used. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
Returns:
`np.ndarray`: The rescaled image.
"""
return rescale(image, scale=scale, data_format=data_format, input_data_format=input_data_format, **kwargs)
def normalize(
self,
image: np.ndarray,
mean: Union[float, Iterable[float]],
std: Union[float, Iterable[float]],
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
) -> np.ndarray:
"""
Normalize an image by subtracting mean and dividing by standard deviation.
Args:
image (`np.ndarray`):
Image to normalize.
mean (`float` or `Iterable[float]`):
Mean value(s) for normalization.
std (`float` or `Iterable[float]`):
Standard deviation value(s) for normalization.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the output image. If unset, the channel dimension format of the input
image is used. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
Returns:
`np.ndarray`: The normalized image.
"""
return normalize(image, mean=mean, std=std, data_format=data_format, input_data_format=input_data_format, **kwargs)
) -> np.ndarray:
"""
Normalize an image. image = (image - image_mean) / image_std.
Args:
image (`np.ndarray`):
Image to normalize.
mean (`float` or `Iterable[float]`):
Image mean to use for normalization.
std (`float` or `Iterable[float]`):
Image standard deviation to use for normalization.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the output image. If unset, the channel dimension format of the input
image is used. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
Returns:
`np.ndarray`: The normalized image.
"""
return normalize(
image, mean=mean, std=std, data_format=data_format, input_data_format=input_data_format, **kwargs
)
def center_crop(
self,
image: np.ndarray,
size: Dict[str, int],
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
):
"""
Perform center cropping on the image.
Args:
image (`np.ndarray`):
Image to crop.
size (`Dict[str, int]`):
Dictionary containing the target size for cropping, with keys 'height' and 'width'.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the output image. If unset, the channel dimension format of the input
image is used. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
Returns:
`np.ndarray`: Cropped image.
"""
return center_crop(
image, size=size, data_format=data_format, input_data_format=input_data_format, **kwargs
)
) -> np.ndarray:
"""
Center crop an image to `(size["height"], size["width"])`. If the input size is smaller than `crop_size` along
any edge, the image is padded with 0's and then center cropped.
Args:
image (`np.ndarray`):
Image to center crop.
size (`Dict[str, int]`):
Size of the output image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the output image. If unset, the channel dimension format of the input
image is used. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
"""
size = get_size_dict(size)
if "height" not in size or "width" not in size:
raise ValueError(f"The size dictionary must have keys 'height' and 'width'. Got {size.keys()}")
return center_crop(
image,
size=(size["height"], size["width"]),
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
)
VALID_SIZE_DICT_KEYS = ({"height", "width"}, {"shortest_edge"}, {"shortest_edge", "longest_edge"}, {"longest_edge"})
def is_valid_size_dict(size_dict):
if not isinstance(size_dict, dict):
return False
size_dict_keys = set(size_dict.keys())
for allowed_keys in VALID_SIZE_DICT_KEYS:
if size_dict_keys == allowed_keys:
return True
return False
def convert_to_size_dict(
size, max_size: Optional[int] = None, default_to_square: bool = True, height_width_order: bool = True
):
if isinstance(size, int) and default_to_square:
if max_size is not None:
raise ValueError("Cannot specify both size as an int, with default_to_square=True and max_size")
return {"height": size, "width": size}
elif isinstance(size, int) and not default_to_square:
size_dict = {"shortest_edge": size}
if max_size is not None:
size_dict["longest_edge"] = max_size
return size_dict
elif isinstance(size, (tuple, list)) and height_width_order:
return {"height": size[0], "width": size[1]}
elif isinstance(size, (tuple, list)) and not height_width_order:
return {"height": size[1], "width": size[0]}
elif size is None and max_size is not None:
if default_to_square:
raise ValueError("Cannot specify both default_to_square=True and max_size")
return {"longest_edge": max_size}
raise ValueError(f"Could not convert size input to size dict: {size}")
def get_size_dict(
size: Union[int, Iterable[int], Dict[str, int]] = None,
max_size: Optional[int] = None,
height_width_order: bool = True,
default_to_square: bool = True,
param_name="size",
) -> dict:
"""
Converts the old size parameter in the config into the new dict expected in the config. This is to ensure backwards
compatibility with the old image processor configs and removes ambiguity over whether the tuple is in (height,
width) or (width, height) format.
- If `size` is tuple, it is converted to `{"height": size[0], "width": size[1]}` or `{"height": size[1], "width":
size[0]}` if `height_width_order` is `False`.
- If `size` is an int, and `default_to_square` is `True`, it is converted to `{"height": size, "width": size}`.
- If `size` is an int and `default_to_square` is False, it is converted to `{"shortest_edge": size}`. If `max_size`
is set, it is added to the dict as `{"longest_edge": max_size}`.
"""
return convert_to_size_dict(size, max_size, default_to_square, height_width_order)
"""
Casts the `size` parameter into a standardized size dictionary.
Args:
size (`Union[int, Iterable[int], Dict[str, int]]`, *optional*):
The `size` parameter to be cast into a size dictionary.
max_size (`Optional[int]`, *optional*):
The `max_size` parameter to be cast into a size dictionary.
height_width_order (`bool`, *optional*, defaults to `True`):
If `size` is a tuple, specifies whether it's in (height, width) or (width, height) order.
default_to_square (`bool`, *optional*, defaults to `True`):
If `size` is an int, specifies whether to default to a square image or not.
"""
if not isinstance(size, dict):
size_dict = convert_to_size_dict(size, max_size, default_to_square, height_width_order)
logger.info(
f"{param_name} should be a dictionary on of the following set of keys: {VALID_SIZE_DICT_KEYS}, got {size}."
f" Converted to {size_dict}.",
)
else:
size_dict = size
if not is_valid_size_dict(size_dict):
raise ValueError(
f"{param_name} must have one of the following set of keys: {VALID_SIZE_DICT_KEYS}, got {size_dict.keys()}"
)
return size_dict
def select_best_resolution(original_size: tuple, possible_resolutions: list) -> tuple:
"""
Selects the best resolution from a list of possible resolutions based on the original size.
This is done by calculating the effective and wasted resolution for each possible resolution.
The best fit resolution is the one that maximizes the effective resolution and minimizes the wasted resolution.
Args:
original_size (tuple):
The original size of the image in the format (height, width).
possible_resolutions (list):
A list of possible resolutions in the format [(height1, width1), (height2, width2), ...].
Returns:
tuple: The best fit resolution in the format (height, width).
"""
original_height, original_width = original_size
best_fit = None
max_effective_resolution = 0
min_wasted_resolution = float("inf")
for height, width in possible_resolutions:
scale = min(width / original_width, height / original_height)
downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
wasted_resolution = (width * height) - effective_resolution
if effective_resolution > max_effective_resolution or (
effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution
):
max_effective_resolution = effective_resolution
min_wasted_resolution = wasted_resolution
best_fit = (height, width)
return best_fit
ImageProcessingMixin.push_to_hub = copy_func(ImageProcessingMixin.push_to_hub)
if ImageProcessingMixin.push_to_hub.__doc__ is not None:
ImageProcessingMixin.push_to_hub.__doc__ = ImageProcessingMixin.push_to_hub.__doc__.format(
object="image processor", object_class="AutoImageProcessor", object_files="image processor file"
)
.\image_transforms.py
import warnings
from typing import Iterable, List, Optional, Tuple, Union
import numpy as np
from .image_utils import (
ChannelDimension,
ImageInput,
get_channel_dimension_axis,
get_image_size,
infer_channel_dimension_format,
)
from .utils import ExplicitEnum, TensorType, is_jax_tensor, is_tf_tensor, is_torch_tensor
from .utils.import_utils import (
is_flax_available,
is_tf_available,
is_torch_available,
is_vision_available,
requires_backends,
)
if is_vision_available():
import PIL
from .image_utils import PILImageResampling
if is_torch_available():
import torch
if is_tf_available():
import tensorflow as tf
if is_flax_available():
import jax.numpy as jnp
def to_channel_dimension_format(
image: np.ndarray,
channel_dim: Union[ChannelDimension, str],
input_channel_dim: Optional[Union[ChannelDimension, str]] = None,
) -> np.ndarray:
"""
Converts `image` to the channel dimension format specified by `channel_dim`.
Args:
image (`numpy.ndarray`):
The image to have its channel dimension set.
channel_dim (`ChannelDimension`):
The channel dimension format to use.
input_channel_dim (`ChannelDimension`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred from the input image.
Returns:
`np.ndarray`: The image with the channel dimension set to `channel_dim`.
"""
if not isinstance(image, np.ndarray):
raise ValueError(f"Input image must be of type np.ndarray, got {type(image)}")
if input_channel_dim is None:
input_channel_dim = infer_channel_dimension_format(image)
target_channel_dim = ChannelDimension(channel_dim)
if input_channel_dim == target_channel_dim:
return image
if target_channel_dim == ChannelDimension.FIRST:
image = image.transpose((2, 0, 1))
elif target_channel_dim == ChannelDimension.LAST:
image = image.transpose((1, 2, 0))
else:
raise ValueError("Unsupported channel dimension format: {}".format(channel_dim))
return image
def rescale(
image: np.ndarray,
scale: float,
data_format: Optional[ChannelDimension] = None,
dtype: np.dtype = np.float32,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray:
"""
Rescales the input `image` by a factor of `scale`.
Args:
image (`numpy.ndarray`):
The image to be rescaled.
scale (`float`):
The scaling factor to be applied to the image.
data_format (`ChannelDimension`, *optional*):
The desired channel dimension format of the output image.
dtype (`np.dtype`, *optional*):
The desired data type of the output image.
input_data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
Returns:
`np.ndarray`: The rescaled image.
"""
if not isinstance(image, np.ndarray):
raise ValueError(f"Input image must be of type np.ndarray, got {type(image)}")
rescaled_image = image * scale
if data_format is not None:
rescaled_image = to_channel_dimension_format(rescaled_image, data_format, input_data_format)
rescaled_image = rescaled_image.astype(dtype)
return rescaled_image
def _rescale_for_pil_conversion(image):
if image.dtype == np.uint8:
do_rescale = False
elif np.allclose(image, image.astype(int)):
if np.all(0 <= image) and np.all(image <= 255):
do_rescale = False
else:
raise ValueError(
"The image to be converted to a PIL image contains values outside the range [0, 255], "
f"got [{image.min()}, {image.max()}] which cannot be converted to uint8."
)
elif np.all(0 <= image) and np.all(image <= 1):
do_rescale = True
else:
raise ValueError(
"The image to be converted to a PIL image contains values outside the range [0, 1], "
f"got [{image.min()}, {image.max()}] which cannot be converted to uint8."
)
return do_rescale
def to_pil_image(
image: Union[np.ndarray, "PIL.Image.Image", "torch.Tensor", "tf.Tensor", "jnp.ndarray"],
do_rescale: Optional[bool] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> "PIL.Image.Image":
"""
Converts `image` to a PIL Image. Optionally rescales it and puts the channel dimension back as the last axis if
needed.
Args:
image (`PIL.Image.Image` or `numpy.ndarray` or `torch.Tensor` or `tf.Tensor`):
The image to convert to the `PIL.Image` format.
do_rescale (`bool`, *optional*):
Whether or not to apply the scaling factor (to make pixel values integers between 0 and 255). Will default
to `True` if the image type is a floating type and casting to `int` would result in a loss of precision,
and `False` otherwise.
input_data_format (`ChannelDimension`, *optional*):
The channel dimension format of the input image. If unset, will use the inferred format from the input.
Returns:
`PIL.Image.Image`: The converted image.
"""
requires_backends(to_pil_image, ["vision"])
if isinstance(image, PIL.Image.Image):
return image
if is_torch_tensor(image) or is_tf_tensor(image):
image = image.numpy()
elif is_jax_tensor(image):
image = np.array(image)
elif not isinstance(image, np.ndarray):
raise ValueError("Input image type not supported: {}".format(type(image)))
image = to_channel_dimension_format(image, ChannelDimension.LAST, input_data_format)
image = np.squeeze(image, axis=-1) if image.shape[-1] == 1 else image
do_rescale = _rescale_for_pil_conversion(image) if do_rescale is None else do_rescale
if do_rescale:
image = rescale(image, 255)
image = image.astype(np.uint8)
return PIL.Image.fromarray(image)
def get_resize_output_image_size(
input_image: np.ndarray,
size: Union[int, Tuple[int, int], List[int], Tuple[int]],
default_to_square: bool = True,
max_size: Optional[int] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> tuple:
"""
Find the target (height, width) dimension of the output image after resizing given the input image and the desired
size.
Args:
input_image (`np.ndarray`):
The image to resize.
size (`int` or `Tuple[int, int]` or List[int] or Tuple[int]):
The size to use for resizing the image. If `size` is a sequence like (h, w), output size will be matched to
this.
If `size` is an int and `default_to_square` is `True`, then image will be resized to (size, size). If
`size` is an int and `default_to_square` is `False`, then smaller edge of the image will be matched to this
number. i.e, if height > width, then image will be rescaled to (size * height / width, size).
default_to_square (`bool`, *optional*, defaults to `True`):
How to convert `size` when it is a single int. If set to `True`, the `size` will be converted to a square
(`size`,`size`). If set to `False`, will replicate
[`torchvision.transforms.Resize`](https://pytorch.org/vision/stable/transforms.html#torchvision.transforms.Resize)
with support for resizing only the smallest edge and providing an optional `max_size`.
max_size (`int`, *optional*):
The maximum allowed for the longer edge of the resized image: if the longer edge of the image is greater
than `max_size` after being resized according to `size`, then the image is resized again so that the longer
edge is equal to `max_size`. As a result, `size` might be overruled, i.e the smaller edge may be shorter
than `size`. Only used if `default_to_square` is `False`.
input_data_format (`ChannelDimension`, *optional*):
The channel dimension format of the input image. If unset, will use the inferred format from the input.
Returns:
`tuple`: The target (height, width) dimension of the output image after resizing.
"""
if isinstance(size, (tuple, list)):
if len(size) == 2:
return tuple(size)
elif len(size) == 1:
size = size[0]
else:
raise ValueError("size must have 1 or 2 elements if it is a list or tuple")
if default_to_square:
return (size, size)
height, width = get_image_size(input_image, input_data_format)
short, long = (width, height) if width <= height else (height, width)
requested_new_short = size
new_short, new_long = requested_new_short, int(requested_new_short * long / short)
if max_size is not None:
if max_size <= requested_new_short:
raise ValueError(
f"max_size = {max_size} must be strictly greater than the requested "
f"size for the smaller edge size = {size}"
)
if new_long > max_size:
new_short, new_long = int(max_size * new_short / new_long), max_size
return (new_long, new_short) if width <= height else (new_short, new_long)
"""
使用 PIL 库将 `image` 调整大小为 `size` 指定的尺寸。
Args:
image (`np.ndarray`):
要调整大小的图像。
size (`Tuple[int, int]`):
用于调整图像大小的尺寸。
resample (`int`, *optional*, 默认为 `PILImageResampling.BILINEAR`):
用于重采样的滤波器。
reducing_gap (`int`, *optional*):
通过两步骤优化图像调整大小。`reducing_gap` 越大,结果越接近公平重采样。详细信息请参考 Pillow 文档。
data_format (`ChannelDimension`, *optional*):
输出图像的通道维度格式。如果未设置,将从输入中推断格式。
return_numpy (`bool`, *optional*, 默认为 `True`):
是否将调整大小后的图像作为 numpy 数组返回。如果为 False,则返回 `PIL.Image.Image` 对象。
input_data_format (`ChannelDimension`, *optional*):
输入图像的通道维度格式。如果未设置,将从输入中推断格式。
Returns:
`np.ndarray`: 调整大小后的图像。
"""
requires_backends(resize, ["vision"])
resample = resample if resample is not None else PILImageResampling.BILINEAR
if not len(size) == 2:
raise ValueError("size must have 2 elements")
if input_data_format is None:
input_data_format = infer_channel_dimension_format(image)
data_format = input_data_format if data_format is None else data_format
do_rescale = False
if not isinstance(image, PIL.Image.Image):
do_rescale = _rescale_for_pil_conversion(image)
image = to_pil_image(image, do_rescale=do_rescale, input_data_format=input_data_format)
height, width = size
resized_image = image.resize((width, height), resample=resample, reducing_gap=reducing_gap)
resized_image = np.array(resized_image)
resized_image = np.expand_dims(resized_image, axis=-1) if resized_image.ndim == 2 else resized_image
resized_image = to_channel_dimension_format(
resized_image, data_format, input_channel_dim=ChannelDimension.LAST
)
resized_image = rescale(resized_image, 1 / 255) if do_rescale else resized_image
return resized_image
def center_crop(
image: np.ndarray,
size: Tuple[int, int],
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
return_numpy: Optional[bool] = None,
) -> np.ndarray:
"""
Crops the `image` to the specified `size` using a center crop. Note that if the image is too small to be cropped to
the size given, it will be padded (so the returned result will always be of size `size`).
Args:
image (`np.ndarray`):
The input image to be cropped.
size (`Tuple[int, int]`):
The desired output size after cropping, specified as (height, width).
data_format (`Union[str, ChannelDimension]`, *optional*):
The channel dimension format of the output image. If unset, will use the inferred format from the input.
input_data_format (`Union[str, ChannelDimension]`, *optional*):
The channel dimension format of the input image. If unset, will use the inferred format from the input.
return_numpy (`bool`, *optional*):
Deprecated parameter. If provided, this should be set to `True`.
Returns:
`np.ndarray`: The cropped image of the specified `size`.
"""
"""
Args:
image (`np.ndarray`):
The image to crop.
size (`Tuple[int, int]`):
The target size for the cropped image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the output image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
If unset, will use the inferred format of the input image.
input_data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
If unset, will use the inferred format of the input image.
return_numpy (`bool`, *optional*):
Whether or not to return the cropped image as a numpy array. Used for backwards compatibility with the
previous ImageFeatureExtractionMixin method.
- Unset: will return the same type as the input image.
- `True`: will return a numpy array.
- `False`: will return a `PIL.Image.Image` object.
Returns:
`np.ndarray`: The cropped image.
"""
requires_backends(center_crop, ["vision"])
if return_numpy is not None:
warnings.warn("return_numpy is deprecated and will be removed in v.4.33", FutureWarning)
return_numpy = True if return_numpy is None else return_numpy
if not isinstance(image, np.ndarray):
raise ValueError(f"Input image must be of type np.ndarray, got {type(image)}")
if not isinstance(size, Iterable) or len(size) != 2:
raise ValueError("size must have 2 elements representing the height and width of the output image")
if input_data_format is None:
input_data_format = infer_channel_dimension_format(image)
output_data_format = data_format if data_format is not None else input_data_format
image = to_channel_dimension_format(image, ChannelDimension.FIRST, input_data_format)
orig_height, orig_width = get_image_size(image, ChannelDimension.FIRST)
crop_height, crop_width = size
crop_height, crop_width = int(crop_height), int(crop_width)
top = (orig_height - crop_height) // 2
bottom = top + crop_height
left = (orig_width - crop_width) // 2
right = left + crop_width
if top >= 0 and bottom <= orig_height and left >= 0 and right <= orig_width:
image = image[..., top:bottom, left:right]
image = to_channel_dimension_format(image, output_data_format, ChannelDimension.FIRST)
return image
new_height = max(crop_height, orig_height)
new_width = max(crop_width, orig_width)
new_shape = image.shape[:-2] + (new_height, new_width)
new_image = np.zeros_like(image, shape=new_shape)
top_pad = (new_height - orig_height) // 2
bottom_pad = top_pad + orig_height
left_pad = (new_width - orig_width) // 2
right_pad = left_pad + orig_width
new_image[..., top_pad:bottom_pad, left_pad:right_pad] = image
top += top_pad
bottom += top_pad
left += left_pad
right += left_pad
new_image = new_image[..., max(0, top) : min(new_height, bottom), max(0, left) : min(new_width, right)]
new_image = to_channel_dimension_format(new_image, output_data_format, ChannelDimension.FIRST)
if not return_numpy:
new_image = to_pil_image(new_image)
return new_image
def _center_to_corners_format_torch(bboxes_center: "torch.Tensor") -> "torch.Tensor":
center_x, center_y, width, height = bboxes_center.unbind(-1)
bbox_corners = torch.stack(
[(center_x - 0.5 * width), (center_y - 0.5 * height), (center_x + 0.5 * width), (center_y + 0.5 * height)],
dim=-1,
)
return bbox_corners
def _center_to_corners_format_numpy(bboxes_center: np.ndarray) -> np.ndarray:
center_x, center_y, width, height = bboxes_center.T
bboxes_corners = np.stack(
[center_x - 0.5 * width, center_y - 0.5 * height, center_x + 0.5 * width, center_y + 0.5 * height],
axis=-1,
)
return bboxes_corners
def _center_to_corners_format_tf(bboxes_center: "tf.Tensor") -> "tf.Tensor":
center_x, center_y, width, height = tf.unstack(bboxes_center, axis=-1)
bboxes_corners = tf.stack(
[center_x - 0.5 * width, center_y - 0.5 * height, center_x + 0.5 * width, center_y + 0.5 * height],
axis=-1,
)
return bboxes_corners
def center_to_corners_format(bboxes_center: TensorType) -> TensorType:
"""
Converts bounding boxes from center format to corners format.
center format: contains the coordinate for the center of the box and its width, height dimensions
(center_x, center_y, width, height)
corners format: contains the coodinates for the top-left and bottom-right corners of the box
(top_left_x, top_left_y, bottom_right_x, bottom_right_y)
"""
if is_torch_tensor(bboxes_center):
return _center_to_corners_format_torch(bboxes_center)
elif isinstance(bboxes_center, np.ndarray):
return _center_to_corners_format_numpy(bboxes_center)
elif is_tf_tensor(bboxes_center):
return _center_to_corners_format_tf(bboxes_center)
raise ValueError(f"Unsupported input type {type(bboxes_center)}")
def _corners_to_center_format_torch(bboxes_corners: "torch.Tensor") -> "torch.Tensor":
top_left_x, top_left_y, bottom_right_x, bottom_right_y = bboxes_corners.unbind(-1)
b = [
(top_left_x + bottom_right_x) / 2,
(top_left_y + bottom_right_y) / 2,
(bottom_right_x - top_left_x),
(bottom_right_y - top_left_y),
]
return torch.stack(b, dim=-1)
def _corners_to_center_format_numpy(bboxes_corners: np.ndarray) -> np.ndarray:
top_left_x, top_left_y, bottom_right_x, bottom_right_y = bboxes_corners.T
bboxes_center = np.stack(
[
(top_left_x + bottom_right_x) / 2,
(top_left_y + bottom_right_y) / 2,
(bottom_right_x - top_left_x),
(bottom_right_y - top_left_y),
],
axis=-1,
)
return bboxes_center
def _corners_to_center_format_tf(bboxes_corners: "tf.Tensor") -> "tf.Tensor":
"""
Converts bounding boxes from corners format to center format using TensorFlow operations.
Args:
bboxes_corners (tf.Tensor): Tensor containing bounding box coordinates in corners format
(top_left_x, top_left_y, bottom_right_x, bottom_right_y)
Returns:
tf.Tensor: Tensor containing bounding box coordinates in center format
(center_x, center_y, width, height)
"""
top_left_x, top_left_y, bottom_right_x, bottom_right_y = tf.unstack(bboxes_corners, axis=-1)
bboxes_center = tf.stack(
[
(top_left_x + bottom_right_x) / 2,
(top_left_y + bottom_right_y) / 2,
(bottom_right_x - top_left_x),
(bottom_right_y - top_left_y),
],
axis=-1,
)
return bboxes_center
def corners_to_center_format(bboxes_corners: TensorType) -> TensorType:
"""
Converts bounding boxes from corners format to center format.
Args:
bboxes_corners (TensorType): Tensor or array containing bounding box coordinates in corners format
(top_left_x, top_left_y, bottom_right_x, bottom_right_y)
Returns:
TensorType: Tensor or array containing bounding box coordinates in center format
(center_x, center_y, width, height)
Raises:
ValueError: If the input type is unsupported
"""
if is_torch_tensor(bboxes_corners):
return _corners_to_center_format_torch(bboxes_corners)
elif isinstance(bboxes_corners, np.ndarray):
return _corners_to_center_format_numpy(bboxes_corners)
elif is_tf_tensor(bboxes_corners):
return _corners_to_center_format_tf(bboxes_corners)
raise ValueError(f"Unsupported input type {type(bboxes_corners)}")
def rgb_to_id(color):
"""
Converts RGB color to unique ID.
Args:
color (np.ndarray or list): RGB color values
Returns:
int: Unique ID corresponding to the RGB color
"""
if isinstance(color, np.ndarray) and len(color.shape) == 3:
if color.dtype == np.uint8:
color = color.astype(np.int32)
return color[:, :, 0] + 256 * color[:, :, 1] + 256 * 256 * color[:, :, 2]
return int(color[0] + 256 * color[1] + 256 * 256 * color[2])
def id_to_rgb(id_map):
"""
Converts unique ID to RGB color.
Args:
id_map (np.ndarray or int): Unique ID or array of IDs
Returns:
np.ndarray or list: RGB color corresponding to the unique ID or array of RGB colors
"""
if isinstance(id_map, np.ndarray):
id_map_copy = id_map.copy()
rgb_shape = tuple(list(id_map.shape) + [3])
rgb_map = np.zeros(rgb_shape, dtype=np.uint8)
for i in range(3):
rgb_map[..., i] = id_map_copy % 256
id_map_copy //= 256
return rgb_map
color = []
for _ in range(3):
color.append(id_map % 256)
id_map //= 256
return color
class PaddingMode(ExplicitEnum):
"""
Enum class for the different padding modes to use when padding images.
"""
CONSTANT = "constant"
REFLECT = "reflect"
REPLICATE = "replicate"
SYMMETRIC = "symmetric"
def pad(
image: np.ndarray,
padding: Union[int, Tuple[int, int], Iterable[Tuple[int, int]]],
mode: PaddingMode = PaddingMode.CONSTANT,
constant_values: Union[float, Iterable[float]] = 0.0,
data_format: Optional[Union[str, ChannelDimension]] = None,
):
"""
Pads an image array according to specified parameters.
Args:
image (np.ndarray): Image array to be padded.
padding (int or Tuple[int, int] or Iterable[Tuple[int, int]]): Padding size or sizes in each dimension.
mode (PaddingMode, optional): Padding mode, defaults to PaddingMode.CONSTANT.
constant_values (float or Iterable[float], optional): Constant value(s) to pad with, defaults to 0.0.
data_format (str or ChannelDimension, optional): Data format of the image array, defaults to None.
Returns:
np.ndarray: Padded image array.
"""
input_data_format: Optional[Union[str, ChannelDimension]] = None,
def pad_image(image: np.ndarray,
padding: Union[int, Tuple[int, int], Iterable[Tuple[int, int]]],
mode: PaddingMode,
constant_values: Optional[Union[float, Iterable[float]]] = None,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None) -> np.ndarray:
"""
Pads the `image` with the specified (height, width) `padding` and `mode`.
Args:
image (`np.ndarray`):
The image to pad.
padding (`int` or `Tuple[int, int]` or `Iterable[Tuple[int, int]]`):
Padding to apply to the edges of the height, width axes. Can be one of three formats:
- `((before_height, after_height), (before_width, after_width))` unique pad widths for each axis.
- `((before, after),)` yields same before and after pad for height and width.
- `(pad,)` or int is a shortcut for before = after = pad width for all axes.
mode (`PaddingMode`):
The padding mode to use. Can be one of:
- `"constant"`: pads with a constant value.
- `"reflect"`: pads with the reflection of the vector mirrored on the first and last values of the
vector along each axis.
- `"replicate"`: pads with the replication of the last value on the edge of the array along each axis.
- `"symmetric"`: pads with the reflection of the vector mirrored along the edge of the array.
constant_values (`float` or `Iterable[float]`, *optional*):
The value to use for the padding if `mode` is `"constant"`.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the output image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
If unset, will use same as the input image.
input_data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
If unset, will use the inferred format of the input image.
Returns:
`np.ndarray`: The padded image.
"""
if input_data_format is None:
input_data_format = infer_channel_dimension_format(image)
def _expand_for_data_format(values):
"""
Convert values to be in the format expected by np.pad based on the data format.
"""
if isinstance(values, (int, float)):
values = ((values, values), (values, values))
elif isinstance(values, tuple) and len(values) == 1:
values = ((values[0], values[0]), (values[0], values[0]))
elif isinstance(values, tuple) and len(values) == 2 and isinstance(values[0], int):
values = (values, values)
elif isinstance(values, tuple) and len(values) == 2 and isinstance(values[0], tuple):
values = values
else:
raise ValueError(f"Unsupported format: {values}")
values = ((0, 0), *values) if input_data_format == ChannelDimension.FIRST else (*values, (0, 0))
values = (0, *values) if image.ndim == 4 else values
return values
padding = _expand_for_data_format(padding)
if mode == PaddingMode.CONSTANT:
constant_values = _expand_for_data_format(constant_values)
image = np.pad(image, padding, mode="constant", constant_values=constant_values)
elif mode == PaddingMode.REFLECT:
image = np.pad(image, padding, mode="reflect")
elif mode == PaddingMode.REPLICATE:
image = np.pad(image, padding, mode="edge")
elif mode == PaddingMode.SYMMETRIC:
image = np.pad(image, padding, mode="symmetric")
else:
raise ValueError(f"Invalid padding mode: {mode}")
image = to_channel_dimension_format(image, data_format, input_data_format) if data_format is not None else image
return image
def convert_to_rgb(image: ImageInput) -> ImageInput:
"""
Converts an image to RGB format. Only converts if the image is of type PIL.Image.Image, otherwise returns the image
as is.
Args:
image (Image):
The image to convert.
"""
requires_backends(convert_to_rgb, ["vision"])
if not isinstance(image, PIL.Image.Image):
return image
image = image.convert("RGB")
return image
def flip_channel_order(
image: np.ndarray,
data_format: Optional[ChannelDimension] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray:
"""
Flips the channel order of the image.
If the image is in RGB format, it will be converted to BGR and vice versa.
Args:
image (`np.ndarray`):
The image to flip.
data_format (`ChannelDimension`, *optional*):
The channel dimension format for the output image. Can be one of:
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
If unset, will use same as the input image.
input_data_format (`ChannelDimension`, *optional*):
The channel dimension format for the input image. Can be one of:
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
If unset, will use the inferred format of the input image.
"""
input_data_format = infer_channel_dimension_format(image) if input_data_format is None else input_data_format
if input_data_format == ChannelDimension.LAST:
image = image[..., ::-1]
elif input_data_format == ChannelDimension.FIRST:
image = image[::-1, ...]
else:
raise ValueError(f"Unsupported channel dimension: {input_data_format}")
if data_format is not None:
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
return image