Transformers 源码解析(六十九)
.\models\mamba\__init__.py
from typing import TYPE_CHECKING
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_torch_available,
)
_import_structure = {
"configuration_mamba": ["MAMBA_PRETRAINED_CONFIG_ARCHIVE_MAP", "MambaConfig", "MambaOnnxConfig"],
}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_mamba"] = [
"MAMBA_PRETRAINED_MODEL_ARCHIVE_LIST",
"MambaForCausalLM",
"MambaModel",
"MambaPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_mamba import MAMBA_PRETRAINED_CONFIG_ARCHIVE_MAP, MambaConfig, MambaOnnxConfig
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_mamba import (
MAMBA_PRETRAINED_MODEL_ARCHIVE_LIST,
MambaForCausalLM,
MambaModel,
MambaPreTrainedModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
.\models\marian\configuration_marian.py
"""
Marian model configuration
"""
from collections import OrderedDict
from typing import Any, Mapping, Optional
from ... import PreTrainedTokenizer
from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig, OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast
from ...onnx.utils import compute_effective_axis_dimension
from ...utils import TensorType, is_torch_available, logging
logger = logging.get_logger(__name__)
MARIAN_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"Helsinki-NLP/opus-mt-en-de": "https://huggingface.co/Helsinki-NLP/opus-mt-en-de/resolve/main/config.json",
}
class MarianConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`MarianModel`]. It is used to instantiate an
Marian model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the Marian
[Helsinki-NLP/opus-mt-en-de](https://huggingface.co/Helsinki-NLP/opus-mt-en-de) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Examples:
```
>>> from transformers import MarianModel, MarianConfig
>>> # Initializing a Marian Helsinki-NLP/opus-mt-en-de style configuration
>>> configuration = MarianConfig()
>>> # Initializing a model from the Helsinki-NLP/opus-mt-en-de style configuration
>>> model = MarianModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
"""
model_type = "marian"
keys_to_ignore_at_inference = ["past_key_values"]
attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
def __init__(
self,
vocab_size=58101,
decoder_vocab_size=None,
max_position_embeddings=1024,
encoder_layers=12,
encoder_ffn_dim=4096,
encoder_attention_heads=16,
decoder_layers=12,
decoder_ffn_dim=4096,
decoder_attention_heads=16,
encoder_layerdrop=0.0,
decoder_layerdrop=0.0,
use_cache=True,
is_encoder_decoder=True,
activation_function="gelu",
d_model=1024,
dropout=0.1,
attention_dropout=0.0,
activation_dropout=0.0,
init_std=0.02,
decoder_start_token_id=58100,
scale_embedding=False,
pad_token_id=58100,
eos_token_id=0,
forced_eos_token_id=0,
share_encoder_decoder_embeddings=True,
**kwargs,
):
self.vocab_size = vocab_size
self.decoder_vocab_size = decoder_vocab_size or vocab_size
self.max_position_embeddings = max_position_embeddings
self.d_model = d_model
self.encoder_ffn_dim = encoder_ffn_dim
self.encoder_layers = encoder_layers
self.encoder_attention_heads = encoder_attention_heads
self.decoder_ffn_dim = decoder_ffn_dim
self.decoder_layers = decoder_layers
self.decoder_attention_heads = decoder_attention_heads
self.dropout = dropout
self.attention_dropout = attention_dropout
self.activation_dropout = activation_dropout
self.activation_function = activation_function
self.init_std = init_std
self.encoder_layerdrop = encoder_layerdrop
self.decoder_layerdrop = decoder_layerdrop
self.use_cache = use_cache
self.num_hidden_layers = encoder_layers
self.scale_embedding = scale_embedding
self.share_encoder_decoder_embeddings = share_encoder_decoder_embeddings
super().__init__(
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
is_encoder_decoder=is_encoder_decoder,
decoder_start_token_id=decoder_start_token_id,
forced_eos_token_id=forced_eos_token_id,
**kwargs,
)
class MarianOnnxConfig(OnnxSeq2SeqConfigWithPast):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task in ["default", "seq2seq-lm"]:
common_inputs = OrderedDict(
[
("input_ids", {0: "batch", 1: "encoder_sequence"}),
("attention_mask", {0: "batch", 1: "encoder_sequence"}),
]
)
if self.use_past:
common_inputs["decoder_input_ids"] = {0: "batch"}
common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"}
else:
common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"}
common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"}
if self.use_past:
self.fill_with_past_key_values_(common_inputs, direction="inputs")
elif self.task == "causal-lm":
common_inputs = OrderedDict(
[
("input_ids", {0: "batch", 1: "encoder_sequence"}),
("attention_mask", {0: "batch", 1: "encoder_sequence"}),
]
)
if self.use_past:
num_encoder_layers, _ = self.num_layers
for i in range(num_encoder_layers):
common_inputs[f"past_key_values.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"}
common_inputs[f"past_key_values.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"}
else:
common_inputs = OrderedDict(
[
("input_ids", {0: "batch", 1: "encoder_sequence"}),
("attention_mask", {0: "batch", 1: "encoder_sequence"}),
("decoder_input_ids", {0: "batch", 1: "decoder_sequence"}),
("decoder_attention_mask", {0: "batch", 1: "decoder_sequence"}),
]
)
return common_inputs
@property
def outputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task in ["default", "seq2seq-lm"]:
common_outputs = super().outputs
else:
common_outputs = super(OnnxConfigWithPast, self).outputs
if self.use_past:
num_encoder_layers, _ = self.num_layers
for i in range(num_encoder_layers):
common_outputs[f"present.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"}
common_outputs[f"present.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"}
return common_outputs
) -> Mapping[str, Any]:
encoder_inputs = self._generate_dummy_inputs_for_encoder_and_decoder(
tokenizer, batch_size, seq_length, is_pair, framework
)
decoder_seq_length = seq_length if not self.use_past else 1
decoder_inputs = self._generate_dummy_inputs_for_encoder_and_decoder(
tokenizer, batch_size, decoder_seq_length, is_pair, framework
)
decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()}
common_inputs = dict(**encoder_inputs, **decoder_inputs)
if self.use_past:
if not is_torch_available():
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
else:
import torch
batch, encoder_seq_length = common_inputs["input_ids"].shape
decoder_seq_length = common_inputs["decoder_input_ids"].shape[1]
num_encoder_attention_heads, num_decoder_attention_heads = self.num_attention_heads
encoder_shape = (
batch,
num_encoder_attention_heads,
encoder_seq_length,
self._config.hidden_size // num_encoder_attention_heads,
)
decoder_past_length = decoder_seq_length + 3
decoder_shape = (
batch,
num_decoder_attention_heads,
decoder_past_length,
self._config.hidden_size // num_decoder_attention_heads,
)
common_inputs["decoder_attention_mask"] = torch.cat(
[common_inputs["decoder_attention_mask"], torch.ones(batch, decoder_past_length)], dim=1
)
common_inputs["past_key_values"] = []
num_encoder_layers, num_decoder_layers = self.num_layers
min_num_layers = min(num_encoder_layers, num_decoder_layers)
max_num_layers = max(num_encoder_layers, num_decoder_layers) - min_num_layers
remaining_side_name = "encoder" if num_encoder_layers > num_decoder_layers else "decoder"
for _ in range(min_num_layers):
common_inputs["past_key_values"].append(
(
torch.zeros(decoder_shape),
torch.zeros(decoder_shape),
torch.zeros(encoder_shape),
torch.zeros(encoder_shape),
)
)
shape = encoder_shape if remaining_side_name == "encoder" else decoder_shape
for _ in range(min_num_layers, max_num_layers):
common_inputs["past_key_values"].append((torch.zeros(shape), torch.zeros(shape)))
return common_inputs
def _generate_dummy_inputs_for_causal_lm(
self,
tokenizer: PreTrainedTokenizer,
batch_size: int = -1,
seq_length: int = -1,
is_pair: bool = False,
framework: Optional[TensorType] = None,
) -> Mapping[str, Any]:
common_inputs = self._generate_dummy_inputs_for_encoder_and_decoder(
tokenizer, batch_size, seq_length, is_pair, framework
)
if self.use_past:
if not is_torch_available():
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
else:
import torch
batch, seqlen = common_inputs["input_ids"].shape
past_key_values_length = seqlen + 2
num_encoder_layers, _ = self.num_layers
num_encoder_attention_heads, _ = self.num_attention_heads
past_shape = (
batch,
num_encoder_attention_heads,
past_key_values_length,
self._config.hidden_size // num_encoder_attention_heads,
)
mask_dtype = common_inputs["attention_mask"].dtype
common_inputs["attention_mask"] = torch.cat(
[common_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)], dim=1
)
common_inputs["past_key_values"] = [
(torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(num_encoder_layers)
]
return common_inputs
def _generate_dummy_inputs_for_encoder_and_decoder(
self,
tokenizer: PreTrainedTokenizer,
batch_size: int = -1,
seq_length: int = -1,
is_pair: bool = False,
framework: Optional[TensorType] = None,
) -> Mapping[str, Any]:
def generate_dummy_inputs(
self,
tokenizer: PreTrainedTokenizer,
batch_size: int = -1,
seq_length: int = -1,
is_pair: bool = False,
framework: Optional[TensorType] = None,
) -> Mapping[str, Any]:
if self.task in ["default", "seq2seq-lm"]:
common_inputs = self._generate_dummy_inputs_for_default_and_seq2seq_lm(
tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
)
else:
common_inputs = self._generate_dummy_inputs_for_causal_lm(
tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
)
return common_inputs
def _flatten_past_key_values_(self, flattened_output, name, idx, t):
if self.task in ["default", "seq2seq-lm"]:
flattened_output = super()._flatten_past_key_values_(flattened_output, name, idx, t)
else:
flattened_output = super(OnnxSeq2SeqConfigWithPast, self)._flatten_past_key_values_(
flattened_output, name, idx, t
)
@property
def atol_for_validation(self) -> float:
return 1e-4
.\models\marian\convert_marian_tatoeba_to_pytorch.py
import argparse
import datetime
import json
import os
import re
from pathlib import Path
from typing import Tuple
import yaml
from tqdm import tqdm
from transformers.models.marian.convert_marian_to_pytorch import (
FRONT_MATTER_TEMPLATE,
convert,
convert_opus_name_to_hf_name,
download_and_unzip,
get_system_metadata,
)
DEFAULT_REPO = "Tatoeba-Challenge"
DEFAULT_MODEL_DIR = os.path.join(DEFAULT_REPO, "models")
LANG_CODE_URL = "https://datahub.io/core/language-codes/r/language-codes-3b2.csv"
ISO_URL = "https://cdn-datasets.huggingface.co/language_codes/iso-639-3.csv"
ISO_PATH = "lang_code_data/iso-639-3.csv"
LANG_CODE_PATH = "lang_code_data/language-codes-3b2.csv"
TATOEBA_MODELS_URL = "https://object.pouta.csc.fi/Tatoeba-MT-models"
class TatoebaConverter:
"""
Convert Tatoeba-Challenge models to huggingface format.
Steps:
1. Convert numpy state dict to hf format (same code as OPUS-MT-Train conversion).
2. Rename opus model to huggingface format. This means replace each alpha3 code with an alpha2 code if a unique
one exists. e.g. aav-eng -> aav-en, heb-eng -> he-en
3. Select the best model for a particular pair, parse the yml for it and write a model card. By default the
best model is the one listed first in released-model-results, but it's also possible to specify the most
recent one.
"""
def __init__(self, save_dir="marian_converted"):
assert Path(DEFAULT_REPO).exists(), "need git clone git@github.com:Helsinki-NLP/Tatoeba-Challenge.git"
self.download_lang_info()
self.model_results = json.load(open("Tatoeba-Challenge/models/released-model-results.json"))
self.alpha3_to_alpha2 = {}
for line in open(ISO_PATH):
parts = line.split("\t")
if len(parts[0]) == 3 and len(parts[3]) == 2:
self.alpha3_to_alpha2[parts[0]] = parts[3]
for line in open(LANG_CODE_PATH):
parts = line.split(",")
if len(parts[0]) == 3 and len(parts[1]) == 2:
self.alpha3_to_alpha2[parts[0]] = parts[1]
self.model_card_dir = Path(save_dir)
self.tag2name = {}
for key, value in GROUP_MEMBERS.items():
self.tag2name[key] = value[0]
def convert_models(self, tatoeba_ids, dry_run=False):
models_to_convert = [self.parse_metadata(x) for x in tatoeba_ids]
save_dir = Path("marian_ckpt")
dest_dir = Path(self.model_card_dir)
dest_dir.mkdir(exist_ok=True)
for model in tqdm(models_to_convert):
if "SentencePiece" not in model["pre-processing"]:
print(f"Skipping {model['release']} because it doesn't appear to use SentencePiece")
continue
if not os.path.exists(save_dir / model["_name"]):
download_and_unzip(f"{TATOEBA_MODELS_URL}/{model['release']}", save_dir / model["_name"])
opus_language_groups_to_hf = convert_opus_name_to_hf_name
pair_name = opus_language_groups_to_hf(model["_name"])
convert(save_dir / model["_name"], dest_dir / f"opus-mt-{pair_name}")
self.write_model_card(model, dry_run=dry_run)
def expand_group_to_two_letter_codes(self, grp_name):
return [self.alpha3_to_alpha2.get(x, x) for x in GROUP_MEMBERS[grp_name][1]]
def is_group(self, code, name):
return "languages" in name or len(GROUP_MEMBERS.get(code, [])) > 1
def get_tags(self, code, name):
if len(code) == 2:
assert "languages" not in name, f"{code}: {name}"
return [code]
elif self.is_group(code, name):
group = self.expand_group_to_two_letter_codes(code)
group.append(code)
return group
else:
print(f"Three letter monolingual code: {code}")
return [code]
def resolve_lang_code(self, src, tgt) -> Tuple[str, str]:
src_tags = self.get_tags(src, self.tag2name[src])
tgt_tags = self.get_tags(tgt, self.tag2name[tgt])
return src_tags, tgt_tags
@staticmethod
def model_type_info_from_model_name(name):
info = {"_has_backtranslated_data": False}
if "1m" in name:
info["_data_per_pair"] = str(1e6)
if "2m" in name:
info["_data_per_pair"] = str(2e6)
if "4m" in name:
info["_data_per_pair"] = str(4e6)
if "+bt" in name:
info["_has_backtranslated_data"] = True
if "tuned4" in name:
info["_tuned"] = re.search(r"tuned4[^-]+", name).group()
return info
content = (
f"""
* model: {model_dict['modeltype']}
* source language code{src_multilingual*'s'}: {', '.join(a2_src_tags)}
* target language code{tgt_multilingual*'s'}: {', '.join(a2_tgt_tags)}
* dataset: opus {backtranslated_data}
* release date: {model_dict['release-date']}
* pre-processing: {model_dict['pre-processing']}
"""
+ multilingual_data
+ tuned
+ download
+ langtoken
+ datainfo
+ testset
+ testscores
+ scorestable
)
content = FRONT_MATTER_TEMPLATE.format(lang_tags) + extra_markdown + content
items = "\n".join([f"* {k}: {v}" for k, v in metadata.items()])
sec3 = "\n### System Info: \n" + items
content += sec3
if dry_run:
print("CONTENT:")
print(content)
print("METADATA:")
print(metadata)
return
sub_dir = self.model_card_dir / model_dict["_hf_model_id"]
sub_dir.mkdir(exist_ok=True)
dest = sub_dir / "README.md"
dest.open("w").write(content)
for k, v in metadata.items():
if isinstance(v, datetime.date):
metadata[k] = datetime.datetime.strftime(v, "%Y-%m-%d")
with open(sub_dir / "metadata.json", "w", encoding="utf-8") as writeobj:
json.dump(metadata, writeobj)
def download_lang_info(self):
Path(LANG_CODE_PATH).parent.mkdir(exist_ok=True)
import wget
if not os.path.exists(ISO_PATH):
wget.download(ISO_URL, ISO_PATH)
if not os.path.exists(LANG_CODE_PATH):
wget.download(LANG_CODE_URL, LANG_CODE_PATH)
def parse_metadata(self, model_name, repo_path=DEFAULT_MODEL_DIR, method="best"):
p = Path(repo_path) / model_name
def url_to_name(url):
return url.split("/")[-1].split(".")[0]
if model_name not in self.model_results:
method = "newest"
if method == "best":
results = [url_to_name(model["download"]) for model in self.model_results[model_name]]
ymls = [f for f in os.listdir(p) if f.endswith(".yml") and f[:-4] in results]
ymls.sort(key=lambda x: results.index(x[:-4]))
metadata = yaml.safe_load(open(p / ymls[0]))
metadata.update(self.model_type_info_from_model_name(ymls[0][:-4]))
elif method == "newest":
ymls = [f for f in os.listdir(p) if f.endswith(".yml")]
ymls.sort(
key=lambda x: datetime.datetime.strptime(re.search(r"\d\d\d\d-\d\d?-\d\d?", x).group(), "%Y-%m-%d")
)
metadata = yaml.safe_load(open(p / ymls[-1]))
metadata.update(self.model_type_info_from_model_name(ymls[-1][:-4]))
else:
raise NotImplementedError(f"Don't know argument method='{method}' to parse_metadata()")
metadata["_name"] = model_name
return metadata
GROUP_MEMBERS = {
"aav": ("Austro-Asiatic languages", {"hoc", "hoc_Latn", "kha", "khm", "khm_Latn", "mnw", "vie", "vie_Hani"}),
"afa": (
"Afro-Asiatic languages",
{
"acm",
"afb",
"amh",
"apc",
"ara",
"arq",
"ary",
"arz",
"hau_Latn",
"heb",
"kab",
"mlt",
"rif_Latn",
"shy_Latn",
"som",
"thv",
"tir",
},
),
"afr": ("Afrikaans", {"afr"}),
"alv": (
"Atlantic-Congo languages",
{
"ewe",
"fuc",
"fuv",
"ibo",
"kin",
"lin",
"lug",
"nya",
"run",
"sag",
"sna",
"swh",
"toi_Latn",
"tso",
"umb",
"wol",
"xho",
"yor",
"zul",
},
),
"ara": ("Arabic", {"afb", "apc", "apc_Latn", "ara", "ara_Latn", "arq", "arq_Latn", "arz"}),
"art": (
"Artificial languages",
{
"afh_Latn",
"avk_Latn",
"dws_Latn",
"epo",
"ido",
"ido_Latn",
"ile_Latn",
"ina_Latn",
"jbo",
"jbo_Cyrl",
"jbo_Latn",
"ldn_Latn",
"lfn_Cyrl",
"lfn_Latn",
"nov_Latn",
"qya",
"qya_Latn",
"sjn_Latn",
"tlh_Latn",
"tzl",
"tzl_Latn",
"vol_Latn",
},
),
"aze": ("Azerbaijani", {"aze_Latn"}),
"bat": ("Baltic languages", {"lit", "lav", "prg_Latn", "ltg", "sgs"}),
"bel": ("Belarusian", {"bel", "bel_Latn"}),
"ben": ("Bengali", {"ben"}),
"bnt": (
"Bantu languages",
{"kin", "lin", "lug", "nya", "run", "sna", "swh", "toi_Latn", "tso", "umb", "xho", "zul"},
),
"bul": ("Bulgarian", {"bul", "bul_Latn"}),
"cat": ("Catalan", {"cat"}),
"cau": ("Caucasian languages", {"abk", "kat", "che", "ady"}),
"ccs": ("South Caucasian languages", {"kat"}),
"ceb": ("Cebuano", {"ceb"}),
"cel": ("Celtic languages", {"gla", "gle", "bre", "cor", "glv", "cym"}),
"ces": ("Czech", {"ces"}),
"cpf": ("Creoles and pidgins, French‑based", {"gcf_Latn", "hat", "mfe"}),
"cpp": (
"Creoles and pidgins, Portuguese-based",
{"zsm_Latn", "ind", "pap", "min", "tmw_Latn", "max_Latn", "zlm_Latn"},
),
"cus": ("Cushitic languages", {"som"}),
"dan": ("Danish", {"dan"}),
"deu": ("German", {"deu"}),
}
"dra": ("Dravidian languages", {"tam", "kan", "mal", "tel"}),
"ell": ("Modern Greek (1453-)", {"ell"}),
"eng": ("English", {"eng"}),
"epo": ("Esperanto", {"epo"}),
"est": ("Estonian", {"est"}),
"euq": ("Basque (family)", {"eus"}),
"eus": ("Basque", {"eus"}),
"fin": ("Finnish", {"fin"}),
"fiu": (
"Finno-Ugrian languages",
{
"est",
"fin",
"fkv_Latn",
"hun",
"izh",
"kpv",
"krl",
"liv_Latn",
"mdf",
"mhr",
"myv",
"sma",
"sme",
"udm",
"vep",
"vro",
},
),
"fra": ("French", {"fra"}),
"gem": (
"Germanic languages",
{
"afr",
"ang_Latn",
"dan",
"deu",
"eng",
"enm_Latn",
"fao",
"frr",
"fry",
"gos",
"got_Goth",
"gsw",
"isl",
"ksh",
"ltz",
"nds",
"nld",
"nno",
"nob",
"nob_Hebr",
"non_Latn",
"pdc",
"sco",
"stq",
"swe",
"swg",
"yid",
},
),
"gle": ("Irish", {"gle"}),
"glg": ("Galician", {"glg"}),
"gmq": (
"North Germanic languages",
{
"dan",
"nob",
"nob_Hebr",
"swe",
"isl",
"nno",
"non_Latn",
"fao",
},
),
"gmw": (
"West Germanic languages",
{
"afr",
"ang_Latn",
"deu",
"eng",
"enm_Latn",
"frr",
"fry",
"gos",
"gsw",
"ksh",
"ltz",
"nds",
"nld",
"pdc",
"sco",
"stq",
"swg",
"yid",
},
),
"grk": ("Greek languages", {"grc_Grek", "ell"}),
"hbs": ("Serbo-Croatian", {"hrv", "srp_Cyrl", "bos_Latn", "srp_Latn"}),
"heb": ("Hebrew", {"heb"}),
"hin": ("Hindi", {"hin"}),
"inc": (
"Indic languages",
{
"asm",
"awa",
"ben",
"bho",
"gom",
"guj",
"hif_Latn",
"hin",
"mai",
"mar",
"npi",
"ori",
"pan_Guru",
"pnb",
"rom",
"san_Deva",
"sin",
"snd_Arab",
"urd",
},
),
"ine": (
"Indo-European languages",
{
"afr", "afr_Arab", "aln", "ang_Latn", "arg", "asm", "ast", "awa", "bel",
"bel_Latn", "ben", "bho", "bjn", "bos_Latn", "bre", "bul", "bul_Latn", "cat",
"ces", "cor", "cos", "csb_Latn", "cym", "dan", "deu", "dsb", "egl", "ell",
"eng", "enm_Latn", "ext", "fao", "fra", "frm_Latn", "frr", "fry", "gcf_Latn",
"gla", "gle", "glg", "glv", "gom", "gos", "got_Goth", "grc_Grek", "gsw",
"guj", "hat", "hif_Latn", "hin", "hrv", "hsb", "hye", "hye_Latn", "ind",
"isl", "ita", "jdt_Cyrl", "ksh", "kur_Arab", "kur_Latn", "lad", "lad_Latn",
"lat_Grek", "lat_Latn", "lav", "lij", "lit", "lld_Latn", "lmo", "ltg", "ltz",
"mai", "mar", "max_Latn", "mfe", "min", "mkd", "mwl", "nds", "nld", "nno",
"nob", "nob_Hebr", "non_Latn", "npi", "oci", "ori", "orv_Cyrl", "oss",
"pan_Guru", "pap", "pcd", "pdc", "pes", "pes_Latn", "pes_Thaa", "pms",
"pnb", "pol", "por", "prg_Latn", "pus", "roh", "rom", "ron", "rue", "rus",
"rus_Latn", "san_Deva", "scn", "sco", "sgs", "sin", "slv", "snd_Arab",
"spa", "sqi", "srd", "srp_Cyrl", "srp_Latn", "stq", "swe", "swg", "tgk_Cyrl",
"tly_Latn", "tmw_Latn", "ukr", "urd", "vec", "wln", "yid", "zlm_Latn",
"zsm_Latn", "zza"
},
),
"isl": ("Icelandic", {"isl"}),
"ita": ("Italian", {"ita"}),
"itc": (
"Italic languages",
{
"arg",
"ast",
"bjn",
"cat",
"cos",
"egl",
"ext",
"fra",
"frm_Latn",
"gcf_Latn",
"glg",
"hat",
"ind",
"ita",
"lad",
"lad_Latn",
"lat_Grek",
"lat_Latn",
"lij",
"lld_Latn",
"lmo",
"max_Latn",
"mfe",
"min",
"mwl",
"oci",
"pap",
"pcd",
"pms",
"por",
"roh",
"ron",
"scn",
"spa",
"srd",
"tmw_Latn",
"vec",
"wln",
"zlm_Latn",
"zsm_Latn",
},
),
"jpn": (
"Japanese",
{
"jpn",
"jpn_Bopo",
"jpn_Hang",
"jpn_Hani",
"jpn_Hira",
"jpn_Kana",
"jpn_Latn",
"jpn_Yiii",
},
),
"jpx": (
"Japanese (family)",
{"jpn"},
),
"kat": (
"Georgian",
{"kat"},
),
"kor": (
"Korean",
{
"kor_Hani",
"kor_Hang",
"kor_Latn",
"kor",
},
),
"lav": (
"Latvian",
{"lav"},
),
"lit": (
"Lithuanian",
{"lit"},
),
"mkd": (
"Macedonian",
{"mkd"},
),
"mkh": (
"Mon-Khmer languages",
{
"vie_Hani",
"mnw",
"vie",
"kha",
"khm_Latn",
"khm",
},
),
"msa": (
"Malay (macrolanguage)",
{
"zsm_Latn",
"ind",
"max_Latn",
"zlm_Latn",
"min",
},
),
"nic": (
"Niger-Kordofanian languages",
{
"bam_Latn",
"ewe",
"fuc",
"fuv",
"ibo",
"kin",
"roa": (
"Romance languages",
{
"arg",
"ast",
"cat",
"cos",
"egl",
"ext",
"fra",
"frm_Latn",
"gcf_Latn",
"glg",
"hat",
"ind",
"ita",
"lad",
"lad_Latn",
"lij",
"lld_Latn",
"lmo",
"max_Latn",
"mfe",
"min",
"mwl",
"oci",
"pap",
"pms",
"por",
"roh",
"ron",
"scn",
"spa",
"tmw_Latn",
"vec",
"wln",
"zlm_Latn",
"zsm_Latn",
},
),
"ron": ("Romanian", {"ron"}),
"run": ("Rundi", {"run"}),
"rus": ("Russian", {"rus"}),
"sal": ("Salishan languages", {"shs_Latn"}),
"sem": (
"Semitic languages",
{
"acm",
"afb",
"amh",
"apc",
"ara",
"arq",
"ary",
"arz",
"heb",
"mlt",
"tir",
},
),
"sla": (
"Slavic languages",
{
"bel",
"bel_Latn",
"bos_Latn",
"bul",
"bul_Latn",
"ces",
"csb_Latn",
"dsb",
"hrv",
"hsb",
"mkd",
"orv_Cyrl",
"pol",
"rue",
"rus",
"slv",
"srp_Cyrl",
"srp_Latn",
"ukr",
},
),
"slv": ("Slovenian", {"slv"}),
"spa": ("Spanish", {"spa"}),
"swe": ("Swedish", {"swe"}),
"taw": ("Tai", {"lao", "tha"}),
"tgl": ("Tagalog", {"tgl_Latn"}),
"tha": ("Thai", {"tha"}),
"trk": (
"Turkic languages",
{
"aze_Latn",
"bak",
"chv",
"crh",
"crh_Latn",
"kaz_Cyrl",
"kaz_Latn",
"kir_Cyrl",
"kjh",
"kum",
"ota_Arab",
"ota_Latn",
"sah",
"tat",
"tat_Arab",
"tat_Latn",
"tuk",
"tuk_Latn",
"tur",
"tyv",
"uig_Arab",
"uig_Cyrl",
"uzb_Cyrl",
"zho": (
"Chinese",
{
"cjy_Hans",
"cjy_Hant",
"cmn",
"cmn_Bopo",
"cmn_Hang",
"cmn_Hani",
"cmn_Hans",
"cmn_Hant",
"cmn_Hira",
"cmn_Kana",
"cmn_Latn",
"cmn_Yiii",
"gan",
"hak_Hani",
"lzh",
"lzh_Bopo",
"lzh_Hang",
"lzh_Hani",
"lzh_Hans",
"lzh_Hira",
"lzh_Kana",
"lzh_Yiii",
"nan",
"nan_Hani",
"wuu",
"wuu_Bopo",
"wuu_Hani",
"wuu_Latn",
"yue",
"yue_Bopo",
"yue_Hang",
"yue_Hani",
"yue_Hans",
"yue_Hant",
"yue_Hira",
"yue_Kana",
"zho",
"zho_Hans",
"zho_Hant",
},
),
"zle": (
"East Slavic languages",
{
"bel",
"orv_Cyrl",
"bel_Latn",
"rus",
"ukr",
"rue",
},
),
"zls": (
"South Slavic languages",
{
"bos_Latn",
"bul",
"bul_Latn",
"hrv",
"mkd",
"slv",
"srp_Cyrl",
"srp_Latn",
},
),
"zlw": (
"West Slavic languages",
{
"csb_Latn",
"dsb",
"hsb",
"pol",
"ces",
},
),
}
def l2front_matter(langs):
return "".join(f"- {l}\n" for l in langs)
def dedup(lst):
"""Preservers order"""
new_lst = []
for item in lst:
if not item or item in new_lst:
continue
else:
new_lst.append(item)
return new_lst
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-m", "--models", action="append", help="<Required> Set flag", required=True, nargs="+", dest="models"
)
parser.add_argument("-save_dir", "--save_dir", default="marian_converted", help="where to save converted models")
args = parser.parse_args()
resolver = TatoebaConverter(save_dir=args.save_dir)
resolver.convert_models(args.models[0])
.\models\marian\convert_marian_to_pytorch.py
import argparse
import json
import os
import socket
import time
import warnings
from pathlib import Path
from typing import Dict, List, Union
from zipfile import ZipFile
import numpy as np
import torch
from huggingface_hub.hf_api import list_models
from torch import nn
from tqdm import tqdm
from transformers import MarianConfig, MarianMTModel, MarianTokenizer
def remove_suffix(text: str, suffix: str):
if text.endswith(suffix):
return text[: -len(suffix)]
return text
def remove_prefix(text: str, prefix: str):
if text.startswith(prefix):
return text[len(prefix) :]
return text
def convert_encoder_layer(opus_dict, layer_prefix: str, converter: dict):
sd = {}
for k in opus_dict:
if not k.startswith(layer_prefix):
continue
stripped = remove_prefix(k, layer_prefix)
v = opus_dict[k].T
sd[converter[stripped]] = torch.tensor(v).squeeze()
return sd
def load_layers_(layer_lst: nn.ModuleList, opus_state: dict, converter, is_decoder=False):
for i, layer in enumerate(layer_lst):
layer_tag = f"decoder_l{i + 1}_" if is_decoder else f"encoder_l{i + 1}_"
sd = convert_encoder_layer(opus_state, layer_tag, converter)
layer.load_state_dict(sd, strict=False)
def find_pretrained_model(src_lang: str, tgt_lang: str) -> List[str]:
"""查找可以接受指定源语言并输出目标语言的模型列表。"""
prefix = "Helsinki-NLP/opus-mt-"
model_list = list_models()
model_ids = [x.modelId for x in model_list if x.modelId.startswith("Helsinki-NLP")]
src_and_targ = [
remove_prefix(m, prefix).lower().split("-") for m in model_ids if "+" not in m
]
matching = [f"{prefix}{a}-{b}" for (a, b) in src_and_targ if src_lang in a and tgt_lang in b]
return matching
def add_emb_entries(wemb, final_bias, n_special_tokens=1):
vsize, d_model = wemb.shape
embs_to_add = np.zeros((n_special_tokens, d_model))
new_embs = np.concatenate([wemb, embs_to_add])
bias_to_add = np.zeros((n_special_tokens, 1))
new_bias = np.concatenate((final_bias, bias_to_add), axis=1)
return new_embs, new_bias
def _cast_yaml_str(v):
bool_dct = {"true": True, "false": False}
if not isinstance(v, str):
return v
elif v in bool_dct:
return bool_dct[v]
try:
return int(v)
except (TypeError, ValueError):
return v
def cast_marian_config(raw_cfg: Dict[str, str]) -> Dict:
return {k: _cast_yaml_str(v) for k, v in raw_cfg.items()}
CONFIG_KEY = "special:model.yml"
def load_config_from_state_dict(opus_dict):
import yaml
cfg_str = "".join([chr(x) for x in opus_dict[CONFIG_KEY]])
yaml_cfg = yaml.load(cfg_str[:-1], Loader=yaml.BaseLoader)
return cast_marian_config(yaml_cfg)
def find_model_file(dest_dir):
model_files = list(Path(dest_dir).glob("*.npz"))
if len(model_files) != 1:
raise ValueError(f"Found more than one model file: {model_files}")
model_file = model_files[0]
return model_file
ROM_GROUP = (
"fr+fr_BE+fr_CA+fr_FR+wa+frp+oc+ca+rm+lld+fur+lij+lmo+es+es_AR+es_CL+es_CO+es_CR+es_DO"
"+es_EC+es_ES+es_GT+es_HN+es_MX+es_NI+es_PA+es_PE+es_PR+es_SV+es_UY+es_VE+pt+pt_br+pt_BR"
"+pt_PT+gl+lad+an+mwl+it+it_IT+co+nap+scn+vec+sc+ro+la"
)
GROUPS = [
("cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh", "ZH"),
(ROM_GROUP, "ROMANCE"),
("de+nl+fy+af+da+fo+is+no+nb+nn+sv", "NORTH_EU"),
("da+fo+is+no+nb+nn+sv", "SCANDINAVIA"),
("se+sma+smj+smn+sms", "SAMI"),
("nb_NO+nb+nn_NO+nn+nog+no_nb+no", "NORWAY"),
("ga+cy+br+gd+kw+gv", "CELTIC"),
]
GROUP_TO_OPUS_NAME = {
"opus-mt-ZH-de": "cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh-de",
"opus-mt-ZH-fi": "cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh-fi",
"opus-mt-ZH-sv": "cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh-sv",
"opus-mt-SCANDINAVIA-SCANDINAVIA": "da+fo+is+no+nb+nn+sv-da+fo+is+no+nb+nn+sv",
"opus-mt-NORTH_EU-NORTH_EU": "de+nl+fy+af+da+fo+is+no+nb+nn+sv-de+nl+fy+af+da+fo+is+no+nb+nn+sv",
"opus-mt-de-ZH": "de-cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh",
"opus-mt-en_el_es_fi-en_el_es_fi": "en+el+es+fi-en+el+es+fi",
"opus-mt-en-ROMANCE": (
"en-fr+fr_BE+fr_CA+fr_FR+wa+frp+oc+ca+rm+lld+fur+lij+lmo+es+es_AR+es_CL+es_CO+es_CR+es_DO"
"+es_EC+es_ES+es_GT+es_HN+es_MX+es_NI+es_PA+es_PE+es_PR+es_SV+es_UY+es_VE+pt+pt_br+pt_BR"
"+pt_PT+gl+lad+an+mwl+it+it_IT+co+nap+scn+vec+sc+ro+la"
),
"opus-mt-en-CELTIC": "en-ga+cy+br+gd+kw+gv",
"opus-mt-es-NORWAY": "es-nb_NO+nb+nn_NO+nn+nog+no_nb+no",
"opus-mt-fi_nb_no_nn_ru_sv_en-SAMI": "fi+nb+no+nn+ru+sv+en-se+sma+smj+smn+sms",
"opus-mt-fi-ZH": "fi-cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh",
"opus-mt-fi-NORWAY": "fi-nb_NO+nb+nn_NO+nn+nog+no_nb+no",
"opus-mt-ROMANCE-en": (
"fr+fr_BE+fr_CA+fr_FR+wa+frp+oc+ca+rm+lld+fur+lij+lmo+es+es_AR+es_CL+es_CO+es_CR+es_DO"
"+es_EC+es_ES+es_GT+es_HN+es_MX+es_NI+es_PA+es_PE+es_PR+es_SV+es_UY+es_VE+pt+pt_br+pt_BR"
"+pt_PT+gl+lad+an+mwl+it+it_IT+co+nap+scn+vec+sc+ro+la-en"
),
"opus-mt-CELTIC-en": "ga+cy+br+gd+kw+gv-en",
"opus-mt-sv-ZH": "sv-cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh",
"opus-mt-sv-NORWAY": "sv-nb_NO+nb+nn_NO+nn+nog+no_nb+no",
}
OPUS_GITHUB_URL = "https://github.com/Helsinki-NLP/OPUS-MT-train/blob/master/models/"
ORG_NAME = "Helsinki-NLP/"
def convert_opus_name_to_hf_name(x):
"""将 OPUS-MT-Train 名称转换为 Hugging Face 模型名称(已弃用)"""
for substr, grp_name in GROUPS:
x = x.replace(substr, grp_name)
return x.replace("+", "_")
def convert_hf_name_to_opus_name(hf_model_name):
"""
根据假设,假设在不在 GROUP_TO_OPUS_NAME 中的模型中没有像 pt_br 这样的语言代码。
将 Hugging Face 模型名称转换为 OPUS-MT-Train 名称
"""
hf_model_name = remove_prefix(hf_model_name, ORG_NAME)
if hf_model_name in GROUP_TO_OPUS_NAME:
opus_w_prefix = GROUP_TO_OPUS_NAME[hf_model_name]
else:
opus_w_prefix = hf_model_name.replace("_", "+")
return remove_prefix(opus_w_prefix, "opus-mt-")
def get_system_metadata(repo_root):
import git
return {
"helsinki_git_sha": git.Repo(path=repo_root, search_parent_directories=True).head.object.hexsha,
"transformers_git_sha": git.Repo(path=".", search_parent_directories=True).head.object.hexsha,
"port_machine": socket.gethostname(),
"port_time": time.strftime("%Y-%m-%d-%H:%M"),
}
FRONT_MATTER_TEMPLATE = """---
language:
{}
tags:
- translation
license: apache-2.0
---
"""
DEFAULT_REPO = "Tatoeba-Challenge"
DEFAULT_MODEL_DIR = os.path.join(DEFAULT_REPO, "models")
def write_model_card(
hf_model_name: str,
repo_root=DEFAULT_REPO,
save_dir=Path("marian_converted"),
dry_run=False,
extra_metadata={},
) -> str:
"""
复制最新模型的 readme 部分来自 OPUS,并添加元数据。上传命令: aws s3 sync model_card_dir
s3://models.huggingface.co/bert/Helsinki-NLP/ --dryrun
"""
import pandas as pd
hf_model_name = remove_prefix(hf_model_name, ORG_NAME)
opus_name: str = convert_hf_name_to_opus_name(hf_model_name)
if repo_root not in ("OPUS-MT-train", "Tatoeba-Challenge"):
raise ValueError(f"Repos root is {repo_root}. Expected either OPUS-MT-train or Tatoeba-Challenge")
opus_readme_path = Path(repo_root).joinpath("models", opus_name, "README.md")
if not (opus_readme_path.exists()):
raise ValueError(f"Readme file {opus_readme_path} not found")
opus_src, opus_tgt = [x.split("+") for x in opus_name.split("-")]
readme_url = f"https://github.com/Helsinki-NLP/{repo_root}/tree/master/models/{opus_name}/README.md"
s, t = ",".join(opus_src), ",".join(opus_tgt)
metadata = {
"hf_name": hf_model_name,
"source_languages": s,
"target_languages": t,
"opus_readme_url": readme_url,
"original_repo": repo_root,
"tags": ["translation"],
}
metadata.update(extra_metadata)
metadata.update(get_system_metadata(repo_root))
extra_markdown = (
f"### {hf_model_name}\n\n* source group: {metadata['src_name']} \n* target group: "
f"{metadata['tgt_name']} \n* OPUS readme: [{opus_name}]({readme_url})\n"
)
content = opus_readme_path.open().read()
content = content.split("\n# ")[-1]
splat = content.split("*")[2:]
print(splat[3])
content = "*".join(splat)
content = (
FRONT_MATTER_TEMPLATE.format(metadata["src_alpha2"])
+ extra_markdown
+ "\n* "
+ content.replace("download", "download original weights")
)
items = "\n\n".join([f"- {k}: {v}" for k, v in metadata.items()])
sec3 = "\n### System Info: \n" + items
content += sec3
if dry_run:
return content, metadata
sub_dir = save_dir / f"opus-mt-{hf_model_name}"
sub_dir.mkdir(exist_ok=True)
dest = sub_dir / "README.md"
dest.open("w").write(content)
pd.Series(metadata).to_json(sub_dir / "metadata.json")
return content, metadata
def make_registry(repo_path="Opus-MT-train/models"):
if not (Path(repo_path) / "fr-en" / "README.md").exists():
raise ValueError(
f"repo_path:{repo_path} does not exist: "
"You must run: git clone git@github.com:Helsinki-NLP/Opus-MT-train.git before calling."
)
results = {}
for p in Path(repo_path).iterdir():
n_dash = p.name.count("-")
if n_dash == 0:
continue
else:
lns = list(open(p / "README.md").readlines())
results[p.name] = _parse_readme(lns)
return [(k, v["pre-processing"], v["download"], v["download"][:-4] + ".test.txt") for k, v in results.items()]
def convert_all_sentencepiece_models(model_list=None, repo_path=None, dest_dir=Path("marian_converted")):
"""Requires 300GB"""
save_dir = Path("marian_ckpt")
dest_dir = Path(dest_dir)
dest_dir.mkdir(exist_ok=True)
save_paths = []
if model_list is None:
model_list: list = make_registry(repo_path=repo_path)
for k, prepro, download, test_set_url in tqdm(model_list):
if "SentencePiece" not in prepro:
continue
if not os.path.exists(save_dir / k):
download_and_unzip(download, save_dir / k)
pair_name = convert_opus_name_to_hf_name(k)
convert(save_dir / k, dest_dir / f"opus-mt-{pair_name}")
save_paths.append(dest_dir / f"opus-mt-{pair_name}")
return save_paths
def lmap(f, x) -> List:
return list(map(f, x))
def fetch_test_set(test_set_url):
import wget
fname = wget.download(test_set_url, "opus_test.txt")
lns = Path(fname).open().readlines()
src = lmap(str.strip, lns[::4])
gold = lmap(str.strip, lns[1::4])
mar_model = lmap(str.strip, lns[2::4])
if not (len(gold) == len(mar_model) == len(src)):
raise ValueError(f"Gold, marian and source lengths {len(gold)}, {len(mar_model)}, {len(src)} mismatched")
os.remove(fname)
return src, mar_model, gold
def convert_whole_dir(path=Path("marian_ckpt/")):
for subdir in tqdm(list(path.ls())):
dest_dir = f"marian_converted/{subdir.name}"
if (dest_dir / "pytorch_model.bin").exists():
continue
convert(source_dir, dest_dir)
def _parse_readme(lns):
"""Get link and metadata from opus model card equivalent."""
subres = {}
for ln in [x.strip() for x in lns]:
if not ln.startswith("*"):
continue
ln = ln[1:].strip()
for k in ["download", "dataset", "models", "model", "pre-processing"]:
if ln.startswith(k):
break
else:
continue
if k in ["dataset", "model", "pre-processing"]:
splat = ln.split(":")
_, v = splat
subres[k] = v
elif k == "download":
v = ln.split("(")[-1][:-1]
subres[k] = v
return subres
def save_tokenizer_config(dest_dir: Path, separate_vocabs=False):
dname = dest_dir.name.split("-")
dct = {"target_lang": dname[-1], "source_lang": "-".join(dname[:-1]), "separate_vocabs": separate_vocabs}
save_json(dct, dest_dir / "tokenizer_config.json")
def add_special_tokens_to_vocab(model_dir: Path, separate_vocab=False) -> None:
if separate_vocab:
vocab = load_yaml(find_src_vocab_file(model_dir))
vocab = {k: int(v) for k, v in vocab.items()}
num_added = add_to_vocab_(vocab, ["<pad>"])
save_json(vocab, model_dir / "vocab.json")
vocab = load_yaml(find_tgt_vocab_file(model_dir))
vocab = {k: int(v) for k, v in vocab.items()}
num_added = add_to_vocab_(vocab, ["<pad>"])
save_json(vocab, model_dir / "target_vocab.json")
save_tokenizer_config(model_dir, separate_vocabs=separate_vocab)
else:
vocab = load_yaml(find_vocab_file(model_dir))
vocab = {k: int(v) for k, v in vocab.items()}
num_added = add_to_vocab_(vocab, ["<pad>"])
print(f"added {num_added} tokens to vocab")
save_json(vocab, model_dir / "vocab.json")
save_tokenizer_config(model_dir)
def check_equal(marian_cfg, k1, k2):
v1, v2 = marian_cfg[k1], marian_cfg[k2]
if v1 != v2:
raise ValueError(f"hparams {k1},{k2} differ: {v1} != {v2}")
def find_vocab_file(model_dir):
return list(model_dir.glob("*vocab.yml"))[0]
def find_src_vocab_file(model_dir):
return list(model_dir.glob("*src.vocab.yml"))[0]
def find_tgt_vocab_file(model_dir):
return list(model_dir.glob("*trg.vocab.yml"))[0]
def add_to_vocab_(vocab: Dict[str, int], special_tokens: List[str]):
start = max(vocab.values()) + 1
added = 0
for tok in special_tokens:
if tok in vocab:
continue
vocab[tok] = start + added
added += 1
return added
def check_marian_cfg_assumptions(marian_cfg):
assumed_settings = {
"layer-normalization": False,
"right-left": False,
"transformer-ffn-depth": 2,
"transformer-aan-depth": 2,
"transformer-no-projection": False,
"transformer-postprocess-emb": "d",
"transformer-postprocess": "dan",
"transformer-preprocess": "",
"type": "transformer",
"ulr-dim-emb": 0,
"dec-cell-base-depth": 2,
"dec-cell-high-depth": 1,
"transformer-aan-nogate": False,
}
for k, v in assumed_settings.items():
actual = marian_cfg[k]
if actual != v:
raise ValueError(f"Unexpected config value for {k} expected {v} got {actual}")
BIAS_KEY = "decoder_ff_logit_out_b"
BART_CONVERTER = {
"self_Wq": "self_attn.q_proj.weight",
"self_Wk": "self_attn.k_proj.weight",
"self_Wv": "self_attn.v_proj.weight",
"self_Wo": "self_attn.out_proj.weight",
"self_bq": "self_attn.q_proj.bias",
"self_bk": "self_attn.k_proj.bias",
"self_bv": "self_attn.v_proj.bias",
"self_bo": "self_attn.out_proj.bias",
"self_Wo_ln_scale": "self_attn_layer_norm.weight",
"self_Wo_ln_bias": "self_attn_layer_norm.bias",
"ffn_W1": "fc1.weight",
"ffn_b1": "fc1.bias",
}
"ffn_W2": "fc2.weight",
"ffn_b2": "fc2.bias",
"ffn_ffn_ln_scale": "final_layer_norm.weight",
"ffn_ffn_ln_bias": "final_layer_norm.bias",
"context_Wk": "encoder_attn.k_proj.weight",
"context_Wo": "encoder_attn.out_proj.weight",
"context_Wq": "encoder_attn.q_proj.weight",
"context_Wv": "encoder_attn.v_proj.weight",
"context_bk": "encoder_attn.k_proj.bias",
"context_bo": "encoder_attn.out_proj.bias",
"context_bq": "encoder_attn.q_proj.bias",
"context_bv": "encoder_attn.v_proj.bias",
"context_Wo_ln_scale": "encoder_attn_layer_norm.weight",
"context_Wo_ln_bias": "encoder_attn_layer_norm.bias",
}
class OpusState:
def _check_layer_entries(self):
self.encoder_l1 = self.sub_keys("encoder_l1")
self.decoder_l1 = self.sub_keys("decoder_l1")
self.decoder_l2 = self.sub_keys("decoder_l2")
if len(self.encoder_l1) != 16:
warnings.warn(f"Expected 16 keys for each encoder layer, got {len(self.encoder_l1)}")
if len(self.decoder_l1) != 26:
warnings.warn(f"Expected 26 keys for each decoder layer, got {len(self.decoder_l1)}")
if len(self.decoder_l2) != 26:
warnings.warn(f"Expected 26 keys for each decoder layer, got {len(self.decoder_l1)}")
@property
def extra_keys(self):
extra = []
for k in self.state_keys:
if (
k.startswith("encoder_l")
or k.startswith("decoder_l")
or k in [CONFIG_KEY, "Wemb", "encoder_Wemb", "decoder_Wemb", "Wpos", "decoder_ff_logit_out_b"]
):
continue
else:
extra.append(k)
return extra
def sub_keys(self, layer_prefix):
return [remove_prefix(k, layer_prefix) for k in self.state_dict if k.startswith(layer_prefix)]
def load_tokenizer(self):
add_special_tokens_to_vocab(self.source_dir, not self.share_encoder_decoder_embeddings)
return MarianTokenizer.from_pretrained(str(self.source_dir))
def load_marian_model(self) -> MarianMTModel:
state_dict, cfg = self.state_dict, self.hf_config
if not cfg.static_position_embeddings:
raise ValueError("config.static_position_embeddings should be True")
model = MarianMTModel(cfg)
if "hidden_size" in cfg.to_dict():
raise ValueError("hidden_size is in config")
load_layers_(
model.model.encoder.layers,
state_dict,
BART_CONVERTER,
)
load_layers_(model.model.decoder.layers, state_dict, BART_CONVERTER, is_decoder=True)
if self.cfg["tied-embeddings-src"]:
wemb_tensor = nn.Parameter(torch.FloatTensor(self.wemb))
bias_tensor = nn.Parameter(torch.FloatTensor(self.final_bias))
model.model.shared.weight = wemb_tensor
model.model.encoder.embed_tokens = model.model.decoder.embed_tokens = model.model.shared
else:
wemb_tensor = nn.Parameter(torch.FloatTensor(self.wemb))
model.model.encoder.embed_tokens.weight = wemb_tensor
decoder_wemb_tensor = nn.Parameter(torch.FloatTensor(self.dec_wemb))
bias_tensor = nn.Parameter(torch.FloatTensor(self.final_bias))
model.model.decoder.embed_tokens.weight = decoder_wemb_tensor
model.final_logits_bias = bias_tensor
if "Wpos" in state_dict:
print("Unexpected: got Wpos")
wpos_tensor = torch.tensor(state_dict["Wpos"])
model.model.encoder.embed_positions.weight = wpos_tensor
model.model.decoder.embed_positions.weight = wpos_tensor
if cfg.normalize_embedding:
if "encoder_emb_ln_scale_pre" not in state_dict:
raise ValueError("encoder_emb_ln_scale_pre is not in state dictionary")
raise NotImplementedError("Need to convert layernorm_embedding")
if self.extra_keys:
raise ValueError(f"Failed to convert {self.extra_keys}")
if model.get_input_embeddings().padding_idx != self.pad_token_id:
raise ValueError(
f"Padding tokens {model.get_input_embeddings().padding_idx} and {self.pad_token_id} mismatched"
)
return model
"""
Tatoeba conversion instructions in scripts/tatoeba/README.md
"""
parser = argparse.ArgumentParser()
parser.add_argument("--src", type=str, help="path to marian model sub dir", default="en-de")
parser.add_argument("--dest", type=str, default=None, help="Path to the output PyTorch model.")
args = parser.parse_args()
source_dir = Path(args.src)
if not source_dir.exists():
raise ValueError(f"Source directory {source_dir} not found")
dest_dir = f"converted-{source_dir.name}" if args.dest is None else args.dest
convert(source_dir, dest_dir)
.\models\marian\modeling_flax_marian.py
""" Flax Marian model."""
import math
import random
from functools import partial
from typing import Callable, Optional, Tuple
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask
from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax
from jax.random import PRNGKey
from ...modeling_flax_outputs import (
FlaxBaseModelOutput,
FlaxBaseModelOutputWithPastAndCrossAttentions,
FlaxCausalLMOutputWithCrossAttentions,
FlaxSeq2SeqLMOutput,
FlaxSeq2SeqModelOutput,
)
from ...modeling_flax_utils import (
ACT2FN,
FlaxPreTrainedModel,
append_call_sample_docstring,
append_replace_return_docstrings,
overwrite_call_docstring,
)
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_marian import MarianConfig
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "Helsinki-NLP/opus-mt-en-de"
_CONFIG_FOR_DOC = "MarianConfig"
MARIAN_START_DOCSTRING = r"""
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a Flax Linen
[flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
Finally, this model supports inherent JAX features such as:
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
"""
Parameters:
config ([`MarianConfig`]): Model configuration class with all the parameters of the model.
初始化模型配置类,包含所有模型参数。
通过配置文件初始化不会加载与模型相关的权重,仅加载配置。
参考 [`~FlaxPreTrainedModel.from_pretrained`] 方法以加载模型权重。
dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
计算的数据类型。可以是 `jax.numpy.float32`、`jax.numpy.float16`(在GPU上)和 `jax.numpy.bfloat16`(在TPU上)之一。
可用于在GPU或TPU上启用混合精度训练或半精度推断。
如果指定,所有计算将使用给定的 `dtype` 执行。
**注意,这仅指定计算的数据类型,不影响模型参数的数据类型。**
如果要更改模型参数的数据类型,请参阅 [`~FlaxPreTrainedModel.to_fp16`] 和 [`~FlaxPreTrainedModel.to_bf16`]。
"""
MARIAN_INPUTS_DOCSTRING = r"""
"""
MARIAN_ENCODE_INPUTS_DOCSTRING = r"""
Args:
input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
输入序列标记的索引,用于词汇表中的标记。默认情况下会忽略填充。
可以使用 [`AutoTokenizer`] 获得这些索引。参见 [`PreTrainedTokenizer.encode`] 和 [`PreTrainedTokenizer.__call__`] 了解详情。
[什么是输入 ID?](../glossary
attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *可选*):
遮罩,用于避免在填充的标记索引上进行注意力计算。遮罩值选在 `[0, 1]` 范围内:
- 1 表示**不遮罩**的标记,
- 0 表示**遮罩**的标记。
[什么是注意力遮罩?](../glossary
position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *可选*):
输入序列标记在位置嵌入中的位置索引。选择范围是 `[0, config.max_position_embeddings - 1]`。
output_attentions (`bool`, *可选*):
是否返回所有注意力层的注意力张量。更多细节请参见返回的张量中的 `attentions` 字段。
output_hidden_states (`bool`, *可选*):
是否返回所有层的隐藏状态。更多细节请参见返回的张量中的 `hidden_states` 字段。
return_dict (`bool`, *可选*):
是否返回一个 [`~utils.ModelOutput`] 而不是简单的元组。
"""
MARIAN_DECODE_INPUTS_DOCSTRING = r"""
"""
def create_sinusoidal_positions(n_pos, dim):
"""
创建正弦位置编码。
Args:
n_pos (int): 位置数量。
dim (int): 编码维度。
Returns:
jnp.ndarray: 正弦位置编码的数组。
"""
position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)])
sentinel = dim // 2 + dim % 2
out = np.zeros_like(position_enc)
out[:, 0:sentinel] = np.sin(position_enc[:, 0::2])
out[:, sentinel:] = np.cos(position_enc[:, 1::2])
return jnp.array(out)
# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
"""
将输入的标记向右移动一位。
"""
shifted_input_ids = jnp.zeros_like(input_ids)
shifted_input_ids = shifted_input_ids.at[:, 1:].set(input_ids[:, :-1])
shifted_input_ids = shifted_input_ids.at[:, 0].set(decoder_start_token_id)
shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
return shifted_input_ids
# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention with Bart->Marian
class FlaxMarianAttention(nn.Module):
"""
Marian 模型的注意力机制模块。
"""
config: MarianConfig
embed_dim: int
num_heads: int
dropout: float = 0.0
causal: bool = False
bias: bool = True
dtype: jnp.dtype = jnp.float32 # 计算中使用的数据类型,默认为 jnp.float32
def setup(self) -> None:
self.head_dim = self.embed_dim // self.num_heads # 计算每个头部的维度
if self.head_dim * self.num_heads != self.embed_dim: # 检查 embed_dim 是否能被 num_heads 整除
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {self.num_heads})."
)
# 创建一个部分应用了 nn.Dense 的函数,用于创建全连接层
dense = partial(
nn.Dense,
self.embed_dim,
use_bias=self.bias,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
# 初始化查询、键、值、输出的全连接层
self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
self.out_proj = dense()
# 初始化 dropout 层
self.dropout_layer = nn.Dropout(rate=self.dropout)
if self.causal:
# 如果需要因果注意力,创建因果 mask
self.causal_mask = make_causal_mask(
jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool"
)
def _split_heads(self, hidden_states):
# 将隐藏状态张量按照头部数目和头部维度进行重塑
return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
def _merge_heads(self, hidden_states):
# 将分离的头部重新合并成原来的形状
return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
@nn.compact
"""
This function takes projected key, value states from a single input token and concatenates the states to cached
states from previous steps. This function is slightly adapted from the official Flax repository:
https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py
"""
# 检测是否初始化缓存数据
is_initialized = self.has_variable("cache", "cached_key")
# 获取缓存的键(key)并初始化为零数组,其形状和数据类型与输入的键(key)相同
cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
# 获取缓存的值(value)并初始化为零数组,其形状和数据类型与输入的值(value)相同
cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
# 获取缓存索引(index),如果不存在则初始化为零
cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
if is_initialized:
# 获取批次维度的数量和最大长度、注意力头数、每个头部的深度
*batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
# 更新键(key)和值(value)缓存,使用新的一维空间切片
cur_index = cache_index.value
indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
key = lax.dynamic_update_slice(cached_key.value, key, indices)
value = lax.dynamic_update_slice(cached_value.value, value, indices)
# 更新缓存中的键(key)和值(value)
cached_key.value = key
cached_value.value = value
# 更新缓存索引,增加已更新的缓存向量数
num_updated_cache_vectors = query.shape[1]
cache_index.value = cache_index.value + num_updated_cache_vectors
# 对于缓存的解码器自注意力,创建因果掩码:我们的单个查询位置只能关注已生成和缓存的键位置,而不是剩余的零元素
pad_mask = jnp.broadcast_to(
jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
)
# 合并因果掩码和输入的注意力掩码
attention_mask = combine_masks(pad_mask, attention_mask)
# 返回更新后的键(key)、值(value)和注意力掩码
return key, value, attention_mask
# 从 transformers.models.bart.modeling_flax_bart.FlaxBartEncoderLayer 复制代码并将 Bart->Marian 替换
class FlaxMarianEncoderLayer(nn.Module):
# Marian 模型配置
config: MarianConfig
# 计算的数据类型,默认为 jnp.float32
dtype: jnp.dtype = jnp.float32
# 设置层的初始化操作
def setup(self) -> None:
# 设置嵌入维度为模型配置中的 d_model
self.embed_dim = self.config.d_model
# 定义自注意力层
self.self_attn = FlaxMarianAttention(
config=self.config,
embed_dim=self.embed_dim,
num_heads=self.config.encoder_attention_heads,
dropout=self.config.attention_dropout,
dtype=self.dtype,
)
# 自注意力层后的 Layer Normalization
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
# Dropout 层
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
# 激活函数
self.activation_fn = ACT2FN[self.config.activation_function]
# 激活函数后的 Dropout 层
self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
# 第一个全连接层,使用 jax 的正态分布初始化
self.fc1 = nn.Dense(
self.config.encoder_ffn_dim,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
# 第二个全连接层,输出维度为 embed_dim,同样使用正态分布初始化
self.fc2 = nn.Dense(
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
)
# 最终的 Layer Normalization
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
# 实现类的调用方法,对输入的隐藏状态进行编码处理
def __call__(
self,
hidden_states: jnp.ndarray,
attention_mask: jnp.ndarray,
output_attentions: bool = True,
deterministic: bool = True,
) -> Tuple[jnp.ndarray]:
# 保存残差连接
residual = hidden_states
# 应用自注意力机制,得到新的隐藏状态和注意力权重
hidden_states, attn_weights = self.self_attn(hidden_states=hidden_states, attention_mask=attention_mask)
# 应用 Dropout
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
# 残差连接和新隐藏状态相加
hidden_states = residual + hidden_states
# 应用自注意力层后的 Layer Normalization
hidden_states = self.self_attn_layer_norm(hidden_states)
# 保存残差连接
residual = hidden_states
# 应用激活函数和第一个全连接层
hidden_states = self.activation_fn(self.fc1(hidden_states))
# 应用激活函数后的 Dropout
hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic)
# 应用第二个全连接层
hidden_states = self.fc2(hidden_states)
# 应用 Dropout
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
# 残差连接和新隐藏状态相加
hidden_states = residual + hidden_states
# 应用最终的 Layer Normalization
hidden_states = self.final_layer_norm(hidden_states)
# 输出为一个元组,包含最终的隐藏状态
outputs = (hidden_states,)
# 如果需要输出注意力权重,加入到输出元组中
if output_attentions:
outputs += (attn_weights,)
return outputs
# 从 transformers.models.bart.modeling_flax_bart.FlaxBartEncoderLayerCollection 复制代码并将 Bart->Marian 替换
class FlaxMarianEncoderLayerCollection(nn.Module):
# Marian 模型配置
config: MarianConfig
# 计算的数据类型,默认为 jnp.float32
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
# 设置层的初始化操作
def setup(self):
# 创建编码层的集合,每个编码层使用 FlaxMarianEncoderLayer 创建
self.layers = [
FlaxMarianEncoderLayer(self.config, name=str(i), dtype=self.dtype)
for i in range(self.config.encoder_layers)
]
# 编码层的 dropout 率
self.layerdrop = self.config.encoder_layerdrop
# 定义一个特殊方法 __call__,使得对象可以被调用
def __call__(
self,
hidden_states,
attention_mask,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
# 如果需要输出注意力权重,则初始化空的元组用于存储所有注意力权重
all_attentions = () if output_attentions else None
# 如果需要输出隐藏状态,则初始化空的元组用于存储所有隐藏状态
all_hidden_states = () if output_hidden_states else None
# 遍历所有的编码器层
for encoder_layer in self.layers:
# 如果需要输出隐藏状态,则将当前隐藏状态添加到 all_hidden_states 中
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
# 添加 LayerDrop 功能,参见论文 https://arxiv.org/abs/1909.11556 的描述
dropout_probability = random.uniform(0, 1)
# 如果非确定性且随机数小于层级丢弃率,则跳过当前层
if not deterministic and (dropout_probability < self.layerdrop):
layer_outputs = (None, None) # 跳过层的输出
else:
# 否则,调用当前编码器层的前向传播函数
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
output_attentions,
deterministic,
)
# 更新隐藏状态为当前层的输出的第一个元素
hidden_states = layer_outputs[0]
# 如果需要输出注意力权重,则将当前层的注意力权重添加到 all_attentions 中
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
# 如果需要输出隐藏状态,则将最后一个隐藏状态添加到 all_hidden_states 中
if output_hidden_states:
all_hidden_states += (hidden_states,)
# 将最终的输出整合为一个元组,根据 return_dict 决定返回类型
outputs = (hidden_states, all_hidden_states, all_attentions)
# 如果不需要以字典形式返回,则返回一个元组,去除其中为 None 的部分
if not return_dict:
return tuple(v for v in outputs if v is not None)
# 否则,以 FlaxBaseModelOutput 对象形式返回所有输出
return FlaxBaseModelOutput(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
)
# 从 transformers.models.bart.modeling_flax_bart.FlaxBartDecoderLayer 复制代码,并将 Bart 更改为 Marian
class FlaxMarianDecoderLayer(nn.Module):
# 使用 MarianConfig 类型的配置参数 config
config: MarianConfig
# 默认数据类型为 jnp.float32
dtype: jnp.dtype = jnp.float32
# 设置方法,初始化层的各项参数
def setup(self) -> None:
# 获取嵌入维度,等于配置中的 d_model
self.embed_dim = self.config.d_model
# 定义自注意力层,使用 FlaxMarianAttention 类
self.self_attn = FlaxMarianAttention(
config=self.config,
embed_dim=self.embed_dim,
num_heads=self.config.decoder_attention_heads,
dropout=self.config.attention_dropout,
causal=True,
dtype=self.dtype,
)
# 定义 dropout 层,用于 self-attention 和全连接层之间
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
# 激活函数,根据配置中的激活函数选择对应的函数
self.activation_fn = ACT2FN[self.config.activation_function]
# 激活函数的 dropout 层
self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
# 自注意力层的 LayerNorm 层
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
# 定义编码器注意力层,使用 FlaxMarianAttention 类
self.encoder_attn = FlaxMarianAttention(
config=self.config,
embed_dim=self.embed_dim,
num_heads=self.config.decoder_attention_heads,
dropout=self.config.attention_dropout,
dtype=self.dtype,
)
# 编码器注意力层的 LayerNorm 层
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
# 第一个全连接层,输入维度为 decoder_ffn_dim,输出维度与嵌入维度相同
self.fc1 = nn.Dense(
self.config.decoder_ffn_dim,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
# 第二个全连接层,输入维度与嵌入维度相同,输出维度也与嵌入维度相同
self.fc2 = nn.Dense(
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
)
# 最终输出的 LayerNorm 层
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
# 对象调用方法,定义层的前向传播逻辑
def __call__(
self,
hidden_states: jnp.ndarray, # 输入的隐藏状态
attention_mask: jnp.ndarray, # 注意力掩码
encoder_hidden_states: Optional[jnp.ndarray] = None, # 编码器的隐藏状态(可选)
encoder_attention_mask: Optional[jnp.ndarray] = None, # 编码器的注意力掩码(可选)
init_cache: bool = False, # 是否初始化缓存(默认为 False)
output_attentions: bool = True, # 是否输出注意力权重(默认为 True)
deterministic: bool = True, # 是否确定性计算(默认为 True)
) -> Tuple[jnp.ndarray]:
residual = hidden_states
# Self Attention
# 使用自注意力机制处理隐藏状态,返回处理后的隐藏状态和注意力权重
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states, attention_mask=attention_mask, init_cache=init_cache
)
# 应用 dropout 层,用于防止过拟合
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
# 添加残差连接
hidden_states = residual + hidden_states
# 对处理后的隐藏状态进行层归一化
hidden_states = self.self_attn_layer_norm(hidden_states)
# Cross-Attention Block
cross_attn_weights = None
# 如果有编码器隐藏状态,执行交叉注意力机制
if encoder_hidden_states is not None:
residual = hidden_states
# 使用编码器注意力机制处理隐藏状态,返回处理后的隐藏状态和注意力权重
hidden_states, cross_attn_weights = self.encoder_attn(
hidden_states=hidden_states,
key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
)
# 应用 dropout 层,用于防止过拟合
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
# 添加残差连接
hidden_states = residual + hidden_states
# 对处理后的隐藏状态进行层归一化
hidden_states = self.encoder_attn_layer_norm(hidden_states)
# Fully Connected
residual = hidden_states
# 应用激活函数和全连接层 fc1
hidden_states = self.activation_fn(self.fc1(hidden_states))
# 应用 dropout 层,用于防止过拟合
hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic)
# 应用全连接层 fc2
hidden_states = self.fc2(hidden_states)
# 应用 dropout 层,用于防止过拟合
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
# 添加残差连接
hidden_states = residual + hidden_states
# 对处理后的隐藏状态进行层归一化
hidden_states = self.final_layer_norm(hidden_states)
outputs = (hidden_states,)
# 如果需要输出注意力权重,将自注意力和交叉注意力的权重添加到输出中
if output_attentions:
outputs += (self_attn_weights, cross_attn_weights)
# 返回最终输出
return outputs
# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartDecoderLayerCollection with Bart->Marian
# 定义一个名为FlaxMarianDecoderLayerCollection的类,作为Marian模型的解码器层集合
class FlaxMarianDecoderLayerCollection(nn.Module):
config: MarianConfig
dtype: jnp.dtype = jnp.float32 # 计算的数据类型
def setup(self):
# 初始化解码器层列表,每个解码器层使用FlaxMarianDecoderLayer构造,数量由配置文件self.config.decoder_layers决定
self.layers = [
FlaxMarianDecoderLayer(self.config, name=str(i), dtype=self.dtype)
for i in range(self.config.decoder_layers)
]
# 设置LayerDrop的概率,从配置文件self.config.decoder_layerdrop中获取
self.layerdrop = self.config.decoder_layerdrop
def __call__(
self,
hidden_states,
attention_mask,
encoder_hidden_states: Optional[jnp.ndarray] = None,
encoder_attention_mask: Optional[jnp.ndarray] = None,
deterministic: bool = True,
init_cache: bool = False,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
# decoder layers
# 如果需要输出隐藏状态,则初始化all_hidden_states为一个空元组,否则为None
all_hidden_states = () if output_hidden_states else None
# 如果需要输出注意力分布,则初始化all_self_attns为一个空元组,否则为None
all_self_attns = () if output_attentions else None
# 如果需要输出交叉注意力分布,并且encoder_hidden_states不为None,则初始化all_cross_attentions为一个空元组,否则为None
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
# 遍历每个解码器层进行处理
for decoder_layer in self.layers:
if output_hidden_states:
# 如果需要输出隐藏状态,将当前的hidden_states加入all_hidden_states中
all_hidden_states += (hidden_states,)
# 添加LayerDrop功能,详情见论文https://arxiv.org/abs/1909.11556
# 生成一个0到1之间的随机数,作为Dropout的概率
dropout_probability = random.uniform(0, 1)
# 如果不是确定性的计算,并且随机数小于self.layerdrop,则不执行当前解码器层的计算
if not deterministic and (dropout_probability < self.layerdrop):
layer_outputs = (None, None, None)
else:
# 否则,执行当前解码器层的计算,传入相应的参数
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
init_cache=init_cache,
output_attentions=output_attentions,
deterministic=deterministic,
)
# 更新hidden_states为当前解码器层的输出中的第一个元素
hidden_states = layer_outputs[0]
if output_attentions:
# 如果需要输出注意力分布,将当前解码器层的注意力分布加入all_self_attns中
all_self_attns += (layer_outputs[1],)
# 如果encoder_hidden_states不为None,则将当前解码器层的交叉注意力分布加入all_cross_attentions中
if encoder_hidden_states is not None:
all_cross_attentions += (layer_outputs[2],)
# 将最后一个解码器层的隐藏状态加入all_hidden_states中
if output_hidden_states:
all_hidden_states += (hidden_states,)
# 汇总所有的输出信息到outputs列表中
outputs = [hidden_states, all_hidden_states, all_self_attns, all_cross_attentions]
# 如果不需要以字典形式返回结果,则返回outputs中不为None的元素构成的元组
if not return_dict:
return tuple(v for v in outputs if v is not None)
# 否则,以FlaxBaseModelOutputWithPastAndCrossAttentions对象的形式返回结果
return FlaxBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attns,
cross_attentions=all_cross_attentions,
)
# 定义一个名为FlaxMarianEncoder的类,作为Marian模型的编码器
class FlaxMarianEncoder(nn.Module):
config: MarianConfig
embed_tokens: nn.Embed
dtype: jnp.dtype = jnp.float32 # 计算的数据类型
# 初始化模型的设置,包括dropout层和embedding相关的参数设置
def setup(self):
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
# 设置embedding的维度
embed_dim = self.config.d_model
# 设置最大的位置编码长度
self.max_source_positions = self.config.max_position_embeddings
# 如果设置了scale_embedding标志位,则对embedding进行缩放
self.embed_scale = math.sqrt(embed_dim) if self.config.scale_embedding else 1.0
# 创建sinusoidal位置编码矩阵
self.embed_positions = create_sinusoidal_positions(self.config.max_position_embeddings, embed_dim)
# 初始化encoder层集合
self.layers = FlaxMarianEncoderLayerCollection(self.config, self.dtype)
# 模型的调用方法,输入参数和返回类型可选
def __call__(
self,
input_ids,
attention_mask,
position_ids,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
deterministic: bool = True,
):
# 获取输入的形状信息
input_shape = input_ids.shape
# 重新整形输入id
input_ids = input_ids.reshape(-1, input_shape[-1])
# 对输入id进行embedding并缩放
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
# 根据位置id从预先创建的位置编码中取出对应的位置信息
positions = jnp.take(self.embed_positions, position_ids, axis=0)
# 明确地将位置信息的数据类型转换为和输入embedding相同的数据类型
positions = positions.astype(inputs_embeds.dtype)
# 将embedding和位置信息相加得到最终的隐藏状态表示
hidden_states = inputs_embeds + positions
# 应用dropout层到隐藏状态
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
# 调用模型的encoder层进行前向传播
outputs = self.layers(
hidden_states,
attention_mask,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# 如果不要求返回字典形式的输出,则直接返回模型的outputs对象
if not return_dict:
return outputs
# 返回以FlaxBaseModelOutput对象封装的输出结果,包括最终的隐藏状态、所有隐藏状态以及注意力分布
return FlaxBaseModelOutput(
last_hidden_state=outputs.last_hidden_state,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
class FlaxMarianDecoder(nn.Module):
config: MarianConfig # 类型注解,指定config属性为MarianConfig类型
embed_tokens: nn.Embed # 类型注解,指定embed_tokens属性为nn.Embed类型
dtype: jnp.dtype = jnp.float32 # 计算中使用的数据类型,默认为jnp.float32
def setup(self):
self.dropout_layer = nn.Dropout(rate=self.config.dropout) # 初始化dropout层,使用config中的dropout率
embed_dim = self.config.d_model # 获取config中的d_model作为嵌入维度
self.max_target_positions = self.config.max_position_embeddings # 设置最大目标位置为config中的max_position_embeddings
self.embed_scale = math.sqrt(self.config.d_model) if self.config.scale_embedding else 1.0 # 根据scale_embedding标志设置嵌入缩放因子
self.embed_positions = create_sinusoidal_positions(self.config.max_position_embeddings, embed_dim) # 创建正弦位置编码
self.layers = FlaxMarianDecoderLayerCollection(self.config, self.dtype) # 初始化解码器层集合
def __call__(
self,
input_ids,
attention_mask,
position_ids,
encoder_hidden_states: Optional[jnp.ndarray] = None,
encoder_attention_mask: Optional[jnp.ndarray] = None,
init_cache: bool = False,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
deterministic: bool = True,
):
input_shape = input_ids.shape # 获取输入张量的形状
input_ids = input_ids.reshape(-1, input_shape[-1]) # 将输入张量重新形状为二维张量
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale # 使用嵌入令牌和缩放因子对输入进行嵌入
# 嵌入位置信息
positions = jnp.take(self.embed_positions, position_ids, axis=0)
# 明确地将位置转换为与inputs_embeds相同的数据类型,因为self.embed_positions未注册为参数
positions = positions.astype(inputs_embeds.dtype)
hidden_states = inputs_embeds + positions # 将嵌入的输入和位置编码相加
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) # 应用dropout层
outputs = self.layers(
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
deterministic=deterministic,
init_cache=init_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
) # 将hidden_states传递给解码器层进行处理
if not return_dict:
return outputs
return FlaxBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=outputs.last_hidden_state,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
) # 如果return_dict为True,则返回带有注意力信息的输出
class FlaxMarianModule(nn.Module):
config: MarianConfig # 类型注解,指定config属性为MarianConfig类型
dtype: jnp.dtype = jnp.float32 # 计算中使用的数据类型,默认为jnp.float32
def setup(self):
self.shared = nn.Embed(
self.config.vocab_size,
self.config.d_model,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
) # 初始化共享的嵌入层,使用config中的词汇大小和d_model,并使用正态分布初始化器初始化
self.encoder = FlaxMarianEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared) # 初始化编码器
self.decoder = FlaxMarianDecoder(self.config, dtype=self.dtype, embed_tokens=self.shared) # 初始化解码器
def _get_encoder_module(self):
return self.encoder # 返回编码器模块
# 返回解码器模块对象
def _get_decoder_module(self):
return self.decoder
# 实现调用操作,执行序列到序列模型的前向传播
def __call__(
self,
input_ids,
attention_mask,
decoder_input_ids,
decoder_attention_mask,
position_ids,
decoder_position_ids,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
deterministic: bool = True,
):
# 使用编码器模型处理输入序列
encoder_outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=deterministic,
)
# 使用解码器模型处理目标序列
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
position_ids=decoder_position_ids,
encoder_hidden_states=encoder_outputs[0],
encoder_attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=deterministic,
)
# 如果不需要返回字典格式的输出,则将编码器和解码器的输出拼接并返回
if not return_dict:
return decoder_outputs + encoder_outputs
# 返回序列到序列模型的输出对象,其中包含解码器和编码器的相关隐藏状态和注意力权重
return FlaxSeq2SeqModelOutput(
last_hidden_state=decoder_outputs.last_hidden_state,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)
class FlaxMarianPreTrainedModel(FlaxPreTrainedModel):
# 使用 MarianConfig 作为配置类
config_class = MarianConfig
# 基础模型前缀为 "model"
base_model_prefix: str = "model"
# 模块类暂未定义
module_class: nn.Module = None
def __init__(
self,
config: MarianConfig,
input_shape: Tuple[int] = (1, 1),
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs,
):
# 创建模块实例,传入配置和其他参数
module = self.module_class(config=config, dtype=dtype, **kwargs)
# 调用父类构造函数初始化模型
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# 初始化输入张量 input_ids
input_ids = jnp.zeros(input_shape, dtype="i4")
# 设置 input_ids 的最后一个位置为 eos_token_id
input_ids = input_ids.at[(..., -1)].set(self.config.eos_token_id)
# 初始化 attention_mask 为全1的张量
attention_mask = jnp.ones_like(input_ids)
# 将 decoder_input_ids 初始化为 input_ids
decoder_input_ids = input_ids
# 将 decoder_attention_mask 初始化为全1的张量
decoder_attention_mask = jnp.ones_like(input_ids)
# 获取 input_ids 的形状信息
batch_size, sequence_length = input_ids.shape
# 生成 position_ids,广播形状为 (batch_size, sequence_length)
position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
# 生成 decoder_position_ids,广播形状为 (batch_size, sequence_length)
decoder_position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
# 分割随机数生成器 rng,返回 params_rng 和 dropout_rng
params_rng, dropout_rng = jax.random.split(rng)
# 构建随机数字典 rngs,包含 params_rng 和 dropout_rng
rngs = {"params": params_rng, "dropout": dropout_rng}
# 使用模块的初始化方法初始化模型参数,返回随机生成的参数 random_params
random_params = self.module.init(
rngs,
input_ids,
attention_mask,
decoder_input_ids,
decoder_attention_mask,
position_ids,
decoder_position_ids,
)["params"]
# 如果传入了预定义的参数 params
if params is not None:
# 展平 random_params 和 params
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
# 处理缺失的键
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
# 清空缺失键集合
self._missing_keys = set()
# 冻结并返回 params
return freeze(unflatten_dict(params))
else:
# 返回随机生成的参数 random_params
return random_params
def init_cache(self, batch_size, max_length, encoder_outputs):
r"""
Args:
batch_size (`int`):
batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
max_length (`int`):
maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
cache.
encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`):
`encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*:
`attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*)
is a sequence of hidden-states at the output of the last layer of the encoder. Used in the
cross-attention of the decoder.
"""
# 初始化用于检索缓存的输入变量
decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4")
# 创建与decoder_input_ids相同形状的全1张量,用作解码器的注意力遮罩
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
# 使用广播方式生成位置编码,形状与decoder_input_ids相同
decoder_position_ids = jnp.broadcast_to(
jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape
)
def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs):
# 获取解码器模块
decoder_module = module._get_decoder_module()
# 调用解码器模块进行前向传播
return decoder_module(decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs)
# 使用给定的输入参数初始化模型的变量
init_variables = self.module.init(
jax.random.PRNGKey(0),
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
decoder_position_ids=decoder_position_ids,
encoder_hidden_states=encoder_outputs[0],
init_cache=True,
method=_decoder_forward, # 只需调用解码器以初始化缓存
)
# 返回解冻后的初始化变量中的缓存部分
return unfreeze(init_variables["cache"])
@add_start_docstrings(MARIAN_ENCODE_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=MarianConfig)
def encode(
self,
input_ids: jnp.ndarray,
attention_mask: Optional[jnp.ndarray] = None,
position_ids: Optional[jnp.ndarray] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
@add_start_docstrings(MARIAN_DECODE_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=FlaxBaseModelOutputWithPastAndCrossAttentions, config_class=MarianConfig)
# 使用指定的文档字符串注解这个方法,将其标记为用于解码的函数,并替换返回值的文档字符串
def decode(
self,
decoder_input_ids,
encoder_outputs,
encoder_attention_mask: Optional[jnp.ndarray] = None,
decoder_attention_mask: Optional[jnp.ndarray] = None,
decoder_position_ids: Optional[jnp.ndarray] = None,
past_key_values: dict = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
):
r"""
Returns:
Example:
```
>>> from transformers import AutoTokenizer, FlaxMarianMTModel
>>> tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-de")
>>> model = FlaxMarianMTModel.from_pretrained("Helsinki-NLP/opus-mt-en-de")
>>> text = "My friends are cool but they eat too many carbs."
>>> inputs = tokenizer(text, max_length=64, return_tensors="jax")
>>> encoder_outputs = model.encode(**inputs)
```
Defines whether to output attentions or not. Defaults to `True` if `output_attentions` is not `None`, else `False`.
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
Defines whether to output hidden states or not. Defaults to `True` if `output_hidden_states` is not `None`, else `False`.
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
Defines whether to return outputs as a dictionary. Defaults to `True` if `return_dict` is not `None`, else `False`.
return_dict = return_dict if return_dict is not None else self.config.return_dict
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids)
if position_ids is None:
batch_size, sequence_length = input_ids.shape
position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
def _encoder_forward(module, input_ids, attention_mask, position_ids, **kwargs):
encode_module = module._get_encoder_module()
return encode_module(input_ids, attention_mask, position_ids, **kwargs)
return self.module.apply(
{"params": params or self.params},
input_ids=jnp.array(input_ids, dtype="i4"),
attention_mask=jnp.array(attention_mask, dtype="i4"),
position_ids=jnp.array(position_ids, dtype="i4"),
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=not train,
rngs=rngs,
method=_encoder_forward,
)
def __call__(
self,
input_ids: jnp.ndarray,
attention_mask: Optional[jnp.ndarray] = None,
decoder_input_ids: Optional[jnp.ndarray] = None,
decoder_attention_mask: Optional[jnp.ndarray] = None,
position_ids: Optional[jnp.ndarray] = None,
decoder_position_ids: Optional[jnp.ndarray] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids)
if position_ids is None:
batch_size, sequence_length = input_ids.shape
position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
if decoder_input_ids is None:
decoder_input_ids = shift_tokens_right(
input_ids, self.config.pad_token_id, decoder_start_token_id=self.config.decoder_start_token_id
)
if decoder_attention_mask is None:
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
if decoder_position_ids is None:
batch_size, sequence_length = decoder_input_ids.shape
decoder_position_ids = jnp.broadcast_to(
jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
)
rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
return self.module.apply(
{"params": params or self.params},
input_ids=jnp.array(input_ids, dtype="i4"),
attention_mask=jnp.array(attention_mask, dtype="i4"),
position_ids=jnp.array(position_ids, dtype="i4"),
decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=not train,
rngs=rngs,
)
@add_start_docstrings(
"The bare Marian Model transformer outputting raw hidden-states without any specific head on top.",
MARIAN_START_DOCSTRING,
)
class FlaxMarianModel(FlaxMarianPreTrainedModel):
config: MarianConfig
dtype: jnp.dtype = jnp.float32
module_class = FlaxMarianModule
append_call_sample_docstring(FlaxMarianModel, _CHECKPOINT_FOR_DOC, FlaxSeq2SeqModelOutput, _CONFIG_FOR_DOC)
class FlaxMarianMTModule(nn.Module):
config: MarianConfig
dtype: jnp.dtype = jnp.float32
bias_init: Callable[..., jnp.ndarray] = jax.nn.initializers.zeros
def setup(self):
self.model = FlaxMarianModule(config=self.config, dtype=self.dtype)
self.lm_head = nn.Dense(
self.model.shared.num_embeddings,
use_bias=False,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.model.shared.num_embeddings))
def _get_encoder_module(self):
return self.model.encoder
def _get_decoder_module(self):
return self.model.decoder
def __call__(
self,
input_ids,
attention_mask,
decoder_input_ids,
decoder_attention_mask,
position_ids,
decoder_position_ids,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
deterministic: bool = True,
):
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
position_ids=position_ids,
decoder_position_ids=decoder_position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=deterministic,
)
hidden_states = outputs[0]
if self.config.tie_word_embeddings:
shared_embedding = self.model.variables["params"]["shared"]["embedding"]
lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
else:
lm_logits = self.lm_head(hidden_states)
lm_logits += self.final_logits_bias.astype(self.dtype)
if not return_dict:
output = (lm_logits,) + outputs[1:]
return output
return FlaxSeq2SeqLMOutput(
logits=lm_logits,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
cross_attentions=outputs.cross_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions,
)
@add_start_docstrings(
"The MARIAN Model with a language modeling head. Can be used for translation.", MARIAN_START_DOCSTRING
)
class FlaxMarianMTModel(FlaxMarianPreTrainedModel):
module_class = FlaxMarianMTModule
dtype: jnp.dtype = jnp.float32
@add_start_docstrings(MARIAN_DECODE_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=MarianConfig)
def decode(
self,
decoder_input_ids,
encoder_outputs,
encoder_attention_mask: Optional[jnp.ndarray] = None,
decoder_attention_mask: Optional[jnp.ndarray] = None,
decoder_position_ids: Optional[jnp.ndarray] = None,
past_key_values: dict = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
):
def _adapt_logits_for_beam_search(self, logits):
"""This function enforces the padding token never to be generated."""
logits = logits.at[:, :, self.config.pad_token_id].set(float("-inf"))
return logits
def prepare_inputs_for_generation(
self,
decoder_input_ids,
max_length,
attention_mask: Optional[jax.Array] = None,
decoder_attention_mask: Optional[jax.Array] = None,
encoder_outputs=None,
**kwargs,
):
batch_size, seq_length = decoder_input_ids.shape
past_key_values = self.init_cache(batch_size, max_length, encoder_outputs)
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
if decoder_attention_mask is not None:
position_ids = decoder_attention_mask.cumsum(axis=-1) - 1
extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0))
else:
position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
return {
"past_key_values": past_key_values,
"encoder_outputs": encoder_outputs,
"encoder_attention_mask": attention_mask,
"decoder_attention_mask": extended_attention_mask,
"decoder_position_ids": position_ids,
}
def update_inputs_for_generation(self, model_outputs, model_kwargs):
model_kwargs["past_key_values"] = model_outputs.past_key_values
model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1
return model_kwargs
FLAX_MARIAN_MT_DOCSTRING = """
Returns:
Example:
```
>>> from transformers import AutoTokenizer, FlaxMarianMTModel
>>> model = FlaxMarianMTModel.from_pretrained("Helsinki-NLP/opus-mt-en-de")
>>> tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-de")
>>> text = "My friends are cool but they eat too many carbs."
>>> input_ids = tokenizer(text, max_length=64, return_tensors="jax").input_ids
>>> sequences = model.generate(input_ids, max_length=64, num_beams=2).sequences
>>> outputs = tokenizer.batch_decode(sequences, skip_special_tokens=True)
>>> # should give *Meine Freunde sind cool, aber sie essen zu viele Kohlenhydrate.*
```
"""
overwrite_call_docstring(
FlaxMarianMTModel,
MARIAN_INPUTS_DOCSTRING + FLAX_MARIAN_MT_DOCSTRING,
)
append_replace_return_docstrings(FlaxMarianMTModel, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
.\models\marian\modeling_marian.py
"""从 Marian C++ 仓库移植的 PyTorch MarianMTModel 模型。"""
import copy
import math
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
Seq2SeqLMOutput,
Seq2SeqModelOutput,
)
from ...modeling_utils import PreTrainedModel
from ...utils import (
add_end_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from .configuration_marian import MarianConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "MarianConfig"
_CHECKPOINT_FOR_DOC = "Helsinki-NLP/opus-mt-en-de"
MARIAN_PRETRAINED_MODEL_ARCHIVE_LIST = [
"Helsinki-NLP/opus-mt-en-de",
]
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
"""
将输入的 token 向右移动一位。
"""
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
shifted_input_ids[:, 0] = decoder_start_token_id
if pad_token_id is None:
raise ValueError("self.model.config.pad_token_id 必须被定义。")
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
return shifted_input_ids
class MarianSinusoidalPositionalEmbedding(nn.Embedding):
"""此模块生成任意长度的正弦位置嵌入。"""
def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None) -> None:
super().__init__(num_positions, embedding_dim)
self.weight = self._init_weight(self.weight)
@staticmethod
def _init_weight(out: nn.Parameter) -> nn.Parameter:
"""
初始化权重矩阵,类似于 XLM 的 create_sinusoidal_embeddings 函数,但特征没有交错。
余弦特征位于向量的后半部分。[dim // 2:]
"""
n_pos, dim = out.shape
position_enc = np.array(
[[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]
)
out.requires_grad = False
sentinel = dim // 2 if dim % 2 == 0 else (dim // 2) + 1
out[:, 0:sentinel] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
out[:, sentinel:] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
out.detach_()
return out
@torch.no_grad()
def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0) -> torch.Tensor:
"""
前向传播函数,用于计算位置编码的张量。
`input_ids_shape` 应该是 [bsz x seqlen] 的形状。
"""
bsz, seq_len = input_ids_shape[:2]
positions = torch.arange(
past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
)
return super().forward(positions)
class MarianAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
is_decoder: bool = False,
bias: bool = True,
is_causal: bool = False,
config: Optional[MarianConfig] = None,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
self.config = config
if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {num_heads})."
)
self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
self.is_causal = is_causal
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
):
...
class MarianEncoderLayer(nn.Module):
def __init__(self, config: MarianConfig):
super().__init__()
self.embed_dim = config.d_model
self.self_attn = MARIAN_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout,
config=config,
)
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout
self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
def forward(
self,
hidden_states: torch.FloatTensor,
attention_mask: torch.FloatTensor,
layer_head_mask: torch.FloatTensor,
output_attentions: Optional[bool] = False,
):
...
) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
`(encoder_attention_heads,)`.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
"""
residual = hidden_states
hidden_states, attn_weights, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
residual = hidden_states
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = self.fc2(hidden_states)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
hidden_states = self.final_layer_norm(hidden_states)
if hidden_states.dtype == torch.float16 and (
torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
):
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs
MARIAN_ATTENTION_CLASSES = {"eager": MarianAttention}
class MarianDecoderLayer(nn.Module):
def __init__(self, config: MarianConfig):
super().__init__()
self.embed_dim = config.d_model
self.self_attn = MARIAN_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
is_decoder=True,
is_causal=True,
config=config,
)
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.encoder_attn = MARIAN_ATTENTION_CLASSES[config._attn_implementation](
self.embed_dim,
config.decoder_attention_heads,
dropout=config.attention_dropout,
is_decoder=True,
config=config,
)
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = True,
):
pass
class MarianPreTrainedModel(PreTrainedModel):
config_class = MarianConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
def _init_weights(self, module: Union[nn.Linear, nn.Embedding, MarianSinusoidalPositionalEmbedding]):
std = self.config.init_std
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, MarianSinusoidalPositionalEmbedding):
pass
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
@property
def dummy_inputs(self):
pad_token = self.config.pad_token_id
input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
dummy_inputs = {
"attention_mask": input_ids.ne(pad_token),
"input_ids": input_ids,
"decoder_input_ids": input_ids,
}
return dummy_inputs
MARIAN_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`MarianConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
MARIAN_GENERATION_EXAMPLE = r"""
Pytorch version of marian-nmt's transformer.h (c++). Designed for the OPUS-NMT translation checkpoints. Available
models are listed [here](https://huggingface.co/models?search=Helsinki-NLP).
Examples:
```
>>> from transformers import AutoTokenizer, MarianMTModel
>>> src = "fr" # source language
>>> trg = "en" # target language
>>> model_name = f"Helsinki-NLP/opus-mt-{src}-{trg}"
>>> model = MarianMTModel.from_pretrained(model_name)
>>> tokenizer = AutoTokenizer.from_pretrained(model_name)
>>> sample_text = "où est l'arrêt de bus ?"
>>> batch = tokenizer([sample_text], return_tensors="pt")
>>> generated_ids = model.generate(**batch)
>>> tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
"Where's the bus stop?"
```
"""
MARIAN_INPUTS_DOCSTRING = r"""
"""
class MarianEncoder(MarianPreTrainedModel):
"""
Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
[`MarianEncoderLayer`].
Args:
config: MarianConfig
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
embed_tokens (nn.Embedding): output embedding
A PyTorch embedding layer representing the output embeddings of the model.
"""
def __init__(self, config: MarianConfig, embed_tokens: Optional[nn.Embedding] = None):
super().__init__(config)
self.dropout = config.dropout
self.layerdrop = config.encoder_layerdrop
embed_dim = config.d_model
self.padding_idx = config.pad_token_id
self.max_source_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
if embed_tokens is not None:
self.embed_tokens = embed_tokens
else:
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
self.embed_positions = MarianSinusoidalPositionalEmbedding(
config.max_position_embeddings, embed_dim, self.padding_idx
)
self.layers = nn.ModuleList([MarianEncoderLayer(config) for _ in range(config.encoder_layers)])
self.gradient_checkpointing = False
self.post_init()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
class MarianDecoder(MarianPreTrainedModel):
"""
Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`MarianDecoderLayer`]
Args:
config: MarianConfig
embed_tokens (nn.Embedding): output embedding
"""
def __init__(self, config: MarianConfig, embed_tokens: Optional[nn.Embedding] = None):
super().__init__(config)
self.dropout = config.dropout
self.layerdrop = config.decoder_layerdrop
self.padding_idx = config.pad_token_id
self.max_target_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
if embed_tokens is not None:
self.embed_tokens = embed_tokens
else:
self.embed_tokens = nn.Embedding(config.decoder_vocab_size, config.d_model, self.padding_idx)
self.embed_positions = MarianSinusoidalPositionalEmbedding(
config.max_position_embeddings, config.d_model, self.padding_idx
)
self.layers = nn.ModuleList([MarianDecoderLayer(config) for _ in range(config.decoder_layers)])
self.gradient_checkpointing = False
self.post_init()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
"""
Forward pass for the MarianDecoder module.
Args:
input_ids (torch.LongTensor): Input token IDs
attention_mask (torch.Tensor): Attention mask for masking out padded tokens
encoder_hidden_states (torch.FloatTensor): Hidden states from the encoder
encoder_attention_mask (torch.LongTensor): Attention mask for encoder's hidden states
head_mask (torch.Tensor): Mask for heads in the self-attention layers
cross_attn_head_mask (torch.Tensor): Mask for heads in the cross-attention layers
past_key_values (Tuple[Tuple[torch.FloatTensor]]): Cached key-value pairs for fast decoding
inputs_embeds (torch.FloatTensor): Optional tensor of embedded inputs
use_cache (bool): Whether to use cached key-value pairs
output_attentions (bool): Whether to output attentions
output_hidden_states (bool): Whether to output hidden states
return_dict (bool): Whether to return a dictionary as output
Returns:
Various outputs depending on the configuration (return_dict or not)
"""
def __init__(self, config: MarianConfig):
super().__init__(config)
padding_idx, vocab_size = config.pad_token_id, config.vocab_size
self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
if self.config.share_encoder_decoder_embeddings:
encoder_embed_tokens = decoder_embed_tokens = self.shared
else:
encoder_embed_tokens = copy.deepcopy(self.shared)
decoder_embed_tokens = copy.deepcopy(self.shared)
self.shared = None
self.encoder = MarianEncoder(config, encoder_embed_tokens)
self.decoder = MarianDecoder(config, decoder_embed_tokens)
self.post_init()
def get_input_embeddings(self):
return self.get_encoder().get_input_embeddings()
def set_input_embeddings(self, value):
if self.config.share_encoder_decoder_embeddings:
self.shared = value
self.encoder.embed_tokens = self.shared
self.decoder.embed_tokens = self.shared
else:
self.encoder.embed_tokens = value
def get_decoder_input_embeddings(self):
if self.config.share_encoder_decoder_embeddings:
raise ValueError(
"`get_decoder_input_embeddings` should not be called if `config.share_encoder_decoder_embeddings` "
"is `True`. Please use `get_input_embeddings` instead."
)
return self.get_decoder().get_input_embeddings()
def set_decoder_input_embeddings(self, value):
if self.config.share_encoder_decoder_embeddings:
raise ValueError(
"`config.share_encoder_decoder_embeddings` is set to `True` meaning the decoder input embeddings "
"are shared with the encoder. In order to set the decoder input embeddings, you should simply set "
"the encoder input embeddings by calling `set_input_embeddings` with the appropriate embeddings."
)
self.decoder.embed_tokens = value
def get_encoder(self):
return self.encoder
def get_decoder(self):
return self.decoder
@add_start_docstrings_to_model_forward(MARIAN_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[Union[Tuple[torch.Tensor], BaseModelOutput]] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
@add_start_docstrings(
"The Marian Model with a language modeling head. Can be used for summarization.", MARIAN_START_DOCSTRING
)
class MarianMTModel(MarianPreTrainedModel):
base_model_prefix = "model"
_keys_to_ignore_on_load_missing = [
"final_logits_bias",
"encoder.embed_positions.weight",
"decoder.embed_positions.weight",
]
_keys_to_ignore_on_save = ["model.encoder.embed_positions.weight", "model.decoder.embed_positions.weight"]
_tied_weights_keys = ["model.encoder.embed_tokens.weight", "model.decoder.embed_tokens.weight", "lm_head.weight"]
def __init__(self, config: MarianConfig):
super().__init__(config)
self.model = MarianModel(config)
target_vocab_size = config.vocab_size if config.share_encoder_decoder_embeddings else config.decoder_vocab_size
self.register_buffer("final_logits_bias", torch.zeros((1, target_vocab_size)))
self.lm_head = nn.Linear(config.d_model, target_vocab_size, bias=False)
self.post_init()
def get_encoder(self):
return self.model.get_encoder()
def get_decoder(self):
return self.model.get_decoder()
def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
if self.config.share_encoder_decoder_embeddings:
self._resize_final_logits_bias(new_num_tokens)
return new_embeddings
def _resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of=None) -> nn.Embedding:
old_embeddings = self.get_input_embeddings()
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens, pad_to_multiple_of)
self.set_input_embeddings(new_embeddings)
new_num_tokens = new_embeddings.weight.shape[0]
if self.config.share_encoder_decoder_embeddings:
self.config.decoder_vocab_size = new_num_tokens
if (
self.config.share_encoder_decoder_embeddings
and self.get_output_embeddings() is not None
and not self.config.tie_word_embeddings
):
old_lm_head = self.get_output_embeddings()
new_lm_head = self._get_resized_lm_head(old_lm_head, new_num_tokens)
self.set_output_embeddings(new_lm_head)
return self.get_input_embeddings()
def resize_decoder_token_embeddings(self, new_num_tokens):
if self.config.share_encoder_decoder_embeddings:
raise ValueError(
"`resize_decoder_token_embeddings` should not be called if `config.share_encoder_decoder_embeddings` "
"is `True`. Please use `resize_token_embeddings` instead."
)
old_embeddings = self.model.get_decoder_input_embeddings()
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
self.model.set_decoder_input_embeddings(new_embeddings)
if self.get_output_embeddings() is not None and not self.config.tie_word_embeddings:
old_lm_head = self.get_output_embeddings()
new_lm_head = self._get_resized_lm_head(old_lm_head, new_num_tokens)
self.set_output_embeddings(new_lm_head)
model_embeds = self.model.get_decoder_input_embeddings()
if new_num_tokens is None:
return model_embeds
self.config.decoder_vocab_size = new_num_tokens
self.tie_weights()
self._resize_final_logits_bias(new_num_tokens)
return model_embeds
def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
old_num_tokens = self.final_logits_bias.shape[-1]
if new_num_tokens <= old_num_tokens:
new_bias = self.final_logits_bias[:, :new_num_tokens]
else:
extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
self.register_buffer("final_logits_bias", new_bias)
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings: nn.Embedding):
self.lm_head = new_embeddings
def tie_weights(self):
"""
Tie the weights between the input embeddings and the output embeddings.
If the `torchscript` flag is set in the configuration, can't handle parameter sharing so we are cloning the
weights instead.
"""
output_embeddings = self.get_output_embeddings()
if output_embeddings is not None and getattr(self.config, "tie_word_embeddings", True):
word_embeddings = self.get_decoder().get_input_embeddings()
self._tie_or_clone_weights(output_embeddings, word_embeddings)
if getattr(self.config, "is_encoder_decoder", False) and getattr(self.config, "tie_encoder_decoder", False):
if hasattr(self, self.base_model_prefix):
self = getattr(self, self.base_model_prefix)
self._tie_encoder_decoder_weights(self.encoder, self.decoder, self.base_model_prefix)
for module in self.modules():
if hasattr(module, "_tie_weights"):
module._tie_weights()
@add_start_docstrings_to_model_forward(MARIAN_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
@add_end_docstrings(MARIAN_GENERATION_EXAMPLE)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[Union[Tuple[torch.Tensor], BaseModelOutput]] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Seq2SeqLMOutput:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
`Seq2SeqLMOutput`: A class representing the outputs of the Seq2Seq language model.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if labels is not None:
if use_cache:
logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
use_cache = False
if decoder_input_ids is None and decoder_inputs_embeds is None:
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
)
outputs = self.model(
input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
encoder_outputs=encoder_outputs,
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias
masked_lm_loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.decoder_vocab_size), labels.view(-1))
if not return_dict:
output = (lm_logits,) + outputs[1:]
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
return Seq2SeqLMOutput(
loss=masked_lm_loss,
logits=lm_logits,
past_key_values=outputs.past_key_values,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
cross_attentions=outputs.cross_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions,
)
def prepare_inputs_for_generation(
self,
decoder_input_ids: torch.LongTensor,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
encoder_outputs: Optional[Union[Tuple[torch.Tensor], BaseModelOutput]] = None,
**kwargs,
) -> Dict:
"""
Prepare inputs for text generation.
Args:
decoder_input_ids: Input IDs for decoder.
past_key_values: Tuple of past key and value tensors.
attention_mask: Mask to avoid attention on padding tokens.
head_mask: Mask to nullify selected heads of the attention modules.
decoder_head_mask: Mask to nullify selected heads of the decoder self-attention modules.
cross_attn_head_mask: Mask to nullify selected heads of the cross-attention modules.
use_cache: Flag to control whether to use caching.
encoder_outputs: Output tensors from the encoder.
Returns:
Dictionary containing prepared inputs.
"""
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]
if decoder_input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
remove_prefix_length = decoder_input_ids.shape[1] - 1
decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
return {
"input_ids": None,
"encoder_outputs": encoder_outputs,
"past_key_values": past_key_values,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
"use_cache": use_cache,
}
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
"""
Shift labels to the right to prep inputs for decoder.
Args:
labels: Tensor of labels.
Returns:
Tensor of shifted labels.
"""
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
@staticmethod
def _reorder_cache(past_key_values, beam_idx):
"""
Reorder past key and value tensors based on beam index.
Args:
past_key_values: Tuple of past key and value tensors.
beam_idx: Tensor containing indices to reorder with.
Returns:
Reordered tuple of past key and value tensors.
"""
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2])
+ layer_past[2:],
)
return reordered_past
class MarianDecoderWrapper(MarianPreTrainedModel):
"""
这个包装类是一个辅助类,用于在使用因果语言模型与 EncoderDecoderModel 框架组合时正确加载预训练的检查点。
"""
def __init__(self, config):
super().__init__(config)
self.decoder = MarianDecoder(config)
def forward(self, *args, **kwargs):
return self.decoder(*args, **kwargs)
class MarianForCausalLM(MarianPreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
config = copy.deepcopy(config)
config.is_decoder = True
config.is_encoder_decoder = False
super().__init__(config)
self.model = MarianDecoderWrapper(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.model.decoder.embed_tokens
def set_input_embeddings(self, value):
self.model.decoder.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model.decoder = decoder
def get_decoder(self):
return self.model.decoder
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
...
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs
):
...
):
if attention_mask is None:
attention_mask = input_ids.new_ones(input_ids.shape)
if past_key_values:
past_length = past_key_values[0][0].shape[2]
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"past_key_values": past_key_values,
"use_cache": use_cache,
}
@staticmethod
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past
.\models\marian\modeling_tf_marian.py
""" TF 2.0 Marian model."""
from __future__ import annotations
import random
from typing import Optional, Tuple, Union
import numpy as np
import tensorflow as tf
from ...activations_tf import get_tf_activation
from ...modeling_tf_outputs import (
TFBaseModelOutput,
TFBaseModelOutputWithPastAndCrossAttentions,
TFSeq2SeqLMOutput,
TFSeq2SeqModelOutput,
)
from ...modeling_tf_utils import (
TFCausalLanguageModelingLoss,
TFPreTrainedModel,
keras,
keras_serializable,
unpack_inputs,
)
from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax
from ...utils import (
add_code_sample_docstrings,
add_end_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from .configuration_marian import MarianConfig
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "Helsinki-NLP/opus-mt-en-de"
_CONFIG_FOR_DOC = "MarianConfig"
LARGE_NEGATIVE = -1e8
def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int):
pad_token_id = tf.cast(pad_token_id, input_ids.dtype)
decoder_start_token_id = tf.cast(decoder_start_token_id, input_ids.dtype)
start_tokens = tf.fill(
(shape_list(input_ids)[0], 1), tf.convert_to_tensor(decoder_start_token_id, input_ids.dtype)
)
shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1)
shifted_input_ids = tf.where(
shifted_input_ids == -100,
tf.fill(shape_list(shifted_input_ids), tf.convert_to_tensor(pad_token_id, input_ids.dtype)),
shifted_input_ids,
)
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
with tf.control_dependencies([assert_gte0]):
shifted_input_ids = tf.identity(shifted_input_ids)
return shifted_input_ids
def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: int = 0):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz = input_ids_shape[0]
tgt_len = input_ids_shape[1]
mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE
mask_cond = tf.range(shape_list(mask)[-1])
mask = tf.where(mask_cond < tf.reshape(mask_cond + 1, (shape_list(mask)[-1], 1)), 0.0, mask)
if past_key_values_length > 0:
mask = tf.concat([tf.zeros((tgt_len, past_key_values_length)), mask], axis=-1)
return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1))
def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
src_len = shape_list(mask)[1]
tgt_len = tgt_len if tgt_len is not None else src_len
one_cst = tf.constant(1.0)
mask = tf.cast(mask, dtype=one_cst.dtype)
expanded_mask = tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1))
return (one_cst - expanded_mask) * LARGE_NEGATIVE
class TFMarianSinusoidalPositionalEmbedding(keras.layers.Layer):
"""This module produces sinusoidal positional embeddings of any length."""
def __init__(self, num_positions: int, embedding_dim: int, **kwargs):
super().__init__(**kwargs)
if embedding_dim % 2 != 0:
raise NotImplementedError(f"odd embedding_dim {embedding_dim} not supported")
self.embedding_dim = embedding_dim
self.num_positions = num_positions
def build(self, input_shape: tf.TensorShape):
"""
Build shared token embedding layer Shared weights logic adapted from
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
"""
weight = self._init_weight(self.num_positions, self.embedding_dim)
self.weight = self.add_weight(
name="embeddings",
shape=[self.num_positions, self.embedding_dim],
)
weight = tf.cast(weight, dtype=self.weight.dtype)
self.weight.assign(weight)
super().build(input_shape)
@staticmethod
def _init_weight(n_pos: int, dim: int):
"""
Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in
the 2nd half of the vector. [dim // 2:]
"""
position_enc = np.array(
[[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]
)
table = np.zeros_like(position_enc)
table[:, 0 : dim // 2] = np.sin(position_enc[:, 0::2])
table[:, dim // 2 :] = np.cos(position_enc[:, 1::2])
table = tf.convert_to_tensor(table)
tf.stop_gradient(table)
return table
def call(
self, input_shape: tf.TensorShape, past_key_values_length: int = 0, position_ids: tf.Tensor | None = None
):
"""Input is expected to be of size [bsz x seqlen]."""
if position_ids is None:
seq_len = input_shape[1]
position_ids = tf.range(past_key_values_length, seq_len + past_key_values_length, delta=1, name="range")
return tf.gather(self.weight, position_ids)
class TFMarianAttention(keras.layers.Layer):
"""Multi-headed attention from "Attention Is All You Need"""
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
is_decoder: bool = False,
bias: bool = True,
**kwargs,
):
super().__init__(**kwargs)
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = keras.layers.Dropout(dropout)
self.head_dim = embed_dim // num_heads
if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {num_heads})."
)
self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
self.k_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="k_proj")
self.q_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="q_proj")
self.v_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="v_proj")
self.out_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="out_proj")
def _shape(self, tensor: tf.Tensor, seq_len: int, bsz: int):
return tf.transpose(tf.reshape(tensor, (bsz, seq_len, self.num_heads, self.head_dim)), (0, 2, 1, 3))
def call(
self,
hidden_states: tf.Tensor,
key_value_states: tf.Tensor | None = None,
past_key_value: Tuple[Tuple[tf.Tensor]] | None = None,
attention_mask: tf.Tensor | None = None,
layer_head_mask: tf.Tensor | None = None,
training: Optional[bool] = False,
):
...
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "k_proj", None) is not None:
with tf.name_scope(self.k_proj.name):
self.k_proj.build([None, None, self.embed_dim])
if getattr(self, "q_proj", None) is not None:
with tf.name_scope(self.q_proj.name):
self.q_proj.build([None, None, self.embed_dim])
if getattr(self, "v_proj", None) is not None:
with tf.name_scope(self.v_proj.name):
self.v_proj.build([None, None, self.embed_dim])
if getattr(self, "out_proj", None) is not None:
with tf.name_scope(self.out_proj.name):
self.out_proj.build([None, None, self.embed_dim])
class TFMarianEncoderLayer(keras.layers.Layer):
...
def __init__(self, config: MarianConfig, **kwargs):
super().__init__(**kwargs)
self.embed_dim = config.d_model
self.self_attn = TFMarianAttention(
self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout, name="self_attn"
)
self.self_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm")
self.dropout = keras.layers.Dropout(config.dropout)
self.activation_fn = get_tf_activation(config.activation_function)
self.activation_dropout = keras.layers.Dropout(config.activation_dropout)
self.fc1 = keras.layers.Dense(config.encoder_ffn_dim, name="fc1")
self.fc2 = keras.layers.Dense(self.embed_dim, name="fc2")
self.final_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
self.config = config
def call(
self,
hidden_states: tf.Tensor,
attention_mask: np.ndarray | tf.Tensor | None,
layer_head_mask: tf.Tensor | None,
training: Optional[bool] = False,
) -> tf.Tensor:
"""
Args:
hidden_states (`tf.Tensor`): 输入层的张量,形状为 `(batch, seq_len, embed_dim)`
attention_mask (`tf.Tensor`): 注意力掩码,形状为 `(batch, 1, tgt_len, src_len)`,
其中填充元素由非常大的负值指示。
layer_head_mask (`tf.Tensor`): 给定层的注意力头掩码,形状为 `(encoder_attention_heads,)`
training (`Optional[bool]`, optional): 是否处于训练模式,默认为False。
Returns:
tf.Tensor: 返回处理后的张量,形状为 `(batch, seq_len, embed_dim)`
"""
residual = hidden_states
hidden_states, self_attn_weights, _ = self.self_attn(
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
)
tf.debugging.assert_equal(
shape_list(hidden_states),
shape_list(residual),
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
)
hidden_states = self.dropout(hidden_states, training=training)
hidden_states = residual + hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
residual = hidden_states
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = self.activation_dropout(hidden_states, training=training)
hidden_states = self.fc2(hidden_states)
hidden_states = self.dropout(hidden_states, training=training)
hidden_states = residual + hidden_states
hidden_states = self.final_layer_norm(hidden_states)
return hidden_states, self_attn_weights
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "self_attn", None) is not None:
with tf.name_scope(self.self_attn.name):
self.self_attn.build(None)
if getattr(self, "self_attn_layer_norm", None) is not None:
with tf.name_scope(self.self_attn_layer_norm.name):
self.self_attn_layer_norm.build([None, None, self.embed_dim])
if getattr(self, "fc1", None) is not None:
with tf.name_scope(self.fc1.name):
self.fc1.build([None, None, self.embed_dim])
if getattr(self, "fc2", None) is not None:
with tf.name_scope(self.fc2.name):
self.fc2.build([None, None, self.config.encoder_ffn_dim])
if getattr(self, "final_layer_norm", None) is not None:
with tf.name_scope(self.final_layer_norm.name):
self.final_layer_norm.build([None, None, self.embed_dim])
class TFMarianDecoderLayer(keras.layers.Layer):
def __init__(self, config: MarianConfig, **kwargs):
super().__init__(**kwargs)
self.embed_dim = config.d_model
self.self_attn = TFMarianAttention(
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
name="self_attn",
is_decoder=True,
)
self.dropout = keras.layers.Dropout(config.dropout)
self.activation_fn = get_tf_activation(config.activation_function)
self.activation_dropout = keras.layers.Dropout(config.activation_dropout)
self.self_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm")
self.encoder_attn = TFMarianAttention(
self.embed_dim,
config.decoder_attention_heads,
dropout=config.attention_dropout,
name="encoder_attn",
is_decoder=True,
)
self.encoder_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="encoder_attn_layer_norm")
self.fc1 = keras.layers.Dense(config.decoder_ffn_dim, name="fc1")
self.fc2 = keras.layers.Dense(self.embed_dim, name="fc2")
self.final_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
self.config = config
def call(
self,
hidden_states: tf.Tensor,
attention_mask: np.ndarray | tf.Tensor | None = None,
encoder_hidden_states: np.ndarray | tf.Tensor | None = None,
encoder_attention_mask: np.ndarray | tf.Tensor | None = None,
layer_head_mask: tf.Tensor | None = None,
cross_attn_layer_head_mask: tf.Tensor | None = None,
past_key_value: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
training: Optional[bool] = False,
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "self_attn", None) is not None:
with tf.name_scope(self.self_attn.name):
self.self_attn.build(None)
if getattr(self, "self_attn_layer_norm", None) is not None:
with tf.name_scope(self.self_attn_layer_norm.name):
self.self_attn_layer_norm.build([None, None, self.embed_dim])
if getattr(self, "encoder_attn", None) is not None:
with tf.name_scope(self.encoder_attn.name):
self.encoder_attn.build(None)
if getattr(self, "encoder_attn_layer_norm", None) is not None:
with tf.name_scope(self.encoder_attn_layer_norm.name):
self.encoder_attn_layer_norm.build([None, None, self.embed_dim])
if getattr(self, "fc1", None) is not None:
with tf.name_scope(self.fc1.name):
self.fc1.build([None, None, self.embed_dim])
if getattr(self, "fc2", None) is not None:
with tf.name_scope(self.fc2.name):
self.fc2.build([None, None, self.config.decoder_ffn_dim])
if getattr(self, "final_layer_norm", None) is not None:
with tf.name_scope(self.final_layer_norm.name):
self.final_layer_norm.build([None, None, self.embed_dim])
class TFMarianPreTrainedModel(TFPreTrainedModel):
config_class = MarianConfig
base_model_prefix = "model"
MARIAN_START_DOCSTRING = r"""
This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
behavior.
<Tip>
TensorFlow models and layers in `transformers` accept two formats as input:
- having all inputs as keyword arguments (like PyTorch models), or
- having all inputs as a list, tuple or dict in the first positional argument.
The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second
format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with
the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first
positional argument:
- a single Tensor with `input_ids` only and nothing else: `model(input_ids)`
- a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
`model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
- a dictionary with one or several input Tensors associated to the input names given in the docstring:
`model({"input_ids": input_ids, "token_type_ids": token_type_ids})`
Note that when creating models and layers with
[subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
about any of this, as you can just pass inputs like you would to any other Python function!
</Tip>
Args:
config ([`MarianConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
"""
MARIAN_GENERATION_EXAMPLE = r"""
TF version of marian-nmt's transformer.h (c++). Designed for the OPUS-NMT translation checkpoints. Available
models are listed [here](https://huggingface.co/models?search=Helsinki-NLP).
Examples:
```
>>> from transformers import AutoTokenizer, TFMarianMTModel
>>> from typing import List
>>> src = "fr" # source language
>>> trg = "en" # target language
>>> sample_text = "où est l'arrêt de bus ?"
>>> model_name = f"Helsinki-NLP/opus-mt-{src}-{trg}"
>>> model = TFMarianMTModel.from_pretrained(model_name)
>>> tokenizer = AutoTokenizer.from_pretrained(model_name)
>>> batch = tokenizer([sample_text], return_tensors="tf")
>>> gen = model.generate(**batch)
>>> tokenizer.batch_decode(gen, skip_special_tokens=True)
"Where is the bus stop ?"
```
"""
MARIAN_INPUTS_DOCSTRING = r"""
"""
@keras_serializable
class TFMarianEncoder(keras.layers.Layer):
config_class = MarianConfig
"""
Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
[`TFMarianEncoderLayer`].
Args:
config: MarianConfig
"""
def __init__(self, config: MarianConfig, embed_tokens: Optional[keras.layers.Embedding] = None, **kwargs):
super().__init__(**kwargs)
self.config = config
self.dropout = keras.layers.Dropout(config.dropout)
self.layerdrop = config.encoder_layerdrop
self.padding_idx = config.pad_token_id
self.max_source_positions = config.max_position_embeddings
self.embed_scale = tf.math.sqrt(float(config.d_model)) if config.scale_embedding else 1.0
self.embed_tokens = embed_tokens
self.embed_positions = TFMarianSinusoidalPositionalEmbedding(
config.max_position_embeddings,
config.d_model,
name="embed_positions",
)
self.layers = [TFMarianEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)]
def get_embed_tokens(self):
return self.embed_tokens
def set_embed_tokens(self, embed_tokens):
self.embed_tokens = embed_tokens
@unpack_inputs
def call(
self,
input_ids: tf.Tensor | None = None,
inputs_embeds: tf.Tensor | None = None,
attention_mask: tf.Tensor | None = None,
head_mask: tf.Tensor | None = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
**kwargs,
):
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "embed_positions", None) is not None:
with tf.name_scope(self.embed_positions.name):
self.embed_positions.build(None)
if getattr(self, "layers", None) is not None:
for layer in self.layers:
with tf.name_scope(layer.name):
layer.build(None)
@keras_serializable
class TFMarianDecoder(keras.layers.Layer):
config_class = MarianConfig
"""
Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`TFMarianDecoderLayer`]
Args:
config: MarianConfig # 输入参数为MarianConfig类型的配置对象
embed_tokens: output embedding # 输出嵌入的标记
"""
def __init__(self, config: MarianConfig, embed_tokens: Optional[keras.layers.Embedding] = None, **kwargs):
super().__init__(**kwargs)
self.config = config
self.padding_idx = config.pad_token_id
self.embed_tokens = embed_tokens
self.layerdrop = config.decoder_layerdrop
self.embed_positions = TFMarianSinusoidalPositionalEmbedding(
config.max_position_embeddings,
config.d_model,
name="embed_positions",
)
self.embed_scale = tf.math.sqrt(float(config.d_model)) if config.scale_embedding else 1.0
self.layers = [TFMarianDecoderLayer(config, name=f"layers.{i}") for i in range(config.decoder_layers)]
self.dropout = keras.layers.Dropout(config.dropout)
def get_embed_tokens(self):
return self.embed_tokens
def set_embed_tokens(self, embed_tokens):
self.embed_tokens = embed_tokens
@unpack_inputs
def call(
self,
input_ids: tf.Tensor | None = None,
inputs_embeds: tf.Tensor | None = None,
attention_mask: tf.Tensor | None = None,
position_ids: tf.Tensor | None = None,
encoder_hidden_states: tf.Tensor | None = None,
encoder_attention_mask: tf.Tensor | None = None,
head_mask: tf.Tensor | None = None,
cross_attn_head_mask: tf.Tensor | None = None,
past_key_values: Tuple[Tuple[tf.Tensor]] | None = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
):
...
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "embed_positions", None) is not None:
with tf.name_scope(self.embed_positions.name):
self.embed_positions.build(None)
if getattr(self, "layers", None) is not None:
for layer in self.layers:
with tf.name_scope(layer.name):
layer.build(None)
@keras_serializable
class TFMarianMainLayer(keras.layers.Layer):
config_class = MarianConfig
def __init__(self, config: MarianConfig, **kwargs):
super().__init__(**kwargs)
self.config = config
self.shared = keras.layers.Embedding(
input_dim=config.vocab_size,
output_dim=config.d_model,
embeddings_initializer=keras.initializers.TruncatedNormal(stddev=self.config.init_std),
name="model.shared",
)
self.shared.load_weight_prefix = "model.shared"
self.encoder = TFMarianEncoder(config, self.shared, name="encoder")
self.decoder = TFMarianDecoder(config, self.shared, name="decoder")
def get_input_embeddings(self):
return self.shared
def set_input_embeddings(self, new_embeddings):
self.shared = new_embeddings
self.encoder.embed_tokens = self.shared
self.decoder.embed_tokens = self.shared
@unpack_inputs
def call(
self,
input_ids: tf.Tensor | None = None,
attention_mask: tf.Tensor | None = None,
decoder_input_ids: tf.Tensor | None = None,
decoder_attention_mask: tf.Tensor | None = None,
decoder_position_ids: tf.Tensor | None = None,
head_mask: tf.Tensor | None = None,
decoder_head_mask: tf.Tensor | None = None,
cross_attn_head_mask: tf.Tensor | None = None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values: Tuple[Tuple[tf.Tensor]] = None,
inputs_embeds: tf.Tensor | None = None,
decoder_inputs_embeds: tf.Tensor | None = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
**kwargs,
):
pass
):
if decoder_input_ids is None and decoder_inputs_embeds is None:
use_cache = False
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
if encoder_outputs is None:
encoder_outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
elif return_dict and not isinstance(encoder_outputs, TFBaseModelOutput):
encoder_outputs = TFBaseModelOutput(
last_hidden_state=encoder_outputs[0],
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
)
elif not return_dict and not isinstance(encoder_outputs, tuple):
encoder_outputs = encoder_outputs.to_tuple()
decoder_outputs = self.decoder(
decoder_input_ids,
attention_mask=decoder_attention_mask,
position_ids=decoder_position_ids,
encoder_hidden_states=encoder_outputs[0],
encoder_attention_mask=attention_mask,
head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
if not return_dict:
return decoder_outputs + encoder_outputs
return TFSeq2SeqModelOutput(
last_hidden_state=decoder_outputs.last_hidden_state,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)
if self.built:
return
self.built = True
with tf.name_scope(self.shared.load_weight_prefix + "/" + self.shared.name + "/"):
self.shared.build(None)
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
if getattr(self, "decoder", None) is not None:
with tf.name_scope(self.decoder.name):
self.decoder.build(None)
class TFMarianModel(TFMarianPreTrainedModel):
def __init__(self, config: MarianConfig, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.model = TFMarianMainLayer(config, name="model")
def get_encoder(self):
return self.model.encoder
def get_decoder(self):
return self.model.decoder
@unpack_inputs
@add_start_docstrings_to_model_forward(MARIAN_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TFSeq2SeqModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def call(
self,
input_ids: tf.Tensor | None = None,
attention_mask: tf.Tensor | None = None,
decoder_input_ids: tf.Tensor | None = None,
decoder_attention_mask: tf.Tensor | None = None,
decoder_position_ids: tf.Tensor | None = None,
head_mask: tf.Tensor | None = None,
decoder_head_mask: tf.Tensor | None = None,
cross_attn_head_mask: tf.Tensor | None = None,
encoder_outputs: tf.Tensor | None = None,
past_key_values: Tuple[Tuple[tf.Tensor]] | None = None,
inputs_embeds: tf.Tensor | None = None,
decoder_inputs_embeds: tf.Tensor | None = None,
use_cache: bool | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
return_dict: bool | None = None,
training: bool = False,
**kwargs,
) -> Tuple[tf.Tensor] | TFSeq2SeqModelOutput:
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
decoder_position_ids=decoder_position_ids,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
return outputs
def serving_output(self, output):
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
return TFSeq2SeqModelOutput(
last_hidden_state=output.last_hidden_state,
past_key_values=pkv,
decoder_hidden_states=dec_hs,
decoder_attentions=dec_attns,
cross_attentions=cross_attns,
encoder_last_hidden_state=output.encoder_last_hidden_state,
encoder_hidden_states=enc_hs,
encoder_attentions=enc_attns,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "model", None) is not None:
with tf.name_scope(self.model.name):
self.model.build(None)
class BiasLayer(keras.layers.Layer):
"""
Bias as a layer. It is used for serialization purposes: `keras.Model.save_weights` stores on a per-layer basis,
so all weights have to be registered in a layer.
"""
def __init__(self, shape, initializer, trainable, name, **kwargs):
super().__init__(name=name, **kwargs)
self.bias = self.add_weight(name=name, shape=shape, initializer=initializer, trainable=trainable)
def call(self, x):
return x + self.bias
@add_start_docstrings(
"The MARIAN Model with a language modeling head. Can be used for summarization.",
MARIAN_START_DOCSTRING,
)
class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss):
_keys_to_ignore_on_load_unexpected = [
r"model.encoder.embed_tokens.weight",
r"model.decoder.embed_tokens.weight",
]
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.model = TFMarianMainLayer(config, name="model")
self.use_cache = config.use_cache
self.bias_layer = BiasLayer(
name="final_logits_bias", shape=[1, config.vocab_size], initializer="zeros", trainable=False
)
def get_decoder(self):
return self.model.decoder
def get_encoder(self):
return self.model.encoder
def get_output_embeddings(self):
return self.get_input_embeddings()
def set_output_embeddings(self, value):
self.set_input_embeddings(value)
def get_bias(self):
return {"final_logits_bias": self.bias_layer.bias}
def set_bias(self, value):
vocab_size = value["final_logits_bias"].shape[-1]
self.bias_layer = BiasLayer(
name="final_logits_bias", shape=[1, vocab_size], initializer="zeros", trainable=False
)
self.bias_layer.bias.assign(value["final_logits_bias"])
@unpack_inputs
@add_start_docstrings_to_model_forward(MARIAN_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
@add_end_docstrings(MARIAN_GENERATION_EXAMPLE)
input_ids: tf.Tensor | None = None,
attention_mask: tf.Tensor | None = None,
decoder_input_ids: tf.Tensor | None = None,
decoder_attention_mask: tf.Tensor | None = None,
decoder_position_ids: tf.Tensor | None = None,
head_mask: tf.Tensor | None = None,
decoder_head_mask: tf.Tensor | None = None,
cross_attn_head_mask: tf.Tensor | None = None,
encoder_outputs: TFBaseModelOutput | None = None,
past_key_values: Tuple[Tuple[tf.Tensor]] | None = None,
inputs_embeds: tf.Tensor | None = None,
decoder_inputs_embeds: tf.Tensor | None = None,
use_cache: bool | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
return_dict: bool | None = None,
labels: tf.Tensor | None = None,
training: bool = False,
) -> Tuple[tf.Tensor] | TFSeq2SeqLMOutput:
r"""
labels (`tf.tensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Depending on `return_dict`, either a tuple or `TFSeq2SeqLMOutput`.
"""
if labels is not None:
labels = tf.where(
labels == self.config.pad_token_id,
tf.fill(shape_list(labels), tf.cast(-100, labels.dtype)),
labels,
)
use_cache = False
if decoder_input_ids is None and decoder_inputs_embeds is None:
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
)
outputs = self.model(
input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
encoder_outputs=encoder_outputs,
decoder_attention_mask=decoder_attention_mask,
decoder_position_ids=decoder_position_ids,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
lm_logits = tf.matmul(outputs[0], self.model.shared.weights, transpose_b=True)
lm_logits = self.bias_layer(lm_logits)
masked_lm_loss = None if labels is None else self.hf_compute_loss(labels, lm_logits)
if not return_dict:
output = (lm_logits,) + outputs[1:]
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
return TFSeq2SeqLMOutput(
loss=masked_lm_loss,
logits=lm_logits,
past_key_values=outputs.past_key_values,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
cross_attentions=outputs.cross_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions,
)
def serving_output(self, output):
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
return TFSeq2SeqLMOutput(
logits=output.logits,
past_key_values=pkv,
decoder_hidden_states=dec_hs,
decoder_attentions=dec_attns,
cross_attentions=cross_attns,
encoder_last_hidden_state=output.encoder_last_hidden_state,
encoder_hidden_states=enc_hs,
encoder_attentions=enc_attns,
)
def prepare_inputs_for_generation(
self,
decoder_input_ids,
past_key_values=None,
attention_mask=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
use_cache=None,
encoder_outputs=None,
**kwargs,
):
if past_key_values is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
if decoder_attention_mask is not None:
decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:]
elif past_key_values is not None:
decoder_position_ids = past_key_values[0][0].shape[2]
else:
decoder_position_ids = tf.range(decoder_input_ids.shape[1])
return {
"input_ids": None,
"encoder_outputs": encoder_outputs,
"past_key_values": past_key_values,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask,
"decoder_position_ids": decoder_position_ids,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
"use_cache": use_cache,
}
def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "model", None) is not None:
with tf.name_scope(self.model.name):
self.model.build(None)
if getattr(self, "bias_layer", None) is not None:
with tf.name_scope(self.bias_layer.name):
self.bias_layer.build(None)