Transformers 源码解析(一百二十)
.\models\vit_msn\__init__.py
from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
_import_structure = {"configuration_vit_msn": ["VIT_MSN_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTMSNConfig"]}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_vit_msn"] = [
"VIT_MSN_PRETRAINED_MODEL_ARCHIVE_LIST",
"ViTMSNModel",
"ViTMSNForImageClassification",
"ViTMSNPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_vit_msn import VIT_MSN_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTMSNConfig
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_vit_msn import (
VIT_MSN_PRETRAINED_MODEL_ARCHIVE_LIST,
ViTMSNForImageClassification,
ViTMSNModel,
ViTMSNPreTrainedModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
.\models\vivit\configuration_vivit.py
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
VIVIT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"google/vivit-b-16x2-kinetics400": (
"https://huggingface.co/google/vivit-b-16x2-kinetics400/resolve/main/config.json"
),
}
class VivitConfig(PretrainedConfig):
r"""
这是用于存储 [`VivitModel`] 配置的配置类。根据指定的参数实例化配置对象,定义模型架构。
使用默认参数实例化配置对象将产生类似于 ViViT [google/vivit-b-16x2-kinetics400]
(https://huggingface.co/google/vivit-b-16x2-kinetics400) 架构的配置。
配置对象继承自 [`PretrainedConfig`],可用于控制模型的输出。阅读 [`PretrainedConfig`] 的文档以获取更多信息。
"""
pass
model_type = "vivit"
def __init__(
self,
image_size=224,
num_frames=32,
tubelet_size=[2, 16, 16],
num_channels=3,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu_fast",
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
initializer_range=0.02,
layer_norm_eps=1e-06,
qkv_bias=True,
**kwargs,
):
):
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.image_size = image_size
self.num_frames = num_frames
self.tubelet_size = tubelet_size
self.num_channels = num_channels
self.qkv_bias = qkv_bias
super().__init__(**kwargs)
.\models\vivit\convert_vivit_flax_to_pytorch.py
"""Convert Flax ViViT checkpoints from the original repository to PyTorch. URL:
https://github.com/google-research/scenic/tree/main/scenic/projects/vivit
"""
import argparse
import json
import os.path
from collections import OrderedDict
import numpy as np
import requests
import torch
from flax.training.checkpoints import restore_checkpoint
from huggingface_hub import hf_hub_download
from transformers import VivitConfig, VivitForVideoClassification, VivitImageProcessor
from transformers.image_utils import PILImageResampling
def download_checkpoint(path):
url = "https://storage.googleapis.com/scenic-bucket/vivit/kinetics_400/vivit_base_16x2_unfactorized/checkpoint"
with open(path, "wb") as f:
with requests.get(url, stream=True) as req:
for chunk in req.iter_content(chunk_size=2048):
f.write(chunk)
def get_vivit_config() -> VivitConfig:
config = VivitConfig()
config.num_labels = 400
repo_id = "huggingface/label-files"
filename = "kinetics400-id2label.json"
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
config.id2label = id2label
config.label2id = {v: k for k, v in id2label.items()}
return config
def prepare_video():
file = hf_hub_download(
repo_id="hf-internal-testing/spaghetti-video", filename="eating_spaghetti_32_frames.npy", repo_type="dataset"
)
video = np.load(file)
return list(video)
def transform_attention(current: np.ndarray):
if np.ndim(current) == 2:
return transform_attention_bias(current)
elif np.ndim(current) == 3:
return transform_attention_kernel(current)
else:
raise Exception(f"Invalid number of dimesions: {np.ndim(current)}")
def transform_attention_bias(current: np.ndarray):
return current.flatten()
def transform_attention_kernel(current: np.ndarray):
return np.reshape(current, (current.shape[0], current.shape[1] * current.shape[2])).T
def transform_attention_output_weight(current: np.ndarray):
return np.reshape(current, (current.shape[0] * current.shape[1], current.shape[2])).T
def transform_state_encoder_block(state_dict, i):
state = state_dict["optimizer"]["target"]["Transformer"][f"encoderblock_{i}"]
prefix = f"encoder.layer.{i}."
new_state = {
prefix + "intermediate.dense.bias": state["MlpBlock_0"]["Dense_0"]["bias"],
prefix + "intermediate.dense.weight": np.transpose(state["MlpBlock_0"]["Dense_0"]["kernel"]),
prefix + "output.dense.bias": state["MlpBlock_0"]["Dense_1"]["bias"],
prefix + "output.dense.weight": np.transpose(state["MlpBlock_0"]["Dense_1"]["kernel"]),
prefix + "layernorm_before.bias": state["LayerNorm_0"]["bias"],
prefix + "layernorm_before.weight": state["LayerNorm_0"]["scale"],
prefix + "layernorm_after.bias": state["LayerNorm_1"]["bias"],
prefix + "layernorm_after.weight": state["LayerNorm_1"]["scale"],
prefix + "attention.attention.query.bias": transform_attention(
state["MultiHeadDotProductAttention_0"]["query"]["bias"]
),
prefix + "attention.attention.query.weight": transform_attention(
state["MultiHeadDotProductAttention_0"]["query"]["kernel"]
),
prefix + "attention.attention.key.bias": transform_attention(
state["MultiHeadDotProductAttention_0"]["key"]["bias"]
),
prefix + "attention.attention.key.weight": transform_attention(
state["MultiHeadDotProductAttention_0"]["key"]["kernel"]
),
prefix + "attention.attention.value.bias": transform_attention(
state["MultiHeadDotProductAttention_0"]["value"]["bias"]
),
prefix + "attention.attention.value.weight": transform_attention(
state["MultiHeadDotProductAttention_0"]["value"]["kernel"]
),
prefix + "attention.output.dense.bias": state["MultiHeadDotProductAttention_0"]["out"]["bias"],
prefix + "attention.output.dense.weight": transform_attention_output_weight(
state["MultiHeadDotProductAttention_0"]["out"]["kernel"]
),
}
return new_state
def get_n_layers(state_dict):
return sum([1 if "encoderblock_" in k else 0 for k in state_dict["optimizer"]["target"]["Transformer"].keys()])
def transform_state(state_dict, classification_head=False):
transformer_layers = get_n_layers(state_dict)
new_state = OrderedDict()
new_state["layernorm.bias"] = state_dict["optimizer"]["target"]["Transformer"]["encoder_norm"]["bias"]
new_state["layernorm.weight"] = state_dict["optimizer"]["target"]["Transformer"]["encoder_norm"]["scale"]
new_state["embeddings.patch_embeddings.projection.weight"] = np.transpose(
state_dict["optimizer"]["target"]["embedding"]["kernel"], (4, 3, 0, 1, 2)
)
new_state["embeddings.patch_embeddings.projection.bias"] = state_dict["optimizer"]["target"]["embedding"]["bias"]
new_state["embeddings.cls_token"] = state_dict["optimizer"]["target"]["cls"]
return new_state
new_state["embeddings.position_embeddings"] = state_dict["optimizer"]["target"]["Transformer"]["posembed_input"][
"pos_embedding"
]
for i in range(transformer_layers):
new_state.update(transform_state_encoder_block(state_dict, i))
if classification_head:
new_state = {"vivit." + k: v for k, v in new_state.items()}
new_state["classifier.weight"] = np.transpose(state_dict["optimizer"]["target"]["output_projection"]["kernel"])
new_state["classifier.bias"] = np.transpose(state_dict["optimizer"]["target"]["output_projection"]["bias"])
return {k: torch.tensor(v) for k, v in new_state.items()}
def get_processor() -> VivitImageProcessor:
extractor = VivitImageProcessor()
assert extractor.do_resize is True
assert extractor.size == {"shortest_edge": 256}
assert extractor.do_center_crop is True
assert extractor.crop_size == {"width": 224, "height": 224}
assert extractor.resample == PILImageResampling.BILINEAR
assert extractor.do_normalize is False
assert extractor.do_rescale is True
assert extractor.rescale_factor == 1 / 255
assert extractor.do_zero_centering is True
return extractor
def convert(output_path: str):
flax_model_path = "checkpoint"
if not os.path.exists(flax_model_path):
download_checkpoint(flax_model_path)
state_dict = restore_checkpoint(flax_model_path, None)
new_state = transform_state(state_dict, classification_head=True)
config = get_vivit_config()
assert config.image_size == 224
assert config.num_frames == 32
model = VivitForVideoClassification(config)
model.load_state_dict(new_state)
model.eval()
extractor = get_processor()
video = prepare_video()
inputs = extractor(video, return_tensors="pt")
outputs = model(**inputs)
expected_shape = torch.Size([1, 400])
expected_slice = torch.tensor([-1.0543, 2.0764, -0.2104, 0.4439, -0.9658])
assert outputs.logits.shape == expected_shape
assert torch.allclose(outputs.logits[0, :5], expected_slice, atol=1e-4), outputs.logits[0, :5]
model.save_pretrained(output_path)
extractor.save_pretrained(output_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--output_model_name", "-o", type=str, help="输出转换后的 HuggingFace 模型的路径")
args = parser.parse_args()
convert(args.output_model_name)
.\models\vivit\image_processing_vivit.py
"""Vivit 的图像处理类。"""
from typing import Dict, List, Optional, Union
import numpy as np
from transformers.utils import is_vision_available
from transformers.utils.generic import TensorType
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import (
get_resize_output_image_size,
rescale,
resize,
to_channel_dimension_format,
)
from ...image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
ChannelDimension,
ImageInput,
PILImageResampling,
infer_channel_dimension_format,
is_scaled_image,
is_valid_image,
to_numpy_array,
valid_images,
validate_kwargs,
validate_preprocess_arguments,
)
from ...utils import logging
if is_vision_available():
import PIL
logger = logging.get_logger(__name__)
def make_batched(videos) -> List[List[ImageInput]]:
"""将视频列表批量化为 Vivit 需要的格式。
Args:
videos: 输入的视频数据,可以是单个视频或嵌套列表/元组的视频集合。
Returns:
List[List[ImageInput]]: 批量化后的视频列表,每个元素为一个视频帧列表。
Raises:
ValueError: 如果无法从给定的视频数据创建批量化视频。
"""
if isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]):
return videos
elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]):
return [videos]
elif is_valid_image(videos):
return [[videos]]
raise ValueError(f"Could not make batched video from {videos}")
class VivitImageProcessor(BaseImageProcessor):
r"""
构建 Vivit 图像处理器。
继承自 BaseImageProcessor 类。
"""
def __init__(self):
"""初始化 Vivit 图像处理器。"""
super().__init__()
Args:
do_resize (`bool`, *optional*, defaults to `True`):
是否调整图像的高度和宽度尺寸到指定的 `size`。可以在 `preprocess` 方法中的 `do_resize` 参数中被覆盖。
size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 256}`):
调整后的输出图像尺寸。图像的最短边将调整为 `size["shortest_edge"]`,同时保持原始图像的纵横比。可以在 `preprocess` 方法中的 `size` 参数中被覆盖。
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
调整图像尺寸时使用的重采样滤波器。可以在 `preprocess` 方法中的 `resample` 参数中被覆盖。
do_center_crop (`bool`, *optional*, defaults to `True`):
是否对图像进行中心裁剪到指定的 `crop_size`。可以在 `preprocess` 方法中的 `do_center_crop` 参数中被覆盖。
crop_size (`Dict[str, int]`, *optional*, defaults to `{"height": 224, "width": 224}`):
应用中心裁剪后的图像尺寸。可以在 `preprocess` 方法中的 `crop_size` 参数中被覆盖。
do_rescale (`bool`, *optional*, defaults to `True`):
是否按照指定的缩放因子 `rescale_factor` 进行图像缩放。可以在 `preprocess` 方法中的 `do_rescale` 参数中被覆盖。
rescale_factor (`int` or `float`, *optional*, defaults to `1/127.5`):
如果进行图像缩放,定义要使用的缩放因子。可以在 `preprocess` 方法中的 `rescale_factor` 参数中被覆盖。
offset (`bool`, *optional*, defaults to `True`):
是否在正负方向同时进行图像缩放。可以在 `preprocess` 方法中的 `offset` 参数中被覆盖。
do_normalize (`bool`, *optional*, defaults to `True`):
是否对图像进行归一化。可以在 `preprocess` 方法中的 `do_normalize` 参数中被覆盖。
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
如果归一化图像,定义要使用的均值。这是一个浮点数或长度等于图像通道数的浮点数列表。可以在 `preprocess` 方法中的 `image_mean` 参数中被覆盖。
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
如果归一化图像,定义要使用的标准差。这是一个浮点数或长度等于图像通道数的浮点数列表。可以在 `preprocess` 方法中的 `image_std` 参数中被覆盖。
"""
model_input_names = ["pixel_values"]
# 初始化函数,用于设置图像预处理的各项参数
def __init__(
self,
do_resize: bool = True, # 是否进行图像尺寸调整的标志
size: Dict[str, int] = None, # 图像尺寸的字典,包含最短边和可能的其他尺寸参数
resample: PILImageResampling = PILImageResampling.BILINEAR, # 图像调整大小时的重采样方法
do_center_crop: bool = True, # 是否进行中心裁剪的标志
crop_size: Dict[str, int] = None, # 裁剪后的图像尺寸的字典,包含高度和宽度
do_rescale: bool = True, # 是否进行图像像素值缩放的标志
rescale_factor: Union[int, float] = 1 / 127.5, # 图像像素值缩放的因子
offset: bool = True, # 是否进行图像像素值偏移的标志
do_normalize: bool = True, # 是否进行图像像素值标准化的标志
image_mean: Optional[Union[float, List[float]]] = None, # 图像像素值的均值
image_std: Optional[Union[float, List[float]]] = None, # 图像像素值的标准差
**kwargs, # 其他参数,以字典形式传入
) -> None:
# 调用父类的初始化方法,传入额外的关键字参数
super().__init__(**kwargs)
# 如果 size 参数为 None,则设为默认值 {"shortest_edge": 256}
size = size if size is not None else {"shortest_edge": 256}
# 根据参数获取最终确定的图像尺寸字典,确保不是正方形
size = get_size_dict(size, default_to_square=False)
# 如果 crop_size 参数为 None,则设为默认值 {"height": 224, "width": 224}
crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
# 根据参数获取最终确定的裁剪尺寸字典
crop_size = get_size_dict(crop_size, param_name="crop_size")
# 将各参数值赋给对象的属性
self.do_resize = do_resize
self.size = size
self.do_center_crop = do_center_crop
self.crop_size = crop_size
self.resample = resample
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.offset = offset
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
# 图像处理器对象支持的键列表
self._valid_processor_keys = [
"videos",
"do_resize",
"size",
"resample",
"do_center_crop",
"crop_size",
"do_rescale",
"rescale_factor",
"offset",
"do_normalize",
"image_mean",
"image_std",
"return_tensors",
"data_format",
"input_data_format",
]
# 重新调整图像大小的函数,基于输入参数对图像进行变换
# size参数表示输出图像的尺寸,根据get_size_dict函数获取确切的尺寸字典
size = get_size_dict(size, default_to_square=False)
# 如果size字典中包含"shortest_edge"键,根据最短边长度调整输出图像尺寸
if "shortest_edge" in size:
# 调用get_resize_output_image_size函数计算调整后的图像尺寸
output_size = get_resize_output_image_size(
image, size["shortest_edge"], default_to_square=False, input_data_format=input_data_format
)
# 如果size字典中同时包含"height"和"width"键,直接使用指定的高度和宽度
elif "height" in size and "width" in size:
output_size = (size["height"], size["width"])
# 如果size字典中的键不符合要求,抛出数值错误异常
else:
raise ValueError(f"Size must have 'height' and 'width' or 'shortest_edge' as keys. Got {size.keys()}")
# 调用resize函数对图像进行实际的大小调整操作,返回调整后的图像
return resize(
image,
size=output_size,
resample=resample,
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
)
):
"""
Rescale an image by a scale factor.
If `offset` is `True`, the image has its values rescaled by `scale` and then offset by 1. If `scale` is
1/127.5, the image is rescaled between [-1, 1].
image = image * scale - 1
If `offset` is `False`, and `scale` is 1/255, the image is rescaled between [0, 1].
image = image * scale
Args:
image (`np.ndarray`):
Image to rescale.
scale (`int` or `float`):
Scale to apply to the image.
offset (`bool`, *optional*):
Whether to scale the image in both negative and positive directions.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
"""
# 调用rescale函数,对图像进行重新缩放处理
rescaled_image = rescale(
image, scale=scale, data_format=data_format, input_data_format=input_data_format, **kwargs
)
# 如果offset为True,则对重新缩放后的图像进行偏移处理
if offset:
rescaled_image = rescaled_image - 1
# 返回经过缩放和可能偏移处理后的图像
return rescaled_image
def _preprocess_image(
self,
image: ImageInput,
do_resize: bool = None,
size: Dict[str, int] = None,
resample: PILImageResampling = None,
do_center_crop: bool = None,
crop_size: Dict[str, int] = None,
do_rescale: bool = None,
rescale_factor: float = None,
offset: bool = None,
do_normalize: bool = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray:
"""Preprocesses a single image."""
# 验证预处理参数的有效性,确保所有参数都被正确设置
validate_preprocess_arguments(
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_center_crop=do_center_crop,
crop_size=crop_size,
do_resize=do_resize,
size=size,
resample=resample,
)
if offset and not do_rescale:
# 如果设置了 offset 但未设置 do_rescale,则抛出数值错误异常
raise ValueError("For offset, do_rescale must also be set to True.")
# 将输入的图像转换为 numpy 数组
image = to_numpy_array(image)
if is_scaled_image(image) and do_rescale:
# 如果图像已经被缩放,并且需要进行重新缩放,则发出警告
logger.warning_once(
"It looks like you are trying to rescale already rescaled images. If the input"
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
)
if input_data_format is None:
# 推断输入数据的通道维度格式
input_data_format = infer_channel_dimension_format(image)
if do_resize:
# 如果需要进行 resize 操作,则调用 resize 方法进行处理
image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
if do_center_crop:
# 如果需要进行中心裁剪,则调用 center_crop 方法进行处理
image = self.center_crop(image, size=crop_size, input_data_format=input_data_format)
if do_rescale:
# 如果需要进行缩放操作,则调用 rescale 方法进行处理
image = self.rescale(image=image, scale=rescale_factor, offset=offset, input_data_format=input_data_format)
if do_normalize:
# 如果需要进行归一化,则调用 normalize 方法进行处理
image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
# 将图像数据转换为指定的通道维度格式
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
return image
def preprocess(
self,
videos: ImageInput,
do_resize: bool = None,
size: Dict[str, int] = None,
resample: PILImageResampling = None,
do_center_crop: bool = None,
crop_size: Dict[str, int] = None,
do_rescale: bool = None,
rescale_factor: float = None,
offset: bool = None,
do_normalize: bool = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: ChannelDimension = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
.\models\vivit\modeling_vivit.py
class VivitTubeletEmbeddings(nn.Module):
"""
Construct Vivit Tubelet embeddings.
This module turns a batch of videos of shape (batch_size, num_frames, num_channels, height, width) into a tensor of
shape (batch_size, seq_len, hidden_size) to be consumed by a Transformer encoder.
The seq_len (the number of patches) equals (number of frames // tubelet_size[0]) * (height // tubelet_size[1]) *
(width // tubelet_size[2]).
"""
def __init__(self, config):
super().__init__()
self.num_frames = config.num_frames
self.image_size = config.image_size
self.patch_size = config.tubelet_size
self.num_patches = (
(self.image_size // self.patch_size[2])
* (self.image_size // self.patch_size[1])
* (self.num_frames // self.patch_size[0])
)
self.embed_dim = config.hidden_size
self.projection = nn.Conv3d(
config.num_channels,
config.hidden_size,
kernel_size=config.tubelet_size,
stride=config.tubelet_size
)
def forward(self, pixel_values):
batch_size, num_frames, num_channels, height, width = pixel_values.shape
if height != self.image_size or width != self.image_size:
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model ({self.image_size}*{self.image_size})."
)
pixel_values = pixel_values.permute(0, 2, 1, 3, 4)
x = self.projection(pixel_values)
x = self.projection(pixel_values).flatten(2).transpose(1, 2)
return x
class VivitEmbeddings(nn.Module):
"""
Vivit Embeddings.
Creates embeddings from a video using VivitTubeletEmbeddings, adds CLS token and positional embeddings.
"""
def __init__(self, config):
super().__init__()
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
self.patch_embeddings = VivitTubeletEmbeddings(config)
self.position_embeddings = nn.Parameter(
torch.zeros(1, self.patch_embeddings.num_patches + 1, config.hidden_size)
)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.config = config
def forward(self, pixel_values):
batch_size = pixel_values.shape[0]
embeddings = self.patch_embeddings(pixel_values)
cls_tokens = self.cls_token.repeat([batch_size, 1, 1])
embeddings = torch.cat((cls_tokens, embeddings), dim=1)
embeddings = embeddings + self.position_embeddings
embeddings = self.dropout(embeddings)
return embeddings
class VivitSelfAttention(nn.Module):
def __init__(self, config: VivitConfig) -> None:
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
f"heads {config.num_attention_heads}."
)
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
):
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
attention_probs = self.dropout(attention_probs)
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs
class VivitSelfOutput(nn.Module):
"""
The residual connection is defined in VivitLayer instead of here (as is the case with other models), due to the
layernorm applied before each block.
"""
def __init__(self, config: VivitConfig) -> None:
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states
class VivitAttention(nn.Module):
def __init__(self, config: VivitConfig) -> None:
super().__init__()
self.attention = VivitSelfAttention(config)
self.output = VivitSelfOutput(config)
self.pruned_heads = set()
def prune_heads(self, heads: Set[int]) -> None:
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(
heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
)
self.attention.query = prune_linear_layer(self.attention.query, index)
self.attention.key = prune_linear_layer(self.attention.key, index)
self.attention.value = prune_linear_layer(self.attention.value, index)
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
def forward(
self,
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
self_outputs = self.attention(hidden_states, head_mask, output_attentions)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:]
return outputs
class VivitIntermediate(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states
class VivitOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = hidden_states + input_tensor
return hidden_states
class VivitLayer(nn.Module):
"""This corresponds to the EncoderBlock class in the scenic/vivit implementation."""
def __init__(self, config):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = VivitAttention(config)
self.intermediate = VivitIntermediate(config)
self.output = VivitOutput(config)
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states, head_mask=None, output_attentions=False):
self_attention_outputs = self.attention(
self.layernorm_before(hidden_states),
head_mask,
output_attentions=output_attentions,
)
attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:]
hidden_states = attention_output + hidden_states
layer_output = self.layernorm_after(hidden_states)
layer_output = self.intermediate(layer_output)
layer_output = self.output(layer_output, hidden_states)
outputs = (layer_output,) + outputs
return outputs
class VivitEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.layer = nn.ModuleList([VivitLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
def forward(
self,
hidden_states,
head_mask=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
):
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
layer_head_mask,
output_attentions,
)
else:
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
"""
# VivitPooler 类,用于汇集模型隐藏状态的第一个令牌的隐藏状态
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = VivitConfig
base_model_prefix = "vivit"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
def _init_weights(self, module):
"""
# 初始化模型权重
Initialize the weights
"""
if isinstance(module, (nn.Linear, nn.Conv3d)):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, nn.Parameter):
module.data.normal_(mean=0.0, std=self.config.initializer_range)
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using [`VivitImageProcessor`]. See
[`VivitImageProcessor.preprocess`] for details.
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
@add_start_docstrings(
"The bare ViViT Transformer model outputting raw hidden-states without any specific head on top.",
VIVIT_START_DOCSTRING,
)
"""
class VivitModel(VivitPreTrainedModel):
"""
ViViT Transformer model for raw hidden-states.
Args:
config: ViViT model configuration instance.
add_pooling_layer: Whether to add a pooling layer on top of the encoder.
Attributes:
embeddings: ViViT embeddings module.
encoder: ViViT encoder module.
layernorm: Layer normalization module.
pooler: Optional pooling layer for final representation.
Methods:
get_input_embeddings(): Retrieve the patch embeddings from the model.
_prune_heads(heads_to_prune): Prune attention heads in the model.
forward(...): Forward pass of the ViViT model.
"""
def __init__(self, config, add_pooling_layer=True):
super().__init__(config)
self.config = config
self.embeddings = VivitEmbeddings(config)
self.encoder = VivitEncoder(config)
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.pooler = VivitPooler(config) if add_pooling_layer else None
self.post_init()
def get_input_embeddings(self):
"""
Retrieve the patch embeddings from the ViViT model.
Returns:
embeddings: Patch embeddings used by the model.
"""
return self.embeddings.patch_embeddings
def _prune_heads(self, heads_to_prune):
"""
Prunes specific attention heads in the ViViT model.
Args:
heads_to_prune (dict): Dictionary mapping layer numbers to lists of heads to prune.
"""
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
@add_start_docstrings_to_model_forward(VIVIT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
"""
Forward pass of the ViViT model.
Args:
pixel_values: Pixel values of the input video frames.
head_mask: Mask to exclude certain attention heads.
output_attentions: Whether to output attentions weights.
output_hidden_states: Whether to output hidden states.
return_dict: Whether to return a dictionary.
Returns:
BaseModelOutputWithPooling or torch.Tensor: Model outputs.
"""
pass
"""
@add_start_docstrings(
"""ViViT Transformer model with a video classification head on top (a linear layer on top of the final hidden state of the
[CLS] token) e.g. for Kinetics-400.""",
VIVIT_START_DOCSTRING,
)
"""
class VivitForVideoClassification(VivitPreTrainedModel):
"""
ViViT Transformer model with a video classification head.
Args:
config: ViViT model configuration instance.
Attributes:
num_labels: Number of classification labels.
vivit: ViViT base model.
classifier: Linear classification layer.
Methods:
forward(...): Forward pass of the model for video classification.
"""
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.vivit = VivitModel(config, add_pooling_layer=False)
self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
self.post_init()
@add_start_docstrings_to_model_forward(VIVIT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=ImageClassifierOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
"""
Forward pass of the ViViT model for video classification.
Args:
pixel_values: Pixel values of the input video frames.
head_mask: Mask to exclude certain attention heads.
labels: Labels for classification.
output_attentions: Whether to output attentions weights.
output_hidden_states: Whether to output hidden states.
return_dict: Whether to return a dictionary.
Returns:
ImageClassifierOutput or torch.Tensor: Model outputs.
"""
pass
.\models\vivit\__init__.py
from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
_import_structure = {
"configuration_vivit": ["VIVIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "VivitConfig"],
}
try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["image_processing_vivit"] = ["VivitImageProcessor"]
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_vivit"] = [
"VIVIT_PRETRAINED_MODEL_ARCHIVE_LIST",
"VivitModel",
"VivitPreTrainedModel",
"VivitForVideoClassification",
]
if TYPE_CHECKING:
from .configuration_vivit import VIVIT_PRETRAINED_CONFIG_ARCHIVE_MAP, VivitConfig
try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .image_processing_vivit import VivitImageProcessor
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_vivit import (
VIVIT_PRETRAINED_MODEL_ARCHIVE_LIST,
VivitForVideoClassification,
VivitModel,
VivitPreTrainedModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
.\models\wav2vec2\configuration_wav2vec2.py
"""
Wav2Vec2 model configuration
"""
import functools
import operator
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"facebook/wav2vec2-base-960h": "https://huggingface.co/facebook/wav2vec2-base-960h/resolve/main/config.json",
}
class Wav2Vec2Config(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Wav2Vec2Model`]. It is used to instantiate an
Wav2Vec2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the Wav2Vec2
[facebook/wav2vec2-base-960h](https://huggingface.co/facebook/wav2vec2-base-960h) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Example:
```
>>> from transformers import Wav2Vec2Config, Wav2Vec2Model
>>> # Initializing a Wav2Vec2 facebook/wav2vec2-base-960h style configuration
>>> configuration = Wav2Vec2Config()
>>> # Initializing a model (with random weights) from the facebook/wav2vec2-base-960h style configuration
>>> model = Wav2Vec2Model(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
"""
model_type = "wav2vec2"
def __init__(
self,
vocab_size=32,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout=0.1,
activation_dropout=0.1,
attention_dropout=0.1,
feat_proj_dropout=0.0,
feat_quantizer_dropout=0.0,
final_dropout=0.1,
layerdrop=0.1,
initializer_range=0.02,
layer_norm_eps=1e-5,
feat_extract_norm="group",
feat_extract_activation="gelu",
conv_dim=(512, 512, 512, 512, 512, 512, 512),
conv_stride=(5, 2, 2, 2, 2, 2, 2),
conv_kernel=(10, 3, 3, 3, 3, 2, 2),
conv_bias=False,
num_conv_pos_embeddings=128,
num_conv_pos_embedding_groups=16,
do_stable_layer_norm=False,
apply_spec_augment=True,
mask_time_prob=0.05,
mask_time_length=10,
mask_time_min_masks=2,
mask_feature_prob=0.0,
mask_feature_length=10,
mask_feature_min_masks=0,
num_codevectors_per_group=320,
num_codevector_groups=2,
contrastive_logits_temperature=0.1,
num_negatives=100,
codevector_dim=256,
proj_codevector_dim=256,
diversity_loss_weight=0.1,
ctc_loss_reduction="sum",
ctc_zero_infinity=False,
use_weighted_layer_sum=False,
classifier_proj_size=256,
tdnn_dim=(512, 512, 512, 512, 1500),
tdnn_kernel=(5, 3, 3, 1, 1),
tdnn_dilation=(1, 2, 3, 1, 1),
xvector_output_dim=512,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
add_adapter=False,
adapter_kernel_size=3,
adapter_stride=2,
num_adapter_layers=3,
output_hidden_size=None,
adapter_attn_dim=None,
**kwargs,
):
@property
def inputs_to_logits_ratio(self):
return functools.reduce(operator.mul, self.conv_stride, 1)
.\models\wav2vec2\convert_wav2vec2_original_pytorch_checkpoint_to_pytorch.py
"""Convert Wav2Vec2 checkpoint."""
import argparse
import json
import os
import fairseq
import torch
from fairseq.data import Dictionary
from transformers import (
Wav2Vec2Config,
Wav2Vec2CTCTokenizer,
Wav2Vec2FeatureExtractor,
Wav2Vec2ForCTC,
Wav2Vec2ForPreTraining,
Wav2Vec2Processor,
logging,
)
from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2ForSequenceClassification
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
MAPPING = {
"post_extract_proj": "feature_projection.projection",
"encoder.pos_conv.0": "encoder.pos_conv_embed.conv",
"self_attn.k_proj": "encoder.layers.*.attention.k_proj",
"self_attn.v_proj": "encoder.layers.*.attention.v_proj",
"self_attn.q_proj": "encoder.layers.*.attention.q_proj",
"self_attn.out_proj": "encoder.layers.*.attention.out_proj",
"self_attn_layer_norm": "encoder.layers.*.layer_norm",
"fc1": "encoder.layers.*.feed_forward.intermediate_dense",
"fc2": "encoder.layers.*.feed_forward.output_dense",
"final_layer_norm": "encoder.layers.*.final_layer_norm",
"encoder.layer_norm": "encoder.layer_norm",
"adapter_layer": "encoder.layers.*.adapter_layer",
"w2v_model.layer_norm": "feature_projection.layer_norm",
"quantizer.weight_proj": "quantizer.weight_proj",
"quantizer.vars": "quantizer.codevectors",
"project_q": "project_q",
"final_proj": "project_hid",
"w2v_encoder.proj": "lm_head",
"mask_emb": "masked_spec_embed",
"pooling_layer.linear": "projector",
"pooling_layer.projection": "classifier",
}
TOP_LEVEL_KEYS = [
"lm_head",
"quantizer.weight_proj",
"quantizer.codevectors",
"project_q",
"project_hid",
"projector",
"classifier",
]
def read_txt_into_dict(filename):
result = {}
with open(filename, "r") as file:
for line_number, line in enumerate(file):
line = line.strip()
if line:
words = line.split()
key = line_number
value = words[0]
result[key] = value
return result
def set_recursively(key, value, full_name, weight_type, hf_pointer):
for attribute in key.split("."):
hf_pointer = getattr(hf_pointer, attribute)
hf_param_name = None
for param_key in PARAM_MAPPING.keys():
if full_name.endswith(param_key):
hf_param_name = PARAM_MAPPING[full_name.split(".")[-1]]
weight_type = "param"
if weight_type is not None and weight_type != "param":
hf_shape = getattr(hf_pointer, weight_type).shape
elif weight_type is not None and weight_type == "param":
shape_pointer = hf_pointer
for attribute in hf_param_name.split("."):
shape_pointer = getattr(shape_pointer, attribute)
hf_shape = shape_pointer.shape
value = value[0]
else:
hf_shape = hf_pointer.shape
if hf_shape != value.shape:
raise ValueError(
f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be"
f" {value.shape} for {full_name}"
)
if weight_type == "weight":
hf_pointer.weight.data = value
elif weight_type == "weight_g":
hf_pointer.weight_g.data = value
elif weight_type == "weight_v":
hf_pointer.weight_v.data = value
elif weight_type == "bias":
hf_pointer.bias.data = value
elif weight_type == "param":
for attribute in hf_param_name.split("."):
hf_pointer = getattr(hf_pointer, attribute)
hf_pointer.data = value
else:
hf_pointer.data = value
logger.info(f"{key + '.' + weight_type if weight_type is not None else ''} was initialized from {full_name}.")
def rename_dict(key, value, full_name, weight_type, hf_dict):
hf_param_name = None
for param_key in PARAM_MAPPING.keys():
if full_name.endswith(param_key):
hf_param_name = PARAM_MAPPING[full_name.split(".")[-1]]
weight_type = "param"
if weight_type is not None and weight_type != "param":
full_key = ".".join([key, weight_type])
elif weight_type is not None and weight_type == "param":
full_key = ".".join([key, hf_param_name])
else:
full_key = key
hf_dict[full_key] = value if "lm_head" in full_key else value[0]
PARAM_MAPPING = {
"W_a": "linear_1.weight",
"W_b": "linear_2.weight",
"b_a": "linear_1.bias",
"b_b": "linear_2.bias",
"ln_W": "norm.weight",
"ln_b": "norm.bias",
}
def load_wav2vec2_layer(name, value, hf_model=None, hf_dict=None):
is_used = False
for key, mapped_key in MAPPING.items():
mapped_key = "wav2vec2." + mapped_key if mapped_key not in TOP_LEVEL_KEYS else mapped_key
if key in name or key.split("w2v_model.")[-1] == name.split(".")[0]:
is_used = True
if "*" in mapped_key:
layer_index = name.split(key)[0].split(".")[-2]
mapped_key = mapped_key.replace("*", layer_index)
if "weight_g" in name:
weight_type = "weight_g"
elif "weight_v" in name:
weight_type = "weight_v"
elif "bias" in name:
weight_type = "bias"
elif "weight" in name:
weight_type = "weight"
else:
weight_type = None
if hf_dict is not None:
rename_dict(mapped_key, value, name, weight_type, hf_dict)
else:
set_recursively(mapped_key, value, name, weight_type, hf_model)
return is_used
return is_used
def recursively_load_weights(fairseq_model, hf_model, is_headless):
unused_weights = []
fairseq_dict = fairseq_model.state_dict()
feature_extractor = hf_model.wav2vec2.feature_extractor
for name, value in fairseq_dict.items():
is_used = False
if "conv_layers" in name:
load_conv_layer(
name,
value,
feature_extractor,
unused_weights,
hf_model.config.feat_extract_norm == "group",
)
is_used = True
else:
is_used = load_wav2vec2_layer(name, value, hf_model)
if not is_used:
unused_weights.append(name)
logger.warning(f"Unused weights: {unused_weights}")
def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_group_norm):
name = full_name.split("conv_layers.")[-1]
items = name.split(".")
layer_id = int(items[0])
type_id = int(items[1])
if type_id == 0:
if "bias" in name:
if value.shape != feature_extractor.conv_layers[layer_id].conv.bias.data.shape:
raise ValueError(
f"{full_name} has size {value.shape}, but"
f" {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found."
)
feature_extractor.conv_layers[layer_id].conv.bias.data = value
logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
elif "weight" in name:
if value.shape != feature_extractor.conv_layers[layer_id].conv.weight.data.shape:
raise ValueError(
f"{full_name} has size {value.shape}, but"
f" {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found."
)
feature_extractor.conv_layers[layer_id].conv.weight.data = value
logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
elif (type_id == 2 and not use_group_norm) or (type_id == 2 and layer_id == 0 and use_group_norm):
if "bias" in name:
if value.shape != feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape:
raise ValueError(
f"{full_name} has size {value.shape}, but"
f" {feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape} was found."
)
feature_extractor.conv_layers[layer_id].layer_norm.bias.data = value
logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
elif "weight" in name:
if value.shape != feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape:
raise ValueError(
f"{full_name} has size {value.shape}, but"
f" {feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape} was found."
)
feature_extractor.conv_layers[layer_id].layer_norm.weight.data = value
logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
else:
unused_weights.append(full_name)
@torch.no_grad()
def convert_wav2vec2_checkpoint(
checkpoint_path, pytorch_dump_folder_path, config_path=None, dict_path=None, is_finetuned=True, is_seq_class=False
):
"""
Copy/paste/tweak model's weights to transformers design.
"""
if config_path is not None:
config = Wav2Vec2Config.from_pretrained(config_path)
else:
config = Wav2Vec2Config()
if is_seq_class:
id2label = read_txt_into_dict(dict_path)
config.id2label = id2label
hf_wav2vec = Wav2Vec2ForSequenceClassification(config)
feature_extractor = Wav2Vec2FeatureExtractor(
feature_size=1,
sampling_rate=16000,
padding_value=0,
do_normalize=True,
return_attention_mask=True,
)
feature_extractor.save_pretrained(pytorch_dump_folder_path)
elif is_finetuned:
if dict_path:
target_dict = Dictionary.load(dict_path)
config.bos_token_id = target_dict.pad_index
config.pad_token_id = target_dict.bos_index
config.eos_token_id = target_dict.eos_index
config.vocab_size = len(target_dict.symbols)
vocab_path = os.path.join(pytorch_dump_folder_path, "vocab.json")
if not os.path.isdir(pytorch_dump_folder_path):
logger.error("--pytorch_dump_folder_path ({}) should be a directory".format(pytorch_dump_folder_path))
return
os.makedirs(pytorch_dump_folder_path, exist_ok=True)
vocab_dict = target_dict.indices
vocab_dict["<pad>"] = 0
vocab_dict["<s>"] = 1
with open(vocab_path, "w", encoding="utf-8") as vocab_handle:
json.dump(vocab_dict, vocab_handle)
tokenizer = Wav2Vec2CTCTokenizer(
vocab_path,
unk_token=target_dict.unk_word,
pad_token=target_dict.pad_word,
bos_token=target_dict.bos_word,
eos_token=target_dict.eos_word,
word_delimiter_token="|",
do_lower_case=False,
)
return_attention_mask = True if config.feat_extract_norm == "layer" else False
feature_extractor = Wav2Vec2FeatureExtractor(
feature_size=1,
sampling_rate=16000,
padding_value=0,
do_normalize=True,
return_attention_mask=return_attention_mask,
)
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
processor.save_pretrained(pytorch_dump_folder_path)
hf_wav2vec = Wav2Vec2ForCTC(config)
else:
hf_wav2vec = Wav2Vec2ForPreTraining(config)
if is_finetuned or is_seq_class:
model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
[checkpoint_path], arg_overrides={"data": "/".join(dict_path.split("/")[:-1])}
)
else:
task_arg = argparse.Namespace(task="audio_pretraining")
task = fairseq.tasks.setup_task(task_arg)
model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([checkpoint_path], task=task)
model = model[0].eval()
recursively_load_weights(model, hf_wav2vec, not is_finetuned)
hf_wav2vec.save_pretrained(pytorch_dump_folder_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to fairseq checkpoint")
parser.add_argument("--dict_path", default=None, type=str, help="Path to dict of fine-tuned model")
parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
parser.add_argument(
"--not_finetuned", action="store_true", help="Whether the model to convert is a fine-tuned model or not"
)
parser.add_argument(
"--is_seq_class",
action="store_true",
help="Whether the model to convert is a fine-tuned sequence classification model or not",
)
args = parser.parse_args()
is_finetuned = not args.not_finetuned and not args.is_seq_class
convert_wav2vec2_checkpoint(
args.checkpoint_path,
args.pytorch_dump_folder_path,
args.config_path,
args.dict_path,
is_finetuned,
args.is_seq_class,
)
.\models\wav2vec2\convert_wav2vec2_original_s3prl_checkpoint_to_pytorch.py
import argparse
import torch
from transformers import (
Wav2Vec2Config,
Wav2Vec2FeatureExtractor,
Wav2Vec2ForAudioFrameClassification,
Wav2Vec2ForSequenceClassification,
Wav2Vec2ForXVector,
logging,
)
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
def convert_classification(base_model_name, hf_config, downstream_dict):
model = Wav2Vec2ForSequenceClassification.from_pretrained(base_model_name, config=hf_config)
model.projector.weight.data = downstream_dict["projector.weight"]
model.projector.bias.data = downstream_dict["projector.bias"]
model.classifier.weight.data = downstream_dict["model.post_net.linear.weight"]
model.classifier.bias.data = downstream_dict["model.post_net.linear.bias"]
return model
def convert_diarization(base_model_name, hf_config, downstream_dict):
model = Wav2Vec2ForAudioFrameClassification.from_pretrained(base_model_name, config=hf_config)
model.classifier.weight.data = downstream_dict["model.linear.weight"]
model.classifier.bias.data = downstream_dict["model.linear.bias"]
return model
def convert_xvector(base_model_name, hf_config, downstream_dict):
model = Wav2Vec2ForXVector.from_pretrained(base_model_name, config=hf_config)
model.projector.weight.data = downstream_dict["connector.weight"]
model.projector.bias.data = downstream_dict["connector.bias"]
for i, kernel_size in enumerate(hf_config.tdnn_kernel):
model.tdnn[i].kernel.weight.data = downstream_dict[
f"model.framelevel_feature_extractor.module.{i}.kernel.weight"
]
model.tdnn[i].kernel.bias.data = downstream_dict[f"model.framelevel_feature_extractor.module.{i}.kernel.bias"]
model.feature_extractor.weight.data = downstream_dict["model.utterancelevel_feature_extractor.linear1.weight"]
model.feature_extractor.bias.data = downstream_dict["model.utterancelevel_feature_extractor.linear1.bias"]
model.classifier.weight.data = downstream_dict["model.utterancelevel_feature_extractor.linear2.weight"]
model.classifier.bias.data = downstream_dict["model.utterancelevel_feature_extractor.linear2.bias"]
model.objective.weight.data = downstream_dict["objective.W"]
return model
@torch.no_grad()
def convert_s3prl_checkpoint(base_model_name, config_path, checkpoint_path, model_dump_path):
"""
将模型的权重复制/粘贴/调整到transformers设计中。
"""
checkpoint = torch.load(checkpoint_path, map_location="cpu")
downstream_dict = checkpoint["Downstream"]
hf_config = Wav2Vec2Config.from_pretrained(config_path)
hf_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
base_model_name, return_attention_mask=True, do_normalize=False
)
arch = hf_config.architectures[0]
if arch.endswith("ForSequenceClassification"):
hf_model = convert_classification(base_model_name, hf_config, downstream_dict)
elif arch.endswith("ForAudioFrameClassification"):
hf_model = convert_diarization(base_model_name, hf_config, downstream_dict)
elif arch.endswith("ForXVector"):
hf_model = convert_xvector(base_model_name, hf_config, downstream_dict)
else:
raise NotImplementedError(f"S3PRL weights conversion is not supported for {arch}")
if hf_config.use_weighted_layer_sum:
hf_model.layer_weights.data = checkpoint["Featurizer"]["weights"]
hf_feature_extractor.save_pretrained(model_dump_path)
hf_model.save_pretrained(model_dump_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--base_model_name", default=None, type=str, help="Name of the huggingface pretrained base model."
)
parser.add_argument("--config_path", default=None, type=str, help="Path to the huggingface classifier config.")
parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to the s3prl checkpoint.")
parser.add_argument("--model_dump_path", default=None, type=str, help="Path to the final converted model.")
args = parser.parse_args()
convert_s3prl_checkpoint(args.base_model_name, args.config_path, args.checkpoint_path, args.model_dump_path)
.\models\wav2vec2\feature_extraction_wav2vec2.py
"""
Wav2Vec2 的特征提取器类
"""
from typing import List, Optional, Union
import numpy as np
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
from ...feature_extraction_utils import BatchFeature
from ...utils import PaddingStrategy, TensorType, logging
logger = logging.get_logger(__name__)
class Wav2Vec2FeatureExtractor(SequenceFeatureExtractor):
r"""
构建一个 Wav2Vec2 特征提取器。
此特征提取器继承自 [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`],其中包含大部分主要方法。
用户应参考此超类以获取关于这些方法的更多信息。
Args:
feature_size (`int`, 默认为 1):
提取特征的特征维度。
sampling_rate (`int`, 默认为 16000):
音频文件的数字化采样率,以赫兹(Hz)表示。
padding_value (`float`, 默认为 0.0):
用于填充填充值的值。
do_normalize (`bool`, *可选*, 默认为 `True`):
是否对输入进行零均值单位方差归一化。归一化可以显著提高某些模型的性能,
例如 [wav2vec2-lv60](https://huggingface.co/models?search=lv60)。
return_attention_mask (`bool`, *可选*, 默认为 `False`):
是否 [`~Wav2Vec2FeatureExtractor.__call__`] 应返回 `attention_mask`。
<Tip>
对于设置了 `config.feat_extract_norm == "group"` 的 Wav2Vec2 模型,例如
[wav2vec2-base](https://huggingface.co/facebook/wav2vec2-base-960h),**没有** 使用
`attention_mask` 进行训练。对于这样的模型,`input_values` 应仅用 0 填充,不应传递 `attention_mask`。
对于设置了 `config.feat_extract_norm == "layer"` 的 Wav2Vec2 模型,例如
[wav2vec2-lv60](https://huggingface.co/facebook/wav2vec2-large-960h-lv60-self),应传递 `attention_mask`
以进行批处理推断。
</Tip>
"""
model_input_names = ["input_values", "attention_mask"]
def __init__(
self,
feature_size=1,
sampling_rate=16000,
padding_value=0.0,
return_attention_mask=False,
do_normalize=True,
**kwargs,
):
super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
self.return_attention_mask = return_attention_mask
self.do_normalize = do_normalize
@staticmethod
def zero_mean_unit_var_norm(
input_values: List[np.ndarray], attention_mask: List[np.ndarray], padding_value: float = 0.0
) -> List[np.ndarray]:
"""
每个数组都被归一化为零均值和单位方差
"""
if attention_mask is not None:
attention_mask = np.array(attention_mask, np.int32)
normed_input_values = []
for vector, length in zip(input_values, attention_mask.sum(-1)):
normed_slice = (vector - vector[:length].mean()) / np.sqrt(vector[:length].var() + 1e-7)
if length < normed_slice.shape[0]:
normed_slice[length:] = padding_value
normed_input_values.append(normed_slice)
else:
normed_input_values = [(x - x.mean()) / np.sqrt(x.var() + 1e-7) for x in input_values]
return normed_input_values
def __call__(
self,
raw_speech: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]],
padding: Union[bool, str, PaddingStrategy] = False,
max_length: Optional[int] = None,
truncation: bool = False,
pad_to_multiple_of: Optional[int] = None,
return_attention_mask: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
sampling_rate: Optional[int] = None,
**kwargs,
.\models\wav2vec2\modeling_flax_wav2vec2.py
"""
Flax Wav2Vec2 model.
"""
from functools import partial
from typing import Optional, Tuple, Union
import flax
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput
from ...modeling_flax_utils import (
ACT2FN,
FlaxPreTrainedModel,
append_replace_return_docstrings,
overwrite_call_docstring,
)
from ...utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from .configuration_wav2vec2 import Wav2Vec2Config
logger = logging.get_logger(__name__)
last_hidden_state: jnp.ndarray = None
extract_features: jnp.ndarray = None
hidden_states: Optional[Tuple[jnp.ndarray]] = None
attentions: Optional[Tuple[jnp.ndarray]] = None
@flax.struct.dataclass
class FlaxWav2Vec2ForPreTrainingOutput(ModelOutput):
"""
Output type of [`FlaxWav2Vec2ForPreTrainingOutput`], with potential hidden states and attentions.
Args:
loss (*optional*, returned when model is in train mode, `jnp.ndarray` of shape `(1,)`):
Total loss as the sum of the contrastive loss (L_m) and the diversity loss (L_d) as stated in the [official
paper](https://arxiv.org/pdf/2006.11477.pdf) . (classification) loss.
projected_states (`jnp.ndarray` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
Hidden-states of the model projected to *config.proj_codevector_dim* that can be used to predict the masked
projected quantized states.
projected_quantized_states (`jnp.ndarray` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
Quantized extracted feature vectors projected to *config.proj_codevector_dim* representing the positive
target vectors for contrastive loss.
hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
projected_states: jnp.ndarray = None
projected_quantized_states: jnp.ndarray = None
codevector_perplexity: jnp.ndarray = None
hidden_states: Optional[Tuple[jnp.ndarray]] = None
attentions: Optional[Tuple[jnp.ndarray]] = None
def _compute_mask_indices(
shape: Tuple[int, int],
mask_prob: float,
mask_length: int,
attention_mask: Optional[np.ndarray] = None,
min_masks: int = 0,
) -> np.ndarray:
"""
Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
CPU as part of the preprocessing during training.
"""
Args:
shape: the shape for which to compute masks.
should be of size 2 where first element is batch size and 2nd is timesteps
mask_prob:
probability for each token to be chosen as start of the span to be masked. this will be multiplied by
number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
however due to overlaps, the actual number will be smaller (unless no_overlap is True)
mask_length: size of the mask
min_masks: minimum number of masked spans
"""
# 解包形状参数,batch_size 为批次大小,sequence_length 为时间步长
batch_size, sequence_length = shape
# 如果 mask_length 小于 1,则引发值错误
if mask_length < 1:
raise ValueError("`mask_length` has to be bigger than 0.")
# 如果 mask_length 大于 sequence_length,则引发值错误
if mask_length > sequence_length:
raise ValueError(
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and"
f" `sequence_length`: {sequence_length}`"
)
# 计算每批次中需要掩蔽的区间数目
num_masked_spans = int(mask_prob * sequence_length / mask_length + np.random.rand(1).item())
# 确保 num_masked_spans 不小于 min_masks
num_masked_spans = max(num_masked_spans, min_masks)
# 确保掩蔽的索引数不超过 sequence_length
if num_masked_spans * mask_length > sequence_length:
num_masked_spans = sequence_length // mask_length
# 初始化一个形状为 (batch_size, sequence_length) 的布尔类型的掩蔽数组
spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
# 随机生成要掩蔽的起始索引
spec_aug_mask_idxs = np.array(
[
np.random.choice(np.arange(sequence_length - (mask_length - 1)), num_masked_spans, replace=False)
for _ in range(batch_size)
]
)
# 将掩蔽的索引扩展为掩蔽的区间
spec_aug_mask_idxs = np.broadcast_to(spec_aug_mask_idxs[:, :, None], (batch_size, num_masked_spans, mask_length))
spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, num_masked_spans * mask_length)
# 创建一个偏移数组以便扩展掩蔽的区间
offsets = np.arange(mask_length)[None, None, :]
offsets = np.broadcast_to(offsets, (batch_size, num_masked_spans, mask_length)).reshape(
batch_size, num_masked_spans * mask_length
)
spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
# 在掩蔽数组中填充掩蔽的索引
np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
# 如果存在 attention_mask,则确保填充的输入 ID 不能被掩蔽
if attention_mask is not None:
spec_aug_mask = np.where(attention_mask, spec_aug_mask, False)
# 返回生成的掩蔽数组
return spec_aug_mask
def _sample_negative_indices(features_shape: Tuple, num_negatives: int, attention_mask: Optional[np.ndarray] = None):
"""
Sample `num_negatives` vectors from feature vectors.
"""
# 解析输入参数的形状信息
batch_size, sequence_length, hidden_size = features_shape
# 检查序列长度是否小于等于1,如果是则引发异常
if sequence_length <= 1:
raise ValueError(
"`features should have `sequence_length` > 1, but are of shape "
f"(batch_size, sequence_length, hidden_size) = ({batch_size, sequence_length, hidden_size})."
)
# 从同一个语句中随机选择 `num_negatives` 个向量索引
sampled_negative_indices = []
for batch_idx in range(batch_size):
# 根据注意力掩码确定可用索引的上限,或者使用序列长度的上限
high = attention_mask[batch_idx].sum() - 1 if attention_mask is not None else sequence_length - 1
# 随机抽样索引,数量为 `num_negatives * sequence_length`
sampled_indices_slice = np.random.randint(0, high, size=(num_negatives * sequence_length,))
sampled_negative_indices.append(sampled_indices_slice)
sampled_negative_indices = np.asarray(sampled_negative_indices, dtype=np.int32)
# 生成正向量的索引,将其重复 `num_negatives` 次
feature_indices = np.broadcast_to(np.arange(sequence_length)[:, None], (sequence_length, num_negatives)).flatten()
# 避免抽样到相同的正向量索引,同时保持均匀分布
sampled_negative_indices[sampled_negative_indices >= feature_indices] += 1
# 调整索引以匹配批次大小
for batch_idx in range(1, batch_size):
sampled_negative_indices[batch_idx] += batch_idx * sequence_length
return sampled_negative_indices
WAV_2_VEC_2_START_DOCSTRING = r"""
Wav2Vec2 was proposed in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech
Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael
Auli.
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a Flax Linen
[flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
Finally, this model supports inherent JAX features such as:
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html
"""
Parameters:
config ([`Wav2Vec2Config`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
`jax.numpy.bfloat16` (on TPUs).
This specifies the data type used for computations, allowing for mixed-precision training or
half-precision inference on GPUs or TPUs. If specified, all computations within the model will be
performed with the specified `dtype`.
**Note that this setting affects only the computation dtype and not the dtype of model parameters.**
To change the dtype of model parameters, refer to [`~FlaxPreTrainedModel.to_fp16`] and
[`~FlaxPreTrainedModel.to_bf16`].
"""
定义一个类 `FlaxWav2Vec2LayerNormConvLayer`,继承自 `nn.Module`,用于实现基于 Flax 的 Wav2Vec2 模型的一层。
"""
class FlaxWav2Vec2LayerNormConvLayer(nn.Module):
# 设置类属性 `config` 为 `Wav2Vec2Config` 类型,用于配置模型参数
config: Wav2Vec2Config
# 设置类属性 `layer_id` 为整数,表示当前层的标识,默认为 0
layer_id: int = 0
# 设置类属性 `dtype` 为 `jnp.float32`,表示数据类型为 32 位浮点数
dtype: jnp.dtype = jnp.float32
# 设置函数,用于初始化网络层参数
def setup(self):
# 如果当前层不是第一层,设置输入卷积维度为指定的卷积维度列表中对应层的值,否则设为1
self.in_conv_dim = self.config.conv_dim[self.layer_id] if self.layer_id > 0 else 1
# 设置输出卷积维度为指定的卷积维度列表中对应层的值
self.out_conv_dim = self.config.conv_dim[self.layer_id]
# 初始化卷积层
self.conv = nn.Conv(
features=self.config.conv_dim[self.layer_id], # 卷积层输出特征维度
kernel_size=(self.config.conv_kernel[self.layer_id],), # 卷积核大小
strides=(self.config.conv_stride[self.layer_id],), # 卷积步长
use_bias=self.config.conv_bias, # 是否使用偏置
kernel_init=jax.nn.initializers.he_normal(), # 卷积核初始化方法
padding="VALID", # 卷积填充方式
dtype=self.dtype, # 数据类型
)
# 初始化层归一化层
self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
# 初始化激活函数,根据配置选择相应的激活函数
self.activation = ACT2FN[self.config.feat_extract_activation]
# 定义调用函数,用于前向传播计算
def __call__(self, hidden_states):
# 卷积操作,计算特征提取后的隐藏状态
hidden_states = self.conv(hidden_states)
# 层归一化操作,对卷积输出进行归一化处理
hidden_states = self.layer_norm(hidden_states)
# 激活函数操作,对归一化后的输出应用激活函数
hidden_states = self.activation(hidden_states)
# 返回处理后的隐藏状态
return hidden_states
# 定义一个自定义的 Flax 模块,用于卷积操作并包含权重归一化
class FlaxConvWithWeightNorm(nn.Module):
# 配置信息,指定为 Wav2Vec2Config 类型
config: Wav2Vec2Config
# 数据类型,默认为 jnp.float32
dtype: jnp.dtype = jnp.float32
# 模块设置方法,用于初始化模块的各个部分
def setup(self):
# 创建卷积层,设置特征数为 hidden_size,卷积核大小为 num_conv_pos_embeddings
self.conv = nn.Conv(
features=self.config.hidden_size,
kernel_size=(self.config.num_conv_pos_embeddings,),
kernel_init=jax.nn.initializers.he_normal(),
padding="VALID",
feature_group_count=self.config.num_conv_pos_embedding_groups,
dtype=self.dtype,
)
# 定义权重形状,与卷积层特征数及分组数有关
weight_shape = (
self.conv.features,
self.conv.features // self.conv.feature_group_count,
self.conv.kernel_size[0],
)
# 初始化并定义权重 v 作为模型参数,使用 he_normal 初始化器
self.weight_v = self.param("weight_v", jax.nn.initializers.he_normal(), weight_shape)
# 计算权重 v 的 L2 范数,并初始化权重 g 作为模型参数
self.weight_g = self.param("weight_g", lambda _: jnp.linalg.norm(self.weight_v, axis=(0, 1))[None, None, :])
# 初始化偏置参数,特征数与卷积层相同
self.bias = self.param("bias", jax.nn.initializers.zeros, (self.conv.features,))
# 计算用于填充输入的前置填充数
self.prev_padding = self.conv.kernel_size[0] // 2
# 内部方法,用于获取归一化后的权重
def _get_normed_weights(self):
# 计算权重 v 的归一化形式
weight_v_norm = jnp.linalg.norm(self.weight_v, axis=(0, 1))[None, None, :]
normed_weight_v = jnp.divide(self.weight_v, weight_v_norm)
# 计算归一化后的卷积核
normed_kernel = jnp.multiply(normed_weight_v, self.weight_g)
return normed_kernel
# 模块的调用方法,执行卷积操作并返回结果
def __call__(self, hidden_states):
# 获取归一化后的卷积核
kernel = self._get_normed_weights()
# 对输入进行前置填充,保证卷积输出尺寸与输入相同
hidden_states = jnp.pad(hidden_states, ((0, 0), (self.prev_padding, self.prev_padding), (0, 0)))
# 应用卷积操作到输入上,使用归一化后的卷积核和偏置
hidden_states = self.conv.apply({"params": {"kernel": kernel.T, "bias": self.bias}}, hidden_states)
return hidden_states
# 定义一个 Flax 模块,用于处理位置卷积嵌入
class FlaxWav2Vec2PositionalConvEmbedding(nn.Module):
# 配置信息,指定为 Wav2Vec2Config 类型
config: Wav2Vec2Config
# 数据类型,默认为 jnp.float32
dtype: jnp.dtype = jnp.float32
# 模块设置方法,用于初始化模块的各个部分
def setup(self):
# 创建包含权重归一化的卷积层模块
self.conv = FlaxConvWithWeightNorm(self.config, dtype=self.dtype)
# 设置激活函数为配置文件中指定的函数
self.activation = ACT2FN[self.config.feat_extract_activation]
# 根据卷积核大小决定需要移除的填充数量
self.num_pad_remove = 1 if self.config.num_conv_pos_embeddings % 2 == 0 else 0
# 模块的调用方法,执行位置卷积嵌入操作并返回结果
def __call__(self, hidden_states):
# 调整输入张量的维度顺序
hidden_states = hidden_states.transpose((0, 1, 2))
# 应用包含权重归一化的卷积操作到输入上
hidden_states = self.conv(hidden_states)
# 根据需要移除的填充数量截取卷积输出
if self.num_pad_remove > 0:
hidden_states = hidden_states[:, : -self.num_pad_remove, :]
# 应用激活函数到卷积输出上
hidden_states = self.activation(hidden_states)
# 恢复张量的原始维度顺序并返回结果
hidden_states = hidden_states.transpose((0, 1, 2))
return hidden_states
# 定义一个 Flax 模块,用于包含一系列卷积层的集合
class FlaxConvLayersCollection(nn.Module):
# 配置信息,指定为 Wav2Vec2Config 类型
config: Wav2Vec2Config
# 数据类型,默认为 jnp.float32
dtype: jnp.dtype = jnp.float32
# 初始化方法,用于设置对象的初始状态
def setup(self):
# 如果配置要求特征提取的归一化方式为 "layer"
if self.config.feat_extract_norm == "layer":
# 创建一系列 FlaxWav2Vec2LayerNormConvLayer 对象作为 self.layers 列表的元素,
# 每个对象对应一个特征提取层
self.layers = [
FlaxWav2Vec2LayerNormConvLayer(self.config, layer_id=i, name=str(i), dtype=self.dtype)
for i in range(self.config.num_feat_extract_layers)
]
# 如果配置要求特征提取的归一化方式为 "group",暂时不支持这种方式
elif self.config.feat_extract_norm == "group":
# 抛出 NotImplementedError 异常,提醒暂时只支持 "layer" 形式的特征提取归一化
raise NotImplementedError("At the moment only ``config.feat_extact_norm == 'layer'`` is supported")
# 如果配置的特征提取归一化方式既不是 "layer" 也不是 "group",则抛出 ValueError 异常
else:
# 抛出 ValueError 异常,指明配置中的 feat_extract_norm 值不合法
raise ValueError(
f"`config.feat_extract_norm` is {self.config.feat_extract_norm}, but has to be one of ['group',"
" 'layer']"
)
# 对象被调用时执行的方法,用于处理输入的隐藏状态数据
def __call__(self, hidden_states):
# 遍历 self.layers 中的每个 conv_layer,依次对 hidden_states 进行处理
for i, conv_layer in enumerate(self.layers):
hidden_states = conv_layer(hidden_states) # 调用 conv_layer 对象处理 hidden_states
# 返回处理后的 hidden_states
return hidden_states
class FlaxWav2Vec2FeatureEncoder(nn.Module):
"""从原始音频波形中构建特征"""
config: Wav2Vec2Config # 引用Wav2Vec2Config配置对象
dtype: jnp.dtype = jnp.float32 # 计算时使用的数据类型,默认为单精度浮点数
def setup(self):
self.conv_layers = FlaxConvLayersCollection(self.config, dtype=self.dtype)
# 初始化卷积层集合,使用配置对象和指定数据类型
def __call__(self, input_values, freeze_feature_encoder=False):
hidden_states = input_values[:, :, None]
# 在最后添加一个维度,将形状从[batch_size, seq_len]变为[batch_size, seq_len, 1]
hidden_states = self.conv_layers(hidden_states)
# 经过卷积层处理,处理后形状为[batch_size, seq_len, hidden_size]
if freeze_feature_encoder:
hidden_states = jax.lax.stop_gradient(hidden_states)
# 如果需要冻结特征编码器,则停止梯度传播
return hidden_states
class FlaxWav2Vec2FeatureProjection(nn.Module):
config: Wav2Vec2Config # 引用Wav2Vec2Config配置对象
dtype: jnp.dtype = jnp.float32 # 计算时使用的数据类型,默认为单精度浮点数
def setup(self):
self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
# 初始化层归一化,使用指定的epsilon值和数据类型
self.projection = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
# 初始化全连接层,设置隐藏大小、权重初始化方法和数据类型
self.dropout = nn.Dropout(rate=self.config.feat_proj_dropout)
# 初始化dropout层,设置丢弃率为配置中的特征投影dropout率
def __call__(self, hidden_states, deterministic=True):
norm_hidden_states = self.layer_norm(hidden_states)
# 对隐藏状态进行层归一化处理
hidden_states = self.projection(norm_hidden_states)
# 应用投影层
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
# 应用dropout,如果确定性为True,则使用确定性dropout
return hidden_states, norm_hidden_states
class FlaxWav2Vec2Attention(nn.Module):
config: Wav2Vec2Config # 引用Wav2Vec2Config配置对象
embed_dim: int # 嵌入维度
num_heads: int # 头的数量
dropout: float = 0.0 # dropout率,默认为0.0
bias: bool = True # 是否使用偏置,默认为True
dtype: jnp.dtype = jnp.float32 # 计算时使用的数据类型,默认为单精度浮点数
def setup(self) -> None:
self.head_dim = self.embed_dim // self.num_heads
# 计算每个头的维度
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
)
# 检查embed_dim必须能够被num_heads整除的条件,否则引发错误
dense = partial(
nn.Dense,
self.embed_dim,
use_bias=self.bias,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
)
# 创建一个部分应用了参数的全连接层函数
self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
# 使用dense函数初始化查询、键、值投影层
self.out_proj = dense()
# 使用dense函数初始化输出投影层
self.dropout_layer = nn.Dropout(rate=self.dropout)
# 初始化dropout层,设置丢弃率为配置中的dropout率
def _split_heads(self, hidden_states):
return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
# 将隐藏状态切分成多个头
def _merge_heads(self, hidden_states):
return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
# 合并多个头的隐藏状态
def __call__(
self,
hidden_states: jnp.ndarray,
key_value_states: Optional[jnp.ndarray] = None,
attention_mask: Optional[jnp.ndarray] = None,
deterministic: bool = True,
# 定义Attention层的调用方式,包括隐藏状态、键值状态、注意力掩码和确定性
) -> Tuple[jnp.ndarray]:
"""Input shape: Batch x Time x Channel"""
# 获取查询投影
query_states = self.q_proj(hidden_states)
# 获取键投影和值投影
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# 将查询投影、键投影和值投影按照头的数量进行分割
query_states = self._split_heads(query_states)
key_states = self._split_heads(key_states)
value_states = self._split_heads(value_states)
# 如果存在注意力掩码,则扩展维度以匹配张量形状
if attention_mask is not None:
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
# 将布尔类型的注意力掩码转换为注意力偏置
if attention_mask is not None:
# 注意力掩码转换为注意力偏置
attention_bias = lax.select(
attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
)
else:
attention_bias = None
# 如果不是确定性计算且具有非零的 dropout 率,则创建 dropout 随机数生成器
dropout_rng = None
if not deterministic and self.dropout > 0.0:
dropout_rng = self.make_rng("dropout")
# 计算注意力权重
attn_weights = dot_product_attention_weights(
query_states,
key_states,
bias=attention_bias,
dropout_rng=dropout_rng,
dropout_rate=self.dropout,
broadcast_dropout=True,
deterministic=deterministic,
dtype=self.dtype,
precision=None,
)
# 计算注意力输出,使用 einsum 实现批量矩阵乘法
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
attn_output = self._merge_heads(attn_output) # 合并注意力头
attn_output = self.out_proj(attn_output) # 输出投影
return attn_output, attn_weights
# 定义一个名为 FlaxWav2Vec2FeedForward 的自定义神经网络模块,继承自 nn.Module
class FlaxWav2Vec2FeedForward(nn.Module):
# 类属性:配置信息,类型为 Wav2Vec2Config
config: Wav2Vec2Config
# 类属性:数据类型,默认为 jnp.float32
dtype: jnp.dtype = jnp.float32
# 初始化方法,设置网络结构
def setup(self):
# 定义中间层的 dropout 操作,使用配置中的激活函数的 dropout 率
self.intermediate_dropout = nn.Dropout(rate=self.config.activation_dropout)
# 定义中间层的全连接层,输入大小为配置中的 intermediate_size
# 初始化方式为正态分布,范围为配置中的 initializer_range
self.intermediate_dense = nn.Dense(
self.config.intermediate_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
# 根据配置选择激活函数,如果是字符串则从预定义的映射中获取,否则直接使用配置中的激活函数
if isinstance(self.config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[self.config.hidden_act]
else:
self.intermediate_act_fn = self.config.hidden_act
# 定义输出层的全连接层,输出大小为配置中的 hidden_size
# 初始化方式为正态分布,范围为配置中的 initializer_range
self.output_dense = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
# 定义输出层的 dropout 操作,使用配置中的隐藏层 dropout 率
self.output_dropout = nn.Dropout(rate=self.config.hidden_dropout)
# 前向传播方法,接收隐藏状态和是否确定性的标志,返回最终的隐藏状态
def __call__(self, hidden_states, deterministic=True):
# 中间层的全连接操作
hidden_states = self.intermediate_dense(hidden_states)
# 中间层的激活函数
hidden_states = self.intermediate_act_fn(hidden_states)
# 中间层的 dropout 操作
hidden_states = self.intermediate_dropout(hidden_states, deterministic=deterministic)
# 输出层的全连接操作
hidden_states = self.output_dense(hidden_states)
# 输出层的 dropout 操作
hidden_states = self.output_dropout(hidden_states, deterministic=deterministic)
# 返回最终的隐藏状态
return hidden_states
# 定义一个名为 FlaxWav2Vec2EncoderLayerStableLayerNorm 的自定义神经网络模块,继承自 nn.Module
class FlaxWav2Vec2EncoderLayerStableLayerNorm(nn.Module):
# 类属性:配置信息,类型为 Wav2Vec2Config
config: Wav2Vec2Config
# 类属性:数据类型,默认为 jnp.float32
dtype: jnp.dtype = jnp.float32
# 初始化方法,设置网络结构
def setup(self):
# 定义注意力层
self.attention = FlaxWav2Vec2Attention(
config=self.config,
embed_dim=self.config.hidden_size,
num_heads=self.config.num_attention_heads,
dropout=self.config.attention_dropout,
dtype=self.dtype,
)
# 定义隐藏层的 dropout 操作
self.dropout = nn.Dropout(rate=self.config.hidden_dropout)
# 定义层归一化操作,使用配置中的 epsilon
self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
# 定义前馈网络层
self.feed_forward = FlaxWav2Vec2FeedForward(self.config, dtype=self.dtype)
# 定义最终的层归一化操作
self.final_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
# 前向传播方法,接收隐藏状态、注意力掩码、是否确定性的标志和是否输出注意力权重的标志,返回输出
def __call__(self, hidden_states, attention_mask=None, deterministic=True, output_attentions=False):
# 记录注意力残差连接
attn_residual = hidden_states
# 应用层归一化操作
hidden_states = self.layer_norm(hidden_states)
# 注意力层的前向传播
hidden_states, attn_weights = self.attention(
hidden_states, attention_mask=attention_mask, deterministic=deterministic
)
# 应用隐藏层的 dropout 操作
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
# 加上注意力残差连接
hidden_states = attn_residual + hidden_states
# 应用前馈网络层
hidden_states = hidden_states + self.feed_forward(
self.final_layer_norm(hidden_states), deterministic=deterministic
)
# 输出结果
outputs = (hidden_states,)
# 如果需要输出注意力权重,则添加到输出中
if output_attentions:
outputs += (attn_weights,)
return outputs
# 定义一个名为 FlaxWav2Vec2EncoderLayerStableLayerNormCollection 的自定义神经网络模块,继承自 nn.Module
class FlaxWav2Vec2EncoderLayerStableLayerNormCollection(nn.Module):
# 类属性:配置信息,类型为 Wav2Vec2Config
config: Wav2Vec2Config
# 定义数据类型为 jnp.float32,默认为浮点数类型
dtype: jnp.dtype = jnp.float32
# 定义初始化方法,创建多个编码层对象并存储在列表 self.layers 中
def setup(self):
self.layers = [
# 使用 FlaxWav2Vec2EncoderLayerStableLayerNorm 类创建编码层对象,编号从 '0' 到 str(num_hidden_layers-1),并指定数据类型为 self.dtype
FlaxWav2Vec2EncoderLayerStableLayerNorm(self.config, name=str(i), dtype=self.dtype)
for i in range(self.config.num_hidden_layers)
]
# 定义调用方法,接受输入 hidden_states 和多个可选参数,并根据参数返回结果
def __call__(
self,
hidden_states, # 输入的隐藏状态张量
attention_mask=None, # 可选的注意力掩码张量,默认为 None
deterministic: bool = True, # 是否确定性推断,默认为 True
output_attentions: bool = False, # 是否输出注意力张量,默认为 False
output_hidden_states: bool = False, # 是否输出所有隐藏状态,默认为 False
return_dict: bool = True, # 是否以字典形式返回结果,默认为 True
):
# 初始化空的元组变量 all_attentions 和 all_hidden_states,根据参数决定是否存储相应的输出
all_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
# 遍历 self.layers 中的编码层,并依次处理隐藏状态
for i, layer in enumerate(self.layers):
if output_hidden_states:
# 如果需要输出隐藏状态,则将当前隐藏状态存入 all_hidden_states 中
all_hidden_states += (hidden_states,)
# 调用当前层的 __call__ 方法,处理隐藏状态和注意力掩码,根据参数确定是否输出注意力张量
layer_outputs = layer(
hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions
)
# 更新隐藏状态为当前层的输出的第一个元素
hidden_states = layer_outputs[0]
if output_attentions:
# 如果需要输出注意力张量,则将当前层的注意力张量存入 all_attentions 中
all_attentions += (layer_outputs[1],)
if output_hidden_states:
# 如果需要输出隐藏状态,则将最终的隐藏状态存入 all_hidden_states 中
all_hidden_states += (hidden_states,)
# 按照设定的返回方式构建输出元组 outputs
outputs = (hidden_states, all_hidden_states, all_attentions)
if not return_dict:
# 如果不需要以字典形式返回,则返回一个去除 None 值后的元组
return tuple(v for v in outputs if v is not None)
# 如果需要以字典形式返回,则返回一个包含最终隐藏状态、所有隐藏状态和所有注意力张量的 FlaxBaseModelOutput 对象
return FlaxBaseModelOutput(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
)
class FlaxWav2Vec2StableLayerNormEncoder(nn.Module):
# Wav2Vec2Config类型的配置对象
config: Wav2Vec2Config
# 数据类型,默认为32位浮点数
dtype: jnp.dtype = jnp.float32
# 模块设置方法,初始化各个子模块
def setup(self):
# 位置卷积嵌入层对象,使用Wav2Vec2Config配置和指定数据类型
self.pos_conv_embed = FlaxWav2Vec2PositionalConvEmbedding(self.config, dtype=self.dtype)
# 层归一化对象,使用指定的epsilon值和数据类型
self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
# 丢弃层对象,使用指定的丢弃率
self.dropout = nn.Dropout(rate=self.config.hidden_dropout)
# 编码器层集合对象,使用Wav2Vec2Config配置和指定数据类型
self.layers = FlaxWav2Vec2EncoderLayerStableLayerNormCollection(self.config, dtype=self.dtype)
# 对象调用方法,实现编码器的前向计算
def __call__(
self,
hidden_states,
attention_mask=None,
deterministic=True,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
):
# 如果存在注意力掩码,则确保填充的令牌不被注意到
if attention_mask is not None:
hidden_states = jnp.where(
# 根据注意力掩码扩展到hidden_states的形状,将未被掩盖的位置置为0
jnp.broadcast_to(attention_mask[:, :, None], hidden_states.shape), hidden_states, 0
)
# 计算位置嵌入
position_embeddings = self.pos_conv_embed(hidden_states)
# 将位置嵌入加到hidden_states中
hidden_states = hidden_states + position_embeddings
# 对加了位置嵌入的hidden_states进行丢弃操作
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
# 调用编码器层集合对象进行编码器层的前向计算
outputs = self.layers(
hidden_states,
attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# 对编码器输出的最后一个隐藏状态进行层归一化处理
last_hidden_state = self.layer_norm(outputs[0])
# 如果需要返回隐藏状态历史,更新最后一个`hidden_states`元素
hidden_states = None
if output_hidden_states:
hidden_states = outputs[1]
hidden_states = hidden_states[:-1] + (last_hidden_state,)
# 如果不返回字典格式的结果,则展开outputs并返回非空值
if not return_dict:
outputs = (last_hidden_state, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:])
return tuple(v for v in outputs if v is not None)
# 返回FlaxBaseModelOutput对象,包括最后的隐藏状态、隐藏状态历史和注意力信息
return FlaxBaseModelOutput(
last_hidden_state=last_hidden_state, hidden_states=hidden_states, attentions=outputs.attentions
)
class FlaxWav2Vec2GumbelVectorQuantizer(nn.Module):
"""
Vector quantization using gumbel softmax. See [CATEGORICAL REPARAMETERIZATION WITH
GUMBEL-SOFTMAX](https://arxiv.org/pdf/1611.01144.pdf) for more information.
"""
# Wav2Vec2Config类型的配置对象
config: Wav2Vec2Config
# 数据类型,默认为32位浮点数
dtype: jnp.dtype = jnp.float32
# 在设置方法中初始化类的一些属性
def setup(self):
# 将配置中的参数赋值给实例属性
self.num_groups = self.config.num_codevector_groups
self.num_vars = self.config.num_codevectors_per_group
# 检查是否能够均匀分割 codevector_dim
if self.config.codevector_dim % self.num_groups != 0:
# 如果不能整除,抛出数值错误异常
raise ValueError(
f"`config.codevector_dim {self.config.codevector_dim} must be divisible by"
f" `config.num_codevector_groups` {self.num_groups} for concatenation"
)
# 为存储码书变量(码字)预留空间
self.codevectors = self.param(
"codevectors",
jax.nn.initializers.uniform(),
(1, self.num_groups * self.num_vars, self.config.codevector_dim // self.num_groups),
)
# 设置权重投影层
self.weight_proj = nn.Dense(
self.num_groups * self.num_vars,
kernel_init=jax.nn.initializers.normal(1.0),
dtype=self.dtype,
)
# 静态方法:计算困惑度
@staticmethod
def _compute_perplexity(probs, mask=None):
# 如果有掩码,扩展掩码并应用到概率矩阵上
if mask is not None:
mask_extended = jnp.broadcast_to(mask.flatten()[:, None, None], probs.shape)
probs = jnp.where(mask_extended, probs, jnp.zeros_like(probs))
marginal_probs = probs.sum(axis=0) / mask.sum()
else:
# 否则,计算概率矩阵的平均值
marginal_probs = probs.mean(axis=0)
# 计算困惑度
perplexity = jnp.exp(-jnp.sum(marginal_probs * jnp.log(marginal_probs + 1e-7), axis=-1)).sum()
return perplexity
def __call__(self, hidden_states, mask_time_indices=None, deterministic=True, temperature=1):
batch_size, sequence_length, hidden_size = hidden_states.shape
# 将隐藏状态投影到代码向量维度
hidden_states = self.weight_proj(hidden_states)
hidden_states = hidden_states.reshape(batch_size * sequence_length * self.num_groups, -1)
if not deterministic:
# 使用古贝尔分布在可区分的方式中采样代码向量概率
gumbel_rng = self.make_rng("gumbel")
gumbels = jax.random.gumbel(gumbel_rng, hidden_states.shape)
codevector_probs = nn.softmax((hidden_states + gumbels) / temperature)
# 计算困惑度
codevector_soft_dist = nn.softmax(
hidden_states.reshape(batch_size * sequence_length, self.num_groups, -1), axis=-1
)
perplexity = self._compute_perplexity(codevector_soft_dist, mask_time_indices)
else:
# 以非可区分的方式取 argmax
# 计算硬代码向量分布(one-hot)
codevector_idx = hidden_states.argmax(axis=-1)
codevector_probs = jax.nn.one_hot(codevector_idx, hidden_states.shape[-1]) * 1.0
codevector_probs = codevector_probs.reshape(batch_size * sequence_length, self.num_groups, -1)
perplexity = self._compute_perplexity(codevector_probs, mask_time_indices)
codevector_probs = codevector_probs.reshape(batch_size * sequence_length, -1)
# 使用概率值检索代码向量
codevectors_per_group = jnp.expand_dims(codevector_probs, axis=-1) * self.codevectors
codevectors = codevectors_per_group.reshape(batch_size * sequence_length, self.num_groups, self.num_vars, -1)
codevectors = codevectors.sum(-2).reshape(batch_size, sequence_length, -1)
return codevectors, perplexity
class FlaxWav2Vec2Adapter(nn.Module):
config: Wav2Vec2Config
dtype: jnp.dtype = jnp.float32
def setup(self):
# hidden_states require down-projection if feature dims don't match
# 如果特征维度不匹配,则需要对隐藏状态进行降维投影
if self.config.output_hidden_size != self.config.hidden_size:
# Initialize a Dense layer for projection with normal distribution initialization
# 初始化一个用于投影的稠密层,使用正态分布进行初始化
self.proj = nn.Dense(
self.config.output_hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
# Layer normalization for the projection layer
# 投影层的层归一化
self.proj_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
else:
self.proj = self.proj_layer_norm = None
# Initialize the collection of adapter layers
# 初始化适配器层集合
self.layers = FlaxWav2Vec2AdapterLayersCollection(self.config, dtype=self.dtype)
def __call__(self, hidden_states, deterministic=True):
# down-project hidden_states if required
# 如果需要,则对隐藏状态进行降维投影
if self.proj is not None and self.proj_layer_norm is not None:
hidden_states = self.proj(hidden_states)
hidden_states = self.proj_layer_norm(hidden_states)
# Pass hidden_states through adapter layers
# 通过适配器层处理隐藏状态
hidden_states = self.layers(hidden_states)
return hidden_states
class FlaxWav2Vec2AdapterLayer(nn.Module):
config: Wav2Vec2Config
dtype: jnp.dtype = jnp.float32
def setup(self):
# Initialize a convolutional layer for the adapter layer
# 初始化适配器层的卷积层
self.conv = nn.Conv(
features=2 * self.config.output_hidden_size,
kernel_size=(self.config.adapter_kernel_size,),
strides=(self.config.adapter_stride,),
padding=((1, 1),),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
def __call__(self, hidden_states):
# Apply convolution to hidden_states
# 将卷积应用于隐藏状态
hidden_states = self.conv(hidden_states)
# Apply gated linear unit (GLU) activation along axis 2
# 沿着轴 2 应用门控线性单元(GLU)激活函数
hidden_states = nn.glu(hidden_states, axis=2)
return hidden_states
class FlaxWav2Vec2AdapterLayersCollection(nn.Module):
config: Wav2Vec2Config
dtype: jnp.dtype = jnp.float32
def setup(self):
# Initialize a list of adapter layers
# 初始化适配器层的列表
self.layers = [
FlaxWav2Vec2AdapterLayer(self.config, name=str(i), dtype=self.dtype)
for i in range(self.config.num_adapter_layers)
]
def __call__(self, hidden_states):
# Iterate through each adapter layer and apply it to hidden_states
# 遍历每个适配器层,并将其应用于隐藏状态
for conv_layer in self.layers:
hidden_states = conv_layer(hidden_states)
return hidden_states
class FlaxWav2Vec2PreTrainedModel(FlaxPreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = Wav2Vec2Config
base_model_prefix: str = "wav2vec2"
main_input_name = "input_values"
module_class: nn.Module = None
def __init__(
self,
config: Wav2Vec2Config,
input_shape: Tuple = (1, 1024),
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs,
):
# 使用配置和数据类型初始化模块对象
module = self.module_class(config=config, dtype=dtype, **kwargs)
# 调用父类初始化方法,传递配置、模块对象、输入形状、随机种子、数据类型和是否执行初始化的标志
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# 初始化输入张量
input_values = jnp.zeros(input_shape, dtype="i4")
# 创建一个与输入值形状相同的全1张量作为注意力掩码
attention_mask = jnp.ones_like(input_values)
# 拆分随机数生成器为两部分,一个用于参数,一个用于dropout
params_rng, dropout_rng = jax.random.split(rng, 2)
rngs = {"params": params_rng, "dropout": dropout_rng}
# 使用模块的初始化方法初始化参数,返回参数字典
random_params = self.module.init(rngs, input_values, attention_mask, return_dict=False)["params"]
if params is not None:
# 如果传入了额外的参数,将随机生成的参数与传入的参数合并
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
# 否则直接返回随机生成的参数
return random_params
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
def __call__(
self,
input_values,
attention_mask=None,
mask_time_indices=None,
params: dict = None,
dropout_rng: jax.random.PRNGKey = None,
train: bool = False,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
freeze_feature_encoder: bool = False,
return_dict: Optional[bool] = None,
):
# 如果输出注意力没有明确指定,则使用配置中的设置
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
# 如果输出隐藏状态没有明确指定,则使用配置中的设置
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
# 如果返回字典没有明确指定,则使用配置中的设置
return_dict = return_dict if return_dict is not None else self.config.return_dict
# 获取输入数据的批量大小和序列长度
batch_size, sequence_length = input_values.shape
# 如果没有提供注意力掩码,则创建一个全1的注意力掩码
if attention_mask is None:
attention_mask = jnp.ones((batch_size, sequence_length))
# 处理可能存在的随机数生成器
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
# 构建输入参数字典,如果未提供params则使用self.params
inputs = {"params": params or self.params}
# 调用模块的应用方法,执行模型前向传播
return self.module.apply(
inputs,
jnp.array(input_values, dtype="f4"),
jnp.array(attention_mask, dtype="i4"),
mask_time_indices,
not train,
output_attentions,
output_hidden_states,
freeze_feature_encoder,
return_dict,
rngs=rngs,
)
def _get_feat_extract_output_lengths(
self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
):
# 调用模块的特征提取方法,获取输出长度
return self.module._get_feat_extract_output_lengths(input_lengths, add_adapter=add_adapter)
# 定义一个名为 FlaxWav2Vec2Module 的 PyTorch 模块
class FlaxWav2Vec2Module(nn.Module):
# 类型注解:配置信息为 Wav2Vec2Config 类型
config: Wav2Vec2Config
# 数据类型,默认为 jnp.float32
dtype: jnp.dtype = jnp.float32
# 模块初始化方法
def setup(self):
# 初始化特征提取器,使用配置信息和指定数据类型
self.feature_extractor = FlaxWav2Vec2FeatureEncoder(self.config, dtype=self.dtype)
# 初始化特征投影器,使用配置信息和指定数据类型
self.feature_projection = FlaxWav2Vec2FeatureProjection(self.config, dtype=self.dtype)
# 初始化掩码后的谱图嵌入参数,形状为 (hidden_size,)
self.masked_spec_embed = self.param(
"masked_spec_embed", jax.nn.initializers.uniform(), (self.config.hidden_size,)
)
# 如果配置指定使用稳定层归一化
if self.config.do_stable_layer_norm:
# 初始化编码器,使用配置信息和指定数据类型
self.encoder = FlaxWav2Vec2StableLayerNormEncoder(self.config, dtype=self.dtype)
else:
# 抛出错误,暂不支持稳定层归一化未启用的情况
raise NotImplementedError("``config.do_stable_layer_norm is False`` is currently not supported.")
# 如果配置指定添加适配器,初始化适配器
self.adapter = FlaxWav2Vec2Adapter(self.config, dtype=self.dtype) if self.config.add_adapter else None
# 模块的调用方法,用于执行模型前向传播
def __call__(
self,
input_values,
attention_mask=None,
mask_time_indices=None,
deterministic=True,
output_attentions=None,
output_hidden_states=None,
freeze_feature_encoder=False,
return_dict=None,
):
# 提取特征向量
extract_features = self.feature_extractor(input_values, freeze_feature_encoder=freeze_feature_encoder)
# 如果有注意力掩码
if attention_mask is not None:
# 计算对应于特征向量的减少注意力掩码
attention_mask = self._get_feature_vector_attention_mask(
extract_features.shape[1], attention_mask, add_adapter=False
)
# 特征投影
hidden_states, extract_features = self.feature_projection(extract_features, deterministic=deterministic)
# 如果有时间轴索引的掩码
if mask_time_indices is not None:
# 在时间轴上应用 SpecAugment,并使用给定的索引
hidden_states = jnp.where(
jnp.broadcast_to(mask_time_indices[:, :, None], hidden_states.shape),
jnp.broadcast_to(self.masked_spec_embed[None, None, :], hidden_states.shape),
hidden_states,
)
# 编码器的输出
encoder_outputs = self.encoder(
hidden_states,
attention_mask=attention_mask,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# 编码器的隐藏状态
hidden_states = encoder_outputs[0]
# 如果有适配器,应用适配器
if self.adapter is not None:
hidden_states = self.adapter(hidden_states)
# 如果不返回字典形式的结果
if not return_dict:
# 返回元组形式的结果:(隐藏状态, 提取的特征) + 编码器输出中的其余部分
return (hidden_states, extract_features) + encoder_outputs[1:]
# 返回 FlaxWav2Vec2BaseModelOutput 类的实例,包括最后的隐藏状态、提取的特征、隐藏状态和注意力权重
return FlaxWav2Vec2BaseModelOutput(
last_hidden_state=hidden_states,
extract_features=extract_features,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
# 辅助方法:获取特征提取器的输出长度
def _get_feat_extract_output_lengths(
self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
):
"""
Computes the output length of the convolutional layers
计算卷积层的输出长度
"""
add_adapter = self.config.add_adapter if add_adapter is None else add_adapter
def _conv_out_length(input_length, kernel_size, stride):
# 1D convolutional layer output length formula taken
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
# 1D卷积层输出长度的计算公式,参考自 https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
return (input_length - kernel_size) // stride + 1
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
if add_adapter:
for _ in range(self.config.num_adapter_layers):
input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)
return input_lengths
def _get_feature_vector_attention_mask(
self, feature_vector_length: int, attention_mask: jnp.ndarray, add_adapter=None
):
# Effectively attention_mask.sum(-1), but not inplace to be able to run
# on inference mode.
# 实际上是 attention_mask.sum(-1),但不是原地操作,以便在推断模式下运行。
non_padded_lengths = attention_mask.cumsum(axis=-1)[:, -1]
output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter)
batch_size = attention_mask.shape[0]
attention_mask = jnp.zeros((batch_size, feature_vector_length), dtype=attention_mask.dtype)
# these two operations makes sure that all values
# before the output lengths indices are attended to
# 这两个操作确保所有输出长度索引之前的值都被关注到
attention_mask = attention_mask.at[jnp.arange(attention_mask.shape[0]), output_lengths - 1].set(1)
attention_mask = jnp.flip(jnp.flip(attention_mask, -1).cumsum(-1), -1).astype("bool")
return attention_mask
# 添加函数文档字符串和装饰器,描述此类作为没有特定输出头部的裸Wav2Vec2模型转换器
@add_start_docstrings(
"The bare Wav2Vec2 Model transformer outputting raw hidden-states without any specific head on top.",
WAV_2_VEC_2_START_DOCSTRING,
)
# 定义 FlaxWav2Vec2Model 类,继承自 FlaxWav2Vec2PreTrainedModel 类
class FlaxWav2Vec2Model(FlaxWav2Vec2PreTrainedModel):
module_class = FlaxWav2Vec2Module # 设置模块类为 FlaxWav2Vec2Module
# 定义 FLAX_WAV2VEC2_MODEL_DOCSTRING 作为模型的文档字符串,描述返回值和示例用法
FLAX_WAV2VEC2_MODEL_DOCSTRING = """
Returns:
Example:
```
>>> from transformers import AutoProcessor, FlaxWav2Vec2Model
>>> from datasets import load_dataset
>>> import soundfile as sf
>>> processor = AutoProcessor.from_pretrained("facebook/wav2vec2-large-lv60")
>>> model = FlaxWav2Vec2Model.from_pretrained("facebook/wav2vec2-large-lv60")
>>> def map_to_array(batch):
... speech, _ = sf.read(batch["file"])
... batch["speech"] = speech
... return batch
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> ds = ds.map(map_to_array)
>>> input_values = processor(
... ds["speech"][0], sampling_rate=16_000, return_tensors="np"
... ).input_values
>>> hidden_states = model(input_values).last_hidden_state
```
"""
# 调用 overwrite_call_docstring 函数,将输入的文档字符串添加到 FlaxWav2Vec2Model 类的文档字符串中
overwrite_call_docstring(
FlaxWav2Vec2Model,
WAV_2_VEC_2_INPUTS_DOCSTRING + FLAX_WAV2VEC2_MODEL_DOCSTRING,
)
# 调用 append_replace_return_docstrings 函数,为 FlaxWav2Vec2Model 类添加返回值文档字符串
append_replace_return_docstrings(
FlaxWav2Vec2Model, output_type=FlaxWav2Vec2BaseModelOutput, config_class=Wav2Vec2Config
)
# 定义 FlaxWav2Vec2ForCTCModule 类,继承自 nn.Module
class FlaxWav2Vec2ForCTCModule(nn.Module):
config: Wav2Vec2Config
dtype: jnp.dtype = jnp.float32
# 初始化函数,设置模块及其成员
def setup(self):
self.wav2vec2 = FlaxWav2Vec2Module(self.config, dtype=self.dtype) # 初始化 wav2vec2 模块
self.dropout = nn.Dropout(rate=self.config.final_dropout) # 初始化 dropout 层
self.lm_head = nn.Dense( # 初始化语言模型头部 Dense 层
self.config.vocab_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
# 调用函数,定义模型的前向传播逻辑
def __call__(
self,
input_values,
attention_mask=None,
mask_time_indices=None,
deterministic=True,
output_attentions=None,
output_hidden_states=None,
freeze_feature_encoder=False,
return_dict=None,
):
# 调用 wav2vec2 模块进行前向传播,获取输出
outputs = self.wav2vec2(
input_values,
attention_mask=attention_mask,
mask_time_indices=mask_time_indices,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
freeze_feature_encoder=freeze_feature_encoder,
return_dict=return_dict,
)
hidden_states = outputs[0] # 获取隐藏状态
hidden_states = self.dropout(hidden_states, deterministic=deterministic) # 应用 dropout
logits = self.lm_head(hidden_states) # 计算 logits
if not return_dict:
return (logits,) + outputs[2:] # 返回 logits 和其他输出
# 返回包含 logits、隐藏状态和注意力的 FlaxCausalLMOutput 对象
return FlaxCausalLMOutput(logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
def _get_feat_extract_output_lengths(
self,
input_lengths: Union[jnp.ndarray, int],
add_adapter: Optional[bool] = None,
):
"""
Computes the output length of the convolutional layers
"""
# 如果 add_adapter 未提供,则使用配置中的默认值
add_adapter = self.config.add_adapter if add_adapter is None else add_adapter
def _conv_out_length(input_length, kernel_size, stride):
# 从 https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html 获取的
# 1维卷积层输出长度计算公式
return (input_length - kernel_size) // stride + 1
# 遍历每个卷积核大小和步长,并计算每层卷积的输出长度
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
# 如果需要添加适配器层,根据配置中的适配器层数量和步长进行计算
if add_adapter:
for _ in range(self.config.num_adapter_layers):
input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)
# 返回最终计算得到的输出长度
return input_lengths
# 使用装饰器为 FlaxWav2Vec2ForCTC 类添加文档字符串,描述其为在 Connectionist Temporal Classification (CTC) 上加有语言建模头部的 Wav2Vec2 模型
@add_start_docstrings(
"Wav2Vec2 Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).",
WAV_2_VEC_2_START_DOCSTRING,
)
# 定义 FlaxWav2Vec2ForCTC 类,继承自 FlaxWav2Vec2PreTrainedModel 类
class FlaxWav2Vec2ForCTC(FlaxWav2Vec2PreTrainedModel):
# 将 module_class 属性指定为 FlaxWav2Vec2ForCTCModule
module_class = FlaxWav2Vec2ForCTCModule
# FLAX_WAV2VEC2_FOR_CTC_DOCSTRING 是一个长字符串,描述了 FlaxWav2Vec2ForCTC 类的返回值和示例用法
# 调用 overwrite_call_docstring 函数,为 FlaxWav2Vec2ForCTC 类的文档字符串添加输入参数文档和 FLAX_WAV2VEC2_FOR_CTC_DOCSTRING 内容
overwrite_call_docstring(
FlaxWav2Vec2ForCTC,
WAV_2_VEC_2_INPUTS_DOCSTRING + FLAX_WAV2VEC2_FOR_CTC_DOCSTRING,
)
# 调用 append_replace_return_docstrings 函数,为 FlaxWav2Vec2ForCTC 类添加输出类型文档,并指定 output_type 和 config_class 参数
append_replace_return_docstrings(FlaxWav2Vec2ForCTC, output_type=FlaxCausalLMOutput, config_class=Wav2Vec2Config)
# 定义 FlaxWav2Vec2ForPreTrainingModule 类,继承自 nn.Module 类
class FlaxWav2Vec2ForPreTrainingModule(nn.Module):
# 设置 config 属性为 Wav2Vec2Config 类型,dtype 属性默认为 jnp.float32
config: Wav2Vec2Config
dtype: jnp.dtype = jnp.float32
# 定义 setup 方法,初始化模块
def setup(self):
# 实例化 FlaxWav2Vec2Module 类,并存储在 self.wav2vec2 属性中
self.wav2vec2 = FlaxWav2Vec2Module(self.config, dtype=self.dtype)
# 使用 self.config.feat_quantizer_dropout 参数初始化 nn.Dropout 类,存储在 self.dropout_features 属性中
self.dropout_features = nn.Dropout(self.config.feat_quantizer_dropout)
# 实例化 FlaxWav2Vec2GumbelVectorQuantizer 类,并存储在 self.quantizer 属性中
self.quantizer = FlaxWav2Vec2GumbelVectorQuantizer(self.config, dtype=self.dtype)
# 使用 self.config.proj_codevector_dim 参数初始化 nn.Dense 类,存储在 self.project_q 属性中
self.project_q = nn.Dense(
self.config.proj_codevector_dim,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
# 使用 self.config.proj_codevector_dim 参数初始化 nn.Dense 类,存储在 self.project_hid 属性中
self.project_hid = nn.Dense(
self.config.proj_codevector_dim,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
# 定义 __call__ 方法,实现对象的可调用性
def __call__(
self,
input_values,
attention_mask=None,
mask_time_indices=None,
gumbel_temperature: int = 1,
deterministic: bool = True,
output_attentions=None,
output_hidden_states=None,
freeze_feature_encoder=False,
return_dict=None,
# 函数参数的注释可以在文档字符串中找到
**kwargs,
):
# 省略方法内部的具体实现,不在注释范围内
):
r"""
Returns:
Example:
```
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# 使用给定的参数调用wav2vec2模型,获取输出
outputs = self.wav2vec2(
input_values,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
mask_time_indices=mask_time_indices,
deterministic=deterministic,
freeze_feature_encoder=freeze_feature_encoder,
return_dict=return_dict,
)
# 将所有转换后的特征(包括被掩码的)投影到最终的向量量化维度
transformer_features = self.project_hid(outputs[0])
# 量化所有(未被掩码的)提取特征并投影到最终的向量量化维度
extract_features = self.dropout_features(outputs[1], deterministic=deterministic)
quantized_features, codevector_perplexity = self.quantizer(
extract_features, mask_time_indices, deterministic=deterministic, temperature=gumbel_temperature
)
quantized_features = self.project_q(quantized_features)
# 如果不使用返回字典,则返回元组形式的输出
if not return_dict:
return (transformer_features, quantized_features, codevector_perplexity) + outputs[2:]
# 使用FlaxWav2Vec2ForPreTrainingOutput类封装输出,包括所有相关信息
return FlaxWav2Vec2ForPreTrainingOutput(
projected_states=transformer_features,
projected_quantized_states=quantized_features,
codevector_perplexity=codevector_perplexity,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def _get_feat_extract_output_lengths(
self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
):
"""
计算卷积层的输出长度
"""
add_adapter = self.config.add_adapter if add_adapter is None else add_adapter
def _conv_out_length(input_length, kernel_size, stride):
# 从 https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html 中得到的一维卷积层输出长度公式
return (input_length - kernel_size) // stride + 1
# 遍历配置的卷积核大小和步幅,计算每一层卷积层的输出长度
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
# 如果需要添加适配器层,则计算适配器层的输出长度
if add_adapter:
for _ in range(self.config.num_adapter_layers):
input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)
return input_lengths
@add_start_docstrings("""Wav2Vec2 Model with a quantizer and `VQ` head on top.""", WAV_2_VEC_2_START_DOCSTRING)
class FlaxWav2Vec2ForPreTraining(FlaxWav2Vec2PreTrainedModel):
module_class = FlaxWav2Vec2ForPreTrainingModule
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
# 覆盖原始定义,添加了 `gumbel_temperature` 输入参数
def __call__(
self,
input_values,
attention_mask=None,
mask_time_indices=None,
gumbel_temperature: int = 1,
params: dict = None,
dropout_rng: jax.random.PRNGKey = None,
gumbel_rng: jax.random.PRNGKey = None,
train: bool = False,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
freeze_feature_encoder: bool = False,
return_dict: Optional[bool] = None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
batch_size, sequence_length = input_values.shape
# 如果未提供注意力掩码,则创建一个全为1的注意力掩码
if attention_mask is None:
attention_mask = jnp.ones((batch_size, sequence_length))
# 处理可能需要的任何伪随机数生成器
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
if gumbel_rng is not None:
rngs["gumbel"] = gumbel_rng
# 准备模型输入
inputs = {"params": params or self.params}
# 调用模块的前向方法
return self.module.apply(
inputs,
jnp.array(input_values, dtype="f4"),
jnp.array(attention_mask, dtype="i4"),
mask_time_indices,
gumbel_temperature,
not train,
output_attentions,
output_hidden_states,
freeze_feature_encoder,
return_dict,
rngs=rngs,
)
FLAX_WAV2VEC2_FOR_PRETRAINING_DOCSTRING = """
Returns:
Example:
```
>>> import optax
>>> import numpy as np
>>> import jax.numpy as jnp
>>> from transformers import AutoFeatureExtractor, FlaxWav2Vec2ForPreTraining
>>> from transformers.models.wav2vec2.modeling_flax_wav2vec2 import _compute_mask_indices
>>> from datasets import load_dataset
>>> import soundfile as sf
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-large-lv60")
>>> model = FlaxWav2Vec2ForPreTraining.from_pretrained("facebook/wav2vec2-large-lv60")
>>> def map_to_array(batch):
... speech, _ = sf.read(batch["file"])
... batch["speech"] = speech
... return batch
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> ds = ds.map(map_to_array)
>>> input_values = feature_extractor(ds["speech"][0], return_tensors="np").input_values
>>>
>>> batch_size, raw_sequence_length = input_values.shape
>>> sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length)
>>> mask_time_indices = _compute_mask_indices((batch_size, sequence_length), mask_prob=0.2, mask_length=2)
>>> outputs = model(input_values, mask_time_indices=mask_time_indices)
>>>
>>> cosine_sim = optax.cosine_similarity(outputs.projected_states, outputs.projected_quantized_states)
>>>
>>> assert np.asarray(cosine_sim)[mask_time_indices].mean() > 0.5
"""
为 `FlaxWav2Vec2ForPreTraining` 类的 `__call__` 方法覆盖文档字符串,
使用 `WAV_2_VEC_2_INPUTS_DOCSTRING` 和 `FLAX_WAV2VEC2_FOR_PRETRAINING_DOCSTRING` 进行替换。
"""
overwrite_call_docstring(
FlaxWav2Vec2ForPreTraining,
WAV_2_VEC_2_INPUTS_DOCSTRING + FLAX_WAV2VEC2_FOR_PRETRAINING_DOCSTRING,
)
"""
为 `FlaxWav2Vec2ForPreTraining` 类附加和替换返回值文档字符串,
使用 `FlaxWav2Vec2ForPreTrainingOutput` 作为输出类型,`Wav2Vec2Config` 作为配置类。
"""
append_replace_return_docstrings(
FlaxWav2Vec2ForPreTraining, output_type=FlaxWav2Vec2ForPreTrainingOutput, config_class=Wav2Vec2Config
)
.\models\wav2vec2\modeling_tf_wav2vec2.py
""" TensorFlow Wav2Vec2 模型。"""
from __future__ import annotations
import warnings
from dataclasses import dataclass
from typing import Any, Optional, Tuple, Union
import numpy as np
import tensorflow as tf
from ...activations_tf import get_tf_activation
from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput, TFSequenceClassifierOutput
from ...modeling_tf_utils import (
TFPreTrainedModel,
get_initializer,
keras,
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list, stable_softmax
from ...utils import (
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from .configuration_wav2vec2 import Wav2Vec2Config
logger = logging.get_logger(__name__)
_HIDDEN_STATES_START_POSITION = 2
_CHECKPOINT_FOR_DOC = "facebook/wav2vec2-base-960h"
_CONFIG_FOR_DOC = "Wav2Vec2Config"
TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST = [
"facebook/wav2vec2-base-960h",
"facebook/wav2vec2-large-960h",
"facebook/wav2vec2-large-960h-lv60",
"facebook/wav2vec2-large-960h-lv60-self",
]
LARGE_NEGATIVE = -1e8
@dataclass
class TFWav2Vec2BaseModelOutput(ModelOutput):
"""
[`TFWav2Vec2BaseModelOutput`] 的输出类型,包含潜在的隐藏状态和注意力。
继承自 ModelOutput 类。
"""
"""
Args:
last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
模型最后一层输出的隐藏状态序列。
extract_features (`tf.Tensor` of shape `(batch_size, sequence_length, conv_dim[-1])`):
模型最后一个卷积层提取的特征向量序列。
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
包含模型每一层输出的隐藏状态的元组。形状为 `(batch_size, sequence_length, hidden_size)`。
模型每一层的隐藏状态,以及初始嵌入输出。
attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
包含注意力权重的元组。形状为 `(batch_size, num_heads, sequence_length, sequence_length)`。
经过注意力 softmax 后的注意力权重,用于计算自注意力头中的加权平均值。
"""
last_hidden_state: tf.Tensor = None
extract_features: tf.Tensor = None
hidden_states: Tuple[tf.Tensor] | None = None
attentions: Tuple[tf.Tensor] | None = None
def _sample_without_replacement(distribution, num_samples):
"""
Categorical sampling without replacement is currently not implemented. The gumbel-max trick will do for now - see
https://github.com/tensorflow/tensorflow/issues/9260 for more info
"""
z = -tf.math.log(tf.random.uniform(shape_list(distribution), 0, 1))
_, indices = tf.nn.top_k(distribution + z, num_samples)
return indices
def _scatter_values_on_batch_indices(values, batch_indices, output_shape):
"""
Scatter function as in PyTorch with indices in format (batch_dim, indixes)
"""
indices_shape = shape_list(batch_indices)
broad_casted_batch_dims = tf.reshape(
tf.broadcast_to(tf.expand_dims(tf.range(indices_shape[0]), axis=-1), indices_shape), [1, -1]
)
pair_indices = tf.transpose(tf.concat([broad_casted_batch_dims, tf.reshape(batch_indices, [1, -1])], 0))
return tf.scatter_nd(pair_indices, tf.reshape(values, [-1]), output_shape)
def _compute_mask_indices(
shape: Tuple[int, int],
mask_prob: float,
mask_length: int,
min_masks: int = 0,
) -> tf.Tensor:
"""
Computes random mask spans for a given shape
Args:
shape: the shape for which to compute masks.
should be of size 2 where first element is batch size and 2nd is timesteps
attention_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
mask_prob:
probability for each token to be chosen as start of the span to be masked. this will be multiplied by
number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
however due to overlaps, the actual number will be smaller (unless no_overlap is True)
mask_length: size of the mask
min_masks: minimum number of masked spans
Adapted from [fairseq's
data_utils.py](https://github.com/pytorch/fairseq/blob/e0788f7007a8473a76db573985031f3c94201e79/fairseq/data/data_utils.py#L376).
"""
batch_size, sequence_length = shape
if mask_length < 1:
raise ValueError("`mask_length` has to be bigger than 0.")
tf.debugging.assert_less(
mask_length,
sequence_length,
message=(
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and"
f" `sequence_length`: {sequence_length}`"
),
)
num_masked_spans = mask_prob * tf.cast(sequence_length, tf.float32) / mask_length + tf.random.uniform((1,))
num_masked_spans = tf.maximum(num_masked_spans, min_masks)
num_masked_spans = tf.cast(num_masked_spans, tf.int32)
num_masked_spans = tf.math.minimum(sequence_length // mask_length, num_masked_spans)
num_masked_spans = tf.squeeze(num_masked_spans)
spec_aug_mask = tf.zeros((batch_size, sequence_length), dtype=tf.int32)
uniform_dist = tf.ones((batch_size, sequence_length - (mask_length - 1)))
spec_aug_mask_idxs = _sample_without_replacement(uniform_dist, num_masked_spans)
spec_aug_mask_idxs = tf.expand_dims(spec_aug_mask_idxs, -1)
spec_aug_mask_idxs = tf.tile(spec_aug_mask_idxs, (1, 1, mask_length))
spec_aug_mask_idxs = tf.reshape(spec_aug_mask_idxs, (batch_size, num_masked_spans * mask_length))
offsets = tf.range(mask_length)[tf.newaxis, tf.newaxis, :]
offsets = tf.tile(offsets, (batch_size, num_masked_spans, 1))
offsets = tf.reshape(offsets, (batch_size, num_masked_spans * mask_length))
spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
spec_aug_mask = _scatter_values_on_batch_indices(
tf.ones_like(spec_aug_mask_idxs), spec_aug_mask_idxs, tf.shape(spec_aug_mask)
)
return spec_aug_mask
def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
src_len = shape_list(mask)[1]
tgt_len = tgt_len if tgt_len is not None else src_len
one_cst = tf.constant(1.0)
mask = tf.cast(mask, dtype=one_cst.dtype)
expanded_mask = tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1))
return (one_cst - expanded_mask) * LARGE_NEGATIVE
class TFWav2Vec2GroupNorm(keras.layers.Layer):
"""
From tensorflow-addons https://www.tensorflow.org/addons/api_docs/python/tfa/layers/GroupNormalization
"""
def __init__(
self,
groups: int = 32,
axis: int = -1,
epsilon: float = 1e-3,
center: bool = True,
scale: bool = True,
beta_initializer: keras.initializers.Initializer = "zeros",
gamma_initializer: keras.initializers.Initializer = "ones",
beta_regularizer: keras.regularizers.Regularizer = None,
gamma_regularizer: keras.regularizers.Regularizer = None,
beta_constraint: keras.constraints.Constraint = None,
gamma_constraint: keras.constraints.Constraint = None,
**kwargs,
):
super().__init__(**kwargs)
self.supports_masking = True
self.groups = groups
self.axis = axis
self.epsilon = epsilon
self.center = center
self.scale = scale
self.beta_initializer = keras.initializers.get(beta_initializer)
self.gamma_initializer = keras.initializers.get(gamma_initializer)
self.beta_regularizer = keras.regularizers.get(beta_regularizer)
self.gamma_regularizer = keras.regularizers.get(gamma_regularizer)
self.beta_constraint = keras.constraints.get(beta_constraint)
self.gamma_constraint = keras.constraints.get(gamma_constraint)
self._check_axis()
def build(self, input_shape):
self._check_if_input_shape_is_none(input_shape)
self._set_number_of_groups_for_instance_norm(input_shape)
self._check_size_of_dimensions(input_shape)
self._create_input_spec(input_shape)
self._add_gamma_weight(input_shape)
self._add_beta_weight(input_shape)
self.built = True
super().build(input_shape)
def call(self, inputs):
input_shape = keras.backend.int_shape(inputs)
tensor_input_shape = tf.shape(inputs)
reshaped_inputs, group_shape = self._reshape_into_groups(inputs, input_shape, tensor_input_shape)
normalized_inputs = self._apply_normalization(reshaped_inputs, input_shape)
is_instance_norm = (input_shape[self.axis] // self.groups) == 1
if not is_instance_norm:
outputs = tf.reshape(normalized_inputs, tensor_input_shape)
else:
outputs = normalized_inputs
return outputs
def get_config(self):
config = {
"groups": self.groups,
"axis": self.axis,
"epsilon": self.epsilon,
"center": self.center,
"scale": self.scale,
"beta_initializer": keras.initializers.serialize(self.beta_initializer),
"gamma_initializer": keras.initializers.serialize(self.gamma_initializer),
"beta_regularizer": keras.regularizers.serialize(self.beta_regularizer),
"gamma_regularizer": keras.regularizers.serialize(self.gamma_regularizer),
"beta_constraint": keras.constraints.serialize(self.beta_constraint),
"gamma_constraint": keras.constraints.serialize(self.gamma_constraint),
}
base_config = super().get_config()
return {**base_config, **config}
def compute_output_shape(self, input_shape):
return input_shape
def _reshape_into_groups(self, inputs, input_shape, tensor_input_shape):
group_shape = [tensor_input_shape[i] for i in range(len(input_shape))]
is_instance_norm = (input_shape[self.axis] // self.groups) == 1
if not is_instance_norm:
group_shape[self.axis] = input_shape[self.axis] // self.groups
group_shape.insert(self.axis, self.groups)
group_shape = tf.stack(group_shape)
reshaped_inputs = tf.reshape(inputs, group_shape)
return reshaped_inputs, group_shape
else:
return inputs, group_shape
def _apply_normalization(self, reshaped_inputs, input_shape):
group_shape = keras.backend.int_shape(reshaped_inputs)
group_reduction_axes = list(range(1, len(group_shape)))
is_instance_norm = (input_shape[self.axis] // self.groups) == 1
if not is_instance_norm:
axis = -2 if self.axis == -1 else self.axis - 1
else:
axis = -1 if self.axis == -1 else self.axis - 1
group_reduction_axes.pop(axis)
mean, variance = tf.nn.moments(reshaped_inputs, group_reduction_axes, keepdims=True)
gamma, beta = self._get_reshaped_weights(input_shape)
normalized_inputs = tf.nn.batch_normalization(
reshaped_inputs,
mean=mean,
variance=variance,
scale=gamma,
offset=beta,
variance_epsilon=self.epsilon,
)
return normalized_inputs
def _get_reshaped_weights(self, input_shape):
broadcast_shape = self._create_broadcast_shape(input_shape)
gamma = None
beta = None
if self.scale:
gamma = tf.reshape(self.gamma, broadcast_shape)
if self.center:
beta = tf.reshape(self.beta, broadcast_shape)
return gamma, beta
def _check_if_input_shape_is_none(self, input_shape):
dim = input_shape[self.axis]
if dim is None:
raise ValueError(
"Axis "
+ str(self.axis)
+ " of input tensor should have a defined dimension but the layer received an input with shape "
+ str(input_shape)
+ "."
)
def _set_number_of_groups_for_instance_norm(self, input_shape):
dim = input_shape[self.axis]
if self.groups == -1:
self.groups = dim
def _check_size_of_dimensions(self, input_shape):
dim = input_shape[self.axis]
if dim < self.groups:
raise ValueError(
"Number of groups ("
+ str(self.groups)
+ ") cannot be more than the number of channels ("
+ str(dim)
+ ")."
)
if dim % self.groups != 0:
raise ValueError(
"Number of groups ("
+ str(self.groups)
+ ") must be a multiple of the number of channels ("
+ str(dim)
+ ")."
)
def _check_axis(self):
if self.axis == 0:
raise ValueError(
"You are trying to normalize your batch axis. Do you want to use tf.layer.batch_normalization instead"
)
def _create_input_spec(self, input_shape):
dim = input_shape[self.axis]
self.input_spec = keras.layers.InputSpec(ndim=len(input_shape), axes={self.axis: dim})
def _add_gamma_weight(self, input_shape):
dim = input_shape[self.axis]
shape = (dim,)
if self.scale:
self.gamma = self.add_weight(
shape=shape,
name="gamma",
initializer=self.gamma_initializer,
regularizer=self.gamma_regularizer,
constraint=self.gamma_constraint,
)
else:
self.gamma = None
def _add_beta_weight(self, input_shape):
dim = input_shape[self.axis]
shape = (dim,)
if self.center:
self.beta = self.add_weight(
shape=shape,
name="beta",
initializer=self.beta_initializer,
regularizer=self.beta_regularizer,
constraint=self.beta_constraint,
)
else:
self.beta = None
def _create_broadcast_shape(self, input_shape):
broadcast_shape = [1] * len(input_shape)
is_instance_norm = (input_shape[self.axis] // self.groups) == 1
if not is_instance_norm:
broadcast_shape[self.axis] = input_shape[self.axis] // self.groups
broadcast_shape.insert(self.axis, self.groups)
else:
broadcast_shape[self.axis] = self.groups
return broadcast_shape
class TFWav2Vec2WeightNormConv1D(keras.layers.Conv1D):
"""Adapted from https://www.tensorflow.org/probability/api_docs/python/tfp/layers/weight_norm/WeightNorm"""
def __init__(self, filters, kernel_size, groups, explicit_padding, **kwargs):
super().__init__(
filters=filters,
kernel_size=kernel_size,
groups=groups,
padding="valid",
use_bias=True,
bias_initializer="he_normal",
**kwargs,
)
self.explicit_padding = explicit_padding
self.filter_axis = 2
self.kernel_norm_axes = tf.constant([0, 1])
def _init_norm(self):
"""Set the norm of the weight vector."""
kernel_norm = tf.sqrt(tf.reduce_sum(tf.square(self.weight_v), axis=self.kernel_norm_axes))
self.weight_g.assign(kernel_norm[:, tf.newaxis, tf.newaxis])
def _normalize_kernel(self):
"""Generate normalized weights."""
kernel = tf.nn.l2_normalize(self.weight_v, axis=self.kernel_norm_axes) * tf.transpose(self.weight_g)
self.kernel = tf.transpose(kernel)
def build(self, input_shape):
if not self.built:
super().build(input_shape)
self.kernel = tf.Variable(tf.transpose(self.kernel), name="weight_v", trainable=True)
self.weight_v = self.kernel
self.weight_g = self.add_weight(
name="weight_g",
shape=(int(self.weight_v.shape[self.filter_axis]), 1, 1),
initializer="ones",
dtype=self.weight_v.dtype,
trainable=True,
)
self._init_norm()
self.bias = self.add_weight(name="bias", shape=(self.filters,), initializer="zeros", trainable=True)
def call(self, inputs):
self._normalize_kernel()
padded_inputs = tf.pad(inputs, ((0, 0), (self.explicit_padding, self.explicit_padding), (0, 0)))
output = super().call(padded_inputs)
return output
class TFWav2Vec2NoLayerNormConvLayer(keras.layers.Layer):
def __init__(self, config: Wav2Vec2Config, layer_id: int = 0, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.in_conv_dim = config.conv_dim[layer_id] if layer_id > 0 else 1
self.out_conv_dim = config.conv_dim[layer_id]
self.conv = keras.layers.Conv1D(
filters=self.out_conv_dim,
kernel_size=config.conv_kernel[layer_id],
strides=config.conv_stride[layer_id],
use_bias=config.conv_bias,
name="conv",
)
self.activation = get_tf_activation(config.feat_extract_activation)
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.conv(hidden_states)
hidden_states = self.activation(hidden_states)
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "conv", None) is not None:
with tf.name_scope(self.conv.name):
self.conv.build([None, None, self.in_conv_dim])
class TFWav2Vec2LayerNormConvLayer(keras.layers.Layer):
def __init__(self, config: Wav2Vec2Config, layer_id: int = 0, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.in_conv_dim = config.conv_dim[layer_id] if layer_id > 0 else 1
self.out_conv_dim = config.conv_dim[layer_id]
self.conv = keras.layers.Conv1D(
filters=self.out_conv_dim,
kernel_size=config.conv_kernel[layer_id],
strides=config.conv_stride[layer_id],
use_bias=config.conv_bias,
name="conv",
)
self.layer_norm = keras.layers.LayerNormalization(name="layer_norm", epsilon=config.layer_norm_eps)
self.activation = get_tf_activation(config.feat_extract_activation)
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.conv(hidden_states)
hidden_states = self.layer_norm(hidden_states)
hidden_states = self.activation(hidden_states)
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "conv", None) is not None:
with tf.name_scope(self.conv.name):
self.conv.build([None, None, self.in_conv_dim])
if getattr(self, "layer_norm", None) is not None:
with tf.name_scope(self.layer_norm.name):
self.layer_norm.build([None, None, self.out_conv_dim])
class TFWav2Vec2GroupNormConvLayer(keras.layers.Layer):
def __init__(self, config: Wav2Vec2Config, layer_id: int = 0, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.in_conv_dim = config.conv_dim[layer_id] if layer_id > 0 else 1
self.out_conv_dim = config.conv_dim[layer_id]
self.conv = keras.layers.Conv1D(
filters=self.out_conv_dim,
kernel_size=config.conv_kernel[layer_id],
strides=config.conv_stride[layer_id],
use_bias=config.conv_bias,
name="conv",
)
self.activation = get_tf_activation(config.feat_extract_activation)
self.layer_norm = TFWav2Vec2GroupNorm(
groups=self.out_conv_dim, epsilon=config.layer_norm_eps, name="layer_norm"
)
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.conv(hidden_states)
hidden_states = self.layer_norm(hidden_states)
hidden_states = self.activation(hidden_states)
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "conv", None) is not None:
with tf.name_scope(self.conv.name):
self.conv.build([None, None, self.in_conv_dim])
if getattr(self, "layer_norm", None) is not None:
with tf.name_scope(self.layer_norm.name):
self.layer_norm.build([None, None, self.out_conv_dim])
class TFWav2Vec2PositionalConvEmbedding(keras.layers.Layer):
def __init__(self, config: Wav2Vec2Config, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.conv = TFWav2Vec2WeightNormConv1D(
filters=config.hidden_size,
kernel_size=config.num_conv_pos_embeddings,
groups=config.num_conv_pos_embedding_groups,
explicit_padding=config.num_conv_pos_embeddings // 2,
name="conv",
)
self.padding = TFWav2Vec2SamePadLayer(config.num_conv_pos_embeddings)
self.activation = get_tf_activation(config.feat_extract_activation)
self.config = config
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.conv(hidden_states)
hidden_states = self.padding(hidden_states)
hidden_states = self.activation(hidden_states)
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "conv", None) is not None:
with tf.name_scope(self.conv.name):
self.conv.build([None, None, self.config.hidden_size])
class TFWav2Vec2SamePadLayer(keras.layers.Layer):
def __init__(self, num_conv_pos_embeddings, **kwargs):
super().__init__(**kwargs)
self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0
def call(self, hidden_states):
if self.num_pad_remove > 0:
hidden_states = hidden_states[:, : -self.num_pad_remove, :]
return hidden_states
class TFWav2Vec2FeatureEncoder(keras.layers.Layer):
def __init__(self, config: Wav2Vec2Config, **kwargs: Any) -> None:
super().__init__(**kwargs)
if config.feat_extract_norm == "group":
conv_layers = [TFWav2Vec2GroupNormConvLayer(config, layer_id=0, name=f"conv_layers.{0}")] + [
TFWav2Vec2NoLayerNormConvLayer(config, layer_id=i + 1, name=f"conv_layers.{i+1}")
for i in range(config.num_feat_extract_layers - 1)
]
elif config.feat_extract_norm == "layer":
conv_layers = [
TFWav2Vec2LayerNormConvLayer(config, layer_id=i, name=f"conv_layers.{i}")
for i in range(config.num_feat_extract_layers)
]
else:
raise ValueError(
f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']"
)
self.conv_layers = conv_layers
def call(self, input_values):
hidden_states = tf.expand_dims(input_values, -1)
for conv_layer in self.conv_layers:
hidden_states = conv_layer(hidden_states)
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "conv_layers", None) is not None:
for conv_layer in self.conv_layers:
with tf.name_scope(conv_layer.name):
conv_layer.build(None)
class TFWav2Vec2FeatureExtractor(TFWav2Vec2FeatureEncoder):
def __init__(self, config, **kwargs):
super().__init__(config, **kwargs)
warnings.warn(
f"The class `{self.__class__.__name__}` has been depreciated "
"and will be removed in Transformers v5. "
f"Use `{self.__class__.__bases__[0].__name__}` instead.",
FutureWarning,
)
class TFWav2Vec2FeatureProjection(keras.layers.Layer):
def __init__(self, config: Wav2Vec2Config, **kwargs):
super().__init__(**kwargs)
self.layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
self.projection = keras.layers.Dense(
units=config.hidden_size,
kernel_initializer=get_initializer(config.initializer_range),
bias_initializer="zeros",
name="projection",
)
self.dropout = keras.layers.Dropout(rate=config.feat_proj_dropout)
self.config = config
def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor:
norm_hidden_states = self.layer_norm(hidden_states)
hidden_states = self.projection(norm_hidden_states)
hidden_states = self.dropout(hidden_states, training=training)
return hidden_states, norm_hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "layer_norm", None) is not None:
with tf.name_scope(self.layer_norm.name):
self.layer_norm.build([None, None, self.config.conv_dim[-1]])
if getattr(self, "projection", None) is not None:
with tf.name_scope(self.projection.name):
self.projection.build([None, None, self.config.conv_dim[-1]])
class TFWav2Vec2Attention(keras.layers.Layer):
"""Multi-headed attention from "Attention Is All You Need"""
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
is_decoder: bool = False,
bias: bool = True,
**kwargs,
super().__init__(**kwargs)
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = keras.layers.Dropout(dropout)
self.head_dim = embed_dim // num_heads
if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {num_heads})."
)
self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
self.k_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="k_proj")
self.q_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="q_proj")
self.v_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="v_proj")
self.out_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="out_proj")
def _shape(self, tensor: tf.Tensor, seq_len: int, bsz: int):
return tf.transpose(tf.reshape(tensor, (bsz, seq_len, self.num_heads, self.head_dim)), (0, 2, 1, 3))
def call(
self,
hidden_states: tf.Tensor,
key_value_states: tf.Tensor | None = None,
past_key_value: Tuple[Tuple[tf.Tensor]] | None = None,
attention_mask: tf.Tensor | None = None,
layer_head_mask: tf.Tensor | None = None,
training: Optional[bool] = False,
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "k_proj", None) is not None:
with tf.name_scope(self.k_proj.name):
self.k_proj.build([None, None, self.embed_dim])
if getattr(self, "q_proj", None) is not None:
with tf.name_scope(self.q_proj.name):
self.q_proj.build([None, None, self.embed_dim])
if getattr(self, "v_proj", None) is not None:
with tf.name_scope(self.v_proj.name):
self.v_proj.build([None, None, self.embed_dim])
if getattr(self, "out_proj", None) is not None:
with tf.name_scope(self.out_proj.name):
self.out_proj.build([None, None, self.embed_dim])
class TFWav2Vec2EncoderLayer(keras.layers.Layer):
def __init__(self, config: Wav2Vec2Config, **kwargs):
super().__init__(**kwargs)
self.attention = TFWav2Vec2Attention(
embed_dim=config.hidden_size,
num_heads=config.num_attention_heads,
dropout=config.attention_dropout,
is_decoder=False,
name="attention",
)
self.dropout = keras.layers.Dropout(config.hidden_dropout)
self.layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
self.feed_forward = TFWav2Vec2FeedForward(config, name="feed_forward")
self.final_layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="final_layer_norm")
self.config = config
def call(
self,
hidden_states: tf.Tensor,
attention_mask: tf.Tensor | None = None,
output_attentions: Optional[bool] = False,
training: bool = False,
) -> tf.Tensor:
hidden_states = self.attention(
hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
training=training,
)
hidden_states = self.dropout(hidden_states, training=training)
hidden_states = self.layer_norm(hidden_states)
hidden_states = self.feed_forward(hidden_states, training=training)
hidden_states = self.final_layer_norm(hidden_states)
return hidden_states
) -> Tuple[tf.Tensor]:
attn_residual = hidden_states
hidden_states, attn_weights, _ = self.attention(
hidden_states, attention_mask=attention_mask, training=training
)
hidden_states = self.dropout(hidden_states, training=training)
hidden_states = attn_residual + hidden_states
hidden_states = self.layer_norm(hidden_states)
hidden_states = hidden_states + self.feed_forward(hidden_states)
hidden_states = self.final_layer_norm(hidden_states)
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "attention", None) is not None:
with tf.name_scope(self.attention.name):
self.attention.build(None)
if getattr(self, "layer_norm", None) is not None:
with tf.name_scope(self.layer_norm.name):
self.layer_norm.build([None, None, self.config.hidden_size])
if getattr(self, "feed_forward", None) is not None:
with tf.name_scope(self.feed_forward.name):
self.feed_forward.build(None)
if getattr(self, "final_layer_norm", None) is not None:
with tf.name_scope(self.final_layer_norm.name):
self.final_layer_norm.build([None, None, self.config.hidden_size])
class TFWav2Vec2EncoderLayerStableLayerNorm(keras.layers.Layer):
def __init__(self, config: Wav2Vec2Config, **kwargs):
super().__init__(**kwargs)
self.attention = TFWav2Vec2Attention(
embed_dim=config.hidden_size,
num_heads=config.num_attention_heads,
dropout=config.attention_dropout,
is_decoder=False,
name="attention",
)
self.dropout = keras.layers.Dropout(config.hidden_dropout)
self.layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
self.feed_forward = TFWav2Vec2FeedForward(config, name="feed_forward")
self.final_layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="final_layer_norm")
self.config = config
def call(
self,
hidden_states: tf.Tensor,
attention_mask: tf.Tensor | None = None,
output_attentions: Optional[bool] = False,
training: bool = False,
) -> Tuple[tf.Tensor]:
attn_residual = hidden_states
hidden_states = self.layer_norm(hidden_states)
hidden_states, attn_weights, _ = self.attention(
hidden_states, attention_mask=attention_mask, training=training
)
hidden_states = self.dropout(hidden_states, training=training)
hidden_states = attn_residual + hidden_states
hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states))
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "attention", None) is not None:
with tf.name_scope(self.attention.name):
self.attention.build(None)
if getattr(self, "layer_norm", None) is not None:
with tf.name_scope(self.layer_norm.name):
self.layer_norm.build([None, None, self.config.hidden_size])
if getattr(self, "feed_forward", None) is not None:
with tf.name_scope(self.feed_forward.name):
self.feed_forward.build(None)
if getattr(self, "final_layer_norm", None) is not None:
with tf.name_scope(self.final_layer_norm.name):
self.final_layer_norm.build([None, None, self.config.hidden_size])
class TFWav2Vec2Encoder(keras.layers.Layer):
def __init__(self, config: Wav2Vec2Config, **kwargs):
super().__init__(**kwargs)
self.config = config
self.pos_conv_embed = TFWav2Vec2PositionalConvEmbedding(config, name="pos_conv_embed")
self.layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
self.dropout = keras.layers.Dropout(config.hidden_dropout)
self.layer = [TFWav2Vec2EncoderLayer(config, name=f"layers.{i}") for i in range(config.num_hidden_layers)]
def call(
self,
hidden_states: tf.Tensor,
attention_mask: tf.Tensor | None = None,
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = True,
training: Optional[bool] = False,
) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
if attention_mask is not None:
hidden_states = hidden_states * tf.expand_dims(attention_mask, -1)
attention_mask = _expand_mask(attention_mask)
else:
attention_mask = None
position_embeddings = self.pos_conv_embed(hidden_states)
hidden_states = hidden_states + position_embeddings
hidden_states = self.layer_norm(hidden_states)
hidden_states = self.dropout(hidden_states, training=training)
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
dropout_probability = np.random.uniform(0, 1)
if training and (dropout_probability < self.config.layerdrop):
continue
layer_outputs = layer_module(
hidden_states=hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
training=training,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
return TFBaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "pos_conv_embed", None) is not None:
with tf.name_scope(self.pos_conv_embed.name):
self.pos_conv_embed.build(None)
if getattr(self, "layer_norm", None) is not None:
with tf.name_scope(self.layer_norm.name):
self.layer_norm.build([None, None, self.config.hidden_size])
if getattr(self, "layer", None) is not None:
for layer in self.layer:
with tf.name_scope(layer.name):
layer.build(None)
class TFWav2Vec2EncoderStableLayerNorm(keras.layers.Layer):
def __init__(self, config: Wav2Vec2Config, **kwargs):
super().__init__(**kwargs)
self.config = config
self.pos_conv_embed = TFWav2Vec2PositionalConvEmbedding(config, name="pos_conv_embed")
self.layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
self.dropout = keras.layers.Dropout(config.hidden_dropout)
self.layer = [
TFWav2Vec2EncoderLayerStableLayerNorm(config, name=f"layers.{i}") for i in range(config.num_hidden_layers)
]
def call(
self,
hidden_states: tf.Tensor,
attention_mask: tf.Tensor | None = None,
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = True,
training: Optional[bool] = False,
) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
if attention_mask is not None:
hidden_states = hidden_states * tf.expand_dims(attention_mask, -1)
attention_mask = _expand_mask(attention_mask)
else:
attention_mask = None
position_embeddings = self.pos_conv_embed(hidden_states)
hidden_states = hidden_states + position_embeddings
hidden_states = self.dropout(hidden_states, training=training)
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
dropout_probability = np.random.uniform(0, 1)
if training and (dropout_probability < self.config.layerdrop):
continue
layer_outputs = layer_module(
hidden_states=hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
training=training,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
hidden_states = self.layer_norm(hidden_states)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
return TFBaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "pos_conv_embed", None) is not None:
with tf.name_scope(self.pos_conv_embed.name):
self.pos_conv_embed.build(None)
if getattr(self, "layer_norm", None) is not None:
with tf.name_scope(self.layer_norm.name):
self.layer_norm.build([None, None, self.config.hidden_size])
if getattr(self, "layer", None) is not None:
for layer in self.layer:
with tf.name_scope(layer.name):
layer.build(None)
@keras_serializable
class TFWav2Vec2MainLayer(keras.layers.Layer):
config_class = Wav2Vec2Config
def __init__(self, config: Wav2Vec2Config, **kwargs):
super().__init__(**kwargs)
self.config = config
self.feature_extractor = TFWav2Vec2FeatureEncoder(config, name="feature_extractor")
self.feature_projection = TFWav2Vec2FeatureProjection(config, name="feature_projection")
if config.do_stable_layer_norm:
self.encoder = TFWav2Vec2EncoderStableLayerNorm(config, name="encoder")
else:
self.encoder = TFWav2Vec2Encoder(config, name="encoder")
def build(self, input_shape=None):
if self.built:
return
self.built = True
if self.config.mask_time_prob > 0.0 or self.config.mask_feature_prob > 0.0:
self.masked_spec_embed = self.add_weight(
shape=(self.config.hidden_size,),
initializer="uniform",
trainable=True,
name="masked_spec_embed"
)
if getattr(self, "feature_extractor", None) is not None:
with tf.name_scope(self.feature_extractor.name):
self.feature_extractor.build(None)
if getattr(self, "feature_projection", None) is not None:
with tf.name_scope(self.feature_projection.name):
self.feature_projection.build(None)
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
def _get_feat_extract_output_lengths(self, input_lengths: tf.Tensor):
"""
Computes the output length of the convolutional layers
"""
def _conv_out_length(input_length, kernel_size, stride):
return (input_length - kernel_size) // stride + 1
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
return input_lengths
def _mask_hidden_states(self, hidden_states: tf.Tensor, mask_time_indices: tf.Tensor | None = None):
"""
Masks extracted features along time axis and/or along feature axis according to
[SpecAugment](https://arxiv.org/abs/1904.08779).
"""
batch_size, sequence_length, hidden_size = shape_list(hidden_states)
if not getattr(self.config, "apply_spec_augment", True):
return hidden_states
if mask_time_indices is not None:
hidden_states = tf.where(
tf.cast(mask_time_indices[:, :, tf.newaxis], tf.bool),
self.masked_spec_embed[tf.newaxis, tf.newaxis, :],
hidden_states,
)
elif self.config.mask_time_prob > 0:
mask_time_indices = _compute_mask_indices(
(batch_size, sequence_length),
mask_prob=self.config.mask_time_prob,
mask_length=self.config.mask_time_length,
min_masks=2,
)
hidden_states = tf.where(
tf.cast(mask_time_indices[:, :, tf.newaxis], tf.bool),
self.masked_spec_embed[tf.newaxis, tf.newaxis, :],
hidden_states,
)
if self.config.mask_feature_prob > 0:
mask_feature_indices = _compute_mask_indices(
(batch_size, hidden_size),
mask_prob=self.config.mask_feature_prob,
mask_length=self.config.mask_feature_length,
)
hidden_states = tf.where(mask_feature_indices[:, tf.newaxis, :], hidden_states, 0)
return hidden_states
@unpack_inputs
def call(
self,
input_values: tf.Tensor,
attention_mask: tf.Tensor | None = None,
token_type_ids: tf.Tensor | None = None,
position_ids: tf.Tensor | None = None,
head_mask: tf.Tensor | None = None,
inputs_embeds: tf.Tensor | None = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
**kwargs: Any,
):
extract_features = self.feature_extractor(tf.cast(input_values, tf.float32), training=training)
if attention_mask is not None:
output_lengths = self._get_feat_extract_output_lengths(tf.reduce_sum(attention_mask, -1))
attention_mask = tf.sequence_mask(
output_lengths, maxlen=shape_list(extract_features)[1], dtype=extract_features.dtype
)
hidden_states, extract_features = self.feature_projection(extract_features, training=training)
mask_time_indices = kwargs.get("mask_time_indices", None)
if training:
hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices)
encoder_outputs = self.encoder(
hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
hidden_states = encoder_outputs[0]
if not return_dict:
return (hidden_states, extract_features) + encoder_outputs[1:]
return TFWav2Vec2BaseModelOutput(
last_hidden_state=hidden_states,
extract_features=extract_features,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = Wav2Vec2Config
base_model_prefix = "wav2vec2"
main_input_name = "input_values"
@property
def input_signature(self):
return {
"input_values": tf.TensorSpec((None, None), tf.float32, name="input_values"),
"attention_mask": tf.TensorSpec((None, None), tf.float32, name="attention_mask"),
}
@property
def dummy_inputs(self):
return {
"input_values": tf.random.uniform(shape=(1, 500), dtype=tf.float32),
"attention_mask": tf.ones(shape=(1, 500), dtype=tf.float32),
}
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
logger.warning(
f"\n{self.__class__.__name__} has backpropagation operations that are NOT supported on CPU. If you wish "
"to train/fine-tune this model, you need a GPU or a TPU"
)
def _get_feat_extract_output_lengths(self, input_lengths, add_adapter=None):
"""
Computes the output length of the convolutional layers
"""
add_adapter = self.config.add_adapter if add_adapter is None else add_adapter
def _conv_out_length(input_length, kernel_size, stride):
return tf.math.floordiv(input_length - kernel_size, stride) + 1
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
if add_adapter:
for _ in range(self.config.num_adapter_layers):
input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)
return input_lengths
def _get_feature_vector_attention_mask(
self, feature_vector_length: int, attention_mask: tf.Tensor, add_adapter=None
)
non_padded_lengths = tf.math.cumsum(attention_mask, axis=-1)[:, -1]
output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter)
output_lengths = tf.cast(output_lengths, tf.int32)
batch_size = tf.shape(attention_mask)[0]
attention_mask = tf.zeros(
(batch_size, feature_vector_length), dtype=attention_mask.dtype, name="attention_mask"
)
attention_mask = tf.tensor_scatter_nd_update(
attention_mask,
indices=tf.stack([tf.range(batch_size), output_lengths - 1], axis=1),
updates=tf.ones([batch_size], dtype=attention_mask.dtype),
)
attention_mask = tf.reverse(attention_mask, axis=[-1])
attention_mask = tf.cumsum(attention_mask, axis=-1)
attention_mask = tf.reverse(attention_mask, axis=[-1])
attention_mask = tf.cast(attention_mask, tf.bool)
return attention_mask
"""
This model inherits from `TFPreTrainedModel`. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a `keras.Model` subclass. Use it as a regular TF 2.0 Keras Model and refer to the TF 2.0
documentation for all matters related to general usage and behavior.
<Tip>
TensorFlow models and layers in `transformers` accept two formats as input:
- having all inputs as keyword arguments (like PyTorch models), or
- having all inputs as a list, tuple or dict in the first positional argument.
The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second
format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with
the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first
positional argument:
- a single Tensor with `input_values` only and nothing else: `model(input_values)`
- a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
`model([input_values, attention_mask])` or `model([input_values, attention_mask, token_type_ids])`
- a dictionary with one or several input Tensors associated to the input names given in the docstring:
`model({"input_values": input_values, "token_type_ids": token_type_ids})`
Note that when creating models and layers with
[subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
about any of this, as you can just pass inputs like you would to any other Python function!
</Tip>
Args:
config ([`Wav2Vec2Config`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
WAV_2_VEC_2_START_DOCSTRING = r"""
This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
behavior.
<Tip>
TensorFlow models and layers in `transformers` accept two formats as input:
- having all inputs as keyword arguments (like PyTorch models), or
- having all inputs as a list, tuple or dict in the first positional argument.
The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second
format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with
the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first
positional argument:
- a single Tensor with `input_values` only and nothing else: `model(input_values)`
- a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
`model([input_values, attention_mask])` or `model([input_values, attention_mask, token_type_ids])`
- a dictionary with one or several input Tensors associated to the input names given in the docstring:
`model({"input_values": input_values, "token_type_ids": token_type_ids})`
Note that when creating models and layers with
[subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
about any of this, as you can just pass inputs like you would to any other Python function!
</Tip>
Args:
config ([`Wav2Vec2Config`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
"""
"""
@add_start_docstrings(
"The bare TFWav2Vec2 Model transformer outputing raw hidden-states without any specific head on top.",
WAV_2_VEC_2_START_DOCSTRING,
)
class TFWav2Vec2Model(TFWav2Vec2PreTrainedModel):
"""
The bare TFWav2Vec2 Model transformer outputting raw hidden-states without any specific head on top.
This class inherits from `TFWav2Vec2PreTrainedModel` and includes additional documentation provided by
`WAV_2_VEC_2_START_DOCSTRING`.
Args:
config (Wav2Vec2Config): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the `PreTrainedModel.from_pretrained` method to load the model weights.
"""
def __init__(self, config: Wav2Vec2Config, *inputs, **kwargs):
"""
Initializes a TFWav2Vec2Model instance.
Args:
config (Wav2Vec2Config): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the PreTrainedModel.from_pretrained method to load the model weights.
*inputs: Additional positional arguments to be passed.
**kwargs: Additional keyword arguments to be passed.
"""
super().__init__(config, *inputs, **kwargs)
self.config = config
self.wav2vec2 = TFWav2Vec2MainLayer(config, name="wav2vec2")
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFBaseModelOutput, config_class=_CONFIG_FOR_DOC)
@unpack_inputs
def call(
self,
input_values: tf.Tensor,
attention_mask: tf.Tensor | None = None,
token_type_ids: tf.Tensor | None = None,
position_ids: tf.Tensor | None = None,
head_mask: tf.Tensor | None = None,
inputs_embeds: tf.Tensor | None = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
"""
Returns:
Example:
```
>>> from transformers import AutoProcessor, TFWav2Vec2Model
>>> from datasets import load_dataset
>>> import soundfile as sf
>>> processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h")
>>> model = TFWav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
>>> def map_to_array(batch):
... speech, _ = sf.read(batch["file"])
... batch["speech"] = speech
... return batch
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> ds = ds.map(map_to_array)
>>> input_values = processor(ds["speech"][0], return_tensors="tf").input_values # Batch size 1
>>> hidden_states = model(input_values).last_hidden_state
```"""
output_hidden_states = output_hidden_states if output_hidden_states else self.config.output_hidden_states
output_attentions = output_attentions if output_attentions else self.config.output_attentions
return_dict = return_dict if return_dict else self.config.return_dict
outputs = self.wav2vec2(
input_values=input_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "wav2vec2", None) is not None:
with tf.name_scope(self.wav2vec2.name):
self.wav2vec2.build(None)
@add_start_docstrings(
"""TFWav2Vec2 Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
WAV_2_VEC_2_START_DOCSTRING,
)
class TFWav2Vec2ForCTC(TFWav2Vec2PreTrainedModel):
def __init__(self, config: Wav2Vec2Config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.wav2vec2 = TFWav2Vec2MainLayer(config, name="wav2vec2")
self.dropout = keras.layers.Dropout(config.final_dropout)
self.lm_head = keras.layers.Dense(config.vocab_size, name="lm_head")
self.output_hidden_size = (
config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size
)
def freeze_feature_extractor(self):
"""
Calling this function will disable the gradient computation for the feature encoder so that its parameters will
not be updated during training.
"""
warnings.warn(
"The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
"Please use the equivalent `freeze_feature_encoder` method instead.",
FutureWarning,
)
self.freeze_feature_encoder()
def freeze_feature_encoder(self):
"""
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
not be updated during training.
"""
self.wav2vec2.feature_extractor.trainable = False
@unpack_inputs
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFCausalLMOutput, config_class=_CONFIG_FOR_DOC)
def call(
self,
input_values: tf.Tensor,
attention_mask: tf.Tensor | None = None,
token_type_ids: tf.Tensor | None = None,
position_ids: tf.Tensor | None = None,
head_mask: tf.Tensor | None = None,
inputs_embeds: tf.Tensor | None = None,
output_attentions: Optional[bool] = None,
labels: tf.Tensor | None = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: Optional[bool] = False,
):
"""
Call function to perform forward pass of the model. This function integrates with the `transformers` library's
`add_start_docstrings_to_model_forward` decorator to provide structured documentation for inputs and outputs.
"""
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "wav2vec2", None) is not None:
with tf.name_scope(self.wav2vec2.name):
self.wav2vec2.build(None)
if getattr(self, "lm_head", None) is not None:
with tf.name_scope(self.lm_head.name):
self.lm_head.build([None, None, self.output_hidden_size])
def __init__(self, config):
super().__init__(config)
self.wav2vec2 = TFWav2Vec2MainLayer(config, name="wav2vec2")
self.num_layers = config.num_hidden_layers + 1
with tf.name_scope(self._name_scope()):
if config.use_weighted_layer_sum:
self.layer_weights = self.add_weight(
shape=(self.num_layers,), initializer="ones", trainable=True, name="layer_weights"
)
self.config = config
self.projector = keras.layers.Dense(units=config.classifier_proj_size, name="projector")
self.classifier = keras.layers.Dense(units=config.num_labels, activation=None, name="classifier")
def freeze_feature_extractor(self):
"""
Calling this function will disable the gradient computation for the feature encoder so that its parameters will
not be updated during training.
"""
warnings.warn(
"The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
"Please use the equivalent `freeze_feature_encoder` method instead.",
FutureWarning,
)
self.freeze_feature_encoder()
def freeze_feature_encoder(self):
"""
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
not be updated during training.
"""
self.wav2vec2.feature_extractor.trainable = False
def freeze_base_model(self):
"""
Calling this function will disable the gradient computation for the base model so that its parameters will not
be updated during training. Only the classification head will be updated.
"""
for layer in self.wav2vec2.layers:
layer.trainable = False
@unpack_inputs
def call(
self,
input_values: tf.Tensor,
attention_mask: tf.Tensor | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
return_dict: bool | None = None,
labels: tf.Tensor | None = None,
training: bool = False,
) -> TFSequenceClassifierOutput | Tuple[tf.Tensor]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
outputs = self.wav2vec2(
input_values,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
if self.config.use_weighted_layer_sum:
hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
hidden_states = tf.stack(hidden_states, axis=1)
norm_weights = tf.nn.softmax(self.layer_weights, axis=-1)
hidden_states = tf.reduce_sum(hidden_states * tf.reshape(norm_weights, [-1, 1, 1]), axis=1)
else:
hidden_states = outputs[0]
hidden_states = self.projector(hidden_states)
if attention_mask is None:
pooled_output = tf.reduce_mean(hidden_states, axis=1)
else:
padding_mask = self._get_feature_vector_attention_mask(shape_list(hidden_states)[1], attention_mask)
padding_mask_float = tf.cast(padding_mask, hidden_states.dtype)
hidden_states = tf.multiply(hidden_states, tf.expand_dims(padding_mask_float, axis=-1))
pooled_output = tf.divide(
tf.reduce_sum(hidden_states, axis=1), tf.expand_dims(tf.reduce_sum(padding_mask_float, axis=1), axis=1)
)
logits = self.classifier(pooled_output)
loss = None
if labels is not None:
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
loss = loss_fn(tf.reshape(labels, [-1]), tf.reshape(logits, [-1, self.config.num_labels]))
if not return_dict:
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
return ((loss,) + output) if loss is not None else output
return TFSequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)