Transformers 源码解析(一百零八)
.\models\swin\modeling_tf_swin.py
""" TF 2.0 Swin Transformer 模型。"""
from __future__ import annotations
import collections.abc
import math
import warnings
from dataclasses import dataclass
from functools import partial
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
import tensorflow as tf
from ...activations_tf import ACT2FN
from ...modeling_tf_utils import (
TFPreTrainedModel,
TFSequenceClassificationLoss,
get_initializer,
keras,
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list
from ...utils import (
ModelOutput,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from .configuration_swin import SwinConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "SwinConfig"
_CHECKPOINT_FOR_DOC = "microsoft/swin-tiny-patch4-window7-224"
_EXPECTED_OUTPUT_SHAPE = [1, 49, 768]
_IMAGE_CLASS_CHECKPOINT = "microsoft/swin-tiny-patch4-window7-224"
_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
TF_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST = [
"microsoft/swin-tiny-patch4-window7-224",
]
@dataclass
class TFSwinEncoderOutput(ModelOutput):
"""
Swin 编码器的输出,可能包括隐藏状态和注意力。
"""
Args:
last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
模型最后一层的隐藏状态序列的张量。
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
可选参数,当 `output_hidden_states=True` 或 `config.output_hidden_states=True` 时返回,包含模型每一层的隐藏状态的元组。
每个张量的形状为 `(batch_size, sequence_length, hidden_size)`。
包括初始嵌入输出后每个层的模型隐藏状态。
attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
可选参数,当 `output_attentions=True` 或 `config.output_attentions=True` 时返回,包含模型每个阶段的注意力权重的元组。
每个张量的形状为 `(batch_size, num_heads, sequence_length, sequence_length)`。
在注意力 softmax 后的注意力权重,用于计算自注意力头部的加权平均值。
reshaped_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
可选参数,当 `output_hidden_states=True` 或 `config.output_hidden_states=True` 时返回,包含模型每一层的隐藏状态的元组。
每个张量的形状为 `(batch_size, hidden_size, height, width)`。
包括初始嵌入输出后每个层的模型隐藏状态,重塑以包括空间维度。
@dataclass
class TFSwinModelOutput(ModelOutput):
"""
Swin model's outputs that also contains a pooling of the last hidden states.
Args:
last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
pooler_output (`tf.Tensor` of shape `(batch_size, hidden_size)`, *optional*, returned when `add_pooling_layer=True` is passed):
Average pooling of the last layer hidden-state.
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each stage) of shape
`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `tf.Tensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
reshaped_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each stage) of shape
`(batch_size, hidden_size, height, width)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
include the spatial dimensions.
"""
last_hidden_state: tf.Tensor = None
pooler_output: tf.Tensor | None = None
hidden_states: Tuple[tf.Tensor, ...] | None = None
attentions: Tuple[tf.Tensor, ...] | None = None
reshaped_hidden_states: Tuple[tf.Tensor, ...] | None = None
Args:
loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided):
Masked image modeling (MLM) loss.
reconstruction (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`):
Reconstructed pixel values.
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each stage) of shape
`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `tf.Tensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
reshaped_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each stage) of shape
`(batch_size, hidden_size, height, width)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
include the spatial dimensions.
"""
# 初始化属性:损失、重构像素值、隐藏状态、注意力权重和重塑后的隐藏状态,默认为None
loss: tf.Tensor | None = None
reconstruction: tf.Tensor = None
hidden_states: Tuple[tf.Tensor, ...] | None = None
attentions: Tuple[tf.Tensor, ...] | None = None
reshaped_hidden_states: Tuple[tf.Tensor, ...] | None = None
@property
def logits(self):
# 发出警告,提醒用户logits属性即将在Transformers的第5个版本中移除,建议使用reconstruction属性获取最终输出
warnings.warn(
"logits attribute is deprecated and will be removed in version 5 of Transformers."
" Please use the reconstruction attribute to retrieve the final output instead.",
FutureWarning,
)
# 返回重构属性作为输出
return self.reconstruction
@dataclass
class TFSwinImageClassifierOutput(ModelOutput):
"""
Swin outputs for image classification.
Args:
loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Classification (or regression if config.num_labels==1) loss.
logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each stage) of shape
`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `tf.Tensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
reshaped_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each stage) of shape
`(batch_size, hidden_size, height, width)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
include the spatial dimensions.
"""
loss: tf.Tensor | None = None # 损失值,如果提供了 `labels` 参数,则返回;用于分类(如果 `config.num_labels==1` 则为回归)的损失。
logits: tf.Tensor = None # 分类(或回归,如果 `config.num_labels==1`)得分,未经 SoftMax 处理,形状为 `(batch_size, config.num_labels)`。
hidden_states: Tuple[tf.Tensor, ...] | None = None # 模型在每一层输出的隐藏状态和初始嵌入输出的元组,形状为 `(batch_size, sequence_length, hidden_size)`。
attentions: Tuple[tf.Tensor, ...] | None = None # 注意力权重,经过注意力 SoftMax 后的结果,用于计算自注意力头部中的加权平均值,形状为 `(batch_size, num_heads, sequence_length, sequence_length)` 的元组。
reshaped_hidden_states: Tuple[tf.Tensor, ...] | None = None # 模型在每一层输出的隐藏状态和初始嵌入输出的重塑版本,包括空间维度,形状为 `(batch_size, hidden_size, height, width)` 的元组。
def window_partition(input_feature: tf.Tensor, window_size: int) -> tf.Tensor:
"""
Partitions the given input into windows.
"""
batch_size, height, width, num_channels = shape_list(input_feature) # 获取输入特征的形状信息
input_feature = tf.reshape(
input_feature,
(batch_size, height // window_size, window_size, width // window_size, window_size, num_channels), # 将输入特征重塑为窗口的形状
)
windows = tf.transpose(input_feature, (0, 1, 3, 2, 4, 5)) # 调整窗口的顺序
windows = tf.reshape(windows, (-1, window_size, window_size, num_channels)) # 将调整顺序后的窗口展平
return windows
def window_reverse(windows: tf.Tensor, window_size: int, height: int, width: int) -> tf.Tensor:
"""
Merges windows to produce higher resolution features.
"""
x = tf.shape(windows)[0] # 获取窗口张量的第一维大小
y = tf.cast(height * width / (window_size * window_size), tf.int32) # 计算合并后特征的大小
batch_size = tf.math.floordiv(x, y) # 计算批次大小
# 将输入的窗口数据重新形状为指定的多维张量,以便进行后续处理
windows = tf.reshape(
windows, (batch_size, height // window_size, width // window_size, window_size, window_size, -1)
)
# 转置张量的维度顺序,以便后续处理更方便
windows = tf.transpose(windows, (0, 1, 3, 2, 4, 5))
# 将张量重新形状为指定的多维张量,以便进行后续处理
windows = tf.reshape(windows, (batch_size, height, width, -1))
# 返回处理后的窗口数据张量
return windows
def drop_path(
input: tf.Tensor, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
) -> tf.Tensor:
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
# 如果 drop_prob 为 0 或者不处于训练模式,则直接返回输入
if drop_prob == 0.0 or not training:
return input
# 计算保留的概率
keep_prob = 1 - drop_prob
# 获取输入张量的形状信息
input_shape = shape_list(input)
# 获取张量的维度数
ndim = len(input_shape)
# 构建一个形状与输入张量相同的随机张量,用于决定每个元素是否保留
shape = [input_shape[0]] + [1] * (ndim - 1) # 适用于不同维度的张量,不仅限于2D卷积网络
random_tensor = tf.random.uniform(shape)
# 将随机张量中小于等于保留概率的元素设置为1.0,其余设置为0.0
random_tensor = tf.where(random_tensor <= keep_prob, 1.0, 0.0)
# 如果保留概率大于0且需要按保留概率进行缩放,则对随机张量进行缩放处理
if keep_prob > 0.0 and scale_by_keep:
random_tensor /= keep_prob
# 返回经过随机路径丢弃后的输入张量
return input * random_tensor
class TFSwinEmbeddings(keras.layers.Layer):
"""
Construct the patch and position embeddings. Optionally, also the mask token.
"""
def __init__(self, config: SwinConfig, use_mask_token: bool = False, **kwargs) -> None:
super().__init__(**kwargs)
# 初始化补丁和位置嵌入
self.patch_embeddings = TFSwinPatchEmbeddings(config, name="patch_embeddings")
# 获取补丁数量和网格大小
self.num_patches = self.patch_embeddings.num_patches
self.patch_grid = self.patch_embeddings.grid_size
self.embed_dim = config.embed_dim
self.use_mask_token = use_mask_token
self.use_absolute_embeddings = config.use_absolute_embeddings
# 层归一化
self.norm = keras.layers.LayerNormalization(name="norm", epsilon=1e-5)
# dropout
self.dropout = keras.layers.Dropout(config.hidden_dropout_prob, name="dropout")
self.config = config
def build(self, input_shape: tf.TensorShape) -> None:
# 如果需要使用掩码令牌,则添加掩码令牌的权重
if self.use_mask_token:
self.mask_token = self.add_weight(shape=(1, 1, self.embed_dim), initializer="zeros", name="mask_token")
else:
self.mask_token = None
# 如果使用绝对位置嵌入,则添加位置嵌入的权重
if self.use_absolute_embeddings:
self.position_embeddings = self.add_weight(
(1, self.num_patches + 1, self.embed_dim), initializer="zeros", name="positional_embeddings"
)
else:
self.position_embeddings = None
# 如果已经构建,则直接返回
if self.built:
return
self.built = True
# 构建补丁嵌入层、层归一化层和dropout层
if getattr(self, "patch_embeddings", None) is not None:
with tf.name_scope(self.patch_embeddings.name):
self.patch_embeddings.build(None)
if getattr(self, "norm", None) is not None:
with tf.name_scope(self.norm.name):
self.norm.build([None, None, self.config.embed_dim])
if getattr(self, "dropout", None) is not None:
with tf.name_scope(self.dropout.name):
self.dropout.build(None)
def call(
self, pixel_values: tf.Tensor, bool_masked_pos: bool = None, training: bool = False
) -> tf.Tensor:
# 留待实现,用于调用该层处理输入张量
pass
) -> Tuple[tf.Tensor, Tuple[int, int]]:
# 计算输入图像的嵌入向量和输出维度
embeddings, output_dimensions = self.patch_embeddings(pixel_values, training=training)
# 对嵌入向量进行归一化处理
embeddings = self.norm(embeddings, training=training)
# 获取嵌入向量的形状信息
batch_size, seq_len, _ = shape_list(embeddings)
# 如果存在需要屏蔽的位置信息
if bool_masked_pos is not None:
# 创建与嵌入向量相同形状的屏蔽标记
mask_tokens = tf.repeat(self.mask_token, batch_size, 0)
mask_tokens = tf.repeat(mask_tokens, seq_len, 1)
# 将屏蔽位置的嵌入向量替换为屏蔽标记
mask = tf.expand_dims(bool_masked_pos, -1)
mask = tf.cast(mask, mask_tokens.dtype)
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
# 如果存在位置嵌入向量,则将其加到嵌入向量上
if self.position_embeddings is not None:
embeddings = embeddings + self.position_embeddings
# 对嵌入向量进行dropout处理
embeddings = self.dropout(embeddings, training=training)
# 返回处理后的嵌入向量和输出维度
return embeddings, output_dimensions
class TFSwinPatchEmbeddings(keras.layers.Layer):
"""
Image to Patch Embedding.
"""
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
# 从配置中获取图像大小和patch大小
image_size, patch_size = config.image_size, config.patch_size
# 获取通道数和嵌入维度
num_channels, hidden_size = config.num_channels, config.embed_dim
# 如果图像大小和patch大小不是可迭代对象,转换为元组
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
# 计算patch的数量
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
# 设置类属性
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.num_patches = num_patches
self.grid_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
# 定义投影层,使用Conv2D将patch映射到隐藏维度空间
self.projection = keras.layers.Conv2D(
filters=hidden_size,
kernel_size=self.patch_size,
strides=self.patch_size,
padding="valid",
name="projection",
)
def maybe_pad(self, pixel_values: tf.Tensor, height: int, width: int) -> tf.Tensor:
# 如果宽度不是patch宽度的整数倍,进行填充
if width % self.patch_size[1] != 0:
pad_values = ((0, 0), (0, 0), (0, 0), (0, self.patch_size[1] - width % self.patch_size[1]))
pixel_values = tf.pad(pixel_values, pad_values)
# 如果高度不是patch高度的整数倍,进行填充
if height % self.patch_size[0] != 0:
pad_values = ((0, 0), (0, 0), (0, self.patch_size[0] - height % self.patch_size[0]), (0, 0))
pixel_values = tf.pad(pixel_values, pad_values)
return pixel_values
def call(self, pixel_values: tf.Tensor, training: bool = False) -> Tuple[tf.Tensor, Tuple[int, int]]:
# 获取输入张量的形状信息
_, num_channels, height, width = shape_list(pixel_values)
# 在动态执行环境下,检查通道数是否与配置中设置的一致
if tf.executing_eagerly() and num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
# 如果需要,对输入进行填充,使其可以被self.patch_size整除
pixel_values = self.maybe_pad(pixel_values, height, width)
# 调整输入张量的维度顺序 B,C,H,W -> B,H,W,C
pixel_values = tf.transpose(pixel_values, (0, 2, 3, 1))
# 使用投影层将patch映射到隐藏维度空间
embeddings = self.projection(pixel_values, training=training)
# 调整输出张量的维度顺序 B,H,W,C -> B,C,H,W
embeddings = tf.transpose(embeddings, (0, 3, 1, 2))
# 获取输出张量的形状信息
batch_size, channels, height, width = shape_list(embeddings)
output_dimensions = (height, width)
# 将输出张量reshape为 B,N,C 的形式,其中N为patch的数量
embeddings = tf.reshape(embeddings, (batch_size, channels, -1))
embeddings = tf.transpose(embeddings, (0, 2, 1))
return embeddings, output_dimensions
# 定义一个方法用于构建模型,如果已经构建过则直接返回
def build(self, input_shape=None):
if self.built:
return
# 标记模型已经构建
self.built = True
# 检查是否存在投影层,并在 TensorFlow 的命名空间下构建投影层
if getattr(self, "projection", None) is not None:
with tf.name_scope(self.projection.name):
# 使用投影层的建模方法来构建投影层,传入特定维度的列表
self.projection.build([None, None, None, self.num_channels])
class TFSwinPatchMerging(keras.layers.Layer):
"""
Patch Merging Layer.
Args:
input_resolution (`Tuple[int]`):
Resolution of input feature.
dim (`int`):
Number of input channels.
norm_layer (`keras.layer.Layer`, *optional*, defaults to `keras.layers.LayerNormalization`):
Normalization layer class.
"""
def __init__(
self, input_resolution: Tuple[int, int], dim: int, norm_layer: Optional[Callable] = None, **kwargs
) -> None:
super().__init__(**kwargs)
self.input_resolution = input_resolution # 设置输入特征的分辨率
self.dim = dim # 设置输入通道数
self.reduction = keras.layers.Dense(2 * dim, use_bias=False, name="reduction") # 创建一个稠密层用于特征降维
if norm_layer is None:
# 如果未提供自定义的归一化层,则使用默认的层归一化层,设置标准化的epsilon值与PyTorch相同
self.norm = keras.layers.LayerNormalization(epsilon=1e-5, name="norm")
else:
self.norm = norm_layer(name="norm") # 使用提供的自定义归一化层
def maybe_pad(self, input_feature: tf.Tensor, height: int, width: int) -> tf.Tensor:
should_pad = (height % 2 == 1) or (width % 2 == 1)
if should_pad:
pad_values = ((0, 0), (0, height % 2), (0, width % 2), (0, 0)) # 计算需要填充的值
input_feature = tf.pad(input_feature, pad_values) # 对输入特征进行填充
return input_feature
def call(self, input_feature: tf.Tensor, input_dimensions: Tuple[int, int], training: bool = False) -> tf.Tensor:
height, width = input_dimensions
batch_size, _, num_channels = shape_list(input_feature) # 获取输入特征的形状信息
input_feature = tf.reshape(input_feature, (batch_size, height, width, num_channels)) # 将输入特征重塑为四维张量
input_feature = self.maybe_pad(input_feature, height, width) # 可能对输入特征进行填充,使其尺寸可以被宽度和高度整除
input_feature_0 = input_feature[:, 0::2, 0::2, :] # 提取输入特征的每隔一个像素点的子集
input_feature_1 = input_feature[:, 1::2, 0::2, :] # 提取输入特征的每隔一个像素点的子集
input_feature_2 = input_feature[:, 0::2, 1::2, :] # 提取输入特征的每隔一个像素点的子集
input_feature_3 = input_feature[:, 1::2, 1::2, :] # 提取输入特征的每隔一个像素点的子集
input_feature = tf.concat([input_feature_0, input_feature_1, input_feature_2, input_feature_3], -1) # 合并这四个子集
input_feature = tf.reshape(
input_feature, (batch_size, -1, 4 * num_channels)
) # 将合并后的特征重塑为三维张量,以便进一步处理
input_feature = self.norm(input_feature, training=training) # 对特征进行归一化
input_feature = self.reduction(input_feature, training=training) # 对特征进行降维
return input_feature
# 定义 build 方法,用于构建模型,如果已经构建过,则直接返回
def build(self, input_shape=None):
# 检查是否已经构建过,如果是则返回,避免重复构建
if self.built:
return
# 将标志设置为已构建
self.built = True
# 如果有指定的 reduction 属性,则在名为 reduction 的命名空间下构建
if getattr(self, "reduction", None) is not None:
with tf.name_scope(self.reduction.name):
# 使用 4 * self.dim 的输入形状来构建 reduction 属性
self.reduction.build([None, None, 4 * self.dim])
# 如果有指定的 norm 属性,则在名为 norm 的命名空间下构建
if getattr(self, "norm", None) is not None:
with tf.name_scope(self.norm.name):
# 使用 4 * self.dim 的输入形状来构建 norm 属性
self.norm.build([None, None, 4 * self.dim])
class TFSwinDropPath(keras.layers.Layer):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob: float = None, scale_by_keep: bool = True, **kwargs) -> None:
super(TFSwinDropPath, self).__init__(**kwargs)
self.drop_prob = drop_prob # 初始化丢弃概率
self.scale_by_keep = scale_by_keep # 是否按保留比例缩放
def call(self, input: tf.Tensor, training: bool = False) -> tf.Tensor:
# 调用 drop_path 函数来应用丢弃路径操作
return drop_path(input, self.drop_prob, training, self.scale_by_keep)
class TFSwinSelfAttention(keras.layers.Layer):
def __init__(self, config: SwinConfig, dim: int, num_heads: int, **kwargs) -> None:
super().__init__(**kwargs)
if dim % num_heads != 0:
raise ValueError(
f"The hidden size ({dim}) is not a multiple of the number of attention heads ({num_heads})"
)
self.num_attention_heads = num_heads # 设置注意力头数
self.attention_head_size = int(dim / num_heads) # 计算每个注意力头的大小
self.all_head_size = self.num_attention_heads * self.attention_head_size # 总的 QKV 大小
window_size = config.window_size
self.window_size = (
window_size if isinstance(window_size, collections.abc.Iterable) else (window_size, window_size)
) # 窗口大小
self.query = keras.layers.Dense(
self.all_head_size,
kernel_initializer=get_initializer(config.initializer_range),
use_bias=config.qkv_bias,
name="query",
) # 查询向量的全连接层
self.key = keras.layers.Dense(
self.all_head_size,
kernel_initializer=get_initializer(config.initializer_range),
use_bias=config.qkv_bias,
name="key",
) # 键向量的全连接层
self.value = keras.layers.Dense(
self.all_head_size,
kernel_initializer=get_initializer(config.initializer_range),
use_bias=config.qkv_bias,
name="value",
) # 值向量的全连接层
self.dropout = keras.layers.Dropout(config.attention_probs_dropout_prob) # 注意力概率的 dropout 层
def build(self, input_shape: tf.TensorShape) -> None:
# 创建一个用于存储相对位置偏置表的权重变量
self.relative_position_bias_table = self.add_weight(
shape=(((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1)), self.num_attention_heads),
initializer="zeros",
name="relative_position_bias_table",
)
# 创建一个用于存储相对位置索引的权重变量,这些索引是窗口内每个标记的相对位置
self.relative_position_index = self.add_weight(
shape=(self.window_size[0] ** 2, self.window_size[1] ** 2),
trainable=False,
dtype=tf.int32,
name="relative_position_index",
)
# 获取窗口内每个标记的成对相对位置索引
coords_h = tf.range(self.window_size[0])
coords_w = tf.range(self.window_size[1])
coords = tf.stack(tf.meshgrid(coords_h, coords_w, indexing="ij"))
coords_flatten = tf.reshape(coords, (shape_list(coords)[0], -1))
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
relative_coords = tf.transpose(relative_coords, (1, 2, 0))
stack_0, stack_1 = tf.unstack(relative_coords, axis=2)
stack_0 += self.window_size[0] - 1
stack_0 *= 2 * self.window_size[1] - 1
stack_1 += self.window_size[1] - 1
relative_coords = tf.stack([stack_0, stack_1], axis=2)
# 计算相对位置索引的总和并分配给相对位置索引变量
self.relative_position_index.assign(tf.cast(tf.reduce_sum(relative_coords, axis=-1), tf.int32))
# 如果已经构建过,则直接返回
if self.built:
return
# 标记模型已经构建
self.built = True
# 如果存在查询、键、值变量,则构建它们的结构
if getattr(self, "query", None) is not None:
with tf.name_scope(self.query.name):
self.query.build([None, None, self.all_head_size])
if getattr(self, "key", None) is not None:
with tf.name_scope(self.key.name):
self.key.build([None, None, self.all_head_size])
if getattr(self, "value", None) is not None:
with tf.name_scope(self.value.name):
self.value.build([None, None, self.all_head_size])
def transpose_for_scores(self, x: tf.Tensor) -> tf.Tensor:
# 调整张量的形状以便计算注意力分数
new_x_shape = shape_list(x)[:-1] + [self.num_attention_heads, self.attention_head_size]
x = tf.reshape(x, new_x_shape)
return tf.transpose(x, (0, 2, 1, 3))
def call(
self,
hidden_states: tf.Tensor,
attention_mask: tf.Tensor | None = None,
head_mask: tf.Tensor | None = None,
output_attentions: bool = False,
training: bool = False,
) -> Tuple[tf.Tensor, ...]:
# 获取隐藏状态的形状信息:批大小、维度等
batch_size, dim, _ = shape_list(hidden_states)
# 对隐藏状态进行查询操作,生成混合的查询层
mixed_query_layer = self.query(hidden_states)
# 使用self.key对隐藏状态进行键的转换,并调整形状以适应注意力得分计算
key_layer = self.transpose_for_scores(self.key(hidden_states))
# 使用self.value对隐藏状态进行值的转换,并调整形状以适应注意力得分计算
value_layer = self.transpose_for_scores(self.value(hidden_states))
# 对混合的查询层进行形状调整,以适应注意力得分计算
query_layer = self.transpose_for_scores(mixed_query_layer)
# 计算查询层与键层之间的点积,得到原始的注意力得分
attention_scores = tf.matmul(query_layer, tf.transpose(key_layer, (0, 1, 3, 2)))
# 对注意力得分进行缩放,以减少数值大小对 softmax 函数计算的影响
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
# 根据相对位置索引从相对位置偏置表中获取相对位置偏置
relative_position_bias = tf.gather(
self.relative_position_bias_table, tf.reshape(self.relative_position_index, (-1,))
)
# 调整相对位置偏置的形状以匹配注意力得分的形状
relative_position_bias = tf.reshape(
relative_position_bias,
(self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1),
)
# 转置相对位置偏置的维度顺序,以便与注意力得分相加
relative_position_bias = tf.transpose(relative_position_bias, (2, 0, 1))
attention_scores = attention_scores + tf.expand_dims(relative_position_bias, 0)
# 如果存在注意力掩码,则应用它
if attention_mask is not None:
# 获取注意力掩码的形状信息
mask_shape = shape_list(attention_mask)[0]
# 调整注意力得分的形状以匹配掩码的形状
attention_scores = tf.reshape(
attention_scores, (batch_size // mask_shape, mask_shape, self.num_attention_heads, dim, dim)
)
# 扩展注意力掩码的维度以匹配注意力得分
attention_mask = tf.expand_dims(attention_mask, 1)
attention_mask = tf.expand_dims(attention_mask, 0)
# 将注意力掩码加到注意力得分上
attention_scores = attention_scores + attention_mask
# 重新调整注意力得分的形状
attention_scores = tf.reshape(attention_scores, (-1, self.num_attention_heads, dim, dim))
# 对注意力得分进行 softmax 归一化,得到注意力概率
attention_probs = tf.nn.softmax(attention_scores, axis=-1)
# 使用 dropout 进行注意力概率的随机失活,仅在训练时生效
attention_probs = self.dropout(attention_probs, training=training)
# 如果指定了头部掩码,则应用头部掩码
if head_mask is not None:
attention_probs = attention_probs * head_mask
# 计算上下文层,将注意力概率乘以值层
context_layer = tf.matmul(attention_probs, value_layer)
# 调整上下文层的维度顺序,以适应输出格式
context_layer = tf.transpose(context_layer, (0, 2, 1, 3))
# 调整上下文层的形状以匹配所有头部的输出大小
new_context_layer_shape = shape_list(context_layer)[:-2] + [
self.all_head_size,
]
context_layer = tf.reshape(context_layer, new_context_layer_shape)
# 输出结果,包括上下文层和可能的注意力概率
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs
# 定义一个名为 TFSwinSelfOutput 的自定义层,继承自 Keras 的 Layer 类
class TFSwinSelfOutput(keras.layers.Layer):
# 初始化方法,接受 SwinConfig 对象、整数 dim 和额外的关键字参数
def __init__(self, config: SwinConfig, dim: int, **kwargs) -> None:
super().__init__(**kwargs)
# 创建一个 Dense 层,用于线性变换,输出维度为 dim
self.dense = keras.layers.Dense(dim, name="dense")
# 创建一个 Dropout 层,使用配置中的 dropout 概率
self.dropout = keras.layers.Dropout(config.attention_probs_dropout_prob, name="dropout")
self.dim = dim
# 前向传播方法,接受 hidden_states(输入张量)、input_tensor(输入张量)、training(布尔值,指示是否处于训练模式)
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
# 将输入通过 Dense 层进行线性变换
hidden_states = self.dense(hidden_states)
# 对线性变换后的结果进行 Dropout 操作
hidden_states = self.dropout(hidden_states, training=training)
return hidden_states
# 构建方法,用于构建层的内部结构
def build(self, input_shape=None):
# 如果层已经构建,则直接返回
if self.built:
return
self.built = True
# 如果存在 Dense 层,则构建该层
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.dim])
# 如果存在 Dropout 层,则构建该层
if getattr(self, "dropout", None) is not None:
with tf.name_scope(self.dropout.name):
self.dropout.build(None)
# 定义一个名为 TFSwinAttention 的自定义层,继承自 Keras 的 Layer 类
class TFSwinAttention(keras.layers.Layer):
# 初始化方法,接受 SwinConfig 对象、整数 dim、整数 num_heads 和额外的关键字参数
def __init__(self, config: SwinConfig, dim: int, num_heads: int, **kwargs) -> None:
super().__init__(**kwargs)
# 创建一个 TFSwinSelfAttention 层,用于处理注意力机制
self.self = TFSwinSelfAttention(config, dim, num_heads, name="self")
# 创建一个 TFSwinSelfOutput 层,用于处理自注意力输出
self.self_output = TFSwinSelfOutput(config, dim, name="output")
# 初始化一个空集合,用于存储要剪枝的注意力头
self.pruned_heads = set()
# 剪枝注意力头的方法,抛出未实现异常
def prune_heads(self, heads):
"""
Prunes heads of the model. See base class PreTrainedModel heads: dict of {layer_num: list of heads to prune in
this layer}
"""
raise NotImplementedError
# 前向传播方法,接受 hidden_states(输入张量)、attention_mask(注意力掩码张量)、head_mask(头部掩码张量)、
# output_attentions(布尔值,指示是否输出注意力矩阵)、training(布尔值,指示是否处于训练模式)
def call(
self,
hidden_states: tf.Tensor,
attention_mask: tf.Tensor | None = None,
head_mask: tf.Tensor | None = None,
output_attentions: bool = False,
training: bool = False,
) -> tf.Tensor:
# 使用 self 层处理输入的 hidden_states,得到自注意力输出 self_outputs
self_outputs = self.self(hidden_states, attention_mask, head_mask, output_attentions, training=training)
# 使用 self_output 层处理 self_outputs 和原始 hidden_states,得到注意力输出 attention_output
attention_output = self.self_output(self_outputs[0], hidden_states, training=training)
# 构建输出元组 outputs,包括注意力输出和可能的注意力矩阵(如果有的话)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs
# 构建方法,用于构建层的内部结构
def build(self, input_shape=None):
# 如果层已经构建,则直接返回
if self.built:
return
self.built = True
# 如果存在 self 层,则构建该层
if getattr(self, "self", None) is not None:
with tf.name_scope(self.self.name):
self.self.build(None)
# 如果存在 self_output 层,则构建该层
if getattr(self, "self_output", None) is not None:
with tf.name_scope(self.self_output.name):
self.self_output.build(None)
# 定义一个名为 TFSwinIntermediate 的自定义层,继承自 Keras 的 Layer 类
class TFSwinIntermediate(keras.layers.Layer):
# 初始化方法,用于创建一个新的实例
def __init__(self, config: SwinConfig, dim: int, **kwargs) -> None:
# 调用父类(tf.keras.layers.Layer)的初始化方法
super().__init__(**kwargs)
# 创建一个全连接层,输出维度为 config.mlp_ratio * dim,命名为 "dense"
self.dense = keras.layers.Dense(int(config.mlp_ratio * dim), name="dense")
# 根据配置文件中的 hidden_act 参数确定中间激活函数
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
# 将维度信息保存在实例变量 dim 中
self.dim = dim
# 调用方法,定义了该层的正向传播逻辑
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
# 通过全连接层处理输入的 hidden_states,得到输出 hidden_states
hidden_states = self.dense(hidden_states)
# 使用中间激活函数处理输出 hidden_states
hidden_states = self.intermediate_act_fn(hidden_states)
# 返回处理后的 hidden_states
return hidden_states
# 构建方法,用于构建层的变量(如果尚未构建)
def build(self, input_shape=None):
# 如果已经构建过,直接返回
if self.built:
return
# 设置标志位,表明已经构建过
self.built = True
# 如果存在全连接层 dense,则根据输入形状构建该层
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
# 调用全连接层的 build 方法,指定输入形状 [None, None, self.dim]
self.dense.build([None, None, self.dim])
# 定义一个名为 TFSwinOutput 的自定义层,继承自 keras 的 Layer 类
class TFSwinOutput(keras.layers.Layer):
# 初始化方法,接受 SwinConfig 对象、维度 dim 和其他关键字参数
def __init__(self, config: SwinConfig, dim: int, **kwargs) -> None:
super().__init__(**kwargs)
# 创建一个全连接层 dense,输出维度为 dim,命名为 "dense"
self.dense = keras.layers.Dense(dim, name="dense")
# 创建一个 Dropout 层,使用 SwinConfig 中的隐藏层 Dropout 概率作为参数,命名为 "dropout"
self.dropout = keras.layers.Dropout(config.hidden_dropout_prob, name="dropout")
# 将传入的 SwinConfig 对象保存到 self.config 中
self.config = config
# 将传入的维度 dim 保存到 self.dim 中
self.dim = dim
# 定义 call 方法,接收隐藏状态 hidden_states 和训练标志 training
def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor:
# 将隐藏状态输入到全连接层 dense 中,得到输出 hidden_states
hidden_states = self.dense(hidden_states)
# 对输出 hidden_states 应用 Dropout 操作,使用 training 参数控制是否训练模式
hidden_states = self.dropout(hidden_states, training=training)
# 返回经过全连接层和 Dropout 后的 hidden_states
return hidden_states
# 定义 build 方法,用于构建层的参数
def build(self, input_shape=None):
# 如果已经构建过,直接返回
if self.built:
return
# 标记为已构建
self.built = True
# 检查是否存在 self.dense 属性
if getattr(self, "dense", None) is not None:
# 在命名空间 self.dense.name 下,构建全连接层,输入形状为 [None, None, int(self.config.mlp_ratio * self.dim)]
with tf.name_scope(self.dense.name):
self.dense.build([None, None, int(self.config.mlp_ratio * self.dim)])
# 定义一个名为 TFSwinLayer 的自定义层,继承自 keras 的 Layer 类
class TFSwinLayer(keras.layers.Layer):
# 初始化方法,接受 config 对象、维度 dim、输入分辨率 input_resolution、注意力头数 num_heads 和其他关键字参数
def __init__(
self, config, dim, input_resolution: Tuple[int, int], num_heads: int, shift_size: int = 0, **kwargs
) -> None:
super().__init__(**kwargs)
# 设置前馈传输块的大小为 config 中的 chunk_size_feed_forward
self.chunk_size_feed_forward = config.chunk_size_feed_forward
# 计算输入分辨率的最小值
min_res = tf.reduce_min(input_resolution)
# 窗口大小为最小分辨率和 config 中的 window_size 的较小值
self.window_size = min_res if min_res <= config.window_size else config.window_size
# 如果最小分辨率小于等于窗口大小,则 shift_size 设为 0;否则使用传入的 shift_size
self.shift_size = 0 if min_res <= self.window_size else shift_size
# 保存输入分辨率到 self.input_resolution 中
self.input_resolution = input_resolution
# 创建 LayerNormalization 层,epsilon 使用 config 中的 layer_norm_eps,命名为 "layernorm_before"
self.layernorm_before = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm_before")
# 创建注意力机制层 TFSwinAttention,使用传入的 config、dim 和 num_heads,命名为 "attention"
self.attention = TFSwinAttention(config, dim, num_heads, name="attention")
# 如果 config 中的 drop_path_rate 大于 0.0,则创建 TFSwinDropPath 层,命名为 "drop_path",否则使用线性激活层
self.drop_path = (
TFSwinDropPath(config.drop_path_rate, name="drop_path")
if config.drop_path_rate > 0.0
else keras.layers.Activation("linear", name="drop_path")
)
# 创建 LayerNormalization 层,epsilon 使用 config 中的 layer_norm_eps,命名为 "layernorm_after"
self.layernorm_after = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm_after")
# 创建 Swin 模型的中间层 TFSwinIntermediate,使用 config 和 dim,命名为 "intermediate"
self.intermediate = TFSwinIntermediate(config, dim, name="intermediate")
# 创建 Swin 模型的输出层 TFSwinOutput,使用 config 和 dim,命名为 "output"
self.swin_output = TFSwinOutput(config, dim, name="output")
# 保存维度 dim 到 self.dim 中
self.dim = dim
def get_attn_mask(self, height: int, width: int, window_size: int, shift_size: int) -> tf.Tensor | None:
# 创建一个全零的图像掩码,形状为(height, width)
img_mask = tf.zeros((height, width))
# 定义高度和宽度的切片范围,用于创建注意力掩码
height_slices = ((0, -window_size), (-window_size, -shift_size), (-shift_size, -1))
width_slices = ((0, -window_size), (-window_size, -shift_size), (-shift_size, -1))
# 计算 SW-MSA 的注意力掩码
if shift_size > 0:
count = 0
for height_slice in height_slices:
for width_slice in width_slices:
# 计算当前切片内的索引
height_inds = tf.range(height_slice[0] % height, height_slice[1] % height + 1)
width_inds = tf.range(width_slice[0] % width, width_slice[1] % width + 1)
indices = tf.reshape(tf.stack(tf.meshgrid(height_inds, width_inds), axis=-1), (-1, 2))
if len(indices) >= 1:
# 将更新值为 count 的掩码应用到图像掩码的对应位置
updates = tf.ones((len(indices),), dtype=img_mask.dtype) * count
img_mask = tf.tensor_scatter_nd_update(img_mask, indices, updates)
count += 1
# 将图像掩码扩展维度以适应后续计算要求
img_mask = tf.expand_dims(img_mask, -1)
img_mask = tf.expand_dims(img_mask, 0)
# 对图像掩码进行窗口划分,用于后续的注意力计算
mask_windows = window_partition(img_mask, window_size)
mask_windows = tf.reshape(mask_windows, (-1, window_size * window_size))
# 构建注意力掩码,对角线上的元素为 -100.0,其余为 0.0
attn_mask = tf.expand_dims(mask_windows, 1) - tf.expand_dims(mask_windows, 2)
attn_mask = tf.where(attn_mask != 0, float(-100.0), attn_mask)
attn_mask = tf.where(attn_mask == 0, float(0.0), attn_mask)
return attn_mask
def maybe_pad(
self, hidden_states: tf.Tensor, window_size: int, height: int, width: int
) -> Tuple[tf.Tensor, tf.Tensor]:
# 计算需要在图像状态中填充的右边和底部的像素数
pad_right = (window_size - width % window_size) % window_size
pad_bottom = (window_size - height % window_size) % window_size
# 定义填充的数值,填充右边和底部,保持其他维度不变
pad_values = [[0, 0], [0, pad_bottom], [0, pad_right], [0, 0]]
# 在隐藏状态张量上应用填充
hidden_states = tf.pad(hidden_states, pad_values)
# 将填充值转换为一维张量返回
pad_values = tf.reshape(pad_values, (-1,))
return hidden_states, pad_values
def call(
self,
hidden_states: tf.Tensor,
input_dimensions: Tuple[int, int],
head_mask: tf.Tensor | None = None,
output_attentions: bool = False,
training: bool = False,
):
# 神经网络层的调用函数,处理输入的隐藏状态和其他参数
) -> tf.Tensor:
# 如果窗口大小大于输入分辨率,则不分割窗口
min_res = tf.reduce_min(input_dimensions) # 计算输入维度的最小值
shift_size = 0 if min_res <= self.window_size else self.shift_size # 如果最小分辨率小于等于窗口大小,则不进行移动;否则使用预设的移动大小
window_size = min_res if min_res <= self.window_size else self.window_size # 窗口大小取决于最小分辨率和设定的窗口大小
height, width = input_dimensions # 解包输入维度
batch_size, _, channels = shape_list(hidden_states) # 获取隐藏状态的批处理大小、高度、宽度和通道数
shortcut = hidden_states # 备份隐藏状态
hidden_states = self.layernorm_before(hidden_states, training=training) # 应用层归一化到隐藏状态之前
hidden_states = tf.reshape(hidden_states, (batch_size, height, width, channels)) # 重新调整隐藏状态的形状为(batch_size, height, width, channels)
hidden_states, pad_values = self.maybe_pad(hidden_states, window_size, height, width) # 可能对隐藏状态进行填充,使其成为窗口大小的倍数
_, height_pad, width_pad, _ = shape_list(hidden_states) # 获取调整后隐藏状态的形状
# 循环移位
if shift_size > 0:
shifted_hidden_states = tf.roll(hidden_states, shift=(-shift_size, -shift_size), axis=(1, 2)) # 在轴(1, 2)上执行负移位
else:
shifted_hidden_states = hidden_states # 否则不进行移位
# 分割窗口
hidden_states_windows = window_partition(shifted_hidden_states, window_size) # 将移位后的隐藏状态分割成窗口
hidden_states_windows = tf.reshape(hidden_states_windows, (-1, window_size * window_size, channels)) # 重新调整窗口的形状为(-1, window_size * window_size, channels)
attn_mask = self.get_attn_mask(
height=height_pad, width=width_pad, window_size=window_size, shift_size=shift_size
) # 获取注意力掩码
attention_outputs = self.attention(
hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions, training=training
) # 应用自注意力机制
attention_output = attention_outputs[0] # 提取注意力输出的第一个元素
attention_windows = tf.reshape(attention_output, (-1, window_size, window_size, channels)) # 重新调整注意力输出的形状为(-1, window_size, window_size, channels)
shifted_windows = window_reverse(attention_windows, window_size, height_pad, width_pad) # 反转窗口
# 反向循环移位
if shift_size > 0:
attention_windows = tf.roll(shifted_windows, shift=(shift_size, shift_size), axis=(1, 2)) # 在轴(1, 2)上执行正移位
else:
attention_windows = shifted_windows # 否则不进行移位
was_padded = pad_values[3] > 0 or pad_values[5] > 0 # 检查是否对隐藏状态进行了填充
if was_padded:
attention_windows = attention_windows[:, :height, :width, :] # 如果进行了填充,则截取有效部分
attention_windows = tf.reshape(attention_windows, (batch_size, height * width, channels)) # 重新调整注意力窗口的形状为(batch_size, height * width, channels)
hidden_states = shortcut + self.drop_path(attention_windows, training=training) # 添加残差连接和DropPath
layer_output = self.layernorm_after(hidden_states, training=training) # 应用层归一化到隐藏状态之后
layer_output = self.intermediate(layer_output) # 应用中间层变换
layer_output = hidden_states + self.swin_output(layer_output, training=training) # 添加Swin Transformer的输出
layer_outputs = (layer_output, attention_outputs[1]) if output_attentions else (layer_output,) # 构造输出元组
return layer_outputs # 返回层输出
# 构建模型的方法,用于设置层的输入形状并构建层的参数
def build(self, input_shape=None):
# 如果已经构建过,则直接返回,避免重复构建
if self.built:
return
# 将构建标志设置为已构建
self.built = True
# 如果存在layernorm_before属性,则构建layernorm_before层
if getattr(self, "layernorm_before", None) is not None:
# 使用layernorm_before层的名字作为命名空间
with tf.name_scope(self.layernorm_before.name):
# 构建layernorm_before层,设置输入形状为[None, None, self.dim]
self.layernorm_before.build([None, None, self.dim])
# 如果存在attention属性,则构建attention层
if getattr(self, "attention", None) is not None:
# 使用attention层的名字作为命名空间
with tf.name_scope(self.attention.name):
# 构建attention层,输入形状为None(表示不确定的形状)
self.attention.build(None)
# 如果存在drop_path属性,则构建drop_path层
if getattr(self, "drop_path", None) is not None:
# 使用drop_path层的名字作为命名空间
with tf.name_scope(self.drop_path.name):
# 构建drop_path层,输入形状为None
self.drop_path.build(None)
# 如果存在layernorm_after属性,则构建layernorm_after层
if getattr(self, "layernorm_after", None) is not None:
# 使用layernorm_after层的名字作为命名空间
with tf.name_scope(self.layernorm_after.name):
# 构建layernorm_after层,设置输入形状为[None, None, self.dim]
self.layernorm_after.build([None, None, self.dim])
# 如果存在intermediate属性,则构建intermediate层
if getattr(self, "intermediate", None) is not None:
# 使用intermediate层的名字作为命名空间
with tf.name_scope(self.intermediate.name):
# 构建intermediate层,输入形状为None
self.intermediate.build(None)
# 如果存在swin_output属性,则构建swin_output层
if getattr(self, "swin_output", None) is not None:
# 使用swin_output层的名字作为命名空间
with tf.name_scope(self.swin_output.name):
# 构建swin_output层,输入形状为None
self.swin_output.build(None)
class TFSwinStage(keras.layers.Layer):
# 定义一个名为 TFSwinStage 的自定义 Keras 层
def __init__(
self,
config: SwinConfig,
dim: int,
input_resolution: Tuple[int, int],
depth: int,
num_heads: int,
drop_path: List[float],
downsample: Optional[Callable],
**kwargs,
) -> None:
super().__init__(**kwargs)
# 初始化函数,接受多个参数,其中包括 Swin 模型的配置、维度、输入分辨率、深度、头数、路径丢弃率等
self.config = config
self.dim = dim
# 创建一个由 TFSwinLayer 实例组成的列表,每个实例代表一个层
self.blocks = [
TFSwinLayer(
config=config,
dim=dim,
input_resolution=input_resolution,
num_heads=num_heads,
shift_size=0 if (i % 2 == 0) else config.window_size // 2,
name=f"blocks.{i}",
)
for i in range(depth)
]
# 如果存在下采样函数,创建下采样层
if downsample is not None:
self.downsample = downsample(
input_resolution,
dim=dim,
norm_layer=partial(keras.layers.LayerNormalization, epsilon=1e-5),
name="downsample",
)
else:
self.downsample = None
# 初始化指向(pointing)为 False
self.pointing = False
# 定义调用函数,处理输入并返回输出
def call(
self,
hidden_states: tf.Tensor,
input_dimensions: Tuple[int, int],
head_mask: tf.Tensor | None = None,
output_attentions: Optional[bool] = False,
training: bool = False,
) -> Tuple[tf.Tensor, ...]:
height, width = input_dimensions
# 遍历所有层,逐层处理隐藏状态
for i, layer_module in enumerate(self.blocks):
layer_head_mask = head_mask[i] if head_mask is not None else None
# 调用每个层的处理函数,获取层的输出
layer_outputs = layer_module(
hidden_states, input_dimensions, layer_head_mask, output_attentions, training=training
)
# 更新隐藏状态为当前层的输出
hidden_states = layer_outputs[0]
# 如果存在下采样层,对隐藏状态进行下采样操作
if self.downsample is not None:
height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2
output_dimensions = (height, width, height_downsampled, width_downsampled)
hidden_states = self.downsample(layer_outputs[0], input_dimensions, training=training)
else:
output_dimensions = (height, width, height, width)
# 组装阶段的输出,包括隐藏状态和输出尺寸
stage_outputs = (hidden_states, output_dimensions)
# 如果需要输出注意力权重,则将它们添加到阶段的输出中
if output_attentions:
stage_outputs += layer_outputs[1:]
return stage_outputs
# 定义构建函数,在第一次调用时构建层
def build(self, input_shape=None):
if self.built:
return
self.built = True
# 如果存在下采样层,构建该层
if getattr(self, "downsample", None) is not None:
with tf.name_scope(self.downsample.name):
self.downsample.build(None)
# 对每个层调用构建函数,构建所有的子层
if getattr(self, "blocks", None) is not None:
for layer in self.blocks:
with tf.name_scope(layer.name):
layer.build(None)
class TFSwinEncoder(keras.layers.Layer):
# 定义一个名为 TFSwinEncoder 的自定义 Keras 层
# 初始化函数,接受一个SwinTransformer的配置对象和一个网格大小的元组作为参数
def __init__(self, config: SwinConfig, grid_size: Tuple[int, int], **kwargs):
# 调用父类的初始化函数
super().__init__(**kwargs)
# 计算SwinTransformer模型的层数
self.num_layers = len(config.depths)
# 保存传入的配置对象
self.config = config
# 计算每一层的DropPath率,并转换为列表
dpr = list((tf.linspace(0, 1, sum(config.depths)) * config.drop_path_rate).numpy())
# 创建SwinTransformer的各个层
self.layers = [
TFSwinStage(
config=config,
# 计算当前层的维度
dim=int(config.embed_dim * 2**i_layer),
# 计算当前层的输入分辨率
input_resolution=(grid_size[0] // (2**i_layer), grid_size[1] // (2**i_layer)),
# 设置当前层的深度
depth=config.depths[i_layer],
# 设置当前层的头数
num_heads=config.num_heads[i_layer],
# 为当前层设置DropPath率
drop_path=dpr[sum(config.depths[:i_layer]) : sum(config.depths[: i_layer + 1])],
# 如果当前层不是最后一层,设置下采样方法;否则为None
downsample=TFSwinPatchMerging if (i_layer < self.num_layers - 1) else None,
# 设置当前层的名称
name=f"layers.{i_layer}",
)
# 对每一层进行迭代
for i_layer in range(self.num_layers)
]
# 默认关闭梯度检查点
self.gradient_checkpointing = False
# 模型调用函数,接受隐藏状态张量、输入维度元组等多个参数
def call(
self,
hidden_states: tf.Tensor,
input_dimensions: Tuple[int, int],
head_mask: tf.Tensor | None = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
training: bool = False,
) -> Union[Tuple[tf.Tensor, ...], TFSwinEncoderOutput]:
# 定义函数签名及返回类型,输入为隐藏状态及其他参数,输出为元组或TFSwinEncoderOutput类型
all_input_dimensions = ()
# 初始化空元组,用于存储所有输入维度信息
all_hidden_states = () if output_hidden_states else None
# 如果需要输出隐藏状态,则初始化空元组,否则置为None
all_reshaped_hidden_states = () if output_hidden_states else None
# 如果需要输出隐藏状态,则初始化空元组,否则置为None
all_self_attentions = () if output_attentions else None
# 如果需要输出注意力权重,则初始化空元组,否则置为None
if output_hidden_states:
batch_size, _, hidden_size = shape_list(hidden_states)
# 获取隐藏状态的批量大小、高、宽、通道数信息
# 重排形状为 b (h w) c -> b c h w
reshaped_hidden_state = tf.reshape(hidden_states, (batch_size, *input_dimensions, hidden_size))
reshaped_hidden_state = tf.transpose(reshaped_hidden_state, (0, 3, 1, 2))
# 将形状调整为 b c h w,并进行转置以匹配预期的维度顺序
all_hidden_states += (hidden_states,)
all_reshaped_hidden_states += (reshaped_hidden_state,)
# 将隐藏状态及其重排后的形状信息添加到对应的元组中
for i, layer_module in enumerate(self.layers):
# 遍历self.layers中的每一层模块
layer_head_mask = head_mask[i] if head_mask is not None else None
# 获取当前层的注意力头遮罩,如果未提供则置为None
layer_outputs = layer_module(
hidden_states, input_dimensions, layer_head_mask, output_attentions, training=training
)
# 调用当前层模块的前向传播方法,计算层的输出结果
hidden_states = layer_outputs[0]
# 更新隐藏状态为当前层输出的第一个元素(通常是最终的隐藏状态)
output_dimensions = layer_outputs[1]
# 获取当前层输出的维度信息
input_dimensions = (output_dimensions[-2], output_dimensions[-1])
# 更新输入维度为当前层输出的高和宽信息
all_input_dimensions += (input_dimensions,)
# 将更新后的输入维度信息添加到all_input_dimensions中
if output_hidden_states:
batch_size, _, hidden_size = shape_list(hidden_states)
# 获取隐藏状态的批量大小、高、宽、通道数信息
# 重排形状为 b (h w) c -> b c h w
reshaped_hidden_state = tf.reshape(hidden_states, (batch_size, *input_dimensions, hidden_size))
reshaped_hidden_state = tf.transpose(reshaped_hidden_state, (0, 3, 1, 2))
# 将形状调整为 b c h w,并进行转置以匹配预期的维度顺序
all_hidden_states += (hidden_states,)
all_reshaped_hidden_states += (reshaped_hidden_state,)
# 将隐藏状态及其重排后的形状信息添加到对应的元组中
if output_attentions:
all_self_attentions += layer_outputs[2:]
# 如果需要输出注意力权重,则将当前层输出中的注意力权重信息添加到all_self_attentions中
if not return_dict:
# 如果不需要返回字典格式的输出结果
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
# 返回所有非空的结果组成的元组
return TFSwinEncoderOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
reshaped_hidden_states=all_reshaped_hidden_states,
)
# 返回以TFSwinEncoderOutput格式封装的输出结果
def build(self, input_shape=None):
# 定义build方法,用于构建模型层次结构
if self.built:
# 如果模型已构建完成,则直接返回
return
self.built = True
# 将模型标记为已构建
if getattr(self, "layers", None) is not None:
# 如果存在模型层列表
for layer in self.layers:
# 遍历每一层
with tf.name_scope(layer.name):
# 使用层的名称创建命名空间
layer.build(None)
# 调用层的build方法构建层次结构
class TFSwinPreTrainedModel(TFPreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
# 使用 SwinConfig 类作为模型的配置类
config_class = SwinConfig
# 基础模型的前缀名为 "swin"
base_model_prefix = "swin"
# 主输入名称为 "pixel_values"
main_input_name = "pixel_values"
SWIN_START_DOCSTRING = r"""
This model is a Tensorflow
[keras.layers.Layer](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer) sub-class. Use it as a
regular Tensorflow Module and refer to the Tensorflow documentation for all matter related to general usage and
behavior.
Parameters:
config ([`SwinConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
SWIN_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`]
for details.
head_mask (`tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
def normalize_data_format(value: str) -> str:
"""
From tensorflow addons
https://github.com/tensorflow/addons/blob/8cec33fcaaf1cf90aec7bdd55a0fcdbb251ce5c2/tensorflow_addons/utils/keras_utils.py
"""
# 如果值为 None,则使用 keras 后端的图像数据格式作为值
if value is None:
value = keras.backend.image_data_format()
# 将值转换为小写
data_format = value.lower()
# 如果数据格式不是 "channels_first" 或 "channels_last",则引发 ValueError 异常
if data_format not in {"channels_first", "channels_last"}:
raise ValueError(
'The `data_format` argument must be one of "channels_first", "channels_last". Received: ' + str(value)
)
# 返回标准化后的数据格式
return data_format
class AdaptiveAveragePooling1D(keras.layers.Layer):
"""
Args:
"""
"""
Average 1D Pooling with adaptive kernel size.
output_size: An integer or tuple/list of a single integer, specifying pooled_features.
The new size of output channels.
data_format: A string,
one of `channels_last` (default) or `channels_first`. The ordering of the dimensions in the inputs.
`channels_last` corresponds to inputs with shape `(batch, steps, channels)` while `channels_first` corresponds
to inputs with shape `(batch, channels, steps)`.
Input shape:
- If `data_format='channels_last'`: 3D tensor with shape `(batch, steps, channels)`.
- If `data_format='channels_first'`: 3D tensor with shape `(batch, channels, steps)`.
Output shape:
- If `data_format='channels_last'`: 3D tensor with shape `(batch_size, pooled_steps, channels)`.
- If `data_format='channels_first'`: 3D tensor with shape `(batch_size, channels, pooled_steps)`.
Adapted from [tensorflow-addon's adaptive pooling.py](
https://github.com/tensorflow/addons/blob/8cec33fcaaf1cf90aec7bdd55a0fcdbb251ce5c2/tensorflow_addons/layers/adaptive_pooling.py#L90-L120
)
"""
# 定义一个平均池化层,支持自适应核大小
class AveragePooling1D(tf.keras.layers.Layer):
def __init__(
self,
output_size: Union[int, Iterable[int]], # 池化后的输出尺寸,可以是整数或整数组成的可迭代对象
reduce_function: Callable = tf.reduce_mean, # 池化使用的函数,默认为平均值池化
data_format: Optional[str] = None, # 数据格式,默认为 None
**kwargs, # 其他参数
) -> None:
self.data_format = normalize_data_format(data_format) # 标准化数据格式
self.reduce_function = reduce_function # 池化函数
self.output_size = (output_size,) if isinstance(output_size, int) else tuple(output_size) # 输出尺寸的元组形式
super().__init__(**kwargs) # 调用父类初始化方法
def call(self, inputs: tf.Tensor, *args) -> None:
bins = self.output_size[0] # 获取输出尺寸中的第一个值作为 bins
if self.data_format == "channels_last":
splits = tf.split(inputs, bins, axis=1) # 在通道维度上分割输入张量
splits = tf.stack(splits, axis=1) # 在第二个维度上堆叠分割后的张量
out_vect = self.reduce_function(splits, axis=2) # 沿着第三个维度对堆叠后的张量进行池化
else:
splits = tf.split(inputs, bins, axis=2) # 在时间步维度上分割输入张量
splits = tf.stack(splits, axis=2) # 在第三个维度上堆叠分割后的张量
out_vect = self.reduce_function(splits, axis=3) # 沿着第四个维度对堆叠后的张量进行池化
return out_vect # 返回池化后的张量
def compute_output_shape(self, input_shape: Iterable[int]) -> tf.TensorShape:
input_shape = tf.TensorShape(input_shape).as_list() # 将输入形状转换为列表形式
if self.data_format == "channels_last":
shape = tf.TensorShape([input_shape[0], self.output_size[0], input_shape[2]]) # 计算输出形状,通道在最后
else:
shape = tf.TensorShape([input_shape[0], input_shape[1], self.output_size[0]]) # 计算输出形状,通道在中间
return shape # 返回输出形状的张量形状对象
def get_config(self) -> Dict[str, Any]:
config = {
"output_size": self.output_size, # 输出尺寸配置
"data_format": self.data_format, # 数据格式配置
}
base_config = super().get_config() # 调用父类配置方法
return {**base_config, **config} # 返回合并后的配置字典
# 定义一个 Keras 自定义层 TFSwinMainLayer,并添加了 keras_serializable 装饰器,使其能够序列化
@keras_serializable
class TFSwinMainLayer(keras.layers.Layer):
# 设置配置类为 SwinConfig
config_class = SwinConfig
# 初始化函数,接受 SwinConfig 类型的 config 参数,以及其他可选参数
def __init__(
self, config: SwinConfig, add_pooling_layer: bool = True, use_mask_token: bool = False, **kwargs
) -> None:
# 调用父类的初始化方法
super().__init__(**kwargs)
# 将传入的配置参数 config 赋值给对象的 config 属性
self.config = config
# 计算层数,即配置的深度列表的长度
self.num_layers = len(config.depths)
# 计算特征数,为配置中的嵌入维度乘以 2 的 (层数 - 1) 次方
self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1))
# 创建 TFSwinEmbeddings 对象,并赋值给 embeddings 属性
self.embeddings = TFSwinEmbeddings(config, use_mask_token=use_mask_token, name="embeddings")
# 创建 TFSwinEncoder 对象,并传入 patch_grid 参数和名称 "encoder",赋值给 encoder 属性
self.encoder = TFSwinEncoder(config, self.embeddings.patch_grid, name="encoder")
# 创建 LayerNormalization 层,epsilon 参数为配置中的层归一化 epsilon 值,名称为 "layernorm",赋值给 layernorm 属性
self.layernorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm")
# 如果 add_pooling_layer 为 True,则创建 AdaptiveAveragePooling1D 层,输出大小为 (1,),赋值给 pooler 属性;否则 pooler 属性为 None
self.pooler = AdaptiveAveragePooling1D(output_size=(1,)) if add_pooling_layer else None
# 获取输入嵌入的方法,返回 embeddings 对象的 patch_embeddings 属性
def get_input_embeddings(self) -> TFSwinPatchEmbeddings:
return self.embeddings.patch_embeddings
# 模型头部修剪方法,接受 heads_to_prune 参数,用于剪枝模型中的注意力头
def _prune_heads(self, heads_to_prune: Dict[int, List]):
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
# 遍历 heads_to_prune 字典中的每一层和对应要剪枝的注意力头列表
for layer, heads in heads_to_prune.items():
# 在编码器(self.encoder)的指定层(layer)的注意力部分(attention)进行头部剪枝操作
self.encoder.layer[layer].attention.prune_heads(heads)
# 获取头部掩码的方法,接受 head_mask 参数,如果非空则抛出未实现错误,否则返回与深度列表长度相同的 None 列表
def get_head_mask(self, head_mask: Optional[Any]) -> List:
if head_mask is not None:
raise NotImplementedError
return [None] * len(self.config.depths)
# 调用方法,接受多个参数并进行处理,包括像素值、掩码位置、头部掩码等
@unpack_inputs
def call(
self,
pixel_values: tf.Tensor | None = None,
bool_masked_pos: tf.Tensor | None = None,
head_mask: tf.Tensor | None = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
) -> Union[TFSwinModelOutput, Tuple[tf.Tensor, ...]]:
# 如果未指定,则根据配置确定是否输出注意力权重
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
# 如果未指定,则根据配置确定是否输出隐藏状态
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
# 如果未指定,则根据配置确定是否返回字典格式的输出
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# 如果像素值为空,则抛出数值错误异常
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
# 准备头部掩码(如果需要)
# head_mask 中的 1.0 表示保留对应的注意力头部
# attention_probs 的形状为 bsz x n_heads x N x N
# 输入的 head_mask 形状为 [num_heads] 或者 [num_hidden_layers x num_heads]
# head_mask 被转换为形状 [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask)
# 将像素值传入嵌入层,并获取嵌入层的输出和输入维度
embedding_output, input_dimensions = self.embeddings(
pixel_values, bool_masked_pos=bool_masked_pos, training=training
)
# 将嵌入层的输出传入编码器,并返回编码器的输出
encoder_outputs = self.encoder(
embedding_output,
input_dimensions,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
# 获取编码器的序列输出,并进行 layer normalization
sequence_output = encoder_outputs[0]
sequence_output = self.layernorm(sequence_output, training=training)
# 初始化池化输出为 None
pooled_output = None
# 如果池化器不为空,则对序列输出进行池化
if self.pooler is not None:
batch_size, _, num_features = shape_list(sequence_output)
pooled_output = self.pooler(sequence_output)
pooled_output = tf.reshape(pooled_output, (batch_size, num_features))
# 如果不需要返回字典,则返回输出元组
if not return_dict:
output = (sequence_output, pooled_output) + encoder_outputs[1:]
return output
# 如果需要返回字典格式的输出,则构建 TFSwinModelOutput 对象
return TFSwinModelOutput(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
reshaped_hidden_states=encoder_outputs.reshaped_hidden_states,
)
def build(self, input_shape=None):
# 如果已经构建过,则直接返回
if self.built:
return
# 标记已经构建
self.built = True
# 如果存在嵌入层,则构建嵌入层
if getattr(self, "embeddings", None) is not None:
with tf.name_scope(self.embeddings.name):
self.embeddings.build(None)
# 如果存在编码器,则构建编码器
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
# 如果存在层归一化,则构建层归一化
if getattr(self, "layernorm", None) is not None:
with tf.name_scope(self.layernorm.name):
self.layernorm.build([None, None, self.num_features])
# 使用装饰器为类添加文档字符串,描述其作为裸的 Swin 模型变换器,输出未经任何特定头部处理的原始隐藏状态
@add_start_docstrings(
"The bare Swin Model transformer outputting raw hidden-states without any specific head on top.",
SWIN_START_DOCSTRING,
)
# 定义 TFSwinModel 类,继承自 TFSwinPreTrainedModel
class TFSwinModel(TFSwinPreTrainedModel):
# 初始化方法
def __init__(
self, config: SwinConfig, add_pooling_layer: bool = True, use_mask_token: bool = False, **kwargs
) -> None:
# 调用父类的初始化方法
super().__init__(config, **kwargs)
# 保存配置信息到实例变量
self.config = config
# 创建 TFSwinMainLayer 的实例 swin,并命名为 "swin"
self.swin = TFSwinMainLayer(config, name="swin")
# 为 call 方法添加文档字符串,描述其作为模型前向传播的入口点,使用 SWIN_INPUTS_DOCSTRING 作为输入文档字符串
@add_start_docstrings_to_model_forward(SWIN_INPUTS_DOCSTRING)
# 使用装饰器添加代码示例文档字符串,展示模型的使用示例
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TFSwinModelOutput,
config_class=_CONFIG_FOR_DOC,
modality="vision",
expected_output=_EXPECTED_OUTPUT_SHAPE,
)
# 使用装饰器解包输入,确保正确处理输入参数
@unpack_inputs
# 定义 call 方法,接收多个参数并返回 TFSwinModelOutput 或 tf.Tensor 元组
def call(
self,
pixel_values: tf.Tensor | None = None,
bool_masked_pos: tf.Tensor | None = None,
head_mask: tf.Tensor | None = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
) -> Union[TFSwinModelOutput, Tuple[tf.Tensor, ...]]:
r"""
bool_masked_pos (`tf.Tensor` of shape `(batch_size, num_patches)`, *optional*):
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
"""
# 根据需要确定是否输出注意力权重,默认使用配置中的设置
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
# 根据需要确定是否输出隐藏状态,默认使用配置中的设置
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
# 根据需要确定是否返回字典形式的输出,默认使用配置中的设置
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# 如果未提供像素值,则引发值错误
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
# 调用 self.swin 的前向传播方法,传递所有参数,并获取模型输出
swin_outputs = self.swin(
pixel_values=pixel_values,
bool_masked_pos=bool_masked_pos,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
# 返回模型输出
return swin_outputs
# 实现 build 方法,用于构建模型层次结构
def build(self, input_shape=None):
# 如果已经构建过,直接返回
if self.built:
return
# 标记为已构建
self.built = True
# 如果 self.swin 已存在,则在命名空间下构建 self.swin
if getattr(self, "swin", None) is not None:
with tf.name_scope(self.swin.name):
self.swin.build(None)
# 定义 TFSwinPixelShuffle 类,继承自 keras.layers.Layer,实现了 torch.nn.PixelShuffle 的 TensorFlow 版本的层
class TFSwinPixelShuffle(keras.layers.Layer):
"""TF layer implementation of torch.nn.PixelShuffle"""
# 初始化方法
def __init__(self, upscale_factor: int, **kwargs) -> None:
# 调用父类的初始化方法
super().__init__(**kwargs)
# 如果 upscale_factor 不是整数或小于 2,则引发值错误
if not isinstance(upscale_factor, int) or upscale_factor < 2:
raise ValueError(f"upscale_factor must be an integer value >= 2 got {upscale_factor}")
# 保存 upscale_factor 到实例变量
self.upscale_factor = upscale_factor
# 定义一个方法,接受一个张量 x 作为输入,返回一个张量作为输出
def call(self, x: tf.Tensor) -> tf.Tensor:
# 将输入张量赋值给 hidden_states
hidden_states = x
# 调用 shape_list 函数获取 hidden_states 的形状信息,并解包得到 batch_size, _, _, num_input_channels
batch_size, _, _, num_input_channels = shape_list(hidden_states)
# 计算块大小的平方
block_size_squared = self.upscale_factor**2
# 计算输出深度,即 num_input_channels 除以块大小的平方后取整
output_depth = int(num_input_channels / block_size_squared)
# 创建一个常量张量 permutation,用于存储一个通道排列顺序的索引
permutation = tf.constant(
# 使用列表推导式生成的二维数组,每个元素是一个索引,按照不同通道和块的顺序排列
[[i + j * block_size_squared for i in range(block_size_squared) for j in range(output_depth)]]
)
# 使用 tf.gather 函数根据 permutation 中的索引重新组织 hidden_states 的通道
hidden_states = tf.gather(params=hidden_states, indices=tf.tile(permutation, [batch_size, 1]), batch_dims=-1)
# 使用 tf.nn.depth_to_space 函数进行深度到空间的转换,根据 upscale_factor 参数进行块的重新排列
hidden_states = tf.nn.depth_to_space(hidden_states, block_size=self.upscale_factor, data_format="NHWC")
# 返回处理后的 hidden_states 作为结果
return hidden_states
# 自定义的 TensorFlow 2.x 模型层,用于实现 TFSwin 模型的解码器部分
class TFSwinDecoder(keras.layers.Layer):
def __init__(self, config: SwinConfig, **kwargs):
super().__init__(**kwargs)
# 定义一个 1x1 卷积层,用于特征变换
self.conv2d = keras.layers.Conv2D(
filters=config.encoder_stride**2 * config.num_channels, kernel_size=1, strides=1, name="0"
)
# 像素重排层,用于反向像素重排
self.pixel_shuffle = TFSwinPixelShuffle(config.encoder_stride, name="1")
# 保存 Swin 模型的配置信息
self.config = config
def call(self, x: tf.Tensor) -> tf.Tensor:
# 将输入张量从 B,C,H,W 转置为 B,H,W,C
hidden_states = x
hidden_states = tf.transpose(hidden_states, (0, 2, 3, 1))
# 经过 1x1 卷积层变换
hidden_states = self.conv2d(hidden_states)
# 经过像素重排层
hidden_states = self.pixel_shuffle(hidden_states)
# 将输出张量从 B,H,W,C 转置为 B,C,H,W
hidden_states = tf.transpose(hidden_states, (0, 3, 1, 2))
return hidden_states
def build(self, input_shape=None):
# 如果已经构建过,直接返回
if self.built:
return
self.built = True
# 构建卷积层
if getattr(self, "conv2d", None) is not None:
with tf.name_scope(self.conv2d.name):
self.conv2d.build([None, None, None, self.config.hidden_size])
# 构建像素重排层
if getattr(self, "pixel_shuffle", None) is not None:
with tf.name_scope(self.pixel_shuffle.name):
self.pixel_shuffle.build(None)
# 基于 Swin 模型的一个变体,用于处理带掩码的图像建模,参考 SimMIM 论文提出的方法
@add_start_docstrings(
"Swin Model with a decoder on top for masked image modeling, as proposed in"
" [SimMIM](https://arxiv.org/abs/2111.09886).",
SWIN_START_DOCSTRING,
)
class TFSwinForMaskedImageModeling(TFSwinPreTrainedModel):
def __init__(self, config: SwinConfig):
super().__init__(config)
# Swin 主层,不包含池化层,使用掩码标记
self.swin = TFSwinMainLayer(config, add_pooling_layer=False, use_mask_token=True, name="swin")
# Swin 解码器层
self.decoder = TFSwinDecoder(config, name="decoder")
@add_start_docstrings_to_model_forward(SWIN_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFSwinMaskedImageModelingOutput, config_class=_CONFIG_FOR_DOC)
@unpack_inputs
def call(
self,
pixel_values: tf.Tensor | None = None,
bool_masked_pos: tf.Tensor | None = None,
head_mask: tf.Tensor | None = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
):
# 略
pass
def build(self, input_shape=None):
# 如果已经构建过,直接返回
if self.built:
return
self.built = True
# 构建 Swin 主层
if getattr(self, "swin", None) is not None:
with tf.name_scope(self.swin.name):
self.swin.build(None)
# 构建 Swin 解码器层
if getattr(self, "decoder", None) is not None:
with tf.name_scope(self.decoder.name):
self.decoder.build(None)
# Swin 模型的图像分类变体,顶部附加了一个分类头部的线性层(在 [CLS] 标记的最终隐藏状态之上),例如用于 ImageNet
@add_start_docstrings(
"""
Swin Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
the [CLS] token) e.g. for ImageNet.
""",
SWIN_START_DOCSTRING,
)
class TFSwinForImageClassification(TFSwinPreTrainedModel, TFSequenceClassificationLoss):
# 略
pass
# 初始化函数,接受一个 SwinConfig 类型的配置对象作为参数
def __init__(self, config: SwinConfig):
# 调用父类的初始化方法
super().__init__(config)
# 设置类的属性,表示分类数目
self.num_labels = config.num_labels
# 创建一个 TFSwinMainLayer 类的实例,命名为 "swin"
self.swin = TFSwinMainLayer(config, name="swin")
# 分类器头部
# 如果配置的标签数目大于 0,则创建一个全连接层作为分类器
# 否则创建一个线性激活层作为分类器
self.classifier = (
keras.layers.Dense(config.num_labels, name="classifier")
if config.num_labels > 0
else keras.layers.Activation("linear", name="classifier")
)
# 根据装饰器提供的文档字符串,定义了模型前向传播的方法
@add_start_docstrings_to_model_forward(SWIN_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_IMAGE_CLASS_CHECKPOINT,
output_type=TFSwinImageClassifierOutput,
config_class=_CONFIG_FOR_DOC,
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
)
@unpack_inputs
def call(
self,
pixel_values: tf.Tensor | None = None,
head_mask: tf.Tensor | None = None,
labels: tf.Tensor | None = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
) -> Union[Tuple[tf.Tensor, ...], TFSwinImageClassifierOutput]:
"""
labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
# 确定是否返回字典类型的输出
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# 调用 Swin 模型的前向传播方法
outputs = self.swin(
pixel_values,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
# 获取池化后的输出
pooled_output = outputs[1]
# 将池化输出传递给分类器进行预测
logits = self.classifier(pooled_output, training=training)
# 如果有提供标签,则计算损失
loss = None if labels is None else self.hf_compute_loss(labels, logits)
# 如果不要求返回字典类型的输出,则按需返回输出的元组
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
# 否则返回 TFSwinImageClassifierOutput 类型的对象
return TFSwinImageClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
reshaped_hidden_states=outputs.reshaped_hidden_states,
)
# 构建模型,设置模型的输入形状
def build(self, input_shape=None):
# 如果模型已经构建过,则直接返回
if self.built:
return
# 标记模型已经构建
self.built = True
# 如果存在 Swin 层,则在其命名空间下构建 Swin 层
if getattr(self, "swin", None) is not None:
with tf.name_scope(self.swin.name):
self.swin.build(None)
# 如果存在分类器,则在其命名空间下构建分类器,并传入 Swin 特征数目作为输入形状的一部分
if getattr(self, "classifier", None) is not None:
if hasattr(self.classifier, "name"):
with tf.name_scope(self.classifier.name):
self.classifier.build([None, None, self.swin.num_features])
.\models\swin\__init__.py
from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tf_available, is_torch_available
_import_structure = {"configuration_swin": ["SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP", "SwinConfig", "SwinOnnxConfig"]}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_swin"] = [
"SWIN_PRETRAINED_MODEL_ARCHIVE_LIST",
"SwinForImageClassification",
"SwinForMaskedImageModeling",
"SwinModel",
"SwinPreTrainedModel",
"SwinBackbone",
]
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_tf_swin"] = [
"TF_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFSwinForImageClassification",
"TFSwinForMaskedImageModeling",
"TFSwinModel",
"TFSwinPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_swin import SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP, SwinConfig, SwinOnnxConfig
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_swin import (
SWIN_PRETRAINED_MODEL_ARCHIVE_LIST,
SwinBackbone,
SwinForImageClassification,
SwinForMaskedImageModeling,
SwinModel,
SwinPreTrainedModel,
)
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_tf_swin import (
TF_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST,
TFSwinForImageClassification,
TFSwinForMaskedImageModeling,
TFSwinModel,
TFSwinPreTrainedModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
.\models\swin2sr\configuration_swin2sr.py
""" Swin2SR Transformer model configuration"""
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
SWIN2SR_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"caidas/swin2sr-classicalsr-x2-64": (
"https://huggingface.co/caidas/swin2sr-classicalsr-x2-64/resolve/main/config.json"
),
}
class Swin2SRConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Swin2SRModel`]. It is used to instantiate a Swin
Transformer v2 model according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the Swin Transformer v2
[caidas/swin2sr-classicalsr-x2-64](https://huggingface.co/caidas/swin2sr-classicalsr-x2-64) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Example:
```
>>> from transformers import Swin2SRConfig, Swin2SRModel
>>> # Initializing a Swin2SR caidas/swin2sr-classicalsr-x2-64 style configuration
>>> configuration = Swin2SRConfig()
>>> # Initializing a model (with random weights) from the caidas/swin2sr-classicalsr-x2-64 style configuration
>>> model = Swin2SRModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
"""
model_type = "swin2sr"
attribute_map = {
"hidden_size": "embed_dim",
"num_attention_heads": "num_heads",
"num_hidden_layers": "num_layers",
}
def __init__(
self,
image_size=64,
patch_size=1,
num_channels=3,
num_channels_out=None,
embed_dim=180,
depths=[6, 6, 6, 6, 6, 6],
num_heads=[6, 6, 6, 6, 6, 6],
window_size=8,
mlp_ratio=2.0,
qkv_bias=True,
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
drop_path_rate=0.1,
hidden_act="gelu",
use_absolute_embeddings=False,
initializer_range=0.02,
layer_norm_eps=1e-5,
upscale=2,
img_range=1.0,
resi_connection="1conv",
upsampler="pixelshuffle",
**kwargs,
super().__init__(**kwargs)
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.num_channels_out = num_channels if num_channels_out is None else num_channels_out
self.embed_dim = embed_dim
self.depths = depths
self.num_layers = len(depths)
self.num_heads = num_heads
self.window_size = window_size
self.mlp_ratio = mlp_ratio
self.qkv_bias = qkv_bias
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.drop_path_rate = drop_path_rate
self.hidden_act = hidden_act
self.use_absolute_embeddings = use_absolute_embeddings
self.layer_norm_eps = layer_norm_eps
self.initializer_range = initializer_range
self.upscale = upscale
self.img_range = img_range
self.resi_connection = resi_connection
self.upsampler = upsampler
.\models\swin2sr\convert_swin2sr_original_to_pytorch.py
"""Convert Swin2SR checkpoints from the original repository. URL: https://github.com/mv-lab/swin2sr"""
import argparse
import requests
import torch
from PIL import Image
from torchvision.transforms import Compose, Normalize, Resize, ToTensor
from transformers import Swin2SRConfig, Swin2SRForImageSuperResolution, Swin2SRImageProcessor
def get_config(checkpoint_url):
config = Swin2SRConfig()
if "Swin2SR_ClassicalSR_X4_64" in checkpoint_url:
config.upscale = 4
elif "Swin2SR_CompressedSR_X4_48" in checkpoint_url:
config.upscale = 4
config.image_size = 48
config.upsampler = "pixelshuffle_aux"
elif "Swin2SR_Lightweight_X2_64" in checkpoint_url:
config.depths = [6, 6, 6, 6]
config.embed_dim = 60
config.num_heads = [6, 6, 6, 6]
config.upsampler = "pixelshuffledirect"
elif "Swin2SR_RealworldSR_X4_64_BSRGAN_PSNR" in checkpoint_url:
config.upscale = 4
config.upsampler = "nearest+conv"
elif "Swin2SR_Jpeg_dynamic" in checkpoint_url:
config.num_channels = 1
config.upscale = 1
config.image_size = 126
config.window_size = 7
config.img_range = 255.0
config.upsampler = ""
return config
def rename_key(name, config):
if "patch_embed.proj" in name and "layers" not in name:
name = name.replace("patch_embed.proj", "embeddings.patch_embeddings.projection")
if "patch_embed.norm" in name:
name = name.replace("patch_embed.norm", "embeddings.patch_embeddings.layernorm")
if "layers" in name:
name = name.replace("layers", "encoder.stages")
if "residual_group.blocks" in name:
name = name.replace("residual_group.blocks", "layers")
if "attn.proj" in name:
name = name.replace("attn.proj", "attention.output.dense")
if "attn" in name:
name = name.replace("attn", "attention.self")
if "norm1" in name:
name = name.replace("norm1", "layernorm_before")
if "norm2" in name:
name = name.replace("norm2", "layernorm_after")
if "mlp.fc1" in name:
name = name.replace("mlp.fc1", "intermediate.dense")
if "mlp.fc2" in name:
name = name.replace("mlp.fc2", "output.dense")
if "q_bias" in name:
name = name.replace("q_bias", "query.bias")
if "k_bias" in name:
name = name.replace("k_bias", "key.bias")
if "v_bias" in name:
name = name.replace("v_bias", "value.bias")
if "cpb_mlp" in name:
name = name.replace("cpb_mlp", "continuous_position_bias_mlp")
if "patch_embed.proj" in name:
name = name.replace("patch_embed.proj", "patch_embed.projection")
if name == "norm.weight":
name = "layernorm.weight"
if name == "norm.bias":
name = "layernorm.bias"
if "conv_first" in name:
name = name.replace("conv_first", "first_convolution")
if (
"upsample" in name
or "conv_before_upsample" in name
or "conv_bicubic" in name
or "conv_up" in name
or "conv_hr" in name
or "conv_last" in name
or "aux" in name
):
if "conv_last" in name:
name = name.replace("conv_last", "final_convolution")
if config.upsampler in ["pixelshuffle", "pixelshuffle_aux", "nearest+conv"]:
if "conv_before_upsample.0" in name:
name = name.replace("conv_before_upsample.0", "conv_before_upsample")
if "upsample.0" in name:
name = name.replace("upsample.0", "upsample.convolution_0")
if "upsample.2" in name:
name = name.replace("upsample.2", "upsample.convolution_1")
name = "upsample." + name
elif config.upsampler == "pixelshuffledirect":
name = name.replace("upsample.0.weight", "upsample.conv.weight")
name = name.replace("upsample.0.bias", "upsample.conv.bias")
else:
pass
else:
name = "swin2sr." + name
return name
def convert_state_dict(orig_state_dict, config):
for key in orig_state_dict.copy().keys():
val = orig_state_dict.pop(key)
if "qkv" in key:
key_split = key.split(".")
stage_num = int(key_split[1])
block_num = int(key_split[4])
dim = config.embed_dim
if "weight" in key:
orig_state_dict[
f"swin2sr.encoder.stages.{stage_num}.layers.{block_num}.attention.self.query.weight"
] = val[:dim, :]
orig_state_dict[
f"swin2sr.encoder.stages.{stage_num}.layers.{block_num}.attention.self.key.weight"
] = val[dim : dim * 2, :]
orig_state_dict[
f"swin2sr.encoder.stages.{stage_num}.layers.{block_num}.attention.self.value.weight"
] = val[-dim:, :]
else:
orig_state_dict[
f"swin2sr.encoder.stages.{stage_num}.layers.{block_num}.attention.self.query.bias"
] = val[:dim]
orig_state_dict[
f"swin2sr.encoder.stages.{stage_num}.layers.{block_num}.attention.self.key.bias"
] = val[dim : dim * 2]
orig_state_dict[
f"swin2sr.encoder.stages.{stage_num}.layers.{block_num}.attention.self.value.bias"
] = val[-dim:]
pass
else:
orig_state_dict[rename_key(key, config)] = val
return orig_state_dict
def convert_swin2sr_checkpoint(checkpoint_url, pytorch_dump_folder_path, push_to_hub):
config = get_config(checkpoint_url)
model = Swin2SRForImageSuperResolution(config)
model.eval()
state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu")
new_state_dict = convert_state_dict(state_dict, config)
missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False)
if len(missing_keys) > 0:
raise ValueError("Missing keys when converting: {}".format(missing_keys))
for key in unexpected_keys:
if not ("relative_position_index" in key or "relative_coords_table" in key or "self_mask" in key):
raise ValueError(f"Unexpected key {key} in state_dict")
url = "https://github.com/mv-lab/swin2sr/blob/main/testsets/real-inputs/shanghai.jpg?raw=true"
image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
processor = Swin2SRImageProcessor()
image_size = 126 if "Jpeg" in checkpoint_url else 256
transforms = Compose(
[
Resize((image_size, image_size)),
ToTensor(),
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
pixel_values = transforms(image).unsqueeze(0)
if config.num_channels == 1:
pixel_values = pixel_values[:, 0, :, :].unsqueeze(1)
outputs = model(pixel_values)
if "Swin2SR_ClassicalSR_X2_64" in checkpoint_url:
expected_shape = torch.Size([1, 3, 512, 512])
expected_slice = torch.tensor(
[[-0.7087, -0.7138, -0.6721], [-0.8340, -0.8095, -0.7298], [-0.9149, -0.8414, -0.7940]]
)
elif "Swin2SR_ClassicalSR_X4_64" in checkpoint_url:
expected_shape = torch.Size([1, 3, 1024, 1024])
expected_slice = torch.tensor(
[[-0.7775, -0.8105, -0.8933], [-0.7764, -0.8356, -0.9225], [-0.7976, -0.8686, -0.9579]]
)
elif "Swin2SR_CompressedSR_X4_48" in checkpoint_url:
expected_shape = torch.Size([1, 3, 1024, 1024])
expected_slice = torch.tensor(
[[-0.8035, -0.7504, -0.7491], [-0.8538, -0.8124, -0.7782], [-0.8804, -0.8651, -0.8493]]
)
elif "Swin2SR_Lightweight_X2_64" in checkpoint_url:
expected_shape = torch.Size([1, 3, 512, 512])
expected_slice = torch.tensor(
[[-0.7669, -0.8662, -0.8767], [-0.8810, -0.9962, -0.9820], [-0.9340, -1.0322, -1.1149]]
)
elif "Swin2SR_RealworldSR_X4_64_BSRGAN_PSNR" in checkpoint_url:
expected_shape = torch.Size([1, 3, 1024, 1024])
expected_slice = torch.tensor(
[[-0.5238, -0.5557, -0.6321], [-0.6016, -0.5903, -0.6391], [-0.6244, -0.6334, -0.6889]]
)
assert (
outputs.reconstruction.shape == expected_shape
), f"Shape of reconstruction should be {expected_shape}, but is {outputs.reconstruction.shape}"
assert torch.allclose(outputs.reconstruction[0, 0, :3, :3], expected_slice, atol=1e-3)
print("Looks ok!")
url_to_name = {
"https://github.com/mv-lab/swin2sr/releases/download/v0.0.1/Swin2SR_ClassicalSR_X2_64.pth": (
"swin2SR-classical-sr-x2-64"
),
"https://github.com/mv-lab/swin2sr/releases/download/v0.0.1/Swin2SR_ClassicalSR_X4_64.pth": (
"swin2SR-classical-sr-x4-64"
),
"https://github.com/mv-lab/swin2sr/releases/download/v0.0.1/Swin2SR_CompressedSR_X4_48.pth": (
"swin2SR-compressed-sr-x4-48"
),
"https://github.com/mv-lab/swin2sr/releases/download/v0.0.1/Swin2SR_Lightweight_X2_64.pth": (
"swin2SR-lightweight-x2-64"
),
"https://github.com/mv-lab/swin2sr/releases/download/v0.0.1/Swin2SR_RealworldSR_X4_64_BSRGAN_PSNR.pth": (
"swin2SR-realworld-sr-x4-64-bsrgan-psnr"
),
}
model_name = url_to_name[checkpoint_url]
if pytorch_dump_folder_path is not None:
print(f"Saving model {model_name} to {pytorch_dump_folder_path}")
model.save_pretrained(pytorch_dump_folder_path)
print(f"Saving image processor to {pytorch_dump_folder_path}")
processor.save_pretrained(pytorch_dump_folder_path)
if push_to_hub:
model.push_to_hub(f"caidas/{model_name}")
processor.push_to_hub(f"caidas/{model_name}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--checkpoint_url",
default="https://github.com/mv-lab/swin2sr/releases/download/v0.0.1/Swin2SR_ClassicalSR_X2_64.pth",
type=str,
help="URL of the original Swin2SR checkpoint you'd like to convert.",
)
parser.add_argument(
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
)
parser.add_argument("--push_to_hub", action="store_true", help="Whether to push the converted model to the hub.")
args = parser.parse_args()
convert_swin2sr_checkpoint(args.checkpoint_url, args.pytorch_dump_folder_path, args.push_to_hub)
.\models\swin2sr\image_processing_swin2sr.py
"""Image processor class for Swin2SR."""
from typing import Optional, Union
import numpy as np
from ...image_processing_utils import BaseImageProcessor, BatchFeature
from ...image_transforms import get_image_size, pad, to_channel_dimension_format
from ...image_utils import (
ChannelDimension,
ImageInput,
infer_channel_dimension_format,
is_scaled_image,
make_list_of_images,
to_numpy_array,
valid_images,
validate_kwargs,
validate_preprocess_arguments,
)
from ...utils import TensorType, logging
logger = logging.get_logger(__name__)
class Swin2SRImageProcessor(BaseImageProcessor):
r"""
Constructs a Swin2SR image processor.
Args:
do_rescale (`bool`, *optional*, defaults to `True`):
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
parameter in the `preprocess` method.
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
`preprocess` method.
"""
model_input_names = ["pixel_values"]
def __init__(
self,
do_rescale: bool = True,
rescale_factor: Union[int, float] = 1 / 255,
do_pad: bool = True,
pad_size: int = 8,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_pad = do_pad
self.pad_size = pad_size
self._valid_processor_keys = [
"images",
"do_rescale",
"rescale_factor",
"do_pad",
"pad_size",
"return_tensors",
"data_format",
"input_data_format",
]
def pad(
self,
image: np.ndarray,
size: int,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
):
"""
Pad an image to make the height and width divisible by `size`.
Args:
image (`np.ndarray`):
Image to pad.
size (`int`):
The size to make the height and width divisible by.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the output image. If unset, the channel dimension format of the input
image is used. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
input_data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
Returns:
`np.ndarray`: The padded image.
"""
old_height, old_width = get_image_size(image, input_data_format)
pad_height = (old_height // size + 1) * size - old_height
pad_width = (old_width // size + 1) * size - old_width
return pad(
image,
((0, pad_height), (0, pad_width)),
mode="symmetric",
data_format=data_format,
input_data_format=input_data_format,
)
def preprocess(
self,
images: ImageInput,
do_rescale: Optional[bool] = None,
rescale_factor: Optional[float] = None,
do_pad: Optional[bool] = None,
pad_size: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
.\models\swin2sr\modeling_swin2sr.py
""" PyTorch Swin2SR Transformer model."""
import collections.abc
import math
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput, ImageSuperResolutionOutput
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer
from ...utils import (
ModelOutput,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from .configuration_swin2sr import Swin2SRConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "Swin2SRConfig"
_CHECKPOINT_FOR_DOC = "caidas/swin2SR-classical-sr-x2-64"
_EXPECTED_OUTPUT_SHAPE = [1, 180, 488, 648]
SWIN2SR_PRETRAINED_MODEL_ARCHIVE_LIST = [
"caidas/swin2SR-classical-sr-x2-64",
]
@dataclass
class Swin2SREncoderOutput(ModelOutput):
"""
Swin2SR 编码器的输出,可能包含隐藏状态和注意力权重。
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
模型最后一层的隐藏状态序列输出。
hidden_states (`tuple(torch.FloatTensor)`, *可选*, 当 `output_hidden_states=True` 传递或当 `config.output_hidden_states=True` 时返回):
模型每层的隐藏状态的元组,包括初始嵌入的输出。
模型每层的隐藏状态以及初始嵌入的输出。
attentions (`tuple(torch.FloatTensor)`, *可选*, 当 `output_attentions=True` 传递或当 `config.output_attentions=True` 时返回):
模型每阶段的注意力权重的元组。
注意力 softmax 后的注意力权重,用于计算自注意力头中的加权平均值。
"""
last_hidden_state: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
def window_partition(input_feature, window_size):
"""
Partitions the given input into windows.
"""
batch_size, height, width, num_channels = input_feature.shape
input_feature = input_feature.view(
batch_size, height // window_size, window_size, width // window_size, window_size, num_channels
)
windows = input_feature.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, num_channels)
return windows
def window_reverse(windows, window_size, height, width):
"""
Merges windows to produce higher resolution features.
"""
num_channels = windows.shape[-1]
windows = windows.view(-1, height // window_size, width // window_size, window_size, window_size, num_channels)
windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, height, width, num_channels)
return windows
def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
argument.
"""
if drop_prob == 0.0 or not training:
return input
keep_prob = 1 - drop_prob
shape = (input.shape[0],) + (1,) * (input.ndim - 1)
random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
random_tensor.floor_()
output = input.div(keep_prob) * random_tensor
return output
class Swin2SRDropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob: Optional[float] = None) -> None:
super().__init__()
self.drop_prob = drop_prob
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return drop_path(hidden_states, self.drop_prob, self.training)
def extra_repr(self) -> str:
return "p={}".format(self.drop_prob)
class Swin2SREmbeddings(nn.Module):
"""
Construct the patch and optional position embeddings.
"""
def __init__(self, config):
super().__init__()
self.patch_embeddings = Swin2SRPatchEmbeddings(config)
num_patches = self.patch_embeddings.num_patches
if config.use_absolute_embeddings:
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.embed_dim))
else:
self.position_embeddings = None
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.window_size = config.window_size
def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor]:
embeddings, output_dimensions = self.patch_embeddings(pixel_values)
if self.position_embeddings is not None:
embeddings = embeddings + self.position_embeddings
embeddings = self.dropout(embeddings)
return embeddings, output_dimensions
class Swin2SRPatchEmbeddings(nn.Module):
def __init__(self, config, normalize_patches=True):
super().__init__()
num_channels = config.embed_dim
image_size, patch_size = config.image_size, config.patch_size
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
patches_resolution = [image_size[0] // patch_size[0], image_size[1] // patch_size[1]]
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.projection = nn.Conv2d(num_channels, config.embed_dim, kernel_size=patch_size, stride=patch_size)
self.layernorm = nn.LayerNorm(config.embed_dim) if normalize_patches else None
def forward(self, embeddings: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]:
embeddings = self.projection(embeddings)
_, _, height, width = embeddings.shape
output_dimensions = (height, width)
embeddings = embeddings.flatten(2).transpose(1, 2)
if self.layernorm is not None:
embeddings = self.layernorm(embeddings)
return embeddings, output_dimensions
class Swin2SRPatchUnEmbeddings(nn.Module):
r"""Image to Patch Unembedding"""
def __init__(self, config):
super().__init__()
self.embed_dim = config.embed_dim
def forward(self, embeddings, x_size):
batch_size, height_width, num_channels = embeddings.shape
embeddings = embeddings.transpose(1, 2).view(batch_size, self.embed_dim, x_size[0], x_size[1])
return embeddings
class Swin2SRPatchMerging(nn.Module):
"""
Patch Merging Layer.
Args:
input_resolution (`Tuple[int]`):
Resolution of input feature.
dim (`int`):
Number of input channels.
norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`):
Normalization layer class.
"""
def __init__(self, input_resolution: Tuple[int], dim: int, norm_layer: nn.Module = nn.LayerNorm) -> None:
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(2 * dim)
def maybe_pad(self, input_feature, height, width):
should_pad = (height % 2 == 1) or (width % 2 == 1)
if should_pad:
pad_values = (0, 0, 0, width % 2, 0, height % 2)
input_feature = nn.functional.pad(input_feature, pad_values)
return input_feature
def forward(self, input_feature: torch.Tensor, input_dimensions: Tuple[int, int]) -> torch.Tensor:
height, width = input_dimensions
batch_size, dim, num_channels = input_feature.shape
input_feature = input_feature.view(batch_size, height, width, num_channels)
input_feature = self.maybe_pad(input_feature, height, width)
input_feature_0 = input_feature[:, 0::2, 0::2, :]
input_feature_1 = input_feature[:, 1::2, 0::2, :]
input_feature_2 = input_feature[:, 0::2, 1::2, :]
input_feature_3 = input_feature[:, 1::2, 1::2, :]
input_feature = torch.cat([input_feature_0, input_feature_1, input_feature_2, input_feature_3], -1)
input_feature = input_feature.view(batch_size, -1, 4 * num_channels)
input_feature = self.reduction(input_feature)
input_feature = self.norm(input_feature)
return input_feature
class Swin2SRSelfAttention(nn.Module):
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
batch_size, dim, num_channels = hidden_states.shape
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
attention_scores = nn.functional.normalize(query_layer, dim=-1) @ nn.functional.normalize(
key_layer, dim=-1
).transpose(-2, -1)
logit_scale = torch.clamp(self.logit_scale, max=math.log(1.0 / 0.01)).exp()
attention_scores = attention_scores * logit_scale
relative_position_bias_table = self.continuous_position_bias_mlp(self.relative_coords_table).view(
-1, self.num_attention_heads
)
relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
)
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
attention_scores = attention_scores + relative_position_bias.unsqueeze(0)
if attention_mask is not None:
mask_shape = attention_mask.shape[0]
attention_scores = attention_scores.view(
batch_size // mask_shape, mask_shape, self.num_attention_heads, dim, dim
) + attention_mask.unsqueeze(1).unsqueeze(0)
attention_scores = attention_scores + attention_mask.unsqueeze(1).unsqueeze(0)
attention_scores = attention_scores.view(-1, self.num_attention_heads, dim, dim)
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
attention_probs = self.dropout(attention_probs)
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs
class Swin2SRSelfOutput(nn.Module):
def __init__(self, config, dim):
super().__init__()
self.dense = nn.Linear(dim, dim)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states
class Swin2SRAttention(nn.Module):
def __init__(self, config, dim, num_heads, window_size, pretrained_window_size=0):
super().__init__()
self.self = Swin2SRSelfAttention(
config=config,
dim=dim,
num_heads=num_heads,
window_size=window_size,
pretrained_window_size=pretrained_window_size
if isinstance(pretrained_window_size, collections.abc.Iterable)
else (pretrained_window_size, pretrained_window_size),
)
self.output = Swin2SRSelfOutput(config, dim)
self.pruned_heads = set()
def prune_heads(self, heads):
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
)
self.self.query = prune_linear_layer(self.self.query, index)
self.self.key = prune_linear_layer(self.self.key, index)
self.self.value = prune_linear_layer(self.self.value, index)
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
self_outputs = self.self(hidden_states, attention_mask, head_mask, output_attentions)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:]
return outputs
class Swin2SRIntermediate(nn.Module):
def __init__(self, config, dim):
super().__init__()
self.dense = nn.Linear(dim, int(config.mlp_ratio * dim))
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class Swin2SROutput(nn.Module):
def __init__(self, config, dim):
super().__init__()
self.dense = nn.Linear(int(config.mlp_ratio * dim), dim)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states
class Swin2SRLayer(nn.Module):
def __init__(self, config, dim, input_resolution, num_heads, shift_size=0, pretrained_window_size=0):
super().__init__()
self.input_resolution = input_resolution
window_size, shift_size = self._compute_window_shift(
(config.window_size, config.window_size), (shift_size, shift_size)
)
self.window_size = window_size[0]
self.shift_size = shift_size[0]
self.attention = Swin2SRAttention(
config=config,
dim=dim,
num_heads=num_heads,
window_size=self.window_size,
pretrained_window_size=pretrained_window_size
if isinstance(pretrained_window_size, collections.abc.Iterable)
else (pretrained_window_size, pretrained_window_size),
)
self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps)
self.drop_path = Swin2SRDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
self.intermediate = Swin2SRIntermediate(config, dim)
self.output = Swin2SROutput(config, dim)
self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps)
def _compute_window_shift(self, target_window_size, target_shift_size) -> Tuple[Tuple[int, int], Tuple[int, int]]:
window_size = [r if r <= w else w for r, w in zip(self.input_resolution, target_window_size)]
shift_size = [0 if r <= w else s for r, w, s in zip(self.input_resolution, window_size, target_shift_size)]
return window_size, shift_size
````
def get_attn_mask(self, height, width, dtype):
if self.shift_size > 0:
img_mask = torch.zeros((1, height, width, 1), dtype=dtype)
height_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None),
)
width_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None),
)
count = 0
for height_slice in height_slices:
for width_slice in width_slices:
img_mask[:, height_slice, width_slice, :] = count
count += 1
mask_windows = window_partition(img_mask, self.window_size)
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
else:
attn_mask = None
return attn_mask
def maybe_pad(self, hidden_states, height, width):
pad_right = (self.window_size - width % self.window_size) % self.window_size
pad_bottom = (self.window_size - height % self.window_size) % self.window_size
pad_values = (0, 0, 0, pad_right, 0, pad_bottom)
hidden_states = nn.functional.pad(hidden_states, pad_values)
return hidden_states, def forward(
self,
hidden_states: torch.Tensor,
input_dimensions: Tuple[int, int],
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
height, width = input_dimensions
batch_size, _, channels = hidden_states.size()
shortcut = hidden_states
hidden_states = hidden_states.view(batch_size, height, width, channels)
hidden_states, pad_values = self.maybe_pad(hidden_states, height, width)
_, height_pad, width_pad, _ = hidden_states.shape
if self.shift_size > 0:
shifted_hidden_states = torch.roll(hidden_states, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
else:
shifted_hidden_states = hidden_states
hidden_states_windows = window_partition(shifted_hidden_states, self.window_size)
hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels)
attn_mask = self.get_attn_mask(height_pad, width_pad, dtype=hidden_states.dtype)
if attn_mask is not None:
attn_mask = attn_mask.to(hidden_states_windows.device)
attention_outputs = self.attention(
hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions
)
attention_output = attention_outputs[0]
attention_windows = attention_output.view(-1, self.window_size, self.window_size, channels)
shifted_windows = window_reverse(attention_windows, self.window_size, height_pad, width_pad)
if self.shift_size > 0:
attention_windows = torch.roll(shifted_windows, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
attention_windows = shifted_windows
was_padded = pad_values[3] > 0 or pad_values[5] > 0
if was_padded:
attention_windows = attention_windows[:, :height, :width, :].contiguous()
attention_windows = attention_windows.view(batch_size, height * width, channels)
hidden_states = self.layernorm_before(attention_windows)
hidden_states = shortcut + self.drop_path(hidden_states)
layer_output = self.intermediate(hidden_states)
layer_output = self.output(layer_output)
layer_output = hidden_states + self.drop_path(self.layernorm_after(layer_output))
layer_outputs = (layer_output, attention_outputs[1]) if output_attentions else (layer_output,)
return layer_outputs
class Swin2SRStage(nn.Module):
"""
This corresponds to the Residual Swin Transformer Block (RSTB) in the original implementation.
"""
def __init__(self, config, dim, input_resolution, depth, num_heads, drop_path, pretrained_window_size=0):
super().__init__()
self.config = config
self.dim = dim
self.layers = nn.ModuleList(
[
Swin2SRLayer(
config=config,
dim=dim,
input_resolution=input_resolution,
num_heads=num_heads,
shift_size=0 if (i % 2 == 0) else config.window_size // 2,
pretrained_window_size=pretrained_window_size,
)
for i in range(depth)
]
)
if config.resi_connection == "1conv":
self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
elif config.resi_connection == "3conv":
self.conv = nn.Sequential(
nn.Conv2d(dim, dim // 4, 3, 1, 1),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(dim // 4, dim, 3, 1, 1),
)
self.patch_embed = Swin2SRPatchEmbeddings(config, normalize_patches=False)
self.patch_unembed = Swin2SRPatchUnEmbeddings(config)
def forward(
self,
hidden_states: torch.Tensor,
input_dimensions: Tuple[int, int],
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
residual = hidden_states
height, width = input_dimensions
for i, layer_module in enumerate(self.layers):
layer_head_mask = head_mask[i] if head_mask is not None else None
layer_outputs = layer_module(hidden_states, input_dimensions, layer_head_mask, output_attentions)
hidden_states = layer_outputs[0]
output_dimensions = (height, width, height, width)
hidden_states = self.patch_unembed(hidden_states, input_dimensions)
hidden_states = self.conv(hidden_states)
hidden_states, _ = self.patch_embed(hidden_states)
hidden_states = hidden_states + residual
stage_outputs = (hidden_states, output_dimensions)
if output_attentions:
stage_outputs += layer_outputs[1:]
return stage_outputs
def __init__(self, config, grid_size):
super().__init__()
self.num_stages = len(config.depths)
self.config = config
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))]
self.stages = nn.ModuleList(
[
Swin2SRStage(
config=config,
dim=config.embed_dim,
input_resolution=(grid_size[0], grid_size[1]),
depth=config.depths[stage_idx],
num_heads=config.num_heads[stage_idx],
drop_path=dpr[sum(config.depths[:stage_idx]) : sum(config.depths[: stage_idx + 1])],
pretrained_window_size=0,
)
for stage_idx in range(self.num_stages)
]
)
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
input_dimensions: Tuple[int, int],
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = True,
) -> Union[Tuple, Swin2SREncoderOutput]:
all_input_dimensions = ()
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
if output_hidden_states:
all_hidden_states += (hidden_states,)
for i, stage_module in enumerate(self.stages):
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
stage_module.__call__, hidden_states, input_dimensions, layer_head_mask, output_attentions
)
else:
layer_outputs = stage_module(hidden_states, input_dimensions, layer_head_mask, output_attentions)
hidden_states = layer_outputs[0]
output_dimensions = layer_outputs[1]
input_dimensions = (output_dimensions[-2], output_dimensions[-1])
all_input_dimensions += (input_dimensions,)
if output_hidden_states:
all_hidden_states += (hidden_states,)
if output_attentions:
all_self_attentions += layer_outputs[2:]
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
return Swin2SREncoderOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
class Swin2SRPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = Swin2SRConfig
base_model_prefix = "swin2sr"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d)):
torch.nn.init.trunc_normal_(module.weight.data, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
SWIN2SR_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
behavior.
Parameters:
config ([`Swin2SRConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
SWIN2SR_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
[`Swin2SRImageProcessor.__call__`] for details.
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
@add_start_docstrings(
"The bare Swin2SR Model transformer outputting raw hidden-states without any specific head on top.",
SWIN2SR_START_DOCSTRING,
)
class Swin2SRModel(Swin2SRPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.config = config
if config.num_channels == 3 and config.num_channels_out == 3:
rgb_mean = (0.4488, 0.4371, 0.4040)
self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
else:
self.mean = torch.zeros(1, 1, 1, 1)
self.img_range = config.img_range
self.first_convolution = nn.Conv2d(config.num_channels, config.embed_dim, 3, 1, 1)
self.embeddings = Swin2SREmbeddings(config)
self.encoder = Swin2SREncoder(config, grid_size=self.embeddings.patch_embeddings.patches_resolution)
self.layernorm = nn.LayerNorm(config.embed_dim, eps=config.layer_norm_eps)
self.patch_unembed = Swin2SRPatchUnEmbeddings(config)
self.conv_after_body = nn.Conv2d(config.embed_dim, config.embed_dim, 3, 1, 1)
self.post_init()
def get_input_embeddings(self):
return self.embeddings.patch_embeddings
def _prune_heads(self, heads_to_prune):
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
def pad_and_normalize(self, pixel_values):
_, _, height, width = pixel_values.size()
window_size = self.config.window_size
modulo_pad_height = (window_size - height % window_size) % window_size
modulo_pad_width = (window_size - width % window_size) % window_size
pixel_values = nn.functional.pad(pixel_values, (0, modulo_pad_width, 0, modulo_pad_height), "reflect")
self.mean = self.mean.type_as(pixel_values)
pixel_values = (pixel_values - self.mean) * self.img_range
return pixel_values
@add_start_docstrings_to_model_forward(SWIN2SR_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=BaseModelOutput,
config_class=_CONFIG_FOR_DOC,
modality="vision",
expected_output=_EXPECTED_OUTPUT_SHAPE,
)
def forward(
self,
pixel_values: torch.FloatTensor,
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
head_mask = self.get_head_mask(head_mask, len(self.config.depths))
_, _, height, width = pixel_values.shape
pixel_values = self.pad_and_normalize(pixel_values)
embeddings = self.first_convolution(pixel_values)
embedding_output, input_dimensions = self.embeddings(embeddings)
encoder_outputs = self.encoder(
embedding_output,
input_dimensions,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = encoder_outputs[0]
sequence_output = self.layernorm(sequence_output)
sequence_output = self.patch_unembed(sequence_output, (height, width))
sequence_output = self.conv_after_body(sequence_output) + embeddings
if not return_dict:
output = (sequence_output,) + encoder_outputs[1:]
return output
return BaseModelOutput(
last_hidden_state=sequence_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
class PixelShuffleUpsampler(nn.Module):
"""PixelShuffleUpsampler module.
This module performs upsampling using PixelShuffle.
Args:
config (`object`):
Configuration object containing parameters.
num_features (`int`):
Number of intermediate features.
Attributes:
conv_before_upsample (`nn.Conv2d`):
Convolutional layer before upsampling.
activation (`nn.LeakyReLU`):
LeakyReLU activation function.
upsample (`Upsample`):
Upsample module.
final_convolution (`nn.Conv2d`):
Final convolutional layer.
"""
def __init__(self, config, num_features):
super().__init__()
self.conv_before_upsample = nn.Conv2d(config.embed_dim, num_features, 3, 1, 1)
self.activation = nn.LeakyReLU(inplace=True)
self.upsample = Upsample(config.upscale, num_features)
self.final_convolution = nn.Conv2d(num_features, config.num_channels_out, 3, 1, 1)
def forward(self, sequence_output):
x = self.conv_before_upsample(sequence_output)
x = self.activation(x)
x = self.upsample(x)
x = self.final_convolution(x)
return x
class NearestConvUpsampler(nn.Module):
"""NearestConvUpsampler module.
This module performs upsampling using nearest-neighbor interpolation followed by convolution.
Args:
scale (`int`):
Scale factor for upsampling.
in_channels (`int`):
Number of input channels.
out_channels (`int`):
Number of output channels.
Attributes:
upsample (`nn.Upsample`):
Upsampling layer.
conv (`nn.Conv2d`):
Convolutional layer.
"""
def __init__(self, config, num_features):
super().__init__()
if config.upscale != 4:
raise ValueError("The nearest+conv upsampler only supports an upscale factor of 4 at the moment.")
self.conv_before_upsample = nn.Conv2d(config.embed_dim, num_features, 3, 1, 1)
self.activation = nn.LeakyReLU(inplace=True)
self.conv_up1 = nn.Conv2d(num_features, num_features, 3, 1, 1)
self.conv_up2 = nn.Conv2d(num_features, num_features, 3, 1, 1)
self.conv_hr = nn.Conv2d(num_features, num_features, 3, 1, 1)
self.final_convolution = nn.Conv2d(num_features, config.num_channels_out, 3, 1, 1)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
def forward(self, sequence_output):
sequence_output = self.conv_before_upsample(sequence_output)
sequence_output = self.activation(sequence_output)
sequence_output = self.lrelu(
self.conv_up1(torch.nn.functional.interpolate(sequence_output, scale_factor=2, mode="nearest"))
)
sequence_output = self.lrelu(
self.conv_up2(torch.nn.functional.interpolate(sequence_output, scale_factor=2, mode="nearest"))
)
reconstruction = self.final_convolution(self.lrelu(self.conv_hr(sequence_output)))
return reconstruction
class PixelShuffleAuxUpsampler(nn.Module):
def __init__(self, config, num_features):
super().__init__()
self.upscale = config.upscale
self.conv_bicubic = nn.Conv2d(config.num_channels, num_features, 3, 1, 1)
self.conv_before_upsample = nn.Conv2d(config.embed_dim, num_features, 3, 1, 1)
self.activation = nn.LeakyReLU(inplace=True)
self.conv_aux = nn.Conv2d(num_features, config.num_channels, 3, 1, 1)
self.conv_after_aux = nn.Sequential(nn.Conv2d(3, num_features, 3, 1, 1), nn.LeakyReLU(inplace=True))
self.upsample = Upsample(config.upscale, num_features)
self.final_convolution = nn.Conv2d(num_features, config.num_channels_out, 3, 1, 1)
def forward(self, sequence_output, bicubic, height, width):
bicubic = self.conv_bicubic(bicubic)
sequence_output = self.conv_before_upsample(sequence_output)
sequence_output = self.activation(sequence_output)
aux = self.conv_aux(sequence_output)
sequence_output = self.conv_after_aux(aux)
sequence_output = (
self.upsample(sequence_output)[:, :, : height * self.upscale, : width * self.upscale]
+ bicubic[:, :, : height * self.upscale, : width * self.upscale]
)
reconstruction = self.final_convolution(sequence_output)
return reconstruction, aux
@add_start_docstrings(
"""
Swin2SR模型的变压器,顶部带有一个上采样器头部,用于图像超分辨率和恢复。
""",
SWIN2SR_START_DOCSTRING,
)
class Swin2SRForImageSuperResolution(Swin2SRPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.swin2sr = Swin2SRModel(config)
self.upsampler = config.upsampler
self.upscale = config.upscale
num_features = 64
if self.upsampler == "pixelshuffle":
self.upsample = PixelShuffleUpsampler(config, num_features)
elif self.upsampler == "pixelshuffle_aux":
self.upsample = PixelShuffleAuxUpsampler(config, num_features)
elif self.upsampler == "pixelshuffledirect":
self.upsample = UpsampleOneStep(config.upscale, config.embed_dim, config.num_channels_out)
elif self.upsampler == "nearest+conv":
self.upsample = NearestConvUpsampler(config, num_features)
else:
self.final_convolution = nn.Conv2d(config.embed_dim, config.num_channels_out, 3, 1, 1)
self.post_init()
@add_start_docstrings_to_model_forward(SWIN2SR_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=ImageSuperResolutionOutput, config_class=_CONFIG_FOR_DOC)
.\models\swin2sr\__init__.py
from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
_import_structure = {
"configuration_swin2sr": ["SWIN2SR_PRETRAINED_CONFIG_ARCHIVE_MAP", "Swin2SRConfig"],
}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_swin2sr"] = [
"SWIN2SR_PRETRAINED_MODEL_ARCHIVE_LIST",
"Swin2SRForImageSuperResolution",
"Swin2SRModel",
"Swin2SRPreTrainedModel",
]
try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["image_processing_swin2sr"] = ["Swin2SRImageProcessor"]
if TYPE_CHECKING:
from .configuration_swin2sr import SWIN2SR_PRETRAINED_CONFIG_ARCHIVE_MAP, Swin2SRConfig
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_swin2sr import (
SWIN2SR_PRETRAINED_MODEL_ARCHIVE_LIST,
Swin2SRForImageSuperResolution,
Swin2SRModel,
Swin2SRPreTrainedModel,
)
try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .image_processing_swin2sr import Swin2SRImageProcessor
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
.\models\swinv2\configuration_swinv2.py
""" Swinv2 Transformer model configuration"""
from ...configuration_utils import PretrainedConfig
from ...utils import logging
from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
logger = logging.get_logger(__name__)
SWINV2_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"microsoft/swinv2-tiny-patch4-window8-256": (
"https://huggingface.co/microsoft/swinv2-tiny-patch4-window8-256/resolve/main/config.json"
),
}
class Swinv2Config(BackboneConfigMixin, PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Swinv2Model`]. It is used to instantiate a Swin
Transformer v2 model according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the Swin Transformer v2
[microsoft/swinv2-tiny-patch4-window8-256](https://huggingface.co/microsoft/swinv2-tiny-patch4-window8-256)
architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Example:
```
>>> from transformers import Swinv2Config, Swinv2Model
>>> # Initializing a Swinv2 microsoft/swinv2-tiny-patch4-window8-256 style configuration
>>> configuration = Swinv2Config()
>>> # Initializing a model (with random weights) from the microsoft/swinv2-tiny-patch4-window8-256 style configuration
>>> model = Swinv2Model(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
"""
model_type = "swinv2"
attribute_map = {
"num_attention_heads": "num_heads",
"num_hidden_layers": "num_layers",
}
def __init__(
self,
image_size=224,
patch_size=4,
num_channels=3,
embed_dim=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_size=7,
pretrained_window_sizes=[0, 0, 0, 0],
mlp_ratio=4.0,
qkv_bias=True,
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
drop_path_rate=0.1,
hidden_act="gelu",
use_absolute_embeddings=False,
initializer_range=0.02,
layer_norm_eps=1e-5,
encoder_stride=32,
out_features=None,
out_indices=None,
**kwargs,
):
super().__init__(**kwargs)
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.embed_dim = embed_dim
self.depths = depths
self.num_layers = len(depths)
self.num_heads = num_heads
self.window_size = window_size
self.pretrained_window_sizes = pretrained_window_sizes
self.mlp_ratio = mlp_ratio
self.qkv_bias = qkv_bias
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.drop_path_rate = drop_path_rate
self.hidden_act = hidden_act
self.use_absolute_embeddings = use_absolute_embeddings
self.layer_norm_eps = layer_norm_eps
self.initializer_range = initializer_range
self.encoder_stride = encoder_stride
self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)]
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
)
self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))
.\models\swinv2\convert_swinv2_timm_to_pytorch.py
"""Convert Swinv2 checkpoints from the timm library."""
import argparse
import json
from pathlib import Path
import requests
import timm
import torch
from huggingface_hub import hf_hub_download
from PIL import Image
from transformers import AutoImageProcessor, Swinv2Config, Swinv2ForImageClassification
def get_swinv2_config(swinv2_name):
config = Swinv2Config()
name_split = swinv2_name.split("_")
model_size = name_split[1]
if "to" in name_split[3]:
img_size = int(name_split[3][-3:])
else:
img_size = int(name_split[3])
if "to" in name_split[2]:
window_size = int(name_split[2][-2:])
else:
window_size = int(name_split[2][6:])
if model_size == "tiny":
embed_dim = 96
depths = (2, 2, 6, 2)
num_heads = (3, 6, 12, 24)
elif model_size == "small":
embed_dim = 96
depths = (2, 2, 18, 2)
num_heads = (3, 6, 12, 24)
elif model_size == "base":
embed_dim = 128
depths = (2, 2, 18, 2)
num_heads = (4, 8, 16, 32)
else:
embed_dim = 192
depths = (2, 2, 18, 2)
num_heads = (6, 12, 24, 48)
if "to" in swinv2_name:
config.pretrained_window_sizes = (12, 12, 12, 6)
if ("22k" in swinv2_name) and ("to" not in swinv2_name):
num_classes = 21841
repo_id = "huggingface/label-files"
filename = "imagenet-22k-id2label.json"
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
config.id2label = id2label
config.label2id = {v: k for k, v in id2label.items()}
else:
num_classes = 1000
repo_id = "huggingface/label-files"
filename = "imagenet-1k-id2label.json"
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
config.id2label = id2label
config.label2id = {v: k for k, v in id2label.items()}
config.image_size = img_size
config.num_labels = num_classes
config.embed_dim = embed_dim
config.depths = depths
config.num_heads = num_heads
config.window_size = window_size
return config
def rename_key(name):
if "patch_embed.proj" in name:
name = name.replace("patch_embed.proj", "embeddings.patch_embeddings.projection")
if "patch_embed.norm" in name:
name = name.replace("patch_embed.norm", "embeddings.norm")
if "layers" in name:
name = "encoder." + name
if "attn.proj" in name:
name = name.replace("attn.proj", "attention.output.dense")
if "attn" in name:
name = name.replace("attn", "attention.self")
if "norm1" in name:
name = name.replace("norm1", "layernorm_before")
if "norm2" in name:
name = name.replace("norm2", "layernorm_after")
if "mlp.fc1" in name:
name = name.replace("mlp.fc1", "intermediate.dense")
if "mlp.fc2" in name:
name = name.replace("mlp.fc2", "output.dense")
if "q_bias" in name:
name = name.replace("q_bias", "query.bias")
if "k_bias" in name:
name = name.replace("k_bias", "key.bias")
if "v_bias" in name:
name = name.replace("v_bias", "value.bias")
if "cpb_mlp" in name:
name = name.replace("cpb_mlp", "continuous_position_bias_mlp")
if name == "norm.weight":
name = "layernorm.weight"
if name == "norm.bias":
name = "layernorm.bias"
if "head" in name:
name = name.replace("head", "classifier")
else:
name = "swinv2." + name
return name
def convert_state_dict(orig_state_dict, model):
for key in orig_state_dict.copy().keys():
val = orig_state_dict.pop(key)
if "mask" in key:
continue
elif "qkv" in key:
key_split = key.split(".")
layer_num = int(key_split[1])
block_num = int(key_split[3])
dim = model.swinv2.encoder.layers[layer_num].blocks[block_num].attention.self.all_head_size
if "weight" in key:
orig_state_dict[
f"swinv2.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.query.weight"
] = val[:dim, :]
orig_state_dict[
f"swinv2.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.key.weight"
] = val[dim : dim * 2, :]
orig_state_dict[
f"swinv2.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.value.weight"
] = val[-dim:, :]
else:
orig_state_dict[
f"swinv2.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.query.bias"
] = val[:dim]
orig_state_dict[
f"swinv2.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.key.bias"
] = val[dim : dim * 2]
orig_state_dict[
f"swinv2.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.value.bias"
] = val[-dim:]
else:
orig_state_dict[rename_key(key)] = val
return orig_state_dict
def convert_swinv2_checkpoint(swinv2_name, pytorch_dump_folder_path):
timm_model = timm.create_model(swinv2_name, pretrained=True)
timm_model.eval()
config = get_swinv2_config(swinv2_name)
model = Swinv2ForImageClassification(config)
model.eval()
new_state_dict = convert_state_dict(timm_model.state_dict(), model)
model.load_state_dict(new_state_dict)
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image_processor = AutoImageProcessor.from_pretrained("microsoft/{}".format(swinv2_name.replace("_", "-")))
image = Image.open(requests.get(url, stream=True).raw)
inputs = image_processor(images=image, return_tensors="pt")
timm_outs = timm_model(inputs["pixel_values"])
hf_outs = model(**inputs).logits
assert torch.allclose(timm_outs, hf_outs, atol=1e-3)
print(f"Saving model {swinv2_name} to {pytorch_dump_folder_path}")
model.save_pretrained(pytorch_dump_folder_path)
print(f"Saving image processor to {pytorch_dump_folder_path}")
image_processor.save_pretrained(pytorch_dump_folder_path)
model.push_to_hub(
repo_path_or_name=Path(pytorch_dump_folder_path, swinv2_name),
organization="nandwalritik",
commit_message="Add model",
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--swinv2_name",
default="swinv2_tiny_patch4_window8_256",
type=str,
help="Name of the Swinv2 timm model you'd like to convert.",
)
parser.add_argument(
"--pytorch_dump_folder_path",
default=None,
type=str,
help="Path to the output PyTorch model directory."
)
args = parser.parse_args()
convert_swinv2_checkpoint(args.swinv2_name, args.pytorch_dump_folder_path)
.\models\swinv2\modeling_swinv2.py
""" PyTorch Swinv2 Transformer model."""
import collections.abc
import math
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import Tensor, nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_outputs import BackboneOutput
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer
from ...utils import (
ModelOutput,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from ...utils.backbone_utils import BackboneMixin
from .configuration_swinv2 import Swinv2Config
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "Swinv2Config"
_CHECKPOINT_FOR_DOC = "microsoft/swinv2-tiny-patch4-window8-256"
_EXPECTED_OUTPUT_SHAPE = [1, 64, 768]
_IMAGE_CLASS_CHECKPOINT = "microsoft/swinv2-tiny-patch4-window8-256"
_IMAGE_CLASS_EXPECTED_OUTPUT = "Egyptian cat"
SWINV2_PRETRAINED_MODEL_ARCHIVE_LIST = [
"microsoft/swinv2-tiny-patch4-window8-256",
]
@dataclass
class Swinv2EncoderOutput(ModelOutput):
"""
Swinv2 编码器的输出,可能包含隐藏状态和注意力权重。
# 最后一层模型的隐藏状态,形状为(batch_size, sequence_length, hidden_size)
last_hidden_state: torch.FloatTensor = None
# 模型每一层的隐藏状态的元组,形状为(batch_size, sequence_length, hidden_size),可选项,当`output_hidden_states=True`时返回
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
# 注意力权重的元组,形状为(batch_size, num_heads, sequence_length, sequence_length),可选项,当`output_attentions=True`时返回
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
# 模型每一层的隐藏状态的元组,形状为(batch_size, hidden_size, height, width),包括空间维度,可选项,当`output_hidden_states=True`且输出被重塑时返回
reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
# 使用 dataclass 装饰器定义 Swinv2ModelOutput 类,它继承自 ModelOutput 类
# ModelOutput 是一个基础类,可能在 transformers 库中定义
@dataclass
# 从 transformers.models.swin.modeling_swin.SwinModelOutput 复制的类定义,将 Swin 替换为 Swinv2
class Swinv2ModelOutput(ModelOutput):
"""
Swinv2 模型的输出,同时包含最后隐藏状态的池化结果。
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
模型最后一层的隐藏状态序列输出。
pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*, 当 `add_pooling_layer=True` 时返回):
最后一层隐藏状态的平均池化结果。
hidden_states (`tuple(torch.FloatTensor)`, *optional*, 当 `output_hidden_states=True` 或 `config.output_hidden_states=True` 时返回):
包含模型每一层隐藏状态的元组,以及初始嵌入输出。
形状为 `(batch_size, sequence_length, hidden_size)`。
attentions (`tuple(torch.FloatTensor)`, *optional*, 当 `output_attentions=True` 或 `config.output_attentions=True` 时返回):
自注意力机制 softmax 后的注意力权重,用于计算自注意力头的加权平均值。
形状为 `(batch_size, num_heads, sequence_length, sequence_length)`。
reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, 当 `output_hidden_states=True` 或 `config.output_hidden_states=True` 时返回):
包含模型每一层隐藏状态的元组,以及初始嵌入输出,重塑为包含空间维度的形状。
形状为 `(batch_size, hidden_size, height, width)`。
"""
@dataclass
# 从 transformers.models.swin.modeling_swin.SwinMaskedImageModelingOutput 复制的类定义,将 Swin 替换为 Swinv2
class Swinv2MaskedImageModelingOutput(ModelOutput):
"""
Swinv2 掩码图像模型的输出。
这个类定义可能还需要填充完整,以匹配 Swinv2 模型的具体输出内容和结构。
通常来说,这些数据类定义了模型输出的结构,包括各个部分的详细说明。
你可以根据实际的 Swinv2 模型输出来进一步补充这个类的内容。
例如,可以包括类似于上面 Swinv2ModelOutput 类的参数说明,描述具体的模型输出内容和形状。
"""
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided):
Masked image modeling (MLM) loss.
图像模型的掩码损失(MLM损失)。
reconstruction (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Reconstructed pixel values.
重建的像素数值。
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
模型在每一层输出的隐藏状态,包括初始嵌入输出。
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
注意力权重经过注意力softmax后的结果,用于计算自注意力头中的加权平均。
reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
shape `(batch_size, hidden_size, height, width)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
include the spatial dimensions.
模型在每一层输出的隐藏状态,包括重塑以包括空间维度的初始嵌入输出。
"""
loss: Optional[torch.FloatTensor] = None
reconstruction: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
@property
def logits(self):
warnings.warn(
"logits attribute is deprecated and will be removed in version 5 of Transformers."
" Please use the reconstruction attribute to retrieve the final output instead.",
FutureWarning,
)
return self.reconstruction
@dataclass
class Swinv2ImageClassifierOutput(ModelOutput):
"""
Swinv2图像分类的输出。
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, 当提供`labels`时返回):
分类(如果config.num_labels==1则是回归)损失。
logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
分类(如果config.num_labels==1则是回归)得分(SoftMax之前)。
hidden_states (`tuple(torch.FloatTensor)`, *optional*, 当`output_hidden_states=True`时返回或者当`config.output_hidden_states=True`时返回):
包含每层输出的`torch.FloatTensor`元组,形状为`(batch_size, sequence_length, hidden_size)`。
每个层的模型隐藏状态加上初始嵌入输出。
attentions (`tuple(torch.FloatTensor)`, *optional*, 当`output_attentions=True`时返回或者当`config.output_attentions=True`时返回):
包含每个阶段`torch.FloatTensor`元组,形状为`(batch_size, num_heads, sequence_length, sequence_length)`。
注意力softmax后的注意力权重,用于计算自注意力头的加权平均值。
reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, 当`output_hidden_states=True`时返回或者当`config.output_hidden_states=True`时返回):
包含每层输出的`torch.FloatTensor`元组,形状为`(batch_size, hidden_size, height, width)`。
每个层的模型隐藏状态加上初始嵌入输出,重塑以包含空间维度。
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
def window_partition(input_feature, window_size):
"""
将给定输入分区为窗口。
"""
batch_size, height, width, num_channels = input_feature.shape
input_feature = input_feature.view(
batch_size, height // window_size, window_size, width // window_size, window_size, num_channels
)
windows = input_feature.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, num_channels)
return windows
def window_reverse(windows, window_size, height, width):
"""
合并窗口以产生更高分辨率的特征。
"""
num_channels = windows.shape[-1]
windows = windows.view(-1, height // window_size, width // window_size, window_size, window_size, num_channels)
windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous()
windows = windows.view(-1, height, width, num_channels)
return windows
def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
argument.
"""
if drop_prob == 0.0 or not training:
return input
keep_prob = 1 - drop_prob
shape = (input.shape[0],) + (1,) * (input.ndim - 1)
random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
random_tensor.floor_()
output = input.div(keep_prob) * random_tensor
return output
class Swinv2DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob: Optional[float] = None) -> None:
super().__init__()
self.drop_prob = drop_prob
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return drop_path(hidden_states, self.drop_prob, self.training)
def extra_repr(self) -> str:
return "p={}".format(self.drop_prob)
class Swinv2Embeddings(nn.Module):
"""
Construct the patch and position embeddings. Optionally, also the mask token.
"""
def __init__(self, config, use_mask_token=False):
super().__init__()
self.patch_embeddings = Swinv2PatchEmbeddings(config)
num_patches = self.patch_embeddings.num_patches
self.patch_grid = self.patch_embeddings.grid_size
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.embed_dim)) if use_mask_token else None
if config.use_absolute_embeddings:
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.embed_dim))
else:
self.position_embeddings = None
self.norm = nn.LayerNorm(config.embed_dim)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(
self, pixel_values: Optional[torch.FloatTensor], bool_masked_pos: Optional[torch.BoolTensor] = None
):
pass
) -> Tuple[torch.Tensor]:
embeddings, output_dimensions = self.patch_embeddings(pixel_values)
embeddings = self.norm(embeddings)
batch_size, seq_len, _ = embeddings.size()
if bool_masked_pos is not None:
mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)
mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
if self.position_embeddings is not None:
embeddings = embeddings + self.position_embeddings
embeddings = self.dropout(embeddings)
return embeddings, output_dimensions
class Swinv2PatchEmbeddings(nn.Module):
"""
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
Transformer.
"""
def __init__(self, config):
super().__init__()
image_size, patch_size = config.image_size, config.patch_size
num_channels, hidden_size = config.num_channels, config.embed_dim
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.num_patches = num_patches
self.grid_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
def maybe_pad(self, pixel_values, height, width):
if width % self.patch_size[1] != 0:
pad_values = (0, self.patch_size[1] - width % self.patch_size[1])
pixel_values = nn.functional.pad(pixel_values, pad_values)
if height % self.patch_size[0] != 0:
pad_values = (0, 0, 0, self.patch_size[0] - height % self.patch_size[0])
pixel_values = nn.functional.pad(pixel_values, pad_values)
return pixel_values
def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]:
_, num_channels, height, width = pixel_values.shape
if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
pixel_values = self.maybe_pad(pixel_values, height, width)
embeddings = self.projection(pixel_values)
_, _, height, width = embeddings.shape
embeddings = embeddings.flatten(2).transpose(1, 2)
return embeddings, (height, width)
class Swinv2PatchMerging(nn.Module):
"""
Patch Merging Layer.
Args:
input_resolution (`Tuple[int]`):
Resolution of input feature.
dim (`int`):
Number of input channels.
norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`):
Normalization layer class.
"""
def __init__(self, input_resolution: Tuple[int], dim: int, norm_layer: nn.Module = nn.LayerNorm) -> None:
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(2 * dim)
def maybe_pad(self, input_feature, height, width):
should_pad = (height % 2 == 1) or (width % 2 == 1)
if should_pad:
pad_values = (0, 0, 0, width % 2, 0, height % 2)
input_feature = nn.functional.pad(input_feature, pad_values)
return input_feature
def forward(self, input_feature: torch.Tensor, input_dimensions: Tuple[int, int]) -> torch.Tensor:
height, width = input_dimensions
batch_size, dim, num_channels = input_feature.shape
input_feature = input_feature.view(batch_size, height, width, num_channels)
input_feature = self.maybe_pad(input_feature, height, width)
input_feature_0 = input_feature[:, 0::2, 0::2, :]
input_feature_1 = input_feature[:, 1::2, 0::2, :]
input_feature_2 = input_feature[:, 0::2, 1::2, :]
input_feature_3 = input_feature[:, 1::2, 1::2, :]
input_feature = torch.cat([input_feature_0, input_feature_1, input_feature_2, input_feature_3], -1)
input_feature = input_feature.view(batch_size, -1, 4 * num_channels)
input_feature = self.reduction(input_feature)
input_feature = self.norm(input_feature)
return input_feature
class Swinv2SelfAttention(nn.Module):
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
batch_size, dim, num_channels = hidden_states.shape
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
attention_scores = nn.functional.normalize(query_layer, dim=-1) @ nn.functional.normalize(
key_layer, dim=-1
).transpose(-2, -1)
logit_scale = torch.clamp(self.logit_scale, max=math.log(1.0 / 0.01)).exp()
attention_scores = attention_scores * logit_scale
relative_position_bias_table = self.continuous_position_bias_mlp(self.relative_coords_table).view(
-1, self.num_attention_heads
)
relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
)
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
attention_scores = attention_scores + relative_position_bias.unsqueeze(0)
if attention_mask is not None:
mask_shape = attention_mask.shape[0]
attention_scores = attention_scores.view(
batch_size // mask_shape, mask_shape, self.num_attention_heads, dim, dim
) + attention_mask.unsqueeze(1).unsqueeze(0)
attention_scores = attention_scores.view(-1, self.num_attention_heads, dim, dim)
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
attention_probs = self.dropout(attention_probs)
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs
class Swinv2SelfOutput(nn.Module):
def __init__(self, config, dim):
super().__init__()
self.dense = nn.Linear(dim, dim)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states
class Swinv2Attention(nn.Module):
def __init__(self, config, dim, num_heads, window_size, pretrained_window_size=0):
super().__init__()
self.self = Swinv2SelfAttention(
config=config,
dim=dim,
num_heads=num_heads,
window_size=window_size,
pretrained_window_size=pretrained_window_size
if isinstance(pretrained_window_size, collections.abc.Iterable)
else (pretrained_window_size, pretrained_window_size),
)
self.output = Swinv2SelfOutput(config, dim)
self.pruned_heads = set()
def prune_heads(self, heads):
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
)
self.self.query = prune_linear_layer(self.self.query, index)
self.self.key = prune_linear_layer(self.self.key, index)
self.self.value = prune_linear_layer(self.self.value, index)
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
self_outputs = self.self(hidden_states, attention_mask, head_mask, output_attentions)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:]
return outputs
class Swinv2Intermediate(nn.Module):
def __init__(self, config, dim):
super().__init__()
self.dense = nn.Linear(dim, int(config.mlp_ratio * dim))
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class Swinv2Output(nn.Module):
def __init__(self, config, dim):
super().__init__()
self.dense = nn.Linear(int(config.mlp_ratio * dim), dim)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states
class Swinv2Layer(nn.Module):
def __init__(self, config, dim, input_resolution, num_heads, shift_size=0, pretrained_window_size=0):
super().__init__()
window_size, shift_size = self._compute_window_shift(
(config.window_size, config.window_size), (shift_size, shift_size)
)
self.window_size = window_size[0]
self.shift_size = shift_size[0]
self.attention = Swinv2Attention(
config=config,
dim=dim,
num_heads=num_heads,
window_size=self.window_size,
pretrained_window_size=pretrained_window_size
if isinstance(pretrained_window_size, collections.abc.Iterable)
else (pretrained_window_size, pretrained_window_size),
)
self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps)
self.drop_path = Swinv2DropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
self.intermediate = Swinv2Intermediate(config, dim)
self.output = Swinv2Output(config, dim)
self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps)
def _compute_window_shift(self, target_window_size, target_shift_size) -> Tuple[Tuple[int, int], Tuple[int, int]]:
window_size = [r if r <= w else w for r, w in zip(self.input_resolution, target_window_size)]
shift_size = [0 if r <= w else s for r, w, s in zip(self.input_resolution, window_size, target_shift_size)]
return window_size, shift_size
def get_attn_mask(self, height, width, dtype):
if self.shift_size > 0:
img_mask = torch.zeros((1, height, width, 1), dtype=dtype)
height_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None),
)
width_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None),
)
count = 0
for height_slice in height_slices:
for width_slice in width_slices:
img_mask[:, height_slice, width_slice, :] = count
count += 1
mask_windows = window_partition(img_mask, self.window_size)
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
else:
attn_mask = None
return attn_mask
def maybe_pad(self, hidden_states, height, width):
pad_right = (self.window_size - width % self.window_size) % self.window_size
pad_bottom = (self.window_size - height % self.window_size) % self.window_size
pad_values = (0, 0, 0, pad_right, 0, pad_bottom)
hidden_states = nn.functional.pad(hidden_states, pad_values)
return hidden_states, pad_values
def forward(
self,
hidden_states: torch.Tensor,
input_dimensions: Tuple[int, int],
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
height, width = input_dimensions
batch_size, _, channels = hidden_states.size()
shortcut = hidden_states
hidden_states = hidden_states.view(batch_size, height, width, channels)
hidden_states, pad_values = self.maybe_pad(hidden_states, height, width)
_, height_pad, width_pad, _ = hidden_states.shape
if self.shift_size > 0:
shifted_hidden_states = torch.roll(hidden_states, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
else:
shifted_hidden_states = hidden_states
hidden_states_windows = window_partition(shifted_hidden_states, self.window_size)
hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels)
attn_mask = self.get_attn_mask(height_pad, width_pad, dtype=hidden_states.dtype)
if attn_mask is not None:
attn_mask = attn_mask.to(hidden_states_windows.device)
attention_outputs = self.attention(
hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions
)
attention_output = attention_outputs[0]
attention_windows = attention_output.view(-1, self.window_size, self.window_size, channels)
shifted_windows = window_reverse(attention_windows, self.window_size, height_pad, width_pad)
if self.shift_size > 0:
attention_windows = torch.roll(shifted_windows, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
attention_windows = shifted_windows
was_padded = pad_values[3] > 0 or pad_values[5] > 0
if was_padded:
attention_windows = attention_windows[:, :height, :width, :].contiguous()
attention_windows = attention_windows.view(batch_size, height * width, channels)
hidden_states = self.layernorm_before(attention_windows)
hidden_states = shortcut + self.drop_path(hidden_states)
layer_output = self.intermediate(hidden_states)
layer_output = self.output(layer_output)
layer_output = hidden_states + self.drop_path(self.layernorm_after(layer_output))
layer_outputs = (layer_output, attention_outputs[1]) if output_attentions else (layer_output,)
return layer_outputs
class Swinv2Stage(nn.Module):
def __init__(
self, config, dim, input_resolution, depth, num_heads, drop_path, downsample, pretrained_window_size=0
):
super().__init__()
self.config = config
self.dim = dim
blocks = []
for i in range(depth):
block = Swinv2Layer(
config=config,
dim=dim,
input_resolution=input_resolution,
num_heads=num_heads,
shift_size=0 if (i % 2 == 0) else config.window_size // 2,
pretrained_window_size=pretrained_window_size,
)
blocks.append(block)
self.blocks = nn.ModuleList(blocks)
if downsample is not None:
self.downsample = downsample(input_resolution, dim=dim, norm_layer=nn.LayerNorm)
else:
self.downsample = None
self.pointing = False
def forward(
self,
hidden_states: torch.Tensor,
input_dimensions: Tuple[int, int],
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
height, width = input_dimensions
for i, layer_module in enumerate(self.blocks):
layer_head_mask = head_mask[i] if head_mask is not None else None
layer_outputs = layer_module(
hidden_states,
input_dimensions,
layer_head_mask,
output_attentions,
)
hidden_states = layer_outputs[0]
hidden_states_before_downsampling = hidden_states
if self.downsample is not None:
height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2
output_dimensions = (height, width, height_downsampled, width_downsampled)
hidden_states = self.downsample(hidden_states_before_downsampling, input_dimensions)
else:
output_dimensions = (height, width, height, width)
stage_outputs = (hidden_states, hidden_states_before_downsampling, output_dimensions)
if output_attentions:
stage_outputs += layer_outputs[1:]
return stage_outputs
def __init__(self, config, grid_size, pretrained_window_sizes=(0, 0, 0, 0)):
super().__init__()
self.num_layers = len(config.depths)
self.config = config
if self.config.pretrained_window_sizes is not None:
pretrained_window_sizes = config.pretrained_window_sizes
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))]
layers = []
for i_layer in range(self.num_layers):
stage = Swinv2Stage(
config=config,
dim=int(config.embed_dim * 2**i_layer),
input_resolution=(grid_size[0] // (2**i_layer), grid_size[1] // (2**i_layer)),
depth=config.depths[i_layer],
num_heads=config.num_heads[i_layer],
drop_path=dpr[sum(config.depths[:i_layer]) : sum(config.depths[: i_layer + 1])],
downsample=Swinv2PatchMerging if (i_layer < self.num_layers - 1) else None,
pretrained_window_size=pretrained_window_sizes[i_layer],
)
layers.append(stage)
self.layers = nn.ModuleList(layers)
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
input_dimensions: Tuple[int, int],
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
output_hidden_states_before_downsampling: Optional[bool] = False,
return_dict: Optional[bool] = True,
class Swinv2PreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = Swinv2Config
base_model_prefix = "swinv2"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d)):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
SWINV2_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
behavior.
Parameters:
config ([`Swinv2Config`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
SWINV2_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`]
for details.
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
@add_start_docstrings(
"The bare Swinv2 Model transformer outputting raw hidden-states without any specific head on top.",
SWINV2_START_DOCSTRING,
)
class Swinv2Model(Swinv2PreTrainedModel):
def __init__(self, config, add_pooling_layer=True, use_mask_token=False):
super().__init__(config)
self.config = config
self.num_layers = len(config.depths)
self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1))
self.embeddings = Swinv2Embeddings(config, use_mask_token=use_mask_token)
self.encoder = Swinv2Encoder(config, self.embeddings.patch_grid)
self.layernorm = nn.LayerNorm(self.num_features, eps=config.layer_norm_eps)
self.pooler = nn.AdaptiveAvgPool1d(1) if add_pooling_layer else None
self.post_init()
def get_input_embeddings(self):
return self.embeddings.patch_embeddings
def _prune_heads(self, heads_to_prune):
"""
对模型的注意力头进行剪枝。
heads_to_prune: {layer_num: 需要在该层剪枝的头列表} 参见基类PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
@add_start_docstrings_to_model_forward(SWINV2_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=Swinv2ModelOutput,
config_class=_CONFIG_FOR_DOC,
modality="vision",
expected_output=_EXPECTED_OUTPUT_SHAPE,
)
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
bool_masked_pos: Optional[torch.BoolTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, Swinv2ModelOutput]:
r"""
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
head_mask = self.get_head_mask(head_mask, len(self.config.depths))
embedding_output, input_dimensions = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
encoder_outputs = self.encoder(
embedding_output,
input_dimensions,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = encoder_outputs[0]
sequence_output = self.layernorm(sequence_output)
pooled_output = None
if self.pooler is not None:
pooled_output = self.pooler(sequence_output.transpose(1, 2))
pooled_output = torch.flatten(pooled_output, 1)
if not return_dict:
output = (sequence_output, pooled_output) + encoder_outputs[1:]
return output
return Swinv2ModelOutput(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
reshaped_hidden_states=encoder_outputs.reshaped_hidden_states,
)
@add_start_docstrings(
"""
Swinv2 Model with a decoder on top for masked image modeling, as proposed in
[SimMIM](https://arxiv.org/abs/2111.09886).
<Tip>
Note that we provide a script to pre-train this model on custom data in our [examples
directory](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-pretraining).
</Tip>
""",
SWINV2_START_DOCSTRING,
)
class Swinv2ForMaskedImageModeling(Swinv2PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.swinv2 = Swinv2Model(config, add_pooling_layer=False, use_mask_token=True)
num_features = int(config.embed_dim * 2 ** (config.num_layers - 1))
self.decoder = nn.Sequential(
nn.Conv2d(
in_channels=num_features, out_channels=config.encoder_stride**2 * config.num_channels, kernel_size=1
),
nn.PixelShuffle(config.encoder_stride),
)
self.post_init()
@add_start_docstrings_to_model_forward(SWINV2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Swinv2MaskedImageModelingOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
bool_masked_pos: Optional[torch.BoolTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
@add_start_docstrings(
"""
Swinv2 Model transformer with an image classification head on top (a linear layer on top of the final hidden state
of the [CLS] token) e.g. for ImageNet.
""",
SWINV2_START_DOCSTRING,
)
class Swinv2ForImageClassification(Swinv2PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.swinv2 = Swinv2Model(config)
self.classifier = (
nn.Linear(self.swinv2.num_features, config.num_labels) if config.num_labels > 0 else nn.Identity()
)
self.post_init()
@add_start_docstrings_to_model_forward(SWINV2_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_IMAGE_CLASS_CHECKPOINT,
output_type=Swinv2ImageClassifierOutput,
config_class=_CONFIG_FOR_DOC,
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
)
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, Swinv2ImageClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.swinv2(
pixel_values,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
pooled_output = outputs[1]
logits = self.classifier(pooled_output)
loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return Swinv2ImageClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
reshaped_hidden_states=outputs.reshaped_hidden_states,
)
@add_start_docstrings(
"""
Swinv2 backbone, to be used with frameworks like DETR and MaskFormer.
""",
SWINV2_START_DOCSTRING,
)
class Swinv2Backbone(Swinv2PreTrainedModel, BackboneMixin):
def __init__(self, config):
super().__init__(config)
super()._init_backbone(config)
self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
self.embeddings = Swinv2Embeddings(config)
self.encoder = Swinv2Encoder(config, self.embeddings.patch_grid)
self.post_init()
def get_input_embeddings(self):
return self.embeddings.patch_embeddings
@add_start_docstrings_to_model_forward(SWINV2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values: Tensor,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
"""
根据给定的参数返回 BackboneOutput 对象。
参数:
return_dict (bool, optional): 是否返回字典形式的输出,默认为使用配置中的设定。
output_hidden_states (bool, optional): 是否输出隐藏状态,默认为使用配置中的设定。
output_attentions (bool, optional): 是否输出注意力权重,默认为使用配置中的设定。
返回:
BackboneOutput: 包含特征图、隐藏状态和注意力权重的对象。
示例:
```
>>> from transformers import AutoImageProcessor, AutoBackbone
>>> import torch
>>> from PIL import Image
>>> import requests
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> processor = AutoImageProcessor.from_pretrained("microsoft/swinv2-tiny-patch4-window8-256")
>>> model = AutoBackbone.from_pretrained(
... "microsoft/swinv2-tiny-patch4-window8-256", out_features=["stage1", "stage2", "stage3", "stage4"]
... )
>>> inputs = processor(image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> feature_maps = outputs.feature_maps
>>> list(feature_maps[-1].shape)
[1, 2048, 7, 7]
```
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
embedding_output, input_dimensions = self.embeddings(pixel_values)
outputs = self.encoder(
embedding_output,
input_dimensions,
head_mask=None,
output_attentions=output_attentions,
output_hidden_states=True,
output_hidden_states_before_downsampling=True,
return_dict=return_dict,
)
hidden_states = outputs.reshaped_hidden_states if return_dict else outputs[-1]
feature_maps = ()
for stage, hidden_state in zip(self.stage_names, hidden_states):
if stage in self.out_features:
feature_maps += (hidden_state,)
if not return_dict:
output = (feature_maps,)
if output_hidden_states:
output += (outputs[1],)
if output_attentions:
output += (outputs[2],)
return output
return BackboneOutput(
feature_maps=feature_maps,
hidden_states=outputs.hidden_states if output_hidden_states else None,
attentions=outputs.attentions,
)