Transformers 源码解析(四十九)
.\models\flaubert\__init__.py
from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tf_available, is_torch_available
_import_structure = {
"configuration_flaubert": ["FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "FlaubertConfig", "FlaubertOnnxConfig"],
"tokenization_flaubert": ["FlaubertTokenizer"],
}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_flaubert"] = [
"FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
"FlaubertForMultipleChoice",
"FlaubertForQuestionAnswering",
"FlaubertForQuestionAnsweringSimple",
"FlaubertForSequenceClassification",
"FlaubertForTokenClassification",
"FlaubertModel",
"FlaubertWithLMHeadModel",
"FlaubertPreTrainedModel",
]
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_tf_flaubert"] = [
"TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFFlaubertForMultipleChoice",
"TFFlaubertForQuestionAnsweringSimple",
"TFFlaubertForSequenceClassification",
"TFFlaubertForTokenClassification",
"TFFlaubertModel",
"TFFlaubertPreTrainedModel",
"TFFlaubertWithLMHeadModel",
]
if TYPE_CHECKING:
from .configuration_flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, FlaubertConfig, FlaubertOnnxConfig
from .tokenization_flaubert import FlaubertTokenizer
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_flaubert import (
FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
FlaubertForMultipleChoice,
FlaubertForQuestionAnswering,
FlaubertForQuestionAnsweringSimple,
FlaubertForSequenceClassification,
FlaubertForTokenClassification,
FlaubertModel,
FlaubertPreTrainedModel,
FlaubertWithLMHeadModel,
)
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_tf_flaubert import (
TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFFlaubertForMultipleChoice,
TFFlaubertForQuestionAnsweringSimple,
TFFlaubertForSequenceClassification,
TFFlaubertForTokenClassification,
TFFlaubertModel,
TFFlaubertPreTrainedModel,
TFFlaubertWithLMHeadModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
.\models\flava\configuration_flava.py
""" FLAVA model configurations"""
import os
from typing import Any, Dict, Union
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
FLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"facebook/flava-full": "https://huggingface.co/facebook/flava-full/resolve/main/config.json",
}
class FlavaImageConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`FlavaImageModel`]. It is used to instantiate an
FLAVA model according to the specified arguments, defining the model architecture.
# 这是用于存储 [`FlavaImageModel`] 配置的配置类,用于根据指定的参数实例化 FLAVA 模型,定义模型架构。
Instantiating a configuration with the defaults will yield a similar configuration to that of the FLAVA
[facebook/flava-full](https://huggingface.co/facebook/flava-full) architecture.
# 使用默认值实例化配置将产生类似于 FLAVA [facebook/flava-full](https://huggingface.co/facebook/flava-full) 架构的配置。
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
# 配置对象继承自 [`PretrainedConfig`],可用于控制模型输出。更多信息,请阅读 [`PretrainedConfig`] 的文档。
# 定义模型类型为 FLAVA 图像模型
model_type = "flava_image_model"
# 初始化函数,设置模型配置参数
def __init__(
self,
hidden_size: int = 768, # 隐藏层大小,默认为768
num_hidden_layers: int = 12, # 隐藏层数,默认为12
num_attention_heads: int = 12, # 注意力头数,默认为12
intermediate_size: int = 3072, # 中间层大小,默认为3072
hidden_act: int = "gelu", # 隐藏层激活函数,默认为"gelu"
hidden_dropout_prob: float = 0.0, # 隐藏层dropout概率,默认为0.0
attention_probs_dropout_prob: float = 0.0, # 注意力概率dropout概率,默认为0.0
initializer_range: float = 0.02, # 初始化范围,默认为0.02
layer_norm_eps: float = 1e-12, # LayerNorm的epsilon,默认为1e-12
image_size: int = 224, # 图像大小,默认为224
patch_size: int = 16, # 图像块大小,默认为16
num_channels: int = 3, # 图像通道数,默认为3
qkv_bias: bool = True, # 是否在QKV中使用偏置,默认为True
mask_token: bool = True, # 是否使用掩码token,默认为True
vocab_size: int = 8192, # 词汇表大小,默认为8192
**kwargs, # 其他关键字参数
):
super().__init__(**kwargs) # 调用父类初始化方法
self.hidden_size = hidden_size # 设置隐藏层大小
self.num_hidden_layers = num_hidden_layers # 设置隐藏层数
self.num_attention_heads = num_attention_heads # 设置注意力头数
self.intermediate_size = intermediate_size # 设置中间层大小
self.hidden_act = hidden_act # 设置隐藏层激活函数
self.hidden_dropout_prob = hidden_dropout_prob # 设置隐藏层dropout概率
self.attention_probs_dropout_prob = attention_probs_dropout_prob # 设置注意力概率dropout概率
self.initializer_range = initializer_range # 设置初始化范围
self.layer_norm_eps = layer_norm_eps # 设置LayerNorm的epsilon
self.image_size = image_size # 设置图像大小
self.patch_size = patch_size # 设置图像块大小
self.num_channels = num_channels # 设置图像通道数
self.qkv_bias = qkv_bias # 设置是否在QKV中使用偏置
self.mask_token = mask_token # 设置是否使用掩码token
self.vocab_size = vocab_size # 设置词汇表大小
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs) # 设置kwargs中的token
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) # 获取配置字典和更新后的kwargs
# 如果从FlavaConfig加载,获取图像配置字典
if config_dict.get("model_type") == "flava":
config_dict = config_dict["image_config"]
# 如果配置字典中包含model_type,并且cls有model_type属性,并且不匹配时,发出警告
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)
return cls.from_dict(config_dict, **kwargs) # 从配置字典和kwargs创建实例
class FlavaTextConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`FlavaTextModel`]. It is used to instantiate an
FLAVA model according to the specified arguments, defining the model architecture.
Instantiating a configuration with the defaults will yield a similar configuration to that of the FLAVA
[facebook/flava-full](https://huggingface.co/facebook/flava-full) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Example:
```
>>> from transformers import FlavaTextConfig, FlavaTextModel
>>>
>>> configuration = FlavaTextConfig()
>>>
>>> model = FlavaTextModel(configuration)
>>>
>>> configuration = model.config
```"""
model_type = "flava_text_model"
def __init__(
self,
vocab_size: int = 30522,
type_vocab_size: int = 2,
max_position_embeddings: int = 512,
position_embedding_type: str = "absolute",
hidden_size: int = 768,
num_hidden_layers: int = 12,
num_attention_heads: int = 12,
intermediate_size: int = 3072,
hidden_act: str = "gelu",
hidden_dropout_prob: float = 0.0,
attention_probs_dropout_prob: float = 0.0,
initializer_range: float = 0.02,
layer_norm_eps: float = 1e-12,
pad_token_id: int = 0,
qkv_bias: bool = True,
**kwargs,
):
# 调用父类构造函数,初始化继承自 PretrainedConfig 的属性
super().__init__(**kwargs)
# 初始化配置参数
self.vocab_size = vocab_size # 词汇表大小
self.type_vocab_size = type_vocab_size # 类型词汇表大小
self.max_position_embeddings = max_position_embeddings # 最大位置嵌入长度
self.position_embedding_type = position_embedding_type # 位置嵌入类型
self.hidden_size = hidden_size # 隐藏层大小
self.num_hidden_layers = num_hidden_layers # 隐藏层层数
self.num_attention_heads = num_attention_heads # 注意力头数
self.intermediate_size = intermediate_size # 中间层大小
self.hidden_act = hidden_act # 隐藏层激活函数
self.hidden_dropout_prob = hidden_dropout_prob # 隐藏层 dropout 概率
self.attention_probs_dropout_prob = attention_probs_dropout_prob # 注意力 dropout 概率
self.initializer_range = initializer_range # 初始化范围
self.layer_norm_eps = layer_norm_eps # 层归一化 epsilon
self.qkv_bias = qkv_bias # 是否使用 QKV 偏置
self.pad_token_id = pad_token_id # 填充 token 的 ID
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
# 在kwargs中设置token
cls._set_token_in_kwargs(kwargs)
# 获取预训练模型的配置字典和更新后的kwargs
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
# 如果配置字典指定的模型类型是"flava",则使用其text_config作为配置字典
if config_dict.get("model_type") == "flava":
config_dict = config_dict["text_config"]
# 如果配置字典中有"model_type"字段,并且类(cls)具有"model_type"属性,
# 且配置字典中的模型类型不等于类的模型类型,则发出警告
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)
# 使用配置字典和kwargs创建实例
return cls.from_dict(config_dict, **kwargs)
# 定义一个配置类,用于存储 FlavaMultimodalModel 的配置信息。该类继承自 PretrainedConfig。
class FlavaMultimodalConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`FlavaMultimodalModel`]. It is used to instantiate
an FLAVA model according to the specified arguments, defining the model architecture.
Instantiating a configuration with the defaults will yield a similar configuration to that of the FLAVA
[facebook/flava-full](https://huggingface.co/facebook/flava-full) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
hidden_size (`int`, *optional*, defaults to 768):
Dimensionality of the encoder layers and the pooler layer.
num_hidden_layers (`int`, *optional*, defaults to 6):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 12):
Number of attention heads for each attention layer in the Transformer encoder.
intermediate_size (`int`, *optional*, defaults to 3072):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"selu"` and `"gelu_new"` are supported.
hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
The epsilon used by the layer normalization layers.
qkv_bias (`bool`, *optional*, defaults to `True`):
Whether to add a bias to the queries, keys and values.
use_cls_token (`bool`, *optional*, defaults to `True`):
Whether to use an extra CLS token for multimodal settings. Usually needed by the FLAVA model.
Example:
```
>>> from transformers import FlavaMultimodalConfig, FlavaMultimodalModel
>>>
>>> configuration = FlavaMultimodalConfig()
>>>
>>> model = FlavaMultimodalModel(configuration)
>>>
>>> configuration = model.config
```"""
# 类属性,指定模型类型为 flava_multimodal_model
model_type = "flava_multimodal_model"
# 初始化方法,用于初始化模型配置参数
def __init__(
self,
hidden_size: int = 768, # 隐藏层大小,默认为768
num_hidden_layers: int = 6, # 隐藏层数,默认为6
num_attention_heads: int = 12, # 注意力头的数量,默认为12
intermediate_size: int = 3072, # 中间层大小,默认为3072
hidden_act: int = "gelu", # 隐藏层激活函数,默认为gelu
hidden_dropout_prob: int = 0.0, # 隐藏层的dropout概率,默认为0.0
attention_probs_dropout_prob: int = 0.0, # 注意力概率的dropout概率,默认为0.0
initializer_range: float = 0.02, # 初始化范围,默认为0.02
layer_norm_eps: float = 1e-12, # 层归一化的epsilon值,默认为1e-12
qkv_bias: bool = True, # 是否在QKV中使用偏置,默认为True
use_cls_token: bool = True, # 是否使用CLS令牌,默认为True
**kwargs, # 其余关键字参数
):
# 调用父类的初始化方法,传入其余的关键字参数
super().__init__(**kwargs)
# 设置模型的各种配置参数
self.hidden_size = hidden_size # 设置隐藏层大小
self.num_hidden_layers = num_hidden_layers # 设置隐藏层数
self.num_attention_heads = num_attention_heads # 设置注意力头的数量
self.intermediate_size = intermediate_size # 设置中间层大小
self.hidden_act = hidden_act # 设置隐藏层激活函数
self.hidden_dropout_prob = hidden_dropout_prob # 设置隐藏层的dropout概率
self.attention_probs_dropout_prob = attention_probs_dropout_prob # 设置注意力概率的dropout概率
self.initializer_range = initializer_range # 设置初始化范围
self.layer_norm_eps = layer_norm_eps # 设置层归一化的epsilon值
self.qkv_bias = qkv_bias # 设置是否在QKV中使用偏置
self.use_cls_token = use_cls_token # 设置是否使用CLS令牌
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
# 在kwargs中设置token
cls._set_token_in_kwargs(kwargs)
# 获取模型的配置字典和剩余的kwargs参数
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
# 如果配置字典中的模型类型是"flava",则使用多模态配置字典
if config_dict.get("model_type") == "flava":
config_dict = config_dict["multimodal_config"]
# 如果配置字典中包含"model_type"且模型类型与当前类的模型类型不匹配,发出警告
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)
# 使用配置字典和kwargs参数创建一个新的实例
return cls.from_dict(config_dict, **kwargs)
class FlavaImageCodebookConfig(PretrainedConfig):
model_type = "flava_image_codebook"
r"""
[`FlavaImageCodebookConfig`] is the configuration class to store the configuration of a [`FlavaImageCodebook`]. It
is used to instantiate an FLAVA model according to the specified arguments, defining the model architecture.
Instantiating a configuration with the defaults will yield a similar configuration to that of the FLAVA
[facebook/flava-image-codebook](https://huggingface.co/facebook/flava-image-codebook) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
num_groups (`int`, defaults to 4):
Number of groups to be created. This parameter as of now doesn't affect the model and is used for some
internal calculation and estimations.
input_channels (`int`, defaults to 3):
Number of channels in the image to be passed.
num_blocks_per_group (`int`, defaults to 2):
Number of conv-based blocks per group.
hidden_size (`int`, defaults to 256):
Size of hidden dim for the blocks.
vocab_size (`int`, defaults to 8192):
Size of the output vocabulary for the codebook.
freeze (`bool`, defaults to `True`):
Whether to freeze the weights of the model.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
kwargs (*optional*):
Dictionary of keyword arguments.
Example:
```
>>> from transformers import FlavaImageCodebookConfig, FlavaImageCodebook
>>> # Initializing a FlavaImageCodebook with style configuration
>>> configuration = FlavaImageCodebookConfig()
>>> # Initializing a FlavaImageCodebook model (with random weights) from the style configuration
>>> model = FlavaImageCodebook(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
"""
def __init__(
self,
num_groups: int = 4,
input_channels: int = 3,
num_blocks_per_group: int = 2,
hidden_size: int = 256,
vocab_size: int = 8192,
freeze: int = True,
initializer_range: float = 0.02,
**kwargs,
):
# 调用父类的初始化方法,传入所有额外的关键字参数
super().__init__(**kwargs)
# 设置对象的各项配置参数
self.num_groups = num_groups # 设置创建的组数
self.input_channels = input_channels # 设置传递的图像通道数
self.num_blocks_per_group = num_blocks_per_group # 设置每组的卷积块数
self.hidden_size = hidden_size # 设置块的隐藏维度大小
self.vocab_size = vocab_size # 设置代码本输出词汇表的大小
self.freeze = freeze # 设置是否冻结模型权重
self.initializer_range = initializer_range # 设置权重矩阵初始化的标准差范围
@classmethod
# 根据预训练模型名称或路径及额外参数创建一个预训练配置对象
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
# 在 kwargs 中设置 token
cls._set_token_in_kwargs(kwargs)
# 获取预训练模型的配置字典和更新后的 kwargs
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
# 如果配置字典指定模型类型为 "flava",则从中获取图像代码本配置字典
if config_dict.get("model_type") == "flava":
config_dict = config_dict["image_codebook_config"]
# 如果配置字典中包含 "model_type",并且类有 "model_type" 属性,并且配置的模型类型与类中定义的不同,
# 则发出警告提示,因为这可能会导致不同配置的模型出现错误。
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)
# 根据配置字典和 kwargs 创建预训练配置对象并返回
return cls.from_dict(config_dict, **kwargs)
# `FlavaConfig` 是一个配置类,继承自 `PretrainedConfig`
class FlavaConfig(PretrainedConfig):
r"""
[`FlavaConfig`] is the configuration class to store the configuration of a [`FlavaModel`]. It is used to
instantiate FLAVA model according to the specified arguments, defining the text model, image model, image codebook
and multimodal model configs. Instantiating a configuration with the defaults will yield a similar configuration to
that of the FLAVA [facebook/flava-full](https://huggingface.co/facebook/flava-full) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
text_config (`dict`, *optional*):
文本配置选项的字典,用于初始化 FlavaTextConfig。
image_config (`dict`, *optional*):
图像配置选项的字典,用于初始化 FlavaImageConfig。
multimodal_config (`dict`, *optional*):
多模态配置选项的字典,用于初始化 FlavaMultimodalConfig。
hidden_size (`int`, *optional*, defaults to 768):
编码器层和汇聚层的维度。
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
层归一化层使用的 epsilon 值。
projection_dim (`int`, *optional*, defaults to 512):
文本和图像投影层的维度。
logit_scale_init_value (`float`, *optional*, defaults to 2.6592):
*logit_scale* 参数的初始值,默认按照 FLAVA/CLIP 实现使用的默认值。
initializer_range (`float`, *optional*, defaults to 0.02):
用于初始化所有权重矩阵的截断正态分布的标准差。
ce_ignore_index (`int`, *optional*, defaults to -100):
交叉熵忽略的索引。
mim_weight (`float`, *optional*, defaults to 1.0):
分配给 MIM(Masked Image Modeling)单模态损失的权重。
mlm_weight (`float`, *optional*, defaults to 1.0):
分配给 MLM(Masked Language Modeling)单模态损失的权重。
global_contrastive_weight (`float`, *optional*, defaults to 1.0):
分配给全局对比度交叉对齐损失的权重。
itm_weight (`float`, *optional*, defaults to 1.0):
分配给图像-文本匹配多模态损失的权重。
mmm_image_weight (`float`, *optional*, defaults to 1.0):
分配给 MMM 损失的图像部分的权重。
mmm_text_weight (`float`, *optional*, defaults to 1.0):
分配给 MMM 损失的文本部分的权重。
global_backprop_contrastive (`bool`, *optional*, defaults to `True`):
是否在对比度损失中通过所有工作者进行全局反向传播。
skip_unmasked_multimodal_encoder (`bool`, *optional*, defaults to `True`):
是否跳过运行未屏蔽的多模态编码器,其输出不被 FLAVA 损失使用。
return_loss (`bool`, *optional*, defaults to `True`):
是否返回损失值。
kwargs (*optional*):
关键字参数的字典。
Example:
```
>>> from transformers import FlavaConfig, FlavaModel, FlavaForPreTraining
>>> # 使用风格配置初始化 FlavaConfig
>>> configuration = FlavaConfig()
>>> # 使用风格配置初始化 FlavaModel 和 FlavaForPreTraining 模型(带有随机权重)
```
>>> model = FlavaModel(configuration)
>>> model_pre = FlavaForPreTraining(configuration)
# 实例化 FlavaModel 和 FlavaForPreTraining 对象,使用给定的 configuration 配置
>>> configuration = model.config
>>> configuration_pre = model_pre.config
model_type = "flava"
def __init__(
self,
image_config: Dict[str, Any] = None,
text_config: Dict[str, Any] = None,
multimodal_config: Dict[str, Any] = None,
image_codebook_config: Dict[str, Any] = None,
hidden_size: int = 768,
layer_norm_eps: float = 1e-12,
projection_dim: int = 768,
init_codebook: bool = True,
logit_scale_init_value: float = 2.6592,
initializer_range: float = 0.02,
ce_ignore_index: int = -100,
mim_weight: float = 1.0,
mlm_weight: float = 1.0,
global_contrastive_weight: float = 1.0,
itm_weight: float = 1.0,
mmm_image_weight: float = 1.0,
mmm_text_weight: float = 1.0,
global_backprop_contrastive: bool = True,
skip_unmasked_multimodal_encoder: bool = True,
return_loss: bool = True,
**kwargs,
):
r"""
Instantiate a [`FlavaConfig`] (or a derived class) from flava text model configuration, flava image model
configuration, flava multimodal model and flava codebook model configuration.
Returns:
[`FlavaConfig`]: An instance of a configuration object
"""
# 使用给定的配置参数初始化 FlavaConfig 类的实例
return cls(
image_config=image_config.to_dict(),
text_config=text_config.to_dict(),
multimodal_config=multimodal_config.to_dict(),
image_codebook_config=image_codebook_config.to_dict(),
**kwargs,
)
.\models\flava\convert_dalle_to_flava_codebook.py
import argparse
import os
import torch
from transformers import FlavaImageCodebook, FlavaImageCodebookConfig
def rreplace(s, old, new, occurrence):
li = s.rsplit(old, occurrence)
return new.join(li)
def count_parameters(state_dict):
return sum(param.float().sum() if "encoder.embeddings" not in key else 0 for key, param in state_dict.items())
def upgrade_state_dict(state_dict):
upgrade = {}
group_keys = ["group_1", "group_2", "group_3", "group_4"]
for key, value in state_dict.items():
for group_key in group_keys:
if group_key in key:
key = key.replace(f"{group_key}.", f"{group_key}.group.")
if "res_path" in key:
key = key.replace("res_path.", "res_path.path.")
if key.endswith(".w"):
key = rreplace(key, ".w", ".weight", 1)
if key.endswith(".b"):
key = rreplace(key, ".b", ".bias", 1)
upgrade[key] = value.float()
return upgrade
@torch.no_grad()
def convert_dalle_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_path=None, save_checkpoint=True):
"""
复制/粘贴/调整模型的权重以适应transformers设计。
"""
from dall_e import Encoder
encoder = Encoder()
if os.path.exists(checkpoint_path):
ckpt = torch.load(checkpoint_path)
else:
ckpt = torch.hub.load_state_dict_from_url(checkpoint_path)
if isinstance(ckpt, Encoder):
ckpt = ckpt.state_dict()
encoder.load_state_dict(ckpt)
if config_path is not None:
config = FlavaImageCodebookConfig.from_pretrained(config_path)
else:
config = FlavaImageCodebookConfig()
hf_model = FlavaImageCodebook(config).eval()
state_dict = encoder.state_dict()
hf_state_dict = upgrade_state_dict(state_dict)
hf_model.load_state_dict(hf_state_dict)
hf_state_dict = hf_model.state_dict()
hf_count = count_parameters(hf_state_dict)
state_dict_count = count_parameters(state_dict)
assert torch.allclose(hf_count, state_dict_count, atol=1e-3)
if save_checkpoint:
hf_model.save_pretrained(pytorch_dump_folder_path)
else:
return hf_state_dict
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to flava checkpoint")
parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
args = parser.parse_args()
convert_dalle_checkpoint(args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path)
.\models\flava\convert_flava_original_pytorch_to_hf.py
import argparse
import os
import torch
from transformers import FlavaConfig, FlavaForPreTraining
from transformers.models.flava.convert_dalle_to_flava_codebook import convert_dalle_checkpoint
def count_parameters(state_dict):
return sum(param.float().sum() if "encoder.embeddings" not in key else 0 for key, param in state_dict.items())
def upgrade_state_dict(state_dict, codebook_state_dict):
upgrade = {}
for key, value in state_dict.items():
if "text_encoder.embeddings" in key or "image_encoder.embeddings" in key:
continue
key = key.replace("heads.cmd.mim_head.cls.predictions", "mmm_image_head")
key = key.replace("heads.cmd.mlm_head.cls.predictions", "mmm_text_head")
key = key.replace("heads.cmd.itm_head.cls", "itm_head")
key = key.replace("heads.cmd.itm_head.pooler", "itm_head.pooler")
key = key.replace("heads.cmd.clip_head.logit_scale", "flava.logit_scale")
key = key.replace("heads.fairseq_mlm.cls.predictions", "mlm_head")
key = key.replace("heads.imagenet.mim_head.cls.predictions", "mim_head")
key = key.replace("mm_text_projection", "flava.text_to_mm_projection")
key = key.replace("mm_image_projection", "flava.image_to_mm_projection")
key = key.replace("image_encoder.module", "flava.image_model")
key = key.replace("text_encoder.module", "flava.text_model")
key = key.replace("mm_encoder.module.encoder.cls_token", "flava.multimodal_model.cls_token")
key = key.replace("mm_encoder.module", "flava.multimodal_model")
key = key.replace("text_projection", "flava.text_projection")
key = key.replace("image_projection", "flava.image_projection")
upgrade[key] = value.float()
for key, value in codebook_state_dict.items():
upgrade[f"image_codebook.{key}"] = value
return upgrade
@torch.no_grad()
def convert_flava_checkpoint(checkpoint_path, codebook_path, pytorch_dump_folder_path, config_path=None):
"""
Copy/paste/tweak model's weights to transformers design.
将模型权重复制/粘贴/调整为transformers设计。
"""
if config_path is not None:
config = FlavaConfig.from_pretrained(config_path)
else:
config = FlavaConfig()
hf_model = FlavaForPreTraining(config).eval()
codebook_state_dict = convert_dalle_checkpoint(codebook_path, None, save_checkpoint=False)
if os.path.exists(checkpoint_path):
state_dict = torch.load(checkpoint_path, map_location="cpu")
else:
state_dict = torch.hub.load_state_dict_from_url(checkpoint_path, map_location="cpu")
hf_state_dict = upgrade_state_dict(state_dict, codebook_state_dict)
hf_model.load_state_dict(hf_state_dict)
hf_state_dict = hf_model.state_dict()
hf_count = count_parameters(hf_state_dict)
state_dict_count = count_parameters(state_dict) + count_parameters(codebook_state_dict)
assert torch.allclose(hf_count, state_dict_count, atol=1e-3)
hf_model.save_pretrained(pytorch_dump_folder_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to flava checkpoint")
parser.add_argument("--codebook_path", default=None, type=str, help="Path to flava codebook checkpoint")
parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
args = parser.parse_args()
convert_flava_checkpoint(args.checkpoint_path, args.codebook_path, args.pytorch_dump_folder_path, args.config_path)
.\models\flava\feature_extraction_flava.py
"""FLAVA 的特征提取器类。"""
import warnings
from ...utils import logging
from .image_processing_flava import FlavaImageProcessor
logger = logging.get_logger(__name__)
class FlavaFeatureExtractor(FlavaImageProcessor):
def __init__(self, *args, **kwargs) -> None:
warnings.warn(
"The class FlavaFeatureExtractor is deprecated and will be removed in version 5 of Transformers. Please"
" use FlavaImageProcessor instead.",
FutureWarning,
)
super().__init__(*args, **kwargs)
.\models\flava\image_processing_flava.py
"""Image processor class for Flava."""
import math
import random
from functools import lru_cache
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import numpy as np
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import resize, to_channel_dimension_format
from ...image_utils import (
OPENAI_CLIP_MEAN,
OPENAI_CLIP_STD,
ChannelDimension,
ImageInput,
PILImageResampling,
infer_channel_dimension_format,
is_scaled_image,
make_list_of_images,
to_numpy_array,
valid_images,
validate_kwargs,
validate_preprocess_arguments,
)
from ...utils import TensorType, is_vision_available, logging
if is_vision_available():
import PIL
logger = logging.get_logger(__name__)
FLAVA_IMAGE_MEAN = OPENAI_CLIP_MEAN
FLAVA_IMAGE_STD = OPENAI_CLIP_STD
FLAVA_CODEBOOK_MEAN = [0.0, 0.0, 0.0]
FLAVA_CODEBOOK_STD = [1.0, 1.0, 1.0]
LOGIT_LAPLACE_EPS: float = 0.1
class FlavaMaskingGenerator:
def __init__(
self,
input_size: Union[int, Tuple[int, int]] = 14,
total_mask_patches: int = 75,
mask_group_max_patches: Optional[int] = None,
mask_group_min_patches: int = 16,
mask_group_min_aspect_ratio: Optional[float] = 0.3,
mask_group_max_aspect_ratio: float = None,
):
if not isinstance(input_size, tuple):
input_size = (input_size,) * 2
self.height, self.width = input_size
self.num_patches = self.height * self.width
self.total_mask_patches = total_mask_patches
self.mask_group_min_patches = mask_group_min_patches
self.mask_group_max_patches = total_mask_patches if mask_group_max_patches is None else mask_group_max_patches
mask_group_max_aspect_ratio = mask_group_max_aspect_ratio or 1 / mask_group_min_aspect_ratio
self.log_aspect_ratio = (math.log(mask_group_min_aspect_ratio), math.log(mask_group_max_aspect_ratio))
def __repr__(self):
repr_str = "MaskingGenerator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % (
self.height,
self.width,
self.mask_group_min_patches,
self.mask_group_max_patches,
self.total_mask_patches,
self.log_aspect_ratio[0],
self.log_aspect_ratio[1],
)
return repr_str
def get_shape(self):
return self.height, self.width
def _mask(self, mask, max_mask_patches):
delta = 0
for _attempt in range(10):
target_area = random.uniform(self.mask_group_min_patches, max_mask_patches)
aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
height = int(round(math.sqrt(target_area * aspect_ratio)))
width = int(round(math.sqrt(target_area / aspect_ratio)))
if width < self.width and height < self.height:
top = random.randint(0, self.height - height)
left = random.randint(0, self.width - width)
num_masked = mask[top : top + height, left : left + width].sum()
if 0 < height * width - num_masked <= max_mask_patches:
for i in range(top, top + height):
for j in range(left, left + width):
if mask[i, j] == 0:
mask[i, j] = 1
delta += 1
if delta > 0:
break
return delta
def __call__(self):
mask = np.zeros(shape=self.get_shape(), dtype=int)
mask_count = 0
while mask_count < self.total_mask_patches:
max_mask_patches = self.total_mask_patches - mask_count
max_mask_patches = min(max_mask_patches, self.mask_group_max_patches)
delta = self._mask(mask, max_mask_patches)
if delta == 0:
break
else:
mask_count += delta
return mask
class FlavaImageProcessor(BaseImageProcessor):
r"""
构造一个 Flava 图像处理器。
"""
model_input_names = ["pixel_values"]
def __init__(
self,
do_resize: bool = True,
size: Dict[str, int] = None,
resample: PILImageResampling = PILImageResampling.BICUBIC,
do_center_crop: bool = True,
crop_size: Dict[str, int] = None,
do_rescale: bool = True,
rescale_factor: Union[int, float] = 1 / 255,
do_normalize: bool = True,
image_mean: Optional[Union[float, Iterable[float]]] = None,
image_std: Optional[Union[float, Iterable[float]]] = None,
return_image_mask: bool = False,
input_size_patches: int = 14,
total_mask_patches: int = 75,
mask_group_min_patches: int = 16,
mask_group_max_patches: Optional[int] = None,
mask_group_min_aspect_ratio: float = 0.3,
mask_group_max_aspect_ratio: Optional[float] = None,
return_codebook_pixels: bool = False,
codebook_do_resize: bool = True,
codebook_size: bool = None,
codebook_resample: int = PILImageResampling.LANCZOS,
codebook_do_center_crop: bool = True,
codebook_crop_size: int = None,
codebook_do_rescale: bool = True,
codebook_rescale_factor: Union[int, float] = 1 / 255,
codebook_do_map_pixels: bool = True,
codebook_do_normalize: bool = True,
codebook_image_mean: Optional[Union[float, Iterable[float]]] = None,
codebook_image_std: Optional[Union[float, Iterable[float]]] = None,
**kwargs,
@classmethod
def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
"""
重写基类的 `from_dict` 方法,以确保在使用 from_dict 和 kwargs 创建图像处理器时更新参数,
例如 `FlavaImageProcessor.from_pretrained(checkpoint, codebook_size=600)`
"""
image_processor_dict = image_processor_dict.copy()
if "codebook_size" in kwargs:
image_processor_dict["codebook_size"] = kwargs.pop("codebook_size")
if "codebook_crop_size" in kwargs:
image_processor_dict["codebook_crop_size"] = kwargs.pop("codebook_crop_size")
return super().from_dict(image_processor_dict, **kwargs)
@lru_cache()
def masking_generator(
self,
input_size_patches,
total_mask_patches,
mask_group_min_patches,
mask_group_max_patches,
mask_group_min_aspect_ratio,
mask_group_max_aspect_ratio,
) -> FlavaMaskingGenerator:
return FlavaMaskingGenerator(
input_size=input_size_patches,
total_mask_patches=total_mask_patches,
mask_group_min_patches=mask_group_min_patches,
mask_group_max_patches=mask_group_max_patches,
mask_group_min_aspect_ratio=mask_group_min_aspect_ratio,
mask_group_max_aspect_ratio=mask_group_max_aspect_ratio,
)
def resize(
self,
image: np.ndarray,
size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.BICUBIC,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
) -> np.ndarray:
"""
Resize an image to `(size["height"], size["width"])`.
Args:
image (`np.ndarray`):
Image to resize.
size (`Dict[str, int]`):
Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
`PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BICUBIC`.
data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the output image. If unset, the channel dimension format of the input
image is used. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
Returns:
`np.ndarray`: The resized image.
"""
size = get_size_dict(size)
if "height" not in size or "width" not in size:
raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
output_size = (size["height"], size["width"])
return resize(
image,
size=output_size,
resample=resample,
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
)
def map_pixels(self, image: np.ndarray) -> np.ndarray:
"""
Maps pixel values of an image using a specific constant.
Args:
image (`np.ndarray`):
Input image.
Returns:
`np.ndarray`: Processed image with mapped pixel values.
"""
return (1 - 2 * LOGIT_LAPLACE_EPS) * image + LOGIT_LAPLACE_EPS
def _preprocess_image(
self,
image: ImageInput,
do_resize: bool = None,
size: Dict[str, int] = None,
resample: PILImageResampling = None,
do_center_crop: bool = None,
crop_size: Dict[str, int] = None,
do_rescale: bool = None,
rescale_factor: float = None,
do_normalize: bool = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_map_pixels: bool = None,
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
input_data_format: Optional[ChannelDimension] = None,
) -> np.ndarray:
"""Preprocesses a single image."""
validate_preprocess_arguments(
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_center_crop=do_center_crop,
crop_size=crop_size,
do_resize=do_resize,
size=size,
resample=resample,
)
image = to_numpy_array(image)
if is_scaled_image(image) and do_rescale:
logger.warning_once(
"It looks like you are trying to rescale already rescaled images. If the input"
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
)
if input_data_format is None:
input_data_format = infer_channel_dimension_format(image)
if do_resize:
image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
if do_center_crop:
image = self.center_crop(image=image, size=crop_size, input_data_format=input_data_format)
if do_rescale:
image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
if do_normalize:
image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
if do_map_pixels:
image = self.map_pixels(image)
if data_format is not None:
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
return image
def preprocess(
self,
images: ImageInput,
do_resize: Optional[bool] = None,
size: Dict[str, int] = None,
resample: PILImageResampling = None,
do_center_crop: Optional[bool] = None,
crop_size: Optional[Dict[str, int]] = None,
do_rescale: Optional[bool] = None,
rescale_factor: Optional[float] = None,
do_normalize: Optional[bool] = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
return_image_mask: Optional[bool] = None,
input_size_patches: Optional[int] = None,
total_mask_patches: Optional[int] = None,
mask_group_min_patches: Optional[int] = None,
mask_group_max_patches: Optional[int] = None,
mask_group_min_aspect_ratio: Optional[float] = None,
mask_group_max_aspect_ratio: Optional[float] = None,
return_codebook_pixels: Optional[bool] = None,
codebook_do_resize: Optional[bool] = None,
codebook_size: Optional[Dict[str, int]] = None,
codebook_resample: Optional[int] = None,
codebook_do_center_crop: Optional[bool] = None,
codebook_crop_size: Optional[Dict[str, int]] = None,
codebook_do_rescale: Optional[bool] = None,
codebook_rescale_factor: Optional[float] = None,
codebook_do_map_pixels: Optional[bool] = None,
codebook_do_normalize: Optional[bool] = None,
codebook_image_mean: Optional[Iterable[float]] = None,
codebook_image_std: Optional[Iterable[float]] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: ChannelDimension = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
):
.\models\flava\modeling_flava.py
""" PyTorch FLAVA model. """
import collections
import math
from collections import OrderedDict
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Set, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
ModelOutput,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from .configuration_flava import (
FlavaConfig,
FlavaImageCodebookConfig,
FlavaImageConfig,
FlavaMultimodalConfig,
FlavaTextConfig,
)
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "facebook/flava-full"
_CHECKPOINT_FOR_CODEBOOK_DOC = "facebook/flava-image-codebook"
_CONFIG_CLASS_FOR_IMAGE_MODEL_DOC = "FlavaImageConfig"
_CONFIG_CLASS_FOR_TEXT_MODEL_DOC = "FlavaTextConfig"
_CONFIG_CLASS_FOR_MULTIMODAL_MODEL_DOC = "FlavaMultimodalConfig"
_EXPECTED_IMAGE_OUTPUT_SHAPE = [1, 197, 768]
FLAVA_PRETRAINED_MODEL_ARCHIVE_LIST = [
"facebook/flava-full",
]
FLAVA_CODEBOOK_PRETRAINED_MODEL_ARCHIVE_LIST = ["facebook/flava-image-codebook"]
LOGIT_SCALE_CLAMP_MIN = 0
LOGIT_SCALE_CLAMP_MAX = 4.6052
FlavaPossibleConfigs = Union[FlavaTextConfig, FlavaImageConfig, FlavaMultimodalConfig]
@dataclass
class FlavaModelOutput(ModelOutput):
"""
FlavaModel 的输出,包含来自各个编码器的嵌入和输出。
注意,返回的 `image_embeddings` 和 `text_embeddings` 类似于变压器返回的汇总输出。
如果需要用于对比损失或检索的嵌入,请在 `image_embeddings` 和 `text_embeddings` 上使用 FLAVA 模型的
`image_projection` 和 `text_projection` 层。
"""
"""
Args:
image_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `pixel_values` are present):
The image embeddings which are basically the pooled output of [`FlavaImageModel`].
image_output (`BaseModelOutputWithPooling`, *optional*, returned when `pixel_values` are present):
The output of the [`FlavaImageModel`].
text_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `input_ids` are present):
The text embeddings which are basically the pooled output of [`FlavaTextModel`].
text_output (`BaseModelOutputWithPooling`, *optional*, returned when `input_ids` are present):
The output of the [`FlavaTextModel`].
multimodal_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `input_ids` and `pixel_values` are present and `skip_multimodal_encoder` is `None` or `False`):
The multimodal embeddings which are basically the pooled output of [`FlavaTextModel`].
multimodal_output (`BaseModelOutputWithPooling`, returned when `input_ids` and `pixel_values` are present and `skip_multimodal_encoder` is `None` or `False`):
The output of the [`FlavaMultimodalModel`].
"""
image_embeddings: Optional[torch.FloatTensor] = None
image_output: Optional[BaseModelOutputWithPooling] = None
text_embeddings: Optional[torch.FloatTensor] = None
text_output: Optional[BaseModelOutputWithPooling] = None
multimodal_embeddings: Optional[torch.FloatTensor] = None
multimodal_output: Optional[BaseModelOutputWithPooling] = None
def to_tuple(self) -> Tuple[Any]:
return tuple(
self[k] if k not in ["text_output", "image_output", "multimodal_output"] else getattr(self, k).to_tuple()
for k in self.keys()
)
@dataclass
class FlavaLosses(ModelOutput):
"""Class representing pretraining losses from FLAVA model
Args:
mim (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mim_labels` and `pixel_values` are present, `input_ids_masked` is absent and `mim_weight` > 0.:
Masked Image Modeling loss as used in BeIT calculated only for unimodal image data.
mlm (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mlm_labels` and `input_ids_masked` are present, `pixel_values` is absent and `mlm_weight` > 0.:
Masked Language Modeling loss as used in BERT calculated only for unimodal text data.
itm (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `itm_labels`, `input_ids_masked`, `pixel_values` are present and `itm_weight` > 0.:
Image Text Matching (ITM) loss calculated for paired image-text data. Note that ITM loss is calculated on
masked pairs in FLAVA.
global_contrastive (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `input_ids` and `pixel_values` are present and `global_contrastive_weight` > 0.:
Contrastive loss for image-text similarity similar to CLIP but calculated globally for paired image-text
data. This is calculated on unmasked images and texts.
mmm_image (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mim_labels`, `pixel_values` and `input_ids_masked` are present and `mmm_image_weight` > 0.:
Masked Multimodal Modeling loss's image component calculated on paired image-text data.
mmm_text (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mlm_labels`, `pixel_values` and `input_ids_masked` are present and `mmm_text_weight` > 0.:
Masked Multimodal Modeling loss's text component calculated on paired image-text data.
"""
mim: Optional[torch.FloatTensor] = None
mlm: Optional[torch.FloatTensor] = None
itm: Optional[torch.FloatTensor] = None
global_contrastive: Optional[torch.FloatTensor] = None
mmm_image: Optional[torch.FloatTensor] = None
mmm_text: Optional[torch.FloatTensor] = None
def all_none(self) -> bool:
all_none = True
for v in self.values():
if v is not None:
all_none = False
break
return all_none
@dataclass
class FlavaForPreTrainingOutput(ModelOutput):
"""
Output from FlavaForPreTraining containing embeddings, and outputs from individual encoders.
Note that `image_embeddings` and `text_embeddings` returned are similar to pooled output returned from a
transformer. If you want embeddings for contrastive loss or retrieval use a FLAVA model's `image_projection` and
`text_projection` layers on `image_embeddings` and `text_embeddings` respectively.
"""
loss: Optional[torch.FloatTensor] = None
loss_info: FlavaLosses = None
image_embeddings: Optional[torch.FloatTensor] = None
image_output: Optional[BaseModelOutputWithPooling] = None
text_embeddings: Optional[torch.FloatTensor] = None
text_output: Optional[BaseModelOutputWithPooling] = None
multimodal_embeddings: Optional[torch.FloatTensor] = None
multimodal_output: Optional[BaseModelOutputWithPooling] = None
image_masked_embeddings: Optional[torch.FloatTensor] = None
image_masked_output: Optional[BaseModelOutputWithPooling] = None
text_masked_embeddings: Optional[torch.FloatTensor] = None
text_masked_output: Optional[BaseModelOutputWithPooling] = None
multimodal_masked_embeddings: Optional[torch.FloatTensor] = None
multimodal_masked_output: Optional[BaseModelOutputWithPooling] = None
mim_logits: Optional[torch.FloatTensor] = None
mlm_logits: Optional[torch.FloatTensor] = None
itm_logits: Optional[torch.FloatTensor] = None
contrastive_logits_per_image: Optional[torch.FloatTensor] = None
contrastive_logits_per_text: Optional[torch.FloatTensor] = None
mmm_image_logits: Optional[torch.FloatTensor] = None
mmm_text_logits: Optional[torch.FloatTensor] = None
def to_tuple(self) -> Tuple[Any]:
transformer_outputs = [
"text_output",
"image_output",
"multimodal_output",
"text_masked_output",
"image_masked_output",
"multimodal_masked_output",
]
return tuple(self[k] if k not in transformer_outputs else getattr(self, k).to_tuple() for k in self.keys())
class FlavaImageEmbeddings(nn.Module):
"""
构建 CLS token、位置和patch embeddings。可选择是否包含 mask token。
"""
def __init__(self, config: FlavaImageConfig, use_mask_token: bool = False) -> None:
super().__init__()
use_mask_token = use_mask_token or config.mask_token
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None
self.patch_embeddings = PatchEmbeddings(
image_size=config.image_size,
patch_size=config.patch_size,
num_channels=config.num_channels,
embed_dim=config.hidden_size,
)
num_patches = self.patch_embeddings.num_patches
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.config = config
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
resolution images.
Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/image_transformer.py#L174
"""
npatch = embeddings.shape[1] - 1
num_pos = self.position_embeddings.shape[1] - 1
if npatch == num_pos and height == width:
return self.position_embeddings
class_pos_embed = self.position_embeddings[:, 0]
patch_pos_embed = self.position_embeddings[:, 1:]
dim = embeddings.shape[-1]
num_h_patches = height // self.config.patch_size
num_w_patches = width // self.config.patch_size
num_h_patches, num_w_patches = num_h_patches + 0.1, num_w_patches + 0.1
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed.reshape(1, int(math.sqrt(num_pos)), int(math.sqrt(num_pos)), dim).permute(0, 3, 1, 2),
scale_factor=(num_h_patches / math.sqrt(num_pos), num_w_patches / math.sqrt(num_pos)),
mode="bicubic",
align_corners=False,
)
if int(num_h_patches) != patch_pos_embed.shape[-2] or int(num_w_patches) != patch_pos_embed.shape[-1]:
raise ValueError(
f"Number of patches for images ({int(num_h_patches), int(num_w_patches)}) don't match the "
f"shape of position embedding ({patch_pos_embed.shape[-2], patch_pos_embed.shape[-1]})"
)
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
) -> torch.Tensor:
batch_size, num_channels, height, width = pixel_values.shape
embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
batch_size, seq_len, _ = embeddings.size()
if bool_masked_pos is not None:
mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)
if bool_masked_pos.dim() == 3:
bool_masked_pos = bool_masked_pos.view(bool_masked_pos.size(0), -1)
mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
embeddings = torch.cat((cls_tokens, embeddings), dim=1)
if interpolate_pos_encoding:
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
else:
embeddings = embeddings + self.position_embeddings
embeddings = self.dropout(embeddings)
return embeddings
class PatchEmbeddings(nn.Module):
"""
Image to Patch Embedding.
"""
def __init__(
self,
image_size: int = 224,
patch_size: Union[int, Tuple[int, int]] = 16,
num_channels: int = 3,
embed_dim: int = 768,
):
super().__init__()
if not isinstance(image_size, collections.abc.Iterable):
image_size = (image_size, image_size)
if not isinstance(patch_size, collections.abc.Iterable):
patch_size = (patch_size, patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.image_size = image_size
self.patch_size = patch_size
self.num_patches = num_patches
self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
batch_size, num_channels, height, width = pixel_values.shape
if not interpolate_pos_encoding:
if height != self.image_size[0] or width != self.image_size[1]:
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model"
f" ({self.image_size[0]}*{self.image_size[1]})."
)
x = self.projection(pixel_values).flatten(2).transpose(1, 2)
return x
class FlavaTextEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings."""
def __init__(self, config):
super().__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
self.register_buffer(
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
):
input_shape = input_ids.size()
seq_length = input_shape[1]
if position_ids is None:
position_ids = self.position_ids[:, :seq_length]
if token_type_ids is None:
if hasattr(self, "token_type_ids"):
buffered_token_type_ids = self.token_type_ids[:, :seq_length]
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
token_type_ids = buffered_token_type_ids_expanded
else:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
inputs_embeds = self.word_embeddings(input_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = inputs_embeds + token_type_embeddings
if self.position_embedding_type == "absolute":
position_embeddings = self.position_embeddings(position_ids)
embeddings += position_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
class FlavaSelfAttention(nn.Module):
def __init__(self, config: FlavaPossibleConfigs) -> None:
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
f"heads {config.num_attention_heads}."
)
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
if attention_mask is not None:
attention_scores = attention_scores + attention_mask
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
attention_probs = self.dropout(attention_probs)
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs
class FlavaSelfOutput(nn.Module):
"""
The residual connection is defined in FlavaLayer (same as ViTLayer) instead of here (as is the case with other
models), due to the layernorm applied before each block.
"""
def __init__(self, config: FlavaPossibleConfigs) -> None:
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states
class FlavaAttention(nn.Module):
def __init__(self, config: FlavaPossibleConfigs) -> None:
super().__init__()
self.attention = FlavaSelfAttention(config)
self.output = FlavaSelfOutput(config)
self.pruned_heads = set()
def prune_heads(self, heads: Set[int]) -> None:
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(
heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
)
self.attention.query = prune_linear_layer(self.attention.query, index)
self.attention.key = prune_linear_layer(self.attention.key, index)
self.attention.value = prune_linear_layer(self.attention.value, index)
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
self_outputs = self.attention(
hidden_states, attention_mask=attention_mask, head_mask=head_mask, output_attentions=output_attentions
)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:]
return outputs
class FlavaIntermediate(nn.Module):
def __init__(self, config: FlavaPossibleConfigs) -> None:
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class FlavaOutput(nn.Module):
def __init__(self, config: FlavaPossibleConfigs) -> None:
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = hidden_states + input_tensor
return hidden_states
class FlavaLayer(nn.Module):
"""这对应于timm实现中的Block类。"""
def __init__(self, config: FlavaPossibleConfigs) -> None:
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = FlavaAttention(config)
self.intermediate = FlavaIntermediate(config)
self.output = FlavaOutput(config)
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
self_attention_outputs = self.attention(
self.layernorm_before(hidden_states),
attention_mask=attention_mask,
head_mask=head_mask,
output_attentions=output_attentions,
)
attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:]
hidden_states = attention_output + hidden_states
layer_output = self.layernorm_after(hidden_states)
layer_output = self.intermediate(layer_output)
layer_output = self.output(layer_output, hidden_states)
outputs = (layer_output,) + outputs
return outputs
class FlavaEncoder(nn.Module):
def __init__(self, config: FlavaConfig) -> None:
super().__init__()
self.config = config
self.layer = nn.ModuleList([FlavaLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
) -> Union[tuple, BaseModelOutput]:
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
layer_head_mask,
output_attentions,
)
else:
layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions
)
class FlavaPooler(nn.Module):
def __init__(self, config: FlavaPossibleConfigs):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states: torch.Tensor):
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
FLAVA_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
behavior.
Parameters:
config ([`{config}`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
FLAVA_INPUTS_DOCSTRING_COMMON = r"""
attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
FLAVA_IMAGE_INPUTS_DOCSTRING_BASE = r"""
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values of the input images. This tensor represents the image data with dimensions:
- `batch_size`: Number of images in the batch.
- `num_channels`: Number of color channels (e.g., 3 for RGB images).
- `height`: Height of each image.
- `width`: Width of each image.
Pixel values can be obtained using an `AutoImageProcessor`. Refer to the documentation
of [`FlavaImageProcessor.__call__`] for more details.
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, image_num_patches)`):
Boolean tensor indicating masked positions within each image. Each element:
- `1`: Indicates the corresponding image patch is masked.
- `0`: Indicates the corresponding image patch is not masked.
interpolate_pos_encoding (`bool`, *optional*):
Optional flag indicating whether to interpolate pre-trained position encodings. If set to `True`,
the model will interpolate existing position encodings; if `False` or not provided, no interpolation
will be performed.
"""
FLAVA_IMAGE_INPUTS_DOCSTRING = FLAVA_IMAGE_INPUTS_DOCSTRING_BASE + FLAVA_INPUTS_DOCSTRING_COMMON
FLAVA_TEXT_INPUTS_DOCSTRING_BASE = r"""
Args:
input_ids (`torch.LongTensor` of shape `({0})`):
Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See
[`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input
IDs?](../glossary#input-ids)
token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
1]`:
- 0 corresponds to a *sentence A* token,
- 1 corresponds to a *sentence B* token.
[What are token type IDs?](../glossary#token-type-ids)
"""
FLAVA_TEXT_INPUTS_DOCSTRING = FLAVA_TEXT_INPUTS_DOCSTRING_BASE + FLAVA_INPUTS_DOCSTRING_COMMON
FLAVA_MULTIMODAL_INPUTS_DOCSTRING = (
r"""
Args:
hidden_states (`torch.FloatTensor` of shape `(batch_size, image_num_patches + text_seq_len, hidden_size)`):
The concatenated hidden states of unimodal encoders.
"""
+ FLAVA_INPUTS_DOCSTRING_COMMON
)
FLAVA_MODEL_INPUTS_DOCSTRING_BASE = r"""
Args:
skip_multimodal_encoder (*bool*, *optional*):
Skip any calculations for multimodal encoder. Useful if multimodal encoding is not going to be used.
"""
FLAVA_MODEL_INPUTS_DOCSTRING = (
FLAVA_IMAGE_INPUTS_DOCSTRING_BASE
+ FLAVA_TEXT_INPUTS_DOCSTRING_BASE
+ FLAVA_INPUTS_DOCSTRING_COMMON
+ FLAVA_MODEL_INPUTS_DOCSTRING_BASE
)
FLAVA_PRETRAINING_INPUTS_DOCSTRING = (
r"""
Args:
input_ids_masked (`torch.LongTensor` of shape `({0})`):
Indices of input sequence tokens in the vocabulary. These ones are the masked version of the original task
to be used with MLM. Indices can be obtained using [`AutoTokenizer`] along with
[`DataCollatorForMaskedLanguageModeling`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids)
"""
+ FLAVA_TEXT_INPUTS_DOCSTRING_BASE
+ FLAVA_IMAGE_INPUTS_DOCSTRING_BASE
)
"""
+ FLAVA_INPUTS_DOCSTRING_COMMON
)
FLAVA_PRETRAINING_START_DOCSTRING_EXTRA = r"""
Parameters:
image_codebook ([`nn.Module`]): If passed, the image codebook will be set to this. Otherwise. it will
be initialized using the image_codebook_config defined in the config first as the first parameter.
"""
class FlavaPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = FlavaConfig
base_model_prefix = "flava"
supports_gradient_checkpointing = True
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d)):
# 对于线性层和卷积层,使用正态分布初始化权重
# 略微不同于 TF 版本,后者使用截断正态分布进行初始化
# 参考:https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
# 如果存在偏置项,则将其初始化为零
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
# 对于嵌入层,使用正态分布初始化权重
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
# 如果指定了 padding_idx,则将其对应的权重初始化为零
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
# 对于 LayerNorm 层,初始化偏置为零,初始化权重为全1
module.bias.data.zero_()
module.weight.data.fill_(1.0)
@add_start_docstrings(
"The bare FLAVA Image Model transformer outputting raw hidden-states without any specific head on top.",
FLAVA_START_DOCSTRING.format(config="FlavaImageConfig"),
)
class FlavaImageModel(FlavaPreTrainedModel):
config_class = FlavaImageConfig
# This override allows us to load FlavaImageModel from FlavaModel/FlavaForPreTraining checkpoints.
base_model_prefix = "flava.image_model"
main_input_name = "pixel_values"
def __init__(self, config: FlavaImageConfig, add_pooling_layer: bool = True):
super().__init__(config)
self.config = config
# 初始化模型的各个部分
self.embeddings = FlavaImageEmbeddings(config)
self.encoder = FlavaEncoder(config)
# 初始化 LayerNorm 和 Pooler(如果需要)
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.pooler = FlavaPooler(config) if add_pooling_layer else None
# 执行初始化后的操作
self.post_init()
def get_input_embeddings(self) -> nn.Module:
return self.embeddings.patch_embeddings
def set_input_embeddings(self, value: nn.Module):
self.embeddings.patch_embeddings = value
def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
# 裁剪模型中的注意力头
self.encoder.layer[layer].attention.prune_heads(heads)
@add_start_docstrings_to_model_forward(FLAVA_IMAGE_INPUTS_DOCSTRING.format("batch_size, image_num_patches"))
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=BaseModelOutputWithPooling,
config_class=_CONFIG_CLASS_FOR_IMAGE_MODEL_DOC,
modality="vision",
expected_output=_EXPECTED_IMAGE_OUTPUT_SHAPE,
)
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
bool_masked_pos: Optional[torch.BoolTensor] = None,
interpolate_pos_encoding: Optional[bool] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[tuple, BaseModelOutputWithPooling]:
# 设置输出注意事项,默认为模型配置中的输出设置
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
# 设置输出隐藏状态,默认为模型配置中的输出设置
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
# 设置是否返回字典,默认为模型配置中的设置
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None:
# 如果未提供像素值,则抛出数值错误异常
raise ValueError("You have to specify pixel_values")
# 准备头部掩码(如果需要)
# head_mask 中的 1.0 表示保留该头部
# attention_probs 的形状为 bsz x n_heads x N x N
# 输入的 head_mask 形状为 [num_heads] 或 [num_hidden_layers x num_heads]
# 并且 head_mask 被转换为形状 [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
# 将像素值传入嵌入层进行编码
embedding_output = self.embeddings(
pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
)
# 将编码后的数据传入编码器进行处理
encoder_outputs = self.encoder(
embedding_output,
attention_mask=attention_mask,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# 获取编码器的序列输出
sequence_output = encoder_outputs[0]
# 序列输出经过 LayerNormalization 处理
sequence_output = self.layernorm(sequence_output)
# 如果有池化层,对序列输出进行池化
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
if not return_dict:
# 如果不返回字典,则返回元组格式的输出
return (sequence_output, pooled_output) + encoder_outputs[1:]
# 如果需要返回字典格式的输出,则构造 BaseModelOutputWithPooling 对象
return BaseModelOutputWithPooling(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
# 使用装饰器添加文档字符串,描述这是一个在顶部没有特定头部的原始隐藏状态输出的 FLAVA 文本模型转换器。
# 使用 FLAVA_START_DOCSTRING 格式化字符串,填充 FlavaTextConfig 相关信息。
@add_start_docstrings(
"The bare FLAVA Text Model transformer outputting raw hidden-states without any specific head on top.",
FLAVA_START_DOCSTRING.format(config="FlavaTextConfig"),
)
# 定义 FlavaTextModel 类,继承自 FlavaPreTrainedModel 类
class FlavaTextModel(FlavaPreTrainedModel):
# 指定配置类为 FlavaTextConfig
config_class = FlavaTextConfig
# 模型前缀用于加载 FlavaTextModel 的检查点
base_model_prefix = "flava.text_model"
def __init__(self, config: FlavaTextConfig, add_pooling_layer: bool = True):
# 调用父类的构造方法,传入配置对象
super().__init__(config)
# 保存配置对象
self.config = config
# 初始化嵌入层对象
self.embeddings = FlavaTextEmbeddings(config)
# 初始化编码器对象
self.encoder = FlavaEncoder(config)
# 初始化层归一化层,使用配置中的隐藏层大小和层归一化的 epsilon 参数
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
# 如果设置了添加池化层,则初始化池化层对象
self.pooler = FlavaPooler(config) if add_pooling_layer else None
# 执行初始化后的处理
self.post_init()
# 获取输入嵌入的方法,返回词嵌入层对象
def get_input_embeddings(self) -> PatchEmbeddings:
return self.embeddings.word_embeddings
# 设置输入嵌入的方法,设置词嵌入层对象为指定的值
def set_input_embeddings(self, value: nn.Module):
self.embeddings.word_embeddings = value
# 对模型的注意力头进行修剪的方法
def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
# 使用装饰器添加模型正向传播的文档字符串,描述输入参数的含义
@add_start_docstrings_to_model_forward(FLAVA_TEXT_INPUTS_DOCSTRING.format("batch_size, text_seq_length"))
# 使用示例代码的文档字符串样本,描述模型的检查点、输出类型和配置类
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=BaseModelOutputWithPooling,
config_class=_CONFIG_CLASS_FOR_TEXT_MODEL_DOC,
)
# 模型的正向传播方法定义,接收多个输入参数,返回模型输出
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[tuple, BaseModelOutputWithPooling]:
# 如果 output_attentions 参数为 None,则使用配置中的 output_attentions 参数
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
# 如果 output_hidden_states 参数为 None,则使用配置中的 output_hidden_states 参数
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
# 如果 return_dict 参数为 None,则使用配置中的 use_return_dict 参数
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# 如果 input_ids 为空,则抛出数值错误异常
if input_ids is None:
raise ValueError("You have to specify input_ids")
# 获取 input_ids 的形状
input_shape = input_ids.size()
# 如果 attention_mask 为空,则创建全 1 的注意力掩码,形状与 input_ids 相同
if attention_mask is None:
attention_mask = torch.ones(input_shape, device=input_ids.device)
# 准备头部掩码(head mask),如果需要的话
# head_mask 中的 1.0 表示保留对应的头部
# attention_probs 的形状为 bsz x n_heads x N x N
# 输入的 head_mask 形状为 [num_heads] 或 [num_hidden_layers x num_heads]
# 并且 head_mask 被转换为形状 [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
# 获取扩展的注意力掩码(extended_attention_mask)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
attention_mask, input_shape, input_ids.device
)
# 通过 embeddings 模块生成嵌入输出
embedding_output = self.embeddings(
input_ids=input_ids,
token_type_ids=token_type_ids,
position_ids=position_ids,
)
# 使用 encoder 模块处理嵌入输出,得到编码器的输出
encoder_outputs = self.encoder(
embedding_output,
attention_mask=extended_attention_mask,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# 获取序列输出(sequence_output)
sequence_output = encoder_outputs[0]
# 应用 layernorm 层到序列输出上
sequence_output = self.layernorm(sequence_output)
# 如果有池化器(pooler),则应用池化器到序列输出上,得到池化输出(pooled_output)
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
# 如果不需要返回字典形式的结果,则返回一个元组
if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]
# 否则,返回一个 BaseModelOutputWithPooling 对象
return BaseModelOutputWithPooling(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
# 添加文档字符串描述 FLAVA Multimodal Model 类的基本信息和配置格式
@add_start_docstrings(
"The bare FLAVA Multimodal Model transformer outputting raw hidden-states without any specific head on top.",
FLAVA_START_DOCSTRING.format(config="FlavaMultimodalConfig"),
)
class FlavaMultimodalModel(FlavaPreTrainedModel):
# 指定该类使用的配置类为 FlavaMultimodalConfig
config_class = FlavaMultimodalConfig
# 定义在加载模型时从 FlavaModel/FlavaForPreTraining 检查点中读取的基础模型前缀
base_model_prefix = "flava.multimodal_model"
# 主输入名称为 "hidden_states"
main_input_name = "hidden_states"
def __init__(self, config: FlavaMultimodalConfig, add_pooling_layer=True):
# 调用父类构造函数,初始化模型配置
super().__init__(config)
self.config = config
# 根据配置决定是否使用类别标记 (CLS token)
self.use_cls_token = self.config.use_cls_token
if self.use_cls_token:
# 如果使用类别标记,则初始化一个可学习的张量作为类别标记
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
# 初始化编码器,使用 FlavaEncoder 类
self.encoder = FlavaEncoder(config)
# 初始化层归一化层,使用指定的层归一化尺寸和 epsilon 值
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
# 根据参数决定是否添加池化层,如果需要则使用 FlavaPooler 类初始化池化层
self.pooler = FlavaPooler(config) if add_pooling_layer else None
# 调用后初始化方法,用于子类中进一步初始化操作
self.post_init()
def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
# 遍历需要剪枝的层和对应的注意力头信息,通过调用编码器中的注意力层进行剪枝
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
@add_start_docstrings_to_model_forward(
FLAVA_MULTIMODAL_INPUTS_DOCSTRING.format("batch_size, image_num_patches + text_seq_len")
)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=BaseModelOutputWithPooling,
config_class=_CONFIG_CLASS_FOR_MULTIMODAL_MODEL_DOC,
)
# 定义模型的前向传播函数,接收输入的隐藏状态和可选的掩码和掩码头,返回模型输出
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
# 前向传播函数的文档字符串描述了输入参数的含义和模型预期的输出
# 输入的隐藏状态张量
self,
# 可选的注意力掩码张量,用于控制哪些位置需要被忽略
attention_mask: Optional[torch.Tensor] = None,
# 可选的头掩码张量,用于控制哪些注意力头需要被忽略
head_mask: Optional[torch.Tensor] = None,
# 是否输出注意力权重
output_attentions: Optional[bool] = None,
# 是否输出隐藏状态
output_hidden_states: Optional[bool] = None,
# 是否以字典格式返回结果
return_dict: Optional[bool] = None,
) -> Union[tuple, BaseModelOutputWithPooling]:
# 确定是否输出注意力权重
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
# 确定是否输出隐藏状态
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
# 确定是否使用返回字典格式
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# 获取隐藏状态张量的维度
batch_size, seq_length, _ = hidden_states.size()
# 如果使用CLS token,则扩展并拼接隐藏状态
if self.use_cls_token:
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
hidden_states = torch.cat((cls_tokens, hidden_states), dim=1)
seq_length += 1
# 如果未提供注意力掩码,则创建全1的注意力掩码张量
if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_length), device=hidden_states.device)
# 准备头部掩码(如果需要)
# head_mask中的1.0表示保留对应的头部
# attention_probs的形状为bsz x n_heads x N x N
# 输入的head_mask形状为[num_heads]或[num_hidden_layers x num_heads]
# head_mask被转换为形状[num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
# 获取扩展的注意力掩码张量
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
attention_mask, (batch_size, seq_length), hidden_states.device
)
# 将隐藏状态输入编码器
encoder_outputs = self.encoder(
hidden_states,
attention_mask=extended_attention_mask,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# 获取编码器的序列输出
sequence_output = encoder_outputs[0]
# 应用层归一化
sequence_output = self.layernorm(sequence_output)
# 如果存在池化器,则对序列输出进行池化
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
# 如果不需要返回字典,则返回一个元组
if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]
# 如果需要返回字典,则构造BaseModelOutputWithPooling对象返回
return BaseModelOutputWithPooling(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
# 使用装饰器为 FLAVA 模型类添加文档字符串,描述该模型仅输出原始隐藏状态,没有额外的顶层头部。
# FLAVA_START_DOCSTRING 包含一个格式字符串,用于填充 FlavaConfig 的信息。
@add_start_docstrings(
"The bare FLAVA Model transformer outputting raw hidden-states without any specific head on top.",
FLAVA_START_DOCSTRING.format(config="FlavaConfig"),
)
class FlavaModel(FlavaPreTrainedModel):
# 指定配置类为 FlavaConfig
config_class = FlavaConfig
def __init__(self, config: FlavaConfig):
# 调用父类构造函数,初始化模型
super().__init__(config)
# 验证文本配置是否为 FlavaTextConfig 类型,否则引发 ValueError
if not isinstance(config.text_config, FlavaTextConfig):
raise ValueError(
"config.text_config is expected to be of type FlavaTextConfig but is of type"
f" {type(config.text_config)}."
)
# 验证图像配置是否为 FlavaImageConfig 类型,否则引发 ValueError
if not isinstance(config.image_config, FlavaImageConfig):
raise ValueError(
"config.image_config is expected to be of type FlavaImageConfig but is of type"
f" {type(config.image_config)}."
)
# 验证多模态配置是否为 FlavaMultimodalConfig 类型,否则引发 ValueError
if not isinstance(config.multimodal_config, FlavaMultimodalConfig):
raise ValueError(
"config.multimodal_config is expected to be of type FlavaMultimodalConfig but "
+ f"is of type {type(config.multimodal_config)}."
)
# 将各配置对象存储为类属性
text_config = config.text_config
image_config = config.image_config
multimodal_config = config.multimodal_config
# 初始化投影维度、文本隐藏层大小、图像隐藏层大小、多模态隐藏层大小
self.projection_dim = config.projection_dim
self.text_hidden_size = text_config.hidden_size
self.image_hidden_size = image_config.hidden_size
self.mm_hidden_size = multimodal_config.hidden_size
# 初始化文本模型、图像模型和多模态模型
self.text_model = FlavaTextModel(text_config)
self.image_model = FlavaImageModel(image_config)
self.multimodal_model = FlavaMultimodalModel(multimodal_config)
# 初始化图像到投影空间的线性层、文本到投影空间的线性层、logit 缩放参数
self.image_projection = nn.Linear(self.image_hidden_size, self.projection_dim)
self.text_projection = nn.Linear(self.text_hidden_size, self.projection_dim)
self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
# 初始化图像到多模态投影空间的线性层、文本到多模态投影空间的线性层
self.image_to_mm_projection = nn.Linear(self.image_hidden_size, self.mm_hidden_size)
self.text_to_mm_projection = nn.Linear(self.text_hidden_size, self.mm_hidden_size)
# 执行初始化权重并进行最终处理
self.post_init()
# 使用装饰器为 get_text_features 方法添加文档字符串,描述该方法接受文本输入并返回相关特征
@add_start_docstrings_to_model_forward(FLAVA_TEXT_INPUTS_DOCSTRING.format("batch_size, text_seq_length"))
def get_text_features(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> torch.FloatTensor:
r"""
Returns:
text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
applying the projection layer to the pooled output of [`FlavaTextModel`].
Examples:
```
>>> from transformers import AutoProcessor, FlavaModel
>>> model = FlavaModel.from_pretrained("{0}")
>>> processor = AutoProcessor.from_pretrained("{0}")
>>> inputs = processor(
... text=["a photo of a cat", "a photo of a dog"], max_length=77, padding="max_length", return_tensors="pt"
... )
>>> text_features = model.get_text_features(**inputs)
```""".format(_CHECKPOINT_FOR_DOC)
# 使用预训练模型生成文本特征,通过传入的参数组成输入
text_outputs = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# 从文本输出中获取池化后的特征向量(通常是最后一个隐藏状态)
pooled_output = text_outputs[0] # last_hidden_state
# 将池化后的特征向量投影到最终的文本特征空间
text_features = self.text_projection(pooled_output)
# 返回文本特征向量
return text_features
@add_start_docstrings_to_model_forward(FLAVA_IMAGE_INPUTS_DOCSTRING.format("batch_size, image_num_patches"))
def get_image_features(
self,
pixel_values: Optional[torch.Tensor] = None,
bool_masked_pos: Optional[torch.BoolTensor] = None,
interpolate_pos_encoding: Optional[bool] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> torch.FloatTensor:
r"""
Returns:
image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
applying the projection layer to the pooled output of [`FlavaImageModel`].
Examples:
```
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, FlavaModel
>>> model = FlavaModel.from_pretrained("{0}")
>>> processor = AutoProcessor.from_pretrained("{0}")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, return_tensors="pt")
>>> image_features = model.get_image_features(**inputs)
```""".format(_CHECKPOINT_FOR_DOC)
# 调用图像模型,传入像素值、是否遮蔽位置、注意力掩码、头掩码等参数进行推理
image_outputs = self.image_model(
pixel_values=pixel_values,
bool_masked_pos=bool_masked_pos,
attention_mask=attention_mask,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict,
)
# 从图像模型输出中取出汇总输出(通常是最后一个隐藏状态)
pooled_output = image_outputs[0] # last_hidden_state
# 将汇总输出应用于图像投影层,生成图像特征向量
image_features = self.image_projection(pooled_output)
# 返回图像特征向量作为模型前向传播的结果
return image_features
@add_start_docstrings_to_model_forward(
FLAVA_MODEL_INPUTS_DOCSTRING.format("batch_size, image_num_patches + text_seq_len")
)
@replace_return_docstrings(output_type=FlavaModelOutput, config_class=FlavaConfig)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
bool_masked_pos: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
image_attention_mask: Optional[torch.Tensor] = None,
skip_multimodal_encoder: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: bool = True,
return_dict: Optional[bool] = None,
class FlavaImageCodebookResPath(nn.Module):
def __init__(self, in_size: int, out_size: int, **kwargs):
super().__init__()
hid_size = out_size // 4
# 定义一个有序字典,用于存储网络的层次结构
path = OrderedDict()
path["relu_1"] = nn.ReLU() # 第一个 ReLU 激活函数
path["conv_1"] = nn.Conv2d(in_size, hid_size, kernel_size=3, padding=1) # 第一个卷积层
path["relu_2"] = nn.ReLU() # 第二个 ReLU 激活函数
path["conv_2"] = nn.Conv2d(hid_size, hid_size, kernel_size=3, padding=1) # 第二个卷积层
path["relu_3"] = nn.ReLU() # 第三个 ReLU 激活函数
path["conv_3"] = nn.Conv2d(hid_size, hid_size, kernel_size=3, padding=1) # 第三个卷积层
path["relu_4"] = nn.ReLU() # 第四个 ReLU 激活函数
path["conv_4"] = nn.Conv2d(hid_size, out_size, kernel_size=1, padding=0) # 第四个卷积层(输出层)
# 使用有序字典定义的层次结构创建一个顺序容器
self.path = nn.Sequential(path)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.path(x)
class FlavaImageCodebookBlock(nn.Module):
def __init__(self, in_size: int, out_size: int, num_layers: int, **kwargs):
super().__init__()
# 计算后增益,用于乘以残差路径的输出
self.post_gain = 1 / (num_layers**2)
# 如果输入尺寸不等于输出尺寸,使用 1x1 卷积进行维度匹配
if in_size != out_size:
self.id_path = nn.Conv2d(in_size, out_size, kernel_size=1, padding=0)
else:
self.id_path = nn.Identity() # 若输入输出尺寸相同,则使用恒等映射
# 创建 FLAVA 图像编码块的残差路径
self.res_path = FlavaImageCodebookResPath(in_size, out_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# 返回恒等映射加上后增益乘以残差路径的输出
return self.id_path(x) + self.post_gain * self.res_path(x)
class FlavaImageCodebookLayerGroup(nn.Module):
def __init__(self, num_blocks: int, num_layers: int, in_size: int, out_size: int, use_pool: bool = True):
super().__init__()
blocks = OrderedDict()
# 创建多个 FLAVA 图像编码块的组合
for i in range(num_blocks):
if i == 0:
blocks[f"block_{i+1}"] = FlavaImageCodebookBlock(in_size, out_size, num_layers)
else:
blocks[f"block_{i+1}"] = FlavaImageCodebookBlock(out_size, out_size, num_layers)
# 如果指定使用池化层,则添加最大池化层到块组中
if use_pool:
blocks["pool"] = nn.MaxPool2d(kernel_size=2)
# 创建顺序容器,包含所有创建的块组
self.group = nn.Sequential(blocks)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.group(x)
# 受 DALL-E 编码器启发,FLAVA 图像码书模型用于生成原始隐藏状态,可用于根据 DALL-E 词汇为图像生成图像标记。用于为 MIM 生成标签。
# 使用 `get_codebook_indices` 函数获取图像的标记。
@add_start_docstrings(
"""
FLAVA 图像码书模型,受 DALL-E 原始编码器启发而来。输出原始隐藏状态,可用于根据 DALL-E 词汇为图像生成图像标记。用于为 MIM 生成标签。
使用 `get_codebook_indices` 函数获取图像的标记。
""",
FLAVA_START_DOCSTRING.format(config="FlavaImageCodebookConfig"),
)
class FlavaImageCodebook(FlavaPreTrainedModel):
base_model_prefix = ""
config_class = FlavaImageCodebookConfig
main_input_name = "pixel_values"
supports_gradient_checkpointing = False
def __init__(
self,
config: FlavaImageCodebookConfig,
**kwargs: Any,
):
super().__init__(config) # 调用父类构造函数,初始化模型配置
self.config = config # 将配置信息存储到对象属性中
self.num_groups = config.num_groups # 设置组数
self.input_channels = config.input_channels # 设置输入通道数
self.num_blocks_per_group = config.num_blocks_per_group # 设置每组中的块数
self.hidden_size = config.hidden_size # 设置隐藏层大小
self.vocab_size = config.vocab_size # 设置词汇表大小
num_layers = self.num_groups * self.num_blocks_per_group # 计算总层数
output_blocks = OrderedDict()
output_blocks["relu"] = nn.ReLU() # 添加ReLU激活函数到输出块
output_blocks["conv"] = nn.Conv2d(8 * self.hidden_size, self.vocab_size, kernel_size=1, padding=0) # 添加卷积层到输出块
blocks = OrderedDict()
blocks["input"] = nn.Conv2d(self.input_channels, 1 * self.hidden_size, kernel_size=7, padding=3) # 添加输入卷积层到块
blocks["group_1"] = FlavaImageCodebookLayerGroup(
self.num_blocks_per_group, num_layers, 1 * self.hidden_size, 1 * self.hidden_size
) # 添加第一个图像码书层组
blocks["group_2"] = FlavaImageCodebookLayerGroup(
self.num_blocks_per_group, num_layers, 1 * self.hidden_size, 2 * self.hidden_size
) # 添加第二个图像码书层组
blocks["group_3"] = FlavaImageCodebookLayerGroup(
self.num_blocks_per_group, num_layers, 2 * self.hidden_size, 4 * self.hidden_size
) # 添加第三个图像码书层组
blocks["group_4"] = FlavaImageCodebookLayerGroup(
self.num_blocks_per_group, num_layers, 4 * self.hidden_size, 8 * self.hidden_size, use_pool=False
) # 添加第四个图像码书层组,并指定不使用池化操作
blocks["output"] = nn.Sequential(output_blocks) # 添加输出块到块序列
self.blocks = nn.Sequential(blocks) # 构建模型的块序列
self.post_init() # 执行后初始化操作
if self.config.freeze:
for param in self.parameters():
param.requires_grad = False # 如果配置要求冻结模型,则设置所有参数不需要梯度计算
def get_codebook_indices(self, pixel_values: torch.Tensor) -> torch.Tensor:
"""
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Codebook pixel values can be obtained using [`AutoImageProcessor`] by passing
`return_codebook_pixels=True`. See [`FlavaImageProcessor.__call__`] for details.
Examples:
```
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoImageProcessor, FlavaImageCodebook
>>> model = FlavaImageCodebook.from_pretrained("{0}")
>>> image_processor = AutoImageProcessor.from_pretrained("{0}")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = image_processor([image], return_codebook_pixels=True, return_tensors="pt")
>>> inputs = dict(pixel_values=inputs.codebook_pixel_values)
>>> outputs = model.get_codebook_indices(**inputs)
```
""".format(_CHECKPOINT_FOR_CODEBOOK_DOC)
z_logits = self.blocks(pixel_values) # 将像素值传递给模型的块序列进行处理,得到logits
return torch.argmax(z_logits, axis=1) # 返回logits的最大值索引作为码书索引
# 使用给定的像素值作为输入,通过神经网络模块生成概率分布的 logits
z_logits = self.blocks(pixel_values)
# 对 logits 进行 softmax 处理,得到概率分布
return nn.Softmax(dim=1)(z_logits)
# 定义一个类,用于处理 Flava 模型的预测头部变换
class FlavaPredictionHeadTransform(nn.Module):
def __init__(self, config):
super().__init__()
# 创建一个全连接层,输入和输出维度都是 config.hidden_size
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
# 根据配置选择激活函数,存储在 transform_act_fn 中
if isinstance(config.hidden_act, str):
self.transform_act_fn = ACT2FN[config.hidden_act]
else:
self.transform_act_fn = config.hidden_act
# 创建一个 LayerNorm 层,归一化隐藏状态的维度为 config.hidden_size
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
# 前向传播函数,对输入的隐藏状态进行变换操作
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states) # 全连接层变换
hidden_states = self.transform_act_fn(hidden_states) # 应用激活函数
hidden_states = self.LayerNorm(hidden_states) # LayerNorm 归一化
return hidden_states
# 定义一个类,用于 Flava 模型的蒙版预测头部
class FlavaMaskedPredictionHead(nn.Module):
def __init__(self, config, weight=None):
super().__init__()
self.config = config
# 创建 FlavaPredictionHeadTransform 实例,用于变换隐藏状态
self.transform = FlavaPredictionHeadTransform(config)
# 创建一个线性层,将隐藏状态映射到词汇表大小,无偏置
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.bias = nn.Parameter(torch.zeros(config.vocab_size)) # 创建一个偏置参数
if weight is not None:
self.decoder.weight = weight
# 将 decoder 的偏置参数与 self.bias 关联,以便在 resize_token_embeddings 时正确调整大小
self.decoder.bias = self.bias
# 前向传播函数,对输入进行预测头部的操作
def forward(self, x):
x = self.transform(x) # 应用变换操作
x = self.decoder(x) # 使用线性层进行预测
return x
# 定义一个类,用于 Flava 模型的 ITM 头部
class FlavaITMHead(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
# 创建 FlavaPooler 实例,用于池化隐藏状态
self.pooler = FlavaPooler(config)
# 创建一个线性层,将池化后的隐藏状态映射到 2 个输出类别(用于 ITM 任务)
self.seq_relationship = nn.Linear(config.hidden_size, 2)
# 前向传播函数,对输入进行 ITM 头部的操作
def forward(self, x):
x = self.pooler(x) # 应用池化操作
x = self.seq_relationship(x) # 使用线性层进行序列关系预测
return x
# 定义一个类,用于 Flava 模型的全局对比头部
class FlavaGlobalContrastiveHead(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
# 存储全局背景传播对比的配置
self.global_backprop_contrastive = config.global_backprop_contrastive
# 定义一个前向传播函数,接收图片嵌入、文本嵌入和logit缩放因子作为输入参数
def forward(self, image_embeddings, text_embeddings, logit_scale):
# 计算温度参数,使用logit缩放因子的指数值
temperature = torch.exp(logit_scale)
# 检查当前环境是否支持并初始化了分布式训练
if not torch.distributed.is_available() or not torch.distributed.is_initialized():
# 如果未启用分布式训练,则生成标签张量,从0到图片嵌入的数量
labels = torch.arange(image_embeddings.size(0), device=image_embeddings.device)
# 将当前批次的图片嵌入数据存储在列表中
image_embeddings_all = [image_embeddings]
# 将当前批次的文本嵌入数据存储在列表中
text_embeddings_all = [text_embeddings]
else:
# 获取当前批次的本地大小
local_batch_size = image_embeddings.size(0)
# 获取分布式训练的总节点数
world_size = torch.distributed.get_world_size()
if self.global_backprop_contrastive:
# 如果启用全局反向传播对比,则使用分布式函数收集所有工作节点的图片和文本嵌入
image_embeddings_all = torch.distributed.nn.functional.all_gather(image_embeddings)
text_embeddings_all = torch.distributed.nn.functional.all_gather(text_embeddings)
else:
# 如果未启用全局反向传播对比,则为每个工作节点创建一个零张量列表
image_embeddings_all = [torch.zeros_like(text_embeddings) for _ in range(world_size)]
text_embeddings_all = [torch.zeros_like(image_embeddings) for _ in range(world_size)]
# 使用分布式函数收集所有工作节点的图片嵌入数据
torch.distributed.all_gather(image_embeddings_all, image_embeddings)
# 使用分布式函数收集所有工作节点的文本嵌入数据
torch.distributed.all_gather(text_embeddings_all, text_embeddings)
# 为每个本地批次生成对应的标签,考虑当前节点的排名和本地批次大小
labels = local_batch_size * torch.distributed.get_rank() + torch.arange(
local_batch_size, device=image_embeddings.device
)
# 将收集到的所有图片嵌入数据拼接成一个张量
image_embeddings_all = torch.cat(image_embeddings_all)
# 将收集到的所有文本嵌入数据拼接成一个张量
text_embeddings_all = torch.cat(text_embeddings_all)
# 计算图片嵌入与所有文本嵌入的点积,并乘以温度参数
logits_per_image = torch.matmul(image_embeddings, text_embeddings_all.transpose(0, 1)) * temperature
# 计算文本嵌入与所有图片嵌入的点积,并乘以温度参数
logits_per_text = torch.matmul(text_embeddings, image_embeddings_all.transpose(0, 1)) * temperature
# 返回计算得到的图片logits、文本logits以及相应的标签
return logits_per_image, logits_per_text, labels
# 使用装饰器为 FLAVA 预训练模型添加文档字符串,描述模型输出损失、嵌入、logits 和变换器输出。
@add_start_docstrings(
"""
The FLAVA model for pretraining which outputs losses, embeddings, logits and transformer outputs.
""",
FLAVA_START_DOCSTRING.format(config="FlavaConfig") + FLAVA_PRETRAINING_START_DOCSTRING_EXTRA,
)
class FlavaForPreTraining(FlavaPreTrainedModel):
# 这些键与 xxx.bias 相关联
_tied_weights_keys = [
"mmm_text_head.decoder.bias",
"mmm_image_head.decoder.bias",
"mlm_head.decoder.bias",
"mim_head.decoder.bias",
]
def __init__(self, config: FlavaConfig, image_codebook: Optional[nn.Module] = None):
# 调用父类构造函数初始化模型
super().__init__(config)
# 创建 FLAVA 模型
self.flava = FlavaModel(config)
# 设置图像码书,如果未提供且配置指定则初始化图像码书
self.image_codebook = image_codebook
if self.image_codebook is None and config.init_codebook:
self.image_codebook = FlavaImageCodebook(config.image_codebook_config)
# 根据文本和图像编码器配置创建遮蔽头,以确保有正确的词汇表
self.mim_head = FlavaMaskedPredictionHead(config.image_config)
self.mlm_head = FlavaMaskedPredictionHead(config.text_config)
self.itm_head = FlavaITMHead(config)
self.mmm_image_head = FlavaMaskedPredictionHead(config.image_config)
self.mmm_text_head = FlavaMaskedPredictionHead(config.text_config)
self.global_contrastive_head = FlavaGlobalContrastiveHead(config)
# 设置图像和文本词汇表大小
self.image_vocab_size = config.image_config.vocab_size
self.text_vocab_size = config.text_config.vocab_size
# 设置 MLM、MIM、全局对比损失权重
self.mlm_weight = config.mlm_weight
self.mim_weight = config.mim_weight
self.global_contrastive_weight = config.global_contrastive_weight
# 设置交叉熵忽略索引和 ITM 权重
self.ce_ignore_index = config.ce_ignore_index
self.itm_weight = config.itm_weight
# 设置 MMM 图像和文本权重
self.mmm_image_weight = config.mmm_image_weight
self.mmm_text_weight = config.mmm_text_weight
# 设置是否跳过未遮蔽的多模态编码器
self.skip_unmasked_multimodal_encoder = config.skip_unmasked_multimodal_encoder
# 执行初始化后操作
self.post_init()
def _resize_to_2d(self, x: torch.Tensor):
# 如果输入张量维度大于 2,则展平为二维张量
if x.dim() > 2:
x = x.view(x.size(0), -1)
return x
@add_start_docstrings_to_model_forward(
# 添加模型 forward 方法的输入文档字符串,描述输入参数的形状
FLAVA_PRETRAINING_INPUTS_DOCSTRING.format("batch_size, text_seq_len", "batch_size, image_num_patches")
)
@replace_return_docstrings(output_type=FlavaForPreTrainingOutput, config_class=FlavaConfig)
# 定义模型的前向传播函数,接收多个输入参数,所有参数都是可选的
self,
# 输入的token IDs序列,用于文本输入
input_ids: Optional[torch.LongTensor] = None,
# 掩码后的输入token IDs序列,用于MLM任务
input_ids_masked: Optional[torch.LongTensor] = None,
# 图像的像素值,用于图像输入
pixel_values: Optional[torch.FloatTensor] = None,
# 用于编码图像的码本像素值
codebook_pixel_values: Optional[torch.FloatTensor] = None,
# 注意力掩码,用于指示哪些token是padding的
attention_mask: Optional[torch.Tensor] = None,
# token类型IDs,用于BERT类型模型
token_type_ids: Optional[torch.Tensor] = None,
# 布尔掩码,指示哪些位置是被掩盖的
bool_masked_pos: Optional[torch.Tensor] = None,
# 位置IDs,用于指定token的位置信息
position_ids: Optional[torch.LongTensor] = None,
# 图像注意力掩码,指示图像中哪些部分需要注意力
image_attention_mask: Optional[torch.Tensor] = None,
# 是否跳过未掩盖的多模态编码器
skip_unmasked_multimodal_encoder: bool = None,
# MLM任务的标签
mlm_labels: Optional[torch.Tensor] = None,
# MIM任务的标签
mim_labels: Optional[torch.Tensor] = None,
# ITM任务的标签
itm_labels: Optional[torch.Tensor] = None,
# 是否输出注意力权重
output_attentions: Optional[bool] = None,
# 是否输出隐藏状态
output_hidden_states: bool = True,
# 是否返回字典形式的输出
return_dict: Optional[bool] = None,
# 是否返回损失值
return_loss: Optional[bool] = None,
.\models\flava\processing_flava.py
"""
Image/Text processor class for FLAVA
"""
import warnings
from typing import List, Optional, Union
from ...image_utils import ImageInput
from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
from ...utils import TensorType
class FlavaProcessor(ProcessorMixin):
r"""
Constructs a FLAVA processor which wraps a FLAVA image processor and a FLAVA tokenizer into a single processor.
[`FlavaProcessor`] offers all the functionalities of [`FlavaImageProcessor`] and [`BertTokenizerFast`]. See the
[`~FlavaProcessor.__call__`] and [`~FlavaProcessor.decode`] for more information.
Args:
image_processor ([`FlavaImageProcessor`], *optional*): The image processor is a required input.
tokenizer ([`BertTokenizerFast`], *optional*): The tokenizer is a required input.
"""
attributes = ["image_processor", "tokenizer"]
image_processor_class = "FlavaImageProcessor"
tokenizer_class = ("BertTokenizer", "BertTokenizerFast")
def __init__(self, image_processor=None, tokenizer=None, **kwargs):
feature_extractor = None
if "feature_extractor" in kwargs:
warnings.warn(
"The `feature_extractor` argument is deprecated and will be removed in v5, use `image_processor`"
" instead.",
FutureWarning,
)
feature_extractor = kwargs.pop("feature_extractor")
image_processor = image_processor if image_processor is not None else feature_extractor
if image_processor is None:
raise ValueError("You need to specify an `image_processor`.")
if tokenizer is None:
raise ValueError("You need to specify a `tokenizer`.")
super().__init__(image_processor, tokenizer)
self.current_processor = self.image_processor
def __call__(
self,
images: Optional[ImageInput] = None,
text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
add_special_tokens: bool = True,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = False,
max_length: Optional[int] = None,
stride: int = 0,
pad_to_multiple_of: Optional[int] = None,
return_image_mask: Optional[bool] = None,
return_codebook_pixels: Optional[bool] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
return_length: bool = False,
verbose: bool = True,
return_tensors: Optional[Union[str, TensorType]] = None,
**kwargs,
):
"""
This method uses [`FlavaImageProcessor.__call__`] method to prepare image(s) for the model, and
[`BertTokenizerFast.__call__`] to prepare text for the model.
Please refer to the docstring of the above two methods for more information.
"""
if text is None and images is None:
raise ValueError("You have to specify either text or images. Both cannot be none.")
if text is not None:
encoding = self.tokenizer(
text=text,
add_special_tokens=add_special_tokens,
padding=padding,
truncation=truncation,
max_length=max_length,
stride=stride,
pad_to_multiple_of=pad_to_multiple_of,
return_token_type_ids=return_token_type_ids,
return_attention_mask=return_attention_mask,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_offsets_mapping=return_offsets_mapping,
return_length=return_length,
verbose=verbose,
return_tensors=return_tensors,
**kwargs,
)
if images is not None:
image_features = self.image_processor(
images,
return_image_mask=return_image_mask,
return_codebook_pixels=return_codebook_pixels,
return_tensors=return_tensors,
**kwargs,
)
if text is not None and images is not None:
encoding.update(image_features)
return encoding
elif text is not None:
return encoding
else:
return BatchEncoding(data=dict(**image_features), tensor_type=return_tensors)
def batch_decode(self, *args, **kwargs):
return self.tokenizer.batch_decode(*args, **kwargs)
def decode(self, *args, **kwargs):
return self.tokenizer.decode(*args, **kwargs)
@property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
image_processor_input_names = self.image_processor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
@property
def feature_extractor_class(self):
warnings.warn(
"`feature_extractor_class` is deprecated and will be removed in v5. Use `image_processor_class` instead.",
FutureWarning,
)
return self.image_processor_class
@property
def feature_extractor(self):
warnings.warn(
"`feature_extractor` is deprecated and will be removed in v5. Use `image_processor` instead.",
FutureWarning,
)
return self.image_processor
.\models\flava\__init__.py
from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
_import_structure = {
"configuration_flava": [
"FLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP",
"FlavaConfig",
"FlavaImageCodebookConfig",
"FlavaImageConfig",
"FlavaMultimodalConfig",
"FlavaTextConfig",
],
}
try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["feature_extraction_flava"] = ["FlavaFeatureExtractor"]
_import_structure["image_processing_flava"] = ["FlavaImageProcessor"]
_import_structure["processing_flava"] = ["FlavaProcessor"]
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_flava"] = [
"FLAVA_PRETRAINED_MODEL_ARCHIVE_LIST",
"FlavaForPreTraining",
"FlavaImageCodebook",
"FlavaImageModel",
"FlavaModel",
"FlavaMultimodalModel",
"FlavaPreTrainedModel",
"FlavaTextModel",
]
if TYPE_CHECKING:
from .configuration_flava import (
FLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP,
FlavaConfig,
FlavaImageCodebookConfig,
FlavaImageConfig,
FlavaMultimodalConfig,
FlavaTextConfig,
)
try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .feature_extraction_flava import FlavaFeatureExtractor
from .image_processing_flava import FlavaImageProcessor
from .processing_flava import FlavaProcessor
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_flava import (
FLAVA_PRETRAINED_MODEL_ARCHIVE_LIST,
FlavaForPreTraining,
FlavaImageCodebook,
FlavaImageModel,
FlavaModel,
FlavaMultimodalModel,
FlavaPreTrainedModel,
FlavaTextModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__
)
.\models\fnet\configuration_fnet.py
"""
FNet model configuration
This module defines the configuration for the FNet model, specifying how to instantiate
different variants of the model architecture based on provided arguments.
"""
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
FNET_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"google/fnet-base": "https://huggingface.co/google/fnet-base/resolve/main/config.json",
"google/fnet-large": "https://huggingface.co/google/fnet-large/resolve/main/config.json",
}
class FNetConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`FNetModel`]. It is used to instantiate an FNet
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the FNet
[google/fnet-base](https://huggingface.co/google/fnet-base) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
"""
model_type = "fnet"
def __init__(
self,
vocab_size=32000,
hidden_size=768,
num_hidden_layers=12,
intermediate_size=3072,
hidden_act="gelu_new",
hidden_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=4,
initializer_range=0.02,
layer_norm_eps=1e-12,
use_tpu_fourier_optimizations=False,
tpu_short_seq_length=512,
pad_token_id=3,
bos_token_id=1,
eos_token_id=2,
**kwargs,
):
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.initializer_range = initializer_range
self.type_vocab_size = type_vocab_size
self.layer_norm_eps = layer_norm_eps
self.use_tpu_fourier_optimizations = use_tpu_fourier_optimizations
self.tpu_short_seq_length = tpu_short_seq_length
.\models\fnet\convert_fnet_original_flax_checkpoint_to_pytorch.py
import argparse
import torch
from flax.training.checkpoints import restore_checkpoint
from transformers import FNetConfig, FNetForPreTraining
from transformers.utils import logging
logging.set_verbosity_info()
def convert_flax_checkpoint_to_pytorch(flax_checkpoint_path, fnet_config_file, save_path):
config = FNetConfig.from_json_file(fnet_config_file)
print(f"Building PyTorch model from configuration: {config}")
fnet_pretraining_model = FNetForPreTraining(config)
checkpoint_dict = restore_checkpoint(flax_checkpoint_path, None)
pretrained_model_params = checkpoint_dict["target"]
state_dict = fnet_pretraining_model.state_dict()
position_ids = state_dict["fnet.embeddings.position_ids"]
new_state_dict = {"fnet.embeddings.position_ids": position_ids}
new_state_dict["fnet.embeddings.word_embeddings.weight"] = torch.tensor(
pretrained_model_params["encoder"]["embedder"]["word"]["embedding"]
)
new_state_dict["fnet.embeddings.position_embeddings.weight"] = torch.tensor(
pretrained_model_params["encoder"]["embedder"]["position"]["embedding"][0]
)
new_state_dict["fnet.embeddings.token_type_embeddings.weight"] = torch.tensor(
pretrained_model_params["encoder"]["embedder"]["type"]["embedding"]
)
new_state_dict["fnet.embeddings.projection.weight"] = torch.tensor(
pretrained_model_params["encoder"]["embedder"]["hidden_mapping_in"]["kernel"]
).T
new_state_dict["fnet.embeddings.projection.bias"] = torch.tensor(
pretrained_model_params["encoder"]["embedder"]["hidden_mapping_in"]["bias"]
)
new_state_dict["fnet.embeddings.LayerNorm.weight"] = torch.tensor(
pretrained_model_params["encoder"]["embedder"]["layer_norm"]["scale"]
)
new_state_dict["fnet.embeddings.LayerNorm.bias"] = torch.tensor(
pretrained_model_params["encoder"]["embedder"]["layer_norm"]["bias"]
)
new_state_dict[f"fnet.encoder.layer.{layer}.fourier.output.LayerNorm.weight"] = torch.tensor(
pretrained_model_params["encoder"][f"encoder_{layer}"]["mixing_layer_norm"]["scale"]
)
new_state_dict[f"fnet.encoder.layer.{layer}.fourier.output.LayerNorm.bias"] = torch.tensor(
pretrained_model_params["encoder"][f"encoder_{layer}"]["mixing_layer_norm"]["bias"]
)
new_state_dict[f"fnet.encoder.layer.{layer}.intermediate.dense.weight"] = torch.tensor(
pretrained_model_params["encoder"][f"feed_forward_{layer}"]["intermediate"]["kernel"]
).T
new_state_dict[f"fnet.encoder.layer.{layer}.intermediate.dense.bias"] = torch.tensor(
pretrained_model_params["encoder"][f"feed_forward_{layer}"]["intermediate"]["bias"]
)
new_state_dict[f"fnet.encoder.layer.{layer}.output.dense.weight"] = torch.tensor(
pretrained_model_params["encoder"][f"feed_forward_{layer}"]["output"]["kernel"]
).T
new_state_dict[f"fnet.encoder.layer.{layer}.output.dense.bias"] = torch.tensor(
pretrained_model_params["encoder"][f"feed_forward_{layer}"]["output"]["bias"]
)
new_state_dict[f"fnet.encoder.layer.{layer}.output.LayerNorm.weight"] = torch.tensor(
pretrained_model_params["encoder"][f"encoder_{layer}"]["output_layer_norm"]["scale"]
)
new_state_dict[f"fnet.encoder.layer.{layer}.output.LayerNorm.bias"] = torch.tensor(
pretrained_model_params["encoder"][f"encoder_{layer}"]["output_layer_norm"]["bias"]
)
new_state_dict["fnet.pooler.dense.weight"] = torch.tensor(pretrained_model_params["encoder"]["pooler"]["kernel"]).T
new_state_dict["fnet.pooler.dense.bias"] = torch.tensor(pretrained_model_params["encoder"]["pooler"]["bias"])
new_state_dict["cls.predictions.transform.dense.weight"] = torch.tensor(
pretrained_model_params["predictions_dense"]["kernel"]
).T
new_state_dict["cls.predictions.transform.dense.bias"] = torch.tensor(
pretrained_model_params["predictions_dense"]["bias"]
)
new_state_dict["cls.predictions.transform.LayerNorm.weight"] = torch.tensor(
pretrained_model_params["predictions_layer_norm"]["scale"]
)
new_state_dict["cls.predictions.transform.LayerNorm.bias"] = torch.tensor(
pretrained_model_params["predictions_layer_norm"]["bias"]
)
new_state_dict["cls.predictions.decoder.weight"] = torch.tensor(
pretrained_model_params["encoder"]["embedder"]["word"]["embedding"]
)
new_state_dict["cls.predictions.decoder.bias"] = torch.tensor(
pretrained_model_params["predictions_output"]["output_bias"]
)
new_state_dict["cls.predictions.bias"] = torch.tensor(pretrained_model_params["predictions_output"]["output_bias"])
new_state_dict["cls.seq_relationship.weight"] = torch.tensor(
pretrained_model_params["classification"]["output_kernel"]
)
new_state_dict["cls.seq_relationship.bias"] = torch.tensor(
pretrained_model_params["classification"]["output_bias"]
)
fnet_pretraining_model.load_state_dict(new_state_dict)
print(f"Saving pretrained model to {save_path}")
fnet_pretraining_model.save_pretrained(save_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--flax_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
)
parser.add_argument(
"--fnet_config_file",
default=None,
type=str,
required=True,
help=(
"The config json file corresponding to the pre-trained FNet model. \n"
"This specifies the model architecture."
),
)
parser.add_argument("--save_path", default=None, type=str, required=True, help="Path to the output model.")
args = parser.parse_args()
convert_flax_checkpoint_to_pytorch(args.flax_checkpoint_path, args.fnet_config_file, args.save_path)
.\models\fnet\modeling_fnet.py
""" PyTorch FNet model."""
import warnings
from dataclasses import dataclass
from functools import partial
from typing import Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...utils import is_scipy_available
if is_scipy_available():
from scipy import linalg
from ...activations import ACT2FN
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPooling,
MaskedLMOutput,
ModelOutput,
MultipleChoiceModelOutput,
NextSentencePredictorOutput,
QuestionAnsweringModelOutput,
SequenceClassifierOutput,
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from .configuration_fnet import FNetConfig
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "google/fnet-base"
_CONFIG_FOR_DOC = "FNetConfig"
FNET_PRETRAINED_MODEL_ARCHIVE_LIST = [
"google/fnet-base",
"google/fnet-large",
]
def _two_dim_matmul(x, matrix_dim_one, matrix_dim_two):
"""Applies 2D matrix multiplication to 3D input arrays."""
seq_length = x.shape[1]
matrix_dim_one = matrix_dim_one[:seq_length, :seq_length]
x = x.type(torch.complex64)
return torch.einsum("bij,jk,ni->bnk", x, matrix_dim_two, matrix_dim_one)
def two_dim_matmul(x, matrix_dim_one, matrix_dim_two):
"""Applies 2D matrix multiplication to 3D input arrays."""
return _two_dim_matmul(x, matrix_dim_one, matrix_dim_two)
def fftn(x):
"""
Applies n-dimensional Fast Fourier Transform (FFT) to input array.
Args:
x: Input n-dimensional array.
Returns:
n-dimensional Fourier transform of input n-dimensional array.
"""
out = x
for axis in reversed(range(x.ndim)[1:]):
out = torch.fft.fft(out, axis=axis)
return out
class FNetEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings."""
def __init__(self, config):
super().__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.projection = nn.Linear(config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
self.register_buffer(
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
)
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
if input_ids is not None:
input_shape = input_ids.size()
else:
input_shape = inputs_embeds.size()[:-1]
seq_length = input_shape[1]
if position_ids is None:
position_ids = self.position_ids[:, :seq_length]
if token_type_ids is None:
if hasattr(self, "token_type_ids"):
buffered_token_type_ids = self.token_type_ids[:, :seq_length]
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
token_type_ids = buffered_token_type_ids_expanded
else:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = inputs_embeds + token_type_embeddings
position_embeddings = self.position_embeddings(position_ids)
embeddings += position_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.projection(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
class FNetBasicFourierTransform(nn.Module):
def __init__(self, config):
super().__init__()
self._init_fourier_transform(config)
def _init_fourier_transform(self, config):
if not config.use_tpu_fourier_optimizations:
self.fourier_transform = partial(torch.fft.fftn, dim=(1, 2))
elif config.max_position_embeddings <= 4096:
if is_scipy_available():
self.register_buffer(
"dft_mat_hidden", torch.tensor(linalg.dft(config.hidden_size), dtype=torch.complex64)
)
self.register_buffer(
"dft_mat_seq", torch.tensor(linalg.dft(config.tpu_short_seq_length), dtype=torch.complex64)
)
self.fourier_transform = partial(
two_dim_matmul, matrix_dim_one=self.dft_mat_seq, matrix_dim_two=self.dft_mat_hidden
)
else:
logging.warning(
"SciPy is needed for DFT matrix calculation and is not found. Using TPU optimized fast fourier"
" transform instead."
)
self.fourier_transform = fftn
else:
self.fourier_transform = fftn
def forward(self, hidden_states):
outputs = self.fourier_transform(hidden_states).real
return (outputs,)
class FNetBasicOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states, input_tensor):
hidden_states = self.LayerNorm(input_tensor + hidden_states)
return hidden_states
class FNetFourierTransform(nn.Module):
def __init__(self, config):
super().__init__()
self.self = FNetBasicFourierTransform(config)
self.output = FNetBasicOutput(config)
def forward(self, hidden_states):
self_outputs = self.self(hidden_states)
fourier_output = self.output(self_outputs[0], hidden_states)
outputs = (fourier_output,)
return outputs
class FNetIntermediate(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class FNetOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class FNetLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.fourier = FNetFourierTransform(config)
self.intermediate = FNetIntermediate(config)
self.output = FNetOutput(config)
def forward(self, hidden_states):
self_fourier_outputs = self.fourier(hidden_states)
fourier_output = self_fourier_outputs[0]
layer_output = apply_chunking_to_forward(
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, fourier_output
)
outputs = (layer_output,)
return outputs
def feed_forward_chunk(self, fourier_output):
intermediate_output = self.intermediate(fourier_output)
layer_output = self.output(intermediate_output, fourier_output)
return layer_output
class FNetEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.layer = nn.ModuleList([FNetLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
def forward(self, hidden_states, output_hidden_states=False, return_dict=True):
all_hidden_states = () if output_hidden_states else None
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(layer_module.__call__, hidden_states)
else:
layer_outputs = layer_module(hidden_states)
hidden_states = layer_outputs[0]
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)
return BaseModelOutput(last_hidden_state=hidden_states, hidden_states=all_hidden_states)
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
class FNetPredictionHeadTransform(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
if isinstance(config.hidden_act, str):
self.transform_act_fn = ACT2FN[config.hidden_act]
else:
self.transform_act_fn = config.hidden_act
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states
class FNetLMPredictionHead(nn.Module):
def __init__(self, config):
super().__init__()
self.transform = FNetPredictionHeadTransform(config)
self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
self.decoder.bias = self.bias
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
return hidden_states
def _tie_weights(self):
self.bias = self.decoder.bias
class FNetOnlyMLMHead(nn.Module):
def __init__(self, config):
super().__init__()
self.predictions = FNetLMPredictionHead(config)
def forward(self, sequence_output):
prediction_scores = self.predictions(sequence_output)
return prediction_scores
class FNetOnlyNSPHead(nn.Module):
def __init__(self, config):
super().__init__()
self.seq_relationship = nn.Linear(config.hidden_size, 2)
def forward(self, pooled_output):
seq_relationship_score = self.seq_relationship(pooled_output)
return seq_relationship_score
class FNetPreTrainingHeads(nn.Module):
def __init__(self, config):
super().__init__()
self.predictions = FNetLMPredictionHead(config)
self.seq_relationship = nn.Linear(config.hidden_size, 2)
def forward(self, sequence_output, pooled_output):
prediction_scores = self.predictions(sequence_output)
seq_relationship_score = self.seq_relationship(pooled_output)
return prediction_scores, seq_relationship_score
class FNetPreTrainedModel(PreTrainedModel):
"""
FNet预训练模型基类
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = FNetConfig
base_model_prefix = "fnet"
supports_gradient_checkpointing = True
def _init_weights(self, module):
"""Initialize the weights"""
# 如果 module 是 nn.Linear 类型
if isinstance(module, nn.Linear):
# 使用正态分布初始化权重,均值为 0.0,标准差为配置中的初始化范围
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
# 注意:原始代码中偏置的初始化和权重相同
if module.bias is not None:
# 将偏置数据初始化为零
module.bias.data.zero_()
# 如果 module 是 nn.Embedding 类型
elif isinstance(module, nn.Embedding):
# 使用正态分布初始化权重,均值为 0.0,标准差为配置中的初始化范围
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
# 如果有 padding_idx,则将对应位置的权重初始化为零
module.weight.data[module.padding_idx].zero_()
# 如果 module 是 nn.LayerNorm 类型
elif isinstance(module, nn.LayerNorm):
# 将偏置数据初始化为零
module.bias.data.zero_()
# 将权重数据初始化为 1.0
module.weight.data.fill_(1.0)
"""
FNetForPreTrainingOutput 类定义了预训练模型的输出类型,继承自 ModelOutput。
Args:
loss (torch.FloatTensor, 可选): 当提供 `labels` 时返回,表示总损失,包括掩码语言建模损失和下一个序列预测(分类)损失,形状为 `(1,)`。
prediction_logits (torch.FloatTensor): 语言建模头部的预测分数,即 SoftMax 之前的每个词汇标记的分数,形状为 `(batch_size, sequence_length, config.vocab_size)`。
seq_relationship_logits (torch.FloatTensor): 下一个序列预测(分类)头部的预测分数,即 SoftMax 之前的 True/False 连续性的分数,形状为 `(batch_size, 2)`。
hidden_states (tuple(torch.FloatTensor), 可选): 当 `output_hidden_states=True` 被传递或 `config.output_hidden_states=True` 时返回,包含模型每一层的隐藏状态,形状为 `(batch_size, sequence_length, hidden_size)`。
"""
FNET_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html
it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
behavior.
Parameters:
config ([`FNetConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
FNET_INPUTS_DOCSTRING = r"""
输入参数说明:
Args:
input_ids (`torch.LongTensor` of shape `({0})`):
Indices of input sequence tokens in the vocabulary.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary
token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
1]`:
- 0 corresponds to a *sentence A* token,
- 1 corresponds to a *sentence B* token.
[What are token type IDs?](../glossary
position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.max_position_embeddings - 1]`.
[What are position IDs?](../glossary
inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
model's internal embedding lookup matrix.
# 是否返回所有层的隐藏状态。有关更多细节,请查看返回张量中的 `hidden_states`。
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
# 是否返回 [`~utils.ModelOutput`] 而不是普通元组。
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
@add_start_docstrings(
"The bare FNet Model transformer outputting raw hidden-states without any specific head on top.",
FNET_START_DOCSTRING,
)
class FNetModel(FNetPreTrainedModel):
"""
The model can behave as an encoder, following the architecture described in [FNet: Mixing Tokens with Fourier
Transforms](https://arxiv.org/abs/2105.03824) by James Lee-Thorp, Joshua Ainslie, Ilya Eckstein, Santiago Ontanon.
"""
def __init__(self, config, add_pooling_layer=True):
super().__init__(config)
self.config = config
# 初始化 FNetEmbeddings 对象,用于处理模型的嵌入层
self.embeddings = FNetEmbeddings(config)
# 初始化 FNetEncoder 对象,用于处理模型的编码器层
self.encoder = FNetEncoder(config)
# 如果需要添加池化层,则初始化 FNetPooler 对象,否则为 None
self.pooler = FNetPooler(config) if add_pooling_layer else None
# 初始化权重并进行最终处理
self.post_init()
def get_input_embeddings(self):
# 返回嵌入层的单词嵌入
return self.embeddings.word_embeddings
def set_input_embeddings(self, value):
# 设置嵌入层的单词嵌入
self.embeddings.word_embeddings = value
@add_start_docstrings_to_model_forward(FNET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=BaseModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[tuple, BaseModelOutput]:
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# 如果同时指定了 input_ids 和 inputs_embeds,则抛出数值错误
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
# 如果指定了 input_ids
elif input_ids is not None:
# 获取 input_ids 的形状
input_shape = input_ids.size()
batch_size, seq_length = input_shape
# 如果指定了 inputs_embeds
elif inputs_embeds is not None:
# 获取 inputs_embeds 的形状,排除最后一维
input_shape = inputs_embeds.size()[:-1]
batch_size, seq_length = input_shape
else:
# 如果既没有指定 input_ids 也没有指定 inputs_embeds,则抛出数值错误
raise ValueError("You have to specify either input_ids or inputs_embeds")
# 如果配置中启用了 TPU Fourier 优化,并且序列长度小于等于 4096,并且配置的 tpu_short_seq_length 不等于当前序列长度
if (
self.config.use_tpu_fourier_optimizations
and seq_length <= 4096
and self.config.tpu_short_seq_length != seq_length
):
# 抛出数值错误,提示需要设置正确的 tpu_short_seq_length
raise ValueError(
"The `tpu_short_seq_length` in FNetConfig should be set equal to the sequence length being passed to"
" the model when using TPU optimizations."
)
# 根据是否存在 input_ids 来确定设备
device = input_ids.device if input_ids is not None else inputs_embeds.device
# 如果 token_type_ids 未指定
if token_type_ids is None:
# 如果 embeddings 拥有 token_type_ids 属性
if hasattr(self.embeddings, "token_type_ids"):
# 从 embeddings 中获取 token_type_ids,并截取到序列长度的部分,然后扩展到整个 batch
buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
token_type_ids = buffered_token_type_ids_expanded
else:
# 如果 embeddings 没有 token_type_ids 属性,则创建一个全零的 tensor
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
# 使用 embeddings 进行前向传播
embedding_output = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
)
# 使用 encoder 进行编码器的前向传播
encoder_outputs = self.encoder(
embedding_output,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# 获取编码器输出的序列输出
sequence_output = encoder_outputs[0]
# 如果存在 pooler,则使用 pooler 对序列输出进行池化
pooler_output = self.pooler(sequence_output) if self.pooler is not None else None
# 如果不要求返回字典形式的输出,则返回元组形式的结果
if not return_dict:
return (sequence_output, pooler_output) + encoder_outputs[1:]
# 否则,返回一个带池化的 BaseModelOutputWithPooling 对象
return BaseModelOutputWithPooling(
last_hidden_state=sequence_output,
pooler_output=pooler_output,
hidden_states=encoder_outputs.hidden_states,
)
"""
FNet Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next
sentence prediction (classification)` head.
"""
# 导入所需的函数和类,用于添加文档字符串
@add_start_docstrings(
"""
FNet Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next
sentence prediction (classification)` head.
""",
FNET_START_DOCSTRING,
)
# 定义 FNetForPreTraining 类,继承自 FNetPreTrainedModel
class FNetForPreTraining(FNetPreTrainedModel):
# 定义用于权重共享的关键键列表
_tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
# 初始化函数
def __init__(self, config):
# 调用父类的初始化函数
super().__init__(config)
# 初始化 FNetModel 模型
self.fnet = FNetModel(config)
# 初始化 FNetPreTrainingHeads 模型
self.cls = FNetPreTrainingHeads(config)
# 调用后处理函数,初始化权重并应用最终处理
self.post_init()
# 获取输出嵌入层的函数
def get_output_embeddings(self):
return self.cls.predictions.decoder
# 设置输出嵌入层的函数
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
# 前向传播函数
@add_start_docstrings_to_model_forward(FNET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=FNetForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
next_sentence_label: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
# 函数参数:输入的张量数据、标签、下一个句子标签、是否返回隐藏状态、是否返回字典类型的结果
return self.fnet(
input_ids=input_ids,
token_type_ids=token_type_ids,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
labels=labels,
next_sentence_label=next_sentence_label,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
注释:
# 调用 FNetModel 的前向传播方法,将参数传递给 fnet 对象
return self.fnet(
input_ids=input_ids,
token_type_ids=token_type_ids,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
labels=labels,
next_sentence_label=next_sentence_label,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
这段代码定义了一个 FNetForPreTraining 类,它包含了 FNet 模型的预训练结构,包括一个掩码语言建模头和一个下一个句子预测头。它实现了前向传播方法,调用了内部的 FNetModel 的前向传播函数。
) -> Union[Tuple, FNetForPreTrainingOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
next_sentence_label (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
(see `input_ids` docstring) Indices should be in `[0, 1]`:
- 0 indicates sequence B is a continuation of sequence A,
- 1 indicates sequence B is a random sequence.
kwargs (`Dict[str, any]`, optional, defaults to *{}*):
Used to hide legacy arguments that have been deprecated.
Returns:
Returns an instance of `FNetForPreTrainingOutput` containing various outputs from the model.
Example:
```
>>> from transformers import AutoTokenizer, FNetForPreTraining
>>> import torch
>>> tokenizer = AutoTokenizer.from_pretrained("google/fnet-base")
>>> model = FNetForPreTraining.from_pretrained("google/fnet-base")
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(**inputs)
>>> prediction_logits = outputs.prediction_logits
>>> seq_relationship_logits = outputs.seq_relationship_logits
```"""
# Determine whether to use the return_dict mode based on the input argument or default configuration
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Forward pass through the FNet model with specified inputs and configurations
outputs = self.fnet(
input_ids,
token_type_ids=token_type_ids,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# Extract the relevant outputs from the FNet model's output
sequence_output, pooled_output = outputs[:2]
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
total_loss = None
if labels is not None and next_sentence_label is not None:
loss_fct = CrossEntropyLoss()
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
total_loss = masked_lm_loss + next_sentence_loss
if not return_dict:
output = (prediction_scores, seq_relationship_score) + outputs[2:]
return ((total_loss,) + output) if total_loss is not None else output
else:
return FNetForPreTrainingOutput(
loss=total_loss,
prediction_logits=prediction_scores,
seq_relationship_logits=seq_relationship_score,
hidden_states=outputs.hidden_states,
)
@add_start_docstrings("""FNet Model with a `language modeling` head on top.""", FNET_START_DOCSTRING)
class FNetForMaskedLM(FNetPreTrainedModel):
_tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
def __init__(self, config):
super().__init__(config)
self.fnet = FNetModel(config)
self.cls = FNetOnlyMLMHead(config)
self.post_init()
def get_output_embeddings(self):
return self.cls.predictions.decoder
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
@add_start_docstrings_to_model_forward(FNET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=MaskedLMOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, MaskedLMOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.fnet(
input_ids,
token_type_ids=token_type_ids,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
prediction_scores = self.cls(sequence_output)
masked_lm_loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict:
output = (prediction_scores,) + outputs[2:]
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
return MaskedLMOutput(loss=masked_lm_loss, logits=prediction_scores, hidden_states=outputs.hidden_states)
@add_start_docstrings(
"""FNet Model with a `next sentence prediction (classification)` head on top.""",
FNET_START_DOCSTRING,
)
class FNetForNextSentencePrediction(FNetPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.fnet = FNetModel(config)
self.cls = FNetOnlyNSPHead(config)
self.post_init()
@add_start_docstrings_to_model_forward(FNET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple, NextSentencePredictorOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
(see `input_ids` docstring). Indices should be in `[0, 1]`:
- 0 indicates sequence B is a continuation of sequence A,
- 1 indicates sequence B is a random sequence.
Returns:
Example:
```
>>> from transformers import AutoTokenizer, FNetForNextSentencePrediction
>>> import torch
>>> tokenizer = AutoTokenizer.from_pretrained("google/fnet-base")
>>> model = FNetForNextSentencePrediction.from_pretrained("google/fnet-base")
>>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
>>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
>>> encoding = tokenizer(prompt, next_sentence, return_tensors="pt")
>>> outputs = model(**encoding, labels=torch.LongTensor([1]))
>>> logits = outputs.logits
>>> assert logits[0, 0] < logits[0, 1] # next sentence was random
```"""
if "next_sentence_label" in kwargs:
warnings.warn(
"The `next_sentence_label` argument is deprecated and will be removed in a future version, use"
" `labels` instead.",
FutureWarning,
)
labels = kwargs.pop("next_sentence_label")
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.fnet(
input_ids,
token_type_ids=token_type_ids,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
pooled_output = outputs[1]
seq_relationship_scores = self.cls(pooled_output)
next_sentence_loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1))
if not return_dict:
output = (seq_relationship_scores,) + outputs[2:]
return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output
return NextSentencePredictorOutput(
loss=next_sentence_loss,
logits=seq_relationship_scores,
hidden_states=outputs.hidden_states,
)
@add_start_docstrings(
"""
FNet Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
output) e.g. for GLUE tasks.
""",
FNET_START_DOCSTRING,
)
class FNetForSequenceClassification(FNetPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.fnet = FNetModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.post_init()
@add_start_docstrings_to_model_forward(FNET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=SequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
"""
Perform a forward pass of the FNetForSequenceClassification model.
Args:
input_ids (torch.Tensor, optional): Input token IDs of shape (batch_size, sequence_length).
token_type_ids (torch.Tensor, optional): Input token type IDs of shape (batch_size, sequence_length).
position_ids (torch.Tensor, optional): Input token position IDs of shape (batch_size, sequence_length).
inputs_embeds (torch.Tensor, optional): Embedded representations of inputs.
labels (torch.Tensor, optional): Labels for classification task.
output_hidden_states (bool, optional): Whether to output hidden states.
return_dict (bool, optional): Whether to return a dictionary as output.
Returns:
SequenceClassifierOutput: Classification output consisting of logits, hidden states, etc.
"""
) -> Union[Tuple, SequenceClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.fnet(
input_ids,
token_type_ids=token_type_ids,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(loss=loss, logits=logits, hidden_states=outputs.hidden_states)
@add_start_docstrings(
"""
FNet Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
softmax) e.g. for RocStories/SWAG tasks.
""",
FNET_START_DOCSTRING,
)
class FNetForMultipleChoice(FNetPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.fnet = FNetModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, 1)
self.post_init()
@add_start_docstrings_to_model_forward(FNET_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=MultipleChoiceModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
"""
FNet 模型的前向传播方法。
Args:
input_ids (torch.Tensor, optional): 输入的 token IDs 张量.
token_type_ids (torch.Tensor, optional): token 类型 IDs 张量.
position_ids (torch.Tensor, optional): 位置 IDs 张量.
inputs_embeds (torch.Tensor, optional): 嵌入输入张量.
labels (torch.Tensor, optional): 标签张量 (用于训练时).
output_hidden_states (bool, optional): 是否输出隐藏状态.
return_dict (bool, optional): 是否返回字典格式输出.
Returns:
返回一个包含多个选择的模型输出.
"""
pass
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
inputs_embeds = (
inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
if inputs_embeds is not None
else None
)
outputs = self.fnet(
input_ids,
token_type_ids=token_type_ids,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
reshaped_logits = logits.view(-1, num_choices)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(reshaped_logits, labels)
if not return_dict:
output = (reshaped_logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return MultipleChoiceModelOutput(loss=loss, logits=reshaped_logits, hidden_states=outputs.hidden_states)
"""
FNet Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
Named-Entity-Recognition (NER) tasks.
"""
@add_start_docstrings(
"""
FNet Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
Named-Entity-Recognition (NER) tasks.
""",
FNET_START_DOCSTRING,
)
class FNetForTokenClassification(FNetPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.fnet = FNetModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.post_init()
@add_start_docstrings_to_model_forward(FNET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, TokenClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.fnet(
input_ids,
token_type_ids=token_type_ids,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(loss=loss, logits=logits, hidden_states=outputs.hidden_states)
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.fnet = FNetModel(config)
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
self.post_init()
@add_start_docstrings_to_model_forward(FNET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=QuestionAnsweringModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
start_positions: Optional[torch.Tensor] = None,
end_positions: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, QuestionAnsweringModelOutput]:
r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.fnet(
input_ids,
token_type_ids=token_type_ids,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None
if start_positions is not None and end_positions is not None:
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1)
ignored_index = start_logits.size(1)
start_positions = start_positions.clamp(0, ignored_index)
end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
if not return_dict:
output = (start_logits, end_logits) + outputs[2:]
return ((total_loss,) + output) if total_loss is not None else output
return QuestionAnsweringModelOutput(
loss=total_loss, start_logits=start_logits, end_logits=end_logits, hidden_states=outputs.hidden_states
)