Transformers 源码解析(八十八)
.\models\perceiver\tokenization_perceiver.py
""" Perceiver 的分词器类。"""
from typing import Dict, List, Optional, Tuple
from ...tokenization_utils import AddedToken, PreTrainedTokenizer
from ...utils import logging
logger = logging.get_logger(__name__)
class PerceiverTokenizer(PreTrainedTokenizer):
"""
构建一个 Perceiver 分词器。Perceiver 简单地使用原始字节 utf-8 编码。
这个分词器继承自 [`PreTrainedTokenizer`],该类包含大部分主要方法。用户应参考这个父类获取更多有关这些方法的信息。
Args:
pad_token (`str`, *optional*, defaults to `"[PAD]"`):
用于填充的标记,在批处理不同长度的序列时使用。
bos_token (`str`, *optional*, defaults to `"[BOS]"`):
BOS 标记(在词汇表中保留,但实际上不使用)。
eos_token (`str`, *optional*, defaults to `"[EOS]"`):
序列结束标记(在词汇表中保留,但实际上不使用)。
<Tip>
当使用特殊标记构建序列时,这不是实际用于序列结束的标记。
实际用于结束序列的标记是 `sep_token`。
</Tip>
mask_token (`str`, *optional*, defaults to `"[MASK]"`):
用于掩码语言建模的 MASK 标记。
cls_token (`str`, *optional*, defaults to `"[CLS]"`):
CLS 标记(在词汇表中保留,但实际上不使用)。
sep_token (`str`, *optional*, defaults to `"[SEP]"`):
分隔符标记,在从两个序列构建一个序列时使用。
"""
model_input_names = ["input_ids", "attention_mask"]
def __init__(
self,
pad_token="[PAD]",
bos_token="[BOS]",
eos_token="[EOS]",
mask_token="[MASK]",
cls_token="[CLS]",
sep_token="[SEP]",
model_max_length=2048,
**kwargs,
):
super().__init__(
pad_token=pad_token,
bos_token=bos_token,
eos_token=eos_token,
mask_token=mask_token,
cls_token=cls_token,
sep_token=sep_token,
**kwargs,
)
) -> None:
pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
mask_token = AddedToken(mask_token, lstrip=False, rstrip=False) if isinstance(mask_token, str) else mask_token
cls_token = AddedToken(cls_token, lstrip=False, rstrip=False) if isinstance(cls_token, str) else cls_token
sep_token = AddedToken(sep_token, lstrip=False, rstrip=False) if isinstance(sep_token, str) else sep_token
self._utf_vocab_size = 2**8
self._added_tokens_decoder: Dict[str, int] = {
0: pad_token,
1: bos_token,
2: eos_token,
3: mask_token,
4: cls_token,
5: sep_token,
}
self._num_special_tokens = len(self._added_tokens_decoder)
super().__init__(
pad_token=pad_token,
bos_token=bos_token,
eos_token=eos_token,
mask_token=mask_token,
cls_token=cls_token,
sep_token=sep_token,
model_max_length=model_max_length,
**kwargs,
)
def get_vocab(self) -> Dict[str, int]:
vocab = {}
for i in range(self._utf_vocab_size):
token = chr(i)
vocab[token] = i + self._num_special_tokens
vocab.update(self.added_tokens_encoder)
return vocab
@property
def vocab_size(self):
return self._utf_vocab_size
def get_special_tokens_mask(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
def get_special_tokens_mask(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
) -> List[int]:
"""
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer `prepare_for_model` method.
Args:
token_ids_0 (`List[int]`):
List of IDs.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not the token list is already formatted with special tokens for the model.
Returns:
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
"""
if already_has_special_tokens:
return super().get_special_tokens_mask(
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
)
if token_ids_1 is None:
return [1] + [0] * len(token_ids_0) + [1]
else:
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Build model inputs from a sequence or a pair of sequence for sequence classification tasks. A sequence has the
following format:
- single sequence: `[CLS] X [SEP]`
- pair of sequences: `[CLS] A [SEP] B [SEP]`
Args:
token_ids_0 (`List[int]`):
List of IDs to which the special tokens will be added.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
"""
if token_ids_1 is None:
return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
else:
return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + token_ids_1 + [self.sep_token_id]
def _tokenize(self, text: str) -> List[str]:
"""Take as input a string and return a list of strings (tokens) for words/sub-words"""
tokens = [chr(i) for i in text.encode("utf-8")]
return tokens
def _convert_token_to_id(self, token):
"""Converts a token (str) into an ID using the vocabulary."""
if len(token) != 1:
token_id = self.unk_token_id
else:
token_id = ord(token) + self._num_special_tokens
return token_id
def _convert_id_to_token(self, index):
"""Converts an index (integer) into a token (str) using the vocabulary."""
token = chr(index - self._num_special_tokens)
return token
def convert_tokens_to_string(self, tokens):
bstring = b""
for token in tokens:
if token in self.added_tokens_encoder:
tok_string = str(token).encode("utf-8")
else:
tok_string = bytes([ord(token)])
bstring += tok_string
string = bstring.decode("utf-8", errors="replace")
return string
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
return ()
.\models\perceiver\__init__.py
from typing import TYPE_CHECKING
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_tokenizers_available,
is_torch_available,
is_vision_available,
)
_import_structure = {
"configuration_perceiver": ["PERCEIVER_PRETRAINED_CONFIG_ARCHIVE_MAP", "PerceiverConfig", "PerceiverOnnxConfig"],
"tokenization_perceiver": ["PerceiverTokenizer"],
}
try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["feature_extraction_perceiver"] = ["PerceiverFeatureExtractor"]
_import_structure["image_processing_perceiver"] = ["PerceiverImageProcessor"]
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_perceiver"] = [
"PERCEIVER_PRETRAINED_MODEL_ARCHIVE_LIST",
"PerceiverForImageClassificationConvProcessing",
"PerceiverForImageClassificationFourier",
"PerceiverForImageClassificationLearned",
"PerceiverForMaskedLM",
"PerceiverForMultimodalAutoencoding",
"PerceiverForOpticalFlow",
"PerceiverForSequenceClassification",
"PerceiverLayer",
"PerceiverModel",
"PerceiverPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_perceiver import PERCEIVER_PRETRAINED_CONFIG_ARCHIVE_MAP, PerceiverConfig, PerceiverOnnxConfig
from .tokenization_perceiver import PerceiverTokenizer
try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .feature_extraction_perceiver import PerceiverFeatureExtractor
from .image_processing_perceiver import PerceiverImageProcessor
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_perceiver import (
PERCEIVER_PRETRAINED_MODEL_ARCHIVE_LIST,
PerceiverForImageClassificationConvProcessing,
PerceiverForImageClassificationFourier,
PerceiverForImageClassificationLearned,
PerceiverForMaskedLM,
PerceiverForMultimodalAutoencoding,
PerceiverForOpticalFlow,
PerceiverForSequenceClassification,
PerceiverLayer,
PerceiverModel,
PerceiverPreTrainedModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
.\models\persimmon\configuration_persimmon.py
""" Persimmon model configuration"""
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
PERSIMMON_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"adept/persimmon-8b-base": "https://huggingface.co/adept/persimmon-8b-base/resolve/main/config.json",
}
class PersimmonConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`PersimmonModel`]. It is used to instantiate an
Persimmon model according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the
[adept/persimmon-8b-base](https://huggingface.co/adept/persimmon-8b-base).
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
```
>>> from transformers import PersimmonModel, PersimmonConfig
>>> # Initializing a Persimmon persimmon-7b style configuration
>>> configuration = PersimmonConfig()
```
"""
model_type = "persimmon"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=262144,
hidden_size=4096,
intermediate_size=16384,
num_hidden_layers=36,
num_attention_heads=64,
hidden_act="relu2",
max_position_embeddings=16384,
initializer_range=0.02,
layer_norm_eps=1e-5,
use_cache=True,
tie_word_embeddings=False,
rope_theta=25000.0,
rope_scaling=None,
qk_layernorm=True,
hidden_dropout=0.0,
attention_dropout=0.0,
partial_rotary_factor=0.5,
pad_token_id=None,
bos_token_id=1,
eos_token_id=2,
**kwargs,
):
pass
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.qk_layernorm = qk_layernorm
self.hidden_dropout = hidden_dropout
self.attention_dropout = attention_dropout
self.partial_rotary_factor = partial_rotary_factor
self._rope_scaling_validation()
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
def __init__(
vocab_size,
max_position_embeddings,
hidden_size,
intermediate_size,
num_hidden_layers,
num_attention_heads,
hidden_act,
initializer_range,
layer_norm_eps,
use_cache,
rope_theta,
rope_scaling,
qk_layernorm,
hidden_dropout,
attention_dropout,
partial_rotary_factor,
pad_token_id=None,
bos_token_id=None,
eos_token_id=None,
tie_word_embeddings=False,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.qk_layernorm = qk_layernorm
self.hidden_dropout = hidden_dropout
self.attention_dropout = attention_dropout
self.partial_rotary_factor = partial_rotary_factor
self._rope_scaling_validation()
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
.\models\persimmon\convert_persimmon_weights_to_hf.py
import argparse
import os
import warnings
import flatdict
import torch
from transformers import LlamaTokenizer, PersimmonConfig, PersimmonForCausalLM
try:
from transformers import LlamaTokenizerFast
tokenizer_class = LlamaTokenizerFast
except ImportError as e:
warnings.warn(e)
warnings.warn(
"The converted tokenizer will be the `slow` tokenizer. To use the fast, update your `tokenizers` library and re-run the tokenizer conversion"
)
tokenizer_class = LlamaTokenizer
"""
示例用法:
git clone https://github.com/persimmon-ai-labs/adept-inference
wget https://axtkn4xl5cip.objectstorage.us-phoenix-1.oci.customer-oci.com/n/axtkn4xl5cip/b/adept-public-data/o/8b_base_model_release.tar
wget https://axtkn4xl5cip.objectstorage.us-phoenix-1.oci.customer-oci.com/n/axtkn4xl5cip/b/adept-public-data/o/8b_chat_model_release.tar
python src/transformers/models/persimmon/convert_persimmon_weights_to_hf.py --input_dir /path/to/downloaded/persimmon/weights/ --output_dir /output/path
"""
KEYS_TO_MODIFY_MAPPING = {
"self_attention": "self_attn",
"language_model.encoder": "model",
"word_embeddings_for_head": "lm_head",
"language_model.embedding.word_embeddings": "model.embed_tokens",
}
KEYS_TO_REMOVE = "rotary_emb.inv_freq"
def rename_state_dict(state_dict):
model_state_dict = {}
for key, value in state_dict.items():
for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items():
if key_to_modify in key:
key = key.replace(key_to_modify, new_key)
if KEYS_TO_REMOVE in key:
continue
model_state_dict[key] = value
return model_state_dict
def convert_persimmon_checkpoint(pytorch_dump_folder_path, ada_lib_path, pt_model_path, safe_serialization=False):
import sys
sys.path.insert(0, ada_lib_path)
model_state_dict_base = torch.load(pt_model_path, map_location="cpu")
state_dict = flatdict.FlatDict(model_state_dict_base["model"], ".")
state_dict = rename_state_dict(state_dict)
transformers_config = PersimmonConfig()
model = PersimmonForCausalLM(transformers_config, eos_token_id=71013, bos_token_id=71013).to(torch.bfloat16)
model.load_state_dict(state_dict)
model.save_pretrained(pytorch_dump_folder_path, safe_serialization=safe_serialization)
transformers_config.save_pretrained(pytorch_dump_folder_path)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--input_dir",
help="Location of Persimmon weights, which contains tokenizer.model and model folders",
)
parser.add_argument(
"--pt_model_path",
help="Location of Persimmon `model_optim_rng.pt`",
)
parser.add_argument(
"--output_dir",
help="Location to write HF model and tokenizer",
)
parser.add_argument(
"--ada_lib_path",
help="Location to write HF model and tokenizer",
)
parser.add_argument("--safe_serialization", type=bool, help="Whether or not to save using `safetensors`.")
args = parser.parse_args()
spm_path = os.path.join(args.input_dir, "adept_vocab.model")
convert_persimmon_checkpoint(
pytorch_dump_folder_path=args.output_dir,
pt_model_path=args.pt_model_path,
safe_serialization=args.safe_serialization,
ada_lib_path=args.ada_lib_path,
)
tokenizer = tokenizer_class(spm_path, bos_token="|ENDOFTEXT|", eos_token="|ENDOFTEXT|")
tokenizer.save_pretrained(args.output_dir)
if __name__ == "__main__":
main()
.\models\persimmon\modeling_persimmon.py
""" PyTorch Persimmon model."""
import math
from typing import List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
from ...modeling_utils import PreTrainedModel
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_persimmon import PersimmonConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "PersimmonConfig"
class PersimmonRotaryEmbedding(nn.Module):
"""
Rotary positional embedding for Persimmon model.
"""
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
"""
Initialize the PersimmonRotaryEmbedding module.
Args:
dim (int): Dimensionality of the embedding.
max_position_embeddings (int): Maximum number of positions to embed.
base (int): Base value for rotational frequencies.
device (Optional[torch.device]): Device to store the embeddings.
"""
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self._set_cos_sin_cache(
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
)
def _set_cos_sin_cache(self, seq_len, device, dtype):
"""
Precompute and store cosine and sine values.
Args:
seq_len (int): Length of sequence to compute values for.
device (torch.device): Device to store the cache tensors.
dtype (torch.dtype): Data type of the cache tensors.
"""
pass
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
def forward(self, x, seq_len=None):
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
return (
self.cos_cached[:seq_len].to(dtype=x.dtype),
self.sin_cached[:seq_len].to(dtype=x.dtype),
)
class PersimmonLinearScalingRotaryEmbedding(PersimmonRotaryEmbedding):
"""PersimmonRotaryEmbedding扩展了线性缩放。鸣谢Reddit用户/u/kaiokendev"""
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
self.scaling_factor = scaling_factor
super().__init__(dim, max_position_embeddings, base, device)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
t = t / self.scaling_factor
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
class PersimmonDynamicNTKScalingRotaryEmbedding(PersimmonRotaryEmbedding):
"""PersimmonRotaryEmbedding扩展了动态NTK缩放。鸣谢Reddit用户/u/bloc97和/u/emozilla"""
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
self.scaling_factor = scaling_factor
super().__init__(dim, max_position_embeddings, base, device)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
if seq_len > self.max_position_embeddings:
base = self.base * (
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
) ** (self.dim / (self.dim - 2))
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
def rotate_half(x):
"""旋转输入张量一半的隐藏维度。"""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`):
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
used to pass offsetted position ids when working with a KV-cache.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class PersimmonMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.dense_h_to_4h = nn.Linear(config.hidden_size, config.intermediate_size)
self.dense_4h_to_h = nn.Linear(config.intermediate_size, config.hidden_size)
self.act = ACT2FN[config.hidden_act]
def forward(self, hidden_states):
hidden_states = self.dense_h_to_4h(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.dense_4h_to_h(hidden_states)
return hidden_states
class PersimmonAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: PersimmonConfig, layer_idx: Optional[int] = None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
if layer_idx is None:
logger.warning_once(
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.partial_rotary_factor = config.partial_rotary_factor
self.is_causal = True
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True)
self.dense = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=True)
self.qk_layernorm = config.qk_layernorm
if self.qk_layernorm:
self.q_layernorm = nn.LayerNorm(
config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True
)
self.k_layernorm = nn.LayerNorm(
config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True
)
self.attention_dropout = nn.Dropout(config.attention_dropout)
self._init_rope()
def _init_rope(self):
if self.config.rope_scaling is None:
self.rotary_emb = PersimmonRotaryEmbedding(
int(self.partial_rotary_factor * self.head_dim),
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)
else:
scaling_type = self.config.rope_scaling["type"]
scaling_factor = self.config.rope_scaling["factor"]
if scaling_type == "linear":
self.rotary_emb = PersimmonLinearScalingRotaryEmbedding(
int(self.partial_rotary_factor * self.head_dim),
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
)
elif scaling_type == "dynamic":
self.rotary_emb = PersimmonDynamicNTKScalingRotaryEmbedding(
int(self.partial_rotary_factor * self.head_dim),
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Split the last dimension into (num_heads, head_dim) without making any copies, results share same memory
storage as `fused_qkv`
Args:
fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim]
Returns:
query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim]
value: [batch_size, seq_length, num_heads, head_dim]
"""
batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim)
return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :]
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
class PersimmonDecoderLayer(nn.Module):
def __init__(self, config: PersimmonConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = PersimmonAttention(config=config, layer_idx=layer_idx)
self.mlp = PersimmonMLP(config)
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range
`[0, config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids)
past_key_value (`Tuple(torch.FloatTensor)`, *optional*):
cached past key and value projection states
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = hidden_states + residual
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
PERSIMMON_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`PersimmonConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
@add_start_docstrings(
"The bare Persimmon Model outputting raw hidden-states without any specific head on top.",
PERSIMMON_START_DOCSTRING,
)
class PersimmonPreTrainedModel(PreTrainedModel):
config_class = PersimmonConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["PersimmonDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_cache_class = True
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
PERSIMMON_INPUTS_DOCSTRING = r"""
"""
@add_start_docstrings(
"The bare Persimmon Model outputting raw hidden-states without any specific head on top.",
PERSIMMON_START_DOCSTRING,
)
class PersimmonModel(PersimmonPreTrainedModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`PersimmonDecoderLayer`]
Args:
config: PersimmonConfig
"""
def __init__(self, config: PersimmonConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList(
[PersimmonDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.gradient_checkpointing = False
self.post_init()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
@add_start_docstrings_to_model_forward(PERSIMMON_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
class PersimmonForCausalLM(PersimmonPreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.model = PersimmonModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
@add_start_docstrings_to_model_forward(PERSIMMON_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
pass
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
pass
):
if past_key_values is not None:
if isinstance(past_key_values, Cache):
cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens
max_cache_length = past_key_values.get_max_length()
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
elif past_length < input_ids.shape[1]:
input_ids = input_ids[:, past_length:]
if (
max_cache_length is not None
and attention_mask is not None
and cache_length + input_ids.shape[1] > max_cache_length
):
attention_mask = attention_mask[:, -max_cache_length:]
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
}
)
return model_inputs
@staticmethod
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past
.\models\persimmon\__init__.py
from typing import TYPE_CHECKING
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_torch_available,
)
_import_structure = {
"configuration_persimmon": ["PERSIMMON_PRETRAINED_CONFIG_ARCHIVE_MAP", "PersimmonConfig"],
}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_persimmon"] = [
"PersimmonForCausalLM",
"PersimmonModel",
"PersimmonPreTrainedModel",
"PersimmonForSequenceClassification",
]
if TYPE_CHECKING:
from .configuration_persimmon import PERSIMMON_PRETRAINED_CONFIG_ARCHIVE_MAP, PersimmonConfig
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_persimmon import (
PersimmonForCausalLM,
PersimmonForSequenceClassification,
PersimmonModel,
PersimmonPreTrainedModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
.\models\phi\configuration_phi.py
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
PHI_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"microsoft/phi-1": "https://huggingface.co/microsoft/phi-1/resolve/main/config.json",
"microsoft/phi-1_5": "https://huggingface.co/microsoft/phi-1_5/resolve/main/config.json",
"microsoft/phi-2": "https://huggingface.co/microsoft/phi-2/resolve/main/config.json",
}
class PhiConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`PhiModel`]. It is used to instantiate an Phi
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the Phi
[microsoft/phi-1](https://huggingface.co/microsoft/phi-1).
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Example:
```
>>> from transformers import PhiModel, PhiConfig
>>> # Initializing a Phi-1 style configuration
>>> configuration = PhiConfig.from_pretrained("microsoft/phi-1")
>>> # Initializing a model from the configuration
>>> model = PhiModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
"""
model_type = "phi"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=51200,
hidden_size=2048,
intermediate_size=8192,
num_hidden_layers=24,
num_attention_heads=32,
num_key_value_heads=None,
resid_pdrop=0.0,
embd_pdrop=0.0,
attention_dropout=0.0,
hidden_act="gelu_new",
max_position_embeddings=2048,
initializer_range=0.02,
layer_norm_eps=1e-5,
use_cache=True,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
partial_rotary_factor=0.5,
qk_layernorm=False,
bos_token_id=1,
eos_token_id=2,
**kwargs,
):
pass
):
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.resid_pdrop = resid_pdrop
self.embd_pdrop = embd_pdrop
self.attention_dropout = attention_dropout
self.hidden_act = hidden_act
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.partial_rotary_factor = partial_rotary_factor
self.qk_layernorm = qk_layernorm
self._rope_scaling_validation()
super().__init__(
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
def _rope_scaling_validation(self):
"""
Validate the `rope_scaling` configuration.
"""
if self.rope_scaling is None:
return
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
raise ValueError(
"`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
f"got {self.rope_scaling}"
)
rope_scaling_type = self.rope_scaling.get("type", None)
rope_scaling_factor = self.rope_scaling.get("factor", None)
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
raise ValueError(
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
)
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
.\models\phi\convert_phi_weights_to_hf.py
import argparse
import gc
import os
import safetensors
import torch
from huggingface_hub import hf_hub_download
from transformers import PhiConfig, PhiForCausalLM
_MODELS = {
"microsoft/phi-1": ["https://huggingface.co/microsoft/phi-1/blob/main/pytorch_model.bin"],
"microsoft/phi-1_5": ["https://huggingface.co/microsoft/phi-1_5/blob/main/pytorch_model.bin"],
"microsoft/phi-2": [
"https://huggingface.co/microsoft/phi-2/blob/main/model-00001-of-00002.safetensors",
"https://huggingface.co/microsoft/phi-2/blob/main/model-00002-of-00002.safetensors",
],
}
PHI_MAPPING = {
"transformer.embd.wte.weight": "model.embed_tokens.weight",
"lm_head.linear": "lm_head",
"lm_head.ln": "model.final_layernorm",
"layers": "model.layers",
"transformer": "model",
".h.": ".layers.",
"ln": "input_layernorm",
"mixer": "self_attn",
"Wqkv": "query_key_value",
"out_proj": "dense",
}
def convert_weights(original_weights, mapping, config):
converted_weights = {}
original_weights_keys = sorted(original_weights.keys())
for original_weights_key in original_weights_keys:
new_key = original_weights_key
if "rotary_emb" in new_key:
continue
if "Wqkv" in new_key:
if "weight" in new_key:
weight = original_weights[new_key]
weights_shape = weight.shape
weight = (
weight.view(3, config.num_attention_heads, -1, config.hidden_size)
.transpose(0, 1)
.reshape(*weights_shape)
)
original_weights[new_key] = weight
elif "bias" in new_key:
bias = original_weights[new_key]
bias_shape = bias.shape
bias = bias.view(3, config.num_attention_heads, -1).transpose(0, 1).reshape(*bias_shape)
original_weights[new_key] = bias
for k, v in mapping.items():
if k in new_key:
new_key = new_key.replace(k, v)
converted_weights[new_key] = original_weights.pop(original_weights_key)
return converted_weights
repo_id = f"{url.split('/')[3]}/{url.split('/')[4]}"
filename = f"{url.split('/')[-1]}"
hf_hub_download(
repo_id=repo_id,
filename=filename,
force_filename=root,
local_dir_use_symlinks=False,
)
def convert_phi_weights(
model_name, checkpoint_path, pytorch_dump_folder_path, use_cuda, save_weights_directly, _MODELS
):
_MODELS = _MODELS if model_name not in _MODELS.keys() else {model_name: _MODELS.get(model_name)}
device = "cuda" if torch.cuda.is_available() and use_cuda else "cpu"
for model_name, model_url in _MODELS.items():
converted_checkpoint = {}
model_checkpoint = {}
for model_each_url in model_url:
model_path = os.path.join(checkpoint_path, model_name + "_" + model_each_url.split("/")[-1])
if not os.path.exists(model_path):
print(f"\n{model_name} was not found! Downloading it to {model_path}")
_download(url=model_each_url, root=model_path)
if model_path.endswith("safetensors"):
loaded_weights = safetensors.torch.load_file(model_path, device=device)
else:
loaded_weights = torch.load(model_path, map_location=device)
model_checkpoint.update(**loaded_weights)
model_type = model_name.split("/")[1]
config = PhiConfig()
if model_type == "phi-2":
config.hidden_size = 2560
config.intermediate_size = 10240
config.num_hidden_layers = 32
config.resid_pdrop = 0.1
config.partial_rotary_factor = 0.4
config.torch_dtype = "float16"
converted_checkpoint.update(**convert_weights(model_checkpoint, PHI_MAPPING, config))
if save_weights_directly:
save_weights_path = os.path.join(pytorch_dump_folder_path, model_type + "_pytorch_model.bin")
torch.save(converted_checkpoint, save_weights_path)
print(f"Model weights saved at {save_weights_path}!")
else:
model = PhiForCausalLM(config).to(device)
model.load_state_dict(converted_checkpoint, strict=True)
save_model_path = os.path.join(pytorch_dump_folder_path, model_type)
model.save_pretrained(save_model_path)
print(f"Model saved at {save_model_path}!")
del config, model
del model_checkpoint, converted_checkpoint
if use_cuda:
torch.cuda.empty_cache()
gc.collect()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name",
type=str,
help="要转换的模型名称。请选择其中之一:phi-1, phi-1_5, phi-2。如果未提供,则转换所有模型。",
default=None,
)
parser.add_argument(
"--checkpoint_path", type=str, help="已下载检查点文件夹的路径。(请输入完整路径)"
)
parser.add_argument(
"--pytorch_dump_folder_path",
default=None,
type=str,
help="PyTorch 模型输出路径。(请输入完整路径)",
)
parser.add_argument(
"--use_cuda",
default=False,
type=bool,
help="在转换过程中是否将权重加载到 GPU 上,默认为 False",
)
parser.add_argument(
"--save_weights_directly",
default=True,
type=bool,
help="是否直接保存转换后的权重,或者将权重加载到 Phi 模型中再保存。默认为 True",
)
args = parser.parse_args()
convert_phi_weights(
args.model_name,
args.checkpoint_path,
args.pytorch_dump_folder_path,
args.use_cuda,
args.save_weights_directly,
_MODELS,
)
.\models\phi\modeling_phi.py
""" PyTorch Phi model. """
import math
from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from packaging import version
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
get_torch_version,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
)
from .configuration_phi import PhiConfig
if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "microsoft/phi-1"
_CONFIG_FOR_DOC = "PhiConfig"
PHI_PRETRAINED_MODEL_ARCHIVE_LIST = [
"microsoft/phi-1",
"microsoft/phi-1_5",
"microsoft/phi-2",
]
def _get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)
class PhiRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self._set_cos_sin_cache(
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
def forward(self, x, seq_len=None):
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
return (
self.cos_cached[:seq_len].to(dtype=x.dtype),
self.sin_cached[:seq_len].to(dtype=x.dtype),
)
class PhiLinearScalingRotaryEmbedding(PhiRotaryEmbedding):
"""PhiRotaryEmbedding扩展了线性缩放。感谢Reddit用户/u/kaiokendev"""
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
self.scaling_factor = scaling_factor
super().__init__(dim, max_position_embeddings, base, device)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
t = t / self.scaling_factor
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
class PhiDynamicNTKScalingRotaryEmbedding(PhiRotaryEmbedding):
"""PhiRotaryEmbedding扩展了动态NTK缩放。感谢Reddit用户/u/bloc97和/u/emozilla"""
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
self.scaling_factor = scaling_factor
super().__init__(dim, max_position_embeddings, base, device)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
if seq_len > self.max_position_embeddings:
base = self.base * (
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
) ** (self.dim / (self.dim - 2))
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
def rotate_half(x):
"""旋转输入张量一半的隐藏维度。"""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""
Args:
q (`torch.Tensor`): 查询张量。
k (`torch.Tensor`): 键张量。
cos (`torch.Tensor`): 旋转嵌入的余弦部分。
sin (`torch.Tensor`): 旋转嵌入的正弦部分。
position_ids (`torch.Tensor`): 表示与查询和键张量对应的位置索引。
unsqueeze_dim (`int`, *可选*, 默认为 1):
'unsqueeze_dim' 参数指定沿其进行展开的维度,以便将 cos[position_ids] 和 sin[position_ids] 广播到 q 和 k 的维度。
例如,如果 cos[position_ids] 和 sin[position_ids] 的形状为 [batch_size, seq_len, head_dim],
当 q 和 k 的形状为 [batch_size, heads, seq_len, head_dim] 时,设置 unsqueeze_dim=1 使得它们可以正确广播到 q 和 k 的形状。
同样地,如果 q 和 k 的形状为 [batch_size, seq_len, heads, head_dim],则设置 unsqueeze_dim=2。
Returns:
`tuple(torch.Tensor)`: 返回应用了旋转位置嵌入后的查询和键张量。
"""
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class PhiMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.activation_fn = ACT2FN[config.hidden_act]
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
这是 torch.repeat_interleave(x, dim=1, repeats=n_rep) 的等效实现。
将隐藏状态从 (batch, num_key_value_heads, seqlen, head_dim) 扩展为 (batch, num_attention_heads, seqlen, head_dim)。
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: PhiConfig, layer_idx: Optional[int] = None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
if layer_idx is None:
logger.warning_once(
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)
self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.partial_rotary_factor = config.partial_rotary_factor
self.is_causal = True
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
self.dense = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=True)
self.qk_layernorm = config.qk_layernorm
if self.qk_layernorm:
self.q_layernorm = nn.LayerNorm(
config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True
)
self.k_layernorm = nn.LayerNorm(
config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True
)
self._init_rope()
def _init_rope(self):
if self.config.rope_scaling is None:
self.rotary_emb = PhiRotaryEmbedding(
int(self.partial_rotary_factor * self.head_dim),
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)
else:
scaling_type = self.config.rope_scaling["type"]
scaling_factor = self.config.rope_scaling["factor"]
if scaling_type == "linear":
self.rotary_emb = PhiLinearScalingRotaryEmbedding(
int(self.partial_rotary_factor * self.head_dim),
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
)
elif scaling_type == "dynamic":
self.rotary_emb = PhiDynamicNTKScalingRotaryEmbedding(
int(self.partial_rotary_factor * self.head_dim),
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
class PhiFlashAttention2(PhiAttention):
"""
Phi flash attention module. This module inherits from `PhiAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def _flash_attention_forward(
self,
query_states,
key_states,
value_states,
attention_mask,
query_length,
dropout=0.0,
softmax_scale=None,
**kwargs,
):
):
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
first unpad the input, then computes the attention scores and pad the final attention scores.
Args:
query_states (`torch.Tensor`):
Input query states to be passed to Flash Attention API
key_states (`torch.Tensor`):
Input key states to be passed to Flash Attention API
value_states (`torch.Tensor`):
Input value states to be passed to Flash Attention API
attention_mask (`torch.Tensor`):
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
position of padding tokens and 1 for the position of non-padding tokens.
dropout (`float`):
Attention dropout
softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
"""
if not self._flash_attn_uses_top_left_mask:
causal = self.is_causal
else:
causal = self.is_causal and query_length != 1
if attention_mask is not None:
batch_size = query_states.shape[0]
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
query_states, key_states, value_states, attention_mask, query_length
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
)
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else:
attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
)
return attn_output
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
key_layer = index_first_axis(
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
value_layer = index_first_axis(
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
if query_length == kv_seq_len:
query_layer = index_first_axis(
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
elif query_length == 1:
max_seqlen_in_batch_q = 1
cu_seqlens_q = torch.arange(
batch_size + 1, dtype=torch.int32, device=query_layer.device
)
indices_q = cu_seqlens_q[:-1]
query_layer = query_layer.squeeze(1)
else:
attention_mask = attention_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
return (
query_layer,
key_layer,
value_layer,
indices_q,
(cu_seqlens_q, cu_seqlens_k),
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)
class PhiSdpaAttention(PhiAttention):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0")
"""
SDPA attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
`PhiAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
SDPA API.
"""
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
):
pass
PHI_ATTENTION_CLASSES = {
"eager": PhiAttention,
"flash_attention_2": PhiFlashAttention2,
"sdpa": PhiSdpaAttention,
}
class PhiDecoderLayer(nn.Module):
def __init__(self, config: PhiConfig, layer_idx: int):
super().__init__()
self.self_attn = PHI_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx)
self.mlp = PhiMLP(config)
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.resid_dropout = nn.Dropout(config.resid_pdrop)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
):
pass
"""
Args:
hidden_states (`torch.FloatTensor`):
输入到层的张量,形状为 `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*):
注意力掩码,形状为 `(batch, 1, tgt_len, src_len)`,其中填充元素由非常大的负值表示。
position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
每个输入序列标记在位置嵌入中的位置索引。选在范围 `[0, config.n_positions - 1]`。[什么是位置ID?](../glossary#position-ids)
output_attentions (`bool`, *optional*):
是否返回所有注意力层的注意力张量。有关更多细节,请查看返回的张量中的 `attentions`。
use_cache (`bool`, *optional*):
如果设置为 `True`,则返回 `past_key_values` 键值状态,可以用于加速解码(参见 `past_key_values`)。
past_key_value (`Tuple(torch.FloatTensor)`, *optional*):
缓存的过去键和值投影状态。
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
attn_outputs, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
attn_outputs = self.resid_dropout(attn_outputs)
feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states))
hidden_states = attn_outputs + feed_forward_hidden_states + residual
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
PHI_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`PhiConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
@add_start_docstrings(
"The bare Phi Model outputting raw hidden-states without any specific head on top.",
PHI_START_DOCSTRING,
)
class PhiPreTrainedModel(PreTrainedModel):
config_class = PhiConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["PhiDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
PHI_INPUTS_DOCSTRING = r"""
"""
@add_start_docstrings(
"The bare Phi Model outputting raw hidden-states without any specific head on top.",
PHI_START_DOCSTRING,
)
class PhiModel(PhiPreTrainedModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`PhiDecoderLayer`]
Args:
config: PhiConfig
"""
def __init__(self, config: PhiConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.embed_dropout = nn.Dropout(config.embd_pdrop)
self.layers = nn.ModuleList(
[PhiDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self._use_sdpa = config._attn_implementation == "sdpa"
self.gradient_checkpointing = False
self.post_init()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
@add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
class PhiForCausalLM(PhiPreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.model = PhiModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
pass
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
pass
):
if past_key_values is not None:
if isinstance(past_key_values, Cache):
cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens
max_cache_length = past_key_values.get_max_length()
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
elif past_length < input_ids.shape[1]:
input_ids = input_ids[:, past_length:]
if (
max_cache_length is not None
and attention_mask is not None
and cache_length + input_ids.shape[1] > max_cache_length
):
attention_mask = attention_mask[:, -max_cache_length:]
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
}
)
return model_inputs
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past
@add_start_docstrings(
"""
The PhiModel with a sequence classification head on top (linear layer).
[`PhiForSequenceClassification`] uses the last token in order to do the classification, as other causal models
(e.g. GPT-2) do.
Since it does classification on the last token, it requires to know the position of the last token. If a
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
each row of the batch).
""",
PHI_START_DOCSTRING,
)
class PhiForSequenceClassification(PhiPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.model = PhiModel(config)
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
@add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
@add_start_docstrings(
"""
PhiModel with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
Named-Entity-Recognition (NER) tasks.
""",
PHI_START_DOCSTRING,
)
class PhiForTokenClassification(PhiPreTrainedModel):
def __init__(self, config: PhiConfig):
super().__init__(config)
self.num_labels = config.num_labels
self.model = PhiModel(config)
if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
classifier_dropout = config.classifier_dropout
elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
classifier_dropout = config.hidden_dropout
else:
classifier_dropout = 0.1
self.dropout = nn.Dropout(classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.post_init()
@add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**deprecated_arguments,
) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
model_outputs = self.model(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = model_outputs[0]
hidden_states = self.dropout(hidden_states)
logits = self.classifier(hidden_states)
loss = None
if labels is not None:
labels = labels.to(logits.device)
batch_size, seq_length = labels.shape
loss_fct = CrossEntropyLoss()
loss = loss_fct(
logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length)
)
if not return_dict:
output = (logits,) + model_outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=model_outputs.hidden_states,
attentions=model_outputs.attentions,
)
.\models\phi\__init__.py
from typing import TYPE_CHECKING
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_sentencepiece_available,
is_tokenizers_available,
is_torch_available,
)
_import_structure = {
"configuration_phi": ["PHI_PRETRAINED_CONFIG_ARCHIVE_MAP", "PhiConfig"],
}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_phi"] = [
"PHI_PRETRAINED_MODEL_ARCHIVE_LIST",
"PhiPreTrainedModel",
"PhiModel",
"PhiForCausalLM",
"PhiForSequenceClassification",
"PhiForTokenClassification",
]
if TYPE_CHECKING:
from .configuration_phi import PHI_PRETRAINED_CONFIG_ARCHIVE_MAP, PhiConfig
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_phi import (
PHI_PRETRAINED_MODEL_ARCHIVE_LIST,
PhiForCausalLM,
PhiForSequenceClassification,
PhiForTokenClassification,
PhiModel,
PhiPreTrainedModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
.\models\phobert\tokenization_phobert.py
""" PhoBERT 的分词类 """
import os
import re
from shutil import copyfile
from typing import List, Optional, Tuple
from ...tokenization_utils import PreTrainedTokenizer
from ...utils import logging
logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {
"vocab_file": "vocab.txt",
"merges_file": "bpe.codes",
}
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"vinai/phobert-base": "https://huggingface.co/vinai/phobert-base/resolve/main/vocab.txt",
"vinai/phobert-large": "https://huggingface.co/vinai/phobert-large/resolve/main/vocab.txt",
},
"merges_file": {
"vinai/phobert-base": "https://huggingface.co/vinai/phobert-base/resolve/main/bpe.codes",
"vinai/phobert-large": "https://huggingface.co/vinai/phobert-large/resolve/main/bpe.codes",
},
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"vinai/phobert-base": 256,
"vinai/phobert-large": 256,
}
def get_pairs(word):
"""
返回单词中的符号对集合。
单词表示为符号元组(符号是长度可变的字符串)。
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
pairs = set(pairs)
return pairs
class PhobertTokenizer(PreTrainedTokenizer):
"""
构造一个 PhoBERT 分词器。基于字节对编码(Byte-Pair-Encoding)。
此分词器继承自 PreTrainedTokenizer,其中包含大多数主要方法。用户应参考这个超类以获取有关这些方法的更多信息。
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(
self,
vocab_file,
merges_file,
bos_token="<s>",
eos_token="</s>",
sep_token="</s>",
cls_token="<s>",
unk_token="<unk>",
pad_token="<pad>",
mask_token="<mask>",
**kwargs,
):
):
self.vocab_file = vocab_file
self.merges_file = merges_file
self.encoder = {}
self.encoder[str(bos_token)] = 0
self.encoder[str(pad_token)] = 1
self.encoder[str(eos_token)] = 2
self.encoder[str(unk_token)] = 3
self.add_from_file(vocab_file)
self.decoder = {v: k for k, v in self.encoder.items()}
with open(merges_file, encoding="utf-8") as merges_handle:
merges = merges_handle.read().split("\n")[:-1]
merges = [tuple(merge.split()[:-1]) for merge in merges]
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {}
super().__init__(
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
sep_token=sep_token,
cls_token=cls_token,
pad_token=pad_token,
mask_token=mask_token,
**kwargs,
)
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
adding special tokens. A PhoBERT sequence has the following format:
- single sequence: `<s> X </s>`
- pair of sequences: `<s> A </s></s> B </s>`
Args:
token_ids_0 (`List[int]`):
List of IDs to which the special tokens will be added.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
"""
if token_ids_1 is None:
return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
cls = [self.cls_token_id]
sep = [self.sep_token_id]
return cls + token_ids_0 + sep + sep + token_ids_1 + sep
def get_special_tokens_mask(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
):
) -> List[int]:
"""
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer `prepare_for_model` method.
Args:
token_ids_0 (`List[int]`):
List of IDs.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not the token list is already formatted with special tokens for the model.
Returns:
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
"""
if already_has_special_tokens:
return super().get_special_tokens_mask(
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
)
if token_ids_1 is None:
return [1] + ([0] * len(token_ids_0)) + [1]
return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
def create_token_type_ids_from_sequences(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Create a mask from the two sequences passed to be used in a sequence-pair classification task. PhoBERT does not
make use of token type ids, therefore a list of zeros is returned.
Args:
token_ids_0 (`List[int]`):
List of IDs.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: List of zeros.
"""
sep = [self.sep_token_id]
cls = [self.cls_token_id]
if token_ids_1 is None:
return len(cls + token_ids_0 + sep) * [0]
return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
@property
def vocab_size(self):
return len(self.encoder)
def get_vocab(self):
return dict(self.encoder, **self.added_tokens_encoder)
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token)
word = tuple(list(word[:-1]) + [word[-1] + "</w>"])
pairs = get_pairs(word)
if not pairs:
return token
while True:
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
except ValueError:
new_word.extend(word[i:])
break
else:
new_word.extend(word[i:j])
i = j
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = "@@ ".join(word)
word = word[:-4]
self.cache[token] = word
return word
def _tokenize(self, text):
"""Tokenize a string."""
split_tokens = []
words = re.findall(r"\S+\n?", text)
for token in words:
split_tokens.extend(list(self.bpe(token).split(" ")))
return split_tokens
def _convert_token_to_id(self, token):
"""Converts a token (str) in an id using the vocab."""
return self.encoder.get(token, self.encoder.get(self.unk_token))
def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab."""
return self.decoder.get(index, self.unk_token)
def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (string) in a single string."""
out_string = " ".join(tokens).replace("@@ ", "").strip()
return out_string
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
return
out_vocab_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
)
out_merge_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
)
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file)
elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
if os.path.abspath(self.merges_file) != os.path.abspath(out_merge_file):
copyfile(self.merges_file, out_merge_file)
return out_vocab_file, out_merge_file
def add_from_file(self, f):
"""
从文本文件加载预先存在的字典,并将其符号添加到此实例中。
"""
if isinstance(f, str):
try:
with open(f, "r", encoding="utf-8") as fd:
self.add_from_file(fd)
except FileNotFoundError as fnfe:
raise fnfe
except UnicodeError:
raise Exception(f"Incorrect encoding detected in {f}, please rebuild the dataset")
return
lines = f.readlines()
for lineTmp in lines:
line = lineTmp.strip()
idx = line.rfind(" ")
if idx == -1:
raise ValueError("Incorrect dictionary format, expected '<token> <cnt>'")
word = line[:idx]
self.encoder[word] = len(self.encoder)
.\models\phobert\__init__.py
from typing import TYPE_CHECKING
from ...utils import _LazyModule
_import_structure = {"tokenization_phobert": ["PhobertTokenizer"]}
if TYPE_CHECKING:
from .tokenization_phobert import PhobertTokenizer
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
.\models\pix2struct\configuration_pix2struct.py
""" Pix2Struct 模型配置 """
import os
from typing import Union
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
PIX2STRUCT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"google/pix2struct-textcaps-base": (
"https://huggingface.co/google/pix2struct-textcaps-base/resolve/main/config.json"
),
}
class Pix2StructTextConfig(PretrainedConfig):
r"""
这是用于存储 [`Pix2StructTextModel`] 配置的配置类。它用于根据指定的参数实例化
Pix2Struct 文本模型,定义模型架构。使用默认值实例化配置将产生类似于
[google/pix2struct-base](https://huggingface.co/google/pix2struct-base) 架构使用的 Pix2Struct 文本解码器的配置。
配置对象继承自 [`PretrainedConfig`],可用于控制模型的输出。有关更多信息,请阅读
[`PretrainedConfig`] 的文档。
# 定义模型类型为 "pix2struct_text_model"
model_type = "pix2struct_text_model"
python
# 在推断时要忽略的键列表
keys_to_ignore_at_inference = ["past_key_values"]
# 属性映射字典,将类参数名映射到配置文件中的属性名
attribute_map = {
"hidden_size": "hidden_size",
"num_attention_heads": "num_heads",
"num_hidden_layers": "num_layers",
}
# 类的初始化方法,定义了模型配置的默认参数
def __init__(
self,
vocab_size=50244,
hidden_size=768,
d_kv=64,
d_ff=2048,
num_layers=12,
num_heads=12,
relative_attention_num_buckets=32,
relative_attention_max_distance=128,
dropout_rate=0.1,
layer_norm_epsilon=1e-6,
initializer_factor=1.0,
dense_act_fn="gelu_new",
decoder_start_token_id=0,
use_cache=False,
pad_token_id=0,
eos_token_id=1,
tie_word_embeddings=False,
is_decoder=True,
**kwargs,
):
# 初始化类的各个参数
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.d_kv = d_kv
self.d_ff = d_ff
self.num_layers = num_layers
self.num_heads = num_heads
self.relative_attention_num_buckets = relative_attention_num_buckets
self.relative_attention_max_distance = relative_attention_max_distance
self.dropout_rate = dropout_rate
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_factor = initializer_factor
self.use_cache = use_cache
self.eos_token_id = eos_token_id
self.decoder_start_token_id = decoder_start_token_id
# 为了向后兼容,设置密集层激活函数
self.dense_act_fn = dense_act_fn
# 调用父类的初始化方法,传入参数
super().__init__(
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
decoder_start_token_id=decoder_start_token_id,
tie_word_embeddings=tie_word_embeddings,
is_decoder=is_decoder,
**kwargs,
)
# 类方法,从预训练模型加载配置
@classmethod
def from_pretrained(
cls, pretrainehidden_size_name_or_path: Union[str, os.PathLike], **kwargs
) -> "PretrainedConfig":
# 设置 token 的参数到 kwargs 中
cls._set_token_in_kwargs(kwargs)
# 获取预训练模型的配置字典和额外的 kwargs
config_dict, kwargs = cls.get_config_dict(pretrainehidden_size_name_or_path, **kwargs)
# 如果配置字典中的模型类型是 "pix2struct",则获取其中的文本配置字典
if config_dict.get("model_type") == "pix2struct":
config_dict = config_dict["text_config"]
# 如果配置字典中指定的模型类型与类中定义的模型类型不匹配,发出警告
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)
# 从配置字典和 kwargs 创建并返回一个类实例
return cls.from_dict(config_dict, **kwargs)
# Pix2StructVisionConfig 类,继承自 PretrainedConfig 类
class Pix2StructVisionConfig(PretrainedConfig):
r"""
这是一个配置类,用于存储 [`Pix2StructVisionModel`] 的配置。它被用来实例化一个 Pix2Struct 视觉模型,根据指定的参数定义模型架构。
默认情况下实例化一个配置将产生类似于 Pix2Struct-base [google/pix2struct-base](https://huggingface.co/google/pix2struct-base) 架构的配置。
配置对象继承自 [`PretrainedConfig`],可以用来控制模型的输出。阅读 [`PretrainedConfig`] 的文档以获取更多信息。
```
model_type = "pix2struct_vision_model"
def __init__(
self,
hidden_size=768,
patch_embed_hidden_size=768,
d_ff=2048,
d_kv=64,
num_hidden_layers=12,
num_attention_heads=12,
dense_act_fn="gelu_new",
layer_norm_eps=1e-6,
dropout_rate=0.0,
attention_dropout=0.0,
initializer_range=1e-10,
initializer_factor=1.0,
seq_len=4096,
relative_attention_num_buckets=32,
relative_attention_max_distance=128,
**kwargs,
):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.patch_embed_hidden_size = patch_embed_hidden_size
self.d_ff = d_ff
self.dropout_rate = dropout_rate
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.initializer_range = initializer_range
self.initializer_factor = initializer_factor
self.attention_dropout = attention_dropout
self.layer_norm_eps = layer_norm_eps
self.dense_act_fn = dense_act_fn
self.seq_len = seq_len
self.relative_attention_num_buckets = relative_attention_num_buckets
self.relative_attention_max_distance = relative_attention_max_distance
self.d_kv = d_kv
@classmethod
def from_pretrained(
cls, pretrainehidden_size_name_or_path: Union[str, os.PathLike], **kwargs
) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(pretrainehidden_size_name_or_path, **kwargs)
if config_dict.get("model_type") == "pix2struct":
config_dict = config_dict["vision_config"]
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)
return cls.from_dict(config_dict, **kwargs)
class Pix2StructConfig(PretrainedConfig):
r"""
[`Pix2StructConfig`] is the configuration class to store the configuration of a
[`Pix2StructForConditionalGeneration`]. It is used to instantiate a Pix2Struct model according to the specified
arguments, defining the text model and vision model configs. Instantiating a configuration with the defaults will
yield a similar configuration to that of the Pix2Struct-base
[google/pix2struct-base](https://huggingface.co/google/pix2struct-base) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
text_config (`dict`, *optional*):
Dictionary of configuration options used to initialize [`Pix2StructTextConfig`].
vision_config (`dict`, *optional*):
Dictionary of configuration options used to initialize [`Pix2StructVisionConfig`].
initializer_factor (`float`, *optional*, defaults to 1.0):
Factor to multiply the initialization range with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
is_vqa (`bool`, *optional*, defaults to `False`):
Whether the model has been fine-tuned for VQA or not.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether to tie the word embeddings between the text and vision models.
is_encoder_decoder (`bool`, *optional*, defaults to `True`):
Whether the model follows an encoder-decoder architecture.
kwargs (*optional*):
Dictionary of keyword arguments.
Example:
```
>>> from transformers import Pix2StructConfig, Pix2StructForConditionalGeneration
>>> # Initializing a Pix2StructConfig with google/pix2struct-base style configuration
>>> configuration = Pix2StructConfig()
>>> # Initializing a Pix2StructForConditionalGeneration (with random weights) from the google/pix2struct-base style configuration
>>> model = Pix2StructForConditionalGeneration(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
>>> # We can also initialize a Pix2StructConfig from a Pix2StructTextConfig and a Pix2StructVisionConfig
>>> # Initializing a Pix2Struct text and Pix2Struct vision configuration
>>> config_text = Pix2StructTextConfig()
>>> config_vision = Pix2StructVisionConfig()
>>> config = Pix2StructConfig.from_text_vision_configs(config_text, config_vision)
```"""
model_type = "pix2struct"
def __init__(
self,
text_config=None,
vision_config=None,
initializer_factor=1.0,
initializer_range=0.02,
is_vqa=False,
tie_word_embeddings=False,
is_encoder_decoder=True,
**kwargs,
):
super().__init__(**kwargs)
self.text_config = text_config
self.vision_config = vision_config
self.initializer_factor = initializer_factor
self.initializer_range = initializer_range
self.is_vqa = is_vqa
self.tie_word_embeddings = tie_word_embeddings
self.is_encoder_decoder = is_encoder_decoder
):
super().__init__(tie_word_embeddings=tie_word_embeddings, is_encoder_decoder=is_encoder_decoder, **kwargs)
if text_config is None:
text_config = {}
logger.info("text_config is None. Initializing the Pix2StructTextConfig with default values.")
if vision_config is None:
vision_config = {}
logger.info("vision_config is None. Initializing the Pix2StructVisionConfig with default values.")
self.text_config = Pix2StructTextConfig(**text_config)
self.vision_config = Pix2StructVisionConfig(**vision_config)
self.decoder_start_token_id = self.text_config.decoder_start_token_id
self.pad_token_id = self.text_config.pad_token_id
self.eos_token_id = self.text_config.eos_token_id
self.initializer_factor = initializer_factor
self.initializer_range = initializer_range
self.text_config.initializer_range = self.initializer_range
self.vision_config.initializer_range = self.initializer_range
self.is_vqa = is_vqa
@classmethod
def from_text_vision_configs(
cls, text_config: Pix2StructTextConfig, vision_config: Pix2StructVisionConfig, **kwargs
):
r"""
Instantiate a [`Pix2StructConfig`] (or a derived class) from pix2struct text model configuration and pix2struct
vision model configuration.
Returns:
[`Pix2StructConfig`]: An instance of a configuration object
"""
return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
.\models\pix2struct\convert_pix2struct_original_pytorch_to_hf.py
import argparse
import os
import re
import torch
from flax.traverse_util import flatten_dict
from t5x import checkpoints
from transformers import (
AutoTokenizer,
Pix2StructConfig,
Pix2StructForConditionalGeneration,
Pix2StructImageProcessor,
Pix2StructProcessor,
Pix2StructTextConfig,
Pix2StructVisionConfig,
)
def get_flax_param(t5x_checkpoint_path):
flax_params = checkpoints.load_t5x_checkpoint(t5x_checkpoint_path)
flax_params = flatten_dict(flax_params)
return flax_params
def rename_and_convert_flax_params(flax_dict):
converted_dict = {}
CONVERSION_MAPPING = {
"token_embedder": "embeddings",
"encoder_norm": "layernorm",
"kernel": "weight",
".out": ".output",
"scale": "weight",
"embedders_0.pos_embedding": "row_embedder.weight",
"embedders_1.pos_embedding": "column_embedder.weight",
}
DECODER_CONVERSION_MAPPING = {
"query": "attention.query",
"key": "attention.key",
"value": "attention.value",
"output.dense": "output",
"encoder_decoder_attention.o": "encoder_decoder_attention.attention.o",
"pre_self_attention_layer_norm": "self_attention.layer_norm",
"pre_cross_attention_layer_norm": "encoder_decoder_attention.layer_norm",
"mlp.": "mlp.DenseReluDense.",
"pre_mlp_layer_norm": "mlp.layer_norm",
"self_attention.o": "self_attention.attention.o",
"decoder.embeddings.embedding": "decoder.embed_tokens.weight",
"decoder.relpos_bias.rel_embedding": "decoder.layer.0.self_attention.attention.relative_attention_bias.weight",
"decoder.decoder_norm.weight": "decoder.final_layer_norm.weight",
"decoder.logits_dense.weight": "decoder.lm_head.weight",
}
for key in flax_dict.keys():
if "target" in key:
new_key = ".".join(key[1:])
for old, new in CONVERSION_MAPPING.items():
new_key = new_key.replace(old, new)
if "decoder" in new_key:
for old, new in DECODER_CONVERSION_MAPPING.items():
new_key = new_key.replace(old, new)
if "layers" in new_key and "decoder" not in new_key:
new_key = re.sub(r"layers_(\d+)", r"layer.\1", new_key)
new_key = new_key.replace("encoder", "encoder.encoder")
elif "layers" in new_key and "decoder" in new_key:
new_key = re.sub(r"layers_(\d+)", r"layer.\1", new_key)
converted_dict[new_key] = flax_dict[key]
converted_torch_dict = {}
for key in converted_dict.keys():
if ("embed_tokens" not in key) and ("embedder" not in key):
converted_torch_dict[key] = torch.from_numpy(converted_dict[key].T)
else:
converted_torch_dict[key] = torch.from_numpy(converted_dict[key])
return converted_torch_dict
def convert_pix2struct_original_pytorch_checkpoint_to_hf(
t5x_checkpoint_path, pytorch_dump_folder_path, use_large=False, is_vqa=False
):
flax_params = get_flax_param(t5x_checkpoint_path)
if not use_large:
encoder_config = Pix2StructVisionConfig()
decoder_config = Pix2StructTextConfig()
else:
encoder_config = Pix2StructVisionConfig(
hidden_size=1536, d_ff=3968, num_attention_heads=24, num_hidden_layers=18
)
decoder_config = Pix2StructTextConfig(hidden_size=1536, d_ff=3968, num_heads=24, num_layers=18)
config = Pix2StructConfig(
vision_config=encoder_config.to_dict(), text_config=decoder_config.to_dict(), is_vqa=is_vqa
)
model = Pix2StructForConditionalGeneration(config)
torch_params = rename_and_convert_flax_params(flax_params)
model.load_state_dict(torch_params)
tok = AutoTokenizer.from_pretrained("ybelkada/test-pix2struct-tokenizer")
image_processor = Pix2StructImageProcessor()
processor = Pix2StructProcessor(image_processor=image_processor, tokenizer=tok)
if use_large:
processor.image_processor.max_patches = 4096
processor.image_processor.is_vqa = True
os.makedirs(pytorch_dump_folder_path, exist_ok=True)
model.save_pretrained(pytorch_dump_folder_path)
processor.save_pretrained(pytorch_dump_folder_path)
print("Model saved in {}".format(pytorch_dump_folder_path))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--t5x_checkpoint_path", default=None, type=str, help="Path to the original T5x checkpoint.")
parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
parser.add_argument("--use_large", action="store_true", help="Use large model.")
parser.add_argument("--is_vqa", action="store_true", help="Use large model.")
args = parser.parse_args()
convert_pix2struct_original_pytorch_checkpoint_to_hf(
args.t5x_checkpoint_path, args.pytorch_dump_folder_path, args.use_large
)
.\models\pix2struct\image_processing_pix2struct.py
"""Pix2Struct 的图像处理类"""
import io
import math
from typing import Dict, Optional, Union
import numpy as np
from huggingface_hub import hf_hub_download
from ...image_processing_utils import BaseImageProcessor, BatchFeature
from ...image_transforms import convert_to_rgb, normalize, to_channel_dimension_format, to_pil_image
from ...image_utils import (
ChannelDimension,
ImageInput,
get_image_size,
infer_channel_dimension_format,
make_list_of_images,
to_numpy_array,
valid_images,
)
from ...utils import TensorType, is_torch_available, is_vision_available, logging
from ...utils.import_utils import requires_backends
if is_vision_available():
import textwrap
from PIL import Image, ImageDraw, ImageFont
if is_torch_available():
import torch
logger = logging.get_logger(__name__)
DEFAULT_FONT_PATH = "ybelkada/fonts"
def torch_extract_patches(image_tensor, patch_height, patch_width):
"""
从给定的图像张量中提取补丁的实用函数。返回形状为 (1, `patch_height`, `patch_width`, `num_channels`x `patch_height` x `patch_width`) 的张量
Args:
image_tensor (torch.Tensor):
要从中提取补丁的图像张量。
patch_height (int):
要提取的补丁的高度。
patch_width (int):
要提取的补丁的宽度。
"""
requires_backends(torch_extract_patches, ["torch"])
image_tensor = image_tensor.unsqueeze(0)
patches = torch.nn.functional.unfold(image_tensor, (patch_height, patch_width), stride=(patch_height, patch_width))
patches = patches.reshape(image_tensor.size(0), image_tensor.size(1), patch_height, patch_width, -1)
patches = patches.permute(0, 4, 2, 3, 1).reshape(
image_tensor.size(2) // patch_height,
image_tensor.size(3) // patch_width,
image_tensor.size(1) * patch_height * patch_width,
)
return patches.unsqueeze(0)
def render_text(
text: str,
text_size: int = 36,
text_color: str = "black",
background_color: str = "white",
font_path: str = DEFAULT_FONT_PATH,
max_width: Optional[int] = None,
max_height: Optional[int] = None,
):
"""
渲染文本为图像的实用函数。
Args:
text (str):
要渲染的文本内容。
text_size (int, optional):
文本的字体大小,默认为 36。
text_color (str, optional):
文本的颜色,默认为黑色。
background_color (str, optional):
背景的颜色,默认为白色。
font_path (str, optional):
字体文件的路径,默认为 DEFAULT_FONT_PATH。
max_width (int, optional):
最大宽度限制,默认为 None。
max_height (int, optional):
最大高度限制,默认为 None。
"""
left_padding: int = 5,
right_padding: int = 5,
top_padding: int = 5,
bottom_padding: int = 5,
font_bytes: Optional[bytes] = None,
font_path: Optional[str] = None,
def render_header(
image: np.ndarray, header: str, input_data_format: Optional[Union[str, ChildProcessError]] = None, **kwargs
):
"""
Renders the input text as a header on the input image.
Args:
image (`np.ndarray`):
Input image represented as a NumPy array.
header (`str`):
Text to render as the header.
input_data_format (`Optional[Union[str, ChildProcessError]]`, *optional*):
Format of the input data. Defaults to `None`.
**kwargs:
Additional keyword arguments for customization.
Returns:
`Image.Image`:
An image with the rendered header text.
Note:
This function renders the header text onto the given image using specified or default parameters.
It adapts the text rendering from an external source.
Adapted from:
https://github.com/google-research/pix2struct/blob/0e1779af0f4db4b652c1d92b3bbd2550a7399123/pix2struct/preprocessing/preprocessing_utils.py#L87
"""
requires_backends(render_text, "vision")
wrapper = textwrap.TextWrapper(width=80)
lines = wrapper.wrap(text=header)
wrapped_text = "\n".join(lines)
if font_bytes is not None and font_path is None:
font = io.BytesIO(font_bytes)
elif font_path is not None:
font = font_path
else:
font = hf_hub_download(DEFAULT_FONT_PATH, "Arial.TTF")
font = ImageFont.truetype(font, encoding="UTF-8", size=text_size)
temp_draw = ImageDraw.Draw(Image.new("RGB", (1, 1), background_color))
_, _, text_width, text_height = temp_draw.textbbox((0, 0), wrapped_text, font)
image_width = text_width + left_padding + right_padding
image_height = text_height + top_padding + bottom_padding
image = Image.new("RGB", (image_width, image_height), background_color)
draw = ImageDraw.Draw(image)
draw.text(xy=(left_padding, top_padding), text=wrapped_text, fill=text_color, font=font)
return image
Args:
image (`np.ndarray`):
The image to render the header on.
header (`str`):
The header text.
data_format (`Union[ChannelDimension, str]`, *optional*):
The data format of the image. Can be either "ChannelDimension.channels_first" or
"ChannelDimension.channels_last".
Returns:
`np.ndarray`: The image with the header rendered.
"""
# 检查渲染头部所需的视觉后端是否存在
requires_backends(render_header, "vision")
# 如果需要,将输入的图像转换为PIL图像格式
image = to_pil_image(image, input_data_format=input_data_format)
# 使用渲染文本函数生成头部文本对应的图像
header_image = render_text(header, **kwargs)
# 计算新图像的宽度为头部图像和原始图像宽度的最大值
new_width = max(header_image.width, image.width)
# 计算新图像的高度,保持原始图像的宽高比
new_height = int(image.height * (new_width / image.width))
new_header_height = int(header_image.height * (new_width / header_image.width))
# 创建新的RGB模式的白色背景图像,大小为新宽度和高度之和
new_image = Image.new("RGB", (new_width, new_height + new_header_height), "white")
# 将调整大小后的头部图像粘贴到新图像的顶部
new_image.paste(header_image.resize((new_width, new_header_height)), (0, 0))
# 将调整大小后的原始图像粘贴到新图像的下部
new_image.paste(image.resize((new_width, new_height)), (0, new_header_height))
# 如果需要,将新图像转换回原始数据格式
new_image = to_numpy_array(new_image)
# 如果推断出新图像的通道维度格式为最后一个维度
if infer_channel_dimension_format(new_image) == ChannelDimension.LAST:
# 将新图像转换为最后一个通道维度格式
new_image = to_channel_dimension_format(new_image, ChannelDimension.LAST)
# 返回渲染了头部的新图像
return new_image
r"""
Constructs a Pix2Struct image processor.
构造一个 Pix2Struct 图像处理器。
Args:
do_convert_rgb (`bool`, *optional*, defaults to `True`):
Whether to convert the image to RGB.
是否将图像转换为 RGB。
do_normalize (`bool`, *optional*, defaults to `True`):
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
method. According to Pix2Struct paper and code, the image is normalized with its own mean and standard
deviation.
是否对图像进行归一化。可以通过 `preprocess` 方法中的 `do_normalize` 参数进行覆盖。
根据 Pix2Struct 论文和代码,图像使用其自身的均值和标准差进行归一化。
patch_size (`Dict[str, int]`, *optional*, defaults to `{"height": 16, "width": 16}`):
The patch size to use for the image. According to Pix2Struct paper and code, the patch size is 16x16.
图像使用的补丁大小。根据 Pix2Struct 论文和代码,补丁大小为 16x16。
max_patches (`int`, *optional*, defaults to 2048):
The maximum number of patches to extract from the image as per the [Pix2Struct
paper](https://arxiv.org/pdf/2210.03347.pdf).
从图像中提取的最大补丁数,根据 Pix2Struct 论文。
is_vqa (`bool`, *optional*, defaults to `False`):
Whether or not the image processor is for the VQA task. If `True` and `header_text` is passed in, text is
rendered onto the input images.
图像处理器是否用于 VQA 任务。如果为 `True` 并且传入了 `header_text`,则将文本渲染到输入图像上。
"""
model_input_names = ["flattened_patches"]
def __init__(
self,
do_convert_rgb: bool = True,
do_normalize: bool = True,
patch_size: Dict[str, int] = None,
max_patches: int = 2048,
is_vqa: bool = False,
**kwargs,
) -> None:
super().__init__(**kwargs) # 调用父类的初始化方法,传递所有未明确指定的关键字参数
self.patch_size = patch_size if patch_size is not None else {"height": 16, "width": 16} # 设置补丁大小,默认为 {"height": 16, "width": 16}
self.do_normalize = do_normalize # 是否进行归一化
self.do_convert_rgb = do_convert_rgb # 是否进行 RGB 转换
self.max_patches = max_patches # 最大提取补丁数
self.is_vqa = is_vqa # 是否用于 VQA 任务
def extract_flattened_patches(
self,
image: np.ndarray,
max_patches: int,
patch_size: dict,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
) -> np.ndarray:
"""
Extract flattened patches from an image.
Args:
image (`np.ndarray`):
Image to extract flattened patches from.
max_patches (`int`):
Maximum number of patches to extract.
patch_size (`dict`):
Dictionary containing the patch height and width.
Returns:
result (`np.ndarray`):
A sequence of `max_patches` flattened patches.
"""
# 检查是否需要使用 torch 后端函数
requires_backends(self.extract_flattened_patches, "torch")
# 将图像转换为 torch 张量格式
image = to_channel_dimension_format(image, ChannelDimension.FIRST, input_data_format)
image = torch.from_numpy(image)
# 获取补丁的高度和宽度
patch_height, patch_width = patch_size["height"], patch_size["width"]
# 获取图像的高度和宽度
image_height, image_width = get_image_size(image, ChannelDimension.FIRST)
# 最大化比例以便适应给定的最大补丁数和图像尺寸
scale = math.sqrt(max_patches * (patch_height / image_height) * (patch_width / image_width))
num_feasible_rows = max(min(math.floor(scale * image_height / patch_height), max_patches), 1)
num_feasible_cols = max(min(math.floor(scale * image_width / patch_width), max_patches), 1)
resized_height = max(num_feasible_rows * patch_height, 1)
resized_width = max(num_feasible_cols * patch_width, 1)
# 对图像进行插值调整大小
image = torch.nn.functional.interpolate(
image.unsqueeze(0),
size=(resized_height, resized_width),
mode="bilinear",
align_corners=False,
antialias=True,
).squeeze(0)
# 提取图像的补丁
# [1, rows, columns, patch_height * patch_width * image_channels]
patches = torch_extract_patches(image, patch_height, patch_width)
# 获取补丁的形状信息
patches_shape = patches.shape
rows = patches_shape[1]
columns = patches_shape[2]
depth = patches_shape[3]
# 重新整形补丁张量以便进一步处理
# [rows * columns, patch_height * patch_width * image_channels]
patches = patches.reshape([rows * columns, depth])
# 创建行和列的索引张量
# [rows * columns, 1]
row_ids = torch.arange(rows).reshape([rows, 1]).repeat(1, columns).reshape([rows * columns, 1])
col_ids = torch.arange(columns).reshape([1, columns]).repeat(rows, 1).reshape([rows * columns, 1])
# 将索引张量的值加一,以避免包含代表填充的零
row_ids += 1
col_ids += 1
# 准备额外的补丁特征信息
# [rows * columns, 1]
row_ids = row_ids.to(torch.float32)
col_ids = col_ids.to(torch.float32)
# 拼接行号、列号和补丁数据,形成最终的输出结果
# [rows * columns, 2 + patch_height * patch_width * image_channels]
result = torch.cat([row_ids, col_ids, patches], -1)
# 对结果进行填充,以保证输出的补丁数量不超过 max_patches
# [max_patches, 2 + patch_height * patch_width * image_channels]
result = torch.nn.functional.pad(result, [0, 0, 0, max_patches - (rows * columns)]).float()
# 将结果转换为 NumPy 数组格式
result = to_numpy_array(result)
return result
# 对图像进行标准化处理,使得图像数据的均值为0,标准差为1
def normalize(
self,
image: np.ndarray,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
) -> np.ndarray:
"""
Normalize an image. image = (image - image_mean) / image_std.
The image std is to mimic the tensorflow implementation of the `per_image_standardization`:
https://www.tensorflow.org/api_docs/python/tf/image/per_image_standardization
Args:
image (`np.ndarray`):
Image to normalize.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the output image. If unset, the channel dimension format of the input
image is used.
input_data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
"""
# 如果图像的数据类型是uint8,则转换为float32类型
if image.dtype == np.uint8:
image = image.astype(np.float32)
# 计算图像的均值和标准差
mean = np.mean(image)
std = np.std(image)
adjusted_stddev = max(std, 1.0 / math.sqrt(np.prod(image.shape)))
# 调用标准化函数进行图像标准化处理,返回标准化后的图像数据
return normalize(
image,
mean=mean,
std=adjusted_stddev,
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
)
# 图像预处理函数,可以进行RGB转换、标准化、裁剪等操作
def preprocess(
self,
images: ImageInput,
header_text: Optional[str] = None,
do_convert_rgb: bool = None,
do_normalize: Optional[bool] = None,
max_patches: Optional[int] = None,
patch_size: Optional[Dict[str, int]] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: ChannelDimension = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
):
.\models\pix2struct\modeling_pix2struct.py
""" Pix2Struct modeling file"""
import math
from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from ...activations import ACT2FN
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPooling,
CausalLMOutputWithCrossAttentions,
Seq2SeqLMOutput,
Seq2SeqModelOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...utils import (
DUMMY_INPUTS,
DUMMY_MASK,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_torch_fx_proxy,
logging,
replace_return_docstrings,
)
from .configuration_pix2struct import Pix2StructConfig, Pix2StructTextConfig, Pix2StructVisionConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "Pix2StructConfig"
PIX2STRUCT_PRETRAINED_MODEL_ARCHIVE_LIST = [
"google/pix2struct-textcaps-base",
"google/pix2struct-textcaps-large",
"google/pix2struct-base",
"google/pix2struct-large",
"google/pix2struct-ai2d-base",
"google/pix2struct-ai2d-large",
"google/pix2struct-widget-captioning-base",
"google/pix2struct-widget-captioning-large",
"google/pix2struct-screen2words-base",
"google/pix2struct-screen2words-large",
"google/pix2struct-docvqa-base",
"google/pix2struct-docvqa-large",
"google/pix2struct-ocrvqa-base",
"google/pix2struct-ocrvqa-large",
"google/pix2struct-chartqa-base",
"google/pix2struct-inforgraphics-vqa-base",
"google/pix2struct-inforgraphics-vqa-large",
]
class Pix2StructLayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
return self.weight * hidden_states
try:
from apex.normalization import FusedRMSNorm
Pix2StructLayerNorm = FusedRMSNorm
logger.info("Discovered apex.normalization.FusedRMSNorm - will use it instead of Pix2StructLayerNorm")
except ImportError:
pass
except Exception:
logger.warning("Discovered apex but it failed to load, falling back to Pix2StructLayerNorm")
pass
ALL_LAYERNORM_LAYERS.append(Pix2StructLayerNorm)
class Pix2StructVisionEmbeddings(nn.Module):
r"""
Construct the embeddings from patch. In `Pix2Struct` the input is different from classic Vision-transformer models.
Here the input is a sequence of `seq_len` flattened patches that also combines padding patches (tokens). Each patch
is represented by a vector of `hidden_size` values.
"""
def __init__(self, config: Pix2StructConfig) -> None:
super().__init__()
self.patch_projection = nn.Linear(config.patch_embed_hidden_size, config.hidden_size)
self.row_embedder = nn.Embedding(config.seq_len, config.hidden_size)
self.column_embedder = nn.Embedding(config.seq_len, config.hidden_size)
self.dropout = nn.Dropout(config.dropout_rate)
def forward(self, flattened_patches: torch.Tensor) -> torch.Tensor:
row_indices = flattened_patches[:, :, 0].long()
col_indices = flattened_patches[:, :, 1].long()
flattened_patches = flattened_patches[:, :, 2:]
embeddings = self.patch_projection(flattened_patches)
row_embeddings = self.row_embedder(row_indices)
col_embeddings = self.column_embedder(col_indices)
embeddings = embeddings + row_embeddings + col_embeddings
embeddings = self.dropout(embeddings)
return embeddings
class Pix2StructVisionAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.hidden_size = config.hidden_size
self.key_value_proj_dim = config.d_kv
self.n_heads = config.num_attention_heads
self.dropout = config.attention_dropout
self.inner_dim = self.n_heads * self.key_value_proj_dim
self.query = nn.Linear(self.hidden_size, self.inner_dim, bias=False)
self.key = nn.Linear(self.hidden_size, self.inner_dim, bias=False)
self.value = nn.Linear(self.hidden_size, self.inner_dim, bias=False)
self.output = nn.Linear(self.inner_dim, self.hidden_size, bias=False)
self.gradient_checkpointing = False
def forward(
self,
hidden_states,
attention_mask=None,
position_bias=None,
layer_head_mask=None,
output_attentions=False,
):
"""
Self-attention block
"""
batch_size, seq_length = hidden_states.shape[:2]
def to_projection_shape(states):
"""将输入状态调整为投影形状"""
return states.contiguous().view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
query_states = to_projection_shape(self.query(hidden_states))
key_states = to_projection_shape(self.key(hidden_states))
value_states = to_projection_shape(self.value(hidden_states))
scores = torch.matmul(query_states, key_states.transpose(3, 2))
if position_bias is None:
position_bias = torch.zeros(
(1, self.n_heads, seq_length, seq_length), device=scores.device, dtype=scores.dtype
)
if self.gradient_checkpointing and self.training:
position_bias.requires_grad = True
if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_length), device=scores.device, dtype=scores.dtype)
if attention_mask.dim() == 2:
position_bias = position_bias + attention_mask[:, None, None, :].to(position_bias.device)
else:
position_bias = position_bias + attention_mask.to(position_bias.device)
position_bias = 1 - position_bias
position_bias_masked = position_bias.masked_fill(position_bias == 1, torch.finfo(scores.dtype).min)
scores += position_bias_masked
scores = torch.max(scores, torch.tensor(torch.finfo(scores.dtype).min))
attn_weights = nn.functional.softmax(scores, dim=-1, dtype=torch.float32).type_as(scores)
attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
if layer_head_mask is not None:
attn_weights = attn_weights * layer_head_mask
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
attn_output = self.output(attn_output)
outputs = (attn_output,) + (position_bias,)
if output_attentions:
outputs = outputs + (attn_weights,)
return outputs
class Pix2StructVisionMlp(nn.Module):
def __init__(self, config: Pix2StructVisionConfig):
super().__init__()
self.wi_0 = nn.Linear(config.hidden_size, config.d_ff, bias=False)
self.wi_1 = nn.Linear(config.hidden_size, config.d_ff, bias=False)
self.wo = nn.Linear(config.d_ff, config.hidden_size, bias=False)
self.dropout = nn.Dropout(config.dropout_rate)
self.act = ACT2FN[config.dense_act_fn]
def forward(self, hidden_states):
hidden_gelu = self.act(self.wi_0(hidden_states))
hidden_linear = self.wi_1(hidden_states)
hidden_states = hidden_gelu * hidden_linear
hidden_states = self.dropout(hidden_states)
if (
isinstance(self.wo.weight, torch.Tensor)
and hidden_states.dtype != self.wo.weight.dtype
and self.wo.weight.dtype != torch.int8
):
hidden_states = hidden_states.to(self.wo.weight.dtype)
hidden_states = self.wo(hidden_states)
return hidden_states
class Pix2StructVisionLayer(nn.Module):
def __init__(self, config: Pix2StructConfig) -> None:
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = Pix2StructVisionAttention(config)
self.mlp = Pix2StructVisionMlp(config)
self.pre_mlp_layer_norm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.pre_attention_layer_norm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
residual = hidden_states
hidden_states = self.pre_attention_layer_norm(hidden_states)
self_attention_outputs = self.attention(
hidden_states,
attention_mask=attention_mask,
layer_head_mask=head_mask,
output_attentions=output_attentions,
)
attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:]
hidden_states = attention_output + residual
layer_output = self.pre_mlp_layer_norm(hidden_states)
layer_output = self.mlp(layer_output) + hidden_states
outputs = (layer_output,) + outputs
return outputs
class Pix2StructVisionEncoder(nn.Module):
def __init__(self, config: Pix2StructConfig) -> None:
super().__init__()
self.config = config
self.layer = nn.ModuleList([Pix2StructVisionLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
) -> Union[tuple, BaseModelOutput]:
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
layer_head_mask,
output_attentions,
)
else:
layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
class Pix2StructPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = Pix2StructConfig
@property
def dummy_inputs(self):
input_ids = torch.tensor(DUMMY_INPUTS)
input_mask = torch.tensor(DUMMY_MASK)
dummy_inputs = {
"decoder_input_ids": input_ids,
"input_ids": input_ids,
"decoder_attention_mask": input_mask,
}
return dummy_inputs
def _shift_right(self, input_ids):
decoder_start_token_id = self.config.decoder_start_token_id
pad_token_id = self.config.pad_token_id
if decoder_start_token_id is None:
raise ValueError(
"self.model.config.decoder_start_token_id has to be defined. In Pix2Struct it is usually set to the pad_token_id. "
"See Pix2Struct docs for more information."
)
if is_torch_fx_proxy(input_ids):
shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id)
shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
else:
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
shifted_input_ids[..., 0] = decoder_start_token_id
if pad_token_id is None:
raise ValueError("self.model.config.pad_token_id has to be defined.")
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
return shifted_input_ids
@add_start_docstrings(
"The bare Pix2StructVision Model transformer outputting raw hidden-states without any specific head on top.",
PIX2STRUCT_VISION_START_DOCSTRING,
)
class Pix2StructVisionModel(Pix2StructPreTrainedModel):
config_class = Pix2StructVisionConfig
main_input_name = "flattened_patches"
supports_gradient_checkpointing = True
_no_split_modules = ["Pix2StructVisionLayer"]
def __init__(self, config: Pix2StructConfig):
super().__init__(config)
self.config = config
self.embeddings = Pix2StructVisionEmbeddings(config)
self.encoder = Pix2StructVisionEncoder(config)
self.layernorm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.post_init()
def get_input_embeddings(self):
return self.embeddings.patch_projection
def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
def forward(
self,
flattened_patches: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if flattened_patches is None:
raise ValueError("You have to specify flattened_patches")
if attention_mask is None:
attention_mask = (flattened_patches.sum(dim=-1) != 0).float()
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
embedding_output = self.embeddings(flattened_patches)
encoder_outputs = self.encoder(
embedding_output,
attention_mask=attention_mask,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = encoder_outputs[0]
sequence_output = self.layernorm(sequence_output)
if not return_dict:
head_outputs = (sequence_output,)
return head_outputs + encoder_outputs[1:]
return BaseModelOutput(
last_hidden_state=sequence_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
class Pix2StructTextDenseGatedActDense(nn.Module):
def __init__(self, config: Pix2StructTextConfig):
super().__init__()
self.wi_0 = nn.Linear(config.hidden_size, config.d_ff, bias=False)
self.wi_1 = nn.Linear(config.hidden_size, config.d_ff, bias=False)
self.wo = nn.Linear(config.d_ff, config.hidden_size, bias=False)
self.dropout = nn.Dropout(config.dropout_rate)
self.act = ACT2FN[config.dense_act_fn]
def forward(self, hidden_states):
hidden_gelu = self.act(self.wi_0(hidden_states))
hidden_linear = self.wi_1(hidden_states)
hidden_states = hidden_gelu * hidden_linear
hidden_states = self.dropout(hidden_states)
if (
isinstance(self.wo.weight, torch.Tensor)
and hidden_states.dtype != self.wo.weight.dtype
and self.wo.weight.dtype != torch.int8
):
hidden_states = hidden_states.to(self.wo.weight.dtype)
hidden_states = self.wo(hidden_states)
return hidden_states
class Pix2StructTextLayerFF(nn.Module):
def __init__(self, config: Pix2StructTextConfig):
super().__init__()
self.DenseReluDense = Pix2StructTextDenseGatedActDense(config)
self.layer_norm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate)
def forward(self, hidden_states):
forwarded_states = self.layer_norm(hidden_states)
forwarded_states = self.DenseReluDense(forwarded_states)
hidden_states = hidden_states + self.dropout(forwarded_states)
return hidden_states
def __init__(self, config: Pix2StructTextConfig, has_relative_attention_bias=False):
super().__init__()
self.has_relative_attention_bias = has_relative_attention_bias
self.relative_attention_num_buckets = config.relative_attention_num_buckets
self.relative_attention_max_distance = config.relative_attention_max_distance
self.hidden_size = config.hidden_size
self.key_value_proj_dim = config.d_kv
self.n_heads = config.num_heads
self.dropout = config.dropout_rate
self.inner_dim = self.n_heads * self.key_value_proj_dim
self.query = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
self.key = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
self.value = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
self.output = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
if self.has_relative_attention_bias:
self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
self.pruned_heads = set()
self.gradient_checkpointing = False
@staticmethod
def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
"""
Adapted from Mesh Tensorflow:
https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
Translate relative position to a bucket number for relative attention. The relative position is defined as
memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
This should allow for more graceful generalization to longer sequences than the model has been trained on
Args:
relative_position: an int32 Tensor - the relative position between memory and query
bidirectional: a boolean - whether the attention is bidirectional or not
num_buckets: an integer - number of buckets to categorize relative positions into
max_distance: an integer - maximum distance to consider for bucketing
Returns:
a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
"""
relative_buckets = 0
if bidirectional:
num_buckets //= 2
relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
relative_position = torch.abs(relative_position)
else:
relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
max_exact = num_buckets // 2
is_small = relative_position < max_exact
relative_position_if_large = max_exact + (
torch.log(relative_position.float() / max_exact)
/ math.log(max_distance / max_exact)
* (num_buckets - max_exact)
).to(torch.long)
relative_position_if_large = torch.min(
relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
)
relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
return relative_buckets
def compute_bias(self, query_length, key_length, device=None):
"""Compute binned relative position bias"""
if device is None:
device = self.relative_attention_bias.weight.device
context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
relative_position = memory_position - context_position
relative_position_bucket = self._relative_position_bucket(
relative_position,
bidirectional=False,
num_buckets=self.relative_attention_num_buckets,
max_distance=self.relative_attention_max_distance,
)
values = self.relative_attention_bias(relative_position_bucket)
values = values.permute([2, 0, 1]).unsqueeze(0)
return values
def forward(
self,
hidden_states,
mask=None,
key_value_states=None,
position_bias=None,
past_key_value=None,
layer_head_mask=None,
query_length=None,
use_cache=False,
output_attentions=False,
class Pix2StructTextLayerSelfAttention(nn.Module):
def __init__(self, config, has_relative_attention_bias=False):
super().__init__()
self.attention = Pix2StructTextAttention(config, has_relative_attention_bias=has_relative_attention_bias)
self.layer_norm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate)
def forward(
self,
hidden_states,
attention_mask=None,
position_bias=None,
layer_head_mask=None,
past_key_value=None,
use_cache=False,
output_attentions=False,
):
normed_hidden_states = self.layer_norm(hidden_states)
attention_output = self.attention(
normed_hidden_states,
mask=attention_mask,
position_bias=position_bias,
layer_head_mask=layer_head_mask,
past_key_value=past_key_value,
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = hidden_states + self.dropout(attention_output[0])
outputs = (hidden_states,) + attention_output[1:]
return outputs
class Pix2StructTextLayerCrossAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.attention = Pix2StructTextAttention(config, has_relative_attention_bias=False)
self.layer_norm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate)
def forward(
self,
hidden_states,
key_value_states,
attention_mask=None,
position_bias=None,
layer_head_mask=None,
past_key_value=None,
use_cache=False,
query_length=None,
output_attentions=False,
):
normed_hidden_states = self.layer_norm(hidden_states)
attention_output = self.attention(
normed_hidden_states,
mask=attention_mask,
key_value_states=key_value_states,
position_bias=position_bias,
layer_head_mask=layer_head_mask,
past_key_value=past_key_value,
use_cache=use_cache,
query_length=query_length,
output_attentions=output_attentions,
)
layer_output = hidden_states + self.dropout(attention_output[0])
outputs = (layer_output,) + attention_output[1:]
return outputs
class Pix2StructTextBlock(nn.Module):
def __init__(self, config, has_relative_attention_bias=False):
super().__init__()
self.self_attention = Pix2StructTextLayerSelfAttention(
config, has_relative_attention_bias=has_relative_attention_bias
)
self.encoder_decoder_attention = Pix2StructTextLayerCrossAttention(config)
self.mlp = Pix2StructTextLayerFF(config)
def forward(
self,
hidden_states,
attention_mask=None,
position_bias=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
encoder_decoder_position_bias=None,
layer_head_mask=None,
cross_attn_layer_head_mask=None,
past_key_value=None,
use_cache=False,
output_attentions=False,
return_dict=True,
):
pass
PIX2STRUCT_START_DOCSTRING = r"""
The Pix2Struct model was proposed in [Pix2Struct: Screenshot Parsing as Pretraining for Visual Language
Understanding](https://arxiv.org/abs/2210.03347) by Kenton Lee, Mandar Joshi, Iulia Turc, Hexiang Hu, Fangyu Liu,
Julian Eisenschlos, Urvashi Khandelwal, Peter Shaw, Ming-Wei Chang, Kristina Toutanova. It's an encoder decoder
transformer pre-trained in a image-to-text setting.
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config (Union[`Pix2StructConfig`, `Pix2StructTextConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
PIX2STRUCT_TEXT_INPUTS_DOCSTRING = r"""
"""
PIX2STRUCT_INPUTS_DOCSTRING = r"""
"""
@add_start_docstrings(
"The standalone text decoder of Pix2Struct",
PIX2STRUCT_START_DOCSTRING,
)
class Pix2StructTextModel(Pix2StructPreTrainedModel):
config_class = Pix2StructTextConfig
_no_split_modules = ["Pix2StructTextBlock"]
_tied_weights_keys = ["lm_head.weight"]
supports_gradient_checkpointing = True
def __init__(self, config):
super().__init__(config)
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layer = nn.ModuleList(
[Pix2StructTextBlock(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)]
)
self.final_layer_norm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.post_init()
self.gradient_checkpointing = False
def _reorder_cache(self, past_key_values, beam_idx):
if past_key_values is None:
logger.warning("You might want to consider setting `use_cache=True` to speed up decoding")
return past_key_values
reordered_decoder_past = ()
for layer_past_states in past_key_values:
reordered_layer_past_states = ()
for layer_past_state in layer_past_states:
reordered_layer_past_states = reordered_layer_past_states + (
layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)),
)
if reordered_layer_past_states[0].shape != layer_past_states[0].shape:
raise ValueError(
f"reordered_layer_past_states[0] shape {reordered_layer_past_states[0].shape} and layer_past_states[0] shape {layer_past_states[0].shape} mismatched"
)
if len(reordered_layer_past_states) != len(layer_past_states):
raise ValueError(
f"length of reordered_layer_past_states {len(reordered_layer_past_states)} and length of layer_past_states {len(layer_past_states)} mismatched"
)
reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
return reordered_decoder_past
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, new_embeddings):
self.embed_tokens = new_embeddings
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
@add_start_docstrings_to_model_forward(PIX2STRUCT_TEXT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
labels: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None,
**kwargs,
@add_start_docstrings(
"A conditional generation model with a language modeling head. Can be used for sequence generation tasks.",
PIX2STRUCT_START_DOCSTRING,
)
class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel):
config_class = Pix2StructConfig
main_input_name = "flattened_patches"
_tied_weights_keys = ["decoder.lm_head.weight"]
def __init__(self, config: Pix2StructConfig):
super().__init__(config)
self.encoder = Pix2StructVisionModel(config.vision_config)
self.decoder = Pix2StructTextModel(config.text_config)
self.is_vqa = config.is_vqa
self.post_init()
def get_input_embeddings(self):
return self.decoder.get_input_embeddings()
def set_input_embeddings(self, new_embeddings):
self.decoder.set_input_embeddings(new_embeddings)
def get_output_embeddings(self) -> nn.Module:
return self.decoder.get_output_embeddings()
def set_output_embeddings(self, new_embeddings):
self.decoder.set_output_embeddings(new_embeddings)
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> nn.Embedding:
model_embeds = self.decoder.resize_token_embeddings(new_num_tokens)
self.config.text_config.vocab_size = new_num_tokens
return model_embeds
def get_decoder(self):
return self.decoder
def get_encoder(self):
return self.encoder
@add_start_docstrings_to_model_forward(PIX2STRUCT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
flattened_patches: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.BoolTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
decoder_head_mask: Optional[torch.FloatTensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
labels: Optional[torch.LongTensor] = None,
decoder_inputs_embeds: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
def prepare_inputs_for_generation(
self,
input_ids,
flattened_patches: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
decoder_attention_mask: Optional[torch.BoolTensor] = None,
past_key_values=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
use_cache=None,
encoder_outputs=None,
**kwargs,
):
if decoder_attention_mask is None:
decoder_attention_mask = torch.ones_like(input_ids).to(input_ids.device)
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
return {
"flattened_patches": flattened_patches,
"decoder_input_ids": input_ids,
"past_key_values": past_key_values,
"encoder_outputs": encoder_outputs,
"attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
"use_cache": use_cache,
}