Transformers 源码解析(七十五)
.\models\mistral\modeling_mistral.py
""" PyTorch Mistral model. """
import inspect
import math
import warnings
from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
from ...modeling_utils import PreTrainedModel
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
)
from .configuration_mistral import MistralConfig
if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "MistralConfig"
def _get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)
class MistralRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
MistralRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class MistralRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self._set_cos_sin_cache(
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
def forward(self, x, seq_len=None):
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
return (
self.cos_cached[:seq_len].to(dtype=x.dtype),
self.sin_cached[:seq_len].to(dtype=x.dtype),
)
def rotate_half(x):
"""对输入的隐藏维度的一半进行旋转。"""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""将Rotary位置嵌入应用到查询和键张量中。"""
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`):
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
used to pass offsetted position ids when working with a KV-cache.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
# Unsqueezing cos and sin tensors along the specified dimension to match q and k tensor shapes
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
# Applying rotary position embedding to q and k tensors
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class MistralMLP(nn.Module):
# MistralMLP 类,用于定义一个 MLP 模型
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size # 从配置中获取隐藏层大小
self.intermediate_size = config.intermediate_size # 从配置中获取中间层大小
# 创建一个线性层,用于门控投影,输入大小为 hidden_size,输出大小为 intermediate_size,无偏置
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
# 创建一个线性层,用于上投影,输入大小为 hidden_size,输出大小为 intermediate_size,无偏置
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
# 创建一个线性层,用于下投影,输入大小为 intermediate_size,输出大小为 hidden_size,无偏置
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
# 根据配置中的隐藏激活函数选择对应的激活函数
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
# 前向传播函数,利用门控投影、激活函数、上投影计算最终输出
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
# Copied from transformers.models.llama.modeling_llama.repeat_kv
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
# 将 hidden_states 在维度 1 上重复 n_rep 次,实现扩展
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
# 扩展 hidden_states 维度
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
class MistralAttention(nn.Module):
"""
Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
and "Generating Long Sequences with Sparse Transformers".
"""
# MistralAttention 类,实现多头注意力机制,基于 'Attention Is All You Need' 的方法,并支持滑动窗口注意力
# 初始化函数,用于创建一个新的Mistral注意力层对象
def __init__(self, config: MistralConfig, layer_idx: Optional[int] = None):
# 调用父类的初始化方法
super().__init__()
# 将传入的配置对象保存到实例变量中
self.config = config
# 保存传入的层索引到实例变量中
self.layer_idx = layer_idx
# 如果未传入层索引,发出警告,并说明在使用缓存时可能会导致错误
if layer_idx is None:
logger.warning_once(
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)
# 从配置中获取隐藏单元大小并保存到实例变量中
self.hidden_size = config.hidden_size
# 从配置中获取注意力头的数量并保存到实例变量中
self.num_heads = config.num_attention_heads
# 计算每个注意力头的维度并保存到实例变量中
self.head_dim = self.hidden_size // self.num_heads
# 从配置中获取键值头的数量并保存到实例变量中
self.num_key_value_heads = config.num_key_value_heads
# 计算每个键值头的组数并保存到实例变量中
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
# 从配置中获取最大位置嵌入数并保存到实例变量中
self.max_position_embeddings = config.max_position_embeddings
# 从配置中获取Rope Theta并保存到实例变量中
self.rope_theta = config.rope_theta
# 设置是否因果化为True,并保存到实例变量中
self.is_causal = True
# 从配置中获取注意力丢弃率并保存到实例变量中
self.attention_dropout = config.attention_dropout
# 检查隐藏单元大小是否能被注意力头的数量整除,否则抛出值错误
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
# 创建查询投影矩阵,将隐藏状态映射到注意力头维度的空间
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
# 创建键投影矩阵,将隐藏状态映射到键值头维度的空间
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
# 创建值投影矩阵,将隐藏状态映射到键值头维度的空间
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
# 创建输出投影矩阵,将注意力头的结果映射回隐藏状态的空间
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
# 创建旋转嵌入对象,用于引入循环旋转机制以捕捉序列位置信息
self.rotary_emb = MistralRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)
# 定义形状函数,用于调整张量的形状以适应注意力计算的需要
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
# 前向传播函数,执行Mistral注意力层的计算过程
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
# 定义一个名为 MistralFlashAttention2 的类,继承自 MistralAttention 类。
# 这个类是 Mistral flash attention 模块,其权重继承自 MistralAttention,没有进行修改。
class MistralFlashAttention2(MistralAttention):
"""
Mistral flash attention module. This module inherits from `MistralAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
# 从 transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ 中复制而来
# 初始化函数,接受任意参数并传递给父类的初始化函数
def __init__(self, *args, **kwargs):
# 调用父类的初始化函数
super().__init__(*args, **kwargs)
# TODO: 在 Flash Attention for RoCm 更新到 2.1 后应移除这段代码。
# flash_attn<2.1 生成左上角对齐的因果蒙版,而这里需要的是右下角对齐,默认情况下 flash_attn>=2.1 已经实现了这个变更。这个属性用于处理这个差异。
# 参考链接:https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0
# 需要注意的是,在 flash_attn<2.1 中,当 q_seqlen != k_seqlen(除了 q_seqlen == 1 的情况)时会生成错误的蒙版(左上角)。
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
# 正向传播函数
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
):
# 该函数定义了模块的正向传播逻辑,接受多个参数,其中 hidden_states 是必传的 Tensor 类型参数。
# attention_mask, position_ids, past_key_value, output_attentions, use_cache 等参数是可选的。
# **kwargs 允许传递任意额外的关键字参数。
# 私有方法 _flash_attention_forward 的定义
def _flash_attention_forward(
self,
query_states,
key_states,
value_states,
attention_mask,
query_length,
dropout=0.0,
softmax_scale=None,
use_sliding_windows=False,
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
# 如果键值序列长度与注意力掩码长度不一致,需要调整注意力掩码
if kv_seq_len != attention_mask.shape[-1]:
attention_mask_num_tokens = attention_mask.shape[-1]
attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :]
# 获取未填充数据的索引和相关的序列长度信息
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
# 根据索引重新组织键和值的层,以便与查询层对齐
key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
if query_length == kv_seq_len:
# 如果查询长度与键值序列长度相同,则直接使用相同的索引和序列长度信息
query_layer = index_first_axis(
query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
elif query_length == 1:
# 如果查询长度为1,特殊处理序列长度信息和查询层
max_seqlen_in_batch_q = 1
cu_seqlens_q = torch.arange(
batch_size + 1, dtype=torch.int32, device=query_layer.device
) # 这里有一个内存复制操作,效率较低。
indices_q = cu_seqlens_q[:-1]
query_layer = query_layer.squeeze(1)
else:
# 对于其他情况,假设左填充,调整注意力掩码,然后调用unpad_input函数处理查询层
attention_mask = attention_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
# 返回更新后的查询层、键层、值层,以及相关的索引和序列长度信息
return (
query_layer,
key_layer,
value_layer,
indices_q,
(cu_seqlens_q, cu_seqlens_k),
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)
# 从 transformers.models.llama.modeling_llama.LlamaSdpaAttention 复制代码并将 LLama 改为 Mistral
# TODO @Arthur 在静态缓存后不再从 LLama 复制代码
class MistralSdpaAttention(MistralAttention):
"""
Mistral 注意力模块使用 torch.nn.functional.scaled_dot_product_attention。该模块继承自
`MistralAttention`,模块的权重保持不变。唯一的改动在于前向传播部分以适应 SDPA API。
"""
# 改编自 MistralAttention.forward
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
):
"""
前向传播方法用于执行注意力计算。
Args:
hidden_states (torch.Tensor): 输入的隐藏状态张量。
attention_mask (Optional[torch.Tensor], optional): 注意力掩码张量,默认为None。
position_ids (Optional[torch.LongTensor], optional): 位置标识符张量,默认为None。
past_key_value (Optional[Cache], optional): 过去的键值对缓存,默认为None。
output_attentions (bool, optional): 是否输出注意力权重,默认为False。
use_cache (bool, optional): 是否使用缓存,默认为False。
Returns:
根据模块的具体实现不同,返回不同的结果。
"""
# 实现具体的注意力计算逻辑
# (具体实现部分可能包括 scaled_dot_product_attention 的调用或其它实现方式)
MISTRAL_ATTENTION_CLASSES = {
"eager": MistralAttention,
"flash_attention_2": MistralFlashAttention2,
"sdpa": MistralSdpaAttention,
}
class MistralDecoderLayer(nn.Module):
def __init__(self, config: MistralConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
# 初始化自注意力机制,根据配置选择不同的实现类
self.self_attn = MISTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
# MLP 部分的初始化
self.mlp = MistralMLP(config)
# 输入层归一化,使用 MistralRMSNorm 类进行初始化
self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# 注意力后归一化,同样使用 MistralRMSNorm 类进行初始化
self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
**kwargs,
):
"""
Mistral 解码器层的前向传播方法。
Args:
hidden_states (torch.Tensor): 输入的隐藏状态张量。
attention_mask (Optional[torch.Tensor], optional): 注意力掩码张量,默认为None。
position_ids (Optional[torch.LongTensor], optional): 位置标识符张量,默认为None。
past_key_value (Optional[Tuple[torch.Tensor]], optional): 过去的键值对缓存,默认为None。
output_attentions (Optional[bool], optional): 是否输出注意力权重,默认为False。
use_cache (Optional[bool], optional): 是否使用缓存,默认为False。
**kwargs: 其他可选参数。
Returns:
根据模块的具体实现不同,返回不同的结果。
"""
# 实现具体的前向传播逻辑
# (具体实现部分包括自注意力、MLP处理和归一化处理等步骤)
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
# 如果传入了 `padding_mask` 参数,发出警告,提示在 v4.37 版本中将移除,请使用 `attention_mask` 替代
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
"""
Args:
hidden_states (`torch.FloatTensor`): 输入到层的张量,形状为 `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): 注意力掩码,形状为 `(batch, sequence_length)`,
其中填充元素由0表示。
output_attentions (`bool`, *optional*):
是否返回所有注意力层的注意力张量。查看返回张量中的 `attentions` 以获取更多详细信息。
use_cache (`bool`, *optional*):
如果设置为 `True`,将返回 `past_key_values` 键值状态,可用于加速解码(参见 `past_key_values`)。
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): 缓存的过去键和值投影状态
"""
residual = hidden_states
# 输入层归一化
hidden_states = self.input_layernorm(hidden_states)
# 自注意力
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
# 残差连接
hidden_states = residual + hidden_states
# 全连接层归一化
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
# 残差连接
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
# 定义一个长文档字符串,描述了 MistralPreTrainedModel 类的继承关系和使用方法
MISTRAL_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`MistralConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
# 为 MistralPreTrainedModel 类添加文档注释,指明它是一个输出原始隐藏状态的模型,没有特定的输出层
@add_start_docstrings(
"The bare Mistral Model outputting raw hidden-states without any specific head on top.",
MISTRAL_START_DOCSTRING,
)
class MistralPreTrainedModel(PreTrainedModel):
# 指定 MistralConfig 作为配置类
config_class = MistralConfig
# 基础模型前缀名称为 "model"
base_model_prefix = "model"
# 支持梯度检查点
supports_gradient_checkpointing = True
# 不进行模块拆分的模块列表
_no_split_modules = ["MistralDecoderLayer"]
# 跳过设备放置的键名 "past_key_values"
_skip_keys_device_placement = "past_key_values"
# 支持 Flash Attention 2
_supports_flash_attn_2 = True
# 支持 SDPA
_supports_sdpa = True
# 支持缓存类
_supports_cache_class = True
# 初始化权重的方法,根据模块类型设置权重
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
MISTRAL_INPUTS_DOCSTRING = r"""
"""
# 为 MistralModel 类添加文档注释,描述它是一个 Transformer 解码器模型,由多个 MistralDecoderLayer 组成
@add_start_docstrings(
"The bare Mistral Model outputting raw hidden-states without any specific head on top.",
MISTRAL_START_DOCSTRING,
)
class MistralModel(MistralPreTrainedModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MistralDecoderLayer`]
Args:
config: MistralConfig
"""
# 初始化方法,接受一个 MistralConfig 的配置对象
def __init__(self, config: MistralConfig):
super().__init__(config)
# 设置填充索引为配置中的 pad_token_id
self.padding_idx = config.pad_token_id
# 设置词汇表大小为配置中的 vocab_size
self.vocab_size = config.vocab_size
# 初始化词嵌入层,指定词汇表大小、隐藏大小和填充索引
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
# 初始化多个 MistralDecoderLayer 层,根据 num_hidden_layers 参数
self.layers = nn.ModuleList(
[MistralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
# 设置注意力实现类型为配置中的 _attn_implementation
self._attn_implementation = config._attn_implementation
# 初始化 RMS 归一化层,指定隐藏大小和 RMS 归一化的 epsilon 值
self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# 默认关闭梯度检查点
self.gradient_checkpointing = False
# 初始化权重并进行最终处理
self.post_init()
# 返回当前模型的输入嵌入(embedding)
def get_input_embeddings(self):
return self.embed_tokens
# 设置当前模型的输入嵌入(embedding)
def set_input_embeddings(self, value):
self.embed_tokens = value
# 使用 MISTRAL_INPUTS_DOCSTRING 将文档字符串添加到模型前向传播方法上
def forward(
self,
input_ids: torch.LongTensor = None, # 输入的 token IDs,数据类型为 LongTensor
attention_mask: Optional[torch.Tensor] = None, # 注意力遮罩,可选的 Torch 张量
position_ids: Optional[torch.LongTensor] = None, # 位置 IDs,可选的 LongTensor
past_key_values: Optional[List[torch.FloatTensor]] = None, # 过去的键值对列表,可选的 FloatTensor 列表
inputs_embeds: Optional[torch.FloatTensor] = None, # 输入的嵌入张量,可选的 FloatTensor
use_cache: Optional[bool] = None, # 是否使用缓存,可选的布尔值
output_attentions: Optional[bool] = None, # 是否输出注意力权重,可选的布尔值
output_hidden_states: Optional[bool] = None, # 是否输出隐藏状态,可选的布尔值
return_dict: Optional[bool] = None, # 是否以字典形式返回结果,可选的布尔值
class MistralForCausalLM(MistralPreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
# 使用MistralModel构建模型
self.model = MistralModel(config)
# 设置词汇表大小
self.vocab_size = config.vocab_size
# 线性层,将隐藏状态映射到词汇表大小的空间,无偏置项
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# 初始化权重并应用最终处理
self.post_init()
def get_input_embeddings(self):
# 返回模型的输入嵌入层
return self.model.embed_tokens
def set_input_embeddings(self, value):
# 设置模型的输入嵌入层
self.model.embed_tokens = value
def get_output_embeddings(self):
# 返回语言模型头部的输出嵌入层
return self.lm_head
def set_output_embeddings(self, new_embeddings):
# 设置语言模型头部的输出嵌入层
self.lm_head = new_embeddings
def set_decoder(self, decoder):
# 设置解码器
self.model = decoder
def get_decoder(self):
# 获取解码器
return self.model
@add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
# 模型前向传播函数,详细说明见函数装饰器的注释
pass
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
# 为生成准备输入的函数,包括输入ID、过去键值、注意力掩码和输入嵌入
pass
# 检查是否提供了 past_key_values 参数,如果是则根据其内容进行处理
if past_key_values is not None:
# 如果 past_key_values 是 Cache 类型,则获取其相关属性
if isinstance(past_key_values, Cache):
cache_length = past_key_values.get_seq_length() # 获取缓存序列的长度
past_length = past_key_values.seen_tokens # 获取已处理的标记数
max_cache_length = past_key_values.get_max_length() # 获取最大缓存长度
else:
# 否则假设 past_key_values 是一个元组,获取其第一个元素的第三维长度作为 cache_length 和 past_length
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None
# 保留未处理的标记:
# 1 - 如果 attention_mask 的长度超过 input_ids 的长度,则表明一些输入是作为缓存的一部分传递的
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
# 2 - 如果 past_length 小于 input_ids 的长度,则 input_ids 包含所有输入标记。根据 past_length 截断 input_ids。
elif past_length < input_ids.shape[1]:
input_ids = input_ids[:, past_length:]
# 3 - 否则(past_length >= input_ids.shape[1]),假设 input_ids 只包含未处理的标记。
# 如果即将超过最大缓存长度,则需要裁剪输入的 attention_mask。
if (
max_cache_length is not None
and attention_mask is not None
and cache_length + input_ids.shape[1] > max_cache_length
):
attention_mask = attention_mask[:, -max_cache_length:]
# 获取可选的 position_ids 参数,如果 attention_mask 存在且 position_ids 为 None,则动态生成 position_ids 以用于批次生成
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
position_ids = attention_mask.long().cumsum(-1) - 1 # 在 attention_mask 上累积和计算 position_ids
position_ids.masked_fill_(attention_mask == 0, 1) # 将 attention_mask 为 0 的位置填充为 1
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :] # 如果 past_key_values 存在,只保留与 input_ids 相关的部分
# 如果传入了 inputs_embeds 参数,并且 past_key_values 为 None,则只在第一代中使用它们
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds} # 使用 inputs_embeds 作为模型输入
else:
model_inputs = {"input_ids": input_ids} # 否则使用 input_ids 作为模型输入
# 更新 model_inputs 字典,添加 position_ids、past_key_values、use_cache 和 attention_mask
model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
}
)
return model_inputs # 返回最终的模型输入字典
@staticmethod
# 定义一个函数 `_reorder_cache`,用于重新排序缓存 `past_key_values` 中的数据
def _reorder_cache(past_key_values, beam_idx):
# 初始化一个空元组,用于存储重新排序后的缓存数据
reordered_past = ()
# 遍历 past_key_values 中的每一层的缓存数据
for layer_past in past_key_values:
# 对每层的缓存数据进行重新排序,并将重新排序后的结果添加到 reordered_past 中
reordered_past += (
# 对每个 past_state 执行索引选择操作,使用 beam_idx 作为索引,转移到 past_state 的设备上
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
# 返回重新排序后的缓存数据 reordered_past
return reordered_past
# 定义了一个用于序列分类的 Mistral 模型,其顶部有一个线性层用于分类。
# 该模型使用最后一个 token 进行分类,类似于其他因果模型(如 GPT-2)的做法。
# 如果配置中定义了 `pad_token_id`,则找到每行中不是填充 token 的最后一个 token 进行分类。
# 如果没有定义 `pad_token_id`,则直接取每个批次中每行的最后一个值作为分类的 token。
# 当传入 `inputs_embeds` 而不是 `input_ids` 时,由于无法猜测填充 token,也采用相同的策略(取每行的最后一个值)。
@add_start_docstrings(
"""
The Mistral Model transformer with a sequence classification head on top (linear layer).
[`MistralForSequenceClassification`] uses the last token in order to do the classification, as other causal models
(e.g. GPT-2) do.
Since it does classification on the last token, it requires to know the position of the last token. If a
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
each row of the batch).
""",
MISTRAL_START_DOCSTRING,
)
# 从 transformers.models.llama.modeling_llama.LlamaForSequenceClassification 复制并修改为使用 Mistral 模型
class MistralForSequenceClassification(MistralPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.model = MistralModel(config)
# 使用线性层进行分类,输出维度为类别数,没有偏置项
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
# 初始化权重并进行最终处理
self.post_init()
def get_input_embeddings(self):
# 获取模型的输入嵌入层
return self.model.embed_tokens
def set_input_embeddings(self, value):
# 设置模型的输入嵌入层
self.model.embed_tokens = value
@add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
.\models\mistral\__init__.py
from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_flax_available, is_torch_available
_import_structure = {
"configuration_mistral": ["MISTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP", "MistralConfig"],
}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_mistral"] = [
"MistralForCausalLM",
"MistralModel",
"MistralPreTrainedModel",
"MistralForSequenceClassification",
]
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_flax_mistral"] = [
"FlaxMistralForCausalLM",
"FlaxMistralModel",
"FlaxMistralPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_mistral import MISTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP, MistralConfig
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_mistral import (
MistralForCausalLM,
MistralForSequenceClassification,
MistralModel,
MistralPreTrainedModel,
)
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_flax_mistral import (
FlaxMistralForCausalLM,
FlaxMistralModel,
FlaxMistralPreTrainedModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
.\models\mixtral\configuration_mixtral.py
""" Mixtral model configuration"""
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
MIXTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"mistral-ai/Mixtral-8x7B": "https://huggingface.co/mistral-ai/Mixtral-8x7B/resolve/main/config.json",
}
class MixtralConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`MixtralModel`]. It is used to instantiate an
Mixtral model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the Mixtral-7B-v0.1 or Mixtral-7B-Instruct-v0.1.
[mixtralai/Mixtral-8x7B](https://huggingface.co/mixtralai/Mixtral-8x7B)
[mixtralai/Mixtral-7B-Instruct-v0.1](https://huggingface.co/mixtralai/Mixtral-7B-Instruct-v0.1)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
```
>>> from transformers import MixtralModel, MixtralConfig
>>> # Initializing a Mixtral 7B style configuration
>>> configuration = MixtralConfig()
>>> # Initializing a model from the Mixtral 7B style configuration
>>> model = MixtralModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
"""
model_type = "mixtral"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=32000,
hidden_size=4096,
intermediate_size=14336,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=8,
hidden_act="silu",
max_position_embeddings=4096 * 32,
initializer_range=0.02,
rms_norm_eps=1e-5,
use_cache=True,
pad_token_id=None,
bos_token_id=1,
eos_token_id=2,
tie_word_embeddings=False,
rope_theta=1e6,
sliding_window=None,
attention_dropout=0.0,
num_experts_per_tok=2,
num_local_experts=8,
output_router_logits=False,
router_aux_loss_coef=0.001,
**kwargs,
):
super().__init__(
vocab_size=vocab_size,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_hidden_layers=num_hidden_layers,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
hidden_act=hidden_act,
max_position_embeddings=max_position_embeddings,
initializer_range=initializer_range,
rms_norm_eps=rms_norm_eps,
use_cache=use_cache,
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
rope_theta=rope_theta,
sliding_window=sliding_window,
attention_dropout=attention_dropout,
num_experts_per_tok=num_experts_per_tok,
num_local_experts=num_local_experts,
output_router_logits=output_router_logits,
router_aux_loss_coef=router_aux_loss_coef,
**kwargs,
)
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.sliding_window = sliding_window
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.attention_dropout = attention_dropout
self.num_experts_per_tok = num_experts_per_tok
self.num_local_experts = num_local_experts
self.output_router_logits = output_router_logits
self.router_aux_loss_coef = router_aux_loss_coef
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
.\models\mixtral\convert_mixtral_weights_to_hf.py
import argparse
import json
import os
import torch
from transformers import (
MixtralConfig,
MixtralForCausalLM,
)
"""
示例用法:
python src/transformers/models/mixtral/convert_mixtral_weights_to_hf.py \
--input_dir /path/to/downloaded/mixtral/weights --model_size 7B --output_dir /output/path
之后,可以通过以下方式加载模型:
from transformers import MixtralForCausalLM
model = MixtralForCausalLM.from_pretrained("/output/path")
重要说明:你需要能够将整个模型加载到内存中以执行此脚本(即使最大版本被分成多个检查点,每个检查点都包含模型权重的一部分,因此我们需要将它们全部加载到内存中)。
"""
def compute_intermediate_size(n, ffn_dim_multiplier=1, multiple_of=256):
return multiple_of * ((int(ffn_dim_multiplier * int(8 * n / 3)) + multiple_of - 1) // multiple_of)
def read_json(path):
with open(path, "r") as f:
return json.load(f)
def write_json(text, path):
with open(path, "w") as f:
json.dump(text, f)
def write_model(model_path, input_base_path, model_size, safe_serialization=True):
os.makedirs(model_path, exist_ok=True)
params = read_json(os.path.join(input_base_path, "params.json"))
num_shards = 1
sliding_window = int(params["sliding_window"]) if "sliding_window" in params else None
n_layers = params["num_hidden_layers"]
n_heads = params["num_attention_heads"]
n_heads_per_shard = n_heads // num_shards
dim = params["hidden_size"]
dims_per_head = dim // n_heads
base = params.get("rope_theta", 10000.0)
max_position_embeddings = 4096 * 8
num_local_experts = params["num_local_experts"]
ffn_dim = params["intermediate_size"]
vocab_size = params["vocab_size"]
if "num_key_value_heads" in params:
num_key_value_heads = params["num_key_value_heads"]
num_local_key_value_heads = num_key_value_heads // num_shards
key_value_dim = dims_per_head * num_local_key_value_heads
else:
num_key_value_heads = n_heads
num_local_key_value_heads = n_heads_per_shard
key_value_dim = dim
def permute(w, n_heads=n_heads, dim1=dim, dim2=dim):
return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2)
print(f"Fetching all parameters from the checkpoint at {input_base_path}.")
loaded = [
torch.load(os.path.join(input_base_path, f"consolidated.{i:02d}.pt"), map_location="cpu") for i in range(8)
]
merged_state_dict = {}
for state_dict in loaded:
merged_state_dict.update(state_dict)
state_dict = {}
state_dict.update(
{
"model.norm.weight": merged_state_dict["norm.weight"],
"model.embed_tokens.weight": merged_state_dict["tok_embeddings.weight"],
"lm_head.weight": merged_state_dict["output.weight"],
}
)
config = MixtralConfig(
hidden_size=dim,
intermediate_size=ffn_dim,
num_attention_heads=params["num_attention_heads"],
num_hidden_layers=params["num_hidden_layers"],
rms_norm_eps=params["rms_norm_eps"],
num_key_value_heads=num_key_value_heads,
vocab_size=vocab_size,
rope_theta=base,
max_position_embeddings=max_position_embeddings,
sliding_window=sliding_window,
num_local_experts=num_local_experts,
)
print("Loading the checkpoint in a Mixtral model.")
with torch.device("meta"):
model = MixtralForCausalLM(config)
del model.config._name_or_path
model.config.torch_dtype = torch.float16
print("Saving in the Transformers format.")
model.load_state_dict(state_dict, strict=True, assign=True)
for n, p in model.named_parameters():
assert p.device.type != "meta", f"{n} has not been loaded!"
model.save_pretrained(model_path, safe_serialization=safe_serialization)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--input_dir",
help="Location of Mixtral weights, which contains tokenizer.model and model folders",
required=True,
)
parser.add_argument(
"--model_size",
choices=["7B"],
help="'f' models correspond to the finetuned versions, and are specific to the Mixtral official release. For more details on Mixtral, checkout the original repo: https://huggingface.co/mistral-ai",
default="7B",
)
parser.add_argument("--output_dir", help="Location to write HF model", required=True)
parser.add_argument("--safe_serialization", type=bool, help="Whether or not to save using `safetensors`.")
args = parser.parse_args()
write_model(
model_path=args.output_dir,
input_base_path=args.input_dir,
model_size=args.model_size,
safe_serialization=args.safe_serialization,
)
if __name__ == "__main__":
main()
.\models\mixtral\modeling_mixtral.py
""" PyTorch Mixtral model."""
import inspect
import math
import warnings
from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from ...modeling_outputs import (
MoeCausalLMOutputWithPast,
MoeModelOutputWithPast,
SequenceClassifierOutputWithPast,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import is_torch_greater_or_equal_than_1_13
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
)
from ...utils.import_utils import is_torch_fx_available
from .configuration_mixtral import MixtralConfig
if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
if is_torch_fx_available():
if not is_torch_greater_or_equal_than_1_13:
import torch.fx
_prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "MixtralConfig"
def load_balancing_loss_func(
gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2, attention_mask: Optional[torch.Tensor] = None
) -> float:
r"""
Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
"""
See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
experts is too unbalanced.
Args:
gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]):
Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
shape [batch_size X sequence_length, num_experts].
attention_mask (`torch.Tensor`, None):
The attention_mask used in forward function
shape [batch_size X sequence_length] if not None.
num_experts (`int`, *optional*):
Number of experts
Returns:
The auxiliary loss.
"""
# 如果 gate_logits 为空或者不是元组,则返回 0
if gate_logits is None or not isinstance(gate_logits, tuple):
return 0
# 如果 gate_logits 是元组,则计算设备并将各层的门控 logits 拼接起来
if isinstance(gate_logits, tuple):
compute_device = gate_logits[0].device
concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
# 计算 routing weights,即经过 softmax 处理后的权重
routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
# 根据 routing weights 获取 top_k 个专家的索引
_, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
# 创建一个 one-hot 编码的专家 mask
expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
if attention_mask is None:
# 如果没有 attention_mask,则计算每个专家被路由到的 token 的百分比
tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
# 计算路由到每个专家的平均概率
router_prob_per_expert = torch.mean(routing_weights, dim=0)
else:
batch_size, sequence_length = attention_mask.shape
num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
# 创建专家注意力 mask,用于处理 padding token
expert_attention_mask = (
attention_mask[None, :, :, None, None]
.expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
.reshape(-1, top_k, num_experts)
.to(compute_device)
)
# 计算每个专家被路由到的 token 的百分比,考虑了 attention_mask
tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
expert_attention_mask, dim=0
)
# 创建路由概率专家注意力 mask,用于处理 padding token
router_per_expert_attention_mask = (
attention_mask[None, :, :, None]
.expand((num_hidden_layers, batch_size, sequence_length, num_experts))
.reshape(-1, num_experts)
.to(compute_device)
)
# 计算路由到每个专家的平均概率,考虑了 attention_mask
router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
router_per_expert_attention_mask, dim=0
)
# 计算每个专家的损失乘以路由器概率,并对所有专家求和得到总损失
overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
# 将总损失乘以专家的数量,得到最终的整体损失
return overall_loss * num_experts
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
# 计算非填充数据的索引、累计序列长度和批次中最大序列长度
def _get_unpad_data(attention_mask):
# 计算每个批次中的序列长度总和
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
# 找到所有非填充位置的索引
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
# 获取批次中最大的序列长度
max_seqlen_in_batch = seqlens_in_batch.max().item()
# 计算累计序列长度并进行填充
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mixtral
# MixtralRMSNorm类,用于模仿T5LayerNorm,实现均值归一化
class MixtralRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
MixtralRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
# 初始化权重参数
self.weight = nn.Parameter(torch.ones(hidden_size))
# 设置方差的小值 epsilon
self.variance_epsilon = eps
def forward(self, hidden_states):
# 获取输入的数据类型
input_dtype = hidden_states.dtype
# 将输入转换为 float32 类型
hidden_states = hidden_states.to(torch.float32)
# 计算输入张量的方差
variance = hidden_states.pow(2).mean(-1, keepdim=True)
# 应用均值归一化
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
# 返回加权后的归一化结果
return self.weight * hidden_states.to(input_dtype)
# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Mixtral
# MixtralRotaryEmbedding类,用于生成旋转嵌入矩阵,实现Self-Attention操作
class MixtralRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
# 初始化维度、最大位置嵌入和基数
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
# 计算频率的倒数,用于生成正弦和余弦值
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
# 将频率作为缓冲区注册,以便后续使用
self.register_buffer("inv_freq", inv_freq, persistent=False)
# 构建旋转嵌入的正弦和余弦缓存
self._set_cos_sin_cache(
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
)
def _set_cos_sin_cache(self, seq_len, device, dtype):
# 设置缓存的最大序列长度
self.max_seq_len_cached = seq_len
# 生成等间距的整数张量
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
# 计算正弦和余弦值的缓存
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
# 如果当前序列长度超过缓存的最大序列长度,重新设置正弦和余弦缓存
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
# 返回旋转嵌入的正弦和余弦值
return (
self.cos_cached[:seq_len].to(dtype=x.dtype),
self.sin_cached[:seq_len].to(dtype=x.dtype),
)
# Copied from transformers.models.llama.modeling_llama.rotate_half
# 实现输入张量的上下半部分交换
def rotate_half(x):
# 对输入张量的一半隐藏维度进行旋转操作
"""Rotates half the hidden dims of the input."""
# 将输入张量 x 的前半部分进行切片,保留其隐藏维度的前一半数据
x1 = x[..., : x.shape[-1] // 2]
# 将输入张量 x 的后半部分进行切片,保留其隐藏维度的后一半数据
x2 = x[..., x.shape[-1] // 2 :]
# 将 x 的后半部分取负值,并与 x 的前半部分连接在一起,以实现旋转操作
return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`):
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
used to pass offsetted position ids when working with a KV-cache.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
# Unsqueezes cos and sin tensors along unsqueeze_dim to match dimensions of q and k
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
# Apply rotary position embedding to q and k tensors
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
# Copied from transformers.models.llama.modeling_llama.repeat_kv
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
# Extract dimensions from hidden_states tensor
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
# If n_rep is 1, return the original hidden_states tensor
if n_rep == 1:
return hidden_states
# Expand hidden_states tensor to repeat along the specified dimension
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
# Reshape expanded tensor to the desired shape
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
# Copied from transformers.models.mistral.modeling_mistral.MistralAttention with Mistral->Mixtral
class MixtralAttention(nn.Module):
"""
Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
and "Generating Long Sequences with Sparse Transformers".
"""
# 初始化函数,接受配置参数和可选的层索引
def __init__(self, config: MixtralConfig, layer_idx: Optional[int] = None):
# 调用父类的初始化方法
super().__init__()
# 保存传入的配置参数
self.config = config
# 保存传入的层索引
self.layer_idx = layer_idx
# 如果未提供层索引,发出警告,因为在使用缓存时可能会导致前向调用中的错误
if layer_idx is None:
logger.warning_once(
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)
# 从配置中获取隐藏层大小
self.hidden_size = config.hidden_size
# 从配置中获取注意力头数
self.num_heads = config.num_attention_heads
# 计算每个注意力头的维度
self.head_dim = self.hidden_size // self.num_heads
# 从配置中获取键值头数
self.num_key_value_heads = config.num_key_value_heads
# 计算每组键值头的数量
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
# 从配置中获取最大位置嵌入数
self.max_position_embeddings = config.max_position_embeddings
# 从配置中获取旋转嵌入的基础值
self.rope_theta = config.rope_theta
# 设定是否是因果注意力
self.is_causal = True
# 从配置中获取注意力丢弃率
self.attention_dropout = config.attention_dropout
# 检查隐藏层大小是否能被注意力头数整除,否则抛出数值错误
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
# 初始化查询投影层
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
# 初始化键投影层
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
# 初始化值投影层
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
# 初始化输出投影层
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
# 初始化旋转嵌入层
self.rotary_emb = MixtralRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)
# 根据给定的张量形状,调整其形状以适应注意力头数和头维度的结构
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
# 前向传播函数,接收隐藏状态、注意力掩码、位置ID、过去的键值对缓存等参数
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
# 从 `transformers.models.mistral.modeling_mistral.MistralFlashAttention2` 复制的 `MixtralFlashAttention2` 类,将 Mistral 更名为 Mixtral
class MixtralFlashAttention2(MixtralAttention):
"""
Mixtral flash attention module. This module inherits from `MixtralAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
# 从 `transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__` 复制的构造函数
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: 在 Flash Attention for RoCm 版本升级到 2.1 之后应该移除这段注释。
# flash_attn<2.1 生成左上对齐的因果蒙版,而这里需要右下对齐的默认效果。此属性用于处理这种差异。参考:https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0。
# 注意,对于 flash_attn<2.1,除了 q_seqlen == 1 的情况外,使用 q_seqlen != k_seqlen 会产生错误的蒙版(左上对齐)。
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
):
"""
Override of the forward method to integrate Mixtral flash attention with handling of padding tokens.
"""
# 真正的前向传播方法,集成了 Mixtral flash attention 并处理填充标记
pass
def _flash_attention_forward(
self,
query_states,
key_states,
value_states,
attention_mask,
query_length,
dropout=0.0,
softmax_scale=None,
use_sliding_windows=False,
# 定义一个方法 `_upad_input`,该方法接受多个输入参数:query_layer, key_layer, value_layer, attention_mask, query_length
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
# 获取 key_layer 的形状信息,分别为 batch_size, kv_seq_len, num_heads, head_dim
batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
# 如果 kv_seq_len 不等于 attention_mask 的最后一个维度长度,需要重新创建 padding mask
if kv_seq_len != attention_mask.shape[-1]:
# 获取 attention_mask 的最后一个维度长度
attention_mask_num_tokens = attention_mask.shape[-1]
# 更新 attention_mask,保留 kv_seq_len 长度的部分
attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :]
# 调用 _get_unpad_data 函数,获取解压后的数据 indices_k, cu_seqlens_k, max_seqlen_in_batch_k
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
# 通过索引操作,对 key_layer 进行重新组织,形状变为 (batch_size * kv_seq_len, num_heads, head_dim)
key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
# 对 value_layer 进行类似的重新组织
value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
# 根据 query_length 的不同情况进行不同的处理
if query_length == kv_seq_len:
# 如果 query_length 等于 kv_seq_len,则对 query_layer 进行索引操作
query_layer = index_first_axis(
query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
elif query_length == 1:
# 如果 query_length 等于 1,则将 query_layer 的形状调整,并生成相应的索引和长度信息
max_seqlen_in_batch_q = 1
cu_seqlens_q = torch.arange(
batch_size + 1, dtype=torch.int32, device=query_layer.device
) # 这里有一个 memcpy 操作,非常不好。
indices_q = cu_seqlens_q[:-1]
query_layer = query_layer.squeeze(1)
else:
# 否则,根据 -query_length: 切片假设左填充,更新 attention_mask,并调用 unpad_input 函数处理 query_layer
attention_mask = attention_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
# 返回处理后的结果,包括 query_layer, key_layer, value_layer, indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_in_batch_q, max_seqlen_in_batch_k)
return (
query_layer,
key_layer,
value_layer,
indices_q,
(cu_seqlens_q, cu_seqlens_k),
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)
# 从`transformers.models.mistral.modeling_mistral.MistralSdpaAttention`复制而来,将"Mistral"改为"Mixtral"
class MixtralSdpaAttention(MixtralAttention):
"""
使用`torch.nn.functional.scaled_dot_product_attention`的Mixtral注意力模块。此模块继承自`MixtralAttention`,
其权重保持不变。唯一的更改在于前向传递,以适应SDPA API。
"""
# 从MixtralAttention.forward进行调整
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
):
pass # 这里的实际实现在SDPA API中进行了调整,但在注释中未提供具体的实现细节
# 定义了Mixtral注意力类别的映射字典
MIXTRAL_ATTENTION_CLASSES = {
"eager": MixtralAttention,
"flash_attention_2": MixtralFlashAttention2,
"sdpa": MixtralSdpaAttention, # 将sdpa映射到MixtralSdpaAttention类
}
class MixtralBlockSparseTop2MLP(nn.Module):
def __init__(self, config: MixtralConfig):
super().__init__()
self.ffn_dim = config.intermediate_size
self.hidden_dim = config.hidden_size
# 线性层定义
self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
# 激活函数从ACT2FN字典中选择
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, hidden_states):
# 前向传递计算
current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
current_hidden_states = self.w2(current_hidden_states)
return current_hidden_states
# MixtralBLockSparseTop2MLP被废弃,用MixtralBlockSparseTop2MLP代替,发出一次警告
class MixtralBLockSparseTop2MLP(MixtralBlockSparseTop2MLP):
def __init__(self, *args, **kwargs):
logger.warning_once(
"MixtralBLockSparseTop2MLP is deprecated by MixtralBlockSparseTop2MLP and will be removed in v4.40."
)
super().__init__(*args, **kwargs)
class MixtralSparseMoeBlock(nn.Module):
"""
这个实现严格等同于标准的MoE,具有全容量(没有丢弃标记的令牌)。它更快,因为它将MoE操作
形式化为块稀疏操作,以适应对专家的不平衡分配,而标准MoE要么(1)丢弃标记,以降低性能,要么(2)
将容量因子设置为专家数量,从而浪费填充的计算和内存。
"""
def __init__(self, config):
super().__init__()
self.hidden_dim = config.hidden_size
self.ffn_dim = config.intermediate_size
self.num_experts = config.num_local_experts
self.top_k = config.num_experts_per_tok
# gating
# gating机制的线性层
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
# 创建MixtralBlockSparseTop2MLP模块列表,用于每个专家
self.experts = nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""
接收隐藏状态作为输入,返回处理后的隐藏状态和路由器的logits值。
Args:
hidden_states (torch.Tensor): 输入的隐藏状态张量,形状为(batch_size, sequence_length, hidden_dim)
Returns:
torch.Tensor: 处理后的最终隐藏状态张量,形状为(batch_size, sequence_length, hidden_dim)
torch.Tensor: 路由器的logits张量,形状为(batch * sequence_length, n_experts)
"""
# 获取输入张量的维度信息
batch_size, sequence_length, hidden_dim = hidden_states.shape
# 将输入的三维张量重塑为二维张量,以便进行路由器的计算
hidden_states = hidden_states.view(-1, hidden_dim)
# 使用路由器模型计算路由器的logits
router_logits = self.gate(hidden_states)
# 使用softmax函数对logits进行归一化处理,得到路由权重
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
# 从每个路由权重中选择top-k的值,并重新归一化
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
# 将归一化后的路由权重转换为输入张量的数据类型
routing_weights = routing_weights.to(hidden_states.dtype)
# 初始化一个全零张量,用于存储最终的隐藏状态
final_hidden_states = torch.zeros(
(batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
)
# 使用one-hot编码创建选定专家的专家掩码
# 这将用于轻松地索引哪个专家将被调用
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
# 遍历模型中所有可用的专家,并在每个专家上执行计算
for expert_idx in range(self.num_experts):
expert_layer = self.experts[expert_idx]
idx, top_x = torch.where(expert_mask[expert_idx])
if top_x.shape[0] == 0:
continue
# 将top_x张量转换为Python列表,以便在PyTorch中更快地索引
top_x_list = top_x.tolist()
idx_list = idx.tolist()
# 根据索引从隐藏状态中获取正确的隐藏状态,并计算当前专家的隐藏状态
current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None]
# 使用index_add_方法将当前专家的隐藏状态加到最终隐藏状态中
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
# 将最终隐藏状态张量重塑回原始形状(batch_size, sequence_length, hidden_dim)
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
# 返回最终的隐藏状态张量和路由器的logits张量
return final_hidden_states, router_logits
# 定义 MixtralDecoderLayer 类,继承自 nn.Module,用于实现 Mixtral 模型的解码器层
class MixtralDecoderLayer(nn.Module):
# 初始化方法,接受 MixtralConfig 和层索引作为参数
def __init__(self, config: MixtralConfig, layer_idx: int):
super().__init__()
# 设置隐藏层大小
self.hidden_size = config.hidden_size
# 初始化自注意力机制,根据配置选择不同的注意力实现类进行初始化
self.self_attn = MIXTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
# 初始化块稀疏多路注意力模块
self.block_sparse_moe = MixtralSparseMoeBlock(config)
# 初始化输入层归一化模块
self.input_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# 初始化注意力后归一化模块
self.post_attention_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# 前向传播方法,接受隐藏状态、注意力掩码、位置 ID、过去的键值对等参数
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
output_router_logits: Optional[bool] = False,
use_cache: Optional[bool] = False,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
if "padding_mask" in kwargs:
# 如果传入了 `padding_mask` 参数,发出警告提示
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
"""
Args:
hidden_states (`torch.FloatTensor`): 输入到层的张量,形状为 `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *可选*): 注意力掩码张量,形状为 `(batch, sequence_length)`,其中填充元素为0
past_key_value (`Tuple(torch.FloatTensor)`, *可选*): 缓存的过去键值投影状态
output_attentions (`bool`, *可选*):
是否返回所有注意力层的注意力张量。详见返回的张量中的 `attentions` 了解更多细节。
output_router_logits (`bool`, *可选*):
是否返回所有路由器的logits。这对计算路由器损失很有用,在推理时不应返回。
use_cache (`bool`, *可选*):
如果设置为 `True`,则返回 `past_key_values` 键值状态,可用于加速解码 (参见 `past_key_values`).
"""
residual = hidden_states
# 输入层归一化
hidden_states = self.input_layernorm(hidden_states)
# 自注意力层
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = residual + hidden_states
# 全连接层
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states, router_logits = self.block_sparse_moe(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
if output_router_logits:
outputs += (router_logits,)
return outputs
# MIXTRAL_START_DOCSTRING 是一个多行原始字符串,用于描述 MixtralPreTrainedModel 类的文档字符串。
# 它包含了关于模型继承自 PreTrainedModel 的信息,以及如何使用 PyTorch 的说明和参数列表。
MIXTRAL_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`MixtralConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
# add_start_docstrings 是一个装饰器,用于为 MixtralPreTrainedModel 类添加文档字符串。
# 第一个参数是描述该模型输出原始隐藏状态的概述性文本。
# 第二个参数是 MIXTRAL_START_DOCSTRING,用于详细描述该类的配置和参数信息。
@add_start_docstrings(
"The bare Mixtral Model outputting raw hidden-states without any specific head on top.",
MIXTRAL_START_DOCSTRING,
)
# MixtralPreTrainedModel 类继承自 PreTrainedModel,用于 Mixtral 模型的预训练和初始化。
class MixtralPreTrainedModel(PreTrainedModel):
# 配置类,指定了 Mixtral 模型的配置信息。
config_class = MixtralConfig
# 基础模型的前缀,通常用于命名前缀。
base_model_prefix = "model"
# 是否支持梯度检查点。
supports_gradient_checkpointing = True
# 不需要拆分的模块列表。
_no_split_modules = ["MixtralDecoderLayer"]
# 跳过键的设备放置。
_skip_keys_device_placement = "past_key_values"
# 是否支持 Flash Attention 2。
_supports_flash_attn_2 = True
# 是否支持 SDPA(Scaled Dot-Product Attention)。
_supports_sdpa = True
# 是否支持缓存类。
_supports_cache_class = True
# 初始化权重的函数。
def _init_weights(self, module):
std = self.config.initializer_range
# 如果是线性层,初始化权重和偏置。
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
# 如果是嵌入层,初始化权重。
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
# MIXTRAL_INPUTS_DOCSTRING 是一个未填充的多行原始字符串,可能用于描述 MixtralModel 类的输入信息。
MIXTRAL_INPUTS_DOCSTRING = r"""
"""
# add_start_docstrings 是一个装饰器,用于为 MixtralModel 类添加文档字符串。
# 第一个参数是描述该模型输出原始隐藏状态的概述性文本。
# 第二个参数是 MIXTRAL_START_DOCSTRING,用于详细描述该类的配置和参数信息。
@add_start_docstrings(
"The bare Mixtral Model outputting raw hidden-states without any specific head on top.",
MIXTRAL_START_DOCSTRING,
)
# MixtralModel 类继承自 MixtralPreTrainedModel,代表了 Mixtral 模型的具体实现。
class MixtralModel(MixtralPreTrainedModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MixtralDecoderLayer`]
Args:
config: MixtralConfig
"""
# 初始化函数,接受一个 MixtralConfig 类型的参数 config
def __init__(self, config: MixtralConfig):
# 调用父类的初始化函数,传入 config 参数
super().__init__(config)
# 设置 padding_idx 属性为 config 的 pad_token_id
self.padding_idx = config.pad_token_id
# 设置 vocab_size 属性为 config 的 vocab_size
self.vocab_size = config.vocab_size
# 创建一个嵌入层对象 embed_tokens,用于将输入的 token 转换为向量表示
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
# 创建一个由多个 MixtralDecoderLayer 组成的层列表,每层通过不同的 layer_idx 构建
self.layers = nn.ModuleList(
[MixtralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
# 设置 _attn_implementation 属性为 config 的 _attn_implementation
self._attn_implementation = config._attn_implementation
# 创建一个 MixtralRMSNorm 对象 norm,用于进行归一化处理
self.norm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# 初始化梯度检查点标志为 False
self.gradient_checkpointing = False
# 调用 post_init 函数,完成权重初始化和最终处理
self.post_init()
# 返回 embed_tokens 属性,即输入嵌入层对象
def get_input_embeddings(self):
return self.embed_tokens
# 设置 embed_tokens 属性为 value
def set_input_embeddings(self, value):
self.embed_tokens = value
# 忽略复制操作,用于 forward 函数的装饰器
@add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_router_logits: Optional[bool] = None,
return_dict: Optional[bool] = None,
# MixtralForCausalLM 类,继承自 MixtralPreTrainedModel 类,用于混合专家模型的因果语言建模任务
class MixtralForCausalLM(MixtralPreTrainedModel):
# 定义被绑定权重的键值,用于共享权重
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
# 调用父类的初始化方法,传入配置对象 config
super().__init__(config)
# 初始化 MixtralModel 模型,根据传入的配置对象 config
self.model = MixtralModel(config)
# 设置词汇表大小为配置对象中的词汇表大小
self.vocab_size = config.vocab_size
# 初始化 lm_head,使用线性层将隐藏状态映射到词汇表大小,无偏置
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# 设置路由辅助损失系数为配置对象中的路由辅助损失系数
self.router_aux_loss_coef = config.router_aux_loss_coef
# 设置本地专家的数量为配置对象中的本地专家数量
self.num_experts = config.num_local_experts
# 设置每个令牌的专家数量为配置对象中的每个令牌专家数量
self.num_experts_per_tok = config.num_experts_per_tok
# 调用后处理初始化方法,用于初始化权重并应用最终处理
self.post_init()
# 获取输入嵌入层,返回 MixtralModel 模型的嵌入 tokens
def get_input_embeddings(self):
return self.model.embed_tokens
# 设置输入嵌入层的值
def set_input_embeddings(self, value):
self.model.embed_tokens = value
# 获取输出嵌入层,返回 lm_head 线性层
def get_output_embeddings(self):
return self.lm_head
# 设置输出嵌入层的值
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
# 设置解码器,用于设置 MixtralModel 模型的 decoder
def set_decoder(self, decoder):
self.model = decoder
# 获取解码器,返回当前 MixtralModel 模型
def get_decoder(self):
return self.model
# 前向传播函数,接受多种输入参数,返回 MoeCausalLMOutputWithPast 类型的输出
@add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_router_logits: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
# 前向传播函数,详细参数含义见上方修饰器的文档注释
# 本函数无具体实现,仅用于说明接口,实际实现需在派生类中完成
pass
# 为生成准备输入的函数
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
output_router_logits=False,
**kwargs,
):
# 为生成任务准备输入的函数,详细参数含义见上方函数签名
# 本函数无具体实现,仅用于说明接口,实际实现需在派生类中完成
pass
# Omit tokens covered by past_key_values
# 如果 past_key_values 不为空,则跳过已被处理的 token
if past_key_values is not None:
# Check if past_key_values is an instance of Cache
# 检查 past_key_values 是否为 Cache 类的实例
if isinstance(past_key_values, Cache):
# Get sequence length from past_key_values
# 从 past_key_values 中获取序列长度
cache_length = past_key_values.get_seq_length()
# Get seen tokens count from past_key_values
# 从 past_key_values 中获取已看到的 token 数量
past_length = past_key_values.seen_tokens
# Get maximum cache length from past_key_values
# 从 past_key_values 中获取最大缓存长度
max_cache_length = past_key_values.get_max_length()
else:
# Assume past_key_values is a tuple and get dimensions from it
# 假设 past_key_values 是一个元组,并从中获取维度信息
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None
# Keep only the unprocessed tokens:
# 保留未处理的 token:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
# input)
# 如果 attention_mask 的长度超过 input_ids 的长度,则说明部分输入作为缓存的一部分传递(例如将 input_embeds 作为输入)
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
# input_ids based on the past_length.
# 如果 past_length 小于 input_ids 的长度,则 input_ids 包含所有的输入 token。根据 past_length 可以丢弃 input_ids 的部分 token。
elif past_length < input_ids.shape[1]:
input_ids = input_ids[:, past_length:]
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
# 否则(past_length >= input_ids.shape[1]),假设 input_ids 只包含未处理的 token。
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
# 如果即将超出最大缓存长度,我们需要裁剪输入的 attention mask。
if (
max_cache_length is not None
and attention_mask is not None
and cache_length + input_ids.shape[1] > max_cache_length
):
attention_mask = attention_mask[:, -max_cache_length:]
# Get position_ids from kwargs if not provided
# 如果 attention_mask 不为空且 position_ids 为空,则动态创建 position_ids 以进行批量生成
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
# 如果传递了 `inputs_embeds`,我们只想在第一代步骤中使用它们
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
# Update model_inputs with various parameters
# 使用各种参数更新 model_inputs
model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
"output_router_logits": output_router_logits,
}
)
# Return the constructed model_inputs dictionary
# 返回构建的 model_inputs 字典
return model_inputs
# 定义一个函数 `_reorder_cache`,用于重新排序缓存数据
def _reorder_cache(past_key_values, beam_idx):
# 初始化一个空的元组,用于存储重新排序后的缓存数据
reordered_past = ()
# 遍历每层的缓存数据
for layer_past in past_key_values:
# 对每层的缓存数据进行重新排序,并将结果作为元组加入到 `reordered_past` 中
reordered_past += (
# 对每个 `past_state` 根据 `beam_idx` 进行索引选择,并放到对应设备上
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
# 返回重新排序后的缓存数据
return reordered_past
"""
The Mixtral Model transformer with a sequence classification head on top (linear layer).
[`MixtralForSequenceClassification`] uses the last token in order to do the classification, as other causal models
(e.g. GPT-2) do.
Since it does classification on the last token, it requires to know the position of the last token. If a
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
each row of the batch).
"""
# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Mixtral, LLAMA->MIXTRAL
class MixtralForSequenceClassification(MixtralPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.model = MixtralModel(config) # 初始化 Mixtral 模型
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) # 线性层用于分类得分
# 初始化权重并应用最终处理
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens # 返回输入嵌入的模型
def set_input_embeddings(self, value):
self.model.embed_tokens = value # 设置输入嵌入的模型
@add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
"""
Forward pass for MixtralForSequenceClassification.
Args:
input_ids (torch.LongTensor, optional): Input token IDs.
attention_mask (torch.Tensor, optional): Mask to avoid performing attention on padding tokens.
position_ids (torch.LongTensor, optional): IDs to mark each token's position.
past_key_values (List[torch.FloatTensor], optional): Cached key/value states for faster decoding.
inputs_embeds (torch.FloatTensor, optional): Precomputed embeddings for the input tokens.
labels (torch.LongTensor, optional): Labels for computing the sequence classification loss.
use_cache (bool, optional): Whether or not to use cached key/value states.
output_attentions (bool, optional): Whether or not to output attentions weights.
output_hidden_states (bool, optional): Whether or not to output hidden states.
return_dict (bool, optional): Whether or not to return a dictionary as the output.
Returns:
Depending on `return_dict`, either a model output dictionary or a tuple of logits and loss.
Notes:
This method defines how inputs are processed through the Mixtral model for sequence classification.
"""
# 实现 MixtralForSequenceClassification 的前向传播
# 具体实现根据参数的不同选择执行不同的操作,最终返回结果
pass
.\models\mixtral\__init__.py
from typing import TYPE_CHECKING
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_torch_available,
)
_import_structure = {
"configuration_mixtral": ["MIXTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP", "MixtralConfig"],
}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_mixtral"] = [
"MixtralForCausalLM",
"MixtralModel",
"MixtralPreTrainedModel",
"MixtralForSequenceClassification",
]
if TYPE_CHECKING:
from .configuration_mixtral import MIXTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP, MixtralConfig
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_mixtral import (
MixtralForCausalLM,
MixtralForSequenceClassification,
MixtralModel,
MixtralPreTrainedModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
.\models\mluke\convert_mluke_original_pytorch_checkpoint_to_pytorch.py
"""Convert mLUKE checkpoint."""
import argparse
import json
import os
from collections import OrderedDict
import torch
from transformers import LukeConfig, LukeForMaskedLM, MLukeTokenizer, XLMRobertaTokenizer
from transformers.tokenization_utils_base import AddedToken
@torch.no_grad()
def convert_luke_checkpoint(checkpoint_path, metadata_path, entity_vocab_path, pytorch_dump_folder_path, model_size):
with open(metadata_path) as metadata_file:
metadata = json.load(metadata_file)
config = LukeConfig(use_entity_aware_attention=True, **metadata["model_config"])
state_dict = torch.load(checkpoint_path, map_location="cpu")["module"]
entity_vocab = load_original_entity_vocab(entity_vocab_path)
entity_vocab["[MASK2]"] = max(entity_vocab.values()) + 1
config.entity_vocab_size += 1
tokenizer = XLMRobertaTokenizer.from_pretrained(metadata["model_config"]["bert_model_name"])
entity_token_1 = AddedToken("<ent>", lstrip=False, rstrip=False)
entity_token_2 = AddedToken("<ent2>", lstrip=False, rstrip=False)
tokenizer.add_special_tokens({"additional_special_tokens": [entity_token_1, entity_token_2]})
config.vocab_size += 2
print(f"Saving tokenizer to {pytorch_dump_folder_path}")
tokenizer.save_pretrained(pytorch_dump_folder_path)
with open(os.path.join(pytorch_dump_folder_path, "tokenizer_config.json"), "r") as f:
tokenizer_config = json.load(f)
tokenizer_config["tokenizer_class"] = "MLukeTokenizer"
with open(os.path.join(pytorch_dump_folder_path, "tokenizer_config.json"), "w") as f:
json.dump(tokenizer_config, f)
with open(os.path.join(pytorch_dump_folder_path, MLukeTokenizer.vocab_files_names["entity_vocab_file"]), "w") as f:
json.dump(entity_vocab, f)
tokenizer = MLukeTokenizer.from_pretrained(pytorch_dump_folder_path)
ent_init_index = tokenizer.convert_tokens_to_ids(["@"])[0]
ent2_init_index = tokenizer.convert_tokens_to_ids(["#"])[0]
word_emb = state_dict["embeddings.word_embeddings.weight"]
ent_emb = word_emb[ent_init_index].unsqueeze(0)
ent2_emb = word_emb[ent2_init_index].unsqueeze(0)
state_dict["embeddings.word_embeddings.weight"] = torch.cat([word_emb, ent_emb, ent2_emb])
for bias_name in ["lm_head.decoder.bias", "lm_head.bias"]:
decoder_bias = state_dict[bias_name]
ent_decoder_bias = decoder_bias[ent_init_index].unsqueeze(0)
ent2_decoder_bias = decoder_bias[ent2_init_index].unsqueeze(0)
state_dict[bias_name] = torch.cat([decoder_bias, ent_decoder_bias, ent2_decoder_bias])
for layer_index in range(config.num_hidden_layers):
for matrix_name in ["query.weight", "query.bias"]:
prefix = f"encoder.layer.{layer_index}.attention.self."
state_dict[prefix + "w2e_" + matrix_name] = state_dict[prefix + matrix_name]
state_dict[prefix + "e2w_" + matrix_name] = state_dict[prefix + matrix_name]
state_dict[prefix + "e2e_" + matrix_name] = state_dict[prefix + matrix_name]
entity_emb = state_dict["entity_embeddings.entity_embeddings.weight"]
entity_mask_emb = entity_emb[entity_vocab["[MASK]"]].unsqueeze(0)
state_dict["entity_embeddings.entity_embeddings.weight"] = torch.cat([entity_emb, entity_mask_emb])
entity_prediction_bias = state_dict["entity_predictions.bias"]
entity_mask_bias = entity_prediction_bias[entity_vocab["[MASK]"]].unsqueeze(0)
state_dict["entity_predictions.bias"] = torch.cat([entity_prediction_bias, entity_mask_bias])
model = LukeForMaskedLM(config=config).eval()
state_dict.pop("entity_predictions.decoder.weight")
state_dict.pop("lm_head.decoder.weight")
state_dict.pop("lm_head.decoder.bias")
state_dict_for_hugging_face = OrderedDict()
for key, value in state_dict.items():
if not (key.startswith("lm_head") or key.startswith("entity_predictions")):
state_dict_for_hugging_face[f"luke.{key}"] = state_dict[key]
else:
state_dict_for_hugging_face[key] = state_dict[key]
missing_keys, unexpected_keys = model.load_state_dict(state_dict_for_hugging_face, strict=False)
if set(unexpected_keys) != {"luke.embeddings.position_ids"}:
raise ValueError(f"Unexpected unexpected_keys: {unexpected_keys}")
if set(missing_keys) != {
"lm_head.decoder.weight",
"lm_head.decoder.bias",
"entity_predictions.decoder.weight",
}:
raise ValueError(f"Unexpected missing_keys: {missing_keys}")
model.tie_weights()
assert (model.luke.embeddings.word_embeddings.weight == model.lm_head.decoder.weight).all()
assert (model.luke.entity_embeddings.entity_embeddings.weight == model.entity_predictions.decoder.weight).all()
tokenizer = MLukeTokenizer.from_pretrained(pytorch_dump_folder_path, task="entity_classification")
text = "ISO 639-3 uses the code fas for the dialects spoken across Iran and アフガニスタン (Afghanistan)."
span = (0, 9)
encoding = tokenizer(text, entity_spans=[span], return_tensors="pt")
outputs = model(**encoding)
if model_size == "large":
raise NotImplementedError
else:
expected_shape = torch.Size((1, 33, 768))
expected_slice = torch.tensor([[0.0892, 0.0596, -0.2819], [0.0134, 0.1199, 0.0573], [-0.0169, 0.0927, 0.0644]])
if not (outputs.last_hidden_state.shape == expected_shape):
raise ValueError(
f"Outputs.last_hidden_state.shape is {outputs.last_hidden_state.shape}, Expected shape is {expected_shape}"
)
if not torch.allclose(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4):
raise ValueError
if model_size == "large":
raise NotImplementedError
else:
expected_shape = torch.Size((1, 1, 768))
expected_slice = torch.tensor([[-0.1482, 0.0609, 0.0322]])
if not (outputs.entity_last_hidden_state.shape == expected_shape):
raise ValueError(
f"Outputs.entity_last_hidden_state.shape is {outputs.entity_last_hidden_state.shape}, Expected shape is"
f" {expected_shape}"
)
if not torch.allclose(outputs.entity_last_hidden_state[0, :3, :3], expected_slice, atol=1e-4):
raise ValueError
tokenizer = MLukeTokenizer.from_pretrained(pytorch_dump_folder_path)
text = "Tokyo is the capital of <mask>."
span = (24, 30)
encoding = tokenizer(text, entity_spans=[span], return_tensors="pt")
outputs = model(**encoding)
input_ids = encoding["input_ids"][0].tolist()
mask_position_id = input_ids.index(tokenizer.convert_tokens_to_ids("<mask>"))
predicted_id = outputs.logits[0][mask_position_id].argmax(dim=-1)
assert "Japan" == tokenizer.decode(predicted_id)
predicted_entity_id = outputs.entity_logits[0][0].argmax().item()
multilingual_predicted_entities = [
entity for entity, entity_id in tokenizer.entity_vocab.items() if entity_id == predicted_entity_id
]
assert [e for e in multilingual_predicted_entities if e.startswith("en:")][0] == "en:Japan"
print("Saving PyTorch model to {}".format(pytorch_dump_folder_path))
model.save_pretrained(pytorch_dump_folder_path)
def load_original_entity_vocab(entity_vocab_path):
SPECIAL_TOKENS = ["[MASK]", "[PAD]", "[UNK]"]
data = [json.loads(line) for line in open(entity_vocab_path)]
new_mapping = {}
for entry in data:
entity_id = entry["id"]
for entity_name, language in entry["entities"]:
if entity_name in SPECIAL_TOKENS:
new_mapping[entity_name] = entity_id
break
new_entity_name = f"{language}:{entity_name}"
new_mapping[new_entity_name] = entity_id
return new_mapping
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint_path", type=str, help="Path to a pytorch_model.bin file.")
parser.add_argument(
"--metadata_path", default=None, type=str, help="Path to a metadata.json file, defining the configuration."
)
parser.add_argument(
"--entity_vocab_path",
default=None,
type=str,
help="Path to an entity_vocab.tsv file, containing the entity vocabulary.",
)
parser.add_argument(
"--pytorch_dump_folder_path", default=None, type=str, help="Path to where to dump the output PyTorch model."
)
parser.add_argument(
"--model_size", default="base", type=str, choices=["base", "large"], help="Size of the model to be converted."
)
args = parser.parse_args()
convert_luke_checkpoint(
args.checkpoint_path,
args.metadata_path,
args.entity_vocab_path,
args.pytorch_dump_folder_path,
args.model_size,
)
.\models\mluke\tokenization_mluke.py
import itertools
import json
import os
from collections.abc import Mapping
from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import sentencepiece as spm
from ...tokenization_utils import PreTrainedTokenizer
from ...tokenization_utils_base import (
ENCODE_KWARGS_DOCSTRING,
AddedToken,
BatchEncoding,
EncodedInput,
PaddingStrategy,
TensorType,
TextInput,
TextInputPair,
TruncationStrategy,
to_py_obj,
)
from ...utils import add_end_docstrings, is_tf_tensor, is_torch_tensor, logging
logger = logging.get_logger(__name__)
EntitySpan = Tuple[int, int]
EntitySpanInput = List[EntitySpan]
Entity = str
EntityInput = List[Entity]
SPIECE_UNDERLINE = "▁"
VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model", "entity_vocab_file": "entity_vocab.json"}
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"studio-ousia/mluke-base": "https://huggingface.co/studio-ousia/mluke-base/resolve/main/vocab.json",
},
"merges_file": {
"studio-ousia/mluke-base": "https://huggingface.co/studio-ousia/mluke-base/resolve/main/merges.txt",
},
"entity_vocab_file": {
"studio-ousia/mluke-base": "https://huggingface.co/studio-ousia/mluke-base/resolve/main/entity_vocab.json",
},
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"studio-ousia/mluke-base": 512,
}
class MLukeTokenizer(PreTrainedTokenizer):
"""
Adapted from [`XLMRobertaTokenizer`] and [`LukeTokenizer`]. Based on
[SentencePiece](https://github.com/google/sentencepiece).
This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
this superclass for more information regarding those methods.
Attributes:
sp_model (`SentencePieceProcessor`):
The *SentencePiece* processor that is used for every conversion (string, tokens and IDs).
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
model_input_names = ["input_ids", "attention_mask"]
def __init__(
self,
vocab_file,
entity_vocab_file,
bos_token="<s>",
eos_token="</s>",
sep_token="</s>",
cls_token="<s>",
unk_token="<unk>",
pad_token="<pad>",
mask_token="<mask>",
task=None,
max_entity_length=32,
max_mention_length=30,
entity_token_1="<ent>",
entity_token_2="<ent2>",
entity_unk_token="[UNK]",
entity_pad_token="[PAD]",
entity_mask_token="[MASK]",
entity_mask2_token="[MASK2]",
sp_model_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
):
super().__init__(**kwargs)
@property
def vocab_size(self):
return len(self.sp_model) + self.fairseq_offset + 1
def get_vocab(self):
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
vocab.update(self.added_tokens_encoder)
return vocab
def _tokenize(self, text: str) -> List[str]:
return self.sp_model.encode(text, out_type=str)
def _convert_token_to_id(self, token):
"""Converts a token (str) in an id using the vocab."""
if token in self.fairseq_tokens_to_ids:
return self.fairseq_tokens_to_ids[token]
spm_id = self.sp_model.PieceToId(token)
return spm_id + self.fairseq_offset if spm_id else self.unk_token_id
def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab."""
if index in self.fairseq_ids_to_tokens:
return self.fairseq_ids_to_tokens[index]
return self.sp_model.IdToPiece(index - self.fairseq_offset)
def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (strings for sub-words) in a single string."""
out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
return out_string
def __getstate__(self):
state = self.__dict__.copy()
state["sp_model"] = None
state["sp_model_proto"] = self.sp_model.serialized_model_proto()
return state
def __setstate__(self, d):
self.__dict__ = d
if not hasattr(self, "sp_model_kwargs"):
self.sp_model_kwargs = {}
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.LoadFromSerializedProto(self.sp_model_proto)
@add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
def __call__(
self,
text: Union[TextInput, List[TextInput]],
text_pair: Optional[Union[TextInput, List[TextInput]]] = None,
entity_spans: Optional[Union[EntitySpanInput, List[EntitySpanInput]]] = None,
entity_spans_pair: Optional[Union[EntitySpanInput, List[EntitySpanInput]]] = None,
entities: Optional[Union[EntityInput, List[EntityInput]]] = None,
entities_pair: Optional[Union[EntityInput, List[EntityInput]]] = None,
add_special_tokens: bool = True,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = None,
max_length: Optional[int] = None,
max_entity_length: Optional[int] = None,
stride: int = 0,
is_split_into_words: Optional[bool] = False,
pad_to_multiple_of: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
return_length: bool = False,
verbose: bool = True,
**kwargs,
):
def _encode_plus(
self,
text: Union[TextInput],
text_pair: Optional[Union[TextInput]] = None,
entity_spans: Optional[EntitySpanInput] = None,
entity_spans_pair: Optional[EntitySpanInput] = None,
entities: Optional[EntityInput] = None,
entities_pair: Optional[EntityInput] = None,
add_special_tokens: bool = True,
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
max_length: Optional[int] = None,
max_entity_length: Optional[int] = None,
stride: int = 0,
is_split_into_words: Optional[bool] = False,
pad_to_multiple_of: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
return_length: bool = False,
verbose: bool = True,
**kwargs,
):
) -> BatchEncoding:
if return_offsets_mapping:
raise NotImplementedError(
"return_offset_mapping is not available when using Python tokenizers. "
"To use this feature, change your tokenizer to one deriving from "
"transformers.PreTrainedTokenizerFast. "
"More information on available tokenizers at "
"https://github.com/huggingface/transformers/pull/2674"
)
if is_split_into_words:
raise NotImplementedError("is_split_into_words is not supported in this tokenizer.")
(
first_ids,
second_ids,
first_entity_ids,
second_entity_ids,
first_entity_token_spans,
second_entity_token_spans,
) = self._create_input_sequence(
text=text,
text_pair=text_pair,
entities=entities,
entities_pair=entities_pair,
entity_spans=entity_spans,
entity_spans_pair=entity_spans_pair,
**kwargs,
)
return self.prepare_for_model(
first_ids,
pair_ids=second_ids,
entity_ids=first_entity_ids,
pair_entity_ids=second_entity_ids,
entity_token_spans=first_entity_token_spans,
pair_entity_token_spans=second_entity_token_spans,
add_special_tokens=add_special_tokens,
padding=padding_strategy.value,
truncation=truncation_strategy.value,
max_length=max_length,
max_entity_length=max_entity_length,
stride=stride,
pad_to_multiple_of=pad_to_multiple_of,
return_tensors=return_tensors,
prepend_batch_axis=True,
return_attention_mask=return_attention_mask,
return_token_type_ids=return_token_type_ids,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_length=return_length,
verbose=verbose,
)
def _batch_encode_plus(
self,
batch_text_or_text_pairs: Union[List[TextInput], List[TextInputPair]],
batch_entity_spans_or_entity_spans_pairs: Optional[
Union[List[EntitySpanInput], List[Tuple[EntitySpanInput, EntitySpanInput]]]
] = None,
batch_entities_or_entities_pairs: Optional[
Union[List[EntityInput], List[Tuple[EntityInput, EntityInput]]]
] = None,
add_special_tokens: bool = True,
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
max_length: Optional[int] = None,
max_entity_length: Optional[int] = None,
stride: int = 0,
is_split_into_words: Optional[bool] = False,
pad_to_multiple_of: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
return_length: bool = False,
verbose: bool = True,
**kwargs,
):
if not isinstance(entity_spans, list):
raise ValueError("entity_spans should be given as a list")
elif len(entity_spans) > 0 and not isinstance(entity_spans[0], tuple):
raise ValueError(
"entity_spans should be given as a list of tuples containing the start and end character indices"
)
if entities is not None:
if not isinstance(entities, list):
raise ValueError("If you specify entities, they should be given as a list")
if len(entities) > 0 and not isinstance(entities[0], str):
raise ValueError("If you specify entities, they should be given as a list of entity names")
if len(entities) != len(entity_spans):
raise ValueError("If you specify entities, entities and entity_spans must be the same length")
def _create_input_sequence(
self,
text: Union[TextInput],
text_pair: Optional[Union[TextInput]] = None,
entities: Optional[EntityInput] = None,
entities_pair: Optional[EntityInput] = None,
entity_spans: Optional[EntitySpanInput] = None,
entity_spans_pair: Optional[EntitySpanInput] = None,
**kwargs,
):
pass
@add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
def _batch_prepare_for_model(
batch_ids_pairs: List[Tuple[List[int], None]],
batch_entity_ids_pairs: List[Tuple[Optional[List[int]], Optional[List[int]]]],
batch_entity_token_spans_pairs: List[Tuple[Optional[List[Tuple[int, int]]], Optional[List[Tuple[int, int]]]]],
add_special_tokens: bool = True,
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
max_length: Optional[int] = None,
max_entity_length: Optional[int] = None,
stride: int = 0,
pad_to_multiple_of: Optional[int] = None,
return_tensors: Optional[str] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_length: bool = False,
verbose: bool = True,
) -> BatchEncoding:
"""
Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. It
adds special tokens, truncates sequences if overflowing while taking into account the special tokens and
manages a moving window (with user defined stride) for overflowing tokens
Args:
batch_ids_pairs: list of tokenized input ids or input ids pairs
batch_entity_ids_pairs: list of entity ids or entity ids pairs
batch_entity_token_spans_pairs: list of entity spans or entity spans pairs
max_entity_length: The maximum length of the entity sequence.
"""
batch_outputs = {}
for input_ids, entity_ids, entity_token_span_pairs in zip(
batch_ids_pairs, batch_entity_ids_pairs, batch_entity_token_spans_pairs
):
first_ids, second_ids = input_ids
first_entity_ids, second_entity_ids = entity_ids
first_entity_token_spans, second_entity_token_spans = entity_token_span_pairs
outputs = self.prepare_for_model(
first_ids,
second_ids,
entity_ids=first_entity_ids,
pair_entity_ids=second_entity_ids,
entity_token_spans=first_entity_token_spans,
pair_entity_token_spans=second_entity_token_spans,
add_special_tokens=add_special_tokens,
padding=PaddingStrategy.DO_NOT_PAD.value,
truncation=truncation_strategy.value,
max_length=max_length,
max_entity_length=max_entity_length,
stride=stride,
pad_to_multiple_of=None,
return_attention_mask=False,
return_token_type_ids=return_token_type_ids,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_length=return_length,
return_tensors=None,
prepend_batch_axis=False,
verbose=verbose,
)
for key, value in outputs.items():
if key not in batch_outputs:
batch_outputs[key] = []
batch_outputs[key].append(value)
batch_outputs = self.pad(
batch_outputs,
padding=padding_strategy.value,
max_length=max_length,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask,
)
batch_outputs = BatchEncoding(batch_outputs, tensor_type=return_tensors)
return batch_outputs
@add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
def prepare_for_model(
self,
ids: List[int],
pair_ids: Optional[List[int]] = None,
entity_ids: Optional[List[int]] = None,
pair_entity_ids: Optional[List[int]] = None,
entity_token_spans: Optional[List[Tuple[int, int]]] = None,
pair_entity_token_spans: Optional[List[Tuple[int, int]]] = None,
add_special_tokens: bool = True,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = None,
max_length: Optional[int] = None,
max_entity_length: Optional[int] = None,
stride: int = 0,
pad_to_multiple_of: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
return_length: bool = False,
verbose: bool = True,
prepend_batch_axis: bool = False,
**kwargs,
):
pass
def pad(
self,
encoded_inputs: Union[
BatchEncoding,
List[BatchEncoding],
Dict[str, EncodedInput],
Dict[str, List[EncodedInput]],
List[Dict[str, EncodedInput]],
],
padding: Union[bool, str, PaddingStrategy] = True,
max_length: Optional[int] = None,
max_entity_length: Optional[int] = None,
pad_to_multiple_of: Optional[int] = None,
return_attention_mask: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
verbose: bool = True,
):
pass
def _pad(
self,
encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
max_length: Optional[int] = None,
max_entity_length: Optional[int] = None,
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
pad_to_multiple_of: Optional[int] = None,
return_attention_mask: Optional[bool] = None,
):
pass
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str, str]:
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
return
out_vocab_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
)
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file)
elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
entity_vocab_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["entity_vocab_file"]
)
with open(entity_vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(self.entity_vocab, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
return out_vocab_file, entity_vocab_file
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
adding special tokens. An XLM-RoBERTa sequence has the following format:
- single sequence: `<s> X </s>`
- pair of sequences: `<s> A </s></s> B </s>`
Args:
token_ids_0 (`List[int]`):
List of IDs to which the special tokens will be added.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
"""
if token_ids_1 is None:
return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
cls = [self.cls_token_id]
sep = [self.sep_token_id]
return cls + token_ids_0 + sep + sep + token_ids_1 + sep
def get_special_tokens_mask(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
):
def get_special_tokens_mask(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
) -> List[int]:
"""
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer `prepare_for_model` method.
Args:
token_ids_0 (`List[int]`):
List of IDs.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not the token list is already formatted with special tokens for the model.
Returns:
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
"""
if already_has_special_tokens:
return super().get_special_tokens_mask(
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
)
if token_ids_1 is None:
return [1] + ([0] * len(token_ids_0)) + [1]
return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
def create_token_type_ids_from_sequences(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Create a mask from the two sequences passed to be used in a sequence-pair classification task. XLM-RoBERTa does
not make use of token type ids, therefore a list of zeros is returned.
Args:
token_ids_0 (`List[int]`):
List of IDs.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: List of zeros.
"""
sep = [self.sep_token_id]
cls = [self.cls_token_id]
if token_ids_1 is None:
return len(cls + token_ids_0 + sep) * [0]
return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
.\models\mluke\__init__.py
from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_sentencepiece_available
_import_structure = {}
try:
if not is_sentencepiece_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["tokenization_mluke"] = ["MLukeTokenizer"]
if TYPE_CHECKING:
try:
if not is_sentencepiece_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .tokenization_mluke import MLukeTokenizer
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
.\models\mobilebert\configuration_mobilebert.py
""" MobileBERT model configuration"""
from collections import OrderedDict
from typing import Mapping
from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig
from ...utils import logging
logger = logging.get_logger(__name__)
MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"google/mobilebert-uncased": "https://huggingface.co/google/mobilebert-uncased/resolve/main/config.json"
}
class MobileBertConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`MobileBertModel`] or a [`TFMobileBertModel`]. It
is used to instantiate a MobileBERT model according to the specified arguments, defining the model architecture.
Instantiating a configuration with the defaults will yield a similar configuration to that of the MobileBERT
[google/mobilebert-uncased](https://huggingface.co/google/mobilebert-uncased) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Examples:
```
>>> from transformers import MobileBertConfig, MobileBertModel
>>> # Initializing a MobileBERT configuration
>>> configuration = MobileBertConfig()
>>> # Initializing a model (with random weights) from the configuration above
>>> model = MobileBertModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
Attributes: pretrained_config_archive_map (Dict[str, str]): A dictionary containing all the available pre-trained
checkpoints.
"""
pretrained_config_archive_map = MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
model_type = "mobilebert"
def __init__(
self,
vocab_size=30522,
hidden_size=512,
num_hidden_layers=24,
num_attention_heads=4,
intermediate_size=512,
hidden_act="relu",
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
layer_norm_eps=1e-12,
pad_token_id=0,
embedding_size=128,
trigram_input=True,
use_bottleneck=True,
intra_bottleneck_size=128,
use_bottleneck_attention=False,
key_query_shared_bottleneck=True,
num_feedforward_networks=4,
normalization_type="no_norm",
classifier_activation=True,
classifier_dropout=None,
**kwargs,
):
super().__init__(pad_token_id=pad_token_id, **kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.embedding_size = embedding_size
self.trigram_input = trigram_input
self.use_bottleneck = use_bottleneck
self.intra_bottleneck_size = intra_bottleneck_size
self.use_bottleneck_attention = use_bottleneck_attention
self.key_query_shared_bottleneck = key_query_shared_bottleneck
self.num_feedforward_networks = num_feedforward_networks
self.normalization_type = normalization_type
self.classifier_activation = classifier_activation
if self.use_bottleneck:
self.true_hidden_size = intra_bottleneck_size
else:
self.true_hidden_size = hidden_size
self.classifier_dropout = classifier_dropout
class MobileBertOnnxConfig(OnnxConfig):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task == "multiple-choice":
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
else:
dynamic_axis = {0: "batch", 1: "sequence"}
return OrderedDict(
[
("input_ids", dynamic_axis),
("attention_mask", dynamic_axis),
("token_type_ids", dynamic_axis),
]
)
.\models\mobilebert\convert_mobilebert_original_tf_checkpoint_to_pytorch.py
import argparse
import torch
from transformers import MobileBertConfig, MobileBertForPreTraining, load_tf_weights_in_mobilebert
from transformers.utils import logging
logging.set_verbosity_info()
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, mobilebert_config_file, pytorch_dump_path):
config = MobileBertConfig.from_json_file(mobilebert_config_file)
print(f"Building PyTorch model from configuration: {config}")
model = MobileBertForPreTraining(config)
model = load_tf_weights_in_mobilebert(model, config, tf_checkpoint_path)
print(f"Save PyTorch model to {pytorch_dump_path}")
torch.save(model.state_dict(), pytorch_dump_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
)
parser.add_argument(
"--mobilebert_config_file",
default=None,
type=str,
required=True,
help=(
"The config json file corresponding to the pre-trained MobileBERT model. \n"
"This specifies the model architecture."
),
)
parser.add_argument(
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
)
args = parser.parse_args()
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.mobilebert_config_file, args.pytorch_dump_path)
.\models\mobilebert\modeling_mobilebert.py
import math
import os
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import torch
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPooling,
MaskedLMOutput,
MultipleChoiceModelOutput,
NextSentencePredictorOutput,
QuestionAnsweringModelOutput,
SequenceClassifierOutput,
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
ModelOutput,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from .configuration_mobilebert import MobileBertConfig
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "google/mobilebert-uncased"
_CONFIG_FOR_DOC = "MobileBertConfig"
_CHECKPOINT_FOR_TOKEN_CLASSIFICATION = "mrm8488/mobilebert-finetuned-ner"
_TOKEN_CLASS_EXPECTED_OUTPUT = "['I-ORG', 'I-ORG', 'O', 'O', 'O', 'O', 'O', 'I-LOC', 'O', 'I-LOC', 'I-LOC']"
_TOKEN_CLASS_EXPECTED_LOSS = 0.03
_CHECKPOINT_FOR_QA = "csarron/mobilebert-uncased-squad-v2"
_QA_EXPECTED_OUTPUT = "'a nice puppet'"
_QA_EXPECTED_LOSS = 3.98
_QA_TARGET_START_INDEX = 12
_QA_TARGET_END_INDEX = 13
_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "lordtt13/emo-mobilebert"
_SEQ_CLASS_EXPECTED_OUTPUT = "'others'"
_SEQ_CLASS_EXPECTED_LOSS = "4.72"
MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST = ["google/mobilebert-uncased"]
def load_tf_weights_in_mobilebert(model, config, tf_checkpoint_path):
```
加载 MobileBERT 模型的 TensorFlow 权重,并将它们转换为 PyTorch 模型权重。
Args:
model (PreTrainedModel): 要加载权重的 MobileBERT 模型实例。
config (MobileBertConfig): MobileBERT 模型的配置对象。
tf_checkpoint_path (str): TensorFlow 权重的路径。
Returns:
None
Raises:
ImportError: 如果导入 TensorFlow 失败。
RuntimeError: 如果无法从 tf_checkpoint_path 加载权重。
Example usage:
```
model = MobileBertModel.from_pretrained('google/mobilebert-uncased')
config = MobileBertConfig.from_pretrained('google/mobilebert-uncased')
load_tf_weights_in_mobilebert(model, config, 'path/to/tf_checkpoint')
```
```
```
"""Load tf checkpoints in a pytorch model."""
try:
import re
import numpy as np
import tensorflow as tf
except ImportError:
logger.error(
"Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions."
)
raise
tf_path = os.path.abspath(tf_checkpoint_path)
logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
init_vars = tf.train.list_variables(tf_path)
names = []
arrays = []
for name, shape in init_vars:
logger.info(f"Loading TF weight {name} with shape {shape}")
array = tf.train.load_variable(tf_path, name)
names.append(name)
arrays.append(array)
for name, array in zip(names, arrays):
name = name.replace("ffn_layer", "ffn")
name = name.replace("FakeLayerNorm", "LayerNorm")
name = name.replace("extra_output_weights", "dense/kernel")
name = name.replace("bert", "mobilebert")
name = name.split("/")
if any(
n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
for n in name
):
logger.info(f"Skipping {'/'.join(name)}")
continue
pointer = model
for m_name in name:
if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
scope_names = re.split(r"_(\d+)", m_name)
else:
scope_names = [m_name]
if scope_names[0] == "kernel" or scope_names[0] == "gamma":
pointer = getattr(pointer, "weight")
elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
pointer = getattr(pointer, "bias")
elif scope_names[0] == "output_weights":
pointer = getattr(pointer, "weight")
elif scope_names[0] == "squad":
pointer = getattr(pointer, "classifier")
else:
try:
pointer = getattr(pointer, scope_names[0])
except AttributeError:
logger.info(f"Skipping {'/'.join(name)}")
continue
if len(scope_names) >= 2:
num = int(scope_names[1])
pointer = pointer[num]
if m_name[-11:] == "_embeddings":
pointer = getattr(pointer, "weight")
elif m_name == "kernel":
array = np.transpose(array)
try:
assert (
pointer.shape == array.shape
), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
except AssertionError as e:
e.args += (pointer.shape, array.shape)
raise
logger.info(f"Initialize PyTorch weight {name}")
pointer.data = torch.from_numpy(array)
return model
self,
input_ids: Optional[torch.LongTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
super().__init__()
self.trigram_input = config.trigram_input
self.embedding_size = config.embedding_size
self.hidden_size = config.hidden_size
self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
embed_dim_multiplier = 3 if self.trigram_input else 1
embedded_input_size = self.embedding_size * embed_dim_multiplier
self.embedding_transformation = nn.Linear(embedded_input_size, config.hidden_size)
self.LayerNorm = NORM2FN[config.normalization_type](config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
) -> torch.Tensor:
if input_ids is not None:
input_shape = input_ids.size()
else:
input_shape = inputs_embeds.size()[:-1]
seq_length = input_shape[1]
if position_ids is None:
position_ids = self.position_ids[:, :seq_length]
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
if self.trigram_input:
inputs_embeds = torch.cat(
[
nn.functional.pad(inputs_embeds[:, 1:], [0, 0, 0, 1, 0, 0], value=0.0),
inputs_embeds,
nn.functional.pad(inputs_embeds[:, :-1], [0, 0, 1, 0, 0, 0], value=0.0),
],
dim=2,
)
if self.trigram_input or self.embedding_size != self.hidden_size:
inputs_embeds = self.embedding_transformation(inputs_embeds)
position_embeddings = self.position_embeddings(position_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = inputs_embeds + position_embeddings + token_type_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
class MobileBertSelfAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.true_hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.true_hidden_size, self.all_head_size)
self.key = nn.Linear(config.true_hidden_size, self.all_head_size)
self.value = nn.Linear(
config.true_hidden_size if config.use_bottleneck_attention else config.hidden_size, self.all_head_size
)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
self,
query_tensor: torch.Tensor,
key_tensor: torch.Tensor,
value_tensor: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
) -> Tuple[torch.Tensor]:
mixed_query_layer = self.query(query_tensor)
mixed_key_layer = self.key(key_tensor)
mixed_value_layer = self.value(value_tensor)
query_layer = self.transpose_for_scores(mixed_query_layer)
key_layer = self.transpose_for_scores(mixed_key_layer)
value_layer = self.transpose_for_scores(mixed_value_layer)
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
if attention_mask is not None:
attention_scores = attention_scores + attention_mask
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
attention_probs = self.dropout(attention_probs)
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs
def __init__(self, config):
super().__init__()
self.use_bottleneck = config.use_bottleneck
self.dense = nn.Linear(config.true_hidden_size, config.true_hidden_size)
self.LayerNorm = NORM2FN[config.normalization_type](config.true_hidden_size, eps=config.layer_norm_eps)
if not self.use_bottleneck:
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor, residual_tensor: torch.Tensor) -> torch.Tensor:
layer_outputs = self.dense(hidden_states)
if not self.use_bottleneck:
layer_outputs = self.dropout(layer_outputs)
layer_outputs = self.LayerNorm(layer_outputs + residual_tensor)
return layer_outputs
class MobileBertAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.self = MobileBertSelfAttention(config)
self.output = MobileBertSelfOutput(config)
self.pruned_heads = set()
def prune_heads(self, heads):
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
)
self.self.query = prune_linear_layer(self.self.query, index)
self.self.key = prune_linear_layer(self.self.key, index)
self.self.value = prune_linear_layer(self.self.value, index)
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
def forward(
self,
query_tensor: torch.Tensor,
key_tensor: torch.Tensor,
value_tensor: torch.Tensor,
layer_input: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
) -> Tuple[torch.Tensor]:
self_outputs = self.self(
query_tensor,
key_tensor,
value_tensor,
attention_mask,
head_mask,
output_attentions,
)
attention_output = self.output(self_outputs[0], layer_input)
outputs = (attention_output,) + self_outputs[1:]
return outputs
class MobileBertIntermediate(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.true_hidden_size, config.intermediate_size)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class OutputBottleneck(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.true_hidden_size, config.hidden_size)
self.LayerNorm = NORM2FN[config.normalization_type](config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor, residual_tensor: torch.Tensor) -> torch.Tensor:
layer_outputs = self.dense(hidden_states)
layer_outputs = self.dropout(layer_outputs)
layer_outputs = self.LayerNorm(layer_outputs + residual_tensor)
return layer_outputs
class MobileBertOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.use_bottleneck = config.use_bottleneck
self.dense = nn.Linear(config.intermediate_size, config.true_hidden_size)
self.LayerNorm = NORM2FN[config.normalization_type](config.true_hidden_size)
if not self.use_bottleneck:
self.dropout = nn.Dropout(config.hidden_dropout_prob)
else:
self.bottleneck = OutputBottleneck(config)
def forward(
self, intermediate_states: torch.Tensor, residual_tensor_1: torch.Tensor, residual_tensor_2: torch.Tensor
) -> torch.Tensor:
layer_output = self.dense(intermediate_states)
if not self.use_bottleneck:
layer_output = self.dropout(layer_output)
layer_output = self.LayerNorm(layer_output + residual_tensor_1)
else:
layer_output = self.LayerNorm(layer_output + residual_tensor_1)
layer_output = self.bottleneck(layer_output, residual_tensor_2)
return layer_output
class BottleneckLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intra_bottleneck_size)
self.LayerNorm = NORM2FN[config.normalization_type](config.intra_bottleneck_size, eps=config.layer_norm_eps)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
layer_input = self.dense(hidden_states)
layer_input = self.LayerNorm(layer_input)
return layer_input
class Bottleneck(nn.Module):
def __init__(self, config):
super().__init__()
self.key_query_shared_bottleneck = config.key_query_shared_bottleneck
self.use_bottleneck_attention = config.use_bottleneck_attention
self.input = BottleneckLayer(config)
if self.key_query_shared_bottleneck:
self.attention = BottleneckLayer(config)
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor]:
bottlenecked_hidden_states = self.input(hidden_states)
if self.use_bottleneck_attention:
return (bottlenecked_hidden_states,) * 4
elif self.key_query_shared_bottleneck:
shared_attention_input = self.attention(hidden_states)
return (shared_attention_input, shared_attention_input, hidden_states, bottlenecked_hidden_states)
else:
return (hidden_states, hidden_states, hidden_states, bottlenecked_hidden_states)
class MobileBertLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.use_bottleneck = config.use_bottleneck
self.num_feedforward_networks = config.num_feedforward_networks
self.attention = MobileBertAttention(config)
self.intermediate = MobileBertIntermediate(config)
self.output = MobileBertOutput(config)
if self.use_bottleneck:
self.bottleneck = Bottleneck(config)
if config.num_feedforward_networks > 1:
self.ffn = nn.ModuleList([FFNLayer(config) for _ in range(config.num_feedforward_networks - 1)])
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
```
) -> Tuple[torch.Tensor]:
if self.use_bottleneck:
query_tensor, key_tensor, value_tensor, layer_input = self.bottleneck(hidden_states)
else:
query_tensor, key_tensor, value_tensor, layer_input = [hidden_states] * 4
self_attention_outputs = self.attention(
query_tensor,
key_tensor,
value_tensor,
layer_input,
attention_mask,
head_mask,
output_attentions=output_attentions,
)
attention_output = self_attention_outputs[0]
s = (attention_output,)
outputs = self_attention_outputs[1:]
if self.num_feedforward_networks != 1:
for i, ffn_module in enumerate(self.ffn):
attention_output = ffn_module(attention_output)
s += (attention_output,)
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output, hidden_states)
outputs = (
(layer_output,)
+ outputs
+ (
torch.tensor(1000),
query_tensor,
key_tensor,
value_tensor,
layer_input,
attention_output,
intermediate_output,
)
+ s
)
return outputs
class MobileBertEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.layer = nn.ModuleList([MobileBertLayer(config) for _ in range(config.num_hidden_layers)])
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = True,
) -> Union[Tuple, BaseModelOutput]:
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_outputs = layer_module(
hidden_states,
attention_mask,
head_mask[i],
output_attentions,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
)
class MobileBertPooler(nn.Module):
def __init__(self, config):
super().__init__()
self.do_activate = config.classifier_activation
if self.do_activate:
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
first_token_tensor = hidden_states[:, 0]
if not self.do_activate:
return first_token_tensor
else:
pooled_output = self.dense(first_token_tensor)
pooled_output = torch.tanh(pooled_output)
return pooled_output
class MobileBertPredictionHeadTransform(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
if isinstance(config.hidden_act, str):
self.transform_act_fn = ACT2FN[config.hidden_act]
else:
self.transform_act_fn = config.hidden_act
self.LayerNorm = NORM2FN["layer_norm"](config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states
class MobileBertLMPredictionHead(nn.Module):
def __init__(self, config):
super().__init__()
self.transform = MobileBertPredictionHeadTransform(config)
self.dense = nn.Linear(config.vocab_size, config.hidden_size - config.embedding_size, bias=False)
self.decoder = nn.Linear(config.embedding_size, config.vocab_size, bias=False)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
self.decoder.bias = self.bias
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.transform(hidden_states)
hidden_states = hidden_states.matmul(torch.cat([self.decoder.weight.t(), self.dense.weight], dim=0))
hidden_states += self.decoder.bias
return hidden_states
class MobileBertOnlyMLMHead(nn.Module):
def __init__(self, config):
super().__init__()
self.predictions = MobileBertLMPredictionHead(config)
def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
prediction_scores = self.predictions(sequence_output)
return prediction_scores
class MobileBertPreTrainingHeads(nn.Module):
def __init__(self, config):
super().__init__()
self.predictions = MobileBertLMPredictionHead(config)
self.seq_relationship = nn.Linear(config.hidden_size, 2)
def forward(self, sequence_output: torch.Tensor, pooled_output: torch.Tensor) -> Tuple[torch.Tensor]:
prediction_scores = self.predictions(sequence_output)
seq_relationship_score = self.seq_relationship(pooled_output)
return prediction_scores, seq_relationship_score
class MobileBertPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = MobileBertConfig
pretrained_model_archive_map = MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST
load_tf_weights = load_tf_weights_in_mobilebert
base_model_prefix = "mobilebert"
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, (nn.LayerNorm, NoNorm)):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
@dataclass
class MobileBertForPreTrainingOutput(ModelOutput):
"""
Output type of [`MobileBertForPreTraining`].
Args:
loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
Total loss as the sum of the masked language modeling loss and the next sequence prediction
(classification) loss.
prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
seq_relationship_logits (`torch.FloatTensor` of shape `(batch_size, 2)`):
Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
before SoftMax).
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
loss: Optional[torch.FloatTensor] = None
prediction_logits: torch.FloatTensor = None
seq_relationship_logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
MOBILEBERT_START_DOCSTRING = r"""
Docstring for `MobileBertForPreTrainingOutput`.
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`MobileBertConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
MOBILEBERT_INPUTS_DOCSTRING = r"""
Docstring for `MOBILEBERT_INPUTS_DOCSTRING`.
"""
Args:
input_ids (`torch.LongTensor` of shape `({0})`):
Indices of input sequence tokens in the vocabulary.
[What are input IDs?](../glossary
attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
- 1 表示 **未被屏蔽** 的标记,
- 0 表示 **被屏蔽** 的标记。
[What are attention masks?](../glossary
token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
- 0 对应 *句子 A* 的标记,
- 1 对应 *句子 B* 的标记。
[What are token type IDs?](../glossary
position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
[What are position IDs?](../glossary
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
- 1 表示头部 **未被屏蔽**,
- 0 表示头部 **被屏蔽**。
inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
output_attentions (`bool`, *optional*):
output_hidden_states (`bool`, *optional*):
return_dict (`bool`, *optional*):
"""
The bare MobileBert Model transformer outputting raw hidden-states without any specific head on top.
"""
@add_start_docstrings(
"The bare MobileBert Model transformer outputting raw hidden-states without any specific head on top.",
MOBILEBERT_START_DOCSTRING,
)
class MobileBertModel(MobileBertPreTrainedModel):
"""
MobileBertModel class implementing the MobileBERT architecture.
https://arxiv.org/pdf/2004.02984.pdf
"""
def __init__(self, config, add_pooling_layer=True):
"""
Initializes a MobileBertModel instance.
Args:
config (MobileBertConfig): Configuration class for MobileBERT.
add_pooling_layer (bool): Whether to add a pooling layer. Defaults to True.
"""
super().__init__(config)
self.config = config
self.embeddings = MobileBertEmbeddings(config)
self.encoder = MobileBertEncoder(config)
self.pooler = MobileBertPooler(config) if add_pooling_layer else None
self.post_init()
def get_input_embeddings(self):
"""
Retrieves the input word embeddings from MobileBertEmbeddings.
Returns:
torch.nn.Embedding: The word embedding layer.
"""
return self.embeddings.word_embeddings
def set_input_embeddings(self, value):
"""
Sets the input word embeddings in MobileBertEmbeddings.
Args:
value (torch.Tensor): New tensor for word embeddings.
"""
self.embeddings.word_embeddings = value
def _prune_heads(self, heads_to_prune):
"""
Prunes heads of the model.
Args:
heads_to_prune (dict): Dictionary of {layer_num: list of heads to prune in this layer}.
See base class PreTrainedModel.
"""
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
@add_start_docstrings_to_model_forward(
MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")
)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=BaseModelOutputWithPooling,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
"""
Forward pass for the MobileBertModel.
Args:
input_ids (Optional[torch.LongTensor]): Input ids of shape (batch_size, sequence_length).
attention_mask (Optional[torch.FloatTensor]): Attention mask of shape (batch_size, sequence_length).
token_type_ids (Optional[torch.LongTensor]): Token type ids of shape (batch_size, sequence_length).
position_ids (Optional[torch.LongTensor]): Position ids of shape (batch_size, sequence_length).
head_mask (Optional[torch.FloatTensor]): Mask to nullify selected heads of shape (num_heads,).
inputs_embeds (Optional[torch.FloatTensor]): Embedded inputs of shape (batch_size, sequence_length, embedding_size).
output_hidden_states (Optional[bool]): Whether to return hidden states.
output_attentions (Optional[bool]): Whether to return attentions.
return_dict (Optional[bool]): Whether to return a dictionary.
Returns:
BaseModelOutputWithPooling or tuple:
BaseModelOutputWithPooling if output_hidden_states=False and output_attentions=False
tuple (torch.FloatTensor, ...) otherwise
"""
pass
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
input_shape = input_ids.size()
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if attention_mask is None:
attention_mask = torch.ones(input_shape, device=device)
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
embedding_output = self.embeddings(
input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
)
encoder_outputs = self.encoder(
embedding_output,
attention_mask=extended_attention_mask,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPooling(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
@add_start_docstrings(
"""
MobileBert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a
`next sentence prediction (classification)` head.
""",
MOBILEBERT_START_DOCSTRING,
)
class MobileBertForPreTraining(MobileBertPreTrainedModel):
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
def __init__(self, config):
super().__init__(config)
self.mobilebert = MobileBertModel(config)
self.cls = MobileBertPreTrainingHeads(config)
self.post_init()
def get_output_embeddings(self):
return self.cls.predictions.decoder
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> nn.Embedding:
self.cls.predictions.dense = self._get_resized_lm_head(
self.cls.predictions.dense, new_num_tokens=new_num_tokens, transposed=True
)
return super().resize_token_embeddings(new_num_tokens=new_num_tokens)
@add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=MobileBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
next_sentence_label: Optional[torch.LongTensor] = None,
output_attentions: Optional[torch.FloatTensor] = None,
output_hidden_states: Optional[torch.FloatTensor] = None,
return_dict: Optional[torch.FloatTensor] = None,
):
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> nn.Embedding:
self.cls.predictions.dense = self._get_resized_lm_head(
self.cls.predictions.dense, new_num_tokens=new_num_tokens, transposed=True
)
return super().resize_token_embeddings(new_num_tokens=new_num_tokens)
@add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=MaskedLMOutput,
config_class=_CONFIG_FOR_DOC,
expected_output="'paris'",
expected_loss=0.57,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, MaskedLMOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.mobilebert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
prediction_scores = self.cls(sequence_output)
masked_lm_loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict:
output = (prediction_scores,) + outputs[2:]
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
return MaskedLMOutput(
loss=masked_lm_loss,
logits=prediction_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
class MobileBertOnlyNSPHead(nn.Module):
def __init__(self, config):
super().__init__()
self.seq_relationship = nn.Linear(config.hidden_size, 2)
def forward(self, pooled_output: torch.Tensor) -> torch.Tensor:
seq_relationship_score = self.seq_relationship(pooled_output)
return seq_relationship_score
@add_start_docstrings(
"""MobileBert Model with a `next sentence prediction (classification)` head on top.""",
MOBILEBERT_START_DOCSTRING,
)
class MobileBertForNextSentencePrediction(MobileBertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.mobilebert = MobileBertModel(config)
self.cls = MobileBertOnlyNSPHead(config)
self.post_init()
@add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple, NextSentencePredictorOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
(see `input_ids` docstring) Indices should be in `[0, 1]`.
- 0 indicates sequence B is a continuation of sequence A,
- 1 indicates sequence B is a random sequence.
Returns:
Depending on `return_dict`:
- If `return_dict` is `False`, returns a tuple with `seq_relationship_score` and additional outputs.
- If `return_dict` is `True`, returns a `NextSentencePredictorOutput` object.
Examples:
Example usage of the `MobileBertForNextSentencePrediction` model.
```
>>> from transformers import AutoTokenizer, MobileBertForNextSentencePrediction
>>> import torch
>>> tokenizer = AutoTokenizer.from_pretrained("google/mobilebert-uncased")
>>> model = MobileBertForNextSentencePrediction.from_pretrained("google/mobilebert-uncased")
>>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
>>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
>>> encoding = tokenizer(prompt, next_sentence, return_tensors="pt")
>>> outputs = model(**encoding, labels=torch.LongTensor([1]))
>>> loss = outputs.loss
>>> logits = outputs.logits
```"""
if "next_sentence_label" in kwargs:
warnings.warn(
"The `next_sentence_label` argument is deprecated and will be removed in a future version, use"
" `labels` instead.",
FutureWarning,
)
labels = kwargs.pop("next_sentence_label")
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.mobilebert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
pooled_output = outputs[1]
seq_relationship_score = self.cls(pooled_output)
next_sentence_loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), labels.view(-1))
if not return_dict:
output = (seq_relationship_score,) + outputs[2:]
return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output
return NextSentencePredictorOutput(
loss=next_sentence_loss,
logits=seq_relationship_score,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
class MobileBertForSequenceClassification(MobileBertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.config = config
self.mobilebert = MobileBertModel(config)
classifier_dropout = (
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
)
self.dropout = nn.Dropout(classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.post_init()
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.mobilebert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"""
MobileBert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a
linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
""",
MOBILEBERT_START_DOCSTRING,
)
class MobileBertForQuestionAnswering(MobileBertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.mobilebert = MobileBertModel(config, add_pooling_layer=False)
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
self.post_init()
@add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_QA,
output_type=QuestionAnsweringModelOutput,
config_class=_CONFIG_FOR_DOC,
qa_target_start_index=_QA_TARGET_START_INDEX,
qa_target_end_index=_QA_TARGET_END_INDEX,
expected_output=_QA_EXPECTED_OUTPUT,
expected_loss=_QA_EXPECTED_LOSS,
)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
start_positions: Optional[torch.Tensor] = None,
end_positions: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.mobilebert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None
if start_positions is not None and end_positions is not None:
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1)
ignored_index = start_logits.size(1)
start_positions = start_positions.clamp(0, ignored_index)
end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
if not return_dict:
output = (start_logits, end_logits) + outputs[2:]
return ((total_loss,) + output) if total_loss is not None else output
return QuestionAnsweringModelOutput(
loss=total_loss,
start_logits=start_logits,
end_logits=end_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
class MobileBertForMultipleChoice(MobileBertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.mobilebert = MobileBertModel(config)
classifier_dropout = (
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
)
self.dropout = nn.Dropout(classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, 1)
self.post_init()
@add_start_docstrings_to_model_forward(
MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=MultipleChoiceModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
`input_ids` above)
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
inputs_embeds = (
inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
if inputs_embeds is not None
else None
)
outputs = self.mobilebert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
reshaped_logits = logits.view(-1, num_choices)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(reshaped_logits, labels)
if not return_dict:
output = (reshaped_logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return MultipleChoiceModelOutput(
loss=loss,
logits=reshaped_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"""
MobileBert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g.
for Named-Entity-Recognition (NER) tasks.
""",
MOBILEBERT_START_DOCSTRING,
)
class MobileBertForTokenClassification(MobileBertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.mobilebert = MobileBertModel(config, add_pooling_layer=False)
classifier_dropout = (
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
)
self.dropout = nn.Dropout(classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.post_init()
@add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_TOKEN_CLASSIFICATION,
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
expected_output=_TOKEN_CLASS_EXPECTED_OUTPUT,
expected_loss=_TOKEN_CLASS_EXPECTED_LOSS,
)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.mobilebert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)