Transformers 源码解析(十三)
.\models\bart\modeling_flax_bart.py
""" Flax Bart model."""
import math
import random
from functools import partial
from typing import Callable, Optional, Tuple
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask
from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax
from jax.random import PRNGKey
from ...modeling_flax_outputs import (
FlaxBaseModelOutput,
FlaxBaseModelOutputWithPastAndCrossAttentions,
FlaxCausalLMOutputWithCrossAttentions,
FlaxSeq2SeqLMOutput,
FlaxSeq2SeqModelOutput,
FlaxSeq2SeqQuestionAnsweringModelOutput,
FlaxSeq2SeqSequenceClassifierOutput,
)
from ...modeling_flax_utils import (
ACT2FN,
FlaxPreTrainedModel,
append_call_sample_docstring,
append_replace_return_docstrings,
overwrite_call_docstring,
)
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_bart import BartConfig
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "facebook/bart-base"
_CONFIG_FOR_DOC = "BartConfig"
BART_START_DOCSTRING = r"""
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a Flax Linen
[flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
Finally, this model supports inherent JAX features such as:
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
"""
"""
定义 BART 输入文档字符串
"""
BART_INPUTS_DOCSTRING = r"""
"""
"""
定义 BART 编码输入文档字符串
Args:
input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
输入序列标记在词汇表中的索引。默认情况下,将忽略填充。
可以使用 [`AutoTokenizer`] 获取索引。详情请参阅 [`PreTrainedTokenizer.encode`] 和 [`PreTrainedTokenizer.__call__`]。
[什么是输入 ID?](../glossary#input-ids)
attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
避免在填充标记索引上执行注意力的掩码。掩码值选在 `[0, 1]`:
- 1 表示**未屏蔽**的标记,
- 0 表示**已屏蔽**的标记。
[什么是注意力掩码?](../glossary#attention-mask)
position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
每个输入序列标记在位置嵌入中的位置索引。选在范围 `[0, config.max_position_embeddings - 1]`。
output_attentions (`bool`, *optional*):
是否返回所有注意力层的注意力张量。详见返回张量中的 `attentions`。
output_hidden_states (`bool`, *optional*):
是否返回所有层的隐藏状态。详见返回张量中的 `hidden_states`。
return_dict (`bool`, *optional*):
是否返回 [`~utils.ModelOutput`] 而非普通元组。
"""
"""
定义 BART 解码输入文档字符串
"""
BART_DECODE_INPUTS_DOCSTRING = r"""
"""
def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
"""
将输入 ID 向右移动一个标记。
"""
shifted_input_ids = jnp.zeros_like(input_ids)
shifted_input_ids = shifted_input_ids.at[:, 1:].set(input_ids[:, :-1])
shifted_input_ids = shifted_input_ids.at[:, 0].set(decoder_start_token_id)
shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
return shifted_input_ids
class FlaxBartAttention(nn.Module):
"""
FlaxBartAttention 类定义
"""
config: BartConfig
embed_dim: int
num_heads: int
dropout: float = 0.0
causal: bool = False
bias: bool = True
dtype: jnp.dtype = jnp.float32
def setup(self) -> None:
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {self.num_heads})."
)
dense = partial(
nn.Dense,
self.embed_dim,
use_bias=self.bias,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
self.out_proj = dense()
self.dropout_layer = nn.Dropout(rate=self.dropout)
if self.causal:
self.causal_mask = make_causal_mask(
jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool"
)
def _split_heads(self, hidden_states):
return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
def _merge_heads(self, hidden_states):
return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
@nn.compact
def _concatenate_to_cache(self, key, value, query, attention_mask):
"""
This function takes projected key, value states from a single input token and concatenates the states to cached
states from previous steps. This function is slightly adapted from the official Flax repository:
https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
"""
is_initialized = self.has_variable("cache", "cached_key")
cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
if is_initialized:
*batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
cur_index = cache_index.value
indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
key = lax.dynamic_update_slice(cached_key.value, key, indices)
value = lax.dynamic_update_slice(cached_value.value, value, indices)
cached_key.value = key
cached_value.value = value
num_updated_cache_vectors = query.shape[1]
cache_index.value = cache_index.value + num_updated_cache_vectors
pad_mask = jnp.broadcast_to(
jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
)
attention_mask = combine_masks(pad_mask, attention_mask)
return key, value, attention_mask
class FlaxBartEncoderLayer(nn.Module):
config: BartConfig
dtype: jnp.dtype = jnp.float32
def setup(self) -> None:
self.embed_dim = self.config.d_model
self.self_attn = FlaxBartAttention(
config=self.config,
embed_dim=self.embed_dim,
num_heads=self.config.encoder_attention_heads,
dropout=self.config.attention_dropout,
dtype=self.dtype,
)
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
self.activation_fn = ACT2FN[self.config.activation_function]
self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
self.fc1 = nn.Dense(
self.config.encoder_ffn_dim,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
self.fc2 = nn.Dense(
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
)
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
def __call__(
self,
hidden_states: jnp.ndarray,
attention_mask: jnp.ndarray,
output_attentions: bool = True,
deterministic: bool = True,
) -> Tuple[jnp.ndarray]:
residual = hidden_states
hidden_states, attn_weights = self.self_attn(hidden_states=hidden_states, attention_mask=attention_mask)
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
hidden_states = residual + hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
residual = hidden_states
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic)
hidden_states = self.fc2(hidden_states)
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
hidden_states = residual + hidden_states
hidden_states = self.final_layer_norm(hidden_states)
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs
class FlaxBartEncoderLayerCollection(nn.Module):
config: BartConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.layers = [
FlaxBartEncoderLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.encoder_layers)
]
self.layerdrop = self.config.encoder_layerdrop
def __call__(
self,
hidden_states,
attention_mask,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
all_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
for encoder_layer in self.layers:
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
dropout_probability = random.uniform(0, 1)
if not deterministic and (dropout_probability < self.layerdrop):
layer_outputs = (None, None)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
output_attentions,
deterministic,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
if output_hidden_states:
all_hidden_states += (hidden_states,)
outputs = (hidden_states, all_hidden_states, all_attentions)
if not return_dict:
return tuple(v for v in outputs if v is not None)
return FlaxBaseModelOutput(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
)
class FlaxBartDecoderLayer(nn.Module):
config: BartConfig
dtype: jnp.dtype = jnp.float32
def setup(self) -> None:
self.embed_dim = self.config.d_model
self.self_attn = FlaxBartAttention(
config=self.config,
embed_dim=self.embed_dim,
num_heads=self.config.decoder_attention_heads,
dropout=self.config.attention_dropout,
causal=True,
dtype=self.dtype,
)
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
self.activation_fn = ACT2FN[self.config.activation_function]
self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
self.encoder_attn = FlaxBartAttention(
config=self.config,
embed_dim=self.embed_dim,
num_heads=self.config.decoder_attention_heads,
dropout=self.config.attention_dropout,
dtype=self.dtype,
)
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
self.fc1 = nn.Dense(
self.config.decoder_ffn_dim,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
self.fc2 = nn.Dense(
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
)
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
def __call__(
self,
hidden_states: jnp.ndarray,
attention_mask: jnp.ndarray,
encoder_hidden_states: Optional[jnp.ndarray] = None,
encoder_attention_mask: Optional[jnp.ndarray] = None,
init_cache: bool = False,
output_attentions: bool = True,
deterministic: bool = True,
) -> Tuple[jnp.ndarray]:
residual = hidden_states
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states, attention_mask=attention_mask, init_cache=init_cache
)
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
hidden_states = residual + hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
cross_attn_weights = None
if encoder_hidden_states is not None:
residual = hidden_states
hidden_states, cross_attn_weights = self.encoder_attn(
hidden_states=hidden_states,
key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
)
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
hidden_states = residual + hidden_states
hidden_states = self.encoder_attn_layer_norm(hidden_states)
residual = hidden_states
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic)
hidden_states = self.fc2(hidden_states)
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
hidden_states = residual + hidden_states
hidden_states = self.final_layer_norm(hidden_states)
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights, cross_attn_weights)
return outputs
class FlaxBartDecoderLayerCollection(nn.Module):
config: BartConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.layers = [
FlaxBartDecoderLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.decoder_layers)
]
self.layerdrop = self.config.decoder_layerdrop
def __call__(
self,
hidden_states,
attention_mask,
encoder_hidden_states: Optional[jnp.ndarray] = None,
encoder_attention_mask: Optional[jnp.ndarray] = None,
deterministic: bool = True,
init_cache: bool = False,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
dropout_probability = random.uniform(0, 1)
if not deterministic and (dropout_probability < self.layerdrop):
layer_outputs = (None, None, None)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
init_cache=init_cache,
output_attentions=output_attentions,
deterministic=deterministic,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attns += (layer_outputs[1],)
if encoder_hidden_states is not None:
all_cross_attentions += (layer_outputs[2],)
if output_hidden_states:
all_hidden_states += (hidden_states,)
outputs = [hidden_states, all_hidden_states, all_self_attns, all_cross_attentions]
if not return_dict:
return tuple(v for v in outputs if v is not None)
return FlaxBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attns,
cross_attentions=all_cross_attentions,
)
class FlaxBartClassificationHead(nn.Module):
"""用于句子级分类任务的头部模块。"""
config: BartConfig
inner_dim: int
num_classes: int
pooler_dropout: float
dtype: jnp.dtype = jnp.float32
def setup(self):
self.dense = nn.Dense(
self.inner_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
)
self.dropout = nn.Dropout(rate=self.pooler_dropout)
self.out_proj = nn.Dense(
self.num_classes,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
def __call__(self, hidden_states: jnp.ndarray, deterministic: bool):
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
hidden_states = self.dense(hidden_states)
hidden_states = jnp.tanh(hidden_states)
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
hidden_states = self.out_proj(hidden_states)
return hidden_states
class FlaxBartEncoder(nn.Module):
config: BartConfig
embed_tokens: nn.Embed
dtype: jnp.dtype = jnp.float32
def setup(self):
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
embed_dim = self.config.d_model
self.padding_idx = self.config.pad_token_id
self.max_source_positions = self.config.max_position_embeddings
self.embed_scale = math.sqrt(embed_dim) if self.config.scale_embedding else 1.0
self.offset = 2
self.embed_positions = nn.Embed(
self.config.max_position_embeddings + self.offset,
embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
dtype=self.dtype,
)
self.layers = FlaxBartEncoderLayerCollection(self.config, self.dtype)
self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
def __call__(
self,
input_ids,
attention_mask,
position_ids,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
deterministic: bool = True,
):
input_shape = input_ids.shape
input_ids = input_ids.reshape(-1, input_shape[-1])
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
embed_pos = self.embed_positions(position_ids + self.offset)
hidden_states = inputs_embeds + embed_pos
hidden_states = self.layernorm_embedding(hidden_states)
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
outputs = self.layers(
hidden_states,
attention_mask,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
if not return_dict:
return outputs
return FlaxBaseModelOutput(
last_hidden_state=outputs.last_hidden_state,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def setup(self):
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
embed_dim = self.config.d_model
self.padding_idx = self.config.pad_token_id
self.max_target_positions = self.config.max_position_embeddings
self.embed_scale = math.sqrt(self.config.d_model) if self.config.scale_embedding else 1.0
self.offset = 2
self.embed_positions = nn.Embed(
self.config.max_position_embeddings + self.offset,
embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
dtype=self.dtype,
)
self.layers = FlaxBartDecoderLayerCollection(self.config, self.dtype)
self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
def __call__(
self,
input_ids,
attention_mask,
position_ids,
encoder_hidden_states: Optional[jnp.ndarray] = None,
encoder_attention_mask: Optional[jnp.ndarray] = None,
init_cache: bool = False,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
deterministic: bool = True,
):
input_shape = input_ids.shape
input_ids = input_ids.reshape(-1, input_shape[-1])
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
positions = self.embed_positions(position_ids + self.offset)
hidden_states = inputs_embeds + positions
hidden_states = self.layernorm_embedding(hidden_states)
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
outputs = self.layers(
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
deterministic=deterministic,
init_cache=init_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
if not return_dict:
return outputs
return FlaxBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=outputs.last_hidden_state,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
)
class FlaxBartModule(nn.Module):
config: BartConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.shared = nn.Embed(
self.config.vocab_size,
self.config.d_model,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
dtype=self.dtype,
)
self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
self.decoder = FlaxBartDecoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
def _get_encoder_module(self):
return self.encoder
def _get_decoder_module(self):
return self.decoder
def __call__(
self,
input_ids,
attention_mask,
decoder_input_ids,
decoder_attention_mask,
position_ids,
decoder_position_ids,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
deterministic: bool = True,
):
encoder_outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=deterministic,
)
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
position_ids=decoder_position_ids,
encoder_hidden_states=encoder_outputs[0],
encoder_attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=deterministic,
)
if not return_dict:
return decoder_outputs + encoder_outputs
return FlaxSeq2SeqModelOutput(
last_hidden_state=decoder_outputs.last_hidden_state,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)
class FlaxBartPreTrainedModel(FlaxPreTrainedModel):
config_class = BartConfig
base_model_prefix: str = "model"
module_class: nn.Module = None
def __init__(
self,
config: BartConfig,
input_shape: Tuple[int] = (1, 1),
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs,
):
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
input_ids = jnp.zeros(input_shape, dtype="i4")
input_ids = input_ids.at[(..., -1)].set(self.config.eos_token_id)
attention_mask = jnp.ones_like(input_ids)
decoder_input_ids = input_ids
decoder_attention_mask = jnp.ones_like(input_ids)
batch_size, sequence_length = input_ids.shape
position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
decoder_position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
random_params = self.module.init(
rngs,
input_ids,
attention_mask,
decoder_input_ids,
decoder_attention_mask,
position_ids,
decoder_position_ids,
)["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
def init_cache(self, batch_size, max_length, encoder_outputs):
r"""
Args:
batch_size (`int`):
batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
max_length (`int`):
maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
cache.
encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`):
`encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*:
`attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*)
is a sequence of hidden-states at the output of the last layer of the encoder. Used in the
cross-attention of the decoder.
"""
decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4")
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
decoder_position_ids = jnp.broadcast_to(
jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape
)
def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs):
decoder_module = module._get_decoder_module()
return decoder_module(
decoder_input_ids,
decoder_attention_mask,
decoder_position_ids,
**kwargs,
)
init_variables = self.module.init(
jax.random.PRNGKey(0),
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
decoder_position_ids=decoder_position_ids,
encoder_hidden_states=encoder_outputs[0],
init_cache=True,
method=_decoder_forward,
)
return unfreeze(init_variables["cache"])
@add_start_docstrings(BART_ENCODE_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=BartConfig)
def encode(
self,
input_ids: jnp.ndarray,
attention_mask: Optional[jnp.ndarray] = None,
position_ids: Optional[jnp.ndarray] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
):
r"""
Returns:
Example:
```
>>> from transformers import AutoTokenizer, FlaxBartForConditionalGeneration
>>> model = FlaxBartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")
>>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
>>> text = "My friends are cool but they eat too many carbs."
>>> inputs = tokenizer(text, max_length=1024, return_tensors="jax")
>>> encoder_outputs = model.encode(**inputs)
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids)
if position_ids is None:
batch_size, sequence_length = input_ids.shape
position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
def _encoder_forward(module, input_ids, attention_mask, position_ids, **kwargs):
encode_module = module._get_encoder_module()
return encode_module(input_ids, attention_mask, position_ids, **kwargs)
return self.module.apply(
{"params": params or self.params},
input_ids=jnp.array(input_ids, dtype="i4"),
attention_mask=jnp.array(attention_mask, dtype="i4"),
position_ids=jnp.array(position_ids, dtype="i4"),
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=not train,
rngs=rngs,
method=_encoder_forward,
)
@add_start_docstrings(BART_DECODE_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=FlaxBaseModelOutputWithPastAndCrossAttentions, config_class=BartConfig)
def decode(
self,
decoder_input_ids,
encoder_outputs,
encoder_attention_mask: Optional[jnp.ndarray] = None,
decoder_attention_mask: Optional[jnp.ndarray] = None,
decoder_position_ids: Optional[jnp.ndarray] = None,
past_key_values: dict = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
@add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
def __call__(
self,
input_ids: jnp.ndarray,
attention_mask: Optional[jnp.ndarray] = None,
decoder_input_ids: Optional[jnp.ndarray] = None,
decoder_attention_mask: Optional[jnp.ndarray] = None,
position_ids: Optional[jnp.ndarray] = None,
decoder_position_ids: Optional[jnp.ndarray] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids)
if position_ids is None:
batch_size, sequence_length = input_ids.shape
position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
if decoder_input_ids is None:
decoder_input_ids = shift_tokens_right(
input_ids, self.config.pad_token_id, decoder_start_token_id=self.config.decoder_start_token_id
)
if decoder_attention_mask is None:
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
if decoder_position_ids is None:
batch_size, sequence_length = decoder_input_ids.shape
decoder_position_ids = jnp.broadcast_to(
jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
)
rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
return self.module.apply(
{"params": params or self.params},
input_ids=jnp.array(input_ids, dtype="i4"),
attention_mask=jnp.array(attention_mask, dtype="i4"),
position_ids=jnp.array(position_ids, dtype="i4"),
decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=not train,
rngs=rngs,
)
class FlaxBartModel(FlaxBartPreTrainedModel):
config: BartConfig
dtype: jnp.dtype = jnp.float32
module_class = FlaxBartModule
append_call_sample_docstring(FlaxBartModel, _CHECKPOINT_FOR_DOC, FlaxSeq2SeqModelOutput, _CONFIG_FOR_DOC)
class FlaxBartForConditionalGenerationModule(nn.Module):
config: BartConfig
dtype: jnp.dtype = jnp.float32
bias_init: Callable[..., jnp.ndarray] = jax.nn.initializers.zeros
def setup(self):
self.model = FlaxBartModule(config=self.config, dtype=self.dtype)
self.lm_head = nn.Dense(
self.model.shared.num_embeddings,
use_bias=False,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.model.shared.num_embeddings))
def _get_encoder_module(self):
return self.model.encoder
def _get_decoder_module(self):
return self.model.decoder
def __call__(
self,
input_ids,
attention_mask,
decoder_input_ids,
decoder_attention_mask,
position_ids,
decoder_position_ids,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
deterministic: bool = True,
):
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
position_ids=position_ids,
decoder_position_ids=decoder_position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=deterministic,
)
hidden_states = outputs[0]
if self.config.tie_word_embeddings:
shared_embedding = self.model.variables["params"]["shared"]["embedding"]
lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
else:
lm_logits = self.lm_head(hidden_states)
lm_logits += jax.lax.stop_gradient(self.final_logits_bias.astype(self.dtype))
if not return_dict:
output = (lm_logits,) + outputs[1:]
return output
return FlaxSeq2SeqLMOutput(
logits=lm_logits,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
cross_attentions=outputs.cross_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions,
)
@add_start_docstrings(
"The BART Model with a language modeling head. Can be used for summarization.", BART_START_DOCSTRING
)
class FlaxBartForConditionalGeneration(FlaxBartPreTrainedModel):
module_class = FlaxBartForConditionalGenerationModule
dtype: jnp.dtype = jnp.float32
@add_start_docstrings(BART_DECODE_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=BartConfig)
def decode(
self,
decoder_input_ids,
encoder_outputs,
encoder_attention_mask: Optional[jnp.ndarray] = None,
decoder_attention_mask: Optional[jnp.ndarray] = None,
decoder_position_ids: Optional[jnp.ndarray] = None,
past_key_values: dict = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
):
pass
def prepare_inputs_for_generation(
self,
decoder_input_ids,
max_length,
attention_mask: Optional[jax.Array] = None,
decoder_attention_mask: Optional[jax.Array] = None,
encoder_outputs=None,
**kwargs,
):
batch_size, seq_length = decoder_input_ids.shape
past_key_values = self.init_cache(batch_size, max_length, encoder_outputs)
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
if decoder_attention_mask is not None:
position_ids = decoder_attention_mask.cumsum(axis=-1) - 1
extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0))
else:
position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
return {
"past_key_values": past_key_values,
"encoder_outputs": encoder_outputs,
"encoder_attention_mask": attention_mask,
"decoder_attention_mask": extended_attention_mask,
"decoder_position_ids": position_ids,
}
def update_inputs_for_generation(self, model_outputs, model_kwargs):
model_kwargs["past_key_values"] = model_outputs.past_key_values
model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1
return model_kwargs
FLAX_BART_CONDITIONAL_GENERATION_DOCSTRING = """
Returns:
Summarization example:
```
>>> from transformers import AutoTokenizer, FlaxBartForConditionalGeneration
>>> model = FlaxBartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")
# 使用预训练的 FlaxBart 模型加载条件生成模型,用于生成文本摘要
>>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
# 使用预训练的 tokenizer 加载 BART 模型的分词器
>>> ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs."
# 待摘要的文章内容
>>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors="np")
# 使用分词器对文章进行分词,并封装成适合模型输入的格式
>>> # Generate Summary
# 生成摘要的过程
>>> summary_ids = model.generate(inputs["input_ids"]).sequences
# 使用模型生成输入文章的摘要序列
>>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False))
# 打印生成的摘要,跳过特殊标记并保持分词时的空格处理方式
Mask filling example:
>>> import jax
# 导入 JAX 库,用于高性能数值计算
>>> from transformers import AutoTokenizer, FlaxBartForConditionalGeneration
>>> model = FlaxBartForConditionalGeneration.from_pretrained("facebook/bart-large")
# 使用预训练的 FlaxBart 模型加载条件生成模型
>>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large")
# 使用预训练的 tokenizer 加载 BART 模型的分词器
>>> TXT = "My friends are <mask> but they eat too many carbs."
# 带有掩码填充的文本示例
>>> input_ids = tokenizer([TXT], return_tensors="jax")["input_ids"]
# 使用分词器对带有掩码的文本进行分词,并封装成适合模型输入的格式
>>> logits = model(input_ids).logits
# 通过模型生成输入文本的 logits,用于获取每个词的预测概率
>>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero()[0].item()
# 找到掩码位置的索引
>>> probs = jax.nn.softmax(logits[0, masked_index], axis=0)
# 对掩码位置的 logits 进行 softmax 处理,得到预测概率分布
>>> values, predictions = jax.lax.top_k(probs, k=1)
# 获取最高概率的预测值和其对应的索引
>>> tokenizer.decode(predictions).split()
# 解码预测的标记并拆分成词汇列表
"""
将调用文档字符串覆盖为 BART 输入文档字符串和 FLAX BART 条件生成文档字符串的组合
"""
overwrite_call_docstring(
FlaxBartForConditionalGeneration, BART_INPUTS_DOCSTRING + FLAX_BART_CONDITIONAL_GENERATION_DOCSTRING
)
"""
追加并替换 FlaxBartForConditionalGeneration 类的返回文档字符串
"""
append_replace_return_docstrings(
FlaxBartForConditionalGeneration, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC
)
"""
定义一个用于序列分类的 FlaxBartForSequenceClassificationModule 类
"""
class FlaxBartForSequenceClassificationModule(nn.Module):
"""
BART 的配置
"""
config: BartConfig
"""
数据类型,默认为 32 位浮点数
"""
dtype: jnp.dtype = jnp.float32
"""
可选的标签数目
"""
num_labels: Optional[int] = None
"""
模型的设置方法
"""
def setup(self):
"""
创建 BART 模型实例
"""
self.model = FlaxBartModule(config=self.config, dtype=self.dtype)
"""
创建用于分类的 BART 分类头
"""
self.classification_head = FlaxBartClassificationHead(
config=self.config,
inner_dim=self.config.d_model,
num_classes=self.num_labels if self.num_labels is not None else self.config.num_labels,
pooler_dropout=self.config.classifier_dropout,
)
"""
获取编码器模块的私有方法
"""
def _get_encoder_module(self):
return self.model.encoder
"""
获取解码器模块的私有方法
"""
def _get_decoder_module(self):
return self.model.decoder
"""
定义类实例被调用时的行为
"""
def __call__(
self,
input_ids,
attention_mask,
decoder_input_ids,
decoder_attention_mask,
position_ids,
decoder_position_ids,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
deterministic: bool = True,
"""
输入序列分类模块的参数:
input_ids: 输入的 token IDs
attention_mask: 注意力遮罩
decoder_input_ids: 解码器的输入 token IDs
decoder_attention_mask: 解码器的注意力遮罩
position_ids: 位置 IDs
decoder_position_ids: 解码器的位置 IDs
output_attentions: 是否输出注意力权重
output_hidden_states: 是否输出隐藏状态
return_dict: 是否返回字典格式的输出
deterministic: 是否确定性运行
"""
# 实例方法主体为空,由子类实现具体逻辑
pass
):
# 调用模型进行推理,获取输出结果
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
position_ids=position_ids,
decoder_position_ids=decoder_position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=deterministic,
)
# 获取模型输出中的最后一个隐藏状态
hidden_states = outputs[0] # 最后一个隐藏状态
# 创建一个掩码,标记输入中的 <eos> 位置
eos_mask = jnp.where(input_ids == self.config.eos_token_id, 1, 0)
# 处理特定的 JAX 编译错误类型,确保避免 JIT 编译中的错误
if type(eos_mask) != jax.interpreters.partial_eval.DynamicJaxprTracer:
# 检查每个示例中 <eos> 标记的数量是否一致
if len(jnp.unique(eos_mask.sum(1))) > 1:
raise ValueError("所有示例必须具有相同数量的 <eos> 标记。")
# 检查是否有示例缺少 <eos> 标记
if any(eos_mask.sum(1) == 0):
raise ValueError("输入中缺少 <eos> 标记。")
# 为每个示例保留最后一个 <eos> 标记
eos_mask_noised = eos_mask + jnp.arange(eos_mask.shape[1]) * 1e-6
eos_mask = jnp.where(eos_mask_noised == eos_mask_noised.max(1).reshape(-1, 1), 1, 0)
# 使用 eos_mask 对隐藏状态进行加权求和,以获得句子表示
sentence_representation = jnp.einsum("ijk, ij -> ijk", hidden_states, eos_mask).sum(1)
# 将句子表示传递给分类头,获取分类 logits
logits = self.classification_head(sentence_representation, deterministic=deterministic)
# 如果不需要返回字典,则返回输出的元组
if not return_dict:
output = (logits,) + outputs[1:]
return output
# 构造 FlaxSeq2SeqSequenceClassifierOutput 对象,封装模型输出
return FlaxSeq2SeqSequenceClassifierOutput(
logits=logits,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
cross_attentions=outputs.cross_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions,
)
# 使用自定义的 docstring 添加起始注释给 FlaxBartForSequenceClassification 类,指定其用途和应用场景
@add_start_docstrings(
"""
Bart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE
tasks.
""",
BART_START_DOCSTRING, # 引用预定义的 Bart 模型的起始注释
)
class FlaxBartForSequenceClassification(FlaxBartPreTrainedModel):
module_class = FlaxBartForSequenceClassificationModule # 设定模型类
dtype = jnp.float32 # 设置数据类型
# 向 FlaxBartForSequenceClassification 类添加调用样例的文档字符串
append_call_sample_docstring(
FlaxBartForSequenceClassification,
_CHECKPOINT_FOR_DOC, # 引用检查点文档
FlaxSeq2SeqSequenceClassifierOutput, # 引用输出类文档
_CONFIG_FOR_DOC, # 引用配置文档
)
# 定义 FlaxBartForQuestionAnsweringModule 类,继承自 nn.Module
class FlaxBartForQuestionAnsweringModule(nn.Module):
config: BartConfig # 使用 BartConfig 配置
dtype: jnp.dtype = jnp.float32 # 设置数据类型为 float32
num_labels = 2 # 设定标签数量为 2
def setup(self):
self.model = FlaxBartModule(config=self.config, dtype=self.dtype) # 使用配置和数据类型初始化模型
self.qa_outputs = nn.Dense( # 定义问题-回答输出层
self.num_labels, # 输出层标签数量
dtype=self.dtype, # 输出层数据类型
kernel_init=jax.nn.initializers.normal(self.config.init_std), # 使用正态分布初始化权重
)
def _get_encoder_module(self):
return self.model.encoder # 获取编码器模块
def _get_decoder_module(self):
return self.model.decoder # 获取解码器模块
def __call__(
self,
input_ids,
attention_mask,
decoder_input_ids,
decoder_attention_mask,
position_ids,
decoder_position_ids,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
deterministic: bool = True,
):
# 调用模型进行正向传播
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
position_ids=position_ids,
decoder_position_ids=decoder_position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=deterministic,
)
sequence_output = outputs[0] # 提取序列输出
logits = self.qa_outputs(sequence_output) # 通过问题-回答输出层计算 logits
start_logits, end_logits = jnp.split(logits, logits.shape[-1], axis=-1) # 分割 logits 得到起始和结束 logits
start_logits = start_logits.squeeze(-1) # 压缩起始 logits 的最后一维
end_logits = end_logits.squeeze(-1) # 压缩结束 logits 的最后一维
if not return_dict:
output = (start_logits, end_logits) + outputs[1:] # 如果不返回字典,则将输出整合为元组
return output
# 返回字典格式的输出
return FlaxSeq2SeqQuestionAnsweringModelOutput(
start_logits=start_logits,
end_logits=end_logits,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
cross_attentions=outputs.cross_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions,
)
# 使用自定义的 docstring 添加起始注释给 FlaxBartForSequenceClassification 类,指定其用途和应用场景
@add_start_docstrings(
"""
BART Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
""",
layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
# 创建一个层用于在隐藏状态输出的基础上计算“span起始位置logits”和“span结束位置logits”。
""",
BART_START_DOCSTRING,
)
class FlaxBartForQuestionAnswering(FlaxBartPreTrainedModel):
module_class = FlaxBartForQuestionAnsweringModule
dtype = jnp.float32
append_call_sample_docstring(
FlaxBartForQuestionAnswering,
_CHECKPOINT_FOR_DOC,
FlaxSeq2SeqQuestionAnsweringModelOutput,
_CONFIG_FOR_DOC,
)
class FlaxBartDecoderPreTrainedModel(FlaxPreTrainedModel):
config_class = BartConfig
base_model_prefix: str = "model"
module_class: nn.Module = None
def __init__(
self,
config: BartConfig,
input_shape: Tuple[int] = (1, 1),
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs,
):
config.is_decoder = True
config.is_encoder_decoder = False
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
input_ids = jnp.zeros(input_shape, dtype="i4")
attention_mask = jnp.ones_like(input_ids)
batch_size, sequence_length = input_ids.shape
position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
encoder_hidden_states = jnp.zeros(input_shape + (self.config.d_model,))
encoder_attention_mask = attention_mask
module_init_outputs = self.module.init(
rngs,
input_ids,
attention_mask,
position_ids,
encoder_hidden_states,
encoder_attention_mask,
return_dict=False,
)
return module_init_outputs["params"]
def init_cache(self, batch_size, max_length):
r"""
Args:
batch_size (`int`):
用于快速自回归解码的批量大小,定义了初始化缓存的批量大小。
max_length (`int`):
自回归解码的最大可能长度,定义了初始化缓存的序列长度。
"""
input_ids = jnp.ones((batch_size, max_length), dtype="i4")
attention_mask = jnp.ones_like(input_ids, dtype="i4")
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
init_variables = self.module.init(
jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
)
return unfreeze(init_variables["cache"])
@add_start_docstrings_to_model_forward(BART_DECODE_INPUTS_DOCSTRING)
def __call__(
self,
input_ids: jnp.ndarray,
attention_mask: Optional[jnp.ndarray] = None,
position_ids: Optional[jnp.ndarray] = None,
encoder_hidden_states: Optional[jnp.ndarray] = None,
encoder_attention_mask: Optional[jnp.ndarray] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
past_key_values: dict = None,
dropout_rng: PRNGKey = None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
if encoder_hidden_states is not None and encoder_attention_mask is None:
batch_size, sequence_length = encoder_hidden_states.shape[:2]
encoder_attention_mask = jnp.ones((batch_size, sequence_length))
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids)
if position_ids is None:
batch_size, sequence_length = input_ids.shape
position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
inputs = {"params": params or self.params}
if past_key_values:
inputs["cache"] = past_key_values
mutable = ["cache"]
else:
mutable = False
outputs = self.module.apply(
inputs,
input_ids=jnp.array(input_ids, dtype="i4"),
attention_mask=jnp.array(attention_mask, dtype="i4"),
position_ids=jnp.array(position_ids, dtype="i4"),
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=not train,
rngs=rngs,
mutable=mutable,
)
if past_key_values is not None and return_dict:
outputs, past_key_values = outputs
outputs["past_key_values"] = unfreeze(past_key_values["cache"])
return outputs
elif past_key_values is not None and not return_dict:
outputs, past_key_values = outputs
outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
return outputs
class FlaxBartDecoderWrapper(nn.Module):
"""
This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
used in combination with the [`EncoderDecoderModel`] framework.
"""
config: BartConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
embed_dim = self.config.d_model
embed_tokens = nn.Embed(
self.config.vocab_size,
embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
dtype=self.dtype,
)
self.decoder = FlaxBartDecoder(config=self.config, embed_tokens=embed_tokens, dtype=self.dtype)
def __call__(self, *args, **kwargs):
return self.decoder(*args, **kwargs)
class FlaxBartForCausalLMModule(nn.Module):
config: BartConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.model = FlaxBartDecoderWrapper(config=self.config, dtype=self.dtype)
self.lm_head = nn.Dense(
self.config.vocab_size,
use_bias=False,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
def __call__(
self,
input_ids,
attention_mask,
position_ids,
encoder_hidden_states: Optional[jnp.ndarray] = None,
encoder_attention_mask: Optional[jnp.ndarray] = None,
init_cache: bool = False,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
deterministic: bool = True,
):
outputs = self.model(
input_ids,
attention_mask,
position_ids,
encoder_hidden_states,
encoder_attention_mask,
deterministic=deterministic,
init_cache=init_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
if self.config.tie_word_embeddings:
shared_embedding = self.model.variables["params"]["decoder"]["embed_tokens"]["embedding"]
lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
else:
lm_logits = self.lm_head(hidden_states)
if not return_dict:
return (lm_logits,) + outputs[1:]
return FlaxCausalLMOutputWithCrossAttentions(
logits=lm_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
)
@add_start_docstrings(
"""
Bart Decoder Model with a language modeling head on top (linear layer with weights tied to the input embeddings)
e.g for autoregressive tasks.
""",
BART_START_DOCSTRING,
)
class FlaxBartForCausalLM(FlaxBartDecoderPreTrainedModel):
module_class = FlaxBartForCausalLMModule
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
batch_size, seq_length = input_ids.shape
past_key_values = self.init_cache(batch_size, max_length)
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
if attention_mask is not None:
position_ids = attention_mask.cumsum(axis=-1) - 1
extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
else:
position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
return {
"past_key_values": past_key_values,
"attention_mask": extended_attention_mask,
"position_ids": position_ids,
}
def update_inputs_for_generation(self, model_outputs, model_kwargs):
model_kwargs["past_key_values"] = model_outputs.past_key_values
model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
return model_kwargs
append_call_sample_docstring(
FlaxBartForCausalLM,
_CHECKPOINT_FOR_DOC,
FlaxCausalLMOutputWithCrossAttentions,
_CONFIG_FOR_DOC,
)
.\models\bart\modeling_tf_bart.py
"""
# "Verify that `labels` has only positive values and -100"
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
# Make sure the assertion op is called by wrapping the result in an identity no-op
with tf.control_dependencies([assert_gte0]):
# 确保断言操作被调用,并返回与输入相同的 shifted_input_ids 张量
shifted_input_ids = tf.identity(shifted_input_ids)
return shifted_input_ids
# 创建用于双向自注意力的因果掩码。
bsz = input_ids_shape[0] # 获取批次大小
tgt_len = input_ids_shape[1] # 获取目标长度
mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE # 创建一个初始掩码矩阵,用负无穷大填充
mask_cond = tf.range(shape_list(mask)[-1]) # 创建一个与掩码矩阵最后一个维度大小相等的序列
# 将掩码矩阵的下三角部分置零,实现因果性,确保每个位置只能依赖于它之前的位置
mask = tf.where(mask_cond < tf.reshape(mask_cond + 1, (shape_list(mask)[-1], 1)), 0.0, mask)
if past_key_values_length > 0:
# 如果存在过去的键值对长度,则在掩码矩阵左侧填充零,以匹配过去键值对的长度
mask = tf.concat([tf.zeros((tgt_len, past_key_values_length)), mask], axis=-1)
# 使用 tf.tile 扩展掩码矩阵的维度以匹配输入的批次大小,并返回结果
return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1))
def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
# 获取注意力掩码的序列长度
src_len = shape_list(mask)[1]
# 如果未指定目标长度,则使用源长度作为目标长度
tgt_len = tgt_len if tgt_len is not None else src_len
# 创建常数张量,值为1.0
one_cst = tf.constant(1.0)
# 将注意力掩码转换为与 one_cst 相同数据类型的张量
mask = tf.cast(mask, dtype=one_cst.dtype)
# 在第二维和第三维上对注意力掩码进行复制扩展
expanded_mask = tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1))
# 返回扩展后的掩码,并乘以一个大负数,表示未关注的区域
return (one_cst - expanded_mask) * LARGE_NEGATIVE
class TFBartLearnedPositionalEmbedding(keras.layers.Embedding):
"""
This module learns positional embeddings up to a fixed maximum size.
"""
def __init__(self, num_embeddings: int, embedding_dim: int, **kwargs):
# 如果 padding_idx 被指定,Bart 模型会偏移嵌入的 id 值并相应调整 num_embeddings
# 这是一个针对 Bart 模型的特殊处理,其他模型不需要
self.offset = 2
super().__init__(num_embeddings + self.offset, embedding_dim, **kwargs)
def call(
self,
input_shape: Optional[tf.TensorShape] = None,
past_key_values_length: int = 0,
position_ids: tf.Tensor | None = None,
):
"""Input is expected to be of size [bsz x seqlen]."""
if position_ids is None:
# 如果未提供位置 id,则根据输入形状中的序列长度创建位置 id
seq_len = input_shape[1]
position_ids = tf.range(seq_len, delta=1, name="range")
position_ids += past_key_values_length
# 确定位置 id 的数据类型,并将其与偏移量相加后传递给父类的调用方法
offset_dtype = position_ids.dtype if isinstance(position_ids, tf.Tensor) else tf.int32
return super().call(position_ids + tf.constant(self.offset, dtype=offset_dtype))
class TFBartAttention(keras.layers.Layer):
"""Multi-headed attention from "Attention Is All You Need"""
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
is_decoder: bool = False,
bias: bool = True,
**kwargs,
):
super().__init__(**kwargs)
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = keras.layers.Dropout(dropout)
self.head_dim = embed_dim // num_heads
if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {num_heads})."
)
self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
self.k_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="k_proj")
self.q_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="q_proj")
self.v_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="v_proj")
self.out_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="out_proj")
def _shape(self, tensor: tf.Tensor, seq_len: int, bsz: int):
return tf.transpose(tf.reshape(tensor, (bsz, seq_len, self.num_heads, self.head_dim)), (0, 2, 1, 3))
def call(
self,
hidden_states: tf.Tensor,
key_value_states: tf.Tensor | None = None,
past_key_value: Tuple[Tuple[tf.Tensor]] | None = None,
attention_mask: tf.Tensor | None = None,
layer_head_mask: tf.Tensor | None = None,
training: Optional[bool] = False,
):
if self.built:
return
self.built = True
if getattr(self, "k_proj", None) is not None:
with tf.name_scope(self.k_proj.name):
self.k_proj.build([None, None, self.embed_dim])
if getattr(self, "q_proj", None) is not None:
with tf.name_scope(self.q_proj.name):
self.q_proj.build([None, None, self.embed_dim])
if getattr(self, "v_proj", None) is not None:
with tf.name_scope(self.v_proj.name):
self.v_proj.build([None, None, self.embed_dim])
if getattr(self, "out_proj", None) is not None:
with tf.name_scope(self.out_proj.name):
self.out_proj.build([None, None, self.embed_dim])
class TFBartEncoderLayer(keras.layers.Layer):
def __init__(self, config: BartConfig, **kwargs):
super().__init__(**kwargs)
self.embed_dim = config.d_model
self.self_attn = TFBartAttention(
self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout, name="self_attn"
)
self.self_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm")
self.dropout = keras.layers.Dropout(config.dropout)
self.activation_fn = get_tf_activation(config.activation_function)
self.activation_dropout = keras.layers.Dropout(config.activation_dropout)
self.fc1 = keras.layers.Dense(config.encoder_ffn_dim, name="fc1")
self.fc2 = keras.layers.Dense(self.embed_dim, name="fc2")
self.final_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
self.config = config
def call(
self,
hidden_states: tf.Tensor,
attention_mask: np.ndarray | tf.Tensor | None,
layer_head_mask: tf.Tensor | None,
training: Optional[bool] = False,
) -> tf.Tensor:
"""
Args:
hidden_states (`tf.Tensor`): 输入到该层的张量,形状为 `(batch, seq_len, embed_dim)`
attention_mask (`tf.Tensor`): 注意力掩码张量,形状为 `(batch, 1, tgt_len, src_len)`,用大负值表示填充元素
layer_head_mask (`tf.Tensor`): 给定层中注意力头的掩码张量,形状为 `(encoder_attention_heads,)`
training (`Optional[bool]`): 是否处于训练模式,默认为 False
"""
residual = hidden_states
hidden_states, self_attn_weights, _ = self.self_attn(
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
)
tf.debugging.assert_equal(
shape_list(hidden_states),
shape_list(residual),
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
)
hidden_states = self.dropout(hidden_states, training=training)
hidden_states = residual + hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
residual = hidden_states
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = self.activation_dropout(hidden_states, training=training)
hidden_states = self.fc2(hidden_states)
hidden_states = self.dropout(hidden_states, training=training)
hidden_states = residual + hidden_states
hidden_states = self.final_layer_norm(hidden_states)
return hidden_states, self_attn_weights
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "self_attn", None) is not None:
with tf.name_scope(self.self_attn.name):
self.self_attn.build(None)
if getattr(self, "self_attn_layer_norm", None) is not None:
with tf.name_scope(self.self_attn_layer_norm.name):
self.self_attn_layer_norm.build([None, None, self.embed_dim])
if getattr(self, "fc1", None) is not None:
with tf.name_scope(self.fc1.name):
self.fc1.build([None, None, self.embed_dim])
if getattr(self, "fc2", None) is not None:
with tf.name_scope(self.fc2.name):
self.fc2.build([None, None, self.config.encoder_ffn_dim])
if getattr(self, "final_layer_norm", None) is not None:
with tf.name_scope(self.final_layer_norm.name):
self.final_layer_norm.build([None, None, self.embed_dim])
class TFBartDecoderLayer(keras.layers.Layer):
def __init__(self, config: BartConfig, **kwargs):
super().__init__(**kwargs)
self.embed_dim = config.d_model
self.self_attn = TFBartAttention(
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
name="self_attn",
is_decoder=True,
)
self.dropout = keras.layers.Dropout(config.dropout)
self.activation_fn = get_tf_activation(config.activation_function)
self.activation_dropout = keras.layers.Dropout(config.activation_dropout)
self.self_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm")
self.encoder_attn = TFBartAttention(
self.embed_dim,
config.decoder_attention_heads,
dropout=config.attention_dropout,
name="encoder_attn",
is_decoder=True,
)
self.encoder_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="encoder_attn_layer_norm")
self.fc1 = keras.layers.Dense(config.decoder_ffn_dim, name="fc1")
self.fc2 = keras.layers.Dense(self.embed_dim, name="fc2")
self.final_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
self.config = config
def call(
self,
hidden_states: tf.Tensor,
attention_mask: np.ndarray | tf.Tensor | None = None,
encoder_hidden_states: np.ndarray | tf.Tensor | None = None,
encoder_attention_mask: np.ndarray | tf.Tensor | None = None,
layer_head_mask: tf.Tensor | None = None,
cross_attn_layer_head_mask: tf.Tensor | None = None,
past_key_value: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
training: Optional[bool] = False,
pass
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "self_attn", None) is not None:
with tf.name_scope(self.self_attn.name):
self.self_attn.build(None)
if getattr(self, "self_attn_layer_norm", None) is not None:
with tf.name_scope(self.self_attn_layer_norm.name):
self.self_attn_layer_norm.build([None, None, self.embed_dim])
if getattr(self, "encoder_attn", None) is not None:
with tf.name_scope(self.encoder_attn.name):
self.encoder_attn.build(None)
if getattr(self, "encoder_attn_layer_norm", None) is not None:
with tf.name_scope(self.encoder_attn_layer_norm.name):
self.encoder_attn_layer_norm.build([None, None, self.embed_dim])
if getattr(self, "fc1", None) is not None:
with tf.name_scope(self.fc1.name):
self.fc1.build([None, None, self.embed_dim])
if getattr(self, "fc2", None) is not None:
with tf.name_scope(self.fc2.name):
self.fc2.build([None, None, self.config.decoder_ffn_dim])
if getattr(self, "final_layer_norm", None) is not None:
with tf.name_scope(self.final_layer_norm.name):
self.final_layer_norm.build([None, None, self.embed_dim])
class TFBartClassificationHead(keras.layers.Layer):
"""Head for sentence-level classification tasks."""
def __init__(self, inner_dim: int, num_classes: int, pooler_dropout: float, name: str, **kwargs):
super().__init__(name=name, **kwargs)
self.dense = keras.layers.Dense(inner_dim, name="dense")
self.dropout = keras.layers.Dropout(pooler_dropout)
self.out_proj = keras.layers.Dense(num_classes, name="out_proj")
self.input_dim = inner_dim
self.inner_dim = inner_dim
def call(self, inputs):
hidden_states = self.dropout(inputs)
hidden_states = self.dense(hidden_states)
hidden_states = keras.activations.tanh(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.out_proj(hidden_states)
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.input_dim])
if getattr(self, "out_proj", None) is not None:
with tf.name_scope(self.out_proj.name):
self.out_proj.build([None, None, self.inner_dim])
class TFBartPretrainedModel(TFPreTrainedModel):
config_class = BartConfig
base_model_prefix = "model"
@property
def dummy_inputs(self):
dummy_inputs = super().dummy_inputs
dummy_inputs["input_ids"] = dummy_inputs["input_ids"] * 2
if "decoder_input_ids" in dummy_inputs:
dummy_inputs["decoder_input_ids"] = dummy_inputs["decoder_input_ids"] * 2
return dummy_inputs
def tf_to_pt_weight_rename(self, tf_weight):
if tf_weight == "model.shared.weight":
return tf_weight, "model.decoder.embed_tokens.weight"
else:
return (tf_weight,)
BART_START_DOCSTRING = r"""
This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
behavior.
<Tip>
TensorFlow models and layers in `transformers` accept two formats as input:
- having all inputs as keyword arguments (like PyTorch models), or
- having all inputs as a list, tuple or dict in the first positional argument.
The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
"""
config ([`BartConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration.
Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
"""
"""
BART_GENERATION_EXAMPLE = r"""
Summarization example:
```
>>> from transformers import AutoTokenizer, TFBartForConditionalGeneration
>>> model = TFBartForConditionalGeneration.from_pretrained("facebook/bart-large")
>>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large")
>>> ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs."
>>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors="tf")
>>> # Generate Summary
>>> summary_ids = model.generate(inputs["input_ids"], num_beams=4, max_length=5)
>>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False))
```
Mask filling example:
```
>>> from transformers import AutoTokenizer, TFBartForConditionalGeneration
>>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large")
>>> TXT = "My friends are <mask> but they eat too many carbs."
>>> model = TFBartForConditionalGeneration.from_pretrained("facebook/bart-large")
>>> input_ids = tokenizer([TXT], return_tensors="tf")["input_ids"]
>>> logits = model(input_ids).logits
>>> probs = tf.nn.softmax(logits[0])
>>> # probs[5] is associated with the mask token
```
"""
BART_INPUTS_DOCSTRING = r"""
"""
@keras_serializable
class TFBartEncoder(keras.layers.Layer):
config_class = BartConfig
"""
Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
[`TFBartEncoderLayer`].
Args:
config: BartConfig
"""
def __init__(self, config: BartConfig, embed_tokens: Optional[keras.layers.Embedding] = None, **kwargs):
super().__init__(**kwargs)
self.config = config
self.dropout = keras.layers.Dropout(config.dropout)
self.layerdrop = config.encoder_layerdrop
self.padding_idx = config.pad_token_id
self.max_source_positions = config.max_position_embeddings
self.embed_scale = tf.math.sqrt(float(config.d_model)) if config.scale_embedding else 1.0
self.embed_tokens = embed_tokens
self.embed_positions = TFBartLearnedPositionalEmbedding(
config.max_position_embeddings,
config.d_model,
name="embed_positions",
)
self.layers = [TFBartEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)]
self.layernorm_embedding = keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding")
self.embed_dim = config.d_model
@unpack_inputs
def call(
self,
input_ids: TFModelInputType | None = None,
inputs_embeds: np.ndarray | tf.Tensor | None = None,
attention_mask: np.ndarray | tf.Tensor | None = None,
head_mask: np.ndarray | tf.Tensor | None = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: Optional[bool] = False,
):
if self.built:
return
self.built = True
if getattr(self, "embed_positions", None) is not None:
with tf.name_scope(self.embed_positions.name):
self.embed_positions.build(None)
if getattr(self, "layernorm_embedding", None) is not None:
with tf.name_scope(self.layernorm_embedding.name):
self.layernorm_embedding.build([None, None, self.embed_dim])
if getattr(self, "layers", None) is not None:
for layer in self.layers:
with tf.name_scope(layer.name):
layer.build(None)
@keras_serializable
class TFBartDecoder(keras.layers.Layer):
config_class = BartConfig
"""
Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`TFBartDecoderLayer`]
Args:
config: BartConfig
embed_tokens: output embedding
"""
def __init__(self, config: BartConfig, embed_tokens: Optional[keras.layers.Embedding] = None, **kwargs):
super().__init__(**kwargs)
self.config = config
self.padding_idx = config.pad_token_id
self.embed_tokens = embed_tokens
self.layerdrop = config.decoder_layerdrop
self.embed_positions = TFBartLearnedPositionalEmbedding(
config.max_position_embeddings,
config.d_model,
name="embed_positions",
)
self.embed_scale = tf.math.sqrt(float(config.d_model)) if config.scale_embedding else 1.0
self.layers = [TFBartDecoderLayer(config, name=f"layers.{i}") for i in range(config.decoder_layers)]
self.layernorm_embedding = keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding")
self.dropout = keras.layers.Dropout(config.dropout)
@unpack_inputs
def call(
self,
input_ids: TFModelInputType | None = None,
inputs_embeds: np.ndarray | tf.Tensor | None = None,
attention_mask: np.ndarray | tf.Tensor | None = None,
position_ids: np.ndarray | tf.Tensor | None = None,
encoder_hidden_states: np.ndarray | tf.Tensor | None = None,
encoder_attention_mask: np.ndarray | tf.Tensor | None = None,
head_mask: np.ndarray | tf.Tensor | None = None,
cross_attn_head_mask: np.ndarray | tf.Tensor | None = None,
past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: Optional[bool] = False,
):
...
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "embed_positions", None) is not None:
with tf.name_scope(self.embed_positions.name):
self.embed_positions.build(None)
if getattr(self, "layernorm_embedding", None) is not None:
with tf.name_scope(self.layernorm_embedding.name):
self.layernorm_embedding.build([None, None, self.config.d_model])
if getattr(self, "layers", None) is not None:
for layer in self.layers:
with tf.name_scope(layer.name):
layer.build(None)
@keras_serializable
class TFBartMainLayer(keras.layers.Layer):
config_class = BartConfig
def __init__(self, config: BartConfig, load_weight_prefix=None, **kwargs):
super().__init__(**kwargs)
self.config = config
self.shared = keras.layers.Embedding(
input_dim=config.vocab_size,
output_dim=config.d_model,
embeddings_initializer=keras.initializers.TruncatedNormal(stddev=self.config.init_std),
name="model.shared",
)
self.shared.load_weight_prefix = "model.shared" if load_weight_prefix is None else load_weight_prefix
self.encoder = TFBartEncoder(config, self.shared, name="encoder")
self.decoder = TFBartDecoder(config, self.shared, name="decoder")
def get_input_embeddings(self):
return self.shared
def set_input_embeddings(self, new_embeddings):
self.shared = new_embeddings
self.encoder.embed_tokens = self.shared
self.decoder.embed_tokens = self.shared
@unpack_inputs
def call(
self,
input_ids: TFModelInputType | None = None,
attention_mask: np.ndarray | tf.Tensor | None = None,
decoder_input_ids: np.ndarray | tf.Tensor | None = None,
decoder_attention_mask: np.ndarray | tf.Tensor | None = None,
decoder_position_ids: np.ndarray | tf.Tensor | None = None,
head_mask: np.ndarray | tf.Tensor | None = None,
decoder_head_mask: np.ndarray | tf.Tensor | None = None,
cross_attn_head_mask: np.ndarray | tf.Tensor | None = None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
inputs_embeds: np.ndarray | tf.Tensor | None = None,
decoder_inputs_embeds: np.ndarray | tf.Tensor | None = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: Optional[bool] = False,
**kwargs,
):
def build(self, input_shape=None):
if self.built:
return
self.built = True
with tf.name_scope(self.shared.load_weight_prefix + "/" + self.shared.name + "/"):
self.shared.build(None)
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
if getattr(self, "decoder", None) is not None:
with tf.name_scope(self.decoder.name):
self.decoder.build(None)
@add_start_docstrings(
"The bare BART Model outputting raw hidden-states without any specific head on top.",
BART_START_DOCSTRING,
)
class TFBartModel(TFBartPretrainedModel):
_requires_load_weight_prefix = True
def __init__(self, config: BartConfig, load_weight_prefix=None, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.model = TFBartMainLayer(config, load_weight_prefix=load_weight_prefix, name="model")
def get_encoder(self):
return self.model.encoder
def get_decoder(self):
return self.model.decoder
@add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TFSeq2SeqModelOutput,
config_class=_CONFIG_FOR_DOC,
)
@unpack_inputs
def call(
self,
input_ids: TFModelInputType | None = None,
attention_mask: np.ndarray | tf.Tensor | None = None,
decoder_input_ids: np.ndarray | tf.Tensor | None = None,
decoder_attention_mask: np.ndarray | tf.Tensor | None = None,
decoder_position_ids: np.ndarray | tf.Tensor | None = None,
head_mask: np.ndarray | tf.Tensor | None = None,
decoder_head_mask: np.ndarray | tf.Tensor | None = None,
cross_attn_head_mask: np.ndarray | tf.Tensor | None = None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
inputs_embeds: np.ndarray | tf.Tensor | None = None,
decoder_inputs_embeds: np.ndarray | tf.Tensor | None = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: Optional[bool] = False,
**kwargs,
) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
decoder_position_ids=decoder_position_ids,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
return outputs
def serving_output(self, output):
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
return TFSeq2SeqModelOutput(
last_hidden_state=output.last_hidden_state,
past_key_values=pkv,
decoder_hidden_states=dec_hs,
decoder_attentions=dec_attns,
cross_attentions=cross_attns,
encoder_last_hidden_state=output.encoder_last_hidden_state,
encoder_hidden_states=enc_hs,
encoder_attentions=enc_attns,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "model", None) is not None:
with tf.name_scope(self.model.name):
self.model.build(None)
class BiasLayer(keras.layers.Layer):
"""
Bias as a layer. It is used for serialization purposes: `keras.Model.save_weights` stores on a per-layer basis,
so all weights have to be registered in a layer.
"""
def __init__(self, shape, initializer, trainable, name, **kwargs):
super().__init__(name=name, **kwargs)
self.bias = self.add_weight(name=name, shape=shape, initializer=initializer, trainable=trainable)
def call(self, x):
return x + self.bias
@add_start_docstrings(
"The BART Model with a language modeling head. Can be used for summarization.",
BART_START_DOCSTRING,
)
class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageModelingLoss):
_keys_to_ignore_on_load_missing = [r"final_logits_bias"]
_requires_load_weight_prefix = True
def __init__(self, config, load_weight_prefix=None, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.model = TFBartMainLayer(config, load_weight_prefix=load_weight_prefix, name="model")
self.use_cache = config.use_cache
self.bias_layer = BiasLayer(
name="final_logits_bias", shape=[1, config.vocab_size], initializer="zeros", trainable=False
)
def get_decoder(self):
return self.model.decoder
def get_encoder(self):
return self.model.encoder
def get_output_embeddings(self):
return self.get_input_embeddings()
def set_output_embeddings(self, value):
self.set_input_embeddings(value)
def get_bias(self):
return {"final_logits_bias": self.bias_layer.bias}
def set_bias(self, value):
vocab_size = value["final_logits_bias"].shape[-1]
self.bias_layer = BiasLayer(
name="final_logits_bias", shape=[1, vocab_size], initializer="zeros", trainable=False
)
self.bias_layer.bias.assign(value["final_logits_bias"])
@add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
@add_end_docstrings(BART_GENERATION_EXAMPLE)
@unpack_inputs
input_ids: TFModelInputType | None = None,
attention_mask: np.ndarray | tf.Tensor | None = None,
decoder_input_ids: np.ndarray | tf.Tensor | None = None,
decoder_attention_mask: np.ndarray | tf.Tensor | None = None,
decoder_position_ids: np.ndarray | tf.Tensor | None = None,
head_mask: np.ndarray | tf.Tensor | None = None,
decoder_head_mask: np.ndarray | tf.Tensor | None = None,
cross_attn_head_mask: np.ndarray | tf.Tensor | None = None,
encoder_outputs: Optional[TFBaseModelOutput] = None,
past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
inputs_embeds: np.ndarray | tf.Tensor | None = None,
decoder_inputs_embeds: np.ndarray | tf.Tensor | None = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
labels: tf.Tensor | None = None,
training: Optional[bool] = False,
) -> Union[TFSeq2SeqLMOutput, Tuple[tf.Tensor]]:
r"""
labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Either a `TFSeq2SeqLMOutput` object or a tuple containing a `tf.Tensor` depending on the `return_dict` parameter.
"""
if labels is not None:
labels = tf.where(
labels == self.config.pad_token_id,
tf.cast(tf.fill(shape_list(labels), -100), labels.dtype),
labels,
)
use_cache = False
if decoder_input_ids is None and decoder_inputs_embeds is None:
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
)
outputs = self.model(
input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
encoder_outputs=encoder_outputs,
decoder_attention_mask=decoder_attention_mask,
decoder_position_ids=decoder_position_ids,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
lm_logits = tf.matmul(outputs[0], self.model.shared.weights, transpose_b=True)
lm_logits = self.bias_layer(lm_logits)
masked_lm_loss = None if labels is None else self.hf_compute_loss(labels, lm_logits)
if not return_dict:
output = (lm_logits,) + outputs[1:]
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
else:
return TFSeq2SeqLMOutput(
loss=masked_lm_loss,
logits=lm_logits,
past_key_values=outputs.past_key_values,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
cross_attentions=outputs.cross_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions,
)
def serving_output(self, output):
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
return TFSeq2SeqLMOutput(
logits=output.logits,
past_key_values=pkv,
decoder_hidden_states=dec_hs,
decoder_attentions=dec_attns,
cross_attentions=cross_attns,
encoder_last_hidden_state=output.encoder_last_hidden_state,
encoder_hidden_states=enc_hs,
encoder_attentions=enc_attns,
)
def prepare_inputs_for_generation(
self,
decoder_input_ids,
past_key_values=None,
attention_mask=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
use_cache=None,
encoder_outputs=None,
**kwargs,
):
if past_key_values is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
if decoder_attention_mask is not None:
decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:]
elif past_key_values is not None:
decoder_position_ids = past_key_values[0][0].shape[2]
else:
decoder_position_ids = tf.range(decoder_input_ids.shape[1])
return {
"input_ids": None,
"encoder_outputs": encoder_outputs,
"past_key_values": past_key_values,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask,
"decoder_position_ids": decoder_position_ids,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
"use_cache": use_cache,
}
def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "model", None) is not None:
with tf.name_scope(self.model.name):
self.model.build(None)
if getattr(self, "bias_layer", None) is not None:
with tf.name_scope(self.bias_layer.name):
self.bias_layer.build(None)
@add_start_docstrings(
"""
Bart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE
tasks.
""",
BART_START_DOCSTRING,
)
class TFBartForSequenceClassification(TFBartPretrainedModel, TFSequenceClassificationLoss):
def __init__(self, config: BartConfig, load_weight_prefix=None, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.model = TFBartMainLayer(config, load_weight_prefix=load_weight_prefix, name="model")
self.classification_head = TFBartClassificationHead(
config.d_model, config.num_labels, config.classifier_dropout, name="classification_head"
)
@add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFSeq2SeqSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
@unpack_inputs
def call(
self,
input_ids: TFModelInputType | None = None,
attention_mask: np.ndarray | tf.Tensor | None = None,
decoder_input_ids: np.ndarray | tf.Tensor | None = None,
decoder_attention_mask: np.ndarray | tf.Tensor | None = None,
decoder_position_ids: np.ndarray | tf.Tensor | None = None,
head_mask: np.ndarray | tf.Tensor | None = None,
decoder_head_mask: np.ndarray | tf.Tensor | None = None,
cross_attn_head_mask: np.ndarray | tf.Tensor | None = None,
encoder_outputs: Optional[TFBaseModelOutput] = None,
past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
inputs_embeds: np.ndarray | tf.Tensor | None = None,
decoder_inputs_embeds: np.ndarray | tf.Tensor | None = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
labels: tf.Tensor | None = None,
training: Optional[bool] = False,
def serving_output(self, output):
logits = tf.convert_to_tensor(output.logits)
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
return TFSeq2SeqSequenceClassifierOutput(
logits=logits,
past_key_values=pkv,
decoder_hidden_states=dec_hs,
decoder_attentions=dec_attns,
cross_attentions=cross_attns,
encoder_last_hidden_state=output.encoder_last_hidden_state,
encoder_hidden_states=enc_hs,
encoder_attentions=enc_attns,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "model", None) is not None:
with tf.name_scope(self.model.name):
self.model.build(None)
if getattr(self, "classification_head", None) is not None:
with tf.name_scope(self.classification_head.name):
self.classification_head.build(None)
.\models\bart\tokenization_bart.py
@lru_cache()
def bytes_to_unicode():
"""
返回 utf-8 字节列表及其与 Unicode 字符串的映射。特别地,避免将空格/控制字符映射到 BPE 代码中会出错的情况。
可逆的 BPE(Byte Pair Encoding)代码适用于 Unicode 字符串。这意味着你的词汇表中需要有大量的 Unicode 字符。
"""
return [
'\u2581' + chr(i) for i in range(0, 128)
] + [chr(i) for i in range(128, 256)]
def make_utf8_to_unicode_lookup():
bs = (
list(range(ord("!"), ord("~") + 1)) +
list(range(ord("¡"), ord("¬") + 1)) +
list(range(ord("®"), ord("ÿ") + 1))
)
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8 + n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
"""
Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
class BartTokenizer(PreTrainedTokenizer):
"""
Constructs a BART tokenizer, which is smilar to the ROBERTa tokenizer, using byte-level Byte-Pair-Encoding.
This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will
be encoded differently whether it is at the beginning of the sentence (without space) or not:
```
>>> from transformers import BartTokenizer
>>> tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
>>> tokenizer("Hello world")["input_ids"]
[0, 31414, 232, 2]
>>> tokenizer(" Hello world")["input_ids"]
[0, 20920, 232, 2]
```
You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer or when you
call it on some text, but since the model was not pretrained this way, it might yield a decrease in performance.
<Tip>
When used with `is_split_into_words=True`, this tokenizer will add a space before each word (even the first one).
</Tip>
This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
this superclass for more information regarding those methods.
"""
def __init__(self, vocab_file, merges_file, errors='replace', special_tokens_dict=None, max_len=None, **kwargs):
super().__init__(vocab_file, merges_file, errors=errors, special_tokens_dict=special_tokens_dict, **kwargs)
self.max_len = max_len
@classmethod
def from_pretrained(cls, *inputs, **kwargs):
return super().from_pretrained(*inputs, **kwargs)
def __call__(self, text, **kwargs):
return super().__call__(text, **kwargs)
Args:
vocab_file (`str`):
词汇表文件的路径。
merges_file (`str`):
合并文件的路径。
errors (`str`, *optional*, defaults to `"replace"`):
当解码字节为 UTF-8 时的错误处理方式。详见 [bytes.decode](https://docs.python.org/3/library/stdtypes.html
bos_token (`str`, *optional*, defaults to `"<s>"`):
预训练过程中用作序列开头的特殊 token。可以用作序列分类器的 token。
<Tip>
在构建序列时使用特殊 token 时,实际用于序列开头的 token 是 `cls_token`。
</Tip>
eos_token (`str`, *optional*, defaults to `"</s>"`):
序列结尾的特殊 token。
<Tip>
在构建序列时使用特殊 token 时,实际用于序列结尾的 token 是 `sep_token`。
</Tip>
sep_token (`str`, *optional*, defaults to `"</s>"`):
分隔符 token,在构建多个序列的合并序列时使用,例如序列分类或问答任务中的问题和文本序列。也作为使用特殊 token 构建序列时的最后一个 token。
cls_token (`str`, *optional*, defaults to `"<s>"`):
分类器 token,在序列分类任务中使用(整个序列的分类而不是每个 token 的分类)。在使用特殊 token 构建序列时,它是序列的第一个 token。
unk_token (`str`, *optional*, defaults to `"<unk>"`):
未知 token,如果词汇表中不存在某个 token,则将其替换为该 token。
pad_token (`str`, *optional*, defaults to `"<pad>"`):
用于填充的 token,在批处理不同长度的序列时使用。
mask_token (`str`, *optional*, defaults to `"<mask>"`):
用于掩码值的 token,在进行掩码语言建模训练时使用,模型将尝试预测该 token。
add_prefix_space (`bool`, *optional*, defaults to `False`):
是否在输入的开头添加一个空格,这允许将第一个词视为其他词一样处理。(BART tokenizer 通过前导空格检测单词的开头)。
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
model_input_names = ["input_ids", "attention_mask"]
def __init__(
self,
vocab_file,
merges_file,
errors="replace",
bos_token="<s>",
eos_token="</s>",
sep_token="</s>",
cls_token="<s>",
unk_token="<unk>",
pad_token="<pad>",
mask_token="<mask>",
add_prefix_space=False,
**kwargs,
):
# 如果初始的特殊标记是字符串类型,则使用AddedToken进行处理,保持左右空格的原样
bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
sep_token = AddedToken(sep_token, lstrip=False, rstrip=False) if isinstance(sep_token, str) else sep_token
cls_token = AddedToken(cls_token, lstrip=False, rstrip=False) if isinstance(cls_token, str) else cls_token
unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
# 处理mask_token,使其像普通单词一样,包括前面的空格
mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
# 使用utf-8编码打开vocab_file文件,加载其中的内容到self.encoder字典中
with open(vocab_file, encoding="utf-8") as vocab_handle:
self.encoder = json.load(vocab_handle)
# 创建self.decoder字典,将self.encoder的键值对反转,用于从索引到单词的解码
self.decoder = {v: k for k, v in self.encoder.items()}
# 设定解码中遇到错误时的处理方式
self.errors = errors # how to handle errors in decoding
# 使用bytes_to_unicode函数生成字节编码到Unicode的映射表
self.byte_encoder = bytes_to_unicode()
# 创建self.byte_decoder字典,将self.byte_encoder的键值对反转,用于从Unicode到字节编码的解码
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
# 使用utf-8编码打开merges_file文件,读取内容并按行分割,去掉首尾空行后将其转换为元组列表bpe_merges
with open(merges_file, encoding="utf-8") as merges_handle:
bpe_merges = merges_handle.read().split("\n")[1:-1]
bpe_merges = [tuple(merge.split()) for merge in bpe_merges]
# 创建self.bpe_ranks字典,将bpe_merges列表转换为字典,键为元组,值为其在列表中的索引
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
# 初始化缓存字典为空字典
self.cache = {}
# 设定是否在前缀空格之前添加特殊标记的选项
self.add_prefix_space = add_prefix_space
# 编译正则表达式模式pat,用于匹配字符串中的各种形式的标点、字母和数字
# 应该添加re.IGNORECASE标志,以便处理大写形式的缩写
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
# 调用父类的构造方法,传递初始化参数
super().__init__(
errors=errors,
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
sep_token=sep_token,
cls_token=cls_token,
pad_token=pad_token,
mask_token=mask_token,
add_prefix_space=add_prefix_space,
**kwargs,
)
@property
def vocab_size(self):
# 返回self.encoder字典的长度,即词汇表的大小
return len(self.encoder)
def get_vocab(self):
# 返回包含self.encoder和self.added_tokens_encoder所有键值对的字典
return dict(self.encoder, **self.added_tokens_encoder)
def _tokenize(self, text):
"""Tokenize a string."""
# 初始化空列表,用于存储BPE处理后的token
bpe_tokens = []
# 使用正则表达式找到所有匹配self.pat的token,并进行处理
for token in re.findall(self.pat, text):
# 将token按utf-8编码,并映射到unicode字符串,避免BPE的控制token(在我们的情况下是空格)
token = "".join(
self.byte_encoder[b] for b in token.encode("utf-8")
) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case)
# 将BPE处理后的token按空格分割,并加入到bpe_tokens列表中
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
# 返回处理后的token列表
return bpe_tokens
def _convert_token_to_id(self, token):
"""Converts a token (str) in an id using the vocab."""
# 使用self.encoder获取token对应的id,若token不存在,则使用self.unk_token的id
return self.encoder.get(token, self.encoder.get(self.unk_token))
def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab."""
# 使用self.decoder获取index对应的token
return self.decoder.get(index)
def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (string) in a single string."""
# 将tokens列表连接成一个字符串
text = "".join(tokens)
# 将字符串按字节解码成utf-8格式,并处理可能的错误
text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
# 返回解码后的文本字符串
return text
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
# 检查保存目录是否存在,如果不存在则记录错误并返回
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
return
# 构建词汇表文件路径
vocab_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
)
# 构建合并文件路径
merge_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
)
# 写入词汇表到文件中
with open(vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
index = 0
# 写入合并数据到文件中
with open(merge_file, "w", encoding="utf-8") as writer:
writer.write("#version: 0.2\n")
# 遍历 BPE rank 数据并写入文件
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
if index != token_index:
# 记录警告,指出 BPE 合并索引不连续的情况
logger.warning(
f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive."
" Please check that the tokenizer is not corrupted!"
)
index = token_index
writer.write(" ".join(bpe_tokens) + "\n")
index += 1
# 返回保存的词汇表文件路径和合并文件路径
return vocab_file, merge_file
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
adding special tokens. A BART sequence has the following format:
- single sequence: `<s> X </s>`
- pair of sequences: `<s> A </s></s> B </s>`
Args:
token_ids_0 (`List[int]`):
List of IDs to which the special tokens will be added.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
"""
# 如果没有第二个序列,则返回添加特殊 token 后的单个序列
if token_ids_1 is None:
return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
# 构建两个序列合并后的输入序列
cls = [self.cls_token_id]
sep = [self.sep_token_id]
return cls + token_ids_0 + sep + sep + token_ids_1 + sep
def get_special_tokens_mask(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
):
"""
Retrieve sequence ids where special tokens are added.
Args:
token_ids_0 (`List[int]`):
List of IDs of the first sequence.
token_ids_1 (`List[int]`, *optional*):
Optional list of IDs of the second sequence.
already_has_special_tokens (`bool`, *optional*):
Whether the sequences already contain special tokens.
Returns:
`List[int]`: A list of binary indicators where 1 indicates a special token and 0 indicates a regular token.
"""
# 计算特殊 token 的掩码
special_tokens_mask = [1] * len(token_ids_0)
if token_ids_1 is not None:
special_tokens_mask += [1] * len(token_ids_1)
return special_tokens_mask
) -> List[int]:
"""
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer `prepare_for_model` method.
Args:
token_ids_0 (`List[int]`):
List of IDs.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not the token list is already formatted with special tokens for the model.
Returns:
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
"""
# If the token list already has special tokens, delegate to the superclass method
if already_has_special_tokens:
return super().get_special_tokens_mask(
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
)
# If there's only one token list provided, return a mask with special tokens added at both ends
if token_ids_1 is None:
return [1] + ([0] * len(token_ids_0)) + [1]
return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
def create_token_type_ids_from_sequences(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Create a mask from the two sequences passed to be used in a sequence-pair classification task. BART does not
make use of token type ids, therefore a list of zeros is returned.
Args:
token_ids_0 (`List[int]`):
List of IDs.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: List of zeros.
"""
sep = [self.sep_token_id]
cls = [self.cls_token_id]
if token_ids_1 is None:
return len(cls + token_ids_0 + sep) * [0]
return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs):
"""
Prepares text for tokenization by optionally adding a prefix space based on conditions.
Args:
text (str): The input text to be tokenized.
is_split_into_words (bool, optional): Whether the text is already split into words.
**kwargs: Additional keyword arguments.
Returns:
tuple: A tuple containing the modified text and remaining keyword arguments.
"""
add_prefix_space = kwargs.pop("add_prefix_space", self.add_prefix_space)
if (is_split_into_words or add_prefix_space) and (len(text) > 0 and not text[0].isspace()):
text = " " + text
return (text, kwargs)
.\models\bart\tokenization_bart_fast.py
import json
from typing import List, Optional, Tuple
from tokenizers import pre_tokenizers, processors
from ...tokenization_utils_base import AddedToken, BatchEncoding
from ...tokenization_utils_fast import PreTrainedTokenizerFast
from ...utils import logging
from .tokenization_bart import BartTokenizer
logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt", "tokenizer_file": "tokenizer.json"}
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"facebook/bart-base": "https://huggingface.co/facebook/bart-base/resolve/main/vocab.json",
"facebook/bart-large": "https://huggingface.co/facebook/bart-large/resolve/main/vocab.json",
"facebook/bart-large-mnli": "https://huggingface.co/facebook/bart-large-mnli/resolve/main/vocab.json",
"facebook/bart-large-cnn": "https://huggingface.co/facebook/bart-large-cnn/resolve/main/vocab.json",
"facebook/bart-large-xsum": "https://huggingface.co/facebook/bart-large-xsum/resolve/main/vocab.json",
"yjernite/bart_eli5": "https://huggingface.co/yjernite/bart_eli5/resolve/main/vocab.json",
},
"merges_file": {
"facebook/bart-base": "https://huggingface.co/facebook/bart-base/resolve/main/merges.txt",
"facebook/bart-large": "https://huggingface.co/facebook/bart-large/resolve/main/merges.txt",
"facebook/bart-large-mnli": "https://huggingface.co/facebook/bart-large-mnli/resolve/main/merges.txt",
"facebook/bart-large-cnn": "https://huggingface.co/facebook/bart-large-cnn/resolve/main/merges.txt",
"facebook/bart-large-xsum": "https://huggingface.co/facebook/bart-large-xsum/resolve/main/merges.txt",
"yjernite/bart_eli5": "https://huggingface.co/yjernite/bart_eli5/resolve/main/merges.txt",
},
{
"tokenizer_file": {
"facebook/bart-base": "https://huggingface.co/facebook/bart-base/resolve/main/tokenizer.json",
"facebook/bart-large": "https://huggingface.co/facebook/bart-large/resolve/main/tokenizer.json",
"facebook/bart-large-mnli": "https://huggingface.co/facebook/bart-large-mnli/resolve/main/tokenizer.json",
"facebook/bart-large-cnn": "https://huggingface.co/facebook/bart-large-cnn/resolve/main/tokenizer.json",
"facebook/bart-large-xsum": "https://huggingface.co/facebook/bart-large-xsum/resolve/main/tokenizer.json",
"yjernite/bart_eli5": "https://huggingface.co/yjernite/bart_eli5/resolve/main/tokenizer.json",
},
}
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"facebook/bart-base": 1024,
"facebook/bart-large": 1024,
"facebook/bart-large-mnli": 1024,
"facebook/bart-large-cnn": 1024,
"facebook/bart-large-xsum": 1024,
"yjernite/bart_eli5": 1024,
}
class BartTokenizerFast(PreTrainedTokenizerFast):
r"""
构建一个“快速”BART分词器(基于HuggingFace的*tokenizers*库),派生自GPT-2分词器,使用字节级别的字节对编码。
此分词器已经训练成将空格视为标记的一部分(类似于sentencepiece),因此一个词会根据其是否位于句子开头而编码不同:
```
>>> from transformers import BartTokenizerFast
>>> tokenizer = BartTokenizerFast.from_pretrained("facebook/bart-base")
>>> tokenizer("Hello world")["input_ids"]
[0, 31414, 232, 2]
>>> tokenizer(" Hello world")["input_ids"]
[0, 20920, 232, 2]
```
当在实例化分词器或对文本调用时传递 `add_prefix_space=True`,可以避免这种行为,但由于模型未以这种方式进行预训练,可能会导致性能下降。
<Tip>
当与 `is_split_into_words=True` 一起使用时,需要使用 `add_prefix_space=True` 实例化此分词器。
</Tip>
此分词器继承自[`PreTrainedTokenizerFast`],该类包含大多数主要方法。用户应参考此超类以获取有关这些方法的更多信息。
```
Args:
vocab_file (`str`):
Path to the vocabulary file.
merges_file (`str`):
Path to the merges file.
errors (`str`, *optional*, defaults to `"replace"`):
Paradigm to follow when decoding bytes to UTF-8. See
[bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
bos_token (`str`, *optional*, defaults to `"<s>"`):
The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
<Tip>
When building a sequence using special tokens, this is not the token that is used for the beginning of
sequence. The token used is the `cls_token`.
</Tip>
eos_token (`str`, *optional*, defaults to `"</s>"`):
The end of sequence token.
<Tip>
When building a sequence using special tokens, this is not the token that is used for the end of sequence.
The token used is the `sep_token`.
</Tip>
sep_token (`str`, *optional*, defaults to `"</s>"`):
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
sequence classification or for a text and a question for question answering. It is also used as the last
token of a sequence built with special tokens.
cls_token (`str`, *optional*, defaults to `"<s>"`):
The classifier token which is used when doing sequence classification (classification of the whole sequence
instead of per-token classification). It is the first token of the sequence when built with special tokens.
unk_token (`str`, *optional*, defaults to `"<unk>"`):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
token instead.
pad_token (`str`, *optional*, defaults to `"<pad>"`):
The token used for padding, for example when batching sequences of different lengths.
mask_token (`str`, *optional*, defaults to `"<mask>"`):
The token used for masking values. This is the token used when training this model with masked language
modeling. This is the token which the model will try to predict.
add_prefix_space (`bool`, *optional*, defaults to `False`):
Whether or not to add an initial space to the input. This allows to treat the leading word just as any
other word. (BART tokenizer detect beginning of words by the preceding space).
trim_offsets (`bool`, *optional*, defaults to `True`):
Whether the post processing step should trim offsets to avoid including whitespaces.
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
model_input_names = ["input_ids", "attention_mask"]
def __init__(
self,
vocab_file=None,
merges_file=None,
tokenizer_file=None,
errors="replace",
bos_token="<s>",
eos_token="</s>",
sep_token="</s>",
cls_token="<s>",
unk_token="<unk>",
pad_token="<pad>",
mask_token="<mask>",
add_prefix_space=False,
trim_offsets=True,
**kwargs,
):
mask_token = (
AddedToken(mask_token, lstrip=True, normalized=True, special=True)
if isinstance(mask_token, str)
else mask_token
)
super().__init__(
vocab_file,
merges_file,
tokenizer_file=tokenizer_file,
errors=errors,
bos_token=bos_token,
eos_token=eos_token,
sep_token=sep_token,
cls_token=cls_token,
unk_token=unk_token,
pad_token=pad_token,
mask_token=mask_token,
add_prefix_space=add_prefix_space,
trim_offsets=trim_offsets,
**kwargs,
)
pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__())
if pre_tok_state.get("add_prefix_space", add_prefix_space) != add_prefix_space:
pre_tok_class = getattr(pre_tokenizers, pre_tok_state.pop("type"))
pre_tok_state["add_prefix_space"] = add_prefix_space
self.backend_tokenizer.pre_tokenizer = pre_tok_class(**pre_tok_state)
self.add_prefix_space = add_prefix_space
tokenizer_component = "post_processor"
tokenizer_component_instance = getattr(self.backend_tokenizer, tokenizer_component, None)
if tokenizer_component_instance:
state = json.loads(tokenizer_component_instance.__getstate__())
if "sep" in state:
state["sep"] = tuple(state["sep"])
if "cls" in state:
state["cls"] = tuple(state["cls"])
changes_to_apply = False
if state.get("add_prefix_space", add_prefix_space) != add_prefix_space:
state["add_prefix_space"] = add_prefix_space
changes_to_apply = True
if state.get("trim_offsets", trim_offsets) != trim_offsets:
state["trim_offsets"] = trim_offsets
changes_to_apply = True
if changes_to_apply:
component_class = getattr(processors, state.pop("type"))
new_value = component_class(**state)
setattr(self.backend_tokenizer, tokenizer_component, new_value)
def mask_token(self) -> str:
"""
`str`: 返回用于训练模型的掩码标记。如果尚未设置,则记录错误信息。
BART 分词器具有特殊的掩码标记,用于填充掩码管道。该掩码标记会贪婪地包括 *<mask>* 前面的空格。
"""
if self._mask_token is None:
if self.verbose:
logger.error("Using mask_token, but it is not set yet.")
return None
return str(self._mask_token)
@mask_token.setter
def mask_token(self, value):
"""
重写掩码标记的默认行为,使其能够吞掉前面的空格。
这是为了与所有之前基于 BART 的模型保持向后兼容性。
"""
value = AddedToken(value, lstrip=True, rstrip=False) if isinstance(value, str) else value
self._mask_token = value
def _batch_encode_plus(self, *args, **kwargs) -> BatchEncoding:
is_split_into_words = kwargs.get("is_split_into_words", False)
if is_split_into_words and not self.add_prefix_space:
raise ValueError(
f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True "
"to use it with pretokenized inputs."
)
return super()._batch_encode_plus(*args, **kwargs)
def _encode_plus(self, *args, **kwargs) -> BatchEncoding:
is_split_into_words = kwargs.get("is_split_into_words", False)
if is_split_into_words and not self.add_prefix_space:
raise ValueError(
f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True "
"to use it with pretokenized inputs."
)
return super()._encode_plus(*args, **kwargs)
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
"""
将词汇表保存到指定的目录中。
调用底层分词器模型的保存方法,并返回保存的文件列表。
"""
files = self._tokenizer.model.save(save_directory, name=filename_prefix)
return tuple(files)
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
"""
为输入构建包含特殊标记的序列。
在 token_ids_0 前加入 bos_token_id,后加入 eos_token_id。如果提供 token_ids_1,则在其前后也加入 eos_token_id。
"""
output = [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
if token_ids_1 is None:
return output
return output + [self.eos_token_id] + token_ids_1 + [self.eos_token_id]
def create_token_type_ids_from_sequences(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
):
"""
根据序列创建 token_type_ids。
token_ids_0 和 token_ids_1 用于创建对应的 token_type_ids,用于区分不同的句子或片段。
"""
Args:
token_ids_0 (`List[int]`):
第一个序列的ID列表。
token_ids_1 (`List[int]`, *optional*):
第二个序列的ID列表,用于序列对。
Returns:
`List[int]`: 全零列表,长度根据输入的序列长度动态计算。
"""
# 分隔符 token 的 ID 列表
sep = [self.sep_token_id]
# 类别 token 的 ID 列表
cls = [self.cls_token_id]
# 如果第二个序列的 ID 列表为空
if token_ids_1 is None:
# 返回长度为 cls + token_ids_0 + sep 组合后的列表,每个元素都是 0
return len(cls + token_ids_0 + sep) * [0]
# 如果有第二个序列的 ID 列表
# 返回长度为 cls + token_ids_0 + sep + sep + token_ids_1 + sep 组合后的列表,每个元素都是 0
return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
.\models\bart\__init__.py
from typing import TYPE_CHECKING
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_flax_available,
is_tf_available,
is_tokenizers_available,
is_torch_available,
)
_import_structure = {
"configuration_bart": ["BART_PRETRAINED_CONFIG_ARCHIVE_MAP", "BartConfig", "BartOnnxConfig"],
"tokenization_bart": ["BartTokenizer"],
}
try:
if not is_tokenizers_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["tokenization_bart_fast"] = ["BartTokenizerFast"]
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_bart"] = [
"BART_PRETRAINED_MODEL_ARCHIVE_LIST",
"BartForCausalLM",
"BartForConditionalGeneration",
"BartForQuestionAnswering",
"BartForSequenceClassification",
"BartModel",
"BartPreTrainedModel",
"BartPretrainedModel",
"PretrainedBartModel",
]
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_tf_bart"] = [
"TFBartForConditionalGeneration",
"TFBartForSequenceClassification",
"TFBartModel",
"TFBartPretrainedModel",
]
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_flax_bart"] = [
"FlaxBartDecoderPreTrainedModel",
"FlaxBartForCausalLM",
"FlaxBartForConditionalGeneration",
"FlaxBartForQuestionAnswering",
"FlaxBartForSequenceClassification",
"FlaxBartModel",
"FlaxBartPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_bart import BART_PRETRAINED_CONFIG_ARCHIVE_MAP, BartConfig, BartOnnxConfig
from .tokenization_bart import BartTokenizer
try:
if not is_tokenizers_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .tokenization_bart_fast import BartTokenizerFast
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_bart import (
BART_PRETRAINED_MODEL_ARCHIVE_LIST,
BartForCausalLM,
BartForConditionalGeneration,
BartForQuestionAnswering,
BartForSequenceClassification,
BartModel,
BartPreTrainedModel,
BartPretrainedModel,
PretrainedBartModel,
)
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_tf_bart import (
TFBartForConditionalGeneration,
TFBartForSequenceClassification,
TFBartModel,
TFBartPretrainedModel,
)
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_flax_bart import (
FlaxBartDecoderPreTrainedModel,
FlaxBartForCausalLM,
FlaxBartForConditionalGeneration,
FlaxBartForQuestionAnswering,
FlaxBartForSequenceClassification,
FlaxBartModel,
FlaxBartPreTrainedModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
.\models\barthez\tokenization_barthez.py
""" BARThez 模型的分词类。"""
import os
from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple
import sentencepiece as spm
from ...tokenization_utils import AddedToken, PreTrainedTokenizer
from ...utils import logging
logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model"}
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"moussaKam/mbarthez": "https://huggingface.co/moussaKam/mbarthez/resolve/main/sentencepiece.bpe.model",
"moussaKam/barthez": "https://huggingface.co/moussaKam/barthez/resolve/main/sentencepiece.bpe.model",
"moussaKam/barthez-orangesum-title": (
"https://huggingface.co/moussaKam/barthez-orangesum-title/resolve/main/sentencepiece.bpe.model"
),
},
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"moussaKam/mbarthez": 1024,
"moussaKam/barthez": 1024,
"moussaKam/barthez-orangesum-title": 1024,
}
SPIECE_UNDERLINE = "▁"
class BarthezTokenizer(PreTrainedTokenizer):
"""
从 `CamembertTokenizer` 和 `BartTokenizer` 改编而来。构建一个 BARThez 分词器。基于
[SentencePiece](https://github.com/google/sentencepiece)。
此分词器继承自 `PreTrainedTokenizer`,其中包含大多数主要方法。用户应参考
此超类以获取关于这些方法的更多信息。
Attributes:
sp_model (`SentencePieceProcessor`):
用于所有转换(字符串、标记和 ID)的 SentencePiece 处理器。
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
model_input_names = ["input_ids", "attention_mask"]
def __init__(
self,
vocab_file,
bos_token="<s>",
eos_token="</s>",
sep_token="</s>",
cls_token="<s>",
unk_token="<unk>",
pad_token="<pad>",
mask_token="<mask>",
sp_model_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
) -> None:
"""
初始化一个新的 BARThezTokenizer 对象。
Args:
mask_token (`Union[str, AddedToken]`):
用作掩码标记的特殊令牌。如果是字符串,则 lstrip=True,special=True。
sp_model_kwargs (`Optional[Dict]`, *optional*):
SentencePiece 模型的额外参数,默认为空字典。
vocab_file (`Optional[Union[str, Path]]`):
词汇文件的路径。
bos_token (`Optional[str]`, *optional*):
用作开头(beginning of sequence)标记的特殊令牌。
eos_token (`Optional[str]`, *optional*):
用作结尾(end of sequence)标记的特殊令牌。
unk_token (`Optional[str]`, *optional*):
用作未知标记的特殊令牌。
sep_token (`Optional[str]`, *optional*):
用作分隔标记的特殊令牌。
cls_token (`Optional[str]`, *optional*):
用作类标记的特殊令牌。
pad_token (`Optional[str]`, *optional*):
用作填充标记的特殊令牌。
**kwargs:
其他参数传递给父类构造函数。
"""
mask_token = AddedToken(mask_token, lstrip=True, special=True) if isinstance(mask_token, str) else mask_token
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
self.vocab_file = vocab_file
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(str(vocab_file))
super().__init__(
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
sep_token=sep_token,
cls_token=cls_token,
pad_token=pad_token,
mask_token=mask_token,
sp_model_kwargs=self.sp_model_kwargs,
**kwargs,
)
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
为序列分类任务构建模型输入,通过连接和添加特殊标记。BARThez 序列的格式如下:
- 单个序列: `<s> X </s>`
- 序列对: `<s> A </s></s> B </s>`
Args:
token_ids_0 (`List[int]`):
要添加特殊标记的 ID 列表。
token_ids_1 (`Optional[List[int]]`, *optional*):
第二个序列的 ID 列表,用于序列对输入。
Returns:
`List[int]`: 包含适当特殊标记的输入 ID 列表。
"""
if token_ids_1 is None:
return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
cls = [self.cls_token_id]
sep = [self.sep_token_id]
return cls + token_ids_0 + sep + sep + token_ids_1 + sep
def get_special_tokens_mask(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
) -> List[int]:
"""
返回包含特殊标记的掩码列表,用于指示输入中的特殊标记位置。
Args:
token_ids_0 (`List[int]`):
输入序列的 ID 列表。
token_ids_1 (`Optional[List[int]]`, *optional*):
第二个序列的 ID 列表,用于序列对输入。
already_has_special_tokens (`bool`, *optional*):
如果输入已包含特殊标记,则为 True。
Returns:
`List[int]`: 标记了特殊标记位置的掩码列表。
"""
def get_special_tokens_mask(
self,
token_ids_0: List[int],
token_ids_1: Optional[List[int]] = None,
already_has_special_tokens: bool = False
) -> List[int]:
"""
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer `prepare_for_model` method.
Args:
token_ids_0 (`List[int]`):
List of IDs.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not the token list is already formatted with special tokens for the model.
Returns:
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
"""
if already_has_special_tokens:
return super().get_special_tokens_mask(
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
)
if token_ids_1 is None:
return [1] + ([0] * len(token_ids_0)) + [1]
else:
return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
def create_token_type_ids_from_sequences(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Create a mask from the two sequences passed to be used in a sequence-pair classification task.
Args:
token_ids_0 (`List[int]`):
List of IDs.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: List of zeros.
"""
sep = [self.sep_token_id]
cls = [self.cls_token_id]
if token_ids_1 is None:
return len(cls + token_ids_0 + sep) * [0]
else:
return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
@property
def vocab_size(self):
return len(self.sp_model)
def get_vocab(self):
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
vocab.update(self.added_tokens_encoder)
return vocab
def _tokenize(self, text: str) -> List[str]:
return self.sp_model.encode(text, out_type=str)
def _convert_token_to_id(self, token):
"""Converts a token (str) into an ID using the vocabulary."""
return self.sp_model.PieceToId(token)
def _convert_id_to_token(self, index):
"""Converts an ID (integer) into a token (str) using the vocabulary."""
return self.sp_model.IdToPiece(index)
def convert_tokens_to_string(self, tokens):
current_sub_tokens = []
out_string = ""
prev_is_special = False
for token in tokens:
if token in self.all_special_tokens:
if not prev_is_special:
out_string += " "
out_string += self.sp_model.decode(current_sub_tokens) + token
prev_is_special = True
current_sub_tokens = []
else:
current_sub_tokens.append(token)
prev_is_special = False
out_string += self.sp_model.decode(current_sub_tokens)
return out_string.strip()
def __getstate__(self):
state = self.__dict__.copy()
state["sp_model"] = None
return state
def __setstate__(self, d):
self.__dict__ = d
if not hasattr(self, "sp_model_kwargs"):
self.sp_model_kwargs = {}
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(self.vocab_file)
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
return
out_vocab_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
)
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file)
elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
return (out_vocab_file,)
.\models\barthez\tokenization_barthez_fast.py
""" BARThez 模型的分词类。"""
import os
from shutil import copyfile
from typing import List, Optional, Tuple
from ...tokenization_utils import AddedToken
from ...tokenization_utils_fast import PreTrainedTokenizerFast
from ...utils import is_sentencepiece_available, logging
if is_sentencepiece_available():
from .tokenization_barthez import BarthezTokenizer
else:
BarthezTokenizer = None
logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model", "tokenizer_file": "tokenizer.json"}
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"moussaKam/mbarthez": "https://huggingface.co/moussaKam/mbarthez/resolve/main/sentencepiece.bpe.model",
"moussaKam/barthez": "https://huggingface.co/moussaKam/barthez/resolve/main/sentencepiece.bpe.model",
"moussaKam/barthez-orangesum-title": (
"https://huggingface.co/moussaKam/barthez-orangesum-title/resolve/main/sentencepiece.bpe.model"
),
},
"tokenizer_file": {
"moussaKam/mbarthez": "https://huggingface.co/moussaKam/mbarthez/resolve/main/tokenizer.json",
"moussaKam/barthez": "https://huggingface.co/moussaKam/barthez/resolve/main/tokenizer.json",
"moussaKam/barthez-orangesum-title": (
"https://huggingface.co/moussaKam/barthez-orangesum-title/resolve/main/tokenizer.json"
),
},
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"moussaKam/mbarthez": 1024,
"moussaKam/barthez": 1024,
"moussaKam/barthez-orangesum-title": 1024,
}
SPIECE_UNDERLINE = "▁"
class BarthezTokenizerFast(PreTrainedTokenizerFast):
"""
从 `CamembertTokenizer` 和 `BartTokenizer` 改编而来。构建一个“快速”的 BARThez 分词器,基于
[SentencePiece](https://github.com/google/sentencepiece)。
该分词器继承自 `PreTrainedTokenizerFast`,其中包含大多数主要方法。用户应参考这个超类以获取更多关于这些方法的信息。
"""
"""
Args:
vocab_file (`str`):
[SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
contains the vocabulary necessary to instantiate a tokenizer.
bos_token (`str`, *optional*, defaults to `"<s>"`):
The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
<Tip>
When building a sequence using special tokens, this is not the token that is used for the beginning of
sequence. The token used is the `cls_token`.
</Tip>
eos_token (`str`, *optional*, defaults to `"</s>"`):
The end of sequence token.
<Tip>
When building a sequence using special tokens, this is not the token that is used for the end of sequence.
The token used is the `sep_token`.
</Tip>
sep_token (`str`, *optional*, defaults to `"</s>"`):
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
sequence classification or for a text and a question for question answering. It is also used as the last
token of a sequence built with special tokens.
cls_token (`str`, *optional*, defaults to `"<s>"`):
The classifier token which is used when doing sequence classification (classification of the whole sequence
instead of per-token classification). It is the first token of the sequence when built with special tokens.
unk_token (`str`, *optional*, defaults to `"<unk>"`):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
token instead.
pad_token (`str`, *optional*, defaults to `"<pad>"`):
The token used for padding, for example when batching sequences of different lengths.
mask_token (`str`, *optional*, defaults to `"<mask>"`):
The token used for masking values. This is the token used when training this model with masked language
modeling. This is the token which the model will try to predict.
additional_special_tokens (`List[str]`, *optional*, defaults to `["<s>NOTUSED", "</s>NOTUSED"]`):
Additional special tokens used by the tokenizer.
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
model_input_names = ["input_ids", "attention_mask"]
slow_tokenizer_class = BarthezTokenizer
def __init__(
self,
vocab_file=None,
tokenizer_file=None,
bos_token="<s>",
eos_token="</s>",
sep_token="</s>",
cls_token="<s>",
unk_token="<unk>",
pad_token="<pad>",
mask_token="<mask>",
**kwargs,
"""
):
# 如果 mask_token 是字符串类型,将其包装为一个带有剥离左侧空格和不剥离右侧空格的 AddedToken 对象;否则保持不变
mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
# 调用父类的初始化方法,传入必要的参数和关键字参数
super().__init__(
vocab_file,
tokenizer_file=tokenizer_file,
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
sep_token=sep_token,
cls_token=cls_token,
pad_token=pad_token,
mask_token=mask_token,
**kwargs,
)
# 设置对象的 vocab_file 属性为传入的 vocab_file
self.vocab_file = vocab_file
@property
def can_save_slow_tokenizer(self) -> bool:
# 如果 self.vocab_file 存在且是一个文件,则返回 True;否则返回 False
return os.path.isfile(self.vocab_file) if self.vocab_file else False
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
通过添加特殊 token 构建用于序列分类任务的模型输入。BARThez 序列的格式如下:
- 单个序列: `<s> X </s>`
- 序列对: `<s> A </s></s> B </s>`
Args:
token_ids_0 (`List[int]`):
需要添加特殊 token 的 ID 列表。
token_ids_1 (`List[int]`, *optional*):
第二个序列的 ID 列表(对序列任务时使用)。
Returns:
`List[int]`: 包含适当特殊 token 的输入 ID 列表。
"""
if token_ids_1 is None:
# 如果没有第二个序列,返回包含 cls_token_id, token_ids_0 和 sep_token_id 的列表
return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
cls = [self.cls_token_id]
sep = [self.sep_token_id]
# 返回包含 cls_token_id, token_ids_0, sep_token_id, sep_token_id, token_ids_1 和 sep_token_id 的列表
return cls + token_ids_0 + sep + sep + token_ids_1 + sep
def create_token_type_ids_from_sequences(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
从两个传入的序列创建一个用于序列对分类任务的掩码。
Args:
token_ids_0 (`List[int]`):
第一个序列的 ID 列表。
token_ids_1 (`List[int]`, *optional*):
第二个序列的 ID 列表(对序列任务时使用)。
Returns:
`List[int]`: 全为零的列表。
"""
sep = [self.sep_token_id]
cls = [self.cls_token_id]
if token_ids_1 is None:
# 如果没有第二个序列,返回长度为 cls_token_id, token_ids_0 和 sep 的列表,所有元素为零
return len(cls + token_ids_0 + sep) * [0]
# 返回长度为 cls_token_id, token_ids_0, sep, sep, token_ids_1 和 sep 的列表,所有元素为零
return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
# 定义一个方法用于保存词汇表到指定目录下的文件,返回文件路径元组
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
# 如果当前的快速分词器不具备保存慢速分词器所需的信息,则引发数值错误异常
if not self.can_save_slow_tokenizer:
raise ValueError(
"Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
"tokenizer."
)
# 如果保存目录不存在,记录错误日志并返回
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
return
# 构建输出的词汇表文件路径
out_vocab_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
)
# 如果当前词汇表文件路径与输出文件路径不一致,则复制当前词汇表文件到输出路径
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
copyfile(self.vocab_file, out_vocab_file)
# 返回输出词汇表文件路径的元组
return (out_vocab_file,)
.\models\barthez\__init__.py
from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule
from ...utils import is_sentencepiece_available, is_tokenizers_available
_import_structure = {}
try:
if not is_sentencepiece_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["tokenization_barthez"] = ["BarthezTokenizer"]
try:
if not is_tokenizers_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["tokenization_barthez_fast"] = ["BarthezTokenizerFast"]
if TYPE_CHECKING:
try:
if not is_sentencepiece_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .tokenization_barthez import BarthezTokenizer
try:
if not is_tokenizers_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .tokenization_barthez_fast import BarthezTokenizerFast
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
.\models\bartpho\tokenization_bartpho.py
""" Tokenization classes for BARTpho-syllable model."""
import os
from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple
import sentencepiece as spm
from ...tokenization_utils import AddedToken, PreTrainedTokenizer
from ...utils import logging
logger = logging.get_logger(__name__)
SPIECE_UNDERLINE = "▁"
VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model", "monolingual_vocab_file": "dict.txt"}
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"vinai/bartpho-syllable": "https://huggingface.co/vinai/bartpho-syllable/resolve/main/sentencepiece.bpe.model",
},
"monolingual_vocab_file": {
"vinai/bartpho-syllable": "https://huggingface.co/vinai/bartpho-syllable/resolve/main/dict.txt",
},
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {"vinai/bartpho-syllable": 1024}
class BartphoTokenizer(PreTrainedTokenizer):
"""
自 [`XLMRobertaTokenizer`] 改编。基于 [SentencePiece](https://github.com/google/sentencepiece)。
此分词器继承自 [`PreTrainedTokenizer`],包含大多数主要方法。用户应参考超类以获取更多有关这些方法的信息。
Attributes:
sp_model (`SentencePieceProcessor`):
每次转换(字符串、标记和 ID)都使用的 SentencePiece 处理器。
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
model_input_names = ["input_ids", "attention_mask"]
def __init__(
self,
vocab_file,
monolingual_vocab_file,
bos_token="<s>",
eos_token="</s>",
sep_token="</s>",
cls_token="<s>",
unk_token="<unk>",
pad_token="<pad>",
mask_token="<mask>",
sp_model_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
) -> None:
mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
self.vocab_file = vocab_file
self.monolingual_vocab_file = monolingual_vocab_file
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(str(vocab_file))
self.fairseq_tokens_to_ids = {}
cnt = 0
for token in [bos_token, pad_token, eos_token, unk_token, sep_token, cls_token]:
if str(token) not in self.fairseq_tokens_to_ids:
self.fairseq_tokens_to_ids[str(token)] = cnt
cnt += 1
with open(monolingual_vocab_file, "r", encoding="utf-8") as f:
for line in f.readlines():
token = line.strip().split()[0]
self.fairseq_tokens_to_ids[token] = len(self.fairseq_tokens_to_ids)
if str(mask_token) not in self.fairseq_tokens_to_ids:
self.fairseq_tokens_to_ids[str(mask_token)] = len(self.fairseq_tokens_to_ids)
self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()}
super().__init__(
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
sep_token=sep_token,
cls_token=cls_token,
pad_token=pad_token,
mask_token=mask_token,
sp_model_kwargs=self.sp_model_kwargs,
**kwargs,
)
def __getstate__(self):
state = self.__dict__.copy()
state["sp_model"] = None
state["sp_model_proto"] = self.sp_model.serialized_model_proto()
return state
def __setstate__(self, d):
self.__dict__ = d
if not hasattr(self, "sp_model_kwargs"):
self.sp_model_kwargs = {}
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.LoadFromSerializedProto(self.sp_model_proto)
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Build model inputs from a sequence or a pair of sequences for sequence classification tasks by concatenating and
adding special tokens. An BARTPho sequence has the following format:
- single sequence: `<s> X </s>`
- pair of sequences: `<s> A </s></s> B </s>`
Args:
token_ids_0 (`List[int]`):
List of IDs to which the special tokens will be added.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: List of input IDs with the appropriate special tokens.
"""
if token_ids_1 is None:
return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
cls = [self.cls_token_id]
sep = [self.sep_token_id]
return cls + token_ids_0 + sep + sep + token_ids_1 + sep
def get_special_tokens_mask(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
) -> List[int]:
"""
Retrieve sequence IDs from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer `prepare_for_model` method.
Args:
token_ids_0 (`List[int]`):
List of IDs.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not the token list is already formatted with special tokens for the model.
Returns:
`List[int]`: A list of integers indicating the presence of special tokens (1) or sequence tokens (0).
"""
if already_has_special_tokens:
return super().get_special_tokens_mask(
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
)
if token_ids_1 is None:
return [1] + ([0] * len(token_ids_0)) + [1]
return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
def create_token_type_ids_from_sequences(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
):
"""
Create token type IDs tensor from sequences for sequence classification tasks. This method assigns each token in the input
sequences a token type ID (0 or 1) depending on whether it belongs to the first or the second sequence.
Args:
token_ids_0 (`List[int]`):
List of IDs for the first sequence.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: A list of token type IDs where each ID corresponds to the respective input token.
"""
) -> List[int]:
"""
Create a mask from the two sequences passed to be used in a sequence-pair classification task. BARTPho does not
make use of token type ids, therefore a list of zeros is returned.
Args:
token_ids_0 (`List[int]`):
List of IDs.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: List of zeros.
"""
sep = [self.sep_token_id]
cls = [self.cls_token_id]
if token_ids_1 is None:
return len(cls + token_ids_0 + sep) * [0]
return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
@property
def vocab_size(self):
return len(self.fairseq_ids_to_tokens)
def get_vocab(self):
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
vocab.update(self.added_tokens_encoder)
return vocab
def _tokenize(self, text: str) -> List[str]:
return self.sp_model.encode(text, out_type=str)
def _convert_token_to_id(self, token):
"""Converts a token (str) into an ID using the vocabulary."""
if token in self.fairseq_tokens_to_ids:
return self.fairseq_tokens_to_ids[token]
else:
return self.unk_token_id
def _convert_id_to_token(self, index):
"""Converts an index (integer) into a token (str) using the vocabulary."""
return self.fairseq_ids_to_tokens[index]
def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (strings for sub-words) into a single string."""
out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
return out_string
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
return
out_vocab_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
)
out_monolingual_vocab_file = os.path.join(
save_directory,
(filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["monolingual_vocab_file"],
)
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file)
elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
if os.path.abspath(self.monolingual_vocab_file) != os.path.abspath(
out_monolingual_vocab_file
) and os.path.isfile(self.monolingual_vocab_file):
copyfile(self.monolingual_vocab_file, out_monolingual_vocab_file)
elif not os.path.isfile(self.monolingual_vocab_file):
with open(out_monolingual_vocab_file, "w", encoding="utf-8") as fp:
for token in self.fairseq_tokens_to_ids:
if token not in self.all_special_tokens:
fp.write(f"{str(token)} \n")
return out_vocab_file, out_monolingual_vocab_file