Transformers 源码解析(七十一)
.\models\mask2former\modeling_mask2former.py
""" PyTorch Mask2Former model. """
import math
import warnings
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
from torch import Tensor, nn
from ...activations import ACT2FN
from ...file_utils import (
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_scipy_available,
replace_return_docstrings,
requires_backends,
)
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithCrossAttentions,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import is_torch_greater_or_equal_than_2_1
from ...utils import is_accelerate_available, logging
from ...utils.backbone_utils import load_backbone
from .configuration_mask2former import Mask2FormerConfig
if is_scipy_available():
from scipy.optimize import linear_sum_assignment
if is_accelerate_available():
from accelerate import PartialState
from accelerate.utils import reduce
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "Mask2FormerConfig"
_CHECKPOINT_FOR_DOC = "facebook/mask2former-swin-small-coco-instance"
_IMAGE_PROCESSOR_FOR_DOC = "Mask2FormerImageProcessor"
MASK2FORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
"facebook/mask2former-swin-small-coco-instance",
]
@dataclass
class Mask2FormerPixelDecoderOutput(ModelOutput):
"""
Mask2Former's pixel decoder module output, practically a Multi-Scale Deformable Attention based decoder. It returns
the mask features and the multiscale features.
"""
pass
Args:
multi_scale_features (`tuple(torch.FloatTensor)`):
Tuple of multi-scale features of scales [1/8, 1/16, 1/32] and shape `(batch_size, num_channels, height,
width)`from the Multi-Scale Deformable Attention based Pixel Decoder.
多尺度特征的元组,包含比例为 [1/8, 1/16, 1/32] 的特征,形状为 `(batch_size, num_channels, height, width)`,
来自基于多尺度可变注意力的像素解码器。
mask_features (`torch.FloatTensor`):
Tensor of shape `(batch_size, num_channels, height, width)`, 1/4 scale features from the last Pixel Decoder
Layer.
形状为 `(batch_size, num_channels, height, width)` 的张量,来自最后一个像素解码器层的1/4比例特征。
attentions (`tuple(torch.FloatTensor)`, *optional*):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`. Attentions weights from pixel decoder. Returned when `output_attentions=True` is passed
or when `config.output_attentions=True`
可选的注意力权重元组,每个元素的形状为 `(batch_size, num_heads, sequence_length, sequence_length)`,
表示像素解码器中的注意力权重。在设置 `output_attentions=True` 或 `config.output_attentions=True` 时返回。
"""
multi_scale_features: Tuple[torch.FloatTensor] = None
mask_features: torch.FloatTensor = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class Mask2FormerMaskedAttentionDecoderOutput(BaseModelOutputWithCrossAttentions):
"""
Mask2FormerMaskedAttentionDecoderOutput 类用于表示 Transformer 解码器的输出。
它在 BaseModelOutputWithCrossAttentions 的基础上添加了两个属性:mask 预测的 logits 和中间解码器激活的元组,
即每个解码器层的输出,每个输出都经过 layernorm 处理。
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
模型最后一层的隐藏状态序列。
hidden_states (`tuple(torch.FloatTensor)`, *optional*):
一个元组,包含 `torch.FloatTensor` 类型的张量。第一个张量是从嵌入层输出的结果,其余每个张量对应每个层的输出,
形状为 `(batch_size, sequence_length, hidden_size)`。当 `output_hidden_states=True` 时返回。
attentions (`tuple(torch.FloatTensor)`, *optional*):
一个元组,包含 `torch.FloatTensor` 类型的张量,每个张量的形状为 `(batch_size, num_heads, sequence_length,
sequence_length)`。表示经过注意力 softmax 后的注意力权重,用于计算自注意力头中的加权平均。
当 `output_attentions=True` 时返回。
masks_queries_logits (`tuple(torch.FloatTensor)` of shape `(batch_size, num_queries, height, width)`):
一个元组,包含 Transformer 解码器所有层的 mask 预测 logits。
intermediate_hidden_states (`tuple(torch.FloatTensor)` of shape `(num_queries, 1, hidden_size)`):
中间解码器激活的元组,即每个解码器层的输出,每个输出都经过 layernorm 处理。
"""
last_hidden_state: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[torch.FloatTensor] = None
masks_queries_logits: Tuple[torch.FloatTensor] = None
intermediate_hidden_states: Tuple[torch.FloatTensor] = None
@dataclass
class Mask2FormerPixelLevelModuleOutput(ModelOutput):
"""
Mask2FormerPixelLevelModuleOutput 类表示 Mask2Former 模型的像素级模块输出。
它返回了编码器的输出(可选)以及 `decoder` 的所有隐藏状态(多尺度特征)。
默认情况下,`encoder` 是 Swin 骨干网络,`decoder` 是基于多尺度可变形注意力的解码器。
`decoder_last_hidden_state` 是每个像素的嵌入,而 `decoder_hidden_states` 指的是使用论文中定义的多尺度策略产生的多尺度特征图。
Args:
decoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
解码器最后一层的每个像素的嵌入。
decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
一个元组,包含 `torch.FloatTensor` 类型的张量。表示使用多尺度策略生成的多尺度特征图。
"""
# 定义函数的参数列表,包括四个输入参数,均为torch.FloatTensor类型
Args:
encoder_last_hidden_state (`torch.FloatTensor`):
编码器最后的隐藏状态,即最后阶段编码器的最终特征图,形状为`(batch_size, num_channels, height, width)`
encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
编码器每个阶段输出的隐藏状态的元组。每个元素是形状为`(batch_size, num_channels, height, width)`的torch.FloatTensor。
如果设置了output_hidden_states为True,则返回此参数。
decoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width))`:
解码器最后一个Pixel解码层的1/4比例特征。
decoder_hidden_states (`tuple(torch.FloatTensor)`):
解码器每个阶段输出的隐藏状态的元组。每个元素是形状为`(batch_size, num_channels, height, width)`的torch.FloatTensor。
"""
encoder_last_hidden_state: torch.FloatTensor = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
decoder_last_hidden_state: torch.FloatTensor = None
decoder_hidden_states: Tuple[torch.FloatTensor] = None
@dataclass
class Mask2FormerModelOutput(ModelOutput):
"""
Class for outputs of [`Mask2FormerModel`]. This class returns all the needed hidden states to compute the logits.
Args:
encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, *optional*):
Last hidden states (final feature map) of the last stage of the encoder model (backbone). Returned when
`output_hidden_states=True` is passed.
encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
shape `(batch_size, num_channels, height, width)`. Hidden-states (also called feature maps) of the encoder
model at the output of each stage. Returned when `output_hidden_states=True` is passed.
pixel_decoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, *optional*):
Last hidden states (final feature map) of the last stage of the pixel decoder model.
pixel_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
shape `(batch_size, num_channels, height, width)`. Hidden-states (also called feature maps) of the pixel
decoder model at the output of each stage. Returned when `output_hidden_states=True` is passed.
transformer_decoder_last_hidden_state (`tuple(torch.FloatTensor)`):
Final output of the transformer decoder `(batch_size, sequence_length, hidden_size)`.
transformer_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
shape `(batch_size, sequence_length, hidden_size)`. Hidden-states (also called feature maps) of the
transformer decoder at the output of each stage. Returned when `output_hidden_states=True` is passed.
transformer_decoder_intermediate_states (`tuple(torch.FloatTensor)` of shape `(num_queries, 1, hidden_size)`):
Intermediate decoder activations, i.e. the output of each decoder layer, each of them gone through a
layernorm.
masks_queries_logits (`tuple(torch.FloatTensor)` of shape `(batch_size, num_queries, height, width)`)
Mask Predictions from each layer in the transformer decoder.
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed):
Tuple of `tuple(torch.FloatTensor)` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`. Self attentions weights from transformer decoder.
"""
encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, *optional*):
Last hidden states (final feature map) of the last stage of the encoder model.
encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
Tuple of `torch.FloatTensor` representing hidden states of the encoder model at each stage.
pixel_decoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, *optional*):
Last hidden states of the last stage of the pixel decoder model.
pixel_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
Tuple of `torch.FloatTensor` representing hidden states of the pixel decoder model at each stage.
transformer_decoder_last_hidden_state (`tuple(torch.FloatTensor)`):
Final output of the transformer decoder.
transformer_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
Tuple of `torch.FloatTensor` representing hidden states of the transformer decoder at each stage.
transformer_decoder_intermediate_states (`tuple(torch.FloatTensor)` of shape `(num_queries, 1, hidden_size)`):
Intermediate decoder activations, each gone through a layernorm.
masks_queries_logits (`tuple(torch.FloatTensor)` of shape `(batch_size, num_queries, height, width)`)
Mask Predictions from each layer in the transformer decoder.
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed):
Self attentions weights from transformer decoder.
"""
encoder_last_hidden_state: torch.FloatTensor = None
pixel_decoder_last_hidden_state: torch.FloatTensor = None
transformer_decoder_last_hidden_state: torch.FloatTensor = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
pixel_decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
transformer_decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
transformer_decoder_intermediate_states: Tuple[torch.FloatTensor] = None
masks_queries_logits: Tuple[torch.FloatTensor] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
"""
@dataclass
class Mask2FormerForUniversalSegmentationOutput(ModelOutput):
"""
[`Mask2FormerForUniversalSegmentationOutput`]的输出类。
这个输出可以直接传递给[`~Mask2FormerImageProcessor.post_process_semantic_segmentation`]、
[`~Mask2FormerImageProcessor.post_process_instance_segmentation`]或
[`~Mask2FormerImageProcessor.post_process_panoptic_segmentation`]以计算最终的分割图。
请参阅[`~Mask2FormerImageProcessor`]获取有关使用的详细信息。
"""
loss: Optional[torch.FloatTensor] = None
class_queries_logits: torch.FloatTensor = None
masks_queries_logits: torch.FloatTensor = None
auxiliary_logits: Optional[List[Dict[str, torch.FloatTensor]]] = None
encoder_last_hidden_state: torch.FloatTensor = None
pixel_decoder_last_hidden_state: torch.FloatTensor = None
transformer_decoder_last_hidden_state: torch.FloatTensor = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
pixel_decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
transformer_decoder_hidden_states: Optional[torch.FloatTensor] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
def sample_point(
input_features: torch.Tensor, point_coordinates: torch.Tensor, add_dim=False, **kwargs
) -> torch.Tensor:
"""
一个对`torch.nn.functional.grid_sample`进行包装的函数,支持3D点坐标张量。
Args:
input_features (`torch.Tensor` of shape (batch_size, channels, height, width)):
包含在高度*宽度网格上的特征映射的张量
point_coordinates (`torch.Tensor` of shape (batch_size, num_points, 2) or (batch_size, grid_height, grid_width,
2)):
包含[0, 1] * [0, 1]规范化点坐标的张量
add_dim (`bool`):
用于跟踪是否添加了维度
Returns:
point_features (`torch.Tensor` of shape (batch_size, channels, num_points) or (batch_size, channels,
height_grid, width_grid)):
包含`point_coordinates`中点的特征的张量。
"""
if point_coordinates.dim() == 3:
add_dim = True
point_coordinates = point_coordinates.unsqueeze(2)
point_features = torch.nn.functional.grid_sample(input_features, 2.0 * point_coordinates - 1.0, **kwargs)
if add_dim:
point_features = point_features.squeeze(3)
return point_features
def dice_loss(inputs: Tensor, labels: Tensor, num_masks: int) -> Tensor:
r"""
计算DICE损失,类似于掩码的广义IOU,计算方式如下:
"""
计算二进制分割任务中的 Dice Loss。
Args:
inputs (`torch.Tensor`):
表示一个掩码的张量。
labels (`torch.Tensor`):
与输入张量具有相同形状的张量。存储每个元素的二进制分类标签
(0表示负类,1表示正类)。
num_masks (`int`):
当前批次中存在的掩码数量,用于归一化。
Returns:
`torch.Tensor`: 计算得到的损失值。
"""
# 计算概率,并将结果展平为二维数组
probs = inputs.sigmoid().flatten(1)
# 计算 Dice 损失的分子部分
numerator = 2 * (probs * labels).sum(-1)
# 计算 Dice 损失的分母部分
denominator = probs.sum(-1) + labels.sum(-1)
# 计算最终的 Dice 损失
loss = 1 - (numerator + 1) / (denominator + 1)
# 将损失值对每个掩码进行求和并进行归一化
loss = loss.sum() / num_masks
return loss
# 定义一个函数,计算输入张量和标签之间的 sigmoid 交叉熵损失
def sigmoid_cross_entropy_loss(inputs: torch.Tensor, labels: torch.Tensor, num_masks: int) -> torch.Tensor:
r"""
Args:
inputs (`torch.Tensor`):
任意形状的浮点张量。
labels (`torch.Tensor`):
与输入张量形状相同的张量。存储每个输入元素的二元分类标签
(0 表示负类,1 表示正类)。
Returns:
loss (`torch.Tensor`): 计算得到的损失张量。
"""
# 使用 BCEWithLogitsLoss 函数定义损失计算方式,不进行汇总
criterion = nn.BCEWithLogitsLoss(reduction="none")
# 计算交叉熵损失
cross_entropy_loss = criterion(inputs, labels)
# 计算平均损失,并按 num_masks 汇总
loss = cross_entropy_loss.mean(1).sum() / num_masks
return loss
# 从 transformers.models.maskformer.modeling_maskformer.pair_wise_dice_loss 复制过来的代码
def pair_wise_dice_loss(inputs: Tensor, labels: Tensor) -> Tensor:
"""
一对一版本的 Dice 损失,参见 `dice_loss` 的用法。
Args:
inputs (`torch.Tensor`):
表示掩码的张量。
labels (`torch.Tensor`):
与输入张量形状相同的张量。存储每个输入元素的二元分类标签
(0 表示负类,1 表示正类)。
Returns:
`torch.Tensor`: 每对之间计算得到的损失。
"""
# 对输入张量应用 sigmoid 函数,并展平到第一维度
inputs = inputs.sigmoid().flatten(1)
numerator = 2 * torch.matmul(inputs, labels.T)
# 使用广播获取一个 [num_queries, NUM_CLASSES] 的矩阵
denominator = inputs.sum(-1)[:, None] + labels.sum(-1)[None, :]
loss = 1 - (numerator + 1) / (denominator + 1)
return loss
# 定义一个函数,计算输入张量和标签之间的一对一 sigmoid 交叉熵损失
def pair_wise_sigmoid_cross_entropy_loss(inputs: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
r"""
一对一版本的交叉熵损失,参见 `sigmoid_cross_entropy_loss` 的用法。
Args:
inputs (`torch.Tensor`):
表示掩码的张量。
labels (`torch.Tensor`):
与输入张量形状相同的张量。存储每个输入元素的二元分类标签
(0 表示负类,1 表示正类)。
Returns:
loss (`torch.Tensor`): 每对之间计算得到的损失。
"""
# 获取输入张量的高度和宽度
height_and_width = inputs.shape[1]
# 使用 BCEWithLogitsLoss 函数定义损失计算方式,不进行汇总
criterion = nn.BCEWithLogitsLoss(reduction="none")
# 分别计算正类和负类的交叉熵损失
cross_entropy_loss_pos = criterion(inputs, torch.ones_like(inputs))
cross_entropy_loss_neg = criterion(inputs, torch.zeros_like(inputs))
# 计算正类和负类的损失
loss_pos = torch.matmul(cross_entropy_loss_pos / height_and_width, labels.T)
loss_neg = torch.matmul(cross_entropy_loss_neg / height_and_width, (1 - labels).T)
# 组合正类和负类的损失
loss = loss_pos + loss_neg
return loss
# 从 https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/modeling/matcher.py 调整而来
class Mask2FormerHungarianMatcher(nn.Module):
"""这个类计算标签和网络预测之间的分配。
"""
For efficiency reasons, the labels don't include the no_object. Because of this, in general, there are more
predictions than labels. In this case, we do a 1-to-1 matching of the best predictions, while the others are
un-matched (and thus treated as non-objects).
"""
def __init__(
self, cost_class: float = 1.0, cost_mask: float = 1.0, cost_dice: float = 1.0, num_points: int = 12544
):
"""Creates the matcher
Params:
cost_class (`float`, *optional*, defaults to 1.0):
Relative weight of the classification error in the matching cost.
cost_mask (`float`, *optional*, defaults to 1.0):
This is the relative weight of the focal loss of the binary mask in the matching cost.
cost_dice (`float`, *optional*, defaults to 1.0):
This is the relative weight of the dice loss of the binary mask in the matching cost.
num_points (`int`, *optional*, defaults to 12544):
No. of points to sample on which the mask loss will be calculated. The same set of K points are
uniformly sampled for all prediction and ground truth masks to construct the cost matrix for bipartite
matching.
"""
super().__init__()
if cost_class == 0 and cost_mask == 0 and cost_dice == 0:
raise ValueError("All costs cant be 0")
self.num_points = num_points
self.cost_class = cost_class
self.cost_mask = cost_mask
self.cost_dice = cost_dice
@torch.no_grad()
def forward(
self,
masks_queries_logits: torch.Tensor,
class_queries_logits: torch.Tensor,
mask_labels: torch.Tensor,
class_labels: torch.Tensor,
class Mask2FormerLoss(nn.Module):
def __init__(self, config: Mask2FormerConfig, weight_dict: Dict[str, float]):
"""
The Mask2Former Loss. The loss is computed very similar to DETR. The process happens in two steps: 1) we
compute hungarian assignment between ground truth masks and the outputs of the model 2) we supervise each pair
of matched ground-truth / prediction (supervise class and mask)
Args:
config (`Mask2FormerConfig`):
The configuration for Mask2Former model also containing loss calculation specific parameters.
weight_dict (`Dict[str, float]`):
A dictionary of weights to be applied to the different losses.
"""
super().__init__()
requires_backends(self, ["scipy"])
self.num_labels = config.num_labels
self.weight_dict = weight_dict
self.eos_coef = config.no_object_weight
empty_weight = torch.ones(self.num_labels + 1)
empty_weight[-1] = self.eos_coef
self.register_buffer("empty_weight", empty_weight)
self.num_points = config.train_num_points
self.oversample_ratio = config.oversample_ratio
self.importance_sample_ratio = config.importance_sample_ratio
self.matcher = Mask2FormerHungarianMatcher(
cost_class=1.0,
cost_dice=config.dice_weight,
cost_mask=config.mask_weight,
num_points=self.num_points,
)
def _max_by_axis(self, sizes: List[List[int]]) -> List[int]:
maxes = sizes[0]
for sublist in sizes[1:]:
for index, item in enumerate(sublist):
maxes[index] = max(maxes[index], item)
return maxes
def _pad_images_to_max_in_batch(self, tensors: List[Tensor]) -> Tuple[Tensor, Tensor]:
max_size = self._max_by_axis([list(tensor.shape) for tensor in tensors])
batch_shape = [len(tensors)] + max_size
batch_size, _, height, width = batch_shape
dtype = tensors[0].dtype
device = tensors[0].device
padded_tensors = torch.zeros(batch_shape, dtype=dtype, device=device)
padding_masks = torch.ones((batch_size, height, width), dtype=torch.bool, device=device)
for tensor, padded_tensor, padding_mask in zip(tensors, padded_tensors, padding_masks):
padded_tensor[: tensor.shape[0], : tensor.shape[1], : tensor.shape[2]].copy_(tensor)
padding_mask[: tensor.shape[1], : tensor.shape[2]] = False
return padded_tensors, padding_masks
def loss_labels(
self, class_queries_logits: Tensor, class_labels: List[Tensor], indices: Tuple[np.array]
) -> Dict[str, Tensor]:
"""Compute the losses related to the labels using cross entropy.
Args:
class_queries_logits (`torch.Tensor`):
A tensor of shape `batch_size, num_queries, num_labels`
class_labels (`List[torch.Tensor]`):
List of class labels of shape `(labels)`.
indices (`Tuple[np.array])`:
The indices computed by the Hungarian matcher.
Returns:
`Dict[str, Tensor]`: A dict of `torch.Tensor` containing the following key:
- **loss_cross_entropy** -- The loss computed using cross entropy on the predicted and ground truth labels.
"""
pred_logits = class_queries_logits
batch_size, num_queries, _ = pred_logits.shape
criterion = nn.CrossEntropyLoss(weight=self.empty_weight)
idx = self._get_predictions_permutation_indices(indices)
target_classes_o = torch.cat(
[target[j] for target, (_, j) in zip(class_labels, indices)]
)
target_classes = torch.full(
(batch_size, num_queries), fill_value=self.num_labels, dtype=torch.int64, device=pred_logits.device
)
target_classes[idx] = target_classes_o
pred_logits_transposed = pred_logits.transpose(1, 2)
loss_ce = criterion(pred_logits_transposed, target_classes)
losses = {"loss_cross_entropy": loss_ce}
return losses
def loss_masks(
self,
masks_queries_logits: torch.Tensor,
mask_labels: List[torch.Tensor],
indices: Tuple[np.array],
num_masks: int,
) -> Dict[str, torch.Tensor]:
"""Compute the losses related to the masks using sigmoid_cross_entropy_loss and dice loss.
Args:
masks_queries_logits (`torch.Tensor`):
A tensor of shape `(batch_size, num_queries, height, width)`.
mask_labels (`torch.Tensor`):
List of mask labels of shape `(labels, height, width)`.
indices (`Tuple[np.array])`:
The indices computed by the Hungarian matcher.
num_masks (`int)`:
The number of masks, used for normalization.
Returns:
losses (`Dict[str, Tensor]`): A dict of `torch.Tensor` containing two keys:
- **loss_mask** -- The loss computed using sigmoid cross entropy loss on the predicted and ground truth.
masks.
- **loss_dice** -- The loss computed using dice loss on the predicted on the predicted and ground truth,
masks.
"""
src_idx = self._get_predictions_permutation_indices(indices)
tgt_idx = self._get_targets_permutation_indices(indices)
pred_masks = masks_queries_logits[src_idx]
target_masks, _ = self._pad_images_to_max_in_batch(mask_labels)
target_masks = target_masks[tgt_idx]
pred_masks = pred_masks[:, None]
target_masks = target_masks[:, None]
with torch.no_grad():
point_coordinates = self.sample_points_using_uncertainty(
pred_masks,
lambda logits: self.calculate_uncertainty(logits),
self.num_points,
self.oversample_ratio,
self.importance_sample_ratio,
)
point_labels = sample_point(target_masks, point_coordinates, align_corners=False).squeeze(1)
point_logits = sample_point(pred_masks, point_coordinates, align_corners=False).squeeze(1)
losses = {
"loss_mask": sigmoid_cross_entropy_loss(point_logits, point_labels, num_masks),
"loss_dice": dice_loss(point_logits, point_labels, num_masks),
}
del pred_masks
del target_masks
return losses
def _get_predictions_permutation_indices(self, indices):
batch_indices = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
predictions_indices = torch.cat([src for (src, _) in indices])
return batch_indices, predictions_indices
def _get_targets_permutation_indices(self, indices):
batch_indices = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
target_indices = torch.cat([tgt for (_, tgt) in indices])
return batch_indices, target_indices
def calculate_uncertainty(self, logits: torch.Tensor) -> torch.Tensor:
"""
在Mask2Former论文中,不确定性被估计为logits中前景类的预测与0.0之间的L1距离。
Args:
logits (`torch.Tensor`): 形状为(R, 1, ...)的张量,R为所有预测掩码的总数,C为前景类的数量,值为logits。
Returns:
scores (`torch.Tensor`): 形状为(R, 1, ...)的张量,包含不确定性分数,不确定位置的分数最高。
"""
uncertainty_scores = -(torch.abs(logits))
return uncertainty_scores
def sample_points_using_uncertainty(
self,
logits: torch.Tensor,
uncertainty_function,
num_points: int,
oversample_ratio: int,
importance_sample_ratio: float,
) -> torch.Tensor:
"""
This function samples points in [0, 1] * [0, 1] coordinate space based on uncertainty of logits predictions.
Args:
logits (`torch.Tensor`):
Logit predictions for bounding boxes.
uncertainty_function:
Function to calculate uncertainties based on logit predictions.
num_points (`int`):
Number of points to sample.
oversample_ratio (`int`):
Oversampling ratio for point sampling.
importance_sample_ratio (`float`):
Ratio of points sampled via importance sampling.
Returns:
point_coordinates (`torch.Tensor`):
Coordinates of sampled points.
"""
num_boxes = logits.shape[0]
num_points_sampled = int(num_points * oversample_ratio)
point_coordinates = torch.rand(num_boxes, num_points_sampled, 2, device=logits.device)
point_logits = sample_point(logits, point_coordinates, align_corners=False)
point_uncertainties = uncertainty_function(point_logits)
num_uncertain_points = int(importance_sample_ratio * num_points)
num_random_points = num_points - num_uncertain_points
idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1]
shift = num_points_sampled * torch.arange(num_boxes, dtype=torch.long, device=logits.device)
idx += shift[:, None]
point_coordinates = point_coordinates.view(-1, 2)[idx.view(-1), :].view(num_boxes, num_uncertain_points, 2)
if num_random_points > 0:
point_coordinates = torch.cat(
[point_coordinates, torch.rand(num_boxes, num_random_points, 2, device=logits.device)],
dim=1,
)
return point_coordinates
"""
This performs the loss computation.
Args:
masks_queries_logits (`torch.Tensor`):
A tensor of shape `(batch_size, num_queries, height, width)`.
Contains logits for predicted masks.
class_queries_logits (`torch.Tensor`):
A tensor of shape `(batch_size, num_queries, num_labels)`.
Contains logits for predicted class labels.
mask_labels (`torch.Tensor`):
List of mask labels of shape `(labels, height, width)`.
Ground truth masks.
class_labels (`List[torch.Tensor]`):
List of class labels of shape `(labels)`.
Ground truth class labels.
auxiliary_predictions (`Dict[str, torch.Tensor]`, *optional*):
if `use_auxiliary_loss` was set to `true` in [`Mask2FormerConfig`], then it contains the logits from
the inner layers of the Mask2FormerMaskedAttentionDecoder.
Dictionary of auxiliary predictions from intermediate layers.
Returns:
losses (`Dict[str, Tensor]`): A dict of `torch.Tensor` containing three keys:
- **loss_cross_entropy** -- The loss computed using cross entropy on the predicted and ground truth labels.
- **loss_mask** -- The loss computed using sigmoid cross_entropy loss on the predicted and ground truth
masks.
- **loss_dice** -- The loss computed using dice loss on the predicted and ground truth masks.
if `use_auxiliary_loss` was set to `true` in [`Mask2FormerConfig`], the dictionary contains additional
losses for each auxiliary predictions.
"""
indices = self.matcher(masks_queries_logits, class_queries_logits, mask_labels, class_labels)
num_masks = self.get_num_masks(class_labels, device=class_labels[0].device)
losses: Dict[str, Tensor] = {
**self.loss_masks(masks_queries_logits, mask_labels, indices, num_masks),
**self.loss_labels(class_queries_logits, class_labels, indices),
}
if auxiliary_predictions is not None:
for idx, aux_outputs in enumerate(auxiliary_predictions):
masks_queries_logits = aux_outputs["masks_queries_logits"]
class_queries_logits = aux_outputs["class_queries_logits"]
loss_dict = self.forward(masks_queries_logits, class_queries_logits, mask_labels, class_labels)
loss_dict = {f"{key}_{idx}": value for key, value in loss_dict.items()}
losses.update(loss_dict)
return losses
def get_num_masks(self, class_labels: torch.Tensor, device: torch.device) -> torch.Tensor:
"""
Computes the average number of target masks across the batch, for normalization purposes.
"""
num_masks = sum([len(classes) for classes in class_labels])
num_masks = torch.as_tensor(num_masks, dtype=torch.float, device=device)
world_size = 1
if is_accelerate_available():
if PartialState._shared_state != {}:
num_masks = reduce(num_masks)
world_size = PartialState().num_processes
num_masks = torch.clamp(num_masks / world_size, min=1)
return num_masks
def multi_scale_deformable_attention(
value: Tensor, value_spatial_shapes: Tensor, sampling_locations: Tensor, attention_weights: Tensor
) -> Tensor:
batch_size, _, num_heads, hidden_dim = value.shape
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
value_list = value.split([height.item() * width.item() for height, width in value_spatial_shapes], dim=1)
sampling_grids = 2 * sampling_locations - 1
sampling_value_list = []
for level_id, (height, width) in enumerate(value_spatial_shapes):
value_l_ = (
value_list[level_id].flatten(2).transpose(1, 2).reshape(batch_size * num_heads, hidden_dim, height, width)
)
sampling_grid_l_ = sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1)
sampling_value_l_ = nn.functional.grid_sample(
value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False
)
sampling_value_list.append(sampling_value_l_)
attention_weights = attention_weights.transpose(1, 2).reshape(
batch_size * num_heads, 1, num_queries, num_levels * num_points
)
output = (
(torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
.sum(-1)
.view(batch_size, num_heads * hidden_dim, num_queries)
)
return output.transpose(1, 2).contiguous()
class Mask2FormerSinePositionEmbedding(nn.Module):
"""
这是一个更标准的位置嵌入版本,与“Attention is all you need”论文中使用的非常相似,通用于处理图像。
"""
def __init__(
self, num_pos_feats: int = 64, temperature: int = 10000, normalize: bool = False, scale: Optional[float] = None
):
super().__init__()
self.num_pos_feats = num_pos_feats
self.temperature = temperature
self.normalize = normalize
if scale is None:
scale = 2 * math.pi
self.scale = scale
):
super().__init__()
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
self.num_pos_feats = num_pos_feats
self.temperature = temperature
self.normalize = normalize
self.scale = 2 * math.pi if scale is None else scale
def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor:
if mask is None:
mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
not_mask = (~mask).to(x.dtype)
y_embed = not_mask.cumsum(1)
x_embed = not_mask.cumsum(2)
if self.normalize:
eps = 1e-6
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.num_pos_feats, dtype=torch.int64, device=x.device).type_as(x)
dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
return pos
class Mask2FormerPixelDecoderEncoderMultiscaleDeformableAttention(nn.Module):
"""
在 Deformable DETR 中提出的多尺度可变形注意力机制。
"""
def __init__(self, embed_dim: int, num_heads: int, n_levels: int, n_points: int):
super().__init__()
if embed_dim % num_heads != 0:
raise ValueError(
f"embed_dim (d_model) must be divisible by num_heads, but got {embed_dim} and {num_heads}"
)
dim_per_head = embed_dim // num_heads
if not ((dim_per_head & (dim_per_head - 1) == 0) and dim_per_head != 0):
warnings.warn(
"You'd better set embed_dim (d_model) in DeformableDetrMultiscaleDeformableAttention to make the"
" dimension of each attention head a power of 2 which is more efficient in the authors' CUDA"
" implementation."
)
self.im2col_step = 128
self.d_model = embed_dim
self.n_levels = n_levels
self.n_heads = num_heads
self.n_points = n_points
self.sampling_offsets = nn.Linear(embed_dim, num_heads * n_levels * n_points * 2)
self.attention_weights = nn.Linear(embed_dim, num_heads * n_levels * n_points)
self.value_proj = nn.Linear(embed_dim, embed_dim)
self.output_proj = nn.Linear(embed_dim, embed_dim)
def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]):
return tensor if position_embeddings is None else tensor + position_embeddings
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states=None,
encoder_attention_mask=None,
position_embeddings: Optional[torch.Tensor] = None,
reference_points=None,
spatial_shapes=None,
level_start_index=None,
output_attentions: bool = False,
):
if position_embeddings is not None:
hidden_states = self.with_pos_embed(hidden_states, position_embeddings)
batch_size, num_queries, _ = hidden_states.shape
batch_size, sequence_length, _ = encoder_hidden_states.shape
if (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() != sequence_length:
raise ValueError(
"Make sure to align the spatial shapes with the sequence length of the encoder hidden states"
)
value = self.value_proj(encoder_hidden_states)
if attention_mask is not None:
value = value.masked_fill(attention_mask[..., None], float(0))
value = value.view(batch_size, sequence_length, self.n_heads, self.d_model // self.n_heads)
sampling_offsets = self.sampling_offsets(hidden_states).view(
batch_size, num_queries, self.n_heads, self.n_levels, self.n_points, 2
)
attention_weights = self.attention_weights(hidden_states).view(
batch_size, num_queries, self.n_heads, self.n_levels * self.n_points
)
attention_weights = nn.functional.softmax(attention_weights, -1).view(
batch_size, num_queries, self.n_heads, self.n_levels, self.n_points
)
if reference_points.shape[-1] == 2:
offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
sampling_locations = (
reference_points[:, :, None, :, None, :]
+ sampling_offsets / offset_normalizer[None, None, None, :, None, :]
)
elif reference_points.shape[-1] == 4:
sampling_locations = (
reference_points[:, :, None, :, None, :2]
+ sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
)
else:
raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")
output = multi_scale_deformable_attention(value, spatial_shapes, sampling_locations, attention_weights)
output = self.output_proj(output)
return output, attention_weights
class Mask2FormerPixelDecoderEncoderLayer(nn.Module):
def __init__(self, config: Mask2FormerConfig):
super().__init__()
self.embed_dim = config.feature_size
self.self_attn = Mask2FormerPixelDecoderEncoderMultiscaleDeformableAttention(
embed_dim=self.embed_dim,
num_heads=config.num_attention_heads,
n_levels=3,
n_points=4,
)
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.dropout = config.dropout
self.activation_fn = nn.functional.relu
self.activation_dropout = config.dropout
self.fc1 = nn.Linear(self.embed_dim, config.encoder_feedforward_dim)
self.fc2 = nn.Linear(config.encoder_feedforward_dim, self.embed_dim)
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
position_embeddings: torch.Tensor = None,
reference_points=None,
spatial_shapes=None,
level_start_index=None,
output_attentions: bool = False,
):
"""
Args:
hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
输入到层的输入。
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
注意力遮罩。
position_embeddings (`torch.FloatTensor`, *optional*):
位置嵌入,将要添加到 `hidden_states` 中。
reference_points (`torch.FloatTensor`, *optional*):
参考点。
spatial_shapes (`torch.LongTensor`, *optional*):
主干特征图的空间形状。
level_start_index (`torch.LongTensor`, *optional*):
层级起始索引。
output_attentions (`bool`, *optional*):
是否返回所有注意力层的注意力张量。查看返回的张量中的 `attentions` 以获取更多细节。
"""
residual = hidden_states
hidden_states, attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=hidden_states,
encoder_attention_mask=attention_mask,
position_embeddings=position_embeddings,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
output_attentions=output_attentions,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
residual = hidden_states
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = self.fc2(hidden_states)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
hidden_states = self.final_layer_norm(hidden_states)
if self.training:
if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any():
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights.transpose(1, 0),)
return outputs
class Mask2FormerPixelDecoderEncoderOnly(nn.Module):
"""
Transformer encoder consisting of *config.encoder_layers* deformable attention layers. Each layer is a
[`Mask2FormerPixelDecoderEncoderLayer`]. The encoder updates the flattened multi-scale feature maps through
multiple deformable attention layers.
Args:
config: Mask2FormerConfig
"""
def __init__(self, config: Mask2FormerConfig):
super().__init__()
self.config = config
self.dropout = config.dropout
self.layers = nn.ModuleList(
[Mask2FormerPixelDecoderEncoderLayer(config) for _ in range(config.encoder_layers)]
)
@staticmethod
def get_reference_points(spatial_shapes, valid_ratios, device):
"""
Get reference points for each feature map. Used in decoder.
Args:
spatial_shapes (`torch.LongTensor`):
Spatial shapes of each feature map, has shape of `(num_feature_levels, 2)`.
valid_ratios (`torch.FloatTensor`):
Valid ratios of each feature map, has shape of `(batch_size, num_feature_levels, 2)`.
device (`torch.device`):
Device on which to create the tensors.
Returns:
`torch.FloatTensor` of shape `(batch_size, num_queries, num_feature_levels, 2)`
"""
reference_points_list = []
for lvl, (height, width) in enumerate(spatial_shapes):
ref_y, ref_x = torch.meshgrid(
torch.linspace(0.5, height - 0.5, height, dtype=valid_ratios.dtype, device=device),
torch.linspace(0.5, width - 0.5, width, dtype=valid_ratios.dtype, device=device),
indexing="ij",
)
ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * height)
ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * width)
ref = torch.stack((ref_x, ref_y), -1)
reference_points_list.append(ref)
reference_points = torch.cat(reference_points_list, 1)
reference_points = reference_points[:, :, None] * valid_ratios[:, None]
return reference_points
def forward(
self,
inputs_embeds=None,
attention_mask=None,
position_embeddings=None,
spatial_shapes=None,
level_start_index=None,
valid_ratios=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
class Mask2FormerPixelDecoder(nn.Module):
def __init__(self, config: Mask2FormerConfig, feature_channels):
super().__init__()
self.config = config
feature_dim = config.feature_size
mask_dim = config.mask_feature_size
num_pos_features = feature_dim // 2
self.position_embedding = Mask2FormerSinePositionEmbedding(num_pos_feats=num_pos_features, normalize=True)
self.num_feature_levels = 3
transformer_in_channels = feature_channels[-self.num_feature_levels :]
self.transformer_feature_strides = config.feature_strides[-self.num_feature_levels :]
self.feature_channels = feature_channels
self.level_embed = nn.Parameter(torch.Tensor(self.num_feature_levels, feature_dim))
if self.num_feature_levels > 1:
input_projections_list = []
for in_channels in transformer_in_channels[::-1]:
input_projections_list.append(
nn.Sequential(
nn.Conv2d(in_channels, feature_dim, kernel_size=1),
nn.GroupNorm(32, feature_dim),
)
)
self.input_projections = nn.ModuleList(input_projections_list)
else:
self.input_projections = nn.ModuleList(
[
nn.Sequential(
nn.Conv2d(transformer_in_channels[-1], feature_dim, kernel_size=1),
nn.GroupNorm(32, feature_dim),
)
]
)
self.encoder = Mask2FormerPixelDecoderEncoderOnly(config)
self.mask_projection = nn.Conv2d(feature_dim, mask_dim, kernel_size=1, stride=1, padding=0)
stride = min(self.transformer_feature_strides)
self.common_stride = config.common_stride
self.num_fpn_levels = int(np.log2(stride) - np.log2(self.common_stride))
lateral_convs = []
output_convs = []
for idx, in_channels in enumerate(self.feature_channels[: self.num_fpn_levels]):
lateral_conv = nn.Sequential(
nn.Conv2d(in_channels, feature_dim, kernel_size=1, bias=False),
nn.GroupNorm(32, feature_dim),
)
output_conv = nn.Sequential(
nn.Conv2d(feature_dim, feature_dim, kernel_size=3, stride=1, padding=1, bias=False),
nn.GroupNorm(32, feature_dim),
nn.ReLU(),
)
self.add_module("adapter_{}".format(idx + 1), lateral_conv)
self.add_module("layer_{}".format(idx + 1), output_conv)
lateral_convs.append(lateral_conv)
output_convs.append(output_conv)
self.lateral_convolutions = lateral_convs[::-1]
self.output_convolutions = output_convs[::-1]
def get_valid_ratio(self, mask, dtype=torch.float32):
"""Get the valid ratio of all feature maps."""
_, height, width = mask.shape
valid_height = torch.sum(~mask[:, :, 0], 1)
valid_width = torch.sum(~mask[:, 0, :], 1)
valid_ratio_heigth = valid_height.to(dtype) / height
valid_ratio_width = valid_width.to(dtype) / width
valid_ratio = torch.stack([valid_ratio_width, valid_ratio_heigth], -1)
return valid_ratio
def forward(
self,
features,
encoder_outputs=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
class Mask2FormerPixelLevelModule(nn.Module):
def __init__(self, config: Mask2FormerConfig):
"""
Pixel Level Module proposed in [Masked-attention Mask Transformer for Universal Image
Segmentation](https://arxiv.org/abs/2112.01527). It runs the input image through a backbone and a pixel
decoder, generating multi-scale feature maps and pixel embeddings.
Args:
config ([`Mask2FormerConfig`]):
The configuration used to instantiate this model.
"""
super().__init__()
self.encoder = load_backbone(config)
self.decoder = Mask2FormerPixelDecoder(config, feature_channels=self.encoder.channels)
def forward(self, pixel_values: Tensor, output_hidden_states: bool = False) -> Mask2FormerPixelLevelModuleOutput:
backbone_features = self.encoder(pixel_values).feature_maps
decoder_output = self.decoder(backbone_features, output_hidden_states=output_hidden_states)
return Mask2FormerPixelLevelModuleOutput(
encoder_last_hidden_state=backbone_features[-1],
encoder_hidden_states=tuple(backbone_features) if output_hidden_states else None,
decoder_last_hidden_state=decoder_output.mask_features,
decoder_hidden_states=decoder_output.multi_scale_features,
)
class Mask2FormerAttention(nn.Module):
"""
Multi-headed attention from 'Attention Is All You Need' paper. Here, we add position embeddings to the queries and
keys (as explained in the DETR paper).
"""
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
is_decoder: bool = False,
bias: bool = True,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
if self.head_dim * num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {num_heads})."
)
self.scaling = self.head_dim**-0.5
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
def _shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int):
return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]):
return tensor if position_embeddings is None else tensor + position_embeddings
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_embeddings: Optional[torch.Tensor] = None,
key_value_states: Optional[torch.Tensor] = None,
key_value_position_embeddings: Optional[torch.Tensor] = None,
output_attentions: bool = False,
"""
Mask2FormerMaskedAttentionDecoderLayer由self-attention、交叉(masked)attention和FFN块组成。
在Mask2FormerMaskedAttentionDecoderLayer中使用的交叉attention实际上是一种限制注意力在预测段周围局部特征的masked attention块,
这导致更快的收敛和更好的性能。相比标准的DetrDecoder,Mask2FormerMaskedAttentionDecoder中的self和cross(即masked)attention块的顺序被交换,
这是一种优化改进。
Args:
config (`Mask2FormerConfig`):
用于初始化Mask2FormerMaskedAttentionDecoder的配置。
"""
def __init__(self, config: Mask2FormerConfig):
super().__init__()
self.config = config
self.embed_dim = self.config.hidden_dim
self.pre_norm = self.config.pre_norm
self.self_attn = Mask2FormerAttention(
embed_dim=self.embed_dim,
num_heads=config.num_attention_heads,
dropout=config.dropout,
is_decoder=True,
)
self.dropout = self.config.dropout
self.activation_fn = ACT2FN[self.config.activation_function]
self.activation_dropout = self.config.dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.cross_attn = nn.MultiheadAttention(self.embed_dim, self.config.num_attention_heads, self.config.dropout)
self.cross_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.fc1 = nn.Linear(self.embed_dim, self.config.dim_feedforward)
self.fc2 = nn.Linear(self.config.dim_feedforward, self.embed_dim)
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
"""
如果位置编码pos不为None,则将其添加到张量tensor中;否则返回原始张量tensor。
Args:
tensor (torch.Tensor): 输入张量
pos (Optional[Tensor]): 位置编码张量,可选
Returns:
torch.Tensor: 处理后的张量
"""
return tensor if pos is None else tensor + pos
def forward_post(
self,
hidden_states: torch.Tensor,
level_index: int = None,
attention_mask: Optional[torch.Tensor] = None,
position_embeddings: Optional[torch.Tensor] = None,
query_position_embeddings: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
):
cross_attn_weights = None
self_attn_weights = None
residual = hidden_states
hidden_states, cross_attn_weights = self.cross_attn(
query=self.with_pos_embed(hidden_states, query_position_embeddings),
key=self.with_pos_embed(encoder_hidden_states[level_index], position_embeddings[level_index]),
value=encoder_hidden_states[level_index],
attn_mask=encoder_attention_mask,
key_padding_mask=None,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
hidden_states = self.cross_attn_layer_norm(hidden_states)
residual = hidden_states
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
position_embeddings=query_position_embeddings,
attention_mask=None,
output_attentions=True,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
residual = hidden_states
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = self.fc2(hidden_states)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
hidden_states = self.final_layer_norm(hidden_states)
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights, cross_attn_weights)
return outputs
def forward_pre(
self,
hidden_states: torch.Tensor,
level_index: int = None,
attention_mask: Optional[torch.Tensor] = None,
position_embeddings: Optional[torch.Tensor] = None,
query_position_embeddings: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
cross_attn_weights = None
self_attn_weights = None
residual = hidden_states
hidden_states = self.cross_attn_layer_norm(hidden_states)
hidden_states, cross_attn_weights = self.cross_attn(
query=self.with_pos_embed(hidden_states, query_position_embeddings),
key=self.with_pos_embed(encoder_hidden_states[level_index], position_embeddings[level_index]),
value=encoder_hidden_states[level_index],
attn_mask=encoder_attention_mask,
key_padding_mask=None,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
position_embeddings=query_position_embeddings,
attention_mask=None,
output_attentions=True,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = self.fc2(hidden_states)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights, cross_attn_weights)
return outputs
def forward(
self,
hidden_states: torch.Tensor,
level_index: int = None,
attention_mask: Optional[torch.Tensor] = None,
position_embeddings: Optional[torch.Tensor] = None,
query_position_embeddings: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
"""
Args:
hidden_states (`torch.FloatTensor`):
输入到层的张量,形状为 `(seq_len, batch, embed_dim)`。
attention_mask (`torch.FloatTensor`):
注意力遮罩张量,形状为 `(1, seq_len, tgt_len, src_len)`。
position_embeddings (`torch.FloatTensor`, *可选*):
添加到掩码注意力层中键的位置嵌入。
query_position_embeddings (`torch.FloatTensor`, *可选*):
添加到自注意力层中查询和键的位置嵌入。
encoder_hidden_states (`torch.FloatTensor`):
层的交叉注意力输入张量,形状为 `(seq_len, batch, embed_dim)`。
encoder_attention_mask (`torch.FloatTensor`):
编码器注意力遮罩张量,大小为 `(1, seq_len, tgt_len, src_len)`。
output_attentions (`bool`, *可选*):
是否返回所有注意力层的注意力张量。查看返回的张量中的 `attentions` 以获取更多细节。
"""
if self.pre_norm:
outputs = self.forward_pre(
hidden_states=hidden_states,
level_index=level_index,
position_embeddings=position_embeddings,
query_position_embeddings=query_position_embeddings,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
)
else:
outputs = self.forward_post(
hidden_states=hidden_states,
level_index=level_index,
position_embeddings=position_embeddings,
query_position_embeddings=query_position_embeddings,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
)
return outputs
class Mask2FormerMaskedAttentionDecoder(nn.Module):
"""
Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a
[`Mask2FormerMaskedAttentionDecoderLayer`]. The decoder updates the query embeddings through multiple cross
(masked) and self-attention layers. The decoder uses a new **masked attention** mechanism instead of the standard
cross-attention, which extracts localized features by constraining cross-attention to within the foreground region
of the predicted mask for each query, instead of attending to the full feature map.
Args:
config (`Mask2FormerConfig`):
Configuration used to instantiate Mask2FormerMaskedAttentionDecoder.
"""
def __init__(self, config: Mask2FormerConfig):
super().__init__()
self.config = config
self.mask_feature_size = config.mask_feature_size
self.dropout = config.dropout
self.layerdrop = config.dropout
self.num_feature_levels = 3
self.decoder_layers = config.decoder_layers - 1
self.layers = nn.ModuleList(
[Mask2FormerMaskedAttentionDecoderLayer(self.config) for _ in range(self.decoder_layers)]
)
self.layernorm = nn.LayerNorm(config.hidden_dim)
self.mask_predictor = Mask2FormerMaskPredictor(
hidden_size=config.hidden_dim,
num_heads=config.num_attention_heads,
mask_feature_size=self.mask_feature_size,
)
self.gradient_checkpointing = False
def forward(
self,
inputs_embeds: torch.Tensor = None,
multi_stage_positional_embeddings: torch.Tensor = None,
pixel_embeddings: torch.Tensor = None,
encoder_hidden_states: torch.Tensor = None,
query_position_embeddings: torch.Tensor = None,
feature_size_list: List = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
class Mask2FormerPredictionBlock(nn.Module):
def __init__(self, in_dim: int, out_dim: int, activation: nn.Module) -> None:
super().__init__()
self.layers = [nn.Linear(in_dim, out_dim), activation]
for i, layer in enumerate(self.layers):
self.add_module(str(i), layer)
def forward(self, input: Tensor) -> Tensor:
hidden_state = input
for layer in self.layers:
hidden_state = layer(hidden_state)
return hidden_state
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int = 3):
"""
A classic Multi Layer Perceptron (MLP).
Args:
input_dim (`int`):
The input dimensions.
hidden_dim (`int`):
The hidden dimensions.
output_dim (`int`):
The output dimensions.
num_layers (int, *optional*, defaults to 3):
The number of layers.
"""
super().__init__()
in_dims = [input_dim] + [hidden_dim] * (num_layers - 1)
out_dims = [hidden_dim] * (num_layers - 1) + [output_dim]
self.layers = []
for i, (in_dim, out_dim) in enumerate(zip(in_dims, out_dims)):
activation = nn.ReLU() if i < num_layers - 1 else nn.Identity()
layer = Mask2FormerPredictionBlock(in_dim, out_dim, activation=activation)
self.layers.append(layer)
self.add_module(str(i), layer)
def forward(self, input: Tensor) -> Tensor:
hidden_state = input
for layer in self.layers:
hidden_state = layer(hidden_state)
return hidden_state
class Mask2FormerMaskPredictor(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, mask_feature_size: torch.Tensor):
"""
This class is used to get the predicted mask for a given Mask2FormerMaskedAttentionDecoder layer. It also
generates the binarized attention mask associated with the given predicted mask. The attention mask obtained
using predicted mask of the (l-1)th decoder layer is fed to the cross(masked)-attention block of the next
decoder layer as input.
Args:
hidden_size (`int`):
The feature dimension of the Mask2FormerMaskedAttentionDecoder
num_heads (`int`):
The number of heads used in the Mask2FormerMaskedAttentionDecoder
mask_feature_size (`torch.Tensor`):
one of the output dimensions of the predicted masks for each query
"""
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.mask_embedder = Mask2FormerMLPPredictionHead(self.hidden_size, self.hidden_size, mask_feature_size)
def forward(self, outputs: torch.Tensor, pixel_embeddings: torch.Tensor, attention_mask_target_size: int = None):
mask_embeddings = self.mask_embedder(outputs.transpose(0, 1))
is_tracing = (
torch.jit.is_tracing()
or isinstance(outputs, torch.fx.Proxy)
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
)
if is_tracing and not is_torch_greater_or_equal_than_2_1:
batch_size, num_queries, num_channels = mask_embeddings.shape
_, _, height, width = pixel_embeddings.shape
outputs_mask = torch.zeros((batch_size, num_queries, height, width), device=mask_embeddings.device)
for c in range(num_channels):
outputs_mask += mask_embeddings[..., c][..., None, None] * pixel_embeddings[:, None, c]
else:
outputs_mask = torch.einsum("bqc, bchw -> bqhw", mask_embeddings, pixel_embeddings)
attention_mask = nn.functional.interpolate(
outputs_mask, size=attention_mask_target_size, mode="bilinear", align_corners=False
)
attention_mask = attention_mask.sigmoid().flatten(2).unsqueeze(1).repeat(1, self.num_heads, 1, 1)
attention_mask = (attention_mask.flatten(0, 1) < 0.5).bool()
attention_mask = attention_mask.detach()
return outputs_mask, attention_mask
class Mask2FormerTransformerModule(nn.Module):
"""
The Mask2Former's transformer module.
"""
def __init__(self, in_features: int, config: Mask2FormerConfig):
super().__init__()
hidden_dim = config.hidden_dim
self.num_feature_levels = 3
self.position_embedder = Mask2FormerSinePositionEmbedding(num_pos_feats=hidden_dim // 2, normalize=True)
self.queries_embedder = nn.Embedding(config.num_queries, hidden_dim)
self.queries_features = nn.Embedding(config.num_queries, hidden_dim)
self.input_projections = []
for _ in range(self.num_feature_levels):
if in_features != hidden_dim or config.enforce_input_projection:
self.input_projections.append(nn.Conv2d(in_features, hidden_dim, kernel_size=1))
else:
self.input_projections.append(nn.Sequential())
self.decoder = Mask2FormerMaskedAttentionDecoder(config=config)
self.level_embed = nn.Embedding(self.num_feature_levels, hidden_dim)
def forward(
self,
multi_scale_features: List[Tensor],
mask_features: Tensor,
output_hidden_states: bool = False,
output_attentions: bool = False,
) -> Mask2FormerMaskedAttentionDecoderOutput:
multi_stage_features = []
multi_stage_positional_embeddings = []
size_list = []
for i in range(self.num_feature_levels):
size_list.append(multi_scale_features[i].shape[-2:])
multi_stage_positional_embeddings.append(self.position_embedder(multi_scale_features[i], None).flatten(2))
multi_stage_features.append(
self.input_projections[i](multi_scale_features[i]).flatten(2)
+ self.level_embed.weight[i][None, :, None]
)
multi_stage_positional_embeddings[-1] = multi_stage_positional_embeddings[-1].permute(2, 0, 1)
multi_stage_features[-1] = multi_stage_features[-1].permute(2, 0, 1)
_, batch_size, _ = multi_stage_features[0].shape
query_embeddings = self.queries_embedder.weight.unsqueeze(1).repeat(1, batch_size, 1)
query_features = self.queries_features.weight.unsqueeze(1).repeat(1, batch_size, 1)
decoder_output = self.decoder(
inputs_embeds=query_features,
multi_stage_positional_embeddings=multi_stage_positional_embeddings,
pixel_embeddings=mask_features,
encoder_hidden_states=multi_stage_features,
query_position_embeddings=query_embeddings,
feature_size_list=size_list,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
return_dict=True,
)
return decoder_output
MASK2FORMER_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
behavior.
Parameters:
config ([`Mask2FormerConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
MASK2FORMER_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
[`AutoImageProcessor.preprocess`] for details.
pixel_mask (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
Mask to avoid performing attention on padding pixel values. Mask values selected in `[0, 1]`:
- 1 for pixels that are real (i.e. **not masked**),
- 0 for pixels that are padding (i.e. **masked**).
[What are attention masks?](../glossary#attention-mask)
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of Detr's decoder attention layers.
return_dict (`bool`, *optional*):
Whether or not to return a [`~Mask2FormerModelOutput`] instead of a plain tuple.
"""
@add_start_docstrings(
"The bare Mask2Former Model outputting raw hidden-states without any specific head on top.",
MASK2FORMER_START_DOCSTRING,
)
class Mask2FormerModel(Mask2FormerPreTrainedModel):
main_input_name = "pixel_values"
def __init__(self, config: Mask2FormerConfig):
super().__init__(config)
self.pixel_level_module = Mask2FormerPixelLevelModule(config)
self.transformer_module = Mask2FormerTransformerModule(in_features=config.feature_size, config=config)
self.post_init()
@add_start_docstrings_to_model_forward(MASK2FORMER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Mask2FormerModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values: Tensor,
pixel_mask: Optional[Tensor] = None,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
@add_start_docstrings(
"The Mask2Former Model with heads on top for instance/semantic/panoptic segmentation.",
MASK2FORMER_START_DOCSTRING,
)
class Mask2FormerForUniversalSegmentation(Mask2FormerPreTrainedModel):
main_input_name = "pixel_values"
def __init__(self, config: Mask2FormerConfig):
super().__init__(config)
self.model = Mask2FormerModel(config)
self.weight_dict: Dict[str, float] = {
"loss_cross_entropy": config.class_weight,
"loss_mask": config.mask_weight,
"loss_dice": config.dice_weight,
}
self.class_predictor = nn.Linear(config.hidden_dim, config.num_labels + 1)
self.criterion = Mask2FormerLoss(config=config, weight_dict=self.weight_dict)
self.post_init()
def get_loss_dict(
self,
masks_queries_logits: Tensor,
class_queries_logits: Tensor,
mask_labels: Tensor,
class_labels: Tensor,
auxiliary_predictions: Dict[str, Tensor],
) -> Dict[str, Tensor]:
loss_dict: Dict[str, Tensor] = self.criterion(
masks_queries_logits=masks_queries_logits,
class_queries_logits=class_queries_logits,
mask_labels=mask_labels,
class_labels=class_labels,
auxiliary_predictions=auxiliary_predictions,
)
for key, weight in self.weight_dict.items():
for loss_key, loss in loss_dict.items():
if key in loss_key:
loss *= weight
return loss_dict
def get_loss(self, loss_dict: Dict[str, Tensor]) -> Tensor:
return sum(loss_dict.values())
def get_auxiliary_logits(self, classes: torch.Tensor, output_masks: torch.Tensor):
auxiliary_logits: List[Dict(str, Tensor)] = []
for aux_binary_masks, aux_classes in zip(output_masks[:-1], classes[:-1]):
auxiliary_logits.append({"masks_queries_logits": aux_binary_masks, "class_queries_logits": aux_classes})
return auxiliary_logits
@add_start_docstrings_to_model_forward(MASK2FORMER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Mask2FormerForUniversalSegmentationOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values: Tensor,
mask_labels: Optional[List[Tensor]] = None,
class_labels: Optional[List[Tensor]] = None,
pixel_mask: Optional[Tensor] = None,
output_hidden_states: Optional[bool] = None,
output_auxiliary_logits: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
.\models\mask2former\__init__.py
from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
_import_structure = {
"configuration_mask2former": [
"MASK2FORMER_PRETRAINED_CONFIG_ARCHIVE_MAP",
"Mask2FormerConfig",
],
}
try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["image_processing_mask2former"] = ["Mask2FormerImageProcessor"]
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_mask2former"] = [
"MASK2FORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
"Mask2FormerForUniversalSegmentation",
"Mask2FormerModel",
"Mask2FormerPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_mask2former import MASK2FORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, Mask2FormerConfig
try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .image_processing_mask2former import Mask2FormerImageProcessor
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_mask2former import (
MASK2FORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
Mask2FormerForUniversalSegmentation,
Mask2FormerModel,
Mask2FormerPreTrainedModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
.\models\maskformer\configuration_maskformer.py
""" MaskFormer model configuration"""
from typing import Dict, Optional
from ...configuration_utils import PretrainedConfig
from ...utils import logging
from ..auto import CONFIG_MAPPING
from ..detr import DetrConfig
from ..swin import SwinConfig
MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"facebook/maskformer-swin-base-ade": (
"https://huggingface.co/facebook/maskformer-swin-base-ade/blob/main/config.json"
)
}
logger = logging.get_logger(__name__)
class MaskFormerConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`MaskFormerModel`]. It is used to instantiate a
MaskFormer model according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the MaskFormer
[facebook/maskformer-swin-base-ade](https://huggingface.co/facebook/maskformer-swin-base-ade) architecture trained
on [ADE20k-150](https://huggingface.co/datasets/scene_parse_150).
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Currently, MaskFormer only supports the [Swin Transformer](swin) as backbone.
# 定义 MaskFormerConfig 类,用于配置 MaskFormerModel 模型的参数
class MaskFormerConfig:
# 控制掩码特征的大小,默认为 256
mask_feature_size (`int`, *optional*, defaults to 256):
The masks' features size, this value will also be used to specify the Feature Pyramid Network features'
size.
# 控制无物体类别的权重,默认为 0.1
no_object_weight (`float`, *optional*, defaults to 0.1):
Weight to apply to the null (no object) class.
# 是否使用辅助损失,默认为 False
use_auxiliary_loss(`bool`, *optional*, defaults to `False`):
If `True` [`MaskFormerForInstanceSegmentationOutput`] will contain the auxiliary losses computed using the
logits from each decoder's stage.
# 如果未设置 backbone_config,则使用默认配置 `swin-base-patch4-window12-384` 的配置
backbone_config (`Dict`, *optional*):
The configuration passed to the backbone, if unset, the configuration corresponding to
`swin-base-patch4-window12-384` will be used.
# 当 backbone_config 为 None 时,使用此参数指定要使用的骨干网络名称
backbone (`str`, *optional*):
Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
# 是否使用预训练的骨干网络权重,默认为 False
use_pretrained_backbone (`bool`, *optional*, `False`):
Whether to use pretrained weights for the backbone.
# 是否从 timm 库中加载 backbone,默认为 False
use_timm_backbone (`bool`, *optional*, `False`):
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
library.
# 当从检查点加载时,传递给 AutoBackbone 的关键字参数,例如 `{'out_indices': (0, 1, 2, 3)}`
backbone_kwargs (`dict`, *optional*):
Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
# 配置传递给变换器解码模型的参数,如果未设置,则使用 `detr-resnet-50` 的基本配置
decoder_config (`Dict`, *optional*):
The configuration passed to the transformer decoder model, if unset the base config for `detr-resnet-50`
will be used.
# 初始化所有权重矩阵的截断正态初始化器的标准差,默认为 0.02
init_std (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
# HM Attention map 模块中用于 Xavier 初始化增益的缩放因子,默认为 1
init_xavier_std (`float`, *optional*, defaults to 1):
The scaling factor used for the Xavier initialization gain in the HM Attention map module.
# Dice 损失的权重,默认为 1.0
dice_weight (`float`, *optional*, defaults to 1.0):
The weight for the dice loss.
# 交叉熵损失的权重,默认为 1.0
cross_entropy_weight (`float`, *optional*, defaults to 1.0):
The weight for the cross entropy loss.
# 掩码损失的权重,默认为 20.0
mask_weight (`float`, *optional*, defaults to 20.0):
The weight for the mask loss.
# 模型是否输出其辅助 logits,默认未指定
output_auxiliary_logits (`bool`, *optional*):
Should the model output its `auxiliary_logits` or not.
# 当所选的骨干模型类型不在 `["swin"]` 中或解码器模型类型不在 `["detr"]` 中时,引发 `ValueError`
Raises:
`ValueError`:
Raised if the backbone model type selected is not in `["swin"]` or the decoder model type selected is not
in `["detr"]`
Examples:
# 从 transformers 库导入 MaskFormerConfig 和 MaskFormerModel 类
>>> from transformers import MaskFormerConfig, MaskFormerModel
# Initializing a MaskFormer configuration object using default values
configuration = MaskFormerConfig()
# Initializing a MaskFormerModel object with the specified configuration, initially with random weights
model = MaskFormerModel(configuration)
# Accessing the configuration of the model instance
configuration = model.config
.\models\maskformer\configuration_maskformer_swin.py
"""
MaskFormer Swin Transformer model configuration
"""
from ...configuration_utils import PretrainedConfig
from ...utils import logging
from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
logger = logging.get_logger(__name__)
class MaskFormerSwinConfig(BackboneConfigMixin, PretrainedConfig):
"""
This is the configuration class to store the configuration of a [`MaskFormerSwinModel`]. It is used to instantiate
a Donut model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the Swin
[microsoft/swin-tiny-patch4-window7-224](https://huggingface.co/microsoft/swin-tiny-patch4-window7-224)
architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Example:
```
>>> from transformers import MaskFormerSwinConfig, MaskFormerSwinModel
>>> # Initializing a microsoft/swin-tiny-patch4-window7-224 style configuration
>>> configuration = MaskFormerSwinConfig()
>>> # Initializing a model (with random weights) from the microsoft/swin-tiny-patch4-window7-224 style configuration
>>> model = MaskFormerSwinModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
"""
model_type = "maskformer-swin"
attribute_map = {
"num_attention_heads": "num_heads",
"num_hidden_layers": "num_layers",
}
def __init__(
self,
image_size=224,
patch_size=4,
num_channels=3,
embed_dim=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_size=7,
mlp_ratio=4.0,
qkv_bias=True,
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
drop_path_rate=0.1,
hidden_act="gelu",
use_absolute_embeddings=False,
initializer_range=0.02,
layer_norm_eps=1e-5,
out_features=None,
out_indices=None,
**kwargs,
):
super().__init__(**kwargs)
):
super().__init__(**kwargs)
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.embed_dim = embed_dim
self.depths = depths
self.num_layers = len(depths)
self.num_heads = num_heads
self.window_size = window_size
self.mlp_ratio = mlp_ratio
self.qkv_bias = qkv_bias
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.drop_path_rate = drop_path_rate
self.hidden_act = hidden_act
self.use_absolute_embeddings = use_absolute_embeddings
self.layer_norm_eps = layer_norm_eps
self.initializer_range = initializer_range
self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))
self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)]
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
)
.\models\maskformer\convert_maskformer_original_pytorch_checkpoint_to_pytorch.py
import sys
from argparse import ArgumentParser
from dataclasses import dataclass
from pathlib import Path
from pprint import pformat
from typing import Any, Dict, Iterator, List, Set, Tuple
import requests
import torch
import torchvision.transforms as T
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import get_cfg
from detectron2.data import MetadataCatalog
from detectron2.projects.deeplab import add_deeplab_config
from PIL import Image
from torch import Tensor, nn
from transformers.models.maskformer.feature_extraction_maskformer import MaskFormerImageProcessor
from transformers.models.maskformer.modeling_maskformer import (
MaskFormerConfig,
MaskFormerForInstanceSegmentation,
MaskFormerForInstanceSegmentationOutput,
MaskFormerModel,
MaskFormerModelOutput,
)
from transformers.utils import logging
StateDict = Dict[str, Tensor]
logging.set_verbosity_info()
logger = logging.get_logger()
torch.manual_seed(0)
class TrackedStateDict:
def __init__(self, to_track: Dict):
"""This class "tracks" a python dictionary by keeping track of which item is accessed.
Args:
to_track (Dict): The dictionary we wish to track
"""
self.to_track = to_track
self._seen: Set[str] = set()
def __getitem__(self, key: str) -> Any:
return self.to_track[key]
def __setitem__(self, key: str, item: Any):
self._seen.add(key)
self.to_track[key] = item
def diff(self) -> List[str]:
"""This method returns a set difference between the keys in the tracked state dict and the one we have access so far.
This is an effective method to check if we have update all the keys
Returns:
List[str]: List of keys not yet updated
"""
return set(self.to_track.keys()) - self._seen
def copy(self) -> Dict:
return self.to_track.copy()
def prepare_img():
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
img_data = requests.get(url, stream=True).raw
im = Image.open(img_data)
return im
@dataclass
class Args:
"""Fake command line arguments needed by maskformer/detectron implementation"""
config_file: str
def setup_cfg(args: Args):
cfg = get_cfg()
add_deeplab_config(cfg)
add_mask_former_config(cfg)
cfg.merge_from_file(args.config_file)
cfg.freeze()
return cfg
class OriginalMaskFormerConfigToOursConverter:
def __call__(self, original_config: object) -> MaskFormerConfig:
model = original_config.MODEL
mask_former = model.MASK_FORMER
swin = model.SWIN
dataset_catalog = MetadataCatalog.get(original_config.DATASETS.TEST[0])
id2label = dict(enumerate(dataset_catalog.stuff_classes))
label2id = {label: idx for idx, label in id2label.items()}
config: MaskFormerConfig = MaskFormerConfig(
fpn_feature_size=model.SEM_SEG_HEAD.CONVS_DIM,
mask_feature_size=model.SEM_SEG_HEAD.MASK_DIM,
num_labels=model.SEM_SEG_HEAD.NUM_CLASSES,
no_object_weight=mask_former.NO_OBJECT_WEIGHT,
num_queries=mask_former.NUM_OBJECT_QUERIES,
backbone_config={
"pretrain_img_size": swin.PRETRAIN_IMG_SIZE,
"image_size": swin.PRETRAIN_IMG_SIZE,
"in_channels": 3,
"patch_size": swin.PATCH_SIZE,
"embed_dim": swin.EMBED_DIM,
"depths": swin.DEPTHS,
"num_heads": swin.NUM_HEADS,
"window_size": swin.WINDOW_SIZE,
"drop_path_rate": swin.DROP_PATH_RATE,
"model_type": "swin",
},
dice_weight=mask_former.DICE_WEIGHT,
ce_weight=1.0,
mask_weight=mask_former.MASK_WEIGHT,
decoder_config={
"model_type": "detr",
"max_position_embeddings": 1024,
"encoder_layers": 6,
"encoder_ffn_dim": 2048,
"encoder_attention_heads": 8,
"decoder_layers": mask_former.DEC_LAYERS,
"decoder_ffn_dim": mask_former.DIM_FEEDFORWARD,
"decoder_attention_heads": mask_former.NHEADS,
"encoder_layerdrop": 0.0,
"decoder_layerdrop": 0.0,
"d_model": mask_former.HIDDEN_DIM,
"dropout": mask_former.DROPOUT,
"attention_dropout": 0.0,
"activation_dropout": 0.0,
"init_std": 0.02,
"init_xavier_std": 1.0,
"scale_embedding": False,
"auxiliary_loss": False,
"dilation": False,
},
id2label=id2label,
label2id=label2id,
)
return config
class OriginalMaskFormerConfigToImageProcessorConverter:
pass
def __call__(self, original_config: object) -> MaskFormerImageProcessor:
model = original_config.MODEL
model_input = original_config.INPUT
dataset_catalog = MetadataCatalog.get(original_config.DATASETS.TEST[0])
return MaskFormerImageProcessor(
image_mean=(torch.tensor(model.PIXEL_MEAN) / 255).tolist(),
image_std=(torch.tensor(model.PIXEL_STD) / 255).tolist(),
size=model_input.MIN_SIZE_TEST,
max_size=model_input.MAX_SIZE_TEST,
num_labels=model.SEM_SEG_HEAD.NUM_CLASSES,
ignore_index=dataset_catalog.ignore_label,
size_divisibility=32,
)
class OriginalMaskFormerCheckpointToOursConverter:
def __init__(self, original_model: nn.Module, config: MaskFormerConfig):
self.original_model = original_model
self.config = config
def pop_all(self, renamed_keys: List[Tuple[str, str]], dst_state_dict: StateDict, src_state_dict: StateDict):
for src_key, dst_key in renamed_keys:
dst_state_dict[dst_key] = src_state_dict.pop(src_key)
def replace_pixel_module(self, dst_state_dict: StateDict, src_state_dict: StateDict):
dst_prefix: str = "pixel_level_module.decoder"
src_prefix: str = "sem_seg_head.pixel_decoder"
self.replace_backbone(dst_state_dict, src_state_dict, self.config)
def rename_keys_for_conv(detectron_conv: str, mine_conv: str):
return [
(f"{detectron_conv}.weight", f"{mine_conv}.0.weight"),
(f"{detectron_conv}.norm.weight", f"{mine_conv}.1.weight"),
(f"{detectron_conv}.norm.bias", f"{mine_conv}.1.bias"),
]
renamed_keys = [
(f"{src_prefix}.mask_features.weight", f"{dst_prefix}.mask_projection.weight"),
(f"{src_prefix}.mask_features.bias", f"{dst_prefix}.mask_projection.bias"),
]
renamed_keys.extend(rename_keys_for_conv(f"{src_prefix}.layer_4", f"{dst_prefix}.fpn.stem"))
for src_i, dst_i in zip(range(3, 0, -1), range(0, 3)):
renamed_keys.extend(
rename_keys_for_conv(f"{src_prefix}.adapter_{src_i}", f"{dst_prefix}.fpn.layers.{dst_i}.proj")
)
renamed_keys.extend(
rename_keys_for_conv(f"{src_prefix}.layer_{src_i}", f"{dst_prefix}.fpn.layers.{dst_i}.block")
)
self.pop_all(renamed_keys, dst_state_dict, src_state_dict)
def rename_keys_in_detr_decoder(self, dst_state_dict: StateDict, src_state_dict: StateDict):
dst_prefix: str = "transformer_module.decoder"
src_prefix: str = "sem_seg_head.predictor.transformer.decoder"
rename_keys = []
for i in range(self.config.decoder_config.decoder_layers):
rename_keys.append(
(
f"{src_prefix}.layers.{i}.self_attn.out_proj.weight",
f"{dst_prefix}.layers.{i}.self_attn.out_proj.weight",
)
)
rename_keys.append(
(
f"{src_prefix}.layers.{i}.self_attn.out_proj.bias",
f"{dst_prefix}.layers.{i}.self_attn.out_proj.bias",
)
)
rename_keys.append(
(
f"{src_prefix}.layers.{i}.multihead_attn.out_proj.weight",
f"{dst_prefix}.layers.{i}.encoder_attn.out_proj.weight",
)
)
rename_keys.append(
(
f"{src_prefix}.layers.{i}.multihead_attn.out_proj.bias",
f"{dst_prefix}.layers.{i}.encoder_attn.out_proj.bias",
)
)
rename_keys.append((f"{src_prefix}.layers.{i}.linear1.weight", f"{dst_prefix}.layers.{i}.fc1.weight"))
rename_keys.append((f"{src_prefix}.layers.{i}.linear1.bias", f"{dst_prefix}.layers.{i}.fc1.bias"))
rename_keys.append((f"{src_prefix}.layers.{i}.linear2.weight", f"{dst_prefix}.layers.{i}.fc2.weight"))
rename_keys.append((f"{src_prefix}.layers.{i}.linear2.bias", f"{dst_prefix}.layers.{i}.fc2.bias"))
rename_keys.append(
(f"{src_prefix}.layers.{i}.norm1.weight", f"{dst_prefix}.layers.{i}.self_attn_layer_norm.weight")
)
rename_keys.append(
(f"{src_prefix}.layers.{i}.norm1.bias", f"{dst_prefix}.layers.{i}.self_attn_layer_norm.bias")
)
rename_keys.append(
(f"{src_prefix}.layers.{i}.norm2.weight", f"{dst_prefix}.layers.{i}.encoder_attn_layer_norm.weight")
)
rename_keys.append(
(f"{src_prefix}.layers.{i}.norm2.bias", f"{dst_prefix}.layers.{i}.encoder_attn_layer_norm.bias")
)
rename_keys.append(
(f"{src_prefix}.layers.{i}.norm3.weight", f"{dst_prefix}.layers.{i}.final_layer_norm.weight")
)
rename_keys.append(
(f"{src_prefix}.layers.{i}.norm3.bias", f"{dst_prefix}.layers.{i}.final_layer_norm.bias")
)
return rename_keys
def replace_q_k_v_in_detr_decoder(self, dst_state_dict: StateDict, src_state_dict: StateDict):
dst_prefix: str = "transformer_module.decoder"
src_prefix: str = "sem_seg_head.predictor.transformer.decoder"
for i in range(self.config.decoder_config.decoder_layers):
in_proj_weight = src_state_dict.pop(f"{src_prefix}.layers.{i}.self_attn.in_proj_weight")
in_proj_bias = src_state_dict.pop(f"{src_prefix}.layers.{i}.self_attn.in_proj_bias")
dst_state_dict[f"{dst_prefix}.layers.{i}.self_attn.q_proj.weight"] = in_proj_weight[:256, :]
dst_state_dict[f"{dst_prefix}.layers.{i}.self_attn.q_proj.bias"] = in_proj_bias[:256]
dst_state_dict[f"{dst_prefix}.layers.{i}.self_attn.k_proj.weight"] = in_proj_weight[256:512, :]
dst_state_dict[f"{dst_prefix}.layers.{i}.self_attn.k_proj.bias"] = in_proj_bias[256:512]
dst_state_dict[f"{dst_prefix}.layers.{i}.self_attn.v_proj.weight"] = in_proj_weight[-256:, :]
dst_state_dict[f"{dst_prefix}.layers.{i}.self_attn.v_proj.bias"] = in_proj_bias[-256:]
in_proj_weight_cross_attn = src_state_dict.pop(f"{src_prefix}.layers.{i}.multihead_attn.in_proj_weight")
in_proj_bias_cross_attn = src_state_dict.pop(f"{src_prefix}.layers.{i}.multihead_attn.in_proj_bias")
dst_state_dict[f"{dst_prefix}.layers.{i}.encoder_attn.q_proj.weight"] = in_proj_weight_cross_attn[:256, :]
dst_state_dict[f"{dst_prefix}.layers.{i}.encoder_attn.q_proj.bias"] = in_proj_bias_cross_attn[:256]
dst_state_dict[f"{dst_prefix}.layers.{i}.encoder_attn.k_proj.weight"] = in_proj_weight_cross_attn[256:512, :]
dst_state_dict[f"{dst_prefix}.layers.{i}.encoder_attn.k_proj.bias"] = in_proj_bias_cross_attn[256:512]
dst_state_dict[f"{dst_prefix}.layers.{i}.encoder_attn.v_proj.weight"] = in_proj_weight_cross_attn[-256:, :]
dst_state_dict[f"{dst_prefix}.layers.{i}.encoder_attn.v_proj.bias"] = in_proj_bias_cross_attn[-256:]
def replace_detr_decoder(self, dst_state_dict: StateDict, src_state_dict: StateDict):
dst_prefix: str = "transformer_module.decoder"
src_prefix: str = "sem_seg_head.predictor.transformer.decoder"
renamed_keys = self.rename_keys_in_detr_decoder(dst_state_dict, src_state_dict)
renamed_keys.extend(
[
(f"{src_prefix}.norm.weight", f"{dst_prefix}.layernorm.weight"),
(f"{src_prefix}.norm.bias", f"{dst_prefix}.layernorm.bias"),
]
)
self.pop_all(renamed_keys, dst_state_dict, src_state_dict)
self.replace_q_k_v_in_detr_decoder(dst_state_dict, src_state_dict)
def replace_transformer_module(self, dst_state_dict: StateDict, src_state_dict: StateDict):
dst_prefix: str = "transformer_module"
src_prefix: str = "sem_seg_head.predictor"
self.replace_detr_decoder(dst_state_dict, src_state_dict)
renamed_keys = [
(f"{src_prefix}.query_embed.weight", f"{dst_prefix}.queries_embedder.weight"),
(f"{src_prefix}.input_proj.weight", f"{dst_prefix}.input_projection.weight"),
(f"{src_prefix}.input_proj.bias", f"{dst_prefix}.input_projection.bias"),
]
self.pop_all(renamed_keys, dst_state_dict, src_state_dict)
def replace_instance_segmentation_module(self, dst_state_dict: StateDict, src_state_dict: StateDict):
dst_prefix: str = ""
src_prefix: str = "sem_seg_head.predictor"
renamed_keys = [
(f"{src_prefix}.class_embed.weight", f"{dst_prefix}class_predictor.weight"),
(f"{src_prefix}.class_embed.bias", f"{dst_prefix}class_predictor.bias"),
]
mlp_len = 3
for i in range(mlp_len):
renamed_keys.extend(
[
(f"{src_prefix}.mask_embed.layers.{i}.weight", f"{dst_prefix}mask_embedder.{i}.0.weight"),
(f"{src_prefix}.mask_embed.layers.{i}.bias", f"{dst_prefix}mask_embedder.{i}.0.bias"),
]
)
logger.info(f"Replacing keys {pformat(renamed_keys)}")
self.pop_all(renamed_keys, dst_state_dict, src_state_dict)
def convert(self, mask_former: MaskFormerModel) -> MaskFormerModel:
dst_state_dict = TrackedStateDict(mask_former.state_dict())
src_state_dict = self.original_model.state_dict()
self.replace_pixel_module(dst_state_dict, src_state_dict)
self.replace_transformer_module(dst_state_dict, src_state_dict)
logger.info(f"Missed keys are {pformat(dst_state_dict.diff())}")
logger.info(f"Not copied keys are {pformat(src_state_dict.keys())}")
logger.info("🙌 Done")
mask_former.load_state_dict(dst_state_dict)
return mask_former
def convert_instance_segmentation(
self, mask_former: MaskFormerForInstanceSegmentation
) -> MaskFormerForInstanceSegmentation:
dst_state_dict = TrackedStateDict(mask_former.state_dict())
src_state_dict = self.original_model.state_dict()
self.replace_instance_segmentation_module(dst_state_dict, src_state_dict)
mask_former.load_state_dict(dst_state_dict)
return mask_former
@staticmethod
def using_dirs(checkpoints_dir: Path, config_dir: Path) -> Iterator[Tuple[object, Path, Path]]:
checkpoints: List[Path] = checkpoints_dir.glob("**/*.pkl")
for checkpoint in checkpoints:
logger.info(f"💪 Converting {checkpoint.stem}")
config: Path = config_dir / checkpoint.parents[0].stem / "swin" / f"{checkpoint.stem}.yaml"
yield config, checkpoint
def test(original_model, our_model: MaskFormerForInstanceSegmentation, image_processor: MaskFormerImageProcessor):
with torch.no_grad():
original_model = original_model.eval()
our_model = our_model.eval()
im = prepare_img()
tr = T.Compose(
[
T.Resize((384, 384)),
T.ToTensor(),
T.Normalize(
mean=torch.tensor([123.675, 116.280, 103.530]) / 255.0,
std=torch.tensor([58.395, 57.120, 57.375]) / 255.0,
),
],
)
x = tr(im).unsqueeze(0)
original_model_backbone_features = original_model.backbone(x.clone())
our_model_output: MaskFormerModelOutput = our_model.model(x.clone(), output_hidden_states=True)
for original_model_feature, our_model_feature in zip(
original_model_backbone_features.values(), our_model_output.encoder_hidden_states
):
assert torch.allclose(
original_model_feature, our_model_feature, atol=1e-3
), "The backbone features are not the same."
original_model_pixel_out = original_model.sem_seg_head.pixel_decoder.forward_features(
original_model_backbone_features
)
assert torch.allclose(
original_model_pixel_out[0], our_model_output.pixel_decoder_last_hidden_state, atol=1e-4
), "The pixel decoder feature are not the same"
original_model_out = original_model([{"image": x.squeeze(0)}])
original_segmentation = original_model_out[0]["sem_seg"]
our_model_out: MaskFormerForInstanceSegmentationOutput = our_model(x)
our_segmentation = image_processor.post_process_segmentation(our_model_out, target_size=(384, 384))
assert torch.allclose(
original_segmentation, our_segmentation, atol=1e-3
), "The segmentation image is not the same."
logger.info("✅ Test passed!")
def get_name(checkpoint_file: Path):
model_name_raw: str = checkpoint_file.stem
parent_name: str = checkpoint_file.parents[0].stem
backbone = "swin"
dataset = ""
if "coco" in parent_name:
dataset = "coco"
elif "ade" in parent_name:
dataset = "ade"
else:
raise ValueError(f"{parent_name} must be wrong since we didn't find 'coco' or 'ade' in it ")
backbone_types = ["tiny", "small", "base", "large"]
backbone_type = list(filter(lambda x: x in model_name_raw, backbone_types))[0]
model_name = f"maskformer-{backbone}-{backbone_type}-{dataset}"
return model_name
if __name__ == "__main__":
parser = ArgumentParser(
description="Command line to convert the original maskformers (with swin backbone) to our implementations."
)
parser.add_argument(
"--checkpoints_dir",
type=Path,
help=(
"A directory containing the model's checkpoints. The directory has to have the following structure:"
" <DIR_NAME>/<DATASET_NAME>/<CONFIG_NAME>.pkl"
),
)
parser.add_argument(
"--configs_dir",
type=Path,
help=(
"A directory containing the model's configs, see detectron2 doc. The directory has to have the following"
" structure: <DIR_NAME>/<DATASET_NAME>/<CONFIG_NAME>.yaml"
),
)
parser.add_argument(
"--pytorch_dump_folder_path",
required=True,
type=Path,
help="Path to the folder to output PyTorch models.",
)
parser.add_argument(
"--maskformer_dir",
required=True,
type=Path,
help=(
"A path to MaskFormer's original implementation directory. You can download from here:"
" https://github.com/facebookresearch/MaskFormer"
),
)
args = parser.parse_args()
checkpoints_dir: Path = args.checkpoints_dir
config_dir: Path = args.configs_dir
save_directory: Path = args.pytorch_dump_folder_path
maskformer_dir: Path = args.maskformer_dir
sys.path.append(str(maskformer_dir.parent))
from MaskFormer.mask_former import add_mask_former_config
from MaskFormer.mask_former.mask_former_model import MaskFormer as OriginalMaskFormer
if not save_directory.exists():
save_directory.mkdir(parents=True)
for config_file, checkpoint_file in OriginalMaskFormerCheckpointToOursConverter.using_dirs(
checkpoints_dir, config_dir
):
):
image_processor = OriginalMaskFormerConfigToImageProcessorConverter()(setup_cfg(Args(config_file=config_file)))
original_config = setup_cfg(Args(config_file=config_file))
mask_former_kwargs = OriginalMaskFormer.from_config(original_config)
original_model = OriginalMaskFormer(**mask_former_kwargs).eval()
DetectionCheckpointer(original_model).load(str(checkpoint_file))
config: MaskFormerConfig = OriginalMaskFormerConfigToOursConverter()(original_config)
mask_former = MaskFormerModel(config=config).eval()
converter = OriginalMaskFormerCheckpointToOursConverter(original_model, config)
maskformer = converter.convert(mask_former)
mask_former_for_instance_segmentation = MaskFormerForInstanceSegmentation(config=config).eval()
mask_former_for_instance_segmentation.model = mask_former
mask_former_for_instance_segmentation = converter.convert_instance_segmentation(
mask_former_for_instance_segmentation
)
test(original_model, mask_former_for_instance_segmentation, image_processor)
model_name = get_name(checkpoint_file)
logger.info(f"🪄 Saving {model_name}")
image_processor.save_pretrained(save_directory / model_name)
mask_former_for_instance_segmentation.save_pretrained(save_directory / model_name)
image_processor.push_to_hub(
repo_path_or_name=save_directory / model_name,
commit_message="Add model",
use_temp_dir=True,
)
mask_former_for_instance_segmentation.push_to_hub(
repo_path_or_name=save_directory / model_name,
commit_message="Add model",
use_temp_dir=True,
)
.\models\maskformer\convert_maskformer_resnet_to_pytorch.py
"""Convert MaskFormer checkpoints with ResNet backbone from the original repository. URL:
https://github.com/facebookresearch/MaskFormer"""
import argparse
import json
import pickle
from pathlib import Path
import requests
import torch
from huggingface_hub import hf_hub_download
from PIL import Image
from transformers import MaskFormerConfig, MaskFormerForInstanceSegmentation, MaskFormerImageProcessor, ResNetConfig
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
def get_maskformer_config(model_name: str):
if "resnet101c" in model_name:
raise NotImplementedError("To do")
elif "resnet101" in model_name:
backbone_config = ResNetConfig.from_pretrained(
"microsoft/resnet-101", out_features=["stage1", "stage2", "stage3", "stage4"]
)
else:
backbone_config = ResNetConfig.from_pretrained(
"microsoft/resnet-50", out_features=["stage1", "stage2", "stage3", "stage4"]
)
config = MaskFormerConfig(backbone_config=backbone_config)
repo_id = "huggingface/label-files"
if "ade20k-full" in model_name:
config.num_labels = 847
filename = "maskformer-ade20k-full-id2label.json"
elif "ade" in model_name:
config.num_labels = 150
filename = "ade20k-id2label.json"
elif "coco-stuff" in model_name:
config.num_labels = 171
filename = "maskformer-coco-stuff-id2label.json"
elif "coco" in model_name:
config.num_labels = 133
filename = "coco-panoptic-id2label.json"
elif "cityscapes" in model_name:
config.num_labels = 19
filename = "cityscapes-id2label.json"
elif "vistas" in model_name:
config.num_labels = 65
filename = "mapillary-vistas-id2label.json"
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
config.id2label = id2label
config.label2id = {v: k for k, v in id2label.items()}
return config
def create_rename_keys(config):
rename_keys = []
rename_keys.append(("backbone.stem.conv1.weight", "model.pixel_level_module.encoder.embedder.embedder.convolution.weight"))
rename_keys.append(("backbone.stem.conv1.norm.weight", "model.pixel_level_module.encoder.embedder.embedder.normalization.weight"))
rename_keys.append(("backbone.stem.conv1.norm.bias", "model.pixel_level_module.encoder.embedder.embedder.normalization.bias"))
rename_keys.append(("backbone.stem.conv1.norm.running_mean", "model.pixel_level_module.encoder.embedder.embedder.normalization.running_mean"))
rename_keys.append(("backbone.stem.conv1.norm.running_var", "model.pixel_level_module.encoder.embedder.embedder.normalization.running_var"))
rename_keys.append(("sem_seg_head.layer_4.weight", "model.pixel_level_module.decoder.fpn.stem.0.weight"))
rename_keys.append(("sem_seg_head.layer_4.norm.weight", "model.pixel_level_module.decoder.fpn.stem.1.weight"))
rename_keys.append(("sem_seg_head.layer_4.norm.bias", "model.pixel_level_module.decoder.fpn.stem.1.bias"))
for source_index, target_index in zip(range(3, 0, -1), range(0, 3)):
rename_keys.append((f"sem_seg_head.adapter_{source_index}.weight", f"model.pixel_level_module.decoder.fpn.layers.{target_index}.proj.0.weight"))
rename_keys.append((f"sem_seg_head.adapter_{source_index}.norm.weight", f"model.pixel_level_module.decoder.fpn.layers.{target_index}.proj.1.weight"))
rename_keys.append((f"sem_seg_head.adapter_{source_index}.norm.bias", f"model.pixel_level_module.decoder.fpn.layers.{target_index}.proj.1.bias"))
rename_keys.append((f"sem_seg_head.layer_{source_index}.weight", f"model.pixel_level_module.decoder.fpn.layers.{target_index}.block.0.weight"))
rename_keys.append((f"sem_seg_head.layer_{source_index}.norm.weight", f"model.pixel_level_module.decoder.fpn.layers.{target_index}.block.1.weight"))
rename_keys.append((f"sem_seg_head.layer_{source_index}.norm.bias", f"model.pixel_level_module.decoder.fpn.layers.{target_index}.block.1.bias"))
rename_keys.append(("sem_seg_head.mask_features.weight", "model.pixel_level_module.decoder.mask_projection.weight"))
rename_keys.append(("sem_seg_head.mask_features.bias", "model.pixel_level_module.decoder.mask_projection.bias"))
for idx in range(config.decoder_config.decoder_layers):
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.self_attn.out_proj.weight", f"model.transformer_module.decoder.layers.{idx}.self_attn.out_proj.weight"))
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.self_attn.out_proj.bias", f"model.transformer_module.decoder.layers.{idx}.self_attn.out_proj.bias"))
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.multihead_attn.out_proj.weight", f"model.transformer_module.decoder.layers.{idx}.encoder_attn.out_proj.weight"))
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.multihead_attn.out_proj.bias", f"model.transformer_module.decoder.layers.{idx}.encoder_attn.out_proj.bias"))
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.linear1.weight", f"model.transformer_module.decoder.layers.{idx}.fc1.weight"))
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.linear1.bias", f"model.transformer_module.decoder.layers.{idx}.fc1.bias"))
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.linear2.weight", f"model.transformer_module.decoder.layers.{idx}.fc2.weight"))
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.linear2.bias", f"model.transformer_module.decoder.layers.{idx}.fc2.bias"))
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.norm1.weight", f"model.transformer_module.decoder.layers.{idx}.self_attn_layer_norm.weight"))
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.norm1.bias", f"model.transformer_module.decoder.layers.{idx}.self_attn_layer_norm.bias"))
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.norm2.weight", f"model.transformer_module.decoder.layers.{idx}.encoder_attn_layer_norm.weight"))
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.norm2.bias", f"model.transformer_module.decoder.layers.{idx}.encoder_attn_layer_norm.bias"))
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.norm3.weight", f"model.transformer_module.decoder.layers.{idx}.final_layer_norm.weight"))
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.norm3.bias", f"model.transformer_module.decoder.layers.{idx}.final_layer_norm.bias"))
rename_keys.append(("sem_seg_head.predictor.transformer.decoder.norm.weight", "model.transformer_module.decoder.layernorm.weight"))
rename_keys.append(("sem_seg_head.predictor.transformer.decoder.norm.bias", "model.transformer_module.decoder.layernorm.bias"))
def rename_key(dct, old, new):
val = dct.pop(old)
dct[new] = val
def read_in_decoder_q_k_v(state_dict, config):
hidden_size = config.decoder_config.hidden_size
for idx in range(config.decoder_config.decoder_layers):
in_proj_weight = state_dict.pop(f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.self_attn.in_proj_weight")
in_proj_bias = state_dict.pop(f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.self_attn.in_proj_bias")
state_dict[f"model.transformer_module.decoder.layers.{idx}.self_attn.q_proj.weight"] = in_proj_weight[: hidden_size, :]
state_dict[f"model.transformer_module.decoder.layers.{idx}.self_attn.q_proj.bias"] = in_proj_bias[:config.hidden_size]
state_dict[f"model.transformer_module.decoder.layers.{idx}.self_attn.k_proj.weight"] = in_proj_weight[hidden_size : hidden_size * 2, :]
state_dict[f"model.transformer_module.decoder.layers.{idx}.self_attn.k_proj.bias"] = in_proj_bias[hidden_size : hidden_size * 2]
state_dict[f"model.transformer_module.decoder.layers.{idx}.self_attn.v_proj.weight"] = in_proj_weight[-hidden_size :, :]
state_dict[f"model.transformer_module.decoder.layers.{idx}.self_attn.v_proj.bias"] = in_proj_bias[-hidden_size :]
in_proj_weight = state_dict.pop(f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.multihead_attn.in_proj_weight")
in_proj_bias = state_dict.pop(f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.multihead_attn.in_proj_bias")
state_dict[f"model.transformer_module.decoder.layers.{idx}.encoder_attn.q_proj.weight"] = in_proj_weight[: hidden_size, :]
state_dict[f"model.transformer_module.decoder.layers.{idx}.encoder_attn.q_proj.bias"] = in_proj_bias[:config.hidden_size]
state_dict[f"model.transformer_module.decoder.layers.{idx}.encoder_attn.k_proj.weight"] = in_proj_weight[hidden_size : hidden_size * 2, :]
state_dict[f"model.transformer_module.decoder.layers.{idx}.encoder_attn.k_proj.bias"] = in_proj_bias[hidden_size : hidden_size * 2]
state_dict[f"model.transformer_module.decoder.layers.{idx}.encoder_attn.v_proj.weight"] = in_proj_weight[-hidden_size :, :]
state_dict[f"model.transformer_module.decoder.layers.{idx}.encoder_attn.v_proj.bias"] = in_proj_bias[-hidden_size :]
def prepare_img() -> torch.Tensor:
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
im = Image.open(requests.get(url, stream=True).raw)
return im
@torch.no_grad()
def convert_maskformer_checkpoint(
model_name: str, checkpoint_path: str, pytorch_dump_folder_path: str, push_to_hub: bool = False
):
"""
Copy/paste/tweak model's weights to our MaskFormer structure.
"""
config = get_maskformer_config(model_name)
with open(checkpoint_path, "rb") as f:
data = pickle.load(f)
state_dict = data["model"]
rename_keys = create_rename_keys(config)
for src, dest in rename_keys:
rename_key(state_dict, src, dest)
read_in_decoder_q_k_v(state_dict, config)
for key, value in state_dict.items():
state_dict[key] = torch.from_numpy(value)
model = MaskFormerForInstanceSegmentation(config)
model.eval()
model.load_state_dict(state_dict)
image = prepare_img()
if "vistas" in model_name:
ignore_index = 65
elif "cityscapes" in model_name:
ignore_index = 65535
else:
ignore_index = 255
reduce_labels = True if "ade" in model_name else False
image_processor = MaskFormerImageProcessor(ignore_index=ignore_index, reduce_labels=reduce_labels)
inputs = image_processor(image, return_tensors="pt")
outputs = model(**inputs)
if model_name == "maskformer-resnet50-ade":
expected_logits = torch.tensor(
[[6.7710, -0.1452, -3.5687], [1.9165, -1.0010, -1.8614], [3.6209, -0.2950, -1.3813]]
)
elif model_name == "maskformer-resnet101-ade":
expected_logits = torch.tensor(
[[4.0381, -1.1483, -1.9688], [2.7083, -1.9147, -2.2555], [3.4367, -1.3711, -2.1609]]
)
elif model_name == "maskformer-resnet50-coco-stuff":
expected_logits = torch.tensor(
[[3.2309, -3.0481, -2.8695], [5.4986, -5.4242, -2.4211], [6.2100, -5.2279, -2.7786]]
)
elif model_name == "maskformer-resnet101-coco-stuff":
expected_logits = torch.tensor(
[[4.7188, -3.2585, -2.8857], [6.6871, -2.9181, -1.2487], [7.2449, -2.2764, -2.1874]]
)
elif model_name == "maskformer-resnet101-cityscapes":
expected_logits = torch.tensor(
[[-1.8861, -1.5465, 0.6749], [-2.3677, -1.6707, -0.0867], [-2.2314, -1.9530, -0.9132]]
)
elif model_name == "maskformer-resnet50-vistas":
expected_logits = torch.tensor(
[[-6.3917, -1.5216, -1.1392], [-5.5335, -4.5318, -1.8339], [-4.3576, -4.0301, 0.2162]]
)
elif model_name == "maskformer-resnet50-ade20k-full":
expected_logits = torch.tensor(
[[3.6146, -1.9367, -3.2534], [4.0099, 0.2027, -2.7576], [3.3913, -2.3644, -3.9519]]
)
elif model_name == "maskformer-resnet101-ade20k-full":
expected_logits = torch.tensor(
[[3.2211, -1.6550, -2.7605], [2.8559, -2.4512, -2.9574], [2.6331, -2.6775, -2.1844]]
)
assert torch.allclose(outputs.class_queries_logits[0, :3, :3], expected_logits, atol=1e-4)
print("Looks ok!")
if pytorch_dump_folder_path is not None:
print(f"Saving model and image processor of {model_name} to {pytorch_dump_folder_path}")
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
model.save_pretrained(pytorch_dump_folder_path)
image_processor.save_pretrained(pytorch_dump_folder_path)
if push_to_hub:
print(f"Pushing model and image processor of {model_name} to the hub...")
model.push_to_hub(f"facebook/{model_name}")
image_processor.push_to_hub(f"facebook/{model_name}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name",
default="maskformer-resnet50-ade",
type=str,
required=True,
choices=[
"maskformer-resnet50-ade",
"maskformer-resnet101-ade",
"maskformer-resnet50-coco-stuff",
"maskformer-resnet101-coco-stuff",
"maskformer-resnet101-cityscapes",
"maskformer-resnet50-vistas",
"maskformer-resnet50-ade20k-full",
"maskformer-resnet101-ade20k-full",
],
help=("Name of the MaskFormer model you'd like to convert",),
)
parser.add_argument(
"--checkpoint_path",
type=str,
required=True,
help=("Path to the original pickle file (.pkl) of the original checkpoint.",),
)
parser.add_argument(
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
)
parser.add_argument(
"--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
)
args = parser.parse_args()
convert_maskformer_checkpoint(
args.model_name, args.checkpoint_path, args.pytorch_dump_folder_path, args.push_to_hub
)
.\models\maskformer\convert_maskformer_swin_to_pytorch.py
import argparse
import json
import pickle
from pathlib import Path
import requests
import torch
from huggingface_hub import hf_hub_download
from PIL import Image
from transformers import MaskFormerConfig, MaskFormerForInstanceSegmentation, MaskFormerImageProcessor, SwinConfig
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
def get_maskformer_config(model_name: str):
backbone_config = SwinConfig.from_pretrained(
"microsoft/swin-tiny-patch4-window7-224", out_features=["stage1", "stage2", "stage3", "stage4"]
)
config = MaskFormerConfig(backbone_config=backbone_config)
repo_id = "huggingface/label-files"
if "ade20k-full" in model_name:
config.num_labels = 847
filename = "maskformer-ade20k-full-id2label.json"
elif "ade" in model_name:
config.num_labels = 150
filename = "ade20k-id2label.json"
elif "coco-stuff" in model_name:
config.num_labels = 171
filename = "maskformer-coco-stuff-id2label.json"
elif "coco" in model_name:
config.num_labels = 133
filename = "coco-panoptic-id2label.json"
elif "cityscapes" in model_name:
config.num_labels = 19
filename = "cityscapes-id2label.json"
elif "vistas" in model_name:
config.num_labels = 65
filename = "mapillary-vistas-id2label.json"
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
return config
def create_rename_keys(config):
rename_keys = []
rename_keys.append(("backbone.patch_embed.proj.weight", "model.pixel_level_module.encoder.model.embeddings.patch_embeddings.projection.weight"))
rename_keys.append(("backbone.patch_embed.proj.bias", "model.pixel_level_module.encoder.model.embeddings.patch_embeddings.projection.bias"))
rename_keys.append(("backbone.patch_embed.norm.weight", "model.pixel_level_module.encoder.model.embeddings.norm.weight"))
rename_keys.append(("backbone.patch_embed.norm.bias", "model.pixel_level_module.encoder.model.embeddings.norm.bias"))
rename_keys.append(("sem_seg_head.layer_4.weight", "model.pixel_level_module.decoder.fpn.stem.0.weight"))
rename_keys.append(("sem_seg_head.layer_4.norm.weight", "model.pixel_level_module.decoder.fpn.stem.1.weight"))
rename_keys.append(("sem_seg_head.layer_4.norm.bias", "model.pixel_level_module.decoder.fpn.stem.1.bias"))
for source_index, target_index in zip(range(3, 0, -1), range(0, 3)):
rename_keys.append((f"sem_seg_head.adapter_{source_index}.weight", f"model.pixel_level_module.decoder.fpn.layers.{target_index}.proj.0.weight"))
rename_keys.append((f"sem_seg_head.adapter_{source_index}.norm.weight", f"model.pixel_level_module.decoder.fpn.layers.{target_index}.proj.1.weight"))
rename_keys.append((f"sem_seg_head.adapter_{source_index}.norm.bias", f"model.pixel_level_module.decoder.fpn.layers.{target_index}.proj.1.bias"))
rename_keys.append((f"sem_seg_head.layer_{source_index}.weight", f"model.pixel_level_module.decoder.fpn.layers.{target_index}.block.0.weight"))
rename_keys.append((f"sem_seg_head.layer_{source_index}.norm.weight", f"model.pixel_level_module.decoder.fpn.layers.{target_index}.block.1.weight"))
rename_keys.append((f"sem_seg_head.layer_{source_index}.norm.bias", f"model.pixel_level_module.decoder.fpn.layers.{target_index}.block.1.bias"))
rename_keys.append(("sem_seg_head.mask_features.weight", "model.pixel_level_module.decoder.mask_projection.weight"))
rename_keys.append(("sem_seg_head.mask_features.bias", "model.pixel_level_module.decoder.mask_projection.bias"))
for idx in range(config.decoder_config.decoder_layers):
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.self_attn.out_proj.weight", f"model.transformer_module.decoder.layers.{idx}.self_attn.out_proj.weight"))
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.self_attn.out_proj.bias", f"model.transformer_module.decoder.layers.{idx}.self_attn.out_proj.bias"))
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.multihead_attn.out_proj.weight", f"model.transformer_module.decoder.layers.{idx}.encoder_attn.out_proj.weight"))
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.multihead_attn.out_proj.bias", f"model.transformer_module.decoder.layers.{idx}.encoder_attn.out_proj.bias"))
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.linear1.weight", f"model.transformer_module.decoder.layers.{idx}.fc1.weight"))
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.linear1.bias", f"model.transformer_module.decoder.layers.{idx}.fc1.bias"))
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.linear2.weight", f"model.transformer_module.decoder.layers.{idx}.fc2.weight"))
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.linear2.bias", f"model.transformer_module.decoder.layers.{idx}.fc2.bias"))
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.norm1.weight", f"model.transformer_module.decoder.layers.{idx}.self_attn_layer_norm.weight"))
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.norm1.bias", f"model.transformer_module.decoder.layers.{idx}.self_attn_layer_norm.bias"))
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.norm2.weight", f"model.transformer_module.decoder.layers.{idx}.encoder_attn_layer_norm.weight"))
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.norm2.bias", f"model.transformer_module.decoder.layers.{idx}.encoder_attn_layer_norm.bias"))
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.norm3.weight", f"model.transformer_module.decoder.layers.{idx}.final_layer_norm.weight"))
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.norm3.bias", f"model.transformer_module.decoder.layers.{idx}.final_layer_norm.bias"))
rename_keys.append(("sem_seg_head.predictor.transformer.decoder.norm.weight", "model.transformer_module.decoder.layernorm.weight"))
rename_keys.append(("sem_seg_head.predictor.transformer.decoder.norm.bias", "model.transformer_module.decoder.layernorm.bias"))
rename_keys.append(("sem_seg_head.predictor.query_embed.weight", "model.transformer_module.queries_embedder.weight"))
rename_keys.append(("sem_seg_head.predictor.input_proj.weight", "model.transformer_module.input_projection.weight"))
rename_keys.append(("sem_seg_head.predictor.input_proj.bias", "model.transformer_module.input_projection.bias"))
rename_keys.append(("sem_seg_head.predictor.class_embed.weight", "class_predictor.weight"))
rename_keys.append(("sem_seg_head.predictor.class_embed.bias", "class_predictor.bias"))
for i in range(3):
rename_keys.append((f"sem_seg_head.predictor.mask_embed.layers.{i}.weight", f"mask_embedder.{i}.0.weight"))
rename_keys.append((f"sem_seg_head.predictor.mask_embed.layers.{i}.bias", f"mask_embedder.{i}.0.bias"))
return rename_keys
def rename_key(dct, old, new):
val = dct.pop(old)
dct[new] = val
def read_in_swin_q_k_v(state_dict, backbone_config):
num_features = [int(backbone_config.embed_dim * 2**i) for i in range(len(backbone_config.depths))]
for i in range(len(backbone_config.depths)):
dim = num_features[i]
for j in range(backbone_config.depths[i]):
in_proj_weight = state_dict.pop(f"backbone.layers.{i}.blocks.{j}.attn.qkv.weight")
in_proj_bias = state_dict.pop(f"backbone.layers.{i}.blocks.{j}.attn.qkv.bias")
state_dict[f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.attention.self.query.weight"] = in_proj_weight[:dim, :]
state_dict[f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.attention.self.query.bias"] = in_proj_bias[: dim]
state_dict[f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.attention.self.key.weight"] = in_proj_weight[
dim : dim * 2, :
]
state_dict[f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.attention.self.key.bias"] = in_proj_bias[
dim : dim * 2
]
state_dict[f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.attention.self.value.weight"] = in_proj_weight[
-dim :, :
]
state_dict[f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.attention.self.value.bias"] = in_proj_bias[-dim :]
def read_in_decoder_q_k_v(state_dict, config):
hidden_size = config.decoder_config.hidden_size
for idx in range(config.decoder_config.decoder_layers):
in_proj_weight = state_dict.pop(f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.self_attn.in_proj_weight")
in_proj_bias = state_dict.pop(f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.self_attn.in_proj_bias")
state_dict[f"model.transformer_module.decoder.layers.{idx}.self_attn.q_proj.weight"] = in_proj_weight[: hidden_size, :]
state_dict[f"model.transformer_module.decoder.layers.{idx}.self_attn.q_proj.bias"] = in_proj_bias[:config.hidden_size]
state_dict[f"model.transformer_module.decoder.layers.{idx}.self_attn.k_proj.weight"] = in_proj_weight[hidden_size : hidden_size * 2, :]
state_dict[f"model.transformer_module.decoder.layers.{idx}.self_attn.k_proj.bias"] = in_proj_bias[hidden_size : hidden_size * 2]
state_dict[f"model.transformer_module.decoder.layers.{idx}.self_attn.v_proj.weight"] = in_proj_weight[-hidden_size :, :]
state_dict[f"model.transformer_module.decoder.layers.{idx}.self_attn.v_proj.bias"] = in_proj_bias[-hidden_size :]
in_proj_weight = state_dict.pop(f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.multihead_attn.in_proj_weight")
in_proj_bias = state_dict.pop(f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.multihead_attn.in_proj_bias")
state_dict[f"model.transformer_module.decoder.layers.{idx}.encoder_attn.q_proj.weight"] = in_proj_weight[: hidden_size, :]
state_dict[f"model.transformer_module.decoder.layers.{idx}.encoder_attn.q_proj.bias"] = in_proj_bias[:config.hidden_size]
state_dict[f"model.transformer_module.decoder.layers.{idx}.encoder_attn.k_proj.weight"] = in_proj_weight[hidden_size : hidden_size * 2, :]
state_dict[f"model.transformer_module.decoder.layers.{idx}.encoder_attn.k_proj.bias"] = in_proj_bias[hidden_size : hidden_size * 2]
state_dict[f"model.transformer_module.decoder.layers.{idx}.encoder_attn.v_proj.weight"] = in_proj_weight[-hidden_size :, :]
state_dict[f"model.transformer_module.decoder.layers.{idx}.encoder_attn.v_proj.bias"] = in_proj_bias[-hidden_size :]
def prepare_img() -> torch.Tensor:
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
im = Image.open(requests.get(url, stream=True).raw)
return im
@torch.no_grad()
def convert_maskformer_checkpoint(
model_name: str, checkpoint_path: str, pytorch_dump_folder_path: str, push_to_hub: bool = False
):
"""
Copy/paste/tweak model's weights to our MaskFormer structure.
"""
config = get_maskformer_config(model_name)
with open(checkpoint_path, "rb") as f:
data = pickle.load(f)
state_dict = data["model"]
rename_keys = create_rename_keys(config)
for src, dest in rename_keys:
rename_key(state_dict, src, dest)
read_in_swin_q_k_v(state_dict, config.backbone_config)
read_in_decoder_q_k_v(state_dict, config)
for key, value in state_dict.items():
state_dict[key] = torch.from_numpy(value)
model = MaskFormerForInstanceSegmentation(config)
model.eval()
for name, param in model.named_parameters():
print(name, param.shape)
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
assert missing_keys == [
"model.pixel_level_module.encoder.model.layernorm.weight",
"model.pixel_level_module.encoder.model.layernorm.bias",
]
assert len(unexpected_keys) == 0, f"Unexpected keys: {unexpected_keys}"
image = prepare_img()
if "vistas" in model_name:
ignore_index = 65
elif "cityscapes" in model_name:
ignore_index = 65535
else:
ignore_index = 255
reduce_labels = True if "ade" in model_name else False
image_processor = MaskFormerImageProcessor(ignore_index=ignore_index, reduce_labels=reduce_labels)
inputs = image_processor(image, return_tensors="pt")
outputs = model(**inputs)
print("Logits:", outputs.class_queries_logits[0, :3, :3])
if model_name == "maskformer-swin-tiny-ade":
expected_logits = torch.tensor(
[[3.6353, -4.4770, -2.6065], [0.5081, -4.2394, -3.5343], [2.1909, -5.0353, -1.9323]]
)
assert torch.allclose(outputs.class_queries_logits[0, :3, :3], expected_logits, atol=1e-4)
print("Looks ok!")
if pytorch_dump_folder_path is not None:
print(f"Saving model and image processor to {pytorch_dump_folder_path}")
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
model.save_pretrained(pytorch_dump_folder_path)
image_processor.save_pretrained(pytorch_dump_folder_path)
if push_to_hub:
print("Pushing model and image processor to the hub...")
model.push_to_hub(f"nielsr/{model_name}")
image_processor.push_to_hub(f"nielsr/{model_name}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name",
default="maskformer-swin-tiny-ade",
type=str,
help=("Name of the MaskFormer model you'd like to convert",),
)
parser.add_argument(
"--checkpoint_path",
default="/Users/nielsrogge/Documents/MaskFormer_checkpoints/MaskFormer-Swin-tiny-ADE20k/model.pkl",
type=str,
help="Path to the original state dict (.pth file).",
)
parser.add_argument(
"--pytorch_dump_folder_path",
default=None,
type=str,
help="Path to the output PyTorch model directory."
)
parser.add_argument(
"--push_to_hub",
action="store_true",
help="Whether or not to push the converted model to the 🤗 hub."
)
args = parser.parse_args()
convert_maskformer_checkpoint(
args.model_name, args.checkpoint_path, args.pytorch_dump_folder_path, args.push_to_hub
)
.\models\maskformer\feature_extraction_maskformer.py
"""
Feature extractor class for MaskFormer.
"""
import warnings
from ...utils import logging
from .image_processing_maskformer import MaskFormerImageProcessor
logger = logging.get_logger(__name__)
class MaskFormerFeatureExtractor(MaskFormerImageProcessor):
def __init__(self, *args, **kwargs) -> None:
warnings.warn(
"The class MaskFormerFeatureExtractor is deprecated and will be removed in version 5 of Transformers."
" Please use MaskFormerImageProcessor instead.",
FutureWarning,
)
super().__init__(*args, **kwargs)
.\models\maskformer\image_processing_maskformer.py
"""MaskFormer 的图像处理器类。"""
import math
import warnings
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple, Union
import numpy as np
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import (
PaddingMode,
get_resize_output_image_size,
pad,
rescale,
resize,
to_channel_dimension_format,
)
from ...image_utils import (
ChannelDimension,
ImageInput,
PILImageResampling,
get_image_size,
infer_channel_dimension_format,
is_scaled_image,
make_list_of_images,
to_numpy_array,
valid_images,
validate_kwargs,
validate_preprocess_arguments,
)
from ...utils import (
IMAGENET_DEFAULT_MEAN,
IMAGENET_DEFAULT_STD,
TensorType,
is_torch_available,
is_torch_tensor,
logging,
)
logger = logging.get_logger(__name__)
if TYPE_CHECKING:
from transformers import MaskFormerForInstanceSegmentationOutput
if is_torch_available():
import torch
from torch import nn
def max_across_indices(values: Iterable[Any]) -> List[Any]:
"""
返回可迭代值的所有索引中的最大值。
"""
return [max(values_i) for values_i in zip(*values)]
def get_max_height_width(
images: List[np.ndarray], input_data_format: Optional[Union[str, ChannelDimension]] = None
) -> List[int]:
"""
获取批次中所有图像的最大高度和宽度。
"""
if input_data_format is None:
input_data_format = infer_channel_dimension_format(images[0])
if input_data_format == ChannelDimension.FIRST:
_, max_height, max_width = max_across_indices([img.shape for img in images])
elif input_data_format == ChannelDimension.LAST:
max_height, max_width, _ = max_across_indices([img.shape for img in images])
else:
raise ValueError(f"Invalid channel dimension format: {input_data_format}")
return (max_height, max_width)
def make_pixel_mask(
def make_pixel_mask(image: np.ndarray, output_size: Tuple[int, int]) -> np.ndarray:
"""
Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
Args:
image (`np.ndarray`):
Image to make the pixel mask for.
output_size (`Tuple[int, int]`):
Output size of the mask.
"""
input_height, input_width = get_image_size(image, channel_dim=input_data_format)
mask = np.zeros(output_size, dtype=np.int64)
mask[:input_height, :input_width] = 1
return mask
def binary_mask_to_rle(mask):
"""
Converts given binary mask of shape `(height, width)` to the run-length encoding (RLE) format.
Args:
mask (`torch.Tensor` or `numpy.array`):
A binary mask tensor of shape `(height, width)` where 0 denotes background and 1 denotes the target
segment_id or class_id.
Returns:
`List`: Run-length encoded list of the binary mask. Refer to COCO API for more information about the RLE
format.
"""
if is_torch_tensor(mask):
mask = mask.numpy()
pixels = mask.flatten()
pixels = np.concatenate([[0], pixels, [0]])
runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
runs[1::2] -= runs[::2]
return list(runs)
def convert_segmentation_to_rle(segmentation):
"""
Converts given segmentation map of shape `(height, width)` to the run-length encoding (RLE) format.
Args:
segmentation (`torch.Tensor` or `numpy.array`):
A segmentation map of shape `(height, width)` where each value denotes a segment or class id.
Returns:
`List[List]`: A list of lists, where each list is the run-length encoding of a segment / class id.
"""
segment_ids = torch.unique(segmentation)
run_length_encodings = []
for idx in segment_ids:
mask = torch.where(segmentation == idx, 1, 0)
rle = binary_mask_to_rle(mask)
run_length_encodings.append(rle)
return run_length_encodings
def remove_low_and_no_objects(masks, scores, labels, object_mask_threshold, num_labels):
"""
Binarize the given masks using `object_mask_threshold`, it returns the associated values of `masks`, `scores` and
`labels`.
Args:
masks (`torch.Tensor`):
A tensor of shape `(num_queries, height, width)`.
scores (`torch.Tensor`):
A tensor of shape `(num_queries)`.
labels (`torch.Tensor`):
A tensor of shape `(num_queries)`.
object_mask_threshold (`float`):
A number between 0 and 1 used to binarize the masks.
Raises:
`ValueError`: Raised when the first dimension doesn't match in all input tensors.
"""
if not (masks.shape[0] == scores.shape[0] == labels.shape[0]):
raise ValueError("mask, scores and labels must have the same shape!")
to_keep = labels.ne(num_labels) & (scores > object_mask_threshold)
return masks[to_keep], scores[to_keep], labels[to_keep]
def check_segment_validity(mask_labels, mask_probs, k, mask_threshold=0.5, overlap_mask_area_threshold=0.8):
mask_k = mask_labels == k
mask_k_area = mask_k.sum()
original_area = (mask_probs[k] >= mask_threshold).sum()
mask_exists = mask_k_area > 0 and original_area > 0
if mask_exists:
area_ratio = mask_k_area / original_area
if not area_ratio.item() > overlap_mask_area_threshold:
mask_exists = False
return mask_exists, mask_k
def compute_segments(
mask_probs,
pred_scores,
pred_labels,
mask_threshold: float = 0.5,
overlap_mask_area_threshold: float = 0.8,
label_ids_to_fuse: Optional[Set[int]] = None,
target_size: Tuple[int, int] = None,
):
height = mask_probs.shape[1] if target_size is None else target_size[0]
width = mask_probs.shape[2] if target_size is None else target_size[1]
segmentation = torch.zeros((height, width), dtype=torch.int32, device=mask_probs.device)
segments: List[Dict] = []
if target_size is not None:
mask_probs = nn.functional.interpolate(
mask_probs.unsqueeze(0), size=target_size, mode="bilinear", align_corners=False
)[0]
current_segment_id = 0
mask_probs *= pred_scores.view(-1, 1, 1)
mask_labels = mask_probs.argmax(0)
stuff_memory_list: Dict[str, int] = {}
for k in range(pred_labels.shape[0]):
pred_class = pred_labels[k].item()
should_fuse = pred_class in label_ids_to_fuse
mask_exists, mask_k = check_segment_validity(
mask_labels, mask_probs, k, mask_threshold, overlap_mask_area_threshold
)
if mask_exists:
if pred_class in stuff_memory_list:
current_segment_id = stuff_memory_list[pred_class]
else:
current_segment_id += 1
segmentation[mask_k] = current_segment_id
segment_score = round(pred_scores[k].item(), 6)
segments.append(
{
"id": current_segment_id,
"label_id": pred_class,
"was_fused": should_fuse,
"score": segment_score,
}
)
if should_fuse:
stuff_memory_list[pred_class] = current_segment_id
return segmentation, segments
def convert_segmentation_map_to_binary_masks(
segmentation_map: "np.ndarray",
instance_id_to_semantic_id: Optional[Dict[int, int]] = None,
ignore_index: Optional[int] = None,
reduce_labels: bool = False,
):
raise ValueError("If `reduce_labels` is True, `ignore_index` must be provided.")
if reduce_labels:
segmentation_map = np.where(segmentation_map == 0, ignore_index, segmentation_map - 1)
all_labels = np.unique(segmentation_map)
if ignore_index is not None:
all_labels = all_labels[all_labels != ignore_index]
binary_masks = [(segmentation_map == i) for i in all_labels]
binary_masks = np.stack(binary_masks, axis=0)
if instance_id_to_semantic_id is not None:
labels = np.zeros(all_labels.shape[0])
for label in all_labels:
class_id = instance_id_to_semantic_id[label + 1 if reduce_labels else label]
labels[all_labels == label] = class_id - 1 if reduce_labels else class_id
else:
labels = all_labels
return binary_masks.astype(np.float32), labels.astype(np.int64)
def get_maskformer_resize_output_image_size(
image: np.ndarray,
size: Union[int, Tuple[int, int], List[int], Tuple[int]],
max_size: Optional[int] = None,
size_divisor: int = 0,
default_to_square: bool = True,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> Tuple[int, int]:
"""
根据所需大小计算输出图像的大小。
Args:
image (`np.ndarray`):
输入图像。
size (`int` or `Tuple[int, int]` or `List[int]` or `Tuple[int]`):
输出图像的大小。
max_size (`int`, *可选*):
输出图像的最大大小。
size_divisor (`int`, *可选*, 默认为 0):
如果提供了 `size_divisor`,输出图像大小将可以被此数整除。
default_to_square (`bool`, *可选*, 默认为 `True`):
如果未提供大小是否默认为正方形。
input_data_format (`ChannelDimension` or `str`, *可选*):
输入图像的通道维度格式。如果未设置,则使用输入的推断格式。
Returns:
`Tuple[int, int]`: 输出图像的大小。
"""
output_size = get_resize_output_image_size(
input_image=image,
size=size,
default_to_square=default_to_square,
max_size=max_size,
input_data_format=input_data_format,
)
if size_divisor > 0:
height, width = output_size
height = int(math.ceil(height / size_divisor) * size_divisor)
width = int(math.ceil(width / size_divisor) * size_divisor)
output_size = (height, width)
return output_size
class MaskFormerImageProcessor(BaseImageProcessor):
r"""
Constructs a MaskFormer image processor. The image processor can be used to prepare image(s) and optional targets
for the model.
This image processor inherits from [`BaseImageProcessor`] which contains most of the main methods. Users should
refer to this superclass for more information regarding those methods.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the input to a certain `size`.
size (`int`, *optional*, defaults to 800):
Resize the input to the given size. Only has an effect if `do_resize` is set to `True`. If size is a
sequence like `(width, height)`, output size will be matched to this. If size is an int, smaller edge of
the image will be matched to this number. i.e, if `height > width`, then image will be rescaled to `(size *
height / width, size)`.
size_divisor (`int`, *optional*, defaults to 32):
Some backbones need images divisible by a certain number. If not passed, it defaults to the value used in
Swin Transformer.
resample (`int`, *optional*, defaults to `Resampling.BILINEAR`):
An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`,
`PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`,
`PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set
to `True`.
do_rescale (`bool`, *optional*, defaults to `True`):
Whether to rescale the input to a certain `scale`.
rescale_factor (`float`, *optional*, defaults to `1/ 255`):
Rescale the input by the given factor. Only has an effect if `do_rescale` is set to `True`.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether or not to normalize the input with mean and standard deviation.
image_mean (`int`, *optional*, defaults to `[0.485, 0.456, 0.406]`):
The sequence of means for each channel, to be used when normalizing images. Defaults to the ImageNet mean.
image_std (`int`, *optional*, defaults to `[0.229, 0.224, 0.225]`):
The sequence of standard deviations for each channel, to be used when normalizing images. Defaults to the
ImageNet std.
ignore_index (`int`, *optional*):
Label to be assigned to background pixels in segmentation maps. If provided, segmentation map pixels
denoted with 0 (background) will be replaced with `ignore_index`.
do_reduce_labels (`bool`, *optional*, defaults to `False`):
Whether or not to decrement all label values of segmentation maps by 1. Usually used for datasets where 0
is used for background, and background itself is not included in all classes of a dataset (e.g. ADE20k).
The background label will be replaced by `ignore_index`.
"""
model_input_names = ["pixel_values", "pixel_mask"]
def __init__(
self,
do_resize: bool = True,
size: Dict[str, int] = None,
size_divisor: int = 32,
resample: PILImageResampling = PILImageResampling.BILINEAR,
do_rescale: bool = True,
rescale_factor: float = 1 / 255,
do_normalize: bool = True,
image_mean: Union[float, List[float]] = None,
image_std: Union[float, List[float]] = None,
ignore_index: Optional[int] = None,
do_reduce_labels: bool = False,
**kwargs,
):
):
if "size_divisibility" in kwargs:
warnings.warn(
"The `size_divisibility` argument is deprecated and will be removed in v4.27. Please use "
"`size_divisor` instead.",
FutureWarning,
)
size_divisor = kwargs.pop("size_divisibility")
if "max_size" in kwargs:
warnings.warn(
"The `max_size` argument is deprecated and will be removed in v4.27. Please use size['longest_edge']"
" instead.",
FutureWarning,
)
self._max_size = kwargs.pop("max_size")
else:
self._max_size = 1333
if "reduce_labels" in kwargs:
warnings.warn(
"The `reduce_labels` argument is deprecated and will be removed in v4.27. Please use "
"`do_reduce_labels` instead.",
FutureWarning,
)
do_reduce_labels = kwargs.pop("reduce_labels")
size = size if size is not None else {"shortest_edge": 800, "longest_edge": self._max_size}
size = get_size_dict(size, max_size=self._max_size, default_to_square=False)
super().__init__(**kwargs)
self.do_resize = do_resize
self.size = size
self.resample = resample
self.size_divisor = size_divisor
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
self.ignore_index = ignore_index
self.do_reduce_labels = do_reduce_labels
self._valid_processor_keys = [
"images",
"segmentation_maps",
"instance_id_to_semantic_id",
"do_resize",
"size",
"size_divisor",
"resample",
"do_rescale",
"rescale_factor",
"do_normalize",
"image_mean",
"image_std",
"ignore_index",
"do_reduce_labels",
"return_tensors",
"data_format",
"input_data_format",
]
@classmethod
def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
"""
Overrides the `from_dict` method from the base class to make sure parameters are updated if image processor is
created using from_dict and kwargs e.g. `MaskFormerImageProcessor.from_pretrained(checkpoint, max_size=800)`
"""
image_processor_dict = image_processor_dict.copy()
if "max_size" in kwargs:
image_processor_dict["max_size"] = kwargs.pop("max_size")
if "size_divisibility" in kwargs:
image_processor_dict["size_divisibility"] = kwargs.pop("size_divisibility")
return super().from_dict(image_processor_dict, **kwargs)
def resize(
self,
image: np.ndarray,
size: Dict[str, int],
size_divisor: int = 0,
resample: PILImageResampling = PILImageResampling.BILINEAR,
data_format=None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
) -> np.ndarray:
"""
Resize the image to the given size. Size can be min_size (scalar) or `(height, width)` tuple. If size is an
int, smaller edge of the image will be matched to this number.
Args:
image (`np.ndarray`):
Image to resize.
size (`Dict[str, int]`):
The size of the output image.
size_divisor (`int`, *optional*, defaults to 0):
If `size_divisor` is given, the output image size will be divisible by the number.
resample (`PILImageResampling` resampling filter, *optional*, defaults to `PILImageResampling.BILINEAR`):
Resampling filter to use when resizing the image.
data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the output image. If unset, the channel dimension format of the input
image is used.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
"""
if "max_size" in kwargs:
warnings.warn(
"The `max_size` parameter is deprecated and will be removed in v4.27. "
"Please specify in `size['longest_edge'] instead`.",
FutureWarning,
)
max_size = kwargs.pop("max_size")
else:
max_size = None
size = get_size_dict(size, max_size=max_size, default_to_square=False)
if "shortest_edge" in size and "longest_edge" in size:
size, max_size = size["shortest_edge"], size["longest_edge"]
elif "height" in size and "width" in size:
size = (size["height"], size["width"])
max_size = None
else:
raise ValueError(
"Size must contain 'height' and 'width' keys or 'shortest_edge' and 'longest_edge' keys. Got"
f" {size.keys()}."
)
size = get_maskformer_resize_output_image_size(
image=image,
size=size,
max_size=max_size,
size_divisor=size_divisor,
default_to_square=False,
input_data_format=input_data_format,
)
image = resize(
image, size=size, resample=resample, data_format=data_format, input_data_format=input_data_format, **kwargs
)
return image
def rescale(
self,
image: np.ndarray,
rescale_factor: float,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray:
"""
Rescale the image by the given factor. image = image * rescale_factor.
Args:
image (`np.ndarray`):
Image to rescale.
rescale_factor (`float`):
The value to use for rescaling.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the output image. If unset, the channel dimension format of the input
image is used. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
input_data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the input image. If unset, is inferred from the input image. Can be
one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
"""
return rescale(image, rescale_factor, data_format=data_format, input_data_format=input_data_format)
def convert_segmentation_map_to_binary_masks(
self,
segmentation_map: "np.ndarray",
instance_id_to_semantic_id: Optional[Dict[int, int]] = None,
ignore_index: Optional[int] = None,
reduce_labels: bool = False,
):
"""
Convert a segmentation map to binary masks.
Args:
segmentation_map (`np.ndarray`):
The input segmentation map.
instance_id_to_semantic_id (Optional[Dict[int, int]]):
Mapping from instance IDs to semantic IDs. If not provided, no mapping is applied.
ignore_index (Optional[int]):
Index to ignore in the segmentation map.
reduce_labels (bool):
Whether to reduce the number of labels in the output.
Returns:
Binary masks corresponding to the segmentation map.
"""
reduce_labels = reduce_labels if reduce_labels is not None else self.reduce_labels
ignore_index = ignore_index if ignore_index is not None else self.ignore_index
return convert_segmentation_map_to_binary_masks(
segmentation_map=segmentation_map,
instance_id_to_semantic_id=instance_id_to_semantic_id,
ignore_index=ignore_index,
reduce_labels=reduce_labels,
)
def __call__(self, images, segmentation_maps=None, **kwargs) -> BatchFeature:
"""
Callable interface for preprocessing images and segmentation maps.
Args:
images:
Images to preprocess.
segmentation_maps:
Segmentation maps associated with the images.
**kwargs:
Additional keyword arguments for preprocessing.
Returns:
Preprocessed batch of features.
"""
return self.preprocess(images, segmentation_maps=segmentation_maps, **kwargs)
def _preprocess(
self,
image: ImageInput,
do_resize: bool = None,
size: Dict[str, int] = None,
size_divisor: int = None,
resample: PILImageResampling = None,
do_rescale: bool = None,
rescale_factor: float = None,
do_normalize: bool = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
):
"""
Internal preprocessing function for handling various transformations on images.
Args:
image (ImageInput):
Input image to preprocess.
do_resize (bool, optional):
Whether to resize the image.
size (Dict[str, int], optional):
Desired size for resizing (width, height).
size_divisor (int, optional):
Divisor for resizing the image dimensions.
resample (PILImageResampling, optional):
Resampling method for resizing.
do_rescale (bool, optional):
Whether to rescale the image.
rescale_factor (float, optional):
Scaling factor for image rescaling.
do_normalize (bool, optional):
Whether to normalize the image.
image_mean (Union[float, List[float]], optional):
Mean values for image normalization.
image_std (Union[float, List[float]], optional):
Standard deviation values for image normalization.
input_data_format (Union[str, ChannelDimension], optional):
Format of the input image data.
Returns:
Preprocessed image based on the specified transformations.
"""
):
if do_resize:
image = self.resize(
image, size=size, size_divisor=size_divisor, resample=resample, input_data_format=input_data_format
)
if do_rescale:
image = self.rescale(image, rescale_factor=rescale_factor, input_data_format=input_data_format)
if do_normalize:
image = self.normalize(image, mean=image_mean, std=image_std, input_data_format=input_data_format)
return image
def _preprocess_image(
self,
image: ImageInput,
do_resize: bool = None,
size: Dict[str, int] = None,
size_divisor: int = None,
resample: PILImageResampling = None,
do_rescale: bool = None,
rescale_factor: float = None,
do_normalize: bool = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray:
"""Preprocesses a single image."""
image = to_numpy_array(image)
if is_scaled_image(image) and do_rescale:
logger.warning_once(
"It looks like you are trying to rescale already rescaled images. If the input"
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
)
if input_data_format is None:
input_data_format = infer_channel_dimension_format(image)
image = self._preprocess(
image=image,
do_resize=do_resize,
size=size,
size_divisor=size_divisor,
resample=resample,
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
input_data_format=input_data_format,
)
if data_format is not None:
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
return image
def _preprocess_mask(
self,
segmentation_map: ImageInput,
do_resize: bool = None,
size: Dict[str, int] = None,
size_divisor: int = 0,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray:
"""Preprocesses a single mask."""
segmentation_map = to_numpy_array(segmentation_map)
if segmentation_map.ndim == 2:
added_channel_dim = True
segmentation_map = segmentation_map[None, ...]
input_data_format = ChannelDimension.FIRST
else:
added_channel_dim = False
if input_data_format is None:
input_data_format = infer_channel_dimension_format(segmentation_map, num_channels=1)
segmentation_map = self._preprocess(
image=segmentation_map,
do_resize=do_resize,
resample=PILImageResampling.NEAREST,
size=size,
size_divisor=size_divisor,
do_rescale=False,
do_normalize=False,
input_data_format=input_data_format,
)
if added_channel_dim:
segmentation_map = segmentation_map.squeeze(0)
return segmentation_map
def preprocess(
self,
images: ImageInput,
segmentation_maps: Optional[ImageInput] = None,
instance_id_to_semantic_id: Optional[Dict[int, int]] = None,
do_resize: Optional[bool] = None,
size: Optional[Dict[str, int]] = None,
size_divisor: Optional[int] = None,
resample: PILImageResampling = None,
do_rescale: Optional[bool] = None,
rescale_factor: Optional[float] = None,
do_normalize: Optional[bool] = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
ignore_index: Optional[int] = None,
do_reduce_labels: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
):
def _pad_image(
self,
image: np.ndarray,
output_size: Tuple[int, int],
constant_values: Union[float, Iterable[float]] = 0,
data_format: Optional[ChannelDimension] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray:
"""
Pad an image with zeros to the given size.
"""
input_height, input_width = get_image_size(image, channel_dim=input_data_format)
output_height, output_width = output_size
pad_bottom = output_height - input_height
pad_right = output_width - input_width
padding = ((0, pad_bottom), (0, pad_right))
padded_image = pad(
image,
padding,
mode=PaddingMode.CONSTANT,
constant_values=constant_values,
data_format=data_format,
input_data_format=input_data_format,
)
return padded_image
def pad(
self,
images: List[np.ndarray],
constant_values: Union[float, Iterable[float]] = 0,
return_pixel_mask: bool = True,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Optional[ChannelDimension] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> BatchFeature:
"""
Pads a batch of images to the bottom and right of the image with zeros to the size of largest height and width
in the batch and optionally returns their corresponding pixel mask.
Args:
image (`np.ndarray`):
Image to pad.
constant_values (`float` or `Iterable[float]`, *optional*):
The value to use for the padding if `mode` is `"constant"`.
return_pixel_mask (`bool`, *optional*, defaults to `True`):
Whether to return a pixel mask.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
"""
pad_size = get_max_height_width(images, input_data_format=input_data_format)
padded_images = [
self._pad_image(
image,
pad_size,
constant_values=constant_values,
data_format=data_format,
input_data_format=input_data_format,
)
for image in images
]
data = {"pixel_values": padded_images}
if return_pixel_mask:
masks = [
make_pixel_mask(image=image, output_size=pad_size, input_data_format=input_data_format)
for image in images
]
data["pixel_mask"] = masks
return BatchFeature(data=data, tensor_type=return_tensors)
def encode_inputs(
self,
pixel_values_list: List[ImageInput],
segmentation_maps: ImageInput = None,
instance_id_to_semantic_id: Optional[Union[List[Dict[int, int]], Dict[int, int]]] = None,
ignore_index: Optional[int] = None,
reduce_labels: bool = False,
return_tensors: Optional[Union[str, TensorType]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
):
"""
Encodes input data into a format suitable for model input, optionally handling segmentation maps and instance IDs.
Args:
pixel_values_list (`List[ImageInput]`):
List of images to encode.
segmentation_maps (`ImageInput`, *optional*):
Segmentation maps corresponding to images.
instance_id_to_semantic_id (`Optional[Union[List[Dict[int, int]], Dict[int, int]]]`, *optional*):
Mapping from instance IDs to semantic IDs.
ignore_index (`Optional[int]`, *optional*):
Index to ignore during encoding.
reduce_labels (`bool`, *optional*, defaults to `False`):
Whether to reduce the number of unique labels.
return_tensors (`Optional[Union[str, TensorType]]`, *optional*):
The type of tensors to return (e.g., `'tf'`, `'pt'`, `'np'`, `'jax'`).
input_data_format (`Optional[Union[str, ChannelDimension]]`, *optional*):
The channel dimension format of the input data.
Returns:
BatchFeature:
Encoded inputs wrapped in a `BatchFeature` object.
"""
def post_process_segmentation(
self, outputs: "MaskFormerForInstanceSegmentationOutput", target_size: Tuple[int, int] = None
):
"""
Post-processes segmentation outputs to adjust them to a target size if specified.
Args:
outputs (`MaskFormerForInstanceSegmentationOutput`):
Model outputs to post-process.
target_size (`Tuple[int, int]`, *optional*):
Target size to resize the outputs.
"""
) -> "torch.Tensor":
"""
Converts the output of [`MaskFormerForInstanceSegmentationOutput`] into image segmentation predictions. Only
supports PyTorch.
Args:
outputs ([`MaskFormerForInstanceSegmentationOutput`]):
The outputs from [`MaskFormerForInstanceSegmentation`].
target_size (`Tuple[int, int]`, *optional*):
If set, the `masks_queries_logits` will be resized to `target_size`.
Returns:
`torch.Tensor`:
A tensor of shape (`batch_size, num_class_labels, height, width`).
"""
logger.warning(
"`post_process_segmentation` is deprecated and will be removed in v5 of Transformers, please use"
" `post_process_instance_segmentation`",
FutureWarning,
)
class_queries_logits = outputs.class_queries_logits
masks_queries_logits = outputs.masks_queries_logits
if target_size is not None:
masks_queries_logits = torch.nn.functional.interpolate(
masks_queries_logits,
size=target_size,
mode="bilinear",
align_corners=False,
)
masks_classes = class_queries_logits.softmax(dim=-1)[..., :-1]
masks_probs = masks_queries_logits.sigmoid()
segmentation = torch.einsum("bqc, bqhw -> bchw", masks_classes, masks_probs)
return segmentation
def post_process_instance_segmentation(
self,
outputs,
threshold: float = 0.5,
mask_threshold: float = 0.5,
overlap_mask_area_threshold: float = 0.8,
target_sizes: Optional[List[Tuple[int, int]]] = None,
return_coco_annotation: Optional[bool] = False,
return_binary_maps: Optional[bool] = False,
) -> "torch.Tensor":
"""
Post-processes outputs of an instance segmentation model, optionally converting them into semantic segmentation maps.
Args:
outputs ([MaskFormerForInstanceSegmentation]):
Raw outputs from the instance segmentation model.
threshold (float):
Threshold value for class probability to consider predictions.
mask_threshold (float):
Threshold value for mask probabilities to consider the mask prediction.
overlap_mask_area_threshold (float):
Threshold for overlapping mask areas.
target_sizes (List[Tuple[int, int]], optional):
List specifying the desired output sizes (height, width) for each prediction.
If `None`, predictions will not be resized.
return_coco_annotation (bool, optional):
Flag indicating whether to return COCO-style annotations.
return_binary_maps (bool, optional):
Flag indicating whether to return binary maps along with semantic segmentation.
Returns:
List[torch.Tensor]:
List of semantic segmentation maps, each of shape (height, width), corresponding to the target_sizes
entries if specified. Each entry contains semantic class IDs.
"""
class_queries_logits = outputs.class_queries_logits
masks_queries_logits = outputs.masks_queries_logits
masks_classes = class_queries_logits.softmax(dim=-1)[..., :-1]
masks_probs = masks_queries_logits.sigmoid()
segmentation = torch.einsum("bqc, bqhw -> bchw", masks_classes, masks_probs)
batch_size = class_queries_logits.shape[0]
if target_sizes is not None:
if batch_size != len(target_sizes):
raise ValueError(
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
)
semantic_segmentation = []
for idx in range(batch_size):
resized_logits = torch.nn.functional.interpolate(
segmentation[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
)
semantic_map = resized_logits[0].argmax(dim=0)
semantic_segmentation.append(semantic_map)
else:
semantic_segmentation = segmentation.argmax(dim=1)
semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]
return semantic_segmentation
def post_process_panoptic_segmentation(
self,
outputs,
threshold: float = 0.5,
mask_threshold: float = 0.5,
overlap_mask_area_threshold: float = 0.8,
label_ids_to_fuse: Optional[Set[int]] = None,
target_sizes: Optional[List[Tuple[int, int]]] = None,