Transformers 源码解析(一百零九)
.\models\swinv2\__init__.py
from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
_import_structure = {
"configuration_swinv2": ["SWINV2_PRETRAINED_CONFIG_ARCHIVE_MAP", "Swinv2Config"],
}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_swinv2"] = [
"SWINV2_PRETRAINED_MODEL_ARCHIVE_LIST",
"Swinv2ForImageClassification",
"Swinv2ForMaskedImageModeling",
"Swinv2Model",
"Swinv2PreTrainedModel",
"Swinv2Backbone",
]
if TYPE_CHECKING:
from .configuration_swinv2 import SWINV2_PRETRAINED_CONFIG_ARCHIVE_MAP, Swinv2Config
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_swinv2 import (
SWINV2_PRETRAINED_MODEL_ARCHIVE_LIST,
Swinv2Backbone,
Swinv2ForImageClassification,
Swinv2ForMaskedImageModeling,
Swinv2Model,
Swinv2PreTrainedModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
.\models\switch_transformers\configuration_switch_transformers.py
""" Switch Transformers model configuration"""
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
SWITCH_TRANSFORMERS_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"google/switch-base-8": "https://huggingface.co/google/switch-base-8/blob/main/config.json",
}
class SwitchTransformersConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`SwitchTransformersModel`]. It is used to
instantiate a SwitchTransformers model according to the specified arguments, defining the model architecture.
Instantiating a configuration with the defaults will yield a similar configuration to that of the
SwitchTransformers [google/switch-base-8](https://huggingface.co/google/switch-base
):
# 初始化 Transformer 参数
self.vocab_size = vocab_size # 设置词汇表大小
self.d_model = d_model # 设置 Transformer 模型的维度大小
self.d_kv = d_kv # 设置键和值的维度大小
self.d_ff = d_ff # 设置前馈网络的隐藏层大小
self.num_sparse_encoder_layers = num_sparse_encoder_layers # 编码器中稀疏层的数量
self.num_layers = num_layers # 总层数
self.num_decoder_layers = (
num_decoder_layers if num_decoder_layers is not None else self.num_layers
) # 解码器层数,默认与总层数相同
self.num_sparse_decoder_layers = num_sparse_decoder_layers # 解码器中稀疏层的数量
# 每隔多少层设置一个稀疏层,用于编码器
if self.num_sparse_encoder_layers > 0:
self.encoder_sparse_step = self.num_layers // self.num_sparse_encoder_layers
else:
self.encoder_sparse_step = self.num_layers # 如果没有稀疏层,则步长为总层数,这会创建0个稀疏层
# 每隔多少层设置一个稀疏层,用于解码器
if self.num_sparse_decoder_layers > 0:
self.decoder_sparse_step = self.num_decoder_layers // self.num_sparse_decoder_layers
else:
self.decoder_sparse_step = self.num_decoder_layers # 如果没有稀疏层,则步长为总层数,这会创建0个稀疏层
self.num_heads = num_heads # 设置注意力头的数量
self.num_experts = num_experts # 设置专家的数量
self.expert_capacity = expert_capacity # 设置每个专家的容量
self.router_bias = router_bias # 设置路由器偏置
self.router_jitter_noise = router_jitter_noise # 设置路由器抖动噪声
if router_dtype not in ["float32", "float16", "bfloat16"]:
raise ValueError(f"`router_dtype` must be one of 'float32', 'float16' or 'bfloat16', got {router_dtype}")
self.router_dtype = router_dtype # 设置路由器数据类型
self.router_ignore_padding_tokens = router_ignore_padding_tokens # 是否忽略填充标记的路由
self.relative_attention_num_buckets = relative_attention_num_buckets # 相对注意力的桶数量
self.relative_attention_max_distance = relative_attention_max_distance # 相对注意力的最大距离
self.dropout_rate = dropout_rate # 设置丢弃率
self.layer_norm_epsilon = layer_norm_epsilon # 层归一化的 epsilon 值
self.initializer_factor = initializer_factor # 初始化因子
self.use_cache = use_cache # 是否使用缓存
self.add_router_probs = add_router_probs # 是否添加路由概率
self.router_z_loss_coef = router_z_loss_coef # 路由 Z 损失系数
self.router_aux_loss_coef = router_aux_loss_coef # 路由辅助损失系数
self.dense_act_fn = dense_act_fn # 密集层的激活函数
super().__init__(
pad_token_id=pad_token_id, # 填充标记 ID
eos_token_id=eos_token_id, # 终止标记 ID
is_encoder_decoder=is_encoder_decoder, # 是否是编码解码器
**kwargs, # 其它关键字参数
)
.\models\switch_transformers\convert_big_switch.py
import argparse
import json
import os
import tensorstore as ts
import torch
from flax import serialization
from flax.traverse_util import flatten_dict, unflatten_dict
from tensorflow.io import gfile
from transformers.modeling_utils import dtype_byte_size
from transformers.models.switch_transformers.convert_switch_transformers_original_flax_checkpoint_to_pytorch import (
rename_keys,
)
from transformers.utils import WEIGHTS_INDEX_NAME, WEIGHTS_NAME
from transformers.utils.hub import convert_file_size_to_int
def rename_base_flax_keys(flax_key_tuple, flax_tensor):
"""
对基本 JAX 键进行重命名以适配 PyTorch。
"""
if flax_key_tuple[-1] == "kernel" and flax_tensor.ndim == 3:
flax_key_tuple = flax_key_tuple[:-1] + ("weight",)
flax_tensor = torch.permute(flax_tensor, (0, 2, 1))
elif flax_key_tuple[-1] == "kernel" and ".".join(flax_key_tuple):
flax_key_tuple = flax_key_tuple[:-1] + ("weight",)
flax_tensor = flax_tensor.T
elif flax_key_tuple[-1] in ["scale", "embedding"]:
flax_key_tuple = flax_key_tuple[:-1] + ("weight",)
return flax_key_tuple, flax_tensor
def get_key_and_tensorstore_dict(layer, checkpoint_info, switch_checkpoint_path):
"""
获取键和 TensorStore 字典。
"""
if "metadata" in layer:
split_layer = layer.split("metadata")
curr_real_layer_name = "".join(split_layer[0])[:-1]
split_layer = [tuple(("metadata" + split_layer[1]).split("/"))]
elif "kvstore" in layer:
split_layer = layer.split("kvstore")
curr_real_layer_name = "".join(split_layer[0])[:-1]
split_layer = [tuple(("kvstore" + split_layer[1]).split("/"))]
else:
split_layer = layer.split("/")
curr_real_layer_name = "/".join(split_layer[:-1])
split_layer[-1] = (split_layer[-1],)
if "kvstore/path" in layer:
content = f"{switch_checkpoint_path}/{checkpoint_info[layer]}"
elif "kvstore/driver" in layer:
content = "file"
else:
content = checkpoint_info[layer]
return curr_real_layer_name, split_layer, content
def rename_and_save_block(current_block, save_path):
"""
重命名当前块的键并保存。
"""
current_block = rename_keys(current_block)
new_current_block = {}
for k, v in current_block.items():
new_current_block[k.replace("/", ".")] = v
current_block = new_current_block
torch.save(current_block, save_path)
def shard_on_the_fly(switch_checkpoint_path, dump_path, max_shard_size, dtype, weights_name: str = WEIGHTS_NAME):
"""
动态分片检查点文件。
"""
max_shard_size = convert_file_size_to_int(max_shard_size)
sharded_state_dicts = []
current_block = {}
current_block_size = 0
total_size = 0
os.makedirs(dump_path, exist_ok=True)
with gfile.GFile(switch_checkpoint_path + "/checkpoint", "rb") as fp:
checkpoint_info = serialization.msgpack_restore(fp.read())["optimizer"]["target"]
checkpoint_info = flatten_dict(checkpoint_info, sep="/")
all_layers = {}
for layer in checkpoint_info.keys():
curr_real_layer_name, split_layer, content = get_key_and_tensorstore_dict(
layer, checkpoint_info, switch_checkpoint_path
)
if curr_real_layer_name in all_layers:
all_layers[curr_real_layer_name][split_layer[-1]] = content
else:
all_layers[curr_real_layer_name] = {split_layer[-1]: content}
for key in all_layers.keys():
raw_weights = ts.open(unflatten_dict(all_layers[key])).result().read().result()
raw_weights = torch.tensor(raw_weights)
weight_size = raw_weights.numel() * dtype_byte_size(raw_weights.dtype)
key, raw_weights = rename_base_flax_keys(tuple(key.split("/")), raw_weights)
key = "/".join(key)
if current_block_size + weight_size > max_shard_size:
save_path = os.path.join(
dump_path, weights_name.replace(".bin", f"-{len(sharded_state_dicts)+1:05d}-of-???.bin")
)
rename_and_save_block(current_block, save_path)
sharded_state_dicts.append(current_block.keys())
del current_block
current_block = {}
current_block_size = 0
current_block[key] = raw_weights.to(getattr(torch, dtype))
current_block_size += weight_size
total_size += weight_size
save_path = os.path.join(dump_path, weights_name.replace(".bin", f"-{len(sharded_state_dicts)+1:05d}-of-???.bin"))
rename_and_save_block(current_block, save_path)
sharded_state_dicts.append(current_block.keys())
if len(sharded_state_dicts) == 1:
return {weights_name: sharded_state_dicts[0]}, None
weight_map = {}
shards = {}
for idx, shard in enumerate(sharded_state_dicts):
shard_file = weights_name.replace(
".bin", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.bin"
)
temp_filename = os.path.join(dump_path, weights_name.replace(".bin", f"-{idx+1:05d}-of-???.bin"))
os.rename(temp_filename, os.path.join(dump_path, shard_file))
shards[shard_file] = shard
for key in shard:
weight_map[key] = shard_file
metadata = {"total_size": total_size}
index = {"metadata": metadata, "weight_map": weight_map}
with open(os.path.join(dump_path, WEIGHTS_INDEX_NAME), "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)
return metadata, index
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--switch_t5x_checkpoint_path",
default="/mnt/disks/disk_switch/original_checkpoints/switch-xxl-128/checkpoint_634600",
type=str,
required=False,
help="Path to a directory containing a folder per layer. Follows the original Google format.",
)
parser.add_argument("--max_shard_size", default="10GB", required=False, help="Max shard size")
parser.add_argument("--dtype", default="bfloat16", type=str, required=False, help="dtype of the saved model")
parser.add_argument(
"--pytorch_dump_folder_path",
default="/mnt/disks/disk_switch/original_checkpoints/switch-xxl-128-converted",
type=str,
required=False,
help="Path to the output pytorch model.",
)
args = parser.parse_args()
shard_on_the_fly(
args.switch_t5x_checkpoint_path,
args.pytorch_dump_folder_path,
args.max_shard_size,
args.dtype,
)
def sanity_check():
from transformers import SwitchTransformersConfig, SwitchTransformersForConditionalGeneration, T5Tokenizer
config = SwitchTransformersConfig.from_pretrained("google/switch-base-8")
config.save_pretrained("/home/arthur_huggingface_co/transformers/switch_converted")
model = SwitchTransformersForConditionalGeneration.from_pretrained(
"/home/arthur_huggingface_co/transformers/switch_converted", device_map="auto"
)
tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-small")
text = "A <extra_id_0> walks into a bar a orders a <extra_id_1> with <extra_id_2> pinch of <extra_id_3>."
input_ids = tokenizer(text, return_tensors="pt").input_ids
out = model.generate(input_ids, decoder_start_token_id=0)
print(tokenizer.decode(out[0]))
.\models\switch_transformers\convert_switch_transformers_original_flax_checkpoint_to_pytorch.py
"""将 SwitchTransformersX 仓库的检查点转换为 JAX/FLAX 模型。"""
import argparse
import re
from flax.traverse_util import flatten_dict, unflatten_dict
from t5x import checkpoints
from transformers import SwitchTransformersConfig, SwitchTransformersForConditionalGeneration
from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model
from transformers.utils import logging
logging.set_verbosity_info()
MOE_LAYER_NAME_MAPPING = {
"/attention/": "/0/SelfAttention/",
"/self_attention/": "/0/SelfAttention/",
"/encoder_decoder_attention/": "/1/EncDecAttention/",
"value": "v",
"query": "q",
"key": "k",
"out": "o",
"pre_self_attention_layer_norm": "0/layer_norm",
"pre_cross_attention_layer_norm": "1/layer_norm",
"pre_attention_layer_norm": "0/layer_norm",
"token_embedder": "shared",
"encoder_norm": "final_layer_norm",
"decoder_norm": "final_layer_norm",
"relpos_bias/rel_embedding": "block/0/layer/0/SelfAttention/relative_attention_bias/weight",
"router/router_weights/w/": "router/classifier/",
"roer/roer_weights/w/": "router/classifier/",
"logits_dense": "lm_head",
}
def rename_keys(s_dict):
keys = list(s_dict.keys())
for key in keys:
layer_to_block_of_layer = r".*/layers_(\d+)"
new_key = key
if re.match(layer_to_block_of_layer, key):
new_key = re.sub(r"layers_(\d+)", r"block/\1/layer", new_key)
layer_to_block_of_layer = r"(encoder|decoder)\/"
if re.match(layer_to_block_of_layer, key):
groups = re.match(layer_to_block_of_layer, new_key).groups()
if groups[0] == "encoder":
new_key = re.sub(r"/mlp/", r"/1/mlp/", new_key)
new_key = re.sub(r"/pre_mlp_layer_norm/", r"/1/layer_norm/", new_key)
elif groups[0] == "decoder":
new_key = re.sub(r"/mlp/", r"/2/mlp/", new_key)
new_key = re.sub(r"/pre_mlp_layer_norm/", r"/2/layer_norm/", new_key)
for old_key, temp_key in MOE_LAYER_NAME_MAPPING.items():
if old_key in new_key:
new_key = new_key.replace(old_key, temp_key)
print(f"{key} -> {new_key}")
s_dict[new_key] = s_dict.pop(key)
if "encoder/block/0/layer/0/SelfAttention/relative_attention_bias/weight" in s_dict:
s_dict["encoder/block/0/layer/0/SelfAttention/relative_attention_bias/weight"] = s_dict[
"encoder/block/0/layer/0/SelfAttention/relative_attention_bias/weight"
].T
if "decoder/block/0/layer/0/SelfAttention/relative_attention_bias/weight" in s_dict:
s_dict["decoder/block/0/layer/0/SelfAttention/relative_attention_bias/weight"] = s_dict[
"decoder/block/0/layer/0/SelfAttention/relative_attention_bias/weight"
].T
for key in list(s_dict.keys()):
if "expert" in key:
num_experts = s_dict[key].shape[0]
expert_weights = s_dict[key]
for idx in range(num_experts):
s_dict[key.replace("expert/", f"experts/expert_{idx}/")] = expert_weights[idx]
print(f"{key} -> {key.replace('expert/', f'experts/expert_{idx}/')}")
s_dict.pop(key)
return s_dict
GIN_TO_CONFIG_MAPPING = {
"NUM_ENCODER_LAYERS": "num_layers",
"NUM_DECODER_LAYERS": "num_decoder_layers",
"NUM_HEADS": "num_heads",
"HEAD_DIM": "d_kv",
"EMBED_DIM": "d_model",
"MLP_DIM": "d_ff",
"NUM_SELECTED_EXPERTS": "num_selected_experts",
"NUM_ENCODER_SPARSE_LAYERS": "num_sparse_encoder_layers",
"NUM_DECODER_SPARSE_LAYERS": "num_sparse_decoder_layers",
"dense.MlpBlock.activations": "feed_forward_proj",
}
def convert_gin_to_config(gin_file, num_experts):
import regex as re
with open(gin_file, "r") as f:
raw_gin = f.read()
regex_match = re.findall(r"(.*) = ([0-9.]*)", raw_gin)
args = {}
for param, value in regex_match:
if param in GIN_TO_CONFIG_MAPPING and value != "":
args[GIN_TO_CONFIG_MAPPING[param]] = float(value) if "." in value else int(value)
activation = re.findall(r"(.*activations) = \(\'(.*)\',\)", raw_gin)[0]
args[GIN_TO_CONFIG_MAPPING[activation[0]]] = str(activation[1])
args["num_experts"] = num_experts
config = SwitchTransformersConfig(**args)
return config
def convert_flax_checkpoint_to_pytorch(
flax_checkpoint_path, config_file, gin_file=None, pytorch_dump_path="./", num_experts=8
):
print(f"Loading flax weights from : {flax_checkpoint_path}")
flax_params = checkpoints.load_t5x_checkpoint(flax_checkpoint_path)
if gin_file is not None:
config = convert_gin_to_config(gin_file, num_experts)
else:
config = SwitchTransformersConfig.from_pretrained(config_file)
pt_model = SwitchTransformersForConditionalGeneration(config)
flax_params = flax_params["target"]
flax_params = flatten_dict(flax_params, sep="/")
flax_params = rename_keys(flax_params)
flax_params = unflatten_dict(flax_params, sep="/")
load_flax_weights_in_pytorch_model(pt_model, flax_params)
print(f"Save PyTorch model to {pytorch_dump_path}")
pt_model.save_pretrained(pytorch_dump_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--switch_t5x_checkpoint_path",
default=None,
type=str,
required=True,
help=(
"The config json file corresponding to the pre-trained SwitchTransformers model. \nThis specifies the"
" model architecture. If not provided, a `gin_file` has to be provided."
),
)
parser.add_argument(
"--gin_file",
default=None,
type=str,
required=False,
help="Path to the gin config file. If not provided, a `config_file` has to be passed ",
)
parser.add_argument(
"--config_name", default=None, type=str, required=False, help="Config name of SwitchTransformers model."
)
parser.add_argument(
"--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output pytorch model."
)
parser.add_argument("--num_experts", default=8, type=int, required=False, help="Number of experts")
args = parser.parse_args()
convert_flax_checkpoint_to_pytorch(
args.switch_t5x_checkpoint_path,
args.config_name,
args.gin_file,
args.pytorch_dump_folder_path,
args.num_experts,
)
parser.add_argument(
"--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output pytorch model."
)
parser.add_argument("--num_experts", default=8, type=int, required=False, help="Number of experts")
args = parser.parse_args()
convert_flax_checkpoint_to_pytorch(
args.switch_t5x_checkpoint_path,
args.config_name,
args.gin_file,
args.pytorch_dump_folder_path,
args.num_experts,
)
.\models\switch_transformers\modeling_switch_transformers.py
def load_balancing_loss_func(router_probs: torch.Tensor, expert_indices: torch.Tensor) -> float:
r"""
计算负载平衡损失函数。
负载平衡损失函数用于Switch Transformers中,旨在通过调整专家分配概率来实现负载平衡,以优化模型性能。
Args:
router_probs (torch.Tensor):
形状为 [batch_size, sequence_length, num_experts] 的输入概率张量,表示每个位置选择每个专家的概率。
expert_indices (torch.Tensor):
形状为 [batch_size, sequence_length] 的整数张量,表示每个位置选择的专家索引。
Returns:
float:
标量,表示计算得到的负载平衡损失值。
"""
num_groups, tokens_per_group, _ = router_probs.shape
log_z = torch.logsumexp(router_probs, dim=-1)
balancing_loss = log_z**2
return torch.sum(balancing_loss) / (num_groups * tokens_per_group)
num_experts = router_probs.shape[-1]
if expert_indices.dtype != torch.int64:
expert_indices = expert_indices.to(torch.int64)
if len(expert_indices.shape) == 2:
expert_indices = expert_indices.unsqueeze(2)
expert_mask = torch.nn.functional.one_hot(expert_indices, num_experts)
expert_mask = torch.max(expert_mask, axis=-2).values
expert_mask = expert_mask.to(torch.float32)
tokens_per_group_and_expert = torch.mean(expert_mask, axis=-2)
router_prob_per_group_and_expert = torch.mean(router_probs, axis=-2)
return torch.mean(tokens_per_group_and_expert * router_prob_per_group_and_expert) * (num_experts**2)
class SwitchTransformersTop1Router(nn.Module):
"""
Router using tokens choose top-1 experts assignment.
This router uses the same mechanism as in Switch Transformer (https://arxiv.org/abs/2101.03961) and V-MoE
(https://arxiv.org/abs/2106.05974): tokens choose their top experts. Items are sorted by router_probs and then
routed to their choice of expert until the expert's expert_capacity is reached. **There is no guarantee that each
token is processed by an expert**, or that each expert receives at least one token.
"""
def __init__(self, config: SwitchTransformersConfig):
super().__init__()
self.num_experts = config.num_experts
self.expert_capacity = config.expert_capacity
self.classifier = nn.Linear(config.hidden_size, self.num_experts, bias=config.router_bias)
self.jitter_noise = config.router_jitter_noise
self.ignore_padding_tokens = config.router_ignore_padding_tokens
self.dtype = getattr(torch, config.router_dtype)
def _compute_router_probabilities(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
Computes router probabilities from input hidden states.
Args:
hidden_states (`torch.Tensor`):
(batch_size, sequence_length, hidden_dim) from which router probabilities are computed.
Returns:
router_probabilities (`torch.Tensor`):
Tensor of shape (batch_size, sequence_length, num_experts) corresponding to the probabilities for each
token and expert. Used for routing tokens to experts.
router_logits (`torch.Tensor`):
Logits tensor of shape (batch_size, sequence_length, num_experts) corresponding to raw router logits.
This is used later for computing router z-loss.
"""
self.input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(self.dtype)
if self.training and self.jitter_noise > 0:
hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)
self._cast_classifier()
router_logits = self.classifier(hidden_states)
router_probabilities = nn.functional.softmax(router_logits, dim=-1, dtype=self.dtype).to(self.input_dtype)
return router_probabilities, router_logits
def _cast_classifier(self):
r"""
`bitsandbytes` `Linear8bitLt` layers does not support manual casting Therefore we need to check if they are an
instance of the `Linear8bitLt` class by checking special attributes.
"""
if not (hasattr(self.classifier, "SCB") or hasattr(self.classifier, "CB")):
self.classifier = self.classifier.to(self.dtype)
def forward(self, hidden_states: torch.Tensor) -> Tuple:
r"""
Generic forward function for every Router class. Each Router expects to have the same input hidden states
(`hidden_states`) corresponding to the hidden states for each token, the `expert_capacity` corresponding to the
number of tokens the Router will send to each expert, some Routers can send up to few tokens to each expert.
Each Router works as the following: it expects the hidden states for each token, gets the `router_probs` and
`router_logits` from the `router_weights`. This will assign for each token, the raw probability to be assigned
to an expert. Then each Router class will have to define its own `_compute_routing_instructions`.
Args:
hidden_states (`torch.Tensor`) :
[num_groups, tokens_per_group, hidden_dim] inputs to send to experts.
Returns:
Tuple[`torch.Tensor`, `torch.Tensor`, `torch.Tensor`] Tuple containing the expert index, the router probs
and the router logits. The router probabilities and logits are required to compute the loss.
"""
router_probs, router_logits = self._compute_router_probabilities(hidden_states)
expert_index = torch.argmax(router_probs, dim=-1)
expert_index = torch.nn.functional.one_hot(expert_index, num_classes=self.num_experts)
token_priority = torch.cumsum(expert_index, dim=-2)
expert_capacity_mask = token_priority <= self.expert_capacity
expert_index = expert_index * expert_capacity_mask
router_probs = torch.max(router_probs, dim=-1).values.unsqueeze(-1)
return expert_index, router_probs, router_logits
class SwitchTransformersLayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
初始化层规范化模块,LayerNorm按照SwitchTransformers风格处理,不使用偏差和平均值减去操作。
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
"""
执行前向传播,计算并应用在hidden_states上的层规范化。
无需减去平均值,使用根均方计算变差。
"""
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
return self.weight * hidden_states
class SwitchTransformersDenseActDense(nn.Module):
def __init__(self, config: SwitchTransformersConfig):
"""
初始化稠密激活密集层模块,包含输入线性变换、激活函数应用、模版速率下采样线性变换。
"""
super().__init__()
self.wi = nn.Linear(config.d_model, config.d_ff, bias=False)
self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
self.dropout = nn.Dropout(config.dropout_rate)
self.act = ACT2FN[config.dense_act_fn]
def forward(self, hidden_states):
"""
执行前向传播,通过线性变换、激活函数、dropout和线性变换逐层更新hidden_states。
注意转换类型要和权重一致,对于tensor的转换也会基于数据类型的规则进行处理。
"""
hidden_states = self.wi(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.dropout(hidden_states)
if (
isinstance(self.wo.weight, torch.Tensor)
and hidden_states.dtype != self.wo.weight.dtype
and self.wo.weight.dtype != torch.int8
):
hidden_states = hidden_states.to(self.wo.weight.dtype)
hidden_states = self.wo(hidden_states)
return hidden_states
class SwitchTransformersSparseMLP(nn.Module):
"""
实现了Switch Transformers稀疏多层感知机(MLP)模块的特异性参数和结构,包括路由模块和专家模块。
"""
def __init__(self, config: SwitchTransformersConfig, expert_class: nn.Module = SwitchTransformersDenseActDense):
"""
根据配置初始化SparseMLP类,包括所需的路由器和专家模块。
"""
super().__init__()
self.router = SwitchTransformersTop1Router(config)
self.experts = nn.ModuleDict()
for idx in range(config.num_experts):
expert_name = f"expert_{idx}"
self.experts[expert_name] = expert_class(config)
def forward(self, hidden_states):
r"""
Hold on, this will be slightly tricky to understand In the correct order, a MoE layer does the following:
1- Gets the `router_mask` from the router. The shape of the mask is `(batch_size, sequence_length, num_expert)`
and corresponds to the argmax of the `router_probs`. The probabilities are needed in the computation of the
hidden states : they are broadcasted to the hidden states values (can be interpreted as a scaling factor).
2- Dispatch the tokens to its associated experts. We do a classic for loop over the experts and assign for each
expert the corresponding hidden states.
"""
router_mask, router_probs, router_logits = self.router(hidden_states)
expert_index = torch.argmax(router_mask, dim=-1)
next_states = hidden_states.clone()
for idx, expert in enumerate(self.experts.values()):
token_indices = router_mask[:, :, idx].bool()
next_states[token_indices] = expert(hidden_states[token_indices]).to(next_states.dtype)
hidden_states = router_probs * next_states
return hidden_states, (router_logits, expert_index)
class SwitchTransformersAttention(nn.Module):
"""
Switch Transformers Attention module, based on PyTorch's nn.Module.
This module is responsible for handling the attention mechanism within Switch Transformers.
"""
def __init__(self, config: SwitchTransformersConfig):
super().__init__()
self.self = SwitchTransformersSelfAttention(config)
self.dropout = nn.Dropout(config.dropout_rate)
def forward(self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False):
"""
Forward pass for the Switch Transformers Attention module.
Args:
hidden_states (`torch.Tensor`):
The input hidden states for the attention calculation.
attention_mask (`torch.Tensor`, optional):
Mask to avoid performing attention on padding tokens.
head_mask (`torch.Tensor`, optional):
Mask to nullify specific heads of the attention tensors.
output_attentions (`bool`, optional):
Whether to return attentions tensors.
Returns:
`torch.Tensor` or (`torch.Tensor`, `torch.Tensor`):
Depending on `output_attentions`, either returns the contextualized representation
or a tuple with the contextualized representation and the attention tensors.
"""
self_output = self.self(hidden_states, attention_mask, head_mask, output_attentions)
if output_attentions:
attention_outputs = self_output[1]
outputs = self_output[0]
return outputs, attention_outputs
else:
return self_output
def __init__(self, config: SwitchTransformersConfig, has_relative_attention_bias=False):
super().__init__()
self.is_decoder = config.is_decoder
self.has_relative_attention_bias = has_relative_attention_bias
self.relative_attention_num_buckets = config.relative_attention_num_buckets
self.relative_attention_max_distance = config.relative_attention_max_distance
self.d_model = config.d_model
self.key_value_proj_dim = config.d_kv
self.n_heads = config.num_heads
self.dropout = config.dropout_rate
self.inner_dim = self.n_heads * self.key_value_proj_dim
self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
if self.has_relative_attention_bias:
self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
self.pruned_heads = set()
self.gradient_checkpointing = False
def prune_heads(self, heads):
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(
heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads
)
self.q = prune_linear_layer(self.q, index)
self.k = prune_linear_layer(self.k, index)
self.v = prune_linear_layer(self.v, index)
self.o = prune_linear_layer(self.o, index, dim=1)
self.n_heads = self.n_heads - len(heads)
self.inner_dim = self.key_value_proj_dim * self.n_heads
self.pruned_heads = self.pruned_heads.union(heads)
def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
"""
Adapted from Mesh Tensorflow:
https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
Translate relative position to a bucket number for relative attention. The relative position is defined as
memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
This should allow for more graceful generalization to longer sequences than the model has been trained on
Args:
relative_position: an int32 Tensor - the relative position between memory and query positions
bidirectional: a boolean - whether the attention is bidirectional
num_buckets: an integer - number of buckets to categorize relative positions
max_distance: an integer - maximum distance considered for bucketing
Returns:
a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
"""
relative_buckets = 0
if bidirectional:
num_buckets //= 2
relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
relative_position = torch.abs(relative_position)
else:
relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
max_exact = num_buckets // 2
is_small = relative_position < max_exact
relative_position_if_large = max_exact + (
torch.log(relative_position.float() / max_exact)
/ math.log(max_distance / max_exact)
* (num_buckets - max_exact)
).to(torch.long)
relative_position_if_large = torch.min(
relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
)
relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
return relative_buckets
def compute_bias(self, query_length, key_length, device=None):
"""Compute binned relative position bias"""
if device is None:
device = self.relative_attention_bias.weight.device
context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
relative_position = memory_position - context_position
relative_position_bucket = self._relative_position_bucket(
relative_position,
bidirectional=(not self.is_decoder),
num_buckets=self.relative_attention_num_buckets,
max_distance=self.relative_attention_max_distance,
)
values = self.relative_attention_bias(relative_position_bucket)
values = values.permute([2, 0, 1]).unsqueeze(0)
return values
def forward(
self,
hidden_states,
mask=None,
key_value_states=None,
position_bias=None,
past_key_value=None,
layer_head_mask=None,
query_length=None,
use_cache=False,
output_attentions=False,
class SwitchTransformersLayerSelfAttention(nn.Module):
def __init__(self, config, has_relative_attention_bias=False):
super().__init__()
self.SelfAttention = SwitchTransformersAttention(
config, has_relative_attention_bias=has_relative_attention_bias
)
self.layer_norm = SwitchTransformersLayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate)
def forward(
self,
hidden_states,
attention_mask=None,
position_bias=None,
layer_head_mask=None,
past_key_value=None,
use_cache=False,
output_attentions=False,
):
normed_hidden_states = self.layer_norm(hidden_states)
attention_output = self.SelfAttention(
normed_hidden_states,
mask=attention_mask,
position_bias=position_bias,
layer_head_mask=layer_head_mask,
past_key_value=past_key_value,
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = hidden_states + self.dropout(attention_output[0])
outputs = (hidden_states,) + attention_output[1:]
return outputs
class SwitchTransformersLayerCrossAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.EncDecAttention = SwitchTransformersAttention(config, has_relative_attention_bias=False)
self.layer_norm = SwitchTransformersLayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate)
def forward(
self,
hidden_states,
key_value_states,
attention_mask=None,
position_bias=None,
layer_head_mask=None,
past_key_value=None,
use_cache=False,
query_length=None,
output_attentions=False,
):
normed_hidden_states = self.layer_norm(hidden_states)
attention_output = self.EncDecAttention(
normed_hidden_states,
mask=attention_mask,
key_value_states=key_value_states,
position_bias=position_bias,
layer_head_mask=layer_head_mask,
past_key_value=past_key_value,
use_cache=use_cache,
query_length=query_length,
output_attentions=output_attentions,
)
layer_output = hidden_states + self.dropout(attention_output[0])
outputs = (layer_output,) + attention_output[1:]
return outputs
class SwitchTransformersBlock(nn.Module):
pass
def __init__(self, config, has_relative_attention_bias=False, is_sparse=False):
super().__init__()
self.is_decoder = config.is_decoder
self.is_sparse = is_sparse
self.layer = nn.ModuleList()
self.layer.append(
SwitchTransformersLayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias)
)
if self.is_decoder:
self.layer.append(SwitchTransformersLayerCrossAttention(config))
self.layer.append(SwitchTransformersLayerFF(config, is_sparse=self.is_sparse))
def forward(
self,
hidden_states,
attention_mask=None,
position_bias=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
encoder_decoder_position_bias=None,
layer_head_mask=None,
cross_attn_layer_head_mask=None,
past_key_value=None,
use_cache=False,
output_attentions=False,
output_router_logits=True,
return_dict=True,
class SwitchTransformersPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = SwitchTransformersConfig
base_model_prefix = "switch_transformers"
supports_gradient_checkpointing = True
_no_split_modules = ["SwitchTransformersBlock"]
@property
def dummy_inputs(self):
input_ids = torch.tensor(DUMMY_INPUTS)
input_mask = torch.tensor(DUMMY_MASK)
dummy_inputs = {
"decoder_input_ids": input_ids,
"input_ids": input_ids,
"decoder_attention_mask": input_mask,
}
return dummy_inputs
def _shift_right(self, input_ids):
decoder_start_token_id = self.config.decoder_start_token_id
pad_token_id = self.config.pad_token_id
if decoder_start_token_id is None:
raise ValueError(
"self.model.config.decoder_start_token_id has to be defined. In SwitchTransformers it is usually set"
" to the pad_token_id. See SwitchTransformers docs for more information"
)
if is_torch_fx_proxy(input_ids):
shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id)
shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
else:
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
shifted_input_ids[..., 0] = decoder_start_token_id
if pad_token_id is None:
raise ValueError("self.model.config.pad_token_id has to be defined.")
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
return shifted_input_ids
class SwitchTransformersStack(SwitchTransformersPreTrainedModel):
def __init__(self, config, embed_tokens=None):
super().__init__(config)
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model)
if embed_tokens is not None:
self.embed_tokens.weight = embed_tokens.weight
self.is_decoder = config.is_decoder
sparse_step = config.decoder_sparse_step if self.is_decoder else config.encoder_sparse_step
config.num_layers = config.num_decoder_layers if self.is_decoder else config.num_layers
self.block = nn.ModuleList()
for i in range(config.num_layers):
is_sparse = (i % sparse_step == 1 or sparse_step == 1) if sparse_step > 0 else False
self.block.append(
SwitchTransformersBlock(config, has_relative_attention_bias=bool(i == 0), is_sparse=is_sparse)
)
self.final_layer_norm = SwitchTransformersLayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate)
self.post_init()
self.device_map = None
self.gradient_checkpointing = False
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, new_embeddings):
self.embed_tokens = new_embeddings
def forward(
self,
input_ids=None,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
inputs_embeds=None,
head_mask=None,
cross_attn_head_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
output_router_logits=True,
return_dict=None,
SWITCH_TRANSFORMERS_START_DOCSTRING = r"""
The SWITCH_TRANSFORMERS model was proposed in [Switch Transformers: Scaling to Trillion Parameter Models with
Simple and Efficient Sparsity](https://arxiv.org/abs/2101.03961) by [William
Fedus](https://arxiv.org/search/cs?searchtype=author&query=Fedus%2C+W), [Barret
Zoph](https://arxiv.org/search/cs?searchtype=author&query=Zoph%2C+B), and [Noam
Shazeer](https://arxiv.org/search/cs?searchtype=author&query=Shazeer%2C+N). It's an encoder-decoder T5-like model
with sparse Feed Forward that stands for Mixture of Experts (MoE) architecture.
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`SwitchTransformersConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
SWITCH_TRANSFORMERS_INPUTS_DOCSTRING = r"""
"""
SWITCH_TRANSFORMERS_ENCODER_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
# 输入序列标记在词汇表中的索引。SWITCH_TRANSFORMERS 是一个带有相对位置嵌入的模型,因此可以在左右两侧填充输入。
# 可以使用 `AutoTokenizer` 获取索引。详见 `PreTrainedTokenizer.encode` 和 `PreTrainedTokenizer.__call__`。
# 如需了解如何为预训练准备 `input_ids`,请查看 [SWITCH_TRANSFORMERS Training](./switch_transformers#training)。
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
# 避免对填充的标记索引执行注意力操作的掩码。掩码值在 `[0, 1]` 之间:
# - 1 表示**不掩盖**的标记,
# - 0 表示**掩盖**的标记。
# [什么是注意力掩码?](../glossary#attention-mask)
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
# 用于使自注意力模块中选择的头部失效的掩码。掩码值在 `[0, 1]` 之间:
# - 1 表示头部**不被掩盖**,
# - 0 表示头部**被掩盖**。
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
# 可选项,直接传递嵌入表示而不是传递 `input_ids`。如果您希望更好地控制如何将 `input_ids` 索引转换为相关联的向量,而不是使用模型的内部嵌入查找矩阵,则此选项很有用。
output_attentions (`bool`, *optional*):
# 是否返回所有注意力层的注意力张量。有关详细信息,请查看返回张量下的 `attentions`。
output_hidden_states (`bool`, *optional*):
# 是否返回所有层的隐藏状态。有关详细信息,请查看返回张量下的 `hidden_states`。
output_router_logits (`bool`, *optional*):
# 是否返回所有路由器的 logits。这对计算路由器损失很有用,在推断期间不应返回。
return_dict (`bool`, *optional*):
# 是否返回 `~utils.ModelOutput` 而不是普通元组。
"""
__HEAD_MASK_WARNING_MSG = """
The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently,
`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions.
If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers,
num_heads)`.
"""
@add_start_docstrings(
"The bare SWITCH_TRANSFORMERS Model transformer outputting raw hidden-states without any specific head on top.",
SWITCH_TRANSFORMERS_START_DOCSTRING,
)
class SwitchTransformersModel(SwitchTransformersPreTrainedModel):
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: SwitchTransformersConfig):
super().__init__(config)
self.shared = nn.Embedding(config.vocab_size, config.d_model)
encoder_config = copy.deepcopy(config)
encoder_config.is_decoder = False
encoder_config.use_cache = False
encoder_config.is_encoder_decoder = False
self.encoder = SwitchTransformersStack(encoder_config, self.shared)
decoder_config = copy.deepcopy(config)
decoder_config.is_decoder = True
decoder_config.is_encoder_decoder = False
self.decoder = SwitchTransformersStack(decoder_config, self.shared)
self.post_init()
self.device_map = None
def get_input_embeddings(self):
return self.shared
def set_input_embeddings(self, new_embeddings):
self.shared = new_embeddings
self.encoder.set_input_embeddings(new_embeddings)
self.decoder.set_input_embeddings(new_embeddings)
def _tie_weights(self):
if self.config.tie_word_embeddings:
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
def get_encoder(self):
return self.encoder
def get_decoder(self):
return self.decoder
def _prune_heads(self, heads_to_prune):
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
@add_start_docstrings_to_model_forward(SWITCH_TRANSFORMERS_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Seq2SeqMoEModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.BoolTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
decoder_head_mask: Optional[torch.FloatTensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.Tensor] = None,
decoder_inputs_embeds: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_router_logits: Optional[bool] = None,
return_dict: Optional[bool] = None,
@add_start_docstrings(
"""SWITCH_TRANSFORMERS Model with a `language modeling` head on top.""", SWITCH_TRANSFORMERS_START_DOCSTRING
)
class SwitchTransformersForConditionalGeneration(SwitchTransformersPreTrainedModel):
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
def __init__(self, config: SwitchTransformersConfig):
super().__init__(config)
self.model_dim = config.d_model
self.shared = nn.Embedding(config.vocab_size, config.d_model)
encoder_config = copy.deepcopy(config)
encoder_config.is_decoder = False
encoder_config.use_cache = False
encoder_config.is_encoder_decoder = False
self.encoder = SwitchTransformersStack(encoder_config, self.shared)
decoder_config = copy.deepcopy(config)
decoder_config.is_decoder = True
decoder_config.is_encoder_decoder = False
decoder_config.num_layers = config.num_decoder_layers
self.decoder = SwitchTransformersStack(decoder_config, self.shared)
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
self.router_z_loss_coef = config.router_z_loss_coef
self.router_aux_loss_coef = config.router_aux_loss_coef
self.post_init()
self.device_map = None
def get_input_embeddings(self):
return self.shared
def set_input_embeddings(self, new_embeddings):
self.shared = new_embeddings
self.encoder.set_input_embeddings(new_embeddings)
self.decoder.set_input_embeddings(new_embeddings)
def _tie_weights(self):
if self.config.tie_word_embeddings:
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def get_output_embeddings(self):
return self.lm_head
def get_encoder(self):
return self.encoder
def get_decoder(self):
return self.decoder
@add_start_docstrings_to_model_forward(SWITCH_TRANSFORMERS_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Seq2SeqMoEOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.BoolTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
decoder_head_mask: Optional[torch.FloatTensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_router_logits: Optional[bool] = True,
return_dict: Optional[bool] = None,
):
...
def _unpack_router_logits(self, router_outputs):
total_router_logits = []
total_expert_indexes = []
for router_output in router_outputs:
if len(router_output[0].shape) > 1:
router_logits, expert_indexes = router_output
total_router_logits.append(router_logits)
total_expert_indexes.append(expert_indexes)
return torch.cat(total_router_logits, dim=1), torch.cat(total_expert_indexes, dim=1)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
use_cache=None,
encoder_outputs=None,
**kwargs,
):
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
return {
"decoder_input_ids": input_ids,
"past_key_values": past_key_values,
"encoder_outputs": encoder_outputs,
"attention_mask": attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
"use_cache": use_cache,
}
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return self._shift_right(labels)
def _reorder_cache(self, past_key_values, beam_idx):
if past_key_values is None:
logger.warning("You might want to consider setting `use_cache=True` to speed up decoding")
return past_key_values
reordered_decoder_past = ()
for layer_past_states in past_key_values:
reordered_layer_past_states = ()
for layer_past_state in layer_past_states:
reordered_layer_past_states = reordered_layer_past_states + (
layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)),
)
if reordered_layer_past_states[0].shape != layer_past_states[0].shape:
raise ValueError(
"expected reordered_layer_past_states to have the same shape than layer_past_states, "
f"but got {reordered_layer_past_states[0].shape} and {layer_past_states[0].shape}"
)
if len(reordered_layer_past_states) != len(layer_past_states):
raise ValueError(
"expected layer_past_states to have the same length as reordered_layer_past_states, "
f"but got {len(layer_past_states)} and {len(reordered_layer_past_states)}"
)
reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
return reordered_decoder_past
@add_start_docstrings(
"The bare SWITCH_TRANSFORMERS Model transformer outputting encoder's raw hidden-states without any specific head"
" on top.",
SWITCH_TRANSFORMERS_START_DOCSTRING,
)
class SwitchTransformersEncoderModel(SwitchTransformersPreTrainedModel):
_tied_weights_keys = ["encoder.embed_tokens.weight"]
def __init__(self, config: SwitchTransformersConfig):
super().__init__(config)
self.shared = nn.Embedding(config.vocab_size, config.d_model)
encoder_config = copy.deepcopy(config)
encoder_config.use_cache = False
encoder_config.is_encoder_decoder = False
self.encoder = SwitchTransformersStack(encoder_config, self.shared)
self.post_init()
self.device_map = None
def get_input_embeddings(self):
return self.shared
def set_input_embeddings(self, new_embeddings):
self.shared = new_embeddings
self.encoder.set_input_embeddings(new_embeddings)
def _tie_weights(self):
if self.config.tie_word_embeddings:
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
def get_encoder(self):
return self.encoder
def _prune_heads(self, heads_to_prune):
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.encoder.block[layer].layer[0].SelfAttention.prune_heads(heads)
@add_start_docstrings_to_model_forward(SWITCH_TRANSFORMERS_ENCODER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=MoEModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_router_logits: Optional[bool] = True,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.FloatTensor], MoEModelOutput]:
r"""
Returns the encoder outputs based on the given inputs and optional configurations.
Example:
```
>>> from transformers import AutoTokenizer, SwitchTransformersEncoderModel
>>> tokenizer = AutoTokenizer.from_pretrained("google/switch-base-8")
>>> model = SwitchTransformersEncoderModel.from_pretrained("google/switch-base-8")
>>> input_ids = tokenizer(
... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
... ).input_ids # Batch size 1
>>> outputs = model(input_ids=input_ids)
>>> last_hidden_states = outputs.last_hidden_state
```
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
encoder_outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_router_logits=output_router_logits,
return_dict=return_dict,
)
return encoder_outputs
.\models\switch_transformers\__init__.py
from typing import TYPE_CHECKING
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_flax_available,
is_sentencepiece_available,
is_tf_available,
is_tokenizers_available,
is_torch_available,
)
_import_structure = {
"configuration_switch_transformers": [
"SWITCH_TRANSFORMERS_PRETRAINED_CONFIG_ARCHIVE_MAP",
"SwitchTransformersConfig",
"SwitchTransformersOnnxConfig",
]
}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_switch_transformers"] = [
"SWITCH_TRANSFORMERS_PRETRAINED_MODEL_ARCHIVE_LIST",
"SwitchTransformersEncoderModel",
"SwitchTransformersForConditionalGeneration",
"SwitchTransformersModel",
"SwitchTransformersPreTrainedModel",
"SwitchTransformersTop1Router",
"SwitchTransformersSparseMLP",
]
if TYPE_CHECKING:
from .configuration_switch_transformers import (
SWITCH_TRANSFORMERS_PRETRAINED_CONFIG_ARCHIVE_MAP,
SwitchTransformersConfig,
SwitchTransformersOnnxConfig,
)
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_switch_transformers import (
SWITCH_TRANSFORMERS_PRETRAINED_MODEL_ARCHIVE_LIST,
SwitchTransformersEncoderModel,
SwitchTransformersForConditionalGeneration,
SwitchTransformersModel,
SwitchTransformersPreTrainedModel,
SwitchTransformersSparseMLP,
SwitchTransformersTop1Router,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
.\models\t5\configuration_t5.py
""" T5 模型配置 """
from typing import Mapping
from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxSeq2SeqConfigWithPast
from ...utils import logging
logger = logging.get_logger(__name__)
T5_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"google-t5/t5-small": "https://huggingface.co/google-t5/t5-small/resolve/main/config.json",
"google-t5/t5-base": "https://huggingface.co/google-t5/t5-base/resolve/main/config.json",
"google-t5/t5-large": "https://huggingface.co/google-t5/t5-large/resolve/main/config.json",
"google-t5/t5-3b": "https://huggingface.co/google-t5/t5-3b/resolve/main/config.json",
"google-t5/t5-11b": "https://huggingface.co/google-t5/t5-11b/resolve/main/config.json",
}
class T5Config(PretrainedConfig):
r"""
这是一个配置类,用于存储 [`T5Model`] 或 [`TFT5Model`] 的配置。它用于根据指定的参数实例化 T5 模型,定义模型架构。
使用默认参数实例化配置对象将产生类似于 T5 [google-t5/t5-small](https://huggingface.co/google-t5/t5-small) 架构的配置。
配置对象继承自 [`PretrainedConfig`],可以用于控制模型输出。有关更多信息,请参阅 [`PretrainedConfig`] 的文档。
```
# 模型类型设定为 "t5"
model_type = "t5"
# 推断阶段忽略的关键字列表,这里包括 "past_key_values"
keys_to_ignore_at_inference = ["past_key_values"]
# 属性映射字典,将模型属性名映射到通用命名
attribute_map = {"hidden_size": "d_model", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"}
# 初始化函数,用于初始化Transformer模型的各种参数和配置
def __init__(
self,
vocab_size=32128, # 词汇表大小,默认为32128
d_model=512, # 模型的维度,默认为512
d_kv=64, # 键和值的维度,默认为64
d_ff=2048, # Feed Forward层的维度,默认为2048
num_layers=6, # Transformer层的数量,默认为6
num_decoder_layers=None, # 解码器层的数量,默认为num_layers的值,保持对称性
num_heads=8, # 多头注意力机制中的头数,默认为8
relative_attention_num_buckets=32, # 相对注意力机制中的桶数,默认为32
relative_attention_max_distance=128, # 相对注意力机制的最大距离,默认为128
dropout_rate=0.1, # Dropout率,默认为0.1
layer_norm_epsilon=1e-6, # Layer Normalization中的epsilon,默认为1e-6
initializer_factor=1.0, # 初始化因子,默认为1.0
feed_forward_proj="relu", # 前向传播投影层的激活函数,默认为'relu'
is_encoder_decoder=True, # 是否是编码器-解码器结构,默认为True
use_cache=True, # 是否使用缓存,默认为True
pad_token_id=0, # 填充token的ID,默认为0
eos_token_id=1, # EOS(句子结束)token的ID,默认为1
classifier_dropout=0.0, # 分类器中的Dropout率,默认为0.0
**kwargs, # 其他关键字参数
):
self.vocab_size = vocab_size # 初始化词汇表大小
self.d_model = d_model # 初始化模型的维度
self.d_kv = d_kv # 初始化键和值的维度
self.d_ff = d_ff # 初始化Feed Forward层的维度
self.num_layers = num_layers # 初始化Transformer层的数量
self.num_decoder_layers = (
num_decoder_layers if num_decoder_layers is not None else self.num_layers
) # 解码器层数,默认为num_layers的值,保持对称性
self.num_heads = num_heads # 初始化多头注意力机制中的头数
self.relative_attention_num_buckets = relative_attention_num_buckets # 初始化相对注意力机制中的桶数
self.relative_attention_max_distance = relative_attention_max_distance # 初始化相对注意力机制的最大距离
self.dropout_rate = dropout_rate # 初始化Dropout率
self.classifier_dropout = classifier_dropout # 初始化分类器中的Dropout率
self.layer_norm_epsilon = layer_norm_epsilon # 初始化Layer Normalization中的epsilon
self.initializer_factor = initializer_factor # 初始化初始化因子
self.feed_forward_proj = feed_forward_proj # 初始化前向传播投影层的激活函数
self.use_cache = use_cache # 初始化是否使用缓存
act_info = self.feed_forward_proj.split("-") # 拆分前向传播投影激活函数的信息
self.dense_act_fn = act_info[-1] # 设置密集层的激活函数
self.is_gated_act = act_info[0] == "gated" # 判断是否为门控激活函数
if len(act_info) > 1 and act_info[0] != "gated" or len(act_info) > 2:
# 如果激活函数信息长度大于1且不是'gated'或长度大于2,则引发值错误
raise ValueError(
f"`feed_forward_proj`: {feed_forward_proj} is not a valid activation function of the dense layer. "
"Please make sure `feed_forward_proj` is of the format `gated-{ACT_FN}` or `{ACT_FN}`, e.g. "
"'gated-gelu' or 'relu'"
)
# 为了向后兼容性
if feed_forward_proj == "gated-gelu":
self.dense_act_fn = "gelu_new" # 设置密集层的激活函数为'gelu_new'
super().__init__(
pad_token_id=pad_token_id, # 初始化填充token的ID
eos_token_id=eos_token_id, # 初始化EOS(句子结束)token的ID
is_encoder_decoder=is_encoder_decoder, # 初始化是否是编码器-解码器结构
**kwargs, # 其他关键字参数
)
class T5OnnxConfig(OnnxSeq2SeqConfigWithPast):
# 定义 T5 模型的配置类,继承自带过去信息的 Seq2Seq ONNX 配置类
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
# 定义输入属性,返回一个映射类型,将字符串键映射到包含整数和字符串的字典
# 常见的输入配置,包括输入 ID 和注意力掩码
common_inputs = {
"input_ids": {0: "batch", 1: "encoder_sequence"},
"attention_mask": {0: "batch", 1: "encoder_sequence"},
}
# 如果使用过去信息
if self.use_past:
# 调整注意力掩码以包含过去的编码器序列信息
common_inputs["attention_mask"][1] = "past_encoder_sequence + sequence"
# 添加解码器输入 ID 的配置
common_inputs["decoder_input_ids"] = {0: "batch"}
# 添加解码器注意力掩码的配置,包括过去的解码器序列信息和当前序列信息
common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"}
else:
# 添加默认的解码器输入 ID 配置
common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"}
# 添加默认的解码器注意力掩码配置
common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"}
# 如果使用过去信息,调用内部方法填充键值对
if self.use_past:
self.fill_with_past_key_values_(common_inputs, direction="inputs")
# 返回最终的输入配置字典
return common_inputs
@property
def default_onnx_opset(self) -> int:
# 定义默认的 ONNX 运算集版本号为 13
return 13
.\models\t5\convert_t5x_checkpoint_to_flax.py
"""Convert T5X checkpoints from the original repository to JAX/FLAX model."""
import argparse
from t5x import checkpoints
from transformers import FlaxT5ForConditionalGeneration, T5Config
def convert_t5x_checkpoint_to_flax(t5x_checkpoint_path, config_name, flax_dump_folder_path):
config = T5Config.from_pretrained(config_name)
flax_model = FlaxT5ForConditionalGeneration(config=config)
t5x_model = checkpoints.load_t5x_checkpoint(t5x_checkpoint_path)
split_mlp_wi = "wi_0" in t5x_model["target"]["encoder"]["layers_0"]["mlp"]
for layer_index in range(config.num_layers):
layer_name = f"layers_{str(layer_index)}"
t5x_attention_key = t5x_model["target"]["encoder"][layer_name]["attention"]["key"]["kernel"]
t5x_attention_out = t5x_model["target"]["encoder"][layer_name]["attention"]["out"]["kernel"]
t5x_attention_query = t5x_model["target"]["encoder"][layer_name]["attention"]["query"]["kernel"]
t5x_attention_value = t5x_model["target"]["encoder"][layer_name]["attention"]["value"]["kernel"]
t5x_attention_layer_norm = t5x_model["target"]["encoder"][layer_name]["pre_attention_layer_norm"]["scale"]
if split_mlp_wi:
t5x_mlp_wi_0 = t5x_model["target"]["encoder"][layer_name]["mlp"]["wi_0"]["kernel"]
t5x_mlp_wi_1 = t5x_model["target"]["encoder"][layer_name]["mlp"]["wi_1"]["kernel"]
else:
t5x_mlp_wi = t5x_model["target"]["encoder"][layer_name]["mlp"]["wi"]["kernel"]
t5x_mlp_wo = t5x_model["target"]["encoder"][layer_name]["mlp"]["wo"]["kernel"]
t5x_mlp_layer_norm = t5x_model["target"]["encoder"][layer_name]["pre_mlp_layer_norm"]["scale"]
flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["k"][
"kernel"
] = t5x_attention_key
flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["o"][
"kernel"
] = t5x_attention_out
flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["q"][
"kernel"
] = t5x_attention_query
flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["v"][
"kernel"
] = t5x_attention_value
flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["layer_norm"][
"weight"
] = t5x_attention_layer_norm
if split_mlp_wi:
flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wi_0"][
"kernel"
] = t5x_mlp_wi_0
flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wi_1"][
"kernel"
] = t5x_mlp_wi_1
else:
flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wi"][
"kernel"
] = t5x_mlp_wi
flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wo"][
"kernel"
] = t5x_mlp_wo
flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["layer_norm"][
"weight"
] = t5x_mlp_layer_norm
t5x_encoder_rel_embedding = t5x_model["target"]["encoder"]["relpos_bias"]["rel_embedding"].T
flax_model.params["encoder"]["block"]["0"]["layer"]["0"]["SelfAttention"]["relative_attention_bias"][
"embedding"
] = t5x_encoder_rel_embedding
t5x_encoder_norm = t5x_model["target"]["encoder"]["encoder_norm"]["scale"]
flax_model.params["encoder"]["final_layer_norm"]["weight"] = t5x_encoder_norm
tx5_decoder_norm = t5x_model["target"]["decoder"]["decoder_norm"]["scale"]
flax_model.params["decoder"]["final_layer_norm"]["weight"] = tx5_decoder_norm
t5x_decoder_rel_embedding = t5x_model["target"]["decoder"]["relpos_bias"]["rel_embedding"].T
flax_model.params["decoder"]["block"]["0"]["layer"]["0"]["SelfAttention"]["relative_attention_bias"][
"embedding"
] = t5x_decoder_rel_embedding
tx5_token_embeddings = t5x_model["target"]["token_embedder"]["embedding"]
flax_model.params["shared"]["embedding"] = tx5_token_embeddings
if "logits_dense" in t5x_model["target"]["decoder"]:
flax_model.params["lm_head"]["kernel"] = t5x_model["target"]["decoder"]["logits_dense"]["kernel"]
flax_model.save_pretrained(flax_dump_folder_path)
print("T5X Model was sucessfully converted!")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--t5x_checkpoint_path", default=None, type=str, required=True, help="Path the TX5 checkpoint."
)
parser.add_argument("--config_name", default=None, type=str, required=True, help="Config name of T5 model.")
parser.add_argument(
"--flax_dump_folder_path", default=None, type=str, required=True, help="Path to the output FLAX model."
)
args = parser.parse_args()
convert_t5x_checkpoint_to_flax(args.t5x_checkpoint_path, args.config_name, args.flax_dump_folder_path)
.\models\t5\convert_t5x_checkpoint_to_pytorch.py
"""
将T5X检查点转换为PyTorch格式
步骤:
- 根据https://cloud.google.com/storage/docs/gsutil_install 安装gsutil
- 在https://github.com/google-research/t5x/blob/main/docs/models.md#t5-11-checkpoints获取T5X检查点 示例:
`gsutil -m cp -r gs://t5-data/pretrained_models/t5x/t5_1_1_small $HOME/`
- 创建或下载相应模型的配置。例如,对于T5 v1.1 small,您可以使用
https://huggingface.co/google/t5-v1_1-small/blob/main/config.json
- 转换:
```
python3 convert_t5x_checkpoint_to_pytorch.py --t5x_checkpoint_path=$HOME/t5_1_1_small --config_file=config.json\
--pytorch_dump_path=$HOME/t5_1_1_small_pt
```
"""
import argparse
import collections
import torch
from flax import traverse_util
from t5x import checkpoints
from transformers import T5Config, T5EncoderModel, T5ForConditionalGeneration
from transformers.utils import logging
logging.set_verbosity_info()
def t5x_attention_lookup(params, i, prefix, layer_name="attention"):
"""返回(self-)attention的KOQV参数,不进行转置。"""
k = params[f"{prefix}/layers_{i}/{layer_name}/key/kernel"]
o = params[f"{prefix}/layers_{i}/{layer_name}/out/kernel"]
q = params[f"{prefix}/layers_{i}/{layer_name}/query/kernel"]
v = params[f"{prefix}/layers_{i}/{layer_name}/value/kernel"]
return k, o, q, v
def t5x_mlp_lookup(params, i, prefix, split_mlp_wi=False):
"""返回层的MLP参数,不进行转置。"""
if split_mlp_wi:
wi_0 = params[f"{prefix}/layers_{i}/mlp/wi_0/kernel"]
wi_1 = params[f"{prefix}/layers_{i}/mlp/wi_1/kernel"]
wi = (wi_0, wi_1)
else:
wi = params[f"{prefix}/layers_{i}/mlp/wi/kernel"]
wo = params[f"{prefix}/layers_{i}/mlp/wo/kernel"]
return wi, wo
def t5x_layer_norm_lookup(params, i, prefix, layer_name):
"""返回层的层归一化参数。"""
return params[f"{prefix}/layers_{i}/{layer_name}/scale"]
def convert_t5x_to_pytorch(variables: dict, *, num_layers: int, num_decoder_layers: int, is_encoder_only: bool):
"""将T5X-Flax的参数转换为Transformers-PyTorch格式。"""
old = traverse_util.flatten_dict(variables["target"])
old = {"/".join(k): v for k, v in old.items()}
split_mlp_wi = "encoder/layers_0/mlp/wi_0/kernel" in old
print("Split MLP:", split_mlp_wi)
new = collections.OrderedDict()
new["shared.weight"] = old["token_embedder/embedding"]
for i in range(num_layers):
layer_norm = t5x_layer_norm_lookup(old, i, "encoder", "pre_attention_layer_norm")
k, o, q, v = t5x_attention_lookup(old, i, "encoder", "attention")
new[f"encoder.block.{i}.layer.0.layer_norm.weight"] = layer_norm
new[f"encoder.block.{i}.layer.0.SelfAttention.k.weight"] = k.T
new[f"encoder.block.{i}.layer.0.SelfAttention.o.weight"] = o.T
new[f"encoder.block.{i}.layer.0.SelfAttention.q.weight"] = q.T
new[f"encoder.block.{i}.layer.0.SelfAttention.v.weight"] = v.T
layer_norm = t5x_layer_norm_lookup(old, i, "encoder", "pre_mlp_layer_norm")
wi, wo = t5x_mlp_lookup(old, i, "encoder", split_mlp_wi)
new[f"encoder.block.{i}.layer.1.layer_norm.weight"] = layer_norm
if split_mlp_wi:
new[f"encoder.block.{i}.layer.1.DenseReluDense.wi_0.weight"] = wi[0].T
new[f"encoder.block.{i}.layer.1.DenseReluDense.wi_1.weight"] = wi[1].T
else:
new[f"encoder.block.{i}.layer.1.DenseReluDense.wi.weight"] = wi.T
new[f"encoder.block.{i}.layer.1.DenseReluDense.wo.weight"] = wo.T
new["encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"] = old[
"encoder/relpos_bias/rel_embedding"
].T
new["encoder.final_layer_norm.weight"] = old["encoder/encoder_norm/scale"]
if not is_encoder_only:
for i in range(num_decoder_layers):
layer_norm = t5x_layer_norm_lookup(old, i, "decoder", "pre_self_attention_layer_norm")
k, o, q, v = t5x_attention_lookup(old, i, "decoder", "self_attention")
new[f"decoder.block.{i}.layer.0.layer_norm.weight"] = layer_norm
new[f"decoder.block.{i}.layer.0.SelfAttention.k.weight"] = k.T
new[f"decoder.block.{i}.layer.0.SelfAttention.o.weight"] = o.T
new[f"decoder.block.{i}.layer.0.SelfAttention.q.weight"] = q.T
new[f"decoder.block.{i}.layer.0.SelfAttention.v.weight"] = v.T
layer_norm = t5x_layer_norm_lookup(old, i, "decoder", "pre_cross_attention_layer_norm")
k, o, q, v = t5x_attention_lookup(old, i, "decoder", "encoder_decoder_attention")
new[f"decoder.block.{i}.layer.1.layer_norm.weight"] = layer_norm
new[f"decoder.block.{i}.layer.1.EncDecAttention.k.weight"] = k.T
new[f"decoder.block.{i}.layer.1.EncDecAttention.o.weight"] = o.T
new[f"decoder.block.{i}.layer.1.EncDecAttention.q.weight"] = q.T
new[f"decoder.block.{i}.layer.1.EncDecAttention.v.weight"] = v.T
layer_norm = t5x_layer_norm_lookup(old, i, "decoder", "pre_mlp_layer_norm")
wi, wo = t5x_mlp_lookup(old, i, "decoder", split_mlp_wi)
new[f"decoder.block.{i}.layer.2.layer_norm.weight"] = layer_norm
if split_mlp_wi:
new[f"decoder.block.{i}.layer.2.DenseReluDense.wi_0.weight"] = wi[0].T
new[f"decoder.block.{i}.layer.2.DenseReluDense.wi_1.weight"] = wi[1].T
else:
new[f"decoder.block.{i}.layer.2.DenseReluDense.wi.weight"] = wi.T
new[f"decoder.block.{i}.layer.2.DenseReluDense.wo.weight"] = wo.T
new["decoder.final_layer_norm.weight"] = old["decoder/decoder_norm/scale"]
new["decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"] = old[
"decoder/relpos_bias/rel_embedding"
].T
if "decoder/logits_dense/kernel" in old:
new["lm_head.weight"] = old["decoder/logits_dense/kernel"].T
return new
def make_state_dict(converted_params, is_encoder_only: bool):
"""Prepares a state dict for the PyTorch model."""
state_dict = collections.OrderedDict([(k, torch.from_numpy(v.copy())) for (k, v) in converted_params.items()])
if "encoder.embed_tokens.weight" not in state_dict:
state_dict["encoder.embed_tokens.weight"] = state_dict["shared.weight"]
if not is_encoder_only:
if "decoder.embed_tokens.weight" not in state_dict:
state_dict["decoder.embed_tokens.weight"] = state_dict["shared.weight"]
if "lm_head.weight" not in state_dict:
print("Using shared word embeddings as lm_head.")
state_dict["lm_head.weight"] = state_dict["shared.weight"]
return state_dict
def load_t5x_weights_in_t5(model, config, t5x_checkpoint_path, is_encoder_only):
"""Replaces the params in model with the T5X converted params."""
variables = checkpoints.load_t5x_checkpoint(t5x_checkpoint_path)
converted = convert_t5x_to_pytorch(
variables,
num_layers=config.num_layers,
num_decoder_layers=config.num_decoder_layers,
is_encoder_only=is_encoder_only,
)
state_dict = make_state_dict(converted, is_encoder_only)
model.load_state_dict(state_dict, strict=True)
def convert_t5x_checkpoint_to_pytorch(
t5x_checkpoint_path, config_file, pytorch_dump_path, is_encoder_only: bool = False
):
"""Loads the config and model, converts the T5X checkpoint, and saves a PyTorch checkpoint."""
config = T5Config.from_json_file(config_file)
print(f"Building PyTorch model from configuration: {config}")
if is_encoder_only:
model = T5EncoderModel(config)
else:
model = T5ForConditionalGeneration(config)
load_t5x_weights_in_t5(model, config, t5x_checkpoint_path, is_encoder_only)
print(f"Save PyTorch model to {pytorch_dump_path}")
model.save_pretrained(pytorch_dump_path)
model.from_pretrained(pytorch_dump_path)
print("Done")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Converts a native T5X checkpoint into a PyTorch checkpoint.")
parser.add_argument(
"--t5x_checkpoint_path", default=None, type=str, required=True, help="Path to the T5X checkpoint."
)
parser.add_argument(
"--config_file",
default=None,
type=str,
required=True,
help="The config json file corresponding to the pre-trained T5 model.\nThis specifies the model architecture.",
)
parser.add_argument(
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
)
parser.add_argument(
"--is_encoder_only", action="store_true", help="Check if the model is encoder-decoder model", default=False
)
args = parser.parse_args()
convert_t5x_checkpoint_to_pytorch(
args.t5x_checkpoint_path, args.config_file, args.pytorch_dump_path, args.is_encoder_only
)
parser.add_argument(
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
)
parser.add_argument(
"--is_encoder_only", action="store_true", help="Check if the model is encoder-decoder model", default=False
)
args = parser.parse_args()
convert_t5x_checkpoint_to_pytorch(
args.t5x_checkpoint_path, args.config_file, args.pytorch_dump_path, args.is_encoder_only
)
这段代码片段用于解析命令行参数,并调用一个函数来执行 T5X 模型到 PyTorch 模型的转换。
.\models\t5\convert_t5_original_tf_checkpoint_to_pytorch.py
"""Convert T5 checkpoint."""
import argparse
from transformers import T5Config, T5ForConditionalGeneration, load_tf_weights_in_t5
from transformers.utils import logging
logging.set_verbosity_info()
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path):
config = T5Config.from_json_file(config_file)
print(f"Building PyTorch model from configuration: {config}")
model = T5ForConditionalGeneration(config)
load_tf_weights_in_t5(model, config, tf_checkpoint_path)
print(f"Save PyTorch model to {pytorch_dump_path}")
model.save_pretrained(pytorch_dump_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--tf_checkpoint_path", default=None, type=str, required=True,
help="Path to the TensorFlow checkpoint path."
)
parser.add_argument(
"--config_file",
default=None,
type=str,
required=True,
help=(
"The config json file corresponding to the pre-trained T5 model. \n"
"This specifies the model architecture."
),
)
parser.add_argument(
"--pytorch_dump_path", default=None, type=str, required=True,
help="Path to the output PyTorch model."
)
args = parser.parse_args()
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.config_file, args.pytorch_dump_path)
.\models\t5\modeling_flax_t5.py
""" Flax T5 model."""
import copy
from typing import Callable, Optional, Tuple
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask
from flax.linen import partitioning as nn_partitioning
from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax.random import PRNGKey
from ...modeling_flax_outputs import (
FlaxBaseModelOutput,
FlaxBaseModelOutputWithPastAndCrossAttentions,
FlaxCausalLMOutputWithCrossAttentions,
FlaxSeq2SeqLMOutput,
FlaxSeq2SeqModelOutput,
)
from ...modeling_flax_utils import (
ACT2FN,
FlaxPreTrainedModel,
append_call_sample_docstring,
append_replace_return_docstrings,
overwrite_call_docstring,
)
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_t5 import T5Config
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "google-t5/t5-small"
_CONFIG_FOR_DOC = "T5Config"
remat = nn_partitioning.remat
def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
"""
Shift input ids one token to the right.
"""
shifted_input_ids = jnp.zeros_like(input_ids)
shifted_input_ids = shifted_input_ids.at[:, 1:].set(input_ids[:, :-1])
shifted_input_ids = shifted_input_ids.at[:, 0].set(decoder_start_token_id)
shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
return shifted_input_ids
class FlaxT5LayerNorm(nn.Module):
hidden_size: int
dtype: jnp.dtype = jnp.float32
eps: float = 1e-6
weight_init: Callable[..., np.ndarray] = jax.nn.initializers.ones
def setup(self):
self.weight = self.param("weight", self.weight_init, (self.hidden_size,))
def __call__(self, hidden_states):
"""
Construct a layernorm module in the T5 style; No bias and no subtraction of mean.
"""
variance = jnp.power(hidden_states.astype("f4"), 2).mean(axis=-1, keepdims=True)
hidden_states = hidden_states / jnp.sqrt(variance + self.eps)
return self.weight * hidden_states
class FlaxT5DenseActDense(nn.Module):
config: T5Config
dtype: jnp.dtype = jnp.float32
def setup(self):
wi_init_std = self.config.initializer_factor * (self.config.d_model**-0.5)
wo_init_std = self.config.initializer_factor * (self.config.d_ff**-0.5)
self.wi = nn.Dense(
self.config.d_ff,
use_bias=False,
kernel_init=jax.nn.initializers.normal(wi_init_std),
dtype=self.dtype,
)
self.wo = nn.Dense(
self.config.d_model,
use_bias=False,
kernel_init=jax.nn.initializers.normal(wo_init_std),
dtype=self.dtype,
)
self.dropout = nn.Dropout(self.config.dropout_rate)
self.act = ACT2FN[self.config.dense_act_fn]
def __call__(self, hidden_states, deterministic=True):
hidden_states = self.wi(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
hidden_states = self.wo(hidden_states)
return hidden_states
class FlaxT5DenseGatedActDense(nn.Module):
config: T5Config
dtype: jnp.dtype = jnp.float32
def setup(self):
wi_init_std = self.config.initializer_factor * (self.config.d_model**-0.5)
wo_init_std = self.config.initializer_factor * (self.config.d_ff**-0.5)
self.wi_0 = nn.Dense(
self.config.d_ff,
use_bias=False,
kernel_init=jax.nn.initializers.normal(wi_init_std),
dtype=self.dtype,
)
self.wi_1 = nn.Dense(
self.config.d_ff,
use_bias=False,
kernel_init=jax.nn.initializers.normal(wi_init_std),
dtype=self.dtype,
)
self.wo = nn.Dense(
self.config.d_model,
use_bias=False,
kernel_init=jax.nn.initializers.normal(wo_init_std),
dtype=self.dtype,
)
self.dropout = nn.Dropout(self.config.dropout_rate)
self.act = ACT2FN[self.config.dense_act_fn]
def __call__(self, hidden_states, deterministic):
hidden_gelu = self.act(self.wi_0(hidden_states))
hidden_linear = self.wi_1(hidden_states)
hidden_states = hidden_gelu * hidden_linear
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
hidden_states = self.wo(hidden_states)
return hidden_states
class FlaxT5LayerFF(nn.Module):
config: T5Config
dtype: jnp.dtype = jnp.float32
def setup(self):
if self.config.is_gated_act:
self.DenseReluDense = FlaxT5DenseGatedActDense(self.config, dtype=self.dtype)
else:
self.DenseReluDense = FlaxT5DenseActDense(self.config, dtype=self.dtype)
self.layer_norm = FlaxT5LayerNorm(self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype)
self.dropout = nn.Dropout(self.config.dropout_rate)
def __call__(self, hidden_states, deterministic=True):
forwarded_states = self.layer_norm(hidden_states)
forwarded_states = self.DenseReluDense(forwarded_states, deterministic=deterministic)
hidden_states = hidden_states + self.dropout(forwarded_states, deterministic=deterministic)
return hidden_states
class FlaxT5Attention(nn.Module):
config: T5Config
has_relative_attention_bias: bool = False
causal: bool = False
dtype: jnp.dtype = jnp.float32
def setup(self):
self.relative_attention_num_buckets = self.config.relative_attention_num_buckets
self.relative_attention_max_distance = self.config.relative_attention_max_distance
self.d_model = self.config.d_model
self.key_value_proj_dim = self.config.d_kv
self.n_heads = self.config.num_heads
self.dropout = self.config.dropout_rate
self.inner_dim = self.n_heads * self.key_value_proj_dim
q_init_std = self.config.initializer_factor * ((self.inner_dim * self.key_value_proj_dim) ** -0.5)
kv_init_std = self.config.initializer_factor * (self.inner_dim**-0.5)
o_init_std = self.config.initializer_factor * (self.inner_dim**-0.5)
self.q = nn.Dense(
self.inner_dim,
use_bias=False,
kernel_init=jax.nn.initializers.normal(q_init_std),
dtype=self.dtype,
)
self.k = nn.Dense(
self.inner_dim,
use_bias=False,
kernel_init=jax.nn.initializers.normal(kv_init_std),
dtype=self.dtype,
)
self.v = nn.Dense(
self.inner_dim,
use_bias=False,
kernel_init=jax.nn.initializers.normal(kv_init_std),
dtype=self.dtype,
)
self.o = nn.Dense(
self.d_model,
use_bias=False,
kernel_init=jax.nn.initializers.normal(o_init_std),
dtype=self.dtype,
)
if self.has_relative_attention_bias:
self.relative_attention_bias = nn.Embed(
self.relative_attention_num_buckets,
self.n_heads,
embedding_init=jax.nn.initializers.normal(kv_init_std),
dtype=self.dtype,
)
@staticmethod
def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
"""
Adapted from Mesh Tensorflow:
https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
Translate relative position to a bucket number for relative attention. The relative position is defined as
memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
This should allow for more graceful generalization to longer sequences than the model has been trained on
"""
relative_buckets = 0
if bidirectional:
num_buckets //= 2
relative_buckets += (relative_position > 0) * num_buckets
relative_position = jnp.abs(relative_position)
else:
relative_position = -jnp.clip(relative_position, a_max=0)
max_exact = num_buckets // 2
is_small = relative_position < max_exact
relative_position_if_large = max_exact + (
jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact)
)
relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1)
relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large)
return relative_buckets.astype("i4")
def compute_bias(self, query_length, key_length):
"""Compute binned relative position bias"""
context_position = jnp.arange(query_length, dtype="i4")[:, None]
memory_position = jnp.arange(key_length, dtype="i4")[None, :]
relative_position = memory_position - context_position
relative_position_bucket = self._relative_position_bucket(
relative_position,
bidirectional=(not self.causal),
num_buckets=self.relative_attention_num_buckets,
max_distance=self.relative_attention_max_distance,
)
values = self.relative_attention_bias(relative_position_bucket)
values = values.transpose((2, 0, 1))[None, :, :, :]
return values
def _split_heads(self, hidden_states):
return hidden_states.reshape(hidden_states.shape[:2] + (self.n_heads, self.key_value_proj_dim))
def _merge_heads(self, hidden_states):
return hidden_states.reshape(hidden_states.shape[:2] + (self.inner_dim,))
@nn.compact
def _concatenate_to_cache(self, key, value, query, attention_mask):
"""
This function takes projected key, value states from a single input token and concatenates the states to cached
states from previous steps. This function is slightly adapted from the official Flax repository:
https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
"""
is_initialized = self.has_variable("cache", "cached_key")
cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
if is_initialized:
*batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
cur_index = cache_index.value
indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
key = jax.lax.dynamic_update_slice(cached_key.value, key, indices)
value = jax.lax.dynamic_update_slice(cached_value.value, value, indices)
cached_key.value = key
cached_value.value = value
num_updated_cache_vectors = query.shape[1]
cache_index.value = cache_index.value + num_updated_cache_vectors
pad_mask = jnp.broadcast_to(
jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
)
attention_mask = combine_masks(pad_mask, attention_mask)
return key, value, attention_mask
def _create_position_bias(
self, key_states, query_states, attention_mask, init_cache, seq_length, causal_attention_mask_shift
):
cache_is_filled = self.causal and self.has_variable("cache", "cached_key") and (not init_cache)
key_length = key_states.shape[1]
query_length = key_length if cache_is_filled else query_states.shape[1]
if self.has_relative_attention_bias:
position_bias = self.compute_bias(query_length, key_length)
elif attention_mask is not None:
position_bias = jnp.zeros_like(attention_mask)
else:
position_bias = jnp.zeros((1, self.n_heads, query_length, key_length), dtype=self.dtype)
if cache_is_filled:
max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
position_bias = jax.lax.dynamic_slice(
position_bias,
(0, 0, causal_attention_mask_shift, 0),
(1, self.n_heads, seq_length, max_decoder_length),
)
return position_bias
def __call__(
self,
hidden_states,
attention_mask=None,
key_value_states=None,
position_bias=None,
use_cache=False,
output_attentions=False,
deterministic=True,
init_cache=False,
class FlaxT5Block(nn.Module):
config: T5Config
has_relative_attention_bias: bool = False
def setup(self):
self.SelfAttention = FlaxT5Attention(
self.config,
has_relative_attention_bias=self.has_relative_attention_bias,
causal=self.config.causal,
dtype=self.dtype,
)
self.layer_norm = FlaxT5LayerNorm(self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype)
self.dropout = nn.Dropout(self.config.dropout_rate)
def __call__(
self,
hidden_states,
attention_mask=None,
position_bias=None,
output_attentions=False,
deterministic=True,
init_cache=False,
):
normed_hidden_states = self.layer_norm(hidden_states)
attention_output = self.SelfAttention(
normed_hidden_states,
attention_mask=attention_mask,
position_bias=position_bias,
output_attentions=output_attentions,
deterministic=deterministic,
init_cache=init_cache,
)
hidden_states = hidden_states + self.dropout(attention_output[0], deterministic=deterministic)
outputs = (hidden_states,) + attention_output[1:]
return outputs
def setup(self):
self.causal = self.config.causal
self.layer = (
FlaxT5LayerSelfAttention(
self.config,
has_relative_attention_bias=self.has_relative_attention_bias,
name=str(0),
dtype=self.dtype,
),
)
feed_forward_index = 1
if self.causal:
self.layer += (FlaxT5LayerCrossAttention(self.config, name=str(1), dtype=self.dtype),)
feed_forward_index += 1
self.layer += (FlaxT5LayerFF(self.config, name=str(feed_forward_index), dtype=self.dtype),)
def __call__(
self,
hidden_states,
attention_mask=None,
position_bias=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
encoder_decoder_position_bias=None,
output_attentions=False,
return_dict=True,
deterministic=True,
init_cache=False,
):
self_attention_outputs = self.layer[0](
hidden_states,
attention_mask=attention_mask,
position_bias=position_bias,
output_attentions=output_attentions,
deterministic=deterministic,
init_cache=init_cache,
)
hidden_states = self_attention_outputs[0]
attention_outputs = self_attention_outputs[1:]
do_cross_attention = self.causal and encoder_hidden_states is not None
if do_cross_attention:
cross_attention_outputs = self.layer[1](
hidden_states,
key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
position_bias=encoder_decoder_position_bias,
output_attentions=output_attentions,
deterministic=deterministic,
)
hidden_states = cross_attention_outputs[0]
attention_outputs = attention_outputs + cross_attention_outputs[1:]
hidden_states = self.layer[-1](hidden_states, deterministic=deterministic)
outputs = (hidden_states,)
outputs = outputs + attention_outputs
return outputs
class FlaxT5LayerCollection(nn.Module):
config: T5Config
has_relative_attention_bias: bool
dtype: jnp.dtype = jnp.float32
def setup(self):
self.layer = FlaxT5Block(
self.config, has_relative_attention_bias=self.has_relative_attention_bias, dtype=self.dtype
)
def __call__(
self,
hidden_states,
attention_mask=None,
position_bias=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
encoder_decoder_position_bias=None,
output_attentions=False,
deterministic=True,
init_cache=False,
):
return self.layer(
hidden_states,
attention_mask=attention_mask,
position_bias=position_bias,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
encoder_decoder_position_bias=encoder_decoder_position_bias,
output_attentions=output_attentions,
deterministic=deterministic,
init_cache=init_cache,
)
class FlaxT5BlockCollection(nn.Module):
config: T5Config
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self):
self.causal = self.config.causal
if self.gradient_checkpointing:
FlaxT5CheckpointLayer = remat(FlaxT5LayerCollection, static_argnums=(6, 7, 8))
self.blocks = [
FlaxT5CheckpointLayer(
self.config,
has_relative_attention_bias=(i == 0),
dtype=self.dtype,
name=str(i),
)
for i in range(self.config.num_layers)
]
else:
self.blocks = [
FlaxT5LayerCollection(
self.config,
has_relative_attention_bias=(i == 0),
dtype=self.dtype,
name=str(i),
)
for i in range(self.config.num_layers)
]
def __call__(
self,
hidden_states=None,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
output_attentions: bool = False,
output_hidden_states: bool = False,
deterministic: bool = True,
init_cache: bool = False,
):
pass
):
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
all_cross_attentions = () if (output_attentions and self.causal) else None
position_bias = None
encoder_decoder_position_bias = None
for i, layer_module in enumerate(self.blocks):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_outputs = layer_module(
hidden_states,
attention_mask,
position_bias,
encoder_hidden_states,
encoder_attention_mask,
encoder_decoder_position_bias,
output_attentions,
deterministic,
init_cache,
)
hidden_states = layer_outputs[0]
position_bias = layer_outputs[1]
if self.causal and encoder_hidden_states is not None:
encoder_decoder_position_bias = layer_outputs[3 if output_attentions else 2]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[2],)
if self.causal:
all_cross_attentions = all_cross_attentions + (layer_outputs[4],)
return FlaxBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_attentions,
cross_attentions=all_cross_attentions,
)
class FlaxT5Stack(nn.Module):
config: T5Config
embed_tokens: nn.Embed
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
def setup(self):
self.causal = self.config.causal
self.block = FlaxT5BlockCollection(
self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)
self.final_layer_norm = FlaxT5LayerNorm(
self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype
)
self.dropout = nn.Dropout(self.config.dropout_rate)
def __call__(
self,
input_ids=None,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
deterministic: bool = True,
init_cache: bool = False,
):
hidden_states = self.embed_tokens(input_ids)
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
outputs = self.block(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
deterministic=deterministic,
init_cache=init_cache,
)
hidden_states = outputs[0]
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
all_hidden_states = None
if output_hidden_states:
all_hidden_states = outputs.hidden_states
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
if output_hidden_states:
return (
hidden_states,
all_hidden_states,
) + outputs[2:]
return (hidden_states,) + outputs[1:]
return FlaxBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
)
T5_ENCODE_INPUTS_DOCSTRING = r"""
# 接收输入参数:
# - input_ids (`jnp.ndarray`,形状为 `(batch_size, sequence_length)`):
# 表示输入序列标记在词汇表中的索引。T5 是一个带有相对位置嵌入的模型,因此可以在左右两侧对输入进行填充。
# 可以使用 [`AutoTokenizer`] 获取索引。详见 [`PreTrainedTokenizer.encode`] 和 [`PreTrainedTokenizer.__call__`]。
# 若要了解有关预训练中如何准备 `input_ids` 的更多信息,请查看 [T5 Training](./t5#training)。
# - attention_mask (`jnp.ndarray`,形状为 `(batch_size, sequence_length)`),*可选*:
# 遮盖掩码,用于避免在填充标记索引上执行注意力操作。遮盖值在 `[0, 1]` 之间:
# - 1 表示 **未遮盖** 的标记,
# - 0 表示 **遮盖** 的标记。
# [什么是注意力遮盖?](../glossary#attention-mask)
# - output_attentions (`bool`,*可选*):
# 是否返回所有注意力层的注意力张量。查看返回的张量中的 `attentions` 以获取更多细节。
# - output_hidden_states (`bool`,*可选*):
# 是否返回所有层的隐藏状态。查看返回的张量中的 `hidden_states` 以获取更多细节。
# - return_dict (`bool`,*可选*):
# 是否返回 [`~utils.ModelOutput`] 而不是普通的元组。
"""
定义一个文档字符串常量,用于描述T5模型输入的参数说明文档。
Args:
decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`):
解码器输入序列标记在词汇表中的索引。
可以使用[`AutoTokenizer`]获取。详见[`PreTrainedTokenizer.encode`]和[`PreTrainedTokenizer.__call__`]。
[decoder_input_ids是什么?](../glossary
在训练时,应提供`decoder_input_ids`。
encoder_outputs (`tuple(tuple(jnp.ndarray)`):
元组包含(`last_hidden_state`, *可选*: `hidden_states`, *可选*: `attentions`)
`last_hidden_state`的形状为`(batch_size, sequence_length, hidden_size)`,*可选*是编码器最后一层的隐藏状态序列。
用于解码器的交叉注意力。
encoder_attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *可选*):
遮罩,避免对填充标记索引执行注意力计算。遮罩中的值选择在`[0, 1]`:
- 1表示**未遮罩**的标记,
- 0表示**已遮罩**的标记。
[注意力遮罩是什么?](../glossary
decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *可选*):
默认行为:生成一个张量,忽略`decoder_input_ids`中的填充标记。默认情况下也将使用因果遮罩。
如果要更改填充行为,应根据需要进行修改。有关默认策略的更多信息,请参见[论文中的图1](https://arxiv.org/abs/1910.13461)。
past_key_values (`Dict[str, np.ndarray]`, *可选*, 由`init_cache`返回或传递先前的`past_key_values`):
预先计算的隐藏状态字典(注意力块中的键和值)。可用于快速自回归解码。预计算的键和值隐藏状态的形状为*[batch_size, max_length]*。
output_attentions (`bool`, *可选*):
是否返回所有注意力层的注意力张量。有关返回张量中`attentions`的更多细节,请参见文档。
output_hidden_states (`bool`, *可选*):
是否返回所有层的隐藏状态。有关返回张量中`hidden_states`的更多细节,请参见文档。
return_dict (`bool`, *可选*):
是否返回[`~utils.ModelOutput`]而不是普通元组。
"""
# 初始化方法,用于创建一个新的模型实例
def __init__(
self,
config: T5Config,
input_shape: Tuple[int] = (1, 1),
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
gradient_checkpointing: bool = False,
**kwargs,
):
# 使用给定的配置和参数实例化模块对象
module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs)
# 调用父类的初始化方法,传入配置、模块对象以及其他参数
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
# 启用梯度检查点功能的方法
def enable_gradient_checkpointing(self):
# 更新模块对象,设置梯度检查点为 True
self._module = self.module_class(
config=self.config,
dtype=self.dtype,
gradient_checkpointing=True,
)
# 初始化模型权重的方法
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# 初始化输入张量 input_ids,全零张量
input_ids = jnp.zeros(input_shape, dtype="i4")
# 创建与 input_ids 形状相同的全一张量 attention_mask
attention_mask = jnp.ones_like(input_ids)
args = [input_ids, attention_mask]
# 如果模块类不是 FlaxT5EncoderModule,则初始化解码器相关输入张量
if self.module_class not in [FlaxT5EncoderModule]:
decoder_input_ids = jnp.ones_like(input_ids)
decoder_attention_mask = jnp.ones_like(input_ids)
args.extend([decoder_input_ids, decoder_attention_mask])
# 切分随机数生成器 rng,用于参数和 dropout
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
# 使用随机数生成器和输入张量初始化模块参数,返回随机初始化后的参数
random_params = self.module.init(
rngs,
*args,
)["params"]
# 如果传入了已有的参数 params,则将缺失的参数填充为随机初始化的参数值
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
# 返回填充后的参数冻结字典
return freeze(unflatten_dict(params))
else:
# 否则直接返回随机初始化的参数
return random_params
# 覆盖父类的 __call__ 方法,定义模型的前向传播
@add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
def __call__(
self,
input_ids: jnp.ndarray,
attention_mask: Optional[jnp.ndarray] = None,
decoder_input_ids: jnp.ndarray = None,
decoder_attention_mask: Optional[jnp.ndarray] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
):
# 如果 `output_attentions` 参数未指定,则使用配置中的默认值
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
# 如果 `output_hidden_states` 参数未指定,则使用配置中的默认值
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
# 如果 `return_dict` 参数未指定,则使用配置中的默认值
return_dict = return_dict if return_dict is not None else self.config.return_dict
# 如果缺少 `decoder_input_ids` 参数,则抛出数值错误异常
if decoder_input_ids is None:
raise ValueError(
"Make sure to provide both `input_ids` and `decoder_input_ids`. `decoder_input_ids` is not passed"
" here."
)
# 准备编码器输入的注意力掩码
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids)
# 准备解码器输入的注意力掩码
if decoder_attention_mask is None:
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
# 处理可能存在的伪随机数生成器
rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
# 调用模型的应用方法,传递参数和输入数据
return self.module.apply(
{"params": params or self.params},
input_ids=jnp.array(input_ids, dtype="i4"),
attention_mask=jnp.array(attention_mask, dtype="i4"),
decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=not train,
rngs=rngs,
)
# 初始化缓存函数,用于自动回归解码的快速初始化
def init_cache(self, batch_size, max_length, encoder_outputs):
r"""
Args:
batch_size (`int`):
用于快速自动回归解码的批大小。定义了初始化缓存时的批大小。
max_length (`int`):
自动回归解码的最大可能长度。定义了初始化缓存时的序列长度。
encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`):
`encoder_outputs` 包括 (`last_hidden_state`, *可选*: `hidden_states`, *可选*: `attentions`)。
`last_hidden_state` 的形状为 `(batch_size, sequence_length, hidden_size)`,
*可选*: 是编码器最后一层的隐藏状态序列。在解码器的交叉注意力中使用。
"""
# 初始化用于检索缓存的输入变量
decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4")
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, **kwargs):
# 获取解码器模块
decoder_module = module._get_decoder_module()
# 调用解码器模块进行前向传播
return decoder_module(
decoder_input_ids,
decoder_attention_mask,
**kwargs,
)
# 使用指定方法进行初始化,仅需调用解码器以初始化缓存
init_variables = self.module.init(
jax.random.PRNGKey(0),
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
encoder_hidden_states=encoder_outputs[0],
init_cache=True,
method=_decoder_forward,
)
# 返回解冻后的缓存变量
return unfreeze(init_variables["cache"])
@add_start_docstrings(T5_ENCODE_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=T5Config)
def encode(
self,
input_ids: jnp.ndarray,
attention_mask: Optional[jnp.ndarray] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
):
r"""
Returns:
Example:
```
>>> from transformers import AutoTokenizer, FlaxT5ForConditionalGeneration
>>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small")
>>> model = FlaxT5ForConditionalGeneration.from_pretrained("google-t5/t5-small")
>>> text = "My friends are cool but they eat too many carbs."
>>> inputs = tokenizer(text, return_tensors="np")
>>> encoder_outputs = model.encode(**inputs)
```"""
# 根据参数设置或者默认配置决定是否输出注意力机制
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
# 根据参数设置或者默认配置决定是否输出隐藏状态
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
# 根据参数设置或者默认配置决定是否返回字典格式的输出
return_dict = return_dict if return_dict is not None else self.config.return_dict
# 如果未提供注意力掩码,则创建一个全为1的掩码,与输入张量维度相同
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids)
# 如果有需要,处理任何的伪随机数生成器
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
# 定义编码器前向函数
def _encoder_forward(module, input_ids, attention_mask, **kwargs):
encode_module = module._get_encoder_module()
return encode_module(input_ids, attention_mask, **kwargs)
# 调用模型的应用方法,传入参数并执行编码器前向计算
return self.module.apply(
{"params": params or self.params}, # 使用给定的参数或者默认参数
input_ids=jnp.array(input_ids, dtype="i4"), # 将输入张量转换为JAX数组
attention_mask=jnp.array(attention_mask, dtype="i4"), # 将注意力掩码转换为JAX数组
output_attentions=output_attentions, # 是否输出注意力机制
output_hidden_states=output_hidden_states, # 是否输出隐藏状态
return_dict=return_dict, # 是否以字典格式返回输出
deterministic=not train, # 是否确定性计算,即非训练模式
rngs=rngs, # 传入的伪随机数生成器
method=_encoder_forward, # 调用的方法,即编码器的前向计算函数
)
@add_start_docstrings(T5_DECODE_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=FlaxBaseModelOutputWithPastAndCrossAttentions, config_class=T5Config)
def decode(
self,
decoder_input_ids,
encoder_outputs,
encoder_attention_mask: Optional[jnp.ndarray] = None,
decoder_attention_mask: Optional[jnp.ndarray] = None,
past_key_values: dict = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
"""
The T5 model was proposed in [Exploring the Limits of Transfer Learning with a Unified Text-to-Text
Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan
Narang, Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. It's an encoder-decoder transformer pre-trained in a
text-to-text denoising generative setting.
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a Flax Linen
[flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
regular Flax Module and refer to the Flax documentation for all matters related to general usage and behavior.
Finally, this model supports inherent JAX features such as:
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
Parameters:
config ([`T5Config`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
`jax.numpy.bfloat16` (on TPUs).
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
specified all the computation will be performed with the given `dtype`.
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
parameters.**
If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
[`~FlaxPreTrainedModel.to_bf16`].
"""
@add_start_docstrings(
"The bare T5 Model transformer outputting raw hidden-states without any specific head on top.",
T5_START_DOCSTRING,
)
class FlaxT5Module(nn.Module):
config: T5Config
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
gradient_checkpointing: bool = False
def _get_encoder_module(self):
"""
Retrieve the encoder module of the T5 model.
"""
return self.encoder
def _get_decoder_module(self):
"""
Retrieve the decoder module of the T5 model.
"""
return self.decoder
# 初始化模型参数,包括共享的嵌入层和编码器、解码器的配置
def setup(self):
# 初始化共享的嵌入层,使用给定的词汇大小和模型维度,使用正态分布进行初始化
self.shared = nn.Embed(
self.config.vocab_size,
self.config.d_model,
embedding_init=jax.nn.initializers.normal(self.config.initializer_factor * 1.0),
dtype=self.dtype,
)
# 复制编码器配置,并设置非因果性(causal=False)
encoder_config = copy.deepcopy(self.config)
encoder_config.causal = False
# 初始化编码器模型,使用共享的嵌入层和复制后的配置
self.encoder = FlaxT5Stack(
encoder_config,
embed_tokens=self.shared,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
# 复制解码器配置,并设置因果性(causal=True),以及解码器层数
decoder_config = copy.deepcopy(self.config)
decoder_config.causal = True
decoder_config.num_layers = self.config.num_decoder_layers
# 初始化解码器模型,使用共享的嵌入层和复制后的配置
self.decoder = FlaxT5Stack(
decoder_config,
embed_tokens=self.shared,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
# 模型调用函数,用于执行编码和解码操作
def __call__(
self,
input_ids=None,
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
encoder_outputs=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
deterministic: bool = True,
):
# 确定是否返回字典格式的输出
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# 编码阶段(训练和第一次预测通道)
encoder_outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=deterministic,
)
# 解码阶段
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
encoder_hidden_states=encoder_outputs[0], # 使用编码器的隐藏状态作为输入
encoder_attention_mask=attention_mask, # 使用编码器的注意力掩码
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=deterministic,
)
# 如果不返回字典格式的输出,则将编码器和解码器的输出合并返回
if not return_dict:
return decoder_outputs + encoder_outputs
# 返回字典格式的输出,包括编码器和解码器的各种隐藏状态和注意力分布
return FlaxSeq2SeqModelOutput(
last_hidden_state=decoder_outputs.last_hidden_state,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)
# 使用 FlaxT5PreTrainedModel 作为基类定义 FlaxT5Model 类
class FlaxT5Model(FlaxT5PreTrainedModel):
# 将 module_class 属性设置为 FlaxT5Module
module_class = FlaxT5Module
# 调用 append_call_sample_docstring 函数,为 FlaxT5Model 类添加示例和文档字符串
append_call_sample_docstring(FlaxT5Model, _CHECKPOINT_FOR_DOC, FlaxSeq2SeqModelOutput, _CONFIG_FOR_DOC)
# 定义 FLAX_T5_MODEL_DOCSTRING 变量,包含返回值和示例的文档字符串
FLAX_T5_MODEL_DOCSTRING = """
Returns:
Example:
```
>>> from transformers import AutoTokenizer, FlaxT5Model
>>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small")
>>> model = FlaxT5Model.from_pretrained("google-t5/t5-small")
>>> input_ids = tokenizer(
... "Studies have been shown that owning a dog is good for you", return_tensors="np"
... ).input_ids
>>> decoder_input_ids = tokenizer("Studies show that", return_tensors="np").input_ids
>>> # preprocess: Prepend decoder_input_ids with start token which is pad token for T5Model.
>>> # This is not needed for torch's T5ForConditionalGeneration as it does this internally using labels arg.
>>> decoder_input_ids = model._shift_right(decoder_input_ids)
>>>
>>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
>>> last_hidden_states = outputs.last_hidden_state
```
"""
# 调用 overwrite_call_docstring 函数,为 FlaxT5Model 类替换调用时的文档字符串
overwrite_call_docstring(FlaxT5Model, T5_INPUTS_DOCSTRING + FLAX_T5_MODEL_DOCSTRING)
# 调用 append_replace_return_docstrings 函数,为 FlaxT5Model 类添加返回值文档字符串
append_replace_return_docstrings(FlaxT5Model, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
# 使用 add_start_docstrings 函数为 FlaxT5EncoderModule 类添加类注释和初始文档字符串
@add_start_docstrings(
"The bare T5 Model transformer outputting encoder's raw hidden-states without any specific head on top.",
T5_START_DOCSTRING,
)
# 定义 FlaxT5EncoderModule 类,继承自 nn.Module
class FlaxT5EncoderModule(nn.Module):
# 包含 T5Config 类型的 config 属性和 jnp.float32 类型的 dtype 属性
config: T5Config
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
gradient_checkpointing: bool = False
# 定义 setup 方法,初始化模型的共享参数和编码器
def setup(self):
# 创建共享的嵌入层,使用正态分布初始化
self.shared = nn.Embed(
self.config.vocab_size,
self.config.d_model,
embedding_init=jax.nn.initializers.normal(self.config.initializer_factor * 1.0),
dtype=self.dtype,
)
# 复制配置以用于编码器,并设置相关属性
encoder_config = copy.deepcopy(self.config)
encoder_config.is_decoder = False
encoder_config.is_encoder_decoder = False
encoder_config.causal = False
# 创建 FlaxT5Stack 编码器
self.encoder = FlaxT5Stack(
encoder_config,
embed_tokens=self.shared,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
# 定义 __call__ 方法,处理模型的正向传播
def __call__(
self,
input_ids=None,
attention_mask=None,
output_attentions=False,
output_hidden_states=False,
return_dict: bool = True,
deterministic: bool = True,
):
# 如果需要编码(训练或第一次预测),调用编码器进行处理
encoder_outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=deterministic,
)
# 返回编码器的输出
return encoder_outputs
# 将模块类设置为FlaxT5EncoderModule
module_class = FlaxT5EncoderModule
# 使用装饰器为__call__方法添加文档字符串,文档字符串来源于T5_ENCODE_INPUTS_DOCSTRING
@add_start_docstrings_to_model_forward(T5_ENCODE_INPUTS_DOCSTRING)
def __call__(
self,
input_ids: jnp.ndarray, # 输入的token IDs,作为JAX数组
attention_mask: Optional[jnp.ndarray] = None, # 可选的注意力遮罩,如果为None则设为全1数组
output_attentions: Optional[bool] = None, # 是否输出注意力权重,如果为None则使用self.config.output_attentions
output_hidden_states: Optional[bool] = None, # 是否输出隐藏状态,如果为None则使用self.config.output_hidden_states
return_dict: Optional[bool] = None, # 是否返回字典格式的输出,如果为None则使用self.config.return_dict
train: bool = False, # 是否为训练模式
params: dict = None, # 参数字典,默认为None
dropout_rng: PRNGKey = None, # dropout的随机数生成器,如果为None则表示不使用dropout
):
# 确定是否输出注意力权重
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
# 确定是否输出隐藏状态
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
# 确定是否返回字典格式的输出
return_dict = return_dict if return_dict is not None else self.config.return_dict
# 准备编码器的输入
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids) # 如果注意力遮罩为None,则全设为1的数组
# 处理可能存在的任何PRNG(伪随机数生成器)
rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
# 调用self.module的apply方法,对输入参数进行编码
return self.module.apply(
{"params": params or self.params}, # 模型参数字典
input_ids=jnp.array(input_ids, dtype="i4"), # token IDs转换为JAX数组,数据类型为32位整数
attention_mask=jnp.array(attention_mask, dtype="i4"), # 注意力遮罩转换为JAX数组,数据类型为32位整数
output_attentions=output_attentions, # 是否输出注意力权重
output_hidden_states=output_hidden_states, # 是否输出隐藏状态
return_dict=return_dict, # 是否返回字典格式的输出
deterministic=not train, # 是否确定性运行,即非训练模式
rngs=rngs, # PRNG(伪随机数生成器)字典
)
@add_start_docstrings("""T5 Model with a `language modeling` head on top.""", T5_START_DOCSTRING)
class FlaxT5ForConditionalGenerationModule(nn.Module):
config: T5Config
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
gradient_checkpointing: bool = False
def _get_encoder_module(self):
# 返回编码器模块
return self.encoder
def _get_decoder_module(self):
# 返回解码器模块
return self.decoder
def setup(self):
self.model_dim = self.config.d_model
self.shared = nn.Embed(
self.config.vocab_size,
self.config.d_model,
embedding_init=jax.nn.initializers.normal(self.config.initializer_factor),
dtype=self.dtype,
)
encoder_config = copy.deepcopy(self.config)
encoder_config.causal = False
encoder_config.use_cache = False
encoder_config.is_encoder_decoder = False
# 初始化编码器模型
self.encoder = FlaxT5Stack(
encoder_config, self.shared, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)
decoder_config = copy.deepcopy(self.config)
decoder_config.causal = True
decoder_config.is_encoder_decoder = False
decoder_config.num_layers = self.config.num_decoder_layers
# 初始化解码器模型
self.decoder = FlaxT5Stack(
decoder_config, self.shared, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)
# 初始化语言模型头部
self.lm_head = nn.Dense(
self.config.vocab_size,
use_bias=False,
kernel_init=jax.nn.initializers.normal(self.config.initializer_factor),
dtype=self.dtype,
)
def __call__(
self,
input_ids=None,
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
encoder_outputs=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
deterministic: bool = True,
# T5 条件生成模型的调用函数
# 参数说明:
# input_ids: 输入序列的 token IDs
# attention_mask: 注意力遮罩,指示哪些位置是 padding 的
# decoder_input_ids: 解码器的输入 token IDs
# decoder_attention_mask: 解码器的注意力遮罩
# encoder_outputs: 编码器的输出
# output_attentions: 是否输出注意力权重
# output_hidden_states: 是否输出隐藏状态
# return_dict: 是否返回字典格式的输出
# deterministic: 是否使用确定性推断
# 函数主体根据输入参数执行条件生成任务,输出生成的结果
):
# 如果 return_dict 参数为 None,则根据配置决定是否使用默认值 self.config.use_return_dict
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# 编码阶段
encoder_outputs = self.encoder(
input_ids=input_ids, # 输入的 token IDs
attention_mask=attention_mask, # 注意力掩码,指示哪些位置是有效的
output_attentions=output_attentions, # 是否输出注意力权重
output_hidden_states=output_hidden_states, # 是否输出隐藏状态
return_dict=return_dict, # 是否返回字典格式的输出
deterministic=deterministic, # 是否确定性计算
)
hidden_states = encoder_outputs[0] # 获取编码器的隐藏状态
# 解码阶段
decoder_outputs = self.decoder(
input_ids=decoder_input_ids, # 解码器的输入 token IDs
attention_mask=decoder_attention_mask, # 解码器的注意力掩码
encoder_hidden_states=hidden_states, # 编码器的隐藏状态,作为解码器的输入
encoder_attention_mask=attention_mask, # 编码器的注意力掩码,用于解码器的注意力机制
output_attentions=output_attentions, # 是否输出注意力权重
output_hidden_states=output_hidden_states, # 是否输出隐藏状态
return_dict=return_dict, # 是否返回字典格式的输出
deterministic=deterministic, # 是否确定性计算
)
sequence_output = decoder_outputs[0] # 获取解码器的输出序列
if self.config.tie_word_embeddings:
# 如果配置中指定共享词嵌入,则在投影到词汇表之前进行输出的重新缩放
# 参考:https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
sequence_output = sequence_output * (self.model_dim**-0.5)
if self.config.tie_word_embeddings:
# 如果配置中指定共享词嵌入,则从共享的变量中获取嵌入层参数,并应用于 lm_head
shared_embedding = self.shared.variables["params"]["embedding"]
lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, sequence_output)
else:
lm_logits = self.lm_head(sequence_output) # 否则直接将输出序列传递给 lm_head
if not return_dict:
# 如果不需要返回字典格式的输出,则返回一组元组
return (lm_logits,) + decoder_outputs[1:] + encoder_outputs
# 返回 FlaxSeq2SeqLMOutput 类型的对象,包含详细的输出信息
return FlaxSeq2SeqLMOutput(
logits=lm_logits,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)
class FlaxT5ForConditionalGeneration(FlaxT5PreTrainedModel):
module_class = FlaxT5ForConditionalGenerationModule
@add_start_docstrings(T5_DECODE_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=T5Config)
# 定义解码方法,用于生成模型输出
def decode(
self,
decoder_input_ids,
encoder_outputs,
encoder_attention_mask: Optional[jnp.ndarray] = None,
decoder_attention_mask: Optional[jnp.ndarray] = None,
past_key_values: dict = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
):
# 准备用于生成的输入数据,返回生成所需的上下文和注意力掩码
def prepare_inputs_for_generation(
self,
decoder_input_ids,
max_length,
attention_mask: Optional[jax.Array] = None,
decoder_attention_mask: Optional[jax.Array] = None,
encoder_outputs=None,
**kwargs,
):
# 初始化缓存,准备生成所需的过去键值
batch_size, seq_length = decoder_input_ids.shape
past_key_values = self.init_cache(batch_size, max_length, encoder_outputs)
# 创建扩展的注意力掩码,用于遮蔽输入之外的位置
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
if decoder_attention_mask is not None:
extended_attention_mask = jax.lax.dynamic_update_slice(
extended_attention_mask, decoder_attention_mask, (0, 0)
)
return {
"past_key_values": past_key_values,
"encoder_outputs": encoder_outputs,
"encoder_attention_mask": attention_mask,
"decoder_attention_mask": extended_attention_mask,
}
# 更新用于生成的输入,将模型输出的过去键值更新到输入参数中
def update_inputs_for_generation(self, model_outputs, model_kwargs):
model_kwargs["past_key_values"] = model_outputs.past_key_values
return model_kwargs
FLAX_T5_CONDITIONAL_GENERATION_DOCSTRING = """
返回:
生成模型的输出结果。
示例:
```
>>> from transformers import AutoTokenizer, FlaxT5ForConditionalGeneration
>>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small")
>>> model = FlaxT5ForConditionalGeneration.from_pretrained("google-t5/t5-small")
>>> ARTICLE_TO_SUMMARIZE = "summarize: My friends are cool but they eat too many carbs."
>>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], return_tensors="np")
>>>
>>> summary_ids = model.generate(inputs["input_ids"]).sequences
>>> print(tokenizer.decode(summary_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=False))
```
"""
overwrite_call_docstring(
# 导入FlaxT5ForConditionalGeneration类,并合并文档字符串常量T5_INPUTS_DOCSTRING和FLAX_T5_CONDITIONAL_GENERATION_DOCSTRING
FlaxT5ForConditionalGeneration, T5_INPUTS_DOCSTRING + FLAX_T5_CONDITIONAL_GENERATION_DOCSTRING
# 将文档字符串追加到指定类的方法中,并替换已有的文档字符串(如果存在),然后返回文档字符串修饰后的方法。
append_replace_return_docstrings(
FlaxT5ForConditionalGeneration, # 将文档字符串添加到 FlaxT5ForConditionalGeneration 类中的方法
output_type=FlaxSeq2SeqLMOutput, # 指定输出类型为 FlaxSeq2SeqLMOutput
config_class=_CONFIG_FOR_DOC # 使用 _CONFIG_FOR_DOC 配置类
)
.\models\t5\modeling_t5.py
""" PyTorch T5 model."""
import copy
import math
import os
import warnings
from typing import List, Optional, Tuple, Union
import torch
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
Seq2SeqLMOutput,
Seq2SeqModelOutput,
Seq2SeqQuestionAnsweringModelOutput,
Seq2SeqSequenceClassifierOutput,
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import ALL_LAYERNORM_LAYERS, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
DUMMY_INPUTS,
DUMMY_MASK,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_torch_fx_proxy,
logging,
replace_return_docstrings,
)
from ...utils.model_parallel_utils import assert_device_map, get_device_map
from .configuration_t5 import T5Config
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "T5Config"
_CHECKPOINT_FOR_DOC = "google-t5/t5-small"
T5_PRETRAINED_MODEL_ARCHIVE_LIST = [
"google-t5/t5-small",
"google-t5/t5-base",
"google-t5/t5-large",
"google-t5/t5-3b",
"google-t5/t5-11b",
]
def load_tf_weights_in_t5(model, config, tf_checkpoint_path):
"""Load tf checkpoints in a pytorch model."""
try:
import re
import numpy as np
import tensorflow as tf
except ImportError:
logger.error(
"Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions."
)
raise
tf_path = os.path.abspath(tf_checkpoint_path)
logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
init_vars = tf.train.list_variables(tf_path)
names = []
tf_weights = {}
for name, shape in init_vars:
logger.info(f"Loading TF weight {name} with shape {shape}")
array = tf.train.load_variable(tf_path, name)
names.append(name)
tf_weights[name] = array
logger.info(f"Weights not copied to PyTorch model: {', '.join(tf_weights.keys())}.")
return model
PARALLELIZE_DOCSTRING = r"""
This is an experimental feature and is a subject to change at a moment's notice.
Uses a device map to distribute attention modules of the model across several devices. If no device map is given,
it will evenly distribute blocks across all devices.
Args:
device_map (`Dict[int, list]`, optional, defaults to None):
A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always
automatically mapped to the first device (for esoteric reasons). That means that the first device should
have fewer attention modules mapped to it than other devices. For reference, the t5 models have the
following number of attention modules:
- google-t5/t5-small: 6
- google-t5/t5-base: 12
- google-t5/t5-large: 24
- google-t5/t5-3b: 24
- google-t5/t5-11b: 24
Example:
```
# Here is an example of a device map on a machine with 4 GPUs using google-t5/t5-3b, which has a total of 24 attention modules:
model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-3b")
device_map = {
0: [0, 1, 2],
1: [3, 4, 5, 6, 7, 8, 9],
2: [10, 11, 12, 13, 14, 15, 16],
3: [17, 18, 19, 20, 21, 22, 23],
}
model.parallelize(device_map)
```
"""
DEPARALLELIZE_DOCSTRING = r"""
Moves the model to cpu from a model parallel state.
Example:
```
# On a 4 GPU machine with google-t5/t5-3b:
model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-3b")
device_map = {
0: [0, 1, 2],
1: [3, 4, 5, 6, 7, 8, 9],
2: [10, 11, 12, 13, 14, 15, 16],
3: [17, 18, 19, 20, 21, 22, 23],
}
model.parallelize(device_map) # Splits the model across several devices
model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache()
```
"""
class T5LayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
return self.weight * hidden_states
try:
from apex.normalization import FusedRMSNorm
T5LayerNorm = FusedRMSNorm
logger.info("Discovered apex.normalization.FusedRMSNorm - will use it instead of T5LayerNorm")
except ImportError:
pass
except Exception:
logger.warning("discovered apex but it failed to load, falling back to T5LayerNorm")
pass
ALL_LAYERNORM_LAYERS.append(T5LayerNorm)
class T5DenseActDense(nn.Module):
def __init__(self, config: T5Config):
super().__init__()
self.wi = nn.Linear(config.d_model, config.d_ff, bias=False)
self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
self.dropout = nn.Dropout(config.dropout_rate)
self.act = ACT2FN[config.dense_act_fn]
def forward(self, hidden_states):
hidden_states = self.wi(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.dropout(hidden_states)
if (
isinstance(self.wo.weight, torch.Tensor)
and hidden_states.dtype != self.wo.weight.dtype
and self.wo.weight.dtype != torch.int8
):
hidden_states = hidden_states.to(self.wo.weight.dtype)
hidden_states = self.wo(hidden_states)
return hidden_states
class T5DenseGatedActDense(nn.Module):
def __init__(self, config: T5Config):
super().__init__()
self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
self.dropout = nn.Dropout(config.dropout_rate)
self.act = ACT2FN[config.dense_act_fn]
def forward(self, hidden_states):
hidden_gelu = self.act(self.wi_0(hidden_states))
hidden_linear = self.wi_1(hidden_states)
hidden_states = hidden_gelu * hidden_linear
hidden_states = self.dropout(hidden_states)
if (
isinstance(self.wo.weight, torch.Tensor)
and hidden_states.dtype != self.wo.weight.dtype
and self.wo.weight.dtype != torch.int8
):
hidden_states = hidden_states.to(self.wo.weight.dtype)
hidden_states = self.wo(hidden_states)
return hidden_states
class T5LayerFF(nn.Module):
def __init__(self, config: T5Config):
super().__init__()
if config.is_gated_act:
self.DenseReluDense = T5DenseGatedActDense(config)
else:
self.DenseReluDense = T5DenseActDense(config)
self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate)
def forward(self, hidden_states):
forwarded_states = self.layer_norm(hidden_states)
forwarded_states = self.DenseReluDense(forwarded_states)
hidden_states = hidden_states + self.dropout(forwarded_states)
return hidden_states
class T5Attention(nn.Module):
def __init__(self, config: T5Config, has_relative_attention_bias=False):
super().__init__()
self.is_decoder = config.is_decoder
self.has_relative_attention_bias = has_relative_attention_bias
self.relative_attention_num_buckets = config.relative_attention_num_buckets
self.relative_attention_max_distance = config.relative_attention_max_distance
self.d_model = config.d_model
self.key_value_proj_dim = config.d_kv
self.n_heads = config.num_heads
self.dropout = config.dropout_rate
self.inner_dim = self.n_heads * self.key_value_proj_dim
self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
if self.has_relative_attention_bias:
self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
self.pruned_heads = set()
self.gradient_checkpointing = False
def prune_heads(self, heads):
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(
heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads
)
self.q = prune_linear_layer(self.q, index)
self.k = prune_linear_layer(self.k, index)
self.v = prune_linear_layer(self.v, index)
self.o = prune_linear_layer(self.o, index, dim=1)
self.n_heads = self.n_heads - len(heads)
self.inner_dim = self.key_value_proj_dim * self.n_heads
self.pruned_heads = self.pruned_heads.union(heads)
@staticmethod
def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
"""
Adapted from Mesh Tensorflow:
https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
Translate relative position to a bucket number for relative attention. The relative position is defined as
memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
This should allow for more graceful generalization to longer sequences than the model has been trained on
Args:
relative_position: an int32 Tensor - the difference in positions between memory and query
bidirectional: a boolean - whether the attention is bidirectional or not
num_buckets: an integer - number of buckets to categorize relative positions into
max_distance: an integer - maximum distance for categorizing relative positions
Returns:
a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
"""
relative_buckets = 0
if bidirectional:
num_buckets //= 2
relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
relative_position = torch.abs(relative_position)
else:
relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
max_exact = num_buckets // 2
is_small = relative_position < max_exact
relative_position_if_large = max_exact + (
torch.log(relative_position.float() / max_exact)
/ math.log(max_distance / max_exact)
* (num_buckets - max_exact)
).to(torch.long)
relative_position_if_large = torch.min(
relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
)
relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
return relative_buckets
def compute_bias(self, query_length, key_length, device=None):
"""Compute binned relative position bias"""
if device is None:
device = self.relative_attention_bias.weight.device
context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
relative_position = memory_position - context_position
relative_position_bucket = self._relative_position_bucket(
relative_position,
bidirectional=(not self.is_decoder),
num_buckets=self.relative_attention_num_buckets,
max_distance=self.relative_attention_max_distance,
)
values = self.relative_attention_bias(relative_position_bucket)
values = values.permute([2, 0, 1]).unsqueeze(0)
return values
def forward(
self,
hidden_states,
mask=None,
key_value_states=None,
position_bias=None,
past_key_value=None,
layer_head_mask=None,
query_length=None,
use_cache=False,
output_attentions=False,
class T5LayerSelfAttention(nn.Module):
def __init__(self, config, has_relative_attention_bias=False):
super().__init__()
self.SelfAttention = T5Attention(config, has_relative_attention_bias=has_relative_attention_bias)
self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate)
def forward(
self,
hidden_states,
attention_mask=None,
position_bias=None,
layer_head_mask=None,
past_key_value=None,
use_cache=False,
output_attentions=False,
):
normed_hidden_states = self.layer_norm(hidden_states)
attention_output = self.SelfAttention(
normed_hidden_states,
mask=attention_mask,
position_bias=position_bias,
layer_head_mask=layer_head_mask,
past_key_value=past_key_value,
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = hidden_states + self.dropout(attention_output[0])
outputs = (hidden_states,) + attention_output[1:]
return outputs
class T5LayerCrossAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.EncDecAttention = T5Attention(config, has_relative_attention_bias=False)
self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate)
def forward(
self,
hidden_states,
key_value_states,
attention_mask=None,
position_bias=None,
layer_head_mask=None,
past_key_value=None,
use_cache=False,
query_length=None,
output_attentions=False,
):
normed_hidden_states = self.layer_norm(hidden_states)
attention_output = self.EncDecAttention(
normed_hidden_states,
mask=attention_mask,
key_value_states=key_value_states,
position_bias=position_bias,
layer_head_mask=layer_head_mask,
past_key_value=past_key_value,
use_cache=use_cache,
query_length=query_length,
output_attentions=output_attentions,
)
layer_output = hidden_states + self.dropout(attention_output[0])
outputs = (layer_output,) + attention_output[1:]
return outputs
class T5Block(nn.Module):
def __init__(self, config, has_relative_attention_bias=False):
super().__init__()
self.is_decoder = config.is_decoder
self.layer = nn.ModuleList()
self.layer.append(T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias))
if self.is_decoder:
self.layer.append(T5LayerCrossAttention(config))
self.layer.append(T5LayerFF(config))
class T5ClassificationHead(nn.Module):
"""Head for sentence-level classification tasks."""
def __init__(self, config: T5Config):
super().__init__()
self.dense = nn.Linear(config.d_model, config.d_model)
self.dropout = nn.Dropout(p=config.classifier_dropout)
self.out_proj = nn.Linear(config.d_model, config.num_labels)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dropout(hidden_states)
hidden_states = self.dense(hidden_states)
hidden_states = torch.tanh(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.out_proj(hidden_states)
return hidden_states
class T5PreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = T5Config
load_tf_weights = load_tf_weights_in_t5
base_model_prefix = "transformer"
is_parallelizable = True
supports_gradient_checkpointing = True
_no_split_modules = ["T5Block"]
_keep_in_fp32_modules = ["wo"]
@property
def dummy_inputs(self):
input_ids = torch.tensor(DUMMY_INPUTS)
input_mask = torch.tensor(DUMMY_MASK)
dummy_inputs = {
"decoder_input_ids": input_ids,
"input_ids": input_ids,
"decoder_attention_mask": input_mask,
}
return dummy_inputs
def _shift_right(self, input_ids):
decoder_start_token_id = self.config.decoder_start_token_id
pad_token_id = self.config.pad_token_id
if decoder_start_token_id is None:
raise ValueError(
"self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. "
"See T5 docs for more information."
)
if is_torch_fx_proxy(input_ids):
shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id)
shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
else:
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
shifted_input_ids[..., 0] = decoder_start_token_id
if pad_token_id is None:
raise ValueError("self.model.config.pad_token_id has to be defined.")
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
return shifted_input_ids
class T5Stack(T5PreTrainedModel):
pass
def __init__(self, config, embed_tokens=None):
super().__init__(config)
self.embed_tokens = embed_tokens
self.is_decoder = config.is_decoder
self.block = nn.ModuleList(
[T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)]
)
self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate)
self.post_init()
self.model_parallel = False
self.device_map = None
self.gradient_checkpointing = False
@add_start_docstrings(PARALLELIZE_DOCSTRING)
def parallelize(self, device_map=None):
warnings.warn(
"`T5Stack.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your model"
" with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
" `device_map` but it needs to be a dictionary module_name to device, so for instance {'block.0': 0,"
" 'block.1': 1, ...}",
FutureWarning,
)
self.device_map = (
get_device_map(len(self.block), range(torch.cuda.device_count())) if device_map is None else device_map
)
assert_device_map(self.device_map, len(self.block))
self.model_parallel = True
self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys()))
self.last_device = "cuda:" + str(max(self.device_map.keys()))
for k, v in self.device_map.items():
for layer in v:
cuda_device = "cuda:" + str(k)
self.block[layer] = self.block[layer].to(cuda_device)
self.embed_tokens = self.embed_tokens.to(self.first_device)
self.final_layer_norm = self.final_layer_norm.to(self.last_device)
@add_start_docstrings(DEPARALLELIZE_DOCSTRING)
def deparallelize(self):
warnings.warn(
"Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
FutureWarning,
)
self.model_parallel = False
self.device_map = None
self.first_device = "cpu"
self.last_device = "cpu"
for i in range(len(self.block)):
self.block[i] = self.block[i].to("cpu")
self.embed_tokens = self.embed_tokens.to("cpu")
self.final_layer_norm = self.final_layer_norm.to("cpu")
torch.cuda.empty_cache()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, new_embeddings):
self.embed_tokens = new_embeddings
def forward(
self,
input_ids=None,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
inputs_embeds=None,
head_mask=None,
cross_attn_head_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
T5_START_DOCSTRING = r"""
The T5 model was proposed in [Exploring the Limits of Transfer Learning with a Unified Text-to-Text
Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan
Narang, Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. It's an encoder decoder transformer pre-trained in a
text-to-text denoising generative setting.
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`T5Config`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
T5_INPUTS_DOCSTRING = r"""
"""
T5_ENCODER_INPUTS_DOCSTRING = r"""
"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
output_attentions (`bool`, *optional*):
output_hidden_states (`bool`, *optional*):
return_dict (`bool`, *optional*):
"""
# 警告消息,用于将来的警告:head_mask 参数已分成两个输入参数 - head_mask 和 decoder_head_mask
__HEAD_MASK_WARNING_MSG = """
The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently,
`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions.
If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers,
num_heads)`.
"""
@add_start_docstrings(
"The bare T5 Model transformer outputting raw hidden-states without any specific head on top.",
T5_START_DOCSTRING,
)
class T5Model(T5PreTrainedModel):
# 在模型加载时忽略的意外键列表
_keys_to_ignore_on_load_unexpected = [
"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
]
# 共享权重的键列表
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: T5Config):
super().__init__(config)
# 创建共享的嵌入层,用于处理词汇大小和模型维度的embedding
self.shared = nn.Embedding(config.vocab_size, config.d_model)
# 复制配置并设置编码器
encoder_config = copy.deepcopy(config)
encoder_config.is_decoder = False
encoder_config.use_cache = False
encoder_config.is_encoder_decoder = False
self.encoder = T5Stack(encoder_config, self.shared)
# 复制配置并设置解码器
decoder_config = copy.deepcopy(config)
decoder_config.is_decoder = True
decoder_config.is_encoder_decoder = False
decoder_config.num_layers = config.num_decoder_layers
self.decoder = T5Stack(decoder_config, self.shared)
# 初始化权重并应用最终处理
self.post_init()
# 模型并行计算标志
self.model_parallel = False
self.device_map = None
@add_start_docstrings(PARALLELIZE_DOCSTRING)
def parallelize(self, device_map=None):
# 发出警告,提醒方法即将被弃用
warnings.warn(
"`T5Model.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your model"
" with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
" `device_map` but it needs to be a dictionary module_name to device, so for instance {'encoder.block.0':"
" 0, 'encoder.block.1': 1, ...}",
FutureWarning,
)
# 根据传入的设备映射或自动生成平衡设备映射
self.device_map = (
get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
if device_map is None
else device_map
)
# 检查设备映射的有效性
assert_device_map(self.device_map, len(self.encoder.block))
# 对编码器和解码器进行并行化处理
self.encoder.parallelize(self.device_map)
self.decoder.parallelize(self.device_map)
self.model_parallel = True
@add_start_docstrings(DEPARALLELIZE_DOCSTRING)
# 发出一个关于函数过时的警告,提示使用者此功能将在 Transformers 的 v5 版本中移除
def deparallelize(self):
warnings.warn(
"Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
FutureWarning,
)
# 调用编码器对象的 deparallelize 方法,取消并行化设置
self.encoder.deparallelize()
# 调用解码器对象的 deparallelize 方法,取消并行化设置
self.decoder.deparallelize()
# 将编码器移动到 CPU 上执行
self.encoder = self.encoder.to("cpu")
# 将解码器移动到 CPU 上执行
self.decoder = self.decoder.to("cpu")
# 禁用模型并行化设置
self.model_parallel = False
# 将设备映射设置为空
self.device_map = None
# 清空 CUDA 缓存
torch.cuda.empty_cache()
# 获取输入嵌入层对象的方法
def get_input_embeddings(self):
return self.shared
# 设置输入嵌入层对象的方法
def set_input_embeddings(self, new_embeddings):
# 更新共享的嵌入层对象
self.shared = new_embeddings
# 更新编码器的输入嵌入层对象
self.encoder.set_input_embeddings(new_embeddings)
# 更新解码器的输入嵌入层对象
self.decoder.set_input_embeddings(new_embeddings)
# 内部方法,用于绑定权重(如果配置允许)
def _tie_weights(self):
if self.config.tie_word_embeddings:
# 绑定或克隆编码器的词嵌入权重与共享的嵌入层对象
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
# 绑定或克隆解码器的词嵌入权重与共享的嵌入层对象
self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
# 获取编码器对象的方法
def get_encoder(self):
return self.encoder
# 获取解码器对象的方法
def get_decoder(self):
return self.decoder
# 内部方法,用于剪枝模型的注意力头
def _prune_heads(self, heads_to_prune):
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
# 遍历需要剪枝的层及其对应的头信息
for layer, heads in heads_to_prune.items():
# 对编码器的某一层的注意力模型进行头剪枝操作
self.encoder.layer[layer].attention.prune_heads(heads)
# 此函数用于模型的前向传播
@add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.BoolTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
decoder_head_mask: Optional[torch.FloatTensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.Tensor] = None,
decoder_inputs_embeds: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
@add_start_docstrings("""T5 Model with a `language modeling` head on top.""", T5_START_DOCSTRING)
class T5ForConditionalGeneration(T5PreTrainedModel):
# 忽略加载时不期望的键列表
_keys_to_ignore_on_load_unexpected = [
"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
]
# 被绑定权重的键列表
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
def __init__(self, config: T5Config):
super().__init__(config)
# 模型维度
self.model_dim = config.d_model
# 共享的嵌入层,用于输入词汇表大小和模型维度
self.shared = nn.Embedding(config.vocab_size, config.d_model)
# 复制编码器配置,并设置为非解码器模式
encoder_config = copy.deepcopy(config)
encoder_config.is_decoder = False
encoder_config.use_cache = False
encoder_config.is_encoder_decoder = False
# 创建编码器实例
self.encoder = T5Stack(encoder_config, self.shared)
# 复制解码器配置,并设置为解码器模式
decoder_config = copy.deepcopy(config)
decoder_config.is_decoder = True
decoder_config.is_encoder_decoder = False
decoder_config.num_layers = config.num_decoder_layers
# 创建解码器实例
self.decoder = T5Stack(decoder_config, self.shared)
# 线性层,用于语言模型的输出,输入维度为模型维度,输出维度为词汇表大小,无偏置
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
# 初始化权重并应用最终处理
self.post_init()
# 模型并行化
self.model_parallel = False
self.device_map = None
@add_start_docstrings(PARALLELIZE_DOCSTRING)
def parallelize(self, device_map=None):
# 发出警告,此方法即将弃用
warnings.warn(
"`T5ForConditionalGeneration.parallelize` is deprecated and will be removed in v5 of Transformers, you"
" should load your model with `device_map='balanced'` in the call to `from_pretrained`. You can also"
" provide your own `device_map` but it needs to be a dictionary module_name to device, so for instance"
" {'encoder.block.0': 0, 'encoder.block.1': 1, ...}",
FutureWarning,
)
# 获取设备映射,如果未提供则使用均衡映射
self.device_map = (
get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
if device_map is None
else device_map
)
# 验证设备映射的有效性
assert_device_map(self.device_map, len(self.encoder.block))
# 将编码器并行化
self.encoder.parallelize(self.device_map)
# 将解码器并行化
self.decoder.parallelize(self.device_map)
# 将语言模型头移到解码器的第一个设备上
self.lm_head = self.lm_head.to(self.decoder.first_device)
# 设置模型为模型并行化状态
self.model_parallel = True
@add_start_docstrings(DEPARALLELIZE_DOCSTRING)
def deparallelize(self):
# 发出警告,此方法即将弃用
warnings.warn(
"Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
FutureWarning,
)
# 取消编码器的并行化
self.encoder.deparallelize()
# 取消解码器的并行化
self.decoder.deparallelize()
# 将编码器移到CPU
self.encoder = self.encoder.to("cpu")
# 将解码器移到CPU
self.decoder = self.decoder.to("cpu")
# 将语言模型头移到CPU
self.lm_head = self.lm_head.to("cpu")
# 设置模型为非模型并行化状态
self.model_parallel = False
self.device_map = None
# 清空CUDA缓存
torch.cuda.empty_cache()
def get_input_embeddings(self):
# 返回共享的嵌入层
return self.shared
# 设置模型的输入词嵌入
def set_input_embeddings(self, new_embeddings):
# 将新的词嵌入赋给共享的嵌入层
self.shared = new_embeddings
# 设置编码器的输入词嵌入
self.encoder.set_input_embeddings(new_embeddings)
# 设置解码器的输入词嵌入
self.decoder.set_input_embeddings(new_embeddings)
# 绑定权重(或克隆)以确保编码器和解码器共享相同的词嵌入
def _tie_weights(self):
if self.config.tie_word_embeddings:
# 绑定或克隆编码器的嵌入层与共享的嵌入层
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
# 绑定或克隆解码器的嵌入层与共享的嵌入层
self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
# 设置模型的输出词嵌入
def set_output_embeddings(self, new_embeddings):
# 设置语言模型头部的新词嵌入
self.lm_head = new_embeddings
# 返回模型的输出词嵌入
def get_output_embeddings(self):
return self.lm_head
# 返回编码器实例
def get_encoder(self):
return self.encoder
# 返回解码器实例
def get_decoder(self):
return self.decoder
# 模型前向传播方法,用于执行T5模型的输入到输出的转换
@add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.BoolTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
decoder_head_mask: Optional[torch.FloatTensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
# 省略T5模型的前向传播逻辑,由装饰器管理
# 准备用于生成的输入,这里主要用于生成文本
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
head_mask=None,
decoder_head_mask=None,
decoder_attention_mask=None,
cross_attn_head_mask=None,
use_cache=None,
encoder_outputs=None,
**kwargs,
):
# 省略准备生成文本输入的逻辑,可以传递各种参数给模型
):
# 如果使用了过去的键值(past_key_values),则裁剪decoder_input_ids
if past_key_values is not None:
# 获取过去键值的长度
past_length = past_key_values[0][0].shape[2]
# 一些生成方法已经只传递最后一个输入ID
if input_ids.shape[1] > past_length:
# 如果输入的input_ids长度大于过去的长度,裁剪掉前面的部分
remove_prefix_length = past_length
else:
# 默认旧的行为:保留最后一个ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
return {
"decoder_input_ids": input_ids, # 返回裁剪后的decoder_input_ids
"past_key_values": past_key_values, # 返回过去的键值
"encoder_outputs": encoder_outputs, # 返回编码器的输出
"attention_mask": attention_mask, # 返回注意力掩码
"head_mask": head_mask, # 返回头部掩码
"decoder_head_mask": decoder_head_mask, # 返回解码器头部掩码
"decoder_attention_mask": decoder_attention_mask, # 返回解码器的注意力掩码
"cross_attn_head_mask": cross_attn_head_mask, # 返回交叉注意力头部掩码
"use_cache": use_cache, # 返回是否使用缓存的标志
}
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
# 根据标签准备解码器的输入ids,将标签向右移动一位
return self._shift_right(labels)
def _reorder_cache(self, past_key_values, beam_idx):
# 如果解码器的过去状态未包含在输出中
# 快速解码被禁用,无需重新排序
if past_key_values is None:
logger.warning("You might want to consider setting `use_cache=True` to speed up decoding")
return past_key_values
reordered_decoder_past = ()
for layer_past_states in past_key_values:
# 从层过去状态中获取正确的批次索引
# past的批次维度在第二个位置
reordered_layer_past_states = ()
for layer_past_state in layer_past_states:
# 需要为每个四个键/值状态设置正确的past
reordered_layer_past_states = reordered_layer_past_states + (
layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)),
)
if reordered_layer_past_states[0].shape != layer_past_states[0].shape:
raise ValueError(
f"reordered_layer_past_states[0] shape {reordered_layer_past_states[0].shape} and layer_past_states[0] shape {layer_past_states[0].shape} mismatched"
)
if len(reordered_layer_past_states) != len(layer_past_states):
raise ValueError(
f"length of reordered_layer_past_states {len(reordered_layer_past_states)} and length of layer_past_states {len(layer_past_states)} mismatched"
)
reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
return reordered_decoder_past
# 添加模型的文档字符串,描述了这个类是一个 T5 编码器模型,输出编码器的原始隐藏状态而不带任何特定的头部
@add_start_docstrings(
"The bare T5 Model transformer outputting encoder's raw hidden-states without any specific head on top.",
T5_START_DOCSTRING,
)
class T5EncoderModel(T5PreTrainedModel):
# 在加载模型时需要保持权重一致的键列表
_tied_weights_keys = ["encoder.embed_tokens.weight"]
# 加载时需要忽略的意外键列表,这里排除了包含"decoder"的键
_keys_to_ignore_on_load_unexpected = [r"decoder"]
def __init__(self, config: T5Config):
super().__init__(config)
# 共享的嵌入层,根据配置创建一个词汇表大小为config.vocab_size,维度为config.d_model的嵌入层
self.shared = nn.Embedding(config.vocab_size, config.d_model)
# 复制配置以配置编码器,并设置一些属性
encoder_config = copy.deepcopy(config)
encoder_config.use_cache = False
encoder_config.is_encoder_decoder = False
# 创建 T5 堆栈编码器,使用共享的嵌入层
self.encoder = T5Stack(encoder_config, self.shared)
# 初始化权重并应用最终处理
self.post_init()
# 模型并行处理相关属性初始化
self.model_parallel = False
self.device_map = None
# 添加模型并行化的文档字符串,警告此方法在后续版本中将被移除
@add_start_docstrings(PARALLELIZE_DOCSTRING)
def parallelize(self, device_map=None):
warnings.warn(
"`T5EncoderModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should load"
" your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
" `device_map` but it needs to be a dictionary module_name to device, so for instance {'block.0': 0,"
" 'block.1': 1, ...}",
FutureWarning,
)
# 获取设备映射,如果未提供设备映射,则默认使用均衡的设备映射
self.device_map = (
get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
if device_map is None
else device_map
)
# 检查设备映射的合法性
assert_device_map(self.device_map, len(self.encoder.block))
# 调用编码器的并行化方法,设置模型并行为真
self.encoder.parallelize(self.device_map)
self.model_parallel = True
# 添加反并行化的文档字符串,警告此方法在后续版本中将被移除
@add_start_docstrings(DEPARALLELIZE_DOCSTRING)
def deparallelize(self):
warnings.warn(
"Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
FutureWarning,
)
# 调用编码器的反并行化方法,将编码器转移到 CPU 上,并设置模型并行为假
self.encoder.deparallelize()
self.encoder = self.encoder.to("cpu")
self.model_parallel = False
self.device_map = None
# 清空 CUDA 缓存
torch.cuda.empty_cache()
# 返回共享的嵌入层
def get_input_embeddings(self):
return self.shared
# 设置新的输入嵌入层,并更新编码器的输入嵌入层
def set_input_embeddings(self, new_embeddings):
self.shared = new_embeddings
self.encoder.set_input_embeddings(new_embeddings)
# 绑定权重的私有方法,如果配置中要求绑定词嵌入权重,则绑定编码器的嵌入词汇表和共享的嵌入层
def _tie_weights(self):
if self.config.tie_word_embeddings:
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
# 返回编码器对象
def get_encoder(self):
return self.encoder
# 剪枝模型中的注意力头,heads_to_prune 是一个字典,表示需要在每层剪枝的注意力头的列表
def _prune_heads(self, heads_to_prune):
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
# 调用编码器堆栈中每层的自注意力模块的剪枝头方法
self.encoder.block[layer].layer[0].SelfAttention.prune_heads(heads)
# 添加 T5 编码器模型前向传播的文档字符串
@add_start_docstrings_to_model_forward(T5_ENCODER_INPUTS_DOCSTRING)
# 将函数返回值的文档字符串中的输出类型替换为BaseModelOutput,配置类为_CONFIG_FOR_DOC
@replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
# 前向传播函数,接受多个输入参数,并返回Union类型的torch.FloatTensor或BaseModelOutput
def forward(
self,
input_ids: Optional[torch.LongTensor] = None, # 输入的token IDs,类型为可选的长整型张量
attention_mask: Optional[torch.FloatTensor] = None, # 注意力遮罩张量,类型为可选的浮点数张量
head_mask: Optional[torch.FloatTensor] = None, # 头部遮罩张量,类型为可选的浮点数张量
inputs_embeds: Optional[torch.FloatTensor] = None, # 嵌入输入张量,类型为可选的浮点数张量
output_attentions: Optional[bool] = None, # 是否输出注意力张量,类型为可选的布尔值
output_hidden_states: Optional[bool] = None, # 是否输出隐藏状态张量,类型为可选的布尔值
return_dict: Optional[bool] = None, # 是否返回字典格式的输出,类型为可选的布尔值
) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
r"""
返回值:
如果return_dict不为None,则返回return_dict;否则返回self.config.use_return_dict。
示例:
```
>>> from transformers import AutoTokenizer, T5EncoderModel
>>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small")
>>> model = T5EncoderModel.from_pretrained("google-t5/t5-small")
>>> input_ids = tokenizer(
... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
... ).input_ids
>>> outputs = model(input_ids=input_ids)
>>> last_hidden_states = outputs.last_hidden_state
```
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# 使用编码器模型处理输入,并获取编码器的输出结果
encoder_outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# 返回编码器的输出结果
return encoder_outputs
"""
T5 model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE
tasks.
"""
# 定义 T5 序列分类模型,其顶部有一个线性层(位于汇聚输出之上),用于例如 GLUE 任务
@add_start_docstrings(
"""
T5 Encoder Model with a token classification head on top (a linear layer on top of the hidden-states output)
e.g. for Named-Entity-Recognition (NER) tasks.
""",
T5_START_DOCSTRING,
)
class T5ForTokenClassification(T5PreTrainedModel):
# 指定权重共享的关键键列表
_tied_weights_keys = ["transformer.encoder.embed_tokens.weight"]
def __init__(self, config: T5Config):
super().__init__(config)
# 初始化 T5 配置
self.num_labels = config.num_labels
# 创建 T5 编码器模型
self.transformer = T5EncoderModel(config)
# 添加一个丢弃层
self.dropout = nn.Dropout(config.classifier_dropout)
# 添加一个线性分类器层
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
# 初始化权重并执行最终处理
self.post_init()
@add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TokenClassifierOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.Tensor] = None, # 输入的token IDs张量,可选
attention_mask: Optional[torch.Tensor] = None, # 注意力掩码张量,可选
head_mask: Optional[torch.Tensor] = None, # 头部掩码张量,可选
inputs_embeds: Optional[torch.Tensor] = None, # 输入的嵌入张量,可选
labels: Optional[torch.Tensor] = None, # 用于计算标记分类损失的标签张量,可选
output_attentions: Optional[bool] = None, # 是否输出注意力权重,可选
output_hidden_states: Optional[bool] = None, # 是否输出隐藏状态,可选
return_dict: Optional[bool] = None, # 是否返回字典格式的输出,可选
) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
Returns:
"""
# 确定是否使用配置中的返回字典设置
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# 将输入传递给transformer模型,并获取输出
outputs = self.transformer(
input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# 从输出中获取隐藏状态并应用dropout
hidden_states = outputs[0]
hidden_states = self.dropout(hidden_states)
# 将隐藏状态传递给分类器,获取预测的逻辑回归输出
logits = self.classifier(hidden_states)
# 如果提供了标签,则计算损失
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
# 根据return_dict标志,返回不同的输出格式
if not return_dict:
output = (logits, outputs[2:-1]) # 仅在不返回字典时输出隐藏状态
return ((loss,) + output) if loss is not None else output
# 返回TokenClassifierOutput对象,包含损失、预测的逻辑回归、隐藏状态和注意力权重
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
# 使用特定的文档字符串初始化模型,该模型包含一个用于提取问答任务(如SQuAD)的跨度分类头部(在隐藏状态输出之上的线性层,用于计算“跨度起始logits”和“跨度结束logits”)。
@add_start_docstrings(
"""
T5 Model with a span classification head on top for extractive question-answering tasks like SQuAD (linear layers
on top of the hidden-states output to compute `span start logits` and `span end logits`).
""",
T5_START_DOCSTRING,
)
class T5ForQuestionAnswering(T5PreTrainedModel):
# 在加载时忽略的键列表,遇到不期待的键
_keys_to_ignore_on_load_unexpected = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"]
# 权重共享的键列表
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
# 初始化方法,接受一个T5Config对象作为参数
def __init__(self, config: T5Config):
super().__init__(config)
# 模型维度设为配置文件中的d_model值
self.model_dim = config.d_model
# 共享的词嵌入层,使用配置文件中的vocab_size和d_model创建
self.shared = nn.Embedding(config.vocab_size, config.d_model)
# 复制配置以创建编码器
encoder_config = copy.deepcopy(config)
encoder_config.is_decoder = False
encoder_config.use_cache = False
encoder_config.is_encoder_decoder = False
# 使用T5Stack模块和共享的词嵌入层创建编码器
self.encoder = T5Stack(encoder_config, self.shared)
# 复制配置以创建解码器
decoder_config = copy.deepcopy(config)
decoder_config.is_decoder = True
decoder_config.is_encoder_decoder = False
decoder_config.num_layers = config.num_decoder_layers
# 使用T5Stack模块和共享的词嵌入层创建解码器
self.decoder = T5Stack(decoder_config, self.shared)
# 输出的标签数量为配置文件中的num_labels
self.num_labels = config.num_labels
# 线性层,将隐藏大小映射到标签数量
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
# 初始化权重并进行最终处理
self.post_init()
# 模型并行设置为False
self.model_parallel = False
# 返回共享的词嵌入层对象
def get_input_embeddings(self):
return self.shared
# 设置新的输入词嵌入层
def set_input_embeddings(self, new_embeddings):
self.shared = new_embeddings
# 同时更新编码器和解码器的输入词嵌入层
self.encoder.set_input_embeddings(new_embeddings)
self.decoder.set_input_embeddings(new_embeddings)
# 绑定权重,如果配置中设置了tie_word_embeddings为True
def _tie_weights(self):
if self.config.tie_word_embeddings:
# 绑定或克隆权重以使编码器和解码器共享词嵌入层的权重
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
# 返回编码器对象
def get_encoder(self):
return self.encoder
# 返回解码器对象
def get_decoder(self):
return self.decoder
# 用于模型前向传播方法的装饰器,添加了T5输入文档字符串
@add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
# 替换返回值文档字符串为Seq2SeqQuestionAnsweringModelOutput类型,使用配置类_CONFIG_FOR_DOC
@replace_return_docstrings(output_type=Seq2SeqQuestionAnsweringModelOutput, config_class=_CONFIG_FOR_DOC)
# 定义神经网络模型的前向传播函数,接受多个可选的输入参数,类型为 PyTorch 的张量或布尔值
def forward(
self,
input_ids: Optional[torch.LongTensor] = None, # 输入文本的词编号张量,可选
attention_mask: Optional[torch.FloatTensor] = None, # 输入文本的注意力掩码张量,可选
decoder_input_ids: Optional[torch.LongTensor] = None, # 解码器输入的词编号张量,可选
decoder_attention_mask: Optional[torch.BoolTensor] = None, # 解码器的注意力掩码张量,可选
head_mask: Optional[torch.FloatTensor] = None, # 多头注意力的头部掩码张量,可选
decoder_head_mask: Optional[torch.FloatTensor] = None, # 解码器多头注意力的头部掩码张量,可选
cross_attn_head_mask: Optional[torch.Tensor] = None, # 交叉注意力的头部掩码张量,可选
encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, # 编码器的输出,包含多个张量的元组,可选
start_positions: Optional[torch.LongTensor] = None, # 起始位置张量,用于损失计算,可选
end_positions: Optional[torch.LongTensor] = None, # 结束位置张量,用于损失计算,可选
inputs_embeds: Optional[torch.FloatTensor] = None, # 输入的嵌入向量张量,可选
decoder_inputs_embeds: Optional[torch.FloatTensor] = None, # 解码器输入的嵌入向量张量,可选
use_cache: Optional[bool] = None, # 是否使用缓存,用于解码器的 Transformer 模型,可选
output_attentions: Optional[bool] = None, # 是否输出注意力权重,可选
output_hidden_states: Optional[bool] = None, # 是否输出隐藏状态,可选
return_dict: Optional[bool] = None, # 是否以字典形式返回输出,可选