Transformers 源码解析(六十六)
.\models\longformer\modeling_tf_longformer.py
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import numpy as np
import tensorflow as tf
from ...activations_tf import get_tf_activation
from ...modeling_tf_utils import (
TFMaskedLanguageModelingLoss,
TFModelInputType,
TFMultipleChoiceLoss,
TFPreTrainedModel,
TFQuestionAnsweringLoss,
TFSequenceClassificationLoss,
TFTokenClassificationLoss,
get_initializer,
keras,
keras_serializable,
unpack_inputs,
)
from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax
from ...utils import (
ModelOutput,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
)
from .configuration_longformer import LongformerConfig
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "allenai/longformer-base-4096"
_CONFIG_FOR_DOC = "LongformerConfig"
LARGE_NEGATIVE = -1e8
TF_LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
"allenai/longformer-base-4096",
"allenai/longformer-large-4096",
"allenai/longformer-large-4096-finetuned-triviaqa",
"allenai/longformer-base-4096-extra.pos.embd.only",
"allenai/longformer-large-4096-extra.pos.embd.only",
]
@dataclass
class TFLongformerBaseModelOutput(ModelOutput):
"""
Longformer 模型的基础输出类,可能包含隐藏状态、本地和全局注意力等信息。
继承自 ModelOutput 类。
"""
Args:
last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
模型最后一层输出的隐藏状态序列。
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
包含模型每一层隐藏状态的元组,形状为 `(batch_size, sequence_length, hidden_size)`。
模型在每一层的隐藏状态,以及初始嵌入输出。
attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
包含模型每一层局部注意力权重的元组,形状为 `(batch_size, num_heads, sequence_length, x + attention_window + 1)`,
其中 `x` 是全局注意力掩码的标记数。
在注意力 softmax 后的局部注意力权重,用于计算自注意力头中的加权平均值。这些是从序列中每个标记到具有全局注意力的每个标记的注意力权重(前 `x` 个值),
以及到注意力窗口内每个标记的注意力权重(剩余的 `attention_window + 1` 个值)。
注意,前 `x` 个值是指文本中固定位置的标记,而剩余的 `attention_window + 1` 个值是指相对位置的标记:
标记到自身的注意力权重位于索引 `x + attention_window / 2`,而前(后) `attention_window / 2` 个值是到前(后)标记的注意力权重。
如果注意力窗口包含具有全局注意力的标记,则相应索引处的注意力权重设为 0;该值应从第一个 `x` 个注意力权重中获取。
如果标记具有全局注意力,则到 `attentions` 中所有其他标记的注意力权重设为 0;该值应从 `global_attentions` 中获取。
global_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
包含模型每一层全局注意力权重的元组,形状为 `(batch_size, num_heads, sequence_length, x)`,
其中 `x` 是具有全局注意力掩码的标记数。
在注意力 softmax 后的全局注意力权重,用于计算自注意力头中的加权平均值。这些是从具有全局注意力的每个标记到序列中每个标记的注意力权重。
Raises:
None
Returns:
None
@dataclass
class TFLongformerBaseModelOutputWithPooling(ModelOutput):
"""
Base class for Longformer's outputs that also contains a pooling of the last hidden states.
"""
last_hidden_state: tf.Tensor = None
pooler_output: tf.Tensor = None
hidden_states: Tuple[tf.Tensor, ...] | None = None
attentions: Tuple[tf.Tensor, ...] | None = None
global_attentions: Tuple[tf.Tensor, ...] | None = None
@dataclass
class TFLongformerMaskedLMOutput(ModelOutput):
"""
Base class for masked language models outputs.
"""
loss: tf.Tensor | None = None
logits: tf.Tensor = None
hidden_states: Tuple[tf.Tensor, ...] | None = None
attentions: Tuple[tf.Tensor, ...] | None = None
global_attentions: Tuple[tf.Tensor, ...] | None = None
@dataclass
class TFLongformerQuestionAnsweringModelOutput(ModelOutput):
"""
Base class for outputs of question answering Longformer models.
问题回答 Longformer 模型输出的基类。
"""
loss: tf.Tensor | None = None
start_logits: tf.Tensor = None
end_logits: tf.Tensor = None
hidden_states: Tuple[tf.Tensor, ...] | None = None
attentions: Tuple[tf.Tensor, ...] | None = None
global_attentions: Tuple[tf.Tensor, ...] | None = None
@dataclass
class TFLongformerSequenceClassifierOutput(ModelOutput):
"""
Base class for outputs of sentence classification models.
句子分类模型输出的基类。
"""
loss: tf.Tensor | None = None
@dataclass
class TFLongformerMultipleChoiceModelOutput(ModelOutput):
"""
Base class for outputs of multiple choice models.
"""
Args:
loss (`tf.Tensor` of shape *(1,)*, *optional*, returned when `labels` is provided):
分类损失。
logits (`tf.Tensor` of shape `(batch_size, num_choices)`):
*num_choices* 是输入张量的第二维度。(见上述的 *input_ids*)。
分类得分(SoftMax 之前)。
hidden_states (`tuple(tf.Tensor)`, *optional*, 当 `output_hidden_states=True` 被传递或者 `config.output_hidden_states=True` 时返回):
形状为 `(batch_size, sequence_length, hidden_size)` 的 `tf.Tensor` 元组。
模型在每一层输出的隐藏状态以及初始嵌入输出。
attentions (`tuple(tf.Tensor)`, *optional*, 当 `output_attentions=True` 被传递或者 `config.output_attentions=True` 时返回):
形状为 `(batch_size, num_heads, sequence_length, x + attention_window + 1)` 的 `tf.Tensor` 元组,其中 `x` 是具有全局注意力掩码的标记数。
注意力 softmax 后的局部注意力权重,用于计算自注意力头中的加权平均值。这些是从序列中每个标记到具有全局注意力的每个标记(前 `x` 个值)和到注意力窗口中每个标记(剩余 `attention_window + 1` 个值)的注意力权重。请注意,前 `x` 个值指的是文本中具有固定位置的标记,但剩余的 `attention_window + 1` 个值指的是具有相对位置的标记:一个标记到自身的注意力权重位于索引 `x + attention_window / 2`,前 `attention_window / 2`(后续)的值是到前(后) `attention_window / 2` 个标记的注意力权重。如果注意力窗口包含具有全局注意力的标记,则相应索引处的注意力权重设为 0;其值应从前 `x` 个注意力权重中访问。如果一个标记具有全局注意力,则 `attentions` 中所有其他标记的注意力权重设为 0,其值应从 `global_attentions` 中访问。
global_attentions (`tuple(tf.Tensor)`, *optional*, 当 `output_attentions=True` 被传递或者 `config.output_attentions=True` 时返回):
形状为 `(batch_size, num_heads, sequence_length, x)` 的 `tf.Tensor` 元组,其中 `x` 是具有全局注意力掩码的标记数。
注意力 softmax 后的全局注意力权重,用于计算自注意力头中的加权平均值。这些是每个具有全局注意力的标记到序列中每个标记的注意力权重。
logits: tf.Tensor = None
hidden_states: Tuple[tf.Tensor, ...] | None = None
attentions: Tuple[tf.Tensor, ...] | None = None
global_attentions: Tuple[tf.Tensor, ...] | None = None
@dataclass
class TFLongformerTokenClassifierOutput(ModelOutput):
"""
定义一个数据类 TFLongformerTokenClassifierOutput,用于存储长形式模型的标记分类输出结果。
继承自 ModelOutput 类。
Base class for outputs of token classification models.
"""
"""
Args:
loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided) :
分类损失。
Classification loss.
logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.num_labels)`):
分类分数(SoftMax 之前)。
Classification scores (before SoftMax).
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
模型每一层的隐藏状态,包括初始嵌入输出。
Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
每层的注意力权重。
Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x +
attention_window + 1)`, where `x` is the number of tokens with global attention mask.
Local attentions weights after the attention softmax, used to compute the weighted average in the
self-attention heads. Those are the attention weights from every token in the sequence to every token with
global attention (first `x` values) and to every token in the attention window (remaining `attention_window
+ 1` values). Note that the first `x` values refer to tokens with fixed positions in the text, but the
remaining `attention_window + 1` values refer to tokens with relative positions: the attention weight of a
token to itself is located at index `x + attention_window / 2` and the `attention_window / 2` preceding
(succeeding) values are the attention weights to the `attention_window / 2` preceding (succeeding) tokens.
If the attention window contains a token with global attention, the attention weight at the corresponding
index is set to 0; the value should be accessed from the first `x` attention weights. If a token has global
attention, the attention weights to all other tokens in `attentions` is set to 0, the values should be
accessed from `global_attentions`.
global_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
全局注意力权重。
Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x)`, where `x`
is the number of tokens with global attention mask.
Global attentions weights after the attention softmax, used to compute the weighted average in the
self-attention heads. Those are the attention weights from every token with global attention to every token
in the sequence.
"""
loss: tf.Tensor | None = None
logits: tf.Tensor = None
hidden_states: Tuple[tf.Tensor, ...] | None = None
attentions: Tuple[tf.Tensor, ...] | None = None
global_attentions: Tuple[tf.Tensor, ...] | None = None
def _compute_global_attention_mask(input_ids_shape, sep_token_indices, before_sep_token=True):
assert shape_list(sep_token_indices)[1] == 2, "`input_ids` should have two dimensions"
question_end_index = tf.reshape(sep_token_indices, (input_ids_shape[0], 3, 2))[:, 0, 1][:, None]
attention_mask = tf.expand_dims(tf.range(input_ids_shape[1], dtype=tf.int64), axis=0)
attention_mask = tf.tile(attention_mask, (input_ids_shape[0], 1))
if before_sep_token is True:
question_end_index = tf.tile(question_end_index, (1, input_ids_shape[1]))
attention_mask = tf.cast(attention_mask < question_end_index, dtype=question_end_index.dtype)
else:
question_end_index = tf.tile(question_end_index + 1, (1, input_ids_shape[1]))
attention_mask = tf.cast(
(attention_mask > question_end_index) * (attention_mask < input_ids_shape[-1]),
dtype=question_end_index.dtype,
)
return attention_mask
class TFLongformerLMHead(keras.layers.Layer):
"""Longformer Head for masked language modeling."""
def __init__(self, config, input_embeddings, **kwargs):
super().__init__(**kwargs)
self.config = config
self.hidden_size = config.hidden_size
self.dense = keras.layers.Dense(
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
)
self.layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
self.act = get_tf_activation("gelu")
self.decoder = input_embeddings
def build(self, input_shape=None):
self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias")
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
if getattr(self, "layer_norm", None) is not None:
with tf.name_scope(self.layer_norm.name):
self.layer_norm.build([None, None, self.config.hidden_size])
def get_output_embeddings(self):
return self.decoder
def set_output_embeddings(self, value):
self.decoder.weight = value
self.decoder.vocab_size = shape_list(value)[0]
def get_bias(self):
return {"bias": self.bias}
def set_bias(self, value):
self.bias = value["bias"]
self.config.vocab_size = shape_list(value["bias"])[0]
def call(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.act(hidden_states)
seq_length = shape_list(tensor=hidden_states)[1]
hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.hidden_size])
hidden_states = tf.matmul(a=hidden_states, b=self.decoder.weight, transpose_b=True)
hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, seq_length, self.config.vocab_size])
hidden_states = tf.nn.bias_add(value=hidden_states, bias=self.bias)
return hidden_states
class TFLongformerEmbeddings(keras.layers.Layer):
"""
Same as BertEmbeddings with a tiny tweak for positional embeddings indexing and some extra casting.
"""
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.padding_idx = 1
self.config = config
self.hidden_size = config.hidden_size
self.max_position_embeddings = config.max_position_embeddings
self.initializer_range = config.initializer_range
self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
def build(self, input_shape=None):
with tf.name_scope("word_embeddings"):
self.weight = self.add_weight(
name="weight",
shape=[self.config.vocab_size, self.hidden_size],
initializer=get_initializer(self.initializer_range),
)
with tf.name_scope("token_type_embeddings"):
self.token_type_embeddings = self.add_weight(
name="embeddings",
shape=[self.config.type_vocab_size, self.hidden_size],
initializer=get_initializer(self.initializer_range),
)
with tf.name_scope("position_embeddings"):
self.position_embeddings = self.add_weight(
name="embeddings",
shape=[self.max_position_embeddings, self.hidden_size],
initializer=get_initializer(self.initializer_range),
)
if self.built:
return
self.built = True
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])
def create_position_ids_from_input_ids(self, input_ids, past_key_values_length=0):
"""
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding
symbols are ignored. This is modified from fairseq's `utils.make_positions`.
Args:
input_ids: tf.Tensor
Returns: tf.Tensor
"""
mask = tf.cast(tf.math.not_equal(input_ids, self.padding_idx), dtype=input_ids.dtype)
incremental_indices = (tf.math.cumsum(mask, axis=1) + past_key_values_length) * mask
return incremental_indices + self.padding_idx
def call(
self,
input_ids=None,
position_ids=None,
token_type_ids=None,
inputs_embeds=None,
past_key_values_length=0,
training=False,
"""
Applies embedding based on inputs tensor.
Returns:
final_embeddings (`tf.Tensor`): output embedding tensor.
"""
assert not (input_ids is None and inputs_embeds is None)
if input_ids is not None:
check_embeddings_within_bounds(input_ids, self.config.vocab_size)
inputs_embeds = tf.gather(params=self.weight, indices=input_ids)
input_shape = shape_list(inputs_embeds)[:-1]
if token_type_ids is None:
token_type_ids = tf.cast(tf.fill(dims=input_shape, value=0), tf.int64)
if position_ids is None:
if input_ids is not None:
position_ids = self.create_position_ids_from_input_ids(
input_ids=input_ids, past_key_values_length=past_key_values_length
)
else:
position_ids = tf.expand_dims(
tf.range(start=self.padding_idx + 1, limit=input_shape[-1] + self.padding_idx + 1, dtype=tf.int64),
axis=0,
)
position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids)
final_embeddings = inputs_embeds + position_embeds + token_type_embeds
final_embeddings = self.LayerNorm(inputs=final_embeddings)
final_embeddings = self.dropout(inputs=final_embeddings, training=training)
return final_embeddings
class TFLongformerIntermediate(keras.layers.Layer):
def __init__(self, config: LongformerConfig, **kwargs):
super().__init__(**kwargs)
self.dense = keras.layers.Dense(
units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = get_tf_activation(config.hidden_act)
else:
self.intermediate_act_fn = config.hidden_act
self.config = config
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
class TFLongformerOutput(keras.layers.Layer):
def __init__(self, config: LongformerConfig, **kwargs):
super().__init__(**kwargs)
self.dense = keras.layers.Dense(
units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
)
self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
self.config = config
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.dropout(inputs=hidden_states, training=training)
hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor)
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.intermediate_size])
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])
class TFLongformerPooler(keras.layers.Layer):
def __init__(self, config: LongformerConfig, **kwargs):
super().__init__(**kwargs)
self.dense = keras.layers.Dense(
units=config.hidden_size,
kernel_initializer=get_initializer(config.initializer_range),
activation="tanh",
name="dense",
)
self.config = config
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(inputs=first_token_tensor)
return pooled_output
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
class TFLongformerSelfOutput(keras.layers.Layer):
def __init__(self, config: LongformerConfig, **kwargs):
super().__init__(**kwargs)
self.dense = keras.layers.Dense(
units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
)
self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
self.config = config
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.dropout(inputs=hidden_states, training=training)
hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor)
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])
class TFLongformerSelfAttention(keras.layers.Layer):
def __init__(self, config, layer_id, **kwargs):
super().__init__(**kwargs)
self.config = config
if config.hidden_size % config.num_attention_heads != 0:
raise ValueError(
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
f"heads ({config.num_attention_heads}"
)
self.num_heads = config.num_attention_heads
self.head_dim = int(config.hidden_size / config.num_attention_heads)
self.embed_dim = config.hidden_size
self.query = keras.layers.Dense(
self.embed_dim,
kernel_initializer=get_initializer(config.initializer_range),
name="query",
)
self.key = keras.layers.Dense(
self.embed_dim,
kernel_initializer=get_initializer(config.initializer_range),
name="key",
)
self.value = keras.layers.Dense(
self.embed_dim,
kernel_initializer=get_initializer(config.initializer_range),
name="value",
)
self.query_global = keras.layers.Dense(
self.embed_dim,
kernel_initializer=get_initializer(config.initializer_range),
name="query_global",
)
self.key_global = keras.layers.Dense(
self.embed_dim,
kernel_initializer=get_initializer(config.initializer_range),
name="key_global",
)
self.value_global = keras.layers.Dense(
self.embed_dim,
kernel_initializer=get_initializer(config.initializer_range),
name="value_global",
)
self.dropout = keras.layers.Dropout(config.attention_probs_dropout_prob)
self.global_dropout = keras.layers.Dropout(config.attention_probs_dropout_prob)
self.layer_id = layer_id
attention_window = config.attention_window[self.layer_id]
assert (
attention_window % 2 == 0
), f"`attention_window` for layer {self.layer_id} has to be an even value. Given {attention_window}"
assert (
attention_window > 0
), f"`attention_window` for layer {self.layer_id} has to be positive. Given {attention_window}"
self.one_sided_attn_window_size = attention_window // 2
def build(self, input_shape=None):
if not self.built:
with tf.name_scope("query_global"):
self.query_global.build((self.config.hidden_size,))
with tf.name_scope("key_global"):
self.key_global.build((self.config.hidden_size,))
with tf.name_scope("value_global"):
self.value_global.build((self.config.hidden_size,))
if self.built:
return
self.built = True
if getattr(self, "query", None) is not None:
with tf.name_scope(self.query.name):
self.query.build([None, None, self.config.hidden_size])
if getattr(self, "key", None) is not None:
with tf.name_scope(self.key.name):
self.key.build([None, None, self.config.hidden_size])
if getattr(self, "value", None) is not None:
with tf.name_scope(self.value.name):
self.value.build([None, None, self.config.hidden_size])
if getattr(self, "query_global", None) is not None:
with tf.name_scope(self.query_global.name):
self.query_global.build([None, None, self.config.hidden_size])
if getattr(self, "key_global", None) is not None:
with tf.name_scope(self.key_global.name):
self.key_global.build([None, None, self.config.hidden_size])
if getattr(self, "value_global", None) is not None:
with tf.name_scope(self.value_global.name):
self.value_global.build([None, None, self.config.hidden_size])
def call(
self,
inputs,
training=False,
):
pass
@staticmethod
def _mask_invalid_locations(input_tensor, window_overlap):
mask_2d_upper = tf.reverse(
tf.linalg.band_part(tf.ones(shape=(window_overlap, window_overlap + 1)), -1, 0),
axis=[0],
)
padding = tf.convert_to_tensor(
[[0, shape_list(input_tensor)[1] - window_overlap], [0, shape_list(input_tensor)[3] - window_overlap - 1]]
)
mask_2d = tf.pad(mask_2d_upper, padding)
mask_2d = mask_2d + tf.reverse(mask_2d, axis=[0, 1])
mask_4d = tf.tile(mask_2d[None, :, None, :], (shape_list(input_tensor)[0], 1, 1, 1))
inf_tensor = -float("inf") * tf.ones_like(input_tensor)
input_tensor = tf.where(tf.math.greater(mask_4d, 0), inf_tensor, input_tensor)
return input_tensor
def _sliding_chunks_matmul_attn_probs_value(self, attn_probs, value, window_overlap):
"""
Same as _sliding_chunks_query_key_matmul but for attn_probs and value tensors. Returned tensor will be of the
same shape as `attn_probs`
"""
batch_size, seq_len, num_heads, head_dim = shape_list(value)
tf.debugging.assert_equal(
seq_len % (window_overlap * 2), 0, message="Seq_len has to be multiple of 2 * window_overlap"
)
tf.debugging.assert_equal(
shape_list(attn_probs)[:3],
shape_list(value)[:3],
message="value and attn_probs must have same dims (except head_dim)",
)
tf.debugging.assert_equal(
shape_list(attn_probs)[3],
2 * window_overlap + 1,
message="attn_probs last dim has to be 2 * window_overlap + 1",
)
chunks_count = seq_len // window_overlap - 1
chunked_attn_probs = tf.reshape(
tf.transpose(attn_probs, (0, 2, 1, 3)),
(
batch_size * num_heads,
seq_len // window_overlap,
window_overlap,
2 * window_overlap + 1,
),
)
value = tf.reshape(
tf.transpose(value, (0, 2, 1, 3)),
(batch_size * num_heads, seq_len, head_dim),
)
paddings = tf.convert_to_tensor([[0, 0], [window_overlap, window_overlap], [0, 0]])
padded_value = tf.pad(value, paddings, constant_values=-1)
frame_size = 3 * window_overlap * head_dim
frame_hop_size = (shape_list(padded_value)[1] * head_dim - frame_size) // chunks_count
chunked_value = tf.signal.frame(
tf.reshape(padded_value, (batch_size * num_heads, -1)),
frame_size,
frame_hop_size,
)
chunked_value = tf.reshape(
chunked_value,
(batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim),
)
tf.debugging.assert_equal(
shape_list(chunked_value),
[batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim],
message="Chunked value has the wrong shape",
)
chunked_attn_probs = self._pad_and_diagonalize(chunked_attn_probs)
context = tf.einsum("bcwd,bcdh->bcwh", chunked_attn_probs, chunked_value)
context = tf.transpose(
tf.reshape(context, (batch_size, num_heads, seq_len, head_dim)),
(0, 2, 1, 3),
)
return context
def _pad_and_transpose_last_two_dims(hidden_states_padded, paddings):
"""
Pads the last two dimensions of `hidden_states_padded` tensor and then transposes the last two dimensions.
Args:
hidden_states_padded: Input tensor to be padded and transposed.
paddings: Tensor specifying the padding amounts for each dimension.
Returns:
Transposed tensor after padding.
"""
hidden_states_padded = tf.pad(
hidden_states_padded, paddings
)
batch_size, chunk_size, seq_length, hidden_dim = shape_list(hidden_states_padded)
hidden_states_padded = tf.reshape(hidden_states_padded, (batch_size, chunk_size, hidden_dim, seq_length))
return hidden_states_padded
@staticmethod
def _pad_and_diagonalize(chunked_hidden_states):
"""
Shifts every row 1 step right, converting columns into diagonals.
Example:
chunked_hidden_states: A 4-dimensional tensor representing chunked hidden states.
window_overlap: Integer representing the number of rows/columns to shift.
Returns:
Tensor with padded and diagonalized dimensions.
"""
total_num_heads, num_chunks, window_overlap, hidden_dim = shape_list(chunked_hidden_states)
paddings = tf.convert_to_tensor([[0, 0], [0, 0], [0, 0], [0, window_overlap + 1]])
chunked_hidden_states = tf.pad(
chunked_hidden_states, paddings
)
chunked_hidden_states = tf.reshape(
chunked_hidden_states, (total_num_heads, num_chunks, -1)
)
chunked_hidden_states = chunked_hidden_states[
:, :, :-window_overlap
]
chunked_hidden_states = tf.reshape(
chunked_hidden_states,
(total_num_heads, num_chunks, window_overlap, window_overlap + hidden_dim),
)
chunked_hidden_states = chunked_hidden_states[:, :, :, :-1]
return chunked_hidden_states
def _chunk(hidden_states, window_overlap):
"""convert into overlapping chunks. Chunk size = 2w, overlap size = w"""
batch_size, seq_length, hidden_dim = shape_list(hidden_states)
num_output_chunks = 2 * (seq_length // (2 * window_overlap)) - 1
frame_hop_size = window_overlap * hidden_dim
frame_size = 2 * frame_hop_size
hidden_states = tf.reshape(hidden_states, (batch_size, seq_length * hidden_dim))
chunked_hidden_states = tf.signal.frame(hidden_states, frame_size, frame_hop_size)
tf.debugging.assert_equal(
shape_list(chunked_hidden_states),
[batch_size, num_output_chunks, frame_size],
message=(
"Make sure chunking is correctly applied. `Chunked hidden states should have output dimension"
f" {[batch_size, frame_size, num_output_chunks]}, but got {shape_list(chunked_hidden_states)}."
),
)
chunked_hidden_states = tf.reshape(
chunked_hidden_states,
(batch_size, num_output_chunks, 2 * window_overlap, hidden_dim),
)
return chunked_hidden_states
@staticmethod
def _get_global_attn_indices(is_index_global_attn):
"""compute global attn indices required throughout forward pass"""
num_global_attn_indices = tf.math.count_nonzero(is_index_global_attn, axis=1)
num_global_attn_indices = tf.cast(num_global_attn_indices, dtype=tf.constant(1).dtype)
max_num_global_attn_indices = tf.reduce_max(num_global_attn_indices)
is_index_global_attn_nonzero = tf.where(is_index_global_attn)
is_local_index_global_attn = tf.range(max_num_global_attn_indices) < tf.expand_dims(
num_global_attn_indices, axis=-1
)
is_local_index_global_attn_nonzero = tf.where(is_local_index_global_attn)
is_local_index_no_global_attn_nonzero = tf.where(tf.math.logical_not(is_local_index_global_attn))
return (
max_num_global_attn_indices,
is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero,
)
def _concat_with_global_key_attn_probs(
self,
attn_scores,
key_vectors,
query_vectors,
max_num_global_attn_indices,
is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero,
):
batch_size = shape_list(key_vectors)[0]
global_key_vectors = tf.gather_nd(key_vectors, is_index_global_attn_nonzero)
key_vectors_only_global = tf.scatter_nd(
is_local_index_global_attn_nonzero,
global_key_vectors,
shape=(
batch_size,
max_num_global_attn_indices,
self.num_heads,
self.head_dim,
),
)
attn_probs_from_global_key = tf.einsum("blhd,bshd->blhs", query_vectors, key_vectors_only_global)
attn_probs_from_global_key_trans = tf.transpose(attn_probs_from_global_key, (0, 3, 1, 2))
mask_shape = (shape_list(is_local_index_no_global_attn_nonzero)[0],) + tuple(
shape_list(attn_probs_from_global_key_trans)[-2:]
)
mask = tf.ones(mask_shape) * -10000.0
mask = tf.cast(mask, dtype=attn_probs_from_global_key_trans.dtype)
attn_probs_from_global_key_trans = tf.tensor_scatter_nd_update(
attn_probs_from_global_key_trans,
is_local_index_no_global_attn_nonzero,
mask,
)
attn_probs_from_global_key = tf.transpose(attn_probs_from_global_key_trans, (0, 2, 3, 1))
attn_scores = tf.concat((attn_probs_from_global_key, attn_scores), axis=-1)
return attn_scores
def _compute_attn_output_with_global_indices(
self,
value_vectors,
attn_probs,
max_num_global_attn_indices,
is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero,
):
batch_size = shape_list(attn_probs)[0]
attn_probs_only_global = attn_probs[:, :, :, :max_num_global_attn_indices]
global_value_vectors = tf.gather_nd(value_vectors, is_index_global_attn_nonzero)
value_vectors_only_global = tf.scatter_nd(
is_local_index_global_attn_nonzero,
global_value_vectors,
shape=(
batch_size,
max_num_global_attn_indices,
self.num_heads,
self.head_dim,
),
)
attn_output_only_global = tf.einsum("blhs,bshd->blhd", attn_probs_only_global, value_vectors_only_global)
attn_probs_without_global = attn_probs[:, :, :, max_num_global_attn_indices:]
attn_output_without_global = self._sliding_chunks_matmul_attn_probs_value(
attn_probs_without_global, value_vectors, self.one_sided_attn_window_size
)
return attn_output_only_global + attn_output_without_global
def _compute_global_attn_output_from_hidden(
self,
attn_output,
hidden_states,
max_num_global_attn_indices,
layer_head_mask,
is_local_index_global_attn_nonzero,
is_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero,
is_index_masked,
training,
):
def reshape_and_transpose(self, vector, batch_size):
return tf.reshape(
tf.transpose(
tf.reshape(vector, (batch_size, -1, self.num_heads, self.head_dim)),
(0, 2, 1, 3),
),
(batch_size * self.num_heads, -1, self.head_dim),
)
class TFLongformerAttention(keras.layers.Layer):
def __init__(self, config, layer_id=0, **kwargs):
super().__init__(**kwargs)
self.self_attention = TFLongformerSelfAttention(config, layer_id, name="self")
self.dense_output = TFLongformerSelfOutput(config, name="output")
def prune_heads(self, heads):
raise NotImplementedError
def call(self, inputs, training=False):
(
hidden_states,
attention_mask,
layer_head_mask,
is_index_masked,
is_index_global_attn,
is_global_attn,
) = inputs
self_outputs = self.self_attention(
[hidden_states, attention_mask, layer_head_mask, is_index_masked, is_index_global_attn, is_global_attn],
training=training,
)
attention_output = self.dense_output(self_outputs[0], hidden_states, training=training)
outputs = (attention_output,) + self_outputs[1:]
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "self_attention", None) is not None:
with tf.name_scope(self.self_attention.name):
self.self_attention.build(None)
if getattr(self, "dense_output", None) is not None:
with tf.name_scope(self.dense_output.name):
self.dense_output.build(None)
class TFLongformerLayer(keras.layers.Layer):
def __init__(self, config, layer_id=0, **kwargs):
super().__init__(**kwargs)
self.attention = TFLongformerAttention(config, layer_id, name="attention")
self.intermediate = TFLongformerIntermediate(config, name="intermediate")
self.longformer_output = TFLongformerOutput(config, name="output")
def call(self, inputs, training=False):
(
hidden_states,
attention_mask,
layer_head_mask,
is_index_masked,
is_index_global_attn,
is_global_attn,
) = inputs
attention_outputs = self.attention(
[hidden_states, attention_mask, layer_head_mask, is_index_masked, is_index_global_attn, is_global_attn],
training=training,
)
attention_output = attention_outputs[0]
intermediate_output = self.intermediate(attention_output)
layer_output = self.longformer_output(intermediate_output, attention_output, training=training)
outputs = (layer_output,) + attention_outputs[1:]
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "attention", None) is not None:
with tf.name_scope(self.attention.name):
self.attention.build(None)
if getattr(self, "intermediate", None) is not None:
with tf.name_scope(self.intermediate.name):
self.intermediate.build(None)
if getattr(self, "longformer_output", None) is not None:
with tf.name_scope(self.longformer_output.name):
self.longformer_output.build(None)
class TFLongformerEncoder(keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions
self.layer = [TFLongformerLayer(config, i, name=f"layer_._{i}") for i in range(config.num_hidden_layers)]
def call(
self,
hidden_states,
attention_mask=None,
head_mask=None,
padding_len=0,
is_index_masked=None,
is_index_global_attn=None,
is_global_attn=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
):
all_hidden_states = () if output_hidden_states else None
all_attentions = all_global_attentions = () if output_attentions else None
for idx, layer_module in enumerate(self.layer):
if output_hidden_states:
hidden_states_to_add = hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states
all_hidden_states = all_hidden_states + (hidden_states_to_add,)
layer_outputs = layer_module(
[
hidden_states,
attention_mask,
head_mask[idx] if head_mask is not None else None,
is_index_masked,
is_index_global_attn,
is_global_attn,
],
training=training,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (tf.transpose(layer_outputs[1], (0, 2, 1, 3)),)
all_global_attentions = all_global_attentions + (tf.transpose(layer_outputs[2], (0, 1, 3, 2)),)
if output_hidden_states:
hidden_states_to_add = hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states
all_hidden_states = all_hidden_states + (hidden_states_to_add,)
hidden_states = hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states
if output_attentions:
all_attentions = (
tuple([state[:, :, :-padding_len, :] for state in all_attentions])
if padding_len > 0
else all_attentions
)
if not return_dict:
return tuple(
v for v in [hidden_states, all_hidden_states, all_attentions, all_global_attentions] if v is not None
)
return TFLongformerBaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_attentions,
global_attentions=all_global_attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "layer", None) is not None:
for layer in self.layer:
with tf.name_scope(layer.name):
layer.build(None)
@keras_serializable
class TFLongformerMainLayer(keras.layers.Layer):
config_class = LongformerConfig
def __init__(self, config, add_pooling_layer=True, **kwargs):
super().__init__(**kwargs)
if isinstance(config.attention_window, int):
assert config.attention_window % 2 == 0, "`config.attention_window` has to be an even value"
assert config.attention_window > 0, "`config.attention_window` has to be positive"
config.attention_window = [config.attention_window] * config.num_hidden_layers
else:
assert len(config.attention_window) == config.num_hidden_layers, (
"`len(config.attention_window)` should equal `config.num_hidden_layers`. "
f"Expected {config.num_hidden_layers}, given {len(config.attention_window)}"
)
self.config = config
self.num_hidden_layers = config.num_hidden_layers
self.initializer_range = config.initializer_range
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
self.return_dict = config.use_return_dict
self.pad_token_id = config.pad_token_id
self.attention_window = config.attention_window
self.embeddings = TFLongformerEmbeddings(config, name="embeddings")
self.encoder = TFLongformerEncoder(config, name="encoder")
self.pooler = TFLongformerPooler(config, name="pooler") if add_pooling_layer else None
def get_input_embeddings(self):
return self.embeddings
def set_input_embeddings(self, value):
self.embeddings.weight = value
self.embeddings.vocab_size = shape_list(value)[0]
def _prune_heads(self, heads_to_prune):
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
raise NotImplementedError
@unpack_inputs
def call(
self,
input_ids=None,
attention_mask=None,
head_mask=None,
global_attention_mask=None,
token_type_ids=None,
position_ids=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
):
def _pad_to_window_size(
self,
input_ids,
attention_mask,
token_type_ids,
position_ids,
inputs_embeds,
pad_token_id,
):
):
"""A helper function to pad tokens and mask to work with implementation of Longformer selfattention."""
attention_window = (
self.attention_window if isinstance(self.attention_window, int) else max(self.attention_window)
)
assert attention_window % 2 == 0, f"`attention_window` should be an even value. Given {attention_window}"
input_shape = shape_list(input_ids) if input_ids is not None else shape_list(inputs_embeds)
batch_size, seq_len = input_shape[:2]
padding_len = (attention_window - seq_len % attention_window) % attention_window
paddings = tf.convert_to_tensor([[0, 0], [0, padding_len]])
if input_ids is not None:
input_ids = tf.pad(input_ids, paddings, constant_values=pad_token_id)
if position_ids is not None:
position_ids = tf.pad(position_ids, paddings, constant_values=pad_token_id)
if inputs_embeds is not None:
if padding_len > 0:
input_ids_padding = tf.cast(tf.fill((batch_size, padding_len), self.pad_token_id), tf.int64)
inputs_embeds_padding = self.embeddings(input_ids_padding)
inputs_embeds = tf.concat([inputs_embeds, inputs_embeds_padding], axis=-2)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "embeddings", None) is not None:
with tf.name_scope(self.embeddings.name):
self.embeddings.build(None)
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
if getattr(self, "pooler", None) is not None:
with tf.name_scope(self.pooler.name):
self.pooler.build(None)
"""
这是一个抽象类,处理权重初始化以及下载和加载预训练模型的简单接口。
config_class = LongformerConfig
base_model_prefix = "longformer"
@property
def input_signature(self):
sig = super().input_signature
sig["global_attention_mask"] = tf.TensorSpec((None, None), tf.int32, name="global_attention_mask")
return sig
"""
LONGFORMER_START_DOCSTRING = r"""
这个模型继承自[`TFPreTrainedModel`]。请查看超类文档,了解库实现的所有通用方法(如下载或保存、调整输入嵌入大小、修剪头等)。
这个模型也是一个 [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) 的子类。可以像使用常规的 TF 2.0 Keras 模型一样使用它,并参考 TF 2.0 的文档了解有关一般使用和行为的所有内容。
<Tip>
`transformers` 中的 TensorFlow 模型和层接受两种输入格式:
- 将所有输入作为关键字参数(类似于 PyTorch 模型)传递,或者
- 将所有输入作为列表、元组或字典传递给第一个位置参数。
支持第二种格式的原因是,Keras 方法在传递输入给模型和层时更喜欢这种格式。由于这种支持,当使用 `model.fit()` 等方法时,只需将输入和标签以 `model.fit()` 支持的任何格式传递即可!然而,如果要在 Keras 方法之外(如在使用 Keras `Functional` API 创建自己的层或模型时)使用第二种格式,有三种可能的方法可以使用来收集第一个位置参数中的所有输入张量:
- 只有 `input_ids` 的单个张量,没有其他内容:`model(input_ids)`
- 长度可变的列表,按照文档字符串中给定的顺序包含一个或多个输入张量:`model([input_ids, attention_mask])` 或 `model([input_ids, attention_mask, token_type_ids])`
- 一个字典,其中包含一个或多个输入张量,与文档字符串中给定的输入名称相关联:`model({"input_ids": input_ids, "token_type_ids": token_type_ids})`
注意,当使用 [子类化](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) 创建模型和层时,不需要担心这些内容,因为可以像将输入传递给任何其他 Python 函数一样传递输入!
</Tip>
Parameters:
config ([`LongformerConfig`]): 包含模型所有参数的模型配置类。
使用配置文件初始化不会加载与模型相关的权重,仅加载配置。
查看 [`~PreTrainedModel.from_pretrained`] 方法以加载模型权重。
"""
LONGFORMER_INPUTS_DOCSTRING = r"""
"""
@add_start_docstrings(
"The bare Longformer Model outputting raw hidden-states without any specific head on top.",
LONGFORMER_START_DOCSTRING,
)
class TFLongformerModel(TFLongformerPreTrainedModel):
"""
TFLongformerModel类继承自TFLongformerPreTrainedModel,用于输出不带特定头部的原始隐藏状态。
This class copies code from [`TFRobertaModel`] and overwrites standard self-attention with longformer
self-attention to provide the ability to process long sequences following the self-attention approach described in
[Longformer: the Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, and
Arman Cohan. Longformer self-attention combines a local (sliding window) and global attention to extend to long
documents without the O(n^2) increase in memory and compute.
The self-attention module `TFLongformerSelfAttention` implemented here supports the combination of local and global
attention but it lacks support for autoregressive attention and dilated attention. Autoregressive and dilated
attention are more relevant for autoregressive language modeling than finetuning on downstream tasks. Future
release will add support for autoregressive attention, but the support for dilated attention requires a custom CUDA
kernel to be memory and compute efficient.
"""
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.longformer = TFLongformerMainLayer(config, name="longformer")
@unpack_inputs
@add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def call(
self,
input_ids: TFModelInputType | None = None,
attention_mask: np.ndarray | tf.Tensor | None = None,
head_mask: np.ndarray | tf.Tensor | None = None,
global_attention_mask: np.ndarray | tf.Tensor | None = None,
token_type_ids: np.ndarray | tf.Tensor | None = None,
position_ids: np.ndarray | tf.Tensor | None = None,
inputs_embeds: np.ndarray | tf.Tensor | None = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: Optional[bool] = False,
) -> Union[TFLongformerBaseModelOutputWithPooling, Tuple[tf.Tensor]]:
outputs = self.longformer(
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
global_attention_mask=global_attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "longformer", None) is not None:
with tf.name_scope(self.longformer.name):
self.longformer.build(None)
@add_start_docstrings(
"""Longformer Model with a `language modeling` head on top.""",
LONGFORMER_START_DOCSTRING,
)
class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModelingLoss):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.longformer = TFLongformerMainLayer(config, add_pooling_layer=False, name="longformer")
self.lm_head = TFLongformerLMHead(config, self.longformer.embeddings, name="lm_head")
def get_lm_head(self):
return self.lm_head
def get_prefix_bias_name(self):
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.lm_head.name
@unpack_inputs
@add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
checkpoint="allenai/longformer-base-4096",
output_type=TFLongformerMaskedLMOutput,
config_class=_CONFIG_FOR_DOC,
mask="<mask>",
expected_output="' Paris'",
expected_loss=0.44,
)
def call(
self,
input_ids: TFModelInputType | None = None,
attention_mask: np.ndarray | tf.Tensor | None = None,
head_mask: np.ndarray | tf.Tensor | None = None,
global_attention_mask: np.ndarray | tf.Tensor | None = None,
token_type_ids: np.ndarray | tf.Tensor | None = None,
position_ids: np.ndarray | tf.Tensor | None = None,
inputs_embeds: np.ndarray | tf.Tensor | None = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
labels: np.ndarray | tf.Tensor | None = None,
training: Optional[bool] = False,
**kwargs,
):
"""
使用 Longformer 进行前向传播,支持以下输入参数:
- input_ids: 输入的模型标识符
- attention_mask: 注意力遮罩,指定哪些元素需要被处理
- head_mask: 头部遮罩,用于控制多头注意力层的掩码
- global_attention_mask: 全局注意力遮罩,控制全局注意力机制
- token_type_ids: 标记类型标识符,用于区分不同文本段落
- position_ids: 位置标识符,指定输入序列中每个位置的绝对位置
- inputs_embeds: 输入嵌入,替代输入模型标识符的嵌入表示
- output_attentions: 是否输出注意力权重
- output_hidden_states: 是否输出隐藏状态
- return_dict: 是否返回结果字典
- labels: 标签,用于模型训练
- training: 是否为训练模式
其中,kwargs 包含其他未显式列出的关键字参数。
"""
pass
) -> Union[TFLongformerMaskedLMOutput, Tuple[tf.Tensor]]:
r"""
labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
"""
outputs = self.longformer(
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
global_attention_mask=global_attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
sequence_output = outputs[0]
prediction_scores = self.lm_head(sequence_output, training=training)
loss = None if labels is None else self.hf_compute_loss(labels, prediction_scores)
if not return_dict:
output = (prediction_scores,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TFLongformerMaskedLMOutput(
loss=loss,
logits=prediction_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
global_attentions=outputs.global_attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "longformer", None) is not None:
with tf.name_scope(self.longformer.name):
self.longformer.build(None)
if getattr(self, "lm_head", None) is not None:
with tf.name_scope(self.lm_head.name):
self.lm_head.build(None)
"""
Longformer Model with a span classification head on top for extractive question-answering tasks like SQuAD /
TriviaQA (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
"""
class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAnsweringLoss):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
self.longformer = TFLongformerMainLayer(config, add_pooling_layer=False, name="longformer")
self.qa_outputs = keras.layers.Dense(
config.num_labels,
kernel_initializer=get_initializer(config.initializer_range),
name="qa_outputs",
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
checkpoint="allenai/longformer-large-4096-finetuned-triviaqa",
output_type=TFLongformerQuestionAnsweringModelOutput,
config_class=_CONFIG_FOR_DOC,
expected_output="' puppet'",
expected_loss=0.96,
)
def call(
self,
input_ids: TFModelInputType | None = None,
attention_mask: np.ndarray | tf.Tensor | None = None,
head_mask: np.ndarray | tf.Tensor | None = None,
global_attention_mask: np.ndarray | tf.Tensor | None = None,
token_type_ids: np.ndarray | tf.Tensor | None = None,
position_ids: np.ndarray | tf.Tensor | None = None,
inputs_embeds: np.ndarray | tf.Tensor | None = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
start_positions: np.ndarray | tf.Tensor | None = None,
end_positions: np.ndarray | tf.Tensor | None = None,
training: Optional[bool] = False,
):
pass
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "longformer", None) is not None:
with tf.name_scope(self.longformer.name):
self.longformer.build(None)
if getattr(self, "qa_outputs", None) is not None:
with tf.name_scope(self.qa_outputs.name):
self.qa_outputs.build([None, None, self.config.hidden_size])
class TFLongformerClassificationHead(keras.layers.Layer):
"""Head for sentence-level classification tasks."""
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.dense = keras.layers.Dense(
config.hidden_size,
kernel_initializer=get_initializer(config.initializer_range),
activation="tanh",
name="dense",
)
self.dropout = keras.layers.Dropout(config.hidden_dropout_prob)
self.out_proj = keras.layers.Dense(
config.num_labels,
kernel_initializer=get_initializer(config.initializer_range),
name="out_proj"
)
self.config = config
def call(self, hidden_states, training=False):
hidden_states = hidden_states[:, 0, :]
hidden_states = self.dropout(hidden_states, training=training)
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states, training=training)
output = self.out_proj(hidden_states)
return output
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
if getattr(self, "out_proj", None) is not None:
with tf.name_scope(self.out_proj.name):
self.out_proj.build([None, None, self.config.hidden_size])
@add_start_docstrings(
"""
Longformer Model transformer with a sequence classification/regression head on top (a linear layer on top of the
pooled output) e.g. for GLUE tasks.
""",
LONGFORMER_START_DOCSTRING,
)
class TFLongformerForSequenceClassification(TFLongformerPreTrainedModel, TFSequenceClassificationLoss):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
self.longformer = TFLongformerMainLayer(config, add_pooling_layer=False, name="longformer")
self.classifier = TFLongformerClassificationHead(config, name="classifier")
@unpack_inputs
@add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TFLongformerSequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def call(
self,
input_ids: TFModelInputType | None = None,
attention_mask: np.ndarray | tf.Tensor | None = None,
head_mask: np.ndarray | tf.Tensor | None = None,
token_type_ids: np.ndarray | tf.Tensor | None = None,
position_ids: np.ndarray | tf.Tensor | None = None,
global_attention_mask: np.ndarray | tf.Tensor | None = None,
inputs_embeds: np.ndarray | tf.Tensor | None = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
labels: np.ndarray | tf.Tensor | None = None,
training: Optional[bool] = False,
) -> Union[TFLongformerSequenceClassifierOutput, Tuple[tf.Tensor]]:
if input_ids is not None and not isinstance(input_ids, tf.Tensor):
input_ids = tf.convert_to_tensor(input_ids, dtype=tf.int64)
elif input_ids is not None:
input_ids = tf.cast(input_ids, tf.int64)
if attention_mask is not None and not isinstance(attention_mask, tf.Tensor):
attention_mask = tf.convert_to_tensor(attention_mask, dtype=tf.int64)
elif attention_mask is not None:
attention_mask = tf.cast(attention_mask, tf.int64)
if global_attention_mask is not None and not isinstance(global_attention_mask, tf.Tensor):
global_attention_mask = tf.convert_to_tensor(global_attention_mask, dtype=tf.int64)
elif global_attention_mask is not None:
global_attention_mask = tf.cast(global_attention_mask, tf.int64)
if global_attention_mask is None and input_ids is not None:
logger.warning_once("Initializing global attention on CLS token...")
global_attention_mask = tf.zeros_like(input_ids)
updates = tf.ones(shape_list(input_ids)[0], dtype=tf.int64)
indices = tf.pad(
tensor=tf.expand_dims(tf.range(shape_list(input_ids)[0], dtype=tf.int64), axis=1),
paddings=[[0, 0], [0, 1]],
constant_values=0,
)
global_attention_mask = tf.tensor_scatter_nd_update(
global_attention_mask,
indices,
updates,
)
outputs = self.longformer(
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
global_attention_mask=global_attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
sequence_output = outputs[0]
logits = self.classifier(sequence_output)
loss = None if labels is None else self.hf_compute_loss(labels, logits)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TFLongformerSequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
global_attentions=outputs.global_attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "longformer", None) is not None:
with tf.name_scope(self.longformer.name):
self.longformer.build(None)
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build(None)
@add_start_docstrings(
"""
Longformer Model with a multiple choice classification head on top (a linear layer on top of the pooled output and
a softmax) e.g. for RocStories/SWAG tasks.
""",
LONGFORMER_START_DOCSTRING,
)
class TFLongformerForMultipleChoice(TFLongformerPreTrainedModel, TFMultipleChoiceLoss):
"""
定义了一个基于Longformer模型的多选题分类器,通过在汇总输出之上添加一个线性层和softmax来实现,
例如用于RocStories/SWAG任务。
"""
_keys_to_ignore_on_load_missing = [r"dropout"]
"""
在从PT模型加载TF模型时,表示授权的意外/缺失层的名称列表。
"""
def __init__(self, config, *inputs, **kwargs):
"""
初始化方法,用于创建模型实例。
Args:
config: Longformer模型的配置对象。
*inputs: 可变长度的输入参数。
**kwargs: 关键字参数。
"""
super().__init__(config, *inputs, **kwargs)
self.longformer = TFLongformerMainLayer(config, name="longformer")
"""
创建Longformer的主层实例,使用给定的配置和名称。
"""
self.dropout = keras.layers.Dropout(config.hidden_dropout_prob)
"""
创建一个Dropout层,用于在训练过程中随机丢弃部分神经元,防止过拟合。
"""
self.classifier = keras.layers.Dense(
1, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
)
"""
创建一个全连接层作为分类器,输出维度为1,使用给定的初始化器范围初始化权重。
"""
self.config = config
"""
保存配置对象供后续使用。
"""
@property
def input_signature(self):
"""
定义模型的输入签名,指定了各输入张量的形状和类型。
"""
return {
"input_ids": tf.TensorSpec((None, None, None), tf.int32, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None, None), tf.int32, name="attention_mask"),
"global_attention_mask": tf.TensorSpec((None, None, None), tf.int32, name="global_attention_mask"),
}
@unpack_inputs
@add_start_docstrings_to_model_forward(
LONGFORMER_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TFLongformerMultipleChoiceModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def call(
self,
input_ids: TFModelInputType | None = None,
attention_mask: np.ndarray | tf.Tensor | None = None,
head_mask: np.ndarray | tf.Tensor | None = None,
token_type_ids: np.ndarray | tf.Tensor | None = None,
position_ids: np.ndarray | tf.Tensor | None = None,
global_attention_mask: np.ndarray | tf.Tensor | None = None,
inputs_embeds: np.ndarray | tf.Tensor | None = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
labels: np.ndarray | tf.Tensor | None = None,
training: Optional[bool] = False,
) -> Union[TFLongformerMultipleChoiceModelOutput, Tuple[tf.Tensor]]:
r"""
labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]`
where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above)
"""
if input_ids is not None:
num_choices = shape_list(input_ids)[1]
seq_length = shape_list(input_ids)[2]
else:
num_choices = shape_list(inputs_embeds)[1]
seq_length = shape_list(inputs_embeds)[2]
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
flat_global_attention_mask = (
tf.reshape(global_attention_mask, (-1, shape_list(global_attention_mask)[-1]))
if global_attention_mask is not None
else None
)
flat_inputs_embeds = (
tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3]))
if inputs_embeds is not None
else None
)
outputs = self.longformer(
flat_input_ids,
position_ids=flat_position_ids,
token_type_ids=flat_token_type_ids,
attention_mask=flat_attention_mask,
head_mask=head_mask,
global_attention_mask=flat_global_attention_mask,
inputs_embeds=flat_inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
reshaped_logits = tf.reshape(logits, (-1, num_choices))
loss = None if labels is None else self.hf_compute_loss(labels, reshaped_logits)
if not return_dict:
output = (reshaped_logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TFLongformerMultipleChoiceModelOutput(
loss=loss,
logits=reshaped_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
global_attentions=outputs.global_attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "longformer", None) is not None:
with tf.name_scope(self.longformer.name):
self.longformer.build(None)
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build([None, None, self.config.hidden_size])
@add_start_docstrings(
"""
Longformer Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g.
for Named-Entity-Recognition (NER) tasks.
""",
LONGFORMER_START_DOCSTRING,
)
class TFLongformerForTokenClassification(TFLongformerPreTrainedModel, TFTokenClassificationLoss):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"dropout"]
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
self.longformer = TFLongformerMainLayer(config=config, add_pooling_layer=False, name="longformer")
self.dropout = keras.layers.Dropout(config.hidden_dropout_prob)
self.classifier = keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TFLongformerTokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def call(
self,
input_ids: TFModelInputType | None = None,
attention_mask: np.ndarray | tf.Tensor | None = None,
head_mask: np.ndarray | tf.Tensor | None = None,
token_type_ids: np.ndarray | tf.Tensor | None = None,
position_ids: np.ndarray | tf.Tensor | None = None,
global_attention_mask: np.ndarray | tf.Tensor | None = None,
inputs_embeds: np.ndarray | tf.Tensor | None = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
labels: Optional[Union[np.array, tf.Tensor]] = None,
training: Optional[bool] = False,
) -> Union[TFLongformerTokenClassifierOutput, Tuple[tf.Tensor]]:
r"""
labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
"""
outputs = self.longformer(
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
global_attention_mask=global_attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)
loss = None if labels is None else self.hf_compute_loss(labels, logits)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TFLongformerTokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
global_attentions=outputs.global_attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "longformer", None) is not None:
with tf.name_scope(self.longformer.name):
self.longformer.build(None)
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build([None, None, self.config.hidden_size])
.\models\longformer\tokenization_longformer.py
import json
import os
from functools import lru_cache
from typing import List, Optional, Tuple
import regex as re
from ...tokenization_utils import AddedToken, PreTrainedTokenizer
from ...utils import logging
logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt"}
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"allenai/longformer-base-4096": "https://huggingface.co/allenai/longformer-base-4096/resolve/main/vocab.json",
"allenai/longformer-large-4096": (
"https://huggingface.co/allenai/longformer-large-4096/resolve/main/vocab.json"
),
"allenai/longformer-large-4096-finetuned-triviaqa": (
"https://huggingface.co/allenai/longformer-large-4096-finetuned-triviaqa/resolve/main/vocab.json"
),
"allenai/longformer-base-4096-extra.pos.embd.only": (
"https://huggingface.co/allenai/longformer-base-4096-extra.pos.embd.only/resolve/main/vocab.json"
),
"allenai/longformer-large-4096-extra.pos.embd.only": (
"https://huggingface.co/allenai/longformer-large-4096-extra.pos.embd.only/resolve/main/vocab.json"
),
},
"merges_file": {
"allenai/longformer-base-4096": "https://huggingface.co/allenai/longformer-base-4096/resolve/main/merges.txt",
"allenai/longformer-large-4096": (
"https://huggingface.co/allenai/longformer-large-4096/resolve/main/merges.txt"
),
"allenai/longformer-large-4096-finetuned-triviaqa": (
"https://huggingface.co/allenai/longformer-large-4096-finetuned-triviaqa/resolve/main/merges.txt"
),
"allenai/longformer-base-4096-extra.pos.embd.only": (
"https://huggingface.co/allenai/longformer-base-4096-extra.pos.embd.only/resolve/main/merges.txt"
),
"allenai/longformer-large-4096-extra.pos.embd.only": (
"https://huggingface.co/allenai/longformer-large-4096-extra.pos.embd.only/resolve/main/merges.txt"
),
},
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"allenai/longformer-base-4096": 4096,
"allenai/longformer-large-4096": 4096,
"allenai/longformer-large-4096-finetuned-triviaqa": 4096,
"allenai/longformer-base-4096-extra.pos.embd.only": 4096,
}
"allenai/longformer-large-4096-extra.pos.embd.only": 4096,
}
@lru_cache()
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
characters the bpe code barfs on.
The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
tables between utf-8 bytes and unicode strings.
"""
bs = (
list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
)
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8 + n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
"""
Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
class LongformerTokenizer(PreTrainedTokenizer):
"""
Constructs a Longformer tokenizer, derived from the GPT-2 tokenizer, using byte-level Byte-Pair-Encoding.
This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will
be encoded differently whether it is at the beginning of the sentence (without space) or not:
```
>>> from transformers import LongformerTokenizer
>>> tokenizer = LongformerTokenizer.from_pretrained("allenai/longformer-base-4096")
>>> tokenizer("Hello world")["input_ids"]
[0, 31414, 232, 2]
>>> tokenizer(" Hello world")["input_ids"]
[0, 20920, 232, 2]
```
You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer or when you
call it on some text, but since the model was not pretrained this way, it might yield a decrease in performance.
<Tip>
When used with `is_split_into_words=True`, this tokenizer will add a space before each word (even the first one).
</Tip>
This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
this superclass for more information regarding those methods.
"""
def __init__(self, *init_inputs, **kwargs):
super().__init__(*init_inputs, **kwargs)
vocab_file (`str`):
Path to the vocabulary file.
merges_file (`str`):
Path to the merges file.
errors (`str`, *optional*, defaults to `"replace"`):
Paradigm to follow when decoding bytes to UTF-8. See
[bytes.decode](https://docs.python.org/3/library/stdtypes.html
bos_token (`str`, *optional*, defaults to `"<s>"`):
The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
<Tip>
When building a sequence using special tokens, this is not the token that is used for the beginning of
sequence. The token used is the `cls_token`.
</Tip>
eos_token (`str`, *optional*, defaults to `"</s>"`):
The end of sequence token.
<Tip>
When building a sequence using special tokens, this is not the token that is used for the end of sequence.
The token used is the `sep_token`.
</Tip>
sep_token (`str`, *optional*, defaults to `"</s>"`):
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
sequence classification or for a text and a question for question answering. It is also used as the last
token of a sequence built with special tokens.
cls_token (`str`, *optional*, defaults to `"<s>"`):
The classifier token which is used when doing sequence classification (classification of the whole sequence
instead of per-token classification). It is the first token of the sequence when built with special tokens.
unk_token (`str`, *optional*, defaults to `"<unk>"`):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
token instead.
pad_token (`str`, *optional*, defaults to `"<pad>"`):
The token used for padding, for example when batching sequences of different lengths.
mask_token (`str`, *optional*, defaults to `"<mask>"`):
The token used for masking values. This is the token used when training this model with masked language
modeling. This is the token which the model will try to predict.
add_prefix_space (`bool`, *optional*, defaults to `False`):
Whether or not to add an initial space to the input. This allows to treat the leading word just as any
other word. (Longformer tokenizer detect beginning of words by the preceding space).
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
model_input_names = ["input_ids", "attention_mask"]
def __init__(
self,
vocab_file,
merges_file,
errors="replace",
bos_token="<s>",
eos_token="</s>",
sep_token="</s>",
cls_token="<s>",
unk_token="<unk>",
pad_token="<pad>",
mask_token="<mask>",
add_prefix_space=False,
**kwargs,
):
bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
sep_token = AddedToken(sep_token, lstrip=False, rstrip=False) if isinstance(sep_token, str) else sep_token
cls_token = AddedToken(cls_token, lstrip=False, rstrip=False) if isinstance(cls_token, str) else cls_token
mask_token = (
AddedToken(mask_token, lstrip=True, rstrip=False, normalized=False)
if isinstance(mask_token, str)
else mask_token
)
with open(vocab_file, encoding="utf-8") as vocab_handle:
self.encoder = json.load(vocab_handle)
self.decoder = {v: k for k, v in self.encoder.items()}
self.errors = errors
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
with open(merges_file, encoding="utf-8") as merges_handle:
bpe_merges = merges_handle.read().split("\n")[1:-1]
bpe_merges = [tuple(merge.split()) for merge in bpe_merges]
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
self.cache = {}
self.add_prefix_space = add_prefix_space
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
super().__init__(
errors=errors,
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
sep_token=sep_token,
cls_token=cls_token,
pad_token=pad_token,
mask_token=mask_token,
add_prefix_space=add_prefix_space,
**kwargs,
)
@property
def vocab_size(self):
return len(self.encoder)
def get_vocab(self):
vocab = dict(self.encoder).copy()
vocab.update(self.added_tokens_encoder)
return vocab
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token)
pairs = get_pairs(word)
if not pairs:
return token
while True:
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
except ValueError:
new_word.extend(word[i:])
break
else:
new_word.extend(word[i:j])
i = j
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = " ".join(word)
self.cache[token] = word
return word
def _tokenize(self, text):
"""Tokenize a string."""
bpe_tokens = []
for token in re.findall(self.pat, text):
token = "".join(
self.byte_encoder[b] for b in token.encode("utf-8")
)
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
return bpe_tokens
def _convert_token_to_id(self, token):
"""Converts a token (str) in an id using the vocab."""
return self.encoder.get(token, self.encoder.get(self.unk_token))
def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab."""
return self.decoder.get(index)
def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (string) in a single string."""
text = "".join(tokens)
text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
return text
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
return
vocab_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
)
merge_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
)
with open(vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
index = 0
with open(merge_file, "w", encoding="utf-8") as writer:
writer.write("#version: 0.2\n")
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning(
f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive."
" Please check that the tokenizer is not corrupted!"
)
index = token_index
writer.write(" ".join(bpe_tokens) + "\n")
index += 1
return vocab_file, merge_file
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
adding special tokens. A Longformer sequence has the following format:
- single sequence: `<s> X </s>`
- pair of sequences: `<s> A </s></s> B </s>`
Args:
token_ids_0 (`List[int]`):
List of IDs to which the special tokens will be added.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
"""
if token_ids_1 is None:
return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
cls = [self.cls_token_id]
sep = [self.sep_token_id]
return cls + token_ids_0 + sep + sep + token_ids_1 + sep
def get_special_tokens_mask(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
):
if already_has_special_tokens:
return [1] * len(token_ids_0)
special_tokens_mask = []
for token_id in token_ids_0:
special_tokens_mask.append(1 if token_id in [self.cls_token_id, self.sep_token_id] else 0)
if token_ids_1 is not None:
for token_id in token_ids_1:
special_tokens_mask.append(1 if token_id in [self.cls_token_id, self.sep_token_id] else 0)
return special_tokens_mask
) -> List[int]:
"""
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer `prepare_for_model` method.
Args:
token_ids_0 (`List[int]`):
List of IDs.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not the token list is already formatted with special tokens for the model.
Returns:
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
"""
if already_has_special_tokens:
return super().get_special_tokens_mask(
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
)
if token_ids_1 is None:
return [1] + ([0] * len(token_ids_0)) + [1]
else:
return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
def create_token_type_ids_from_sequences(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Create a mask from the two sequences passed to be used in a sequence-pair classification task. Longformer does not
make use of token type ids, therefore a list of zeros is returned.
Args:
token_ids_0 (`List[int]`):
List of IDs.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: List of zeros.
"""
sep = [self.sep_token_id]
cls = [self.cls_token_id]
if token_ids_1 is None:
return len(cls + token_ids_0 + sep) * [0]
else:
return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs):
"""
Prepare text for tokenization by optionally adding a space prefix based on the arguments.
Args:
text (str): Input text to be tokenized.
is_split_into_words (bool, optional): Whether the text is already split into words.
**kwargs: Additional keyword arguments.
Returns:
Tuple[str, dict]: Processed text and any remaining keyword arguments.
"""
add_prefix_space = kwargs.pop("add_prefix_space", self.add_prefix_space)
if (is_split_into_words or add_prefix_space) and (len(text) > 0 and not text[0].isspace()):
text = " " + text
return (text, kwargs)
.\models\longformer\tokenization_longformer_fast.py
"""Fast Tokenization classes for Longformer."""
import json
from typing import List, Optional, Tuple
from tokenizers import pre_tokenizers, processors
from ...tokenization_utils_base import AddedToken, BatchEncoding
from ...tokenization_utils_fast import PreTrainedTokenizerFast
from ...utils import logging
from .tokenization_longformer import LongformerTokenizer
logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt", "tokenizer_file": "tokenizer.json"}
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"allenai/longformer-base-4096": "https://huggingface.co/allenai/longformer-base-4096/resolve/main/vocab.json",
"allenai/longformer-large-4096": (
"https://huggingface.co/allenai/longformer-large-4096/resolve/main/vocab.json"
),
"allenai/longformer-large-4096-finetuned-triviaqa": (
"https://huggingface.co/allenai/longformer-large-4096-finetuned-triviaqa/resolve/main/vocab.json"
),
"allenai/longformer-base-4096-extra.pos.embd.only": (
"https://huggingface.co/allenai/longformer-base-4096-extra.pos.embd.only/resolve/main/vocab.json"
),
"allenai/longformer-large-4096-extra.pos.embd.only": (
"https://huggingface.co/allenai/longformer-large-4096-extra.pos.embd.only/resolve/main/vocab.json"
),
},
"merges_file": {
"allenai/longformer-base-4096": "https://huggingface.co/allenai/longformer-base-4096/resolve/main/merges.txt",
"allenai/longformer-large-4096": (
"https://huggingface.co/allenai/longformer-large-4096/resolve/main/merges.txt"
),
"allenai/longformer-large-4096-finetuned-triviaqa": (
"https://huggingface.co/allenai/longformer-large-4096-finetuned-triviaqa/resolve/main/merges.txt"
),
"allenai/longformer-base-4096-extra.pos.embd.only": (
"https://huggingface.co/allenai/longformer-base-4096-extra.pos.embd.only/resolve/main/merges.txt"
),
"allenai/longformer-large-4096-extra.pos.embd.only": (
"https://huggingface.co/allenai/longformer-large-4096-extra.pos.embd.only/resolve/main/merges.txt"
),
},
"tokenizer_file": {
"allenai/longformer-base-4096": (
"https://huggingface.co/allenai/longformer-base-4096/resolve/main/tokenizer.json"
),
"allenai/longformer-large-4096": (
"https://huggingface.co/allenai/longformer-large-4096/resolve/main/tokenizer.json"
),
"allenai/longformer-large-4096-finetuned-triviaqa": (
"https://huggingface.co/allenai/longformer-large-4096-finetuned-triviaqa/resolve/main/tokenizer.json"
),
"allenai/longformer-base-4096-extra.pos.embd.only": (
"https://huggingface.co/allenai/longformer-base-4096-extra.pos.embd.only/resolve/main/tokenizer.json"
),
"allenai/longformer-large-4096-extra.pos.embd.only": (
"https://huggingface.co/allenai/longformer-large-4096-extra.pos.embd.only/resolve/main/tokenizer.json"
),
},
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"allenai/longformer-base-4096": 4096,
"allenai/longformer-large-4096": 4096,
"allenai/longformer-large-4096-finetuned-triviaqa": 4096,
"allenai/longformer-base-4096-extra.pos.embd.only": 4096,
"allenai/longformer-large-4096-extra.pos.embd.only": 4096,
}
class LongformerTokenizerFast(PreTrainedTokenizerFast):
"""
Construct a "fast" Longformer tokenizer (backed by HuggingFace's *tokenizers* library), derived from the GPT-2
tokenizer, using byte-level Byte-Pair-Encoding.
This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will
be encoded differently whether it is at the beginning of the sentence (without space) or not:
```
>>> from transformers import LongformerTokenizerFast
>>> tokenizer = LongformerTokenizerFast.from_pretrained("allenai/longformer-base-4096")
>>> tokenizer("Hello world")["input_ids"]
[0, 31414, 232, 2]
>>> tokenizer(" Hello world")["input_ids"]
[0, 20920, 232, 2]
```
You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer or when you
call it on some text, but since the model was not pretrained this way, it might yield a decrease in performance.
<Tip>
When used with `is_split_into_words=True`, this tokenizer needs to be instantiated with `add_prefix_space=True`.
</Tip>
This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
refer to this superclass for more information regarding those methods.
```
# 定义函数参数说明
Args:
vocab_file (`str`):
# 词汇表文件的路径。
Path to the vocabulary file.
merges_file (`str`):
# 合并文件的路径。
Path to the merges file.
errors (`str`, *optional*, defaults to `"replace"`):
# 解码字节为 UTF-8 时的错误处理策略。详见 bytes.decode 的说明文档。
Paradigm to follow when decoding bytes to UTF-8. See
[bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
bos_token (`str`, *optional*, defaults to `"<s>"`):
# 预训练过程中用作序列开头的特殊标记。
The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
<Tip>
When building a sequence using special tokens, this is not the token that is used for the beginning of
sequence. The token used is the `cls_token`.
</Tip>
eos_token (`str`, *optional*, defaults to `"</s>"`):
# 序列的结束标记。
<Tip>
When building a sequence using special tokens, this is not the token that is used for the end of sequence.
The token used is the `sep_token`.
</Tip>
sep_token (`str`, *optional*, defaults to `"</s>"`):
# 分隔符标记,在构建多序列时使用,例如序列分类或问答时的文本和问题。
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
sequence classification or for a text and a question for question answering. It is also used as the last
token of a sequence built with special tokens.
cls_token (`str`, *optional*, defaults to `"<s>"`):
# 分类器标记,在序列分类时使用(整体序列分类而不是每个标记的分类)。
The classifier token which is used when doing sequence classification (classification of the whole sequence
instead of per-token classification). It is the first token of the sequence when built with special tokens.
unk_token (`str`, *optional*, defaults to `"<unk>"`):
# 未知标记,词汇表中不存在的标记将被设置为此标记。
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
token instead.
pad_token (`str`, *optional*, defaults to `"<pad>"`):
# 填充标记,在批处理不同长度序列时使用。
The token used for padding, for example when batching sequences of different lengths.
mask_token (`str`, *optional*, defaults to `"<mask>"`):
# 掩码标记,用于掩码语言建模训练。
The token used for masking values. This is the token used when training this model with masked language
modeling. This is the token which the model will try to predict.
add_prefix_space (`bool`, *optional*, defaults to `False`):
# 是否在输入前添加初始空格,用于长序列处理。
Whether or not to add an initial space to the input. This allows to treat the leading word just as any
other word. (Longformer tokenizer detect beginning of words by the preceding space).
trim_offsets (`bool`, *optional*, defaults to `True`):
# 后处理步骤是否应修剪偏移量以避免包含空格。
Whether the post processing step should trim offsets to avoid including whitespaces.
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
model_input_names = ["input_ids", "attention_mask"]
slow_tokenizer_class = LongformerTokenizer
def __init__(
self,
vocab_file=None,
merges_file=None,
tokenizer_file=None,
errors="replace",
bos_token="<s>",
eos_token="</s>",
sep_token="</s>",
cls_token="<s>",
unk_token="<unk>",
pad_token="<pad>",
mask_token="<mask>",
add_prefix_space=False,
trim_offsets=True,
**kwargs,
):
mask_token = (
AddedToken(mask_token, lstrip=True, rstrip=False, normalized=False)
if isinstance(mask_token, str)
else mask_token
)
super().__init__(
vocab_file,
merges_file,
tokenizer_file=tokenizer_file,
errors=errors,
bos_token=bos_token,
eos_token=eos_token,
sep_token=sep_token,
cls_token=cls_token,
unk_token=unk_token,
pad_token=pad_token,
mask_token=mask_token,
add_prefix_space=add_prefix_space,
trim_offsets=trim_offsets,
**kwargs,
)
pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__())
if pre_tok_state.get("add_prefix_space", add_prefix_space) != add_prefix_space:
pre_tok_class = getattr(pre_tokenizers, pre_tok_state.pop("type"))
pre_tok_state["add_prefix_space"] = add_prefix_space
self.backend_tokenizer.pre_tokenizer = pre_tok_class(**pre_tok_state)
self.add_prefix_space = add_prefix_space
tokenizer_component = "post_processor"
tokenizer_component_instance = getattr(self.backend_tokenizer, tokenizer_component, None)
if tokenizer_component_instance:
state = json.loads(tokenizer_component_instance.__getstate__())
if "sep" in state:
state["sep"] = tuple(state["sep"])
if "cls" in state:
state["cls"] = tuple(state["cls"])
changes_to_apply = False
if state.get("add_prefix_space", add_prefix_space) != add_prefix_space:
state["add_prefix_space"] = add_prefix_space
changes_to_apply = True
if state.get("trim_offsets", trim_offsets) != trim_offsets:
state["trim_offsets"] = trim_offsets
changes_to_apply = True
if changes_to_apply:
component_class = getattr(processors, state.pop("type"))
new_value = component_class(**state)
setattr(self.backend_tokenizer, tokenizer_component, new_value)
@property
def mask_token(self) -> str:
"""
`str`: 返回掩码标记,用于在进行掩码语言建模训练时使用。如果在未设置的情况下使用,则记录错误日志。
Longformer 分词器有一个特殊的掩码标记,可在填充掩码流程中使用。该掩码标记将贪婪地包括 *<mask>* 前面的空格。
"""
if self._mask_token is None:
if self.verbose:
logger.error("Using mask_token, but it is not set yet.")
return None
return str(self._mask_token)
@mask_token.setter
def mask_token(self, value):
"""
重写掩码标记的默认行为,使其能够吸收其前的空格。
这是为了与基于 Longformer 的所有先前使用的模型保持向后兼容性。
"""
value = AddedToken(value, lstrip=True, rstrip=False) if isinstance(value, str) else value
self._mask_token = value
def _batch_encode_plus(self, *args, **kwargs) -> BatchEncoding:
is_split_into_words = kwargs.get("is_split_into_words", False)
assert self.add_prefix_space or not is_split_into_words, (
f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True "
"to use it with pretokenized inputs."
)
return super()._batch_encode_plus(*args, **kwargs)
def _encode_plus(self, *args, **kwargs) -> BatchEncoding:
is_split_into_words = kwargs.get("is_split_into_words", False)
assert self.add_prefix_space or not is_split_into_words, (
f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True "
"to use it with pretokenized inputs."
)
return super()._encode_plus(*args, **kwargs)
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
"""
将词汇表保存到指定目录,并返回保存的文件名。
调用分词器模型的保存方法,将模型保存到指定目录中,并使用指定的文件名前缀。
"""
files = self._tokenizer.model.save(save_directory, name=filename_prefix)
return tuple(files)
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
"""
根据给定的 token_ids 构建包含特殊标记的输入序列。
将序列开始标记 (bos_token_id)、序列 0 的 token_ids 和序列结束标记 (eos_token_id) 拼接成输出列表。
如果存在第二个序列 (token_ids_1),则将其与序列 0 后面的 eos_token_id 拼接。
"""
output = [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
if token_ids_1 is None:
return output
return output + [self.eos_token_id] + token_ids_1 + [self.eos_token_id]
def create_token_type_ids_from_sequences(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
):
"""
根据给定的 token_ids 构建 token 类型 ID。
用于指示每个 token 属于哪个序列的标识符,通常用于区分两个不同序列的 token。
"""
) -> List[int]:
"""
Create a mask from the two sequences passed to be used in a sequence-pair classification task. Longformer does not
make use of token type ids, therefore a list of zeros is returned.
Args:
token_ids_0 (`List[int]`):
List of IDs.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: List of zeros.
"""
sep = [self.sep_token_id]
cls = [self.cls_token_id]
if token_ids_1 is None:
return len(cls + token_ids_0 + sep) * [0]
return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
.\models\longformer\__init__.py
from typing import TYPE_CHECKING
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_tf_available,
is_tokenizers_available,
is_torch_available,
)
_import_structure = {
"configuration_longformer": [
"LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP",
"LongformerConfig",
"LongformerOnnxConfig",
],
"tokenization_longformer": ["LongformerTokenizer"],
}
try:
if not is_tokenizers_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["tokenization_longformer_fast"] = ["LongformerTokenizerFast"]
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_longformer"] = [
"LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
"LongformerForMaskedLM",
"LongformerForMultipleChoice",
"LongformerForQuestionAnswering",
"LongformerForSequenceClassification",
"LongformerForTokenClassification",
"LongformerModel",
"LongformerPreTrainedModel",
"LongformerSelfAttention",
]
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_tf_longformer"] = [
"TF_LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFLongformerForMaskedLM",
"TFLongformerForMultipleChoice",
"TFLongformerForQuestionAnswering",
"TFLongformerForSequenceClassification",
"TFLongformerForTokenClassification",
"TFLongformerModel",
"TFLongformerPreTrainedModel",
"TFLongformerSelfAttention",
]
if TYPE_CHECKING:
from .configuration_longformer import (
LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
LongformerConfig,
LongformerOnnxConfig,
)
from .tokenization_longformer import LongformerTokenizer
try:
if not is_tokenizers_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .tokenization_longformer_fast import LongformerTokenizerFast
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_longformer import (
LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
LongformerForMaskedLM,
LongformerForMultipleChoice,
LongformerForQuestionAnswering,
LongformerForSequenceClassification,
LongformerForTokenClassification,
LongformerModel,
LongformerPreTrainedModel,
LongformerSelfAttention,
)
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_tf_longformer import (
TF_LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
TFLongformerForMaskedLM,
TFLongformerForMultipleChoice,
TFLongformerForQuestionAnswering,
TFLongformerForSequenceClassification,
TFLongformerForTokenClassification,
TFLongformerModel,
TFLongformerPreTrainedModel,
TFLongformerSelfAttention,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
.\models\longt5\configuration_longt5.py
""" LongT5 model configuration"""
from typing import Mapping
from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxSeq2SeqConfigWithPast
from ...utils import logging
logger = logging.get_logger(__name__)
LONGT5_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"google/long-t5-local-base": "https://huggingface.co/google/long-t5-local-base/blob/main/config.json",
"google/long-t5-local-large": "https://huggingface.co/google/long-t5-local-large/blob/main/config.json",
"google/long-t5-tglobal-base": "https://huggingface.co/google/long-t5-tglobal-base/blob/main/config.json",
"google/long-t5-tglobal-large": "https://huggingface.co/google/long-t5-tglobal-large/blob/main/config.json",
}
class LongT5Config(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`LongT5Model`] or a [`FlaxLongT5Model`]. It is
used to instantiate a LongT5 model according to the specified arguments, defining the model architecture.
Instantiating a configuration with the defaults will yield a similar configuration to that of the LongT5
[google/long-t5-local-base](https://huggingface.co/google/long-t5-local-base) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
# 设置模型类型为 "longt5"
model_type = "longt5"
# 定义一个在推理过程中要忽略的键列表
keys_to_ignore_at_inference = ["past_key_values"]
# 定义一个将类属性名称映射到别名的字典
attribute_map = {"hidden_size": "d_model", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"}
# 初始化方法,用于设置模型的各种参数和属性
def __init__(
self,
vocab_size=32128, # 词汇表大小,默认为32128
d_model=512, # 隐藏层大小,默认为512
d_kv=64, # 键值的维度,默认为64
d_ff=2048, # 前馈神经网络内部层的维度,默认为2048
num_layers=6, # 网络层数,默认为6
num_decoder_layers=None, # 解码器层数,默认为None,即与编码器层数相同
num_heads=8, # 注意力头的数量,默认为8
local_radius=127, # 本地注意力的半径,默认为127
global_block_size=16, # 全局块大小,默认为16
relative_attention_num_buckets=32, # 相对注意力的桶数,默认为32
relative_attention_max_distance=128, # 相对注意力的最大距离,默认为128
dropout_rate=0.1, # Dropout率,默认为0.1
layer_norm_epsilon=1e-6, # Layer normalization的epsilon值,默认为1e-6
initializer_factor=1.0, # 初始化因子,默认为1.0
feed_forward_proj="relu", # 前馈网络的激活函数,默认为relu
is_encoder_decoder=True, # 是否为编码器-解码器结构,默认为True
encoder_attention_type="local", # 编码器注意力的类型,默认为local
use_cache=True, # 是否使用缓存,默认为True
pad_token_id=0, # 填充标记的ID,默认为0
eos_token_id=1, # 结束标记的ID,默认为1
**kwargs, # 其他关键字参数,用于传递给父类构造函数
):
# 设置对象的各种属性值
self.vocab_size = vocab_size
self.d_model = d_model
self.d_kv = d_kv
self.d_ff = d_ff
self.num_layers = num_layers
# 如果给定的解码器层数不为None,则使用给定的值,否则使用编码器层数作为解码器层数
self.num_decoder_layers = num_decoder_layers if num_decoder_layers is not None else self.num_layers
self.num_heads = num_heads
self.local_radius = local_radius
self.global_block_size = global_block_size
self.relative_attention_num_buckets = relative_attention_num_buckets
self.relative_attention_max_distance = relative_attention_max_distance
self.dropout_rate = dropout_rate
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_factor = initializer_factor
self.feed_forward_proj = feed_forward_proj
self.encoder_attention_type = encoder_attention_type
self.use_cache = use_cache
# 解析前馈网络激活函数的信息,提取激活函数名称和是否为门控激活函数的标志
act_info = self.feed_forward_proj.split("-")
self.dense_act_fn = act_info[-1] # 提取激活函数的名称
self.is_gated_act = act_info[0] == "gated" # 判断是否为门控激活函数
# 如果激活函数信息的长度超出预期或格式不正确,则抛出值错误异常
if len(act_info) > 1 and act_info[0] != "gated" or len(act_info) > 2:
raise ValueError(
f"`feed_forward_proj`: {feed_forward_proj} is not a valid activation function of the dense layer. "
"Please make sure `feed_forward_proj` is of the format `gated-{ACT_FN}` or `{ACT_FN}`, e.g. "
"'gated-gelu' or 'relu'"
)
# 对于向后兼容性,如果前馈网络激活函数设为'gated-gelu',则更新为'gelu_new'
if feed_forward_proj == "gated-gelu":
self.dense_act_fn = "gelu_new"
# 调用父类的初始化方法,传递填充标记ID、结束标记ID、是否为编码器-解码器等参数以及其他关键字参数
super().__init__(
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
is_encoder_decoder=is_encoder_decoder,
**kwargs,
)
# 定义一个名为 LongT5OnnxConfig 的类,继承自 OnnxSeq2SeqConfigWithPast 类
class LongT5OnnxConfig(OnnxSeq2SeqConfigWithPast):
# inputs 属性,返回一个映射,描述了模型的输入结构
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
# 定义通用的输入格式,包括 input_ids 和 attention_mask
common_inputs = {
"input_ids": {0: "batch", 1: "encoder_sequence"},
"attention_mask": {0: "batch", 1: "encoder_sequence"},
}
# 如果使用过去信息(use_past 为 True)
if self.use_past:
# 调整 attention_mask 的描述以包括过去编码器序列
common_inputs["attention_mask"][1] = "past_encoder_sequence + sequence"
# 添加 decoder_input_ids 的描述
common_inputs["decoder_input_ids"] = {0: "batch"}
# 添加 decoder_attention_mask 的描述,包括过去解码器序列
common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"}
else:
# 如果不使用过去信息,添加普通的 decoder_input_ids 描述
common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"}
# 添加普通的 decoder_attention_mask 描述
common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"}
# 如果使用过去信息,调用 fill_with_past_key_values_ 方法填充 common_inputs
if self.use_past:
self.fill_with_past_key_values_(common_inputs, direction="inputs")
# 返回描述输入结构的字典 common_inputs
return common_inputs
# default_onnx_opset 属性,返回默认的 ONNX 运算集版本号
@property
def default_onnx_opset(self) -> int:
# 返回 ONNX 运算集版本号 13
return 13
.\models\longt5\convert_longt5x_checkpoint_to_flax.py
import argparse
from t5x import checkpoints
from transformers import AutoConfig, FlaxAutoModelForSeq2SeqLM
def convert_t5x_checkpoint_to_flax(t5x_checkpoint_path, config_name, flax_dump_folder_path):
config = AutoConfig.from_pretrained(config_name)
flax_model = FlaxAutoModelForSeq2SeqLM.from_config(config=config)
t5x_model = checkpoints.load_t5x_checkpoint(t5x_checkpoint_path)
split_mlp_wi = "wi_0" in t5x_model["target"]["encoder"]["layers_0"]["mlp"]
if config.model_type == "t5":
encoder_attn_name = "SelfAttention"
if config.model_type == "longt5" and config.encoder_attention_type == "local":
encoder_attn_name = "LocalSelfAttention"
elif config.model_type == "longt5" and config.encoder_attention_type == "transient-global":
encoder_attn_name = "TransientGlobalSelfAttention"
else:
raise ValueError(
"Given config is expected to have `model_type='t5'`, or `model_type='longt5` with `encoder_attention_type`"
" attribute with a value from ['local', 'transient-global]."
)
t5x_encoder_rel_embedding = t5x_model["target"]["encoder"]["relpos_bias"]["rel_embedding"].T
flax_model.params["encoder"]["block"]["0"]["layer"]["0"][encoder_attn_name]["relative_attention_bias"][
"embedding"
] = t5x_encoder_rel_embedding
if config.model_type == "longt5" and config.encoder_attention_type == "transient-global":
t5x_encoder_global_rel_embedding = t5x_model["target"]["encoder"]["side_relpos_bias"]["rel_embedding"].T
flax_model.params["encoder"]["block"]["0"]["layer"]["0"][encoder_attn_name]["global_relative_attention_bias"][
"embedding"
] = t5x_encoder_global_rel_embedding
t5x_encoder_norm = t5x_model["target"]["encoder"]["encoder_norm"]["scale"]
flax_model.params["encoder"]["final_layer_norm"]["weight"] = t5x_encoder_norm
tx5_decoder_norm = t5x_model["target"]["decoder"]["decoder_norm"]["scale"]
flax_model.params["decoder"]["final_layer_norm"]["weight"] = tx5_decoder_norm
t5x_decoder_rel_embedding = t5x_model["target"]["decoder"]["relpos_bias"]["rel_embedding"].T
flax_model.params["decoder"]["block"]["0"]["layer"]["0"]["SelfAttention"]["relative_attention_bias"][
"embedding"
] = t5x_decoder_rel_embedding
tx5_token_embeddings = t5x_model["target"]["token_embedder"]["embedding"]
flax_model.params["shared"]["embedding"] = tx5_token_embeddings
if "logits_dense" in t5x_model["target"]["decoder"]:
flax_model.params["lm_head"]["kernel"] = t5x_model["target"]["decoder"]["logits_dense"]["kernel"]
flax_model.save_pretrained(flax_dump_folder_path)
print("T5X Model was sucessfully converted!")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--t5x_checkpoint_path", default=None, type=str, required=True, help="Path the T5X checkpoint."
)
parser.add_argument("--config_name", default=None, type=str, required=True, help="Config name of LongT5/T5 model.")
parser.add_argument(
"--flax_dump_folder_path", default=None, type=str, required=True, help="Path to the output FLAX model."
)
args = parser.parse_args()
convert_t5x_checkpoint_to_flax(args.t5x_checkpoint_path, args.config_name, args.flax_dump_folder_path)
.\models\longt5\modeling_flax_longt5.py
import copy
from typing import Any, Callable, List, Optional, Tuple
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask
from flax.linen import partitioning as nn_partitioning
from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax.random import PRNGKey
from ...modeling_flax_outputs import (
FlaxBaseModelOutput,
FlaxBaseModelOutputWithPastAndCrossAttentions,
FlaxCausalLMOutputWithCrossAttentions,
FlaxSeq2SeqLMOutput,
FlaxSeq2SeqModelOutput,
)
from ...modeling_flax_utils import (
ACT2FN,
FlaxPreTrainedModel,
append_call_sample_docstring,
append_replace_return_docstrings,
overwrite_call_docstring,
)
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_longt5 import LongT5Config
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "google/long-t5-local-base"
_CONFIG_FOR_DOC = "LongT5Config"
remat = nn_partitioning.remat
def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
"""
将输入的token向右移动一个位置。
"""
shifted_input_ids = jnp.zeros_like(input_ids)
shifted_input_ids = shifted_input_ids.at[:, 1:].set(input_ids[:, :-1])
shifted_input_ids = shifted_input_ids.at[:, 0].set(decoder_start_token_id)
shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
return shifted_input_ids
def _pad_to_multiple(x: jnp.ndarray, block_len: int, axis: int, pad_value: int = 0) -> jnp.ndarray:
"""将数组填充到长度为block_len的倍数"""
pad_len = -x.shape[axis] % block_len
pad = [(0, 0)] * x.ndim
pad[axis] = (0, pad_len)
x = jnp.pad(x, pad_width=pad, mode="constant", constant_values=pad_value)
return x
def _split_into_blocks(x: jnp.ndarray, block_len: int, axis: int) -> jnp.ndarray:
"""沿着指定轴将输入数组分割成指定长度的块"""
if x.shape[axis] % block_len != 0:
x = _pad_to_multiple(x, block_len, axis, pad_value=0)
num_blocks = x.shape[axis] // block_len
output_shape = x.shape[:axis] + (num_blocks, block_len) + x.shape[(axis + 1):]
return x.reshape(output_shape)
def _concatenate_3_blocks(x: jnp.ndarray, block_axis: int, sequence_axis: int, pad_value: int = 0) -> jnp.ndarray:
"""Concatenate three consecutive blocks for each input block for local attentiont.
For more information, see: https://arxiv.org/pdf/2112.07916.pdf.
"""
num_blocks = x.shape[block_axis]
pad = [(0, 0)] * x.ndim
pad[block_axis] = (1, 1)
x = jnp.pad(x, pad_width=pad, mode="constant", constant_values=pad_value)
blocks_list: List[np.array] = []
for i in range(3):
indices = [slice(0, None)] * x.ndim
indices[block_axis] = slice(i, i + num_blocks)
indices = tuple(indices)
blocks_list.append(x[indices])
return jnp.concatenate(blocks_list, axis=sequence_axis)
def _make_3block_relative_position_ids(block_len: int) -> jnp.ndarray:
"""Makes 3-blocked relative position ids for local attention."""
position_ids = jnp.arange(3 * block_len, dtype=jnp.int32)
center_position_ids = position_ids[block_len:-block_len]
relative_position_ids = position_ids[None, :] - center_position_ids[:, None]
return relative_position_ids
def _mask_local_attention_mask(local_attention_mask: np.ndarray, block_len: int) -> jnp.ndarray:
"""Mask local attention mask to enforce that tokens are not allowed to attend tokens farther than ``local_radius."""
relative_position_ids = _make_3block_relative_position_ids(block_len)
locality_mask = jnp.abs(relative_position_ids) < block_len
locality_mask = locality_mask[None, None, :, :]
return jnp.logical_and(local_attention_mask, locality_mask)
def _get_local_attention_mask(attention_mask: np.ndarray, block_len: int) -> jnp.ndarray:
"""Prepare attention mask to be applied for a local attention."""
_blocked_attention_mask = _split_into_blocks(attention_mask, block_len, axis=1)
_3blocked_attention_mask = _concatenate_3_blocks(_blocked_attention_mask, block_axis=1, sequence_axis=2)
_blocked_attention_mask = _blocked_attention_mask[..., None]
_3blocked_attention_mask = _3blocked_attention_mask[..., None, :]
local_attention_mask = jnp.logical_and(_blocked_attention_mask, _3blocked_attention_mask)
local_attention_mask = _mask_local_attention_mask(local_attention_mask, block_len)
return local_attention_mask[:, None, ...]
def _make_global_fixed_block_ids(attention_mask: np.ndarray, global_block_size: int) -> Tuple[jnp.ndarray, np.ndarray]:
"""Make global fixed block ids for global attention."""
...
"""Obtain the "fixed block" global id corresponding to each input token.
This implementation is a simlified version of the original Flaxformr implementation adopted from:
https://github.com/google/flaxformer/blob/main/flaxformer/architectures/longt5/long_attention.py.
In our scenario, as we use this strategy only for a decoder, orphan tokens, i.e. those tokens which do not make for
the whole fixed block, are assigned to the preceding block.
Padding tokens from the original sequence are represented by -1.
"""
batch_size, seq_len = attention_mask.shape[:2]
def handle_orphan_tokens(block_ids: np.ndarray) -> jnp.ndarray:
block_ends = (jnp.arange(seq_len) % global_block_size) == global_block_size - 1
true_block_ends = jnp.logical_and(block_ends, block_ids >= 0)
full_blocks = true_block_ends.sum(-1)[..., None]
block_ids = jnp.minimum(block_ids, full_blocks - 1)
return block_ids
fixed_block_mask = jnp.ones_like(attention_mask) / global_block_size
fixed_block_mask = jnp.cumsum(fixed_block_mask, axis=1) - fixed_block_mask
mask = jnp.where(attention_mask != 0.0, 1.0, -1000.0)
global_block_ids = jnp.maximum(
jnp.floor(mask + fixed_block_mask - 1.0), jnp.array(-1.0, dtype=attention_mask.dtype)
)
global_block_ids = (global_block_ids * attention_mask) + (attention_mask - 1)
global_block_ids = handle_orphan_tokens(global_block_ids)
num_globals = seq_len // global_block_size
if num_globals > 0:
_sequence_block_ids_max = jnp.repeat(global_block_ids.max(axis=-1)[:, None], repeats=num_globals, axis=1)
else:
_sequence_block_ids_max = jnp.zeros((batch_size, 0), dtype=global_block_ids.dtype)
global_segment_ids = jnp.cumsum(jnp.ones((batch_size, num_globals)), axis=-1) - 1
global_segment_ids = jnp.where(global_segment_ids <= _sequence_block_ids_max, 1, 0)
return global_block_ids, global_segment_ids
def _make_side_relative_position_ids(attention_mask: np.ndarray, global_block_size: int) -> np.ndarray:
block_ids, global_segment_ids = _make_global_fixed_block_ids(attention_mask, global_block_size)
global_seq_len = global_segment_ids.shape[-1]
global_positions = jnp.arange(global_seq_len)
side_relative_position = global_positions - block_ids[..., None]
return side_relative_position
def _create_global_aggregates(hidden_states: np.ndarray, block_ids: np.ndarray, global_seq_len: int) -> np.ndarray:
"""Compute individual block aggregates by summing over individual blocks."""
one_hot_block_ids = jax.nn.one_hot(block_ids, global_seq_len)
return jnp.einsum("...nd,...ng->...gd", hidden_states, one_hot_block_ids)
class FlaxLongT5LayerNorm(nn.Module):
hidden_size: int
dtype: jnp.dtype = jnp.float32
eps: float = 1e-6
weight_init: Callable[..., np.ndarray] = jax.nn.initializers.ones
def setup(self):
self.weight = self.param("weight", self.weight_init, (self.hidden_size,))
def __call__(self, hidden_states):
"""
Construct a layernorm module in the LongT5 style; No bias and no subtraction of mean.
"""
variance = jnp.power(hidden_states.astype("f4"), 2).mean(axis=-1, keepdims=True)
hidden_states = hidden_states / jnp.sqrt(variance + self.eps)
return self.weight * hidden_states
class FlaxLongT5DenseActDense(nn.Module):
config: LongT5Config
dtype: jnp.dtype = jnp.float32
def setup(self):
wi_init_std = self.config.initializer_factor * (self.config.d_model**-0.5)
wo_init_std = self.config.initializer_factor * (self.config.d_ff**-0.5)
self.wi = nn.Dense(
self.config.d_ff,
use_bias=False,
kernel_init=jax.nn.initializers.normal(wi_init_std),
dtype=self.dtype,
)
self.act = ACT2FN[self.config.dense_act_fn]
self.dropout = nn.Dropout(self.config.dropout_rate)
self.wo = nn.Dense(
self.config.d_model,
use_bias=False,
kernel_init=jax.nn.initializers.normal(wo_init_std),
dtype=self.dtype,
)
def __call__(self, hidden_states, deterministic=True):
hidden_states = self.wi(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
hidden_states = self.wo(hidden_states)
return hidden_states
class FlaxLongT5DenseGatedActDense(nn.Module):
config: LongT5Config
def setup(self):
wi_init_std = self.config.initializer_factor * (self.config.d_model**-0.5)
wo_init_std = self.config.initializer_factor * (self.config.d_ff**-0.5)
self.wi_0 = nn.Dense(
self.config.d_ff,
use_bias=False,
kernel_init=jax.nn.initializers.normal(wi_init_std),
dtype=self.dtype,
)
self.wi_1 = nn.Dense(
self.config.d_ff,
use_bias=False,
kernel_init=jax.nn.initializers.normal(wi_init_std),
dtype=self.dtype,
)
self.wo = nn.Dense(
self.config.d_model,
use_bias=False,
kernel_init=jax.nn.initializers.normal(wo_init_std),
dtype=self.dtype,
)
self.dropout = nn.Dropout(self.config.dropout_rate)
self.act = ACT2FN[self.config.dense_act_fn]
def __call__(self, hidden_states, deterministic):
hidden_gelu = self.act(self.wi_0(hidden_states))
hidden_linear = self.wi_1(hidden_states)
hidden_states = hidden_gelu * hidden_linear
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
hidden_states = self.wo(hidden_states)
return hidden_states
class FlaxLongT5LayerFF(nn.Module):
config: LongT5Config
def setup(self):
if self.config.is_gated_act:
self.DenseReluDense = FlaxLongT5DenseGatedActDense(self.config, dtype=self.dtype)
else:
self.DenseReluDense = FlaxLongT5DenseActDense(self.config, dtype=self.dtype)
self.layer_norm = FlaxLongT5LayerNorm(
self.config.d_model,
eps=self.config.layer_norm_epsilon,
dtype=self.dtype,
)
self.dropout = nn.Dropout(self.config.dropout_rate)
def __call__(self, hidden_states, deterministic=True):
forwarded_states = self.layer_norm(hidden_states)
forwarded_states = self.DenseReluDense(forwarded_states, deterministic=deterministic)
hidden_states = hidden_states + self.dropout(forwarded_states, deterministic=deterministic)
return hidden_states
class FlaxLongT5Attention(nn.Module):
config: LongT5Config
has_relative_attention_bias: bool = False
causal: bool = False
def setup(self):
self.relative_attention_num_buckets = self.config.relative_attention_num_buckets
self.relative_attention_max_distance = self.config.relative_attention_max_distance
self.d_model = self.config.d_model
self.key_value_proj_dim = self.config.d_kv
self.n_heads = self.config.num_heads
self.dropout = self.config.dropout_rate
self.inner_dim = self.n_heads * self.key_value_proj_dim
q_init_std = self.config.initializer_factor * ((self.inner_dim * self.key_value_proj_dim) ** -0.5)
kv_init_std = self.config.initializer_factor * (self.inner_dim**-0.5)
o_init_std = self.config.initializer_factor * (self.inner_dim**-0.5)
self.q = nn.Dense(
self.inner_dim,
use_bias=False,
kernel_init=jax.nn.initializers.normal(q_init_std),
dtype=self.dtype,
)
self.k = nn.Dense(
self.inner_dim,
use_bias=False,
kernel_init=jax.nn.initializers.normal(kv_init_std),
dtype=self.dtype,
)
self.v = nn.Dense(
self.inner_dim,
use_bias=False,
kernel_init=jax.nn.initializers.normal(kv_init_std),
dtype=self.dtype,
)
self.o = nn.Dense(
self.d_model,
use_bias=False,
kernel_init=jax.nn.initializers.normal(o_init_std),
dtype=self.dtype,
)
if self.has_relative_attention_bias:
self.relative_attention_bias = nn.Embed(
self.relative_attention_num_buckets,
self.n_heads,
embedding_init=jax.nn.initializers.normal(kv_init_std),
dtype=self.dtype,
)
@staticmethod
def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
"""
Adapted from Mesh Tensorflow:
https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
Translate relative position to a bucket number for relative attention. The relative position is defined as
memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
This should allow for more graceful generalization to longer sequences than the model has been trained on
"""
relative_buckets = 0
if bidirectional:
num_buckets //= 2
relative_buckets += (relative_position > 0) * num_buckets
relative_position = jnp.abs(relative_position)
else:
relative_position = -jnp.clip(relative_position, a_max=0)
max_exact = num_buckets // 2
is_small = relative_position < max_exact
relative_position_if_large = max_exact + (
jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact)
)
relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1)
relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large)
return relative_buckets.astype("i4")
def compute_bias(self, query_length, key_length):
"""Compute binned relative position bias"""
context_position = jnp.arange(query_length, dtype="i4")[:, None]
memory_position = jnp.arange(key_length, dtype="i4")[None, :]
relative_position = memory_position - context_position
relative_position_bucket = self._relative_position_bucket(
relative_position,
bidirectional=(not self.causal),
num_buckets=self.relative_attention_num_buckets,
max_distance=self.relative_attention_max_distance,
)
values = self.relative_attention_bias(relative_position_bucket)
values = values.transpose((2, 0, 1))[None, :, :, :]
return values
def _split_heads(self, hidden_states):
return hidden_states.reshape(hidden_states.shape[:2] + (self.n_heads, self.key_value_proj_dim))
def _merge_heads(self, hidden_states):
return hidden_states.reshape(hidden_states.shape[:2] + (self.inner_dim,))
@nn.compact
def _concatenate_to_cache(self, key, value, query, attention_mask):
"""
This function takes projected key, value states from a single input token and concatenates the states to cached
states from previous steps. This function is slighly adapted from the official Flax repository:
https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
"""
is_initialized = self.has_variable("cache", "cached_key")
cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
if is_initialized:
*batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
cur_index = cache_index.value
indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
key = jax.lax.dynamic_update_slice(cached_key.value, key, indices)
value = jax.lax.dynamic_update_slice(cached_value.value, value, indices)
cached_key.value = key
cached_value.value = value
num_updated_cache_vectors = query.shape[1]
cache_index.value = cache_index.value + num_updated_cache_vectors
pad_mask = jnp.broadcast_to(
jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
)
attention_mask = combine_masks(pad_mask, attention_mask)
return key, value, attention_mask
def _create_position_bias(
self, key_states, query_states, attention_mask, init_cache, seq_length, causal_attention_mask_shift
):
cache_is_filled = self.causal and self.has_variable("cache", "cached_key") and (not init_cache)
key_length = key_states.shape[1]
query_length = key_length if cache_is_filled else query_states.shape[1]
if self.has_relative_attention_bias:
position_bias = self.compute_bias(query_length, key_length)
elif attention_mask is not None:
position_bias = jnp.zeros_like(attention_mask)
else:
position_bias = jnp.zeros((1, self.n_heads, query_length, key_length), dtype=self.dtype)
if cache_is_filled:
max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
position_bias = jax.lax.dynamic_slice(
position_bias,
(0, 0, causal_attention_mask_shift, 0),
(1, self.n_heads, seq_length, max_decoder_length),
)
return position_bias
def __call__(
self,
hidden_states,
attention_mask=None,
key_value_states=None,
position_bias=None,
use_cache=False,
output_attentions=False,
deterministic=True,
init_cache=False,
class FlaxLongT5LocalAttention(nn.Module):
config: LongT5Config
has_relative_attention_bias: bool = False
dtype: jnp.dtype = jnp.float32
def setup(self):
self.relative_attention_num_buckets = self.config.relative_attention_num_buckets
self.relative_attention_max_distance = self.config.relative_attention_max_distance
self.d_model = self.config.d_model
self.key_value_proj_dim = self.config.d_kv
self.n_heads = self.config.num_heads
self.local_radius = self.config.local_radius
self.block_len = self.local_radius + 1
self.dropout = self.config.dropout_rate
self.inner_dim = self.n_heads * self.key_value_proj_dim
q_init_std = self.config.initializer_factor * ((self.inner_dim * self.key_value_proj_dim) ** -0.5)
kv_init_std = self.config.initializer_factor * (self.inner_dim**-0.5)
o_init_std = self.config.initializer_factor * (self.inner_dim**-0.5)
self.q = nn.Dense(
self.inner_dim,
use_bias=False,
kernel_init=jax.nn.initializers.normal(q_init_std),
dtype=self.dtype,
)
self.k = nn.Dense(
self.inner_dim,
use_bias=False,
kernel_init=jax.nn.initializers.normal(kv_init_std),
dtype=self.dtype,
)
self.v = nn.Dense(
self.inner_dim,
use_bias=False,
kernel_init=jax.nn.initializers.normal(kv_init_std),
dtype=self.dtype,
)
self.o = nn.Dense(
self.d_model,
use_bias=False,
kernel_init=jax.nn.initializers.normal(o_init_std),
dtype=self.dtype,
)
if self.has_relative_attention_bias:
self.relative_attention_bias = nn.Embed(
self.relative_attention_num_buckets,
self.n_heads,
embedding_init=jax.nn.initializers.normal(kv_init_std),
)
@staticmethod
def _relative_position_bucket(x, max_distance: int, num_buckets: int, bidirectional: bool = True):
"""
根据相对位置计算桶索引,用于生成相对位置偏置。
Args:
x: 相对位置
max_distance: 最大距离
num_buckets: 桶的数量
bidirectional: 是否双向
Returns:
相对位置的桶索引
"""
def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
"""
Adapted from Mesh Tensorflow:
https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
Translate relative position to a bucket number for relative attention. The relative position is defined as
memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
This should allow for more graceful generalization to longer sequences than the model has been trained on
"""
relative_buckets = 0
if bidirectional:
num_buckets //= 2
relative_buckets += (relative_position > 0) * num_buckets
relative_position = jnp.abs(relative_position)
else:
relative_position = -jnp.clip(relative_position, a_max=0)
max_exact = num_buckets // 2
is_small = relative_position < max_exact
relative_position_if_large = max_exact + (
jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact)
)
relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1)
relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large)
return relative_buckets.astype("i4")
def compute_bias(self, block_length: int):
"""Compute binned relative position bias"""
memory_position = jnp.arange(3 * block_length, dtype="i4")
context_position = memory_position[block_length:-block_length]
relative_position = memory_position[None, :] - context_position[:, None]
relative_position_bucket = self._relative_position_bucket(
relative_position,
bidirectional=True,
num_buckets=self.relative_attention_num_buckets,
max_distance=self.relative_attention_max_distance,
)
values = self.relative_attention_bias(relative_position_bucket)
values = values.transpose((2, 0, 1))[None, None, :, :, :]
return values
def _split_heads(self, hidden_states):
return hidden_states.reshape(hidden_states.shape[:2] + (self.n_heads, self.key_value_proj_dim))
def _merge_heads(self, hidden_states):
return hidden_states.reshape(hidden_states.shape[0], -1, self.inner_dim)
def _create_position_bias(self, block_len: int, attention_mask: Optional[np.ndarray]) -> np.ndarray:
if self.has_relative_attention_bias:
position_bias = self.compute_bias(block_len)
elif attention_mask is not None:
position_bias = jnp.zeros_like(attention_mask)
else:
position_bias = jnp.zeros((1, 1, self.n_heads, block_len, 3 * block_len), dtype=self.dtype)
return position_bias
def __call__(
self,
hidden_states,
attention_mask=None,
key_value_states=None,
position_bias=None,
output_attentions=False,
deterministic=True,
class FlaxLongT5TransientGlobalAttention(nn.Module):
config: LongT5Config
has_relative_attention_bias: bool = False
dtype: jnp.dtype = jnp.float32
def setup(self):
self.relative_attention_num_buckets = self.config.relative_attention_num_buckets
self.relative_attention_max_distance = self.config.relative_attention_max_distance
self.d_model = self.config.d_model
self.key_value_proj_dim = self.config.d_kv
self.n_heads = self.config.num_heads
self.local_radius = self.config.local_radius
self.block_len = self.local_radius + 1
self.global_block_size = self.config.global_block_size
self.dropout = self.config.dropout_rate
self.inner_dim = self.n_heads * self.key_value_proj_dim
q_init_std = self.config.initializer_factor * ((self.inner_dim * self.key_value_proj_dim) ** -0.5)
kv_init_std = self.config.initializer_factor * (self.inner_dim**-0.5)
o_init_std = self.config.initializer_factor * (self.inner_dim**-0.5)
self.q = nn.Dense(
self.inner_dim,
use_bias=False,
kernel_init=jax.nn.initializers.normal(q_init_std),
dtype=self.dtype,
)
self.k = nn.Dense(
self.inner_dim,
use_bias=False,
kernel_init=jax.nn.initializers.normal(kv_init_std),
dtype=self.dtype,
)
self.v = nn.Dense(
self.inner_dim,
use_bias=False,
kernel_init=jax.nn.initializers.normal(kv_init_std),
dtype=self.dtype,
)
self.o = nn.Dense(
self.d_model,
use_bias=False,
kernel_init=jax.nn.initializers.normal(o_init_std),
dtype=self.dtype,
)
if self.has_relative_attention_bias:
self.relative_attention_bias = nn.Embed(
self.relative_attention_num_buckets,
self.n_heads,
embedding_init=jax.nn.initializers.normal(kv_init_std),
)
if self.has_relative_attention_bias:
self.global_relative_attention_bias = nn.Embed(
self.relative_attention_num_buckets,
self.n_heads,
embedding_init=jax.nn.initializers.normal(kv_init_std),
)
self.global_input_layer_norm = FlaxLongT5LayerNorm(
self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype
)
@staticmethod
def _relative_position_bucket(
x: jnp.ndarray,
bidirectional: bool = True,
num_buckets: int = 32,
max_distance: int = 128,
) -> jnp.ndarray:
def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
"""
Adapted from Mesh Tensorflow:
https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
Translate relative position to a bucket number for relative attention. The relative position is defined as
memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
This should allow for more graceful generalization to longer sequences than the model has been trained on
"""
relative_buckets = 0
if bidirectional:
num_buckets //= 2
relative_buckets += (relative_position > 0) * num_buckets
relative_position = jnp.abs(relative_position)
else:
relative_position = -jnp.clip(relative_position, a_max=0)
max_exact = num_buckets // 2
is_small = relative_position < max_exact
relative_position_if_large = max_exact + (
jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact)
)
relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1)
relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large)
return relative_buckets.astype("i4")
def compute_bias(self, block_length: int):
"""Compute binned relative position bias"""
memory_position = jnp.arange(3 * block_length, dtype="i4")
context_position = memory_position[block_length:-block_length]
relative_position = memory_position[None, :] - context_position[:, None]
relative_position_bucket = self._relative_position_bucket(
relative_position,
bidirectional=True,
num_buckets=self.relative_attention_num_buckets,
max_distance=self.relative_attention_max_distance,
)
values = self.relative_attention_bias(relative_position_bucket)
values = values.transpose((2, 0, 1))[None, None, :, :, :]
return values
def compute_side_bias(self, attention_mask: np.ndarray, global_segment_ids: np.ndarray) -> np.ndarray:
side_attention_mask = jnp.equal(attention_mask[..., None], global_segment_ids[:, None, :])[:, None, ...]
attention_side_bias = jax.lax.select(
side_attention_mask > 0,
jnp.full(side_attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(side_attention_mask.shape, -1e10).astype(self.dtype),
)
side_relative_position = _make_side_relative_position_ids(attention_mask, self.global_block_size)
side_relative_position_bucket = self._relative_position_bucket(
side_relative_position,
bidirectional=True,
num_buckets=self.relative_attention_num_buckets,
max_distance=self.relative_attention_max_distance,
)
side_bias = self.global_relative_attention_bias(side_relative_position_bucket)
side_bias = jnp.transpose(side_bias, (0, 3, 1, 2))
attention_side_bias = attention_side_bias + side_bias
return attention_side_bias
def _split_heads(self, hidden_states):
return hidden_states.reshape(hidden_states.shape[:2] + (self.n_heads, self.key_value_proj_dim))
def _merge_heads(self, hidden_states):
return hidden_states.reshape(hidden_states.shape[0], -1, self.inner_dim)
def _create_position_bias(self, block_len: int, attention_mask: Optional[np.ndarray]) -> np.ndarray:
if self.has_relative_attention_bias:
position_bias = self.compute_bias(block_len)
elif attention_mask is not None:
position_bias = jnp.zeros_like(attention_mask)
else:
position_bias = jnp.zeros((1, 1, self.n_heads, block_len, 3 * block_len), dtype=self.dtype)
return position_bias
def __call__(
self,
hidden_states,
attention_mask=None,
key_value_states=None,
position_bias=None,
output_attentions=False,
deterministic=True,
class FlaxLongT5LayerLocalSelfAttention(nn.Module):
"""Local self attention used in encoder"""
config: LongT5Config
has_relative_attention_bias: bool = False
dtype: jnp.dtype = jnp.float32
def setup(self):
self.LocalSelfAttention = FlaxLongT5LocalAttention(
self.config, has_relative_attention_bias=self.has_relative_attention_bias, dtype=self.dtype
)
self.layer_norm = FlaxLongT5LayerNorm(
self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype
)
self.dropout = nn.Dropout(self.config.dropout_rate)
def __call__(
self,
hidden_states,
attention_mask=None,
position_bias=None,
output_attentions=False,
deterministic=True,
**kwargs: Any,
):
normed_hidden_states = self.layer_norm(hidden_states)
attention_output = self.LocalSelfAttention(
normed_hidden_states,
attention_mask=attention_mask,
position_bias=position_bias,
output_attentions=output_attentions,
deterministic=deterministic,
)
hidden_states = hidden_states + self.dropout(attention_output[0], deterministic=deterministic)
outputs = (hidden_states,) + attention_output[1:]
return outputs
class FlaxLongT5LayerTransientGlobalSelfAttention(nn.Module):
"""Transient-Global self attention used in encoder"""
config: LongT5Config
has_relative_attention_bias: bool = False
dtype: jnp.dtype = jnp.float32
def setup(self):
self.TransientGlobalSelfAttention = FlaxLongT5TransientGlobalAttention(
self.config, has_relative_attention_bias=self.has_relative_attention_bias, dtype=self.dtype
)
self.layer_norm = FlaxLongT5LayerNorm(
self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype
)
self.dropout = nn.Dropout(self.config.dropout_rate)
def __call__(
self,
hidden_states,
attention_mask=None,
position_bias=None,
output_attentions=False,
deterministic=True,
**kwargs: Any,
):
normed_hidden_states = self.layer_norm(hidden_states)
attention_output = self.TransientGlobalSelfAttention(
normed_hidden_states,
attention_mask=attention_mask,
position_bias=position_bias,
output_attentions=output_attentions,
deterministic=deterministic,
)
hidden_states = hidden_states + self.dropout(attention_output[0], deterministic=deterministic)
outputs = (hidden_states,) + attention_output[1:]
return outputs
class FlaxLongT5LayerSelfAttention(nn.Module):
config: LongT5Config
has_relative_attention_bias: bool = False
dtype: jnp.dtype = jnp.float32
def setup(self):
self.SelfAttention = FlaxLongT5Attention(
self.config,
has_relative_attention_bias=self.has_relative_attention_bias,
causal=self.config.causal,
dtype=self.dtype,
)
self.layer_norm = FlaxLongT5LayerNorm(
self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype
)
self.dropout = nn.Dropout(self.config.dropout_rate)
def __call__(
self,
hidden_states,
attention_mask=None,
position_bias=None,
output_attentions=False,
deterministic=True,
init_cache=False,
):
normed_hidden_states = self.layer_norm(hidden_states)
attention_output = self.SelfAttention(
normed_hidden_states,
attention_mask=attention_mask,
position_bias=position_bias,
output_attentions=output_attentions,
deterministic=deterministic,
init_cache=init_cache,
)
hidden_states = hidden_states + self.dropout(attention_output[0], deterministic=deterministic)
outputs = (hidden_states,) + attention_output[1:]
return outputs
class FlaxLongT5LayerCrossAttention(nn.Module):
config: LongT5Config
dtype: jnp.dtype = jnp.float32
def setup(self):
self.EncDecAttention = FlaxLongT5Attention(
self.config, has_relative_attention_bias=False, causal=False, dtype=self.dtype
)
self.layer_norm = FlaxLongT5LayerNorm(
self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype
)
self.dropout = nn.Dropout(self.config.dropout_rate)
def __call__(
self,
hidden_states,
key_value_states,
attention_mask=None,
position_bias=None,
output_attentions=False,
deterministic=True,
):
normed_hidden_states = self.layer_norm(hidden_states)
attention_output = self.EncDecAttention(
normed_hidden_states,
attention_mask=attention_mask,
key_value_states=key_value_states,
position_bias=position_bias,
output_attentions=output_attentions,
)
hidden_states = hidden_states + self.dropout(attention_output[0], deterministic=deterministic)
outputs = (hidden_states,) + attention_output[1:]
return outputs
class FlaxLongT5Block(nn.Module):
config: LongT5Config
has_relative_attention_bias: bool = False
dtype: jnp.dtype = jnp.float32
def setup(self):
self.causal = self.config.causal
if self.causal:
attention_layer = FlaxLongT5LayerSelfAttention
elif self.config.encoder_attention_type == "local":
attention_layer = FlaxLongT5LayerLocalSelfAttention
elif self.config.encoder_attention_type == "transient-global":
attention_layer = FlaxLongT5LayerTransientGlobalSelfAttention
else:
raise ValueError(
"For encoder attention mechanism, either `local` or `transient-global` attention type is expected, "
f"but got {self.config.encoder_attention_type}."
)
self.layer = (
attention_layer(
self.config,
has_relative_attention_bias=self.has_relative_attention_bias,
name=str(0),
dtype=self.dtype,
),
)
feed_forward_index = 1
if self.causal:
self.layer += (FlaxLongT5LayerCrossAttention(self.config, name=str(1), dtype=self.dtype),)
feed_forward_index += 1
self.layer += (FlaxLongT5LayerFF(self.config, name=str(feed_forward_index), dtype=self.dtype),)
def __call__(
self,
hidden_states,
attention_mask=None,
position_bias=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
encoder_decoder_position_bias=None,
output_attentions=False,
return_dict=True,
deterministic=True,
init_cache=False,
):
self_attention_outputs = self.layer[0](
hidden_states,
attention_mask=attention_mask,
position_bias=position_bias,
output_attentions=output_attentions,
deterministic=deterministic,
init_cache=init_cache,
)
hidden_states = self_attention_outputs[0]
attention_outputs = self_attention_outputs[1:]
do_cross_attention = self.causal and encoder_hidden_states is not None
if do_cross_attention:
cross_attention_outputs = self.layer[1](
hidden_states,
key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
position_bias=encoder_decoder_position_bias,
output_attentions=output_attentions,
deterministic=deterministic,
)
hidden_states = cross_attention_outputs[0]
attention_outputs = attention_outputs + cross_attention_outputs[1:]
hidden_states = self.layer[-1](hidden_states, deterministic=deterministic)
outputs = (hidden_states,)
outputs = outputs + attention_outputs
return outputs
class FlaxLongT5LayerCollection(nn.Module):
config: LongT5Config
has_relative_attention_bias: bool
dtype: jnp.dtype = jnp.float32
def setup(self):
self.layer = FlaxLongT5Block(
self.config, has_relative_attention_bias=self.has_relative_attention_bias, dtype=self.dtype
)
def __call__(
self,
hidden_states,
attention_mask=None,
position_bias=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
encoder_decoder_position_bias=None,
output_attentions=False,
deterministic=True,
init_cache=False,
):
return self.layer(
hidden_states,
attention_mask=attention_mask,
position_bias=position_bias,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
encoder_decoder_position_bias=encoder_decoder_position_bias,
output_attentions=output_attentions,
deterministic=deterministic,
init_cache=init_cache,
)
class FlaxLongT5BlockCollection(nn.Module):
config: LongT5Config
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self):
self.causal = self.config.causal
if self.gradient_checkpointing:
FlaxLongT5CheckpointLayer = remat(FlaxLongT5LayerCollection, static_argnums=(6, 7, 8))
self.blocks = [
FlaxLongT5CheckpointLayer(
self.config,
has_relative_attention_bias=(i == 0),
dtype=self.dtype,
name=str(i),
)
for i in range(self.config.num_layers)
]
else:
self.blocks = [
FlaxLongT5LayerCollection(
self.config,
has_relative_attention_bias=(i == 0),
dtype=self.dtype,
name=str(i),
)
for i in range(self.config.num_layers)
]
def __call__(
self,
hidden_states=None,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
output_attentions: bool = False,
output_hidden_states: bool = False,
deterministic: bool = True,
init_cache: bool = False,
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
all_cross_attentions = () if (output_attentions and self.causal) else None
position_bias = None
encoder_decoder_position_bias = None
for i, layer_module in enumerate(self.blocks):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_outputs = layer_module(
hidden_states,
attention_mask,
position_bias,
encoder_hidden_states,
encoder_attention_mask,
encoder_decoder_position_bias,
output_attentions,
deterministic,
init_cache,
)
hidden_states = layer_outputs[0]
position_bias = layer_outputs[1]
if self.causal and encoder_hidden_states is not None:
encoder_decoder_position_bias = layer_outputs[3 if output_attentions else 2]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[2],)
if self.causal:
all_cross_attentions = all_cross_attentions + (layer_outputs[4],)
return FlaxBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_attentions,
cross_attentions=all_cross_attentions,
)
class FlaxLongT5Stack(nn.Module):
config: LongT5Config
embed_tokens: nn.Embed
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self):
self.causal = self.config.causal
self.block = FlaxLongT5BlockCollection(
self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)
self.final_layer_norm = FlaxLongT5LayerNorm(
self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype
)
self.dropout = nn.Dropout(self.config.dropout_rate)
def __call__(
self,
input_ids=None,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
deterministic: bool = True,
init_cache: bool = False,
):
hidden_states = self.embed_tokens(input_ids)
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
outputs = self.block(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
deterministic=deterministic,
init_cache=init_cache,
)
hidden_states = outputs[0]
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
all_hidden_states = None
if output_hidden_states:
all_hidden_states = outputs.hidden_states
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
if output_hidden_states:
return (
hidden_states,
all_hidden_states,
) + outputs[2:]
return (hidden_states,) + outputs[1:]
return FlaxBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
)
LONGT5_ENCODE_INPUTS_DOCSTRING = r"""
# 接收输入参数的函数定义,用于处理LongT5模型的输入数据
Args:
input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
输入序列标记在词汇表中的索引。LongT5模型具有相对位置嵌入,因此可以在序列的左右两侧进行填充。
可以使用[`AutoTokenizer`]获取索引。详见[`PreTrainedTokenizer.encode`]和[`PreTrainedTokenizer.__call__`]。
想要了解有关如何为预训练准备`input_ids`的更多信息,请查看[长T5训练](./longt5#training)。
attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
遮罩,用于避免在填充标记索引上执行注意力操作。遮罩值选在 `[0, 1]` 之间:
- 对于**未被遮罩**的标记,值为1,
- 对于**被遮罩**的标记,值为0。
[什么是注意力遮罩?](../glossary#attention-mask)
output_attentions (`bool`, *optional*):
是否返回所有注意力层的注意力张量。查看返回的张量中的`attentions`以获取更多详细信息。
output_hidden_states (`bool`, *optional*):
是否返回所有层的隐藏状态。查看返回的张量中的`hidden_states`以获取更多详细信息。
return_dict (`bool`, *optional*):
是否返回[`~utils.ModelOutput`]而不是普通元组。
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
class FlaxLongT5PreTrainedModel(FlaxPreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
# 指定配置类为 LongT5Config
config_class = LongT5Config
# 基础模型前缀为 "transformer"
base_model_prefix = "transformer"
# 模块类默认为空
module_class: nn.Module = None
# 初始化方法,用于创建一个 LongT5 模型实例
def __init__(
self,
config: LongT5Config,
input_shape: Tuple[int] = (1, 1),
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs,
):
# 使用给定的配置和参数创建模型类实例
module = self.module_class(config=config, dtype=dtype, **kwargs)
# 调用父类的初始化方法,传入配置、模型类实例以及其他参数
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
# 启用梯度检查点功能,重新设置模型实例以支持梯度检查点
def enable_gradient_checkpointing(self):
self._module = self.module_class(
config=self.config,
dtype=self.dtype,
gradient_checkpointing=True,
)
# 初始化模型权重方法,使用随机数种子 rng 和输入形状 input_shape
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# 初始化输入张量 input_ids
input_ids = jnp.zeros(input_shape, dtype="i4")
# 创建 attention_mask,decoder_input_ids 和 decoder_attention_mask 张量
attention_mask = jnp.ones_like(input_ids)
decoder_input_ids = jnp.ones_like(input_ids)
decoder_attention_mask = jnp.ones_like(input_ids)
# 分割随机数种子 rng 为 params_rng 和 dropout_rng
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
# 使用初始化方法初始化模型参数 random_params
random_params = self.module.init(
rngs,
input_ids,
attention_mask,
decoder_input_ids,
decoder_attention_mask,
)["params"]
# 如果提供了初始参数 params,则将缺失的参数从 random_params 中补充到 params 中
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
# 重写 __call__ 方法,并添加了文档字符串装饰器 LONGT5_INPUTS_DOCSTRING
@add_start_docstrings_to_model_forward(LONGT5_INPUTS_DOCSTRING)
def __call__(
self,
input_ids: jnp.ndarray,
attention_mask: Optional[jnp.ndarray] = None,
decoder_input_ids: jnp.ndarray = None,
decoder_attention_mask: Optional[jnp.ndarray] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
):
# 如果未指定output_attentions,则使用配置中的默认值
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
# 如果未指定output_hidden_states,则使用配置中的默认值
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
# 如果未指定return_dict,则使用配置中的默认值
return_dict = return_dict if return_dict is not None else self.config.return_dict
# 如果decoder_input_ids未提供,则抛出值错误
if decoder_input_ids is None:
raise ValueError(
"Make sure to provide both `input_ids` and `decoder_input_ids`. `decoder_input_ids` is not passed"
" here."
)
# 准备编码器输入的注意力掩码
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids)
# 准备解码器输入的注意力掩码
if decoder_attention_mask is None:
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
# 如果需要处理任何伪随机数生成器,则放入rngs字典中
rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
# 调用self.module.apply方法来执行模型
return self.module.apply(
{"params": params or self.params}, # 模型参数
input_ids=jnp.array(input_ids, dtype="i4"), # 编码器输入的token IDs
attention_mask=jnp.array(attention_mask, dtype="i4"), # 编码器输入的注意力掩码
decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), # 解码器输入的token IDs
decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), # 解码器输入的注意力掩码
output_attentions=output_attentions, # 是否返回注意力权重
output_hidden_states=output_hidden_states, # 是否返回隐藏状态
return_dict=return_dict, # 是否返回字典形式的输出
deterministic=not train, # 是否确定性执行(非训练状态)
rngs=rngs, # 伪随机数生成器字典
)
def init_cache(self, batch_size, max_length, encoder_outputs):
r"""
Args:
batch_size (`int`):
batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
max_length (`int`):
maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
cache.
encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`):
`encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*:
`attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*)
is a sequence of hidden-states at the output of the last layer of the encoder. Used in the
cross-attention of the decoder.
"""
# 初始化输入变量以检索缓存
decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4")
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, **kwargs):
decoder_module = module._get_decoder_module()
return decoder_module(
decoder_input_ids,
decoder_attention_mask,
**kwargs,
)
# 使用指定的参数初始化模型,并获取初始化后的变量
init_variables = self.module.init(
jax.random.PRNGKey(0),
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
encoder_hidden_states=encoder_outputs[0],
init_cache=True,
method=_decoder_forward, # 只需调用解码器以初始化缓存
)
# 返回冻结的缓存变量
return unfreeze(init_variables["cache"])
@add_start_docstrings(LONGT5_ENCODE_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=LongT5Config)
def encode(
self,
input_ids: jnp.ndarray,
attention_mask: Optional[jnp.ndarray] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
):
r"""
Returns:
Example:
```
>>> from transformers import AutoTokenizer, FlaxLongT5ForConditionalGeneration
>>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-base")
>>> model = FlaxLongT5ForConditionalGeneration.from_pretrained("google/long-t5-local-base")
>>> text = "My friends are cool but they eat too many carbs."
>>> inputs = tokenizer(text, return_tensors="np")
>>> encoder_outputs = model.encode(**inputs)
```"""
# 如果没有显式指定,则使用默认配置中的值
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
# 如果 attention_mask 为 None,则创建一个全为 1 的掩码
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids)
# 如果有 dropout_rng,则加入 RNG 字典中
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
# 定义内部函数 _encoder_forward,用于调用编码器模块
def _encoder_forward(module, input_ids, attention_mask, **kwargs):
encode_module = module._get_encoder_module()
return encode_module(input_ids, attention_mask, **kwargs)
# 调用 Flax 模块的 apply 方法进行前向传播
return self.module.apply(
{"params": params or self.params}, # 使用给定的参数或者当前实例的参数
input_ids=jnp.array(input_ids, dtype="i4"), # 转换输入 ids 到 JAX 数组格式
attention_mask=jnp.array(attention_mask, dtype="i4"), # 转换注意力掩码到 JAX 数组格式
output_attentions=output_attentions, # 是否返回注意力权重
output_hidden_states=output_hidden_states, # 是否返回隐藏状态
return_dict=return_dict, # 是否返回字典形式的输出
deterministic=not train, # 是否是确定性运行模式,非训练状态
rngs=rngs, # 随机数生成器字典
method=_encoder_forward, # 调用的方法,这里是编码器的前向传播函数
)
@add_start_docstrings(LONGT5_DECODE_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=FlaxBaseModelOutputWithPastAndCrossAttentions, config_class=LongT5Config)
def decode(
self,
decoder_input_ids,
encoder_outputs,
encoder_attention_mask: Optional[jnp.ndarray] = None,
decoder_attention_mask: Optional[jnp.ndarray] = None,
past_key_values: dict = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
"""
The LongT5 model was proposed in [LongT5: Efficient Text-To-Text Transformer for Long
Sequences](https://arxiv.org/abs/2112.07916) by Mandy Guo, Joshua Ainslie, David Uthus, Santiago Ontanon, Jianmo
Ni, Yun-Hsuan Sung and Yinfei Yang. It's an encoder-decoder transformer pre-trained in a text-to-text denoising
generative setting. LongT5 model is an extension of T5 model, and it enables using one of the two different
efficient attention mechanisms - (1) Local attention, or (2) Transient-Global attention.
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a Flax Linen
[flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
Finally, this model supports inherent JAX features such as:
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
Parameters:
config ([`LongT5Config`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
`jax.numpy.bfloat16` (on TPUs).
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
specified all the computation will be performed with the given `dtype`.
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
parameters.**
If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
[`~FlaxPreTrainedModel.to_bf16`].
"""
@add_start_docstrings(
"The bare LONGT5 Model transformer outputting raw hidden-stateswithout any specific head on top.",
LONGT5_START_DOCSTRING,
)
# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5Module with T5->LongT5
class FlaxLongT5Module(nn.Module):
"""
Flax module for the LongT5 model, extending the T5 architecture to support long sequences and different attention mechanisms.
Inherits from `nn.Module`, enabling it to be used as a Flax Linen module. Refer to Flax documentation for usage details.
Attributes:
config (LongT5Config): Model configuration object containing all parameters.
dtype (jnp.dtype): Data type for computation, defaulting to jnp.float32.
"""
config: LongT5Config
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
# 梯度检查点标志,默认为 False
gradient_checkpointing: bool = False
# 获取编码器模块的方法
def _get_encoder_module(self):
return self.encoder
# 获取解码器模块的方法
def _get_decoder_module(self):
return self.decoder
# 设置方法,用于初始化和配置模型
def setup(self):
# 创建共享的嵌入层,用于输入和输出的词汇表
self.shared = nn.Embed(
self.config.vocab_size,
self.config.d_model,
embedding_init=jax.nn.initializers.normal(self.config.initializer_factor * 1.0),
dtype=self.dtype,
)
# 复制编码器配置,并禁用因果关系,创建编码器对象
encoder_config = copy.deepcopy(self.config)
encoder_config.causal = False
self.encoder = FlaxLongT5Stack(
encoder_config,
embed_tokens=self.shared,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
# 复制解码器配置,并启用因果关系,根据配置创建解码器对象
decoder_config = copy.deepcopy(self.config)
decoder_config.causal = True
decoder_config.num_layers = self.config.num_decoder_layers
self.decoder = FlaxLongT5Stack(
decoder_config,
embed_tokens=self.shared,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
# 调用方法,实现模型的前向传播
def __call__(
self,
input_ids=None,
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
encoder_outputs=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
deterministic: bool = True,
):
# 如果未指定返回字典,则使用配置中的默认设置
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# 编码器的前向传播,生成编码器输出
encoder_outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=deterministic,
)
# 解码器的前向传播,生成解码器输出
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
encoder_hidden_states=encoder_outputs[0], # 使用编码器的隐藏状态作为解码器的输入
encoder_attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=deterministic,
)
# 如果不需要返回字典,则将编码器和解码器的输出合并返回
if not return_dict:
return decoder_outputs + encoder_outputs
# 构造并返回序列到序列模型的输出
return FlaxSeq2SeqModelOutput(
last_hidden_state=decoder_outputs.last_hidden_state,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)
# 从transformers.models.t5.modeling_flax_t5.FlaxT5Model复制代码,将T5替换为LongT5
class FlaxLongT5Model(FlaxLongT5PreTrainedModel):
# 使用FlaxLongT5Module作为模块类
module_class = FlaxLongT5Module
# 将FlaxLongT5Model的调用示例文档字符串附加到_CHECKPOINT_FOR_DOC并使用FlaxSeq2SeqModelOutput进行文档化
append_call_sample_docstring(FlaxLongT5Model, _CHECKPOINT_FOR_DOC, FlaxSeq2SeqModelOutput, _CONFIG_FOR_DOC)
# 定义FLAX_LONGT5_MODEL_DOCSTRING,包含函数的返回值和示例用法
FLAX_LONGT5_MODEL_DOCSTRING = """
Returns:
Example:
```
>>> from transformers import AutoTokenizer, FlaxLongT5Model
>>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-base")
>>> model = FlaxLongT5Model.from_pretrained("google/long-t5-local-base")
>>> input_ids = tokenizer(
... "Studies have been shown that owning a dog is good for you", return_tensors="np"
... ).input_ids
>>> decoder_input_ids = tokenizer("Studies show that", return_tensors="np").input_ids
>>> # forward pass
>>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
>>> last_hidden_states = outputs.last_hidden_state
```
"""
# 重写FlaxLongT5Model的调用文档字符串,使用LONGT5_INPUTS_DOCSTRING和FLAX_LONGT5_MODEL_DOCSTRING
overwrite_call_docstring(FlaxLongT5Model, LONGT5_INPUTS_DOCSTRING + FLAX_LONGT5_MODEL_DOCSTRING)
# 将FlaxLongT5Model的返回文档字符串替换为FlaxSeq2SeqLMOutput,使用_CONFIG_FOR_DOC作为配置类
append_replace_return_docstrings(FlaxLongT5Model, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
@add_start_docstrings("""LONGT5 Model with a `language modeling` head on top.""", LONGT5_START_DOCSTRING)
# 从transformers.models.t5.modeling_flax_t5.FlaxT5ForConditionalGenerationModule复制代码,将T5替换为LongT5
class FlaxLongT5ForConditionalGenerationModule(nn.Module):
# 配置为LongT5Config类型
config: LongT5Config
# 计算中的数据类型为jnp.float32
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
# 梯度检查点默认为False
gradient_checkpointing: bool = False
# 返回编码器模块
def _get_encoder_module(self):
return self.encoder
# 返回解码器模块
def _get_decoder_module(self):
return self.decoder
# 模块设置函数
def setup(self):
# 模型维度为配置中的d_model
self.model_dim = self.config.d_model
# 创建共享的嵌入层,使用正态分布初始化
self.shared = nn.Embed(
self.config.vocab_size,
self.config.d_model,
embedding_init=jax.nn.initializers.normal(self.config.initializer_factor),
dtype=self.dtype,
)
# 复制编码器配置,并设定特定参数
encoder_config = copy.deepcopy(self.config)
encoder_config.causal = False
encoder_config.use_cache = False
encoder_config.is_encoder_decoder = False
# 创建LongT5编码器栈
self.encoder = FlaxLongT5Stack(
encoder_config, self.shared, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)
# 复制解码器配置,并设定特定参数
decoder_config = copy.deepcopy(self.config)
decoder_config.causal = True
decoder_config.is_encoder_decoder = False
decoder_config.num_layers = self.config.num_decoder_layers
# 创建LongT5解码器栈
self.decoder = FlaxLongT5Stack(
decoder_config, self.shared, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)
# 创建语言模型头部,全连接层
self.lm_head = nn.Dense(
self.config.vocab_size,
use_bias=False,
kernel_init=jax.nn.initializers.normal(self.config.initializer_factor),
dtype=self.dtype,
)
def __call__(
self,
input_ids=None,
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
encoder_outputs=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
deterministic: bool = True,
):
# 确保返回字典的设置正确
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Encode
# 调用编码器进行编码
encoder_outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=deterministic,
)
hidden_states = encoder_outputs[0]
# Decode
# 调用解码器进行解码
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
encoder_hidden_states=hidden_states,
encoder_attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=deterministic,
)
sequence_output = decoder_outputs[0]
if self.config.tie_word_embeddings:
# 如果需要共享词嵌入,则按比例缩放输出
# 参考:https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
sequence_output = sequence_output * (self.model_dim**-0.5)
if self.config.tie_word_embeddings:
# 如果需要共享词嵌入,则应用共享的嵌入层
shared_embedding = self.shared.variables["params"]["embedding"]
lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, sequence_output)
else:
lm_logits = self.lm_head(sequence_output)
if not return_dict:
# 如果不需要返回字典,则返回元组形式的结果
return (lm_logits,) + decoder_outputs[1:] + encoder_outputs
# 如果需要返回字典,则构造 FlaxSeq2SeqLMOutput 对象
return FlaxSeq2SeqLMOutput(
logits=lm_logits,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)
class FlaxLongT5ForConditionalGeneration(FlaxLongT5PreTrainedModel):
# 模型类指定为 FlaxLongT5ForConditionalGenerationModule
module_class = FlaxLongT5ForConditionalGenerationModule
@add_start_docstrings(LONGT5_DECODE_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=LongT5Config)
def decode(
self,
decoder_input_ids,
encoder_outputs,
encoder_attention_mask: Optional[jnp.ndarray] = None,
decoder_attention_mask: Optional[jnp.ndarray] = None,
past_key_values: dict = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
):
"""
解码函数,根据给定的输入和条件生成输出。
Args:
decoder_input_ids: 解码器的输入 ID。
encoder_outputs: 编码器的输出。
encoder_attention_mask: 可选,编码器的注意力遮罩。
decoder_attention_mask: 可选,解码器的注意力遮罩。
past_key_values: 可选,过去的键值对,用于加速生成。
output_attentions: 可选,是否输出注意力权重。
output_hidden_states: 可选,是否输出隐藏状态。
return_dict: 可选,是否以字典形式返回输出。
train: 是否处于训练模式。
params: 可选,模型参数。
dropout_rng: 可选,用于 dropout 的随机数生成器。
Returns:
解码后的输出结果。
"""
...
def prepare_inputs_for_generation(
self,
decoder_input_ids,
max_length,
attention_mask: Optional[jax.Array] = None,
decoder_attention_mask: Optional[jax.Array] = None,
encoder_outputs=None,
**kwargs,
):
"""
为生成过程准备输入,初始化缓存并生成注意力掩码。
Args:
decoder_input_ids: 解码器的输入 ID。
max_length: 生成的最大长度。
attention_mask: 可选,编码器的注意力遮罩。
decoder_attention_mask: 可选,解码器的注意力遮罩。
encoder_outputs: 可选,编码器的输出。
**kwargs: 其他关键字参数。
Returns:
包含生成过程所需输入的字典。
"""
# 初始化缓存
batch_size, seq_length = decoder_input_ids.shape
past_key_values = self.init_cache(batch_size, max_length, encoder_outputs)
# 创建扩展的注意力掩码
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
if decoder_attention_mask is not None:
extended_attention_mask = jax.lax.dynamic_update_slice(
extended_attention_mask, decoder_attention_mask, (0, 0)
)
return {
"past_key_values": past_key_values,
"encoder_outputs": encoder_outputs,
"encoder_attention_mask": attention_mask,
"decoder_attention_mask": extended_attention_mask,
}
def update_inputs_for_generation(self, model_outputs, model_kwargs):
"""
更新生成过程的输入,将过去的键值对更新为模型输出的过去键值对。
Args:
model_outputs: 模型的输出。
model_kwargs: 模型的关键字参数。
Returns:
更新后的模型关键字参数。
"""
model_kwargs["past_key_values"] = model_outputs.past_key_values
return model_kwargs
FLAX_LONGT5_CONDITIONAL_GENERATION_DOCSTRING = """
Returns:
生成的结果。
Example:
```
>>> from transformers import AutoTokenizer, FlaxLongT5ForConditionalGeneration
>>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-base")
>>> model = FlaxLongT5ForConditionalGeneration.from_pretrained("google/long-t5-local-base")
>>> ARTICLE_TO_SUMMARIZE = "summarize: My friends are cool but they eat too many carbs."
>>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], return_tensors="np")
>>> # 生成摘要
>>> summary_ids = model.generate(inputs["input_ids"]).sequences
>>> print(tokenizer.decode(summary_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=False))
```
"""
overwrite_call_docstring(
# 导入 FlaxLongT5ForConditionalGeneration 类以及相关文档字符串
FlaxLongT5ForConditionalGeneration, LONGT5_INPUTS_DOCSTRING + FLAX_LONGT5_CONDITIONAL_GENERATION_DOCSTRING
# 调用函数 `append_replace_return_docstrings`,传入参数 `FlaxLongT5ForConditionalGeneration` 作为第一个位置参数,
# `output_type=FlaxSeq2SeqLMOutput` 作为关键字参数 `output_type` 的值,
# `_CONFIG_FOR_DOC` 作为关键字参数 `config_class` 的值。
append_replace_return_docstrings(
FlaxLongT5ForConditionalGeneration, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC
)