Transformers 源码解析(九十九)
.\models\sam\image_processing_sam.py
"""SAM 的图像处理类。"""
import math
from copy import deepcopy
from itertools import product
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import convert_to_rgb, pad, resize, to_channel_dimension_format
from ...image_utils import (
IMAGENET_DEFAULT_MEAN,
IMAGENET_DEFAULT_STD,
ChannelDimension,
ImageInput,
PILImageResampling,
get_image_size,
infer_channel_dimension_format,
is_scaled_image,
make_list_of_images,
to_numpy_array,
valid_images,
validate_kwargs,
validate_preprocess_arguments,
)
from ...utils import (
TensorType,
is_tf_available,
is_torch_available,
is_torchvision_available,
logging,
requires_backends,
)
if is_torch_available():
import torch
import torch.nn.functional as F
if is_torchvision_available():
from torchvision.ops.boxes import batched_nms
if is_tf_available():
import tensorflow as tf
from tensorflow.experimental import numpy as tnp
from ...tf_utils import flatten, shape_list
logger = logging.get_logger(__name__)
class SamImageProcessor(BaseImageProcessor):
r"""
构造 SAM 图像处理器。
"""
model_input_names = ["pixel_values"]
def __init__(
self,
do_resize: bool = True,
size: Dict[str, int] = None,
mask_size: Dict[str, int] = None,
resample: PILImageResampling = PILImageResampling.BILINEAR,
do_rescale: bool = True,
rescale_factor: Union[int, float] = 1 / 255,
do_normalize: bool = True,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_pad: bool = True,
pad_size: int = None,
mask_pad_size: int = None,
do_convert_rgb: bool = True,
**kwargs,
) -> None:
super().__init__(**kwargs)
size = size if size is not None else {"longest_edge": 1024}
size = get_size_dict(max_size=size, default_to_square=False) if not isinstance(size, dict) else size
pad_size = pad_size if pad_size is not None else {"height": 1024, "width": 1024}
pad_size = get_size_dict(pad_size, default_to_square=True)
mask_size = mask_size if mask_size is not None else {"longest_edge": 256}
mask_size = (
get_size_dict(max_size=mask_size, default_to_square=False)
if not isinstance(mask_size, dict)
else mask_size
)
mask_pad_size = mask_pad_size if mask_pad_size is not None else {"height": 256, "width": 256}
mask_pad_size = get_size_dict(mask_pad_size, default_to_square=True)
self.do_resize = do_resize
self.size = size
self.mask_size = mask_size
self.resample = resample
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
self.do_pad = do_pad
self.pad_size = pad_size
self.mask_pad_size = mask_pad_size
self.do_convert_rgb = do_convert_rgb
self._valid_processor_keys = [
"images",
"segmentation_maps",
"do_resize",
"size",
"mask_size",
"resample",
"do_rescale",
"rescale_factor",
"do_normalize",
"image_mean",
"image_std",
"do_pad",
"pad_size",
"mask_pad_size",
"do_convert_rgb",
"return_tensors",
"data_format",
"input_data_format",
]
def pad_image(
self,
image: np.ndarray,
pad_size: Dict[str, int],
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
) -> np.ndarray:
"""
Pad an image to `(pad_size["height"], pad_size["width"])` with zeros to the right and bottom.
Args:
image (`np.ndarray`):
Image to pad.
pad_size (`Dict[str, int]`):
Size of the output image after padding.
data_format (`str` or `ChannelDimension`, *optional*):
The data format of the image. Can be either "channels_first" or "channels_last". If `None`, the
`data_format` of the `image` will be used.
input_data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
"""
output_height, output_width = pad_size["height"], pad_size["width"]
input_height, input_width = get_image_size(image, channel_dim=input_data_format)
pad_width = output_width - input_width
pad_height = output_height - input_height
padded_image = pad(
image,
((0, pad_height), (0, pad_width)),
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
)
return padded_image
def _get_preprocess_shape(self, old_shape: Tuple[int, int], longest_edge: int):
"""
Compute the output size given input size and target long side length.
"""
oldh, oldw = old_shape
scale = longest_edge * 1.0 / max(oldh, oldw)
newh, neww = oldh * scale, oldw * scale
newh = int(newh + 0.5)
neww = int(neww + 0.5)
return (newh, neww)
def resize(
self,
image: np.ndarray,
size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.BICUBIC,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
):
"""
Resize an image to a specific size.
Args:
image (`np.ndarray`):
Image to resize.
size (`Dict[str, int]`):
Target size of the output image after resizing, specified as {'height': height, 'width': width}.
resample (`PILImageResampling`, *optional*):
Resampling method. Default is `PILImageResampling.BICUBIC`.
data_format (`str` or `ChannelDimension`, *optional*):
The data format of the image. Can be either "channels_first" or "channels_last". If `None`, the
`data_format` of the `image` will be used.
input_data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
"""
pass
) -> np.ndarray:
"""
Resize an image to `(size["height"], size["width"])`.
Args:
image (`np.ndarray`):
Image to resize.
size (`Dict[str, int]`):
Dictionary in the format `{"longest_edge": int}` specifying the size of the output image. The longest
edge of the image will be resized to the specified size, while the other edge will be resized to
maintain the aspect ratio.
resample:
`PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`.
data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the output image. If unset, the channel dimension format of the input
image is used. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
Returns:
`np.ndarray`: The resized image.
"""
size = get_size_dict(size)
if "longest_edge" not in size:
raise ValueError(f"The `size` dictionary must contain the key `longest_edge`. Got {size.keys()}")
input_size = get_image_size(image, channel_dim=input_data_format)
output_height, output_width = self._get_preprocess_shape(input_size, size["longest_edge"])
return resize(
image,
size=(output_height, output_width),
resample=resample,
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
)
def _preprocess(
self,
image: ImageInput,
do_resize: bool,
do_rescale: bool,
do_normalize: bool,
size: Optional[Dict[str, int]] = None,
resample: PILImageResampling = None,
rescale_factor: Optional[float] = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_pad: Optional[bool] = None,
pad_size: Optional[Dict[str, int]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
):
if do_resize:
image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
reshaped_input_size = get_image_size(image, channel_dim=input_data_format)
if do_rescale:
image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
if do_normalize:
image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
if do_pad:
image = self.pad_image(image=image, pad_size=pad_size, input_data_format=input_data_format)
return image, reshaped_input_size
def _preprocess_image(
self,
image: ImageInput,
do_resize: Optional[bool] = None,
size: Dict[str, int] = None,
resample: PILImageResampling = None,
do_rescale: bool = None,
rescale_factor: Optional[float] = None,
do_normalize: Optional[bool] = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_pad: Optional[bool] = None,
pad_size: Optional[Dict[str, int]] = None,
do_convert_rgb: Optional[bool] = None,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> Tuple[np.ndarray, Tuple[int, int], Tuple[int, int]]:
image = to_numpy_array(image)
if do_convert_rgb:
image = convert_to_rgb(image)
image = to_numpy_array(image)
if is_scaled_image(image) and do_rescale:
logger.warning_once(
"It looks like you are trying to rescale already rescaled images. If the input"
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
)
if input_data_format is None:
input_data_format = infer_channel_dimension_format(image)
original_size = get_image_size(image, channel_dim=input_data_format)
image, reshaped_input_size = self._preprocess(
image=image,
do_resize=do_resize,
size=size,
resample=resample,
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_pad=do_pad,
pad_size=pad_size,
input_data_format=input_data_format,
)
if data_format is not None:
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
return image, original_size, reshaped_input_size
def _preprocess_mask(
self,
segmentation_map: ImageInput,
do_resize: Optional[bool] = None,
mask_size: Dict[str, int] = None,
do_pad: Optional[bool] = None,
mask_pad_size: Optional[Dict[str, int]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray:
segmentation_map = to_numpy_array(segmentation_map)
if segmentation_map.ndim == 2:
added_channel_dim = True
segmentation_map = segmentation_map[None, ...]
input_data_format = ChannelDimension.FIRST
else:
added_channel_dim = False
if input_data_format is None:
input_data_format = infer_channel_dimension_format(segmentation_map, num_channels=1)
original_size = get_image_size(segmentation_map, channel_dim=input_data_format)
segmentation_map, _ = self._preprocess(
image=segmentation_map,
do_resize=do_resize,
size=mask_size,
resample=PILImageResampling.NEAREST,
do_rescale=False,
do_normalize=False,
do_pad=do_pad,
pad_size=mask_pad_size,
input_data_format=input_data_format,
)
if added_channel_dim:
segmentation_map = segmentation_map.squeeze(0)
segmentation_map = segmentation_map.astype(np.int64)
return segmentation_map, original_size
def preprocess(
self,
images: ImageInput,
segmentation_maps: Optional[ImageInput] = None,
do_resize: Optional[bool] = None,
size: Optional[Dict[str, int]] = None,
mask_size: Optional[Dict[str, int]] = None,
resample: Optional["PILImageResampling"] = None,
do_rescale: Optional[bool] = None,
rescale_factor: Optional[Union[int, float]] = None,
do_normalize: Optional[bool] = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_pad: Optional[bool] = None,
pad_size: Optional[Dict[str, int]] = None,
mask_pad_size: Optional[Dict[str, int]] = None,
do_convert_rgb: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: ChannelDimension = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
):
pass
def post_process_masks(
self,
masks,
original_sizes,
reshaped_input_sizes,
mask_threshold=0.0,
binarize=True,
pad_size=None,
return_tensors="pt",
**kwargs,
):
pass
):
"""
Remove padding and upscale masks to the original image size.
Args:
masks (`Union[List[torch.Tensor], List[np.ndarray], List[tf.Tensor]]`):
Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format.
original_sizes (`Union[torch.Tensor, tf.Tensor, List[Tuple[int,int]]]`):
The original sizes of each image before it was resized to the model's expected input shape, in (height,
width) format.
reshaped_input_sizes (`Union[torch.Tensor, tf.Tensor, List[Tuple[int,int]]]`):
The size of each image as it is fed to the model, in (height, width) format. Used to remove padding.
mask_threshold (`float`, *optional*, defaults to 0.0):
The threshold to use for binarizing the masks.
binarize (`bool`, *optional*, defaults to `True`):
Whether to binarize the masks.
pad_size (`int`, *optional*, defaults to `self.pad_size`):
The target size the images were padded to before being passed to the model. If None, the target size is
assumed to be the processor's `pad_size`.
return_tensors (`str`, *optional*, defaults to `"pt"`):
If `"pt"`, return PyTorch tensors. If `"tf"`, return TensorFlow tensors.
Returns:
(`Union[torch.Tensor, tf.Tensor]`): Batched masks in batch_size, num_channels, height, width) format, where
(height, width) is given by original_size.
"""
if return_tensors == "pt":
return self._post_process_masks_pt(
masks=masks,
original_sizes=original_sizes,
reshaped_input_sizes=reshaped_input_sizes,
mask_threshold=mask_threshold,
binarize=binarize,
pad_size=pad_size,
)
elif return_tensors == "tf":
return self._post_process_masks_tf(
masks=masks,
original_sizes=original_sizes,
reshaped_input_sizes=reshaped_input_sizes,
mask_threshold=mask_threshold,
binarize=binarize,
pad_size=pad_size,
)
else:
raise ValueError("return_tensors must be either 'pt' or 'tf'")
def _post_process_masks_pt(
self, masks, original_sizes, reshaped_input_sizes, mask_threshold=0.0, binarize=True, pad_size=None
):
"""
Remove padding and upscale masks to the original image size.
Args:
masks (`Union[List[torch.Tensor], List[np.ndarray]]`):
Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format.
original_sizes (`Union[torch.Tensor, List[Tuple[int,int]]]`):
The original sizes of each image before it was resized to the model's expected input shape, in (height,
width) format.
reshaped_input_sizes (`Union[torch.Tensor, List[Tuple[int,int]]]`):
The size of each image as it is fed to the model, in (height, width) format. Used to remove padding.
mask_threshold (`float`, *optional*, defaults to 0.0):
The threshold to use for binarizing the masks.
binarize (`bool`, *optional*, defaults to `True`):
Whether to binarize the masks.
pad_size (`int`, *optional*, defaults to `self.pad_size`):
The target size the images were padded to before being passed to the model. If None, the target size is
assumed to be the processor's `pad_size`.
Returns:
(`torch.Tensor`): Batched masks in batch_size, num_channels, height, width) format, where (height, width)
is given by original_size.
"""
requires_backends(self, ["torch"])
pad_size = self.pad_size if pad_size is None else pad_size
target_image_size = (pad_size["height"], pad_size["width"])
if isinstance(original_sizes, (torch.Tensor, np.ndarray)):
original_sizes = original_sizes.tolist()
if isinstance(reshaped_input_sizes, (torch.Tensor, np.ndarray)):
reshaped_input_sizes = reshaped_input_sizes.tolist()
output_masks = []
for i, original_size in enumerate(original_sizes):
if isinstance(masks[i], np.ndarray):
masks[i] = torch.from_numpy(masks[i])
elif not isinstance(masks[i], torch.Tensor):
raise ValueError("Input masks should be a list of `torch.tensors` or a list of `np.ndarray`")
interpolated_mask = F.interpolate(masks[i], target_image_size, mode="bilinear", align_corners=False)
interpolated_mask = interpolated_mask[..., : reshaped_input_sizes[i][0], : reshaped_input_sizes[i][1]]
interpolated_mask = F.interpolate(interpolated_mask, original_size, mode="bilinear", align_corners=False)
if binarize:
interpolated_mask = interpolated_mask > mask_threshold
output_masks.append(interpolated_mask)
return output_masks
):
"""
Remove padding and upscale masks to the original image size.
Args:
masks (`tf.Tensor`):
Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format.
original_sizes (`tf.Tensor`):
The original size of the images before resizing for input to the model, in (height, width) format.
reshaped_input_sizes (`tf.Tensor`):
The size of the image input to the model, in (height, width) format. Used to remove padding.
mask_threshold (`float`, *optional*, defaults to 0.0):
The threshold to use for binarizing the masks.
binarize (`bool`, *optional*, defaults to `True`):
Whether to binarize the masks.
pad_size (`int`, *optional*, defaults to `self.pad_size`):
The target size the images were padded to before being passed to the model. If None, the target size is
assumed to be the processor's `pad_size`.
Returns:
(`tf.Tensor`): Batched masks in batch_size, num_channels, height, width) format, where (height, width) is
given by original_size.
"""
requires_backends(self, ["tf"])
pad_size = self.pad_size if pad_size is None else pad_size
target_image_size = (pad_size["height"], pad_size["width"])
output_masks = []
for i, original_size in enumerate(original_sizes):
mask = tf.transpose(masks[i], perm=[0, 2, 3, 1])
interpolated_mask = tf.image.resize(mask, target_image_size, method="bilinear")
interpolated_mask = interpolated_mask[:, : reshaped_input_sizes[i][0], : reshaped_input_sizes[i][1], :]
interpolated_mask = tf.image.resize(interpolated_mask, original_size, method="bilinear")
if binarize:
interpolated_mask = interpolated_mask > mask_threshold
output_masks.append(tf.transpose(interpolated_mask, perm=[0, 3, 1, 2]))
return output_masks
):
"""
Post processes mask that are generated by calling the Non Maximum Suppression algorithm on the predicted masks.
Args:
all_masks (`Union[List[torch.Tensor], List[tf.Tensor]]`):
List of all predicted segmentation masks
all_scores (`Union[List[torch.Tensor], List[tf.Tensor]]`):
List of all predicted iou scores
all_boxes (`Union[List[torch.Tensor], List[tf.Tensor]]`):
List of all bounding boxes of the predicted masks
crops_nms_thresh (`float`):
Threshold for NMS (Non Maximum Suppression) algorithm.
return_tensors (`str`, *optional*, defaults to `pt`):
If `pt`, returns `torch.Tensor`. If `tf`, returns `tf.Tensor`.
"""
if return_tensors == "pt":
return _postprocess_for_mg(all_masks, all_scores, all_boxes, crops_nms_thresh)
elif return_tensors == "tf":
return _postprocess_for_mg_tf(all_masks, all_scores, all_boxes, crops_nms_thresh)
def generate_crop_boxes(
self,
image,
target_size,
crop_n_layers: int = 0,
overlap_ratio: float = 512 / 1500,
points_per_crop: Optional[int] = 32,
crop_n_points_downscale_factor: Optional[List[int]] = 1,
device: Optional["torch.device"] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
return_tensors: str = "pt",
):
"""
Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer.
Args:
image (`np.array`):
Input original image
target_size (`int`):
Target size of the resized image
crop_n_layers (`int`, *optional*, defaults to 0):
If >0, mask prediction will be run again on crops of the image. Sets the number of layers to run, where
each layer has 2**i_layer number of image crops.
overlap_ratio (`float`, *optional*, defaults to 512/1500):
Sets the degree to which crops overlap. In the first crop layer, crops will overlap by this fraction of
the image length. Later layers with more crops scale down this overlap.
points_per_crop (`int`, *optional*, defaults to 32):
Number of points to sample from each crop.
crop_n_points_downscale_factor (`List[int]`, *optional*, defaults to 1):
The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
device (`torch.device`, *optional*, defaults to None):
Device to use for the computation. If None, cpu will be used.
input_data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
return_tensors (`str`, *optional*, defaults to `pt`):
If `pt`, returns `torch.Tensor`. If `tf`, returns `tf.Tensor`.
"""
crop_boxes, points_per_crop, cropped_images, input_labels = _generate_crop_boxes(
image,
target_size,
crop_n_layers,
overlap_ratio,
points_per_crop,
crop_n_points_downscale_factor,
input_data_format,
)
if return_tensors == "pt":
if device is None:
device = torch.device("cpu")
crop_boxes = torch.tensor(crop_boxes, device=device)
points_per_crop = torch.tensor(points_per_crop, device=device)
input_labels = torch.tensor(input_labels, device=device)
elif return_tensors == "tf":
if device is not None:
raise ValueError("device is not a supported argument when return_tensors is tf!")
crop_boxes = tf.convert_to_tensor(crop_boxes)
points_per_crop = tf.convert_to_tensor(points_per_crop)
input_labels = tf.convert_to_tensor(input_labels)
else:
raise ValueError("return_tensors must be either 'pt' or 'tf'.")
return crop_boxes, points_per_crop, cropped_images, input_labels
"""
根据给定的条件过滤预测的掩码,并执行必要的转换和填充操作。
Args:
masks (`Union[torch.Tensor, tf.Tensor]`):
输入的掩码张量。
iou_scores (`Union[torch.Tensor, tf.Tensor]`):
IoU(Intersection over Union)分数的列表。
original_size (`Tuple[int,int]`):
原始图像的尺寸。
cropped_box_image (`np.array`):
裁剪后的图像数组。
pred_iou_thresh (`float`, *optional*, 默认为 0.88):
IoU 分数的阈值。
stability_score_thresh (`float`, *optional*, 默认为 0.95):
稳定性分数的阈值。
mask_threshold (`float`, *optional*, 默认为 0):
预测掩码的阈值。
stability_score_offset (`float`, *optional*, 默认为 1):
在 `_compute_stability_score` 方法中使用的稳定性分数的偏移量。
return_tensors (`str`, *optional*, 默认为 `pt`):
如果是 `pt`,返回 `torch.Tensor`;如果是 `tf`,返回 `tf.Tensor`。
"""
if return_tensors == "pt":
return self._filter_masks_pt(
masks=masks,
iou_scores=iou_scores,
original_size=original_size,
cropped_box_image=cropped_box_image,
pred_iou_thresh=pred_iou_thresh,
stability_score_thresh=stability_score_thresh,
mask_threshold=mask_threshold,
stability_score_offset=stability_score_offset,
)
elif return_tensors == "tf":
return self._filter_masks_tf(
masks=masks,
iou_scores=iou_scores,
original_size=original_size,
cropped_box_image=cropped_box_image,
pred_iou_thresh=pred_iou_thresh,
stability_score_thresh=stability_score_thresh,
mask_threshold=mask_threshold,
stability_score_offset=stability_score_offset,
)
"""
Filters the predicted masks by selecting only the ones that meets several criteria. The first criterion being
that the iou scores needs to be greater than `pred_iou_thresh`. The second criterion is that the stability
score needs to be greater than `stability_score_thresh`. The method also converts the predicted masks to
bounding boxes and pad the predicted masks if necessary.
Args:
masks (`torch.Tensor`):
Input masks.
iou_scores (`torch.Tensor`):
List of IoU scores.
original_size (`Tuple[int,int]`):
Size of the original image.
cropped_box_image (`np.array`):
The cropped image.
pred_iou_thresh (`float`, *optional*, defaults to 0.88):
The threshold for the IoU scores.
stability_score_thresh (`float`, *optional*, defaults to 0.95):
The threshold for the stability score.
mask_threshold (`float`, *optional*, defaults to 0):
The threshold for the predicted masks.
stability_score_offset (`float`, *optional*, defaults to 1):
The offset for the stability score used in the `_compute_stability_score` method.
"""
requires_backends(self, ["torch"])
original_height, original_width = original_size
iou_scores = iou_scores.flatten(0, 1)
masks = masks.flatten(0, 1)
if masks.shape[0] != iou_scores.shape[0]:
raise ValueError("masks and iou_scores must have the same batch size.")
if masks.device != iou_scores.device:
iou_scores = iou_scores.to(masks.device)
batch_size = masks.shape[0]
keep_mask = torch.ones(batch_size, dtype=torch.bool, device=masks.device)
if pred_iou_thresh > 0.0:
keep_mask = keep_mask & (iou_scores > pred_iou_thresh)
if stability_score_thresh > 0.0:
stability_scores = _compute_stability_score_pt(masks, mask_threshold, stability_score_offset)
keep_mask = keep_mask & (stability_scores > stability_score_thresh)
scores = iou_scores[keep_mask]
masks = masks[keep_mask]
masks = masks > mask_threshold
converted_boxes = _batched_mask_to_box(masks)
keep_mask = ~_is_box_near_crop_edge(
converted_boxes, cropped_box_image, [0, 0, original_width, original_height]
)
scores = scores[keep_mask]
masks = masks[keep_mask]
converted_boxes = converted_boxes[keep_mask]
masks = _pad_masks(masks, cropped_box_image, original_height, original_width)
masks = _mask_to_rle_pytorch(masks)
return masks, scores, converted_boxes
):
"""
Filters the predicted masks by selecting only the ones that meets several criteria. The first criterion being
that the iou scores needs to be greater than `pred_iou_thresh`. The second criterion is that the stability
score needs to be greater than `stability_score_thresh`. The method also converts the predicted masks to
bounding boxes and pad the predicted masks if necessary.
Args:
masks (`tf.Tensor`):
Input masks.
iou_scores (`tf.Tensor`):
List of IoU scores.
original_size (`Tuple[int,int]`):
Size of the orginal image.
cropped_box_image (`np.array`):
The cropped image.
pred_iou_thresh (`float`, *optional*, defaults to 0.88):
The threshold for the iou scores.
stability_score_thresh (`float`, *optional*, defaults to 0.95):
The threshold for the stability score.
mask_threshold (`float`, *optional*, defaults to 0):
The threshold for the predicted masks.
stability_score_offset (`float`, *optional*, defaults to 1):
The offset for the stability score used in the `_compute_stability_score` method.
"""
requires_backends(self, ["tf"])
original_height, original_width = original_size
iou_scores = tf.reshape(iou_scores, [iou_scores.shape[0] * iou_scores.shape[1], iou_scores.shape[2:]])
masks = tf.reshape(masks, [masks.shape[0] * masks.shape[1], masks.shape[2:]])
if masks.shape[0] != iou_scores.shape[0]:
raise ValueError("masks and iou_scores must have the same batch size.")
batch_size = masks.shape[0]
keep_mask = tf.ones(batch_size, dtype=tf.bool)
if pred_iou_thresh > 0.0:
keep_mask = keep_mask & (iou_scores > pred_iou_thresh)
if stability_score_thresh > 0.0:
stability_scores = _compute_stability_score_tf(masks, mask_threshold, stability_score_offset)
keep_mask = keep_mask & (stability_scores > stability_score_thresh)
scores = iou_scores[keep_mask]
masks = masks[keep_mask]
masks = masks > mask_threshold
converted_boxes = _batched_mask_to_box_tf(masks)
keep_mask = ~_is_box_near_crop_edge_tf(
converted_boxes, cropped_box_image, [0, 0, original_width, original_height]
)
scores = scores[keep_mask]
masks = masks[keep_mask]
converted_boxes = converted_boxes[keep_mask]
masks = _pad_masks_tf(masks, cropped_box_image, original_height, original_width)
masks = _mask_to_rle_tf(masks)
return masks, scores, converted_boxes
def _compute_stability_score_pt(masks: "torch.Tensor", mask_threshold: float, stability_score_offset: int):
intersections = (
(masks > (mask_threshold + stability_score_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
)
unions = (masks > (mask_threshold - stability_score_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
stability_scores = intersections / unions
return stability_scores
def _compute_stability_score_tf(masks: "tf.Tensor", mask_threshold: float, stability_score_offset: int):
intersections = tf.count_nonzero(
masks > (mask_threshold + stability_score_offset), axis=[-1, -2], dtype=tf.float32
)
unions = tf.count_nonzero(masks > (mask_threshold - stability_score_offset), axis=[-1, -2], dtype=tf.float32)
stability_scores = intersections / unions
return stability_scores
def _build_point_grid(n_per_side: int) -> np.ndarray:
offset = 1 / (2 * n_per_side)
points_one_side = np.linspace(offset, 1 - offset, n_per_side)
points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
points_y = np.tile(points_one_side[:, None], (1, n_per_side))
points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2)
return points
def _normalize_coordinates(
target_size: int, coords: np.ndarray, original_size: Tuple[int, int], is_bounding_box=False
) -> np.ndarray:
old_height, old_width = original_size
scale = target_size * 1.0 / max(old_height, old_width)
new_height, new_width = old_height * scale, old_width * scale
new_width, new_height = int(new_width + 0.5), int(new_height + 0.5)
coords = deepcopy(coords).astype(float)
if is_bounding_box:
coords = coords.reshape(-1, 2, 2)
coords[..., 0] = coords[..., 0] * (new_width / old_width)
coords[..., 1] = coords[..., 1] * (new_height / old_height)
if is_bounding_box:
coords = coords.reshape(-1, 4)
return coords
def _generate_crop_boxes(
image,
target_size: int,
crop_n_layers: int = 0,
overlap_ratio: float = 512 / 1500,
points_per_crop: Optional[int] = 32,
crop_n_points_downscale_factor: Optional[List[int]] = 1,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> Tuple[List[List[int]], List[int]]:
"""
根据指定的参数生成不同的大小的截取框列表,每个层级不同大小的包含 (2^i)^2 个框。
"""
if isinstance(image, list):
raise ValueError("Only one image is allowed for crop generation.")
image = to_numpy_array(image)
original_size = get_image_size(image, input_data_format)
points_grid = []
for i in range(crop_n_layers + 1):
n_points = int(points_per_crop / (crop_n_points_downscale_factor**i))
points_grid.append(_build_point_grid(n_points))
crop_boxes, layer_idxs = _generate_per_layer_crops(crop_n_layers, overlap_ratio, original_size)
cropped_images, point_grid_per_crop = _generate_crop_images(
crop_boxes, image, points_grid, layer_idxs, target_size, original_size, input_data_format
)
crop_boxes = np.array(crop_boxes)
crop_boxes = crop_boxes.astype(np.float32)
points_per_crop = np.array([point_grid_per_crop])
points_per_crop = np.transpose(points_per_crop, axes=(0, 2, 1, 3))
input_labels = np.ones_like(points_per_crop[:, :, :, 0], dtype=np.int64)
return crop_boxes, points_per_crop, cropped_images, input_labels
def _generate_per_layer_crops(crop_n_layers, overlap_ratio, original_size):
crop_boxes, layer_idxs = [], []
im_height, im_width = original_size
short_side = min(im_height, im_width)
crop_boxes.append([0, 0, im_width, im_height])
layer_idxs.append(0)
for i_layer in range(crop_n_layers):
n_crops_per_side = 2 ** (i_layer + 1)
overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side))
crop_width = int(math.ceil((overlap * (n_crops_per_side - 1) + im_width) / n_crops_per_side))
crop_height = int(math.ceil((overlap * (n_crops_per_side - 1) + im_height) / n_crops_per_side))
crop_box_x0 = [int((crop_width - overlap) * i) for i in range(n_crops_per_side)]
crop_box_y0 = [int((crop_height - overlap) * i) for i in range(n_crops_per_side)]
for left, top in product(crop_box_x0, crop_box_y0):
box = [left, top, min(left + crop_width, im_width), min(top + crop_height, im_height)]
crop_boxes.append(box)
layer_idxs.append(i_layer + 1)
return crop_boxes, layer_idxs
def _generate_crop_images(
crop_boxes, image, points_grid, layer_idxs, target_size, original_size, input_data_format=None
):
cropped_images = []
total_points_per_crop = []
for i, crop_box in enumerate(crop_boxes):
left, top, right, bottom = crop_box
channel_dim = infer_channel_dimension_format(image, input_data_format)
if channel_dim == ChannelDimension.LAST:
cropped_im = image[top:bottom, left:right, :]
else:
cropped_im = image[:, top:bottom, left:right]
cropped_images.append(cropped_im)
cropped_im_size = get_image_size(cropped_im, channel_dim)
points_scale = np.array(cropped_im_size)[None, ::-1]
points = points_grid[layer_idxs[i]] * points_scale
normalized_points = _normalize_coordinates(target_size, points, original_size)
total_points_per_crop.append(normalized_points)
return cropped_images, total_points_per_crop
def _pad_masks(masks, crop_box: List[int], orig_height: int, orig_width: int):
left, top, right, bottom = crop_box
if left == 0 and top == 0 and right == orig_width and bottom == orig_height:
return masks
pad_x, pad_y = orig_width - (right - left), orig_height - (bottom - top)
pad = (left, pad_x - left, top, pad_y - top)
return torch.nn.functional.pad(masks, pad, value=0)
def _pad_masks_tf(masks, crop_box: List[int], orig_height: int, orig_width: int):
left, top, right, bottom = crop_box
if left == 0 and top == 0 and right == orig_width and bottom == orig_height:
return masks
pad_x, pad_y = orig_width - (right - left), orig_height - (bottom - top)
pad = (left, pad_x - left, top, pad_y - top)
return tf.pad(masks, pad, constant_values=0)
def _is_box_near_crop_edge(boxes, crop_box, orig_box, atol=20.0):
"""Filter masks at the edge of a crop, but not at the edge of the original image."""
crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
left, top, _, _ = crop_box
offset = torch.tensor([[left, top, left, top]], device=boxes.device)
if len(boxes.shape) == 3:
offset = offset.unsqueeze(1)
boxes = (boxes + offset).float()
near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0)
near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0)
near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge)
return torch.any(near_crop_edge, dim=1)
def _is_box_near_crop_edge_tf(boxes, crop_box, orig_box, atol=20.0):
"""Filter masks at the edge of a crop, but not at the edge of the original image."""
crop_box_tf = tf.convert_to_tensor(crop_box, dtype=tf.float32)
orig_box_tf = tf.convert_to_tensor(orig_box, dtype=tf.float32)
left, top, _, _ = crop_box
offset = tf.convert_to_tensor([[left, top, left, top]])
if len(boxes.shape) == 3:
offset = tf.expand_dims(offset, 1)
boxes = tf.cast(boxes + offset, tf.float32)
near_crop_edge = tfp.math.is_close(boxes, crop_box_tf[None, :], atol=atol, rtol=0)
near_image_edge = tfp.math.is_close(boxes, orig_box_tf[None, :], atol=atol, rtol=0)
near_crop_edge = tf.math.logical_and(near_crop_edge, ~near_image_edge)
return tf.reduce_any(near_crop_edge, axis=1)
def _batched_mask_to_box(masks: "torch.Tensor"):
"""
Computes the bounding boxes around the given input masks. The bounding boxes are in the XYXY format which
corresponds the following required indices:
- LEFT: left hand side of the bounding box
- TOP: top of the bounding box
- RIGHT: right of the bounding box
- BOTTOM: bottom of the bounding box
Return [0,0,0,0] for an empty mask. For input shape channel_1 x channel_2 x ... x height x width, the output shape
is channel_1 x channel_2 x ... x 4.
Args:
- masks (`torch.Tensor` of shape `(batch, nb_mask, height, width)`)
"""
if torch.numel(masks) == 0:
return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
shape = masks.shape
height, width = shape[-2:]
in_height, _ = torch.max(masks, dim=-1)
in_height_coords = in_height * torch.arange(height, device=in_height.device)[None, :]
bottom_edges, _ = torch.max(in_height_coords, dim=-1)
in_height_coords = in_height_coords + height * (~in_height)
top_edges, _ = torch.min(in_height_coords, dim=-1)
in_width, _ = torch.max(masks, dim=-2)
in_width_coords = in_width * torch.arange(width, device=in_width.device)[None, :]
right_edges, _ = torch.max(in_width_coords, dim=-1)
in_width_coords = in_width_coords + width * (~in_width)
left_edges, _ = torch.min(in_width_coords, dim=-1)
empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1)
out = out * (~empty_filter).unsqueeze(-1)
out = out.reshape(*shape[:-2], 4)
return out
batch_size, height, width = input_mask.shape
input_mask = tf.transpose(input_mask, perm=[0, 2, 1])
input_mask = tf.reshape(input_mask, [batch_size, -1])
diff = input_mask[:, 1:] ^ input_mask[:, :-1]
change_indices = tf.where(diff)
out = []
for i in tf.range(batch_size):
cur_idxs = tf.boolean_mask(change_indices[:, 1], change_indices[:, 0] == i) + 1
btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
counts = [] if input_mask[i, 0] == 0 else [0]
counts += [cur_idxs[0].numpy().item()] + btw_idxs.numpy().tolist() + [height * width - cur_idxs[-1]]
out.append({"size": [height, width], "counts": counts})
return out
input_mask = flatten(tf.transpose(input_mask, perm=(0, 2, 1)), 1)
diff = input_mask[:, 1:] ^ input_mask[:, :-1]
change_indices = tf.where(diff)
out = []
for i in range(batch_size):
cur_idxs = change_indices[change_indices[:, 0] == i, 1] + 1
btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
counts = [] if input_mask[i, 0] == 0 else [0]
counts += [cur_idxs[0].item()] + btw_idxs.tolist() + [height * width - cur_idxs[-1]]
out.append({"size": [height, width], "counts": counts})
return out
def _rle_to_mask(rle: Dict[str, Any]) -> np.ndarray:
"""Compute a binary mask from an uncompressed RLE."""
height, width = rle["size"]
mask = np.empty(height * width, dtype=bool)
idx = 0
parity = False
for count in rle["counts"]:
mask[idx : idx + count] = parity
idx += count
parity = not parity
mask = mask.reshape(width, height)
return mask.transpose()
def _postprocess_for_mg_tf(rle_masks, iou_scores, mask_boxes, amg_crops_nms_thresh=0.7):
"""
Perform NMS (Non Maximum Suppression) on the outputs.
Args:
rle_masks (`tf.Tensor`):
binary masks in the RLE format
iou_scores (`tf.Tensor` of shape (nb_masks, 1)):
iou_scores predicted by the model
mask_boxes (`tf.Tensor`):
The bounding boxes corresponding to segmentation masks
amg_crops_nms_thresh (`float`, *optional*, defaults to 0.7):
NMS threshold.
"""
keep_by_nms = tf.image.combined_non_max_suppression(
boxes=mask_boxes.float(),
scores=iou_scores,
idxs=torch.zeros(mask_boxes.shape[0]),
iou_threshold=amg_crops_nms_thresh,
)
iou_scores = iou_scores[keep_by_nms]
rle_masks = [rle_masks[i] for i in keep_by_nms]
mask_boxes = mask_boxes[keep_by_nms]
masks = [_rle_to_mask(rle) for rle in rle_masks]
return masks, iou_scores, rle_masks, mask_boxes
.\models\sam\modeling_sam.py
""" PyTorch SAM model."""
import collections
import math
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import Tensor, nn
from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput
from ...modeling_utils import PreTrainedModel
from ...utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from .configuration_sam import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "SamConfig"
_CHECKPOINT_FOR_DOC = "facebook/sam-vit-huge"
SAM_PRETRAINED_MODEL_ARCHIVE_LIST = [
"facebook/sam-vit-huge",
"facebook/sam-vit-large",
"facebook/sam-vit-base",
]
@dataclass
class SamVisionEncoderOutput(ModelOutput):
"""
Base class for sam vision model's outputs that also contains image embeddings obtained by applying the projection
layer to the pooler_output.
"""
image_embeds: Optional[torch.FloatTensor] = None
last_hidden_state: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
@dataclass
class SamImageSegmentationOutput(ModelOutput):
"""
Base class for Segment-Anything model's output
Args:
iou_scores (`torch.FloatTensor` of shape `(batch_size, num_masks)`):
The iou scores of the predicted masks.
pred_masks (`torch.FloatTensor` of shape `(batch_size, num_masks, height, width)`):
The predicted low resolutions masks. Needs to be post-processed by the processor
vision_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the vision model at the output of each layer plus the optional initial embedding outputs.
vision_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
mask_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
iou_scores: torch.FloatTensor = None
pred_masks: torch.FloatTensor = None
vision_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
vision_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
mask_decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
class SamPatchEmbeddings(nn.Module):
"""
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
Transformer.
"""
def __init__(self, config):
super().__init__()
image_size, patch_size = config.image_size, config.patch_size
num_channels, hidden_size = config.num_channels, config.hidden_size
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.num_patches = num_patches
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
def forward(self, pixel_values):
batch_size, num_channels, height, width = pixel_values.shape
if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
if height != self.image_size[0] or width != self.image_size[1]:
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
)
embeddings = self.projection(pixel_values).permute(0, 2, 3, 1)
return embeddings
class SamMLPBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.lin1 = nn.Linear(config.hidden_size, config.mlp_dim)
self.lin2 = nn.Linear(config.mlp_dim, config.hidden_size)
self.act = ACT2FN[config.hidden_act]
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.lin1(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.lin2(hidden_states)
return hidden_states
class SamLayerNorm(nn.Module):
r"""支持两种数据格式(channels_last或channels_first)的LayerNorm。
channels_last对应输入形状为(batch_size, height, width, channels),而channels_first对应输入形状为(batch_size, channels, height, width)。
"""
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.eps = eps
self.data_format = data_format
if self.data_format not in ["channels_last", "channels_first"]:
raise NotImplementedError(f"Unsupported data format: {self.data_format}")
self.normalized_shape = (normalized_shape,)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.data_format == "channels_last":
x = torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
elif self.data_format == "channels_first":
input_dtype = x.dtype
x = x.float()
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = x.to(dtype=input_dtype)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
class SamAttention(nn.Module):
"""
SAM的注意力层,允许在将查询、键和值投影后缩小嵌入的大小。
"""
def __init__(self, config, downsample_rate=None):
super().__init__()
self.hidden_size = config.hidden_size
downsample_rate = config.attention_downsample_rate if downsample_rate is None else downsample_rate
self.internal_dim = config.hidden_size // downsample_rate
self.num_attention_heads = config.num_attention_heads
if self.internal_dim % config.num_attention_heads != 0:
raise ValueError("num_attention_heads must divide hidden_size.")
self.q_proj = nn.Linear(self.hidden_size, self.internal_dim)
self.k_proj = nn.Linear(self.hidden_size, self.internal_dim)
self.v_proj = nn.Linear(self.hidden_size, self.internal_dim)
self.out_proj = nn.Linear(self.internal_dim, self.hidden_size)
def _separate_heads(self, hidden_states: Tensor, num_attention_heads: int) -> Tensor:
batch, point_batch_size, n_tokens, channel = hidden_states.shape
c_per_head = channel // num_attention_heads
hidden_states = hidden_states.reshape(batch * point_batch_size, n_tokens, num_attention_heads, c_per_head)
return hidden_states.transpose(1, 2)
def _recombine_heads(self, hidden_states: Tensor, point_batch_size: int) -> Tensor:
batch, n_heads, n_tokens, c_per_head = hidden_states.shape
hidden_states = hidden_states.transpose(1, 2)
return hidden_states.reshape(batch // point_batch_size, point_batch_size, n_tokens, n_heads * c_per_head)
def forward(self, query: Tensor, key: Tensor, value: Tensor, attention_similarity: Tensor = None) -> Tensor:
query = self.q_proj(query)
key = self.k_proj(key)
value = self.v_proj(value)
point_batch_size = query.shape[1]
query = self._separate_heads(query, self.num_attention_heads)
key = self._separate_heads(key, self.num_attention_heads)
value = self._separate_heads(value, self.num_attention_heads)
_, _, _, c_per_head = query.shape
attn = query @ key.permute(0, 1, 3, 2)
attn = attn / math.sqrt(c_per_head)
attn = torch.softmax(attn, dim=-1)
if attention_similarity is not None:
attn = attn + attention_similarity
attn = torch.softmax(attn, dim=-1)
out = attn @ value
out = self._recombine_heads(out, point_batch_size)
out = self.out_proj(out)
return out
class SamTwoWayAttentionBlock(nn.Module):
def __init__(self, config, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False):
"""
A transformer block with four layers:
(1) self-attention of sparse inputs
(2) cross attention of sparse inputs -> dense inputs
(3) MLP block on sparse inputs
(4) cross attention of dense inputs -> sparse inputs
Arguments:
config (`SamMaskDecoderConfig`):
The configuration file used to instantiate the block
attention_downsample_rate (*optionalk*, int, defaults to 2):
The downsample ratio of the block used to reduce the inner dim of the attention.
skip_first_layer_pe (*optional*, bool, defaults to `False`):
Whether or not to skip the addition of the query_point_embedding on the first layer.
"""
super().__init__()
self.hidden_size = config.hidden_size
self.layer_norm_eps = config.layer_norm_eps
self.self_attn = SamAttention(config, downsample_rate=1)
self.layer_norm1 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
self.cross_attn_token_to_image = SamAttention(config, downsample_rate=attention_downsample_rate)
self.layer_norm2 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
self.mlp = SamMLPBlock(config)
self.layer_norm3 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
self.layer_norm4 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
self.cross_attn_image_to_token = SamAttention(config, downsample_rate=attention_downsample_rate)
self.skip_first_layer_pe = skip_first_layer_pe
def forward(
self,
queries: Tensor,
keys: Tensor,
query_point_embedding: Tensor,
key_point_embedding: Tensor,
attention_similarity: Tensor,
output_attentions: bool = False,
):
if self.skip_first_layer_pe:
queries = self.self_attn(query=queries, key=queries, value=queries)
else:
query = queries + query_point_embedding
attn_out = self.self_attn(query=query, key=query, value=queries)
queries = queries + attn_out
queries = self.layer_norm1(queries)
query = queries + query_point_embedding
key = keys + key_point_embedding
attn_out = self.cross_attn_token_to_image(
query=query, key=key, value=keys, attention_similarity=attention_similarity
)
queries = queries + attn_out
queries = self.layer_norm2(queries)
mlp_out = self.mlp(queries)
queries = queries + mlp_out
queries = self.layer_norm3(queries)
query = queries + query_point_embedding
key = keys + key_point_embedding
attn_out = self.cross_attn_image_to_token(query=key, key=query, value=queries)
keys = keys + attn_out
keys = self.layer_norm4(keys)
outputs = (queries, keys)
if output_attentions:
outputs = outputs + (attn_out,)
else:
outputs = outputs + (None,)
return outputs
class SamTwoWayTransformer(nn.Module):
def __init__(self, config: SamMaskDecoderConfig):
super().__init__()
self.config = config
self.num_hidden_layers = config.num_hidden_layers
self.layers = nn.ModuleList()
for i in range(self.num_hidden_layers):
self.layers.append(SamTwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0)))
self.final_attn_token_to_image = SamAttention(config)
self.layer_norm_final_attn = nn.LayerNorm(config.hidden_size)
def forward(
self,
point_embeddings: Tensor,
image_embeddings: Tensor,
image_positional_embeddings: Tensor,
attention_similarity: Tensor,
target_embedding=None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
all_attentions = ()
if image_embeddings is None:
raise ValueError("You have to specify an image_embedding")
image_embeddings = image_embeddings.flatten(2).permute(0, 2, 1).unsqueeze(1)
image_positional_embeddings = image_positional_embeddings.flatten(2).permute(0, 2, 1).unsqueeze(1)
queries = point_embeddings
keys = image_embeddings
for layer in self.layers:
if target_embedding is not None:
queries += target_embedding
queries, keys, attention_outputs = layer(
queries=queries,
keys=keys,
query_point_embedding=point_embeddings,
key_point_embedding=image_positional_embeddings,
attention_similarity=attention_similarity,
output_attentions=output_attentions,
)
if output_attentions:
all_attentions = all_attentions + (attention_outputs,)
query = queries + point_embeddings
key = keys + image_positional_embeddings
attn_out = self.final_attn_token_to_image(query=query, key=key, value=keys)
queries = queries + attn_out
queries = self.layer_norm_final_attn(queries)
return queries, keys, all_attentions
):
super().__init__()
self.num_layers = num_layers
self.activation = nn.ReLU()
self.proj_in = nn.Linear(input_dim, hidden_dim)
self.proj_out = nn.Linear(hidden_dim, output_dim)
self.layers = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for _ in range(num_layers - 2)])
self.sigmoid_output = sigmoid_output
def forward(self, hidden_states):
hidden_states = self.proj_in(hidden_states)
hidden_states = self.activation(hidden_states)
for layer in self.layers:
hidden_states = self.activation(layer(hidden_states))
hidden_states = self.proj_out(hidden_states)
if self.sigmoid_output:
hidden_states = F.sigmoid(hidden_states)
return hidden_states
class SamMaskDecoder(nn.Module):
def __init__(self, config: SamMaskDecoderConfig):
super().__init__()
self.hidden_size = config.hidden_size
self.num_multimask_outputs = config.num_multimask_outputs
self.num_mask_tokens = config.num_multimask_outputs + 1
self.iou_token = nn.Embedding(1, self.hidden_size)
self.mask_tokens = nn.Embedding(self.num_mask_tokens, self.hidden_size)
self.transformer = SamTwoWayTransformer(config)
self.upscale_conv1 = nn.ConvTranspose2d(self.hidden_size, self.hidden_size // 4, kernel_size=2, stride=2)
self.upscale_conv2 = nn.ConvTranspose2d(self.hidden_size // 4, self.hidden_size // 8, kernel_size=2, stride=2)
self.upscale_layer_norm = SamLayerNorm(self.hidden_size // 4, data_format="channels_first")
self.activation = nn.GELU()
mlps_list = []
for _ in range(self.num_mask_tokens):
mlps_list += [SamFeedForward(self.hidden_size, self.hidden_size, self.hidden_size // 8, 3)]
self.output_hypernetworks_mlps = nn.ModuleList(mlps_list)
self.iou_prediction_head = SamFeedForward(
self.hidden_size, config.iou_head_hidden_dim, self.num_mask_tokens, config.iou_head_depth
)
def forward(
self,
image_embeddings: torch.Tensor,
image_positional_embeddings: torch.Tensor,
sparse_prompt_embeddings: torch.Tensor,
dense_prompt_embeddings: torch.Tensor,
multimask_output: bool,
output_attentions: Optional[bool] = None,
attention_similarity: torch.Tensor = None,
target_embedding: torch.Tensor = None,
):
pass
class SamPositionalEmbedding(nn.Module):
def __init__(self, config):
super().__init__()
self.scale = config.hidden_size // 2
self.register_buffer("positional_embedding", self.scale * torch.randn((2, config.num_pos_feats)))
def forward(self, input_coords, input_shape=None):
"""Positionally encode points that are normalized to [0,1]."""
coordinates = input_coords.clone()
if input_shape is not None:
coordinates[:, :, :, 0] = coordinates[:, :, :, 0] / input_shape[1]
coordinates[:, :, :, 1] = coordinates[:, :, :, 1] / input_shape[0]
coordinates = 2 * coordinates - 1
coordinates = coordinates.to(self.positional_embedding.dtype)
coordinates = coordinates @ self.positional_embedding
coordinates = 2 * np.pi * coordinates
return torch.cat([torch.sin(coordinates), torch.cos(coordinates)], dim=-1)
class SamMaskEmbedding(nn.Module):
pass
def __init__(self, config: SamPromptEncoderConfig):
super().__init__()
self.mask_input_channels = config.mask_input_channels // 4
self.activation = ACT2FN[config.hidden_act]
self.conv1 = nn.Conv2d(1, self.mask_input_channels, kernel_size=2, stride=2)
self.conv2 = nn.Conv2d(self.mask_input_channels, config.mask_input_channels, kernel_size=2, stride=2)
self.conv3 = nn.Conv2d(config.mask_input_channels, config.hidden_size, kernel_size=1)
self.layer_norm1 = SamLayerNorm(
self.mask_input_channels, eps=config.layer_norm_eps, data_format="channels_first"
)
self.layer_norm2 = SamLayerNorm(
self.mask_input_channels * 4, eps=config.layer_norm_eps, data_format="channels_first"
)
def forward(self, masks):
hidden_states = self.conv1(masks)
hidden_states = self.layer_norm1(hidden_states)
hidden_states = self.activation(hidden_states)
hidden_states = self.conv2(hidden_states)
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.activation(hidden_states)
dense_embeddings = self.conv3(hidden_states)
return dense_embeddings
class SamPromptEncoder(nn.Module):
def __init__(self, config: SamPromptEncoderConfig, shared_patch_embedding):
super().__init__()
self.shared_embedding = shared_patch_embedding
self.mask_embed = SamMaskEmbedding(config)
self.no_mask_embed = nn.Embedding(1, config.hidden_size)
self.image_embedding_size = (config.image_embedding_size, config.image_embedding_size)
self.input_image_size = config.image_size
self.point_embed = nn.ModuleList(
[nn.Embedding(1, config.hidden_size) for i in range(config.num_point_embeddings)]
)
self.hidden_size = config.hidden_size
self.not_a_point_embed = nn.Embedding(1, config.hidden_size)
def _embed_points(self, points: torch.Tensor, labels: torch.Tensor, pad: bool) -> torch.Tensor:
"""Embeds point prompts."""
points = points + 0.5
if pad:
target_point_shape = (points.shape[0], points.shape[1], 1, points.shape[-1])
target_labels_shape = (points.shape[0], points.shape[1], 1)
padding_point = torch.zeros(target_point_shape, device=points.device)
padding_label = -torch.ones(target_labels_shape, device=labels.device)
points = torch.cat([points, padding_point], dim=2)
labels = torch.cat([labels, padding_label], dim=2)
input_shape = (self.input_image_size, self.input_image_size)
point_embedding = self.shared_embedding(points, input_shape)
point_embedding = torch.where(labels[..., None] == -1, self.not_a_point_embed.weight, point_embedding)
point_embedding = torch.where(
labels[..., None] != -10,
point_embedding,
torch.tensor(0.0, dtype=point_embedding.dtype, device=point_embedding.device),
)
point_embedding = torch.where(
(labels == 0)[:, :, :, None],
point_embedding + self.point_embed[0].weight[None, None, :, :],
point_embedding,
)
point_embedding = torch.where(
(labels == 1)[:, :, :, None],
point_embedding + self.point_embed[1].weight[None, None, :, :],
point_embedding,
)
return point_embedding
def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
"""Embeds box prompts."""
boxes = boxes + 0.5
batch_size, nb_boxes = boxes.shape[:2]
coords = boxes.reshape(batch_size, nb_boxes, 2, 2)
input_shape = (self.input_image_size, self.input_image_size)
corner_embedding = self.shared_embedding(coords, input_shape)
corner_embedding[:, :, 0, :] += self.point_embed[2].weight
corner_embedding[:, :, 1, :] += self.point_embed[3].weight
return corner_embedding
def forward(
self,
input_points: Optional[Tuple[torch.Tensor, torch.Tensor]],
input_labels: Optional[torch.Tensor],
input_boxes: Optional[torch.Tensor],
input_masks: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Embeds different types of prompts, returning both sparse and dense embeddings.
Args:
points (`torch.Tensor`, *optional*):
point coordinates and labels to embed.
boxes (`torch.Tensor`, *optional*):
boxes to embed
masks (`torch.Tensor`, *optional*):
masks to embed
"""
sparse_embeddings = None
batch_size = 1
target_device = self.shared_embedding.positional_embedding.device
if input_points is not None:
batch_size, point_batch_size = input_points.shape[:2]
if input_labels is None:
raise ValueError("If points are provided, labels must also be provided.")
point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None))
sparse_embeddings = point_embeddings
if input_boxes is not None:
batch_size = input_boxes.shape[0]
box_embeddings = self._embed_boxes(input_boxes)
if sparse_embeddings is None:
sparse_embeddings = box_embeddings
else:
sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=2)
if input_masks is not None:
dense_embeddings = self.mask_embed(input_masks)
else:
dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
batch_size, -1, self.image_embedding_size[0], self.image_embedding_size[1]
)
if sparse_embeddings is None:
sparse_embeddings = torch.zeros((batch_size, 1, 1, self.hidden_size), device=target_device)
return sparse_embeddings, dense_embeddings
"""
Add decomposed relative positional embeddings to the attention scores.
Args:
attn (`torch.Tensor`):
Attention scores between query and key.
query (`torch.Tensor`):
Query tensor.
rel_pos_h (`torch.Tensor`):
Relative positional embeddings along height.
rel_pos_w (`torch.Tensor`):
Relative positional embeddings along width.
q_size (`Tuple[int, int]`):
Size of the query tensor (height, width).
k_size (`Tuple[int, int]`):
Size of the key tensor (height, width).
Returns:
`torch.Tensor`: Attention scores modified by relative positional embeddings.
"""
max_rel_dist = int(2 * max(q_size[0], k_size[0]) - 1)
rel_pos_h_resized = F.interpolate(rel_pos_h.unsqueeze(0), size=max_rel_dist, mode="linear")
rel_pos_w_resized = F.interpolate(rel_pos_w.unsqueeze(0), size=max_rel_dist, mode="linear")
rel_pos_h_resized = rel_pos_h_resized.squeeze(0)
rel_pos_w_resized = rel_pos_w_resized.squeeze(0)
q_coords = torch.arange(q_size[0]).unsqueeze(1) * max(k_size[0] / q_size[0], 1.0)
k_coords = torch.arange(k_size[0]).unsqueeze(0) * max(q_size[0] / k_size[0], 1.0)
relative_coords_h = (q_coords - k_coords) + (k_size[0] - 1) * max(q_size[0] / k_size[0], 1.0)
q_coords = torch.arange(q_size[1]).unsqueeze(1) * max(k_size[1] / q_size[1], 1.0)
k_coords = torch.arange(k_size[1]).unsqueeze(0) * max(q_size[1] / k_size[1], 1.0)
relative_coords_w = (q_coords - k_coords) + (k_size[1] - 1) * max(q_size[1] / k_size[1], 1.0)
rel_pos_h = rel_pos_h_resized[relative_coords_h.long()]
rel_pos_w = rel_pos_w_resized[relative_coords_w.long()]
rel_pos = rel_pos_h + rel_pos_w.unsqueeze(0)
rel_pos = rel_pos.unsqueeze(0).expand(attn.size(0), -1, -1)
attn = attn + rel_pos
return attn
) -> torch.Tensor:
"""
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py
Args:
attn (`torch.Tensor`):
attention map.
query (`torch.Tensor`):
query q in the attention layer with shape (batch_size, query_height * query_width, channel).
rel_pos_h (`torch.Tensor`):
relative position embeddings (Lh, channel) for height axis.
rel_pos_w (`torch.Tensor`):
relative position embeddings (Lw, channel) for width axis.
q_size (tuple):
spatial sequence size of query q with (query_height, query_width).
k_size (tuple):
spatial sequence size of key k with (key_height, key_width).
Returns:
attn (`torch.Tensor`):
attention map with added relative positional embeddings.
"""
query_height, query_width = q_size
key_height, key_width = k_size
relative_position_height = self.get_rel_pos(query_height, key_height, rel_pos_h)
relative_position_width = self.get_rel_pos(query_width, key_width, rel_pos_w)
batch_size, _, dim = query.shape
reshaped_query = query.reshape(batch_size, query_height, query_width, dim)
rel_h = torch.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height)
rel_w = torch.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width)
attn = attn.reshape(batch_size, query_height, query_width, key_height, key_width)
attn = attn + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
attn = attn.reshape(batch_size, query_height * query_width, key_height * key_width)
return attn
def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor:
batch_size, height, width, _ = hidden_states.shape
qkv = (
self.qkv(hidden_states)
.reshape(batch_size, height * width, 3, self.num_attention_heads, -1)
.permute(2, 0, 3, 1, 4)
)
query, key, value = qkv.reshape(3, batch_size * self.num_attention_heads, height * width, -1).unbind(0)
attn_weights = (query * self.scale) @ key.transpose(-2, -1)
if self.use_rel_pos:
attn_weights = self.add_decomposed_rel_pos(
attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
)
attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype)
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -1)
attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, height, width, -1)
attn_output = self.proj(attn_output)
if output_attentions:
outputs = (attn_output, attn_weights)
else:
outputs = (attn_output, None)
return outputs
def __init__(self, config, window_size):
"""
Initialize the SamVisionLayer module.
Args:
config (object): Configuration object containing parameters like hidden_size and layer_norm_eps.
window_size (int): Size of the sliding window to partition the input.
"""
super().__init__()
self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.attn = SamVisionAttention(config, window_size)
self.layer_norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.mlp = SamMLPBlock(config)
self.window_size = window_size
def window_partition(self, hidden_states: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
"""
Partition input tensor into non-overlapping windows with padding if necessary.
Args:
hidden_states (torch.Tensor): Input tensor with shape [batch_size, height, width, channel].
window_size (int): Size of the window.
Returns:
windows (torch.Tensor): Tensor of windows after partitioning with shape [batch_size * num_windows, window_size, window_size, channel].
(pad_height, pad_width) (Tuple[int, int]): Padded height and width before partitioning.
"""
batch_size, height, width, channel = hidden_states.shape
pad_h = (window_size - height % window_size) % window_size
pad_w = (window_size - width % window_size) % window_size
hidden_states = F.pad(hidden_states, (0, 0, 0, pad_w, 0, pad_h))
pad_height, pad_width = height + pad_h, width + pad_w
hidden_states = hidden_states.reshape(
batch_size, pad_height // window_size, window_size, pad_width // window_size, window_size, channel
)
windows = hidden_states.permute(0, 1, 3, 2, 4, 5).contiguous().reshape(-1, window_size, window_size, channel)
return windows, (pad_height, pad_width)
def window_unpartition(
self, windows: torch.Tensor, window_size: int, padding_shape: Tuple[int, int], original_shape: Tuple[int, int]
):
"""
Reconstruct the original tensor from windows.
Args:
windows (torch.Tensor): Tensor of windows with shape [batch_size * num_windows, window_size, window_size, channel].
window_size (int): Size of the window.
padding_shape (Tuple[int, int]): Padded height and width before partitioning.
original_shape (Tuple[int, int]): Original height and width of the input tensor before padding.
Returns:
Tensor of original shape [batch_size, height, width, channel].
"""
pass
) -> torch.Tensor:
"""
Args:
hidden_states (tensor):
输入的张量,包含 [batch_size * num_windows, window_size, window_size, channel] 的数据。
window_size (int):
窗口大小。
padding_shape (Tuple):
填充后的高度和宽度 (pad_height, pad_width)。
original_shape (Tuple):
填充前的原始高度和宽度 (height, width)。
Returns:
hidden_states: 没有分区的序列,维度为 [batch_size, height, width, channel]。
"""
pad_height, pad_width = padding_shape
height, width = original_shape
batch_size = windows.shape[0] // (pad_height * pad_width // window_size // window_size)
hidden_states = windows.reshape(
batch_size, pad_height // window_size, pad_width // window_size, window_size, window_size, -1
)
hidden_states = (
hidden_states.permute(0, 1, 3, 2, 4, 5).contiguous().reshape(batch_size, pad_height, pad_width, -1)
)
hidden_states = hidden_states[:, :height, :width, :].contiguous()
return hidden_states
def forward(
self,
hidden_states: torch.Tensor,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.FloatTensor]:
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
if self.window_size > 0:
height, width = hidden_states.shape[1], hidden_states.shape[2]
hidden_states, padding_shape = self.window_partition(hidden_states, self.window_size)
hidden_states, attn_weights = self.attn(
hidden_states=hidden_states,
output_attentions=output_attentions,
)
if self.window_size > 0:
hidden_states = self.window_unpartition(hidden_states, self.window_size, padding_shape, (height, width))
hidden_states = residual + hidden_states
layernorm_output = self.layer_norm2(hidden_states)
hidden_states = hidden_states + self.mlp(layernorm_output)
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs
class SamVisionNeck(nn.Module):
def __init__(self, config: SamVisionConfig):
super().__init__()
self.config = config
self.conv1 = nn.Conv2d(config.hidden_size, config.output_channels, kernel_size=1, bias=False)
self.layer_norm1 = SamLayerNorm(config.output_channels, data_format="channels_first")
self.conv2 = nn.Conv2d(config.output_channels, config.output_channels, kernel_size=3, padding=1, bias=False)
self.layer_norm2 = SamLayerNorm(config.output_channels, data_format="channels_first")
def forward(self, hidden_states):
hidden_states = hidden_states.permute(0, 3, 1, 2)
hidden_states = self.conv1(hidden_states)
hidden_states = self.layer_norm1(hidden_states)
hidden_states = self.conv2(hidden_states)
hidden_states = self.layer_norm2(hidden_states)
return hidden_states
class SamVisionEncoder(nn.Module):
def __init__(self, config: SamVisionConfig):
super().__init__()
self.config = config
self.image_size = config.image_size
self.patch_embed = SamPatchEmbeddings(config)
self.pos_embed = None
if config.use_abs_pos:
self.pos_embed = nn.Parameter(
torch.zeros(
1,
config.image_size // config.patch_size,
config.image_size // config.patch_size,
config.hidden_size,
)
)
self.layers = nn.ModuleList()
for i in range(config.num_hidden_layers):
layer = SamVisionLayer(
config,
window_size=config.window_size if i not in config.global_attn_indexes else 0,
)
self.layers.append(layer)
self.neck = SamVisionNeck(config)
self.gradient_checkpointing = False
def get_input_embeddings(self):
return self.patch_embed
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SamVisionEncoderOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
hidden_states = self.patch_embed(pixel_values)
if self.pos_embed is not None:
hidden_states = hidden_states + self.pos_embed
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
for i, layer_module in enumerate(self.layers):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
)
else:
layer_outputs = layer_module(hidden_states, output_attentions=output_attentions)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
hidden_states = self.neck(hidden_states)
if not return_dict:
outputs = (hidden_states,)
if output_hidden_states:
outputs = outputs + (all_hidden_states,)
if output_attentions:
outputs = outputs + (all_self_attentions,)
return outputs
return SamVisionEncoderOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
class SamPreTrainedModel(PreTrainedModel):
config_class = SamConfig
base_model_prefix = "sam"
main_input_name = "pixel_values"
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
@add_start_docstrings(
"Segment Anything Model (SAM) for generating segmentation masks, given an input image and ",
" optional 2D location and bounding boxes.",
SAM_START_DOCSTRING,
)
class SamModel(SamPreTrainedModel):
_tied_weights_keys = ["prompt_encoder.shared_embedding.positional_embedding"]
def __init__(self, config):
super().__init__(config)
self.shared_image_embedding = SamPositionalEmbedding(config.vision_config)
self.vision_encoder = SamVisionEncoder(config.vision_config)
self.prompt_encoder = SamPromptEncoder(config.prompt_encoder_config, self.shared_image_embedding)
self.mask_decoder = SamMaskDecoder(config.mask_decoder_config)
self.post_init()
def get_input_embeddings(self):
return self.vision_encoder.get_input_embeddings()
注释:
def get_image_wide_positional_embeddings(self):
size = self.config.prompt_encoder_config.image_embedding_size
target_device = self.shared_image_embedding.positional_embedding.device
target_dtype = self.shared_image_embedding.positional_embedding.dtype
grid = torch.ones((size, size), device=target_device, dtype=target_dtype)
y_embed = grid.cumsum(dim=0) - 0.5
x_embed = grid.cumsum(dim=1) - 0.5
y_embed = y_embed / size
x_embed = x_embed / size
positional_embedding = self.shared_image_embedding(torch.stack([x_embed, y_embed], dim=-1))
return positional_embedding.permute(2, 0, 1).unsqueeze(0)
@torch.no_grad()
def get_image_embeddings(
self,
pixel_values,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
r"""
返回通过视觉编码器处理像素值得到的图像嵌入。
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
输入的像素值
output_attentions (`bool`, *optional*):
是否返回所有注意力层的注意力张量。
output_hidden_states (`bool`, *optional*):
是否返回所有层的隐藏状态。
return_dict (`bool`, *optional*):
是否返回 [`~utils.ModelOutput`] 而不是简单的元组。
"""
vision_output = self.vision_encoder(
pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
image_embeddings = vision_output[0]
return image_embeddings
@torch.no_grad()
def get_prompt_embeddings(
self,
input_points: Optional[torch.FloatTensor] = None,
input_labels: Optional[torch.LongTensor] = None,
input_boxes: Optional[torch.FloatTensor] = None,
input_masks: Optional[torch.LongTensor] = None,
):
):
r"""
Returns the prompt embeddings by passing the input points, labels, boxes and masks through the prompt encoder.
Args:
input_points (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_points_per_image, 2)`):
Optional input points for the prompt encoder. The padding of the point is automatically done by the
processor. `point_batch_size` refers to the number of masks that we want the model to predict per
point. The model will output `point_batch_size` times 3 masks in total.
input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points_per_image)`):
Optional input labels for the prompt encoder. The padding of the labels is automatically done by the
processor, or can be fed by the user.
input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes_per_image, 4)`):
Optional input boxes for the prompt encoder. The padding of the boxes is automatically done by the
processor. users can also pass manually the input boxes.
input_masks (`torch.LongTensor` of shape `(batch_size, image_size, image_size)`):
Optional input masks for the prompt encoder.
"""
prompt_output = self.prompt_encoder(
input_points=input_points,
input_labels=input_labels,
input_boxes=input_boxes,
input_masks=input_masks,
)
return prompt_output
@add_start_docstrings_to_model_forward(SAM_INPUTS_DOCSTRING)
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
input_points: Optional[torch.FloatTensor] = None,
input_labels: Optional[torch.LongTensor] = None,
input_boxes: Optional[torch.FloatTensor] = None,
input_masks: Optional[torch.LongTensor] = None,
image_embeddings: Optional[torch.FloatTensor] = None,
multimask_output: bool = True,
attention_similarity: Optional[torch.FloatTensor] = None,
target_embedding: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
.\models\sam\modeling_tf_sam.py
"""
TensorFlow SAM model. This file was mostly generated by auto-translation from the PyTorch original. In the event of a
discrepancy, the original file should be regarded as the 'reference' version.
"""
from __future__ import annotations
import collections
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import numpy as np
import tensorflow as tf
from ...activations_tf import ACT2FN
from ...modeling_tf_outputs import TFBaseModelOutput
from ...modeling_tf_utils import (
TFModelInputType, TFPreTrainedModel,
keras, shape_list, unpack_inputs
)
from ...tf_utils import flatten, functional_layernorm
from ...utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from .configuration_sam import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "SamConfig"
_CHECKPOINT_FOR_DOC = "facebook/sam-vit-huge"
TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST = [
"facebook/sam-vit-huge",
"facebook/sam-vit-large",
"facebook/sam-vit-base",
]
@dataclass
class TFSamVisionEncoderOutput(ModelOutput):
"""
Base class for sam vision model's outputs that also contains image embeddings obtained by applying the projection
layer to the pooler_output.
"""
"""
Args:
image_embeds (`tf.Tensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
The image embeddings obtained by applying the projection layer to the pooler_output.
last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for
the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
image_embeds: tf.Tensor | None = None
last_hidden_state: tf.Tensor = None
hidden_states: Tuple[tf.Tensor, ...] | None = None
attentions: Tuple[tf.Tensor, ...] | None = None
@dataclass
class TFSamImageSegmentationOutput(ModelOutput):
"""
Base class for Segment-Anything model's output
Args:
iou_scores (`tf.Tensor` of shape `(batch_size, num_masks)`):
The iou scores of the predicted masks.
预测掩膜的IoU分数。
pred_masks (`tf.Tensor` of shape `(batch_size, num_masks, height, width)`):
The predicted low resolutions masks. Needs to be post-processed by the processor.
预测的低分辨率掩膜,需要由处理器进行后处理。
vision_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for
the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the vision model at the output of each layer plus the optional initial embedding outputs.
视觉模型在每一层输出的隐藏状态,以及可选的初始嵌入输出。
vision_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
注意力权重,在注意力softmax后计算的,用于计算自注意力头中的加权平均值。
mask_decoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
注意力权重,在注意力softmax后计算的,用于计算自注意力头中的加权平均值。
"""
iou_scores: tf.Tensor = None
pred_masks: tf.Tensor = None
vision_hidden_states: Tuple[tf.Tensor, ...] | None = None
vision_attentions: Tuple[tf.Tensor, ...] | None = None
mask_decoder_attentions: Tuple[tf.Tensor, ...] | None = None
class TFSamPatchEmbeddings(keras.layers.Layer):
"""
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
Transformer.
"""
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
image_size, patch_size = config.image_size, config.patch_size
num_channels, hidden_size = config.num_channels, config.hidden_size
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.num_patches = num_patches
self.projection = keras.layers.Conv2D(
hidden_size, kernel_size=patch_size, strides=patch_size, name="projection"
)
def call(self, pixel_values):
batch_size, num_channels, height, width = shape_list(pixel_values)
if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
if height != self.image_size[0] or width != self.image_size[1]:
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
)
embeddings = self.projection(tf.transpose(pixel_values, perm=[0, 2, 3, 1]))
return embeddings
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "projection", None) is not None:
with tf.name_scope(self.projection.name):
self.projection.build([None, None, None, self.num_channels])
class TFSamMLPBlock(keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.lin1 = keras.layers.Dense(config.mlp_dim, name="lin1")
self.lin2 = keras.layers.Dense(config.hidden_size, name="lin2")
self.act = ACT2FN[config.hidden_act]
self.config = config
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.lin1(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.lin2(hidden_states)
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "lin1", None) is not None:
with tf.name_scope(self.lin1.name):
self.lin1.build([None, None, self.config.hidden_size])
if getattr(self, "lin2", None) is not None:
with tf.name_scope(self.lin2.name):
self.lin2.build([None, None, self.config.mlp_dim])
class TFSamLayerNorm(keras.layers.Layer):
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last", **kwargs):
super().__init__(**kwargs)
self.eps = eps
self.data_format = data_format
self.normalized_shape = normalized_shape
if self.data_format not in ["channels_last", "channels_first"]:
raise NotImplementedError(f"Unsupported data format: {self.data_format}")
def build(self, input_shape):
self.weight = self.add_weight(shape=self.normalized_shape, initializer="ones", name="weight")
self.bias = self.add_weight(shape=self.normalized_shape, initializer="zeros", name="bias")
super().build(input_shape)
def call(self, x: tf.Tensor) -> tf.Tensor:
if self.data_format == "channels_last":
x = functional_layernorm(x, weight=self.weight, bias=self.bias, epsilon=self.eps, axis=-1)
elif self.data_format == "channels_first":
x = functional_layernorm(x, weight=self.weight, bias=self.bias, epsilon=self.eps, axis=1)
return x
class TFSamAttention(keras.layers.Layer):
"""
SAM's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
values.
"""
def __init__(self, config, downsample_rate=None, **kwargs):
super().__init__(**kwargs)
self.hidden_size = config.hidden_size
downsample_rate = config.attention_downsample_rate if downsample_rate is None else downsample_rate
self.internal_dim = config.hidden_size // downsample_rate
self.num_attention_heads = config.num_attention_heads
if self.internal_dim % config.num_attention_heads != 0:
raise ValueError("num_attention_heads must divide hidden_size.")
self.q_proj = keras.layers.Dense(self.internal_dim, name="q_proj")
self.k_proj = keras.layers.Dense(self.internal_dim, name="k_proj")
self.v_proj = keras.layers.Dense(self.internal_dim, name="v_proj")
self.out_proj = keras.layers.Dense(self.hidden_size, name="out_proj")
def _separate_heads(self, hidden_states: tf.Tensor, num_attention_heads: int) -> tf.Tensor:
batch, point_batch_size, n_tokens, channel = shape_list(hidden_states)
c_per_head = channel // num_attention_heads
hidden_states = tf.reshape(
hidden_states, (batch * point_batch_size, n_tokens, num_attention_heads, c_per_head)
)
return tf.transpose(hidden_states, perm=[0, 2, 1, 3])
def _recombine_heads(self, hidden_states: tf.Tensor, point_batch_size: int) -> tf.Tensor:
batch, n_heads, n_tokens, c_per_head = shape_list(hidden_states)
hidden_states = tf.transpose(hidden_states, perm=[0, 2, 1, 3])
return tf.reshape(
hidden_states,
(batch // tf.reduce_max([1, point_batch_size]), point_batch_size, n_tokens, n_heads * c_per_head),
)
def call(self, query: tf.Tensor, key: tf.Tensor, value: tf.Tensor) -> tf.Tensor:
query = self.q_proj(query)
key = self.k_proj(key)
value = self.v_proj(value)
point_batch_size = shape_list(query)[1]
query = self._separate_heads(query, self.num_attention_heads)
key = self._separate_heads(key, self.num_attention_heads)
value = self._separate_heads(value, self.num_attention_heads)
_, _, _, c_per_head = shape_list(query)
attn = tf.matmul(
query, tf.transpose(key, perm=[0, 1, 3, 2])
)
attn = attn / tf.math.sqrt(float(c_per_head))
attn = tf.nn.softmax(attn, axis=-1)
out = tf.matmul(attn, value)
out = self._recombine_heads(out, point_batch_size)
out = self.out_proj(out)
return out
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "q_proj", None) is not None:
with tf.name_scope(self.q_proj.name):
self.q_proj.build([None, None, self.hidden_size])
if getattr(self, "k_proj", None) is not None:
with tf.name_scope(self.k_proj.name):
self.k_proj.build([None, None, self.hidden_size])
if getattr(self, "v_proj", None) is not None:
with tf.name_scope(self.v_proj.name):
self.v_proj.build([None, None, self.hidden_size])
if getattr(self, "out_proj", None) is not None:
with tf.name_scope(self.out_proj.name):
self.out_proj.build([None, None, self.internal_dim])
class TFSamTwoWayAttentionBlock(keras.layers.Layer):
def __init__(self, config, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False, **kwargs):
"""
初始化函数,配置了四个层次的 Transformer 模块:
(1) 自注意力层,处理稀疏输入
(2) 从稀疏输入到密集输入的交叉注意力层
(3) 在稀疏输入上的多层感知机块
(4) 从密集输入到稀疏输入的交叉注意力层
Arguments:
config (`SamMaskDecoderConfig`):
用于实例化该块的配置文件
attention_downsample_rate (*optionalk*, int, defaults to 2):
用于减少注意力内部维度的下采样比率
skip_first_layer_pe (*optional*, bool, defaults to `False`):
是否跳过在第一层添加 query_point_embedding 的步骤
"""
super().__init__(**kwargs)
self.hidden_size = config.hidden_size
self.layer_norm_eps = config.layer_norm_eps
self.self_attn = TFSamAttention(config, downsample_rate=1, name="self_attn")
self.layer_norm1 = keras.layers.LayerNormalization(epsilon=self.layer_norm_eps, name="layer_norm1")
self.cross_attn_token_to_image = TFSamAttention(
config, downsample_rate=attention_downsample_rate, name="cross_attn_token_to_image"
)
self.layer_norm2 = keras.layers.LayerNormalization(epsilon=self.layer_norm_eps, name="layer_norm2")
self.mlp = TFSamMLPBlock(config, name="mlp")
self.layer_norm3 = keras.layers.LayerNormalization(epsilon=self.layer_norm_eps, name="layer_norm3")
self.layer_norm4 = keras.layers.LayerNormalization(epsilon=self.layer_norm_eps, name="layer_norm4")
self.cross_attn_image_to_token = TFSamAttention(
config, downsample_rate=attention_downsample_rate, name="cross_attn_image_to_token"
)
self.skip_first_layer_pe = skip_first_layer_pe
def call(
self,
queries: tf.Tensor,
keys: tf.Tensor,
query_point_embedding: tf.Tensor,
key_point_embedding: tf.Tensor,
output_attentions: bool = False,
):
if self.skip_first_layer_pe:
queries = self.self_attn(query=queries, key=queries, value=queries)
else:
query = queries + query_point_embedding
attn_out = self.self_attn(query=query, key=query, value=queries)
queries = queries + attn_out
queries = self.layer_norm1(queries)
query = queries + query_point_embedding
key = keys + key_point_embedding
attn_out = self.cross_attn_token_to_image(query=query, key=key, value=keys)
queries = queries + attn_out
queries = self.layer_norm2(queries)
mlp_out = self.mlp(queries)
queries = queries + mlp_out
queries = self.layer_norm3(queries)
query = queries + query_point_embedding
key = keys + key_point_embedding
attn_out = self.cross_attn_image_to_token(query=key, key=query, value=queries)
keys = keys + attn_out
keys = self.layer_norm4(keys)
outputs = (queries, keys)
if output_attentions:
outputs = outputs + (attn_out,)
else:
outputs = outputs + (None,)
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "self_attn", None) is not None:
with tf.name_scope(self.self_attn.name):
self.self_attn.build(None)
if getattr(self, "layer_norm1", None) is not None:
with tf.name_scope(self.layer_norm1.name):
self.layer_norm1.build([None, None, None, self.hidden_size])
if getattr(self, "cross_attn_token_to_image", None) is not None:
with tf.name_scope(self.cross_attn_token_to_image.name):
self.cross_attn_token_to_image.build(None)
if getattr(self, "layer_norm2", None) is not None:
with tf.name_scope(self.layer_norm2.name):
self.layer_norm2.build([None, None, None, self.hidden_size])
if getattr(self, "mlp", None) is not None:
with tf.name_scope(self.mlp.name):
self.mlp.build(None)
if getattr(self, "layer_norm3", None) is not None:
with tf.name_scope(self.layer_norm3.name):
self.layer_norm3.build([None, None, None, self.hidden_size])
if getattr(self, "cross_attn_image_to_token", None) is not None:
with tf.name_scope(self.cross_attn_image_to_token.name):
self.cross_attn_image_to_token.build(None)
class TFSamTwoWayTransformer(keras.layers.Layer):
def __init__(self, config: SamMaskDecoderConfig, **kwargs):
super().__init__(**kwargs)
self.config = config
self.num_hidden_layers = config.num_hidden_layers
self.layers = []
for i in range(self.num_hidden_layers):
self.layers.append(TFSamTwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0), name=f"layers_._{i}"))
self.final_attn_token_to_image = TFSamAttention(config, name="final_attn_token_to_image")
self.layer_norm_final_attn = keras.layers.LayerNormalization(
epsilon=config.layer_norm_eps, name="layer_norm_final_attn"
)
def call(
self,
point_embeddings: tf.Tensor,
image_embeddings: tf.Tensor,
image_positional_embeddings: tf.Tensor,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, TFBaseModelOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
all_attentions = ()
if image_embeddings is None:
raise ValueError("You have to specify an image_embedding")
image_embeddings = tf.transpose(flatten(image_embeddings, 2), perm=(0, 2, 1))[:, None]
image_positional_embeddings = tf.transpose(flatten(image_positional_embeddings, 2), (0, 2, 1))[:, None]
queries = point_embeddings
keys = image_embeddings
for layer in self.layers:
queries, keys, attention_outputs = layer(
queries=queries,
keys=keys,
query_point_embedding=point_embeddings,
key_point_embedding=image_positional_embeddings,
output_attentions=output_attentions,
)
if output_attentions:
all_attentions = all_attentions + (attention_outputs,)
query = queries + point_embeddings
key = keys + image_positional_embeddings
attn_out = self.final_attn_token_to_image(query=query, key=key, value=keys)
queries = queries + attn_out
queries = self.layer_norm_final_attn(queries)
return queries, keys, all_attentions
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "final_attn_token_to_image", None) is not None:
with tf.name_scope(self.final_attn_token_to_image.name):
self.final_attn_token_to_image.build(None)
if getattr(self, "layer_norm_final_attn", None) is not None:
with tf.name_scope(self.layer_norm_final_attn.name):
self.layer_norm_final_attn.build([None, None, None, self.config.hidden_size])
for layer in self.layers:
with tf.name_scope(layer.name):
layer.build(None)
class TFSamFeedForward(keras.layers.Layer):
def __init__(
self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, sigmoid_output: bool = False, **kwargs
):
super().__init__(**kwargs)
self.num_layers = num_layers
self.activation = keras.layers.ReLU()
self.proj_in = keras.layers.Dense(hidden_dim, input_shape=(input_dim,), name="proj_in")
self.proj_out = keras.layers.Dense(output_dim, input_shape=(hidden_dim,), name="proj_out")
self.layers = [
keras.layers.Dense(hidden_dim, input_shape=(hidden_dim,), name=f"layers_._{i}")
for i in range(num_layers - 2)
]
self.sigmoid_output = sigmoid_output
self.hidden_dim = hidden_dim
self.input_dim = input_dim
def call(self, hidden_states):
hidden_states = self.proj_in(hidden_states)
hidden_states = self.activation(hidden_states)
for layer in self.layers:
hidden_states = self.activation(layer(hidden_states))
hidden_states = self.proj_out(hidden_states)
if self.sigmoid_output:
hidden_states = tf.sigmoid(hidden_states)
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "proj_in", None) is not None:
with tf.name_scope(self.proj_in.name):
self.proj_in.build([None, None, self.input_dim])
if getattr(self, "proj_out", None) is not None:
with tf.name_scope(self.proj_out.name):
self.proj_out.build([None, None, self.hidden_dim])
if getattr(self, "layers", None) is not None:
for layer in self.layers:
with tf.name_scope(layer.name):
layer.build([None, None, self.hidden_dim])
def __init__(self, config: SamMaskDecoderConfig, **kwargs):
super().__init__(**kwargs)
self.hidden_size = config.hidden_size
self.num_multimask_outputs = config.num_multimask_outputs
self.num_mask_tokens = config.num_multimask_outputs + 1
self.transformer = TFSamTwoWayTransformer(config, name="transformer")
self.upscale_conv1 = keras.layers.Conv2DTranspose(
self.hidden_size // 4, kernel_size=2, strides=2, name="upscale_conv1", data_format="channels_first"
)
self.upscale_conv2 = keras.layers.Conv2DTranspose(
self.hidden_size // 8, kernel_size=2, strides=2, name="upscale_conv2", data_format="channels_first"
)
self.upscale_layer_norm = TFSamLayerNorm(
self.hidden_size // 4, data_format="channels_first", name="upscale_layer_norm"
)
self.activation = tf.nn.gelu
mlps_list = []
for i in range(self.num_mask_tokens):
mlps_list += [
TFSamFeedForward(
self.hidden_size,
self.hidden_size,
self.hidden_size // 8,
3,
name=f"output_hypernetworks_mlps_._{i}",
)
]
self.output_hypernetworks_mlps = mlps_list
self.iou_prediction_head = TFSamFeedForward(
self.hidden_size,
config.iou_head_hidden_dim,
self.num_mask_tokens,
config.iou_head_depth,
name="iou_prediction_head",
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
self.iou_token = self.add_weight(shape=(1, self.hidden_size), name="iou_token.weight", trainable=True)
self.mask_tokens = self.add_weight(
shape=(self.num_mask_tokens, self.hidden_size), name="mask_tokens.weight", trainable=True
)
if getattr(self, "transformer", None) is not None:
with tf.name_scope(self.transformer.name):
self.transformer.build(None)
if getattr(self, "upscale_conv1", None) is not None:
with tf.name_scope(self.upscale_conv1.name):
self.upscale_conv1.build([None, self.hidden_size, None, None])
if getattr(self, "upscale_conv2", None) is not None:
with tf.name_scope(self.upscale_conv2.name):
self.upscale_conv2.build([None, self.hidden_size // 4, None, None])
if getattr(self, "upscale_layer_norm", None) is not None:
with tf.name_scope(self.upscale_layer_norm.name):
self.upscale_layer_norm.build(None)
if getattr(self, "iou_prediction_head", None) is not None:
with tf.name_scope(self.iou_prediction_head.name):
self.iou_prediction_head.build(None)
for mlp in self.output_hypernetworks_mlps:
with tf.name_scope(mlp.name):
mlp.build(None)
def call(
self,
image_embeddings: tf.Tensor,
image_positional_embeddings: tf.Tensor,
sparse_prompt_embeddings: tf.Tensor,
dense_prompt_embeddings: tf.Tensor,
multimask_output: bool,
output_attentions: Optional[bool] = None,
class TFSamPositionalEmbedding(keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.scale = config.hidden_size // 2
self.config = config
def build(self, input_shape):
self.positional_embedding = self.add_weight(
name="positional_embedding",
shape=(2, self.config.num_pos_feats),
initializer=keras.initializers.RandomNormal(mean=0.0, stddev=self.scale),
trainable=False,
)
super().build(input_shape)
def call(self, input_coords, input_shape=None):
"""Positionally encode points that are normalized to [0,1]."""
coordinates = tf.identity(input_coords)
if input_shape is not None:
coordinates = tf.stack(
[
tf.cast(coordinates[:, :, :, 0], tf.float32) / input_shape[1],
tf.cast(coordinates[:, :, :, 1], tf.float32) / input_shape[0],
],
axis=-1,
)
coordinates = 2 * coordinates - 1
coordinates = tf.cast(coordinates, self.positional_embedding.dtype)
coordinates = tf.matmul(coordinates, self.positional_embedding)
coordinates = 2 * np.pi * coordinates
return tf.concat([tf.sin(coordinates), tf.cos(coordinates)], axis=-1)
class TFSamMaskEmbedding(keras.layers.Layer):
def __init__(self, config: SamPromptEncoderConfig, **kwargs):
super().__init__(**kwargs)
self.mask_input_channels = config.mask_input_channels // 4
self.activation = ACT2FN[config.hidden_act]
self.conv1 = keras.layers.Conv2D(self.mask_input_channels, kernel_size=2, strides=2, name="conv1")
self.conv2 = keras.layers.Conv2D(config.mask_input_channels, kernel_size=2, strides=2, name="conv2")
self.conv3 = keras.layers.Conv2D(config.hidden_size, kernel_size=1, name="conv3")
self.layer_norm1 = TFSamLayerNorm(self.mask_input_channels, config.layer_norm_eps, name="layer_norm1")
self.layer_norm2 = TFSamLayerNorm(self.mask_input_channels * 4, config.layer_norm_eps, name="layer_norm2")
self.config = config
def call(self, masks):
masks = tf.transpose(masks, perm=(0, 2, 3, 1))
hidden_states = self.conv1(masks)
hidden_states = self.layer_norm1(hidden_states)
hidden_states = self.activation(hidden_states)
hidden_states = self.conv2(hidden_states)
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.activation(hidden_states)
dense_embeddings = self.conv3(hidden_states)
dense_embeddings = tf.transpose(dense_embeddings, perm=(0, 3, 1, 2))
return dense_embeddings
def build(self, input_shape=None):
if self.built:
return
self.built = True
with tf.name_scope("conv1"):
self.conv1.build([None, None, None, 1])
with tf.name_scope("conv2"):
self.conv2.build([None, None, None, self.mask_input_channels])
with tf.name_scope("conv3"):
self.conv3.build([None, None, None, self.mask_input_channels * 4])
with tf.name_scope("layer_norm1"):
self.layer_norm1.build([None, None, None, self.mask_input_channels])
with tf.name_scope("layer_norm2"):
self.layer_norm2.build([None, None, None, self.mask_input_channels * 4])
class TFSamPromptEncoder(keras.layers.Layer):
def __init__(self, config: SamPromptEncoderConfig, shared_patch_embedding, **kwargs):
super().__init__(**kwargs)
self.shared_embedding = shared_patch_embedding
self.mask_embed = TFSamMaskEmbedding(config, name="mask_embed")
self.no_mask_embed = None
self.image_embedding_size = (config.image_embedding_size, config.image_embedding_size)
self.input_image_size = config.image_size
self.point_embed = []
self.hidden_size = config.hidden_size
self.not_a_point_embed = None
self.config = config
def build(self, input_shape=None):
self.no_mask_embed = self.add_weight(
name="no_mask_embed.weight",
shape=(1, self.hidden_size),
initializer=keras.initializers.RandomNormal(mean=0.0, stddev=0.02),
trainable=True,
)
self.point_embed = [
self.add_weight(
name=f"point_embed_._{i}.weight",
shape=(1, self.hidden_size),
initializer=keras.initializers.RandomNormal(mean=0.0, stddev=0.02),
trainable=True,
)
for i in range(self.config.num_point_embeddings)
]
self.not_a_point_embed = self.add_weight(
name="not_a_point_embed.weight",
shape=(1, self.hidden_size),
initializer=keras.initializers.RandomNormal(mean=0.0, stddev=0.02),
trainable=True,
)
with tf.name_scope("mask_embed"):
self.mask_embed.build(
(None, self.config.mask_input_channels, self.config.image_size, self.config.image_size)
)
if self.built:
return
self.built = True
if getattr(self, "mask_embed", None) is not None:
with tf.name_scope(self.mask_embed.name):
self.mask_embed.build(None)
def _embed_points(self, points: tf.Tensor, labels: tf.Tensor, pad: bool) -> tf.Tensor:
"""Embeds point prompts."""
points = points + 0.5
if pad:
target_point_shape = (shape_list(points)[0], shape_list(points)[1], 1, shape_list(points)[-1])
target_labels_shape = (shape_list(points)[0], shape_list(points)[1], 1)
padding_point = tf.zeros(target_point_shape, dtype=points.dtype)
padding_label = -tf.ones(target_labels_shape, dtype=labels.dtype)
points = tf.concat([points, padding_point], axis=2)
labels = tf.concat([labels, padding_label], axis=2)
input_shape = (self.input_image_size, self.input_image_size)
point_embedding = self.shared_embedding(points, input_shape)
point_embedding = tf.where(labels[..., None] == -1, self.not_a_point_embed[0], point_embedding)
point_embedding = tf.where(
labels[..., None] != -10,
point_embedding,
tf.zeros_like(point_embedding),
)
point_embedding = tf.where(
(labels == 0)[:, :, :, None], point_embedding + self.point_embed[0], point_embedding
)
point_embedding = tf.where(
(labels == 1)[:, :, :, None], point_embedding + self.point_embed[1], point_embedding
)
return point_embedding
def _embed_boxes(self, boxes: tf.Tensor) -> tf.Tensor:
"""Embeds box prompts."""
boxes = boxes + 0.5
batch_size, nb_boxes = shape_list(boxes)[:2]
coords = tf.reshape(boxes, (batch_size, nb_boxes, 2, 2))
input_shape = (self.input_image_size, self.input_image_size)
corner_embedding = self.shared_embedding(coords, input_shape)
corner_embedding += tf.where(
tf.range(shape_list(corner_embedding)[2])[None, None, :, None] == 0,
self.point_embed[2][0],
self.point_embed[3][0],
)
return corner_embedding
def call(
self,
batch_size: Optional[int],
input_points: Optional[Tuple[tf.Tensor, tf.Tensor]],
input_labels: tf.Tensor | None,
input_boxes: tf.Tensor | None,
input_masks: tf.Tensor | None,
"""
Embeds different types of prompts, returning both sparse and dense embeddings.
Args:
points (`tf.Tensor`, *optional*):
point coordinates and labels to embed.
boxes (`tf.Tensor`, *optional*):
boxes to embed
masks (`tf.Tensor`, *optional`):
masks to embed
"""
sparse_embeddings = None
if input_points is not None:
batch_size, point_batch_size = shape_list(input_points)[:2]
if input_labels is None:
raise ValueError("If points are provided, labels must also be provided.")
point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None))
sparse_embeddings = tf.zeros(
(batch_size, point_batch_size, 0, self.hidden_size), dtype=point_embeddings.dtype
)
sparse_embeddings = tf.concat([sparse_embeddings, point_embeddings], axis=2)
if input_boxes is not None:
batch_size = shape_list(input_boxes)[0]
box_embeddings = self._embed_boxes(input_boxes)
if sparse_embeddings is None:
sparse_embeddings = box_embeddings
else:
sparse_embeddings = tf.concat([sparse_embeddings, box_embeddings], axis=2)
if input_masks is not None:
dense_embeddings = self.mask_embed(input_masks)
else:
dense_embeddings = self.no_mask_embed[0]
dense_embeddings = tf.reshape(dense_embeddings, (1, -1, 1, 1))
dense_embeddings = tf.tile(
dense_embeddings, (batch_size, 1, self.image_embedding_size[0], self.image_embedding_size[1])
)
if sparse_embeddings is None:
sparse_embeddings = tf.zeros((batch_size, 0, 1, self.hidden_size), dtype=dense_embeddings.dtype)
return sparse_embeddings, dense_embeddings
class TFSamVisionAttention(keras.layers.Layer):
"""Multi-head Attention block with relative position embeddings."""
def __init__(self, config, window_size, **kwargs):
super().__init__(**kwargs)
input_size = (
(config.image_size // config.patch_size, config.image_size // config.patch_size)
if window_size == 0
else (window_size, window_size)
)
self.input_size = input_size
self.num_attention_heads = config.num_attention_heads
head_dim = config.hidden_size // config.num_attention_heads
self.head_dim = head_dim
self.scale = head_dim**-0.5
self.dropout = config.attention_dropout
self.qkv = keras.layers.Dense(config.hidden_size * 3, use_bias=config.qkv_bias, name="qkv")
self.proj = keras.layers.Dense(config.hidden_size, name="proj")
self.use_rel_pos = config.use_rel_pos
if self.use_rel_pos:
if input_size is None:
raise ValueError("Input size must be provided if using relative positional encoding.")
self.config = config
def build(self, input_shape=None):
if self.input_size is not None:
self.rel_pos_h = self.add_weight(
shape=(2 * self.input_size[0] - 1, self.head_dim), initializer="zeros", name="rel_pos_h"
)
self.rel_pos_w = self.add_weight(
shape=(2 * self.input_size[1] - 1, self.head_dim), initializer="zeros", name="rel_pos_w"
)
if self.built:
return
self.built = True
if getattr(self, "qkv", None) is not None:
with tf.name_scope(self.qkv.name):
self.qkv.build([None, None, self.config.hidden_size])
if getattr(self, "proj", None) is not None:
with tf.name_scope(self.proj.name):
self.proj.build([None, None, self.config.hidden_size])
def get_rel_pos(self, q_size: int, k_size: int, rel_pos: tf.Tensor) -> tf.Tensor:
"""
Get relative positional embeddings according to the relative positions of
query and key sizes.
Args:
q_size (int):
size of the query.
k_size (int):
size of key k.
rel_pos (`tf.Tensor`):
relative position embeddings (L, channel).
Returns:
Extracted positional embeddings according to relative positions.
"""
max_rel_dist = int(2 * max(q_size, k_size) - 1)
if rel_pos.shape[0] != max_rel_dist:
rel_pos_resized = tf.image.resize(
tf.reshape(rel_pos, (1, rel_pos.shape[0], -1)),
size=(max_rel_dist, rel_pos.shape[1]),
method="bilinear",
)
rel_pos_resized = tf.reshape(rel_pos_resized, (-1, max_rel_dist))
else:
rel_pos_resized = rel_pos
q_coords = tf.expand_dims(tf.range(q_size, dtype=tf.float32), 1) * max(k_size / q_size, 1.0)
k_coords = tf.expand_dims(tf.range(k_size, dtype=tf.float32), 0) * max(q_size / k_size, 1.0)
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
return tf.gather(rel_pos_resized, tf.cast(relative_coords, tf.int32))
def add_decomposed_rel_pos(
self,
attn: tf.Tensor,
query: tf.Tensor,
rel_pos_h: tf.Tensor,
rel_pos_w: tf.Tensor,
q_size: Tuple[int, int],
k_size: Tuple[int, int],
...
) -> tf.Tensor:
"""
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py
Args:
attn (`tf.Tensor`):
attention map.
query (`tf.Tensor`):
query q in the attention layer with shape (batch_size, query_height * query_width, channel).
rel_pos_h (`tf.Tensor`):
relative position embeddings (Lh, channel) for height axis.
rel_pos_w (`tf.Tensor`):
relative position embeddings (Lw, channel) for width axis.
q_size (tuple):
spatial sequence size of query q with (query_height, query_width).
k_size (tuple):
spatial sequence size of key k with (key_height, key_width).
Returns:
attn (`tf.Tensor`):
attention map with added relative positional embeddings.
"""
query_height, query_width = q_size
key_height, key_width = k_size
relative_position_height = self.get_rel_pos(query_height, key_height, rel_pos_h)
relative_position_width = self.get_rel_pos(query_width, key_width, rel_pos_w)
batch_size, _, dim = shape_list(query)
reshaped_query = tf.reshape(query, (batch_size, query_height, query_width, dim))
rel_h = tf.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height)
rel_w = tf.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width)
attn = tf.reshape(attn, (batch_size, query_height, query_width, key_height, key_width))
attn = attn + tf.expand_dims(rel_h, axis=-1) + tf.expand_dims(rel_w, axis=-2)
attn = tf.reshape(attn, (batch_size, query_height * query_width, key_height * key_width))
return attn
def call(self, hidden_states: tf.Tensor, output_attentions=False, training=False) -> tf.Tensor:
batch_size, height, width, _ = shape_list(hidden_states)
qkv = tf.reshape(self.qkv(hidden_states), (batch_size, height * width, 3, self.num_attention_heads, -1))
qkv = tf.transpose(qkv, perm=(2, 0, 3, 1, 4))
query, key, value = tf.unstack(
tf.reshape(qkv, (3, batch_size * self.num_attention_heads, height * width, -1)), axis=0
)
attn_weights = tf.matmul(query * self.scale, key, transpose_b=True)
if self.use_rel_pos:
attn_weights = self.add_decomposed_rel_pos(
attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
)
attn_weights = tf.nn.softmax(attn_weights, axis=-1)
if training:
attn_probs = tf.nn.dropout(attn_weights, rate=self.dropout)
else:
attn_probs = attn_weights
attn_output = tf.reshape(attn_probs @ value, (batch_size, self.num_attention_heads, height, width, -1))
attn_output = tf.transpose(attn_output, perm=(0, 2, 3, 1, 4))
attn_output = tf.reshape(attn_output, (batch_size, height, width, self.config.hidden_size))
attn_output = self.proj(attn_output)
if output_attentions:
outputs = (attn_output, attn_weights)
else:
outputs = (attn_output, None)
return outputs
class TFSamVisionLayer(keras.layers.Layer):
def __init__(self, config, window_size, **kwargs):
super().__init__(**kwargs)
self.layer_norm1 = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm1")
self.attn = TFSamVisionAttention(config, window_size, name="attn")
self.layer_norm2 = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm2")
self.mlp = TFSamMLPBlock(config, name="mlp")
self.window_size = window_size
self.config = config
def window_partition(self, hidden_states: tf.Tensor, window_size: int) -> Tuple[tf.Tensor, Tuple[int, int]]:
batch_size, height, width, channel = shape_list(hidden_states)
pad_h = (window_size - height % window_size) % window_size
pad_w = (window_size - width % window_size) % window_size
if pad_h > 0 or pad_w > 0:
hidden_states = tf.pad(hidden_states, [[0, 0], [0, pad_h], [0, pad_w], [0, 0]])
pad_height, pad_width = height + pad_h, width + pad_w
hidden_states = tf.reshape(
hidden_states,
[batch_size, pad_height // window_size, window_size, pad_width // window_size, window_size, channel],
)
windows = tf.reshape(
tf.transpose(hidden_states, perm=[0, 1, 3, 2, 4, 5]), [-1, window_size, window_size, channel]
)
return windows, (pad_height, pad_width)
def window_unpartition(
self, windows: tf.Tensor, window_size: int, padding_shape: Tuple[int, int], original_shape: Tuple[int, int]
) -> tf.Tensor:
pad_height, pad_width = padding_shape
height, width = original_shape
batch_size = shape_list(windows)[0] // (pad_height * pad_width // window_size // window_size)
hidden_states = tf.reshape(
windows, [batch_size, pad_height // window_size, pad_width // window_size, window_size, window_size, -1]
)
hidden_states = tf.reshape(
tf.transpose(hidden_states, perm=[0, 1, 3, 2, 4, 5]), [batch_size, pad_height, pad_width, -1]
)
if pad_height > height or pad_width > width:
hidden_states = hidden_states[:, :height, :width, :]
return hidden_states
def call(
self,
hidden_states: tf.Tensor,
output_attentions: Optional[bool] = False,
training: Optional[bool] = False,
) -> Tuple[tf.Tensor]:
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
if self.window_size > 0:
height, width = hidden_states.shape[1], hidden_states.shape[2]
hidden_states, padding_shape = self.window_partition(hidden_states, self.window_size)
hidden_states, attn_weights = self.attn(
hidden_states=hidden_states,
output_attentions=output_attentions,
training=training,
)
if self.window_size > 0:
hidden_states = self.window_unpartition(hidden_states, self.window_size, padding_shape, (height, width))
hidden_states = residual + hidden_states
layernorm_output = self.layer_norm2(hidden_states)
hidden_states = hidden_states + self.mlp(layernorm_output)
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "layer_norm1", None) is not None:
with tf.name_scope(self.layer_norm1.name):
self.layer_norm1.build([None, None, None, self.config.hidden_size])
if getattr(self, "attn", None) is not None:
with tf.name_scope(self.attn.name):
self.attn.build(None)
if getattr(self, "layer_norm2", None) is not None:
with tf.name_scope(self.layer_norm2.name):
self.layer_norm2.build([None, None, None, self.config.hidden_size])
if getattr(self, "mlp", None) is not None:
with tf.name_scope(self.mlp.name):
self.mlp.build(None)
class TFSamVisionEncoder(keras.layers.Layer):
def __init__(self, config: SamVisionConfig, **kwargs):
super().__init__(**kwargs)
self.config = config
self.image_size = config.image_size
self.patch_embed = TFSamPatchEmbeddings(config, name="patch_embed")
self.pos_embed = None
self.layers = []
for i in range(config.num_hidden_layers):
layer = TFSamVisionLayer(
config,
window_size=config.window_size if i not in config.global_attn_indexes else 0,
name=f"layers_._{i}",
)
self.layers.append(layer)
self.neck = TFSamVisionNeck(config, name="neck")
def build(self, input_shape=None):
if self.built:
return
self.built = True
if self.config.use_abs_pos:
self.pos_embed = self.add_weight(
shape=[
1,
self.config.image_size // self.config.patch_size,
self.config.image_size // self.config.patch_size,
self.config.hidden_size,
],
initializer="zeros",
trainable=True,
name="pos_embed",
)
if getattr(self, "patch_embed", None) is not None:
with tf.name_scope(self.patch_embed.name):
self.patch_embed.build(None)
if getattr(self, "neck", None) is not None:
with tf.name_scope(self.neck.name):
self.neck.build(None)
for layer in self.layers:
with tf.name_scope(layer.name):
layer.build(None)
def get_input_embeddings(self):
return self.patch_embed
def call(
self,
pixel_values: tf.Tensor | None = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: Optional[bool] = False,
) -> Union[Tuple, TFSamVisionEncoderOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
hidden_states = self.patch_embed(pixel_values)
if self.pos_embed is not None:
hidden_states = hidden_states + self.pos_embed
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
for i, layer_module in enumerate(self.layers):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_outputs = layer_module(hidden_states, output_attentions=output_attentions, training=training)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
hidden_states = self.neck(hidden_states)
if not return_dict:
outputs = (hidden_states,)
if output_hidden_states:
outputs = outputs + (all_hidden_states,)
if output_attentions:
outputs = outputs + (all_self_attentions,)
return outputs
return TFSamVisionEncoderOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
@add_start_docstrings(
"Segment Anything Model (SAM) for generating segmentation masks, given an input image and ",
" optional 2D location and bounding boxes.",
SAM_START_DOCSTRING,
)
class TFSamModel(TFSamPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"prompt_encoder.shared_embedding.positional_embedding"]
def __init__(self, config, **kwargs):
super().__init__(config, **kwargs)
self.shared_image_embedding = TFSamPositionalEmbedding(config.vision_config, name="shared_image_embedding")
self.vision_encoder = TFSamVisionEncoder(config.vision_config, name="vision_encoder")
self.prompt_encoder = TFSamPromptEncoder(
config.prompt_encoder_config, self.shared_image_embedding, name="prompt_encoder"
)
self.mask_decoder = TFSamMaskDecoder(config.mask_decoder_config, name="mask_decoder")
self.config = config
def get_input_embeddings(self):
return self.vision_encoder.get_input_embeddings()
def get_image_wide_positional_embeddings(self):
size = self.config.prompt_encoder_config.image_embedding_size
grid = tf.ones((size, size))
y_embed = tf.math.cumsum(grid, axis=0) - 0.5
x_embed = tf.math.cumsum(grid, axis=1) - 0.5
y_embed = y_embed / size
x_embed = x_embed / size
positional_embedding = self.shared_image_embedding(tf.stack([x_embed, y_embed], axis=-1))
return tf.expand_dims(tf.transpose(positional_embedding, perm=[2, 0, 1]), axis=0)
def get_image_embeddings(
self,
pixel_values,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
pass
):
r"""
Returns the image embeddings by passing the pixel values through the vision encoder.
Args:
pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`):
Input pixel values
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.TFModelOutput`] instead of a plain tuple.
"""
vision_output = self.vision_encoder(
pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
image_embeddings = vision_output[0]
return image_embeddings
def get_prompt_embeddings(
self,
input_points: tf.Tensor | None = None,
input_labels: tf.Tensor | None = None,
input_boxes: tf.Tensor | None = None,
input_masks: tf.Tensor | None = None,
):
r"""
Returns the prompt embeddings by passing the input points, labels, boxes and masks through the prompt encoder.
Args:
input_points (`tf.Tensor` of shape `(batch_size, point_batch_size, num_points_per_image, 2)`):
Optional input points for the prompt encoder. The padding of the point is automatically done by the
processor. `point_batch_size` refers to the number of masks that we want the model to predict per
point. The model will output `point_batch_size` times 3 masks in total.
input_labels (`tf.Tensor` of shape `(batch_size, point_batch_size, num_points_per_image)`):
Optional input labels for the prompt encoder. The padding of the labels is automatically done by the
processor, or can be fed by the user.
input_boxes (`tf.Tensor` of shape `(batch_size, num_boxes_per_image, 4)`):
Optional input boxes for the prompt encoder. The padding of the boxes is automatically done by the
processor. users can also pass manually the input boxes.
input_masks (`tf.Tensor` of shape `(batch_size, image_size, image_size)`):
Optional input masks for the prompt encoder.
"""
prompt_output = self.prompt_encoder(
input_points=input_points,
input_labels=input_labels,
input_boxes=input_boxes,
input_masks=input_masks,
)
return prompt_output
@unpack_inputs
@add_start_docstrings_to_model_forward(SAM_INPUTS_DOCSTRING)
def call(
self,
pixel_values: TFModelInputType | None = None,
input_points: tf.Tensor | None = None,
input_labels: tf.Tensor | None = None,
input_boxes: tf.Tensor | None = None,
input_masks: tf.Tensor | None = None,
image_embeddings: tf.Tensor | None = None,
multimask_output: bool = True,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
return_dict: bool | None = None,
training: bool = False,
**kwargs,
):
def serving_output(self, output: TFSamImageSegmentationOutput) -> TFSamImageSegmentationOutput:
hs = tf.convert_to_tensor(output.vision_hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.vision_attentions) if self.config.output_attentions else None
return TFSamImageSegmentationOutput(
iou_scores=output.iou_scores,
pred_masks=output.pred_masks,
vision_hidden_states=hs if self.config.output_hidden_states else None,
vision_attentions=attns if self.config.output_attentions else None,
mask_decoder_attentions=output.mask_decoder_attentions if self.config.output_attentions else None,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "shared_image_embedding", None) is not None:
with tf.name_scope(self.shared_image_embedding.name):
self.shared_image_embedding.build(None)
if getattr(self, "vision_encoder", None) is not None:
with tf.name_scope(self.vision_encoder.name):
self.vision_encoder.build(None)
if getattr(self, "prompt_encoder", None) is not None:
with tf.name_scope(self.prompt_encoder.name):
self.prompt_encoder.build(None)
if getattr(self, "mask_decoder", None) is not None:
with tf.name_scope(self.mask_decoder.name):
self.mask_decoder.build(None)
.\models\sam\processing_sam.py
"""
Processor class for SAM.
"""
from copy import deepcopy
from typing import Optional, Union
import numpy as np
from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import BatchEncoding
from ...utils import TensorType, is_tf_available, is_torch_available
if is_torch_available():
import torch
if is_tf_available():
import tensorflow as tf
class SamProcessor(ProcessorMixin):
r"""
Constructs a SAM processor which wraps a SAM image processor and an 2D points & Bounding boxes processor into a
single processor.
[`SamProcessor`] offers all the functionalities of [`SamImageProcessor`]. See the docstring of
[`~SamImageProcessor.__call__`] for more information.
Args:
image_processor (`SamImageProcessor`):
An instance of [`SamImageProcessor`]. The image processor is a required input.
"""
attributes = ["image_processor"]
image_processor_class = "SamImageProcessor"
def __init__(self, image_processor):
super().__init__(image_processor)
self.current_processor = self.image_processor
self.point_pad_value = -10
self.target_size = self.image_processor.size["longest_edge"]
def __call__(
self,
images=None,
segmentation_maps=None,
input_points=None,
input_labels=None,
input_boxes=None,
return_tensors: Optional[Union[str, TensorType]] = None,
**kwargs,
def _pad_points_and_labels(self, input_points, input_labels):
"""
The method pads the 2D points and labels to the maximum number of points in the batch.
"""
expected_nb_points = max([point.shape[0] for point in input_points])
processed_input_points = []
for i, point in enumerate(input_points):
if point.shape[0] != expected_nb_points:
point = np.concatenate(
[point, np.zeros((expected_nb_points - point.shape[0], 2)) + self.point_pad_value], axis=0
)
input_labels[i] = np.append(input_labels[i], [self.point_pad_value])
processed_input_points.append(point)
input_points = processed_input_points
return input_points, input_labels
def _normalize_coordinates(
self, target_size: int, coords: np.ndarray, original_size, is_bounding_box=False
):
"""
Normalize coordinates based on the target size and original size.
"""
if is_bounding_box:
scale = max(original_size) / float(target_size)
else:
scale = float(target_size) / max(original_size)
coords[:, 0::2] *= scale
coords[:, 1::2] *= scale
return coords
) -> np.ndarray:
"""
Expects a numpy array of length 2 in the final dimension. Requires the original image size in (H, W) format.
"""
old_h, old_w = original_size
new_h, new_w = self.image_processor._get_preprocess_shape(original_size, longest_edge=target_size)
coords = deepcopy(coords).astype(float)
if is_bounding_box:
coords = coords.reshape(-1, 2, 2)
coords[..., 0] = coords[..., 0] * (new_w / old_w)
coords[..., 1] = coords[..., 1] * (new_h / old_h)
if is_bounding_box:
coords = coords.reshape(-1, 4)
return coords
def _check_and_preprocess_points(
self,
input_points=None,
input_labels=None,
input_boxes=None,
):
r"""
Check and preprocesses the 2D points, labels and bounding boxes. It checks if the input is valid and if they
are, it converts the coordinates of the points and bounding boxes. If a user passes directly a `torch.Tensor`,
it is converted to a `numpy.ndarray` and then to a `list`.
"""
if input_points is not None:
if hasattr(input_points, "numpy"):
input_points = input_points.numpy().tolist()
if not isinstance(input_points, list) or not isinstance(input_points[0], list):
raise ValueError("Input points must be a list of list of floating points.")
input_points = [np.array(input_point) for input_point in input_points]
else:
input_points = None
if input_labels is not None:
if hasattr(input_labels, "numpy"):
input_labels = input_labels.numpy().tolist()
if not isinstance(input_labels, list) or not isinstance(input_labels[0], list):
raise ValueError("Input labels must be a list of list integers.")
input_labels = [np.array(label) for label in input_labels]
else:
input_labels = None
if input_boxes is not None:
if hasattr(input_boxes, "numpy"):
input_boxes = input_boxes.numpy().tolist()
if (
not isinstance(input_boxes, list)
or not isinstance(input_boxes[0], list)
or not isinstance(input_boxes[0][0], list)
):
raise ValueError("Input boxes must be a list of list of list of floating points.")
input_boxes = [np.array(box).astype(np.float32) for box in input_boxes]
else:
input_boxes = None
return input_points, input_labels, input_boxes
@property
def model_input_names(self):
image_processor_input_names = self.image_processor.model_input_names
return list(dict.fromkeys(image_processor_input_names))
def post_process_masks(self, *args, **kwargs):
return self.image_processor.post_process_masks(*args, **kwargs)
.\models\sam\__init__.py
from typing import TYPE_CHECKING
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_tf_available,
is_torch_available,
is_vision_available,
)
_import_structure = {
"configuration_sam": [
"SAM_PRETRAINED_CONFIG_ARCHIVE_MAP",
"SamConfig",
"SamMaskDecoderConfig",
"SamPromptEncoderConfig",
"SamVisionConfig",
],
"processing_sam": ["SamProcessor"],
}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_sam"] = [
"SAM_PRETRAINED_MODEL_ARCHIVE_LIST",
"SamModel",
"SamPreTrainedModel",
]
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_tf_sam"] = [
"TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFSamModel",
"TFSamPreTrainedModel",
]
try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["image_processing_sam"] = ["SamImageProcessor"]
if TYPE_CHECKING:
from .configuration_sam import (
SAM_PRETRAINED_CONFIG_ARCHIVE_MAP,
SamConfig,
SamMaskDecoderConfig,
SamPromptEncoderConfig,
SamVisionConfig,
)
from .processing_sam import SamProcessor
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_sam import SAM_PRETRAINED_MODEL_ARCHIVE_LIST, SamModel, SamPreTrainedModel
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_tf_sam import TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST, TFSamModel, TFSamPreTrainedModel
try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .image_processing_sam import SamImageProcessor
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
.\models\seamless_m4t\configuration_seamless_m4t.py
""" SeamlessM4T 模型配置"""
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
SEAMLESS_M4T_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"facebook/hf-seamless-m4t-medium": "https://huggingface.co/facebook/hf-seamless-m4t-medium/resolve/main/config.json",
}
class SeamlessM4TConfig(PretrainedConfig):
r"""
这是一个配置类,用于存储 [`~SeamlessM4TModel`] 的配置。根据指定的参数实例化一个 SeamlessM4T 模型配置,
定义模型的架构。使用默认配置实例化将产生类似于 SeamlessM4T
["facebook/hf-seamless-m4t-medium"](https://huggingface.co/"facebook/hf-seamless-m4t-medium") 架构的配置。
配置对象继承自 [`PretrainedConfig`],可用于控制模型的输出。阅读 [`PretrainedConfig`] 的文档获取更多信息。
```
>>> from transformers import SeamlessM4TModel, SeamlessM4TConfig
>>> # 初始化 SeamlessM4T "facebook/hf-seamless-m4t-medium" 风格的配置
>>> configuration = SeamlessM4TConfig()
>>> # 根据 "facebook/hf-seamless-m4t-medium" 风格的配置初始化模型
>>> model = SeamlessM4TModel(configuration)
>>> # 访问模型配置
>>> configuration = model.config
```
"""
model_type = "seamless_m4t"
def __init__(
self,
vocab_size=256102,
t2u_vocab_size=10082,
hidden_size=1024,
initializer_range=0.02,
layer_norm_eps=1e-5,
use_cache=True,
max_position_embeddings=1024,
is_encoder_decoder=True,
encoder_layerdrop=0.05,
decoder_layerdrop=0.05,
activation_function="relu",
dropout=0.1,
attention_dropout=0.1,
activation_dropout=0.0,
scale_embedding=True,
encoder_layers=24,
encoder_ffn_dim=8192,
encoder_attention_heads=16,
decoder_layers=24,
decoder_ffn_dim=8192,
decoder_attention_heads=16,
decoder_start_token_id=3,
max_new_tokens=256,
pad_token_id=0,
bos_token_id=2,
eos_token_id=3,
speech_encoder_layers=24,
speech_encoder_attention_heads=16,
speech_encoder_intermediate_size=4096,
speech_encoder_hidden_act="swish",
speech_encoder_dropout=0.0,
add_adapter=True,
speech_encoder_layerdrop=0.1,
feature_projection_input_dim=160,
num_conv_pos_embeddings=128,
num_conv_pos_embedding_groups=16,
adaptor_kernel_size=8,
adaptor_stride=8,
adaptor_dropout=0.1,
num_adapter_layers=1,
position_embeddings_type="relative",
rotary_embedding_base=10000,
max_source_positions=4096,
conv_depthwise_kernel_size=31,
t2u_bos_token_id=0,
t2u_pad_token_id=1,
t2u_eos_token_id=2,
t2u_decoder_start_token_id=2,
t2u_max_new_tokens=1024,
t2u_encoder_layers=6,
t2u_encoder_ffn_dim=8192,
t2u_encoder_attention_heads=16,
t2u_decoder_layers=6,
t2u_decoder_ffn_dim=8192,
t2u_decoder_attention_heads=16,
t2u_max_position_embeddings=2048,
sampling_rate=16000,
upsample_initial_channel=512,
upsample_rates=[5, 4, 4, 2, 2],
upsample_kernel_sizes=[11, 8, 8, 4, 4],
resblock_kernel_sizes=[3, 7, 11],
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
leaky_relu_slope=0.1,
unit_hifi_gan_vocab_size=10000,
unit_embed_dim=1280,
lang_embed_dim=256,
spkr_embed_dim=256,
vocoder_num_langs=36,
vocoder_num_spkrs=200,
variance_predictor_kernel_size=3,
var_pred_dropout=0.5,
vocoder_offset=4,
**kwargs,