Transformers 源码解析(一百一十六)
.\models\vilt\modeling_vilt.py
""" PyTorch ViLT 模型。"""
import collections.abc
import math
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPooling,
MaskedLMOutput,
ModelOutput,
SequenceClassifierOutput,
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import (
find_pruneable_heads_and_indices,
meshgrid,
prune_linear_layer,
)
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from .configuration_vilt import ViltConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "ViltConfig"
_CHECKPOINT_FOR_DOC = "dandelin/vilt-b32-mlm"
VILT_PRETRAINED_MODEL_ARCHIVE_LIST = [
"dandelin/vilt-b32-mlm",
]
@dataclass
class ViltForImagesAndTextClassificationOutput(ModelOutput):
"""
[`ViltForImagesAndTextClassification`] 的输出类。
"""
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
分类(如果config.num_labels==1,则为回归)损失值。
logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
分类(如果config.num_labels==1,则为回归)得分(SoftMax之前)。
hidden_states (`List[tuple(torch.FloatTensor)]`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
`torch.FloatTensor`的列表(每个图像-文本对一个,每个元组包含嵌入的输出+每层输出)的元组,形状为`(batch_size, sequence_length, hidden_size)`。
模型在每一层输出的隐藏状态加上初始嵌入的输出。
attentions (`List[tuple(torch.FloatTensor)]`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
`torch.FloatTensor`的列表(每个图像-文本对一个,每个元组包含注意力权重的输出)的元组,形状为`(batch_size, num_heads, sequence_length, sequence_length)`。
注意力softmax之后的注意力权重,用于计算自注意力头中的加权平均值。
"""
# 定义可能为None的损失值变量,类型为`torch.FloatTensor`
loss: Optional[torch.FloatTensor] = None
# 定义必须存在的logits变量,类型为`torch.FloatTensor`
logits: torch.FloatTensor = None
# 定义可能为None的隐藏状态列表变量,每个元素为`torch.FloatTensor`的元组列表
hidden_states: Optional[List[Tuple[torch.FloatTensor]]] = None
# 定义可能为None的注意力权重列表变量,每个元素为`torch.FloatTensor`的元组列表
attentions: Optional[List[Tuple[torch.FloatTensor]]] = None
class ViltEmbeddings(nn.Module):
"""
Construct the text and patch embeddings.
Text embeddings are equivalent to BERT embeddings.
Patch embeddings are equivalent to ViT embeddings.
"""
def __init__(self, config):
super().__init__()
# text embeddings
self.text_embeddings = TextEmbeddings(config)
# patch embeddings
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
self.patch_embeddings = ViltPatchEmbeddings(config)
num_patches = self.patch_embeddings.num_patches
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
# modality type (text/patch) embeddings
self.token_type_embeddings = nn.Embedding(config.modality_type_vocab_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.config = config
def forward(
self,
input_ids,
attention_mask,
token_type_ids,
pixel_values,
pixel_mask,
inputs_embeds,
image_embeds,
image_token_type_idx=1,
):
# PART 1: text embeddings
text_embeds = self.text_embeddings(
input_ids=input_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
)
# PART 2: patch embeddings (with interpolated position encodings)
if image_embeds is None:
# Generate visual embeddings and masks from pixel values
image_embeds, image_masks, patch_index = self.visual_embed(
pixel_values, pixel_mask, max_image_length=self.config.max_image_length
)
else:
# Flatten pixel masks
image_masks = pixel_mask.flatten(1)
# PART 3: add modality type embeddings
# 0 indicates text, 1 indicates image, 2 is optionally used when a second image is provided (NLVR2)
if image_token_type_idx is None:
image_token_type_idx = 1
# Add token type embeddings to text embeddings
text_embeds = text_embeds + self.token_type_embeddings(
torch.zeros_like(attention_mask, dtype=torch.long, device=text_embeds.device)
)
# Add token type embeddings to image embeddings
image_embeds = image_embeds + self.token_type_embeddings(
torch.full_like(image_masks, image_token_type_idx, dtype=torch.long, device=text_embeds.device)
)
# PART 4: concatenate text and image embeddings
embeddings = torch.cat([text_embeds, image_embeds], dim=1)
# Concatenate attention masks and image masks
masks = torch.cat([attention_mask, image_masks], dim=1)
return embeddings, masks
# 初始化函数,接受一个配置参数 config
def __init__(self, config):
# 调用父类的初始化方法
super().__init__()
# 创建一个词嵌入层,用于将词汇索引映射为隐藏状态向量
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
# 创建一个位置嵌入层,用于将位置索引映射为隐藏状态向量
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
# 创建一个标记类型嵌入层,用于将标记类型索引映射为隐藏状态向量
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
# 创建一个 LayerNorm 层,用于标准化隐藏状态向量
# 注意:这里 LayerNorm 的命名方式与 TensorFlow 的模型变量保持一致,以便能够加载任何 TensorFlow 的检查点文件
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
# 创建一个 Dropout 层,用于在训练过程中随机丢弃隐藏状态向量的部分内容,以防止过拟合
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# 设置位置嵌入类型,默认为绝对位置编码
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
# 注册一个缓冲区,用于存储位置索引的张量,这个张量在序列化时会被导出
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
# 注册一个缓冲区,用于存储标记类型索引的张量,初始值为全零
self.register_buffer(
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
)
# 前向传播函数,接受多个输入参数,根据输入计算模型的输出
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
# 如果输入的 input_ids 不为 None,则获取其形状
if input_ids is not None:
input_shape = input_ids.size()
else:
# 否则,获取 inputs_embeds 的形状(排除最后一维)
input_shape = inputs_embeds.size()[:-1]
# 获取序列长度,即输入数据的第二个维度大小
seq_length = input_shape[1]
# 如果 position_ids 为 None,则使用预先注册的位置索引张量 self.position_ids
if position_ids is None:
position_ids = self.position_ids[:, :seq_length]
# 如果 token_type_ids 为 None,则使用预先注册的标记类型索引张量 self.token_type_ids
if token_type_ids is None:
if hasattr(self, "token_type_ids"):
# 获取并扩展预先注册的 token_type_ids 到与输入形状相匹配的张量
buffered_token_type_ids = self.token_type_ids[:, :seq_length]
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
token_type_ids = buffered_token_type_ids_expanded
else:
# 如果未注册 token_type_ids,则创建一个全零张量,与输入形状相匹配
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
# 如果 inputs_embeds 为 None,则通过 word_embeddings 层将 input_ids 映射为词嵌入向量
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
# 根据 token_type_ids 获取标记类型嵌入向量
token_type_embeddings = self.token_type_embeddings(token_type_ids)
# 将词嵌入向量和标记类型嵌入向量相加,得到最终的嵌入向量
embeddings = inputs_embeds + token_type_embeddings
# 如果位置编码方式为绝对位置编码,则添加位置嵌入向量到最终的嵌入向量中
if self.position_embedding_type == "absolute":
position_embeddings = self.position_embeddings(position_ids)
embeddings += position_embeddings
# 对最终的嵌入向量进行 LayerNorm 标准化处理
embeddings = self.LayerNorm(embeddings)
# 对标准化后的向量应用 Dropout,以防止过拟合
embeddings = self.dropout(embeddings)
# 返回最终的嵌入向量作为模型的输出
return embeddings
"""
Image to Patch Embedding.
"""
# 初始化函数,设置类的初始状态
def __init__(self, config):
super().__init__()
# 从配置中获取图像大小和patch大小
image_size, patch_size = config.image_size, config.patch_size
# 从配置中获取通道数和隐藏层大小
num_channels, hidden_size = config.num_channels, config.hidden_size
# 确保image_size和patch_size是可迭代对象,若不是则转为元组
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
# 计算patch的数量
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
# 设置类的成员变量
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.num_patches = num_patches
# 使用卷积层进行投影,将图像转换为patch embeddings
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
# 前向传播函数,定义了数据从输入到输出的流程
def forward(self, pixel_values):
# 获取输入张量的尺寸信息
batch_size, num_channels, height, width = pixel_values.shape
# 如果输入通道数与配置中的通道数不匹配,则抛出数值错误异常
if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
# 确定目标数据类型为投影层权重的数据类型
target_dtype = self.projection.weight.dtype
# 对输入张量进行投影操作,并转换为目标数据类型
x = self.projection(pixel_values.to(dtype=target_dtype))
# 返回投影后的张量作为输出
return x
class ViltSelfAttention(nn.Module):
# 初始化函数,设置自注意力模块的初始状态
def __init__(self, config):
super().__init__()
# 如果隐藏层大小不能被注意力头数整除,并且配置中没有嵌入大小属性,则抛出数值错误异常
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
f"heads {config.num_attention_heads}."
)
# 设置注意力头数和每个头的大小
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
# 定义查询、键、值的线性映射层,并考虑是否使用偏置
self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
# 定义dropout层,用于注意力概率的dropout
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
# 将输入张量重塑为注意力分数计算所需的形状
def transpose_for_scores(self, x):
# 计算新的张量形状,以便符合多头注意力的要求
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
# 重新排列张量的维度,使得多头注意力能够并行计算
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
# 此方法用于前向传播,计算注意力机制中的上下文表示
def forward(self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False):
# 通过self.query对隐藏状态进行查询,生成混合的查询层
mixed_query_layer = self.query(hidden_states)
# 通过self.key对隐藏状态进行键的变换,并进行得分计算
key_layer = self.transpose_for_scores(self.key(hidden_states))
# 通过self.value对隐藏状态进行值的变换,并进行得分计算
value_layer = self.transpose_for_scores(self.value(hidden_states))
# 对混合的查询层再进行得分计算
query_layer = self.transpose_for_scores(mixed_query_layer)
# 通过点积计算"查询"和"键"之间的原始注意力分数
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
# 根据注意力头的大小对注意力分数进行缩放
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
# 如果提供了注意力掩码,则应用它
if attention_mask is not None:
# 注意力掩码是预先计算好的,适用于BertModel的forward()函数中的所有层
attention_scores = attention_scores + attention_mask
# 将注意力分数归一化为概率分布
attention_probs = nn.Softmax(dim=-1)(attention_scores)
# 使用dropout函数对注意力概率进行随机失活处理
attention_probs = self.dropout(attention_probs)
# 如果需要,对注意力概率进行头部掩码处理
if head_mask is not None:
attention_probs = attention_probs * head_mask
# 计算加权和得到上下文表示层
context_layer = torch.matmul(attention_probs, value_layer)
# 对上下文表示层进行维度变换和重塑
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
# 准备输出,根据需要包含注意力分布
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
# 返回计算结果
return outputs
# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Vilt
class ViltSelfOutput(nn.Module):
"""
The residual connection is defined in ViltLayer instead of here (as is the case with other models), due to the
layernorm applied before each block.
"""
def __init__(self, config: ViltConfig) -> None:
super().__init__()
# 定义一个全连接层,输入和输出的维度都是 config.hidden_size
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
# 定义一个 dropout 层,根据给定的隐藏状态概率随机将输入置零
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
# 将输入的隐藏状态通过全连接层映射到同一维度
hidden_states = self.dense(hidden_states)
# 对映射后的隐藏状态进行 dropout 操作
hidden_states = self.dropout(hidden_states)
return hidden_states
class ViltAttention(nn.Module):
def __init__(self, config):
super().__init__()
# 初始化自注意力层和自输出层,都使用给定的配置参数
self.attention = ViltSelfAttention(config)
self.output = ViltSelfOutput(config)
# 初始化一个空集合,用于存储被修剪的注意力头
self.pruned_heads = set()
def prune_heads(self, heads):
if len(heads) == 0:
return
# 根据给定的头部列表找到可修剪的头部和对应的索引
heads, index = find_pruneable_heads_and_indices(
heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
)
# 修剪线性层
self.attention.query = prune_linear_layer(self.attention.query, index)
self.attention.key = prune_linear_layer(self.attention.key, index)
self.attention.value = prune_linear_layer(self.attention.value, index)
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
# 更新超参数并存储修剪的头部
self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
def forward(self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False):
# 通过自注意力层处理隐藏状态和相关的掩码信息
self_outputs = self.attention(hidden_states, attention_mask, head_mask, output_attentions)
# 使用自输出层将注意力层的输出与原始隐藏状态相加
attention_output = self.output(self_outputs[0], hidden_states)
# 如果输出注意力信息,则将其添加到输出元组中
outputs = (attention_output,) + self_outputs[1:]
return outputs
# Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->Vilt
class ViltIntermediate(nn.Module):
def __init__(self, config: ViltConfig) -> None:
super().__init__()
# 定义一个全连接层,输入维度为 config.hidden_size,输出维度为 config.intermediate_size
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
# 如果隐藏激活函数是字符串,则根据字符串映射到相应的激活函数,否则直接使用给定的激活函数
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# 将输入的隐藏状态通过全连接层映射到 intermediate_size 的维度
hidden_states = self.dense(hidden_states)
# 将映射后的隐藏状态通过中间激活函数处理
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
# Copied from transformers.models.vit.modeling_vit.ViTOutput with ViT->Vilt
class ViltOutput(nn.Module):
def __init__(self, config: ViltConfig) -> None:
super().__init__()
# 定义一个全连接层,将中间大小的特征转换为隐藏大小
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
# 定义一个 dropout 层,用于随机断开神经元连接,防止过拟合
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
# 将输入的隐藏状态通过全连接层映射到隐藏大小的空间
hidden_states = self.dense(hidden_states)
# 对映射后的结果进行 dropout 处理
hidden_states = self.dropout(hidden_states)
# 将处理后的隐藏状态与输入张量相加作为最终输出
hidden_states = hidden_states + input_tensor
return hidden_states
class ViltLayer(nn.Module):
"""This corresponds to the Block class in the timm implementation."""
def __init__(self, config):
super().__init__()
# 设置用于分块前馈的块大小
self.chunk_size_feed_forward = config.chunk_size_feed_forward
# 序列长度的维度索引
self.seq_len_dim = 1
# 初始化自注意力层、中间层和输出层
self.attention = ViltAttention(config)
self.intermediate = ViltIntermediate(config)
self.output = ViltOutput(config)
# ViLT 中的 layernorm 在自注意力之前应用
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
# ViLT 中的 layernorm 也在自注意力之后应用
self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False):
# 对输入的隐藏状态应用 layernorm,并传入自注意力层进行处理
self_attention_outputs = self.attention(
self.layernorm_before(hidden_states),
attention_mask,
head_mask,
output_attentions=output_attentions,
)
attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:] # 如果输出注意力权重,添加自注意力的输出
# 第一个残差连接:将自注意力的输出与原始隐藏状态相加
hidden_states = attention_output + hidden_states.to(attention_output.device)
# 在 ViLT 中,layernorm 也在自注意力之后应用
layer_output = self.layernorm_after(hidden_states)
# 经过中间层的处理
layer_output = self.intermediate(layer_output)
# 第二个残差连接:将中间层的输出与原始隐藏状态传入输出层
layer_output = self.output(layer_output, hidden_states)
# 将最终层的输出添加到输出集合中
outputs = (layer_output,) + outputs
return outputs
class ViltEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
# 创建多层 ViltLayer 构成的层列表
self.layer = nn.ModuleList([ViltLayer(config) for _ in range(config.num_hidden_layers)])
# 默认关闭梯度检查点
self.gradient_checkpointing = False
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
):
):
# 如果不需要输出隐藏状态,则初始化为空元组;否则设为 None
all_hidden_states = () if output_hidden_states else None
# 如果不需要输出注意力权重,则初始化为空元组;否则设为 None
all_self_attentions = () if output_attentions else None
# 遍历 Transformer 模型的每一层
for i, layer_module in enumerate(self.layer):
# 如果需要输出隐藏状态,则累加当前隐藏状态到 all_hidden_states
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
# 获取当前层的头部遮罩,如果未提供则为 None
layer_head_mask = head_mask[i] if head_mask is not None else None
# 如果启用渐变检查点并且处于训练模式下
if self.gradient_checkpointing and self.training:
# 通过渐变检查点功能调用当前层模块,获取层的输出
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
layer_head_mask,
output_attentions,
)
else:
# 否则直接调用当前层模块,获取层的输出
layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions)
# 更新隐藏状态为当前层的输出的第一个元素
hidden_states = layer_outputs[0]
# 如果需要输出注意力权重,则累加当前层的注意力权重到 all_self_attentions
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
# 如果需要输出隐藏状态,则最后将当前隐藏状态加入 all_hidden_states
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
# 如果不返回字典形式的输出,则按顺序返回隐藏状态、所有隐藏状态和所有注意力权重的非空元组
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
# 否则以 BaseModelOutput 类的形式返回结果,包含最终隐藏状态、所有隐藏状态和所有注意力权重
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
class ViltPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
# 设置模型的配置类
config_class = ViltConfig
# 模型基础名称前缀
base_model_prefix = "vilt"
# 支持梯度检查点
supports_gradient_checkpointing = True
# 不需要分割的模块列表
_no_split_modules = ["ViltEmbeddings", "ViltSelfAttention"]
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d)):
# 如果是线性层或卷积层,使用正态分布初始化权重
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
# 如果有偏置,则将偏置初始化为零
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
# 如果是嵌入层,使用正态分布初始化权重
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
# 如果有填充索引,则将对应位置的权重初始化为零
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
# 如果是LayerNorm层,将偏置初始化为零,权重初始化为1.0
module.bias.data.zero_()
module.weight.data.fill_(1.0)
# ViLT模型的起始文档字符串
VILT_START_DOCSTRING = r"""
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html
it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
behavior.
Parameters:
config ([`ViltConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
# ViLT模型输入文档字符串(空白)
VILT_INPUTS_DOCSTRING = r"""
"""
# ViLT图像和文本分类输入文档字符串(空白)
VILT_IMAGES_AND_TEXT_CLASSIFICATION_INPUTS_DOCSTRING = r"""
"""
# 添加起始文档字符串注释到ViltModel类
@add_start_docstrings(
"The bare ViLT Model transformer outputting raw hidden-states without any specific head on top.",
VILT_START_DOCSTRING,
)
class ViltModel(ViltPreTrainedModel):
def __init__(self, config, add_pooling_layer=True):
super().__init__(config)
self.config = config
# 初始化嵌入层和编码器
self.embeddings = ViltEmbeddings(config)
self.encoder = ViltEncoder(config)
# LayerNorm层,用于归一化隐藏层输出
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
# 如果需要添加汇聚层,则初始化汇聚器
self.pooler = ViltPooler(config) if add_pooling_layer else None
# 初始化权重并应用最终处理
self.post_init()
def get_input_embeddings(self):
return self.embeddings.text_embeddings.word_embeddings
def set_input_embeddings(self, value):
self.embeddings.text_embeddings.word_embeddings = value
# 修剪模型的注意力头
def _prune_heads(self, heads_to_prune):
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
# 遍历需要修剪的层和对应需要修剪的注意力头
for layer, heads in heads_to_prune.items():
# 对编码器中的特定层的注意力头进行修剪
self.encoder.layer[layer].attention.prune_heads(heads)
# 将输入参数添加到模型的文档字符串
@add_start_docstrings_to_model_forward(VILT_INPUTS_DOCSTRING)
# 替换返回值的文档字符串
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
# 定义模型的前向传播方法
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
pixel_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
image_embeds: Optional[torch.FloatTensor] = None,
image_token_type_idx: Optional[int] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
class ViltPredictionHeadTransform(nn.Module):
def __init__(self, config):
super().__init__()
# 使用全连接层进行线性变换,输入和输出维度都是 config.hidden_size
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
# 根据配置选择激活函数,如果配置中指定的是字符串形式的激活函数,则使用对应的函数,否则直接使用配置中的函数
if isinstance(config.hidden_act, str):
self.transform_act_fn = ACT2FN[config.hidden_act]
else:
self.transform_act_fn = config.hidden_act
# 应用 Layer Normalization 进行归一化处理,参数包括隐藏状态的维度和层归一化的 epsilon 值
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states):
# 将隐藏状态通过全连接层进行线性变换
hidden_states = self.dense(hidden_states)
# 应用选择的激活函数进行非线性变换
hidden_states = self.transform_act_fn(hidden_states)
# 对变换后的隐藏状态应用 Layer Normalization 进行归一化
hidden_states = self.LayerNorm(hidden_states)
return hidden_states
# 初始化函数,用于初始化模型对象
def __init__(self, config, weight=None):
# 调用父类的初始化方法
super().__init__()
# 将配置参数保存到对象的属性中
self.config = config
# 创建一个 ViltPredictionHeadTransform 的实例,并保存到对象的属性中
self.transform = ViltPredictionHeadTransform(config)
# 创建一个线性层,用于模型的解码器,指定输入大小为 config.hidden_size,输出大小为 config.vocab_size,且没有偏置项
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# 创建一个可学习的偏置项,大小为 config.vocab_size
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
# 如果给定了预训练权重 weight,则将其赋值给解码器的权重
if weight is not None:
self.decoder.weight = weight
# 为了确保偏置项能够正确地在调整 token embeddings 时被重新调整大小,需要在这里建立两者之间的链接
self.decoder.bias = self.bias
# 前向传播函数,接收输入 x 并返回模型的输出 x
def forward(self, x):
# 对输入 x 应用预测头变换
x = self.transform(x)
# 使用解码器对变换后的 x 进行解码
x = self.decoder(x)
# 返回解码后的输出 x
return x
@add_start_docstrings(
"""
Vilt Model transformer with a classifier head on top (a linear layer on top of the final hidden state of the [CLS]
token) for visual question answering, e.g. for VQAv2.
""",
VILT_START_DOCSTRING,
)
class ViltForQuestionAnswering(ViltPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.vilt = ViltModel(config)
# Classifier head
self.classifier = nn.Sequential(
nn.Linear(config.hidden_size, config.hidden_size * 2), # Linear layer to expand hidden size
nn.LayerNorm(config.hidden_size * 2), # Layer normalization
nn.GELU(), # GELU activation function
nn.Linear(config.hidden_size * 2, config.num_labels), # Final linear layer for classification
)
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(VILT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
pixel_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
image_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[SequenceClassifierOutput, Tuple[torch.FloatTensor]]:
r"""
labels (`torch.FloatTensor` of shape `(batch_size, num_labels)`, *optional*):
Labels for computing the visual question answering loss. This tensor must be either a one-hot encoding of
all answers that are applicable for a given example in the batch, or a soft encoding indicating which
answers are applicable, where 1.0 is the highest score.
Returns:
Depending on `return_dict`, returns either a `SequenceClassifierOutput` or a tuple containing logits and optionally other outputs.
Examples:
```
>>> from transformers import ViltProcessor, ViltForQuestionAnswering
>>> import requests
>>> from PIL import Image
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> text = "How many cats are there?"
>>> processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
>>> model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
>>>
>>> encoding = processor(image, text, return_tensors="pt")
>>>
>>> outputs = model(**encoding)
>>> logits = outputs.logits
>>> idx = logits.argmax(-1).item()
>>> print("Predicted answer:", model.config.id2label[idx])
Predicted answer: 2
```"""
# Determine whether to use the return_dict provided or the class attribute for return settings
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Perform forward pass through the VILT model
outputs = self.vilt(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
pixel_values=pixel_values,
pixel_mask=pixel_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
image_embeds=image_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# Extract the pooler_output from the outputs based on return_dict setting
pooler_output = outputs.pooler_output if return_dict else outputs[1]
# Pass the pooler_output through the classifier layer to obtain logits
logits = self.classifier(pooler_output)
# Initialize loss variable
loss = None
# Calculate loss if labels are provided
if labels is not None:
# Move labels tensor to the same device as logits for compatibility
labels = labels.to(logits.device)
# Compute binary cross entropy loss scaled by number of labels
loss = nn.functional.binary_cross_entropy_with_logits(logits, labels) * labels.shape[1]
# Reference to paper or implementation where this loss scaling is discussed
# Prepare output based on return_dict flag
if not return_dict:
# If return_dict is False, prepare tuple output
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
# If return_dict is True, prepare SequenceClassifierOutput object
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
# 使用装饰器为类添加文档字符串,描述了该类的作用和功能,以及适用的应用场景(图片到文本或文本到图片的检索)
@add_start_docstrings(
"""
Vilt Model transformer with a classifier head on top (a linear layer on top of the final hidden state of the [CLS]
token) for image-to-text or text-to-image retrieval, e.g. MSCOCO and F30K.
""",
VILT_START_DOCSTRING,
)
# 定义 ViltForImageAndTextRetrieval 类,继承自 ViltPreTrainedModel
class ViltForImageAndTextRetrieval(ViltPreTrainedModel):
# 初始化方法,接受一个 config 对象作为参数
def __init__(self, config):
# 调用父类的初始化方法
super().__init__(config)
# 创建 ViltModel 的实例,并保存到 self.vilt 属性中
self.vilt = ViltModel(config)
# 分类器头部,使用线性层将最终隐藏状态([CLS] token)映射到单一输出维度
self.rank_output = nn.Linear(config.hidden_size, 1)
# 初始化权重并进行最终处理
self.post_init()
# 使用装饰器为 forward 方法添加文档字符串,描述了该方法的输入参数及其作用
@add_start_docstrings_to_model_forward(VILT_INPUTS_DOCSTRING)
# 使用装饰器替换返回值的文档字符串,指定输出类型为 SequenceClassifierOutput,配置类为 _CONFIG_FOR_DOC
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
# forward 方法,处理模型的前向传播逻辑
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
pixel_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
image_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[SequenceClassifierOutput, Tuple[torch.FloatTensor]]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels are currently not supported.
Returns:
Depending on `return_dict` flag:
- If `return_dict` is False, returns a tuple containing `logits` and additional outputs.
- If `return_dict` is True, returns a `SequenceClassifierOutput` object containing `loss`, `logits`, `hidden_states`, and `attentions`.
Examples:
```
>>> from transformers import ViltProcessor, ViltForImageAndTextRetrieval
>>> import requests
>>> from PIL import Image
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> texts = ["An image of two cats chilling on a couch", "A football player scoring a goal"]
>>> processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-coco")
>>> model = ViltForImageAndTextRetrieval.from_pretrained("dandelin/vilt-b32-finetuned-coco")
>>>
>>> scores = dict()
>>> for text in texts:
...
... encoding = processor(image, text, return_tensors="pt")
... outputs = model(**encoding)
... scores[text] = outputs.logits[0, :].item()
```
"""
# Determine whether to use the return_dict flag or the model's default configuration
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Perform the forward pass through the VILT model
outputs = self.vilt(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
pixel_values=pixel_values,
pixel_mask=pixel_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
image_embeds=image_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# Select the pooler output based on whether return_dict is True or False
pooler_output = outputs.pooler_output if return_dict else outputs[1]
# Generate logits using the rank_output method
logits = self.rank_output(pooler_output)
# Initialize loss as None
loss = None
# Handle labels if provided (currently raises NotImplementedError)
if labels is not None:
# Move labels to the device where logits are located
labels = labels.to(logits.device)
raise NotImplementedError("Training is not yet supported.")
# Return the output based on whether return_dict is True or False
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
# Return a SequenceClassifierOutput object containing loss, logits, hidden_states, and attentions
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"""
Vilt Model transformer with a classifier head on top for natural language visual reasoning, e.g. NLVR2.
""",
VILT_IMAGES_AND_TEXT_CLASSIFICATION_INPUTS_DOCSTRING,
)
class ViltForImagesAndTextClassification(ViltPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.vilt = ViltModel(config)
# Classifier head
num_images = config.num_images
# 定义分类器,包括线性层、LayerNorm和GELU激活函数
self.classifier = nn.Sequential(
nn.Linear(config.hidden_size * num_images, config.hidden_size * num_images),
nn.LayerNorm(config.hidden_size * num_images),
nn.GELU(),
nn.Linear(config.hidden_size * num_images, config.num_labels),
)
# Initialize weights and apply final processing
# 初始化权重并进行最终处理
self.post_init()
@add_start_docstrings_to_model_forward(VILT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=ViltForImagesAndTextClassificationOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
pixel_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
image_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
@add_start_docstrings(
"""
ViLT Model with a token classification head on top (a linear layer on top of the final hidden-states of the text
tokens) e.g. for Named-Entity-Recognition (NER) tasks.
""",
VILT_START_DOCSTRING,
)
class ViltForTokenClassification(ViltPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
# 初始化 ViLT 模型,不添加池化层
self.vilt = ViltModel(config, add_pooling_layer=False)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# 分类器是一个线性层,输出维度为 config.num_labels
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
# Initialize weights and apply final processing
# 初始化权重并进行最终处理
self.post_init()
@add_start_docstrings_to_model_forward(VILT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TokenClassifierOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
pixel_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
image_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[TokenClassifierOutput, Tuple[torch.FloatTensor]]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, text_sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
Returns:
Either a `TokenClassifierOutput` containing loss, logits, hidden states, and attentions,
or a tuple with logits and optional hidden states and attentions.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Pass inputs to the VILT model for processing
outputs = self.vilt(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
pixel_values=pixel_values,
pixel_mask=pixel_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
image_embeds=image_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
# Determine the size of the text input
text_input_size = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
# Apply dropout to the sequence output
sequence_output = self.dropout(sequence_output)
# Classify tokens using the classifier layer
logits = self.classifier(sequence_output[:, :text_input_size])
loss = None
if labels is not None:
# Calculate the cross-entropy loss
loss_fct = CrossEntropyLoss()
# Move labels to the same device as logits
labels = labels.to(logits.device)
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
# If return_dict is False, return a tuple of logits and optionally hidden states and attentions
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
# If return_dict is True, return a TokenClassifierOutput object
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
.\models\vilt\processing_vilt.py
"""
Processor class for ViLT.
"""
import warnings
from typing import List, Optional, Union
from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
from ...utils import TensorType
class ViltProcessor(ProcessorMixin):
r"""
Constructs a ViLT processor which wraps a BERT tokenizer and ViLT image processor into a single processor.
[`ViltProcessor`] offers all the functionalities of [`ViltImageProcessor`] and [`BertTokenizerFast`]. See the
docstring of [`~ViltProcessor.__call__`] and [`~ViltProcessor.decode`] for more information.
Args:
image_processor (`ViltImageProcessor`, *optional*):
An instance of [`ViltImageProcessor`]. The image processor is a required input.
tokenizer (`BertTokenizerFast`, *optional*):
An instance of ['BertTokenizerFast`]. The tokenizer is a required input.
"""
attributes = ["image_processor", "tokenizer"]
image_processor_class = "ViltImageProcessor"
tokenizer_class = ("BertTokenizer", "BertTokenizerFast")
def __init__(self, image_processor=None, tokenizer=None, **kwargs):
feature_extractor = None
if "feature_extractor" in kwargs:
warnings.warn(
"The `feature_extractor` argument is deprecated and will be removed in v5, use `image_processor`"
" instead.",
FutureWarning,
)
feature_extractor = kwargs.pop("feature_extractor")
image_processor = image_processor if image_processor is not None else feature_extractor
if image_processor is None:
raise ValueError("You need to specify an `image_processor`.")
if tokenizer is None:
raise ValueError("You need to specify a `tokenizer`.")
super().__init__(image_processor, tokenizer)
self.current_processor = self.image_processor
def __call__(
self,
images,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
add_special_tokens: bool = True,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = None,
max_length: Optional[int] = None,
stride: int = 0,
pad_to_multiple_of: Optional[int] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
return_length: bool = False,
verbose: bool = True,
return_tensors: Optional[Union[str, TensorType]] = None,
**kwargs,
) -> BatchEncoding:
"""
This method uses [`ViltImageProcessor.__call__`] method to prepare image(s) for the model, and
[`BertTokenizerFast.__call__`] to prepare text for the model.
Please refer to the docstring of the above two methods for more information.
"""
encoding = self.tokenizer(
text=text,
add_special_tokens=add_special_tokens,
padding=padding,
truncation=truncation,
max_length=max_length,
stride=stride,
pad_to_multiple_of=pad_to_multiple_of,
return_token_type_ids=return_token_type_ids,
return_attention_mask=return_attention_mask,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_offsets_mapping=return_offsets_mapping,
return_length=return_length,
verbose=verbose,
return_tensors=return_tensors,
**kwargs,
)
encoding_image_processor = self.image_processor(images, return_tensors=return_tensors)
encoding.update(encoding_image_processor)
return encoding
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to BertTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
refer to the docstring of this method for more information.
"""
return self.tokenizer.batch_decode(*args, **kwargs)
def decode(self, *args, **kwargs):
"""
This method forwards all its arguments to BertTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
the docstring of this method for more information.
"""
return self.tokenizer.decode(*args, **kwargs)
@property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
image_processor_input_names = self.image_processor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
def feature_extractor_class(self):
warnings.warn(
"`feature_extractor_class` is deprecated and will be removed in v5. Use `image_processor_class` instead.",
FutureWarning,
)
return self.image_processor_class
@property
def feature_extractor(self):
warnings.warn(
"`feature_extractor` is deprecated and will be removed in v5. Use `image_processor` instead.",
FutureWarning,
)
return self.image_processor
.\models\vilt\__init__.py
from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
_import_structure = {"configuration_vilt": ["VILT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViltConfig"]}
try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["feature_extraction_vilt"] = ["ViltFeatureExtractor"]
_import_structure["image_processing_vilt"] = ["ViltImageProcessor"]
_import_structure["processing_vilt"] = ["ViltProcessor"]
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_vilt"] = [
"VILT_PRETRAINED_MODEL_ARCHIVE_LIST",
"ViltForImageAndTextRetrieval",
"ViltForImagesAndTextClassification",
"ViltForTokenClassification",
"ViltForMaskedLM",
"ViltForQuestionAnswering",
"ViltLayer",
"ViltModel",
"ViltPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_vilt import VILT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViltConfig
try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .feature_extraction_vilt import ViltFeatureExtractor
from .image_processing_vilt import ViltImageProcessor
from .processing_vilt import ViltProcessor
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_vilt import (
VILT_PRETRAINED_MODEL_ARCHIVE_LIST,
ViltForImageAndTextRetrieval,
ViltForImagesAndTextClassification,
ViltForMaskedLM,
ViltForQuestionAnswering,
ViltForTokenClassification,
ViltLayer,
ViltModel,
ViltPreTrainedModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
.\models\vipllava\configuration_vipllava.py
""" VipLlava model configuration"""
import warnings
from ...configuration_utils import PretrainedConfig
from ...utils import logging
from ..auto import CONFIG_MAPPING
logger = logging.get_logger(__name__)
VIPLLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"ybelkada/vip-llava-7b-hf": "https://huggingface.co/llava-hf/vip-llava-7b-hf/resolve/main/config.json",
}
class VipLlavaConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`VipLlavaForConditionalGeneration`]. It is used to instantiate an
VipLlava model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the VipLlava-9B.
e.g. [ybelkada/vip-llava-7b-hf](https://huggingface.co/ybelkada/vip-llava-7b-hf)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vision_config (`VipLlavaVisionConfig`, *optional*):
Custom vision config or dict
text_config (`Union[AutoConfig, dict]`, *optional*):
The config object of the text backbone. Can be any of `LlamaConfig` or `MistralConfig`.
ignore_index (`int`, *optional*, defaults to -100):
The ignore index for the loss function.
image_token_index (`int`, *optional*, defaults to 32000):
The image token index to encode the image prompt.
projector_hidden_act (`str`, *optional*, defaults to `"gelu"`):
The activation function used by the multimodal projector.
projector_layernorm_eps (`float`, *optional*, defaults to 1e-05):
The layer norm epsilon of the projector layernorm
vision_feature_layers (`List[int]`, *optional*, defaults to `[-2, -5, -8, -11, 6]`):
The list of layers to select the vision features from.
Example:
```
>>> from transformers import VipLlavaForConditionalGeneration, VipLlavaConfig, CLIPVisionConfig, LlamaConfig
>>> # Initializing a CLIP-vision config
>>> vision_config = CLIPVisionConfig()
>>> # Initializing a Llama config
>>> text_config = LlamaConfig()
```
"""
>>> configuration = VipLlavaConfig(vision_config, text_config)
>>> model = VipLlavaForConditionalGeneration(configuration)
>>> configuration = model.config
.\models\vipllava\convert_vipllava_weights_to_hf.py
import argparse
import torch
from huggingface_hub import hf_hub_download
from transformers import (
AddedToken,
AutoConfig,
AutoTokenizer,
CLIPImageProcessor,
LlavaProcessor,
VipLlavaConfig,
VipLlavaForConditionalGeneration,
)
KEYS_TO_MODIFY_MAPPING = {
"model.vision_tower.": "",
"model.mm_projector": "multi_modal_projector",
"model": "model.model",
"vision_model.model": "vision_model",
"lm_head": "language_model.lm_head",
"model.model": "language_model.model",
"multi_modal_projector.0": "multi_modal_projector.linear_1",
"multi_modal_projector.2": "multi_modal_projector.linear_2",
"final_linear.0": "linear_1",
"final_linear.2": "linear_2",
"multi_modal_projector.clip_layernorm": "multi_modal_projector.projector_layernorm",
}
def convert_state_dict_to_hf(state_dict):
new_state_dict = {}
for key, value in state_dict.items():
if key.endswith(".inv_freq"):
continue
for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items():
if key_to_modify in key:
key = key.replace(key_to_modify, new_key)
new_state_dict[key] = value
return new_state_dict
def convert_vipllava_llama_to_hf(text_model_id, vision_model_id, output_hub_path, old_state_dict_id):
torch.set_default_dtype(torch.float16)
text_config = AutoConfig.from_pretrained(text_model_id)
tokenizer = AutoTokenizer.from_pretrained(text_model_id)
tokenizer.add_tokens(AddedToken("<image>", special=True, normalized=False), special_tokens=True)
tokenizer.add_special_tokens({"pad_token": "<pad>"})
image_processor = CLIPImageProcessor.from_pretrained(vision_model_id)
processor = LlavaProcessor(tokenizer=tokenizer, image_processor=image_processor)
config = VipLlavaConfig(text_config=text_config)
config.pad_token_id = 32001
with torch.device("meta"):
model = VipLlavaForConditionalGeneration(config)
pad_shape = 64
state_dict_path = hf_hub_download(old_state_dict_id, "model_state_dict_7b.bin")
state_dict = torch.load(state_dict_path, map_location="cpu")
state_dict = convert_state_dict_to_hf(state_dict)
model.load_state_dict(state_dict, strict=True, assign=True)
pre_expansion_embeddings = model.language_model.model.embed_tokens.weight.data
mu = torch.mean(pre_expansion_embeddings, dim=0).float()
n = pre_expansion_embeddings.size()[0]
sigma = ((pre_expansion_embeddings - mu).T @ (pre_expansion_embeddings - mu)) / n
dist = torch.distributions.multivariate_normal.MultivariateNormal(mu, covariance_matrix=1e-5 * sigma)
model.resize_token_embeddings(config.text_config.vocab_size + 2, pad_shape)
model.language_model.model.embed_tokens.weight.data[32000:] = torch.stack(
tuple((dist.sample() for _ in range(model.language_model.model.embed_tokens.weight.data[32000:].shape[0]))),
dim=0,
)
model.language_model.lm_head.weight.data[32000:] = torch.stack(
tuple((dist.sample() for _ in range(model.language_model.lm_head.weight.data[32000:].shape[0]))),
dim=0,
)
model.push_to_hub(output_hub_path)
processor.push_to_hub(output_hub_path)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--text_model_id",
help="Hub location of the text model",
)
parser.add_argument(
"--vision_model_id",
help="Hub location of the vision model",
)
parser.add_argument(
"--output_hub_path",
help="Location on the hub of the converted model",
)
parser.add_argument(
"--old_state_dict_id",
help="Location on the hub of the raw state dict of the original model. The filename needs to be `model_state_dict.bin`",
)
args = parser.parse_args()
convert_vipllava_llama_to_hf(
args.text_model_id, args.vision_model_id, args.output_hub_path, args.old_state_dict_id
)
if __name__ == "__main__":
main()
.\models\vipllava\modeling_vipllava.py
""" PyTorch VipLlava model."""
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from ... import PreTrainedModel
from ...activations import ACT2FN
from ...cache_utils import Cache
from ...modeling_outputs import ModelOutput
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from ..auto import AutoModel, AutoModelForCausalLM
from .configuration_vipllava import VipLlavaConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "VipLlavaConfig"
VIPLLAVA_PRETRAINED_MODEL_ARCHIVE_LIST = [
"llava-hf/vip-llava-7b-hf",
]
@dataclass
class VipLlavaCausalLMOutputWithPast(ModelOutput):
"""
Base class for VipLlava causal language model (or autoregressive) outputs.
"""
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss (for next-token prediction).
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver
"""
# Optional loss value for language modeling
loss: Optional[torch.FloatTensor] = None
# Predicted logits for each token in the batch
logits: torch.FloatTensor = None
# Cached key and value states for speeding up sequential decoding
past_key_values: Optional[List[torch.FloatTensor]] = None
# Hidden states of the model at each layer's output and optional initial embeddings
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
# Attention weights after softmax, used for self-attention computation
attentions: Optional[Tuple[torch.FloatTensor]] = None
# Hidden states produced by the vision encoder for image embeddings
image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
@add_start_docstrings(
"The bare VipLlava Model outputting raw hidden-states without any specific head on top.",
VIPLLAVA_START_DOCSTRING,
)
# 为 VipLlavaPreTrainedModel 类添加文档字符串,描述其作为 VipLlava 模型的基础预训练模型的输出为原始隐藏状态,没有特定的输出头部。
# 从 PreTrainedModel 类继承,定义 VipLlavaPreTrainedModel 类
class VipLlavaPreTrainedModel(PreTrainedModel):
# 指定配置类为 VipLlavaConfig
config_class = VipLlavaConfig
# 模型的基础名称前缀
base_model_prefix = "model"
# 支持梯度检查点
supports_gradient_checkpointing = True
# 不分割的模块列表
_no_split_modules = ["VipLlavaVisionAttention"]
# 跳过键设备放置
_skip_keys_device_placement = "past_key_values"
# 支持 Flash Attention 2
_supports_flash_attn_2 = True
# 初始化模型权重的方法,用于对传入的模块进行权重初始化
def _init_weights(self, module):
# 注意: 这个迁移版本的 VipLlava 不适用于从头训练,只能用于推理和微调。
# 因此,适当的初始化权重代码已经被移除。原始代码库位于 https://github.com/haotian-liu/LLaVA/tree/main/vipllava,可以用于训练目的。
# 根据配置获取初始化标准差
std = (
self.config.initializer_range
if hasattr(self.config, "initializer_range")
else self.config.text_config.initializer_range
)
# 如果模块具有类嵌入(class_embedding)属性,则对其进行标准正态分布初始化
if hasattr(module, "class_embedding"):
module.class_embedding.data.normal_(mean=0.0, std=std)
# 如果模块是线性层(nn.Linear)或二维卷积层(nn.Conv2d),则对权重进行标准正态分布初始化,
# 如果有偏置,则将偏置初始化为零
if isinstance(module, (nn.Linear, nn.Conv2d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
# 如果模块是嵌入层(nn.Embedding),则对权重进行标准正态分布初始化,
# 如果定义了填充索引(padding_idx),则将该索引处的权重初始化为零
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
@property
def _supports_sdpa(self):
"""
检索语言模型的属性,检查模型是否支持 SDPA(Self-Attention with Dual Paths)。
"""
return self.language_model._supports_sdpa
# 定义模型文档字符串,用于描述 VIPLLAVA 模型的输入
VIPLLAVA_INPUTS_DOCSTRING = r"""
"""
@add_start_docstrings(
"""The VIPLLAVA model which consists of a vision backbone and a language model.""",
VIPLLAVA_START_DOCSTRING,
)
# 从 transformers.models.llava.modeling_llava.LlavaForConditionalGeneration 复制而来,将 LLAVA 改为 VIPLLAVA,Llava 改为 VipLlava
class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel):
def __init__(self, config: VipLlavaConfig):
super().__init__(config)
# 初始化视觉塔模型,使用从配置中获取的视觉配置
self.vision_tower = AutoModel.from_config(config.vision_config)
# 初始化多模态投影器
self.multi_modal_projector = VipLlavaMultiModalProjector(config)
# 获取文本配置中的词汇表大小作为模型的词汇表大小
self.vocab_size = config.text_config.vocab_size
# 初始化语言模型,使用从配置中获取的文本配置和注意力实现方式
self.language_model = AutoModelForCausalLM.from_config(
config.text_config, attn_implementation=config._attn_implementation
)
# 如果配置中定义了 pad_token_id,则使用配置中的值;否则使用 -1
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
# 执行初始化后的处理
self.post_init()
def get_input_embeddings(self):
# 获取语言模型的输入嵌入层
return self.language_model.get_input_embeddings()
def set_input_embeddings(self, value):
# 设置语言模型的输入嵌入层
self.language_model.set_input_embeddings(value)
def get_output_embeddings(self):
# 获取语言模型的输出嵌入层
return self.language_model.get_output_embeddings()
def set_output_embeddings(self, new_embeddings):
# 设置语言模型的输出嵌入层
self.language_model.set_output_embeddings(new_embeddings)
def set_decoder(self, decoder):
# 设置语言模型的解码器
self.language_model.set_decoder(decoder)
def get_decoder(self):
# 获取语言模型的解码器
return self.language_model.get_decoder()
def tie_weights(self):
# 绑定语言模型的权重
return self.language_model.tie_weights()
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
# 调整语言模型的 token 嵌入层大小,并更新模型配置中的词汇表大小
model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
self.config.text_config.vocab_size = model_embeds.num_embeddings
self.vocab_size = model_embeds.num_embeddings
return model_embeds
@add_start_docstrings_to_model_forward(VIPLLAVA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=VipLlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
# 忽略复制
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
vision_feature_layers: Optional[List[int]] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
pass # 在正式实现前,暂时占位,不执行任何操作
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, inputs_embeds=None, pixel_values=None, attention_mask=None, **kwargs
):
pass # 在正式实现前,暂时占位,不执行任何操作
):
# 如果传入的过去键值不为 None,则处理缓存相关逻辑
if past_key_values is not None:
# 如果过去键值是 Cache 类型,则获取序列长度和已见标记数
if isinstance(past_key_values, Cache):
cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens
else:
# 否则,假设过去键值的第一个元素的第一个维度是 token 的形状的长度
cache_length = past_length = past_key_values[0][0].shape[2]
# 保留未处理的 token:
# 1 - 如果 attention_mask 的长度超过 input_ids 的长度,则处理仅作为缓存传递的情况
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
# 2 - 如果 past_length 小于 input_ids 的长度,则 input_ids 包含所有输入 token,可以基于 past_length 丢弃 input_ids
elif past_length < input_ids.shape[1]:
input_ids = input_ids[:, past_length:]
# 3 - 否则(past_length >= input_ids.shape[1]),假设 input_ids 只有未处理的 token
elif self.config.image_token_index in input_ids:
input_ids = input_ids[:, input_ids.shape[1] - 1 :]
# 如果缓存已见 token 数超过其容量限制,那么缓存有一个大小限制。丢弃较早的 attention 值,因为它们对应的值不是输入的一部分。
if cache_length < past_length and attention_mask is not None:
attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :]
position_ids = kwargs.get("position_ids", None)
# 如果 attention_mask 不为 None 且 position_ids 为 None,则在批量生成时动态创建 position_ids
if attention_mask is not None and position_ids is None:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
# 如果传入 inputs_embeds,则仅在第一代步骤中使用它们
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
# 更新 model_inputs 字典
model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
"pixel_values": pixel_values,
}
)
return model_inputs
# 重排序缓存的内部方法委托给语言模型的 _reorder_cache 方法
def _reorder_cache(self, *args, **kwargs):
return self.language_model._reorder_cache(*args, **kwargs)
.\models\vipllava\__init__.py
from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
_import_structure = {"configuration_vipllava": ["VIPLLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP", "VipLlavaConfig"]}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_vipllava"] = [
"VIPLLAVA_PRETRAINED_MODEL_ARCHIVE_LIST",
"VipLlavaForConditionalGeneration",
"VipLlavaPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_vipllava import VIPLLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP, VipLlavaConfig
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_vipllava import (
VIPLLAVA_PRETRAINED_MODEL_ARCHIVE_LIST,
VipLlavaForConditionalGeneration,
VipLlavaPreTrainedModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
.\models\vision_encoder_decoder\configuration_vision_encoder_decoder.py
from typing import TYPE_CHECKING, Any, Mapping, Optional, OrderedDict
from packaging import version
from ...utils import logging
from ..auto.configuration_auto import AutoConfig
if TYPE_CHECKING:
from ... import PreTrainedTokenizerBase, TensorType
logger = logging.get_logger(__name__)
class VisionEncoderDecoderConfig(PretrainedConfig):
r"""
[`VisionEncoderDecoderConfig`] 是配置类,用于存储 [`VisionEncoderDecoderModel`] 的配置信息。
用于根据指定的参数实例化一个 Vision-Encoder-Text-Decoder 模型,定义编码器和解码器的配置。
配置对象继承自 [`PretrainedConfig`],用于控制模型的输出。查阅 [`PretrainedConfig`] 的文档获取更多信息。
Args:
kwargs (*optional*):
关键字参数的字典。特别是:
- **encoder** ([`PretrainedConfig`], *optional*) -- 定义编码器配置的配置对象实例。
- **decoder** ([`PretrainedConfig`], *optional*) -- 定义解码器配置的配置对象实例。
Examples:
```
>>> from transformers import BertConfig, ViTConfig, VisionEncoderDecoderConfig, VisionEncoderDecoderModel
>>> # 初始化 ViT 和 BERT 风格的配置
>>> config_encoder = ViTConfig()
>>> config_decoder = BertConfig()
>>> config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(config_encoder, config_decoder)
>>> # 初始化一个 ViTBert 模型(具有随机权重),从 ViT 和 google-bert/bert-base-uncased 风格的配置开始
>>> model = VisionEncoderDecoderModel(config=config)
>>> # 访问模型配置
>>> config_encoder = model.config.encoder
>>> config_decoder = model.config.decoder
>>> # 将解码器配置设置为 causal lm
>>> config_decoder.is_decoder = True
>>> config_decoder.add_cross_attention = True
>>> # 保存模型,包括其配置
>>> model.save_pretrained("my-model")
>>> # 从预训练文件夹加载模型和配置
```
"""
pass
encoder_decoder_config = VisionEncoderDecoderConfig.from_pretrained("my-model")
model = VisionEncoderDecoderModel.from_pretrained("my-model", config=encoder_decoder_config)
model_type = "vision-encoder-decoder"
is_composition = True
def __init__(self, **kwargs):
super().__init__(**kwargs)
if "encoder" not in kwargs or "decoder" not in kwargs:
raise ValueError(
f"A configuraton of type {self.model_type} cannot be instantiated because "
f"not both `encoder` and `decoder` sub-configurations are passed, but only {kwargs}"
)
encoder_config = kwargs.pop("encoder")
encoder_model_type = encoder_config.pop("model_type")
decoder_config = kwargs.pop("decoder")
decoder_model_type = decoder_config.pop("model_type")
self.encoder = AutoConfig.for_model(encoder_model_type, **encoder_config)
self.decoder = AutoConfig.for_model(decoder_model_type, **decoder_config)
self.is_encoder_decoder = True
@classmethod
def from_encoder_decoder_configs(
cls, encoder_config: PretrainedConfig, decoder_config: PretrainedConfig, **kwargs
) -> PretrainedConfig:
"""
从预训练的编码器模型配置和解码器模型配置实例化一个 `VisionEncoderDecoderConfig`(或其派生类)。
返回:
[`VisionEncoderDecoderConfig`]: 配置对象的一个实例
"""
logger.info("Setting `config.is_decoder=True` and `config.add_cross_attention=True` for decoder_config")
decoder_config.is_decoder = True
decoder_config.add_cross_attention = True
return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict(), **kwargs)
class VisionEncoderDecoderEncoderOnnxConfig(OnnxConfig):
torch_onnx_minimum_version = version.parse("1.11")
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict(
[
("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
]
)
@property
def atol_for_validation(self) -> float:
return 1e-4
@property
def outputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict({"last_hidden_state": {0: "batch", 1: "encoder_sequence"}})
class VisionEncoderDecoderDecoderOnnxConfig(OnnxConfig):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
common_inputs = OrderedDict()
common_inputs["input_ids"] = {0: "batch", 1: "past_decoder_sequence + sequence"}
common_inputs["attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"}
common_inputs["encoder_hidden_states"] = {0: "batch", 1: "encoder_sequence"}
return common_inputs
def generate_dummy_inputs(
self,
tokenizer: "PreTrainedTokenizerBase",
batch_size: int = -1,
seq_length: int = -1,
is_pair: bool = False,
framework: Optional["TensorType"] = None,
) -> Mapping[str, Any]:
import torch
common_inputs = OrderedDict()
dummy_input = super().generate_dummy_inputs(
tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
)
batch, encoder_sequence = dummy_input["input_ids"].shape
encoder_hidden_states_shape = (batch, encoder_sequence, self._config.encoder_hidden_size)
common_inputs["input_ids"] = dummy_input.pop("input_ids")
common_inputs["attention_mask"] = dummy_input.pop("attention_mask")
common_inputs["encoder_hidden_states"] = torch.zeros(encoder_hidden_states_shape)
return common_inputs
class VisionEncoderDecoderOnnxConfig(OnnxConfig):
@property
def inputs(self) -> None:
pass
def get_encoder_config(self, encoder_config: PretrainedConfig) -> OnnxConfig:
r"""
返回用于 `VisionEncoderDecoder` 模型的 ONNX 编码器配置。
Args:
encoder_config (`PretrainedConfig`):
导出到 ONNX 时使用的编码器模型配置。
Returns:
[`VisionEncoderDecoderEncoderOnnxConfig`]: ONNX 配置对象的实例
"""
return VisionEncoderDecoderEncoderOnnxConfig(encoder_config)
def get_decoder_config(
self, encoder_config: PretrainedConfig, decoder_config: PretrainedConfig, feature: str = "default"
) -> OnnxConfig:
r"""
Returns ONNX decoder config for `VisionEncoderDecoder` model.
Args:
encoder_config (`PretrainedConfig`):
The encoder model's configuration to use when exporting to ONNX.
decoder_config (`PretrainedConfig`):
The decoder model's configuration to use when exporting to ONNX
feature (`str`, *optional*):
The type of feature to export the model with.
Returns:
[`VisionEncoderDecoderDecoderOnnxConfig`]: An instance of the ONNX configuration object.
"""
decoder_config.encoder_hidden_size = encoder_config.hidden_size
return VisionEncoderDecoderDecoderOnnxConfig(decoder_config, feature)
.\models\vision_encoder_decoder\modeling_flax_vision_encoder_decoder.py
import os
from typing import Optional, Tuple, Union
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax
from jax.random import PRNGKey
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutputWithCrossAttentions, FlaxSeq2SeqLMOutput
from ...modeling_flax_utils import FlaxPreTrainedModel
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from ..auto.configuration_auto import AutoConfig
from ..auto.modeling_flax_auto import FlaxAutoModel, FlaxAutoModelForCausalLM
from .configuration_vision_encoder_decoder import VisionEncoderDecoderConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "VisionEncoderDecoderConfig"
VISION_ENCODER_DECODER_START_DOCSTRING = r"""
This class can be used to initialize an image-to-text-sequence model with any pretrained vision autoencoding model
as the encoder and any pretrained text autoregressive model as the decoder. The encoder is loaded via
[`~AutoModel.from_pretrained`] function and the decoder is loaded via [`~AutoModelForCausalLM.from_pretrained`]
function. Cross-attention layers are automatically added to the decoder and should be fine-tuned on a downstream
generative task, like image captioning.
The effectiveness of initializing sequence-to-sequence models with pretrained checkpoints for sequence generation
tasks was shown in [Leveraging Pre-trained Checkpoints for Sequence Generation
Tasks](https://arxiv.org/abs/1907.12461) by Sascha Rothe, Shashi Narayan, Aliaksei Severyn. Michael Matena, Yanqi
Zhou, Wei Li, Peter J. Liu.
Additionally, in [TrOCR: Transformer-based Optical Character Recognition with Pre-trained
Models](https://arxiv.org/abs/2109.10282) it is shown how leveraging large pretrained vision models for optical
character recognition (OCR) yields a significant performance improvement.
After such a Vision-Encoder-Text-Decoder model has been trained/fine-tuned, it can be saved/loaded just like any
other models (see the examples for more information).
"""
"""
VISION_ENCODER_DECODER_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using the vision model's image processor. For example, using
[`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`] for details.
decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
Indices of decoder input sequence tokens in the vocabulary.
Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are decoder input IDs?](../glossary#decoder-input-ids)
decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
be used by default.
decoder_position_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
range `[0, config.decoder.max_position_embeddings - 1]`.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
If set to `True`, the model will return a [`~utils.FlaxSeq2SeqLMOutput`] instead of a plain tuple.
"""
VISION_ENCODER_DECODER_ENCODE_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using the vision model's image processor. For example, using
[`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`] for details.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
If set to `True`, the model will return a [`~utils.FlaxBaseModelOutput`] instead of a plain tuple.
"""
VISION_ENCODER_DECODER_DECODE_INPUTS_DOCSTRING = r"""
Args:
decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`):
Indices of decoder input sequence tokens in the vocabulary. These tokens are generated by the model during
decoding based on the provided `decoder_start_token_id`.
decoder_start_token_id (`int`):
The id of the token to start decoding with. This is usually the beginning-of-sequence token.
encoder_outputs (`Union[FlaxBaseModelOutput, Tuple[jnp.ndarray]]`):
Tuple comprising various elements depending on the configuration and inputs: logits as a jnp.ndarray of
shape `(batch_size, sequence_length, vocab_size)`, hidden_states as a tuple of length `num_layers` with
each element being a jnp.ndarray of shape `(batch_size, sequence_length, hidden_size)`, attentions as a
tuple of length `num_layers` with each element being a jnp.ndarray of shape `(batch_size, num_heads,
sequence_length, sequence_length)`, and others.
attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
1's in positions corresponding to input tokens to ignore and 0's in positions corresponding to input tokens
to attend to. It's used to mask pad tokens in input sentences. It's also used to indicate the position of
input tokens.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
If set to `True`, the model will return a [`~utils.FlaxSeq2SeqLMOutput`] instead of a plain tuple.
"""
# 定义函数参数和其类型,以下是函数的解释说明
Args:
decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
解码器输入序列标记的索引,大小为`(batch_size, target_sequence_length)`,可选。
可以使用[`PreTrainedTokenizer`]获取这些索引。详见[`PreTrainedTokenizer.encode`]和[`PreTrainedTokenizer.__call__`]。
[什么是解码器输入 ID?](../glossary#decoder-input-ids)
如果使用了 `past_key_values`,则可选地只需输入最后的 `decoder_input_ids`(参见 `past_key_values`)。
对于序列到序列的训练,应提供 `decoder_input_ids`。如果没有提供 `decoder_input_ids`,模型将通过将 `input_ids` 向右移动创建此张量,用于去噪预训练。
encoder_outputs (`tuple(tuple(jnp.ndarray)`):
元组由 (`last_hidden_state`, *可选*: `hidden_states`, *可选*: `attentions`) 组成。
`last_hidden_state` 大小为 `(batch_size, sequence_length, hidden_size)`,*可选*,是编码器最后一层的隐藏状态输出序列。用于解码器的交叉注意力。
decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
默认行为:生成一个张量,忽略 `decoder_input_ids` 中的填充标记。因果蒙版也将默认使用。
decoder_position_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
每个解码器输入序列标记在位置嵌入中的位置索引。选择范围为 `[0, config.decoder.max_position_embeddings - 1]`。
past_key_values (`Dict[str, jnp.ndarray]`, *optional*, 由 `init_cache` 返回或在传递先前的 `past_key_values` 时返回):
预计算的隐藏状态字典(键和值在注意力块中)。可用于快速自回归解码。预计算的键和值隐藏状态的形状为 *[batch_size, max_length]*。
output_attentions (`bool`, *optional*):
是否返回所有注意力层的注意力张量。有关更多详细信息,请参见返回的张量下的 `attentions`。
output_hidden_states (`bool`, *optional*):
是否返回所有层的隐藏状态。有关更多详细信息,请参见返回的张量下的 `hidden_states`。
return_dict (`bool`, *optional*):
如果设置为 `True`,模型将返回 [`~utils.FlaxCausalLMOutputWithCrossAttentions`] 而不是普通元组。
"""
class FlaxVisionEncoderDecoderModule(nn.Module):
config: VisionEncoderDecoderConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
encoder_config = self.config.encoder
decoder_config = self.config.decoder
from ...models.auto.modeling_flax_auto import FLAX_MODEL_FOR_CAUSAL_LM_MAPPING, FLAX_MODEL_MAPPING
encoder_module = FLAX_MODEL_MAPPING[encoder_config.__class__].module_class
decoder_module = FLAX_MODEL_FOR_CAUSAL_LM_MAPPING[decoder_config.__class__].module_class
self.encoder = encoder_module(encoder_config, dtype=self.dtype)
self.decoder = decoder_module(decoder_config, dtype=self.dtype)
if (
self.encoder.config.hidden_size != self.decoder.config.hidden_size
and self.decoder.config.cross_attention_hidden_size is None
):
self.enc_to_dec_proj = nn.Dense(
self.decoder.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.decoder.config.initializer_range),
dtype=self.dtype,
)
else:
self.enc_to_dec_proj = None
def _get_encoder_module(self):
return self.encoder
def _get_projection_module(self):
return self.enc_to_dec_proj
def _get_decoder_module(self):
return self.decoder
def __call__(
self,
pixel_values,
decoder_input_ids,
decoder_attention_mask,
decoder_position_ids,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
deterministic: bool = True,
):
encoder_outputs = self.encoder(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=deterministic,
)
encoder_hidden_states = encoder_outputs[0]
if self.enc_to_dec_proj is not None:
encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
batch_size, sequence_length = encoder_hidden_states.shape[:2]
encoder_attention_mask = jnp.ones((batch_size, sequence_length))
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
position_ids=decoder_position_ids,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=deterministic,
)
if not return_dict:
return decoder_outputs + encoder_outputs
return FlaxSeq2SeqLMOutput(
logits=decoder_outputs.logits,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)
@add_start_docstrings(VISION_ENCODER_DECODER_START_DOCSTRING)
class FlaxVisionEncoderDecoderModel(FlaxPreTrainedModel):
r"""
[`FlaxVisionEncoderDecoderModel`] is a generic model class that will be instantiated as a transformer architecture
with the module (flax.nn.Module) of one of the base vision model classes of the library as encoder module and
another one as decoder module when created with the :meth*~transformers.FlaxAutoModel.from_pretrained* class method
for the encoder and :meth*~transformers.FlaxAutoModelForCausalLM.from_pretrained* class method for the decoder.
"""
config_class = VisionEncoderDecoderConfig
base_model_prefix = "vision_encoder_decoder"
main_input_name = "pixel_values"
module_class = FlaxVisionEncoderDecoderModule
def __init__(
self,
config: VisionEncoderDecoderConfig,
input_shape: Optional[Tuple] = None,
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs,
):
if not _do_init:
raise ValueError(
"`FlaxVisionEncoderDecoderModel` cannot be created without initializing, `_do_init` must be `True`."
)
if input_shape is None:
num_channels = getattr(config.encoder, "num_channels", 3)
input_shape = (
(1, config.encoder.image_size, config.encoder.image_size, num_channels),
(1, 1),
)
if config.decoder.cross_attention_hidden_size is not None:
if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
raise ValueError(
"If `cross_attention_hidden_size` is specified in the decoder's configuration, it has to be equal"
f" to the encoder's `hidden_size`. Got {config.decoder.cross_attention_hidden_size} for"
f" `config.decoder.cross_attention_hidden_size` and {config.encoder.hidden_size} for"
" `config.encoder.hidden_size`."
)
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
encoder_input_shape, decoder_input_shape = input_shape
pixel_values = jnp.zeros(encoder_input_shape, dtype=self.dtype)
decoder_input_ids = jnp.zeros(decoder_input_shape, dtype="i4")
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
batch_size, _, _, _ = pixel_values.shape
decoder_batch_size, decoder_sequence_length = decoder_input_ids.shape
if not decoder_batch_size == batch_size:
raise ValueError(
f"The inputs of encoder and decoder should have the same batch size, but got {batch_size} for encoder "
f"and {decoder_batch_size} for decoder."
)
decoder_position_ids = jnp.broadcast_to(
jnp.arange(decoder_sequence_length)[None, :], (decoder_batch_size, decoder_sequence_length)
)
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
random_params = self.module.init(
rngs,
pixel_values,
decoder_input_ids,
decoder_attention_mask,
decoder_position_ids,
)["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
@add_start_docstrings(VISION_ENCODER_DECODER_ENCODE_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=_CONFIG_FOR_DOC)
def encode(
self,
pixel_values: jnp.ndarray,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
r"""
Args:
pixel_values (`jnp.ndarray`):
Pixel values of the input images. A tensor of shape `(batch_size, channels, height, width)`.
output_attentions (`Optional[bool]`, optional):
Whether to return attentions weights. Defaults to `None`.
output_hidden_states (`Optional[bool]`, optional):
Whether to return hidden states. Defaults to `None`.
return_dict (`Optional[bool]`, optional):
Whether to return a dictionary instead of a tuple of outputs. Defaults to `None`.
train (`bool`, optional):
Whether in training mode. Defaults to `False`.
params (`dict`, optional):
Optional parameters for the encoding process. Defaults to `None`.
dropout_rng (`PRNGKey`, optional):
Random number generator key for dropout. Defaults to `None`.
"""
):
r"""
Returns:
Example:
```
>>> from transformers import AutoImageProcessor, FlaxVisionEncoderDecoderModel
>>> from PIL import Image
>>> import requests
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
>>> # initialize a vit-gpt2 from pretrained ViT and GPT2 models. Note that the cross-attention layers will be randomly initialized
>>> model = FlaxVisionEncoderDecoderModel.from_encoder_decoder_pretrained(
... "google/vit-base-patch16-224-in21k", "openai-community/gpt2"
... )
>>> pixel_values = image_processor(images=image, return_tensors="np").pixel_values
>>> encoder_outputs = model.encode(pixel_values)
```
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1))
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
def _encoder_forward(module, pixel_values, **kwargs):
encode_module = module._get_encoder_module()
return encode_module(pixel_values, **kwargs)
outputs = self.module.apply(
{"params": params or self.params},
pixel_values=jnp.array(pixel_values, dtype=self.dtype),
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=not train,
rngs=rngs,
method=_encoder_forward,
)
if return_dict:
outputs = FlaxBaseModelOutput(
last_hidden_state=outputs.last_hidden_state,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
return outputs
@add_start_docstrings(VISION_ENCODER_DECODER_DECODE_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
def decode(
self,
decoder_input_ids,
encoder_outputs,
decoder_attention_mask: Optional[jnp.ndarray] = None,
decoder_position_ids: Optional[jnp.ndarray] = None,
past_key_values: dict = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
):
pass
@add_start_docstrings_to_model_forward(VISION_ENCODER_DECODER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
def __call__(
self,
pixel_values: jnp.ndarray,
decoder_input_ids: Optional[jnp.ndarray] = None,
decoder_attention_mask: Optional[jnp.ndarray] = None,
decoder_position_ids: Optional[jnp.ndarray] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
):
pass
def prepare_inputs_for_generation(
self,
decoder_input_ids,
max_length,
decoder_attention_mask: Optional[jax.Array] = None,
encoder_outputs=None,
**kwargs,
):
batch_size, seq_length = decoder_input_ids.shape
past_key_values = self.init_cache(batch_size, max_length, encoder_outputs)
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
if decoder_attention_mask is not None:
decoder_position_ids = decoder_attention_mask.cumsum(axis=-1) - 1
extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0))
else:
decoder_position_ids = jnp.broadcast_to(
jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)
)
return {
"past_key_values": past_key_values,
"encoder_outputs": encoder_outputs,
"decoder_attention_mask": extended_attention_mask,
"decoder_position_ids": decoder_position_ids,
}
def update_inputs_for_generation(self, model_outputs, model_kwargs):
model_kwargs["past_key_values"] = model_outputs.past_key_values
model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1
return model_kwargs
@classmethod
def from_encoder_decoder_pretrained(
cls,
encoder_pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
decoder_pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
*model_args,
**kwargs,
.\models\vision_encoder_decoder\modeling_tf_vision_encoder_decoder.py
from ...configuration_utils import PretrainedConfig
from ...modeling_tf_outputs import TFBaseModelOutput, TFSeq2SeqLMOutput
from ...modeling_tf_utils import TFCausalLanguageModelingLoss, TFPreTrainedModel, get_initializer, keras, unpack_inputs
from ...tf_utils import shape_list
from ...utils import (
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from ..auto.configuration_auto import AutoConfig
from ..auto.modeling_tf_auto import TFAutoModel, TFAutoModelForCausalLM
from .configuration_vision_encoder_decoder import VisionEncoderDecoderConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "VisionEncoderDecoderConfig"
DEPRECATION_WARNING = (
"Version v4.17.0 introduces a better way to train encoder-decoder models by computing the loss inside the"
" encoder-decoder framework rather than in the decoder itself. You may observe training discrepancies if"
" fine-tuning a model trained with versions anterior to 4.17.0. The decoder_input_ids are now created based on the"
" labels, no need to pass them yourself anymore."
)
VISION_ENCODER_DECODER_START_DOCSTRING = r"""
This class can be used to initialize an image-to-text-sequence model with any pretrained vision autoencoding model
as the encoder and any pretrained text autoregressive model as the decoder. The encoder is loaded via
[`~TFAutoModel.from_pretrained`] function and the decoder is loaded via [`~TFAutoModelForCausalLM.from_pretrained`]
function. Cross-attention layers are automatically added to the decoder and should be fine-tuned on a downstream
generative task, like image captioning.
The effectiveness of initializing sequence-to-sequence models with pretrained checkpoints for sequence generation
tasks was shown in [Leveraging Pre-trained Checkpoints for Sequence Generation
Tasks](https://arxiv.org/abs/1907.12461) by Sascha Rothe, Shashi Narayan, Aliaksei Severyn. Michael Matena, Yanqi
Zhou, Wei Li, Peter J. Liu.
Additionally, in [TrOCR: Transformer-based Optical Character Recognition with Pre-trained
# 在论文[Large Pretrained Vision Models](https://arxiv.org/abs/2109.10282)中展示了如何利用大型预训练视觉模型进行光学字符识别(OCR),从而显著提高性能。
#
# 训练/微调了这样的视觉-编码器-文本-解码器模型后,可以像处理其他模型一样保存/加载它(参见示例以获取更多信息)。
#
# 这个模型继承自[`TFPreTrainedModel`]。请查阅超类文档,了解库为所有模型实现的通用方法(例如下载或保存、调整输入嵌入、剪枝头等)。
#
# 这个模型也是一个[keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model)子类。可以将其作为常规的 TF 2.0 Keras 模型使用,并参考 TF 2.0 的文档了解所有与一般使用和行为相关的事项。
#
# 参数:
# config ([`VisionEncoderDecoderConfig`]): 包含模型所有参数的配置类。
# 使用配置文件初始化模型不会加载与模型关联的权重,只加载配置。查看[`~TFPreTrainedModel.from_pretrained`]方法以加载模型权重。
"""
VISION_ENCODER_DECODER_INPUTS_DOCSTRING = r"""
"""
def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int):
if pad_token_id is None:
raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.")
pad_token_id = tf.cast(pad_token_id, input_ids.dtype)
if decoder_start_token_id is None:
raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.")
decoder_start_token_id = tf.cast(decoder_start_token_id, input_ids.dtype)
start_tokens = tf.fill((shape_list(input_ids)[0], 1), decoder_start_token_id)
shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1)
shifted_input_ids = tf.where(
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
)
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
with tf.control_dependencies([assert_gte0]):
shifted_input_ids = tf.identity(shifted_input_ids)
return shifted_input_ids
@add_start_docstrings(VISION_ENCODER_DECODER_START_DOCSTRING)
class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
r"""
[`TFVisionEncoderDecoderModel`] 是一个通用模型类,当使用 [`~TFAutoModel.from_pretrained`] 类方法为编码器创建一个基本视觉模型类,
并使用 [`~TFAutoModelForCausalLM.from_pretrained`] 类方法为解码器创建另一个基本模型类时,它将被实例化为一个转换器架构。
"""
config_class = VisionEncoderDecoderConfig
base_model_prefix = "vision_encoder_decoder"
load_weight_prefix = "tf_vision_encoder_decoder_model"
main_input_name = "pixel_values"
def __init__(
self,
config: Optional[PretrainedConfig] = None,
encoder: Optional[TFPreTrainedModel] = None,
decoder: Optional[TFPreTrainedModel] = None,
):
if config is None and (encoder is None or decoder is None):
raise ValueError("Either a configuration or an encoder and a decoder has to be provided.")
if config is None:
config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config)
else:
if not isinstance(config, self.config_class):
raise ValueError(f"config: {config} has to be of type {self.config_class}")
if config.decoder.cross_attention_hidden_size is not None:
if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
raise ValueError(
"If `cross_attention_hidden_size` is specified in the decoder's configuration, it has to be equal"
f" to the encoder's `hidden_size`. Got {config.decoder.cross_attention_hidden_size} for"
f" `config.decoder.cross_attention_hidden_size` and {config.encoder.hidden_size} for"
" `config.encoder.hidden_size`."
)
super().__init__(config)
if encoder is None:
encoder = TFAutoModel.from_config(config.encoder, name="encoder")
if decoder is None:
decoder = TFAutoModelForCausalLM.from_config(config.decoder, name="decoder")
self.encoder = encoder
self.decoder = decoder
if self.encoder.config.to_dict() != self.config.encoder.to_dict():
logger.warning(
f"Config of the encoder: {self.encoder.__class__} is overwritten by shared encoder config:"
f" {self.config.encoder}"
)
if self.decoder.config.to_dict() != self.config.decoder.to_dict():
logger.warning(
f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config:"
f" {self.config.decoder}"
)
self.encoder.config = self.config.encoder
self.decoder.config = self.config.decoder
if (
self.encoder.get_output_embeddings() is not None:
raise ValueError(
f"The encoder {self.encoder} should not have a LM Head. Please use a model without LM Head"
)
def input_signature(self):
vision_config = self.config.encoder
if hasattr(vision_config, "vision_config"):
vision_config = vision_config.vision_config
if hasattr(vision_config, "image_size"):
image_size = vision_config.image_size
else:
image_size = vision_config.input_size
return {
"pixel_values": tf.TensorSpec(
shape=(
None,
vision_config.num_channels,
image_size,
image_size,
),
dtype=tf.float32,
),
"decoder_input_ids": tf.TensorSpec(shape=(None, None), dtype=tf.int32, name="decoder_input_ids"),
}
def get_encoder(self):
return self.encoder
def get_decoder(self):
return self.decoder
def get_input_embeddings(self):
return self.encoder.get_input_embeddings()
def get_output_embeddings(self):
return self.decoder.get_output_embeddings()
def set_output_embeddings(self, new_embeddings):
return self.decoder.set_output_embeddings(new_embeddings)
def tf_to_pt_weight_rename(self, tf_weight):
encoder_model_type = self.config.encoder.model_type
if "encoder" in tf_weight and "decoder" not in tf_weight:
return (re.sub(rf"encoder\.{encoder_model_type}\.", "encoder.", tf_weight),)
else:
return (tf_weight,)
@classmethod
def from_encoder_decoder_pretrained(
cls,
encoder_pretrained_model_name_or_path: str = None,
decoder_pretrained_model_name_or_path: str = None,
*model_args,
**kwargs,
):
pass
@unpack_inputs
@add_start_docstrings_to_model_forward(
VISION_ENCODER_DECODER_INPUTS_DOCSTRING.format("batch_size, sequence_length")
)
@replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
def forward(self, **kwargs):
pass
def serving_output(self, output):
pkv = tf.tuple(output.past_key_values)[1] if self.config.decoder.use_cache else None
dec_hs = (
tf.convert_to_tensor(output.decoder_hidden_states) if self.config.decoder.output_hidden_states else None
)
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.decoder.output_attentions else None
enc_hs = (
tf.convert_to_tensor(output.encoder_hidden_states) if self.config.encoder.output_hidden_states else None
)
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.encoder.output_attentions else None
cross_attns = (
tf.convert_to_tensor(output.cross_attentions)
if self.config.decoder.output_attentions and output.cross_attentions is not None
else None
)
return TFSeq2SeqLMOutput(
logits=output.logits,
past_key_values=pkv,
decoder_hidden_states=dec_hs,
decoder_attentions=dec_attns,
encoder_last_hidden_state=output.encoder_last_hidden_state,
encoder_hidden_states=enc_hs,
encoder_attentions=enc_attns,
cross_attentions=cross_attns,
)
):
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past_key_values=past_key_values)
decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None
past_key_values = decoder_inputs.get("past_key_values")
input_dict = {
"pixel_values": None,
"attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask,
"decoder_input_ids": decoder_inputs["input_ids"],
"encoder_outputs": TFBaseModelOutput(last_hidden_state=encoder_outputs[0]),
"past_key_values": past_key_values,
"use_cache": use_cache,
}
return input_dict
def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
def resize_token_embeddings(self, *args, **kwargs):
raise NotImplementedError(
"Resizing the embedding layers via the TFVisionEncoderDecoderModel directly is not supported. "
"Please use the respective methods of the wrapped objects (model.decoder.resize_token_embeddings(...))"
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "enc_to_dec_proj", None) is not None:
with tf.name_scope(self.enc_to_dec_proj.name):
self.enc_to_dec_proj.build([None, None, self.encoder.config.hidden_size])
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
if getattr(self, "decoder", None) is not None:
with tf.name_scope(self.decoder.name):
self.decoder.build(None)
.\models\vision_encoder_decoder\modeling_vision_encoder_decoder.py
""" 用于支持 Vision-Encoder-Text-Decoder 结构的类"""
import gc
import os
import tempfile
from typing import Optional, Tuple, Union
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from ...configuration_utils import PretrainedConfig
from ...modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
from ...modeling_utils import PreTrainedModel
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from ..auto.configuration_auto import AutoConfig
from ..auto.modeling_auto import AutoModel, AutoModelForCausalLM
from .configuration_vision_encoder_decoder import VisionEncoderDecoderConfig
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
"""
将输入的 token 向右移动一个位置。
"""
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
if decoder_start_token_id is None:
raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.")
shifted_input_ids[:, 0] = decoder_start_token_id
if pad_token_id is None:
raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.")
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
return shifted_input_ids
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "VisionEncoderDecoderConfig"
VISION_ENCODER_DECODER_START_DOCSTRING = r"""
此类可用于初始化一个图像到文本序列模型,其中编码器是任何预训练的视觉自编码模型,解码器是任何预训练的文本自回归模型。
编码器通过 [`~AutoModel.from_pretrained`] 函数加载,解码器通过 [`~AutoModelForCausalLM.from_pretrained`] 函数加载。
交叉注意力层会自动添加到解码器中,并应在下游生成任务(如图像字幕)中进行微调。
初始化序列到序列模型时使用预训练检查点的有效性
tasks was shown in [Leveraging Pre-trained Checkpoints for Sequence Generation
Tasks](https://arxiv.org/abs/1907.12461) by Sascha Rothe, Shashi Narayan, Aliaksei Severyn. Michael Matena, Yanqi
Zhou, Wei Li, Peter J. Liu.
Additionally, in [TrOCR: Transformer-based Optical Character Recognition with Pre-trained
Models](https://arxiv.org/abs/2109.10282) it is shown how leveraging large pretrained vision models for optical
character recognition (OCR) yields a significant performance improvement.
After such a Vision-Encoder-Text-Decoder model has been trained/fine-tuned, it can be saved/loaded just like any
other models (see the examples for more information).
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`VisionEncoderDecoderConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
VISION_ENCODER_DECODER_INPUTS_DOCSTRING = r"""
"""
@add_start_docstrings(VISION_ENCODER_DECODER_START_DOCSTRING)
class VisionEncoderDecoderModel(PreTrainedModel):
r"""
[`VisionEncoderDecoderModel`] is a generic model class that will be instantiated as a transformer architecture with
one of the base vision model classes of the library as encoder and another one as decoder when created with the
:meth*~transformers.AutoModel.from_pretrained* class method for the encoder and
:meth*~transformers.AutoModelForCausalLM.from_pretrained* class method for the decoder.
"""
config_class = VisionEncoderDecoderConfig
base_model_prefix = "vision_encoder_decoder"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
def __init__(
self,
config: Optional[PretrainedConfig] = None,
encoder: Optional[PreTrainedModel] = None,
decoder: Optional[PreTrainedModel] = None,
):
raise ValueError("Either a configuration or an encoder and a decoder has to be provided.")
if config is None:
config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config)
else:
if not isinstance(config, self.config_class):
raise ValueError(f"Config: {config} has to be of type {self.config_class}")
if config.decoder.cross_attention_hidden_size is not None:
if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
raise ValueError(
"If `cross_attention_hidden_size` is specified in the decoder's configuration, it has to be equal"
f" to the encoder's `hidden_size`. Got {config.decoder.cross_attention_hidden_size} for"
f" `config.decoder.cross_attention_hidden_size` and {config.encoder.hidden_size} for"
" `config.encoder.hidden_size`."
)
config.tie_word_embeddings = False
super().__init__(config)
if encoder is None:
encoder = AutoModel.from_config(config.encoder)
if decoder is None:
decoder = AutoModelForCausalLM.from_config(config.decoder)
self.encoder = encoder
self.decoder = decoder
if self.encoder.config.to_dict() != self.config.encoder.to_dict():
logger.warning(
f"Config of the encoder: {self.encoder.__class__} is overwritten by shared encoder config:"
f" {self.config.encoder}"
)
if self.decoder.config.to_dict() != self.config.decoder.to_dict():
logger.warning(
f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config:"
f" {self.config.decoder}"
)
self.encoder.config = self.config.encoder
self.decoder.config = self.config.decoder
if (
self.encoder.config.hidden_size != self.decoder.config.hidden_size
and self.decoder.config.cross_attention_hidden_size is None
):
self.enc_to_dec_proj = nn.Linear(self.encoder.config.hidden_size, self.decoder.config.hidden_size)
if self.encoder.get_output_embeddings() is not None:
raise ValueError(
f"The encoder {self.encoder} should not have a LM Head. Please use a model without LM Head"
)
def get_encoder(self):
return self.encoder
def get_decoder(self):
return self.decoder
def get_output_embeddings(self):
return self.decoder.get_output_embeddings()
def set_output_embeddings(self, new_embeddings):
return self.decoder.set_output_embeddings(new_embeddings)
@classmethod
def from_encoder_decoder_pretrained(
cls,
encoder_pretrained_model_name_or_path: str = None,
decoder_pretrained_model_name_or_path: str = None,
*model_args,
**kwargs,
):
pass
@add_start_docstrings_to_model_forward(VISION_ENCODER_DECODER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.BoolTensor] = None,
encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
):
pass
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
):
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past_key_values=past_key_values)
decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None
input_dict = {
"attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask,
"decoder_input_ids": decoder_inputs["input_ids"],
"encoder_outputs": encoder_outputs,
"past_key_values": decoder_inputs["past_key_values"],
"use_cache": use_cache,
}
return input_dict
def resize_token_embeddings(self, *args, **kwargs):
raise NotImplementedError(
"Resizing the embedding layers via the VisionEncoderDecoderModel directly is not supported.Please use the"
" respective methods of the wrapped decoder object (model.decoder.resize_token_embeddings(...))"
)
def _reorder_cache(self, past_key_values, beam_idx):
return self.decoder._reorder_cache(past_key_values, beam_idx)
.\models\vision_encoder_decoder\__init__.py
from typing import TYPE_CHECKING
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_flax_available,
is_tf_available,
is_torch_available,
)
_import_structure = {
"configuration_vision_encoder_decoder": ["VisionEncoderDecoderConfig", "VisionEncoderDecoderOnnxConfig"]
}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_vision_encoder_decoder"] = ["VisionEncoderDecoderModel"]
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_tf_vision_encoder_decoder"] = ["TFVisionEncoderDecoderModel"]
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_flax_vision_encoder_decoder"] = ["FlaxVisionEncoderDecoderModel"]
if TYPE_CHECKING:
from .configuration_vision_encoder_decoder import VisionEncoderDecoderConfig, VisionEncoderDecoderOnnxConfig
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_vision_encoder_decoder import VisionEncoderDecoderModel
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_tf_vision_encoder_decoder import TFVisionEncoderDecoderModel
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_flax_vision_encoder_decoder import FlaxVisionEncoderDecoderModel
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
.\models\vision_text_dual_encoder\configuration_vision_text_dual_encoder.py
from ...configuration_utils import PretrainedConfig
from ...utils import logging
from ..auto.configuration_auto import AutoConfig
from ..chinese_clip.configuration_chinese_clip import ChineseCLIPVisionConfig
from ..clip.configuration_clip import CLIPVisionConfig
from ..siglip.configuration_siglip import SiglipVisionConfig
logger = logging.get_logger(__name__)
VISION_MODEL_CONFIGS = {
"clip_vision_model": CLIPVisionConfig,
"chinese_clip_vision_model": ChineseCLIPVisionConfig,
"siglip_vision_model": SiglipVisionConfig,
}
class VisionTextDualEncoderConfig(PretrainedConfig):
r"""
[`VisionTextDualEncoderConfig`] 是一个配置类,用于存储 [`VisionTextDualEncoderModel`] 的配置信息。
根据指定的参数实例化 [`VisionTextDualEncoderModel`] 模型,定义了文本模型和视觉模型的配置。
配置对象继承自 [`PretrainedConfig`],可用于控制模型的输出。更多信息请参阅 [`PretrainedConfig`] 的文档。
Args:
projection_dim (`int`, *optional*, defaults to 512):
文本和视觉投影层的维度。
logit_scale_init_value (`float`, *optional*, defaults to 2.6592):
*logit_scale* 参数的初始值。默认值按照原始 CLIP 实现使用。
kwargs (*optional*):
字典形式的关键字参数。
Examples:
```
>>> from transformers import ViTConfig, BertConfig, VisionTextDualEncoderConfig, VisionTextDualEncoderModel
>>> # 初始化 BERT 和 ViT 的配置
>>> config_vision = ViTConfig()
>>> config_text = BertConfig()
>>> config = VisionTextDualEncoderConfig.from_vision_text_configs(config_vision, config_text, projection_dim=512)
>>> # 初始化一个带有随机权重的 BERT 和 ViT 模型
>>> model = VisionTextDualEncoderModel(config=config)
>>> # 访问模型配置
>>> config_vision = model.config.vision_config
>>> config_text = model.config.text_config
>>> # 保存模型及其配置
>>> model.save_pretrained("vit-bert")
>>> # 从预训练文件夹加载模型和配置
```
# 从预训练模型“vit-bert”加载视觉文本双编码器配置
vision_text_config = VisionTextDualEncoderConfig.from_pretrained("vit-bert")
# 使用加载的配置实例化视觉文本双编码器模型
model = VisionTextDualEncoderModel.from_pretrained("vit-bert", config=vision_text_config)
# 设定模型类型为“vision-text-dual-encoder”
model_type = "vision-text-dual-encoder"
# 表示这个类是一个复合类
is_composition = True
def __init__(self, projection_dim=512, logit_scale_init_value=2.6592, **kwargs):
# 调用父类的初始化方法
super().__init__(**kwargs)
# 检查是否提供了视觉配置参数
if "vision_config" not in kwargs:
raise ValueError("`vision_config` can not be `None`.")
# 检查是否提供了文本配置参数
if "text_config" not in kwargs:
raise ValueError("`text_config` can not be `None`.")
# 弹出并获取视觉配置参数
vision_config = kwargs.pop("vision_config")
# 弹出并获取文本配置参数
text_config = kwargs.pop("text_config")
# 获取视觉模型类型
vision_model_type = vision_config.pop("model_type")
# 获取文本模型类型
text_model_type = text_config.pop("model_type")
# 根据视觉模型类型获取对应的配置类
vision_config_class = VISION_MODEL_CONFIGS.get(vision_model_type)
# 如果找到了对应的配置类,则使用提供的视觉配置参数实例化它
if vision_config_class is not None:
self.vision_config = vision_config_class(**vision_config)
# 否则,根据视觉模型类型和参数自动创建一个配置实例
else:
self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config)
# 如果这个配置实例本身有一个名为`vision_config`的属性,则将其设置为当前实例的`vision_config`
if hasattr(self.vision_config, "vision_config"):
self.vision_config = self.vision_config.vision_config
# 根据文本模型类型和参数自动创建一个文本配置实例
self.text_config = AutoConfig.for_model(text_model_type, **text_config)
# 设置投影维度参数
self.projection_dim = projection_dim
# 设置对数尺度初始化值参数
self.logit_scale_init_value = logit_scale_init_value
@classmethod
def from_vision_text_configs(cls, vision_config: PretrainedConfig, text_config: PretrainedConfig, **kwargs):
"""
从视觉模型配置和文本模型配置实例化一个[`VisionTextDualEncoderConfig`](或其派生类)。
Args:
vision_config (PretrainedConfig): 视觉模型配置的实例
text_config (PretrainedConfig): 文本模型配置的实例
**kwargs: 其他参数
Returns:
VisionTextDualEncoderConfig: 配置对象的一个实例
"""
return cls(vision_config=vision_config.to_dict(), text_config=text_config.to_dict(), **kwargs)
这些注释为每行代码提供了详细的解释,包括代码的目的、参数的作用以及返回值的说明。
.\models\vision_text_dual_encoder\modeling_flax_vision_text_dual_encoder.py
""" Flax VisionTextDualEncoder model."""
from typing import Optional, Tuple
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.traverse_util import flatten_dict, unflatten_dict
from ...modeling_flax_utils import FlaxPreTrainedModel, append_replace_return_docstrings, overwrite_call_docstring
from ...utils import add_start_docstrings, logging
from ..auto.configuration_auto import AutoConfig
from ..auto.modeling_flax_auto import FLAX_MODEL_MAPPING, FlaxAutoModel
from ..clip.modeling_flax_clip import FlaxCLIPOutput, FlaxCLIPVisionModel
from .configuration_vision_text_dual_encoder import VisionTextDualEncoderConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "VisionTextDualEncoderConfig"
VISION_TEXT_DUAL_ENCODER_START_DOCSTRING = r"""
This class can be used to initialize a vision-text dual encoder model with any pretrained vision autoencoding model
as the vision encoder and any pretrained text model as the text encoder. The vision and text encoders are loaded
via the [`~FlaxAutoModel.from_pretrained`] method. The projection layers are automatically added to the model and
should be fine-tuned on a downstream task, like contrastive image-text modeling.
In [LiT: Zero-Shot Transfer with Locked-image Text Tuning](https://arxiv.org/abs/2111.07991) it is shown how
leveraging pre-trained (locked/frozen) image and text model for contrastive learning yields significant improvment
on new zero-shot vision tasks such as image classification or retrieval.
After such a Vision-Text-Dual-Encoder model has been trained/fine-tuned, it can be saved/loaded just like any other
models (see the examples for more information).
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a
[flax.linen.Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) subclass. Use it
as a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and
behavior.
Finally, this model supports inherent JAX features such as:
# 导入必要的 JAX 模块,包括 JIT 编译、自动微分、向量化和并行化
import jax
import jax.numpy as jnp
from flax.training.common import VisionTextDualEncoderConfig
from transformers import FlaxPreTrainedModel
# 函数定义:初始化模型配置
def __init__(self, config: VisionTextDualEncoderConfig):
# 模型配置,包含所有模型参数的类
self.config = config
# 函数定义:设置计算的数据类型
def set_dtype(self, dtype=jnp.float32):
"""
设置计算的数据类型。
Parameters:
dtype (jax.numpy.dtype, optional, default=jax.numpy.float32):
计算的数据类型。可以是 `jax.numpy.float32`、`jax.numpy.float16`(在 GPU 上)和 `jax.numpy.bfloat16`(在 TPU 上)。
可以用于启用混合精度训练或在 GPU 或 TPU 上进行半精度推理。如果指定了 dtype,则所有的计算将使用给定的 dtype。
**注意:这只指定了计算的数据类型,不影响模型参数的数据类型。**
如果您希望更改模型参数的数据类型,请参见 `~FlaxPreTrainedModel.to_fp16` 和 `~FlaxPreTrainedModel.to_bf16`。
"""
self.dtype = dtype
"""
定义了一个字符串常量 VISION_TEXT_DUAL_ENCODER_INPUTS_DOCSTRING,用于存储文档字符串。
Args:
input_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`):
输入序列标记的索引,可以通过 AutoTokenizer 获取。默认情况下将忽略填充部分。
[什么是输入 ID?](../glossary
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
遮罩,用于避免在填充标记索引上执行注意力。遮罩值选择在 `[0, 1]`:
- 1 表示 **未遮罩** 的标记,
- 0 表示 **已遮罩** 的标记。
[什么是注意力遮罩?](../glossary
position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
每个输入序列标记在位置嵌入中的位置索引。选择范围在 `[0, config.max_position_embeddings - 1]`。
[什么是位置 ID?](../glossary
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
像素值。默认情况下将忽略填充部分。可以使用图像处理器获取像素值(例如,如果使用 ViT 作为编码器,应使用 `AutoImageProcessor`)。
[ViTImageProcessor.__call__] 获取更多细节。
output_attentions (`bool`, *optional*):
是否返回所有注意力层的注意力张量。详见返回的张量中的 `attentions` 获取更多细节。
output_hidden_states (`bool`, *optional*):
是否返回所有层的隐藏状态。详见返回的张量中的 `hidden_states` 获取更多细节。
return_dict (`bool`, *optional*):
是否返回 `~utils.ModelOutput` 而不是普通元组。
"""
class FlaxVisionTextDualEncoderModule(nn.Module):
config: VisionTextDualEncoderConfig
dtype: jnp.dtype = jnp.float32
# 设置函数的初始化操作,准备模型和参数配置
def setup(self):
# 从配置对象中获取视觉模型和文本模型的配置信息
vision_config = self.config.vision_config
text_config = self.config.text_config
# 设置视觉嵌入维度和文本嵌入维度
self.vision_embed_dim = vision_config.hidden_size
self.text_embed_dim = text_config.hidden_size
self.projection_dim = self.config.projection_dim
# 根据视觉模型和文本模型的配置选择相应的模型类
vision_module = FLAX_MODEL_MAPPING.get(self.config.vision_config.__class__, FlaxCLIPVisionModel).module_class
text_module = FLAX_MODEL_MAPPING[self.config.text_config.__class__].module_class
# 初始化视觉模型和文本模型
self.vision_model = vision_module(vision_config, dtype=self.dtype)
self.text_model = text_module(text_config, dtype=self.dtype)
# 初始化视觉和文本的投影层
self.visual_projection = nn.Dense(
self.projection_dim,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(0.02),
use_bias=False,
)
self.text_projection = nn.Dense(
self.projection_dim,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(0.02),
use_bias=False,
)
# 初始化logit的缩放系数,并将其作为模型的参数
self.logit_scale = self.param(
"logit_scale", lambda _, shape: jnp.ones(shape) * self.config.logit_scale_init_value, []
)
# 定义模型的调用方法,处理输入并返回模型的输出
def __call__(
self,
input_ids=None,
pixel_values=None,
attention_mask=None,
position_ids=None,
token_type_ids=None,
deterministic: bool = True,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
# 如果 return_dict 不为 None,则使用给定的 return_dict;否则使用对象自身的配置中的 return_dict
return_dict = return_dict if return_dict is not None else self.config.return_dict
# 使用视觉模型处理像素值,获取视觉输出,包括注意力权重和隐藏状态
vision_outputs = self.vision_model(
pixel_values=pixel_values,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# 使用文本模型处理输入,获取文本输出,包括注意力权重和隐藏状态
text_outputs = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# 从视觉输出中获取图像嵌入,并通过投影层进行处理
image_embeds = vision_outputs[1]
image_embeds = self.visual_projection(image_embeds)
# 从文本输出中获取文本嵌入,并通过投影层进行处理
text_embeds = text_outputs[1]
text_embeds = self.text_projection(text_embeds)
# 对图像嵌入进行标准化处理
image_embeds = image_embeds / jnp.linalg.norm(image_embeds, axis=-1, keepdims=True)
# 对文本嵌入进行标准化处理
text_embeds = text_embeds / jnp.linalg.norm(text_embeds, axis=-1, keepdims=True)
# 使用余弦相似度计算文本和图像嵌入之间的逻辑相似性得分
logit_scale = jnp.exp(self.logit_scale)
logits_per_text = jnp.matmul(text_embeds, image_embeds.T) * logit_scale
logits_per_image = logits_per_text.T
# 如果 return_dict 为 False,则返回包含多个输出的元组
if not return_dict:
return (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
# 如果 return_dict 为 True,则返回自定义的输出对象 FlaxCLIPOutput
return FlaxCLIPOutput(
logits_per_image=logits_per_image,
logits_per_text=logits_per_text,
text_embeds=text_embeds,
image_embeds=image_embeds,
text_model_output=text_outputs,
vision_model_output=vision_outputs,
)
# 将文本和视觉输入编码成嵌入向量的模型,继承自FlaxPreTrainedModel
@add_start_docstrings(VISION_TEXT_DUAL_ENCODER_START_DOCSTRING)
class FlaxVisionTextDualEncoderModel(FlaxPreTrainedModel):
# 使用VisionTextDualEncoderConfig作为配置类
config_class = VisionTextDualEncoderConfig
# 使用FlaxVisionTextDualEncoderModule作为模块类
module_class = FlaxVisionTextDualEncoderModule
def __init__(
self,
config: VisionTextDualEncoderConfig,
input_shape: Optional[Tuple] = None,
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs,
):
# 如果不初始化,则抛出错误
if not _do_init:
raise ValueError(
"`FlaxVisionTextDualEncoderModel` cannot be created without initializing, `_do_init` must be `True`."
)
# 如果未提供输入形状,则使用默认的输入形状
if input_shape is None:
input_shape = ((1, 1), (1, config.vision_config.image_size, config.vision_config.image_size, 3))
# 创建模块实例
module = self.module_class(config=config, dtype=dtype, **kwargs)
# 调用父类的初始化方法,传入配置、模块、输入形状、种子和数据类型
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# 初始化输入张量
input_ids = jnp.zeros(input_shape[0], dtype="i4")
# 生成位置编码
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape[0])
# 生成token类型编码,默认为1
token_type_ids = jnp.ones_like(input_ids)
# 生成注意力掩码,默认为全1
attention_mask = jnp.ones_like(input_ids)
# 生成像素值,使用正态分布随机数初始化
pixel_values = jax.random.normal(rng, input_shape[1])
# 分割随机数生成器,用于参数和dropout
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
# 初始化模块,获取随机参数
random_params = self.module.init(rngs, input_ids, pixel_values, attention_mask, position_ids, token_type_ids)[
"params"
]
# 如果提供了参数,则使用提供的参数替换缺失的随机参数
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
def __call__(
self,
input_ids,
pixel_values,
attention_mask=None,
position_ids=None,
token_type_ids=None,
params: dict = None,
dropout_rng: jax.random.PRNGKey = None,
train: bool = False,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
#```
dropout_rng: jax.random.PRNGKey = None,
train: bool = False,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
# 如果 output_attentions 不为 None,则使用其值;否则使用配置中的默认值
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
# 如果 output_hidden_states 不为 None,则使用其值;否则使用配置中的默认值
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
# 如果 return_dict 不为 None,则使用其值;否则使用配置中的默认值
return_dict = return_dict if return_dict is not None else self.config.return_dict
# 转置像素值数组,调整维度顺序
pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1))
# 如果 position_ids 为 None,则创建一个与 input_ids 最后一个维度广播兼容的数组
if position_ids is None:
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
# 如果 token_type_ids 为 None,则创建一个与 input_ids 形状相同的全零数组
if token_type_ids is None:
token_type_ids = jnp.zeros_like(input_ids)
# 如果 attention_mask 为 None,则创建一个与 input_ids 形状相同的全一数组
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids)
# 处理任何需要的伪随机数发生器 PRNG
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
# 调用 self.module.apply 方法,传递相关参数进行模型应用
return self.module.apply(
{"params": params or self.params},
jnp.array(input_ids, dtype="i4"),
jnp.array(pixel_values, dtype=jnp.float32),
jnp.array(attention_mask, dtype="i4"),
jnp.array(position_ids, dtype="i4"),
jnp.array(token_type_ids, dtype="i4"),
not train,
output_attentions,
output_hidden_states,
return_dict,
rngs=rngs,
)
# 定义一个方法 get_text_features,接受多个参数,用于获取文本特征
def get_text_features(
self,
input_ids,
attention_mask=None,
position_ids=None,
token_type_ids=None,
params: dict = None,
dropout_rng: jax.random.PRNGKey = None,
train=False,
):
):
r"""
Args:
input_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
provide it.
Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary
Returns:
text_features (`jnp.ndarray` of shape `(batch_size, output_dim`): The text embeddings obtained by applying
the projection layer to the pooled output of text model.
"""
# 如果未提供 position_ids 参数,则使用 input_ids 的长度广播生成位置 IDs
if position_ids is None:
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
# 如果未提供 token_type_ids 参数,则生成与 input_ids 形状相同的全零张量
if token_type_ids is None:
token_type_ids = jnp.zeros_like(input_ids)
# 如果未提供 attention_mask 参数,则生成与 input_ids 形状相同的全一张量
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids)
# 处理可能需要的任何伪随机数发生器(PRNG)
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
def _get_features(module, input_ids, attention_mask, position_ids, token_type_ids, deterministic):
# 调用文本模型获取文本输出
text_outputs = module.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
token_type_ids=token_type_ids,
deterministic=deterministic,
)
# 从文本输出中获取汇聚输出
pooled_output = text_outputs[1]
# 应用文本投影层获得文本特征
text_features = module.text_projection(pooled_output)
return text_features
# 调用模块的 apply 方法来应用参数和输入数据进行前向计算
return self.module.apply(
{"params": params or self.params}, # 提供模型参数
jnp.array(input_ids, dtype="i4"), # 输入的序列 token IDs
jnp.array(attention_mask, dtype="i4"), # 输入的注意力掩码
jnp.array(position_ids, dtype="i4"), # 输入的位置 IDs
jnp.array(token_type_ids, dtype="i4"), # 输入的 token 类型 IDs
not train, # 是否是推理模式
method=_get_features, # 调用的方法来获取特征
rngs=rngs, # 伪随机数发生器的字典
)
def get_image_features(
self, pixel_values, params: dict = None, dropout_rng: jax.random.PRNGKey = None, train=False
):
r"""
Args:
pixel_values (`numpy.ndarray` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained
using [`ImageFeatureExtractionMixin`]. See [`ImageFeatureExtractionMixin.__call__`] for details.
Returns:
image_features (`jnp.ndarray` of shape `(batch_size, output_dim`): The image embeddings obtained by
applying the projection layer to the pooled output of vision model.
"""
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
# 定义一个内部函数,用于从视觉模型中提取特征
def _get_features(module, pixel_values, deterministic):
# 调用视觉模型,传入像素值和确定性参数,获取视觉模型的输出
vision_outputs = module.vision_model(pixel_values=pixel_values, deterministic=deterministic)
# 提取汇总输出(通常是第二个输出)
pooled_output = vision_outputs[1] # pooled_output
# 将汇总输出应用于视觉投影层,获取图像特征
image_features = module.visual_projection(pooled_output)
return image_features
# 调用当前对象所包含的模块的 apply 方法,将参数和数据传入视觉模型处理函数
return self.module.apply(
{"params": params or self.params}, # 使用给定的参数或对象的参数
jnp.array(pixel_values, dtype=jnp.float32), # 将像素值转换为 jax 数组
not train, # 确定是否为训练模式
method=_get_features, # 指定处理方法为 _get_features 函数
rngs=rngs, # 传入任何可能需要的随机数生成器
)
@classmethod
def from_vision_text_pretrained(
cls,
vision_model_name_or_path: str = None,
text_model_name_or_path: str = None,
*model_args,
**kwargs,
# 定义 VisionTextDualEncoderModel 的文档字符串,包含函数的返回值和示例用法
VISION_TEXT_DUAL_ENCODER_MODEL_DOCSTRING = r"""
Returns:
Examples:
```
>>> from PIL import Image
>>> import requests
>>> import jax
>>> from transformers import (
... FlaxVisionTextDualEncoderModel,
... VisionTextDualEncoderProcessor,
... AutoImageProcessor,
... AutoTokenizer,
... )
>>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
>>> image_processor = AutoImageProcesor.from_pretrained("google/vit-base-patch16-224")
>>> processor = VisionTextDualEncoderProcessor(image_processor, tokenizer)
>>> model = FlaxVisionTextDualEncoderModel.from_vision_text_pretrained(
... "google/vit-base-patch16-224", "google-bert/bert-base-uncased"
... )
>>>
>>> urls = [
... "http://images.cocodataset.org/val2017/000000039769.jpg",
... "https://farm3.staticflickr.com/2674/5850229113_4fe05d5265_z.jpg",
... ]
>>> images = [Image.open(requests.get(url, stream=True).raw) for url in urls]
>>> inputs = processor(
... text=["a photo of a cat", "a photo of a dog"], images=images, return_tensors="np", padding=True
... )
>>> outputs = model(
... input_ids=inputs.input_ids,
... attention_mask=inputs.attention_mask,
... pixel_values=inputs.pixel_values,
... )
>>> logits_per_image = outputs.logits_per_image
>>>
>>> model.save_pretrained("vit-bert")
>>> model = FlaxVisionTextDualEncoderModel.from_pretrained("vit-bert")
>>>
>>> outputs = model(**inputs)
>>> logits_per_image = outputs.logits_per_image
>>> probs = jax.nn.softmax(logits_per_image, axis=1)
```
"""
# 调用 overwrite_call_docstring 函数,用于替换 FlaxVisionTextDualEncoderModel 类的文档字符串
overwrite_call_docstring(
FlaxVisionTextDualEncoderModel,
VISION_TEXT_DUAL_ENCODER_INPUTS_DOCSTRING + VISION_TEXT_DUAL_ENCODER_MODEL_DOCSTRING,
)
# 调用 append_replace_return_docstrings 函数,用于附加和替换 FlaxVisionTextDualEncoderModel 类的返回值文档字符串
append_replace_return_docstrings(
FlaxVisionTextDualEncoderModel, output_type=FlaxCLIPOutput, config_class=_CONFIG_FOR_DOC
)
.\models\vision_text_dual_encoder\modeling_tf_vision_text_dual_encoder.py
VISION_TEXT_DUAL_ENCODER_START_DOCSTRING = r"""
This class can be used to initialize a vision-text dual encoder model with any pretrained vision autoencoding model
as the vision encoder and any pretrained text model as the text encoder. The vision and text encoders are loaded
via the [`~TFAutoModel.from_pretrained`] method. The projection layers are automatically added to the model and
should be fine-tuned on a downstream task, like contrastive image-text modeling.
In [LiT: Zero-Shot Transfer with Locked-image Text Tuning](https://arxiv.org/abs/2111.07991) it is shown how
leveraging pre-trained (locked/frozen) image and text model for contrastive learning yields significant improvement
on new zero-shot vision tasks such as image classification or retrieval.
After such a Vision-Text-Dual-Encoder model has been trained/fine-tuned, it can be saved/loaded just like any other
models (see the examples for more information).
This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a Keras [Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it as a
regular Keras Model and refer to the TF documentation for all matter related to general usage and behavior.
"""
Parameters:
config ([`VisionEncoderDecoderConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
"""
Args:
input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`):
输入序列 token 的索引,位于词汇表中。默认情况下会忽略填充部分。
可以使用 `PreTrainedTokenizer` 获取索引。详见 `PreTrainedTokenizer.encode` 和 `PreTrainedTokenizer.__call__`。
[什么是输入 ID?](../glossary#input-ids)
attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
避免在填充 token 索引上执行注意力的掩码。掩码值在 `[0, 1]` 范围内:
- 对于 **未被掩码** 的 token,为 1,
- 对于 **被掩码** 的 token,为 0。
[什么是注意力掩码?](../glossary#attention-mask)
position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
每个输入序列 token 在位置嵌入中的位置索引。选择范围在 `[0, config.max_position_embeddings - 1]` 内。
[什么是位置 ID?](../glossary#position-ids)
output_attentions (`bool`, *optional*):
是否返回所有注意力层的注意力张量。查看返回张量下的 `attentions` 获取更多细节。
output_hidden_states (`bool`, *optional*):
是否返回所有层的隐藏状态。查看返回张量下的 `hidden_states` 获取更多细节。
return_dict (`bool`, *optional*):
是否返回 [`~utils.ModelOutput`] 而不是普通元组。
"""
input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`):
attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`):
return_loss (`bool`, *optional*):
output_attentions (`bool`, *optional*):
output_hidden_states (`bool`, *optional*):
return_dict (`bool`, *optional*):
"""
# 从 transformers.models.clip.modeling_tf_clip.contrastive_loss 复制的对比损失函数定义
def contrastive_loss(logits: tf.Tensor) -> tf.Tensor:
# 计算稀疏分类交叉熵损失的均值,用于对比损失的计算
return tf.math.reduce_mean(
keras.metrics.sparse_categorical_crossentropy(
y_true=tf.range(shape_list(logits)[0]), y_pred=logits, from_logits=True
)
)
# 从 transformers.models.clip.modeling_tf_clip.clip_loss 复制的 CLIP 损失函数定义
def clip_loss(similarity: tf.Tensor) -> tf.Tensor:
# 计算标题和图像的对比损失,由对比损失函数 contrastive_loss 计算
caption_loss = contrastive_loss(similarity)
# 转置相似度矩阵并计算图像和标题的对比损失,同样使用 contrastive_loss 函数
image_loss = contrastive_loss(tf.transpose(similarity))
# 返回标题损失和图像损失的平均值作为 CLIP 损失
return (caption_loss + image_loss) / 2.0
# 使用 @add_start_docstrings(VISION_TEXT_DUAL_ENCODER_START_DOCSTRING) 装饰器注释的双编码器模型类
class TFVisionTextDualEncoderModel(TFPreTrainedModel):
# 指定模型配置类
config_class = VisionTextDualEncoderConfig
# 指定基础模型前缀
base_model_prefix = "vision_text_dual_encoder"
# 指定加载权重前缀
load_weight_prefix = "tf_vision_text_dual_encoder_model"
def __init__(
self,
config: Optional[VisionTextDualEncoderConfig] = None,
vision_model: Optional[TFPreTrainedModel] = None,
text_model: Optional[TFPreTrainedModel] = None,
):
# 如果未提供配置且视觉模型或文本模型任一未提供,则引发 ValueError
if config is None and (vision_model is None or text_model is None):
raise ValueError("Either a configuration or an vision and a text model has to be provided")
# 如果未提供配置,则从视觉和文本模型的配置中创建 VisionTextDualEncoderConfig
if config is None:
config = VisionTextDualEncoderConfig.from_vision_text_configs(vision_model.config, text_model.config)
else:
# 如果提供的配置不是 VisionTextDualEncoderConfig 类型,则引发 ValueError
if not isinstance(config, self.config_class):
raise ValueError(f"config: {config} has to be of type {self.config_class}")
# 使用配置初始化父类
super().__init__(config)
# 如果未提供视觉模型,则根据配置创建适当的视觉模型
if vision_model is None:
if isinstance(config.vision_config, CLIPVisionConfig):
vision_model = TFCLIPVisionModel.from_config(config.vision_config, name="vision_model")
else:
vision_model = TFAutoModel.from_config(config.vision_config, name="vision_model")
# 如果未提供文本模型,则根据配置创建适当的文本模型
if text_model is None:
text_model = TFAutoModel.from_config(config.text_config, name="text_model")
# 分别设置视觉模型和文本模型
self.vision_model = vision_model
self.text_model = text_model
# 确保各模型的配置引用共享配置,以保持配置更新同步
self.vision_model.config = self.config.vision_config
self.text_model.config = self.config.text_config
# 设置视觉嵌入维度、文本嵌入维度和投影维度
self.vision_embed_dim = config.vision_config.hidden_size
self.text_embed_dim = config.text_config.hidden_size
self.projection_dim = config.projection_dim
# 定义视觉和文本的投影层,不使用偏置项
self.visual_projection = keras.layers.Dense(self.projection_dim, use_bias=False, name="visual_projection")
self.text_projection = keras.layers.Dense(self.projection_dim, use_bias=False, name="text_projection")
# 初始化日志尺度为 None
self.logit_scale = None
# 设置模型配置
self.config = config
# 在构建方法中构建模型,确保命名正确
def build(self, input_shape=None):
# 如果已经构建过,则直接返回,避免重复构建
if self.built:
return
# 设置标志表示模型已经构建
self.built = True
# 使用常量初始化器设置logit_scale权重,shape为(1,)
initializer = keras.initializers.Constant(self.config.logit_scale_init_value)
self.logit_scale = self.add_weight(shape=(1,), initializer=initializer, name="logit_scale")
# 如果存在visual_projection属性,则构建它并设置命名空间
if getattr(self, "visual_projection", None) is not None:
with tf.name_scope(self.visual_projection.name):
self.visual_projection.build([None, None, self.vision_embed_dim])
# 如果存在text_projection属性,则构建它并设置命名空间
if getattr(self, "text_projection", None) is not None:
with tf.name_scope(self.text_projection.name):
self.text_projection.build([None, None, self.text_embed_dim])
# 设置vision_model的命名空间并构建其模型
with tf.name_scope(self.vision_model.name):
self.vision_model.build(None)
# 设置text_model的命名空间并构建其模型
with tf.name_scope(self.text_model.name):
self.text_model.build(None)
# 将TensorFlow的权重名称转换为PyTorch风格的权重名称
def tf_to_pt_weight_rename(self, tf_weight):
# 如果权重名称中包含"vision_model",则根据不同情况进行重命名处理
if "vision_model" in tf_weight:
if tf_weight.count("vision_model") == 1:
return (re.sub(r"vision_model\..*?\.", "vision_model.", tf_weight),)
elif tf_weight.count("vision_model") == 2:
return (re.sub(r"vision_model\..*?\.vision_model", "vision_model.vision_model", tf_weight),)
else:
raise ValueError(
f"Unexpected weight name {tf_weight}. Please file an issue on the"
" Transformers repo to let us know about this error!"
)
# 如果权重名称中包含"text_model",则进行相应的重命名处理
elif "text_model" in tf_weight:
return (re.sub(r"text_model\..*?\.", "text_model.", tf_weight),)
# 如果以上条件都不符合,则返回原始的权重名称
else:
return (tf_weight,)
# 添加模型前向传播的文档字符串,并用VISION_TEXT_DUAL_ENCODER_TEXT_INPUTS_DOCSTRING进行注释
def get_text_features(
self,
input_ids=None,
attention_mask=None,
position_ids=None,
token_type_ids=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
Returns:
text_features (`tf.Tensor` of shape `(batch_size, output_dim`): The text embeddings obtained by applying
the projection layer to the pooled output of [`TFCLIPTextModel`].
Examples:
```
>>> from transformers import TFVisionTextDualEncoderModel, AutoTokenizer
>>> model = TFVisionTextDualEncoderModel.from_pretrained("clip-italian/clip-italian", from_pt=True)
>>> tokenizer = AutoTokenizer.from_pretrained("clip-italian/clip-italian")
>>> inputs = tokenizer(["una foto di un gatto", "una foto di un cane"], padding=True, return_tensors="np")
>>> text_features = model.get_text_features(**inputs)
```"""
# 使用 self.text_model 处理输入,获取文本输出
text_outputs = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
token_type_ids=token_type_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# 从文本输出中获取池化后的输出
pooled_output = text_outputs[1]
# 使用 self.text_projection 对池化输出进行投影,得到文本特征
text_features = self.text_projection(pooled_output)
return text_features
@add_start_docstrings_to_model_forward(VISION_TEXT_DUAL_ENCODER_VISION_INPUTS_DOCSTRING)
def get_image_features(
self,
pixel_values=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
Returns:
image_features (`tf.Tensor` of shape `(batch_size, output_dim`): The image embeddings obtained by applying
the projection layer to the pooled output of [`TFCLIPVisionModel`].
Examples:
```
>>> from PIL import Image
>>> import requests
>>> from transformers import TFVisionTextDualEncoderModel, AutoImageProcessor
>>> model = TFVisionTextDualEncoderModel.from_pretrained("clip-italian/clip-italian", from_pt=True)
>>> image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = image_processor(images=image, return_tensors="np")
>>> image_features = model.get_image_features(**inputs)
```"""
# 使用 self.vision_model 处理输入,获取视觉输出
vision_outputs = self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# 从视觉输出中获取池化后的输出
pooled_output = vision_outputs[1] # pooled_output
# 使用 self.visual_projection 对池化输出进行投影,得到图像特征
image_features = self.visual_projection(pooled_output)
return image_features
@unpack_inputs
@add_start_docstrings_to_model_forward(VISION_TEXT_DUAL_ENCODER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFCLIPOutput, config_class=_CONFIG_FOR_DOC)
# 定义一个方法 `call`,用于执行模型推理或训练过程的输入处理和参数设置
def call(
self,
input_ids: tf.Tensor | None = None, # 输入文本的token IDs张量,默认为None
pixel_values: tf.Tensor | None = None, # 输入图像的像素值张量,默认为None
attention_mask: tf.Tensor | None = None, # 注意力掩码张量,默认为None
position_ids: tf.Tensor | None = None, # 位置编码张量,默认为None
return_loss: Optional[bool] = None, # 是否返回损失张量的布尔值,可选,默认为None
token_type_ids: tf.Tensor | None = None, # token类型 IDs 张量,默认为None
output_attentions: Optional[bool] = None, # 是否输出注意力张量的布尔值,可选,默认为None
output_hidden_states: Optional[bool] = None, # 是否输出隐藏状态张量的布尔值,可选,默认为None
return_dict: Optional[bool] = None, # 是否返回字典格式的输出结果的布尔值,可选,默认为None
training: bool = False, # 是否为训练模式的布尔值,默认为False
):
# 类方法,用于从预训练的视觉-文本模型加载模型
@classmethod
def from_vision_text_pretrained(
cls,
vision_model_name_or_path: str = None, # 视觉模型名称或路径的字符串,默认为None
text_model_name_or_path: str = None, # 文本模型名称或路径的字符串,默认为None
*model_args, # 模型参数的位置参数
**kwargs, # 模型参数的关键字参数
):
# 属性方法,返回构建网络所需的虚拟输入数据字典
@property
def dummy_inputs(self):
"""
Dummy inputs to build the network.
Returns:
`Dict[str, tf.Tensor]`: The dummy inputs.
"""
# 使用预定义的虚拟输入数据构建输入文本的token IDs张量
input_ids = tf.constant(DUMMY_INPUTS, dtype=tf.int32)
batch_size, seq_len = input_ids.shape
# 使用随机生成的虚拟输入数据构建输入图像的像素值张量
VISION_DUMMY_INPUTS = tf.random.uniform(
shape=(
batch_size,
self.config.vision_config.num_channels,
self.config.vision_config.image_size,
self.config.vision_config.image_size,
),
dtype=tf.float32,
)
pixel_values = tf.constant(VISION_DUMMY_INPUTS)
# 构建并返回包含虚拟输入数据的字典
dummy = {"pixel_values": pixel_values, "input_ids": input_ids}
return dummy
.\models\vision_text_dual_encoder\modeling_vision_text_dual_encoder.py
""" PyTorch VisionTextDualEncoder model. """
from typing import Optional, Tuple, Union
import torch
from torch import nn
from ...modeling_utils import PreTrainedModel
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from ..auto.configuration_auto import AutoConfig
from ..auto.modeling_auto import AutoModel
from ..clip.modeling_clip import CLIPOutput, CLIPVisionConfig, CLIPVisionModel
from .configuration_vision_text_dual_encoder import VisionTextDualEncoderConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "VisionTextDualEncoderConfig"
VISION_TEXT_DUAL_ENCODER_START_DOCSTRING = r"""
This class can be used to initialize a vision-text dual encoder model with any pretrained vision autoencoding model
as the vision encoder and any pretrained text model as the text encoder. The vision and text encoders are loaded
via the [`~AutoModel.from_pretrained`] method. The projection layers are automatically added to the model and
should be fine-tuned on a downstream task, like contrastive image-text modeling.
In [LiT: Zero-Shot Transfer with Locked-image Text Tuning](https://arxiv.org/abs/2111.07991) it is shown how
leveraging pre-trained (locked/frozen) image and text model for contrastive learning yields significant improvment
on new zero-shot vision tasks such as image classification or retrieval.
After such a Vision-Text-Dual-Encoder model has been trained/fine-tuned, it can be saved/loaded just like any other
models (see the examples for more information).
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`VisionEncoderDecoderConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
输入序列标记在词汇表中的索引。默认情况下,将忽略填充标记。
可以使用[`PreTrainedTokenizer`]获取索引。参见[`PreTrainedTokenizer.encode`]和[`PreTrainedTokenizer.__call__`]了解详情。
[什么是输入 ID?](../glossary
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
遮盖机制,用于避免在填充标记索引上执行注意力操作。遮盖值选择在 `[0, 1]`:
- 1 表示**未遮盖**的标记,
- 0 表示**遮盖**的标记。
[什么是注意力遮盖?](../glossary
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
输入序列中每个标记的位置索引,用于位置嵌入。选择范围在 `[0, config.max_position_embeddings - 1]`。
[什么是位置 ID?](../glossary
output_attentions (`bool`, *optional*):
是否返回所有注意力层的注意力张量。查看返回张量中的 `attentions` 获取更多细节。
output_hidden_states (`bool`, *optional*):
是否返回所有层的隐藏状态。查看返回张量中的 `hidden_states` 获取更多细节。
return_dict (`bool`, *optional*):
是否返回 [`~utils.ModelOutput`] 而不是简单的元组。
"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
# 输入序列标记的索引,用于表示词汇表中的每个标记。默认情况下会忽略填充。
# 可以使用 `AutoTokenizer` 获取这些索引。参见 `PreTrainedTokenizer.encode` 和 `PreTrainedTokenizer.__call__` 获取详情。
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
# 避免在填充的标记索引上执行注意力操作的掩码。掩码的值选择在 `[0, 1]` 之间:
# - 对于 **未屏蔽的** 标记,设为 1,
# - 对于 **屏蔽的** 标记,设为 0。
[What are attention masks?](../glossary#attention-mask)
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
# 每个输入序列标记在位置嵌入中的位置索引。选择范围在 `[0, config.max_position_embeddings - 1]` 之间。
[What are position IDs?](../glossary#position-ids)
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
# 像素值。默认情况下会忽略填充。可以通过图像处理器获取像素值(例如,如果使用 ViT 作为编码器,应使用 `AutoImageProcessor`)。
# 详见 `ViTImageProcessor.__call__` 获取详情。
return_loss (`bool`, *optional*):
# 是否返回对比损失。
output_attentions (`bool`, *optional*):
# 是否返回所有注意力层的注意力张量。返回的张量中的 `attentions` 字段提供更多细节。
output_hidden_states (`bool`, *optional*):
# 是否返回所有层的隐藏状态。返回的张量中的 `hidden_states` 字段提供更多细节。
return_dict (`bool`, *optional*):
# 是否返回 `~utils.ModelOutput` 而不是普通元组。
"""
Copied from transformers.models.clip.modeling_clip.contrastive_loss
定义对比损失函数,输入为 logits,输出为损失值
"""
def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))
"""
Copied from transformers.models.clip.modeling_clip.clip_loss
定义 CLIP 损失函数,输入为相似性张量,输出为损失值
"""
def clip_loss(similarity: torch.Tensor) -> torch.Tensor:
# 计算文本和图像的对比损失
caption_loss = contrastive_loss(similarity)
image_loss = contrastive_loss(similarity.t())
# 返回文本和图像损失的平均值
return (caption_loss + image_loss) / 2.0
"""
@add_start_docstrings(VISION_TEXT_DUAL_ENCODER_START_DOCSTRING)
双编码器模型,结合视觉和文本输入进行编码
"""
class VisionTextDualEncoderModel(PreTrainedModel):
config_class = VisionTextDualEncoderConfig
base_model_prefix = "vision_text_dual_encoder"
def __init__(
self,
config: Optional[VisionTextDualEncoderConfig] = None,
vision_model: Optional[PreTrainedModel] = None,
text_model: Optional[PreTrainedModel] = None,
):
if config is None and (vision_model is None or text_model is None):
raise ValueError("Either a configuration or an vision and a text model has to be provided")
if config is None:
# 如果未提供配置,则从视觉和文本模型的配置创建配置对象
config = VisionTextDualEncoderConfig.from_vision_text_configs(vision_model.config, text_model.config)
else:
if not isinstance(config, self.config_class):
raise ValueError(f"config: {config} has to be of type {self.config_class}")
# 使用父类初始化模型
super().__init__(config)
# 如果未提供视觉模型,则根据配置创建默认的视觉模型
if vision_model is None:
if isinstance(config.vision_config, CLIPVisionConfig):
vision_model = CLIPVisionModel(config.vision_config)
else:
vision_model = AutoModel.from_config(config.vision_config)
# 如果未提供文本模型,则根据配置创建默认的文本模型
if text_model is None:
text_model = AutoModel.from_config(config.text_config)
# 将创建的视觉模型和文本模型保存到当前对象中
self.vision_model = vision_model
self.text_model = text_model
# 确保各个模型的配置对象与共享的配置对象同步更新
self.vision_model.config = self.config.vision_config
self.text_model.config = self.config.text_config
# 设置视觉和文本嵌入的维度和投影维度
self.vision_embed_dim = config.vision_config.hidden_size
self.text_embed_dim = config.text_config.hidden_size
self.projection_dim = config.projection_dim
# 定义视觉和文本的线性投影层
self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False)
self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)
# 初始化 logits 缩放参数
self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
@add_start_docstrings_to_model_forward(VISION_TEXT_DUAL_ENCODER_TEXT_INPUTS_DOCSTRING)
def get_text_features(
self,
input_ids=None,
attention_mask=None,
position_ids=None,
token_type_ids=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
# 注意:这里函数定义没有完全列出,后续可能还有参数
@add_start_docstrings_to_model_forward(VISION_TEXT_DUAL_ENCODER_VISION_INPUTS_DOCSTRING)
# 使用装饰器添加模型前向传播的文档字符串,文档字符串定义了输入参数和返回结果的形状和含义
def get_image_features(
self,
pixel_values=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
Returns:
image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
applying the projection layer to the pooled output of [`CLIPVisionModel`].
Examples:
```
>>> from PIL import Image
>>> import requests
>>> from transformers import VisionTextDualEncoderModel, AutoImageProcessor
>>> model = VisionTextDualEncoderModel.from_pretrained("clip-italian/clip-italian")
>>> image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = image_processor(images=image, return_tensors="pt")
>>> image_features = model.get_image_features(**inputs)
```"""
# 使用视觉模型处理像素值,返回视觉特征,可以控制是否输出注意力和隐藏状态,并选择返回形式
vision_outputs = self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# 从视觉输出的第二个元素中获取池化输出作为特征表示
pooled_output = vision_outputs[1] # pooled_output
# 将池化输出应用于视觉投影层,得到最终的图像特征表示
image_features = self.visual_projection(pooled_output)
# 返回图像特征表示
return image_features
# 定义一个类方法 `forward`,用于模型的前向传播
def forward(
self,
input_ids: Optional[torch.LongTensor] = None, # 输入的 token IDs,类型为长整型张量,可选
pixel_values: Optional[torch.FloatTensor] = None, # 输入的像素值,类型为浮点张量,可选
attention_mask: Optional[torch.Tensor] = None, # 注意力掩码张量,类型为张量,可选
position_ids: Optional[torch.LongTensor] = None, # 位置 IDs,类型为长整型张量,可选
return_loss: Optional[bool] = None, # 是否返回损失值,类型为布尔值,可选
token_type_ids: Optional[torch.LongTensor] = None, # Token 类型 IDs,类型为长整型张量,可选
output_attentions: Optional[bool] = None, # 是否输出注意力权重,类型为布尔值,可选
output_hidden_states: Optional[bool] = None, # 是否输出隐藏状态,类型为布尔值,可选
return_dict: Optional[bool] = None, # 是否返回字典格式的输出,类型为布尔值,可选
@classmethod
def from_pretrained(cls, *args, **kwargs):
# 目前不支持复合模型的快速初始化
kwargs["_fast_init"] = False
# 调用父类的 `from_pretrained` 方法,并传递所有的位置参数和关键字参数
return super().from_pretrained(*args, **kwargs)
@classmethod
def from_vision_text_pretrained(
cls,
vision_model_name_or_path: str = None, # 视觉模型的名称或路径,类型为字符串,可选
text_model_name_or_path: str = None, # 文本模型的名称或路径,类型为字符串,可选
*model_args, # 其他模型参数,位置参数的元组
**kwargs, # 其他模型参数,关键字参数的字典
.\models\vision_text_dual_encoder\processing_vision_text_dual_encoder.py
"""
Processor class for VisionTextDualEncoder
"""
import warnings
from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import BatchEncoding
class VisionTextDualEncoderProcessor(ProcessorMixin):
r"""
Constructs a VisionTextDualEncoder processor which wraps an image processor and a tokenizer into a single
processor.
[`VisionTextDualEncoderProcessor`] offers all the functionalities of [`AutoImageProcessor`] and [`AutoTokenizer`].
See the [`~VisionTextDualEncoderProcessor.__call__`] and [`~VisionTextDualEncoderProcessor.decode`] for more
information.
Args:
image_processor ([`AutoImageProcessor`], *optional*):
The image processor is a required input.
tokenizer ([`PreTrainedTokenizer`], *optional*):
The tokenizer is a required input.
"""
attributes = ["image_processor", "tokenizer"]
image_processor_class = "AutoImageProcessor"
tokenizer_class = "AutoTokenizer"
def __init__(self, image_processor=None, tokenizer=None, **kwargs):
feature_extractor = None
if "feature_extractor" in kwargs:
warnings.warn(
"The `feature_extractor` argument is deprecated and will be removed in v5, use `image_processor`"
" instead.",
FutureWarning,
)
feature_extractor = kwargs.pop("feature_extractor")
image_processor = image_processor if image_processor is not None else feature_extractor
if image_processor is None:
raise ValueError("You have to specify an image_processor.")
if tokenizer is None:
raise ValueError("You have to specify a tokenizer.")
super().__init__(image_processor, tokenizer)
self.current_processor = self.image_processor
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to VisionTextDualEncoderTokenizer's
[`~PreTrainedTokenizer.batch_decode`]. Please refer to the docstring of this method for more information.
"""
return self.tokenizer.batch_decode(*args, **kwargs)
def decode(self, *args, **kwargs):
return self.tokenizer.decode(*args, **kwargs)
@property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
image_processor_input_names = self.image_processor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
@property
def feature_extractor_class(self):
warnings.warn(
"`feature_extractor_class` is deprecated and will be removed in v5. Use `image_processor_class` instead.",
FutureWarning,
)
return self.image_processor_class
@property
def feature_extractor(self):
warnings.warn(
"`feature_extractor` is deprecated and will be removed in v5. Use `image_processor` instead.",
FutureWarning,
)
return self.image_processor