Transformers 源码解析(六十)
.\models\instructblip\processing_instructblip.py
"""
Processor class for InstructBLIP. Largely copy of Blip2Processor with addition of a tokenizer for the Q-Former.
"""
import os
from typing import List, Optional, Union
from ...image_processing_utils import BatchFeature
from ...image_utils import ImageInput
from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
from ...utils import TensorType
from ..auto import AutoTokenizer
class InstructBlipProcessor(ProcessorMixin):
r"""
Constructs an InstructBLIP processor which wraps a BLIP image processor and a LLaMa/T5 tokenizer into a single
processor.
[`InstructBlipProcessor`] offers all the functionalities of [`BlipImageProcessor`] and [`AutoTokenizer`]. See the
docstring of [`~BlipProcessor.__call__`] and [`~BlipProcessor.decode`] for more information.
Args:
image_processor (`BlipImageProcessor`):
An instance of [`BlipImageProcessor`]. The image processor is a required input.
tokenizer (`AutoTokenizer`):
An instance of ['PreTrainedTokenizer`]. The tokenizer is a required input.
qformer_tokenizer (`AutoTokenizer`):
An instance of ['PreTrainedTokenizer`]. The Q-Former tokenizer is a required input.
"""
attributes = ["image_processor", "tokenizer"]
image_processor_class = "BlipImageProcessor"
tokenizer_class = "AutoTokenizer"
def __init__(self, image_processor, tokenizer, qformer_tokenizer):
super().__init__(image_processor, tokenizer)
self.qformer_tokenizer = qformer_tokenizer
def __call__(
self,
images: ImageInput = None,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
add_special_tokens: bool = True,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = None,
max_length: Optional[int] = None,
stride: int = 0,
pad_to_multiple_of: Optional[int] = None,
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
return_token_type_ids: bool = False,
return_length: bool = False,
verbose: bool = True,
return_tensors: Optional[Union[str, TensorType]] = None,
**kwargs,
) -> BatchFeature:
"""
使用 [`BlipImageProcessor.__call__`] 方法准备模型的图像数据,
和 [`BertTokenizerFast.__call__`] 方法准备模型的文本数据。
更多信息请参考上述两个方法的文档字符串。
"""
if images is None and text is None:
raise ValueError("You have to specify at least images or text.")
encoding = BatchFeature()
if text is not None:
text_encoding = self.tokenizer(
text=text,
add_special_tokens=add_special_tokens,
padding=padding,
truncation=truncation,
max_length=max_length,
stride=stride,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_offsets_mapping=return_offsets_mapping,
return_token_type_ids=return_token_type_ids,
return_length=return_length,
verbose=verbose,
return_tensors=return_tensors,
**kwargs,
)
encoding.update(text_encoding)
qformer_text_encoding = self.qformer_tokenizer(
text=text,
add_special_tokens=add_special_tokens,
padding=padding,
truncation=truncation,
max_length=max_length,
stride=stride,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_offsets_mapping=return_offsets_mapping,
return_token_type_ids=return_token_type_ids,
return_length=return_length,
verbose=verbose,
return_tensors=return_tensors,
**kwargs,
)
encoding["qformer_input_ids"] = qformer_text_encoding.pop("input_ids")
encoding["qformer_attention_mask"] = qformer_text_encoding.pop("attention_mask")
if images is not None:
image_encoding = self.image_processor(images, return_tensors=return_tensors)
encoding.update(image_encoding)
return encoding
def batch_decode(self, *args, **kwargs):
"""
此方法将所有参数转发给 PreTrainedTokenizer 的 [`~PreTrainedTokenizer.batch_decode`] 方法。
详细信息请参考该方法的文档字符串。
"""
return self.tokenizer.batch_decode(*args, **kwargs)
def decode(self, *args, **kwargs):
"""
This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to
the docstring of this method for more information.
"""
return self.tokenizer.decode(*args, **kwargs)
@property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
image_processor_input_names = self.image_processor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
def save_pretrained(self, save_directory, **kwargs):
if os.path.isfile(save_directory):
raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file")
os.makedirs(save_directory, exist_ok=True)
qformer_tokenizer_path = os.path.join(save_directory, "qformer_tokenizer")
self.qformer_tokenizer.save_pretrained(qformer_tokenizer_path)
return super().save_pretrained(save_directory, **kwargs)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
qformer_tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="qformer_tokenizer")
args = cls._get_arguments_from_pretrained(pretrained_model_name_or_path, **kwargs)
args.append(qformer_tokenizer)
return cls(*args)
.\models\instructblip\__init__.py
from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
_import_structure = {
"configuration_instructblip": [
"INSTRUCTBLIP_PRETRAINED_CONFIG_ARCHIVE_MAP",
"InstructBlipConfig",
"InstructBlipQFormerConfig",
"InstructBlipVisionConfig",
],
"processing_instructblip": ["InstructBlipProcessor"],
}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_instructblip"] = [
"INSTRUCTBLIP_PRETRAINED_MODEL_ARCHIVE_LIST",
"InstructBlipQFormerModel",
"InstructBlipPreTrainedModel",
"InstructBlipForConditionalGeneration",
"InstructBlipVisionModel",
]
if TYPE_CHECKING:
from .configuration_instructblip import (
INSTRUCTBLIP_PRETRAINED_CONFIG_ARCHIVE_MAP,
InstructBlipConfig,
InstructBlipQFormerConfig,
InstructBlipVisionConfig,
)
from .processing_instructblip import InstructBlipProcessor
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_instructblip import (
INSTRUCTBLIP_PRETRAINED_MODEL_ARCHIVE_LIST,
InstructBlipForConditionalGeneration,
InstructBlipPreTrainedModel,
InstructBlipQFormerModel,
InstructBlipVisionModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
.\models\jukebox\configuration_jukebox.py
""" Jukebox 配置 """
import os
from typing import List, Union
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
JUKEBOX_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"openai/jukebox-5b-lyrics": "https://huggingface.co/openai/jukebox-5b-lyrics/blob/main/config.json",
"openai/jukebox-1b-lyrics": "https://huggingface.co/openai/jukebox-1b-lyrics/blob/main/config.json",
}
_LARGE_ATTENTION = [
"block_attn",
"transpose_block_attn",
"prev_block_attn",
"block_attn",
"transpose_block_attn",
"prev_block_attn",
"block_attn",
"transpose_block_attn",
"prev_block_attn",
"block_attn",
"transpose_block_attn",
"prev_block_attn",
"block_attn",
"transpose_block_attn",
"prev_block_attn",
"block_attn",
"transpose_block_attn",
"prev_block_attn",
"cross_attention",
"block_attn",
"transpose_block_attn",
"prev_block_attn",
"block_attn",
"transpose_block_attn",
"prev_block_attn",
"block_attn",
"transpose_block_attn",
"prev_block_attn",
"cross_attention",
"block_attn",
"transpose_block_attn",
"prev_block_attn",
"block_attn",
"transpose_block_attn",
"prev_block_attn",
"block_attn",
"transpose_block_attn",
"prev_block_attn",
"cross_attention",
"block_attn",
"transpose_block_attn",
"prev_block_attn",
"block_attn",
"transpose_block_attn",
"prev_block_attn",
"block_attn",
"transpose_block_attn",
"prev_block_attn",
"cross_attention",
"block_attn",
"transpose_block_attn",
"prev_block_attn",
"block_attn",
"transpose_block_attn",
"prev_block_attn",
"block_attn",
"transpose_block_attn",
"prev_block_attn",
"cross_attention",
]
_RawColumnPreviousRowAttention = ["block_attn", "transpose_block_attn", "prev_block_attn"]
_FullDenseAttention = ["dense_attention"]
_PrimePrimeDenseAttention = ["prime_attn", "prime_attn", "dense_attn"]
def full_dense_attention(layer):
return _FullDenseAttention[0]
def raw_column_previous_row_attention(layer):
return _RawColumnPreviousRowAttention[layer % 3]
def large_separated_enc_dec_w_lyrics(layer):
return _LARGE_ATTENTION[layer % 79]
def enc_dec_with_lyrics(layer):
if layer % 16 == 15:
return _PrimePrimeDenseAttention[layer % 3]
return _RawColumnPreviousRowAttention[layer % 3]
ATTENTION_PATTERNS = {
"full_dense_attention": full_dense_attention,
"raw_column_previous_row_attention": raw_column_previous_row_attention,
"large_separated_enc_dec_w_lyrics": large_separated_enc_dec_w_lyrics,
"enc_dec_with_lyrics": enc_dec_with_lyrics,
}
class JukeboxPriorConfig(PretrainedConfig):
"""
This is the configuration class to store the configuration of a [`JukeboxPrior`]. It is used to instantiate a
`JukeboxPrior` according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the top level prior from the
[openai/jukebox-1b-lyrics](https://huggingface.co/openai/jukebox
-1b-lyrics) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
"""
model_type = "jukebox_prior"
attribute_map = {
"max_position_embeddings": "n_positions",
"num_attention_heads": "n_head",
}
def __init__(
self,
act_fn="quick_gelu",
level=0,
alignment_head=2,
alignment_layer=68,
attention_multiplier=0.25,
attention_pattern="enc_dec_with_lyrics",
attn_dropout=0,
attn_res_scale=False,
blocks=64,
conv_res_scale=None,
num_layers=72,
emb_dropout=0,
encoder_config=None,
encoder_loss_fraction=0.4,
hidden_size=2048,
init_scale=0.2,
is_encoder_decoder=True,
lyric_vocab_size=80,
mask=False,
max_duration=600,
max_nb_genres=1,
merged_decoder=True,
metadata_conditioning=True,
metadata_dims=[604, 7898],
min_duration=0,
mlp_multiplier=1.0,
music_vocab_size=2048,
n_ctx=6144,
n_heads=2,
nb_relevant_lyric_tokens=384,
res_conv_depth=3,
res_conv_width=128,
res_convolution_multiplier=1,
res_dilation_cycle=None,
res_dilation_growth_rate=1,
res_downs_t=[3, 2, 2],
res_strides_t=[2, 2, 2],
resid_dropout=0,
sampling_rate=44100,
spread=None,
timing_dims=64,
zero_out=False,
**kwargs,
):
):
self.act_fn = act_fn
self.alignment_head = alignment_head
self.alignment_layer = alignment_layer
self.attention_multiplier = attention_multiplier
self.attention_pattern = attention_pattern
self.attn_dropout = attn_dropout
self.attn_res_scale = attn_res_scale
self.blocks = blocks
self.conv_res_scale = conv_res_scale
self.num_layers = num_layers
self.emb_dropout = emb_dropout
self.music_vocab_size = music_vocab_size
if encoder_config is not None:
self.encoder_config = JukeboxPriorConfig(**encoder_config)
else:
self.encoder_config = None
self.encoder_loss_fraction = encoder_loss_fraction
self.init_scale = init_scale
self.is_encoder_decoder = is_encoder_decoder
self.lyric_vocab_size = lyric_vocab_size
self.level = level
self.mask = mask
self.max_duration = max_duration
self.max_nb_genres = max_nb_genres
self.merged_decoder = merged_decoder
self.metadata_conditioning = metadata_conditioning
self.metadata_dims = metadata_dims
self.min_duration = min_duration
self.mlp_multiplier = mlp_multiplier
self.n_ctx = n_ctx
self.n_heads = n_heads
self.nb_relevant_lyric_tokens = nb_relevant_lyric_tokens
self.res_conv_depth = res_conv_depth
self.res_conv_width = res_conv_width
self.res_convolution_multiplier = res_convolution_multiplier
self.res_dilation_cycle = res_dilation_cycle
self.res_dilation_growth_rate = res_dilation_growth_rate
self.res_downs_t = res_downs_t
self.res_strides_t = res_strides_t
self.resid_dropout = resid_dropout
self.sampling_rate = sampling_rate
self.spread = spread
self.timing_dims = timing_dims
self.hidden_size = hidden_size
self.zero_out = zero_out
@classmethod
def from_pretrained(
cls, pretrained_model_name_or_path: Union[str, os.PathLike], level=0, **kwargs
) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
if config_dict.get("model_type") == "jukebox":
config_dict = config_dict[f"prior_{level}"]
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)
return cls.from_dict(config_dict, **kwargs)
class JukeboxVQVAEConfig(PretrainedConfig):
"""
This is the configuration class to store the configuration of a [`JukeboxVQVAE`]. It is used to instantiate a
`JukeboxVQVAE` according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the VQVAE from
[openai/jukebox-1b-lyrics](https://huggingface.co/openai/jukebox-1b-lyrics) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
# 定义一个函数,用于构建 VQVAE 模型
def build_model(
act_fn: str = "relu", # 激活函数,默认为 ReLU
nb_discrete_codes: int = 2048, # VQVAE 的离散码数量,默认为 2048
commit: float = 0.02, # Commit loss 的乘数,默认为 0.02
conv_input_shape: int = 1, # 音频通道数,默认为 1
conv_res_scale: bool = False, # 是否缩放 JukeboxResConv1DBlock 的残差,默认为 False
embed_dim: int = 64, # Codebook 向量的嵌入维度,默认为 64
hop_fraction: List[int] = [0.125, 0.5, 0.5], # 进行采样过程时使用的非交叠窗口的分数列表,默认为 [0.125, 0.5, 0.5]
levels: int = 3, # 在 VQVAE 中使用的层级数,默认为 3
lmu: float = 0.99, # 用于代码本更新的指数移动平均系数,默认为 0.99
multipliers: List[int] = [2, 1, 1], # 每个层级使用的深度和宽度乘数列表,默认为 [2, 1, 1]
res_conv_depth: int = 4, # 编码器和解码器块的深度,默认为 4
res_conv_width: int = 32, # 编码器和解码器块的宽度,默认为 32
res_convolution_multiplier: int = 1, # JukeboxResConv1DBlock 中隐藏维度的缩放因子,默认为 1
res_dilation_cycle: int = None, # JukeboxResnet 中使用的扩张周期值,默认为 None
res_dilation_growth_rate: int = 3, # VQVAE 中使用的 ResNet 扩张增长率,默认为 3
res_downs_t: List[int] = [3, 2, 2], # 分层 VQ-VAE 中每个层级的下采样率列表,默认为 [3, 2, 2]
res_strides_t: List[int] = [2, 2, 2], # 分层 VQ-VAE 中每个层级的步长列表,默认为 [2, 2, 2]
sample_length: int = 1058304, # VQVAE 的最大输入形状,默认为 1058304
init_scale: float = 0.2, # 初始化尺度,默认为 0.2
zero_out: bool = False, # 初始化时是否将卷积权重置零,默认为 False
):
"""
构建 VQVAE 模型,根据给定的参数设置各种配置和参数。
"""
# 函数体为空,用于声明函数的开始
pass
# 设定模型类型为 "jukebox_vqvae"
model_type = "jukebox_vqvae"
# 定义类的初始化方法,接受多个参数
def __init__(
self,
act_fn="relu", # 激活函数,默认为 relu
nb_discrete_codes=2048, # 离散代码数量,默认为 2048
commit=0.02, # commit 参数,默认为 0.02
conv_input_shape=1, # 卷积输入形状,默认为 1
conv_res_scale=False, # 是否使用卷积残差缩放,默认为 False
embed_dim=64, # 嵌入维度,默认为 64
hop_fraction=[0.125, 0.5, 0.5], # hop fraction 列表,默认值为 [0.125, 0.5, 0.5]
levels=3, # 级别数量,默认为 3
lmu=0.99, # lmu 参数,默认为 0.99
multipliers=[2, 1, 1], # 多重因子列表,默认为 [2, 1, 1]
res_conv_depth=4, # 卷积深度,默认为 4
res_conv_width=32, # 卷积宽度,默认为 32
res_convolution_multiplier=1, # 卷积乘数,默认为 1
res_dilation_cycle=None, # 膨胀周期,默认为 None
res_dilation_growth_rate=3, # 膨胀增长率,默认为 3
res_downs_t=[3, 2, 2], # 下采样 t 列表,默认为 [3, 2, 2]
res_strides_t=[2, 2, 2], # 步幅 t 列表,默认为 [2, 2, 2]
sample_length=1058304, # 样本长度,默认为 1058304
init_scale=0.2, # 初始化规模,默认为 0.2
zero_out=False, # 是否置零,默认为 False
**kwargs, # 其他关键字参数
):
self.hop_fraction = hop_fraction # 设置类属性 hop_fraction
self.conv_input_shape = conv_input_shape # 设置类属性 conv_input_shape
self.sample_length = sample_length # 设置类属性 sample_length
# 设置 VQVAE 参数(全部使用)
self.levels = levels
self.embed_dim = embed_dim
self.nb_discrete_codes = nb_discrete_codes
self.res_conv_width = res_conv_width
self.res_conv_depth = res_conv_depth
self.res_convolution_multiplier = res_convolution_multiplier
self.res_dilation_growth_rate = res_dilation_growth_rate
self.res_dilation_cycle = res_dilation_cycle
self.multipliers = multipliers
self.res_downs_t = res_downs_t
self.res_strides_t = res_strides_t
self.lmu = lmu
self.commit = commit
self.conv_res_scale = conv_res_scale
self.act_fn = act_fn
self.init_scale = init_scale
self.zero_out = zero_out
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs) # 在关键字参数中设置令牌
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) # 获取配置字典和更新后的关键字参数
# 如果加载的是 CLIPConfig,获取文本配置字典
if config_dict.get("model_type") == "jukebox":
config_dict = config_dict["vqvae_config"]
# 检查配置字典中的模型类型是否与类中定义的模型类型一致
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)
# 从配置字典和关键字参数实例化类并返回
return cls.from_dict(config_dict, **kwargs)
class JukeboxConfig(PretrainedConfig):
"""
This is the configuration class to store the configuration of a [`JukeboxModel`].
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information. Instantiating a configuration with the defaults will
yield a similar configuration to that of
[openai/jukebox-1b-lyrics](https://huggingface.co/openai/jukebox-1b-lyrics) architecture.
The downsampling and stride are used to determine downsampling of the input sequence. For example, downsampling =
(5,3), and strides = (2, 2) will downsample the audio by 2^5 = 32 to get the first level of codes, and 2**8 = 256
to get the second level codes. This is mostly true for training the top level prior and the upsamplers.
Args:
vqvae_config (`JukeboxVQVAEConfig`, *optional*):
Configuration for the `JukeboxVQVAE` model.
prior_config_list (`List[JukeboxPriorConfig]`, *optional*):
List of the configs for each of the `JukeboxPrior` of the model. The original architecture uses 3 priors.
nb_priors (`int`, *optional*, defaults to 3):
Number of prior models that will sequentially sample tokens. Each prior is conditional auto regressive
(decoder) model, apart from the top prior, which can include a lyric encoder. The available models were
trained using a top prior and 2 upsampler priors.
sampling_rate (`int`, *optional*, defaults to 44100):
Sampling rate of the raw audio.
timing_dims (`int`, *optional*, defaults to 64):
Dimensions of the JukeboxRangeEmbedding layer which is equivalent to traditional positional embedding
layer. The timing embedding layer converts the absolute and relative position in the currently sampled
audio to a tensor of length `timing_dims` that will be added to the music tokens.
min_duration (`int`, *optional*, defaults to 0):
Minimum duration of the audios to generate
max_duration (`float`, *optional*, defaults to 600.0):
Maximum duration of the audios to generate
max_nb_genres (`int`, *optional*, defaults to 5):
Maximum number of genres that can be used to condition a single sample.
metadata_conditioning (`bool`, *optional*, defaults to `True`):
Whether or not to use metadata conditioning, corresponding to the artist, the genre and the min/maximum
duration.
Example:
```
>>> from transformers import JukeboxModel, JukeboxConfig
>>>
>>> configuration = JukeboxConfig()
>>>
>>> model = JukeboxModel(configuration)
>>>
>>> configuration = model.config
```
"""
# 类型标识符,用于标识该配置类是`jukebox`类型的配置
model_type = "jukebox"
# 初始化方法,用于实例化 JukeboxConfig 对象
def __init__(
self,
vqvae_config=None,
prior_config_list=None,
nb_priors=3,
sampling_rate=44100,
timing_dims=64,
min_duration=0,
max_duration=600.0,
max_nb_genres=5,
metadata_conditioning=True,
**kwargs,
):
# 如果 vqvae_config 为 None,则用空字典初始化
if vqvae_config is None:
vqvae_config = {}
# 记录日志,说明 vqvae_config 是 None,使用默认值初始化 JukeboxVQVAE
logger.info("vqvae_config is None. initializing the JukeboxVQVAE with default values.")
# 使用给定的 vqvae_config 字典创建 JukeboxVQVAEConfig 对象,并赋值给 self.vqvae_config
self.vqvae_config = JukeboxVQVAEConfig(**vqvae_config)
# 如果 prior_config_list 不为 None,则依次用 JukeboxPriorConfig 类实例化列表中的每个配置
if prior_config_list is not None:
self.prior_configs = [JukeboxPriorConfig(**prior_config) for prior_config in prior_config_list]
else:
# 否则初始化为空列表
self.prior_configs = []
# 对于每个 prior_idx 在 nb_priors 范围内,尝试从 kwargs 中获取配置信息,如果没有则使用空字典初始化
for prior_idx in range(nb_priors):
prior_config = kwargs.pop(f"prior_{prior_idx}", None)
if prior_config is None:
prior_config = {}
# 记录日志,说明该 prior_idx 的配置是 None,使用默认值初始化 JukeboxPriorConfig 列表
logger.info(
f"prior_{prior_idx}'s config is None. Initializing the JukeboxPriorConfig list with default"
" values."
)
# 使用 prior_config 字典创建 JukeboxPriorConfig 对象,并添加到 prior_configs 列表中
self.prior_configs.append(JukeboxPriorConfig(**prior_config))
# 将 vqvae_config 中的 hop_fraction 属性赋值给当前对象的 hop_fraction 属性
self.hop_fraction = self.vqvae_config.hop_fraction
# 将传入的各种元数据配置参数赋值给对象的相应属性
self.nb_priors = nb_priors
self.max_nb_genres = max_nb_genres
self.sampling_rate = sampling_rate
self.timing_dims = timing_dims
self.min_duration = min_duration
self.max_duration = max_duration
self.metadata_conditioning = metadata_conditioning
# 调用父类的初始化方法,传入剩余的 kwargs 参数
super().__init__(**kwargs)
@classmethod
def from_configs(cls, prior_configs: List[JukeboxPriorConfig], vqvae_config: JukeboxVQVAEConfig, **kwargs):
r"""
Instantiate a [`JukeboxConfig`] (or a derived class) from clip text model configuration and clip vision model
configuration.
Returns:
[`JukeboxConfig`]: An instance of a configuration object
"""
# 将 prior_configs 列表中每个配置对象转换为字典形式,存入 prior_config_list
prior_config_list = [config.to_dict() for config in prior_configs]
# 调用当前类的初始化方法,传入 prior_config_list 和 vqvae_config 的字典形式,以及 kwargs 参数
return cls(prior_config_list=prior_config_list, vqvae_config_dict=vqvae_config.to_dict(), **kwargs)
def to_dict(self):
# 重写父类的 to_dict 方法,将对象转换为字典形式
result = super().to_dict()
# 将 prior_configs 列表中每个配置对象转换为字典形式,存入 result 字典的 "prior_config_list" 键下
result["prior_config_list"] = [config.to_dict() for config in result.pop("prior_configs")]
return result
.\models\jukebox\convert_jukebox.py
"""Convert Jukebox checkpoints"""
import argparse
import json
import os
from pathlib import Path
import requests
import torch
from transformers import JukeboxConfig, JukeboxModel
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
PREFIX = "https://openaipublic.azureedge.net/jukebox/models/"
MODEL_MAPPING = {
"jukebox-1b-lyrics": [
"5b/vqvae.pth.tar",
"5b/prior_level_0.pth.tar",
"5b/prior_level_1.pth.tar",
"1b_lyrics/prior_level_2.pth.tar",
],
"jukebox-5b-lyrics": [
"5b/vqvae.pth.tar",
"5b/prior_level_0.pth.tar",
"5b/prior_level_1.pth.tar",
"5b_lyrics/prior_level_2.pth.tar",
],
}
def replace_key(key):
if key.endswith(".model.1.bias") and len(key.split(".")) > 10:
key = key.replace(".model.1.bias", ".conv1d_1.bias")
elif key.endswith(".model.1.weight") and len(key.split(".")) > 10:
key = key.replace(".model.1.weight", ".conv1d_1.weight")
elif key.endswith(".model.3.bias") and len(key.split(".")) > 10:
key = key.replace(".model.3.bias", ".conv1d_2.bias")
elif key.endswith(".model.3.weight") and len(key.split(".")) > 10:
key = key.replace(".model.3.weight", ".conv1d_2.weight")
if "conditioner_blocks.0." in key:
key = key.replace("conditioner_blocks.0", "conditioner_blocks")
if "prime_prior" in key:
key = key.replace("prime_prior", "encoder")
if ".emb." in key and "total" not in key and "absolute" not in key and "relative" not in key:
key = key.replace(".emb.", ".")
if key.endswith("k"):
return key.replace(".k", ".codebook")
if "y_emb." in key:
return key.replace("y_emb.", "metadata_embedding.")
if "x_emb.emb." in key:
key = key.replace("0.x_emb.emb", "embed_tokens")
if "prime_state_ln" in key:
return key.replace("prime_state_ln", "encoder.final_layer_norm")
if ".ln" in key:
return key.replace(".ln", ".layer_norm")
if "_ln" in key:
return key.replace("_ln", "_layer_norm")
if "prime_state_proj" in key:
return key.replace("prime_state_proj", "encoder.proj_in")
if "prime_x_out" in key:
return key.replace("prime_x_out", "encoder.lm_head")
if "prior.x_out" in key:
return key.replace("x_out", "fc_proj_out")
if "x_emb" in key:
return key.replace("x_emb", "embed_tokens")
return key
def fix_jukebox_keys(state_dict, model_state_dict, key_prefix, mapping):
new_dict = {}
import re
re_encoder_block_conv_in = re.compile(r"encoders.(\d*).level_blocks.(\d*).model.(\d*).(\d).(bias|weight)")
re_encoder_block_resnet = re.compile(
r"encoders.(\d*).level_blocks.(\d*).model.(\d*).(\d).model.(\d*).model.(\d*).(bias|weight)"
)
re_encoder_block_proj_out = re.compile(r"encoders.(\d*).level_blocks.(\d*).model.(\d*).(bias|weight)")
re_decoder_block_conv_out = re.compile(r"decoders.(\d*).level_blocks.(\d*).model.(\d*).(\d).(bias|weight)")
re_decoder_block_resnet = re.compile(
r"decoders.(\d*).level_blocks.(\d*).model.(\d*).(\d).model.(\d*).model.(\d*).(bias|weight)"
)
re_decoder_block_proj_in = re.compile(r"decoders.(\d*).level_blocks.(\d*).model.(\d*).(bias|weight)")
re_prior_cond_conv_out = re.compile(r"conditioner_blocks.(\d*).cond.model.(\d*).(\d).(bias|weight)")
re_prior_cond_resnet = re.compile(
r"conditioner_blocks.(\d*).cond.model.(\d*).(\d).model.(\d*).model.(\d*).(bias|weight)"
)
re_prior_cond_proj_in = re.compile(r"conditioner_blocks.(\d*).cond.model.(\d*).(bias|weight)")
return new_dict
@torch.no_grad()
def convert_openai_checkpoint(model_name=None, pytorch_dump_folder_path=None):
"""
Copy/paste/tweak model's weights to our Jukebox structure.
"""
for file in MODEL_MAPPING[model_name]:
if not os.path.isfile(f"{pytorch_dump_folder_path}/{file.split('/')[-1]}"):
r = requests.get(f"{PREFIX}{file}", allow_redirects=True)
os.makedirs(f"{pytorch_dump_folder_path}/", exist_ok=True)
open(f"{pytorch_dump_folder_path}/{file.split('/')[-1]}", "wb").write(r.content)
model_to_convert = MODEL_MAPPING[model_name.split("/")[-1]]
config = JukeboxConfig.from_pretrained(model_name)
model = JukeboxModel(config)
weight_dict = []
mapping = {}
for i, dict_name in enumerate(model_to_convert):
old_dic = torch.load(f"{pytorch_dump_folder_path}/{dict_name.split('/')[-1]}")["model"]
new_dic = {}
for k in old_dic.keys():
if k.endswith(".b"):
new_dic[k.replace("b", "bias")] = old_dic[k]
elif k.endswith(".w"):
new_dic[k.replace("w", "weight")] = old_dic[k]
elif "level_2" not in dict_name and "cond.model." in k:
new_dic[k.replace(".blocks.", ".model.")] = old_dic[k]
else:
new_dic[k] = old_dic[k]
key_prefix = "vqvae" if i == 0 else f"priors.{3 - i}"
new_dic = fix_jukebox_keys(new_dic, model.state_dict(), key_prefix, mapping)
weight_dict.append(new_dic)
vqvae_state_dict = weight_dict.pop(0)
model.vqvae.load_state_dict(vqvae_state_dict)
for i in range(len(weight_dict)):
model.priors[i].load_state_dict(weight_dict[2 - i])
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
with open(f"{pytorch_dump_folder_path}/mapping.json", "w") as txtfile:
json.dump(mapping, txtfile)
print(f"Saving model {model_name} to {pytorch_dump_folder_path}")
model.save_pretrained(pytorch_dump_folder_path)
return weight_dict
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name",
default="jukebox-5b-lyrics",
type=str,
help="Name of the model you'd like to convert.",
)
parser.add_argument(
"--pytorch_dump_folder_path",
default="jukebox-5b-lyrics-converted",
type=str,
help="Path to the output PyTorch model directory.",
)
args = parser.parse_args()
convert_openai_checkpoint(args.model_name, args.pytorch_dump_folder_path)
.\models\jukebox\modeling_jukebox.py
"""PyTorch Jukebox model."""
import math
import os
from typing import List, Optional, Tuple
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn import LayerNorm as FusedLayerNorm
from ...activations import ACT2FN
from ...modeling_utils import PreTrainedModel
from ...utils import add_start_docstrings, logging
from ...utils.logging import tqdm
from .configuration_jukebox import ATTENTION_PATTERNS, JukeboxConfig, JukeboxPriorConfig, JukeboxVQVAEConfig
logger = logging.get_logger(__name__)
JUKEBOX_PRETRAINED_MODEL_ARCHIVE_LIST = [
"openai/jukebox-1b-lyrics",
"openai/jukebox-5b-lyrics",
]
def filter_logits(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")):
"""
Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
logits (`torch.Tensor`):
logits distribution shape (vocabulary size)
top_k (`int`, *optional*, defaults to 0):
When `top_k >0` keep only top key tokens with highest probability (top-k filtering).
top_p (`int`, *optional*, defaults to 0):
When `top_p>0.0` keep the top tokens with cumulative probability >= `top_p` (nucleus filtering).
"""
logits = logits.clone()
top_k = min(top_k, logits.size(-1))
if top_k > 0:
indices_to_remove = logits < torch.topk(logits, top_k, dim=-1)[0][..., -1:]
logits[indices_to_remove] = filter_value
if top_p > 0.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool).scatter_(
dim=-1, index=sorted_indices, src=sorted_indices_to_remove
)
logits[indices_to_remove] = filter_value
return logits
def get_relevant_lyric_tokens(full_tokens, max_n_lyric_tokens, total_length, offset, duration):
"""
Extract only the relevant tokens based on the character position. A total of `max_n_lyric_tokens` tokens will be
returned. If the provided token sequence is smaller, it will be padded, otherwise, only characters ranging from the
midpoint - `max_n_lyric_tokens//2` to the midpoint + `max_n_lyric_tokens//2` will be returned. This *focuses* on
the most relevant tokens (in time) for the sequence.
Args:
full_tokens (`List[int]`):
List containing the token ids of the entire lyrics.
max_n_lyric_tokens (`int`):
Maximum number of lyric tokens to return.
total_length (`int`):
Total expected length of the music (not all of it is generated, see duration), in samples.
offset (`int`):
Starting sample in the music. If the offset is greater than 0, the lyrics will be shifted take that into
account
duration (`int`):
Expected duration of the generated music, in samples. The duration has to be smaller than the total length,
which represent the overall length of the signal,
"""
full_tokens = full_tokens[0]
if len(full_tokens) < max_n_lyric_tokens:
tokens = torch.cat(
[torch.zeros(max_n_lyric_tokens - len(full_tokens), dtype=torch.long).to(full_tokens.device), full_tokens]
)
indices = [-1] * (max_n_lyric_tokens - len(full_tokens)) + list(range(0, len(full_tokens)))
else:
midpoint = int(len(full_tokens) * (offset + duration / 2.0) / total_length)
midpoint = min(max(midpoint, max_n_lyric_tokens // 2), len(full_tokens) - max_n_lyric_tokens // 2)
tokens = full_tokens[midpoint - max_n_lyric_tokens // 2 : midpoint + max_n_lyric_tokens // 2]
indices = list(range(midpoint - max_n_lyric_tokens // 2, midpoint + max_n_lyric_tokens // 2))
return tokens.unsqueeze(dim=0), indices
def get_starts(total_length, n_ctx, hop_length):
starts = []
for start in range(0, total_length - n_ctx + hop_length, hop_length):
if start + n_ctx >= total_length:
start = total_length - n_ctx
starts.append(start)
return starts
def get_alignment(music_tokens, labels, prior, config):
level = prior.levels - 1
n_ctx = prior.n_ctx
tokens = music_tokens[level]
batch_size, total_length = tokens.shape[0], tokens.shape[1]
if total_length < n_ctx:
padding_length = n_ctx - total_length
tokens = torch.cat(
[tokens, torch.zeros(batch_size, n_ctx - total_length, dtype=tokens.dtype, device=tokens.device)], dim=1
)
total_length = tokens.shape[1]
else:
padding_length = 0
hop_length = int(config.hop_fraction[-level - 1] * prior.n_ctx)
alignment_head, alignment_layer = config.prior_alignment_head[0], config.prior_alignment_layer[0]
attn_layers = {alignment_layer}
alignment_hops = {}
indices_hops = {}
for start in tqdm(get_starts(total_length, n_ctx, hop_length), desc="Computing lyric to music alignment "):
end = start + n_ctx
metadata, indices_hop = prior.get_metadata(labels, start, config.sample_length, get_indices=True, offset=0)
tokens_bs = torch.chunk(tokens, batch_size, dim=0)
metadata_bs = torch.chunk(metadata, batch_size, dim=0)
w_hops = []
for tokens_i, metadata_i in zip(tokens_bs, metadata_bs):
w_hop = prior.forward_tokens(tokens_i[:, start:end], [], metadata_i, get_attn_weights=attn_layers)
w_hops.append(w_hop[0][:, alignment_head])
del w_hop
weights = torch.cat(w_hops, dim=0)
del w_hops
alignment_hop = weights.float().cpu().numpy()
del weights
indices_hops[start] = indices_hop
alignment_hops[start] = alignment_hop
alignments = []
for item in range(batch_size):
full_tokens = labels[0, 3:]
alignment = np.zeros((total_length, len(full_tokens) + 1))
for start in reversed(get_starts(total_length, n_ctx, hop_length)):
end = start + n_ctx
alignment_hop = alignment_hops[start][item]
indices = indices_hops[start][item]
alignment[start:end, indices] = alignment_hop
alignment = alignment[: total_length - padding_length, :-1]
alignments.append(alignment)
return alignments
def save_temp_audio(fname, lvl, metas, aud):
aud = torch.clamp(aud, -1, 1).cpu().numpy()
for i in list(range(aud.shape[0])):
if metas is not None:
artists, genres, lyrics = list(metas)[i].values()
path = f"{fname}/lvl_{lvl}-{artists}-{genres}-{lyrics[:5]}-{i}"
np.save(path, aud[i])
else:
np.save(f"{fname}/lvl_{lvl}-sample-{i}", aud[i])
def get_mask(mask, query_length, key_value_length, blocks, spread, device, sample, sample_t):
if mask is None or query_length == 1:
return None
offset = sample_t - query_length if sample else max(key_value_length - query_length, 0)
if mask == "autoregressive":
mask = torch.ones(query_length, key_value_length, device=device).tril(offset)
elif mask == "summary":
mask = torch.ones(query_length, query_length, device=device).tril()
mask = mask.view(query_length, blocks, query_length // blocks)[:, :-1, -key_value_length // blocks :]
mask = (
torch.nn.functional.pad(
mask,
(0, 0, 1, 0),
value=1,
)
.contiguous()
.view(query_length, key_value_length)
)
elif mask == "prime":
mask = torch.ones(query_length, key_value_length, device=device).tril(offset)
return mask.view(1, 1, query_length, key_value_length)
class JukeboxConv1D(nn.Module):
def __init__(self, input_width, output_width):
super().__init__()
self.input_width = input_width
self.output_width = output_width
weight = torch.empty(input_width, output_width)
bias = torch.zeros(output_width)
self.weight = nn.Parameter(weight)
self.bias = nn.Parameter(bias)
def forward(self, hidden_states):
size_out = (*hidden_states.size()[:-1], self.output_width)
hidden_states = torch.addmm(
self.bias.type_as(hidden_states),
hidden_states.view(-1, hidden_states.size(-1)),
self.weight.type_as(hidden_states),
)
hidden_states = hidden_states.view(*size_out)
return hidden_states
class JukeboxResConv1DBlock(nn.Module):
def __init__(self, config, conv_width, depth=1, res_scale=1.0):
super().__init__()
hidden_dim = config.res_convolution_multiplier * conv_width
dilation = config.res_dilation_growth_rate**depth
padding = dilation
self.res_scale = res_scale
self.activation = nn.ReLU()
self.conv1d_1 = nn.Conv1d(conv_width, hidden_dim, 3, 1, padding, dilation)
self.conv1d_2 = nn.Conv1d(hidden_dim, conv_width, 1, 1, 0)
def forward(self, hidden_states):
residuals = hidden_states
hidden_states = self.activation(hidden_states)
hidden_states = self.conv1d_1(hidden_states)
hidden_states = self.activation(hidden_states)
hidden_states = self.conv1d_2(hidden_states)
return residuals + self.res_scale * hidden_states
class JukeboxResnet1D(nn.Module):
def __init__(self, config, conv_width, n_depth, reverse_dilation=False):
super().__init__()
self.dilation_cycle = config.res_dilation_cycle
res_scale = 1.0 if not config.conv_res_scale else 1.0 / math.sqrt(n_depth)
blocks = []
for depth in range(n_depth):
block_depth = depth if self.dilation_cycle is None else depth % self.dilation_cycle
blocks.append(JukeboxResConv1DBlock(config, conv_width, block_depth, res_scale))
if reverse_dilation:
blocks = blocks[::-1]
self.resnet_block = nn.ModuleList(blocks)
def forward(self, hidden_states):
for block in self.resnet_block:
hidden_states = block(hidden_states)
return hidden_states
class JukeboxEncoderConvBlock(nn.Module):
def __init__(self, config, embed_dim, hidden_dim, depth, down_t, stride_t):
super().__init__()
blocks = []
filter_t = stride_t * 2
pad_t = stride_t // 2
if down_t > 0:
for i in range(down_t):
blocks.append(nn.Conv1d(embed_dim if i == 0 else hidden_dim, hidden_dim, filter_t, stride_t, pad_t))
blocks.append(JukeboxResnet1D(config, hidden_dim, depth))
self.proj_out = nn.Conv1d(hidden_dim, config.embed_dim, 3, 1, 1)
self.downsample_block = nn.ModuleList(blocks)
def forward(self, hidden_states):
for block in self.downsample_block:
hidden_states = block(hidden_states)
hidden_states = self.proj_out(hidden_states)
return hidden_states
class JukeboxEncoder(nn.Module):
def __init__(self, config, width, depth, levels, downs_t, strides_t):
super().__init__()
self.levels = levels
self.level_blocks = nn.ModuleList()
iterator = zip(list(range(self.levels)), downs_t, strides_t)
for i, down_t, stride_t in iterator:
self.level_blocks.append(
JukeboxEncoderConvBlock(
config, config.conv_input_shape if i == 0 else config.embed_dim, width, depth, down_t, stride_t
)
)
def forward(self, hidden_states):
all_hidden_states = []
for level in range(self.levels):
level_block = self.level_blocks[level]
hidden_states = level_block(hidden_states)
all_hidden_states.append(hidden_states)
return all_hidden_states
class JukeboxDecoderConvBock(nn.Module):
def __init__(self, config, embed_dim, hidden_dim, depth, down_t, stride_t, reverse_dilation=True):
self.embed_dim = embed_dim
self.hidden_dim = hidden_dim
super().__init__()
blocks = []
if down_t > 0:
filter_t = stride_t * 2
pad_t = stride_t // 2
self.proj_in = nn.Conv1d(embed_dim, hidden_dim, 3, 1, 1)
for i in range(down_t):
blocks.append(JukeboxResnet1D(config, hidden_dim, depth, reverse_dilation))
blocks.append(
nn.ConvTranspose1d(
hidden_dim, hidden_dim if i < down_t - 1 else embed_dim, filter_t, stride_t, pad_t
)
)
self.upsample_block = nn.ModuleList(blocks)
def forward(self, hidden_states):
hidden_states = self.proj_in(hidden_states)
for block in self.upsample_block:
hidden_states = block(hidden_states)
return hidden_states
class JukeboxDecoder(nn.Module):
def __init__(self, config, hidden_dim, depth, levels, downs_t, strides_t):
super().__init__()
self.levels = levels
self.level_blocks = nn.ModuleList()
for level, down_t, stride_t in zip(list(range(self.levels)), downs_t, strides_t):
self.level_blocks.append(
JukeboxDecoderConvBock(config, config.embed_dim, hidden_dim, depth, down_t, stride_t)
)
self.out = nn.Conv1d(config.embed_dim, config.conv_input_shape, 3, 1, 1)
def forward(self, hidden_states, all_levels=True):
hidden_state = hidden_states[-1]
for level in reversed(range(self.levels)):
level_block = self.level_blocks[level]
hidden_state = level_block(hidden_state)
if level != 0 and all_levels:
hidden_state = hidden_state + hidden_states[level - 1]
hidden_state = self.out(hidden_state)
return hidden_state
class JukeboxBottleneckBlock(nn.Module):
def __init__(self, config: JukeboxVQVAEConfig):
super().__init__()
self.nb_discrete_codes = config.nb_discrete_codes
self.codebook_width = config.embed_dim
self.mu = config.lmu
self.threshold = 1.0
self.init = False
self.codebook_sum = None
self.codebook_elem = None
self.register_buffer("codebook", torch.zeros(self.nb_discrete_codes, self.codebook_width))
def _tile(self, hidden_states):
dim, embed_width = hidden_states.shape
if dim < self.nb_discrete_codes:
n_repeats = (self.nb_discrete_codes + dim - 1) // dim
std = 0.01 / np.sqrt(embed_width)
hidden_states = hidden_states.repeat(n_repeats, 1)
hidden_states = hidden_states + torch.randn_like(hidden_states) * std
return hidden_states
def init_codebook(self, hidden_states):
nb_discrete_codes = self.nb_discrete_codes
self.init = True
codes = self._tile(hidden_states)
self.codebook = codes[torch.randperm(codes.shape[0])][:nb_discrete_codes]
self.codebook_sum = self.codebook
self.codebook_elem = torch.ones(nb_discrete_codes, device=self.codebook.device)
def update_codebook(self, hidden_states, latent_states):
mu, codebook_width, nb_discrete_codes = self.mu, self.codebook_width, self.nb_discrete_codes
with torch.no_grad():
latent_states_onehot = torch.zeros(nb_discrete_codes, hidden_states.shape[0], device=hidden_states.device)
latent_states_onehot.scatter_(0, latent_states.view(1, hidden_states.shape[0]), 1)
_codebook_sum = torch.matmul(latent_states_onehot, hidden_states)
_codebook_elem = latent_states_onehot.sum(dim=-1)
codes = self._tile(hidden_states)
_random_codebook = codes[torch.randperm(codes.shape[0])][:nb_discrete_codes]
old_codebook = self.codebook
self.codebook_sum = mu * self.codebook_sum + (1.0 - mu) * _codebook_sum
self.codebook_elem = mu * self.codebook_elem + (1.0 - mu) * _codebook_elem
usage = (self.codebook_elem.view(nb_discrete_codes, 1) >= self.threshold).float()
norm_code = self.codebook_sum.view(nb_discrete_codes, codebook_width) / self.codebook_elem.view(
nb_discrete_codes, 1
)
self.codebook = usage * (norm_code) + (1 - usage) * _random_codebook
_codebook_prob = _codebook_elem / torch.sum(_codebook_elem)
entropy = -torch.sum(_codebook_prob * torch.log(_codebook_prob + 1e-8))
used_curr = (_codebook_elem >= self.threshold).sum()
usage = torch.sum(usage)
dk = torch.norm(self.codebook - old_codebook) / np.sqrt(np.prod(old_codebook.shape))
return {"entropy": entropy, "used_curr": used_curr, "usage": usage, "dk": dk}
def preprocess(self, hidden_states):
hidden_states = hidden_states.permute(0, 2, 1).contiguous()
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
if hidden_states.shape[-1] == self.codebook_width:
prenorm = torch.norm(hidden_states - torch.mean(hidden_states)) / np.sqrt(np.prod(hidden_states.shape))
elif hidden_states.shape[-1] == 2 * self.codebook_width:
x1, x2 = hidden_states[..., : self.codebook_width], hidden_states[..., self.codebook_width :]
prenorm = (torch.norm(x1 - torch.mean(x1)) / np.sqrt(np.prod(x1.shape))) + (
torch.norm(x2 - torch.mean(x2)) / np.sqrt(np.prod(x2.shape))
)
hidden_states = x1 + x2
return hidden_states, prenorm
def postprocess(self, latent_states, dequantised_states, x_shape):
batch_size, time = x_shape
dequantised_states = dequantised_states.view(batch_size, time, -1).permute(0, 2, 1).contiguous()
latent_states = latent_states.view(batch_size, time)
return latent_states, dequantised_states
def quantise(self, latent_states):
codebook_weights = self.codebook.t()
distance = (
torch.sum(latent_states**2, dim=-1, keepdim=True)
- 2 * torch.matmul(latent_states, codebook_weights)
+ torch.sum(codebook_weights**2, dim=0, keepdim=True)
)
min_distance, music_tokens = torch.min(distance, dim=-1)
fit = torch.mean(min_distance)
return music_tokens, fit
def dequantise(self, music_tokens):
dequantised_states = F.embedding(music_tokens, self.codebook)
return dequantised_states
def encode(self, latent_states):
samples, _, seq_len = latent_states.shape
latent_states, _ = self.preprocess(latent_states)
music_tokens, _ = self.quantise(latent_states)
music_tokens = music_tokens.view(samples, seq_len)
return music_tokens
def decode(self, music_tokens):
samples, seq_len = music_tokens.shape
dequantised_states = self.dequantise(music_tokens)
dequantised_states = (
dequantised_states.view(samples, seq_len, self.codebook_width).permute(0, 2, 1).contiguous()
)
return dequantised_states
def forward(self, hidden_states, update_codebook=True):
samples, _, seq_len = hidden_states.shape
hidden_states, prenorm = self.preprocess(hidden_states)
if update_codebook and not self.init:
self.init_codebook(hidden_states)
music_tokens, fit = self.quantise(hidden_states)
dequantised_states = self.dequantise(music_tokens)
if update_codebook:
update_metrics = self.update_codebook(hidden_states, music_tokens)
else:
update_metrics = {}
commit_loss = torch.norm(dequantised_states.detach() - hidden_states) ** 2 / np.prod(hidden_states.shape)
dequantised_states = hidden_states + (dequantised_states - hidden_states).detach()
music_tokens, dequantised_states = self.postprocess(music_tokens, dequantised_states, (samples, seq_len))
return music_tokens, dequantised_states, commit_loss, dict(fit=fit, pn=prenorm, **update_metrics)
import torch.nn as nn
class JukeboxBottleneck(nn.Module):
def __init__(self, config, levels):
super().__init__()
self.levels = levels
self.level_blocks = nn.ModuleList()
for level in range(self.levels):
self.level_blocks.append(JukeboxBottleneckBlock(config))
def encode(self, raw_audio):
music_tokens = [
level_block.encode(hidden_states) for (level_block, hidden_states) in zip(self.level_blocks, raw_audio)
]
return music_tokens
def decode(self, music_tokens, start_level=0, end_level=None):
if end_level is None:
end_level = self.levels
quantised_audio = [
level_block.decode(z) for (level_block, z) in zip(self.level_blocks[start_level:end_level], music_tokens)
]
return quantised_audio
def forward(self, input_audio):
music_tokens, quantised_states, commit_losses, metrics = [], [], [], []
for level in range(self.levels):
level_block = self.level_blocks[-level - 1]
hidden_states = input_audio[level]
sampled_tokens, quantised_state, commit_loss, metric = level_block(
hidden_states, update_codebook=self.training
)
music_tokens.append(sampled_tokens)
if not self.training:
quantised_state = quantised_state.detach()
quantised_states.append(quantised_state)
commit_losses.append(commit_loss)
if self.training:
metrics.append(metric)
return music_tokens, quantised_states, commit_losses, metrics
JUKEBOX_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config (`JukeboxConfig`): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
@add_start_docstrings(
"""The Hierarchical VQ-VAE model used in Jukebox. This model follows the Hierarchical VQVAE paper from [Will Williams, Sam
Ringer, Tom Ash, John Hughes, David MacLeod, Jamie Dougherty](https://arxiv.org/abs/2002.08111).
""",
JUKEBOX_START_DOCSTRING,
)
class JukeboxVQVAE(PreTrainedModel):
config_class = JukeboxVQVAEConfig
base_model_prefix = "vqvae"
def _init_weights(self, module):
if isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=0.02 * self.config.init_scale)
elif isinstance(module, JukeboxConv1D):
if self.config.zero_out:
module.weight.data.zero_()
else:
module.weight.data.normal_(mean=0.0, std=0.02 * self.config.init_scale)
elif isinstance(module, JukeboxResConv1DBlock) and self.config.zero_out:
module.conv1d_2.weight.data.zero_()
module.conv1d_2.bias.data.zero_()
if isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
def __init__(self, config: JukeboxVQVAEConfig):
super().__init__(config)
downs_t = config.res_downs_t
strides_t = config.res_strides_t
if not config.sample_length:
downsamples = [stride**down for stride, down in zip(strides_t, downs_t)]
top_raw_to_tokens = np.prod(downsamples)
config.sample_length = (
config.sample_length_in_seconds * config.sampling_rate // top_raw_to_tokens
) * top_raw_to_tokens
config.sample_length = config.sample_length.astype(int)
self.nb_discrete_codes = config.nb_discrete_codes
self.commit = config.commit
self.sample_length = config.sample_length
self.downsamples = [stride**down for stride, down in zip(strides_t, downs_t)]
self.hop_lengths = np.cumprod(self.downsamples)
self.levels = levels = config.levels
self.music_tokens_shapes = [
(int(self.sample_length // self.hop_lengths[-level - 1])) for level in range(levels)
]
self.multipliers = config.multipliers if config.multipliers is not None else [1] * levels
self.encoders = nn.ModuleList()
self.decoders = nn.ModuleList()
for level in range(levels):
width = config.res_conv_width * self.multipliers[level]
depth = config.res_conv_depth * self.multipliers[level]
self.encoders.append(
JukeboxEncoder(config, width, depth, level + 1, downs_t[: level + 1], strides_t[: level + 1])
)
self.decoders.append(
JukeboxDecoder(config, width, depth, level + 1, downs_t[: level + 1], strides_t[: level + 1])
)
self.bottleneck = JukeboxBottleneck(config, levels)
def _decode(self, music_tokens, start_level=0, end_level=None):
if end_level is None:
end_level = self.levels
latent_states = self.bottleneck.decode(music_tokens, start_level=start_level, end_level=end_level)
decoder, dequantised_state = self.decoders[start_level], latent_states[0:1]
dequantised_state = decoder(dequantised_state, all_levels=False)
dequantised_state = dequantised_state.permute(0, 2, 1)
return dequantised_state
def decode(self, music_tokens, start_level=0, end_level=None, bs_chunks=1) -> torch.Tensor:
"""
将输入的 `music_tokens` 解码为它们的 `raw_audio` 表示。
Args:
music_tokens (`torch.LongTensor`):
音乐编码的张量,通过使用码本将其解码为原始音频。每个音乐编码应该是码本中相应 `code` 向量的索引。
start_level (`int`, *optional*):
解码过程开始的级别。默认为 0。
end_level (`int`, *optional*):
解码过程结束的级别。默认为 None。
bs_chunks (int, *optional*):
同时处理的块数。
Returns:
`torch.Tensor`: 解码后的原始音频张量。
"""
token_chunks = [torch.chunk(token, bs_chunks, dim=0) for token in music_tokens]
dequantised_states = []
for i in range(bs_chunks):
music_tokens_i = [chunks[i] for chunks in token_chunks]
dequantised_state = self._decode(music_tokens_i, start_level=start_level, end_level=end_level)
dequantised_states.append(dequantised_state)
return torch.cat(dequantised_states, dim=0)
def _encode(self, raw_audio, start_level=0, end_level=None):
if end_level is None:
end_level = self.levels
input_audio = raw_audio.permute(0, 2, 1).float()
latent_states = []
for level in range(self.levels):
encoder = self.encoders[level]
latent_state = encoder(input_audio)
latent_states.append(latent_state[-1])
music_tokens = self.bottleneck.encode(latent_states)
return music_tokens[start_level:end_level]
audio_chunks = torch.chunk(input_audio, bs_chunks, dim=0)
music_tokens_list = []
for chunk_i in audio_chunks:
music_tokens_i = self._encode(chunk_i, start_level=start_level, end_level=end_level)
music_tokens_list.append(music_tokens_i)
music_tokens = [torch.cat(music_tokens_level, dim=0) for music_tokens_level in zip(*music_tokens_list)]
return music_tokens
def sample(self, n_samples):
music_tokens = [
torch.randint(0, self.nb_discrete_codes, size=(n_samples, *music_tokens_shape), device="cpu")
for music_tokens_shape in self.music_tokens_shapes
]
return self.decode(music_tokens)
input_audio = raw_audio.permute(0, 2, 1).float()
latent_states = []
for level in range(self.levels):
encoder = self.encoders[level]
latent_state = encoder(input_audio)
latent_states.append(latent_state[-1])
_, music_tokens, commit_losses, _ = self.bottleneck(latent_states)
dequantised_states = []
for level in range(self.levels):
decoder = self.decoders[level]
dequantised_state = decoder(music_tokens[level : level + 1], all_levels=False)
dequantised_states.append(dequantised_state.permute(0, 2, 1))
commit_loss = sum(commit_losses)
loss = self.commit * commit_loss
return dequantised_states, loss
class JukeboxMLP(nn.Module):
def __init__(self, config):
super().__init__()
embed_dim = config.hidden_size
hidden_dim = int(config.mlp_multiplier * embed_dim)
self.c_fc = JukeboxConv1D(embed_dim, hidden_dim)
self.c_proj = JukeboxConv1D(hidden_dim, embed_dim)
self.act = ACT2FN[config.act_fn]
self.dropout = nn.Dropout(config.resid_dropout)
def forward(self, hidden_states):
hidden_states = self.c_fc(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.c_proj(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states
class JukeboxLayerNorm(FusedLayerNorm):
def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
super().__init__(normalized_shape, eps=eps, elementwise_affine=elementwise_affine)
self.width = np.prod(normalized_shape)
self.max_numel = 65535 * self.width
def forward(self, input):
if input.numel() > self.max_numel:
return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps).type_as(input)
else:
return super().forward(input).type_as(input)
class JukeboxAttention(nn.Module):
def __init__(self, config, n_ctx, attn_func="dense_attn"):
super().__init__()
self.embed_dim = config.hidden_size
self.n_heads = config.n_heads
self.dropout = config.attn_dropout
hidden_dim = int(config.attention_multiplier * self.embed_dim)
self.head_dim = hidden_dim // config.n_heads
self.n_ctx = n_ctx
self.hidden_dim = hidden_dim
self.scale = self.head_dim**-0.25
self.mask = config.mask
if attn_func == "cross_attention":
self.c_attn = JukeboxConv1D(self.embed_dim, hidden_dim)
self.c_enc_kv = JukeboxConv1D(self.embed_dim, hidden_dim * 2)
else:
self.c_attn = JukeboxConv1D(self.embed_dim, hidden_dim * 3)
self.c_proj = JukeboxConv1D(hidden_dim, self.embed_dim)
self.attn_dropout = nn.Dropout(config.attn_dropout)
self.resid_dropout = nn.Dropout(config.resid_dropout)
self.attn_func = attn_func
if attn_func == "cross_attention":
self.qkv = self.decode_qkv
elif attn_func == "prime_attn":
self.qkv = self.prime_qkv
else:
self.qkv = self.factored_qkv
ATTENTION_MAP = {
"dense_attn": (self.dense_attn, "autoregressive"),
"block_attn": (self.block_attn, "autoregressive"),
"transpose_block_attn": (self.transpose_block_attn, "autoregressive"),
"prev_block_attn": (self.prev_block_attn, None),
"summary_attn": (self.summary_attn, "summary"),
"summary_spread_attn": (self.summary_spread_attn, "summary"),
"cross_attention": (self.dense_attn, None),
"prime_attn": (self.prime_attn, "prime"),
}
self.attn, self.attn_mask = ATTENTION_MAP[attn_func]
self.blocks = config.blocks
self.spread = config.spread
if self.blocks is not None:
self.block_ctx = self.n_ctx // self.blocks
self.sample_t = 0
self.cache = {}
self.encoder_len = config.nb_relevant_lyric_tokens
self.record_attn = False
def _attn(self, query_states, key_states, value_states, sample):
scale = self.scale
if self.training:
attention_weight = torch.matmul(query_states * scale, key_states * scale)
else:
attention_weight = torch.matmul(query_states, key_states)
attention_weight.mul_(scale * scale)
attn_weight_type = attention_weight.dtype
attention_weight = attention_weight.float()
if self.mask:
mask = get_mask(
self.attn_mask,
query_states.size(-2),
key_states.size(-1),
self.blocks,
self.spread,
attention_weight.device,
sample,
self.sample_t,
)
if mask is not None:
attention_weight = attention_weight * mask + -1e9 * (1 - mask)
attention_prob = F.softmax(attention_weight, dim=-1).type(attn_weight_type)
if self.record_attn:
self.attention_prob = attention_prob
if self.attn_func == "prime_attn":
self.attention_prob = self.attention_prob[:, :, self.encoder_len :, : self.encoder_len]
attention_prob = self.attn_dropout(attention_prob)
context_states = torch.matmul(attention_prob, value_states)
return context_states
def merge_heads(self, hidden_states):
hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous()
new_hidden_states_shape = (*hidden_states.size()[:-2], hidden_states.size(-2) * hidden_states.size(-1))
return hidden_states.view(*new_hidden_states_shape)
def split_heads(self, hidden_states, is_key=False):
new_hidden_states_shape = (
*hidden_states.size()[:-1],
self.n_heads,
hidden_states.size(-1) // self.n_heads,
)
hidden_states = hidden_states.view(*new_hidden_states_shape)
if is_key:
return hidden_states.permute(0, 2, 3, 1)
else:
return hidden_states.permute(0, 2, 1, 3)
def dense_attn(self, query, key, value, sample):
query = self.split_heads(query)
key = self.split_heads(key, is_key=True)
value = self.split_heads(value)
context_states = self._attn(query, key, value, sample)
context_states = self.merge_heads(context_states)
return context_states
def block_attn(self, query, key, value, sample):
block_ctx = self.block_ctx
batch_size, seq_len, embed_dim = value.shape
if sample:
return self.dense_attn(query, key, value, sample).view(batch_size, 1, embed_dim)
else:
query_length = query.shape[1]
query = query.view(batch_size * query_length // block_ctx, block_ctx, embed_dim)
if query_length < seq_len:
seq_len = query_length
key = key[:, -seq_len:].contiguous()
value = value[:, -seq_len:].contiguous()
key = key.view(batch_size * seq_len // block_ctx, block_ctx, embed_dim)
value = value.view(batch_size * seq_len // block_ctx, block_ctx, embed_dim)
return self.dense_attn(query, key, value, sample).view(batch_size, seq_len, embed_dim)
def transpose_block_attn(self, query, key, value, sample):
block_ctx = self.block_ctx
batch_size, seq_len, embed_dim = value.shape
if sample:
block_len = (seq_len - 1) % block_ctx
key = key[:, block_len::block_ctx, :]
value = value[:, block_len::block_ctx, :]
return self.dense_attn(query, key, value, sample).view(batch_size, 1, embed_dim)
else:
query_length = query.shape[1]
query = query.view(batch_size, query_length // block_ctx, block_ctx, embed_dim)
query = query.transpose(1, 2).contiguous()
query = query.view(batch_size * block_ctx, query_length // block_ctx, embed_dim)
key = key.view(batch_size, seq_len // block_ctx, block_ctx, embed_dim)
key = key.transpose(1, 2).contiguous()
key = key.view(batch_size * block_ctx, seq_len // block_ctx, embed_dim)
value = value.view(batch_size, seq_len // block_ctx, block_ctx, embed_dim)
value = value.transpose(1, 2).contiguous()
value = value.view(batch_size * block_ctx, seq_len // block_ctx, embed_dim)
block_attn = self.dense_attn(query, key, value, sample)
block_attn = block_attn.view(batch_size, block_ctx, query_length // block_ctx, embed_dim)
block_attn = block_attn.transpose(1, 2).contiguous()
block_attn = block_attn.view(batch_size, query_length, embed_dim)
return block_attn
def prev_block_attn(self, query, key, value, sample):
block_ctx = self.block_ctx
batch_size, seq_len, embed_dim = value.shape
if sample:
block = (seq_len - 1) // block_ctx
prev_l = (block - 1) * block_ctx
if block > 0:
key = key[:, prev_l : prev_l + block_ctx, :]
value = value[:, prev_l : prev_l + block_ctx, :]
else:
key = torch.zeros(batch_size, block_ctx, embed_dim, device=query.device, dtype=query.dtype)
value = torch.zeros(batch_size, block_ctx, embed_dim, device=query.device, dtype=query.dtype)
return self.dense_attn(query, key, value, sample).view(batch_size, 1, embed_dim)
else:
query_length = query.shape[1]
query = query.view(batch_size * query_length // block_ctx, block_ctx, embed_dim)
key = key.view(batch_size, seq_len // block_ctx, block_ctx, embed_dim)[:, :-1, :, :]
key = torch.nn.functional.pad(key, (0, 0, 0, 0, 1, 0))
key = key.view(batch_size * seq_len // block_ctx, block_ctx, embed_dim)
value = value.view(batch_size, seq_len // block_ctx, block_ctx, embed_dim)[:, :-1, :, :]
value = torch.nn.functional.pad(value, (0, 0, 0, 0, 1, 0))
value = value.view(batch_size * seq_len // block_ctx, block_ctx, embed_dim)
if query_length < seq_len:
nb_query_blocks = query_length // block_ctx
nb_key_blocks = seq_len // block_ctx
seq_len = query_length
key = key.view(batch_size, nb_key_blocks, block_ctx, embed_dim)[:, -nb_query_blocks:]
key = key.contiguous().view(batch_size * nb_query_blocks, block_ctx, embed_dim)
value = value.view(batch_size, nb_key_blocks, block_ctx, embed_dim)[:, -nb_query_blocks:]
value = value.contiguous().view(batch_size * nb_query_blocks, block_ctx, embed_dim)
return self.dense_attn(query, key, value, sample).view(batch_size, seq_len, embed_dim)
def summary_attn(self, query, key, value, sample):
blocks = self.blocks
block_ctx = self.block_ctx
batch_size, seq_len, embed_dim = value.shape
if sample:
raise NotImplementedError
else:
key = key.view(batch_size, blocks, seq_len // blocks, embed_dim)[:, :-1, -1, :]
key = torch.nn.functional.pad(key, (0, 0, 1, 0))
value = value.view(batch_size, blocks, seq_len // blocks, embed_dim)[:, :-1, -1, :]
value = torch.nn.functional.pad(value, (0, 0, 1, 0))
return self.dense_attn(query, key, value, sample).view(batch_size, seq_len, embed_dim)
def summary_spread_attn(self, query, key, value, sample):
blocks = self.blocks
spread = self.spread
batch_size, seq_len, embed_dim = value.shape
if sample:
raise NotImplementedError
else:
key = key.view(batch_size, blocks, seq_len // blocks, embed_dim)[:, :-1, -spread:, :]
key = torch.nn.functional.pad(key, (0, 0, 0, 0, 1, 0)).contiguous()
value = value.view(batch_size, blocks, seq_len // blocks, embed_dim)[:, :-1, -spread:, :]
value = torch.nn.functional.pad(value, (0, 0, 0, 0, 1, 0)).contiguous()
return self.dense_attn(query, key, value, sample).view(batch_size, seq_len, embed_dim)
def prime_attn(self, query, key, value, sample):
encoder_len = self._encoder_len
key = key[:, :encoder_len]
value = value[:, :encoder_len]
return self.dense_attn(query, key, value, sample)
def factored_qkv(self, hidden_states, last_encoder_hidden_states=None, sample=False):
curr_ctx = hidden_states.shape[1]
if last_encoder_hidden_states is not None:
raise TypeError("last_encoder_hidden_states should be None")
query, key, value = hidden_states.chunk(3, dim=2)
if sample:
self.sample_t += curr_ctx
key, value = self._append_cache(key, value)
l_cache = self._suff_cache_len()
if self._cache_len() > l_cache:
self._slice_cache(-l_cache)
if curr_ctx > 1:
if self.attn_func != "dense_attn":
query = self._pad_to_block_ctx(query, query=True)
key = self._pad_to_block_ctx(key)
value = self._pad_to_block_ctx(value)
sample = False
else:
key = self.cache["key"]
value = self.cache["value"]
return query, key, value, sample
def prime_qkv(self, hidden_states, last_encoder_hidden_states=None, sample=False):
curr_ctx = hidden_states.shape[1]
if last_encoder_hidden_states is not None:
raise TypeError("last_encoder_hidden_states should be None")
query, key, value = hidden_states.chunk(3, dim=2)
if sample:
if self._cache_len() < self._encoder_len:
self._append_cache(key, value)
if self._cache_len() > self._encoder_len:
self._slice_cache(0, self._encoder_len)
key, value = self.cache["key"], self.cache["value"]
self.sample_t += curr_ctx
return query, key, value, sample
def decode_qkv(self, hidden_states, last_encoder_hidden_states=None, sample=False):
curr_ctx = hidden_states.shape[1]
query = hidden_states
if sample:
if self.sample_t == 0:
self.cache["key"], self.cache["value"] = self.c_enc_kv(
last_encoder_hidden_states.type_as(hidden_states)
).chunk(2, dim=2)
key, value = self.cache["key"], self.cache["value"]
self.sample_t += curr_ctx
else:
key, value = self.c_enc_kv(last_encoder_hidden_states.type_as(hidden_states)).chunk(2, dim=2)
return query, key, value, sample
def forward(self, hidden_states, last_encoder_hidden_states=None, sample=False):
curr_ctx = hidden_states.shape[1]
hidden_states = self.c_attn(hidden_states)
query, key, value, sample = self.qkv(
hidden_states, last_encoder_hidden_states=last_encoder_hidden_states, sample=sample
)
attention_scores = self.attn(query, key, value, sample)
if attention_scores.shape[1] != curr_ctx:
offset = self._offset(curr_ctx)
attention_scores = attention_scores[:, offset : offset + curr_ctx, :].contiguous()
attention_scores = self.c_proj(attention_scores)
return self.resid_dropout(attention_scores)
@property
def _encoder_len(self):
encoder_len = self.encoder_len
encoder_blocks = (encoder_len // self.blocks) + 1
return encoder_blocks * self.blocks
def _offset(self, curr_ctx):
if self.attn_func == "dense_attn":
return 0
return (self.sample_t - curr_ctx) % self.block_ctx
def _pad_to_block_ctx(self, hidden_states, query=False):
seq_len = hidden_states.shape[1]
offset = self._offset(seq_len) if query else 0
n_blocks = (seq_len + offset + self.block_ctx - 1) // self.block_ctx
pad = n_blocks * self.block_ctx - seq_len - offset
if pad == 0 and offset == 0:
return hidden_states
else:
return F.pad(hidden_states, (0, 0, offset, pad))
def _cache_len(self):
return 0 if "key" not in self.cache else self.cache["key"].shape[1]
def _suff_cache_len(self):
"""
前提条件:
键和值已经附加了当前上下文,并且self.sample_t反映了上下文中的1索引样本位置。
"""
previous_block_length = (self.sample_t - 1) % self.block_ctx + 1 + self.block_ctx
REQUIRED_CACHE_LEN = {
"dense_attn": self.sample_t,
"block_attn": (self.sample_t - 1) % self.block_ctx + 1,
"transpose_block_attn": self.sample_t,
"prev_block_attn": self.sample_t if self.sample_t <= self.block_ctx else previous_block_length,
"cross_attn": self.encoder_len,
"prime_attn": min(self.sample_t, self._encoder_len),
}
return REQUIRED_CACHE_LEN[self.attn_func]
def _slice_cache(self, start, end=None):
self.cache["key"] = self.cache["key"][:, start:end]
self.cache["value"] = self.cache["value"][:, start:end]
def _append_cache(self, key, value):
if "key" not in self.cache:
self.cache["key"] = key
self.cache["value"] = value
else:
old_key, old_value = key, value
key = torch.cat([self.cache["key"], old_key], dim=1)
value = torch.cat([self.cache["value"], old_value], dim=1)
del self.cache["key"]
del self.cache["value"]
del old_key
del old_value
self.cache["key"] = key
self.cache["value"] = value
return self.cache["key"], self.cache["value"]
def del_cache(self):
self.sample_t = 0
if "key" in self.cache:
del self.cache["key"]
if "value" in self.cache:
del self.cache["value"]
self.cache = {}
class JukeboxBlock(nn.Module):
def __init__(self, config, n_ctx, attn_func="dense_attn"):
super().__init__()
self.width = config.hidden_size
self.attn = JukeboxAttention(config, n_ctx, attn_func=attn_func)
self.layer_norm_0 = JukeboxLayerNorm(config.hidden_size)
self.mlp = JukeboxMLP(config)
self.layer_norm_1 = JukeboxLayerNorm(config.hidden_size)
self.res_scale = 1.0 / config.num_layers if config.attn_res_scale else 1.0
self.attn_func = attn_func
def forward(self, hidden_states, last_encoder_hidden_states, sample=False):
residuals = hidden_states
hidden_states = self.layer_norm_0(hidden_states)
hidden_states = self.attn(hidden_states, last_encoder_hidden_states, sample)
output_states = self.layer_norm_1(residuals + hidden_states)
output_states = self.mlp(output_states)
if self.res_scale == 1.0:
output = residuals + hidden_states + output_states
else:
output = residuals + self.res_scale * (hidden_states + output_states)
return output
class JukeboxLayerStack(nn.Module):
def __init__(self, config, n_ctx):
super().__init__()
self.n_ctx = n_ctx
self.width = config.hidden_size
self.num_layers = config.num_layers
self.blocks = config.blocks
self.attention_pattern = config.attention_pattern
if self.blocks is not None:
self.block_ctx = n_ctx // self.blocks
self.encoder_len = config.nb_relevant_lyric_tokens
self.n_heads = config.n_heads
attention_pattern = ATTENTION_PATTERNS[self.attention_pattern]
self._attn_mods = nn.ModuleList()
for depth in range(self.num_layers):
self._attn_mods.append(JukeboxBlock(config, n_ctx, attn_func=attention_pattern(depth)))
self.saved_attn_weights = []
def set_record_attn(self, record_attn):
"""
设置是否记录注意力 softmax 到 self.saved_attn_weights 中。
Args:
record_attn (`Union[bool,set]`):
若为 set 类型,表示要记录哪些层的注意力 softmax;若为 bool 类型,表示是否全部记录。
"""
def _should_record_attn(layer_idx):
if isinstance(record_attn, bool):
return record_attn
return layer_idx in record_attn
for i, layer in enumerate(self._attn_mods):
layer.attn.record_attn = _should_record_attn(i)
if not record_attn:
self.saved_attn_weights = []
def forward(self, hidden_states, last_encoder_hidden_states=None, sample=False):
for i, attn_layer in enumerate(self._attn_mods):
if attn_layer.attn_func == "cross_attention":
hidden_states = attn_layer(
hidden_states, last_encoder_hidden_states=last_encoder_hidden_states, sample=sample
)
else:
hidden_states = attn_layer(hidden_states, last_encoder_hidden_states=None, sample=sample)
if attn_layer.attn.record_attn:
self.saved_attn_weights.append(attn_layer.attn.c_attn.weight)
return hidden_states
def del_cache(self):
for attn_layer in self._attn_mods:
attn_layer.attn.del_cache()
class JukeboxPositionalEmbedding(nn.Module):
def __init__(self, embed_dim, width):
super().__init__()
self.pos_emb = nn.Parameter(torch.empty((embed_dim, width)))
def forward(self):
pos_emb = self.pos_emb
return pos_emb
class JukeboxConditionalAutoregressive(nn.Module):
def __init__(
self,
config,
n_ctx=None,
embed_dim=None,
audio_conditioning=False,
metadata_conditioning=False,
is_encoder=False,
):
super().__init__()
"""
Autoregressive model on either lyric tokens or music tokens, or both. The attention pattern should be properly
set fro each configuration.
Args:
config (`JukeboxPriorConfig`):
Model configuration class with all the parameters of the model. Initializing with a config file does
not load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
n_ctx (`int`, *optional*):
Number of tokens or lyrics tokens provided in a single pass.
embed_dim (`int`, *optional*):
Either equals to the dimension of the codebook, or the sum of n_vocab (lyrics) and codeboook dimension,
if the model combines lyrics and music tokens, or simply n_vocab if the model is a seperate encoder
audio_conditioning (`bool`, *optional`, defaults to `False`):
Whether or not the prior supports conditioning on audio.
metadata_conditioning (`bool`, *optional`, defaults to `False`):
Whether or not the prior supports conditioning on artist, genres, lyrics, and timing.
is_encoder (`bool`, *optional`, defaults to `False`):
Whether the model is an encoder only model.
"""
super().__init__()
self.width = config.hidden_size
self.num_layers = config.num_layers
self.n_ctx = n_ctx if n_ctx is not None else config.n_ctx
self.embed_dim = embed_dim if embed_dim is not None else config.music_vocab_size
self.embed_tokens = nn.Embedding(self.embed_dim, config.hidden_size)
self.embed_tokens_dropout = nn.Dropout(config.emb_dropout)
self.metadata_conditioning = metadata_conditioning
self.audio_conditioning = audio_conditioning
if not metadata_conditioning:
self.start_token = nn.Parameter(torch.empty((1, config.hidden_size)))
self.pos_emb = JukeboxPositionalEmbedding(self.n_ctx, config.hidden_size)
self.pos_emb_dropout = nn.Dropout(config.emb_dropout)
self.transformer = JukeboxLayerStack(config, n_ctx=self.n_ctx)
self.is_encoder = is_encoder
self.encoder_len = config.nb_relevant_lyric_tokens
if config.merged_decoder:
self.add_cond_after_transformer = False
self.share_embed_tokens_fc_proj_out = False
else:
self.add_cond_after_transformer = True
self.share_embed_tokens_fc_proj_out = True
if not is_encoder:
self.fc_proj_out = nn.Linear(config.hidden_size, self.embed_dim, bias=False)
if self.share_embed_tokens_fc_proj_out:
self.fc_proj_out.weight = self.embed_tokens.weight
self.loss = torch.nn.CrossEntropyLoss()
def forward(
self,
tokens,
audio_conditioning=None,
metadata_conditioning=None,
last_encoder_hidden_states=None,
get_preds=False,
get_acts=False,
get_sep_loss=False,
):
"""
Args:
tokens (`torch.tensor`):
Can represent music tokens, lyrics tokens or both, depending on the configuration.
"""
batch_size = tokens.shape[0]
with torch.no_grad():
tokens = tokens.view(batch_size, -1).long()
if not self.audio_conditioning:
audio_conditioning = torch.zeros(
(batch_size, 1, self.width),
device=tokens.device,
dtype=self.transformer._attn_mods[0].mlp.c_fc.weight.dtype,
)
target = tokens
hidden_states = self.embed_tokens(tokens)
hidden_states = torch.cat((hidden_states[:, -1:], hidden_states[:, :-1]), dim=1)
if self.metadata_conditioning:
hidden_states[:, 0] = metadata_conditioning.view(batch_size, self.width)
else:
hidden_states[:, 0] = self.start_token
hidden_states = (
self.embed_tokens_dropout(hidden_states) + self.pos_emb_dropout(self.pos_emb()) + audio_conditioning
)
hidden_states = self.transformer(
hidden_states, last_encoder_hidden_states=last_encoder_hidden_states
)
if self.add_cond_after_transformer:
hidden_states = hidden_states + audio_conditioning
activations = hidden_states
if self.is_encoder:
return hidden_states
hidden_states = self.fc_proj_out(hidden_states)
loss_fn = nn.CrossEntropyLoss()
if get_sep_loss:
lyric_hidden_states = hidden_states[:, : self.encoder_len].reshape(-1, self.embed_dim)
token_hidden_states = hidden_states[:, self.encoder_len :].reshape(-1, self.embed_dim)
lyric_loss = loss_fn(lyric_hidden_states, target[:, : self.encoder_len].reshape(-1)) / np.log(2.0)
music_token_loss = loss_fn(token_hidden_states, target[:, self.encoder_len :].reshape(-1)) / np.log(2.0)
loss = (lyric_loss, music_token_loss)
else:
loss = loss_fn(hidden_states.view(-1, self.embed_dim), target.view(-1)) / np.log(2.0)
if get_preds:
return loss, hidden_states
elif get_acts:
return loss, activations
else:
return loss, None
def get_emb(self, sample_t, n_samples, tokens, audio_conditioning, metadata_conditioning):
if sample_t == 0:
hidden_states = torch.empty(n_samples, 1, self.width, dtype=self.embed_tokens.weight.dtype).to(
self.embed_tokens.weight.device
)
if self.metadata_conditioning:
hidden_states[:, 0] = metadata_conditioning.view(n_samples, self.width)
else:
hidden_states[:, 0] = self.start_token
else:
hidden_states = self.embed_tokens(tokens)
if audio_conditioning.shape == (n_samples, self.n_ctx, self.width):
cond = audio_conditioning[:, sample_t : sample_t + 1, :]
else:
cond = audio_conditioning
hidden_states = hidden_states + self.pos_emb()[sample_t : sample_t + 1] + cond
return hidden_states, cond
):
if sample_tokens is None:
sample_tokens = self.n_ctx
if not self.audio_conditioning:
audio_conditioning = torch.zeros(
(n_samples, 1, self.width), dtype=self.transformer._attn_mods[0].mlp.c_fc.weight.dtype
).to(self.fc_proj_out.device)
with torch.no_grad():
sampled_tokens = []
tokens = None
if get_preds:
preds = []
iter = tqdm(range(0, sample_tokens), leave=False)
for sample_t in iter:
iter.set_description(f"Ancestral sampling {sample_tokens} music tokens", refresh=True)
hidden_states, cond = self.get_emb(
sample_t, n_samples, tokens, audio_conditioning, metadata_conditioning
)
hidden_states = self.transformer(
hidden_states, last_encoder_hidden_states=last_encoder_hidden_states, sample=True
)
if self.add_cond_after_transformer:
hidden_states = hidden_states + cond
hidden_states = self.fc_proj_out(hidden_states)
if get_preds:
preds.append(hidden_states.clone())
hidden_states = hidden_states / temp
hidden_states = filter_logits(hidden_states, top_k=top_k, top_p=top_p)
tokens = torch.distributions.Categorical(logits=hidden_states).sample()
sampled_tokens.append(tokens.clone())
del tokens
self.transformer.del_cache()
tokens = torch.cat(sampled_tokens, dim=1)
if get_preds:
preds = torch.cat(preds, dim=1)
if get_preds:
return tokens, preds
else:
return tokens
def split_chunks(self, length, chunk_size):
n_passes = (length + chunk_size - 1) // chunk_size
chunk_sizes = [*[chunk_size] * (n_passes - 1), (length - 1) % chunk_size + 1]
return chunk_sizes
def primed_sample(
self,
n_samples,
lyric_and_music_tokens,
audio_conditioning=None,
metadata_conditioning=None,
last_encoder_hidden_states=None,
temp=1.0,
top_k=0,
top_p=0.0,
get_preds=False,
chunk_size=None,
sample_tokens=None,
class JukeboxMusicTokenConditioner(nn.Module):
"""
The `JukeboxMusicTokenConditioner` takes music tokens as an input (coresponding to the codes of the VQVAE's
codebook) and upsamples it using a single layer of decoder convolution block (the same is used in the VQVAE).
"""
def __init__(self, config, level):
super().__init__()
self.embed_tokens = nn.Embedding(config.music_vocab_size, config.hidden_size)
config.embed_dim = config.music_vocab_size
self.upsampler = JukeboxDecoderConvBock(
config,
config.hidden_size,
config.res_conv_width,
config.res_conv_depth,
config.res_downs_t[level],
config.res_strides_t[level],
reverse_dilation=False,
)
self.layer_norm = JukeboxLayerNorm(config.hidden_size)
def forward(self, music_tokens, raw_audio_conditionning=None):
"""
Args:
music_tokens (`torch.LongTensor`):
Music tokens form the uper level in range(nb_discrete_codes)
raw_audio_conditionning (`torch.LongTensor`, *optional*):
Audio used when primed sampling, raw audio information that conditions the generation
"""
if raw_audio_conditionning is None:
raw_audio_conditionning = 0.0
music_tokens = music_tokens.long()
hidden_states = self.embed_tokens(music_tokens)
hidden_states = hidden_states + raw_audio_conditionning
hidden_states = hidden_states.permute(0, 2, 1)
hidden_states = self.upsampler(hidden_states)
hidden_states = hidden_states.permute(0, 2, 1)
hidden_states = self.layer_norm(hidden_states)
return hidden_states
class JukeboxRangeEmbedding(nn.Module):
"""
The `JukeboxRangeEmbedding` interpolate the given [pos_start, pos_end] to obtain an equivalent of time positional
embedding of length `n_ctx`.
Binning process : For each pos in position tensor, find its bin [start,end) mapped to [0,1,...,bins-1] [start,end)
-> [0,1) -> [0, bins) -> floor -> [0,...,bins-1] NOTE: Open ended interval on right, so start <= pos < end, not <=
end
"""
def __init__(self, n_time, embed_dim, range, out_width, clamp=False):
super().__init__()
self.emb = nn.Embedding(embed_dim, out_width)
self.n_time = n_time
self.embed_dim = embed_dim
self.pos_min, self.pos_max = range
self.clamp = clamp
def forward(self, pos_start, pos_end=None):
if not len(pos_start.shape) == 2:
raise TypeError(f"Expected shape with 2 dims, got {pos_start.shape}")
if not (self.pos_min <= pos_start).all() and (pos_start < self.pos_max).all():
raise TypeError(f"Range is [{self.pos_min},{self.pos_max}), got {pos_start}")
pos_start = pos_start.float()
if pos_end is not None:
if self.clamp:
pos_end = pos_end.clamp(self.pos_min, self.pos_max)
pos_end = pos_end.float()
n_time = self.n_time
if n_time != 1:
interpolation = (
torch.arange(0, n_time, dtype=torch.float, device=pos_start.device).view(1, n_time) / n_time
)
position = pos_start + (pos_end - pos_start) * interpolation
else:
position = pos_start
normalised_position = (position - self.pos_min) / (self.pos_max - self.pos_min)
bins_ = (self.embed_dim * normalised_position).floor().long().detach()
return self.emb(bins_)
class JukeboxLabelConditioner(nn.Module):
def __init__(self, config, include_time_signal):
super().__init__()
embed_dim = config.hidden_size
timing_dims = config.timing_dims
sampling_rate = config.sampling_rate
nb_genres, nb_artists = config.metadata_dims
music_tokens_shape = config.n_ctx
self.max_nb_genres = config.max_nb_genres
self.bow_genre_emb = nn.Embedding(nb_genres, embed_dim)
self.artist_emb = nn.Embedding(nb_artists, embed_dim)
self.include_time_signal = include_time_signal
if self.include_time_signal:
total_length_range = (config.min_duration * sampling_rate, config.max_duration * sampling_rate)
absolute_pos_range = (0.0, config.max_duration * sampling_rate)
relative_pos_range = (0.0, 1.0)
self.total_length_emb = JukeboxRangeEmbedding(1, timing_dims, total_length_range, embed_dim)
self.absolute_pos_emb = JukeboxRangeEmbedding(
music_tokens_shape, timing_dims, absolute_pos_range, embed_dim
)
self.relative_pos_emb = JukeboxRangeEmbedding(
music_tokens_shape, timing_dims, relative_pos_range, embed_dim, clamp=True
)
def forward(self, metadata):
total_length = metadata[:, 0:1]
offset = metadata[:, 1:2]
length = metadata[:, 2:3]
artist = metadata[:, 3:4]
genre = metadata[:, 4:]
artist_emb = self.artist_emb(artist)
mask = (genre >= 0).float().unsqueeze(2)
genre_emb = (self.bow_genre_emb(genre.clamp(0)) * mask).sum(dim=1, keepdim=True)
start_emb = genre_emb + artist_emb
if self.include_time_signal:
start, end = offset, offset + length
total_length = total_length.float()
start = start.float()
end = end.float()
pos_emb = (
self.total_length_emb(total_length)
+ self.absolute_pos_emb(start, end)
+ self.relative_pos_emb(start / total_length, end / total_length)
)
else:
pos_emb = None
return start_emb, pos_emb
config_class = JukeboxPriorConfig
def _init_weights(self, module):
init_scale = self.config.init_scale
if isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=0.02 * init_scale)
elif isinstance(module, JukeboxConv1D):
if self.config.zero_out:
module.weight.data.zero_()
else:
module.weight.data.normal_(mean=0.0, std=0.02 * init_scale)
elif isinstance(module, JukeboxPositionalEmbedding):
module.pos_emb.data.normal_(mean=0.0, std=0.01 * init_scale)
elif isinstance(module, JukeboxRangeEmbedding):
module.emb.weight.data.normal_(mean=0.0, std=0.01 * init_scale)
elif isinstance(module, JukeboxConditionalAutoregressive) and hasattr(module, "lm_head"):
module.lm_head.weight.data.normal_(mean=0.0, std=0.02 * init_scale)
elif isinstance(module, JukeboxConditionalAutoregressive) and hasattr(module, "start_token"):
module.start_token.data.normal_(mean=0.0, std=0.01 * init_scale)
elif isinstance(module, JukeboxResConv1DBlock) and self.config.zero_out:
module.conv1d_2.weigth.data.zero_()
module.conv1d_2.bias.data.zero_()
if isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
````
def get_metadata(self, labels, start, total_length, offset, get_indices=False):
metadata = labels.clone()
metadata[:, 0] = total_length
metadata[:, 2] = int(self.sample_length)
metadata[:, 1:2] = int(offset * self.raw_to_tokens) + int(start * self.raw_to_tokens)
metadata, indices = self.set_metadata_lyric_tokens(metadata)
if get_indices:
return metadata, indices
else:
return metadata
def set_metadata_lyric_tokens(self, labels):
"""
处理完整的标签,只提取相关的歌词 token,并保持元数据的条件 token。
"""
if self.nb_relevant_lyric_tokens > 0:
tokens_list = torch.zeros(
(labels.shape[0], self.nb_relevant_lyric_tokens), dtype=torch.long, device=labels.device
)
indices_list = []
for idx in range(labels.shape[0]):
full_tokens = labels.clone()[:, 4 + self.metadata_embedding.max_nb_genres :]
total_length, offset, duration = labels[idx, 0], labels[idx, 1], labels[idx, 2]
tokens, indices = get_relevant_lyric_tokens(
full_tokens, self.nb_relevant_lyric_tokens, total_length, offset, duration
)
tokens_list[idx, :] = tokens
indices_list.append(indices)
return (
torch.cat((labels[:, : 4 + self.metadata_embedding.max_nb_genres], tokens_list), dim=-1),
indices_list,
)
else:
return labels, None
def get_music_tokens_conds(self, music_tokens, start, end):
"""
提取当前层级的条件音乐 token。
"""
if self.level != 0:
music_tokens_cond = music_tokens[self.level - 1]
music_tokens = music_tokens_cond[:, start // self.cond_downsample : end // self.cond_downsample]
missing_cond_len = self.n_ctx // self.cond_downsample - music_tokens_cond[-1].shape[-1]
if missing_cond_len > 0:
init_cond = torch.zeros(1, missing_cond_len).to(music_tokens_cond.device)
music_tokens_cond = torch.cat((music_tokens_cond, init_cond), dim=-1).long()
music_tokens_conds = [music_tokens_cond]
else:
music_tokens_conds = None
return music_tokens_conds
def prior_preprocess(self, tokens, conds):
"""
Shifts the input tokens to account for the dictionary merge. The embed_dim_shift give by how much the music
tokens should be shifted by. It is equal to `lyric_vocab_size`.
"""
batch_size = tokens[0].shape[0]
for i in range(len(tokens)):
tokens[i] = (tokens[i] + int(self.embed_dim_shift[i])).view(batch_size, -1)
for i in range(len(conds)):
if conds[i] is None:
conds[i] = torch.zeros(
(batch_size, self.input_shapes[i], self.width), dtype=tokens[0].dtype, device=tokens[0].device
)
return torch.cat(tokens, dim=1), torch.cat(conds, dim=1)
def prior_postprocess(self, tokens):
"""
Shifts back the input tokens if the model uses an encoder decoder architecture. As the embedding layer is
shared, `prior_embed_dim_shift` shifts the music token ids by `lyric_vocab_size`. Only returns the music
tokens.
"""
batch_size = tokens.shape[0]
dims = (self.input_shapes[0], tokens.shape[1] - self.input_shapes[0])
tokens = list(torch.split(tokens, dims, dim=1))
for i in range(len(tokens)):
bins_shift = int(self.embed_dim_shift[i])
tokens[i] = (tokens[i] - bins_shift).view(batch_size, -1)
tokens[i] = torch.clamp(tokens[i], min=0)
return tokens[-1]
def embed_tokens(self, music_tokens_conds):
"""
Embeds the upper level music tokens and upsamples them to provide as audio conditioning.
"""
music_tokens_conds = music_tokens_conds[: self.cond_level + 1]
audio_conditioning = None
for music_tokens_cond, conditioner_block in reversed(list(zip(music_tokens_conds, [self.conditioner_blocks]))):
audio_conditioning = conditioner_block(music_tokens_cond, audio_conditioning)
return audio_conditioning
def encode(self, hidden_states, start_level=None, end_level=None, bs_chunks=1):
"""
Encodes the hidden states (raw audio) using the VQVAE's encoder. Returns latent_states.
"""
if start_level is None:
start_level = self.level
if end_level is None:
end_level = self.levels
with torch.no_grad():
latent_states = self.vqvae_encoder(
hidden_states, start_level=start_level, end_level=end_level, bs_chunks=bs_chunks
)
return latent_states
def decode(self, music_tokens, start_level=None, end_level=None, bs_chunks=1):
"""
Usamples the sequence of codebook vectors to a raw audio.
"""
if start_level is None:
start_level = self.level
if end_level is None:
end_level = self.levels
with torch.no_grad():
output = self.vqvae_decoder(
music_tokens, start_level=start_level, end_level=end_level, bs_chunks=bs_chunks
)
return output
def get_cond(self, music_tokens_conds, metadata):
"""
Converts the input tokens to input_embeddings. Splits the lyrics form the rest of the metadata. Lyric tokens
can be None.
"""
if metadata is not None:
n_labels = metadata.shape[1] - self.nb_relevant_lyric_tokens
metadata, lyric_tokens = metadata[:, :n_labels], metadata[:, n_labels:]
else:
metadata, lyric_tokens = None, None
metadata_conditioning, metadata_pos = (
self.metadata_embedding(metadata) if self.metadata_conditioning else (None, None)
)
audio_conditioning = self.embed_tokens(music_tokens_conds) if self.audio_conditioning else metadata_pos
return audio_conditioning, metadata_conditioning, lyric_tokens
def sample(
self,
n_samples,
music_tokens=None,
music_tokens_conds=None,
metadata=None,
temp=1.0,
top_k=0,
top_p=0.0,
chunk_size=None,
sample_tokens=None,
):
def get_encoder_states(self, lyric_tokens, sample=False):
"""
Retreive the last hidden_states of the lyric encoder that will be attended to by the decoder. Forwards through
the lyric encoder.
"""
if self.nb_relevant_lyric_tokens != 0 and self.lyric_conditioning:
if sample:
self.encoder = self.encoder.to(lyric_tokens.device)
lyric_acts = self.encoder(lyric_tokens, None, None, None)
lyric_acts = self.encoder.proj_in(lyric_acts)
last_encoder_hidden_states = self.encoder.final_layer_norm(lyric_acts)
else:
last_encoder_hidden_states = None
return last_encoder_hidden_states
def get_encoder_loss(self, last_encoder_hidden_states, target_lyrics):
"""
Computes the loss for the lyric encoder: next lyric token prediction.
"""
if self.lyric_conditioning:
last_encoder_hidden_states = self.encoder.lm_head(last_encoder_hidden_states)
encoder_loss = nn.functional.cross_entropy(
last_encoder_hidden_states.view(-1, self.encoder_dim), target_lyrics.view(-1)
) / np.log(2.0)
else:
encoder_loss = torch.tensor(0.0, device=last_encoder_hidden_states.device)
return encoder_loss
def forward_tokens(
self, music_tokens, music_tokens_conds=[], metadata=None, get_preds=False, get_attn_weights=False
):
"""
Applies a forward pass using the conditioning tokens. Different from the classic forward as it does not use the
vqvae's encoding layers.
"""
if get_attn_weights:
self.prior.transformer.set_record_attn(get_attn_weights)
audio_conditioning, metadata_conditioning, lyric_tokens = self.get_cond(music_tokens_conds, metadata)
if self.is_encoder_decoder:
tokens, audio_conditioning = self.prior_preprocess(
[lyric_tokens, music_tokens], [None, audio_conditioning]
)
(encoder_loss, next_token_prediction_loss), preds = self.prior(
tokens, audio_conditioning, metadata_conditioning, get_sep_loss=True, get_preds=get_preds
)
else:
last_encoder_hidden_states = self.get_encoder_states(lyric_tokens)
encoder_loss = self.get_encoder_loss(last_encoder_hidden_states, lyric_tokens)
next_token_prediction_loss, preds = self.prior(
music_tokens,
audio_conditioning,
metadata_conditioning,
last_encoder_hidden_states,
get_preds=get_preds,
)
loss = self.encoder_loss_fraction * encoder_loss * self.nb_relevant_lyric_tokens / self.total_loss_dims
loss += next_token_prediction_loss * self.next_token_prediction_loss_dims / self.total_loss_dims
metrics = {
"bpd": next_token_prediction_loss.clone().detach(),
"encoder_loss": encoder_loss.clone().detach(),
"next_token_prediction_loss": next_token_prediction_loss.clone().detach(),
}
if get_preds:
metrics["preds"] = preds.clone().detach()
if get_attn_weights:
saved_attn_weights = self.prior.transformer.saved_attn_weights
self.prior.transformer.set_record_attn(False)
return saved_attn_weights
else:
return loss, metrics
) -> List[torch.Tensor]:
"""
Encode the hidden states using the `vqvae` encoder, and then predicts the next token in the `forward_tokens`
function. The loss is the sum of the `encoder` loss and the `decoder` loss.
Args:
hidden_states (`torch.Tensor`):
Hidden states which should be raw audio
metadata (`List[torch.LongTensor]`, *optional*):
List containing the metadata conditioning tensor with the lyric and the metadata tokens.
decode (`bool`, *optional*, defaults to `False`):
Whether or not to decode the encoded to tokens.
get_preds (`bool`, *optional*, defaults to `False`):
Whether or not to return the actual predictions of the model.
"""
batch_size = hidden_states.shape[0]
music_tokens, *music_tokens_conds = self.encode(hidden_states, bs_chunks=batch_size)
loss, metrics = self.forward_tokens(
music_tokens=music_tokens,
music_tokens_conds=music_tokens_conds,
metadata=metadata,
get_preds=get_preds,
)
if decode:
dequantised_states = self.decode([music_tokens, *music_tokens_conds])
else:
dequantised_states = None
return dequantised_states, loss, metrics
class JukeboxPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = JukeboxConfig
base_model_prefix = "jukebox"
supports_gradient_checkpointing = False
def _init_weights(self, module):
if isinstance(module, JukeboxPrior) or isinstance(module, JukeboxVQVAE):
module.apply(module._init_weights)
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
JUKEBOX_SAMPLING_INPUT_DOCSTRING = r"""
labels (`List[torch.LongTensor]` of length `n_sample`, and shape `(self.levels, self.config.max_nb_genre + lyric_sequence_length)` :
List of metadata such as `artist_id`, `genre_id` and the full list of lyric tokens which are used to
condition the generation.
sampling_kwargs (`Dict[Any]`):
Various additional sampling arguments that are used by the `_sample` function. A detail list of the
arguments can bee seen in the [`_sample`] function documentation.
"""
@add_start_docstrings(
"""The bare JUKEBOX Model used for music generation. 4 sampling techniques are supported : `primed_sample`, `upsample`,
`continue_sample` and `ancestral_sample`. It does not have a `forward` method as the training is not end to end. If
you want to fine-tune the model, it is recommended to use the `JukeboxPrior` class and train each prior
individually.
""",
JUKEBOX_START_DOCSTRING,
)
class JukeboxModel(JukeboxPreTrainedModel):
_no_split_modules = ["JukeboxBlock"]
def __init__(self, config):
super().__init__(config)
vqvae_config = config.vqvae_config
self.vqvae = JukeboxVQVAE(vqvae_config)
self.set_shared_params(config)
self.priors = nn.ModuleList(
[JukeboxPrior(config.prior_configs[level], level) for level in range(config.nb_priors)]
)
def set_shared_params(self, model_config):
"""
Initialises the parameters that are shared. This has to be done here because the list of `JukeboxPriorConfig`
is nest, and is thus unreachable in the `from_dict` function
"""
for config in model_config.prior_configs:
config.sampling_rate = model_config.sampling_rate
config.timing_dims = model_config.timing_dims
config.min_duration = model_config.min_duration
config.max_duration = model_config.max_duration
config.max_nb_genres = model_config.max_nb_genres
config.metadata_conditioning = model_config.metadata_conditioning
def decode(self, music_tokens, start_level=0, end_level=None, bs_chunks=1):
return self.vqvae.decode(music_tokens, start_level, end_level, bs_chunks)
def encode(self, input_audio, start_level=0, end_level=None, bs_chunks=1):
return self.vqvae.encode(input_audio, start_level, end_level, bs_chunks)
def split_batch(self, obj, n_samples, split_size):
n_passes = (n_samples + split_size - 1) // split_size
if isinstance(obj, torch.Tensor):
return torch.split(obj, split_size, dim=0)
elif isinstance(obj, list):
return list(zip(*[torch.split(item, split_size, dim=0) for item in obj]))
elif obj is None:
return [None] * n_passes
else:
raise TypeError("Unknown input type")
def sample_partial_window(
self, music_tokens, labels, offset, sampling_kwargs, level, tokens_to_sample, max_batch_size
):
prior = self.priors[level]
sampled_tokens = music_tokens[level]
n_ctx = prior.n_ctx
nb_sampled_tokens = sampled_tokens.shape[1]
if nb_sampled_tokens < n_ctx - tokens_to_sample:
sampling_kwargs["sample_tokens"] = nb_sampled_tokens + tokens_to_sample
start = 0
else:
sampling_kwargs["sample_tokens"] = n_ctx
start = nb_sampled_tokens - n_ctx + tokens_to_sample
return self.sample_single_window(music_tokens, labels, offset, sampling_kwargs, level, start, max_batch_size)
def sample_single_window(self, music_tokens, labels, offset, sampling_kwargs, level, start, max_batch_size):
prior = self.priors[level]
n_samples = music_tokens[0].shape[0]
n_ctx = prior.n_ctx
end = start + n_ctx
previous_sampled_tokens = music_tokens[level][:, start:end]
sample_tokens = sampling_kwargs.get("sample_tokens", None)
if "sample_tokens" in sampling_kwargs:
sample_tokens = end - start
conditioning_tokens = previous_sampled_tokens.shape[1]
new_tokens = sample_tokens - previous_sampled_tokens.shape[1]
logger.info(
f"Sampling {sample_tokens} tokens for [{start},{start+sample_tokens}]. Conditioning on"
f" {conditioning_tokens} tokens"
)
if new_tokens <= 0:
return music_tokens
music_tokens_conds = prior.get_music_tokens_conds(music_tokens, start, end)
metadata = prior.get_metadata(labels, start, self.total_length, offset)
music_tokens_list = self.split_batch(previous_sampled_tokens, n_samples, max_batch_size)
music_tokens_conds_list = self.split_batch(music_tokens_conds, n_samples, max_batch_size)
metadata_list = self.split_batch(metadata, n_samples, max_batch_size)
tokens = []
iterator = tqdm(zip(music_tokens_list, music_tokens_conds_list, metadata_list), leave=False)
for music_tokens_i, music_tokens_conds_i, metadata_i in iterator:
name = ["Ancestral", "Primed"][music_tokens_i.shape[1] == 0]
iterator.set_description(
f"[prior level {level}] {name} Sampling {sample_tokens} tokens out of"
f" {self.total_length//prior.raw_to_tokens}",
refresh=True,
)
tokens_i = prior.sample(
n_samples=music_tokens_i.shape[0],
music_tokens=music_tokens_i,
music_tokens_conds=music_tokens_conds_i,
metadata=metadata_i,
**sampling_kwargs,
)
tokens.append(tokens_i)
sampled_tokens = torch.cat(tokens, dim=0)
music_tokens_new = sampled_tokens[:, -new_tokens:]
music_tokens[level] = torch.cat([music_tokens[level], music_tokens_new], dim=1)
return music_tokens
def sample_level(
self, music_tokens, labels, offset, sampling_kwargs, level, total_length, hop_length, max_batch_size
):
if total_length >= self.priors[level].n_ctx:
iterator = get_starts(total_length, self.priors[level].n_ctx, hop_length)
for start in iterator:
music_tokens = self.sample_single_window(
music_tokens, labels, offset, sampling_kwargs, level, start, max_batch_size
)
else:
music_tokens = self.sample_partial_window(
music_tokens, labels, offset, sampling_kwargs, level, total_length, max_batch_size
)
return music_tokens
@torch.no_grad()
def _sample(
self,
music_tokens,
labels,
sample_levels,
metas=None,
chunk_size=32,
sampling_temperature=0.98,
lower_batch_size=16,
max_batch_size=16,
sample_length_in_seconds=24,
compute_alignments=False,
sample_tokens=None,
offset=0,
save_results=True,
sample_length=None,
):
@add_start_docstrings(
"""
Generates music tokens based on the provided `labels. Will start at the desired prior level and automatically
upsample the sequence. If you want to create the audio, you should call `model.decode(tokens)`, which will use
the VQ-VAE decoder to convert the music tokens to raw audio.
Args:
labels (`List[torch.LongTensor]`) :
List of length `n_sample`, and shape `(self.levels, 4 + self.config.max_nb_genre +
lyric_sequence_length)` metadata such as `artist_id`, `genre_id` and the full list of lyric tokens
which are used to condition the generation.
n_samples (`int`, *optional*, default to 1) :
Number of samples to be generated in parallel.
""",
)
def ancestral_sample(self, labels, n_samples=1, **sampling_kwargs) -> List[torch.LongTensor]:
"""
Example:
```
>>> from transformers import AutoTokenizer, JukeboxModel, set_seed
>>> model = JukeboxModel.from_pretrained("openai/jukebox-1b-lyrics", min_duration=0).eval()
>>> tokenizer = AutoTokenizer.from_pretrained("openai/jukebox-1b-lyrics")
>>> lyrics = "Hey, are you awake? Can you talk to me?"
>>> artist = "Zac Brown Band"
>>> genre = "Country"
>>> metas = tokenizer(artist=artist, genres=genre, lyrics=lyrics)
>>> set_seed(0)
>>> music_tokens = model.ancestral_sample(metas.input_ids, sample_length=400)
>>> with torch.no_grad():
... model.decode(music_tokens)[:, :10].squeeze(-1)
tensor([[-0.0219, -0.0679, -0.1050, -0.1203, -0.1271, -0.0936, -0.0396, -0.0405,
-0.0818, -0.0697]])
```
"""
sample_levels = sampling_kwargs.pop("sample_levels", list(range(len(self.priors))))
music_tokens = [
torch.zeros(n_samples, 0, dtype=torch.long, device=labels[0].device) for _ in range(len(self.priors))
]
music_tokens = self._sample(music_tokens, labels, sample_levels, **sampling_kwargs)
return music_tokens
@add_start_docstrings(
"""Generates a continuation of the previously generated tokens.
Args:
music_tokens (`List[torch.LongTensor]` of length `self.levels` ) :
A sequence of music tokens which will be used as context to continue the sampling process. Should have
`self.levels` tensors, each corresponding to the generation at a certain level.
""",
JUKEBOX_SAMPLING_INPUT_DOCSTRING,
)
def continue_sample(self, music_tokens, labels, **sampling_kwargs) -> List[torch.LongTensor]:
sample_levels = sampling_kwargs.pop("sample_levels", list(range(len(self.priors))))
music_tokens = self._sample(music_tokens, labels, sample_levels, **sampling_kwargs)
return music_tokens
@add_start_docstrings(
"""Upsamples a sequence of music tokens using the prior at level `level`.
Args:
music_tokens (`List[torch.LongTensor]` of length `self.levels` ) :
A sequence of music tokens which will be used as context to continue the sampling process. Should have
`self.levels` tensors, each corresponding to the generation at a certain level.
""",
JUKEBOX_SAMPLING_INPUT_DOCSTRING,
)
def upsample(self, music_tokens, labels, **sampling_kwargs) -> List[torch.LongTensor]:
sample_levels = sampling_kwargs.pop("sample_levels", list(range(len(self.priors) - 1)))
music_tokens = self._sample(music_tokens, labels, sample_levels, **sampling_kwargs)
return music_tokens
@add_start_docstrings(
"""Generate a raw audio conditioned on the provided `raw_audio` which is used as conditioning at each of the
generation levels. The audio is encoded to music tokens using the 3 levels of the VQ-VAE. These tokens are
used: as conditioning for each level, which means that no ancestral sampling is required.
Args:
raw_audio (`List[torch.Tensor]` of length `n_samples` ) :
A list of raw audio that will be used as conditioning information for each samples that will be
generated.
""",
JUKEBOX_SAMPLING_INPUT_DOCSTRING,
)
这是一个装饰器函数,用于给 `primed_sample` 方法添加文档字符串。文档字符串描述了函数的作用、参数和返回值。
def primed_sample(self, raw_audio, labels, **sampling_kwargs) -> List[torch.LongTensor]:
定义了一个名为 `primed_sample` 的方法,用于生成基于提供的 `raw_audio` 条件的原始音频。返回一个列表,其中每个元素是包含音乐 token 的 torch LongTensor。
sample_levels = sampling_kwargs.pop("sample_levels", list(range(len(self.priors))))
从 `sampling_kwargs` 中获取 `sample_levels` 参数,如果不存在则默认为 `self.priors` 的长度范围内的列表。
self.vqvae.to(raw_audio.device).float()
将 `self.vqvae` 移动到 `raw_audio` 的设备上,并将其数据类型转换为 float。
with torch.no_grad():
进入一个禁用梯度跟踪的上下文管理器,以确保在该部分代码中不会进行梯度计算。
music_tokens = self.vqvae.encode(
raw_audio, start_level=0, end_level=len(self.priors), bs_chunks=raw_audio.shape[0]
)
使用 `self.vqvae` 对 `raw_audio` 进行编码,生成音乐 token。使用从 0 到 `len(self.priors)` 的级别作为起始和结束级别,并根据 `raw_audio` 的形状分块处理。
music_tokens = self._sample(music_tokens, labels, sample_levels, **sampling_kwargs)
调用 `_sample` 方法,使用 `music_tokens`、`labels` 和 `sample_levels` 进行采样,传递额外的 `sampling_kwargs`。
return music_tokens
返回生成的音乐 token 列表。
.\models\jukebox\tokenization_jukebox.py
import json
import os
import re
import unicodedata
from json.encoder import INFINITY
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import regex
from ...tokenization_utils import AddedToken, PreTrainedTokenizer
from ...tokenization_utils_base import BatchEncoding
from ...utils import TensorType, is_flax_available, is_tf_available, is_torch_available, logging
from ...utils.generic import _is_jax, _is_numpy
logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {
"artists_file": "artists.json",
"lyrics_file": "lyrics.json",
"genres_file": "genres.json",
}
PRETRAINED_VOCAB_FILES_MAP = {
"artists_file": {
"jukebox": "https://huggingface.co/ArthurZ/jukebox/blob/main/artists.json",
},
"genres_file": {
"jukebox": "https://huggingface.co/ArthurZ/jukebox/blob/main/genres.json",
},
"lyrics_file": {
"jukebox": "https://huggingface.co/ArthurZ/jukebox/blob/main/lyrics.json",
},
}
PRETRAINED_LYRIC_TOKENS_SIZES = {
"jukebox": 512,
}
class JukeboxTokenizer(PreTrainedTokenizer):
"""
构造 Jukebox 分词器。Jukebox 可以根据三种不同的输入条件进行条件化:
- 艺术家:每个艺术家关联的唯一 ID 存储在提供的字典中。
- 音乐流派:每种流派关联的唯一 ID 存储在提供的字典中。
- 歌词:基于字符的分词。必须初始化使用词汇表中包含的字符列表。
该分词器不需要训练。它应该能够处理不同数量的输入:
因为模型的条件化可以在三种不同的查询上完成。如果未提供任何值,则将使用默认值。
根据应该条件化模型的流派数量(`n_genres`)而定。
参数:
- PreTrainedTokenizer:继承自父类 PreTrainedTokenizer 的构造函数。
示例用法:
```
>>> from transformers import JukeboxTokenizer
>>> tokenizer = JukeboxTokenizer.from_pretrained("openai/jukebox-1b-lyrics")
>>> tokenizer("Alan Jackson", "Country Rock", "old town road")["input_ids"]
[tensor([[ 0, 0, 0, 6785, 546, 41, 38, 30, 76, 46, 41, 49,
40, 76, 44, 41, 27, 30]]), tensor([[ 0, 0, 0, 145, 0]]), tensor([[ 0, 0, 0, 145, 0]])]
```
```
# 你可以通过在实例化这个分词器时或在调用它处理文本时传递 `add_prefix_space=True` 来避免这种行为,但由于模型不是以这种方式预训练的,可能会导致性能下降。
# 提示信息
# 如果未提供任何内容,流派和艺术家将随机选择或设置为 None。
# 这个分词器继承自 [`PreTrainedTokenizer`],其中包含大多数主要方法。用户应参考该超类以获取有关这些方法的更多信息。
# 然而,代码不允许这样做,只支持从各种流派组成。
# 参数说明:
# artists_file (`str`):
# 包含艺术家与其ID映射的词汇文件的路径。默认文件支持 "v2" 和 "v3"。
# genres_file (`str`):
# 包含流派与其ID映射的词汇文件的路径。
# lyrics_file (`str`):
# 包含歌词分词接受字符的词汇文件的路径。
# version (`List[str]`, 可选, 默认为 `["v3", "v2", "v2"]`) :
# 分词器版本列表。`5b-lyrics` 的顶级优先模型使用 `v3` 而不是 `v2` 进行训练。
# n_genres (`int`, 可选, 默认为 5):
# 用于组合的最大流派数。
# max_n_lyric_tokens (`int`, 可选, 默认为 512):
# 保留的最大歌词分词数量。
# unk_token (`str`, 可选, 默认为 `"<|endoftext|>"`):
# 未知标记。词汇表中没有的标记将无法转换为ID,并被设置为此标记。
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_lyric_input_size = PRETRAINED_LYRIC_TOKENS_SIZES
model_input_names = ["input_ids", "attention_mask"]
def __init__(
self,
artists_file,
genres_file,
lyrics_file,
version=["v3", "v2", "v2"],
max_n_lyric_tokens=512,
n_genres=5,
unk_token="<|endoftext|>",
**kwargs,
):
):
unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
self.version = version
self.max_n_lyric_tokens = max_n_lyric_tokens
self.n_genres = n_genres
self._added_tokens_decoder = {0: unk_token}
with open(artists_file, encoding="utf-8") as vocab_handle:
self.artists_encoder = json.load(vocab_handle)
with open(genres_file, encoding="utf-8") as vocab_handle:
self.genres_encoder = json.load(vocab_handle)
with open(lyrics_file, encoding="utf-8") as vocab_handle:
self.lyrics_encoder = json.load(vocab_handle)
oov = r"[^A-Za-z0-9.,:;!?\-'\"()\[\] \t\n]+"
if len(self.lyrics_encoder) == 79:
oov = oov.replace(r"\-'", r"\-+'")
self.out_of_vocab = regex.compile(oov)
self.artists_decoder = {v: k for k, v in self.artists_encoder.items()}
self.genres_decoder = {v: k for k, v in self.genres_encoder.items()}
self.lyrics_decoder = {v: k for k, v in self.lyrics_encoder.items()}
super().__init__(
unk_token=unk_token,
n_genres=n_genres,
version=version,
max_n_lyric_tokens=max_n_lyric_tokens,
**kwargs,
)
@property
def vocab_size(self):
return len(self.artists_encoder) + len(self.genres_encoder) + len(self.lyrics_encoder)
def get_vocab(self):
return {
"artists_encoder": self.artists_encoder,
"genres_encoder": self.genres_encoder,
"lyrics_encoder": self.lyrics_encoder,
}
def _convert_token_to_id(self, list_artists, list_genres, list_lyrics):
"""Converts the artist, genre and lyrics tokens to their index using the vocabulary.
The total_length, offset and duration have to be provided in order to select relevant lyrics and add padding to
the lyrics token sequence.
"""
artists_id = [self.artists_encoder.get(artist, 0) for artist in list_artists]
for genres in range(len(list_genres)):
list_genres[genres] = [self.genres_encoder.get(genre, 0) for genre in list_genres[genres]]
list_genres[genres] = list_genres[genres] + [-1] * (self.n_genres - len(list_genres[genres]))
lyric_ids = [[self.lyrics_encoder.get(character, 0) for character in list_lyrics[0]], [], []]
return artists_id, list_genres, lyric_ids
def _tokenize(self, lyrics):
"""
Converts a string into a sequence of tokens (string), using the tokenizer. Split in words for word-based
vocabulary or sub-words for sub-word-based vocabularies (BPE/SentencePieces/WordPieces).
Do NOT take care of added tokens. Only the lyrics are split into character for the character-based vocabulary.
"""
return list(lyrics)
def tokenize(self, artist, genre, lyrics, **kwargs):
"""
Converts three strings in a 3 sequence of tokens using the tokenizer
"""
artist, genre, lyrics = self.prepare_for_tokenization(artist, genre, lyrics)
lyrics = self._tokenize(lyrics)
return artist, genre, lyrics
def prepare_for_tokenization(
self, artists: str, genres: str, lyrics: str, is_split_into_words: bool = False
):
) -> Tuple[str, str, str, Dict[str, Any]]:
"""
Performs any necessary transformations before tokenization.
Args:
artist (`str`):
The artist name to prepare. This will mostly lower the string
genres (`str`):
The genre name to prepare. This will mostly lower the string.
lyrics (`str`):
The lyrics to prepare.
is_split_into_words (`bool`, *optional*, defaults to `False`):
Whether or not the input is already pre-tokenized (e.g., split into words). If set to `True`, the
tokenizer assumes the input is already split into words (for instance, by splitting it on whitespace)
which it will tokenize. This is useful for NER or token classification.
"""
for idx in range(len(self.version)):
if self.version[idx] == "v3":
artists[idx] = artists[idx].lower()
genres[idx] = [genres[idx].lower()]
else:
artists[idx] = self._normalize(artists[idx]) + ".v2"
genres[idx] = [
self._normalize(genre) + ".v2" for genre in genres[idx].split("_")
]
if self.version[0] == "v2":
self.out_of_vocab = regex.compile(r"[^A-Za-z0-9.,:;!?\-'\"()\[\] \t\n]+")
vocab = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789.,:;!?-+'\"()[] \t\n"
self.vocab = {vocab[index]: index + 1 for index in range(len(vocab))}
self.vocab["<unk>"] = 0
self.n_vocab = len(vocab) + 1
self.lyrics_encoder = self.vocab
self.lyrics_decoder = {v: k for k, v in self.vocab.items()}
self.lyrics_decoder[0] = ""
else:
self.out_of_vocab = regex.compile(r"[^A-Za-z0-9.,:;!?\-+'\"()\[\] \t\n]+")
lyrics = self._run_strip_accents(lyrics)
lyrics = lyrics.replace("\\", "\n")
lyrics = self.out_of_vocab.sub("", lyrics), [], []
return artists, genres, lyrics
def _run_strip_accents(self, text):
"""Strips accents from a piece of text."""
text = unicodedata.normalize("NFD", text)
output = []
for char in text:
cat = unicodedata.category(char)
if cat == "Mn":
continue
output.append(char)
return "".join(output)
def _normalize(self, text: str) -> str:
"""
Normalizes the input text. This process is for the genres and the artist
Args:
text (`str`):
Artist or Genre string to normalize
"""
accepted = (
[chr(i) for i in range(ord("a"), ord("z") + 1)]
+ [chr(i) for i in range(ord("A"), ord("Z") + 1)]
+ [chr(i) for i in range(ord("0"), ord("9") + 1)]
+ ["."]
)
accepted = frozenset(accepted)
pattern = re.compile(r"_+")
text = "".join([c if c in accepted else "_" for c in text.lower()])
text = pattern.sub("_", text).strip("_")
return text
def convert_lyric_tokens_to_string(self, lyrics: List[str]) -> str:
return " ".join(lyrics)
"""
Convert the inner content to tensors.
Args:
tensor_type (`str` or [`~utils.TensorType`], *optional*):
The type of tensors to use. If `str`, should be one of the values of the enum [`~utils.TensorType`]. If
unset, no modification is done.
prepend_batch_axis (`int`, *optional*, defaults to `False`):
Whether or not to add the batch dimension during the conversion.
"""
if not isinstance(tensor_type, TensorType):
tensor_type = TensorType(tensor_type)
if tensor_type == TensorType.TENSORFLOW:
if not is_tf_available():
raise ImportError(
"Unable to convert output to TensorFlow tensors format, TensorFlow is not installed."
)
import tensorflow as tf
as_tensor = tf.constant
is_tensor = tf.is_tensor
elif tensor_type == TensorType.PYTORCH:
if not is_torch_available():
raise ImportError("Unable to convert output to PyTorch tensors format, PyTorch is not installed.")
import torch
as_tensor = torch.tensor
is_tensor = torch.is_tensor
elif tensor_type == TensorType.JAX:
if not is_flax_available():
raise ImportError("Unable to convert output to JAX tensors format, JAX is not installed.")
import jax.numpy as jnp
as_tensor = jnp.array
is_tensor = _is_jax
else:
as_tensor = np.asarray
is_tensor = _is_numpy
try:
if prepend_batch_axis:
inputs = [inputs]
if not is_tensor(inputs):
inputs = as_tensor(inputs)
except:
raise ValueError(
"Unable to create tensor, you should probably activate truncation and/or padding "
"with 'padding=True' 'truncation=True' to have batched tensors with the same length."
)
return inputs
def __call__(self, artist, genres, lyrics="", return_tensors="pt") -> BatchEncoding:
"""Convert the raw string to a list of token ids
Args:
artist (`str`):
Name of the artist.
genres (`str`):
List of genres that will be mixed to condition the audio
lyrics (`str`, *optional*, defaults to `""`):
Lyrics used to condition the generation
"""
input_ids = [0, 0, 0]
artist = [artist] * len(self.version)
genres = [genres] * len(self.version)
artists_tokens, genres_tokens, lyrics_tokens = self.tokenize(artist, genres, lyrics)
artists_id, genres_ids, full_tokens = self._convert_token_to_id(artists_tokens, genres_tokens, lyrics_tokens)
attention_masks = [-INFINITY] * len(full_tokens[-1])
input_ids = [
self.convert_to_tensors(
[input_ids + [artists_id[i]] + genres_ids[i] + full_tokens[i]], tensor_type=return_tensors
)
for i in range(len(self.version))
]
return BatchEncoding({"input_ids": input_ids, "attention_masks": attention_masks})
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
"""
Saves the tokenizer's vocabulary dictionary to the provided save_directory.
Args:
save_directory (`str`):
A path to the directory where to saved. It will be created if it doesn't exist.
filename_prefix (`Optional[str]`, *optional*):
A prefix to add to the names of the files saved by the tokenizer.
"""
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
return
artists_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["artists_file"]
)
with open(artists_file, "w", encoding="utf-8") as f:
f.write(json.dumps(self.artists_encoder, ensure_ascii=False))
genres_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["genres_file"]
)
with open(genres_file, "w", encoding="utf-8") as f:
f.write(json.dumps(self.genres_encoder, ensure_ascii=False))
lyrics_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["lyrics_file"]
)
with open(lyrics_file, "w", encoding="utf-8") as f:
f.write(json.dumps(self.lyrics_encoder, ensure_ascii=False))
return (artists_file, genres_file, lyrics_file)
def _convert_id_to_token(self, artists_index, genres_index, lyric_index):
"""
Converts an index (integer) in a token (str) using the vocab.
Args:
artists_index (`int`):
Index of the artist in its corresponding dictionary.
genres_index (`Union[List[int], int]`):
Index of the genre in its corresponding dictionary. Can be a single index or a list of indices.
lyric_index (`List[int]`):
List of character indices, each corresponding to a character.
Returns:
artist (`Optional[str]`):
Decoded artist name corresponding to artists_index.
genres (`List[Optional[str]]`):
List of decoded genre names corresponding to genres_index.
lyrics (`List[Optional[str]]`):
List of decoded characters corresponding to lyric_index.
"""
artist = self.artists_decoder.get(artists_index)
genres = [self.genres_decoder.get(genre) for genre in genres_index]
lyrics = [self.lyrics_decoder.get(character) for character in lyric_index]
return artist, genres, lyrics
.\models\jukebox\__init__.py
from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
_import_structure = {
"configuration_jukebox": [
"JUKEBOX_PRETRAINED_CONFIG_ARCHIVE_MAP",
"JukeboxConfig",
"JukeboxPriorConfig",
"JukeboxVQVAEConfig",
],
"tokenization_jukebox": ["JukeboxTokenizer"],
}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_jukebox"] = [
"JUKEBOX_PRETRAINED_MODEL_ARCHIVE_LIST",
"JukeboxModel",
"JukeboxPreTrainedModel",
"JukeboxVQVAE",
"JukeboxPrior",
]
if TYPE_CHECKING:
from .configuration_jukebox import (
JUKEBOX_PRETRAINED_CONFIG_ARCHIVE_MAP,
JukeboxConfig,
JukeboxPriorConfig,
JukeboxVQVAEConfig,
)
from .tokenization_jukebox import JukeboxTokenizer
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_jukebox import (
JUKEBOX_PRETRAINED_MODEL_ARCHIVE_LIST,
JukeboxModel,
JukeboxPreTrainedModel,
JukeboxPrior,
JukeboxVQVAE,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
.\models\kosmos2\configuration_kosmos2.py
""" KOSMOS-2 模型配置"""
import os
from typing import Union
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
KOSMOS2_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"microsoft/kosmos-2-patch14-224": (
"https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/config.json"
),
}
class Kosmos2TextConfig(PretrainedConfig):
r"""
这是一个配置类,用于存储 [`Kosmos2TextModel`] 的配置信息。根据指定的参数实例化 KOSMOS-2 文本解码器,
定义模型架构。使用默认参数实例化配置对象将产生类似于 KOSMOS-2 文本解码器
[microsoft/kosmos-2-patch14-224](https://huggingface.co/microsoft/kosmos-2-patch14-224) 架构的配置。
配置对象继承自 [`PretrainedConfig`],可用于控制模型输出。阅读 [`PretrainedConfig`] 的文档以获取更多信息。
```
# 定义 Kosmos2 模型的参数和默认值
# 模型的类型,用于标识 Kosmos2 文本模型
model_type = "kosmos_2_text_model"
# 推断阶段需要忽略的键列表,这些键不会在推断时使用
keys_to_ignore_at_inference = ["past_key_values"]
# 属性映射字典,将模型参数的名称映射到 Kosmos2 模型期望的名称
attribute_map = {
"num_attention_heads": "attention_heads", # 注意力头的数量
"hidden_size": "embed_dim", # 隐藏层的维度
"num_hidden_layers": "layers", # Transformer 编码器中的隐藏层数量
}
# 初始化函数,用于创建一个新的配置对象
def __init__(
self,
vocab_size=65037, # 词汇表大小,默认为65037
max_position_embeddings=2048, # 最大位置嵌入数量,默认为2048
embed_dim=2048, # 嵌入维度,默认为2048
layers=24, # 层数,默认为24
ffn_dim=8192, # 前馈神经网络维度,默认为8192
attention_heads=32, # 注意力头数,默认为32
activation_function="gelu", # 激活函数,默认为"gelu"
dropout=0.1, # 普通层级dropout概率,默认为0.1
attention_dropout=0.1, # 注意力模块dropout概率,默认为0.1
activation_dropout=0.0, # 激活函数dropout概率,默认为0.0
layerdrop=0.0, # 层级dropout概率,默认为0.0
layer_norm_eps=1e-5, # 层归一化的epsilon,默认为1e-5
init_std=0.02, # 初始化标准差,默认为0.02
scale_embedding=True, # 是否缩放嵌入,默认为True
use_cache=True, # 是否使用缓存,默认为True
pad_token_id=1, # 填充标记ID,默认为1
bos_token_id=0, # 开始序列标记ID,默认为0
eos_token_id=2, # 结束序列标记ID,默认为2
**kwargs, # 其他关键字参数
):
# 调用父类的初始化方法,设置填充、开始、结束标记ID等参数
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
**kwargs,
)
# 初始化配置对象的各个属性
self.vocab_size = vocab_size # 设置词汇表大小属性
self.max_position_embeddings = max_position_embeddings # 设置最大位置嵌入数量属性
self.embed_dim = embed_dim # 设置嵌入维度属性
self.layers = layers # 设置层数属性
self.ffn_dim = ffn_dim # 设置前馈神经网络维度属性
self.attention_heads = attention_heads # 设置注意力头数属性
self.activation_function = activation_function # 设置激活函数属性
self.dropout = dropout # 设置普通层级dropout概率属性
self.attention_dropout = attention_dropout # 设置注意力模块dropout概率属性
self.activation_dropout = activation_dropout # 设置激活函数dropout概率属性
self.layerdrop = layerdrop # 设置层级dropout概率属性
self.layer_norm_eps = layer_norm_eps # 设置层归一化的epsilon属性
self.init_std = init_std # 设置初始化标准差属性
self.scale_embedding = scale_embedding # 设置是否缩放嵌入属性
self.use_cache = use_cache # 设置是否使用缓存属性
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
# 将token相关参数添加到kwargs中
cls._set_token_in_kwargs(kwargs)
# 获取配置字典和更新后的kwargs
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
# 如果加载自Kosmos2Config,则获取文本配置字典
if config_dict.get("model_type") == "kosmos-2":
config_dict = config_dict["text_config"]
# 如果配置字典中存在model_type,并且与类的model_type不同,发出警告
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)
# 从配置字典和kwargs创建类的实例
return cls.from_dict(config_dict, **kwargs)
# 定义 `Kosmos2VisionConfig` 类,用于存储 `Kosmos2VisionModel` 的配置信息。
# 继承自 `PretrainedConfig`,用于控制模型的输出。详细信息请参考 `PretrainedConfig` 的文档。
class Kosmos2VisionConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Kosmos2VisionModel`]. It is used to instantiate a
KOSMOS-2 vision encoder according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the vision encoder of the KOSMOS-2
[microsoft/kosmos-2-patch14-224](https://huggingface.co/microsoft/kosmos-2-patch14-224) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
hidden_size (`int`, *optional*, defaults to 1024):
Dimensionality of the encoder layers and the pooler layer.
intermediate_size (`int`, *optional*, defaults to 4096):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
num_hidden_layers (`int`, *optional*, defaults to 24):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer encoder.
num_channels (`int`, *optional*, defaults to 3):
The number of input channels.
image_size (`int`, *optional*, defaults to 224):
The size (resolution) of each image.
patch_size (`int`, *optional*, defaults to 14):
The size (resolution) of each patch.
hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported.
layer_norm_eps (`float`, *optional*, defaults to 1e-5):
The epsilon used by the layer normalization layers.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
initializer_factor (`float`, *optional*, defaults to 1):
A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
testing).
"""
):
super().__init__(**kwargs)
# 调用父类的初始化方法,传入关键字参数
self.hidden_size = hidden_size
# 设置隐藏层大小
self.intermediate_size = intermediate_size
# 设置中间层大小
self.num_hidden_layers = num_hidden_layers
# 设置隐藏层数量
self.num_attention_heads = num_attention_heads
# 设置注意力头数量
self.num_channels = num_channels
# 设置通道数量
self.patch_size = patch_size
# 设置图像块大小
self.image_size = image_size
# 设置图像大小
self.initializer_range = initializer_range
# 设置初始化范围
self.initializer_factor = initializer_factor
# 设置初始化因子
self.attention_dropout = attention_dropout
# 设置注意力丢弃率
self.layer_norm_eps = layer_norm_eps
# 设置层归一化的 epsilon 值
self.hidden_act = hidden_act
# 设置隐藏层激活函数
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
# 调用类方法 _set_token_in_kwargs,设置关键字参数中的 token
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
# 调用类方法 get_config_dict,获取预训练模型的配置字典和更新后的关键字参数
# 如果从 Kosmos2Config 加载,则获取视觉配置字典
if config_dict.get("model_type") == "kosmos-2":
config_dict = config_dict["vision_config"]
# 如果配置字典中存在 "model_type" 并且类具有 "model_type" 属性,并且它们不相同,发出警告
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)
# 从配置字典创建类的实例,并返回
return cls.from_dict(config_dict, **kwargs)
class Kosmos2Config(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Kosmos2Model`]. It is used to instantiate a
KOSMOS-2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the KOSMOS-2
[microsoft/kosmos-2-patch14-224](https://huggingface.co/microsoft/kosmos-2-patch14-224) architecture.
Args:
text_config (`dict`, *optional*):
Dictionary of configuration options used to initialize [`Kosmos2TextConfig`].
vision_config (`dict`, *optional*):
Dictionary of configuration options used to initialize [`Kosmos2VisionConfig`].
latent_query_num (`int`, *optional*, defaults to 64):
The number of latent query tokens that represent the image features used in the text decoder component.
kwargs (*optional*):
Dictionary of keyword arguments.
Example:
```
>>> from transformers import Kosmos2Config, Kosmos2Model
>>>
>>> configuration = Kosmos2Config()
>>>
>>> model = Kosmos2Model(configuration)
>>>
>>> configuration = model.config
```"""
# 设定模型类型为 "kosmos-2"
model_type = "kosmos-2"
# 标志这个配置类是由多个部分组成
is_composition = True
def __init__(
self,
text_config=None,
vision_config=None,
latent_query_num=64,
**kwargs,
):
# 调用父类构造函数,传入所有额外的关键字参数
super().__init__(**kwargs)
# 如果文本配置为空,使用默认空字典并记录日志
if text_config is None:
text_config = {}
logger.info("`text_config` is `None`. Initializing the `Kosmos2TextConfig` with default values.")
# 如果视觉配置为空,使用默认空字典并记录日志
if vision_config is None:
vision_config = {}
logger.info("`vision_config` is `None`. Initializing the `Kosmos2VisionConfig` with default values.")
# 根据传入的文本配置初始化 `Kosmos2TextConfig` 对象
self.text_config = Kosmos2TextConfig(**text_config)
# 根据传入的视觉配置初始化 `Kosmos2VisionConfig` 对象
self.vision_config = Kosmos2VisionConfig(**vision_config)
# 设置 latent_query_num 属性,表示在文本解码器组件中用于表示图像特征的潜在查询标记数目
self.latent_query_num = latent_query_num
.\models\kosmos2\convert_kosmos2_original_pytorch_checkpoint_to_pytorch.py
import argparse
from fairseq.checkpoint_utils import load_checkpoint_to_cpu
from transformers import Kosmos2Config, Kosmos2ForConditionalGeneration
KEYS_TO_MODIFY_MAPPING = {
"gpt_model.decoder.output_projection": "text_model.lm_head",
"gpt_model.decoder": "text_model.model",
"img_connector": "image_to_text_projection",
"img_model.visual.class_embedding": "vision_model.model.embeddings.class_embedding",
"img_model.visual.positional_embedding": "vision_model.model.embeddings.position_embedding.weight",
"img_model.visual.conv1": "vision_model.model.embeddings.patch_embedding",
"img_model.visual": "vision_model.model",
"ln_pre": "pre_layrnorm",
"ln_post": "post_layernorm",
"transformer.resblocks": "encoder.layers",
"ts_attn": "self_attn",
"ln_1": "layer_norm1",
"ln_2": "layer_norm2",
"c_fc": "fc1",
"c_proj": "fc2",
}
KEYS_TO_IGNORE = [
"gpt_model.decoder.embed_positions._float_tensor",
"gpt_model.decoder.self_attn_sope.scale",
]
def rename_key(key):
for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items():
if key_to_modify in key:
key = key.replace(key_to_modify, new_key)
return key
def convert_kosmos2_checkpoint_to_pytorch(checkpoint_path, pytorch_dump_folder_path):
state = load_checkpoint_to_cpu(checkpoint_path)
state_dict = state["model"]
state_dict_keys = list(state_dict.keys())
config = Kosmos2Config()
config.text_config.no_repeat_ngram_size = 3
model = Kosmos2ForConditionalGeneration(config)
converted_state_dict = {}
for key in state_dict_keys:
if key in KEYS_TO_IGNORE:
continue
renamed_key = rename_key(key)
converted_state_dict[renamed_key] = state_dict[key]
model.load_state_dict(converted_state_dict, strict=True)
model.save_pretrained(pytorch_dump_folder_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--kosmos2_checkpoint_path", default=None, type=str, required=True, help="Path the official PyTorch dump."
)
parser.add_argument(
"--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
)
args = parser.parse_args()
convert_kosmos2_checkpoint_to_pytorch(args.kosmos2_checkpoint_path, args.pytorch_dump_folder_path)
.\models\kosmos2\modeling_kosmos2.py
""" PyTorch KOSMOS-2 model."""
import math
from dataclasses import dataclass
from typing import Any, List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPooling,
CausalLMOutputWithCrossAttentions,
)
from ...modeling_utils import PreTrainedModel
from ...utils import (
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from .configuration_kosmos2 import Kosmos2Config, Kosmos2TextConfig, Kosmos2VisionConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = Kosmos2Config
KOSMOS2_PRETRAINED_MODEL_ARCHIVE_LIST = [
"microsoft/kosmos-2-patch14-224",
]
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
def _make_causal_mask(
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
"""
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
are ignored. This is modified from fairseq's `utils.make_positions`.
Args:
input_ids (torch.Tensor): 输入的 token IDs
padding_idx (int): 填充符号的索引
past_key_values_length (int, optional): 过去键值长度,用于增量索引计算
Returns:
torch.Tensor: 替换后的位置 ID
"""
mask = input_ids.ne(padding_idx).int()
incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
return incremental_indices.long() + padding_idx
KOSMOS2_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`Kosmos2Config`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
KOSMOS2_VISION_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
[`CLIPImageProcessor.__call__`] for details.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
KOSMOS2_TEXT_INPUTS_DOCSTRING = r"""
Args:
"""
KOSMOS2_INPUTS_DOCSTRING = r"""
Args:
"""
@dataclass
class Kosmos2ModelOutput(ModelOutput):
"""
Base class for text model's outputs that also contains a pooling of the last hidden states.
"""
last_hidden_state: torch.FloatTensor = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
image_embeds: Optional[torch.FloatTensor] = None
projection_attentions: Optional[Tuple[torch.FloatTensor]] = None
vision_model_output: BaseModelOutputWithPooling = None
def to_tuple(self) -> Tuple[Any]:
return tuple(
self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
for k in self.keys()
)
@dataclass
class Kosmos2ForConditionalGenerationModelOutput(ModelOutput):
"""
Model output class for `Kosmos2ForConditionalGeneration`.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss (for next-token prediction).
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
image_embeds (`torch.FloatTensor` of shape `(batch_size, latent_query_num, hidden_size)`, *optional*):
Sequence of hidden-states at the output of `Kosmos2ImageToTextProjection`.
projection_attentions (`tuple(torch.FloatTensor)`, *optional*):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights given by `Kosmos2ImageToTextProjection`, after the attention softmax, used to compute
the weighted average in the self-attention heads.
vision_model_output(`BaseModelOutputWithPooling`, *optional*):
The output of the [`Kosmos2VisionModel`].
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
`config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
input) to speed up sequential decoding.
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
image_embeds: Optional[torch.FloatTensor] = None
projection_attentions: Optional[Tuple[torch.FloatTensor]] = None
vision_model_output: BaseModelOutputWithPooling = None
def to_tuple(self) -> Tuple[Any]:
return tuple(
self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
for k in self.keys()
)
class Kosmos2VisionEmbeddings(nn.Module):
def __init__(self, config: Kosmos2VisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
bias=False,
)
self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches + 1
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
batch_size = pixel_values.shape[0]
target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
embeddings = embeddings + self.position_embedding(self.position_ids)
return embeddings
class Kosmos2VisionAttention(nn.Module):
"""来自 'Attention Is All You Need' 论文的多头注意力机制"""
def __init__(self, config):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: "
f"{self.num_heads})."
)
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
class Kosmos2VisionMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.activation_fn = ACT2FN[config.hidden_act]
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
class Kosmos2VisionEncoderLayer(nn.Module):
def __init__(self, config: Kosmos2VisionConfig):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = Kosmos2VisionAttention(config)
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = Kosmos2VisionMLP(config)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
causal_attention_mask: torch.Tensor,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.FloatTensor]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
causal_attention_mask (`torch.FloatTensor`): mask indicating the causal nature of attention
output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers.
"""
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states, attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
causal_attention_mask=causal_attention_mask,
output_attentions=output_attentions,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs
class Kosmos2VisionEncoder(nn.Module):
"""
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
"""
class Kosmos2VisionEncoderLayer(nn.Module):
def __init__(self, config: Kosmos2VisionConfig):
super().__init__()
self.config = config
self.layers = nn.ModuleList([Kosmos2VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
def forward(
self,
inputs_embeds,
attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
class Kosmos2VisionTransformer(nn.Module):
def __init__(self, config: Kosmos2VisionConfig):
super().__init__()
self.config = config
embed_dim = config.hidden_size
self.embeddings = Kosmos2VisionEmbeddings(config)
self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self.encoder = Kosmos2VisionEncoder(config)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
hidden_states = self.embeddings(pixel_values)
hidden_states = self.pre_layrnorm(hidden_states)
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_state = encoder_outputs[0]
pooled_output = last_hidden_state[:, 0, :]
pooled_output = self.post_layernorm(pooled_output)
if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
class Kosmos2TextSinusoidalPositionalEmbedding(nn.Module):
"""This module produces sinusoidal positional embeddings of any length."""
def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None):
super().__init__()
self.offset = 2
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
self.make_weights(num_positions + self.offset, embedding_dim, padding_idx)
def make_weights(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):
emb_weights = self.get_embedding(num_embeddings, embedding_dim, padding_idx)
if hasattr(self, "weights"):
emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device)
self.register_buffer("weights", emb_weights, persistent=False)
@staticmethod
def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):
"""
构建正弦位置编码的嵌入向量。
该方法与tensor2tensor中的实现匹配,但与《Attention Is All You Need》中第3.5节的描述略有不同。
"""
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb)
emb = torch.arange(num_embeddings, dtype=torch.int64).float().unsqueeze(1) * emb.unsqueeze(0)
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
if embedding_dim % 2 == 1:
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
if padding_idx is not None:
emb[padding_idx, :] = 0
return emb.to(torch.get_default_dtype())
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor = None,
inputs_embeds: torch.Tensor = None,
past_key_values_length: int = 0,
position_ids: torch.Tensor = None,
):
if input_ids is not None:
bsz, seq_len = input_ids.size()
if position_ids is None:
position_ids = create_position_ids_from_input_ids(
input_ids, self.padding_idx, past_key_values_length
).to(input_ids.device)
else:
bsz, seq_len = inputs_embeds.size()[:-1]
if position_ids is None:
position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds, past_key_values_length)
max_pos = self.padding_idx + 1 + seq_len + past_key_values_length
if max_pos > self.weights.size(0):
self.make_weights(max_pos + self.offset, self.embedding_dim, self.padding_idx)
return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, self.weights.shape[-1]).detach()
def create_position_ids_from_inputs_embeds(self, inputs_embeds, past_key_values_length):
"""
直接提供 embeddings。无法推断哪些是填充的,因此生成顺序的 position ids。
Args:
inputs_embeds: torch.Tensor
Returns: torch.Tensor
"""
input_shape = inputs_embeds.size()[:-1]
sequence_length = input_shape[1]
position_ids = torch.arange(
self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
)
return position_ids.unsqueeze(0).expand(input_shape).contiguous() + past_key_values_length
class KosmosTextAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
config,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
is_decoder: bool = False,
add_inner_attn_layernorm: bool = False,
bias: bool = True,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {num_heads})."
)
self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.inner_attn_ln = None
if add_inner_attn_layernorm:
self.inner_attn_ln = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
def _shape(self, projection: torch.Tensor) -> torch.Tensor:
new_projection_shape = projection.size()[:-1] + (self.num_heads, self.head_dim)
new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3)
return new_projection
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
def forward(self, hidden_states):
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = self.ffn_layernorm(hidden_states)
hidden_states = self.fc2(hidden_states)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
return hidden_states
class Kosmos2TextBlock(nn.Module):
def __init__(self, config: Kosmos2TextConfig):
super().__init__()
self.embed_dim = config.embed_dim
self.self_attn = KosmosTextAttention(
config,
embed_dim=self.embed_dim,
num_heads=config.attention_heads,
dropout=config.attention_dropout,
is_decoder=True,
add_inner_attn_layernorm=True,
)
self.dropout = config.dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
if config.add_cross_attention:
self.encoder_attn = KosmosTextAttention(
config,
embed_dim=self.embed_dim,
num_heads=config.attention_heads,
dropout=config.attention_dropout,
is_decoder=True,
add_inner_attn_layernorm=False,
)
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.ffn = Kosmos2TextFFN(config)
self.final_layer_norm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = True,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
past_key_value=self_attn_past_key_value,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
cross_attn_present_key_value = None
cross_attn_weights = None
if encoder_hidden_states is not None:
if not hasattr(self, "encoder_attn"):
raise ValueError(
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
" by setting `config.add_cross_attention=True`"
)
residual = hidden_states
hidden_states = self.encoder_attn_layer_norm(hidden_states)
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
layer_head_mask=cross_attn_layer_head_mask,
past_key_value=cross_attn_past_key_value,
output_attentions=output_attentions,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
present_key_value = present_key_value + cross_attn_present_key_value
residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.ffn(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights, cross_attn_weights)
if use_cache:
outputs += (present_key_value,)
return outputs
"""
Transformer decoder consisting of `config.layers` layers. Each layer is a [`Kosmos2TextBlock`].
Args:
config: Kosmos2TextConfig
"""
def __init__(self, config: Kosmos2TextConfig):
super().__init__()
self.config = config
self.dropout = config.dropout
self.layerdrop = config.layerdrop
self.embed_scale = math.sqrt(config.embed_dim) if config.scale_embedding else 1.0
self.embed_tokens = nn.Embedding(config.vocab_size, config.embed_dim, padding_idx=config.pad_token_id)
self.embed_positions = Kosmos2TextSinusoidalPositionalEmbedding(
num_positions=config.max_position_embeddings,
embedding_dim=config.embed_dim,
padding_idx=config.pad_token_id,
)
self.layers = nn.ModuleList([Kosmos2TextBlock(config) for _ in range(config.layers)])
self.layer_norm = nn.LayerNorm(config.embed_dim, config.layer_norm_eps)
self.gradient_checkpointing = False
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
)
if attention_mask is not None:
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
inputs_embeds.device
)
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
def forward_embedding(
self,
input_ids,
inputs_embeds: torch.Tensor = None,
image_embeds: torch.Tensor = None,
img_input_mask: torch.Tensor = None,
past_key_values_length: int = 0,
position_ids: torch.Tensor = None,
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if image_embeds is not None:
inputs_embeds[img_input_mask.to(dtype=torch.bool)] = image_embeds.to(inputs_embeds.device).view(
-1, image_embeds.size(-1)
)
inputs_embeds = inputs_embeds * self.embed_scale
positions = self.embed_positions(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length,
position_ids=position_ids,
)
positions = positions.to(inputs_embeds.device)
hidden_states = inputs_embeds + positions
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
return hidden_states
class Kosmos2PreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = Kosmos2Config
supports_gradient_checkpointing = True
_no_split_modules = ["Kosmos2VisionEncoderLayer", "Kosmos2TextBlock"]
class Kosmos2VisionModel(Kosmos2PreTrainedModel):
config_class = Kosmos2VisionConfig
main_input_name = "pixel_values"
def __init__(self, config: Kosmos2VisionConfig):
super().__init__(config)
self.model = Kosmos2VisionTransformer(config)
self.post_init()
def get_input_embeddings(self) -> nn.Module:
return self.model.embeddings.patch_embedding
@add_start_docstrings_to_model_forward(KOSMOS2_VISION_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=Kosmos2VisionConfig)
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
"""
前向传播方法,接受像素值作为输入,可选输出注意力、隐藏状态和返回字典。
Returns:
返回模型的输出,可能是元组或带池化的基础模型输出。
"""
return self.model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
class Kosmos2TextModel(Kosmos2PreTrainedModel):
config_class = Kosmos2TextConfig
def __init__(self, config: Kosmos2TextConfig):
super().__init__(config)
self.model = Kosmos2TextTransformer(config)
self.post_init()
def get_input_embeddings(self) -> nn.Module:
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
@add_start_docstrings_to_model_forward(KOSMOS2_TEXT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutputWithPastAndCrossAttentions, config_class=Kosmos2TextConfig)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
image_embeds: Optional[torch.Tensor] = None,
image_embeds_position_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
r"""
将输入参数传递给模型,并返回模型的输出。
Parameters:
- input_ids (Optional[torch.Tensor]): 输入的 token IDs 序列,默认为 None。
- attention_mask (Optional[torch.Tensor]): 注意力遮罩张量,默认为 None。
- image_embeds (Optional[torch.Tensor]): 图像嵌入张量,默认为 None。
- image_embeds_position_mask (Optional[torch.Tensor]): 图像嵌入的位置遮罩张量,默认为 None。
- encoder_hidden_states (Optional[torch.Tensor]): 编码器的隐藏状态张量,默认为 None。
- encoder_attention_mask (Optional[torch.Tensor]): 编码器的注意力遮罩张量,默认为 None。
- head_mask (Optional[torch.Tensor]): 头部遮罩张量,默认为 None。
- cross_attn_head_mask (Optional[torch.Tensor]): 跨注意力头部遮罩张量,默认为 None。
- past_key_values (Optional[List[torch.FloatTensor]]): 过去的键值对列表,默认为 None。
- inputs_embeds (Optional[torch.Tensor]): 输入的嵌入张量,默认为 None。
- position_ids (Optional[torch.Tensor]): 位置 ID 张量,默认为 None。
- use_cache (Optional[bool]): 是否使用缓存,默认为 None。
- output_attentions (Optional[bool]): 是否输出注意力权重,默认为 None。
- output_hidden_states (Optional[bool]): 是否输出隐藏状态,默认为 None。
- return_dict (Optional[bool]): 是否返回字典格式的输出,默认为 None。
Returns:
- Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: 返回模型的输出,可能是一个元组或特定的输出类对象。
"""
return self.model(
input_ids=input_ids,
attention_mask=attention_mask,
image_embeds=image_embeds,
image_embeds_position_mask=image_embeds_position_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
position_ids=position_ids,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
"""
The text model from KOSMOS-2 with a language modeling head on top (linear layer with weights tied to the input
embeddings).
"""
@add_start_docstrings(
KOSMOS2_START_DOCSTRING,
)
class Kosmos2TextForCausalLM(Kosmos2PreTrainedModel):
config_class = Kosmos2TextConfig
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: Kosmos2TextConfig):
super().__init__(config)
self.model = Kosmos2TextTransformer(config)
self.lm_head = nn.Linear(in_features=config.embed_dim, out_features=config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self) -> nn.Module:
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self) -> nn.Module:
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
@add_start_docstrings_to_model_forward(KOSMOS2_TEXT_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=CausalLMOutputWithCrossAttentions,
config_class=Kosmos2TextConfig
)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
image_embeds: Optional[torch.Tensor] = None,
image_embeds_position_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
`[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
Returns:
Depending on `return_dict`, either a tuple with `loss` and various outputs or an instance of
`CausalLMOutputWithCrossAttentions` containing `loss`, `logits`, and other relevant model outputs.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if labels is not None:
if use_cache:
logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
use_cache = False
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
image_embeds=image_embeds,
image_embeds_position_mask=image_embeds_position_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
position_ids=position_ids,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
lm_logits = self.lm_head(outputs[0])
loss = None
if labels is not None:
labels = labels.to(lm_logits.device)
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
batch_size, seq_length, vocab_size = shift_logits.shape
loss_fct = CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length)
)
if not return_dict:
output = (lm_logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithCrossAttentions(
loss=loss,
logits=lm_logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
)
):
input_shape = input_ids.shape
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)
position_ids = None
if past_key_values is not None:
position_ids = create_position_ids_from_input_ids(
input_ids,
padding_idx=self.config.pad_token_id,
past_key_values_length=0,
)[:, -1:]
input_ids = input_ids[:, -1:]
image_embeds = None
image_embeds_position_mask = None
elif image_embeds_position_mask is not None:
batch_size, seq_len = input_ids.size()
mask_len = image_embeds_position_mask.size()[-1]
image_embeds_position_mask = torch.cat(
(
image_embeds_position_mask,
torch.zeros(size=(batch_size, seq_len - mask_len), dtype=torch.bool, device=input_ids.device),
),
dim=1,
)
return {
"input_ids": input_ids,
"image_embeds": image_embeds,
"image_embeds_position_mask": image_embeds_position_mask,
"past_key_values": past_key_values,
"attention_mask": attention_mask,
"position_ids": position_ids,
"use_cache": use_cache,
}
@staticmethod
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past
class Kosmos2ImageToTextProjection(nn.Module):
"""The layer that transforms the image model's output to part of the text model's input (namely, image features)"""
def __init__(self, config: Kosmos2Config):
super().__init__()
self.dense = nn.Linear(config.vision_config.hidden_size, config.text_config.embed_dim)
self.latent_query = nn.Parameter(torch.randn(config.latent_query_num, config.text_config.embed_dim))
self.x_attn = KosmosTextAttention(
config.text_config,
config.text_config.embed_dim,
config.text_config.attention_heads,
dropout=config.text_config.attention_dropout,
is_decoder=False,
add_inner_attn_layernorm=False,
)
def forward(self, features):
hidden_states = self.dense(features)
latent_query = self.latent_query.unsqueeze(0).expand(hidden_states.size(0), -1, -1)
key_value_states = torch.cat([hidden_states, latent_query], dim=1)
hidden_states, attn_weights, _ = self.x_attn(
hidden_states=latent_query,
encoder_hidden_states=key_value_states,
past_key_value=None,
attention_mask=None,
output_attentions=None,
)
return hidden_states, attn_weights
@add_start_docstrings(
"""
KOSMOS-2 Model for generating text and image features. The model consists of a vision encoder and a language model.
""",
KOSMOS2_START_DOCSTRING,
)
class Kosmos2Model(Kosmos2PreTrainedModel):
config_class = Kosmos2Config
main_input_name = "pixel_values"
def __init__(self, config: Kosmos2Config):
super().__init__(config)
self.text_model = Kosmos2TextModel(config.text_config)
self.vision_model = Kosmos2VisionModel(config.vision_config)
self.image_to_text_projection = Kosmos2ImageToTextProjection(config)
self.post_init()
def get_input_embeddings(self) -> nn.Module:
return self.text_model.model.embed_tokens
def set_input_embeddings(self, value):
self.text_model.model.embed_tokens = value
@add_start_docstrings_to_model_forward(KOSMOS2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Kosmos2ModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
input_ids: Optional[torch.Tensor] = None,
image_embeds_position_mask: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
image_embeds: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
"""
KOSMOS-2 Model for generating text and bounding boxes given an image. The model consists of a vision encoder and a
language model.
"""
@add_start_docstrings(
"""
KOSMOS-2 Model for generating text and bounding boxes given an image. The model consists of a vision encoder and a
language model.
""",
KOSMOS2_START_DOCSTRING,
)
class Kosmos2ForConditionalGeneration(Kosmos2PreTrainedModel):
config_class = Kosmos2Config
main_input_name = "pixel_values"
_tied_weights_keys = ["text_model.lm_head.weight"]
def __init__(self, config: Kosmos2Config):
super().__init__(config)
self.text_model = Kosmos2TextForCausalLM(config.text_config)
self.vision_model = Kosmos2VisionModel(config.vision_config)
self.image_to_text_projection = Kosmos2ImageToTextProjection(config)
self.post_init()
def get_input_embeddings(self) -> nn.Module:
return self.text_model.model.embed_tokens
def set_input_embeddings(self, value):
self.text_model.model.embed_tokens = value
def get_output_embeddings(self) -> nn.Module:
return self.text_model.get_output_embeddings()
def set_output_embeddings(self, new_embeddings):
self.text_model.set_output_embeddings(new_embeddings)
@add_start_docstrings_to_model_forward(KOSMOS2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Kosmos2ForConditionalGenerationModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
input_ids: Optional[torch.Tensor] = None,
image_embeds_position_mask: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
image_embeds: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
pass
def generate(
self,
pixel_values: Optional[torch.Tensor] = None,
image_embeds_position_mask: Optional[torch.Tensor] = None,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
image_embeds: Optional[torch.Tensor] = None,
**kwargs,
):
pass
inputs = kwargs.pop("inputs", None)
if pixel_values is not None and inputs is not None:
raise ValueError(
f"`inputs`: {inputs} were passed alongside `pixel_values` which is not allowed."
f"Make sure to either pass `inputs` or pixel_values=..."
)
if pixel_values is None and inputs is not None:
pixel_values = inputs
if image_embeds is None:
vision_model_output = self.vision_model(pixel_values)
image_embeds = self.vision_model.model.post_layernorm(vision_model_output[0])
image_embeds = nn.functional.normalize(image_embeds, dim=-1)
image_embeds, projection_attentions = self.image_to_text_projection(image_embeds)
output = self.text_model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
image_embeds=image_embeds,
image_embeds_position_mask=image_embeds_position_mask,
**kwargs,
)
return output