Transformers 源码解析(七十七)
.\models\mobilenet_v2\modeling_mobilenet_v2.py
""" PyTorch MobileNetV2 model."""
from typing import Optional, Union
import torch
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_outputs import (
BaseModelOutputWithPoolingAndNoAttention,
ImageClassifierOutputWithNoAttention,
SemanticSegmenterOutput,
)
from ...modeling_utils import PreTrainedModel
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from .configuration_mobilenet_v2 import MobileNetV2Config
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "MobileNetV2Config"
_CHECKPOINT_FOR_DOC = "google/mobilenet_v2_1.0_224"
_EXPECTED_OUTPUT_SHAPE = [1, 1280, 7, 7]
_IMAGE_CLASS_CHECKPOINT = "google/mobilenet_v2_1.0_224"
_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
MOBILENET_V2_PRETRAINED_MODEL_ARCHIVE_LIST = [
"google/mobilenet_v2_1.4_224",
"google/mobilenet_v2_1.0_224",
"google/mobilenet_v2_0.37_160",
"google/mobilenet_v2_0.35_96",
]
def _build_tf_to_pytorch_map(model, config, tf_weights=None):
"""
A map of modules from TF to PyTorch.
"""
tf_to_pt_map = {}
if isinstance(model, (MobileNetV2ForImageClassification, MobileNetV2ForSemanticSegmentation)):
backbone = model.mobilenet_v2
else:
backbone = model
def ema(x):
return x + "/ExponentialMovingAverage" if x + "/ExponentialMovingAverage" in tf_weights else x
prefix = "MobilenetV2/Conv/"
tf_to_pt_map[ema(prefix + "weights")] = backbone.conv_stem.first_conv.convolution.weight
tf_to_pt_map[ema(prefix + "BatchNorm/beta")] = backbone.conv_stem.first_conv.normalization.bias
tf_to_pt_map[ema(prefix + "BatchNorm/gamma")] = backbone.conv_stem.first_conv.normalization.weight
tf_to_pt_map[prefix + "BatchNorm/moving_mean"] = backbone.conv_stem.first_conv.normalization.running_mean
tf_to_pt_map[prefix + "BatchNorm/moving_variance"] = backbone.conv_stem.first_conv.normalization.running_var
prefix = "MobilenetV2/expanded_conv/depthwise/"
tf_to_pt_map[ema(prefix + "depthwise_weights")] = backbone.conv_stem.conv_3x3.convolution.weight
tf_to_pt_map[ema(prefix + "BatchNorm/beta")] = backbone.conv_stem.conv_3x3.normalization.bias
tf_to_pt_map[ema(prefix + "BatchNorm/gamma")] = backbone.conv_stem.conv_3x3.normalization.weight
tf_to_pt_map[prefix + "BatchNorm/moving_mean"] = backbone.conv_stem.conv_3x3.normalization.running_mean
tf_to_pt_map[prefix + "BatchNorm/moving_variance"] = backbone.conv_stem.conv_3x3.normalization.running_var
tf_to_pt_map[ema(prefix + "depthwise_weights")] = backbone.conv_stem.conv_3x3.convolution.weight
tf_to_pt_map[ema(prefix + "BatchNorm/beta")] = backbone.conv_stem.conv_3x3.normalization.bias
tf_to_pt_map[ema(prefix + "BatchNorm/gamma")] = backbone.conv_stem.conv_3x3.normalization.weight
tf_to_pt_map[prefix + "BatchNorm/moving_mean"] = backbone.conv_stem.conv_3x3.normalization.running_mean
tf_to_pt_map[prefix + "BatchNorm/moving_variance"] = backbone.conv_stem.conv_3x3.normalization.running_var
prefix = "MobilenetV2/expanded_conv/project/"
tf_to_pt_map[ema(prefix + "weights")] = backbone.conv_stem.reduce_1x1.convolution.weight
tf_to_pt_map[ema(prefix + "BatchNorm/beta")] = backbone.conv_stem.reduce_1x1.normalization.bias
tf_to_pt_map[ema(prefix + "BatchNorm/gamma")] = backbone.conv_stem.reduce_1x1.normalization.weight
tf_to_pt_map[prefix + "BatchNorm/moving_mean"] = backbone.conv_stem.reduce_1x1.normalization.running_mean
tf_to_pt_map[prefix + "BatchNorm/moving_variance"] = backbone.conv_stem.reduce_1x1.normalization.running_var
prefix = "MobilenetV2/expanded_conv/project/"
tf_to_pt_map[ema(prefix + "weights")] = backbone.conv_stem.reduce_1x1.convolution.weight
tf_to_pt_map[ema(prefix + "BatchNorm/beta")] = backbone.conv_stem.reduce_1x1.normalization.bias
tf_to_pt_map[ema(prefix + "BatchNorm/gamma")] = backbone.conv_stem.reduce_1x1.normalization.weight
tf_to_pt_map[prefix + "BatchNorm/moving_mean"] = backbone.conv_stem.reduce_1x1.normalization.running_mean
tf_to_pt_map[prefix + "BatchNorm/moving_variance"] = backbone.conv_stem.reduce_1x1.normalization.running_var
for i in range(16):
tf_index = i + 1
pt_index = i
pointer = backbone.layer[pt_index]
prefix = f"MobilenetV2/expanded_conv_{tf_index}/expand/"
tf_to_pt_map[ema(prefix + "weights")] = pointer.expand_1x1.convolution.weight
tf_to_pt_map[ema(prefix + "BatchNorm/beta")] = pointer.expand_1x1.normalization.bias
tf_to_pt_map[ema(prefix + "BatchNorm/gamma")] = pointer.expand_1x1.normalization.weight
tf_to_pt_map[prefix + "BatchNorm/moving_mean"] = pointer.expand_1x1.normalization.running_mean
tf_to_pt_map[prefix + "BatchNorm/moving_variance"] = pointer.expand_1x1.normalization.running_var
prefix = f"MobilenetV2/expanded_conv_{tf_index}/depthwise/"
tf_to_pt_map[ema(prefix + "depthwise_weights")] = pointer.conv_3x3.convolution.weight
tf_to_pt_map[ema(prefix + "BatchNorm/beta")] = pointer.conv_3x3.normalization.bias
tf_to_pt_map[ema(prefix + "BatchNorm/gamma")] = pointer.conv_3x3.normalization.weight
tf_to_pt_map[prefix + "BatchNorm/moving_mean"] = pointer.conv_3x3.normalization.running_mean
tf_to_pt_map[prefix + "BatchNorm/moving_variance"] = pointer.conv_3x3.normalization.running_var
prefix = f"MobilenetV2/expanded_conv_{tf_index}/project/"
tf_to_pt_map[ema(prefix + "weights")] = pointer.reduce_1x1.convolution.weight
tf_to_pt_map[ema(prefix + "BatchNorm/beta")] = pointer.reduce_1x1.normalization.bias
tf_to_pt_map[ema(prefix + "BatchNorm/gamma")] = pointer.reduce_1x1.normalization.weight
tf_to_pt_map[prefix + "BatchNorm/moving_mean"] = pointer.reduce_1x1.normalization.running_mean
tf_to_pt_map[prefix + "BatchNorm/moving_variance"] = pointer.reduce_1x1.normalization.running_var
for i in range(16):
tf_index = i + 1
pt_index = i
pointer = backbone.layer[pt_index]
prefix = f"MobilenetV2/expanded_conv_{tf_index}/expand/"
tf_to_pt_map[ema(prefix + "weights")] = pointer.expand_1x1.convolution.weight
tf_to_pt_map[ema(prefix + "BatchNorm/beta")] = pointer.expand_1x1.normalization.bias
tf_to_pt_map[ema(prefix + "BatchNorm/gamma")] = pointer.expand_1x1.normalization.weight
tf_to_pt_map[prefix + "BatchNorm/moving_mean"] = pointer.expand_1x1.normalization.running_mean
tf_to_pt_map[prefix + "BatchNorm/moving_variance"] = pointer.expand_1x1.normalization.running_var
prefix = f"MobilenetV2/expanded_conv_{tf_index}/depthwise/"
tf_to_pt_map[ema(prefix + "depthwise_weights")] = pointer.conv_3x3.convolution.weight
tf_to_pt_map[ema(prefix + "BatchNorm/beta")] = pointer.conv_3x3.normalization.bias
tf_to_pt_map[ema(prefix + "BatchNorm/gamma")] = pointer.conv_3x3.normalization.weight
tf_to_pt_map[prefix + "BatchNorm/moving_mean"] = pointer.conv_3x3.normalization.running_mean
tf_to_pt_map[prefix + "BatchNorm/moving_variance"] = pointer.conv_3x3.normalization.running_var
prefix = f"MobilenetV2/expanded_conv_{tf_index}/project/"
tf_to_pt_map[ema(prefix + "weights")] = pointer.reduce_1x1.convolution.weight
tf_to_pt_map[ema(prefix + "BatchNorm/beta")] = pointer.reduce_1x1.normalization.bias
tf_to_pt_map[ema(prefix + "BatchNorm/gamma")] = pointer.reduce_1x1.normalization.weight
tf_to_pt_map[prefix + "BatchNorm/moving_mean"] = pointer.reduce_1x1.normalization.running_mean
tf_to_pt_map[prefix + "BatchNorm/moving_variance"] = pointer.reduce_1x1.normalization.running_var
prefix = "MobilenetV2/Conv_1/"
tf_to_pt_map[ema(prefix + "weights")] = backbone.conv_1x1.convolution.weight
prefix = "MobilenetV2/Conv_1/"
tf_to_pt_map[ema(prefix + "weights")] = backbone.conv_1x1.convolution.weight
tf_to_pt_map[ema(prefix + "BatchNorm/beta")] = backbone.conv_1x1.normalization.bias
tf_to_pt_map[ema(prefix + "BatchNorm/gamma")] = backbone.conv_1x1.normalization.weight
tf_to_pt_map[prefix + "BatchNorm/moving_mean"] = backbone.conv_1x1.normalization.running_mean
tf_to_pt_map[prefix + "BatchNorm/moving_variance"] = backbone.conv_1x1.normalization.running_var
if isinstance(model, MobileNetV2ForImageClassification):
prefix = "MobilenetV2/Logits/Conv2d_1c_1x1/"
tf_to_pt_map[ema(prefix + "weights")] = model.classifier.weight
tf_to_pt_map[ema(prefix + "biases")] = model.classifier.bias
if isinstance(model, MobileNetV2ForSemanticSegmentation):
prefix = "image_pooling/"
tf_to_pt_map[prefix + "weights"] = model.segmentation_head.conv_pool.convolution.weight
tf_to_pt_map[prefix + "BatchNorm/beta"] = model.segmentation_head.conv_pool.normalization.bias
tf_to_pt_map[prefix + "BatchNorm/gamma"] = model.segmentation_head.conv_pool.normalization.weight
tf_to_pt_map[prefix + "BatchNorm/moving_mean"] = model.segmentation_head.conv_pool.normalization.running_mean
tf_to_pt_map[prefix + "BatchNorm/moving_variance"] = model.segmentation_head.conv_pool.normalization.running_var
prefix = "aspp0/"
tf_to_pt_map[prefix + "weights"] = model.segmentation_head.conv_aspp.convolution.weight
tf_to_pt_map[prefix + "BatchNorm/beta"] = model.segmentation_head.conv_aspp.normalization.bias
tf_to_pt_map[prefix + "BatchNorm/gamma"] = model.segmentation_head.conv_aspp.normalization.weight
tf_to_pt_map[prefix + "BatchNorm/moving_mean"] = model.segmentation_head.conv_aspp.normalization.running_mean
tf_to_pt_map[prefix + "BatchNorm/moving_variance"] = model.segmentation_head.conv_aspp.normalization.running_var
prefix = "concat_projection/"
tf_to_pt_map[prefix + "weights"] = model.segmentation_head.conv_projection.convolution.weight
tf_to_pt_map[prefix + "BatchNorm/beta"] = model.segmentation_head.conv_projection.normalization.bias
tf_to_pt_map[prefix + "BatchNorm/gamma"] = model.segmentation_head.conv_projection.normalization.weight
tf_to_pt_map[prefix + "BatchNorm/moving_mean"] = model.segmentation_head.conv_projection.normalization.running_mean
tf_to_pt_map[prefix + "BatchNorm/moving_variance"] = model.segmentation_head.conv_projection.normalization.running_var
prefix = "logits/semantic/"
tf_to_pt_map[ema(prefix + "weights")] = model.segmentation_head.classifier.convolution.weight
tf_to_pt_map[ema(prefix + "biases")] = model.segmentation_head.classifier.convolution.bias
return tf_to_pt_map
def load_tf_weights_in_mobilenet_v2(model, config, tf_checkpoint_path):
try:
import numpy as np
import tensorflow as tf
except ImportError:
logger.error(
"Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions."
)
raise
init_vars = tf.train.list_variables(tf_checkpoint_path)
tf_weights = {}
for name, shape in init_vars:
logger.info(f"Loading TF weight {name} with shape {shape}")
array = tf.train.load_variable(tf_checkpoint_path, name)
tf_weights[name] = array
tf_to_pt_map = _build_tf_to_pytorch_map(model, config, tf_weights)
for name, pointer in tf_to_pt_map.items():
logger.info(f"Importing {name}")
if name not in tf_weights:
logger.info(f"{name} not in tf pre-trained weights, skipping")
continue
array = tf_weights[name]
if "depthwise_weights" in name:
logger.info("Transposing depthwise")
array = np.transpose(array, (2, 3, 0, 1))
elif "weights" in name:
logger.info("Transposing")
if len(pointer.shape) == 2:
array = array.squeeze().transpose()
else:
array = np.transpose(array, (3, 2, 0, 1))
if pointer.shape != array.shape:
raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
logger.info(f"Initialize PyTorch weight {name} {array.shape}")
pointer.data = torch.from_numpy(array)
tf_weights.pop(name, None)
tf_weights.pop(name + "/RMSProp", None)
tf_weights.pop(name + "/RMSProp_1", None)
tf_weights.pop(name + "/ExponentialMovingAverage", None)
tf_weights.pop(name + "/Momentum", None)
logger.info(f"Weights not copied to PyTorch model: {', '.join(tf_weights.keys())}")
return model
def make_divisible(value: int, divisor: int = 8, min_value: Optional[int] = None) -> int:
"""
Ensure that all layers have a channel count that is divisible by `divisor`. This function is taken from the
original TensorFlow repo. It can be seen here:
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
"""
if min_value is None:
min_value = divisor
new_value = max(min_value, int(value + divisor / 2) // divisor * divisor)
if new_value < 0.9 * value:
new_value += divisor
return int(new_value)
def apply_depth_multiplier(config: MobileNetV2Config, channels: int) -> int:
return make_divisible(int(round(channels * config.depth_multiplier)), config.depth_divisible_by, config.min_depth)
"""
Apply TensorFlow-style "SAME" padding to a convolution layer. See the notes at:
https://www.tensorflow.org/api_docs/python/tf/nn#notes_on_padding_2
"""
in_height = int(features.shape[-2])
in_width = int(features.shape[-1])
stride_height, stride_width = conv_layer.stride
kernel_height, kernel_width = conv_layer.kernel_size
dilation_height, dilation_width = conv_layer.dilation
if in_height % stride_height == 0:
pad_along_height = max(kernel_height - stride_height, 0)
else:
pad_along_height = max(kernel_height - (in_height % stride_height), 0)
if in_width % stride_width == 0:
pad_along_width = max(kernel_width - stride_width, 0)
else:
pad_along_width = max(kernel_width - (in_width % stride_width), 0)
pad_left = pad_along_width // 2
pad_right = pad_along_width - pad_left
pad_top = pad_along_height // 2
pad_bottom = pad_along_height - pad_top
padding = (
pad_left * dilation_width,
pad_right * dilation_width,
pad_top * dilation_height,
pad_bottom * dilation_height,
)
return nn.functional.pad(features, padding, "constant", 0.0)
) -> None:
super().__init__()
self.config = config
if in_channels % groups != 0:
raise ValueError(f"Input channels ({in_channels}) are not divisible by {groups} groups.")
if out_channels % groups != 0:
raise ValueError(f"Output channels ({out_channels}) are not divisible by {groups} groups.")
padding = 0 if config.tf_padding else int((kernel_size - 1) / 2) * dilation
self.convolution = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
padding_mode="zeros",
)
if use_normalization:
self.normalization = nn.BatchNorm2d(
num_features=out_channels,
eps=config.layer_norm_eps if layer_norm_eps is None else layer_norm_eps,
momentum=0.997,
affine=True,
track_running_stats=True,
)
else:
self.normalization = None
if use_activation:
if isinstance(use_activation, str):
self.activation = ACT2FN[use_activation]
elif isinstance(config.hidden_act, str):
self.activation = ACT2FN[config.hidden_act]
else:
self.activation = config.hidden_act
else:
self.activation = None
def forward(self, features: torch.Tensor) -> torch.Tensor:
if self.config.tf_padding:
features = apply_tf_padding(features, self.convolution)
features = self.convolution(features)
if self.normalization is not None:
features = self.normalization(features)
if self.activation is not None:
features = self.activation(features)
return features
class MobileNetV2InvertedResidual(nn.Module):
def __init__(
self, config: MobileNetV2Config, in_channels: int, out_channels: int, stride: int, dilation: int = 1
) -> None:
super().__init__()
expanded_channels = make_divisible(
int(round(in_channels * config.expand_ratio)), config.depth_divisible_by, config.min_depth
)
if stride not in [1, 2]:
raise ValueError(f"Invalid stride {stride}.")
self.use_residual = (stride == 1) and (in_channels == out_channels)
self.expand_1x1 = MobileNetV2ConvLayer(
config, in_channels=in_channels, out_channels=expanded_channels, kernel_size=1
)
self.conv_3x3 = MobileNetV2ConvLayer(
config,
in_channels=expanded_channels,
out_channels=expanded_channels,
kernel_size=3,
stride=stride,
groups=expanded_channels,
dilation=dilation,
)
self.reduce_1x1 = MobileNetV2ConvLayer(
config,
in_channels=expanded_channels,
out_channels=out_channels,
kernel_size=1,
use_activation=False,
)
def forward(self, features: torch.Tensor) -> torch.Tensor:
residual = features
features = self.expand_1x1(features)
features = self.conv_3x3(features)
features = self.reduce_1x1(features)
return residual + features if self.use_residual else features
class MobileNetV2Stem(nn.Module):
def __init__(self, config: MobileNetV2Config, in_channels: int, expanded_channels: int, out_channels: int) -> None:
super().__init__()
self.first_conv = MobileNetV2ConvLayer(
config,
in_channels=in_channels,
out_channels=expanded_channels,
kernel_size=3,
stride=2,
)
if config.first_layer_is_expansion:
self.expand_1x1 = None
else:
self.expand_1x1 = MobileNetV2ConvLayer(
config, in_channels=expanded_channels, out_channels=expanded_channels, kernel_size=1
)
self.conv_3x3 = MobileNetV2ConvLayer(
config,
in_channels=expanded_channels,
out_channels=expanded_channels,
kernel_size=3,
stride=1,
groups=expanded_channels,
)
self.reduce_1x1 = MobileNetV2ConvLayer(
config,
in_channels=expanded_channels,
out_channels=out_channels,
kernel_size=1,
use_activation=False,
)
def forward(self, features: torch.Tensor) -> torch.Tensor:
features = self.first_conv(features)
if self.expand_1x1 is not None:
features = self.expand_1x1(features)
features = self.conv_3x3(features)
features = self.reduce_1x1(features)
return features
class MobileNetV2PreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = MobileNetV2Config
load_tf_weights = load_tf_weights_in_mobilenet_v2
base_model_prefix = "mobilenet_v2"
main_input_name = "pixel_values"
supports_gradient_checkpointing = False
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d]) -> None:
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d)):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.BatchNorm2d):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
MOBILENET_V2_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
behavior.
Parameters:
config ([`MobileNetV2Config`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
MOBILENET_V2_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
[`MobileNetV2ImageProcessor.__call__`] for details.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
@add_start_docstrings(
"The bare MobileNetV2 model outputting raw hidden-states without any specific head on top.",
MOBILENET_V2_START_DOCSTRING,
)
class MobileNetV2Model(MobileNetV2PreTrainedModel):
pass
def __init__(self, config: MobileNetV2Config, add_pooling_layer: bool = True):
super().__init__(config)
self.config = config
channels = [16, 24, 24, 32, 32, 32, 64, 64, 64, 64, 96, 96, 96, 160, 160, 160, 320]
channels = [apply_depth_multiplier(config, x) for x in channels]
strides = [2, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1]
self.conv_stem = MobileNetV2Stem(
config,
in_channels=config.num_channels,
expanded_channels=apply_depth_multiplier(config, 32),
out_channels=channels[0],
)
current_stride = 2
dilation = 1
self.layer = nn.ModuleList()
for i in range(16):
if current_stride == config.output_stride:
layer_stride = 1
layer_dilation = dilation
dilation *= strides[i]
else:
layer_stride = strides[i]
layer_dilation = 1
current_stride *= layer_stride
self.layer.append(
MobileNetV2InvertedResidual(
config,
in_channels=channels[i],
out_channels=channels[i + 1],
stride=layer_stride,
dilation=layer_dilation,
)
)
if config.finegrained_output and config.depth_multiplier < 1.0:
output_channels = 1280
else:
output_channels = apply_depth_multiplier(config, 1280)
self.conv_1x1 = MobileNetV2ConvLayer(
config,
in_channels=channels[-1],
out_channels=output_channels,
kernel_size=1,
)
self.pooler = nn.AdaptiveAvgPool2d((1, 1)) if add_pooling_layer else None
self.post_init()
def _prune_heads(self, heads_to_prune):
raise NotImplementedError
@add_start_docstrings_to_model_forward(MOBILENET_V2_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=BaseModelOutputWithPoolingAndNoAttention,
config_class=_CONFIG_FOR_DOC,
modality="vision",
expected_output=_EXPECTED_OUTPUT_SHAPE,
)
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
"""
Forward pass of the MobileNetV2 model.
Args:
pixel_values (Optional[torch.Tensor]): Input tensor of shape (batch_size, channels, height, width).
output_hidden_states (Optional[bool]): Whether to return hidden states.
return_dict (Optional[bool]): Whether to return as a dictionary.
Returns:
BaseModelOutputWithPoolingAndNoAttention: A namedtuple with the model outputs.
"""
pass
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
hidden_states = self.conv_stem(pixel_values)
all_hidden_states = () if output_hidden_states else None
for i, layer_module in enumerate(self.layer):
hidden_states = layer_module(hidden_states)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
last_hidden_state = self.conv_1x1(hidden_states)
if self.pooler is not None:
pooled_output = torch.flatten(self.pooler(last_hidden_state), start_dim=1)
else:
pooled_output = None
if not return_dict:
return tuple(v for v in [last_hidden_state, pooled_output, all_hidden_states] if v is not None)
return BaseModelOutputWithPoolingAndNoAttention(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=all_hidden_states,
)
``
) -> Union[tuple, ImageClassifierOutputWithNoAttention]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss). If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.mobilenet_v2(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict)
pooled_output = outputs.pooler_output if return_dict else outputs[1]
logits = self.classifier(self.dropout(pooled_output))
loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return ImageClassifierOutputWithNoAttention(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
)
class MobileNetV2DeepLabV3Plus(nn.Module):
"""
The neural network from the paper "Encoder-Decoder with Atrous Separable Convolution for Semantic Image
Segmentation" https://arxiv.org/abs/1802.02611
"""
def __init__(self, config: MobileNetV2Config) -> None:
super().__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(output_size=1)
self.conv_pool = MobileNetV2ConvLayer(
config,
in_channels=apply_depth_multiplier(config, 320),
out_channels=256,
kernel_size=1,
stride=1,
use_normalization=True,
use_activation="relu",
layer_norm_eps=1e-5,
)
self.conv_aspp = MobileNetV2ConvLayer(
config,
in_channels=apply_depth_multiplier(config, 320),
out_channels=256,
kernel_size=1,
stride=1,
use_normalization=True,
use_activation="relu",
layer_norm_eps=1e-5,
)
self.conv_projection = MobileNetV2ConvLayer(
config,
in_channels=512,
out_channels=256,
kernel_size=1,
stride=1,
use_normalization=True,
use_activation="relu",
layer_norm_eps=1e-5,
)
self.dropout = nn.Dropout2d(config.classifier_dropout_prob)
self.classifier = MobileNetV2ConvLayer(
config,
in_channels=256,
out_channels=config.num_labels,
kernel_size=1,
use_normalization=False,
use_activation=False,
bias=True,
)
def forward(self, features: torch.Tensor) -> torch.Tensor:
spatial_size = features.shape[-2:]
features_pool = self.avg_pool(features)
features_pool = self.conv_pool(features_pool)
features_pool = nn.functional.interpolate(
features_pool, size=spatial_size, mode="bilinear", align_corners=True
)
features_aspp = self.conv_aspp(features)
features = torch.cat([features_pool, features_aspp], dim=1)
features = self.conv_projection(features)
features = self.dropout(features)
features = self.classifier(features)
return features
@add_start_docstrings(
"""
MobileNetV2 model with a semantic segmentation head on top, e.g. for Pascal VOC.
""",
MOBILENET_V2_START_DOCSTRING,
)
class MobileNetV2ForSemanticSegmentation(MobileNetV2PreTrainedModel):
def __init__(self, config: MobileNetV2Config) -> None:
super().__init__(config)
self.num_labels = config.num_labels
self.mobilenet_v2 = MobileNetV2Model(config, add_pooling_layer=False)
self.segmentation_head = MobileNetV2DeepLabV3Plus(config)
self.post_init()
@add_start_docstrings_to_model_forward(MOBILENET_V2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=SemanticSegmenterOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.mobilenet_v2(
pixel_values,
output_hidden_states=True,
return_dict=return_dict,
)
encoder_hidden_states = outputs.hidden_states if return_dict else outputs[1]
logits = self.segmentation_head(encoder_hidden_states[-1])
loss = None
if labels is not None:
if self.config.num_labels == 1:
raise ValueError("标签数量应大于1")
else:
upsampled_logits = nn.functional.interpolate(
logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
)
loss_fct = CrossEntropyLoss(ignore_index=self.config.semantic_loss_ignore_index)
loss = loss_fct(upsampled_logits, labels)
if not return_dict:
if output_hidden_states:
output = (logits,) + outputs[1:]
else:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return SemanticSegmenterOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states if output_hidden_states else None,
attentions=None,
)
.\models\mobilenet_v2\__init__.py
from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
_import_structure = {
"configuration_mobilenet_v2": [
"MOBILENET_V2_PRETRAINED_CONFIG_ARCHIVE_MAP",
"MobileNetV2Config",
"MobileNetV2OnnxConfig",
],
}
try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["feature_extraction_mobilenet_v2"] = ["MobileNetV2FeatureExtractor"]
_import_structure["image_processing_mobilenet_v2"] = ["MobileNetV2ImageProcessor"]
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_mobilenet_v2"] = [
"MOBILENET_V2_PRETRAINED_MODEL_ARCHIVE_LIST",
"MobileNetV2ForImageClassification",
"MobileNetV2ForSemanticSegmentation",
"MobileNetV2Model",
"MobileNetV2PreTrainedModel",
"load_tf_weights_in_mobilenet_v2",
]
if TYPE_CHECKING:
from .configuration_mobilenet_v2 import (
MOBILENET_V2_PRETRAINED_CONFIG_ARCHIVE_MAP,
MobileNetV2Config,
MobileNetV2OnnxConfig,
)
try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .feature_extraction_mobilenet_v2 import MobileNetV2FeatureExtractor
from .image_processing_mobilenet_v2 import MobileNetV2ImageProcessor
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_mobilenet_v2 import (
MOBILENET_V2_PRETRAINED_MODEL_ARCHIVE_LIST,
MobileNetV2ForImageClassification,
MobileNetV2ForSemanticSegmentation,
MobileNetV2Model,
MobileNetV2PreTrainedModel,
load_tf_weights_in_mobilenet_v2,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
.\models\mobilevit\configuration_mobilevit.py
from collections import OrderedDict
from typing import Mapping
from packaging import version
from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig
from ...utils import logging
logger = logging.get_logger(__name__)
MOBILEVIT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"apple/mobilevit-small": "https://huggingface.co/apple/mobilevit-small/resolve/main/config.json",
"apple/mobilevit-x-small": "https://huggingface.co/apple/mobilevit-x-small/resolve/main/config.json",
"apple/mobilevit-xx-small": "https://huggingface.co/apple/mobilevit-xx-small/resolve/main/config.json",
"apple/deeplabv3-mobilevit-small": (
"https://huggingface.co/apple/deeplabv3-mobilevit-small/resolve/main/config.json"
),
"apple/deeplabv3-mobilevit-x-small": (
"https://huggingface.co/apple/deeplabv3-mobilevit-x-small/resolve/main/config.json"
),
"apple/deeplabv3-mobilevit-xx-small": (
"https://huggingface.co/apple/deeplabv3-mobilevit-xx-small/resolve/main/config.json"
),
}
class MobileViTConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`MobileViTModel`]. It is used to instantiate a
MobileViT model according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the MobileViT
[apple/mobilevit-small](https://huggingface.co/apple/mobilevit-small) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Example:
```
>>> from transformers import MobileViTConfig, MobileViTModel
>>> # Initializing a mobilevit-small style configuration
>>> configuration = MobileViTConfig()
>>> # Initializing a model from the mobilevit-small style configuration
>>> model = MobileViTModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
"""
model_type = "mobilevit"
def __init__(
self,
num_channels=3,
image_size=256,
patch_size=2,
hidden_sizes=[144, 192, 240],
neck_hidden_sizes=[16, 32, 64, 96, 128, 160, 640],
num_attention_heads=4,
mlp_ratio=2.0,
expand_ratio=4.0,
hidden_act="silu",
conv_kernel_size=3,
output_stride=32,
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.0,
classifier_dropout_prob=0.1,
initializer_range=0.02,
layer_norm_eps=1e-5,
qkv_bias=True,
aspp_out_channels=256,
atrous_rates=[6, 12, 18],
aspp_dropout_prob=0.1,
semantic_loss_ignore_index=255,
**kwargs,
):
super().__init__(**kwargs)
self.num_channels = num_channels
self.image_size = image_size
self.patch_size = patch_size
self.hidden_sizes = hidden_sizes
self.neck_hidden_sizes = neck_hidden_sizes
self.num_attention_heads = num_attention_heads
self.mlp_ratio = mlp_ratio
self.expand_ratio = expand_ratio
self.hidden_act = hidden_act
self.conv_kernel_size = conv_kernel_size
self.output_stride = output_stride
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.classifier_dropout_prob = classifier_dropout_prob
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.qkv_bias = qkv_bias
self.aspp_out_channels = aspp_out_channels
self.atrous_rates = atrous_rates
self.aspp_dropout_prob = aspp_dropout_prob
self.semantic_loss_ignore_index = semantic_loss_ignore_index
class MobileViTOnnxConfig(OnnxConfig):
torch_onnx_minimum_version = version.parse("1.11")
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict([("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"})])
@property
def outputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task == "image-classification":
return OrderedDict([("logits", {0: "batch"})])
else:
return OrderedDict([("last_hidden_state", {0: "batch"}), ("pooler_output", {0: "batch"})])
@property
def atol_for_validation(self) -> float:
return 1e-4
.\models\mobilevit\convert_mlcvnets_to_pytorch.py
import argparse
import json
from pathlib import Path
import requests
import torch
from huggingface_hub import hf_hub_download
from PIL import Image
from transformers import (
MobileViTConfig,
MobileViTForImageClassification,
MobileViTForSemanticSegmentation,
MobileViTImageProcessor,
)
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
def get_mobilevit_config(mobilevit_name):
config = MobileViTConfig()
if "mobilevit_s" in mobilevit_name:
config.hidden_sizes = [144, 192, 240]
config.neck_hidden_sizes = [16, 32, 64, 96, 128, 160, 640]
elif "mobilevit_xs" in mobilevit_name:
config.hidden_sizes = [96, 120, 144]
config.neck_hidden_sizes = [16, 32, 48, 64, 80, 96, 384]
elif "mobilevit_xxs" in mobilevit_name:
config.hidden_sizes = [64, 80, 96]
config.neck_hidden_sizes = [16, 16, 24, 48, 64, 80, 320]
config.hidden_dropout_prob = 0.05
config.expand_ratio = 2.0
if mobilevit_name.startswith("deeplabv3_"):
config.image_size = 512
config.output_stride = 16
config.num_labels = 21
filename = "pascal-voc-id2label.json"
else:
config.num_labels = 1000
filename = "imagenet-1k-id2label.json"
repo_id = "huggingface/label-files"
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
config.id2label = id2label
config.label2id = {v: k for k, v in id2label.items()}
return config
def rename_key(name, base_model=False):
for i in range(1, 6):
if f"layer_{i}." in name:
name = name.replace(f"layer_{i}.", f"encoder.layer.{i - 1}.")
if "conv_1." in name:
name = name.replace("conv_1.", "conv_stem.")
if ".block." in name:
name = name.replace(".block.", ".")
if "exp_1x1" in name:
name = name.replace("exp_1x1", "expand_1x1")
if "red_1x1" in name:
name = name.replace("red_1x1", "reduce_1x1")
if ".local_rep.conv_3x3." in name:
name = name.replace(".local_rep.conv_3x3.", ".conv_kxk.")
if ".local_rep.conv_1x1." in name:
name = name.replace(".local_rep.conv_1x1.", ".conv_1x1.")
if ".norm." in name:
name = name.replace(".norm.", ".normalization.")
if ".conv." in name:
name = name.replace(".conv.", ".convolution.")
if ".conv_proj." in name:
name = name.replace(".conv_proj.", ".conv_projection.")
for i in range(0, 2):
for j in range(0, 4):
if f".{i}.{j}." in name:
name = name.replace(f".{i}.{j}.", f".{i}.layer.{j}.")
for i in range(2, 6):
for j in range(0, 4):
if f".{i}.{j}." in name:
name = name.replace(f".{i}.{j}.", f".{i}.")
if "expand_1x1" in name:
name = name.replace("expand_1x1", "downsampling_layer.expand_1x1")
if "conv_3x3" in name:
name = name.replace("conv_3x3", "downsampling_layer.conv_3x3")
if "reduce_1x1" in name:
name = name.replace("reduce_1x1", "downsampling_layer.reduce_1x1")
for i in range(2, 5):
if f".global_rep.{i}.weight" in name:
name = name.replace(f".global_rep.{i}.weight", ".layernorm.weight")
if f".global_rep.{i}.bias" in name:
name = name.replace(f".global_rep.{i}.bias", ".layernorm.bias")
if ".global_rep." in name:
name = name.replace(".global_rep.", ".transformer.")
if ".pre_norm_mha.0." in name:
name = name.replace(".pre_norm_mha.0.", ".layernorm_before.")
if ".pre_norm_mha.1.out_proj." in name:
name = name.replace(".pre_norm_mha.1.out_proj.", ".attention.output.dense.")
if ".pre_norm_ffn.0." in name:
name = name.replace(".pre_norm_ffn.0.", ".layernorm_after.")
if ".pre_norm_ffn.1." in name:
name = name.replace(".pre_norm_ffn.1.", ".intermediate.dense.")
if ".pre_norm_ffn.4." in name:
name = name.replace(".pre_norm_ffn.4.", ".output.dense.")
if ".transformer." in name:
name = name.replace(".transformer.", ".transformer.layer.")
if ".aspp_layer." in name:
name = name.replace(".aspp_layer.", ".")
if ".aspp_pool." in name:
name = name.replace(".aspp_pool.", ".")
if "seg_head." in name:
name = name.replace("seg_head.", "segmentation_head.")
if "segmentation_head.classifier.classifier." in name:
name = name.replace("segmentation_head.classifier.classifier.", "segmentation_head.classifier.")
if "classifier.fc." in name:
name = name.replace("classifier.fc.", "classifier.")
elif (not base_model) and ("segmentation_head." not in name):
name = "mobilevit." + name
return name
def convert_state_dict(orig_state_dict, model, base_model=False):
if base_model:
model_prefix = ""
else:
model_prefix = "mobilevit."
for key in orig_state_dict.copy().keys():
val = orig_state_dict.pop(key)
if key[:8] == "encoder.":
key = key[8:]
if "qkv" in key:
key_split = key.split(".")
layer_num = int(key_split[0][6:]) - 1
transformer_num = int(key_split[3])
layer = model.get_submodule(f"{model_prefix}encoder.layer.{layer_num}")
dim = layer.transformer.layer[transformer_num].attention.attention.all_head_size
prefix = (
f"{model_prefix}encoder.layer.{layer_num}.transformer.layer.{transformer_num}.attention.attention."
)
if "weight" in key:
orig_state_dict[prefix + "query.weight"] = val[:dim, :]
orig_state_dict[prefix + "key.weight"] = val[dim : dim * 2, :]
orig_state_dict[prefix + "value.weight"] = val[-dim:, :]
else:
orig_state_dict[prefix + "query.bias"] = val[:dim]
orig_state_dict[prefix + "key.bias"] = val[dim : dim * 2]
orig_state_dict[prefix + "value.bias"] = val[-dim:]
else:
orig_state_dict[rename_key(key, base_model)] = val
return orig_state_dict
@torch.no_grad()
def convert_movilevit_checkpoint(mobilevit_name, checkpoint_path, pytorch_dump_folder_path, push_to_hub=False):
"""
Copy/paste/tweak model's weights to our MobileViT structure.
"""
config = get_mobilevit_config(mobilevit_name)
state_dict = torch.load(checkpoint_path, map_location="cpu")
if mobilevit_name.startswith("deeplabv3_"):
model = MobileViTForSemanticSegmentation(config).eval()
else:
model = MobileViTForImageClassification(config).eval()
new_state_dict = convert_state_dict(state_dict, model)
model.load_state_dict(new_state_dict)
image_processor = MobileViTImageProcessor(crop_size=config.image_size, size=config.image_size + 32)
encoding = image_processor(images=prepare_img(), return_tensors="pt")
outputs = model(**encoding)
logits = outputs.logits
if mobilevit_name.startswith("deeplabv3_"):
assert logits.shape == (1, 21, 32, 32)
if mobilevit_name == "deeplabv3_mobilevit_s":
expected_logits = torch.tensor(
[
[[6.2065, 6.1292, 6.2070], [6.1079, 6.1254, 6.1747], [6.0042, 6.1071, 6.1034]],
[[-6.9253, -6.8653, -7.0398], [-7.3218, -7.3983, -7.3670], [-7.1961, -7.2482, -7.1569]],
[[-4.4723, -4.4348, -4.3769], [-5.3629, -5.4632, -5.4598], [-5.1587, -5.3402, -5.5059]],
]
)
elif mobilevit_name == "deeplabv3_mobilevit_xs":
expected_logits = torch.tensor(
[
[[5.4449, 5.5733, 5.6314], [5.1815, 5.3930, 5.5963], [5.1656, 5.4333, 5.4853]],
[[-9.4423, -9.7766, -9.6714], [-9.1581, -9.5720, -9.5519], [-9.1006, -9.6458, -9.5703]],
[[-7.7721, -7.3716, -7.1583], [-8.4599, -8.0624, -7.7944], [-8.4172, -7.8366, -7.5025]],
]
)
elif mobilevit_name == "deeplabv3_mobilevit_xxs":
expected_logits = torch.tensor(
[
[[6.9811, 6.9743, 7.3123], [7.1777, 7.1931, 7.3938], [7.5633, 7.8050, 7.8901]],
[[-10.5536, -10.2332, -10.2924], [-10.2336, -9.8624, -9.5964], [-10.8840, -10.8158, -10.6659]],
[[-3.4938, -3.0631, -2.8620], [-3.4205, -2.8135, -2.6875], [-3.4179, -2.7945, -2.8750]],
]
)
else:
raise ValueError(f"Unknown mobilevit_name: {mobilevit_name}")
assert torch.allclose(logits[0, :3, :3, :3], expected_logits, atol=1e-4)
else:
assert logits.shape == (1, 1000)
if mobilevit_name == "mobilevit_s":
expected_logits = torch.tensor([-0.9866, 0.2392, -1.1241])
elif mobilevit_name == "mobilevit_xs":
expected_logits = torch.tensor([-2.4761, -0.9399, -1.9587])
elif mobilevit_name == "mobilevit_xxs":
expected_logits = torch.tensor([-1.9364, -1.2327, -0.4653])
else:
raise ValueError(f"Unknown mobilevit_name: {mobilevit_name}")
assert torch.allclose(logits[0, :3], expected_logits, atol=1e-4)
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
print(f"Saving model {mobilevit_name} to {pytorch_dump_folder_path}")
model.save_pretrained(pytorch_dump_folder_path)
print(f"Saving image processor to {pytorch_dump_folder_path}")
image_processor.save_pretrained(pytorch_dump_folder_path)
if push_to_hub:
model_mapping = {
"mobilevit_s": "mobilevit-small",
"mobilevit_xs": "mobilevit-x-small",
"mobilevit_xxs": "mobilevit-xx-small",
"deeplabv3_mobilevit_s": "deeplabv3-mobilevit-small",
"deeplabv3_mobilevit_xs": "deeplabv3-mobilevit-x-small",
"deeplabv3_mobilevit_xxs": "deeplabv3-mobilevit-xx-small",
}
print("Pushing to the hub...")
model_name = model_mapping[mobilevit_name]
image_processor.push_to_hub(model_name, organization="apple")
model.push_to_hub(model_name, organization="apple")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--mobilevit_name",
default="mobilevit_s",
type=str,
help=(
"Name of the MobileViT model you'd like to convert. Should be one of 'mobilevit_s', 'mobilevit_xs',"
" 'mobilevit_xxs', 'deeplabv3_mobilevit_s', 'deeplabv3_mobilevit_xs', 'deeplabv3_mobilevit_xxs'."
),
)
parser.add_argument(
"--checkpoint_path", required=True, type=str, help="Path to the original state dict (.pt file)."
)
parser.add_argument(
"--pytorch_dump_folder_path", required=True, type=str, help="Path to the output PyTorch model directory."
)
parser.add_argument(
"--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
)
args = parser.parse_args()
convert_movilevit_checkpoint(
args.mobilevit_name, args.checkpoint_path, args.pytorch_dump_folder_path, args.push_to_hub
)
.\models\mobilevit\feature_extraction_mobilevit.py
"""MobileViT 的特征提取器类。"""
import warnings
from ...utils import logging
from .image_processing_mobilevit import MobileViTImageProcessor
logger = logging.get_logger(__name__)
class MobileViTFeatureExtractor(MobileViTImageProcessor):
def __init__(self, *args, **kwargs) -> None:
warnings.warn(
"The class MobileViTFeatureExtractor is deprecated and will be removed in version 5 of Transformers."
" Please use MobileViTImageProcessor instead.",
FutureWarning,
)
super().__init__(*args, **kwargs)
.\models\mobilevit\image_processing_mobilevit.py
"""Image processor class for MobileViT."""
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import flip_channel_order, get_resize_output_image_size, resize, to_channel_dimension_format
from ...image_utils import (
ChannelDimension,
ImageInput,
PILImageResampling,
infer_channel_dimension_format,
is_scaled_image,
make_list_of_images,
to_numpy_array,
valid_images,
validate_kwargs,
validate_preprocess_arguments,
)
from ...utils import TensorType, is_torch_available, is_torch_tensor, is_vision_available, logging
if is_vision_available():
import PIL
if is_torch_available():
import torch
logger = logging.get_logger(__name__)
class MobileViTImageProcessor(BaseImageProcessor):
r"""
Constructs a MobileViT image processor.
构建 MobileViT 图像处理器类
"""
"""
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the image's (height, width) dimensions to the specified `size`.
Can be overridden by the `do_resize` parameter in the `preprocess` method.
size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 224}`):
Controls the size of the output image after resizing.
Can be overridden by the `size` parameter in the `preprocess` method.
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
Defines the resampling filter to use if resizing the image.
Can be overridden by the `resample` parameter in the `preprocess` method.
do_rescale (`bool`, *optional*, defaults to `True`):
Whether to rescale the image by the specified scale `rescale_factor`.
Can be overridden by the `do_rescale` parameter in the `preprocess` method.
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
Scale factor to use if rescaling the image.
Can be overridden by the `rescale_factor` parameter in the `preprocess` method.
do_center_crop (`bool`, *optional*, defaults to `True`):
Whether to crop the input at the center.
Can be overridden by the `do_center_crop` parameter in the `preprocess` method.
crop_size (`Dict[str, int]`, *optional*, defaults to `{"height": 256, "width": 256}`):
Desired output size `(size["height"], size["width"])` when applying center-cropping.
Can be overridden by the `crop_size` parameter in the `preprocess` method.
do_flip_channel_order (`bool`, *optional*, defaults to `True`):
Whether to flip the color channels from RGB to BGR.
Can be overridden by the `do_flip_channel_order` parameter in the `preprocess` method.
"""
model_input_names = ["pixel_values"]
def __init__(
self,
do_resize: bool = True,
size: Dict[str, int] = None,
resample: PILImageResampling = PILImageResampling.BILINEAR,
do_rescale: bool = True,
rescale_factor: Union[int, float] = 1 / 255,
do_center_crop: bool = True,
crop_size: Dict[str, int] = None,
do_flip_channel_order: bool = True,
**kwargs,
):
pass
) -> None:
super().__init__(**kwargs)
size = size if size is not None else {"shortest_edge": 224}
size = get_size_dict(size, default_to_square=False)
crop_size = crop_size if crop_size is not None else {"height": 256, "width": 256}
crop_size = get_size_dict(crop_size, param_name="crop_size")
self.do_resize = do_resize
self.size = size
self.resample = resample
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_center_crop = do_center_crop
self.crop_size = crop_size
self.do_flip_channel_order = do_flip_channel_order
self._valid_processor_keys = [
"images",
"segmentation_maps",
"do_resize",
"size",
"resample",
"do_rescale",
"rescale_factor",
"do_center_crop",
"crop_size",
"do_flip_channel_order",
"return_tensors",
"data_format",
"input_data_format",
]
def resize(
self,
image: np.ndarray,
size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.BILINEAR,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
) -> np.ndarray:
"""
Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge
resized to keep the input aspect ratio.
Args:
image (`np.ndarray`):
Image to resize.
size (`Dict[str, int]`):
Size of the output image.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
Resampling filter to use when resizing the image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
"""
default_to_square = True
if "shortest_edge" in size:
size = size["shortest_edge"]
default_to_square = False
elif "height" in size and "width" in size:
size = (size["height"], size["width"])
else:
raise ValueError("Size must contain either 'shortest_edge' or 'height' and 'width'.")
output_size = get_resize_output_image_size(
image,
size=size,
default_to_square=default_to_square,
input_data_format=input_data_format,
)
return resize(
image,
size=output_size,
resample=resample,
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
)
def flip_channel_order(
self,
image: np.ndarray,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray:
"""
Flip the color channels from RGB to BGR or vice versa.
Args:
image (`np.ndarray`):
The image, represented as a numpy array.
data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
"""
return flip_channel_order(image, data_format=data_format, input_data_format=input_data_format)
def __call__(self, images, segmentation_maps=None, **kwargs):
"""
Preprocesses a batch of images and optionally segmentation maps.
Overrides the `__call__` method of the `Preprocessor` class so that both images and segmentation maps can be
passed in as positional arguments.
"""
return super().__call__(images, segmentation_maps=segmentation_maps, **kwargs)
def _preprocess(
self,
image: ImageInput,
do_resize: bool,
do_rescale: bool,
do_center_crop: bool,
do_flip_channel_order: bool,
size: Optional[Dict[str, int]] = None,
resample: PILImageResampling = None,
rescale_factor: Optional[float] = None,
crop_size: Optional[Dict[str, int]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
):
if do_resize:
image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
if do_rescale:
image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
if do_center_crop:
image = self.center_crop(image=image, size=crop_size, input_data_format=input_data_format)
if do_flip_channel_order:
image = self.flip_channel_order(image, input_data_format=input_data_format)
return image
def _preprocess_image(
self,
image: ImageInput,
do_resize: bool = None,
size: Dict[str, int] = None,
resample: PILImageResampling = None,
do_rescale: bool = None,
rescale_factor: float = None,
do_center_crop: bool = None,
crop_size: Dict[str, int] = None,
do_flip_channel_order: bool = None,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray:
"""Preprocesses a single image."""
image = to_numpy_array(image)
if is_scaled_image(image) and do_rescale:
logger.warning_once(
"It looks like you are trying to rescale already rescaled images. If the input"
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
)
if input_data_format is None:
input_data_format = infer_channel_dimension_format(image)
image = self._preprocess(
image=image,
do_resize=do_resize,
size=size,
resample=resample,
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_center_crop=do_center_crop,
crop_size=crop_size,
do_flip_channel_order=do_flip_channel_order,
input_data_format=input_data_format,
)
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
return image
def _preprocess_mask(
self,
segmentation_map: ImageInput,
do_resize: bool = None,
size: Dict[str, int] = None,
do_center_crop: bool = None,
crop_size: Dict[str, int] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
):
pass
) -> np.ndarray:
"""Preprocesses a single mask."""
segmentation_map = to_numpy_array(segmentation_map)
if segmentation_map.ndim == 2:
added_channel_dim = True
segmentation_map = segmentation_map[None, ...]
input_data_format = ChannelDimension.FIRST
else:
added_channel_dim = False
if input_data_format is None:
input_data_format = infer_channel_dimension_format(segmentation_map, num_channels=1)
segmentation_map = self._preprocess(
image=segmentation_map,
do_resize=do_resize,
size=size,
resample=PILImageResampling.NEAREST,
do_rescale=False,
do_center_crop=do_center_crop,
crop_size=crop_size,
do_flip_channel_order=False,
input_data_format=input_data_format,
)
if added_channel_dim:
segmentation_map = segmentation_map.squeeze(0)
segmentation_map = segmentation_map.astype(np.int64)
return segmentation_map
def post_process_semantic_segmentation(self, outputs, target_sizes: List[Tuple] = None):
"""
Converts the output of [`MobileViTForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch.
Args:
outputs ([`MobileViTForSemanticSegmentation`]):
Raw outputs of the model.
target_sizes (`List[Tuple]` of length `batch_size`, *optional*):
List of tuples corresponding to the requested final size (height, width) of each prediction. If unset,
predictions will not be resized.
Returns:
semantic_segmentation: `List[torch.Tensor]` of length `batch_size`, where each item is a semantic
segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is
specified). Each entry of each `torch.Tensor` correspond to a semantic class id.
"""
logits = outputs.logits
if target_sizes is not None:
if len(logits) != len(target_sizes):
raise ValueError(
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
)
if is_torch_tensor(target_sizes):
target_sizes = target_sizes.numpy()
semantic_segmentation = []
for idx in range(len(logits)):
resized_logits = torch.nn.functional.interpolate(
logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
)
semantic_map = resized_logits[0].argmax(dim=0)
semantic_segmentation.append(semantic_map)
else:
semantic_segmentation = logits.argmax(dim=1)
semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]
return semantic_segmentation
.\models\mobilevit\modeling_mobilevit.py
""" PyTorch MobileViT model."""
import math
from typing import Dict, Optional, Set, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_outputs import (
BaseModelOutputWithNoAttention,
BaseModelOutputWithPoolingAndNoAttention,
ImageClassifierOutputWithNoAttention,
SemanticSegmenterOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "MobileViTConfig"
_CHECKPOINT_FOR_DOC = "apple/mobilevit-small"
_EXPECTED_OUTPUT_SHAPE = [1, 640, 8, 8]
_IMAGE_CLASS_CHECKPOINT = "apple/mobilevit-small"
_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
MOBILEVIT_PRETRAINED_MODEL_ARCHIVE_LIST = [
"apple/mobilevit-small",
"apple/mobilevit-x-small",
"apple/mobilevit-xx-small",
"apple/deeplabv3-mobilevit-small",
"apple/deeplabv3-mobilevit-x-small",
"apple/deeplabv3-mobilevit-xx-small",
]
def make_divisible(value: int, divisor: int = 8, min_value: Optional[int] = None) -> int:
"""
Ensure that all layers have a channel count that is divisible by `divisor`. This function is taken from the
original TensorFlow repo. It can be seen here:
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
"""
if min_value is None:
min_value = divisor
new_value = max(min_value, int(value + divisor / 2) // divisor * divisor)
if new_value < 0.9 * value:
new_value += divisor
return int(new_value)
class MobileViTConvLayer(nn.Module):
def __init__(
self,
config: MobileViTConfig,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
groups: int = 1,
bias: bool = False,
dilation: int = 1,
use_normalization: bool = True,
use_activation: Union[bool, str] = True,
) -> None:
super().__init__()
padding = int((kernel_size - 1) / 2) * dilation
if in_channels % groups != 0:
raise ValueError(f"Input channels ({in_channels}) are not divisible by {groups} groups.")
if out_channels % groups != 0:
raise ValueError(f"Output channels ({out_channels}) are not divisible by {groups} groups.")
self.convolution = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
padding_mode="zeros",
)
if use_normalization:
self.normalization = nn.BatchNorm2d(
num_features=out_channels,
eps=1e-5,
momentum=0.1,
affine=True,
track_running_stats=True,
)
else:
self.normalization = None
if use_activation:
if isinstance(use_activation, str):
self.activation = ACT2FN[use_activation]
elif isinstance(config.hidden_act, str):
self.activation = ACT2FN[config.hidden_act]
else:
self.activation = config.hidden_act
else:
self.activation = None
def forward(self, features: torch.Tensor) -> torch.Tensor:
features = self.convolution(features)
if self.normalization is not None:
features = self.normalization(features)
if self.activation is not None:
features = self.activation(features)
return features
class MobileViTInvertedResidual(nn.Module):
"""
Inverted residual block (MobileNetv2): https://arxiv.org/abs/1801.04381
"""
def __init__(
self, config: MobileViTConfig, in_channels: int, out_channels: int, stride: int, dilation: int = 1
) -> None:
super().__init__()
expanded_channels = make_divisible(int(round(in_channels * config.expand_ratio)), 8)
if stride not in [1, 2]:
raise ValueError(f"Invalid stride {stride}.")
self.use_residual = (stride == 1) and (in_channels == out_channels)
self.expand_1x1 = MobileViTConvLayer(
config, in_channels=in_channels, out_channels=expanded_channels, kernel_size=1
)
self.conv_3x3 = MobileViTConvLayer(
config,
in_channels=expanded_channels,
out_channels=expanded_channels,
kernel_size=3,
stride=stride,
groups=expanded_channels,
dilation=dilation,
)
self.reduce_1x1 = MobileViTConvLayer(
config,
in_channels=expanded_channels,
out_channels=out_channels,
kernel_size=1,
use_activation=False,
)
def forward(self, features: torch.Tensor) -> torch.Tensor:
residual = features
features = self.expand_1x1(features)
features = self.conv_3x3(features)
features = self.reduce_1x1(features)
return residual + features if self.use_residual else features
class MobileViTMobileNetLayer(nn.Module):
def __init__(
self, config: MobileViTConfig, in_channels: int, out_channels: int, stride: int = 1, num_stages: int = 1
) -> None:
super().__init__()
self.layer = nn.ModuleList()
for i in range(num_stages):
layer = MobileViTInvertedResidual(
config,
in_channels=in_channels,
out_channels=out_channels,
stride=stride if i == 0 else 1,
)
self.layer.append(layer)
in_channels = out_channels
def forward(self, features: torch.Tensor) -> torch.Tensor:
for layer_module in self.layer:
features = layer_module(features)
return features
class MobileViTSelfAttention(nn.Module):
pass
def __init__(self, config: MobileViTConfig, hidden_size: int) -> None:
super().__init__()
if hidden_size % config.num_attention_heads != 0:
raise ValueError(
f"The hidden size {hidden_size,} is not a multiple of the number of attention "
f"heads {config.num_attention_heads}."
)
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(hidden_size, self.all_head_size, bias=config.qkv_bias)
self.key = nn.Linear(hidden_size, self.all_head_size, bias=config.qkv_bias)
self.value = nn.Linear(hidden_size, self.all_head_size, bias=config.qkv_bias)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
attention_probs = self.dropout(attention_probs)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
return context_layer
class MobileViTSelfOutput(nn.Module):
def __init__(self, config: MobileViTConfig, hidden_size: int) -> None:
super().__init__()
self.dense = nn.Linear(hidden_size, hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states
class MobileViTAttention(nn.Module):
def __init__(self, config: MobileViTConfig, hidden_size: int) -> None:
super().__init__()
self.attention = MobileViTSelfAttention(config, hidden_size)
self.output = MobileViTSelfOutput(config, hidden_size)
self.pruned_heads = set()
def prune_heads(self, heads: Set[int]) -> None:
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(
heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
)
self.attention.query = prune_linear_layer(self.attention.query, index)
self.attention.key = prune_linear_layer(self.attention.key, index)
self.attention.value = prune_linear_layer(self.attention.value, index)
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
self_outputs = self.attention(hidden_states)
attention_output = self.output(self_outputs)
return attention_output
class MobileViTIntermediate(nn.Module):
def __init__(self, config: MobileViTConfig, hidden_size: int, intermediate_size: int) -> None:
super().__init__()
self.dense = nn.Linear(hidden_size, intermediate_size)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class MobileViTOutput(nn.Module):
def __init__(self, config: MobileViTConfig, hidden_size: int, intermediate_size: int) -> None:
super().__init__()
self.dense = nn.Linear(intermediate_size, hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = hidden_states + input_tensor
return hidden_states
class MobileViTTransformerLayer(nn.Module):
def __init__(self, config: MobileViTConfig, hidden_size: int, intermediate_size: int) -> None:
super().__init__()
self.attention = MobileViTAttention(config, hidden_size)
self.intermediate = MobileViTIntermediate(config, hidden_size, intermediate_size)
self.output = MobileViTOutput(config, hidden_size, intermediate_size)
self.layernorm_before = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
self.layernorm_after = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
attention_output = self.attention(self.layernorm_before(hidden_states))
hidden_states = attention_output + hidden_states
layer_output = self.layernorm_after(hidden_states)
layer_output = self.intermediate(layer_output)
layer_output = self.output(layer_output, hidden_states)
return layer_output
class MobileViTTransformer(nn.Module):
def __init__(self, config: MobileViTConfig, hidden_size: int, num_stages: int) -> None:
super().__init__()
self.layer = nn.ModuleList()
for _ in range(num_stages):
transformer_layer = MobileViTTransformerLayer(
config,
hidden_size=hidden_size,
intermediate_size=int(hidden_size * config.mlp_ratio),
)
self.layer.append(transformer_layer)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
for layer_module in self.layer:
hidden_states = layer_module(hidden_states)
return hidden_states
class MobileViTLayer(nn.Module):
"""
MobileViT block: https://arxiv.org/abs/2110.02178
"""
def __init__(
self,
config: MobileViTConfig,
in_channels: int,
out_channels: int,
stride: int,
hidden_size: int,
num_stages: int,
dilation: int = 1,
) -> None:
super().__init__()
self.patch_width = config.patch_size
self.patch_height = config.patch_size
if stride == 2:
self.downsampling_layer = MobileViTInvertedResidual(
config,
in_channels=in_channels,
out_channels=out_channels,
stride=stride if dilation == 1 else 1,
dilation=dilation // 2 if dilation > 1 else 1,
)
in_channels = out_channels
else:
self.downsampling_layer = None
self.conv_kxk = MobileViTConvLayer(
config,
in_channels=in_channels,
out_channels=in_channels,
kernel_size=config.conv_kernel_size,
)
self.conv_1x1 = MobileViTConvLayer(
config,
in_channels=in_channels,
out_channels=hidden_size,
kernel_size=1,
use_normalization=False,
use_activation=False,
)
self.transformer = MobileViTTransformer(
config,
hidden_size=hidden_size,
num_stages=num_stages,
)
self.layernorm = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
self.conv_projection = MobileViTConvLayer(
config, in_channels=hidden_size, out_channels=in_channels, kernel_size=1
)
self.fusion = MobileViTConvLayer(
config, in_channels=2 * in_channels, out_channels=in_channels, kernel_size=config.conv_kernel_size
)
def unfolding(self, features: torch.Tensor) -> Tuple[torch.Tensor, Dict]:
patch_width, patch_height = self.patch_width, self.patch_height
patch_area = int(patch_width * patch_height)
batch_size, channels, orig_height, orig_width = features.shape
new_height = int(math.ceil(orig_height / patch_height) * patch_height)
new_width = int(math.ceil(orig_width / patch_width) * patch_width)
interpolate = False
if new_width != orig_width or new_height != orig_height:
features = nn.functional.interpolate(
features, size=(new_height, new_width), mode="bilinear", align_corners=False
)
interpolate = True
num_patch_width = new_width // patch_width
num_patch_height = new_height // patch_height
num_patches = num_patch_height * num_patch_width
patches = features.reshape(
batch_size * channels * num_patch_height, patch_height, num_patch_width, patch_width
)
patches = patches.transpose(1, 2)
patches = patches.reshape(batch_size, channels, num_patches, patch_area)
patches = patches.transpose(1, 3)
patches = patches.reshape(batch_size * patch_area, num_patches, -1)
info_dict = {
"orig_size": (orig_height, orig_width),
"batch_size": batch_size,
"channels": channels,
"interpolate": interpolate,
"num_patches": num_patches,
"num_patches_width": num_patch_width,
"num_patches_height": num_patch_height,
}
return patches, info_dict
def folding(self, patches: torch.Tensor, info_dict: Dict) -> torch.Tensor:
patch_width, patch_height = self.patch_width, self.patch_height
patch_area = int(patch_width * patch_height)
batch_size = info_dict["batch_size"]
channels = info_dict["channels"]
num_patches = info_dict["num_patches"]
num_patch_height = info_dict["num_patches_height"]
num_patch_width = info_dict["num_patches_width"]
features = patches.contiguous().view(batch_size, patch_area, num_patches, -1)
features = features.transpose(1, 3)
features = features.reshape(
batch_size * channels * num_patch_height, num_patch_width, patch_height, patch_width
)
features = features.transpose(1, 2)
features = features.reshape(
batch_size, channels, num_patch_height * patch_height, num_patch_width * patch_width
)
if info_dict["interpolate"]:
features = nn.functional.interpolate(
features, size=info_dict["orig_size"], mode="bilinear", align_corners=False
)
return features
def forward(self, features: torch.Tensor) -> torch.Tensor:
if self.downsampling_layer:
features = self.downsampling_layer(features)
residual = features
features = self.conv_kxk(features)
features = self.conv_1x1(features)
patches, info_dict = self.unfolding(features)
patches = self.transformer(patches)
patches = self.layernorm(patches)
features = self.folding(patches, info_dict)
features = self.conv_projection(features)
features = self.fusion(torch.cat((residual, features), dim=1))
return features
class MobileViTEncoder(nn.Module):
def __init__(self, config: MobileViTConfig) -> None:
super().__init__()
self.config = config
self.layer = nn.ModuleList()
self.gradient_checkpointing = False
dilate_layer_4 = dilate_layer_5 = False
if config.output_stride == 8:
dilate_layer_4 = True
dilate_layer_5 = True
elif config.output_stride == 16:
dilate_layer_5 = True
dilation = 1
layer_1 = MobileViTMobileNetLayer(
config,
in_channels=config.neck_hidden_sizes[0],
out_channels=config.neck_hidden_sizes[1],
stride=1,
num_stages=1,
)
self.layer.append(layer_1)
layer_2 = MobileViTMobileNetLayer(
config,
in_channels=config.neck_hidden_sizes[1],
out_channels=config.neck_hidden_sizes[2],
stride=2,
num_stages=3,
)
self.layer.append(layer_2)
layer_3 = MobileViTLayer(
config,
in_channels=config.neck_hidden_sizes[2],
out_channels=config.neck_hidden_sizes[3],
stride=2,
hidden_size=config.hidden_sizes[0],
num_stages=2,
)
self.layer.append(layer_3)
if dilate_layer_4:
dilation *= 2
layer_4 = MobileViTLayer(
config,
in_channels=config.neck_hidden_sizes[3],
out_channels=config.neck_hidden_sizes[4],
stride=2,
hidden_size=config.hidden_sizes[1],
num_stages=4,
dilation=dilation,
)
self.layer.append(layer_4)
if dilate_layer_5:
dilation *= 2
layer_5 = MobileViTLayer(
config,
in_channels=config.neck_hidden_sizes[4],
out_channels=config.neck_hidden_sizes[5],
stride=2,
hidden_size=config.hidden_sizes[2],
num_stages=3,
dilation=dilation,
)
self.layer.append(layer_5)
def forward(
self,
hidden_states: torch.Tensor,
output_hidden_states: bool = False,
return_dict: bool = True,
) -> Union[tuple, BaseModelOutputWithNoAttention]:
all_hidden_states = () if output_hidden_states else None
for i, layer_module in enumerate(self.layer):
if self.gradient_checkpointing and self.training:
hidden_states = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
)
else:
hidden_states = layer_module(hidden_states)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)
return BaseModelOutputWithNoAttention(last_hidden_state=hidden_states, hidden_states=all_hidden_states)
class MobileViTPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = MobileViTConfig
base_model_prefix = "mobilevit"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d)):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
MOBILEVIT_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
behavior.
Parameters:
config ([`MobileViTConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
MOBILEVIT_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
[`MobileViTImageProcessor.__call__`] for details.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
@add_start_docstrings(
"The bare MobileViT model outputting raw hidden-states without any specific head on top.",
MOBILEVIT_START_DOCSTRING,
)
class MobileViTModel(MobileViTPreTrainedModel):
"""
MobileViTModel extends MobileViTPreTrainedModel to include specific functionalities for the MobileViT model.
Inherits from:
`MobileViTPreTrainedModel`: Provides general initialization and weights handling functionalities.
Docstring from `add_start_docstrings` decorator:
"The bare MobileViT model outputting raw hidden-states without any specific head on top."
MOBILEVIT_START_DOCSTRING: Detailed documentation regarding model usage and configuration parameters.
"""
def __init__(self, config: MobileViTConfig, expand_output: bool = True):
super().__init__(config)
self.config = config
self.expand_output = expand_output
self.conv_stem = MobileViTConvLayer(
config,
in_channels=config.num_channels,
out_channels=config.neck_hidden_sizes[0],
kernel_size=3,
stride=2,
)
self.encoder = MobileViTEncoder(config)
if self.expand_output:
self.conv_1x1_exp = MobileViTConvLayer(
config,
in_channels=config.neck_hidden_sizes[5],
out_channels=config.neck_hidden_sizes[6],
kernel_size=1,
)
self.post_init()
def _prune_heads(self, heads_to_prune):
"""Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base class PreTrainedModel
"""
for layer_index, heads in heads_to_prune.items():
mobilevit_layer = self.encoder.layer[layer_index]
if isinstance(mobilevit_layer, MobileViTLayer):
for transformer_layer in mobilevit_layer.transformer.layer:
transformer_layer.attention.prune_heads(heads)
@add_start_docstrings_to_model_forward(MOBILEVIT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=BaseModelOutputWithPoolingAndNoAttention,
config_class=_CONFIG_FOR_DOC,
modality="vision",
expected_output=_EXPECTED_OUTPUT_SHAPE,
)
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
self,
pixel_values: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[tuple, BaseModelOutputWithPoolingAndNoAttention]:
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
embedding_output = self.conv_stem(pixel_values)
encoder_outputs = self.encoder(
embedding_output,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
if self.expand_output:
last_hidden_state = self.conv_1x1_exp(encoder_outputs[0])
pooled_output = torch.mean(last_hidden_state, dim=[-2, -1], keepdim=False)
else:
last_hidden_state = encoder_outputs[0]
pooled_output = None
if not return_dict:
output = (last_hidden_state, pooled_output) if pooled_output is not None else (last_hidden_state,)
return output + encoder_outputs[1:]
return BaseModelOutputWithPoolingAndNoAttention(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
)
@add_start_docstrings(
"""
MobileViT model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
ImageNet.
""",
MOBILEVIT_START_DOCSTRING,
)
class MobileViTForImageClassification(MobileViTPreTrainedModel):
def __init__(self, config: MobileViTConfig) -> None:
super().__init__(config)
self.num_labels = config.num_labels
self.mobilevit = MobileViTModel(config)
self.dropout = nn.Dropout(config.classifier_dropout_prob, inplace=True)
self.classifier = (
nn.Linear(config.neck_hidden_sizes[-1], config.num_labels) if config.num_labels > 0 else nn.Identity()
)
self.post_init()
@add_start_docstrings_to_model_forward(MOBILEVIT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_IMAGE_CLASS_CHECKPOINT,
output_type=ImageClassifierOutputWithNoAttention,
config_class=_CONFIG_FOR_DOC,
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
)
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = None,
labels: Optional[torch.Tensor] = None,
return_dict: Optional[bool] = None,
) -> Union[tuple, ImageClassifierOutputWithNoAttention]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss). If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.mobilevit(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict)
pooled_output = outputs.pooler_output if return_dict else outputs[1]
logits = self.classifier(self.dropout(pooled_output))
loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return ImageClassifierOutputWithNoAttention(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
)
class MobileViTASPPPooling(nn.Module):
def __init__(self, config: MobileViTConfig, in_channels: int, out_channels: int) -> None:
super().__init__()
self.global_pool = nn.AdaptiveAvgPool2d(output_size=1)
self.conv_1x1 = MobileViTConvLayer(
config,
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=1,
use_normalization=True,
use_activation="relu",
)
def forward(self, features: torch.Tensor) -> torch.Tensor:
spatial_size = features.shape[-2:]
features = self.global_pool(features)
features = self.conv_1x1(features)
features = nn.functional.interpolate(features, size=spatial_size, mode="bilinear", align_corners=False)
return features
class MobileViTASPP(nn.Module):
"""
ASPP module defined in DeepLab papers: https://arxiv.org/abs/1606.00915, https://arxiv.org/abs/1706.05587
"""
def __init__(self, config: MobileViTConfig) -> None:
super().__init__()
in_channels = config.neck_hidden_sizes[-2]
out_channels = config.aspp_out_channels
if len(config.atrous_rates) != 3:
raise ValueError("Expected 3 values for atrous_rates")
self.convs = nn.ModuleList()
in_projection = MobileViTConvLayer(
config,
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
use_activation="relu",
)
self.convs.append(in_projection)
self.convs.extend(
[
MobileViTConvLayer(
config,
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
dilation=rate,
use_activation="relu",
)
for rate in config.atrous_rates
]
)
pool_layer = MobileViTASPPPooling(config, in_channels, out_channels)
self.convs.append(pool_layer)
self.project = MobileViTConvLayer(
config, in_channels=5 * out_channels, out_channels=out_channels, kernel_size=1, use_activation="relu"
)
self.dropout = nn.Dropout(p=config.aspp_dropout_prob)
def forward(self, features: torch.Tensor) -> torch.Tensor:
pyramid = []
for conv in self.convs:
pyramid.append(conv(features))
pyramid = torch.cat(pyramid, dim=1)
pooled_features = self.project(pyramid)
pooled_features = self.dropout(pooled_features)
return pooled_features
class MobileViTDeepLabV3(nn.Module):
"""
DeepLabv3 architecture: https://arxiv.org/abs/1706.05587
"""
def __init__(self, config: MobileViTConfig) -> None:
super().__init__()
self.aspp = MobileViTASPP(config)
self.dropout = nn.Dropout2d(config.classifier_dropout_prob)
self.classifier = MobileViTConvLayer(
config,
in_channels=config.aspp_out_channels,
out_channels=config.num_labels,
kernel_size=1,
use_normalization=False,
use_activation=False,
bias=True,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
features = self.aspp(hidden_states[-1])
features = self.dropout(features)
features = self.classifier(features)
return features
@add_start_docstrings(
"""
MobileViT model with a semantic segmentation head on top, e.g. for Pascal VOC.
""",
MOBILEVIT_START_DOCSTRING,
)
class MobileViTForSemanticSegmentation(MobileViTPreTrainedModel):
def __init__(self, config: MobileViTConfig) -> None:
super().__init__(config)
self.num_labels = config.num_labels
self.mobilevit = MobileViTModel(config, expand_output=False)
self.segmentation_head = MobileViTDeepLabV3(config)
self.post_init()
@add_start_docstrings_to_model_forward(MOBILEVIT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=SemanticSegmenterOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.mobilevit(
pixel_values,
output_hidden_states=True,
return_dict=return_dict,
)
获取是否输出隐藏状态和返回类型的设定,若未指定则使用模型配置中的默认设定。
encoder_hidden_states = outputs.hidden_states if return_dict else outputs[1]
根据返回类型决定使用模型输出的隐藏状态或者第二个元素作为编码器的隐藏状态。
logits = self.segmentation_head(encoder_hidden_states)
使用编码器隐藏状态作为输入,通过分割头部生成预测的logits。
loss = None
if labels is not None:
if self.config.num_labels == 1:
raise ValueError("The number of labels should be greater than one")
else:
upsampled_logits = nn.functional.interpolate(
logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
)
loss_fct = CrossEntropyLoss(ignore_index=self.config.semantic_loss_ignore_index)
loss = loss_fct(upsampled_logits, labels)
如果提供了标签,根据标签的形状和配置中的忽略索引,使用交叉熵损失函数计算损失值。
if not return_dict:
if output_hidden_states:
output = (logits,) + outputs[1:]
else:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
根据返回类型和是否输出隐藏状态,构建输出元组并返回。如果有损失值,则将其作为第一个元素返回。
return SemanticSegmenterOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states if output_hidden_states else None,
attentions=None,
)
以自定义的输出对象形式返回结果,包括损失、logits、隐藏状态(如果需要)和注意力机制(目前为None)。```
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.mobilevit(
pixel_values,
output_hidden_states=True,
return_dict=return_dict,
)
设置是否输出隐藏状态和返回类型的选择,如果未指定则使用模型配置中的默认设置。
encoder_hidden_states = outputs.hidden_states if return_dict else outputs[1]
根据返回类型决定使用模型输出的隐藏状态或者第二个元素作为编码器的隐藏状态。
logits = self.segmentation_head(encoder_hidden_states)
使用编码器隐藏状态作为输入,通过分割头部生成预测的logits。
loss = None
if labels is not None:
if self.config.num_labels == 1:
raise ValueError("The number of labels should be greater than one")
else:
upsampled_logits = nn.functional.interpolate(
logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
)
loss_fct = CrossEntropyLoss(ignore_index=self.config.semantic_loss_ignore_index)
loss = loss_fct(upsampled_logits, labels)
如果提供了标签,根据标签的形状和配置中的忽略索引,使用交叉熵损失函数计算损失值。
if not return_dict:
if output_hidden_states:
output = (logits,) + outputs[1:]
else:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
根据返回类型和是否输出隐藏状态,构建输出元组并返回。如果有损失值,则将其作为第一个元素返回。
return SemanticSegmenterOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states if output_hidden_states else None,
attentions=None,
)
以自定义的输出对象形式返回结果,包括损失、logits、隐藏状态(如果需要)和注意力机制(目前为None)。
.\models\mobilevit\modeling_tf_mobilevit.py
""" TensorFlow 2.0 MobileViT 模型。"""
from __future__ import annotations
from typing import Dict, Optional, Tuple, Union
import tensorflow as tf
from ...activations_tf import get_tf_activation
from ...file_utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from ...modeling_tf_outputs import (
TFBaseModelOutput,
TFBaseModelOutputWithPooling,
TFImageClassifierOutputWithNoAttention,
TFSemanticSegmenterOutputWithNoAttention,
)
from ...modeling_tf_utils import (
TFPreTrainedModel,
TFSequenceClassificationLoss,
keras,
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list, stable_softmax
from ...utils import logging
from .configuration_mobilevit import MobileViTConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "MobileViTConfig"
_CHECKPOINT_FOR_DOC = "apple/mobilevit-small"
_EXPECTED_OUTPUT_SHAPE = [1, 640, 8, 8]
_IMAGE_CLASS_CHECKPOINT = "apple/mobilevit-small"
_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
TF_MOBILEVIT_PRETRAINED_MODEL_ARCHIVE_LIST = [
"apple/mobilevit-small",
"apple/mobilevit-x-small",
"apple/mobilevit-xx-small",
"apple/deeplabv3-mobilevit-small",
"apple/deeplabv3-mobilevit-x-small",
"apple/deeplabv3-mobilevit-xx-small",
]
def make_divisible(value: int, divisor: int = 8, min_value: Optional[int] = None) -> int:
"""
确保所有层的通道数量可被 `divisor` 整除。此函数源自原始 TensorFlow 仓库,可在以下链接找到:
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
"""
if min_value is None:
min_value = divisor
new_value = max(min_value, int(value + divisor / 2) // divisor * divisor)
if new_value < 0.9 * value:
new_value += divisor
return int(new_value)
class TFMobileViTConvLayer(keras.layers.Layer):
def __init__(
self,
config: MobileViTConfig,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
groups: int = 1,
bias: bool = False,
dilation: int = 1,
use_normalization: bool = True,
use_activation: Union[bool, str] = True,
**kwargs,
) -> None:
super().__init__(**kwargs)
logger.warning(
f"\n{self.__class__.__name__} has backpropagation operations that are NOT supported on CPU. If you wish "
"to train/fine-tune this model, you need a GPU or a TPU"
)
padding = int((kernel_size - 1) / 2) * dilation
self.padding = keras.layers.ZeroPadding2D(padding)
if out_channels % groups != 0:
raise ValueError(f"Output channels ({out_channels}) are not divisible by {groups} groups.")
self.convolution = keras.layers.Conv2D(
filters=out_channels,
kernel_size=kernel_size,
strides=stride,
padding="VALID",
dilation_rate=dilation,
groups=groups,
use_bias=bias,
name="convolution",
)
if use_normalization:
self.normalization = keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.1, name="normalization")
else:
self.normalization = None
if use_activation:
if isinstance(use_activation, str):
self.activation = get_tf_activation(use_activation)
elif isinstance(config.hidden_act, str):
self.activation = get_tf_activation(config.hidden_act)
else:
self.activation = config.hidden_act
else:
self.activation = None
self.in_channels = in_channels
self.out_channels = out_channels
def call(self, features: tf.Tensor, training: bool = False) -> tf.Tensor:
padded_features = self.padding(features)
features = self.convolution(padded_features)
if self.normalization is not None:
features = self.normalization(features, training=training)
if self.activation is not None:
features = self.activation(features)
return features
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "convolution", None) is not None:
with tf.name_scope(self.convolution.name):
self.convolution.build([None, None, None, self.in_channels])
if getattr(self, "normalization", None) is not None:
if hasattr(self.normalization, "name"):
with tf.name_scope(self.normalization.name):
self.normalization.build([None, None, None, self.out_channels])
class TFMobileViTInvertedResidual(keras.layers.Layer):
"""
Inverted residual block (MobileNetv2): https://arxiv.org/abs/1801.04381
"""
def __init__(
self, config: MobileViTConfig, in_channels: int, out_channels: int, stride: int, dilation: int = 1, **kwargs
) -> None:
super().__init__(**kwargs)
expanded_channels = make_divisible(int(round(in_channels * config.expand_ratio)), 8)
if stride not in [1, 2]:
raise ValueError(f"Invalid stride {stride}.")
self.use_residual = (stride == 1) and (in_channels == out_channels)
self.expand_1x1 = TFMobileViTConvLayer(
config, in_channels=in_channels, out_channels=expanded_channels, kernel_size=1, name="expand_1x1"
)
self.conv_3x3 = TFMobileViTConvLayer(
config,
in_channels=expanded_channels,
out_channels=expanded_channels,
kernel_size=3,
stride=stride,
groups=expanded_channels,
dilation=dilation,
name="conv_3x3",
)
self.reduce_1x1 = TFMobileViTConvLayer(
config,
in_channels=expanded_channels,
out_channels=out_channels,
kernel_size=1,
use_activation=False,
name="reduce_1x1",
)
def call(self, features: tf.Tensor, training: bool = False) -> tf.Tensor:
residual = features
features = self.expand_1x1(features, training=training)
features = self.conv_3x3(features, training=training)
features = self.reduce_1x1(features, training=training)
return residual + features if self.use_residual else features
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "expand_1x1", None) is not None:
with tf.name_scope(self.expand_1x1.name):
self.expand_1x1.build(None)
if getattr(self, "conv_3x3", None) is not None:
with tf.name_scope(self.conv_3x3.name):
self.conv_3x3.build(None)
if getattr(self, "reduce_1x1", None) is not None:
with tf.name_scope(self.reduce_1x1.name):
self.reduce_1x1.build(None)
class TFMobileViTMobileNetLayer(keras.layers.Layer):
def __init__(
self,
config: MobileViTConfig,
in_channels: int,
out_channels: int,
stride: int = 1,
num_stages: int = 1,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.layers = []
for i in range(num_stages):
layer = TFMobileViTInvertedResidual(
config,
in_channels=in_channels,
out_channels=out_channels,
stride=stride if i == 0 else 1,
name=f"layer.{i}",
)
self.layers.append(layer)
in_channels = out_channels
def call(self, features: tf.Tensor, training: bool = False) -> tf.Tensor:
for layer_module in self.layers:
features = layer_module(features, training=training)
return features
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "layers", None) is not None:
for layer_module in self.layers:
with tf.name_scope(layer_module.name):
layer_module.build(None)
class TFMobileViTSelfAttention(keras.layers.Layer):
def __init__(self, config: MobileViTConfig, hidden_size: int, **kwargs) -> None:
super().__init__(**kwargs)
if hidden_size % config.num_attention_heads != 0:
raise ValueError(
f"The hidden size {hidden_size,} is not a multiple of the number of attention "
f"heads {config.num_attention_heads}."
)
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
scale = tf.cast(self.attention_head_size, dtype=tf.float32)
self.scale = tf.math.sqrt(scale)
self.query = keras.layers.Dense(self.all_head_size, use_bias=config.qkv_bias, name="query")
self.key = keras.layers.Dense(self.all_head_size, use_bias=config.qkv_bias, name="key")
self.value = keras.layers.Dense(self.all_head_size, use_bias=config.qkv_bias, name="value")
self.dropout = keras.layers.Dropout(config.attention_probs_dropout_prob)
self.hidden_size = hidden_size
def transpose_for_scores(self, x: tf.Tensor) -> tf.Tensor:
batch_size = tf.shape(x)[0]
x = tf.reshape(x, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))
return tf.transpose(x, perm=[0, 2, 1, 3])
def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor:
batch_size = tf.shape(hidden_states)[0]
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(self.query(hidden_states))
attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
attention_scores = attention_scores / self.scale
attention_probs = stable_softmax(attention_scores, axis=-1)
attention_probs = self.dropout(attention_probs, training=training)
context_layer = tf.matmul(attention_probs, value_layer)
context_layer = tf.transpose(context_layer, perm=[0, 2, 1, 3])
context_layer = tf.reshape(context_layer, shape=(batch_size, -1, self.all_head_size))
return context_layer
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "query", None) is not None:
with tf.name_scope(self.query.name):
self.query.build([None, None, self.hidden_size])
if getattr(self, "key", None) is not None:
with tf.name_scope(self.key.name):
self.key.build([None, None, self.hidden_size])
if getattr(self, "value", None) is not None:
with tf.name_scope(self.value.name):
self.value.build([None, None, self.hidden_size])
class TFMobileViTSelfOutput(keras.layers.Layer):
def __init__(self, config: MobileViTConfig, hidden_size: int, **kwargs) -> None:
super().__init__(**kwargs)
self.dense = keras.layers.Dense(hidden_size, name="dense")
self.dropout = keras.layers.Dropout(config.hidden_dropout_prob)
self.hidden_size = hidden_size
def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states, training=training)
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.hidden_size])
class TFMobileViTAttention(keras.layers.Layer):
def __init__(self, config: MobileViTConfig, hidden_size: int, **kwargs) -> None:
super().__init__(**kwargs)
self.attention = TFMobileViTSelfAttention(config, hidden_size, name="attention")
self.dense_output = TFMobileViTSelfOutput(config, hidden_size, name="output")
def prune_heads(self, heads):
raise NotImplementedError
def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor:
self_outputs = self.attention(hidden_states, training=training)
attention_output = self.dense_output(self_outputs, training=training)
return attention_output
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "attention", None) is not None:
with tf.name_scope(self.attention.name):
self.attention.build(None)
if getattr(self, "dense_output", None) is not None:
with tf.name_scope(self.dense_output.name):
self.dense_output.build(None)
class TFMobileViTIntermediate(keras.layers.Layer):
def __init__(self, config: MobileViTConfig, hidden_size: int, intermediate_size: int, **kwargs) -> None:
super().__init__(**kwargs)
self.dense = keras.layers.Dense(intermediate_size, name="dense")
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = get_tf_activation(config.hidden_act)
else:
self.intermediate_act_fn = config.hidden_act
self.hidden_size = hidden_size
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.hidden_size])
class TFMobileViTOutput(keras.layers.Layer):
def __init__(self, config: MobileViTConfig, hidden_size: int, intermediate_size: int, **kwargs) -> None:
super().__init__(**kwargs)
self.dense = keras.layers.Dense(hidden_size, name="dense")
self.dropout = keras.layers.Dropout(config.hidden_dropout_prob)
self.intermediate_size = intermediate_size
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states, training=training)
hidden_states = hidden_states + input_tensor
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.intermediate_size])
class TFMobileViTTransformerLayer(keras.layers.Layer):
def __init__(self, config: MobileViTConfig, hidden_size: int, intermediate_size: int, **kwargs) -> None:
super().__init__(**kwargs)
self.attention = TFMobileViTAttention(config, hidden_size, name="attention")
self.intermediate = TFMobileViTIntermediate(config, hidden_size, intermediate_size, name="intermediate")
self.mobilevit_output = TFMobileViTOutput(config, hidden_size, intermediate_size, name="output")
self.layernorm_before = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm_before")
self.layernorm_after = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm_after")
self.hidden_size = hidden_size
def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor:
attention_output = self.attention(self.layernorm_before(hidden_states), training=training)
hidden_states = attention_output + hidden_states
layer_output = self.layernorm_after(hidden_states)
layer_output = self.intermediate(layer_output)
layer_output = self.mobilevit_output(layer_output, hidden_states, training=training)
return layer_output
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "attention", None) is not None:
with tf.name_scope(self.attention.name):
self.attention.build(None)
if getattr(self, "intermediate", None) is not None:
with tf.name_scope(self.intermediate.name):
self.intermediate.build(None)
if getattr(self, "mobilevit_output", None) is not None:
with tf.name_scope(self.mobilevit_output.name):
self.mobilevit_output.build(None)
if getattr(self, "layernorm_before", None) is not None:
with tf.name_scope(self.layernorm_before.name):
self.layernorm_before.build([None, None, self.hidden_size])
if getattr(self, "layernorm_after", None) is not None:
with tf.name_scope(self.layernorm_after.name):
self.layernorm_after.build([None, None, self.hidden_size])
class TFMobileViTTransformer(keras.layers.Layer):
def __init__(self, config: MobileViTConfig, hidden_size: int, num_stages: int, **kwargs) -> None:
super().__init__(**kwargs)
self.layers = []
for i in range(num_stages):
transformer_layer = TFMobileViTTransformerLayer(
config,
hidden_size=hidden_size,
intermediate_size=int(hidden_size * config.mlp_ratio),
name=f"layer.{i}",
)
self.layers.append(transformer_layer)
def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor:
for layer_module in self.layers:
hidden_states = layer_module(hidden_states, training=training)
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "layers", None) is not None:
for layer_module in self.layers:
with tf.name_scope(layer_module.name):
layer_module.build(None)
"""
MobileViT block: https://arxiv.org/abs/2110.02178
"""
def __init__(
self,
config: MobileViTConfig,
in_channels: int,
out_channels: int,
stride: int,
hidden_size: int,
num_stages: int,
dilation: int = 1,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.patch_width = config.patch_size
self.patch_height = config.patch_size
if stride == 2:
self.downsampling_layer = TFMobileViTInvertedResidual(
config,
in_channels=in_channels,
out_channels=out_channels,
stride=stride if dilation == 1 else 1,
dilation=dilation // 2 if dilation > 1 else 1,
name="downsampling_layer",
)
in_channels = out_channels
else:
self.downsampling_layer = None
self.conv_kxk = TFMobileViTConvLayer(
config,
in_channels=in_channels,
out_channels=in_channels,
kernel_size=config.conv_kernel_size,
name="conv_kxk",
)
self.conv_1x1 = TFMobileViTConvLayer(
config,
in_channels=in_channels,
out_channels=hidden_size,
kernel_size=1,
use_normalization=False,
use_activation=False,
name="conv_1x1",
)
self.transformer = TFMobileViTTransformer(
config, hidden_size=hidden_size, num_stages=num_stages, name="transformer"
)
self.layernorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm")
self.conv_projection = TFMobileViTConvLayer(
config, in_channels=hidden_size, out_channels=in_channels, kernel_size=1, name="conv_projection"
)
self.fusion = TFMobileViTConvLayer(
config,
in_channels=2 * in_channels,
out_channels=in_channels,
kernel_size=config.conv_kernel_size,
name="fusion",
)
self.hidden_size = hidden_size
def unfolding(self, features: tf.Tensor) -> Tuple[tf.Tensor, Dict]:
patch_width, patch_height = self.patch_width, self.patch_height
patch_area = tf.cast(patch_width * patch_height, "int32")
batch_size = tf.shape(features)[0]
orig_height = tf.shape(features)[1]
orig_width = tf.shape(features)[2]
channels = tf.shape(features)[3]
new_height = tf.cast(tf.math.ceil(orig_height / patch_height) * patch_height, "int32")
new_width = tf.cast(tf.math.ceil(orig_width / patch_width) * patch_width, "int32")
interpolate = new_width != orig_width or new_height != orig_height
if interpolate:
features = tf.image.resize(features, size=(new_height, new_width), method="bilinear")
num_patch_width = new_width // patch_width
num_patch_height = new_height // patch_height
num_patches = num_patch_height * num_patch_width
features = tf.transpose(features, [0, 3, 1, 2])
patches = tf.reshape(
features, (batch_size * channels * num_patch_height, patch_height, num_patch_width, patch_width)
)
patches = tf.transpose(patches, [0, 2, 1, 3])
patches = tf.reshape(patches, (batch_size, channels, num_patches, patch_area))
patches = tf.transpose(patches, [0, 3, 2, 1])
patches = tf.reshape(patches, (batch_size * patch_area, num_patches, channels))
info_dict = {
"orig_size": (orig_height, orig_width),
"batch_size": batch_size,
"channels": channels,
"interpolate": interpolate,
"num_patches": num_patches,
"num_patches_width": num_patch_width,
"num_patches_height": num_patch_height,
}
return patches, info_dict
def folding(self, patches: tf.Tensor, info_dict: Dict) -> tf.Tensor:
patch_width, patch_height = self.patch_width, self.patch_height
patch_area = int(patch_width * patch_height)
batch_size = info_dict["batch_size"]
channels = info_dict["channels"]
num_patches = info_dict["num_patches"]
num_patch_height = info_dict["num_patches_height"]
num_patch_width = info_dict["num_patches_width"]
features = tf.reshape(patches, (batch_size, patch_area, num_patches, -1))
features = tf.transpose(features, perm=(0, 3, 2, 1))
features = tf.reshape(
features, (batch_size * channels * num_patch_height, num_patch_width, patch_height, patch_width)
)
features = tf.transpose(features, perm=(0, 2, 1, 3))
features = tf.reshape(
features, (batch_size, channels, num_patch_height * patch_height, num_patch_width * patch_width)
)
features = tf.transpose(features, perm=(0, 2, 3, 1))
if info_dict["interpolate"]:
features = tf.image.resize(features, size=info_dict["orig_size"], method="bilinear")
return features
def call(self, features: tf.Tensor, training: bool = False) -> tf.Tensor:
if self.downsampling_layer:
features = self.downsampling_layer(features, training=training)
residual = features
features = self.conv_kxk(features, training=training)
features = self.conv_1x1(features, training=training)
patches, info_dict = self.unfolding(features)
patches = self.transformer(patches, training=training)
patches = self.layernorm(patches)
features = self.folding(patches, info_dict)
features = self.conv_projection(features, training=training)
features = self.fusion(tf.concat([residual, features], axis=-1), training=training)
return features
if self.built:
return
self.built = True
if getattr(self, "conv_kxk", None) is not None:
with tf.name_scope(self.conv_kxk.name):
self.conv_kxk.build(None)
if getattr(self, "conv_1x1", None) is not None:
with tf.name_scope(self.conv_1x1.name):
self.conv_1x1.build(None)
if getattr(self, "transformer", None) is not None:
with tf.name_scope(self.transformer.name):
self.transformer.build(None)
if getattr(self, "layernorm", None) is not None:
with tf.name_scope(self.layernorm.name):
self.layernorm.build([None, None, self.hidden_size])
if getattr(self, "conv_projection", None) is not None:
with tf.name_scope(self.conv_projection.name):
self.conv_projection.build(None)
if getattr(self, "fusion", None) is not None:
with tf.name_scope(self.fusion.name):
self.fusion.build(None)
if getattr(self, "downsampling_layer", None) is not None:
with tf.name_scope(self.downsampling_layer.name):
self.downsampling_layer.build(None)
class TFMobileViTEncoder(keras.layers.Layer):
def __init__(self, config: MobileViTConfig, **kwargs) -> None:
super().__init__(**kwargs)
self.config = config
self.layers = []
dilate_layer_4 = dilate_layer_5 = False
if config.output_stride == 8:
dilate_layer_4 = True
dilate_layer_5 = True
elif config.output_stride == 16:
dilate_layer_5 = True
dilation = 1
layer_1 = TFMobileViTMobileNetLayer(
config,
in_channels=config.neck_hidden_sizes[0],
out_channels=config.neck_hidden_sizes[1],
stride=1,
num_stages=1,
name="layer.0",
)
self.layers.append(layer_1)
layer_2 = TFMobileViTMobileNetLayer(
config,
in_channels=config.neck_hidden_sizes[1],
out_channels=config.neck_hidden_sizes[2],
stride=2,
num_stages=3,
name="layer.1",
)
self.layers.append(layer_2)
layer_3 = TFMobileViTLayer(
config,
in_channels=config.neck_hidden_sizes[2],
out_channels=config.neck_hidden_sizes[3],
stride=2,
hidden_size=config.hidden_sizes[0],
num_stages=2,
name="layer.2",
)
self.layers.append(layer_3)
if dilate_layer_4:
dilation *= 2
layer_4 = TFMobileViTLayer(
config,
in_channels=config.neck_hidden_sizes[3],
out_channels=config.neck_hidden_sizes[4],
stride=2,
hidden_size=config.hidden_sizes[1],
num_stages=4,
dilation=dilation,
name="layer.3",
)
self.layers.append(layer_4)
if dilate_layer_5:
dilation *= 2
layer_5 = TFMobileViTLayer(
config,
in_channels=config.neck_hidden_sizes[4],
out_channels=config.neck_hidden_sizes[5],
stride=2,
hidden_size=config.hidden_sizes[2],
num_stages=3,
dilation=dilation,
name="layer.4",
)
self.layers.append(layer_5)
def call(
self,
hidden_states: tf.Tensor,
output_hidden_states: bool = False,
return_dict: bool = True,
training: bool = False,
):
) -> Union[tuple, TFBaseModelOutput]:
all_hidden_states = () if output_hidden_states else None
for i, layer_module in enumerate(self.layers):
hidden_states = layer_module(hidden_states, training=training)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)
return TFBaseModelOutput(last_hidden_state=hidden_states, hidden_states=all_hidden_states)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "layers", None) is not None:
for layer_module in self.layers:
with tf.name_scope(layer_module.name):
layer_module.build(None)
@keras_serializable
class TFMobileViTMainLayer(keras.layers.Layer):
config_class = MobileViTConfig
def __init__(self, config: MobileViTConfig, expand_output: bool = True, **kwargs):
super().__init__(**kwargs)
self.config = config
self.expand_output = expand_output
self.conv_stem = TFMobileViTConvLayer(
config,
in_channels=config.num_channels,
out_channels=config.neck_hidden_sizes[0],
kernel_size=3,
stride=2,
name="conv_stem",
)
self.encoder = TFMobileViTEncoder(config, name="encoder")
if self.expand_output:
self.conv_1x1_exp = TFMobileViTConvLayer(
config,
in_channels=config.neck_hidden_sizes[5],
out_channels=config.neck_hidden_sizes[6],
kernel_size=1,
name="conv_1x1_exp",
)
self.pooler = keras.layers.GlobalAveragePooling2D(data_format="channels_first", name="pooler")
def _prune_heads(self, heads_to_prune):
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
raise NotImplementedError
@unpack_inputs
def call(
self,
pixel_values: tf.Tensor | None = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
) -> Union[Tuple[tf.Tensor], TFBaseModelOutputWithPooling]:
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1))
embedding_output = self.conv_stem(pixel_values, training=training)
encoder_outputs = self.encoder(
embedding_output, output_hidden_states=output_hidden_states, return_dict=return_dict, training=training
)
if self.expand_output:
last_hidden_state = self.conv_1x1_exp(encoder_outputs[0])
last_hidden_state = tf.transpose(last_hidden_state, perm=[0, 3, 1, 2])
pooled_output = self.pooler(last_hidden_state)
else:
last_hidden_state = encoder_outputs[0]
last_hidden_state = tf.transpose(last_hidden_state, perm=[0, 3, 1, 2])
pooled_output = None
if not return_dict:
output = (last_hidden_state, pooled_output) if pooled_output is not None else (last_hidden_state,)
if not self.expand_output:
remaining_encoder_outputs = encoder_outputs[1:]
remaining_encoder_outputs = tuple(
[tf.transpose(h, perm=(0, 3, 1, 2)) for h in remaining_encoder_outputs[0]]
)
remaining_encoder_outputs = (remaining_encoder_outputs,)
return output + remaining_encoder_outputs
else:
return output + encoder_outputs[1:]
if output_hidden_states:
hidden_states = tuple([tf.transpose(h, perm=(0, 3, 1, 2)) for h in encoder_outputs[1]])
return TFBaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=hidden_states if output_hidden_states else encoder_outputs.hidden_states,
)
if self.built:
return
self.built = True
if getattr(self, "conv_stem", None) is not None:
with tf.name_scope(self.conv_stem.name):
self.conv_stem.build(None)
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
if getattr(self, "pooler", None) is not None:
with tf.name_scope(self.pooler.name):
self.pooler.build([None, None, None, None])
if getattr(self, "conv_1x1_exp", None) is not None:
with tf.name_scope(self.conv_1x1_exp.name):
self.conv_1x1_exp.build(None)
"""
Documentation string defining the format of inputs accepted by models and layers in the MobileViT architecture.
It explains the two supported input formats: keyword arguments and positional list/tuple/dict for input tensors.
When using TensorFlow 2.0 Keras methods like `model.fit()`, the second format (list, tuple, dict) is preferred.
This enables flexibility in passing inputs such as `pixel_values`, `attention_mask`, and `token_type_ids`.
For Keras Functional API or subclassing, inputs can be:
- A single tensor: `model(pixel_values)`
- A list of tensors: `model([pixel_values, attention_mask])`
- A dictionary of tensors: `model({"pixel_values": pixel_values, "token_type_ids": token_type_ids})`
This documentation guides users on how to interface with MobileViT models and layers effectively.
Parameters:
config ([`MobileViTConfig`]): Configuration class containing all model parameters.
Loading weights requires using [`~TFPreTrainedModel.from_pretrained`], which initializes the model with weights.
"""
"""
MobileViT model outputting raw hidden-states without any specific head on top.
此类定义了一个MobileViT模型,它没有特定的输出头部。
MOBILEVIT_START_DOCSTRING: 在此处未提供具体内容的示例文档字符串。
"""
class TFMobileViTModel(TFMobileViTPreTrainedModel):
def __init__(self, config: MobileViTConfig, expand_output: bool = True, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.config = config
self.expand_output = expand_output
self.mobilevit = TFMobileViTMainLayer(config, expand_output=expand_output, name="mobilevit")
@unpack_inputs
@add_start_docstrings_to_model_forward(MOBILEVIT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TFBaseModelOutputWithPooling,
config_class=_CONFIG_FOR_DOC,
modality="vision",
expected_output=_EXPECTED_OUTPUT_SHAPE,
)
def call(
self,
pixel_values: tf.Tensor | None = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
) -> Union[Tuple[tf.Tensor], TFBaseModelOutputWithPooling]:
output = self.mobilevit(pixel_values, output_hidden_states, return_dict, training=training)
return output
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "mobilevit", None) is not None:
with tf.name_scope(self.mobilevit.name):
self.mobilevit.build(None)
"""
MobileViT model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
ImageNet.
此类定义了一个带有图像分类头部的MobileViT模型,例如用于ImageNet。
MOBILEVIT_START_DOCSTRING: 在此处未提供具体内容的示例文档字符串。
"""
class TFMobileViTForImageClassification(TFMobileViTPreTrainedModel, TFSequenceClassificationLoss):
def __init__(self, config: MobileViTConfig, *inputs, **kwargs) -> None:
super().__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
self.mobilevit = TFMobileViTMainLayer(config, name="mobilevit")
self.dropout = keras.layers.Dropout(config.classifier_dropout_prob)
self.classifier = (
keras.layers.Dense(config.num_labels, name="classifier") if config.num_labels > 0 else tf.identity
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(MOBILEVIT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_IMAGE_CLASS_CHECKPOINT,
output_type=TFImageClassifierOutputWithNoAttention,
config_class=_CONFIG_FOR_DOC,
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
)
def call(
self,
pixel_values: tf.Tensor | None = None,
output_hidden_states: Optional[bool] = None,
labels: tf.Tensor | None = None,
return_dict: Optional[bool] = None,
training: Optional[bool] = False,
) -> Union[TFImageClassifierOutputWithNoAttention]:
output = self.mobilevit(pixel_values, output_hidden_states, return_dict, training=training)
return output
) -> Union[tuple, TFImageClassifierOutputWithNoAttention]:
r"""
labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss). If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.mobilevit(
pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict, training=training
)
pooled_output = outputs.pooler_output if return_dict else outputs[1]
logits = self.classifier(self.dropout(pooled_output, training=training))
loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TFImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "mobilevit", None) is not None:
with tf.name_scope(self.mobilevit.name):
self.mobilevit.build(None)
if getattr(self, "classifier", None) is not None:
if hasattr(self.classifier, "name"):
with tf.name_scope(self.classifier.name):
self.classifier.build([None, None, self.config.neck_hidden_sizes[-1]])
class TFMobileViTASPPPooling(keras.layers.Layer):
def __init__(self, config: MobileViTConfig, in_channels: int, out_channels: int, **kwargs) -> None:
super().__init__(**kwargs)
self.global_pool = keras.layers.GlobalAveragePooling2D(keepdims=True, name="global_pool")
self.conv_1x1 = TFMobileViTConvLayer(
config,
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=1,
use_normalization=True,
use_activation="relu",
name="conv_1x1",
)
def call(self, features: tf.Tensor, training: bool = False) -> tf.Tensor:
spatial_size = shape_list(features)[1:-1]
features = self.global_pool(features)
features = self.conv_1x1(features, training=training)
features = tf.image.resize(features, size=spatial_size, method="bilinear")
return features
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "global_pool", None) is not None:
with tf.name_scope(self.global_pool.name):
self.global_pool.build([None, None, None, None])
if getattr(self, "conv_1x1", None) is not None:
with tf.name_scope(self.conv_1x1.name):
self.conv_1x1.build(None)
class TFMobileViTASPP(keras.layers.Layer):
"""
ASPP module defined in DeepLab papers: https://arxiv.org/abs/1606.00915, https://arxiv.org/abs/1706.05587
"""
def __init__(self, config: MobileViTConfig, **kwargs) -> None:
super().__init__(**kwargs)
in_channels = config.neck_hidden_sizes[-2]
out_channels = config.aspp_out_channels
if len(config.atrous_rates) != 3:
raise ValueError("Expected 3 values for atrous_rates")
self.convs = []
in_projection = TFMobileViTConvLayer(
config,
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
use_activation="relu",
name="convs.0",
)
self.convs.append(in_projection)
self.convs.extend(
[
TFMobileViTConvLayer(
config,
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
dilation=rate,
use_activation="relu",
name=f"convs.{i + 1}",
)
for i, rate in enumerate(config.atrous_rates)
]
)
pool_layer = TFMobileViTASPPPooling(
config, in_channels, out_channels, name=f"convs.{len(config.atrous_rates) + 1}"
)
self.convs.append(pool_layer)
self.project = TFMobileViTConvLayer(
config,
in_channels=5 * out_channels,
out_channels=out_channels,
kernel_size=1,
use_activation="relu",
name="project",
)
self.dropout = keras.layers.Dropout(config.aspp_dropout_prob)
def call(self, features: tf.Tensor, training: bool = False) -> tf.Tensor:
features = tf.transpose(features, perm=[0, 2, 3, 1])
pyramid = []
for conv in self.convs:
pyramid.append(conv(features, training=training))
pyramid = tf.concat(pyramid, axis=-1)
pooled_features = self.project(pyramid, training=training)
pooled_features = self.dropout(pooled_features, training=training)
return pooled_features
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "project", None) is not None:
with tf.name_scope(self.project.name):
self.project.build(None)
if getattr(self, "convs", None) is not None:
for conv in self.convs:
with tf.name_scope(conv.name):
conv.build(None)
class TFMobileViTDeepLabV3(keras.layers.Layer):
"""
DeepLabv3 architecture: https://arxiv.org/abs/1706.05587
"""
def __init__(self, config: MobileViTConfig, **kwargs) -> None:
super().__init__(**kwargs)
self.aspp = TFMobileViTASPP(config, name="aspp")
self.dropout = keras.layers.Dropout(config.classifier_dropout_prob)
self.classifier = TFMobileViTConvLayer(
config,
in_channels=config.aspp_out_channels,
out_channels=config.num_labels,
kernel_size=1,
use_normalization=False,
use_activation=False,
bias=True,
name="classifier",
)
def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor:
features = self.aspp(hidden_states[-1], training=training)
features = self.dropout(features, training=training)
features = self.classifier(features, training=training)
return features
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "aspp", None) is not None:
with tf.name_scope(self.aspp.name):
self.aspp.build(None)
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build(None)
@add_start_docstrings(
"""
MobileViT model with a semantic segmentation head on top, e.g. for Pascal VOC.
""",
MOBILEVIT_START_DOCSTRING,
)
class TFMobileViTForSemanticSegmentation(TFMobileViTPreTrainedModel):
def __init__(self, config: MobileViTConfig, **kwargs) -> None:
super().__init__(config, **kwargs)
self.num_labels = config.num_labels
self.mobilevit = TFMobileViTMainLayer(config, expand_output=False, name="mobilevit")
self.segmentation_head = TFMobileViTDeepLabV3(config, name="segmentation_head")
def hf_compute_loss(self, logits, labels):
label_interp_shape = shape_list(labels)[1:]
upsampled_logits = tf.image.resize(logits, size=label_interp_shape, method="bilinear")
loss_fct = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction="none")
def masked_loss(real, pred):
unmasked_loss = loss_fct(real, pred)
mask = tf.cast(real != self.config.semantic_loss_ignore_index, dtype=unmasked_loss.dtype)
masked_loss = unmasked_loss * mask
reduced_masked_loss = tf.reduce_sum(masked_loss) / tf.reduce_sum(mask)
return tf.reshape(reduced_masked_loss, (1,))
return masked_loss(labels, upsampled_logits)
@unpack_inputs
@add_start_docstrings_to_model_forward(MOBILEVIT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFSemanticSegmenterOutputWithNoAttention, config_class=_CONFIG_FOR_DOC)
def call(
self,
pixel_values: tf.Tensor | None = None,
labels: tf.Tensor | None = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.mobilevit(
pixel_values,
output_hidden_states=True,
return_dict=return_dict,
training=training,
)
encoder_hidden_states = outputs.hidden_states if return_dict else outputs[1]
logits = self.segmentation_head(encoder_hidden_states, training=training)
loss = None
if labels is not None:
if not self.config.num_labels > 1:
raise ValueError("The number of labels should be greater than one")
else:
loss = self.hf_compute_loss(logits=logits, labels=labels)
logits = tf.transpose(logits, perm=[0, 3, 1, 2])
if not return_dict:
if output_hidden_states:
output = (logits,) + outputs[1:]
else:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TFSemanticSegmenterOutputWithNoAttention(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states if output_hidden_states else None,
)
if self.built:
return
self.built = True
if getattr(self, "mobilevit", None) is not None:
with tf.name_scope(self.mobilevit.name):
self.mobilevit.build(None)
if getattr(self, "segmentation_head", None) is not None:
with tf.name_scope(self.segmentation_head.name):
self.segmentation_head.build(None)