Transformers 源码解析(九十二)
.\models\rag\tokenization_rag.py
import os
import warnings
from typing import List, Optional
from ...tokenization_utils_base import BatchEncoding
from ...utils import logging
from .configuration_rag import RagConfig
logger = logging.get_logger(__name__)
class RagTokenizer:
def __init__(self, question_encoder, generator):
self.question_encoder = question_encoder
self.generator = generator
self.current_tokenizer = self.question_encoder
def save_pretrained(self, save_directory):
if os.path.isfile(save_directory):
raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file")
os.makedirs(save_directory, exist_ok=True)
question_encoder_path = os.path.join(save_directory, "question_encoder_tokenizer")
generator_path = os.path.join(save_directory, "generator_tokenizer")
self.question_encoder.save_pretrained(question_encoder_path)
self.generator.save_pretrained(generator_path)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
from ..auto.tokenization_auto import AutoTokenizer
config = kwargs.pop("config", None)
if config is None:
config = RagConfig.from_pretrained(pretrained_model_name_or_path)
question_encoder = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path, config=config.question_encoder, subfolder="question_encoder_tokenizer"
)
generator = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path, config=config.generator, subfolder="generator_tokenizer"
)
return cls(question_encoder=question_encoder, generator=generator)
def __call__(self, *args, **kwargs):
return self.current_tokenizer(*args, **kwargs)
def batch_decode(self, *args, **kwargs):
return self.generator.batch_decode(*args, **kwargs)
def decode(self, *args, **kwargs):
return self.generator.decode(*args, **kwargs)
def _switch_to_input_mode(self):
self.current_tokenizer = self.question_encoder
def _switch_to_target_mode(self):
self.current_tokenizer = self.generator
warnings.warn(
"`prepare_seq2seq_batch` is deprecated and will be removed in version 5 of 🤗 Transformers. Use the "
"regular `__call__` method to prepare your inputs and the tokenizer under the `with_target_tokenizer` "
"context manager to prepare your targets. See the documentation of your specific tokenizer for more "
"details",
FutureWarning,
)
if max_length is None:
max_length = self.current_tokenizer.model_max_length
model_inputs = self(
src_texts,
add_special_tokens=True,
return_tensors=return_tensors,
max_length=max_length,
padding=padding,
truncation=truncation,
**kwargs,
)
if tgt_texts is None:
return model_inputs
if max_target_length is None:
max_target_length = self.current_tokenizer.model_max_length
labels = self(
text_target=tgt_texts,
add_special_tokens=True,
return_tensors=return_tensors,
padding=padding,
max_length=max_target_length,
truncation=truncation,
**kwargs,
)
model_inputs["labels"] = labels["input_ids"]
return model_inputs
.\models\rag\__init__.py
from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tf_available, is_torch_available
_import_structure = {
"configuration_rag": ["RagConfig"],
"retrieval_rag": ["RagRetriever"],
"tokenization_rag": ["RagTokenizer"],
}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_rag"] = [
"RagModel",
"RagPreTrainedModel",
"RagSequenceForGeneration",
"RagTokenForGeneration",
]
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_tf_rag"] = [
"TFRagModel",
"TFRagPreTrainedModel",
"TFRagSequenceForGeneration",
"TFRagTokenForGeneration",
]
if TYPE_CHECKING:
from .configuration_rag import RagConfig
from .retrieval_rag import RagRetriever
from .tokenization_rag import RagTokenizer
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_rag import RagModel, RagPreTrainedModel, RagSequenceForGeneration, RagTokenForGeneration
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_tf_rag import (
TFRagModel,
TFRagPreTrainedModel,
TFRagSequenceForGeneration,
TFRagTokenForGeneration,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
.\models\realm\configuration_realm.py
""" REALM model configuration."""
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
REALM_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"google/realm-cc-news-pretrained-embedder": (
"https://huggingface.co/google/realm-cc-news-pretrained-embedder/resolve/main/config.json"
),
"google/realm-cc-news-pretrained-encoder": (
"https://huggingface.co/google/realm-cc-news-pretrained-encoder/resolve/main/config.json"
),
"google/realm-cc-news-pretrained-scorer": (
"https://huggingface.co/google/realm-cc-news-pretrained-scorer/resolve/main/config.json"
),
"google/realm-cc-news-pretrained-openqa": (
"https://huggingface.co/google/realm-cc-news-pretrained-openqa/aresolve/main/config.json"
),
"google/realm-orqa-nq-openqa": "https://huggingface.co/google/realm-orqa-nq-openqa/resolve/main/config.json",
"google/realm-orqa-nq-reader": "https://huggingface.co/google/realm-orqa-nq-reader/resolve/main/config.json",
"google/realm-orqa-wq-openqa": "https://huggingface.co/google/realm-orqa-wq-openqa/resolve/main/config.json",
"google/realm-orqa-wq-reader": "https://huggingface.co/google/realm-orqa-wq-reader/resolve/main/config.json",
}
class RealmConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of
1. [`RealmEmbedder`]
2. [`RealmScorer`]
3. [`RealmKnowledgeAugEncoder`]
4. [`RealmRetriever`]
5. [`RealmReader`]
6. [`RealmForOpenQA`]
It is used to instantiate an REALM model according to the specified arguments, defining the model architecture.
Instantiating a configuration with the defaults will yield a similar configuration to that of the REALM
[google/realm-cc-news-pretrained-embedder](https://huggingface.co/google/realm-cc-news-pretrained-embedder)
architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Example:
```
>>> from transformers import RealmConfig, RealmEmbedder
>>> # Initializing a REALM realm-cc-news-pretrained-* style configuration
>>> configuration = RealmConfig()
```
# 使用给定的配置初始化一个模型(具有随机权重),使用 google/realm-cc-news-pretrained-embedder 风格的配置
model = RealmEmbedder(configuration)
# 访问模型的配置信息
configuration = model.config
.\models\realm\modeling_realm.py
import math
import os
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions,
MaskedLMOutput,
ModelOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_realm import RealmConfig
logger = logging.get_logger(__name__)
_EMBEDDER_CHECKPOINT_FOR_DOC = "google/realm-cc-news-pretrained-embedder"
_ENCODER_CHECKPOINT_FOR_DOC = "google/realm-cc-news-pretrained-encoder"
_SCORER_CHECKPOINT_FOR_DOC = "google/realm-cc-news-pretrained-scorer"
_CONFIG_FOR_DOC = "RealmConfig"
REALM_PRETRAINED_MODEL_ARCHIVE_LIST = [
"google/realm-cc-news-pretrained-embedder",
"google/realm-cc-news-pretrained-encoder",
"google/realm-cc-news-pretrained-scorer",
"google/realm-cc-news-pretrained-openqa",
"google/realm-orqa-nq-openqa",
"google/realm-orqa-nq-reader",
"google/realm-orqa-wq-openqa",
"google/realm-orqa-wq-reader",
]
def load_tf_weights_in_realm(model, config, tf_checkpoint_path):
"""Load tf checkpoints in a pytorch model."""
try:
import re
import numpy as np
import tensorflow as tf
except ImportError:
logger.error(
"Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions."
)
raise
tf_path = os.path.abspath(tf_checkpoint_path)
logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
init_vars = tf.train.list_variables(tf_path)
names = []
arrays = []
for name, shape in init_vars:
logger.info(f"Loading TF weight {name} with shape {shape}")
array = tf.train.load_variable(tf_path, name)
names.append(name)
arrays.append(array)
return model
class RealmEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings."""
def __init__(self, config):
super().__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
self.register_buffer(
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
past_key_values_length: int = 0,
) -> torch.Tensor:
if input_ids is not None:
input_shape = input_ids.size()
else:
input_shape = inputs_embeds.size()[:-1]
seq_length = input_shape[1]
if position_ids is None:
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
if token_type_ids is None:
if hasattr(self, "token_type_ids"):
buffered_token_type_ids = self.token_type_ids[:, :seq_length]
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
token_type_ids = buffered_token_type_ids_expanded
else:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = inputs_embeds + token_type_embeddings
if self.position_embedding_type == "absolute":
position_embeddings = self.position_embeddings(position_ids)
embeddings += position_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
class RealmSelfAttention(nn.Module):
def __init__(self, config, position_embedding_type=None):
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
f"heads ({config.num_attention_heads})"
)
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = position_embedding_type or getattr(
config, "position_embedding_type", "absolute"
)
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
self.is_decoder = config.is_decoder
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
class RealmSelfOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class RealmAttention(nn.Module):
def __init__(self, config, position_embedding_type=None):
super().__init__()
self.self = RealmSelfAttention(config, position_embedding_type=position_embedding_type)
self.output = RealmSelfOutput(config)
self.pruned_heads = set()
def prune_heads(self, heads):
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
)
self.self.query = prune_linear_layer(self.self.query, index)
self.self.key = prune_linear_layer(self.self.key, index)
self.self.value = prune_linear_layer(self.self.value, index)
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
self_outputs = self.self(
hidden_states,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:]
return outputs
class RealmIntermediate(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class RealmOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class RealmLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = RealmAttention(config)
self.is_decoder = config.is_decoder
self.add_cross_attention = config.add_cross_attention
if self.add_cross_attention:
if not self.is_decoder:
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
self.crossattention = RealmAttention(config, position_embedding_type="absolute")
self.intermediate = RealmIntermediate(config)
self.output = RealmOutput(config)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention(
hidden_states,
attention_mask,
head_mask,
output_attentions=output_attentions,
past_key_value=self_attn_past_key_value,
)
attention_output = self_attention_outputs[0]
if self.is_decoder:
outputs = self_attention_outputs[1:-1]
present_key_value = self_attention_outputs[-1]
else:
outputs = self_attention_outputs[1:]
cross_attn_present_key_value = None
if self.is_decoder and encoder_hidden_states is not None:
if not hasattr(self, "crossattention"):
raise ValueError(
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
" by setting `config.add_cross_attention=True`"
)
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
cross_attention_outputs = self.crossattention(
attention_output,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
cross_attn_past_key_value,
output_attentions,
)
attention_output = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1:-1]
cross_attn_present_key_value = cross_attention_outputs[-1]
present_key_value = present_key_value + cross_attn_present_key_value
layer_output = apply_chunking_to_forward(
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
)
outputs = (layer_output,) + outputs
if self.is_decoder:
outputs = outputs + (present_key_value,)
return outputs
def feed_forward_chunk(self, attention_output):
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
return layer_output
class RealmEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.layer = nn.ModuleList([RealmLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = True,
-> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
next_decoder_cache = () if use_cache else None
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[-1],)
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v
for v in [
hidden_states,
next_decoder_cache,
all_hidden_states,
all_self_attentions,
all_cross_attentions,
]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
)
class RealmPooler(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
@dataclass
class RealmEmbedderOutput(ModelOutput):
"""
RealmEmbedder 模型的输出。
Args:
projected_score (`torch.FloatTensor` of shape `(batch_size, config.retriever_proj_size)`):
投影分数。
hidden_states (`tuple(torch.FloatTensor)`, *optional*, 当 `output_hidden_states=True` 时返回或当 `config.output_hidden_states=True` 时返回):
一个元组,包含 `torch.FloatTensor`(一个用于嵌入输出 + 一个用于每个层的输出)的形状为 `(batch_size, sequence_length, hidden_size)`。
模型在每一层输出的隐藏状态加上初始嵌入的输出。
attentions (`tuple(torch.FloatTensor)`, *optional*, 当 `output_attentions=True` 时返回或当 `config.output_attentions=True` 时返回):
一个元组,包含 `torch.FloatTensor`(每层一个)的形状为 `(batch_size, num_heads, sequence_length, sequence_length)`。
在注意力 softmax 后的注意力权重,用于计算自注意力头中的加权平均值。
"""
projected_score: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class RealmScorerOutput(ModelOutput):
"""
RealmScorer 模型的输出。
Args:
relevance_score (`torch.FloatTensor` of shape `(batch_size, config.num_candidates)`):
文件候选的相关性分数(softmax 之前)。
query_score (`torch.FloatTensor` of shape `(batch_size, config.retriever_proj_size)`):
源自查询嵌入的查询分数。
candidate_score (`torch.FloatTensor` of shape `(batch_size, config.num_candidates, config.retriever_proj_size)`):
源自嵌入器的候选分数。
"""
relevance_score: torch.FloatTensor = None
query_score: torch.FloatTensor = None
candidate_score: torch.FloatTensor = None
@dataclass
class RealmReaderOutput(ModelOutput):
"""
RealmReader 模型的输出。
这里没有特定的参数和注释,仅作为占位使用。
"""
pass
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `start_positions`, `end_positions`, `has_answers` are provided):
总损失。
retriever_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `start_positions`, `end_positions`, `has_answers` are provided):
检索器损失。
reader_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `start_positions`, `end_positions`, `has_answers` are provided):
阅读器损失。
retriever_correct (`torch.BoolTensor` of shape `(config.searcher_beam_size,)`, *optional*):
检索器是否正确检测到包含答案的证据块。
reader_correct (`torch.BoolTensor` of shape `(config.reader_beam_size, num_candidates)`, *optional*):
阅读器是否正确检测到包含答案的文本片段候选。
block_idx (`torch.LongTensor` of shape `()`):
预测答案最有可能出现的检索到的证据块的索引。
candidate (`torch.LongTensor` of shape `()`):
预测答案最有可能出现的检索到的文本片段候选的索引。
start_pos (`torch.IntTensor` of shape `()`):
预测答案在 *RealmReader* 输入中起始位置的索引。
end_pos (`torch.IntTensor` of shape `()`):
预测答案在 *RealmReader* 输入中结束位置的索引。
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
模型每一层的隐藏状态,包括初始嵌入输出。
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
self-attention 头部的注意力权重,用于计算注意力头部的加权平均值。
loss: torch.FloatTensor = None
retriever_loss: torch.FloatTensor = None
reader_loss: torch.FloatTensor = None
retriever_correct: torch.BoolTensor = None
reader_correct: torch.BoolTensor = None
block_idx: torch.LongTensor = None
candidate: torch.LongTensor = None
start_pos: torch.int32 = None
end_pos: torch.int32 = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class RealmForOpenQAOutput(ModelOutput):
"""
Outputs of [`RealmForOpenQA`] models.
Args:
reader_output (`dict`):
Reader output.
predicted_answer_ids (`torch.LongTensor` of shape `(answer_sequence_length)`):
Predicted answer ids.
"""
reader_output: dict = None
predicted_answer_ids: torch.LongTensor = None
class RealmPredictionHeadTransform(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
if isinstance(config.hidden_act, str):
self.transform_act_fn = ACT2FN[config.hidden_act]
else:
self.transform_act_fn = config.hidden_act
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states
class RealmLMPredictionHead(nn.Module):
def __init__(self, config):
super().__init__()
self.transform = RealmPredictionHeadTransform(config)
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
self.decoder.bias = self.bias
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
return hidden_states
class RealmOnlyMLMHead(nn.Module):
def __init__(self, config):
super().__init__()
self.predictions = RealmLMPredictionHead(config)
def forward(self, sequence_output):
prediction_scores = self.predictions(sequence_output)
return prediction_scores
class RealmScorerProjection(nn.Module):
def __init__(self, config):
super().__init__()
self.predictions = RealmLMPredictionHead(config)
self.dense = nn.Linear(config.hidden_size, config.retriever_proj_size)
self.LayerNorm = nn.LayerNorm(config.retriever_proj_size, eps=config.layer_norm_eps)
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states
class RealmReaderProjection(nn.Module):
pass
def __init__(self, config):
super().__init__()
self.config = config
self.dense_intermediate = nn.Linear(config.hidden_size, config.span_hidden_size * 2)
self.dense_output = nn.Linear(config.span_hidden_size, 1)
self.layer_normalization = nn.LayerNorm(config.span_hidden_size, eps=config.reader_layer_norm_eps)
self.relu = nn.ReLU()
def forward(self, hidden_states, block_mask):
def span_candidates(masks):
"""
Generate span candidates.
Args:
masks: <bool> [num_retrievals, max_sequence_len]
Returns:
starts: <int32> [num_spans] ends: <int32> [num_spans] span_masks: <int32> [num_retrievals, num_spans]
whether spans locate in evidence block.
"""
_, max_sequence_len = masks.shape
def _spans_given_width(width):
current_starts = torch.arange(max_sequence_len - width + 1, device=masks.device)
current_ends = torch.arange(width - 1, max_sequence_len, device=masks.device)
return current_starts, current_ends
starts, ends = zip(*(_spans_given_width(w + 1) for w in range(self.config.max_span_width)))
starts = torch.cat(starts, 0)
ends = torch.cat(ends, 0)
start_masks = torch.index_select(masks, dim=-1, index=starts)
end_masks = torch.index_select(masks, dim=-1, index=ends)
span_masks = start_masks * end_masks
return starts, ends, span_masks
def mask_to_score(mask, dtype=torch.float32):
return (1.0 - mask.type(dtype)) * torch.finfo(dtype).min
hidden_states = self.dense_intermediate(hidden_states)
start_projection, end_projection = hidden_states.chunk(2, dim=-1)
candidate_starts, candidate_ends, candidate_mask = span_candidates(block_mask)
candidate_start_projections = torch.index_select(start_projection, dim=1, index=candidate_starts)
candidate_end_projections = torch.index_select(end_projection, dim=1, index=candidate_ends)
candidate_hidden = candidate_start_projections + candidate_end_projections
candidate_hidden = self.relu(candidate_hidden)
candidate_hidden = self.layer_normalization(candidate_hidden)
reader_logits = self.dense_output(candidate_hidden).squeeze(-1)
reader_logits += mask_to_score(candidate_mask, dtype=reader_logits.dtype)
return reader_logits, candidate_starts, candidate_ends
REALM_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
behavior.
Parameters:
config ([`RealmConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
REALM_INPUTS_DOCSTRING = r"""
"""
Args:
input_ids (`torch.LongTensor` of shape `({0})`):
attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
output_attentions (`bool`, *optional*):
output_hidden_states (`bool`, *optional*):
return_dict (`bool`, *optional*):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = RealmConfig
load_tf_weights = load_tf_weights_in_realm
base_model_prefix = "realm"
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _flatten_inputs(self, *inputs):
"""Flatten inputs' shape to (-1, input_shape[-1])"""
flattened_inputs = []
for tensor in inputs:
if tensor is None:
flattened_inputs.append(None)
else:
input_shape = tensor.shape
if len(input_shape) > 2:
tensor = tensor.view((-1, input_shape[-1]))
flattened_inputs.append(tensor)
return flattened_inputs
class RealmBertModel(RealmPreTrainedModel):
"""
Same as the original BertModel but remove docstrings.
"""
def __init__(self, config, add_pooling_layer=True):
super().__init__(config)
self.config = config
self.embeddings = RealmEmbeddings(config)
self.encoder = RealmEncoder(config)
self.pooler = RealmPooler(config) if add_pooling_layer else None
self.post_init()
def get_input_embeddings(self):
return self.embeddings.word_embeddings
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
def _prune_heads(self, heads_to_prune):
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
@add_start_docstrings(
"The embedder of REALM outputting projected score that will be used to calculate relevance score.",
REALM_START_DOCSTRING,
)
class RealmEmbedder(RealmPreTrainedModel):
_tied_weights_keys = ["cls.predictions.decoder.bias"]
def __init__(self, config):
super().__init__(config)
self.realm = RealmBertModel(self.config)
self.cls = RealmScorerProjection(self.config)
self.post_init()
def get_input_embeddings(self):
return self.realm.embeddings.word_embeddings
def set_input_embeddings(self, value):
self.realm.embeddings.word_embeddings = value
@add_start_docstrings_to_model_forward(REALM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=RealmEmbedderOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, RealmEmbedderOutput]:
"""
RealmEmbedder 的前向传播方法。
Returns:
如果 return_dict 为 False,则返回元组 (projected_score, hidden_states, attentions)。
如果 return_dict 为 True,则返回 RealmEmbedderOutput 对象,其中包含 projected_score、hidden_states 和 attentions。
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
realm_outputs = self.realm(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
pooler_output = realm_outputs[1]
projected_score = self.cls(pooler_output)
if not return_dict:
return (projected_score,) + realm_outputs[2:4]
else:
return RealmEmbedderOutput(
projected_score=projected_score,
hidden_states=realm_outputs.hidden_states,
attentions=realm_outputs.attentions,
)
"The scorer of REALM outputting relevance scores representing the score of document candidates (before softmax).",
REALM_START_DOCSTRING,
class RealmScorer(RealmPreTrainedModel):
r"""
Args:
query_embedder ([`RealmEmbedder`]):
Embedder for input sequences. If not specified, it will use the same embedder as candidate sequences.
"""
def __init__(self, config, query_embedder=None):
super().__init__(config)
self.embedder = RealmEmbedder(self.config)
self.query_embedder = query_embedder if query_embedder is not None else self.embedder
self.post_init()
@add_start_docstrings_to_model_forward(REALM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=RealmScorerOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
candidate_input_ids: Optional[torch.LongTensor] = None,
candidate_attention_mask: Optional[torch.FloatTensor] = None,
candidate_token_type_ids: Optional[torch.LongTensor] = None,
candidate_inputs_embeds: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
@add_start_docstrings(
"The knowledge-augmented encoder of REALM outputting masked language model logits and marginal log-likelihood"
" loss.",
REALM_START_DOCSTRING,
)
class RealmKnowledgeAugEncoder(RealmPreTrainedModel):
_tied_weights_keys = ["cls.predictions.decoder"]
def __init__(self, config):
super().__init__(config)
self.realm = RealmBertModel(self.config)
self.cls = RealmOnlyMLMHead(self.config)
self.post_init()
def get_input_embeddings(self):
return self.realm.embeddings.word_embeddings
def set_input_embeddings(self, value):
self.realm.embeddings.word_embeddings = value
def get_output_embeddings(self):
return self.cls.predictions.decoder
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
@add_start_docstrings_to_model_forward(
REALM_INPUTS_DOCSTRING.format("batch_size, num_candidates, sequence_length")
)
@replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
relevance_score: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
mlm_mask: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
@add_start_docstrings("The reader of REALM.", REALM_START_DOCSTRING)
class RealmReader(RealmPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.realm = RealmBertModel(config)
self.cls = RealmOnlyMLMHead(config)
self.qa_outputs = RealmReaderProjection(config)
self.post_init()
@add_start_docstrings_to_model_forward(REALM_INPUTS_DOCSTRING.format("reader_beam_size, sequence_length"))
@replace_return_docstrings(output_type=RealmReaderOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
relevance_score: Optional[torch.FloatTensor] = None,
block_mask: Optional[torch.BoolTensor] = None,
start_positions: Optional[torch.LongTensor] = None,
end_positions: Optional[torch.LongTensor] = None,
has_answers: Optional[torch.BoolTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
REALM_FOR_OPEN_QA_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `({0})`):
Indices of input sequence tokens in the vocabulary.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
1]`:
- 0 corresponds to a *sentence A* token,
- 1 corresponds to a *sentence B* token (should not be used in this model by design).
[What are token type IDs?](../glossary#token-type-ids)
answer_ids (`list` of shape `(num_answers, answer_length)`, *optional*):
Answer ids for computing the marginal log-likelihood loss. Indices should be in `[-1, 0, ...,
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-1` are ignored (masked), the
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
`RealmForOpenQA` 用于端到端的开放域问答。
该类继承自 `RealmPreTrainedModel`。
"""
class RealmForOpenQA(RealmPreTrainedModel):
def __init__(self, config, retriever=None):
"""
初始化方法,用于实例化一个 `RealmForOpenQA` 对象。
Args:
config (`PretrainedConfig`): 包含该模型配置信息的配置对象。
retriever (`Optional`): 用于检索的对象,默认为 `None`。
"""
super().__init__(config)
self.embedder = RealmEmbedder(config) # 实例化一个 `RealmEmbedder` 对象
self.reader = RealmReader(config) # 实例化一个 `RealmReader` 对象
self.register_buffer(
"block_emb",
torch.zeros(()).new_empty(
size=(config.num_block_records, config.retriever_proj_size),
dtype=torch.float32,
device=torch.device("cpu"),
),
)
self.retriever = retriever # 设置检索器对象
self.post_init() # 调用初始化后处理方法
@property
def searcher_beam_size(self):
"""
获取搜索器的 beam size。在训练模式下返回 `config.searcher_beam_size`,
否则返回 `config.reader_beam_size`。
Returns:
`int`: beam size 的大小。
"""
if self.training:
return self.config.searcher_beam_size
return self.config.reader_beam_size
def block_embedding_to(self, device):
"""
将 `self.block_emb` 发送到指定的设备。
Args:
device (`str` or `torch.device`):
要发送 `self.block_emb` 的目标设备。
"""
self.block_emb = self.block_emb.to(device)
@add_start_docstrings_to_model_forward(REALM_FOR_OPEN_QA_DOCSTRING.format("1, sequence_length"))
@replace_return_docstrings(output_type=RealmForOpenQAOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.LongTensor],
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
answer_ids: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None,
):
"""
模型的前向传播方法,用于执行推断或训练。
Args:
input_ids (`Optional[torch.LongTensor]`):
输入的 token IDs。
attention_mask (`Optional[torch.FloatTensor]`, optional):
注意力掩码张量,默认为 `None`。
token_type_ids (`Optional[torch.LongTensor]`, optional):
分段 token IDs,默认为 `None`。
answer_ids (`Optional[torch.LongTensor]`, optional):
答案 token IDs,默认为 `None`。
return_dict (`Optional[bool]`, optional):
是否返回字典作为输出,默认为 `None`。
Returns:
`RealmForOpenQAOutput` 或者是一个字典,包含模型输出的各种信息。
"""
.\models\realm\retrieval_realm.py
"""REALM Retriever model implementation."""
import os
from typing import Optional, Union
import numpy as np
from huggingface_hub import hf_hub_download
from ... import AutoTokenizer
from ...utils import logging
_REALM_BLOCK_RECORDS_FILENAME = "block_records.npy"
logger = logging.get_logger(__name__)
class ScaNNSearcher:
"""Note that ScaNNSearcher cannot currently be used within the model. In future versions, it might however be included."""
def __init__(
self,
db,
num_neighbors,
dimensions_per_block=2,
num_leaves=1000,
num_leaves_to_search=100,
training_sample_size=100000,
):
"""Build scann searcher."""
from scann.scann_ops.py.scann_ops_pybind import builder as Builder
builder = Builder(db=db, num_neighbors=num_neighbors, distance_measure="dot_product")
builder = builder.tree(
num_leaves=num_leaves, num_leaves_to_search=num_leaves_to_search, training_sample_size=training_sample_size
)
builder = builder.score_ah(dimensions_per_block=dimensions_per_block)
self.searcher = builder.build()
def search_batched(self, question_projection):
"""Perform batched search using the constructed SCANN searcher."""
retrieved_block_ids, _ = self.searcher.search_batched(question_projection.detach().cpu())
return retrieved_block_ids.astype("int64")
class RealmRetriever:
"""The retriever of REALM outputting the retrieved evidence block and whether the block has answers as well as answer
positions."
Parameters:
block_records (`np.ndarray`):
A numpy array which cantains evidence texts.
tokenizer ([`RealmTokenizer`]):
The tokenizer to encode retrieved texts.
"""
def __init__(self, block_records, tokenizer):
"""Initialize RealmRetriever with block records and tokenizer."""
super().__init__()
self.block_records = block_records
self.tokenizer = tokenizer
def __call__(self, retrieved_block_ids, question_input_ids, answer_ids, max_length=None, return_tensors="pt"):
retrieved_blocks = np.take(self.block_records, indices=retrieved_block_ids, axis=0)
question = self.tokenizer.decode(question_input_ids[0], skip_special_tokens=True)
text = []
text_pair = []
for retrieved_block in retrieved_blocks:
text.append(question)
text_pair.append(retrieved_block.decode())
concat_inputs = self.tokenizer(
text, text_pair, padding=True, truncation=True, return_special_tokens_mask=True, max_length=max_length
)
concat_inputs_tensors = concat_inputs.convert_to_tensors(return_tensors)
if answer_ids is not None:
return self.block_has_answer(concat_inputs, answer_ids) + (concat_inputs_tensors,)
else:
return (None, None, None, concat_inputs_tensors)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *init_inputs, **kwargs):
if os.path.isdir(pretrained_model_name_or_path):
block_records_path = os.path.join(pretrained_model_name_or_path, _REALM_BLOCK_RECORDS_FILENAME)
else:
block_records_path = hf_hub_download(
repo_id=pretrained_model_name_or_path, filename=_REALM_BLOCK_RECORDS_FILENAME, **kwargs
)
block_records = np.load(block_records_path, allow_pickle=True)
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, *init_inputs, **kwargs)
return cls(block_records, tokenizer)
def save_pretrained(self, save_directory):
np.save(os.path.join(save_directory, _REALM_BLOCK_RECORDS_FILENAME), self.block_records)
self.tokenizer.save_pretrained(save_directory)
def block_has_answer(self, concat_inputs, answer_ids):
"""check if retrieved_blocks has answers."""
has_answers = []
start_pos = []
end_pos = []
max_answers = 0
for input_id in concat_inputs.input_ids:
input_id_list = input_id.tolist()
first_sep_idx = input_id_list.index(self.tokenizer.sep_token_id)
second_sep_idx = first_sep_idx + 1 + input_id_list[first_sep_idx + 1:].index(self.tokenizer.sep_token_id)
start_pos.append([])
end_pos.append([])
for answer in answer_ids:
for idx in range(first_sep_idx + 1, second_sep_idx):
if answer[0] == input_id_list[idx]:
if input_id_list[idx: idx + len(answer)] == answer:
start_pos[-1].append(idx)
end_pos[-1].append(idx + len(answer) - 1)
if len(start_pos[-1]) == 0:
has_answers.append(False)
else:
has_answers.append(True)
if len(start_pos[-1]) > max_answers:
max_answers = len(start_pos[-1])
for start_pos_, end_pos_ in zip(start_pos, end_pos):
if len(start_pos_) < max_answers:
padded = [-1] * (max_answers - len(start_pos_))
start_pos_ += padded
end_pos_ += padded
return has_answers, start_pos, end_pos
.\models\realm\tokenization_realm.py
"""Tokenization classes for REALM."""
import collections
import os
import unicodedata
from typing import List, Optional, Tuple
from ...tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace
from ...tokenization_utils_base import BatchEncoding
from ...utils import PaddingStrategy, logging
logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"google/realm-cc-news-pretrained-embedder": (
"https://huggingface.co/google/realm-cc-news-pretrained-embedder/resolve/main/vocab.txt"
),
"google/realm-cc-news-pretrained-encoder": (
"https://huggingface.co/google/realm-cc-news-pretrained-encoder/resolve/main/vocab.txt"
),
"google/realm-cc-news-pretrained-scorer": (
"https://huggingface.co/google/realm-cc-news-pretrained-scorer/resolve/main/vocab.txt"
),
"google/realm-cc-news-pretrained-openqa": (
"https://huggingface.co/google/realm-cc-news-pretrained-openqa/aresolve/main/vocab.txt"
),
"google/realm-orqa-nq-openqa": "https://huggingface.co/google/realm-orqa-nq-openqa/resolve/main/vocab.txt",
"google/realm-orqa-nq-reader": "https://huggingface.co/google/realm-orqa-nq-reader/resolve/main/vocab.txt",
"google/realm-orqa-wq-openqa": "https://huggingface.co/google/realm-orqa-wq-openqa/resolve/main/vocab.txt",
"google/realm-orqa-wq-reader": "https://huggingface.co/google/realm-orqa-wq-reader/resolve/main/vocab.txt",
}
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"google/realm-cc-news-pretrained-embedder": 512,
"google/realm-cc-news-pretrained-encoder": 512,
"google/realm-cc-news-pretrained-scorer": 512,
"google/realm-cc-news-pretrained-openqa": 512,
"google/realm-orqa-nq-openqa": 512,
"google/realm-orqa-nq-reader": 512,
"google/realm-orqa-wq-openqa": 512,
"google/realm-orqa-wq-reader": 512,
}
PRETRAINED_INIT_CONFIGURATION = {
"google/realm-cc-news-pretrained-embedder": {"do_lower_case": True},
"google/realm-cc-news-pretrained-encoder": {"do_lower_case": True},
"google/realm-cc-news-pretrained-scorer": {"do_lower_case": True},
"google/realm-cc-news-pretrained-openqa": {"do_lower_case": True},
}
"google/realm-orqa-nq-openqa": {"do_lower_case": True},
"google/realm-orqa-nq-reader": {"do_lower_case": True},
"google/realm-orqa-wq-openqa": {"do_lower_case": True},
"google/realm-orqa-wq-reader": {"do_lower_case": True},
}
def load_vocab(vocab_file):
vocab = collections.OrderedDict()
with open(vocab_file, "r", encoding="utf-8") as reader:
tokens = reader.readlines()
for index, token in enumerate(tokens):
token = token.rstrip("\n")
vocab[token] = index
return vocab
def whitespace_tokenize(text):
text = text.strip()
if not text:
return []
tokens = text.split()
return tokens
class RealmTokenizer(PreTrainedTokenizer):
r"""
Construct a REALM tokenizer.
[`RealmTokenizer`] is identical to [`BertTokenizer`] and runs end-to-end tokenization: punctuation splitting and
wordpiece.
This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
this superclass for more information regarding those methods.
"""
def __init__(
self,
vocab_file,
do_lower_case=True,
do_basic_tokenize=True,
never_split=None,
unk_token="[UNK]",
sep_token="[SEP]",
pad_token="[PAD]",
cls_token="[CLS]",
mask_token="[MASK]",
tokenize_chinese_chars=True,
strip_accents=None,
**kwargs,
):
):
if not os.path.isfile(vocab_file):
raise ValueError(
f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained"
" model use `tokenizer = RealmTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
)
self.vocab = load_vocab(vocab_file)
self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
self.do_basic_tokenize = do_basic_tokenize
if do_basic_tokenize:
self.basic_tokenizer = BasicTokenizer(
do_lower_case=do_lower_case,
never_split=never_split,
tokenize_chinese_chars=tokenize_chinese_chars,
strip_accents=strip_accents,
)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=str(unk_token))
super().__init__(
do_lower_case=do_lower_case,
do_basic_tokenize=do_basic_tokenize,
never_split=never_split,
unk_token=unk_token,
sep_token=sep_token,
pad_token=pad_token,
cls_token=cls_token,
mask_token=mask_token,
tokenize_chinese_chars=tokenize_chinese_chars,
strip_accents=strip_accents,
**kwargs,
)
@property
def do_lower_case(self):
return self.basic_tokenizer.do_lower_case
@property
def vocab_size(self):
return len(self.vocab)
def get_vocab(self):
return dict(self.vocab, **self.added_tokens_encoder)
def _tokenize(self, text):
split_tokens = []
if self.do_basic_tokenize:
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
if token in self.basic_tokenizer.never_split:
split_tokens.append(token)
else:
split_tokens += self.wordpiece_tokenizer.tokenize(token)
else:
split_tokens = self.wordpiece_tokenizer.tokenize(text)
return split_tokens
def _convert_token_to_id(self, token):
"""Converts a token (str) in an id using the vocab."""
return self.vocab.get(token, self.vocab.get(self.unk_token))
def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab."""
return self.ids_to_tokens.get(index, self.unk_token)
def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (string) in a single string."""
out_string = " ".join(tokens).replace(" ##", "").strip()
return out_string
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Build model inputs from a sequence or a pair of sequences for sequence classification tasks by concatenating and
adding special tokens. A REALM sequence has the following format:
- single sequence: `[CLS] X [SEP]`
- pair of sequences: `[CLS] A [SEP] B [SEP]`
Args:
token_ids_0 (`List[int]`):
List of IDs to which the special tokens will be added.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: List of input IDs with the appropriate special tokens.
"""
if token_ids_1 is None:
return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
cls = [self.cls_token_id]
sep = [self.sep_token_id]
return cls + token_ids_0 + sep + token_ids_1 + sep
def get_special_tokens_mask(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
) -> List[int]:
"""
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer `prepare_for_model` method.
Args:
token_ids_0 (`List[int]`):
List of IDs.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not the token list is already formatted with special tokens for the model.
Returns:
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
"""
if already_has_special_tokens:
return super().get_special_tokens_mask(
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
)
if token_ids_1 is not None:
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
return [1] + ([0] * len(token_ids_0)) + [1]
def create_token_type_ids_from_sequences(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Create token type IDs tensor from given sequence or pair of sequences. A REALM token type IDs sequence has the
following format:
- single sequence: `[0] * (len(token_ids_0) + 2)`
- pair of sequences: `[0] * (len(token_ids_0) + 2) + [1] * (len(token_ids_1) + 1)`
Args:
token_ids_0 (`List[int]`):
List of IDs for the first sequence.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: List of token type IDs with the appropriate length and values.
"""
def create_token_type_ids_from_sequences(self, token_ids_0: List[int], token_ids_1: Optional[List[int]]) -> List[int]:
"""
Create a mask from the two sequences passed to be used in a sequence-pair classification task. A REALM sequence
pair mask has the following format:
```
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
| first sequence | second sequence |
```
If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s).
Args:
token_ids_0 (`List[int]`):
List of IDs for the first sequence.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
"""
sep = [self.sep_token_id]
cls = [self.cls_token_id]
if token_ids_1 is None:
return len(cls + token_ids_0 + sep) * [0]
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
index = 0
if os.path.isdir(save_directory):
vocab_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
)
else:
vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory
with open(vocab_file, "w", encoding="utf-8") as writer:
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning(
f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive."
" Please check that the vocabulary is not corrupted!"
)
index = token_index
writer.write(token + "\n")
index += 1
return (vocab_file,)
class BasicTokenizer(object):
"""
构造一个 BasicTokenizer 实例,用于运行基本的分词操作(如标点符号分割、转换为小写等)。
Args:
do_lower_case (`bool`, *optional*, defaults to `True`):
是否在分词时将输入转换为小写。
never_split (`Iterable`, *optional*):
在分词过程中永远不会被拆分的 token 集合。仅在 `do_basic_tokenize=True` 时有效。
tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
是否分词中文字符。
对于日语,这应该被禁用(参见此 [issue](https://github.com/huggingface/transformers/issues/328))。
strip_accents (`bool`, *optional*):
是否去除所有的重音符号。如果未指定此选项,则会根据 `lowercase` 的值(与原始 BERT 相同)来确定。
"""
def __init__(self, do_lower_case=True, never_split=None, tokenize_chinese_chars=True, strip_accents=None):
if never_split is None:
never_split = []
self.do_lower_case = do_lower_case
self.never_split = set(never_split)
self.tokenize_chinese_chars = tokenize_chinese_chars
self.strip_accents = strip_accents
def tokenize(self, text, never_split=None):
"""
Basic Tokenization of a piece of text. Split on "white spaces" only, for sub-word tokenization, see
WordPieceTokenizer.
Args:
never_split (`List[str]`, *optional*)
Kept for backward compatibility purposes. Now implemented directly at the base class level (see
[`PreTrainedTokenizer.tokenize`]) List of token not to split.
"""
never_split = self.never_split.union(set(never_split)) if never_split else self.never_split
text = self._clean_text(text)
if self.tokenize_chinese_chars:
text = self._tokenize_chinese_chars(text)
orig_tokens = whitespace_tokenize(text)
split_tokens = []
for token in orig_tokens:
if token not in never_split:
if self.do_lower_case:
token = token.lower()
if self.strip_accents is not False:
token = self._run_strip_accents(token)
elif self.strip_accents:
token = self._run_strip_accents(token)
split_tokens.extend(self._run_split_on_punc(token, never_split))
output_tokens = whitespace_tokenize(" ".join(split_tokens))
return output_tokens
def _run_strip_accents(self, text):
"""Strips accents from a piece of text."""
text = unicodedata.normalize("NFD", text)
output = []
for char in text:
cat = unicodedata.category(char)
if cat == "Mn":
continue
output.append(char)
return "".join(output)
def _run_split_on_punc(self, text, never_split=None):
"""Splits punctuation on a piece of text."""
if never_split is not None and text in never_split:
return [text]
chars = list(text)
i = 0
start_new_word = True
output = []
while i < len(chars):
char = chars[i]
if _is_punctuation(char):
output.append([char])
start_new_word = True
else:
if start_new_word:
output.append([])
start_new_word = False
output[-1].append(char)
i += 1
return ["".join(x) for x in output]
def _tokenize_chinese_chars(self, text):
"""Adds whitespace around any CJK character."""
output = []
for char in text:
cp = ord(char)
if self._is_chinese_char(cp):
output.append(" ")
output.append(char)
output.append(" ")
else:
output.append(char)
return "".join(output)
def _is_chinese_char(self, cp):
"""Checks whether CP is the codepoint of a CJK character."""
if (
(cp >= 0x4E00 and cp <= 0x9FFF)
or (cp >= 0x3400 and cp <= 0x4DBF)
or (cp >= 0x20000 and cp <= 0x2A6DF)
or (cp >= 0x2A700 and cp <= 0x2B73F)
or (cp >= 0x2B740 and cp <= 0x2B81F)
or (cp >= 0x2B820 and cp <= 0x2CEAF)
or (cp >= 0xF900 and cp <= 0xFAFF)
or (cp >= 0x2F800 and cp <= 0x2FA1F)
):
return True
return False
def _clean_text(self, text):
"""Performs invalid character removal and whitespace cleanup on text."""
output = []
for char in text:
cp = ord(char)
if cp == 0 or cp == 0xFFFD or _is_control(char):
continue
if _is_whitespace(char):
output.append(" ")
else:
output.append(char)
return "".join(output)
class WordpieceTokenizer(object):
"""Runs WordPiece tokenization."""
def __init__(self, vocab, unk_token, max_input_chars_per_word=100):
self.vocab = vocab
self.unk_token = unk_token
self.max_input_chars_per_word = max_input_chars_per_word
def tokenize(self, text):
"""
Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform
tokenization using the given vocabulary.
For example, `input = "unaffable"` wil return as output `["un", "##aff", "##able"]`.
Args:
text: A single token or whitespace separated tokens. This should have
already been passed through *BasicTokenizer*.
Returns:
A list of wordpiece tokens.
"""
output_tokens = []
for token in whitespace_tokenize(text):
chars = list(token)
if len(chars) > self.max_input_chars_per_word:
output_tokens.append(self.unk_token)
continue
is_bad = False
start = 0
sub_tokens = []
while start < len(chars):
end = len(chars)
cur_substr = None
while start < end:
substr = "".join(chars[start:end])
if start > 0:
substr = "##" + substr
if substr in self.vocab:
cur_substr = substr
break
end -= 1
if cur_substr is None:
is_bad = True
break
sub_tokens.append(cur_substr)
start = end
if is_bad:
output_tokens.append(self.unk_token)
else:
output_tokens.extend(sub_tokens)
return output_tokens
.\models\realm\tokenization_realm_fast.py
"""REALM 的快速分词类。"""
import json
from typing import List, Optional, Tuple
from tokenizers import normalizers
from ...tokenization_utils_base import BatchEncoding
from ...tokenization_utils_fast import PreTrainedTokenizerFast
from ...utils import PaddingStrategy, logging
from .tokenization_realm import RealmTokenizer
logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer.json"}
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"google/realm-cc-news-pretrained-embedder": (
"https://huggingface.co/google/realm-cc-news-pretrained-embedder/resolve/main/vocab.txt"
),
"google/realm-cc-news-pretrained-encoder": (
"https://huggingface.co/google/realm-cc-news-pretrained-encoder/resolve/main/vocab.txt"
),
"google/realm-cc-news-pretrained-scorer": (
"https://huggingface.co/google/realm-cc-news-pretrained-scorer/resolve/main/vocab.txt"
),
"google/realm-cc-news-pretrained-openqa": (
"https://huggingface.co/google/realm-cc-news-pretrained-openqa/aresolve/main/vocab.txt"
),
"google/realm-orqa-nq-openqa": "https://huggingface.co/google/realm-orqa-nq-openqa/resolve/main/vocab.txt",
"google/realm-orqa-nq-reader": "https://huggingface.co/google/realm-orqa-nq-reader/resolve/main/vocab.txt",
"google/realm-orqa-wq-openqa": "https://huggingface.co/google/realm-orqa-wq-openqa/resolve/main/vocab.txt",
"google/realm-orqa-wq-reader": "https://huggingface.co/google/realm-orqa-wq-reader/resolve/main/vocab.txt",
},
}
"tokenizer_file": {
"google/realm-cc-news-pretrained-embedder": (
"https://huggingface.co/google/realm-cc-news-pretrained-embedder/resolve/main/tokenizer.jsont"
),
"google/realm-cc-news-pretrained-encoder": (
"https://huggingface.co/google/realm-cc-news-pretrained-encoder/resolve/main/tokenizer.json"
),
"google/realm-cc-news-pretrained-scorer": (
"https://huggingface.co/google/realm-cc-news-pretrained-scorer/resolve/main/tokenizer.json"
),
"google/realm-cc-news-pretrained-openqa": (
"https://huggingface.co/google/realm-cc-news-pretrained-openqa/aresolve/main/tokenizer.json"
),
"google/realm-orqa-nq-openqa": (
"https://huggingface.co/google/realm-orqa-nq-openqa/resolve/main/tokenizer.json"
),
"google/realm-orqa-nq-reader": (
"https://huggingface.co/google/realm-orqa-nq-reader/resolve/main/tokenizer.json"
),
"google/realm-orqa-wq-openqa": (
"https://huggingface.co/google/realm-orqa-wq-openqa/resolve/main/tokenizer.json"
),
"google/realm-orqa-wq-reader": (
"https://huggingface.co/google/realm-orqa-wq-reader/resolve/main/tokenizer.json"
),
},
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"google/realm-cc-news-pretrained-embedder": 512,
"google/realm-cc-news-pretrained-encoder": 512,
"google/realm-cc-news-pretrained-scorer": 512,
"google/realm-cc-news-pretrained-openqa": 512,
"google/realm-orqa-nq-openqa": 512,
"google/realm-orqa-nq-reader": 512,
"google/realm-orqa-wq-openqa": 512,
"google/realm-orqa-wq-reader": 512,
}
PRETRAINED_INIT_CONFIGURATION = {
"google/realm-cc-news-pretrained-embedder": {"do_lower_case": True},
"google/realm-cc-news-pretrained-encoder": {"do_lower_case": True},
"google/realm-cc-news-pretrained-scorer": {"do_lower_case": True},
"google/realm-cc-news-pretrained-openqa": {"do_lower_case": True},
"google/realm-orqa-nq-openqa": {"do_lower_case": True},
"google/realm-orqa-nq-reader": {"do_lower_case": True},
"google/realm-orqa-wq-openqa": {"do_lower_case": True},
"google/realm-orqa-wq-reader": {"do_lower_case": True},
}
class RealmTokenizerFast(PreTrainedTokenizerFast):
r"""
Construct a "fast" REALM tokenizer (backed by HuggingFace's *tokenizers* library). Based on WordPiece.
[`RealmTokenizerFast`] is identical to [`BertTokenizerFast`] and runs end-to-end tokenization: punctuation
splitting and wordpiece.
This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
refer to this superclass for more information regarding those methods.
# 定义预置的词汇文件名列表,通常包含不同语言的词汇文件名
vocab_files_names = VOCAB_FILES_NAMES
# 定义预训练模型的词汇文件映射,用于加载不同预训练模型的词汇文件
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
# 定义预训练模型的初始化配置,可能包含模型的特定参数配置
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
# 定义预训练模型的最大输入长度限制,通常用于限制输入序列的最大长度
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
# 定义慢速分词器类,通常用于特定语言或对分词速度要求不高的场景
slow_tokenizer_class = RealmTokenizer
# 初始化函数,用于创建一个新的实例对象
def __init__(
self,
vocab_file=None, # 词汇文件路径,用于加载模型的词汇
tokenizer_file=None, # 分词器文件路径,用于加载保存的分词器模型
do_lower_case=True, # 是否将输入文本转为小写
unk_token="[UNK]", # 未知标记,用于词汇中未出现的词的表示
sep_token="[SEP]", # 分隔符标记,用于组合多个序列的标记
pad_token="[PAD]", # 填充标记,用于批处理不同长度的序列
cls_token="[CLS]", # 分类器标记,用于序列分类任务的开始标记
mask_token="[MASK]", # 掩码标记,用于掩码语言模型训练中的预测
tokenize_chinese_chars=True, # 是否分词中文字符
strip_accents=None, # 是否去除所有的重音符号
**kwargs, # 其他可选参数,用于灵活设置
):
# 调用父类的构造函数初始化对象,设置各种参数
super().__init__(
vocab_file,
tokenizer_file=tokenizer_file,
do_lower_case=do_lower_case,
unk_token=unk_token,
sep_token=sep_token,
pad_token=pad_token,
cls_token=cls_token,
mask_token=mask_token,
tokenize_chinese_chars=tokenize_chinese_chars,
strip_accents=strip_accents,
**kwargs,
)
# 从后端分词器获取规范化器的状态并反序列化
normalizer_state = json.loads(self.backend_tokenizer.normalizer.__getstate__())
# 检查当前对象的参数是否与规范化器的状态匹配,如果不匹配则更新规范化器
if (
normalizer_state.get("lowercase", do_lower_case) != do_lower_case
or normalizer_state.get("strip_accents", strip_accents) != strip_accents
or normalizer_state.get("handle_chinese_chars", tokenize_chinese_chars) != tokenize_chinese_chars
):
normalizer_class = getattr(normalizers, normalizer_state.pop("type"))
normalizer_state["lowercase"] = do_lower_case
normalizer_state["strip_accents"] = strip_accents
normalizer_state["handle_chinese_chars"] = tokenize_chinese_chars
self.backend_tokenizer.normalizer = normalizer_class(**normalizer_state)
# 设置对象的小写处理标志位
self.do_lower_case = do_lower_case
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
"""
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
adding special tokens. A REALM sequence has the following format:
- single sequence: `[CLS] X [SEP]`
- pair of sequences: `[CLS] A [SEP] B [SEP]`
Args:
token_ids_0 (`List[int]`):
List of IDs to which the special tokens will be added.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: List of [input IDs](../glossary
"""
# 构建带有特殊标记的模型输入序列
output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
if token_ids_1 is not None:
output += token_ids_1 + [self.sep_token_id]
return output
def create_token_type_ids_from_sequences(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
):
def create_token_type_ids_from_sequences(
self,
token_ids_0: List[int],
token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Create a mask from the two sequences passed to be used in a sequence-pair classification task. A REALM sequence
pair mask has the following format:
```
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
| first sequence | second sequence |
```
If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s).
Args:
token_ids_0 (`List[int]`):
List of IDs for the first sequence.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: List of [token type IDs](../glossary
"""
# Define the separator and classification token IDs
sep = [self.sep_token_id]
cls = [self.cls_token_id]
# If token_ids_1 is None, return a mask with zeros for only the first sequence
if token_ids_1 is None:
return len(cls + token_ids_0 + sep) * [0]
# Otherwise, concatenate masks for both sequences where the first sequence has 0s and the second has 1s
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
"""
Save the tokenizer's vocabulary to the specified directory.
Args:
save_directory (str):
Directory path where the vocabulary will be saved.
filename_prefix (str, *optional*):
Optional prefix for the saved vocabulary files.
Returns:
Tuple[str]: Tuple containing the paths to the saved files.
"""
# Save the tokenizer's model (vocabulary) to the specified directory with an optional filename prefix
files = self._tokenizer.model.save(save_directory, name=filename_prefix)
return tuple(files)
.\models\realm\__init__.py
from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tokenizers_available, is_torch_available
_import_structure = {
"configuration_realm": ["REALM_PRETRAINED_CONFIG_ARCHIVE_MAP", "RealmConfig"],
"tokenization_realm": ["RealmTokenizer"],
}
try:
if not is_tokenizers_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["tokenization_realm_fast"] = ["RealmTokenizerFast"]
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_realm"] = [
"REALM_PRETRAINED_MODEL_ARCHIVE_LIST",
"RealmEmbedder",
"RealmForOpenQA",
"RealmKnowledgeAugEncoder",
"RealmPreTrainedModel",
"RealmReader",
"RealmScorer",
"load_tf_weights_in_realm",
]
_import_structure["retrieval_realm"] = ["RealmRetriever"]
if TYPE_CHECKING:
from .configuration_realm import REALM_PRETRAINED_CONFIG_ARCHIVE_MAP, RealmConfig
from .tokenization_realm import RealmTokenizer
try:
if not is_tokenizers_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .tokenization_realm import RealmTokenizerFast
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_realm import (
REALM_PRETRAINED_MODEL_ARCHIVE_LIST,
RealmEmbedder,
RealmForOpenQA,
RealmKnowledgeAugEncoder,
RealmPreTrainedModel,
RealmReader,
RealmScorer,
load_tf_weights_in_realm,
)
from .retrieval_realm import RealmRetriever
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
.\models\reformer\configuration_reformer.py
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"google/reformer-crime-and-punishment": (
"https://huggingface.co/google/reformer-crime-and-punishment/resolve/main/config.json"
),
"google/reformer-enwik8": "https://huggingface.co/google/reformer-enwik8/resolve/main/config.json",
}
class ReformerConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`ReformerModel`]. It is used to instantiate a
Reformer model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the ReFormer
[google/reformer-crime-and-punishment](https://huggingface.co/google/reformer-crime-and-punishment) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Examples:
```
>>> from transformers import ReformerConfig, ReformerModel
>>> # Initializing a Reformer configuration
>>> configuration = ReformerConfig()
>>> # Initializing a Reformer model (with random weights)
>>> model = ReformerModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
"""
model_type = "reformer"
keys_to_ignore_at_inference = ["past_buckets_states"]
attribute_map = {}
def __init__(
self,
attention_head_size=64,
attn_layers=["local", "lsh", "local", "lsh", "local", "lsh"],
axial_norm_std=1.0,
axial_pos_embds=True,
axial_pos_shape=[64, 64],
axial_pos_embds_dim=[64, 192],
chunk_size_lm_head=0,
eos_token_id=2,
feed_forward_size=512,
hash_seed=None,
hidden_act="relu",
hidden_dropout_prob=0.05,
hidden_size=256,
initializer_range=0.02,
is_decoder=False,
layer_norm_eps=1e-12,
local_num_chunks_before=1,
local_num_chunks_after=0,
local_attention_probs_dropout_prob=0.05,
local_attn_chunk_length=64,
lsh_attn_chunk_length=64,
lsh_attention_probs_dropout_prob=0.0,
lsh_num_chunks_before=1,
lsh_num_chunks_after=0,
max_position_embeddings=4096,
num_attention_heads=12,
num_buckets=None,
num_hashes=1,
pad_token_id=0,
vocab_size=320,
tie_word_embeddings=False,
use_cache=True,
classifier_dropout=None,
**kwargs,
):
self.hash_seed = hash_seed
self.vocab_size = vocab_size
self.attention_head_size = attention_head_size
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.num_hashes = num_hashes
self.num_hidden_layers = len(attn_layers)
self.num_buckets = tuple(num_buckets) if isinstance(num_buckets, list) else num_buckets
self.lsh_attn_chunk_length = lsh_attn_chunk_length
self.local_attn_chunk_length = local_attn_chunk_length
self.lsh_num_chunks_after = lsh_num_chunks_after
self.lsh_num_chunks_before = lsh_num_chunks_before
self.local_num_chunks_after = local_num_chunks_after
self.local_num_chunks_before = local_num_chunks_before
self.hidden_act = hidden_act
self.feed_forward_size = feed_forward_size
self.hidden_dropout_prob = hidden_dropout_prob
self.lsh_attention_probs_dropout_prob = lsh_attention_probs_dropout_prob
self.local_attention_probs_dropout_prob = local_attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.axial_pos_embds = axial_pos_embds
self.axial_pos_shape = tuple(axial_pos_shape)
self.axial_pos_embds_dim = tuple(axial_pos_embds_dim)
self.axial_norm_std = axial_norm_std
self.chunk_size_lm_head = chunk_size_lm_head
self.attn_layers = attn_layers
self.use_cache = use_cache
self.classifier_dropout = classifier_dropout
super().__init__(
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
is_decoder=is_decoder,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
.\models\reformer\convert_reformer_trax_checkpoint_to_pytorch.py
import argparse
import pickle
import numpy as np
import torch
from torch import nn
from transformers import ReformerConfig, ReformerModelWithLMHead
from transformers.utils import logging
logging.set_verbosity_info()
def set_param(torch_layer, weight, bias=None):
assert torch_layer.weight.shape == weight.shape, f"{torch_layer} layer.weight does not match"
torch_layer.weight = nn.Parameter(weight)
if bias is not None:
assert torch_layer.bias.shape == bias.shape, f"{torch_layer} layer.bias does not match"
torch_layer.bias = nn.Parameter(bias)
def set_layer_weights_in_torch_lsh(weights, torch_layer, hidden_size):
np_query_key = np.asarray(weights[0])
np_value = np.asarray(weights[1])
np_dense = np.asarray(weights[2])
set_param(
torch_layer.self_attention.query_key,
torch.tensor(np_query_key).transpose(1, 2).contiguous().view(-1, hidden_size),
)
set_param(
torch_layer.self_attention.value,
torch.tensor(np_value).transpose(1, 2).contiguous().view(-1, hidden_size),
)
set_param(
torch_layer.output.dense,
torch.tensor(np_dense).view(-1, hidden_size).contiguous().transpose(0, 1),
)
def set_layer_weights_in_torch_local(weights, torch_layer, hidden_size):
np_query = np.asarray(weights[0])
np_key = np.asarray(weights[1])
np_value = np.asarray(weights[2])
np_dense = np.asarray(weights[3])
set_param(
torch_layer.self_attention.query,
torch.tensor(np_query).transpose(1, 2).contiguous().view(-1, hidden_size),
)
set_param(
torch_layer.self_attention.key,
torch.tensor(np_key).transpose(1, 2).contiguous().view(-1, hidden_size),
)
set_param(
torch_layer.self_attention.value,
torch.tensor(np_value).transpose(1, 2).contiguous().view(-1, hidden_size),
)
set_param(
torch_layer.output.dense,
torch.tensor(np_dense).view(-1, hidden_size).contiguous().transpose(0, 1),
)
def set_block_weights_in_torch(weights, torch_block, hidden_size):
layer_norm_1 = weights[0][0][0]
layer_norm_1_weight = np.asarray(layer_norm_1[0])
layer_norm_1_bias = np.asarray(layer_norm_1[1])
set_param(
torch_block.attention.layer_norm,
torch.tensor(layer_norm_1_weight),
torch.tensor(layer_norm_1_bias),
)
attn_weights = weights[0][1]
if len(attn_weights) < 4:
set_layer_weights_in_torch_lsh(attn_weights, torch_block.attention, hidden_size)
else:
set_layer_weights_in_torch_local(attn_weights, torch_block.attention, hidden_size)
intermediate_weights = weights[2][0][1][2]
if len(intermediate_weights) == 4:
intermediate_weights = intermediate_weights[2]
layer_norm_2_weight = np.asarray(intermediate_weights[0][0])
layer_norm_2_bias = np.asarray(intermediate_weights[0][1])
set_param(
torch_block.feed_forward.layer_norm,
torch.tensor(layer_norm_2_weight),
torch.tensor(layer_norm_2_bias),
)
inter_dense_weight = np.asarray(intermediate_weights[1][0])
inter_dense_bias = np.asarray(intermediate_weights[1][1])
set_param(
torch_block.feed_forward.dense.dense,
torch.tensor(inter_dense_weight).transpose(0, 1).contiguous(),
torch.tensor(inter_dense_bias),
)
out_dense_weight = np.asarray(intermediate_weights[4][0])
out_dense_bias = np.asarray(intermediate_weights[4][1])
set_param(
torch_block.feed_forward.output.dense,
torch.tensor(out_dense_weight).transpose(0, 1).contiguous(),
torch.tensor(out_dense_bias),
)
def set_model_weights_in_torch(weights, torch_model, hidden_size):
torch_model_reformer = torch_model.reformer
word_embeddings = np.asarray(weights[1])
set_param(
torch_model_reformer.embeddings.word_embeddings,
torch.tensor(word_embeddings),
)
if isinstance(weights[3], tuple):
position_embeddings = torch_model_reformer.embeddings.position_embeddings
for emb_idx in range(len(position_embeddings.weights)):
emb_weights = np.asarray(weights[3][emb_idx][0])
assert (
position_embeddings.weights[emb_idx].shape == emb_weights.shape
), f"{position_embeddings[emb_idx]} emb does not match"
position_embeddings.weights[emb_idx] = nn.Parameter(torch.tensor(emb_weights))
trax_layer_weights = weights[5]
assert len(torch_model_reformer.encoder.layers) * 4 == len(
trax_layer_weights
), "HF and trax model do not have the same number of layers"
for layer_idx, layer in enumerate(torch_model_reformer.encoder.layers):
block_weights = trax_layer_weights[4 * layer_idx : 4 * (layer_idx + 1)]
set_block_weights_in_torch(block_weights, layer, hidden_size)
layer_norm_out_weight = np.asarray(weights[7][0])
layer_norm_out_bias = np.asarray(weights[7][1])
set_param(
torch_model_reformer.encoder.layer_norm,
torch.tensor(layer_norm_out_weight),
torch.tensor(layer_norm_out_bias),
)
output_embed_weights = np.asarray(weights[9][0])
output_embed_bias = np.asarray(weights[9][1])
set_param(
torch_model.lm_head.decoder,
torch.tensor(output_embed_weights).transpose(0, 1).contiguous(),
torch.tensor(output_embed_bias),
)
def convert_trax_checkpoint_to_pytorch(trax_model_pkl_path, config_file, pytorch_dump_path):
config = ReformerConfig.from_json_file(config_file)
print(f"Building PyTorch model from configuration: {config}")
model = ReformerModelWithLMHead(config)
with open(trax_model_pkl_path, "rb") as f:
model_weights = pickle.load(f)["weights"]
set_model_weights_in_torch(model_weights, model, config.hidden_size)
print(f"Save PyTorch model to {pytorch_dump_path}")
torch.save(model.state_dict(), pytorch_dump_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--trax_model_pkl_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
)
parser.add_argument(
"--config_file",
default=None,
type=str,
required=True,
help=(
"The config json file corresponding to the pre-trained Reformer model. \n"
"This specifies the model architecture."
),
)
parser.add_argument(
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
)
args = parser.parse_args()
convert_trax_checkpoint_to_pytorch(args.trax_model_pkl_path, args.config_file, args.pytorch_dump_path)
.\models\reformer\modeling_reformer.py
def _stable_argsort(vector, dim):
scale_offset = torch.arange(vector.shape[dim], device=vector.device).view(1, 1, -1)
scale_offset = scale_offset.expand(vector.shape)
scaled_vector = vector.shape[dim] * vector + (scale_offset % vector.shape[dim])
return torch.argsort(scaled_vector, dim=dim)
def _get_least_common_mult_chunk_len(config):
attn_types = config.attn_layers
attn_types_set = set(attn_types)
if len(attn_types_set) == 1 and attn_types[0] == "lsh":
return config.lsh_attn_chunk_length
elif len(attn_types_set) == 1 and attn_types[0] == "local":
return config.local_attn_chunk_length
elif len(attn_types_set) == 2 and attn_types_set == {"lsh", "local"}:
return np.lcm(config.lsh_attn_chunk_length, config.local_attn_chunk_length)
else:
raise NotImplementedError(
f"Only attn layer types 'lsh' and 'local' exist, but `config.attn_layers`: {config.attn_layers}. Select "
"attn layer types from ['lsh', 'local'] only."
)
def _get_min_chunk_len(config):
attn_types = config.attn_layers
attn_types_set = set(attn_types)
if len(attn_types_set) == 1 and attn_types[0] == "lsh":
return config.lsh_attn_chunk_length
elif len(attn_types_set) == 1 and attn_types[0] == "local":
return config.local_attn_chunk_length
elif len(attn_types_set) == 2 and attn_types_set == {"lsh", "local"}:
return min(config.lsh_attn_chunk_length, config.local_attn_chunk_length)
else:
raise NotImplementedError(
f"Only attn layer types 'lsh' and 'local' exist, but `config.attn_layers`: {config.attn_layers}. Select "
"attn layer types from ['lsh', 'local'] only."
)
class AxialPositionEmbeddings(nn.Module):
"""
Constructs axial position embeddings. Useful for very long input sequences to save memory and time.
"""
def __init__(self, config):
super().__init__()
self.axial_pos_shape = config.axial_pos_shape
self.axial_pos_embds_dim = config.axial_pos_embds_dim
self.dropout = config.hidden_dropout_prob
self.least_common_mult_chunk_length = _get_least_common_mult_chunk_len(config)
self.weights = nn.ParameterList()
if sum(self.axial_pos_embds_dim) != config.hidden_size:
raise ValueError(
f"Make sure that config.axial_pos_embds factors: {self.axial_pos_embds_dim} sum to "
f"config.hidden_size: {config.hidden_size}"
)
for axis, axial_pos_embd_dim in enumerate(self.axial_pos_embds_dim):
ax_shape = [1] * len(self.axial_pos_shape)
ax_shape[axis] = self.axial_pos_shape[axis]
ax_shape = tuple(ax_shape) + (axial_pos_embd_dim,)
self.weights.append(nn.Parameter(torch.ones(ax_shape, dtype=torch.float32)))
class PositionEmbeddings(nn.Module):
"""Constructs conventional position embeddings of shape `[max_pos_embeddings, hidden_size]`."""
def __init__(self, config):
super().__init__()
self.dropout = config.hidden_dropout_prob
self.embedding = nn.Embedding(config.max_position_embeddings, config.hidden_size)
def forward(self, position_ids):
position_embeddings = self.embedding(position_ids)
position_embeddings = nn.functional.dropout(position_embeddings, p=self.dropout, training=self.training)
return position_embeddings
class ReformerEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings."""
def __init__(self, config):
super().__init__()
self.max_position_embeddings = config.max_position_embeddings
self.dropout = config.hidden_dropout_prob
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
self.position_embeddings = (
AxialPositionEmbeddings(config) if config.axial_pos_embds else PositionEmbeddings(config)
)
def forward(self, input_ids=None, position_ids=None, inputs_embeds=None, start_idx_pos_encodings=0):
if input_ids is not None:
input_shape = input_ids.size()
device = input_ids.device
else:
input_shape = inputs_embeds.size()[:-1]
device = inputs_embeds.device
seq_length = input_shape[1]
if position_ids is None:
position_ids = torch.arange(
start_idx_pos_encodings, start_idx_pos_encodings + seq_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0).expand(input_shape)
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
if position_ids.shape[-1] > self.max_position_embeddings:
raise ValueError(
f"Sequence Length: {position_ids.shape[-1]} has to be less or equal than "
f"config.max_position_embeddings {self.max_position_embeddings}."
)
embeddings = nn.functional.dropout(inputs_embeds, p=self.dropout, training=self.training)
position_embeddings = self.position_embeddings(position_ids)
embeddings = embeddings + position_embeddings
return embeddings
class EfficientAttentionMixin:
"""
A few utilities for nn.Modules in Reformer, to be used as a mixin.
"""
def _look_adjacent(self, vectors, num_chunks_before, num_chunks_after):
"""
Used to implement attention between consecutive chunks.
Args:
vectors: array of shape [batch_size, num_attention_heads, n_chunks, chunk_len, ...]
num_chunks_before: chunks before current chunk to include in attention
num_chunks_after: chunks after current chunk to include in attention
Returns:
tensor of shape [num_chunks, N * chunk_length, ...], where N = (1 + num_chunks_before + num_chunks_after).
"""
if num_chunks_before == 0 and num_chunks_after == 0:
return vectors
slices = []
for i in range(-num_chunks_before, num_chunks_after + 1):
if i == 0:
slices.append(vectors)
else:
slices.append(torch.cat([vectors[:, :, i:, ...], vectors[:, :, :i, ...]], dim=2))
return torch.cat(slices, dim=3)
def _split_hidden_size_dim(self, x, num_attn_heads, attn_head_size):
new_x_shape = x.size()[:-1] + (num_attn_heads, attn_head_size)
x = x.view(*new_x_shape)
return x.transpose(2, 1)
def _merge_hidden_size_dims(self, x, num_attn_heads, attn_head_size):
x = x.permute(0, 2, 1, 3)
return torch.reshape(x, (x.size()[0], -1, num_attn_heads * attn_head_size))
def _split_seq_length_dim_to(self, vectors, dim_factor_1, dim_factor_2, num_attn_heads, attn_head_size=None):
batch_size = vectors.shape[0]
split_dim_shape = (batch_size, num_attn_heads, dim_factor_1, dim_factor_2)
if len(vectors.shape) == 4:
return torch.reshape(vectors, split_dim_shape + (attn_head_size,))
elif len(vectors.shape) == 3:
return torch.reshape(vectors, split_dim_shape)
else:
raise ValueError(f"Input vector rank should be one of [3, 4], but is: {len(vectors.shape)}")
class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
def __init__(self, config):
super().__init__()
self.config = config
self.chunk_length = config.lsh_attn_chunk_length
self.num_hashes = config.num_hashes
self.num_buckets = config.num_buckets
self.num_chunks_before = config.lsh_num_chunks_before
self.num_chunks_after = config.lsh_num_chunks_after
self.hash_seed = config.hash_seed
self.is_decoder = config.is_decoder
self.max_position_embeddings = config.max_position_embeddings
self.dropout = config.lsh_attention_probs_dropout_prob
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = config.attention_head_size
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.hidden_size = config.hidden_size
self.query_key = nn.Linear(self.hidden_size, self.all_head_size, bias=False)
self.value = nn.Linear(self.hidden_size, self.all_head_size, bias=False)
self.register_buffer("self_mask_value_float16", torch.tensor(-1e3), persistent=False)
self.register_buffer("self_mask_value_float32", torch.tensor(-1e5), persistent=False)
self.register_buffer("mask_value_float16", torch.tensor(-1e4), persistent=False)
self.register_buffer("mask_value_float32", torch.tensor(-1e9), persistent=False)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
num_hashes=None,
buckets=None,
past_buckets_states=None,
use_cache=False,
output_attentions=False,
**kwargs,
):
def _query_per_attn_head(self, hidden_states):
per_head_query_key = self.query_key.weight.reshape(
self.num_attention_heads, self.attention_head_size, self.hidden_size
).transpose(-2, -1)
query_key_vectors = torch.einsum("balh,ahr->balr", hidden_states, per_head_query_key)
return query_key_vectors
def _value_per_attn_head(self, hidden_states):
per_head_value = self.value.weight.reshape(
self.num_attention_heads, self.attention_head_size, self.hidden_size
).transpose(-2, -1)
value_vectors = torch.einsum("balh,ahr->balr", hidden_states, per_head_value)
return value_vectors
def _get_sorted_bucket_idx_and_undo_sorted_bucket_idx(self, sequence_length, buckets, num_hashes):
with torch.no_grad():
sorted_bucket_idx = _stable_argsort(buckets, dim=-1)
indices = (
torch.arange(sorted_bucket_idx.shape[-1], device=buckets.device)
.view(1, 1, -1)
.expand(sorted_bucket_idx.shape)
)
undo_sorted_bucket_idx = sorted_bucket_idx.new(*sorted_bucket_idx.size())
undo_sorted_bucket_idx.scatter_(-1, sorted_bucket_idx, indices)
return sorted_bucket_idx, undo_sorted_bucket_idx
def _set_num_buckets(self, sequence_length):
num_buckets_pow_2 = (2 * (sequence_length // self.chunk_length)).bit_length() - 1
num_buckets = 2**num_buckets_pow_2
num_buckets_limit = 2 * max(
int((self.max_position_embeddings // self.chunk_length) ** (0.5)),
self.chunk_length,
)
if num_buckets > num_buckets_limit:
num_buckets = [2 ** (num_buckets_pow_2 // 2), 2 ** (num_buckets_pow_2 - num_buckets_pow_2 // 2)]
logger.warning(f"config.num_buckets 未设置。将 config.num_buckets 设置为 {num_buckets}...")
self.config.num_buckets = num_buckets
self.num_buckets = num_buckets
def _attend(
self,
query_vectors,
key_vectors,
value_vectors,
sorted_bucket_idx_per_hash,
attention_mask,
head_mask,
do_standard_self_attention,
do_cached_attention,
):
def _compute_attn_mask(
self, query_indices, key_indices, attention_mask, query_key_dot_shape, do_standard_self_attention
):
if attention_mask is not None:
attention_mask = attention_mask.to(torch.bool)[:, None, :]
if not do_standard_self_attention:
attention_mask = attention_mask[:, None, :]
attention_mask = attention_mask.expand(query_indices.shape[:-1] + (-1,))
attention_mask = torch.gather(attention_mask, -1, key_indices)
attention_mask = attention_mask.unsqueeze(-2).expand(query_key_dot_shape)
if self.is_decoder is True:
causal_mask = torch.ge(query_indices.unsqueeze(-1), key_indices.unsqueeze(-2)).to(query_indices.device)
if attention_mask is not None:
attention_mask = causal_mask * attention_mask
else:
attention_mask = causal_mask
return attention_mask
def _get_relevant_hid_states_and_buckets(
self, query_vectors, attention_mask, num_hashes, hidden_states, past_states, past_buckets
):
def _expand_to_indices_in_relevant_chunk(self, indices, sequence_length):
start_indices_chunk = ((indices[:, -1] // self.chunk_length) - self.num_chunks_before) * self.chunk_length
total_chunk_size = self.chunk_length * (1 + self.num_chunks_before + self.num_chunks_after)
expanded_start_indices = start_indices_chunk.unsqueeze(-1).expand(indices.shape[0], total_chunk_size)
chunk_sequence_indices = expanded_start_indices + torch.arange(
total_chunk_size, device=indices.device, dtype=torch.long
).unsqueeze(0).expand(indices.shape[0], total_chunk_size)
chunk_sequence_indices = chunk_sequence_indices.flatten() % sequence_length
indices = indices.unsqueeze(1).expand((indices.shape[0], total_chunk_size, -1)).flatten(0, 1).clone()
indices[:, -1] = chunk_sequence_indices
return indices
def _len_and_dim_norm(self, vectors, sqrt_num):
"""
length and attention head size dim normalization
"""
vectors = self._len_norm(vectors)
vectors = vectors / sqrt_num
return vectors
def _len_norm(self, x, epsilon=1e-6):
"""
length normalization
"""
variance = torch.mean(x**2, -1, keepdim=True)
norm_x = x * torch.rsqrt(variance + epsilon)
return norm_x
def _gather_by_expansion(self, vectors, idxs, num_hashes):
expanded_idxs = idxs.unsqueeze(-1).expand(-1, -1, -1, self.attention_head_size)
vectors = vectors.repeat(1, 1, num_hashes, 1)
return torch.gather(vectors, 2, expanded_idxs)
class ReverseSort(Function):
"""
After chunked attention is applied which sorted clusters, original ordering has to be restored. Since customized
backward function is used for Reformer, the gradients of the output vectors have to be explicitly sorted here.
"""
@staticmethod
def forward(ctx, out_vectors, logits, sorted_bucket_idx, undo_sorted_bucket_idx):
with torch.no_grad():
ctx.sorted_bucket_idx = sorted_bucket_idx
expanded_undo_sort_indices = undo_sorted_bucket_idx.unsqueeze(-1).expand(out_vectors.shape)
out_vectors = torch.gather(out_vectors, 2, expanded_undo_sort_indices)
logits = torch.gather(logits, 2, undo_sorted_bucket_idx)
return out_vectors, logits
@staticmethod
def backward(ctx, grad_out_vectors, grad_logits):
sorted_bucket_idx = ctx.sorted_bucket_idx
expanded_sort_indices = sorted_bucket_idx.unsqueeze(-1).expand(grad_out_vectors.shape)
grad_out_vectors = torch.gather(grad_out_vectors, 2, expanded_sort_indices)
grad_logits = torch.gather(grad_logits, 2, sorted_bucket_idx)
return grad_out_vectors, grad_logits, None, None
class LocalSelfAttention(nn.Module, EfficientAttentionMixin):
def __init__(self, config):
super().__init__()
self.num_attention_heads = config.num_attention_heads
self.chunk_length = config.local_attn_chunk_length
self.num_chunks_before = config.local_num_chunks_before
self.num_chunks_after = config.local_num_chunks_after
self.is_decoder = config.is_decoder
self.pad_token_id = config.pad_token_id
self.attention_head_size = config.attention_head_size
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.hidden_size = config.hidden_size
self.query = nn.Linear(self.hidden_size, self.all_head_size, bias=False)
self.key = nn.Linear(self.hidden_size, self.all_head_size, bias=False)
self.value = nn.Linear(self.hidden_size, self.all_head_size, bias=False)
self.dropout = config.local_attention_probs_dropout_prob
self.register_buffer("mask_value_float16", torch.tensor(-1e4), persistent=False)
self.register_buffer("mask_value_float32", torch.tensor(-1e9), persistent=False)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
past_buckets_states=None,
use_cache=False,
output_attentions=False,
**kwargs,
):
"""
Performs the forward pass for local self-attention mechanism.
Args:
hidden_states (torch.Tensor): Input embeddings or hidden states.
attention_mask (torch.Tensor, optional): Mask indicating which elements should be attended to.
head_mask (torch.Tensor, optional): Mask indicating heads to be masked out.
past_buckets_states (torch.Tensor, optional): States from previous attention buckets.
use_cache (bool, optional): Whether to use caching mechanism.
output_attentions (bool, optional): Whether to output attention scores.
Returns:
torch.Tensor: Output embeddings or hidden states.
"""
def _compute_attn_mask(
self, query_indices, key_indices, attention_mask, query_key_dots_shape, do_standard_self_attention
):
"""
Computes the attention mask based on query and key indices.
Args:
query_indices (torch.Tensor): Indices for queries.
key_indices (torch.Tensor): Indices for keys.
attention_mask (torch.Tensor): Attention mask.
query_key_dots_shape (Tuple): Shape of the query-key dot product.
do_standard_self_attention (bool): Whether to perform standard self-attention.
Returns:
torch.Tensor: Computed attention mask.
"""
if attention_mask is not None:
attention_mask = attention_mask.to(torch.bool)[:, None, :]
if not do_standard_self_attention:
attention_mask = self._split_seq_length_dim_to(attention_mask, -1, self.chunk_length, 1)
attention_mask = self._look_adjacent(attention_mask, self.num_chunks_before, self.num_chunks_after)
attention_mask = attention_mask.unsqueeze(-2).expand(query_key_dots_shape)
if self.is_decoder is True:
causal_mask = torch.ge(query_indices.unsqueeze(-1), key_indices.unsqueeze(-2)).to(query_indices.device)
if attention_mask is not None:
attention_mask = causal_mask * attention_mask
else:
attention_mask = causal_mask
return attention_mask
@staticmethod
def _retrieve_relevant_hidden_states(previous_hidden_states, chunk_length, num_chunks_before):
start_position = ((previous_hidden_states.shape[1] // chunk_length) - num_chunks_before) * chunk_length
return previous_hidden_states[:, start_position:]
class ReformerSelfOutput(nn.Module):
def __init__(self, config):
super().__init__()
all_head_size = config.num_attention_heads * config.attention_head_size
self.dropout = config.hidden_dropout_prob
self.dense = nn.Linear(all_head_size, config.hidden_size, bias=False)
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
return hidden_states
class ReformerAttention(nn.Module):
def __init__(self, config, layer_id=0):
super().__init__()
self.layer_id = layer_id
self.attn_layers = config.attn_layers
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
if len(set(self.attn_layers)) == 1 and self.attn_layers[0] == "lsh":
self.self_attention = LSHSelfAttention(config)
elif len(set(self.attn_layers)) == 1 and self.attn_layers[0] == "local":
self.self_attention = LocalSelfAttention(config)
elif len(set(self.attn_layers)) == 2 and set(self.attn_layers) == {"lsh", "local"}:
if self.attn_layers[self.layer_id] == "lsh":
self.self_attention = LSHSelfAttention(config)
else:
self.self_attention = LocalSelfAttention(config)
else:
raise NotImplementedError(
f"Only attn layer types 'lsh' and 'local' exist, but got `config.attn_layers`: {self.attn_layers}. "
"Select attn layer types from ['lsh', 'local'] only."
)
self.output = ReformerSelfOutput(config)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
num_hashes=None,
past_buckets_states=None,
use_cache=False,
orig_sequence_length=None,
output_attentions=False,
buckets=None,
):
hidden_states = self.layer_norm(hidden_states)
if past_buckets_states is not None:
past_buckets_states_layer = past_buckets_states[self.layer_id]
else:
past_buckets_states_layer = None
self_attention_outputs = self.self_attention(
hidden_states=hidden_states,
head_mask=head_mask,
attention_mask=attention_mask,
num_hashes=num_hashes,
past_buckets_states=past_buckets_states_layer,
use_cache=use_cache,
output_attentions=output_attentions,
buckets=buckets,
)
if hasattr(self_attention_outputs, "buckets"):
buckets = self_attention_outputs.buckets
else:
buckets = None
if use_cache:
if past_buckets_states[self.layer_id][0] is None:
past_buckets = (
buckets[:, :, :, :orig_sequence_length]
if (buckets is not None and orig_sequence_length > 1)
else buckets
)
else:
past_buckets = torch.cat([past_buckets_states[self.layer_id][0], buckets], dim=-1)
if past_buckets_states[self.layer_id][1] is None:
past_states = hidden_states[:, :orig_sequence_length]
else:
past_states = torch.cat([past_buckets_states[self.layer_id][1], hidden_states], dim=1)
past_buckets_states[self.layer_id] = (past_buckets, past_states)
attention_output = self.output(self_attention_outputs.hidden_states)
return AttentionOutput(
hidden_states=attention_output,
attention_probs=self_attention_outputs.attention_probs,
buckets=buckets,
)
class ReformerFeedForwardDense(nn.Module):
def __init__(self, config):
super().__init__()
self.dropout = config.hidden_dropout_prob
if isinstance(config.hidden_act, str):
self.act_fn = ACT2FN[config.hidden_act]
else:
self.act_fn = config.hidden_act
self.dense = nn.Linear(config.hidden_size, config.feed_forward_size)
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = self.act_fn(hidden_states)
return hidden_states
class ReformerFeedForwardOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dropout = config.hidden_dropout_prob
self.dense = nn.Linear(config.feed_forward_size, config.hidden_size)
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
return hidden_states
class ChunkReformerFeedForward(nn.Module):
def __init__(self, config):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dense = ReformerFeedForwardDense(config)
self.output = ReformerFeedForwardOutput(config)
def forward(self, attention_output):
return apply_chunking_to_forward(
self.forward_chunk,
self.chunk_size_feed_forward,
self.seq_len_dim,
attention_output,
)
def forward_chunk(self, hidden_states):
hidden_states = self.layer_norm(hidden_states)
hidden_states = self.dense(hidden_states)
return self.output(hidden_states)
class ReformerLayer(nn.Module):
def __init__(self, config, layer_id=0):
super().__init__()
self.attention = ReformerAttention(config, layer_id)
self.attention_seed = None
self.feed_forward_seed = None
self.feed_forward = ChunkReformerFeedForward(config)
def _init_attention_seed(self):
"""
This function sets a new seed for the attention layer to make dropout deterministic for both forward calls: 1
normal forward call and 1 forward call in backward to recalculate activations.
"""
if hasattr(torch.cuda, "default_generators") and len(torch.cuda.default_generators) > 0:
device_idx = torch.cuda.current_device()
self.attention_seed = torch.cuda.default_generators[device_idx].seed()
else:
self.attention_seed = int(torch.seed() % sys.maxsize)
torch.manual_seed(self.attention_seed)
def _init_feed_forward_seed(self):
"""
This function sets a new seed for the feed forward layer to make dropout deterministic for both forward calls:
1 normal forward call and 1 forward call in backward to recalculate activations.
"""
if hasattr(torch.cuda, "default_generators") and len(torch.cuda.default_generators) > 0:
device_idx = torch.cuda.current_device()
self.feed_forward_seed = torch.cuda.default_generators[device_idx].seed()
else:
self.feed_forward_seed = int(torch.seed() % sys.maxsize)
torch.manual_seed(self.feed_forward_seed)
def forward(
self,
prev_attn_output,
hidden_states,
attention_mask=None,
head_mask=None,
num_hashes=None,
past_buckets_states=None,
use_cache=False,
orig_sequence_length=None,
output_attentions=False,
):
with torch.no_grad():
if self.training:
self._init_attention_seed()
attn_outputs = self.attention(
hidden_states=hidden_states,
head_mask=head_mask,
attention_mask=attention_mask,
num_hashes=num_hashes,
past_buckets_states=past_buckets_states,
use_cache=use_cache,
orig_sequence_length=orig_sequence_length,
output_attentions=output_attentions,
)
attn_output = attn_outputs.hidden_states
attn_output = prev_attn_output + attn_output
del prev_attn_output
if self.training:
self._init_feed_forward_seed()
hidden_states = hidden_states + self.feed_forward(attn_output)
return ReformerOutput(
attn_output=attn_output,
hidden_states=hidden_states,
attention_probs=attn_outputs.attention_probs,
buckets=attn_outputs.buckets,
)
def backward_pass(
self,
next_attn_output,
hidden_states,
grad_attn_output,
grad_hidden_states,
attention_mask=None,
head_mask=None,
buckets=None,
assert self.training, (
"If you want to train `ReformerModel` and its variations, make sure to use `model.train()` to put the"
" model into training mode."
)
with torch.enable_grad():
next_attn_output.requires_grad = True
torch.manual_seed(self.feed_forward_seed)
res_hidden_states = self.feed_forward(next_attn_output)
res_hidden_states.backward(grad_hidden_states, retain_graph=True)
with torch.no_grad():
hidden_states = hidden_states - res_hidden_states
del res_hidden_states
grad_attn_output = grad_attn_output + next_attn_output.grad
next_attn_output.grad = None
with torch.enable_grad():
hidden_states.requires_grad = True
torch.manual_seed(self.attention_seed)
output = self.attention(
hidden_states=hidden_states,
head_mask=head_mask,
attention_mask=attention_mask,
buckets=buckets,
).hidden_states
output.backward(grad_attn_output, retain_graph=True)
with torch.no_grad():
attn_output = next_attn_output - output
del output, next_attn_output
grad_hidden_states = grad_hidden_states + hidden_states.grad
hidden_states.grad = None
hidden_states = hidden_states.detach()
return ReformerBackwardOutput(
attn_output=attn_output,
hidden_states=hidden_states,
grad_attn_output=grad_attn_output,
grad_hidden_states=grad_hidden_states,
)
"""
针对可逆函数的自定义反向传播函数,以防止 PyTorch 执行通常的反向传播。
通过这种方式确保在前向传播期间不保存内存消耗昂贵的激活值。
本函数的实现受到 https://github.com/lucidrains/reformer-pytorch/blob/master/reformer_pytorch/reversible.py 的启发。
"""
@staticmethod
def forward(
ctx,
hidden_states,
layers,
attention_mask,
head_mask,
num_hashes,
all_hidden_states,
all_attentions,
past_buckets_states,
use_cache,
orig_sequence_length,
output_hidden_states,
output_attentions,
):
all_buckets = ()
hidden_states, attn_output = torch.chunk(hidden_states, 2, dim=-1)
for layer_id, (layer, layer_head_mask) in enumerate(zip(layers, head_mask)):
if output_hidden_states is True:
all_hidden_states.append(hidden_states)
layer_outputs = layer(
prev_attn_output=attn_output,
hidden_states=hidden_states,
attention_mask=attention_mask,
head_mask=layer_head_mask,
num_hashes=num_hashes,
past_buckets_states=past_buckets_states,
use_cache=use_cache,
orig_sequence_length=orig_sequence_length,
output_attentions=output_attentions,
)
attn_output = layer_outputs.attn_output
hidden_states = layer_outputs.hidden_states
all_buckets = all_buckets + (layer_outputs.buckets,)
if output_attentions:
all_attentions.append(layer_outputs.attention_probs)
if output_hidden_states is True:
all_hidden_states.append(hidden_states)
ctx.save_for_backward(attn_output.detach(), hidden_states.detach())
ctx.layers = layers
ctx.all_buckets = all_buckets
ctx.head_mask = head_mask
ctx.attention_mask = attention_mask
return torch.cat([attn_output, hidden_states], dim=-1)
@staticmethod
def backward(ctx, grad_hidden_states):
grad_attn_output, grad_hidden_states = torch.chunk(grad_hidden_states, 2, dim=-1)
attn_output, hidden_states = ctx.saved_tensors
output = ReformerBackwardOutput(
attn_output=attn_output,
hidden_states=hidden_states,
grad_attn_output=grad_attn_output,
grad_hidden_states=grad_hidden_states,
)
del grad_attn_output, grad_hidden_states, attn_output, hidden_states
layers = ctx.layers
all_buckets = ctx.all_buckets
head_mask = ctx.head_mask
attention_mask = ctx.attention_mask
for idx, layer in enumerate(layers[::-1]):
buckets = all_buckets[-1]
all_buckets = all_buckets[:-1]
output = layer.backward_pass(
next_attn_output=output.attn_output,
hidden_states=output.hidden_states,
grad_attn_output=output.grad_attn_output,
grad_hidden_states=output.grad_hidden_states,
head_mask=head_mask[len(layers) - idx - 1],
attention_mask=attention_mask,
buckets=buckets,
)
assert all_buckets == (), "buckets have to be empty after backpropagation"
grad_hidden_states = torch.cat([output.grad_attn_output, output.grad_hidden_states], dim=-1)
return grad_hidden_states, None, None, None, None, None, None, None, None, None, None, None
class ReformerEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.dropout = config.hidden_dropout_prob
self.layers = nn.ModuleList([ReformerLayer(config, i) for i in range(config.num_hidden_layers)])
self.layer_norm = nn.LayerNorm(2 * config.hidden_size, eps=config.layer_norm_eps)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
num_hashes=None,
past_buckets_states=None,
use_cache=False,
orig_sequence_length=None,
output_hidden_states=False,
output_attentions=False,
):
all_hidden_states = []
all_attentions = []
if past_buckets_states is None:
past_buckets_states = [((None), (None)) for i in range(len(self.layers))]
hidden_states = torch.cat([hidden_states, hidden_states], dim=-1)
hidden_states = _ReversibleFunction.apply(
hidden_states,
self.layers,
attention_mask,
head_mask,
num_hashes,
all_hidden_states,
all_attentions,
past_buckets_states,
use_cache,
orig_sequence_length,
output_hidden_states,
output_attentions,
)
hidden_states = self.layer_norm(hidden_states)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
return ReformerEncoderOutput(
hidden_states=hidden_states,
all_hidden_states=all_hidden_states,
all_attentions=all_attentions,
past_buckets_states=past_buckets_states,
)
class ReformerOnlyLMHead(nn.Module):
def __init__(self, config):
super().__init__()
self.seq_len_dim = 1
self.chunk_size_lm_head = config.chunk_size_lm_head
self.decoder = nn.Linear(2 * config.hidden_size, config.vocab_size, bias=False)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
self.decoder.bias = self.bias
def forward(self, hidden_states):
return apply_chunking_to_forward(self.forward_chunk, self.chunk_size_lm_head, self.seq_len_dim, hidden_states)
def forward_chunk(self, hidden_states):
hidden_states = self.decoder(hidden_states)
return hidden_states
def _tie_weights(self):
self.bias = self.decoder.bias
class ReformerPreTrainedModel(PreTrainedModel):
config_class = ReformerConfig
base_model_prefix = "reformer"
@property
def dummy_inputs(self):
input_ids = torch.tensor(DUMMY_INPUTS)
input_mask = torch.tensor(DUMMY_MASK)
dummy_inputs = {
"input_ids": input_ids,
"attention_mask": input_mask,
}
return dummy_inputs
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, AxialPositionEmbeddings):
for weight in module.weights:
nn.init.normal_(weight, std=self.config.axial_norm_std)
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
@dataclass
class ReformerModelOutput(ModelOutput):
"""
Output type of [`ReformerModel`].
"""
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_predict, hidden_size)`):
模型最后一层的隐藏状态序列。
`num_predict` 对应于 `target_mapping.shape[1]`。如果 `target_mapping` 是 `None`,那么 `num_predict` 对应于 `sequence_length`。
past_buckets_states (`List[Tuple(torch.LongTensor, torch.FloatTensor)]`, *optional*, 在 `use_cache=True` 或 `config.use_cache=True` 时返回):
包含预先计算的桶和隐藏状态的列表,用于加速顺序解码。
每个元素是一个元组,第一个元素是形状为 `(batch_size, num_heads, num_hashes, sequence_length)` 的先前*桶*,
第二个元素是形状为 `(batch_size, sequence_length, hidden_size)` 的先前*隐藏状态*。
hidden_states (`tuple(torch.FloatTensor)`, *optional*, 在 `output_hidden_states=True` 或 `config.output_hidden_states=True` 时返回):
元组中包含了每层的输出 (`embeddings` 输出的一个和每层输出的一个) 的 `torch.FloatTensor`,形状为 `(batch_size, sequence_length, hidden_size)`。
模型每一层的隐藏状态,加上初始嵌入的输出。
attentions (`tuple(torch.FloatTensor)`, *optional*, 在 `output_attentions=True` 或 `config.output_attentions=True` 时返回):
元组中包含了每层的注意力权重 `torch.FloatTensor` (每层一个),形状为 `(batch_size, num_heads, sequence_length, sequence_length)`。
经过注意力 softmax 后的注意力权重,用于计算自注意力头中的加权平均值。
@dataclass
class ReformerModelWithLMHeadOutput(ModelOutput):
"""
Output type of [`ReformerModelWithLMHead`].
Args:
loss (`torch.FloatTensor` of shape *(1,)*, *optional*, returned when `labels` is provided)
Language modeling loss (for next-token prediction).
logits (`torch.FloatTensor` of shape `(batch_size, num_predict, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
`num_predict` corresponds to `target_mapping.shape[1]`. If `target_mapping` is `None`, then `num_predict`
corresponds to `sequence_length`.
past_buckets_states (`List[Tuple(torch.LongTensor, torch.FloatTensor)]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
List of `Tuple(torch.LongTensor, torch.FloatTensor` of length `config.n_layers`, with the first element
being the previous *buckets* of shape `(batch_size, num_heads, num_hashes, sequence_length)`) and the
second being the previous *hidden_states* of shape `(batch_size, sequence_length, hidden_size)`).
Contains precomputed buckets and hidden-states that can be used (see `past_buckets_states` input) to speed
up sequential decoding.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings and one for the output of each layer)
of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
past_buckets_states: Optional[List[Tuple[torch.LongTensor, torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
REFORMER_START_DOCSTRING = r"""
Reformer was proposed in [Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451) by Nikita Kitaev,
Łukasz Kaiser, Anselm Levskaya.
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
"""
"""
REFORMER_INPUTS_DOCSTRING = r"""
"""
# 使用装饰器添加文档字符串,描述模型输出原始隐藏状态的Reformer模型,没有特定的输出头
@add_start_docstrings(
"The bare Reformer Model transformer outputting raw hidden-states without any specific head on top.",
REFORMER_START_DOCSTRING,
)
# 定义ReformerModel类,继承自ReformerPreTrainedModel
class ReformerModel(ReformerPreTrainedModel):
def __init__(self, config):
# 调用父类构造函数
super().__init__(config)
# 将配置保存在self.config中
self.config = config
# 断言确保num_hidden_layers大于0,否则抛出异常
assert (
self.config.num_hidden_layers > 0
), "`config.attn_layers` is empty. Select at least one attn layer form ['lsh', 'local']"
# 初始化词嵌入和编码器
self.embeddings = ReformerEmbeddings(config)
self.encoder = ReformerEncoder(config)
# 初始化权重并进行最终处理
self.post_init()
# 返回输入嵌入
def get_input_embeddings(self):
return self.embeddings.word_embeddings
# 设置输入嵌入
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
# 剪枝模型的注意力头
def _prune_heads(self, heads_to_prune):
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
# 使用装饰器添加文档字符串,描述模型前向传播函数的输入
@add_start_docstrings_to_model_forward(REFORMER_INPUTS_DOCSTRING)
# 添加代码示例的文档字符串
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=ReformerModelOutput,
config_class=_CONFIG_FOR_DOC,
)
# 定义前向传播函数
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
num_hashes: Optional[int] = None,
past_buckets_states: Optional[List[Tuple[torch.Tensor]]] = None,
use_cache: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
# 剪切到多个块长度的填充函数
def _pad_to_mult_of_chunk_length(
self,
input_ids,
inputs_embeds=None,
attention_mask=None,
position_ids=None,
input_shape=None,
padding_length=None,
padded_seq_length=None,
device=None,
):
# 如果输入的序列长度不是 `config.chunk_length` 的倍数,则发出警告
logger.warning_once(
f"Input ids are automatically padded from {input_shape[-1]} to {input_shape[-1] + padding_length} to be a "
f"multiple of `config.chunk_length`: {padded_seq_length}"
)
# 使用指定的填充长度和 pad_token_id 创建填充后的输入 ids 张量
padded_input_ids = torch.full(
(input_shape[0], padding_length),
self.config.pad_token_id,
device=device,
dtype=torch.long,
)
# 扩展 `attention_mask`
if attention_mask is not None:
# 创建一个与输入形状相同的全零张量,并将其与原 attention_mask 进行拼接
pad_attention_mask = torch.zeros(input_shape[0], padding_length, device=device, dtype=attention_mask.dtype)
attention_mask = torch.cat([attention_mask, pad_attention_mask], dim=-1)
else:
# 如果原 attention_mask 不存在,则创建一个全真值的张量和一个全零张量,然后进行拼接
attention_mask = torch.cat(
[
torch.ones(input_shape, device=device, dtype=torch.bool),
torch.zeros((input_shape[0], padding_length), device=device, dtype=torch.bool),
],
dim=-1,
)
# 如果存在输入 ids,则将填充后的输入 ids 进行拼接,并更新输入形状
if input_ids is not None:
input_ids = torch.cat([input_ids, padded_input_ids], dim=-1)
input_shape = input_ids.size()
# 如果存在位置 ids,则对位置 ids 进行填充
if position_ids is not None:
# 创建一个从原始长度到填充长度的序列,然后拼接到原位置 ids 后面
padded_position_ids = torch.arange(input_shape[-1], padded_seq_length, dtype=torch.long, device=device)
padded_position_ids = position_ids.unsqueeze(0).expand(input_shape[0], padding_length)
position_ids = torch.cat([position_ids, padded_position_ids], dim=-1)
# 如果存在输入嵌入张量,则对其进行填充,并更新输入形状
if inputs_embeds is not None:
# 使用填充后的输入 ids 和位置 ids 重新计算输入嵌入张量,并进行拼接
padded_inputs_embeds = self.embeddings(padded_input_ids, position_ids)
inputs_embeds = torch.cat([inputs_embeds, padded_inputs_embeds], dim=-2)
input_shape = inputs_embeds.size()
# 返回填充后的 input_ids、inputs_embeds、attention_mask、position_ids 和更新后的 input_shape
return input_ids, inputs_embeds, attention_mask, position_ids, input_shape
# 使用装饰器为 ReformerModelWithLMHead 类添加文档字符串,描述其作为带有语言建模头部的 Reformer 模型
@add_start_docstrings("""Reformer Model with a `language modeling` head on top.""", REFORMER_START_DOCSTRING)
class ReformerModelWithLMHead(ReformerPreTrainedModel):
# 定义需要共享权重的键列表,这些键对应于语言建模头部的权重和偏置
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
def __init__(self, config):
# 调用父类的构造函数,传入配置参数
super().__init__(config)
# 断言配置中必须设置 is_decoder=True,否则抛出异常
assert config.is_decoder, "If you want to use `ReformerModelWithLMHead` make sure that `is_decoder=True`."
# 断言配置中不应该使用 "local" 注意力层或者 local_num_chunks_after 应为 0,否则抛出异常
assert "local" not in self.config.attn_layers or config.local_num_chunks_after == 0, (
"If causal mask is enabled, make sure that `config.local_num_chunks_after` is set to 0 and not"
f" {config.local_num_chunks_after}."
)
# 断言配置中不应该使用 "lsh" 注意力层或者 lsh_num_chunks_after 应为 0,否则抛出异常
assert "lsh" not in self.config.attn_layers or config.lsh_num_chunks_after == 0, (
"If causal mask is enabled, make sure that `config.lsh_num_chunks_after` is set to 1 and not"
f" {config.lsh_num_chunks_after}."
)
# 实例化 ReformerModel,传入配置参数
self.reformer = ReformerModel(config)
# 实例化 ReformerOnlyLMHead,传入配置参数
self.lm_head = ReformerOnlyLMHead(config)
# 初始化权重并应用最终处理
self.post_init()
# 获取语言建模头部的输出嵌入层
def get_output_embeddings(self):
return self.lm_head.decoder
# 设置语言建模头部的输出嵌入层为新的嵌入层
def set_output_embeddings(self, new_embeddings):
self.lm_head.decoder = new_embeddings
# 重写 forward 方法,并使用装饰器添加文档字符串和代码示例文档字符串
@add_start_docstrings_to_model_forward(REFORMER_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=CausalLMOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
num_hashes: Optional[int] = None,
past_buckets_states: Optional[List[Tuple[torch.Tensor]]] = None,
use_cache: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
labels: Optional[torch.Tensor] = None,
) -> Union[Tuple, CausalLMOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
labels in `[0, ..., config.vocab_size]`
"""
# 如果 return_dict 已经被定义,则使用其当前值;否则使用 self.config.use_return_dict 的值
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# 调用 Reformer 模型,传递多个参数以获取输出
reformer_outputs = self.reformer(
input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
num_hashes=num_hashes,
past_buckets_states=past_buckets_states,
use_cache=use_cache,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
return_dict=return_dict,
)
# 从 Reformer 输出中获取序列输出
sequence_output = reformer_outputs[0]
# 将序列输出传递给语言模型头部,获取 logits
logits = self.lm_head(sequence_output)
# 初始化损失为 None
loss = None
# 如果 labels 不为空,则计算损失
if labels is not None:
# 将 logits 左移一位,以便对应标签右移一位,即 tokens < n 预测 n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# 使用交叉熵损失函数计算损失
loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
# 如果 return_dict 为 False,则返回一个元组,包含 logits 和其他输出
if not return_dict:
output = (logits,) + reformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
# 如果 return_dict 为 True,则返回一个包含损失和各种输出的 ReformerModelWithLMHeadOutput 对象
return ReformerModelWithLMHeadOutput(
loss=loss,
logits=logits,
past_buckets_states=reformer_outputs.past_buckets_states,
hidden_states=reformer_outputs.hidden_states,
attentions=reformer_outputs.attentions,
)
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, use_cache=None, num_hashes=None, **kwargs
):
# 如果 past_key_values 不为空,则只保留输入的最后一个 token
if past_key_values is not None:
input_ids = input_ids[:, -1:]
# 准备用于生成的输入字典
inputs_dict = {
"input_ids": input_ids,
"past_buckets_states": past_key_values,
"use_cache": use_cache,
"num_hashes": num_hashes,
}
# 返回输入字典
return inputs_dict
# 重新排序缓存中的过去键值对状态
def _reorder_cache(self, past_key_values, beam_idx):
# 存储重新排序后的过去桶状态和隐藏状态的列表
reord_past_buckets_states = []
# 遍历每个层级的过去键值对
for layer_past in past_key_values:
# 如果当前层的桶状态不为None,则根据beam_idx重新排序
if layer_past[0] is not None:
reord_buckets = layer_past[0].index_select(0, beam_idx.to(layer_past[0].device))
else:
reord_buckets = None
# 根据beam_idx重新排序当前层的隐藏状态
reord_hidden_states = layer_past[1].index_select(0, beam_idx.to(layer_past[1].device))
# 将重新排序后的桶状态和隐藏状态组成元组,并添加到列表中
reord_past_buckets_states.append((reord_buckets, reord_hidden_states))
# 返回所有层级的重新排序后的过去桶状态和隐藏状态的列表
return reord_past_buckets_states
@add_start_docstrings("""Reformer Model with a `language modeling` head on top.""", REFORMER_START_DOCSTRING)
class ReformerForMaskedLM(ReformerPreTrainedModel):
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
def __init__(self, config):
super().__init__(config)
# 确保不是解码器模式,因为此模型用于双向自注意力
assert not config.is_decoder, (
"If you want to use `ReformerForMaskedLM` make sure `config.is_decoder=False` for bi-directional"
" self-attention."
)
# 初始化Reformer模型和ReformerLMHead
self.reformer = ReformerModel(config)
self.lm_head = ReformerOnlyLMHead(config)
# 初始化权重并应用最终处理
self.post_init()
def get_output_embeddings(self):
# 返回语言建模头部的解码器权重
return self.lm_head.decoder
def set_output_embeddings(self, new_embeddings):
# 设置新的输出嵌入到语言建模头部的解码器
self.lm_head.decoder = new_embeddings
@add_start_docstrings_to_model_forward(REFORMER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
num_hashes: Optional[int] = None,
labels: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
@add_start_docstrings(
"""
Reformer Model transformer with a sequence classification/regression head on top (a linear layer on top of the
pooled output) e.g. for GLUE tasks.
""",
REFORMER_START_DOCSTRING,
)
class ReformerForSequenceClassification(ReformerPreTrainedModel):
def __init__(self, config):
super().__init__(config)
# 保存标签数和配置
self.num_labels = config.num_labels
self.config = config
# 初始化Reformer模型和分类器
self.reformer = ReformerModel(config)
self.classifier = ReformerClassificationHead(config)
if config.is_decoder is True:
# 如果配置为解码器,警告可能需要禁用因果遮蔽以进行序列分类
logger.warning("You might want to disable causal masking for sequence classification")
# 初始化权重并应用最终处理
self.post_init()
@add_start_docstrings_to_model_forward(REFORMER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
num_hashes: Optional[int] = None,
labels: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
class ReformerClassificationHead(nn.Module):
"""Head for sentence-level classification tasks."""
def __init__(self, config):
super().__init__()
# 初始化分类器头部,输入维度为2倍的隐藏状态大小,输出维度为隐藏状态大小
self.dense = nn.Linear(2 * config.hidden_size, config.hidden_size)
# 根据配置设置分类器的dropout,如果未提供分类器dropout,则使用隐藏层的dropout概率
classifier_dropout = (
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
)
self.dropout = nn.Dropout(classifier_dropout)
# 最终的线性投影层,将隐藏状态映射到标签数量大小的输出
self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
def forward(self, hidden_states, **kwargs):
# 只取第一个特殊符号 <s> 的隐藏状态,通常对应于 [CLS] 标记
hidden_states = hidden_states[:, 0, :] # take <s> token (equiv. to [CLS])
# 应用dropout以减少过拟合风险
hidden_states = self.dropout(hidden_states)
# 使用全连接层进行特征变换
hidden_states = self.dense(hidden_states)
# 应用tanh激活函数,增强非线性建模能力
hidden_states = torch.tanh(hidden_states)
# 再次应用dropout
hidden_states = self.dropout(hidden_states)
# 最终通过线性投影层得到分类任务的输出
hidden_states = self.out_proj(hidden_states)
return hidden_states
@add_start_docstrings(
"""
Reformer Model with a span classification head on top for extractive question-answering tasks like SQuAD / TriviaQA
( a linear layer on top of hidden-states output to compute `span start logits` and `span end logits`.
""",
REFORMER_START_DOCSTRING,
)
class ReformerForQuestionAnswering(ReformerPreTrainedModel):
def __init__(self, config):
super().__init__(config)
# 记录标签数量
self.num_labels = config.num_labels
# 初始化Reformer模型,这里的隐藏状态是2倍的隐藏大小,因为使用了可逆残差层
self.reformer = ReformerModel(config)
# 初始化问题回答任务的输出层,输入维度是2倍的隐藏状态大小,输出维度是标签数量
self.qa_outputs = nn.Linear(2 * config.hidden_size, config.num_labels)
# 初始化权重并应用最终处理
self.post_init()
@add_start_docstrings_to_model_forward(REFORMER_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=QuestionAnsweringModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
num_hashes: Optional[int] = None,
start_positions: Optional[torch.Tensor] = None,
end_positions: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
.\models\reformer\tokenization_reformer.py
""" Tokenization class for model Reformer."""
import os
from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple
import sentencepiece as spm
from ...tokenization_utils import PreTrainedTokenizer
from ...utils import logging
logger = logging.get_logger(__name__)
SPIECE_UNDERLINE = "▁"
VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"}
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"google/reformer-crime-and-punishment": (
"https://huggingface.co/google/reformer-crime-and-punishment/resolve/main/spiece.model"
)
}
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"google/reformer-crime-and-punishment": 524288,
}
class ReformerTokenizer(PreTrainedTokenizer):
"""
Construct a Reformer tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece) .
This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
this superclass for more information regarding those methods.
"""
"""
Args:
vocab_file (`str`):
[SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
contains the vocabulary necessary to instantiate a tokenizer.
eos_token (`str`, *optional*, defaults to `"</s>"`):
The end of sequence token.
<Tip>
When building a sequence using special tokens, this is not the token that is used for the end of sequence.
The token used is the `sep_token`.
</Tip>
unk_token (`str`, *optional*, defaults to `"<unk>"`):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
token instead.
additional_special_tokens (`List[str]`, *optional*, defaults to `[]`):
Additional special tokens used by the tokenizer.
sp_model_kwargs (`dict`, *optional*):
Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
to set:
- `enable_sampling`: Enable subword regularization.
- `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
- `nbest_size = {0,1}`: No sampling is performed.
- `nbest_size > 1`: samples from the nbest_size results.
- `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
using forward-filtering-and-backward-sampling algorithm.
- `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
BPE-dropout.
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
model_input_names = ["input_ids", "attention_mask"]
def __init__(
self,
vocab_file,
eos_token="</s>",
unk_token="<unk>",
additional_special_tokens=[],
sp_model_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
) -> None:
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
self.vocab_file = vocab_file
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(vocab_file)
super().__init__(
eos_token=eos_token,
unk_token=unk_token,
additional_special_tokens=additional_special_tokens,
sp_model_kwargs=self.sp_model_kwargs,
**kwargs,
)
@property
def vocab_size(self):
return self.sp_model.get_piece_size()
def get_vocab(self) -> Dict[str, int]:
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
vocab.update(self.added_tokens_encoder)
return vocab
def __getstate__(self):
state = self.__dict__.copy()
state["sp_model"] = None
return state
def __setstate__(self, d):
self.__dict__ = d
if not hasattr(self, "sp_model_kwargs"):
self.sp_model_kwargs = {}
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(self.vocab_file)
def _tokenize(self, text: str) -> List[str]:
"""接受一个字符串作为输入,并返回一个由单词/子词组成的列表(tokens)"""
return self.sp_model.encode(text, out_type=str)
def _convert_token_to_id(self, token):
"""将一个token(字符串)转换为其对应的ID,使用词汇表"""
return self.sp_model.piece_to_id(token)
def _convert_id_to_token(self, index):
"""将一个索引(整数)转换为其对应的token(字符串),使用词汇表"""
if index < self.sp_model.get_piece_size():
token = self.sp_model.IdToPiece(index)
return token
def convert_tokens_to_string(self, tokens):
"""将一系列token(字符串)转换为单个字符串"""
current_sub_tokens = []
out_string = ""
for token in tokens:
if token in self.all_special_tokens:
out_string += self.sp_model.decode(current_sub_tokens) + token
current_sub_tokens = []
else:
current_sub_tokens.append(token)
out_string += self.sp_model.decode(current_sub_tokens)
return out_string.strip()
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
return
out_vocab_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
)
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file)
elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
return (out_vocab_file,)