Transformers 源码解析(六十四)
.\models\led\tokenization_led.py
"""LED 的分词类。"""
import json
import os
from functools import lru_cache
from typing import Dict, List, Optional, Tuple, Union
import regex as re
from ...tokenization_utils import AddedToken, PreTrainedTokenizer
from ...tokenization_utils_base import BatchEncoding, EncodedInput
from ...utils import PaddingStrategy, logging
logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt"}
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"allenai/led-base-16384": "https://huggingface.co/allenai/led-base-16384/resolve/main/vocab.json",
},
"merges_file": {
"allenai/led-base-16384": "https://huggingface.co/allenai/led-base-16384/resolve/main/merges.txt",
},
"tokenizer_file": {
"allenai/led-base-16384": "https://huggingface.co/allenai/led-base-16384/resolve/main/tokenizer.json",
},
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"allenai/led-base-16384": 16384,
}
@lru_cache()
def bytes_to_unicode():
"""
返回 utf-8 字节列表和 Unicode 字符串的映射。避免映射到空白字符或控制字符,以免引起 bpe 代码错误。
可逆的 bpe 代码适用于 Unicode 字符串。这意味着如果要避免 UNK(未知)符号,词汇表中需要大量的 Unicode 字符。
当数据集达到约 100 亿个标记时,您需要大约 5000 个 Unicode 字符以获得良好的覆盖率。
这在普通的 32K bpe 词汇表中占有相当大的比例。为了避免这种情况,我们需要 utf-8 字节和 Unicode 字符串之间的查找表。
"""
bs = (
list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
)
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8 + n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
"""
返回单词中的符号对集合。
单词表示为符号元组(符号是长度可变的字符串)。
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
class LEDTokenizer(PreTrainedTokenizer):
"""
Constructs a LED tokenizer, which is similar to the ROBERTa tokenizer, using byte-level Byte-Pair-Encoding.
This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will
be encoded differently whether it is at the beginning of the sentence (without space) or not:
```
>>> from transformers import LEDTokenizer
>>> tokenizer = LEDTokenizer.from_pretrained("allenai/led-base-16384")
>>> tokenizer("Hello world")["input_ids"]
[0, 31414, 232, 2]
>>> tokenizer(" Hello world")["input_ids"]
[0, 20920, 232, 2]
```
You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer or when you
call it on some text, but since the model was not pretrained this way, it might yield a decrease in performance.
<Tip>
When used with `is_split_into_words=True`, this tokenizer will add a space before each word (even the first one).
</Tip>
This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
this superclass for more information regarding those methods.
"""
def __init__(self, vocab_file, merges_file, **kwargs):
"""
Initializes the LEDTokenizer with the provided vocabulary and merges files.
Args:
vocab_file (str): Path to the vocabulary file.
merges_file (str): Path to the merges file.
kwargs: Additional arguments passed to the tokenizer initialization.
"""
super().__init__(**kwargs)
self.vocab_file = vocab_file
self.merges_file = merges_file
self.encoder = json.load(open(vocab_file))
self.decoder = {v: k for k, v in self.encoder.items()}
with open(merges_file, encoding="utf-8") as f:
bpe_data = f.read().split("\n")[1:-1]
merges = [(tuple(merge.split()[0:2]), int(merge.split()[2])) for merge in bpe_data]
self.bpe_ranks = dict(zip(merges, range(len(merges))))
def _tokenize(self, text):
"""
Tokenizes a given text into subwords.
Args:
text (str): The input text to tokenize.
Returns:
List[str]: A list of subwords representing the tokenized text.
"""
merges = self._split_to_subwords(text)
return merges
def _split_to_subwords(self, text):
"""
Splits the text into subwords based on the BPE merges.
Args:
text (str): The input text to split.
Returns:
List[str]: A list of subwords.
"""
return []
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
model_input_names = ["input_ids", "attention_mask"]
def __init__(
self,
vocab_file,
merges_file,
errors="replace",
bos_token="<s>",
eos_token="</s>",
sep_token="</s>",
cls_token="<s>",
unk_token="<unk>",
pad_token="<pad>",
mask_token="<mask>",
add_prefix_space=False,
**kwargs,
):
bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
sep_token = AddedToken(sep_token, lstrip=False, rstrip=False) if isinstance(sep_token, str) else sep_token
cls_token = AddedToken(cls_token, lstrip=False, rstrip=False) if isinstance(cls_token, str) else cls_token
unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
with open(vocab_file, encoding="utf-8") as vocab_handle:
self.encoder = json.load(vocab_handle)
self.decoder = {v: k for k, v in self.encoder.items()}
self.errors = errors
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
with open(merges_file, encoding="utf-8") as merges_handle:
bpe_merges = merges_handle.read().split("\n")[1:-1]
bpe_merges = [tuple(merge.split()) for merge in bpe_merges]
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
self.cache = {}
self.add_prefix_space = add_prefix_space
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
super().__init__(
errors=errors,
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
sep_token=sep_token,
cls_token=cls_token,
pad_token=pad_token,
mask_token=mask_token,
add_prefix_space=add_prefix_space,
**kwargs,
)
@property
def vocab_size(self):
return len(self.encoder)
def get_vocab(self):
return dict(self.encoder, **self.added_tokens_encoder)
def bpe
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token)
pairs = get_pairs(word)
if not pairs:
return token
while True:
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
except ValueError:
new_word.extend(word[i:])
break
else:
new_word.extend(word[i:j])
i = j
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = " ".join(word)
self.cache[token] = word
return word
def _tokenize(self, text):
"""Tokenize a string."""
bpe_tokens = []
for token in re.findall(self.pat, text):
token = "".join(
self.byte_encoder[b] for b in token.encode("utf-8")
)
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
return bpe_tokens
def _convert_token_to_id(self, token):
"""Converts a token (str) in an id using the vocab."""
return self.encoder.get(token, self.encoder.get(self.unk_token))
def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab."""
return self.decoder.get(index)
def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (string) in a single string."""
text = "".join(tokens)
text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
return text
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
return
vocab_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
)
merge_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
)
with open(vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
index = 0
with open(merge_file, "w", encoding="utf-8") as writer:
writer.write("#version: 0.2\n")
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning(
f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive."
" Please check that the tokenizer is not corrupted!"
)
index = token_index
writer.write(" ".join(bpe_tokens) + "\n")
index += 1
return vocab_file, merge_file
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
通过添加特殊标记,为序列分类任务构建模型输入。LED 序列有以下格式:
- 单个序列:`<s> X </s>`
- 序列对:`<s> A </s></s> B </s>`
Args:
token_ids_0 (`List[int]`):
要添加特殊标记的 ID 列表。
token_ids_1 (`List[int]`, *可选*):
第二个序列的 ID 列表(用于序列对任务)。
Returns:
`List[int]`: 包含适当特殊标记的输入 ID 列表。
"""
if token_ids_1 is None:
return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
cls = [self.cls_token_id]
sep = [self.sep_token_id]
return cls + token_ids_0 + sep + sep + token_ids_1 + sep
def get_special_tokens_mask(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
) -> List[int]:
) -> List[int]:
"""
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer `prepare_for_model` method.
Args:
token_ids_0 (`List[int]`):
List of IDs.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not the token list is already formatted with special tokens for the model.
Returns:
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
"""
if already_has_special_tokens:
return super().get_special_tokens_mask(
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
)
if token_ids_1 is None:
return [1] + ([0] * len(token_ids_0)) + [1]
return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
def create_token_type_ids_from_sequences(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Create a mask from the two sequences passed to be used in a sequence-pair classification task. LED does not
make use of token type ids, therefore a list of zeros is returned.
Args:
token_ids_0 (`List[int]`):
List of IDs.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: List of zeros.
"""
sep = [self.sep_token_id]
cls = [self.cls_token_id]
if token_ids_1 is None:
return len(cls + token_ids_0 + sep) * [0]
return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs):
"""
Prepare text for tokenization by adding a space prefix if specified and not already present.
Args:
text (str): The input text to be tokenized.
is_split_into_words (bool, optional): Whether the text is already split into words.
**kwargs: Additional keyword arguments.
Returns:
tuple: A tuple containing the modified text and remaining keyword arguments.
"""
add_prefix_space = kwargs.pop("add_prefix_space", self.add_prefix_space)
if (is_split_into_words or add_prefix_space) and (len(text) > 0 and not text[0].isspace()):
text = " " + text
return (text, kwargs)
def _pad(
self,
encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
max_length: Optional[int] = None,
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
pad_to_multiple_of: Optional[int] = None,
return_attention_mask: Optional[bool] = None,
) -> dict:
encoded_inputs = super()._pad(
encoded_inputs=encoded_inputs,
max_length=max_length,
padding_strategy=padding_strategy,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask,
)
if return_attention_mask is None:
return_attention_mask = "attention_mask" in self.model_input_names
if return_attention_mask and "global_attention_mask" in encoded_inputs:
required_input = encoded_inputs[self.model_input_names[0]]
needs_to_be_padded = len(encoded_inputs["global_attention_mask"]) != len(required_input)
if needs_to_be_padded:
difference = len(required_input) - len(encoded_inputs["global_attention_mask"])
if self.padding_side == "right":
encoded_inputs["global_attention_mask"] = (
encoded_inputs["global_attention_mask"] + [-1] * difference
)
elif self.padding_side == "left":
encoded_inputs["global_attention_mask"] = [-1] * difference + encoded_inputs[
"global_attention_mask"
]
else:
raise ValueError("Invalid padding strategy:" + str(self.padding_side))
return encoded_inputs
.\models\led\tokenization_led_fast.py
class LEDTokenizerFast(PreTrainedTokenizerFast):
r"""
Construct a "fast" LED tokenizer (backed by HuggingFace's *tokenizers* library), derived from the GPT-2 tokenizer,
using byte-level Byte-Pair-Encoding.
This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will
be encoded differently whether it is at the beginning of the sentence (without space) or not:
```
>>> from transformers import LEDTokenizerFast
>>> tokenizer = LEDTokenizerFast.from_pretrained("allenai/led-base-16384")
>>> tokenizer("Hello world")["input_ids"]
[0, 31414, 232, 2]
>>> tokenizer(" Hello world")["input_ids"]
[0, 20920, 232, 2]
```
You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer or when you
call it on some text, but since the model was not pretrained this way, it might yield a decrease in performance.
<Tip>
When used with `is_split_into_words=True`, this tokenizer needs to be instantiated with `add_prefix_space=True`.
</Tip>
This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
"""
refer to this superclass for more information regarding those methods.
Args:
vocab_file (`str`):
Path to the vocabulary file.
merges_file (`str`):
Path to the merges file.
errors (`str`, *optional*, defaults to `"replace"`):
Paradigm to follow when decoding bytes to UTF-8. See
[bytes.decode](https://docs.python.org/3/library/stdtypes.html
bos_token (`str`, *optional*, defaults to `"<s>"`):
The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
<Tip>
When building a sequence using special tokens, this is not the token that is used for the beginning of
sequence. The token used is the `cls_token`.
</Tip>
eos_token (`str`, *optional*, defaults to `"</s>"`):
The end of sequence token.
<Tip>
When building a sequence using special tokens, this is not the token that is used for the end of sequence.
The token used is the `sep_token`.
</Tip>
sep_token (`str`, *optional*, defaults to `"</s>"`):
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
sequence classification or for a text and a question for question answering. It is also used as the last
token of a sequence built with special tokens.
cls_token (`str`, *optional*, defaults to `"<s>"`):
The classifier token which is used when doing sequence classification (classification of the whole sequence
instead of per-token classification). It is the first token of the sequence when built with special tokens.
unk_token (`str`, *optional*, defaults to `"<unk>"`):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
token instead.
pad_token (`str`, *optional*, defaults to `"<pad>"`):
The token used for padding, for example when batching sequences of different lengths.
mask_token (`str`, *optional*, defaults to `"<mask>"`):
The token used for masking values. This is the token used when training this model with masked language
modeling. This is the token which the model will try to predict.
add_prefix_space (`bool`, *optional*, defaults to `False`):
Whether or not to add an initial space to the input. This allows to treat the leading word just as any
other word. (LED tokenizer detect beginning of words by the preceding space).
trim_offsets (`bool`, *optional*, defaults to `True`):
Whether the post processing step should trim offsets to avoid including whitespaces.
"""
# 设置两个常量变量
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
# 设置最大模型输入尺寸为预训练位置嵌入大小
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
# 指定慢速分词器的类为 LEDTokenizer
slow_tokenizer_class = LEDTokenizer
# 模型输入的名称列表,包括 input_ids 和 attention_mask
model_input_names = ["input_ids", "attention_mask"]
# 以下内容是从 transformers.models.bart.tokenization_bart_fast.BartTokenizerFast.__init__ 中复制过来的
# 初始化方法
def __init__(
self,
vocab_file=None,
merges_file=None,
tokenizer_file=None,
errors="replace",
bos_token="<s>",
eos_token="</s>",
sep_token="</s>",
cls_token="<s>",
unk_token="<unk>",
pad_token="<pad>",
mask_token="<mask>",
add_prefix_space=False,
trim_offsets=True,
**kwargs,
):
# 如果 `mask_token` 是字符串,创建一个带有特殊标志的 AddedToken 对象,用于表示特殊的 MASK 标记
mask_token = (
AddedToken(mask_token, lstrip=True, normalized=True, special=True)
if isinstance(mask_token, str)
else mask_token
)
# 调用父类的初始化方法,初始化 LEDTokenizerFast 对象
super().__init__(
vocab_file,
merges_file,
tokenizer_file=tokenizer_file,
errors=errors,
bos_token=bos_token,
eos_token=eos_token,
sep_token=sep_token,
cls_token=cls_token,
unk_token=unk_token,
pad_token=pad_token,
mask_token=mask_token,
add_prefix_space=add_prefix_space,
trim_offsets=trim_offsets,
**kwargs,
)
# 获取当前前置处理器的状态,并检查是否需要更新 `add_prefix_space` 属性
pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__())
if pre_tok_state.get("add_prefix_space", add_prefix_space) != add_prefix_space:
# 如果前置处理器的 `add_prefix_space` 属性不匹配当前设定,更新前置处理器的状态
pre_tok_class = getattr(pre_tokenizers, pre_tok_state.pop("type"))
pre_tok_state["add_prefix_space"] = add_prefix_space
self.backend_tokenizer.pre_tokenizer = pre_tok_class(**pre_tok_state)
self.add_prefix_space = add_prefix_space
# 检查后处理器的状态,并更新 `sep` 和 `cls` 标记为元组,以便与 LED 的 `post_processor` 兼容
tokenizer_component = "post_processor"
tokenizer_component_instance = getattr(self.backend_tokenizer, tokenizer_component, None)
if tokenizer_component_instance:
state = json.loads(tokenizer_component_instance.__getstate__())
if "sep" in state:
state["sep"] = tuple(state["sep"])
if "cls" in state:
state["cls"] = tuple(state["cls"])
changes_to_apply = False
# 检查后处理器的状态是否需要更新 `add_prefix_space` 和 `trim_offsets` 属性
if state.get("add_prefix_space", add_prefix_space) != add_prefix_space:
state["add_prefix_space"] = add_prefix_space
changes_to_apply = True
if state.get("trim_offsets", trim_offsets) != trim_offsets:
state["trim_offsets"] = trim_offsets
changes_to_apply = True
# 如果有更改需要应用,则创建新的后处理器实例并更新到 LEDTokenizerFast 对象中
if changes_to_apply:
component_class = getattr(processors, state.pop("type"))
new_value = component_class(**state)
setattr(self.backend_tokenizer, tokenizer_component, new_value)
def mask_token(self) -> str:
"""
`str`: 获取掩码标记,用于训练掩码语言建模的模型。如果尚未设置,则记录错误信息。
LED 分词器具有特殊的掩码标记,用于填充掩码管道中的空白。掩码标记将贪婪地包括在 *<mask>* 前的空格。
"""
# 如果掩码标记未设置,则记录错误信息并返回 None
if self._mask_token is None:
if self.verbose:
logger.error("Using mask_token, but it is not set yet.")
return None
# 返回掩码标记的字符串表示
return str(self._mask_token)
@mask_token.setter
def mask_token(self, value):
"""
设置掩码标记的默认行为,使其在之前包含空格。
这是为了与所有先前使用的基于 LED 的模型保持向后兼容所必需的。
"""
# 如果值是字符串类型,则创建 AddedToken 对象,并设置 lstrip=True,rstrip=False,使掩码标记行为类似普通词
value = AddedToken(value, lstrip=True, rstrip=False) if isinstance(value, str) else value
self._mask_token = value
# 从 transformers.models.bart.tokenization_bart_fast.BartTokenizerFast._batch_encode_plus 复制而来
def _batch_encode_plus(self, *args, **kwargs) -> BatchEncoding:
is_split_into_words = kwargs.get("is_split_into_words", False)
# 如果输入被预分词且没有添加前缀空格,则抛出 ValueError
if is_split_into_words and not self.add_prefix_space:
raise ValueError(
f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True "
"to use it with pretokenized inputs."
)
# 调用父类的 _batch_encode_plus 方法进行批处理编码
return super()._batch_encode_plus(*args, **kwargs)
# 从 transformers.models.bart.tokenization_bart_fast.BartTokenizerFast._encode_plus 复制而来
def _encode_plus(self, *args, **kwargs) -> BatchEncoding:
is_split_into_words = kwargs.get("is_split_into_words", False)
# 如果输入被预分词且没有添加前缀空格,则抛出 ValueError
if is_split_into_words and not self.add_prefix_space:
raise ValueError(
f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True "
"to use it with pretokenized inputs."
)
# 调用父类的 _encode_plus 方法进行编码
return super()._encode_plus(*args, **kwargs)
# 从 transformers.models.bart.tokenization_bart_fast.BartTokenizerFast.save_vocabulary 复制而来
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
# 调用内部的 tokenizer.model.save 方法保存词汇表到指定目录下
files = self._tokenizer.model.save(save_directory, name=filename_prefix)
return tuple(files)
# 从 transformers.models.bart.tokenization_bart_fast.BartTokenizerFast.build_inputs_with_special_tokens 复制而来
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
# 构建带有特殊标记的输入,包括起始标记、终止标记
output = [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
if token_ids_1 is None:
return output
# 如果存在第二个输入序列,添加终止标记,并连接第二个输入序列及其终止标记
return output + [self.eos_token_id] + token_ids_1 + [self.eos_token_id]
# 从 BART -> LED 的转换中复制的方法,用于根据输入序列创建token类型ID
def create_token_type_ids_from_sequences(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
创建用于序列对分类任务的掩码。LED 不使用token类型ID,因此返回一个由零组成的列表。
Args:
token_ids_0 (`List[int]`):
第一个序列的ID列表。
token_ids_1 (`List[int]`, *optional*):
第二个序列的ID列表,用于序列对。
Returns:
`List[int]`: 全零列表。
"""
sep = [self.sep_token_id] # 分隔符的token ID列表
cls = [self.cls_token_id] # 类别标记的token ID列表
if token_ids_1 is None:
# 如果只有一个输入序列,则返回一个由零填充的列表
return len(cls + token_ids_0 + sep) * [0]
# 如果有两个输入序列,则返回一个由零填充的列表,包括两个分隔符
return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
# 从 transformers.models.led.tokenization_led.LEDTokenizer._pad 复制的方法
def _pad(
self,
encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
max_length: Optional[int] = None,
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
pad_to_multiple_of: Optional[int] = None,
return_attention_mask: Optional[bool] = None,
) -> dict:
encoded_inputs = super()._pad(
encoded_inputs=encoded_inputs,
max_length=max_length,
padding_strategy=padding_strategy,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask,
)
使用 `super()._pad` 方法对输入进行填充操作,返回填充后的编码输入字典。
# Load from model defaults
if return_attention_mask is None:
return_attention_mask = "attention_mask" in self.model_input_names
如果 `return_attention_mask` 为 `None`,则检查模型输入名称中是否包含 `"attention_mask"`,将其赋值给 `return_attention_mask`。
if return_attention_mask and "global_attention_mask" in encoded_inputs:
required_input = encoded_inputs[self.model_input_names[0]]
# `global_attention_mask` need to have the same length as other (sequential) inputs.
needs_to_be_padded = len(encoded_inputs["global_attention_mask"]) != len(required_input)
如果 `return_attention_mask` 为真且 `encoded_inputs` 中包含 `"global_attention_mask"`:
- 获取第一个模型输入的名称,并检查 `"global_attention_mask"` 的长度是否与该输入的长度相同。
if needs_to_be_padded:
difference = len(required_input) - len(encoded_inputs["global_attention_mask"])
if self.padding_side == "right":
# Use `-1` since `0` in `global_attention_mask` means `local attention` instead of `not to attend`
encoded_inputs["global_attention_mask"] = (
encoded_inputs["global_attention_mask"] + [-1] * difference
)
elif self.padding_side == "left":
encoded_inputs["global_attention_mask"] = [-1] * difference + encoded_inputs[
"global_attention_mask"
]
else:
raise ValueError("Invalid padding strategy:" + str(self.padding_side))
如果需要进行填充:
- 计算差异,确定填充方向(右侧或左侧),将 `-1` 添加到 `global_attention_mask` 以保持与其他输入相同的长度。
return encoded_inputs
返回填充后的编码输入字典 `encoded_inputs`。
.\models\led\__init__.py
from typing import TYPE_CHECKING
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_tf_available,
is_tokenizers_available,
is_torch_available,
)
_import_structure = {
"configuration_led": ["LED_PRETRAINED_CONFIG_ARCHIVE_MAP", "LEDConfig"],
"tokenization_led": ["LEDTokenizer"],
}
try:
if not is_tokenizers_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["tokenization_led_fast"] = ["LEDTokenizerFast"]
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_led"] = [
"LED_PRETRAINED_MODEL_ARCHIVE_LIST",
"LEDForConditionalGeneration",
"LEDForQuestionAnswering",
"LEDForSequenceClassification",
"LEDModel",
"LEDPreTrainedModel",
]
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_tf_led"] = ["TFLEDForConditionalGeneration", "TFLEDModel", "TFLEDPreTrainedModel"]
if TYPE_CHECKING:
from .configuration_led import LED_PRETRAINED_CONFIG_ARCHIVE_MAP, LEDConfig
from .tokenization_led import LEDTokenizer
try:
if not is_tokenizers_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .tokenization_led_fast import LEDTokenizerFast
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_led import (
LED_PRETRAINED_MODEL_ARCHIVE_LIST,
LEDForConditionalGeneration,
LEDForQuestionAnswering,
LEDForSequenceClassification,
LEDModel,
LEDPreTrainedModel,
)
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_tf_led import TFLEDForConditionalGeneration, TFLEDModel, TFLEDPreTrainedModel
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
.\models\levit\configuration_levit.py
"""
LeViT model configuration
"""
from collections import OrderedDict
from typing import Mapping
from packaging import version
from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig
from ...utils import logging
logger = logging.get_logger(__name__)
LEVIT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"facebook/levit-128S": "https://huggingface.co/facebook/levit-128S/resolve/main/config.json",
}
class LevitConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`LevitModel`]. It is used to instantiate a LeViT
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the LeViT
[facebook/levit-128S](https://huggingface.co/facebook/levit-128S) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
"""
model_type = "levit"
def __init__(
self,
image_size=224,
num_channels=3,
kernel_size=3,
stride=2,
padding=1,
patch_size=16,
hidden_sizes=[128, 256, 384],
num_attention_heads=[4, 8, 12],
depths=[4, 4, 4],
key_dim=[16, 16, 16],
drop_path_rate=0,
mlp_ratio=[2, 2, 2],
attention_ratio=[2, 2, 2],
initializer_range=0.02,
**kwargs,
):
):
super().__init__(**kwargs)
self.image_size = image_size
self.num_channels = num_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.hidden_sizes = hidden_sizes
self.num_attention_heads = num_attention_heads
self.depths = depths
self.key_dim = key_dim
self.drop_path_rate = drop_path_rate
self.patch_size = patch_size
self.attention_ratio = attention_ratio
self.mlp_ratio = mlp_ratio
self.initializer_range = initializer_range
self.down_ops = [
["Subsample", key_dim[0], hidden_sizes[0] // key_dim[0], 4, 2, 2],
["Subsample", key_dim[0], hidden_sizes[1] // key_dim[0], 4, 2, 2],
]
class LevitOnnxConfig(OnnxConfig):
torch_onnx_minimum_version = version.parse("1.11")
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict(
[
("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
]
)
@property
def atol_for_validation(self) -> float:
return 1e-4
.\models\levit\convert_levit_timm_to_pytorch.py
"""从 timm 转换 LeViT 检查点。"""
import argparse
import json
from collections import OrderedDict
from functools import partial
from pathlib import Path
import timm
import torch
from huggingface_hub import hf_hub_download
from transformers import LevitConfig, LevitForImageClassificationWithTeacher, LevitImageProcessor
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger()
def convert_weight_and_push(
hidden_sizes: int, name: str, config: LevitConfig, save_directory: Path, push_to_hub: bool = True
):
print(f"Converting {name}...")
with torch.no_grad():
if hidden_sizes == 128:
if name[-1] == "S":
from_model = timm.create_model("levit_128s", pretrained=True)
else:
from_model = timm.create_model("levit_128", pretrained=True)
elif hidden_sizes == 192:
from_model = timm.create_model("levit_192", pretrained=True)
elif hidden_sizes == 256:
from_model = timm.create_model("levit_256", pretrained=True)
elif hidden_sizes == 384:
from_model = timm.create_model("levit_384", pretrained=True)
from_model.eval()
our_model = LevitForImageClassificationWithTeacher(config).eval()
huggingface_weights = OrderedDict()
weights = from_model.state_dict()
og_keys = list(from_model.state_dict().keys())
new_keys = list(our_model.state_dict().keys())
print(len(og_keys), len(new_keys))
for i in range(len(og_keys)):
huggingface_weights[new_keys[i]] = weights[og_keys[i]]
our_model.load_state_dict(huggingface_weights)
x = torch.randn((2, 3, 224, 224))
out1 = from_model(x)
out2 = our_model(x).logits
assert torch.allclose(out1, out2), "The model logits don't match the original one."
checkpoint_name = name
print(checkpoint_name)
if push_to_hub:
our_model.save_pretrained(save_directory / checkpoint_name)
image_processor = LevitImageProcessor()
image_processor.save_pretrained(save_directory / checkpoint_name)
print(f"Pushed {checkpoint_name}")
def convert_weights_and_push(save_directory: Path, model_name: str = None, push_to_hub: bool = True):
filename = "imagenet-1k-id2label.json"
num_labels = 1000
expected_shape = (1, num_labels)
repo_id = "huggingface/label-files"
num_labels = num_labels
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
id2label = id2label
label2id = {v: k for k, v in id2label.items()}
ImageNetPreTrainedConfig = partial(LevitConfig, num_labels=num_labels, id2label=id2label, label2id=label2id)
names_to_hidden_sizes = {
"levit-128S": 128,
"levit-128": 128,
"levit-192": 192,
"levit-256": 256,
"levit-384": 384,
}
names_to_config = {
"levit-128S": ImageNetPreTrainedConfig(
hidden_sizes=[128, 256, 384],
num_attention_heads=[4, 6, 8],
depths=[2, 3, 4],
key_dim=[16, 16, 16],
drop_path_rate=0,
),
"levit-128": ImageNetPreTrainedConfig(
hidden_sizes=[128, 256, 384],
num_attention_heads=[4, 8, 12],
depths=[4, 4, 4],
key_dim=[16, 16, 16],
drop_path_rate=0,
),
"levit-192": ImageNetPreTrainedConfig(
hidden_sizes=[192, 288, 384],
num_attention_heads=[3, 5, 6],
depths=[4, 4, 4],
key_dim=[32, 32, 32],
drop_path_rate=0,
),
"levit-256": ImageNetPreTrainedConfig(
hidden_sizes=[256, 384, 512],
num_attention_heads=[4, 6, 8],
depths=[4, 4, 4],
key_dim=[32, 32, 32],
drop_path_rate=0,
),
"levit-384": ImageNetPreTrainedConfig(
hidden_sizes=[384, 512, 768],
num_attention_heads=[6, 9, 12],
depths=[4, 4, 4],
key_dim=[32, 32, 32],
drop_path_rate=0.1,
),
}
if model_name:
convert_weight_and_push(
names_to_hidden_sizes[model_name], model_name, names_to_config[model_name], save_directory, push_to_hub
)
else:
for model_name, config in names_to_config.items():
convert_weight_and_push(names_to_hidden_sizes[model_name], model_name, config, save_directory, push_to_hub)
return config, expected_shape
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name",
default=None,
type=str,
help="The name of the model you wish to convert, it must be one of the supported Levit* architecture,",
)
parser.add_argument(
"--pytorch_dump_folder_path",
default="levit-dump-folder/",
type=Path,
required=False,
help="Path to the output PyTorch model directory.",
)
parser.add_argument("--push_to_hub", action="store_true", help="Push model and image processor to the hub")
parser.add_argument(
"--no-push_to_hub",
dest="push_to_hub",
action="store_false",
help="Do not push model and image processor to the hub",
)
args = parser.parse_args()
pytorch_dump_folder_path: Path = args.pytorch_dump_folder_path
pytorch_dump_folder_path.mkdir(exist_ok=True, parents=True)
convert_weights_and_push(pytorch_dump_folder_path, args.model_name, args.push_to_hub)
.\models\levit\feature_extraction_levit.py
"""LeViT 的特征提取器类。"""
import warnings
from ...utils import logging
from .image_processing_levit import LevitImageProcessor
logger = logging.get_logger(__name__)
class LevitFeatureExtractor(LevitImageProcessor):
def __init__(self, *args, **kwargs) -> None:
warnings.warn(
"The class LevitFeatureExtractor is deprecated and will be removed in version 5 of Transformers. Please"
" use LevitImageProcessor instead.",
FutureWarning,
)
super().__init__(*args, **kwargs)
.\models\levit\image_processing_levit.py
from typing import Dict, Iterable, Optional, Union
import numpy as np
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import (
get_resize_output_image_size,
resize,
to_channel_dimension_format,
)
from ...image_utils import (
IMAGENET_DEFAULT_MEAN,
IMAGENET_DEFAULT_STD,
ChannelDimension,
ImageInput,
PILImageResampling,
infer_channel_dimension_format,
is_scaled_image,
make_list_of_images,
to_numpy_array,
valid_images,
validate_kwargs,
validate_preprocess_arguments,
)
from ...utils import TensorType, logging
logger = logging.get_logger(__name__)
class LevitImageProcessor(BaseImageProcessor):
r"""
Constructs a LeViT image processor.
"""
Args:
do_resize (`bool`, *optional*, defaults to `True`):
是否调整输入图像的最短边至 int(256/224 * size),可以在 `preprocess` 方法中的 `do_resize` 参数中覆盖。
size (`Dict[str, int]`, *optional*, defaults to `{"shortest_edge": 224}`):
调整后的输出图像尺寸。如果 `size` 是一个包含 "width" 和 "height" 键的字典,图像将被调整至 `(size["height"], size["width"])`。如果 `size` 是一个包含 "shortest_edge" 键的字典,最短边的值 `c` 将被重新缩放为 `int(c * (256/224))`。图像的较小边将被匹配到此值,例如,如果 height > width,则图像将被缩放至 `(size["shortest_edge"] * height / width, size["shortest_edge"])`。可以在 `preprocess` 方法中的 `size` 参数中覆盖。
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
如果调整图像大小,使用的重采样滤波器。可以在 `preprocess` 方法中的 `resample` 参数中覆盖。
do_center_crop (`bool`, *optional*, defaults to `True`):
是否对输入图像进行中心裁剪至 `(crop_size["height"], crop_size["width"])`。可以在 `preprocess` 方法中的 `do_center_crop` 参数中覆盖。
crop_size (`Dict`, *optional*, defaults to `{"height": 224, "width": 224}`):
`center_crop` 后的期望图像尺寸。可以在 `preprocess` 方法中的 `crop_size` 参数中覆盖。
do_rescale (`bool`, *optional*, defaults to `True`):
控制是否按指定的比例因子 `rescale_factor` 重新缩放图像。可以在 `preprocess` 方法中的 `do_rescale` 参数中覆盖。
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
如果重新缩放图像,要使用的缩放因子。可以在 `preprocess` 方法中的 `rescale_factor` 参数中覆盖。
do_normalize (`bool`, *optional*, defaults to `True`):
控制是否对图像进行归一化。可以在 `preprocess` 方法中的 `do_normalize` 参数中覆盖。
image_mean (`List[int]`, *optional*, defaults to `[0.485, 0.456, 0.406]`):
如果归一化图像,要使用的均值。这是一个浮点数或与图像通道数相同长度的浮点数列表。可以在 `preprocess` 方法中的 `image_mean` 参数中覆盖。
image_std (`List[int]`, *optional*, defaults to `[0.229, 0.224, 0.225]`):
如果归一化图像,要使用的标准差。这是一个浮点数或与图像通道数相同长度的浮点数列表。可以在 `preprocess` 方法中的 `image_std` 参数中覆盖。
"""
model_input_names = ["pixel_values"]
def __init__(
self,
do_resize: bool = True,
size: Dict[str, int] = None,
resample: PILImageResampling = PILImageResampling.BICUBIC,
do_center_crop: bool = True,
crop_size: Dict[str, int] = None,
do_rescale: bool = True,
rescale_factor: Union[int, float] = 1 / 255,
do_normalize: bool = True,
image_mean: Optional[Union[float, Iterable[float]]] = IMAGENET_DEFAULT_MEAN,
image_std: Optional[Union[float, Iterable[float]]] = IMAGENET_DEFAULT_STD,
**kwargs,
) -> None:
# 调用父类初始化方法
super().__init__(**kwargs)
# 如果 size 参数为 None,则设置默认的最短边为 224
size = size if size is not None else {"shortest_edge": 224}
# 根据给定的 size 参数获取大小的字典,确保不会默认为正方形
size = get_size_dict(size, default_to_square=False)
# 如果 crop_size 参数为 None,则设置默认的高度和宽度均为 224
crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
# 根据给定的 crop_size 参数获取裁剪大小的字典
crop_size = get_size_dict(crop_size, param_name="crop_size")
# 初始化类成员变量
self.do_resize = do_resize
self.size = size
self.resample = resample
self.do_center_crop = do_center_crop
self.crop_size = crop_size
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
# 设置有效的处理器关键字列表,包括图像处理相关的参数和数据格式参数
self._valid_processor_keys = [
"images",
"do_resize",
"size",
"resample",
"do_center_crop",
"crop_size",
"do_rescale",
"rescale_factor",
"do_normalize",
"image_mean",
"image_std",
"return_tensors",
"data_format",
"input_data_format",
]
) -> np.ndarray:
"""
Resize an image.
If size is a dict with keys "width" and "height", the image will be resized to `(size["height"],
size["width"])`.
If size is a dict with key "shortest_edge", the shortest edge value `c` is rescaled to `int(c * (256/224))`.
The smaller edge of the image will be matched to this value i.e, if height > width, then image will be rescaled
to `(size["shortest_egde"] * height / width, size["shortest_egde"])`.
Args:
image (`np.ndarray`):
Image to resize.
size (`Dict[str, int]`):
Size of the output image after resizing. If size is a dict with keys "width" and "height", the image
will be resized to (height, width). If size is a dict with key "shortest_edge", the shortest edge value
`c` is rescaled to int(`c` * (256/224)). The smaller edge of the image will be matched to this value
i.e, if height > width, then image will be rescaled to (size * height / width, size).
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
Resampling filter to use when resizing the image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
"""
# Determine the target size dictionary based on input size parameters
size_dict = get_size_dict(size, default_to_square=False)
# Check if 'shortest_edge' is specified in size dictionary
if "shortest_edge" in size:
# Calculate the length of the shortest edge based on the scaling factor (256/224)
shortest_edge = int((256 / 224) * size["shortest_edge"])
# Determine the output size after resizing based on the calculated shortest edge
output_size = get_resize_output_image_size(
image, size=shortest_edge, default_to_square=False, input_data_format=input_data_format
)
# Update size_dict to reflect the height and width after resizing
size_dict = {"height": output_size[0], "width": output_size[1]}
# Ensure that the size_dict contains both 'height' and 'width' keys
if "height" not in size_dict or "width" not in size_dict:
# Raise an error if the size_dict does not have the required keys
raise ValueError(
f"Size dict must have keys 'height' and 'width' or 'shortest_edge'. Got {size_dict.keys()}"
)
# Resize the image to the specified dimensions using the resize function
return resize(
image,
size=(size_dict["height"], size_dict["width"]),
resample=resample,
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
)
# 定义一个预处理方法,用于处理图像数据
def preprocess(
self,
images: ImageInput, # 图像输入,可以是单张图像或图像列表
do_resize: Optional[bool] = None, # 是否调整大小的标志,默认为None
size: Optional[Dict[str, int]] = None, # 调整大小的目标尺寸,字典形式,包含宽和高
resample: PILImageResampling = None, # 调整大小时使用的重采样方法,默认为None
do_center_crop: Optional[bool] = None, # 是否进行中心裁剪的标志,默认为None
crop_size: Optional[Dict[str, int]] = None, # 中心裁剪的目标尺寸,字典形式,包含宽和高
do_rescale: Optional[bool] = None, # 是否进行重新缩放的标志,默认为None
rescale_factor: Optional[float] = None, # 重新缩放的因子,默认为None
do_normalize: Optional[bool] = None, # 是否进行归一化的标志,默认为None
image_mean: Optional[Union[float, Iterable[float]]] = None, # 图像归一化的均值,默认为None
image_std: Optional[Union[float, Iterable[float]]] = None, # 图像归一化的标准差,默认为None
return_tensors: Optional[TensorType] = None, # 返回的张量类型,默认为None
data_format: ChannelDimension = ChannelDimension.FIRST, # 数据的通道格式,默认为第一通道
input_data_format: Optional[Union[str, ChannelDimension]] = None, # 输入数据的通道格式,默认为None
**kwargs, # 其他可能的关键字参数,以字典形式接收
.\models\levit\modeling_levit.py
import itertools
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...modeling_outputs import (
BaseModelOutputWithNoAttention,
BaseModelOutputWithPoolingAndNoAttention,
ImageClassifierOutputWithNoAttention,
ModelOutput,
)
from ...modeling_utils import PreTrainedModel
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from .configuration_levit import LevitConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "LevitConfig"
_CHECKPOINT_FOR_DOC = "facebook/levit-128S"
_EXPECTED_OUTPUT_SHAPE = [1, 16, 384]
_IMAGE_CLASS_CHECKPOINT = "facebook/levit-128S"
_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
LEVIT_PRETRAINED_MODEL_ARCHIVE_LIST = [
"facebook/levit-128S",
]
@dataclass
class LevitForImageClassificationWithTeacherOutput(ModelOutput):
"""
[`LevitForImageClassificationWithTeacher`] 的输出类型。
"""
logits: torch.FloatTensor = None
cls_logits: torch.FloatTensor = None
distillation_logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
class LevitConvEmbeddings(nn.Module):
"""
LeViT Conv Embeddings with Batch Norm, used in the initial patch embedding layer.
"""
def __init__(
self, in_channels, out_channels, kernel_size, stride, padding, dilation=1, groups=1, bn_weight_init=1
):
super().__init__()
self.convolution = nn.Conv2d(
in_channels, out_channels, kernel_size, stride, padding, dilation=dilation, groups=groups, bias=False
)
self.batch_norm = nn.BatchNorm2d(out_channels)
def forward(self, embeddings):
embeddings = self.convolution(embeddings)
embeddings = self.batch_norm(embeddings)
return embeddings
class LevitPatchEmbeddings(nn.Module):
"""
LeViT patch embeddings, for final embeddings to be passed to transformer blocks. It consists of multiple
`LevitConvEmbeddings`.
"""
def __init__(self, config):
super().__init__()
self.embedding_layer_1 = LevitConvEmbeddings(
config.num_channels, config.hidden_sizes[0] // 8, config.kernel_size, config.stride, config.padding
)
self.activation_layer_1 = nn.Hardswish()
self.embedding_layer_2 = LevitConvEmbeddings(
config.hidden_sizes[0] // 8, config.hidden_sizes[0] // 4, config.kernel_size, config.stride, config.padding
)
self.activation_layer_2 = nn.Hardswish()
self.embedding_layer_3 = LevitConvEmbeddings(
config.hidden_sizes[0] // 4, config.hidden_sizes[0] // 2, config.kernel_size, config.stride, config.padding
)
self.activation_layer_3 = nn.Hardswish()
self.embedding_layer_4 = LevitConvEmbeddings(
config.hidden_sizes[0] // 2, config.hidden_sizes[0], config.kernel_size, config.stride, config.padding
)
self.num_channels = config.num_channels
def forward(self, pixel_values):
num_channels = pixel_values.shape[1]
if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
embeddings = self.embedding_layer_1(pixel_values)
embeddings = self.activation_layer_1(embeddings)
embeddings = self.embedding_layer_2(embeddings)
embeddings = self.activation_layer_2(embeddings)
embeddings = self.embedding_layer_3(embeddings)
embeddings = self.activation_layer_3(embeddings)
embeddings = self.embedding_layer_4(embeddings)
return embeddings.flatten(2).transpose(1, 2)
class MLPLayerWithBN(nn.Module):
"""
MLP layer with Batch Norm, used in the transformer blocks.
"""
def __init__(self, input_dim, output_dim, bn_weight_init=1):
super().__init__()
self.linear = nn.Linear(in_features=input_dim, out_features=output_dim, bias=False)
self.batch_norm = nn.BatchNorm1d(output_dim)
def forward(self, hidden_state):
hidden_state = self.linear(hidden_state)
hidden_state = self.batch_norm(hidden_state.flatten(0, 1)).reshape_as(hidden_state)
return hidden_state
class LevitSubsample(nn.Module):
def __init__(self, stride, resolution):
super().__init__()
self.stride = stride
self.resolution = resolution
def forward(self, hidden_state):
batch_size, _, channels = hidden_state.shape
hidden_state = hidden_state.view(batch_size, self.resolution, self.resolution, channels)[
:, :: self.stride, :: self.stride
].reshape(batch_size, -1, channels)
return hidden_state
class LevitAttention(nn.Module):
def __init__(self, hidden_sizes, key_dim, num_attention_heads, attention_ratio, resolution):
super().__init__()
self.num_attention_heads = num_attention_heads
self.scale = key_dim**-0.5
self.key_dim = key_dim
self.attention_ratio = attention_ratio
self.out_dim_keys_values = attention_ratio * key_dim * num_attention_heads + key_dim * num_attention_heads * 2
self.out_dim_projection = attention_ratio * key_dim * num_attention_heads
self.queries_keys_values = MLPLayerWithBN(hidden_sizes, self.out_dim_keys_values)
self.activation = nn.Hardswish()
self.projection = MLPLayerWithBN(self.out_dim_projection, hidden_sizes, bn_weight_init=0)
points = list(itertools.product(range(resolution), range(resolution)))
len_points = len(points)
attention_offsets, indices = {}, []
for p1 in points:
for p2 in points:
offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
if offset not in attention_offsets:
attention_offsets[offset] = len(attention_offsets)
indices.append(attention_offsets[offset])
self.attention_bias_cache = {}
self.attention_biases = torch.nn.Parameter(torch.zeros(num_attention_heads, len(attention_offsets)))
self.register_buffer(
"attention_bias_idxs", torch.LongTensor(indices).view(len_points, len_points), persistent=False
)
@torch.no_grad()
def train(self, mode=True):
super().train(mode)
if mode and self.attention_bias_cache:
self.attention_bias_cache = {}
def get_attention_biases(self, device):
if self.training:
return self.attention_biases[:, self.attention_bias_idxs]
else:
device_key = str(device)
if device_key not in self.attention_bias_cache:
self.attention_bias_cache[device_key] = self.attention_biases[:, self.attention_bias_idxs]
return self.attention_bias_cache[device_key]
def forward(self, hidden_state):
batch_size, seq_length, _ = hidden_state.shape
queries_keys_values = self.queries_keys_values(hidden_state)
query, key, value = queries_keys_values.view(batch_size, seq_length, self.num_attention_heads, -1).split(
[self.key_dim, self.key_dim, self.attention_ratio * self.key_dim], dim=3
)
query = query.permute(0, 2, 1, 3)
key = key.permute(0, 2, 1, 3)
value = value.permute(0, 2, 1, 3)
attention = query @ key.transpose(-2, -1) * self.scale + self.get_attention_biases(hidden_state.device)
attention = attention.softmax(dim=-1)
hidden_state = (attention @ value).transpose(1, 2).reshape(batch_size, seq_length, self.out_dim_projection)
hidden_state = self.projection(self.activation(hidden_state))
return hidden_state
class LevitAttentionSubsample(nn.Module):
def __init__(
self,
input_dim,
output_dim,
key_dim,
num_attention_heads,
attention_ratio,
stride,
resolution_in,
resolution_out,
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.scale = key_dim**-0.5
self.key_dim = key_dim
self.attention_ratio = attention_ratio
self.out_dim_keys_values = attention_ratio * key_dim * num_attention_heads + key_dim * num_attention_heads
self.out_dim_projection = attention_ratio * key_dim * num_attention_heads
self.resolution_out = resolution_out
self.keys_values = MLPLayerWithBN(input_dim, self.out_dim_keys_values)
self.queries_subsample = LevitSubsample(stride, resolution_in)
self.queries = MLPLayerWithBN(input_dim, key_dim * num_attention_heads)
self.activation = nn.Hardswish()
self.projection = MLPLayerWithBN(self.out_dim_projection, output_dim)
self.attention_bias_cache = {}
points = list(itertools.product(range(resolution_in), range(resolution_in)))
points_ = list(itertools.product(range(resolution_out), range(resolution_out)))
len_points, len_points_ = len(points), len(points_)
attention_offsets, indices = {}, []
for p1 in points_:
for p2 in points:
size = 1
offset = (abs(p1[0] * stride - p2[0] + (size - 1) / 2), abs(p1[1] * stride - p2[1] + (size - 1) / 2))
if offset not in attention_offsets:
attention_offsets[offset] = len(attention_offsets)
indices.append(attention_offsets[offset])
self.attention_biases = torch.nn.Parameter(torch.zeros(num_attention_heads, len(attention_offsets)))
self.register_buffer(
"attention_bias_idxs", torch.LongTensor(indices).view(len_points_, len_points), persistent=False
)
@torch.no_grad()
def train(self, mode=True):
super().train(mode)
if mode and self.attention_bias_cache:
self.attention_bias_cache = {}
def get_attention_biases(self, device):
if self.training:
return self.attention_biases[:, self.attention_bias_idxs]
else:
device_key = str(device)
if device_key not in self.attention_bias_cache:
self.attention_bias_cache[device_key] = self.attention_biases[:, self.attention_bias_idxs]
return self.attention_bias_cache[device_key]
def forward(self, hidden_state):
batch_size, seq_length, _ = hidden_state.shape
key, value = (
self.keys_values(hidden_state)
.view(batch_size, seq_length, self.num_attention_heads, -1)
.split([self.key_dim, self.attention_ratio * self.key_dim], dim=3)
)
key = key.permute(0, 2, 1, 3)
value = value.permute(0, 2, 1, 3)
query = self.queries(self.queries_subsample(hidden_state))
query = query.view(batch_size, self.resolution_out**2, self.num_attention_heads, self.key_dim).permute(
0, 2, 1, 3
)
attention = query @ key.transpose(-2, -1) * self.scale + self.get_attention_biases(hidden_state.device)
attention = attention.softmax(dim=-1)
hidden_state = (attention @ value).transpose(1, 2).reshape(batch_size, -1, self.out_dim_projection)
hidden_state = self.projection(self.activation(hidden_state))
return hidden_state
class LevitMLPLayer(nn.Module):
"""
MLP Layer with `2X` expansion in contrast to ViT with `4X`.
"""
def __init__(self, input_dim, hidden_dim):
super().__init__()
self.linear_up = MLPLayerWithBN(input_dim, hidden_dim)
self.activation = nn.Hardswish()
self.linear_down = MLPLayerWithBN(hidden_dim, input_dim)
def forward(self, hidden_state):
hidden_state = self.linear_up(hidden_state)
hidden_state = self.activation(hidden_state)
hidden_state = self.linear_down(hidden_state)
return hidden_state
class LevitResidualLayer(nn.Module):
"""
Residual Block for LeViT
"""
def __init__(self, module, drop_rate):
super().__init__()
self.module = module
self.drop_rate = drop_rate
def forward(self, hidden_state):
if self.training and self.drop_rate > 0:
rnd = torch.rand(hidden_state.size(0), 1, 1, device=hidden_state.device)
rnd = rnd.ge_(self.drop_rate).div(1 - self.drop_rate).detach()
hidden_state = hidden_state + self.module(hidden_state) * rnd
return hidden_state
else:
hidden_state = hidden_state + self.module(hidden_state)
return hidden_state
class LevitStage(nn.Module):
"""
LeViT Stage consisting of `LevitMLPLayer` and `LevitAttention` layers.
"""
def __init__(
self,
config,
idx,
hidden_sizes,
key_dim,
depths,
num_attention_heads,
attention_ratio,
mlp_ratio,
down_ops,
resolution_in,
):
super().__init__()
self.config = config
self.idx = idx
self.hidden_sizes = hidden_sizes
self.key_dim = key_dim
self.depths = depths
self.num_attention_heads = num_attention_heads
self.attention_ratio = attention_ratio
self.mlp_ratio = mlp_ratio
self.down_ops = down_ops
self.resolution_in = resolution_in
):
super().__init__()
self.layers = []
self.config = config
self.resolution_in = resolution_in
for _ in range(depths):
self.layers.append(
LevitResidualLayer(
LevitAttention(hidden_sizes, key_dim, num_attention_heads, attention_ratio, resolution_in),
self.config.drop_path_rate,
)
)
if mlp_ratio > 0:
hidden_dim = hidden_sizes * mlp_ratio
self.layers.append(
LevitResidualLayer(LevitMLPLayer(hidden_sizes, hidden_dim), self.config.drop_path_rate)
)
if down_ops[0] == "Subsample":
self.resolution_out = (self.resolution_in - 1) // down_ops[5] + 1
self.layers.append(
LevitAttentionSubsample(
*self.config.hidden_sizes[idx : idx + 2],
key_dim=down_ops[1],
num_attention_heads=down_ops[2],
attention_ratio=down_ops[3],
stride=down_ops[5],
resolution_in=resolution_in,
resolution_out=self.resolution_out,
)
)
self.resolution_in = self.resolution_out
if down_ops[4] > 0:
hidden_dim = self.config.hidden_sizes[idx + 1] * down_ops[4]
self.layers.append(
LevitResidualLayer(
LevitMLPLayer(self.config.hidden_sizes[idx + 1], hidden_dim), self.config.drop_path_rate
)
)
self.layers = nn.ModuleList(self.layers)
def get_resolution(self):
return self.resolution_in
def forward(self, hidden_state):
for layer in self.layers:
hidden_state = layer(hidden_state)
return hidden_state
class LevitEncoder(nn.Module):
"""
LeViT Encoder consisting of multiple `LevitStage` stages.
"""
def __init__(self, config):
super().__init__()
self.config = config
resolution = self.config.image_size // self.config.patch_size
self.stages = []
self.config.down_ops.append([""])
for stage_idx in range(len(config.depths)):
stage = LevitStage(
config,
stage_idx,
config.hidden_sizes[stage_idx],
config.key_dim[stage_idx],
config.depths[stage_idx],
config.num_attention_heads[stage_idx],
config.attention_ratio[stage_idx],
config.mlp_ratio[stage_idx],
config.down_ops[stage_idx],
resolution,
)
resolution = stage.get_resolution()
self.stages.append(stage)
self.stages = nn.ModuleList(self.stages)
def forward(self, hidden_state, output_hidden_states=False, return_dict=True):
all_hidden_states = () if output_hidden_states else None
for stage in self.stages:
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_state,)
hidden_state = stage(hidden_state)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_state,)
if not return_dict:
return tuple(v for v in [hidden_state, all_hidden_states] if v is not None)
return BaseModelOutputWithNoAttention(last_hidden_state=hidden_state, hidden_states=all_hidden_states)
class LevitClassificationLayer(nn.Module):
"""
LeViT Classification Layer
"""
def __init__(self, input_dim, output_dim):
super().__init__()
self.batch_norm = nn.BatchNorm1d(input_dim)
self.linear = nn.Linear(input_dim, output_dim)
def forward(self, hidden_state):
hidden_state = self.batch_norm(hidden_state)
logits = self.linear(hidden_state)
return logits
class LevitPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = LevitConfig
base_model_prefix = "levit"
main_input_name = "pixel_values"
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d)):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d)):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
@add_start_docstrings(
"The bare Levit model outputting raw features without any specific head on top.",
LEVIT_START_DOCSTRING,
)
class LevitModel(LevitPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.config = config
self.patch_embeddings = LevitPatchEmbeddings(config)
self.encoder = LevitEncoder(config)
self.post_init()
@add_start_docstrings_to_model_forward(LEVIT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=BaseModelOutputWithPoolingAndNoAttention,
config_class=_CONFIG_FOR_DOC,
modality="vision",
expected_output=_EXPECTED_OUTPUT_SHAPE,
)
def forward(
self,
pixel_values: torch.FloatTensor = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPoolingAndNoAttention]:
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
embeddings = self.patch_embeddings(pixel_values)
encoder_outputs = self.encoder(
embeddings,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_state = encoder_outputs[0]
pooled_output = last_hidden_state.mean(dim=1)
if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPoolingAndNoAttention(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
)
@add_start_docstrings(
"""
Levit Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
ImageNet.
""",
LEVIT_START_DOCSTRING,
)
class LevitForImageClassification(LevitPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.config = config
self.num_labels = config.num_labels
self.levit = LevitModel(config)
self.classifier = (
LevitClassificationLayer(config.hidden_sizes[-1], config.num_labels)
if config.num_labels > 0
else torch.nn.Identity()
)
self.post_init()
@add_start_docstrings_to_model_forward(LEVIT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_IMAGE_CLASS_CHECKPOINT,
output_type=ImageClassifierOutputWithNoAttention,
config_class=_CONFIG_FOR_DOC,
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
)
def forward(
self,
pixel_values: torch.FloatTensor = None,
labels: Optional[torch.LongTensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, ImageClassifierOutputWithNoAttention]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.levit(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict)
sequence_output = outputs[0]
sequence_output = sequence_output.mean(1)
logits = self.classifier(sequence_output)
loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return ImageClassifierOutputWithNoAttention(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
)
@add_start_docstrings(
"""
LeViT Model transformer with image classification heads on top (a linear layer on top of the final hidden state and
a linear layer on top of the final hidden state of the distillation token) e.g. for ImageNet. .. warning::
This model supports inference-only. Fine-tuning with distillation (i.e. with a teacher) is not yet
supported.
""",
LEVIT_START_DOCSTRING,
)
class LevitForImageClassificationWithTeacher(LevitPreTrainedModel):
"""
构建一个基于 LeViT 模型的图像分类器,带有两个分类头部(一个用于最终隐藏状态的线性层,另一个用于蒸馏令牌最终隐藏状态的线性层),适用于 ImageNet 等数据集。
注意:该模型仅支持推断,暂不支持使用蒸馏(即与教师模型进行微调)。
Attributes:
config (LevitConfig): 模型的配置对象,包含模型的各种参数设定。
num_labels (int): 分类任务中的标签数量。
levit (LevitModel): 底层的 LeViT 模型实例。
"""
def __init__(self, config):
"""
初始化方法,用于创建一个新的 LevitForImageClassificationWithTeacher 实例。
Args:
config (LevitConfig): 模型的配置对象,包含模型的各种参数设定。
"""
super().__init__(config)
self.config = config
self.num_labels = config.num_labels
self.levit = LevitModel(config)
self.classifier = (
LevitClassificationLayer(config.hidden_sizes[-1], config.num_labels)
if config.num_labels > 0
else torch.nn.Identity()
)
self.classifier_distill = (
LevitClassificationLayer(config.hidden_sizes[-1], config.num_labels)
if config.num_labels > 0
else torch.nn.Identity()
)
self.post_init()
@add_start_docstrings_to_model_forward(LEVIT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_IMAGE_CLASS_CHECKPOINT,
output_type=LevitForImageClassificationWithTeacherOutput,
config_class=_CONFIG_FOR_DOC,
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
)
def forward(
self,
pixel_values: torch.FloatTensor = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, LevitForImageClassificationWithTeacherOutput]:
"""
前向传播方法,执行模型的推断过程。
Args:
pixel_values (torch.FloatTensor, optional): 输入的像素值张量。默认为 None。
output_hidden_states (bool, optional): 是否返回隐藏状态。默认为 None。
return_dict (bool, optional): 是否以字典形式返回输出。默认为 None。
Returns:
Union[Tuple, LevitForImageClassificationWithTeacherOutput]: 根据 return_dict 的设置,返回不同的输出形式。
如果 return_dict 为 False,则返回一个元组。
如果 return_dict 为 True,则返回一个 LevitForImageClassificationWithTeacherOutput 对象。
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.levit(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict)
sequence_output = outputs[0]
sequence_output = sequence_output.mean(1)
cls_logits, distill_logits = self.classifier(sequence_output), self.classifier_distill(sequence_output)
logits = (cls_logits + distill_logits) / 2
if not return_dict:
output = (logits, cls_logits, distill_logits) + outputs[2:]
return output
return LevitForImageClassificationWithTeacherOutput(
logits=logits,
cls_logits=cls_logits,
distillation_logits=distill_logits,
hidden_states=outputs.hidden_states,
)
.\models\levit\__init__.py
from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
_import_structure = {"configuration_levit": ["LEVIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "LevitConfig", "LevitOnnxConfig"]}
try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["feature_extraction_levit"] = ["LevitFeatureExtractor"]
_import_structure["image_processing_levit"] = ["LevitImageProcessor"]
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_levit"] = [
"LEVIT_PRETRAINED_MODEL_ARCHIVE_LIST",
"LevitForImageClassification",
"LevitForImageClassificationWithTeacher",
"LevitModel",
"LevitPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_levit import LEVIT_PRETRAINED_CONFIG_ARCHIVE_MAP, LevitConfig, LevitOnnxConfig
try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .feature_extraction_levit import LevitFeatureExtractor
from .image_processing_levit import LevitImageProcessor
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_levit import (
LEVIT_PRETRAINED_MODEL_ARCHIVE_LIST,
LevitForImageClassification,
LevitForImageClassificationWithTeacher,
LevitModel,
LevitPreTrainedModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
.\models\lilt\configuration_lilt.py
""" LiLT configuration"""
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
LILT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"SCUT-DLVCLab/lilt-roberta-en-base": (
"https://huggingface.co/SCUT-DLVCLab/lilt-roberta-en-base/resolve/main/config.json"
),
}
class LiltConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`LiltModel`]. It is used to instantiate a LiLT
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the LiLT
[SCUT-DLVCLab/lilt-roberta-en-base](https://huggingface.co/SCUT-DLVCLab/lilt-roberta-en-base) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Examples:
```
>>> from transformers import LiltConfig, LiltModel
>>> # Initializing a LiLT SCUT-DLVCLab/lilt-roberta-en-base style configuration
>>> configuration = LiltConfig()
>>> # Randomly initializing a model from the SCUT-DLVCLab/lilt-roberta-en-base style configuration
>>> model = LiltModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
"""
model_type = "lilt"
def __init__(
self,
vocab_size=30522,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
layer_norm_eps=1e-12,
pad_token_id=0,
position_embedding_type="absolute",
classifier_dropout=None,
channel_shrink_ratio=4,
max_2d_position_embeddings=1024,
**kwargs,
):
"""
Initializes a new instance of LiltConfig with optional parameters to define the model architecture.
Parameters:
- vocab_size: The size of the vocabulary.
- hidden_size: The size of the hidden layers.
- num_hidden_layers: The number of hidden layers.
- num_attention_heads: The number of attention heads in the multi-head attention setups.
- intermediate_size: The size of the intermediate (i.e., feed-forward) layer in the transformer blocks.
- hidden_act: The activation function (e.g., "gelu").
- hidden_dropout_prob: The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
- attention_probs_dropout_prob: The dropout ratio for the attention probabilities.
- max_position_embeddings: The maximum length of the input sequences.
- type_vocab_size: The size of the token type vocab.
- initializer_range: The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
- layer_norm_eps: The epsilon used by LayerNorm layers.
- pad_token_id: The ID of the padding token.
- position_embedding_type: The type of position embeddings.
- classifier_dropout: The dropout ratio for classifier.
- channel_shrink_ratio: The shrink ratio of channel.
- max_2d_position_embeddings: The maximum length of the 2D position embeddings.
- **kwargs: Additional keyword arguments.
"""
super().__init__(**kwargs)
):
super().__init__(pad_token_id=pad_token_id, **kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.position_embedding_type = position_embedding_type
self.classifier_dropout = classifier_dropout
self.channel_shrink_ratio = channel_shrink_ratio
self.max_2d_position_embeddings = max_2d_position_embeddings
.\models\lilt\modeling_lilt.py
"""PyTorch LiLT model."""
import math
from typing import Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPooling,
QuestionAnsweringModelOutput,
SequenceClassifierOutput,
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_lilt import LiltConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "LiltConfig"
LILT_PRETRAINED_MODEL_ARCHIVE_LIST = [
"SCUT-DLVCLab/lilt-roberta-en-base",
]
class LiltTextEmbeddings(nn.Module):
def __init__(self, config):
super().__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.padding_idx = config.pad_token_id
self.position_embeddings = nn.Embedding(
config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
)
def forward(
self,
input_ids=None,
token_type_ids=None,
position_ids=None,
inputs_embeds=None,
):
if position_ids is None:
if input_ids is not None:
position_ids = self.create_position_ids_from_input_ids(input_ids, self.padding_idx).to(
input_ids.device
)
else:
position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
if input_ids is not None:
input_shape = input_ids.size()
else:
input_shape = inputs_embeds.size()[:-1]
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = inputs_embeds + token_type_embeddings
if self.position_embedding_type == "absolute":
position_embeddings = self.position_embeddings(position_ids)
embeddings += position_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings, position_ids
def create_position_ids_from_input_ids(self, input_ids, padding_idx):
"""
Args:
非填充符号替换为它们的位置编号。位置编号从 padding_idx+1 开始。忽略填充符号。这是从 fairseq 的 `utils.make_positions` 修改而来。
input_ids: torch.Tensor
padding_idx: int
Returns: torch.Tensor
"""
mask = input_ids.ne(padding_idx).int()
incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask)) * mask
return incremental_indices.long() + padding_idx
def create_position_ids_from_inputs_embeds(self, inputs_embeds):
"""
Args:
我们直接提供嵌入。无法推断哪些是填充符号,因此只生成顺序的位置 ids。
inputs_embeds: torch.Tensor
Returns: torch.Tensor
"""
input_shape = inputs_embeds.size()[:-1]
sequence_length = input_shape[1]
position_ids = torch.arange(
self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
)
return position_ids.unsqueeze(0).expand(input_shape)
class LiltLayoutEmbeddings(nn.Module):
def __init__(self, config):
super().__init__()
self.x_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.hidden_size // 6)
self.y_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.hidden_size // 6)
self.h_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.hidden_size // 6)
self.w_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.hidden_size // 6)
self.padding_idx = config.pad_token_id
self.box_position_embeddings = nn.Embedding(
config.max_position_embeddings,
config.hidden_size // config.channel_shrink_ratio,
padding_idx=self.padding_idx,
)
self.box_linear_embeddings = nn.Linear(
in_features=config.hidden_size, out_features=config.hidden_size // config.channel_shrink_ratio
)
self.LayerNorm = nn.LayerNorm(config.hidden_size // config.channel_shrink_ratio, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, bbox=None, position_ids=None):
try:
left_position_embeddings = self.x_position_embeddings(bbox[:, :, 0])
upper_position_embeddings = self.y_position_embeddings(bbox[:, :, 1])
right_position_embeddings = self.x_position_embeddings(bbox[:, :, 2])
lower_position_embeddings = self.y_position_embeddings(bbox[:, :, 3])
except IndexError as e:
raise IndexError("The `bbox` coordinate values should be within 0-1000 range.") from e
h_position_embeddings = self.h_position_embeddings(bbox[:, :, 3] - bbox[:, :, 1])
w_position_embeddings = self.w_position_embeddings(bbox[:, :, 2] - bbox[:, :, 0])
spatial_position_embeddings = torch.cat(
[
left_position_embeddings,
upper_position_embeddings,
right_position_embeddings,
lower_position_embeddings,
h_position_embeddings,
w_position_embeddings,
],
dim=-1,
)
spatial_position_embeddings = self.box_linear_embeddings(spatial_position_embeddings)
box_position_embeddings = self.box_position_embeddings(position_ids)
spatial_position_embeddings = spatial_position_embeddings + box_position_embeddings
spatial_position_embeddings = self.LayerNorm(spatial_position_embeddings)
spatial_position_embeddings = self.dropout(spatial_position_embeddings)
return spatial_position_embeddings
def __init__(self, config, position_embedding_type=None):
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
f"heads ({config.num_attention_heads})"
)
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.layout_query = nn.Linear(
config.hidden_size // config.channel_shrink_ratio, self.all_head_size // config.channel_shrink_ratio
)
self.layout_key = nn.Linear(
config.hidden_size // config.channel_shrink_ratio, self.all_head_size // config.channel_shrink_ratio
)
self.layout_value = nn.Linear(
config.hidden_size // config.channel_shrink_ratio, self.all_head_size // config.channel_shrink_ratio
)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = position_embedding_type or getattr(
config, "position_embedding_type", "absolute"
)
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
self.channel_shrink_ratio = config.channel_shrink_ratio
class LiltSelfOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class LiltAttention(nn.Module):
def __init__(self, config, position_embedding_type=None):
super().__init__()
self.self = LiltSelfAttention(config, position_embedding_type=position_embedding_type)
self.output = LiltSelfOutput(config)
self.pruned_heads = set()
ori_hidden_size = config.hidden_size
config.hidden_size = config.hidden_size // config.channel_shrink_ratio
self.layout_output = LiltSelfOutput(config)
config.hidden_size = ori_hidden_size
def prune_heads(self, heads):
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
)
self.self.query = prune_linear_layer(self.self.query, index)
self.self.key = prune_linear_layer(self.self.key, index)
self.self.value = prune_linear_layer(self.self.value, index)
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
def forward(
self,
hidden_states: torch.Tensor,
layout_inputs: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
self_outputs = self.self(
hidden_states,
layout_inputs,
attention_mask,
head_mask,
output_attentions,
)
attention_output = self.output(self_outputs[0][0], hidden_states)
layout_attention_output = self.layout_output(self_outputs[0][1], layout_inputs)
outputs = ((attention_output, layout_attention_output),) + self_outputs[1:]
return outputs
class LiltLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = LiltAttention(config)
self.intermediate = LiltIntermediate(config)
self.output = LiltOutput(config)
ori_hidden_size = config.hidden_size
ori_intermediate_size = config.intermediate_size
config.hidden_size = config.hidden_size // config.channel_shrink_ratio
config.intermediate_size = config.intermediate_size // config.channel_shrink_ratio
self.layout_intermediate = LiltIntermediate(config)
self.layout_output = LiltOutput(config)
config.hidden_size = ori_hidden_size
config.intermediate_size = ori_intermediate_size
def forward(
self,
hidden_states: torch.Tensor,
layout_inputs: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
):
) -> Tuple[torch.Tensor]:
self_attention_outputs = self.attention(
hidden_states,
layout_inputs,
attention_mask,
head_mask,
output_attentions=output_attentions,
)
attention_output = self_attention_outputs[0][0]
layout_attention_output = self_attention_outputs[0][1]
outputs = self_attention_outputs[1:]
layer_output = apply_chunking_to_forward(
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
)
layout_layer_output = apply_chunking_to_forward(
self.layout_feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, layout_attention_output
)
outputs = ((layer_output, layout_layer_output),) + outputs
return outputs
def feed_forward_chunk(self, attention_output):
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
return layer_output
def layout_feed_forward_chunk(self, attention_output):
intermediate_output = self.layout_intermediate(attention_output)
layer_output = self.layout_output(intermediate_output, attention_output)
return layer_output
class LiltEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.layer = nn.ModuleList([LiltLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
layout_inputs: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = True,
) -> Union[Tuple[torch.Tensor], BaseModelOutput]:
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
layout_inputs,
attention_mask,
layer_head_mask,
output_attentions,
)
else:
layer_outputs = layer_module(
hidden_states,
layout_inputs,
attention_mask,
layer_head_mask,
output_attentions,
)
hidden_states = layer_outputs[0][0]
layout_inputs = layer_outputs[0][1]
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v
for v in [
hidden_states,
all_hidden_states,
all_self_attentions,
]
if v is not None
)
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
class LiltPooler(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
class LiltPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = LiltConfig
base_model_prefix = "lilt"
supports_gradient_checkpointing = True
_no_split_modules = []
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
LILT_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`LiltConfig`]): Model configuration class with all the parameters of the
model. Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
LILT_INPUTS_DOCSTRING = r"""
"""
@add_start_docstrings(
"The bare LiLT Model transformer outputting raw hidden-states without any specific head on top.",
LILT_START_DOCSTRING,
)
class LiltModel(LiltPreTrainedModel):
def __init__(self, config, add_pooling_layer=True):
super().__init__(config)
self.config = config
self.embeddings = LiltTextEmbeddings(config)
self.layout_embeddings = LiltLayoutEmbeddings(config)
self.encoder = LiltEncoder(config)
self.pooler = LiltPooler(config) if add_pooling_layer else None
self.post_init()
def get_input_embeddings(self):
return self.embeddings.word_embeddings
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
def _prune_heads(self, heads_to_prune):
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
@add_start_docstrings_to_model_forward(LILT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
bbox: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
"""
Performs forward pass of the model. Args:
input_ids (Optional[torch.Tensor], optional): Input tensors for the model.
bbox (Optional[torch.Tensor], optional): Bounding box tensors.
attention_mask (Optional[torch.Tensor], optional): Attention mask tensors.
token_type_ids (Optional[torch.Tensor], optional): Token type ID tensors.
position_ids (Optional[torch.Tensor], optional): Position ID tensors.
head_mask (Optional[torch.Tensor], optional): Head mask tensors.
inputs_embeds (Optional[torch.Tensor], optional): Embedded input tensors.
output_attentions (Optional[bool], optional): Whether to output attentions.
output_hidden_states (Optional[bool], optional): Whether to output hidden states.
return_dict (Optional[bool], optional): Whether to return as dictionary.
"""
pass
@add_start_docstrings(
"""
LiLT Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
output) e.g. for GLUE tasks.
""",
LILT_START_DOCSTRING,
)
class LiltForSequenceClassification(LiltPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.config = config
self.lilt = LiltModel(config, add_pooling_layer=False)
self.classifier = LiltClassificationHead(config)
self.post_init()
@add_start_docstrings_to_model_forward(LILT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
bbox: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
@add_start_docstrings(
"""
Lilt Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
Named-Entity-Recognition (NER) tasks.
""",
LILT_START_DOCSTRING,
)
class LiltForTokenClassification(LiltPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.lilt = LiltModel(config, add_pooling_layer=False)
classifier_dropout = (
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
)
self.dropout = nn.Dropout(classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.post_init()
@add_start_docstrings_to_model_forward(LILT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=TokenClassifierOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
bbox: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
r"""
返回Lilt模型的前向传播结果。
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
用于计算标记分类损失的标签。索引应在 `[0, ..., config.num_labels - 1]` 范围内。
Returns:
如果 `return_dict=False`:
返回一个包含 `(logits, hidden_states, attentions)` 的元组,其中 `logits` 是预测的分类结果张量。
如果 `loss` 不为空,则还包含 `loss`。
如果 `return_dict=True`:
返回一个 `TokenClassifierOutput` 对象,包含 `loss`、`logits`、`hidden_states` 和 `attentions` 属性。
Examples:
```
>>> from transformers import AutoTokenizer, AutoModelForTokenClassification
>>> from datasets import load_dataset
>>> tokenizer = AutoTokenizer.from_pretrained("SCUT-DLVCLab/lilt-roberta-en-base")
>>> model = AutoModelForTokenClassification.from_pretrained("SCUT-DLVCLab/lilt-roberta-en-base")
>>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train")
>>> example = dataset[0]
>>> words = example["tokens"]
>>> boxes = example["bboxes"]
>>> encoding = tokenizer(words, boxes=boxes, return_tensors="pt")
>>> outputs = model(**encoding)
>>> predicted_class_indices = outputs.logits.argmax(-1)
```
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.lilt(
input_ids,
bbox=bbox,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)
loss = None
if labels is not None:
labels = labels.to(logits.device)
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
class LiltClassificationHead(nn.Module):
"""用于句子级分类任务的头部模块。"""
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
classifier_dropout = (
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
)
self.dropout = nn.Dropout(classifier_dropout)
self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
def forward(self, features, **kwargs):
x = features[:, 0, :]
x = self.dropout(x)
x = self.dense(x)
x = torch.tanh(x)
x = self.dropout(x)
x = self.out_proj(x)
return x
@add_start_docstrings(
"""
在 Lilt 模型顶部添加用于提取式问答任务的 span 分类头部(在隐藏状态输出之上的线性层,计算 `span start logits` 和 `span end logits`)。
""",
LILT_START_DOCSTRING,
)
class LiltForQuestionAnswering(LiltPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.lilt = LiltModel(config, add_pooling_layer=False)
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
self.post_init()
@add_start_docstrings_to_model_forward(LILT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=QuestionAnsweringModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
bbox: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
start_positions: Optional[torch.LongTensor] = None,
end_positions: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
.\models\lilt\__init__.py
from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
_import_structure = {
"configuration_lilt": ["LILT_PRETRAINED_CONFIG_ARCHIVE_MAP", "LiltConfig"],
}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_lilt"] = [
"LILT_PRETRAINED_MODEL_ARCHIVE_LIST",
"LiltForQuestionAnswering",
"LiltForSequenceClassification",
"LiltForTokenClassification",
"LiltModel",
"LiltPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_lilt import LILT_PRETRAINED_CONFIG_ARCHIVE_MAP, LiltConfig
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_lilt import (
LILT_PRETRAINED_MODEL_ARCHIVE_LIST,
LiltForQuestionAnswering,
LiltForSequenceClassification,
LiltForTokenClassification,
LiltModel,
LiltPreTrainedModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
.\models\llama\configuration_llama.py
""" LLaMA model configuration"""
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
class LlamaConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the LLaMA-7B.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
```
>>> from transformers import LlamaModel, LlamaConfig
>>> # Initializing a LLaMA llama-7b style configuration
>>> configuration = LlamaConfig()
>>> # Initializing a model from the llama-7b style configuration
>>> model = LlamaModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
"""
model_type = "llama"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=32000,
hidden_size=4096,
intermediate_size=11008,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=None,
hidden_act="silu",
max_position_embeddings=2048,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=None,
bos_token_id=1,
eos_token_id=2,
pretraining_tp=1,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.pretraining_tp = pretraining_tp
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self._rope_scaling_validation()
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
def _rope_scaling_validation(self):
"""
Validate the `rope_scaling` configuration.
"""
if self.rope_scaling is None:
return
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
raise ValueError(
"`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
f"got {self.rope_scaling}"
)
rope_scaling_type = self.rope_scaling.get("type", None)
rope_scaling_factor = self.rope_scaling.get("factor", None)
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
raise ValueError(
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
)
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
.\models\llama\convert_llama_weights_to_hf.py
import argparse
import gc
import json
import os
import shutil
import warnings
import torch
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
try:
from transformers import LlamaTokenizerFast
except ImportError as e:
warnings.warn(e)
warnings.warn(
"The converted tokenizer will be the `slow` tokenizer. To use the fast, update your `tokenizers` library and re-run the tokenizer conversion"
)
LlamaTokenizerFast = None
"""
样例用法:
python src/transformers/models/llama/convert_llama_weights_to_hf.py \
--input_dir /path/to/downloaded/llama/weights --model_size 7B --output_dir /output/path
"""
NUM_SHARDS = {
"7B": 1,
"7Bf": 1,
"13B": 2,
"13Bf": 2,
"34B": 4,
"30B": 4,
"65B": 8,
"70B": 8,
"70Bf": 8,
}
def compute_intermediate_size(n, ffn_dim_multiplier=1, multiple_of=256):
return multiple_of * ((int(ffn_dim_multiplier * int(8 * n / 3)) + multiple_of - 1) // multiple_of)
def read_json(path):
with open(path, "r") as f:
return json.load(f)
def write_json(text, path):
with open(path, "w") as f:
json.dump(text, f)
def write_model(
model_path, input_base_path, model_size, tokenizer_path=None, safe_serialization=True, llama_version=1
):
if not os.path.isfile(os.path.join(input_base_path, "params.json")):
input_base_path = os.path.join(input_base_path, model_size)
os.makedirs(model_path, exist_ok=True)
tmp_model_path = os.path.join(model_path, "tmp")
os.makedirs(tmp_model_path, exist_ok=True)
params = read_json(os.path.join(input_base_path, "params.json"))
num_shards = NUM_SHARDS[model_size]
params = params.get("model", params)
n_layers = params["n_layers"]
n_heads = params["n_heads"]
n_heads_per_shard = n_heads // num_shards
dim = params["dim"]
dims_per_head = dim // n_heads
base = params.get("rope_theta", 10000.0)
inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
if base > 10000.0:
max_position_embeddings = 16384
else:
if llama_version == 1:
max_position_embeddings = 2048
elif llama_version == 2:
max_position_embeddings = 4096
else:
raise NotImplementedError(
f"Version {llama_version} of llama is not supported yet. "
"Current supported versions of llama are [1, 2]."
)
tokenizer_class = LlamaTokenizer if LlamaTokenizerFast is None else LlamaTokenizerFast
if tokenizer_path is not None:
tokenizer = tokenizer_class(tokenizer_path)
tokenizer.save_pretrained(model_path)
vocab_size = tokenizer.vocab_size if tokenizer_path is not None else 32000
if params.get("n_kv_heads", None) is not None:
num_key_value_heads = params["n_kv_heads"]
num_local_key_value_heads = n_heads_per_shard // num_key_value_heads
key_value_dim = dim // num_key_value_heads
else:
num_key_value_heads = n_heads
num_local_key_value_heads = n_heads_per_shard
key_value_dim = dim
def permute(w, n_heads=n_heads, dim1=dim, dim2=dim):
return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2)
print(f"Fetching all parameters from the checkpoint at {input_base_path}.")
if num_shards == 1:
loaded = torch.load(os.path.join(input_base_path, "consolidated.00.pth"), map_location="cpu")
else:
loaded = [
torch.load(os.path.join(input_base_path, f"consolidated.{i:02d}.pth"), map_location="cpu")
for i in range(num_shards)
]
param_count = 0
index_dict = {"weight_map": {}}
filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin"
if num_shards == 1:
state_dict = {
"model.embed_tokens.weight": loaded["tok_embeddings.weight"],
"model.norm.weight": loaded["norm.weight"],
"lm_head.weight": loaded["output.weight"],
}
else:
state_dict = {
"model.norm.weight": loaded[0]["norm.weight"],
"model.embed_tokens.weight": torch.cat(
[loaded[i]["tok_embeddings.weight"] for i in range(num_shards)], dim=1
),
"lm_head.weight": torch.cat([loaded[i]["output.weight"] for i in range(num_shards)], dim=0),
}
for k, v in state_dict.items():
index_dict["weight_map"][k] = filename
param_count += v.numel()
torch.save(state_dict, os.path.join(tmp_model_path, filename))
index_dict["metadata"] = {"total_size": param_count * 2}
write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json"))
ffn_dim_multiplier = params["ffn_dim_multiplier"] if "ffn_dim_multiplier" in params else 1
multiple_of = params["multiple_of"] if "multiple_of" in params else 256
config = LlamaConfig(
hidden_size=dim,
intermediate_size=compute_intermediate_size(dim, ffn_dim_multiplier, multiple_of),
num_attention_heads=params["n_heads"],
num_hidden_layers=params["n_layers"],
rms_norm_eps=params["norm_eps"],
num_key_value_heads=num_key_value_heads,
vocab_size=vocab_size,
rope_theta=base,
max_position_embeddings=max_position_embeddings,
)
config.save_pretrained(tmp_model_path)
del state_dict
del loaded
gc.collect()
print("Loading the checkpoint in a Llama model.")
model = LlamaForCausalLM.from_pretrained(tmp_model_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True)
del model.config._name_or_path
model.config.torch_dtype = torch.float16
print("Saving in the Transformers format.")
model.save_pretrained(model_path, safe_serialization=safe_serialization)
shutil.rmtree(tmp_model_path)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--input_dir",
help="Location of LLaMA weights, which contains tokenizer.model and model folders",
)
parser.add_argument(
"--model_size",
choices=["7B", "7Bf", "13B", "13Bf", "30B", "34B", "65B", "70B", "70Bf", "tokenizer_only"],
help="'f' models correspond to the finetuned versions, and are specific to the Llama2 official release. For more details on Llama2, checkout the original repo: https://huggingface.co/meta-llama",
)
parser.add_argument(
"--output_dir",
help="Location to write HF model and tokenizer",
)
parser.add_argument("--safe_serialization", type=bool, help="Whether or not to save using `safetensors`.")
parser.add_argument(
"--llama_version",
choices=[1, 2],
default=1,
type=int,
help="Version of the Llama model to convert. Currently supports Llama1 and Llama2. Controls the context size",
)
args = parser.parse_args()
spm_path = os.path.join(args.input_dir, "tokenizer.model")
if args.model_size != "tokenizer_only":
write_model(
model_path=args.output_dir,
input_base_path=args.input_dir,
model_size=args.model_size,
safe_serialization=args.safe_serialization,
tokenizer_path=spm_path,
llama_version=args.llama_version,
)
else:
write_tokenizer(args.output_dir, spm_path)
if __name__ == "__main__":
main()
.\models\llama\modeling_flax_llama.py
from functools import partial
from typing import Optional, Tuple
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask
from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput
from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging
from .configuration_llama import LlamaConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "LlamaConfig"
_CHECKPOINT_FOR_DOC = "afmck/testing-llama-tiny"
_REAL_CHECKPOINT_FOR_DOC = "openlm-research/open_llama_3b_v2"
LLAMA_START_DOCSTRING = r"""
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a Flax Linen
[flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
Finally, this model supports inherent JAX features such as:
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
# 参数:
# config ([`LlamaConfig`]): 模型配置类,包含模型的所有参数。
# 用配置文件初始化不会加载与模型关联的权重,仅加载配置。
# 可以查看 [`~FlaxPreTrainedModel.from_pretrained`] 方法来加载模型权重。
# dtype (`jax.numpy.dtype`, *可选*, 默认为 `jax.numpy.float32`):
# 计算的数据类型。可以是 `jax.numpy.float32`, `jax.numpy.float16`, 或 `jax.numpy.bfloat16` 中的一种。
#
# 这可用于在 GPU 或 TPU 上启用混合精度训练或半精度推断。如果指定,则所有计算将使用给定的 `dtype` 进行。
#
# **请注意,这仅指定计算的数据类型,不影响模型参数的数据类型。**
#
# 如果您希望更改模型参数的数据类型,请参阅 [`~FlaxPreTrainedModel.to_fp16`] 和 [`~FlaxPreTrainedModel.to_bf16`]。
# 创建正弦位置编码矩阵,用于将位置索引映射为正弦波形式的向量表示
def create_sinusoidal_positions(num_pos, dim):
# 计算正弦编码的频率逆频率
inv_freq = 1.0 / (10000 ** (np.arange(0, dim, 2) / dim))
# 计算位置索引乘以频率得到的矩阵,每个维度都是浮点数
freqs = np.einsum("i , j -> i j", np.arange(num_pos), inv_freq).astype("float32")
# 按照最后一个维度将两个频率矩阵连接起来,形成最终的正弦位置编码矩阵
emb = np.concatenate((freqs, freqs), axis=-1)
# 将 emb 数组中的每个元素应用正弦函数,然后与对应元素应用余弦函数的结果拼接起来
out = np.concatenate((np.sin(emb)[:, None, :], np.cos(emb)[:, None, :]), axis=-1)
# 从拼接后的数组中取出前 num_pos 列,并转换为 JAX 数组格式返回
return jnp.array(out[:, :, :num_pos])
# 定义一个函数,用于将输入张量的后一半隐藏维度旋转
def rotate_half(tensor):
# 将张量按照其最后一个维度的一半进行拼接,实现旋转操作
rotate_half_tensor = jnp.concatenate(
(-tensor[..., tensor.shape[-1] // 2 :], tensor[..., : tensor.shape[-1] // 2]), axis=-1
)
return rotate_half_tensor
# 定义一个函数,将旋转的位置嵌入应用到输入张量上
def apply_rotary_pos_emb(tensor, sin_pos, cos_pos):
# 将输入张量乘以余弦位置编码,然后加上经过旋转半隐藏维度的正弦位置编码
return (tensor * cos_pos) + (rotate_half(tensor) * sin_pos)
# 定义一个名为FlaxLlamaRMSNorm的类,继承自nn.Module
class FlaxLlamaRMSNorm(nn.Module):
# 类的配置信息
config: LlamaConfig
dtype: jnp.dtype = jnp.float32
# 设置方法,在类实例化时调用,用于初始化权重和其他参数
def setup(self):
# 设置 epsilon 参数为 RMS 归一化的小数值
self.epsilon = self.config.rms_norm_eps
# 初始化权重矩阵,形状为隐藏大小(hidden_size)
self.weight = self.param("weight", lambda _, shape: jnp.ones(shape), self.config.hidden_size)
# 类的调用方法,对隐藏状态进行处理
def __call__(self, hidden_states):
# 将隐藏状态转换为 jax 数组,并将数据类型设置为 jnp.float32
variance = jnp.asarray(hidden_states, dtype=jnp.float32)
# 计算方差的平方
variance = jnp.power(variance, 2)
# 求取方差的平均值,保持最后一个维度
variance = variance.mean(-1, keepdims=True)
# 根据 RMS 归一化公式,将隐藏状态除以标准差加上一个小值 epsilon
hidden_states = hidden_states / jnp.sqrt(variance + self.epsilon)
# 返回加权后的隐藏状态
return self.weight * jnp.asarray(hidden_states, dtype=self.dtype)
# 定义一个名为FlaxLlamaRotaryEmbedding的类,继承自nn.Module
class FlaxLlamaRotaryEmbedding(nn.Module):
# 类的配置信息
config: LlamaConfig
dtype: jnp.dtype = jnp.float32
# 设置方法,在类实例化时调用,用于初始化位置编码
def setup(self):
# 计算每个注意力头的维度
head_dim = self.config.hidden_size // self.config.num_attention_heads
# 创建正弦余弦位置编码
self.sincos = create_sinusoidal_positions(self.config.max_position_embeddings, head_dim)
# 类的调用方法,将位置编码应用到键、查询和位置ID上
def __call__(self, key, query, position_ids):
# 获取指定位置ID的正弦余弦位置编码
sincos = self.sincos[position_ids]
sin_pos, cos_pos = jnp.split(sincos, 2, axis=-1)
# 将正弦余弦位置编码应用到键和查询上
key = apply_rotary_pos_emb(key, sin_pos, cos_pos)
query = apply_rotary_pos_emb(query, sin_pos, cos_pos)
# 将键和查询转换为 jax 数组,并将数据类型设置为 self.dtype
key = jnp.asarray(key, dtype=self.dtype)
query = jnp.asarray(query, dtype=self.dtype)
# 返回处理后的键和查询
return key, query
# 定义一个名为FlaxLlamaAttention的类,继承自nn.Module
class FlaxLlamaAttention(nn.Module):
# 类的配置信息
config: LlamaConfig
dtype: jnp.dtype = jnp.float32
causal: bool = True
is_cross_attention: bool = False
# 设置方法,在类实例化时调用,用于初始化注意力机制的参数
def setup(self):
# 从配置中获取参数
config = self.config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
self.attention_softmax_in_fp32 = self.dtype is not jnp.float32
# 创建偏置注意力层
dense = partial(
nn.Dense,
self.embed_dim,
use_bias=config.attention_bias,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
)
# 初始化查询、键、值和输出投影层
self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
self.o_proj = dense()
# 创建因果遮罩
self.causal_mask = make_causal_mask(jnp.ones((1, config.max_position_embeddings), dtype="bool"), dtype="bool")
# 创建旋转嵌入层
self.rotary_emb = FlaxLlamaRotaryEmbedding(config, dtype=self.dtype)
# 内部方法,用于将隐藏状态分割为多个注意力头
def _split_heads(self, hidden_states):
return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
def _merge_heads(self, hidden_states):
# 将输入的 hidden_states 重塑成形状为 (batch_size, sequence_length, self.embed_dim) 的张量,并返回
return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
@nn.compact
# 从 transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoSelfAttention._concatenate_to_cache 复制而来
def _concatenate_to_cache(self, key, value, query, attention_mask):
"""
This function takes projected key, value states from a single input token and concatenates the states to cached
states from previous steps. This function is slighly adapted from the official Flax repository:
https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py
"""
# 检测是否通过缺少现有缓存数据进行初始化
is_initialized = self.has_variable("cache", "cached_key")
# 如果未初始化,则创建形状和类型与 key 相同的零张量作为 cached_key
cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
# 如果未初始化,则创建形状和类型与 value 相同的零张量作为 cached_value
cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
# 如果未初始化,则创建初始值为 0 的 cache_index
cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
if is_initialized:
# 获取当前缓存张量的形状信息
*batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
# 使用新的 1 维空间切片更新 key、value 缓存
cur_index = cache_index.value
indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
key = lax.dynamic_update_slice(cached_key.value, key, indices)
value = lax.dynamic_update_slice(cached_value.value, value, indices)
cached_key.value = key
cached_value.value = value
# 更新 cache_index 值,增加已更新的缓存向量数量
num_updated_cache_vectors = query.shape[1]
cache_index.value = cache_index.value + num_updated_cache_vectors
# 生成用于缓存的因果掩码:我们的单个查询位置只应关注已生成和缓存的键位置,而不是剩余的零元素
pad_mask = jnp.broadcast_to(
jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
)
# 将 pad_mask 与 attention_mask 结合
attention_mask = combine_masks(pad_mask, attention_mask)
return key, value, attention_mask
def __call__(
self,
hidden_states,
attention_mask,
position_ids,
deterministic: bool = True,
init_cache: bool = False,
output_attentions: bool = False,
):
# 使用投影函数计算查询向量
query = self.q_proj(hidden_states)
# 使用投影函数计算键向量
key = self.k_proj(hidden_states)
# 使用投影函数计算值向量
value = self.v_proj(hidden_states)
# 将查询向量分割成多个头
query = self._split_heads(query)
# 将键向量分割成多个头
key = self._split_heads(key)
# 将值向量分割成多个头
value = self._split_heads(value)
# 应用旋转位置编码器到键和查询向量
key, query = self.rotary_emb(key, query, position_ids)
# 获取查询向量和键向量的长度
query_length, key_length = query.shape[1], key.shape[1]
# 构建因果掩码
if self.has_variable("cache", "cached_key"):
# 如果有缓存的键,根据缓存索引和最大解码器长度动态切片因果掩码
mask_shift = self.variables["cache"]["cache_index"]
max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
causal_mask = lax.dynamic_slice(
self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
)
else:
# 否则,使用静态切片获取因果掩码
causal_mask = self.causal_mask[:, :, :query_length, :key_length]
# 获取批次大小
batch_size = hidden_states.shape[0]
# 将因果掩码广播到与查询向量和键向量匹配的形状
causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
# 广播注意力掩码以匹配因果掩码的形状
attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
# 结合注意力掩码和因果掩码
attention_mask = combine_masks(attention_mask, causal_mask)
# 初始化 dropout RNG
dropout_rng = None
if not deterministic and self.config.attention_dropout > 0.0:
dropout_rng = self.make_rng("dropout")
# 在快速自回归解码期间,逐步一次性输入一个位置,逐步缓存键和值。
if self.has_variable("cache", "cached_key") or init_cache:
key, value, attention_mask = self._concatenate_to_cache(key, value, query, attention_mask)
# 将布尔掩码转换为浮点数掩码
attention_bias = lax.select(
attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
)
# 标准点积注意力
attention_dtype = jnp.float32 if self.attention_softmax_in_fp32 else self.dtype
attn_weights = dot_product_attention_weights(
query,
key,
bias=attention_bias,
dropout_rng=dropout_rng,
dropout_rate=self.config.attention_dropout,
deterministic=deterministic,
dtype=attention_dtype,
)
# 如果需要,将注意力权重转换为指定的数据类型
if self.attention_softmax_in_fp32:
attn_weights = attn_weights.astype(self.dtype)
# 使用注意力权重计算注意力输出
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value)
# 合并多头得到的注意力输出
attn_output = self._merge_heads(attn_output)
# 应用输出投影层
attn_output = self.o_proj(attn_output)
# 准备输出,包括注意力输出和注意力权重(如果需要)
outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
return outputs
class FlaxLlamaMLP(nn.Module):
config: LlamaConfig # 类型注解:指定该类的配置信息来自于LlamaConfig类
dtype: jnp.dtype = jnp.float32 # 类型注解:指定数据类型为jnp.float32,默认为浮点数类型
def setup(self):
embed_dim = self.config.hidden_size # 从配置中获取隐藏层大小作为嵌入维度
inner_dim = self.config.intermediate_size if self.config.intermediate_size is not None else 4 * embed_dim
# 计算内部层维度,如果配置中有中间大小定义则使用,否则使用默认值4倍的嵌入维度
kernel_init = jax.nn.initializers.normal(self.config.initializer_range)
# 使用正态分布初始化器初始化核参数,范围由配置的initializer_range定义
self.act = ACT2FN[self.config.hidden_act]
# 从ACT2FN字典中获取激活函数,并存储在act属性中,其类型由配置的hidden_act指定
self.gate_proj = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init)
# 创建具有inner_dim大小的全连接层,不使用偏置,使用上述初始化器初始化权重
self.down_proj = nn.Dense(embed_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init)
# 创建具有embed_dim大小的全连接层,不使用偏置,使用上述初始化器初始化权重
self.up_proj = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init)
# 创建具有inner_dim大小的全连接层,不使用偏置,使用上述初始化器初始化权重
def __call__(self, hidden_states):
up_proj_states = self.up_proj(hidden_states)
# 使用up_proj层处理输入的隐藏状态
gate_states = self.act(self.gate_proj(hidden_states))
# 使用激活函数act处理gate_proj层处理后的隐藏状态
hidden_states = self.down_proj(up_proj_states * gate_states)
# 使用down_proj层处理up_proj_states与gate_states的乘积,并将结果存储在隐藏状态中
return hidden_states
# 返回处理后的隐藏状态作为结果
class FlaxLlamaDecoderLayer(nn.Module):
config: LlamaConfig # 类型注解:指定该类的配置信息来自于LlamaConfig类
dtype: jnp.dtype = jnp.float32 # 类型注解:指定数据类型为jnp.float32,默认为浮点数类型
def setup(self):
self.input_layernorm = FlaxLlamaRMSNorm(self.config, dtype=self.dtype)
# 创建一个使用LlamaConfig和指定数据类型的FlaxLlamaRMSNorm实例,存储在input_layernorm属性中
self.self_attn = FlaxLlamaAttention(self.config, dtype=self.dtype)
# 创建一个使用LlamaConfig和指定数据类型的FlaxLlamaAttention实例,存储在self_attn属性中
self.post_attention_layernorm = FlaxLlamaRMSNorm(self.config, dtype=self.dtype)
# 创建一个使用LlamaConfig和指定数据类型的FlaxLlamaRMSNorm实例,存储在post_attention_layernorm属性中
self.mlp = FlaxLlamaMLP(self.config, dtype=self.dtype)
# 创建一个使用LlamaConfig和指定数据类型的FlaxLlamaMLP实例,存储在mlp属性中
def __call__(
self,
hidden_states,
attention_mask=None,
position_ids=None,
deterministic: bool = True,
init_cache: bool = False,
output_attentions: bool = False,
):
residual = hidden_states
# 将输入的隐藏状态存储在变量residual中,用于残差连接
hidden_states = self.input_layernorm(hidden_states)
# 使用input_layernorm对隐藏状态进行规范化处理
outputs = self.self_attn(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
deterministic=deterministic,
init_cache=init_cache,
output_attentions=output_attentions,
)
# 使用self_attn处理规范化后的隐藏状态,传递额外参数attention_mask、position_ids等,并将结果存储在outputs中
attn_output = outputs[0]
# 从outputs中获取注意力机制的输出
hidden_states = residual + attn_output
# 将residual与注意力输出相加得到新的隐藏状态
residual = hidden_states
# 将新的隐藏状态存储在变量residual中,用于下一步的残差连接
hidden_states = self.post_attention_layernorm(hidden_states)
# 使用post_attention_layernorm对新的隐藏状态进行规范化处理
hidden_states = self.mlp(hidden_states)
# 使用mlp处理规范化后的隐藏状态,得到最终的输出
hidden_states = residual + hidden_states
# 将残差连接的结果与MLP处理后的隐藏状态相加,作为最终的输出
return (hidden_states,) + outputs[1:]
# 返回包含最终输出和outputs中其他项的元组
# Copied from transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoPreTrainedModel with GPTNeo->Llama, GPT_NEO->LLAMA, transformer->model
class FlaxLlamaPreTrainedModel(FlaxPreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = LlamaConfig
# 指定配置类为LlamaConfig
base_model_prefix = "model"
# 指定基础模型前缀为"model"
module_class: nn.Module = None
# 指定模块类为nn.Module,初始值为None
def __init__(
self,
config: LlamaConfig,
input_shape: Tuple = (1, 1),
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs,
):
# 使用给定的配置和参数初始化模块对象
module = self.module_class(config=config, dtype=dtype, **kwargs)
# 调用父类的初始化方法,传入配置、模块对象以及其他参数
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# 初始化输入张量
input_ids = jnp.zeros(input_shape, dtype="i4")
# 创建与 input_ids 相同形状的全 1 张量作为 attention_mask
attention_mask = jnp.ones_like(input_ids)
# 根据 input_ids 的维度生成位置编码
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
# 利用输入的随机种子分割出两个随机数生成器
params_rng, dropout_rng = jax.random.split(rng)
# 将随机数生成器存入字典
rngs = {"params": params_rng, "dropout": dropout_rng}
# 利用模块的初始化方法初始化参数
random_params = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)["params"]
# 如果提供了额外的参数,则将随机初始化的参数与提供的参数进行合并
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
def init_cache(self, batch_size, max_length):
r"""
Args:
batch_size (`int`):
fast auto-regressive decoding 使用的批大小。定义初始化缓存的批大小。
max_length (`int`):
auto-regressive decoding 的最大可能长度。定义初始化缓存的序列长度。
"""
# 初始化输入变量以检索缓存
input_ids = jnp.ones((batch_size, max_length))
# 创建与 input_ids 相同形状的全 1 张量作为 attention_mask
attention_mask = jnp.ones_like(input_ids)
# 根据 input_ids 的形状生成位置编码
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
# 利用模块的初始化方法初始化变量,并指定初始化缓存
init_variables = self.module.init(
jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
)
return unfreeze(init_variables["cache"])
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
def __call__(
self,
input_ids,
attention_mask=None,
position_ids=None,
params: dict = None,
past_key_values: dict = None,
dropout_rng: jax.random.PRNGKey = None,
train: bool = False,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
# 省略了 __call__ 方法的注释,因为该方法通过装饰器 @add_start_docstrings_to_model_forward 添加了文档字符串
):
# 如果没有显式提供输出注意力的设置,则使用配置中的默认值
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
# 如果没有显式提供输出隐藏状态的设置,则使用配置中的默认值
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
# 如果没有显式提供返回字典的设置,则使用配置中的默认值
return_dict = return_dict if return_dict is not None else self.config.return_dict
# 获取输入张量的批量大小和序列长度
batch_size, sequence_length = input_ids.shape
# 如果未提供位置编码,则根据序列长度创建默认位置编码
if position_ids is None:
# 如果传入了过去的键值(past_key_values),则需要明确提供位置编码,否则抛出异常
if past_key_values is not None:
raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.")
# 使用广播操作将序列长度范围内的数组扩展为指定批次大小的位置编码张量
position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
# 如果未提供注意力遮罩,则创建全1的注意力遮罩张量
if attention_mask is None:
attention_mask = jnp.ones((batch_size, sequence_length))
# 处理任何需要的伪随机数生成器
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
# 准备输入参数字典,包括模型参数或者传入的参数
inputs = {"params": params or self.params}
# 如果传入了过去的键值(past_key_values),则将其作为缓存传递给模型,确保缓存是可变的以便后续更新
if past_key_values:
inputs["cache"] = past_key_values
mutable = ["cache"]
else:
mutable = False
# 应用模型的正向传播,传递所有必要的输入张量和设置
outputs = self.module.apply(
inputs,
jnp.array(input_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"),
jnp.array(position_ids, dtype="i4"),
not train,
False,
output_attentions,
output_hidden_states,
return_dict,
rngs=rngs,
mutable=mutable,
)
# 如果传入了过去的键值(past_key_values)并且需要返回字典,则将更新后的缓存添加到模型输出中
if past_key_values is not None and return_dict:
outputs, past_key_values = outputs
outputs["past_key_values"] = unfreeze(past_key_values["cache"])
return outputs
# 如果传入了过去的键值(past_key_values)但不需要返回字典,则将更新后的缓存添加到模型输出元组中
elif past_key_values is not None and not return_dict:
outputs, past_key_values = outputs
outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
# 返回模型的输出结果
return outputs
class FlaxLlamaLayerCollection(nn.Module):
# LlamaConfig 类型的配置信息
config: LlamaConfig
# 默认数据类型为 jnp.float32
dtype: jnp.dtype = jnp.float32
# 初始化方法
def setup(self):
# 创建一系列 FlaxLlamaDecoderLayer 对象并存储在 self.blocks 中
self.blocks = [
FlaxLlamaDecoderLayer(self.config, dtype=self.dtype, name=str(i))
for i in range(self.config.num_hidden_layers)
]
# 调用实例时执行的方法
def __call__(
self,
hidden_states,
attention_mask=None,
position_ids=None,
deterministic: bool = True,
init_cache: bool = False,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = False,
):
# 如果输出注意力矩阵,则初始化空的元组 all_attentions
all_attentions = () if output_attentions else None
# 如果输出隐藏状态,则初始化空的元组 all_hidden_states
all_hidden_states = () if output_hidden_states else None
# 遍历 self.blocks 中的每个 FlaxLlamaDecoderLayer 对象
for block in self.blocks:
# 如果输出隐藏状态,则将当前隐藏状态 hidden_states 添加到 all_hidden_states 中
if output_hidden_states:
all_hidden_states += (hidden_states,)
# 调用 block 对象,计算层的输出结果 layer_outputs
layer_outputs = block(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
deterministic=deterministic,
init_cache=init_cache,
output_attentions=output_attentions,
)
# 更新 hidden_states 为当前层的输出结果中的第一个元素
hidden_states = layer_outputs[0]
# 如果输出注意力矩阵,则将当前层的注意力矩阵添加到 all_attentions 中
if output_attentions:
all_attentions += (layer_outputs[1],)
# 输出结果包括 hidden_states, all_hidden_states, all_attentions
# 注意:all_hidden_states 和 all_attentions 可能包含 None 值,由 FlaxLlamaModule 进行过滤处理
outputs = (hidden_states, all_hidden_states, all_attentions)
return outputs
class FlaxLlamaModule(nn.Module):
# LlamaConfig 类型的配置信息
config: LlamaConfig
# 默认数据类型为 jnp.float32
dtype: jnp.dtype = jnp.float32
# 初始化方法
def setup(self):
# 设置隐藏大小为 config 中的隐藏大小
self.hidden_size = self.config.hidden_size
# 使用正态分布初始化 embed_tokens 层,存储在 self.embed_tokens 中
embedding_init = jax.nn.initializers.normal(stddev=self.config.initializer_range)
self.embed_tokens = nn.Embed(
self.config.vocab_size,
self.hidden_size,
embedding_init=embedding_init,
dtype=self.dtype,
)
# 创建 FlaxLlamaLayerCollection 对象并存储在 self.layers 中
self.layers = FlaxLlamaLayerCollection(self.config, dtype=self.dtype)
# 创建 FlaxLlamaRMSNorm 对象并存储在 self.norm 中
self.norm = FlaxLlamaRMSNorm(self.config, dtype=self.dtype)
# 调用实例时执行的方法
def __call__(
self,
input_ids,
attention_mask=None,
position_ids=None,
deterministic=True,
init_cache: bool = False,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
# 省略部分代码,未提供完整内容
# 使用给定的输入 ID 创建输入的嵌入表示,数据类型转换为32位整数
input_embeds = self.embed_tokens(input_ids.astype("i4"))
# 将输入的嵌入表示传递给模型的层进行处理,并返回处理后的输出结果
outputs = self.layers(
input_embeds,
position_ids=position_ids,
attention_mask=attention_mask,
deterministic=deterministic,
init_cache=init_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# 从模型输出中获取隐藏状态,索引为0的元素为模型的最后隐藏状态
hidden_states = outputs[0]
# 对隐藏状态进行归一化处理
hidden_states = self.norm(hidden_states)
# 如果需要输出所有隐藏状态,则将当前隐藏状态添加到所有隐藏状态列表中
if output_hidden_states:
all_hidden_states = outputs[1] + (hidden_states,)
outputs = (hidden_states, all_hidden_states) + outputs[2:]
else:
outputs = (hidden_states,) + outputs[1:]
# 如果不需要以字典形式返回结果,则返回所有非空的输出值的元组
if not return_dict:
return tuple(v for v in outputs if v is not None)
# 如果需要以字典形式返回结果,则使用 FlaxBaseModelOutput 类封装最后的隐藏状态、所有隐藏状态和注意力值
return FlaxBaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=outputs[1],
attentions=outputs[-1],
)
# 添加起始文档字符串和元数据到 FlaxLlamaModel 类,说明它是一个裸 Llama 模型变换器,输出原始隐藏状态,没有特定的顶部头部。
# 使用 LLAMA_START_DOCSTRING 定义的起始文档字符串作为补充信息。
@add_start_docstrings(
"The bare Llama Model transformer outputting raw hidden-states without any specific head on top.",
LLAMA_START_DOCSTRING,
)
class FlaxLlamaModel(FlaxLlamaPreTrainedModel):
module_class = FlaxLlamaModule
# 向 FlaxLlamaModel 类添加调用示例的文档字符串
append_call_sample_docstring(
FlaxLlamaModel,
_CHECKPOINT_FOR_DOC,
FlaxBaseModelOutput,
_CONFIG_FOR_DOC,
real_checkpoint=_REAL_CHECKPOINT_FOR_DOC,
)
# 定义 FlaxLlamaForCausalLMModule 类,用于支持因果语言建模任务
class FlaxLlamaForCausalLMModule(nn.Module):
# 模块配置参数
config: LlamaConfig
# 数据类型,默认为 jnp.float32
dtype: jnp.dtype = jnp.float32
def setup(self):
# 使用给定的配置参数和数据类型创建 Llama 模型
self.model = FlaxLlamaModule(self.config, dtype=self.dtype)
# 创建语言建模头部,一个全连接层,用于生成词汇表大小的输出
self.lm_head = nn.Dense(
self.config.vocab_size,
use_bias=False,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
)
# 定义模块的调用方法
def __call__(
self,
input_ids,
attention_mask=None,
position_ids=None,
deterministic: bool = True,
init_cache: bool = False,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
# 调用 Llama 模型来处理输入序列
outputs = self.model(
input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
deterministic=deterministic,
init_cache=init_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# 获取模型的隐藏状态
hidden_states = outputs[0]
# 使用语言建模头部生成最终的语言建模输出
lm_logits = self.lm_head(hidden_states)
# 如果不要求返回字典格式的输出,则返回元组形式的输出
if not return_dict:
return (lm_logits,) + outputs[1:]
# 返回格式化后的因果语言建模输出
return FlaxCausalLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
# 向 FlaxLlamaForCausalLM 类添加起始文档字符串,说明它是带有语言建模头部的 Llama 模型变换器
@add_start_docstrings(
"""
The Llama Model transformer with a language modeling head (linear layer) on top.
""",
LLAMA_START_DOCSTRING,
)
# 从 transformers.models.gptj.modeling_flax_gptj.FlaxGPTJForCausalLM 复制到 FlaxLlamaForCausalLM,
# 并将其中的 GPTJ 替换为 Llama
class FlaxLlamaForCausalLM(FlaxLlamaPreTrainedModel):
module_class = FlaxLlamaForCausalLMModule
# 为生成准备输入数据的方法
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
# 初始化缓存
batch_size, seq_length = input_ids.shape
# 使用模型的初始化方法创建缓存
past_key_values = self.init_cache(batch_size, max_length)
# 注意:通常需要在 attention_mask 中对超出 input_ids.shape[-1] 和 cache_length 之外的位置置为 0。
# 但由于 Llama 使用因果注意力机制,这些位置已经被掩码处理。
# 因此,在这里我们可以创建一个单一的静态 attention_mask,这样更高效地进行编译。
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
# 如果有传入 attention_mask,则根据它计算 position_ids
if attention_mask is not None:
position_ids = attention_mask.cumsum(axis=-1) - 1
extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
else:
# 否则,根据序列长度广播创建 position_ids
position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
# 返回生成所需的输入数据字典
return {
"past_key_values": past_key_values,
"attention_mask": extended_attention_mask,
"position_ids": position_ids,
}
# 更新生成过程中的输入数据的方法
def update_inputs_for_generation(self, model_outputs, model_kwargs):
# 更新模型关键值缓存和 position_ids
model_kwargs["past_key_values"] = model_outputs.past_key_values
model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
# 返回更新后的输入数据字典
return model_kwargs
# 将样本文档字符串附加到指定类中的方法上
append_call_sample_docstring(
# 目标类:FlaxLlamaForCausalLM,用于添加文档字符串
FlaxLlamaForCausalLM,
# 用于文档的检查点对象的名称或引用:_CHECKPOINT_FOR_DOC
_CHECKPOINT_FOR_DOC,
# 生成的文档字符串应描述的输出对象类型:FlaxCausalLMOutput
FlaxCausalLMOutput,
# 用于文档的配置对象的名称或引用:_CONFIG_FOR_DOC
_CONFIG_FOR_DOC,
# 实际使用的检查点对象的名称或引用:_REAL_CHECKPOINT_FOR_DOC
real_checkpoint=_REAL_CHECKPOINT_FOR_DOC,
)
.\models\llama\modeling_llama.py
"""PyTorch LLaMA model."""
import math
import warnings
from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
QuestionAnsweringModelOutput,
SequenceClassifierOutputWithPast,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
)
from .configuration_llama import LlamaConfig
if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "LlamaConfig"
def _get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)
class LlamaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states
ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm)
class LlamaRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
super().__init__()
self.scaling_factor = scaling_factor
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.max_seq_len_cached = max_position_embeddings
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
t = t / self.scaling_factor
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False)
self.register_buffer("_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False)
@property
def sin_cached(self):
logger.warning_once(
"The sin_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use "
"the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class"
)
return self._sin_cached
@property
def cos_cached(self):
logger.warning_once(
"The cos_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use "
"the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class"
)
return self._cos_cached
@torch.no_grad()
def forward(self, x, position_ids):
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
"""LlamaRotaryEmbedding 扩展,添加了线性缩放。感谢 Reddit 用户 /u/kaiokendev 的贡献。"""
def forward(self, x, position_ids):
position_ids = position_ids.float() / self.scaling_factor
cos, sin = super().forward(x, position_ids)
return cos, sin
class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
"""LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
def forward(self, x, position_ids):
seq_len = torch.max(position_ids) + 1
if seq_len > self.max_position_embeddings:
base = self.base * (
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
) ** (self.dim / (self.dim - 2))
inv_freq = 1.0 / (
base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim)
)
self.register_buffer("inv_freq", inv_freq, persistent=False)
cos, sin = super().forward(x, position_ids)
return cos, sin
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
if self.config.pretraining_tp > 1:
slice = self.intermediate_size // self.config.pretraining_tp
gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
up_proj_slices = self.up_proj.weight.split(slice, dim=0)
down_proj_slices = self.down_proj.weight.split(slice, dim=1)
gate_proj = torch.cat(
[F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
)
up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
down_proj = [
F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
]
down_proj = sum(down_proj)
else:
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
class LlamaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
if layer_idx is None:
logger.warning_once(
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)
self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.is_causal = True
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
self._init_rope()
def _init_rope(self):
if self.config.rope_scaling is None:
self.rotary_emb = LlamaRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)
else:
scaling_type = self.config.rope_scaling["type"]
scaling_factor = self.config.rope_scaling["factor"]
if scaling_type == "linear":
self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
)
elif scaling_type == "dynamic":
self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
class LlamaFlashAttention2(LlamaAttention):
"""
Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
):
def _flash_attention_forward(
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
):
):
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
first unpad the input, then computes the attention scores and pad the final attention scores.
Args:
query_states (`torch.Tensor`):
Input query states to be passed to Flash Attention API
key_states (`torch.Tensor`):
Input key states to be passed to Flash Attention API
value_states (`torch.Tensor`):
Input value states to be passed to Flash Attention API
attention_mask (`torch.Tensor`):
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
position of padding tokens and 1 for the position of non-padding tokens.
dropout (`float`):
Attention dropout
softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
"""
if not self._flash_attn_uses_top_left_mask:
causal = self.is_causal
else:
causal = self.is_causal and query_length != 1
if attention_mask is not None:
batch_size = query_states.shape[0]
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
query_states, key_states, value_states, attention_mask, query_length
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
)
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else:
attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
)
return attn_output
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
key_layer = index_first_axis(
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
value_layer = index_first_axis(
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
if query_length == kv_seq_len:
query_layer = index_first_axis(
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
elif query_length == 1:
max_seqlen_in_batch_q = 1
cu_seqlens_q = torch.arange(
batch_size + 1, dtype=torch.int32, device=query_layer.device
)
indices_q = cu_seqlens_q[:-1]
query_layer = query_layer.squeeze(1)
else:
attention_mask = attention_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
return (
query_layer,
key_layer,
value_layer,
indices_q,
(cu_seqlens_q, cu_seqlens_k),
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)
class LlamaSdpaAttention(LlamaAttention):
"""
Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
`LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
SDPA API.
"""
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
):
"""
Forward pass of the LlamaSdpaAttention module.
Args:
hidden_states (torch.Tensor): The input hidden states.
attention_mask (Optional[torch.Tensor], optional): The attention mask. Defaults to None.
position_ids (Optional[torch.LongTensor], optional): The position ids. Defaults to None.
past_key_value (Optional[Cache], optional): The past key value cache. Defaults to None.
output_attentions (bool, optional): Whether to output attentions. Defaults to False.
use_cache (bool, optional): Whether to use caching. Defaults to False.
cache_position (Optional[torch.LongTensor], optional): The position for caching. Defaults to None.
**kwargs: Additional keyword arguments.
Returns:
torch.Tensor: The output tensor from the attention layer.
"""
LLAMA_ATTENTION_CLASSES = {
"eager": LlamaAttention,
"flash_attention_2": LlamaFlashAttention2,
"sdpa": LlamaSdpaAttention,
}
class LlamaDecoderLayer(nn.Module):
def __init__(self, config: LlamaConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
self.mlp = LlamaMLP(config)
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
):
"""
Forward pass of the LlamaDecoderLayer module.
Args:
hidden_states (torch.Tensor): The input hidden states.
attention_mask (Optional[torch.Tensor], optional): The attention mask. Defaults to None.
position_ids (Optional[torch.LongTensor], optional): The position ids. Defaults to None.
past_key_value (Optional[Tuple[torch.Tensor]], optional): The past key value cache. Defaults to None.
output_attentions (Optional[bool], optional): Whether to output attentions. Defaults to False.
use_cache (Optional[bool], optional): Whether to use caching. Defaults to False.
cache_position (Optional[torch.LongTensor], optional): The position for caching. Defaults to None.
**kwargs: Additional keyword arguments.
Returns:
torch.Tensor: The output tensor from the decoder layer.
"""
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*):
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
query_sequence_length, key_sequence_length)` if default attention is used.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
"Document the inputs the LLAMA model accepts (`model_input_ids`, `attention_mask`, etc.) See the superclass "
"documentation for more details."
LLAMA_INPUTS_DOCSTRING,
)
(
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
LLAMA_START_DOCSTRING,
)
class LlamaForCausalLM(LlamaPreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.model = LlamaModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
):
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, **kwargs
):
@staticmethod
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past
"""
The Llama Model transformer with a sequence classification head on top (linear layer).
[`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
(e.g. GPT-2) do.
Since it does classification on the last token, it requires to know the position of the last token. If a
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
each row of the batch).
"""
@add_start_docstrings(
"""
The Llama Model transformer with a span classification head on top for extractive question-answering tasks like
SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
""",
LLAMA_START_DOCSTRING,
)
class LlamaForQuestionAnswering(LlamaPreTrainedModel):
base_model_prefix = "transformer"
def __init__(self, config):
super().__init__(config)
self.transformer = LlamaModel(config)
self.qa_outputs = nn.Linear(config.hidden_size, 2)
self.post_init()
def get_input_embeddings(self):
return self.transformer.embed_tokens
def set_input_embeddings(self, value):
self.transformer.embed_tokens = value
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
"""
# 定义一个方法 `forward`,用于模型的前向传播
def forward(
# 输入序列的 token IDs,类型为长整型张量,可选参数
input_ids: Optional[torch.LongTensor] = None,
# 注意力遮罩,类型为单精度浮点张量,可选参数
attention_mask: Optional[torch.FloatTensor] = None,
# 位置编码 ID,类型为长整型张量,可选参数
position_ids: Optional[torch.LongTensor] = None,
# 过去的键值对,类型为浮点张量列表,可选参数
past_key_values: Optional[List[torch.FloatTensor]] = None,
# 输入的嵌入张量,类型为单精度浮点张量,可选参数
inputs_embeds: Optional[torch.FloatTensor] = None,
# 起始位置,类型为长整型张量,可选参数
start_positions: Optional[torch.LongTensor] = None,
# 结束位置,类型为长整型张量,可选参数
end_positions: Optional[torch.LongTensor] = None,
# 是否输出注意力张量,布尔类型,可选参数
output_attentions: Optional[bool] = None,
# 是否输出隐藏状态,布尔类型,可选参数
output_hidden_states: Optional[bool] = None,
# 是否返回字典格式的结果,布尔类型,可选参数
return_dict: Optional[bool] = None,
) -> Union[Tuple, QuestionAnsweringModelOutput]:
r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
"""
# 确保返回的字典存在,如果未提供则使用配置中的默认设置
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# 使用 Transformer 模型处理输入,获取输出
outputs = self.transformer(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# 从模型输出中提取序列输出
sequence_output = outputs[0]
# 将序列输出传递给问答模型输出层,获取开始和结束位置的 logits
logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None
if start_positions is not None and end_positions is not None:
# 如果在多GPU环境中,需要添加一个维度
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1).to(start_logits.device)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1).to(end_logits.device)
# 忽略超出模型输入长度的位置
ignored_index = start_logits.size(1)
start_positions = start_positions.clamp(0, ignored_index)
end_positions = end_positions.clamp(0, ignored_index)
# 定义交叉熵损失函数,忽略指定索引
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
# 如果不要求返回字典形式的输出,则按元组形式返回结果
if not return_dict:
output = (start_logits, end_logits) + outputs[2:]
return ((total_loss,) + output) if total_loss is not None else output
# 返回 QuestionAnsweringModelOutput 类型的对象,包含损失、开始和结束位置的 logits,以及隐藏状态和注意力权重
return QuestionAnsweringModelOutput(
loss=total_loss,
start_logits=start_logits,
end_logits=end_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)