Transformers 源码解析(四十四)
.\models\electra\modeling_flax_electra.py
from typing import Callable, Optional, Tuple
import flax
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask
from flax.linen import partitioning as nn_partitioning
from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax
from ...modeling_flax_outputs import (
FlaxBaseModelOutput,
FlaxBaseModelOutputWithPastAndCrossAttentions,
FlaxCausalLMOutputWithCrossAttentions,
FlaxMaskedLMOutput,
FlaxMultipleChoiceModelOutput,
FlaxQuestionAnsweringModelOutput,
FlaxSequenceClassifierOutput,
FlaxTokenClassifierOutput,
)
from ...modeling_flax_utils import (
ACT2FN,
FlaxPreTrainedModel,
append_call_sample_docstring,
append_replace_return_docstrings,
overwrite_call_docstring,
)
from ...utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from .configuration_electra import ElectraConfig
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "google/electra-small-discriminator"
_CONFIG_FOR_DOC = "ElectraConfig"
remat = nn_partitioning.remat
@flax.struct.dataclass
class FlaxElectraForPreTrainingOutput(ModelOutput):
"""
[`ElectraForPreTraining`] 的输出类型。
"""
Args:
logits (`jnp.ndarray` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
# 定义 logits 变量,类型为 jnp.ndarray,形状为 (batch_size, sequence_length, config.vocab_size)
logits: jnp.ndarray = None
# 定义 hidden_states 变量,类型为 Optional[Tuple[jnp.ndarray]],可选参数,当 `output_hidden_states=True` 时返回
# 返回一个元组,包含 jnp.ndarray 类型的张量,形状为 (batch_size, sequence_length, hidden_size)
hidden_states: Optional[Tuple[jnp.ndarray]] = None
# 定义 attentions 变量,类型为 Optional[Tuple[jnp.ndarray]],可选参数,当 `output_attentions=True` 时返回
# 返回一个元组,包含 jnp.ndarray 类型的张量,形状为 (batch_size, num_heads, sequence_length, sequence_length)
# 表示注意力权重经过 softmax 后的结果,用于计算自注意力头部中的加权平均值。
attentions: Optional[Tuple[jnp.ndarray]] = None
# 定义模型的文档字符串,描述该模型从 FlaxPreTrainedModel 继承,并列出了库为所有模型实现的通用方法(如下载、保存和从 PyTorch 模型转换权重)
ELECTRA_START_DOCSTRING = r"""
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading, saving and converting weights from PyTorch models)
This model is also a Flax Linen
[flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
Finally, this model supports inherent JAX features such as:
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html
Parameters:
config ([`ElectraConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
# 定义模型输入的文档字符串,目前为空白
ELECTRA_INPUTS_DOCSTRING = r"""
Args:
input_ids (`numpy.ndarray` of shape `({0})`):
输入序列标记在词汇表中的索引。
可以使用 [`AutoTokenizer`] 获取这些索引。详情见 [`PreTrainedTokenizer.encode`] 和
[`PreTrainedTokenizer.__call__`]。
[什么是输入 ID?](../glossary
attention_mask (`numpy.ndarray` of shape `({0})`, *optional*):
避免对填充标记索引执行注意力的掩码。掩码值为 `[0, 1]`:
- 1 表示**不屏蔽**的标记,
- 0 表示**屏蔽**的标记。
[什么是注意力掩码?](../glossary
token_type_ids (`numpy.ndarray` of shape `({0})`, *optional*):
段标记索引,指示输入的第一部分和第二部分。索引值为 `[0, 1]`:
- 0 对应*句子 A* 的标记,
- 1 对应*句子 B* 的标记。
[什么是标记类型 ID?](../glossary
position_ids (`numpy.ndarray` of shape `({0})`, *optional*):
每个输入序列标记在位置嵌入中的位置索引。选择范围为 `[0, config.max_position_embeddings - 1]`。
head_mask (`numpy.ndarray` of shape `({0})`, `optional):
选择性屏蔽注意力模块中的头部的掩码。掩码值为 `[0, 1]`:
- 1 表示**不屏蔽**的头部,
- 0 表示**屏蔽**的头部。
return_dict (`bool`, *optional*):
是否返回一个 [`~utils.ModelOutput`] 而不是普通的元组。
"""
定义一个名为 FlaxElectraEmbeddings 的 nn.Module 类,用于构建包括单词、位置和标记类型嵌入的 embeddings。
config: ElectraConfig
# 保存了 Electra 模型的配置信息,如词汇大小、嵌入维度等
dtype: jnp.dtype = jnp.float32
# 计算时使用的数据类型,默认为 jnp.float32
setup(self):
# 初始化模型的各个组件
self.word_embeddings = nn.Embed(
self.config.vocab_size,
self.config.embedding_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
)
# 创建单词嵌入层,根据词汇大小和嵌入维度进行初始化
self.position_embeddings = nn.Embed(
self.config.max_position_embeddings,
self.config.embedding_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
)
# 创建位置嵌入层,根据最大位置嵌入数和嵌入维度进行初始化
self.token_type_embeddings = nn.Embed(
self.config.type_vocab_size,
self.config.embedding_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
)
# 创建标记类型嵌入层,根据标记类型的数量和嵌入维度进行初始化
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
# 创建 Layer Normalization 层,使用给定的 epsilon 参数进行初始化
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
# 创建 Dropout 层,使用给定的 dropout 概率进行初始化
__call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True):
# 定义 __call__ 方法,实现模块的调用功能,接受输入参数并进行处理
inputs_embeds = self.word_embeddings(input_ids.astype("i4"))
# 将输入的词汇 ID 转换为单词嵌入
position_embeds = self.position_embeddings(position_ids.astype("i4"))
# 将位置 ID 转换为位置嵌入
token_type_embeddings = self.token_type_embeddings(token_type_ids.astype("i4"))
# 将标记类型 ID 转换为标记类型嵌入
hidden_states = inputs_embeds + token_type_embeddings + position_embeds
# 将单词、位置和标记类型嵌入求和,形成最终的隐藏状态表示
hidden_states = self.LayerNorm(hidden_states)
# 对隐藏状态进行 Layer Normalization 处理
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
# 对处理后的隐藏状态进行 Dropout 操作
return hidden_states
# 返回处理后的最终隐藏状态
"""
class FlaxElectraSelfAttention(nn.Module):
config: ElectraConfig
causal: bool = False
dtype: jnp.dtype = jnp.float32
def setup(self):
self.head_dim = self.config.hidden_size // self.config.num_attention_heads
if self.config.hidden_size % self.config.num_attention_heads != 0:
raise ValueError(
"`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads` "
" : {self.config.num_attention_heads}"
)
self.query = nn.Dense(
self.config.hidden_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
)
self.key = nn.Dense(
self.config.hidden_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
)
self.value = nn.Dense(
self.config.hidden_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
)
if self.causal:
self.causal_mask = make_causal_mask(
jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool"
)
def _split_heads(self, hidden_states):
return hidden_states.reshape(hidden_states.shape[:2] + (self.config.num_attention_heads, self.head_dim))
def _merge_heads(self, hidden_states):
return hidden_states.reshape(hidden_states.shape[:2] + (self.config.hidden_size,))
@nn.compact
def _concatenate_to_cache(self, key, value, query, attention_mask):
"""
This function takes projected key, value states from a single input token and concatenates the states to cached
states from previous steps. This function is slightly adapted from the official Flax repository:
https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
"""
is_initialized = self.has_variable("cache", "cached_key")
cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
if is_initialized:
*batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
cur_index = cache_index.value
indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
key = lax.dynamic_update_slice(cached_key.value, key, indices)
value = lax.dynamic_update_slice(cached_value.value, value, indices)
cached_key.value = key
cached_value.value = value
num_updated_cache_vectors = query.shape[1]
cache_index.value = cache_index.value + num_updated_cache_vectors
pad_mask = jnp.broadcast_to(
jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
)
attention_mask = combine_masks(pad_mask, attention_mask)
return key, value, attention_mask
class FlaxElectraSelfOutput(nn.Module):
config: ElectraConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.dense = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
def __call__(self, hidden_states, input_tensor, deterministic: bool = True):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class FlaxElectraAttention(nn.Module):
config: ElectraConfig
causal: bool = False
dtype: jnp.dtype = jnp.float32
def setup(self):
self.self = FlaxElectraSelfAttention(self.config, causal=self.causal, dtype=self.dtype)
def __call__(
self,
hidden_states,
attention_mask,
layer_head_mask,
key_value_states=None,
init_cache=False,
deterministic=True,
output_attentions: bool = False,
):
attn_outputs = self.self(
hidden_states,
attention_mask,
layer_head_mask=layer_head_mask,
key_value_states=key_value_states,
init_cache=init_cache,
deterministic=deterministic,
output_attentions=output_attentions,
)
attn_output = attn_outputs[0]
hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic)
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_outputs[1],)
return outputs
class FlaxElectraIntermediate(nn.Module):
config: ElectraConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.dense = nn.Dense(
self.config.intermediate_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
self.activation = ACT2FN[self.config.hidden_act]
def __call__(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.activation(hidden_states)
return hidden_states
class FlaxElectraOutput(nn.Module):
config: ElectraConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.dense = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
def __call__(self, hidden_states, attention_output, deterministic: bool = True):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
hidden_states = self.LayerNorm(hidden_states + attention_output)
return hidden_states
class FlaxElectraLayer(nn.Module):
config: ElectraConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.attention = FlaxElectraAttention(self.config, causal=self.config.is_decoder, dtype=self.dtype)
self.intermediate = FlaxElectraIntermediate(self.config, dtype=self.dtype)
self.output = FlaxElectraOutput(self.config, dtype=self.dtype)
if self.config.add_cross_attention:
self.crossattention = FlaxElectraAttention(self.config, causal=False, dtype=self.dtype)
def __call__(
self,
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states: Optional[jnp.ndarray] = None,
encoder_attention_mask: Optional[jnp.ndarray] = None,
init_cache: bool = False,
deterministic: bool = True,
output_attentions: bool = False,
):
attention_output = self.attention(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
init_cache=init_cache,
deterministic=deterministic,
output_attentions=output_attentions,
)
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output, deterministic=deterministic)
if self.config.add_cross_attention:
attention_output = self.crossattention(
layer_output,
encoder_attention_mask,
encoder_hidden_states,
layer_head_mask=None,
init_cache=init_cache,
deterministic=deterministic,
output_attentions=output_attentions,
)
return (layer_output, attention_output) if output_attentions else layer_output
attention_outputs = self.attention(
hidden_states,
attention_mask,
layer_head_mask=layer_head_mask,
init_cache=init_cache,
deterministic=deterministic,
output_attentions=output_attentions,
)
attention_output = attention_outputs[0]
if encoder_hidden_states is not None:
cross_attention_outputs = self.crossattention(
attention_output,
attention_mask=encoder_attention_mask,
layer_head_mask=layer_head_mask,
key_value_states=encoder_hidden_states,
deterministic=deterministic,
output_attentions=output_attentions,
)
attention_output = cross_attention_outputs[0]
hidden_states = self.intermediate(attention_output)
hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic)
outputs = (hidden_states,)
if output_attentions:
outputs += (attention_outputs[1],)
if encoder_hidden_states is not None:
outputs += (cross_attention_outputs[1],)
return outputs
class FlaxElectraLayerCollection(nn.Module):
config: ElectraConfig
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self):
if self.gradient_checkpointing:
FlaxElectraCheckpointLayer = remat(FlaxElectraLayer, static_argnums=(5, 6, 7))
self.layers = [
FlaxElectraCheckpointLayer(self.config, name=str(i), dtype=self.dtype)
for i in range(self.config.num_hidden_layers)
]
else:
self.layers = [
FlaxElectraLayer(self.config, name=str(i), dtype=self.dtype)
for i in range(self.config.num_hidden_layers)
]
def __call__(
self,
hidden_states,
attention_mask,
head_mask,
encoder_hidden_states: Optional[jnp.ndarray] = None,
encoder_attention_mask: Optional[jnp.ndarray] = None,
init_cache: bool = False,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
all_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
if head_mask is not None:
if head_mask.shape[0] != (len(self.layers)):
raise ValueError(
f"The head_mask should be specified for {len(self.layers)} layers, but it is for "
f" {head_mask.shape[0]}."
)
for i, layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = layer(
hidden_states,
attention_mask,
head_mask[i] if head_mask is not None else None,
encoder_hidden_states,
encoder_attention_mask,
init_cache,
deterministic,
output_attentions,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions += (layer_outputs[1],)
if encoder_hidden_states is not None:
all_cross_attentions += (layer_outputs[2],)
if output_hidden_states:
all_hidden_states += (hidden_states,)
outputs = (hidden_states, all_hidden_states, all_attentions, all_cross_attentions)
if not return_dict:
return tuple(v for v in outputs if v is not None)
return FlaxBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_attentions,
cross_attentions=all_cross_attentions,
)
class FlaxElectraEncoder(nn.Module):
config: ElectraConfig
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self):
self.layer = FlaxElectraLayerCollection(
self.config,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
def __call__(
self,
hidden_states,
attention_mask,
head_mask,
encoder_hidden_states: Optional[jnp.ndarray] = None,
encoder_attention_mask: Optional[jnp.ndarray] = None,
init_cache: bool = False,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
return self.layer(
hidden_states,
attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
init_cache=init_cache,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
class FlaxElectraGeneratorPredictions(nn.Module):
config: ElectraConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
self.dense = nn.Dense(self.config.embedding_size, dtype=self.dtype)
def __call__(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = ACT2FN[self.config.hidden_act](hidden_states)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states
class FlaxElectraDiscriminatorPredictions(nn.Module):
"""用于鉴别器的预测模块,由两个密集层组成。"""
config: ElectraConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.dense = nn.Dense(self.config.hidden_size, dtype=self.dtype)
self.dense_prediction = nn.Dense(1, dtype=self.dtype)
def __call__(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = ACT2FN[self.config.hidden_act](hidden_states)
hidden_states = self.dense_prediction(hidden_states).squeeze(-1)
return hidden_states
class FlaxElectraPreTrainedModel(FlaxPreTrainedModel):
"""
处理权重初始化和一个简单接口以下载和加载预训练模型的抽象类。
"""
config_class = ElectraConfig
base_model_prefix = "electra"
module_class: nn.Module = None
def __init__(
self,
config: ElectraConfig,
input_shape: Tuple = (1, 1),
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
gradient_checkpointing: bool = False,
**kwargs,
):
module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def enable_gradient_checkpointing(self):
self._module = self.module_class(
config=self.config,
dtype=self.dtype,
gradient_checkpointing=True,
)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
input_ids = jnp.zeros(input_shape, dtype="i4")
token_type_ids = jnp.zeros_like(input_ids)
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
attention_mask = jnp.ones_like(input_ids)
head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads))
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
if self.config.add_cross_attention:
encoder_hidden_states = jnp.zeros(input_shape + (self.config.hidden_size,))
encoder_attention_mask = attention_mask
module_init_outputs = self.module.init(
rngs,
input_ids,
attention_mask,
token_type_ids,
position_ids,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
return_dict=False,
)
else:
module_init_outputs = self.module.init(
rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False
)
random_params = module_init_outputs["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
def init_cache(self, batch_size, max_length):
r"""
Args:
batch_size (`int`):
batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
max_length (`int`):
maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
cache.
"""
input_ids = jnp.ones((batch_size, max_length), dtype="i4")
attention_mask = jnp.ones_like(input_ids, dtype="i4")
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
init_variables = self.module.init(
jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
)
return unfreeze(init_variables["cache"])
@add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def __call__(
self,
input_ids,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
params: dict = None,
dropout_rng: jax.random.PRNGKey = None,
train: bool = False,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
past_key_values: dict = None,
class FlaxElectraModule(nn.Module):
config: ElectraConfig
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self):
self.embeddings = FlaxElectraEmbeddings(self.config, dtype=self.dtype)
if self.config.embedding_size != self.config.hidden_size:
self.embeddings_project = nn.Dense(self.config.hidden_size, dtype=self.dtype)
self.encoder = FlaxElectraEncoder(
self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)
def __call__(
self,
input_ids,
attention_mask,
token_type_ids,
position_ids,
head_mask: Optional[np.ndarray] = None,
encoder_hidden_states: Optional[jnp.ndarray] = None,
encoder_attention_mask: Optional[jnp.ndarray] = None,
init_cache: bool = False,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
embeddings = self.embeddings(
input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic
)
if hasattr(self, "embeddings_project"):
embeddings = self.embeddings_project(embeddings)
return self.encoder(
embeddings,
attention_mask,
head_mask=head_mask,
deterministic=deterministic,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
init_cache=init_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
@add_start_docstrings(
"The bare Electra Model transformer outputting raw hidden-states without any specific head on top.",
ELECTRA_START_DOCSTRING,
)
class FlaxElectraModel(FlaxElectraPreTrainedModel):
module_class = FlaxElectraModule
append_call_sample_docstring(FlaxElectraModel, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutput, _CONFIG_FOR_DOC)
class FlaxElectraTiedDense(nn.Module):
embedding_size: int
dtype: jnp.dtype = jnp.float32
precision = None
bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros
def setup(self):
self.bias = self.param("bias", self.bias_init, (self.embedding_size,))
def __call__(self, x, kernel):
x = jnp.asarray(x, self.dtype)
kernel = jnp.asarray(kernel, self.dtype)
y = lax.dot_general(
x,
kernel,
(((x.ndim - 1,), (0,)), ((), ())),
precision=self.precision,
)
bias = jnp.asarray(self.bias, self.dtype)
return y + bias
class FlaxElectraForMaskedLMModule(nn.Module):
config: ElectraConfig
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self):
self.electra = FlaxElectraModule(
config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)
self.generator_predictions = FlaxElectraGeneratorPredictions(config=self.config, dtype=self.dtype)
if self.config.tie_word_embeddings:
self.generator_lm_head = FlaxElectraTiedDense(self.config.vocab_size, dtype=self.dtype)
else:
self.generator_lm_head = nn.Dense(self.config.vocab_size, dtype=self.dtype)
def __call__(
self,
input_ids,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
outputs = self.electra(
input_ids,
attention_mask,
token_type_ids,
position_ids,
head_mask,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
prediction_scores = self.generator_predictions(hidden_states)
if self.config.tie_word_embeddings:
shared_embedding = self.electra.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
prediction_scores = self.generator_lm_head(prediction_scores, shared_embedding.T)
else:
prediction_scores = self.generator_lm_head(prediction_scores)
if not return_dict:
return (prediction_scores,) + outputs[1:]
return FlaxMaskedLMOutput(
logits=prediction_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings("""Electra Model with a `language modeling` head on top.""", ELECTRA_START_DOCSTRING)
class FlaxElectraForMaskedLM(FlaxElectraPreTrainedModel):
module_class = FlaxElectraForMaskedLMModule
append_call_sample_docstring(FlaxElectraForMaskedLM, _CHECKPOINT_FOR_DOC, FlaxMaskedLMOutput, _CONFIG_FOR_DOC)
class FlaxElectraForPreTrainingModule(nn.Module):
config: ElectraConfig
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self):
self.electra = FlaxElectraModule(
config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)
self.discriminator_predictions = FlaxElectraDiscriminatorPredictions(config=self.config, dtype=self.dtype)
def __call__(
self,
input_ids,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
outputs = self.electra(
input_ids,
attention_mask,
token_type_ids,
position_ids,
head_mask,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
logits = self.discriminator_predictions(hidden_states)
if not return_dict:
return (logits,) + outputs[1:]
return FlaxElectraForPreTrainingOutput(
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"""
Electra model with a binary classification head on top as used during pretraining for identifying generated tokens.
It is recommended to load the discriminator checkpoint into that model.
""",
ELECTRA_START_DOCSTRING,
)
class FlaxElectraForPreTraining(FlaxElectraPreTrainedModel):
module_class = FlaxElectraForPreTrainingModule
FLAX_ELECTRA_FOR_PRETRAINING_DOCSTRING = """
Returns:
Example:
```
>>> from transformers import AutoTokenizer, FlaxElectraForPreTraining
>>> tokenizer = AutoTokenizer.from_pretrained("google/electra-small-discriminator")
>>> model = FlaxElectraForPreTraining.from_pretrained("google/electra-small-discriminator")
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="np")
>>> outputs = model(**inputs)
>>> prediction_logits = outputs.logits
```
"""
overwrite_call_docstring(
FlaxElectraForPreTraining,
ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length") + FLAX_ELECTRA_FOR_PRETRAINING_DOCSTRING,
)
FlaxElectraForPreTraining, output_type=FlaxElectraForPreTrainingOutput, config_class=_CONFIG_FOR_DOC
)
class FlaxElectraForTokenClassificationModule(nn.Module):
config: ElectraConfig
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self):
self.electra = FlaxElectraModule(
config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)
classifier_dropout = (
self.config.classifier_dropout
if self.config.classifier_dropout is not None
else self.config.hidden_dropout_prob
)
self.dropout = nn.Dropout(classifier_dropout)
self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype)
def __call__(
self,
input_ids,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
outputs = self.electra(
input_ids,
attention_mask,
token_type_ids,
position_ids,
head_mask,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
logits = self.classifier(hidden_states)
if not return_dict:
return (logits,) + outputs[1:]
return FlaxTokenClassifierOutput(
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"""
Electra model with a token classification head on top.
Both the discriminator and generator may be loaded into this model.
""",
ELECTRA_START_DOCSTRING,
)
class FlaxElectraForTokenClassification(FlaxElectraPreTrainedModel):
module_class = FlaxElectraForTokenClassificationModule
append_call_sample_docstring(
FlaxElectraForTokenClassification,
_CHECKPOINT_FOR_DOC,
FlaxTokenClassifierOutput,
_CONFIG_FOR_DOC,
)
def identity(x, **kwargs):
return x
class FlaxElectraSequenceSummary(nn.Module):
r"""
Compute a single vector summary of a sequence hidden states.
"""
Args:
config ([`PretrainedConfig`]):
The config used by the model. Relevant arguments in the config class of the model are (refer to the actual
config class of your model for the default values it uses):
- **summary_use_proj** (`bool`) -- Add a projection after the vector extraction.
- **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes
(otherwise to `config.hidden_size`).
- **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output,
another string or `None` will add no activation.
- **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation.
- **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation.
"""
# 定义一个类变量config,它是一个ElectraConfig对象
config: ElectraConfig
# 定义一个数据类型变量dtype,默认为jnp.float32
dtype: jnp.dtype = jnp.float32
# 类的初始化方法
def setup(self):
# 设置summary初始值为identity函数
self.summary = identity
# 检查config对象是否有summary_use_proj属性,并且它为True
if hasattr(self.config, "summary_use_proj") and self.config.summary_use_proj:
# 检查config对象是否有summary_proj_to_labels属性,并且它为True,并且config.num_labels大于0
if (
hasattr(self.config, "summary_proj_to_labels")
and self.config.summary_proj_to_labels
and self.config.num_labels > 0
):
# 设置num_classes为config.num_labels
num_classes = self.config.num_labels
else:
# 否则设置num_classes为config.hidden_size
num_classes = self.config.hidden_size
# 将summary设置为一个全连接层nn.Dense,输出维度为num_classes,数据类型为self.dtype
self.summary = nn.Dense(num_classes, dtype=self.dtype)
# 获取summary_activation字符串属性值
activation_string = getattr(self.config, "summary_activation", None)
# 根据activation_string获取对应的激活函数,如果为None则使用恒等函数lambda x: x
self.activation = ACT2FN[activation_string] if activation_string else lambda x: x # noqa F407
# 设置first_dropout初始值为identity函数
self.first_dropout = identity
# 检查config对象是否有summary_first_dropout属性,并且其值大于0
if hasattr(self.config, "summary_first_dropout") and self.config.summary_first_dropout > 0:
# 将first_dropout设置为一个Dropout层,丢弃概率为config.summary_first_dropout
self.first_dropout = nn.Dropout(self.config.summary_first_dropout)
# 设置last_dropout初始值为identity函数
self.last_dropout = identity
# 检查config对象是否有summary_last_dropout属性,并且其值大于0
if hasattr(self.config, "summary_last_dropout") and self.config.summary_last_dropout > 0:
# 将last_dropout设置为一个Dropout层,丢弃概率为config.summary_last_dropout
self.last_dropout = nn.Dropout(self.config.summary_last_dropout)
def __call__(self, hidden_states, cls_index=None, deterministic: bool = True):
"""
Compute a single vector summary of a sequence hidden states.
Args:
hidden_states (`jnp.ndarray` of shape `[batch_size, seq_len, hidden_size]`):
The hidden states of the last layer.
cls_index (`jnp.ndarray` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*):
Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token.
Returns:
`jnp.ndarray`: The summary of the sequence hidden states.
"""
# NOTE: This function computes a summary vector of the sequence hidden states.
# Extract the first token's hidden state from each sequence in the batch
output = hidden_states[:, 0]
# Apply dropout to the extracted hidden state
output = self.first_dropout(output, deterministic=deterministic)
# Compute the summary vector using a predefined method
output = self.summary(output)
# Apply an activation function to the computed summary vector
output = self.activation(output)
# Apply dropout to the final output vector before returning
output = self.last_dropout(output, deterministic=deterministic)
# Return the final summary vector
return output
# 定义一个基于 Flax 的 Electra 多选题模型的模块类
class FlaxElectraForMultipleChoiceModule(nn.Module):
# 指定配置对象为 ElectraConfig
config: ElectraConfig
# 指定数据类型为 jnp.float32 的浮点数
dtype: jnp.dtype = jnp.float32
# 梯度检查点,默认为关闭状态
gradient_checkpointing: bool = False
# 模块初始化方法
def setup(self):
# 创建 Electra 模型对象
self.electra = FlaxElectraModule(
config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)
# 创建序列摘要对象
self.sequence_summary = FlaxElectraSequenceSummary(config=self.config, dtype=self.dtype)
# 创建分类器对象,使用 Dense 层,输出维度为 1
self.classifier = nn.Dense(1, dtype=self.dtype)
# 对象调用方法,处理输入并返回输出
def __call__(
self,
input_ids,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
# 获取选择题的数量
num_choices = input_ids.shape[1]
# 若输入不为 None,则重塑输入的形状以便处理
input_ids = input_ids.reshape(-1, input_ids.shape[-1]) if input_ids is not None else None
attention_mask = attention_mask.reshape(-1, attention_mask.shape[-1]) if attention_mask is not None else None
token_type_ids = token_type_ids.reshape(-1, token_type_ids.shape[-1]) if token_type_ids is not None else None
position_ids = position_ids.reshape(-1, position_ids.shape[-1]) if position_ids is not None else None
# 使用 Electra 模型进行前向传播
outputs = self.electra(
input_ids,
attention_mask,
token_type_ids,
position_ids,
head_mask,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# 提取隐藏状态
hidden_states = outputs[0]
# 对隐藏状态进行序列摘要
pooled_output = self.sequence_summary(hidden_states, deterministic=deterministic)
# 使用分类器进行分类,生成逻辑回归结果
logits = self.classifier(pooled_output)
# 重塑 logits 的形状以匹配输入的多选题数量
reshaped_logits = logits.reshape(-1, num_choices)
# 如果不返回字典,则返回元组形式的结果
if not return_dict:
return (reshaped_logits,) + outputs[1:]
# 返回多选题模型的输出,包括重塑后的 logits,隐藏状态和注意力
return FlaxMultipleChoiceModelOutput(
logits=reshaped_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
# 为 FlaxElectraForMultipleChoice 类添加文档字符串,描述其功能和用途
@add_start_docstrings(
"""
ELECTRA Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
softmax) e.g. for RocStories/SWAG tasks.
""",
ELECTRA_START_DOCSTRING,
)
class FlaxElectraForMultipleChoice(FlaxElectraPreTrainedModel):
module_class = FlaxElectraForMultipleChoiceModule
# 为 FlaxElectraForMultipleChoice 类的调用方法添加文档字符串示例
overwrite_call_docstring(
FlaxElectraForMultipleChoice, ELECTRA_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
)
# 为 FlaxElectraForMultipleChoice 类添加调用方法的样例文档字符串
append_call_sample_docstring(
FlaxElectraForMultipleChoice,
_CHECKPOINT_FOR_DOC,
FlaxMultipleChoiceModelOutput,
_CONFIG_FOR_DOC,
)
# 定义一个基于 Flax 的 Electra 问答模型的模块类
class FlaxElectraForQuestionAnsweringModule(nn.Module):
# 指定配置对象为 ElectraConfig
config: ElectraConfig
# 指定数据类型为 jnp.float32 的浮点数
dtype: jnp.dtype = jnp.float32
# 设置类中的梯度检查点标志,默认为 False
gradient_checkpointing: bool = False
# 初始化模型设置
def setup(self):
# 使用给定的配置、数据类型和梯度检查点设置创建 FlaxElectraModule 实例
self.electra = FlaxElectraModule(
config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)
# 创建输出层,用于问题回答任务,输出维度为 self.config.num_labels
self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype)
# 定义对象的调用方法,处理输入并返回预测结果
def __call__(
self,
input_ids,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
# 调用 Electra 模型进行前向传播
outputs = self.electra(
input_ids,
attention_mask,
token_type_ids,
position_ids,
head_mask,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# 获取隐藏状态
hidden_states = outputs[0]
# 使用输出层计算起始和结束位置的 logits
logits = self.qa_outputs(hidden_states)
# 按输出类别数将 logits 分割为起始位置和结束位置的 logits
start_logits, end_logits = logits.split(self.config.num_labels, axis=-1)
# 去除最后一个维度上的冗余维度
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)
# 如果不返回字典,则返回元组形式的结果
if not return_dict:
return (start_logits, end_logits) + outputs[1:]
# 返回 FlaxQuestionAnsweringModelOutput 对象,包含起始和结束 logits、隐藏状态和注意力
return FlaxQuestionAnsweringModelOutput(
start_logits=start_logits,
end_logits=end_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"""
ELECTRA Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
""",
ELECTRA_START_DOCSTRING,
)
class FlaxElectraForQuestionAnswering(FlaxElectraPreTrainedModel):
module_class = FlaxElectraForQuestionAnsweringModule
append_call_sample_docstring(
FlaxElectraForQuestionAnswering,
_CHECKPOINT_FOR_DOC,
FlaxQuestionAnsweringModelOutput,
_CONFIG_FOR_DOC,
)
class FlaxElectraClassificationHead(nn.Module):
"""Head for sentence-level classification tasks."""
config: ElectraConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
# Initialize a fully connected layer with hidden_size neurons
self.dense = nn.Dense(self.config.hidden_size, dtype=self.dtype)
# Determine dropout rate based on config values
classifier_dropout = (
self.config.classifier_dropout
if self.config.classifier_dropout is not None
else self.config.hidden_dropout_prob
)
# Apply dropout with computed rate
self.dropout = nn.Dropout(classifier_dropout)
# Final output layer with num_labels neurons
self.out_proj = nn.Dense(self.config.num_labels, dtype=self.dtype)
def __call__(self, hidden_states, deterministic: bool = True):
# Extract the representation of the first token (<s>) from hidden_states
x = hidden_states[:, 0, :]
# Apply dropout to the extracted token representation
x = self.dropout(x, deterministic=deterministic)
# Pass through the fully connected layer
x = self.dense(x)
# Apply GELU activation function (similar to BERT's tanh)
x = ACT2FN["gelu"](x)
# Apply dropout again
x = self.dropout(x, deterministic=deterministic)
# Pass through the output layer
x = self.out_proj(x)
# Return the logits for sequence classification
return x
class FlaxElectraForSequenceClassificationModule(nn.Module):
config: ElectraConfig
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self):
# Initialize Electra module with specified configuration
self.electra = FlaxElectraModule(
config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)
# Initialize classification head using the same configuration
self.classifier = FlaxElectraClassificationHead(config=self.config, dtype=self.dtype)
def __call__(
self,
input_ids,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
# 如果 `return_dict` 为 True,则返回一个命名元组对象 FlaxSequenceClassifierOutput
# 包含 logits, hidden_states 和 attentions 这些字段
if not return_dict:
# 如果 `return_dict` 为 False,返回一个元组,包含 logits 和 outputs 的其余部分
return (logits,) + outputs[1:]
# 如果 `return_dict` 为 True,返回一个 FlaxSequenceClassifierOutput 对象
return FlaxSequenceClassifierOutput(
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"""
Electra Model transformer with a sequence classification/regression head on top (a linear layer on top of the
pooled output) e.g. for GLUE tasks.
""",
ELECTRA_START_DOCSTRING,
)
class FlaxElectraForSequenceClassification(FlaxElectraPreTrainedModel):
module_class = FlaxElectraForSequenceClassificationModule
append_call_sample_docstring(
FlaxElectraForSequenceClassification,
_CHECKPOINT_FOR_DOC,
FlaxSequenceClassifierOutput,
_CONFIG_FOR_DOC,
)
class FlaxElectraForCausalLMModule(nn.Module):
config: ElectraConfig
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self):
self.electra = FlaxElectraModule(
config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)
self.generator_predictions = FlaxElectraGeneratorPredictions(config=self.config, dtype=self.dtype)
if self.config.tie_word_embeddings:
self.generator_lm_head = FlaxElectraTiedDense(self.config.vocab_size, dtype=self.dtype)
else:
self.generator_lm_head = nn.Dense(self.config.vocab_size, dtype=self.dtype)
def __call__(
self,
input_ids,
attention_mask: Optional[jnp.ndarray] = None,
token_type_ids: Optional[jnp.ndarray] = None,
position_ids: Optional[jnp.ndarray] = None,
head_mask: Optional[jnp.ndarray] = None,
encoder_hidden_states: Optional[jnp.ndarray] = None,
encoder_attention_mask: Optional[jnp.ndarray] = None,
init_cache: bool = False,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
**注释:**
# 添加起始文档字符串,描述此模型是基于Electra模型的序列分类/回归头(线性层叠加在汇总输出之上),例如用于GLUE任务。
@add_start_docstrings(
"""
Electra Model transformer with a sequence classification/regression head on top (a linear layer on top of the
pooled output) e.g. for GLUE tasks.
""",
ELECTRA_START_DOCSTRING,
)
# 定义用于序列分类的FlaxElectraForSequenceClassification类,继承自FlaxElectraPreTrainedModel类。
class FlaxElectraForSequenceClassification(FlaxElectraPreTrainedModel):
module_class = FlaxElectraForSequenceClassificationModule
# 向FlaxElectraForSequenceClassification类添加调用示例文档字符串。
append_call_sample_docstring(
FlaxElectraForSequenceClassification,
_CHECKPOINT_FOR_DOC,
FlaxSequenceClassifierOutput,
_CONFIG_FOR_DOC,
)
# 定义用于因果语言模型的FlaxElectraForCausalLMModule类,继承自nn.Module。
class FlaxElectraForCausalLMModule(nn.Module):
config: ElectraConfig # 类型注解,指定config属性的类型为ElectraConfig。
dtype: jnp.dtype = jnp.float32 # 类型注解,指定dtype属性的类型,默认为jnp.float32。
gradient_checkpointing: bool = False # 类型注解,指定gradient_checkpointing属性的类型,默认为False。
# 模块的设置方法
def setup(self):
# 创建Electra模块并赋值给self.electra属性,根据配置、数据类型和梯度检查点设置。
self.electra = FlaxElectraModule(
config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)
# 创建生成器预测模块并赋值给self.generator_predictions属性,根据配置和数据类型设置。
self.generator_predictions = FlaxElectraGeneratorPredictions(config=self.config, dtype=self.dtype)
# 如果配置要求共享词嵌入,则创建FlaxElectraTiedDense类型的生成器语言模型头部,否则创建普通的nn.Dense。
if self.config.tie_word_embeddings:
self.generator_lm_head = FlaxElectraTiedDense(self.config.vocab_size, dtype=self.dtype)
else:
self.generator_lm_head = nn.Dense(self.config.vocab_size, dtype=self.dtype)
# 模块的调用方法,接收多个输入参数,执行因果语言模型的计算。
def __call__(
self,
input_ids,
attention_mask: Optional[jnp.ndarray] = None,
token_type_ids: Optional[jnp.ndarray] = None,
position_ids: Optional[jnp.ndarray] = None,
head_mask: Optional[jnp.ndarray] = None,
encoder_hidden_states: Optional[jnp.ndarray] = None,
encoder_attention_mask: Optional[jnp.ndarray] = None,
init_cache: bool = False,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
# 调用 ELECTRA 模型进行推理,获取输出结果
outputs = self.electra(
input_ids,
attention_mask,
token_type_ids,
position_ids,
head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
init_cache=init_cache,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# 从模型输出中获取隐藏状态
hidden_states = outputs[0]
# 使用生成器生成预测分数
prediction_scores = self.generator_predictions(hidden_states)
# 如果配置指定词嵌入共享
if self.config.tie_word_embeddings:
# 获取 ELECTRA 模型中的共享词嵌入参数
shared_embedding = self.electra.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
# 使用共享词嵌入进行生成器的 LM 头部预测
prediction_scores = self.generator_lm_head(prediction_scores, shared_embedding.T)
else:
# 否则,直接使用生成器的 LM 头部进行预测
prediction_scores = self.generator_lm_head(prediction_scores)
# 如果不返回字典形式的输出
if not return_dict:
# 返回包含预测分数和额外输出的元组
return (prediction_scores,) + outputs[1:]
# 返回 FlaxCausalLMOutputWithCrossAttentions 类的对象,其中包含预测分数、隐藏状态、注意力权重及交叉注意力权重
return FlaxCausalLMOutputWithCrossAttentions(
logits=prediction_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
)
@add_start_docstrings(
"""
Electra Model with a language modeling head on top (a linear layer on top of the hidden-states output) e.g for
autoregressive tasks.
""",
ELECTRA_START_DOCSTRING,
)
# 基于 transformers.models.bert.modeling_flax_bert.FlaxBertForCausalLM 中的代码,将 Bert 替换为 Electra
class FlaxElectraForCausalLM(FlaxElectraPreTrainedModel):
module_class = FlaxElectraForCausalLMModule
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
# 初始化缓存
batch_size, seq_length = input_ids.shape
# 使用模型的初始化缓存方法创建过去键值对
past_key_values = self.init_cache(batch_size, max_length)
# 注意:通常情况下,需要在 attention_mask 中对超出 input_ids.shape[-1] 和小于 cache_length 的位置填充 0
# 但由于解码器使用因果遮蔽,这些位置已经被遮蔽了
# 因此,我们可以在这里创建一个静态的 attention_mask,这对于编译来说更有效
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
if attention_mask is not None:
# 计算位置 ID,根据 attention_mask 累积和减去 1
position_ids = attention_mask.cumsum(axis=-1) - 1
# 更新 extended_attention_mask,使用 attention_mask 进行动态更新切片
extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
else:
# 如果没有提供 attention_mask,则广播生成位置 ID
position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
return {
"past_key_values": past_key_values,
"attention_mask": extended_attention_mask,
"position_ids": position_ids,
}
def update_inputs_for_generation(self, model_outputs, model_kwargs):
# 更新生成过程中的模型参数
model_kwargs["past_key_values"] = model_outputs.past_key_values
model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
return model_kwargs
# 将样例调用的文档字符串附加到类 FlaxElectraForCausalLM 上,用于文档化
append_call_sample_docstring(
FlaxElectraForCausalLM,
_CHECKPOINT_FOR_DOC,
FlaxCausalLMOutputWithCrossAttentions,
_CONFIG_FOR_DOC,
)
.\models\electra\modeling_tf_electra.py
""" TF Electra model."""
from __future__ import annotations
import math
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import numpy as np
import tensorflow as tf
from ...activations_tf import get_tf_activation
from ...modeling_tf_outputs import (
TFBaseModelOutputWithPastAndCrossAttentions,
TFMaskedLMOutput,
TFMultipleChoiceModelOutput,
TFQuestionAnsweringModelOutput,
TFSequenceClassifierOutput,
TFTokenClassifierOutput,
)
from ...modeling_tf_utils import (
TFMaskedLanguageModelingLoss,
TFModelInputType,
TFMultipleChoiceLoss,
TFPreTrainedModel,
TFQuestionAnsweringLoss,
TFSequenceClassificationLoss,
TFSequenceSummary,
TFTokenClassificationLoss,
get_initializer,
keras,
keras_serializable,
unpack_inputs,
)
from ...tf_utils import (
check_embeddings_within_bounds,
shape_list,
stable_softmax,
)
from ...utils import (
ModelOutput,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from .configuration_electra import ElectraConfig
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "google/electra-small-discriminator"
_CONFIG_FOR_DOC = "ElectraConfig"
TF_ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST = [
"google/electra-small-generator",
"google/electra-base-generator",
"google/electra-large-generator",
"google/electra-small-discriminator",
"google/electra-base-discriminator",
"google/electra-large-discriminator",
]
class TFElectraSelfAttention(keras.layers.Layer):
def __init__(self, config: ElectraConfig, **kwargs):
super().__init__(**kwargs)
if config.hidden_size % config.num_attention_heads != 0:
raise ValueError(
f"The hidden size ({config.hidden_size}) is not a multiple of the number "
f"of attention heads ({config.num_attention_heads})"
)
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.sqrt_att_head_size = math.sqrt(self.attention_head_size)
self.query = keras.layers.Dense(
units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
)
self.key = keras.layers.Dense(
units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key"
)
self.value = keras.layers.Dense(
units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
)
self.dropout = keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
self.is_decoder = config.is_decoder
self.config = config
def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))
return tf.transpose(tensor, perm=[0, 2, 1, 3])
def call(
self,
hidden_states: tf.Tensor,
attention_mask: tf.Tensor,
head_mask: tf.Tensor,
encoder_hidden_states: tf.Tensor,
encoder_attention_mask: tf.Tensor,
past_key_value: Tuple[tf.Tensor],
output_attentions: bool,
training: bool = False,
):
pass
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "query", None) is not None:
with tf.name_scope(self.query.name):
self.query.build([None, None, self.config.hidden_size])
if getattr(self, "key", None) is not None:
with tf.name_scope(self.key.name):
self.key.build([None, None, self.config.hidden_size])
if getattr(self, "value", None) is not None:
with tf.name_scope(self.value.name):
self.value.build([None, None, self.config.hidden_size])
class TFElectraSelfOutput(keras.layers.Layer):
def __init__(self, config: ElectraConfig, **kwargs):
super().__init__(**kwargs)
self.dense = keras.layers.Dense(
units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
)
self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
self.config = config
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.dropout(inputs=hidden_states, training=training)
hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor)
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])
class TFElectraAttention(keras.layers.Layer):
def __init__(self, config: ElectraConfig, **kwargs):
super().__init__(**kwargs)
self.self_attention = TFElectraSelfAttention(config, name="self")
self.dense_output = TFElectraSelfOutput(config, name="output")
def prune_heads(self, heads):
raise NotImplementedError
def call(
self,
input_tensor: tf.Tensor,
attention_mask: tf.Tensor,
head_mask: tf.Tensor,
encoder_hidden_states: tf.Tensor,
encoder_attention_mask: tf.Tensor,
past_key_value: Tuple[tf.Tensor],
output_attentions: bool,
training: bool = False,
) -> Tuple[tf.Tensor]:
self_outputs = self.self_attention(
hidden_states=input_tensor,
attention_mask=attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_value=past_key_value,
output_attentions=output_attentions,
training=training,
)
attention_output = self.dense_output(
hidden_states=self_outputs[0], input_tensor=input_tensor, training=training
)
outputs = (attention_output,) + self_outputs[1:]
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "self_attention", None) is not None:
with tf.name_scope(self.self_attention.name):
self.self_attention.build(None)
if getattr(self, "dense_output", None) is not None:
with tf.name_scope(self.dense_output.name):
self.dense_output.build(None)
class TFElectraIntermediate(keras.layers.Layer):
def __init__(self, config: ElectraConfig, **kwargs):
super().__init__(**kwargs)
self.dense = keras.layers.Dense(
units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = get_tf_activation(config.hidden_act)
else:
self.intermediate_act_fn = config.hidden_act
self.config = config
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
class TFElectraOutput(keras.layers.Layer):
def __init__(self, config: ElectraConfig, **kwargs):
super().__init__(**kwargs)
self.dense = keras.layers.Dense(
units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
)
self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
self.config = config
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.dropout(inputs=hidden_states, training=training)
hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor)
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.intermediate_size])
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])
class TFElectraLayer(keras.layers.Layer):
def __init__(self, config: ElectraConfig, **kwargs):
super().__init__(**kwargs)
self.attention = TFElectraAttention(config, name="attention")
self.is_decoder = config.is_decoder
self.add_cross_attention = config.add_cross_attention
if self.add_cross_attention:
if not self.is_decoder:
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
self.crossattention = TFElectraAttention(config, name="crossattention")
self.intermediate = TFElectraIntermediate(config, name="intermediate")
self.bert_output = TFElectraOutput(config, name="output")
def call(
self,
hidden_states: tf.Tensor,
attention_mask: tf.Tensor,
head_mask: tf.Tensor,
encoder_hidden_states: tf.Tensor | None,
encoder_attention_mask: tf.Tensor | None,
past_key_value: Tuple[tf.Tensor] | None,
output_attentions: bool,
training: bool = False,
) -> Tuple[tf.Tensor]:
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention(
input_tensor=hidden_states,
attention_mask=attention_mask,
head_mask=head_mask,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=self_attn_past_key_value,
output_attentions=output_attentions,
training=training,
)
attention_output = self_attention_outputs[0]
if self.is_decoder:
outputs = self_attention_outputs[1:-1]
present_key_value = self_attention_outputs[-1]
else:
outputs = self_attention_outputs[1:]
cross_attn_present_key_value = None
if self.is_decoder and encoder_hidden_states is not None:
if not hasattr(self, "crossattention"):
raise ValueError(
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
" by setting `config.add_cross_attention=True`"
)
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
cross_attention_outputs = self.crossattention(
input_tensor=attention_output,
attention_mask=attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_value=cross_attn_past_key_value,
output_attentions=output_attentions,
training=training,
)
attention_output = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1:-1]
cross_attn_present_key_value = cross_attention_outputs[-1]
present_key_value = present_key_value + cross_attn_present_key_value
intermediate_output = self.intermediate(hidden_states=attention_output)
layer_output = self.bert_output(
hidden_states=intermediate_output, input_tensor=attention_output, training=training
)
outputs = (layer_output,) + outputs
if self.is_decoder:
outputs = outputs + (present_key_value,)
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "attention", None) is not None:
with tf.name_scope(self.attention.name):
self.attention.build(None)
if getattr(self, "intermediate", None) is not None:
with tf.name_scope(self.intermediate.name):
self.intermediate.build(None)
if getattr(self, "bert_output", None) is not None:
with tf.name_scope(self.bert_output.name):
self.bert_output.build(None)
if getattr(self, "crossattention", None) is not None:
with tf.name_scope(self.crossattention.name):
self.crossattention.build(None)
class TFElectraEncoder(keras.layers.Layer):
def __init__(self, config: ElectraConfig, **kwargs):
super().__init__(**kwargs)
self.config = config
self.layer = [TFElectraLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)]
def call(
self,
hidden_states: tf.Tensor,
attention_mask: tf.Tensor,
head_mask: tf.Tensor,
encoder_hidden_states: tf.Tensor | None,
encoder_attention_mask: tf.Tensor | None,
past_key_values: Tuple[Tuple[tf.Tensor]] | None,
use_cache: Optional[bool],
output_attentions: bool,
output_hidden_states: bool,
return_dict: bool,
training: bool = False,
) -> Union[TFBaseModelOutputWithPastAndCrossAttentions, Tuple[tf.Tensor]]:
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
next_decoder_cache = () if use_cache else None
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
past_key_value = past_key_values[i] if past_key_values is not None else None
layer_outputs = layer_module(
hidden_states=hidden_states,
attention_mask=attention_mask,
head_mask=head_mask[i],
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_value=past_key_value,
output_attentions=output_attentions,
training=training,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[-1],)
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
if self.config.add_cross_attention and encoder_hidden_states is not None:
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v for v in [hidden_states, all_hidden_states, all_attentions, all_cross_attentions] if v is not None
)
return TFBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=all_hidden_states,
attentions=all_attentions,
cross_attentions=all_cross_attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "layer", None) is not None:
for layer in self.layer:
with tf.name_scope(layer.name):
layer.build(None)
class TFElectraPooler(keras.layers.Layer):
def __init__(self, config: ElectraConfig, **kwargs):
super().__init__(**kwargs)
self.dense = keras.layers.Dense(
units=config.hidden_size,
kernel_initializer=get_initializer(config.initializer_range),
activation="tanh",
name="dense",
)
self.config = config
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(inputs=first_token_tensor)
return pooled_output
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
class TFElectraEmbeddings(keras.layers.Layer):
"""Construct the embeddings from word, position and token_type embeddings."""
def __init__(self, config: ElectraConfig, **kwargs):
super().__init__(**kwargs)
self.config = config
self.embedding_size = config.embedding_size
self.max_position_embeddings = config.max_position_embeddings
self.initializer_range = config.initializer_range
self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
def build(self, input_shape=None):
with tf.name_scope("word_embeddings"):
self.weight = self.add_weight(
name="weight",
shape=[self.config.vocab_size, self.embedding_size],
initializer=get_initializer(self.initializer_range),
)
with tf.name_scope("token_type_embeddings"):
self.token_type_embeddings = self.add_weight(
name="embeddings",
shape=[self.config.type_vocab_size, self.embedding_size],
initializer=get_initializer(self.initializer_range),
)
with tf.name_scope("position_embeddings"):
self.position_embeddings = self.add_weight(
name="embeddings",
shape=[self.max_position_embeddings, self.embedding_size],
initializer=get_initializer(self.initializer_range),
)
if self.built:
return
self.built = True
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.embedding_size])
def call(
self,
input_ids: tf.Tensor = None,
position_ids: tf.Tensor = None,
token_type_ids: tf.Tensor = None,
inputs_embeds: tf.Tensor = None,
past_key_values_length=0,
training: bool = False,
) -> tf.Tensor:
"""
Applies embedding based on inputs tensor.
Returns:
final_embeddings (`tf.Tensor`): output embedding tensor.
"""
if input_ids is None and inputs_embeds is None:
raise ValueError("Need to provide either `input_ids` or `input_embeds`.")
if input_ids is not None:
check_embeddings_within_bounds(input_ids, self.config.vocab_size)
inputs_embeds = tf.gather(params=self.weight, indices=input_ids)
input_shape = shape_list(inputs_embeds)[:-1]
if token_type_ids is None:
token_type_ids = tf.fill(dims=input_shape, value=0)
if position_ids is None:
position_ids = tf.expand_dims(
tf.range(start=past_key_values_length, limit=input_shape[1] + past_key_values_length), axis=0
)
position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids)
final_embeddings = inputs_embeds + position_embeds + token_type_embeds
final_embeddings = self.LayerNorm(inputs=final_embeddings)
final_embeddings = self.dropout(inputs=final_embeddings, training=training)
return final_embeddings
class TFElectraDiscriminatorPredictions(keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.dense = keras.layers.Dense(config.hidden_size, name="dense")
self.dense_prediction = keras.layers.Dense(1, name="dense_prediction")
self.config = config
def call(self, discriminator_hidden_states, training=False):
hidden_states = self.dense(discriminator_hidden_states)
hidden_states = get_tf_activation(self.config.hidden_act)(hidden_states)
logits = tf.squeeze(self.dense_prediction(hidden_states), -1)
return logits
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
if getattr(self, "dense_prediction", None) is not None:
with tf.name_scope(self.dense_prediction.name):
self.dense_prediction.build([None, None, self.config.hidden_size])
class TFElectraGeneratorPredictions(keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dense = keras.layers.Dense(config.embedding_size, name="dense")
self.config = config
def call(self, generator_hidden_states, training=False):
hidden_states = self.dense(generator_hidden_states)
hidden_states = get_tf_activation("gelu")(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.embedding_size])
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
class TFElectraPreTrainedModel(TFPreTrainedModel):
"""
一个抽象类,用于处理权重初始化以及下载和加载预训练模型的简单接口。
"""
config_class = ElectraConfig
base_model_prefix = "electra"
_keys_to_ignore_on_load_unexpected = [r"generator_lm_head.weight"]
_keys_to_ignore_on_load_missing = [r"dropout"]
@keras_serializable
class TFElectraMainLayer(keras.layers.Layer):
config_class = ElectraConfig
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.config = config
self.is_decoder = config.is_decoder
self.embeddings = TFElectraEmbeddings(config, name="embeddings")
if config.embedding_size != config.hidden_size:
self.embeddings_project = keras.layers.Dense(config.hidden_size, name="embeddings_project")
self.encoder = TFElectraEncoder(config, name="encoder")
def get_input_embeddings(self):
return self.embeddings
def set_input_embeddings(self, value):
self.embeddings.weight = value
self.embeddings.vocab_size = shape_list(value)[0]
def _prune_heads(self, heads_to_prune):
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
raise NotImplementedError
def get_extended_attention_mask(self, attention_mask, input_shape, dtype, past_key_values_length=0):
batch_size, seq_length = input_shape
if attention_mask is None:
attention_mask = tf.fill(dims=(batch_size, seq_length + past_key_values_length), value=1)
attention_mask_shape = shape_list(attention_mask)
mask_seq_length = seq_length + past_key_values_length
if self.is_decoder:
seq_ids = tf.range(mask_seq_length)
causal_mask = tf.less_equal(
tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)),
seq_ids[None, :, None],
)
causal_mask = tf.cast(causal_mask, dtype=attention_mask.dtype)
extended_attention_mask = causal_mask * attention_mask[:, None, :]
attention_mask_shape = shape_list(extended_attention_mask)
extended_attention_mask = tf.reshape(
extended_attention_mask, (attention_mask_shape[0], 1, attention_mask_shape[1], attention_mask_shape[2])
)
if past_key_values_length > 0:
extended_attention_mask = extended_attention_mask[:, :, -seq_length:, :]
else:
extended_attention_mask = tf.reshape(
attention_mask, (attention_mask_shape[0], 1, 1, attention_mask_shape[1])
)
extended_attention_mask = tf.cast(extended_attention_mask, dtype=dtype)
one_cst = tf.constant(1.0, dtype=dtype)
ten_thousand_cst = tf.constant(-10000.0, dtype=dtype)
extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst)
return extended_attention_mask
def get_head_mask(self, head_mask):
if head_mask is not None:
raise NotImplementedError
else:
head_mask = [None] * self.config.num_hidden_layers
return head_mask
def call(
self,
input_ids: TFModelInputType | None = None,
attention_mask: np.ndarray | tf.Tensor | None = None,
token_type_ids: np.ndarray | tf.Tensor | None = None,
position_ids: np.ndarray | tf.Tensor | None = None,
head_mask: np.ndarray | tf.Tensor | None = None,
inputs_embeds: np.ndarray | tf.Tensor | None = None,
encoder_hidden_states: np.ndarray | tf.Tensor | None = None,
encoder_attention_mask: np.ndarray | tf.Tensor | None = None,
past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: Optional[bool] = False,
):
if self.built:
return
self.built = True
if getattr(self, "embeddings", None) is not None:
with tf.name_scope(self.embeddings.name):
self.embeddings.build(None)
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
if getattr(self, "embeddings_project", None) is not None:
with tf.name_scope(self.embeddings_project.name):
self.embeddings_project.build([None, None, self.config.embedding_size])
@dataclass
class TFElectraForPreTrainingOutput(ModelOutput):
"""
[`TFElectraForPreTraining`]的输出类型。
Args:
loss (*可选*, 当提供 `labels` 时返回, `tf.Tensor` 形状为 `(1,)`):
ELECTRA 目标的总损失。
logits (`tf.Tensor` 形状为 `(batch_size, sequence_length)`):
头部的预测分数(SoftMax 前每个标记的分数)。
hidden_states (`tuple(tf.Tensor)`, *可选*, 当 `output_hidden_states=True` 或 `config.output_hidden_states=True` 时返回):
元组的 `tf.Tensor`(一个用于嵌入输出 + 每个层的输出)形状为 `(batch_size, sequence_length, hidden_size)`。
模型在每层输出的隐藏状态以及初始嵌入输出。
attentions (`tuple(tf.Tensor)`, *可选*, 当 `output_attentions=True` 或 `config.output_attentions=True` 时返回):
元组的 `tf.Tensor`(每个层一个)形状为 `(batch_size, num_heads, sequence_length, sequence_length)`。
经过注意力 softmax 后的注意力权重,用于计算自注意力头部的加权平均。
"""
logits: tf.Tensor = None
hidden_states: Tuple[tf.Tensor] | None = None
attentions: Tuple[tf.Tensor] | None = None
ELECTRA_START_DOCSTRING = r"""
此模型继承自 [`TFPreTrainedModel`]。查看超类文档以获取库实现的所有模型的通用方法(如下载或保存、调整输入嵌入、修剪头部等)。
此模型还是 [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) 的子类。将其视为常规的 TF 2.0 Keras 模型,并参考 TF 2.0 文档,了解有关一般用法和行为的所有内容。
<Tip>
`transformers` 中的 TensorFlow 模型和层接受两种输入格式:
- 将所有输入作为关键字参数(类似于 PyTorch 模型);
- 将所有输入作为列表、元组或字典的第一个位置参数。
支持第二种格式的原因是,当传递输入给模型和层时,Keras 方法更喜欢此格式。由于这种支持,在使用诸如 `model.fit()` 等方法时,您应该能够“只需传递”您的输入和标签 - 只需使用 `model.fit()` 支持的任何格式!但是,如果您想在 Keras 方法如 `fit()` 和 `predict()` 之外使用第二种格式,比如在使用 Keras `Functional` API 创建自己的层或模型时,有三种可能性可以用于在第一个位置参数中收集所有输入张量:
- 只有 `input_ids` 的单个张量:`model(input_ids)`
- 可变长度列表,其中按文档字符串中给出的顺序包含一个或多个输入张量:
"""
`model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
- 当使用模型对象 `model` 时,可以传入一个包含输入张量的字典,键名需与文档字符串中给出的输入名称对应:
`model({"input_ids": input_ids, "token_type_ids": token_type_ids})`
Note that when creating models and layers with
[subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
about any of this, as you can just pass inputs like you would to any other Python function!
- 当使用子类化创建模型和层时,您无需担心这些细节,可以像传递任何其他 Python 函数的输入一样操作!
Parameters:
config ([`ElectraConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
- config ([`ElectraConfig`]): 包含模型所有参数的配置类。
使用配置文件初始化模型时,并不会加载与模型关联的权重,只加载配置信息。
查看 [`~PreTrainedModel.from_pretrained`] 方法以加载模型的权重。
"""
ELECTRA_INPUTS_DOCSTRING = r"""
Args:
input_ids (`Numpy array` or `tf.Tensor` of shape `({0})`):
Indices of input sequence tokens in the vocabulary.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and
[`PreTrainedTokenizer.encode`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
position_ids (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.max_position_embeddings - 1]`.
[What are position IDs?](../glossary#position-ids)
head_mask (`Numpy array` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
inputs_embeds (`tf.Tensor` of shape `({0}, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the
config will be used instead.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
used instead.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
eager mode, in graph mode the value will always be set to True.
training (`bool`, *optional*, defaults to `False`):
Whether or not to use the model in training mode (some modules like dropout modules have different
behaviors between training and evaluation).
"""
@add_start_docstrings(
# 添加文档字符串前缀,将其应用于下方的函数装饰器
"""
生成器模型和判别器模型的检查点可以加载到此模型中。
这是一个裸的 Electra 模型变压器,输出未经任何特定头部处理的原始隐藏状态。与 BERT 模型相似,但如果隐藏大小和嵌入大小不同,则在嵌入层和编码器之间使用额外的线性层。
ELECTRA_START_DOCSTRING 标识符,指示这是 Electra 模型的文档字符串的起始部分。
"""
)
# 结束类定义的括号
class TFElectraModel(TFElectraPreTrainedModel):
# TFElectraModel 类继承自 TFElectraPreTrainedModel 类
def __init__(self, config, *inputs, **kwargs):
# 初始化方法,接受 config 对象和任意其他输入参数
# 调用父类的初始化方法
super().__init__(config, *inputs, **kwargs)
# 创建 TFElectraMainLayer 实例并赋值给 self.electra
self.electra = TFElectraMainLayer(config, name="electra")
@unpack_inputs
@add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TFBaseModelOutputWithPastAndCrossAttentions,
config_class=_CONFIG_FOR_DOC,
)
def call(
self,
input_ids: TFModelInputType | None = None,
attention_mask: np.ndarray | tf.Tensor | None = None,
token_type_ids: np.ndarray | tf.Tensor | None = None,
position_ids: np.ndarray | tf.Tensor | None = None,
head_mask: np.ndarray | tf.Tensor | None = None,
inputs_embeds: np.ndarray | tf.Tensor | None = None,
encoder_hidden_states: np.ndarray | tf.Tensor | None = None,
encoder_attention_mask: np.ndarray | tf.Tensor | None = None,
past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: Optional[bool] = False,
) -> Union[TFBaseModelOutputWithPastAndCrossAttentions, Tuple[tf.Tensor]]:
r"""
encoder_hidden_states (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
the model is configured as a decoder.
encoder_attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
past_key_values (`Tuple[Tuple[tf.Tensor]]` of length `config.n_layers`)
contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
use_cache (`bool`, *optional*, defaults to `True`):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`). Set to `False` during training, `True` during generation
"""
# 调用 Electra 模型进行前向传播,接受多个输入参数
outputs = self.electra(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
# 返回 Electra 模型的输出结果
return outputs
def build(self, input_shape=None):
# 如果模型已经构建过,则直接返回,避免重复构建
if self.built:
return
# 标记模型已经构建
self.built = True
# 如果 self.electra 存在,则在对应的命名空间下构建 Electra 模型
if getattr(self, "electra", None) is not None:
with tf.name_scope(self.electra.name):
# 构建 Electra 模型,传入 None 作为输入形状
self.electra.build(None)
# 使用装饰器为类添加文档字符串,描述该类的作用和功能
@add_start_docstrings(
"""
Electra model with a binary classification head on top as used during pretraining for identifying generated tokens.
Even though both the discriminator and generator may be loaded into this model, the discriminator is the only model
of the two to have the correct classification head to be used for this model.
""",
ELECTRA_START_DOCSTRING,
)
class TFElectraForPreTraining(TFElectraPreTrainedModel):
# 初始化方法,接收配置和其他关键字参数
def __init__(self, config, **kwargs):
# 调用父类的初始化方法
super().__init__(config, **kwargs)
# 创建 Electra 主层,并命名为 "electra"
self.electra = TFElectraMainLayer(config, name="electra")
# 创建 Electra 鉴别器预测层,并命名为 "discriminator_predictions"
self.discriminator_predictions = TFElectraDiscriminatorPredictions(config, name="discriminator_predictions")
# 调用方法,接收多个输入参数,执行模型的前向传播
@unpack_inputs
# 使用装饰器添加模型前向传播的文档字符串,描述输入参数的格式和作用
@add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
# 使用装饰器替换返回值的文档字符串,指定返回结果的类型为 TFElectraForPreTrainingOutput
@replace_return_docstrings(output_type=TFElectraForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
def call(
self,
input_ids: TFModelInputType | None = None,
attention_mask: np.ndarray | tf.Tensor | None = None,
token_type_ids: np.ndarray | tf.Tensor | None = None,
position_ids: np.ndarray | tf.Tensor | None = None,
head_mask: np.ndarray | tf.Tensor | None = None,
inputs_embeds: np.ndarray | tf.Tensor | None = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: Optional[bool] = False,
# 最后一个参数没有被完全列出
# 表示是否返回字典形式的结果
return_dict: Optional[bool] = None,
# 是否在训练模式下运行模型
training: Optional[bool] = False,
discriminator_hidden_states = self.electra(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
# 调用 self.electra 模型进行前向传播,传入各种输入参数,获取鉴别器模型的隐藏状态。
discriminator_sequence_output = discriminator_hidden_states[0]
# 从鉴别器模型的隐藏状态中提取序列输出,即第一个元素。
logits = self.discriminator_predictions(discriminator_sequence_output)
# 使用 self.discriminator_predictions 模型预测鉴别器输出的 logits(对数概率)。
if not return_dict:
return (logits,) + discriminator_hidden_states[1:]
# 如果 return_dict 参数为 False,则返回 logits 和鉴别器模型的其他隐藏状态。
return TFElectraForPreTrainingOutput(
logits=logits,
hidden_states=discriminator_hidden_states.hidden_states,
attentions=discriminator_hidden_states.attentions,
)
# 如果 return_dict 参数为 True,则返回 TFElectraForPreTrainingOutput 对象,包含 logits、隐藏状态和注意力权重。
def build(self, input_shape=None):
if self.built:
return
# 如果模型已经构建过,直接返回,避免重复构建。
self.built = True
# 将模型标记为已构建状态。
if getattr(self, "electra", None) is not None:
with tf.name_scope(self.electra.name):
self.electra.build(None)
# 如果 self.electra 存在,使用其名称作为命名空间,在该命名空间下构建 self.electra 模型。
if getattr(self, "discriminator_predictions", None) is not None:
with tf.name_scope(self.discriminator_predictions.name):
self.discriminator_predictions.build(None)
# 如果 self.discriminator_predictions 存在,使用其名称作为命名空间,在该命名空间下构建 self.discriminator_predictions 模型。
class TFElectraMaskedLMHead(keras.layers.Layer):
# 定义 Electra 模型的 Masked Language Modeling 头部的层
def __init__(self, config, input_embeddings, **kwargs):
super().__init__(**kwargs)
self.config = config
self.embedding_size = config.embedding_size
self.input_embeddings = input_embeddings
def build(self, input_shape):
# 添加权重,初始化偏置向量为全零向量
self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias")
super().build(input_shape)
def get_output_embeddings(self):
# 返回输入的嵌入层对象
return self.input_embeddings
def set_output_embeddings(self, value):
# 设置输入的嵌入层的权重和词汇大小
self.input_embeddings.weight = value
self.input_embeddings.vocab_size = shape_list(value)[0]
def get_bias(self):
# 返回偏置向量字典
return {"bias": self.bias}
def set_bias(self, value):
# 设置偏置向量
self.bias = value["bias"]
self.config.vocab_size = shape_list(value["bias"])[0]
def call(self, hidden_states):
# 计算 Masked Language Modeling 的输出
seq_length = shape_list(tensor=hidden_states)[1]
hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.embedding_size])
hidden_states = tf.matmul(a=hidden_states, b=self.input_embeddings.weight, transpose_b=True)
hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, seq_length, self.config.vocab_size])
hidden_states = tf.nn.bias_add(value=hidden_states, bias=self.bias)
return hidden_states
@add_start_docstrings(
"""
Electra model with a language modeling head on top.
Even though both the discriminator and generator may be loaded into this model, the generator is the only model of
the two to have been trained for the masked language modeling task.
""",
ELECTRA_START_DOCSTRING,
)
class TFElectraForMaskedLM(TFElectraPreTrainedModel, TFMaskedLanguageModelingLoss):
# Electra 模型加上顶部的语言建模头部
def __init__(self, config, **kwargs):
super().__init__(config, **kwargs)
self.config = config
# Electra 主层
self.electra = TFElectraMainLayer(config, name="electra")
# Electra 生成器预测
self.generator_predictions = TFElectraGeneratorPredictions(config, name="generator_predictions")
if isinstance(config.hidden_act, str):
self.activation = get_tf_activation(config.hidden_act)
else:
self.activation = config.hidden_act
# Electra 的 Masked Language Modeling 头部
self.generator_lm_head = TFElectraMaskedLMHead(config, self.electra.embeddings, name="generator_lm_head")
def get_lm_head(self):
# 返回 Masked Language Modeling 头部
return self.generator_lm_head
def get_prefix_bias_name(self):
# 警告:方法已弃用,请使用 `get_bias` 替代
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.generator_lm_head.name
@unpack_inputs
@add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
checkpoint="google/electra-small-generator",
output_type=TFMaskedLMOutput,
config_class=_CONFIG_FOR_DOC,
mask="[MASK]",
expected_output="'paris'",
expected_loss=1.22,
)
def call(
self,
input_ids: TFModelInputType | None = None,
attention_mask: np.ndarray | tf.Tensor | None = None,
token_type_ids: np.ndarray | tf.Tensor | None = None,
position_ids: np.ndarray | tf.Tensor | None = None,
head_mask: np.ndarray | tf.Tensor | None = None,
inputs_embeds: np.ndarray | tf.Tensor | None = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
labels: np.ndarray | tf.Tensor | None = None,
training: Optional[bool] = False,
) -> Union[TFMaskedLMOutput, Tuple[tf.Tensor]]:
r"""
Define the call function for the Electra generator model.
labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
"""
# Generate hidden states using the Electra model with provided inputs
generator_hidden_states = self.electra(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
# Extract sequence output from generator hidden states
generator_sequence_output = generator_hidden_states[0]
# Generate prediction scores using the generator predictions function
prediction_scores = self.generator_predictions(generator_sequence_output, training=training)
# Apply language modeling head to the generator prediction scores
prediction_scores = self.generator_lm_head(prediction_scores, training=training)
# Compute loss only if labels are provided using the provided loss computation function
loss = None if labels is None else self.hf_compute_loss(labels, prediction_scores)
# Prepare output based on whether return_dict is False or True
if not return_dict:
output = (prediction_scores,) + generator_hidden_states[1:]
return ((loss,) + output) if loss is not None else output
# Return TFMaskedLMOutput with detailed components if return_dict is True
return TFMaskedLMOutput(
loss=loss,
logits=prediction_scores,
hidden_states=generator_hidden_states.hidden_states,
attentions=generator_hidden_states.attentions,
)
# 构建模型的方法,用于设置模型结构和参数
def build(self, input_shape=None):
# 如果模型已经构建过,则直接返回,不再重复构建
if self.built:
return
# 将模型标记为已构建状态
self.built = True
# 如果存在名为 "electra" 的子模型,进行其构建
if getattr(self, "electra", None) is not None:
# 使用电力转换模型的名字作为命名空间
with tf.name_scope(self.electra.name):
# 调用电力转换模型的构建方法,输入形状为 None 表示使用默认形状
self.electra.build(None)
# 如果存在名为 "generator_predictions" 的子模型,进行其构建
if getattr(self, "generator_predictions", None) is not None:
# 使用生成器预测模型的名字作为命名空间
with tf.name_scope(self.generator_predictions.name):
# 调用生成器预测模型的构建方法,输入形状为 None 表示使用默认形状
self.generator_predictions.build(None)
# 如果存在名为 "generator_lm_head" 的子模型,进行其构建
if getattr(self, "generator_lm_head", None) is not None:
# 使用生成器语言模型头部的名字作为命名空间
with tf.name_scope(self.generator_lm_head.name):
# 调用生成器语言模型头部的构建方法,输入形状为 None 表示使用默认形状
self.generator_lm_head.build(None)
"""
ELECTRA Model transformer with a sequence classification/regression head on top (a linear layer on top of the
pooled output) e.g. for GLUE tasks.
"""
@add_start_docstrings(
"""
ELECTRA Model transformer with a sequence classification/regression head on top (a linear layer on top of the
pooled output) e.g. for GLUE tasks.
""",
ELECTRA_START_DOCSTRING,
)
class TFElectraForSequenceClassification(TFElectraPreTrainedModel, TFSequenceClassificationLoss):
"""
ELECTRA模型的转换器,顶部带有序列分类/回归头(在汇聚输出顶部的线性层),例如用于GLUE任务。
"""
def __init__(self, config, *inputs, **kwargs):
"""
初始化方法。
Args:
config (ElectraConfig): 模型的配置对象,包含模型的超参数。
*inputs: 可变长度的输入参数。
**kwargs: 其他关键字参数。
"""
super().__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels # 设置模型的标签数
self.electra = TFElectraMainLayer(config, name="electra") # ELECTRA主层对象
self.classifier = TFElectraClassificationHead(config, name="classifier") # 分类头部对象
@unpack_inputs
@add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
checkpoint="bhadresh-savani/electra-base-emotion",
output_type=TFSequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
expected_output="'joy'",
expected_loss=0.06,
)
def forward(self, *model_args, **model_kwargs):
"""
正向传播方法,根据输入计算模型输出。
Args:
*model_args: 可变长度的模型输入参数。
**model_kwargs: 模型输入的关键字参数。
Returns:
TFSequenceClassifierOutput: 序列分类器的输出对象。
"""
pass # 这里的方法体未提供,仅有注释和装饰器的声明
def call(
self,
input_ids: TFModelInputType | None = None, # 接收输入的文本序列的 ID,可以为空
attention_mask: np.ndarray | tf.Tensor | None = None, # 注意力掩码,用于指示模型在处理输入时哪些部分需要注意
token_type_ids: np.ndarray | tf.Tensor | None = None, # 用于区分不同文本序列的 token 类型 ID
position_ids: np.ndarray | tf.Tensor | None = None, # 表示输入中每个 token 的位置 ID
head_mask: np.ndarray | tf.Tensor | None = None, # 头部掩码,用于指示模型在自注意力机制中哪些头部需要被屏蔽
inputs_embeds: np.ndarray | tf.Tensor | None = None, # 可选的嵌入输入,可以直接提供输入的嵌入表示
output_attentions: Optional[bool] = None, # 是否输出注意力权重
output_hidden_states: Optional[bool] = None, # 是否输出隐藏状态
return_dict: Optional[bool] = None, # 是否返回结果字典
labels: np.ndarray | tf.Tensor | None = None, # 用于计算序列分类/回归损失的标签
training: Optional[bool] = False, # 是否处于训练模式
) -> Union[TFSequenceClassifierOutput, Tuple[tf.Tensor]]:
r"""
labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
# 调用 Electra 模型进行前向传播
outputs = self.electra(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
# 将 Electra 输出传递给分类器
logits = self.classifier(outputs[0])
# 如果提供了标签,则计算损失
loss = None if labels is None else self.hf_compute_loss(labels, logits)
# 根据 return_dict 参数决定返回结果的格式
if not return_dict:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
else:
# 如果 return_dict 为 True,则返回 TFSequenceClassifierOutput 对象
return TFSequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def build(self, input_shape=None):
# 如果已经构建过模型,则直接返回
if self.built:
return
self.built = True
# 如果存在 Electra 模型,则构建其内部结构
if getattr(self, "electra", None) is not None:
with tf.name_scope(self.electra.name):
self.electra.build(None)
# 如果存在分类器模型,则构建其内部结构
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build(None)
"""
ELECTRA 模型,顶部带有多选分类头部(在池化输出的基础上是一个线性层和一个 softmax),例如用于 RocStories/SWAG 任务。
继承自 TFElectraPreTrainedModel 和 TFMultipleChoiceLoss。
"""
@add_start_docstrings(
"""
ELECTRA Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
softmax) e.g. for RocStories/SWAG tasks.
""",
ELECTRA_START_DOCSTRING,
)
class TFElectraForMultipleChoice(TFElectraPreTrainedModel, TFMultipleChoiceLoss):
def __init__(self, config, *inputs, **kwargs):
"""
初始化方法,设置模型的各个组件。
Parameters:
- config: ELECTRA 模型的配置对象。
- *inputs: 可变长度的输入。
- **kwargs: 其他关键字参数。
"""
super().__init__(config, *inputs, **kwargs)
# ELECTRA 主体层
self.electra = TFElectraMainLayer(config, name="electra")
# 序列汇总层
self.sequence_summary = TFSequenceSummary(
config, initializer_range=config.initializer_range, name="sequence_summary"
)
# 分类器层
self.classifier = keras.layers.Dense(
1, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
)
# 保存配置对象
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TFMultipleChoiceModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def call(
self,
input_ids: TFModelInputType | None = None,
attention_mask: np.ndarray | tf.Tensor | None = None,
token_type_ids: np.ndarray | tf.Tensor | None = None,
position_ids: np.ndarray | tf.Tensor | None = None,
head_mask: np.ndarray | tf.Tensor | None = None,
inputs_embeds: np.ndarray | tf.Tensor | None = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
labels: np.ndarray | tf.Tensor | None = None,
training: Optional[bool] = False,
"""
调用方法,执行 ELECTRA 模型的前向传播。
Parameters:
- input_ids: 输入的 token IDs。
- attention_mask: 注意力掩码。
- token_type_ids: token 类型 IDs。
- position_ids: 位置 IDs。
- head_mask: 头部掩码。
- inputs_embeds: 输入的嵌入。
- output_attentions: 是否输出注意力。
- output_hidden_states: 是否输出隐藏状态。
- return_dict: 是否返回字典形式结果。
- labels: 标签数据。
- training: 是否处于训练模式。
Returns:
ELECTRA 模型的输出对象。
"""
...
) -> Union[TFMultipleChoiceModelOutput, Tuple[tf.Tensor]]:
r"""
labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]`
where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above)
"""
# 如果给定了 input_ids,则获取其第二和第三维的大小
if input_ids is not None:
num_choices = shape_list(input_ids)[1]
seq_length = shape_list(input_ids)[2]
else:
# 如果没有 input_ids,则获取 inputs_embeds 的第二和第三维的大小
num_choices = shape_list(inputs_embeds)[1]
seq_length = shape_list(inputs_embeds)[2]
# 将输入张量展平为二维张量,如果相应输入不为 None
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
flat_inputs_embeds = (
tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3]))
if inputs_embeds is not None
else None
)
# 调用 Electra 模型进行前向传播,传入展平后的张量及其他参数
outputs = self.electra(
input_ids=flat_input_ids,
attention_mask=flat_attention_mask,
token_type_ids=flat_token_type_ids,
position_ids=flat_position_ids,
head_mask=head_mask,
inputs_embeds=flat_inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
# 对 Electra 模型的输出进行序列汇总
logits = self.sequence_summary(outputs[0])
# 将汇总后的序列 logits 输入分类器进行分类预测
logits = self.classifier(logits)
# 重新整形 logits 张量为形状为 (-1, num_choices)
reshaped_logits = tf.reshape(logits, (-1, num_choices))
# 如果提供了 labels,则计算损失
loss = None if labels is None else self.hf_compute_loss(labels, reshaped_logits)
# 如果 return_dict=False,则按指定格式返回结果
if not return_dict:
output = (reshaped_logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
# 如果 return_dict=True,则返回带有多选模型输出的对象
return TFMultipleChoiceModelOutput(
loss=loss,
logits=reshaped_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def build(self, input_shape=None):
# 如果已经构建,则直接返回
if self.built:
return
# 标记已构建
self.built = True
# 如果存在 self.electra 属性,则构建 Electra 模型
if getattr(self, "electra", None) is not None:
with tf.name_scope(self.electra.name):
self.electra.build(None)
# 如果存在 self.sequence_summary 属性,则构建序列汇总层
if getattr(self, "sequence_summary", None) is not None:
with tf.name_scope(self.sequence_summary.name):
self.sequence_summary.build(None)
# 如果存在 self.classifier 属性,则构建分类器层
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build([None, None, self.config.hidden_size])
@add_start_docstrings(
"""
Electra model with a token classification head on top.
Both the discriminator and generator may be loaded into this model.
""",
ELECTRA_START_DOCSTRING,
)
class TFElectraForTokenClassification(TFElectraPreTrainedModel, TFTokenClassificationLoss):
def __init__(self, config, **kwargs):
super().__init__(config, **kwargs)
# 初始化 Electra 主模型层,命名为 "electra"
self.electra = TFElectraMainLayer(config, name="electra")
# 根据配置中的 dropout 概率设置分类器的 dropout 层
classifier_dropout = (
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
)
self.dropout = keras.layers.Dropout(classifier_dropout)
# 定义一个全连接层作为分类器,输出维度为类别数目
self.classifier = keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
)
# 将配置保存在对象中
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
checkpoint="bhadresh-savani/electra-base-discriminator-finetuned-conll03-english",
output_type=TFTokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
expected_output="['B-LOC', 'B-ORG', 'O', 'O', 'O', 'O', 'O', 'B-LOC', 'O', 'B-LOC', 'I-LOC']",
expected_loss=0.11,
)
def call(
self,
input_ids: TFModelInputType | None = None,
attention_mask: np.ndarray | tf.Tensor | None = None,
token_type_ids: np.ndarray | tf.Tensor | None = None,
position_ids: np.ndarray | tf.Tensor | None = None,
head_mask: np.ndarray | tf.Tensor | None = None,
inputs_embeds: np.ndarray | tf.Tensor | None = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
labels: np.ndarray | tf.Tensor | None = None,
training: Optional[bool] = False,
) -> Union[TFTokenClassifierOutput, Tuple[tf.Tensor]]:
r"""
labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
"""
# 调用 ELECTRA 模型进行预测,获取鉴别器的隐藏状态
discriminator_hidden_states = self.electra(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
# 从鉴别器的隐藏状态中取出序列输出
discriminator_sequence_output = discriminator_hidden_states[0]
# 对鉴别器的序列输出应用 dropout 操作
discriminator_sequence_output = self.dropout(discriminator_sequence_output)
# 将 dropout 后的输出传递给分类器,得到预测的 logits
logits = self.classifier(discriminator_sequence_output)
# 如果提供了标签,则计算损失
loss = None if labels is None else self.hf_compute_loss(labels, logits)
# 根据 return_dict 的值决定返回的结果格式
if not return_dict:
# 如果不要求返回字典,则输出 logits 和其它隐藏状态
output = (logits,) + discriminator_hidden_states[1:]
return ((loss,) + output) if loss is not None else output
# 如果要求返回字典格式的结果,则返回 TFTokenClassifierOutput 对象
return TFTokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=discriminator_hidden_states.hidden_states,
attentions=discriminator_hidden_states.attentions,
)
def build(self, input_shape=None):
# 如果已经构建过模型,则直接返回
if self.built:
return
# 标记模型已经构建
self.built = True
# 如果存在 ELECTRA 模型,建立其内部结构
if getattr(self, "electra", None) is not None:
with tf.name_scope(self.electra.name):
self.electra.build(None)
# 如果存在分类器模型,建立其内部结构
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build([None, None, self.config.hidden_size])
# 使用装饰器添加模型文档字符串,描述了 Electra 模型在提取式问答任务(如 SQuAD)中的应用,包括在隐藏状态输出之上的线性层,用于计算“span start logits”和“span end logits”。
@add_start_docstrings(
"""
Electra Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
""",
ELECTRA_START_DOCSTRING,
)
# 定义 TFElectraForQuestionAnswering 类,继承自 TFElectraPreTrainedModel 和 TFQuestionAnsweringLoss
class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnsweringLoss):
# 初始化方法,接受配置 config 和其他输入参数
def __init__(self, config, *inputs, **kwargs):
# 调用父类的初始化方法
super().__init__(config, *inputs, **kwargs)
# 设置模型的标签数量
self.num_labels = config.num_labels
# 创建 Electra 主层对象,命名为 "electra"
self.electra = TFElectraMainLayer(config, name="electra")
# 创建输出层,使用 Dense 层,输出大小为 config.num_labels,使用指定的初始化器初始化权重,命名为 "qa_outputs"
self.qa_outputs = keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
)
# 保存配置对象
self.config = config
# 使用装饰器来包装 call 方法,添加模型前向传播的文档字符串
@unpack_inputs
@add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
checkpoint="bhadresh-savani/electra-base-squad2",
output_type=TFQuestionAnsweringModelOutput,
config_class=_CONFIG_FOR_DOC,
qa_target_start_index=11,
qa_target_end_index=12,
expected_output="'a nice puppet'",
expected_loss=2.64,
)
# 定义模型的前向传播方法,接受多个输入参数和一些控制参数
def call(
self,
input_ids: TFModelInputType | None = None,
attention_mask: np.ndarray | tf.Tensor | None = None,
token_type_ids: np.ndarray | tf.Tensor | None = None,
position_ids: np.ndarray | tf.Tensor | None = None,
head_mask: np.ndarray | tf.Tensor | None = None,
inputs_embeds: np.ndarray | tf.Tensor | None = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
start_positions: np.ndarray | tf.Tensor | None = None,
end_positions: np.ndarray | tf.Tensor | None = None,
training: Optional[bool] = False,
# 结尾未完,继续下一页 。
) -> Union[TFQuestionAnsweringModelOutput, Tuple[tf.Tensor]]:
r"""
start_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
end_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
"""
discriminator_hidden_states = self.electra(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
# 获取鉴别器模型的隐藏状态
discriminator_sequence_output = discriminator_hidden_states[0]
# 通过输出序列计算问题回答的逻辑张量
logits = self.qa_outputs(discriminator_sequence_output)
# 将逻辑张量沿最后一个维度分割为起始和结束的逻辑张量
start_logits, end_logits = tf.split(logits, 2, axis=-1)
# 压缩起始和结束的逻辑张量的最后一个维度
start_logits = tf.squeeze(start_logits, axis=-1)
end_logits = tf.squeeze(end_logits, axis=-1)
# 初始化损失变量
loss = None
# 如果提供了起始和结束的位置信息,则计算损失
if start_positions is not None and end_positions is not None:
# 准备标签,用于计算损失
labels = {"start_position": start_positions}
labels["end_position"] = end_positions
# 使用标签和预测的逻辑张量计算损失
loss = self.hf_compute_loss(labels, (start_logits, end_logits))
# 如果不要求返回字典,则组装输出
if not return_dict:
output = (
start_logits,
end_logits,
) + discriminator_hidden_states[1:]
# 返回损失和输出,如果损失不为None
return ((loss,) + output) if loss is not None else output
# 返回 TFQuestionAnsweringModelOutput 对象,包含损失和其他输出信息
return TFQuestionAnsweringModelOutput(
loss=loss,
start_logits=start_logits,
end_logits=end_logits,
hidden_states=discriminator_hidden_states.hidden_states,
attentions=discriminator_hidden_states.attentions,
)
def build(self, input_shape=None):
# 如果模型已经构建,则直接返回
if self.built:
return
# 标记模型已经构建
self.built = True
# 如果存在 electra 层,则构建它
if getattr(self, "electra", None) is not None:
with tf.name_scope(self.electra.name):
self.electra.build(None)
# 如果存在 qa_outputs 层,则构建它
if getattr(self, "qa_outputs", None) is not None:
with tf.name_scope(self.qa_outputs.name):
self.qa_outputs.build([None, None, self.config.hidden_size])
.\models\electra\tokenization_electra.py
from typing import List, Optional, Tuple
from ...tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace
from ...utils import logging
logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"google/electra-small-generator": (
"https://huggingface.co/google/electra-small-generator/resolve/main/vocab.txt"
),
"google/electra-base-generator": "https://huggingface.co/google/electra-base-generator/resolve/main/vocab.txt",
"google/electra-large-generator": (
"https://huggingface.co/google/electra-large-generator/resolve/main/vocab.txt"
),
"google/electra-small-discriminator": (
"https://huggingface.co/google/electra-small-discriminator/resolve/main/vocab.txt"
),
"google/electra-base-discriminator": (
"https://huggingface.co/google/electra-base-discriminator/resolve/main/vocab.txt"
),
"google/electra-large-discriminator": (
"https://huggingface.co/google/electra-large-discriminator/resolve/main/vocab.txt"
),
}
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"google/electra-small-generator": 512,
"google/electra-base-generator": 512,
"google/electra-large-generator": 512,
"google/electra-small-discriminator": 512,
"google/electra-base-discriminator": 512,
"google/electra-large-discriminator": 512,
}
PRETRAINED_INIT_CONFIGURATION = {
"google/electra-small-generator": {"do_lower_case": True},
"google/electra-base-generator": {"do_lower_case": True},
"google/electra-large-generator": {"do_lower_case": True},
"google/electra-small-discriminator": {"do_lower_case": True},
"google/electra-base-discriminator": {"do_lower_case": True},
"google/electra-large-discriminator": {"do_lower_case": True},
}
def load_vocab(vocab_file):
"""Loads a vocabulary file into a dictionary."""
vocab = collections.OrderedDict()
with open(vocab_file, "r", encoding="utf-8") as reader:
tokens = reader.readlines()
for index, token in enumerate(tokens):
token = token.rstrip("\n")
vocab[token] = index
return vocab
def whitespace_tokenize(text):
"""Runs basic whitespace cleaning and splitting on a piece of text."""
text = text.strip()
if not text:
return []
tokens = text.split()
return tokens
class ElectraTokenizer(PreTrainedTokenizer):
r"""
Construct a Electra tokenizer. Based on WordPiece.
This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
this superclass for more information regarding those methods.
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(
self,
vocab_file,
do_lower_case=True,
do_basic_tokenize=True,
never_split=None,
unk_token="[UNK]",
sep_token="[SEP]",
pad_token="[PAD]",
cls_token="[CLS]",
mask_token="[MASK]",
tokenize_chinese_chars=True,
strip_accents=None,
**kwargs,
):
"""
Args:
vocab_file (`str`):
包含词汇表的文件。
do_lower_case (`bool`, *optional*, defaults to `True`):
是否在进行分词时将输入转换为小写。
do_basic_tokenize (`bool`, *optional*, defaults to `True`):
是否在WordPiece分词前进行基本分词。
never_split (`Iterable`, *optional*):
在分词过程中不应拆分的标记集合。仅在 `do_basic_tokenize=True` 时有效。
unk_token (`str`, *optional*, defaults to `"[UNK]"`):
未知标记。当输入中的标记不在词汇表中时,将其替换为此标记。
sep_token (`str`, *optional*, defaults to `"[SEP]"`):
分隔符标记,在构建多个序列的序列时使用,例如序列分类或问答问题时使用。也用作构建带有特殊标记的序列的最后一个标记。
pad_token (`str`, *optional*, defaults to `"[PAD]"`):
用于填充的标记,例如在批处理不同长度的序列时使用。
cls_token (`str`, *optional*, defaults to `"[CLS]"`):
分类器标记,在进行序列分类时使用(整个序列的分类而不是每个标记的分类)。它是构建带有特殊标记的序列的第一个标记。
mask_token (`str`, *optional*, defaults to `"[MASK]"`):
用于屏蔽值的标记。这是在进行遮蔽语言建模训练时使用的标记。模型将尝试预测此标记。
tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
是否对中文字符进行分词。
对于日语可能需要禁用此选项(参见此问题: https://github.com/huggingface/transformers/issues/328)。
strip_accents (`bool`, *optional*):
是否删除所有重音符号。如果未指定此选项,则将根据 `lowercase` 的值来确定(与原始Electra一样)。
"""
):
if not os.path.isfile(vocab_file):
raise ValueError(
f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained"
" model use `tokenizer = ElectraTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
)
self.vocab = load_vocab(vocab_file)
self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
self.do_basic_tokenize = do_basic_tokenize
if do_basic_tokenize:
self.basic_tokenizer = BasicTokenizer(
do_lower_case=do_lower_case,
never_split=never_split,
tokenize_chinese_chars=tokenize_chinese_chars,
strip_accents=strip_accents,
)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=str(unk_token))
super().__init__(
do_lower_case=do_lower_case,
do_basic_tokenize=do_basic_tokenize,
never_split=never_split,
unk_token=unk_token,
sep_token=sep_token,
pad_token=pad_token,
cls_token=cls_token,
mask_token=mask_token,
tokenize_chinese_chars=tokenize_chinese_chars,
strip_accents=strip_accents,
**kwargs,
)
@property
def do_lower_case(self):
return self.basic_tokenizer.do_lower_case
@property
def vocab_size(self):
return len(self.vocab)
def get_vocab(self):
return dict(self.vocab, **self.added_tokens_encoder)
def _tokenize(self, text, split_special_tokens=False):
split_tokens = []
if self.do_basic_tokenize:
for token in self.basic_tokenizer.tokenize(
text, never_split=self.all_special_tokens if not split_special_tokens else None
):
if token in self.basic_tokenizer.never_split:
split_tokens.append(token)
else:
split_tokens += self.wordpiece_tokenizer.tokenize(token)
else:
split_tokens = self.wordpiece_tokenizer.tokenize(text)
return split_tokens
def _convert_token_to_id(self, token):
"""Converts a token (str) in an id using the vocab."""
return self.vocab.get(token, self.vocab.get(self.unk_token))
def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab."""
return self.ids_to_tokens.get(index, self.unk_token)
def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (string) in a single string."""
out_string = " ".join(tokens).replace(" ##", "").strip()
return out_string
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
adding special tokens. A Electra sequence has the following format:
- single sequence: `[CLS] X [SEP]`
- pair of sequences: `[CLS] A [SEP] B [SEP]`
Args:
token_ids_0 (`List[int]`):
List of IDs to which the special tokens will be added.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
"""
if token_ids_1 is None:
return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
cls = [self.cls_token_id]
sep = [self.sep_token_id]
return cls + token_ids_0 + sep + token_ids_1 + sep
def get_special_tokens_mask(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
) -> List[int]:
"""
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer `prepare_for_model` method.
Args:
token_ids_0 (`List[int]`):
List of IDs.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not the token list is already formatted with special tokens for the model.
Returns:
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
"""
if already_has_special_tokens:
return super().get_special_tokens_mask(
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
)
if token_ids_1 is not None:
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
return [1] + ([0] * len(token_ids_0)) + [1]
def create_token_type_ids_from_sequences(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Create token type IDs tensor from token id tensors. `0` for the first sentence tokens, `1` for the second sentence
tokens.
Args:
token_ids_0 (`List[int]`):
List of IDs.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: List of token type IDs according to the sequences provided.
"""
) -> List[int]:
"""
Create a mask from the two sequences passed to be used in a sequence-pair classification task. A Electra sequence
pair mask has the following format:
```
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
| first sequence | second sequence |
```
If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s).
Args:
token_ids_0 (`List[int]`):
List of IDs.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
"""
sep = [self.sep_token_id]
cls = [self.cls_token_id]
if token_ids_1 is None:
return len(cls + token_ids_0 + sep) * [0]
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
index = 0
if os.path.isdir(save_directory):
vocab_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
)
else:
vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory
with open(vocab_file, "w", encoding="utf-8") as writer:
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning(
f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive."
" Please check that the vocabulary is not corrupted!"
)
index = token_index
writer.write(token + "\n")
index += 1
return (vocab_file,)
class BasicTokenizer(object):
"""
Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.).
Args:
do_lower_case (`bool`, *optional*, defaults to `True`):
Whether or not to lowercase the input when tokenizing.
是否在分词时将输入转换为小写。
never_split (`Iterable`, *optional*):
Collection of tokens which will never be split during tokenization. Only has an effect when
`do_basic_tokenize=True`
在分词过程中永远不会被分开的标记集合。仅在 `do_basic_tokenize=True` 时有效。
tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
Whether or not to tokenize Chinese characters.
是否对中文字符进行分词。建议对日文关闭此选项(参见这个问题链接)。
strip_accents (`bool`, *optional*):
Whether or not to strip all accents. If this option is not specified, then it will be determined by the
value for `lowercase` (as in the original BERT).
是否去除所有重音符号。如果未指定此选项,则由 `lowercase` 的值来确定(与原始的BERT一致)。
do_split_on_punc (`bool`, *optional*, defaults to `True`):
In some instances we want to skip the basic punctuation splitting so that later tokenization can capture
the full context of the words, such as contractions.
在某些情况下,我们希望跳过基本的标点符号分割,以便后续的分词可以捕获单词的完整上下文,比如缩写词。
"""
def __init__(
self,
do_lower_case=True,
never_split=None,
tokenize_chinese_chars=True,
strip_accents=None,
do_split_on_punc=True,
):
if never_split is None:
never_split = []
self.do_lower_case = do_lower_case
self.never_split = set(never_split)
self.tokenize_chinese_chars = tokenize_chinese_chars
self.strip_accents = strip_accents
self.do_split_on_punc = do_split_on_punc
def tokenize`
def tokenize(self, text, never_split=None):
"""
Basic Tokenization of a piece of text. For sub-word tokenization, see WordPieceTokenizer.
Args:
never_split (`List[str]`, *optional*)
Kept for backward compatibility purposes. Now implemented directly at the base class level (see
[`PreTrainedTokenizer.tokenize`]) List of token not to split.
"""
never_split = self.never_split.union(set(never_split)) if never_split else self.never_split
text = self._clean_text(text)
if self.tokenize_chinese_chars:
text = self._tokenize_chinese_chars(text)
unicode_normalized_text = unicodedata.normalize("NFC", text)
orig_tokens = whitespace_tokenize(unicode_normalized_text)
split_tokens = []
for token in orig_tokens:
if token not in never_split:
if self.do_lower_case:
token = token.lower()
if self.strip_accents is not False:
token = self._run_strip_accents(token)
elif self.strip_accents:
token = self._run_strip_accents(token)
split_tokens.extend(self._run_split_on_punc(token, never_split))
output_tokens = whitespace_tokenize(" ".join(split_tokens))
return output_tokens
def _run_strip_accents(self, text):
"""Strips accents from a piece of text."""
text = unicodedata.normalize("NFD", text)
output = []
for char in text:
cat = unicodedata.category(char)
if cat == "Mn":
continue
output.append(char)
return "".join(output)
def _run_split_on_punc(self, text, never_split=None):
"""Splits punctuation on a piece of text."""
if not self.do_split_on_punc or (never_split is not None and text in never_split):
return [text]
chars = list(text)
i = 0
start_new_word = True
output = []
while i < len(chars):
char = chars[i]
if _is_punctuation(char):
output.append([char])
start_new_word = True
else:
if start_new_word:
output.append([])
start_new_word = False
output[-1].append(char)
i += 1
return ["".join(x) for x in output]
def _tokenize_chinese_chars(self, text):
"""Adds whitespace around any CJK character."""
output = []
for char in text:
cp = ord(char)
if self._is_chinese_char(cp):
output.append(" ")
output.append(char)
output.append(" ")
else:
output.append(char)
return "".join(output)
def _is_chinese_char(self, cp):
"""Checks whether CP is the codepoint of a CJK character."""
if (
(cp >= 0x4E00 and cp <= 0x9FFF)
or (cp >= 0x3400 and cp <= 0x4DBF)
or (cp >= 0x20000 and cp <= 0x2A6DF)
or (cp >= 0x2A700 and cp <= 0x2B73F)
or (cp >= 0x2B740 and cp <= 0x2B81F)
or (cp >= 0x2B820 and cp <= 0x2CEAF)
or (cp >= 0xF900 and cp <= 0xFAFF)
or (cp >= 0x2F800 and cp <= 0x2FA1F)
):
return True
return False
def _clean_text(self, text):
"""Performs invalid character removal and whitespace cleanup on text."""
output = []
for char in text:
cp = ord(char)
if cp == 0 or cp == 0xFFFD or _is_control(char):
continue
if _is_whitespace(char):
output.append(" ")
else:
output.append(char)
return "".join(output)
class WordpieceTokenizer(object):
"""Runs WordPiece tokenization."""
def __init__(self, vocab, unk_token, max_input_chars_per_word=100):
self.vocab = vocab
self.unk_token = unk_token
self.max_input_chars_per_word = max_input_chars_per_word
def tokenize(self, text):
"""
Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform
tokenization using the given vocabulary.
For example, `input = "unaffable"` wil return as output `["un", "##aff", "##able"]`.
Args:
text: A single token or whitespace separated tokens. This should have
already been passed through *BasicTokenizer*.
Returns:
A list of wordpiece tokens.
"""
output_tokens = []
for token in whitespace_tokenize(text):
chars = list(token)
if len(chars) > self.max_input_chars_per_word:
output_tokens.append(self.unk_token)
continue
is_bad = False
start = 0
sub_tokens = []
while start < len(chars):
end = len(chars)
cur_substr = None
while start < end:
substr = "".join(chars[start:end])
if start > 0:
substr = "##" + substr
if substr in self.vocab:
cur_substr = substr
break
end -= 1
if cur_substr is None:
is_bad = True
break
sub_tokens.append(cur_substr)
start = end
if is_bad:
output_tokens.append(self.unk_token)
else:
output_tokens.extend(sub_tokens)
return output_tokens
.\models\electra\tokenization_electra_fast.py
import json
from typing import List, Optional, Tuple
from tokenizers import normalizers
from ...tokenization_utils_fast import PreTrainedTokenizerFast
from .tokenization_electra import ElectraTokenizer
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer.json"}
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"google/electra-small-generator": (
"https://huggingface.co/google/electra-small-generator/resolve/main/vocab.txt"
),
"google/electra-base-generator": "https://huggingface.co/google/electra-base-generator/resolve/main/vocab.txt",
"google/electra-large-generator": (
"https://huggingface.co/google/electra-large-generator/resolve/main/vocab.txt"
),
"google/electra-small-discriminator": (
"https://huggingface.co/google/electra-small-discriminator/resolve/main/vocab.txt"
),
"google/electra-base-discriminator": (
"https://huggingface.co/google/electra-base-discriminator/resolve/main/vocab.txt"
),
"google/electra-large-discriminator": (
"https://huggingface.co/google/electra-large-discriminator/resolve/main/vocab.txt"
),
},
"tokenizer_file": {
"google/electra-small-generator": (
"https://huggingface.co/google/electra-small-generator/resolve/main/tokenizer.json"
),
"google/electra-base-generator": (
"https://huggingface.co/google/electra-base-generator/resolve/main/tokenizer.json"
),
"google/electra-large-generator": (
"https://huggingface.co/google/electra-large-generator/resolve/main/tokenizer.json"
),
"google/electra-small-discriminator": (
"https://huggingface.co/google/electra-small-discriminator/resolve/main/tokenizer.json"
),
"google/electra-base-discriminator": (
"https://huggingface.co/google/electra-base-discriminator/resolve/main/tokenizer.json"
),
"google/electra-large-discriminator": (
"https://huggingface.co/google/electra-large-discriminator/resolve/main/tokenizer.json"
),
},
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"google/electra-small-generator": 512,
"google/electra-base-generator": 512,
"google/electra-large-generator": 512,
"google/electra-small-discriminator": 512,
"google/electra-base-discriminator": 512,
"google/electra-large-discriminator": 512,
}
PRETRAINED_INIT_CONFIGURATION = {
"google/electra-small-generator": {"do_lower_case": True},
"google/electra-base-generator": {"do_lower_case": True},
"google/electra-large-generator": {"do_lower_case": True},
"google/electra-small-discriminator": {"do_lower_case": True},
"google/electra-base-discriminator": {"do_lower_case": True},
"google/electra-large-discriminator": {"do_lower_case": True},
}
class ElectraTokenizerFast(PreTrainedTokenizerFast):
r"""
构建一个“快速”的ELECTRA分词器(基于HuggingFace的*tokenizers*库),基于WordPiece。
此分词器继承自[`PreTrainedTokenizerFast`],其中包含大多数主要方法。用户应参考该超类获取更多关于这些方法的信息。
```
# 定义一个类,实现ElectraTokenizer的功能
class ElectraTokenizer:
# 默认的词汇文件名列表
vocab_files_names = VOCAB_FILES_NAMES
# 预训练模型的词汇文件映射
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
# 预训练模型的初始化配置
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
# 预训练位置嵌入的最大模型输入大小
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
# ElectraTokenizer 的慢速实现类
slow_tokenizer_class = ElectraTokenizer
# 初始化方法,用于创建一个 ElectraTokenizer 对象
def __init__(
self,
vocab_file=None,
tokenizer_file=None,
do_lower_case=True,
unk_token="[UNK]",
sep_token="[SEP]",
pad_token="[PAD]",
cls_token="[CLS]",
mask_token="[MASK]",
tokenize_chinese_chars=True,
strip_accents=None,
**kwargs,
):
):
super().__init__(
vocab_file,
tokenizer_file=tokenizer_file,
do_lower_case=do_lower_case,
unk_token=unk_token,
sep_token=sep_token,
pad_token=pad_token,
cls_token=cls_token,
mask_token=mask_token,
tokenize_chinese_chars=tokenize_chinese_chars,
strip_accents=strip_accents,
**kwargs,
)
# 调用父类的初始化方法,传入必要的参数和关键字参数来初始化对象。
normalizer_state = json.loads(self.backend_tokenizer.normalizer.__getstate__())
# 从后端的分词器对象中获取标准化器的状态,将其反序列化为Python对象。
if (
normalizer_state.get("lowercase", do_lower_case) != do_lower_case
or normalizer_state.get("strip_accents", strip_accents) != strip_accents
or normalizer_state.get("handle_chinese_chars", tokenize_chinese_chars) != tokenize_chinese_chars
):
# 检查标准化器的状态是否与当前对象的参数匹配,如果不匹配则需要更新标准化器。
normalizer_class = getattr(normalizers, normalizer_state.pop("type"))
normalizer_state["lowercase"] = do_lower_case
normalizer_state["strip_accents"] = strip_accents
normalizer_state["handle_chinese_chars"] = tokenize_chinese_chars
self.backend_tokenizer.normalizer = normalizer_class(**normalizer_state)
# 如果有不匹配的参数,根据标准化器的类型更新标准化器对象,确保与当前对象的参数一致。
self.do_lower_case = do_lower_case
# 更新当前对象的小写参数为传入的do_lower_case值。
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
"""
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
adding special tokens. A ELECTRA sequence has the following format:
- single sequence: `[CLS] X [SEP]`
- pair of sequences: `[CLS] A [SEP] B [SEP]`
Args:
token_ids_0 (`List[int]`):
List of IDs to which the special tokens will be added.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: List of [input IDs](../glossary
"""
output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
# 构建模型输入,根据输入的序列或序列对进行连接并添加特殊标记,用于序列分类任务。ELECTRA序列的格式包括单一序列和序列对,对应不同的特殊标记。
if token_ids_1 is not None:
output += token_ids_1 + [self.sep_token_id]
return output
# 如果提供了第二个序列token_ids_1,则将其连接到output中并添加特殊分隔标记,最后返回构建好的输入列表。
def create_token_type_ids_from_sequences(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
# 根据给定的序列创建token type IDs,用于区分不同序列的类型。
def create_electra_mask(self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None) -> List[int]:
"""
Create a mask from the two sequences passed to be used in a sequence-pair classification task. A ELECTRA sequence
pair mask has the following format:
```
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
| first sequence | second sequence |
```
If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s).
Args:
token_ids_0 (`List[int]`):
List of IDs for the first sequence.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: List of token type IDs according to the given sequence(s).
"""
# Define the separation and classification tokens
sep = [self.sep_token_id]
cls = [self.cls_token_id]
# If token_ids_1 is not provided, return a mask with zeros for the first sequence only
if token_ids_1 is None:
return len(cls + token_ids_0 + sep) * [0]
# Return a mask with zeros for the first sequence and ones for the second sequence
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
"""
Save the tokenizer's vocabulary files to the specified directory.
Args:
save_directory (str):
Directory where the vocabulary files will be saved.
filename_prefix (str, *optional*):
Optional prefix for the saved files.
Returns:
`Tuple[str]`: Tuple containing the filenames of the saved vocabulary files.
"""
# Save the model's vocabulary files using the tokenizer's internal method
files = self._tokenizer.model.save(save_directory, name=filename_prefix)
return tuple(files)
.\models\electra\__init__.py
from typing import TYPE_CHECKING
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_flax_available,
is_tf_available,
is_tokenizers_available,
is_torch_available,
)
_import_structure = {
"configuration_electra": ["ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP", "ElectraConfig", "ElectraOnnxConfig"],
"tokenization_electra": ["ElectraTokenizer"],
}
try:
if not is_tokenizers_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["tokenization_electra_fast"] = ["ElectraTokenizerFast"]
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_electra"] = [
"ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST",
"ElectraForCausalLM",
"ElectraForMaskedLM",
"ElectraForMultipleChoice",
"ElectraForPreTraining",
"ElectraForQuestionAnswering",
"ElectraForSequenceClassification",
"ElectraForTokenClassification",
"ElectraModel",
"ElectraPreTrainedModel",
"load_tf_weights_in_electra",
]
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_tf_electra"] = [
"TF_ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFElectraForMaskedLM",
"TFElectraForMultipleChoice",
"TFElectraForPreTraining",
"TFElectraForQuestionAnswering",
"TFElectraForSequenceClassification",
"TFElectraForTokenClassification",
"TFElectraModel",
"TFElectraPreTrainedModel",
]
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_flax_electra"] = [
"FlaxElectraForCausalLM",
"FlaxElectraForMaskedLM",
"FlaxElectraForMultipleChoice",
"FlaxElectraForPreTraining",
"FlaxElectraForQuestionAnswering",
"FlaxElectraForSequenceClassification",
"FlaxElectraForTokenClassification",
"FlaxElectraModel",
"FlaxElectraPreTrainedModel",
]
if TYPE_CHECKING:
pass
from .configuration_electra import ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, ElectraConfig, ElectraOnnxConfig
try:
if not is_tokenizers_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .tokenization_electra_fast import ElectraTokenizerFast
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_electra import (
ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST,
ElectraForCausalLM,
ElectraForMaskedLM,
ElectraForMultipleChoice,
ElectraForPreTraining,
ElectraForQuestionAnswering,
ElectraForSequenceClassification,
ElectraForTokenClassification,
ElectraModel,
ElectraPreTrainedModel,
load_tf_weights_in_electra,
)
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_tf_electra import (
TF_ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST,
TFElectraForMaskedLM,
TFElectraForMultipleChoice,
TFElectraForPreTraining,
TFElectraForQuestionAnswering,
TFElectraForSequenceClassification,
TFElectraForTokenClassification,
TFElectraModel,
TFElectraPreTrainedModel,
)
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_flax_electra import (
FlaxElectraForCausalLM,
FlaxElectraForMaskedLM,
FlaxElectraForMultipleChoice,
FlaxElectraForPreTraining,
FlaxElectraForQuestionAnswering,
FlaxElectraForSequenceClassification,
FlaxElectraForTokenClassification,
FlaxElectraModel,
FlaxElectraPreTrainedModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
.\models\encodec\configuration_encodec.py
""" EnCodec model configuration"""
import math
from typing import Optional
import numpy as np
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
ENCODEC_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"facebook/encodec_24khz": "https://huggingface.co/facebook/encodec_24khz/resolve/main/config.json",
"facebook/encodec_48khz": "https://huggingface.co/facebook/encodec_48khz/resolve/main/config.json",
}
class EncodecConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of an [`EncodecModel`]. It is used to instantiate a
Encodec model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the
[facebook/encodec_24khz](https://huggingface.co/facebook/encodec_24khz) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Example:
```
>>> from transformers import EncodecModel, EncodecConfig
>>> # Initializing a "facebook/encodec_24khz" style configuration
>>> configuration = EncodecConfig()
>>> # Initializing a model (with random weights) from the "facebook/encodec_24khz" style configuration
>>> model = EncodecModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
"""
model_type = "encodec"
def __init__(
self,
target_bandwidths=[1.5, 3.0, 6.0, 12.0, 24.0],
sampling_rate=24_000,
audio_channels=1,
normalize=False,
chunk_length_s=None,
overlap=None,
hidden_size=128,
num_filters=32,
num_residual_layers=1,
upsampling_ratios=[8, 5, 4, 2],
norm_type="weight_norm",
kernel_size=7,
last_kernel_size=7,
residual_kernel_size=3,
dilation_growth_rate=2,
use_causal_conv=True,
pad_mode="reflect",
compress=2,
num_lstm_layers=2,
trim_right_ratio=1.0,
codebook_size=1024,
codebook_dim=None,
use_conv_shortcut=True,
**kwargs,
):
self.target_bandwidths = target_bandwidths
self.sampling_rate = sampling_rate
self.audio_channels = audio_channels
self.normalize = normalize
self.chunk_length_s = chunk_length_s
self.overlap = overlap
self.hidden_size = hidden_size
self.num_filters = num_filters
self.num_residual_layers = num_residual_layers
self.upsampling_ratios = upsampling_ratios
self.norm_type = norm_type
self.kernel_size = kernel_size
self.last_kernel_size = last_kernel_size
self.residual_kernel_size = residual_kernel_size
self.dilation_growth_rate = dilation_growth_rate
self.use_causal_conv = use_causal_conv
self.pad_mode = pad_mode
self.compress = compress
self.num_lstm_layers = num_lstm_layers
self.trim_right_ratio = trim_right_ratio
self.codebook_size = codebook_size
self.codebook_dim = codebook_dim if codebook_dim is not None else hidden_size
self.use_conv_shortcut = use_conv_shortcut
if self.norm_type not in ["weight_norm", "time_group_norm"]:
raise ValueError(
f'self.norm_type must be one of `"weight_norm"`, `"time_group_norm"`), got {self.norm_type}'
)
super().__init__(**kwargs)
@property
def chunk_length(self) -> Optional[int]:
if self.chunk_length_s is None:
return None
else:
return int(self.chunk_length_s * self.sampling_rate)
@property
def chunk_stride(self) -> Optional[int]:
if self.chunk_length_s is None or self.overlap is None:
return None
else:
return max(1, int((1.0 - self.overlap) * self.chunk_length))
@property
def frame_rate(self) -> int:
hop_length = np.prod(self.upsampling_ratios)
return math.ceil(self.sampling_rate / hop_length)
@property
def num_quantizers(self) -> int:
return int(1000 * self.target_bandwidths[-1] // (self.frame_rate * 10))
.\models\encodec\convert_encodec_checkpoint_to_pytorch.py
"""Convert EnCodec checkpoints."""
import argparse
import torch
from transformers import (
EncodecConfig,
EncodecFeatureExtractor,
EncodecModel,
logging,
)
logging.set_verbosity_info()
logger = logging.get_logger("transformers.models.encodec")
MAPPING_QUANTIZER = {
"quantizer.vq.layers.*._codebook.inited": "quantizer.layers.*.codebook.inited",
"quantizer.vq.layers.*._codebook.cluster_size": "quantizer.layers.*.codebook.cluster_size",
"quantizer.vq.layers.*._codebook.embed": "quantizer.layers.*.codebook.embed",
"quantizer.vq.layers.*._codebook.embed_avg": "quantizer.layers.*.codebook.embed_avg",
}
MAPPING_ENCODER = {
"encoder.model.0.conv.conv": "encoder.layers.0.conv",
"encoder.model.1.block.1.conv.conv": "encoder.layers.1.block.1.conv",
"encoder.model.1.block.3.conv.conv": "encoder.layers.1.block.3.conv",
"encoder.model.1.shortcut.conv.conv": "encoder.layers.1.shortcut.conv",
"encoder.model.3.conv.conv": "encoder.layers.3.conv",
"encoder.model.4.block.1.conv.conv": "encoder.layers.4.block.1.conv",
"encoder.model.4.block.3.conv.conv": "encoder.layers.4.block.3.conv",
"encoder.model.4.shortcut.conv.conv": "encoder.layers.4.shortcut.conv",
"encoder.model.6.conv.conv": "encoder.layers.6.conv",
"encoder.model.7.block.1.conv.conv": "encoder.layers.7.block.1.conv",
"encoder.model.7.block.3.conv.conv": "encoder.layers.7.block.3.conv",
"encoder.model.7.shortcut.conv.conv": "encoder.layers.7.shortcut.conv",
"encoder.model.9.conv.conv": "encoder.layers.9.conv",
"encoder.model.10.block.1.conv.conv": "encoder.layers.10.block.1.conv",
"encoder.model.10.block.3.conv.conv": "encoder.layers.10.block.3.conv",
"encoder.model.10.shortcut.conv.conv": "encoder.layers.10.shortcut.conv",
"encoder.model.12.conv.conv": "encoder.layers.12.conv",
"encoder.model.13.lstm": "encoder.layers.13.lstm",
"encoder.model.15.conv.conv": "encoder.layers.15.conv",
}
MAPPING_ENCODER_48K = {
"encoder.model.0.conv.norm": "encoder.layers.0.norm",
}
{
"encoder.model.1.block.1.conv.norm": "encoder.layers.1.block.1.norm",
"encoder.model.1.block.3.conv.norm": "encoder.layers.1.block.3.norm",
"encoder.model.1.shortcut.conv.norm": "encoder.layers.1.shortcut.norm",
"encoder.model.3.conv.norm": "encoder.layers.3.norm",
"encoder.model.4.block.1.conv.norm": "encoder.layers.4.block.1.norm",
"encoder.model.4.block.3.conv.norm": "encoder.layers.4.block.3.norm",
"encoder.model.4.shortcut.conv.norm": "encoder.layers.4.shortcut.norm",
"encoder.model.6.conv.norm": "encoder.layers.6.norm",
"encoder.model.7.block.1.conv.norm": "encoder.layers.7.block.1.norm",
"encoder.model.7.block.3.conv.norm": "encoder.layers.7.block.3.norm",
"encoder.model.7.shortcut.conv.norm": "encoder.layers.7.shortcut.norm",
"encoder.model.9.conv.norm": "encoder.layers.9.norm",
"encoder.model.10.block.1.conv.norm": "encoder.layers.10.block.1.norm",
"encoder.model.10.block.3.conv.norm": "encoder.layers.10.block.3.norm",
"encoder.model.10.shortcut.conv.norm": "encoder.layers.10.shortcut.norm",
"encoder.model.12.conv.norm": "encoder.layers.12.norm",
"encoder.model.15.conv.norm": "encoder.layers.15.norm",
}
}
MAPPING_DECODER = {
"decoder.model.0.conv.conv": "decoder.layers.0.conv",
"decoder.model.1.lstm": "decoder.layers.1.lstm",
"decoder.model.3.convtr.convtr": "decoder.layers.3.conv",
"decoder.model.4.block.1.conv.conv": "decoder.layers.4.block.1.conv",
"decoder.model.4.block.3.conv.conv": "decoder.layers.4.block.3.conv",
"decoder.model.4.shortcut.conv.conv": "decoder.layers.4.shortcut.conv",
"decoder.model.6.convtr.convtr": "decoder.layers.6.conv",
"decoder.model.7.block.1.conv.conv": "decoder.layers.7.block.1.conv",
"decoder.model.7.block.3.conv.conv": "decoder.layers.7.block.3.conv",
"decoder.model.7.shortcut.conv.conv": "decoder.layers.7.shortcut.conv",
"decoder.model.9.convtr.convtr": "decoder.layers.9.conv",
"decoder.model.10.block.1.conv.conv": "decoder.layers.10.block.1.conv",
"decoder.model.10.block.3.conv.conv": "decoder.layers.10.block.3.conv",
"decoder.model.10.shortcut.conv.conv": "decoder.layers.10.shortcut.conv",
"decoder.model.12.convtr.convtr": "decoder.layers.12.conv",
"decoder.model.13.block.1.conv.conv": "decoder.layers.13.block.1.conv",
"decoder.model.13.block.3.conv.conv": "decoder.layers.13.block.3.conv",
"decoder.model.13.shortcut.conv.conv": "decoder.layers.13.shortcut.conv",
"decoder.model.15.conv.conv": "decoder.layers.15.conv",
}
MAPPING_DECODER_48K = {
"decoder.model.0.conv.norm": "decoder.layers.0.norm",
"decoder.model.3.convtr.norm": "decoder.layers.3.norm",
"decoder.model.4.block.1.conv.norm": "decoder.layers.4.block.1.norm",
"decoder.model.4.block.3.conv.norm": "decoder.layers.4.block.3.norm",
"decoder.model.4.shortcut.conv.norm": "decoder.layers.4.shortcut.norm",
"decoder.model.6.convtr.norm": "decoder.layers.6.norm",
"decoder.model.7.block.1.conv.norm": "decoder.layers.7.block.1.norm",
"decoder.model.7.block.3.conv.norm": "decoder.layers.7.block.3.norm",
"decoder.model.7.shortcut.conv.norm": "decoder.layers.7.shortcut.norm",
"decoder.model.9.convtr.norm": "decoder.layers.9.norm",
"decoder.model.10.block.1.conv.norm": "decoder.layers.10.block.1.norm",
"decoder.model.10.block.3.conv.norm": "decoder.layers.10.block.3.norm",
"decoder.model.10.shortcut.conv.norm": "decoder.layers.10.shortcut.norm",
"decoder.model.12.convtr.norm": "decoder.layers.12.norm",
"decoder.model.13.block.1.conv.norm": "decoder.layers.13.block.1.norm",
"decoder.model.13.block.3.conv.norm": "decoder.layers.13.block.3.norm",
"decoder.model.13.shortcut.conv.norm": "decoder.layers.13.shortcut.norm",
"decoder.model.15.conv.norm": "decoder.layers.15.norm",
}
MAPPING_24K = {
**MAPPING_QUANTIZER,
**MAPPING_ENCODER,
**MAPPING_DECODER,
}
MAPPING_48K = {
**MAPPING_QUANTIZER,
**MAPPING_ENCODER,
**MAPPING_ENCODER_48K,
**MAPPING_DECODER,
**MAPPING_DECODER_48K,
}
TOP_LEVEL_KEYS = []
IGNORE_KEYS = []
def set_recursively(hf_pointer, key, value, full_name, weight_type):
for attribute in key.split("."):
hf_pointer = getattr(hf_pointer, attribute)
if weight_type is not None:
hf_shape = getattr(hf_pointer, weight_type).shape
else:
hf_shape = hf_pointer.shape
if hf_shape != value.shape:
raise ValueError(
f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be"
f" {value.shape} for {full_name}"
)
if weight_type == "weight":
hf_pointer.weight.data = value
elif weight_type == "weight_g":
hf_pointer.weight_g.data = value
elif weight_type == "weight_v":
hf_pointer.weight_v.data = value
elif weight_type == "bias":
hf_pointer.bias.data = value
elif weight_type == "running_mean":
hf_pointer.running_mean.data = value
elif weight_type == "running_var":
hf_pointer.running_var.data = value
elif weight_type == "num_batches_tracked":
hf_pointer.num_batches_tracked.data = value
elif weight_type == "weight_ih_l0":
hf_pointer.weight_ih_l0.data = value
elif weight_type == "weight_hh_l0":
hf_pointer.weight_hh_l0.data = value
elif weight_type == "bias_ih_l0":
hf_pointer.bias_ih_l0.data = value
elif weight_type == "bias_hh_l0":
hf_pointer.bias_hh_l0.data = value
elif weight_type == "weight_ih_l1":
hf_pointer.weight_ih_l1.data = value
elif weight_type == "weight_hh_l1":
hf_pointer.weight_hh_l1.data = value
elif weight_type == "bias_ih_l1":
hf_pointer.bias_ih_l1.data = value
elif weight_type == "bias_hh_l1":
hf_pointer.bias_hh_l1.data = value
else:
hf_pointer.data = value
logger.info(f"{key + ('.' + weight_type if weight_type is not None else '')} was initialized from {full_name}.")
def should_ignore(name, ignore_keys):
for key in ignore_keys:
if key.endswith(".*"):
if name.startswith(key[:-1]):
return True
elif ".*." in key:
prefix, suffix = key.split(".*.")
if prefix in name and suffix in name:
return True
elif key in name:
return True
return False
def recursively_load_weights(orig_dict, hf_model, model_name):
unused_weights = []
if model_name == "encodec_24khz" or "encodec_32khz":
MAPPING = MAPPING_24K
elif model_name == "encodec_48khz":
MAPPING = MAPPING_48K
else:
raise ValueError(f"Unsupported model: {model_name}")
for name, value in orig_dict.items():
if should_ignore(name, IGNORE_KEYS):
logger.info(f"{name} was ignored")
continue
is_used = False
for key, mapped_key in MAPPING.items():
if "*" in key:
prefix, suffix = key.split(".*.")
if prefix in name and suffix in name:
key = suffix
if key in name:
if key.endswith("embed") and name.endswith("embed_avg"):
continue
is_used = True
if "*" in mapped_key:
layer_index = name.split(key)[0].split(".")[-2]
mapped_key = mapped_key.replace("*", layer_index)
if "weight_g" in name:
weight_type = "weight_g"
elif "weight_v" in name:
weight_type = "weight_v"
elif "weight_ih_l0" in name:
weight_type = "weight_ih_l0"
elif "weight_hh_l0" in name:
weight_type = "weight_hh_l0"
elif "bias_ih_l0" in name:
weight_type = "bias_ih_l0"
elif "bias_hh_l0" in name:
weight_type = "bias_hh_l0"
elif "weight_ih_l1" in name:
weight_type = "weight_ih_l1"
elif "weight_hh_l1" in name:
weight_type = "weight_hh_l1"
elif "bias_ih_l1" in name:
weight_type = "bias_ih_l1"
elif "bias_hh_l1" in name:
weight_type = "bias_hh_l1"
elif "bias" in name:
weight_type = "bias"
elif "weight" in name:
weight_type = "weight"
elif "running_mean" in name:
weight_type = "running_mean"
elif "running_var" in name:
weight_type = "running_var"
elif "num_batches_tracked" in name:
weight_type = "num_batches_tracked"
else:
weight_type = None
set_recursively(hf_model, mapped_key, value, name, weight_type)
continue
if not is_used:
unused_weights.append(name)
logger.warning(f"Unused weights: {unused_weights}")
def convert_checkpoint(
model_name,
checkpoint_path,
pytorch_dump_folder_path,
config_path=None,
repo_id=None,
):
"""
Copy/paste/tweak model's weights to transformers design.
"""
if config_path is not None:
config = EncodecConfig.from_pretrained(config_path)
else:
config = EncodecConfig()
if model_name == "encodec_24khz":
pass
elif model_name == "encodec_32khz":
config.upsampling_ratios = [8, 5, 4, 4]
config.target_bandwidths = [2.2]
config.num_filters = 64
config.sampling_rate = 32_000
config.codebook_size = 2048
config.use_causal_conv = False
config.normalize = False
config.use_conv_shortcut = False
elif model_name == "encodec_48khz":
config.upsampling_ratios = [8, 5, 4, 2]
config.target_bandwidths = [3.0, 6.0, 12.0, 24.0]
config.sampling_rate = 48_000
config.audio_channels = 2
config.use_causal_conv = False
config.norm_type = "time_group_norm"
config.normalize = True
config.chunk_length_s = 1.0
config.overlap = 0.01
else:
raise ValueError(f"Unknown model name: {model_name}")
model = EncodecModel(config)
feature_extractor = EncodecFeatureExtractor(
feature_size=config.audio_channels,
sampling_rate=config.sampling_rate,
chunk_length_s=config.chunk_length_s,
overlap=config.overlap,
)
feature_extractor.save_pretrained(pytorch_dump_folder_path)
original_checkpoint = torch.load(checkpoint_path)
if "best_state" in original_checkpoint:
original_checkpoint = original_checkpoint["best_state"]
recursively_load_weights(original_checkpoint, model, model_name)
model.save_pretrained(pytorch_dump_folder_path)
if repo_id:
print("Pushing to the hub...")
feature_extractor.push_to_hub(repo_id)
model.push_to_hub(repo_id)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model",
default="encodec_24khz",
type=str,
help="The model to convert. Should be one of 'encodec_24khz', 'encodec_32khz', 'encodec_48khz'.",
)
parser.add_argument("--checkpoint_path", required=True, default=None, type=str, help="Path to original checkpoint")
parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
parser.add_argument(
"--pytorch_dump_folder_path", required=True, default=None, type=str, help="Path to the output PyTorch model."
)
parser.add_argument(
"--push_to_hub", default=None, type=str, help="Where to upload the converted model on the 🤗 hub."
)
args = parser.parse_args()
convert_checkpoint(
args.model,
args.checkpoint_path,
args.pytorch_dump_folder_path,
args.config_path,
args.push_to_hub,
)