Transformers 源码解析(七十四)
.\models\mega\__init__.py
from typing import TYPE_CHECKING
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_torch_available,
)
_import_structure = {
"configuration_mega": ["MEGA_PRETRAINED_CONFIG_ARCHIVE_MAP", "MegaConfig", "MegaOnnxConfig"],
}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_mega"] = [
"MEGA_PRETRAINED_MODEL_ARCHIVE_LIST",
"MegaForCausalLM",
"MegaForMaskedLM",
"MegaForMultipleChoice",
"MegaForQuestionAnswering",
"MegaForSequenceClassification",
"MegaForTokenClassification",
"MegaModel",
"MegaPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_mega import MEGA_PRETRAINED_CONFIG_ARCHIVE_MAP, MegaConfig, MegaOnnxConfig
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_mega import (
MEGA_PRETRAINED_MODEL_ARCHIVE_LIST,
MegaForCausalLM,
MegaForMaskedLM,
MegaForMultipleChoice,
MegaForQuestionAnswering,
MegaForSequenceClassification,
MegaForTokenClassification,
MegaModel,
MegaPreTrainedModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
.\models\megatron_bert\configuration_megatron_bert.py
""" MEGATRON_BERT 模型配置"""
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
}
class MegatronBertConfig(PretrainedConfig):
r"""
这是用于存储 [`MegatronBertModel`] 配置的配置类。它用于根据指定的参数实例化一个 MEGATRON_BERT 模型,
定义模型的架构。使用默认值实例化配置将产生类似于 MEGATRON_BERT
[nvidia/megatron-bert-uncased-345m](https://huggingface.co/nvidia/megatron-bert-uncased-345m) 架构的配置。
配置对象继承自 [`PretrainedConfig`],可用于控制模型输出。有关更多信息,请阅读 [`PretrainedConfig`] 的文档。
Examples:
```
>>> from transformers import MegatronBertConfig, MegatronBertModel
>>> # 初始化一个 MEGATRON_BERT google-bert/bert-base-uncased 风格的配置
>>> configuration = MegatronBertConfig()
>>> # 使用配置初始化一个(带有随机权重)从 google-bert/bert-base-uncased 风格配置的模型
>>> model = MegatronBertModel(configuration)
>>> # 访问模型配置
>>> configuration = model.config
```
"""
model_type = "megatron-bert"
def __init__(
self,
vocab_size=29056,
hidden_size=1024,
num_hidden_layers=24,
num_attention_heads=16,
intermediate_size=4096,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
layer_norm_eps=1e-12,
pad_token_id=0,
position_embedding_type="absolute",
use_cache=True,
**kwargs,
):
super().__init__(pad_token_id=pad_token_id, **kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.position_embedding_type = position_embedding_type
self.use_cache = use_cache
.\models\megatron_bert\convert_megatron_bert_checkpoint.py
import argparse
import os
import re
import zipfile
import torch
from transformers import MegatronBertConfig
def recursive_print(name, val, spaces=0):
if name is None:
msg = None
else:
fmt = "." * max(0, spaces - 2) + "# {:" + str(50 - spaces) + "s}"
msg = fmt.format(name)
if isinstance(val, dict):
if msg is not None:
print(msg)
for k in val.keys():
recursive_print(k, val[k], spaces + 2)
elif isinstance(val, torch.Tensor):
print(msg, ":", val.size())
else:
print(msg, ":", val)
def fix_query_key_value_ordering(param, checkpoint_version, num_splits, num_heads, hidden_size):
input_shape = param.size()
if checkpoint_version == 1.0:
saved_shape = (num_heads, hidden_size, num_splits) + input_shape[1:]
param = param.view(*saved_shape)
param = param.transpose(0, 2)
param = param.transpose(1, 2).contiguous()
elif checkpoint_version >= 2.0:
saved_shape = (num_heads, num_splits, hidden_size) + input_shape[1:]
param = param.view(*saved_shape)
param = param.transpose(0, 1).contiguous()
param = param.view(*input_shape)
return param
def convert_megatron_checkpoint(args, input_state_dict, config):
output_state_dict = {}
ds_args = input_state_dict.get("args", None)
if ds_args is not None:
config.tokenizer_type = ds_args.tokenizer_type
config.vocab_size = ds_args.padded_vocab_size
config.max_position_embeddings = ds_args.max_position_embeddings
config.hidden_size = ds_args.hidden_size
config.num_hidden_layers = ds_args.num_layers
config.num_attention_heads = ds_args.num_attention_heads
config.intermediate_size = ds_args.ffn_hidden_size if "ffn_hidden_size" in ds_args else 4 * ds_args.hidden_size
heads = config.num_attention_heads
hidden_size_per_head = config.hidden_size // heads
if "checkpoint_version" in input_state_dict.keys():
checkpoint_version = input_state_dict["checkpoint_version"]
else:
checkpoint_version = 0.0
model = input_state_dict["model"]
lm = model["language_model"]
embeddings = lm["embedding"]
word_embeddings = embeddings["word_embeddings"]["weight"]
word_embeddings = word_embeddings[: config.vocab_size, :]
output_state_dict["bert.embeddings.word_embeddings.weight"] = word_embeddings
pos_embeddings = embeddings["position_embeddings"]["weight"]
assert pos_embeddings.size(0) == config.max_position_embeddings and pos_embeddings.size(1) == config.hidden_size
output_state_dict["bert.embeddings.position_embeddings.weight"] = pos_embeddings
tokentype_embeddings = embeddings["tokentype_embeddings"]["weight"]
output_state_dict["bert.embeddings.token_type_embeddings.weight"] = tokentype_embeddings
transformer = lm["transformer"] if "transformer" in lm.keys() else lm["encoder"]
layer_re = re.compile(r"layers\.(\d+)\.([a-z0-9_.]+)\.([a-z]+)")
megatron_to_transformers = {
"attention.dense": ".attention.output.dense.",
"self_attention.dense": ".attention.output.dense.",
"mlp.dense_h_to_4h": ".intermediate.dense.",
"mlp.dense_4h_to_h": ".output.dense.",
}
attention_qkv_weight = None
output_state_dict["bert.encoder.ln.weight"] = transformer["final_layernorm.weight"]
output_state_dict["bert.encoder.ln.bias"] = transformer["final_layernorm.bias"]
pooler = lm["pooler"]
output_state_dict["bert.pooler.dense.weight"] = pooler["dense.weight"]
output_state_dict["bert.pooler.dense.bias"] = pooler["dense.bias"]
output_state_dict["cls.predictions.transform.dense.weight"] = lm_head["dense.weight"]
output_state_dict["cls.predictions.transform.dense.bias"] = lm_head["dense.bias"]
output_state_dict["cls.predictions.transform.LayerNorm.weight"] = lm_head["layernorm.weight"]
output_state_dict["cls.predictions.transform.LayerNorm.bias"] = lm_head["layernorm.bias"]
output_state_dict["cls.predictions.decoder.weight"] = word_embeddings
output_state_dict["cls.predictions.bias"] = lm_head["bias"]
output_state_dict["cls.seq_relationship.weight"] = binary_head["weight"]
output_state_dict["cls.seq_relationship.bias"] = binary_head["bias"]
return output_state_dict
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--print-checkpoint-structure", action="store_true")
parser.add_argument("path_to_checkpoint", type=str, help="Path to the ZIP file containing the checkpoint")
parser.add_argument(
"--config_file",
default="",
type=str,
help="An optional config json file describing the pre-trained model.",
)
args = parser.parse_args()
basename = os.path.dirname(args.path_to_checkpoint)
print(f'Extracting PyTorch state dictionary from "{args.path_to_checkpoint}"')
if args.path_to_checkpoint.endswith(".zip"):
with zipfile.ZipFile(args.path_to_checkpoint, "r") as checkpoint:
with checkpoint.open("release/mp_rank_00/model_optim_rng.pt") as pytorch_dict:
input_state_dict = torch.load(pytorch_dict, map_location="cpu")
else:
input_state_dict = torch.load(args.path_to_checkpoint, map_location="cpu")
if args.config_file == "":
config = MegatronBertConfig()
config.vocab_size = input_state_dict["model"]["lm_head"]["bias"].numel()
else:
config = MegatronBertConfig.from_json_file(args.config_file)
print("Converting")
output_state_dict = convert_megatron_checkpoint(args, input_state_dict, config)
if args.print_checkpoint_structure:
recursive_print(None, output_state_dict)
print("Saving config")
config.save_pretrained(basename)
output_checkpoint_file = os.path.join(basename, "pytorch_model.bin")
print(f'Saving checkpoint to "{output_checkpoint_file}"')
torch.save(output_state_dict, output_checkpoint_file)
if __name__ == "__main__":
main()
.\models\megatron_bert\modeling_megatron_bert.py
""" PyTorch MegatronBERT model."""
import math
import os
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
MaskedLMOutput,
MultipleChoiceModelOutput,
NextSentencePredictorOutput,
QuestionAnsweringModelOutput,
SequenceClassifierOutput,
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
ModelOutput,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from .configuration_megatron_bert import MegatronBertConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "MegatronBertConfig"
_CHECKPOINT_FOR_DOC = "nvidia/megatron-bert-cased-345m"
MEGATRON_BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
"nvidia/megatron-bert-cased-345m",
]
def load_tf_weights_in_megatron_bert(model, config, tf_checkpoint_path):
"""Load tf checkpoints in a pytorch model."""
try:
import re
import numpy as np
import tensorflow as tf
except ImportError:
logger.error(
"Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions."
)
raise
tf_path = os.path.abspath(tf_checkpoint_path)
logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
init_vars = tf.train.list_variables(tf_path)
names = []
arrays = []
for name, shape in init_vars:
logger.info(f"Loading TF weight {name} with shape {shape}")
array = tf.train.load_variable(tf_path, name)
names.append(name)
arrays.append(array)
for name, array in zip(names, arrays):
name = name.split("/")
if any(
n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
for n in name
):
logger.info(f"Skipping {'/'.join(name)}")
continue
pointer = model
for m_name in name:
if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
scope_names = re.split(r"_(\d+)", m_name)
else:
scope_names = [m_name]
if scope_names[0] == "kernel" or scope_names[0] == "gamma":
pointer = getattr(pointer, "weight")
elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
pointer = getattr(pointer, "bias")
elif scope_names[0] == "output_weights":
pointer = getattr(pointer, "weight")
elif scope_names[0] == "squad":
pointer = getattr(pointer, "classifier")
else:
try:
pointer = getattr(pointer, scope_names[0])
except AttributeError:
logger.info(f"Skipping {'/'.join(name)}")
continue
if len(scope_names) >= 2:
num = int(scope_names[1])
pointer = pointer[num]
if m_name[-11:] == "_embeddings":
pointer = getattr(pointer, "weight")
elif m_name == "kernel":
array = np.transpose(array)
if pointer.shape != array.shape:
raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
logger.info("Initialize PyTorch weight {}".format(name))
pointer.data = torch.from_numpy(array)
return model
class MegatronBertEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings."""
def __init__(self, config):
super().__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.LongTensor] = None,
past_key_values_length: int = 0,
) -> torch.Tensor:
if input_ids is not None:
input_shape = input_ids.size()
else:
input_shape = inputs_embeds.size()[:-1]
seq_length = input_shape[1]
if position_ids is None:
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = inputs_embeds + token_type_embeddings
if self.position_embedding_type == "absolute":
position_embeddings = self.position_embeddings(position_ids)
embeddings += position_embeddings
embeddings = self.dropout(embeddings)
return embeddings
class MegatronBertSelfAttention(nn.Module):
def __init__(self, config, position_embedding_type=None):
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
f"heads ({config.num_attention_heads})"
)
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = position_embedding_type or getattr(
config, "position_embedding_type", "absolute"
)
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
self.is_decoder = config.is_decoder
class MegatronBertSelfOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
return residual + hidden_states
class MegatronBertAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.self = MegatronBertSelfAttention(config)
self.output = MegatronBertSelfOutput(config)
self.pruned_heads = set()
def prune_heads(self, heads):
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
)
self.self.query = prune_linear_layer(self.self.query, index)
self.self.key = prune_linear_layer(self.self.key, index)
self.self.value = prune_linear_layer(self.self.value, index)
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
ln_outputs = self.ln(hidden_states)
self_outputs = self.self(
ln_outputs,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:]
return outputs
class MegatronBertIntermediate(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class MegatronBertOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
return input_tensor + hidden_states
class MegatronBertLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = MegatronBertAttention(config)
self.is_decoder = config.is_decoder
self.add_cross_attention = config.add_cross_attention
if self.add_cross_attention:
if not self.is_decoder:
raise TypeError(f"{self} should be used as a decoder model if cross attention is added")
self.crossattention = MegatronBertAttention(config)
self.ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.intermediate = MegatronBertIntermediate(config)
self.output = MegatronBertOutput(config)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
):
) -> Tuple[torch.Tensor]:
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention(
hidden_states,
attention_mask,
head_mask,
output_attentions=output_attentions,
past_key_value=self_attn_past_key_value,
)
attention_output = self_attention_outputs[0]
if self.is_decoder:
outputs = self_attention_outputs[1:-1]
present_key_value = self_attention_outputs[-1]
else:
outputs = self_attention_outputs[1:]
cross_attn_present_key_value = None
if self.is_decoder and encoder_hidden_states is not None:
if not hasattr(self, "crossattention"):
raise AttributeError(
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
" by setting `config.add_cross_attention=True`"
)
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
cross_attention_outputs = self.crossattention(
attention_output,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
cross_attn_past_key_value,
output_attentions,
)
attention_output = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1:-1]
cross_attn_present_key_value = cross_attention_outputs[-1]
present_key_value = present_key_value + cross_attn_present_key_value
layer_output = apply_chunking_to_forward(
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
)
outputs = (layer_output,) + outputs
if self.is_decoder:
outputs = outputs + (present_key_value,)
return outputs
def feed_forward_chunk(self, attention_output):
ln_output = self.ln(attention_output)
intermediate_output = self.intermediate(ln_output)
layer_output = self.output(intermediate_output, attention_output)
return layer_output
class MegatronBertEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.layer = nn.ModuleList([MegatronBertLayer(config) for _ in range(config.num_hidden_layers)])
self.ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = True,
class MegatronBertPooler(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
class MegatronBertPredictionHeadTransform(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
if isinstance(config.hidden_act, str):
self.transform_act_fn = ACT2FN[config.hidden_act]
else:
self.transform_act_fn = config.hidden_act
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states
class MegatronBertLMPredictionHead(nn.Module):
def __init__(self, config):
super().__init__()
self.transform = MegatronBertPredictionHeadTransform(config)
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
self.decoder.bias = self.bias
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
return hidden_states
class MegatronBertOnlyMLMHead(nn.Module):
def __init__(self, config):
super().__init__()
self.predictions = MegatronBertLMPredictionHead(config)
def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
prediction_scores = self.predictions(sequence_output)
return prediction_scores
class MegatronBertOnlyNSPHead(nn.Module):
def __init__(self, config):
super().__init__()
self.seq_relationship = nn.Linear(config.hidden_size, 2)
def forward(self, pooled_output):
seq_relationship_score = self.seq_relationship(pooled_output)
return seq_relationship_score
class MegatronBertPreTrainingHeads(nn.Module):
def __init__(self, config):
super().__init__()
self.predictions = MegatronBertLMPredictionHead(config)
self.seq_relationship = nn.Linear(config.hidden_size, 2)
def forward(self, sequence_output, pooled_output):
prediction_scores = self.predictions(sequence_output)
seq_relationship_score = self.seq_relationship(pooled_output)
return prediction_scores, seq_relationship_score
class MegatronBertPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = MegatronBertConfig
load_tf_weights = load_tf_weights_in_megatron_bert
base_model_prefix = "bert"
supports_gradient_checkpointing = True
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
@dataclass
class MegatronBertForPreTrainingOutput(ModelOutput):
"""
Output type of [`MegatronBertForPreTraining`].
"""
loss: Optional[torch.FloatTensor] = None
prediction_logits: torch.FloatTensor = None
seq_relationship_logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`MegatronBertConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
MEGATRON_BERT_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `({0})`):
# 输入序列标记在词汇表中的索引。
# 可以使用 `AutoTokenizer` 获得这些索引。参见 `PreTrainedTokenizer.encode` 和 `PreTrainedTokenizer.__call__`。
# [什么是输入 ID?](../glossary#input-ids)
attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
# 遮罩,避免在填充标记索引上执行注意力操作。遮罩值在 `[0, 1]` 范围内选择:
# - 1 表示不遮罩的标记,
# - 0 表示遮罩的标记。
# [什么是注意力遮罩?](../glossary#attention-mask)
token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
# 段标记索引,指示输入的第一部分和第二部分。索引在 `[0, 1]` 中选择:
# - 0 对应 *句子 A* 的标记,
# - 1 对应 *句子 B* 的标记。
# [什么是标记类型 ID?](../glossary#token-type-ids)
position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
# 每个输入序列标记在位置嵌入中的位置索引。选择范围是 `[0, config.max_position_embeddings - 1]`。
# [什么是位置 ID?](../glossary#position-ids)
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
# 用于将自注意力模块的特定头部置零的遮罩。遮罩值在 `[0, 1]` 范围内选择:
# - 1 表示头部未被遮罩,
# - 0 表示头部被遮罩。
inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
# 可选,可以直接传递嵌入表示而不是 `input_ids`。如果想更精细地控制如何将 `input_ids` 索引转换为关联向量,这将很有用。
# 这对于比模型内部嵌入查找矩阵更有控制的情况很有用。
output_attentions (`bool`, *optional*):
# 是否返回所有注意力层的注意力张量。有关更多详细信息,请参见返回张量中的 `attentions`。
output_hidden_states (`bool`, *optional*):
# 是否返回所有层的隐藏状态。有关更多详细信息,请参见返回张量中的 `hidden_states`。
return_dict (`bool`, *optional*):
# 是否返回 [`~utils.ModelOutput`] 而不是普通元组。
"""
@add_start_docstrings(
"The bare MegatronBert Model transformer outputting raw hidden-states without any specific head on top.",
MEGATRON_BERT_START_DOCSTRING,
)
class MegatronBertModel(MegatronBertPreTrainedModel):
"""
MegatronBertModel类继承自MegatronBertPreTrainedModel,代表一个裸的MegatronBert模型,输出没有特定头部的原始隐藏状态。
这个模型可以作为编码器(只有自注意力)或解码器使用。当作为解码器时,在自注意力层之间会添加一个交叉注意力层,遵循[Attention is
all you need](https://arxiv.org/abs/1706.03762)中描述的架构,作者包括Ashish Vaswani、Noam Shazeer、Niki Parmar、
Jakob Uszkoreit、Llion Jones、Aidan N. Gomez、Lukasz Kaiser和Illia Polosukhin。
要作为解码器使用,需要用`is_decoder`参数设置为`True`来初始化模型配置。要在Seq2Seq模型中使用,需要用`is_decoder`和
`add_cross_attention`参数都设置为`True`来初始化;此时前向传播期望一个`encoder_hidden_states`作为输入。
"""
def __init__(self, config, add_pooling_layer=True):
super().__init__(config)
self.config = config
self.embeddings = MegatronBertEmbeddings(config)
self.encoder = MegatronBertEncoder(config)
self.pooler = MegatronBertPooler(config) if add_pooling_layer else None
self.post_init()
def get_input_embeddings(self):
return self.embeddings.word_embeddings
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
def _prune_heads(self, heads_to_prune):
"""
剪枝模型中的注意力头。heads_to_prune: {layer_num: 要在该层剪枝的头列表} 参见基类PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
@add_start_docstrings_to_model_forward(MEGATRON_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=BaseModelOutputWithPoolingAndCrossAttentions,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
"""
MegatronBert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a
`next sentence prediction (classification)` head.
"""
class MegatronBertForPreTraining(MegatronBertPreTrainedModel):
_tied_weights_keys = ["cls.predictions.decoder"]
def __init__(self, config, add_binary_head=True):
super().__init__(config)
self.bert = MegatronBertModel(config)
self.cls = MegatronBertPreTrainingHeads(config)
self.post_init()
def get_output_embeddings(self):
return self.cls.predictions.decoder
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
@add_start_docstrings_to_model_forward(MEGATRON_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=MegatronBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
next_sentence_label: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
"""
MegatronBert Model with a `language modeling` head on top for CLM fine-tuning.
"""
class MegatronBertForCausalLM(MegatronBertPreTrainedModel):
_tied_weights_keys = ["cls.predictions.decoder"]
def __init__(self, config):
super().__init__(config)
if not config.is_decoder:
logger.warning("If you want to use `MegatronBertForCausalLM` as a standalone, add `is_decoder=True.`")
self.bert = MegatronBertModel(config, add_pooling_layer=False)
self.cls = MegatronBertOnlyMLMHead(config)
self.post_init()
def get_output_embeddings(self):
return self.cls.predictions.decoder
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
@add_start_docstrings_to_model_forward(MEGATRON_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
pass
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
input_shape = input_ids.shape
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
def _reorder_cache(self, past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past
@add_start_docstrings("""MegatronBert Model with a `language modeling` head on top.""", MEGATRON_BERT_START_DOCSTRING)
class MegatronBertForMaskedLM(MegatronBertPreTrainedModel):
_tied_weights_keys = ["cls.predictions.decoder"]
def __init__(self, config):
super().__init__(config)
if config.is_decoder:
logger.warning(
"If you want to use `MegatronBertForMaskedLM` make sure `config.is_decoder=False` for "
"bi-directional self-attention."
)
self.bert = MegatronBertModel(config, add_pooling_layer=False)
self.cls = MegatronBertOnlyMLMHead(config)
self.post_init()
def get_output_embeddings(self):
return self.cls.predictions.decoder
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
@add_start_docstrings_to_model_forward(MEGATRON_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=MaskedLMOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
prediction_scores = self.cls(sequence_output)
masked_lm_loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict:
output = (prediction_scores,) + outputs[2:]
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
return MaskedLMOutput(
loss=masked_lm_loss,
logits=prediction_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
input_shape = input_ids.shape
effective_batch_size = input_shape[0]
if self.config.pad_token_id is None:
raise ValueError("The PAD token should be defined for generation")
attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)
dummy_token = torch.full(
(effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
)
input_ids = torch.cat([input_ids, dummy_token], dim=1)
return {"input_ids": input_ids, "attention_mask": attention_mask}
@add_start_docstrings(
"""MegatronBert Model with a `next sentence prediction (classification)` head on top.""",
MEGATRON_BERT_START_DOCSTRING,
)
class MegatronBertForNextSentencePrediction(MegatronBertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.bert = MegatronBertModel(config)
self.cls = MegatronBertOnlyNSPHead(config)
self.post_init()
@add_start_docstrings_to_model_forward(MEGATRON_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
):
def forward(
self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
labels: Optional[torch.LongTensor] = None,
) -> Union[Tuple, NextSentencePredictorOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
(see `input_ids` docstring). Indices should be in `[0, 1]`:
- 0 indicates sequence B is a continuation of sequence A,
- 1 indicates sequence B is a random sequence.
Returns:
Depending on `return_dict`:
- If `return_dict=False` (default): returns a tuple with `seq_relationship_scores` followed by `outputs[2:]`.
- If `return_dict=True`: returns a `NextSentencePredictorOutput` containing loss, logits, hidden states, and attentions.
Example:
```
>>> from transformers import AutoTokenizer, MegatronBertForNextSentencePrediction
>>> import torch
>>> tokenizer = AutoTokenizer.from_pretrained("nvidia/megatron-bert-cased-345m")
>>> model = MegatronBertForNextSentencePrediction.from_pretrained("nvidia/megatron-bert-cased-345m")
>>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
>>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
>>> encoding = tokenizer(prompt, next_sentence, return_tensors="pt")
>>> outputs = model(**encoding, labels=torch.LongTensor([1]))
>>> logits = outputs.logits
>>> assert logits[0, 0] < logits[0, 1] # next sentence was random
```
if "next_sentence_label" in kwargs:
warnings.warn(
"The `next_sentence_label` argument is deprecated and will be removed in a future version, use"
" `labels` instead.",
FutureWarning,
)
labels = kwargs.pop("next_sentence_label")
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Pass input tensors through the BERT model to get outputs
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# Get the pooled output from BERT's outputs
pooled_output = outputs[1]
# Predict next sentence relationship using a classifier layer
seq_relationship_scores = self.cls(pooled_output)
next_sentence_loss = None
# Compute loss if labels are provided
if labels is not None:
loss_fct = CrossEntropyLoss()
next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1))
# Return outputs based on `return_dict` flag
if not return_dict:
output = (seq_relationship_scores,) + outputs[2:]
return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output
# Return a `NextSentencePredictorOutput` object if `return_dict=True`
return NextSentencePredictorOutput(
loss=next_sentence_loss,
logits=seq_relationship_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"""
MegatronBert Model transformer with a sequence classification/regression head on top (a linear layer on top of the
pooled output) e.g. for GLUE tasks.
""",
MEGATRON_BERT_START_DOCSTRING,
)
class MegatronBertForSequenceClassification(MegatronBertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
# 初始化 Bert 模型和相关组件
self.bert = MegatronBertModel(config)
# Dropout 层,用于减少过拟合
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# 分类器,线性层,将 BERT 输出映射到标签空间
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
# 初始化权重并进行后续处理
self.post_init()
@add_start_docstrings_to_model_forward(MEGATRON_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=SequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
"""
前向传播函数,处理输入并生成模型输出。
Args:
input_ids (Optional[torch.LongTensor], optional): 输入的 token IDs. Defaults to None.
attention_mask (Optional[torch.FloatTensor], optional): 注意力掩码,指示哪些元素是填充的. Defaults to None.
token_type_ids (Optional[torch.LongTensor], optional): token 类型 IDs,区分 segment A 和 segment B. Defaults to None.
position_ids (Optional[torch.LongTensor], optional): token 的位置 IDs. Defaults to None.
head_mask (Optional[torch.FloatTensor], optional): 多头注意力机制的掩码. Defaults to None.
inputs_embeds (Optional[torch.FloatTensor], optional): 嵌入式表示的输入. Defaults to None.
labels (Optional[torch.LongTensor], optional): 标签,用于计算损失. Defaults to None.
output_attentions (Optional[bool], optional): 是否返回注意力权重. Defaults to None.
output_hidden_states (Optional[bool], optional): 是否返回所有隐藏状态. Defaults to None.
return_dict (Optional[bool], optional): 是否以字典形式返回输出. Defaults to None.
Returns:
SequenceClassifierOutput: 包含模型输出和损失的对象
"""
# BERT 模型的 forward 方法,处理输入并生成模型输出
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
pooled_output = outputs[1] # 取出池化输出,通常用于分类任务
pooled_output = self.dropout(pooled_output) # 应用 dropout 防止过拟合
logits = self.classifier(pooled_output) # 使用线性分类器映射到标签空间
return SequenceClassifierOutput(
logits=logits,
hidden_states=outputs.hidden_states if output_hidden_states else None,
attentions=outputs.attentions if output_attentions else None,
)
) -> Union[Tuple, SequenceClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
# 初始化返回字典,如果未提供则使用配置中的默认值
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# 将输入传递给BERT模型进行处理,并获取其输出
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# 从BERT模型的输出中获取汇聚的输出表示
pooled_output = outputs[1]
# 对汇聚的输出表示进行dropout操作
pooled_output = self.dropout(pooled_output)
# 将dropout后的输出传递给分类器,得到预测的logits
logits = self.classifier(pooled_output)
# 初始化损失为None
loss = None
# 如果提供了标签,则计算相应的损失
if labels is not None:
# 如果问题类型未定义,则根据标签类型和类数自动推断问题类型
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
# 根据问题类型选择合适的损失函数
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
# 如果不需要返回字典,则按照非字典返回格式组织输出
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
# 如果需要返回字典,则创建SequenceClassifierOutput对象,并返回
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
# 定义 MegatronBertForMultipleChoice 类,继承自 MegatronBertPreTrainedModel,用于多项选择任务的 Megatron-BERT 模型
@add_start_docstrings(
"""
MegatronBert Model with a multiple choice classification head on top (a linear layer on top of the pooled output
and a softmax) e.g. for RocStories/SWAG tasks.
""",
MEGATRON_BERT_START_DOCSTRING,
)
class MegatronBertForMultipleChoice(MegatronBertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
# 初始化 Megatron-BERT 模型
self.bert = MegatronBertModel(config)
# Dropout 层,用于随机断开神经元连接,防止过拟合
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# 分类器,线性层,将 BERT 隐藏层的输出映射到一个值,用于多项选择的分类
self.classifier = nn.Linear(config.hidden_size, 1)
# 初始化权重并应用最终处理
self.post_init()
@add_start_docstrings_to_model_forward(
MEGATRON_BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=MultipleChoiceModelOutput,
config_class=_CONFIG_FOR_DOC,
)
# 前向传播函数,接收多个输入和控制参数,返回模型输出
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
# 函数参数说明文档
) -> Union[Tuple, MultipleChoiceModelOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
`input_ids` above)
"""
# 根据需要确定是否返回字典格式的输出
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# 获取输入张量的选择数
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
# 重新调整输入张量的形状,将其视为二维张量
input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
# 如果存在输入嵌入,则将其视为三维张量
inputs_embeds = (
inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
if inputs_embeds is not None
else None
)
# 使用BERT模型处理输入
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# 提取池化后的输出
pooled_output = outputs[1]
# 对池化后的输出进行dropout
pooled_output = self.dropout(pooled_output)
# 使用分类器预测logits
logits = self.classifier(pooled_output)
# 调整logits的形状以匹配选择数
reshaped_logits = logits.view(-1, num_choices)
loss = None
# 如果有提供标签,则计算交叉熵损失
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(reshaped_logits, labels)
# 如果不要求返回字典格式的输出,则按元组格式返回结果
if not return_dict:
output = (reshaped_logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
# 如果要求返回字典格式的输出,则创建MultipleChoiceModelOutput对象
return MultipleChoiceModelOutput(
loss=loss,
logits=reshaped_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
"""
MegatronBert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g.
for Named-Entity-Recognition (NER) tasks.
"""
@add_start_docstrings(
"""
MegatronBert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g.
for Named-Entity-Recognition (NER) tasks.
""",
MEGATRON_BERT_START_DOCSTRING,
)
class MegatronBertForTokenClassification(MegatronBertPreTrainedModel):
def __init__(self, config):
"""
Initialize the MegatronBertForTokenClassification model.
Args:
config (MegatronBertConfig): Configuration object specifying the model architecture and hyperparameters.
"""
super().__init__(config)
self.num_labels = config.num_labels
# Initialize the MegatronBertModel with pooling layer excluded
self.bert = MegatronBertModel(config, add_pooling_layer=False)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(MEGATRON_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, TokenClassifierOutput]:
"""
Forward pass of the MegatronBertForTokenClassification model.
Args:
input_ids (torch.LongTensor, optional): Tensor of shape `(batch_size, sequence_length)` containing input token IDs.
attention_mask (torch.FloatTensor, optional): Tensor of shape `(batch_size, sequence_length)` containing attention masks.
token_type_ids (torch.LongTensor, optional): Tensor of shape `(batch_size, sequence_length)` containing token type IDs.
position_ids (torch.LongTensor, optional): Tensor of shape `(batch_size, sequence_length)` containing position IDs.
head_mask (torch.FloatTensor, optional): Tensor of shape `(batch_size, sequence_length)` containing attention head masks.
inputs_embeds (torch.FloatTensor, optional): Tensor of shape `(batch_size, sequence_length, hidden_size)` containing precomputed embeddings.
labels (torch.LongTensor, optional): Tensor of shape `(batch_size, sequence_length)` containing labels for computing token classification loss.
output_attentions (bool, optional): Whether to output attentions.
output_hidden_states (bool, optional): Whether to output hidden states.
return_dict (bool, optional): Whether to return outputs as a dictionary.
Returns:
Union[Tuple, TokenClassifierOutput]: Depending on `return_dict`, either a tuple or a `TokenClassifierOutput` object.
Notes:
- Labels should be in the range `[0, ..., config.num_labels - 1]`.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Perform the forward pass through MegatronBertModel
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
# Apply dropout on the output of the BERT model
sequence_output = self.dropout(sequence_output)
# Pass the modified output through the classifier layer
logits = self.classifier(sequence_output)
loss = None
if labels is not None:
# Compute the token classification loss
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
# Prepare output tuple if return_dict is False
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
# Return TokenClassifierOutput object if return_dict is True
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
# 使用自定义的文档字符串描述 MegatronBertForQuestionAnswering 类,它是基于 Megatron-BERT 模型的抽取式问答任务模型,
# 在隐藏状态输出的基础上加上线性层,用于计算 `span start logits` 和 `span end logits`。
@add_start_docstrings(
"""
MegatronBert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a
linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
""",
MEGATRON_BERT_START_DOCSTRING,
)
class MegatronBertForQuestionAnswering(MegatronBertPreTrainedModel):
def __init__(self, config):
# 调用父类的初始化方法
super().__init__(config)
# 设置类别数目
self.num_labels = config.num_labels
# 初始化 Megatron-BERT 模型,不添加池化层
self.bert = MegatronBertModel(config, add_pooling_layer=False)
# QA 输出层,线性层,输入为隐藏状态大小,输出为类别数目
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
# 初始化权重并进行最终处理
self.post_init()
# 使用自定义的文档字符串描述 forward 方法的输入参数和功能
@add_start_docstrings_to_model_forward(MEGATRON_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
# 使用代码示例的文档字符串描述 forward 方法的返回值类型和相关配置
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=QuestionAnsweringModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
start_positions: Optional[torch.LongTensor] = None,
end_positions: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, QuestionAnsweringModelOutput]:
r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
"""
# 如果 return_dict 未指定,则使用配置中的默认设置
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# 调用 BERT 模型进行前向传播
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# 获取 BERT 输出的序列表示
sequence_output = outputs[0]
# 将序列表示传递给 QA 输出层,得到起始位置和结束位置的 logits
logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1).contiguous() # 压缩维度并确保连续存储
end_logits = end_logits.squeeze(-1).contiguous() # 压缩维度并确保连续存储
total_loss = None
# 如果提供了起始和结束位置,则计算损失
if start_positions is not None and end_positions is not None:
# 如果是多 GPU 情况下,需要添加一个维度
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1)
# 忽略超出模型输入范围的位置
ignored_index = start_logits.size(1)
start_positions = start_positions.clamp(0, ignored_index)
end_positions = end_positions.clamp(0, ignored_index)
# 使用交叉熵损失函数计算起始和结束位置的损失
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
# 如果不需要返回字典格式的输出,则按原样返回结果
if not return_dict:
output = (start_logits, end_logits) + outputs[2:] # 加入额外的输出
return ((total_loss,) + output) if total_loss is not None else output
# 返回 QuestionAnsweringModelOutput 格式的结果
return QuestionAnsweringModelOutput(
loss=total_loss,
start_logits=start_logits,
end_logits=end_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
.\models\megatron_bert\__init__.py
from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
_import_structure = {
"configuration_megatron_bert": ["MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "MegatronBertConfig"],
}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_megatron_bert"] = [
"MEGATRON_BERT_PRETRAINED_MODEL_ARCHIVE_LIST",
"MegatronBertForCausalLM",
"MegatronBertForMaskedLM",
"MegatronBertForMultipleChoice",
"MegatronBertForNextSentencePrediction",
"MegatronBertForPreTraining",
"MegatronBertForQuestionAnswering",
"MegatronBertForSequenceClassification",
"MegatronBertForTokenClassification",
"MegatronBertModel",
"MegatronBertPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_megatron_bert import MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, MegatronBertConfig
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_megatron_bert import (
MEGATRON_BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
MegatronBertForCausalLM,
MegatronBertForMaskedLM,
MegatronBertForMultipleChoice,
MegatronBertForNextSentencePrediction,
MegatronBertForPreTraining,
MegatronBertForQuestionAnswering,
MegatronBertForSequenceClassification,
MegatronBertForTokenClassification,
MegatronBertModel,
MegatronBertPreTrainedModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
.\models\megatron_gpt2\checkpoint_reshaping_and_interoperability.py
import argparse
import json
import os
import re
import sys
import types
import torch
from transformers import AutoTokenizer, GPT2Config
from transformers.modeling_utils import WEIGHTS_INDEX_NAME, WEIGHTS_NAME, shard_checkpoint
def add_checkpointing_args(parser):
parser.add_argument("--megatron-path", type=str, default=None, help="Base directory of Megatron repository")
parser.add_argument(
"--convert_checkpoint_from_megatron_to_transformers",
action="store_true",
help=(
"If True, convert a Megatron checkpoint to a Transformers checkpoint. "
"If False, convert a Transformers checkpoint to a Megatron checkpoint."
),
)
parser.add_argument(
"--load_path",
type=str,
required=True,
help="Path to the checkpoint to convert.",
)
parser.add_argument(
"--save_path",
type=str,
required=True,
help="Path to the converted checkpoint.",
)
parser.add_argument("--print-checkpoint-structure", action="store_true")
return parser
def add_megatron_checkpoint_args(parser):
parser.add_argument(
"--target_tensor_model_parallel_size",
type=int,
default=1,
help=(
"The tensor model parallel size of the converted checkpoint. "
"Only used when converting a Transformers checkpoint to a Megatron checkpoint."
),
)
parser.add_argument(
"--target_pipeline_model_parallel_size",
type=int,
default=1,
help=(
"The pipeline model parallel size of the converted checkpoint. "
"Only used when converting a Transformers checkpoint to a Megatron checkpoint."
),
)
parser.add_argument(
"--target_data_parallel_size",
type=int,
default=1,
help=(
"The data parallel size of the converted checkpoint. "
"Only used when converting a Transformers checkpoint to a Megatron checkpoint."
),
)
parser.add_argument(
"--target_params_dtype",
type=str,
default="fp32",
help=(
"The dtype of the converted checkpoint. "
"Only used when converting a Transformers checkpoint to a Megatron checkpoint."
),
)
parser.add_argument(
"--make_vocab_size_divisible_by",
type=int,
default=128,
help=(
"Pad the vocab size to be divisible by this value. "
"This is added for computational efficiency reasons. "
"Only used when converting a Transformers checkpoint to a Megatron checkpoint."
),
)
parser.add_argument(
"--use_distributed_optimizer",
action="store_true",
help=(
"If True, use the distributed optimizer. "
"Only used when converting a Transformers checkpoint to a Megatron checkpoint."
),
)
return parser
def add_transformers_checkpoint_args(parser):
"""
添加 Transformers 检查点的参数到解析器中。
Args:
parser (ArgumentParser): 解析器对象,用于添加参数
Returns:
ArgumentParser: 更新后的解析器对象
"""
parser.add_argument(
"--tokenizer_name",
type=str,
default=None,
help=(
"要保存的预训练分词器的名称。如果不是 None,则会保存分词器。"
"仅在将 Megatron 检查点转换为 Transformers 检查点时使用。"
),
)
parser.add_argument(
"--max_shard_size",
type=str,
default="10GB",
help=(
"在分片之前检查点的最大大小。检查点分片将小于此大小。"
"如果表示为字符串,需由数字后跟单位(如 `5MB`)组成。"
"仅在将 Megatron 检查点转换为 Transformers 检查点时使用。"
),
)
return parser
megatron_to_transformers = {
"attention.dense": ".attn.c_proj.",
"self_attention.dense": ".attn.c_proj.",
"mlp.dense_h_to_4h": ".mlp.c_fc.",
"mlp.dense_4h_to_h": ".mlp.c_proj.",
}
transformers_to_megatron = {v[1:-1]: k for k, v in megatron_to_transformers.items()}
tensor_parallel_params = [
"self_attention.query_key_value.weight",
"self_attention.query_key_value.bias",
"self_attention.dense.weight",
"mlp.dense_h_to_4h.weight",
"mlp.dense_h_to_4h.bias",
"mlp.dense_4h_to_h.weight",
"attention.query_key_value.weight",
"attention.query_key_value.bias",
"attention.dense.weight",
"attn.c_attn.weight",
"attn.c_attn.bias",
"attn.c_proj.weight",
"mlp.c_fc.weight",
"mlp.c_fc.bias",
"mlp.c_proj.weight",
]
def recursive_print(name, val, spaces=0):
"""
递归打印检查点的结构。此函数源自 `convert_megatron_gpt2_checkpoint.py`。
Args:
name (str): 当前张量参数的名称
val (Tuple(int)): 当前张量参数的形状
spaces (int): 输出嵌套结构之前的空格数
"""
if name is None:
msg = None
else:
fmt = "." * max(0, spaces - 2) + "# {:" + str(50 - spaces) + "s}"
msg = fmt.format(name)
if isinstance(val, dict):
if msg is not None:
print(msg)
for k in val.keys():
recursive_print(k, val[k], spaces + 2)
elif isinstance(val, torch.Tensor):
print(msg, ":", val.size())
else:
print(msg, ":", val)
def megatron_to_transformers_fix_query_key_value_ordering(
param, checkpoint_version, num_splits, num_heads, hidden_size
):
"""
重新排列 param 张量的布局,以便与后续版本兼容为 [num_splits * num_heads * hidden_size, :]。
Args:
param: 要重新排列的参数张量
checkpoint_version: 检查点版本
num_splits: 分片数
num_heads: 头数
hidden_size: 隐藏大小
"""
input_shape = param.size()
if checkpoint_version == 1.0:
saved_shape = (num_heads, hidden_size, num_splits) + input_shape[1:]
param = param.view(*saved_shape)
param = param.transpose(0, 2)
param = param.transpose(1, 2).contiguous()
elif checkpoint_version >= 2.0:
saved_shape = (num_heads, num_splits, hidden_size) + input_shape[1:]
param = param.view(*saved_shape)
param = param.transpose(0, 1).contiguous()
param = param.view(*input_shape)
return param
def transformers_to_megatron_fix_query_key_value_ordering(
param, checkpoint_version, num_splits, num_heads, hidden_size
):
"""
Permutes layout of param tensor to the one compatible with respective NVIDIA Megatron-LM chekpoint versions. Input
is [num_splits * num_heads * hidden_size, :] and output is [num_heads * hidden_size * num_splits, :] for version
1.0 and [num_heads * num_splits * hidden_size, :] for version 2.0 and later. If param is the weight tensor of the
self-attention block, the param needs to be already transposed before calling this function.
Args:
param (torch.Tensor): the tensor to permute
checkpoint_version (int): the version of the checkpoint.
num_splits (int): the number of projections, usually 3 for (Query, Key, Value)
num_heads (int): the number of attention heads
hidden_size (int): the hidden size per head
"""
input_shape = param.size()
if checkpoint_version == 1.0:
current_shape = (num_splits, num_heads, hidden_size) + input_shape[1:]
param = param.view(*current_shape)
param = param.transpose(0, 2)
param = param.transpose(1, 2).contiguous()
elif checkpoint_version >= 2.0:
current_shape = (num_splits, num_heads, hidden_size) + input_shape[1:]
param = param.view(*current_shape)
param = param.transpose(0, 1).contiguous()
param = param.view(*input_shape)
return param
def merge_transformers_sharded_states(path, num_checkpoints):
"""
Merge sharded checkpoints from transformers into a single checkpoint.
Args:
path (str): the path to the sharded checkpoints
num_checkpoints (int): the number of checkpoints to merge
"""
state_dict = {}
for i in range(1, num_checkpoints + 1):
checkpoint_path = os.path.join(path, f"pytorch_model-{i:05d}-of-{num_checkpoints:05d}.bin")
current_chunk = torch.load(checkpoint_path, map_location="cpu")
state_dict.update(current_chunk)
return state_dict
def get_megatron_sharded_states(args, tp_size, pp_size, pp_rank):
"""
Get sharded checkpoints from NVIDIA Megatron-LM checkpoint based on the provided tensor parallel size, pipeline
parallel size and pipeline parallel rank.
Args:
args (argparse.Namespace): the arguments to the script
tp_size (int): the tensor parallel size
pp_size (int): the pipeline parallel size
pp_rank (int): the pipeline parallel rank
"""
tp_state_dicts = []
for i in range(tp_size):
sub_dir_name = f"mp_rank_{i:02d}" if pp_size == 1 else f"mp_rank_{i:02d}_{pp_rank:03d}"
for checkpoint_name in ["model_optim_rng.pt", "model_rng.pt"]:
checkpoint_path = os.path.join(args.load_path, sub_dir_name, checkpoint_name)
if os.path.isfile(checkpoint_path):
break
state_dict = torch.load(checkpoint_path, map_location="cpu")
tp_state_dicts.append(state_dict)
return tp_state_dicts
def get_element_from_dict_by_path(d, path):
path = path.split(".")
for k in path:
if k not in d:
d[k] = {}
d = d[k]
return d
def convert_checkpoint_from_megatron_to_transformers(args):
"""
Convert NVIDIA Megatron-LM checkpoint to HuggingFace Transformers checkpoint. This handles Megatron checkpoints
with different tensor parallelism and pipeline parallelism sizes. It saves the converted checkpoint into shards
using HuggingFace Transformers checkpoint sharding functionality. This greatly extends the functionality of
`convert_megatron_gpt2_checkpoint.py`
Args:
args (argparse.Namespace): the arguments to the script
"""
sub_dirs = os.listdir(args.load_path)
possible_sub_dirs = ["mp_rank_00", "mp_rank_00_000"]
for sub_dir in possible_sub_dirs:
if sub_dir in sub_dirs:
rank0_checkpoint_name = os.listdir(os.path.join(args.load_path, sub_dir))[0]
rank0_checkpoint_path = os.path.join(args.load_path, sub_dir, rank0_checkpoint_name)
break
print(f"Loading Megatron-LM checkpoint arguments from: {rank0_checkpoint_path}")
state_dict = torch.load(rank0_checkpoint_path, map_location="cpu")
megatron_args = state_dict.get("args", None)
if megatron_args is None:
raise ValueError(
"Megatron-LM checkpoint does not contain arguments. This utility only supports Megatron-LM checkpoints"
" containing all the megatron arguments. This is because it loads all config related to model"
" architecture, the tensor and pipeline model parallel size from the checkpoint insead of user having to"
" manually specify all the details. Please save Megatron-LM checkpoint along with all the megatron"
" arguments to use this utility."
)
if megatron_args is not None:
if megatron_args.bias_gelu_fusion:
activation_function = "gelu_fast"
elif megatron_args.openai_gelu:
activation_function = "gelu_new"
else:
activation_function = "gelu"
else:
activation_function = "gelu_new"
vocab_size = (
megatron_args.padded_vocab_size
if getattr(megatron_args, "orig_vocab_size", None) is None
else megatron_args.orig_vocab_size
)
print(vocab_size)
config = GPT2Config(
vocab_size=vocab_size,
n_positions=megatron_args.max_position_embeddings,
n_embd=megatron_args.hidden_size,
n_layer=megatron_args.num_layers,
n_head=megatron_args.num_attention_heads,
n_inner=megatron_args.ffn_hidden_size,
activation_function=activation_function,
resid_pdrop=0.1,
embd_pdrop=0.1,
attn_pdrop=0.1,
layer_norm_epsilon=1e-5,
initializer_range=0.02,
summary_type="cls_index",
summary_use_proj=True,
summary_activation=None,
summary_proj_to_labels=True,
summary_first_dropout=0.1,
scale_attn_weights=True,
use_cache=True,
bos_token_id=vocab_size - 1,
eos_token_id=vocab_size - 1,
architectures=["GPT2LMHeadModel"],
)
output_state_dict = {}
checkpoint_version = state_dict.get("checkpoint_version", 0.0)
tp_size = megatron_args.tensor_model_parallel_size
pp_size = megatron_args.pipeline_model_parallel_size
dtype = torch.float32
layer_re = re.compile(r"layers\.(\d+)\.([a-z0-9_.]+)\.([a-z]+)")
print("Converting")
print("Converting embeddings")
tp_state_dicts = get_megatron_sharded_states(args, tp_size, pp_size, 0)
position_embeddings = get_element_from_dict_by_path(
tp_state_dicts[0], "model.language_model.embedding.position_embeddings.weight"
)
output_state_dict["transformer.wpe.weight"] = position_embeddings.to(dtype)
word_embeddings = torch.cat(
[
get_element_from_dict_by_path(
tp_state_dicts[tp_rank], "model.language_model.embedding.word_embeddings.weight"
)
for tp_rank in range(tp_size)
],
dim=0,
)
word_embeddings = word_embeddings[:vocab_size].to(dtype)
output_state_dict["transformer.wte.weight"] = word_embeddings
print("Converting transformer layers")
heads = config.n_head
hidden_size_per_head = config.n_embd // config.n_head
n_positions = config.n_positions
num_layers = config.num_hidden_layers // pp_size
if config.n_layer != (layer_idx + 1):
raise ValueError(f"Expected {config.n_layer} layers but found {layer_idx + 1}")
print("Converting final layernorm")
params = get_element_from_dict_by_path(tp_state_dicts[0], str(path))
output_state_dict["transformer.ln_f.weight"] = params["final_layernorm.weight"].to(dtype)
output_state_dict["transformer.ln_f.bias"] = params["final_layernorm.bias"].to(dtype)
print("Converting LM head")
output_state_dict["lm_head.weight"] = word_embeddings.to(dtype)
print("Conversion from Megatron-LM to Transformers is done!")
if args.print_checkpoint_structure:
recursive_print(None, output_state_dict)
if args.tokenizer_name is None:
tokenizer_name = "openai-community/gpt2"
else:
tokenizer_name = args.tokenizer_name
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
tokenizer_class = type(tokenizer).__name__
config.tokenizer_class = tokenizer_class
print("Saving config")
config.save_pretrained(args.save_path)
if args.tokenizer_name is not None:
print(f"Adding {tokenizer_class} tokenizer files")
tokenizer.save_pretrained(args.save_path)
max_shard_size = int(args.max_shard_size) if args.max_shard_size.isdigit() else args.max_shard_size
shards, index = shard_checkpoint(output_state_dict, max_shard_size=max_shard_size)
for shard_file, shard in shards.items():
torch.save(shard, os.path.join(args.save_path, shard_file))
if index is None:
print(f"Model weights saved in {os.path.join(args.save_path, WEIGHTS_NAME)}")
else:
save_index_file = os.path.join(args.save_path, WEIGHTS_INDEX_NAME)
with open(save_index_file, "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)
print(
f"The model is bigger than the maximum size per checkpoint ({args.max_shard_size}) and is going to be "
f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
def convert_checkpoint_from_transformers_to_megatron(args):
"""
Convert a checkpoint from HuggingFace Transformers to Megatron-LM. This allows converted checkpoints with variable
tensor parallelism and pipeline parallelism sizes. It takes as input a checkpoint from HuggingFace Transformers
which can have multiple shards.
Args:
args (argparse.Namespace): the arguments to the script
"""
os.makedirs(args.save_path, exist_ok=True)
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)))
if args.megatron_path is not None:
sys.path.insert(0, args.megatron_path)
try:
from megatron.tokenizer.tokenizer import _vocab_size_with_padding
except ModuleNotFoundError:
print("Unable to import Megatron, please specify the path to Megatron using --megatron-path. Exiting.")
exit(1)
sub_dirs = [x for x in os.listdir(args.load_path) if x.startswith("pytorch_model")]
if len(sub_dirs) == 1:
checkpoint_name = "pytorch_model.bin"
state_dict = torch.load(os.path.join(args.load_path, checkpoint_name), map_location="cpu")
else:
num_checkpoints = len(sub_dirs) - 1
state_dict = merge_transformers_sharded_states(args.load_path, num_checkpoints)
config = GPT2Config.from_pretrained(args.load_path)
tracker_filepath = os.path.join(args.save_path, "latest_checkpointed_iteration.txt")
with open(tracker_filepath, "w") as f:
f.write("release")
release_dir = os.path.join(args.save_path, "release")
os.makedirs(release_dir, exist_ok=True)
megatron_args = {
"orig_vocab_size": config.vocab_size,
"max_position_embeddings": config.n_positions,
"hidden_size": config.n_embd,
"num_layers": config.n_layer,
"num_attention_heads": config.n_head,
"ffn_hidden_size": config.n_inner,
"tensor_model_parallel_size": args.target_tensor_model_parallel_size,
"pipeline_model_parallel_size": args.target_pipeline_model_parallel_size,
"data_parallel_size": args.target_data_parallel_size,
"make_vocab_size_divisible_by": args.make_vocab_size_divisible_by,
"rank": 0,
"tokenizer_type": "GPT2BPETokenizer",
}
if config.activation_function == "gelu":
megatron_args["bias_gelu_fusion"] = False
megatron_args["openai_gelu"] = False
elif config.activation_function == "gelu_fast":
megatron_args["bias_gelu_fusion"] = True
megatron_args["openai_gelu"] = False
elif config.activation_function == "gelu_new":
megatron_args["bias_gelu_fusion"] = False
megatron_args["openai_gelu"] = True
margs = types.SimpleNamespace()
for k, v in megatron_args.items():
setattr(margs, k, v)
if args.target_params_dtype == "fp16":
dtype = torch.float16
elif args.target_params_dtype == "bf16":
dtype = torch.bfloat16
else:
dtype = torch.float32
setattr(margs, "params_dtype", dtype)
dummy_optim_state_dict = {}
dummy_optim_state_dict["optimizer"] = {
"step": 0,
"param_groups": [
{
"lr": 0.0,
"beta1": 0.0,
"beta2": 0.0,
"eps": 0.0,
"weight_decay": 0.0,
"correct_bias": False,
"params": [],
}
],
}
if args.use_distributed_optimizer:
for i in range(args.target_pipeline_model_parallel_size):
for j in range(args.target_tensor_model_parallel_size):
for k in range(args.target_data_parallel_size):
if args.target_pipeline_model_parallel_size == 1:
checkpoint_dir = f"mp_rank_{j:02d}_{k:03d}"
else:
checkpoint_dir = f"mp_rank_{j:02d}_{i:03d}_{k:03d}"
checkpoint_dir = os.path.join(release_dir, checkpoint_dir)
os.makedirs(checkpoint_dir, exist_ok=True)
torch.save(
dummy_optim_state_dict,
os.path.join(checkpoint_dir, "optim.pt"),
)
print("Converting")
output_state_dict = []
for i in range(args.target_tensor_model_parallel_size):
output_state_dict.append({})
print("converting embedding layer")
pos_embedding = state_dict["transformer.wpe.weight"].to(dtype)
word_embedding = state_dict["transformer.wte.weight"].to(dtype)
orig_vocab_size = config.vocab_size
padded_vocab_size = _vocab_size_with_padding(orig_vocab_size, margs)
setattr(margs, "padded_vocab_size", padded_vocab_size)
if orig_vocab_size > padded_vocab_size:
full_word_embed = word_embedding[0:padded_vocab_size, :]
elif orig_vocab_size < padded_vocab_size:
padding_size = padded_vocab_size - orig_vocab_size
full_word_embed = torch.cat((word_embedding, word_embedding[-1].unsqueeze(0).expand(padding_size, -1)))
else:
full_word_embed = word_embedding
out_word_embed = torch.chunk(full_word_embed, args.target_tensor_model_parallel_size, dim=0)
for i in range(args.target_tensor_model_parallel_size):
pos_emb_dict = get_element_from_dict_by_path(
output_state_dict[i], "model.language_model.embedding.position_embeddings"
)
pos_emb_dict["weight"] = pos_embedding
word_emb_dict = get_element_from_dict_by_path(
output_state_dict[i], "model.language_model.embedding.word_embeddings"
)
word_emb_dict["weight"] = out_word_embed[i].clone()
print("converting transformer layers")
if config.num_attention_heads % args.target_tensor_model_parallel_size != 0:
raise ValueError(
f"Number of attention heads ({config.num_attention_heads}) must be divisible by number of tensor parallelism"
f" ({args.target_tensor_model_parallel_size})"
)
if config.num_hidden_layers % args.target_pipeline_model_parallel_size != 0:
raise ValueError(
f"Number of layers ({config.num_hidden_layers}) must be divisible by number of pipeline parallelism"
f" ({args.target_pipeline_model_parallel_size})"
)
num_layers = config.num_hidden_layers // args.target_pipeline_model_parallel_size
layer_re = re.compile(r"transformer.h\.(\d+)\.([a-z0-9_.]+)\.([a-z]+)")
heads = config.n_head
hidden_size_per_head = config.n_embd // config.n_head
def main():
parser = argparse.ArgumentParser()
parser = add_checkpointing_args(parser)
parser = add_megatron_checkpoint_args(parser)
parser = add_transformers_checkpoint_args(parser)
args = parser.parse_args()
if args.convert_checkpoint_from_megatron_to_transformers:
convert_checkpoint_from_megatron_to_transformers(args)
else:
convert_checkpoint_from_transformers_to_megatron(args)
if __name__ == "__main__":
main()
.\models\megatron_gpt2\convert_megatron_gpt2_checkpoint.py
def recursive_print(name, val, spaces=0):
if name is None:
msg = None
else:
fmt = "." * max(0, spaces - 2) + "# {:" + str(50 - spaces) + "s}"
msg = fmt.format(name)
if isinstance(val, dict):
if msg is not None:
print(msg)
for k in val.keys():
recursive_print(k, val[k], spaces + 2)
elif isinstance(val, torch.Tensor):
print(msg, ":", val.size())
else:
print(msg, ":", val)
def fix_query_key_value_ordering(param, checkpoint_version, num_splits, num_heads, hidden_size):
input_shape = param.size()
if checkpoint_version == 1.0:
saved_shape = (num_heads, hidden_size, num_splits) + input_shape[1:]
param = param.view(*saved_shape)
param = param.transpose(0, 2)
param = param.transpose(1, 2).contiguous()
elif checkpoint_version >= 2.0:
saved_shape = (num_heads, num_splits, hidden_size) + input_shape[1:]
param = param.view(*saved_shape)
param = param.transpose(0, 1).contiguous()
param = param.view(*input_shape)
return param
def convert_megatron_checkpoint(args, input_state_dict, config):
output_state_dict = {}
ds_args = input_state_dict.get("args", None)
if ds_args is not None:
config.vocab_size = ds_args.padded_vocab_size
config.n_positions = ds_args.max_position_embeddings
config.n_embd = ds_args.hidden_size
config.n_layer = ds_args.num_layers
config.n_head = ds_args.num_attention_heads
config.n_inner = ds_args.ffn_hidden_size
heads = config.n_head
hidden_size_per_head = config.n_embd // config.n_head
if "checkpoint_version" in input_state_dict.keys():
checkpoint_version = input_state_dict["checkpoint_version"]
else:
checkpoint_version = 0.0
model = input_state_dict["model"]
lm = model["language_model"]
embeddings = lm["embedding"]
word_embeddings = embeddings["word_embeddings"]["weight"]
word_embeddings = word_embeddings[: config.vocab_size, :]
output_state_dict["transformer.wte.weight"] = word_embeddings
pos_embeddings = embeddings["position_embeddings"]["weight"]
n_positions = pos_embeddings.size(0)
if n_positions != config.n_positions:
raise ValueError(
f"pos_embeddings.max_sequence_length={n_positions} and config.n_positions={config.n_positions} don't match"
)
output_state_dict["transformer.wpe.weight"] = pos_embeddings
transformer = lm["transformer"] if "transformer" in lm.keys() else lm["encoder"]
layer_re = re.compile(r"layers\.(\d+)\.([a-z0-9_.]+)\.([a-z]+)")
megatron_to_transformers = {
"attention.dense": ".attn.c_proj.",
"self_attention.dense": ".attn.c_proj.",
"mlp.dense_h_to_4h": ".mlp.c_fc.",
"mlp.dense_4h_to_h": ".mlp.c_proj.",
}
for key, val in transformer.items():
m = layer_re.match(key)
if m is None:
break
layer_idx = int(m.group(1))
op_name = m.group(2)
weight_or_bias = m.group(3)
layer_name = f"transformer.h.{layer_idx}"
if op_name.endswith("layernorm"):
ln_name = "ln_1" if op_name.startswith("input") else "ln_2"
output_state_dict[layer_name + "." + ln_name + "." + weight_or_bias] = val
elif (
op_name == "attention.query_key_value" or op_name == "self_attention.query_key_value"
) and weight_or_bias == "weight":
causal_mask = torch.tril(torch.ones((n_positions, n_positions), dtype=torch.float16)).view(
1, 1, n_positions, n_positions
)
output_state_dict[layer_name + ".attn.bias"] = causal_mask
masked_bias = torch.tensor(-1e4, dtype=torch.float16)
output_state_dict[layer_name + ".attn.masked_bias"] = masked_bias
out_val = fix_query_key_value_ordering(val, checkpoint_version, 3, heads, hidden_size_per_head)
out_val = out_val.transpose(0, 1).contiguous()
output_state_dict[layer_name + ".attn.c_attn.weight"] = out_val
elif (
op_name == "attention.query_key_value" or op_name == "self_attention.query_key_value"
) and weight_or_bias == "bias":
out_val = fix_query_key_value_ordering(val, checkpoint_version, 3, heads, hidden_size_per_head)
output_state_dict[layer_name + ".attn.c_attn.bias"] = out_val
elif weight_or_bias == "weight":
out_name = megatron_to_transformers[op_name]
output_state_dict[layer_name + out_name + "weight"] = val.transpose(0, 1)
elif weight_or_bias == "bias":
out_name = megatron_to_transformers[op_name]
output_state_dict[layer_name + out_name + "bias"] = val
assert config.n_layer == layer_idx + 1
output_state_dict["transformer.ln_f.weight"] = transformer["final_layernorm.weight"]
output_state_dict["transformer.ln_f.bias"] = transformer["final_layernorm.bias"]
output_state_dict["lm_head.weight"] = word_embeddings
return output_state_dict
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--print-checkpoint-structure", action="store_true")
parser.add_argument(
"path_to_checkpoint",
type=str,
help="Path to the checkpoint file (.zip archive or direct .pt file)",
)
parser.add_argument(
"--config_file",
default="",
type=str,
help="An optional config json file describing the pre-trained model.",
)
args = parser.parse_args()
basename = os.path.dirname(args.path_to_checkpoint)
print(f"Extracting PyTorch state dictionary from {args.path_to_checkpoint}")
if args.path_to_checkpoint.endswith(".zip"):
with zipfile.ZipFile(args.path_to_checkpoint, "r") as checkpoint:
with checkpoint.open("release/mp_rank_00/model_optim_rng.pt") as pytorch_dict:
input_state_dict = torch.load(pytorch_dict, map_location="cpu")
else:
input_state_dict = torch.load(args.path_to_checkpoint, map_location="cpu")
ds_args = input_state_dict.get("args", None)
if args.config_file == "":
if ds_args is not None:
if ds_args.bias_gelu_fusion:
activation_function = "gelu_fast"
elif ds_args.openai_gelu:
activation_function = "gelu_new"
else:
activation_function = "gelu"
else:
activation_function = "gelu_new"
config = GPT2Config(
vocab_size=50257,
n_positions=1024,
n_embd=1024,
n_layer=24,
n_head=16,
n_inner=4096,
activation_function=activation_function,
resid_pdrop=0.1,
embd_pdrop=0.1,
attn_pdrop=0.1,
layer_norm_epsilon=1e-5,
initializer_range=0.02,
summary_type="cls_index",
summary_use_proj=True,
summary_activation=None,
summary_proj_to_labels=True,
summary_first_dropout=0.1,
scale_attn_weights=True,
use_cache=True,
bos_token_id=50256,
eos_token_id=50256,
)
else:
config = GPT2Config.from_json_file(args.config_file)
config.architectures = ["GPT2LMHeadModel"]
print("Converting")
output_state_dict = convert_megatron_checkpoint(args, input_state_dict, config)
if args.print_checkpoint_structure:
recursive_print(None, output_state_dict)
if ds_args is not None:
tokenizer_type = ds_args.tokenizer_type
if tokenizer_type == "GPT2BPETokenizer":
tokenizer_model_name = "openai-community/gpt2"
elif tokenizer_type == "PretrainedFromHF":
tokenizer_model_name = ds_args.tokenizer_name_or_path
else:
raise ValueError(f"Unrecognized tokenizer_type {tokenizer_type}")
else:
tokenizer_model_name = "openai-community/gpt2"
tokenizer = AutoTokenizer.from_pretrained(tokenizer_model_name)
tokenizer_class = type(tokenizer).__name__
config.tokenizer_class = tokenizer_class
print("Saving config")
config.save_pretrained(basename)
print(f"Adding {tokenizer_class} tokenizer files")
tokenizer.save_pretrained(basename)
output_checkpoint_file = os.path.join(basename, "pytorch_model.bin")
print(f'Saving checkpoint to "{output_checkpoint_file}"')
torch.save(output_state_dict, output_checkpoint_file)
if __name__ == "__main__":
main()
.\models\megatron_gpt2\__init__.py
.\models\mgp_str\configuration_mgp_str.py
""" MGP-STR model configuration"""
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
MGP_STR_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"alibaba-damo/mgp-str-base": "https://huggingface.co/alibaba-damo/mgp-str-base/resolve/main/config.json",
}
class MgpstrConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of an [`MgpstrModel`]. It is used to instantiate an
MGP-STR model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the MGP-STR
[alibaba-damo/mgp-str-base](https://huggingface.co/alibaba-damo/mgp-str-base) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
# 定义默认的图像大小为 [32, 128]
Args:
image_size (`List[int]`, *optional*, defaults to `[32, 128]`):
The size (resolution) of each image.
# 定义每个补丁的大小,默认为 4
patch_size (`int`, *optional*, defaults to 4):
The size (resolution) of each patch.
# 定义输入通道数,默认为 3
num_channels (`int`, *optional*, defaults to 3):
The number of input channels.
# 定义输出令牌的最大数量,默认为 27
max_token_length (`int`, *optional*, defaults to 27):
The max number of output tokens.
# 定义字符头的类别数量,默认为 38
num_character_labels (`int`, *optional*, defaults to 38):
The number of classes for character head .
# 定义bpe头的类别数量,默认为 50257
num_bpe_labels (`int`, *optional*, defaults to 50257):
The number of classes for bpe head .
# 定义wordpiece头的类别数量,默认为 30522
num_wordpiece_labels (`int`, *optional*, defaults to 30522):
The number of classes for wordpiece head .
# 定义嵌入维度,默认为 768
hidden_size (`int`, *optional*, defaults to 768):
The embedding dimension.
# 定义Transformer编码器中的隐藏层数量,默认为 12
num_hidden_layers (`int`, *optional*, defaults to 12):
Number of hidden layers in the Transformer encoder.
# 定义Transformer编码器中每个注意力层的注意头数量,默认为 12
num_attention_heads (`int`, *optional*, defaults to 12):
Number of attention heads for each attention layer in the Transformer encoder.
# 定义mlp隐藏维度与嵌入维度的比率,默认为 4.0
mlp_ratio (`float`, *optional*, defaults to 4.0):
The ratio of mlp hidden dim to embedding dim.
# 定义是否向查询、键和值添加偏置,默认为 True
qkv_bias (`bool`, *optional*, defaults to `True`):
Whether to add a bias to the queries, keys and values.
# 定义模型是否包含蒸馏令牌和头部,如DeiT模型,默认为 False
distilled (`bool`, *optional*, defaults to `False`):
Model includes a distillation token and head as in DeiT models.
# 定义层归一化层使用的 epsilon,默认为 1e-05
layer_norm_eps (`float`, *optional*, defaults to 1e-05):
The epsilon used by the layer normalization layers.
# 定义所有全连接层的丢弃概率,包括嵌入和编码器,默认为 0.0
drop_rate (`float`, *optional*, defaults to 0.0):
The dropout probability for all fully connected layers in the embeddings, encoder.
# 定义注意力概率的丢弃比率,默认为 0.0
attn_drop_rate (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
# 定义随机深度的丢弃率,默认为 0.0
drop_path_rate (`float`, *optional*, defaults to 0.0):
The stochastic depth rate.
# 定义是否返回A^3模块注意力的布尔值,默认为 False
output_a3_attentions (`bool`, *optional*, defaults to `False`):
Whether or not the model should returns A^3 module attentions.
# 定义所有权重矩阵初始化时的截断正态分布的标准差,默认为 0.02
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
Example:
```
>>> from transformers import MgpstrConfig, MgpstrForSceneTextRecognition
>>> # Initializing a Mgpstr mgp-str-base style configuration
>>> configuration = MgpstrConfig()
>>> # Initializing a model (with random weights) from the mgp-str-base style configuration
>>> model = MgpstrForSceneTextRecognition(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
# 设置模型类型为 "mgp-str"
model_type = "mgp-str"
# 定义一个初始化函数,初始化一个模型对象
def __init__(
self,
image_size=[32, 128], # 图像大小,默认为[32, 128]
patch_size=4, # 补丁大小,默认为4
num_channels=3, # 图像通道数,默认为3
max_token_length=27, # 最大标记长度,默认为27
num_character_labels=38, # 字符标签数,默认为38
num_bpe_labels=50257, # BPE标签数,默认为50257
num_wordpiece_labels=30522, # WordPiece标签数,默认为30522
hidden_size=768, # 隐藏层大小,默认为768
num_hidden_layers=12, # 隐藏层数,默认为12
num_attention_heads=12, # 注意力头数,默认为12
mlp_ratio=4.0, # MLP(多层感知机)比例,默认为4.0
qkv_bias=True, # 是否在QKV转换中使用偏置,默认为True
distilled=False, # 是否为蒸馏模型,默认为False
layer_norm_eps=1e-5, # 层归一化的epsilon值,默认为1e-5
drop_rate=0.0, # dropout比率,默认为0.0
attn_drop_rate=0.0, # 注意力dropout比率,默认为0.0
drop_path_rate=0.0, # 路径dropout比率,默认为0.0
output_a3_attentions=False, # 是否输出A3注意力,默认为False
initializer_range=0.02, # 初始化范围,默认为0.02
**kwargs, # 其他关键字参数
):
super().__init__(**kwargs) # 调用父类的初始化方法
self.image_size = image_size # 初始化图像大小属性
self.patch_size = patch_size # 初始化补丁大小属性
self.num_channels = num_channels # 初始化图像通道数属性
self.max_token_length = max_token_length # 初始化最大标记长度属性
self.num_character_labels = num_character_labels # 初始化字符标签数属性
self.num_bpe_labels = num_bpe_labels # 初始化BPE标签数属性
self.num_wordpiece_labels = num_wordpiece_labels # 初始化WordPiece标签数属性
self.hidden_size = hidden_size # 初始化隐藏层大小属性
self.num_hidden_layers = num_hidden_layers # 初始化隐藏层数属性
self.num_attention_heads = num_attention_heads # 初始化注意力头数属性
self.mlp_ratio = mlp_ratio # 初始化MLP比例属性
self.distilled = distilled # 初始化蒸馏模型属性
self.layer_norm_eps = layer_norm_eps # 初始化层归一化epsilon属性
self.drop_rate = drop_rate # 初始化dropout比率属性
self.qkv_bias = qkv_bias # 初始化QKV偏置属性
self.attn_drop_rate = attn_drop_rate # 初始化注意力dropout比率属性
self.drop_path_rate = drop_path_rate # 初始化路径dropout比率属性
self.output_a3_attentions = output_a3_attentions # 初始化是否输出A3注意力属性
self.initializer_range = initializer_range # 初始化初始化范围属性
.\models\mgp_str\modeling_mgp_str.py
""" PyTorch MGP-STR model."""
import collections.abc
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from ...modeling_outputs import BaseModelOutput
from ...modeling_utils import PreTrainedModel
from ...utils import (
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from .configuration_mgp_str import MgpstrConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "MgpstrConfig"
_TOKENIZER_FOR_DOC = "MgpstrTokenizer"
_CHECKPOINT_FOR_DOC = "alibaba-damo/mgp-str-base"
MGP_STR_PRETRAINED_MODEL_ARCHIVE_LIST = [
"alibaba-damo/mgp-str-base",
]
def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
argument.
"""
if drop_prob == 0.0 or not training:
return input
keep_prob = 1 - drop_prob
shape = (input.shape[0],) + (1,) * (input.ndim - 1)
random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
random_tensor.floor_()
output = input.div(keep_prob) * random_tensor
return output
class MgpstrDropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob: Optional[float] = None) -> None:
super().__init__()
self.drop_prob = drop_prob
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return drop_path(hidden_states, self.drop_prob, self.training)
def extra_repr(self) -> str:
return "p={}".format(self.drop_prob)
@dataclass
class MgpstrModelOutput(ModelOutput):
"""
Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
Args:
logits (`tuple(torch.FloatTensor)` of shape `(batch_size, config.num_character_labels)`):
Tuple of `torch.FloatTensor` containing classification scores (before SoftMax) for characters, bpe, and wordpiece.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` containing hidden states of the model at each layer and optional initial embeddings.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` containing attention weights for each layer after softmax computation.
a3_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_a3_attentions=True` is passed or when `config.output_a3_attentions=True`):
Tuple of `torch.FloatTensor` containing attention weights for character, bpe, and wordpiece after softmax computation.
"""
logits: Tuple[torch.FloatTensor] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
a3_attentions: Optional[Tuple[torch.FloatTensor]] = None
class MgpstrEmbeddings(nn.Module):
"""2D Image to Patch Embedding"""
def __init__(self, config: MgpstrConfig):
super().__init__()
image_size = (
config.image_size
if isinstance(config.image_size, collections.abc.Iterable)
else (config.image_size, config.image_size)
)
patch_size = (
config.patch_size
if isinstance(config.patch_size, collections.abc.Iterable)
else (config.patch_size, config.patch_size)
)
self.image_size = image_size
self.patch_size = patch_size
self.grid_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.num_tokens = 2 if config.distilled else 1
self.proj = nn.Conv2d(config.num_channels, config.hidden_size, kernel_size=patch_size, stride=patch_size)
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + self.num_tokens, config.hidden_size))
self.pos_drop = nn.Dropout(p=config.drop_rate)
def forward(self, pixel_values):
batch_size, channel, height, width = pixel_values.shape
if height != self.image_size[0] or width != self.image_size[1]:
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
)
patch_embeddings = self.proj(pixel_values)
patch_embeddings = patch_embeddings.flatten(2).transpose(1, 2)
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
embedding_output = torch.cat((cls_tokens, patch_embeddings), dim=1)
embedding_output = embedding_output + self.pos_embed
embedding_output = self.pos_drop(embedding_output)
return embedding_output
class MgpstrMlp(nn.Module):
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
def __init__(self, config: MgpstrConfig, hidden_features):
super().__init__()
hidden_features = hidden_features or config.hidden_size
self.fc1 = nn.Linear(config.hidden_size, hidden_features)
self.act = nn.GELU()
self.fc2 = nn.Linear(hidden_features, config.hidden_size)
self.drop = nn.Dropout(config.drop_rate)
def forward(self, hidden_states):
hidden_states = self.fc1(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.drop(hidden_states)
hidden_states = self.fc2(hidden_states)
hidden_states = self.drop(hidden_states)
return hidden_states
class MgpstrAttention(nn.Module):
def __init__(self, config: MgpstrConfig):
super().__init__()
self.num_heads = config.num_attention_heads
head_dim = config.hidden_size // config.num_attention_heads
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=config.qkv_bias)
self.attn_drop = nn.Dropout(config.attn_drop_rate)
self.proj = nn.Linear(config.hidden_size, config.hidden_size)
self.proj_drop = nn.Dropout(config.drop_rate)
def forward(self, hidden_states):
batch_size, num, channel = hidden_states.shape
qkv = (
self.qkv(hidden_states)
.reshape(batch_size, num, 3, self.num_heads, channel // self.num_heads)
.permute(2, 0, 3, 1, 4)
)
query, key, value = qkv[0], qkv[1], qkv[2]
attention_probs = (query @ key.transpose(-2, -1)) * self.scale
attention_probs = attention_probs.softmax(dim=-1)
attention_probs = self.attn_drop(attention_probs)
context_layer = (attention_probs @ value).transpose(1, 2).reshape(batch_size, num, channel)
context_layer = self.proj(context_layer)
context_layer = self.proj_drop(context_layer)
return (context_layer, attention_probs)
class MgpstrLayer(nn.Module):
def __init__(self, config: MgpstrConfig, drop_path=None):
super().__init__()
self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.attn = MgpstrAttention(config)
self.drop_path = MgpstrDropPath(drop_path) if drop_path is not None else nn.Identity()
self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
mlp_hidden_dim = int(config.hidden_size * config.mlp_ratio)
self.mlp = MgpstrMlp(config, mlp_hidden_dim)
def forward(self, hidden_states):
self_attention_outputs = self.attn(self.norm1(hidden_states))
attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1]
hidden_states = self.drop_path(attention_output) + hidden_states
layer_output = hidden_states + self.drop_path(self.mlp(self.norm2(hidden_states)))
outputs = (layer_output, outputs)
return outputs
class MgpstrEncoder(nn.Module):
def __init__(self, config: MgpstrConfig):
super().__init__()
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)]
self.blocks = nn.Sequential(
*[MgpstrLayer(config=config, drop_path=dpr[i]) for i in range(config.num_hidden_layers)]
)
def forward(self, hidden_states, output_attentions=False, output_hidden_states=False, return_dict=True):
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
for _, blk in enumerate(self.blocks):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_outputs = blk(hidden_states)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
class MgpstrA3Module(nn.Module):
def __init__(self, config: MgpstrConfig):
super().__init__()
self.token_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.tokenLearner = nn.Sequential(
nn.Conv2d(config.hidden_size, config.hidden_size, kernel_size=(1, 1), stride=1, groups=8, bias=False),
nn.Conv2d(config.hidden_size, config.max_token_length, kernel_size=(1, 1), stride=1, bias=False),
)
self.feat = nn.Conv2d(
config.hidden_size, config.hidden_size, kernel_size=(1, 1), stride=1, groups=8, bias=False
)
self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states):
hidden_states = self.token_norm(hidden_states)
hidden_states = hidden_states.transpose(1, 2).unsqueeze(-1)
selected = self.tokenLearner(hidden_states)
selected = selected.flatten(2)
attentions = F.softmax(selected, dim=-1)
feat = self.feat(hidden_states)
feat = feat.flatten(2).transpose(1, 2)
feat = torch.einsum("...si,...id->...sd", attentions, feat)
a3_out = self.norm(feat)
return (a3_out, attentions)
class MgpstrPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = MgpstrConfig
base_model_prefix = "mgp_str"
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights"""
if isinstance(module, MgpstrEmbeddings):
nn.init.trunc_normal_(module.pos_embed, mean=0.0, std=self.config.initializer_range)
nn.init.trunc_normal_(module.cls_token, mean=0.0, std=self.config.initializer_range)
elif isinstance(module, (nn.Linear, nn.Conv2d)):
module.weight.data = nn.init.trunc_normal_(module.weight.data, mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
MGP_STR_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
behavior.
Parameters:
config ([`MgpstrConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
MGP_STR_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`]
for details.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
@add_start_docstrings(
"The bare MGP-STR Model transformer outputting raw hidden-states without any specific head on top.",
MGP_STR_START_DOCSTRING,
)
class MgpstrModel(MgpstrPreTrainedModel):
def __init__(self, config: MgpstrConfig):
super().__init__(config)
self.config = config
self.embeddings = MgpstrEmbeddings(config)
self.encoder = MgpstrEncoder(config)
def get_input_embeddings(self) -> nn.Module:
return self.embeddings.proj
@add_start_docstrings_to_model_forward(MGP_STR_INPUTS_DOCSTRING)
def forward(
self,
pixel_values: torch.FloatTensor,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
embedding_output = self.embeddings(pixel_values)
encoder_outputs = self.encoder(
embedding_output,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
if not return_dict:
return encoder_outputs
return BaseModelOutput(
last_hidden_state=encoder_outputs.last_hidden_state,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
class MgpstrForSceneTextRecognition(MgpstrPreTrainedModel):
config_class = MgpstrConfig
main_input_name = "pixel_values"
def __init__(self, config: MgpstrConfig) -> None:
super().__init__(config)
self.num_labels = config.num_labels
self.mgp_str = MgpstrModel(config)
self.char_a3_module = MgpstrA3Module(config)
self.bpe_a3_module = MgpstrA3Module(config)
self.wp_a3_module = MgpstrA3Module(config)
self.char_head = nn.Linear(config.hidden_size, config.num_character_labels)
self.bpe_head = nn.Linear(config.hidden_size, config.num_bpe_labels)
self.wp_head = nn.Linear(config.hidden_size, config.num_wordpiece_labels)
@add_start_docstrings_to_model_forward(MGP_STR_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=MgpstrModelOutput, config_class=MgpstrConfig)
def forward(
self,
pixel_values: torch.FloatTensor,
output_attentions: Optional[bool] = None,
output_a3_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.FloatTensor], MgpstrModelOutput]:
r"""
output_a3_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of a3 modules. See `a3_attentions` under returned tensors
for more detail.
Returns:
This function returns either a tuple of torch.FloatTensor or an instance of MgpstrModelOutput.
Example:
```
>>> from transformers import (
... MgpstrProcessor,
... MgpstrForSceneTextRecognition,
... )
>>> import requests
>>> from PIL import Image
>>> # load image from the IIIT-5k dataset
>>> url = "https://i.postimg.cc/ZKwLg2Gw/367-14.png"
>>> image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
>>> processor = MgpstrProcessor.from_pretrained("alibaba-damo/mgp-str-base")
>>> pixel_values = processor(images=image, return_tensors="pt").pixel_values
>>> model = MgpstrForSceneTextRecognition.from_pretrained("alibaba-damo/mgp-str-base")
>>> # inference
>>> outputs = model(pixel_values)
>>> out_strs = processor.batch_decode(outputs.logits)
>>> out_strs["generated_text"]
'["ticket"]'
```
Initialize variables to default values if not provided by the caller.
`output_attentions`, `output_hidden_states`, and `return_dict` are set based on the model configuration.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
mgp_outputs = self.mgp_str(
pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = mgp_outputs[0]
char_a3_out, char_attention = self.char_a3_module(sequence_output)
bpe_a3_out, bpe_attention = self.bpe_a3_module(sequence_output)
wp_a3_out, wp_attention = self.wp_a3_module(sequence_output)
char_logits = self.char_head(char_a3_out)
bpe_logits = self.bpe_head(bpe_a3_out)
wp_logits = self.wp_head(wp_a3_out)
all_a3_attentions = (char_attention, bpe_attention, wp_attention) if output_a3_attentions else None
all_logits = (char_logits, bpe_logits, wp_logits)
if not return_dict:
outputs = (all_logits, all_a3_attentions) + mgp_outputs[1:]
return tuple(output for output in outputs if output is not None)
return MgpstrModelOutput(
logits=all_logits,
hidden_states=mgp_outputs.hidden_states,
attentions=mgp_outputs.attentions,
a3_attentions=all_a3_attentions,
)
.\models\mgp_str\processing_mgp_str.py
from transformers import AutoTokenizer
from transformers.utils import is_torch_available
from transformers.utils.generic import ExplicitEnum
from ...processing_utils import ProcessorMixin
if is_torch_available():
import torch
class DecodeType(ExplicitEnum):
CHARACTER = "char"
BPE = "bpe"
WORDPIECE = "wp"
SUPPORTED_ANNOTATION_FORMATS = (DecodeType.CHARACTER, DecodeType.BPE, DecodeType.WORDPIECE)
class MgpstrProcessor(ProcessorMixin):
"""
构建MGP-STR处理器,将图像处理器和MGP-STR分词器封装到一个单独的处理器中。
[`MgpstrProcessor`] 提供了`ViTImageProcessor`和`MgpstrTokenizer`的所有功能。查看[`~MgpstrProcessor.__call__`]和
[`~MgpstrProcessor.batch_decode`]获取更多信息。
Args:
image_processor (`ViTImageProcessor`, *可选*):
`ViTImageProcessor`的实例。图像处理器是必需的输入。
tokenizer ([`MgpstrTokenizer`], *可选*):
分词器是必需的输入。
"""
attributes = ["image_processor", "char_tokenizer"]
image_processor_class = "ViTImageProcessor"
char_tokenizer_class = "MgpstrTokenizer"
def __init__(self, image_processor=None, tokenizer=None, **kwargs):
feature_extractor = None
if "feature_extractor" in kwargs:
warnings.warn(
"The `feature_extractor` argument is deprecated and will be removed in v5, use `image_processor`"
" instead.",
FutureWarning,
)
feature_extractor = kwargs.pop("feature_extractor")
image_processor = image_processor if image_processor is not None else feature_extractor
if image_processor is None:
raise ValueError("You need to specify an `image_processor`.")
if tokenizer is None:
raise ValueError("You need to specify a `tokenizer`.")
self.char_tokenizer = tokenizer
self.bpe_tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
self.wp_tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
super().__init__(image_processor, tokenizer)
def __call__(self, text=None, images=None, return_tensors=None, **kwargs):
"""
当以普通模式使用时,此方法将所有参数转发到 ViTImageProcessor 的 [`~ViTImageProcessor.__call__`] 并返回其输出。
如果 `text` 不为 `None`,此方法还将 `text` 和 `kwargs` 参数转发到 MgpstrTokenizer 的 [`~MgpstrTokenizer.__call__`] 以编码文本。
更多信息请参考上述方法的文档字符串。
"""
if images is None and text is None:
raise ValueError("You need to specify either an `images` or `text` input to process.")
if images is not None:
inputs = self.image_processor(images, return_tensors=return_tensors, **kwargs)
if text is not None:
encodings = self.char_tokenizer(text, return_tensors=return_tensors, **kwargs)
if text is None:
return inputs
elif images is None:
return encodings
else:
inputs["labels"] = encodings["input_ids"]
return inputs
def batch_decode(self, sequences):
"""
将一组标记 id 的列表转换为字符串列表,通过调用 decode 方法实现。
Args:
sequences (`torch.Tensor`):
标记化输入 id 的列表。
Returns:
`Dict[str, any]`: 解码结果的所有输出字典。
generated_text (`List[str]`): 融合字符、bpe 和 wp 后的最终结果。
scores (`List[float]`): 融合字符、bpe 和 wp 后的最终分数。
char_preds (`List[str]`): 字符解码后的句子列表。
bpe_preds (`List[str]`): bpe 解码后的句子列表。
wp_preds (`List[str]`): wp 解码后的句子列表。
此方法将其所有参数转发到 PreTrainedTokenizer 的 [`~PreTrainedTokenizer.batch_decode`]。更多信息请参考此方法的文档字符串。
"""
char_preds, bpe_preds, wp_preds = sequences
batch_size = char_preds.size(0)
char_strs, char_scores = self._decode_helper(char_preds, "char")
bpe_strs, bpe_scores = self._decode_helper(bpe_preds, "bpe")
wp_strs, wp_scores = self._decode_helper(wp_preds, "wp")
final_strs = []
final_scores = []
for i in range(batch_size):
scores = [char_scores[i], bpe_scores[i], wp_scores[i]]
strs = [char_strs[i], bpe_strs[i], wp_strs[i]]
max_score_index = scores.index(max(scores))
final_strs.append(strs[max_score_index])
final_scores.append(scores[max_score_index])
out = {}
out["generated_text"] = final_strs
out["scores"] = final_scores
out["char_preds"] = char_strs
out["bpe_preds"] = bpe_strs
out["wp_preds"] = wp_strs
return out
def _decode_helper(self, pred_logits, format):
"""
Convert a list of lists of bpe token ids into a list of strings by calling bpe tokenizer.
Args:
pred_logits (`torch.Tensor`):
List of model prediction logits.
format (`Union[DecoderType, str]`):
Type of model prediction. Must be one of ['char', 'bpe', 'wp'].
Returns:
`tuple`:
dec_strs(`str`): The decode strings of model prediction.
conf_scores(`List[float]`): The confidence score of model prediction.
"""
if format == DecodeType.CHARACTER:
decoder = self.char_decode
eos_token = 1
eos_str = "[s]"
elif format == DecodeType.BPE:
decoder = self.bpe_decode
eos_token = 2
eos_str = "#"
elif format == DecodeType.WORDPIECE:
decoder = self.wp_decode
eos_token = 102
eos_str = "[SEP]"
else:
raise ValueError(f"Format {format} is not supported.")
dec_strs, conf_scores = [], []
batch_size = pred_logits.size(0)
batch_max_length = pred_logits.size(1)
_, preds_index = pred_logits.topk(1, dim=-1, largest=True, sorted=True)
preds_index = preds_index.view(-1, batch_max_length)[:, 1:]
preds_str = decoder(preds_index)
preds_max_prob, _ = torch.nn.functional.softmax(pred_logits, dim=2).max(dim=2)
preds_max_prob = preds_max_prob[:, 1:]
for index in range(batch_size):
pred_eos = preds_str[index].find(eos_str)
pred = preds_str[index][:pred_eos]
pred_index = preds_index[index].cpu().tolist()
pred_eos_index = pred_index.index(eos_token) if eos_token in pred_index else -1
pred_max_prob = preds_max_prob[index][: pred_eos_index + 1]
confidence_score = pred_max_prob.cumprod(dim=0)[-1] if pred_max_prob.nelement() != 0 else 0.0
dec_strs.append(pred)
conf_scores.append(confidence_score)
return dec_strs, conf_scores
def char_decode(self, sequences):
"""
Convert a list of lists of char token ids into a list of strings by calling char tokenizer.
Args:
sequences (`torch.Tensor`):
List of tokenized input ids.
Returns:
`List[str]`: The list of char decoded sentences.
"""
decode_strs = [seq.replace(" ", "") for seq in self.char_tokenizer.batch_decode(sequences)]
return decode_strs
def bpe_decode(self, sequences):
"""
Convert a list of lists of bpe token ids into a list of strings by calling bpe tokenizer.
Args:
sequences (`torch.Tensor`):
List of tokenized input ids.
Returns:
`List[str]`: The list of bpe decoded sentences.
"""
return self.bpe_tokenizer.batch_decode(sequences)
def wp_decode(self, sequences):
"""
Convert a list of lists of word piece token ids into a list of strings by calling word piece tokenizer.
Args:
sequences (`torch.Tensor`):
List of tokenized input ids.
Returns:
`List[str]`: The list of wp decoded sentences.
"""
decode_strs = [seq.replace(" ", "") for seq in self.wp_tokenizer.batch_decode(sequences)]
return decode_strs
.\models\mgp_str\tokenization_mgp_str.py
"""
Tokenization classes for MGT-STR CHAR.
"""
import json
import os
from typing import Optional, Tuple
from ...tokenization_utils import PreTrainedTokenizer
from ...utils import logging
logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {"vocab_file": "vocab.json"}
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"mgp-str": "https://huggingface.co/alibaba-damo/mgp-str-base/blob/main/vocab.json",
}
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {"mgp-str": 27}
class MgpstrTokenizer(PreTrainedTokenizer):
"""
Construct a MGP-STR char tokenizer.
This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
this superclass for more information regarding those methods.
Args:
vocab_file (`str`):
Path to the vocabulary file.
unk_token (`str`, *optional*, defaults to `"[GO]"`):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
token instead.
bos_token (`str`, *optional*, defaults to `"[GO]"`):
The beginning of sequence token.
eos_token (`str`, *optional*, defaults to `"[s]"`):
The end of sequence token.
pad_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"[GO]"`):
A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by
attention mechanisms or loss computation.
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, vocab_file, unk_token="[GO]", bos_token="[GO]", eos_token="[s]", pad_token="[GO]", **kwargs):
"""
Initialize a tokenizer instance.
Args:
vocab_file (`str`):
Path to the vocabulary file.
unk_token (`str`, *optional*, defaults to `"[GO]"`):
The unknown token.
bos_token (`str`, *optional*, defaults to `"[GO]"`):
The beginning of sequence token.
eos_token (`str`, *optional*, defaults to `"[s]"`):
The end of sequence token.
pad_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"[GO]"`):
The padding token used in batching.
**kwargs:
Additional keyword arguments passed to the parent class constructor.
"""
with open(vocab_file, encoding="utf-8") as vocab_handle:
self.vocab = json.load(vocab_handle)
self.decoder = {v: k for k, v in self.vocab.items()}
super().__init__(
unk_token=unk_token,
bos_token=bos_token,
eos_token=eos_token,
pad_token=pad_token,
**kwargs,
)
@property
def vocab_size(self):
"""
Return the size of the vocabulary.
Returns:
int: Number of tokens in the vocabulary.
"""
return len(self.vocab)
def get_vocab(self):
"""
Get the vocabulary (including any additional tokens).
Returns:
dict: A dictionary containing the vocabulary tokens and their IDs.
"""
vocab = dict(self.vocab).copy()
vocab.update(self.added_tokens_encoder)
return vocab
def _tokenize(self, text):
char_tokens = []
for s in text:
char_tokens.extend(s)
return char_tokens
def _convert_token_to_id(self, token):
return self.vocab.get(token, self.vocab.get(self.unk_token))
def _convert_id_to_token(self, index):
return self.decoder.get(index)
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not os.path.isdir(save_directory):
logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
return
vocab_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
)
with open(vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(self.vocab, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
return (vocab_file,)
.\models\mgp_str\__init__.py
from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
_import_structure = {
"configuration_mgp_str": ["MGP_STR_PRETRAINED_CONFIG_ARCHIVE_MAP", "MgpstrConfig"],
"processing_mgp_str": ["MgpstrProcessor"],
"tokenization_mgp_str": ["MgpstrTokenizer"],
}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_mgp_str"] = [
"MGP_STR_PRETRAINED_MODEL_ARCHIVE_LIST",
"MgpstrModel",
"MgpstrPreTrainedModel",
"MgpstrForSceneTextRecognition",
]
if TYPE_CHECKING:
from .configuration_mgp_str import MGP_STR_PRETRAINED_CONFIG_ARCHIVE_MAP, MgpstrConfig
from .processing_mgp_str import MgpstrProcessor
from .tokenization_mgp_str import MgpstrTokenizer
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_mgp_str import (
MGP_STR_PRETRAINED_MODEL_ARCHIVE_LIST,
MgpstrForSceneTextRecognition,
MgpstrModel,
MgpstrPreTrainedModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
.\models\mistral\configuration_mistral.py
""" Mistral model configuration"""
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
MISTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"mistralai/Mistral-7B-v0.1": "https://huggingface.co/mistralai/Mistral-7B-v0.1/resolve/main/config.json",
"mistralai/Mistral-7B-Instruct-v0.1": "https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/resolve/main/config.json",
}
class MistralConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`MistralModel`]. It is used to instantiate an
Mistral model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the Mistral-7B-v0.1 or Mistral-7B-Instruct-v0.1.
[mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1)
[mistralai/Mistral-7B-Instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
```
>>> from transformers import MistralModel, MistralConfig
>>> # Initializing a Mistral 7B style configuration
>>> configuration = MistralConfig()
>>> # Initializing a model from the Mistral 7B style configuration
>>> model = MistralModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
"""
model_type = "mistral"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=32000,
hidden_size=4096,
intermediate_size=14336,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=8,
hidden_act="silu",
max_position_embeddings=4096 * 32,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=None,
bos_token_id=1,
eos_token_id=2,
tie_word_embeddings=False,
rope_theta=10000.0,
sliding_window=4096,
attention_dropout=0.0,
**kwargs,
):
super().__init__(**kwargs)
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.sliding_window = sliding_window
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.attention_dropout = attention_dropout
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
.\models\mistral\convert_mistral_weights_to_hf.py
import argparse
import gc
import json
import os
import shutil
import warnings
import torch
from transformers import (
LlamaTokenizer,
MistralConfig,
MistralForCausalLM,
)
try:
from transformers import LlamaTokenizerFast
tokenizer_class = LlamaTokenizerFast
except ImportError as e:
warnings.warn(e)
warnings.warn(
"The converted tokenizer will be the `slow` tokenizer. To use the fast, update your `tokenizers` library and re-run the tokenizer conversion"
)
tokenizer_class = LlamaTokenizer
"""
示例用法:
python src/transformers/models/mistral/convert_mistral_weights_to_hf.py \
--input_dir /path/to/downloaded/mistral/weights --model_size 7B --output_dir /output/path
"""
NUM_SHARDS = {"7B": 1}
def compute_intermediate_size(n, ffn_dim_multiplier=1, multiple_of=256):
return multiple_of * ((int(ffn_dim_multiplier * int(8 * n / 3)) + multiple_of - 1) // multiple_of)
def read_json(path):
with open(path, "r") as f:
return json.load(f)
def write_json(text, path):
with open(path, "w") as f:
json.dump(text, f)
def write_model(model_path, input_base_path, model_size, tokenizer_path=None, safe_serialization=True):
if not os.path.isfile(os.path.join(input_base_path, "params.json")):
input_base_path = os.path.join(input_base_path, model_size)
os.makedirs(model_path, exist_ok=True)
tmp_model_path = os.path.join(model_path, "tmp")
os.makedirs(tmp_model_path, exist_ok=True)
params = read_json(os.path.join(input_base_path, "params.json"))
num_shards = NUM_SHARDS[model_size]
sliding_window = int(params["sliding_window"])
n_layers = params["n_layers"]
n_heads = params["n_heads"]
n_heads_per_shard = n_heads // num_shards
dim = params["dim"]
dims_per_head = dim // n_heads
base = params.get("rope_theta", 10000.0)
inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
max_position_embeddings = 4096 * 8
if tokenizer_path is not None:
tokenizer = tokenizer_class(tokenizer_path)
tokenizer.save_pretrained(model_path)
vocab_size = tokenizer.vocab_size if tokenizer_path is not None else 32000
if "n_kv_heads" in params:
num_key_value_heads = params["n_kv_heads"]
num_local_key_value_heads = num_key_value_heads // num_shards
key_value_dim = dims_per_head * num_local_key_value_heads
else:
num_key_value_heads = n_heads
num_local_key_value_heads = n_heads_per_shard
key_value_dim = dim
def permute(w, n_heads=n_heads, dim1=dim, dim2=dim):
return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2)
print(f"Fetching all parameters from the checkpoint at {input_base_path}.")
loaded = [
torch.load(os.path.join(input_base_path, f"consolidated.{i:02d}.pth"), map_location="cpu")
for i in range(num_shards)
]
param_count = 0
index_dict = {"weight_map": {}}
filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin"
state_dict = {
"model.norm.weight": loaded[0]["norm.weight"],
"model.embed_tokens.weight": torch.cat([loaded[i]["tok_embeddings.weight"] for i in range(num_shards)], dim=1),
"lm_head.weight": torch.cat([loaded[i]["output.weight"] for i in range(num_shards)], dim=0),
}
for k, v in state_dict.items():
index_dict["weight_map"][k] = filename
param_count += v.numel()
torch.save(state_dict, os.path.join(tmp_model_path, filename))
index_dict["metadata"] = {"total_size": param_count * 2}
write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json"))
config = MistralConfig(
hidden_size=dim,
intermediate_size=params["hidden_dim"],
num_attention_heads=params["n_heads"],
num_hidden_layers=params["n_layers"],
rms_norm_eps=params["norm_eps"],
num_key_value_heads=num_key_value_heads,
vocab_size=vocab_size,
rope_theta=base,
max_position_embeddings=max_position_embeddings,
sliding_window=sliding_window,
)
config.save_pretrained(tmp_model_path)
del state_dict
del loaded
gc.collect()
print("Loading the checkpoint in a Mistral model.")
model = MistralForCausalLM.from_pretrained(tmp_model_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True)
del model.config._name_or_path
model.config.torch_dtype = torch.float16
print("Saving in the Transformers format.")
model.save_pretrained(model_path, safe_serialization=safe_serialization)
shutil.rmtree(tmp_model_path)
def write_tokenizer(tokenizer_path, input_tokenizer_path):
print(f"Saving a {tokenizer_class.__name__} to {tokenizer_path}.")
tokenizer = tokenizer_class(input_tokenizer_path)
tokenizer.save_pretrained(tokenizer_path)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--input_dir",
help="Location of Mistral weights, which contains tokenizer.model and model folders",
)
parser.add_argument(
"--model_size",
choices=["7B", "tokenizer_only"],
help="'f' models correspond to the finetuned versions, and are specific to the Mistral2 official release. For more details on Mistral2, checkout the original repo: https://huggingface.co/meta-mistral",
)
parser.add_argument(
"--output_dir",
help="Location to write HF model and tokenizer",
)
parser.add_argument("--safe_serialization", type=bool, help="Whether or not to save using `safetensors`.")
args = parser.parse_args()
spm_path = os.path.join(args.input_dir, "tokenizer.model")
if args.model_size != "tokenizer_only":
write_model(
model_path=args.output_dir,
input_base_path=args.input_dir,
model_size=args.model_size,
safe_serialization=args.safe_serialization,
tokenizer_path=spm_path,
)
else:
write_tokenizer(args.output_dir, spm_path)
if __name__ == "__main__":
main()
.\models\mistral\modeling_flax_mistral.py
from typing import Optional, Tuple
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask
from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax
from ...modeling_flax_outputs import (
FlaxBaseModelOutput,
FlaxBaseModelOutputWithPast,
FlaxCausalLMOutput,
FlaxCausalLMOutputWithCrossAttentions,
)
from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring, logging
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward
from .configuration_mistral import MistralConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "MistralConfig"
_REAL_CHECKPOINT_FOR_DOC = "mistralai/Mistral-7B-v0.1"
_CHECKPOINT_FOR_DOC = "ksmcg/Mistral-tiny"
MISTRAL_START_DOCSTRING = r"""
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a Flax Linen
[flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
Finally, this model supports inherent JAX features such as:
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
# 参数说明:
# config ([`MistralConfig`]): 模型配置类,包含模型的所有参数。
# 使用配置文件初始化时不会加载模型的权重,只加载配置信息。
# 查看 [`~FlaxPreTrainedModel.from_pretrained`] 方法以加载模型权重。
# dtype (`jax.numpy.dtype`, *optional*, 默认为 `jax.numpy.float32`):
# 计算使用的数据类型。可以是 `jax.numpy.float32`, `jax.numpy.float16` 或 `jax.numpy.bfloat16` 中的一种。
# 可用于在 GPU 或 TPU 上启用混合精度训练或半精度推断。如果指定,则所有计算将使用给定的 `dtype` 进行。
#
# **注意,这仅指定计算时的数据类型,不影响模型参数的数据类型。**
#
# 如果要更改模型参数的数据类型,请参阅 [`~FlaxPreTrainedModel.to_fp16`] 和 [`~FlaxPreTrainedModel.to_bf16`]。
# 定义了一个文档字符串常量,描述了 `FlaxMistralRMSNorm` 类的输入参数和用法
MISTRAL_INPUTS_DOCSTRING = r"""
Args:
input_ids (`numpy.ndarray` of shape `(batch_size, input_ids_length)`):
输入序列标记在词汇表中的索引。默认情况下,提供的填充将被忽略。
可以使用 [`AutoTokenizer`] 获取索引。详见 [`PreTrainedTokenizer.encode`] 和 [`PreTrainedTokenizer.__call__`]。
[什么是输入 ID?](../glossary
attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
避免在填充标记索引上执行注意力操作的掩码。掩码值在 `[0, 1]` 范围内:
- 对于 **未被掩码** 的标记,值为 1,
- 对于 **被掩码** 的标记,值为 0。
可以使用 [`AutoTokenizer`] 获取索引。详见 [`PreTrainedTokenizer.encode`] 和 [`PreTrainedTokenizer.__call__`]。
如果使用了 `past_key_values`,可以选择仅输入最后的 `decoder_input_ids`(参见 `past_key_values`)。
如果要更改填充行为,应阅读 [`modeling_opt._prepare_decoder_attention_mask`] 并根据需求进行修改。详见 [该论文中的图表 1](https://arxiv.org/abs/1910.13461) 获取有关默认策略的更多信息。
- 1 表示头部 **未被掩码**,
- 0 表示头部 **被掩码**。
position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
每个输入序列标记在位置嵌入中的位置索引。选择范围为 `[0, config.n_positions - 1]`。
[什么是位置 ID?](../glossary
past_key_values (`Dict[str, np.ndarray]`, *optional*, 由 `init_cache` 返回或传递先前的 `past_key_values`):
预计算隐藏状态的字典(键和值在注意力块中)。可用于快速自回归解码。预计算的键和值隐藏状态的形状为 *[batch_size, max_length]*。
output_attentions (`bool`, *optional*):
是否返回所有注意力层的注意力张量。详见返回的张量中的 `attentions` 获取更多细节。
output_hidden_states (`bool`, *optional*):
是否返回所有层的隐藏状态。详见返回的张量中的 `hidden_states` 获取更多细节。
return_dict (`bool`, *optional*):
是否返回 [`~utils.ModelOutput`] 而不是普通的元组。
"""
# 从 `transformers.models.llama.modeling_flax_llama.FlaxLlamaRMSNorm` 复制并修改为 `FlaxMistralRMSNorm`
class FlaxMistralRMSNorm(nn.Module):
# 类型注解,指定了 `config` 属性的类型为 `MistralConfig`
config: MistralConfig
# 默认数据类型为 `jnp.float32`
dtype: jnp.dtype = jnp.float32
# 初始化对象的epsilon属性为配置中的rms_norm_eps值
self.epsilon = self.config.rms_norm_eps
# 初始化对象的weight属性,使用param方法生成,传入的lambda函数生成一个形状为hidden_size的全1数组
self.weight = self.param("weight", lambda _, shape: jnp.ones(shape), self.config.hidden_size)
# 定义对象的调用方法,接收hidden_states作为参数
def __call__(self, hidden_states):
# 将hidden_states转换为JAX支持的float32类型的数组variance
variance = jnp.asarray(hidden_states, dtype=jnp.float32)
# 对variance中的每个元素求平方
variance = jnp.power(variance, 2)
# 对variance在最后一个维度上求平均值,并保持维度为1
variance = variance.mean(-1, keepdims=True)
# 使用JAX的sqrt函数对variance加上epsilon后开方,作为对hidden_states的归一化系数
# 注意:使用jax.numpy.sqrt代替jax.lax.rsqrt是因为两者的行为不同于torch.rsqrt
hidden_states = hidden_states / jnp.sqrt(variance + self.epsilon)
# 返回归一化后的hidden_states乘以对象的weight属性
return self.weight * jnp.asarray(hidden_states, dtype=self.dtype)
# 从 transformers.models.llama.modeling_flax_llama.FlaxLlamaRotaryEmbedding 复制代码,将 Llama 替换为 Mistral
class FlaxMistralRotaryEmbedding(nn.Module):
# 使用 MistralConfig 配置信息
config: MistralConfig
# 数据类型默认为 jnp.float32
dtype: jnp.dtype = jnp.float32
def setup(self):
# 计算每个注意力头的维度
head_dim = self.config.hidden_size // self.config.num_attention_heads
# 创建正弦和余弦位置编码
self.sincos = create_sinusoidal_positions(self.config.max_position_embeddings, head_dim)
def __call__(self, key, query, position_ids):
# 根据位置编码获取对应的正弦和余弦值
sincos = self.sincos[position_ids]
sin_pos, cos_pos = jnp.split(sincos, 2, axis=-1)
# 应用旋转位置编码到键和查询张量
key = apply_rotary_pos_emb(key, sin_pos, cos_pos)
query = apply_rotary_pos_emb(query, sin_pos, cos_pos)
# 转换为指定数据类型
key = jnp.asarray(key, dtype=self.dtype)
query = jnp.asarray(query, dtype=self.dtype)
# 返回处理后的键和查询张量
return key, query
# 从 transformers.models.llama.modeling_flax_llama.FlaxLlamaMLP 复制代码,将 Llama 替换为 Mistral
class FlaxMistralMLP(nn.Module):
# 使用 MistralConfig 配置信息
config: MistralConfig
# 数据类型默认为 jnp.float32
dtype: jnp.dtype = jnp.float32
def setup(self):
# 获取嵌入维度和内部维度
embed_dim = self.config.hidden_size
inner_dim = self.config.intermediate_size if self.config.intermediate_size is not None else 4 * embed_dim
# 初始化内核,并设置激活函数
kernel_init = jax.nn.initializers.normal(self.config.initializer_range)
self.act = ACT2FN[self.config.hidden_act]
# 定义门控投影、下游投影和上游投影
self.gate_proj = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init)
self.down_proj = nn.Dense(embed_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init)
self.up_proj = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init)
def __call__(self, hidden_states):
# 上游投影处理隐藏状态
up_proj_states = self.up_proj(hidden_states)
# 使用激活函数处理门控投影的隐藏状态
gate_states = self.act(self.gate_proj(hidden_states))
# 应用门控和上游投影到下游投影的隐藏状态
hidden_states = self.down_proj(up_proj_states * gate_states)
# 返回处理后的隐藏状态
return hidden_states
# 从 transformers.models.llama.modeling_flax_llama.apply_rotary_pos_emb 复制代码
def apply_rotary_pos_emb(tensor, sin_pos, cos_pos):
# 应用旋转位置编码到张量
return (tensor * cos_pos) + (rotate_half(tensor) * sin_pos)
# 从 transformers.models.llama.modeling_flax_llama.create_sinusoidal_positions 复制代码
def create_sinusoidal_positions(num_pos, dim):
# 计算逆频率
inv_freq = 1.0 / (10000 ** (np.arange(0, dim, 2) / dim))
freqs = np.einsum("i , j -> i j", np.arange(num_pos), inv_freq).astype("float32")
# 创建正弦和余弦位置编码
emb = np.concatenate((freqs, freqs), axis=-1)
out = np.concatenate((np.sin(emb)[:, None, :], np.cos(emb)[:, None, :]), axis=-1)
return jnp.array(out[:, :, :num_pos])
# 从 transformers.models.llama.modeling_flax_llama.rotate_half 复制代码
def rotate_half(tensor):
"""旋转输入张量的一半隐藏维度。"""
rotate_half_tensor = jnp.concatenate(
(-tensor[..., tensor.shape[-1] // 2 :], tensor[..., : tensor.shape[-1] // 2]), axis=-1
)
return rotate_half_tensor
# 定义 FlaxMistralAttention 类,用于注意力机制,未完整复制
class FlaxMistralAttention(nn.Module):
# 使用 MistralConfig 配置信息
config: MistralConfig
# 数据类型默认为 jnp.float32
dtype: jnp.dtype = jnp.float32
def setup(self):
# 从配置中获取参数
config = self.config
# 设置隐藏层大小
self.hidden_size = config.hidden_size
# 设置注意力头数
self.num_heads = config.num_attention_heads
# 计算每个注意力头的维度
self.head_dim = self.hidden_size // self.num_heads
# 设置键值头数
self.num_key_value_heads = config.num_key_value_heads
# 计算每个键值组的头数
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
# 设置最大位置嵌入数
self.max_position_embeddings = config.max_position_embeddings
# 判断是否需要在注意力softmax计算中使用fp32精度
self.attention_softmax_in_fp32 = self.dtype is not jnp.float32
# 设置rope_theta
self.rope_theta = config.rope_theta
# 检查隐藏层大小是否可以被注意力头数整除
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
# 初始化查询、键、值和输出的线性投影层
self.q_proj = nn.Dense(self.num_heads * self.head_dim, use_bias=False, dtype=self.dtype)
self.k_proj = nn.Dense(self.num_key_value_heads * self.head_dim, use_bias=False, dtype=self.dtype)
self.v_proj = nn.Dense(self.num_key_value_heads * self.head_dim, use_bias=False, dtype=self.dtype)
self.o_proj = nn.Dense(self.hidden_size, use_bias=False, dtype=self.dtype)
# 创建自回归遮罩
casual_mask = make_causal_mask(jnp.ones((1, config.max_position_embeddings), dtype="bool"), dtype="bool")
# 根据滑动窗口大小生成自回归遮罩
self.causal_mask = jnp.triu(casual_mask, k=-config.sliding_window)
# 初始化旋转嵌入
self.rotary_emb = FlaxMistralRotaryEmbedding(config, dtype=self.dtype)
def _split_heads(self, hidden_states, num_heads):
# 将隐藏状态分割成多个头
return hidden_states.reshape(hidden_states.shape[:2] + (num_heads, self.head_dim))
def _merge_heads(self, hidden_states):
# 合并多个头的隐藏状态
return hidden_states.reshape(hidden_states.shape[:2] + (self.hidden_size,))
@nn.compact
# 从transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoSelfAttention._concatenate_to_cache复制而来
def _concatenate_to_cache(self, key, value, query, attention_mask):
"""
This function takes projected key, value states from a single input token and concatenates the states to cached
states from previous steps. This function is slightly adapted from the official Flax repository:
https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py
"""
# 检测是否初始化缓存数据
is_initialized = self.has_variable("cache", "cached_key")
# 获取或者初始化缓存的 key 和 value,若不存在则创建零张量
cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
# 获取或者初始化缓存的索引,若不存在则设置为 0
cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
if is_initialized:
*batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
# 使用新的 1D 空间切片更新 key 和 value 缓存
cur_index = cache_index.value
indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
key = lax.dynamic_update_slice(cached_key.value, key, indices)
value = lax.dynamic_update_slice(cached_value.value, value, indices)
# 更新缓存中的 key 和 value
cached_key.value = key
cached_value.value = value
# 更新缓存索引,增加已更新的缓存向量数目
num_updated_cache_vectors = query.shape[1]
cache_index.value = cache_index.value + num_updated_cache_vectors
# 用于缓存的自注意力掩码:我们的单个查询位置应仅关注已生成和缓存的 key 位置,而不是剩余的零元素。
pad_mask = jnp.broadcast_to(
jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
)
# 组合现有的掩码和给定的注意力掩码
attention_mask = combine_masks(pad_mask, attention_mask)
# 返回更新后的 key, value 和注意力掩码
return key, value, attention_mask
) -> Tuple[jnp.ndarray, jnp.ndarray]:
# 使用 self.q_proj 对隐藏状态进行投影得到查询状态
query_states = self.q_proj(hidden_states)
# 使用 self.k_proj 对隐藏状态进行投影得到键状态
key_states = self.k_proj(hidden_states)
# 使用 self.v_proj 对隐藏状态进行投影得到值状态
value_states = self.v_proj(hidden_states)
# 将查询状态按照头数进行分割
query_states = self._split_heads(query_states, self.num_heads)
# 将键状态按照键值头数进行分割
key_states = self._split_heads(key_states, self.num_key_value_heads)
# 将值状态按照键值头数进行分割
value_states = self._split_heads(value_states, self.num_key_value_heads)
# 使用 rotary_emb 方法对键状态和查询状态进行旋转嵌入
key_states, query_states = self.rotary_emb(key_states, query_states, position_ids)
# 获取查询和键的长度
query_length, key_length = query_states.shape[1], key_states.shape[1]
# 根据是否有缓存的键来确定掩码的偏移量和最大解码长度
if self.has_variable("cache", "cached_key"):
mask_shift = self.variables["cache"]["cache_index"]
max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
# 创建动态切片的因果掩码
causal_mask = lax.dynamic_slice(
self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
)
else:
# 使用预先计算好的因果掩码
causal_mask = self.causal_mask[:, :, :query_length, :key_length]
# 获取批次大小
batch_size = hidden_states.shape[0]
# 将因果掩码广播到与注意力掩码相同的形状
causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
# 将注意力掩码扩展到与因果掩码相同的形状
attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
# 结合注意力掩码和因果掩码
attention_mask = combine_masks(attention_mask, causal_mask)
# 如果有缓存的键或者需要初始化缓存,则将键状态、值状态和注意力掩码拼接到缓存中
if self.has_variable("cache", "cached_key") or init_cache:
key_states, value_states, attention_mask = self._concatenate_to_cache(
key_states, value_states, query_states, attention_mask
)
# 将键状态在键值组之间重复以支持并行处理
key_states = jnp.repeat(key_states, self.num_key_value_groups, axis=2)
# 将值状态在键值组之间重复以支持并行处理
value_states = jnp.repeat(value_states, self.num_key_value_groups, axis=2)
# 创建注意力偏置,根据注意力掩码设置有效和无效区域的偏置值
attention_bias = lax.select(
attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
)
# 常规的点积注意力计算
attention_dtype = jnp.float32 if self.attention_softmax_in_fp32 else self.dtype
attn_weights = dot_product_attention_weights(
query_states,
key_states,
bias=attention_bias,
deterministic=deterministic,
dropout_rate=self.config.attention_dropout,
dtype=attention_dtype,
)
# 如果需要在 float32 中执行 softmax,将注意力权重转换为目标 dtype
if self.attention_softmax_in_fp32:
attn_weights = attn_weights.astype(self.dtype)
# 使用 einsum 执行注意力加权求和操作
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
# 合并多头的结果
attn_output = self._merge_heads(attn_output)
# 对输出应用输出投影
attn_output = self.o_proj(attn_output)
# 准备输出,包括注意力权重(如果需要)
outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
return outputs
# Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaDecoderLayer with Llama->Mistral
class FlaxMistralDecoderLayer(nn.Module):
config: MistralConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
# 初始化输入层的 Layer Normalization
self.input_layernorm = FlaxMistralRMSNorm(self.config, dtype=self.dtype)
# 初始化自注意力机制
self.self_attn = FlaxMistralAttention(self.config, dtype=self.dtype)
# 初始化自注意力后的 Layer Normalization
self.post_attention_layernorm = FlaxMistralRMSNorm(self.config, dtype=self.dtype)
# 初始化多层感知机 MLP
self.mlp = FlaxMistralMLP(self.config, dtype=self.dtype)
def __call__(
self,
hidden_states,
attention_mask=None,
position_ids=None,
deterministic: bool = True,
init_cache: bool = False,
output_attentions: bool = False,
):
# 残差连接
residual = hidden_states
# 应用输入层的 Layer Normalization
hidden_states = self.input_layernorm(hidden_states)
# 应用自注意力机制
outputs = self.self_attn(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
deterministic=deterministic,
init_cache=init_cache,
output_attentions=output_attentions,
)
# 残差连接
attn_output = outputs[0]
hidden_states = residual + attn_output
# 残差连接
residual = hidden_states
# 应用自注意力后的 Layer Normalization
hidden_states = self.post_attention_layernorm(hidden_states)
# 应用多层感知机 MLP
hidden_states = self.mlp(hidden_states)
# 残差连接
hidden_states = residual + hidden_states
return (hidden_states,) + outputs[1:]
# Copied from transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoPreTrainedModel with GPTNeo->Mistral, GPT_NEO->MISTRAL, transformer->model
class FlaxMistralPreTrainedModel(FlaxPreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = MistralConfig
base_model_prefix = "model"
module_class: nn.Module = None
def __init__(
self,
config: MistralConfig,
input_shape: Tuple = (1, 1),
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs,
):
# 初始化模块对象
module = self.module_class(config=config, dtype=dtype, **kwargs)
# 调用父类初始化方法
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# 初始化输入张量
input_ids = jnp.zeros(input_shape, dtype="i4")
# 创建与input_ids形状相同的全1张量作为注意力掩码
attention_mask = jnp.ones_like(input_ids)
# 根据input_ids的形状广播生成位置编码张量
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
# 拆分随机数生成器rng,生成参数随机数和dropout随机数
params_rng, dropout_rng = jax.random.split(rng)
# 存储随机数生成器
rngs = {"params": params_rng, "dropout": dropout_rng}
# 使用self.module的初始化方法初始化模型参数,返回未解冻的参数字典
random_params = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)["params"]
# 如果传入了预训练的参数params,则与随机初始化的参数进行合并
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
# 返回合并后的冻结参数字典
return freeze(unflatten_dict(params))
else:
# 否则返回随机初始化的参数字典
return random_params
def init_cache(self, batch_size, max_length):
r"""
Args:
batch_size (`int`):
用于快速自回归解码的批处理大小,定义了初始化缓存的批处理大小。
max_length (`int`):
自回归解码的最大可能长度,定义了初始化缓存的序列长度。
"""
# 初始化用于检索缓存的输入变量
input_ids = jnp.ones((batch_size, max_length))
# 创建与input_ids形状相同的全1张量作为注意力掩码
attention_mask = jnp.ones_like(input_ids)
# 根据input_ids的形状广播生成位置编码张量
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
# 使用self.module的初始化方法初始化模型变量,设置init_cache=True以初始化缓存
init_variables = self.module.init(
jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
)
# 返回未解冻的缓存字典
return unfreeze(init_variables["cache"])
@add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
def __call__(
self,
input_ids,
attention_mask=None,
position_ids=None,
params: dict = None,
past_key_values: dict = None,
dropout_rng: jax.random.PRNGKey = None,
train: bool = False,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
# 如果没有显式传入 output_attentions 参数,则使用配置中的设定
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
# 如果没有显式传入 output_hidden_states 参数,则使用配置中的设定
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
# 如果没有显式传入 return_dict 参数,则使用配置中的设定
return_dict = return_dict if return_dict is not None else self.config.return_dict
# 获取输入张量的批量大小和序列长度
batch_size, sequence_length = input_ids.shape
# 如果未传入 position_ids,则根据序列长度和批量大小广播生成位置 ID
if position_ids is None:
if past_key_values is not None:
raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.")
position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
# 如果未传入 attention_mask,则创建全为 1 的注意力遮罩
if attention_mask is None:
attention_mask = jnp.ones((batch_size, sequence_length))
# 处理任何需要的伪随机数生成器
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
inputs = {"params": params or self.params}
# 如果传入了 past_key_values,则将其作为 cache 输入到模块中,确保 cache 是可变的
if past_key_values:
inputs["cache"] = past_key_values
mutable = ["cache"]
else:
mutable = False
# 调用模块的 apply 方法进行前向传播
outputs = self.module.apply(
inputs,
jnp.array(input_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"),
jnp.array(position_ids, dtype="i4"),
not train,
False,
output_attentions,
output_hidden_states,
return_dict,
rngs=rngs,
mutable=mutable,
)
# 如果传入了 past_key_values 并且设置了 return_dict,则将更新后的 cache 添加到模型输出中
if past_key_values is not None and return_dict:
outputs, past_key_values = outputs
outputs["past_key_values"] = unfreeze(past_key_values["cache"])
return outputs
# 如果传入了 past_key_values 但未设置 return_dict,则更新 cache 并将其添加到模型输出中
elif past_key_values is not None and not return_dict:
outputs, past_key_values = outputs
outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
# 返回模型输出
return outputs
# 从transformers.models.llama.modeling_flax_llama.FlaxLlamaLayerCollection复制而来,将Llama改为Mistral
class FlaxMistralLayerCollection(nn.Module):
# MistralConfig的实例变量config,dtype默认为jnp.float32
config: MistralConfig
dtype: jnp.dtype = jnp.float32
# 模块初始化方法
def setup(self):
# 创建self.config.num_hidden_layers个FlaxMistralDecoderLayer对象列表
self.blocks = [
FlaxMistralDecoderLayer(self.config, dtype=self.dtype, name=str(i))
for i in range(self.config.num_hidden_layers)
]
# 模块调用方法
def __call__(
self,
hidden_states,
attention_mask=None,
position_ids=None,
deterministic: bool = True,
init_cache: bool = False,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = False,
):
# 如果输出attentions,则初始化空元组all_attentions;否则为None
all_attentions = () if output_attentions else None
# 如果输出hidden states,则初始化空元组all_hidden_states;否则为None
all_hidden_states = () if output_hidden_states else None
# 遍历self.blocks中的每个FlaxMistralDecoderLayer对象
for block in self.blocks:
# 如果需要输出hidden states,则将当前hidden_states添加到all_hidden_states元组中
if output_hidden_states:
all_hidden_states += (hidden_states,)
# 调用block对象进行前向传播,获取layer_outputs
layer_outputs = block(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
deterministic=deterministic,
init_cache=init_cache,
output_attentions=output_attentions,
)
# 更新hidden_states为block的输出的第一个元素
hidden_states = layer_outputs[0]
# 如果需要输出attentions,则将当前层的attentions添加到all_attentions元组中
if output_attentions:
all_attentions += (layer_outputs[1],)
# 输出包含可能为None值的元组outputs,FlaxMistralModule将会过滤掉这些None值
outputs = (hidden_states, all_hidden_states, all_attentions)
# 返回outputs作为模块的输出结果
return outputs
# 从transformers.models.llama.modeling_flax_llama.FlaxLlamaModule复制而来,将Llama改为Mistral
class FlaxMistralModule(nn.Module):
# MistralConfig的实例变量config,dtype默认为jnp.float32
config: MistralConfig
dtype: jnp.dtype = jnp.float32
# 模块初始化方法
def setup(self):
# 设置self.hidden_size为self.config.hidden_size
self.hidden_size = self.config.hidden_size
# 使用正态分布初始化embed_tokens的embedding参数
embedding_init = jax.nn.initializers.normal(stddev=self.config.initializer_range)
# 创建nn.Embed对象embed_tokens,用于token的embedding
self.embed_tokens = nn.Embed(
self.config.vocab_size,
self.hidden_size,
embedding_init=embedding_init,
dtype=self.dtype,
)
# 创建FlaxMistralLayerCollection对象self.layers,用于处理层间关系
self.layers = FlaxMistralLayerCollection(self.config, dtype=self.dtype)
# 创建FlaxMistralRMSNorm对象self.norm,用于层间正则化
self.norm = FlaxMistralRMSNorm(self.config, dtype=self.dtype)
# 模块调用方法
def __call__(
self,
input_ids,
attention_mask=None,
position_ids=None,
deterministic=True,
init_cache: bool = False,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
# 返回字典形式的输出结果
# 输入参数input_ids、attention_mask、position_ids以及其他标志位
):
# 将输入的 token IDs 转换为嵌入表示,数据类型为整数
input_embeds = self.embed_tokens(input_ids.astype("i4"))
# 使用 Transformer 层处理输入数据
outputs = self.layers(
input_embeds,
position_ids=position_ids,
attention_mask=attention_mask,
deterministic=deterministic,
init_cache=init_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# 获取模型输出的隐藏状态
hidden_states = outputs[0]
# 对隐藏状态进行规范化处理
hidden_states = self.norm(hidden_states)
# 如果需要输出所有隐藏状态,则将当前隐藏状态加入所有隐藏状态列表
if output_hidden_states:
all_hidden_states = outputs[1] + (hidden_states,)
outputs = (hidden_states, all_hidden_states) + outputs[2:]
else:
outputs = (hidden_states,) + outputs[1:]
# 如果不需要返回字典形式的输出,则去除所有值为 None 的项并返回元组
if not return_dict:
return tuple(v for v in outputs if v is not None)
# 返回 FlaxBaseModelOutput 对象,包含最后的隐藏状态、所有隐藏状态和注意力值
return FlaxBaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=outputs[1],
attentions=outputs[-1],
)
# 添加文档字符串到 FlaxMistralModel 类,说明其作用是提供裸的 Mistral 模型变换器输出,没有特定的输出头部。
@add_start_docstrings(
"The bare Mistral Model transformer outputting raw hidden-states without any specific head on top.",
MISTRAL_START_DOCSTRING,
)
class FlaxMistralModel(FlaxMistralPreTrainedModel):
# 设置模块类为 FlaxMistralModule
module_class = FlaxMistralModule
# 向 FlaxMistralModel 类添加调用示例文档字符串,用于样例的调用说明
append_call_sample_docstring(
FlaxMistralModel,
_CHECKPOINT_FOR_DOC,
FlaxBaseModelOutputWithPast,
_CONFIG_FOR_DOC,
real_checkpoint=_REAL_CHECKPOINT_FOR_DOC,
)
# 从 transformers.models.llama.modeling_flax_llama.FlaxLlamaForCausalLMModule 复制代码,并将 Llama 更改为 Mistral
class FlaxMistralForCausalLMModule(nn.Module):
config: MistralConfig # 定义配置为 MistralConfig 类型
dtype: jnp.dtype = jnp.float32 # 设置数据类型为 jnp.float32,默认为 float32
def setup(self):
# 使用配置和数据类型创建 FlaxMistralModule 模型
self.model = FlaxMistralModule(self.config, dtype=self.dtype)
# 创建 LM 头部,是一个全连接层,用于语言建模任务
self.lm_head = nn.Dense(
self.config.vocab_size,
use_bias=False,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
)
def __call__(
self,
input_ids,
attention_mask=None,
position_ids=None,
deterministic: bool = True,
init_cache: bool = False,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
# 调用模型进行前向传播
outputs = self.model(
input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
deterministic=deterministic,
init_cache=init_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# 从模型输出中提取隐藏状态
hidden_states = outputs[0]
# 计算语言建模的 logits
lm_logits = self.lm_head(hidden_states)
# 如果不返回字典,则返回一个元组,包含 lm_logits 和其他输出
if not return_dict:
return (lm_logits,) + outputs[1:]
# 返回 FlaxCausalLMOutput 对象,包含 logits、隐藏状态和注意力信息
return FlaxCausalLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
# 添加文档字符串到 FlaxMistralForCausalLM 类,说明其作用是在 Mistral 模型变换器上方增加语言建模头部(线性层)
@add_start_docstrings(
"""
The Mistral Model transformer with a language modeling head (linear layer) on top.
""",
MISTRAL_START_DOCSTRING,
)
# 从 transformers.models.gptj.modeling_flax_gptj.FlaxGPTJForCausalLM 复制代码,并将 GPTJ 更改为 Mistral
class FlaxMistralForCausalLM(FlaxMistralPreTrainedModel):
module_class = FlaxMistralForCausalLMModule
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
# initializing the cache
# 获取输入的批量大小和序列长度
batch_size, seq_length = input_ids.shape
# 使用初始化方法初始化过去的键值对
past_key_values = self.init_cache(batch_size, max_length)
# 因为Mistral使用因果遮罩,对超出input_ids.shape[-1]和小于cache_length的位置已经进行了遮罩处理
# 所以我们可以在这里创建一个静态的注意力遮罩,这对编译效率更高
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
if attention_mask is not None:
# 根据给定的注意力遮罩计算位置ID
position_ids = attention_mask.cumsum(axis=-1) - 1
# 动态更新静态的注意力遮罩,将attention_mask的值复制进去
extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
else:
# 如果没有给定注意力遮罩,则使用默认的位置ID
position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
# 返回准备好的输入字典
return {
"past_key_values": past_key_values,
"attention_mask": extended_attention_mask,
"position_ids": position_ids,
}
def update_inputs_for_generation(self, model_outputs, model_kwargs):
# 更新生成过程中的输入参数
model_kwargs["past_key_values"] = model_outputs.past_key_values
# 更新位置ID,将当前位置向后移动一步
model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
return model_kwargs
# 调用函数 `append_call_sample_docstring`,用于向指定类添加示例文档字符串。
# 第一个参数 `FlaxMistralForCausalLM`:目标类,将在其上添加示例文档字符串。
# 第二个参数 `_CHECKPOINT_FOR_DOC`:用作示例文档字符串中的检查点的常量或路径。
# 第三个参数 `FlaxCausalLMOutputWithCrossAttentions`:示例文档字符串中的输出类。
# 第四个参数 `_CONFIG_FOR_DOC`:用作示例文档字符串中的配置的常量或路径。
# 关键字参数 `real_checkpoint=_REAL_CHECKPOINT_FOR_DOC`:用于指定示例文档字符串中真实检查点的常量或路径。
append_call_sample_docstring(
FlaxMistralForCausalLM,
_CHECKPOINT_FOR_DOC,
FlaxCausalLMOutputWithCrossAttentions,
_CONFIG_FOR_DOC,
real_checkpoint=_REAL_CHECKPOINT_FOR_DOC,
)