Transformers 源码解析(一百一十九)
.\models\vit_hybrid\convert_vit_hybrid_timm_to_pytorch.py
"""Convert ViT hybrid checkpoints from the timm library."""
import argparse
import json
from pathlib import Path
import requests
import timm
import torch
from huggingface_hub import hf_hub_download
from PIL import Image
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from transformers import (
BitConfig,
ViTHybridConfig,
ViTHybridForImageClassification,
ViTHybridImageProcessor,
ViTHybridModel,
)
from transformers.image_utils import PILImageResampling
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
def create_rename_keys(config, base_model=False):
rename_keys = []
rename_keys.append(("cls_token", "vit.embeddings.cls_token"))
rename_keys.append(("pos_embed", "vit.embeddings.position_embeddings"))
rename_keys.append(("patch_embed.proj.weight", "vit.embeddings.patch_embeddings.projection.weight"))
rename_keys.append(("patch_embed.proj.bias", "vit.embeddings.patch_embeddings.projection.bias"))
rename_keys.append(("patch_embed.backbone.stem.conv.weight", "vit.embeddings.patch_embeddings.backbone.bit.embedder.convolution.weight"))
rename_keys.append(("patch_embed.backbone.stem.norm.weight", "vit.embeddings.patch_embeddings.backbone.bit.embedder.norm.weight"))
rename_keys.append(("patch_embed.backbone.stem.norm.bias", "vit.embeddings.patch_embeddings.backbone.bit.embedder.norm.bias"))
for stage_idx in range(len(config.backbone_config.depths)):
for layer_idx in range(config.backbone_config.depths[stage_idx]):
rename_keys.append((
f"patch_embed.backbone.stages.{stage_idx}.blocks.{layer_idx}.conv1.weight",
f"vit.embeddings.patch_embeddings.backbone.bit.encoder.stages.{stage_idx}.layers.{layer_idx}.conv1.weight"))
rename_keys.append((
f"patch_embed.backbone.stages.{stage_idx}.blocks.{layer_idx}.norm1.weight",
f"vit.embeddings.patch_embeddings.backbone.bit.encoder.stages.{stage_idx}.layers.{layer_idx}.norm1.weight"))
rename_keys.append((
f"patch_embed.backbone.stages.{stage_idx}.blocks.{layer_idx}.norm1.bias",
f"vit.embeddings.patch_embeddings.backbone.bit.encoder.stages.{stage_idx}.layers.{layer_idx}.norm1.bias"))
rename_keys.append((
f"patch_embed.backbone.stages.{stage_idx}.blocks.{layer_idx}.conv2.weight",
f"vit.embeddings.patch_embeddings.backbone.bit.encoder.stages.{stage_idx}.layers.{layer_idx}.conv2.weight"))
rename_keys.append((
f"patch_embed.backbone.stages.{stage_idx}.blocks.{layer_idx}.norm2.weight",
f"vit.embeddings.patch_embeddings.backbone.bit.encoder.stages.{stage_idx}.layers.{layer_idx}.norm2.weight"))
rename_keys.append((
f"patch_embed.backbone.stages.{stage_idx}.blocks.{layer_idx}.norm2.bias",
f"vit.embeddings.patch_embeddings.backbone.bit.encoder.stages.{stage_idx}.layers.{layer_idx}.norm2.bias"))
rename_keys.append((
f"patch_embed.backbone.stages.{stage_idx}.blocks.{layer_idx}.conv3.weight",
f"vit.embeddings.patch_embeddings.backbone.bit.encoder.stages.{stage_idx}.layers.{layer_idx}.conv3.weight"))
rename_keys.append((
f"patch_embed.backbone.stages.{stage_idx}.blocks.{layer_idx}.norm3.weight",
f"vit.embeddings.patch_embeddings.backbone.bit.encoder.stages.{stage_idx}.layers.{layer_idx}.norm3.weight"))
rename_keys.append((
f"patch_embed.backbone.stages.{stage_idx}.blocks.{layer_idx}.norm3.bias",
f"vit.embeddings.patch_embeddings.backbone.bit.encoder.stages.{stage_idx}.layers.{layer_idx}.norm3.bias"))
rename_keys.append((
f"patch_embed.backbone.stages.{stage_idx}.blocks.0.downsample.conv.weight",
f"vit.embeddings.patch_embeddings.backbone.bit.encoder.stages.{stage_idx}.layers.0.downsample.conv.weight"))
rename_keys.append((
f"patch_embed.backbone.stages.{stage_idx}.blocks.0.downsample.norm.weight",
f"vit.embeddings.patch_embeddings.backbone.bit.encoder.stages.{stage_idx}.layers.0.downsample.norm.weight"))
rename_keys.append((
f"patch_embed.backbone.stages.{stage_idx}.blocks.0.downsample.norm.bias",
f"vit.embeddings.patch_embeddings.backbone.bit.encoder.stages.{stage_idx}.layers.0.downsample.norm.bias"))
for i in range(config.num_hidden_layers):
rename_keys.append((f"blocks.{i}.norm1.weight", f"vit.encoder.layer.{i}.layernorm_before.weight"))
rename_keys.append((f"blocks.{i}.norm1.bias", f"vit.encoder.layer.{i}.layernorm_before.bias"))
rename_keys.append((f"blocks.{i}.attn.proj.weight", f"vit.encoder.layer.{i}.attention.output.dense.weight"))
rename_keys.append((f"blocks.{i}.attn.proj.bias", f"vit.encoder.layer.{i}.attention.output.dense.bias"))
rename_keys.append((f"blocks.{i}.norm2.weight", f"vit.encoder.layer.{i}.layernorm_after.weight"))
rename_keys.append((f"blocks.{i}.norm2.bias", f"vit.encoder.layer.{i}.layernorm_after.bias"))
rename_keys.append((f"blocks.{i}.mlp.fc1.weight", f"vit.encoder.layer.{i}.intermediate.dense.weight"))
rename_keys.append((f"blocks.{i}.mlp.fc1.bias", f"vit.encoder.layer.{i}.intermediate.dense.bias"))
rename_keys.append((f"blocks.{i}.mlp.fc2.weight", f"vit.encoder.layer.{i}.output.dense.weight"))
rename_keys.append((f"blocks.{i}.mlp.fc2.bias", f"vit.encoder.layer.{i}.output.dense.bias"))
if base_model:
rename_keys.extend(
[
("norm.weight", "layernorm.weight"),
("norm.bias", "layernorm.bias"),
("pre_logits.fc.weight", "pooler.dense.weight"),
("pre_logits.fc.bias", "pooler.dense.bias"),
]
)
rename_keys = [(pair[0], pair[1][4:]) if pair[1].startswith("vit") else pair for pair in rename_keys]
else:
rename_keys.extend(
[
("norm.weight", "vit.layernorm.weight"),
("norm.bias", "vit.layernorm.bias"),
("head.weight", "classifier.weight"),
("head.bias", "classifier.bias"),
]
)
return rename_keys
def read_in_q_k_v(state_dict, config, base_model=False):
for i in range(config.num_hidden_layers):
if base_model:
prefix = ""
else:
prefix = "vit."
in_proj_weight = state_dict.pop(f"blocks.{i}.attn.qkv.weight")
in_proj_bias = state_dict.pop(f"blocks.{i}.attn.qkv.bias")
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[
: config.hidden_size, :
]
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.bias"] = in_proj_bias[: config.hidden_size]
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[
config.hidden_size : config.hidden_size * 2, :
]
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.bias"] = in_proj_bias[
config.hidden_size : config.hidden_size * 2
]
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[
-config.hidden_size :, :
]
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.bias"] = in_proj_bias[-config.hidden_size :]
def remove_classification_head_(state_dict):
ignore_keys = ["head.weight", "head.bias"]
for k in ignore_keys:
state_dict.pop(k, None)
def rename_key(dct, old, new):
val = dct.pop(old)
dct[new] = val
def prepare_img():
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
im = Image.open(requests.get(url, stream=True).raw)
return im
@torch.no_grad()
def convert_vit_checkpoint(vit_name, pytorch_dump_folder_path, push_to_hub=False):
"""
将模型权重复制/粘贴/调整为我们的ViT结构。
"""
backbone_config = BitConfig(
global_padding="same",
layer_type="bottleneck",
depths=(3, 4, 9),
out_features=["stage3"],
embedding_dynamic_padding=True,
)
config = ViTHybridConfig(backbone_config=backbone_config, image_size=384, num_labels=1000)
base_model = False
timm_model = timm.create_model(vit_name, pretrained=True)
timm_model.eval()
state_dict = timm_model.state_dict()
if base_model:
remove_classification_head_(state_dict)
rename_keys = create_rename_keys(config, base_model)
for src, dest in rename_keys:
rename_key(state_dict, src, dest)
read_in_q_k_v(state_dict, config, base_model)
repo_id = "huggingface/label-files"
filename = "imagenet-1k-id2label.json"
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
config.id2label = id2label
config.label2id = {v: k for k, v in id2label.items()}
if vit_name[-5:] == "in21k":
model = ViTHybridModel(config).eval()
else:
model = ViTHybridForImageClassification(config).eval()
model.load_state_dict(state_dict)
transform = create_transform(**resolve_data_config({}, model=timm_model))
timm_transforms = transform.transforms
pillow_resamplings = {
"bilinear": PILImageResampling.BILINEAR,
"bicubic": PILImageResampling.BICUBIC,
"nearest": PILImageResampling.NEAREST,
}
processor = ViTHybridImageProcessor(
do_resize=True,
size={"shortest_edge": timm_transforms[0].size},
resample=pillow_resamplings[timm_transforms[0].interpolation.value],
do_center_crop=True,
crop_size={"height": timm_transforms[1].size[0], "width": timm_transforms[1].size[1]},
do_normalize=True,
image_mean=timm_transforms[-1].mean.tolist(),
image_std=timm_transforms[-1].std.tolist(),
)
image = prepare_img()
timm_pixel_values = transform(image).unsqueeze(0)
pixel_values = processor(image, return_tensors="pt").pixel_values
assert torch.allclose(timm_pixel_values, pixel_values)
with torch.no_grad():
outputs = model(pixel_values)
logits = outputs.logits
print("Predicted class:", logits.argmax(-1).item())
if base_model:
timm_pooled_output = timm_model.forward_features(pixel_values)
assert timm_pooled_output.shape == outputs.pooler_output.shape
assert torch.allclose(timm_pooled_output, outputs.pooler_output, atol=1e-3)
else:
timm_logits = timm_model(pixel_values)
assert timm_logits.shape == outputs.logits.shape
assert torch.allclose(timm_logits, outputs.logits, atol=1e-3)
print("Looks ok!")
if pytorch_dump_folder_path is not None:
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
print(f"Saving model {vit_name} to {pytorch_dump_folder_path}")
model.save_pretrained(pytorch_dump_folder_path)
print(f"Saving processor to {pytorch_dump_folder_path}")
processor.save_pretrained(pytorch_dump_folder_path)
if push_to_hub:
print(f"Pushing model and processor to the hub {vit_name}")
model.push_to_hub(f"ybelkada/{vit_name}")
processor.push_to_hub(f"ybelkada/{vit_name}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--vit_name",
default="vit_base_r50_s16_384",
type=str,
help="Name of the hybrid ViT timm model you'd like to convert.",
)
parser.add_argument(
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
)
parser.add_argument(
"--push_to_hub", action="store_true", help="Whether to upload the model to the HuggingFace hub."
)
args = parser.parse_args()
convert_vit_checkpoint(args.vit_name, args.pytorch_dump_folder_path, args.push_to_hub)
.\models\vit_hybrid\image_processing_vit_hybrid.py
r"""
Constructs a ViT Hybrid image processor.
"""
def __init__(self, image_size: Union[int, List[int]], mean: Optional[List[float]] = OPENAI_CLIP_MEAN,
std: Optional[List[float]] = OPENAI_CLIP_STD, resampling: PILImageResampling = PILImageResampling.BILINEAR,
channel_dim: ChannelDimension = ChannelDimension.LAST, dtype: TensorType = np.float32):
"""
Initialize the ViT Hybrid image processor.
Parameters:
- image_size (Union[int, List[int]]): Desired size of the output image.
- mean (Optional[List[float]]): Mean values for normalization, defaults to OPENAI_CLIP_MEAN.
- std (Optional[List[float]]): Standard deviation values for normalization, defaults to OPENAI_CLIP_STD.
- resampling (PILImageResampling): Resampling method for image resizing, defaults to PILImageResampling.BILINEAR.
- channel_dim (ChannelDimension): Channel dimension format, defaults to ChannelDimension.LAST.
- dtype (TensorType): Data type of the processed images, defaults to np.float32.
"""
super().__init__()
self.image_size = image_size
self.mean = mean
self.std = std
self.resampling = resampling
self.channel_dim = channel_dim
self.dtype = dtype
def __call__(self, images: Union[ImageInput, List[ImageInput]], return_tensors: bool = True,
**kwargs) -> Union[BatchFeature, List[BatchFeature]]:
"""
Process a single image or a batch of images.
Parameters:
- images (Union[ImageInput, List[ImageInput]]): Input image(s) to be processed.
- return_tensors (bool): Whether to return tensors (True) or numpy arrays (False), defaults to True.
- **kwargs: Additional keyword arguments for preprocessing.
Returns:
- Union[BatchFeature, List[BatchFeature]]: Processed image(s) as tensors or numpy arrays.
"""
images = make_list_of_images(images)
validate_kwargs(kwargs)
validate_preprocess_arguments(self.mean, self.std, self.image_size, self.channel_dim)
resized_images = [resize(image, self.image_size, self.resampling) for image in images]
rgb_images = [convert_to_rgb(image) for image in resized_images]
formatted_images = [to_channel_dimension_format(image, self.channel_dim) for image in rgb_images]
normalized_images = [self._normalize_image(image) for image in formatted_images]
if return_tensors:
return np.stack(normalized_images).astype(self.dtype)
else:
return normalized_images
def _normalize_image(self, image: np.ndarray) -> np.ndarray:
"""
Normalize the image data using mean and standard deviation.
Parameters:
- image (np.ndarray): Input image data.
Returns:
- np.ndarray: Normalized image data.
"""
mean = np.array(self.mean)
std = np.array(self.std)
return (image.astype(np.float32) - mean) / std
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
`do_resize` in the `preprocess` method.
size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 224}`):
Size of the image after resizing. The shortest edge of the image is resized to size["shortest_edge"], with
the longest edge resized to keep the input aspect ratio. Can be overridden by `size` in the `preprocess`
method.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
do_center_crop (`bool`, *optional*, defaults to `True`):
Whether to center crop the image to the specified `crop_size`. Can be overridden by `do_center_crop` in the
`preprocess` method.
crop_size (`Dict[str, int]` *optional*, defaults to 224):
Size of the output image after applying `center_crop`. Can be overridden by `crop_size` in the `preprocess`
method.
do_rescale (`bool`, *optional*, defaults to `True`):
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
the `preprocess` method.
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
method.
do_normalize:
Whether to normalize the image. Can be overridden by `do_normalize` in the `preprocess` method.
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
Can be overridden by the `image_std` parameter in the `preprocess` method.
do_convert_rgb (`bool`, *optional*, defaults to `True`):
Whether to convert the image to RGB.
"""
# 定义模型输入的名称列表,包含一个元素 "pixel_values"
model_input_names = ["pixel_values"]
# 初始化方法,用于设置图像处理器的各种参数
def __init__(
self,
do_resize: bool = True, # 是否进行图像大小调整的标志
size: Dict[str, int] = None, # 图像调整后的尺寸字典,默认最短边为224
resample: PILImageResampling = PILImageResampling.BICUBIC, # 图像调整时的重采样方法,默认为双三次插值
do_center_crop: bool = True, # 是否进行中心裁剪的标志
crop_size: Dict[str, int] = None, # 中心裁剪后的尺寸字典,默认为224x224
do_rescale: bool = True, # 是否进行图像数值缩放的标志
rescale_factor: Union[int, float] = 1 / 255, # 图像数值缩放的因子,默认为1/255
do_normalize: bool = True, # 是否进行图像标准化的标志
image_mean: Optional[Union[float, List[float]]] = None, # 图像标准化的均值,默认为OpenAI CLIP模型的均值
image_std: Optional[Union[float, List[float]]] = None, # 图像标准化的标准差,默认为OpenAI CLIP模型的标准差
do_convert_rgb: bool = True, # 是否进行RGB格式转换的标志
**kwargs, # 其他可选参数
) -> None:
# 调用父类初始化方法,传递额外参数
super().__init__(**kwargs)
# 如果size为None,则设置默认尺寸字典,最短边为224
size = size if size is not None else {"shortest_edge": 224}
# 根据参数获取尺寸字典,不默认为正方形
size = get_size_dict(size, default_to_square=False)
# 如果crop_size为None,则设置默认的裁剪尺寸字典,高度和宽度均为224
crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
# 根据参数获取裁剪尺寸字典,默认为正方形
crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size")
# 将参数赋值给实例变量
self.do_resize = do_resize
self.size = size
self.resample = resample
self.do_center_crop = do_center_crop
self.crop_size = crop_size
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
self.do_convert_rgb = do_convert_rgb
# 设置有效的处理器关键字列表
self._valid_processor_keys = [
"images",
"do_resize",
"size",
"resample",
"do_center_crop",
"crop_size",
"do_rescale",
"rescale_factor",
"do_normalize",
"image_mean",
"image_std",
"do_convert_rgb",
"return_tensors",
"data_format",
"input_data_format",
]
# 从transformers.models.clip.image_processing_clip.CLIPImageProcessor.resize复制而来的方法
def resize(
self,
image: np.ndarray, # 待处理的图像数据,NumPy数组格式
size: Dict[str, int], # 目标尺寸字典,包含高度和宽度信息
resample: PILImageResampling = PILImageResampling.BICUBIC, # 重采样方法,默认为双三次插值
data_format: Optional[Union[str, ChannelDimension]] = None, # 数据格式,可选参数
input_data_format: Optional[Union[str, ChannelDimension]] = None, # 输入数据格式,可选参数
**kwargs, # 其他可选参数
) -> np.ndarray:
"""
Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge
resized to keep the input aspect ratio.
Args:
image (`np.ndarray`):
Image to resize.
size (`Dict[str, int]`):
Size of the output image.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
Resampling filter to use when resizing the image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
"""
# 默认情况下将图片调整为正方形
default_to_square = True
if "shortest_edge" in size:
# 如果指定了最短边的长度,则按照最短边调整图片大小
size = size["shortest_edge"]
default_to_square = False
elif "height" in size and "width" in size:
# 如果指定了高度和宽度,则按照这两个尺寸调整图片大小
size = (size["height"], size["width"])
else:
# 如果大小参数中既没有指定最短边,也没有指定高度和宽度,则抛出数值错误
raise ValueError("Size must contain either 'shortest_edge' or 'height' and 'width'.")
# 获得调整后的图片尺寸
output_size = get_resize_output_image_size(
image,
size=size,
default_to_square=default_to_square,
input_data_format=input_data_format,
)
# 调整图片大小并返回调整后的图片数据
return resize(
image,
size=output_size,
resample=resample,
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
)
.\models\vit_hybrid\modeling_vit_hybrid.py
""" PyTorch ViT Hybrid model. """
import collections.abc
import math
from typing import Dict, List, Optional, Set, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from ...utils.backbone_utils import load_backbone
from .configuration_vit_hybrid import ViTHybridConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "ViTHybridConfig"
_CHECKPOINT_FOR_DOC = "google/vit-hybrid-base-bit-384"
_EXPECTED_OUTPUT_SHAPE = [1, 197, 768]
_IMAGE_CLASS_CHECKPOINT = "google/vit-hybrid-base-bit-384"
_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
VIT_HYBRID_PRETRAINED_MODEL_ARCHIVE_LIST = [
"google/vit-hybrid-base-bit-384",
]
class ViTHybridEmbeddings(nn.Module):
"""
构建CLS标记、位置和补丁嵌入。可选择添加掩码标记。
"""
def __init__(self, config: ViTHybridConfig, use_mask_token: bool = False) -> None:
super().__init__()
self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None
self.patch_embeddings = ViTHybridPatchEmbeddings(config)
num_patches = self.patch_embeddings.num_patches
self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size))
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.config = config
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
resolution images.
Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""
num_patches = embeddings.shape[1] - 1
num_positions = self.position_embeddings.shape[1] - 1
if num_patches == num_positions and height == width:
return self.position_embeddings
class_pos_embed = self.position_embeddings[:, 0]
patch_pos_embed = self.position_embeddings[:, 1:]
dim = embeddings.shape[-1]
height = height // self.config.patch_size
width = width // self.config.patch_size
height, width = height + 0.1, width + 0.1
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
scale_factor=(height / math.sqrt(num_positions), width / math.sqrt(num_positions)),
mode="bicubic",
align_corners=False,
)
if int(height) != patch_pos_embed.shape[-2] or int(width) != patch_pos_embed.shape[-1]:
raise ValueError(f"Invalid height or width: {height}, {width}")
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
) -> torch.Tensor:
batch_size, num_channels, height, width = pixel_values.shape
embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
if bool_masked_pos is not None:
seq_length = embeddings.shape[1]
mask_tokens = self.mask_token.expand(batch_size, seq_length, -1)
mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
embeddings = torch.cat((cls_tokens, embeddings), dim=1)
if interpolate_pos_encoding:
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
else:
embeddings = embeddings + self.position_embeddings
embeddings = self.dropout(embeddings)
return embeddings
"""
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
Transformer.
"""
def __init__(self, config, feature_size=None):
super().__init__()
image_size, patch_size = config.image_size, config.patch_size
num_channels, hidden_size = config.num_channels, config.hidden_size
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
self.backbone = load_backbone(config)
if self.backbone.config.model_type != "bit":
raise ValueError(f"Backbone model type {self.backbone.model_type} is not supported.")
feature_dim = self.backbone.channels[-1]
if feature_size is None:
feature_map = config.backbone_featmap_shape
feature_size = feature_map[-2:]
feature_dim = feature_map[1]
else:
feature_size = (
feature_size if isinstance(feature_size, collections.abc.Iterable) else (feature_size, feature_size)
)
feature_dim = self.backbone.channels[-1]
self.grid_size = (feature_size[0] // patch_size[0], feature_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.projection = nn.Conv2d(feature_dim, hidden_size, kernel_size=patch_size, stride=patch_size)
def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
_, num_channels, height, width = pixel_values.shape
if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
if not interpolate_pos_encoding:
if height != self.image_size[0] or width != self.image_size[1]:
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model"
f" ({self.image_size[0]}*{self.image_size[1]})."
)
features = self.backbone(pixel_values).feature_maps[-1]
embeddings = self.projection(features).flatten(2).transpose(1, 2)
return embeddings
class ViTHybridSelfAttention(nn.Module):
def __init__(self, config: ViTHybridConfig) -> None:
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
f"heads {config.num_attention_heads}."
)
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
attention_probs = self.dropout(attention_probs)
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs
class ViTHybridSelfOutput(nn.Module):
"""
在 ViTHybridLayer 中定义残差连接,而不是像其他模型一样在此处定义,这是因为每个块前都应用了 layernorm。
"""
def __init__(self, config: ViTHybridConfig) -> None:
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states
class ViTHybridAttention(nn.Module):
def __init__(self, config: ViTHybridConfig) -> None:
super().__init__()
self.attention = ViTHybridSelfAttention(config)
self.output = ViTHybridSelfOutput(config)
self.pruned_heads = set()
def prune_heads(self, heads: Set[int]) -> None:
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(
heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
)
self.attention.query = prune_linear_layer(self.attention.query, index)
self.attention.key = prune_linear_layer(self.attention.key, index)
self.attention.value = prune_linear_layer(self.attention.value, index)
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
def forward(
self,
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
self_outputs = self.attention(hidden_states, head_mask, output_attentions)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:]
return outputs
class ViTHybridIntermediate(nn.Module):
def __init__(self, config: ViTHybridConfig) -> None:
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class ViTHybridOutput(nn.Module):
def __init__(self, config: ViTHybridConfig) -> None:
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = hidden_states + input_tensor
return hidden_states
class ViTHybridLayer(nn.Module):
"""This corresponds to the Block class in the timm implementation."""
def __init__(self, config: ViTHybridConfig) -> None:
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = ViTHybridAttention(config)
self.intermediate = ViTHybridIntermediate(config)
self.output = ViTHybridOutput(config)
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
self_attention_outputs = self.attention(
self.layernorm_before(hidden_states),
head_mask,
output_attentions=output_attentions,
)
attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:]
hidden_states = attention_output + hidden_states.to(attention_output.device)
layer_output = self.layernorm_after(hidden_states)
layer_output = self.intermediate(layer_output)
layer_output = self.output(layer_output, hidden_states)
outputs = (layer_output,) + outputs
return outputs
class ViTHybridEncoder(nn.Module):
def __init__(self, config: ViTHybridConfig) -> None:
super().__init__()
self.config = config
self.layer = nn.ModuleList([ViTHybridLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
) -> Union[tuple, BaseModelOutput]:
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
layer_head_mask,
output_attentions,
)
else:
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html
as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
behavior.
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values of the input image. This tensor represents the image in the form of batches,
channels, height, and width.
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
Mask indicating which heads of the self-attention mechanism to mask out. It can be provided as
a 1D tensor for a single layer model or a 2D tensor for multi-layer models. Values are in the
range [0, 1]:
- 1 indicates that the head is **not masked**,
- 0 indicates that the head is **masked**.
output_attentions (`bool`, *optional*):
Whether or not to include the attention tensors from all attention layers in the output. Refer to
the returned tensors for details on the `attentions` field.
output_hidden_states (`bool`, *optional*):
Whether or not to include the hidden states from all layers in the output. Refer to the returned
tensors for details on the `hidden_states` field.
return_dict (`bool`, *optional*):
Whether to return a [`~utils.ModelOutput`] instead of a tuple. If True, the output will be wrapped
in a standardized model output format for ease of use and consistency.
"""
@add_start_docstrings(
"The bare ViT Hybrid Model transformer outputting raw hidden-states without any specific head on top.",
VIT_START_DOCSTRING,
)
# 从transformers.models.vit.modeling_vit.ViTModel复制而来,将ViT替换为ViTHybrid
class ViTHybridModel(ViTHybridPreTrainedModel):
def __init__(self, config: ViTHybridConfig, add_pooling_layer: bool = True, use_mask_token: bool = False):
super().__init__(config)
self.config = config
# 初始化ViTHybridEmbeddings对象,使用是否mask token作为参数
self.embeddings = ViTHybridEmbeddings(config, use_mask_token=use_mask_token)
# 初始化ViTHybridEncoder对象
self.encoder = ViTHybridEncoder(config)
# 初始化LayerNorm层,使用配置中的hidden_size和layer_norm_eps参数
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
# 如果add_pooling_layer为True,则初始化ViTHybridPooler对象,否则设为None
self.pooler = ViTHybridPooler(config) if add_pooling_layer else None
# 初始化权重并应用最终处理
self.post_init()
def get_input_embeddings(self) -> ViTHybridPatchEmbeddings:
# 返回embeddings中的patch_embeddings对象
return self.embeddings.patch_embeddings
def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
# 遍历heads_to_prune字典,对每个层和对应的需要prune的头部列表进行操作
for layer, heads in heads_to_prune.items():
# 调用encoder中对应层的attention对象的prune_heads方法,进行头部修剪
self.encoder.layer[layer].attention.prune_heads(heads)
@add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=BaseModelOutputWithPooling,
config_class=_CONFIG_FOR_DOC,
modality="vision",
expected_output=_EXPECTED_OUTPUT_SHAPE,
)
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
bool_masked_pos: Optional[torch.BoolTensor] = None,
head_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
"""
# 如果 output_attentions 为 None,则使用模型配置中的默认值
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
# 如果 output_hidden_states 为 None,则使用模型配置中的默认值
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
# 如果 return_dict 为 None,则使用模型配置中的默认值
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# 如果 pixel_values 为空,则抛出数值错误异常
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
# 准备头部掩码(head_mask)如果需要
# head_mask 中的 1.0 表示保留该头部
# attention_probs 的形状为 bsz x n_heads x N x N
# 输入的 head_mask 形状为 [num_heads] 或 [num_hidden_layers x num_heads]
# head_mask 被转换为形状 [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
# TODO: 可能有更干净的方法来转换输入(从 ImageProcessor 的角度来看)
# 检查 pixel_values 的数据类型是否符合预期,若不符合,则转换为预期的数据类型
expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype
if pixel_values.dtype != expected_dtype:
pixel_values = pixel_values.to(expected_dtype)
# 将像素值传入嵌入层,得到嵌入输出
embedding_output = self.embeddings(
pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
)
# 将嵌入输出传入编码器层
encoder_outputs = self.encoder(
embedding_output,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# 获取编码器的序列输出
sequence_output = encoder_outputs[0]
# 应用层归一化到序列输出
sequence_output = self.layernorm(sequence_output)
# 如果存在池化器,则对序列输出进行池化
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
# 如果 return_dict 为 False,则返回头部输出和编码器其他输出
if not return_dict:
head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
return head_outputs + encoder_outputs[1:]
# 如果 return_dict 为 True,则返回包含池化输出在内的 BaseModelOutputWithPooling 对象
return BaseModelOutputWithPooling(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
# 从transformers.models.vit.modeling_vit.ViTPooler复制而来,将ViT替换为ViTHybrid
class ViTHybridPooler(nn.Module):
def __init__(self, config: ViTHybridConfig):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size) # 创建一个线性层,输入输出维度相同
self.activation = nn.Tanh() # 使用双曲正切函数作为激活函数
def forward(self, hidden_states):
# 通过仅获取第一个令牌对应的隐藏状态来“汇集”模型
first_token_tensor = hidden_states[:, 0] # 获取第一个令牌对应的隐藏状态张量
pooled_output = self.dense(first_token_tensor) # 将其应用于线性层
pooled_output = self.activation(pooled_output) # 应用双曲正切激活函数
return pooled_output
@add_start_docstrings(
"""
ViT Hybrid Model transformer with an image classification head on top (a linear layer on top of the final hidden
state of the [CLS] token) e.g. for ImageNet.
""",
VIT_START_DOCSTRING,
)
# 从transformers.models.vit.modeling_vit.ViTForImageClassification复制而来,将ViT替换为ViTHybrid
class ViTHybridForImageClassification(ViTHybridPreTrainedModel):
def __init__(self, config: ViTHybridConfig) -> None:
super().__init__(config)
self.num_labels = config.num_labels # 设置分类标签数目
self.vit = ViTHybridModel(config, add_pooling_layer=False) # 创建一个ViTHybridModel模型实例,不添加汇集层
# 分类器头部
self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
# 如果有分类标签,则创建一个线性层作为分类器头部;否则使用恒等映射函数Identity()
# 初始化权重并应用最终处理
self.post_init()
@add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_IMAGE_CLASS_CHECKPOINT,
output_type=ImageClassifierOutput,
config_class=_CONFIG_FOR_DOC,
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
)
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[tuple, ImageClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
# 确保 return_dict 变量有值,如果没有提供则使用 self.config.use_return_dict 的值
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# 使用 ViT 模型进行前向传播
outputs = self.vit(
pixel_values,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict,
)
# 获取模型输出中的序列输出
sequence_output = outputs[0]
# 对序列输出的第一个位置进行分类预测
logits = self.classifier(sequence_output[:, 0, :])
# 初始化损失为 None
loss = None
if labels is not None:
# 将 labels 移动到与 logits 相同的设备上,以支持模型并行计算
labels = labels.to(logits.device)
# 根据配置确定问题类型,如果未指定则根据 num_labels 来判断
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
# 根据问题类型计算损失函数
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
# 如果不需要返回字典形式的结果,则按元组形式返回输出
if not return_dict:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
# 返回 ImageClassifierOutput 对象,包含损失、logits、隐藏状态和注意力权重
return ImageClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
.\models\vit_hybrid\__init__.py
from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
_import_structure = {"configuration_vit_hybrid": ["VIT_HYBRID_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTHybridConfig"]}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_vit_hybrid"] = [
"VIT_HYBRID_PRETRAINED_MODEL_ARCHIVE_LIST",
"ViTHybridForImageClassification",
"ViTHybridModel",
"ViTHybridPreTrainedModel",
]
try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["image_processing_vit_hybrid"] = ["ViTHybridImageProcessor"]
if TYPE_CHECKING:
from .configuration_vit_hybrid import VIT_HYBRID_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTHybridConfig
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_vit_hybrid import (
VIT_HYBRID_PRETRAINED_MODEL_ARCHIVE_LIST,
ViTHybridForImageClassification,
ViTHybridModel,
ViTHybridPreTrainedModel,
)
try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .image_processing_vit_hybrid import ViTHybridImageProcessor
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
.\models\vit_mae\configuration_vit_mae.py
""" ViT MAE model configuration"""
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
VIT_MAE_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"facebook/vit-mae-base": "https://huggingface.co/facebook/vit-mae-base/resolve/main/config.json",
}
class ViTMAEConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`ViTMAEModel`]. It is used to instantiate an ViT
MAE model according to the specified arguments, defining the model architecture. Instantiating a configuration with
the defaults will yield a similar configuration to that of the ViT
[facebook/vit-mae-base](https://huggingface.co/facebook/vit-mae-base) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
"""
Args:
hidden_size (`int`, *optional*, defaults to 768):
Dimensionality of the encoder layers and the pooler layer.
num_hidden_layers (`int`, *optional*, defaults to 12):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 12):
Number of attention heads for each attention layer in the Transformer encoder.
intermediate_size (`int`, *optional*, defaults to 3072):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"selu"` and `"gelu_new"` are supported.
hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
The epsilon used by the layer normalization layers.
image_size (`int`, *optional*, defaults to 224):
The size (resolution) of each image.
patch_size (`int`, *optional*, defaults to 16):
The size (resolution) of each patch.
num_channels (`int`, *optional*, defaults to 3):
The number of input channels.
qkv_bias (`bool`, *optional*, defaults to `True`):
Whether to add a bias to the queries, keys and values.
decoder_num_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the decoder.
decoder_hidden_size (`int`, *optional*, defaults to 512):
Dimensionality of the decoder.
decoder_num_hidden_layers (`int`, *optional*, defaults to 8):
Number of hidden layers in the decoder.
decoder_intermediate_size (`int`, *optional*, defaults to 2048):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the decoder.
mask_ratio (`float`, *optional*, defaults to 0.75):
The ratio of the number of masked tokens in the input sequence.
norm_pix_loss (`bool`, *optional*, defaults to `False`):
Whether or not to train with normalized pixels (see Table 3 in the paper). Using normalized pixels improved
representation quality in the experiments of the authors.
>>> configuration = ViTMAEConfig()
>>>
>>> model = ViTMAEModel(configuration)
>>>
>>> configuration = model.config
.\models\vit_mae\convert_vit_mae_to_pytorch.py
import argparse
import requests
import torch
from PIL import Image
from transformers import ViTMAEConfig, ViTMAEForPreTraining, ViTMAEImageProcessor
def rename_key(name):
if "cls_token" in name:
name = name.replace("cls_token", "vit.embeddings.cls_token")
if "mask_token" in name:
name = name.replace("mask_token", "decoder.mask_token")
if "decoder_pos_embed" in name:
name = name.replace("decoder_pos_embed", "decoder.decoder_pos_embed")
if "pos_embed" in name and "decoder" not in name:
name = name.replace("pos_embed", "vit.embeddings.position_embeddings")
if "patch_embed.proj" in name:
name = name.replace("patch_embed.proj", "vit.embeddings.patch_embeddings.projection")
if "patch_embed.norm" in name:
name = name.replace("patch_embed.norm", "vit.embeddings.norm")
if "decoder_blocks" in name:
name = name.replace("decoder_blocks", "decoder.decoder_layers")
if "blocks" in name:
name = name.replace("blocks", "vit.encoder.layer")
if "attn.proj" in name:
name = name.replace("attn.proj", "attention.output.dense")
if "attn" in name:
name = name.replace("attn", "attention.self")
if "norm1" in name:
name = name.replace("norm1", "layernorm_before")
if "norm2" in name:
name = name.replace("norm2", "layernorm_after")
if "mlp.fc1" in name:
name = name.replace("mlp.fc1", "intermediate.dense")
if "mlp.fc2" in name:
name = name.replace("mlp.fc2", "output.dense")
if "decoder_embed" in name:
name = name.replace("decoder_embed", "decoder.decoder_embed")
if "decoder_norm" in name:
name = name.replace("decoder_norm", "decoder.decoder_norm")
if "decoder_pred" in name:
name = name.replace("decoder_pred", "decoder.decoder_pred")
if "norm.weight" in name and "decoder" not in name:
name = name.replace("norm.weight", "vit.layernorm.weight")
if "norm.bias" in name and "decoder" not in name:
name = name.replace("norm.bias", "vit.layernorm.bias")
return name
def convert_state_dict(orig_state_dict, config):
for key in orig_state_dict.copy().keys():
val = orig_state_dict.pop(key)
if "qkv" in key:
key_split = key.split(".")
layer_num = int(key_split[1])
if "decoder_blocks" in key:
dim = config.decoder_hidden_size
prefix = "decoder.decoder_layers."
if "weight" in key:
orig_state_dict[f"{prefix}{layer_num}.attention.attention.query.weight"] = val[:dim, :]
orig_state_dict[f"{prefix}{layer_num}.attention.attention.key.weight"] = val[dim : dim * 2, :]
orig_state_dict[f"{prefix}{layer_num}.attention.attention.value.weight"] = val[-dim:, :]
elif "bias" in key:
orig_state_dict[f"{prefix}{layer_num}.attention.attention.query.bias"] = val[:dim]
orig_state_dict[f"{prefix}{layer_num}.attention.attention.key.bias"] = val[dim : dim * 2]
orig_state_dict[f"{prefix}{layer_num}.attention.attention.value.bias"] = val[-dim:]
else:
dim = config.hidden_size
prefix = "vit.encoder.layer."
if "weight" in key:
orig_state_dict[f"{prefix}{layer_num}.attention.attention.query.weight"] = val[:dim, :]
orig_state_dict[f"{prefix}{layer_num}.attention.attention.key.weight"] = val[dim : dim * 2, :]
orig_state_dict[f"{prefix}{layer_num}.attention.attention.value.weight"] = val[-dim:, :]
elif "bias" in key:
orig_state_dict[f"{prefix}{layer_num}.attention.attention.query.bias"] = val[:dim]
orig_state_dict[f"{prefix}{layer_num}.attention.attention.key.bias"] = val[dim : dim * 2]
orig_state_dict[f"{prefix}{layer_num}.attention.attention.value.bias"] = val[-dim:]
else:
orig_state_dict[rename_key(key)] = val
return orig_state_dict
def convert_vit_mae_checkpoint(checkpoint_url, pytorch_dump_folder_path):
config = ViTMAEConfig()
if "large" in checkpoint_url:
config.hidden_size = 1024
config.intermediate_size = 4096
config.num_hidden_layers = 24
config.num_attention_heads = 16
elif "huge" in checkpoint_url:
config.patch_size = 14
config.hidden_size = 1280
config.intermediate_size = 5120
config.num_hidden_layers = 32
config.num_attention_heads = 16
model = ViTMAEForPreTraining(config)
state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu")["model"]
image_processor = ViTMAEImageProcessor(size=config.image_size)
new_state_dict = convert_state_dict(state_dict, config)
model.load_state_dict(new_state_dict)
model.eval()
url = "https://user-images.githubusercontent.com/11435359/147738734-196fd92f-9260-48d5-ba7e-bf103d29364d.jpg"
image = Image.open(requests.get(url, stream=True).raw)
image_processor = ViTMAEImageProcessor(size=config.image_size)
inputs = image_processor(images=image, return_tensors="pt")
torch.manual_seed(2)
outputs = model(**inputs)
logits = outputs.logits
if "large" in checkpoint_url:
expected_slice = torch.tensor(
[[-0.7309, -0.7128, -1.0169], [-1.0161, -0.9058, -1.1878], [-1.0478, -0.9411, -1.1911]]
)
elif "huge" in checkpoint_url:
expected_slice = torch.tensor(
[[-1.1599, -0.9199, -1.2221], [-1.1952, -0.9269, -1.2307], [-1.2143, -0.9337, -1.2262]]
)
else:
expected_slice = torch.tensor(
[[-0.9192, -0.8481, -1.1259], [-1.1349, -1.0034, -1.2599], [-1.1757, -1.0429, -1.2726]]
)
assert torch.allclose(logits[0, :3, :3], expected_slice, atol=1e-4)
print(f"Saving model to {pytorch_dump_folder_path}")
model.save_pretrained(pytorch_dump_folder_path)
print(f"Saving image processor to {pytorch_dump_folder_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--checkpoint_url",
default="https://dl.fbaipublicfiles.com/mae/visualize/mae_visualize_vit_base.pth",
type=str,
help="URL of the checkpoint you'd like to convert.",
)
parser.add_argument(
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
)
args = parser.parse_args()
convert_vit_mae_checkpoint(args.checkpoint_url, args.pytorch_dump_folder_path)
.\models\vit_mae\modeling_tf_vit_mae.py
""" TF 2.0 ViT MAE (masked autoencoder) model."""
from __future__ import annotations
import collections.abc
import math
from copy import deepcopy
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import numpy as np
import tensorflow as tf
from ...activations_tf import get_tf_activation
from ...file_utils import (
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from ...modeling_tf_outputs import TFBaseModelOutput
from ...modeling_tf_utils import (
TFModelInputType,
TFPreTrainedModel,
get_initializer,
keras,
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list, stable_softmax
from ...utils import logging
from .configuration_vit_mae import ViTMAEConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "ViTMAEConfig"
_CHECKPOINT_FOR_DOC = "facebook/vit-mae-base"
@dataclass
class TFViTMAEModelOutput(ModelOutput):
"""
Class for TFViTMAEModel's outputs, with potential hidden states and attentions.
"""
Args:
last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
模型最后一层的隐藏状态序列的张量。
mask (`tf.Tensor` of shape `(batch_size, sequence_length)`):
指示哪些补丁被掩码(1)和哪些未被掩码(0)的张量。
ids_restore (`tf.Tensor` of shape `(batch_size, sequence_length)`):
包含(打乱后的)掩码补丁的原始索引的张量。
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
元组的 `tf.Tensor` (一个用于嵌入输出 + 每层输出的一个)的形状为 `(batch_size, sequence_length, hidden_size)`。
模型在每一层输出的隐藏状态以及初始嵌入输出。
attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
元组的 `tf.Tensor` (每层一个)的形状为 `(batch_size, num_heads, sequence_length, sequence_length)`。
注意力 softmax 后的注意力权重,用于计算自注意力头中的加权平均值。
"""
# 初始化函数的参数为默认值为 None
last_hidden_state: tf.Tensor = None
mask: tf.Tensor = None
ids_restore: tf.Tensor = None
hidden_states: Tuple[tf.Tensor] | None = None
attentions: Tuple[tf.Tensor] | None = None
@dataclass
class TFViTMAEDecoderOutput(ModelOutput):
"""
TFViTMAEDecoderOutput 类用于存储 TFViTMAEDecoder 的输出结果,可能包含隐藏状态和注意力权重。
Args:
logits (`tf.Tensor` of shape `(batch_size, sequence_length, patch_size ** 2 * num_channels)`):
像素重建的逻辑回归结果。
hidden_states (`tuple(tf.Tensor)`, *optional*, 当 `output_hidden_states=True` 时返回或 `config.output_hidden_states=True` 时返回):
包含 `tf.Tensor` 元组(一个用于嵌入的输出 + 每层的一个输出),形状为 `(batch_size, sequence_length, hidden_size)`。
模型每层的隐藏状态以及初始嵌入的输出。
attentions (`tuple(tf.Tensor)`, *optional*, 当 `output_attentions=True` 时返回或 `config.output_attentions=True` 时返回):
包含 `tf.Tensor` 元组(每层一个),形状为 `(batch_size, num_heads, sequence_length, sequence_length)`。
经过注意力 softmax 后的注意力权重,用于计算自注意力头中的加权平均值。
"""
logits: tf.Tensor = None
hidden_states: Tuple[tf.Tensor] | None = None
attentions: Tuple[tf.Tensor] | None = None
@dataclass
class TFViTMAEForPreTrainingOutput(ModelOutput):
"""
TFViTMAEForPreTrainingOutput 类用于存储 TFViTMAEForPreTraining 的输出结果,可能包含隐藏状态和注意力权重。
Args:
loss (`tf.Tensor` of shape `(1,)`):
像素重建损失。
logits (`tf.Tensor` of shape `(batch_size, sequence_length, patch_size ** 2 * num_channels)`):
像素重建的逻辑回归结果。
mask (`tf.Tensor` of shape `(batch_size, sequence_length)`):
指示哪些补丁被掩盖(1)和哪些没有(0)的张量。
ids_restore (`tf.Tensor` of shape `(batch_size, sequence_length)`):
包含(打乱的)掩盖补丁的原始索引的张量。
hidden_states (`tuple(tf.Tensor)`, *optional*, 当 `output_hidden_states=True` 时返回或 `config.output_hidden_states=True` 时返回):
包含 `tf.Tensor` 元组(一个用于嵌入的输出 + 每层的一个输出),形状为 `(batch_size, sequence_length, hidden_size)`。
模型每层的隐藏状态以及初始嵌入的输出。
attentions (`tuple(tf.Tensor)`, *optional*, 当 `output_attentions=True` 时返回或 `config.output_attentions=True` 时返回):
包含 `tf.Tensor` 元组(每层一个),形状为 `(batch_size, num_heads, sequence_length, sequence_length)`。
经过注意力 softmax 后的注意力权重,用于计算自注意力头中的加权平均值。
"""
loss: tf.Tensor | None = None
logits: tf.Tensor = None
mask: tf.Tensor = None
ids_restore: tf.Tensor = None
hidden_states: Tuple[tf.Tensor] | None = None
# attentions 是一个变量,类型是 Tuple[tf.Tensor] 或者 None
attentions: Tuple[tf.Tensor] | None = None
# 创建二维 sin/cos 位置嵌入的函数
def get_2d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False):
"""
Create 2D sin/cos positional embeddings.
Args:
embed_dim (`int`):
Embedding dimension.
grid_size (`int`):
The grid height and width.
add_cls_token (`bool`, *optional*, defaults to `False`):
Whether or not to add a classification (CLS) token.
Returns:
(`tf.Tensor` of shape (grid_size*grid_size, embed_dim) or (1+grid_size*grid_size, embed_dim): the position
embeddings (with or without classification token)
"""
# 创建高度和宽度的网格
grid_h = tf.range(grid_size, dtype=tf.float32)
grid_w = tf.range(grid_size, dtype=tf.float32)
grid = tf.meshgrid(grid_w, grid_h) # 这里宽度先行
grid = tf.stack(grid, axis=0)
grid = tf.reshape(grid, [2, 1, grid_size, grid_size])
# 从网格获取二维 sin/cos 位置嵌入
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if add_cls_token:
# 如果需要添加 CLS token,则在位置嵌入前面加一个全零向量
pos_embed = tf.concat([tf.zeros((1, embed_dim)), pos_embed], axis=0)
return pos_embed
# 从网格获取二维 sin/cos 位置嵌入
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
if embed_dim % 2 != 0:
raise ValueError("embed_dim must be even")
# 使用一半维度来编码 grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = tf.concat([emb_h, emb_w], axis=1) # (H*W, D)
return emb
# 从网格获取一维 sin/cos 位置嵌入
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
"""
if embed_dim % 2 != 0:
raise ValueError("embed_dim must be even")
omega = tf.range(embed_dim // 2, dtype="float32")
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega # (D/2,)
pos = tf.reshape(pos, [-1]) # (M,)
out = tf.einsum("m,d->md", pos, omega) # (M, D/2), outer product
# 一半位置获取正弦模式,另一半获取余弦模式,然后串联起来
emb_sin = tf.sin(out) # (M, D/2)
emb_cos = tf.cos(out) # (M, D/2)
emb = tf.concat([emb_sin, emb_cos], axis=1) # (M, D)
return emb
class TFViTMAEEmbeddings(keras.layers.Layer):
"""
构建 CLS token、位置和补丁嵌入。
"""
def __init__(self, config: ViTMAEConfig, **kwargs):
super().__init__(**kwargs)
self.patch_embeddings = TFViTMAEPatchEmbeddings(config, name="patch_embeddings")
self.num_patches = self.patch_embeddings.num_patches
self.config = config
# 在神经网络层的建立过程中,创建一个用于分类特殊令牌的权重矩阵,形状为 (1, 1, 隐藏层大小)
self.cls_token = self.add_weight(
shape=(1, 1, self.config.hidden_size),
initializer=tf.random_normal_initializer(stddev=self.config.initializer_range),
trainable=True,
name="cls_token",
)
# 创建位置嵌入矩阵,形状为 (1, num_patches + 1, 隐藏层大小),使用零值初始化
self.position_embeddings = self.add_weight(
shape=(1, self.num_patches + 1, self.config.hidden_size),
initializer="zeros",
trainable=False, # 固定的正弦-余弦位置嵌入
name="position_embeddings",
)
# 调用函数 `get_2d_sincos_pos_embed` 生成二维正弦-余弦位置嵌入
pos_embed = get_2d_sincos_pos_embed(
self.position_embeddings.shape[-1],
int(self.patch_embeddings.num_patches**0.5),
add_cls_token=True,
)[None, ...]
# 将生成的位置嵌入赋值给 self.position_embeddings
self.position_embeddings.assign(pos_embed)
# 如果模型已经建立完成,则直接返回
if self.built:
return
# 标记模型已经建立
self.built = True
# 如果 self.patch_embeddings 属性存在,则调用它的 build 方法
if getattr(self, "patch_embeddings", None) is not None:
with tf.name_scope(self.patch_embeddings.name):
self.patch_embeddings.build(None)
def random_masking(self, sequence: tf.Tensor, noise: tf.Tensor | None = None):
"""
执行每个样本的随机遮盖,通过每个样本的乱序实现。每个样本的乱序由参数 argsort 的随机噪声完成。
Args:
sequence (`tf.Tensor` of shape `(batch_size, sequence_length, dim)`)
noise (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*),主要用于测试目的,
控制随机性并保持可重现性
"""
# 获取 sequence 的形状信息:batch_size, sequence_length, dim
batch_size, seq_length, dim = shape_list(sequence)
# 计算保留的长度,以保证不被遮盖的部分占比为 self.config.mask_ratio
len_keep = int(seq_length * (1 - self.config.mask_ratio))
# 如果没有提供噪声数据,则生成一个均匀分布在 [0, 1) 区间的随机噪声
if noise is None:
noise = tf.random.uniform(shape=(batch_size, seq_length), minval=0.0, maxval=1.0) # 噪声范围在 [0, 1)
# 对每个样本的噪声进行排序
ids_shuffle = tf.argsort(noise, axis=1) # 升序排序:小的表示保留,大的表示移除
ids_restore = tf.argsort(ids_shuffle, axis=1)
# 保留前 len_keep 部分的序号
ids_keep = ids_shuffle[:, :len_keep]
sequence_unmasked = tf.gather(
sequence,
axis=1,
batch_dims=1,
indices=ids_keep,
)
# 生成二进制遮罩:0 表示保留,1 表示移除
# 这个方法是必需的,因为 TF 的 EagerTensors 不支持直接的赋值操作
mask_keep = tf.zeros((batch_size, len_keep))
mask_remove = tf.ones((batch_size, seq_length - len_keep))
mask = tf.concat([mask_keep, mask_remove], axis=-1)
# 根据 ids_restore 恢复原始顺序,得到最终的二进制遮罩
mask = tf.gather(mask, axis=1, batch_dims=1, indices=ids_restore)
return sequence_unmasked, mask, ids_restore
def call(self, pixel_values: tf.Tensor, noise: tf.Tensor = None) -> tf.Tensor:
# 使用 patch_embeddings 方法将像素值转换为嵌入向量
embeddings = self.patch_embeddings(pixel_values)
# 添加位置嵌入,不包括 cls 标记
embeddings = embeddings + self.position_embeddings[:, 1:, :]
# 执行随机遮蔽:将 embeddings 进行部分遮蔽,生成 mask,并记录遮蔽前的位置 ids_restore
embeddings, mask, ids_restore = self.random_masking(embeddings, noise)
# 添加 cls 标记
# 从 self.cls_token 和 self.position_embeddings 中获取 cls 标记,并复制到每个样本序列的开头
cls_token = self.cls_token + self.position_embeddings[:, :1, :]
cls_tokens = tf.tile(cls_token, (shape_list(embeddings)[0], 1, 1))
# 将 cls 标记与 embeddings 拼接起来
embeddings = tf.concat([cls_tokens, embeddings], axis=1)
# 返回处理后的 embeddings、mask 和 ids_restore
return embeddings, mask, ids_restore
class TFViTMAEPatchEmbeddings(keras.layers.Layer):
"""
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
Transformer.
"""
def __init__(self, config: ViTMAEConfig, **kwargs):
super().__init__(**kwargs)
# 从配置中获取图像大小和patch大小
image_size, patch_size = config.image_size, config.patch_size
num_channels, hidden_size = config.num_channels, config.hidden_size
# 如果图像大小和patch大小不是可迭代对象,转换为元组形式
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
# 计算图像中的patch数量
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.image_size = image_size
self.patch_size = patch_size
self.num_patches = num_patches
self.num_channels = num_channels
self.config = config
# 定义卷积层,用于将输入的像素值转换为patch embeddings
self.projection = keras.layers.Conv2D(
filters=hidden_size,
kernel_size=patch_size,
strides=patch_size,
padding="valid",
data_format="channels_last",
kernel_initializer="glorot_uniform", # 使用glorot_uniform初始化卷积核
bias_initializer="zeros", # 使用零初始化偏置
name="projection",
)
def call(self, pixel_values: tf.Tensor, training: bool = False) -> tf.Tensor:
# 获取输入张量的形状信息
batch_size, num_channels, height, width = shape_list(pixel_values)
# 在动态执行模式下,检查通道数是否与配置中设置的一致
if tf.executing_eagerly():
if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the"
" configuration."
)
# 检查输入图像的尺寸是否与配置中设置的一致
if height != self.image_size[0] or width != self.image_size[1]:
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model"
f" ({self.image_size[0]}*{self.image_size[1]})."
)
# 在CPU上运行时,keras.layers.Conv2D不支持NCHW格式,需要将输入格式从NCHW转换为NHWC
# shape = (batch_size, in_height, in_width, in_channels=num_channels)
pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1))
# 将输入像素值投影到隐藏空间中
projection = self.projection(pixel_values)
# 将2D空间维度变换为单一的时间维度
# shape = (batch_size, num_patches, out_channels=embed_dim)
num_patches = (width // self.patch_size[1]) * (height // self.patch_size[0])
x = tf.reshape(tensor=projection, shape=(batch_size, num_patches, -1))
return x
# 定义 build 方法,用于构建神经网络层的结构
def build(self, input_shape=None):
# 如果已经构建过,直接返回,避免重复构建
if self.built:
return
# 标记当前层已经构建
self.built = True
# 如果存在投影层,则构建投影层
if getattr(self, "projection", None) is not None:
# 在 TensorFlow 中,使用 name_scope 可以定义操作的命名空间
with tf.name_scope(self.projection.name):
# 构建投影层,输入的形状是 [None, None, None, self.num_channels]
self.projection.build([None, None, None, self.num_channels])
# 从transformers.models.vit.modeling_tf_vit.TFViTSelfAttention复制到TFViTMAESelfAttention,并修改为ViT->ViTMAE
class TFViTMAESelfAttention(keras.layers.Layer):
def __init__(self, config: ViTMAEConfig, **kwargs):
super().__init__(**kwargs)
# 检查隐藏大小是否能被注意力头数整除
if config.hidden_size % config.num_attention_heads != 0:
raise ValueError(
f"The hidden size ({config.hidden_size}) is not a multiple of the number "
f"of attention heads ({config.num_attention_heads})"
)
# 初始化注意力头数和每个头的大小
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.sqrt_att_head_size = math.sqrt(self.attention_head_size)
# 创建用于查询、键、值的全连接层,并初始化权重
self.query = keras.layers.Dense(
units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
)
self.key = keras.layers.Dense(
units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key"
)
self.value = keras.layers.Dense(
units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
)
# 添加 dropout 层,用于注意力概率的随机失活
self.dropout = keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
self.config = config
def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
# 将形状从 [batch_size, seq_length, all_head_size] 重塑为 [batch_size, seq_length, num_attention_heads, attention_head_size]
tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))
# 将张量从 [batch_size, seq_length, num_attention_heads, attention_head_size] 转置为 [batch_size, num_attention_heads, seq_length, attention_head_size]
return tf.transpose(tensor, perm=[0, 2, 1, 3])
def call(
self,
hidden_states: tf.Tensor,
head_mask: tf.Tensor,
output_attentions: bool,
training: bool = False,
) -> Tuple[tf.Tensor]:
# 获取隐藏状态张量的批量大小
batch_size = shape_list(hidden_states)[0]
# 对隐藏状态进行查询操作,生成混合查询层
mixed_query_layer = self.query(inputs=hidden_states)
# 对隐藏状态进行键操作,生成混合键层
mixed_key_layer = self.key(inputs=hidden_states)
# 对隐藏状态进行值操作,生成混合值层
mixed_value_layer = self.value(inputs=hidden_states)
# 将混合查询层转置以便计算注意力分数
query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
# 将混合键层转置以便计算注意力分数
key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
# 将混合值层转置以便计算注意力分数
value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)
# 计算查询与键之间的点积,得到原始注意力分数
# 形状为(batch size, num_heads, seq_len_q, seq_len_k)
attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
# 计算注意力分数的缩放系数 dk
dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype)
attention_scores = tf.divide(attention_scores, dk)
# 将注意力分数归一化为概率
attention_probs = stable_softmax(logits=attention_scores, axis=-1)
# 对注意力概率进行 dropout 处理
attention_probs = self.dropout(inputs=attention_probs, training=training)
# 如果存在头部掩码,则应用头部掩码
if head_mask is not None:
attention_probs = tf.multiply(attention_probs, head_mask)
# 计算注意力输出值
attention_output = tf.matmul(attention_probs, value_layer)
# 调整注意力输出值的维度顺序
attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3])
# 将注意力输出值重新形状为(batch_size, seq_len_q, all_head_size)
attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size))
# 构建输出元组,根据需要包含注意力概率
outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)
return outputs
def build(self, input_shape=None):
# 如果已经构建过,则直接返回
if self.built:
return
self.built = True
# 如果存在查询层,构建查询层
if getattr(self, "query", None) is not None:
with tf.name_scope(self.query.name):
self.query.build([None, None, self.config.hidden_size])
# 如果存在键层,构建键层
if getattr(self, "key", None) is not None:
with tf.name_scope(self.key.name):
self.key.build([None, None, self.config.hidden_size])
# 如果存在值层,构建值层
if getattr(self, "value", None) is not None:
with tf.name_scope(self.value.name):
self.value.build([None, None, self.config.hidden_size])
# Copied from transformers.models.vit.modeling_tf_vit.TFViTSelfOutput with ViT->ViTMAE
class TFViTMAESelfOutput(keras.layers.Layer):
"""
The residual connection is defined in TFViTMAELayer instead of here (as is the case with other models), due to the
layernorm applied before each block.
"""
def __init__(self, config: ViTMAEConfig, **kwargs):
super().__init__(**kwargs)
# 定义一个全连接层,用于将输入的隐藏状态转换为指定大小的输出
self.dense = keras.layers.Dense(
units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
)
# 定义一个 dropout 层,用于在训练时随机置零部分神经元,防止过拟合
self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
self.config = config
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
# 使用全连接层对隐藏状态进行转换
hidden_states = self.dense(inputs=hidden_states)
# 在训练时对转换后的输出应用 dropout
hidden_states = self.dropout(inputs=hidden_states, training=training)
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
# 如果已定义全连接层,构建全连接层
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
# Copied from transformers.models.vit.modeling_tf_vit.TFViTAttention with ViT->ViTMAE
class TFViTMAEAttention(keras.layers.Layer):
def __init__(self, config: ViTMAEConfig, **kwargs):
super().__init__(**kwargs)
# 定义注意力层对象,用于处理自注意力机制
self.self_attention = TFViTMAESelfAttention(config, name="attention")
# 定义输出层对象,负责接收注意力层输出并进行处理
self.dense_output = TFViTMAESelfOutput(config, name="output")
def prune_heads(self, heads):
raise NotImplementedError
def call(
self,
input_tensor: tf.Tensor,
head_mask: tf.Tensor,
output_attentions: bool,
training: bool = False,
) -> Tuple[tf.Tensor]:
# 调用自注意力层,处理输入张量,返回处理结果和可能的注意力分布(如果输出的话)
self_outputs = self.self_attention(
hidden_states=input_tensor, head_mask=head_mask, output_attentions=output_attentions, training=training
)
# 调用输出层,接收自注意力层的输出和输入张量,并返回处理后的结果
attention_output = self.dense_output(
hidden_states=self_outputs[0], input_tensor=input_tensor, training=training
)
# 将输出整合为一个元组,包括处理后的注意力输出和可能的注意力分布
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
# 如果已定义自注意力层,构建自注意力层
if getattr(self, "self_attention", None) is not None:
with tf.name_scope(self.self_attention.name):
self.self_attention.build(None)
# 如果已定义输出层,构建输出层
if getattr(self, "dense_output", None) is not None:
with tf.name_scope(self.dense_output.name):
self.dense_output.build(None)
# Copied from transformers.models.vit.modeling_tf_vit.TFViTIntermediate with ViT->ViTMAE
class TFViTMAEIntermediate(keras.layers.Layer):
# 初始化函数,用于创建一个新的 ViTMAE 层实例
def __init__(self, config: ViTMAEConfig, **kwargs):
# 调用父类的初始化方法
super().__init__(**kwargs)
# 创建一个密集连接层,用于处理输入特征
self.dense = keras.layers.Dense(
units=config.intermediate_size, # 设置层的输出维度
kernel_initializer=get_initializer(config.initializer_range), # 使用指定的初始化器初始化权重矩阵
name="dense" # 设置层的名称
)
# 根据配置选择激活函数
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = get_tf_activation(config.hidden_act) # 获取指定名称的 TensorFlow 激活函数
else:
self.intermediate_act_fn = config.hidden_act # 直接使用给定的激活函数
self.config = config # 保存配置对象
# 调用函数,用于定义层的正向传播逻辑
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states) # 将输入数据传递给密集连接层
hidden_states = self.intermediate_act_fn(hidden_states) # 应用中间激活函数
return hidden_states # 返回处理后的特征表示
# 构建函数,用于构建层的参数
def build(self, input_shape=None):
if self.built: # 如果已经构建过,直接返回
return
self.built = True # 标记为已构建
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name): # 使用名称空间管理密集连接层
self.dense.build([None, None, self.config.hidden_size]) # 构建密集连接层的参数
# Copied from transformers.models.vit.modeling_tf_vit.TFViTOutput with ViT->ViTMAE
class TFViTMAEOutput(keras.layers.Layer):
def __init__(self, config: ViTMAEConfig, **kwargs):
super().__init__(**kwargs)
# 定义一个全连接层,输出维度为 config.hidden_size,使用指定的初始化方法
self.dense = keras.layers.Dense(
units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
)
# 定义一个 Dropout 层,使用给定的 dropout 率
self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
self.config = config
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
# 对输入的 hidden_states 进行全连接操作
hidden_states = self.dense(inputs=hidden_states)
# 在训练时对全连接结果进行 dropout 处理
hidden_states = self.dropout(inputs=hidden_states, training=training)
# 将 dropout 后的结果与 input_tensor 相加
hidden_states = hidden_states + input_tensor
return hidden_states
def build(self, input_shape=None):
# 如果已经构建过,直接返回
if self.built:
return
self.built = True
# 如果存在 self.dense 层,则根据输入形状构建全连接层
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.intermediate_size])
# Copied from transformers.models.vit.modeling_tf_vit.TFViTLayer with ViT->ViTMAE
class TFViTMAELayer(keras.layers.Layer):
"""This corresponds to the Block class in the timm implementation."""
def __init__(self, config: ViTMAEConfig, **kwargs):
super().__init__(**kwargs)
# 初始化 TFViTMAEAttention 层
self.attention = TFViTMAEAttention(config, name="attention")
# 初始化 TFViTMAEIntermediate 层
self.intermediate = TFViTMAEIntermediate(config, name="intermediate")
# 初始化 TFViTMAEOutput 层
self.vit_output = TFViTMAEOutput(config, name="output")
# 初始化 layernorm 层,在每个 block 的开始和结束进行归一化
self.layernorm_before = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm_before")
self.layernorm_after = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm_after")
self.config = config
def call(
self,
hidden_states: tf.Tensor,
head_mask: tf.Tensor,
output_attentions: bool,
training: bool = False,
**kwargs
) -> tf.Tensor:
# 调用 self.attention 层处理输入 hidden_states
attention_output = self.attention(
hidden_states, head_mask, output_attentions=output_attentions, training=training
)
# 将 attention 输出与 hidden_states 相加,并进行 layernorm 处理
hidden_states = self.layernorm_before(attention_output + hidden_states)
# 调用 self.intermediate 层处理 layernorm 后的结果
intermediate_output = self.intermediate(hidden_states)
# 将 intermediate 输出与 hidden_states 相加,并进行 layernorm 处理
hidden_states = self.layernorm_after(intermediate_output + hidden_states)
# 调用 self.vit_output 层处理最终的输出
output = self.vit_output(hidden_states, attention_output, training=training)
return output
) -> Tuple[tf.Tensor]:
# 调用 self.attention 进行注意力计算,ViTMAE 中在 self-attention 前应用 layernorm
attention_outputs = self.attention(
input_tensor=self.layernorm_before(inputs=hidden_states), # 在 self-attention 前应用 layernorm
head_mask=head_mask,
output_attentions=output_attentions,
training=training,
)
attention_output = attention_outputs[0]
# 第一个残差连接
hidden_states = attention_output + hidden_states
# ViTMAE 中在 self-attention 后同样应用 layernorm
layer_output = self.layernorm_after(inputs=hidden_states)
# 使用 intermediate 层处理输出
intermediate_output = self.intermediate(hidden_states=layer_output)
# 第二个残差连接在此处完成
layer_output = self.vit_output(
hidden_states=intermediate_output, input_tensor=hidden_states, training=training
)
outputs = (layer_output,) + attention_outputs[1:] # 如果有需要,添加注意力信息到输出中
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
# 如果已经构建过,直接返回
# 构建 attention 层
if getattr(self, "attention", None) is not None:
with tf.name_scope(self.attention.name):
self.attention.build(None)
# 构建 intermediate 层
if getattr(self, "intermediate", None) is not None:
with tf.name_scope(self.intermediate.name):
self.intermediate.build(None)
# 构建 vit_output 层
if getattr(self, "vit_output", None) is not None:
with tf.name_scope(self.vit_output.name):
self.vit_output.build(None)
# 构建 layernorm_before 层
if getattr(self, "layernorm_before", None) is not None:
with tf.name_scope(self.layernorm_before.name):
self.layernorm_before.build([None, None, self.config.hidden_size])
# 构建 layernorm_after 层
if getattr(self, "layernorm_after", None) is not None:
with tf.name_scope(self.layernorm_after.name):
self.layernorm_after.build([None, None, self.config.hidden_size])
# 从 transformers.models.vit.modeling_tf_vit.TFViTEncoder 复制代码,将 ViT 更改为 ViTMAE
class TFViTMAEEncoder(keras.layers.Layer):
def __init__(self, config: ViTMAEConfig, **kwargs):
super().__init__(**kwargs)
# 初始化编码器的多层子模块 TFViTMAELayer,并命名为"layer_._{i}"
self.layer = [TFViTMAELayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)]
def call(
self,
hidden_states: tf.Tensor,
head_mask: tf.Tensor,
output_attentions: bool,
output_hidden_states: bool,
return_dict: bool,
training: bool = False,
) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
# 初始化存储所有隐藏状态的元组,如果不需要输出隐藏状态则为 None
all_hidden_states = () if output_hidden_states else None
# 初始化存储所有注意力权重的元组,如果不需要输出注意力权重则为 None
all_attentions = () if output_attentions else None
# 遍历每一层编码器
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
# 如果需要输出隐藏状态,则将当前隐藏状态添加到 all_hidden_states 中
all_hidden_states = all_hidden_states + (hidden_states,)
# 调用当前层的编码器模块,计算输出
layer_outputs = layer_module(
hidden_states=hidden_states,
head_mask=head_mask[i],
output_attentions=output_attentions,
training=training,
)
# 更新当前隐藏状态为当前层的输出
hidden_states = layer_outputs[0]
if output_attentions:
# 如果需要输出注意力权重,则将当前层的注意力权重添加到 all_attentions 中
all_attentions = all_attentions + (layer_outputs[1],)
# 添加最后一层的隐藏状态到 all_hidden_states
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
# 如果不需要返回字典格式的输出,则返回所有非空的元组值
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
# 返回 TFBaseModelOutput 类的对象,包含最后的隐藏状态、所有隐藏状态和所有注意力权重
return TFBaseModelOutput(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
)
def build(self, input_shape=None):
# 如果已经构建过则直接返回
if self.built:
return
self.built = True
# 构建每一层的编码器模块
if getattr(self, "layer", None) is not None:
for layer in self.layer:
with tf.name_scope(layer.name):
layer.build(None)
@keras_serializable
class TFViTMAEMainLayer(keras.layers.Layer):
config_class = ViTMAEConfig
def __init__(self, config: ViTMAEConfig, **kwargs):
super().__init__(**kwargs)
# 初始化 ViTMAE 主层的配置
self.config = config
# 初始化 ViTMAE 主层的嵌入层 TFViTMAEEmbeddings,并命名为"embeddings"
self.embeddings = TFViTMAEEmbeddings(config, name="embeddings")
# 初始化 ViTMAE 主层的编码器 TFViTMAEEncoder,并命名为"encoder"
self.encoder = TFViTMAEEncoder(config, name="encoder")
# 初始化 ViTMAE 主层的层归一化层 LayerNormalization,使用指定的 epsilon 值,并命名为"layernorm"
self.layernorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm")
def get_input_embeddings(self) -> keras.layers.Layer:
# 返回嵌入层 TFViTMAEEmbeddings 的补丁嵌入
return self.embeddings.patch_embeddings
def _prune_heads(self, heads_to_prune):
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
# 抛出未实现的错误,子类需要实现具体的头部修剪逻辑
raise NotImplementedError
@unpack_inputs
# 定义一个方法 `call`,用于执行模型推断或训练
def call(
self,
pixel_values: TFModelInputType | None = None, # 输入像素值,可以为空
noise: tf.Tensor = None, # 噪声张量,默认为空
head_mask: np.ndarray | tf.Tensor | None = None, # 头部掩码,可以是 NumPy 数组、张量或空
output_attentions: Optional[bool] = None, # 是否输出注意力权重,可选布尔值
output_hidden_states: Optional[bool] = None, # 是否输出隐藏状态,可选布尔值
return_dict: Optional[bool] = None, # 是否以字典形式返回结果,可选布尔值
training: bool = False, # 是否处于训练模式,默认为 False
) -> Union[TFViTMAEModelOutput, Tuple[tf.Tensor]]:
# 调用嵌入层的方法获取嵌入输出、掩码和恢复的 IDs
embedding_output, mask, ids_restore = self.embeddings(
pixel_values=pixel_values, training=training, noise=noise
)
# 如果需要,准备头部掩码
# 在头部掩码中,1.0 表示保留该头部
# attention_probs 的形状为 bsz x n_heads x N x N
# 输入的 head_mask 形状为 [num_heads] 或 [num_hidden_layers x num_heads]
# 并且 head_mask 被转换为形状 [num_hidden_layers x batch x num_heads x seq_length x seq_length]
if head_mask is not None:
# 如果存在头部掩码,但当前未实现如何处理
raise NotImplementedError
else:
# 如果头部掩码为空,则创建一个空列表,长度为隐藏层数
head_mask = [None] * self.config.num_hidden_layers
# 使用编码器处理嵌入输出
encoder_outputs = self.encoder(
embedding_output,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
# 获取编码器输出的序列输出
sequence_output = encoder_outputs[0]
# 应用层归一化到序列输出
sequence_output = self.layernorm(inputs=sequence_output)
# 如果不要求以字典形式返回结果,则返回元组
if not return_dict:
return (sequence_output, mask, ids_restore) + encoder_outputs[1:]
# 以 TFViTMAEModelOutput 对象形式返回结果,包括最后的隐藏状态、掩码、恢复的 IDs、隐藏状态和注意力权重
return TFViTMAEModelOutput(
last_hidden_state=sequence_output,
mask=mask,
ids_restore=ids_restore,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
# 定义构建方法 build,用于在需要时构建模型
def build(self, input_shape=None):
# 如果已经构建过模型,则直接返回
if self.built:
return
# 设置标记为已构建
self.built = True
# 如果存在嵌入层,构建嵌入层
if getattr(self, "embeddings", None) is not None:
with tf.name_scope(self.embeddings.name):
self.embeddings.build(None)
# 如果存在编码器,构建编码器
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
# 如果存在层归一化,构建层归一化,设置形状为 [None, None, self.config.hidden_size]
if getattr(self, "layernorm", None) is not None:
with tf.name_scope(self.layernorm.name):
self.layernorm.build([None, None, self.config.hidden_size])
"""
Documenting the expected input formats for ViT-MAE models when using TensorFlow. This docstring serves as a guide
for users on how to provide inputs to the model.
TensorFlow models in `transformers` support two input formats:
- Passing all inputs as keyword arguments.
- Passing all inputs in a list, tuple, or dictionary as the first positional argument.
This flexibility ensures compatibility with TensorFlow's Keras API and other functional usage scenarios.
Args:
pixel_values (Tensor): Input pixel values representing the image.
attention_mask (Tensor, optional): Mask to avoid performing attention on padding tokens.
token_type_ids (Tensor, optional): Segment token indices to distinguish different parts of the input.
Usage Examples:
- Using keyword arguments: `model(pixel_values=inputs)`
- Using a list or tuple for positional argument:
`model([pixel_values, attention_mask])` or `model([pixel_values, attention_mask, token_type_ids])`
- Using a dictionary with input names: `model({"pixel_values": pixel_values, "token_type_ids": token_type_ids})`
Note:
For custom layers or models using Keras Functional API, ensure inputs match the documented formats.
Reference:
[Transformers documentation](https://huggingface.co/transformers/model_doc/vit.html)
"""
Args:
pixel_values (`np.ndarray`, `tf.Tensor`, `List[tf.Tensor]` ``Dict[str, tf.Tensor]` or `Dict[str, np.ndarray]` and each example must have the shape `(batch_size, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`]
for details.
head_mask (`np.ndarray` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the
config will be used instead.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
used instead.
return_dict (`bool`, *optional*):
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. This argument can be used
in eager mode, in graph mode the value will always be set to True.
training (`bool`, *optional*, defaults to `False``):
Whether or not to use the model in training mode (some modules like dropout modules have different
behaviors between training and evaluation).
注释:
"""
Transformer 模型的 ViTMAE 版本的 TensorFlow 实现,输出原始隐藏状态而不带特定的输出头部。
Args:
config (ViTMAEConfig): ViTMAE 模型的配置对象。
*inputs: 可变长度的输入参数。
**kwargs: 关键字参数。
Attributes:
vit (TFViTMAEMainLayer): ViTMAE 主层对象。
Methods:
get_input_embeddings(): 获取输入嵌入层的方法。
call(): 模型的前向传播方法,接受多种参数并返回模型输出。
build(input_shape=None): 构建模型的方法,用于初始化网络层。
Examples:
```
>>> from transformers import AutoImageProcessor, TFViTMAEModel
>>> from PIL import Image
>>> import requests
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> image_processor = AutoImageProcessor.from_pretrained("facebook/vit-mae-base")
>>> model = TFViTMAEModel.from_pretrained("facebook/vit-mae-base")
>>> inputs = image_processor(images=image, return_tensors="tf")
>>> outputs = model(**inputs)
>>> last_hidden_states = outputs.last_hidden_state
```
"""
class TFViTMAEModel(TFViTMAEPreTrainedModel):
def __init__(self, config: ViTMAEConfig, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
# 初始化 ViTMAE 主层对象
self.vit = TFViTMAEMainLayer(config, name="vit")
def get_input_embeddings(self):
# 调用 ViTMAE 主层对象的输入嵌入层方法
return self.vit.get_input_embeddings()
@unpack_inputs
@add_start_docstrings_to_model_forward(VIT_MAE_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFViTMAEModelOutput, config_class=_CONFIG_FOR_DOC)
def call(
self,
pixel_values: TFModelInputType | None = None,
noise: tf.Tensor = None,
head_mask: np.ndarray | tf.Tensor | None = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
) -> Union[TFViTMAEModelOutput, Tuple[tf.Tensor]]:
r"""
模型的前向传播方法,接受多种参数并返回模型输出。
Args:
pixel_values (TFModelInputType | None): 输入的像素值,可以为 None。
noise (tf.Tensor): 噪声张量,默认为 None。
head_mask (np.ndarray | tf.Tensor | None): 头部掩码,可以为 None。
output_attentions (Optional[bool]): 是否输出注意力权重。
output_hidden_states (Optional[bool]): 是否输出隐藏状态。
return_dict (Optional[bool]): 是否返回字典形式的输出。
training (bool): 是否处于训练模式。
Returns:
Union[TFViTMAEModelOutput, Tuple[tf.Tensor]]: 模型的输出结果。
Examples:
```
>>> from transformers import AutoImageProcessor, TFViTMAEModel
>>> from PIL import Image
>>> import requests
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> image_processor = AutoImageProcessor.from_pretrained("facebook/vit-mae-base")
>>> model = TFViTMAEModel.from_pretrained("facebook/vit-mae-base")
>>> inputs = image_processor(images=image, return_tensors="tf")
>>> outputs = model(**inputs)
>>> last_hidden_states = outputs.last_hidden_state
```
"""
# 调用 ViTMAE 主层对象的前向传播方法
outputs = self.vit(
pixel_values=pixel_values,
noise=noise,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
return outputs
def build(self, input_shape=None):
# 如果已经构建过,直接返回
if self.built:
return
# 设置已构建标志为 True
self.built = True
# 如果存在 ViTMAE 主层对象,则在命名作用域内构建它
if getattr(self, "vit", None) is not None:
with tf.name_scope(self.vit.name):
self.vit.build(None)
# 初始化函数,用于创建对象实例时的初始化操作
def __init__(self, config, num_patches, **kwargs):
# 调用父类的初始化方法
super().__init__(**kwargs)
# 创建一个全连接层作为解码器的嵌入层,用于将输入映射到解码器的隐藏大小
self.decoder_embed = keras.layers.Dense(config.decoder_hidden_size, name="decoder_embed")
# 深拷贝配置对象,用于配置解码器层的参数
decoder_config = deepcopy(config)
decoder_config.hidden_size = config.decoder_hidden_size
decoder_config.num_hidden_layers = config.decoder_num_hidden_layers
decoder_config.num_attention_heads = config.decoder_num_attention_heads
decoder_config.intermediate_size = config.decoder_intermediate_size
# 创建多层解码器,每层使用相同的配置
self.decoder_layers = [
TFViTMAELayer(decoder_config, name=f"decoder_layers.{j}") for j in range(config.decoder_num_hidden_layers)
]
# 创建层归一化层,用于归一化解码器的输出
self.decoder_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="decoder_norm")
# 创建解码器预测层,将解码器输出映射回原始图像块的大小和通道数
self.decoder_pred = keras.layers.Dense(
config.patch_size**2 * config.num_channels,
kernel_initializer=get_initializer(config.initializer_range),
name="decoder_pred",
) # encoder to decoder
# 保存配置对象和图像块数量
self.config = config
self.num_patches = num_patches
# 构建模型,用于在图层创建完成后的初始化和构建操作
def build(self, input_shape=None):
# 创建一个权重,用作掩码令牌,形状为 (1, 1, 解码器隐藏大小)
self.mask_token = self.add_weight(
shape=(1, 1, self.config.decoder_hidden_size),
initializer=tf.random_normal_initializer(stddev=self.config.initializer_range),
trainable=True,
name="mask_token",
)
# 创建解码器位置嵌入权重,形状为 (1, 图像块数量+1, 解码器隐藏大小),初始化为零
self.decoder_pos_embed = self.add_weight(
shape=(1, self.num_patches + 1, self.config.decoder_hidden_size),
initializer="zeros",
trainable=False,
name="decoder_pos_embed",
)
# 使用函数生成二维正弦余弦位置嵌入,并将结果赋值给解码器位置嵌入权重
decoder_pos_embed = get_2d_sincos_pos_embed(
self.decoder_pos_embed.shape[-1],
int(self.num_patches**0.5),
add_cls_token=True,
)[None, ...]
self.decoder_pos_embed.assign(decoder_pos_embed)
# 如果已经构建完成则直接返回
if self.built:
return
self.built = True
# 如果存在解码器嵌入层,则构建该层
if getattr(self, "decoder_embed", None) is not None:
with tf.name_scope(self.decoder_embed.name):
self.decoder_embed.build([None, None, self.config.hidden_size])
# 如果存在解码器归一化层,则构建该层
if getattr(self, "decoder_norm", None) is not None:
with tf.name_scope(self.decoder_norm.name):
self.decoder_norm.build([None, None, self.config.decoder_hidden_size])
# 如果存在解码器预测层,则构建该层
if getattr(self, "decoder_pred", None) is not None:
with tf.name_scope(self.decoder_pred.name):
self.decoder_pred.build([None, None, self.config.decoder_hidden_size])
# 如果存在解码器层,则分别构建每一层
if getattr(self, "decoder_layers", None) is not None:
for layer in self.decoder_layers:
with tf.name_scope(layer.name):
layer.build(None)
# 调用模型,实现模型的前向计算
def call(
self,
hidden_states,
ids_restore,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
# 其他未列出的参数,用于控制前向计算的行为
# 嵌入标记tokens到隐藏状态
x = self.decoder_embed(hidden_states)
# 将mask tokens附加到序列
mask_tokens = tf.tile(
self.mask_token,
(shape_list(x)[0], shape_list(ids_restore)[1] + 1 - shape_list(x)[1], 1),
)
x_ = tf.concat([x[:, 1:, :], mask_tokens], axis=1) # 没有cls token
x_ = tf.gather(x_, axis=1, batch_dims=1, indices=ids_restore) # 取消洗牌
x = tf.concat([x[:, :1, :], x_], axis=1) # 添加cls token
# 添加位置嵌入
hidden_states = x + self.decoder_pos_embed
# 应用Transformer层(块)
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
for i, layer_module in enumerate(self.decoder_layers):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_outputs = layer_module(
hidden_states,
head_mask=None,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
# 对隐藏状态进行归一化
hidden_states = self.decoder_norm(hidden_states)
# 预测器投影
logits = self.decoder_pred(hidden_states)
# 移除cls token
logits = logits[:, 1:, :]
# 根据return_dict决定返回的内容
if not return_dict:
return tuple(v for v in [logits, all_hidden_states, all_self_attentions] if v is not None)
return TFViTMAEDecoderOutput(logits=logits, hidden_states=all_hidden_states, attentions=all_self_attentions)
@add_start_docstrings(
"The ViTMAE Model transformer with the decoder on top for self-supervised pre-training.",
VIT_MAE_START_DOCSTRING,
)
class TFViTMAEForPreTraining(TFViTMAEPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.config = config
# 初始化 ViT 主层
self.vit = TFViTMAEMainLayer(config, name="vit")
# 初始化解码器,传入配置和从 ViT 主层获取的补丁数
self.decoder = TFViTMAEDecoder(
config,
num_patches=self.vit.embeddings.num_patches,
name="decoder",
)
def get_input_embeddings(self):
# 返回 ViT 主层的输入嵌入
return self.vit.get_input_embeddings()
def _prune_heads(self, heads_to_prune):
# 抛出未实现错误,用于剪枝操作
raise NotImplementedError
def patchify(self, pixel_values):
"""
Args:
pixel_values (`tf.Tensor` of shape `(batch_size, height, width, num_channels)` or `(batch_size, num_channels, height, width)`):
Pixel values.
Returns:
`tf.Tensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
Patchified pixel values.
"""
patch_size, num_channels = self.config.patch_size, self.config.num_channels
# 确保通道在最后一个维度
if shape_list(pixel_values)[1] == num_channels:
pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1))
# 断言检查
tf.debugging.assert_equal(
shape_list(pixel_values)[1],
shape_list(pixel_values)[2],
message="Make sure the pixel values have a squared size",
)
tf.debugging.assert_equal(
shape_list(pixel_values)[1] % patch_size,
0,
message="Make sure the pixel values have a size that is divisible by the patch size",
)
tf.debugging.assert_equal(
shape_list(pixel_values)[3],
num_channels,
message=(
"Make sure the number of channels of the pixel values is equal to the one set in the configuration"
),
)
# 补丁化处理
batch_size = shape_list(pixel_values)[0]
num_patches_one_direction = shape_list(pixel_values)[2] // patch_size
patchified_pixel_values = tf.reshape(
pixel_values,
(batch_size, num_patches_one_direction, patch_size, num_patches_one_direction, patch_size, num_channels),
)
patchified_pixel_values = tf.einsum("nhpwqc->nhwpqc", patchified_pixel_values)
patchified_pixel_values = tf.reshape(
patchified_pixel_values,
(batch_size, num_patches_one_direction * num_patches_one_direction, patch_size**2 * num_channels),
)
return patchified_pixel_values
def unpatchify(self, patchified_pixel_values):
"""
Args:
patchified_pixel_values (`tf.Tensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
Patchified pixel values.
Returns:
`tf.Tensor` of shape `(batch_size, height, width, num_channels)`:
Pixel values.
"""
# 从patchified_pixel_values中获取patch大小和通道数
patch_size, num_channels = self.config.patch_size, self.config.num_channels
# 计算每个方向上的patch数量,应该是整数
num_patches_one_direction = int(shape_list(patchified_pixel_values)[1] ** 0.5)
# 进行健全性检查,确保patch数量是可以平方的
tf.debugging.assert_equal(
num_patches_one_direction * num_patches_one_direction,
shape_list(patchified_pixel_values)[1],
message="Make sure that the number of patches can be squared",
)
# 解除patchification
batch_size = shape_list(patchified_pixel_values)[0]
patchified_pixel_values = tf.reshape(
patchified_pixel_values,
(batch_size, num_patches_one_direction, num_patches_one_direction, patch_size, patch_size, num_channels),
)
patchified_pixel_values = tf.einsum("nhwpqc->nhpwqc", patchified_pixel_values)
# 重新组织成完整的像素值形状
pixel_values = tf.reshape(
patchified_pixel_values,
(batch_size, num_patches_one_direction * patch_size, num_patches_one_direction * patch_size, num_channels),
)
return pixel_values
def forward_loss(self, pixel_values, pred, mask):
"""
Args:
pixel_values (`tf.Tensor` of shape `(batch_size, height, width, num_channels)`):
Pixel values.
pred (`tf.Tensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
Predicted pixel values.
mask (`tf.Tensor` of shape `(batch_size, sequence_length)`):
Tensor indicating which patches are masked (1) and which are not (0).
Returns:
`tf.Tensor`: Pixel reconstruction loss.
"""
# 将像素值进行patchify处理
target = self.patchify(pixel_values)
# 如果设置了像素损失的归一化,则进行归一化处理
if self.config.norm_pix_loss:
mean = tf.reduce_mean(target, axis=-1, keepdims=True)
var = tf.math.reduce_variance(target, axis=-1, keepdims=True)
target = (target - mean) / (var + 1.0e-6) ** 0.5
# 计算损失,即预测值与目标值之间的平方差
loss = (pred - target) ** 2
loss = tf.reduce_mean(loss, axis=-1) # [batch_size, num_patches], mean loss per patch
# 计算仅在掩码处损失的平均损失
loss = tf.reduce_sum(loss * mask) / tf.reduce_sum(mask) # mean loss on removed patches
loss = tf.reshape(loss, (1,))
return loss
@unpack_inputs
@add_start_docstrings_to_model_forward(VIT_MAE_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFViTMAEForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
def call(
self,
pixel_values: TFModelInputType | None = None,
noise: tf.Tensor = None,
head_mask: np.ndarray | tf.Tensor | None = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
) -> Union[TFViTMAEForPreTrainingOutput, Tuple[tf.Tensor]]:
r"""
Returns:
Examples:
```
>>> from transformers import AutoImageProcessor, TFViTMAEForPreTraining
>>> from PIL import Image
>>> import requests
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> image_processor = AutoImageProcessor.from_pretrained("facebook/vit-mae-base")
>>> model = TFViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base")
>>> inputs = image_processor(images=image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> loss = outputs.loss
>>> mask = outputs.mask
>>> ids_restore = outputs.ids_restore
```"""
# 根据传入的参数设置是否返回字典格式的输出
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# 调用 Vision Transformer 模型进行前向传播
outputs = self.vit(
pixel_values=pixel_values,
noise=noise,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
# 从输出中获取最后隐藏状态、恢复的图像标识和掩码
latent = outputs.last_hidden_state
ids_restore = outputs.ids_restore
mask = outputs.mask
# 使用解码器生成的 logits 计算前向损失
decoder_outputs = self.decoder(latent, ids_restore) # [batch_size, num_patches, patch_size**2*3]
logits = decoder_outputs.logits
loss = self.forward_loss(pixel_values, logits, mask)
# 根据是否返回字典格式决定返回的输出形式
if not return_dict:
output = (logits, mask, ids_restore) + outputs[2:]
return ((loss,) + output) if loss is not None else output
# 返回 TFViTMAEForPreTrainingOutput 对象,包含损失、logits、掩码、恢复的图像标识、隐藏状态和注意力矩阵
return TFViTMAEForPreTrainingOutput(
loss=loss,
logits=logits,
mask=mask,
ids_restore=ids_restore,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
# 标记模型已经构建
self.built = True
# 如果已定义 Vision Transformer 模型,构建其结构
if getattr(self, "vit", None) is not None:
with tf.name_scope(self.vit.name):
self.vit.build(None)
# 如果已定义解码器模型,构建其结构
if getattr(self, "decoder", None) is not None:
with tf.name_scope(self.decoder.name):
self.decoder.build(None)
.\models\vit_mae\modeling_vit_mae.py
""" PyTorch ViT MAE (masked autoencoder) model."""
import collections.abc
import math
from copy import deepcopy
from dataclasses import dataclass
from typing import Optional, Set, Tuple, Union
import numpy as np
import torch
import torch.utils.checkpoint
from torch import nn
from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from .configuration_vit_mae import ViTMAEConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "ViTMAEConfig"
_CHECKPOINT_FOR_DOC = "facebook/vit-mae-base"
VIT_MAE_PRETRAINED_MODEL_ARCHIVE_LIST = [
"facebook/vit-mae-base",
]
last_hidden_state: torch.FloatTensor = None
mask: torch.LongTensor = None
ids_restore: torch.LongTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class ViTMAEDecoderOutput(ModelOutput):
"""
ViTMAEDecoder的输出类,包含潜在的隐藏状态和注意力。
Args:
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, patch_size ** 2 * num_channels)`):
像素重建的logits。
hidden_states (`tuple(torch.FloatTensor)`, *optional*, 当 `output_hidden_states=True` 时返回或当 `config.output_hidden_states=True` 时返回):
`torch.FloatTensor` 的元组(一个用于嵌入的输出 + 每层的输出),形状为 `(batch_size, sequence_length, hidden_size)`。
模型每层输出的隐藏状态以及初始嵌入输出。
attentions (`tuple(torch.FloatTensor)`, *optional*, 当 `output_attentions=True` 时返回或当 `config.output_attentions=True` 时返回):
`torch.FloatTensor` 的元组(每层一个)形状为 `(batch_size, num_heads, sequence_length, sequence_length)`。
经过注意力softmax后的注意力权重,用于计算自注意力头中的加权平均值。
"""
logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class ViTMAEForPreTrainingOutput(ModelOutput):
"""
ViTMAEForPreTraining的输出类,包含潜在的隐藏状态和注意力。
Args:
loss (`torch.FloatTensor` of shape `(1,)`):
像素重建损失。
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, patch_size ** 2 * num_channels)`):
像素重建的logits。
mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
指示哪些补丁被屏蔽(1)哪些没有(0)的张量。
ids_restore (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
包含(打乱的)屏蔽补丁的原始索引的张量。
hidden_states (`tuple(torch.FloatTensor)`, *optional*, 当 `output_hidden_states=True` 时返回或当 `config.output_hidden_states=True` 时返回):
`torch.FloatTensor` 的元组(一个用于嵌入的输出 + 每层的输出),形状为 `(batch_size, sequence_length, hidden_size)`。
模型每层输出的隐藏状态以及初始嵌入输出。
attentions (`tuple(torch.FloatTensor)`, *optional*, 当 `output_attentions=True` 时返回或当 `config.output_attentions=True` 时返回):
`torch.FloatTensor` 的元组(每层一个)形状为 `(batch_size, num_heads, sequence_length, sequence_length)`。
经过注意力softmax后的注意力权重,用于计算自注意力头中的加权平均值。
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
mask: torch.LongTensor = None
ids_restore: torch.LongTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
def get_2d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False):
"""
Create 2D sin/cos positional embeddings.
Args:
embed_dim (`int`):
Embedding dimension.
grid_size (`int`):
The grid height and width.
add_cls_token (`bool`, *optional*, defaults to `False`):
Whether or not to add a classification (CLS) token.
Returns:
(`torch.FloatTensor` of shape (grid_size*grid_size, embed_dim) or (1+grid_size*grid_size, embed_dim): the
position embeddings (with or without classification token)
"""
grid_h = np.arange(grid_size, dtype=np.float32)
grid_w = np.arange(grid_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h)
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size, grid_size])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if add_cls_token:
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
"""
Generate 2D sin/cos positional embeddings from grid coordinates.
Args:
embed_dim (`int`):
Embedding dimension.
grid (`numpy.ndarray`):
Grid coordinates of shape (2, 1, grid_size, grid_size).
Returns:
(`numpy.ndarray`): Positional embeddings of shape (grid_size*grid_size, embed_dim)
"""
if embed_dim % 2 != 0:
raise ValueError("embed_dim must be even")
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])
emb = np.concatenate([emb_h, emb_w], axis=1)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
Generate 1D sin/cos positional embeddings from positions.
Args:
embed_dim (`int`):
Embedding dimension.
pos (`numpy.ndarray`):
Positions to be encoded, shape (M,).
Returns:
(`numpy.ndarray`): Positional embeddings of shape (M, embed_dim)
"""
if embed_dim % 2 != 0:
raise ValueError("embed_dim must be even")
omega = np.arange(embed_dim // 2, dtype=float)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega
pos = pos.reshape(-1)
out = np.einsum("m,d->md", pos, omega)
emb_sin = np.sin(out)
emb_cos = np.cos(out)
emb = np.concatenate([emb_sin, emb_cos], axis=1)
return emb
class ViTMAEEmbeddings(nn.Module):
"""
Construct the CLS token, position and patch embeddings.
"""
def __init__(self, config):
"""
Initializes ViTMAEEmbeddings module.
Args:
config (`object`):
Configuration object containing model parameters.
"""
super().__init__()
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
self.patch_embeddings = ViTMAEPatchEmbeddings(config)
self.num_patches = self.patch_embeddings.num_patches
self.position_embeddings = nn.Parameter(
torch.zeros(1, self.num_patches + 1, config.hidden_size), requires_grad=False
)
self.config = config
self.initialize_weights()
def initialize_weights(self):
"""
Initialize weights of the module.
"""
pass
def initialize_weights(self):
pos_embed = get_2d_sincos_pos_embed(
self.position_embeddings.shape[-1], int(self.patch_embeddings.num_patches**0.5), add_cls_token=True
)
self.position_embeddings.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
w = self.patch_embeddings.projection.weight.data
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
torch.nn.init.normal_(self.cls_token, std=self.config.initializer_range)
def random_masking(self, sequence, noise=None):
"""
执行每个样本的随机掩码操作,通过每个样本的排序随机噪声来进行。排序随机噪声通过 argsort 实现。
Args:
sequence (`torch.LongTensor` of shape `(batch_size, sequence_length, dim)`)
输入序列张量
noise (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *可选*)
主要用于测试目的,控制随机性以及保持可重现性
"""
batch_size, seq_length, dim = sequence.shape
len_keep = int(seq_length * (1 - self.config.mask_ratio))
if noise is None:
noise = torch.rand(batch_size, seq_length, device=sequence.device)
ids_shuffle = torch.argsort(noise, dim=1)
ids_restore = torch.argsort(ids_shuffle, dim=1)
ids_keep = ids_shuffle[:, :len_keep]
sequence_unmasked = torch.gather(sequence, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, dim))
mask = torch.ones([batch_size, seq_length], device=sequence.device)
mask[:, :len_keep] = 0
mask = torch.gather(mask, dim=1, index=ids_restore)
return sequence_unmasked, mask, ids_restore
def forward(self, pixel_values, noise=None):
batch_size, num_channels, height, width = pixel_values.shape
embeddings = self.patch_embeddings(pixel_values)
embeddings = embeddings + self.position_embeddings[:, 1:, :]
embeddings, mask, ids_restore = self.random_masking(embeddings, noise)
cls_token = self.cls_token + self.position_embeddings[:, :1, :]
cls_tokens = cls_token.expand(embeddings.shape[0], -1, -1)
embeddings = torch.cat((cls_tokens, embeddings), dim=1)
return embeddings, mask, ids_restore
class ViTMAEPatchEmbeddings(nn.Module):
"""
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
Transformer.
"""
def __init__(self, config):
super().__init__()
image_size, patch_size = config.image_size, config.patch_size
num_channels, hidden_size = config.num_channels, config.hidden_size
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.num_patches = num_patches
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
def forward(self, pixel_values):
batch_size, num_channels, height, width = pixel_values.shape
if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
if height != self.image_size[0] or width != self.image_size[1]:
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
)
x = self.projection(pixel_values).flatten(2).transpose(1, 2)
return x
class ViTMAESelfAttention(nn.Module):
def __init__(self, config: ViTMAEConfig) -> None:
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
f"heads {config.num_attention_heads}."
)
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
attention_probs = self.dropout(attention_probs)
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs
class ViTMAESelfOutput(nn.Module):
"""
The residual connection is defined in ViTMAELayer instead of here (as is the case with other models), due to the
layernorm applied before each block.
"""
def __init__(self, config: ViTMAEConfig) -> None:
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states
class ViTMAEAttention(nn.Module):
def __init__(self, config: ViTMAEConfig) -> None:
super().__init__()
self.attention = ViTMAESelfAttention(config)
self.output = ViTMAESelfOutput(config)
self.pruned_heads = set()
def prune_heads(self, heads: Set[int]) -> None:
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(
heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
)
self.attention.query = prune_linear_layer(self.attention.query, index)
self.attention.key = prune_linear_layer(self.attention.key, index)
self.attention.value = prune_linear_layer(self.attention.value, index)
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
def forward(
self,
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
self_outputs = self.attention(hidden_states, head_mask, output_attentions)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:]
return outputs
class ViTMAEIntermediate(nn.Module):
def __init__(self, config: ViTMAEConfig) -> None:
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class ViTMAEOutput(nn.Module):
def __init__(self, config: ViTMAEConfig) -> None:
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = hidden_states + input_tensor
return hidden_states
class ViTMAELayer(nn.Module):
"""对应 timm 实现中的 Block 类。"""
def __init__(self, config: ViTMAEConfig) -> None:
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = ViTMAEAttention(config)
self.intermediate = ViTMAEIntermediate(config)
self.output = ViTMAEOutput(config)
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
self_attention_outputs = self.attention(
self.layernorm_before(hidden_states),
head_mask,
output_attentions=output_attentions,
)
attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:]
hidden_states = attention_output + hidden_states
layer_output = self.layernorm_after(hidden_states)
layer_output = self.intermediate(layer_output)
layer_output = self.output(layer_output, hidden_states)
outputs = (layer_output,) + outputs
return outputs
class ViTMAEEncoder(nn.Module):
def __init__(self, config: ViTMAEConfig) -> None:
super().__init__()
self.config = config
self.layer = nn.ModuleList([ViTMAELayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
) -> Union[tuple, BaseModelOutput]:
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
layer_head_mask,
output_attentions,
)
else:
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
@add_start_docstrings(
"The bare ViTMAE Model transformer outputting raw hidden-states without any specific head on top.",
VIT_MAE_START_DOCSTRING,
)
class ViTMAEModel(ViTMAEPreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = ViTMAEConfig
base_model_prefix = "vit"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d)):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def __init__(self, config):
super().__init__(config)
self.config = config
self.embeddings = ViTMAEEmbeddings(config)
self.encoder = ViTMAEEncoder(config)
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.post_init()
def get_input_embeddings(self):
return self.embeddings.patch_embeddings
def _prune_heads(self, heads_to_prune):
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
@add_start_docstrings_to_model_forward(VIT_MAE_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=ViTMAEModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
noise: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
embedding_output, mask, ids_restore = self.embeddings(pixel_values, noise=noise)
encoder_outputs = self.encoder(
embedding_output,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = encoder_outputs[0]
sequence_output = self.layernorm(sequence_output)
if not return_dict:
return (sequence_output, mask, ids_restore) + encoder_outputs[1:]
return ViTMAEModelOutput(
last_hidden_state=sequence_output,
mask=mask,
ids_restore=ids_restore,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
class ViTMAEDecoder(nn.Module):
def __init__(self, config, num_patches):
super().__init__()
self.decoder_embed = nn.Linear(config.hidden_size, config.decoder_hidden_size, bias=True)
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.decoder_hidden_size))
self.decoder_pos_embed = nn.Parameter(
torch.zeros(1, num_patches + 1, config.decoder_hidden_size), requires_grad=False
)
decoder_config = deepcopy(config)
decoder_config.hidden_size = config.decoder_hidden_size
decoder_config.num_hidden_layers = config.decoder_num_hidden_layers
decoder_config.num_attention_heads = config.decoder_num_attention_heads
decoder_config.intermediate_size = config.decoder_intermediate_size
self.decoder_layers = nn.ModuleList(
[ViTMAELayer(decoder_config) for _ in range(config.decoder_num_hidden_layers)]
)
self.decoder_norm = nn.LayerNorm(config.decoder_hidden_size, eps=config.layer_norm_eps)
self.decoder_pred = nn.Linear(
config.decoder_hidden_size, config.patch_size**2 * config.num_channels, bias=True
)
self.gradient_checkpointing = False
self.config = config
self.initialize_weights(num_patches)
def initialize_weights(self, num_patches):
decoder_pos_embed = get_2d_sincos_pos_embed(
self.decoder_pos_embed.shape[-1], int(num_patches**0.5), add_cls_token=True
)
self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
torch.nn.init.normal_(self.mask_token, std=self.config.initializer_range)
def forward(
self,
hidden_states,
ids_restore,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
x = self.decoder_embed(hidden_states)
mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)
x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))
x = torch.cat([x[:, :1, :], x_], dim=1)
hidden_states = x + self.decoder_pos_embed
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
for i, layer_module in enumerate(self.decoder_layers):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
None,
output_attentions,
)
else:
layer_outputs = layer_module(hidden_states, head_mask=None, output_attentions=output_attentions)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
hidden_states = self.decoder_norm(hidden_states)
logits = self.decoder_pred(hidden_states)
logits = logits[:, 1:, :]
if not return_dict:
return tuple(v for v in [logits, all_hidden_states, all_self_attentions] if v is not None)
return ViTMAEDecoderOutput(
logits=logits,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
@add_start_docstrings(
"""The ViTMAE Model transformer with the decoder on top for self-supervised pre-training.
<Tip>
Note that we provide a script to pre-train this model on custom data in our [examples
directory](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-pretraining).
</Tip>
""",
VIT_MAE_START_DOCSTRING,
)
class ViTMAEForPreTraining(ViTMAEPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.config = config
self.vit = ViTMAEModel(config)
self.decoder = ViTMAEDecoder(config, num_patches=self.vit.embeddings.num_patches)
self.post_init()
def get_input_embeddings(self):
return self.vit.embeddings.patch_embeddings
def _prune_heads(self, heads_to_prune):
"""
Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
See base class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
def patchify(self, pixel_values):
"""
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values.
Returns:
`torch.FloatTensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
Patchified pixel values.
"""
patch_size, num_channels = self.config.patch_size, self.config.num_channels
if (pixel_values.shape[2] != pixel_values.shape[3]) or (pixel_values.shape[2] % patch_size != 0):
raise ValueError("Make sure the pixel values have a squared size that is divisible by the patch size")
if pixel_values.shape[1] != num_channels:
raise ValueError(
"Make sure the number of channels of the pixel values is equal to the one set in the configuration"
)
batch_size = pixel_values.shape[0]
num_patches_one_direction = pixel_values.shape[2] // patch_size
patchified_pixel_values = pixel_values.reshape(
batch_size, num_channels, num_patches_one_direction, patch_size, num_patches_one_direction, patch_size
)
patchified_pixel_values = torch.einsum("nchpwq->nhwpqc", patchified_pixel_values)
patchified_pixel_values = patchified_pixel_values.reshape(
batch_size, num_patches_one_direction * num_patches_one_direction, patch_size**2 * num_channels
)
return patchified_pixel_values
def unpatchify(self, patchified_pixel_values):
"""
Args:
patchified_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
Patchified pixel values.
Returns:
`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`:
Pixel values.
"""
patch_size, num_channels = self.config.patch_size, self.config.num_channels
num_patches_one_direction = int(patchified_pixel_values.shape[1] ** 0.5)
if num_patches_one_direction**2 != patchified_pixel_values.shape[1]:
raise ValueError("Make sure that the number of patches can be squared")
batch_size = patchified_pixel_values.shape[0]
patchified_pixel_values = patchified_pixel_values.reshape(
batch_size,
num_patches_one_direction,
num_patches_one_direction,
patch_size,
patch_size,
num_channels,
)
patchified_pixel_values = torch.einsum("nhwpqc->nchpwq", patchified_pixel_values)
pixel_values = patchified_pixel_values.reshape(
batch_size,
num_channels,
num_patches_one_direction * patch_size,
num_patches_one_direction * patch_size,
)
return pixel_values
def forward_loss(self, pixel_values, pred, mask):
"""
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values.
pred (`torch.FloatTensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
Predicted pixel values.
mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
Tensor indicating which patches are masked (1) and which are not (0).
Returns:
`torch.FloatTensor`: Pixel reconstruction loss.
"""
target = self.patchify(pixel_values)
if self.config.norm_pix_loss:
mean = target.mean(dim=-1, keepdim=True)
var = target.var(dim=-1, keepdim=True)
target = (target - mean) / (var + 1.0e-6) ** 0.5
loss = (pred - target) ** 2
loss = loss.mean(dim=-1)
loss = (loss * mask).sum() / mask.sum()
return loss
@add_start_docstrings_to_model_forward(VIT_MAE_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=ViTMAEForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
noise: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.vit(
pixel_values,
noise=noise,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
latent = outputs.last_hidden_state
ids_restore = outputs.ids_restore
mask = outputs.mask
decoder_outputs = self.decoder(latent, ids_restore)
logits = decoder_outputs.logits
loss = self.forward_loss(pixel_values, logits, mask)
if not return_dict:
output = (logits, mask, ids_restore) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return ViTMAEForPreTrainingOutput(
loss=loss,
logits=logits,
mask=mask,
ids_restore=ids_restore,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
.\models\vit_mae\__init__.py
from typing import TYPE_CHECKING
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_flax_available,
is_tf_available,
is_torch_available,
)
_import_structure = {"configuration_vit_mae": ["VIT_MAE_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTMAEConfig"]}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_vit_mae"] = [
"VIT_MAE_PRETRAINED_MODEL_ARCHIVE_LIST",
"ViTMAEForPreTraining",
"ViTMAELayer",
"ViTMAEModel",
"ViTMAEPreTrainedModel",
]
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_tf_vit_mae"] = [
"TFViTMAEForPreTraining",
"TFViTMAEModel",
"TFViTMAEPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_vit_mae import VIT_MAE_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTMAEConfig
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_vit_mae import (
VIT_MAE_PRETRAINED_MODEL_ARCHIVE_LIST,
ViTMAEForPreTraining,
ViTMAELayer,
ViTMAEModel,
ViTMAEPreTrainedModel,
)
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_tf_vit_mae import TFViTMAEForPreTraining, TFViTMAEModel, TFViTMAEPreTrainedModel
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
.\models\vit_msn\configuration_vit_msn.py
""" ViT MSN model configuration"""
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
VIT_MSN_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"sayakpaul/vit-msn-base": "https://huggingface.co/sayakpaul/vit-msn-base/resolve/main/config.json",
}
class ViTMSNConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`ViTMSNModel`]. It is used to instantiate an ViT
MSN model according to the specified arguments, defining the model architecture. Instantiating a configuration with
the defaults will yield a similar configuration to that of the ViT
[facebook/vit_msn_base](https://huggingface.co/facebook/vit_msn_base) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
"""
model_type = "vit_msn"
def __init__(
self,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
initializer_range=0.02,
layer_norm_eps=1e-06,
image_size=224,
patch_size=16,
num_channels=3,
qkv_bias=True,
**kwargs,
):
):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.qkv_bias = qkv_bias
.\models\vit_msn\convert_msn_to_pytorch.py
"""Convert ViT MSN checkpoints from the original repository: https://github.com/facebookresearch/msn"""
import argparse
import json
import requests
import torch
from huggingface_hub import hf_hub_download
from PIL import Image
from transformers import ViTImageProcessor, ViTMSNConfig, ViTMSNModel
from transformers.image_utils import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
torch.set_grad_enabled(False)
def create_rename_keys(config, base_model=False):
rename_keys = []
for i in range(config.num_hidden_layers):
rename_keys.append((f"module.blocks.{i}.norm1.weight", f"vit.encoder.layer.{i}.layernorm_before.weight"))
rename_keys.append((f"module.blocks.{i}.norm1.bias", f"vit.encoder.layer.{i}.layernorm_before.bias"))
rename_keys.append(
(f"module.blocks.{i}.attn.proj.weight", f"vit.encoder.layer.{i}.attention.output.dense.weight")
)
rename_keys.append((f"module.blocks.{i}.attn.proj.bias", f"vit.encoder.layer.{i}.attention.output.dense.bias"))
rename_keys.append((f"module.blocks.{i}.norm2.weight", f"vit.encoder.layer.{i}.layernorm_after.weight"))
rename_keys.append((f"module.blocks.{i}.norm2.bias", f"vit.encoder.layer.{i}.layernorm_after.bias"))
rename_keys.append((f"module.blocks.{i}.mlp.fc1.weight", f"vit.encoder.layer.{i}.intermediate.dense.weight"))
rename_keys.append((f"module.blocks.{i}.mlp.fc1.bias", f"vit.encoder.layer.{i}.intermediate.dense.bias"))
rename_keys.append((f"module.blocks.{i}.mlp.fc2.weight", f"vit.encoder.layer.{i}.output.dense.weight"))
rename_keys.append((f"module.blocks.{i}.mlp.fc2.bias", f"vit.encoder.layer.{i}.output.dense.bias"))
rename_keys.extend(
[
("module.cls_token", "vit.embeddings.cls_token"),
("module.patch_embed.proj.weight", "vit.embeddings.patch_embeddings.projection.weight"),
("module.patch_embed.proj.bias", "vit.embeddings.patch_embeddings.projection.bias"),
("module.pos_embed", "vit.embeddings.position_embeddings"),
]
)
if base_model:
rename_keys.extend(
[
("module.norm.weight", "layernorm.weight"),
("module.norm.bias", "layernorm.bias"),
]
)
rename_keys = [(pair[0], pair[1][4:]) if pair[1].startswith("vit") else pair for pair in rename_keys]
else:
rename_keys.extend(
[
("norm.weight", "vit.layernorm.weight"),
("norm.bias", "vit.layernorm.bias"),
("head.weight", "classifier.weight"),
("head.bias", "classifier.bias"),
]
)
return rename_keys
def read_in_q_k_v(state_dict, config, base_model=False):
for i in range(config.num_hidden_layers):
if base_model:
prefix = ""
else:
prefix = "vit."
in_proj_weight = state_dict.pop(f"module.blocks.{i}.attn.qkv.weight")
in_proj_bias = state_dict.pop(f"module.blocks.{i}.attn.qkv.bias")
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[
: config.hidden_size, :
]
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.bias"] = in_proj_bias[: config.hidden_size]
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[
config.hidden_size : config.hidden_size * 2, :
]
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.bias"] = in_proj_bias[
config.hidden_size : config.hidden_size * 2
]
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[
-config.hidden_size :, :
]
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.bias"] = in_proj_bias[-config.hidden_size :]
def remove_classification_head_(state_dict):
ignore_keys = ["head.weight", "head.bias"]
for k in ignore_keys:
state_dict.pop(k, None)
def remove_projection_head(state_dict):
ignore_keys = [
"module.fc.fc1.weight",
"module.fc.fc1.bias",
"module.fc.bn1.weight",
"module.fc.bn1.bias",
"module.fc.bn1.running_mean",
"module.fc.bn1.running_var",
"module.fc.bn1.num_batches_tracked",
"module.fc.fc2.weight",
"module.fc.fc2.bias",
"module.fc.bn2.weight",
"module.fc.bn2.bias",
"module.fc.bn2.running_mean",
"module.fc.bn2.running_var",
"module.fc.bn2.num_batches_tracked",
"module.fc.fc3.weight",
"module.fc.fc3.bias",
]
for k in ignore_keys:
state_dict.pop(k, None)
def rename_key(dct, old, new):
val = dct.pop(old)
dct[new] = val
def convert_vit_msn_checkpoint(checkpoint_url, pytorch_dump_folder_path):
config = ViTMSNConfig()
config.num_labels = 1000
repo_id = "datasets/huggingface/label-files"
filename = "imagenet-1k-id2label.json"
id2label = json.load(open(hf_hub_download(repo_id, filename), "r"))
id2label = {int(k): v for k, v in id2label.items()}
config.id2label = id2label
config.label2id = {v: k for k, v in id2label.items()}
if "s16" in checkpoint_url:
config.hidden_size = 384
config.intermediate_size = 1536
config.num_attention_heads = 6
elif "l16" in checkpoint_url:
config.hidden_size = 1024
config.intermediate_size = 4096
config.num_hidden_layers = 24
config.num_attention_heads = 16
config.hidden_dropout_prob = 0.1
elif "b4" in checkpoint_url:
config.patch_size = 4
elif "l7" in checkpoint_url:
config.patch_size = 7
config.hidden_size = 1024
config.intermediate_size = 4096
config.num_hidden_layers = 24
config.num_attention_heads = 16
config.hidden_dropout_prob = 0.1
model = ViTMSNModel(config)
state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu")["target_encoder"]
image_processor = ViTImageProcessor(size=config.image_size)
remove_projection_head(state_dict)
rename_keys = create_rename_keys(config, base_model=True)
for src, dest in rename_keys:
rename_key(state_dict, src, dest)
read_in_q_k_v(state_dict, config, base_model=True)
model.load_state_dict(state_dict)
model.eval()
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
image_processor = ViTImageProcessor(
size=config.image_size, image_mean=IMAGENET_DEFAULT_MEAN, image_std=IMAGENET_DEFAULT_STD
)
inputs = image_processor(images=image, return_tensors="pt")
torch.manual_seed(2)
outputs = model(**inputs)
last_hidden_state = outputs.last_hidden_state
if "s16" in checkpoint_url:
expected_slice = torch.tensor([[-1.0915, -1.4876, -1.1809]])
elif "b16" in checkpoint_url:
expected_slice = torch.tensor([[14.2889, -18.9045, 11.7281]])
elif "l16" in checkpoint_url:
expected_slice = torch.tensor([[41.5028, -22.8681, 45.6475]])
elif "b4" in checkpoint_url:
expected_slice = torch.tensor([[-4.3868, 5.2932, -0.4137]])
else:
expected_slice = torch.tensor([[-0.1792, -0.6465, 2.4263]])
assert torch.allclose(last_hidden_state[:, 0, :3], expected_slice, atol=1e-4)
print(f"Saving model to {pytorch_dump_folder_path}")
model.save_pretrained(pytorch_dump_folder_path)
print(f"Saving image processor to {pytorch_dump_folder_path}")
image_processor.save_pretrained(pytorch_dump_folder_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--checkpoint_url",
default="https://dl.fbaipublicfiles.com/msn/vits16_800ep.pth.tar",
type=str,
help="URL of the checkpoint you'd like to convert."
)
parser.add_argument(
"--pytorch_dump_folder_path",
default=None,
type=str,
help="Path to the output PyTorch model directory."
)
args = parser.parse_args()
convert_vit_msn_checkpoint(args.checkpoint_url, args.pytorch_dump_folder_path)
.\models\vit_msn\modeling_vit_msn.py
""" PyTorch ViT MSN(masked siamese network)模型。"""
import collections.abc
import math
from typing import Dict, List, Optional, Set, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput, ImageClassifierOutput
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_vit_msn import ViTMSNConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "ViTMSNConfig"
_CHECKPOINT_FOR_DOC = "facebook/vit-msn-small"
VIT_MSN_PRETRAINED_MODEL_ARCHIVE_LIST = [
"facebook/vit-msn-small",
]
class ViTMSNEmbeddings(nn.Module):
"""
构建 CLS 令牌、位置和补丁嵌入。可选地,也包括掩码令牌。
"""
def __init__(self, config: ViTMSNConfig, use_mask_token: bool = False) -> None:
super().__init__()
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None
self.patch_embeddings = ViTMSNPatchEmbeddings(config)
num_patches = self.patch_embeddings.num_patches
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.config = config
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
resolution images.
Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""
num_patches = embeddings.shape[1] - 1
num_positions = self.position_embeddings.shape[1] - 1
if num_patches == num_positions and height == width:
return self.position_embeddings
class_pos_embed = self.position_embeddings[:, 0]
patch_pos_embed = self.position_embeddings[:, 1:]
dim = embeddings.shape[-1]
patch_window_height = height // self.config.patch_size
patch_window_width = width // self.config.patch_size
patch_window_height, patch_window_width = patch_window_height + 0.1, patch_window_width + 0.1
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
scale_factor=(
patch_window_height / math.sqrt(num_positions),
patch_window_width / math.sqrt(num_positions),
),
mode="bicubic",
align_corners=False,
)
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
) -> torch.Tensor:
batch_size, num_channels, height, width = pixel_values.shape
embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
if bool_masked_pos is not None:
seq_length = embeddings.shape[1]
mask_tokens = self.mask_token.expand(batch_size, seq_length, -1)
mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
embeddings = torch.cat((cls_tokens, embeddings), dim=1)
if interpolate_pos_encoding:
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
else:
embeddings = embeddings + self.position_embeddings
embeddings = self.dropout(embeddings)
return embeddings
class ViTMSNPatchEmbeddings(nn.Module):
"""
这个类将形状为`(batch_size, num_channels, height, width)`的`pixel_values`转换为形状为`(batch_size, seq_length, hidden_size)`的初始隐藏状态(patch embeddings),
以供Transformer使用。
"""
def __init__(self, config):
super().__init__()
image_size, patch_size = config.image_size, config.patch_size
num_channels, hidden_size = config.num_channels, config.hidden_size
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.num_patches = num_patches
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
batch_size, num_channels, height, width = pixel_values.shape
if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
f" Expected {self.num_channels} but got {num_channels}."
)
if not interpolate_pos_encoding:
if height != self.image_size[0] or width != self.image_size[1]:
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model"
f" ({self.image_size[0]}*{self.image_size[1]})."
)
embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
return embeddings
class ViTMSNSelfAttention(nn.Module):
def __init__(self, config: ViTMSNConfig) -> None:
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
f"heads {config.num_attention_heads}."
)
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
attention_probs = self.dropout(attention_probs)
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs
class ViTMSNSelfOutput(nn.Module):
"""
The residual connection is defined in ViTMSNLayer instead of here (as is the case with other models), due to the
layernorm applied before each block.
"""
def __init__(self, config: ViTMSNConfig) -> None:
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states
class ViTMSNAttention(nn.Module):
def __init__(self, config: ViTMSNConfig) -> None:
super().__init__()
self.attention = ViTMSNSelfAttention(config)
self.output = ViTMSNSelfOutput(config)
self.pruned_heads = set()
def prune_heads(self, heads: Set[int]) -> None:
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(
heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
)
self.attention.query = prune_linear_layer(self.attention.query, index)
self.attention.key = prune_linear_layer(self.attention.key, index)
self.attention.value = prune_linear_layer(self.attention.value, index)
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
def forward(
self,
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
self_outputs = self.attention(hidden_states, head_mask, output_attentions)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:]
return outputs
class ViTMSNIntermediate(nn.Module):
def __init__(self, config: ViTMSNConfig) -> None:
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class ViTMSNOutput(nn.Module):
def __init__(self, config: ViTMSNConfig) -> None:
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = hidden_states + input_tensor
return hidden_states
class ViTMSNLayer(nn.Module):
"""This corresponds to the Block class in the timm implementation."""
def __init__(self, config: ViTMSNConfig) -> None:
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = ViTMSNAttention(config)
self.intermediate = ViTMSNIntermediate(config)
self.output = ViTMSNOutput(config)
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
self_attention_outputs = self.attention(
self.layernorm_before(hidden_states),
head_mask,
output_attentions=output_attentions,
)
attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:]
hidden_states = attention_output + hidden_states
layer_output = self.layernorm_after(hidden_states)
layer_output = self.intermediate(layer_output)
layer_output = self.output(layer_output, hidden_states)
outputs = (layer_output,) + outputs
return outputs
class ViTMSNEncoder(nn.Module):
def __init__(self, config: ViTMSNConfig) -> None:
super().__init__()
self.config = config
self.layer = nn.ModuleList([ViTMSNLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
) -> Union[tuple, BaseModelOutput]:
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
layer_head_mask,
output_attentions,
)
else:
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
class ViTMSNPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = ViTMSNConfig
base_model_prefix = "vit"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d)):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
VIT_MSN_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
behavior.
Parameters:
config ([`ViTMSNConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
VIT_MSN_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`]
for details.
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
interpolate_pos_encoding (`bool`, *optional*):
Whether to interpolate the pre-trained position encodings.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
@add_start_docstrings(
"The bare ViTMSN Model outputting raw hidden-states without any specific head on top.",
VIT_MSN_START_DOCSTRING,
)
class ViTMSNModel(ViTMSNPreTrainedModel):
def __init__(self, config: ViTMSNConfig, use_mask_token: bool = False):
super().__init__(config)
self.config = config
self.embeddings = ViTMSNEmbeddings(config, use_mask_token=use_mask_token)
self.encoder = ViTMSNEncoder(config)
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.post_init()
def get_input_embeddings(self) -> ViTMSNPatchEmbeddings:
return self.embeddings.patch_embeddings
def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
@add_start_docstrings_to_model_forward(VIT_MSN_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
bool_masked_pos: Optional[torch.BoolTensor] = None,
head_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: Optional[bool] = None,
return_dict: Optional[bool] = None,
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
embedding_output = self.embeddings(
pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
)
encoder_outputs = self.encoder(
embedding_output,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = encoder_outputs[0]
sequence_output = self.layernorm(sequence_output)
if not return_dict:
head_outputs = (sequence_output,)
return head_outputs + encoder_outputs[1:]
return BaseModelOutput(
last_hidden_state=sequence_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
@add_start_docstrings(
"""
在顶部具有图像分类头的 ViTMSN 模型,例如用于 ImageNet。
""",
VIT_MSN_START_DOCSTRING,
)
class ViTMSNForImageClassification(ViTMSNPreTrainedModel):
def __init__(self, config: ViTMSNConfig) -> None:
super().__init__(config)
self.num_labels = config.num_labels
self.vit = ViTMSNModel(config)
self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
self.post_init()
@add_start_docstrings_to_model_forward(VIT_MSN_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=ImageClassifierOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: Optional[bool] = None,
return_dict: Optional[bool] = None,