Transformers 源码解析(四十六)
.\models\ernie_m\tokenization_ernie_m.py
import io
import os
import unicodedata
from typing import Any, Dict, List, Optional, Tuple
import sentencepiece as spm
from ...tokenization_utils import PreTrainedTokenizer
from ...utils import logging
logger = logging.get_logger(__name__)
SPIECE_UNDERLINE = "▁"
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "sentencepiece_model_ckpt": "sentencepiece.bpe.model"}
RESOURCE_FILES_NAMES = {
"sentencepiece_model_file": "sentencepiece.bpe.model",
"vocab_file": "vocab.txt",
}
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"ernie-m-base": "https://huggingface.co/susnato/ernie-m-base_pytorch/blob/main/vocab.txt",
"ernie-m-large": "https://huggingface.co/susnato/ernie-m-base_pytorch/blob/main/vocab.txt",
},
"sentencepiece_model_file": {
"ernie-m-base": "https://huggingface.co/susnato/ernie-m-base_pytorch/blob/main/sentencepiece.bpe.model",
"ernie-m-large": "https://huggingface.co/susnato/ernie-m-base_pytorch/blob/main/sentencepiece.bpe.model",
},
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"ernie-m-base": 514,
"ernie-m-large": 514,
}
PRETRAINED_INIT_CONFIGURATION = {
"ernie-m-base": {"do_lower_case": False},
"ernie-m-large": {"do_lower_case": False},
}
class ErnieMTokenizer(PreTrainedTokenizer):
r"""
Constructs a Ernie-M tokenizer. It uses the `sentencepiece` tools to cut the words to sub-words.
"""
def __init__(
self,
vocab_file: Optional[str] = None,
sentencepiece_model_file: Optional[str] = None,
do_lower_case=False,
**kwargs
):
"""
:param vocab_file: 词汇文件的路径(可选)
:param sentencepiece_model_file: sentencepiece 模型文件的路径(可选)
:param do_lower_case: 是否将所有输入转换为小写(默认为 False)
"""
super().__init__(
vocab_file=vocab_file,
sentencepiece_model_file=sentencepiece_model_file,
do_lower_case=do_lower_case,
**kwargs
)
"""
Args:
sentencepiece_model_ckpt (`str`):
某句段语法模型检查点的路径, 用于序列到序列的编码和解码任务.
vocab_file (`str`, *optional*):
字典文件路径, 若未提供则继承默认词汇表.
do_lower_case (`str`, *optional*, defaults to `True`):
是否将输入文本转换为小写, 当在数据预处理阶段处理文本时启用.
encoding (`str`, *optional*, defaults to `utf8`):
编码方式, 默认使用UTF-8用于解析输入数据.
unk_token (`str`, *optional*, defaults to `"[UNK]"`):
未知词汇(外域词汇)的标记, 用于替换未在词汇表中的词汇.
sep_token (`str`, *optional*, defaults to `"[SEP]"`):
用于分隔不同句子在同一批文本序列中.
pad_token (`str`, *optional*, defaults to `"[PAD]"`):
用于填充序列, 使所有序列长度相等适用于批处理.
cls_token (`str`, *optional*, defaults to `"[CLS]"`):
分类器的标志符, 表示序列开始的典型符号.
mask_token (`str`, *optional*, defaults to `"[MASK]"`):
用于替换的标记符号, 该模型将其视为需要预测原始未掩码的令牌的例证.
sp_model_kwargs: `Optional[Dict[str, Any]]` = None:
用于初始化句段模型的可选参数字典.
kwargs:
其他可能的初始化参数, 用于扩展上述参数的功能.
"""
model_input_names: List[str] = ["input_ids"]
vocab_files_names = VOCAB_FILES_NAMES
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
resource_files_names = RESOURCE_FILES_NAMES
def __init__(
self,
sentencepiece_model_ckpt,
vocab_file=None,
do_lower_case=False,
encoding="utf8",
unk_token="[UNK]",
sep_token="[SEP]",
pad_token="[PAD]",
cls_token="[CLS]",
mask_token="[MASK]",
sp_model_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
):
pass
) -> None:
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
self.do_lower_case = do_lower_case
self.sentencepiece_model_ckpt = sentencepiece_model_ckpt
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(sentencepiece_model_ckpt)
if vocab_file is not None:
self.vocab = self.load_vocab(filepath=vocab_file)
else:
self.vocab = {self.sp_model.id_to_piece(id): id for id in range(self.sp_model.get_piece_size())}
self.reverse_vocab = {v: k for k, v in self.vocab.items()}
super().__init__(
do_lower_case=do_lower_case,
unk_token=unk_token,
sep_token=sep_token,
pad_token=pad_token,
cls_token=cls_token,
mask_token=mask_token,
vocab_file=vocab_file,
encoding=encoding,
sp_model_kwargs=self.sp_model_kwargs,
**kwargs,
)
def get_offset_mapping(self, text):
if text is None:
return None
split_tokens = self.tokenize(text)
normalized_text, char_mapping = "", []
for i, ch in enumerate(text):
if ch in self.SP_CHAR_MAPPING:
ch = self.SP_CHAR_MAPPING.get(ch)
else:
ch = unicodedata.normalize("NFKC", ch)
if self.is_whitespace(ch):
continue
normalized_text += ch
char_mapping.extend([i] * len(ch))
text, token_mapping, offset = normalized_text, [], 0
if self.do_lower_case:
text = text.lower()
for token in split_tokens:
if token[:1] == "▁":
token = token[1:]
start = text[offset:].index(token) + offset
end = start + len(token)
token_mapping.append((char_mapping[start], char_mapping[end - 1] + 1))
offset = end
return token_mapping
@property
def vocab_size(self):
return len(self.vocab)
def get_vocab(self):
return dict(self.vocab, **self.added_tokens_encoder)
def __getstate__(self):
state = self.__dict__.copy()
state["sp_model"] = None
return state
def __setstate__(self, d):
self.__dict__ = d
if not hasattr(self, "sp_model_kwargs"):
self.sp_model_kwargs = {}
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(self.sentencepiece_model_ckpt)
def clean_text(self, text):
return "".join((self.SP_CHAR_MAPPING.get(c, c) for c in text))
def _tokenize(self, text, enable_sampling=False, nbest_size=64, alpha=0.1):
"""Tokenize a string."""
if self.sp_model_kwargs.get("enable_sampling") is True:
enable_sampling = True
if self.sp_model_kwargs.get("alpha") is not None:
alpha = self.sp_model_kwargs.get("alpha")
if self.sp_model_kwargs.get("nbest_size") is not None:
nbest_size = self.sp_model_kwargs.get("nbest_size")
if not enable_sampling:
pieces = self.sp_model.EncodeAsPieces(text)
else:
pieces = self.sp_model.SampleEncodeAsPieces(text, nbest_size, alpha)
new_pieces = []
for pi, piece in enumerate(pieces):
if piece == SPIECE_UNDERLINE:
if not pieces[pi + 1].startswith(SPIECE_UNDERLINE) and pi != 0:
new_pieces.append(SPIECE_UNDERLINE)
continue
else:
continue
lst_i = 0
for i, chunk in enumerate(piece):
if chunk == SPIECE_UNDERLINE:
continue
if self.is_ch_char(chunk) or self.is_punct(chunk):
if i > lst_i and piece[lst_i:i] != SPIECE_UNDERLINE:
new_pieces.append(piece[lst_i:i])
new_pieces.append(chunk)
lst_i = i + 1
elif chunk.isdigit() and i > 0 and not piece[i - 1].isdigit():
if i > lst_i and piece[lst_i:i] != SPIECE_UNDERLINE:
new_pieces.append(piece[lst_i:i])
lst_i = i
elif not chunk.isdigit() and i > 0 and piece[i - 1].isdigit():
if i > lst_i and piece[lst_i:i] != SPIECE_UNDERLINE:
new_pieces.append(piece[lst_i:i])
lst_i = i
if len(piece) > lst_i:
new_pieces.append(piece[lst_i:])
return new_pieces
def convert_tokens_to_string(self, tokens):
out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
return out_string
def convert_ids_to_string(self, ids):
tokens = self.convert_ids_to_tokens(ids)
out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
return out_string
def _convert_token_to_id(self, token):
return self.vocab.get(token, self.vocab.get(self.unk_token))
def _convert_id_to_token(self, index):
"""
Converts an index (integer) into a token (str) using the vocabulary.
Args:
index (int): Index to convert into a token.
Returns:
str: The corresponding token if found in the vocabulary, otherwise returns the unknown token (self.unk_token).
"""
return self.reverse_vocab.get(index, self.unk_token)
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
"""
构建用于序列分类任务的模型输入,通过连接和添加特殊标记。ErnieM 序列的格式如下:
- 单个序列:`[CLS] X [SEP]`
- 序列对:`[CLS] A [SEP] [SEP] B [SEP]`
Args:
token_ids_0 (List[int]): 要添加特殊标记的 ID 列表。
token_ids_1 (List[int], optional): 第二个序列的 ID 列表(可选)。
Returns:
List[int]: 包含适当特殊标记的输入 ID 列表。
"""
if token_ids_1 is None:
return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
_cls = [self.cls_token_id]
_sep = [self.sep_token_id]
return _cls + token_ids_0 + _sep + _sep + token_ids_1 + _sep
def build_offset_mapping_with_special_tokens(self, offset_mapping_0, offset_mapping_1=None):
"""
构建偏移映射,通过连接和添加特殊标记的偏移量。Ernie-M 偏移映射的格式如下:
- 单个序列:`(0,0) X (0,0)`
- 序列对:`(0,0) A (0,0) (0,0) B (0,0)`
Args:
offset_mapping_ids_0 (List[tuple]): 要添加特殊标记的字符偏移列表。
offset_mapping_ids_1 (List[tuple], optional): 第二个序列的单词片段偏移列表(可选)。
Returns:
List[tuple]: 包含适当特殊标记偏移量的单词片段偏移列表。
"""
if offset_mapping_1 is None:
return [(0, 0)] + offset_mapping_0 + [(0, 0)]
return [(0, 0)] + offset_mapping_0 + [(0, 0), (0, 0)] + offset_mapping_1 + [(0, 0)]
def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False):
r"""
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer `encode` method.
Args:
token_ids_0 (`List[int]`):
List of ids of the first sequence.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
already_has_special_tokens (`str`, *optional*, defaults to `False`):
Whether or not the token list is already formatted with special tokens for the model.
Returns:
`List[int]`:
The list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
"""
if already_has_special_tokens:
if token_ids_1 is not None:
raise ValueError(
"You should not supply a second sequence if the provided sequence of "
"ids is already formatted with special tokens for the model."
)
return [1 if x in [self.sep_token_id, self.cls_token_id] else 0 for x in token_ids_0]
if token_ids_1 is not None:
return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
return [1] + ([0] * len(token_ids_0)) + [1]
def create_token_type_ids_from_sequences(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Create the token type IDs corresponding to the sequences passed. [What are token type
IDs?](../glossary#token-type-ids) Should be overridden in a subclass if the model has a special way of
building: those.
Args:
token_ids_0 (`List[int]`):
The first tokenized sequence.
token_ids_1 (`List[int]`, *optional*):
The second tokenized sequence.
Returns:
`List[int]`: The token type ids.
"""
if token_ids_1 is None:
return (len(token_ids_0) + 2) * [0]
return [0] * (len(token_ids_0) + 1) + [1] * (len(token_ids_1) + 3)
def is_ch_char(self, char):
"""
is_ch_char
"""
if "\u4e00" <= char <= "\u9fff":
return True
return False
def is_alpha(self, char):
"""
is_alpha
"""
if ("a" <= char <= "z") or ("A" <= char <= "Z"):
return True
return False
def is_punct(self, char):
"""
is_punct
"""
if char in ",;:.?!~,;:。?!《》【】":
return True
return False
def is_whitespace(self, char):
"""
is whitespace
"""
if char == " " or char == "\t" or char == "\n" or char == "\r":
return True
if len(char) == 1:
cat = unicodedata.category(char)
if cat == "Zs":
return True
return False
def load_vocab(self, filepath):
token_to_idx = {}
with io.open(filepath, "r", encoding="utf-8") as f:
for index, line in enumerate(f):
token = line.rstrip("\n")
token_to_idx[token] = int(index)
return token_to_idx
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
index = 0
if os.path.isdir(save_directory):
vocab_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
)
else:
vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory
with open(vocab_file, "w", encoding="utf-8") as writer:
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning(
f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive."
" Please check that the vocabulary is not corrupted!"
)
index = token_index
writer.write(token + "\n")
index += 1
tokenizer_model_file = os.path.join(save_directory, "sentencepiece.bpe.model")
with open(tokenizer_model_file, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
return (vocab_file,)
.\models\ernie_m\__init__.py
from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_sentencepiece_available, is_torch_available
_import_structure = {
"configuration_ernie_m": ["ERNIE_M_PRETRAINED_CONFIG_ARCHIVE_MAP", "ErnieMConfig"],
}
try:
if not is_sentencepiece_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["tokenization_ernie_m"] = ["ErnieMTokenizer"]
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_ernie_m"] = [
"ERNIE_M_PRETRAINED_MODEL_ARCHIVE_LIST",
"ErnieMForMultipleChoice",
"ErnieMForQuestionAnswering",
"ErnieMForSequenceClassification",
"ErnieMForTokenClassification",
"ErnieMModel",
"ErnieMPreTrainedModel",
"ErnieMForInformationExtraction",
]
if TYPE_CHECKING:
from .configuration_ernie_m import ERNIE_M_PRETRAINED_CONFIG_ARCHIVE_MAP, ErnieMConfig
try:
if not is_sentencepiece_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .tokenization_ernie_m import ErnieMTokenizer
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_ernie_m import (
ERNIE_M_PRETRAINED_MODEL_ARCHIVE_LIST,
ErnieMForInformationExtraction,
ErnieMForMultipleChoice,
ErnieMForQuestionAnswering,
ErnieMForSequenceClassification,
ErnieMForTokenClassification,
ErnieMModel,
ErnieMPreTrainedModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
.\models\esm\configuration_esm.py
"""
# coding=utf-8
# Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" ESM model configuration"""
# Import necessary modules
from dataclasses import asdict, dataclass
from typing import Optional
# Import configuration utilities
from ...configuration_utils import PretrainedConfig
from ...utils import logging
# Get logger for this module
logger = logging.get_logger(__name__)
# TODO Update this
# Mapping of pretrained model names to their configuration URLs
ESM_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"facebook/esm-1b": "https://huggingface.co/facebook/esm-1b/resolve/main/config.json",
# See all ESM models at https://huggingface.co/models?filter=esm
}
# Configuration class for the ESM model, inheriting from PretrainedConfig
class EsmConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`ESMModel`]. It is used to instantiate a ESM model
according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the ESM
[facebook/esm-1b](https://huggingface.co/facebook/esm-1b) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Examples:
```
>>> from transformers import EsmModel, EsmConfig
>>>
>>>
>>>
```
"""
model_type = "esm"
def __init__(
self,
vocab_size=None,
mask_token_id=None,
pad_token_id=None,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=1026,
initializer_range=0.02,
layer_norm_eps=1e-12,
position_embedding_type="absolute",
use_cache=True,
emb_layer_norm_before=None,
token_dropout=False,
is_folding_model=False,
esmfold_config=None,
vocab_list=None,
**kwargs,
):
super().__init__(pad_token_id=pad_token_id, mask_token_id=mask_token_id, **kwargs)
# 调用父类的初始化方法,传入特定的参数来初始化当前类
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.position_embedding_type = position_embedding_type
self.use_cache = use_cache
self.emb_layer_norm_before = emb_layer_norm_before
self.token_dropout = token_dropout
self.is_folding_model = is_folding_model
# 初始化多个模型配置的参数
if is_folding_model:
if esmfold_config is None:
logger.info("No esmfold_config supplied for folding model, using default values.")
# 如果没有提供 esmfold_config 参数,则使用默认配置并记录日志信息
esmfold_config = EsmFoldConfig()
elif isinstance(esmfold_config, dict):
esmfold_config = EsmFoldConfig(**esmfold_config)
# 如果 esmfold_config 是一个字典,则根据字典内容创建 EsmFoldConfig 对象
self.esmfold_config = esmfold_config
if vocab_list is None:
logger.warning("No vocab_list supplied for folding model, assuming the ESM-2 vocabulary!")
# 如果没有提供 vocab_list 参数,则假设使用 ESM-2 词汇表,并记录警告信息
self.vocab_list = get_default_vocab_list()
else:
self.vocab_list = vocab_list
# 否则,使用提供的 vocab_list 参数
else:
self.esmfold_config = None
self.vocab_list = None
# 如果不是折叠模型,则将 esmfold_config 和 vocab_list 设置为 None
if self.esmfold_config is not None and getattr(self.esmfold_config, "use_esm_attn_map", False):
raise ValueError("The HuggingFace port of ESMFold does not support use_esm_attn_map at this time!")
# 如果 esmfold_config 不为 None,且其属性 use_esm_attn_map 为 True,则抛出值错误异常
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = super().to_dict()
# 调用父类的 to_dict 方法,将父类的序列化结果添加到 output 字典中
if isinstance(self.esmfold_config, EsmFoldConfig):
output["esmfold_config"] = self.esmfold_config.to_dict()
# 如果 esmfold_config 是 EsmFoldConfig 类型的对象,则将其序列化为字典并加入 output 中
return output
# 返回包含当前实例所有属性的字典作为序列化结果
# 数据类 EsmFoldConfig,用于配置 ESM 折叠模型的参数
@dataclass
class EsmFoldConfig:
# ESM 类型,默认为 None
esm_type: str = None
# 是否使用 FP16 格式的 ESM
fp16_esm: bool = True
# 是否使用 ESM 注意力映射
use_esm_attn_map: bool = False
# 是否剔除 ESM 的成对序列
esm_ablate_pairwise: bool = False
# 是否剔除 ESM 的序列
esm_ablate_sequence: bool = False
# ESM 输入的 dropout 概率
esm_input_dropout: float = 0
# 是否嵌入氨基酸信息
embed_aa: bool = True
# 是否绕过语言模型
bypass_lm: bool = False
# LDDT 头部隐藏维度
lddt_head_hid_dim: int = 128
# EsmFoldConfig 的 trunk 配置,如果为 None 则使用默认配置
trunk: "TrunkConfig" = None
# 初始化方法,在对象创建后调用,处理 trunk 属性
def __post_init__(self):
# 如果 trunk 为 None,则使用默认的 TrunkConfig
if self.trunk is None:
self.trunk = TrunkConfig()
# 如果 trunk 是 dict 类型,则将其转换为 TrunkConfig 对象
elif isinstance(self.trunk, dict):
self.trunk = TrunkConfig(**self.trunk)
# 将当前实例序列化为 Python 字典的方法
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
# 将当前实例转换为字典
output = asdict(self)
# 将 trunk 属性也转换为字典
output["trunk"] = self.trunk.to_dict()
return output
# 数据类 TrunkConfig,用于配置 ESM 折叠模型的 trunk 参数
@dataclass
class TrunkConfig:
# trunk 的块数
num_blocks: int = 48
# 序列状态维度
sequence_state_dim: int = 1024
# 成对状态维度
pairwise_state_dim: int = 128
# 序列头部宽度
sequence_head_width: int = 32
# 成对头部宽度
pairwise_head_width: int = 32
# 位置分箱数
position_bins: int = 32
# dropout 概率
dropout: float = 0
# 层丢弃概率
layer_drop: float = 0
# 是否使用 CPU 梯度检查点
cpu_grad_checkpoint: bool = False
# 最大循环次数
max_recycles: int = 4
# 分块大小
chunk_size: Optional[int] = 128
# 结构模块配置
structure_module: "StructureModuleConfig" = None
# 初始化方法,在对象实例化后自动调用。确保配置的正确性和一致性。
def __post_init__(self):
# 如果结构模块未指定,则使用默认的结构模块配置
if self.structure_module is None:
self.structure_module = StructureModuleConfig()
# 如果结构模块是一个字典,则将其转换为结构模块配置对象
elif isinstance(self.structure_module, dict):
self.structure_module = StructureModuleConfig(**self.structure_module)
# 检查最大循环次数是否大于零,否则抛出数值错误异常
if self.max_recycles <= 0:
raise ValueError(f"`max_recycles` should be positive, got {self.max_recycles}.")
# 检查序列状态维度是否是其自身的倍数,否则抛出数值错误异常
if self.sequence_state_dim % self.sequence_state_dim != 0:
raise ValueError(
"`sequence_state_dim` should be a round multiple of `sequence_state_dim`, got"
f" {self.sequence_state_dim} and {self.sequence_state_dim}."
)
# 检查成对状态维度是否是其自身的倍数,否则抛出数值错误异常
if self.pairwise_state_dim % self.pairwise_state_dim != 0:
raise ValueError(
"`pairwise_state_dim` should be a round multiple of `pairwise_state_dim`, got"
f" {self.pairwise_state_dim} and {self.pairwise_state_dim}."
)
# 计算序列头的数量,确保序列状态维度与序列头宽度的乘积相等
sequence_num_heads = self.sequence_state_dim // self.sequence_head_width
if self.sequence_state_dim != sequence_num_heads * self.sequence_head_width:
raise ValueError(
"`sequence_state_dim` should be equal to `sequence_num_heads * sequence_head_width, got"
f" {self.sequence_state_dim} != {sequence_num_heads} * {self.sequence_head_width}."
)
# 计算成对头的数量,确保成对状态维度与成对头宽度的乘积相等
pairwise_num_heads = self.pairwise_state_dim // self.pairwise_head_width
if self.pairwise_state_dim != pairwise_num_heads * self.pairwise_head_width:
raise ValueError(
"`pairwise_state_dim` should be equal to `pairwise_num_heads * pairwise_head_width, got"
f" {self.pairwise_state_dim} != {pairwise_num_heads} * {self.pairwise_head_width}."
)
# 检查成对状态维度是否为偶数,否则抛出数值错误异常
if self.pairwise_state_dim % 2 != 0:
raise ValueError(f"`pairwise_state_dim` should be even, got {self.pairwise_state_dim}.")
# 检查丢弃率是否小于0.4,否则抛出数值错误异常
if self.dropout >= 0.4:
raise ValueError(f"`dropout` should not be greater than 0.4, got {self.dropout}.")
# 将当前实例序列化为Python字典的方法。覆盖默认的to_dict方法。
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
# 将对象的所有属性转换为字典
output = asdict(self)
# 将结构模块属性转换为其对应的字典表示
output["structure_module"] = self.structure_module.to_dict()
return output
@dataclass
class StructureModuleConfig:
"""
定义了结构模块的配置参数的数据类。
Args:
sequence_dim:
单一表示通道的维度
pairwise_dim:
成对表示通道的维度
ipa_dim:
IPA 隐藏通道的维度
resnet_dim:
Angle resnet(Alg. 23 lines 11-14)隐藏通道的维度
num_heads_ipa:
IPA 头的数量
num_qk_points:
在IPA期间生成的查询/键点的数量
num_v_points:
在IPA期间生成的值点的数量
dropout_rate:
层中使用的dropout率
num_blocks:
结构模块的块数量
num_transition_layers:
单一表示转换中的层数(Alg. 23 lines 8-9)
num_resnet_blocks:
Angle resnet 中的块数量
num_angles:
Angle resnet 中生成的角度数量
trans_scale_factor:
单一表示转换的隐藏维度的比例因子
epsilon:
Angle resnet 归一化中使用的小数值
inf:
用于注意力屏蔽的大数值
"""
sequence_dim: int = 384
pairwise_dim: int = 128
ipa_dim: int = 16
resnet_dim: int = 128
num_heads_ipa: int = 12
num_qk_points: int = 4
num_v_points: int = 8
dropout_rate: float = 0.1
num_blocks: int = 8
num_transition_layers: int = 1
num_resnet_blocks: int = 2
num_angles: int = 7
trans_scale_factor: int = 10
epsilon: float = 1e-8
inf: float = 1e5
def to_dict(self):
"""
将数据类实例转换为字典的方法。
"""
return asdict(self)
def get_default_vocab_list():
"""
返回默认的词汇表列表。
Returns:
tuple: 包含默认词汇的元组
"""
return (
"<cls>",
"<pad>",
"<eos>",
"<unk>",
"L",
"A",
"G",
"V",
"S",
"E",
"R",
"T",
"I",
"D",
"P",
"K",
"Q",
"N",
"F",
"Y",
"M",
"H",
"W",
"C",
"X",
"B",
"U",
"Z",
"O",
".",
"-",
"<null_1>",
"<mask>",
)
.\models\esm\convert_esm.py
"""Convert ESM checkpoint."""
import argparse
import pathlib
from pathlib import Path
from tempfile import TemporaryDirectory
import esm as esm_module
import torch
from esm.esmfold.v1.misc import batch_encode_sequences as esmfold_encode_sequences
from esm.esmfold.v1.pretrained import esmfold_v1
from transformers.models.esm.configuration_esm import EsmConfig, EsmFoldConfig
from transformers.models.esm.modeling_esm import (
EsmForMaskedLM,
EsmForSequenceClassification,
EsmIntermediate,
EsmLayer,
EsmOutput,
EsmSelfAttention,
EsmSelfOutput,
)
from transformers.models.esm.modeling_esmfold import EsmForProteinFolding
from transformers.models.esm.tokenization_esm import EsmTokenizer
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
SAMPLE_DATA = [
(
"protein1",
"MNGTEGPNFYVPFSNATGVVRSPFEYPQYYLAEPWQFSMLAAYMFLLIVLGFPINFLTLYVTVQHKKLRTPLNYILLNLAVADLFMVLGGFTSTLYTSLHGYFVFGPTGCNLEGFFATLGGEIALWSLVVLAIERYVVVCKPMSNFRFGENHAIMGVAFTWVMALACAAPPLAGWSRYIPEGLQCSCGIDYYTLKPEVNNESFVIYMFVVHFTIPMIIIFFCYGQLVFTVKEAAAQQQESATTQKAEKEVTRMVIIMVIAFLICWVPYASVAFYIFTHQGSNFGPIFMTIPAFFAKSAAIYNPVIYIMMNKQFRNCMLTTICCGKNPLGDDEASATVSKTETSQVAPA",
),
("protein2", "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLA"),
("protein3", "MKTVRQERLKSI<mask>RILERSKEPVSGAQLAEELS<mask>SRQVIVQDIAYLRSLGYN<mask>VATPRGYVLAGG"),
("protein4", "MKTVRQERLKSI<mask>RILERSKEPVSGAQLAEELS<mask>SRQVIVQDIAYLRSLGYN<mask>VATPRGYVLA"),
]
MODEL_MAPPING = {
"esm1b_t33_650M_UR50S": esm_module.pretrained.esm1b_t33_650M_UR50S,
"esm1v_t33_650M_UR90S_1": esm_module.pretrained.esm1v_t33_650M_UR90S_1,
"esm1v_t33_650M_UR90S_2": esm_module.pretrained.esm1v_t33_650M_UR90S_2,
"esm1v_t33_650M_UR90S_3": esm_module.pretrained.esm1v_t33_650M_UR90S_3,
"esm1v_t33_650M_UR90S_4": esm_module.pretrained.esm1v_t33_650M_UR90S_4,
"esm1v_t33_650M_UR90S_5": esm_module.pretrained.esm1v_t33_650M_UR90S_5,
"esm2_t48_15B_UR50D": esm_module.pretrained.esm2_t48_15B_UR50D,
"esm2_t36_3B_UR50D": esm_module.pretrained.esm2_t36_3B_UR50D,
"esm2_t33_650M_UR50D": esm_module.pretrained.esm2_t33_650M_UR50D,
"esm2_t30_150M_UR50D": esm_module.pretrained.esm2_t30_150M_UR50D,
"esm2_t12_35M_UR50D": esm_module.pretrained.esm2_t12_35M_UR50D,
}
"esm2_t6_8M_UR50D": esm_module.pretrained.esm2_t6_8M_UR50D,
"esmfold_v1": esmfold_v1,
}
restypes = list("ARNDCQEGHILKMFPSTWYV")
restypes_with_x = restypes + ["X"]
restypes_with_extras = restypes_with_x + ["<pad>", "<mask>", "<cls>", "<sep>", "<eos>"]
def get_esmfold_tokenizer():
with TemporaryDirectory() as tempdir:
vocab = "\n".join(restypes_with_extras)
vocab_file = Path(tempdir) / "vocab.txt"
vocab_file.write_text(vocab)
hf_tokenizer = EsmTokenizer(vocab_file=str(vocab_file))
hf_tokenizer.pad_token_id = 0
return hf_tokenizer
def transfer_and_check_weights(original_module, our_module):
status = our_module.load_state_dict(original_module.state_dict())
if status.missing_keys:
raise ValueError(f"Missing keys: {status.missing_keys}")
if status.unexpected_keys:
raise ValueError(f"Unexpected keys: {status.unexpected_keys}")
def convert_esm_checkpoint_to_pytorch(
model: str, pytorch_dump_folder_path: str, classification_head: bool, push_to_repo: str, auth_token: str
):
"""
复制/粘贴/调整 esm 的权重到我们的 BERT 结构中。
"""
if model.startswith("esmfold"):
esm = MODEL_MAPPING[model]()
else:
esm, alphabet = MODEL_MAPPING[model]()
esm.eval()
if model.startswith("esmfold"):
embed_dim = esm.esm.embed_dim
num_layers = esm.esm.num_layers
num_attention_heads = esm.esm.attention_heads
intermediate_size = 4 * embed_dim
token_dropout = esm.esm.token_dropout
emb_layer_norm_before = False
position_embedding_type = "rotary"
is_folding_model = True
esmfold_config = EsmFoldConfig()
for key, val in esm.cfg.items():
if hasattr(esmfold_config, key) and key != "trunk":
setattr(esmfold_config, key, val)
for key, val in esm.cfg.trunk.items():
if hasattr(esmfold_config.trunk, key) and key != "structure_module":
setattr(esmfold_config.trunk, key, val)
for key, val in esm.cfg.trunk.structure_module.items():
if hasattr(esmfold_config.trunk.structure_module, key):
setattr(esmfold_config.trunk.structure_module, key, val)
elif hasattr(esm, "args"):
embed_dim = esm.args.embed_dim
num_layers = esm.args.layers
num_attention_heads = esm.args.attention_heads
intermediate_size = esm.args.ffn_embed_dim
token_dropout = esm.args.token_dropout
emb_layer_norm_before = True if esm.emb_layer_norm_before else False
position_embedding_type = "absolute"
is_folding_model = False
esmfold_config = None
else:
embed_dim = esm.embed_dim
num_layers = esm.num_layers
num_attention_heads = esm.attention_heads
intermediate_size = 4 * embed_dim
token_dropout = esm.token_dropout
emb_layer_norm_before = False
position_embedding_type = "rotary"
is_folding_model = False
esmfold_config = None
if is_folding_model:
alphabet = esm.esm.alphabet
vocab_list = tuple(alphabet.all_toks)
mask_token_id = alphabet.mask_idx
pad_token_id = alphabet.padding_idx
if is_folding_model:
original_esm_model = esm.esm
else:
original_esm_model = esm
config = EsmConfig(
vocab_size=original_esm_model.embed_tokens.num_embeddings,
mask_token_id=mask_token_id,
hidden_size=embed_dim,
num_hidden_layers=num_layers,
num_attention_heads=num_attention_heads,
intermediate_size=intermediate_size,
max_position_embeddings=1026,
layer_norm_eps=1e-5,
attention_probs_dropout_prob=0.0,
hidden_dropout_prob=0.0,
pad_token_id=pad_token_id,
emb_layer_norm_before=emb_layer_norm_before,
token_dropout=token_dropout,
position_embedding_type=position_embedding_type,
is_folding_model=is_folding_model,
esmfold_config=esmfold_config,
vocab_list=vocab_list,
)
if classification_head:
config.num_labels = esm.classification_heads["mnli"].out_proj.weight.shape[0]
print("Our ESM config:", config)
if model.startswith("esmfold"):
model_class = EsmForProteinFolding
elif classification_head:
model_class = EsmForSequenceClassification
else:
model_class = EsmForMaskedLM
model = model_class(config)
model.eval()
model.esm.embeddings.word_embeddings.weight = original_esm_model.embed_tokens.weight
if position_embedding_type == "absolute":
model.esm.embeddings.position_embeddings.weight = original_esm_model.embed_positions.weight
if config.emb_layer_norm_before:
model.esm.embeddings.layer_norm.weight = original_esm_model.emb_layer_norm_before.weight
model.esm.embeddings.layer_norm.bias = original_esm_model.emb_layer_norm_before.bias
model.esm.encoder.emb_layer_norm_after.weight = original_esm_model.emb_layer_norm_after.weight
model.esm.encoder.emb_layer_norm_after.bias = original_esm_model.emb_layer_norm_after.bias
if is_folding_model:
model.esm_s_combine.data = esm.esm_s_combine.data
model.af2_to_esm.data = esm.af2_to_esm.data
transfer_and_check_weights(esm.embedding, model.embedding)
transfer_and_check_weights(esm.esm_s_mlp, model.esm_s_mlp)
transfer_and_check_weights(esm.trunk, model.trunk)
transfer_and_check_weights(esm.distogram_head, model.distogram_head)
transfer_and_check_weights(esm.ptm_head, model.ptm_head)
transfer_and_check_weights(esm.lm_head, model.lm_head)
transfer_and_check_weights(esm.lddt_head, model.lddt_head)
elif classification_head:
model.classifier.dense.weight = esm.esm.classification_heads["mnli"].dense.weight
model.classifier.dense.bias = esm.classification_heads["mnli"].dense.bias
model.classifier.out_proj.weight = esm.classification_heads["mnli"].out_proj.weight
model.classifier.out_proj.bias = esm.classification_heads["mnli"].out_proj.bias
else:
model.lm_head.dense.weight = esm.lm_head.dense.weight
model.lm_head.dense.bias = esm.lm_head.dense.bias
model.lm_head.layer_norm.weight = esm.lm_head.layer_norm.weight
model.lm_head.layer_norm.bias = esm.lm_head.layer_norm.bias
model.lm_head.decoder.weight = esm.lm_head.weight
model.lm_head.bias = esm.lm_head.bias
transfer_and_check_weights(esm.contact_head, model.esm.contact_head)
if is_folding_model:
sample_data = SAMPLE_DATA[:2]
else:
sample_data = SAMPLE_DATA
if is_folding_model:
hf_tokenizer = get_esmfold_tokenizer()
hf_tokens = hf_tokenizer(
[row[1] for row in sample_data], return_tensors="pt", padding=True, add_special_tokens=False
)
esmfold_aas, esmfold_mask, _, _, _ = esmfold_encode_sequences([row[1] for row in sample_data])
success = torch.all(hf_tokens["input_ids"] == esmfold_aas) and torch.all(
hf_tokens["attention_mask"] == esmfold_mask
)
else:
batch_converter = alphabet.get_batch_converter()
batch_labels, batch_strs, batch_tokens = batch_converter(sample_data)
with TemporaryDirectory() as tempdir:
vocab = "\n".join(alphabet.all_toks)
vocab_file = Path(tempdir) / "vocab.txt"
vocab_file.write_text(vocab)
hf_tokenizer = EsmTokenizer(vocab_file=str(vocab_file))
hf_tokens = hf_tokenizer([row[1] for row in sample_data], return_tensors="pt", padding=True)
success = torch.all(hf_tokens["input_ids"] == batch_tokens)
print("Do both models tokenizers output the same tokens?", "🔥" if success else "💩")
if not success:
raise Exception("Tokenization does not match!")
with torch.no_grad():
if is_folding_model:
their_output = esm.cuda().infer([row[1] for row in sample_data])
our_output = model.cuda()(
input_ids=hf_tokens["input_ids"].cuda(), attention_mask=hf_tokens["attention_mask"].cuda()
)
else:
our_output = model(**hf_tokens, output_hidden_states=True)
our_output = our_output["logits"]
if classification_head:
their_output = esm.model.classification_heads["mnli"](esm.extract_features(batch_tokens))
else:
their_output = esm(hf_tokens["input_ids"], repr_layers=list(range(999)))
their_output = their_output["logits"]
if is_folding_model:
max_absolute_diff = torch.max(torch.abs(our_output["positions"] - their_output["positions"])).item()
success = torch.allclose(our_output["positions"], their_output["positions"], atol=1e-5)
else:
max_absolute_diff = torch.max(torch.abs(our_output - their_output)).item()
success = torch.allclose(our_output, their_output, atol=1e-5)
print(f"max_absolute_diff = {max_absolute_diff}")
print("Do both models output the same tensors?", "🔥" if success else "💩")
if not success:
raise Exception("Something went wRoNg")
if not is_folding_model:
our_output = model.predict_contacts(hf_tokens["input_ids"], hf_tokens["attention_mask"])
their_output = esm.predict_contacts(hf_tokens["input_ids"])
max_absolute_diff = torch.max(torch.abs(our_output - their_output)).item()
success = torch.allclose(our_output, their_output, atol=1e-5)
print("Contact prediction testing:")
print(f"max_absolute_diff = {max_absolute_diff}")
print("Do both models output the same tensors?", "🔥" if success else "💩")
if not success:
raise Exception("Something went wRoNg")
pathlib.Path(pytorch_dump_folder_path).mkdir(parents=True, exist_ok=True)
print(f"Saving model to {pytorch_dump_folder_path}")
model.save_pretrained(pytorch_dump_folder_path)
del esm
print(f"Saving tokenizer to {pytorch_dump_folder_path}")
hf_tokenizer.save_pretrained(pytorch_dump_folder_path)
if push_to_repo:
model.push_to_hub(repo_id=push_to_repo, token_token=auth_token)
hf_tokenizer.push_to_hub(repo_id=push_to_repo, token_token=auth_token)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--pytorch_dump_folder_path", type=str, required=True, help="Path to the output PyTorch model."
)
parser.add_argument(
"--classification_head", action="store_true", help="Whether to convert a final classification head."
)
parser.add_argument("--model", default=None, type=str, required=True, help="Name of model to convert.")
parser.add_argument("--push_to_repo", type=str, help="Repo to upload to (including username!).")
parser.add_argument("--auth_token", type=str, help="HuggingFace auth token.")
args = parser.parse_args()
convert_esm_checkpoint_to_pytorch(
args.model, args.pytorch_dump_folder_path, args.classification_head, args.push_to_repo, args.auth_token
)
.\models\esm\modeling_esm.py
import math
from typing import List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions,
MaskedLMOutput,
SequenceClassifierOutput,
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import logging
from .configuration_esm import EsmConfig
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "facebook/esm2_t6_8M_UR50D"
_CONFIG_FOR_DOC = "EsmConfig"
ESM_PRETRAINED_MODEL_ARCHIVE_LIST = [
"facebook/esm2_t6_8M_UR50D",
"facebook/esm2_t12_35M_UR50D",
]
def rotate_half(x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(x, cos, sin):
cos = cos[:, :, : x.shape[-2], :]
sin = sin[:, :, : x.shape[-2], :]
return (x * cos) + (rotate_half(x) * sin)
def gelu(x):
"""
This is the gelu implementation from the original ESM repo. Using F.gelu yields subtly wrong results.
"""
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
def symmetrize(x):
"Make layer symmetric in final two dimensions, used for contact prediction."
return x + x.transpose(-1, -2)
def average_product_correct(x):
"Perform average product correct, used for contact prediction."
a1 = x.sum(-1, keepdims=True)
a2 = x.sum(-2, keepdims=True)
a12 = x.sum((-1, -2), keepdims=True)
avg = a1 * a2
avg.div_(a12)
normalized = x - avg
return normalized
class RotaryEmbedding(torch.nn.Module):
"""
Rotary position embeddings based on those in
[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). Query and keys are transformed by rotation
matrices which depend on their relative positions.
"""
def __init__(self, dim: int):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
inv_freq = inv_freq
self.register_buffer("inv_freq", inv_freq)
self._seq_len_cached = None
self._cos_cached = None
self._sin_cached = None
def _update_cos_sin_tables(self, x, seq_dimension=2):
seq_len = x.shape[seq_dimension]
if seq_len != self._seq_len_cached or self._cos_cached.device != x.device:
self._seq_len_cached = seq_len
t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq)
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self._cos_cached = emb.cos()[None, None, :, :]
self._sin_cached = emb.sin()[None, None, :, :]
return self._cos_cached, self._sin_cached
def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2)
return (
apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
)
class EsmContactPredictionHead(nn.Module):
"""Performs symmetrization, apc, and computes a logistic regression on the output features"""
def __init__(
self,
in_features: int,
bias=True,
eos_idx: int = 2,
):
super().__init__()
self.in_features = in_features
self.eos_idx = eos_idx
self.regression = nn.Linear(in_features, 1, bias)
self.activation = nn.Sigmoid()
def forward(self, tokens, attentions):
eos_mask = tokens.ne(self.eos_idx).to(attentions)
eos_mask = eos_mask.unsqueeze(1) * eos_mask.unsqueeze(2)
attentions = attentions * eos_mask[:, None, None, :, :]
attentions = attentions[..., :-1, :-1]
attentions = attentions[..., 1:, 1:]
batch_size, layers, heads, seqlen, _ = attentions.size()
attentions = attentions.view(batch_size, layers * heads, seqlen, seqlen)
attentions = attentions.to(
self.regression.weight.device
)
attentions = average_product_correct(symmetrize(attentions))
attentions = attentions.permute(0, 2, 3, 1)
return self.activation(self.regression(attentions).squeeze(3))
class EsmEmbeddings(nn.Module):
"""
Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
"""
def __init__(self, config):
super().__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
if config.emb_layer_norm_before:
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
else:
self.layer_norm = None
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
self.padding_idx = config.pad_token_id
self.position_embeddings = nn.Embedding(
config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
)
self.token_dropout = config.token_dropout
self.mask_token_id = config.mask_token_id
def forward(
self, input_ids=None, attention_mask=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
):
if position_ids is None:
if input_ids is not None:
position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length)
else:
position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
embeddings = inputs_embeds
if self.token_dropout:
embeddings = embeddings.masked_fill((input_ids == self.mask_token_id).unsqueeze(-1), 0.0)
mask_ratio_train = 0.15 * 0.8
src_lengths = attention_mask.sum(-1)
mask_ratio_observed = (input_ids == self.mask_token_id).sum(-1).float() / src_lengths
embeddings = (embeddings * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]).to(
embeddings.dtype
)
if self.position_embedding_type == "absolute":
position_embeddings = self.position_embeddings(position_ids)
embeddings = embeddings + position_embeddings
if self.layer_norm is not None:
embeddings = self.layer_norm(embeddings)
if attention_mask is not None:
embeddings = (embeddings * attention_mask.unsqueeze(-1)).to(embeddings.dtype)
return embeddings
input_shape = inputs_embeds.size()[:-1]
sequence_length = input_shape[1]
position_ids = torch.arange(
self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
)
return position_ids.unsqueeze(0).expand(input_shape)
class EsmSelfAttention(nn.Module):
def __init__(self, config, position_embedding_type=None):
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
f"heads ({config.num_attention_heads})"
)
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = position_embedding_type or getattr(
config, "position_embedding_type", "absolute"
)
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
elif self.position_embedding_type == "rotary":
self.rotary_embeddings = RotaryEmbedding(dim=self.attention_head_size)
self.is_decoder = config.is_decoder
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
def __init__(self, config):
super().__init__()
self.self = EsmSelfAttention(config)
self.output = EsmSelfOutput(config)
self.pruned_heads = set()
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def prune_heads(self, heads):
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
)
self.self.query = prune_linear_layer(self.self.query, index)
self.self.key = prune_linear_layer(self.self.key, index)
self.self.value = prune_linear_layer(self.self.value, index)
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=None,
output_attentions=False,
):
hidden_states_ln = self.LayerNorm(hidden_states)
self_outputs = self.self(
hidden_states_ln,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:]
return outputs
class EsmIntermediate(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = gelu(hidden_states)
return hidden_states
class EsmOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = hidden_states + input_tensor
return hidden_states
class EsmLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = EsmAttention(config)
self.is_decoder = config.is_decoder
self.add_cross_attention = config.add_cross_attention
if self.add_cross_attention:
if not self.is_decoder:
raise RuntimeError(f"{self} should be used as a decoder model if cross attention is added")
self.crossattention = EsmAttention(config)
self.intermediate = EsmIntermediate(config)
self.output = EsmOutput(config)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=None,
output_attentions=False,
):
self_attention_outputs = self.attention(
hidden_states,
attention_mask,
head_mask,
output_attentions=output_attentions,
)
attention_output = self_attention_outputs[0]
if self.add_cross_attention:
cross_attention_outputs = self.crossattention(
attention_output,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions=output_attentions,
)
attention_output = cross_attention_outputs[0]
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
layer_output = self.LayerNorm(layer_output + attention_output)
outputs = (layer_output,) + self_attention_outputs[1:]
if output_attentions:
outputs = outputs + cross_attention_outputs[1:]
return outputs
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention(
hidden_states,
attention_mask,
head_mask,
output_attentions=output_attentions,
past_key_value=self_attn_past_key_value,
)
attention_output = self_attention_outputs[0]
if self.is_decoder:
outputs = self_attention_outputs[1:-1]
present_key_value = self_attention_outputs[-1]
else:
outputs = self_attention_outputs[1:]
cross_attn_present_key_value = None
if self.is_decoder and encoder_hidden_states is not None:
if not hasattr(self, "crossattention"):
raise AttributeError(
f"If `encoder_hidden_states` are passed, {self} has to be instantiated"
" with cross-attention layers by setting `config.add_cross_attention=True`"
)
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
cross_attention_outputs = self.crossattention(
attention_output,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
cross_attn_past_key_value,
output_attentions,
)
attention_output = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1:-1]
cross_attn_present_key_value = cross_attention_outputs[-1]
present_key_value = present_key_value + cross_attn_present_key_value
layer_output = self.feed_forward_chunk(attention_output)
outputs = (layer_output,) + outputs
if self.is_decoder:
outputs = outputs + (present_key_value,)
return outputs
```
class EsmEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.layer = nn.ModuleList([EsmLayer(config) for _ in range(config.num_hidden_layers)])
self.emb_layer_norm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.gradient_checkpointing = False
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
class EsmPooler(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
class EsmPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = EsmConfig
base_model_prefix = "esm"
supports_gradient_checkpointing = True
_no_split_modules = ["EsmLayer", "EsmFoldTriangularSelfAttentionBlock", "EsmEmbeddings"]
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
ESM_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
# 将其作为常规的 PyTorch 模块使用,并参考 PyTorch 文档以获取有关一般用法和行为的所有信息。
Parameters:
# config ([`EsmConfig`]): 模型配置类,包含模型的所有参数。
# 初始化时使用配置文件不会加载与模型相关的权重,只加载配置信息。
# 可以查看 [`~PreTrainedModel.from_pretrained`] 方法来加载模型权重。
# 定义了一个原始的 ESM 模型类,继承自 EsmPreTrainedModel
@add_start_docstrings(
"The bare ESM Model transformer outputting raw hidden-states without any specific head on top.",
ESM_START_DOCSTRING,
)
class EsmModel(EsmPreTrainedModel):
"""
ESM 模型类,可以作为编码器(只包含自注意力)或解码器使用,后者则在自注意力层之间添加了一层交叉注意力,遵循了 Ashish Vaswani 等人在《Attention is all you need》中描述的架构。
"""
"""
To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
`add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
"""
# 根据给定的配置初始化模型,可选择添加一个池化层
def __init__(self, config, add_pooling_layer=True):
super().__init__(config)
self.config = config
# 初始化嵌入层
self.embeddings = EsmEmbeddings(config)
# 初始化编码器
self.encoder = EsmEncoder(config)
# 如果设置了添加池化层,则初始化池化层,否则为None
self.pooler = EsmPooler(config) if add_pooling_layer else None
# 初始化联系预测头部
self.contact_head = EsmContactPredictionHead(
in_features=config.num_hidden_layers * config.num_attention_heads, bias=True
)
# 执行初始化后的权重和最终处理
self.post_init()
# 返回输入嵌入层
def get_input_embeddings(self):
return self.embeddings.word_embeddings
# 设置输入嵌入层
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
# 剪枝模型中的注意力头部
def _prune_heads(self, heads_to_prune):
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
# 前向传播函数,接受多种输入和参数
@add_start_docstrings_to_model_forward(ESM_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=BaseModelOutputWithPoolingAndCrossAttentions,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
# 预测联系人函数,接受 tokens 和 attention_mask 作为输入
def predict_contacts(self, tokens, attention_mask):
# 使用模型进行推断,返回注意力矩阵列表
attns = self(tokens, attention_mask=attention_mask, return_dict=True, output_attentions=True).attentions
# 将注意力矩阵堆叠起来,以匹配原始模型的布局
attns = torch.stack(attns, dim=1) # Matches the original model layout
# 在原始模型中,对于填充的 token,其注意力被完全置零。
# 大多数情况下这不会有影响,因为其他 token 不会关注它们,
# 但对于需要将注意力作为输入的联系人预测任务而言,这一点非常重要,
# 因此我们需要在这里模仿这种处理方式。
# 将注意力矩阵乘以 attention_mask,以将填充的 token 的注意力置零
attns *= attention_mask.unsqueeze(1).unsqueeze(2).unsqueeze(3)
attns *= attention_mask.unsqueeze(1).unsqueeze(2).unsqueeze(4)
# 使用联系人头部模型进行联系人预测,并返回结果
return self.contact_head(tokens, attns)
# 定义一个 EsmForMaskedLM 类,继承自 EsmPreTrainedModel 类,并添加了语言建模头部
@add_start_docstrings("""ESM Model with a `language modeling` head on top.""", ESM_START_DOCSTRING)
class EsmForMaskedLM(EsmPreTrainedModel):
# 定义了与 lm_head.decoder.weight 相关的权重绑定键
_tied_weights_keys = ["lm_head.decoder.weight"]
# 初始化方法,接收一个配置对象 config
def __init__(self, config):
# 调用父类的初始化方法
super().__init__(config)
# 如果配置中 is_decoder 为 True,则发出警告信息
if config.is_decoder:
logger.warning(
"If you want to use `EsmForMaskedLM` make sure `config.is_decoder=False` for "
"bi-directional self-attention."
)
# 创建 EsmModel 对象,不添加池化层
self.esm = EsmModel(config, add_pooling_layer=False)
# 创建 EsmLMHead 对象
self.lm_head = EsmLMHead(config)
# 初始化模型权重
self.init_weights()
# 返回 lm_head.decoder 对象,用于输出嵌入
def get_output_embeddings(self):
return self.lm_head.decoder
# 设置 lm_head.decoder 的新嵌入
def set_output_embeddings(self, new_embeddings):
self.lm_head.decoder = new_embeddings
# 前向传播方法,接收多个输入参数并返回输出
@add_start_docstrings_to_model_forward(ESM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=MaskedLMOutput,
config_class=_CONFIG_FOR_DOC,
mask="<mask>",
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
# 以下为输入参数的详细说明
):
) -> Union[Tuple, MaskedLMOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
kwargs (`Dict[str, any]`, optional, defaults to *{}*):
Used to hide legacy arguments that have been deprecated.
"""
# 根据参数 `return_dict` 确定是否返回字典类型的输出
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# 调用模型 `esm` 进行前向传播,传入各种输入参数
outputs = self.esm(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# 获取模型输出的序列输出
sequence_output = outputs[0]
# 对序列输出进行预测得到预测分数
prediction_scores = self.lm_head(sequence_output)
masked_lm_loss = None
# 如果存在标签,则计算掩码语言建模损失
if labels is not None:
loss_fct = CrossEntropyLoss()
# 将标签移动到与预测分数相同的设备上
labels = labels.to(prediction_scores.device)
# 计算掩码语言建模的损失
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
# 如果不返回字典类型的输出,则组织最终输出格式
if not return_dict:
output = (prediction_scores,) + outputs[2:]
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
# 返回掩码语言建模任务的输出,包括损失、预测分数、隐藏状态和注意力权重
return MaskedLMOutput(
loss=masked_lm_loss,
logits=prediction_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def predict_contacts(self, tokens, attention_mask):
# 调用模型 `esm` 的方法进行接触预测
return self.esm.predict_contacts(tokens, attention_mask=attention_mask)
class EsmLMHead(nn.Module):
"""ESM Head for masked language modeling."""
def __init__(self, config):
super().__init__()
# 定义一个全连接层,将输入特征空间映射到隐藏大小的空间
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
# Layer normalization,对输入进行归一化处理
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
# 用于输出,将隐藏大小映射回词汇表大小的线性层,无偏置
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# 定义一个偏置参数,长度为词汇表大小,用于模型输出的偏移
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
def forward(self, features, **kwargs):
# 前向传播函数
x = self.dense(features) # 全连接层映射
x = gelu(x) # 使用 GELU 激活函数
x = self.layer_norm(x) # Layer normalization 归一化处理
# 用线性层映射回词汇表大小,并加上偏置
x = self.decoder(x) + self.bias
return x
@add_start_docstrings(
"""
ESM Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
output) e.g. for GLUE tasks.
""",
ESM_START_DOCSTRING,
)
class EsmForSequenceClassification(EsmPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.config = config
# ESM 模型主体部分,不添加池化层
self.esm = EsmModel(config, add_pooling_layer=False)
# 分类头部,用于序列分类任务
self.classifier = EsmClassificationHead(config)
self.init_weights() # 初始化模型权重
@add_start_docstrings_to_model_forward(ESM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=SequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SequenceClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
# 如果 return_dict 为 None,则使用 self.config.use_return_dict 决定是否返回字典形式的输出
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# 使用 ESM 模型进行前向传播,获取模型的输出
outputs = self.esm(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# 从模型输出中获取序列输出
sequence_output = outputs[0]
# 将序列输出输入分类器,得到预测 logits
logits = self.classifier(sequence_output)
# 初始化损失为 None
loss = None
# 如果存在 labels,则计算损失
if labels is not None:
# 将 labels 移动到 logits 所在的设备上
labels = labels.to(logits.device)
# 根据问题类型确定问题类型("regression", "single_label_classification", "multi_label_classification")
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
# 根据问题类型计算相应的损失
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
# 如果不要求返回字典形式的输出,则按元组形式返回结果
if not return_dict:
output = (logits,) + outputs[2:] # 将 logits 和其他输出组成元组
return ((loss,) + output) if loss is not None else output # 如果有损失,则将损失与输出一起返回,否则只返回输出
# 返回 SequenceClassifierOutput 对象,包括损失、logits、隐藏状态和注意力权重
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"""
ESM Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
Named-Entity-Recognition (NER) tasks.
""",
ESM_START_DOCSTRING,
)
class EsmForTokenClassification(EsmPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
# 初始化 ESM 模型,不添加池化层
self.esm = EsmModel(config, add_pooling_layer=False)
# Dropout 层,用于防止过拟合
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# 分类器,将隐藏状态映射到标签数的线性层
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
# 初始化模型权重
self.init_weights()
@add_start_docstrings_to_model_forward(ESM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, TokenClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
"""
# 确定是否返回字典类型的输出
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# 获取 ESM 模型的输出
outputs = self.esm(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# 获取序列输出
sequence_output = outputs[0]
# 应用 Dropout 层
sequence_output = self.dropout(sequence_output)
# 使用分类器将序列输出映射到标签空间
logits = self.classifier(sequence_output)
# 计算损失
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
# 将标签移到与 logits 相同的设备上
labels = labels.to(logits.device)
# 计算交叉熵损失
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
# 如果不返回字典,则以元组形式返回输出
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
# 返回 TokenClassifierOutput 对象
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
class EsmClassificationHead(nn.Module):
"""Head for sentence-level classification tasks."""
# 初始化函数,用于创建一个新的神经网络模型实例
def __init__(self, config):
# 调用父类的初始化方法
super().__init__()
# 创建一个全连接层,输入和输出维度都是 config.hidden_size
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
# 创建一个 Dropout 层,使用 config.hidden_dropout_prob 作为丢弃概率
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# 创建一个全连接层,输入维度是 config.hidden_size,输出维度是 config.num_labels
self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
# 前向传播函数,定义了数据从输入到输出的流程
def forward(self, features, **kwargs):
# 取 features 的第一个位置的数据,通常表示起始 token(如 [CLS])
x = features[:, 0, :] # take <s> token (equiv. to [CLS])
# 对取出的数据应用 Dropout,随机部分神经元失活,防止过拟合
x = self.dropout(x)
# 将数据通过全连接层 self.dense 进行线性变换
x = self.dense(x)
# 对变换后的数据应用双曲正切函数进行非线性变换
x = torch.tanh(x)
# 再次应用 Dropout 层,进一步随机失活神经元
x = self.dropout(x)
# 将数据通过全连接层 self.out_proj 进行线性变换,得到最终的输出结果
x = self.out_proj(x)
# 返回神经网络模型的输出结果
return x
def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
"""
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
are ignored. This is modified from fairseq's `utils.make_positions`.
Args:
input_ids: torch.Tensor, input tensor containing token IDs
padding_idx: int, the index of padding tokens in input_ids
past_key_values_length: int, optional, length of past key values for incremental processing
Returns:
torch.Tensor, tensor of position IDs corresponding to input_ids
"""
# 创建一个掩码,标记非填充符号的位置为1,填充符号为0
mask = input_ids.ne(padding_idx).int()
# 计算每个非填充符号的位置编号,位置编号从 padding_idx+1 开始,乘以掩码以忽略填充符号
incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
# 将位置编号转换为长整型并加上 padding_idx,得到最终的位置 ID
return incremental_indices.long() + padding_idx
.\models\esm\modeling_esmfold.py
import math
import sys
from dataclasses import dataclass
from functools import partial
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
from torch.nn import LayerNorm
from ...integrations.deepspeed import is_deepspeed_available
from ...modeling_outputs import ModelOutput
from ...utils import (
ContextManagers,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_scipy_available,
logging,
replace_return_docstrings,
)
from .configuration_esm import EsmConfig
from .modeling_esm import ESM_START_DOCSTRING, EsmModel, EsmPreTrainedModel
from .openfold_utils import (
OFProtein,
Rigid,
Rotation,
atom14_to_atom37,
chunk_layer,
compute_predicted_aligned_error,
compute_tm,
frames_and_literature_positions_to_atom14_pos,
make_atom14_masks,
residue_constants,
to_pdb,
torsion_angles_to_frames,
)
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "facebook/esmfold_v1"
_CONFIG_FOR_DOC = "EsmConfig"
@dataclass
class EsmForProteinFoldingOutput(ModelOutput):
"""
[`EsmForProteinFoldingOutput`] 的输出类型。
"""
Args:
frames (`torch.FloatTensor`):
输出帧。
模型预测的帧输出。
sidechain_frames (`torch.FloatTensor`):
侧链帧。
模型预测的侧链帧输出。
unnormalized_angles (`torch.FloatTensor`):
预测的未归一化主链和侧链扭转角度。
模型预测的未归一化主链和侧链扭转角度。
angles (`torch.FloatTensor`):
预测的主链和侧链扭转角度。
模型预测的主链和侧链扭转角度。
positions (`torch.FloatTensor`):
预测的主链和侧链原子的位置。
模型预测的主链和侧链原子位置。
states (`torch.FloatTensor`):
蛋白质折叠主干的隐藏状态。
来自蛋白质折叠主干的隐藏状态。
s_s (`torch.FloatTensor`):
每个残基嵌入。
通过连接ESM-2 LM stem每层的隐藏状态得到的每个残基嵌入。
s_z (`torch.FloatTensor`):
成对残基嵌入。
成对残基嵌入。
distogram_logits (`torch.FloatTensor`):
距离直方图的输入对数。
用于计算残基距离的输入对数。
lm_logits (`torch.FloatTensor`):
ESM-2蛋白质语言模型主干的输出对数。
ESM-2蛋白质语言模型主干的输出对数。
aatype (`torch.FloatTensor`):
输入的氨基酸(AlphaFold2索引)。
输入的氨基酸(AlphaFold2索引)。
atom14_atom_exists (`torch.FloatTensor`):
每个原子在atom14表示中是否存在。
每个原子在atom14表示中是否存在。
residx_atom14_to_atom37 (`torch.FloatTensor`):
atom14到atom37表示之间的映射。
atom14到atom37表示之间的映射。
residx_atom37_to_atom14 (`torch.FloatTensor`):
atom37到atom14表示之间的映射。
atom37到atom14表示之间的映射。
atom37_atom_exists (`torch.FloatTensor`):
每个原子在atom37表示中是否存在。
每个原子在atom37表示中是否存在。
residue_index (`torch.FloatTensor`):
蛋白链中每个残基的索引。
蛋白链中每个残基的索引。
lddt_head (`torch.FloatTensor`):
lddt头部的原始输出。
用于计算plddt的lddt头部的原始输出。
plddt (`torch.FloatTensor`):
每个残基的置信度分数。
模型预测结构可能不确定或蛋白结构无序的区域可能表明低置信度的区域。
ptm_logits (`torch.FloatTensor`):
用于计算ptm的原始logits。
用于计算ptm的原始logits。
ptm (`torch.FloatTensor`):
TM-score输出,代表模型对整体结构的高级置信度。
TM-score输出,代表模型对整体结构的高级置信度。
aligned_confidence_probs (`torch.FloatTensor`):
对齐结构的每个残基置信度分数。
对齐结构的每个残基置信度分数。
predicted_aligned_error (`torch.FloatTensor`):
模型预测与真实值之间的预测误差。
模型预测与真实值之间的预测误差。
max_predicted_aligned_error (`torch.FloatTensor`):
每个样本的最大预测误差。
每个样本的最大预测误差。
"""
frames: torch.FloatTensor = None
sidechain_frames: torch.FloatTensor = None
unnormalized_angles: torch.FloatTensor = None
angles: torch.FloatTensor = None
# 定义一系列变量,每个变量的类型均为 torch.FloatTensor,初始赋值为 None
positions: torch.FloatTensor = None # 用于存储位置信息的张量
states: torch.FloatTensor = None # 用于存储状态信息的张量
s_s: torch.FloatTensor = None # 用于存储 s_s 信息的张量
s_z: torch.FloatTensor = None # 用于存储 s_z 信息的张量
distogram_logits: torch.FloatTensor = None # 用于存储距离直方图 logits 的张量
lm_logits: torch.FloatTensor = None # 用于存储语言模型 logits 的张量
aatype: torch.FloatTensor = None # 用于存储氨基酸类型的张量
atom14_atom_exists: torch.FloatTensor = None # 用于存储 atom14 是否存在的张量
residx_atom14_to_atom37: torch.FloatTensor = None # 用于存储 residue index 到 atom37 的映射的张量
residx_atom37_to_atom14: torch.FloatTensor = None # 用于存储 residue index 到 atom14 的映射的张量
atom37_atom_exists: torch.FloatTensor = None # 用于存储 atom37 是否存在的张量
residue_index: torch.FloatTensor = None # 用于存储残基索引的张量
lddt_head: torch.FloatTensor = None # 用于存储 lddt 头信息的张量
plddt: torch.FloatTensor = None # 用于存储 plddt 信息的张量
ptm_logits: torch.FloatTensor = None # 用于存储 ptm logits 的张量
ptm: torch.FloatTensor = None # 用于存储 ptm 信息的张量
aligned_confidence_probs: torch.FloatTensor = None # 用于存储对齐置信度概率的张量
predicted_aligned_error: torch.FloatTensor = None # 用于存储预测的对齐误差的张量
max_predicted_aligned_error: torch.FloatTensor = None # 用于存储最大预测对齐误差的张量
# 定义一个多行文档字符串,描述了函数 `ESMFOLD_INPUTS_DOCSTRING` 的参数及其含义
ESMFOLD_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `({0})`):
Indices of input sequence tokens in the vocabulary.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary
attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary
position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.max_position_embeddings - 1]`.
[What are position IDs?](../glossary
masking_pattern (`torch.LongTensor` of shape `({0})`, *optional*):
Locations of tokens to mask during training as a form of regularization. Mask values selected in `[0, 1]`.
num_recycles (`int`, *optional*, defaults to `None`):
Number of times to recycle the input sequence. If `None`, defaults to `config.num_recycles`. "Recycling"
consists of passing the output of the folding trunk back in as input to the trunk. During training, the
number of recycles should vary with each batch, to ensure that the model learns to output valid predictions
after each recycle. During inference, num_recycles should be set to the highest value that the model was
trained with for maximum accuracy. Accordingly, when this value is set to `None`, config.max_recycles is
used.
"""
def is_fp16_enabled():
# 检查当前是否启用了 FP16 自动转换
fp16_enabled = torch.get_autocast_gpu_dtype() == torch.float16
fp16_enabled = fp16_enabled and torch.is_autocast_enabled()
return fp16_enabled
def is_deepspeed_initialized():
# 检查是否初始化了 DeepSpeed,如果 DeepSpeed 可用但未初始化则返回 False
if is_deepspeed_available():
return False
else:
try:
import deepspeed
# 尝试调用 DeepSpeed 的初始化检查函数,部分版本可能不支持此功能
return deepspeed.utils.is_initialized()
except Exception:
# 捕获所有异常,返回 False 表示未初始化
return False
def collate_dense_tensors(samples: List[torch.Tensor], pad_v: float = 0) -> torch.Tensor:
"""
将一个张量列表堆叠并填充成一个单一张量,所有张量的维度必须一致。
参数:
samples: 包含多个张量的列表,每个张量的维度必须相同。
pad_v: 填充值,默认为 0。
返回:
堆叠并填充后的单一张量。
异常:
如果 samples 中张量的维度不一致,抛出 RuntimeError 异常。
"""
if len(samples) == 0:
return torch.Tensor() # 如果 samples 列表为空,则返回空张量
if len({x.dim() for x in samples}) != 1:
# 检查 samples 中张量的维度是否一致,不一致则抛出异常
raise RuntimeError(f"Samples has varying dimensions: {[x.dim() for x in samples]}")
# 从 samples 中获取设备信息,假设所有样本都在同一设备上
(device,) = tuple({x.device for x in samples})
# 计算 samples 中每个样本的最大形状的每个维度的最大值
max_shape = [max(lst) for lst in zip(*[x.shape for x in samples])]
# 使用 torch.empty 创建一个与最大形状匹配的张量 result,长度为 len(samples),数据类型与 samples[0] 相同,设备与 samples 相同
result = torch.empty(len(samples), *max_shape, dtype=samples[0].dtype, device=device)
# 用 pad_v 填充 result 张量
result.fill_(pad_v)
# 遍历每个样本并将其复制到 result 张量的适当位置
for i in range(len(samples)):
result_i = result[i] # 获取 result 中的第 i 个子张量
t = samples[i] # 获取第 i 个样本张量 t
# 将样本张量 t 复制到 result_i 的正确位置
result_i[tuple(slice(0, k) for k in t.shape)] = t
# 返回填充后的 result 张量,其中包含了所有样本的数据
return result
# 定义函数,用于将张量的最后几个维度展平成一个维度
def flatten_final_dims(t: torch.Tensor, no_dims: int):
return t.reshape(t.shape[:-no_dims] + (-1,))
# 定义函数,用于对张量的最后几个维度进行置换
def permute_final_dims(tensor: torch.Tensor, inds: List[int]):
# 计算最后几个维度的起始索引
zero_index = -1 * len(inds)
# 获取前面的维度索引列表
first_inds = list(range(len(tensor.shape[:zero_index])))
# 对张量进行置换操作
return tensor.permute(first_inds + [zero_index + i for i in inds])
# 定义函数,对多个字典中相同键的值应用指定的函数
def dict_multimap(fn, dicts):
# 获取第一个字典
first = dicts[0]
new_dict = {}
# 遍历第一个字典的键值对
for k, v in first.items():
# 收集所有字典中相同键的值列表
all_v = [d[k] for d in dicts]
# 如果第一个字典中的值是字典类型,则递归调用dict_multimap函数
if isinstance(v, dict):
new_dict[k] = dict_multimap(fn, all_v)
else:
# 否则,对所有值应用给定的函数fn
new_dict[k] = fn(all_v)
# 返回应用函数后的新字典
return new_dict
# 定义函数,使用截断正态分布初始化权重张量
def trunc_normal_init_(weights, scale=1.0, fan="fan_in"):
shape = weights.shape
# 计算缩放系数
scale = scale / max(1, shape[1])
# 检查是否存在SciPy库,如果不存在,则给出警告
if not is_scipy_available():
logger.warning(
"This init requires scipy, but scipy was not found, default to an approximation that might not be"
" equivalent."
)
# 使用近似值初始化权重张量
std = math.sqrt(scale)
torch.nn.init.normal_(weights, std=std).clamp(min=0.0, max=2.0 * std)
else:
from scipy.stats import truncnorm
# 使用SciPy的截断正态分布生成权重样本
std = math.sqrt(scale) / truncnorm.std(a=-2, b=2, loc=0, scale=1)
samples = truncnorm.rvs(a=-2, b=2, loc=0, scale=std, size=weights.numel())
samples = np.reshape(samples, shape)
# 将生成的样本复制到权重张量中
weights.copy_(torch.tensor(samples, device=weights.device))
# 定义函数,使用指定值初始化权重张量
def ipa_point_weights_init_(weights):
with torch.no_grad():
softplus_inverse_1 = 0.541324854612918
# 用给定值填充权重张量
weights.fill_(softplus_inverse_1)
# 定义类,继承自torch.nn.Linear,实现了自定义的初始化方法
class EsmFoldLinear(nn.Linear):
"""
A Linear layer with built-in nonstandard initializations. Called just like torch.nn.Linear.
Implements the initializers in 1.11.4, plus some additional ones found in the code.
"""
def __init__(
self,
in_dim: int,
out_dim: int,
bias: bool = True,
init: str = "default",
init_fn: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None,
# 继承父类构造方法,定义额外的初始化参数
**kwargs
):
super().__init__(in_dim, out_dim, bias=bias, **kwargs)
):
"""
Args:
in_dim:
输入层的最终维度
out_dim:
层输出的最终维度
bias:
是否学习一个可加偏置,默认为True
init:
要使用的初始化器。可选项包括:
"default": LeCun fan-in截断正态分布初始化
"relu": 带截断正态分布的He初始化
"glorot": Fan-average Glorot均匀分布初始化
"gating": 权重=0,偏置=1
"normal": 标准差为1/sqrt(fan_in)的正态分布初始化
"final": 权重=0,偏置=0
如果init_fn不为None,则被init_fn覆盖。
init_fn:
接受权重和偏置作为输入的自定义初始化器。如果不为None,则覆盖init。
"""
# 调用父类构造函数,初始化输入维度、输出维度和是否有偏置
super().__init__(in_dim, out_dim, bias=bias)
# 如果有偏置,用0填充偏置项
if bias:
with torch.no_grad():
self.bias.fill_(0)
# 初始化器和自定义初始化器赋值
self.init = init
self.init_fn = init_fn
# 检查init参数是否合法
if init not in ["default", "relu", "glorot", "gating", "normal", "final"]:
raise ValueError("Invalid init string.")
class EsmFoldLayerNorm(nn.Module):
def __init__(self, c_in, eps=1e-5):
super().__init__()
self.c_in = (c_in,) # 输入通道数的元组,用于后续操作
self.eps = eps # Layer normalization 中的 epsilon 参数
self.weight = nn.Parameter(torch.ones(c_in)) # 可学习的权重参数,默认为全1
self.bias = nn.Parameter(torch.zeros(c_in)) # 可学习的偏置参数,默认为全0
def forward(self, x):
d = x.dtype # 获取输入张量 x 的数据类型
if d is torch.bfloat16 and not is_deepspeed_initialized(): # 如果输入是 bfloat16 并且没有启用深度速度优化
with torch.cuda.amp.autocast(enabled=False): # 禁用自动混合精度
out = nn.functional.layer_norm(x, self.c_in, self.weight.to(dtype=d), self.bias.to(dtype=d), self.eps) # 使用 layer normalization 进行归一化
else:
out = nn.functional.layer_norm(x, self.c_in, self.weight, self.bias, self.eps) # 使用 layer normalization 进行归一化
return out
@torch.jit.ignore
def softmax_no_cast(t: torch.Tensor, dim: int = -1) -> torch.Tensor:
"""
Softmax, but without automatic casting to fp32 when the input is of type bfloat16
"""
d = t.dtype # 获取输入张量 t 的数据类型
if d is torch.bfloat16 and not is_deepspeed_initialized(): # 如果输入是 bfloat16 并且没有启用深度速度优化
with torch.cuda.amp.autocast(enabled=False): # 禁用自动混合精度
s = torch.nn.functional.softmax(t, dim=dim) # 使用 softmax 计算张量 t 在指定维度上的概率分布
else:
s = torch.nn.functional.softmax(t, dim=dim) # 使用 softmax 计算张量 t 在指定维度上的概率分布
return s
class EsmFoldAttention(nn.Module):
"""
Standard multi-head attention using AlphaFold's default layer initialization. Allows multiple bias vectors.
"""
def __init__(
self,
c_q: int,
c_k: int,
c_v: int,
c_hidden: int,
no_heads: int,
gating: bool = True,
):
"""
Args:
c_q:
Input dimension of query data
c_k:
Input dimension of key data
c_v:
Input dimension of value data
c_hidden:
Per-head hidden dimension
no_heads:
Number of attention heads
gating:
Whether the output should be gated using query data
"""
super().__init__()
self.c_q = c_q # 查询数据的输入维度
self.c_k = c_k # 键数据的输入维度
self.c_v = c_v # 值数据的输入维度
self.c_hidden = c_hidden # 每个注意力头的隐藏层维度
self.no_heads = no_heads # 注意力头的数量
self.gating = gating # 是否使用查询数据对输出进行门控
# DISCREPANCY: c_hidden is not the per-head channel dimension, as
# stated in the supplement, but the overall channel dimension.
self.linear_q = EsmFoldLinear(self.c_q, self.c_hidden * self.no_heads, bias=False, init="glorot") # 查询线性变换层
self.linear_k = EsmFoldLinear(self.c_k, self.c_hidden * self.no_heads, bias=False, init="glorot") # 键线性变换层
self.linear_v = EsmFoldLinear(self.c_v, self.c_hidden * self.no_heads, bias=False, init="glorot") # 值线性变换层
self.linear_o = EsmFoldLinear(self.c_hidden * self.no_heads, self.c_q, init="final") # 输出线性变换层
self.linear_g = None
if self.gating:
self.linear_g = EsmFoldLinear(self.c_q, self.c_hidden * self.no_heads, init="gating") # 门控线性变换层
self.sigmoid = nn.Sigmoid() # Sigmoid 激活函数的实例化
# 准备 Q/K/V 查询、键、值的线性变换
def _prep_qkv(self, q_x: torch.Tensor, kv_x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# 对查询向量 q_x 执行线性变换
q = self.linear_q(q_x)
# 对键向量 kv_x 执行线性变换
k = self.linear_k(kv_x)
# 对值向量 kv_x 执行线性变换
v = self.linear_v(kv_x)
# 重新塑形以适应多头注意力机制的输入格式
# [*, Q/K/V, H, C_hidden]
q = q.view(q.shape[:-1] + (self.no_heads, -1))
k = k.view(k.shape[:-1] + (self.no_heads, -1))
v = v.view(v.shape[:-1] + (self.no_heads, -1))
# 将多头维度与注意力头部数交换位置,以便后续计算注意力权重
# [*, H, Q/K, C_hidden]
q = q.transpose(-2, -3)
k = k.transpose(-2, -3)
v = v.transpose(-2, -3)
# 缩放 Q 向量,以便在计算注意力权重时更稳定
q /= math.sqrt(self.c_hidden)
return q, k, v
# 处理输出结果 o,并应用可选的全局门控线性变换
def _wrap_up(self, o: torch.Tensor, q_x: torch.Tensor) -> torch.Tensor:
if self.linear_g is not None:
# 计算全局门控线性变换的输出,并应用 Sigmoid 激活函数
g = self.sigmoid(self.linear_g(q_x))
# 重新塑形以适应多头注意力机制的输入格式
# [*, Q, H, C_hidden]
g = g.view(g.shape[:-1] + (self.no_heads, -1))
o = o * g
# 将多头注意力机制的输出展平最后两个维度
# [*, Q, H * C_hidden]
o = flatten_final_dims(o, 2)
# 对最终的输出应用线性变换,将其映射到输出空间
# [*, Q, C_q]
o = self.linear_o(o)
return o
# 实现模型的前向传播
def forward(
self,
q_x: torch.Tensor,
kv_x: torch.Tensor,
biases: Optional[List[torch.Tensor]] = None,
use_memory_efficient_kernel: bool = False,
use_lma: bool = False,
lma_q_chunk_size: int = 1024,
lma_kv_chunk_size: int = 4096,
use_flash: bool = False,
flash_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Args:
q_x:
[*, Q, C_q] query data # 输入的查询数据,形状为 [*, Q, C_q]
kv_x:
[*, K, C_k] key data # 输入的键数据,形状为 [*, K, C_k]
biases:
List of biases that broadcast to [*, H, Q, K] # 广播到 [*, H, Q, K] 的偏置列表
use_memory_efficient_kernel:
Whether to use a custom memory-efficient attention kernel. This should be the default choice for most.
If none of the "use_<...>" flags are True, a stock PyTorch implementation is used instead
是否使用自定义的内存高效注意力核。对于大多数情况,这应该是默认选择。
如果没有一个 "use_<...>" 标志为 True,则使用标准的 PyTorch 实现
use_lma:
Whether to use low-memory attention (Staats & Rabe 2021). If none of the "use_<...>" flags are True, a
stock PyTorch implementation is used instead
是否使用低内存注意力 (Staats & Rabe 2021)。
如果没有一个 "use_<...>" 标志为 True,则使用标准的 PyTorch 实现
lma_q_chunk_size:
Query chunk size (for LMA) # 查询分块大小(用于低内存注意力)
lma_kv_chunk_size:
Key/Value chunk size (for LMA) # 键/值分块大小(用于低内存注意力)
Returns
[*, Q, C_q] attention update # 注意力更新后的输出,形状为 [*, Q, C_q]
"""
if use_lma and (lma_q_chunk_size is None or lma_kv_chunk_size is None):
raise ValueError("If use_lma is specified, lma_q_chunk_size and lma_kv_chunk_size must be provided")
# 如果使用低内存注意力,并且没有提供查询或键/值的分块大小,则抛出数值错误异常
if use_flash and biases is not None:
raise ValueError("use_flash is incompatible with the bias option. For masking, use flash_mask instead")
# 如果同时使用闪存和偏置选项,则抛出数值错误异常。应使用 flash_mask 进行遮罩操作而非偏置。
attn_options = [use_memory_efficient_kernel, use_lma, use_flash]
if sum(attn_options) > 1:
raise ValueError("Choose at most one alternative attention algorithm")
# 如果选择了多个注意力算法选项,则抛出数值错误异常。只能选择最多一个备选注意力算法。
if biases is None:
biases = []
# [*, H, Q/K, C_hidden]
query, key, value = self._prep_qkv(q_x, kv_x)
key = permute_final_dims(key, (1, 0))
# 准备查询、键、值,形状为 [*, H, Q/K, C_hidden],并将键的最后两个维度进行置换
# [*, H, Q, K]
output = torch.matmul(query, key)
# 执行矩阵乘法得到注意力分数矩阵 [*, H, Q, K]
for b in biases:
output += b
# 添加偏置到输出
output = softmax_no_cast(output, -1)
# 在最后一个维度上执行 softmax 操作,得到注意力权重
# [*, H, Q, C_hidden]
output = torch.matmul(output, value)
# 使用注意力权重加权值,得到加权后的值矩阵,形状为 [*, H, Q, C_hidden]
output = output.transpose(-2, -3)
# 对输出进行维度转置,将倒数第二个和倒数第三个维度进行交换
output = self._wrap_up(output, q_x)
# 调用 _wrap_up 方法对输出进行包装处理,根据查询数据 q_x
return output
class EsmFoldTriangleAttention(nn.Module):
# 定义 EsmFoldTriangleAttention 类,继承自 nn.Module
def __init__(self, c_in, c_hidden, no_heads, starting=True, inf=1e9):
"""
Args:
c_in:
输入通道维度
c_hidden:
总体隐藏通道维度(非每个注意力头)
no_heads:
注意力头的数量
"""
super().__init__()
# 初始化类的属性
self.c_in = c_in
self.c_hidden = c_hidden
self.no_heads = no_heads
self.starting = starting
self.inf = inf
# 初始化层归一化对象
self.layer_norm = LayerNorm(self.c_in)
# 初始化线性层对象
self.linear = EsmFoldLinear(c_in, self.no_heads, bias=False, init="normal")
# 初始化自定义的注意力对象
self.mha = EsmFoldAttention(self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads)
@torch.jit.ignore
def _chunk(
self,
x: torch.Tensor,
biases: List[torch.Tensor],
chunk_size: int,
use_memory_efficient_kernel: bool = False,
use_lma: bool = False,
inplace_safe: bool = False,
) -> torch.Tensor:
"triangle! triangle!"
# 准备输入参数字典给多头注意力的 chunk_layer 方法
mha_inputs = {
"q_x": x,
"kv_x": x,
"biases": biases,
}
# 使用 chunk_layer 函数对注意力进行分块处理
return chunk_layer(
partial(self.mha, use_memory_efficient_kernel=use_memory_efficient_kernel, use_lma=use_lma),
mha_inputs,
chunk_size=chunk_size,
no_batch_dims=len(x.shape[:-2]),
_out=x if inplace_safe else None,
)
def forward(
self,
x: torch.Tensor,
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
use_memory_efficient_kernel: bool = False,
use_lma: bool = False,
inplace_safe: bool = False,
) -> torch.Tensor:
# 正向传播函数,接收输入张量 x 和可选的掩码 mask
pass # 实际实现在此处省略
) -> torch.Tensor:
"""
Args:
x:
[*, I, J, C_in] input tensor (e.g. the pair representation)
Returns:
[*, I, J, C_in] output tensor
"""
# 如果没有提供掩码,则创建一个形状为 [*, I, J] 的新张量,所有元素为1
if mask is None:
mask = x.new_ones(
x.shape[:-1],
)
# 如果不是起始状态,交换输入张量的倒数第二和倒数第三个维度
if not self.starting:
x = x.transpose(-2, -3)
mask = mask.transpose(-1, -2)
# 对输入张量进行 layer normalization,形状保持不变 [*, I, J, C_in]
x = self.layer_norm(x)
# 创建一个形状为 [*, I, 1, 1, J] 的张量,其中 mask_bias 的计算基于 mask 张量
mask_bias = (self.inf * (mask - 1))[..., :, None, None, :]
# 对线性层的输出进行维度变换,形状为 [*, H, I, J]
triangle_bias = permute_final_dims(self.linear(x), (2, 0, 1))
# 在倒数第四个维度上扩展 triangle_bias,形状变为 [*, 1, H, I, J]
triangle_bias = triangle_bias.unsqueeze(-4)
# 将 mask_bias 和 triangle_bias 放入列表中作为偏置项
biases = [mask_bias, triangle_bias]
# 如果指定了 chunk_size,则调用 _chunk 方法处理输入 x 和 biases
if chunk_size is not None:
x = self._chunk(
x,
biases,
chunk_size,
use_memory_efficient_kernel=use_memory_efficient_kernel,
use_lma=use_lma,
inplace_safe=inplace_safe,
)
else:
# 否则调用 self.mha 进行多头注意力计算,使用给定的 biases
x = self.mha(
q_x=x, kv_x=x, biases=biases, use_memory_efficient_kernel=use_memory_efficient_kernel, use_lma=use_lma
)
# 如果不是起始状态,恢复 x 的倒数第二和倒数第三个维度的顺序
if not self.starting:
x = x.transpose(-2, -3)
# 返回处理后的张量 x
return x
"""
Implements Algorithms 11 and 12.
实现第 11 和第 12 算法。
"""
def __init__(self, config, _outgoing=True):
# 初始化函数,设置模型参数
super().__init__()
# 从配置中获取隐藏状态的维度
c_hidden = config.pairwise_state_dim
# 是否是外部输出
self._outgoing = _outgoing
# 定义线性层,用于算法中的计算
self.linear_a_p = EsmFoldLinear(c_hidden, c_hidden)
self.linear_a_g = EsmFoldLinear(c_hidden, c_hidden, init="gating")
self.linear_b_p = EsmFoldLinear(c_hidden, c_hidden)
self.linear_b_g = EsmFoldLinear(c_hidden, c_hidden, init="gating")
self.linear_g = EsmFoldLinear(c_hidden, c_hidden, init="gating")
self.linear_z = EsmFoldLinear(c_hidden, c_hidden, init="final")
# 初始化输入和输出的 LayerNorm
self.layer_norm_in = LayerNorm(c_hidden)
self.layer_norm_out = LayerNorm(c_hidden)
# 定义 Sigmoid 激活函数
self.sigmoid = nn.Sigmoid()
def _combine_projections(
self, a: torch.Tensor, b: torch.Tensor, _inplace_chunk_size: Optional[int] = None
) -> torch.Tensor:
# 组合投影函数,根据 _outgoing 参数确定维度顺序
if self._outgoing:
a = permute_final_dims(a, (2, 0, 1))
b = permute_final_dims(b, (2, 1, 0))
else:
a = permute_final_dims(a, (2, 1, 0))
b = permute_final_dims(b, (2, 0, 1))
# 如果指定了 _inplace_chunk_size,使用循环方式批量处理
if _inplace_chunk_size is not None:
# 待替换为 torch vmap 的部分
for i in range(0, a.shape[-3], _inplace_chunk_size):
a_chunk = a[..., i : i + _inplace_chunk_size, :, :]
b_chunk = b[..., i : i + _inplace_chunk_size, :, :]
a[..., i : i + _inplace_chunk_size, :, :] = torch.matmul(
a_chunk,
b_chunk,
)
p = a
else:
# 否则直接进行矩阵乘法运算
p = torch.matmul(a, b)
return permute_final_dims(p, (1, 2, 0))
def _inference_forward(
self,
z: torch.Tensor,
mask: Optional[torch.Tensor] = None,
inplace_chunk_size: Optional[int] = None,
with_add: bool = True,
):
# 推断过程的前向传播函数,包括处理 mask、是否进行 in-place 操作和是否添加额外计算
...
def forward(
self,
z: torch.Tensor,
mask: Optional[torch.Tensor] = None,
inplace_safe: bool = False,
_add_with_inplace: bool = False,
_inplace_chunk_size: Optional[int] = 256,
):
# 模型的前向传播函数,接受输入张量 z 和可选的 mask,执行模型计算
...
) -> torch.Tensor:
"""
Args:
x:
[*, N_res, N_res, C_z] input tensor 输入张量,形状为 [*, N_res, N_res, C_z]
mask:
[*, N_res, N_res] input mask 输入的遮罩,形状为 [*, N_res, N_res]
Returns:
[*, N_res, N_res, C_z] output tensor 输出张量,形状为 [*, N_res, N_res, C_z]
"""
if inplace_safe:
x = self._inference_forward(
z,
mask,
inplace_chunk_size=_inplace_chunk_size, # 设置原地操作的块大小
with_add=_add_with_inplace, # 原地操作时是否进行加法
)
return x # 返回处理后的张量
if mask is None:
mask = z.new_ones(z.shape[:-1]) # 使用输入 z 的形状创建全为 1 的遮罩
mask = mask.unsqueeze(-1) # 在最后一个维度上增加一个维度,形状变为 [*, N_res, N_res, 1]
z = self.layer_norm_in(z) # 输入 z 执行层归一化操作
a = mask # 将 mask 赋值给变量 a
a = a * self.sigmoid(self.linear_a_g(z)) # a 乘以线性变换后经过 sigmoid 函数的结果
a = a * self.linear_a_p(z) # a 乘以另一个线性变换的结果
b = mask # 将 mask 赋值给变量 b
b = b * self.sigmoid(self.linear_b_g(z)) # b 乘以线性变换后经过 sigmoid 函数的结果
b = b * self.linear_b_p(z) # b 乘以另一个线性变换的结果
if is_fp16_enabled(): # 如果启用了 FP16 计算
with torch.cuda.amp.autocast(enabled=False): # 关闭自动混合精度计算
x = self._combine_projections(a.float(), b.float()) # 使用浮点数进行投影组合
else:
x = self._combine_projections(a, b) # 使用原始数据类型进行投影组合
del a, b # 删除变量 a 和 b
x = self.layer_norm_out(x) # 对输出 x 进行层归一化操作
x = self.linear_z(x) # 对归一化后的 x 进行线性变换
g = self.sigmoid(self.linear_g(z)) # 对 z 执行线性变换后经过 sigmoid 函数的结果
x = x * g # 将 x 乘以 g
return x # 返回处理后的张量
class EsmFoldPreTrainedModel(EsmPreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
# Subclass `EsmPreTrainedModel` to handle special initialization of weights
def _init_weights(self, module):
"""Initialize the weights of the given module."""
# Check if the module is an instance of `EsmFoldLinear`
if isinstance(module, EsmFoldLinear):
# Apply weight initialization based on module's initialization method
with torch.no_grad():
if module.init_fn is not None:
module.init_fn(module.weight, module.bias)
elif module.init == "default":
trunc_normal_init_(module.weight, scale=1.0)
elif module.init == "relu":
trunc_normal_init_(module.weight, scale=2.0)
elif module.init == "glorot":
nn.init.xavier_uniform_(module.weight, gain=1)
elif module.init == "gating":
module.weight.fill_(0.0)
if module.bias:
module.bias.fill_(1.0)
elif module.init == "normal":
torch.nn.init.kaiming_normal_(module.weight, nonlinearity="linear")
elif module.init == "final":
module.weight.fill_(0.0)
elif isinstance(module, EsmFoldInvariantPointAttention):
ipa_point_weights_init_(module.head_weights)
elif isinstance(module, EsmFoldTriangularSelfAttentionBlock):
torch.nn.init.zeros_(module.tri_mul_in.linear_z.weight)
torch.nn.init.zeros_(module.tri_mul_in.linear_z.bias)
torch.nn.init.zeros_(module.tri_mul_out.linear_z.weight)
torch.nn.init.zeros_(module.tri_mul_out.linear_z.bias)
torch.nn.init.zeros_(module.tri_att_start.mha.linear_o.weight)
torch.nn.init.zeros_(module.tri_att_start.mha.linear_o.bias)
torch.nn.init.zeros_(module.tri_att_end.mha.linear_o.weight)
torch.nn.init.zeros_(module.tri_att_end.mha.linear_o.bias)
torch.nn.init.zeros_(module.sequence_to_pair.o_proj.weight)
torch.nn.init.zeros_(module.sequence_to_pair.o_proj.bias)
torch.nn.init.zeros_(module.pair_to_sequence.linear.weight)
torch.nn.init.zeros_(module.seq_attention.o_proj.weight)
torch.nn.init.zeros_(module.seq_attention.o_proj.bias)
torch.nn.init.zeros_(module.mlp_seq.mlp[-2].weight)
torch.nn.init.zeros_(module.mlp_seq.mlp[-2].bias)
torch.nn.init.zeros_(module.mlp_pair.mlp[-2].weight)
torch.nn.init.zeros_(module.mlp_pair.mlp[-2].bias)
else:
super()._init_weights(module)
class EsmFoldSelfAttention(nn.Module):
def __init__(self, embed_dim, num_heads, head_width, gated=False):
super().__init__()
assert embed_dim == num_heads * head_width
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_width = head_width
self.proj = nn.Linear(embed_dim, embed_dim * 3, bias=False)
self.o_proj = nn.Linear(embed_dim, embed_dim, bias=True)
self.gated = gated
if gated:
self.g_proj = nn.Linear(embed_dim, embed_dim)
torch.nn.init.zeros_(self.g_proj.weight)
torch.nn.init.ones_(self.g_proj.bias)
self.rescale_factor = self.head_width**-0.5
torch.nn.init.zeros_(self.o_proj.bias)
def forward(self, x, mask=None, bias=None, indices=None):
"""
基础的自注意力机制,可选带掩码和外部的注意力偏置。用于处理不同长度的序列,使用掩码。
Inputs:
x: 输入序列的批量 (.. x L x C) mask: 批量的布尔掩码,其中 1=有效,0=填充位置 (.. x L_k) bias: 批量的标量注意力偏置 (.. x Lq x Lk x num_heads)
Outputs:
序列投影 (B x L x embed_dim), 注意力映射 (B x L x L x num_heads)
"""
t = self.proj(x).view(*x.shape[:2], self.num_heads, -1)
t = t.permute(0, 2, 1, 3)
q, k, v = t.chunk(3, dim=-1)
q = self.rescale_factor * q
a = torch.einsum("...qc,...kc->...qk", q, k)
if bias is not None:
a = a + bias.permute(0, 3, 1, 2)
if mask is not None:
mask = mask[:, None, None]
a = a.masked_fill(mask == False, -np.inf)
a = nn.functional.softmax(a, dim=-1)
y = torch.einsum("...hqk,...hkc->...qhc", a, v)
y = y.reshape(*y.shape[:2], -1)
if self.gated:
y = self.g_proj(x).sigmoid() * y
y = self.o_proj(y)
return y, a.permute(0, 3, 1, 2)
class EsmFoldDropout(nn.Module):
"""
Implementation of dropout with the ability to share the dropout mask along a particular dimension.
"""
def __init__(self, r: float, batch_dim: Union[int, List[int]]):
super().__init__()
self.r = r
if isinstance(batch_dim, int):
batch_dim = [batch_dim]
self.batch_dim = batch_dim
self.dropout = nn.Dropout(self.r)
def forward(self, x: torch.Tensor) -> torch.Tensor:
shape = list(x.shape)
if self.batch_dim is not None:
for bd in self.batch_dim:
shape[bd] = 1
return x * self.dropout(x.new_ones(shape))
class EsmFoldSequenceToPair(nn.Module):
def __init__(self, sequence_state_dim, inner_dim, pairwise_state_dim):
super().__init__()
self.layernorm = nn.LayerNorm(sequence_state_dim)
self.proj = nn.Linear(sequence_state_dim, inner_dim * 2, bias=True)
self.o_proj = nn.Linear(2 * inner_dim, pairwise_state_dim, bias=True)
torch.nn.init.zeros_(self.proj.bias)
torch.nn.init.zeros_(self.o_proj.bias)
def forward(self, sequence_state):
"""
Inputs:
sequence_state: B x L x sequence_state_dim
Output:
pairwise_state: B x L x L x pairwise_state_dim
Intermediate state:
B x L x L x 2*inner_dim
"""
assert len(sequence_state.shape) == 3
s = self.layernorm(sequence_state)
s = self.proj(s)
q, k = s.chunk(2, dim=-1)
prod = q[:, None, :, :] * k[:, :, None, :]
diff = q[:, None, :, :] - k[:, :, None, :]
x = torch.cat([prod, diff], dim=-1)
x = self.o_proj(x)
return x
class EsmFoldPairToSequence(nn.Module):
def __init__(self, pairwise_state_dim, num_heads):
super().__init__()
self.layernorm = nn.LayerNorm(pairwise_state_dim)
self.linear = nn.Linear(pairwise_state_dim, num_heads, bias=False)
def forward(self, pairwise_state):
"""
Inputs:
pairwise_state: B x L x L x pairwise_state_dim
Output:
pairwise_bias: B x L x L x num_heads
"""
assert len(pairwise_state.shape) == 4
z = self.layernorm(pairwise_state)
pairwise_bias = self.linear(z)
return pairwise_bias
class EsmFoldResidueMLP(nn.Module):
def __init__(self, embed_dim, inner_dim, dropout=0):
super().__init__()
self.mlp = nn.Sequential(
nn.LayerNorm(embed_dim),
nn.Linear(embed_dim, inner_dim),
nn.ReLU(),
nn.Linear(inner_dim, embed_dim),
nn.Dropout(dropout),
)
def forward(self, x):
return x + self.mlp(x)
class EsmFoldTriangularSelfAttentionBlock(nn.Module):
"""
Placeholder for a module implementing a triangular self-attention block.
This class is not fully implemented in the provided code snippet.
"""
def __init__(self, config):
super().__init__()
self.config = config
sequence_state_dim = config.sequence_state_dim
pairwise_state_dim = config.pairwise_state_dim
sequence_num_heads = sequence_state_dim // config.sequence_head_width
pairwise_num_heads = pairwise_state_dim // config.pairwise_head_width
self.layernorm_1 = nn.LayerNorm(sequence_state_dim)
self.sequence_to_pair = EsmFoldSequenceToPair(sequence_state_dim, pairwise_state_dim // 2, pairwise_state_dim)
self.pair_to_sequence = EsmFoldPairToSequence(pairwise_state_dim, sequence_num_heads)
self.seq_attention = EsmFoldSelfAttention(
sequence_state_dim, sequence_num_heads, config.sequence_head_width, gated=True
)
self.tri_mul_out = EsmFoldTriangleMultiplicativeUpdate(config, _outgoing=True)
self.tri_mul_in = EsmFoldTriangleMultiplicativeUpdate(config, _outgoing=False)
self.tri_att_start = EsmFoldTriangleAttention(
pairwise_state_dim, config.pairwise_head_width, pairwise_num_heads, inf=1e9, starting=True
)
self.tri_att_end = EsmFoldTriangleAttention(
pairwise_state_dim, config.pairwise_head_width, pairwise_num_heads, inf=1e9, starting=False
)
self.mlp_seq = EsmFoldResidueMLP(sequence_state_dim, 4 * sequence_state_dim, dropout=config.dropout)
self.mlp_pair = EsmFoldResidueMLP(pairwise_state_dim, 4 * pairwise_state_dim, dropout=config.dropout)
self.drop = nn.Dropout(config.dropout)
self.row_drop = EsmFoldDropout(config.dropout * 2, 2)
self.col_drop = EsmFoldDropout(config.dropout * 2, 1)
class EsmCategoricalMixture:
def __init__(self, param, bins=50, start=0, end=1):
self.logits = param
bins = torch.linspace(start, end, bins + 1, device=self.logits.device, dtype=self.logits.dtype)
self.v_bins = (bins[:-1] + bins[1:]) / 2
def log_prob(self, true):
true_index = (true.unsqueeze(-1) - self.v_bins[[None] * true.ndim]).abs().argmin(-1)
nll = self.logits.log_softmax(-1)
return torch.take_along_dim(nll, true_index.unsqueeze(-1), dim=-1).squeeze(-1)
def mean(self):
return (self.logits.softmax(-1) @ self.v_bins.unsqueeze(1)).squeeze(-1)
def categorical_lddt(logits, bins=50):
return EsmCategoricalMixture(logits, bins=bins).mean()
def get_axial_mask(mask):
"""
Helper to convert B x L mask of valid positions to axial mask used in row column attentions.
Input:
mask: B x L tensor of booleans
Output:
mask: B x L x L tensor of booleans
"""
if mask is None:
return None
if len(mask.shape) != 2:
raise ValueError(f"`mask` should be a 2d-tensor, got {len(mask.shape)} dims.")
batch_dim, seq_dim = mask.shape
m = mask.unsqueeze(1).expand(batch_dim, seq_dim, seq_dim)
m = m.reshape(batch_dim * seq_dim, seq_dim)
return m
class EsmFoldRelativePosition(nn.Module):
def __init__(self, config):
super().__init__()
self.bins = config.position_bins
self.embedding = torch.nn.Embedding(2 * self.bins + 2, config.pairwise_state_dim)
def forward(self, residue_index, mask=None):
"""
Input:
residue_index: B x L tensor of indices (dytpe=torch.long) mask: B x L tensor of booleans
Output:
pairwise_state: B x L x L x pairwise_state_dim tensor of embeddings
"""
if residue_index.dtype != torch.long:
raise ValueError(f"`residue_index` has dtype {residue_index.dtype}, it should be `torch.long`.")
if mask is not None and residue_index.shape != mask.shape:
raise ValueError(
f"`residue_index` and `mask` have inconsistent shapes: {residue_index.shape} != {mask.shape}."
)
diff = residue_index[:, None, :] - residue_index[:, :, None]
diff = diff.clamp(-self.bins, self.bins)
diff = diff + self.bins + 1
if mask is not None:
mask = mask[:, None, :] * mask[:, :, None]
diff[mask == False] = 0
output = self.embedding(diff)
return output
class EsmFoldAngleResnetBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.linear_1 = EsmFoldLinear(config.resnet_dim, config.resnet_dim, init="relu")
self.linear_2 = EsmFoldLinear(config.resnet_dim, config.resnet_dim, init="final")
self.relu = nn.ReLU()
def forward(self, a: torch.Tensor) -> torch.Tensor:
s_initial = a
a = self.relu(a)
a = self.linear_1(a)
a = self.relu(a)
a = self.linear_2(a)
return a + s_initial
class EsmFoldAngleResnet(nn.Module):
"""
Implements Algorithm 20, lines 11-14
"""
def __init__(self, config):
super().__init__()
self.config = config
self.linear_in = EsmFoldLinear(config.sequence_dim, config.resnet_dim)
self.linear_initial = EsmFoldLinear(config.sequence_dim, config.resnet_dim)
self.layers = nn.ModuleList()
for _ in range(config.num_resnet_blocks):
layer = EsmFoldAngleResnetBlock(config)
self.layers.append(layer)
self.linear_out = EsmFoldLinear(config.resnet_dim, config.num_angles * 2)
self.relu = nn.ReLU()
def forward(self, s: torch.Tensor, s_initial: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
s:
[*, C_hidden] 单个嵌入向量
s_initial:
[*, C_hidden] StructureModule 开始时的单个嵌入向量
Returns:
Tuple[torch.Tensor, torch.Tensor]:
[*, no_angles, 2] 预测的角度
"""
s_initial = self.relu(s_initial)
s_initial = self.linear_initial(s_initial)
s = self.relu(s)
s = self.linear_in(s)
s = s + s_initial
for l in self.layers:
s = l(s)
s = self.relu(s)
s = self.linear_out(s)
s = s.view(s.shape[:-1] + (-1, 2))
unnormalized_s = s
norm_denom = torch.sqrt(
torch.clamp(
torch.sum(s**2, dim=-1, keepdim=True),
min=self.config.epsilon,
)
)
s = s / norm_denom
return unnormalized_s, s
class EsmFoldInvariantPointAttention(nn.Module):
"""
Implements Algorithm 22.
"""
def __init__(self, config):
super().__init__()
self.config = config
c_s = config.sequence_dim
c_z = config.pairwise_dim
self.hidden_dim = config.ipa_dim
self.num_heads = config.num_heads_ipa
self.num_qk_points = config.num_qk_points
self.num_v_points = config.num_v_points
hc = config.ipa_dim * config.num_heads_ipa
self.linear_q = EsmFoldLinear(c_s, hc)
self.linear_kv = EsmFoldLinear(c_s, 2 * hc)
hpq = config.num_heads_ipa * config.num_qk_points * 3
self.linear_q_points = EsmFoldLinear(c_s, hpq)
hpkv = config.num_heads_ipa * (config.num_qk_points + config.num_v_points) * 3
self.linear_kv_points = EsmFoldLinear(c_s, hpkv)
self.linear_b = EsmFoldLinear(c_z, config.num_heads_ipa)
self.head_weights = nn.Parameter(torch.zeros((config.num_heads_ipa)))
concat_out_dim = config.num_heads_ipa * (c_z + config.ipa_dim + config.num_v_points * 4)
self.linear_out = EsmFoldLinear(concat_out_dim, c_s, init="final")
self.softmax = nn.Softmax(dim=-1)
self.softplus = nn.Softplus()
def forward(
self,
s: torch.Tensor,
z: Optional[torch.Tensor],
r: Rigid,
mask: torch.Tensor,
_offload_inference: bool = False,
_z_reference_list: Optional[Sequence[torch.Tensor]] = None,
class EsmFoldBackboneUpdate(nn.Module):
"""
Implements part of Algorithm 23.
"""
def __init__(self, config):
super().__init__()
self.linear = EsmFoldLinear(config.sequence_dim, 6, init="final")
def forward(self, s: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
[*, N_res, C_s] single representation
Returns:
[*, N_res, 6] update vector
"""
update = self.linear(s)
return update
class EsmFoldStructureModuleTransitionLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.linear_1 = EsmFoldLinear(config.sequence_dim, config.sequence_dim, init="relu")
self.linear_2 = EsmFoldLinear(config.sequence_dim, config.sequence_dim, init="relu")
self.linear_3 = EsmFoldLinear(config.sequence_dim, config.sequence_dim, init="final")
self.relu = nn.ReLU()
def forward(self, s):
s_initial = s
s = self.linear_1(s)
s = self.relu(s)
s = self.linear_2(s)
s = self.relu(s)
s = self.linear_3(s)
s = s + s_initial
return s
class EsmFoldStructureModuleTransition(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.layers = nn.ModuleList()
for _ in range(config.num_transition_layers):
l = EsmFoldStructureModuleTransitionLayer(config)
self.layers.append(l)
self.dropout = nn.Dropout(config.dropout_rate)
self.layer_norm = LayerNorm(config.sequence_dim)
def forward(self, s):
for l in self.layers:
s = l(s)
s = self.dropout(s)
s = self.layer_norm(s)
return s
class EsmFoldStructureModule(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.layer_norm_s = LayerNorm(config.sequence_dim)
self.layer_norm_z = LayerNorm(config.pairwise_dim)
self.linear_in = EsmFoldLinear(config.sequence_dim, config.sequence_dim)
self.ipa = EsmFoldInvariantPointAttention(config)
self.ipa_dropout = nn.Dropout(config.dropout_rate)
self.layer_norm_ipa = LayerNorm(config.sequence_dim)
self.transition = EsmFoldStructureModuleTransition(config)
self.bb_update = EsmFoldBackboneUpdate(config)
self.angle_resnet = EsmFoldAngleResnet(config)
def forward(
self,
evoformer_output_dict,
aatype,
mask=None,
_offload_inference=False,
):
pass
def _init_residue_constants(self, float_dtype, device):
if not hasattr(self, "default_frames"):
self.register_buffer(
"default_frames",
torch.tensor(
residue_constants.restype_rigid_group_default_frame,
dtype=float_dtype,
device=device,
requires_grad=False,
),
persistent=False,
)
if not hasattr(self, "group_idx"):
self.register_buffer(
"group_idx",
torch.tensor(
residue_constants.restype_atom14_to_rigid_group,
device=device,
requires_grad=False,
),
persistent=False,
)
if not hasattr(self, "atom_mask"):
self.register_buffer(
"atom_mask",
torch.tensor(
residue_constants.restype_atom14_mask,
dtype=float_dtype,
device=device,
requires_grad=False,
),
persistent=False,
)
if not hasattr(self, "lit_positions"):
self.register_buffer(
"lit_positions",
torch.tensor(
residue_constants.restype_atom14_rigid_group_positions,
dtype=float_dtype,
device=device,
requires_grad=False,
),
persistent=False,
)
def torsion_angles_to_frames(self, r, alpha, f):
self._init_residue_constants(alpha.dtype, alpha.device)
return torsion_angles_to_frames(r, alpha, f, self.default_frames)
def frames_and_literature_positions_to_atom14_pos(self, r, f):
self._init_residue_constants(r.get_rots().dtype, r.get_rots().device)
return frames_and_literature_positions_to_atom14_pos(
r,
f,
self.default_frames,
self.group_idx,
self.atom_mask,
self.lit_positions,
)
class EsmFoldingTrunk(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
c_s = config.sequence_state_dim
c_z = config.pairwise_state_dim
self.pairwise_positional_embedding = EsmFoldRelativePosition(config)
self.blocks = nn.ModuleList([EsmFoldTriangularSelfAttentionBlock(config) for _ in range(config.num_blocks)])
self.recycle_bins = 15
self.recycle_s_norm = nn.LayerNorm(c_s)
self.recycle_z_norm = nn.LayerNorm(c_z)
self.recycle_disto = nn.Embedding(self.recycle_bins, c_z)
self.recycle_disto.weight[0].detach().zero_()
self.structure_module = EsmFoldStructureModule(config.structure_module)
self.trunk2sm_s = nn.Linear(c_s, config.structure_module.sequence_dim)
self.trunk2sm_z = nn.Linear(c_z, config.structure_module.pairwise_dim)
self.chunk_size = config.chunk_size
def set_chunk_size(self, chunk_size):
self.chunk_size = chunk_size
def forward(self, seq_feats, pair_feats, true_aa, residx, mask, no_recycles):
"""
Inputs:
seq_feats: B x L x C tensor of sequence features pair_feats: B x L x L x C tensor of pair features residx: B
x L long tensor giving the position in the sequence mask: B x L boolean tensor indicating valid residues
Output:
predicted_structure: B x L x (num_atoms_per_residue * 3) tensor wrapped in a Coordinates object
"""
device = seq_feats.device
s_s_0 = seq_feats
s_z_0 = pair_feats
if no_recycles is None:
no_recycles = self.config.max_recycles
else:
if no_recycles < 0:
raise ValueError("Number of recycles must not be negative.")
no_recycles += 1
def trunk_iter(s, z, residx, mask):
z = z + self.pairwise_positional_embedding(residx, mask=mask)
for block in self.blocks:
s, z = block(s, z, mask=mask, residue_index=residx, chunk_size=self.chunk_size)
return s, z
s_s = s_s_0
s_z = s_z_0
recycle_s = torch.zeros_like(s_s)
recycle_z = torch.zeros_like(s_z)
recycle_bins = torch.zeros(*s_z.shape[:-1], device=device, dtype=torch.int64)
for recycle_idx in range(no_recycles):
with ContextManagers([] if recycle_idx == no_recycles - 1 else [torch.no_grad()]):
recycle_s = self.recycle_s_norm(recycle_s.detach()).to(device)
recycle_z = self.recycle_z_norm(recycle_z.detach()).to(device)
recycle_z += self.recycle_disto(recycle_bins.detach()).to(device)
s_s, s_z = trunk_iter(s_s_0 + recycle_s, s_z_0 + recycle_z, residx, mask)
structure = self.structure_module(
{"single": self.trunk2sm_s(s_s), "pair": self.trunk2sm_z(s_z)},
true_aa,
mask.float(),
)
recycle_s = s_s
recycle_z = s_z
recycle_bins = EsmFoldingTrunk.distogram(
structure["positions"][-1][:, :, :3],
3.375,
21.375,
self.recycle_bins,
)
structure["s_s"] = s_s
structure["s_z"] = s_z
return structure
@staticmethod
def distogram(coords, min_bin, max_bin, num_bins):
boundaries = torch.linspace(
min_bin,
max_bin,
num_bins - 1,
device=coords.device,
)
boundaries = boundaries**2
N, CA, C = [x.squeeze(-2) for x in coords.chunk(3, dim=-2)]
b = CA - N
c = C - CA
a = b.cross(c, dim=-1)
CB = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + CA
dists = (CB[..., None, :, :] - CB[..., :, None, :]).pow(2).sum(dim=-1, keepdims=True)
bins = torch.sum(dists > boundaries, dim=-1)
return bins
@add_start_docstrings(
"""
ESMForProteinFolding is the HuggingFace port of the original ESMFold model. It consists of an ESM-2 "stem" followed
by a protein folding "head", although unlike most other output heads, this "head" is similar in size and runtime to
the rest of the model combined! It outputs a dictionary containing predicted structural information about the input
protein(s).
""",
ESM_START_DOCSTRING,
)
class EsmForProteinFolding(EsmPreTrainedModel):
_no_split_modules = ["EsmFoldStructureModule", "EsmFoldTriangularSelfAttentionBlock"]
def __init__(self, config):
super().__init__(config)
self.config = config
self.distogram_bins = 64
self.esm = EsmModel(config, add_pooling_layer=False)
self.esm.requires_grad_(False)
if self.config.esmfold_config.fp16_esm:
self.esm.half()
self.esm_feats = self.config.hidden_size
self.esm_attns = self.config.num_hidden_layers * self.config.num_attention_heads
self.esm_layers = self.config.num_hidden_layers
self.register_buffer("af2_to_esm", self._af2_to_esm_from_vocab_list(config.vocab_list))
self.esm_s_combine = nn.Parameter(torch.zeros(self.esm_layers + 1))
trunk_config = self.config.esmfold_config.trunk
c_s = trunk_config.sequence_state_dim
c_z = trunk_config.pairwise_state_dim
self.esm_s_mlp = nn.Sequential(
LayerNorm(self.esm_feats),
nn.Linear(self.esm_feats, c_s),
nn.ReLU(),
nn.Linear(c_s, c_s),
)
self.n_tokens_embed = residue_constants.restype_num + 3
self.pad_idx = 0
self.unk_idx = self.n_tokens_embed - 2
self.mask_idx = self.n_tokens_embed - 1
self.esm_dict_cls_idx = self.config.vocab_list.index("<cls>")
self.esm_dict_mask_idx = self.config.vocab_list.index("<mask>")
self.esm_dict_eos_idx = self.config.vocab_list.index("<eos>")
self.esm_dict_padding_idx = self.config.vocab_list.index("<pad>")
if self.config.esmfold_config.embed_aa:
self.embedding = nn.Embedding(self.n_tokens_embed, c_s, padding_idx=0)
self.trunk = EsmFoldingTrunk(trunk_config)
self.distogram_head = nn.Linear(c_z, self.distogram_bins)
self.ptm_head = nn.Linear(c_z, self.distogram_bins)
self.lm_head = nn.Linear(c_s, self.n_tokens_embed)
self.lddt_bins = 50
structure_module_config = trunk_config.structure_module
self.lddt_head = nn.Sequential(
nn.LayerNorm(structure_module_config.sequence_dim),
nn.Linear(structure_module_config.sequence_dim, self.config.esmfold_config.lddt_head_hid_dim),
nn.Linear(self.config.esmfold_config.lddt_head_hid_dim, self.config.esmfold_config.lddt_head_hid_dim),
nn.Linear(self.config.esmfold_config.lddt_head_hid_dim, 37 * self.lddt_bins),
)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
masking_pattern: Optional[torch.Tensor] = None,
num_recycles: Optional[int] = None,
):
pass
def af2_idx_to_esm_idx(self, aa, mask):
if self.af2_to_esm.device != aa.device:
self.af2_to_esm = self.af2_to_esm.to(aa.device)
aa = (aa + 1).masked_fill(mask != 1, 0)
return self.af2_to_esm[aa]
def compute_language_model_representations(self, esmaa: torch.Tensor) -> torch.Tensor:
device = next(self.parameters()).device
B, L = esmaa.shape
if self.config.esmfold_config.bypass_lm:
esm_s = torch.zeros(B, L, self.esm_s_combine.size[0], -1, self.esm_feats, device=device)
return esm_s
bosi, eosi = self.esm_dict_cls_idx, self.esm_dict_eos_idx
bos = esmaa.new_full((B, 1), bosi)
eos = esmaa.new_full((B, 1), self.esm_dict_padding_idx)
esmaa = torch.cat([bos, esmaa, eos], dim=1)
esmaa[range(B), (esmaa != 1).sum(1)] = eosi
esm_hidden_states = self.esm(esmaa, attention_mask=esmaa != 1, output_hidden_states=True)["hidden_states"]
esm_s = torch.stack(esm_hidden_states, dim=2)
esm_s = esm_s[:, 1:-1]
return esm_s
def bert_mask(self, aa, esmaa, mask, pattern):
new_aa = aa.clone()
target = aa.clone()
new_esmaa = esmaa.clone()
new_aa[pattern == 1] = self.mask_idx
target[pattern != 1] = 0
new_esmaa[pattern == 1] = self.esm_dict_mask_idx
return new_aa, new_esmaa, target
@torch.no_grad()
def infer(
self,
seqs: Union[str, List[str]],
position_ids=None,
):
if isinstance(seqs, str):
lst = [seqs]
else:
lst = seqs
device = next(self.parameters()).device
aatype = collate_dense_tensors(
[
torch.from_numpy(
residue_constants.sequence_to_onehot(
sequence=seq,
mapping=residue_constants.restype_order_with_x,
map_unknown_to_x=True,
)
)
.to(device)
.argmax(dim=1)
for seq in lst
]
)
mask = collate_dense_tensors([aatype.new_ones(len(seq)) for seq in lst])
position_ids = (
torch.arange(aatype.shape[1], device=device).expand(len(lst), -1)
if position_ids is None
else position_ids.to(device)
)
if position_ids.ndim == 1:
position_ids = position_ids.unsqueeze(0)
return self.forward(
aatype,
mask,
position_ids=position_ids,
)
@staticmethod
def output_to_pdb(output: Dict) -> List[str]:
"""Returns the pdb (file) string from the model given the model output."""
output = {k: v.to("cpu").numpy() for k, v in output.items()}
pdbs = []
final_atom_positions = atom14_to_atom37(output["positions"][-1], output)
final_atom_mask = output["atom37_atom_exists"]
for i in range(output["aatype"].shape[0]):
aa = output["aatype"][i]
pred_pos = final_atom_positions[i]
mask = final_atom_mask[i]
resid = output["residue_index"][i] + 1
pred = OFProtein(
aatype=aa,
atom_positions=pred_pos,
atom_mask=mask,
residue_index=resid,
b_factors=output["plddt"][i],
)
pdbs.append(to_pdb(pred))
return pdbs
def infer_pdb(self, seqs, *args, **kwargs) -> str:
"""Returns the pdb (file) string from the model given an input sequence."""
assert isinstance(seqs, str)
output = self.infer(seqs, *args, **kwargs)
return self.output_to_pdb(output)[0]
def infer_pdbs(self, seqs: List[str], *args, **kwargs) -> List[str]:
"""Returns the pdb (file) string from the model given an input sequence."""
output = self.infer(seqs, *args, **kwargs)
return self.output_to_pdb(output)
.\models\esm\modeling_tf_esm.py
from __future__ import annotations
import os
from typing import Optional, Tuple, Union
import numpy as np
import tensorflow as tf
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_tf_outputs import (
TFBaseModelOutputWithPastAndCrossAttentions,
TFBaseModelOutputWithPoolingAndCrossAttentions,
TFMaskedLMOutput,
TFSequenceClassifierOutput,
TFTokenClassifierOutput,
)
from ...modeling_tf_utils import (
TFMaskedLanguageModelingLoss,
TFModelInputType,
TFPreTrainedModel,
TFSequenceClassificationLoss,
TFTokenClassificationLoss,
get_initializer,
keras,
shape_list,
unpack_inputs,
)
from ...tf_utils import check_embeddings_within_bounds, stable_softmax
from ...utils import logging
from .configuration_esm import EsmConfig
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "facebook/esm2_t6_8M_UR50D"
_CONFIG_FOR_DOC = "EsmConfig"
TF_ESM_PRETRAINED_MODEL_ARCHIVE_LIST = [
"facebook/esm2_t6_8M_UR50D",
"facebook/esm2_t12_35M_UR50D",
]
def rotate_half(x):
"""
将张量沿最后一个维度分割成两半,然后进行旋转操作。
Args:
x: 输入的张量
Returns:
tf.Tensor: 旋转后的张量
"""
x1, x2 = tf.split(x, 2, axis=-1)
return tf.concat((-x2, x1), axis=-1)
def apply_rotary_pos_emb(x, cos, sin):
"""
应用旋转位置嵌入到输入张量 x 中。
Args:
x: 输入的张量
cos: 余弦值张量
sin: 正弦值张量
Returns:
tf.Tensor: 应用旋转位置嵌入后的张量
"""
cos = cos[:, :, : tf.shape(x)[-2], :]
sin = sin[:, :, : tf.shape(x)[-2], :]
return (x * cos) + (rotate_half(x) * sin)
def symmetrize(x):
"""
对最后两个维度进行转置操作,使层对称化,用于接触预测。
Args:
x: 输入张量
Returns:
tf.Tensor: 对称化后的张量
"""
return x + tf.linalg.matrix_transpose(x)
def average_product_correct(x):
"""
执行平均产品校正,用于接触预测。
Args:
x: 输入张量
Returns:
tf.Tensor: 校正后的张量
"""
a1 = tf.reduce_sum(x, -1, keepdims=True)
a2 = tf.reduce_sum(x, -2, keepdims=True)
a12 = tf.reduce_sum(x, (-1, -2), keepdims=True)
avg = a1 * a2
avg = avg / a12
normalized = x - avg
return normalized
class TFRotaryEmbedding(keras.layers.Layer):
"""
基于 RoFormer 中的旋转位置嵌入,对查询和键进行旋转矩阵变换,依赖它们的相对位置。
"""
def __init__(self, dim: int, name=None):
super().__init__(name=name)
self.dim = dim
def build(self, input_shape):
super().build(input_shape)
self.inv_freq = self.add_weight(
"inv_freq", shape=(self.dim // 2,), dtype=tf.float32, initializer=get_initializer(1.0), trainable=False
)
self.inv_freq.assign(
1.0 / (10000 ** (tf.range(start=0, limit=self.dim, delta=2, dtype=tf.float32) / self.dim))
)
def _compute_cos_sin(self, x, seq_dimension=2):
seq_len = tf.shape(x)[seq_dimension]
t = tf.range(seq_len, dtype=self.inv_freq.dtype)
freqs = tf.einsum("i, j -> ij", t, self.inv_freq)
emb = tf.concat((freqs, freqs), axis=-1)[None, None, :, :]
return tf.cos(emb), tf.sin(emb)
def call(self, q: tf.Tensor, k: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
cos_emb, sin_emb = self._compute_cos_sin(k, seq_dimension=-2)
return (
apply_rotary_pos_emb(q, cos_emb, sin_emb),
apply_rotary_pos_emb(k, cos_emb, sin_emb),
)
class TFEsmContactPredictionHead(keras.layers.Layer):
"""Performs symmetrization, apc, and computes a logistic regression on the output features"""
def __init__(
self,
in_features: int,
bias=True,
eos_idx: int = 2,
name=None,
):
super().__init__(name=name)
self.eos_idx = eos_idx
self.in_features = in_features
self.regression = keras.layers.Dense(1, use_bias=bias, activation="sigmoid", name="regression")
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "regression", None) is not None:
with tf.name_scope(self.regression.name):
self.regression.build((None, self.in_features))
def call(self, tokens, attentions):
eos_mask = tf.cast(tokens != self.eos_idx, attentions.dtype)
eos_mask = tf.expand_dims(eos_mask, 1) * tf.expand_dims(eos_mask, 2)
attentions = attentions * eos_mask[:, None, None, :, :]
attentions = attentions[..., :-1, :-1]
attentions = attentions[..., 1:, 1:]
batch_size, layers, heads, seqlen, _ = shape_list(attentions)
attentions = tf.reshape(attentions, (batch_size, layers * heads, seqlen, seqlen))
attentions = average_product_correct(symmetrize(attentions))
attentions = tf.transpose(attentions, perm=(0, 2, 3, 1))
return tf.squeeze(self.regression(attentions), 3)
class TFEsmEmbeddings(keras.layers.Layer):
"""
Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
"""
def __init__(self, config, name=None):
super().__init__(name=name)
self.word_embeddings = keras.layers.Embedding(
config.vocab_size,
config.hidden_size,
embeddings_initializer=get_initializer(config.initializer_range),
name="word_embeddings",
)
self.position_embeddings = keras.layers.Embedding(
config.max_position_embeddings,
config.hidden_size,
embeddings_initializer=get_initializer(config.initializer_range),
name="position_embeddings",
)
if config.emb_layer_norm_before:
self.layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
else:
self.layer_norm = None
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.position_ids = tf.range(config.max_position_embeddings)[None, :]
self.padding_idx = config.pad_token_id
self.token_dropout = config.token_dropout
self.mask_token_id = config.mask_token_id
self.config = config
):
if position_ids is None:
if input_ids is not None:
position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length)
else:
position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
if inputs_embeds is None:
check_embeddings_within_bounds(input_ids, self.config.vocab_size)
inputs_embeds = self.word_embeddings(input_ids)
embeddings = inputs_embeds
if self.token_dropout:
embeddings = tf.where((input_ids == self.mask_token_id)[:, :, None], 0.0, embeddings)
mask_ratio_train = 0.15 * 0.8
src_lengths = tf.cast(tf.reduce_sum(attention_mask, axis=-1), tf.float32)
masked_tokens = input_ids == self.mask_token_id
mask_ratio_observed = tf.math.count_nonzero(masked_tokens, dtype=tf.float32, axis=-1) / src_lengths
embeddings = embeddings * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]
if self.position_embedding_type == "absolute":
position_embeddings = self.position_embeddings(position_ids)
embeddings += position_embeddings
if self.layer_norm is not None:
embeddings = self.layer_norm(embeddings)
if attention_mask is not None:
embeddings = embeddings * tf.cast(tf.expand_dims(attention_mask, -1), embeddings.dtype)
return embeddings
if self.built:
return
self.built = True
if getattr(self, "word_embeddings", None) is not None:
with tf.name_scope(self.word_embeddings.name):
self.word_embeddings.build(None)
if getattr(self, "position_embeddings", None) is not None:
with tf.name_scope(self.position_embeddings.name):
self.position_embeddings.build(None)
if getattr(self, "layer_norm", None) is not None:
with tf.name_scope(self.layer_norm.name):
self.layer_norm.build([None, None, self.config.hidden_size])
class TFEsmSelfAttention(keras.layers.Layer):
def __init__(self, config, position_embedding_type=None, name=None):
super().__init__(name=name)
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
f"heads ({config.num_attention_heads})"
)
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = keras.layers.Dense(
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
)
self.key = keras.layers.Dense(
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key"
)
self.value = keras.layers.Dense(
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
)
self.dropout = keras.layers.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = position_embedding_type or getattr(
config, "position_embedding_type", "absolute"
)
self.rotary_embeddings = None
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = keras.layers.Embedding(
2 * config.max_position_embeddings - 1,
self.attention_head_size,
embeddings_initializer=get_initializer(config.initializer_range),
)
elif self.position_embedding_type == "rotary":
self.rotary_embeddings = TFRotaryEmbedding(dim=self.attention_head_size, name="rotary_embeddings")
self.is_decoder = config.is_decoder
self.config = config
def transpose_for_scores(self, x: tf.Tensor) -> tf.Tensor:
new_x_shape = shape_list(x)[:-1] + [self.num_attention_heads, self.attention_head_size]
x = tf.reshape(x, new_x_shape)
return tf.transpose(x, perm=(0, 2, 1, 3))
def call(
self,
hidden_states: tf.Tensor,
attention_mask: tf.Tensor | None = None,
head_mask: tf.Tensor | None = None,
encoder_hidden_states: tf.Tensor | None = None,
encoder_attention_mask: tf.Tensor | None = None,
past_key_value: Tuple[Tuple[tf.Tensor]] | None = None,
output_attentions: Optional[bool] = False,
training: bool = False,
**kwargs
) -> Tuple[tf.Tensor, Optional[Tuple[tf.Tensor]]]:
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "query", None) is not None:
with tf.name_scope(self.query.name):
self.query.build([None, None, self.config.hidden_size])
if getattr(self, "key", None) is not None:
with tf.name_scope(self.key.name):
self.key.build([None, None, self.config.hidden_size])
if getattr(self, "value", None) is not None:
with tf.name_scope(self.value.name):
self.value.build([None, None, self.config.hidden_size])
if getattr(self, "rotary_embeddings", None) is not None:
with tf.name_scope(self.rotary_embeddings.name):
self.rotary_embeddings.build(None)
class TFEsmSelfOutput(keras.layers.Layer):
def __init__(self, config, name=None):
super().__init__(name=name)
self.dense = keras.layers.Dense(
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
)
self.dropout = keras.layers.Dropout(config.hidden_dropout_prob)
self.config = config
def call(self, hidden_states, input_tensor, training=False):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states, training=training)
hidden_states += input_tensor
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
class TFEsmAttention(keras.layers.Layer):
def __init__(self, config, name=None):
super().__init__(name=name)
self.self = TFEsmSelfAttention(config, name="self")
self.output_layer = TFEsmSelfOutput(config, name="output")
self.pruned_heads = set()
self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.config = config
def prune_heads(self, heads):
raise NotImplementedError
def call(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=None,
output_attentions=False,
training=False,
):
hidden_states_ln = self.LayerNorm(hidden_states)
self_outputs = self.self(
hidden_states_ln,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
training,
)
attention_output = self.output_layer(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:]
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "self", None) is not None:
with tf.name_scope(self.self.name):
self.self.build(None)
if getattr(self, "output_layer", None) is not None:
with tf.name_scope(self.output_layer.name):
self.output_layer.build(None)
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])
def __init__(self, config: EsmConfig, **kwargs):
super().__init__(**kwargs)
self.dense = keras.layers.Dense(
units=config.intermediate_size,
kernel_initializer=get_initializer(config.initializer_range),
name="dense",
)
self.config = config
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
hidden_states = tf.nn.gelu(hidden_states)
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
class TFEsmLayer(keras.layers.Layer):
def __init__(self, config, name=None):
super().__init__(name=name)
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = TFEsmAttention(config, name="attention")
self.is_decoder = config.is_decoder
self.add_cross_attention = config.add_cross_attention
if self.add_cross_attention:
if not self.is_decoder:
raise RuntimeError(f"{self} should be used as a decoder model if cross attention is added")
self.crossattention = TFEsmAttention(config)
self.intermediate = TFEsmIntermediate(config, name="intermediate")
self.output_layer = TFEsmOutput(config, name="output")
self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.config = config
def call(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=None,
output_attentions=False,
training=False,
):
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention(
hidden_states,
attention_mask,
head_mask,
output_attentions=output_attentions,
past_key_value=self_attn_past_key_value,
training=training,
)
attention_output = self_attention_outputs[0]
if self.is_decoder:
outputs = self_attention_outputs[1:-1]
present_key_value = self_attention_outputs[-1]
else:
outputs = self_attention_outputs[1:]
cross_attn_present_key_value = None
if self.is_decoder and encoder_hidden_states is not None:
if not hasattr(self, "crossattention"):
raise AttributeError(
f"If `encoder_hidden_states` are passed, {self} has to be instantiated"
" with cross-attention layers by setting `config.add_cross_attention=True`"
)
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
cross_attention_outputs = self.crossattention(
attention_output,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
cross_attn_past_key_value,
output_attentions,
training=training,
)
attention_output = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1:-1]
cross_attn_present_key_value = cross_attention_outputs[-1]
present_key_value = present_key_value + cross_attn_present_key_value
layernorm_output = self.LayerNorm(attention_output)
intermediate_output = self.intermediate(hidden_states=layernorm_output)
layer_output = self.output_layer(
hidden_states=intermediate_output, input_tensor=attention_output, training=training
)
outputs = (layer_output,) + outputs
if self.is_decoder:
outputs = outputs + (present_key_value,)
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "attention", None) is not None:
with tf.name_scope(self.attention.name):
self.attention.build(None)
if getattr(self, "intermediate", None) is not None:
with tf.name_scope(self.intermediate.name):
self.intermediate.build(None)
if getattr(self, "output_layer", None) is not None:
with tf.name_scope(self.output_layer.name):
self.output_layer.build(None)
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])
class TFEsmEncoder(keras.layers.Layer):
def __init__(self, config, name=None):
super().__init__(name=name)
self.config = config
self.layer = [TFEsmLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)]
self.emb_layer_norm_after = keras.layers.LayerNormalization(
epsilon=config.layer_norm_eps, name="emb_layer_norm_after"
)
def call(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
training=False,
):
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
next_decoder_cache = () if use_cache else None
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
training,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[-1],)
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
if self.emb_layer_norm_after:
hidden_states = self.emb_layer_norm_after(hidden_states)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v
for v in [
hidden_states,
next_decoder_cache,
all_hidden_states,
all_self_attentions,
all_cross_attentions,
]
if v is not None
)
return TFBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
)
if self.built:
return
self.built = True
if getattr(self, "emb_layer_norm_after", None) is not None:
with tf.name_scope(self.emb_layer_norm_after.name):
self.emb_layer_norm_after.build([None, None, self.config.hidden_size])
if getattr(self, "layer", None) is not None:
for layer in self.layer:
with tf.name_scope(layer.name):
layer.build(None)
"""
定义一个自定义的 Keras 层 TFEsmPooler,用于 ESM 模型的池化操作。
从 transformers.models.bert.modeling_tf_bert.TFBertPooler 复制并修改为 ESM。
Parameters:
config (EsmConfig): ESM 模型的配置对象,包含模型的各种参数。
Attributes:
dense (Dense): 密集连接层,用于处理隐藏状态向量。
config (EsmConfig): ESM 模型的配置对象。
Methods:
call(hidden_states: tf.Tensor) -> tf.Tensor:
对隐藏状态进行池化操作,只使用第一个 token 对应的隐藏状态。
build(input_shape=None):
构建层,初始化密集连接层。
"""
"""
ESM 模型的预训练模型基类 TFEsmPreTrainedModel。
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
Attributes:
config_class (EsmConfig): 模型配置类,指定为 EsmConfig。
base_model_prefix (str): 基础模型名称前缀,设为 "esm"。
Notes:
该类提供了预训练模型的通用方法,如初始化权重、下载和加载预训练模型等。
"""
"""
ESM 模型的输入文档字符串,描述模型的基本信息和使用方法。
This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a Keras [Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it as a
regular Keras model and refer to the TF/Keras documentation for all matters related to general usage and behavior.
Parameters:
config ([`EsmConfig`]): Model configuration class with all the parameters of the
model. Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
"""
"""
ESM 模型的输入文档字符串,描述输入参数的详细信息和用法示例。
"""
Args:
input_ids (`tf.Tensor` of shape `({0})`):
Indices of input sequence tokens in the vocabulary.
attention_mask (`tf.Tensor` of shape `({0})`, *optional*):
Mask to avoid performing attention on padding token indices.
position_ids (`tf.Tensor` of shape `({0})`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings.
head_mask (`tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
Mask to nullify selected heads of the self-attention modules.
inputs_embeds (`tf.Tensor` of shape `({0}, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers.
return_dict (`bool`, *optional*):
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
"""
@add_start_docstrings(
"The bare ESM Model transformer outputting raw hidden-states without any specific head on top.",
ESM_START_DOCSTRING,
)
class TFEsmMainLayer(keras.layers.Layer):
"""
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
cross-attention is added between the self-attention layers, following the architecture described in [Attention is
all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
`add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
"""
_keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config, add_pooling_layer=True, name=None, **kwargs):
super().__init__(name=name, **kwargs)
self.config = config
self.is_decoder = config.is_decoder # 初始化解码器标志位
self.embeddings = TFEsmEmbeddings(config, name="embeddings") # 初始化嵌入层
self.encoder = TFEsmEncoder(config, name="encoder") # 初始化编码器
self.pooler = TFEsmPooler(config, name="pooler") if add_pooling_layer else None # 初始化池化层(如果需要)
self.contact_head = TFEsmContactPredictionHead(
in_features=self.config.num_hidden_layers * self.config.num_attention_heads, bias=True, name="contact_head"
) # 初始化接触预测头部
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "embeddings", None) is not None:
with tf.name_scope(self.embeddings.name):
self.embeddings.build(None) # 构建嵌入层
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None) # 构建编码器
if getattr(self, "pooler", None) is not None:
with tf.name_scope(self.pooler.name):
self.pooler.build(None) # 构建池化层
if getattr(self, "contact_head", None) is not None:
with tf.name_scope(self.contact_head.name):
self.contact_head.build(None) # 构建接触预测头部
def get_input_embeddings(self):
return self.embeddings.word_embeddings # 获取输入嵌入层的词嵌入
def set_input_embeddings(self, value: tf.Variable):
self.embeddings.word_embeddings.weight = value # 设置输入嵌入层的权重
self.embeddings.vocab_size = shape_list(value)[0] # 设置词汇表大小
def _prune_heads(self, heads_to_prune):
raise NotImplementedError # 剪枝头部的方法,未实现
# 定义一个方法,用于调用模型,接收多种输入参数并返回预测结果
def call(
self,
input_ids: TFModelInputType | None = None,
attention_mask: np.ndarray | tf.Tensor | None = None,
position_ids: np.ndarray | tf.Tensor | None = None,
head_mask: np.ndarray | tf.Tensor | None = None,
inputs_embeds: np.ndarray | tf.Tensor | None = None,
encoder_hidden_states: np.ndarray | tf.Tensor | None = None,
encoder_attention_mask: np.ndarray | tf.Tensor | None = None,
past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
):
# 定义一个方法,用于预测模型的接触点(contacts)
def predict_contacts(self, tokens, attention_mask):
# 调用当前对象(self)的call方法,传入tokens和attention_mask作为输入,
# 并设定return_dict和output_attentions参数为True,以获取注意力权重信息。
attns = self(tokens, attention_mask=attention_mask, return_dict=True, output_attentions=True).attentions
# 将得到的注意力权重列表堆叠成一个张量,维度顺序与原始模型一致
attns = tf.stack(attns, axis=1)
# 在原始模型中,对于填充标记的注意力权重被完全置零。
# 这通常不会有太大影响,因为其他标记不会关注它们,
# 但是在接触点预测任务中,它们作为输入需要被模仿。
# 因此,这里要做的是将填充标记对应位置的注意力权重置零。
attention_mask = tf.cast(attention_mask, attns.dtype)
attns *= attention_mask[:, None, None, None] # 扩展维度匹配注意力权重张量
attns *= attention_mask[:, None, None, :, None] # 扩展维度匹配注意力权重张量
# 调用模型的contact_head方法,传入tokens和处理后的注意力权重attns作为参数,
# 返回接触点预测的结果。
return self.contact_head(tokens, attns)
# 给 TFEsmModel 类添加文档字符串,描述其作为没有特定顶部头的原始隐藏状态输出的 ES 模型转换器
@add_start_docstrings(
"The bare ESM Model transformer outputting raw hidden-states without any specific head on top.",
ESM_START_DOCSTRING,
)
class TFEsmModel(TFEsmPreTrainedModel):
def __init__(self, config: EsmConfig, add_pooling_layer=True, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
# 初始化 ES 模型的主层,根据给定的配置和是否添加池化层
self.esm = TFEsmMainLayer(config, add_pooling_layer=add_pooling_layer, name="esm")
# 对 call 方法进行装饰,添加文档字符串以描述模型前向传播的输入
@unpack_inputs
@add_start_docstrings_to_model_forward(ESM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TFBaseModelOutputWithPoolingAndCrossAttentions,
config_class=_CONFIG_FOR_DOC,
)
def call(
self,
input_ids: TFModelInputType | None = None,
attention_mask: np.ndarray | tf.Tensor | None = None,
position_ids: np.ndarray | tf.Tensor | None = None,
head_mask: np.ndarray | tf.Tensor | None = None,
inputs_embeds: np.ndarray | tf.Tensor | None = None,
encoder_hidden_states: np.ndarray | tf.Tensor | None = None,
encoder_attention_mask: np.ndarray | tf.Tensor | None = None,
past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: Optional[bool] = False,
# 这里继续列出所有的参数,描述它们的作用和可选性
) -> Union[TFBaseModelOutputWithPoolingAndCrossAttentions, Tuple[tf.Tensor]]:
r"""
encoder_hidden_states (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
the model is configured as a decoder.
encoder_attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
past_key_values (`Tuple[Tuple[tf.Tensor]]` of length `config.n_layers`)
contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
use_cache (`bool`, *optional*, defaults to `True`):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`). Set to `False` during training, `True` during generation
"""
outputs = self.esm(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
return outputs
def predict_contacts(self, tokens, attention_mask):
# 调用模型的方法来预测接触点
return self.esm.predict_contacts(tokens, attention_mask)
def build(self, input_shape=None):
if self.built:
return
# 标记模型已构建
self.built = True
if getattr(self, "esm", None) is not None:
with tf.name_scope(self.esm.name):
# 构建模型的子模块
self.esm.build(None)
# 为模型添加文档字符串,描述其为带有顶部语言建模头的ESM模型
@add_start_docstrings("""ESM Model with a `language modeling` head on top.""", ESM_START_DOCSTRING)
class TFEsmForMaskedLM(TFEsmPreTrainedModel, TFMaskedLanguageModelingLoss):
# 在加载过程中忽略缺失的关键字列表
_keys_to_ignore_on_load_missing = [r"position_ids"]
# 在加载过程中忽略意外的关键字列表
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config):
super().__init__(config)
# 如果配置指示为decoder,则发出警告
if config.is_decoder:
logger.warning(
"If you want to use `EsmForMaskedLM` make sure `config.is_decoder=False` for "
"bi-directional self-attention."
)
# 初始化ESM主层,不添加池化层,并命名为"esm"
self.esm = TFEsmMainLayer(config, add_pooling_layer=False, name="esm")
# 初始化ESM语言建模头,并命名为"lm_head"
self.lm_head = TFEsmLMHead(config, name="lm_head")
# 如果需要绑定词嵌入
if config.tie_word_embeddings:
# 确保词嵌入已构建,以便进行绑定
with tf.name_scope(os.path.join(self._name_scope(), "esm", "embeddings", "word_embeddings")):
self.esm.embeddings.word_embeddings.build((None, None))
# 将lm_head的解码器设置为与ESM的词嵌入权重相同
self.lm_head.decoder = self.esm.embeddings.word_embeddings.weights[0]
# 获取输出嵌入
def get_output_embeddings(self):
return self.lm_head.decoder
# 设置输出嵌入
def set_output_embeddings(self, new_embeddings):
self.lm_head.decoder = new_embeddings
# 获取语言建模头
def get_lm_head(self):
return self.lm_head
# 模型调用函数,解包输入并添加模型前向传播的文档字符串
@unpack_inputs
@add_start_docstrings_to_model_forward(ESM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TFMaskedLMOutput,
config_class=_CONFIG_FOR_DOC,
mask="<mask>",
)
def call(
self,
input_ids: TFModelInputType | None = None,
attention_mask: np.ndarray | tf.Tensor | None = None,
position_ids: np.ndarray | tf.Tensor | None = None,
head_mask: np.ndarray | tf.Tensor | None = None,
inputs_embeds: np.ndarray | tf.Tensor | None = None,
encoder_hidden_states: np.ndarray | tf.Tensor | None = None,
encoder_attention_mask: np.ndarray | tf.Tensor | None = None,
labels: np.ndarray | tf.Tensor | None = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
):
# 模型前向传播逻辑在此实现
) -> Union[TFMaskedLMOutput, Tuple[tf.Tensor]]:
r"""
labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
kwargs (`Dict[str, any]`, optional, defaults to *{}*):
Used to hide legacy arguments that have been deprecated.
"""
# 设置是否返回字典格式的输出,如果未提供,则使用配置中的默认设置
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# 使用 ESM 模型进行前向传播
outputs = self.esm(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
# 获取模型输出的序列特征
sequence_output = outputs[0]
# 使用语言模型头部生成预测分数
prediction_scores = self.lm_head(sequence_output)
masked_lm_loss = None
# 如果提供了标签,则计算掩码语言建模损失
if labels is not None:
masked_lm_loss = self.hf_compute_loss(labels=labels, logits=prediction_scores)
# 如果不要求返回字典格式的输出
if not return_dict:
# 构造输出元组,包含预测分数及可能的额外输出
output = (prediction_scores,) + outputs[2:]
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
# 返回 TFMaskedLMOutput 对象,包含损失、预测分数、隐藏状态和注意力权重
return TFMaskedLMOutput(
loss=masked_lm_loss,
logits=prediction_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def predict_contacts(self, tokens, attention_mask):
# 调用 ESM 模型的预测接口,用于生成联系
return self.esm.predict_contacts(tokens, attention_mask)
def build(self, input_shape=None):
# 如果模型已经构建,则直接返回
if self.built:
return
# 设置模型为已构建状态
self.built = True
# 如果存在 ESM 模型,则在命名空间下构建它
if getattr(self, "esm", None) is not None:
with tf.name_scope(self.esm.name):
self.esm.build(None)
# 如果存在语言模型头部,则在命名空间下构建它
if getattr(self, "lm_head", None) is not None:
with tf.name_scope(self.lm_head.name):
self.lm_head.build(None)
class TFEsmLMHead(keras.layers.Layer):
"""ESM Head for masked language modeling."""
def __init__(self, config, name=None):
super().__init__(name=name)
# 创建一个全连接层,用于将输入特征映射到隐藏层大小的输出空间
self.dense = keras.layers.Dense(
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
)
# 添加一个 LayerNormalization 层,用于标准化输入向量
self.layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
# 如果设置了 tie_word_embeddings,decoder 为 None;否则创建一个全连接层,用于解码到词汇表大小
if config.tie_word_embeddings:
self.decoder = None
else:
self.decoder = keras.layers.Dense(
config.vocab_size,
kernel_initializer=get_initializer(config.initializer_range),
name="decoder",
use_bias=False,
)
self.config = config
def build(self, input_shape=None):
# 分离偏置项以匹配 PT 模型,并允许权重交叉加载工作
# 将其放在 build 方法中,以便在将其添加为权重时获得正确的名称
if self.built:
return
self.built = True
# 添加一个名为 "bias" 的权重,形状为 (config.vocab_size,),并初始化为零,可训练
self.bias = self.add_weight("bias", shape=(self.config.vocab_size,), initializer="zeros", trainable=True)
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
# 构建 dense 层,输入形状为 [None, None, config.hidden_size]
self.dense.build([None, None, self.config.hidden_size])
if getattr(self, "layer_norm", None) is not None:
with tf.name_scope(self.layer_norm.name):
# 构建 layer_norm 层,输入形状为 [None, None, config.hidden_size]
self.layer_norm.build([None, None, self.config.hidden_size])
if getattr(self, "decoder", None) is not None and not self.config.tie_word_embeddings:
with tf.name_scope(self.decoder.name):
# 构建 decoder 层,输入形状为 [None, None, config.hidden_size]
self.decoder.build([None, None, self.config.hidden_size])
def get_bias(self):
return {"bias": self.bias}
def call(self, features):
# 经过 dense 层映射特征
x = self.dense(features)
# 使用 gelu 激活函数
x = tf.nn.gelu(x)
# 使用 layer_norm 层标准化输出
x = self.layer_norm(x)
# 根据 tie_word_embeddings 决定如何将 x 投影回词汇表大小,同时加上偏置
if self.config.tie_word_embeddings:
x = tf.matmul(x, self.decoder, transpose_b=True) + self.bias
else:
x = self.decoder(x) + self.bias
return x
@add_start_docstrings(
"""
ESM Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
output) e.g. for GLUE tasks.
""",
ESM_START_DOCSTRING,
)
class TFEsmForSequenceClassification(TFEsmPreTrainedModel, TFSequenceClassificationLoss):
_keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config):
super().__init__(config)
# 设置分类或回归任务的标签数量
self.num_labels = config.num_labels
self.config = config
# 创建 ESM 主层,不添加池化层,命名为 "esm"
self.esm = TFEsmMainLayer(config, add_pooling_layer=False, name="esm")
# 创建分类头部,命名为 "classifier"
self.classifier = TFEsmClassificationHead(config, name="classifier")
@unpack_inputs
@add_start_docstrings_to_model_forward(ESM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TFSequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
# 将当前函数用作代码示例的文档字符串,指定了一些参数和返回类型的信息
def call(
self,
input_ids: TFModelInputType | None = None,
attention_mask: np.ndarray | tf.Tensor | None = None,
position_ids: np.ndarray | tf.Tensor | None = None,
head_mask: np.ndarray | tf.Tensor | None = None,
inputs_embeds: np.ndarray | tf.Tensor | None = None,
labels: np.ndarray | tf.Tensor | None = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
) -> Union[TFSequenceClassifierOutput, Tuple[tf.Tensor]]:
r"""
labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
# 设置 return_dict 变量,若未提供则使用 self.config.use_return_dict 中的默认值
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# 调用 self.esm 方法,执行序列编码模型的前向传播
outputs = self.esm(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
# 从模型输出中获取序列输出
sequence_output = outputs[0]
# 将序列输出传递给分类器,生成分类任务的 logits
logits = self.classifier(sequence_output)
# 计算损失,如果 labels 不为 None,则使用 labels 和 logits 计算损失值
loss = None if labels is None else self.hf_compute_loss(labels, logits)
# 如果 return_dict 为 False,则构建输出元组
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
# 如果 return_dict 为 True,则构建 TFSequenceClassifierOutput 对象作为输出
return TFSequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
# 构建模型,设置输入形状并初始化模型的各个组件
def build(self, input_shape=None):
if self.built:
return
self.built = True
# 如果存在 self.esm,则在命名空间 self.esm.name 下构建它
if getattr(self, "esm", None) is not None:
with tf.name_scope(self.esm.name):
self.esm.build(None)
# 如果存在 self.classifier,则在命名空间 self.classifier.name 下构建它
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build(None)
@add_start_docstrings(
"""
ESM Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
Named-Entity-Recognition (NER) tasks.
""",
ESM_START_DOCSTRING,
)
class TFEsmForTokenClassification(TFEsmPreTrainedModel, TFTokenClassificationLoss):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config):
super().__init__(config)
# 初始化时设置分类标签数量
self.num_labels = config.num_labels
# 创建 ESM 主模型层,不包含池化层
self.esm = TFEsmMainLayer(config, add_pooling_layer=False, name="esm")
# Dropout 层,用于防止过拟合
self.dropout = keras.layers.Dropout(config.hidden_dropout_prob)
# 分类器层,将隐藏状态输出转化为分类预测
self.classifier = keras.layers.Dense(config.num_labels, name="classifier")
# 保存配置信息
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(ESM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TFTokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def call(
self,
input_ids: TFModelInputType | None = None,
attention_mask: np.ndarray | tf.Tensor | None = None,
position_ids: np.ndarray | tf.Tensor | None = None,
head_mask: np.ndarray | tf.Tensor | None = None,
inputs_embeds: np.ndarray | tf.Tensor | None = None,
labels: np.ndarray | tf.Tensor | None = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
) -> Union[TFTokenClassifierOutput, Tuple[tf.Tensor]]:
r"""
labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
"""
# 确定是否返回字典格式的输出结果
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# 调用 ESM 主模型进行前向传播
outputs = self.esm(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
# 获取序列输出
sequence_output = outputs[0]
# 在训练时使用 Dropout 层防止过拟合
sequence_output = self.dropout(sequence_output, training=training)
# 使用分类器层生成分类预测 logits
logits = self.classifier(sequence_output)
# 如果没有提供标签,则不计算损失
loss = None if labels is None else self.hf_compute_loss(labels, logits)
# 根据是否返回字典格式来组织输出
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
# 返回 TFTokenClassifierOutput 格式的结果
return TFTokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
# 如果模型已经构建好,直接返回,不做重复构建
if self.built:
return
# 将模型标记为已构建状态
self.built = True
# 如果存在名为"esm"的属性,并且不为None,执行以下操作
if getattr(self, "esm", None) is not None:
# 在命名空间下以"esm"的名称构建模型
with tf.name_scope(self.esm.name):
# 调用esm对象的build方法,传入None作为输入形状
self.esm.build(None)
# 如果存在名为"classifier"的属性,并且不为None,执行以下操作
if getattr(self, "classifier", None) is not None:
# 在命名空间下以"classifier"的名称构建模型
with tf.name_scope(self.classifier.name):
# 调用classifier对象的build方法,传入[None, None, self.config.hidden_size]作为输入形状
self.classifier.build([None, None, self.config.hidden_size])
class TFEsmClassificationHead(keras.layers.Layer):
"""Head for sentence-level classification tasks."""
def __init__(self, config, name=None):
super().__init__(name=name)
# 定义一个全连接层,用于生成隐藏层大小的输出,激活函数为tanh
self.dense = keras.layers.Dense(
config.hidden_size,
kernel_initializer=get_initializer(config.initializer_range),
activation="tanh",
name="dense",
)
# 定义一个Dropout层,用于在训练时随机丢弃部分输入,以防止过拟合
self.dropout = keras.layers.Dropout(config.hidden_dropout_prob)
# 定义一个全连接层,用于生成类别数目大小的输出,激活函数为线性(即无激活函数)
self.out_proj = keras.layers.Dense(
config.num_labels,
kernel_initializer=get_initializer(config.initializer_range),
activation="linear",
name="out_proj",
)
self.config = config
def call(self, features, training=False):
# 提取features中的第一个位置的向量(对应于<s> token,即[CLS]),作为输入x
x = features[:, 0, :] # take <s> token (equiv. to [CLS])
# 在训练阶段使用dropout随机丢弃部分输入向量,防止过拟合
x = self.dropout(x, training=training)
# 将输入向量x通过全连接层dense进行线性变换,并应用tanh激活函数
x = self.dense(x)
# 再次在训练阶段使用dropout随机丢弃部分输出向量,防止过拟合
x = self.dropout(x, training=training)
# 将处理后的向量x通过全连接层out_proj进行线性变换,生成最终的分类输出
x = self.out_proj(x)
return x
def build(self, input_shape=None):
if self.built:
return
self.built = True
# 如果dense层已定义,则建立其内部权重
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
# 建立dense层的权重,输入形状为[None, None, hidden_size]
self.dense.build([None, None, self.config.hidden_size])
# 如果out_proj层已定义,则建立其内部权重
if getattr(self, "out_proj", None) is not None:
with tf.name_scope(self.out_proj.name):
# 建立out_proj层的权重,输入形状为[None, None, hidden_size]
self.out_proj.build([None, None, self.config.hidden_size])
def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
"""
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
are ignored. This is modified from fairseq's `utils.make_positions`.
Args:
input_ids: 输入的整数张量,表示输入的符号序列
padding_idx: 表示填充符号的索引
past_key_values_length: 过去键值长度,用于计算增量索引
Returns:
tf.Tensor: 包含位置ID的张量,替换非填充符号为其位置数字
"""
# 创建一个掩码,标记出不是填充符号的位置
mask = tf.cast(input_ids != padding_idx, tf.int64)
# 计算每个位置的增量索引,跳过填充符号,位置编号从padding_idx+1开始
incremental_indices = (tf.cumsum(mask, axis=1) + past_key_values_length) * mask
# 将增量索引加上padding_idx,得到最终的位置ID张量
return incremental_indices + padding_idx