Transformers 源码解析(九十三)
.\models\reformer\tokenization_reformer_fast.py
import os
from shutil import copyfile
from typing import Optional, Tuple
from ...tokenization_utils_fast import PreTrainedTokenizerFast
from ...utils import is_sentencepiece_available, logging
if is_sentencepiece_available():
from .tokenization_reformer import ReformerTokenizer
else:
ReformerTokenizer = None
logger = logging.get_logger(__name__)
SPIECE_UNDERLINE = "▁"
VOCAB_FILES_NAMES = {"vocab_file": "spiece.model", "tokenizer_file": "tokenizer.json"}
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"google/reformer-crime-and-punishment": (
"https://huggingface.co/google/reformer-crime-and-punishment/resolve/main/spiece.model"
)
},
"tokenizer_file": {
"google/reformer-crime-and-punishment": (
"https://huggingface.co/google/reformer-crime-and-punishment/resolve/main/tokenizer.json"
)
},
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"google/reformer-crime-and-punishment": 524288,
}
class ReformerTokenizerFast(PreTrainedTokenizerFast):
"""
构建一个“快速”Reformer分词器(由HuggingFace的tokenizers库支持)。基于Unigram模型。
这个分词器继承自 PreTrainedTokenizerFast,包含大多数主要方法。用户应该参考这个超类来获取更多关于这些方法的信息。
"""
Args:
vocab_file (`str`):
[SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
contains the vocabulary necessary to instantiate a tokenizer.
eos_token (`str`, *optional*, defaults to `"</s>"`):
The end of sequence token.
<Tip>
When building a sequence using special tokens, this is not the token that is used for the end of sequence.
The token used is the `sep_token`.
</Tip>
unk_token (`str`, *optional*, defaults to `"<unk>"`):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
token instead.
pad_token (`str`, *optional*, defaults to `"<pad>"`):
The token used for padding, for example when batching sequences of different lengths.
additional_special_tokens (`List[str]`, *optional*):
Additional special tokens used by the tokenizer.
"""
# 获取预定义的文件名常量列表
vocab_files_names = VOCAB_FILES_NAMES
# 获取预训练模型使用的词汇文件映射
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
# 获取预训练位置嵌入的最大模型输入尺寸
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
# 定义模型输入名称列表
model_input_names = ["input_ids", "attention_mask"]
# 慢速标记器类定义为 ReformerTokenizer
slow_tokenizer_class = ReformerTokenizer
def __init__(
self,
vocab_file=None,
tokenizer_file=None,
eos_token="</s>",
unk_token="<unk>",
additional_special_tokens=[],
**kwargs,
):
# 调用父类的初始化方法,传递参数以设置词汇文件、标记器文件、特殊标记等
super().__init__(
vocab_file,
tokenizer_file=tokenizer_file,
eos_token=eos_token,
unk_token=unk_token,
additional_special_tokens=additional_special_tokens,
**kwargs,
)
# 将参数中的词汇文件路径保存到对象属性中
self.vocab_file = vocab_file
@property
def can_save_slow_tokenizer(self) -> bool:
# 检查当前对象是否具备保存慢速标记器所需的信息,主要是检查词汇文件是否存在
return os.path.isfile(self.vocab_file) if self.vocab_file else False
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
# 如果无法保存慢速标记器,则引发 ValueError 异常
if not self.can_save_slow_tokenizer:
raise ValueError(
"Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
"tokenizer."
)
# 如果保存路径不是一个目录,则记录错误并返回
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
return
# 指定输出词汇文件的路径
out_vocab_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
)
# 如果当前词汇文件路径与输出路径不一致,则复制词汇文件到输出路径
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
copyfile(self.vocab_file, out_vocab_file)
# 返回保存的词汇文件路径的元组
return (out_vocab_file,)
.\models\reformer\__init__.py
from typing import TYPE_CHECKING
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_sentencepiece_available,
is_tokenizers_available,
is_torch_available,
)
_import_structure = {"configuration_reformer": ["REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "ReformerConfig"]}
try:
if not is_sentencepiece_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["tokenization_reformer"] = ["ReformerTokenizer"]
try:
if not is_tokenizers_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["tokenization_reformer_fast"] = ["ReformerTokenizerFast"]
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_reformer"] = [
"REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
"ReformerAttention",
"ReformerForMaskedLM",
"ReformerForQuestionAnswering",
"ReformerForSequenceClassification",
"ReformerLayer",
"ReformerModel",
"ReformerModelWithLMHead",
"ReformerPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_reformer import REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, ReformerConfig
try:
if not is_sentencepiece_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .tokenization_reformer import ReformerTokenizer
try:
if not is_tokenizers_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .tokenization_reformer_fast import ReformerTokenizerFast
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_reformer import (
REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
ReformerAttention,
ReformerForMaskedLM,
ReformerForQuestionAnswering,
ReformerForSequenceClassification,
ReformerLayer,
ReformerModel,
ReformerModelWithLMHead,
ReformerPreTrainedModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
.\models\regnet\configuration_regnet.py
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
REGNET_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"facebook/regnet-y-040": "https://huggingface.co/facebook/regnet-y-040/blob/main/config.json",
}
class RegNetConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`RegNetModel`]. It is used to instantiate a RegNet
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the RegNet
[facebook/regnet-y-040](https://huggingface.co/facebook/regnet-y-040) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
num_channels (`int`, *optional*, defaults to 3):
The number of input channels.
embedding_size (`int`, *optional*, defaults to 64):
Dimensionality (hidden size) for the embedding layer.
hidden_sizes (`List[int]`, *optional*, defaults to `[256, 512, 1024, 2048]`):
Dimensionality (hidden size) at each stage.
depths (`List[int]`, *optional*, defaults to `[3, 4, 6, 3]`):
Depth (number of layers) for each stage.
layer_type (`str`, *optional*, defaults to `"y"`):
The layer to use, it can be either `"x" or `"y"`. An `x` layer is a ResNet's BottleNeck layer with
`reduction` fixed to `1`. While a `y` layer is a `x` but with squeeze and excitation. Please refer to the
paper for a detailed explanation of how these layers were constructed.
hidden_act (`str`, *optional*, defaults to `"relu"`):
The non-linear activation function in each block. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"`
are supported.
downsample_in_first_stage (`bool`, *optional*, defaults to `False`):
If `True`, the first stage will downsample the inputs using a `stride` of 2.
Example:
```
>>> from transformers import RegNetConfig, RegNetModel
>>> # Initializing a RegNet regnet-y-40 style configuration
```
"""
configuration = RegNetConfig()
model = RegNetModel(configuration)
configuration = model.config
"""
model_type = "regnet"
# 定义模型类型为 "regnet"
layer_types = ["x", "y"]
# 支持的层类型列表,包括 'x' 和 'y'
def __init__(
self,
num_channels=3,
embedding_size=32,
hidden_sizes=[128, 192, 512, 1088],
depths=[2, 6, 12, 2],
groups_width=64,
layer_type="y",
hidden_act="relu",
**kwargs,
):
# 调用父类构造函数初始化对象
super().__init__(**kwargs)
# 检查给定的 layer_type 是否在支持的层类型列表中,如果不在则抛出错误
if layer_type not in self.layer_types:
raise ValueError(f"layer_type={layer_type} is not one of {','.join(self.layer_types)}")
# 设置对象的各个属性值
self.num_channels = num_channels
self.embedding_size = embedding_size
self.hidden_sizes = hidden_sizes
self.depths = depths
self.groups_width = groups_width
self.layer_type = layer_type
self.hidden_act = hidden_act
# 始终在第一阶段进行下采样
self.downsample_in_first_stage = True
```
.\models\regnet\convert_regnet_seer_10b_to_pytorch.py
"""转换 RegNet 10B 检查点为 vissl 格式。"""
import argparse
import json
import os
import re
from collections import OrderedDict
from dataclasses import dataclass, field
from functools import partial
from pathlib import Path
from pprint import pprint
from typing import Dict, List, Tuple
import torch
import torch.nn as nn
from classy_vision.models.regnet import RegNet, RegNetParams
from huggingface_hub import cached_download, hf_hub_url
from torch import Tensor
from vissl.models.model_helpers import get_trunk_forward_outputs
from transformers import AutoImageProcessor, RegNetConfig, RegNetForImageClassification, RegNetModel
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger()
@dataclass
class Tracker:
"""
追踪器类,用于跟踪模块的前向传播过程,并记录子模块和参数信息。
"""
module: nn.Module
traced: List[nn.Module] = field(default_factory=list)
handles: list = field(default_factory=list)
name2module: Dict[str, nn.Module] = field(default_factory=OrderedDict)
def _forward_hook(self, m, inputs: Tensor, outputs: Tensor, name: str):
"""
前向传播钩子函数,用于处理模块的前向传播输出。
"""
has_not_submodules = len(list(m.modules())) == 1 or isinstance(m, nn.Conv2d) or isinstance(m, nn.BatchNorm2d)
if has_not_submodules:
self.traced.append(m)
self.name2module[name] = m
def __call__(self, x: Tensor):
"""
执行追踪器对象,注册前向传播钩子,并进行模块的前向传播。
"""
for name, m in self.module.named_modules():
self.handles.append(m.register_forward_hook(partial(self._forward_hook, name=name)))
self.module(x)
[x.remove() for x in self.handles]
return self
@property
def parametrized(self):
"""
属性方法,返回具有可学习参数的模块字典。
"""
return {k: v for k, v in self.name2module.items() if len(list(v.state_dict().keys())) > 0}
class FakeRegNetVisslWrapper(nn.Module):
"""
模拟 vissl 操作而无需传递配置文件的 RegNet 包装器。
"""
pass
def __init__(self, model: nn.Module):
super().__init__()
feature_blocks: List[Tuple[str, nn.Module]] = []
feature_blocks.append(("conv1", model.stem))
for k, v in model.trunk_output.named_children():
assert k.startswith("block"), f"Unexpected layer name {k}"
block_index = len(feature_blocks) + 1
feature_blocks.append((f"res{block_index}", v))
self._feature_blocks = nn.ModuleDict(feature_blocks)
def forward(self, x: Tensor):
return get_trunk_forward_outputs(
x,
out_feat_keys=None,
feature_blocks=self._feature_blocks,
)
class FakeRegNetParams(RegNetParams):
"""
Used to instantiate a RegNet model from Classy Vision with the same depth as the 10B one but with super small
parameters, so we can trace it in memory.
"""
def get_expanded_params(self):
return [(8, 2, 2, 8, 1.0), (8, 2, 7, 8, 1.0), (8, 2, 17, 8, 1.0), (8, 2, 1, 8, 1.0)]
def get_from_to_our_keys(model_name: str) -> Dict[str, str]:
"""
Returns a dictionary that maps from original model's key -> our implementation's keys
"""
our_config = RegNetConfig(depths=[2, 7, 17, 1], hidden_sizes=[8, 8, 8, 8], groups_width=8)
if "in1k" in model_name:
our_model = RegNetForImageClassification(our_config)
else:
our_model = RegNetModel(our_config)
from_model = FakeRegNetVisslWrapper(
RegNet(FakeRegNetParams(depth=27, group_width=1010, w_0=1744, w_a=620.83, w_m=2.52))
)
with torch.no_grad():
from_model = from_model.eval()
our_model = our_model.eval()
x = torch.randn((1, 3, 32, 32))
dest_tracker = Tracker(our_model)
dest_traced = dest_tracker(x).parametrized
pprint(dest_tracker.name2module)
src_tracker = Tracker(from_model)
src_traced = src_tracker(x).parametrized
def to_params_dict(dict_with_modules):
params_dict = OrderedDict()
for name, module in dict_with_modules.items():
for param_name, param in module.state_dict().items():
params_dict[f"{name}.{param_name}"] = param
return params_dict
from_to_ours_keys = {}
src_state_dict = to_params_dict(src_traced)
dst_state_dict = to_params_dict(dest_traced)
for (src_key, src_param), (dest_key, dest_param) in zip(src_state_dict.items(), dst_state_dict.items()):
from_to_ours_keys[src_key] = dest_key
logger.info(f"{src_key} -> {dest_key}")
if "in1k" in model_name:
from_to_ours_keys["0.clf.0.weight"] = "classifier.1.weight"
from_to_ours_keys["0.clf.0.bias"] = "classifier.1.bias"
return from_to_ours_keys
def convert_weights_and_push(save_directory: Path, model_name: str = None, push_to_hub: bool = True):
filename = "imagenet-1k-id2label.json"
num_labels = 1000
repo_id = "huggingface/label-files"
num_labels = num_labels
id2label = json.load(open(cached_download(hf_hub_url(repo_id, filename, repo_type="dataset")), "r"))
id2label = {int(k): v for k, v in id2label.items()}
id2label = id2label
label2id = {v: k for k, v in id2label.items()}
ImageNetPreTrainedConfig = partial(RegNetConfig, num_labels=num_labels, id2label=id2label, label2id=label2id)
names_to_config = {
"regnet-y-10b-seer": ImageNetPreTrainedConfig(
depths=[2, 7, 17, 1], hidden_sizes=[2020, 4040, 11110, 28280], groups_width=1010
),
"regnet-y-10b-seer-in1k": ImageNetPreTrainedConfig(
depths=[2, 7, 17, 1], hidden_sizes=[2020, 4040, 11110, 28280], groups_width=1010
),
}
def load_using_classy_vision(checkpoint_url: str) -> Tuple[Dict, Dict]:
files = torch.hub.load_state_dict_from_url(checkpoint_url, model_dir=str(save_directory), map_location="cpu")
model_state_dict = files["classy_state_dict"]["base_model"]["model"]
return model_state_dict["trunk"], model_state_dict["heads"]
names_to_from_model = {
"regnet-y-10b-seer": partial(
load_using_classy_vision,
"https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_regnet10B/model_iteration124500_conso.torch",
),
"regnet-y-10b-seer-in1k": partial(
load_using_classy_vision,
"https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_finetuned/seer_10b_finetuned_in1k_model_phase28_conso.torch",
),
}
from_to_ours_keys = get_from_to_our_keys(model_name)
if not (save_directory / f"{model_name}.pth").exists():
logger.info("Loading original state_dict.")
from_state_dict_trunk, from_state_dict_head = names_to_from_model[model_name]()
from_state_dict = from_state_dict_trunk
if "in1k" in model_name:
from_state_dict = {**from_state_dict_trunk, **from_state_dict_head}
logger.info("Done!")
converted_state_dict = {}
not_used_keys = list(from_state_dict.keys())
regex = r"\.block.-part."
for key in from_state_dict.keys():
src_key = re.sub(regex, "", key)
dest_key = from_to_ours_keys[src_key]
converted_state_dict[dest_key] = from_state_dict[key]
not_used_keys.remove(key)
assert len(not_used_keys) == 0, f"Some keys where not used {','.join(not_used_keys)}"
logger.info(f"The following keys were not used: {','.join(not_used_keys)}")
torch.save(converted_state_dict, save_directory / f"{model_name}.pth")
del converted_state_dict
else:
logger.info("The state_dict was already stored on disk.")
if push_to_hub:
logger.info(f"Token is {os.environ['HF_TOKEN']}")
logger.info("Loading our model.")
our_config = names_to_config[model_name]
our_model_func = RegNetModel
if "in1k" in model_name:
our_model_func = RegNetForImageClassification
our_model = our_model_func(our_config)
our_model.to(torch.device("meta"))
logger.info("Loading state_dict in our model.")
state_dict_keys = our_model.state_dict().keys()
PreTrainedModel._load_pretrained_model_low_mem(
our_model, state_dict_keys, [save_directory / f"{model_name}.pth"]
)
logger.info("Finally, pushing!")
our_model.push_to_hub(
repo_path_or_name=save_directory / model_name,
commit_message="Add model",
output_dir=save_directory / model_name,
)
size = 384
logger.info("we can use the convnext one")
image_processor = AutoImageProcessor.from_pretrained("facebook/convnext-base-224-22k-1k", size=size)
image_processor.push_to_hub(
repo_path_or_name=save_directory / model_name,
commit_message="Add image processor",
output_dir=save_directory / model_name,
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name",
default=None,
type=str,
help=(
"The name of the model you wish to convert, it must be one of the supported regnet* architecture,"
" currently: regnetx-*, regnety-*. If `None`, all of them will the converted."
),
)
parser.add_argument(
"--pytorch_dump_folder_path",
default=None,
type=Path,
required=True,
help="Path to the output PyTorch model directory.",
)
parser.add_argument(
"--push_to_hub",
default=True,
type=bool,
required=False,
help="If True, push model and image processor to the hub.",
)
args = parser.parse_args()
pytorch_dump_folder_path: Path = args.pytorch_dump_folder_path
pytorch_dump_folder_path.mkdir(exist_ok=True, parents=True)
convert_weights_and_push(pytorch_dump_folder_path, args.model_name, args.push_to_hub)
.\models\regnet\convert_regnet_to_pytorch.py
"""Convert RegNet checkpoints from timm and vissl."""
import argparse
import json
from dataclasses import dataclass, field
from functools import partial
from pathlib import Path
from typing import Callable, Dict, List, Tuple
import timm
import torch
import torch.nn as nn
from classy_vision.models.regnet import RegNet, RegNetParams, RegNetY32gf, RegNetY64gf, RegNetY128gf
from huggingface_hub import cached_download, hf_hub_url
from torch import Tensor
from vissl.models.model_helpers import get_trunk_forward_outputs
from transformers import AutoImageProcessor, RegNetConfig, RegNetForImageClassification, RegNetModel
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger()
@dataclass
class Tracker:
module: nn.Module
traced: List[nn.Module] = field(default_factory=list)
handles: list = field(default_factory=list)
def _forward_hook(self, m, inputs: Tensor, outputs: Tensor):
has_not_submodules = len(list(m.modules())) == 1 or isinstance(m, nn.Conv2d) or isinstance(m, nn.BatchNorm2d)
if has_not_submodules:
self.traced.append(m)
def __call__(self, x: Tensor):
for m in self.module.modules():
self.handles.append(m.register_forward_hook(self._forward_hook))
self.module(x)
[x.remove() for x in self.handles]
return self
@property
def parametrized(self):
return list(filter(lambda x: len(list(x.state_dict().keys())) > 0, self.traced))
@dataclass
class ModuleTransfer:
src: nn.Module
dest: nn.Module
verbose: int = 1
src_skip: List = field(default_factory=list)
dest_skip: List = field(default_factory=list)
raise_if_mismatch: bool = True
def __call__(self, x: Tensor):
"""
Transfer the weights of `self.src` to `self.dest` by performing a forward pass using `x` as input. Under the
hood we tracked all the operations in both modules.
"""
dest_traced = Tracker(self.dest)(x).parametrized
src_traced = Tracker(self.src)(x).parametrized
src_traced = list(filter(lambda x: type(x) not in self.src_skip, src_traced))
dest_traced = list(filter(lambda x: type(x) not in self.dest_skip, dest_traced))
if len(dest_traced) != len(src_traced) and self.raise_if_mismatch:
raise Exception(
f"Numbers of operations are different. Source module has {len(src_traced)} operations while"
f" destination module has {len(dest_traced)}."
)
for dest_m, src_m in zip(dest_traced, src_traced):
dest_m.load_state_dict(src_m.state_dict())
if self.verbose == 1:
print(f"Transfered from={src_m} to={dest_m}")
class FakeRegNetVisslWrapper(nn.Module):
"""
Fake wrapper for RegNet that mimics what vissl does without the need to pass a config file.
"""
def __init__(self, model: nn.Module):
super().__init__()
feature_blocks: List[Tuple[str, nn.Module]] = []
feature_blocks.append(("conv1", model.stem))
for k, v in model.trunk_output.named_children():
assert k.startswith("block"), f"Unexpected layer name {k}"
block_index = len(feature_blocks) + 1
feature_blocks.append((f"res{block_index}", v))
self._feature_blocks = nn.ModuleDict(feature_blocks)
def forward(self, x: Tensor):
return get_trunk_forward_outputs(
x,
out_feat_keys=None,
feature_blocks=self._feature_blocks,
)
class NameToFromModelFuncMap(dict):
"""
A Dictionary with some additional logic to return a function that creates the correct original model.
"""
def convert_name_to_timm(self, x: str) -> str:
x_split = x.split("-")
return x_split[0] + x_split[1] + "_" + "".join(x_split[2:])
def __getitem__(self, x: str) -> Callable[[], Tuple[nn.Module, Dict]]:
if x not in self:
x = self.convert_name_to_timm(x)
val = partial(lambda: (timm.create_model(x, pretrained=True).eval(), None))
else:
val = super().__getitem__(x)
return val
class NameToOurModelFuncMap(dict):
"""
A Dictionary with some additional logic to return the correct hugging face RegNet class reference.
"""
def __getitem__(self, x: str) -> Callable[[], nn.Module]:
if "seer" in x and "in1k" not in x:
val = RegNetModel
else:
val = RegNetForImageClassification
return val
def manually_copy_vissl_head(from_state_dict, to_state_dict, keys: List[Tuple[str, str]]):
for from_key, to_key in keys:
to_state_dict[to_key] = from_state_dict[from_key].clone()
print(f"Copied key={from_key} to={to_key}")
return to_state_dict
def convert_weight_and_push(
name: str,
from_model_func: Callable[[], nn.Module],
our_model_func: Callable[[], nn.Module],
config: RegNetConfig,
save_directory: Path,
push_to_hub: bool = True,
):
print(f"Converting {name}...")
with torch.no_grad():
from_model, from_state_dict = from_model_func()
our_model = our_model_func(config).eval()
module_transfer = ModuleTransfer(src=from_model, dest=our_model, raise_if_mismatch=False)
x = torch.randn((1, 3, 224, 224))
module_transfer(x)
if from_state_dict is not None:
keys = []
if "seer" in name and "in1k" in name:
keys = [("0.clf.0.weight", "classifier.1.weight"), ("0.clf.0.bias", "classifier.1.bias")]
to_state_dict = manually_copy_vissl_head(from_state_dict, our_model.state_dict(), keys)
our_model.load_state_dict(to_state_dict)
our_outputs = our_model(x, output_hidden_states=True)
our_output = (
our_outputs.logits if isinstance(our_model, RegNetForImageClassification) else our_outputs.last_hidden_state
)
from_output = from_model(x)
from_output = from_output[-1] if isinstance(from_output, list) else from_output
if "seer" in name and "in1k" in name:
our_output = our_outputs.hidden_states[-1]
assert torch.allclose(from_output, our_output), "The model logits don't match the original one."
if push_to_hub:
our_model.push_to_hub(
repo_path_or_name=save_directory / name,
commit_message="Add model",
use_temp_dir=True,
)
size = 224 if "seer" not in name else 384
image_processor = AutoImageProcessor.from_pretrained("facebook/convnext-base-224-22k-1k", size=size)
image_processor.push_to_hub(
repo_path_or_name=save_directory / name,
commit_message="Add image processor",
use_temp_dir=True,
)
print(f"Pushed {name}")
def convert_weights_and_push(save_directory: Path, model_name: str = None, push_to_hub: bool = True):
filename = "imagenet-1k-id2label.json"
num_labels = 1000
expected_shape = (1, num_labels)
repo_id = "huggingface/label-files"
num_labels = num_labels
id2label = json.load(open(cached_download(hf_hub_url(repo_id, filename, repo_type="dataset")), "r"))
id2label = {int(k): v for k, v in id2label.items()}
id2label = id2label
label2id = {v: k for k, v in id2label.items()}
ImageNetPreTrainedConfig = partial(RegNetConfig, num_labels=num_labels, id2label=id2label, label2id=label2id)
names_to_ours_model_map = NameToOurModelFuncMap()
names_to_from_model_map = NameToFromModelFuncMap()
def load_using_classy_vision(checkpoint_url: str, model_func: Callable[[], nn.Module]) -> Tuple[nn.Module, Dict]:
files = torch.hub.load_state_dict_from_url(checkpoint_url, model_dir=str(save_directory), map_location="cpu")
model = model_func()
model_state_dict = files["classy_state_dict"]["base_model"]["model"]
state_dict = model_state_dict["trunk"]
model.load_state_dict(state_dict)
return model.eval(), model_state_dict["heads"]
names_to_from_model_map["regnet-y-320-seer"] = partial(
load_using_classy_vision,
"https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_regnet32d/seer_regnet32gf_model_iteration244000.torch",
lambda: FakeRegNetVisslWrapper(RegNetY32gf()),
)
names_to_from_model_map["regnet-y-640-seer"] = partial(
load_using_classy_vision,
"https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_regnet64/seer_regnet64gf_model_final_checkpoint_phase0.torch",
lambda: FakeRegNetVisslWrapper(RegNetY64gf()),
)
names_to_from_model_map["regnet-y-1280-seer"] = partial(
load_using_classy_vision,
"https://dl.fbaipublicfiles.com/vissl/model_zoo/swav_ig1b_regnet128Gf_cnstant_bs32_node16_sinkhorn10_proto16k_syncBN64_warmup8k/model_final_checkpoint_phase0.torch",
lambda: FakeRegNetVisslWrapper(RegNetY128gf()),
)
names_to_from_model_map["regnet-y-10b-seer"] = partial(
load_using_classy_vision,
"https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_regnet10B/model_iteration124500_conso.torch",
lambda: FakeRegNetVisslWrapper(
RegNet(RegNetParams(depth=27, group_width=1010, w_0=1744, w_a=620.83, w_m=2.52))
),
)
names_to_from_model_map["regnet-y-320-seer-in1k"] = partial(
load_using_classy_vision,
"https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_finetuned/seer_regnet32_finetuned_in1k_model_final_checkpoint_phase78.torch",
lambda: FakeRegNetVisslWrapper(RegNetY32gf()),
)
names_to_from_model_map["regnet-y-640-seer-in1k"] = partial(
load_using_classy_vision,
"https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_finetuned/seer_regnet64_finetuned_in1k_model_final_checkpoint_phase78.torch",
lambda: FakeRegNetVisslWrapper(RegNetY64gf()),
)
names_to_from_model_map["regnet-y-1280-seer-in1k"] = partial(
load_using_classy_vision,
"https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_finetuned/seer_regnet128_finetuned_in1k_model_final_checkpoint_phase78.torch",
lambda: FakeRegNetVisslWrapper(RegNetY128gf()),
)
names_to_from_model_map["regnet-y-10b-seer-in1k"] = partial(
load_using_classy_vision,
"https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_finetuned/seer_10b_finetuned_in1k_model_phase28_conso.torch",
lambda: FakeRegNetVisslWrapper(
RegNet(RegNetParams(depth=27, group_width=1010, w_0=1744, w_a=620.83, w_m=2.52))
),
)
if model_name:
convert_weight_and_push(
model_name,
names_to_from_model_map[model_name],
names_to_ours_model_map[model_name],
names_to_config[model_name],
save_directory,
push_to_hub,
)
else:
for model_name, config in names_to_config.items():
convert_weight_and_push(
model_name,
names_to_from_model_map[model_name],
names_to_ours_model_map[model_name],
config,
save_directory,
push_to_hub,
)
return config, expected_shape
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name",
default=None,
type=str,
help=(
"The name of the model you wish to convert, it must be one of the supported regnet* architecture,"
" currently: regnetx-*, regnety-*. If `None`, all of them will the converted."
),
)
parser.add_argument(
"--pytorch_dump_folder_path",
default=None,
type=Path,
required=True,
help="Path to the output PyTorch model directory.",
)
parser.add_argument(
"--push_to_hub",
default=True,
type=bool,
required=False,
help="If True, push model and image processor to the hub.",
)
args = parser.parse_args()
pytorch_dump_folder_path: Path = args.pytorch_dump_folder_path
pytorch_dump_folder_path.mkdir(exist_ok=True, parents=True)
convert_weights_and_push(pytorch_dump_folder_path, args.model_name, args.push_to_hub)
.\models\regnet\modeling_flax_regnet.py
from functools import partial
from typing import Optional, Tuple
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.traverse_util import flatten_dict, unflatten_dict
from transformers import RegNetConfig
from transformers.modeling_flax_outputs import (
FlaxBaseModelOutputWithNoAttention,
FlaxBaseModelOutputWithPooling,
FlaxBaseModelOutputWithPoolingAndNoAttention,
FlaxImageClassifierOutputWithNoAttention,
)
from transformers.modeling_flax_utils import (
ACT2FN,
FlaxPreTrainedModel,
append_replace_return_docstrings,
overwrite_call_docstring,
)
from transformers.utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
)
REGNET_START_DOCSTRING = r"""
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading, saving and converting weights from PyTorch models)
This model is also a
[flax.linen.Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) subclass. Use it as
a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and
behavior.
Finally, this model supports inherent JAX features such as:
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
"""
"""
REGNET_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`numpy.ndarray` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
[`RegNetImageProcessor.__call__`] for details.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
# Copied from transformers.models.resnet.modeling_flax_resnet.Identity
class Identity(nn.Module):
"""Identity function."""
@nn.compact
def __call__(self, x, **kwargs):
return x
RegNet shortcut, used to project the residual features to the correct size. If needed, it is also used to
downsample the input using `stride=2`.
"""
class RegNetShortcut(nn.Module):
def __init__(self, out_channels: int, stride: int = 2, dtype: jnp.dtype = jnp.float32):
self.out_channels = out_channels
self.stride = stride
self.dtype = dtype
def setup(self):
self.convolution = nn.Conv(
self.out_channels,
kernel_size=(1, 1),
strides=self.stride,
use_bias=False,
kernel_init=nn.initializers.variance_scaling(2.0, mode="fan_out", distribution="truncated_normal"),
dtype=self.dtype,
)
self.normalization = nn.BatchNorm(momentum=0.9, epsilon=1e-05, dtype=self.dtype)
def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
hidden_state = self.convolution(x)
hidden_state = self.normalization(hidden_state, use_running_average=deterministic)
return hidden_state
class FlaxRegNetSELayerCollection(nn.Module):
in_channels: int
reduced_channels: int
dtype: jnp.dtype = jnp.float32
def setup(self):
self.conv_1 = nn.Conv(
self.reduced_channels,
kernel_size=(1, 1),
kernel_init=nn.initializers.variance_scaling(2.0, mode="fan_out", distribution="truncated_normal"),
dtype=self.dtype,
name="0",
)
self.conv_2 = nn.Conv(
self.in_channels,
kernel_size=(1, 1),
kernel_init=nn.initializers.variance_scaling(2.0, mode="fan_out", distribution="truncated_normal"),
dtype=self.dtype,
name="2",
)
def __call__(self, hidden_state: jnp.ndarray) -> jnp.ndarray:
hidden_state = self.conv_1(hidden_state)
hidden_state = nn.relu(hidden_state)
hidden_state = self.conv_2(hidden_state)
attention = nn.sigmoid(hidden_state)
return attention
class FlaxRegNetSELayer(nn.Module):
"""
Squeeze and Excitation layer (SE) proposed in [Squeeze-and-Excitation Networks](https://arxiv.org/abs/1709.01507).
"""
in_channels: int
reduced_channels: int
dtype: jnp.dtype = jnp.float32
def setup(self):
self.pooler = partial(nn.avg_pool, padding=((0, 0), (0, 0)))
self.attention = FlaxRegNetSELayerCollection(self.in_channels, self.reduced_channels, dtype=self.dtype)
def __call__(self, hidden_state: jnp.ndarray) -> jnp.ndarray:
pooled = self.pooler(
hidden_state,
window_shape=(hidden_state.shape[1], hidden_state.shape[2]),
strides=(hidden_state.shape[1], hidden_state.shape[2]),
)
attention = self.attention(pooled)
hidden_state = hidden_state * attention
return hidden_state
class FlaxRegNetXLayerCollection(nn.Module):
config: RegNetConfig
out_channels: int
stride: int = 1
dtype: jnp.dtype = jnp.float32
def setup(self):
groups = max(1, self.out_channels // self.config.groups_width)
self.layer = [
FlaxRegNetConvLayer(
self.out_channels,
kernel_size=1,
activation=self.config.hidden_act,
dtype=self.dtype,
name="0",
),
FlaxRegNetConvLayer(
self.out_channels,
stride=self.stride,
groups=groups,
activation=self.config.hidden_act,
dtype=self.dtype,
name="1",
),
FlaxRegNetConvLayer(
self.out_channels,
kernel_size=1,
activation=None,
dtype=self.dtype,
name="2",
),
]
def __call__(self, hidden_state: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
for layer in self.layer:
hidden_state = layer(hidden_state, deterministic=deterministic)
return hidden_state
class FlaxRegNetXLayer(nn.Module):
"""
RegNet 的层,由三个 3x3 卷积组成,与 ResNet 的瓶颈层相同,但 reduction = 1。
"""
config: RegNetConfig
in_channels: int
out_channels: int
stride: int = 1
dtype: jnp.dtype = jnp.float32
def setup(self):
should_apply_shortcut = self.in_channels != self.out_channels or self.stride != 1
self.shortcut = (
FlaxRegNetShortCut(
self.out_channels,
stride=self.stride,
dtype=self.dtype,
)
if should_apply_shortcut
else Identity()
)
self.layer = FlaxRegNetXLayerCollection(
self.config,
in_channels=self.in_channels,
out_channels=self.out_channels,
stride=self.stride,
dtype=self.dtype,
)
self.activation_func = ACT2FN[self.config.hidden_act]
def __call__(self, hidden_state: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
residual = hidden_state
hidden_state = self.layer(hidden_state)
residual = self.shortcut(residual, deterministic=deterministic)
hidden_state += residual
hidden_state = self.activation_func(hidden_state)
return hidden_state
class FlaxRegNetYLayerCollection(nn.Module):
config: RegNetConfig
in_channels: int
out_channels: int
stride: int = 1
dtype: jnp.dtype = jnp.float32
def setup(self):
groups = max(1, self.out_channels // self.config.groups_width)
self.layer = [
FlaxRegNetConvLayer(
self.out_channels,
kernel_size=1,
activation=self.config.hidden_act,
dtype=self.dtype,
name="0",
),
FlaxRegNetConvLayer(
self.out_channels,
stride=self.stride,
groups=groups,
activation=self.config.hidden_act,
dtype=self.dtype,
name="1",
),
FlaxRegNetSELayer(
self.out_channels,
reduced_channels=int(round(self.in_channels / 4)),
dtype=self.dtype,
name="2",
),
FlaxRegNetConvLayer(
self.out_channels,
kernel_size=1,
activation=None,
dtype=self.dtype,
name="3",
),
]
def __call__(self, hidden_state: jnp.ndarray) -> jnp.ndarray:
for layer in self.layer:
hidden_state = layer(hidden_state)
return hidden_state
class FlaxRegNetYLayer(nn.Module):
"""
RegNet 的 Y 层:包含一个 X 层和 Squeeze and Excitation 模块。
"""
config: RegNetConfig
in_channels: int
out_channels: int
stride: int = 1
dtype: jnp.dtype = jnp.float32
def setup(self):
should_apply_shortcut = self.in_channels != self.out_channels or self.stride != 1
self.shortcut = (
FlaxRegNetShortCut(
self.out_channels,
stride=self.stride,
dtype=self.dtype,
)
if should_apply_shortcut
else Identity()
)
self.layer = FlaxRegNetYLayerCollection(
self.config,
in_channels=self.in_channels,
out_channels=self.out_channels,
stride=self.stride,
dtype=self.dtype,
)
self.activation_func = ACT2FN[self.config.hidden_act]
def __call__(self, hidden_state: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
residual = hidden_state
hidden_state = self.layer(hidden_state)
residual = self.shortcut(residual, deterministic=deterministic)
hidden_state += residual
hidden_state = self.activation_func(hidden_state)
return hidden_state
class FlaxRegNetStageLayersCollection(nn.Module):
"""
A RegNet stage composed by stacked layers.
"""
config: RegNetConfig
in_channels: int
out_channels: int
stride: int = 2
depth: int = 2
dtype: jnp.dtype = jnp.float32
def setup(self):
layer = FlaxRegNetXLayer if self.config.layer_type == "x" else FlaxRegNetYLayer
layers = [
layer(
self.config,
self.in_channels,
self.out_channels,
stride=self.stride,
dtype=self.dtype,
name="0",
)
]
for i in range(self.depth - 1):
layers.append(
layer(
self.config,
self.out_channels,
self.out_channels,
dtype=self.dtype,
name=str(i + 1),
)
)
self.layers = layers
def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
hidden_state = x
for layer in self.layers:
hidden_state = layer(hidden_state, deterministic=deterministic)
return hidden_state
class FlaxRegNetStage(nn.Module):
"""
A RegNet stage composed by stacked layers.
"""
config: RegNetConfig
in_channels: int
out_channels: int
stride: int = 2
depth: int = 2
dtype: jnp.dtype = jnp.float32
def setup(self):
self.layers = FlaxRegNetStageLayersCollection(
self.config,
in_channels=self.in_channels,
out_channels=self.out_channels,
stride=self.stride,
depth=self.depth,
dtype=self.dtype,
)
def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
return self.layers(x, deterministic=deterministic)
class FlaxRegNetStageCollection(nn.Module):
config: RegNetConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
in_out_channels = zip(self.config.hidden_sizes, self.config.hidden_sizes[1:])
stages = [
FlaxRegNetStage(
self.config,
self.config.embedding_size,
self.config.hidden_sizes[0],
stride=2 if self.config.downsample_in_first_stage else 1,
depth=self.config.depths[0],
dtype=self.dtype,
name="0",
)
]
for i, ((in_channels, out_channels), depth) in enumerate(zip(in_out_channels, self.config.depths[1:])):
stages.append(
FlaxRegNetStage(self.config, in_channels, out_channels, depth=depth, dtype=self.dtype, name=str(i + 1))
)
self.stages = stages
def __call__(
self,
hidden_state: jnp.ndarray,
output_hidden_states: bool = False,
deterministic: bool = True,
) -> FlaxBaseModelOutputWithNoAttention:
hidden_states = () if output_hidden_states else None
for stage_module in self.stages:
if output_hidden_states:
hidden_states = hidden_states + (hidden_state.transpose(0, 3, 1, 2),)
hidden_state = stage_module(hidden_state, deterministic=deterministic)
return hidden_state, hidden_states
class FlaxRegNetEncoder(nn.Module):
config: RegNetConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.stages = FlaxRegNetStageCollection(self.config, dtype=self.dtype)
def __call__(
self,
hidden_state: jnp.ndarray,
output_hidden_states: bool = False,
return_dict: bool = True,
deterministic: bool = True,
) -> FlaxBaseModelOutputWithNoAttention:
hidden_state, hidden_states = self.stages(
hidden_state, output_hidden_states=output_hidden_states, deterministic=deterministic
)
if output_hidden_states:
hidden_states = hidden_states + (hidden_state.transpose(0, 3, 1, 2),)
if not return_dict:
return tuple(v for v in [hidden_state, hidden_states] if v is not None)
return FlaxBaseModelOutputWithNoAttention(
last_hidden_state=hidden_state,
hidden_states=hidden_states,
)
class FlaxRegNetPreTrainedModel(FlaxPreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = RegNetConfig
base_model_prefix = "regnet"
main_input_name = "pixel_values"
module_class: nn.Module = None
def __init__(
self,
config: RegNetConfig,
input_shape=(1, 224, 224, 3),
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs,
):
module = self.module_class(config=config, dtype=dtype, **kwargs)
if input_shape is None:
input_shape = (1, config.image_size, config.image_size, config.num_channels)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
pixel_values = jnp.zeros(input_shape, dtype=self.dtype)
rngs = {"params": rng}
random_params = self.module.init(rngs, pixel_values, return_dict=False)
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
@add_start_docstrings_to_model_forward(REGNET_INPUTS_DOCSTRING)
def __call__(
self,
pixel_values,
params: dict = None,
train: bool = False,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1))
rngs = {}
return self.module.apply(
{
"params": params["params"] if params is not None else self.params["params"],
"batch_stats": params["batch_stats"] if params is not None else self.params["batch_stats"],
},
jnp.array(pixel_values, dtype=jnp.float32),
not train,
output_hidden_states,
return_dict,
rngs=rngs,
mutable=["batch_stats"] if train else False,
)
class FlaxRegNetModule(nn.Module):
config: RegNetConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.embedder = FlaxRegNetEmbeddings(self.config, dtype=self.dtype)
self.encoder = FlaxRegNetEncoder(self.config, dtype=self.dtype)
self.pooler = partial(
nn.avg_pool,
padding=((0, 0), (0, 0)),
)
def __call__(
self,
pixel_values,
deterministic: bool = True,
output_hidden_states: bool = False,
return_dict: bool = True,
) -> FlaxBaseModelOutputWithPoolingAndNoAttention:
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
embedding_output = self.embedder(pixel_values, deterministic=deterministic)
encoder_outputs = self.encoder(
embedding_output,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=deterministic,
)
last_hidden_state = encoder_outputs[0]
pooled_output = self.pooler(
last_hidden_state,
window_shape=(last_hidden_state.shape[1], last_hidden_state.shape[2]),
strides=(last_hidden_state.shape[1], last_hidden_state.shape[2]),
).transpose(0, 3, 1, 2)
last_hidden_state = last_hidden_state.transpose(0, 3, 1, 2)
if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
return FlaxBaseModelOutputWithPoolingAndNoAttention(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
)
@add_start_docstrings(
"The bare RegNet model outputting raw features without any specific head on top.",
REGNET_START_DOCSTRING,
)
class FlaxRegNetModel(FlaxRegNetPreTrainedModel):
module_class = FlaxRegNetModule
FLAX_VISION_MODEL_DOCSTRING = """
Returns:
Examples:
```
>>> from transformers import AutoImageProcessor, FlaxRegNetModel
>>> from PIL import Image
>>> import requests
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> image_processor = AutoImageProcessor.from_pretrained("facebook/regnet-y-040")
>>> model = FlaxRegNetModel.from_pretrained("facebook/regnet-y-040")
>>> inputs = image_processor(images=image, return_tensors="np")
>>> outputs = model(**inputs)
>>> last_hidden_states = outputs.last_hidden_state
```
"""
overwrite_call_docstring(FlaxRegNetModel, FLAX_VISION_MODEL_DOCSTRING)
class FlaxRegNetClassifierCollection(nn.Module):
config: RegNetConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype, name="1")
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
return self.classifier(x)
class FlaxRegNetForImageClassificationModule(nn.Module):
config: RegNetConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.regnet = FlaxRegNetModule(config=self.config, dtype=self.dtype)
if self.config.num_labels > 0:
self.classifier = FlaxRegNetClassifierCollection(self.config, dtype=self.dtype)
else:
self.classifier = Identity()
def __call__(
self,
pixel_values=None,
deterministic: bool = True,
output_hidden_states=None,
return_dict=None,
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.regnet(
pixel_values,
deterministic=deterministic,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
pooled_output = outputs.pooler_output if return_dict else outputs[1]
logits = self.classifier(pooled_output[:, :, 0, 0])
if not return_dict:
output = (logits,) + outputs[2:]
return output
return FlaxImageClassifierOutputWithNoAttention(logits=logits, hidden_states=outputs.hidden_states)
@add_start_docstrings(
"""
用于在RegNet模型顶部添加图像分类头的模型,例如在ImageNet上使用线性层对池化特征进行分类。
""",
REGNET_START_DOCSTRING,
)
class FlaxRegNetForImageClassification(FlaxRegNetPreTrainedModel):
module_class = FlaxRegNetForImageClassificationModule
FLAX_VISION_CLASSIF_DOCSTRING = """
Returns:
Example:
```
>>> from transformers import AutoImageProcessor, FlaxRegNetForImageClassification
>>> from PIL import Image
>>> import jax
>>> import requests
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> image_processor = AutoImageProcessor.from_pretrained("facebook/regnet-y-040")
>>> model = FlaxRegNetForImageClassification.from_pretrained("facebook/regnet-y-040")
>>> inputs = image_processor(images=image, return_tensors="np")
>>> outputs = model(**inputs)
>>> logits = outputs.logits
>>> # model predicts one of the 1000 ImageNet classes
# 注释:
# 此部分为文档字符串`FLAX_VISION_CLASSIF_DOCSTRING`,提供了该模型的返回值说明和使用示例。
# 使用 JAX 提供的 numpy 模块计算 logits 中每个样本预测的类别索引
predicted_class_idx = jax.numpy.argmax(logits, axis=-1)
# 打印预测的类别,根据模型配置中的 id2label 映射将索引转换为标签名称并输出
print("Predicted class:", model.config.id2label[predicted_class_idx.item()])
"""
覆盖调用函数的文档字符串为指定的文档字符串。
"""
overwrite_call_docstring(FlaxRegNetForImageClassification, FLAX_VISION_CLASSIF_DOCSTRING)
"""
向指定类追加或替换返回值文档字符串。
"""
append_replace_return_docstrings(
FlaxRegNetForImageClassification,
output_type=FlaxImageClassifierOutputWithNoAttention,
config_class=RegNetConfig,
)
.\models\regnet\modeling_regnet.py
""" PyTorch RegNet model."""
from typing import Optional
import torch
import torch.utils.checkpoint
from torch import Tensor, nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_outputs import (
BaseModelOutputWithNoAttention,
BaseModelOutputWithPoolingAndNoAttention,
ImageClassifierOutputWithNoAttention,
)
from ...modeling_utils import PreTrainedModel
from ...utils import logging
from .configuration_regnet import RegNetConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "RegNetConfig"
_CHECKPOINT_FOR_DOC = "facebook/regnet-y-040"
_EXPECTED_OUTPUT_SHAPE = [1, 1088, 7, 7]
_IMAGE_CLASS_CHECKPOINT = "facebook/regnet-y-040"
_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
REGNET_PRETRAINED_MODEL_ARCHIVE_LIST = [
"facebook/regnet-y-040",
]
class RegNetConvLayer(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
stride: int = 1,
groups: int = 1,
activation: Optional[str] = "relu",
):
super().__init__()
self.convolution = nn.Conv2d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=kernel_size // 2,
groups=groups,
bias=False,
)
self.normalization = nn.BatchNorm2d(out_channels)
self.activation = ACT2FN[activation] if activation is not None else nn.Identity()
def forward(self, hidden_state):
hidden_state = self.convolution(hidden_state)
hidden_state = self.normalization(hidden_state)
hidden_state = self.activation(hidden_state)
return hidden_state
class RegNetEmbeddings(nn.Module):
"""
RegNet Embedddings (stem) composed of a single aggressive convolution.
"""
def __init__(self, config: RegNetConfig):
super().__init__()
self.embedder = RegNetConvLayer(
config.num_channels, config.embedding_size, kernel_size=3, stride=2, activation=config.hidden_act
)
self.num_channels = config.num_channels
def forward(self, pixel_values):
num_channels = pixel_values.shape[1]
if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
hidden_state = self.embedder(pixel_values)
return hidden_state
class RegNetShortCut(nn.Module):
"""
RegNet的shortcut,用于将残差特征投影到正确的大小。如果需要,还用于使用`stride=2`对输入进行下采样。
"""
def __init__(self, in_channels: int, out_channels: int, stride: int = 2):
super().__init__()
self.convolution = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False)
self.normalization = nn.BatchNorm2d(out_channels)
def forward(self, input: Tensor) -> Tensor:
hidden_state = self.convolution(input)
hidden_state = self.normalization(hidden_state)
return hidden_state
class RegNetSELayer(nn.Module):
"""
压缩与激发层(SE),在[Squeeze-and-Excitation Networks](https://arxiv.org/abs/1709.01507)中提出。
"""
def __init__(self, in_channels: int, reduced_channels: int):
super().__init__()
self.pooler = nn.AdaptiveAvgPool2d((1, 1))
self.attention = nn.Sequential(
nn.Conv2d(in_channels, reduced_channels, kernel_size=1),
nn.ReLU(),
nn.Conv2d(reduced_channels, in_channels, kernel_size=1),
nn.Sigmoid(),
)
def forward(self, hidden_state):
pooled = self.pooler(hidden_state)
attention = self.attention(pooled)
hidden_state = hidden_state * attention
return hidden_state
class RegNetXLayer(nn.Module):
"""
RegNet的层,由三个3x3的卷积组成,与ResNet的瓶颈层相同,但reduction=1。
"""
def __init__(self, config: RegNetConfig, in_channels: int, out_channels: int, stride: int = 1):
super().__init__()
should_apply_shortcut = in_channels != out_channels or stride != 1
groups = max(1, out_channels // config.groups_width)
self.shortcut = (
RegNetShortCut(in_channels, out_channels, stride=stride) if should_apply_shortcut else nn.Identity()
)
self.layer = nn.Sequential(
RegNetConvLayer(in_channels, out_channels, kernel_size=1, activation=config.hidden_act),
RegNetConvLayer(out_channels, out_channels, stride=stride, groups=groups, activation=config.hidden_act),
RegNetConvLayer(out_channels, out_channels, kernel_size=1, activation=None),
)
self.activation = ACT2FN[config.hidden_act]
def forward(self, hidden_state):
residual = hidden_state
hidden_state = self.layer(hidden_state)
residual = self.shortcut(residual)
hidden_state += residual
hidden_state = self.activation(hidden_state)
return hidden_state
class RegNetYLayer(nn.Module):
"""
RegNet的Y层:一个带有Squeeze和Excitation的X层。
"""
def __init__(self, config: RegNetConfig, in_channels: int, out_channels: int, stride: int = 1):
super().__init__()
should_apply_shortcut = in_channels != out_channels or stride != 1
groups = max(1, out_channels // config.groups_width)
self.shortcut = (
RegNetShortCut(in_channels, out_channels, stride=stride) if should_apply_shortcut else nn.Identity()
)
self.layer = nn.Sequential(
RegNetConvLayer(in_channels, out_channels, kernel_size=1, activation=config.hidden_act),
RegNetConvLayer(out_channels, out_channels, stride=stride, groups=groups, activation=config.hidden_act),
RegNetSELayer(out_channels, reduced_channels=int(round(in_channels / 4))),
RegNetConvLayer(out_channels, out_channels, kernel_size=1, activation=None),
)
self.activation = ACT2FN[config.hidden_act]
def forward(self, hidden_state):
residual = hidden_state
hidden_state = self.layer(hidden_state)
residual = self.shortcut(residual)
hidden_state += residual
hidden_state = self.activation(hidden_state)
return hidden_state
class RegNetStage(nn.Module):
"""
A RegNet stage composed by stacked layers.
"""
def __init__(
self,
config: RegNetConfig,
in_channels: int,
out_channels: int,
stride: int = 2,
depth: int = 2,
):
super().__init__()
layer = RegNetXLayer if config.layer_type == "x" else RegNetYLayer
self.layers = nn.Sequential(
layer(
config,
in_channels,
out_channels,
stride=stride,
),
*[layer(config, out_channels, out_channels) for _ in range(depth - 1)],
)
def forward(self, hidden_state):
hidden_state = self.layers(hidden_state)
return hidden_state
class RegNetEncoder(nn.Module):
def __init__(self, config: RegNetConfig):
super().__init__()
self.stages = nn.ModuleList([])
self.stages.append(
RegNetStage(
config,
config.embedding_size,
config.hidden_sizes[0],
stride=2 if config.downsample_in_first_stage else 1,
depth=config.depths[0],
)
)
in_out_channels = zip(config.hidden_sizes, config.hidden_sizes[1:])
for (in_channels, out_channels), depth in zip(in_out_channels, config.depths[1:]):
self.stages.append(RegNetStage(config, in_channels, out_channels, depth=depth))
def forward(
self, hidden_state: Tensor, output_hidden_states: bool = False, return_dict: bool = True
) -> BaseModelOutputWithNoAttention:
hidden_states = () if output_hidden_states else None
for stage_module in self.stages:
if output_hidden_states:
hidden_states = hidden_states + (hidden_state,)
hidden_state = stage_module(hidden_state)
if output_hidden_states:
hidden_states = hidden_states + (hidden_state,)
if not return_dict:
return tuple(v for v in [hidden_state, hidden_states] if v is not None)
return BaseModelOutputWithNoAttention(last_hidden_state=hidden_state, hidden_states=hidden_states)
class RegNetPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = RegNetConfig
base_model_prefix = "regnet"
main_input_name = "pixel_values"
def _init_weights(self, module):
if isinstance(module, nn.Conv2d):
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(module.weight, 1)
nn.init.constant_(module.bias, 0)
REGNET_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
as a regular PyTorch Module and refer to the PyTorch documentation for all matters related to general usage and
behavior.
Parameters:
config ([`RegNetConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
REGNET_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
[`ConvNextImageProcessor.__call__`] for details.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
"""
@add_start_docstrings(
"The bare RegNet model outputting raw features without any specific head on top.",
REGNET_START_DOCSTRING,
)
class RegNetModel(RegNetPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.config = config
self.embedder = RegNetEmbeddings(config)
self.encoder = RegNetEncoder(config)
self.pooler = nn.AdaptiveAvgPool2d((1, 1))
self.post_init()
@add_start_docstrings_to_model_forward(REGNET_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=BaseModelOutputWithPoolingAndNoAttention,
config_class=_CONFIG_FOR_DOC,
modality="vision",
expected_output=_EXPECTED_OUTPUT_SHAPE,
)
def forward(
self, pixel_values: Tensor, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None
):
"""
Perform the forward pass of the RegNet model.
Args:
pixel_values (torch.FloatTensor): Pixel values of shape `(batch_size, num_channels, height, width)`.
These values are obtained using an `AutoImageProcessor`.
output_hidden_states (bool, optional): Whether or not to return hidden states of all layers.
Refer to `hidden_states` in the returned tensors for details.
return_dict (bool, optional): Whether to return a `ModelOutput` instead of a tuple.
Returns:
Depending on `return_dict`, either a `ModelOutput` or a tuple of outputs from the model.
"""
pass
) -> BaseModelOutputWithPoolingAndNoAttention:
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
embedding_output = self.embedder(pixel_values)
encoder_outputs = self.encoder(
embedding_output, output_hidden_states=output_hidden_states, return_dict=return_dict
)
last_hidden_state = encoder_outputs[0]
pooled_output = self.pooler(last_hidden_state)
if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPoolingAndNoAttention(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
)
@add_start_docstrings(
"""
RegNet Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
ImageNet.
""",
REGNET_START_DOCSTRING,
)
class RegNetForImageClassification(RegNetPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.regnet = RegNetModel(config)
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(config.hidden_sizes[-1], config.num_labels) if config.num_labels > 0 else nn.Identity(),
)
self.post_init()
@add_start_docstrings_to_model_forward(REGNET_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_IMAGE_CLASS_CHECKPOINT,
output_type=ImageClassifierOutputWithNoAttention,
config_class=_CONFIG_FOR_DOC,
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
)
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
) -> ImageClassifierOutputWithNoAttention:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.regnet(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict)
pooled_output = outputs.pooler_output if return_dict else outputs[1]
logits = self.classifier(pooled_output)
loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[2:]
return (loss,) + output if loss is not None else output
return ImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states)
.\models\regnet\modeling_tf_regnet.py
""" TensorFlow RegNet 模型."""
from typing import Optional, Tuple, Union
import tensorflow as tf
from ...activations_tf import ACT2FN
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_tf_outputs import (
TFBaseModelOutputWithNoAttention,
TFBaseModelOutputWithPoolingAndNoAttention,
TFSequenceClassifierOutput,
)
from ...modeling_tf_utils import (
TFPreTrainedModel,
TFSequenceClassificationLoss,
keras,
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list
from ...utils import logging
from .configuration_regnet import RegNetConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "RegNetConfig"
_CHECKPOINT_FOR_DOC = "facebook/regnet-y-040"
_EXPECTED_OUTPUT_SHAPE = [1, 1088, 7, 7]
_IMAGE_CLASS_CHECKPOINT = "facebook/regnet-y-040"
_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
TF_REGNET_PRETRAINED_MODEL_ARCHIVE_LIST = [
"facebook/regnet-y-040",
]
class TFRegNetConvLayer(keras.layers.Layer):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
stride: int = 1,
groups: int = 1,
activation: Optional[str] = "relu",
**kwargs,
):
super().__init__(**kwargs)
self.padding = keras.layers.ZeroPadding2D(padding=kernel_size // 2)
self.convolution = keras.layers.Conv2D(
filters=out_channels,
kernel_size=kernel_size,
strides=stride,
padding="VALID",
groups=groups,
use_bias=False,
name="convolution",
)
self.normalization = keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.9, name="normalization")
self.activation = ACT2FN[activation] if activation is not None else tf.identity
self.in_channels = in_channels
self.out_channels = out_channels
def call(self, hidden_state):
hidden_state = self.convolution(self.padding(hidden_state))
hidden_state = self.normalization(hidden_state)
hidden_state = self.activation(hidden_state)
return hidden_state
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "convolution", None) is not None:
with tf.name_scope(self.convolution.name):
self.convolution.build([None, None, None, self.in_channels])
if getattr(self, "normalization", None) is not None:
with tf.name_scope(self.normalization.name):
self.normalization.build([None, None, None, self.out_channels])
class TFRegNetEmbeddings(keras.layers.Layer):
"""
RegNet Embeddings (stem) composed of a single aggressive convolution.
"""
def __init__(self, config: RegNetConfig, **kwargs):
super().__init__(**kwargs)
self.num_channels = config.num_channels
self.embedder = TFRegNetConvLayer(
in_channels=config.num_channels,
out_channels=config.embedding_size,
kernel_size=3,
stride=2,
activation=config.hidden_act,
name="embedder",
)
def call(self, pixel_values):
num_channels = shape_list(pixel_values)[1]
if tf.executing_eagerly() and num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1))
hidden_state = self.embedder(pixel_values)
return hidden_state
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "embedder", None) is not None:
with tf.name_scope(self.embedder.name):
self.embedder.build(None)
class TFRegNetShortCut(keras.layers.Layer):
"""
RegNet shortcut, used to project the residual features to the correct size. If needed, it is also used to
downsample the input using `stride=2`.
"""
def __init__(self, in_channels: int, out_channels: int, stride: int = 2, **kwargs):
super().__init__(**kwargs)
self.convolution = keras.layers.Conv2D(
filters=out_channels, kernel_size=1, strides=stride, use_bias=False, name="convolution"
)
self.normalization = keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.9, name="normalization")
self.in_channels = in_channels
self.out_channels = out_channels
def call(self, inputs: tf.Tensor, training: bool = False) -> tf.Tensor:
return self.normalization(self.convolution(inputs), training=training)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "convolution", None) is not None:
with tf.name_scope(self.convolution.name):
self.convolution.build([None, None, None, self.in_channels])
if getattr(self, "normalization", None) is not None:
with tf.name_scope(self.normalization.name):
self.normalization.build([None, None, None, self.out_channels])
class TFRegNetSELayer(keras.layers.Layer):
"""
Placeholder for the SE (Squeeze-and-Excitation) Layer in RegNet, to be implemented.
This layer is intended for enhancing channel-wise relationships adaptively.
"""
Squeeze and Excitation layer (SE) proposed in [Squeeze-and-Excitation Networks](https://arxiv.org/abs/1709.01507).
"""
# 定义 Squeeze-and-Excitation(SE)层的类
def __init__(self, in_channels: int, reduced_channels: int, **kwargs):
super().__init__(**kwargs)
# 创建全局平均池化层,用于计算特征图的平均值
self.pooler = keras.layers.GlobalAveragePooling2D(keepdims=True, name="pooler")
# 创建注意力机制的两个卷积层,用于生成注意力权重
self.attention = [
keras.layers.Conv2D(filters=reduced_channels, kernel_size=1, activation="relu", name="attention.0"),
keras.layers.Conv2D(filters=in_channels, kernel_size=1, activation="sigmoid", name="attention.2"),
]
# 记录输入通道数和降维后的通道数
self.in_channels = in_channels
self.reduced_channels = reduced_channels
# 定义 SE 层的前向传播函数
def call(self, hidden_state):
# 对输入的特征图进行全局平均池化,生成池化后的结果
pooled = self.pooler(hidden_state)
# 对池化后的结果分别通过两个注意力卷积层,生成注意力权重
for layer_module in self.attention:
pooled = layer_module(pooled)
# 将原始特征图与注意力权重相乘,增强特征表示
hidden_state = hidden_state * pooled
return hidden_state
# 构建 SE 层,确保每个组件都被正确地构建和连接
def build(self, input_shape=None):
if self.built:
return
self.built = True
# 构建全局平均池化层
if getattr(self, "pooler", None) is not None:
with tf.name_scope(self.pooler.name):
self.pooler.build((None, None, None, None))
# 构建注意力卷积层
if getattr(self, "attention", None) is not None:
with tf.name_scope(self.attention[0].name):
self.attention[0].build([None, None, None, self.in_channels])
with tf.name_scope(self.attention[1].name):
self.attention[1].build([None, None, None, self.reduced_channels])
# 定义 TFRegNetXLayer 类,表示 RegNet 模型中的一个层,类似于 ResNet 的瓶颈层,但具有不同的特性。
class TFRegNetXLayer(keras.layers.Layer):
"""
RegNet's layer composed by three `3x3` convolutions, same as a ResNet bottleneck layer with reduction = 1.
"""
# 初始化方法,设置层的参数和结构
def __init__(self, config: RegNetConfig, in_channels: int, out_channels: int, stride: int = 1, **kwargs):
super().__init__(**kwargs)
# 检查是否需要应用快捷连接,根据输入和输出通道数以及步长来判断
should_apply_shortcut = in_channels != out_channels or stride != 1
# 如果需要应用快捷连接,则创建 TFRegNetShortCut 实例作为 shortcut 属性;否则创建线性激活函数作为 shortcut 属性
self.shortcut = (
TFRegNetShortCut(in_channels, out_channels, stride=stride, name="shortcut")
if should_apply_shortcut
else keras.layers.Activation("linear", name="shortcut")
)
# 定义三个卷积层的列表,每一层都是 TFRegNetConvLayer 类的实例,用于构建层内部的特征提取流程
self.layers = [
TFRegNetConvLayer(in_channels, out_channels, kernel_size=1, activation=config.hidden_act, name="layer.0"),
TFRegNetConvLayer(
out_channels, out_channels, stride=stride, groups=max(1, out_channels // config.groups_width),
activation=config.hidden_act, name="layer.1"
),
TFRegNetConvLayer(out_channels, out_channels, kernel_size=1, activation=None, name="layer.2"),
]
# 激活函数根据配置文件中的隐藏激活函数来选择
self.activation = ACT2FN[config.hidden_act]
# 定义层的前向传播逻辑
def call(self, hidden_state):
# 保存输入的残差连接
residual = hidden_state
# 遍历每一层卷积,依次对 hidden_state 进行特征提取
for layer_module in self.layers:
hidden_state = layer_module(hidden_state)
# 将残差连接通过快捷连接层进行处理
residual = self.shortcut(residual)
# 将特征提取后的 hidden_state 与处理后的残差相加
hidden_state += residual
# 使用预定义的激活函数对输出进行激活
hidden_state = self.activation(hidden_state)
return hidden_state
# 构建方法,用于在第一次调用前构建层的变量
def build(self, input_shape=None):
if self.built:
return
self.built = True
# 如果定义了快捷连接,则构建快捷连接层
if getattr(self, "shortcut", None) is not None:
with tf.name_scope(self.shortcut.name):
self.shortcut.build(None)
# 构建每一个卷积层
if getattr(self, "layers", None) is not None:
for layer in self.layers:
with tf.name_scope(layer.name):
layer.build(None)
class TFRegNetYLayer(keras.layers.Layer):
"""
RegNet's Y layer: an X layer with Squeeze and Excitation.
"""
# 初始化函数,用于初始化模型对象
def __init__(self, config: RegNetConfig, in_channels: int, out_channels: int, stride: int = 1, **kwargs):
# 调用父类的初始化方法
super().__init__(**kwargs)
# 确定是否应用快捷连接(shortcut),条件是输入通道数不等于输出通道数或步长不为1
should_apply_shortcut = in_channels != out_channels or stride != 1
# 计算组数,确保至少有一个组
groups = max(1, out_channels // config.groups_width)
# 如果应用快捷连接,则创建一个 TFRegNetShortCut 对象作为快捷连接,否则创建线性激活函数作为快捷连接
self.shortcut = (
TFRegNetShortCut(in_channels, out_channels, stride=stride, name="shortcut")
if should_apply_shortcut
else keras.layers.Activation("linear", name="shortcut")
)
# 定义模型的层列表,包括几个 TFRegNetConvLayer 层和一个 TFRegNetSELayer 层
self.layers = [
TFRegNetConvLayer(in_channels, out_channels, kernel_size=1, activation=config.hidden_act, name="layer.0"),
TFRegNetConvLayer(
out_channels, out_channels, stride=stride, groups=groups, activation=config.hidden_act, name="layer.1"
),
TFRegNetSELayer(out_channels, reduced_channels=int(round(in_channels / 4)), name="layer.2"),
TFRegNetConvLayer(out_channels, out_channels, kernel_size=1, activation=None, name="layer.3"),
]
# 激活函数使用根据配置选择的激活函数
self.activation = ACT2FN[config.hidden_act]
# 调用函数,用于模型的前向传播
def call(self, hidden_state):
# 将输入状态作为残差
residual = hidden_state
# 遍历模型的每一层,并对输入状态进行处理
for layer_module in self.layers:
hidden_state = layer_module(hidden_state)
# 将残差通过快捷连接处理
residual = self.shortcut(residual)
# 将处理后的状态与残差相加
hidden_state += residual
# 应用激活函数到最终的隐藏状态
hidden_state = self.activation(hidden_state)
# 返回最终的隐藏状态
return hidden_state
# 构建函数,用于构建模型的层次结构
def build(self, input_shape=None):
# 如果模型已经构建过,则直接返回
if self.built:
return
# 标记模型已经构建
self.built = True
# 如果存在快捷连接,则构建快捷连接
if getattr(self, "shortcut", None) is not None:
with tf.name_scope(self.shortcut.name):
self.shortcut.build(None)
# 遍历每一层,并构建每一层
if getattr(self, "layers", None) is not None:
for layer in self.layers:
with tf.name_scope(layer.name):
layer.build(None)
class TFRegNetStage(keras.layers.Layer):
"""
A RegNet stage composed by stacked layers.
"""
def __init__(
self, config: RegNetConfig, in_channels: int, out_channels: int, stride: int = 2, depth: int = 2, **kwargs
):
super().__init__(**kwargs)
# 根据配置选择使用 TFRegNetXLayer 或 TFRegNetYLayer 作为层
layer = TFRegNetXLayer if config.layer_type == "x" else TFRegNetYLayer
# 创建层列表,第一层可能使用 stride=2 进行下采样
self.layers = [
layer(config, in_channels, out_channels, stride=stride, name="layers.0"),
*[layer(config, out_channels, out_channels, name=f"layers.{i+1}") for i in range(depth - 1)],
]
def call(self, hidden_state):
# 逐层调用各层的 call 方法
for layer_module in self.layers:
hidden_state = layer_module(hidden_state)
return hidden_state
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "layers", None) is not None:
for layer in self.layers:
with tf.name_scope(layer.name):
layer.build(None)
class TFRegNetEncoder(keras.layers.Layer):
def __init__(self, config: RegNetConfig, **kwargs):
super().__init__(**kwargs)
self.stages = []
# 根据配置中的 downsample_in_first_stage 决定第一阶段是否进行输入的下采样
self.stages.append(
TFRegNetStage(
config,
config.embedding_size,
config.hidden_sizes[0],
stride=2 if config.downsample_in_first_stage else 1,
depth=config.depths[0],
name="stages.0",
)
)
# 构建多个阶段,每个阶段包含多个 TFRegNetStage
in_out_channels = zip(config.hidden_sizes, config.hidden_sizes[1:])
for i, ((in_channels, out_channels), depth) in enumerate(zip(in_out_channels, config.depths[1:])):
self.stages.append(TFRegNetStage(config, in_channels, out_channels, depth=depth, name=f"stages.{i+1}"))
def call(
self, hidden_state: tf.Tensor, output_hidden_states: bool = False, return_dict: bool = True
) -> TFBaseModelOutputWithNoAttention:
hidden_states = () if output_hidden_states else None
# 逐阶段调用 TFRegNetStage 的 call 方法,收集隐藏状态
for stage_module in self.stages:
if output_hidden_states:
hidden_states = hidden_states + (hidden_state,)
hidden_state = stage_module(hidden_state)
if output_hidden_states:
hidden_states = hidden_states + (hidden_state,)
# 根据 return_dict 决定返回的结果类型
if not return_dict:
return tuple(v for v in [hidden_state, hidden_states] if v is not None)
return TFBaseModelOutputWithNoAttention(last_hidden_state=hidden_state, hidden_states=hidden_states)
def build(self, input_shape=None):
if self.built:
return
self.built = True
for stage in self.stages:
with tf.name_scope(stage.name):
stage.build(None)
class TFRegNetMainLayer(keras.layers.Layer):
# 使用 RegNetConfig 类来配置模型参数
config_class = RegNetConfig
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.config = config
# 创建 TFRegNetEmbeddings 实例作为嵌入层
self.embedder = TFRegNetEmbeddings(config, name="embedder")
# 创建 TFRegNetEncoder 实例作为编码器
self.encoder = TFRegNetEncoder(config, name="encoder")
# 创建全局平均池化层,用于池化特征
self.pooler = keras.layers.GlobalAveragePooling2D(keepdims=True, name="pooler")
@unpack_inputs
def call(
self,
pixel_values: tf.Tensor,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
) -> TFBaseModelOutputWithPoolingAndNoAttention:
# 根据需要设置是否输出隐藏状态和是否返回字典形式结果
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# 通过嵌入层处理输入数据
embedding_output = self.embedder(pixel_values, training=training)
# 使用编码器处理嵌入输出
encoder_outputs = self.encoder(
embedding_output, output_hidden_states=output_hidden_states, return_dict=return_dict, training=training
)
# 获取最后一个隐藏状态
last_hidden_state = encoder_outputs[0]
# 对最终池化的输出进行全局维度转换
pooled_output = self.pooler(last_hidden_state)
# 将池化的输出格式转换为 NCHW 格式,确保模块的一致性
pooled_output = tf.transpose(pooled_output, perm=(0, 3, 1, 2))
last_hidden_state = tf.transpose(last_hidden_state, perm=(0, 3, 1, 2))
# 如果需要输出隐藏状态,则将所有隐藏状态也转换为 NCHW 格式
if output_hidden_states:
hidden_states = tuple([tf.transpose(h, perm=(0, 3, 1, 2)) for h in encoder_outputs[1]])
# 如果不返回字典形式结果,则返回元组形式的输出
if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
# 如果返回字典形式结果,则构造 TFBaseModelOutputWithPoolingAndNoAttention 对象
return TFBaseModelOutputWithPoolingAndNoAttention(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=hidden_states if output_hidden_states else encoder_outputs.hidden_states,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
# 如果嵌入层已定义,则构建嵌入层
if getattr(self, "embedder", None) is not None:
with tf.name_scope(self.embedder.name):
self.embedder.build(None)
# 如果编码器已定义,则构建编码器
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
# 如果池化层已定义,则构建池化层
if getattr(self, "pooler", None) is not None:
with tf.name_scope(self.pooler.name):
self.pooler.build((None, None, None, None))
class TFRegNetPreTrainedModel(TFPreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
# 使用 RegNetConfig 类来配置模型参数
config_class = RegNetConfig
# 指定基础模型的前缀名称为 "regnet"
base_model_prefix = "regnet"
# 模型的主要输入名称为 "pixel_values"
main_input_name = "pixel_values"
@property
# 定义一个方法input_signature,用于返回输入数据的签名信息,通常在 TensorFlow 的模型定义中使用
def input_signature(self):
# 返回一个字典,描述了输入张量的规格和数据类型
return {"pixel_values": tf.TensorSpec(shape=(None, self.config.num_channels, 224, 224), dtype=tf.float32)}
# 定义用于文档字符串的模型描述和参数说明,使用原始的三重引号格式化字符串
REGNET_START_DOCSTRING = r"""
This model is a Tensorflow
[keras.layers.Layer](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer) sub-class. Use it as a
regular Tensorflow Module and refer to the Tensorflow documentation for all matter related to general usage and
behavior.
Parameters:
config ([`RegNetConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
"""
# 定义用于输入参数文档字符串的格式化字符串
REGNET_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
[`ConveNextImageProcessor.__call__`] for details.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
# 使用装饰器为类添加起始文档字符串和额外的模型前向传播方法文档
@add_start_docstrings(
"The bare RegNet model outputting raw features without any specific head on top.",
REGNET_START_DOCSTRING,
)
class TFRegNetModel(TFRegNetPreTrainedModel):
def __init__(self, config: RegNetConfig, *inputs, **kwargs):
# 调用父类的初始化方法,传递模型配置和额外的输入参数
super().__init__(config, *inputs, **kwargs)
# 创建主要的RegNet层,使用给定的配置和命名为"regnet"
self.regnet = TFRegNetMainLayer(config, name="regnet")
# 使用装饰器为call方法添加起始文档字符串、输入参数和代码示例文档
@unpack_inputs
@add_start_docstrings_to_model_forward(REGNET_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TFBaseModelOutputWithPoolingAndNoAttention,
config_class=_CONFIG_FOR_DOC,
modality="vision",
expected_output=_EXPECTED_OUTPUT_SHAPE,
)
def call(
self,
pixel_values: tf.Tensor,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
) -> Union[TFBaseModelOutputWithPoolingAndNoAttention, Tuple[tf.Tensor]]:
# 如果没有明确指定输出隐藏状态,使用模型配置中的设定
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
# 如果没有明确指定返回字典形式的输出,使用模型配置中的设定
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# 调用RegNet主层进行前向传播,传递像素值、输出隐藏状态选项、返回字典选项和训练模式
outputs = self.regnet(
pixel_values=pixel_values,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
# 如果不返回字典形式的输出,以元组形式返回
if not return_dict:
return (outputs[0],) + outputs[1:]
# 返回TFBaseModelOutputWithPoolingAndNoAttention类型的输出,包括最终隐藏状态和池化输出
return TFBaseModelOutputWithPoolingAndNoAttention(
last_hidden_state=outputs.last_hidden_state,
pooler_output=outputs.pooler_output,
hidden_states=outputs.hidden_states,
)
# 如果模型已经构建完成,则直接返回,不进行重复构建
if self.built:
return
# 将模型标记为已构建状态
self.built = True
# 检查是否存在名为 "regnet" 的属性,如果存在则执行以下操作
if getattr(self, "regnet", None) is not None:
# 使用 TensorFlow 的命名空间为 regnet 构建模型
with tf.name_scope(self.regnet.name):
# 调用 regnet 对象的 build 方法,传入 None 作为输入形状
self.regnet.build(None)
@add_start_docstrings(
"""
RegNet Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
ImageNet.
""",
REGNET_START_DOCSTRING,
)
class TFRegNetForImageClassification(TFRegNetPreTrainedModel, TFSequenceClassificationLoss):
def __init__(self, config: RegNetConfig, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
self.regnet = TFRegNetMainLayer(config, name="regnet")
# classification head
self.classifier = [
keras.layers.Flatten(), # 将输入展平以供后续全连接层使用
keras.layers.Dense(config.num_labels, name="classifier.1") if config.num_labels > 0 else tf.identity, # 分类器的全连接层
]
@unpack_inputs
@add_start_docstrings_to_model_forward(REGNET_INPUTS_DOCSTRING) # 添加模型前向传播的文档字符串
@add_code_sample_docstrings(
checkpoint=_IMAGE_CLASS_CHECKPOINT,
output_type=TFSequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
)
def call(
self,
pixel_values: Optional[tf.Tensor] = None,
labels: Optional[tf.Tensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
) -> Union[TFSequenceClassifierOutput, Tuple[tf.Tensor]]:
r"""
labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states # 设置是否输出隐藏状态
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict # 设置是否使用返回字典
outputs = self.regnet(
pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict, training=training # 调用 RegNet 主层进行前向传播
)
pooled_output = outputs.pooler_output if return_dict else outputs[1] # 获取汇聚输出或指定位置的输出
flattened_output = self.classifier[0](pooled_output) # 使用展平层处理汇聚输出
logits = self.classifier[1](flattened_output) # 使用全连接层计算 logits
loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits) # 计算损失,若无标签则损失为 None
if not return_dict:
output = (logits,) + outputs[2:] # 组合输出,包括 logits 和可能的其他输出
return ((loss,) + output) if loss is not None else output # 返回损失与输出,或者仅输出
return TFSequenceClassifierOutput(loss=loss, logits=logits, hidden_states=outputs.hidden_states) # 返回包装的输出对象
# 定义神经网络层的构建方法,如果已经构建过则直接返回
def build(self, input_shape=None):
if self.built:
return
# 将标志位设置为已构建
self.built = True
# 如果存在名为"regnet"的属性,并且不为None,则构建regnet部分
if getattr(self, "regnet", None) is not None:
# 在命名空间内构建regnet
with tf.name_scope(self.regnet.name):
self.regnet.build(None)
# 如果存在名为"classifier"的属性,并且不为None,则构建classifier部分
if getattr(self, "classifier", None) is not None:
# 在命名空间内构建classifier[1]
with tf.name_scope(self.classifier[1].name):
# 构建classifier[1],输入形状为[None, None, None, self.config.hidden_sizes[-1]]
self.classifier[1].build([None, None, None, self.config.hidden_sizes[-1]])
.\models\regnet\__init__.py
from typing import TYPE_CHECKING
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_flax_available,
is_tf_available,
is_torch_available,
)
_import_structure = {"configuration_regnet": ["REGNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "RegNetConfig"]}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_regnet"] = [
"REGNET_PRETRAINED_MODEL_ARCHIVE_LIST",
"RegNetForImageClassification",
"RegNetModel",
"RegNetPreTrainedModel",
]
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_tf_regnet"] = [
"TF_REGNET_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFRegNetForImageClassification",
"TFRegNetModel",
"TFRegNetPreTrainedModel",
]
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_flax_regnet"] = [
"FlaxRegNetForImageClassification",
"FlaxRegNetModel",
"FlaxRegNetPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_regnet import REGNET_PRETRAINED_CONFIG_ARCHIVE_MAP, RegNetConfig
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_regnet import (
REGNET_PRETRAINED_MODEL_ARCHIVE_LIST,
RegNetForImageClassification,
RegNetModel,
RegNetPreTrainedModel,
)
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_tf_regnet import (
TF_REGNET_PRETRAINED_MODEL_ARCHIVE_LIST,
TFRegNetForImageClassification,
TFRegNetModel,
TFRegNetPreTrainedModel,
)
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_flax_regnet import (
FlaxRegNetForImageClassification,
FlaxRegNetModel,
FlaxRegNetPreTrainedModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
.\models\rembert\configuration_rembert.py
""" RemBERT model configuration"""
from collections import OrderedDict
from typing import Mapping
from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig
from ...utils import logging
logger = logging.get_logger(__name__)
REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"google/rembert": "https://huggingface.co/google/rembert/resolve/main/config.json",
}
class RemBertConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`RemBertModel`]. It is used to instantiate an
RemBERT model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the RemBERT
[google/rembert](https://huggingface.co/google/rembert) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Example:
```
>>> from transformers import RemBertModel, RemBertConfig
>>> # Initializing a RemBERT rembert style configuration
>>> configuration = RemBertConfig()
>>> # Initializing a model from the rembert style configuration
>>> model = RemBertModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "rembert"
def __init__(
self,
vocab_size=250300,
hidden_size=1152,
num_hidden_layers=32,
num_attention_heads=18,
input_embedding_size=256,
output_embedding_size=1664,
intermediate_size=4608,
hidden_act="gelu",
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
classifier_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
layer_norm_eps=1e-12,
use_cache=True,
pad_token_id=0,
bos_token_id=312,
eos_token_id=313,
**kwargs,
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
self.vocab_size = vocab_size
self.input_embedding_size = input_embedding_size
self.output_embedding_size = output_embedding_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.classifier_dropout_prob = classifier_dropout_prob
self.initializer_range = initializer_range
self.type_vocab_size = type_vocab_size
self.layer_norm_eps = layer_norm_eps
self.use_cache = use_cache
self.tie_word_embeddings = False
class RemBertOnnxConfig(OnnxConfig):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task == "multiple-choice":
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
else:
dynamic_axis = {0: "batch", 1: "sequence"}
return OrderedDict(
[
("input_ids", dynamic_axis),
("attention_mask", dynamic_axis),
("token_type_ids", dynamic_axis),
]
)
@property
def atol_for_validation(self) -> float:
return 1e-4
.\models\rembert\convert_rembert_tf_checkpoint_to_pytorch.py
"""Convert RemBERT checkpoint."""
import argparse
import torch
from transformers import RemBertConfig, RemBertModel, load_tf_weights_in_rembert
from transformers.utils import logging
logging.set_verbosity_info()
def convert_rembert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
config = RemBertConfig.from_json_file(bert_config_file)
print("Building PyTorch model from configuration: {}".format(str(config)))
model = RemBertModel(config)
load_tf_weights_in_rembert(model, config, tf_checkpoint_path)
print("Save PyTorch model to {}".format(pytorch_dump_path))
torch.save(model.state_dict(), pytorch_dump_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
)
parser.add_argument(
"--rembert_config_file",
default=None,
type=str,
required=True,
help=(
"The config json file corresponding to the pre-trained RemBERT model. \n"
"This specifies the model architecture."
),
)
parser.add_argument(
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
)
args = parser.parse_args()
convert_rembert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.rembert_config_file, args.pytorch_dump_path)
.\models\rembert\modeling_rembert.py
""" PyTorch RemBERT 模型。"""
import math
import os
from typing import Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
MaskedLMOutput,
MultipleChoiceModelOutput,
QuestionAnsweringModelOutput,
SequenceClassifierOutput,
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from .configuration_rembert import RemBertConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "RemBertConfig"
_CHECKPOINT_FOR_DOC = "google/rembert"
REMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
"google/rembert",
]
def load_tf_weights_in_rembert(model, config, tf_checkpoint_path):
"""从 TensorFlow checkpoints 中加载权重到 PyTorch 模型中。"""
try:
import re
import numpy as np
import tensorflow as tf
except ImportError:
logger.error(
"在 PyTorch 中加载 TensorFlow 模型需要安装 TensorFlow。请访问 "
"https://www.tensorflow.org/install/ 获取安装指南。"
)
raise
tf_path = os.path.abspath(tf_checkpoint_path)
logger.info(f"从 {tf_path} 转换 TensorFlow checkpoints")
init_vars = tf.train.list_variables(tf_path)
names = []
arrays = []
for name, shape in init_vars:
if any(deny in name for deny in ("adam_v", "adam_m", "output_embedding", "cls")):
continue
logger.info(f"Loading TF weight {name} with shape {shape}")
array = tf.train.load_variable(tf_path, name)
names.append(name)
arrays.append(array)
for name, array in zip(names, arrays):
name = name.replace("bert/", "rembert/")
name = name.split("/")
if any(
n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
for n in name
):
logger.info(f"Skipping {'/'.join(name)}")
continue
pointer = model
for m_name in name:
if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
scope_names = re.split(r"_(\d+)", m_name)
else:
scope_names = [m_name]
if scope_names[0] == "kernel" or scope_names[0] == "gamma":
pointer = getattr(pointer, "weight")
elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
pointer = getattr(pointer, "bias")
elif scope_names[0] == "output_weights":
pointer = getattr(pointer, "weight")
elif scope_names[0] == "squad":
pointer = getattr(pointer, "classifier")
else:
try:
pointer = getattr(pointer, scope_names[0])
except AttributeError:
logger.info("Skipping {}".format("/".join(name)))
continue
if len(scope_names) >= 2:
num = int(scope_names[1])
pointer = pointer[num]
if m_name[-11:] == "_embeddings":
pointer = getattr(pointer, "weight")
elif m_name == "kernel":
array = np.transpose(array)
try:
if pointer.shape != array.shape:
raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
except AssertionError as e:
e.args += (pointer.shape, array.shape)
raise
logger.info(f"Initialize PyTorch weight {name}")
pointer.data = torch.from_numpy(array)
return model
class RemBertEmbeddings(nn.Module):
def __init__(self, config):
super().__init__()
self.word_embeddings = nn.Embedding(
config.vocab_size, config.input_embedding_size, padding_idx=config.pad_token_id
)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.input_embedding_size)
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.input_embedding_size)
self.LayerNorm = nn.LayerNorm(config.input_embedding_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
past_key_values_length: int = 0,
) -> torch.Tensor:
if input_ids is not None:
input_shape = input_ids.size()
else:
input_shape = inputs_embeds.size()[:-1]
seq_length = input_shape[1]
if position_ids is None:
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = inputs_embeds + token_type_embeddings
position_embeddings = self.position_embeddings(position_ids)
embeddings += position_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
class RemBertPooler(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
class RemBertSelfAttention(nn.Module):
def __init__(self, config):
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
f"heads ({config.num_attention_heads})"
)
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.is_decoder = config.is_decoder
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Tuple[Tuple[torch.FloatTensor]] = None,
output_attentions: bool = False,
):
class RemBertSelfOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class RemBertAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.self = RemBertSelfAttention(config)
self.output = RemBertSelfOutput(config)
self.pruned_heads = set()
def prune_heads(self, heads):
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
)
self.self.query = prune_linear_layer(self.self.query, index)
self.self.key = prune_linear_layer(self.self.key, index)
self.self.value = prune_linear_layer(self.self.value, index)
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
self_outputs = self.self(
hidden_states,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:]
return outputs
class RemBertIntermediate(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class RemBertOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class RemBertLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = RemBertAttention(config)
self.is_decoder = config.is_decoder
self.add_cross_attention = config.add_cross_attention
if self.add_cross_attention:
if not self.is_decoder:
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
self.crossattention = RemBertAttention(config)
self.intermediate = RemBertIntermediate(config)
self.output = RemBertOutput(config)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention(
hidden_states,
attention_mask,
head_mask,
output_attentions=output_attentions,
past_key_value=self_attn_past_key_value,
)
attention_output = self_attention_outputs[0]
if self.is_decoder:
outputs = self_attention_outputs[1:-1]
present_key_value = self_attention_outputs[-1]
else:
outputs = self_attention_outputs[1:]
cross_attn_present_key_value = None
if self.is_decoder and encoder_hidden_states is not None:
if not hasattr(self, "crossattention"):
raise ValueError(
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
" by setting `config.add_cross_attention=True`"
)
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
cross_attention_outputs = self.crossattention(
attention_output,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
cross_attn_past_key_value,
output_attentions,
)
attention_output = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1:-1]
cross_attn_present_key_value = cross_attention_outputs[-1]
present_key_value = present_key_value + cross_attn_present_key_value
layer_output = apply_chunking_to_forward(
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
)
outputs = (layer_output,) + outputs
if self.is_decoder:
outputs = outputs + (present_key_value,)
return outputs
def feed_forward_chunk(self, attention_output):
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
return layer_output
class RemBertEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.embedding_hidden_mapping_in = nn.Linear(config.input_embedding_size, config.hidden_size)
self.layer = nn.ModuleList([RemBertLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
use_cache: Optional[bool] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
pass
class RemBertPredictionHeadTransform(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
if isinstance(config.hidden_act, str):
self.transform_act_fn = ACT2FN[config.hidden_act]
else:
self.transform_act_fn = config.hidden_act
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states
class RemBertLMPredictionHead(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.output_embedding_size)
self.decoder = nn.Linear(config.output_embedding_size, config.vocab_size)
self.activation = ACT2FN[config.hidden_act]
self.LayerNorm = nn.LayerNorm(config.output_embedding_size, eps=config.layer_norm_eps)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.activation(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
hidden_states = self.decoder(hidden_states)
return hidden_states
class RemBertOnlyMLMHead(nn.Module):
def __init__(self, config):
super().__init__()
self.predictions = RemBertLMPredictionHead(config)
def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
prediction_scores = self.predictions(sequence_output)
return prediction_scores
class RemBertPreTrainedModel(PreTrainedModel):
"""
RemBert 预训练模型基类,继承自 PreTrainedModel
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
# 定义配置类为 RemBertConfig
config_class = RemBertConfig
# 加载 TensorFlow 权重函数为 load_tf_weights_in_rembert
load_tf_weights = load_tf_weights_in_rembert
# 基础模型前缀为 "rembert"
base_model_prefix = "rembert"
# 支持梯度检查点
supports_gradient_checkpointing = True
def _init_weights(self, module):
"""Initialize the weights"""
# 如果模块是线性层
if isinstance(module, nn.Linear):
# 使用正态分布初始化权重,标准差为配置中的 initializer_range
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
# 如果存在偏置项,则将其初始化为零
if module.bias is not None:
module.bias.data.zero_()
# 如果模块是嵌入层
elif isinstance(module, nn.Embedding):
# 使用正态分布初始化权重,标准差为配置中的 initializer_range
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
# 如果指定了填充索引,则将填充索引对应的权重初始化为零
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
# 如果模块是 LayerNorm 层
elif isinstance(module, nn.LayerNorm):
# 将偏置项初始化为零
module.bias.data.zero_()
# 将权重初始化为全1
module.weight.data.fill_(1.0)
# REMBERT_START_DOCSTRING 是一个原始文档字符串,描述了一个 PyTorch 模型类 RemBert 的基本信息和用法建议
REMBERT_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html
it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
behavior.
Parameters:
config ([`RemBertConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
# REMBERT_INPUTS_DOCSTRING 是一个空白的文档字符串,用于描述模型的输入参数和示例,但当前为空
REMBERT_INPUTS_DOCSTRING = r"""
"""
Args:
input_ids (`torch.LongTensor` of shape `({0})`):
# 输入序列标记在词汇表中的索引
# 可以使用 `AutoTokenizer` 获取这些索引。参见 `PreTrainedTokenizer.encode` 和 `PreTrainedTokenizer.__call__` 获取详情。
# [什么是输入 ID?](../glossary#input-ids)
attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
# 遮罩,避免在填充标记索引上进行注意力计算。遮罩的取值范围为 `[0, 1]`:
# - 1 表示**未遮罩**的标记,
# - 0 表示**已遮罩**的标记。
# [什么是注意力遮罩?](../glossary#attention-mask)
token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
# 段标记索引,用于指示输入的第一部分和第二部分。索引取值范围为 `[0, 1]`:
# - 0 对应**句子 A** 的标记,
# - 1 对应**句子 B** 的标记。
# [什么是标记类型 ID?](../glossary#token-type-ids)
position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
# 每个输入序列标记在位置嵌入中的位置索引。索引取值范围为 `[0, config.max_position_embeddings - 1]`。
# [什么是位置 ID?](../glossary#position-ids)
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
# 遮罩,用于屏蔽自注意力模块中选定的注意力头部。遮罩的取值范围为 `[0, 1]`:
# - 1 表示**未遮罩**的头部,
# - 0 表示**已遮罩**的头部。
inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
# 可选项,可以直接传递嵌入表示而不是 `input_ids`。这对于控制如何将 `input_ids` 索引转换为相关向量比模型内部的嵌入查找矩阵更有用。
output_attentions (`bool`, *optional*):
# 是否返回所有注意力层的注意力张量。更多细节请参见返回的张量中的 `attentions`。
output_hidden_states (`bool`, *optional*):
# 是否返回所有层的隐藏状态。更多细节请参见返回的张量中的 `hidden_states`。
return_dict (`bool`, *optional*):
# 是否返回 [`~utils.ModelOutput`] 而不是普通元组。
"""
@add_start_docstrings(
"The bare RemBERT Model transformer outputting raw hidden-states without any specific head on top.",
REMBERT_START_DOCSTRING,
)
"""
class RemBertModel(RemBertPreTrainedModel):
"""
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
cross-attention is added between the self-attention layers, following the architecture described in [Attention is
all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
`add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
"""
def __init__(self, config, add_pooling_layer=True):
super().__init__(config)
self.config = config
# Initialize embeddings based on configuration
self.embeddings = RemBertEmbeddings(config)
# Initialize encoder based on configuration
self.encoder = RemBertEncoder(config)
# Optionally initialize a pooling layer based on configuration
self.pooler = RemBertPooler(config) if add_pooling_layer else None
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
# Return the word embeddings from the embeddings layer
return self.embeddings.word_embeddings
def set_input_embeddings(self, value):
# Set new word embeddings for the embeddings layer
self.embeddings.word_embeddings = value
def _prune_heads(self, heads_to_prune):
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
# Prune specified heads in the attention layers of the encoder
self.encoder.layer[layer].attention.prune_heads(heads)
@add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
checkpoint="google/rembert",
output_type=BaseModelOutputWithPastAndCrossAttentions,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.LongTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
"""
Forward pass for the RemBERT model.
Args:
input_ids: Indices of input sequence tokens in the vocabulary.
attention_mask: Mask to avoid performing attention on padding token indices.
token_type_ids: Segment token indices to indicate first and second portions of the inputs.
position_ids: Indices of positions of each input sequence tokens in the position embeddings.
head_mask: Mask to nullify selected heads of the attention modules.
inputs_embeds: Overrides the model's base input word embeddings if provided.
encoder_hidden_states: Hidden states of the encoder to feed into the cross-attention layer.
encoder_attention_mask: Mask to avoid performing attention on encoder hidden states.
past_key_values: Cached key-value pairs for fast autoregressive decoding.
use_cache: Whether or not to use the past key-value caches.
output_attentions: Whether or not to return attentions weights.
output_hidden_states: Whether or not to return hidden states.
return_dict: Whether or not to return a dictionary as output.
Returns:
BaseModelOutputWithPastAndCrossAttentions: Model output.
Notes:
Args above are based on REMBERT_INPUTS_DOCSTRING for batch size and sequence length.
"""
# Actual implementation of the forward pass will follow here, specific to RemBERT's architecture and functionality
pass
@add_start_docstrings("""RemBERT Model with a `language modeling` head on top.""", REMBERT_START_DOCSTRING)
class RemBertForMaskedLM(RemBertPreTrainedModel):
_tied_weights_keys = ["cls.predictions.decoder.weight"]
def __init__(self, config):
super().__init__(config)
if config.is_decoder:
logger.warning(
"If you want to use `RemBertForMaskedLM` make sure `config.is_decoder=False` for "
"bi-directional self-attention."
)
self.rembert = RemBertModel(config, add_pooling_layer=False)
self.cls = RemBertOnlyMLMHead(config)
self.post_init()
def get_output_embeddings(self):
return self.cls.predictions.decoder
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
@add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
checkpoint="google/rembert",
output_type=MaskedLMOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.LongTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, MaskedLMOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.rembert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
prediction_scores = self.cls(sequence_output)
masked_lm_loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict:
output = (prediction_scores,) + outputs[2:]
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
return MaskedLMOutput(
loss=masked_lm_loss,
logits=prediction_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
input_shape = input_ids.shape
effective_batch_size = input_shape[0]
assert self.config.pad_token_id is not None, "The PAD token should be defined for generation"
attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)
dummy_token = torch.full(
(effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
)
input_ids = torch.cat([input_ids, dummy_token], dim=1)
return {"input_ids": input_ids, "attention_mask": attention_mask}
@add_start_docstrings(
"""RemBERT Model with a `language modeling` head on top for CLM fine-tuning.""", REMBERT_START_DOCSTRING
)
class RemBertForCausalLM(RemBertPreTrainedModel):
_tied_weights_keys = ["cls.predictions.decoder.weight"]
def __init__(self, config):
super().__init__(config)
if not config.is_decoder:
logger.warning("If you want to use `RemBertForCausalLM` as a standalone, add `is_decoder=True.`")
self.rembert = RemBertModel(config, add_pooling_layer=False)
self.cls = RemBertOnlyMLMHead(config)
self.post_init()
def get_output_embeddings(self):
return self.cls.predictions.decoder
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
@add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.LongTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
input_shape = input_ids.shape
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
def _reorder_cache(self, past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2])
+ layer_past[2:],
)
return reordered_past
@add_start_docstrings(
"""
RemBERT Model transformer with a sequence classification/regression head on top (a linear layer on top of the
pooled output) e.g. for GLUE tasks.
""",
REMBERT_START_DOCSTRING,
)
class RemBertForSequenceClassification(RemBertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.rembert = RemBertModel(config)
self.dropout = nn.Dropout(config.classifier_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.post_init()
@add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
checkpoint="google/rembert",
output_type=SequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SequenceClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.rembert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"""
RemBERT Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
softmax) e.g. for RocStories/SWAG tasks.
""",
REMBERT_START_DOCSTRING,
)
class RemBertForMultipleChoice(RemBertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.rembert = RemBertModel(config)
self.dropout = nn.Dropout(config.classifier_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, 1)
self.post_init()
@add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
@add_code_sample_docstrings(
checkpoint="google/rembert",
output_type=MultipleChoiceModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, MultipleChoiceModelOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
`input_ids` above)
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
inputs_embeds = (
inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
if inputs_embeds is not None
else None
)
outputs = self.rembert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
reshaped_logits = logits.view(-1, num_choices)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(reshaped_logits, labels)
if not return_dict:
output = (reshaped_logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return MultipleChoiceModelOutput(
loss=loss,
logits=reshaped_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
RemBERT Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
Named-Entity-Recognition (NER) tasks.
"""
# 继承自RemBertPreTrainedModel的RemBertForTokenClassification类,用于在RemBERT模型上添加一个用于标记分类的头部
class RemBertForTokenClassification(RemBertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
# 初始化RemBERT模型,不添加池化层
self.rembert = RemBertModel(config, add_pooling_layer=False)
# Dropout层,用于防止过拟合
self.dropout = nn.Dropout(config.classifier_dropout_prob)
# 分类器线性层,将隐藏状态输出映射到标签数量的空间
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
# 初始化权重并应用最终处理
self.post_init()
@add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
checkpoint="google/rembert",
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
# 前向传播函数,接受多个输入参数,并返回模型的输出或损失
def forward(
self,
input_ids: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, TokenClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# 使用RemBERT模型进行前向传播
outputs = self.rembert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# 获取序列输出
sequence_output = outputs[0]
# 应用Dropout层
sequence_output = self.dropout(sequence_output)
# 使用分类器线性层计算logits
logits = self.classifier(sequence_output)
loss = None
if labels is not None:
# 计算交叉熵损失
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
# 如果不使用return_dict,则返回元组形式的输出
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
# 如果使用return_dict,则返回TokenClassifierOutput对象
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
RemBERT Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
""",
REMBERT_START_DOCSTRING,
class RemBertForQuestionAnswering(RemBertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.rembert = RemBertModel(config, add_pooling_layer=False)
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
self.post_init()
@add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
checkpoint="google/rembert",
output_type=QuestionAnsweringModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
start_positions: Optional[torch.LongTensor] = None,
end_positions: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
def forward(
self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, QuestionAnsweringModelOutput]:
r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.rembert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)
total_loss = None
if start_positions is not None and end_positions is not None:
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1)
ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index)
end_positions.clamp_(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
if not return_dict:
output = (start_logits, end_logits) + outputs[2:]
return ((total_loss,) + output) if total_loss is not None else output
return QuestionAnsweringModelOutput(
loss=total_loss,
start_logits=start_logits,
end_logits=end_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)