Transformers 源码解析(八十六)
.\models\owlvit\processing_owlvit.py
"""
Image/Text processor class for OWL-ViT
"""
import warnings
from typing import List
import numpy as np
from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import BatchEncoding
from ...utils import is_flax_available, is_tf_available, is_torch_available
class OwlViTProcessor(ProcessorMixin):
r"""
Constructs an OWL-ViT processor which wraps [`OwlViTImageProcessor`] and [`CLIPTokenizer`]/[`CLIPTokenizerFast`]
into a single processor that interits both the image processor and tokenizer functionalities. See the
[`~OwlViTProcessor.__call__`] and [`~OwlViTProcessor.decode`] for more information.
Args:
image_processor ([`OwlViTImageProcessor`], *optional*):
The image processor is a required input.
tokenizer ([`CLIPTokenizer`, `CLIPTokenizerFast`], *optional*):
The tokenizer is a required input.
"""
attributes = ["image_processor", "tokenizer"]
image_processor_class = "OwlViTImageProcessor"
tokenizer_class = ("CLIPTokenizer", "CLIPTokenizerFast")
def __init__(self, image_processor=None, tokenizer=None, **kwargs):
feature_extractor = None
if "feature_extractor" in kwargs:
warnings.warn(
"The `feature_extractor` argument is deprecated and will be removed in v5, use `image_processor`"
" instead.",
FutureWarning,
)
feature_extractor = kwargs.pop("feature_extractor")
image_processor = image_processor if image_processor is not None else feature_extractor
if image_processor is None:
raise ValueError("You need to specify an `image_processor`.")
if tokenizer is None:
raise ValueError("You need to specify a `tokenizer`.")
super().__init__(image_processor, tokenizer)
def post_process(self, *args, **kwargs):
"""
This method forwards all its arguments to [`OwlViTImageProcessor.post_process`]. Please refer to the docstring
of this method for more information.
"""
return self.image_processor.post_process(*args, **kwargs)
def post_process_object_detection(self, *args, **kwargs):
"""
将所有参数转发到 `OwlViTImageProcessor.post_process_object_detection` 方法中。
请参阅该方法的文档字符串获取更多信息。
"""
return self.image_processor.post_process_object_detection(*args, **kwargs)
def post_process_image_guided_detection(self, *args, **kwargs):
"""
将所有参数转发到 `OwlViTImageProcessor.post_process_one_shot_object_detection` 方法中。
请参阅该方法的文档字符串获取更多信息。
"""
return self.image_processor.post_process_image_guided_detection(*args, **kwargs)
def batch_decode(self, *args, **kwargs):
"""
将所有参数转发到 CLIPTokenizerFast 的 `~PreTrainedTokenizer.batch_decode` 方法中。
请参阅该方法的文档字符串获取更多信息。
"""
return self.tokenizer.batch_decode(*args, **kwargs)
def decode(self, *args, **kwargs):
"""
将所有参数转发到 CLIPTokenizerFast 的 `~PreTrainedTokenizer.decode` 方法中。
请参阅该方法的文档字符串获取更多信息。
"""
return self.tokenizer.decode(*args, **kwargs)
@property
def feature_extractor_class(self):
"""
警告:`feature_extractor_class` 已弃用,并将在 v5 版本中移除。请使用 `image_processor_class` 替代。
返回 `image_processor_class`。
"""
warnings.warn(
"`feature_extractor_class` is deprecated and will be removed in v5. Use `image_processor_class` instead.",
FutureWarning,
)
return self.image_processor_class
@property
def feature_extractor(self):
"""
警告:`feature_extractor` 已弃用,并将在 v5 版本中移除。请使用 `image_processor` 替代。
返回 `image_processor`。
"""
warnings.warn(
"`feature_extractor` is deprecated and will be removed in v5. Use `image_processor` instead.",
FutureWarning,
)
return self.image_processor
.\models\owlvit\__init__.py
from typing import TYPE_CHECKING
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_flax_available,
is_tf_available,
is_tokenizers_available,
is_torch_available,
is_vision_available,
)
_import_structure = {
"configuration_owlvit": [
"OWLVIT_PRETRAINED_CONFIG_ARCHIVE_MAP",
"OwlViTConfig",
"OwlViTOnnxConfig",
"OwlViTTextConfig",
"OwlViTVisionConfig",
],
"processing_owlvit": ["OwlViTProcessor"],
}
try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["feature_extraction_owlvit"] = ["OwlViTFeatureExtractor"]
_import_structure["image_processing_owlvit"] = ["OwlViTImageProcessor"]
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_owlvit"] = [
"OWLVIT_PRETRAINED_MODEL_ARCHIVE_LIST",
"OwlViTModel",
"OwlViTPreTrainedModel",
"OwlViTTextModel",
"OwlViTVisionModel",
"OwlViTForObjectDetection",
]
if TYPE_CHECKING:
from .configuration_owlvit import (
OWLVIT_PRETRAINED_CONFIG_ARCHIVE_MAP,
OwlViTConfig,
OwlViTOnnxConfig,
OwlViTTextConfig,
OwlViTVisionConfig,
)
from .processing_owlvit import OwlViTProcessor
try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .feature_extraction_owlvit import OwlViTFeatureExtractor
from .image_processing_owlvit import OwlViTImageProcessor
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_owlvit import (
OWLVIT_PRETRAINED_MODEL_ARCHIVE_LIST,
OwlViTForObjectDetection,
OwlViTModel,
OwlViTPreTrainedModel,
OwlViTTextModel,
OwlViTVisionModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
.\models\patchtsmixer\configuration_patchtsmixer.py
from typing import List, Optional, Union
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
PATCHTSMIXER_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"ibm/patchtsmixer-etth1-pretrain": "https://huggingface.co/ibm/patchtsmixer-etth1-pretrain/resolve/main/config.json",
}
class PatchTSMixerConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`PatchTSMixerModel`]. It is used to instantiate a
PatchTSMixer model according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the PatchTSMixer
[ibm/patchtsmixer-etth1-pretrain](https://huggingface.co/ibm/patchtsmixer-etth1-pretrain) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Example:
```
>>> from transformers import PatchTSMixerConfig, PatchTSMixerModel
>>> # Initializing a default PatchTSMixer configuration
>>> configuration = PatchTSMixerConfig()
>>> # Randomly initializing a model (with random weights) from the configuration
>>> model = PatchTSMixerModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
"""
model_type = "patchtsmixer"
attribute_map = {
"hidden_size": "d_model",
"num_hidden_layers": "num_layers",
}
def __init__(
self,
context_length: int = 32,
patch_length: int = 8,
num_input_channels: int = 1,
patch_stride: int = 8,
num_parallel_samples: int = 100,
d_model: int = 8,
expansion_factor: int = 2,
num_layers: int = 3,
dropout: float = 0.2,
mode: str = "common_channel",
gated_attn: bool = True,
norm_mlp: str = "LayerNorm",
self_attn: bool = False,
self_attn_heads: int = 1,
use_positional_encoding: bool = False,
positional_encoding_type: str = "sincos",
scaling: Optional[Union[str, bool]] = "std",
loss: str = "mse",
init_std: float = 0.02,
post_init: bool = False,
norm_eps: float = 1e-5,
mask_type: str = "random",
random_mask_ratio: float = 0.5,
num_forecast_mask_patches: Optional[Union[List[int], int]] = [2],
mask_value: int = 0,
masked_loss: bool = True,
channel_consistent_masking: bool = True,
unmasked_channel_indices: Optional[List[int]] = None,
head_dropout: float = 0.2,
distribution_output: str = "student_t",
prediction_length: int = 16,
prediction_channel_indices: list = None,
num_targets: int = 3,
output_range: list = None,
head_aggregation: str = "max_pool",
**kwargs,
):
self.num_input_channels = num_input_channels
self.context_length = context_length
self.patch_length = patch_length
self.patch_stride = patch_stride
self.d_model = d_model
self.expansion_factor = expansion_factor
self.num_layers = num_layers
self.dropout = dropout
self.mode = mode
self.gated_attn = gated_attn
self.norm_mlp = norm_mlp
self.scaling = scaling
self.head_dropout = head_dropout
self.num_patches = (max(context_length, patch_length) - patch_length) // patch_stride + 1
self.mask_type = mask_type
self.random_mask_ratio = random_mask_ratio
self.num_forecast_mask_patches = num_forecast_mask_patches
self.mask_value = mask_value
self.channel_consistent_masking = channel_consistent_masking
self.masked_loss = masked_loss
self.patch_last = True
self.use_positional_encoding = use_positional_encoding
self.positional_encoding_type = positional_encoding_type
self.prediction_length = prediction_length
self.prediction_channel_indices = prediction_channel_indices
self.num_targets = num_targets
self.output_range = output_range
self.head_aggregation = head_aggregation
self.self_attn = self_attn
self.self_attn_heads = self_attn_heads
self.init_std = init_std
self.post_init = post_init
self.distribution_output = distribution_output
self.loss = loss
self.num_parallel_samples = num_parallel_samples
self.unmasked_channel_indices = unmasked_channel_indices
self.norm_eps = norm_eps
super().__init__(**kwargs)
.\models\patchtsmixer\modeling_patchtsmixer.py
import math
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import ModelOutput
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from .configuration_patchtsmixer import PatchTSMixerConfig
logger = logging.get_logger(__name__)
PATCHTSMIXER_PRETRAINED_MODEL_ARCHIVE_LIST = [
"ibm/patchtsmixer-etth1-pretrain",
]
PATCHTSMIXER_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`PatchTSMixerConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
mask_input (`bool`, *optional*, defaults to `False`):
If True, Masking will be enabled. False otherwise.
"""
PATCHTSMIXER_INPUTS_DOCSTRING = r"""
# 定义函数的参数和类型注解,输入参数为过去时间序列的值
# 对于预训练任务,这表示要预测掩码部分的输入时间序列;对于预测任务,这表示历史/过去的时间序列值;
# 对于分类或回归任务,这表示时间序列的上下文值。
# 对于单变量时间序列,num_input_channels 维度应为 1;对于多变量时间序列,它大于 1。
Args:
past_values (`torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`):
Context values of the time series. For a pretraining task, this denotes the input time series to predict
the masked portion. For a forecasting task, this denotes the history/past time series values. Similarly,
for classification or regression tasks, it denotes the appropriate context values of the time series.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
class PatchTSMixerGatedAttention(nn.Module):
"""
Module that applies gated attention to input data.
Args:
in_size (`int`): The input size.
out_size (`int`): The output size.
"""
def __init__(self, in_size: int, out_size: int):
super().__init__()
self.attn_layer = nn.Linear(in_size, out_size)
self.attn_softmax = nn.Softmax(dim=-1)
def forward(self, inputs):
attn_weight = self.attn_softmax(self.attn_layer(inputs))
inputs = inputs * attn_weight
return inputs
class PatchTSMixerBatchNorm(nn.Module):
"""
Compute batch normalization over the sequence length (time) dimension.
"""
def __init__(self, config: PatchTSMixerConfig):
super().__init__()
self.batchnorm = nn.BatchNorm1d(config.d_model, eps=config.norm_eps)
def forward(self, inputs: torch.Tensor):
"""
Parameters:
inputs (`torch.Tensor` of shape `(batch_size, sequence_length, d_model)`):
input for Batch norm calculation
Returns:
`torch.Tensor` of shape `(batch_size, sequence_length, d_model)`
"""
output = inputs.transpose(1, 2)
output = self.batchnorm(output)
return output.transpose(1, 2)
class PatchTSMixerPositionalEncoding(nn.Module):
"""
Class for positional encoding
"""
def __init__(self, config: PatchTSMixerConfig):
super().__init__()
if config.use_positional_encoding:
self.position_enc = self._init_pe(config)
else:
self.position_enc = nn.Parameter(torch.zeros(config.num_patches, config.d_model))
@staticmethod
def _init_pe(config: PatchTSMixerConfig) -> nn.Parameter:
if config.positional_encoding_type == "random":
position_enc = nn.Parameter(torch.randn(config.num_patches, config.d_model), requires_grad=True)
elif config.positional_encoding_type == "sincos":
position_enc = torch.zeros(config.num_patches, config.d_model)
position = torch.arange(0, config.num_patches).unsqueeze(1)
div_term = torch.exp(torch.arange(0, config.d_model, 2) * -(math.log(10000.0) / config.d_model))
position_enc[:, 0::2] = torch.sin(position * div_term)
position_enc[:, 1::2] = torch.cos(position * div_term)
position_enc = position_enc - position_enc.mean()
position_enc = position_enc / (position_enc.std() * 10)
position_enc = nn.Parameter(position_enc, requires_grad=False)
else:
raise ValueError(
f"{config.positional_encoding_type} is not a valid positional encoder. Available types are 'random' and 'sincos'."
)
return position_enc
def forward(self, patch_input: torch.Tensor):
hidden_state = patch_input + self.position_enc
return hidden_state
class PatchTSMixerNormLayer(nn.Module):
"""Normalization block
Args:
config (`PatchTSMixerConfig`, *required*):
Configuration.
"""
def __init__(self, config: PatchTSMixerConfig):
super().__init__()
self.norm_mlp = config.norm_mlp
if "batch" in config.norm_mlp.lower():
self.norm = PatchTSMixerBatchNorm(config)
else:
self.norm = nn.LayerNorm(config.d_model, eps=config.norm_eps)
def forward(self, inputs: torch.Tensor):
"""
Args:
inputs (`torch.Tensor` of shape `((batch_size, num_channels, num_patches, d_model))`):
Input to the normalization layer.
Returns:
`torch.Tensor` of shape `((batch_size, num_channels, num_patches, d_model))`
"""
if "batch" in self.norm_mlp.lower():
inputs_reshaped = torch.reshape(
inputs,
(
inputs.shape[0] * inputs.shape[1],
inputs.shape[2],
inputs.shape[3],
),
)
inputs_reshaped = self.norm(inputs_reshaped)
inputs = torch.reshape(inputs_reshaped, inputs.shape)
else:
inputs = self.norm(inputs)
return inputs
class PatchTSMixerMLP(nn.Module):
def __init__(self, in_features, out_features, config):
super().__init__()
num_hidden = in_features * config.expansion_factor
self.fc1 = nn.Linear(in_features, num_hidden)
self.dropout1 = nn.Dropout(config.dropout)
self.fc2 = nn.Linear(num_hidden, out_features)
self.dropout2 = nn.Dropout(config.dropout)
def forward(self, inputs: torch.Tensor):
"""
Args:
inputs (`torch.Tensor` of shape `((batch_size, num_channels, num_patches, d_model))`):
Input to the MLP layer.
Returns:
`torch.Tensor` of the same shape as `inputs`
"""
inputs = self.dropout1(nn.functional.gelu(self.fc1(inputs)))
inputs = self.fc2(inputs)
inputs = self.dropout2(inputs)
return inputs
class PatchTSMixerChannelFeatureMixerBlock(nn.Module):
"""This module mixes the features in the channel dimension.
Args:
config (`PatchTSMixerConfig`, *required*):
Configuration.
"""
def __init__(self, config: PatchTSMixerConfig):
super().__init__()
self.norm = PatchTSMixerNormLayer(config)
self.gated_attn = config.gated_attn
self.mlp = PatchTSMixerMLP(
in_features=config.num_input_channels,
out_features=config.num_input_channels,
config=config,
)
if config.gated_attn:
self.gating_block = PatchTSMixerGatedAttention(
in_size=config.num_input_channels, out_size=config.num_input_channels
)
def forward(self, inputs: torch.Tensor):
"""
Args:
inputs (`torch.Tensor` of shape `((batch_size, num_channels, num_patches, d_model))`):
input to the MLP layer
Returns:
`torch.Tensor` of the same shape as `inputs`
"""
residual = inputs
inputs = self.norm(inputs)
inputs = inputs.permute(0, 3, 2, 1)
if self.gated_attn:
inputs = self.gating_block(inputs)
inputs = self.mlp(inputs)
inputs = inputs.permute(0, 3, 2, 1)
out = inputs + residual
return out
class PatchTSMixerAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
is_decoder: bool = False,
bias: bool = True,
is_causal: bool = False,
config: Optional[PatchTSMixerConfig] = None,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
self.config = config
if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {num_heads})."
)
self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
self.is_causal = is_causal
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
):
class PatchMixerBlock(nn.Module):
"""This module mixes the patch dimension.
Args:
config (`PatchTSMixerConfig`, *required*):
Configuration.
"""
def __init__(self, config: PatchTSMixerConfig):
super().__init__()
self.norm = PatchTSMixerNormLayer(config)
self.self_attn = config.self_attn
self.gated_attn = config.gated_attn
self.mlp = PatchTSMixerMLP(
in_features=config.num_patches,
out_features=config.num_patches,
config=config,
)
if config.gated_attn:
self.gating_block = PatchTSMixerGatedAttention(in_size=config.num_patches, out_size=config.num_patches)
if config.self_attn:
self.self_attn_layer = PatchTSMixerAttention(
embed_dim=config.d_model,
num_heads=config.self_attn_heads,
dropout=config.dropout,
)
self.norm_attn = PatchTSMixerNormLayer(config)
def forward(self, hidden_state):
"""
Args:
hidden_state (`torch.Tensor`): Input tensor.
Returns:
`torch.Tensor`: Transformed tensor.
"""
residual = hidden_state
hidden_state = self.norm(hidden_state)
if self.self_attn:
batch_size, n_vars, num_patches, d_model = hidden_state.shape
hidden_state_reshaped = hidden_state.reshape(batch_size * n_vars, num_patches, d_model)
x_attn, _, _ = self.self_attn_layer(hidden_state_reshaped, output_attentions=False)
x_attn = x_attn.reshape(batch_size, n_vars, num_patches, d_model)
hidden_state = hidden_state.transpose(2, 3)
hidden_state = self.mlp(hidden_state)
if self.gated_attn:
hidden_state = self.gating_block(hidden_state)
hidden_state = hidden_state.transpose(2, 3)
if self.self_attn:
hidden_state = self.norm_attn(hidden_state + x_attn)
out = hidden_state + residual
return out
class FeatureMixerBlock(nn.Module):
"""This module mixes the hidden feature dimension.
Args:
config (`PatchTSMixerConfig`, *required`):
Configuration.
"""
def __init__(self, config: PatchTSMixerConfig):
super().__init__()
self.norm = PatchTSMixerNormLayer(config)
self.gated_attn = config.gated_attn
self.mlp = PatchTSMixerMLP(
in_features=config.d_model,
out_features=config.d_model,
config=config,
)
if config.gated_attn:
self.gating_block = PatchTSMixerGatedAttention(in_size=config.d_model, out_size=config.d_model)
def forward(self, hidden: torch.Tensor):
"""
Args:
hidden (`torch.Tensor` of shape `(batch_size, num_patches, d_model)`):
Input tensor to the layer.
Returns:
`torch.Tensor`: Transformed tensor.
"""
residual = hidden
hidden = self.norm(hidden)
hidden = self.mlp(hidden)
if self.gated_attn:
hidden = self.gating_block(hidden)
out = hidden + residual
return out
class PatchTSMixerLayer(nn.Module):
"""
The `PatchTSMixer` layer that does all three kinds of mixing.
Args:
config (`PatchTSMixerConfig`, *required`):
Configuration.
"""
def __init__(self, config: PatchTSMixerConfig):
super().__init__()
self.patch_mixer = PatchMixerBlock(config=config)
self.feature_mixer = FeatureMixerBlock(config=config)
self.mode = config.mode
if config.mode == "mix_channel":
self.channel_feature_mixer = PatchTSMixerChannelFeatureMixerBlock(config=config)
def forward(self, hidden: torch.Tensor):
"""
Args:
hidden (`torch.Tensor` of shape `(batch_size, num_patches, d_model)`):
Input tensor to the layer.
Returns:
`torch.Tensor`: Transformed tensor.
"""
if self.mode == "mix_channel":
hidden = self.channel_feature_mixer(hidden)
hidden = self.patch_mixer(hidden)
hidden = self.feature_mixer(hidden)
return hidden
class PatchTSMixerBlock(nn.Module):
"""The main computing framework of the `PatchTSMixer` model.
Args:
config (`PatchTSMixerConfig`, *required`):
Configuration.
"""
def __init__(self, config: PatchTSMixerConfig):
super().__init__()
num_layers = config.num_layers
self.mixers = nn.ModuleList([PatchTSMixerLayer(config=config) for _ in range(num_layers)])
def forward(self, hidden_state, output_hidden_states: bool = False):
"""
Args:
hidden_state (`torch.Tensor`): 输入张量。
output_hidden_states (`bool`, *optional*, 默认为 False):
是否输出所有隐藏状态。
Returns:
`torch.Tensor`: 嵌入结果。 `list`: 如果 `output_hidden_states` 设置为 `True`,则返回所有隐藏状态的列表。
"""
all_hidden_states = []
embedding = hidden_state
for mod in self.mixers:
embedding = mod(embedding)
if output_hidden_states:
all_hidden_states.append(embedding)
if output_hidden_states:
return embedding, all_hidden_states
else:
return embedding, None
class PatchTSMixerForPredictionHead(nn.Module):
"""Prediction Head for Forecasting
Args:
config (`PatchTSMixerConfig`, *required*): Configuration.
"""
def __init__(self, config: PatchTSMixerConfig, distribution_output=None):
super().__init__()
self.prediction_channel_indices = config.prediction_channel_indices
if self.prediction_channel_indices is not None:
self.prediction_channel_indices.sort()
self.dropout_layer = nn.Dropout(config.head_dropout)
if distribution_output is None:
self.base_forecast_block = nn.Linear((config.num_patches * config.d_model), config.prediction_length)
else:
self.base_forecast_block = distribution_output.get_parameter_projection(
config.num_patches * config.d_model
)
self.flatten = nn.Flatten(start_dim=-2)
def forward(self, hidden_features):
"""
Args:
hidden_features (`torch.Tensor` of shape `(batch_size, num_patch, d_model)` in `flatten` mode
or `(batch_size, n_vars, num_patch, d_model)` in `common_channel`/`mix_channel` mode.): Input hidden
features.
Returns:
`torch.Tensor` of shape `(batch_size, prediction_length, nvars)`.
"""
hidden_features = self.flatten(hidden_features)
hidden_features = self.dropout_layer(hidden_features)
forecast = self.base_forecast_block(hidden_features)
if isinstance(forecast, tuple):
forecast = tuple(z.transpose(-1, -2) for z in forecast)
else:
forecast = forecast.transpose(-1, -2)
if self.prediction_channel_indices is not None:
if isinstance(forecast, tuple):
forecast = tuple(z[..., self.prediction_channel_indices] for z in forecast)
else:
forecast = forecast[..., self.prediction_channel_indices]
return forecast
class PatchTSMixerLinearHead(nn.Module):
"""Linear head for Classification and Regression.
Args:
config (`PatchTSMixerConfig`, *required*): Configuration.
"""
def __init__(self, config: PatchTSMixerConfig, distribution_output=None):
super().__init__()
self.head_aggregation = config.head_aggregation
self.output_range = config.output_range
if config.head_aggregation is None:
mul_factor = config.num_patches
else:
mul_factor = 1
self.distribution_output = distribution_output
if distribution_output is None:
self.projection = nn.Linear(
config.d_model * config.num_input_channels * mul_factor,
config.num_targets,
)
else:
self.projection = distribution_output.get_parameter_projection(
config.d_model * config.num_input_channels * mul_factor
)
if config.head_aggregation is None:
self.flatten = nn.Flatten(start_dim=-3)
else:
self.flatten = nn.Flatten(start_dim=-2)
self.dropout = nn.Dropout(config.head_dropout)
def forward(self, hidden_features):
"""
Args:
hidden_features (`torch.Tensor` of shape `(batch_size x num_patch x d_model)` in `flatten` mode
or `(batch_size x n_vars x num_patch x d_model)` in `common_channel`/`mix_channel` mode.): Input hidden
features.
Returns:
`torch.Tensor` of shape `(batch_size x num_targets)`.
"""
hidden_features = hidden_features.transpose(-1, -2)
if self.head_aggregation == "use_last":
hidden_features = hidden_features[..., -1]
elif self.head_aggregation == "max_pool":
hidden_features = hidden_features.max(dim=-1).values
elif self.head_aggregation == "avg_pool":
hidden_features = hidden_features.mean(dim=-1)
if self.flatten:
hidden_features = self.flatten(hidden_features)
hidden_features = self.dropout(hidden_features)
hidden_features = self.projection(hidden_features)
if (self.distribution_output is None) and (self.output_range is not None):
hidden_features = (
torch.sigmoid(hidden_features) * (self.output_range[1] - self.output_range[0]) + self.output_range[0]
)
return hidden_features
class PatchTSMixerPreTrainedModel(PreTrainedModel):
config_class = PatchTSMixerConfig
base_model_prefix = "model"
main_input_name = "past_values"
supports_gradient_checkpointing = False
def _init_weights(self, module):
"""Initialize weights"""
if isinstance(module, PatchTSMixerPositionalEncoding):
if self.config.positional_encoding_type == "random":
nn.init.normal_(module.position_enc, mean=0.0, std=0.1)
elif isinstance(module, (nn.LayerNorm, nn.BatchNorm1d)):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, PatchTSMixerBatchNorm):
module.batchnorm.bias.data.zero_()
module.batchnorm.weight.data.fill_(1.0)
elif isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.init_std)
if module.bias is not None:
module.bias.data.zero_()
class PatchTSMixerPretrainHead(nn.Module):
"""Pretraining head.
Args:
config (`PatchTSMixerConfig`, *required*):
Configuration.
"""
def __init__(self, config: PatchTSMixerConfig):
super().__init__()
self.dropout_layer = nn.Dropout(config.head_dropout)
self.base_pt_block = nn.Linear(config.d_model, config.patch_length)
def forward(self, hidden_features):
"""
Args:
hidden_features (`torch.Tensor` of shape `(batch_size x num_patch x d_model)` in `flatten` mode
or `(batch_size x n_vars x num_patch x d_model)` in `common_channel`/`mix_channel` mode.): Input hidden
features.
Returns:
`torch.Tensor` of shape `(batch_size x n_vars x num_patch x patch_length)`.
"""
hidden_features = self.dropout_layer(hidden_features)
forecast = self.base_pt_block(hidden_features)
return forecast
def random_masking(
inputs: torch.Tensor,
mask_ratio: float,
unmasked_channel_indices: list = None,
channel_consistent_masking: bool = False,
mask_value: int = 0,
):
"""random_masking: Mask the input considering the control variables.
Args:
inputs (torch.Tensor): Input tensor to be masked.
mask_ratio (float): Ratio of elements to be masked.
unmasked_channel_indices (list, optional): List of unmasked channel indices.
channel_consistent_masking (bool, optional): Whether to mask consistently across channels.
mask_value (int, optional): Value to fill in for masked elements.
Returns:
torch.Tensor: Masked input tensor.
"""
if mask_ratio < 0 or mask_ratio >= 1:
raise ValueError(f"Mask ratio {mask_ratio} has to be between 0 and 1.")
batch_size, num_channels, sequence_length, num_features = inputs.shape
device = inputs.device
len_keep = int(sequence_length * (1 - mask_ratio))
if channel_consistent_masking:
noise = torch.rand(batch_size, 1, sequence_length, device=device)
noise = noise.repeat(1, num_channels, 1)
else:
noise = torch.rand(batch_size, num_channels, sequence_length, device=device)
mask = torch.ones(batch_size, num_channels, sequence_length, device=device)
mask[:, :, :len_keep] = 0
ids_shuffle = torch.argsort(noise, dim=-1)
ids_restore = torch.argsort(ids_shuffle, dim=-1)
mask = torch.gather(mask, dim=-1, index=ids_restore)
mask = mask.unsqueeze(-1).repeat(1, 1, 1, num_features)
if unmasked_channel_indices is not None:
mask[:, unmasked_channel_indices, :, :] = 0
inputs_mask = inputs.masked_fill(mask.bool(), mask_value)
return inputs_mask, mask[..., 0]
def forecast_masking(
inputs: torch.Tensor,
num_forecast_mask_patches: Union[list, int],
unmasked_channel_indices: list = None,
mask_value: int = 0,
):
"""Forecast masking that masks the last K patches where K is from the num_forecast_mask_patches.
If num_forecast_mask_patches is a list, samples in the batch will be randomly masked by numbers defined in the list.
Parameters:
inputs (`torch.Tensor`):
Input of shape `(bs, num_channels, num_patch, patch_length)`
num_forecast_mask_patches (`list`):
Number of patches to be masked at the end of each batch sample. e.g. 4 or [3, 5].
unmasked_channel_indices (`list`, *optional*):
Indices of channels that are not masked.
mask_value (`int`, *optional*, defaults to 0):
Values in the masked patches will be filled by `mask_value`.
Returns:
`tuple(torch.Tensor)`: inputs_mask, masked input, same shape as inputs Tensor and Mask tensor of shape `(bs,
num_channels , num_patch)` or `(bs, tsg1, tsg2, num_channels, num_patch)`
"""
if isinstance(num_forecast_mask_patches, int):
num_forecast_mask_patches = [num_forecast_mask_patches]
forecast_mask_ratios = [1 for _ in num_forecast_mask_patches]
batch_size, num_channels, sequence_length, num_features = inputs.shape
mask = torch.zeros(batch_size, num_channels, sequence_length, device=inputs.device)
t_list = []
total_length = 0
total_ratio = sum(forecast_mask_ratios)
for patch_length, ratio in zip(num_forecast_mask_patches, forecast_mask_ratios):
if patch_length <= 0 or patch_length >= sequence_length:
raise ValueError(
f"num_forecast_mask_patches {patch_length} should be greater than 0 and less than total patches."
)
temp_len = int(batch_size * ratio / total_ratio)
t_list.append([patch_length, ratio, temp_len])
total_length += temp_len
t_list = sorted(t_list, key=lambda x: x[2])
if total_length < batch_size:
t_list[0][2] = t_list[0][2] + (batch_size - total_length)
elif total_length > batch_size:
t_list[-1][2] = t_list[-1][2] + (total_length - batch_size)
batch1 = 0
for patch_len, _, temp_len in t_list:
batch2 = batch1 + temp_len
mask[batch1:batch2, :, -patch_len:] = 1
batch1 = batch2
perm = torch.randperm(mask.shape[0])
mask = mask[perm]
mask = mask.unsqueeze(-1).repeat(1, 1, 1, num_features)
if unmasked_channel_indices is not None:
mask[:, unmasked_channel_indices, :, :] = 0
inputs_mask = inputs.masked_fill(mask.bool(), mask_value)
return inputs_mask, mask[..., 0]
class PatchTSMixerPatchify(nn.Module):
"""
A class to patchify the time series sequence into different patches
Returns:
`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`
"""
def __init__(self, config: PatchTSMixerConfig):
super().__init__()
self.sequence_length = config.context_length
self.patch_length = config.patch_length
self.patch_stride = config.patch_stride
if self.sequence_length <= self.patch_length:
raise ValueError(
f"Sequence length ({self.sequence_length}) has to be greater than the patch length ({self.patch_length})"
)
self.num_patches = (max(self.sequence_length, self.patch_length) - self.patch_length) // self.patch_stride + 1
new_sequence_length = self.patch_length + self.patch_stride * (self.num_patches - 1)
self.sequence_start = self.sequence_length - new_sequence_length
def forward(self, past_values: torch.Tensor):
"""
Parameters:
past_values (`torch.Tensor` of shape `(batch_size, sequence_length, num_channels)`, *required*):
Input for patchification
Returns:
`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`
"""
sequence_length = past_values.shape[-2]
if sequence_length != self.sequence_length:
raise ValueError(
f"Input sequence length ({sequence_length}) doesn't match model configuration ({self.sequence_length})."
)
output = past_values[:, self.sequence_start :, :]
output = output.unfold(dimension=-2, size=self.patch_length, step=self.patch_stride)
output = output.transpose(-2, -3).contiguous()
return output
class PatchTSMixerMasking(nn.Module):
"""
Class to perform random or forecast masking.
Parameters:
config (`PatchTSMixerConfig`): model config
Returns:
x_mask (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`)
Masked patched input
mask (`torch.Tensor` of shape `(batch_size, num_channels, num_patches)`)
Bool tensor indicating True on masked points
"""
def __init__(self, config: PatchTSMixerConfig):
super().__init__()
self.random_mask_ratio = config.random_mask_ratio
self.channel_consistent_masking = config.channel_consistent_masking
self.mask_type = config.mask_type
self.num_forecast_mask_patches = config.num_forecast_mask_patches
self.unmasked_channel_indices = config.unmasked_channel_indices
self.mask_value = config.mask_value
if self.unmasked_channel_indices is not None:
self.unmasked_channel_indices = sorted(self.unmasked_channel_indices)
def forward(self, patch_input: torch.Tensor):
"""
Parameters:
patch_input (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`, *required*):
Patch input
Return:
masked_input (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`)
Masked patched input
mask (`torch.Tensor` of shape `(batch_size, num_channels, num_patches)`)
Bool tensor indicating True on masked points
"""
if self.mask_type == "random":
masked_input, mask = random_masking(
inputs=patch_input,
mask_ratio=self.random_mask_ratio,
unmasked_channel_indices=self.unmasked_channel_indices,
channel_consistent_masking=self.channel_consistent_masking,
mask_value=self.mask_value,
)
elif self.mask_type == "forecast":
masked_input, mask = forecast_masking(
inputs=patch_input,
num_forecast_mask_patches=self.num_forecast_mask_patches,
unmasked_channel_indices=self.unmasked_channel_indices,
mask_value=self.mask_value,
)
else:
raise ValueError(f"Invalid mask type {self.mask_type}.")
mask = mask.bool()
return masked_input, mask
class PatchTSMixerStdScaler(nn.Module):
"""
标准化特征,通过计算均值并沿第一个维度进行缩放,然后通过减去均值并除以标准差进行归一化。
"""
def __init__(self, config: PatchTSMixerConfig):
super().__init__()
self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1
self.keepdim = config.keepdim if hasattr(config, "keepdim") else True
self.minimum_scale = config.minimum_scale if hasattr(config, "minimum_scale") else 1e-5
def forward(
self, data: torch.Tensor, observed_indicator: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
前向传播方法
Parameters:
data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`):
输入数据用于批次归一化计算
observed_indicator (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
计算观察指标上的缩放
Returns:
返回元组,包含三个 `torch.Tensor`:
(`(batch_size, sequence_length, num_input_channels)`,
`(batch_size, 1, num_input_channels)`,
`(batch_size, 1, num_input_channels)`)
"""
denominator = observed_indicator.sum(self.dim, keepdim=self.keepdim)
denominator = denominator.clamp_min(1.0)
loc = (data * observed_indicator).sum(self.dim, keepdim=self.keepdim) / denominator
variance = (((data - loc) * observed_indicator) ** 2).sum(self.dim, keepdim=self.keepdim) / denominator
scale = torch.sqrt(variance + self.minimum_scale)
return (data - loc) / scale, loc, scale
class PatchTSMixerMeanScaler(nn.Module):
"""
计算缩放因子作为第一个维度上的加权平均绝对值,并相应地缩放数据。
"""
def __init__(self, config: PatchTSMixerConfig):
super().__init__()
self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1
self.keepdim = config.keepdim if hasattr(config, "keepdim") else True
self.minimum_scale = config.minimum_scale if hasattr(config, "minimum_scale") else 1e-10
self.default_scale = config.default_scale if hasattr(config, "default_scale") else None
def forward(
self, data: torch.Tensor, observed_indicator: torch.Tensor
"""
Parameters:
data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`):
输入用于批量归一化计算的数据
observed_indicator (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
表示观察指标,用于计算缩放比例
Returns:
tuple of `torch.Tensor` of shapes
(`(batch_size, sequence_length, num_input_channels)`, `(batch_size, 1, num_input_channels)`,
`(batch_size, 1, num_input_channels)`)
"""
ts_sum = (data * observed_indicator).abs().sum(self.dim, keepdim=True)
num_observed = observed_indicator.sum(self.dim, keepdim=True)
scale = ts_sum / torch.clamp(num_observed, min=1)
if self.default_scale is None:
batch_sum = ts_sum.sum(dim=0)
batch_observations = torch.clamp(num_observed.sum(0), min=1)
default_scale = torch.squeeze(batch_sum / batch_observations)
else:
default_scale = self.default_scale * torch.ones_like(scale)
scale = torch.where(num_observed > 0, scale, default_scale)
scale = torch.clamp(scale, min=self.minimum_scale)
scaled_data = data / scale
if not self.keepdim:
scale = scale.squeeze(dim=self.dim)
return scaled_data, torch.zeros_like(scale), scale
class PatchTSMixerNOPScaler(nn.Module):
"""
Assigns a scaling factor equal to 1 along the first dimension, and therefore applies no scaling to the input data.
"""
def __init__(self, config: PatchTSMixerConfig):
super().__init__()
self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1
self.keepdim = config.keepdim if hasattr(config, "keepdim") else True
def forward(
self, data: torch.Tensor, observed_indicator: torch.Tensor = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Parameters:
data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`):
input for Batch norm calculation
Returns:
tuple of `torch.Tensor` of shapes
(`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`,
`(batch_size, 1, num_input_channels)`)
"""
scale = torch.ones_like(data, requires_grad=False).mean(dim=self.dim, keepdim=self.keepdim)
loc = torch.zeros_like(data, requires_grad=False).mean(dim=self.dim, keepdim=self.keepdim)
return data, loc, scale
@dataclass
class PatchTSMixerEncoderOutput(ModelOutput):
"""
Base class for `PatchTSMixerEncoderOutput`, with potential hidden states.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches, d_model)`):
Hidden-state at the output of the last layer of the model.
hidden_states (`tuple(torch.FloatTensor)`, *optional*):
Hidden-states of the model at the output of each layer.
"""
last_hidden_state: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
class PatchTSMixerEncoder(PatchTSMixerPreTrainedModel):
"""
Encoder for PatchTSMixer which inputs patched time-series and outputs patched embeddings.
Args:
config (`PatchTSMixerConfig`, *required*):
Configuration.
"""
def __init__(self, config: PatchTSMixerConfig):
super().__init__(config)
self.use_return_dict = config.use_return_dict
self.patcher = nn.Linear(config.patch_length, config.d_model)
if config.use_positional_encoding:
self.positional_encoder = PatchTSMixerPositionalEncoding(config=config)
else:
self.positional_encoder = None
self.mlp_mixer_encoder = PatchTSMixerBlock(config=config)
if config.post_init:
self.post_init()
@replace_return_docstrings(output_type=PatchTSMixerEncoderOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
past_values: torch.Tensor,
output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = None,
) -> Union[Tuple, PatchTSMixerEncoderOutput]:
r"""
Args:
past_values (`torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`):
Context values of the time series. For a pretraining task, this denotes the input time series to
predict the masked portion. For a forecasting task, this denotes the history/past time series values.
Similarly, for classification or regression tasks, it denotes the appropriate context values of the
time series.
For univariate time series, `num_input_channels` dimension should be 1. For multivariate time series,
it is greater than 1.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
Returns:
`torch.FloatTensor` of shape `(batch_size, n_vars, num_patches, d_model)`
"""
return_dict = return_dict if return_dict is not None else self.use_return_dict
patches = self.patcher(past_values)
if self.positional_encoder is not None:
patches = self.positional_encoder(patches)
last_hidden_state, hidden_states = self.mlp_mixer_encoder(patches, output_hidden_states=output_hidden_states)
if not return_dict:
return tuple(
v
for v in [
last_hidden_state,
hidden_states,
]
)
return PatchTSMixerEncoderOutput(last_hidden_state=last_hidden_state, hidden_states=hidden_states)
@dataclass
class PatchTSMixerModelOutput(ModelOutput):
"""
Base class for model's outputs, with potential hidden states.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches, d_model)`):
Hidden-state at the output of the last layer of the model.
hidden_states (`tuple(torch.FloatTensor)`, *optional*):
Hidden-states of the model at the output of each layer.
patch_input (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches, patch_length)`):
Patched input data to the model.
mask: (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches)`,*optional*):
Bool Tensor indicating True in masked patches and False otherwise.
loc: (`torch.FloatTensor` of shape `(batch_size, 1, num_channels)`,*optional*):
Gives the mean of the context window per channel. Used for revin denorm outside the model, if revin
enabled.
scale: (`torch.FloatTensor` of shape `(batch_size, 1, num_channels)`,*optional*):
Gives the std dev of the context window per channel. Used for revin denorm outside the model, if revin
enabled.
"""
last_hidden_state: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
patch_input: torch.FloatTensor = None
mask: Optional[torch.FloatTensor] = None
loc: Optional[torch.FloatTensor] = None
scale: Optional[torch.FloatTensor] = None
@add_start_docstrings(
"The PatchTSMixer Model for time-series forecasting.",
PATCHTSMIXER_START_DOCSTRING,
)
class PatchTSMixerModel(PatchTSMixerPreTrainedModel):
def __init__(self, config: PatchTSMixerConfig, mask_input: bool = False):
super().__init__(config)
self.use_return_dict = config.use_return_dict
self.encoder = PatchTSMixerEncoder(config)
self.patching = PatchTSMixerPatchify(config)
if mask_input is True:
self.masking = PatchTSMixerMasking(config)
else:
self.masking = None
if config.scaling == "mean":
self.scaler = PatchTSMixerMeanScaler(config)
elif config.scaling == "std" or config.scaling is True:
self.scaler = PatchTSMixerStdScaler(config)
else:
self.scaler = PatchTSMixerNOPScaler(config)
if config.post_init:
self.post_init()
@add_start_docstrings_to_model_forward(PATCHTSMIXER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=PatchTSMixerModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
past_values: torch.Tensor,
observed_mask: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = None,
) -> PatchTSMixerModelOutput:
r"""
observed_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
in `[0, 1]`:
- 1 for values that are **observed**,
- 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).
Returns:
`PatchTSMixerModelOutput`: An object containing encoder outputs and other processed inputs.
"""
return_dict = return_dict if return_dict is not None else self.use_return_dict
mask = None
if observed_mask is None:
observed_mask = torch.ones_like(past_values)
scaled_past_values, loc, scale = self.scaler(past_values, observed_mask)
patched_x = self.patching(scaled_past_values)
enc_input = patched_x
if self.masking is not None:
enc_input, mask = self.masking(patched_x)
encoder_output = self.encoder(
enc_input,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
if isinstance(encoder_output, tuple):
encoder_output = PatchTSMixerEncoderOutput(*encoder_output)
if not return_dict:
return tuple(
v
for v in [
encoder_output.last_hidden_state,
encoder_output.hidden_states,
patched_x,
mask,
loc,
scale,
]
)
return PatchTSMixerModelOutput(
last_hidden_state=encoder_output.last_hidden_state,
hidden_states=encoder_output.hidden_states,
patch_input=patched_x,
mask=mask,
loc=loc,
scale=scale,
)
@dataclass
class PatchTSMixerForPreTrainingOutput(ModelOutput):
"""
Output type of [`PatchTSMixerForPreTrainingOutput`].
Args:
prediction_outputs (`torch.FloatTensor` of shape `(batch_size, num_input_channels, num_patches, patch_length)`):
Prediction output from the pretrain head.
hidden_states (`tuple(torch.FloatTensor)`, *optional*):
Hidden-states of the model at the output of each layer.
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_input_channels, num_patches, d_model)`):
Backbone embeddings before passing through the head.
loss (*optional*, returned when `y` is provided, `torch.FloatTensor` of shape `()`):
Total loss
"""
loss: Optional[torch.FloatTensor] = None
prediction_outputs: torch.FloatTensor = None
last_hidden_state: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
class PatchTSMixerForPretraining(PatchTSMixerPreTrainedModel):
r"""
`PatchTSMixer` for mask pretraining.
Args:
config (`PatchTSMixerConfig`, *required*):
Configuration.
Returns:
`None`.
"""
def __init__(self, config: PatchTSMixerConfig):
super().__init__(config)
self.model = PatchTSMixerModel(config, mask_input=True)
self.head = PatchTSMixerPretrainHead(config=config)
self.masked_loss = config.masked_loss
self.use_return_dict = config.use_return_dict
if config.post_init:
self.post_init()
@add_start_docstrings_to_model_forward(PATCHTSMIXER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=PatchTSMixerForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
past_values: torch.Tensor,
observed_mask: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = False,
return_loss: bool = True,
return_dict: Optional[bool] = None,
) -> PatchTSMixerForPreTrainingOutput:
r"""
observed_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
in `[0, 1]`:
- 1 for values that are **observed**,
- 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).
return_loss (`bool`, *optional*):
Whether to return the loss in the `forward` call.
Returns:
PatchTSMixerForPreTrainingOutput: An instance of the output class containing various outputs based on the model's forward pass.
"""
return_dict = return_dict if return_dict is not None else self.use_return_dict
if self.masked_loss is True:
loss = torch.nn.MSELoss(reduction="none")
else:
loss = torch.nn.MSELoss(reduction="mean")
model_output = self.model(
past_values,
observed_mask=observed_mask,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
if isinstance(model_output, tuple):
model_output = PatchTSMixerModelOutput(*model_output)
x_hat = self.head(model_output.last_hidden_state)
if return_loss is True:
loss_val = loss(x_hat, model_output.patch_input)
else:
loss_val = None
if self.masked_loss is True and loss_val is not None:
loss_val = (loss_val.mean(dim=-1) * model_output.mask).sum() / (model_output.mask.sum() + 1e-10)
if not return_dict:
return tuple(
v
for v in [
loss_val,
x_hat,
model_output.last_hidden_state,
model_output.hidden_states,
]
)
return PatchTSMixerForPreTrainingOutput(
loss=loss_val,
prediction_outputs=x_hat,
last_hidden_state=model_output.last_hidden_state,
hidden_states=model_output.hidden_states,
)
@dataclass
class PatchTSMixerForPredictionOutput(ModelOutput):
"""
Output type of [`PatchTSMixerForPredictionOutput`].
Args:
prediction_outputs (`torch.FloatTensor` of shape `(batch_size, prediction_length, num_input_channels)`):
Prediction output from the forecast head.
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_input_channels, num_patches, d_model)`):
Backbone embeddings before passing through the head.
hidden_states (`tuple(torch.FloatTensor)`, *optional*):
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
loss (*optional*, returned when `y` is provided, `torch.FloatTensor` of shape `()`):
Total loss.
loc (`torch.FloatTensor`, *optional* of shape `(batch_size, 1, num_input_channels)`):
Input mean
scale (`torch.FloatTensor`, *optional* of shape `(batch_size, 1, num_input_channels)`):
Input std dev
"""
loss: Optional[torch.FloatTensor] = None
prediction_outputs: torch.FloatTensor = None
last_hidden_state: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
loc: torch.FloatTensor = None
scale: torch.FloatTensor = None
@dataclass
class SamplePatchTSMixerPredictionOutput(ModelOutput):
"""
Base class for time series model's predictions outputs that contains the sampled values from the chosen
distribution.
Args:
sequences (`torch.FloatTensor` of shape `(batch_size, num_samples, prediction_length, number_channels)`):
Sampled values from the chosen distribution.
"""
sequences: torch.FloatTensor = None
@dataclass
class SamplePatchTSMixerRegressionOutput(ModelOutput):
"""
Base class for time series model's predictions outputs that contains the sampled values from the chosen
distribution.
Args:
sequences (`torch.FloatTensor` of shape `(batch_size, num_samples, num_targets)`
Sampled values from the chosen distribution.
"""
sequences: torch.FloatTensor = None
def nll(input: torch.distributions.Distribution, target: torch.Tensor) -> torch.Tensor:
"""
Computes the negative log likelihood loss from input distribution with respect to target.
"""
return -input.log_prob(target)
def weighted_average(input_tensor: torch.Tensor, weights: Optional[torch.Tensor] = None, dim=None) -> torch.Tensor:
"""
Computes the weighted average of a given tensor across a given `dim`, masking values associated with weight zero,
meaning instead of `nan * 0 = nan` you will get `0 * 0 = 0`.
"""
return torch.sum(input_tensor * weights, dim=dim, keepdim=True) / torch.sum(weights, dim=dim, keepdim=True)
if weights is not None:
weighted_tensor = torch.where(weights != 0, input_tensor * weights, torch.zeros_like(input_tensor))
sum_weights = torch.clamp(weights.sum(dim=dim) if dim else weights.sum(), min=1.0)
return (weighted_tensor.sum(dim=dim) if dim else weighted_tensor.sum()) / sum_weights
else:
return input_tensor.mean(dim=dim)
class PatchTSMixerForPrediction(PatchTSMixerPreTrainedModel):
r"""
`PatchTSMixer` for forecasting application.
Args:
config (`PatchTSMixerConfig`, *required*):
Configuration.
Returns:
`None`.
"""
def __init__(self, config: PatchTSMixerConfig):
super().__init__(config)
self.loss = config.loss
self.use_return_dict = config.use_return_dict
self.prediction_channel_indices = config.prediction_channel_indices
self.num_parallel_samples = config.num_parallel_samples
if config.loss == "mse":
self.distribution_output = None
else:
dim = config.prediction_length
distribution_output_map = {
"student_t": StudentTOutput,
"normal": NormalOutput,
"negative_binomial": NegativeBinomialOutput,
}
output_class = distribution_output_map.get(config.distribution_output, None)
if output_class is not None:
self.distribution_output = output_class(dim=dim)
else:
raise ValueError(f"Unknown distribution output {config.distribution_output}")
self.model = PatchTSMixerModel(config)
self.head = PatchTSMixerForPredictionHead(
config=config,
distribution_output=self.distribution_output,
)
if config.post_init:
self.post_init()
@add_start_docstrings_to_model_forward(PATCHTSMIXER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=PatchTSMixerForPredictionOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
past_values: torch.Tensor,
observed_mask: Optional[torch.Tensor] = None,
future_values: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = False,
return_loss: bool = True,
return_dict: Optional[bool] = None,
):
...
def generate(
self,
past_values: torch.Tensor,
observed_mask: Optional[torch.Tensor] = None,
...
) -> SamplePatchTSMixerPredictionOutput:
"""
Generate sequences of sample predictions from a model with a probability distribution head.
Args:
past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
Past values of the time series that serves as context in order to predict the future.
observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
in `[0, 1]`:
- 1 for values that are **observed**,
- 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).
Return:
[`SamplePatchTSMixerPredictionOutput`] where the outputs `sequences` tensor will have shape `(batch_size,
number of samples, prediction_length, num_input_channels)`.
"""
num_parallel_samples = self.num_parallel_samples
outputs = self(
past_values=past_values,
future_values=None,
observed_mask=observed_mask,
output_hidden_states=False,
)
distribution = self.distribution_output.distribution(
outputs.prediction_outputs, loc=outputs.loc, scale=outputs.scale
)
samples = [distribution.sample() for _ in range(num_parallel_samples)]
samples = torch.stack(samples, dim=1)
return SamplePatchTSMixerPredictionOutput(sequences=samples)
@dataclass
class PatchTSMixerForTimeSeriesClassificationOutput(ModelOutput):
"""
Output type of [`PatchTSMixerForTimeSeriesClassificationOutput`].
Args:
prediction_outputs (`torch.FloatTensor` of shape `(batch_size, num_labels)`):
Prediction output from the classification head.
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_input_channels, num_patches, d_model)`):
Backbone embeddings before passing through the head.
hidden_states (`tuple(torch.FloatTensor)`, *optional*):
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
loss (*optional*, returned when `y` is provided, `torch.FloatTensor` of shape `()`):
Total loss.
"""
loss: Optional[torch.FloatTensor] = None
prediction_outputs: torch.FloatTensor = None
last_hidden_state: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
class PatchTSMixerForTimeSeriesClassification(PatchTSMixerPreTrainedModel):
r"""
`PatchTSMixer` for classification application.
Args:
config (`PatchTSMixerConfig`, *required*):
Configuration.
Returns:
`None`.
"""
def __init__(self, config: PatchTSMixerConfig):
super().__init__(config)
self.model = PatchTSMixerModel(config)
self.head = PatchTSMixerLinearHead(
config=config,
)
self.use_return_dict = config.use_return_dict
if config.scaling in ["std", "mean", True]:
self.inject_scale = InjectScalerStatistics4D(d_model=config.d_model, num_patches=config.num_patches)
else:
self.inject_scale = None
if config.post_init:
self.post_init()
@add_start_docstrings_to_model_forward(PATCHTSMIXER_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=PatchTSMixerForTimeSeriesClassificationOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
past_values: torch.Tensor,
target_values: torch.Tensor = None,
output_hidden_states: Optional[bool] = False,
return_loss: bool = True,
return_dict: Optional[bool] = None,
):
"""
Perform forward pass of the PatchTSMixerForTimeSeriesClassification model.
Args:
past_values (`torch.Tensor`):
Tensor of past input values.
target_values (`torch.Tensor`, *optional*):
Tensor of target values for training.
output_hidden_states (`bool`, *optional*):
Whether to output hidden states.
return_loss (`bool`, *optional*):
Whether to return the loss.
return_dict (`bool`, *optional*):
Whether to return a dictionary of outputs.
Returns:
Depending on the `return_dict` setting, returns either a dictionary of outputs or directly the outputs.
"""
pass
r"""
target_values (`torch.FloatTensor` of shape `(batch_size, target_len, num_input_channels)` for forecasting,
`(batch_size, num_targets)` for regression, or `(batch_size,)` for classification, *optional*): Target
values of the time series, that serve as labels for the model. The `target_values` is what the
Transformer needs during training to learn to output, given the `past_values`. Note that, this is NOT
required for a pretraining task.
For a forecasting task, the shape is be `(batch_size, target_len, num_input_channels)`. Even if we want
to forecast only specific channels by setting the indices in `prediction_channel_indices` parameter,
pass the target data with all channels, as channel Filtering for both prediction and target will be
manually applied before the loss computation.
For a classification task, it has a shape of `(batch_size,)`.
For a regression task, it has a shape of `(batch_size, num_targets)`.
return_loss (`bool`, *optional*):
Whether to return the loss in the `forward` call.
Returns:
"""
loss = torch.nn.CrossEntropyLoss()
return_dict = return_dict if return_dict is not None else self.use_return_dict
model_output = self.model(
past_values,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
if isinstance(model_output, tuple):
model_output = PatchTSMixerModelOutput(*model_output)
if self.inject_scale is not None:
model_output.last_hidden_state = self.inject_scale(
model_output.last_hidden_state,
loc=model_output.loc,
scale=model_output.scale,
)
y_hat = self.head(model_output.last_hidden_state)
if target_values is not None and return_loss is True:
loss_val = loss(y_hat, target_values)
else:
loss_val = None
if not return_dict:
return tuple(
v
for v in [
loss_val,
y_hat,
model_output.last_hidden_state,
model_output.hidden_states,
]
)
return PatchTSMixerForTimeSeriesClassificationOutput(
loss=loss_val,
prediction_outputs=y_hat,
last_hidden_state=model_output.last_hidden_state,
hidden_states=model_output.hidden_states,
)
@dataclass
class PatchTSMixerForRegressionOutput(ModelOutput):
"""
Output type of [`PatchTSMixerForRegressionOutput`].
Args:
regression_outputs (`torch.FloatTensor` of shape `(batch_size, num_targets)`):
Prediction output from the regression head.
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_input_channels, num_patches, d_model)`):
Backbone embeddings before passing through the head.
hidden_states (`tuple(torch.FloatTensor)`, *optional*):
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
loss (*optional*, returned when `y` is provided, `torch.FloatTensor` of shape `()`):
Total loss.
"""
loss: Optional[torch.FloatTensor] = None
regression_outputs: torch.FloatTensor = None
last_hidden_state: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
class InjectScalerStatistics4D(nn.Module):
def __init__(self, d_model: int, num_patches: int, expansion: int = 2):
super().__init__()
self.inverse_trans_expansion = nn.Linear(d_model + 2, expansion * d_model)
self.inverse_trans_compression = nn.Linear(expansion * d_model, d_model)
self.map_scale_expansion = nn.Linear(2, 2 * expansion)
self.map_scale_compression = nn.Linear(2 * expansion, 2)
self.num_patches = num_patches
def forward(self, inputs: torch.Tensor, loc: torch.Tensor, scale: torch.Tensor):
"""
Args:
inputs (`torch.Tensor` of shape `(batch_size, num_input_channels, num_patch, d_model)`)
loc (`torch.Tensor` of shape `(batch_size, 1, num_input_channels)`)
scale (`torch.Tensor` of shape `(batch_size, 1, num_input_channels)`)
Returns:
`torch.Tensor` of shape `(batch_size, num_input_channels, num_patch, d_model)`
"""
mean = loc.transpose(-1, -2)
mean = mean.unsqueeze(-2)
mean = mean.repeat(1, 1, self.num_patches, 1)
stdev = scale.transpose(-1, -2)
stdev = stdev.unsqueeze(-2)
stdev = stdev.repeat(1, 1, self.num_patches, 1)
concat_stats = torch.cat([mean, stdev], dim=-1)
concat_stats = self.map_scale_expansion(concat_stats)
concat_stats = self.map_scale_compression(concat_stats)
inputs = torch.cat([inputs, concat_stats], dim=-1)
inputs = self.inverse_trans_expansion(inputs)
inputs = self.inverse_trans_compression(inputs)
return inputs
class PatchTSMixerForRegression(PatchTSMixerPreTrainedModel):
r"""
`PatchTSMixer` for regression application.
Args:
config (`PatchTSMixerConfig`, *required*):
Configuration.
Returns:
`None`.
"""
def __init__(self, config: PatchTSMixerConfig):
super().__init__(config)
self.model = PatchTSMixerModel(config)
self.loss = config.loss
self.distribution_output = config.distribution_output
self.use_return_dict = config.use_return_dict
self.num_parallel_samples = config.num_parallel_samples
if config.loss == "mse":
self.distribution_output = None
else:
distribution_output_map = {
"student_t": StudentTOutput,
"normal": NormalOutput,
"negative_binomial": NegativeBinomialOutput,
}
output_class = distribution_output_map.get(config.distribution_output)
if output_class is not None:
self.distribution_output = output_class(dim=config.num_targets)
else:
raise ValueError(f"Unknown distribution output {config.distribution_output}")
if config.scaling in ["std", "mean", True]:
self.inject_scale = InjectScalerStatistics4D(d_model=config.d_model, num_patches=config.num_patches)
else:
self.inject_scale = None
self.head = PatchTSMixerLinearHead(
config=config,
distribution_output=self.distribution_output,
)
if config.post_init:
self.post_init()
@add_start_docstrings_to_model_forward(PATCHTSMIXER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=PatchTSMixerForRegressionOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
past_values: torch.Tensor,
target_values: torch.Tensor = None,
output_hidden_states: Optional[bool] = False,
return_loss: bool = True,
return_dict: Optional[bool] = None,
):
"""
实现模型的前向传播。
Args:
past_values (torch.Tensor):
过去的值,作为模型的输入。
target_values (torch.Tensor, optional):
目标值,用于计算损失。默认为 None。
output_hidden_states (bool, optional):
是否输出隐藏状态。默认为 False。
return_loss (bool):
是否返回损失值。默认为 True。
return_dict (bool, optional):
是否返回字典形式的输出。默认为 None。
Returns:
根据 return_dict 参数决定的输出形式。
"""
...
def generate(
self,
past_values: torch.Tensor,
...
):
...
) -> SamplePatchTSMixerRegressionOutput:
"""
Generate sequences of sample predictions from a model with a probability distribution head.
Args:
past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
Past values of the time series that serves as context in order to predict the target values.
Return:
[`SamplePatchTSMixerRegressionOutput`] where the outputs `sequences` tensor will have shape `(batch_size,
number of samples, num_targets)`.
"""
num_parallel_samples = self.num_parallel_samples
outputs = self(
past_values=past_values,
target_values=None,
output_hidden_states=False,
)
distribution = self.distribution_output.distribution(outputs.regression_outputs)
samples = [
distribution.sample() for _ in range(num_parallel_samples)
]
samples = torch.stack(samples, dim=1).view(-1, num_parallel_samples, self.config.num_targets)
return SamplePatchTSMixerRegressionOutput(sequences=samples)
.\models\patchtsmixer\__init__.py
代码:
from typing import TYPE_CHECKING
_import_structure = {
"configuration_patchtsmixer": [
"PATCHTSMIXER_PRETRAINED_CONFIG_ARCHIVE_MAP",
"PatchTSMixerConfig",
]
}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_patchtsmixer"] = [
"PATCHTSMIXER_PRETRAINED_MODEL_ARCHIVE_LIST",
"PatchTSMixerPreTrainedModel",
"PatchTSMixerModel",
"PatchTSMixerForPretraining",
"PatchTSMixerForPrediction",
"PatchTSMixerForTimeSeriesClassification",
"PatchTSMixerForRegression",
]
if TYPE_CHECKING:
from .configuration_patchtsmixer import (
PATCHTSMIXER_PRETRAINED_CONFIG_ARCHIVE_MAP,
PatchTSMixerConfig,
)
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_patchtsmixer import (
PATCHTSMIXER_PRETRAINED_MODEL_ARCHIVE_LIST,
PatchTSMixerForPrediction,
PatchTSMixerForPretraining,
PatchTSMixerForRegression,
PatchTSMixerForTimeSeriesClassification,
PatchTSMixerModel,
PatchTSMixerPreTrainedModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
.\models\patchtst\configuration_patchtst.py
"""PatchTST model configuration"""
from typing import List, Optional, Union
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
PATCHTST_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"ibm/patchtst-base": "https://huggingface.co/ibm/patchtst-base/resolve/main/config.json",
}
class PatchTSTConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of an [`PatchTSTModel`]. It is used to instantiate an
PatchTST model according to the specified arguments, defining the model architecture.
[ibm/patchtst](https://huggingface.co/ibm/patchtst) architecture.
Configuration objects inherit from [`PretrainedConfig`] can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
```
>>> from transformers import PatchTSTConfig, PatchTSTModel
>>> # Initializing an PatchTST configuration with 12 time steps for prediction
>>> configuration = PatchTSTConfig(prediction_length=12)
>>> # Randomly initializing a model (with random weights) from the configuration
>>> model = PatchTSTModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
"""
model_type = "patchtst"
attribute_map = {
"hidden_size": "d_model",
"num_attention_heads": "num_attention_heads",
"num_hidden_layers": "num_hidden_layers",
}
def __init__(
self,
num_input_channels: int = 1,
context_length: int = 32,
distribution_output: str = "student_t",
loss: str = "mse",
patch_length: int = 1,
patch_stride: int = 1,
num_hidden_layers: int = 3,
d_model: int = 128,
num_attention_heads: int = 4,
share_embedding: bool = True,
channel_attention: bool = False,
ffn_dim: int = 512,
norm_type: str = "batchnorm",
norm_eps: float = 1e-05,
attention_dropout: float = 0.0,
dropout: float = 0.0,
positional_dropout: float = 0.0,
path_dropout: float = 0.0,
ff_dropout: float = 0.0,
bias: bool = True,
activation_function: str = "gelu",
pre_norm: bool = True,
positional_encoding_type: str = "sincos",
use_cls_token: bool = False,
init_std: float = 0.02,
share_projection: bool = True,
scaling: Optional[Union[str, bool]] = "std",
do_mask_input: Optional[bool] = None,
mask_type: str = "random",
random_mask_ratio: float = 0.5,
num_forecast_mask_patches: Optional[Union[List[int], int]] = [2],
channel_consistent_masking: Optional[bool] = False,
unmasked_channel_indices: Optional[List[int]] = None,
mask_value: int = 0,
pooling_type: str = "mean",
head_dropout: float = 0.0,
prediction_length: int = 24,
num_targets: int = 1,
output_range: Optional[List] = None,
num_parallel_samples: int = 100,
**kwargs,
):
self.context_length = context_length
self.num_input_channels = num_input_channels
self.loss = loss
self.distribution_output = distribution_output
self.num_parallel_samples = num_parallel_samples
self.d_model = d_model
self.num_attention_heads = num_attention_heads
self.ffn_dim = ffn_dim
self.num_hidden_layers = num_hidden_layers
self.dropout = dropout
self.attention_dropout = attention_dropout
self.share_embedding = share_embedding
self.channel_attention = channel_attention
self.norm_type = norm_type
self.norm_eps = norm_eps
self.positional_dropout = positional_dropout
self.path_dropout = path_dropout
self.ff_dropout = ff_dropout
self.bias = bias
self.activation_function = activation_function
self.pre_norm = pre_norm
self.positional_encoding_type = positional_encoding_type
self.use_cls_token = use_cls_token
self.init_std = init_std
self.scaling = scaling
self.patch_length = patch_length
self.patch_stride = patch_stride
self.do_mask_input = do_mask_input
self.mask_type = mask_type
self.random_mask_ratio = random_mask_ratio
self.num_forecast_mask_patches = num_forecast_mask_patches
self.channel_consistent_masking = channel_consistent_masking
self.unmasked_channel_indices = unmasked_channel_indices
self.mask_value = mask_value
self.pooling_type = pooling_type
self.head_dropout = head_dropout
self.share_projection = share_projection
self.prediction_length = prediction_length
self.num_parallel_samples = num_parallel_samples
self.num_targets = num_targets
self.output_range = output_range
super().__init__(**kwargs)
.\models\patchtst\modeling_patchtst.py
""" PyTorch PatchTST model."""
import math
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import torch
from torch import nn
from ...activations import ACT2CLS
from ...modeling_outputs import BaseModelOutput
from ...modeling_utils import PreTrainedModel
from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput
from ...utils import ModelOutput, add_start_docstrings, logging
from .configuration_patchtst import PatchTSTConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "PatchTSTConfig"
PATCHTST_PRETRAINED_MODEL_ARCHIVE_LIST = [
"ibm/patchtst-etth1-pretrain",
]
class PatchTSTAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
is_decoder: bool = False,
bias: bool = True,
is_causal: bool = False,
config: Optional[PatchTSTConfig] = None,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
self.config = config
if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {num_heads})."
)
self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
self.is_causal = is_causal
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
class PatchTSTBatchNorm(nn.Module):
"""
Compute batch normalization over the sequence length (time) dimension.
"""
def __init__(self, config: PatchTSTConfig):
super().__init__()
self.batchnorm = nn.BatchNorm1d(config.d_model, eps=config.norm_eps)
def forward(self, inputs: torch.Tensor):
"""
Parameters:
inputs (`torch.Tensor` of shape `(batch_size, sequence_length, d_model)`):
input for Batch norm calculation
Returns:
`torch.Tensor` of shape `(batch_size, sequence_length, d_model)`
"""
output = inputs.transpose(1, 2)
output = self.batchnorm(output)
return output.transpose(1, 2)
def random_masking(
inputs: torch.Tensor,
mask_ratio: float,
unmasked_channel_indices: list = None,
channel_consistent_masking: bool = False,
mask_value: int = 0,
):
"""random_masking: Mask the input considering the control variables.
Args:
inputs (`torch.Tensor` of shape `(batch_size, num_channels, sequence_length, num_features)`):
The input tensor to mask.
mask_ratio (`float`):
Masking ratio applied to mask the input data during random pretraining. It is the number between 0 and 1.
unmasked_channel_indices (list, *optional*):
Indices of channels that will not be masked.
channel_consistent_masking (bool, *optional*, defaults to `False`):
When true, masking will be same across all channels of a timeseries. Otherwise, masking positions will vary
across channels.
mask_value (int, *optional*, defaults to 0):
Define the value of masked patches for pretraining.
Returns:
`tuple(torch.Tensor)`: inputs_mask, masked input, same shape as input Tensor and mask tensor of shape [bs x c x
n]
"""
if mask_ratio < 0 or mask_ratio >= 1:
raise ValueError(f"Mask ratio {mask_ratio} has to be between 0 and 1.")
batch_size, num_channels, sequence_length, num_features = inputs.shape
device = inputs.device
len_keep = int(sequence_length * (1 - mask_ratio))
if channel_consistent_masking:
noise = torch.rand(batch_size, 1, sequence_length, device=device)
noise = noise.repeat(1, num_channels, 1)
else:
noise = torch.rand(batch_size, num_channels, sequence_length, device=device)
mask = torch.ones(batch_size, num_channels, sequence_length, device=device)
mask[:, :, :len_keep] = 0
ids_shuffle = torch.argsort(noise, dim=-1)
ids_restore = torch.argsort(ids_shuffle, dim=-1)
mask = torch.gather(mask, dim=-1, index=ids_restore)
mask = mask.unsqueeze(-1).repeat(1, 1, 1, num_features)
if unmasked_channel_indices is not None:
mask[:, unmasked_channel_indices, :, :] = 0
inputs_mask = inputs.masked_fill(mask.bool(), mask_value)
return inputs_mask, mask[..., 0]
def forecast_masking(
inputs: torch.Tensor,
num_forecast_mask_patches: Union[list, int],
unmasked_channel_indices: list = None,
mask_value: int = 0,
):
"""Forecast masking that masks the last K patches where K is from the num_forecast_mask_patches.
If num_forecast_mask_patches is a list, samples in the batch will be randomly masked by numbers defined in the list.
Parameters:
inputs (`torch.Tensor`):
Input of shape `(bs, num_channels, num_patch, patch_length)`
num_forecast_mask_patches (`list`):
Number of patches to be masked at the end of each batch sample. e.g. 4 or [3, 5].
unmasked_channel_indices (`list`, *optional*):
Indices of channels that are not masked.
mask_value (`int`, *optional*, defaults to 0):
Values in the masked patches will be filled by `mask_value`.
Returns:
`tuple(torch.Tensor)`: inputs_mask, masked input, same shape as inputs Tensor and Mask tensor of shape `(bs,
num_channels , num_patch)` or `(bs, tsg1, tsg2, num_channels, num_patch)`
"""
if isinstance(num_forecast_mask_patches, int):
num_forecast_mask_patches = [num_forecast_mask_patches]
forecast_mask_ratios = [1 for _ in num_forecast_mask_patches]
batch_size, num_channels, sequence_length, num_features = inputs.shape
mask = torch.zeros(batch_size, num_channels, sequence_length, device=inputs.device)
t_list = []
total_length = 0
total_ratio = sum(forecast_mask_ratios)
for patch_length, ratio in zip(num_forecast_mask_patches, forecast_mask_ratios):
if patch_length <= 0 or patch_length >= sequence_length:
raise ValueError(
f"num_forecast_mask_patches {patch_length} should be greater than 0 and less than total patches."
)
temp_len = int(batch_size * ratio / total_ratio)
t_list.append([patch_length, ratio, temp_len])
total_length += temp_len
t_list = sorted(t_list, key=lambda x: x[2])
if total_length < batch_size:
t_list[0][2] = t_list[0][2] + (batch_size - total_length)
elif total_length > batch_size:
t_list[-1][2] = t_list[-1][2] + (total_length - batch_size)
batch1 = 0
for patch_len, _, temp_len in t_list:
batch2 = batch1 + temp_len
mask[batch1:batch2, :, -patch_len:] = 1
batch1 = batch2
perm = torch.randperm(mask.shape[0])
mask = mask[perm]
mask = mask.unsqueeze(-1).repeat(1, 1, 1, num_features)
if unmasked_channel_indices is not None:
mask[:, unmasked_channel_indices, :, :] = 0
inputs_mask = inputs.masked_fill(mask.bool(), mask_value)
return inputs_mask, mask[..., 0]
def __init__(self, config: PatchTSTConfig):
super().__init__()
self.sequence_length = config.context_length
self.patch_length = config.patch_length
self.patch_stride = config.patch_stride
if self.sequence_length <= self.patch_length:
raise ValueError(
f"Sequence length ({self.sequence_length}) has to be greater than the patch length ({self.patch_length})"
)
self.num_patches = (max(self.sequence_length, self.patch_length) - self.patch_length) // self.patch_stride + 1
new_sequence_length = self.patch_length + self.patch_stride * (self.num_patches - 1)
self.sequence_start = self.sequence_length - new_sequence_length
def forward(self, past_values: torch.Tensor):
"""
Parameters:
past_values (`torch.Tensor` of shape `(batch_size, sequence_length, num_channels)`, *required*):
Input for patchification
Returns:
`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`
"""
sequence_length = past_values.shape[-2]
if sequence_length != self.sequence_length:
raise ValueError(
f"Input sequence length ({sequence_length}) doesn't match model configuration ({self.sequence_length})."
)
output = past_values[:, self.sequence_start :, :]
output = output.unfold(dimension=-2, size=self.patch_length, step=self.patch_stride)
output = output.transpose(-2, -3).contiguous()
return output
class PatchTSTMasking(nn.Module):
"""
Class to perform random or forecast masking.
Parameters:
config (`PatchTSTConfig`): model config
Returns:
x_mask (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`)
Masked patched input
mask (`torch.Tensor` of shape `(batch_size, num_channels, num_patches)`)
Bool tensor indicating True on masked points
"""
def __init__(self, config: PatchTSTConfig):
super().__init__()
self.random_mask_ratio = config.random_mask_ratio
self.channel_consistent_masking = config.channel_consistent_masking
self.mask_type = config.mask_type
self.num_forecast_mask_patches = config.num_forecast_mask_patches
self.unmasked_channel_indices = config.unmasked_channel_indices
self.mask_value = config.mask_value
if self.unmasked_channel_indices is not None:
self.unmasked_channel_indices = sorted(self.unmasked_channel_indices)
def forward(self, patch_input: torch.Tensor):
"""
Parameters:
patch_input (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`, *required*):
Patch input
Return:
masked_input (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`)
Masked patched input
mask (`torch.Tensor` of shape `(batch_size, num_channels, num_patches)`)
Bool tensor indicating True on masked points
"""
if self.mask_type == "random":
masked_input, mask = random_masking(
inputs=patch_input,
mask_ratio=self.random_mask_ratio,
unmasked_channel_indices=self.unmasked_channel_indices,
channel_consistent_masking=self.channel_consistent_masking,
mask_value=self.mask_value,
)
elif self.mask_type == "forecast":
masked_input, mask = forecast_masking(
inputs=patch_input,
num_forecast_mask_patches=self.num_forecast_mask_patches,
unmasked_channel_indices=self.unmasked_channel_indices,
mask_value=self.mask_value,
)
else:
raise ValueError(f"Invalid mask type {self.mask_type}.")
mask = mask.bool()
return masked_input, mask
class PatchTSTEncoderLayer(nn.Module):
"""
PatchTST encoder layer
"""
def __init__(self, config: PatchTSTConfig):
super().__init__()
self.channel_attention = config.channel_attention
self.self_attn = PatchTSTAttention(
embed_dim=config.d_model,
num_heads=config.num_attention_heads,
dropout=config.attention_dropout,
)
self.dropout_path1 = nn.Dropout(config.path_dropout) if config.path_dropout > 0 else nn.Identity()
if config.norm_type == "batchnorm":
self.norm_sublayer1 = PatchTSTBatchNorm(config)
elif config.norm_type == "layernorm":
self.norm_sublayer1 = nn.LayerNorm(config.d_model, eps=config.norm_eps)
else:
raise ValueError(f"{config.norm_type} is not a supported norm layer type.")
if self.channel_attention:
self.dropout_path2 = nn.Dropout(config.path_dropout) if config.path_dropout > 0 else nn.Identity()
if config.norm_type == "batchnorm":
self.norm_sublayer2 = PatchTSTBatchNorm(config)
elif config.norm_type == "layernorm":
self.norm_sublayer2 = nn.LayerNorm(config.d_model, eps=config.norm_eps)
else:
raise ValueError(f"{config.norm_type} is not a supported norm layer type.")
self.ff = nn.Sequential(
nn.Linear(config.d_model, config.ffn_dim, bias=config.bias),
ACT2CLS[config.activation_function](),
nn.Dropout(config.ff_dropout) if config.ff_dropout > 0 else nn.Identity(),
nn.Linear(config.ffn_dim, config.d_model, bias=config.bias),
)
self.dropout_path3 = nn.Dropout(config.path_dropout) if config.path_dropout > 0 else nn.Identity()
if config.norm_type == "batchnorm":
self.norm_sublayer3 = PatchTSTBatchNorm(config)
elif config.norm_type == "layernorm":
self.norm_sublayer3 = nn.LayerNorm(config.d_model, eps=config.norm_eps)
else:
raise ValueError(f"{config.norm_type} is not a supported norm layer type.")
self.pre_norm = config.pre_norm
class PatchTSTPreTrainedModel(PreTrainedModel):
config_class = PatchTSTConfig
base_model_prefix = "model"
main_input_name = "past_values"
supports_gradient_checkpointing = False
def _init_weights(self, module):
"""
初始化权重
"""
if isinstance(module, PatchTSTPositionalEncoding):
if self.config.use_cls_token:
nn.init.normal_(module.cls_token, std=0.02)
if self.config.positional_encoding_type == "random":
nn.init.normal_(module.position_enc, mean=0.0, std=0.1)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, PatchTSTBatchNorm):
module.batchnorm.bias.data.zero_()
module.batchnorm.weight.data.fill_(1.0)
elif isinstance(module, (nn.Linear, nn.Conv1d)):
module.weight.data.normal_(mean=0.0, std=self.config.init_std)
if module.bias is not None:
module.bias.data.zero_()
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (PatchTSTEncoder)):
module.gradient_checkpointing = value
class PatchTSTEmbedding(nn.Module):
def __init__(self, config: PatchTSTConfig):
super().__init__()
self.num_input_channels = config.num_input_channels
self.share_embedding = config.share_embedding
if self.share_embedding:
self.input_embedding = nn.Linear(config.patch_length, config.d_model)
else:
self.input_embedding = nn.ModuleList()
for _ in range(config.num_input_channels):
self.input_embedding.append(nn.Linear(config.patch_length, config.d_model))
def forward(self, patch_input: torch.Tensor):
"""
Parameters:
patch_input (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`, *required*):
Patch input for embedding
return:
`torch.Tensor` of shape `(batch_size, num_channels, num_patches, d_model)`
"""
num_input_channels = patch_input.shape[1]
if num_input_channels != self.num_input_channels:
raise ValueError(
f"The defined number of input channels ({self.num_input_channels}) in the config "
f"has to be the same as the number of channels in the batch input ({num_input_channels})"
)
if self.share_embedding:
embeddings = self.input_embedding(patch_input)
else:
embeddings = [self.input_embedding[i](patch_input[:, i, :, :]) for i in range(num_input_channels)]
embeddings = torch.stack(embeddings, dim=1)
return embeddings
class PatchTSTPositionalEncoding(nn.Module):
"""
Class for positional encoding
"""
def __init__(self, config: PatchTSTConfig, num_patches: int):
super().__init__()
self.use_cls_token = config.use_cls_token
self.num_input_channels = config.num_input_channels
if config.use_cls_token:
self.cls_token = nn.Parameter(torch.zeros(1, 1, 1, config.d_model))
num_patches += 1
self.position_enc = self._init_pe(config, num_patches)
self.positional_dropout = (
nn.Dropout(config.positional_dropout) if config.positional_dropout > 0 else nn.Identity()
)
@staticmethod
def _init_pe(config: PatchTSTConfig, num_patches: int) -> nn.Parameter:
if config.positional_encoding_type == "random":
position_enc = nn.Parameter(torch.randn(num_patches, config.d_model), requires_grad=True)
elif config.positional_encoding_type == "sincos":
position_enc = torch.zeros(num_patches, config.d_model)
position = torch.arange(0, num_patches).unsqueeze(1)
div_term = torch.exp(torch.arange(0, config.d_model, 2) * -(math.log(10000.0) / config.d_model))
position_enc[:, 0::2] = torch.sin(position * div_term)
position_enc[:, 1::2] = torch.cos(position * div_term)
position_enc = position_enc - position_enc.mean()
position_enc = position_enc / (position_enc.std() * 10)
position_enc = nn.Parameter(position_enc, requires_grad=False)
else:
raise ValueError(
f"{config.positional_encoding_type} is not a valid positional encoder. Available types are 'random' and 'sincos'."
)
return position_enc
def forward(self, patch_input: torch.Tensor):
if self.use_cls_token:
patch_input = self.positional_dropout(patch_input + self.position_enc[1:, :])
cls_token = self.cls_token + self.position_enc[:1, :]
cls_tokens = cls_token.expand(patch_input.shape[0], self.num_input_channels, -1, -1)
hidden_state = torch.cat((cls_tokens, patch_input), dim=2)
else:
hidden_state = self.positional_dropout(patch_input + self.position_enc)
return hidden_state
class PatchTSTEncoder(PatchTSTPreTrainedModel):
"""
PatchTST Encoder
"""
def __init__(self, config: PatchTSTConfig, num_patches: int):
super().__init__(config)
self.gradient_checkpointing = False
self.embedder = PatchTSTEmbedding(config)
self.positional_encoder = PatchTSTPositionalEncoding(config, num_patches)
self.layers = nn.ModuleList([PatchTSTEncoderLayer(config) for i in range(config.num_hidden_layers)])
self.post_init()
def forward(
self,
patch_input: torch.Tensor,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
) -> BaseModelOutput:
"""
Parameters:
patch_input (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`, *required*):
Past values of the time series
output_hidden_states (bool, optional): Indicates if hidden states should be outputted.
output_attentions (bool, optional): Indicates if attentions should be outputted.
return:
`BaseModelOutput`
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
patch_input = self.embedder(patch_input)
hidden_state = self.positional_encoder(patch_input)
encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
for encoder_layer in self.layers:
if output_hidden_states:
encoder_states = encoder_states + (hidden_state,)
layer_outputs = encoder_layer(hidden_state=hidden_state, output_attentions=output_attentions)
hidden_state = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
return BaseModelOutput(last_hidden_state=hidden_state, hidden_states=encoder_states, attentions=all_attentions)
PATCHTST_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`PatchTSTConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
@dataclass
class PatchTSTModelOutput(ModelOutput):
"""
Base class for model's outputs, with potential hidden states.
Parameters:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches, patch_length)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, num_channels, height, width)`. Hidden-states of
the model at the output of each layer plus the optional initial embedding outputs.
mask: (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches)`, *optional*)
Bool masked tensor indicating which patches are masked
loc: (`torch.FloatTensor` of shape `(batch_size, 1, num_channels)`, *optional*)
Mean of the input data (batch_size, sequence_length, num_channels) over the sequence_length
scale: (`torch.FloatTensor` of shape `(batch_size, 1, num_channels)`, *optional*)
Std of the input data (batch_size, sequence_length, num_channels) over the sequence_length
patch_input (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches, patch_length)`):
Patched input to the Transformer
"""
last_hidden_state: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
mask: torch.FloatTensor = None
loc: torch.FloatTensor = None
scale: torch.FloatTensor = None
patch_input: torch.FloatTensor = None
@dataclass
class PatchTSTForPretrainingOutput(ModelOutput):
"""
Output type of [`PatchTSTForPretraining`].
This class defines the output structure specifically for the `PatchTSTForPretraining` model, but does not contain any additional fields.
It inherits from `ModelOutput`, which is a base class providing basic fields like `last_hidden_state`, `hidden_states`, etc.
"""
Parameters:
loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
MSE loss.
MSE(均方误差)损失值,仅在提供了`labels`时返回,类型为`torch.FloatTensor`,形状为`(1,)`。
prediction_outputs (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction outputs of the time series modeling heads.
时间序列建模头部的预测输出,类型为`torch.FloatTensor`,形状为`(batch_size, sequence_length, config.vocab_size)`。
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
模型每一层输出的隐藏状态,以及初始嵌入输出的元组,类型为`tuple(torch.FloatTensor)`,形状为`(batch_size, sequence_length, hidden_size)`。
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
注意力权重,在经过注意力 softmax 后得到,用于计算自注意力头部的加权平均,类型为`tuple(torch.FloatTensor)`,形状为`(batch_size, num_heads, sequence_length, sequence_length)`。
@dataclass
class PatchTSTForRegressionOutput(ModelOutput):
"""
Output type of [`PatchTSTForRegression`].
Parameters:
loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
MSE loss.
均方误差损失,仅在提供 `labels` 参数时返回,类型为 `torch.FloatTensor`,形状为 `(1,)`。
regression_outputs (`torch.FloatTensor` of shape `(batch_size, num_targets)`):
Regression outputs of the time series modeling heads.
时间序列建模头部的回归输出,类型为 `torch.FloatTensor`,形状为 `(batch_size, num_targets)`。
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
模型在每层输出的隐藏状态,包括初始嵌入输出,类型为 `tuple(torch.FloatTensor)`,仅在传递 `output_hidden_states=True` 或 `config.output_hidden_states=True` 时返回。
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
注意力权重经过注意力 softmax 后的值,用于计算自注意力头中的加权平均,类型为 `tuple(torch.FloatTensor)`,仅在传递 `output_attentions=True` 或 `config.output_attentions=True` 时返回。
"""
loss: Optional[torch.FloatTensor] = None
regression_outputs: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class PatchTSTForPredictionOutput(ModelOutput):
"""
Output type of [`PatchTSTForPrediction`].
"""
Parameters:
loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
MSE loss.
MSE 损失(均方误差损失),当提供 `labels` 时返回,类型为 `torch.FloatTensor`,形状为 `(1,)`。
prediction_outputs (`torch.FloatTensor` of shape `(batch_size, prediction_length, -1)`):
Prediction outputs of the time series modeling heads.
时间序列建模头的预测输出,类型为 `torch.FloatTensor`,形状为 `(batch_size, prediction_length, -1)`。
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
模型在每层输出的隐藏状态,以及初始嵌入输出的元组。返回条件包括传递 `output_hidden_states=True` 或 `config.output_hidden_states=True`。
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
注意力权重,经过注意力 softmax 后的结果,用于计算自注意力头部的加权平均值。返回条件包括传递 `output_attentions=True` 或 `config.output_attentions=True`。
loc: (`torch.FloatTensor` of shape `(batch_size, 1, num_channels)`, *optional*)
Mean of the input data (batch_size, sequence_length, num_channels) over the sequence_length
输入数据的均值(在序列长度上)。类型为 `torch.FloatTensor`,形状为 `(batch_size, 1, num_channels)`。
scale: (`torch.FloatTensor` of shape `(batch_size, 1, num_channels)`, *optional*)
Std of the input data (batch_size, sequence_length, num_channels) over the sequence_length
输入数据的标准差(在序列长度上)。类型为 `torch.FloatTensor`,形状为 `(batch_size, 1, num_channels)`。
@dataclass
class PatchTSTForClassificationOutput(ModelOutput):
"""
Output type of [`PatchTSTForClassification`].
Parameters:
loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
Total loss as the sum of the masked language modeling loss and the next sequence prediction
(classification) loss.
prediction_logits (`torch.FloatTensor` of shape `(batch_size, num_targets)`):
Prediction scores of the PatchTST modeling head (scores before SoftMax).
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
loss: Optional[torch.FloatTensor] = None
prediction_logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class SamplePatchTSTOutput(ModelOutput):
"""
Base class for time series model's predictions outputs that contains the sampled values from the chosen
distribution.
Parameters:
sequences `(batch_size, num_samples, prediction_length, num_targets)`):
Sampled values from the chosen distribution.
"""
sequences: torch.FloatTensor = None
def nll(input: torch.distributions.Distribution, target: torch.Tensor) -> torch.Tensor:
"""
Computes the negative log likelihood loss from input distribution with respect to target.
"""
return -input.log_prob(target)
def weighted_average(input_tensor: torch.Tensor, weights: Optional[torch.Tensor] = None, dim=None) -> torch.Tensor:
"""
Computes the weighted average of a given tensor across a given `dim`, masking values associated with weight zero,
meaning instead of `nan * 0 = nan` you will get `0 * 0 = 0`.
"""
if weights is not None:
weighted_tensor = torch.where(weights != 0, input_tensor * weights, torch.zeros_like(input_tensor))
sum_weights = torch.clamp(weights.sum(dim=dim) if dim else weights.sum(), min=1.0)
return (weighted_tensor.sum(dim=dim) if dim else weighted_tensor.sum()) / sum_weights
else:
return input_tensor.mean(dim=dim)
class PatchTSTStdScaler(nn.Module):
"""
Standardize features by calculating the mean and scaling along the first dimension, and then normalizes it by
subtracting from the mean and dividing by the standard deviation.
"""
def __init__(self, config: PatchTSTConfig):
super().__init__()
self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1
self.keepdim = config.keepdim if hasattr(config, "keepdim") else True
self.minimum_scale = config.minimum_scale if hasattr(config, "minimum_scale") else 1e-5
def forward(
self, data: torch.Tensor, observed_indicator: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Parameters:
data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`):
input for Batch norm calculation
observed_indicator (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
Calculating the scale on the observed indicator.
Returns:
tuple of `torch.Tensor` of shapes
(`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`,
`(batch_size, 1, num_input_channels)`)
"""
denominator = observed_indicator.sum(self.dim, keepdim=self.keepdim)
denominator = denominator.clamp_min(1.0)
loc = (data * observed_indicator).sum(self.dim, keepdim=self.keepdim) / denominator
variance = (((data - loc) * observed_indicator) ** 2).sum(self.dim, keepdim=self.keepdim) / denominator
scale = torch.sqrt(variance + self.minimum_scale)
return (data - loc) / scale, loc, scale
class PatchTSTMeanScaler(nn.Module):
"""
Computes a scaling factor as the weighted average absolute value along the first dimension, and scales the data
accordingly.
"""
def __init__(self, config: PatchTSTConfig):
super().__init__()
self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1
self.keepdim = config.keepdim if hasattr(config, "keepdim") else True
self.minimum_scale = config.minimum_scale if hasattr(config, "minimum_scale") else 1e-10
self.default_scale = config.default_scale if hasattr(config, "default_scale") else None
def forward(
self, data: torch.Tensor, observed_indicator: torch.Tensor
) -> torch.Tensor:
"""
Parameters:
data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`):
input for Batch norm calculation
observed_indicator (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
Calculating the scale on the observed indicator.
Returns:
`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`
scaled data based on the computed scaling factor.
"""
denominator = observed_indicator.sum(self.dim, keepdim=self.keepdim)
denominator = denominator.clamp_min(1.0)
loc = (data * observed_indicator).sum(self.dim, keepdim=self.keepdim) / denominator
scale = torch.mean(torch.abs(data - loc), dim=self.dim, keepdim=self.keepdim)
if self.default_scale is not None:
scale = torch.max(scale, self.default_scale)
return data / scale
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Parameters:
data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`):
input for Batch norm calculation
observed_indicator (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
Calculating the scale on the observed indicator.
Returns:
tuple of `torch.Tensor` of shapes
(`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`,
`(batch_size, 1, num_input_channels)`)
"""
ts_sum = (data * observed_indicator).abs().sum(self.dim, keepdim=True)
num_observed = observed_indicator.sum(self.dim, keepdim=True)
scale = ts_sum / torch.clamp(num_observed, min=1)
if self.default_scale is None:
batch_sum = ts_sum.sum(dim=0)
batch_observations = torch.clamp(num_observed.sum(0), min=1)
default_scale = torch.squeeze(batch_sum / batch_observations)
else:
default_scale = self.default_scale * torch.ones_like(scale)
scale = torch.where(num_observed > 0, scale, default_scale)
scale = torch.clamp(scale, min=self.minimum_scale)
scaled_data = data / scale
if not self.keepdim:
scale = scale.squeeze(dim=self.dim)
return scaled_data, torch.zeros_like(scale), scale
class PatchTSTNOPScaler(nn.Module):
"""
Assigns a scaling factor equal to 1 along the first dimension, and therefore applies no scaling to the input data.
"""
def __init__(self, config: PatchTSTConfig):
super().__init__()
self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1
self.keepdim = config.keepdim if hasattr(config, "keepdim") else True
def forward(
self, data: torch.Tensor, observed_indicator: torch.Tensor = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Parameters:
data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`):
input for Batch norm calculation
Returns:
tuple of `torch.Tensor` of shapes
(`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`,
`(batch_size, 1, num_input_channels)`)
"""
scale = torch.ones_like(data, requires_grad=False).mean(dim=self.dim, keepdim=self.keepdim)
loc = torch.zeros_like(data, requires_grad=False).mean(dim=self.dim, keepdim=self.keepdim)
return data, loc, scale
class PatchTSTScaler(nn.Module):
def __init__(self, config: PatchTSTConfig):
super().__init__()
if config.scaling == "mean" or config.scaling is True:
self.scaler = PatchTSTMeanScaler(config)
elif config.scaling == "std":
self.scaler = PatchTSTStdScaler(config)
else:
self.scaler = PatchTSTNOPScaler(config)
def forward(
self, data: torch.Tensor, observed_indicator: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Parameters:
data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`):
Input for scaler calculation
observed_indicator (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
Calculating the scale on the observed indicator.
Returns:
tuple of `torch.Tensor` of shapes
(`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`,
`(batch_size, 1, um_input_channels)`)
"""
data, loc, scale = self.scaler(data, observed_indicator)
return data, loc, scale
@add_start_docstrings(
"The bare PatchTST Model outputting raw hidden-states without any specific head.",
PATCHTST_START_DOCSTRING,
)
class PatchTSTModel(PatchTSTPreTrainedModel):
def __init__(self, config: PatchTSTConfig):
super().__init__(config)
self.scaler = PatchTSTScaler(config)
self.patchifier = PatchTSTPatchify(config)
num_patches = self.patchifier.num_patches
if self.do_mask_input:
self.masking = PatchTSTMasking(config)
else:
self.masking = nn.Identity()
self.encoder = PatchTSTEncoder(config, num_patches=num_patches)
self.post_init()
def forward(
self,
past_values: torch.Tensor,
past_observed_mask: Optional[torch.Tensor] = None,
future_values: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
class PatchTSTMaskPretrainHead(nn.Module):
"""
Pretraining head for mask modelling
"""
def __init__(self, config: PatchTSTConfig):
super().__init__()
self.dropout = nn.Dropout(config.dropout)
self.linear = nn.Linear(config.d_model, config.patch_length)
self.use_cls_token = config.use_cls_token
def forward(self, embedding: torch.Tensor) -> torch.Tensor:
"""
Parameters:
embedding (`torch.Tensor` of shape `(bs, num_channels, num_patches, d_model)` or
`(bs, num_channels, num_patches+1, d_model)` if `cls_token` is set to True, *required*):
Embedding from the model
Returns:
`torch.Tensor` of shape `(bs, num_channels, num_patches, d_model)` or
`(bs, num_channels, num_patches+1, d_model)` if `cls_token` is set to True
"""
embedding = self.linear(self.dropout(embedding))
if self.use_cls_token:
embedding = embedding[:, :, 1:, :]
return embedding
@add_start_docstrings(
"The PatchTST for pretrain model.",
PATCHTST_START_DOCSTRING,
)
class PatchTSTForPretraining(PatchTSTPreTrainedModel):
def __init__(self, config: PatchTSTConfig):
super().__init__(config)
config.do_mask_input = True
self.model = PatchTSTModel(config=config)
self.head = PatchTSTMaskPretrainHead(config)
self.post_init()
def forward(
self,
past_values: torch.Tensor,
past_observed_mask: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Dict[str, torch.Tensor]:
"""
Parameters:
past_values (`torch.Tensor`): Tensor containing past values.
past_observed_mask (`Optional[torch.Tensor]`, optional): Mask tensor for observed values.
output_hidden_states (`Optional[bool]`, optional): Whether to output hidden states.
output_attentions (`Optional[bool]`, optional): Whether to output attention weights.
return_dict (`Optional[bool]`, optional): Whether to return a dictionary.
Returns:
Dictionary containing output tensors depending on the model configuration.
"""
class PatchTSTClassificationHead(nn.Module):
def __init__(self, config: PatchTSTConfig):
super().__init__()
self.use_cls_token = config.use_cls_token
self.pooling_type = config.pooling_type
self.flatten = nn.Flatten(start_dim=1)
self.dropout = nn.Dropout(config.head_dropout) if config.head_dropout > 0 else nn.Identity()
self.linear = nn.Linear(config.num_input_channels * config.d_model, config.num_targets)
def forward(self, embedding: torch.Tensor):
"""
Parameters:
embedding (`torch.Tensor` of shape `(bs, num_channels, num_patches, d_model)` or
`(bs, num_channels, num_patches+1, d_model)` if `cls_token` is set to True, *required*):
Embedding from the model
Returns:
`torch.Tensor` of shape `(bs, num_targets)`
"""
if self.use_cls_token:
pooled_embedding = embedding[:, :, 0, :]
elif self.pooling_type == "mean":
pooled_embedding = embedding.mean(dim=2)
elif self.pooling_type == "max":
pooled_embedding = embedding.max(dim=2).values
else:
raise ValueError(f"pooling operator {self.pooling_type} is not implemented yet")
pooled_embedding = self.flatten(pooled_embedding)
output = self.linear(self.dropout(pooled_embedding))
return output
@add_start_docstrings(
"The PatchTST for classification model.",
PATCHTST_START_DOCSTRING,
)
class PatchTSTForClassification(PatchTSTPreTrainedModel):
def __init__(self, config: PatchTSTConfig):
super().__init__(config)
if config.do_mask_input:
logger.warning("Setting `do_mask_input` parameter to False.")
config.do_mask_input = False
self.model = PatchTSTModel(config)
self.head = PatchTSTClassificationHead(config)
self.post_init()
def forward(
self,
past_values: torch.Tensor,
target_values: torch.Tensor = None,
past_observed_mask: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
"""
Forward pass of the PatchTSTForClassification model.
Parameters:
- past_values: Tensor of past input values.
- target_values: Optional tensor of target values.
- past_observed_mask: Optional boolean mask for observed values.
- output_hidden_states: Optional boolean to output hidden states.
- output_attentions: Optional boolean to output attentions.
- return_dict: Optional boolean to return a dictionary.
Returns:
- Depending on configurations, returns classification predictions.
"""
pass
@add_start_docstrings(
"The PatchTST for regression Model.",
PATCHTST_START_DOCSTRING,
)
class PatchTSTPredictionHead(nn.Module):
def __init__(self, config: PatchTSTConfig, num_patches, distribution_output=None):
super().__init__()
self.share_projection = config.share_projection
self.num_input_channels = config.num_input_channels
self.use_cls_token = config.use_cls_token
self.pooling_type = config.pooling_type
if self.pooling_type or self.use_cls_token:
head_dim = config.d_model
else:
head_dim = config.d_model * num_patches
if not self.share_projection:
self.projections = nn.ModuleList()
self.dropouts = nn.ModuleList()
self.flattens = nn.ModuleList()
for i in range(self.num_input_channels):
self.flattens.append(nn.Flatten(start_dim=2))
if distribution_output is None:
self.projections.append(nn.Linear(head_dim, config.prediction_length))
else:
self.projections.append(distribution_output.get_parameter_projection(head_dim))
self.dropouts.append(nn.Dropout(config.head_dropout) if config.head_dropout > 0 else nn.Identity())
else:
self.flatten = nn.Flatten(start_dim=2)
if distribution_output is None:
self.projection = nn.Linear(head_dim, config.prediction_length)
else:
self.projection = distribution_output.get_parameter_projection(head_dim)
self.dropout = nn.Dropout(config.head_dropout) if config.head_dropout > 0 else nn.Identity()
def forward(self, embedding: torch.Tensor):
"""
Parameters:
embedding (`torch.Tensor` of shape `(bs, num_channels, num_patches, d_model)` or
`(bs, num_channels, num_patches+1, d_model)` if `cls_token` is set to True, *required*):
Embedding from the model
Returns:
`torch.Tensor` of shape `(bs, forecast_len, num_channels)`
"""
if self.use_cls_token:
pooled_embedding = embedding[:, :, 0, :]
else:
if self.pooling_type == "mean":
pooled_embedding = embedding.mean(dim=2)
elif self.pooling_type == "max":
pooled_embedding = embedding.max(dim=2).values
else:
pooled_embedding = embedding
if not self.share_projection:
output = []
for i in range(self.num_input_channels):
pooled_embedding = self.flattens[i](pooled_embedding[:, i, :])
pooled_embedding = self.dropouts[i](pooled_embedding)
pooled_embedding = self.projections[i](pooled_embedding)
output.append(pooled_embedding)
output = torch.stack(output, dim=1)
else:
pooled_embedding = self.flatten(pooled_embedding)
pooled_embedding = self.dropout(pooled_embedding)
output = self.projection(pooled_embedding)
if isinstance(output, tuple):
output = tuple(z.transpose(2, 1) for z in output)
else:
output = output.transpose(2, 1)
return output
@add_start_docstrings(
"The PatchTST for prediction model.",
PATCHTST_START_DOCSTRING,
)
class PatchTSTForPrediction(PatchTSTPreTrainedModel):
def __init__(self, config: PatchTSTConfig):
super().__init__(config)
if config.do_mask_input:
logger.warning("Setting `do_mask_input` parameter to False.")
config.do_mask_input = False
self.model = PatchTSTModel(config)
if config.loss == "mse":
self.distribution_output = None
else:
if config.distribution_output == "student_t":
self.distribution_output = StudentTOutput(dim=config.prediction_length)
elif config.distribution_output == "normal":
self.distribution_output = NormalOutput(dim=config.prediction_length)
elif config.distribution_output == "negative_binomial":
self.distribution_output = NegativeBinomialOutput(dim=config.prediction_length)
else:
raise ValueError(f"Unknown distribution output {config.distribution_output}")
self.head = PatchTSTPredictionHead(
config, self.model.patchifier.num_patches, distribution_output=self.distribution_output
)
self.post_init()
def forward(
self,
past_values: torch.Tensor,
past_observed_mask: Optional[torch.Tensor] = None,
future_values: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
...
def generate(
self,
past_values: torch.Tensor,
past_observed_mask: Optional[torch.Tensor] = None,
...
) -> SamplePatchTSTOutput:
"""
Generate sequences of sample predictions from a model with a probability distribution head.
Parameters:
past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
Past values of the time series that serves as context in order to predict the future.
past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
in `[0, 1]`:
- 1 for values that are **observed**,
- 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).
Return:
[`SamplePatchTSTOutput`] where the outputs `sequences` tensor will have shape `(batch_size, number of
samples, prediction_length, 1)` or `(batch_size, number of samples, prediction_length, num_input_channels)`
for multivariate predictions.
"""
num_parallel_samples = self.config.num_parallel_samples
outputs = self(
past_values=past_values,
future_values=None,
past_observed_mask=past_observed_mask,
output_hidden_states=False,
)
if self.distribution_output:
distribution = self.distribution_output.distribution(
outputs.prediction_outputs, loc=outputs.loc, scale=outputs.scale
)
samples = [distribution.sample() for _ in range(num_parallel_samples)]
samples = torch.stack(samples, dim=1)
else:
samples = outputs.prediction_outputs.unsqueeze(1)
return SamplePatchTSTOutput(sequences=samples)
class PatchTSTRegressionHead(nn.Module):
"""
Regression head
"""
def __init__(self, config: PatchTSTConfig, distribution_output=None):
super().__init__()
self.y_range = config.output_range
self.use_cls_token = config.use_cls_token
self.pooling_type = config.pooling_type
self.distribution_output = distribution_output
head_dim = config.num_input_channels * config.d_model
self.flatten = nn.Flatten(start_dim=1)
self.dropout = nn.Dropout(config.head_dropout) if config.head_dropout > 0 else nn.Identity()
if distribution_output is None:
self.projection = nn.Linear(head_dim, config.num_targets)
else:
self.projection = distribution_output.get_parameter_projection(head_dim)
def forward(self, embedding: torch.Tensor):
"""
Parameters:
embedding (`torch.Tensor` of shape `(bs, num_channels, num_patches, d_model)` or
`(bs, num_channels, num_patches+1, d_model)` if `cls_token` is set to True, *required*):
Embedding from the model
Returns:
`torch.Tensor` of shape `(bs, output_dim)`
"""
if self.use_cls_token:
pooled_embedding = embedding[:, :, 0, :]
elif self.pooling_type == "mean":
pooled_embedding = embedding.mean(dim=2)
elif self.pooling_type == "max":
pooled_embedding = embedding.max(dim=2).values
else:
raise ValueError(f"pooling operator {self.pooling_type} is not implemented yet")
pooled_embedding = self.dropout(self.flatten(pooled_embedding))
output = self.projection(pooled_embedding)
if (self.distribution_output is None) & (self.y_range is not None):
output = torch.sigmoid(output) * (self.y_range[1] - self.y_range[0]) + self.y_range[0]
return output
@add_start_docstrings(
"The PatchTST for regression model.",
PATCHTST_START_DOCSTRING,
)
class PatchTSTForRegression(PatchTSTPreTrainedModel):
"""
PatchTST for regression model.
Inherits from PatchTSTPreTrainedModel.
"""
def __init__(self, config: PatchTSTConfig):
super().__init__(config)
if config.do_mask_input:
logger.warning("Setting `do_mask_input` parameter to False.")
config.do_mask_input = False
self.model = PatchTSTModel(config)
if config.loss == "mse":
self.distribution_output = None
else:
if config.distribution_output == "student_t":
self.distribution_output = StudentTOutput(dim=config.num_targets)
elif config.distribution_output == "normal":
self.distribution_output = NormalOutput(dim=config.num_targets)
elif config.distribution_output == "negative_binomial":
self.distribution_output = NegativeBinomialOutput(dim=config.num_targets)
else:
raise ValueError(f"Unknown distribution output {config.distribution_output}")
self.head = PatchTSTRegressionHead(config, self.distribution_output)
self.post_init()
def forward(
self,
past_values: torch.Tensor,
target_values: torch.Tensor = None,
past_observed_mask: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
def generate(
self,
past_values: torch.Tensor,
past_observed_mask: Optional[torch.Tensor] = None,
) -> SamplePatchTSTOutput:
"""
从具有概率分布输出头的模型生成样本预测序列。
Parameters:
past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
时间序列的过去值,用作上下文以预测未来。
past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
布尔掩码,指示哪些 `past_values` 是观察到的,哪些是缺失的。掩码的取值范围为 `[0, 1]`:
- 1 表示 **观察到** 的值,
- 0 表示 **缺失** 的值(即被 NaN 替换为零)。
Return:
[`SamplePatchTSTOutput`],输出的 `sequences` 张量形状为 `(batch_size, number of samples, num_targets)`。
"""
num_parallel_samples = self.config.num_parallel_samples
outputs = self(
past_values=past_values,
target_values=None,
past_observed_mask=past_observed_mask,
output_hidden_states=False,
)
distribution = self.distribution_output.distribution(outputs.regression_outputs)
samples = [distribution.sample() for _ in range(num_parallel_samples)]
samples = torch.stack(samples, dim=1).view(-1, num_parallel_samples, self.config.num_targets)
return SamplePatchTSTOutput(sequences=samples)
.\models\patchtst\__init__.py
from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
_import_structure = {
"configuration_patchtst": [
"PATCHTST_PRETRAINED_CONFIG_ARCHIVE_MAP",
"PatchTSTConfig",
],
}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_patchtst"] = [
"PATCHTST_PRETRAINED_MODEL_ARCHIVE_LIST",
"PatchTSTModel",
"PatchTSTPreTrainedModel",
"PatchTSTForPrediction",
"PatchTSTForPretraining",
"PatchTSTForRegression",
"PatchTSTForClassification",
]
if TYPE_CHECKING:
from .configuration_patchtst import PATCHTST_PRETRAINED_CONFIG_ARCHIVE_MAP, PatchTSTConfig
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_patchtst import (
PATCHTST_PRETRAINED_MODEL_ARCHIVE_LIST,
PatchTSTForClassification,
PatchTSTForPrediction,
PatchTSTForPretraining,
PatchTSTForRegression,
PatchTSTModel,
PatchTSTPreTrainedModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
.\models\pegasus\configuration_pegasus.py
""" PEGASUS model configuration"""
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"google/pegasus-large": "https://huggingface.co/google/pegasus-large/resolve/main/config.json",
}
class PegasusConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`PegasusModel`]. It is used to instantiate an
PEGASUS model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the PEGASUS
[google/pegasus-large](https://huggingface.co/google/pegasus-large) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Example:
```
>>> from transformers import PegasusConfig, PegasusModel
>>> # Initializing a PEGASUS google/pegasus-large style configuration
>>> configuration = PegasusConfig()
>>> # Initializing a model (with random weights) from the google/pegasus-large style configuration
>>> model = PegasusModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "pegasus"
keys_to_ignore_at_inference = ["past_key_values"]
attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
def __init__(
self,
vocab_size=50265,
max_position_embeddings=1024,
encoder_layers=12,
encoder_ffn_dim=4096,
encoder_attention_heads=16,
decoder_layers=12,
decoder_ffn_dim=4096,
decoder_attention_heads=16,
encoder_layerdrop=0.0,
decoder_layerdrop=0.0,
use_cache=True,
is_encoder_decoder=True,
activation_function="gelu",
d_model=1024,
dropout=0.1,
attention_dropout=0.0,
activation_dropout=0.0,
init_std=0.02,
decoder_start_token_id=0,
scale_embedding=False,
pad_token_id=0,
eos_token_id=1,
forced_eos_token_id=1,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.d_model = d_model
self.encoder_ffn_dim = encoder_ffn_dim
self.encoder_layers = encoder_layers
self.encoder_attention_heads = encoder_attention_heads
self.decoder_ffn_dim = decoder_ffn_dim
self.decoder_layers = decoder_layers
self.decoder_attention_heads = decoder_attention_heads
self.dropout = dropout
self.attention_dropout = attention_dropout
self.activation_dropout = activation_dropout
self.activation_function = activation_function
self.init_std = init_std
self.encoder_layerdrop = encoder_layerdrop
self.decoder_layerdrop = decoder_layerdrop
self.use_cache = use_cache
self.num_hidden_layers = encoder_layers
self.scale_embedding = scale_embedding
super().__init__(
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
is_encoder_decoder=is_encoder_decoder,
decoder_start_token_id=decoder_start_token_id,
forced_eos_token_id=forced_eos_token_id,
**kwargs,
)
@property
def num_attention_heads(self) -> int:
return self.encoder_attention_heads
@property
def hidden_size(self) -> int:
return self.d_model
.\models\pegasus\convert_pegasus_tf_to_pytorch.py
PATTERNS = [
["memory_attention", "encoder_attn"],
["attention", "attn"],
["/", "."],
[".LayerNorm.gamma", "_layer_norm.weight"],
[".LayerNorm.beta", "_layer_norm.bias"],
["r.layer_", "r.layers."],
["output_proj", "out_proj"],
["ffn.dense_1.", "fc2."],
["ffn.dense.", "fc1."],
["ffn_layer_norm", "final_layer_norm"],
["kernel", "weight"],
["encoder_layer_norm.", "encoder.layer_norm."],
["decoder_layer_norm.", "decoder.layer_norm."],
["embeddings.weights", "shared.weight"],
]
def rename_state_dict_key(k):
for pegasus_name, hf_name in PATTERNS:
k = k.replace(pegasus_name, hf_name)
return k
def convert_pegasus(tf_weights: dict, cfg_updates: dict) -> PegasusForConditionalGeneration:
cfg_kwargs = DEFAULTS.copy()
cfg_kwargs.update(cfg_updates)
cfg = PegasusConfig(**cfg_kwargs)
torch_model = PegasusForConditionalGeneration(cfg)
sd = torch_model.model.state_dict()
mapping = {}
for k, v in tf_weights.items():
new_k = rename_state_dict_key(k)
if new_k not in sd:
raise ValueError(f"could not find new key {new_k} in state dict. (converted from {k})")
if "dense" in k or "proj" in new_k:
v = v.T
mapping[new_k] = torch.tensor(v, dtype=sd[new_k].dtype)
assert v.shape == sd[new_k].shape, f"{new_k}, {k}, {v.shape}, {sd[new_k].shape}"
mapping["shared.weight"][cfg.pad_token_id] = torch.zeros_like(mapping["shared.weight"][cfg.pad_token_id + 1])
mapping["encoder.embed_tokens.weight"] = mapping["shared.weight"]
mapping["decoder.embed_tokens.weight"] = mapping["shared.weight"]
empty_biases = {k: torch.zeros_like(v) for k, v in sd.items() if k.endswith("bias") and k not in mapping}
mapping.update(**empty_biases)
missing, extra = torch_model.model.load_state_dict(mapping, strict=False)
unexpected_missing = [
k for k in missing if k not in ["encoder.embed_positions.weight", "decoder.embed_positions.weight"]
]
assert unexpected_missing == [], f"no matches found for the following torch keys {unexpected_missing}"
assert extra == [], f"no matches found for the following tf keys {extra}"
return torch_model
import argparse
from pathlib import Path
from typing import Dict
from tqdm import tqdm
import tensorflow as tf
from transformers import PegasusTokenizer
from utils import task_specific_params
from convert_pegasus import convert_pegasus
def get_tf_weights_as_numpy(path="./ckpt/aeslc/model.ckpt-32000") -> Dict:
init_vars = tf.train.list_variables(path)
tf_weights = {}
ignore_name = ["Adafactor", "global_step"]
for name, shape in tqdm(init_vars, desc="converting tf checkpoint to dict"):
skip_key = any(pat in name for pat in ignore_name)
if skip_key:
continue
array = tf.train.load_variable(path, name)
tf_weights[name] = array
return tf_weights
def convert_pegasus_ckpt_to_pytorch(ckpt_path: str, save_dir: str):
dataset = Path(ckpt_path).parent.name
desired_max_model_length = task_specific_params[f"summarization_{dataset}"]["max_position_embeddings"]
tok = PegasusTokenizer.from_pretrained("sshleifer/pegasus", model_max_length=desired_max_model_length)
assert tok.model_max_length == desired_max_model_length
tok.save_pretrained(save_dir)
tf_weights = get_tf_weights_as_numpy(ckpt_path)
cfg_updates = task_specific_params[f"summarization_{dataset}"]
if dataset == "large":
cfg_updates["task_specific_params"] = task_specific_params
torch_model = convert_pegasus(tf_weights, cfg_updates)
torch_model.save_pretrained(save_dir)
sd = torch_model.state_dict()
sd.pop("model.decoder.embed_positions.weight")
sd.pop("model.encoder.embed_positions.weight")
torch.save(sd, Path(save_dir) / "pytorch_model.bin")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("tf_ckpt_path", type=str, help="passed to tf.train.list_variables")
parser.add_argument("save_dir", default=None, type=str, help="Path to the output PyTorch model.")
args = parser.parse_args()
if args.save_dir is None:
dataset = Path(args.tf_ckpt_path).parent.name
args.save_dir = os.path.join("pegasus", dataset)
convert_pegasus_ckpt_to_pytorch(args.tf_ckpt_path, args.save_dir)
.\models\pegasus\modeling_flax_pegasus.py
import math
import random
from functools import partial
from typing import Callable, Optional, Tuple
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask
from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax
from jax.random import PRNGKey
from ...modeling_flax_outputs import (
FlaxBaseModelOutput,
FlaxBaseModelOutputWithPastAndCrossAttentions,
FlaxCausalLMOutputWithCrossAttentions,
FlaxSeq2SeqLMOutput,
FlaxSeq2SeqModelOutput,
)
from ...modeling_flax_utils import (
ACT2FN,
FlaxPreTrainedModel,
add_start_docstrings_to_model_forward,
append_call_sample_docstring,
append_replace_return_docstrings,
overwrite_call_docstring,
)
from ...utils import add_start_docstrings, logging, replace_return_docstrings
from .configuration_pegasus import PegasusConfig
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "google/pegasus-large"
_CONFIG_FOR_DOC = "PegasusConfig"
PEGASUS_START_DOCSTRING = r"""
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a Flax Linen
[flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
Finally, this model supports inherent JAX features such as:
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
"""
Parameters:
config ([`PegasusConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
`jax.numpy.bfloat16` (on TPUs).
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
specified all the computation will be performed with the given `dtype`.
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
parameters.**
If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
[`~FlaxPreTrainedModel.to_bf16`].
"""
PEGASUS_INPUTS_DOCSTRING = r"""
Args:
input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary
attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary
decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
Indices of decoder input sequence tokens in the vocabulary.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are decoder input IDs?](../glossary
decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
be used by default.
If you want to change padding behavior, you should modify to your needs. See diagram 1 in [the
paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy.
position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.max_position_embeddings - 1]`.
decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
range `[0, config.max_position_embeddings - 1]`.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
PEGASUS_ENCODE_INPUTS_DOCSTRING = r"""
Args:
input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens to be encoded. These indices are obtained using a tokenizer, typically
from a list of input strings. Each index corresponds to a token in the vocabulary.
attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. This mask tensor has shape `(batch_size,
sequence_length)`, where each value is either 1 (token is not masked) or 0 (token is masked, typically
because it's padding).
position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence token in the position embeddings matrix. These indices range
from 0 to `config.max_position_embeddings - 1`.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. If `True`, the returned output
will include attention tensors from all layers of the model.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. If `True`, the returned output will include
hidden states from all layers of the model.
return_dict (`bool`, *optional*):
Whether or not to return a `utils.ModelOutput` object instead of a plain tuple. If `True`, the output
will be encapsulated in a structured object that includes additional metadata.
"""
Args:
input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
# 输入序列标记的索引数组,形状为(batch_size, sequence_length)。默认情况下会忽略填充部分。
# 可以使用AutoTokenizer获取这些索引。参见PreTrainedTokenizer.encode和PreTrainedTokenizer.__call__获取详细信息。
# 输入IDs是什么?详见../glossary#input-ids
attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
# 避免在填充的标记索引上执行注意力操作的掩码。掩码值选择在[0, 1]范围内:
# - 1表示**未屏蔽**的标记,
# - 0表示**屏蔽**的标记。
# 注意掩码是什么?详见../glossary#attention-mask
position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
# 输入序列每个标记在位置嵌入中的位置索引数组,形状为(batch_size, sequence_length)。
# 索引值选在范围[0, config.max_position_embeddings - 1]内。
output_attentions (`bool`, *optional*):
# 是否返回所有注意力层的注意力张量。更多详细信息参见返回的张量中的'attentions'部分。
output_hidden_states (`bool`, *optional*):
# 是否返回所有层的隐藏状态。更多详细信息参见返回的张量中的'hidden_states'部分。
return_dict (`bool`, *optional*):
# 是否返回一个utils.ModelOutput而不是普通的元组。
# PEGASUS_DECODE_INPUTS_DOCSTRING 是一个原始字符串(raw string),用于文档化 Pegasus 解码函数的输入参数及其含义。
PEGASUS_DECODE_INPUTS_DOCSTRING = r"""
Args:
decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`):
解码器输入序列标记在词汇表中的索引。
索引可以使用 [`AutoTokenizer`] 获得。有关详细信息,请参阅 [`PreTrainedTokenizer.encode`] 和
[`PreTrainedTokenizer.__call__`]。
[什么是解码器输入ID?](../glossary#decoder-input-ids)
encoder_outputs (`tuple(tuple(jnp.ndarray)`):
元组包括 (`last_hidden_state`, *可选*: `hidden_states`, *可选*: `attentions`)
`last_hidden_state` 的形状为 `(batch_size, sequence_length, hidden_size)`,*可选*) 是编码器最后一层的隐藏状态序列。用于解码器的交叉注意力。
encoder_attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *可选*):
遮罩,避免在填充的标记索引上执行注意力。遮罩值选择在 `[0, 1]`:
- 对于 **未遮罩** 的标记为 1,
- 对于 **遮罩** 的标记为 0.
[什么是注意力遮罩?](../glossary#attention-mask)
decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *可选*):
默认行为: 生成一个张量,忽略 `decoder_input_ids` 中的填充标记。默认情况下也将使用因果遮罩。
如果要更改填充行为,应根据需求进行修改。有关默认策略的更多信息,请参见 [论文中的图表 1](https://arxiv.org/abs/1910.13461)。
decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *可选*):
每个解码器输入序列标记在位置嵌入中的位置索引。选取范围为 `[0, config.max_position_embeddings - 1]`。
past_key_values (`Dict[str, np.ndarray]`, *可选*, 由 `init_cache` 返回或传递先前的 `past_key_values`):
预计算的隐藏状态的字典(在注意力块中的键和值)。用于快速自回归解码。预计算的键和值隐藏状态的形状为 *[batch_size, max_length]*。
output_attentions (`bool`, *可选*):
是否返回所有注意力层的注意力张量。有关返回张量的更多细节,请参见返回的张量下的 `attentions`。
output_hidden_states (`bool`, *可选*):
是否返回所有层的隐藏状态。有关返回张量的更多细节,请参见返回的张量下的 `hidden_states`。
return_dict (`bool`, *可选*):
是否返回 [`~utils.ModelOutput`] 而不是普通元组。
"""
# 将输入的 token ID 右移一个位置
def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
"""
Shift input ids one token to the right.
"""
# 创建与 input_ids 相同形状的全零数组
shifted_input_ids = jnp.zeros_like(input_ids)
# 将 input_ids 的每一行,从第二列开始到末尾的数据,复制到 shifted_input_ids 的对应位置
shifted_input_ids = shifted_input_ids.at[:, 1:].set(input_ids[:, :-1])
# 将每一行的第一列设为 decoder_start_token_id
shifted_input_ids = shifted_input_ids.at[:, 0].set(decoder_start_token_id)
# 将值为 -100 的位置(特殊标记),替换为 pad_token_id
shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
return shifted_input_ids
# 从 transformers.models.marian.modeling_flax_marian.create_sinusoidal_positions 复制而来
# 创建一个正弦位置编码矩阵
def create_sinusoidal_positions(n_pos, dim):
# 根据位置和维度生成一个正弦位置编码矩阵
position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)])
sentinel = dim // 2 + dim % 2
out = np.zeros_like(position_enc)
# 将位置编码矩阵中偶数索引列的值设置为正弦值
out[:, 0:sentinel] = np.sin(position_enc[:, 0::2])
# 将位置编码矩阵中奇数索引列的值设置为余弦值
out[:, sentinel:] = np.cos(position_enc[:, 1::2])
return jnp.array(out)
# 从 transformers.models.bart.modeling_flax_bart.FlaxBartAttention 复制而来,将 Bart 替换为 Pegasus
# 定义 Pegasus 注意力机制的模块
class FlaxPegasusAttention(nn.Module):
config: PegasusConfig
embed_dim: int
num_heads: int
dropout: float = 0.0
causal: bool = False
bias: bool = True
dtype: jnp.dtype = jnp.float32 # 计算时的数据类型
def setup(self) -> None:
# 计算每个头部的维度
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {self.num_heads})."
)
# 定义用于计算的全连接层,初始化方式为正态分布
dense = partial(
nn.Dense,
self.embed_dim,
use_bias=self.bias,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
# 分别为查询、键、值和输出定义全连接层
self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
self.out_proj = dense()
# 定义 dropout 层
self.dropout_layer = nn.Dropout(rate=self.dropout)
# 如果 causal 为 True,则创建一个因果遮罩
if self.causal:
self.causal_mask = make_causal_mask(
jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool"
)
# 将隐藏状态按头部数和头部维度进行分割
def _split_heads(self, hidden_states):
return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
# 将隐藏状态的头部合并回原始形状
def _merge_heads(self, hidden_states):
return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
@nn.compact
def _concatenate_to_cache(self, key, value, query, attention_mask):
"""
This function takes projected key, value states from a single input token and concatenates the states to cached
states from previous steps. This function is slightly adapted from the official Flax repository:
https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
"""
# 检测是否初始化,通过检查缓存数据是否存在来判断
is_initialized = self.has_variable("cache", "cached_key")
# 获取或初始化缓存的键(key)和值(value)
cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
# 获取或初始化缓存索引,用于追踪当前缓存的位置
cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
if is_initialized:
# 获取当前缓存的形状信息,包括批次维度、最大长度、头数和每个头的深度
*batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
# 使用新的一维空间切片更新键(key)和值(value)缓存
cur_index = cache_index.value
indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
key = lax.dynamic_update_slice(cached_key.value, key, indices)
value = lax.dynamic_update_slice(cached_value.value, value, indices)
cached_key.value = key
cached_value.value = value
# 更新缓存索引,增加已更新的缓存向量数量
num_updated_cache_vectors = query.shape[1]
cache_index.value = cache_index.value + num_updated_cache_vectors
# 为缓存的解码器自注意力生成因果遮罩:当前查询位置只能注意到已生成和缓存的键位置,而不是剩余的零元素。
pad_mask = jnp.broadcast_to(
jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
)
attention_mask = combine_masks(pad_mask, attention_mask)
# 返回更新后的键(key)、值(value)和注意力遮罩
return key, value, attention_mask
# 从transformers.models.mbart.modeling_flax_mbart.FlaxMBartEncoderLayer复制而来,将MBart改为Pegasus
class FlaxPegasusEncoderLayer(nn.Module):
# Pegasus模型的配置
config: PegasusConfig
# 计算中使用的数据类型,默认为32位浮点数
dtype: jnp.dtype = jnp.float32
# 设置方法,初始化编码器层的组件
def setup(self) -> None:
# 嵌入维度等于模型配置中的d_model参数
self.embed_dim = self.config.d_model
# Pegasus自注意力机制
self.self_attn = FlaxPegasusAttention(
config=self.config,
embed_dim=self.embed_dim,
num_heads=self.config.encoder_attention_heads,
dropout=self.config.attention_dropout,
dtype=self.dtype,
)
# 层归一化层,用于自注意力输出
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
# 用于自注意力输出的Dropout层
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
# 激活函数,根据配置中的激活函数选择对应的激活函数
self.activation_fn = ACT2FN[self.config.activation_function]
# 激活函数后的Dropout层
self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
# 第一个全连接层,用于前馈神经网络
self.fc1 = nn.Dense(
self.config.encoder_ffn_dim,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
# 第二个全连接层,用于前馈神经网络的输出
self.fc2 = nn.Dense(
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
)
# 最终的层归一化层,用于前馈神经网络的输出
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
# 调用方法,执行编码器层的前向计算
def __call__(
self,
hidden_states: jnp.ndarray,
attention_mask: jnp.ndarray,
output_attentions: bool = True,
deterministic: bool = True,
) -> Tuple[jnp.ndarray]:
# 保存残差连接的输入
residual = hidden_states
# 对输入进行自注意力输出的层归一化处理
hidden_states = self.self_attn_layer_norm(hidden_states)
# 执行自注意力机制计算,并返回计算后的隐藏状态及注意力权重
hidden_states, attn_weights = self.self_attn(hidden_states=hidden_states, attention_mask=attention_mask)
# 应用Dropout层,以减少过拟合风险
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
# 添加残差连接
hidden_states = residual + hidden_states
# 保存残差连接的输入
residual = hidden_states
# 对输入进行最终输出的层归一化处理
hidden_states = self.final_layer_norm(hidden_states)
# 使用激活函数处理第一个全连接层的输出
hidden_states = self.activation_fn(self.fc1(hidden_states))
# 应用激活函数后的Dropout层
hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic)
# 执行第二个全连接层的计算
hidden_states = self.fc2(hidden_states)
# 应用Dropout层,以减少过拟合风险
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
# 添加残差连接
hidden_states = residual + hidden_states
# 构建输出元组,包含最终的隐藏状态
outputs = (hidden_states,)
# 如果需要输出注意力权重,则添加到输出元组中
if output_attentions:
outputs += (attn_weights,)
# 返回最终的输出元组
return outputs
# 从transformers.models.bart.modeling_flax_bart.FlaxBartEncoderLayerCollection复制而来,将Bart改为Pegasus
class FlaxPegasusEncoderLayerCollection(nn.Module):
# Pegasus模型的配置
config: PegasusConfig
# 计算中使用的数据类型,默认为32位浮点数
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
# 设置方法,初始化编码器层的集合
def setup(self):
# 创建编码器层的列表,每一层为FlaxPegasusEncoderLayer的实例
self.layers = [
FlaxPegasusEncoderLayer(self.config, name=str(i), dtype=self.dtype)
for i in range(self.config.encoder_layers)
]
# 编码器层的Dropout概率,由配置文件中的encoder_layerdrop定义
self.layerdrop = self.config.encoder_layerdrop
def __call__(
self,
hidden_states,
attention_mask,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
# 如果输出注意力权重,则初始化一个空元组用于存储所有注意力权重
all_attentions = () if output_attentions else None
# 如果输出隐藏状态,则初始化一个空元组用于存储所有隐藏状态
all_hidden_states = () if output_hidden_states else None
# 遍历每个编码器层
for encoder_layer in self.layers:
# 如果需要输出隐藏状态,则将当前隐藏状态添加到 all_hidden_states 中
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
# 使用 LayerDrop 方法来控制是否跳过当前层
dropout_probability = random.uniform(0, 1)
if not deterministic and (dropout_probability < self.layerdrop): # 如果随机数小于 layerdrop,跳过当前层
layer_outputs = (None, None)
else:
# 调用当前编码器层的前向传播方法
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
output_attentions,
deterministic,
)
# 更新隐藏状态为当前层的输出的第一个元素
hidden_states = layer_outputs[0]
# 如果需要输出注意力权重,则将当前层的注意力权重添加到 all_attentions 中
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
# 如果需要输出隐藏状态,则将最终隐藏状态添加到 all_hidden_states 中
if output_hidden_states:
all_hidden_states += (hidden_states,)
# 构建模型输出对象,根据 return_dict 决定返回格式
outputs = (hidden_states, all_hidden_states, all_attentions)
# 如果 return_dict 为 False,则以元组形式返回 outputs 中非空的部分
if not return_dict:
return tuple(v for v in outputs if v is not None)
# 如果 return_dict 为 True,则构建 FlaxBaseModelOutput 对象返回
return FlaxBaseModelOutput(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
)
# 从transformers.models.mbart.modeling_flax_mbart.FlaxMBartDecoderLayer复制到Pegasus,定义了FlaxPegasusDecoderLayer类
class FlaxPegasusDecoderLayer(nn.Module):
# 类变量:使用PegasusConfig配置
config: PegasusConfig
# 类变量:数据类型为jnp.float32
dtype: jnp.dtype = jnp.float32
# 初始化方法,无返回值
def setup(self) -> None:
# 设置self.embed_dim为配置中的d_model值,即模型的维度大小
self.embed_dim = self.config.d_model
# 初始化self.self_attn为FlaxPegasusAttention对象,用于自注意力机制
self.self_attn = FlaxPegasusAttention(
config=self.config,
embed_dim=self.embed_dim,
num_heads=self.config.decoder_attention_heads,
dropout=self.config.attention_dropout,
causal=True,
dtype=self.dtype,
)
# 初始化self.dropout_layer为Dropout层,用于随机失活
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
# 设置激活函数为配置中指定的激活函数,并初始化激活函数的随机失活层
self.activation_fn = ACT2FN[self.config.activation_function]
self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
# 初始化self.self_attn_layer_norm为LayerNorm层,用于自注意力机制的归一化
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
# 初始化self.encoder_attn为FlaxPegasusAttention对象,用于编码器注意力机制
self.encoder_attn = FlaxPegasusAttention(
config=self.config,
embed_dim=self.embed_dim,
num_heads=self.config.decoder_attention_heads,
dropout=self.config.attention_dropout,
dtype=self.dtype,
)
# 初始化self.encoder_attn_layer_norm为LayerNorm层,用于编码器注意力机制的归一化
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
# 初始化self.fc1为全连接层,用于第一个前馈神经网络
self.fc1 = nn.Dense(
self.config.decoder_ffn_dim,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
# 初始化self.fc2为全连接层,用于第二个前馈神经网络
self.fc2 = nn.Dense(
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
)
# 初始化self.final_layer_norm为LayerNorm层,用于最终的归一化
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
# 调用方法,定义了层的前向传播逻辑
def __call__(
self,
hidden_states: jnp.ndarray, # 输入的隐藏状态张量
attention_mask: jnp.ndarray, # 注意力掩码张量
encoder_hidden_states: Optional[jnp.ndarray] = None, # 编码器的隐藏状态张量(可选)
encoder_attention_mask: Optional[jnp.ndarray] = None, # 编码器的注意力掩码张量(可选)
init_cache: bool = False, # 是否初始化缓存(默认为False)
output_attentions: bool = True, # 是否输出注意力权重(默认为True)
deterministic: bool = True, # 是否使用确定性计算(默认为True)
# 方法开始
# 返回self.self_attn的前向传播结果,对输入的hidden_states进行自注意力计算
# 返回值包括输出张量以及注意力权重(如果output_attentions为True)
return self.self_attn(
hidden_states,
attention_mask,
init_cache=init_cache,
output_attentions=output_attentions,
deterministic=deterministic,
)
) -> Tuple[jnp.ndarray]:
# 将输入的 hidden_states 复制给 residual,用于后续的残差连接
residual = hidden_states
# 对 hidden_states 进行 Layer Normalization
hidden_states = self.self_attn_layer_norm(hidden_states)
# Self Attention
# 使用 self_attn 层处理 hidden_states,包括注意力计算和可能的缓存初始化
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states, attention_mask=attention_mask, init_cache=init_cache
)
# 应用 dropout 层
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
# 添加残差连接
hidden_states = residual + hidden_states
# Cross-Attention Block
cross_attn_weights = None
# 如果提供了 encoder_hidden_states,则执行以下操作
if encoder_hidden_states is not None:
# 将输入的 hidden_states 复制给 residual,用于后续的残差连接
residual = hidden_states
# 对 hidden_states 进行 Layer Normalization
hidden_states = self.encoder_attn_layer_norm(hidden_states)
# 使用 encoder_attn 层处理 hidden_states 和 encoder_hidden_states
# 包括注意力计算和可能的 attention_mask 应用
hidden_states, cross_attn_weights = self.encoder_attn(
hidden_states=hidden_states,
key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
)
# 应用 dropout 层
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
# 添加残差连接
hidden_states = residual + hidden_states
# Fully Connected
# 将输入的 hidden_states 复制给 residual,用于后续的残差连接
residual = hidden_states
# 对 hidden_states 进行 Layer Normalization
hidden_states = self.final_layer_norm(hidden_states)
# 应用激活函数 activation_fn
hidden_states = self.activation_fn(self.fc1(hidden_states))
# 应用 activation_dropout_layer
hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic)
# 应用全连接层 fc2
hidden_states = self.fc2(hidden_states)
# 应用 dropout 层
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
# 添加残差连接
hidden_states = residual + hidden_states
# 准备输出,初始化为包含 hidden_states 的元组 outputs
outputs = (hidden_states,)
# 如果需要输出 attention weights,则将 self_attn_weights 和 cross_attn_weights 添加到 outputs 中
if output_attentions:
outputs += (self_attn_weights, cross_attn_weights)
# 返回 outputs
return outputs
# 从transformers.models.bart.modeling_flax_bart.FlaxBartDecoderLayerCollection复制的代码,将Bart更改为Pegasus
class FlaxPegasusDecoderLayerCollection(nn.Module):
# Pegasus模型的配置
config: PegasusConfig
# 计算的数据类型
dtype: jnp.dtype = jnp.float32 # 计算的数据类型为浮点数(32位)
def setup(self):
# 创建Pegasus解码器层的集合
self.layers = [
FlaxPegasusDecoderLayer(self.config, name=str(i), dtype=self.dtype)
for i in range(self.config.decoder_layers)
]
# 解码器层的LayerDrop概率
self.layerdrop = self.config.decoder_layerdrop
def __call__(
self,
hidden_states,
attention_mask,
encoder_hidden_states: Optional[jnp.ndarray] = None,
encoder_attention_mask: Optional[jnp.ndarray] = None,
deterministic: bool = True,
init_cache: bool = False,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
# 存储所有隐藏状态(如果需要返回)
all_hidden_states = () if output_hidden_states else None
# 存储所有自注意力权重(如果需要返回)
all_self_attns = () if output_attentions else None
# 存储所有跨注意力权重(如果需要返回),仅在同时输出注意力并且存在编码器隐藏状态时才存储
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
# 对每个解码器层进行迭代
for decoder_layer in self.layers:
if output_hidden_states:
# 如果需要输出隐藏状态,则将当前隐藏状态添加到存储中
all_hidden_states += (hidden_states,)
# 添加LayerDrop(参见https://arxiv.org/abs/1909.11556进行描述)
# 随机采样一个Dropout概率
dropout_probability = random.uniform(0, 1)
# 如果是非确定性计算并且随机概率小于LayerDrop阈值,则将输出设为None
if not deterministic and (dropout_probability < self.layerdrop):
layer_outputs = (None, None, None)
else:
# 否则,调用当前解码器层进行前向传播计算
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
init_cache=init_cache,
output_attentions=output_attentions,
deterministic=deterministic,
)
# 更新隐藏状态为当前解码器层的输出的第一个元素
hidden_states = layer_outputs[0]
# 如果需要输出注意力权重,则将当前解码器层的自注意力权重添加到存储中
if output_attentions:
all_self_attns += (layer_outputs[1],)
# 如果存在编码器隐藏状态,则将当前解码器层的跨注意力权重添加到存储中
if encoder_hidden_states is not None:
all_cross_attentions += (layer_outputs[2],)
# 如果需要输出最终隐藏状态,则将最终隐藏状态添加到存储中
if output_hidden_states:
all_hidden_states += (hidden_states,)
# 将所有需要输出的结果存储在outputs列表中
outputs = [hidden_states, all_hidden_states, all_self_attns, all_cross_attentions]
# 如果不需要以字典形式返回结果,则返回输出列表中的非None元素
if not return_dict:
return tuple(v for v in outputs if v is not None)
# 以包含过去和跨注意力权重的形式返回FlaxBaseModelOutputWithPastAndCrossAttentions对象
return FlaxBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attns,
cross_attentions=all_cross_attentions,
)
class FlaxPegasusEncoder(nn.Module):
# Pegasus模型的配置
config: PegasusConfig
# 嵌入令牌的层
embed_tokens: nn.Embed
# 计算的数据类型
dtype: jnp.dtype = jnp.float32 # 计算的数据类型为浮点数(32位)
# 在类初始化方法中设置dropout层,根据配置中的dropout率创建实例
def setup(self):
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
# 从配置中获取词嵌入的维度
embed_dim = self.config.d_model
# 从配置中获取填充标记的索引
self.padding_idx = self.config.pad_token_id
# 从配置中获取源序列的最大位置数
self.max_source_positions = self.config.max_position_embeddings
# 如果配置中设置了缩放词嵌入,则计算缩放因子为词嵌入维度的平方根,否则为1.0
self.embed_scale = math.sqrt(embed_dim) if self.config.scale_embedding else 1.0
# 创建正弦位置编码,并赋值给self.embed_positions
self.embed_positions = create_sinusoidal_positions(self.config.max_position_embeddings, embed_dim)
# 创建FlaxPegasusEncoderLayerCollection实例,用于后续编码器层的处理
self.layers = FlaxPegasusEncoderLayerCollection(self.config, self.dtype)
# 创建LayerNorm实例,用于对隐藏状态进行归一化处理
self.layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
# 在实例调用时处理输入数据,执行编码器操作
def __call__(
self,
input_ids,
attention_mask,
position_ids,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
deterministic: bool = True,
):
# 获取输入数据的形状信息
input_shape = input_ids.shape
# 将input_ids重新reshape为二维张量
input_ids = input_ids.reshape(-1, input_shape[-1])
# 使用嵌入词表将input_ids转换为嵌入向量,并乘以缩放因子
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
# 嵌入位置信息
embed_pos = jnp.take(self.embed_positions, position_ids, axis=0)
# 显式地将位置信息embed_pos转换为与inputs_embeds相同的数据类型
embed_pos = embed_pos.astype(inputs_embeds.dtype)
# 将输入嵌入向量与位置嵌入向量相加形成隐藏状态
hidden_states = inputs_embeds + embed_pos
# 对隐藏状态应用dropout操作
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
# 将隐藏状态传递给编码器层进行处理
outputs = self.layers(
hidden_states,
attention_mask,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# 获取编码器层输出中的最后一个隐藏状态
last_hidden_state = outputs[0]
# 对最后一个隐藏状态应用LayerNorm归一化
last_hidden_state = self.layer_norm(last_hidden_state)
# 如果需要输出所有隐藏状态,则更新outputs中的hidden_states
hidden_states = None
if output_hidden_states:
hidden_states = outputs[1]
hidden_states = hidden_states[:-1] + (last_hidden_state,)
# 如果不返回字典,则根据需要重新组织输出的元组
if not return_dict:
outputs = (last_hidden_state, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:])
return tuple(v for v in outputs if v is not None)
# 返回FlaxBaseModelOutput对象,包括最后一个隐藏状态、所有隐藏状态和注意力权重(如果有)
return FlaxBaseModelOutput(
last_hidden_state=last_hidden_state,
hidden_states=hidden_states,
attentions=outputs.attentions,
)
class FlaxPegasusDecoder(nn.Module):
config: PegasusConfig # Pegasus model configuration
embed_tokens: nn.Embed # Embedding tokens for input sequence
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
self.dropout_layer = nn.Dropout(rate=self.config.dropout) # Dropout layer initialization
embed_dim = self.config.d_model # Dimension of the embedding
self.padding_idx = self.config.pad_token_id # Padding token index from configuration
self.max_target_positions = self.config.max_position_embeddings # Maximum target positions
self.embed_scale = math.sqrt(self.config.d_model) if self.config.scale_embedding else 1.0 # Embedding scale factor
self.embed_positions = create_sinusoidal_positions(self.config.max_position_embeddings, embed_dim)
# Create sinusoidal positional embeddings
self.layers = FlaxPegasusDecoderLayerCollection(self.config, self.dtype)
# Layers of the Pegasus decoder
self.layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
# Layer normalization initialization
def __call__(
self,
input_ids,
attention_mask,
position_ids,
encoder_hidden_states: Optional[jnp.ndarray] = None,
encoder_attention_mask: Optional[jnp.ndarray] = None,
init_cache: bool = False,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
deterministic: bool = True,
):
input_shape = input_ids.shape
input_ids = input_ids.reshape(-1, input_shape[-1])
# Reshape input_ids to flatten the sequence dimensions
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
# Embed input tokens and scale embeddings
# embed positions
positions = jnp.take(self.embed_positions, position_ids, axis=0)
# Retrieve positional embeddings based on position_ids
positions = positions.astype(inputs_embeds.dtype)
# Explicitly cast positions to match inputs_embeds dtype
hidden_states = inputs_embeds + positions
# Combine token embeddings with positional embeddings
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
# Apply dropout to hidden states
outputs = self.layers(
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
deterministic=deterministic,
init_cache=init_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# Pass hidden states through decoder layers
last_hidden_state = outputs[0]
# Retrieve the last hidden state from the outputs
last_hidden_state = self.layer_norm(last_hidden_state)
# Apply layer normalization to the last hidden state
hidden_states = None
if output_hidden_states:
hidden_states = outputs[1]
hidden_states = hidden_states[:-1] + (last_hidden_state,)
# Concatenate previous hidden states with the current one
if not return_dict:
outputs = (last_hidden_state, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:])
return tuple(v for v in outputs if v is not None)
# Return outputs as a tuple without the return_dict format
return FlaxBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=last_hidden_state,
hidden_states=hidden_states,
attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
)
# Return outputs as FlaxBaseModelOutputWithPastAndCrossAttentions object
# 从 transformers.models.bart.modeling_flax_bart.FlaxBartModule 复制并修改为 Pegasus
class FlaxPegasusModule(nn.Module):
config: PegasusConfig # Pegasus 模型的配置对象
dtype: jnp.dtype = jnp.float32 # 计算时使用的数据类型
def setup(self):
# 创建共享的嵌入层,用于编码器和解码器
self.shared = nn.Embed(
self.config.vocab_size,
self.config.d_model,
embedding_init=jax.nn.initializers.normal(self.config.init_std), # 使用正态分布初始化嵌入层
dtype=self.dtype,
)
# 初始化编码器和解码器
self.encoder = FlaxPegasusEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
self.decoder = FlaxPegasusDecoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
def _get_encoder_module(self):
return self.encoder # 返回编码器模块
def _get_decoder_module(self):
return self.decoder # 返回解码器模块
def __call__(
self,
input_ids,
attention_mask,
decoder_input_ids,
decoder_attention_mask,
position_ids,
decoder_position_ids,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
deterministic: bool = True,
):
# 调用编码器得到编码器的输出
encoder_outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=deterministic,
)
# 调用解码器得到解码器的输出,传入编码器的隐藏状态和注意力掩码
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
position_ids=decoder_position_ids,
encoder_hidden_states=encoder_outputs[0],
encoder_attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=deterministic,
)
if not return_dict:
return decoder_outputs + encoder_outputs # 如果不返回字典,则返回所有输出
# 返回序列到序列模型的输出
return FlaxSeq2SeqModelOutput(
last_hidden_state=decoder_outputs.last_hidden_state,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)
class FlaxPegasusPreTrainedModel(FlaxPreTrainedModel):
config_class = PegasusConfig # Pegasus 预训练模型的配置类
base_model_prefix: str = "model" # 基础模型的前缀名称
module_class: nn.Module = None
def __init__(
self,
config: PegasusConfig,
input_shape: Tuple[int] = (1, 1),
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs,
):
# 使用配置和数据类型初始化模块对象
module = self.module_class(config=config, dtype=dtype, **kwargs)
# 调用父类的构造函数,传入配置、模块对象、输入形状、种子、数据类型和初始化标志
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# 初始化输入张量
input_ids = jnp.zeros(input_shape, dtype="i4")
# 创建与 input_ids 形状相同的全 1 张量作为注意力掩码
attention_mask = jnp.ones_like(input_ids)
# 将 decoder_input_ids 初始化为与 input_ids 相同的张量
decoder_input_ids = input_ids
# 创建与 input_ids 形状相同的全 1 张量作为解码器的注意力掩码
decoder_attention_mask = jnp.ones_like(input_ids)
# 获取 batch_size 和 sequence_length
batch_size, sequence_length = input_ids.shape
# 创建位置编码,将序列长度广播到每个样本的每个位置
position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
# 解码器的位置编码与编码器相同
decoder_position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
# 分割随机数生成器 rng 为 params_rng 和 dropout_rng
params_rng, dropout_rng = jax.random.split(rng)
# 创建随机数字典 rngs,包含 params_rng 和 dropout_rng
rngs = {"params": params_rng, "dropout": dropout_rng}
# 使用模块对象的初始化方法初始化模型参数
random_params = self.module.init(
rngs,
input_ids,
attention_mask,
decoder_input_ids,
decoder_attention_mask,
position_ids,
decoder_position_ids,
)["params"]
# 如果提供了初始参数 params,则使用它替换随机初始化的部分参数
if params is not None:
# 展平并解冻参数字典
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
# 将缺失的参数键从随机参数复制到提供的参数中
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set() # 清空缺失键集合
# 冻结并返回更新后的参数字典
return freeze(unflatten_dict(params))
else:
# 如果未提供初始参数,则直接返回随机初始化的参数字典
return random_params
def init_cache(self, batch_size, max_length, encoder_outputs):
r"""
Args:
batch_size (`int`):
fast auto-regressive decoding 使用的 batch_size。定义了初始化缓存时的批处理大小。
max_length (`int`):
自动回归解码的最大可能长度。定义了初始化缓存时的序列长度。
encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`):
`encoder_outputs` 包含 (`last_hidden_state`, *可选*: `hidden_states`, *可选*: `attentions`)。
`last_hidden_state` 的形状为 `(batch_size, sequence_length, hidden_size)`,*可选*: 编码器最后一层输出的隐藏状态。
在解码器的交叉注意力中使用。
初始化缓存函数,用于预先设置解码器的缓存状态。
"""
# 初始化解码器的输入标识,全部为1的数组,形状为 (batch_size, max_length)
decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4")
# 初始化解码器的注意力掩码,与 decoder_input_ids 形状相同的全1数组
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
# 初始化解码器的位置标识,将一个广播数组设置为与 decoder_input_ids 形状相同的位置标识
decoder_position_ids = jnp.broadcast_to(
jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape
)
# 定义内部函数 _decoder_forward,用于调用解码器模块并返回结果
def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs):
decoder_module = module._get_decoder_module()
return decoder_module(
decoder_input_ids,
decoder_attention_mask,
decoder_position_ids,
**kwargs,
)
# 使用模型的初始化方法初始化变量,并设置解码器相关参数
init_variables = self.module.init(
jax.random.PRNGKey(0),
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
decoder_position_ids=decoder_position_ids,
encoder_hidden_states=encoder_outputs[0],
init_cache=True,
method=_decoder_forward, # 仅需调用解码器以初始化缓存
)
# 返回解除冻结后的变量中的缓存部分
return unfreeze(init_variables["cache"])
):
r"""
Returns:
Example:
```
>>> from transformers import AutoTokenizer, FlaxPegasusForConditionalGeneration
>>> model = FlaxPegasusForConditionalGeneration.from_pretrained("google/pegasus-large")
>>> tokenizer = AutoTokenizer.from_pretrained("google/pegasus-large")
>>> text = "My friends are cool but they eat too many carbs."
>>> inputs = tokenizer(text, max_length=1024, return_tensors="np")
>>> encoder_outputs = model.encode(**inputs)
```"""
# 根据传入的参数设置输出注意力机制
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
# 根据传入的参数设置输出隐藏状态
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
# 根据传入的参数设置返回字典类型
return_dict = return_dict if return_dict is not None else self.config.return_dict
# 如果 attention_mask 为 None,则创建一个全为1的掩码与 input_ids 形状相同
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids)
# 如果 position_ids 为 None,则根据 input_ids 形状创建对应的位置ID张量
if position_ids is None:
batch_size, sequence_length = input_ids.shape
position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
# 如果 dropout_rng 不为 None,则将其作为 "dropout" 的随机数生成器
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
# 定义一个内部函数 _encoder_forward,用于调用编码器模块的前向方法
def _encoder_forward(module, input_ids, attention_mask, position_ids, **kwargs):
encode_module = module._get_encoder_module()
return encode_module(input_ids, attention_mask, position_ids, **kwargs)
# 调用 self.module.apply 方法执行模型的前向传播
return self.module.apply(
{"params": params or self.params}, # 使用传入的参数或者默认参数来执行前向传播
input_ids=jnp.array(input_ids, dtype="i4"), # 将 input_ids 转换为 JAX 数组
attention_mask=jnp.array(attention_mask, dtype="i4"), # 将 attention_mask 转换为 JAX 数组
position_ids=jnp.array(position_ids, dtype="i4"), # 将 position_ids 转换为 JAX 数组
output_attentions=output_attentions, # 控制是否输出注意力机制
output_hidden_states=output_hidden_states, # 控制是否输出隐藏状态
return_dict=return_dict, # 控制是否以字典形式返回结果
deterministic=not train, # 是否处于训练模式
rngs=rngs, # 随机数生成器的字典
method=_encoder_forward, # 指定执行的方法
)
@add_start_docstrings(PEGASUS_DECODE_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=FlaxBaseModelOutputWithPastAndCrossAttentions, config_class=PegasusConfig)
def decode(
self,
decoder_input_ids,
encoder_outputs,
encoder_attention_mask: Optional[jnp.ndarray] = None,
decoder_attention_mask: Optional[jnp.ndarray] = None,
decoder_position_ids: Optional[jnp.ndarray] = None,
past_key_values: dict = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
@add_start_docstrings_to_model_forward(PEGASUS_INPUTS_DOCSTRING)
# 定义一个 __call__ 方法,使对象可以像函数一样被调用,接收以下参数:
# - input_ids: 输入的编码序列,类型为 jnp.ndarray
# - attention_mask: 可选参数,注意力掩码,默认为 None
# - decoder_input_ids: 可选参数,解码器的输入编码序列,默认为 None
# - decoder_attention_mask: 可选参数,解码器的注意力掩码,默认为 None
# - position_ids: 可选参数,位置编码序列,默认为 None
# - decoder_position_ids: 可选参数,解码器的位置编码序列,默认为 None
# - output_attentions: 可选参数,是否输出注意力权重,默认为 None
# - output_hidden_states: 可选参数,是否输出隐藏状态,默认为 None
# - return_dict: 可选参数,是否返回字典格式的结果,默认为 None
# - train: 是否处于训练模式,默认为 False
# - params: 可选参数,模型的参数,默认为 None
# - dropout_rng: 可选参数,随机数生成器用于 dropout,默认为 None
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
# 如果 output_attentions 不为 None,则使用该值;否则使用 self.config.output_attentions 的值
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
# 如果 output_hidden_states 不为 None,则使用该值;否则使用 self.config.output_hidden_states 的值
return_dict = return_dict if return_dict is not None else self.config.return_dict
# 如果 return_dict 不为 None,则使用该值;否则使用 self.config.return_dict 的值
# 准备编码器的输入
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids)
# 如果 attention_mask 为 None,则创建一个与 input_ids 形状相同的全为 1 的注意力掩码
if position_ids is None:
batch_size, sequence_length = input_ids.shape
position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
# 如果 position_ids 为 None,则根据 input_ids 的形状创建位置编码序列
# 准备解码器的输入
if decoder_input_ids is None:
decoder_input_ids = shift_tokens_right(
input_ids, self.config.pad_token_id, decoder_start_token_id=self.config.decoder_start_token_id
)
# 如果 decoder_input_ids 为 None,则将 input_ids 向右移动一个位置,并使用配置中的特殊标记进行填充
if decoder_attention_mask is None:
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
# 如果 decoder_attention_mask 为 None,则创建一个与 decoder_input_ids 形状相同的全为 1 的注意力掩码
if decoder_position_ids is None:
batch_size, sequence_length = decoder_input_ids.shape
decoder_position_ids = jnp.broadcast_to(
jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
)
# 如果 decoder_position_ids 为 None,则根据 decoder_input_ids 的形状创建位置编码序列
# 处理可能需要的任何随机数生成器
rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
# 如果 dropout_rng 不为 None,则创建一个包含 dropout_rng 的随机数生成器字典;否则创建一个空字典
return self.module.apply(
{"params": params or self.params},
input_ids=jnp.array(input_ids, dtype="i4"),
attention_mask=jnp.array(attention_mask, dtype="i4"),
position_ids=jnp.array(position_ids, dtype="i4"),
decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=not train,
rngs=rngs,
)
# 应用 self.module 中的函数:
# - 使用 params 或 self.params 中的参数
# - 将输入转换为 jnp.ndarray 格式并传递给相应参数
# - 设置是否输出注意力权重和隐藏状态
# - 设置是否以字典格式返回结果
# - 设置是否处于确定性计算模式
# - 传递随机数生成器字典 rngs
# 使用装饰器为 FlaxPegasusModel 类添加文档字符串,描述其作为 Pegasus 模型的基本转换器,输出原始隐藏状态而无顶部特定头部。
# PEGASUS_START_DOCSTRING 中包含 Pegasus 模型的起始文档字符串。
@add_start_docstrings(
"The bare Pegasus Model transformer outputting raw hidden-states without any specific head on top.",
PEGASUS_START_DOCSTRING,
)
# 定义 FlaxPegasusModel 类,继承自 FlaxPegasusPreTrainedModel,具有 PegasusConfig 类型的配置参数。
class FlaxPegasusModel(FlaxPegasusPreTrainedModel):
config: PegasusConfig
# 计算中使用的数据类型,默认为 jnp.float32
dtype: jnp.dtype = jnp.float32
# 模型类别设置为 FlaxPegasusModule
module_class = FlaxPegasusModule
# 调用函数,为 FlaxPegasusModel 类附加示例调用文档字符串,使用 _CHECKPOINT_FOR_DOC 和 FlaxSeq2SeqModelOutput。
append_call_sample_docstring(FlaxPegasusModel, _CHECKPOINT_FOR_DOC, FlaxSeq2SeqModelOutput, _CONFIG_FOR_DOC)
# 从 transformers.models.bart.modeling_flax_bart.FlaxBartForConditionalGenerationModule 复制代码,修改为 Pegasus 模型
class FlaxPegasusForConditionalGenerationModule(nn.Module):
config: PegasusConfig
dtype: jnp.dtype = jnp.float32
# 偏置初始化器,使用 jax.nn.initializers.zeros 初始化
bias_init: Callable[..., jnp.ndarray] = jax.nn.initializers.zeros
# 模块设置方法
def setup(self):
# 创建 FlaxPegasusModule 模型对象,使用配置和数据类型作为参数
self.model = FlaxPegasusModule(config=self.config, dtype=self.dtype)
# 创建 lm_head 层,使用 nn.Dense 定义,输出维度为 self.model.shared.num_embeddings
self.lm_head = nn.Dense(
self.model.shared.num_embeddings,
use_bias=False,
dtype=self.dtype,
# 使用正态分布初始化 kernel 参数,标准差为 self.config.init_std
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
# 定义 final_logits_bias 参数,形状为 (1, self.model.shared.num_embeddings),使用 bias_init 初始化
self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.model.shared.num_embeddings))
# 获取编码器模块的方法
def _get_encoder_module(self):
return self.model.encoder
# 获取解码器模块的方法
def _get_decoder_module(self):
return self.model.decoder
# 调用方法,定义模型的前向传播逻辑
def __call__(
self,
input_ids,
attention_mask,
decoder_input_ids,
decoder_attention_mask,
position_ids,
decoder_position_ids,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
deterministic: bool = True,
# 声明输入参数的类型和默认值
**kwargs
):
):
# 使用模型生成输出结果
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
position_ids=position_ids,
decoder_position_ids=decoder_position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=deterministic,
)
# 获取模型输出的隐藏状态
hidden_states = outputs[0]
# 如果配置要求共享词嵌入,则使用共享的嵌入层进行计算
if self.config.tie_word_embeddings:
shared_embedding = self.model.variables["params"]["shared"]["embedding"]
# 应用共享的嵌入层权重到隐藏状态,得到语言模型的logits
lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
else:
# 否则直接使用语言模型头部计算logits
lm_logits = self.lm_head(hidden_states)
# 添加最终logits的偏置
lm_logits += jax.lax.stop_gradient(self.final_logits_bias.astype(self.dtype))
# 如果不要求返回字典形式的输出,则返回完整的输出元组
if not return_dict:
output = (lm_logits,) + outputs[1:]
return output
# 否则,返回FlaxSeq2SeqLMOutput类型的对象,包含完整的输出信息
return FlaxSeq2SeqLMOutput(
logits=lm_logits,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
cross_attentions=outputs.cross_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions,
)
@add_start_docstrings(
"The PEGASUS Model with a language modeling head. Can be used for summarization.", PEGASUS_START_DOCSTRING
)
class FlaxPegasusForConditionalGeneration(FlaxPegasusPreTrainedModel):
module_class = FlaxPegasusForConditionalGenerationModule
dtype: jnp.dtype = jnp.float32
@add_start_docstrings(PEGASUS_DECODE_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=PegasusConfig)
def decode(
self,
decoder_input_ids,
encoder_outputs,
encoder_attention_mask: Optional[jnp.ndarray] = None,
decoder_attention_mask: Optional[jnp.ndarray] = None,
decoder_position_ids: Optional[jnp.ndarray] = None,
past_key_values: dict = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
deterministic: bool = True,
params: dict = None,
dropout_rng: PRNGKey = None,
):
"""
Decode function for PEGASUS model, generating outputs based on decoder inputs and encoder outputs.
Args:
decoder_input_ids: Input IDs for the decoder.
encoder_outputs: Outputs from the encoder.
encoder_attention_mask: Optional attention mask for encoder outputs.
decoder_attention_mask: Optional attention mask for decoder inputs.
decoder_position_ids: Optional position IDs for the decoder inputs.
past_key_values: Cached key values from previous decoding steps.
output_attentions: Whether to output attention weights.
output_hidden_states: Whether to output hidden states.
return_dict: Whether to return outputs as a dictionary.
deterministic: Whether to use deterministic behavior.
params: Optional parameters for decoding.
dropout_rng: Random number generator for dropout.
Returns:
Model outputs with cross attentions, conforming to PEGASUS configuration.
"""
# initializing the cache
batch_size, seq_length = decoder_input_ids.shape
past_key_values = self.init_cache(batch_size, max_length, encoder_outputs)
# Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
if decoder_attention_mask is not None:
position_ids = decoder_attention_mask.cumsum(axis=-1) - 1
extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0))
else:
position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
return {
"past_key_values": past_key_values,
"encoder_outputs": encoder_outputs,
"encoder_attention_mask": attention_mask,
"decoder_attention_mask": extended_attention_mask,
"decoder_position_ids": position_ids,
}
def prepare_inputs_for_generation(
self,
decoder_input_ids,
max_length,
attention_mask: Optional[jax.Array] = None,
decoder_attention_mask: Optional[jax.Array] = None,
encoder_outputs=None,
**kwargs,
):
"""
Prepare inputs for generation based on decoder inputs and optional masks.
Args:
decoder_input_ids: Input IDs for the decoder.
max_length: Maximum length of generated outputs.
attention_mask: Optional attention mask for encoder outputs.
decoder_attention_mask: Optional attention mask for decoder inputs.
encoder_outputs: Optional outputs from the encoder.
**kwargs: Additional keyword arguments.
Returns:
Dictionary of inputs formatted for generation.
"""
batch_size, seq_length = decoder_input_ids.shape
past_key_values = self.init_cache(batch_size, max_length, encoder_outputs)
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
if decoder_attention_mask is not None:
position_ids = decoder_attention_mask.cumsum(axis=-1) - 1
extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0))
else:
position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
return {
"past_key_values": past_key_values,
"encoder_outputs": encoder_outputs,
"encoder_attention_mask": attention_mask,
"decoder_attention_mask": extended_attention_mask,
"decoder_position_ids": position_ids,
}
def update_inputs_for_generation(self, model_outputs, model_kwargs):
"""
Update inputs for generation based on model outputs and current model arguments.
Args:
model_outputs: Outputs from the model.
model_kwargs: Current model keyword arguments.
Returns:
Updated model keyword arguments for generation.
"""
model_kwargs["past_key_values"] = model_outputs.past_key_values
model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1
return model_kwargs
>>> model = FlaxPegasusForConditionalGeneration.from_pretrained('google/pegasus-large')
>>> tokenizer = AutoTokenizer.from_pretrained('google/pegasus-large')
>>> ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs."
>>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='np')
>>>
>>> summary_ids = model.generate(inputs['input_ids']).sequences
>>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False))
Mask filling example:
>>> from transformers import AutoTokenizer, FlaxPegasusForConditionalGeneration
>>> tokenizer = AutoTokenizer.from_pretrained("google/pegasus-large")
>>> TXT = "My friends are <mask> but they eat too many carbs."
>>> model = FlaxPegasusForConditionalGeneration.from_pretrained("google/pegasus-large")
>>> input_ids = tokenizer([TXT], return_tensors="np")["input_ids"]
>>> logits = model(input_ids).logits
>>>
>>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
>>>
>>> probs = jax.nn.softmax(logits[0, masked_index], axis=0)
>>>
>>> values, predictions = jax.lax.top_k(probs)
>>>
>>> tokenizer.decode(predictions).split()
"""
为FlaxPegasusForConditionalGeneration类的文档字符串添加内容
使用 PEGASUS_INPUTS_DOCSTRING 和 FLAX_PEGASUS_CONDITIONAL_GENERATION_DOCSTRING 进行覆盖
"""
overwrite_call_docstring(
FlaxPegasusForConditionalGeneration, PEGASUS_INPUTS_DOCSTRING + FLAX_PEGASUS_CONDITIONAL_GENERATION_DOCSTRING
)
"""
为FlaxPegasusForConditionalGeneration类的返回文档字符串追加内容
使用 FlaxSeq2SeqLMOutput 作为输出类型,使用 _CONFIG_FOR_DOC 作为配置类
"""
append_replace_return_docstrings(
FlaxPegasusForConditionalGeneration, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC
)