Transformers 源码解析(二十二)
.\models\bloom\modeling_flax_bloom.py
import math
from functools import partial
from typing import Optional, Tuple
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, dot_product_attention_weights, make_causal_mask
from flax.linen.activation import tanh
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax
from ...modeling_flax_outputs import (
FlaxBaseModelOutput,
FlaxBaseModelOutputWithPastAndCrossAttentions,
FlaxCausalLMOutput,
)
from ...modeling_flax_utils import FlaxPreTrainedModel, append_call_sample_docstring
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "bigscience/bloom"
_CONFIG_FOR_DOC = "BloomConfig"
BLOOM_START_DOCSTRING = r"""
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a Flax Linen
[flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
Finally, this model supports inherent JAX features such as:
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
Parameters:
config ([`BloomConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
`jax.numpy.bfloat16` (on TPUs).
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
specified all the computation will be performed with the given `dtype`.
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
parameters.**
If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
[`~FlaxPreTrainedModel.to_bf16`].
"""
BLOOM_INPUTS_DOCSTRING = r"""
Args:
input_ids (`numpy.ndarray` of shape `(batch_size, input_ids_length)`):
`input_ids_length` = `sequence_length`. Indices of input sequence tokens in the vocabulary.
Indices can be obtained using [`BloomTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
past_key_values (`Dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`):
Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
def build_alibi_tensor(attention_mask: jnp.ndarray, num_heads: int, dtype: Optional[jnp.dtype] = jnp.float32):
"""
Flax implementation of the BLOOM Alibi tensor. BLOOM Alibi tensor is not causal as the original paper mentions, it
relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
`softmax(l+a) = softmax(l)`. Based on
https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
Link to paper: https://arxiv.org/abs/2108.12409
Args:
attention_mask (`jnp.ndarray`):
Token-wise attention mask, this should be of shape `(batch_size, max_seq_len)`.
num_heads (`int`):
Number of attention heads.
dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
The data type (dtype) of the output tensor.
Returns: Alibi tensor of shape `(batch_size * num_heads, 1, max_seq_len)`.
"""
batch_size, seq_length = attention_mask.shape
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
base = jnp.array(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), dtype=jnp.float32)
powers = jnp.arange(1, 1 + closest_power_of_2, dtype=jnp.float32)
slopes = jax.lax.pow(base, powers)
if closest_power_of_2 != num_heads:
extra_base = jnp.array(2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), dtype=jnp.float32)
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
extra_powers = jnp.arange(1, 1 + 2 * num_remaining_heads, 2, dtype=jnp.float32)
slopes = jnp.cat([slopes, jax.lax.pow(extra_base, extra_powers)], axis=0)
arange_tensor = ((attention_mask.cumsum(axis=-1) - 1) * attention_mask)[:, None, :]
alibi = slopes[..., None] * arange_tensor
alibi = jnp.expand_dims(alibi, axis=2)
return jnp.asarray(alibi, dtype)
class FlaxBloomAttention(nn.Module):
config: BloomConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.hidden_size = self.config.hidden_size
self.num_heads = self.config.n_head
self.head_dim = self.hidden_size // self.num_heads
self.attention_softmax_in_fp32 = self.dtype is not jnp.float32
if self.head_dim * self.num_heads != self.hidden_size:
raise ValueError(
f"`hidden_size` must be divisible by `num_heads` (got `hidden_size`: {self.hidden_size} and "
f"`num_heads`: {self.num_heads})."
)
dense = partial(
nn.Dense,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
)
self.query_key_value = dense(self.hidden_size * 3)
self.dense = dense(self.hidden_size)
self.resid_dropout = nn.Dropout(rate=self.config.hidden_dropout)
def _split_heads(self, hidden_states):
return hidden_states.reshape(hidden_states.shape[:-1] + (self.num_heads, self.head_dim * 3))
def _merge_heads(self, hidden_states):
return hidden_states.reshape(hidden_states.shape[:2] + (self.hidden_size,))
@nn.compact
def _concatenate_to_cache(self, key, value, query, attention_mask):
"""
This function takes projected key, value states from a single input token and concatenates the states to cached
states from previous steps. This function is slighly adapted from the official Flax repository:
https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
"""
is_initialized = self.has_variable("cache", "cached_key")
cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
if is_initialized:
*batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
cur_index = cache_index.value
indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
key = lax.dynamic_update_slice(cached_key.value, key, indices)
value = lax.dynamic_update_slice(cached_value.value, value, indices)
cached_key.value = key
cached_value.value = value
num_updated_cache_vectors = query.shape[1]
cache_index.value = cache_index.value + num_updated_cache_vectors
pad_mask = jnp.broadcast_to(
jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
)
attention_mask = combine_masks(pad_mask, attention_mask)
return key, value, attention_mask
class BloomGELU(nn.Module):
def setup(self):
self.dtype = jnp.float32
def __call__(self, x):
return x * 0.5 * (1.0 + tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
class FlaxBloomMLP(nn.Module):
config: BloomConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
hidden_size = self.config.hidden_size
kernel_init = jax.nn.initializers.normal(self.config.initializer_range)
self.dense_h_to_4h = nn.Dense(4 * hidden_size, dtype=self.dtype, kernel_init=kernel_init)
self.dense_4h_to_h = nn.Dense(hidden_size, dtype=self.dtype, kernel_init=kernel_init)
self.hidden_dropout = nn.Dropout(self.config.hidden_dropout)
self.act = BloomGELU()
def __call__(self, hidden_states, residual, deterministic: bool = True):
hidden_states = self.dense_h_to_4h(hidden_states)
hidden_states = self.act(hidden_states)
intermediate_output = self.dense_4h_to_h(hidden_states)
intermediate_output = intermediate_output + residual
hidden_states = self.hidden_dropout(intermediate_output, deterministic=deterministic)
return hidden_states
class FlaxBloomBlock(nn.Module):
config: BloomConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.input_layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
self.self_attention = FlaxBloomAttention(self.config, dtype=self.dtype)
self.post_attention_layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
self.mlp = FlaxBloomMLP(self.config, dtype=self.dtype)
self.apply_residual_connection_post_layernorm = self.config.apply_residual_connection_post_layernorm
self.hidden_dropout = self.config.hidden_dropout
def __call__(
self,
hidden_states,
alibi,
attention_mask=None,
deterministic: bool = True,
init_cache: bool = False,
output_attentions: bool = False,
):
hidden_states = self.input_layernorm(hidden_states)
hidden_states, attention_output = self.self_attention(
hidden_states,
attention_mask,
deterministic=deterministic,
output_attentions=output_attentions,
init_cache=init_cache,
)
if self.apply_residual_connection_post_layernorm:
hidden_states = hidden_states + alibi
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states, alibi, deterministic=deterministic)
return hidden_states
):
layernorm_output = self.input_layernorm(hidden_states)
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = hidden_states
attn_outputs = self.self_attention(
layernorm_output,
residual=residual,
alibi=alibi,
attention_mask=attention_mask,
deterministic=deterministic,
init_cache=init_cache,
output_attentions=output_attentions,
)
attention_output = attn_outputs[0]
outputs = attn_outputs[1:]
post_layernorm = self.post_attention_layernorm(attention_output)
if self.apply_residual_connection_post_layernorm:
residual = post_layernorm
else:
residual = attention_output
output = self.mlp(post_layernorm, residual, deterministic=deterministic)
outputs = (output,) + outputs
return outputs
class FlaxBloomPreTrainedModel(FlaxPreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = BloomConfig
base_model_prefix = "transformer"
module_class: nn.Module = None
def __init__(
self,
config: BloomConfig,
input_shape: Tuple = (1, 1),
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs,
):
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
input_ids = jnp.zeros(input_shape, dtype="i4")
attention_mask = jnp.ones_like(input_ids)
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
random_params = self.module.init(rngs, input_ids, attention_mask, return_dict=False)["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
def init_cache(self, batch_size, max_length):
"""
Args:
batch_size (`int`):
batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
max_length (`int`):
maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
cache.
"""
input_ids = jnp.ones((batch_size, max_length), dtype="i4")
attention_mask = jnp.ones_like(input_ids)
init_variables = self.module.init(
jax.random.PRNGKey(0), input_ids, attention_mask, return_dict=False, init_cache=True
)
return unfreeze(init_variables["cache"])
@add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
def __call__(
self,
input_ids,
attention_mask=None,
past_key_values: dict = None,
params: dict = None,
dropout_rng: jax.random.PRNGKey = None,
train: bool = False,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
batch_size, sequence_length = input_ids.shape
if attention_mask is None:
attention_mask = jnp.ones((batch_size, sequence_length))
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
inputs = {"params": params or self.params}
if past_key_values:
inputs["cache"] = past_key_values
mutable = ["cache"]
else:
mutable = False
outputs = self.module.apply(
inputs,
jnp.array(input_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"),
not train,
False,
output_attentions,
output_hidden_states,
return_dict,
rngs=rngs,
mutable=mutable,
)
if past_key_values is not None and return_dict:
outputs, past_key_values = outputs
outputs["past_key_values"] = unfreeze(past_key_values["cache"])
return outputs
elif past_key_values is not None and not return_dict:
outputs, past_key_values = outputs
outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
return outputs
class FlaxBloomBlockCollection(nn.Module):
config: BloomConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.layers = [
FlaxBloomBlock(self.config, name=str(layer_number), dtype=self.dtype)
for layer_number in range(self.config.num_hidden_layers)
]
def __call__(
self,
hidden_states,
alibi,
attention_mask=None,
deterministic: bool = True,
init_cache: bool = False,
output_attentions: bool = False,
output_hidden_states: bool = False,
):
all_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
for layer_number in range(self.config.num_hidden_layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = self.layers[layer_number](
hidden_states,
alibi=alibi,
attention_mask=attention_mask,
deterministic=deterministic,
init_cache=init_cache,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions += (layer_outputs[1],)
outputs = (hidden_states, all_hidden_states, all_attentions)
return outputs
class FlaxBloomModule(nn.Module):
config: BloomConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.embed_dim = self.config.hidden_size
self.word_embeddings = nn.Embed(
self.config.vocab_size,
self.embed_dim,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
)
self.word_embeddings_layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
self.h = FlaxBloomBlockCollection(self.config, dtype=self.dtype)
self.ln_f = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
def __call__(
self,
input_ids=None,
attention_mask=None,
deterministic=True,
init_cache: bool = False,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
):
inputs_embeds = self.word_embeddings(input_ids)
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
alibi = build_alibi_tensor(attention_mask, self.config.n_head, dtype=hidden_states.dtype)
outputs = self.h(
hidden_states,
alibi=alibi,
attention_mask=attention_mask,
deterministic=deterministic,
init_cache=init_cache,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
)
hidden_states = outputs[0]
hidden_states = self.ln_f(hidden_states)
if output_hidden_states:
all_hidden_states = outputs[1] + (hidden_states,)
outputs = (hidden_states, all_hidden_states) + outputs[2:]
else:
outputs = (hidden_states,) + outputs[1:]
if not return_dict:
return tuple(v for v in [outputs[0], outputs[-1]] if v is not None)
return FlaxBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
hidden_states=outputs[1],
attentions=outputs[-1],
)
@add_start_docstrings(
"The bare Bloom Model transformer outputting raw hidden-states without any specific head on top.",
BLOOM_START_DOCSTRING,
)
class FlaxBloomModel(FlaxBloomPreTrainedModel):
module_class = FlaxBloomModule
append_call_sample_docstring(FlaxBloomModel, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutput, _CONFIG_FOR_DOC)
class FlaxBloomForCausalLMModule(nn.Module):
config: BloomConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.transformer = FlaxBloomModule(self.config, dtype=self.dtype)
self.lm_head = nn.Dense(
self.config.vocab_size,
use_bias=False,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
)
def __call__(
self,
input_ids,
attention_mask,
deterministic: bool = True,
init_cache: bool = False,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
outputs = self.transformer(
input_ids,
attention_mask=attention_mask,
deterministic=deterministic,
init_cache=init_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
if self.config.tie_word_embeddings:
shared_kernel = self.transformer.variables["params"]["word_embeddings"]["embedding"].T
lm_logits = self.lm_head.apply({"params": {"kernel": shared_kernel}}, hidden_states)
else:
lm_logits = self.lm_head(hidden_states)
if not return_dict:
return (lm_logits,) + outputs[1:]
return FlaxCausalLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
@add_start_docstrings(
"""
The Bloom Model transformer with a language modeling head on top (linear layer with weights tied to the input
embeddings).
""",
BLOOM_START_DOCSTRING,
)
class FlaxBloomForCausalLM(FlaxBloomPreTrainedModel):
module_class = FlaxBloomForCausalLMModule
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
batch_size, seq_length = input_ids.shape
past_key_values = self.init_cache(batch_size, max_length)
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
if attention_mask is not None:
extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
return {
"past_key_values": past_key_values,
"attention_mask": extended_attention_mask,
}
def update_inputs_for_generation(self, model_outputs, model_kwargs):
model_kwargs["past_key_values"] = model_outputs.past_key_values
return model_kwargs
append_call_sample_docstring(FlaxBloomForCausalLM, _CHECKPOINT_FOR_DOC, FlaxCausalLMOutput, _CONFIG_FOR_DOC)
.\models\bloom\tokenization_bloom_fast.py
"""Tokenization classes for Bloom."""
import pickle
from typing import Optional, Tuple
from ...tokenization_utils_base import BatchEncoding
from ...tokenization_utils_fast import PreTrainedTokenizerFast
from ...utils import logging
logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {"tokenizer_file": "tokenizer.json"}
PRETRAINED_VOCAB_FILES_MAP = {
"tokenizer_file": {
"bigscience/tokenizer": "https://huggingface.co/bigscience/tokenizer/blob/main/tokenizer.json",
"bigscience/bloom-560m": "https://huggingface.co/bigscience/bloom-560m/blob/main/tokenizer.json",
"bigscience/bloom-1b1": "https://huggingface.co/bigscience/bloom-1b1/blob/main/tokenizer.json",
"bigscience/bloom-1b7": "https://huggingface.co/bigscience/bloom-1b7/blob/main/tokenizer.json",
"bigscience/bloom-3b": "https://huggingface.co/bigscience/bloom-3b/blob/main/tokenizer.json",
"bigscience/bloom-7b1": "https://huggingface.co/bigscience/bloom-7b1/blob/main/tokenizer.json",
"bigscience/bloom": "https://huggingface.co/bigscience/bloom/blob/main/tokenizer.json",
},
}
class BloomTokenizerFast(PreTrainedTokenizerFast):
"""
Construct a "fast" Bloom tokenizer (backed by HuggingFace's *tokenizers* library). Based on byte-level
Byte-Pair-Encoding.
This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will
be encoded differently whether it is at the beginning of the sentence (without space) or not:
```
>>> from transformers import BloomTokenizerFast
>>> tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom")
>>> tokenizer("Hello world")["input_ids"]
[59414, 8876]
>>> tokenizer(" Hello world")["input_ids"]
[86153, 8876]
```
You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer, but since
the model was not pretrained this way, it might yield a decrease in performance.
<Tip>
When used with `is_split_into_words=True`, this tokenizer needs to be instantiated with `add_prefix_space=True`.
</Tip>
This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
refer to this superclass for more information regarding those methods.
"""
Args:
vocab_file (`str`):
Path to the vocabulary file.
merges_file (`str`):
Path to the merges file.
errors (`str`, *optional*, defaults to `"replace"`):
Paradigm to follow when decoding bytes to UTF-8. See
[bytes.decode](https://docs.python.org/3/library/stdtypes.html
unk_token (`str`, *optional*, defaults to `<|endoftext|>`):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
token instead.
bos_token (`str`, *optional*, defaults to `<|endoftext|>`):
The beginning of sequence token.
eos_token (`str`, *optional*, defaults to `<|endoftext|>`):
The end of sequence token.
add_prefix_space (`bool`, *optional*, defaults to `False`):
Whether or not to add an initial space to the input. This allows to treat the leading word just as any
other word. (Bloom tokenizer detect beginning of words by the preceding space).
trim_offsets (`bool`, *optional*, defaults to `True`):
Whether or not the post-processing step should trim offsets to avoid including whitespaces.
"""
# 定义预训练模型所需的文件名称
vocab_files_names = VOCAB_FILES_NAMES
# 预训练模型的文件映射
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
# 模型输入的名称列表
model_input_names = ["input_ids", "attention_mask"]
# 慢速分词器类,默认为 None
slow_tokenizer_class = None
# 没有 `max_model_input_sizes`,因为 BLOOM 使用 ALiBi 位置嵌入
def __init__(
self,
vocab_file=None,
merges_file=None,
tokenizer_file=None,
unk_token="<unk>",
bos_token="<s>",
eos_token="</s>",
pad_token="<pad>",
add_prefix_space=False,
clean_up_tokenization_spaces=False,
**kwargs,
):
# 调用父类的初始化方法,传递必要的参数和可选参数
super().__init__(
vocab_file,
merges_file,
tokenizer_file=tokenizer_file,
unk_token=unk_token,
bos_token=bos_token,
eos_token=eos_token,
pad_token=pad_token,
add_prefix_space=add_prefix_space,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
**kwargs,
)
# 序列化后的预分词器和解码器状态
pre_tok_state = pickle.dumps(self.backend_tokenizer.pre_tokenizer)
decoder_state = pickle.dumps(self.backend_tokenizer.decoder)
# 如果需要添加前缀空格,则更新序列化状态以匹配配置
if add_prefix_space:
pre_tok_state = pre_tok_state.replace(b'"add_prefix_space":false', b'"add_prefix_space": true')
decoder_state = decoder_state.replace(b'"add_prefix_space":false', b'"add_prefix_space": true')
# 反序列化并更新后端分词器的预分词器和解码器
self.backend_tokenizer.pre_tokenizer = pickle.loads(pre_tok_state)
self.backend_tokenizer.decoder = pickle.loads(decoder_state)
# 设置类属性,记录是否添加前缀空格
self.add_prefix_space = add_prefix_space
# 定义一个方法 `_batch_encode_plus`,接受任意位置参数和关键字参数,并返回 `BatchEncoding` 对象
def _batch_encode_plus(self, *args, **kwargs) -> BatchEncoding:
# 从关键字参数中获取 `is_split_into_words`,默认为 False
is_split_into_words = kwargs.get("is_split_into_words", False)
# 如果 `add_prefix_space` 为 False 并且 `is_split_into_words` 也为 False,则抛出异常
if not (self.add_prefix_space or not is_split_into_words):
raise Exception(
f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True to use it with"
" pretokenized inputs."
)
# 调用父类的 `_batch_encode_plus` 方法,并传递所有位置参数和关键字参数
return super()._batch_encode_plus(*args, **kwargs)
# 定义一个方法 `_encode_plus`,接受任意位置参数和关键字参数,并返回 `BatchEncoding` 对象
def _encode_plus(self, *args, **kwargs) -> BatchEncoding:
# 从关键字参数中获取 `is_split_into_words`,默认为 False
is_split_into_words = kwargs.get("is_split_into_words", False)
# 如果 `add_prefix_space` 为 False 并且 `is_split_into_words` 也为 False,则抛出异常
if not (self.add_prefix_space or not is_split_into_words):
raise Exception(
f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True to use it with"
" pretokenized inputs."
)
# 调用父类的 `_encode_plus` 方法,并传递所有位置参数和关键字参数
return super()._encode_plus(*args, **kwargs)
# 定义一个方法 `save_vocabulary`,接受一个保存目录路径 `save_directory` 和一个可选的文件名前缀 `filename_prefix`,返回一个包含文件名的元组
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
# 调用 `_tokenizer` 对象的 `model.save` 方法,将模型保存到指定的 `save_directory` 中,并指定文件名前缀 `filename_prefix`
files = self._tokenizer.model.save(save_directory, name=filename_prefix)
# 返回保存的文件名构成的元组
return tuple(files)
@property
# 定义一个属性 `default_chat_template`,返回一个简单的聊天模板字符串,该模板忽略角色信息,并用 EOS 标记连接消息
def default_chat_template(self):
"""
A simple chat template that ignores role information and just concatenates messages with EOS tokens.
"""
# 发出警告日志,提示用户未定义聊天模板,使用默认模板
logger.warning_once(
"\nNo chat template is defined for this tokenizer - using the default template "
f"for the {self.__class__.__name__} class. If the default is not appropriate for "
"your model, please set `tokenizer.chat_template` to an appropriate template. "
"See https://huggingface.co/docs/transformers/main/chat_templating for more information.\n"
)
# 返回默认的聊天模板字符串,用于处理消息
return "{% for message in messages %}" "{{ message.content }}{{ eos_token }}" "{% endfor %}"
.\models\bloom\__init__.py
from typing import TYPE_CHECKING
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
是Flax可获取的,
是Tokenizers可获取的,
是Torch可获取的,
)
_import_structure = {
"configuration_bloom": ["BLOOM_PRETRAINED_CONFIG_ARCHIVE_MAP", "BloomConfig", "BloomOnnxConfig"],
}
try:
如果 not 是Tokenizers可获取的():
抛出OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["tokenization_bloom_fast"] = ["BloomTokenizerFast"]
try:
If not 是Torch可获取的():
抛出OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_bloom"] = [
"BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST",
"BloomForCausalLM",
"BloomModel",
"BloomPreTrainedModel",
"BloomForSequenceClassification",
"BloomForTokenClassification",
"BloomForQuestionAnswering",
]
try:
If not 是Flax可获取的():
抛出OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_flax_bloom"] = [
"FlaxBloomForCausalLM",
"FlaxBloomModel",
"FlaxBloomPreTrainedModel",
]
如果 "检查类型":
从 .configuration_bloom 导入 BLOOM_PRETRAINED_CONFIG_ARCHIVE_MAP, BloomConfig, BloomOnnxConfig
try:
If not 是Tokenizers可获取的():
抛出OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
从 .tokenization_bloom_fast 导入 BloomTokenizerFast
try:
If not 是Torch可获取的():
抛出OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
从 .modeling_bloom 导入 (
BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST,
BloomForCausalLM,
BloomForQuestionAnswering,
BloomForSequenceClassification,
BloomForTokenClassification,
BloomModel,
BloomPreTrainedModel,
)
try:
If not 是Flax可获取的():
抛出OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
从 .modeling_flax_bloom 导入 (
FlaxBloomForCausalLM,
FlaxBloomModel,
FlaxBloomPreTrainedModel,
)
else:
import 系统 as 系统
_导入_structure = _懒模块(lambda: _导入_structure(), 属性("__version__"))
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
.\models\bridgetower\configuration_bridgetower.py
""" BridgeTower model configuration"""
import os
from typing import Union
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
BRIDGETOWER_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"BridgeTower/bridgetower-base": "https://huggingface.co/BridgeTower/bridgetower-base/blob/main/config.json",
"BridgeTower/bridgetower-base-itm-mlm": (
"https://huggingface.co/BridgeTower/bridgetower-base-itm-mlm/blob/main/config.json"
),
}
class BridgeTowerVisionConfig(PretrainedConfig):
r"""
This is the configuration class to store the vision configuration of a [`BridgeTowerModel`]. Instantiating a
configuration with the defaults will yield a similar configuration to that of the bridgetower-base
[BridgeTower/bridgetower-base](https://huggingface.co/BridgeTower/bridgetower-base/) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
hidden_size (`int`, *optional*, defaults to 768):
Dimensionality of the encoder layers and the pooler layer.
num_hidden_layers (`int`, *optional*, defaults to 12):
Number of hidden layers in visual encoder model.
patch_size (`int`, *optional*, defaults to 16):
The size (resolution) of each patch.
image_size (`int`, *optional*, defaults to 288):
The size (resolution) of each image.
initializer_factor (`float`, *optional*, defaults to 1):
A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
testing).
layer_norm_eps (`float`, *optional*, defaults to 1e-05):
The epsilon used by the layer normalization layers.
stop_gradient (`bool`, *optional*, defaults to `False`):
Whether to stop gradient for training.
share_layernorm (`bool`, *optional*, defaults to `True`):
Whether LayerNorm layers are shared.
remove_last_layer (`bool`, *optional*, defaults to `False`):
Whether to remove the last layer from the vision encoder.
"""
def __init__(
self,
hidden_size=768,
num_hidden_layers=12,
patch_size=16,
image_size=288,
initializer_factor=1.0,
layer_norm_eps=1e-05,
stop_gradient=False,
share_layernorm=True,
remove_last_layer=False,
**kwargs
):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.patch_size = patch_size
self.image_size = image_size
self.initializer_factor = initializer_factor
self.layer_norm_eps = layer_norm_eps
self.stop_gradient = stop_gradient
self.share_layernorm = share_layernorm
self.remove_last_layer = remove_last_layer
>>> from transformers import BridgeTowerVisionConfig
>>>
>>> configuration = BridgeTowerVisionConfig()
>>>
>>> configuration
```"""
model_type = "bridgetower_vision_model"
# 设置模型类型为 "bridgetower_vision_model"
def __init__(
self,
hidden_size=768,
num_hidden_layers=12,
num_channels=3,
patch_size=16,
image_size=288,
initializer_factor=1,
layer_norm_eps=1e-05,
stop_gradient=False,
share_layernorm=True,
remove_last_layer=False,
**kwargs,
):
# 初始化方法,接受多个参数用于配置模型的各个属性
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_channels = num_channels
self.patch_size = patch_size
self.image_size = image_size
self.initializer_factor = initializer_factor
self.layer_norm_eps = layer_norm_eps
self.stop_gradient = stop_gradient
self.share_layernorm = share_layernorm
self.remove_last_layer = remove_last_layer
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
# 类方法,从预训练模型加载配置
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
if config_dict.get("model_type") == "bridgetower":
config_dict = config_dict["text_config"]
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)
return cls.from_dict(config_dict, **kwargs)
# BridgeTowerTextConfig 类继承自 PretrainedConfig,用于存储文本模型的配置信息
class BridgeTowerTextConfig(PretrainedConfig):
r"""
This is the configuration class to store the text configuration of a [`BridgeTowerModel`]. The default values here
are copied from RoBERTa. Instantiating a configuration with the defaults will yield a similar configuration to that
of the bridgetower-base [BridegTower/bridgetower-base](https://huggingface.co/BridgeTower/bridgetower-base/)
architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Example:
```
>>> from transformers import BridgeTowerTextConfig
>>>
>>> configuration = BridgeTowerTextConfig()
>>>
>>> configuration
```"""
# 模型类型为 "bridgetower_text_model"
model_type = "bridgetower_text_model"
# 初始化方法,设置各种模型参数
def __init__(
self,
vocab_size=50265, # 词汇表大小,默认为 50265
hidden_size=768, # 隐藏层大小,默认为 768
num_hidden_layers=12, # 隐藏层数,默认为 12
num_attention_heads=12, # 注意力头数,默认为 12
initializer_factor=1, # 初始化因子,默认为 1
intermediate_size=3072, # 中间层大小,默认为 3072
hidden_act="gelu", # 隐藏层激活函数,默认为 "gelu"
hidden_dropout_prob=0.1, # 隐藏层 dropout 概率,默认为 0.1
attention_probs_dropout_prob=0.1, # 注意力 dropout 概率,默认为 0.1
max_position_embeddings=514, # 最大位置嵌入数,默认为 514
type_vocab_size=1, # 类型词汇表大小,默认为 1
layer_norm_eps=1e-05, # 层归一化 epsilon,默认为 1e-05
pad_token_id=1, # 填充 token 的 id,默认为 1
bos_token_id=0, # 开始 token 的 id,默认为 0
eos_token_id=2, # 结束 token 的 id,默认为 2
position_embedding_type="absolute", # 位置嵌入类型,默认为 "absolute"
use_cache=True, # 是否使用缓存,默认为 True
**kwargs,
):
super().__init__(**kwargs) # 调用父类 PretrainedConfig 的初始化方法
# 设置各个参数
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.initializer_factor = initializer_factor
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.layer_norm_eps = layer_norm_eps
self.position_embedding_type = position_embedding_type
self.use_cache = use_cache
self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
@classmethod
# 根据预训练模型名称或路径获取配置字典和额外的关键字参数
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
# 如果配置字典中的模型类型是 "bridgetower",则将配置字典更新为其"text_config"字段的内容
if config_dict.get("model_type") == "bridgetower":
config_dict = config_dict["text_config"]
# 如果配置字典中包含"model_type"字段,并且类(cls)具有"model_type"属性,并且配置字典中的模型类型与类的模型类型不匹配,
# 则发出警告,因为这种情况下并非所有模型配置都支持,可能导致错误
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)
# 根据配置字典创建并返回预训练配置对象
return cls.from_dict(config_dict, **kwargs)
# BridgeTowerConfig 类,用于存储 BridgeTowerModel 的配置信息
class BridgeTowerConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`BridgeTowerModel`]. It is used to instantiate a
BridgeTower model according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the bridgetower-base
[BridgeTower/bridgetower-base](https://huggingface.co/BridgeTower/bridgetower-base/) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
share_cross_modal_transformer_layers (`bool`, *optional*, defaults to `True`):
Whether cross modal transformer layers are shared.
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler.
hidden_size (`int`, *optional*, defaults to 768):
Dimensionality of the encoder layers and the pooler layer.
initializer_factor (`float`, *optional*, defaults to 1):
A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
testing).
layer_norm_eps (`float`, *optional*, defaults to 1e-05):
The epsilon used by the layer normalization layers.
share_link_tower_layers (`bool`, *optional*, defaults to `False`):
Whether the bride/link tower layers are shared.
link_tower_type (`str`, *optional*, defaults to `"add"`):
Type of the bridge/link layer.
num_attention_heads (`int`, *optional*, defaults to 12):
Number of attention heads for each attention layer in the Transformer encoder.
num_hidden_layers (`int`, *optional*, defaults to 6):
Number of hidden layers in the Transformer encoder.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether to tie input and output embeddings.
init_layernorm_from_vision_encoder (`bool`, *optional*, defaults to `False`):
Whether to init LayerNorm from the vision encoder.
text_config (`dict`, *optional*):
Dictionary of configuration options used to initialize [`BridgeTowerTextConfig`].
vision_config (`dict`, *optional*):
Dictionary of configuration options used to initialize [`BridgeTowerVisionConfig`].
Example:
```
>>> from transformers import BridgeTowerModel, BridgeTowerConfig
>>>
>>> configuration = BridgeTowerConfig()
>>>
>>> model = BridgeTowerModel(configuration)
>>>
>>> configuration = model.config
```
model_type = "bridgetower"
class BridgeTowerConfig:
def __init__(
self,
share_cross_modal_transformer_layers=True,
hidden_act="gelu",
hidden_size=768,
initializer_factor=1,
layer_norm_eps=1e-05,
share_link_tower_layers=False,
link_tower_type="add",
num_attention_heads=12,
num_hidden_layers=6,
tie_word_embeddings=False,
init_layernorm_from_vision_encoder=False,
text_config=None,
vision_config=None,
**kwargs,
):
_ = kwargs.pop("text_config_dict", None)
_ = kwargs.pop("vision_config_dict", None)
super().__init__(**kwargs)
self.share_cross_modal_transformer_layers = share_cross_modal_transformer_layers
self.hidden_act = hidden_act
self.hidden_size = hidden_size
self.initializer_factor = initializer_factor
self.layer_norm_eps = layer_norm_eps
self.share_link_tower_layers = share_link_tower_layers
self.link_tower_type = link_tower_type
self.num_attention_heads = num_attention_heads
self.num_hidden_layers = num_hidden_layers
self.tie_word_embeddings = tie_word_embeddings
self.init_layernorm_from_vision_encoder = init_layernorm_from_vision_encoder
if text_config is None:
text_config = {}
logger.info("`text_config` is `None`. Initializing the `BridgeTowerTextConfig` with default values.")
if vision_config is None:
vision_config = {}
logger.info("`vision_config` is `None`. Initializing the `BridgeTowerVisionConfig` with default values.")
self.text_config = BridgeTowerTextConfig(**text_config)
self.vision_config = BridgeTowerVisionConfig(**vision_config)
@classmethod
def from_text_vision_configs(
cls, text_config: BridgeTowerTextConfig, vision_config: BridgeTowerVisionConfig, **kwargs
):
r"""
从 BridgeTower 文本模型配置实例化一个 [`BridgeTowerConfig`](或其派生类)。返回:
[`BridgeTowerConfig`]: 配置对象的一个实例
"""
return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
.\models\bridgetower\image_processing_bridgetower.py
"""BridgeTower 的图像处理器类。"""
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import numpy as np
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import PaddingMode, center_crop, pad, resize, to_channel_dimension_format
from ...image_utils import (
OPENAI_CLIP_MEAN,
OPENAI_CLIP_STD,
ChannelDimension,
ImageInput,
PILImageResampling,
get_image_size,
infer_channel_dimension_format,
is_batched,
is_scaled_image,
to_numpy_array,
valid_images,
validate_kwargs,
validate_preprocess_arguments,
)
from ...utils import TensorType, is_vision_available, logging
if is_vision_available():
import PIL
logger = logging.get_logger(__name__)
def max_across_indices(values: Iterable[Any]) -> List[Any]:
"""
Return the maximum value across all indices of an iterable of values.
"""
return [max(values_i) for values_i in zip(*values)]
def make_pixel_mask(
image: np.ndarray, output_size: Tuple[int, int], input_data_format: Optional[Union[str, ChannelDimension]] = None
) -> np.ndarray:
"""
Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
Args:
image (`np.ndarray`):
Image to make the pixel mask for.
output_size (`Tuple[int, int]`):
Output size of the mask.
"""
input_height, input_width = get_image_size(image, channel_dim=input_data_format)
mask = np.zeros(output_size, dtype=np.int64)
mask[:input_height, :input_width] = 1
return mask
def get_max_height_width(
images: List[np.ndarray], input_data_format: Optional[Union[str, ChannelDimension]] = None
) -> List[int]:
"""
Get the maximum height and width across all images in a batch.
"""
if input_data_format is None:
input_data_format = infer_channel_dimension_format(images[0])
if input_data_format == ChannelDimension.FIRST:
_, max_height, max_width = max_across_indices([img.shape for img in images])
elif input_data_format == ChannelDimension.LAST:
max_height, max_width, _ = max_across_indices([img.shape for img in images])
else:
raise ValueError(f"Invalid channel dimension format: {input_data_format}")
return (max_height, max_width)
def get_resize_output_image_size(
input_image: np.ndarray,
shorter: int = 800,
longer: int = 1333,
size_divisor: int = 32,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> Tuple[int, int]:
input_height, input_width = get_image_size(input_image, input_data_format)
min_size, max_size = shorter, longer
scale = min_size / min(input_height, input_width)
if input_height < input_width:
new_height = min_size
new_width = scale * input_width
else:
new_height = scale * input_height
new_width = min_size
if max(new_height, new_width) > max_size:
scale = max_size / max(new_height, new_width)
new_height = scale * new_height
new_width = scale * new_width
new_height, new_width = int(new_height + 0.5), int(new_width + 0.5)
new_height = new_height // size_divisor * size_divisor
new_width = new_width // size_divisor * size_divisor
return new_height, new_width
class BridgeTowerImageProcessor(BaseImageProcessor):
r"""
构建一个BridgeTower图像处理器。
"""
model_input_names = ["pixel_values"]
def __init__(
self,
do_resize: bool = True,
size: Dict[str, int] = None,
size_divisor: int = 32,
resample: PILImageResampling = PILImageResampling.BICUBIC,
do_rescale: bool = True,
rescale_factor: Union[int, float] = 1 / 255,
do_normalize: bool = True,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_center_crop: bool = True,
crop_size: Dict[str, int] = None,
do_pad: bool = True,
**kwargs,
):
super().__init__(**kwargs)
self.do_resize = do_resize
self.size = size
self.size_divisor = size_divisor
self.resample = resample
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
self.image_mean = image_mean
self.image_std = image_std
self.do_center_crop = do_center_crop
self.crop_size = crop_size
self.do_pad = do_pad
def __init__(
self,
do_resize: bool = True,
size: Optional[Dict[str, int]] = None,
size_divisor: int = 32,
resample: PILImageResampling = PILImageResampling.BICUBIC,
do_rescale: bool = True,
rescale_factor: Optional[float] = None,
do_normalize: bool = True,
image_mean: Optional[List[float]] = None,
image_std: Optional[List[float]] = None,
do_pad: bool = False,
do_center_crop: bool = False,
crop_size: Optional[Tuple[int, int]] = None,
**kwargs,
) -> None:
if "pad_and_return_pixel_mask" in kwargs:
do_pad = kwargs.pop("pad_and_return_pixel_mask")
super().__init__(**kwargs)
size = size if size is not None else {"shortest_edge": 288}
size = get_size_dict(size, default_to_square=False)
self.do_resize = do_resize
self.size = size
self.size_divisor = size_divisor
self.resample = resample
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
self.do_pad = do_pad
self.do_center_crop = do_center_crop
self.crop_size = crop_size
self._valid_processor_keys = [
"images",
"do_resize",
"size",
"size_divisor",
"resample",
"do_rescale",
"rescale_factor",
"do_normalize",
"image_mean",
"image_std",
"do_pad",
"do_center_crop",
"crop_size",
"return_tensors",
"data_format",
"input_data_format",
]
def resize(
self,
image: np.ndarray,
size: Dict[str, int],
size_divisor: int = 32,
resample: Optional[PILImageResampling] = PILImageResampling.BICUBIC,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
) -> np.ndarray:
"""
Resize an image.
Resizes the shorter side of the image to `size["shortest_edge"]` while preserving the aspect ratio. If the
longer side is larger than the max size `(int(size["shortest_edge"] * 1333 / 800))`, the longer side is then
resized to the max size while preserving the aspect ratio.
Args:
image (`np.ndarray`):
Image to resize.
size (`Dict[str, int]`):
Controls the size of the output image. Should be of the form `{"shortest_edge": int}`.
size_divisor (`int`, defaults to 32):
The image is resized to a size that is a multiple of this value.
resample (`PILImageResampling` filter, *optional*, defaults to `PILImageResampling.BICUBIC`):
Resampling filter to use when resizing the image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
input_data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
Returns:
np.ndarray: Resized image.
Raises:
ValueError: If `size` dictionary does not contain the key `"shortest_edge"`.
"""
size = get_size_dict(size, default_to_square=False)
if "shortest_edge" not in size:
raise ValueError(f"The `size` dictionary must contain the key `shortest_edge`. Got {size.keys()}")
shorter = size["shortest_edge"]
longer = int(1333 / 800 * shorter)
output_size = get_resize_output_image_size(
image, shorter=shorter, longer=longer, size_divisor=size_divisor, input_data_format=input_data_format
)
return resize(
image,
size=output_size,
resample=resample,
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
)
) -> np.ndarray:
"""
Center crop an image to `(size["height"], size["width"])`. If the input size is smaller than `crop_size` along
any edge, the image is padded with 0's and then center cropped.
Args:
image (`np.ndarray`):
Image to center crop.
size (`Dict[str, int]`):
Size of the output image in the form `{"height": h, "width": w}`.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred from the input
image.
"""
output_size = size["shortest_edge"]
return center_crop(
image,
size=(output_size, output_size),
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
)
def _pad_image(
self,
image: np.ndarray,
output_size: Tuple[int, int],
constant_values: Union[float, Iterable[float]] = 0,
data_format: Optional[ChannelDimension] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray:
"""
Pad an image with zeros to the given size.
Args:
image (`np.ndarray`):
Input image to be padded.
output_size (`Tuple[int, int]`):
Desired output size of the image in format `(height, width)`.
constant_values (`Union[float, Iterable[float]]`, *optional*):
Value or sequence of values to pad the image with. Default is 0.
data_format (`ChannelDimension`, *optional*):
Format of the output image channel dimension. If not specified, defaults to `None`.
input_data_format (`Union[str, ChannelDimension]`, *optional*):
Format of the input image channel dimension. If not specified, defaults to `None`.
Returns:
np.ndarray:
Padded image of shape `(output_size[0], output_size[1], channels)`.
"""
input_height, input_width = get_image_size(image, channel_dim=input_data_format)
output_height, output_width = output_size
pad_bottom = output_height - input_height
pad_right = output_width - input_width
padding = ((0, pad_bottom), (0, pad_right))
padded_image = pad(
image,
padding,
mode=PaddingMode.CONSTANT,
constant_values=constant_values,
data_format=data_format,
input_data_format=input_data_format,
)
return padded_image
def pad(
self,
images: List[np.ndarray],
constant_values: Union[float, Iterable[float]] = 0,
return_pixel_mask: bool = True,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Optional[ChannelDimension] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> BatchFeature:
"""
Pads a batch of images to the bottom and right of the image with zeros to the size of largest height and width
in the batch and optionally returns their corresponding pixel mask.
Args:
image (`np.ndarray`):
Image to pad.
constant_values (`float` or `Iterable[float]`, *optional*):
The value to use for the padding if `mode` is `"constant"`.
return_pixel_mask (`bool`, *optional*, defaults to `True`):
Whether to return a pixel mask.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
"""
pad_size = get_max_height_width(images, input_data_format=input_data_format)
padded_images = [
self._pad_image(
image,
pad_size,
constant_values=constant_values,
data_format=data_format,
input_data_format=input_data_format,
)
for image in images
]
data = {"pixel_values": padded_images}
if return_pixel_mask:
masks = [
make_pixel_mask(image=image, output_size=pad_size, input_data_format=input_data_format)
for image in images
]
data["pixel_mask"] = masks
return BatchFeature(data=data, tensor_type=return_tensors)
def preprocess(
self,
images: ImageInput,
do_resize: Optional[bool] = None,
size: Optional[Dict[str, int]] = None,
size_divisor: Optional[int] = None,
resample: PILImageResampling = None,
do_rescale: Optional[bool] = None,
rescale_factor: Optional[float] = None,
do_normalize: Optional[bool] = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_pad: Optional[bool] = None,
do_center_crop: Optional[bool] = None,
crop_size: Dict[str, int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: ChannelDimension = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
.\models\bridgetower\modeling_bridgetower.py
"""PyTorch BridgeTower Model"""
import math
from collections import OrderedDict
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN, QuickGELUActivation
from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions,
MaskedLMOutput,
ModelOutput,
SequenceClassifierOutput,
)
from ...modeling_utils import PreTrainedModel, apply_chunking_to_forward
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_bridgetower import BridgeTowerConfig, BridgeTowerTextConfig, BridgeTowerVisionConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "BridgeTowerConfig"
_CHECKPOINT_FOR_DOC = "BridgeTower/bridgetower-base"
_TOKENIZER_FOR_DOC = "RobertaTokenizer"
BRIDGETOWER_PRETRAINED_MODEL_ARCHIVE_LIST = [
"BridgeTower/bridgetower-base",
"BridgeTower/bridgetower-base-itm-mlm",
]
BRIDGETOWER_START_DOCSTRING = r"""
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ subclass. Use
it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
behavior.
Parameters:
config ([`BridgeTowerConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
BRIDGETOWER_INPUTS_DOCSTRING = r"""
"""
@dataclass
class BridgeTowerModelOutput(ModelOutput):
"""
Output type of [`BridgeTowerModel`].
Represents the output of the BridgeTowerModel.
Inherits from ModelOutput defined in the modeling_outputs module.
"""
text_features: torch.FloatTensor = None
image_features: torch.FloatTensor = None
pooler_output: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class BridgeTowerContrastiveOutput(ModelOutput):
"""
Output type of ['BridgeTowerForContrastiveLearning']
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`:
Image-text contrastive loss. 图像与文本的对比损失值(当 `return_loss` 为 `True` 时返回)。
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
语言建模头部的预测分数(SoftMax 前每个词汇标记的分数)。
text_embeds (`torch.FloatTensor)`, *optional*, returned when model is initialized with `with_projection=True`):
The text embeddings obtained by applying the projection layer to the pooler_output.
应用投影层到池化输出后得到的文本嵌入。
image_embeds (`torch.FloatTensor)`, *optional*, returned when model is initialized with `with_projection=True`):
The image embeddings obtained by applying the projection layer to the pooler_output.
应用投影层到池化输出后得到的图像嵌入。
cross_embeds (`torch.FloatTensor)`, *optional*, returned when model is initialized with `with_projection=True`):
The text-image cross-modal embeddings obtained by applying the projection layer to the pooler_output.
应用投影层到池化输出后得到的文本-图像跨模态嵌入。
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of
the model at the output of each layer plus the optional initial embedding outputs.
如果模型有嵌入层,输出嵌入和每一层的输出形成的元组,形状为 `(batch_size, sequence_length, hidden_size)`。
模型每层的隐藏状态及可选的初始嵌入输出。
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
如果传递了 `output_attentions=True` 或 `config.output_attentions=True`,返回每层的注意力分布,
形状为 `(batch_size, num_heads, sequence_length, sequence_length)`。
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
text_embeds: Optional[Tuple[torch.FloatTensor]] = None
image_embeds: Optional[Tuple[torch.FloatTensor]] = None
cross_embeds: Optional[Tuple[torch.FloatTensor]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
class BridgeTowerResidualAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.attn = nn.MultiheadAttention(config.hidden_size, config.hidden_size // 64)
self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.mlp = nn.ModuleDict(
OrderedDict(
[
("c_fc", nn.Linear(config.hidden_size, config.hidden_size * 4)),
("gelu", QuickGELUActivation()),
("c_proj", nn.Linear(config.hidden_size * 4, config.hidden_size)),
]
)
)
self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.attn_mask = None
def attention(self, hidden_state: torch.Tensor, attention_mask: torch.Tensor):
if attention_mask is not None:
attention_mask = attention_mask.to(dtype=torch.bool, device=hidden_state.device)
self.attn_mask = (
self.attn_mask.to(dtype=hidden_state.dtype, device=hidden_state.device)
if self.attn_mask is not None
else None
)
return self.attn(
hidden_state,
hidden_state,
hidden_state,
need_weights=False,
attn_mask=self.attn_mask,
key_padding_mask=attention_mask,
)[0]
def forward(self, hidden_state: torch.Tensor, attention_mask: torch.Tensor = None):
residual_state = hidden_state + self.attention(self.ln_1(hidden_state), attention_mask)
hidden_state = self.ln_2(residual_state)
for _, layer in self.mlp.items():
hidden_state = layer(hidden_state)
hidden_state = residual_state + hidden_state
return hidden_state
class BridgeTowerVisionTransformer(nn.Module):
def __init__(self, config: BridgeTowerVisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
bias=False,
)
self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches + 1
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
batch_size = pixel_values.shape[0]
target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
embeddings = embeddings + self.position_embedding(self.position_ids)
return embeddings
def __init__(self, config):
super().__init__()
self.embeddings = BridgeTowerVisionEmbeddings(config)
self.ln_pre = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.transformer = BridgeTowerTransformer(config)
self.ln_post = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.share_layernorm = config.share_layernorm
if not config.share_layernorm:
self.ln_separate = nn.ModuleList(
[nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) for _ in range(config.num_hidden_layers)]
)
def forward(self, pixel_values: torch.Tensor, attention_mask):
hidden_states = self.embeddings(pixel_values)
hidden_states = self.ln_pre(hidden_states)
hidden_states = hidden_states.permute(1, 0, 2)
hidden_states = self.transformer(hidden_states, attention_mask)
hidden_states = torch.stack(hidden_states, dim=0)
hidden_states = hidden_states.permute(0, 2, 1, 3)
if self.share_layernorm:
hidden_states = self.ln_post(hidden_states)
else:
hidden_states_stack = []
for hidden_states, ln in zip(hidden_states, self.ln_separate):
hidden_states = ln(hidden_states)
hidden_states_stack.append(hidden_states)
hidden_states = torch.stack(hidden_states_stack, dim=0)
return hidden_states
def forward_pre(self, pixel_values: torch.Tensor):
hidden_states = self.embeddings(pixel_values)
hidden_states = self.ln_pre(hidden_states)
hidden_states = hidden_states.permute(1, 0, 2)
return hidden_states
def forward_post(self, hidden_state: torch.Tensor):
visual_output_post = hidden_state.permute(1, 0, 2)
visual_output_post = self.ln_post(visual_output_post)
return visual_output_post
class BridgeTowerLinkTower(nn.Module):
def __init__(self, config):
super().__init__()
self.link_tower_type = config.link_tower_type
self.hidden_size = config.hidden_size
if config.link_tower_type in ["add", "scaled_add", "interpolate"]:
if config.link_tower_type == "scaled_add":
self.scaled_factor = nn.Parameter(torch.tensor(1.0))
elif config.link_tower_type == "interpolate":
self.beta = nn.Parameter(torch.tensor(0.5))
self.LayerNorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
else:
raise NotImplementedError(f"link_tower_type {config.link_tower_type} is not implemented")
def forward(self, hidden_states, cross_modal_hidden_states, attention_mask):
if self.link_tower_type == "add":
return self.LayerNorm(hidden_states + cross_modal_hidden_states)
elif self.link_tower_type == "scaled_add":
return self.LayerNorm(hidden_states * self.scaled_factor + cross_modal_hidden_states)
elif self.link_tower_type == "interpolate":
return self.LayerNorm(hidden_states * (1 - self.beta) + cross_modal_hidden_states * self.beta)
else:
raise NotImplementedError(f"link_tower_type {self.link_tower_type} is not implemented")
class BridgeTowerSelfOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class BridgeTowerIntermediate(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class BridgeTowerOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class BridgeTowerPooler(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
class BridgeTowerSelfAttention(nn.Module):
def __init__(self, config, position_embedding_type=None):
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
f"heads ({config.num_attention_heads})"
)
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = position_embedding_type or getattr(
config, "position_embedding_type", "absolute"
)
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
self.is_decoder = config.is_decoder
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
class BridgeTowerAttention(nn.Module):
def __init__(self, config, position_embedding_type=None):
super().__init__()
self.self = BridgeTowerSelfAttention(config, position_embedding_type=position_embedding_type)
self.output = BridgeTowerSelfOutput(config)
self.pruned_heads = set()
def prune_heads(self, heads):
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
)
self.self.query = prune_linear_layer(self.self.query, index)
self.self.key = prune_linear_layer(self.self.key, index)
self.self.value = prune_linear_layer(self.self.value, index)
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
self_outputs = self.self(
hidden_states,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:]
return outputs
class BridgeTowerBertCrossLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = BridgeTowerAttention(config)
self.is_decoder = config.is_decoder
self.add_cross_attention = config.add_cross_attention
self.crossattention = BridgeTowerAttention(config)
self.intermediate = BridgeTowerIntermediate(config)
self.output = BridgeTowerOutput(config)
def forward(
self,
hidden_states,
encoder_hidden_states,
attention_mask=None,
head_mask=None,
encoder_attention_mask=None,
past_key_value=None,
output_attentions=False,
):
outputs = self.attention(
hidden_states,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
if self.add_cross_attention:
cross_attention_outputs = self.crossattention(
outputs[0],
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
intermediate_output = self.intermediate(cross_attention_outputs[0])
layer_output = self.output(intermediate_output, outputs[0])
outputs = (layer_output,) + cross_attention_outputs[1:] + outputs[1:]
return outputs
self_attention_outputs = self.attention(
hidden_states,
attention_mask=attention_mask,
head_mask=None,
output_attentions=output_attentions,
past_key_value=None,
)
attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:]
cross_attention_outputs = self.crossattention(
attention_output,
attention_mask=attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_value=past_key_value,
output_attentions=output_attentions,
)
attention_output = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1:-1]
layer_output = apply_chunking_to_forward(
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
)
outputs = (layer_output,) + outputs
return outputs
def feed_forward_chunk(self, attention_output):
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
return layer_output
class BridgeTowerTextLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = BridgeTowerAttention(config)
self.is_decoder = config.is_decoder
if self.add_cross_attention:
if not self.is_decoder:
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
self.crossattention = BridgeTowerAttention(config, position_embedding_type="absolute")
self.intermediate = BridgeTowerIntermediate(config)
self.output = BridgeTowerOutput(config)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention(
hidden_states,
attention_mask,
head_mask,
output_attentions=output_attentions,
past_key_value=self_attn_past_key_value,
)
attention_output = self_attention_outputs[0]
if self.is_decoder:
outputs = self_attention_outputs[1:-1]
present_key_value = self_attention_outputs[-1]
else:
outputs = self_attention_outputs[1:]
cross_attn_present_key_value = None
if self.is_decoder and encoder_hidden_states is not None:
if not hasattr(self, "crossattention"):
raise ValueError(
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
" by setting `config.add_cross_attention=True`"
)
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
cross_attention_outputs = self.crossattention(
attention_output,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
cross_attn_past_key_value,
output_attentions,
)
attention_output = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1:-1]
cross_attn_present_key_value = cross_attention_outputs[-1]
present_key_value = present_key_value + cross_attn_present_key_value
layer_output = apply_chunking_to_forward(
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
)
outputs = (layer_output,) + outputs
if self.is_decoder:
outputs = outputs + (present_key_value,)
return outputs
def feed_forward_chunk(self, attention_output):
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
return layer_output
class BridgeTowerTextEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.layer = nn.ModuleList([BridgeTowerTextLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = True,
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
next_decoder_cache = () if use_cache else None
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[-1],)
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v
for v in [
hidden_states,
next_decoder_cache,
all_hidden_states,
all_self_attentions,
all_cross_attentions,
]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
)
class BridgeTowerTextEmbeddings(nn.Module):
"""
与 BertEmbeddings 相同,但稍作调整以适应位置嵌入的索引。
"""
def __init__(self, config):
super().__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
self.register_buffer(
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
)
self.padding_idx = config.pad_token_id
self.position_embeddings = nn.Embedding(
config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
)
def forward(
self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
):
if position_ids is None:
if input_ids is not None:
position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length)
else:
position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
if input_ids is not None:
input_shape = input_ids.size()
else:
input_shape = inputs_embeds.size()[:-1]
seq_length = input_shape[1]
if token_type_ids is None:
if hasattr(self, "token_type_ids"):
buffered_token_type_ids = self.token_type_ids[:, :seq_length]
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
token_type_ids = buffered_token_type_ids_expanded
else:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = inputs_embeds + token_type_embeddings
if self.position_embedding_type == "absolute":
position_embeddings = self.position_embeddings(position_ids)
embeddings += position_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
def create_position_ids_from_inputs_embeds(self, inputs_embeds):
"""
直接提供嵌入向量,无法推断填充标记,因此只生成顺序位置 id。
Args:
inputs_embeds: torch.Tensor
Returns: torch.Tensor
"""
input_shape = inputs_embeds.size()[:-1]
sequence_length = input_shape[1]
position_ids = torch.arange(
self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
)
return position_ids.unsqueeze(0).expand(input_shape)
def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
"""
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
are ignored. This is modified from fairseq's `utils.make_positions`.
Args:
input_ids: torch.Tensor, input tensor containing symbol indices
padding_idx: int, padding symbol index
past_key_values_length: int, optional, length of past key values
Returns:
torch.Tensor, tensor containing position indices
"""
mask = input_ids.ne(padding_idx).int()
incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
return incremental_indices.long() + padding_idx
class BridgeTowerPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = BridgeTowerConfig
base_model_prefix = "bridgetower"
supports_gradient_checkpointing = False
_no_split_modules = ["BridgeTowerSelfAttention", "BridgeTowerResidualAttention"]
_skip_keys_device_placement = "past_key_values"
def _init_weights(self, module):
"""
Initialize weights of the given module based on its type.
Args:
module: nn.Module, module to initialize weights for
"""
if isinstance(module, BridgeTowerVisionModel):
proj_std = (module.visual.transformer.hidden_size**-0.5) * (
(2 * module.visual.transformer.num_hidden_layers) ** -0.5
)
attn_std = module.visual.transformer.hidden_size**-0.5
fc_std = (2 * module.visual.transformer.hidden_size) ** -0.5
for block in module.visual.transformer.resblocks:
nn.init.normal_(block.attn.in_proj_weight, std=attn_std * self.config.initializer_factor)
nn.init.normal_(block.attn.out_proj.weight, std=proj_std * self.config.initializer_factor)
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std * self.config.initializer_factor)
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std * self.config.initializer_factor)
nn.init.normal_(module.visual.embeddings.class_embedding, std=attn_std * self.config.initializer_factor)
nn.init.normal_(
module.visual.embeddings.position_embedding.weight, std=attn_std * self.config.initializer_factor
)
elif isinstance(module, (nn.Linear, nn.Conv2d, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=0.05 * self.config.initializer_factor)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
class BridgeTowerVisionModel(BridgeTowerPreTrainedModel):
"""
Vision model class inheriting from BridgeTowerPreTrainedModel.
Attributes:
config_class: Class attribute specifying the configuration class for this model.
"""
config_class = BridgeTowerVisionConfig
def __init__(self, config):
"""
Initialize the vision model with the given configuration.
Args:
config: BridgeTowerVisionConfig, configuration instance for the model
"""
super().__init__(config)
self.visual = BridgeTowerVisionTransformer(config)
@property
def dtype(self):
return self.visual.embeddings.patch_embedding.weight.dtype
def forward(self, image, image_mask=None):
return self.visual(image.type(self.dtype), image_mask)
class BridgeTowerTextModel(BridgeTowerPreTrainedModel):
"""
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
cross-attention is added between the self-attention layers, following the architecture described in *Attention is
all you need*_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz
Kaiser and Illia Polosukhin.
To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
`add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
.. _*Attention is all you need*: https://arxiv.org/abs/1706.03762
"""
config_class = BridgeTowerTextConfig
def __init__(self, config, add_pooling_layer=True):
super().__init__(config)
self.config = config
self.embeddings = BridgeTowerTextEmbeddings(config)
self.encoder = BridgeTowerTextEncoder(config)
self.pooler = BridgeTowerPooler(config) if add_pooling_layer else None
self.post_init()
def get_input_embeddings(self):
return self.embeddings.word_embeddings
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
def _prune_heads(self, heads_to_prune):
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
"""
This method defines the forward pass for the BridgeTowerTextModel.
Args:
input_ids (Optional[torch.Tensor]): Indices of input tokens in the vocabulary.
attention_mask (Optional[torch.Tensor]): Mask to avoid performing attention on padding tokens.
token_type_ids (Optional[torch.Tensor]): Segment token indices to differentiate sentences.
position_ids (Optional[torch.Tensor]): Indices of positions of each input token in the sequence.
head_mask (Optional[torch.Tensor]): Mask to nullify selected heads of the self-attention modules.
inputs_embeds (Optional[torch.Tensor]): Optional tensor of embeddings to be used as input instead of
input_ids.
encoder_hidden_states (Optional[torch.Tensor]): Sequence of hidden states of the encoder.
encoder_attention_mask (Optional[torch.Tensor]): Mask to avoid performing attention on encoder padding tokens.
past_key_values (Optional[List[torch.FloatTensor]]): Cached outputs of the model to speed up sequential
decoding.
use_cache (Optional[bool]): Whether or not to use past_key_values to speed up decoding.
output_attentions (Optional[bool]): Whether to return attentions weights.
output_hidden_states (Optional[bool]): Whether to return hidden states.
return_dict (Optional[bool]): Whether to return a dict instead of a tuple.
Returns:
Various outputs depending on the configuration (return_dict or not).
"""
pass
def __init__(self, config):
super().__init__(config)
self.config = config
vision_config = config.vision_config
text_config = config.text_config
if config.share_cross_modal_transformer_layers:
self.cross_modal_text_transform = nn.Linear(text_config.hidden_size, config.hidden_size)
self.cross_modal_image_transform = nn.Linear(vision_config.hidden_size, config.hidden_size)
else:
self.cross_modal_text_transform = nn.ModuleList(
[nn.Linear(text_config.hidden_size, config.hidden_size) for _ in range(config.num_hidden_layers)]
)
self.cross_modal_image_transform = nn.ModuleList(
[nn.Linear(vision_config.hidden_size, config.hidden_size) for _ in range(config.num_hidden_layers)]
)
self.token_type_embeddings = nn.Embedding(2, config.hidden_size)
self.vision_model = BridgeTowerVisionModel(vision_config)
self.text_model = BridgeTowerTextModel(text_config)
if not vision_config.share_layernorm and config.init_layernorm_from_vision_encoder:
for ln in self.vision_model.visual.cross_modal_ln_separate:
ln.weight.data = self.vision_model.visual.ln_post.weight.data
ln.bias.data = self.vision_model.visual.ln_post.bias.data
self.cross_modal_image_layers = nn.ModuleList(
[BridgeTowerBertCrossLayer(text_config) for _ in range(config.num_hidden_layers)]
)
self.cross_modal_text_layers = nn.ModuleList(
[BridgeTowerBertCrossLayer(text_config) for _ in range(config.num_hidden_layers)]
)
self.cross_modal_text_pooler = BridgeTowerPooler(config)
self.cross_modal_image_pooler = BridgeTowerPooler(config)
self.cross_modal_text_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.cross_modal_image_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
if config.share_link_tower_layers:
self.cross_modal_text_link_tower = BridgeTowerLinkTower(config)
self.cross_modal_image_link_tower = BridgeTowerLinkTower(config)
else:
self.cross_modal_text_link_tower = nn.ModuleList(
[BridgeTowerLinkTower(config) for _ in range(config.num_hidden_layers - 1)]
)
self.cross_modal_image_link_tower = nn.ModuleList(
[BridgeTowerLinkTower(config) for _ in range(config.num_hidden_layers - 1)]
)
self.post_init()
def get_input_embeddings(self):
return self.text_model.get_input_embeddings()
def set_input_embeddings(self, value):
self.text_model.set_input_embeddings(value)
@add_start_docstrings_to_model_forward(BRIDGETOWER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BridgeTowerModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
pixel_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
image_embeds: Optional[torch.FloatTensor] = None,
image_token_type_idx: Optional[int] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
labels: Optional[torch.LongTensor] = None,
def get_cls_features(self, text_features, image_features):
cls_features_text = self.cross_modal_text_pooler(text_features)
cls_features_image = self.cross_modal_image_pooler(image_features)
return torch.cat([cls_features_text, cls_features_image], dim=-1)
class BridgeTowerPredictionHeadTransform(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
if isinstance(config.hidden_act, str):
self.transform_act_fn = ACT2FN[config.hidden_act]
else:
self.transform_act_fn = config.hidden_act
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states
class BridgeTowerMLMHead(nn.Module):
def __init__(self, config, weight=None):
super().__init__()
self.config = config
self.transform = BridgeTowerPredictionHeadTransform(config)
self.decoder = nn.Linear(config.hidden_size, config.text_config.vocab_size, bias=False)
self.bias = nn.Parameter(torch.zeros(config.text_config.vocab_size))
if weight is not None:
self.decoder.weight = weight
def forward(self, x):
mlm_score = self.transform(x)
mlm_score = self.decoder(mlm_score) + self.bias
return mlm_score
class BridgeTowerITMHead(nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.fc = nn.Linear(hidden_size, 2)
def forward(self, x):
itm_score = self.fc(x)
return itm_score
@add_start_docstrings(
"""
使用语言建模头部的 BridgeTower 模型,用于预训练期间的任务。
""",
BRIDGETOWER_START_DOCSTRING,
)
class BridgeTowerForMaskedLM(BridgeTowerPreTrainedModel):
_tied_weights_keys = ["mlm_score.decoder.weight"]
def __init__(self, config):
super().__init__(config)
self.bridgetower = BridgeTowerModel(config)
self.mlm_score = BridgeTowerMLMHead(config)
self.post_init()
def get_output_embeddings(self):
return self.mlm_score.decoder
def set_output_embeddings(self, new_embeddings):
self.mlm_score.decoder = new_embeddings
@add_start_docstrings_to_model_forward(BRIDGETOWER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
pixel_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
image_embeds: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
labels: Optional[torch.LongTensor] = None,
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.bridgetower(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
pixel_values=pixel_values,
pixel_mask=pixel_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
image_embeds=image_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
mlm_logits = self.mlm_score(outputs.text_features if return_dict else outputs[0])
masked_lm_loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
labels = labels.to(mlm_logits.device)
masked_lm_loss = loss_fct(mlm_logits.view(-1, self.config.text_config.vocab_size), labels.view(-1))
if not return_dict:
output = tuple(mlm_logits)
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
return MaskedLMOutput(
loss=masked_lm_loss,
logits=mlm_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"""
BridgeTower Model transformer with a classifier head on top (a linear layer on top of the final hidden state of the
[CLS] token) for image-to-text matching.
""",
BRIDGETOWER_START_DOCSTRING,
)
class BridgeTowerForImageAndTextRetrieval(BridgeTowerPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.bridgetower = BridgeTowerModel(config)
self.itm_score = BridgeTowerITMHead(config.hidden_size * 2)
self.post_init()
@add_start_docstrings_to_model_forward(BRIDGETOWER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
pixel_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
image_embeds: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
labels: Optional[torch.LongTensor] = None,
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.bridgetower(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
pixel_values=pixel_values,
pixel_mask=pixel_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
image_embeds=image_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
pooler_output = outputs.pooler_output if return_dict else outputs[2]
logits = self.itm_score(pooler_output)
itm_loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
labels = labels.to(logits.device)
itm_loss = loss_fct(logits, labels)
if not return_dict:
output = tuple(logits)
return ((itm_loss,) + output) if itm_loss is not None else output
return SequenceClassifierOutput(
loss=itm_loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
class BridgeTowerContrastiveHead(nn.Module):
def __init__(self, hidden_size, embed_size):
super().__init__()
self.fc = nn.Linear(hidden_size, embed_size)
def forward(self, x):
x = self.fc(x)
return x
@add_start_docstrings(
"""
BridgeTower Model with a image-text contrastive head on top computing image-text contrastive loss.
""",
BRIDGETOWER_START_DOCSTRING,
)
class BridgeTowerForContrastiveLearning(BridgeTowerPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.bridgetower = BridgeTowerModel(config)
self.itc_text_head = BridgeTowerContrastiveHead(config.hidden_size, config.contrastive_hidden_size)
self.itc_image_head = BridgeTowerContrastiveHead(config.hidden_size, config.contrastive_hidden_size)
self.itc_cross_modal_head = BridgeTowerContrastiveHead(config.hidden_size * 2, config.contrastive_hidden_size)
self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
self.post_init()
@add_start_docstrings_to_model_forward(BRIDGETOWER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BridgeTowerContrastiveOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
pixel_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
image_embeds: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = True,
return_dict: Optional[bool] = None,
return_loss: Optional[bool] = None,
.\models\bridgetower\processing_bridgetower.py
"""
Processor class for BridgeTower.
"""
from typing import List, Optional, Union
from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
from ...utils import TensorType
class BridgeTowerProcessor(ProcessorMixin):
r"""
Constructs a BridgeTower processor which wraps a Roberta tokenizer and BridgeTower image processor into a single
processor.
[`BridgeTowerProcessor`] offers all the functionalities of [`BridgeTowerImageProcessor`] and
[`RobertaTokenizerFast`]. See the docstring of [`~BridgeTowerProcessor.__call__`] and
[`~BridgeTowerProcessor.decode`] for more information.
Args:
image_processor (`BridgeTowerImageProcessor`):
An instance of [`BridgeTowerImageProcessor`]. The image processor is a required input.
tokenizer (`RobertaTokenizerFast`):
An instance of ['RobertaTokenizerFast`]. The tokenizer is a required input.
"""
attributes = ["image_processor", "tokenizer"]
image_processor_class = "BridgeTowerImageProcessor"
tokenizer_class = ("RobertaTokenizer", "RobertaTokenizerFast")
def __init__(self, image_processor, tokenizer):
super().__init__(image_processor, tokenizer)
def __call__(
self,
images,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
add_special_tokens: bool = True,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = None,
max_length: Optional[int] = None,
stride: int = 0,
pad_to_multiple_of: Optional[int] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
return_length: bool = False,
verbose: bool = True,
return_tensors: Optional[Union[str, TensorType]] = None,
**kwargs,
):
"""
Process images and optionally text into model input.
Args:
images: Input images to be processed.
text: Optional text input, can be either TextInput or PreTokenizedInput format.
add_special_tokens: Whether to add special tokens (like [CLS], [SEP]) to the inputs.
padding: Padding strategy. Can be a bool, str, or PaddingStrategy enum.
truncation: Truncation strategy. Can be a bool, str, or TruncationStrategy enum.
max_length: Maximum length of the returned sequences.
stride: Stride to use when overflowing tokens.
pad_to_multiple_of: Pad to a multiple of specified value.
return_token_type_ids: Whether to return token type ids.
return_attention_mask: Whether to return attention mask.
return_overflowing_tokens: Whether to return overflowing tokens.
return_special_tokens_mask: Whether to return special tokens mask.
return_offsets_mapping: Whether to return offsets mapping.
return_length: Whether to return the lengths of processed inputs.
verbose: Whether to output detailed logs during processing.
return_tensors: Return tensors format (e.g., "pt" for PyTorch tensors).
**kwargs: Additional keyword arguments for processing.
Returns:
BatchEncoding: Processed inputs formatted as BatchEncoding.
Notes:
This method processes images and optionally text into a format suitable for model input,
handling tokenization, padding, truncation, and special token additions as specified.
"""
pass
) -> BatchEncoding:
"""
使用 [`BridgeTowerImageProcessor.__call__`] 方法准备图像以供模型使用,
使用 [`RobertaTokenizerFast.__call__`] 方法准备文本以供模型使用。
更多信息请参考上述两个方法的文档字符串。
"""
encoding = self.tokenizer(
text=text,
add_special_tokens=add_special_tokens,
padding=padding,
truncation=truncation,
max_length=max_length,
stride=stride,
pad_to_multiple_of=pad_to_multiple_of,
return_token_type_ids=return_token_type_ids,
return_attention_mask=return_attention_mask,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_offsets_mapping=return_offsets_mapping,
return_length=return_length,
verbose=verbose,
return_tensors=return_tensors,
**kwargs,
)
encoding_image_processor = self.image_processor(
images, return_tensors=return_tensors, do_normalize=True, do_center_crop=True, **kwargs
)
encoding.update(encoding_image_processor)
return encoding
def batch_decode(self, *args, **kwargs):
"""
将所有参数转发给 RobertaTokenizerFast 的 [`~PreTrainedTokenizer.batch_decode`] 方法。
更多信息请参考该方法的文档字符串。
"""
return self.tokenizer.batch_decode(*args, **kwargs)
def decode(self, *args, **kwargs):
"""
将所有参数转发给 RobertaTokenizerFast 的 [`~PreTrainedTokenizer.decode`] 方法。
更多信息请参考该方法的文档字符串。
"""
return self.tokenizer.decode(*args, **kwargs)
@property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
image_processor_input_names = self.image_processor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
.\models\bridgetower\__init__.py
from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
_import_structure = {
"configuration_bridgetower": [
"BRIDGETOWER_PRETRAINED_CONFIG_ARCHIVE_MAP",
"BridgeTowerConfig",
"BridgeTowerTextConfig",
"BridgeTowerVisionConfig",
],
"processing_bridgetower": ["BridgeTowerProcessor"],
}
try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["image_processing_bridgetower"] = ["BridgeTowerImageProcessor"]
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_bridgetower"] = [
"BRIDGETOWER_PRETRAINED_MODEL_ARCHIVE_LIST",
"BridgeTowerForContrastiveLearning",
"BridgeTowerForImageAndTextRetrieval",
"BridgeTowerForMaskedLM",
"BridgeTowerModel",
"BridgeTowerPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_bridgetower import (
BRIDGETOWER_PRETRAINED_CONFIG_ARCHIVE_MAP,
BridgeTowerConfig,
BridgeTowerTextConfig,
BridgeTowerVisionConfig,
)
from .processing_bridgetower import BridgeTowerProcessor
try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .image_processing_bridgetower import BridgeTowerImageProcessor
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_bridgetower import (
BRIDGETOWER_PRETRAINED_MODEL_ARCHIVE_LIST,
BridgeTowerForContrastiveLearning,
BridgeTowerForImageAndTextRetrieval,
BridgeTowerForMaskedLM,
BridgeTowerModel,
BridgeTowerPreTrainedModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
.\models\bros\configuration_bros.py
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
BROS_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"jinho8345/bros-base-uncased": "https://huggingface.co/jinho8345/bros-base-uncased/blob/main/config.json",
"jinho8345/bros-large-uncased": "https://huggingface.co/jinho8345/bros-large-uncased/blob/main/config.json",
}
class BrosConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`BrosModel`] or a [`TFBrosModel`]. It is used to
instantiate a Bros model according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the Bros
[jinho8345/bros-base-uncased](https://huggingface.co/jinho8345/bros-base-uncased) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
"""
Args:
vocab_size (`int`, *optional*, defaults to 30522):
Bros 模型的词汇表大小,定义了在调用 `BrosModel` 或 `TFBrosModel` 时可以表示的不同 token 数量。
hidden_size (`int`, *optional*, defaults to 768):
编码器层和池化层的维度大小。
num_hidden_layers (`int`, *optional*, defaults to 12):
Transformer 编码器中的隐藏层数量。
num_attention_heads (`int`, *optional*, defaults to 12):
Transformer 编码器中每个注意力层的注意力头数量。
intermediate_size (`int`, *optional*, defaults to 3072):
Transformer 编码器中“中间层”(通常称为前馈层)的维度大小。
hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`):
编码器和池化器中的非线性激活函数(函数或字符串)。支持的字符串有 `"gelu"`, `"relu"`, `"silu"` 和 `"gelu_new"`。
hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
嵌入层、编码器和池化器中所有全连接层的 dropout 概率。
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
注意力概率的 dropout 比率。
max_position_embeddings (`int`, *optional*, defaults to 512):
此模型可能使用的最大序列长度。通常设置为较大的值(例如 512、1024 或 2048)以防万一。
type_vocab_size (`int`, *optional*, defaults to 2):
在调用 `BrosModel` 或 `TFBrosModel` 时传递的 `token_type_ids` 的词汇表大小。
initializer_range (`float`, *optional*, defaults to 0.02):
用于初始化所有权重矩阵的截断正态初始化器的标准差。
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
层归一化层使用的 epsilon 值。
pad_token_id (`int`, *optional*, defaults to 0):
词汇表中填充 token 的索引。
dim_bbox (`int`, *optional*, defaults to 8):
边界框坐标的维度大小。 (x0, y1, x1, y0, x1, y1, x0, y1)
bbox_scale (`float`, *optional*, defaults to 100.0):
边界框坐标的缩放因子。
n_relations (`int`, *optional*, defaults to 1):
SpadeEE(实体提取)、SpadeEL(实体链接)头部的关系数量。
classifier_dropout_prob (`float`, *optional*, defaults to 0.1):
分类器头部的 dropout 比率。
Examples:
```
>>> from transformers import BrosConfig, BrosModel
>>>
>>> configuration = BrosConfig()
configuration = BrosConfig()
>>>
>>> model = BrosModel(configuration)
model = BrosModel(configuration)
>>>
>>> configuration = model.config
configuration = model.config
model_type = "bros"
def __init__(
self,
vocab_size=30522,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
layer_norm_eps=1e-12,
pad_token_id=0,
dim_bbox=8,
bbox_scale=100.0,
n_relations=1,
classifier_dropout_prob=0.1,
**kwargs,
):
model_type = "bros"
def __init__(
self,
vocab_size=30522,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
layer_norm_eps=1e-12,
pad_token_id=0,
dim_bbox=8,
bbox_scale=100.0,
n_relations=1,
classifier_dropout_prob=0.1,
**kwargs,
):
super().__init__(
vocab_size=vocab_size,
hidden_size=hidden_size,
num_hidden_layers=num_hidden_layers,
num_attention_heads=num_attention_heads,
intermediate_size=intermediate_size,
hidden_act=hidden_act,
hidden_dropout_prob=hidden_dropout_prob,
attention_probs_dropout_prob=attention_probs_dropout_prob,
max_position_embeddings=max_position_embeddings,
type_vocab_size=type_vocab_size,
initializer_range=initializer_range,
layer_norm_eps=layer_norm_eps,
pad_token_id=pad_token_id,
**kwargs,
)
super().__init__(
vocab_size=vocab_size,
hidden_size=hidden_size,
num_hidden_layers=num_hidden_layers,
num_attention_heads=num_attention_heads,
intermediate_size=intermediate_size,
hidden_act=hidden_act,
hidden_dropout_prob=hidden_dropout_prob,
attention_probs_dropout_prob=attention_probs_dropout_prob,
max_position_embeddings=max_position_embeddings,
type_vocab_size=type_vocab_size,
initializer_range=initializer_range,
layer_norm_eps=layer_norm_eps,
pad_token_id=pad_token_id,
**kwargs,
)
self.dim_bbox = dim_bbox
self.bbox_scale = bbox_scale
self.n_relations = n_relations
self.dim_bbox_sinusoid_emb_2d = self.hidden_size // 4
self.dim_bbox_sinusoid_emb_1d = self.dim_bbox_sinusoid_emb_2d // self.dim_bbox
self.dim_bbox_projection = self.hidden_size // self.num_attention_heads
self.classifier_dropout_prob = classifier_dropout_prob
self.dim_bbox = dim_bbox
self.bbox_scale = bbox_scale
self.n_relations = n_relations
self.dim_bbox_sinusoid_emb_2d = self.hidden_size // 4
self.dim_bbox_sinusoid_emb_1d = self.dim_bbox_sinusoid_emb_2d // self.dim_bbox
self.dim_bbox_projection = self.hidden_size // self.num_attention_heads
self.classifier_dropout_prob = classifier_dropout_prob
.\models\bros\convert_bros_to_pytorch.py
"""将 Bros 检查点转换为 HuggingFace 模型格式"""
import argparse
import bros
import torch
from transformers import BrosConfig, BrosModel, BrosProcessor
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
def get_configs(model_name):
"""获取指定模型的配置信息"""
bros_config = BrosConfig.from_pretrained(model_name)
return bros_config
def remove_ignore_keys_(state_dict):
"""移除指定的忽略键(如果存在)"""
ignore_keys = [
"embeddings.bbox_sinusoid_emb.inv_freq",
]
for k in ignore_keys:
state_dict.pop(k, None)
def rename_key(name):
"""根据约定重命名给定的键"""
if name == "embeddings.bbox_projection.weight":
name = "bbox_embeddings.bbox_projection.weight"
if name == "embeddings.bbox_sinusoid_emb.x_pos_emb.inv_freq":
name = "bbox_embeddings.bbox_sinusoid_emb.x_pos_emb.inv_freq"
if name == "embeddings.bbox_sinusoid_emb.y_pos_emb.inv_freq":
name = "bbox_embeddings.bbox_sinusoid_emb.y_pos_emb.inv_freq"
return name
def convert_state_dict(orig_state_dict, model):
"""将原始模型状态字典转换为适用于 HuggingFace 模型的格式"""
for key in orig_state_dict.copy().keys():
val = orig_state_dict.pop(key)
orig_state_dict[rename_key(key)] = val
remove_ignore_keys_(orig_state_dict)
return orig_state_dict
def convert_bros_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_hub=False):
"""将 Bros 模型检查点转换为 HuggingFace 模型格式"""
original_model = bros.BrosModel.from_pretrained(model_name).eval()
bros_config = get_configs(model_name)
model = BrosModel.from_pretrained(model_name, config=bros_config)
model.eval()
state_dict = original_model.state_dict()
new_state_dict = convert_state_dict(state_dict, model)
model.load_state_dict(new_state_dict)
bbox = torch.tensor(
[
[
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.4396, 0.6720, 0.4659, 0.6720, 0.4659, 0.6850, 0.4396, 0.6850],
[0.4698, 0.6720, 0.4843, 0.6720, 0.4843, 0.6850, 0.4698, 0.6850],
[0.4698, 0.6720, 0.4843, 0.6720, 0.4843, 0.6850, 0.4698, 0.6850],
[0.2047, 0.6870, 0.2730, 0.6870, 0.2730, 0.7000, 0.2047, 0.7000],
[0.2047, 0.6870, 0.2730, 0.6870, 0.2730, 0.7000, 0.2047, 0.7000],
[1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
]
]
)
processor = BrosProcessor.from_pretrained(model_name)
encoding = processor("His name is Rocco.", return_tensors="pt")
encoding["bbox"] = bbox
original_hidden_states = original_model(**encoding).last_hidden_state
last_hidden_states = model(**encoding).last_hidden_state
assert torch.allclose(original_hidden_states, last_hidden_states, atol=1e-4)
if pytorch_dump_folder_path is not None:
print(f"Saving model and processor to {pytorch_dump_folder_path}")
model.save_pretrained(pytorch_dump_folder_path)
processor.save_pretrained(pytorch_dump_folder_path)
if push_to_hub:
model.push_to_hub("jinho8345/" + model_name.split("/")[-1], commit_message="Update model")
processor.push_to_hub("jinho8345/" + model_name.split("/")[-1], commit_message="Update model")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name",
default="jinho8345/bros-base-uncased",
required=False,
type=str,
help="Name of the original model you'd like to convert.",
)
parser.add_argument(
"--pytorch_dump_folder_path",
default=None,
required=False,
type=str,
help="Path to the output PyTorch model directory.",
)
parser.add_argument(
"--push_to_hub",
action="store_true",
help="Whether or not to push the converted model and processor to the 🤗 hub.",
)
args = parser.parse_args()
convert_bros_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub)
.\models\bros\modeling_bros.py
""" PyTorch Bros model."""
import math
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions,
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from .configuration_bros import BrosConfig
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "jinho8345/bros-base-uncased"
_CONFIG_FOR_DOC = "BrosConfig"
BROS_PRETRAINED_MODEL_ARCHIVE_LIST = [
"jinho8345/bros-base-uncased",
"jinho8345/bros-large-uncased",
]
BROS_START_DOCSTRING = r"""
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`BrosConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
BROS_INPUTS_DOCSTRING = r"""
"""
@dataclass
class BrosSpadeOutput(ModelOutput):
"""
Base class for outputs of token classification models.
This class inherits from `ModelOutput` in Hugging Face's library and serves as a base for outputs
from token classification models specific to the Bros model.
Attributes:
Inherits attributes from `ModelOutput`.
"""
"""
# loss 表示分类损失,默认为 None
loss: Optional[torch.FloatTensor] = None
# initial_token_logits 表示实体初始标记的分类分数,默认为 None
initial_token_logits: torch.FloatTensor = None
# subsequent_token_logits 表示实体序列标记的分类分数,默认为 None
subsequent_token_logits: torch.FloatTensor = None
# hidden_states 表示模型每层的隐藏状态的元组,默认为 None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
# attentions 表示模型每层的注意力权重的元组,默认为 None
attentions: Optional[Tuple[torch.FloatTensor]] = None
class BrosPositionalEmbedding1D(nn.Module):
# 引用:https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/mem_transformer.py#L15
# 一维位置编码的模块定义
def __init__(self, config):
super(BrosPositionalEmbedding1D, self).__init__()
# 初始化函数,接收配置参数 config
self.dim_bbox_sinusoid_emb_1d = config.dim_bbox_sinusoid_emb_1d
# 从配置中获取一维位置编码的维度大小
# 计算正弦函数的频率逆数,用于位置编码
inv_freq = 1 / (
10000 ** (torch.arange(0.0, self.dim_bbox_sinusoid_emb_1d, 2.0) / self.dim_bbox_sinusoid_emb_1d)
)
# 将频率逆数作为缓冲区注册到模块中
self.register_buffer("inv_freq", inv_freq)
def forward(self, pos_seq: torch.Tensor) -> torch.Tensor:
# 前向传播函数,输入位置序列,返回位置编码张量
seq_size = pos_seq.size()
b1, b2, b3 = seq_size
# 获取位置序列的大小
sinusoid_inp = pos_seq.view(b1, b2, b3, 1) * self.inv_freq.view(1, 1, 1, self.dim_bbox_sinusoid_emb_1d // 2)
# 计算正弦输入,使用位置序列乘以频率逆数的张量,并广播到合适的形状
pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1)
# 将正弦和余弦结果连接在一起,得到最终的位置编码张量
return pos_emb
class BrosPositionalEmbedding2D(nn.Module):
# 二维位置编码的模块定义
def __init__(self, config):
super(BrosPositionalEmbedding2D, self).__init__()
# 初始化函数,接收配置参数 config
self.dim_bbox = config.dim_bbox
# 从配置中获取边界框维度的大小
# 创建一维位置编码模块实例,用于X和Y方向
self.x_pos_emb = BrosPositionalEmbedding1D(config)
self.y_pos_emb = BrosPositionalEmbedding1D(config)
def forward(self, bbox: torch.Tensor) -> torch.Tensor:
# 前向传播函数,输入边界框张量,返回位置编码后的张量
stack = []
# 初始化一个空列表,用于存储位置编码的结果
for i in range(self.dim_bbox):
# 遍历边界框维度
if i % 2 == 0:
stack.append(self.x_pos_emb(bbox[..., i]))
# 如果是偶数索引,使用X方向的位置编码模块
else:
stack.append(self.y_pos_emb(bbox[..., i]))
# 如果是奇数索引,使用Y方向的位置编码模块
bbox_pos_emb = torch.cat(stack, dim=-1)
# 将所有位置编码结果连接在一起,形成最终的边界框位置编码张量
return bbox_pos_emb
class BrosBboxEmbeddings(nn.Module):
# 边界框嵌入的模块定义
def __init__(self, config):
super(BrosBboxEmbeddings, self).__init__()
# 初始化函数,接收配置参数 config
self.bbox_sinusoid_emb = BrosPositionalEmbedding2D(config)
# 创建二维位置编码模块实例
self.bbox_projection = nn.Linear(config.dim_bbox_sinusoid_emb_2d, config.dim_bbox_projection, bias=False)
# 创建线性层,用于将二维位置编码映射到边界框投影维度
def forward(self, bbox: torch.Tensor):
# 前向传播函数,输入边界框张量,返回映射后的边界框嵌入张量
bbox_t = bbox.transpose(0, 1)
# 转置边界框张量,使得第一维度和第二维度交换
bbox_pos = bbox_t[None, :, :, :] - bbox_t[:, None, :, :]
# 计算边界框的位置关系张量,使用广播来扩展维度
bbox_pos_emb = self.bbox_sinusoid_emb(bbox_pos)
# 使用二维位置编码模块对位置关系张量进行编码
bbox_pos_emb = self.bbox_projection(bbox_pos_emb)
# 使用线性层对位置编码结果进行投影映射
return bbox_pos_emb
class BrosTextEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings."""
# 文本嵌入的模块定义
# 初始化函数,接受一个配置参数 config
def __init__(self, config):
# 调用父类的初始化方法
super().__init__()
# 创建词嵌入层,用于将词的索引映射成词向量
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
# 创建位置嵌入层,用于将位置索引映射成位置向量
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
# 创建token类型嵌入层,用于将token类型索引映射成token类型向量
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
# 创建LayerNorm层,用于对隐藏状态的归一化处理
# 参数名不符合 snake-case 命名规范,是为了兼容 TensorFlow 的模型变量名
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
# 创建dropout层,用于在训练时进行随机失活处理
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# 位置id (1, len position emb) 在序列化时是连续存储的,并且会被导出
# 根据配置添加绝对或相对的位置嵌入
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
# 注册一个持久的缓冲区 position_ids ,存储连续的位置id
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
# 注册一个非持久的缓冲区 token_type_ids ,存储所有位置的token类型id是0
self.register_buffer(
"token_type_ids",
torch.zeros(
self.position_ids.size(),
dtype=torch.long,
device=self.position_ids.device,
),
persistent=False,
)
# 前向传播函数
def forward(
self,
input_ids: Optional[torch.Tensor] = None, # 输入的词的索引
token_type_ids: Optional[torch.Tensor] = None, # token的类型id
position_ids: Optional[torch.Tensor] = None, # 位置id
inputs_embeds: Optional[torch.Tensor] = None, # 输入的词的向量
past_key_values_length: int = 0, # 之前的键值对的长度
) -> torch.Tensor: # 返回值是张量
# 如果有输入的词的索引,获取其形状,否则获取输入的词向量的形状
if input_ids is not None:
input_shape = input_ids.size()
else:
input_shape = inputs_embeds.size()[:-1]
seq_length = input_shape[1] # 序列长度
# 如果没有指定位置id,将位置id设置为连续的一段
if position_ids is None:
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
# 如果没有指定token类型id,根据情况获取token类型id的值
if token_type_ids is None:
if hasattr(self, "token_type_ids"):
buffered_token_type_ids = self.token_type_ids[:, :seq_length]
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
token_type_ids = buffered_token_type_ids_expanded
else:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
# 如果没有指定输入的词向量,获取输入词的索引对应的词向量
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
# 根据token类型id获取token类型的嵌入向量
token_type_embeddings = self.token_type_embeddings(token_type_ids)
# 计算总的嵌入向量,包括词向量、token类型嵌入、位置嵌入
embeddings = inputs_embeds + token_type_embeddings
# 如果使用绝对位置嵌入,计算并加上位置嵌入向量
if self.position_embedding_type == "absolute":
position_embeddings = self.position_embeddings(position_ids)
embeddings += position_embeddings
# 对总的嵌入向量进行LayerNorm处理
embeddings = self.LayerNorm(embeddings)
# 对处理后的嵌入向量进行随机失活处理
embeddings = self.dropout(embeddings)
# 返回处理后的嵌入向量
return embeddings
class BrosSelfAttention(nn.Module):
def __init__(self, config):
super().__init__()
# 检查隐藏大小是否能被注意力头数整除,同时没有嵌入大小属性
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
f"heads ({config.num_attention_heads})"
)
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
# 创建查询、键、值线性层
self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
# Dropout 层
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
# 如果位置嵌入类型是相对键或相对键查询,则创建距离嵌入层
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
self.is_decoder = config.is_decoder
# 调整形状以便计算注意力分数
def transpose_for_scores(self, x: torch.Tensor):
new_x_shape = x.size()[:-1] + (
self.num_attention_heads,
self.attention_head_size,
)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
self,
hidden_states: torch.Tensor,
bbox_pos_emb: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[torch.Tensor] = False,
):
# 此处是模型的前向传播函数,实现自注意力机制和额外的逻辑
pass # 这里可以根据具体实现添加详细的功能注释
# 从 transformers.models.bert.modeling_bert.BertSelfOutput 复制,将 Bert 改为 Bros
class BrosSelfOutput(nn.Module):
def __init__(self, config):
super().__init__()
# 全连接层
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
# LayerNorm 层
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
# Dropout 层
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# 前向传播函数
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
# 线性层
hidden_states = self.dense(hidden_states)
# Dropout
hidden_states = self.dropout(hidden_states)
# LayerNorm
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class BrosAttention(nn.Module):
def __init__(self, config):
super().__init__()
# 创建自注意力和输出对象
self.self = BrosSelfAttention(config)
self.output = BrosSelfOutput(config)
self.pruned_heads = set() # 用于存储被修剪的注意力头部集合
# 对 self 对象的 heads 进行修剪操作
def prune_heads(self, heads):
# 如果 heads 列表为空,则直接返回,不进行操作
if len(heads) == 0:
return
# 调用 find_pruneable_heads_and_indices 函数查找可修剪的 heads 和对应的索引
heads, index = find_pruneable_heads_and_indices(
heads,
self.self.num_attention_heads,
self.self.attention_head_size,
self.pruned_heads,
)
# 修剪 self.query 线性层
self.self.query = prune_linear_layer(self.self.query, index)
# 修剪 self.key 线性层
self.self.key = prune_linear_layer(self.self.key, index)
# 修剪 self.value 线性层
self.self.value = prune_linear_layer(self.self.value, index)
# 修剪 self.output.dense 线性层,dim=1 表示在第一个维度上进行修剪
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
# 更新超参数并记录被修剪的 heads
# 减去被修剪的 heads 的数量,更新注意力头的数量
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
# 计算所有注意力头的新尺寸
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
# 将被修剪的 heads 添加到 pruned_heads 集合中
self.pruned_heads = self.pruned_heads.union(heads)
# 定义 forward 方法,实现模型的前向传播
def forward(
self,
hidden_states: torch.Tensor,
bbox_pos_emb: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
# 调用 self.self 方法进行自注意力机制计算
self_outputs = self.self(
hidden_states=hidden_states,
bbox_pos_emb=bbox_pos_emb,
attention_mask=attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_value=past_key_value,
output_attentions=output_attentions,
)
# 将 self_outputs[0] 与 hidden_states 作为输入,调用 self.output 方法
attention_output = self.output(self_outputs[0], hidden_states)
# 如果需要输出注意力信息,则将 attentions 添加到 outputs 中
outputs = (attention_output,) + self_outputs[1:] # 如果有输出注意力信息,则添加到 outputs 中
return outputs
# 从 transformers.models.bert.modeling_bert.BertIntermediate 复制而来,修改为 BrosIntermediate
class BrosIntermediate(nn.Module):
def __init__(self, config):
super().__init__()
# 创建一个线性层,将输入特征维度 config.hidden_size 转换为 config.intermediate_size
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
# 根据配置选择激活函数 ACT2FN[config.hidden_act] 或者直接使用给定的激活函数 config.hidden_act
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# 输入 hidden_states 经过线性层变换
hidden_states = self.dense(hidden_states)
# 经过中间激活函数变换
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
# 定义 BrosOutput 类,继承自 nn.Module
class BrosOutput(nn.Module):
def __init__(self, config):
super().__init__()
# 创建一个线性层,将输入特征维度 config.intermediate_size 转换为 config.hidden_size
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
# LayerNorm 归一化层,对隐藏状态进行归一化,eps 是归一化过程中的小数值稳定项
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
# Dropout 层,以 config.hidden_dropout_prob 概率丢弃隐藏状态
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
# 输入 hidden_states 经过线性层变换
hidden_states = self.dense(hidden_states)
# 经过 Dropout 层处理
hidden_states = self.dropout(hidden_states)
# 将输入张量 input_tensor 和处理后的 hidden_states 相加,并经过 LayerNorm 归一化
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
# 定义 BrosLayer 类,继承自 nn.Module
class BrosLayer(nn.Module):
def __init__(self, config):
super().__init__()
# 设置用于 feed forward 的块大小
self.chunk_size_feed_forward = config.chunk_size_feed_forward
# 序列长度维度,用于注意力计算
self.seq_len_dim = 1
# BrosAttention 类的实例,用于处理注意力
self.attention = BrosAttention(config)
# 是否为解码器
self.is_decoder = config.is_decoder
# 是否添加交叉注意力
self.add_cross_attention = config.add_cross_attention
# 如果添加交叉注意力但不是解码器,则抛出异常
if self.add_cross_attention:
if not self.is_decoder:
raise Exception(f"{self} should be used as a decoder model if cross attention is added")
# 否则,创建 BrosAttention 类的实例,用于交叉注意力
self.crossattention = BrosAttention(config)
# BrosIntermediate 类的实例,用于处理中间层
self.intermediate = BrosIntermediate(config)
# BrosOutput 类的实例,用于处理输出层
self.output = BrosOutput(config)
def forward(
self,
hidden_states: torch.Tensor,
bbox_pos_emb: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
# 前向传播函数定义,接受多个输入参数,返回处理后的隐藏状态张量
# hidden_states: 输入的隐藏状态张量
# bbox_pos_emb: 边界框位置嵌入张量
# attention_mask: 注意力掩码张量,可选
# head_mask: 头部掩码张量,可选
# encoder_hidden_states: 编码器隐藏状态张量,可选
# encoder_attention_mask: 编码器注意力掩码张量,可选
# past_key_value: 过去的键值对元组,可选
# output_attentions: 是否输出注意力张量,默认为 False
) -> Tuple[torch.Tensor]:
# 如果有缓存的过去的键/值对,则取前两个(用于自注意力)
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
# 执行自注意力计算,传入隐藏状态、边界框位置嵌入、注意力掩码、头部掩码等参数
self_attention_outputs = self.attention(
hidden_states,
bbox_pos_emb=bbox_pos_emb,
attention_mask=attention_mask,
head_mask=head_mask,
output_attentions=output_attentions,
past_key_value=self_attn_past_key_value,
)
# 获取自注意力输出
attention_output = self_attention_outputs[0]
# 如果是解码器,则最后一个输出是自注意力缓存的元组
if self.is_decoder:
# 输出中除了最后一个元素(自注意力缓存),其余都作为输出
outputs = self_attention_outputs[1:-1]
# 当前的键/值对是最后一个元素
present_key_value = self_attention_outputs[-1]
else:
# 输出中包括自注意力的权重
outputs = self_attention_outputs[1:]
# 跨注意力的当前键/值对初始化为None
cross_attn_present_key_value = None
# 如果是解码器且有编码器的隐藏状态
if self.is_decoder and encoder_hidden_states is not None:
# 如果self对象具有crossattention属性,抛出异常
if hasattr(self, "crossattention"):
raise Exception(
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
)
# 如果有缓存的过去的键/值对,则取后两个(用于跨注意力)
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
# 执行跨注意力计算,传入自注意力输出、注意力掩码、头部掩码、编码器隐藏状态等参数
cross_attention_outputs = self.crossattention(
attention_output,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
cross_attn_past_key_value,
output_attentions,
)
# 获取跨注意力的输出
attention_output = cross_attention_outputs[0]
# 将跨注意力的权重添加到输出中
outputs = outputs + cross_attention_outputs[1:-1]
# 将跨注意力的当前键/值对添加到现有的键/值对中
cross_attn_present_key_value = cross_attention_outputs[-1]
present_key_value = present_key_value + cross_attn_present_key_value
# 应用分块机制到前向传播的输出
layer_output = apply_chunking_to_forward(
self.feed_forward_chunk,
self.chunk_size_feed_forward,
self.seq_len_dim,
attention_output,
)
# 将层输出添加到输出元组中
outputs = (layer_output,) + outputs
# 如果是解码器,将注意力的键/值对作为最后一个输出返回
if self.is_decoder:
outputs = outputs + (present_key_value,)
# 返回所有输出
return outputs
# 定义前馈网络的分块函数,接收注意力输出并返回层输出
def feed_forward_chunk(self, attention_output):
# 执行中间层计算
intermediate_output = self.intermediate(attention_output)
# 执行输出层计算,传入中间输出和注意力输出
layer_output = self.output(intermediate_output, attention_output)
# 返回层输出
return layer_output
# 定义一个用于编码的自定义 PyTorch 模块,继承自 nn.Module
class BrosEncoder(nn.Module):
def __init__(self, config):
super().__init__()
# 初始化模块的配置参数
self.config = config
# 创建多个 BrosLayer 模块组成的列表,数量由配置参数决定
self.layer = nn.ModuleList([BrosLayer(config) for _ in range(config.num_hidden_layers)])
def forward(
self,
hidden_states: torch.Tensor,
bbox_pos_emb: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = True,
# 以下是 BrosPooler 类定义,用于池化模型隐藏状态
class BrosPooler(nn.Module):
def __init__(self, config):
super().__init__()
# 使用线性层将隐藏状态的大小转换为配置参数中的隐藏大小
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
# 使用双曲正切函数作为激活函数
self.activation = nn.Tanh()
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# 简单地使用第一个标记对应的隐藏状态来“池化”模型
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
class BrosRelationExtractor(nn.Module):
def __init__(self, config):
super().__init__()
# 初始化关系抽取器模块的配置参数
self.n_relations = config.n_relations
self.backbone_hidden_size = config.hidden_size
self.head_hidden_size = config.hidden_size
self.classifier_dropout_prob = config.classifier_dropout_prob
# 使用指定的 dropout 概率创建一个 dropout 层
self.drop = nn.Dropout(self.classifier_dropout_prob)
# 使用线性层定义查询(query)操作,将骨干隐藏状态大小映射到关系头大小的多个关系
self.query = nn.Linear(self.backbone_hidden_size, self.n_relations * self.head_hidden_size)
# 使用线性层定义键(key)操作,将骨干隐藏状态大小映射到关系头大小的多个关系
self.key = nn.Linear(self.backbone_hidden_size, self.n_relations * self.head_hidden_size)
# 定义一个虚拟节点,通过 nn.Parameter 创建,值为全零向量
self.dummy_node = nn.Parameter(torch.zeros(1, self.backbone_hidden_size))
def forward(self, query_layer: torch.Tensor, key_layer: torch.Tensor):
# 对查询层进行查询操作,并应用 dropout
query_layer = self.query(self.drop(query_layer))
# 创建一个虚拟向量,将其添加到键层中
dummy_vec = self.dummy_node.unsqueeze(0).repeat(1, key_layer.size(1), 1)
key_layer = torch.cat([key_layer, dummy_vec], axis=0)
# 对键层进行键操作,并应用 dropout
key_layer = self.key(self.drop(key_layer))
# 重新调整查询层和键层的形状以适应多头关系的表示
query_layer = query_layer.view(
query_layer.size(0), query_layer.size(1), self.n_relations, self.head_hidden_size
)
key_layer = key_layer.view(key_layer.size(0), key_layer.size(1), self.n_relations, self.head_hidden_size)
# 计算查询层和键层之间的关系分数,采用矩阵乘法进行计算
relation_score = torch.matmul(
query_layer.permute(2, 1, 0, 3), key_layer.permute(2, 1, 3, 0)
) # 相当于 torch.einsum("ibnd,jbnd->nbij", (query_layer, key_layer))
return relation_score
class BrosPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
# 使用 BrosConfig 作为配置类
config_class = BrosConfig
# 基础模型的名称前缀
base_model_prefix = "bros"
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
# 如果是线性层,使用正态分布初始化权重
# 与 TF 版本稍有不同,TF 使用截断正态分布进行初始化
# 参考 https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
# 如果存在偏置项,则初始化为零向量
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
# 如果是嵌入层,使用正态分布初始化权重
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
# 如果指定了 padding_idx,则将对应位置的权重初始化为零向量
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
# 如果是 LayerNorm 层,初始化偏置为零向量,初始化权重为全1向量
module.bias.data.zero_()
module.weight.data.fill_(1.0)
@add_start_docstrings(
"The bare Bros Model transformer outputting raw hidden-states without any specific head on top.",
BROS_START_DOCSTRING,
)
class BrosModel(BrosPreTrainedModel):
def __init__(self, config, add_pooling_layer=True):
super().__init__(config)
# 初始化 BrosModel 类
self.config = config
# 初始化文本嵌入层、边界框嵌入层和编码器
self.embeddings = BrosTextEmbeddings(config)
self.bbox_embeddings = BrosBboxEmbeddings(config)
self.encoder = BrosEncoder(config)
# 如果需要添加池化层,则初始化池化层
self.pooler = BrosPooler(config) if add_pooling_layer else None
# 初始化模型权重
self.init_weights()
def get_input_embeddings(self):
# 返回文本嵌入层的权重
return self.embeddings.word_embeddings
def set_input_embeddings(self, value):
# 设置文本嵌入层的权重
self.embeddings.word_embeddings = value
def _prune_heads(self, heads_to_prune):
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
# 对模型的注意力头进行剪枝
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
@add_start_docstrings_to_model_forward(BROS_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=BaseModelOutputWithPoolingAndCrossAttentions, config_class=_CONFIG_FOR_DOC)
# 定义一个方法 `forward`,用于执行模型的前向传播操作,通常在神经网络模型中使用
def forward(
self,
input_ids: Optional[torch.Tensor] = None, # 输入的 token IDs,可以是一个 PyTorch Tensor,默认为 None
bbox: Optional[torch.Tensor] = None, # bounding box 数据,用于图像处理或对象识别任务,默认为 None
attention_mask: Optional[torch.Tensor] = None, # 注意力掩码,指定模型注意力的作用范围,默认为 None
token_type_ids: Optional[torch.Tensor] = None, # token 类型 IDs,用于处理多句子任务时区分不同句子,默认为 None
position_ids: Optional[torch.Tensor] = None, # 位置 IDs,指定输入 token 的位置信息,默认为 None
head_mask: Optional[torch.Tensor] = None, # 头部掩码,用于屏蔽某些注意力头,默认为 None
inputs_embeds: Optional[torch.Tensor] = None, # 输入的嵌入向量,用于直接输入嵌入向量而不是 token IDs,默认为 None
encoder_hidden_states: Optional[torch.Tensor] = None, # 编码器的隐藏状态,默认为 None
encoder_attention_mask: Optional[torch.Tensor] = None, # 编码器的注意力掩码,默认为 None
past_key_values: Optional[List[torch.FloatTensor]] = None, # 过去的键值对,用于存储过去的注意力信息,默认为 None
use_cache: Optional[bool] = None, # 是否使用缓存,用于存储中间计算结果以加速反向传播,默认为 None
output_attentions: Optional[bool] = None, # 是否输出注意力权重,默认为 None
output_hidden_states: Optional[bool] = None, # 是否输出隐藏状态,默认为 None
return_dict: Optional[bool] = None, # 是否以字典形式返回输出,默认为 None
# 为 BrosForTokenClassification 类添加文档字符串,描述其作为 Bros 模型的一个带有标记分类头的子类,用于命名实体识别(NER)等任务
@add_start_docstrings(
"""
Bros Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
Named-Entity-Recognition (NER) tasks.
""",
BROS_START_DOCSTRING,
)
class BrosForTokenClassification(BrosPreTrainedModel):
# 在加载时忽略的键列表,遇到未预期的 "pooler" 键时不加载
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config):
# 调用父类的初始化方法,传入配置对象 config
super().__init__(config)
# 初始化模型的标签数量
self.num_labels = config.num_labels
# 初始化 BrosModel,传入配置对象 config
self.bros = BrosModel(config)
# 根据配置设置分类器的 dropout 概率,若配置对象中存在 "classifier_dropout" 属性则使用其值,否则使用隐藏层 dropout 的概率
classifier_dropout = (
config.classifier_dropout if hasattr(config, "classifier_dropout") else config.hidden_dropout_prob
)
# 定义一个 dropout 层,用于分类器
self.dropout = nn.Dropout(classifier_dropout)
# 定义一个线性层,将隐藏状态映射到标签数量的输出空间
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
# 初始化模型权重
self.init_weights()
# 为 forward 方法添加文档字符串,描述输入参数和输出类型,参照 BROS_INPUTS_DOCSTRING 的格式
@add_start_docstrings_to_model_forward(BROS_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
# 替换返回值的文档字符串,指定输出类型为 TokenClassifierOutput,配置类为 _CONFIG_FOR_DOC
@replace_return_docstrings(output_type=TokenClassifierOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
bbox: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
bbox_first_token_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
r"""
Token classification model's forward method.
Args:
input_ids (torch.Tensor): Input token IDs.
bbox (torch.Tensor): Bounding box coordinates for tokens.
attention_mask (torch.Tensor, optional): Mask for attention mechanism.
token_type_ids (torch.Tensor, optional): Type IDs for tokens.
position_ids (torch.Tensor, optional): Positional embeddings.
head_mask (torch.Tensor, optional): Mask for attention heads.
inputs_embeds (torch.Tensor, optional): Embedded inputs.
output_attentions (bool, optional): Whether to output attentions.
output_hidden_states (bool, optional): Whether to output hidden states.
return_dict (bool, optional): Whether to return as a dictionary.
Returns:
Union[Tuple[torch.Tensor], TokenClassifierOutput]: Model outputs.
Examples:
```
>>> import torch
>>> from transformers import BrosProcessor, BrosForTokenClassification
>>> processor = BrosProcessor.from_pretrained("jinho8345/bros-base-uncased")
>>> model = BrosForTokenClassification.from_pretrained("jinho8345/bros-base-uncased")
>>> encoding = processor("Hello, my dog is cute", add_special_tokens=False, return_tensors="pt")
>>> bbox = torch.tensor([[[0, 0, 1, 1]]]).repeat(1, encoding["input_ids"].shape[-1], 1)
>>> encoding["bbox"] = bbox
>>> outputs = model(**encoding)
```
"""
# Determine whether to use the return dictionary format or not
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Pass inputs to the model's token classification method
outputs = self.bros(
input_ids,
bbox=bbox,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
if bbox_first_token_mask is not None:
bbox_first_token_mask = bbox_first_token_mask.view(-1)
loss = loss_fct(
logits.view(-1, self.num_labels)[bbox_first_token_mask], labels.view(-1)[bbox_first_token_mask]
)
else:
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"""
Bros Model with a token classification head on top (a entity_linker layer on top of the hidden-states output) e.g.
for Entity-Linking. The entity_linker is used to predict intra-entity links (one entity to another entity).
""",
BROS_START_DOCSTRING,
)
class BrosSpadeELForTokenClassification(BrosPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config):
super().__init__(config)
self.config = config
self.num_labels = config.num_labels
self.n_relations = config.n_relations
self.backbone_hidden_size = config.hidden_size
self.bros = BrosModel(config)
classifier_dropout = (
config.classifier_dropout if hasattr(config, "classifier_dropout") else config.hidden_dropout_prob
)
self.initial_token_classifier = nn.Sequential(
nn.Dropout(classifier_dropout),
nn.Linear(config.hidden_size, config.hidden_size),
nn.Dropout(classifier_dropout),
nn.Linear(config.hidden_size, config.num_labels),
)
self.entity_linker = BrosRelationExtractor(config)
self.init_weights()
@add_start_docstrings_to_model_forward(BROS_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=BrosSpadeOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
bbox: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
bbox_first_token_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
initial_token_labels: Optional[torch.Tensor] = None,
subsequent_token_labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
def __init__(self, config):
super().__init__(config)
self.config = config
self.num_labels = config.num_labels
self.n_relations = config.n_relations
self.backbone_hidden_size = config.hidden_size
self.bros = BrosModel(config)
(config.classifier_dropout if hasattr(config, "classifier_dropout") else config.hidden_dropout_prob)
self.entity_linker = BrosRelationExtractor(config)
self.init_weights()
@add_start_docstrings_to_model_forward(BROS_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=TokenClassifierOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
bbox: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
bbox_first_token_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
r"""
返回值的类型标注,可以是一个包含 torch.Tensor 的元组或者 TokenClassifierOutput 对象。
Returns:
返回模型预测的输出结果。
Examples:
示例代码展示了如何使用该方法进行预测和处理输出结果。
```
>>> import torch
>>> from transformers import BrosProcessor, BrosSpadeELForTokenClassification
>>> processor = BrosProcessor.from_pretrained("jinho8345/bros-base-uncased")
>>> model = BrosSpadeELForTokenClassification.from_pretrained("jinho8345/bros-base-uncased")
>>> encoding = processor("Hello, my dog is cute", add_special_tokens=False, return_tensors="pt")
>>> bbox = torch.tensor([[[0, 0, 1, 1]]]).repeat(1, encoding["input_ids"].shape[-1], 1)
>>> encoding["bbox"] = bbox
>>> outputs = model(**encoding)
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.bros(
input_ids=input_ids,
bbox=bbox,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_states = outputs[0]
last_hidden_states = last_hidden_states.transpose(0, 1).contiguous()
logits = self.entity_linker(last_hidden_states, last_hidden_states).squeeze(0)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
batch_size, max_seq_length = attention_mask.shape
device = attention_mask.device
self_token_mask = torch.eye(max_seq_length, max_seq_length + 1).to(device).bool()
mask = bbox_first_token_mask.view(-1)
bbox_first_token_mask = torch.cat(
[
~bbox_first_token_mask,
torch.zeros([batch_size, 1], dtype=torch.bool).to(device),
],
axis=1,
)
logits = logits.masked_fill(bbox_first_token_mask[:, None, :], torch.finfo(logits.dtype).min)
logits = logits.masked_fill(self_token_mask[None, :, :], torch.finfo(logits.dtype).min)
loss = loss_fct(logits.view(-1, max_seq_length + 1)[mask], labels.view(-1)[mask])
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)