Transformers 源码解析(八十三)
.\models\oneformer\image_processing_oneformer.py
"""
Image processor class for OneFormer.
"""
import json
import os
import warnings
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
import numpy as np
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import RepositoryNotFoundError
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import (
PaddingMode,
get_resize_output_image_size,
pad,
rescale,
resize,
to_channel_dimension_format,
)
from ...image_utils import (
ChannelDimension,
ImageInput,
PILImageResampling,
get_image_size,
infer_channel_dimension_format,
is_scaled_image,
make_list_of_images,
to_numpy_array,
valid_images,
validate_kwargs,
validate_preprocess_arguments,
)
from ...utils import (
IMAGENET_DEFAULT_MEAN,
IMAGENET_DEFAULT_STD,
TensorType,
is_torch_available,
is_torch_tensor,
logging,
)
logger = logging.get_logger(__name__)
if is_torch_available():
import torch
from torch import nn
def max_across_indices(values: Iterable[Any]) -> List[Any]:
"""
Return the maximum value across all indices of an iterable of values.
"""
return [max(values_i) for values_i in zip(*values)]
def get_max_height_width(
images: List[np.ndarray], input_data_format: Optional[Union[str, ChannelDimension]] = None
) -> List[int]:
"""
Get the maximum height and width across all images in a batch.
"""
if input_data_format is None:
input_data_format = infer_channel_dimension_format(images[0])
if input_data_format == ChannelDimension.FIRST:
_, max_height, max_width = max_across_indices([img.shape for img in images])
elif input_data_format == ChannelDimension.LAST:
max_height, max_width, _ = max_across_indices([img.shape for img in images])
else:
raise ValueError(f"Invalid channel dimension format: {input_data_format}")
return (max_height, max_width)
def make_pixel_mask(
def make_pixel_mask(image: np.ndarray, output_size: Tuple[int, int]) -> np.ndarray:
"""
Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
Args:
image (`np.ndarray`):
Image to make the pixel mask for.
output_size (`Tuple[int, int]`):
Output size of the mask.
"""
input_height, input_width = get_image_size(image, channel_dim=input_data_format)
mask = np.zeros(output_size, dtype=np.int64)
mask[:input_height, :input_width] = 1
return mask
def binary_mask_to_rle(mask):
"""
Converts given binary mask of shape `(height, width)` to the run-length encoding (RLE) format.
Args:
mask (`torch.Tensor` or `numpy.array`):
A binary mask tensor of shape `(height, width)` where 0 denotes background and 1 denotes the target
segment_id or class_id.
Returns:
`List`: Run-length encoded list of the binary mask. Refer to COCO API for more information about the RLE
format.
"""
if is_torch_tensor(mask):
mask = mask.numpy()
pixels = mask.flatten()
pixels = np.concatenate([[0], pixels, [0]])
runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
runs[1::2] -= runs[::2]
return list(runs)
def convert_segmentation_to_rle(segmentation):
"""
Converts given segmentation map of shape `(height, width)` to the run-length encoding (RLE) format.
Args:
segmentation (`torch.Tensor` or `numpy.array`):
A segmentation map of shape `(height, width)` where each value denotes a segment or class id.
Returns:
`List[List]`: A list of lists, where each list is the run-length encoding of a segment / class id.
"""
segment_ids = torch.unique(segmentation)
run_length_encodings = []
for idx in segment_ids:
mask = torch.where(segmentation == idx, 1, 0)
rle = binary_mask_to_rle(mask)
run_length_encodings.append(rle)
return run_length_encodings
def remove_low_and_no_objects(masks, scores, labels, object_mask_threshold, num_labels):
"""
Binarize the given masks using `object_mask_threshold`, it returns the associated values of `masks`, `scores` and
`labels`.
Args:
masks (`torch.Tensor`):
A tensor of shape `(num_queries, height, width)`.
scores (`torch.Tensor`):
A tensor of shape `(num_queries)`.
labels (`torch.Tensor`):
A tensor of shape `(num_queries)`.
object_mask_threshold (`float`):
A number between 0 and 1 used to binarize the masks.
Raises:
`ValueError`: Raised when the first dimension doesn't match in all input tensors.
"""
if not (masks.shape[0] == scores.shape[0] == labels.shape[0]):
raise ValueError("mask, scores and labels must have the same shape!")
to_keep = labels.ne(num_labels) & (scores > object_mask_threshold)
return masks[to_keep], scores[to_keep], labels[to_keep]
def check_segment_validity(mask_labels, mask_probs, k, mask_threshold=0.5, overlap_mask_area_threshold=0.8):
mask_k = mask_labels == k
mask_k_area = mask_k.sum()
original_area = (mask_probs[k] >= mask_threshold).sum()
mask_exists = mask_k_area > 0 and original_area > 0
if mask_exists:
area_ratio = mask_k_area / original_area
if not area_ratio.item() > overlap_mask_area_threshold:
mask_exists = False
return mask_exists, mask_k
def compute_segments(
mask_probs,
pred_scores,
pred_labels,
mask_threshold: float = 0.5,
overlap_mask_area_threshold: float = 0.8,
label_ids_to_fuse: Optional[Set[int]] = None,
target_size: Tuple[int, int] = None,
):
height = mask_probs.shape[1] if target_size is None else target_size[0]
width = mask_probs.shape[2] if target_size is None else target_size[1]
segmentation = torch.zeros((height, width), dtype=torch.int32, device=mask_probs.device)
segments: List[Dict] = []
if target_size is not None:
mask_probs = nn.functional.interpolate(
mask_probs.unsqueeze(0), size=target_size, mode="bilinear", align_corners=False
)[0]
current_segment_id = 0
mask_probs *= pred_scores.view(-1, 1, 1)
mask_labels = mask_probs.argmax(0)
stuff_memory_list: Dict[str, int] = {}
for k in range(pred_labels.shape[0]):
pred_class = pred_labels[k].item()
should_fuse = pred_class in label_ids_to_fuse
mask_exists, mask_k = check_segment_validity(
mask_labels, mask_probs, k, mask_threshold, overlap_mask_area_threshold
)
if mask_exists:
if pred_class in stuff_memory_list:
current_segment_id = stuff_memory_list[pred_class]
else:
current_segment_id += 1
segmentation[mask_k] = current_segment_id
segment_score = round(pred_scores[k].item(), 6)
segments.append(
{
"id": current_segment_id,
"label_id": pred_class,
"was_fused": should_fuse,
"score": segment_score,
}
)
if should_fuse:
stuff_memory_list[pred_class] = current_segment_id
return segmentation, segments
def convert_segmentation_map_to_binary_masks(
segmentation_map: "np.ndarray",
instance_id_to_semantic_id: Optional[Dict[int, int]] = None,
ignore_index: Optional[int] = None,
reduce_labels: bool = False,
):
if reduce_labels and ignore_index is None:
raise ValueError("If `reduce_labels` is True, `ignore_index` must be provided.")
if reduce_labels:
segmentation_map = np.where(segmentation_map == 0, ignore_index, segmentation_map - 1)
all_labels = np.unique(segmentation_map)
if ignore_index is not None:
all_labels = all_labels[all_labels != ignore_index]
binary_masks = [(segmentation_map == i) for i in all_labels]
binary_masks = np.stack(binary_masks, axis=0)
if instance_id_to_semantic_id is not None:
labels = np.zeros(all_labels.shape[0])
for label in all_labels:
class_id = instance_id_to_semantic_id[label + 1 if reduce_labels else label]
labels[all_labels == label] = class_id - 1 if reduce_labels else class_id
else:
labels = all_labels
return binary_masks.astype(np.float32), labels.astype(np.int64)
def get_oneformer_resize_output_image_size(
image: np.ndarray,
size: Union[int, Tuple[int, int], List[int], Tuple[int]],
max_size: Optional[int] = None,
default_to_square: bool = True,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> tuple:
"""
根据所需的大小计算输出大小。
Args:
image (`np.ndarray`):
输入图像。
size (`int` or `Tuple[int, int]` or `List[int]` or `Tuple[int]`):
输出图像的大小。
max_size (`int`, *optional*):
输出图像的最大大小。
default_to_square (`bool`, *optional*, 默认为 `True`):
如果未提供大小,是否默认为正方形。
input_data_format (`ChannelDimension` or `str`, *optional*):
输入图像的通道维度格式。如果未设置,则使用输入的推断格式。
Returns:
`Tuple[int, int]`: 输出大小。
"""
output_size = get_resize_output_image_size(
input_image=image,
size=size,
default_to_square=default_to_square,
max_size=max_size,
input_data_format=input_data_format,
)
return output_size
def prepare_metadata(class_info):
metadata = {}
class_names = []
thing_ids = []
for key, info in class_info.items():
metadata[key] = info["name"]
class_names.append(info["name"])
if info["isthing"]:
thing_ids.append(int(key))
metadata["thing_ids"] = thing_ids
metadata["class_names"] = class_names
return metadata
def load_metadata(repo_id, class_info_file):
fname = os.path.join("" if repo_id is None else repo_id, class_info_file)
if not os.path.exists(fname) or not os.path.isfile(fname):
if repo_id is None:
raise ValueError(f"Could not find {fname} locally. repo_id must be defined if loading from the hub")
try:
fname = hf_hub_download(repo_id, class_info_file, repo_type="dataset")
except RepositoryNotFoundError:
fname = hf_hub_download(repo_id, class_info_file)
with open(fname, "r") as f:
class_info = json.load(f)
return class_info
class OneFormerImageProcessor(BaseImageProcessor):
r"""
构造一个 OneFormer 图像处理器。该图像处理器用于准备图像、任务输入及可选的文本输入和目标,以供模型使用。
此图像处理器继承自 `BaseImageProcessor`,其中包含大多数主要方法。用户应参考该超类以获取关于这些方法的更多信息。
"""
model_input_names = ["pixel_values", "pixel_mask", "task_inputs"]
def __init__(
self,
do_resize: bool = True,
size: Dict[str, int] = None,
resample: PILImageResampling = PILImageResampling.BILINEAR,
do_rescale: bool = True,
rescale_factor: float = 1 / 255,
do_normalize: bool = True,
image_mean: Union[float, List[float]] = None,
image_std: Union[float, List[float]] = None,
ignore_index: Optional[int] = None,
do_reduce_labels: bool = False,
repo_path: Optional[str] = "shi-labs/oneformer_demo",
class_info_file: str = None,
num_text: Optional[int] = None,
**kwargs,
):
if "max_size" in kwargs:
self._max_size = kwargs.pop("max_size")
else:
self._max_size = 1333
size = size if size is not None else {"shortest_edge": 800, "longest_edge": self._max_size}
size = get_size_dict(size, max_size=self._max_size, default_to_square=False)
if "reduce_labels" in kwargs:
warnings.warn(
"The `reduce_labels` argument is deprecated and will be removed in v4.27. "
"Please use `do_reduce_labels` instead.",
FutureWarning,
)
do_reduce_labels = kwargs.pop("reduce_labels")
if class_info_file is None:
raise ValueError("You must provide a `class_info_file`")
super().__init__(**kwargs)
self.do_resize = do_resize
self.size = size
self.resample = resample
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
self.ignore_index = ignore_index
self.do_reduce_labels = do_reduce_labels
self.class_info_file = class_info_file
self.repo_path = repo_path
self.metadata = prepare_metadata(load_metadata(repo_path, class_info_file))
self.num_text = num_text
self._valid_processor_keys = [
"images",
"task_inputs",
"segmentation_maps",
"instance_id_to_semantic_id",
"do_resize",
"size",
"resample",
"do_rescale",
"rescale_factor",
"do_normalize",
"image_mean",
"image_std",
"ignore_index",
"do_reduce_labels",
"return_tensors",
"data_format",
"input_data_format",
]
def resize(
self,
image: np.ndarray,
size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.BILINEAR,
data_format=None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
) -> np.ndarray:
"""
Resize the image to the given size. Size can be min_size (scalar) or `(height, width)` tuple. If size is an
int, smaller edge of the image will be matched to this number.
"""
if "max_size" in kwargs:
warnings.warn(
"The `max_size` parameter is deprecated and will be removed in v4.27. "
"Please specify in `size['longest_edge'] instead`.",
FutureWarning,
)
max_size = kwargs.pop("max_size")
else:
max_size = None
size = get_size_dict(size, max_size=max_size, default_to_square=False)
if "shortest_edge" in size and "longest_edge" in size:
size, max_size = size["shortest_edge"], size["longest_edge"]
elif "height" in size and "width" in size:
size = (size["height"], size["width"])
max_size = None
else:
raise ValueError(
"Size must contain 'height' and 'width' keys or 'shortest_edge' and 'longest_edge' keys. Got"
f" {size.keys()}."
)
size = get_oneformer_resize_output_image_size(
image=image, size=size, max_size=max_size, default_to_square=False, input_data_format=input_data_format
)
image = resize(
image, size=size, resample=resample, data_format=data_format, input_data_format=input_data_format
)
return image
) -> np.ndarray:
"""
Rescale the image by the given factor. image = image * rescale_factor.
Args:
image (`np.ndarray`):
Image to rescale.
rescale_factor (`float`):
The value to use for rescaling.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the output image. If unset, the channel dimension format of the input
image is used. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
input_data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the input image. If unset, is inferred from the input image. Can be
one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
"""
return rescale(image, rescale_factor, data_format=data_format, input_data_format=input_data_format)
def convert_segmentation_map_to_binary_masks(
self,
segmentation_map: "np.ndarray",
instance_id_to_semantic_id: Optional[Dict[int, int]] = None,
ignore_index: Optional[int] = None,
reduce_labels: bool = False,
):
reduce_labels = reduce_labels if reduce_labels is not None else self.reduce_labels
ignore_index = ignore_index if ignore_index is not None else self.ignore_index
return convert_segmentation_map_to_binary_masks(
segmentation_map=segmentation_map,
instance_id_to_semantic_id=instance_id_to_semantic_id,
ignore_index=ignore_index,
reduce_labels=reduce_labels,
)
def __call__(self, images, task_inputs=None, segmentation_maps=None, **kwargs) -> BatchFeature:
return self.preprocess(images, task_inputs=task_inputs, segmentation_maps=segmentation_maps, **kwargs)
def _preprocess(
self,
image: ImageInput,
do_resize: bool = None,
size: Dict[str, int] = None,
resample: PILImageResampling = None,
do_rescale: bool = None,
rescale_factor: float = None,
do_normalize: bool = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
):
if do_resize:
image = self.resize(image, size=size, resample=resample, input_data_format=input_data_format)
if do_rescale:
image = self.rescale(image, rescale_factor=rescale_factor, input_data_format=input_data_format)
if do_normalize:
image = self.normalize(image, mean=image_mean, std=image_std, input_data_format=input_data_format)
return image
def _preprocess_image(
self,
image: ImageInput,
do_resize: bool = None,
size: Dict[str, int] = None,
resample: PILImageResampling = None,
do_rescale: bool = None,
rescale_factor: float = None,
do_normalize: bool = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray:
"""Preprocesses a single image."""
image = to_numpy_array(image)
if is_scaled_image(image) and do_rescale:
logger.warning_once(
"It looks like you are trying to rescale already rescaled images. If the input"
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
)
if input_data_format is None:
input_data_format = infer_channel_dimension_format(image)
image = self._preprocess(
image=image,
do_resize=do_resize,
size=size,
resample=resample,
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
input_data_format=input_data_format,
)
if data_format is not None:
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
return image
def _preprocess_mask(
self,
segmentation_map: ImageInput,
do_resize: bool = None,
size: Dict[str, int] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray:
"""Preprocesses a single mask."""
segmentation_map = to_numpy_array(segmentation_map)
if segmentation_map.ndim == 2:
added_channel_dim = True
segmentation_map = segmentation_map[None, ...]
input_data_format = ChannelDimension.FIRST
else:
added_channel_dim = False
if input_data_format is None:
input_data_format = infer_channel_dimension_format(segmentation_map, num_channels=1)
segmentation_map = self._preprocess(
image=segmentation_map,
do_resize=do_resize,
resample=PILImageResampling.NEAREST,
size=size,
do_rescale=False,
do_normalize=False,
input_data_format=input_data_format,
)
if added_channel_dim:
segmentation_map = segmentation_map.squeeze(0)
return segmentation_map
def preprocess(
self,
images: ImageInput,
task_inputs: Optional[List[str]] = None,
segmentation_maps: Optional[ImageInput] = None,
instance_id_to_semantic_id: Optional[Dict[int, int]] = None,
do_resize: Optional[bool] = None,
size: Optional[Dict[str, int]] = None,
resample: PILImageResampling = None,
do_rescale: Optional[bool] = None,
rescale_factor: Optional[float] = None,
do_normalize: Optional[bool] = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
ignore_index: Optional[int] = None,
do_reduce_labels: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
def _pad_image(
self,
image: np.ndarray,
output_size: Tuple[int, int],
constant_values: Union[float, Iterable[float]] = 0,
data_format: Optional[ChannelDimension] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray:
"""
Pad an image with zeros to the given size.
"""
input_height, input_width = get_image_size(image, channel_dim=input_data_format)
output_height, output_width = output_size
pad_bottom = output_height - input_height
pad_right = output_width - input_width
padding = ((0, pad_bottom), (0, pad_right))
padded_image = pad(
image,
padding,
mode=PaddingMode.CONSTANT,
constant_values=constant_values,
data_format=data_format,
input_data_format=input_data_format,
)
return padded_image
def pad(
self,
images: List[np.ndarray],
constant_values: Union[float, Iterable[float]] = 0,
return_pixel_mask: bool = True,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Optional[ChannelDimension] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> BatchFeature:
"""
Pads a batch of images to the bottom and right of the image with zeros to the size of largest height and width
in the batch and optionally returns their corresponding pixel mask.
Args:
image (`np.ndarray`):
Image to pad.
constant_values (`float` or `Iterable[float]`, *optional*):
The value to use for the padding if `mode` is `"constant"`.
return_pixel_mask (`bool`, *optional*, defaults to `True`):
Whether to return a pixel mask.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
"""
pad_size = get_max_height_width(images, input_data_format=input_data_format)
padded_images = [
self._pad_image(
image,
pad_size,
constant_values=constant_values,
data_format=data_format,
input_data_format=input_data_format,
)
for image in images
]
data = {"pixel_values": padded_images}
if return_pixel_mask:
masks = [
make_pixel_mask(image=image, output_size=pad_size, input_data_format=input_data_format)
for image in images
]
data["pixel_mask"] = masks
return BatchFeature(data=data, tensor_type=return_tensors)
def get_semantic_annotations(self, label, num_class_obj):
annotation_classes = label["classes"]
annotation_masks = label["masks"]
texts = ["a semantic photo"] * self.num_text
classes = []
masks = []
for idx in range(len(annotation_classes)):
class_id = annotation_classes[idx]
mask = annotation_masks[idx]
if not np.all(mask is False):
if class_id not in classes:
cls_name = self.metadata[str(class_id)]
classes.append(class_id)
masks.append(mask)
num_class_obj[cls_name] += 1
else:
idx = classes.index(class_id)
masks[idx] += mask
masks[idx] = np.clip(masks[idx], 0, 1)
num = 0
for i, cls_name in enumerate(self.metadata["class_names"]):
if num_class_obj[cls_name] > 0:
for _ in range(num_class_obj[cls_name]):
if num >= len(texts):
break
texts[num] = f"a photo with a {cls_name}"
num += 1
classes = np.array(classes)
masks = np.array(masks)
return classes, masks, texts
def get_instance_annotations(self, label, num_class_obj):
annotation_classes = label["classes"]
annotation_masks = label["masks"]
texts = ["an instance photo"] * self.num_text
classes = []
masks = []
for idx in range(len(annotation_classes)):
class_id = annotation_classes[idx]
mask = annotation_masks[idx]
if class_id in self.metadata["thing_ids"]:
if not np.all(mask is False):
cls_name = self.metadata[str(class_id)]
classes.append(class_id)
masks.append(mask)
num_class_obj[cls_name] += 1
num = 0
for i, cls_name in enumerate(self.metadata["class_names"]):
if num_class_obj[cls_name] > 0:
for _ in range(num_class_obj[cls_name]):
if num >= len(texts):
break
texts[num] = f"a photo with a {cls_name}"
num += 1
classes = np.array(classes)
masks = np.array(masks)
return classes, masks, texts
def get_panoptic_annotations(self, label, num_class_obj):
annotation_classes = label["classes"]
annotation_masks = label["masks"]
texts = ["an panoptic photo"] * self.num_text
classes = []
masks = []
for idx in range(len(annotation_classes)):
class_id = annotation_classes[idx]
mask = annotation_masks[idx].data
if not np.all(mask is False):
cls_name = self.metadata[str(class_id)]
classes.append(class_id)
masks.append(mask)
num_class_obj[cls_name] += 1
num = 0
for i, cls_name in enumerate(self.metadata["class_names"]):
if num_class_obj[cls_name] > 0:
for _ in range(num_class_obj[cls_name]):
if num >= len(texts):
break
texts[num] = f"a photo with a {cls_name}"
num += 1
classes = np.array(classes)
masks = np.array(masks)
return classes, masks, texts
def post_process_semantic_segmentation(
self, outputs, target_sizes: Optional[List[Tuple[int, int]]] = None
):
pass
) -> "torch.Tensor":
"""
Converts the output of [`MaskFormerForInstanceSegmentation`] into semantic segmentation maps. Only supports
PyTorch.
Args:
outputs ([`MaskFormerForInstanceSegmentation`]):
Raw outputs of the model.
target_sizes (`List[Tuple[int, int]]`, *optional*):
List of length (batch_size), where each list item (`Tuple[int, int]]`) corresponds to the requested
final size (height, width) of each prediction. If left to None, predictions will not be resized.
Returns:
`List[torch.Tensor]`:
A list of length `batch_size`, where each item is a semantic segmentation map of shape (height, width)
corresponding to the target_sizes entry (if `target_sizes` is specified). Each entry of each
`torch.Tensor` correspond to a semantic class id.
"""
class_queries_logits = outputs.class_queries_logits
masks_queries_logits = outputs.masks_queries_logits
masks_classes = class_queries_logits.softmax(dim=-1)[..., :-1]
masks_probs = masks_queries_logits.sigmoid()
segmentation = torch.einsum("bqc, bqhw -> bchw", masks_classes, masks_probs)
batch_size = class_queries_logits.shape[0]
if target_sizes is not None:
if batch_size != len(target_sizes):
raise ValueError(
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
)
semantic_segmentation = []
for idx in range(batch_size):
resized_logits = torch.nn.functional.interpolate(
segmentation[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
)
semantic_map = resized_logits[0].argmax(dim=0)
semantic_segmentation.append(semantic_map)
else:
semantic_segmentation = segmentation.argmax(dim=1)
semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]
return semantic_segmentation
def post_process_panoptic_segmentation(
self,
outputs,
threshold: float = 0.5,
mask_threshold: float = 0.5,
overlap_mask_area_threshold: float = 0.8,
label_ids_to_fuse: Optional[Set[int]] = None,
target_sizes: Optional[List[Tuple[int, int]]] = None,
.\models\oneformer\modeling_oneformer.py
""" PyTorch OneFormer 模型 """
import copy
import math
import warnings
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
from torch import Tensor, nn
from torch.cuda.amp import autocast
from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput
from ...modeling_utils import PreTrainedModel
from ...utils import (
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_accelerate_available,
is_scipy_available,
logging,
replace_return_docstrings,
requires_backends,
)
from ...utils.backbone_utils import load_backbone
from .configuration_oneformer import OneFormerConfig
if is_accelerate_available():
from accelerate import PartialState
from accelerate.utils import reduce
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "OneFormerConfig"
_CHECKPOINT_FOR_DOC = "shi-labs/oneformer_ade20k_swin_tiny"
ONEFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
"shi-labs/oneformer_ade20k_swin_tiny",
]
if is_scipy_available():
from scipy.optimize import linear_sum_assignment
def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
def multi_scale_deformable_attention(
value: Tensor, value_spatial_shapes: Tensor, sampling_locations: Tensor, attention_weights: Tensor
) -> Tensor:
batch_size, _, num_heads, hidden_dim = value.shape
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
value_list = value.split([height.item() * width.item() for height, width in value_spatial_shapes], dim=1)
sampling_grids = 2 * sampling_locations - 1
sampling_value_list = []
for level_id, (height, width) in enumerate(value_spatial_shapes):
value_l_ = (
value_list[level_id].flatten(2).transpose(1, 2).reshape(batch_size * num_heads, hidden_dim, height, width)
)
sampling_grid_l_ = sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1)
sampling_value_l_ = nn.functional.grid_sample(
value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False
)
sampling_value_list.append(sampling_value_l_)
attention_weights = attention_weights.transpose(1, 2).reshape(
batch_size * num_heads, 1, num_queries, num_levels * num_points
)
output = (
(torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
.sum(-1)
.view(batch_size, num_heads * hidden_dim, num_queries)
)
return output.transpose(1, 2).contiguous()
def dice_loss(inputs: Tensor, labels: Tensor, num_masks: int) -> Tensor:
r"""
Compute the DICE loss, similar to generalized IOU for masks as follows:
$$ \mathcal{L}_{\text{dice}(x, y) = 1 - \frac{2 * x \cap y }{x \cup y + 1}} $$
In practice, since `labels` is a binary mask, (only 0s and 1s), dice can be computed as follow
$$ \mathcal{L}_{\text{dice}(x, y) = 1 - \frac{2 * x * y }{x + y + 1}} $$
Args:
inputs (`torch.Tensor`):
A tensor representing a mask.
labels (`torch.Tensor`):
A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs
(0 for the negative class and 1 for the positive class).
num_masks (`int`):
The number of masks present in the current batch, used for normalization.
Returns:
`torch.Tensor`: The computed loss.
"""
probs = inputs.sigmoid().flatten(1)
numerator = 2 * (probs * labels).sum(-1)
denominator = probs.sum(-1) + labels.sum(-1)
loss = 1 - (numerator + 1) / (denominator + 1)
loss = loss.sum() / num_masks
return loss
def sigmoid_cross_entropy_loss(inputs: torch.Tensor, labels: torch.Tensor, num_masks: int) -> torch.Tensor:
r"""
Args:
inputs (`torch.Tensor`):
A float tensor of arbitrary shape.
labels (`torch.Tensor`):
A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs
(0 for the negative class and 1 for the positive class).
Returns:
loss (`torch.Tensor`): The computed loss.
"""
criterion = nn.BCEWithLogitsLoss(reduction="none")
cross_entropy_loss = criterion(inputs, labels)
loss = cross_entropy_loss.mean(1).sum() / num_masks
return loss
def pair_wise_dice_loss(inputs: Tensor, labels: Tensor) -> Tensor:
"""
A pair wise version of the dice loss, see `dice_loss` for usage.
Args:
inputs (`torch.Tensor`):
A tensor representing a mask
labels (`torch.Tensor`):
A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs
(0 for the negative class and 1 for the positive class).
Returns:
`torch.Tensor`: The computed loss between each pairs.
"""
inputs = inputs.sigmoid().flatten(1)
numerator = 2 * torch.matmul(inputs, labels.T)
denominator = inputs.sum(-1)[:, None] + labels.sum(-1)[None, :]
loss = 1 - (numerator + 1) / (denominator + 1)
return loss
def pair_wise_sigmoid_cross_entropy_loss(inputs: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
r"""
A pair wise version of the cross entropy loss, see `sigmoid_cross_entropy_loss` for usage.
Args:
inputs (`torch.Tensor`):
A tensor representing a mask.
labels (`torch.Tensor`):
A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs
(0 for the negative class and 1 for the positive class).
Returns:
loss (`torch.Tensor`): The computed loss between each pairs.
"""
height_and_width = inputs.shape[1]
criterion = nn.BCEWithLogitsLoss(reduction="none")
cross_entropy_loss_pos = criterion(inputs, torch.ones_like(inputs))
cross_entropy_loss_neg = criterion(inputs, torch.zeros_like(inputs))
loss_pos = torch.matmul(cross_entropy_loss_pos / height_and_width, labels.T)
loss_neg = torch.matmul(cross_entropy_loss_neg / height_and_width, (1 - labels).T)
loss = loss_pos + loss_neg
return loss
def sample_point(
input_features: torch.Tensor, point_coordinates: torch.Tensor, add_dim=False, **kwargs
) -> torch.Tensor:
"""
A wrapper around `torch.nn.functional.grid_sample` to support 3D point_coordinates tensors.
Args:
input_features (`torch.Tensor` of shape (batch_size, channels, height, width)):
A tensor that contains features map on a height * width grid
point_coordinates (`torch.Tensor` of shape (batch_size, num_points, 2) or (batch_size, grid_height, grid_width,:
2)):
A tensor that contains [0, 1] * [0, 1] normalized point coordinates
add_dim (`bool`):
boolean value to keep track of added dimension
Returns:
point_features (`torch.Tensor` of shape (batch_size, channels, num_points) or (batch_size, channels,
height_grid, width_grid)):
A tensor that contains features for points in `point_coordinates`.
"""
if point_coordinates.dim() == 3:
add_dim = True
point_coordinates = point_coordinates.unsqueeze(2)
point_features = torch.nn.functional.grid_sample(input_features, 2.0 * point_coordinates - 1.0, **kwargs)
if add_dim:
point_features = point_features.squeeze(3)
return point_features
class OneFormerHungarianMatcher(nn.Module):
def __init__(
self, cost_class: float = 1.0, cost_mask: float = 1.0, cost_dice: float = 1.0, num_points: int = 12544
):
"""
Initialize the OneFormer Hungarian Matcher.
Args:
cost_class (`float`):
Cost associated with class differences.
cost_mask (`float`):
Cost associated with mask differences.
cost_dice (`float`):
Cost associated with Dice similarity differences.
num_points (`int`):
Number of points used in the matcher.
"""
super().__init__()
self.cost_class = cost_class
self.cost_mask = cost_mask
self.cost_dice = cost_dice
self.num_points = num_points
):
"""
This class computes an assignment between the labels and the predictions of the network.
For efficiency reasons, the labels don't include the no_object. Because of this, in general, there are more
predictions than labels. In this case, we do a 1-to-1 matching of the best predictions, while the others are
un-matched (and thus treated as non-objects).
Params:
cost_class (float, *optional*, defaults to 1.0):
This is the relative weight of the classification error in the matching cost.
cost_mask (float, *optional*, defaults to 1.0):
This is the relative weight of the sigmoid ce loss of the binary mask in the matching cost.
cost_dice (float, *optional*, defaults to 1.0):
This is the relative weight of the dice loss of the binary mask in the matching cost
num_points (int, *optional*, defaults to 12544):
Number of points to be sampled for dice and mask loss matching cost.
"""
super().__init__()
if cost_class == 0 and cost_mask == 0 and cost_dice == 0:
raise ValueError("All costs cant be 0")
self.cost_class = cost_class
self.cost_mask = cost_mask
self.cost_dice = cost_dice
self.num_points = num_points
@torch.no_grad()
class OneFormerLoss(nn.Module):
def __init__(
self,
num_classes: int,
matcher: OneFormerHungarianMatcher,
weight_dict: Dict[str, float],
eos_coef: float,
num_points: int,
oversample_ratio: float,
importance_sample_ratio: float,
contrastive_temperature: float = None,
):
"""
This class computes the losses using the class predictions, mask predictions and the contrastive queries.
Oneformer calculates the classification CE loss on the class predictions. Mask predictions are used for
calculating the binary CE loss and dice loss. The contrastive queries are used for calculating the contrastive
loss.
Args:
num_labels (`int`):
The number of classes.
matcher (`OneFormerHungarianMatcher`):
A torch module that computes the assigments between the predictions and labels.
weight_dict (`Dict[str, float]`):
A dictionary of weights to be applied to the different losses.
eos_coef (`float`):
Weight to apply to the null class.
num_points (`int`):
Number of points to be sampled for dice and mask loss calculations.
oversample_ratio (`float`):
Required for pointwise loss calculation.
importance_sample_ratio (`float`):
Required for pointwise loss calculation.
contrastive_temperature (`float`):
Temperature for scaling the contrastive logits.
"""
requires_backends(self, ["scipy"])
super().__init__()
self.num_classes = num_classes
self.matcher = matcher
self.weight_dict = weight_dict
self.eos_coef = eos_coef
empty_weight = torch.ones(self.num_classes + 1)
empty_weight[-1] = self.eos_coef
self.register_buffer("empty_weight", empty_weight)
self.num_points = num_points
self.oversample_ratio = oversample_ratio
self.importance_sample_ratio = importance_sample_ratio
self.contrastive_temperature = contrastive_temperature
if self.contrastive_temperature is not None:
self.logit_scale = nn.Parameter(torch.tensor(np.log(1 / contrastive_temperature)))
def _max_by_axis(self, the_list: List[List[int]]) -> List[int]:
"""
Compute the maximum values along each axis of a 2D list.
Args:
the_list (`List[List[int]]`): 二维列表,用于计算每个轴向的最大值
Returns:
`List[int]`: 包含每个轴向最大值的列表
"""
maxes = the_list[0]
for sublist in the_list[1:]:
for index, item in enumerate(sublist):
maxes[index] = max(maxes[index], item)
return maxes
def _pad_images_to_max_in_batch(self, tensors: List[Tensor]) -> Tuple[Tensor, Tensor]:
max_size = self._max_by_axis([list(tensor.shape) for tensor in tensors])
batch_size = len(tensors)
batch_shape = [batch_size] + max_size
b, _, h, w = batch_shape
dtype = tensors[0].dtype
device = tensors[0].device
padded_tensors = torch.zeros(batch_shape, dtype=dtype, device=device)
padding_masks = torch.ones((b, h, w), dtype=torch.bool, device=device)
for tensor, padded_tensor, padding_mask in zip(tensors, padded_tensors, padding_masks):
padded_tensor[: tensor.shape[0], : tensor.shape[1], : tensor.shape[2]].copy_(tensor)
padding_mask[: tensor.shape[1], : tensor.shape[2]] = False
return padded_tensors, padding_masks
def loss_contrastive(self, contrastive_queries_logits: Tensor, text_queries: Tensor):
"""Compute the query-text contrastive loss.
Args:
contrastive_queries_logits (`torch.Tensor`):
A tensor of shape `batch_size, num_queries, hidden_dim`
text_queries (`torch.Tensor`):
A tensor of shape `batch_size, num_queries, hidden_dim`
Returns:
`Dict[str, Tensor]`: A dict of `torch.Tensor` containing the following key:
- **loss_contrastive** -- The query-text contrastive loss computed using task-guided queries
and text queries derived from input text list.
"""
image_queries = contrastive_queries_logits.float()
image_queries = nn.functional.normalize(image_queries.flatten(1), dim=-1)
text_queries = nn.functional.normalize(text_queries.flatten(1), dim=-1)
logit_scale = torch.clamp(self.logit_scale.exp(), max=100)
logits_per_text = torch.matmul(text_queries, image_queries.t()) * logit_scale
logits_per_img = logits_per_text.t()
loss_img = nn.functional.cross_entropy(
logits_per_img, torch.arange(len(logits_per_img), device=logits_per_text.device)
)
loss_text = nn.functional.cross_entropy(
logits_per_text, torch.arange(len(logits_per_text), device=logits_per_text.device)
)
loss_contrastive = loss_img + loss_text
losses = {"loss_contrastive": loss_contrastive}
return losses
def loss_labels(
self, class_queries_logits: Tensor, class_labels: List[Tensor], indices: Tuple[np.array]
):
) -> Dict[str, Tensor]:
"""Compute the losses related to the labels using cross entropy.
Args:
class_queries_logits (`torch.Tensor`):
A tensor of shape `batch_size, num_queries, num_labels`
class_labels (`List[torch.Tensor]`):
List of class labels of shape `(labels)`.
indices (`Tuple[np.array])`:
The indices computed by the Hungarian matcher.
Returns:
`Dict[str, Tensor]`: A dict of `torch.Tensor` containing the following key:
- **loss_cross_entropy** -- The loss computed using cross entropy on the predicted and ground truth labels.
"""
pred_logits = class_queries_logits
batch_size, num_queries, _ = pred_logits.shape
criterion = nn.CrossEntropyLoss(weight=self.empty_weight)
idx = self._get_predictions_permutation_indices(indices)
target_classes_o = torch.cat([target[j] for target, (_, j) in zip(class_labels, indices)])
target_classes = torch.full(
(batch_size, num_queries), fill_value=self.num_classes, dtype=torch.int64, device=pred_logits.device
)
target_classes[idx] = target_classes_o
pred_logits_transposed = pred_logits.transpose(1, 2)
loss_ce = criterion(pred_logits_transposed, target_classes)
losses = {"loss_cross_entropy": loss_ce}
return losses
) -> Dict[str, Tensor]:
"""计算与掩码相关的损失,使用焦点和Dice损失。
Args:
masks_queries_logits (`torch.Tensor`):
形状为 `batch_size, num_queries, height, width` 的张量
掩码查询的预测logits
mask_labels (`torch.Tensor`):
形状为 `(labels, height, width)` 的掩码标签列表
indices (`Tuple[np.array])`:
由匈牙利匹配器计算得到的索引
num_masks (`int)`:
掩码的数量,用于归一化
Returns:
`Dict[str, Tensor]`: 包含两个键的 `torch.Tensor` 字典:
- **loss_mask** -- 使用sigmoid交叉熵损失在预测掩码和真实掩码之间计算的损失
- **loss_dice** -- 使用Dice损失在预测掩码和真实掩码之间计算的损失
"""
src_idx = self._get_predictions_permutation_indices(indices)
tgt_idx = self._get_targets_permutation_indices(indices)
pred_masks = masks_queries_logits[src_idx]
target_masks, _ = self._pad_images_to_max_in_batch(mask_labels)
target_masks = target_masks[tgt_idx]
pred_masks = pred_masks[:, None]
target_masks = target_masks[:, None]
with torch.no_grad():
point_coords = self.sample_points_using_uncertainty(
pred_masks,
self.calculate_uncertainty,
self.num_points,
self.oversample_ratio,
self.importance_sample_ratio,
)
point_labels = sample_point(target_masks, point_coords, align_corners=False).squeeze(1)
point_logits = sample_point(pred_masks, point_coords, align_corners=False).squeeze(1)
losses = {
"loss_mask": sigmoid_cross_entropy_loss(point_logits, point_labels, num_masks),
"loss_dice": dice_loss(point_logits, point_labels, num_masks),
}
del pred_masks
del target_masks
return losses
def calculate_uncertainty(self, logits: torch.Tensor) -> torch.Tensor:
"""
计算不确定性分数,根据 Mask2Former 论文,不确定性被定义为预测logit与0.0之间的L1距离,
针对`logits`中前景类别`classes`的预测。
Args:
logits (`torch.Tensor`):
形状为 (R, 1, ...) 的张量,用于特定类别或类别无关,其中 R 是所有图像中预测的总数目,
C 是前景类别的数量。值为logits。
Returns:
scores (`torch.Tensor`): 形状为 (R, 1, ...) 的张量,包含不确定性分数,其中不确定性最高的位置具有最高的分数。
"""
uncertainty_scores = -(torch.abs(logits))
return uncertainty_scores
def sample_points_using_uncertainty(
self,
logits: torch.Tensor,
uncertainty_function,
num_points: int,
oversample_ratio: int,
importance_sample_ratio: float,
def _get_predictions_permutation_indices(self, indices):
"""
This method generates permutation indices for predictions based on the provided indices.
Args:
indices (list):
List of tuples containing indices for predictions and targets.
Returns:
batch_indices (torch.Tensor):
Indices indicating batch-wise association of predictions.
predictions_indices (torch.Tensor):
Indices indicating the order of predictions after permutation.
"""
batch_indices = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
predictions_indices = torch.cat([src for (src, _) in indices])
return batch_indices, predictions_indices
def _get_targets_permutation_indices(self, indices):
"""
This method generates permutation indices for targets based on the provided indices.
Args:
indices (list):
List of tuples containing indices for predictions and targets.
Returns:
batch_indices (torch.Tensor):
Indices indicating batch-wise association of targets.
target_indices (torch.Tensor):
Indices indicating the order of targets after permutation.
"""
batch_indices = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
target_indices = torch.cat([tgt for (_, tgt) in indices])
return batch_indices, target_indices
def forward(
self,
masks_queries_logits: Tensor,
class_queries_logits: Tensor,
contrastive_queries_logits: Tensor,
mask_labels: List[Tensor],
class_labels: List[Tensor],
text_queries: Tensor,
auxiliary_predictions: Optional[Dict[str, Tensor]] = None,
calculate_contrastive_loss: bool = True,
):
"""
Defines the forward pass of the model.
Args:
masks_queries_logits (Tensor): Logits for mask prediction queries.
class_queries_logits (Tensor): Logits for class prediction queries.
contrastive_queries_logits (Tensor): Logits for contrastive learning queries.
mask_labels (List[Tensor]): List of mask labels.
class_labels (List[Tensor]): List of class labels.
text_queries (Tensor): Queries related to text inputs.
auxiliary_predictions (Optional[Dict[str, Tensor]]): Optional auxiliary predictions.
calculate_contrastive_loss (bool): Flag indicating whether to calculate contrastive loss.
Returns:
None
"""
pass
def get_num_masks(self, class_labels: torch.Tensor, device: torch.device) -> torch.Tensor:
"""
Computes the average number of target masks across the batch, for normalization purposes.
Args:
class_labels (torch.Tensor): Tensor of class labels.
device (torch.device): Device on which tensors reside.
Returns:
torch.Tensor: Average number of masks per batch, normalized.
"""
num_masks = sum([len(classes) for classes in class_labels])
num_masks = torch.as_tensor([num_masks], dtype=torch.float, device=device)
world_size = 1
if is_accelerate_available():
if PartialState._shared_state != {}:
num_masks = reduce(num_masks)
world_size = PartialState().num_processes
num_masks = torch.clamp(num_masks / world_size, min=1)
return num_masks
@dataclass
class OneFormerTransformerDecoderOutput(BaseModelOutput):
"""
Base class for outputs of the Transformer decoder. This class adds attributes for class predictions, mask
predictions and contrastive logits to BaseModelOutputWithCrossAttentions.
Args:
object_logits (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_dim)`):
Queries representation for the region proposals.
contrastive_logits (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_dim)`):
Queries representation for the contrastive loss.
prediction_masks (`torch.FloatTensor` of shape `(batch_size, num_queries, height, width)`):
Mask predictions from last layer of the transformer decoder.
prediction_class (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes+1)`):
Class predictions from last layer of the transformer decoder.
auxiliary_predictions (Tuple of Dict of `str, torch.FloatTensor`, *optional*):
Tuple of class and mask predictions from each layer of the transformer decoder.
"""
object_queries: torch.FloatTensor = None
contrastive_logits: Optional[torch.FloatTensor] = None
prediction_masks: torch.FloatTensor = None
prediction_class: torch.FloatTensor = None
auxiliary_predictions: Optional[Tuple[Dict[str, torch.FloatTensor]]] = None
@dataclass
class OneFormerPixelDecoderOutput(ModelOutput):
"""
OneFormer's pixel decoder module output, practically a Multi-Scale Deformable Attention based decoder. It returns
the mask features and the multiscale features.
Args:
multi_scale_features (`tuple(torch.FloatTensor)`):
Tuple of multi-scale features of scales [1/8, 1/16, 1/32] and shape `(batch_size, num_channels, height,
width)`from the Multi-Scale Deformable Attenntion based Pixel Decoder.
mask_features (`torch.FloatTensor`):
Tensor of shape `(batch_size, num_channels, height, width)`, 1/4 scale features from the last Pixel Decoder
Layer.
attentions (`tuple(torch.FloatTensor)`, *optional*):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`. Attentions weights from pixel decoder. Returned when `output_attentions=True` is passed
or when `config.output_attentions=True`
"""
multi_scale_features: Tuple[torch.FloatTensor] = None
mask_features: torch.FloatTensor = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class OneFormerPixelLevelModuleOutput(ModelOutput):
"""
OneFormer's pixel level module output. It returns both the last and (optionally) the hidden states from the
"""
Args:
encoder_features (List of `(torch.FloatTensor)`):
一个列表,包含了 `torch.FloatTensor` 类型的特征图,形状为 `(batch_size, num_channels, height, width)`。
这些特征图是模型在每个阶段输出的隐藏状态(也称为特征图)。
decoder_features (List of `(torch.FloatTensor)`):
一个列表,包含了 `torch.FloatTensor` 类型的特征图,形状为 `(batch_size, num_channels, height, width)`。
这些特征图是模型在每个阶段输出的隐藏状态(也称为特征图)。
decoder_last_feature (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width))`:
来自最后一个像素解码层的1/4尺度特征图。
@dataclass
class OneFormerModelOutput(ModelOutput):
"""
Class for outputs of [`OneFormerModel`]. This class returns all the needed hidden states to compute the logits.
"""
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
pixel_decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
transformer_decoder_hidden_states: Optional[torch.FloatTensor] = None
transformer_decoder_object_queries: torch.FloatTensor = None
transformer_decoder_contrastive_queries: Optional[torch.FloatTensor] = None
transformer_decoder_mask_predictions: torch.FloatTensor = None
transformer_decoder_class_predictions: torch.FloatTensor = None
transformer_decoder_auxiliary_predictions: Optional[Tuple[Dict[str, torch.FloatTensor]]] = None
text_queries: Optional[torch.FloatTensor] = None
task_token: torch.FloatTensor = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class OneFormerForUniversalSegmentationOutput(ModelOutput):
"""
Class for outputs of [`OneFormerForUniversalSegmentationOutput`].
This output can be directly passed to [`~OneFormerImageProcessor.post_process_semantic_segmentation`] or
[`~OneFormerImageProcessor.post_process_instance_segmentation`] or
[`~OneFormerImageProcessor.post_process_panoptic_segmentation`] depending on the task. Please, see
[`~OneFormerImageProcessor] for details regarding usage.
"""
loss: Optional[torch.FloatTensor] = None
class_queries_logits: torch.FloatTensor = None
masks_queries_logits: torch.FloatTensor = None
auxiliary_predictions: List[Dict[str, torch.FloatTensor]] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
pixel_decoder_hidden_states: Optional[List[torch.FloatTensor]] = None
transformer_decoder_hidden_states: Optional[torch.FloatTensor] = None
transformer_decoder_object_queries: torch.FloatTensor = None
transformer_decoder_contrastive_queries: Optional[torch.FloatTensor] = None
transformer_decoder_mask_predictions: torch.FloatTensor = None
transformer_decoder_class_predictions: torch.FloatTensor = None
transformer_decoder_auxiliary_predictions: Optional[List[Dict[str, torch.FloatTensor]]] = None
text_queries: Optional[torch.FloatTensor] = None
task_token: torch.FloatTensor = None
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
class OneFormerPixelDecoderFrozenBatchNorm2d(nn.Module):
"""
BatchNorm2d where the batch statistics and the affine parameters are fixed.
Copy-paste from torchvision.misc.ops with added eps before rqsrt, without which any other models than
torchvision.models.resnet[18,34,50,101] produce nans.
"""
def __init__(self, n):
super().__init__()
self.register_buffer("weight", torch.ones(n))
self.register_buffer("bias", torch.zeros(n))
self.register_buffer("running_mean", torch.zeros(n))
self.register_buffer("running_var", torch.ones(n))
def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
num_batches_tracked_key = prefix + "num_batches_tracked"
if num_batches_tracked_key in state_dict:
del state_dict[num_batches_tracked_key]
super()._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
)
def forward(self, x):
weight = self.weight.reshape(1, -1, 1, 1)
bias = self.bias.reshape(1, -1, 1, 1)
running_var = self.running_var.reshape(1, -1, 1, 1)
running_mean = self.running_mean.reshape(1, -1, 1, 1)
epsilon = 1e-5
scale = weight * (running_var + epsilon).rsqrt()
bias = bias - running_mean * scale
return x * scale + bias
class OneFormerPixelDecoderEncoderMultiscaleDeformableAttention(nn.Module):
"""
在 Deformable DETR 中提出的多尺度可变形注意力机制。
"""
def __init__(self, embed_dim: int, num_heads: int, n_levels: int, n_points: int):
super().__init__()
if embed_dim % num_heads != 0:
raise ValueError(
f"embed_dim (d_model) must be divisible by num_heads, but got {embed_dim} and {num_heads}"
)
dim_per_head = embed_dim // num_heads
if not ((dim_per_head & (dim_per_head - 1) == 0) and dim_per_head != 0):
warnings.warn(
"You'd better set embed_dim (d_model) in DeformableDetrMultiscaleDeformableAttention to make the"
" dimension of each attention head a power of 2 which is more efficient in the authors' CUDA"
" implementation."
)
self.im2col_step = 128
self.d_model = embed_dim
self.n_levels = n_levels
self.n_heads = num_heads
self.n_points = n_points
self.sampling_offsets = nn.Linear(embed_dim, num_heads * n_levels * n_points * 2)
self.attention_weights = nn.Linear(embed_dim, num_heads * n_levels * n_points)
self.value_proj = nn.Linear(embed_dim, embed_dim)
self.output_proj = nn.Linear(embed_dim, embed_dim)
def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]):
return tensor if position_embeddings is None else tensor + position_embeddings
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states=None,
encoder_attention_mask=None,
position_embeddings: Optional[torch.Tensor] = None,
reference_points=None,
spatial_shapes=None,
level_start_index=None,
output_attentions: bool = False,
):
if position_embeddings is not None:
hidden_states = self.with_pos_embed(hidden_states, position_embeddings)
batch_size, num_queries, _ = hidden_states.shape
batch_size, sequence_length, _ = encoder_hidden_states.shape
if (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() != sequence_length:
raise ValueError(
"Make sure to align the spatial shapes with the sequence length of the encoder hidden states"
)
value = self.value_proj(encoder_hidden_states)
if attention_mask is not None:
value = value.masked_fill(attention_mask[..., None], float(0))
value = value.view(batch_size, sequence_length, self.n_heads, self.d_model // self.n_heads)
sampling_offsets = self.sampling_offsets(hidden_states).view(
batch_size, num_queries, self.n_heads, self.n_levels, self.n_points, 2
)
attention_weights = self.attention_weights(hidden_states).view(
batch_size, num_queries, self.n_heads, self.n_levels * self.n_points
)
attention_weights = nn.functional.softmax(attention_weights, -1).view(
batch_size, num_queries, self.n_heads, self.n_levels, self.n_points
)
if reference_points.shape[-1] == 2:
offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
sampling_locations = (
reference_points[:, :, None, :, None, :]
+ sampling_offsets / offset_normalizer[None, None, None, :, None, :]
)
elif reference_points.shape[-1] == 4:
sampling_locations = (
reference_points[:, :, None, :, None, :2]
+ sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
)
else:
raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")
output = multi_scale_deformable_attention(value, spatial_shapes, sampling_locations, attention_weights)
output = self.output_proj(output)
return output, attention_weights
class OneFormerPixelDecoderEncoderLayer(nn.Module):
def __init__(self, config: OneFormerConfig):
super().__init__()
self.embed_dim = config.conv_dim
self.self_attn = OneFormerPixelDecoderEncoderMultiscaleDeformableAttention(
embed_dim=self.embed_dim,
num_heads=config.num_attention_heads,
n_levels=3,
n_points=4,
)
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.dropout = config.dropout
self.activation_fn = nn.functional.relu
self.activation_dropout = config.dropout
self.fc1 = nn.Linear(self.embed_dim, config.encoder_feedforward_dim)
self.fc2 = nn.Linear(config.encoder_feedforward_dim, self.embed_dim)
self.final_layer_norm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.is_training = config.is_training
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
position_embeddings: torch.Tensor = None,
reference_points=None,
spatial_shapes=None,
level_start_index=None,
output_attentions: bool = False,
"""
Args:
hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
输入层的输入数据。
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
注意力掩码。
position_embeddings (`torch.FloatTensor`, *optional*):
位置嵌入,用于加到 `hidden_states` 上。
reference_points (`torch.FloatTensor`, *optional*):
参考点。
spatial_shapes (`torch.LongTensor`, *optional*):
主干特征图的空间形状。
level_start_index (`torch.LongTensor`, *optional*):
等级开始索引。
output_attentions (`bool`, *optional*):
是否返回所有注意力层的注意力张量。有关详细信息,请参阅返回的张量中的 `attentions`。
"""
residual = hidden_states
hidden_states, attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=hidden_states,
encoder_attention_mask=attention_mask,
position_embeddings=position_embeddings,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
output_attentions=output_attentions,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.is_training)
hidden_states = residual + hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
residual = hidden_states
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.is_training)
hidden_states = self.fc2(hidden_states)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.is_training)
hidden_states = residual + hidden_states
hidden_states = self.final_layer_norm(hidden_states)
if self.is_training:
if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any():
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs
"""
Modified from from transformers.models.detr.modeling_deformable_detr.DeformableDetrEncoder with DeformableDetrEncoder->OneFormerPixelDecoderEncoderOnly
"""
class OneFormerPixelDecoderEncoderOnly(nn.Module):
"""
Transformer encoder consisting of *config.encoder_layers* deformable attention layers. Each layer is a
[`OneFormerPixelDecoderEncoderLayer`].
The encoder updates the flattened multi-scale feature maps through multiple deformable attention layers.
Args:
config: OneFormerConfig
"""
def __init__(self, config: OneFormerConfig):
super().__init__()
self.config = config
self.dropout = config.dropout
self.layers = nn.ModuleList([OneFormerPixelDecoderEncoderLayer(config) for _ in range(config.encoder_layers)])
@staticmethod
def get_reference_points(spatial_shapes, valid_ratios, device):
"""
Get reference points for each feature map. Used in decoder.
Args:
spatial_shapes (`torch.LongTensor` of shape `(num_feature_levels, 2)`):
Spatial shapes of each feature map.
valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`):
Valid ratios of each feature map.
device (`torch.device`):
Device on which to create the tensors.
Returns:
`torch.FloatTensor` of shape `(batch_size, num_queries, num_feature_levels, 2)`
"""
reference_points_list = []
for lvl, (height, width) in enumerate(spatial_shapes):
ref_y, ref_x = torch.meshgrid(
torch.linspace(0.5, height - 0.5, height, dtype=valid_ratios.dtype, device=device),
torch.linspace(0.5, width - 0.5, width, dtype=valid_ratios.dtype, device=device),
)
ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * height)
ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * width)
ref = torch.stack((ref_x, ref_y), -1)
reference_points_list.append(ref)
reference_points = torch.cat(reference_points_list, 1)
reference_points = reference_points[:, :, None] * valid_ratios[:, None]
return reference_points
def forward(
self,
inputs_embeds=None,
attention_mask=None,
position_embeddings=None,
spatial_shapes=None,
level_start_index=None,
valid_ratios=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
pass
def get_valid_ratio(self, mask, dtype=torch.float32):
"""Get the valid ratio of all feature maps."""
_, height, width = mask.shape
valid_height = torch.sum(~mask[:, :, 0], 1)
valid_width = torch.sum(~mask[:, 0, :], 1)
valid_ratio_heigth = valid_height.to(dtype) / height
valid_ratio_width = valid_width.to(dtype) / width
valid_ratio = torch.stack([valid_ratio_width, valid_ratio_heigth], -1)
return valid_ratio
def forward(
self,
features,
encoder_outputs=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
class OneFormerPixelLevelModule(nn.Module):
def __init__(self, config: OneFormerConfig):
"""
Pixel Level Module proposed in [Masked-attention Mask Transformer for Universal Image
Segmentation](https://arxiv.org/abs/2112.01527). It runs the input image through a backbone and a pixel
decoder, generating multi-scale feature maps and pixel embeddings.
Args:
config ([`OneFormerConfig`]):
The configuration used to instantiate this model.
"""
super().__init__()
self.encoder = load_backbone(config)
self.decoder = OneFormerPixelDecoder(config, feature_channels=self.encoder.channels)
def forward(self, pixel_values: Tensor, output_hidden_states: bool = False) -> OneFormerPixelLevelModuleOutput:
features: List[Tensor] = self.encoder(pixel_values).feature_maps
decoder_output: OneFormerPixelDecoderOutput = self.decoder(features, output_hidden_states=output_hidden_states)
return OneFormerPixelLevelModuleOutput(
encoder_features=tuple(features),
decoder_features=decoder_output.multi_scale_features,
decoder_last_feature=decoder_output.mask_features,
)
class OneFormerAttention(nn.Module):
"""
Multi-headed attention from 'Attention Is All You Need' paper. Here, we add position embeddings to the queries and
keys (as explained in the DETR paper).
"""
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
is_decoder: bool = False,
bias: bool = True,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
if self.head_dim * num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {num_heads})."
)
self.scaling = self.head_dim**-0.5
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
def _shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int):
return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]):
return tensor if position_embeddings is None else tensor + position_embeddings
class OneFormerTransformerDecoderSelfAttentionLayer(nn.Module):
def __init__(
self, embed_dim, num_heads, dropout=0.0, activation="relu", normalize_before=False, layer_norm_eps=1e-05
):
super().__init__()
self.self_attn = OneFormerAttention(embed_dim=embed_dim, num_heads=num_heads, dropout=dropout, is_decoder=True)
self.norm = nn.LayerNorm(embed_dim, eps=layer_norm_eps)
self.dropout = nn.Dropout(dropout)
self.activation = ACT2FN[activation]
self.normalize_before = normalize_before
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward_post(
self,
output,
output_mask: Optional[Tensor] = None,
output_key_padding_mask: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
):
output2, attention_weights = self.self_attn(
hidden_states=output, position_embeddings=query_pos, attention_mask=output_mask, output_attentions=True
)
output = output + self.dropout(output2)
output = self.norm(output)
return output, attention_weights
def forward_pre(
self,
output,
output_mask: Optional[Tensor] = None,
output_key_padding_mask: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
):
output2 = self.norm(output)
output2, attention_weights = self.self_attn(
hidden_states=output2, position_embeddings=query_pos, attention_mask=output_mask, output_attentions=True
)
output = output + self.dropout(output2)
return output, attention_weights
def forward(
self,
output,
output_mask: Optional[Tensor] = None,
output_key_padding_mask: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
):
if self.normalize_before:
return self.forward_pre(output, output_mask, output_key_padding_mask, query_pos)
return self.forward_post(output, output_mask, output_key_padding_mask, query_pos)
class OneFormerTransformerDecoderCrossAttentionLayer(nn.Module):
def __init__(
self, embed_dim, num_heads, dropout=0.0, activation="relu", normalize_before=False, layer_norm_eps=1e-05
):
super().__init__()
self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)
self.norm = nn.LayerNorm(embed_dim, eps=layer_norm_eps)
self.dropout = nn.Dropout(dropout)
self.activation = ACT2FN[activation]
self.normalize_before = normalize_before
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward_post(
self,
output,
memory,
memory_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
):
output2, attention_weights = self.multihead_attn(
query=self.with_pos_embed(output, query_pos),
key=self.with_pos_embed(memory, pos),
value=memory,
attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask,
)
output = output + self.dropout(output2)
output = self.norm(output)
return output, attention_weights
def forward_pre(
self,
output,
memory,
memory_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
):
output2 = self.norm(output)
output2, attention_weights = self.multihead_attn(
query=self.with_pos_embed(output2, query_pos),
key=self.with_pos_embed(memory, pos),
value=memory,
attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask,
)
output = output + self.dropout(output2)
return output, attention_weights
def forward(
self,
output,
memory,
memory_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
):
if self.normalize_before:
return self.forward_pre(output, memory, memory_mask, memory_key_padding_mask, pos, query_pos)
return self.forward_post(output, memory, memory_mask, memory_key_padding_mask, pos, query_pos)
class OneFormerTransformerDecoderFFNLayer(nn.Module):
def __init__(
self,
d_model,
dim_feedforward=2048,
dropout=0.0,
activation="relu",
normalize_before=False,
layer_norm_eps=1e-05,
):
super().__init__()
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm = nn.LayerNorm(d_model, eps=layer_norm_eps)
self.activation = ACT2FN[activation]
self.normalize_before = normalize_before
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward_post(self, output):
output2 = self.linear2(self.dropout(self.activation(self.linear1(output))))
output = output + self.dropout(output2)
output = self.norm(output)
return output
def forward_pre(self, output):
output2 = self.norm(output)
output2 = self.linear2(self.dropout(self.activation(self.linear1(output2)))
output = output + self.dropout(output2)
return output
def forward(self, output):
if self.normalize_before:
return self.forward_pre(output)
return self.forward_post(output)
class OneFormerMLPPredictionHead(nn.Module):
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int = 3):
"""
A classic Multi Layer Perceptron (MLP).
Args:
input_dim (`int`):
The input dimensions.
hidden_dim (`int`):
The hidden dimensions.
output_dim (`int`):
The output dimensions.
num_layers (int, *optional*, defaults to 3):
The number of layers.
"""
super().__init__()
in_dims = [input_dim] + [hidden_dim] * (num_layers - 1)
out_dims = [hidden_dim] * (num_layers - 1) + [output_dim]
layers = []
for i, (in_dim, out_dim) in enumerate(zip(in_dims, out_dims)):
layers.append(
PredictionBlock(in_dim, out_dim, activation=nn.ReLU() if i < num_layers - 1 else nn.Identity())
)
self.layers = nn.Sequential(*layers)
def forward(self, input: Tensor) -> Tensor:
return self.layers(input)
class OneFormerTransformerDecoderLayer(nn.Module):
def __init__(self, config: OneFormerConfig):
super().__init__()
self.embed_dim = config.hidden_dim
self.num_feature_levels = 3
self.cross_attn = OneFormerTransformerDecoderCrossAttentionLayer(
embed_dim=self.embed_dim,
num_heads=config.num_attention_heads,
dropout=0.0,
normalize_before=config.pre_norm,
layer_norm_eps=config.layer_norm_eps,
)
self.self_attn = OneFormerTransformerDecoderSelfAttentionLayer(
embed_dim=self.embed_dim,
num_heads=config.num_attention_heads,
dropout=0.0,
normalize_before=config.pre_norm,
layer_norm_eps=config.layer_norm_eps,
)
self.ffn = OneFormerTransformerDecoderFFNLayer(
d_model=self.embed_dim,
dim_feedforward=config.dim_feedforward,
dropout=0.0,
normalize_before=config.pre_norm,
layer_norm_eps=config.layer_norm_eps,
)
def forward(
self,
index: int,
output: torch.Tensor,
multi_stage_features: List[torch.Tensor],
multi_stage_positional_embeddings: List[torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
query_embeddings: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
):
"""
Args:
index (`int`): Transformer 解码器中层的索引。
output (`torch.FloatTensor`): 对象查询,形状为 `(N, batch, hidden_dim)`
multi_stage_features (`List[torch.Tensor]`): 像素解码器中的多尺度特征。
multi_stage_positional_embeddings (`List[torch.Tensor]`):
多尺度特征的位置嵌入。
attention_mask (`torch.FloatTensor`): 用于掩蔽的注意力掩码。
query_embeddings (`torch.FloatTensor`, *optional*):
被添加到自注意力层中的查询和键的位置嵌入。
output_attentions (`bool`, *optional*):
是否返回所有注意力层的注意力张量。查看返回的张量中的 `attentions` 以获取更多详细信息。
"""
level_index = index % self.num_feature_levels
attention_mask[torch.where(attention_mask.sum(-1) == attention_mask.shape[-1])] = False
output, cross_attn_weights = self.cross_attn(
output,
multi_stage_features[level_index],
memory_mask=attention_mask,
memory_key_padding_mask=None,
pos=multi_stage_positional_embeddings[level_index],
query_pos=query_embeddings,
)
output, self_attn_weights = self.self_attn(
output,
output_mask=None,
output_key_padding_mask=None,
query_pos=query_embeddings,
)
output = self.ffn(output)
outputs = (output,)
if output_attentions:
outputs += (self_attn_weights, cross_attn_weights)
return outputs
class OneFormerTransformerDecoderQueryTransformerDecoderLayer(nn.Module):
def __init__(
self,
d_model,
nhead,
dim_feedforward=2048,
dropout=0.1,
activation="relu",
normalize_before=False,
layer_norm_eps=1e-05,
):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
self.activation = ACT2FN[activation]
self.normalize_before = normalize_before
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
class OneFormerTransformerDecoderQueryTransformerDecoder(nn.Module):
def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
super().__init__()
self.layers = _get_clones(decoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
self.return_intermediate = return_intermediate
def forward(
self,
output,
memory,
output_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
output_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
):
intermediate = []
for layer in self.layers:
output = layer(
output,
memory,
output_mask=output_mask,
memory_mask=memory_mask,
output_key_padding_mask=output_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask,
pos=pos,
query_pos=query_pos,
)
if self.return_intermediate:
intermediate.append(self.norm(output))
if self.norm is not None:
output = self.norm(output)
if self.return_intermediate:
intermediate.pop()
intermediate.append(output)
if self.return_intermediate:
return torch.stack(intermediate)
return output.unsqueeze(0)
def forward_post(
self,
output,
memory,
output_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
output_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
):
q = k = self.with_pos_embed(output, query_pos)
output2 = self.self_attn(q, k, value=output, attn_mask=output_mask, key_padding_mask=output_key_padding_mask)
output2 = output2[0]
output = output + self.dropout1(output2)
output = self.norm1(output)
output2 = self.multihead_attn(
query=self.with_pos_embed(output, query_pos),
key=self.with_pos_embed(memory, pos),
value=memory,
attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask,
)
output2 = output2[0]
output = output + self.dropout2(output2)
output = self.norm2(output)
output2 = self.linear2(self.dropout(self.activation(self.linear1(output))))
output = output + self.dropout3(output2)
output = self.norm3(output)
return output
def forward_pre(
self,
output,
memory,
output_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
output_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
):
output2 = self.norm1(output)
q = k = self.with_pos_embed(output2, query_pos)
output2 = self.self_attn(q, k, value=output2, attn_mask=output_mask, key_padding_mask=output_key_padding_mask)
output2 = output2[0]
output = output + self.dropout1(output2)
output2 = self.norm2(output)
output2 = self.multihead_attn(
query=self.with_pos_embed(output2, query_pos),
key=self.with_pos_embed(memory, pos),
value=memory,
attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask,
)
output2 = output2[0]
output = output + self.dropout2(output2)
output2 = self.norm3(output)
output2 = self.linear2(self.dropout(self.activation(self.linear1(output2))))
output = output + self.dropout3(output2)
return output
def forward(
self,
output,
memory,
output_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
output_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
):
if self.pre_norm:
return self.forward_pre(
output, memory, output_mask, memory_mask, output_key_padding_mask, memory_key_padding_mask, pos, query_pos
)
else:
return self.forward_post(
output, memory, output_mask, memory_mask, output_key_padding_mask, memory_key_padding_mask, pos, query_pos
)
):
if self.normalize_before:
return self.forward_pre(
output,
memory,
output_mask,
memory_mask,
output_key_padding_mask,
memory_key_padding_mask,
pos,
query_pos,
)
return self.forward_post(
output,
memory,
output_mask,
memory_mask,
output_key_padding_mask,
memory_key_padding_mask,
pos,
query_pos,
)
class OneFormerTransformerDecoderQueryTransformer(nn.Module):
"""
定义一个Transformer解码器模块,用于处理查询转换任务。
Args:
d_model (int): 模型的输入和输出维度,默认为512。
nhead (int): 多头注意力机制中注意头的数量,默认为8。
num_decoder_layers (int): 解码器堆叠的层数,默认为6。
dim_feedforward (int): 每个位置的前馈神经网络的维度,默认为2048。
dropout (float): Dropout概率,默认为0.1。
activation (str): 激活函数类型,默认为"relu"。
normalize_before (bool): 是否在每个子层之前进行LayerNorm,默认为False。
return_intermediate_dec (bool): 是否返回每个解码层的中间结果,默认为False。
layer_norm_eps (float): LayerNorm中的epsilon值,默认为1e-05。
"""
def __init__(
self,
d_model=512,
nhead=8,
num_decoder_layers=6,
dim_feedforward=2048,
dropout=0.1,
activation="relu",
normalize_before=False,
return_intermediate_dec=False,
layer_norm_eps=1e-05,
):
super().__init__()
decoder_layer = OneFormerTransformerDecoderQueryTransformerDecoderLayer(
d_model, nhead, dim_feedforward, dropout, activation, normalize_before, layer_norm_eps
)
decoder_norm = nn.LayerNorm(d_model, eps=layer_norm_eps)
self.decoder = OneFormerTransformerDecoderQueryTransformerDecoder(
decoder_layer,
num_decoder_layers,
decoder_norm,
return_intermediate=return_intermediate_dec,
)
self.d_model = d_model
self.nhead = nhead
def forward(self, src, mask, query_embed, pos_embed, task_token=None):
"""
定义了Transformer解码器的前向传播逻辑。
Args:
src (torch.Tensor): 输入的源数据,形状为(batch_size, seq_len, d_model)。
mask (torch.Tensor): 掩码张量,形状为(batch_size, seq_len),标记哪些位置需要屏蔽。
query_embed (torch.Tensor): 查询嵌入向量,形状为(seq_len_query, d_model),用于引导解码器。
pos_embed (torch.Tensor): 位置嵌入向量,形状为(seq_len, d_model),提供位置信息。
task_token (torch.Tensor): 任务令牌,形状为(1, 1, d_model),用于特定任务的解码。
Returns:
torch.Tensor: 解码器的输出,形状为(batch_size, d_model, seq_len_query),经过转置以匹配预期输出形状。
"""
batch_size = src.shape[0]
src = src.flatten(2).permute(2, 0, 1)
pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
query_embed = query_embed.unsqueeze(1).repeat(1, batch_size, 1)
if mask is not None:
mask = mask.flatten(1)
if task_token is None:
queries = torch.zeros_like(query_embed)
else:
queries = task_token.repeat(query_embed.shape[0], 1, 1)
queries = self.decoder(queries, src, memory_key_padding_mask=mask, pos=pos_embed, query_pos=query_embed)
return queries.transpose(1, 2)
class OneFormerTransformerDecoder(nn.Module):
"""
Transformer解码器模块。
"""
def __init__(self, in_channels: int, config: OneFormerConfig):
super().__init__()
self.config = config
self.dropout = config.dropout
self.num_heads = config.num_attention_heads
self.is_training = config.is_training
self.use_task_norm = config.use_task_norm
self.use_auxiliary_loss = config.use_auxiliary_loss
self.query_transformer = OneFormerTransformerDecoderQueryTransformer(
d_model=config.hidden_dim,
dropout=config.dropout,
nhead=config.num_attention_heads,
dim_feedforward=config.dim_feedforward,
num_decoder_layers=config.query_dec_layers,
normalize_before=config.pre_norm,
return_intermediate_dec=False,
layer_norm_eps=config.layer_norm_eps,
)
self.decoder_norm = nn.LayerNorm(config.hidden_dim, eps=config.layer_norm_eps)
self.num_feature_levels = 3
self.layers = nn.ModuleList(
[OneFormerTransformerDecoderLayer(config) for _ in range(config.decoder_layers - 1)]
)
self.query_input_projection = nn.Conv2d(in_channels, config.hidden_dim, kernel_size=1)
self.class_embed = nn.Linear(config.hidden_dim, config.num_labels + 1)
self.mask_embed = OneFormerMLPPredictionHead(
config.hidden_dim,
config.hidden_dim,
config.mask_dim,
3,
)
def forward(
self,
task_token=None,
multi_stage_features=None,
multi_stage_positional_embeddings=None,
mask_features=None,
query_features=None,
query_embeddings=None,
query_embedder=None,
size_list=None,
output_attentions=None,
):
):
if self.use_task_norm:
task_token = self.decoder_norm(task_token)
object_queries = self.query_transformer(
query_features,
None,
query_embedder.weight[:-1],
self.query_input_projection(mask_features),
task_token if self.use_task_norm else None,
)
object_queries = object_queries[0].permute(1, 0, 2)
queries = torch.cat([object_queries, task_token], dim=0)
output = queries.clone()
intermediate_class_predictions = []
intermediate_mask_predictions = []
outputs_class, outputs_mask, attention_mask = self.forward_prediction_heads(
output, mask_features, attention_mask_target_size=size_list[0]
)
intermediate_class_predictions.append(outputs_class)
intermediate_mask_predictions.append(outputs_mask)
attentions = ()
for index, layer in enumerate(self.layers):
layer_outputs = layer(
index=index,
output=output,
multi_stage_features=multi_stage_features,
multi_stage_positional_embeddings=multi_stage_positional_embeddings,
attention_mask=attention_mask,
query_embeddings=query_embeddings,
output_attentions=output_attentions,
)
output = layer_outputs[0]
attentions += (layer_outputs[1:],)
outputs_class, outputs_mask, attention_mask = self.forward_prediction_heads(
output, mask_features, attention_mask_target_size=size_list[(index + 1) % self.num_feature_levels]
)
intermediate_class_predictions.append(outputs_class)
intermediate_mask_predictions.append(outputs_mask)
if not len(intermediate_mask_predictions) == len(self.layers) + 1:
raise ValueError(
"Intermediate predictions in the transformer decoder must have the same number of elements as number"
" of layers"
)
object_queries = layer_outputs[0].permute(1, 0, 2)
contrastive_logits = queries.permute(1, 0, 2)
return OneFormerTransformerDecoderOutput(
object_queries=object_queries,
contrastive_logits=contrastive_logits,
prediction_masks=intermediate_mask_predictions[-1],
prediction_class=intermediate_class_predictions[-1],
auxiliary_predictions=self._get_aux_predictions(
intermediate_class_predictions, intermediate_mask_predictions
) if self.use_auxiliary_loss else None,
attentions=attentions,
)
def forward_prediction_heads(self, output, mask_features, attention_mask_target_size):
decoder_output = self.decoder_norm(output)
decoder_output = decoder_output.transpose(0, 1)
outputs_class = self.class_embed(decoder_output)
mask_embed = self.mask_embed(decoder_output)
outputs_mask = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features)
attention_mask = nn.functional.interpolate(
outputs_mask, size=attention_mask_target_size, mode="bilinear", align_corners=False
)
attention_mask = (
attention_mask.sigmoid().flatten(2).unsqueeze(1).repeat(1, self.num_heads, 1, 1).flatten(0, 1) < 0.5
).bool()
attention_mask = attention_mask.detach()
return outputs_class, outputs_mask, attention_mask
def _get_aux_predictions(self, outputs_class, outputs_seg_masks):
aux_list = [
{"class_queries_logits": a, "masks_queries_logits": b}
for a, b in zip(outputs_class[:-1], outputs_seg_masks[:-1])
]
return tuple(aux_list)
class OneFormerTransformerModule(nn.Module):
"""
The OneFormer's transformer module.
"""
def __init__(self, in_features: int, config: OneFormerConfig):
super().__init__()
hidden_dim = config.hidden_dim
self.num_feature_levels = 3
self.position_embedder = OneFormerSinePositionEmbedding(num_pos_feats=hidden_dim // 2, normalize=True)
self.queries_embedder = nn.Embedding(config.num_queries, hidden_dim)
self.input_projections = []
for _ in range(self.num_feature_levels):
if in_features != hidden_dim or config.enforce_input_proj:
self.input_projections.append(nn.Conv2d(in_features, hidden_dim, kernel_size=1))
else:
self.input_projections.append(nn.Sequential())
self.decoder = OneFormerTransformerDecoder(in_channels=in_features, config=config)
self.level_embed = nn.Embedding(self.num_feature_levels, hidden_dim)
def forward(
self,
multi_scale_features: List[Tensor],
mask_features: Tensor,
task_token: Tensor,
output_attentions: bool = False,
) -> OneFormerTransformerDecoderOutput:
if not len(multi_scale_features) == self.num_feature_levels:
raise ValueError(
f"Number of elements in multi_scale_features ({len(multi_scale_features)}) and num_feature_levels"
f" ({self.num_feature_levels}) do not match!"
)
multi_stage_features = []
multi_stage_positional_embeddings = []
size_list = []
for i in range(self.num_feature_levels):
size_list.append(multi_scale_features[i].shape[-2:])
multi_stage_positional_embeddings.append(self.position_embedder(multi_scale_features[i], None).flatten(2))
multi_stage_features.append(
self.input_projections[i](multi_scale_features[i]).flatten(2)
+ self.level_embed.weight[i][None, :, None]
)
multi_stage_positional_embeddings[-1] = multi_stage_positional_embeddings[-1].permute(2, 0, 1)
multi_stage_features[-1] = multi_stage_features[-1].permute(2, 0, 1)
_, batch_size, _ = multi_stage_features[0].shape
query_embeddings = self.queries_embedder.weight.unsqueeze(1).repeat(1, batch_size, 1)
task_token = task_token.unsqueeze(0)
query_features = self.position_embedder(mask_features, None)
return self.decoder(
task_token=task_token,
multi_stage_features=multi_stage_features,
multi_stage_positional_embeddings=multi_stage_positional_embeddings,
mask_features=mask_features,
query_features=query_features,
query_embeddings=query_embeddings,
query_embedder=self.queries_embedder,
size_list=size_list,
output_attentions=output_attentions,
)
class OneFormerSinePositionEmbedding(nn.Module):
"""
This is a more standard version of the position embedding, very similar to the one used by the Attention is all you
need paper, generalized to work on images.
"""
def __init__(
self, num_pos_feats: int = 64, temperature: int = 10000, normalize: bool = False, scale: Optional[float] = None
):
super().__init__()
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
self.num_pos_feats = num_pos_feats
self.temperature = temperature
self.normalize = normalize
self.scale = 2 * math.pi if scale is None else scale
def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor:
if mask is None:
mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
not_mask = (~mask).to(x.dtype)
y_embed = not_mask.cumsum(1)
x_embed = not_mask.cumsum(2)
if self.normalize:
eps = 1e-6
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.num_pos_feats, dtype=torch.int64, device=x.device).type_as(x)
dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
return pos
class PredictionBlock(nn.Module):
def __init__(self, in_dim: int, out_dim: int, activation: nn.Module) -> None:
super().__init__()
self.layers = [nn.Linear(in_dim, out_dim), activation]
for i, layer in enumerate(self.layers):
self.add_module(str(i), layer)
def forward(self, input: Tensor) -> Tensor:
hidden_state = input
for layer in self.layers:
hidden_state = layer(hidden_state)
return hidden_state
class OneFormerTextMapperAttention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.q_proj = nn.Linear(dim, dim, bias=qkv_bias)
self.k_proj = nn.Linear(dim, dim, bias=qkv_bias)
self.v_proj = nn.Linear(dim, dim, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, q, k, v):
batch_size, q_sequence_length, num_channels = q.shape
if not k.shape == v.shape:
raise ValueError(f"keys ({list(k.shape)}) and values ({list(v.shape)}) have different shapes!")
batch_size, k_sequence_length, num_channels = k.shape
q = self.q_proj(q).reshape(batch_size, q_sequence_length, self.num_heads, num_channels // self.num_heads)
k = self.k_proj(k).reshape(batch_size, k_sequence_length, self.num_heads, num_channels // self.num_heads)
v = self.v_proj(v).reshape(batch_size, k_sequence_length, self.num_heads, num_channels // self.num_heads)
attn = torch.einsum("bnkc,bmkc->bknm", q, k) * self.scale
attn = attn.softmax(dim=-1)
output = torch.einsum("bknm,bmkc->bnkc", attn, v).reshape(batch_size, q_sequence_length, num_channels)
output = self.proj(output)
output = self.proj_drop(output)
return output
class OneFormerTextTransformerDecoderLayer(nn.Module):
def __init__(
self,
d_model,
nhead,
dropout=0.1,
layer_norm_eps=1e-05,
):
super().__init__()
self.self_attn = OneFormerTextMapperAttention(d_model, nhead, proj_drop=dropout)
self.cross_attn = OneFormerTextMapperAttention(d_model, nhead, proj_drop=dropout)
self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps)
self.dropout = nn.Dropout(dropout)
self.mlp = nn.Sequential(
nn.Linear(d_model, d_model * 4), nn.GELU(), nn.Dropout(dropout), nn.Linear(d_model * 4, d_model)
)
def forward(self, hidden_state, mem):
q = k = v = self.norm1(hidden_state)
hidden_state = hidden_state + self.self_attn(q, k, v)
q = self.norm2(hidden_state)
hidden_state = hidden_state + self.cross_attn(q, mem, mem)
hidden_state = hidden_state + self.dropout(self.mlp(self.norm3(hidden_state)))
return hidden_state
class OneFormerTextContextDecoder(nn.Module):
def __init__(
self,
transformer_width=256,
transformer_heads=4,
transformer_layers=6,
visual_dim=1024,
dropout=0.1,
layer_norm_eps=1e-05,
**kwargs,
):
super().__init__()
self.memory_proj = nn.Sequential(
nn.LayerNorm(visual_dim, eps=layer_norm_eps),
nn.Linear(visual_dim, transformer_width),
nn.LayerNorm(transformer_width, eps=layer_norm_eps),
)
self.text_proj = nn.Sequential(
nn.LayerNorm(visual_dim, eps=layer_norm_eps),
nn.Linear(visual_dim, transformer_width),
)
self.decoder = nn.ModuleList(
[
OneFormerTextTransformerDecoderLayer(transformer_width, transformer_heads, dropout, layer_norm_eps)
for _ in range(transformer_layers)
]
)
self.out_proj = nn.Sequential(
nn.LayerNorm(transformer_width, eps=layer_norm_eps), nn.Linear(transformer_width, visual_dim)
)
def forward(self, text, visual):
visual = self.memory_proj(visual)
hidden_state = self.text_proj(text)
for layer in self.decoder:
hidden_state = layer(hidden_state, visual)
return self.out_proj(hidden_state)
class OneFormerTextMLP(nn.Module):
def __init__(
self,
hidden_size: Optional[int] = None,
intermediate_size: Optional[int] = None,
output_size: Optional[int] = None,
...
):
super().__init__()
):
super().__init__()
self.activation_fn = ACT2FN["quick_gelu"]
hidden_size = hidden_size
intermediate_size = intermediate_size
output_size = output_size
self.fc1 = nn.Linear(hidden_size, intermediate_size)
self.fc2 = nn.Linear(intermediate_size, output_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
class OneFormerTextTransformerLayer(nn.Module):
def __init__(self, width: int, heads: int, attn_mask: torch.Tensor, layer_norm_eps=1e-05):
super().__init__()
self.self_attn = nn.MultiheadAttention(width, heads)
self.layer_norm1 = nn.LayerNorm(width, eps=layer_norm_eps)
self.mlp = OneFormerTextMLP(width, width * 4, width)
self.layer_norm2 = nn.LayerNorm(width, eps=layer_norm_eps)
self.attn_mask = attn_mask
def forward(
self,
hidden_states: torch.Tensor,
key_padding_mask: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states = self.self_attn(
hidden_states,
hidden_states,
hidden_states,
need_weights=False,
key_padding_mask=key_padding_mask,
)[0]
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class OneFormerTextTransformer(nn.Module):
def __init__(
self,
width: int,
layers: int,
heads: int,
attn_mask: torch.Tensor = None,
use_checkpoint=False,
layer_norm_eps=1e-05,
):
super().__init__()
self.width = width
self.num_layers = layers
self.layers = nn.Sequential(
*[OneFormerTextTransformerLayer(width, heads, attn_mask, layer_norm_eps) for _ in range(layers)]
)
self.use_checkpoint = use_checkpoint
def forward(self, hidden_states: torch.Tensor):
for layer in self.layers:
if self.use_checkpoint:
hidden_states = self._gradient_checkpointing_func(layer, hidden_states)
else:
hidden_states = layer(hidden_states)
return hidden_states
class OneFormerTextEncoder(nn.Module):
def __init__(
self,
context_length: int,
width: int,
layers: int,
vocab_size,
use_checkpoint=False,
layer_norm_eps=1e-05,
):
super().__init__()
heads = width // 64
self.context_length = context_length
self.width = width
self.transformer = OneFormerTextTransformer(
width=width,
layers=layers,
heads=heads,
attn_mask=self.build_attention_mask(),
use_checkpoint=use_checkpoint,
layer_norm_eps=layer_norm_eps,
)
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, width))
self.ln_final = nn.LayerNorm(width, eps=layer_norm_eps)
self.token_embedding = nn.Embedding(vocab_size, width)
mask = torch.empty(self.context_length, self.context_length)
mask.fill_(float("-inf"))
mask.triu_(1)
return mask
hidden_state = self.token_embedding(text)
hidden_state = hidden_state + self.positional_embedding
hidden_state = hidden_state.permute(1, 0, 2)
hidden_state = self.transformer(hidden_state)
hidden_state = hidden_state.permute(1, 0, 2)
hidden_state = self.ln_final(hidden_state)
hidden_state = hidden_state[torch.arange(hidden_state.shape[0]), text.argmax(dim=-1)]
return hidden_state
class OneFormerTextMapper(nn.Module):
def __init__(self, config: OneFormerConfig):
super().__init__()
self.text_encoder = OneFormerTextEncoder(
context_length=config.text_encoder_context_length,
width=config.text_encoder_width,
layers=config.text_encoder_num_layers,
vocab_size=config.text_encoder_vocab_size,
layer_norm_eps=config.layer_norm_eps,
)
self.text_projector = OneFormerMLPPredictionHead(
config.text_encoder_width,
config.hidden_dim,
config.hidden_dim,
config.text_encoder_proj_layers,
)
if config.text_encoder_n_ctx > 0:
self.prompt_ctx = nn.Embedding(
config.text_encoder_n_ctx,
config.text_encoder_width,
)
else:
self.prompt_ctx = None
def forward(
self,
inputs: Tensor,
) -> Tensor:
text_queries = self.encode_text(inputs)
return text_queries
def encode_text(self, text):
if text.ndim is None:
raise ValueError("text must not be NoneType")
if text.ndim not in [2, 3]:
raise ValueError("Number of dimensions in text must be 2 or 3")
squeeze_dim = False
num_text = 1
if text.ndim == 3:
num_text = text.shape[1]
batch_size, num_text, hidden_dim = text.shape
text = text.reshape(batch_size * num_text, hidden_dim)
squeeze_dim = True
encoded_text = self.text_encoder(text)
text_queries = self.text_projector(encoded_text)
if squeeze_dim:
_, hidden_dim = text_queries.shape
text_queries = text_queries.reshape(batch_size, num_text, hidden_dim)
if self.prompt_ctx is not None:
text_queries_ctx = self.prompt_ctx.weight.unsqueeze(0).repeat(text_queries.shape[0], 1, 1)
text_queries = torch.cat([text_queries, text_queries_ctx], dim=1)
return text_queries
class OneFormerTaskModel(nn.Module):
def __init__(self, config: OneFormerConfig):
super().__init__()
self.task_mlp = OneFormerMLPPredictionHead(
config.task_seq_len,
config.hidden_dim,
config.hidden_dim,
2,
)
def forward(self, inputs: Tensor) -> Tensor:
task_tokens = self.task_mlp(inputs)
return task_tokens
ONEFORMER_START_DOCSTRING = r"""
This model is a PyTorch [nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use it as a
regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.
Parameters:
config ([`OneFormerConfig`]): Model configuration class with all the parameters of the model.
初始化模型配置类,包含模型的所有参数。
使用配置文件初始化时,不会加载与模型相关的权重,只加载配置信息。
若要加载模型权重,请参阅 [`~PreTrainedModel.from_pretrained`] 方法。
"""
ONEFORMER_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using [`OneFormerProcessor`]. See
[`OneFormerProcessor.__call__`] for details.
task_inputs (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
Task inputs. Task inputs can be obtained using [`AutoImageProcessor`]. See [`OneFormerProcessor.__call__`]
for details.
pixel_mask (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
Mask to avoid performing attention on padding pixel values. Mask values selected in `[0, 1]`:
- 1 for pixels that are real (i.e. **not masked**),
- 0 for pixels that are padding (i.e. **masked**).
[What are attention masks?](../glossary#attention-mask)
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of Detr's decoder attention layers.
return_dict (`bool`, *optional*):
Whether or not to return a [`~OneFormerModelOutput`] instead of a plain tuple.
"""
class OneFormerPreTrainedModel(PreTrainedModel):
config_class = OneFormerConfig
base_model_prefix = "model"
main_input_name = "pixel_values"
@add_start_docstrings(
"The bare OneFormer Model outputting raw hidden-states without any specific head on top.",
ONEFORMER_START_DOCSTRING,
)
class OneFormerModel(OneFormerPreTrainedModel):
main_input_name = ["pixel_values", "task_inputs"]
def __init__(self, config: OneFormerConfig):
super().__init__(config)
self.pixel_level_module = OneFormerPixelLevelModule(config)
self.transformer_module = OneFormerTransformerModule(in_features=config.conv_dim, config=config)
self.task_encoder = OneFormerTaskModel(config)
self.is_training = config.is_training
if self.is_training:
self.text_mapper = OneFormerTextMapper(config)
else:
self.text_mapper = None
self.post_init()
@add_start_docstrings_to_model_forward(ONEFORMER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=OneFormerModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values: Tensor,
task_inputs: Tensor,
text_inputs: Optional[Tensor] = None,
pixel_mask: Optional[Tensor] = None,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
@add_start_docstrings(
"OneFormer Model for instance, semantic and panoptic image segmentation.",
ONEFORMER_START_DOCSTRING,
)
class OneFormerForUniversalSegmentation(OneFormerPreTrainedModel):
main_input_name = ["pixel_values", "task_inputs"]
def __init__(self, config: OneFormerConfig):
super().__init__(config)
self.model = OneFormerModel(config)
self.matcher = OneFormerHungarianMatcher(
cost_class=config.class_weight,
cost_dice=config.dice_weight,
cost_mask=config.mask_weight,
num_points=config.train_num_points,
)
self.weight_dict: Dict[str, float] = {
"loss_cross_entropy": config.class_weight,
"loss_mask": config.mask_weight,
"loss_dice": config.dice_weight,
"loss_contrastive": config.contrastive_weight,
}
self.criterion = OneFormerLoss(
num_classes=config.num_labels,
matcher=self.matcher,
weight_dict=self.weight_dict,
eos_coef=config.no_object_weight,
num_points=config.train_num_points,
oversample_ratio=config.oversample_ratio,
importance_sample_ratio=config.importance_sample_ratio,
contrastive_temperature=config.contrastive_temperature,
)
self.post_init()
def get_loss_dict(
self,
masks_queries_logits: Tensor,
class_queries_logits: Tensor,
contrastive_queries_logits: Tensor,
mask_labels: Tensor,
class_labels: Tensor,
text_queries: Tensor,
auxiliary_predictions: Dict[str, Tensor],
calculate_contrastive_loss: bool,
) -> Dict[str, Tensor]:
loss_dict: Dict[str, Tensor] = self.criterion(
masks_queries_logits=masks_queries_logits,
class_queries_logits=class_queries_logits,
contrastive_queries_logits=contrastive_queries_logits,
mask_labels=mask_labels,
class_labels=class_labels,
text_queries=text_queries,
auxiliary_predictions=auxiliary_predictions,
calculate_contrastive_loss=calculate_contrastive_loss,
)
for key, weight in self.weight_dict.items():
for loss_key, loss in loss_dict.items():
if key in loss_key:
loss *= weight
return loss_dict
def get_loss(self, loss_dict: Dict[str, Tensor]) -> Tensor:
return sum(loss_dict.values())
@add_start_docstrings_to_model_forward(ONEFORMER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=OneFormerForUniversalSegmentationOutput, config_class=_CONFIG_FOR_DOC)
.\models\oneformer\processing_oneformer.py
"""
Image/Text processor class for OneFormer
"""
from typing import List
from ...processing_utils import ProcessorMixin
from ...utils import is_torch_available
if is_torch_available():
import torch
class OneFormerProcessor(ProcessorMixin):
r"""
Constructs an OneFormer processor which wraps [`OneFormerImageProcessor`] and
[`CLIPTokenizer`]/[`CLIPTokenizerFast`] into a single processor that inherits both the image processor and
tokenizer functionalities.
Args:
image_processor ([`OneFormerImageProcessor`]):
The image processor is a required input.
tokenizer ([`CLIPTokenizer`, `CLIPTokenizerFast`]):
The tokenizer is a required input.
max_seq_len (`int`, *optional*, defaults to 77)):
Sequence length for input text list.
task_seq_len (`int`, *optional*, defaults to 77):
Sequence length for input task token.
"""
attributes = ["image_processor", "tokenizer"]
image_processor_class = "OneFormerImageProcessor"
tokenizer_class = ("CLIPTokenizer", "CLIPTokenizerFast")
def __init__(
self, image_processor=None, tokenizer=None, max_seq_length: int = 77, task_seq_length: int = 77, **kwargs
):
if image_processor is None:
raise ValueError("You need to specify an `image_processor`.")
if tokenizer is None:
raise ValueError("You need to specify a `tokenizer`.")
self.max_seq_length = max_seq_length
self.task_seq_length = task_seq_length
super().__init__(image_processor, tokenizer)
def _preprocess_text(self, text_list=None, max_length=77):
if text_list is None:
raise ValueError("tokens cannot be None.")
tokens = self.tokenizer(text_list, padding="max_length", max_length=max_length, truncation=True)
attention_masks, input_ids = tokens["attention_mask"], tokens["input_ids"]
token_inputs = []
for attn_mask, input_id in zip(attention_masks, input_ids):
token = torch.tensor(attn_mask) * torch.tensor(input_id)
token_inputs.append(token.unsqueeze(0))
token_inputs = torch.cat(token_inputs, dim=0)
return token_inputs
def encode_inputs(self, images=None, task_inputs=None, segmentation_maps=None, **kwargs):
"""
This method forwards all its arguments to [`OneFormerImageProcessor.encode_inputs`] and then tokenizes the
task_inputs. Please refer to the docstring of this method for more information.
"""
if task_inputs is None:
raise ValueError("You have to specify the task_input. Found None.")
elif images is None:
raise ValueError("You have to specify the image. Found None.")
if not all(task in ["semantic", "instance", "panoptic"] for task in task_inputs):
raise ValueError("task_inputs must be semantic, instance, or panoptic.")
encoded_inputs = self.image_processor.encode_inputs(images, task_inputs, segmentation_maps, **kwargs)
if isinstance(task_inputs, str):
task_inputs = [task_inputs]
if isinstance(task_inputs, list) and all(isinstance(task_input, str) for task_input in task_inputs):
task_token_inputs = []
for task in task_inputs:
task_input = f"the task is {task}"
task_token_inputs.append(task_input)
encoded_inputs["task_inputs"] = self._preprocess_text(task_token_inputs, max_length=self.task_seq_length)
else:
raise TypeError("Task Inputs should be a string or a list of strings.")
if hasattr(encoded_inputs, "text_inputs"):
texts_list = encoded_inputs.text_inputs
text_inputs = []
for texts in texts_list:
text_input_list = self._preprocess_text(texts, max_length=self.max_seq_length)
text_inputs.append(text_input_list.unsqueeze(0))
encoded_inputs["text_inputs"] = torch.cat(text_inputs, dim=0)
return encoded_inputs
def post_process_semantic_segmentation(self, *args, **kwargs):
"""
This method forwards all its arguments to [`OneFormerImageProcessor.post_process_semantic_segmentation`].
Please refer to the docstring of this method for more information.
"""
return self.image_processor.post_process_semantic_segmentation(*args, **kwargs)
def post_process_instance_segmentation(self, *args, **kwargs):
"""
This method forwards all its arguments to [`OneFormerImageProcessor.post_process_instance_segmentation`].
Please refer to the docstring of this method for more information.
"""
return self.image_processor.post_process_instance_segmentation(*args, **kwargs)
def post_process_panoptic_segmentation(self, *args, **kwargs):
"""
This method forwards all its arguments to [`OneFormerImageProcessor.post_process_panoptic_segmentation`].
Please refer to the docstring of this method for more information.
"""
return self.image_processor.post_process_panoptic_segmentation(*args, **kwargs)
.\models\oneformer\__init__.py
from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
_import_structure = {
"configuration_oneformer": ["ONEFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "OneFormerConfig"],
"processing_oneformer": ["OneFormerProcessor"],
}
try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["image_processing_oneformer"] = ["OneFormerImageProcessor"]
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_oneformer"] = [
"ONEFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
"OneFormerForUniversalSegmentation",
"OneFormerModel",
"OneFormerPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_oneformer import ONEFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, OneFormerConfig
from .processing_oneformer import OneFormerProcessor
try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .image_processing_oneformer import OneFormerImageProcessor
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_oneformer import (
ONEFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
OneFormerForUniversalSegmentation,
OneFormerModel,
OneFormerPreTrainedModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
.\models\openai\configuration_openai.py
""" OpenAI GPT configuration"""
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"openai-community/openai-gpt": "https://huggingface.co/openai-community/openai-gpt/resolve/main/config.json"
}
class OpenAIGPTConfig(PretrainedConfig):
"""
This is the configuration class to store the configuration of a [`OpenAIGPTModel`] or a [`TFOpenAIGPTModel`]. It is
used to instantiate a GPT model according to the specified arguments, defining the model architecture.
Instantiating a configuration with the defaults will yield a similar configuration to that of the GPT
[openai-community/openai-gpt](https://huggingface.co/openai-community/openai-gpt) architecture from OpenAI.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Examples:
```
>>> from transformers import OpenAIGPTConfig, OpenAIGPTModel
>>> # Initializing a GPT configuration
>>> configuration = OpenAIGPTConfig()
>>> # Initializing a model (with random weights) from the configuration
>>> model = OpenAIGPTModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
"""
model_type = "openai-gpt"
attribute_map = {
"max_position_embeddings": "n_positions",
"hidden_size": "n_embd",
"num_attention_heads": "n_head",
"num_hidden_layers": "n_layer",
}
def __init__(
self,
vocab_size=40478,
n_positions=512,
n_embd=768,
n_layer=12,
n_head=12,
afn="gelu",
resid_pdrop=0.1,
embd_pdrop=0.1,
attn_pdrop=0.1,
layer_norm_epsilon=1e-5,
initializer_range=0.02,
summary_type="cls_index",
summary_use_proj=True,
summary_activation=None,
summary_proj_to_labels=True,
summary_first_dropout=0.1,
**kwargs,
):
super().__init__(
vocab_size=vocab_size,
n_positions=n_positions,
n_embd=n_embd,
n_layer=n_layer,
n_head=n_head,
afn=afn,
resid_pdrop=resid_pdrop,
embd_pdrop=embd_pdrop,
attn_pdrop=attn_pdrop,
layer_norm_epsilon=layer_norm_epsilon,
initializer_range=initializer_range,
summary_type=summary_type,
summary_use_proj=summary_use_proj,
summary_activation=summary_activation,
summary_proj_to_labels=summary_proj_to_labels,
summary_first_dropout=summary_first_dropout,
**kwargs,
)
):
self.vocab_size = vocab_size
self.n_positions = n_positions
self.n_embd = n_embd
self.n_layer = n_layer
self.n_head = n_head
self.afn = afn
self.resid_pdrop = resid_pdrop
self.embd_pdrop = embd_pdrop
self.attn_pdrop = attn_pdrop
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range
self.summary_type = summary_type
self.summary_use_proj = summary_use_proj
self.summary_activation = summary_activation
self.summary_first_dropout = summary_first_dropout
self.summary_proj_to_labels = summary_proj_to_labels
super().__init__(**kwargs)
.\models\openai\convert_openai_original_tf_checkpoint_to_pytorch.py
"""Convert OpenAI GPT checkpoint."""
import argparse
import torch
from transformers import OpenAIGPTConfig, OpenAIGPTModel, load_tf_weights_in_openai_gpt
from transformers.utils import CONFIG_NAME, WEIGHTS_NAME, logging
logging.set_verbosity_info()
def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_config_file, pytorch_dump_folder_path):
if openai_config_file == "":
config = OpenAIGPTConfig()
else:
config = OpenAIGPTConfig.from_json_file(openai_config_file)
model = OpenAIGPTModel(config)
load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path)
pytorch_weights_dump_path = pytorch_dump_folder_path + "/" + WEIGHTS_NAME
pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME
print(f"Save PyTorch model to {pytorch_weights_dump_path}")
torch.save(model.state_dict(), pytorch_weights_dump_path)
print(f"Save configuration file to {pytorch_config_dump_path}")
with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
f.write(config.to_json_string())
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--openai_checkpoint_folder_path",
default=None,
type=str,
required=True,
help="Path to the TensorFlow checkpoint path.",
)
parser.add_argument(
"--pytorch_dump_folder_path",
default=None,
type=str,
required=True,
help="Path to the output PyTorch model.",
)
parser.add_argument(
"--openai_config_file",
default="",
type=str,
help=(
"An optional config json file corresponding to the pre-trained OpenAI model. \n"
"This specifies the model architecture."
),
)
args = parser.parse_args()
convert_openai_checkpoint_to_pytorch(
args.openai_checkpoint_folder_path, args.openai_config_file, args.pytorch_dump_folder_path
)
.\models\openai\modeling_openai.py
import json
import math
import os
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union
import torch
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import gelu_new, silu
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
from ...modeling_utils import PreTrainedModel, SequenceSummary
from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
from ...utils import (
ModelOutput,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from .configuration_openai import OpenAIGPTConfig
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "openai-community/openai-gpt"
_CONFIG_FOR_DOC = "OpenAIGPTConfig"
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST = [
"openai-community/openai-gpt",
]
def load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path):
"""Load tf pre-trained weights in a pytorch model (from NumPy arrays here)"""
import re
import numpy as np
if ".ckpt" in openai_checkpoint_folder_path:
openai_checkpoint_folder_path = os.path.dirname(openai_checkpoint_folder_path)
logger.info(f"Loading weights from {openai_checkpoint_folder_path}")
with open(openai_checkpoint_folder_path + "/parameters_names.json", "r", encoding="utf-8") as names_handle:
names = json.load(names_handle)
with open(openai_checkpoint_folder_path + "/params_shapes.json", "r", encoding="utf-8") as shapes_handle:
shapes = json.load(shapes_handle)
offsets = np.cumsum([np.prod(shape) for shape in shapes])
init_params = [np.load(openai_checkpoint_folder_path + f"/params_{n}.npy") for n in range(10)]
init_params = np.split(np.concatenate(init_params, 0), offsets)[:-1]
init_params = [param.reshape(shape) for param, shape in zip(init_params, shapes)]
init_params = [arr.squeeze() for arr in init_params]
if model.tokens_embed.weight.shape != init_params[1].shape:
raise ValueError(
f"tokens_embed.weight.shape: {model.tokens_embed.weight.shape} does not match init_param[1].shape:"
f" {init_params[1].shape}"
)
if model.positions_embed.weight.shape != init_params[0].shape:
raise ValueError(
f"positions_embed.weight.shape: {model.positions_embed.weight.shape} does not match init_param[0].shape:"
f" {init_params[0].shape}"
)
model.tokens_embed.weight.data = torch.from_numpy(init_params[1])
model.positions_embed.weight.data = torch.from_numpy(init_params[0])
names.pop(0)
init_params.pop(0)
init_params.pop(0)
for name, array in zip(names, init_params):
name = name[6:]
if name[-2:] != ":0":
raise ValueError(f"Layer {name} does not end with :0")
name = name[:-2]
name = name.split("/")
pointer = model
for m_name in name:
if re.fullmatch(r"[A-Za-z]+\d+", m_name):
scope_names = re.split(r"(\d+)", m_name)
else:
scope_names = [m_name]
if scope_names[0] == "g":
pointer = getattr(pointer, "weight")
elif scope_names[0] == "b":
pointer = getattr(pointer, "bias")
elif scope_names[0] == "w":
pointer = getattr(pointer, "weight")
else:
pointer = getattr(pointer, scope_names[0])
if len(scope_names) >= 2:
num = int(scope_names[1])
pointer = pointer[num]
if pointer.shape != array.shape:
raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
logger.info(f"Initialize PyTorch weight {name}")
pointer.data = torch.from_numpy(array)
return model
ACT_FNS = {"relu": nn.ReLU(), "silu": silu, "gelu": gelu_new, "swish": silu}
class Attention(nn.Module):
def __init__(self, nx, n_positions, config, scale=False):
super().__init__()
n_state = nx
if n_state % config.n_head != 0:
raise ValueError(f"Attention n_state shape: {n_state} must be divisible by config.n_head {config.n_head}")
self.register_buffer(
"bias",
torch.tril(torch.ones(n_positions, n_positions)).view(1, 1, n_positions, n_positions),
persistent=False,
)
self.n_head = config.n_head
self.split_size = n_state
self.scale = scale
self.c_attn = Conv1D(n_state * 3, nx)
self.c_proj = Conv1D(n_state, nx)
self.attn_dropout = nn.Dropout(config.attn_pdrop)
self.resid_dropout = nn.Dropout(config.resid_pdrop)
self.pruned_heads = set()
def prune_heads(self, heads):
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(
heads, self.n_head, self.split_size // self.n_head, self.pruned_heads
)
index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])
self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads))
self.n_head = self.n_head - len(heads)
self.pruned_heads = self.pruned_heads.union(heads)
def _attn(self, q, k, v, attention_mask=None, head_mask=None, output_attentions=False):
w = torch.matmul(q, k)
if self.scale:
w = w / math.sqrt(v.size(-1))
b = self.bias[:, :, : w.size(-2), : w.size(-1)]
w = w * b + -1e4 * (1 - b)
if attention_mask is not None:
w = w + attention_mask
w = nn.functional.softmax(w, dim=-1)
w = self.attn_dropout(w)
if head_mask is not None:
w = w * head_mask
outputs = [torch.matmul(w, v)]
if output_attentions:
outputs.append(w)
return outputs
def merge_heads(self, x):
x = x.permute(0, 2, 1, 3).contiguous()
new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
return x.view(*new_x_shape)
def split_heads(self, x, k=False):
new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
x = x.view(*new_x_shape)
if k:
return x.permute(0, 2, 3, 1)
else:
return x.permute(0, 2, 1, 3)
def forward(self, x, attention_mask=None, head_mask=None, output_attentions=False):
x = self.c_attn(x)
query, key, value = x.split(self.split_size, dim=2)
query = self.split_heads(query)
key = self.split_heads(key, k=True)
value = self.split_heads(value)
attn_outputs = self._attn(query, key, value, attention_mask, head_mask, output_attentions)
a = attn_outputs[0]
a = self.merge_heads(a)
a = self.c_proj(a)
a = self.resid_dropout(a)
outputs = [a] + attn_outputs[1:]
return outputs
class MLP(nn.Module):
def __init__(self, n_state, config):
super().__init__()
nx = config.n_embd
self.c_fc = Conv1D(n_state, nx)
self.c_proj = Conv1D(nx, n_state)
self.act = ACT_FNS[config.afn]
self.dropout = nn.Dropout(config.resid_pdrop)
def forward(self, x):
h = self.act(self.c_fc(x))
h2 = self.c_proj(h)
return self.dropout(h2)
class Block(nn.Module):
def __init__(self, n_positions, config, scale=False):
super().__init__()
nx = config.n_embd
self.attn = Attention(nx, n_positions, config, scale)
self.ln_1 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
self.mlp = MLP(4 * nx, config)
self.ln_2 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
def forward(self, x, attention_mask=None, head_mask=None, output_attentions=False):
attn_outputs = self.attn(
x,
attention_mask=attention_mask,
head_mask=head_mask,
output_attentions=output_attentions,
)
a = attn_outputs[0]
n = self.ln_1(x + a)
m = self.mlp(n)
h = self.ln_2(n + m)
outputs = [h] + attn_outputs[1:]
return outputs
class OpenAIGPTPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = OpenAIGPTConfig
load_tf_weights = load_tf_weights_in_openai_gpt
base_model_prefix = "transformer"
def _init_weights(self, module):
"""Initialize the weights."""
if isinstance(module, (nn.Linear, Conv1D)):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
@dataclass
class OpenAIGPTDoubleHeadsModelOutput(ModelOutput):
"""
Base class for outputs of models predicting if two sentences are consecutive or not.
"""
"""
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss.
语言建模的损失值(当提供`labels`时返回)。
mc_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mc_labels` is provided):
Multiple choice classification loss.
多项选择分类的损失值(当提供`mc_labels`时返回)。
logits (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
语言建模头部的预测分数(SoftMax之前每个词汇标记的分数)。
mc_logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`):
Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
多项选择分类头部的预测分数(SoftMax之前每个选项的分数)。
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
模型在每层输出的隐藏状态,加上初始嵌入输出。
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
注意力权重,经过注意力SoftMax后的结果,用于计算自注意力头中的加权平均值。
"""
loss: Optional[torch.FloatTensor] = None
mc_loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
mc_logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
OPENAI_GPT_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`OpenAIGPTConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
OPENAI_GPT_INPUTS_DOCSTRING = r"""
Model input descriptions for OpenAI GPT models, detailing expected inputs and their formats.
Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Indices of input sequence tokens in the vocabulary.
attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on padding tokens.
token_type_ids (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Segment token indices to differentiate two sequences in the same input (e.g. question/answer).
position_ids (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Indices of positions of each input sequence tokens in the position embeddings.
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the self-attention modules.
inputs_embeds (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Instead of passing :obj:`input_ids`, you can directly pass an embedded representation.
Returns:
:obj:`tuple(torch.FloatTensor)`: A tuple of :obj:`torch.FloatTensor` (or :obj:`tuple` of :obj:`torch.FloatTensor`
if :obj:`return_dict` is True) containing various elements depending on the configuration (:class:`~transformers.GPT2Config`)
and inputs.
Examples::
>>> from transformers import GPT2Tokenizer, GPT2Model
>>> import torch
>>> tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
>>> model = GPT2Model.from_pretrained('gpt2')
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(**inputs)
>>> last_hidden_states = outputs.last_hidden_state
"""
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary
- 1表示**未遮罩**的token,
- 0表示**已遮罩**的token。
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
[What are attention masks?](../glossary
- 0对应一个*sentence A*的token,
- 1对应一个*sentence B*的token。
token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
[What are token type IDs?](../glossary
[What are position IDs?](../glossary
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- 1表示头部**未遮罩**,
- 0表示头部**已遮罩**。
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
output_attentions (`bool`, *optional*):
output_hidden_states (`bool`, *optional*):
return_dict (`bool`, *optional*):
"""
OpenAI GPT Model transformer with a language modeling head on top (linear layer with weights tied to the input
embeddings).
"""
class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.transformer = OpenAIGPTModel(config)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.post_init()
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
@add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=CausalLMOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], CausalLMOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = self.transformer(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutput(
loss=loss,
logits=lm_logits,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> Dict[str, Any]:
return {"input_ids": input_ids}
@add_start_docstrings(
"""
OpenAI GPT Model transformer with a language modeling and a multiple-choice classification head on top e.g. for
RocStories/SWAG tasks. The two heads are two linear layers. The language modeling head has its weights tied to the
input embeddings, the classification head takes as input the input of a specified classification token index in the
input sequence).
""",
OPENAI_GPT_START_DOCSTRING,
)
class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
config.num_labels = 1
self.transformer = OpenAIGPTModel(config)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.multiple_choice_head = SequenceSummary(config)
self.post_init()
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
@add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=OpenAIGPTDoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
mc_token_ids: Optional[torch.LongTensor] = None,
labels: Optional[torch.LongTensor] = None,
mc_labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
@add_start_docstrings(
"""
The Original OpenAI GPT Model transformer with a sequence classification head on top (linear layer).
[`OpenAIGPTForSequenceClassification`] uses the last token in order to do the classification, as other causal
models (e.g. GPT-2) do. Since it does classification on the last token, it requires to know the position of the
last token. If a `pad_token_id` is defined in the configuration, it finds the last token that is not a padding
token in each row. If no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since
it cannot guess the padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take
the last value in each row of the batch).
""",
OPENAI_GPT_START_DOCSTRING,
)
class OpenAIGPTForSequenceClassification(OpenAIGPTPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.transformer = OpenAIGPTModel(config)
self.score = nn.Linear(config.n_embd, self.num_labels, bias=False)
@add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=SequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,