Transformers 源码解析(一百一十七)
.\models\vision_text_dual_encoder\__init__.py
from typing import TYPE_CHECKING
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_flax_available,
is_tf_available,
is_torch_available,
)
_import_structure = {
"configuration_vision_text_dual_encoder": ["VisionTextDualEncoderConfig"],
"processing_vision_text_dual_encoder": ["VisionTextDualEncoderProcessor"],
}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_vision_text_dual_encoder"] = ["VisionTextDualEncoderModel"]
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_flax_vision_text_dual_encoder"] = ["FlaxVisionTextDualEncoderModel"]
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_tf_vision_text_dual_encoder"] = ["TFVisionTextDualEncoderModel"]
if TYPE_CHECKING:
from .configuration_vision_text_dual_encoder import VisionTextDualEncoderConfig
from .processing_vision_text_dual_encoder import VisionTextDualEncoderProcessor
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_vision_text_dual_encoder import VisionTextDualEncoderModel
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_flax_vision_text_dual_encoder import FlaxVisionTextDualEncoderModel
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_tf_vision_text_dual_encoder import TFVisionTextDualEncoderModel
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
.\models\visual_bert\configuration_visual_bert.py
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"uclanlp/visualbert-vqa": "https://huggingface.co/uclanlp/visualbert-vqa/resolve/main/config.json",
"uclanlp/visualbert-vqa-pre": "https://huggingface.co/uclanlp/visualbert-vqa-pre/resolve/main/config.json",
"uclanlp/visualbert-vqa-coco-pre": (
"https://huggingface.co/uclanlp/visualbert-vqa-coco-pre/resolve/main/config.json"
),
"uclanlp/visualbert-vcr": "https://huggingface.co/uclanlp/visualbert-vcr/resolve/main/config.json",
"uclanlp/visualbert-vcr-pre": "https://huggingface.co/uclanlp/visualbert-vcr-pre/resolve/main/config.json",
"uclanlp/visualbert-vcr-coco-pre": (
"https://huggingface.co/uclanlp/visualbert-vcr-coco-pre/resolve/main/config.json"
),
"uclanlp/visualbert-nlvr2": "https://huggingface.co/uclanlp/visualbert-nlvr2/resolve/main/config.json",
"uclanlp/visualbert-nlvr2-pre": "https://huggingface.co/uclanlp/visualbert-nlvr2-pre/resolve/main/config.json",
"uclanlp/visualbert-nlvr2-coco-pre": (
"https://huggingface.co/uclanlp/visualbert-nlvr2-coco-pre/resolve/main/config.json"
),
}
class VisualBertConfig(PretrainedConfig):
r"""
这是 VisualBERT 模型的配置类,用于存储 [`VisualBertModel`] 的配置信息。根据指定的参数实例化一个 VisualBERT 模型,
定义模型架构。使用默认配置实例化一个配置对象将产生类似于 VisualBERT [uclanlp/visualbert-vqa-coco-pre]
(https://huggingface.co/uclanlp/visualbert-vqa-coco-pre) 架构的配置。
配置对象继承自 [`PretrainedConfig`],可用于控制模型的输出。阅读 [`PretrainedConfig`] 的文档以获取更多信息。
Example:
```
>>> from transformers import VisualBertConfig, VisualBertModel
>>> # 初始化一个 VisualBERT visualbert-vqa-coco-pre 风格的配置
>>> configuration = VisualBertConfig.from_pretrained("uclanlp/visualbert-vqa-coco-pre")
```
"""
model = VisualBertModel(configuration)
configuration = model.config
.\models\visual_bert\convert_visual_bert_original_pytorch_checkpoint_to_pytorch.py
"""Convert VisualBert checkpoint."""
import argparse
from collections import OrderedDict
from pathlib import Path
import torch
from transformers import (
VisualBertConfig,
VisualBertForMultipleChoice,
VisualBertForPreTraining,
VisualBertForQuestionAnswering,
VisualBertForVisualReasoning,
)
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
rename_keys_prefix = [
("bert.bert", "visual_bert"),
("bert.cls", "cls"),
("bert.classifier", "cls"),
("token_type_embeddings_visual", "visual_token_type_embeddings"),
("position_embeddings_visual", "visual_position_embeddings"),
("projection", "visual_projection"),
]
ACCEPTABLE_CHECKPOINTS = [
"nlvr2_coco_pre_trained.th",
"nlvr2_fine_tuned.th",
"nlvr2_pre_trained.th",
"vcr_coco_pre_train.th",
"vcr_fine_tune.th",
"vcr_pre_train.th",
"vqa_coco_pre_trained.th",
"vqa_fine_tuned.th",
"vqa_pre_trained.th",
]
def load_state_dict(checkpoint_path):
sd = torch.load(checkpoint_path, map_location="cpu")
return sd
def get_new_dict(d, config, rename_keys_prefix=rename_keys_prefix):
new_d = OrderedDict()
new_d["visual_bert.embeddings.position_ids"] = torch.arange(config.max_position_embeddings).expand((1, -1))
for key in d:
if "detector" in key:
continue
new_key = key
for name_pair in rename_keys_prefix:
new_key = new_key.replace(name_pair[0], name_pair[1])
new_d[new_key] = d[key]
if key == "bert.cls.predictions.decoder.weight":
new_d["cls.predictions.decoder.bias"] = new_d["cls.predictions.bias"]
return new_d
@torch.no_grad()
def convert_visual_bert_checkpoint(checkpoint_path, pytorch_dump_folder_path):
"""
Copy/paste/tweak model's weights to our VisualBERT structure.
"""
assert (
checkpoint_path.split("/")[-1] in ACCEPTABLE_CHECKPOINTS
), f"The checkpoint provided must be in {ACCEPTABLE_CHECKPOINTS}."
if "pre" in checkpoint_path:
model_type = "pretraining"
if "vcr" in checkpoint_path:
config_params = {"visual_embedding_dim": 512}
elif "vqa_advanced" in checkpoint_path:
config_params = {"visual_embedding_dim": 2048}
elif "vqa" in checkpoint_path:
config_params = {"visual_embedding_dim": 2048}
elif "nlvr" in checkpoint_path:
config_params = {"visual_embedding_dim": 1024}
else:
raise NotImplementedError(f"No implementation found for `{checkpoint_path}`.")
else:
if "vcr" in checkpoint_path:
config_params = {"visual_embedding_dim": 512}
model_type = "multichoice"
elif "vqa_advanced" in checkpoint_path:
config_params = {"visual_embedding_dim": 2048}
model_type = "vqa_advanced"
elif "vqa" in checkpoint_path:
config_params = {"visual_embedding_dim": 2048, "num_labels": 3129}
model_type = "vqa"
elif "nlvr" in checkpoint_path:
config_params = {
"visual_embedding_dim": 1024,
"num_labels": 2,
}
model_type = "nlvr"
config = VisualBertConfig(**config_params)
state_dict = load_state_dict(checkpoint_path)
new_state_dict = get_new_dict(state_dict, config)
if model_type == "pretraining":
model = VisualBertForPreTraining(config)
elif model_type == "vqa":
model = VisualBertForQuestionAnswering(config)
elif model_type == "nlvr":
model = VisualBertForVisualReasoning(config)
elif model_type == "multichoice":
model = VisualBertForMultipleChoice(config)
model.load_state_dict(new_state_dict)
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
model.save_pretrained(pytorch_dump_folder_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("orig_checkpoint_path", type=str, help="A path to .th on local filesystem.")
parser.add_argument("pytorch_dump_folder_path", type=str, help="Path to the output PyTorch model.")
args = parser.parse_args()
convert_visual_bert_checkpoint(args.orig_checkpoint_path, args.pytorch_dump_folder_path)
.\models\visual_bert\modeling_visual_bert.py
""" PyTorch VisualBERT model."""
import math
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss, KLDivLoss, LogSoftmax
from ...activations import ACT2FN
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPooling,
MultipleChoiceModelOutput,
SequenceClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from .configuration_visual_bert import VisualBertConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "VisualBertConfig"
_CHECKPOINT_FOR_DOC = "uclanlp/visualbert-vqa-coco-pre"
VISUAL_BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
"uclanlp/visualbert-vqa",
"uclanlp/visualbert-vqa-pre",
"uclanlp/visualbert-vqa-coco-pre",
"uclanlp/visualbert-vcr",
"uclanlp/visualbert-vcr-pre",
"uclanlp/visualbert-vcr-coco-pre",
"uclanlp/visualbert-nlvr2",
"uclanlp/visualbert-nlvr2-pre",
"uclanlp/visualbert-nlvr2-coco-pre",
]
class VisualBertEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings and visual embeddings."""
def __init__(self, config):
super().__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
self.visual_token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
self.visual_position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
if config.special_visual_initialize:
self.visual_token_type_embeddings.weight.data = nn.Parameter(
self.token_type_embeddings.weight.data.clone(), requires_grad=True
)
self.visual_position_embeddings.weight.data = nn.Parameter(
self.position_embeddings.weight.data.clone(), requires_grad=True
)
self.visual_projection = nn.Linear(config.visual_embedding_dim, config.hidden_size)
class VisualBertSelfAttention(nn.Module):
def __init__(self, config):
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
f"heads ({config.num_attention_heads})"
)
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
output_attentions=False,
):
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
if attention_mask is not None:
attention_scores = attention_scores + attention_mask
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
attention_probs = self.dropout(attention_probs)
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs
class VisualBertSelfOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class VisualBertAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.self = VisualBertSelfAttention(config)
self.output = VisualBertSelfOutput(config)
self.pruned_heads = set()
def prune_heads(self, heads):
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
)
self.self.query = prune_linear_layer(self.self.query, index)
self.self.key = prune_linear_layer(self.self.key, index)
self.self.value = prune_linear_layer(self.self.value, index)
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
output_attentions=False,
):
self_outputs = self.self(
hidden_states,
attention_mask,
head_mask,
output_attentions,
)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:]
return outputs
class VisualBertIntermediate(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class VisualBertOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class VisualBertLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = VisualBertAttention(config)
self.intermediate = VisualBertIntermediate(config)
self.output = VisualBertOutput(config)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
output_attentions=False,
):
self_attention_outputs = self.attention(
hidden_states,
attention_mask,
head_mask,
output_attentions=output_attentions,
)
attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:]
layer_output = apply_chunking_to_forward(
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
)
outputs = (layer_output,) + outputs
return outputs
def feed_forward_chunk(self, attention_output):
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
return layer_output
class VisualBertEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.layer = nn.ModuleList([VisualBertLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
):
for i, layer_module in enumerate(self.layer):
layer_outputs = layer_module(
hidden_states,
attention_mask,
head_mask[i],
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
return hidden_states
):
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
layer_head_mask,
output_attentions,
)
else:
layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v
for v in [
hidden_states,
all_hidden_states,
all_self_attentions,
]
if v is not None
)
return BaseModelOutput(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions
)
class VisualBertPooler(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
class VisualBertPredictionHeadTransform(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
if isinstance(config.hidden_act, str):
self.transform_act_fn = ACT2FN[config.hidden_act]
else:
self.transform_act_fn = config.hidden_act
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states
class VisualBertLMPredictionHead(nn.Module):
def __init__(self, config):
super().__init__()
self.transform = VisualBertPredictionHeadTransform(config)
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
self.decoder.bias = self.bias
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
return hidden_states
class VisualBertPreTrainingHeads(nn.Module):
def __init__(self, config):
super().__init__()
self.predictions = VisualBertLMPredictionHead(config)
self.seq_relationship = nn.Linear(config.hidden_size, 2)
def forward(self, sequence_output, pooled_output):
prediction_scores = self.predictions(sequence_output)
seq_relationship_score = self.seq_relationship(pooled_output)
return prediction_scores, seq_relationship_score
class VisualBertPreTrainedModel(PreTrainedModel):
"""
VisualBert 的预训练模型基类
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = VisualBertConfig
base_model_prefix = "visual_bert"
supports_gradient_checkpointing = True
def _init_weights(self, module):
"""Initialize the weights"""
# 如果模块是线性层或者嵌入层,使用正态分布初始化权重
if isinstance(module, (nn.Linear, nn.Embedding)):
# 与 TensorFlow 版本稍有不同,PyTorch 使用 normal_ 方法初始化
# 参考 https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
# 如果模块是 LayerNorm 层,初始化偏置为零,权重为1
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
# 如果模块是线性层且有偏置项,初始化偏置为零
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
"""
Output type of `VisualBertForPreTraining`.
Args:
loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
Total loss as the sum of the masked language modeling loss and the sentence-image prediction
(classification) loss.
prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
seq_relationship_logits (`torch.FloatTensor` of shape `(batch_size, 2)`):
Prediction scores of the sentence-image prediction (classification) head (scores of True/False continuation
before SoftMax).
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
# 创建一个字符串,描述 VisualBert 模型的基本信息,没有任何特定的输出头部
"The bare VisualBert Model transformer outputting raw hidden-states without any specific head on top.",
# 使用 VISUAL_BERT_START_DOCSTRING 常量来继续完整的文档字符串
VISUAL_BERT_START_DOCSTRING,
)
class VisualBertModel(VisualBertPreTrainedModel):
"""
The model can behave as an encoder (with only self-attention) following the architecture described in [Attention is
all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
"""
def __init__(self, config, add_pooling_layer=True):
super().__init__(config)
self.config = config
# Initialize VisualBert model components
self.embeddings = VisualBertEmbeddings(config) # VisualBert embeddings module
self.encoder = VisualBertEncoder(config) # VisualBert encoder module
# Optionally add a pooling layer based on add_pooling_layer flag
self.pooler = VisualBertPooler(config) if add_pooling_layer else None
# Determine if to bypass the transformer and add an additional layer
self.bypass_transformer = config.bypass_transformer
if self.bypass_transformer:
self.additional_layer = VisualBertLayer(config) # Additional layer for bypassing transformer
# Initialize weights and perform final processing
self.post_init()
def get_input_embeddings(self):
return self.embeddings.word_embeddings
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
def _prune_heads(self, heads_to_prune):
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
# Iterate over layers and prune specified attention heads
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
@add_start_docstrings_to_model_forward(VISUAL_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
visual_embeds: Optional[torch.FloatTensor] = None,
visual_attention_mask: Optional[torch.LongTensor] = None,
visual_token_type_ids: Optional[torch.LongTensor] = None,
image_text_alignment: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
"""
Forward pass of the VisualBert model.
Args:
input_ids (torch.LongTensor, optional): Input token IDs. Defaults to None.
attention_mask (torch.LongTensor, optional): Mask to avoid performing attention on padding token indices.
Defaults to None.
token_type_ids (torch.LongTensor, optional): Segment token indices to differentiate image and text tokens.
Defaults to None.
position_ids (torch.LongTensor, optional): Indices of each input token in its position. Defaults to None.
head_mask (torch.LongTensor, optional): Mask to nullify selected heads of the self-attention modules.
Defaults to None.
inputs_embeds (torch.FloatTensor, optional): Embedded input tokens. Defaults to None.
visual_embeds (torch.FloatTensor, optional): Embedded visual features. Defaults to None.
visual_attention_mask (torch.LongTensor, optional): Mask for visual features to avoid attending on padding.
Defaults to None.
visual_token_type_ids (torch.LongTensor, optional): Segment token indices for visual inputs. Defaults to None.
image_text_alignment (torch.LongTensor, optional): Alignment between image and text tokens. Defaults to None.
output_attentions (bool, optional): Whether to output attentions weights. Defaults to None.
output_hidden_states (bool, optional): Whether to output hidden states. Defaults to None.
return_dict (bool, optional): Whether to return a dictionary. Defaults to None.
Returns:
BaseModelOutputWithPooling: Output of the VisualBert model with optional pooling.
"""
# Implementation of forward pass for VisualBert model
pass
@add_start_docstrings(
"""
VisualBert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a
`sentence-image prediction (classification)` head.
""",
VISUAL_BERT_START_DOCSTRING,
)
class VisualBertForPreTraining(VisualBertPreTrainedModel):
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
# 初始化方法,接受一个配置对象作为参数
def __init__(self, config):
# 调用父类的初始化方法,传递配置对象
super().__init__(config)
# 创建 VisualBertModel 实例,传递配置对象作为参数
self.visual_bert = VisualBertModel(config)
# 创建 VisualBertPreTrainingHeads 实例,传递配置对象作为参数
self.cls = VisualBertPreTrainingHeads(config)
# 调用 post_init 方法,用于初始化权重并进行最终处理
self.post_init()
# 返回预测的输出嵌入层解码器
def get_output_embeddings(self):
return self.cls.predictions.decoder
# 设置输出嵌入层解码器的新值
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
# 前向传播方法,接受多个输入参数,包括输入的序列、注意力掩码、类型 ID 等
@add_start_docstrings_to_model_forward(VISUAL_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=VisualBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
visual_embeds: Optional[torch.FloatTensor] = None,
visual_attention_mask: Optional[torch.LongTensor] = None,
visual_token_type_ids: Optional[torch.LongTensor] = None,
image_text_alignment: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
labels: Optional[torch.LongTensor] = None,
sentence_image_labels: Optional[torch.LongTensor] = None,
@add_start_docstrings(
"""
VisualBert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and
a softmax) e.g. for VCR tasks.
""",
VISUAL_BERT_START_DOCSTRING,
)
class VisualBertForMultipleChoice(VisualBertPreTrainedModel):
# 初始化方法,接受一个配置对象作为参数
def __init__(self, config):
# 调用父类的初始化方法
super().__init__(config)
# 创建 VisualBertModel 模型对象
self.visual_bert = VisualBertModel(config)
# Dropout 层,使用配置中的隐藏层 dropout 概率
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# 全连接层,将隐藏层输出映射到 1 维,用于多选题任务
self.cls = nn.Linear(config.hidden_size, 1)
# 初始化权重并应用最终处理
self.post_init()
@add_start_docstrings_to_model_forward(
VISUAL_BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
)
@replace_return_docstrings(output_type=MultipleChoiceModelOutput, config_class=_CONFIG_FOR_DOC)
# 前向传播方法,接收多个输入参数和一些可选参数
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
visual_embeds: Optional[torch.FloatTensor] = None,
visual_attention_mask: Optional[torch.LongTensor] = None,
visual_token_type_ids: Optional[torch.LongTensor] = None,
image_text_alignment: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
labels: Optional[torch.LongTensor] = None,
@add_start_docstrings(
"""
VisualBert Model with a classification/regression head on top (a dropout and a linear layer on top of the pooled
output) for VQA.
""",
VISUAL_BERT_START_DOCSTRING,
)
class VisualBertForQuestionAnswering(VisualBertPreTrainedModel):
# 初始化方法,接受一个配置对象作为参数
def __init__(self, config):
# 调用父类的初始化方法
super().__init__(config)
# 标签数量等于配置中的标签数
self.num_labels = config.num_labels
# 创建 VisualBertModel 模型对象
self.visual_bert = VisualBertModel(config)
# Dropout 层,使用配置中的隐藏层 dropout 概率
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# 全连接层,将隐藏层输出映射到标签数维度,用于问答任务
self.cls = nn.Linear(config.hidden_size, config.num_labels)
# 初始化权重并应用最终处理
self.post_init()
@add_start_docstrings_to_model_forward(VISUAL_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
# 定义前向传播方法,接受多种输入参数,都是可选的Tensor类型
self,
input_ids: Optional[torch.LongTensor] = None,
# 输入的token IDs,用于模型输入序列的表示
attention_mask: Optional[torch.LongTensor] = None,
# 注意力掩码,指示哪些元素在计算注意力时被忽略
token_type_ids: Optional[torch.LongTensor] = None,
# 用于区分不同句子或段落的类型IDs
position_ids: Optional[torch.LongTensor] = None,
# 位置IDs,指示每个token在序列中的位置
head_mask: Optional[torch.LongTensor] = None,
# 头部掩码,用于指定哪些注意力头部被屏蔽
inputs_embeds: Optional[torch.FloatTensor] = None,
# 输入的嵌入表示,替代input_ids的嵌入
visual_embeds: Optional[torch.FloatTensor] = None,
# 可视化输入的嵌入表示,例如图像的嵌入
visual_attention_mask: Optional[torch.LongTensor] = None,
# 可视化输入的注意力掩码
visual_token_type_ids: Optional[torch.LongTensor] = None,
# 可视化输入的类型IDs
image_text_alignment: Optional[torch.LongTensor] = None,
# 图像与文本对齐信息
output_attentions: Optional[bool] = None,
# 是否输出注意力权重
output_hidden_states: Optional[bool] = None,
# 是否输出隐藏状态
return_dict: Optional[bool] = None,
# 是否返回字典形式的输出
labels: Optional[torch.LongTensor] = None,
# 标签,用于模型的监督学习训练
@add_start_docstrings(
"""
VisualBert Model with a sequence classification head on top (a dropout and a linear layer on top of the pooled
output) for Visual Reasoning e.g. for NLVR task.
""",
VISUAL_BERT_START_DOCSTRING,
)
class VisualBertForVisualReasoning(VisualBertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
# Initialize the VisualBertModel with the provided configuration
self.visual_bert = VisualBertModel(config)
# Dropout layer with dropout probability as specified in the configuration
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# Linear layer for classification, mapping hidden_size to num_labels
self.cls = nn.Linear(config.hidden_size, config.num_labels) # 2
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(VISUAL_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
visual_embeds: Optional[torch.FloatTensor] = None,
visual_attention_mask: Optional[torch.LongTensor] = None,
visual_token_type_ids: Optional[torch.LongTensor] = None,
image_text_alignment: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
labels: Optional[torch.LongTensor] = None,
):
class VisualBertRegionToPhraseAttention(nn.Module):
def __init__(self, config):
super().__init__()
if config.hidden_size % config.num_attention_heads != 0:
raise ValueError(
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
f"heads ({config.num_attention_heads})"
)
# Number of attention heads is set to 1 for this module
self.num_attention_heads = 1 # config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
# Linear transformations for query, key, and value vectors
self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
# Dropout layer for attention probabilities
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x):
# Reshape and permute dimensions for multi-head attention
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
# 将注意力掩码转换为与查询张量相同的数据类型
attention_mask = attention_mask.to(query.dtype)
# 在张量的维度上扩展注意力掩码,以便与后续张量计算兼容
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
# 使用注意力掩码的逆来填充注意力分数张量,以便在softmax操作中不被考虑
attention_mask = (1.0 - attention_mask) * torch.finfo(query.dtype).min
# 使用查询网络层处理查询张量,生成混合查询层
mixed_query_layer = self.query(query)
# 使用键网络层处理键张量,生成混合键层
mixed_key_layer = self.key(key)
# 将混合查询层转置以便进行注意力得分计算
query_layer = self.transpose_for_scores(mixed_query_layer)
# 将混合键层转置以便进行注意力得分计算
key_layer = self.transpose_for_scores(mixed_key_layer)
# 计算注意力得分矩阵,使用查询层与键层的乘积
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
# 对注意力得分进行缩放,以减少梯度消失问题
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
# 将注意力掩码添加到注意力得分中,屏蔽无效位置的注意力
attention_scores = attention_scores + attention_mask
# 去除不必要的维度,使得注意力得分张量的形状与预期一致
attention_scores = attention_scores.squeeze(1)
# 返回最终的注意力得分张量作为前向传播的输出
return attention_scores
# 使用装饰器为该类添加文档字符串,描述其作为 VisualBert 模型的特性和用途,特别是在 Flickr30 Entities 任务中的应用。
@add_start_docstrings(
"""
VisualBert Model with a Masked Language Modeling head and an attention layer on top for Region-to-Phrase Alignment
e.g. for Flickr30 Entities task.
""",
VISUAL_BERT_START_DOCSTRING, # 引用预定义的 VisualBert 文档字符串模板的一部分
)
# 定义 VisualBertForRegionToPhraseAlignment 类,继承自 VisualBertPreTrainedModel
class VisualBertForRegionToPhraseAlignment(VisualBertPreTrainedModel):
# 定义 _tied_weights_keys 列表,用于记录需要绑定权重的键名
_tied_weights_keys = ["cls.predictions.decoder.bias"]
# 初始化方法,接受一个 config 参数
def __init__(self, config):
# 调用父类的初始化方法
super().__init__(config)
# 创建 VisualBertModel 实例,传入 config 参数
self.visual_bert = VisualBertModel(config)
# 创建一个 dropout 层,使用配置中的隐藏层 dropout 概率
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# 创建 VisualBertPreTrainingHeads 实例,用于预测任务
self.cls = VisualBertPreTrainingHeads(config)
# 创建 VisualBertRegionToPhraseAttention 实例,处理区域到短语的注意力对齐
self.attention = VisualBertRegionToPhraseAttention(config)
# 调用类的后初始化方法,可能用于权重初始化和最终处理
self.post_init()
# 使用装饰器为 forward 方法添加文档字符串,描述其输入参数和返回类型
@add_start_docstrings_to_model_forward(VISUAL_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
# 定义 forward 方法,处理模型的前向传播
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
visual_embeds: Optional[torch.FloatTensor] = None,
visual_attention_mask: Optional[torch.LongTensor] = None,
visual_token_type_ids: Optional[torch.LongTensor] = None,
image_text_alignment: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
region_to_phrase_position: Optional[torch.LongTensor] = None,
labels: Optional[torch.LongTensor] = None,
.\models\visual_bert\__init__.py
from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
_import_structure = {"configuration_visual_bert": ["VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "VisualBertConfig"]}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_visual_bert"] = [
"VISUAL_BERT_PRETRAINED_MODEL_ARCHIVE_LIST",
"VisualBertForMultipleChoice",
"VisualBertForPreTraining",
"VisualBertForQuestionAnswering",
"VisualBertForRegionToPhraseAlignment",
"VisualBertForVisualReasoning",
"VisualBertLayer",
"VisualBertModel",
"VisualBertPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_visual_bert import VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, VisualBertConfig
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_visual_bert import (
VISUAL_BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
VisualBertForMultipleChoice,
VisualBertForPreTraining,
VisualBertForQuestionAnswering,
VisualBertForRegionToPhraseAlignment,
VisualBertForVisualReasoning,
VisualBertLayer,
VisualBertModel,
VisualBertPreTrainedModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
.\models\vit\configuration_vit.py
"""
ViT model configuration
"""
from collections import OrderedDict
from typing import Mapping
from packaging import version
from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig
from ...utils import logging
logger = logging.get_logger(__name__)
VIT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"google/vit-base-patch16-224": "https://huggingface.co/vit-base-patch16-224/resolve/main/config.json",
}
class ViTConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`ViTModel`]. It is used to instantiate an ViT
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the ViT
[google/vit-base-patch16-224](https://huggingface.co/google/vit-base-patch16-224) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
"""
Args:
hidden_size (`int`, *optional*, defaults to 768):
编码器层和池化层的维度。
num_hidden_layers (`int`, *optional*, defaults to 12):
Transformer 编码器中的隐藏层数量。
num_attention_heads (`int`, *optional*, defaults to 12):
Transformer 编码器中每个注意力层的注意力头数量。
intermediate_size (`int`, *optional*, defaults to 3072):
Transformer 编码器中“中间”(即前馈)层的维度。
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
编码器和池化器中的非线性激活函数(函数或字符串)。如果是字符串,支持 "gelu"、"relu"、"selu" 和 "gelu_new"。
hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
嵌入、编码器和池化层中所有全连接层的 dropout 概率。
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
注意力概率的 dropout 比例。
initializer_range (`float`, *optional*, defaults to 0.02):
用于初始化所有权重矩阵的截断正态初始化器的标准差。
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
层规范化层使用的 epsilon。
image_size (`int`, *optional*, defaults to 224):
每个图像的大小(分辨率)。
patch_size (`int`, *optional*, defaults to 16):
每个图块的大小(分辨率)。
num_channels (`int`, *optional*, defaults to 3):
输入通道的数量。
qkv_bias (`bool`, *optional*, defaults to `True`):
是否为查询、键和值添加偏置。
encoder_stride (`int`, *optional*, defaults to 16):
用于遮蔽图像建模中解码器头部中空间分辨率的增加因子。
Example:
```
>>> from transformers import ViTConfig, ViTModel
>>>
>>> configuration = ViTConfig()
>>>
>>> model = ViTModel(configuration)
>>>
>>> configuration = model.config
```"""
# 调用父类的初始化方法,传递所有的关键字参数
super().__init__(**kwargs)
# 设置隐藏层的大小
self.hidden_size = hidden_size
# 设置隐藏层的数量
self.num_hidden_layers = num_hidden_layers
# 设置注意力头的数量
self.num_attention_heads = num_attention_heads
# 设置中间层的大小
self.intermediate_size = intermediate_size
# 设置隐藏层的激活函数类型
self.hidden_act = hidden_act
# 设置隐藏层的dropout概率
self.hidden_dropout_prob = hidden_dropout_prob
# 设置注意力概率的dropout概率
self.attention_probs_dropout_prob = attention_probs_dropout_prob
# 设置初始化范围
self.initializer_range = initializer_range
# 设置层归一化的epsilon值
self.layer_norm_eps = layer_norm_eps
# 设置图像的大小
self.image_size = image_size
# 设置图像块的大小
self.patch_size = patch_size
# 设置通道的数量
self.num_channels = num_channels
# 设置查询-键-值的偏置
self.qkv_bias = qkv_bias
# 设置编码器的步长
self.encoder_stride = encoder_stride
class ViTOnnxConfig(OnnxConfig):
# 定义一个继承自OnnxConfig的类ViTOnnxConfig
# 设置torch_onnx_minimum_version属性为最低要求的版本号为1.11
torch_onnx_minimum_version = version.parse("1.11")
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
# 定义一个属性inputs,返回一个OrderedDict,其中包含输入数据的描述
return OrderedDict(
[
("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
# 描述像素值输入的结构,使用字典表示各维度的含义
]
)
@property
def atol_for_validation(self) -> float:
# 定义一个属性atol_for_validation,返回用于验证的绝对误差阈值为1e-4
return 1e-4
.\models\vit\convert_dino_to_pytorch.py
"""Convert ViT checkpoints trained with the DINO method."""
import argparse
import json
from pathlib import Path
import requests
import torch
from huggingface_hub import hf_hub_download
from PIL import Image
from transformers import ViTConfig, ViTForImageClassification, ViTImageProcessor, ViTModel
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
def create_rename_keys(config, base_model=False):
rename_keys = []
for i in range(config.num_hidden_layers):
rename_keys.append((f"blocks.{i}.norm1.weight", f"vit.encoder.layer.{i}.layernorm_before.weight"))
rename_keys.append((f"blocks.{i}.norm1.bias", f"vit.encoder.layer.{i}.layernorm_before.bias"))
rename_keys.append((f"blocks.{i}.attn.proj.weight", f"vit.encoder.layer.{i}.attention.output.dense.weight"))
rename_keys.append((f"blocks.{i}.attn.proj.bias", f"vit.encoder.layer.{i}.attention.output.dense.bias"))
rename_keys.append((f"blocks.{i}.norm2.weight", f"vit.encoder.layer.{i}.layernorm_after.weight"))
rename_keys.append((f"blocks.{i}.norm2.bias", f"vit.encoder.layer.{i}.layernorm_after.bias"))
rename_keys.append((f"blocks.{i}.mlp.fc1.weight", f"vit.encoder.layer.{i}.intermediate.dense.weight"))
rename_keys.append((f"blocks.{i}.mlp.fc1.bias", f"vit.encoder.layer.{i}.intermediate.dense.bias"))
rename_keys.append((f"blocks.{i}.mlp.fc2.weight", f"vit.encoder.layer.{i}.output.dense.weight"))
rename_keys.append((f"blocks.{i}.mlp.fc2.bias", f"vit.encoder.layer.{i}.output.dense.bias"))
rename_keys.extend(
[
("cls_token", "vit.embeddings.cls_token"),
("patch_embed.proj.weight", "vit.embeddings.patch_embeddings.projection.weight"),
("patch_embed.proj.bias", "vit.embeddings.patch_embeddings.projection.bias"),
("pos_embed", "vit.embeddings.position_embeddings"),
]
)
if base_model:
rename_keys.extend(
[
("norm.weight", "layernorm.weight"),
("norm.bias", "layernorm.bias"),
]
)
rename_keys = [(pair[0], pair[1][4:]) if pair[1].startswith("vit") else pair for pair in rename_keys]
else:
rename_keys.extend(
[
("norm.weight", "vit.layernorm.weight"),
("norm.bias", "vit.layernorm.bias"),
("head.weight", "classifier.weight"),
("head.bias", "classifier.bias"),
]
)
return rename_keys
def read_in_q_k_v(state_dict, config, base_model=False):
for i in range(config.num_hidden_layers):
if base_model:
prefix = ""
else:
prefix = "vit."
in_proj_weight = state_dict.pop(f"blocks.{i}.attn.qkv.weight")
in_proj_bias = state_dict.pop(f"blocks.{i}.attn.qkv.bias")
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[
: config.hidden_size, :
]
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.bias"] = in_proj_bias[: config.hidden_size]
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[
config.hidden_size : config.hidden_size * 2, :
]
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.bias"] = in_proj_bias[
config.hidden_size : config.hidden_size * 2
]
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[
-config.hidden_size :, :
]
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.bias"] = in_proj_bias[-config.hidden_size :]
def remove_classification_head_(state_dict):
ignore_keys = ["head.weight", "head.bias"]
for k in ignore_keys:
state_dict.pop(k, None)
def rename_key(dct, old, new):
val = dct.pop(old)
dct[new] = val
def prepare_img():
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
im = Image.open(requests.get(url, stream=True).raw)
return im
@torch.no_grad()
def convert_vit_checkpoint(model_name, pytorch_dump_folder_path, base_model=True):
"""
将模型的权重复制/粘贴/调整到我们的ViT结构中。
"""
config = ViTConfig()
if model_name[-1] == "8":
config.patch_size = 8
if not base_model:
config.num_labels = 1000
repo_id = "huggingface/label-files"
filename = "imagenet-1k-id2label.json"
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
config.id2label = id2label
config.label2id = {v: k for k, v in id2label.items()}
if model_name in ["dino_vits8", "dino_vits16"]:
config.hidden_size = 384
config.intermediate_size = 1536
config.num_hidden_layers = 12
config.num_attention_heads = 6
original_model = torch.hub.load("facebookresearch/dino:main", model_name)
original_model.eval()
state_dict = original_model.state_dict()
if base_model:
remove_classification_head_(state_dict)
rename_keys = create_rename_keys(config, base_model=base_model)
for src, dest in rename_keys:
rename_key(state_dict, src, dest)
read_in_q_k_v(state_dict, config, base_model)
if base_model:
model = ViTModel(config, add_pooling_layer=False).eval()
else:
model = ViTForImageClassification(config).eval()
model.load_state_dict(state_dict)
image_processor = ViTImageProcessor()
encoding = image_processor(images=prepare_img(), return_tensors="pt")
pixel_values = encoding["pixel_values"]
outputs = model(pixel_values)
if base_model:
final_hidden_state_cls_token = original_model(pixel_values)
assert torch.allclose(final_hidden_state_cls_token, outputs.last_hidden_state[:, 0, :], atol=1e-1)
else:
logits = original_model(pixel_values)
assert logits.shape == outputs.logits.shape
assert torch.allclose(logits, outputs.logits, atol=1e-3)
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
print(f"Saving model {model_name} to {pytorch_dump_folder_path}")
model.save_pretrained(pytorch_dump_folder_path)
print(f"Saving image processor to {pytorch_dump_folder_path}")
image_processor.save_pretrained(pytorch_dump_folder_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name",
default="dino_vitb16",
type=str,
help="Name of the model trained with DINO you'd like to convert.",
)
parser.add_argument(
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
)
parser.add_argument(
"--base_model",
action="store_true",
help="Whether to only convert the base model (no projection head weights).",
)
parser.set_defaults(base_model=True)
args = parser.parse_args()
convert_vit_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.base_model)
.\models\vit\convert_vit_timm_to_pytorch.py
def create_rename_keys(config, base_model=False):
rename_keys = []
for i in range(config.num_hidden_layers):
rename_keys.append((f"blocks.{i}.norm1.weight", f"vit.encoder.layer.{i}.layernorm_before.weight"))
rename_keys.append((f"blocks.{i}.norm1.bias", f"vit.encoder.layer.{i}.layernorm_before.bias"))
rename_keys.append((f"blocks.{i}.attn.proj.weight", f"vit.encoder.layer.{i}.attention.output.dense.weight"))
rename_keys.append((f"blocks.{i}.attn.proj.bias", f"vit.encoder.layer.{i}.attention.output.dense.bias"))
rename_keys.append((f"blocks.{i}.norm2.weight", f"vit.encoder.layer.{i}.layernorm_after.weight"))
rename_keys.append((f"blocks.{i}.norm2.bias", f"vit.encoder.layer.{i}.layernorm_after.bias"))
rename_keys.append((f"blocks.{i}.mlp.fc1.weight", f"vit.encoder.layer.{i}.intermediate.dense.weight"))
rename_keys.append((f"blocks.{i}.mlp.fc1.bias", f"vit.encoder.layer.{i}.intermediate.dense.bias"))
rename_keys.append((f"blocks.{i}.mlp.fc2.weight", f"vit.encoder.layer.{i}.output.dense.weight"))
rename_keys.append((f"blocks.{i}.mlp.fc2.bias", f"vit.encoder.layer.{i}.output.dense.bias"))
rename_keys.append(("cls_token", "vit.embeddings.cls_token"))
rename_keys.append(("patch_embed.proj.weight", "vit.embeddings.patch_embeddings.projection.weight"))
rename_keys.append(("patch_embed.proj.bias", "vit.embeddings.patch_embeddings.projection.bias"))
rename_keys.append(("pos_embed", "vit.embeddings.position_embeddings"))
return rename_keys
if base_model:
rename_keys.extend(
[
("norm.weight", "layernorm.weight"),
("norm.bias", "layernorm.bias"),
]
)
rename_keys = [(pair[0], pair[1][4:]) if pair[1].startswith("vit") else pair for pair in rename_keys]
else:
rename_keys.extend(
[
("norm.weight", "vit.layernorm.weight"),
("norm.bias", "vit.layernorm.bias"),
("head.weight", "classifier.weight"),
("head.bias", "classifier.bias"),
]
)
return rename_keys
def read_in_q_k_v(state_dict, config, base_model=False):
for i in range(config.num_hidden_layers):
if base_model:
prefix = ""
else:
prefix = "vit."
in_proj_weight = state_dict.pop(f"blocks.{i}.attn.qkv.weight")
in_proj_bias = state_dict.pop(f"blocks.{i}.attn.qkv.bias")
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[
: config.hidden_size, :
]
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.bias"] = in_proj_bias[: config.hidden_size]
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[
config.hidden_size : config.hidden_size * 2, :
]
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.bias"] = in_proj_bias[
config.hidden_size : config.hidden_size * 2
]
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[
-config.hidden_size :, :
]
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.bias"] = in_proj_bias[-config.hidden_size :]
def remove_classification_head_(state_dict):
ignore_keys = ["head.weight", "head.bias"]
for k in ignore_keys:
state_dict.pop(k, None)
def rename_key(dct, old, new):
val = dct.pop(old)
dct[new] = val
def prepare_img():
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
im = Image.open(requests.get(url, stream=True).raw)
return im
@torch.no_grad()
def convert_vit_checkpoint(vit_name, pytorch_dump_folder_path):
"""
Copy/paste/tweak model's weights to our ViT structure.
"""
config = ViTConfig()
base_model = False
timm_model = timm.create_model(vit_name, pretrained=True)
timm_model.eval()
if not isinstance(getattr(timm_model, "fc_norm", None), torch.nn.Identity):
raise ValueError(f"{vit_name} is not supported in transformers because of the presence of fc_norm.")
if getattr(timm_model, "global_pool", None) == "avg":
raise ValueError(f"{vit_name} is not supported in transformers because of use of global average pooling.")
if "clip" in vit_name and not isinstance(getattr(timm_model, "norm_pre", None), torch.nn.Identity):
raise ValueError(
f"{vit_name} is not supported in transformers because it's a CLIP style ViT with norm_pre layer."
)
if "siglip" in vit_name and getattr(timm_model, "global_pool", None) == "map":
raise ValueError(
f"{vit_name} is not supported in transformers because it's a SigLIP style ViT with attn_pool."
)
if not isinstance(getattr(timm_model.blocks[0], "ls1", None), torch.nn.Identity) or not isinstance(
getattr(timm_model.blocks[0], "ls2", None), torch.nn.Identity
):
raise ValueError(f"{vit_name} is not supported in transformers because it uses a layer scale in its blocks.")
if not isinstance(timm_model.patch_embed, timm.layers.PatchEmbed):
raise ValueError(f"{vit_name} is not supported in transformers because it is a hybrid ResNet-ViT.")
config.patch_size = timm_model.patch_embed.patch_size[0]
config.image_size = timm_model.patch_embed.img_size[0]
config.hidden_size = timm_model.embed_dim
config.intermediate_size = timm_model.blocks[0].mlp.fc1.out_features
config.num_hidden_layers = len(timm_model.blocks)
config.num_attention_heads = timm_model.blocks[0].attn.num_heads
if timm_model.num_classes != 0:
config.num_labels = timm_model.num_classes
imagenet_subset = infer_imagenet_subset(timm_model)
dataset_info = ImageNetInfo(imagenet_subset)
config.id2label = {i: dataset_info.index_to_label_name(i) for i in range(dataset_info.num_classes())}
config.label2id = {v: k for k, v in config.id2label.items()}
else:
print(f"{vit_name} is going to be converted as a feature extractor only.")
base_model = True
state_dict = timm_model.state_dict()
if base_model:
remove_classification_head_(state_dict)
rename_keys = create_rename_keys(config, base_model)
for src, dest in rename_keys:
rename_key(state_dict, src, dest)
read_in_q_k_v(state_dict, config, base_model)
if base_model:
model = ViTModel(config, add_pooling_layer=False).eval()
else:
model = ViTForImageClassification(config).eval()
model.load_state_dict(state_dict)
if "deit" in vit_name:
image_processor = DeiTImageProcessor(size=config.image_size)
else:
image_processor = ViTImageProcessor(size=config.image_size)
encoding = image_processor(images=prepare_img(), return_tensors="pt")
pixel_values = encoding["pixel_values"]
outputs = model(pixel_values)
if base_model:
timm_pooled_output = timm_model.forward_features(pixel_values)
assert timm_pooled_output.shape == outputs.last_hidden_state.shape
assert torch.allclose(timm_pooled_output, outputs.last_hidden_state, atol=1e-1)
else:
timm_logits = timm_model(pixel_values)
assert timm_logits.shape == outputs.logits.shape
assert torch.allclose(timm_logits, outputs.logits, atol=1e-3)
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
print(f"Saving model {vit_name} to {pytorch_dump_folder_path}")
model.save_pretrained(pytorch_dump_folder_path)
print(f"Saving image processor to {pytorch_dump_folder_path}")
image_processor.save_pretrained(pytorch_dump_folder_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--vit_name",
default="vit_base_patch16_224",
type=str,
help="Name of the ViT timm model you'd like to convert.",
)
parser.add_argument(
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
)
args = parser.parse_args()
convert_vit_checkpoint(args.vit_name, args.pytorch_dump_folder_path)
这段代码是一个命令行程序的入口点,使用argparse模块解析命令行参数,然后调用函数`convert_vit_checkpoint`进行处理。
.\models\vit\feature_extraction_vit.py
"""Feature extractor class for ViT."""
import warnings
from ...utils import logging
from .image_processing_vit import ViTImageProcessor
logger = logging.get_logger(__name__)
class ViTFeatureExtractor(ViTImageProcessor):
def __init__(self, *args, **kwargs) -> None:
warnings.warn(
"The class ViTFeatureExtractor is deprecated and will be removed in version 5 of Transformers. Please"
" use ViTImageProcessor instead.",
FutureWarning,
)
super().__init__(*args, **kwargs)
.\models\vit\image_processing_vit.py
from typing import Dict, List, Optional, Union
import numpy as np
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import resize, to_channel_dimension_format
from ...image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
ChannelDimension,
ImageInput,
PILImageResampling,
infer_channel_dimension_format,
is_scaled_image,
make_list_of_images,
to_numpy_array,
valid_images,
validate_kwargs,
validate_preprocess_arguments,
)
from ...utils import TensorType, logging
logger = logging.get_logger(__name__)
class ViTImageProcessor(BaseImageProcessor):
r"""
Constructs a ViT image processor.
"""
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the image's (height, width) dimensions to the specified `(size["height"],
size["width"])`. Can be overridden by the `do_resize` parameter in the `preprocess` method.
size (`dict`, *optional*, defaults to `{"height": 224, "width": 224}`):
Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
method.
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the
`preprocess` method.
do_rescale (`bool`, *optional*, defaults to `True`):
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
parameter in the `preprocess` method.
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
`preprocess` method.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
method.
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
"""
# 设置模型输入名称为"pixel_values"
model_input_names = ["pixel_values"]
# 初始化函数,设置各个参数的默认值,可以通过`preprocess`方法中的对应参数进行覆盖
def __init__(
self,
do_resize: bool = True,
size: Optional[Dict[str, int]] = None,
resample: PILImageResampling = PILImageResampling.BILINEAR,
do_rescale: bool = True,
rescale_factor: Union[int, float] = 1 / 255,
do_normalize: bool = True,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
**kwargs,
) -> None:
# 调用父类初始化方法,并传递所有关键字参数
super().__init__(**kwargs)
# 如果 size 参数不为 None,则使用指定的尺寸;否则使用默认尺寸 {"height": 224, "width": 224}
size = size if size is not None else {"height": 224, "width": 224}
# 根据 size 获取一个标准化的尺寸字典
size = get_size_dict(size)
# 初始化是否执行 resize 操作的标志
self.do_resize = do_resize
# 初始化是否执行 rescale 操作的标志
self.do_rescale = do_rescale
# 初始化是否执行 normalize 操作的标志
self.do_normalize = do_normalize
# 将 size 存储到对象属性中
self.size = size
# 设定图像 resize 时使用的 resample 方法
self.resample = resample
# 设定图像 rescale 时的缩放因子
self.rescale_factor = rescale_factor
# 如果 image_mean 不为 None,则使用给定的 image_mean;否则使用 IMAGENET_STANDARD_MEAN
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
# 如果 image_std 不为 None,则使用给定的 image_std;否则使用 IMAGENET_STANDARD_STD
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
# 初始化一个有效的处理器键列表,用于检查处理器的有效性
self._valid_processor_keys = [
"images",
"do_resize",
"size",
"resample",
"do_rescale",
"rescale_factor",
"do_normalize",
"image_mean",
"image_std",
"return_tensors",
"data_format",
"input_data_format",
]
) -> np.ndarray:
"""
Resize an image to `(size["height"], size["width"])`.
Args:
image (`np.ndarray`):
Image to resize.
size (`Dict[str, int]`):
Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
`PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`.
data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the output image. If unset, the channel dimension format of the input
image is used. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
Returns:
`np.ndarray`: The resized image.
"""
size = get_size_dict(size) # 获取调整大小后的字典格式大小
if "height" not in size or "width" not in size:
raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
output_size = (size["height"], size["width"]) # 获取调整后的输出大小
return resize(
image,
size=output_size,
resample=resample,
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
)
def preprocess(
self,
images: ImageInput,
do_resize: Optional[bool] = None,
size: Dict[str, int] = None,
resample: PILImageResampling = None,
do_rescale: Optional[bool] = None,
rescale_factor: Optional[float] = None,
do_normalize: Optional[bool] = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
):
"""
Preprocesses a batch of images including resizing, rescaling, normalization, and tensor conversion.
Args:
images (`ImageInput`):
Batch of input images.
do_resize (`bool`, *optional*):
Whether to resize the images. If `True`, resizing will be performed according to `size`.
size (`Dict[str, int]`, *optional*):
Dictionary specifying the target size for resizing each image in the batch.
resample (`PILImageResampling`, *optional*):
Resampling method to use for resizing, default is `PILImageResampling.BILINEAR`.
do_rescale (`bool`, *optional*):
Whether to rescale the images. If `True`, images will be scaled by `rescale_factor`.
rescale_factor (`float`, *optional*):
Scaling factor for rescaling images.
do_normalize (`bool`, *optional*):
Whether to normalize the images.
image_mean (`float` or `List[float]`, *optional*):
Mean values for normalization.
image_std (`float` or `List[float]`, *optional*):
Standard deviation values for normalization.
return_tensors (`str` or `TensorType`, *optional*):
If specified, converts the output to the desired tensor type (e.g., `"torch"` or `"tensorflow"`).
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the output images.
input_data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the input images.
Returns:
`np.ndarray` or `TensorType`: Preprocessed batch of images.
"""
.\models\vit\modeling_flax_vit.py
from typing import Optional, Tuple
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxBaseModelOutputWithPooling, FlaxSequenceClassifierOutput
from ...modeling_flax_utils import (
ACT2FN,
FlaxPreTrainedModel,
append_replace_return_docstrings,
overwrite_call_docstring,
)
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward
from .configuration_vit import ViTConfig
VIT_START_DOCSTRING = r"""
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading, saving and converting weights from PyTorch models)
This model is also a
[flax.linen.Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) subclass. Use it as
a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and
behavior.
Finally, this model supports inherent JAX features such as:
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
"""
"""
Define a docstring for the module `FlaxViTPatchEmbeddings`.
This docstring provides detailed information about the inputs expected by the module, including:
- `pixel_values`: A numpy array containing pixel values of shape `(batch_size, num_channels, height, width)`.
This array can be obtained using an `AutoImageProcessor`.
- `output_attentions`: An optional boolean indicating whether to return attention tensors from all layers.
- `output_hidden_states`: An optional boolean indicating whether to return hidden states from all layers.
- `return_dict`: An optional boolean indicating whether to return a `ModelOutput` instead of a tuple.
"""
class FlaxViTPatchEmbeddings(nn.Module):
config: ViTConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
image_size = self.config.image_size
patch_size = self.config.patch_size
num_patches = (image_size // patch_size) * (image_size // patch_size)
self.num_patches = num_patches
self.num_channels = self.config.num_channels
self.projection = nn.Conv(
self.config.hidden_size,
kernel_size=(patch_size, patch_size),
strides=(patch_size, patch_size),
padding="VALID",
dtype=self.dtype,
kernel_init=jax.nn.initializers.variance_scaling(
self.config.initializer_range**2, "fan_in", "truncated_normal"
),
)
def __call__(self, pixel_values):
num_channels = pixel_values.shape[-1]
if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
embeddings = self.projection(pixel_values)
batch_size, _, _, channels = embeddings.shape
return jnp.reshape(embeddings, (batch_size, -1, channels))
class FlaxViTEmbeddings(nn.Module):
"""Construct the CLS token, position and patch embeddings."""
config: ViTConfig
dtype: jnp.dtype = jnp.float32
self.cls_token = self.param(
"cls_token",
jax.nn.initializers.variance_scaling(self.config.initializer_range**2, "fan_in", "truncated_normal"),
(1, 1, self.config.hidden_size),
)
self.patch_embeddings = FlaxViTPatchEmbeddings(self.config, dtype=self.dtype)
num_patches = self.patch_embeddings.num_patches
self.position_embeddings = self.param(
"position_embeddings",
jax.nn.initializers.variance_scaling(self.config.initializer_range**2, "fan_in", "truncated_normal"),
(1, num_patches + 1, self.config.hidden_size),
)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
def __call__(self, pixel_values, deterministic=True):
batch_size = pixel_values.shape[0]
embeddings = self.patch_embeddings(pixel_values)
cls_tokens = jnp.broadcast_to(self.cls_token, (batch_size, 1, self.config.hidden_size))
embeddings = jnp.concatenate((cls_tokens, embeddings), axis=1)
embeddings = embeddings + self.position_embeddings
embeddings = self.dropout(embeddings, deterministic=deterministic)
return embeddings
class FlaxViTSelfAttention(nn.Module):
config: ViTConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
if self.config.hidden_size % self.config.num_attention_heads != 0:
raise ValueError(
"`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`:"
" {self.config.num_attention_heads}"
)
self.query = nn.Dense(
self.config.hidden_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.variance_scaling(
self.config.initializer_range**2, mode="fan_in", distribution="truncated_normal"
),
use_bias=self.config.qkv_bias,
)
self.key = nn.Dense(
self.config.hidden_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.variance_scaling(
self.config.initializer_range**2, mode="fan_in", distribution="truncated_normal"
),
use_bias=self.config.qkv_bias,
)
self.value = nn.Dense(
self.config.hidden_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.variance_scaling(
self.config.initializer_range**2, mode="fan_in", distribution="truncated_normal"
),
use_bias=self.config.qkv_bias,
)
def __call__(self, hidden_states, deterministic: bool = True, output_attentions: bool = False):
head_dim = self.config.hidden_size // self.config.num_attention_heads
query_states = self.query(hidden_states).reshape(
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
)
value_states = self.value(hidden_states).reshape(
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
)
key_states = self.key(hidden_states).reshape(
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
)
dropout_rng = None
if not deterministic and self.config.attention_probs_dropout_prob > 0.0:
dropout_rng = self.make_rng("dropout")
attn_weights = dot_product_attention_weights(
query_states,
key_states,
dropout_rng=dropout_rng,
dropout_rate=self.config.attention_probs_dropout_prob,
broadcast_dropout=True,
deterministic=deterministic,
dtype=self.dtype,
precision=None,
)
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,))
outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
return outputs
class FlaxViTSelfOutput(nn.Module):
config: ViTConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.dense = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.variance_scaling(
self.config.initializer_range**2, "fan_in", "truncated_normal"
),
dtype=self.dtype,
)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
def __call__(self, hidden_states, input_tensor, deterministic: bool = True):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
return hidden_states
class FlaxViTAttention(nn.Module):
config: ViTConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.attention = FlaxViTSelfAttention(self.config, dtype=self.dtype)
self.output = FlaxViTSelfOutput(self.config, dtype=self.dtype)
def __call__(self, hidden_states, deterministic=True, output_attentions: bool = False):
attn_outputs = self.attention(hidden_states, deterministic=deterministic, output_attentions=output_attentions)
attn_output = attn_outputs[0]
hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic)
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_outputs[1],)
return outputs
class FlaxViTIntermediate(nn.Module):
config: ViTConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.dense = nn.Dense(
self.config.intermediate_size,
kernel_init=jax.nn.initializers.variance_scaling(
self.config.initializer_range**2, "fan_in", "truncated_normal"
),
dtype=self.dtype,
)
self.activation = ACT2FN[self.config.hidden_act]
def __call__(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.activation(hidden_states)
return hidden_states
class FlaxViTOutput(nn.Module):
config: ViTConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.dense = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.variance_scaling(
self.config.initializer_range**2, "fan_in", "truncated_normal"
),
dtype=self.dtype,
)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
def __call__(self, hidden_states, attention_output, deterministic: bool = True):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
hidden_states = hidden_states + attention_output
return hidden_states
class FlaxViTLayer(nn.Module):
config: ViTConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.attention = FlaxViTAttention(self.config, dtype=self.dtype)
self.intermediate = FlaxViTIntermediate(self.config, dtype=self.dtype)
self.output = FlaxViTOutput(self.config, dtype=self.dtype)
self.layernorm_before = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
self.layernorm_after = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
def __call__(self, hidden_states, deterministic: bool = True, output_attentions: bool = False):
attention_outputs = self.attention(
self.layernorm_before(hidden_states),
deterministic=deterministic,
output_attentions=output_attentions,
)
attention_output = attention_outputs[0]
attention_output = attention_output + hidden_states
layer_output = self.layernorm_after(attention_output)
hidden_states = self.intermediate(layer_output)
hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic)
outputs = (hidden_states,)
if output_attentions:
outputs += (attention_outputs[1],)
return outputs
class FlaxViTLayerCollection(nn.Module):
config: ViTConfig
def setup(self):
self.layers = [
FlaxViTLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers)
]
def __call__(
self,
hidden_states,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
all_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
for i, layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = layer(hidden_states, deterministic=deterministic, output_attentions=output_attentions)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions += (layer_outputs[1],)
if output_hidden_states:
all_hidden_states += (hidden_states,)
outputs = (hidden_states,)
if not return_dict:
return tuple(v for v in outputs if v is not None)
return FlaxBaseModelOutput(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
)
class FlaxViTEncoder(nn.Module):
config: ViTConfig
def setup(self):
self.layer = FlaxViTLayerCollection(self.config, dtype=self.dtype)
def __call__(
self,
hidden_states,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
return self.layer(
hidden_states,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
class FlaxViTPooler(nn.Module):
config: ViTConfig
def setup(self):
self.dense = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.variance_scaling(
self.config.initializer_range**2, "fan_in", "truncated_normal"
),
dtype=self.dtype,
)
def __call__(self, hidden_states):
cls_hidden_state = hidden_states[:, 0]
cls_hidden_state = self.dense(cls_hidden_state)
return nn.tanh(cls_hidden_state)
class FlaxViTPreTrainedModel(FlaxPreTrainedModel):
"""
一个处理权重初始化和简单接口以下载和加载预训练模型的抽象类。
"""
config_class = ViTConfig
base_model_prefix = "vit"
main_input_name = "pixel_values"
module_class: nn.Module = None
def __init__(
self,
config: ViTConfig,
input_shape=None,
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs,
):
module = self.module_class(config=config, dtype=dtype, **kwargs)
if input_shape is None:
input_shape = (1, config.image_size, config.image_size, config.num_channels)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
pixel_values = jnp.zeros(input_shape, dtype=self.dtype)
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
random_params = self.module.init(rngs, pixel_values, return_dict=False)["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
@add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def __call__(
self,
pixel_values,
params: dict = None,
dropout_rng: jax.random.PRNGKey = None,
train: bool = False,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1))
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
return self.module.apply(
{"params": params or self.params},
jnp.array(pixel_values, dtype=jnp.float32),
not train,
output_attentions,
output_hidden_states,
return_dict,
rngs=rngs,
)
class FlaxViTModule(nn.Module):
config: ViTConfig
dtype: jnp.dtype = jnp.float32
add_pooling_layer: bool = True
def setup(self):
self.embeddings = FlaxViTEmbeddings(self.config, dtype=self.dtype)
self.encoder = FlaxViTEncoder(self.config, dtype=self.dtype)
self.layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
self.pooler = FlaxViTPooler(self.config, dtype=self.dtype) if self.add_pooling_layer else None
def __call__(
self,
pixel_values,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
hidden_states = self.embeddings(pixel_values, deterministic=deterministic)
outputs = self.encoder(
hidden_states,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
hidden_states = self.layernorm(hidden_states)
pooled = self.pooler(hidden_states) if self.add_pooling_layer else None
if not return_dict:
if pooled is None:
return (hidden_states,) + outputs[1:]
return (hidden_states, pooled) + outputs[1:]
return FlaxBaseModelOutputWithPooling(
last_hidden_state=hidden_states,
pooler_output=pooled,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"The bare ViT Model transformer outputting raw hidden-states without any specific head on top.",
VIT_START_DOCSTRING,
)
class FlaxViTModel(FlaxViTPreTrainedModel):
module_class = FlaxViTModule
FLAX_VISION_MODEL_DOCSTRING = """
Returns:
Examples:
```
>>> from transformers import AutoImageProcessor, FlaxViTModel
>>> from PIL import Image
>>> import requests
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
>>> model = FlaxViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
>>> inputs = image_processor(images=image, return_tensors="np")
>>> outputs = model(**inputs)
>>> last_hidden_states = outputs.last_hidden_state
```
"""
overwrite_call_docstring(FlaxViTModel, FLAX_VISION_MODEL_DOCSTRING)
append_replace_return_docstrings(FlaxViTModel, output_type=FlaxBaseModelOutputWithPooling, config_class=ViTConfig)
class FlaxViTForImageClassificationModule(nn.Module):
config: ViTConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.vit = FlaxViTModule(config=self.config, dtype=self.dtype, add_pooling_layer=False)
self.classifier = nn.Dense(
self.config.num_labels,
dtype=self.dtype,
kernel_init=jax.nn.initializers.variance_scaling(
self.config.initializer_range**2, "fan_in", "truncated_normal"
),
)
def __call__(
self,
pixel_values=None,
deterministic: bool = True,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.vit(
pixel_values,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
logits = self.classifier(hidden_states[:, 0, :])
if not return_dict:
output = (logits,) + outputs[2:]
return output
return FlaxSequenceClassifierOutput(
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"""
ViT Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
the [CLS] token) e.g. for ImageNet.
""",
VIT_START_DOCSTRING,
)
class FlaxViTForImageClassification(FlaxViTPreTrainedModel):
module_class = FlaxViTForImageClassificationModule
FLAX_VISION_CLASSIF_DOCSTRING = """
Returns:
Example:
```
>>> from transformers import AutoImageProcessor, FlaxViTForImageClassification
>>> from PIL import Image
>>> import jax
>>> import requests
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")
>>> model = FlaxViTForImageClassification.from_pretrained("google/vit-base-patch16-224")
>>> inputs = image_processor(images=image, return_tensors="np")
>>> outputs = model(**inputs)
>>> logits = outputs.logits
>>> # model predicts one of the 1000 ImageNet classes
>>> predicted_class_idx = jax.numpy.argmax(logits, axis=-1)
>>> print("Predicted class:", model.config.id2label[predicted_class_idx.item()])
```
"""
overwrite_call_docstring(FlaxViTForImageClassification, FLAX_VISION_CLASSIF_DOCSTRING)
append_replace_return_docstrings(
FlaxViTForImageClassification, output_type=FlaxSequenceClassifierOutput, config_class=ViTConfig
)
.\models\vit\modeling_tf_vit.py
""" TF 2.0 ViT model."""
from __future__ import annotations
import collections.abc
import math
from typing import Optional, Tuple, Union
import numpy as np
import tensorflow as tf
from ...activations_tf import get_tf_activation
from ...modeling_tf_outputs import TFBaseModelOutput, TFBaseModelOutputWithPooling, TFSequenceClassifierOutput
from ...modeling_tf_utils import (
TFModelInputType,
TFPreTrainedModel,
TFSequenceClassificationLoss,
get_initializer,
keras,
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list, stable_softmax
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from .configuration_vit import ViTConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "ViTConfig"
_CHECKPOINT_FOR_DOC = "google/vit-base-patch16-224-in21k"
_EXPECTED_OUTPUT_SHAPE = [1, 197, 768]
_IMAGE_CLASS_CHECKPOINT = "google/vit-base-patch16-224"
_IMAGE_CLASS_EXPECTED_OUTPUT = "Egyptian cat"
class TFViTEmbeddings(keras.layers.Layer):
"""
Construct the CLS token, position and patch embeddings.
"""
def __init__(self, config: ViTConfig, **kwargs):
super().__init__(**kwargs)
self.patch_embeddings = TFViTPatchEmbeddings(config, name="patch_embeddings")
self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
self.config = config
def build(self, input_shape=None):
num_patches = self.patch_embeddings.num_patches
self.cls_token = self.add_weight(
shape=(1, 1, self.config.hidden_size),
initializer=get_initializer(self.config.initializer_range),
trainable=True,
name="cls_token",
)
self.position_embeddings = self.add_weight(
shape=(1, num_patches + 1, self.config.hidden_size),
initializer=get_initializer(self.config.initializer_range),
trainable=True,
name="position_embeddings",
)
if self.built:
return
self.built = True
if getattr(self, "patch_embeddings", None) is not None:
with tf.name_scope(self.patch_embeddings.name):
self.patch_embeddings.build(None)
def interpolate_pos_encoding(self, embeddings, height, width) -> tf.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
resolution images.
Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""
batch_size, seq_len, dim = shape_list(embeddings)
num_patches = seq_len - 1
_, num_positions, _ = shape_list(self.position_embeddings)
num_positions -= 1
if num_patches == num_positions and height == width:
return self.position_embeddings
class_pos_embed = self.position_embeddings[:, :1]
patch_pos_embed = self.position_embeddings[:, 1:]
h0 = height // self.config.patch_size
w0 = width // self.config.patch_size
patch_pos_embed = tf.image.resize(
images=tf.reshape(
patch_pos_embed, shape=(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
),
size=(h0, w0),
method="bicubic",
)
shape = shape_list(patch_pos_embed)
assert h0 == shape[-3] and w0 == shape[-2]
patch_pos_embed = tf.reshape(tensor=patch_pos_embed, shape=(1, -1, dim))
return tf.concat(values=(class_pos_embed, patch_pos_embed), axis=1)
def call(
self, pixel_values: tf.Tensor, interpolate_pos_encoding: bool = False, training: bool = False
) -> tf.Tensor:
batch_size, num_channels, height, width = shape_list(pixel_values)
embeddings = self.patch_embeddings(
pixel_values, interpolate_pos_encoding=interpolate_pos_encoding, training=training
)
cls_tokens = tf.repeat(self.cls_token, repeats=batch_size, axis=0)
embeddings = tf.concat((cls_tokens, embeddings), axis=1)
if interpolate_pos_encoding:
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
else:
embeddings = embeddings + self.position_embeddings
embeddings = self.dropout(embeddings, training=training)
return embeddings
class TFViTPatchEmbeddings(keras.layers.Layer):
"""
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
Transformer.
"""
def __init__(self, config: ViTConfig, **kwargs):
super().__init__(**kwargs)
image_size, patch_size = config.image_size, config.patch_size
num_channels, hidden_size = config.num_channels, config.hidden_size
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.image_size = image_size
self.patch_size = patch_size
self.num_patches = num_patches
self.num_channels = num_channels
self.config = config
self.projection = keras.layers.Conv2D(
filters=hidden_size,
kernel_size=patch_size,
strides=patch_size,
padding="valid",
data_format="channels_last",
use_bias=True,
kernel_initializer=get_initializer(self.config.initializer_range),
bias_initializer="zeros",
name="projection",
)
def call(
self, pixel_values: tf.Tensor, interpolate_pos_encoding: bool = False, training: bool = False
) -> tf.Tensor:
batch_size, num_channels, height, width = shape_list(pixel_values)
if tf.executing_eagerly() and num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
if not interpolate_pos_encoding:
if tf.executing_eagerly():
if height != self.image_size[0] or width != self.image_size[1]:
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model"
f" ({self.image_size[0]}*{self.image_size[1]})."
)
pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1))
projection = self.projection(pixel_values)
num_patches = (width // self.patch_size[1]) * (height // self.patch_size[0])
embeddings = tf.reshape(tensor=projection, shape=(batch_size, num_patches, -1))
return embeddings
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "projection", None) is not None:
with tf.name_scope(self.projection.name):
self.projection.build([None, None, None, self.num_channels])
class TFViTSelfAttention(keras.layers.Layer):
def __init__(self, config: ViTConfig, **kwargs):
super().__init__(**kwargs)
if config.hidden_size % config.num_attention_heads != 0:
raise ValueError(
f"The hidden size ({config.hidden_size}) is not a multiple of the number "
f"of attention heads ({config.num_attention_heads})"
)
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.sqrt_att_head_size = math.sqrt(self.attention_head_size)
self.query = keras.layers.Dense(
units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
)
self.key = keras.layers.Dense(
units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key"
)
self.value = keras.layers.Dense(
units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
)
self.dropout = keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
self.config = config
def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))
return tf.transpose(tensor, perm=[0, 2, 1, 3])
def call(
self,
hidden_states: tf.Tensor,
head_mask: tf.Tensor,
output_attentions: bool,
training: bool = False,
```
) -> Tuple[tf.Tensor]:
batch_size = shape_list(hidden_states)[0]
mixed_query_layer = self.query(inputs=hidden_states)
mixed_key_layer = self.key(inputs=hidden_states)
mixed_value_layer = self.value(inputs=hidden_states)
query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)
attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype)
attention_scores = tf.divide(attention_scores, dk)
attention_probs = stable_softmax(logits=attention_scores, axis=-1)
attention_probs = self.dropout(inputs=attention_probs, training=training)
if head_mask is not None:
attention_probs = tf.multiply(attention_probs, head_mask)
attention_output = tf.matmul(attention_probs, value_layer)
attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3])
attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size))
outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "query", None) is not None:
with tf.name_scope(self.query.name):
self.query.build([None, None, self.config.hidden_size])
if getattr(self, "key", None) is not None:
with tf.name_scope(self.key.name):
self.key.build([None, None, self.config.hidden_size])
if getattr(self, "value", None) is not None:
with tf.name_scope(self.value.name):
self.value.build([None, None, self.config.hidden_size])
class TFViTSelfOutput(keras.layers.Layer):
"""
The residual connection is defined in TFViTLayer instead of here (as is the case with other models), due to the
layernorm applied before each block.
"""
def __init__(self, config: ViTConfig, **kwargs):
super().__init__(**kwargs)
self.dense = keras.layers.Dense(
units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
)
self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
self.config = config
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.dropout(inputs=hidden_states, training=training)
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
class TFViTAttention(keras.layers.Layer):
def __init__(self, config: ViTConfig, **kwargs):
super().__init__(**kwargs)
self.self_attention = TFViTSelfAttention(config, name="attention")
self.dense_output = TFViTSelfOutput(config, name="output")
def prune_heads(self, heads):
raise NotImplementedError
def call(
self,
input_tensor: tf.Tensor,
head_mask: tf.Tensor,
output_attentions: bool,
training: bool = False,
) -> Tuple[tf.Tensor]:
self_outputs = self.self_attention(
hidden_states=input_tensor, head_mask=head_mask, output_attentions=output_attentions, training=training
)
attention_output = self.dense_output(
hidden_states=self_outputs[0], input_tensor=input_tensor, training=training
)
outputs = (attention_output,) + self_outputs[1:]
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "self_attention", None) is not None:
with tf.name_scope(self.self_attention.name):
self.self_attention.build(None)
if getattr(self, "dense_output", None) is not None:
with tf.name_scope(self.dense_output.name):
self.dense_output.build(None)
class TFViTIntermediate(keras.layers.Layer):
def __init__(self, config: ViTConfig, **kwargs):
super().__init__(**kwargs)
self.dense = keras.layers.Dense(
units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = get_tf_activation(config.hidden_act)
else:
self.intermediate_act_fn = config.hidden_act
self.config = config
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
class TFViTOutput(keras.layers.Layer):
def __init__(self, config: ViTConfig, **kwargs):
super().__init__(**kwargs)
self.dense = keras.layers.Dense(
units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
)
self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
self.config = config
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.dropout(inputs=hidden_states, training=training)
hidden_states = hidden_states + input_tensor
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.intermediate_size])
class TFViTLayer(keras.layers.Layer):
"""This corresponds to the Block class in the timm implementation."""
def __init__(self, config: ViTConfig, **kwargs):
super().__init__(**kwargs)
self.attention = TFViTAttention(config, name="attention")
self.intermediate = TFViTIntermediate(config, name="intermediate")
self.vit_output = TFViTOutput(config, name="output")
self.layernorm_before = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm_before")
self.layernorm_after = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm_after")
self.config = config
def call(
self,
hidden_states: tf.Tensor,
head_mask: tf.Tensor,
output_attentions: bool,
training: bool = False,
) -> Tuple[tf.Tensor]:
attention_outputs = self.attention(
input_tensor=self.layernorm_before(inputs=hidden_states),
head_mask=head_mask,
output_attentions=output_attentions,
training=training,
)
attention_output = attention_outputs[0]
hidden_states = attention_output + hidden_states
layer_output = self.layernorm_after(inputs=hidden_states)
intermediate_output = self.intermediate(hidden_states=layer_output)
layer_output = self.vit_output(
hidden_states=intermediate_output, input_tensor=hidden_states, training=training
)
outputs = (layer_output,) + attention_outputs[1:]
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "attention", None) is not None:
with tf.name_scope(self.attention.name):
self.attention.build(None)
if getattr(self, "intermediate", None) is not None:
with tf.name_scope(self.intermediate.name):
self.intermediate.build(None)
if getattr(self, "vit_output", None) is not None:
with tf.name_scope(self.vit_output.name):
self.vit_output.build(None)
if getattr(self, "layernorm_before", None) is not None:
with tf.name_scope(self.layernorm_before.name):
self.layernorm_before.build([None, None, self.config.hidden_size])
if getattr(self, "layernorm_after", None) is not None:
with tf.name_scope(self.layernorm_after.name):
self.layernorm_after.build([None, None, self.config.hidden_size])
class TFViTEncoder(keras.layers.Layer):
def __init__(self, config: ViTConfig, **kwargs):
super().__init__(**kwargs)
self.layer = [TFViTLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)]
def call(
self,
hidden_states: tf.Tensor,
head_mask: tf.Tensor,
output_attentions: bool,
output_hidden_states: bool,
return_dict: bool,
training: bool = False,
) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_outputs = layer_module(
hidden_states=hidden_states,
head_mask=head_mask[i],
output_attentions=output_attentions,
training=training,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
return TFBaseModelOutput(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "layer", None) is not None:
for layer in self.layer:
with tf.name_scope(layer.name):
layer.build(None)
@keras_serializable
class TFViTMainLayer(keras.layers.Layer):
config_class = ViTConfig
def __init__(self, config: ViTConfig, add_pooling_layer: bool = True, **kwargs):
super().__init__(**kwargs)
self.config = config
self.embeddings = TFViTEmbeddings(config, name="embeddings")
self.encoder = TFViTEncoder(config, name="encoder")
self.layernorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm")
self.pooler = TFViTPooler(config, name="pooler") if add_pooling_layer else None
def get_input_embeddings(self) -> keras.layers.Layer:
return self.embeddings.patch_embeddings
def _prune_heads(self, heads_to_prune):
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
raise NotImplementedError
@unpack_inputs
def call(
self,
pixel_values: TFModelInputType | None = None,
head_mask: np.ndarray | tf.Tensor | None = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]:
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
embedding_output = self.embeddings(
pixel_values=pixel_values,
interpolate_pos_encoding=interpolate_pos_encoding,
training=training,
)
if head_mask is not None:
raise NotImplementedError
else:
head_mask = [None] * self.config.num_hidden_layers
encoder_outputs = self.encoder(
hidden_states=embedding_output,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
sequence_output = encoder_outputs[0]
sequence_output = self.layernorm(inputs=sequence_output)
pooled_output = self.pooler(hidden_states=sequence_output) if self.pooler is not None else None
if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]
return TFBaseModelOutputWithPooling(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "embeddings", None) is not None:
with tf.name_scope(self.embeddings.name):
self.embeddings.build(None)
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
if getattr(self, "layernorm", None) is not None:
with tf.name_scope(self.layernorm.name):
self.layernorm.build([None, None, self.config.hidden_size])
if getattr(self, "pooler", None) is not None:
with tf.name_scope(self.pooler.name):
self.pooler.build(None)
"""
This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
behavior.
<Tip>
TensorFlow models and layers in `transformers` accept two formats as input:
- having all inputs as keyword arguments (like PyTorch models), or
- having all inputs as a list, tuple or dict in the first positional argument.
The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second
format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with
the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first
positional argument:
- a single Tensor with `pixel_values` only and nothing else: `model(pixel_values)`
- a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
`model([pixel_values, attention_mask])` or `model([pixel_values, attention_mask, token_type_ids])`
- a dictionary with one or several input Tensors associated to the input names given in the docstring:
`model({"pixel_values": pixel_values, "token_type_ids": token_type_ids})`
Note that when creating models and layers with
[subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
about any of this, as you can just pass inputs like you would to any other Python function!
</Tip>
Args:
config ([`ViTConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
"""
Args:
pixel_values (`np.ndarray`, `tf.Tensor`, `List[tf.Tensor]` ``Dict[str, tf.Tensor]` or `Dict[str, np.ndarray]` and each example must have the shape `(batch_size, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`]
for details.
head_mask (`np.ndarray` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the
config will be used instead.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
used instead.
interpolate_pos_encoding (`bool`, *optional*):
Whether to interpolate the pre-trained position encodings.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
eager mode, in graph mode the value will always be set to True.
training (`bool`, *optional*, defaults to `False``):
Whether or not to use the model in training mode (some modules like dropout modules have different
behaviors between training and evaluation).
@add_start_docstrings(
"""
ViT Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
the [CLS] token) e.g. for ImageNet.
""",
VIT_START_DOCSTRING,
)
class TFViTModel(TFViTPreTrainedModel):
def __init__(self, config: ViTConfig, *inputs, add_pooling_layer=True, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.vit = TFViTMainLayer(config, add_pooling_layer=add_pooling_layer, name="vit")
@unpack_inputs
@add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TFBaseModelOutputWithPooling,
config_class=_CONFIG_FOR_DOC,
modality="vision",
expected_output=_EXPECTED_OUTPUT_SHAPE,
)
def call(
self,
pixel_values: TFModelInputType | None = None,
head_mask: np.ndarray | tf.Tensor | None = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]:
outputs = self.vit(
pixel_values=pixel_values,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict,
training=training,
)
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "vit", None) is not None:
with tf.name_scope(self.vit.name):
self.vit.build(None)
class TFViTPooler(keras.layers.Layer):
def __init__(self, config: ViTConfig, **kwargs):
super().__init__(**kwargs)
self.dense = keras.layers.Dense(
units=config.hidden_size,
kernel_initializer=get_initializer(config.initializer_range),
activation="tanh",
name="dense",
)
self.config = config
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(inputs=first_token_tensor)
return pooled_output
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
"""
<Tip>
Note that it's possible to fine-tune ViT on higher resolution images than the ones it has been trained on, by
setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained
position embeddings to the higher resolution.
</Tip>
""",
VIT_START_DOCSTRING,
class TFViTForImageClassification(TFViTPreTrainedModel, TFSequenceClassificationLoss):
def __init__(self, config: ViTConfig, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
self.vit = TFViTMainLayer(config, add_pooling_layer=False, name="vit")
self.classifier = keras.layers.Dense(
units=config.num_labels,
kernel_initializer=get_initializer(config.initializer_range),
name="classifier",
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_IMAGE_CLASS_CHECKPOINT,
output_type=TFSequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
)
def call(
self,
pixel_values: TFModelInputType | None = None,
head_mask: np.ndarray | tf.Tensor | None = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: Optional[bool] = None,
return_dict: Optional[bool] = None,
labels: np.ndarray | tf.Tensor | None = None,
training: Optional[bool] = False,
) -> Union[TFSequenceClassifierOutput, Tuple[tf.Tensor]]:
r"""
labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*):
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
outputs = self.vit(
pixel_values=pixel_values,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict,
training=training,
)
sequence_output = outputs[0]
logits = self.classifier(inputs=sequence_output[:, 0, :])
loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TFSequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "vit", None) is not None:
with tf.name_scope(self.vit.name):
self.vit.build(None)
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build([None, None, self.config.hidden_size])
.\models\vit\modeling_vit.py
""" PyTorch ViT model."""
import collections.abc
import math
from typing import Dict, List, Optional, Set, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPooling,
ImageClassifierOutput,
MaskedImageModelingOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from .configuration_vit import ViTConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "ViTConfig"
_CHECKPOINT_FOR_DOC = "google/vit-base-patch16-224-in21k"
_EXPECTED_OUTPUT_SHAPE = [1, 197, 768]
_IMAGE_CLASS_CHECKPOINT = "google/vit-base-patch16-224"
_IMAGE_CLASS_EXPECTED_OUTPUT = "Egyptian cat"
VIT_PRETRAINED_MODEL_ARCHIVE_LIST = [
"google/vit-base-patch16-224",
]
class ViTEmbeddings(nn.Module):
"""
Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
"""
def __init__(self, config: ViTConfig, use_mask_token: bool = False) -> None:
super().__init__()
self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None
self.patch_embeddings = ViTPatchEmbeddings(config)
num_patches = self.patch_embeddings.num_patches
self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size))
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.config = config
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
resolution images.
Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""
num_patches = embeddings.shape[1] - 1
num_positions = self.position_embeddings.shape[1] - 1
if num_patches == num_positions and height == width:
return self.position_embeddings
class_pos_embed = self.position_embeddings[:, 0]
patch_pos_embed = self.position_embeddings[:, 1:]
dim = embeddings.shape[-1]
h0 = height // self.config.patch_size
w0 = width // self.config.patch_size
h0, w0 = h0 + 0.1, w0 + 0.1
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
mode="bicubic",
align_corners=False,
)
assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1]
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
) -> torch.Tensor:
batch_size, num_channels, height, width = pixel_values.shape
embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
if bool_masked_pos is not None:
seq_length = embeddings.shape[1]
mask_tokens = self.mask_token.expand(batch_size, seq_length, -1)
mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
embeddings = torch.cat((cls_tokens, embeddings), dim=1)
if interpolate_pos_encoding:
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
else:
embeddings = embeddings + self.position_embeddings
embeddings = self.dropout(embeddings)
return embeddings
class ViTPatchEmbeddings(nn.Module):
"""
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
Transformer.
"""
def __init__(self, config):
super().__init__()
image_size, patch_size = config.image_size, config.patch_size
num_channels, hidden_size = config.num_channels, config.hidden_size
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.num_patches = num_patches
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
batch_size, num_channels, height, width = pixel_values.shape
if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
f" Expected {self.num_channels} but got {num_channels}."
)
if not interpolate_pos_encoding:
if height != self.image_size[0] or width != self.image_size[1]:
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model"
f" ({self.image_size[0]}*{self.image_size[1]})."
)
embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
return embeddings
class ViTSelfAttention(nn.Module):
def __init__(self, config: ViTConfig) -> None:
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
f"heads {config.num_attention_heads}."
)
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
attention_probs = self.dropout(attention_probs)
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs
class ViTSelfOutput(nn.Module):
"""
The residual connection is defined in ViTLayer instead of here (as is the case with other models), due to the
layernorm applied before each block.
"""
def __init__(self, config: ViTConfig) -> None:
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states
class ViTAttention(nn.Module):
def __init__(self, config: ViTConfig) -> None:
super().__init__()
self.attention = ViTSelfAttention(config)
self.output = ViTSelfOutput(config)
self.pruned_heads = set()
def prune_heads(self, heads: Set[int]) -> None:
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(
heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
)
self.attention.query = prune_linear_layer(self.attention.query, index)
self.attention.key = prune_linear_layer(self.attention.key, index)
self.attention.value = prune_linear_layer(self.attention.value, index)
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
def forward(
self,
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
self_outputs = self.attention(hidden_states, head_mask, output_attentions)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:]
return outputs
class ViTIntermediate(nn.Module):
def __init__(self, config: ViTConfig) -> None:
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class ViTOutput(nn.Module):
def __init__(self, config: ViTConfig) -> None:
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = hidden_states + input_tensor
return hidden_states
class ViTLayer(nn.Module):
"""This corresponds to the Block class in the timm implementation."""
def __init__(self, config: ViTConfig) -> None:
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = ViTAttention(config)
self.intermediate = ViTIntermediate(config)
self.output = ViTOutput(config)
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
self_attention_outputs = self.attention(
self.layernorm_before(hidden_states),
head_mask,
output_attentions=output_attentions,
)
attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:]
hidden_states = attention_output + hidden_states
layer_output = self.layernorm_after(hidden_states)
layer_output = self.intermediate(layer_output)
layer_output = self.output(layer_output, hidden_states)
outputs = (layer_output,) + outputs
return outputs
class ViTEncoder(nn.Module):
def __init__(self, config: ViTConfig) -> None:
super().__init__()
self.config = config
self.layer = nn.ModuleList([ViTLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
) -> Union[tuple, BaseModelOutput]:
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
layer_head_mask,
output_attentions,
)
else:
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
behavior.
Parameters:
config ([`ViTConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
"""
@add_start_docstrings(
"The bare ViT Model transformer outputting raw hidden-states without any specific head on top.",
VIT_START_DOCSTRING,
)
"""
class ViTModel(ViTPreTrainedModel):
"""
ViT Model class inheriting from ViTPreTrainedModel.
"""
def __init__(self, config: ViTConfig, add_pooling_layer: bool = True, use_mask_token: bool = False):
"""
Initializes a ViTModel instance.
Args:
config (ViTConfig): Configuration class instance defining model architecture.
add_pooling_layer (bool): Whether to add a pooling layer on top of the encoder.
use_mask_token (bool): Whether to use a mask token for the model.
"""
super().__init__(config)
self.config = config
self.embeddings = ViTEmbeddings(config, use_mask_token=use_mask_token)
self.encoder = ViTEncoder(config)
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.pooler = ViTPooler(config) if add_pooling_layer else None
self.post_init()
def get_input_embeddings(self) -> ViTPatchEmbeddings:
"""
Returns the patch embeddings used as input to the model.
"""
return self.embeddings.patch_embeddings
def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
"""
Prunes heads of the model.
Args:
heads_to_prune (Dict[int, List[int]]): Dictionary of layers and heads to prune.
See base class PreTrainedModel for more details.
"""
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
"""
@add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=BaseModelOutputWithPooling,
config_class=_CONFIG_FOR_DOC,
modality="vision",
expected_output=_EXPECTED_OUTPUT_SHAPE,
)
"""
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
bool_masked_pos: Optional[torch.BoolTensor] = None,
head_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype
if pixel_values.dtype != expected_dtype:
pixel_values = pixel_values.to(expected_dtype)
embedding_output = self.embeddings(
pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
)
encoder_outputs = self.encoder(
embedding_output,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = encoder_outputs[0]
sequence_output = self.layernorm(sequence_output)
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
if not return_dict:
head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
return head_outputs + encoder_outputs[1:]
return BaseModelOutputWithPooling(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
class ViTPooler(nn.Module):
def __init__(self, config: ViTConfig):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states):
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
@add_start_docstrings(
"""ViT Model with a decoder on top for masked image modeling, as proposed in [SimMIM](https://arxiv.org/abs/2111.09886).
<Tip>
Note that we provide a script to pre-train this model on custom data in our [examples
directory](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-pretraining).
</Tip>
""",
VIT_START_DOCSTRING,
)
class ViTForMaskedImageModeling(ViTPreTrainedModel):
def __init__(self, config: ViTConfig) -> None:
super().__init__(config)
self.vit = ViTModel(config, add_pooling_layer=False, use_mask_token=True)
self.decoder = nn.Sequential(
nn.Conv2d(
in_channels=config.hidden_size,
out_channels=config.encoder_stride**2 * config.num_channels,
kernel_size=1,
),
nn.PixelShuffle(config.encoder_stride),
)
self.post_init()
@add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=MaskedImageModelingOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
bool_masked_pos: Optional[torch.BoolTensor] = None,
head_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: Optional[bool] = None,
return_dict: Optional[bool] = None,
def __init__(self, config: ViTConfig) -> None:
super().__init__(config)
self.num_labels = config.num_labels
self.vit = ViTModel(config, add_pooling_layer=False)
self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
self.post_init()
@add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_IMAGE_CLASS_CHECKPOINT,
output_type=ImageClassifierOutput,
config_class=_CONFIG_FOR_DOC,
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
)
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[tuple, ImageClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.vit(
pixel_values,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict,
)
sequence_output = outputs[0]
logits = self.classifier(sequence_output[:, 0, :])
loss = None
if labels is not None:
labels = labels.to(logits.device)
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return ImageClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)