Transformers 源码解析(七十二)
.\models\maskformer\modeling_maskformer.py
"""
import math
from dataclasses import dataclass
from numbers import Number
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
from torch import Tensor, nn
from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_outputs import BaseModelOutputWithCrossAttentions
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import is_torch_greater_or_equal_than_2_1
from ...utils import (
ModelOutput, # 基本模型输出类,用于封装模型输出项
add_start_docstrings, # 用于添加模型的开始说明文档
add_start_docstrings_to_model_forward, # 用于添加模型输入参数的文档
is_accelerate_available, # 检查加速器模块是否可用
is_scipy_available, # 检查科学计算库是否可用
logging, # 日志记录模块
replace_return_docstrings, # 替换返回文档说明的函数
requires_backends, # 要求特定后端支持的装饰器
)
from ...utils.backbone_utils import load_backbone
from ..detr import DetrConfig
from .configuration_maskformer import MaskFormerConfig
from .configuration_maskformer_swin import MaskFormerSwinConfig
if is_accelerate_available(): # 检查加速器模块是否存在
from accelerate import PartialState
from accelerate.utils import reduce
if is_scipy_available(): # 检查科学计算库存在且可用
from scipy.optimize import linear_sum_assignment
logger = logging.get_logger(__name__) # 创建日志记录器
# "MaskFormerConfig"类实例,用于指定模型配置
_CONFIG_FOR_DOC = "MaskFormerConfig"
# "facebook/maskformer-swin-base-ade"模型的预训练模型地址
_CHECKPOINT_FOR_DOC = "facebook/maskformer-swin-base-ade"
# MaskFormer预训练模型列表
MASKFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = ["facebook/maskformer-swin-base-ade",]
@dataclass
# 定义"DetrDecoderOutput"类,扩展了"BaseModelOutputWithCrossAttentions"类,用于处理"DETR"解码器的输出项
class DetrDecoderOutput(BaseModelOutputWithCrossAttentions):
"""
"DetrDecoderOutput"类继承自"BaseModelOutputWithCrossAttentions"类,用于封装"DETREncoder"模块的输出。
接收一个"CrossAttentions"对象作为属性,并在此基础上添加了一个可选的解码器中间层激活堆栈。
用于单辅助解码器损失训练时提供额外的特征信息。
"""
# 定义函数的参数列表,包括最后一个隐藏层的输出状态
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
模型最后一层的隐藏状态的序列输出。
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
可选参数,当传递 `output_hidden_states=True` 或 `config.output_hidden_states=True` 时返回,
包含元组中的 `torch.FloatTensor`(一个用于嵌入层输出,每层输出一个)的形状为 `(batch_size, sequence_length, hidden_size)`。
模型每一层的隐藏状态,以及初始嵌入层的输出。
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
可选参数,当传递 `output_attentions=True` 或 `config.output_attentions=True` 时返回,
包含元组中的 `torch.FloatTensor`(每一层一个)的形状为 `(batch_size, num_heads, sequence_length, sequence_length)`。
经过注意力 softmax 后的注意力权重,用于计算自注意力头中的加权平均值。
cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
可选参数,当同时传递 `output_attentions=True` 和 `config.add_cross_attention=True` 或 `config.output_attentions=True` 时返回,
包含元组中的 `torch.FloatTensor`(每一层一个)的形状为 `(batch_size, num_heads, sequence_length, sequence_length)`。
解码器交叉注意力层的注意力权重,经过注意力 softmax 后,用于计算交叉注意力头中的加权平均值。
intermediate_hidden_states (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, hidden_size)`, *optional*, returned when `config.auxiliary_loss=True`):
可选参数,当传递 `config.auxiliary_loss=True` 时返回,
形状为 `(config.decoder_layers, batch_size, num_queries, hidden_size)` 的中间解码器激活状态。
每个解码器层的中间激活状态,每个状态经过了层归一化。
"""
intermediate_hidden_states: Optional[torch.FloatTensor] = None
@dataclass
class MaskFormerPixelLevelModuleOutput(ModelOutput):
"""
MaskFormer's pixel level module output. It returns both the last and (optionally) the hidden states from the
`encoder` and `decoder`. By default, the `encoder` is a MaskFormerSwin Transformer and the `decoder` is a Feature
Pyramid Network (FPN).
The `encoder_last_hidden_state` are referred on the paper as **images features**, while `decoder_last_hidden_state`
as **pixel embeddings**
Args:
encoder_last_hidden_state (`torch.FloatTensor` of shape`(batch_size, num_channels, height, width)`):
Last hidden states (final feature map) of the last stage of the encoder.
encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
shape `(batch_size, num_channels, height, width)`. Hidden-states (also called feature maps) of the model at
the output of each stage.
decoder_last_hidden_state (`torch.FloatTensor` of shape`(batch_size, num_channels, height, width)`):
Last hidden states (final feature map) of the last stage of the decoder.
decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
shape `(batch_size, num_channels, height, width)`. Hidden-states (also called feature maps) of the model at
the output of each stage.
"""
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
decoder_last_hidden_state: Optional[torch.FloatTensor] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class MaskFormerPixelDecoderOutput(ModelOutput):
"""
MaskFormer's pixel decoder module output, practically a Feature Pyramid Network. It returns the last hidden state
and (optionally) the hidden states.
"""
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
模型最后阶段的最后隐藏状态(最终特征图)。
hidden_states (`tuple(torch.FloatTensor)`, *optional*, 当 `output_hidden_states=True` 时返回或当 `config.output_hidden_states=True` 时返回):
包含多个元素的元组,每个元素是 `torch.FloatTensor`,形状为 `(batch_size, num_channels, height, width)`。
模型在每一层输出的隐藏状态,还包括初始嵌入的输出。
attentions (`tuple(torch.FloatTensor)`, *optional*, 当 `output_attentions=True` 时返回或当 `config.output_attentions=True` 时返回):
包含多个元素的元组,每个元素是 `torch.FloatTensor`,形状为 `(batch_size, num_heads, sequence_length, sequence_length)`。
Detr 解码器中注意力权重经过 attention softmax 后的输出,用于计算自注意力头中的加权平均值。
"""
last_hidden_state: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
# 定义一个数据类,用于存储 [`MaskFormerModel`] 的输出。这个类返回计算 logits 所需的所有隐藏状态。
@dataclass
class MaskFormerModelOutput(ModelOutput):
"""
Class for outputs of [`MaskFormerModel`]. This class returns all the needed hidden states to compute the logits.
Args:
encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Last hidden states (final feature map) of the last stage of the encoder model (backbone).
pixel_decoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Last hidden states (final feature map) of the last stage of the pixel decoder model (FPN).
transformer_decoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Last hidden states (final feature map) of the last stage of the transformer decoder model.
encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
shape `(batch_size, num_channels, height, width)`. Hidden-states (also called feature maps) of the encoder
model at the output of each stage.
pixel_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
shape `(batch_size, num_channels, height, width)`. Hidden-states (also called feature maps) of the pixel
decoder model at the output of each stage.
transformer_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
shape `(batch_size, sequence_length, hidden_size)`. Hidden-states (also called feature maps) of the
transformer decoder at the output of each stage.
hidden_states `tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` containing `encoder_hidden_states`, `pixel_decoder_hidden_states` and
`decoder_hidden_states`
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`. Attentions weights from Detr's decoder after the attention softmax, used to compute the
weighted average in the self-attention heads.
"""
# 定义可选的 torch.FloatTensor 类型变量,用于存储编码器的最后隐藏状态
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
# 定义可选的 torch.FloatTensor 类型变量,用于存储像素解码器的最后隐藏状态
pixel_decoder_last_hidden_state: Optional[torch.FloatTensor] = None
# 定义可选的 torch.FloatTensor 类型变量,用于存储变换器解码器的最后隐藏状态
transformer_decoder_last_hidden_state: Optional[torch.FloatTensor] = None
# 定义可选的 Tuple[torch.FloatTensor] 类型变量,用于存储编码器的隐藏状态序列
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
# 定义可选的 Tuple[torch.FloatTensor] 类型变量,用于存储像素解码器的隐藏状态序列
pixel_decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
# 定义可选的 Tuple[torch.FloatTensor] 类型变量,用于存储变换器解码器的隐藏状态序列
transformer_decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
# 定义可选的 Tuple[torch.FloatTensor] 类型变量,用于存储隐藏状态序列
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
# 定义可选的 Tuple[torch.FloatTensor] 类型变量,用于存储注意力分布序列
attentions: Optional[Tuple[torch.FloatTensor]] = None
# 数据类装饰器,用于定义实例分割输出的数据结构,继承自ModelOutput类
@dataclass
class MaskFormerForInstanceSegmentationOutput(ModelOutput):
"""
Class for outputs of [`MaskFormerForInstanceSegmentation`].
This output can be directly passed to [`~MaskFormerImageProcessor.post_process_semantic_segmentation`] or or
[`~MaskFormerImageProcessor.post_process_instance_segmentation`] or
[`~MaskFormerImageProcessor.post_process_panoptic_segmentation`] depending on the task. Please, see
[`~MaskFormerImageProcessor] for details regarding usage.
"""
# 损失值,可选的浮点张量
loss: Optional[torch.FloatTensor] = None
# 类别查询的逻辑张量
class_queries_logits: torch.FloatTensor = None
# 掩码查询的逻辑张量
masks_queries_logits: torch.FloatTensor = None
# 辅助逻辑张量
auxiliary_logits: torch.FloatTensor = None
# 编码器最后隐藏状态,可选的浮点张量
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
# 像素解码器最后隐藏状态,可选的浮点张量
pixel_decoder_last_hidden_state: Optional[torch.FloatTensor] = None
# 变换器解码器最后隐藏状态,可选的浮点张量
transformer_decoder_last_hidden_state: Optional[torch.FloatTensor] = None
# 编码器隐藏状态,可选的浮点张量元组
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
# 像素解码器隐藏状态,可选的浮点张量元组
pixel_decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
# 变换器解码器隐藏状态,可选的浮点张量元组
transformer_decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
# 隐藏状态,可选的浮点张量元组
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
# 注意力分数,可选的浮点张量元组
attentions: Optional[Tuple[torch.FloatTensor]] = None
# 重新实现自原始实现的函数
def upsample_like(pixel_values: Tensor, like: Tensor, mode: str = "bilinear") -> Tensor:
"""
An utility function that upsamples `pixel_values` to match the dimension of `like`.
Args:
pixel_values (`torch.Tensor`):
The tensor we wish to upsample.
like (`torch.Tensor`):
The tensor we wish to use as size target.
mode (str, *optional*, defaults to `"bilinear"`):
The interpolation mode.
Returns:
`torch.Tensor`: The upsampled tensor
"""
# 获取`like`张量的高度和宽度维度
_, _, height, width = like.shape
# 使用双线性插值法对`pixel_values`进行上采样,使其大小与`like`相匹配
upsampled = nn.functional.interpolate(pixel_values, size=(height, width), mode=mode, align_corners=False)
return upsampled
# 计算DICE损失的函数
def dice_loss(inputs: Tensor, labels: Tensor, num_masks: int) -> Tensor:
r"""
Compute the DICE loss, similar to generalized IOU for masks as follows:
$$ \mathcal{L}_{\text{dice}(x, y) = 1 - \frac{2 * x \cap y }{x \cup y + 1}} $$
In practice, since `labels` is a binary mask, (only 0s and 1s), dice can be computed as follow
$$ \mathcal{L}_{\text{dice}(x, y) = 1 - \frac{2 * x * y }{x + y + 1}} $$
Args:
inputs (`torch.Tensor`):
A tensor representing a mask.
labels (`torch.Tensor`):
A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs
(0 for the negative class and 1 for the positive class).
num_masks (`int`):
The number of masks present in the current batch, used for normalization.
Returns:
`torch.Tensor`: The computed loss.
"""
# 对输入进行sigmoid操作并展平为一维张量,得到预测概率
probs = inputs.sigmoid().flatten(1)
# 计算DICE损失的分子部分:2 * 预测概率 * 真实标签的交集
numerator = 2 * (probs * labels).sum(-1)
# 计算概率和标签在最后一个维度上的和,分别求和
denominator = probs.sum(-1) + labels.sum(-1)
# 计算损失值,使用给定的数值计算公式
loss = 1 - (numerator + 1) / (denominator + 1)
# 将所有损失值求和并除以遮罩数量,得到平均损失
loss = loss.sum() / num_masks
# 返回计算得到的平均损失值
return loss
# 从原始实现重构而来的函数,计算逐对的 Sigmoid Focal Loss
def sigmoid_focal_loss(
inputs: Tensor, labels: Tensor, num_masks: int, alpha: float = 0.25, gamma: float = 2
) -> Tensor:
r"""
Focal loss,最初在 [Focal Loss for Dense Object Detection](https://arxiv.org/abs/1708.02002) 中提出,最初用于 RetinaNet。该损失计算如下:
$$ \mathcal{L}_{\text{focal loss}} = -(1 - p_t)^{\gamma}\log{(p_t)} $$
其中 \\(CE(p_t) = -\log{(p_t)}}\\),CE 是标准交叉熵损失。
请参考论文中的方程式 (1,2,3) 以获得更好的理解。
Args:
inputs (`torch.Tensor`):
任意形状的浮点张量。
labels (`torch.Tensor`):
与 inputs 相同形状的张量。存储每个元素的二元分类标签 (0 表示负类,1 表示正类)。
num_masks (`int`):
当前批次中存在的掩码数量,用于归一化。
alpha (float, *可选*, 默认为 0.25):
在范围 (0,1) 内的加权因子,用于平衡正负例。
gamma (float, *可选*, 默认为 2.0):
调整因子 \\(1 - p_t\\) 的指数,用于平衡简单与困难的例子。
Returns:
`torch.Tensor`: 计算得到的损失。
"""
# 使用带 logits 的二元交叉熵损失,不进行归一化
criterion = nn.BCEWithLogitsLoss(reduction="none")
# 对输入进行 sigmoid 操作得到概率
probs = inputs.sigmoid()
# 计算标准交叉熵损失
cross_entropy_loss = criterion(inputs, labels)
# 计算 p_t
p_t = probs * labels + (1 - probs) * (1 - labels)
# 计算 focal loss
loss = cross_entropy_loss * ((1 - p_t) ** gamma)
# 如果 alpha 大于等于 0,计算 alpha_t
if alpha >= 0:
alpha_t = alpha * labels + (1 - alpha) * (1 - labels)
loss = alpha_t * loss
# 计算平均损失并进行归一化
loss = loss.mean(1).sum() / num_masks
return loss
# 从原始实现重构而来的函数,计算逐对的 Dice Loss
def pair_wise_dice_loss(inputs: Tensor, labels: Tensor) -> Tensor:
"""
Dice Loss 的逐对版本,参见 `dice_loss` 以了解用法。
Args:
inputs (`torch.Tensor`):
表示掩码的张量
labels (`torch.Tensor`):
与 inputs 相同形状的张量。存储每个元素的二元分类标签 (0 表示负类,1 表示正类)。
Returns:
`torch.Tensor`: 每对之间计算得到的损失。
"""
# 对输入进行 sigmoid 操作并展平为一维
inputs = inputs.sigmoid().flatten(1)
# 计算分子,使用矩阵乘法
numerator = 2 * torch.matmul(inputs, labels.T)
# 使用广播获取 [num_queries, NUM_CLASSES] 矩阵
# 计算分母
denominator = inputs.sum(-1)[:, None] + labels.sum(-1)[None, :]
# 计算 Dice Loss
loss = 1 - (numerator + 1) / (denominator + 1)
return loss
# 从原始实现重构而来的函数,计算逐对的 Sigmoid Focal Loss
def pair_wise_sigmoid_focal_loss(inputs: Tensor, labels: Tensor, alpha: float = 0.25, gamma: float = 2.0) -> Tensor:
r"""
Sigmoid Focal Loss 的逐对版本,参见 `sigmoid_focal_loss` 以了解用法。
```
# 如果alpha小于0,则引发值错误异常
if alpha < 0:
raise ValueError("alpha must be positive")
# 获取输入张量的高度和宽度(假设输入是一个二维张量)
height_and_width = inputs.shape[1]
# 使用二元交叉熵损失函数,但是禁止自动平均(即不对每个样本的损失求平均)
criterion = nn.BCEWithLogitsLoss(reduction="none")
# 计算输入张量的sigmoid函数值,即转换为概率
prob = inputs.sigmoid()
# 计算正样本的交叉熵损失
cross_entropy_loss_pos = criterion(inputs, torch.ones_like(inputs))
# 计算焦点损失的正样本部分,用于聚焦于困难的样本
focal_pos = ((1 - prob) ** gamma) * cross_entropy_loss_pos
focal_pos *= alpha
# 计算负样本的交叉熵损失
cross_entropy_loss_neg = criterion(inputs, torch.zeros_like(inputs))
# 计算焦点损失的负样本部分,用于聚焦于容易的样本
focal_neg = (prob**gamma) * cross_entropy_loss_neg
focal_neg *= 1 - alpha
# 计算最终的损失值,分别乘以标签的转置以加权正负样本
loss = torch.matmul(focal_pos, labels.T) + torch.matmul(focal_neg, (1 - labels).T)
# 返回归一化后的损失,即平均每个元素的损失
return loss / height_and_width
# Copied from transformers.models.detr.modeling_detr.DetrAttention
class DetrAttention(nn.Module):
"""
Multi-headed attention from 'Attention Is All You Need' paper.
Here, we add position embeddings to the queries and keys (as explained in the DETR paper).
"""
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
bias: bool = True,
):
super().__init__()
self.embed_dim = embed_dim # 初始化注意力机制的嵌入维度
self.num_heads = num_heads # 初始化注意力头的数量
self.dropout = dropout # 初始化dropout率
self.head_dim = embed_dim // num_heads # 计算每个注意力头的维度
if self.head_dim * num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {num_heads})."
)
self.scaling = self.head_dim**-0.5 # 缩放因子,用于注意力计算
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) # 用于投影键的线性层
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) # 用于投影值的线性层
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) # 用于投影查询的线性层
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) # 输出投影的线性层
def _shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int):
return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
# 重塑张量形状,以便进行多头注意力操作
def with_pos_embed(self, tensor: torch.Tensor, object_queries: Optional[Tensor], **kwargs):
position_embeddings = kwargs.pop("position_embeddings", None)
if kwargs:
raise ValueError(f"Unexpected arguments {kwargs.keys()}")
if position_embeddings is not None and object_queries is not None:
raise ValueError(
"Cannot specify both position_embeddings and object_queries. Please use just object_queries"
)
if position_embeddings is not None:
logger.warning_once(
"position_embeddings has been deprecated and will be removed in v4.34. Please use object_queries instead"
)
object_queries = position_embeddings
return tensor if object_queries is None else tensor + object_queries
# 添加位置嵌入到输入张量中的查询,支持使用对象查询
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
object_queries: Optional[torch.Tensor] = None,
key_value_states: Optional[torch.Tensor] = None,
spatial_position_embeddings: Optional[torch.Tensor] = None,
output_attentions: bool = False,
**kwargs,
):
# 前向传播函数,实现注意力机制的计算过程
# 初始化方法,用于初始化一个DetrDecoderLayer对象
def __init__(self, config: DetrConfig):
# 调用父类的初始化方法
super().__init__()
# 设置嵌入维度等于配置文件中的d_model值
self.embed_dim = config.d_model
# 创建一个自注意力层对象,使用DetrAttention类实现
self.self_attn = DetrAttention(
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
)
# 设置Dropout层的概率为配置文件中的dropout值
self.dropout = config.dropout
# 设置激活函数为配置文件中指定的激活函数
self.activation_fn = ACT2FN[config.activation_function]
# 设置激活函数后的Dropout概率为配置文件中的activation_dropout值
self.activation_dropout = config.activation_dropout
# 对自注意力层输出进行LayerNorm归一化处理
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
# 创建一个编码器注意力层对象,使用DetrAttention类实现
self.encoder_attn = DetrAttention(
self.embed_dim,
config.decoder_attention_heads,
dropout=config.attention_dropout,
)
# 对编码器注意力层输出进行LayerNorm归一化处理
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
# 使用线性层进行特征变换,输入维度为embed_dim,输出维度为配置文件中的decoder_ffn_dim
self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
# 对fc1层输出进行线性变换,输出维度为embed_dim
self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
# 对最终输出进行LayerNorm归一化处理
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
# 前向传播方法,定义了如何处理输入数据的流程
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
object_queries: Optional[torch.Tensor] = None,
query_position_embeddings: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
**kwargs,
class DetrDecoder(nn.Module):
"""
Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`DetrDecoderLayer`].
The decoder updates the query embeddings through multiple self-attention and cross-attention layers.
Some small tweaks for DETR:
- object_queries and query_position_embeddings are added to the forward pass.
- if self.config.auxiliary_loss is set to True, also returns a stack of activations from all decoding layers.
Args:
config: DetrConfig
"""
def __init__(self, config: DetrConfig):
super().__init__()
self.config = config
self.dropout = config.dropout
self.layerdrop = config.decoder_layerdrop
# Initialize layers as a list of DetrDecoderLayer modules based on config.decoder_layers
self.layers = nn.ModuleList([DetrDecoderLayer(config) for _ in range(config.decoder_layers)])
# Apply LayerNorm to the output of the last decoder layer
self.layernorm = nn.LayerNorm(config.d_model)
# Gradient checkpointing is disabled by default
self.gradient_checkpointing = False
def forward(
self,
inputs_embeds=None,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
object_queries=None,
query_position_embeddings=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
**kwargs,
):
# Forward pass through the decoder layers
# Each layer updates the query embeddings using self-attention and cross-attention mechanisms
# object_queries and query_position_embeddings are incorporated if provided
# If auxiliary_loss is True, also returns hidden states from all decoding layers
# The method returns a dictionary of output values
pass # Placeholder for actual implementation
# refactored from original implementation
class MaskFormerHungarianMatcher(nn.Module):
"""This class computes an assignment between the labels and the predictions of the network.
For efficiency reasons, the labels don't include the no_object. Because of this, in general, there are more
predictions than labels. In this case, we do a 1-to-1 matching of the best predictions, while the others are
un-matched (and thus treated as non-objects).
"""
def __init__(self, cost_class: float = 1.0, cost_mask: float = 1.0, cost_dice: float = 1.0):
"""Creates the matcher
Params:
cost_class (float, *optional*, defaults to 1.0):
This is the relative weight of the classification error in the matching cost.
cost_mask (float, *optional*, defaults to 1.0):
This is the relative weight of the focal loss of the binary mask in the matching cost.
cost_dice (float, *optional*, defaults to 1.0):
This is the relative weight of the dice loss of the binary mask in the matching cost
"""
super().__init__()
if cost_class == 0 and cost_mask == 0 and cost_dice == 0:
raise ValueError("All costs cant be 0")
# Initialize the relative weights for classification, mask focal loss, and dice loss
self.cost_class = cost_class
self.cost_mask = cost_mask
self.cost_dice = cost_dice
@torch.no_grad()
def forward(self):
pass # Placeholder for actual implementation
# 返回对象的字符串表示形式,用于调试和显示
def __repr__(self):
# 构建字符串的头部,表示对象的类和名称
head = "Matcher " + self.__class__.__name__
# 构建字符串的主体部分,包括成本类、掩码和Dice成本的信息
body = [
f"cost_class: {self.cost_class}", # 显示成本类的数值
f"cost_mask: {self.cost_mask}", # 显示成本掩码的数值
f"cost_dice: {self.cost_dice}", # 显示Dice成本的数值
]
_repr_indent = 4 # 设置缩进量
# 将头部和主体部分结合起来,每一行前面加上指定的缩进量
lines = [head] + [" " * _repr_indent + line for line in body]
# 将所有行连接成一个多行字符串并返回
return "\n".join(lines)
# 从原始实现中复制并调整
class MaskFormerLoss(nn.Module):
def __init__(
self,
num_labels: int,
matcher: MaskFormerHungarianMatcher,
weight_dict: Dict[str, float],
eos_coef: float,
):
"""
MaskFormer Loss类。损失计算与DETR非常类似。过程分为两步:
1) 计算真实标签掩码与模型输出之间的匈牙利分配
2) 监督每对匹配的真实标签/预测(监督类别和掩码)
Args:
num_labels (`int`):
类别数量。
matcher (`MaskFormerHungarianMatcher`):
计算预测和标签之间分配的Torch模块。
weight_dict (`Dict[str, float]`):
不同损失要应用的权重字典。
eos_coef (`float`):
应用于空类别的权重。
"""
super().__init__()
requires_backends(self, ["scipy"])
self.num_labels = num_labels
self.matcher = matcher
self.weight_dict = weight_dict
self.eos_coef = eos_coef
# 创建一个权重张量,包含所有类别和一个额外的EOS类别
empty_weight = torch.ones(self.num_labels + 1)
empty_weight[-1] = self.eos_coef
self.register_buffer("empty_weight", empty_weight)
def _max_by_axis(self, the_list: List[List[int]]) -> List[int]:
# 按轴找到列表中的最大值
maxes = the_list[0]
for sublist in the_list[1:]:
for index, item in enumerate(sublist):
maxes[index] = max(maxes[index], item)
return maxes
def _pad_images_to_max_in_batch(self, tensors: List[Tensor]) -> Tuple[Tensor, Tensor]:
# 获取批次中的最大尺寸
max_size = self._max_by_axis([list(tensor.shape) for tensor in tensors])
batch_size = len(tensors)
# 计算最终形状
batch_shape = [batch_size] + max_size
b, _, h, w = batch_shape
# 获取元数据
dtype = tensors[0].dtype
device = tensors[0].device
# 创建零填充的张量和填充掩码
padded_tensors = torch.zeros(batch_shape, dtype=dtype, device=device)
padding_masks = torch.ones((b, h, w), dtype=torch.bool, device=device)
# 将张量填充到最大尺寸
for tensor, padded_tensor, padding_mask in zip(tensors, padded_tensors, padding_masks):
padded_tensor[: tensor.shape[0], : tensor.shape[1], : tensor.shape[2]].copy_(tensor)
padding_mask[: tensor.shape[1], : tensor.shape[2]] = False
return padded_tensors, padding_masks
def loss_labels(
self, class_queries_logits: Tensor, class_labels: List[Tensor], indices: Tuple[np.array]
) -> Dict[str, Tensor]:
"""Compute the losses related to the labels using cross entropy.
Args:
class_queries_logits (`torch.Tensor`):
A tensor of shape `batch_size, num_queries, num_labels`
class_labels (`List[torch.Tensor]`):
List of class labels of shape `(labels)`.
indices (`Tuple[np.array])`:
The indices computed by the Hungarian matcher.
Returns:
`Dict[str, Tensor]`: A dict of `torch.Tensor` containing the following key:
- **loss_cross_entropy** -- The loss computed using cross entropy on the predicted and ground truth labels.
"""
pred_logits = class_queries_logits
batch_size, num_queries, _ = pred_logits.shape
# Define CrossEntropyLoss criterion with empty_weight
criterion = nn.CrossEntropyLoss(weight=self.empty_weight)
# Obtain indices for permutation based on the Hungarian matcher output
idx = self._get_predictions_permutation_indices(indices)
# Concatenate target classes for each query in the batch
# Shape after concatenation: (batch_size, num_queries)
target_classes_o = torch.cat([target[j] for target, (_, j) in zip(class_labels, indices)])
# Initialize target_classes tensor with default values
# Shape: (batch_size, num_queries)
target_classes = torch.full(
(batch_size, num_queries), fill_value=self.num_labels, dtype=torch.int64, device=pred_logits.device
)
# Update target_classes tensor using the permutation indices
target_classes[idx] = target_classes_o
# Transpose pred_logits from "batch_size x num_queries x num_labels" to "batch_size x num_labels x num_queries"
pred_logits_transposed = pred_logits.transpose(1, 2)
# Compute cross entropy loss between transposed pred_logits and target_classes
loss_ce = criterion(pred_logits_transposed, target_classes)
# Prepare losses dictionary with cross entropy loss
losses = {"loss_cross_entropy": loss_ce}
return losses
) -> Dict[str, Tensor]:
"""Compute the losses related to the masks using focal and dice loss.
Args:
masks_queries_logits (`torch.Tensor`):
A tensor of shape `batch_size, num_queries, height, width`
mask_labels (`torch.Tensor`):
List of mask labels of shape `(labels, height, width)`.
indices (`Tuple[np.array])`:
The indices computed by the Hungarian matcher.
num_masks (`int)`:
The number of masks, used for normalization.
Returns:
`Dict[str, Tensor]`: A dict of `torch.Tensor` containing two keys:
- **loss_mask** -- The loss computed using sigmoid focal loss on the predicted and ground truth masks.
- **loss_dice** -- The loss computed using dice loss on the predicted and ground truth
masks.
"""
# Get permutation indices for predictions based on Hungarian matcher results
src_idx = self._get_predictions_permutation_indices(indices)
# Get permutation indices for targets based on Hungarian matcher results
tgt_idx = self._get_targets_permutation_indices(indices)
# Select predicted masks using the permutation indices
pred_masks = masks_queries_logits[src_idx]
# Pad and stack target masks to match the shape of predictions
target_masks, _ = self._pad_images_to_max_in_batch(mask_labels)
target_masks = target_masks[tgt_idx]
# Upsample predicted masks to match the size of target masks
pred_masks = nn.functional.interpolate(
pred_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False
)
# Flatten the predictions and targets for loss computation
pred_masks = pred_masks[:, 0].flatten(1)
target_masks = target_masks.flatten(1)
# Compute losses using sigmoid focal loss and dice loss
losses = {
"loss_mask": sigmoid_focal_loss(pred_masks, target_masks, num_masks),
"loss_dice": dice_loss(pred_masks, target_masks, num_masks),
}
return losses
def _get_predictions_permutation_indices(self, indices):
# Concatenate batch indices for predictions
batch_indices = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
# Concatenate prediction indices based on permutation results
predictions_indices = torch.cat([src for (src, _) in indices])
return batch_indices, predictions_indices
def _get_targets_permutation_indices(self, indices):
# Concatenate batch indices for targets
batch_indices = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
# Concatenate target indices based on permutation results
target_indices = torch.cat([tgt for (_, tgt) in indices])
return batch_indices, target_indices
def forward(
self,
masks_queries_logits: Tensor,
class_queries_logits: Tensor,
mask_labels: List[Tensor],
class_labels: List[Tensor],
auxiliary_predictions: Optional[Dict[str, Tensor]] = None,
"""
This performs the loss computation.
Args:
masks_queries_logits (`torch.Tensor`):
A tensor of shape `batch_size, num_queries, height, width`
表示查询掩码的logits张量,形状为 `batch_size, num_queries, height, width`
class_queries_logits (`torch.Tensor`):
A tensor of shape `batch_size, num_queries, num_labels`
表示查询类别的logits张量,形状为 `batch_size, num_queries, num_labels`
mask_labels (`torch.Tensor`):
List of mask labels of shape `(labels, height, width)`.
掩码标签列表,形状为 `(labels, height, width)`
class_labels (`List[torch.Tensor]`):
List of class labels of shape `(labels)`.
类别标签列表,形状为 `(labels)`
auxiliary_predictions (`Dict[str, torch.Tensor]`, *optional*):
if `use_auxiliary_loss` was set to `true` in [`MaskFormerConfig`], then it contains the logits from the
inner layers of the Detr's Decoder.
可选参数,如果在 `MaskFormerConfig` 中设置了 `use_auxiliary_loss` 为 `true`,则包含来自 Detr 解码器内部层的logits。
Returns:
`Dict[str, Tensor]`: A dict of `torch.Tensor` containing two keys:
- **loss_cross_entropy** -- The loss computed using cross entropy on the predicted and ground truth labels.
使用交叉熵计算预测标签和真实标签之间的损失。
- **loss_mask** -- The loss computed using sigmoid focal loss on the predicted and ground truth masks.
使用sigmoid focal loss计算预测掩码和真实掩码之间的损失。
- **loss_dice** -- The loss computed using dice loss on the predicted and ground truth masks.
使用dice loss计算预测掩码和真实掩码之间的损失。
if `use_auxiliary_loss` was set to `true` in [`MaskFormerConfig`], the dictionary contains additional losses
for each auxiliary predictions.
如果在 [`MaskFormerConfig`] 中设置了 `use_auxiliary_loss` 为 `true`,则字典包含每个辅助预测的额外损失。
"""
# retrieve the matching between the outputs of the last layer and the labels
# 获取最后一层输出与标签之间的匹配
indices = self.matcher(masks_queries_logits, class_queries_logits, mask_labels, class_labels)
# compute the average number of target masks for normalization purposes
# 计算平均目标掩码数量,用于归一化
num_masks: Number = self.get_num_masks(class_labels, device=class_labels[0].device)
# get all the losses
# 获取所有的损失
losses: Dict[str, Tensor] = {
**self.loss_masks(masks_queries_logits, mask_labels, indices, num_masks),
**self.loss_labels(class_queries_logits, class_labels, indices),
}
# in case of auxiliary losses, we repeat this process with the output of each intermediate layer.
# 如果存在辅助损失,则对每个中间层的输出重复此过程。
if auxiliary_predictions is not None:
for idx, aux_outputs in enumerate(auxiliary_predictions):
masks_queries_logits = aux_outputs["masks_queries_logits"]
class_queries_logits = aux_outputs["class_queries_logits"]
loss_dict = self.forward(masks_queries_logits, class_queries_logits, mask_labels, class_labels)
loss_dict = {f"{key}_{idx}": value for key, value in loss_dict.items()}
losses.update(loss_dict)
return losses
# 定义一个方法,计算批次中目标掩码的平均数量,用于归一化目的。
def get_num_masks(self, class_labels: torch.Tensor, device: torch.device) -> torch.Tensor:
"""
Computes the average number of target masks across the batch, for normalization purposes.
计算批次中目标掩码的平均数量,用于归一化目的。
"""
# 计算所有类别标签中的掩码总数
num_masks = sum([len(classes) for classes in class_labels])
# 将掩码总数转换为张量,并指定数据类型和设备
num_masks = torch.as_tensor(num_masks, dtype=torch.float, device=device)
# 默认单进程世界大小
world_size = 1
# 如果加速库可用
if is_accelerate_available():
# 如果共享状态非空
if PartialState._shared_state != {}:
# 使用共享状态中的减少功能处理掩码总数
num_masks = reduce(num_masks)
# 获取部分状态对象的进程数量
world_size = PartialState().num_processes
# 将掩码总数除以进程数量进行归一化,并确保至少为1
num_masks = torch.clamp(num_masks / world_size, min=1)
# 返回归一化后的掩码数量
return num_masks
class MaskFormerFPNConvLayer(nn.Module):
def __init__(self, in_features: int, out_features: int, kernel_size: int = 3, padding: int = 1):
"""
A basic module that executes conv - norm - in sequence used in MaskFormer.
Args:
in_features (`int`):
The number of input features (channels).
out_features (`int`):
The number of outputs features (channels).
"""
super().__init__()
# Define layers for convolution, group normalization, and ReLU activation
self.layers = [
nn.Conv2d(in_features, out_features, kernel_size=kernel_size, padding=padding, bias=False),
nn.GroupNorm(32, out_features),
nn.ReLU(inplace=True),
]
# Add each layer to the module and name them with their index
for i, layer in enumerate(self.layers):
self.add_module(str(i), layer)
def forward(self, input: Tensor) -> Tensor:
# Apply each layer sequentially to the input tensor
hidden_state = input
for layer in self.layers:
hidden_state = layer(hidden_state)
return hidden_state
class MaskFormerFPNLayer(nn.Module):
def __init__(self, in_features: int, lateral_features: int):
"""
A Feature Pyramid Network Layer (FPN) layer. It creates a feature map by aggregating features from the previous
and backbone layer. Due to the spatial mismatch, the tensor coming from the previous layer is upsampled.
Args:
in_features (`int`):
The number of input features (channels).
lateral_features (`int`):
The number of lateral features (channels).
"""
super().__init__()
# Project features from the lateral connection to match in_features using 1x1 convolution and group normalization
self.proj = nn.Sequential(
nn.Conv2d(lateral_features, in_features, kernel_size=1, padding=0, bias=False),
nn.GroupNorm(32, in_features),
)
# Create a convolutional block for further processing of features
self.block = MaskFormerFPNConvLayer(in_features, in_features)
def forward(self, down: Tensor, left: Tensor) -> Tensor:
# Project features from the lateral connection
left = self.proj(left)
# Upsample the downsampled features to match the size of the lateral features
down = nn.functional.interpolate(down, size=left.shape[-2:], mode="nearest")
# Aggregate features by element-wise addition
down += left
# Process the aggregated features using the convolutional block
down = self.block(down)
return down
class MaskFormerFPNModel(nn.Module):
# This class definition continues in the actual code and is incomplete here.
pass
# 初始化方法,定义特征金字塔网络的结构
def __init__(self, in_features: int, lateral_widths: List[int], feature_size: int = 256):
"""
Feature Pyramid Network, given an input tensor and a set of feature maps of different feature/spatial sizes,
it creates a list of feature maps with the same feature size.
Args:
in_features (`int`):
The number of input features (channels).
lateral_widths (`List[int]`):
A list with the feature (channel) sizes of each lateral connection.
feature_size (int, *optional*, defaults to 256):
The feature (channel) size of the resulting feature maps.
"""
# 调用父类的初始化方法
super().__init__()
# 定义特征金字塔网络的起始卷积层
self.stem = MaskFormerFPNConvLayer(in_features, feature_size)
# 定义特征金字塔网络的中间层序列,每层是一个MaskFormerFPNLayer对象
self.layers = nn.Sequential(
*[MaskFormerFPNLayer(feature_size, lateral_width) for lateral_width in lateral_widths[::-1]]
)
# 前向传播方法,计算特征金字塔网络的输出特征图列表
def forward(self, features: List[Tensor]) -> List[Tensor]:
# 初始化一个空列表,用于存储特征金字塔网络的输出特征图
fpn_features = []
# 获取最后一个特征图
last_feature = features[-1]
# 获取除了最后一个特征图外的其他特征图列表
other_features = features[:-1]
# 将最后一个特征图送入起始卷积层stem计算
output = self.stem(last_feature)
# 逐层处理特征金字塔网络的每一层
for layer, left in zip(self.layers, other_features[::-1]):
# 使用当前层处理输出特征图和对应的左侧特征图,得到新的输出特征图
output = layer(output, left)
# 将处理后的特征图加入到特征金字塔网络输出列表中
fpn_features.append(output)
# 返回特征金字塔网络的所有输出特征图列表
return fpn_features
# 定义了一个名为 MaskFormerPixelDecoder 的神经网络模块类
class MaskFormerPixelDecoder(nn.Module):
# 初始化方法,设置模块的参数和属性
def __init__(self, *args, feature_size: int = 256, mask_feature_size: int = 256, **kwargs):
r"""
Pixel Decoder Module proposed in [Per-Pixel Classification is Not All You Need for Semantic
Segmentation](https://arxiv.org/abs/2107.06278). It first runs the backbone's features into a Feature Pyramid
Network creating a list of feature maps. Then, it projects the last one to the correct `mask_size`.
Args:
feature_size (`int`, *optional*, defaults to 256):
The feature size (channel dimension) of the FPN feature maps.
mask_feature_size (`int`, *optional*, defaults to 256):
The features (channels) of the target masks size \\(C_{\epsilon}\\) in the paper.
"""
super().__init__()
# 创建 MaskFormerFPNModel 实例,用于生成特征金字塔网络的特征图列表
self.fpn = MaskFormerFPNModel(*args, feature_size=feature_size, **kwargs)
# 使用卷积层将最后一个特征图投影到正确的 mask 尺寸
self.mask_projection = nn.Conv2d(feature_size, mask_feature_size, kernel_size=3, padding=1)
# 前向传播方法,处理输入数据并返回输出
def forward(
self, features: List[Tensor], output_hidden_states: bool = False, return_dict: bool = True
) -> MaskFormerPixelDecoderOutput:
# 使用特征金字塔网络处理输入特征列表,生成特征金字塔的特征图列表
fpn_features = self.fpn(features)
# 获取最后一个特征图并进行投影
last_feature_projected = self.mask_projection(fpn_features[-1])
# 根据 return_dict 参数返回不同形式的输出
if not return_dict:
return (last_feature_projected, tuple(fpn_features)) if output_hidden_states else (last_feature_projected,)
# 如果 return_dict 为 True,则返回 MaskFormerPixelDecoderOutput 对象
return MaskFormerPixelDecoderOutput(
last_hidden_state=last_feature_projected, hidden_states=tuple(fpn_features) if output_hidden_states else ()
)
# 复制并改编自原始实现,与 DetrSinePositionEmbedding 实现几乎相同
class MaskFormerSinePositionEmbedding(nn.Module):
"""
This is a more standard version of the position embedding, very similar to the one used by the Attention is all you
need paper, generalized to work on images.
"""
# 初始化方法,设置位置嵌入的参数和属性
def __init__(
self, num_pos_feats: int = 64, temperature: int = 10000, normalize: bool = False, scale: Optional[float] = None
):
super().__init__()
# 如果指定了 scale 参数但未设置 normalize 参数,则抛出异常
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
# 初始化位置特征数量、温度参数、标准化标志和缩放比例
self.num_pos_feats = num_pos_feats
self.temperature = temperature
self.normalize = normalize
self.scale = 2 * math.pi if scale is None else scale
# 实现 Transformer 模型中的位置编码生成函数
def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor:
# 如果没有给定掩码,创建一个全零的掩码张量,与输入张量的维度匹配
if mask is None:
mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
# 计算反掩码,将掩码取反,转换为输入张量的数据类型
not_mask = (~mask).to(x.dtype)
# 在不被掩码遮挡的区域上计算累积和,作为位置编码的一部分
y_embed = not_mask.cumsum(1) # 在第二个维度上进行累积和
x_embed = not_mask.cumsum(2) # 在第三个维度上进行累积和
# 如果需要归一化位置编码
if self.normalize:
eps = 1e-6
# 对 y 轴和 x 轴的位置编码进行归一化处理,并乘以缩放因子 self.scale
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
# 生成维度张量,用于计算位置编码
dim_t = torch.arange(self.num_pos_feats, dtype=torch.int64, device=x.device).type_as(x)
dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats)
# 根据维度张量计算 x 和 y 的位置编码
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
# 使用正弦和余弦函数堆叠 x 和 y 的位置编码
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
# 将 x 和 y 的位置编码连接起来,并将维度顺序转换为 (batch, channels, height, width)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
# 返回位置编码张量
return pos
class PredictionBlock(nn.Module):
def __init__(self, in_dim: int, out_dim: int, activation: nn.Module) -> None:
super().__init__()
# 创建一个包含线性层和激活函数的层列表
self.layers = [nn.Linear(in_dim, out_dim), activation]
# 将每个层作为子模块添加到当前模块中,以便在 forward 方法中能够正确调用
for i, layer in enumerate(self.layers):
self.add_module(str(i), layer)
def forward(self, input: Tensor) -> Tensor:
hidden_state = input
# 逐层应用网络层和激活函数
for layer in self.layers:
hidden_state = layer(hidden_state)
return hidden_state
class MaskformerMLPPredictionHead(nn.Module):
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int = 3):
"""
A classic Multi Layer Perceptron (MLP).
Args:
input_dim (`int`):
输入维度。
hidden_dim (`int`):
隐藏层维度。
output_dim (`int`):
输出维度。
num_layers (int, *optional*, defaults to 3):
层数。
"""
super().__init__()
# 构建输入和输出维度的列表,用于每个预测块的创建
in_dims = [input_dim] + [hidden_dim] * (num_layers - 1)
out_dims = [hidden_dim] * (num_layers - 1) + [output_dim]
self.layers = []
# 根据给定维度创建预测块列表
for i, (in_dim, out_dim) in enumerate(zip(in_dims, out_dims)):
# 对于除了最后一层外的每一层使用ReLU激活函数,最后一层使用恒等激活函数
activation = nn.ReLU() if i < num_layers - 1 else nn.Identity()
# 创建预测块对象
layer = PredictionBlock(in_dim, out_dim, activation=activation)
self.layers.append(layer)
# 将预测块作为子模块添加到当前模块中,使用索引作为名称
self.add_module(str(i), layer)
def forward(self, input: Tensor) -> Tensor:
hidden_state = input
# 逐层应用预测块
for layer in self.layers:
hidden_state = layer(hidden_state)
return hidden_state
class MaskFormerPixelLevelModule(nn.Module):
pass # 空模块,暂无具体实现
def __init__(self, config: MaskFormerConfig):
"""
Pixel Level Module proposed in [Per-Pixel Classification is Not All You Need for Semantic
Segmentation](https://arxiv.org/abs/2107.06278). It runs the input image through a backbone and a pixel
decoder, generating an image feature map and pixel embeddings.
Args:
config ([`MaskFormerConfig`]):
The configuration used to instantiate this model.
"""
super().__init__() # 调用父类的初始化方法
# 检查配置中是否有`backbone_config`属性,并且其`model_type`为"swin"时
if getattr(config, "backbone_config") is not None and config.backbone_config.model_type == "swin":
# 为了向后兼容,创建一个新的`backbone_config`,并从字典形式转换而来
backbone_config = config.backbone_config
backbone_config = MaskFormerSwinConfig.from_dict(backbone_config.to_dict())
# 设置新的`out_features`,这里设置为固定的阶段名称列表
backbone_config.out_features = ["stage1", "stage2", "stage3", "stage4"]
config.backbone_config = backbone_config
# 加载指定配置的背骨网络
self.encoder = load_backbone(config)
# 获取背骨网络最后一层的特征通道数
feature_channels = self.encoder.channels
# 初始化像素级解码器,传入参数为最后一个特征层的通道数、FPN特征大小、掩码特征大小、以及其它特征层的宽度
self.decoder = MaskFormerPixelDecoder(
in_features=feature_channels[-1],
feature_size=config.fpn_feature_size,
mask_feature_size=config.mask_feature_size,
lateral_widths=feature_channels[:-1],
)
def forward(
self, pixel_values: Tensor, output_hidden_states: bool = False, return_dict: bool = True
) -> MaskFormerPixelLevelModuleOutput:
# 将像素值传入编码器,获取特征映射
features = self.encoder(pixel_values).feature_maps
# 将特征映射传入解码器,获取解码器的输出
decoder_output = self.decoder(features, output_hidden_states, return_dict=return_dict)
# 如果`return_dict`为False,返回特定格式的输出元组
if not return_dict:
last_hidden_state = decoder_output[0] # 解码器输出的最后隐藏状态
outputs = (features[-1], last_hidden_state) # 输出包括编码器的最后一个特征映射和解码器的最后隐藏状态
if output_hidden_states:
hidden_states = decoder_output[1] # 解码器的所有隐藏状态
outputs = outputs + (tuple(features),) + (hidden_states,) # 输出扩展为包括所有特征映射和隐藏状态
return outputs
# 如果`return_dict`为True,构造并返回`MaskFormerPixelLevelModuleOutput`对象
return MaskFormerPixelLevelModuleOutput(
encoder_last_hidden_state=features[-1], # 编码器的最后一个特征映射
decoder_last_hidden_state=decoder_output.last_hidden_state, # 解码器的最后隐藏状态
encoder_hidden_states=tuple(features) if output_hidden_states else (), # 所有编码器特征映射
decoder_hidden_states=decoder_output.hidden_states if output_hidden_states else (), # 所有解码器隐藏状态
)
class MaskFormerTransformerModule(nn.Module):
"""
The MaskFormer's transformer module.
"""
def __init__(self, in_features: int, config: MaskFormerConfig):
super().__init__()
hidden_size = config.decoder_config.hidden_size
should_project = in_features != hidden_size
# 初始化位置编码器,用于对象查询的位置信息嵌入
self.position_embedder = MaskFormerSinePositionEmbedding(num_pos_feats=hidden_size // 2, normalize=True)
# 初始化查询的嵌入层,根据配置的查询数量和隐藏大小
self.queries_embedder = nn.Embedding(config.decoder_config.num_queries, hidden_size)
# 如果输入特征与隐藏大小不同,进行卷积投影
self.input_projection = nn.Conv2d(in_features, hidden_size, kernel_size=1) if should_project else None
# 初始化解码器
self.decoder = DetrDecoder(config=config.decoder_config)
def forward(
self,
image_features: Tensor,
output_hidden_states: bool = False,
output_attentions: bool = False,
return_dict: Optional[bool] = None,
) -> DetrDecoderOutput:
if self.input_projection is not None:
# 如果存在输入投影层,对图像特征进行投影
image_features = self.input_projection(image_features)
# 生成对象查询的位置嵌入
object_queries = self.position_embedder(image_features)
# 重复查询嵌入以匹配批次大小
batch_size = image_features.shape[0]
queries_embeddings = self.queries_embedder.weight.unsqueeze(0).repeat(batch_size, 1, 1)
# 初始化输入嵌入(用零填充),将会被模型修改
inputs_embeds = torch.zeros_like(queries_embeddings, requires_grad=True)
batch_size, num_channels, height, width = image_features.shape
# 重新排列图像特征和对象查询的维度以便匹配解码器的输入格式
image_features = image_features.view(batch_size, num_channels, height * width).permute(0, 2, 1)
object_queries = object_queries.view(batch_size, num_channels, height * width).permute(0, 2, 1)
# 调用解码器进行前向传播
decoder_output: DetrDecoderOutput = self.decoder(
inputs_embeds=inputs_embeds,
attention_mask=None,
encoder_hidden_states=image_features,
encoder_attention_mask=None,
object_queries=object_queries,
query_position_embeddings=queries_embeddings,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# 返回解码器的输出结果
return decoder_output
注释:
# Args 定义了此函数的输入参数
Args:
# `pixel_values` 是一个 FloatTensor,表示像素值,形状为 `(batch_size, num_channels, height, width)`
# 像素值可以通过 `AutoImageProcessor` 获得。详见 `MaskFormerImageProcessor.__call__` 的说明。
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
# `pixel_mask` 是一个 LongTensor,形状为 `(batch_size, height, width)`,可选参数
# 用于避免在填充像素值上执行注意力操作。掩码的取值范围为 `[0, 1]`:
#
# - 1 表示真实像素(即 **未掩码**),
# - 0 表示填充像素(即 **已掩码**)。
#
# [什么是注意力掩码?](../glossary#attention-mask)
pixel_mask (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
# `output_hidden_states` 是一个布尔值,可选参数
# 是否返回所有层的隐藏状态。更多细节请参见返回的张量中的 `hidden_states`。
output_hidden_states (`bool`, *optional*):
# `output_attentions` 是一个布尔值,可选参数
# 是否返回 Detr 解码器注意力层的注意力张量。
output_attentions (`bool`, *optional*):
# `return_dict` 是一个布尔值,可选参数
# 是否返回 `~MaskFormerModelOutput` 而不是普通的元组。
return_dict (`bool`, *optional*):
"""
Defines the MaskFormerModel class which extends MaskFormerPreTrainedModel.
@add_start_docstrings(
"The bare MaskFormer Model outputting raw hidden-states without any specific head on top.",
MASKFORMER_START_DOCSTRING,
)
"""
class MaskFormerModel(MaskFormerPreTrainedModel):
"""
Initializes a MaskFormerModel instance.
Args:
config (MaskFormerConfig): Configuration object specifying model parameters.
Inherits:
MaskFormerPreTrainedModel: Base class for MaskFormerModel, pre-trained model.
Attributes:
pixel_level_module (MaskFormerPixelLevelModule): Pixel-level module for MaskFormer.
transformer_module (MaskFormerTransformerModule): Transformer module for MaskFormer.
"""
def __init__(self, config: MaskFormerConfig):
"""
Constructor for MaskFormerModel.
Args:
config (MaskFormerConfig): Configuration object specifying model parameters.
Calls super() to initialize from MaskFormerPreTrainedModel, initializes:
- pixel_level_module (MaskFormerPixelLevelModule): Module for pixel-level operations.
- transformer_module (MaskFormerTransformerModule): Transformer module for MaskFormer.
Post-initialization handled by self.post_init().
"""
super().__init__(config)
self.pixel_level_module = MaskFormerPixelLevelModule(config)
self.transformer_module = MaskFormerTransformerModule(
in_features=self.pixel_level_module.encoder.channels[-1], config=config
)
self.post_init()
# 定义一个方法 `forward`,用于模型的前向传播
def forward(
# 输入参数 `pixel_values`,类型为 Tensor,表示输入的像素值
self,
# 输入参数 `pixel_mask`,可选的 Tensor 类型,表示像素的掩码,用于指示哪些像素是有效的
pixel_values: Tensor,
# 输入参数 `output_hidden_states`,可选的布尔值,控制是否输出隐藏状态
output_hidden_states: Optional[bool] = None,
# 输入参数 `output_attentions`,可选的布尔值,控制是否输出注意力权重
output_attentions: Optional[bool] = None,
# 输入参数 `return_dict`,可选的布尔值,控制是否返回字典形式的输出
return_dict: Optional[bool] = None,
class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel):
def __init__(self, config: MaskFormerConfig):
super().__init__(config)
# 初始化 MaskFormerModel 模型
self.model = MaskFormerModel(config)
# 从配置中获取隐藏层大小
hidden_size = config.decoder_config.hidden_size
# 创建一个线性层用于类别预测,输出维度为 num_labels + 1(增加一个“空”类别)
self.class_predictor = nn.Linear(hidden_size, config.num_labels + 1)
# 创建 MaskformerMLPPredictionHead 实例,用于掩码嵌入
self.mask_embedder = MaskformerMLPPredictionHead(hidden_size, hidden_size, config.mask_feature_size)
# 创建 MaskFormerHungarianMatcher 实例,用于匹配器
self.matcher = MaskFormerHungarianMatcher(
cost_class=1.0, cost_dice=config.dice_weight, cost_mask=config.mask_weight
)
# 设置损失权重字典,用于损失函数 MaskFormerLoss
self.weight_dict: Dict[str, float] = {
"loss_cross_entropy": config.cross_entropy_weight,
"loss_mask": config.mask_weight,
"loss_dice": config.dice_weight,
}
# 创建 MaskFormerLoss 损失函数实例
self.criterion = MaskFormerLoss(
config.num_labels,
matcher=self.matcher,
weight_dict=self.weight_dict,
eos_coef=config.no_object_weight,
)
# 运行初始化后处理方法
self.post_init()
# 计算并返回损失字典
def get_loss_dict(
self,
masks_queries_logits: Tensor,
class_queries_logits: Tensor,
mask_labels: Tensor,
class_labels: Tensor,
auxiliary_logits: Dict[str, Tensor],
) -> Dict[str, Tensor]:
loss_dict: Dict[str, Tensor] = self.criterion(
masks_queries_logits, class_queries_logits, mask_labels, class_labels, auxiliary_logits
)
# 根据权重字典调整每个损失值
for key, weight in self.weight_dict.items():
for loss_key, loss in loss_dict.items():
if key in loss_key:
loss *= weight
return loss_dict
# 计算并返回总损失值
def get_loss(self, loss_dict: Dict[str, Tensor]) -> Tensor:
return sum(loss_dict.values())
# 前向传播函数,接受多个输入和输出参数,包括像素值、掩码和类别标签等
@add_start_docstrings_to_model_forward(MASKFORMER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=MaskFormerForInstanceSegmentationOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values: Tensor,
mask_labels: Optional[List[Tensor]] = None,
class_labels: Optional[List[Tensor]] = None,
pixel_mask: Optional[Tensor] = None,
output_auxiliary_logits: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
# 省略了前向传播函数的其余部分,因为没有要注释的代码
pass
.\models\maskformer\modeling_maskformer_swin.py
"""MaskFormer Swin Transformer. The reason Swin Transformer is implemented here is because MaskFormer uses the hidden
states before downsampling, which is different from the default Swin Transformer."""
import collections.abc
import math
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
from torch import Tensor, nn
from ...activations import ACT2FN
from ...file_utils import ModelOutput
from ...modeling_outputs import BackboneOutput
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer
from ...utils.backbone_utils import BackboneMixin
from .configuration_maskformer_swin import MaskFormerSwinConfig
@dataclass
class MaskFormerSwinModelOutputWithPooling(ModelOutput):
"""
Class for MaskFormerSwinModel's outputs that also contains the spatial dimensions of the hidden states.
"""
"""
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
模型最后一层的隐藏状态序列。
pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
经过平均池化操作后的最后一层隐藏状态。
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
一个元组,包含每一层的隐藏状态,形状为 `(batch_size, sequence_length, hidden_size)`。
模型在每一层的输出隐藏状态,以及初始嵌入输出。
hidden_states_spatial_dimensions (`tuple(tuple(int, int))`, *optional*):
包含每个隐藏状态的空间维度元组,用于将 `hidden_states` 重塑为 `batch, channels, height, width` 的形式。
由于填充存在,无法在 `forward` 方法之前推断它们的空间大小。
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
一个元组,包含每一层的注意力权重,形状为 `(batch_size, num_heads, sequence_length, sequence_length)`。
注意力 softmax 后的注意力权重,用于计算自注意力头中的加权平均值。
"""
last_hidden_state: torch.FloatTensor = None
pooler_output: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
hidden_states_spatial_dimensions: Tuple[Tuple[int, int]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class MaskFormerSwinBaseModelOutput(ModelOutput):
"""
SwinEncoder模型输出的类。
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
模型最后一层的隐藏状态序列。
hidden_states (`tuple(torch.FloatTensor)`, *optional*, 当 `output_hidden_states=True` 或 `config.output_hidden_states=True` 时返回):
一个元组的 `torch.FloatTensor`(对应每层的输出和初始嵌入输出),
形状为 `(batch_size, sequence_length, hidden_size)`。
模型每一层的隐藏状态加上初始嵌入输出。
hidden_states_spatial_dimensions (`tuple(tuple(int, int))`, *optional*):
包含每个 `hidden_state` 的空间维度的元组,用于将 `hidden_states` 重塑为 `batch, channels, height, width`。
由于填充,它们的空间大小在 `forward` 方法之前无法推断。
attentions (`tuple(torch.FloatTensor)`, *optional*, 当 `output_attentions=True` 或 `config.output_attentions=True` 时返回):
一个元组的 `torch.FloatTensor`(每层一个),形状为 `(batch_size, num_heads, sequence_length, sequence_length)`。
经过注意力 softmax 后的注意力权重,用于在自注意力头中计算加权平均值。
"""
last_hidden_state: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
hidden_states_spatial_dimensions: Tuple[Tuple[int, int]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
def window_partition(input_feature, window_size):
"""
将给定输入分割为窗口。
"""
batch_size, height, width, num_channels = input_feature.shape
input_feature = input_feature.view(
batch_size, height // window_size, window_size, width // window_size, window_size, num_channels
)
windows = input_feature.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, num_channels)
return windows
def window_reverse(windows, window_size, height, width):
"""
合并窗口以产生更高分辨率的特征。
"""
num_channels = windows.shape[-1]
windows = windows.view(-1, height // window_size, width // window_size, window_size, window_size, num_channels)
windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, height, width, num_channels)
return windows
def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
"""
实现丢弃路径(drop path)操作。
Args:
input (torch.Tensor): 输入张量。
drop_prob (float, optional): 丢弃概率。默认为0.0。
training (bool, optional): 是否处于训练模式。默认为False。
Returns:
torch.Tensor: 处理后的张量。
"""
if drop_prob == 0.0 or not training:
return input
keep_prob = 1 - drop_prob
shape = (input.shape[0],) + (1,) * (input.ndim - 1)
random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
random_tensor.floor_()
output = input.div(keep_prob) * random_tensor
return output
class MaskFormerSwinEmbeddings(nn.Module):
"""
Construct the patch and position embeddings.
"""
def __init__(self, config):
super().__init__()
self.patch_embeddings = MaskFormerSwinPatchEmbeddings(config)
num_patches = self.patch_embeddings.num_patches
self.patch_grid = self.patch_embeddings.grid_size
if config.use_absolute_embeddings:
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.embed_dim))
else:
self.position_embeddings = None
self.norm = nn.LayerNorm(config.embed_dim)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, pixel_values):
embeddings, output_dimensions = self.patch_embeddings(pixel_values)
embeddings = self.norm(embeddings)
if self.position_embeddings is not None:
embeddings = embeddings + self.position_embeddings
embeddings = self.dropout(embeddings)
return embeddings, output_dimensions
class MaskFormerSwinPatchEmbeddings(nn.Module):
"""
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
Transformer.
"""
def __init__(self, config):
super().__init__()
image_size, patch_size = config.image_size, config.patch_size
num_channels, hidden_size = config.num_channels, config.embed_dim
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.num_patches = num_patches
self.grid_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
def maybe_pad(self, pixel_values, height, width):
if width % self.patch_size[1] != 0:
pad_values = (0, self.patch_size[1] - width % self.patch_size[1])
pixel_values = nn.functional.pad(pixel_values, pad_values)
if height % self.patch_size[0] != 0:
pad_values = (0, 0, 0, self.patch_size[0] - height % self.patch_size[0])
pixel_values = nn.functional.pad(pixel_values, pad_values)
return pixel_values
def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]:
_, num_channels, height, width = pixel_values.shape
if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
pixel_values = self.maybe_pad(pixel_values, height, width)
embeddings = self.projection(pixel_values)
_, _, height, width = embeddings.shape
output_dimensions = (height, width)
embeddings = embeddings.flatten(2).transpose(1, 2)
return embeddings, output_dimensions
class MaskFormerSwinPatchMerging(nn.Module):
"""
Patch Merging Layer.
Args:
input_resolution (`Tuple[int]`):
Resolution of input feature.
dim (`int`):
Number of input channels.
norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`):
Normalization layer class.
"""
def __init__(self, input_resolution: Tuple[int], dim: int, norm_layer: nn.Module = nn.LayerNorm) -> None:
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def maybe_pad(self, input_feature, height, width):
should_pad = (height % 2 == 1) or (width % 2 == 1)
if should_pad:
pad_values = (0, 0, 0, width % 2, 0, height % 2)
input_feature = nn.functional.pad(input_feature, pad_values)
return input_feature
def forward(self, input_feature: torch.Tensor, input_dimensions: Tuple[int, int]) -> torch.Tensor:
height, width = input_dimensions
batch_size, dim, num_channels = input_feature.shape
input_feature = input_feature.view(batch_size, height, width, num_channels)
input_feature = self.maybe_pad(input_feature, height, width)
input_feature_0 = input_feature[:, 0::2, 0::2, :]
input_feature_1 = input_feature[:, 1::2, 0::2, :]
input_feature_2 = input_feature[:, 0::2, 1::2, :]
input_feature_3 = input_feature[:, 1::2, 1::2, :]
input_feature = torch.cat([input_feature_0, input_feature_1, input_feature_2, input_feature_3], -1)
input_feature = input_feature.view(batch_size, -1, 4 * num_channels)
input_feature = self.norm(input_feature)
input_feature = self.reduction(input_feature)
return input_feature
class MaskFormerSwinDropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob: Optional[float] = None) -> None:
super().__init__()
self.drop_prob = drop_prob
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return drop_path(hidden_states, self.drop_prob, self.training)
def extra_repr(self) -> str:
return "p={}".format(self.drop_prob)
class MaskFormerSwinSelfAttention(nn.Module):
def __init__(self, config, dim, num_heads, window_size):
super().__init__()
if dim % num_heads != 0:
raise ValueError(
f"The hidden size ({dim}) is not a multiple of the number of attention heads ({num_heads})"
)
self.num_attention_heads = num_heads
self.attention_head_size = int(dim / num_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.window_size = (
window_size if isinstance(window_size, collections.abc.Iterable) else (window_size, window_size)
)
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads)
)
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij"))
coords_flatten = torch.flatten(coords, 1)
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
relative_coords[:, :, 0] += self.window_size[0] - 1
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1)
self.register_buffer("relative_position_index", relative_position_index)
self.query = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
self.key = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
self.value = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
batch_size, dim, num_channels = hidden_states.shape
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)]
relative_position_bias = relative_position_bias.view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
)
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
attention_scores = attention_scores + relative_position_bias.unsqueeze(0)
if attention_mask is not None:
mask_shape = attention_mask.shape[0]
attention_scores = attention_scores.view(
batch_size // mask_shape, mask_shape, self.num_attention_heads, dim, dim
)
attention_scores = attention_scores + attention_mask.unsqueeze(1).unsqueeze(0)
attention_scores = attention_scores.view(-1, self.num_attention_heads, dim, dim)
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
attention_probs = self.dropout(attention_probs)
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs
class MaskFormerSwinSelfOutput(nn.Module):
def __init__(self, config, dim):
super().__init__()
self.dense = nn.Linear(dim, dim)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states
class MaskFormerSwinAttention(nn.Module):
def __init__(self, config, dim, num_heads, window_size):
super().__init__()
self.self = MaskFormerSwinSelfAttention(config, dim, num_heads, window_size)
self.output = MaskFormerSwinSelfOutput(config, dim)
self.pruned_heads = set()
def prune_heads(self, heads):
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
)
self.self.query = prune_linear_layer(self.self.query, index)
self.self.key = prune_linear_layer(self.self.key, index)
self.self.value = prune_linear_layer(self.self.value, index)
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
self_outputs = self.self(hidden_states, attention_mask, head_mask, output_attentions)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:]
return outputs
class MaskFormerSwinIntermediate(nn.Module):
def __init__(self, config, dim):
super().__init__()
self.dense = nn.Linear(dim, int(config.mlp_ratio * dim))
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class MaskFormerSwinOutput(nn.Module):
def __init__(self, config, dim):
super().__init__()
self.dense = nn.Linear(int(config.mlp_ratio * dim), dim)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states
class MaskFormerSwinLayer(nn.Module):
def __init__(self, config, dim, input_resolution, num_heads, shift_size=0):
super().__init__()
self.shift_size = shift_size
self.window_size = config.window_size
self.input_resolution = input_resolution
self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps)
self.attention = MaskFormerSwinAttention(config, dim, num_heads, self.window_size)
self.drop_path = (
MaskFormerSwinDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
)
self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps)
self.intermediate = MaskFormerSwinIntermediate(config, dim)
self.output = MaskFormerSwinOutput(config, dim)
def get_attn_mask(self, input_resolution):
if self.shift_size > 0:
height, width = input_resolution
img_mask = torch.zeros((1, height, width, 1))
height_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None),
)
width_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None),
)
count = 0
for height_slice in height_slices:
for width_slice in width_slices:
img_mask[:, height_slice, width_slice, :] = count
count += 1
mask_windows = window_partition(img_mask, self.window_size)
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
else:
attn_mask = None
return attn_mask
def maybe_pad(self, hidden_states, height, width):
pad_left = pad_top = 0
pad_right = (self.window_size - width % self.window_size) % self.window_size
pad_bottom = (self.window_size - height % self.window_size) % self.window_size
pad_values = (0, 0, pad_left, pad_right, pad_top, pad_bottom)
hidden_states = nn.functional.pad(hidden_states, pad_values)
return hidden_states, pad_values
def forward(self, hidden_states, input_dimensions, head_mask=None, output_attentions=False):
height, width = input_dimensions
batch_size, dim, channels = hidden_states.size()
shortcut = hidden_states
hidden_states = self.layernorm_before(hidden_states)
hidden_states = hidden_states.view(batch_size, height, width, channels)
hidden_states, pad_values = self.maybe_pad(hidden_states, height, width)
_, height_pad, width_pad, _ = hidden_states.shape
if self.shift_size > 0:
shifted_hidden_states = torch.roll(hidden_states, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
else:
shifted_hidden_states = hidden_states
hidden_states_windows = window_partition(shifted_hidden_states, self.window_size)
hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels)
attn_mask = self.get_attn_mask((height_pad, width_pad))
if attn_mask is not None:
attn_mask = attn_mask.to(hidden_states_windows.device)
self_attention_outputs = self.attention(
hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions
)
attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:]
attention_windows = attention_output.view(-1, self.window_size, self.window_size, channels)
shifted_windows = window_reverse(
attention_windows, self.window_size, height_pad, width_pad
)
if self.shift_size > 0:
attention_windows = torch.roll(shifted_windows, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
attention_windows = shifted_windows
was_padded = pad_values[3] > 0 or pad_values[5] > 0
if was_padded:
attention_windows = attention_windows[:, :height, :width, :].contiguous()
attention_windows = attention_windows.view(batch_size, height * width, channels)
hidden_states = shortcut + self.drop_path(attention_windows)
layer_output = self.layernorm_after(hidden_states)
layer_output = self.intermediate(layer_output)
layer_output = hidden_states + self.output(layer_output)
outputs = (layer_output,) + outputs
return outputs
class MaskFormerSwinStage(nn.Module):
def __init__(self, config, dim, input_resolution, depth, num_heads, drop_path, downsample):
super().__init__()
self.config = config
self.dim = dim
self.blocks = nn.ModuleList(
[
MaskFormerSwinLayer(
config=config,
dim=dim,
input_resolution=input_resolution,
num_heads=num_heads,
shift_size=0 if (i % 2 == 0) else config.window_size // 2,
)
for i in range(depth)
]
)
if downsample is not None:
self.downsample = downsample(input_resolution, dim=dim, norm_layer=nn.LayerNorm)
else:
self.downsample = None
self.pointing = False
def forward(
self, hidden_states, input_dimensions, head_mask=None, output_attentions=False, output_hidden_states=False
):
all_hidden_states = () if output_hidden_states else None
height, width = input_dimensions
for i, block_module in enumerate(self.blocks):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None
block_hidden_states = block_module(hidden_states, input_dimensions, layer_head_mask, output_attentions)
hidden_states = block_hidden_states[0]
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.downsample is not None:
height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2
output_dimensions = (height, width, height_downsampled, width_downsampled)
hidden_states = self.downsample(hidden_states, input_dimensions)
else:
output_dimensions = (height, width, height, width)
return hidden_states, output_dimensions, all_hidden_states
class MaskFormerSwinEncoder(nn.Module):
pass
def __init__(self, config, grid_size):
super().__init__()
self.num_layers = len(config.depths)
self.config = config
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))]
self.layers = nn.ModuleList(
[
MaskFormerSwinStage(
config=config,
dim=int(config.embed_dim * 2**i_layer),
input_resolution=(grid_size[0] // (2**i_layer), grid_size[1] // (2**i_layer)),
depth=config.depths[i_layer],
num_heads=config.num_heads[i_layer],
drop_path=dpr[sum(config.depths[:i_layer]) : sum(config.depths[: i_layer + 1])],
downsample=MaskFormerSwinPatchMerging if (i_layer < self.num_layers - 1) else None,
)
for i_layer in range(self.num_layers)
]
)
self.gradient_checkpointing = False
def forward(
self,
hidden_states,
input_dimensions,
head_mask=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
):
all_hidden_states = () if output_hidden_states else None
all_input_dimensions = ()
all_self_attentions = () if output_attentions else None
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
for i, layer_module in enumerate(self.layers):
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
layer_hidden_states, output_dimensions, layer_all_hidden_states = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
layer_head_mask,
output_attentions,
)
else:
layer_hidden_states, output_dimensions, layer_all_hidden_states = layer_module(
hidden_states,
input_dimensions,
layer_head_mask,
output_attentions,
output_hidden_states,
)
input_dimensions = (output_dimensions[-2], output_dimensions[-1])
all_input_dimensions += (input_dimensions,)
if output_hidden_states:
all_hidden_states += (layer_all_hidden_states,)
hidden_states = layer_hidden_states
if output_attentions:
all_self_attentions = all_self_attentions + (layer_all_hidden_states[1],)
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
return MaskFormerSwinBaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
hidden_states_spatial_dimensions=all_input_dimensions,
attentions=all_self_attentions,
)
class MaskFormerSwinPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = MaskFormerSwinConfig
base_model_prefix = "model"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d)):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
class MaskFormerSwinModel(MaskFormerSwinPreTrainedModel):
def __init__(self, config, add_pooling_layer=True):
super().__init__(config)
self.config = config
self.num_layers = len(config.depths)
self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1))
self.embeddings = MaskFormerSwinEmbeddings(config)
self.encoder = MaskFormerSwinEncoder(config, self.embeddings.patch_grid)
self.layernorm = nn.LayerNorm(self.num_features, eps=config.layer_norm_eps)
self.pooler = nn.AdaptiveAvgPool1d(1) if add_pooling_layer else None
def get_input_embeddings(self):
return self.embeddings.patch_embeddings
def _prune_heads(self, heads_to_prune):
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
def forward(
self,
pixel_values=None,
head_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
head_mask = self.get_head_mask(head_mask, len(self.config.depths))
embedding_output, input_dimensions = self.embeddings(pixel_values)
encoder_outputs = self.encoder(
embedding_output,
input_dimensions,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = encoder_outputs.last_hidden_state if return_dict else encoder_outputs[0]
sequence_output = self.layernorm(sequence_output)
pooled_output = None
if self.pooler is not None:
pooled_output = self.pooler(sequence_output.transpose(1, 2))
pooled_output = torch.flatten(pooled_output, 1)
if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]
hidden_states_spatial_dimensions = (input_dimensions,) + encoder_outputs.hidden_states_spatial_dimensions
return MaskFormerSwinModelOutputWithPooling(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
hidden_states_spatial_dimensions=hidden_states_spatial_dimensions,
attentions=encoder_outputs.attentions,
)
"""
MaskFormerSwin backbone, designed especially for the MaskFormer framework.
This classes reshapes `hidden_states` from (`batch_size, sequence_length, hidden_size)` to (`batch_size,
num_channels, height, width)`). It also adds additional layernorms after each stage.
Args:
config (`MaskFormerSwinConfig`):
The configuration used by [`MaskFormerSwinModel`].
"""
def __init__(self, config: MaskFormerSwinConfig):
super().__init__(config)
super()._init_backbone(config)
self.model = MaskFormerSwinModel(config)
if "stem" in self.out_features:
raise ValueError("This backbone does not support 'stem' in the `out_features`.")
self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
self.hidden_states_norms = nn.ModuleList(
[nn.LayerNorm(num_channels) for num_channels in self.num_features[1:]]
)
self.post_init()
def forward(
self,
pixel_values: Tensor,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> BackboneOutput:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
outputs = self.model(
pixel_values, output_hidden_states=True, output_attentions=output_attentions, return_dict=True
)
hidden_states = outputs.hidden_states[1:]
spatial_dimensions: Tuple[Tuple[int, int]] = outputs.hidden_states_spatial_dimensions
feature_maps = ()
for i, (hidden_state, stage, (height, width)) in enumerate(
zip(hidden_states, self.stage_names[1:], spatial_dimensions)
):
norm = self.hidden_states_norms[i]
hidden_state_unpolled = hidden_state[-1]
hidden_state_norm = norm(hidden_state_unpolled)
batch_size, _, hidden_size = hidden_state_norm.shape
hidden_state_permuted = (
hidden_state_norm.permute(0, 2, 1).view((batch_size, hidden_size, height, width)).contiguous()
)
if stage in self.out_features:
feature_maps += (hidden_state_permuted,)
if not return_dict:
output = (feature_maps,)
if output_hidden_states:
output += (outputs.hidden_states,)
if output_attentions:
output += (outputs.attentions,)
return output
return BackboneOutput(
feature_maps=feature_maps,
hidden_states=outputs.hidden_states if output_hidden_states else None,
attentions=outputs.attentions,
)
.\models\maskformer\__init__.py
from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
_import_structure = {
"configuration_maskformer": ["MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "MaskFormerConfig"],
"configuration_maskformer_swin": ["MaskFormerSwinConfig"],
}
try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["feature_extraction_maskformer"] = ["MaskFormerFeatureExtractor"]
_import_structure["image_processing_maskformer"] = ["MaskFormerImageProcessor"]
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_maskformer"] = [
"MASKFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
"MaskFormerForInstanceSegmentation",
"MaskFormerModel",
"MaskFormerPreTrainedModel",
]
_import_structure["modeling_maskformer_swin"] = [
"MaskFormerSwinBackbone",
"MaskFormerSwinModel",
"MaskFormerSwinPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_maskformer import MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, MaskFormerConfig
from .configuration_maskformer_swin import MaskFormerSwinConfig
try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .feature_extraction_maskformer import MaskFormerFeatureExtractor
from .image_processing_maskformer import MaskFormerImageProcessor
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_maskformer import (
MASKFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
MaskFormerForInstanceSegmentation,
MaskFormerModel,
MaskFormerPreTrainedModel,
)
from .modeling_maskformer_swin import (
MaskFormerSwinBackbone,
MaskFormerSwinModel,
MaskFormerSwinPreTrainedModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
.\models\mbart\configuration_mbart.py
""" MBART model configuration"""
from collections import OrderedDict
from typing import Any, Mapping, Optional
from ... import PreTrainedTokenizer
from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig, OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast
from ...onnx.utils import compute_effective_axis_dimension
from ...utils import TensorType, is_torch_available, logging
logger = logging.get_logger(__name__)
MBART_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"facebook/mbart-large-cc25": "https://huggingface.co/facebook/mbart-large-cc25/resolve/main/config.json",
}
class MBartConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`MBartModel`]. It is used to instantiate an MBART
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the MBART
[facebook/mbart-large-cc25](https://huggingface.co/facebook/mbart-large-cc25) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Example:
```
>>> from transformers import MBartConfig, MBartModel
>>> # Initializing a MBART facebook/mbart-large-cc25 style configuration
>>> configuration = MBartConfig()
>>> # Initializing a model (with random weights) from the facebook/mbart-large-cc25 style configuration
>>> model = MBartModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "mbart"
keys_to_ignore_at_inference = ["past_key_values"]
attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
def __init__(
self,
vocab_size=50265,
max_position_embeddings=1024,
encoder_layers=12,
encoder_ffn_dim=4096,
encoder_attention_heads=16,
decoder_layers=12,
decoder_ffn_dim=4096,
decoder_attention_heads=16,
encoder_layerdrop=0.0,
decoder_layerdrop=0.0,
use_cache=True,
is_encoder_decoder=True,
activation_function="gelu",
d_model=1024,
dropout=0.1,
attention_dropout=0.0,
activation_dropout=0.0,
init_std=0.02,
classifier_dropout=0.0,
scale_embedding=False,
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
forced_eos_token_id=2,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.d_model = d_model
self.encoder_ffn_dim = encoder_ffn_dim
self.encoder_layers = encoder_layers
self.encoder_attention_heads = encoder_attention_heads
self.decoder_ffn_dim = decoder_ffn_dim
self.decoder_layers = decoder_layers
self.decoder_attention_heads = decoder_attention_heads
self.dropout = dropout
self.attention_dropout = attention_dropout
self.activation_dropout = activation_dropout
self.activation_function = activation_function
self.init_std = init_std
self.encoder_layerdrop = encoder_layerdrop
self.decoder_layerdrop = decoder_layerdrop
self.classifier_dropout = classifier_dropout
self.use_cache = use_cache
self.num_hidden_layers = encoder_layers
self.scale_embedding = scale_embedding
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
is_encoder_decoder=is_encoder_decoder,
forced_eos_token_id=forced_eos_token_id,
**kwargs
)
class MBartOnnxConfig(OnnxSeq2SeqConfigWithPast):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task in ["default", "seq2seq-lm"]:
common_inputs = OrderedDict(
[
("input_ids", {0: "batch", 1: "encoder_sequence"}),
("attention_mask", {0: "batch", 1: "encoder_sequence"}),
]
)
if self.use_past:
common_inputs["decoder_input_ids"] = {0: "batch"}
common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"}
else:
common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"}
common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"}
if self.use_past:
self.fill_with_past_key_values_(common_inputs, direction="inputs")
elif self.task == "causal-lm":
common_inputs = OrderedDict(
[
("input_ids", {0: "batch", 1: "encoder_sequence"}),
("attention_mask", {0: "batch", 1: "encoder_sequence"}),
]
)
if self.use_past:
num_encoder_layers, _ = self.num_layers
for i in range(num_encoder_layers):
common_inputs[f"past_key_values.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"}
common_inputs[f"past_key_values.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"}
else:
common_inputs = OrderedDict(
[
("input_ids", {0: "batch", 1: "encoder_sequence"}),
("attention_mask", {0: "batch", 1: "encoder_sequence"}),
("decoder_input_ids", {0: "batch", 1: "decoder_sequence"}),
("decoder_attention_mask", {0: "batch", 1: "decoder_sequence"}),
]
)
return common_inputs
@property
def outputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task in ["default", "seq2seq-lm"]:
common_outputs = super().outputs
else:
common_outputs = super(OnnxConfigWithPast, self).outputs
if self.use_past:
num_encoder_layers, _ = self.num_layers
for i in range(num_encoder_layers):
common_outputs[f"present.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"}
common_outputs[f"present.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"}
return common_outputs
def _generate_dummy_inputs_for_default_and_seq2seq_lm(
self,
tokenizer: PreTrainedTokenizer,
batch_size: int = -1,
seq_length: int = -1,
is_pair: bool = False,
framework: Optional[TensorType] = None,
) -> Mapping[str, Any]:
encoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
tokenizer, batch_size, seq_length, is_pair, framework
)
decoder_seq_length = seq_length if not self.use_past else 1
decoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
tokenizer, batch_size, decoder_seq_length, is_pair, framework
)
decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()}
common_inputs = dict(**encoder_inputs, **decoder_inputs)
if self.use_past:
if not is_torch_available():
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
else:
import torch
batch, encoder_seq_length = common_inputs["input_ids"].shape
decoder_seq_length = common_inputs["decoder_input_ids"].shape[1]
num_encoder_attention_heads, num_decoder_attention_heads = self.num_attention_heads
encoder_shape = (
batch,
num_encoder_attention_heads,
encoder_seq_length,
self._config.hidden_size // num_encoder_attention_heads,
)
decoder_past_length = decoder_seq_length + 3
decoder_shape = (
batch,
num_decoder_attention_heads,
decoder_past_length,
self._config.hidden_size // num_decoder_attention_heads,
)
common_inputs["decoder_attention_mask"] = torch.cat(
[common_inputs["decoder_attention_mask"], torch.ones(batch, decoder_past_length)], dim=1
)
common_inputs["past_key_values"] = []
num_encoder_layers, num_decoder_layers = self.num_layers
min_num_layers = min(num_encoder_layers, num_decoder_layers)
max_num_layers = max(num_encoder_layers, num_decoder_layers) - min_num_layers
remaining_side_name = "encoder" if num_encoder_layers > num_decoder_layers else "decoder"
for _ in range(min_num_layers):
common_inputs["past_key_values"].append(
(
torch.zeros(decoder_shape),
torch.zeros(decoder_shape),
torch.zeros(encoder_shape),
torch.zeros(encoder_shape),
)
)
shape = encoder_shape if remaining_side_name == "encoder" else decoder_shape
for _ in range(min_num_layers, max_num_layers):
common_inputs["past_key_values"].append((torch.zeros(shape), torch.zeros(shape)))
return common_inputs
def _generate_dummy_inputs_for_causal_lm(
self,
tokenizer: PreTrainedTokenizer,
batch_size: int = -1,
seq_length: int = -1,
is_pair: bool = False,
framework: Optional[TensorType] = None,
) -> Mapping[str, Any]:
common_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
tokenizer, batch_size, seq_length, is_pair, framework
)
if self.use_past:
if not is_torch_available():
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
else:
import torch
batch, seqlen = common_inputs["input_ids"].shape
past_key_values_length = seqlen + 2
num_encoder_layers, _ = self.num_layers
num_encoder_attention_heads, _ = self.num_attention_heads
past_shape = (
batch,
num_encoder_attention_heads,
past_key_values_length,
self._config.hidden_size // num_encoder_attention_heads,
)
mask_dtype = common_inputs["attention_mask"].dtype
common_inputs["attention_mask"] = torch.cat(
[common_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)], dim=1
)
common_inputs["past_key_values"] = [
(torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(num_encoder_layers)
]
return common_inputs
def _generate_dummy_inputs_for_sequence_classification_and_question_answering(
self,
tokenizer: PreTrainedTokenizer,
batch_size: int = -1,
seq_length: int = -1,
is_pair: bool = False,
framework: Optional[TensorType] = None,
) -> Mapping[str, Any]:
batch_size = compute_effective_axis_dimension(
batch_size, fixed_dimension=OnnxConfig.default_fixed_batch, num_token_to_add=0
)
token_to_add = tokenizer.num_special_tokens_to_add(is_pair)
seq_length = compute_effective_axis_dimension(
seq_length, fixed_dimension=OnnxConfig.default_fixed_sequence, num_token_to_add=token_to_add
)
dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size
common_inputs = dict(tokenizer(dummy_input, return_tensors=framework))
return common_inputs
def generate_dummy_inputs(
self,
tokenizer: PreTrainedTokenizer,
batch_size: int = -1,
seq_length: int = -1,
is_pair: bool = False,
framework: Optional[TensorType] = None,
) -> Mapping[str, Any]:
if self.task in ["default", "seq2seq-lm"]:
common_inputs = self._generate_dummy_inputs_for_default_and_seq2seq_lm(
tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
)
elif self.task == "causal-lm":
common_inputs = self._generate_dummy_inputs_for_causal_lm(
tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
)
else:
common_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
)
return common_inputs
def _flatten_past_key_values_(self, flattened_output, name, idx, t):
if self.task in ["default", "seq2seq-lm"]:
flattened_output = super()._flatten_past_key_values_(flattened_output, name, idx, t)
else:
flattened_output = super(OnnxSeq2SeqConfigWithPast, self)._flatten_past_key_values_(
flattened_output, name, idx, t
)
.\models\mbart\convert_mbart_original_checkpoint_to_pytorch.py
import argparse
import torch
from torch import nn
from transformers import MBartConfig, MBartForConditionalGeneration
def remove_ignore_keys_(state_dict):
ignore_keys = [
"encoder.version",
"decoder.version",
"model.encoder.version",
"model.decoder.version",
"_float_tensor",
"decoder.output_projection.weight",
]
for k in ignore_keys:
state_dict.pop(k, None)
def make_linear_from_emb(emb):
vocab_size, emb_size = emb.weight.shape
lin_layer = nn.Linear(vocab_size, emb_size, bias=False)
lin_layer.weight.data = emb.weight.data
return lin_layer
def convert_fairseq_mbart_checkpoint_from_disk(
checkpoint_path, hf_config_path="facebook/mbart-large-en-ro", finetuned=False, mbart_50=False
):
state_dict = torch.load(checkpoint_path, map_location="cpu")["model"]
remove_ignore_keys_(state_dict)
vocab_size = state_dict["encoder.embed_tokens.weight"].shape[0]
mbart_config = MBartConfig.from_pretrained(hf_config_path, vocab_size=vocab_size)
if mbart_50 and finetuned:
mbart_config.activation_function = "relu"
state_dict["shared.weight"] = state_dict["decoder.embed_tokens.weight"]
model = MBartForConditionalGeneration(mbart_config)
model.model.load_state_dict(state_dict)
if finetuned:
model.lm_head = make_linear_from_emb(model.model.shared)
return model
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"fairseq_path", type=str, help="bart.large, bart.large.cnn or a path to a model.pt on local filesystem."
)
parser.add_argument("pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
parser.add_argument(
"--hf_config",
default="facebook/mbart-large-cc25",
type=str,
help="Which huggingface architecture to use: mbart-large",
)
parser.add_argument("--mbart_50", action="store_true", help="whether the model is mMART-50 checkpoint")
parser.add_argument("--finetuned", action="store_true", help="whether the model is a fine-tuned checkpoint")
args = parser.parse_args()
model = convert_fairseq_mbart_checkpoint_from_disk(
args.fairseq_path, hf_config_path=args.hf_config, finetuned=args.finetuned, mbart_50=args.mbart_50
)
model.save_pretrained(args.pytorch_dump_folder_path)
.\models\mbart\modeling_flax_mbart.py
""" Flax MBart model."""
import math
import random
from functools import partial
from typing import Callable, Optional, Tuple
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask
from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax
from jax.random import PRNGKey
from ...modeling_flax_outputs import (
FlaxBaseModelOutput,
FlaxBaseModelOutputWithPastAndCrossAttentions,
FlaxCausalLMOutputWithCrossAttentions,
FlaxSeq2SeqLMOutput,
FlaxSeq2SeqModelOutput,
FlaxSeq2SeqQuestionAnsweringModelOutput,
FlaxSeq2SeqSequenceClassifierOutput,
)
from ...modeling_flax_utils import (
ACT2FN,
FlaxPreTrainedModel,
append_call_sample_docstring,
append_replace_return_docstrings,
overwrite_call_docstring,
)
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_mbart import MBartConfig
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "facebook/mbart-large-cc25"
_CONFIG_FOR_DOC = "MBartConfig"
MBART_START_DOCSTRING = r"""
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a Flax Linen
[flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
Finally, this model supports inherent JAX features such as:
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
Parameters:
config ([`MBartConfig`]): Model configuration class with all the parameters of the model.
初始化模型配置类,包含模型的所有参数。
使用配置文件初始化不会加载模型的权重,只加载配置。可以查看 [`~FlaxPreTrainedModel.from_pretrained`] 方法加载模型权重。
dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
计算的数据类型。可以是 `jax.numpy.float32`、`jax.numpy.float16`(在GPU上)、`jax.numpy.bfloat16`(在TPU上)之一。
可用于在GPU或TPU上启用混合精度训练或半精度推断。如果指定了dtype,则所有计算将使用给定的 `dtype`。
**注意,这仅指定计算的dtype,并不影响模型参数的dtype。**
如果要更改模型参数的dtype,请参阅 [`~FlaxPreTrainedModel.to_fp16`] 和 [`~FlaxPreTrainedModel.to_bf16`]。
"""
MBART_INPUTS_DOCSTRING = r"""
"""
MBART_ENCODE_INPUTS_DOCSTRING = r"""
Args:
input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
输入序列标记在词汇表中的索引。默认情况下,将忽略填充部分。
可以使用 [`AutoTokenizer`] 获取这些索引。详见 [`PreTrainedTokenizer.encode`] 和 [`PreTrainedTokenizer.__call__`]。
[什么是输入 ID?](../glossary#input-ids)
attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
避免在填充标记索引上执行注意力操作的掩码。掩码值选择在 `[0, 1]`:
- 1 表示 **未被掩码** 的标记,
- 0 表示 **被掩码** 的标记。
[什么是注意力掩码?](../glossary#attention-mask)
position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
每个输入序列标记在位置嵌入中的位置索引。选择范围是 `[0, config.max_position_embeddings - 1]`。
output_attentions (`bool`, *optional*):
是否返回所有注意力层的注意力张量。查看返回的张量下的 `attentions` 获取更多细节。
output_hidden_states (`bool`, *optional*):
是否返回所有层的隐藏状态。查看返回的张量下的 `hidden_states` 获取更多细节。
return_dict (`bool`, *optional*):
是否返回 [`~utils.ModelOutput`] 而不是普通的元组。
"""
MBART_DECODE_INPUTS_DOCSTRING = r"""
"""
def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int) -> jnp.ndarray:
"""
将输入 ID 向右移动一个标记,并包装最后一个非填充标记(<LID> 标记)。注意,与其他类似 Bart 模型不同,MBart 没有单一的 `decoder_start_token_id`。
"""
prev_output_tokens = jnp.array(input_ids).copy()
if pad_token_id is None:
raise ValueError("self.model.config.pad_token_id has to be defined.")
prev_output_tokens = jnp.where(prev_output_tokens == -100, pad_token_id, input_ids)
index_of_eos = (jnp.where(prev_output_tokens != pad_token_id, 1, 0).sum(axis=-1) - 1).reshape(-1, 1)
decoder_start_tokens = jnp.array(
[prev_output_tokens[i, eos_idx] for i, eos_idx in enumerate(index_of_eos)], dtype=jnp.int32
).squeeze()
prev_output_tokens = prev_output_tokens.at[:, 1:].set(prev_output_tokens[:, :-1])
prev_output_tokens = prev_output_tokens.at[:, 0].set(decoder_start_tokens)
return prev_output_tokens
class FlaxMBartAttention(nn.Module):
config: MBartConfig
embed_dim: int
num_heads: int
dropout: float = 0.0
causal: bool = False
bias: bool = True
def setup(self) -> None:
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {self.num_heads})."
)
dense = partial(
nn.Dense,
self.embed_dim,
use_bias=self.bias,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
self.out_proj = dense()
self.dropout_layer = nn.Dropout(rate=self.dropout)
if self.causal:
self.causal_mask = make_causal_mask(
jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool"
)
def _split_heads(self, hidden_states):
return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
def _merge_heads(self, hidden_states):
return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
@nn.compact
def _concatenate_to_cache(self, key, value, query, attention_mask):
"""
This function takes projected key, value states from a single input token and concatenates the states to cached
states from previous steps. This function is slighly adapted from the official Flax repository:
https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
"""
is_initialized = self.has_variable("cache", "cached_key")
cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
if is_initialized:
*batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
cur_index = cache_index.value
indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
key = lax.dynamic_update_slice(cached_key.value, key, indices)
value = lax.dynamic_update_slice(cached_value.value, value, indices)
cached_key.value = key
cached_value.value = value
num_updated_cache_vectors = query.shape[1]
cache_index.value = cache_index.value + num_updated_cache_vectors
pad_mask = jnp.broadcast_to(
jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
)
attention_mask = combine_masks(pad_mask, attention_mask)
return key, value, attention_mask
class FlaxMBartEncoderLayer(nn.Module):
config: MBartConfig
dtype: jnp.dtype = jnp.float32
def setup(self) -> None:
self.embed_dim = self.config.d_model
self.self_attn = FlaxMBartAttention(
config=self.config,
embed_dim=self.embed_dim,
num_heads=self.config.encoder_attention_heads,
dropout=self.config.attention_dropout,
dtype=self.dtype,
)
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
self.activation_fn = ACT2FN[self.config.activation_function]
self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
self.fc1 = nn.Dense(
self.config.encoder_ffn_dim,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
self.fc2 = nn.Dense(
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
)
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
def __call__(
self,
hidden_states: jnp.ndarray,
attention_mask: jnp.ndarray,
output_attentions: bool = True,
deterministic: bool = True,
) -> Tuple[jnp.ndarray]:
residual = hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states, attn_weights = self.self_attn(hidden_states=hidden_states, attention_mask=attention_mask)
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic)
hidden_states = self.fc2(hidden_states)
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs
class FlaxMBartEncoderLayerCollection(nn.Module):
config: MBartConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.layers = [
FlaxMBartEncoderLayer(self.config, name=str(i), dtype=self.dtype)
for i in range(self.config.encoder_layers)
]
self.layerdrop = self.config.encoder_layerdrop
def __call__(
self,
hidden_states,
attention_mask,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
all_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
for encoder_layer in self.layers:
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
dropout_probability = random.uniform(0, 1)
if not deterministic and (dropout_probability < self.layerdrop):
layer_outputs = (None, None)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
output_attentions,
deterministic,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
if output_hidden_states:
all_hidden_states += (hidden_states,)
outputs = (hidden_states, all_hidden_states, all_attentions)
if not return_dict:
return tuple(v for v in outputs if v is not None)
return FlaxBaseModelOutput(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
)
class FlaxMBartDecoderLayer(nn.Module):
config: MBartConfig
def setup(self) -> None:
self.embed_dim = self.config.d_model
self.self_attn = FlaxMBartAttention(
config=self.config,
embed_dim=self.embed_dim,
num_heads=self.config.decoder_attention_heads,
dropout=self.config.attention_dropout,
causal=True,
dtype=self.dtype,
)
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
self.activation_fn = ACT2FN[self.config.activation_function]
self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
self.encoder_attn = FlaxMBartAttention(
config=self.config,
embed_dim=self.embed_dim,
num_heads=self.config.decoder_attention_heads,
dropout=self.config.attention_dropout,
dtype=self.dtype,
)
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
self.fc1 = nn.Dense(
self.config.decoder_ffn_dim,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
self.fc2 = nn.Dense(
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
)
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
def __call__(
self,
hidden_states: jnp.ndarray,
attention_mask: jnp.ndarray,
encoder_hidden_states: Optional[jnp.ndarray] = None,
encoder_attention_mask: Optional[jnp.ndarray] = None,
init_cache: bool = False,
output_attentions: bool = True,
deterministic: bool = True,
) -> Tuple[jnp.ndarray]:
residual = hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states, attention_mask=attention_mask, init_cache=init_cache
)
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
hidden_states = residual + hidden_states
cross_attn_weights = None
if encoder_hidden_states is not None:
residual = hidden_states
hidden_states = self.encoder_attn_layer_norm(hidden_states)
hidden_states, cross_attn_weights = self.encoder_attn(
hidden_states=hidden_states,
key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
)
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic)
hidden_states = self.fc2(hidden_states)
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights, cross_attn_weights)
return outputs
class FlaxMBartDecoderLayerCollection(nn.Module):
config: MBartConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.layers = [
FlaxMBartDecoderLayer(self.config, name=str(i), dtype=self.dtype)
for i in range(self.config.decoder_layers)
]
self.layerdrop = self.config.decoder_layerdrop
def __call__(
self,
hidden_states,
attention_mask,
encoder_hidden_states: Optional[jnp.ndarray] = None,
encoder_attention_mask: Optional[jnp.ndarray] = None,
deterministic: bool = True,
init_cache: bool = False,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
dropout_probability = random.uniform(0, 1)
if not deterministic and (dropout_probability < self.layerdrop):
layer_outputs = (None, None, None)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
init_cache=init_cache,
output_attentions=output_attentions,
deterministic=deterministic,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attns += (layer_outputs[1],)
if encoder_hidden_states is not None:
all_cross_attentions += (layer_outputs[2],)
if output_hidden_states:
all_hidden_states += (hidden_states,)
outputs = [hidden_states, all_hidden_states, all_self_attns, all_cross_attentions]
if not return_dict:
return tuple(v for v in outputs if v is not None)
return FlaxBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attns,
cross_attentions=all_cross_attentions,
)
"""Head for sentence-level classification tasks."""
config: MBartConfig
inner_dim: int
num_classes: int
pooler_dropout: float
dtype: jnp.dtype = jnp.float32
def setup(self):
self.dense = nn.Dense(
self.inner_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
)
self.dropout = nn.Dropout(rate=self.pooler_dropout)
self.out_proj = nn.Dense(
self.num_classes,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
def __call__(self, hidden_states: jnp.ndarray, deterministic: bool):
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
hidden_states = self.dense(hidden_states)
hidden_states = jnp.tanh(hidden_states)
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
hidden_states = self.out_proj(hidden_states)
return hidden_states
class FlaxMBartEncoder(nn.Module):
config: MBartConfig
embed_tokens: nn.Embed
dtype: jnp.dtype = jnp.float32
def setup(self):
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
embed_dim = self.config.d_model
self.padding_idx = self.config.pad_token_id
self.max_source_positions = self.config.max_position_embeddings
self.embed_scale = math.sqrt(embed_dim) if self.config.scale_embedding else 1.0
self.offset = 2
self.embed_positions = nn.Embed(
self.config.max_position_embeddings + self.offset,
embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
)
self.layers = FlaxMBartEncoderLayerCollection(self.config, self.dtype)
self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
self.layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
def __call__(
self,
input_ids,
attention_mask,
position_ids,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
deterministic: bool = True,
):
input_shape = input_ids.shape
input_ids = input_ids.reshape(-1, input_shape[-1])
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
embed_pos = self.embed_positions(position_ids + self.offset)
hidden_states = inputs_embeds + embed_pos
hidden_states = self.layernorm_embedding(hidden_states)
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
outputs = self.layers(
hidden_states,
attention_mask,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_states = outputs[0]
last_hidden_states = self.layer_norm(last_hidden_states)
hidden_states = None
if output_hidden_states:
hidden_states = outputs[1]
hidden_states = hidden_states[:-1] + (last_hidden_states,)
if not return_dict:
outputs = (last_hidden_states, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:])
return tuple(v for v in outputs if v is not None)
return FlaxBaseModelOutput(
last_hidden_state=last_hidden_states,
hidden_states=hidden_states,
attentions=outputs.attentions,
)
class FlaxMBartDecoder(nn.Module):
config: MBartConfig
embed_tokens: nn.Embed
dtype: jnp.dtype = jnp.float32
def setup(self):
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
embed_dim = self.config.d_model
self.padding_idx = self.config.pad_token_id
self.max_target_positions = self.config.max_position_embeddings
self.embed_scale = math.sqrt(self.config.d_model) if self.config.scale_embedding else 1.0
self.offset = 2
self.embed_positions = nn.Embed(
self.config.max_position_embeddings + self.offset,
embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
)
self.layers = FlaxMBartDecoderLayerCollection(self.config, self.dtype)
self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
self.layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
def __call__(
self,
input_ids,
attention_mask,
position_ids,
encoder_hidden_states: Optional[jnp.ndarray] = None,
encoder_attention_mask: Optional[jnp.ndarray] = None,
init_cache: bool = False,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
deterministic: bool = True,
):
input_shape = input_ids.shape
input_ids = input_ids.reshape(-1, input_shape[-1])
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
positions = self.embed_positions(position_ids + self.offset)
hidden_states = inputs_embeds + positions
hidden_states = self.layernorm_embedding(hidden_states)
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
outputs = self.layers(
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
deterministic=deterministic,
init_cache=init_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_states = outputs[0]
last_hidden_states = self.layer_norm(last_hidden_states)
hidden_states = None
if output_hidden_states:
hidden_states = outputs[1]
hidden_states = hidden_states[:-1] + (last_hidden_states,)
if not return_dict:
outputs = (last_hidden_states, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:])
return tuple(v for v in outputs if v is not None)
return FlaxBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=last_hidden_states,
hidden_states=hidden_states,
attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
)
class FlaxMBartModule(nn.Module):
config: MBartConfig
def setup(self):
self.shared = nn.Embed(
self.config.vocab_size,
self.config.d_model,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
dtype=self.dtype,
)
self.encoder = FlaxMBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
self.decoder = FlaxMBartDecoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
def _get_encoder_module(self):
return self.encoder
def _get_decoder_module(self):
return self.decoder
def __call__(
self,
input_ids,
attention_mask,
decoder_input_ids,
decoder_attention_mask,
position_ids,
decoder_position_ids,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
deterministic: bool = True,
):
encoder_outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=deterministic,
)
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
position_ids=decoder_position_ids,
encoder_hidden_states=encoder_outputs[0],
encoder_attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=deterministic,
)
if not return_dict:
return decoder_outputs + encoder_outputs
return FlaxSeq2SeqModelOutput(
last_hidden_state=decoder_outputs.last_hidden_state,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)
class FlaxMBartPreTrainedModel(FlaxPreTrainedModel):
config_class = MBartConfig
base_model_prefix: str = "model"
def __init__(
self,
config: MBartConfig,
input_shape: Tuple[int] = (1, 1),
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs,
):
super().__init__(config=config, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init, **kwargs)
):
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
input_ids = jnp.zeros(input_shape, dtype="i4")
input_ids = input_ids.at[(..., -1)].set(self.config.eos_token_id)
attention_mask = jnp.ones_like(input_ids)
decoder_input_ids = input_ids
decoder_attention_mask = jnp.ones_like(input_ids)
batch_size, sequence_length = input_ids.shape
position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
decoder_position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
random_params = self.module.init(
rngs,
input_ids,
attention_mask,
decoder_input_ids,
decoder_attention_mask,
position_ids,
decoder_position_ids,
)["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
def init_cache(self, batch_size, max_length, encoder_outputs):
r"""
Args:
batch_size (`int`):
用于快速自回归解码的批大小。定义初始化缓存时的批大小。
max_length (`int`):
自回归解码的最大可能长度。定义初始化缓存时的序列长度。
encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`):
`encoder_outputs` 包括 (`last_hidden_state`, *可选*: `hidden_states`, *可选*: `attentions`)。
`last_hidden_state` 的形状为 `(batch_size, sequence_length, hidden_size)`,*可选* 是编码器最后一层的隐藏状态的序列,
用于解码器的交叉注意力。
"""
decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4")
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
decoder_position_ids = jnp.broadcast_to(
jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape
)
def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs):
decoder_module = module._get_decoder_module()
return decoder_module(
decoder_input_ids,
decoder_attention_mask,
decoder_position_ids,
**kwargs,
)
init_variables = self.module.init(
jax.random.PRNGKey(0),
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
decoder_position_ids=decoder_position_ids,
encoder_hidden_states=encoder_outputs[0],
init_cache=True,
method=_decoder_forward,
)
return unfreeze(init_variables["cache"])
@add_start_docstrings(MBART_ENCODE_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=MBartConfig)
def encode(
self,
input_ids: jnp.ndarray,
attention_mask: Optional[jnp.ndarray] = None,
position_ids: Optional[jnp.ndarray] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
):
r"""
Returns:
Example:
```
>>> from transformers import AutoTokenizer, FlaxMBartForConditionalGeneration
>>> model = FlaxMBartForConditionalGeneration.from_pretrained("facebook/mbart-large-cc25")
>>> tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25")
>>> text = "My friends are cool but they eat too many carbs."
>>> inputs = tokenizer(text, max_length=1024, return_tensors="jax")
>>> encoder_outputs = model.encode(**inputs)
```
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids)
if position_ids is None:
batch_size, sequence_length = input_ids.shape
position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
def _encoder_forward(module, input_ids, attention_mask, position_ids, **kwargs):
encode_module = module._get_encoder_module()
return encode_module(input_ids, attention_mask, position_ids, **kwargs)
return self.module.apply(
{"params": params or self.params},
input_ids=jnp.array(input_ids, dtype="i4"),
attention_mask=jnp.array(attention_mask, dtype="i4"),
position_ids=jnp.array(position_ids, dtype="i4"),
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=not train,
rngs=rngs,
method=_encoder_forward,
)
@add_start_docstrings(MBART_DECODE_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=FlaxBaseModelOutputWithPastAndCrossAttentions, config_class=MBartConfig)
def decode(
self,
decoder_input_ids,
encoder_outputs,
encoder_attention_mask: Optional[jnp.ndarray] = None,
decoder_attention_mask: Optional[jnp.ndarray] = None,
decoder_position_ids: Optional[jnp.ndarray] = None,
past_key_values: dict = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
):
r"""
"""
def __call__(
self,
input_ids: jnp.ndarray,
attention_mask: Optional[jnp.ndarray] = None,
decoder_input_ids: Optional[jnp.ndarray] = None,
decoder_attention_mask: Optional[jnp.ndarray] = None,
position_ids: Optional[jnp.ndarray] = None,
decoder_position_ids: Optional[jnp.ndarray] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids)
if position_ids is None:
batch_size, sequence_length = input_ids.shape
position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
if decoder_input_ids is None:
decoder_input_ids = shift_tokens_right(input_ids, self.config.pad_token_id)
if decoder_attention_mask is None:
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
if decoder_position_ids is None:
batch_size, sequence_length = decoder_input_ids.shape
decoder_position_ids = jnp.broadcast_to(
jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
)
rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
return self.module.apply(
{"params": params or self.params},
input_ids=jnp.array(input_ids, dtype="i4"),
attention_mask=jnp.array(attention_mask, dtype="i4"),
position_ids=jnp.array(position_ids, dtype="i4"),
decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=not train,
rngs=rngs,
)
@add_start_docstrings(
"The bare MBart Model transformer outputting raw hidden-states without any specific head on top.",
MBART_START_DOCSTRING,
)
class FlaxMBartModel(FlaxMBartPreTrainedModel):
config: MBartConfig
dtype: jnp.dtype = jnp.float32
module_class = FlaxMBartModule
append_call_sample_docstring(FlaxMBartModel, _CHECKPOINT_FOR_DOC, FlaxSeq2SeqModelOutput, _CONFIG_FOR_DOC)
class FlaxMBartForConditionalGenerationModule(nn.Module):
config: MBartConfig
dtype: jnp.dtype = jnp.float32
bias_init: Callable[..., jnp.ndarray] = jax.nn.initializers.zeros
def setup(self):
self.model = FlaxMBartModule(config=self.config, dtype=self.dtype)
self.lm_head = nn.Dense(
self.model.shared.num_embeddings,
use_bias=False,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.model.shared.num_embeddings))
def _get_encoder_module(self):
return self.model.encoder
def _get_decoder_module(self):
return self.model.decoder
def __call__(
self,
input_ids,
attention_mask,
decoder_input_ids,
decoder_attention_mask,
position_ids,
decoder_position_ids,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
deterministic: bool = True,
):
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
position_ids=position_ids,
decoder_position_ids=decoder_position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=deterministic,
)
hidden_states = outputs[0]
if self.config.tie_word_embeddings:
shared_embedding = self.model.variables["params"]["shared"]["embedding"]
lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
else:
lm_logits = self.lm_head(hidden_states)
lm_logits += jax.lax.stop_gradient(self.final_logits_bias.astype(self.dtype))
if not return_dict:
output = (lm_logits,) + outputs[1:]
return output
return FlaxSeq2SeqLMOutput(
logits=lm_logits,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
cross_attentions=outputs.cross_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions,
)
@add_start_docstrings(
"The MMBart Model with a language modeling head. Can be used for summarization.", MBART_START_DOCSTRING
)
class FlaxMBartForConditionalGeneration(FlaxMBartPreTrainedModel):
module_class = FlaxMBartForConditionalGenerationModule
dtype: jnp.dtype = jnp.float32
@add_start_docstrings(MBART_DECODE_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=MBartConfig)
def decode(
self,
decoder_input_ids,
encoder_outputs,
encoder_attention_mask: Optional[jnp.ndarray] = None,
decoder_attention_mask: Optional[jnp.ndarray] = None,
decoder_position_ids: Optional[jnp.ndarray] = None,
past_key_values: dict = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
):
"""
Performs decoding with the model.
Args:
decoder_input_ids: Input IDs for the decoder.
encoder_outputs: Outputs from the encoder.
encoder_attention_mask: Optional attention mask for the encoder outputs.
decoder_attention_mask: Optional attention mask for the decoder inputs.
decoder_position_ids: Optional position IDs for the decoder inputs.
past_key_values: Cached key values for efficient generation.
output_attentions: Whether to output attentions.
output_hidden_states: Whether to output hidden states.
return_dict: Whether to return a dictionary or a tuple.
train: Whether in training mode.
params: Optional parameters.
dropout_rng: Dropout random number generator key.
Returns:
Model output with cross attentions.
"""
def prepare_inputs_for_generation(
self,
decoder_input_ids,
max_length,
attention_mask: Optional[jax.Array] = None,
decoder_attention_mask: Optional[jax.Array] = None,
encoder_outputs=None,
**kwargs,
):
"""
Prepares inputs for generation.
Args:
decoder_input_ids: Input IDs for the decoder.
max_length: Maximum length for generation.
attention_mask: Optional attention mask for the encoder outputs.
decoder_attention_mask: Optional attention mask for the decoder inputs.
encoder_outputs: Outputs from the encoder.
**kwargs: Additional keyword arguments.
Returns:
Dictionary with prepared inputs for generation.
"""
batch_size, seq_length = decoder_input_ids.shape
past_key_values = self.init_cache(batch_size, max_length, encoder_outputs)
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
if decoder_attention_mask is not None:
position_ids = decoder_attention_mask.cumsum(axis=-1) - 1
extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0))
else:
position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
return {
"past_key_values": past_key_values,
"encoder_outputs": encoder_outputs,
"encoder_attention_mask": attention_mask,
"decoder_attention_mask": extended_attention_mask,
"decoder_position_ids": position_ids,
}
def update_inputs_for_generation(self, model_outputs, model_kwargs):
"""
Updates inputs for generation.
Args:
model_outputs: Model outputs from the generation.
model_kwargs: Original model keyword arguments.
Returns:
Updated model keyword arguments.
"""
model_kwargs["past_key_values"] = model_outputs.past_key_values
model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1
return model_kwargs
FLAX_MBART_CONDITIONAL_GENERATION_DOCSTRING = r"""
Returns:
Summarization example:
```
>>> from transformers import AutoTokenizer, FlaxMBartForConditionalGeneration, MBartConfig
>>> model = FlaxMBartForConditionalGeneration.from_pretrained("facebook/mbart-large-cc25")
>>> tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25")
从预训练的MBart模型和tokenizer中加载Facebook的mbart-large-cc25模型和标记器。
>>> ARTICLE_TO_SUMMARIZE = "Meine Freunde sind cool, aber sie essen zu viel Kuchen."
>>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors="np")
定义要进行摘要的文章,并使用tokenizer将其转换为模型所需的输入格式。
>>> # Generate Summary
>>> summary_ids = model.generate(inputs["input_ids"], num_beams=4, max_length=5).sequences
>>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False))
使用模型生成文章的摘要,指定生成4个束(beam),最大长度为5,然后解码生成的摘要并打印。
>>> from transformers import AutoTokenizer, FlaxMBartForConditionalGeneration
>>> model = FlaxMBartForConditionalGeneration.from_pretrained("facebook/mbart-large-cc25")
>>> tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25")
再次加载MBart模型和标记器,确保环境准备好用于示例。
>>> # de_DE is the language symbol id <LID> for German
>>> TXT = "</s> Meine Freunde sind <mask> nett aber sie essen zu viel Kuchen. </s> de_DE"
>>> input_ids = tokenizer([TXT], add_special_tokens=False, return_tensors="np")["input_ids"]
定义一个包含掩码填充的例子,`TXT`包含一个掩码标记`<mask>`,表示需要填充的位置。将`TXT`编码为模型可接受的输入格式。
>>> logits = model(input_ids).logits
>>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero()[0].item()
>>> probs = logits[0, masked_index].softmax(dim=0)
>>> values, predictions = probs.topk(5)
使用模型预测掩码位置的概率分布,并选择最高的五个概率值。
>>> tokenizer.decode(predictions).split()
将预测的结果解码为文本序列,并分割为单词列表。
"""
overwrite_call_docstring(
FlaxMBartForConditionalGeneration, MBART_INPUTS_DOCSTRING + FLAX_MBART_CONDITIONAL_GENERATION_DOCSTRING
)
append_replace_return_docstrings(
FlaxMBartForConditionalGeneration, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC
)
class FlaxMBartForSequenceClassificationModule(nn.Module):
config: MBartConfig
dtype: jnp.dtype = jnp.float32
num_labels: Optional[int] = None
def setup(self):
self.model = FlaxMBartModule(config=self.config, dtype=self.dtype)
self.classification_head = FlaxMBartClassificationHead(
config=self.config,
inner_dim=self.config.d_model,
num_classes=self.num_labels if self.num_labels is not None else self.config.num_labels,
pooler_dropout=self.config.classifier_dropout,
)
def _get_encoder_module(self):
return self.model.encoder
def _get_decoder_module(self):
return self.model.decoder
def __call__(
self,
input_ids,
attention_mask,
decoder_input_ids,
decoder_attention_mask,
position_ids,
decoder_position_ids,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
deterministic: bool = True,
return self.model(
input_ids,
attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
position_ids=position_ids,
decoder_position_ids=decoder_position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=deterministic,
)
):
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
position_ids=position_ids,
decoder_position_ids=decoder_position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=deterministic,
)
hidden_states = outputs[0]
eos_mask = jnp.where(input_ids == self.config.eos_token_id, 1, 0)
if type(eos_mask) != jax.interpreters.partial_eval.DynamicJaxprTracer:
if len(jnp.unique(eos_mask.sum(1))) > 1:
raise ValueError("所有示例必须具有相同数量的 <eos> 标记。")
if any(eos_mask.sum(1) == 0):
raise ValueError("输入中缺少 <eos> 标记。")
eos_mask_noised = eos_mask + jnp.arange(eos_mask.shape[1]) * 1e-6
eos_mask = jnp.where(eos_mask_noised == eos_mask_noised.max(1).reshape(-1, 1), 1, 0)
sentence_representation = jnp.einsum("ijk, ij -> ijk", hidden_states, eos_mask).sum(1)
logits = self.classification_head(sentence_representation, deterministic=deterministic)
if not return_dict:
output = (logits,) + outputs[1:]
return output
return FlaxSeq2SeqSequenceClassifierOutput(
logits=logits,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
cross_attentions=outputs.cross_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions,
)
@add_start_docstrings(
"""
MBart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE
tasks.
""",
MBART_START_DOCSTRING,
)
"""
使用`add_start_docstrings`装饰器为`FlaxMBartForSequenceClassification`类添加文档字符串,描述其作为带有顶部序列分类/头部的MBart模型。
"""
class FlaxMBartForSequenceClassification(FlaxMBartPreTrainedModel):
"""
MBart序列分类模型,继承自`FlaxMBartPreTrainedModel`。
"""
module_class = FlaxMBartForSequenceClassificationModule
dtype = jnp.float32
append_call_sample_docstring(
FlaxMBartForSequenceClassification,
_CHECKPOINT_FOR_DOC,
FlaxSeq2SeqSequenceClassifierOutput,
_CONFIG_FOR_DOC,
)
"""
使用`append_call_sample_docstring`函数为`FlaxMBartForSequenceClassification`类添加示例调用文档字符串。
"""
"""
从`transformers.models.bart.modeling_flax_bart.FlaxBartForQuestionAnsweringModule`复制代码,并将Bart替换为MBart。
"""
class FlaxMBartForQuestionAnsweringModule(nn.Module):
"""
MBart问答模块定义,继承自`nn.Module`。
"""
config: MBartConfig
dtype: jnp.dtype = jnp.float32
num_labels = 2
def setup(self):
"""
设置方法,初始化模型和输出层。
"""
self.model = FlaxMBartModule(config=self.config, dtype=self.dtype)
self.qa_outputs = nn.Dense(
self.num_labels, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
)
def _get_encoder_module(self):
"""
获取编码器模块的私有方法。
"""
return self.model.encoder
def _get_decoder_module(self):
"""
获取解码器模块的私有方法。
"""
return self.model.decoder
def __call__(
self,
input_ids,
attention_mask,
decoder_input_ids,
decoder_attention_mask,
position_ids,
decoder_position_ids,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
deterministic: bool = True,
):
"""
模型调用方法,接受多个输入和参数,返回包含多个输出的字典或元组。
Args:
input_ids: 输入的编码器输入id。
attention_mask: 编码器的注意力掩码。
decoder_input_ids: 解码器输入id。
decoder_attention_mask: 解码器的注意力掩码。
position_ids: 输入的位置id。
decoder_position_ids: 解码器的位置id。
output_attentions: 是否输出注意力权重。
output_hidden_states: 是否输出隐藏状态。
return_dict: 是否以字典形式返回输出。
deterministic: 是否确定性计算。
Returns:
根据return_dict返回不同结构的输出。
"""
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
position_ids=position_ids,
decoder_position_ids=decoder_position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=deterministic,
)
sequence_output = outputs[0]
logits = self.qa_outputs(sequence_output)
start_logits, end_logits = jnp.split(logits, logits.shape[-1], axis=-1)
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)
if not return_dict:
output = (start_logits, end_logits) + outputs[1:]
return output
return FlaxSeq2SeqQuestionAnsweringModelOutput(
start_logits=start_logits,
end_logits=end_logits,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
cross_attentions=outputs.cross_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions,
)
MBart Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
""",
MBART_START_DOCSTRING,
# MBart 模型,使用顶部的跨度分类头部用于抽取式问答任务,如 SQuAD(在隐藏状态输出之上的线性层,用于计算“起始跨度对数”和“结束跨度对数”)。
MBart Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
"""
MBART_START_DOCSTRING,
)
class FlaxMBartForQuestionAnswering(FlaxMBartPreTrainedModel):
module_class = FlaxMBartForQuestionAnsweringModule
dtype = jnp.float32
append_call_sample_docstring(
FlaxMBartForQuestionAnswering,
_CHECKPOINT_FOR_DOC,
FlaxSeq2SeqQuestionAnsweringModelOutput,
_CONFIG_FOR_DOC,
)
.\models\mbart\modeling_mbart.py
""" PyTorch MBART 模型定义 """
import copy
import math
from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
Seq2SeqLMOutput,
Seq2SeqModelOutput,
Seq2SeqQuestionAnsweringModelOutput,
Seq2SeqSequenceClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...utils import (
add_code_sample_docstrings,
add_end_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
)
from .configuration_mbart import MBartConfig
if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "facebook/mbart-large-cc25"
_CONFIG_FOR_DOC = "MBartConfig"
_EXPECTED_OUTPUT_SHAPE = [1, 8, 1024]
MBART_PRETRAINED_MODEL_ARCHIVE_LIST = [
"facebook/mbart-large-cc25",
]
def _get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int):
"""
将输入的 ID 向右移动一个位置,并包装最后一个非填充标记(即 <LID> 标记)。需要注意的是,与其他类似 Bart 的模型不同,MBart 没有单一的 `decoder_start_token_id`。
"""
prev_output_tokens = input_ids.clone()
if pad_token_id is None:
raise ValueError("self.model.config.pad_token_id has to be defined.")
prev_output_tokens.masked_fill_(prev_output_tokens == -100, pad_token_id)
index_of_eos = (prev_output_tokens.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1)
decoder_start_tokens = prev_output_tokens.gather(1, index_of_eos).squeeze()
prev_output_tokens[:, 1:] = prev_output_tokens[:, :-1].clone()
prev_output_tokens[:, 0] = decoder_start_tokens
return prev_output_tokens
class MBartLearnedPositionalEmbedding(nn.Embedding):
"""
This module learns positional embeddings up to a fixed maximum size.
"""
def __init__(self, num_embeddings: int, embedding_dim: int):
self.offset = 2
super().__init__(num_embeddings + self.offset, embedding_dim)
def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
"""`input_ids' shape is expected to be [bsz x seqlen]."""
bsz, seq_len = input_ids.shape[:2]
positions = torch.arange(
past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
).expand(bsz, -1)
return super().forward(positions + self.offset)
class MBartAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
is_decoder: bool = False,
bias: bool = True,
is_causal: bool = False,
config: Optional[MBartConfig] = None,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
self.config = config
if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {num_heads})."
)
self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
self.is_causal = is_causal
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
):
pass
class MBartFlashAttention2(MBartAttention):
"""
Placeholder class for future extension or modification.
"""
pass
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
def _flash_attention_forward(
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
first unpad the input, then computes the attention scores and pad the final attention scores.
Args:
query_states (`torch.Tensor`):
Input query states to be passed to Flash Attention API
key_states (`torch.Tensor`):
Input key states to be passed to Flash Attention API
value_states (`torch.Tensor`):
Input value states to be passed to Flash Attention API
attention_mask (`torch.Tensor`):
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
position of padding tokens and 1 for the position of non-padding tokens.
dropout (`float`):
Attention dropout
softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
"""
if not self._flash_attn_uses_top_left_mask:
causal = self.is_causal
else:
causal = self.is_causal and query_length != 1
if attention_mask is not None:
batch_size = query_states.shape[0]
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
query_states, key_states, value_states, attention_mask, query_length
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
)
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else:
attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
)
return attn_output
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
key_layer = index_first_axis(
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
value_layer = index_first_axis(
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
if query_length == kv_seq_len:
query_layer = index_first_axis(
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
elif query_length == 1:
max_seqlen_in_batch_q = 1
cu_seqlens_q = torch.arange(
batch_size + 1, dtype=torch.int32, device=query_layer.device
)
indices_q = cu_seqlens_q[:-1]
query_layer = query_layer.squeeze(1)
else:
attention_mask = attention_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
return (
query_layer,
key_layer,
value_layer,
indices_q,
(cu_seqlens_q, cu_seqlens_k),
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)
MBART_ATTENTION_CLASSES = {
"eager": MBartAttention,
"flash_attention_2": MBartFlashAttention2,
}
class MBartEncoderLayer(nn.Module):
def __init__(self, config: MBartConfig):
super().__init__()
self.embed_dim = config.d_model
self.self_attn = MBART_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout,
config=config,
)
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout
self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
layer_head_mask: torch.Tensor,
output_attentions: bool = False,
) -> torch.Tensor:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
`(encoder_attention_heads,)`.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
"""
residual = hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states, attn_weights, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = self.fc2(hidden_states)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
if hidden_states.dtype == torch.float16 and (
torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
):
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs
class MBartDecoderLayer(nn.Module):
def __init__(self, config: MBartConfig):
super().__init__()
self.embed_dim = config.d_model
self.self_attn = MBART_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
is_decoder=True,
is_causal=True,
config=config,
)
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.encoder_attn = MBART_ATTENTION_CLASSES[config._attn_implementation](
self.embed_dim,
config.decoder_attention_heads,
dropout=config.attention_dropout,
is_decoder=True,
config=config,
)
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = True,
):
...
class MBartClassificationHead(nn.Module):
"""Head for sentence-level classification tasks."""
def __init__(
self,
input_dim: int,
inner_dim: int,
num_classes: int,
pooler_dropout: float,
):
super().__init__()
self.dense = nn.Linear(input_dim, inner_dim)
self.dropout = nn.Dropout(p=pooler_dropout)
self.out_proj = nn.Linear(inner_dim, num_classes)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dropout(hidden_states)
hidden_states = self.dense(hidden_states)
hidden_states = torch.tanh(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.out_proj(hidden_states)
return hidden_states
class MBartPreTrainedModel(PreTrainedModel):
config_class = MBartConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["MBartDecoderLayer", "MBartAttention"]
_supports_flash_attn_2 = True
def _init_weights(self, module):
std = self.config.init_std
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
@property
def dummy_inputs(self):
pad_token = self.config.pad_token_id
input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
dummy_inputs = {
"attention_mask": input_ids.ne(pad_token),
"input_ids": input_ids,
}
return dummy_inputs
MBART_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`MBartConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
MBART_GENERATION_EXAMPLE = r"""
Translation example:
```
>>> from transformers import AutoTokenizer, MBartForConditionalGeneration
>>> model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-en-ro")
>>> tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-en-ro")
>>> example_english_phrase = "42 is the answer"
>>> inputs = tokenizer(example_english_phrase, return_tensors="pt")
>>> # Translate
>>> generated_ids = model.generate(**inputs, num_beams=4, max_length=5)
>>> tokenizer.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
'42 este răspuns'
```
Mask filling example:
```
>>> from transformers import AutoTokenizer, MBartForConditionalGeneration
>>> model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-cc25")
>>> tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25")
>>> # de_DE is the language symbol id <LID> for German
>>> TXT = "</s> Meine Freunde sind <mask> nett aber sie essen zu viel Kuchen. </s> de_DE"
>>> input_ids = tokenizer([TXT], add_special_tokens=False, return_tensors="pt")["input_ids"]
>>> logits = model(input_ids).logits
>>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
>>> probs = logits[0, masked_index].softmax(dim=0)
>>> values, predictions = probs.topk(5)
>>> tokenizer.decode(predictions).split()
['nett', 'sehr', 'ganz', 'nicht', 'so']
```
"""
MBART_INPUTS_DOCSTRING = r"""
"""
class MBartEncoder(MBartPreTrainedModel):
"""
Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
[`MBartEncoderLayer`].
Args:
config: MBartConfig
Model configuration class with all the parameters of the model.
embed_tokens (nn.Embedding): output embedding
The output embedding for the model.
"""
def __init__(self, config: MBartConfig, embed_tokens: Optional[nn.Embedding] = None):
super().__init__(config)
self.dropout = config.dropout
self.layerdrop = config.encoder_layerdrop
embed_dim = config.d_model
self.padding_idx = config.pad_token_id
self.max_source_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
if embed_tokens is not None:
self.embed_tokens.weight = embed_tokens.weight
self.embed_positions = MBartLearnedPositionalEmbedding(
config.max_position_embeddings,
embed_dim,
)
self.layers = nn.ModuleList([MBartEncoderLayer(config) for _ in range(config.encoder_layers)])
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self.layernorm_embedding = nn.LayerNorm(embed_dim)
self.layer_norm = nn.LayerNorm(config.d_model)
self.gradient_checkpointing = False
self.post_init()
def _backward_compatibility_gradient_checkpointing(self):
if self.supports_gradient_checkpointing and getattr(self.config, "gradient_checkpointing", False):
self.gradient_checkpointing_enable()
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
class MBartDecoder(MBartPreTrainedModel):
"""
Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`MBartDecoderLayer`]
Args:
config: MBartConfig
embed_tokens (nn.Embedding): output embedding
"""
def __init__(self, config: MBartConfig, embed_tokens: Optional[nn.Embedding] = None):
super().__init__(config)
self.dropout = config.dropout
self.layerdrop = config.decoder_layerdrop
self.padding_idx = config.pad_token_id
self.max_target_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
if embed_tokens is not None:
self.embed_tokens.weight = embed_tokens.weight
self.embed_positions = MBartLearnedPositionalEmbedding(
config.max_position_embeddings,
config.d_model,
)
self.layers = nn.ModuleList([MBartDecoderLayer(config) for _ in range(config.decoder_layers)])
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self.layernorm_embedding = nn.LayerNorm(config.d_model)
self.layer_norm = nn.LayerNorm(config.d_model)
self.gradient_checkpointing = False
self.post_init()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
"""
MBartDecoder的前向传播函数
Args:
input_ids (torch.LongTensor, optional): 输入的token IDs
attention_mask (torch.Tensor, optional): 注意力掩码
encoder_hidden_states (torch.FloatTensor, optional): 编码器的隐藏状态
encoder_attention_mask (torch.LongTensor, optional): 编码器的注意力掩码
head_mask (torch.Tensor, optional): 多头注意力的头部掩码
cross_attn_head_mask (torch.Tensor, optional): 跨注意力头部的掩码
past_key_values (Tuple[Tuple[torch.FloatTensor]], optional): 缓存的键值对
inputs_embeds (torch.FloatTensor, optional): 输入的嵌入表示
use_cache (bool, optional): 是否使用缓存
output_attentions (bool, optional): 是否输出注意力
output_hidden_states (bool, optional): 是否输出隐藏状态
return_dict (bool, optional): 是否返回字典
Returns:
根据配置返回不同的输出
"""
pass
def get_input_embeddings(self):
return self.shared
def set_input_embeddings(self, value):
self.shared = value
self.encoder.embed_tokens = self.shared
self.decoder.embed_tokens = self.shared
def get_encoder(self):
return self.encoder
def get_decoder(self):
return self.decoder
def _tie_weights(self):
if self.config.tie_word_embeddings:
self._tie_or_clone_weights(self.encoder.embed_tokens, self.get_input_embeddings())
self._tie_or_clone_weights(self.decoder.embed_tokens, self.get_input_embeddings())
@add_start_docstrings_to_model_forward(MBART_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=Seq2SeqModelOutput,
config_class=_CONFIG_FOR_DOC,
expected_output=_EXPECTED_OUTPUT_SHAPE,
)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
@add_start_docstrings(
"The MBART Model with a language modeling head. Can be used for summarization, after fine-tuning the pretrained models.",
MBART_START_DOCSTRING,
)
class MBartForConditionalGeneration(MBartPreTrainedModel):
base_model_prefix = "model"
_keys_to_ignore_on_load_missing = ["final_logits_bias"]
_tied_weights_keys = ["model.encoder.embed_tokens.weight", "model.decoder.embed_tokens.weight", "lm_head.weight"]
def __init__(self, config: MBartConfig):
super().__init__(config)
self.model = MBartModel(config)
self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
self.post_init()
def get_encoder(self):
return self.model.get_encoder()
def get_decoder(self):
return self.model.get_decoder()
def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
self._resize_final_logits_bias(new_embeddings.weight.shape[0])
return new_embeddings
def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
old_num_tokens = self.final_logits_bias.shape[-1]
if new_num_tokens <= old_num_tokens:
new_bias = self.final_logits_bias[:, :new_num_tokens]
else:
extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
self.register_buffer("final_logits_bias", new_bias)
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
@add_start_docstrings_to_model_forward(MBART_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
@add_end_docstrings(MBART_GENERATION_EXAMPLE)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Seq2SeqLMOutput, Tuple[torch.FloatTensor]]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Either a Seq2SeqLMOutput containing loss, logits, and other optional outputs, or a tuple of
torch.FloatTensor containing logits and optional outputs.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if labels is not None:
if use_cache:
logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
use_cache = False
if decoder_input_ids is None and decoder_inputs_embeds is None:
decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id)
outputs = self.model(
input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
encoder_outputs=encoder_outputs,
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias
masked_lm_loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict:
output = (lm_logits,) + outputs[1:]
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
return Seq2SeqLMOutput(
loss=masked_lm_loss,
logits=lm_logits,
past_key_values=outputs.past_key_values,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
cross_attentions=outputs.cross_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions,
)
def prepare_inputs_for_generation(
self,
decoder_input_ids,
past_key_values=None,
attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
use_cache=None,
encoder_outputs=None,
**kwargs,
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]
if decoder_input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
remove_prefix_length = decoder_input_ids.shape[1] - 1
decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
return {
"input_ids": None,
"encoder_outputs": encoder_outputs,
"past_key_values": past_key_values,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
"use_cache": use_cache,
}
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id)
@staticmethod
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2])
+ layer_past[2:],
)
return reordered_past
@add_start_docstrings(
"""
MBart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE
tasks.
""",
MBART_START_DOCSTRING,
)
class MBartForSequenceClassification(MBartPreTrainedModel):
_tied_weights_keys = ["model.encoder.embed_tokens.weight", "model.decoder.embed_tokens.weight"]
def __init__(self, config: MBartConfig, **kwargs):
super().__init__(config, **kwargs)
self.model = MBartModel(config)
self.classification_head = MBartClassificationHead(
config.d_model,
config.d_model,
config.num_labels,
config.classifier_dropout,
)
self.post_init()
@add_start_docstrings_to_model_forward(MBART_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=Seq2SeqSequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
@add_start_docstrings(
"""
MBART Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
""",
MBART_START_DOCSTRING,
)
class MBartForQuestionAnswering(MBartPreTrainedModel):
_tied_weights_keys = ["model.encoder.embed_tokens.weight", "model.decoder.embed_tokens.weight"]
def __init__(self, config):
super().__init__(config)
config.num_labels = 2
self.num_labels = config.num_labels
self.model = MBartModel(config)
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
self.post_init()
@add_start_docstrings_to_model_forward(MBART_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=Seq2SeqQuestionAnsweringModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: torch.Tensor = None,
attention_mask: Optional[torch.Tensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[List[torch.FloatTensor]] = None,
start_positions: Optional[torch.LongTensor] = None,
end_positions: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
class MBartDecoderWrapper(MBartPreTrainedModel):
"""
这个包装类是一个辅助类,用于在因果语言模型与 [`EncoderDecoderModel`] 框架结合使用时正确加载预训练检查点。
"""
def __init__(self, config):
super().__init__(config)
self.decoder = MBartDecoder(config)
def forward(self, *args, **kwargs):
return self.decoder(*args, **kwargs)
class MBartForCausalLM(MBartPreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
config = copy.deepcopy(config)
config.is_decoder = True
config.is_encoder_decoder = False
super().__init__(config)
self.model = MBartDecoderWrapper(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.model.decoder.embed_tokens
def set_input_embeddings(self, value):
self.model.decoder.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model.decoder = decoder
def get_decoder(self):
return self.model.decoder
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
pass
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs
):
pass
):
if attention_mask is None:
attention_mask = input_ids.new_ones(input_ids.shape)
if past_key_values:
past_length = past_key_values[0][0].shape[2]
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"past_key_values": past_key_values,
"use_cache": use_cache,
}
@staticmethod
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past