Transformers 源码解析(七)
.\kernels\deformable_detr\cpu\ms_deform_attn_cpu.cpp
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
// 包含标准库向量头文件
// 包含 ATen 库的头文件,提供张量操作
// 包含 CUDA 上下文头文件,用于处理 CUDA 相关操作
// 定义 CPU 下的前向传播函数,返回 ATen 张量
at::Tensor
ms_deform_attn_cpu_forward(
const at::Tensor &value, // 输入张量 value
const at::Tensor &spatial_shapes, // 空间形状张量
const at::Tensor &level_start_index, // 层级起始索引张量
const at::Tensor &sampling_loc, // 采样位置张量
const at::Tensor &attn_weight, // 注意力权重张量
const int im2col_step) // im2col 步长
{
// 抛出错误,表明在 CPU 上未实现该函数
AT_ERROR("Not implement on cpu");
}
// 定义 CPU 下的反向传播函数,返回 ATen 张量向量
std::vector<at::Tensor>
ms_deform_attn_cpu_backward(
const at::Tensor &value, // 输入张量 value
const at::Tensor &spatial_shapes, // 空间形状张量
const at::Tensor &level_start_index, // 层级起始索引张量
const at::Tensor &sampling_loc, // 采样位置张量
const at::Tensor &attn_weight, // 注意力权重张量
const at::Tensor &grad_output, // 梯度输出张量
const int im2col_step) // im2col 步长
{
// 抛出错误,表明在 CPU 上未实现该函数
AT_ERROR("Not implement on cpu");
}
.\kernels\deformable_detr\cpu\ms_deform_attn_cpu.h
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
// 预处理指令,确保头文件只被包含一次
// 包含 PyTorch C++ 扩展库头文件
// 前向传播函数声明,计算注意力机制的前向传播
at::Tensor
ms_deform_attn_cpu_forward(
const at::Tensor &value, // 输入的特征值张量
const at::Tensor &spatial_shapes, // 空间形状信息张量
const at::Tensor &level_start_index, // 层级起始索引张量
const at::Tensor &sampling_loc, // 采样位置张量
const at::Tensor &attn_weight, // 注意力权重张量
const int im2col_step); // im2col 步长
// 反向传播函数声明,计算注意力机制的反向传播
std::vector<at::Tensor>
ms_deform_attn_cpu_backward(
const at::Tensor &value, // 输入的特征值张量
const at::Tensor &spatial_shapes, // 空间形状信息张量
const at::Tensor &level_start_index, // 层级起始索引张量
const at::Tensor &sampling_loc, // 采样位置张量
const at::Tensor &attn_weight, // 注意力权重张量
const at::Tensor &grad_output, // 梯度输出张量
const int im2col_step); // im2col 步长
.\kernels\deformable_detr\cuda\ms_deform_attn_cuda.h
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
// 包含 Torch C++ 扩展库的头文件
// 声明 CUDA 前向函数,计算多尺度可变形注意力机制的前向传播
at::Tensor ms_deform_attn_cuda_forward(
const at::Tensor &value, // 输入张量:特征图
const at::Tensor &spatial_shapes, // 输入张量:空间形状
const at::Tensor &level_start_index, // 输入张量:每级起始索引
const at::Tensor &sampling_loc, // 输入张量:采样位置
const at::Tensor &attn_weight, // 输入张量:注意力权重
const int im2col_step // 输入整数:im2col 步骤
);
// 声明 CUDA BF16(BFloat16)前向函数,计算多尺度可变形注意力机制的前向传播
at::Tensor ms_deform_attn_cuda_forward_bf16(
const at::Tensor &value, // 输入张量:特征图
const at::Tensor &spatial_shapes, // 输入张量:空间形状
const at::Tensor &level_start_index, // 输入张量:每级起始索引
const at::Tensor &sampling_loc, // 输入张量:采样位置
const at::Tensor &attn_weight, // 输入张量:注意力权重
const int im2col_step // 输入整数:im2col 步骤
);
// 声明 CUDA 反向函数,计算多尺度可变形注意力机制的反向传播
std::vector<at::Tensor> ms_deform_attn_cuda_backward(
const at::Tensor &value, // 输入张量:特征图
const at::Tensor &spatial_shapes, // 输入张量:空间形状
const at::Tensor &level_start_index, // 输入张量:每级起始索引
const at::Tensor &sampling_loc, // 输入张量:采样位置
const at::Tensor &attn_weight, // 输入张量:注意力权重
const at::Tensor &grad_output, // 输入张量:梯度输出
const int im2col_step // 输入整数:im2col 步骤
);
// 声明 CUDA BF16(BFloat16)反向函数,计算多尺度可变形注意力机制的反向传播
std::vector<at::Tensor> ms_deform_attn_cuda_backward_bf16(
const at::Tensor &value, // 输入张量:特征图
const at::Tensor &spatial_shapes, // 输入张量:空间形状
const at::Tensor &level_start_index, // 输入张量:每级起始索引
const at::Tensor &sampling_loc, // 输入张量:采样位置
const at::Tensor &attn_weight, // 输入张量:注意力权重
const at::Tensor &grad_output, // 输入张量:梯度输出
const int im2col_step // 输入整数:im2col 步骤
);
.\kernels\deformable_detr\ms_deform_attn.h
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
// 前向传播函数,处理注意力机制的计算
at::Tensor
ms_deform_attn_forward(
const at::Tensor &value, // 输入张量,表示特征值
const at::Tensor &spatial_shapes, // 空间形状信息的张量
const at::Tensor &level_start_index, // 层级起始索引
const at::Tensor &sampling_loc, // 采样位置
const at::Tensor &attn_weight, // 注意力权重
const int im2col_step) // im2col 步长
{
// 如果输入张量在 GPU 上
if (value.type().is_cuda())
{
// 调用 CUDA 实现的前向传播函数
return ms_deform_attn_cuda_forward(
value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step);
// 如果没有编译 GPU 支持,则抛出错误
AT_ERROR("Not compiled with GPU support");
}
// 如果输入张量在 CPU 上,则抛出未实现 CPU 上的错误
AT_ERROR("Not implemented on the CPU");
}
// 反向传播函数,处理注意力机制的反向梯度计算
std::vector<at::Tensor>
ms_deform_attn_backward(
const at::Tensor &value, // 输入张量,表示特征值
const at::Tensor &spatial_shapes, // 空间形状信息的张量
const at::Tensor &level_start_index, // 层级起始索引
const at::Tensor &sampling_loc, // 采样位置
const at::Tensor &attn_weight, // 注意力权重
const at::Tensor &grad_output, // 梯度输出
const int im2col_step) // im2col 步长
{
// 如果输入张量在 GPU 上
if (value.type().is_cuda())
{
// 调用 CUDA 实现的反向传播函数
return ms_deform_attn_cuda_backward(
value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step);
// 如果没有编译 GPU 支持,则抛出错误
AT_ERROR("Not compiled with GPU support");
}
// 如果输入张量在 CPU 上,则抛出未实现 CPU 上的错误
AT_ERROR("Not implemented on the CPU");
}
.\kernels\deformable_detr\vision.cpp
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
// 使用 PYBIND11_MODULE 宏定义,将 C++ 函数绑定到 Python 中
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// 定义 Python 可调用函数 ms_deform_attn_forward,对应 C++ 中的 ms_deform_attn_forward 函数
m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward");
// 定义 Python 可调用函数 ms_deform_attn_backward,对应 C++ 中的 ms_deform_attn_backward 函数
m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward");
}
.\kernels\deta\cpu\ms_deform_attn_cpu.cpp
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
// 定义了一个函数,用于在 CPU 上执行 ms_deform_attn 的前向传播
at::Tensor
ms_deform_attn_cpu_forward(
const at::Tensor &value, // 输入张量 value
const at::Tensor &spatial_shapes, // 空间形状信息的张量
const at::Tensor &level_start_index, // 级别起始索引的张量
const at::Tensor &sampling_loc, // 采样位置的张量
const at::Tensor &attn_weight, // 注意力权重的张量
const int im2col_step) // im2col 步长参数
{
// 抛出错误,表示在 CPU 上尚未实现该函数
AT_ERROR("Not implement on cpu");
}
// 定义了一个函数,用于在 CPU 上执行 ms_deform_attn 的反向传播
std::vector<at::Tensor>
ms_deform_attn_cpu_backward(
const at::Tensor &value, // 输入张量 value
const at::Tensor &spatial_shapes, // 空间形状信息的张量
const at::Tensor &level_start_index, // 级别起始索引的张量
const at::Tensor &sampling_loc, // 采样位置的张量
const at::Tensor &attn_weight, // 注意力权重的张量
const at::Tensor &grad_output, // 梯度输出的张量
const int im2col_step) // im2col 步长参数
{
// 抛出错误,表示在 CPU 上尚未实现该函数
AT_ERROR("Not implement on cpu");
}
.\kernels\deta\cpu\ms_deform_attn_cpu.h
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
// 预处理指令,确保头文件只被包含一次
// 包含 PyTorch C++ 扩展的头文件
// 前向推断函数声明,计算可变形注意力机制的前向传播
at::Tensor
ms_deform_attn_cpu_forward(
const at::Tensor &value, // 输入特征张量
const at::Tensor &spatial_shapes, // 空间形状信息
const at::Tensor &level_start_index,// 级别起始索引
const at::Tensor &sampling_loc, // 采样位置
const at::Tensor &attn_weight, // 注意力权重
const int im2col_step); // im2col 步长
// 反向传播函数声明,计算可变形注意力机制的反向传播
std::vector<at::Tensor>
ms_deform_attn_cpu_backward(
const at::Tensor &value, // 输入特征张量
const at::Tensor &spatial_shapes, // 空间形状信息
const at::Tensor &level_start_index,// 级别起始索引
const at::Tensor &sampling_loc, // 采样位置
const at::Tensor &attn_weight, // 注意力权重
const at::Tensor &grad_output, // 梯度输出
const int im2col_step); // im2col 步长
这段代码是一个C++头文件,声明了两个函数 `ms_deform_attn_cpu_forward` 和 `ms_deform_attn_cpu_backward`,用于实现可变形注意力机制的前向传播和反向传播。
.\kernels\deta\cuda\ms_deform_attn_cuda.h
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
// 包含 Torch C++ 扩展的头文件
// 声明 CUDA 前向传播函数,接受多个张量和整数参数
at::Tensor ms_deform_attn_cuda_forward(
const at::Tensor &value, // 输入特征值张量
const at::Tensor &spatial_shapes, // 空间形状信息张量
const at::Tensor &level_start_index, // 层级起始索引张量
const at::Tensor &sampling_loc, // 采样位置张量
const at::Tensor &attn_weight, // 注意力权重张量
const int im2col_step); // im2col 步长整数参数
// 声明 CUDA 反向传播函数,接受多个张量和整数参数,并返回张量向量
std::vector<at::Tensor> ms_deform_attn_cuda_backward(
const at::Tensor &value, // 输入特征值张量
const at::Tensor &spatial_shapes, // 空间形状信息张量
const at::Tensor &level_start_index, // 层级起始索引张量
const at::Tensor &sampling_loc, // 采样位置张量
const at::Tensor &attn_weight, // 注意力权重张量
const at::Tensor &grad_output, // 梯度输出张量
const int im2col_step); // im2col 步长整数参数
.\kernels\deta\ms_deform_attn.h
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
// 前向传播函数,用于实现可形变注意力机制的前向计算
at::Tensor
ms_deform_attn_forward(
const at::Tensor &value, // 输入张量 value
const at::Tensor &spatial_shapes, // 空间形状信息张量
const at::Tensor &level_start_index, // 层级起始索引张量
const at::Tensor &sampling_loc, // 采样位置张量
const at::Tensor &attn_weight, // 注意力权重张量
const int im2col_step) // im2col 步长参数
{
// 如果输入张量在 CUDA 上,则调用 CUDA 实现的前向函数
if (value.type().is_cuda())
{
return ms_deform_attn_cuda_forward(
value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step);
// 如果没有编译 GPU 支持,则抛出错误信息
AT_ERROR("Not compiled with GPU support");
}
// 如果在 CPU 上调用该函数,则抛出错误信息,表明未实现 CPU 版本
AT_ERROR("Not implemented on the CPU");
}
// 反向传播函数,用于实现可形变注意力机制的反向计算
std::vector<at::Tensor>
ms_deform_attn_backward(
const at::Tensor &value, // 输入张量 value
const at::Tensor &spatial_shapes, // 空间形状信息张量
const at::Tensor &level_start_index, // 层级起始索引张量
const at::Tensor &sampling_loc, // 采样位置张量
const at::Tensor &attn_weight, // 注意力权重张量
const at::Tensor &grad_output, // 梯度输出张量
const int im2col_step) // im2col 步长参数
{
// 如果输入张量在 CUDA 上,则调用 CUDA 实现的反向函数
if (value.type().is_cuda())
{
return ms_deform_attn_cuda_backward(
value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step);
// 如果没有编译 GPU 支持,则抛出错误信息
AT_ERROR("Not compiled with GPU support");
}
// 如果在 CPU 上调用该函数,则抛出错误信息,表明未实现 CPU 版本
AT_ERROR("Not implemented on the CPU");
}
.\kernels\deta\vision.cpp
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
// 使用 Pybind11 构建一个 Python 模块,名字为 TORCH_EXTENSION_NAME
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// 定义 Python 接口函数 ms_deform_attn_forward,与 C++ 函数 ms_deform_attn_forward 绑定
m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward");
// 定义 Python 接口函数 ms_deform_attn_backward,与 C++ 函数 ms_deform_attn_backward 绑定
m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward");
}
.\kernels\mra\cuda_kernel.h
// 定义线程块大小为32
// 定义全掩码为32位全1
// 定义优化线程数为256
// CUDA 核函数,计算每个批次中每个块中的最大值索引和最大值
__global__ void index_max_cuda_kernel(
float *index_vals, // [batch_size, 32, num_block]
int *indices, // [batch_size, num_block]
float *max_vals, // [batch_size, A_num_block * 32]
float *max_vals_scatter, // [batch_size, 32, num_block]
long batch_size, // 批次大小
long A_num_block, // A_num_block
long B_num_block, // B_num_block
long num_block // num_block
);
// CUDA 核函数,将稠密矩阵乘法结果转换为稀疏格式
__global__ void mm_to_sparse_cuda_kernel(
float *dense_A, // [batch_size, A_num_block, dim, 32]
float *dense_B, // [batch_size, B_num_block, dim, 32]
int *indices, // [batch_size, num_block]
float *sparse_C, // [batch_size, num_block, 32, 32]
long batch_size, // 批次大小
long A_num_block, // A_num_block
long B_num_block, // B_num_block
long dim, // dim
long num_block // num_block
);
// CUDA 核函数,稀疏矩阵与稠密矩阵的乘法
__global__ void sparse_dense_mm_cuda_kernel(
float *sparse_A, // [batch_size, num_block, 32, 32]
int *indices, // [batch_size, num_block]
float *dense_B, // [batch_size, B_num_block, dim, 32]
float *dense_C, // [batch_size, A_num_block, dim, 32]
long batch_size, // 批次大小
long A_num_block, // A_num_block
long B_num_block, // B_num_block
long dim, // dim
long num_block // num_block
);
// CUDA 核函数,计算稀疏矩阵在指定维度上的和
__global__ void reduce_sum_cuda_kernel(
float *sparse_A, // [batch_size, num_block, 32, 32]
int *indices, // [batch_size, num_block]
float *dense_C, // [batch_size, A_num_block, 32]
long batch_size, // 批次大小
long A_num_block, // A_num_block
long B_num_block, // B_num_block
long num_block // num_block
);
// CUDA 核函数,将稠密矩阵按索引散布到稀疏矩阵中
__global__ void scatter_cuda_kernel(
float *dense_A, // [batch_size, A_num_block, 32]
int *indices, // [batch_size, num_block]
float *sparse_C, // [batch_size, num_block, 32, 32]
long batch_size, // 批次大小
long A_num_block, // A_num_block
long B_num_block, // B_num_block
long num_block // num_block
);
.\kernels\mra\cuda_launch.h
std::vector<at::Tensor> index_max_kernel(
at::Tensor index_vals,
at::Tensor indices,
int A_num_block,
int B_num_block
);
at::Tensor mm_to_sparse_kernel(
at::Tensor dense_A,
at::Tensor dense_B,
at::Tensor indices
);
at::Tensor sparse_dense_mm_kernel(
at::Tensor sparse_A,
at::Tensor indices,
at::Tensor dense_B,
int A_num_block
);
at::Tensor reduce_sum_kernel(
at::Tensor sparse_A,
at::Tensor indices,
int A_num_block,
int B_num_block
);
at::Tensor scatter_kernel(
at::Tensor dense_A,
at::Tensor indices,
int B_num_block
);
.\kernels\mra\torch_extension.cpp
std::vector<at::Tensor> index_max( // 定义函数 index_max,返回一个 Tensor 向量
at::Tensor index_vals, // 输入参数 index_vals,类型为 Tensor
at::Tensor indices, // 输入参数 indices,类型为 Tensor
int A_num_block, // 输入参数 A_num_block,整型
int B_num_block // 输入参数 B_num_block,整型
) {
return index_max_kernel( // 调用 index_max_kernel 函数,返回其结果
index_vals, // 将 index_vals 作为参数传递给 index_max_kernel 函数
indices, // 将 indices 作为参数传递给 index_max_kernel 函数
A_num_block, // 将 A_num_block 作为参数传递给 index_max_kernel 函数
B_num_block // 将 B_num_block 作为参数传递给 index_max_kernel 函数
);
}
at::Tensor mm_to_sparse( // 定义函数 mm_to_sparse,返回一个 Tensor
at::Tensor dense_A, // 输入参数 dense_A,类型为 Tensor
at::Tensor dense_B, // 输入参数 dense_B,类型为 Tensor
at::Tensor indices // 输入参数 indices,类型为 Tensor
) {
return mm_to_sparse_kernel( // 调用 mm_to_sparse_kernel 函数,返回其结果
dense_A, // 将 dense_A 作为参数传递给 mm_to_sparse_kernel 函数
dense_B, // 将 dense_B 作为参数传递给 mm_to_sparse_kernel 函数
indices // 将 indices 作为参数传递给 mm_to_sparse_kernel 函数
);
}
at::Tensor sparse_dense_mm( // 定义函数 sparse_dense_mm,返回一个 Tensor
at::Tensor sparse_A, // 输入参数 sparse_A,类型为 Tensor
at::Tensor indices, // 输入参数 indices,类型为 Tensor
at::Tensor dense_B, // 输入参数 dense_B,类型为 Tensor
int A_num_block // 输入参数 A_num_block,整型
) {
return sparse_dense_mm_kernel( // 调用 sparse_dense_mm_kernel 函数,返回其结果
sparse_A, // 将 sparse_A 作为参数传递给 sparse_dense_mm_kernel 函数
indices, // 将 indices 作为参数传递给 sparse_dense_mm_kernel 函数
dense_B, // 将 dense_B 作为参数传递给 sparse_dense_mm_kernel 函数
A_num_block // 将 A_num_block 作为参数传递给 sparse_dense_mm_kernel 函数
);
}
at::Tensor reduce_sum( // 定义函数 reduce_sum,返回一个 Tensor
at::Tensor sparse_A, // 输入参数 sparse_A,类型为 Tensor
at::Tensor indices, // 输入参数 indices,类型为 Tensor
int A_num_block, // 输入参数 A_num_block,整型
int B_num_block // 输入参数 B_num_block,整型
) {
return reduce_sum_kernel( // 调用 reduce_sum_kernel 函数,返回其结果
sparse_A, // 将 sparse_A 作为参数传递给 reduce_sum_kernel 函数
indices, // 将 indices 作为参数传递给 reduce_sum_kernel 函数
A_num_block, // 将 A_num_block 作为参数传递给 reduce_sum_kernel 函数
B_num_block // 将 B_num_block 作为参数传递给 reduce_sum_kernel 函数
);
}
at::Tensor scatter( // 定义函数 scatter,返回一个 Tensor
at::Tensor dense_A, // 输入参数 dense_A,类型为 Tensor
at::Tensor indices, // 输入参数 indices,类型为 Tensor
int B_num_block // 输入参数 B_num_block,整型
) {
return scatter_kernel( // 调用 scatter_kernel 函数,返回其结果
dense_A, // 将 dense_A 作为参数传递给 scatter_kernel 函数
indices, // 将 indices 作为参数传递给 scatter_kernel 函数
B_num_block // 将 B_num_block 作为参数传递给 scatter_kernel 函数
);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // 定义 Python 扩展模块
m.def("index_max", &index_max, "index_max (CUDA)"); // 将 index_max 函数绑定到 Python 中,并指定描述
m.def("mm_to_sparse", &mm_to_sparse, "mm_to_sparse (CUDA)"); // 将 mm_to_sparse 函数绑定到 Python 中,并指定描述
m.def("sparse_dense_mm", &sparse_dense_mm, "sparse_dense_mm (CUDA)"); // 将 sparse_dense_mm 函数绑定到 Python 中,并指定描述
m.def("reduce_sum", &reduce_sum, "reduce_sum (CUDA)"); // 将 reduce_sum 函数绑定到 Python 中,并指定描述
m.def("scatter", &scatter, "scatter (CUDA)"); // 将 scatter 函数绑定到 Python 中,并指定描述
}
.\kernels\rwkv\wkv_op.cpp
typedef at::BFloat16 bf16;
void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y);
void cuda_forward_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y);
void cuda_forward_with_state(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *s);
void cuda_forward_with_state_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, float *s);
void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv);
void cuda_backward_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, bf16 *gy, bf16 *gw, bf16 *gu, bf16 *gk, bf16 *gv);
void forward(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) {
const int B = k.size(0);
const int T = k.size(1);
const int C = k.size(2);
cuda_forward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>());
}
void forward_bf16(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) {
const int B = k.size(0);
const int T = k.size(1);
const int C = k.size(2);
cuda_forward_bf16(B, T, C, w.data_ptr<float>(), u.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), y.data_ptr<bf16>());
}
void forward_with_state(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &s) {
const int B = k.size(0);
const int T = k.size(1);
const int C = k.size(2);
cuda_forward_with_state(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>(), s.data_ptr<float>());
}
void forward_with_state_bf16(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &s) {
const int B = k.size(0);
const int T = k.size(1);
const int C = k.size(2);
cuda_forward_with_state_bf16(B, T, C, w.data_ptr<float>(), u.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), y.data_ptr<bf16>(), s.data_ptr<float>());
}
void backward(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) {
const int B = k.size(0);
const int T = k.size(1);
const int C = k.size(2);
cuda_backward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>(), gy.data_ptr<float>(), gw.data_ptr<float>(), gu.data_ptr<float>(), gk.data_ptr<float>(), gv.data_ptr<float>());
}
void backward_bf16(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) {
const int B = k.size(0);
cuda_backward_bf16(B, T, C, w.data_ptr<float>(), u.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), y.data_ptr<bf16>(), gy.data_ptr<bf16>(), gw.data_ptr<bf16>(), gu.data_ptr<bf16>(), gk.data_ptr<bf16>(), gv.data_ptr<bf16>());
}
const int T = k.size(1);
const int C = k.size(2);
cuda_backward_bf16(B, T, C, w.data_ptr<float>(), u.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), y.data_ptr<bf16>(),
gy.data_ptr<bf16>(), gw.data_ptr<bf16>(), gu.data_ptr<bf16>(), gk.data_ptr<bf16>(), gv.data_ptr<bf16>());
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// 定义 Python 绑定模块的函数 "forward",与 C++ 函数 &forward 绑定,描述为 "wkv forward"
m.def("forward", &forward, "wkv forward");
// 定义 Python 绑定模块的函数 "forward_bf16",与 C++ 函数 &forward_bf16 绑定,描述为 "wkv forward bf16"
m.def("forward_bf16", &forward_bf16, "wkv forward bf16");
// 定义 Python 绑定模块的函数 "forward_with_state",与 C++ 函数 &forward_with_state 绑定,描述为 "wkv forward with state"
m.def("forward_with_state", &forward_with_state, "wkv forward with state");
// 定义 Python 绑定模块的函数 "forward_with_state_bf16",与 C++ 函数 &forward_with_state_bf16 绑定,描述为 "wkv forward with state bf16"
m.def("forward_with_state_bf16", &forward_with_state_bf16, "wkv forward with state bf16");
// 定义 Python 绑定模块的函数 "backward",与 C++ 函数 &backward 绑定,描述为 "wkv backward"
m.def("backward", &backward, "wkv backward");
// 定义 Python 绑定模块的函数 "backward_bf16",与 C++ 函数 &backward_bf16 绑定,描述为 "wkv backward bf16"
m.def("backward_bf16", &backward_bf16, "wkv backward bf16");
}
TORCH_LIBRARY(wkv, m) {
// 在 Torch 的 wkv 库中注册函数 "forward",与 C++ 函数 forward 绑定
m.def("forward", forward);
// 在 Torch 的 wkv 库中注册函数 "forward_bf16",与 C++ 函数 forward_bf16 绑定
m.def("forward_bf16", forward_bf16);
// 在 Torch 的 wkv 库中注册函数 "forward_with_state",与 C++ 函数 forward_with_state 绑定
m.def("forward_with_state", forward_with_state);
// 在 Torch 的 wkv 库中注册函数 "forward_with_state_bf16",与 C++ 函数 forward_with_state_bf16 绑定
m.def("forward_with_state_bf16", forward_with_state_bf16);
// 在 Torch 的 wkv 库中注册函数 "backward",与 C++ 函数 backward 绑定
m.def("backward", backward);
// 在 Torch 的 wkv 库中注册函数 "backward_bf16",与 C++ 函数 backward_bf16 绑定
m.def("backward_bf16", backward_bf16);
}
.\kernels\yoso\common.h
.\kernels\yoso\common_cuda.h
.\kernels\yoso\common_cuda_device.h
template<typename T>
__device__ int set_insert(T *set, int set_size, T value) {
int slot = value % set_size;
int start_slot = slot;
while (true) {
T prev = atomicCAS(&set[slot], EMPTY_VALUE, value);
if (prev == EMPTY_VALUE || prev == value) {
return slot;
}
slot = (slot + 1) % set_size;
if (slot == start_slot) {
return -1;
}
}
return -1;
}
template<typename T>
__device__ int set_lookup(T *set, int set_size, T value) {
int slot = value % set_size;
int start_slot = slot;
while (true) {
if (set[slot] == value) {
return slot;
}
slot = (slot + 1) % set_size;
if (slot == start_slot) {
return -1;
}
}
return -1;
}
template<typename T>
__device__ void init_buffer(T init_value, T *buffer, int buffer_size, int num_threads, int thread_id) {
__syncthreads();
for (int i = 0; i < buffer_size; i = i + num_threads) {
int offset_idx = i + thread_id;
if (offset_idx < buffer_size) {
buffer[offset_idx] = init_value;
}
}
__syncthreads();
}
template<typename T>
__device__ void copy_data(T *src_pt, T *dist_pt, int data_length, int num_threads, int thread_id) {
__syncthreads();
for (int i = 0; i < data_length; i = i + num_threads) {
int offset_idx = i + thread_id;
if (offset_idx < data_length) {
dist_pt[offset_idx] = src_pt[offset_idx];
}
}
__syncthreads();
}
template<typename T>
__device__ void init_buffer_nonblocking(T init_value, T *buffer, int buffer_size, int num_threads, int thread_id) {
for (int i = 0; i < buffer_size; i = i + num_threads) {
int offset_idx = i + thread_id;
if (offset_idx < buffer_size) {
buffer[offset_idx] = init_value;
}
}
}
template<typename T>
__device__ void copy_data_nonblocking(T *src_pt, T *dist_pt, int data_length, int num_threads, int thread_id) {
for (int i = 0; i < data_length; i = i + num_threads) {
int offset_idx = i + thread_id;
if (offset_idx < data_length) {
dist_pt[offset_idx] = src_pt[offset_idx];
}
}
}
.\kernels\yoso\fast_lsh_cumulation.h
// 导入 PyTorch C++ 扩展头文件
// 导入 ATen 库的头文件
// 导入 STL 中的 vector 容器
// 定义快速哈希(版本1)的核函数,返回多个张量作为结果
std::vector<at::Tensor> fast_hash_ver1_kernel(
// 查询掩码张量
at::Tensor query_mask,
// 查询向量张量
at::Tensor query_vector,
// 关键字掩码张量
at::Tensor key_mask,
// 关键字向量张量
at::Tensor key_vector,
// 哈希函数数量
int num_hash_f,
// 哈希码长度
int hash_code_len,
// 是否使用 CUDA
bool use_cuda
);
// 定义哈希累积(版本1)的核函数,返回张量作为结果
at::Tensor lsh_cumulation_ver1_kernel(
// 查询掩码张量
at::Tensor query_mask,
// 查询哈希码张量
at::Tensor query_hash_code,
// 关键字掩码张量
at::Tensor key_mask,
// 关键字哈希码张量
at::Tensor key_hash_code,
// 值张量
at::Tensor value,
// 哈希表容量
int hashtable_capacity,
// 是否使用 CUDA
bool use_cuda
);
// 定义加权哈希累积(版本1)的核函数,返回张量作为结果
at::Tensor lsh_weighted_cumulation_ver1_kernel(
// 查询掩码张量
at::Tensor query_mask,
// 查询哈希码张量
at::Tensor query_hash_code,
// 查询权重张量
at::Tensor query_weight,
// 关键字掩码张量
at::Tensor key_mask,
// 关键字哈希码张量
at::Tensor key_hash_code,
// 关键字权重张量
at::Tensor key_weight,
// 值张量
at::Tensor value,
// 哈希表容量
int hashtable_capacity,
// 是否使用 CUDA
bool use_cuda
);
// 定义加权哈希累积(版本2、3、4)的核函数,具体功能与版本1类似
// 只是版本号不同,参数及返回值的张量类型与数量相同,不再重复注释每个版本的功能
at::Tensor lsh_weighted_cumulation_ver2_kernel(
at::Tensor query_mask,
at::Tensor query_hash_code,
at::Tensor query_weight,
at::Tensor key_mask,
at::Tensor key_hash_code,
at::Tensor key_weight,
at::Tensor value,
int hashtable_capacity,
bool use_cuda
);
at::Tensor lsh_weighted_cumulation_ver3_kernel(
at::Tensor query_mask,
at::Tensor query_hash_code,
at::Tensor query_weight,
at::Tensor key_mask,
at::Tensor key_hash_code,
at::Tensor key_weight,
at::Tensor value,
int hashtable_capacity,
bool use_cuda
);
at::Tensor lsh_weighted_cumulation_ver4_kernel(
at::Tensor query_mask,
at::Tensor query_hash_code,
at::Tensor query_weight,
at::Tensor key_mask,
at::Tensor key_hash_code,
at::Tensor key_weight,
at::Tensor value,
int hashtable_capacity,
bool use_cuda
);
.\kernels\yoso\fast_lsh_cumulation_cuda.h
__global__ void fast_hash_ver1_cuda_kernel(
int *mask, // [batch_size, num_vector],用于存储掩码数据的整数指针
float *vector, // [batch_size, num_vector, vector_dim],存储向量数据的浮点数指针
int *Dmat, // [3, num_part, vector_dim],存储分割矩阵数据的整数指针
int *hash_code, // [batch_size, num_vector, num_hash_f],存储哈希码数据的整数指针
int batch_size, // 批处理大小,整数参数
int num_vector, // 向量数量,整数参数
int vector_dim, // 向量维度,整数参数
int num_part, // 分割数,整数参数
int num_hash_f, // 哈希函数数量,整数参数
int hash_code_len // 哈希码长度,整数参数
);
__global__ void lsh_cumulation_ver1_step1_cuda_kernel(
int *key_mask, // [batch_size, num_key],用于存储键掩码数据的整数指针
int *key_hash_code, // [batch_size, num_key, num_hash_f],存储键哈希码数据的整数指针
float *value, // [batch_size, num_key, value_dim],存储值数据的浮点数指针
float *hashtable_value, // [batch_size, num_hash_f, hashtable_capacity, value_dim],哈希表值的浮点数指针
int batch_size, // 批处理大小,整数参数
int num_hash_f, // 哈希函数数量,整数参数
int hashtable_capacity, // 哈希表容量,整数参数
int num_key, // 键数量,整数参数
int value_dim, // 值维度,整数参数
int offset_warp // 偏移量(warp),整数参数
);
__global__ void lsh_cumulation_ver1_step2_cuda_kernel(
int *query_mask, // [batch_size, num_query],用于存储查询掩码数据的整数指针
int *query_hash_code, // [batch_size, num_query, num_hash_f],存储查询哈希码数据的整数指针
float *hashtable_value, // [batch_size, num_hash_f, hashtable_capacity, value_dim],哈希表值的浮点数指针
float *cumulation_value, // [batch_size, num_query, value_dim],累积值的浮点数指针
int batch_size, // 批处理大小,整数参数
int num_hash_f, // 哈希函数数量,整数参数
int hashtable_capacity, // 哈希表容量,整数参数
int num_query, // 查询数量,整数参数
int value_dim, // 值维度,整数参数
int offset_warp // 偏移量(warp),整数参数
);
__global__ void lsh_weighted_cumulation_ver1_step1_cuda_kernel(
int *key_mask, // [batch_size, num_key],用于存储键掩码数据的整数指针
int *key_hash_code, // [batch_size, num_key, num_hash_f],存储键哈希码数据的整数指针
float *key_weight, // [batch_size, num_key, weight_dim],存储键权重数据的浮点数指针
float *value, // [batch_size, num_key, value_dim],存储值数据的浮点数指针
float *hashtable_value, // [batch_size, num_hash_f, hashtable_capacity, WARP_SIZE],哈希表值的浮点数指针
int batch_size, // 批处理大小,整数参数
int num_hash_f, // 哈希函数数量,整数参数
int hashtable_capacity, // 哈希表容量,整数参数
int num_key, // 键数量,整数参数
int value_dim, // 值维度,整数参数
int weight_dim, // 权重维度,整数参数
int offset_warp, // 偏移量(warp),整数参数
int weight_idx // 权重索引,整数参数
);
__global__ void lsh_weighted_cumulation_ver1_step2_cuda_kernel(
int *query_mask, // [batch_size, num_query],用于存储查询掩码数据的整数指针
int *query_hash_code, // [batch_size, num_query, num_hash_f],存储查询哈希码数据的整数指针
float *query_weight, // [batch_size, num_query, weight_dim],存储查询权重数据的浮点数指针
float *hashtable_value, // [batch_size, num_hash_f, hashtable_capacity, WARP_SIZE],哈希表值的浮点数指针
float *cumulation_value, // [batch_size, num_query, value_dim],累积值的浮点数指针
int batch_size, // 批处理大小,整数参数
int num_hash_f, // 哈希函数数量,整数参数
int hashtable_capacity, // 哈希表容量,整数参数
int num_query, // 查询数量,整数参数
int value_dim, // 值维度,整数参数
int weight_dim, // 权重维度,整数参数
int offset_warp, // 偏移量(warp),整数参数
int weight_idx // 权重索引,整数参数
);
__global__ void count_sort_step1_cuda_kernel(
int *key_mask, // [batch_size, num_key],用于存储键掩码数据的整数指针
int *key_hash_code, // [batch_size, num_key, num_hash_f],存储键哈希码数据的整数指针
int *count_sort_table, // [batch_size, num_hash_f, hashtable_capacity],计数排序表的整数指针
int batch_size, // 批处理大小,整数参数
int num_hash_f, // 哈希函数数量,整数参数
int hashtable_capacity,// 哈希表容量,整数参数
int num_key // 键数量,整数参数
);
__global__ void count_sort_step2_cuda_kernel(
int *count_sort_table, // [batch_size, num_hash_f, hashtable_capacity],计数排序表的整数指针
int batch_size, // 批处理大小,整数参数
int num_hash_f, // 哈希函数数量,整数参数
int hashtable_capacity // 哈希表容量,整数参数
);
__global__ void count_sort_step3_cuda_kernel(
int *key_mask, // 输入:表示批次中每个关键字的掩码数组 [batch_size, num_key]
int *key_hash_code, // 输入:表示批次中每个关键字的哈希码数组 [batch_size, num_key, num_hash_f]
int *count_sort_table, // 输入/输出:计数排序表格,用于存储排序后的关键字索引 [batch_size, num_hash_f, hashtable_capacity]
int *key_sorted_idxes, // 输出:存储排序后的关键字索引 [batch_size, num_hash_f, num_key]
int batch_size, // 输入:批次大小
int num_hash_f, // 输入:哈希函数数量
int hashtable_capacity, // 输入:哈希表容量
int num_key // 输入:每个批次中的关键字数量
);
__global__ void extract_query_info_cuda_kernel(
int *query_mask, // 输入:表示批次中每个查询的掩码数组 [batch_size, num_query]
int *query_hash_code, // 输入:表示批次中每个查询的哈希码数组 [batch_size, num_query, num_hash_f]
int *count_sort_table, // 输入:计数排序表格,用于存储排序后的关键字索引 [batch_size, num_hash_f, hashtable_capacity]
int *query_info, // 输出:存储查询信息,包括关键字索引和哈希函数索引 [batch_size, num_query, 2, num_hash_f]
int batch_size, // 输入:批次大小
int num_hash_f, // 输入:哈希函数数量
int hashtable_capacity,// 输入:哈希表容量
int num_query // 输入:每个批次中的查询数量
);
__global__ void lsh_weighted_cumulation_ver2_step2_cuda_kernel(
int *query_mask, // 输入:表示批次中每个查询的掩码数组 [batch_size, num_query]
int *query_info, // 输入:存储查询信息,包括关键字索引和哈希函数索引 [batch_size, num_query, 2, num_hash_f]
int *key_sorted_idxes, // 输入:存储排序后的关键字索引 [batch_size, num_hash_f, num_key]
float *query_weight, // 输入:查询的权重数组 [batch_size, num_query, weight_dim]
float *key_weight, // 输入:关键字的权重数组 [batch_size, num_key, weight_dim]
float *value, // 输入:关键字对应的值数组 [batch_size, num_key, value_dim]
float *cumulation_value, // 输出:累积后的值数组 [batch_size, num_query, value_dim]
int batch_size, // 输入:批次大小
int num_hash_f, // 输入:哈希函数数量
int num_query, // 输入:每个批次中的查询数量
int num_key, // 输入:每个批次中的关键字数量
int value_dim, // 输入:值的维度
int weight_dim // 输入:权重的维度
);
__global__ void lsh_weighted_cumulation_ver3_step2_cuda_kernel(
int *query_sorted_idxes, // 输入:存储排序后的查询索引 [batch_size, num_hash_f, num_query]
int *key_mask, // 输入:表示批次中每个关键字的掩码数组 [batch_size, num_key]
int *key_info, // 输入:关键字的信息数组,包括索引和哈希函数索引 [batch_size, num_key, 2, num_hash_f]
float *query_weight, // 输入:查询的权重数组 [batch_size, num_query, weight_dim]
float *key_weight, // 输入:关键字的权重数组 [batch_size, num_key, weight_dim]
float *value, // 输入:关键字对应的值数组 [batch_size, num_key, value_dim]
float *cumulation_value, // 输出:累积后的值数组 [batch_size, num_query, value_dim]
int batch_size, // 输入:批次大小
int num_hash_f, // 输入:哈希函数数量
int num_query, // 输入:每个批次中的查询数量
int num_key, // 输入:每个批次中的关键字数量
int value_dim, // 输入:值的维度
int weight_dim // 输入:权重的维度
);
__global__ void lsh_weighted_cumulation_ver4_step2_cuda_kernel(
int *query_sorted_idxes, // 输入:存储排序后的查询索引 [batch_size, num_hash_f, num_query]
int *key_mask, // 输入:表示批次中每个关键字的掩码数组 [batch_size, num_key]
int *key_info, // 输入:关键字的信息数组,包括索引和哈希函数索引 [batch_size, num_key, 2, num_hash_f]
float *query_weight, // 输入:查询的权重数组 [batch_size, num_query, weight_dim]
float *key_weight, // 输入:关键字的权重数组 [batch_size, num_key, weight_dim]
float *value, // 输入:关键字对应的值数组 [batch_size, num_key, value_dim]
float *cumulation_value, // 输出:累积后的值数组 [batch_size, num_query, value_dim]
int batch_size, // 输入:批次大小
int num_hash_f, // 输入:哈希函数数量
int num_query, // 输入:每个批次中的查询数量
int num_key, // 输入:每个批次中的关键字数量
int value_dim, // 输入:值的维度
int weight_dim // 输入:权重的维度
);
.\kernels\yoso\fast_lsh_cumulation_torch.cpp
// 快速哈希函数,调用指定版本的核函数处理哈希计算
std::vector<at::Tensor> fast_hash(
at::Tensor query_mask, // 查询掩码,形状为[batch_size, num_query]
at::Tensor query_vector, // 查询向量,形状为[batch_size, num_query, vector_dim]
at::Tensor key_mask, // 键掩码,形状为[batch_size, num_key]
at::Tensor key_vector, // 键向量,形状为[batch_size, num_key, vector_dim]
int num_hash_f, // 哈希函数数量
int hash_code_len, // 哈希码长度
bool use_cuda, // 是否使用CUDA加速
int version // 函数版本号
) {
return fast_hash_ver1_kernel(
query_mask,
query_vector,
key_mask,
key_vector,
num_hash_f,
hash_code_len,
use_cuda
);
}
// LSH累积函数,调用指定版本的核函数执行LSH累积操作
at::Tensor lsh_cumulation(
at::Tensor query_mask, // 查询掩码,形状为[batch_size, num_query]
at::Tensor query_hash_code, // 查询哈希码,形状为[batch_size, num_query, num_hash_f]
at::Tensor key_mask, // 键掩码,形状为[batch_size, num_key]
at::Tensor key_hash_code, // 键哈希码,形状为[batch_size, num_key, num_hash_f]
at::Tensor value, // 值,形状为[batch_size, num_key, value_dim]
int hashtable_capacity, // 哈希表容量
bool use_cuda, // 是否使用CUDA加速
int version // 函数版本号
) {
return lsh_cumulation_ver1_kernel(
query_mask,
query_hash_code,
key_mask,
key_hash_code,
value,
hashtable_capacity,
use_cuda
);
}
// 加权LSH累积函数,根据版本号调用不同的核函数执行不同版本的加权LSH累积操作
at::Tensor lsh_weighted_cumulation(
at::Tensor query_mask, // 查询掩码,形状为[batch_size, num_query]
at::Tensor query_hash_code, // 查询哈希码,形状为[batch_size, num_query, num_hash_f]
at::Tensor query_weight, // 查询权重,形状为[batch_size, num_query, weight_dim]
at::Tensor key_mask, // 键掩码,形状为[batch_size, num_key]
at::Tensor key_hash_code, // 键哈希码,形状为[batch_size, num_key, num_hash_f]
at::Tensor key_weight, // 键权重,形状为[batch_size, num_key, weight_dim]
at::Tensor value, // 值,形状为[batch_size, num_key, value_dim]
int hashtable_capacity, // 哈希表容量
bool use_cuda, // 是否使用CUDA加速
int version // 函数版本号
) {
if (version == 1) {
return lsh_weighted_cumulation_ver1_kernel(
query_mask,
query_hash_code,
query_weight,
key_mask,
key_hash_code,
key_weight,
value,
hashtable_capacity,
use_cuda
);
} else if (version == 2) {
return lsh_weighted_cumulation_ver2_kernel(
query_mask,
query_hash_code,
query_weight,
key_mask,
key_hash_code,
key_weight,
value,
hashtable_capacity,
use_cuda
);
} else if (version == 3) {
return lsh_weighted_cumulation_ver3_kernel(
query_mask,
query_hash_code,
query_weight,
key_mask,
key_hash_code,
key_weight,
value,
hashtable_capacity,
use_cuda
);
} else if (version == 4) {
return lsh_weighted_cumulation_ver4_kernel(
query_mask,
query_hash_code,
query_weight,
key_mask,
key_hash_code,
key_weight,
value,
hashtable_capacity,
use_cuda
);
} else {
// 默认情况下使用第三个版本的核函数
return lsh_weighted_cumulation_ver3_kernel(
query_mask,
query_hash_code,
query_weight,
key_mask,
key_hash_code,
key_weight,
value,
hashtable_capacity,
use_cuda
);
}
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fast_hash", &fast_hash, "Fast Hash (CUDA)");
m.def("lsh_cumulation", &lsh_cumulation, "LSH Cumulation (CUDA)");
m.def("lsh_weighted_cumulation", &lsh_weighted_cumulation, "LSH Weighted Cumulation (CUDA)");
}
.\modelcard.py
import copy
import json
import os
import warnings
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import requests
import yaml
from huggingface_hub import model_info
from huggingface_hub.utils import HFValidationError
from . import __version__
from .models.auto.modeling_auto import (
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
MODEL_FOR_CTC_MAPPING_NAMES,
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES,
MODEL_FOR_MASKED_LM_MAPPING_NAMES,
MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES,
MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES,
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES,
)
from .training_args import ParallelMode
from .utils import (
MODEL_CARD_NAME,
cached_file,
is_datasets_available,
is_offline_mode,
is_tf_available,
is_tokenizers_available,
is_torch_available,
logging,
)
TASK_MAPPING = {
"text-generation": MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
"image-classification": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
"image-segmentation": MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES,
"fill-mask": MODEL_FOR_MASKED_LM_MAPPING_NAMES,
"object-detection": MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES,
"question-answering": MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES,
"text2text-generation": MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
"text-classification": MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
"table-question-answering": MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES,
"token-classification": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
"audio-classification": MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
"automatic-speech-recognition": {**MODEL_FOR_CTC_MAPPING_NAMES, **MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES},
"zero-shot-image-classification": MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES,
}
logger = logging.get_logger(__name__)
"""
# 初始化方法,用于创建模型卡片对象
def __init__(self, **kwargs):
# 发出警告,表示该类 `ModelCard` 已被弃用,并将在 Transformers 的第五版中移除
warnings.warn(
"The class `ModelCard` is deprecated and will be removed in version 5 of Transformers", FutureWarning
)
# 推荐的属性来源于 https://arxiv.org/abs/1810.03993(见论文)
# 设置模型细节
self.model_details = kwargs.pop("model_details", {})
# 设置预期使用
self.intended_use = kwargs.pop("intended_use", {})
# 设置因素
self.factors = kwargs.pop("factors", {})
# 设置度量
self.metrics = kwargs.pop("metrics", {})
# 设置评估数据
self.evaluation_data = kwargs.pop("evaluation_data", {})
# 设置训练数据
self.training_data = kwargs.pop("training_data", {})
# 设置定量分析
self.quantitative_analyses = kwargs.pop("quantitative_analyses", {})
# 设置伦理考虑
self.ethical_considerations = kwargs.pop("ethical_considerations", {})
# 设置注意事项和建议
self.caveats_and_recommendations = kwargs.pop("caveats_and_recommendations", {})
# 打开额外的属性
for key, value in kwargs.items():
try:
setattr(self, key, value)
except AttributeError as err:
# 如果无法设置属性,则记录错误信息并抛出异常
logger.error(f"Can't set {key} with value {value} for {self}")
raise err
# 将模型卡片对象保存到指定的目录或文件
def save_pretrained(self, save_directory_or_file):
"""Save a model card object to the directory or file `save_directory_or_file`."""
# 如果保存目录存在,则使用预定义的文件名保存,方便使用 `from_pretrained` 加载
if os.path.isdir(save_directory_or_file):
output_model_card_file = os.path.join(save_directory_or_file, MODEL_CARD_NAME)
else:
output_model_card_file = save_directory_or_file
# 将模型卡片对象保存为 JSON 文件
self.to_json_file(output_model_card_file)
logger.info(f"Model card saved in {output_model_card_file}")
# 从 Python 字典中构造一个 `ModelCard` 对象的类方法
@classmethod
def from_dict(cls, json_object):
"""Constructs a `ModelCard` from a Python dictionary of parameters."""
return cls(**json_object)
# 从 JSON 文件中构造一个 `ModelCard` 对象的类方法
@classmethod
def from_json_file(cls, json_file):
"""Constructs a `ModelCard` from a json file of parameters."""
# 读取 JSON 文件内容
with open(json_file, "r", encoding="utf-8") as reader:
text = reader.read()
# 解析 JSON 文本为 Python 字典对象
dict_obj = json.loads(text)
# 使用字典对象构造一个新的 `ModelCard` 对象
return cls(**dict_obj)
# 判断两个 `ModelCard` 对象是否相等的特殊方法
def __eq__(self, other):
return self.__dict__ == other.__dict__
# 返回 `ModelCard` 对象的字符串表示形式的特殊方法
def __repr__(self):
return str(self.to_json_string())
# 将当前对象实例序列化为一个 Python 字典
def to_dict(self):
"""Serializes this instance to a Python dictionary."""
# 深拷贝当前对象的所有属性到 output 字典中
output = copy.deepcopy(self.__dict__)
return output
# 将当前对象实例序列化为 JSON 字符串
def to_json_string(self):
"""Serializes this instance to a JSON string."""
# 调用 to_dict 方法获取对象的字典表示,转换为带缩进和排序键的 JSON 字符串,并添加换行符
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
# 将当前对象实例保存到一个 JSON 文件中
def to_json_file(self, json_file_path):
"""Save this instance to a json file."""
# 打开指定路径的 JSON 文件,使用 UTF-8 编码写入对象的 JSON 字符串表示
with open(json_file_path, "w", encoding="utf-8") as writer:
writer.write(self.to_json_string())
AUTOGENERATED_TRAINER_COMMENT = """
<!-- This model card has been generated automatically according to the information the Trainer had access to. You
should probably proofread and complete it, then remove this comment. -->
"""
AUTOGENERATED_KERAS_COMMENT = """
<!-- This model card has been generated automatically according to the information Keras had access to. You should
probably proofread and complete it, then remove this comment. -->
"""
TASK_TAG_TO_NAME_MAPPING = {
"fill-mask": "Masked Language Modeling", # 映射任务标签 "fill-mask" 到任务名称 "Masked Language Modeling"
"image-classification": "Image Classification", # 映射任务标签 "image-classification" 到任务名称 "Image Classification"
"image-segmentation": "Image Segmentation", # 映射任务标签 "image-segmentation" 到任务名称 "Image Segmentation"
"multiple-choice": "Multiple Choice", # 映射任务标签 "multiple-choice" 到任务名称 "Multiple Choice"
"object-detection": "Object Detection", # 映射任务标签 "object-detection" 到任务名称 "Object Detection"
"question-answering": "Question Answering", # 映射任务标签 "question-answering" 到任务名称 "Question Answering"
"summarization": "Summarization", # 映射任务标签 "summarization" 到任务名称 "Summarization"
"table-question-answering": "Table Question Answering", # 映射任务标签 "table-question-answering" 到任务名称 "Table Question Answering"
"text-classification": "Text Classification", # 映射任务标签 "text-classification" 到任务名称 "Text Classification"
"text-generation": "Causal Language Modeling", # 映射任务标签 "text-generation" 到任务名称 "Causal Language Modeling"
"text2text-generation": "Sequence-to-sequence Language Modeling", # 映射任务标签 "text2text-generation" 到任务名称 "Sequence-to-sequence Language Modeling"
"token-classification": "Token Classification", # 映射任务标签 "token-classification" 到任务名称 "Token Classification"
"translation": "Translation", # 映射任务标签 "translation" 到任务名称 "Translation"
"zero-shot-classification": "Zero Shot Classification", # 映射任务标签 "zero-shot-classification" 到任务名称 "Zero Shot Classification"
"automatic-speech-recognition": "Automatic Speech Recognition", # 映射任务标签 "automatic-speech-recognition" 到任务名称 "Automatic Speech Recognition"
"audio-classification": "Audio Classification", # 映射任务标签 "audio-classification" 到任务名称 "Audio Classification"
}
METRIC_TAGS = [
"accuracy", # 表示度量标签 "accuracy",用于评估模型准确性
"bleu", # 表示度量标签 "bleu",用于评估机器翻译质量
"f1", # 表示度量标签 "f1",用于评估分类和信息检索等任务的准确性
"matthews_correlation", # 表示度量标签 "matthews_correlation",用于评估二分类问题中的相关性
"pearsonr", # 表示度量标签 "pearsonr",用于评估两个变量之间的线性相关性
"precision", # 表示度量标签 "precision",用于评估分类模型中的精确性
"recall", # 表示度量标签 "recall",用于评估分类模型中的召回率
"rouge", # 表示度量标签 "rouge",用于评估文本摘要生成模型的质量
"sacrebleu", # 表示度量标签 "sacrebleu",用于机器翻译任务中的 BLEU 得分
"spearmanr", # 表示度量标签 "spearmanr",用于评估两个变量的非线性相关性
"wer", # 表示度量标签 "wer",用于评估自动语音识别中的词错误率
]
def _listify(obj):
if obj is None:
return [] # 如果对象为 None,则返回空列表
elif isinstance(obj, str):
return [obj] # 如果对象为字符串,则返回包含该字符串的列表
else:
return obj # 否则返回原始对象
def _insert_values_as_list(metadata, name, values):
if values is None:
return metadata # 如果值为 None,则返回元数据本身
if isinstance(values, str):
values = [values] # 如果值为字符串,则转换成单元素列表
values = [v for v in values if v is not None] # 过滤掉值中的 None 元素
if len(values) == 0:
return metadata # 如果列表为空,则返回元数据本身
metadata[name] = values # 将处理后的列表赋给元数据对应的名称
return metadata # 返回更新后的元数据
def infer_metric_tags_from_eval_results(eval_results):
if eval_results is None:
return {} # 如果评估结果为 None,则返回空字典
result = {} # 初始化结果字典
for key in eval_results.keys():
if key.lower().replace(" ", "_") in METRIC_TAGS:
result[key.lower().replace(" ", "_")] = key # 将符合度量标签的键添加到结果字典中
elif key.lower() == "rouge1":
result["rouge"] = key # 特别处理 "rouge1",将其映射为 "rouge"
return result # 返回最终的结果字典
def _insert_value(metadata, name, value):
if value is None:
return metadata # 如果值为 None,则返回元数据本身
metadata[name] = value # 将值插入到元数据中对应的名称
return metadata # 返回更新后的元数据
def is_hf_dataset(dataset):
if not is_datasets_available():
return False # 如果 datasets 库不可用,则返回 False
from datasets import Dataset, IterableDataset
return isinstance(dataset, (Dataset, IterableDataset)) # 判断 dataset 是否是 Dataset 或 IterableDataset 类的实例
def _get_mapping_values(mapping):
result = [] # 初始化结果列表
for v in mapping.values():
if isinstance(v, (tuple, list)):
result += list(v) # 如果值是元组或列表,则将其展开并添加到结果列表中
else:
result.append(v) # 否则直接添加到结果列表中
return result # 返回所有映射值组成的列表
@dataclass
class TrainingSummary:
model_name: str # 模型名称
language: Optional[Union[str, List[str]]] = None # 语言属性,可以是字符串或字符串列表,默认为 None
license: Optional[str] = None # 许可证信息,默认为 None
"""
tags: Optional[Union[str, List[str]]] = None
finetuned_from: Optional[str] = None
tasks: Optional[Union[str, List[str]]] = None
dataset: Optional[Union[str, List[str]]] = None
dataset_tags: Optional[Union[str, List[str]]] = None
dataset_args: Optional[Union[str, List[str]]] = None
dataset_metadata: Optional[Dict[str, Any]] = None
eval_results: Optional[Dict[str, float]] = None
eval_lines: Optional[List[str]] = None
hyperparameters: Optional[Dict[str, Any]] = None
source: Optional[str] = "trainer"
def __post_init__(self):
if (
self.license is None
and not is_offline_mode()
and self.finetuned_from is not None
and len(self.finetuned_from) > 0
):
try:
info = model_info(self.finetuned_from)
for tag in info.tags:
if tag.startswith("license:"):
self.license = tag[8:]
except (requests.exceptions.HTTPError, requests.exceptions.ConnectionError, HFValidationError):
pass
def create_model_index(self, metric_mapping):
model_index = {"name": self.model_name}
dataset_names = _listify(self.dataset)
dataset_tags = _listify(self.dataset_tags)
dataset_args = _listify(self.dataset_args)
dataset_metadata = _listify(self.dataset_metadata)
if len(dataset_args) < len(dataset_tags):
dataset_args = dataset_args + [None] * (len(dataset_tags) - len(dataset_args))
dataset_mapping = dict(zip(dataset_tags, dataset_names))
dataset_arg_mapping = dict(zip(dataset_tags, dataset_args))
dataset_metadata_mapping = dict(zip(dataset_tags, dataset_metadata))
task_mapping = {
task: TASK_TAG_TO_NAME_MAPPING[task] for task in _listify(self.tasks) if task in TASK_TAG_TO_NAME_MAPPING
}
model_index["results"] = []
if len(task_mapping) == 0 and len(dataset_mapping) == 0:
return [model_index]
if len(task_mapping) == 0:
task_mapping = {None: None}
if len(dataset_mapping) == 0:
dataset_mapping = {None: None}
all_possibilities = [(task_tag, ds_tag) for task_tag in task_mapping for ds_tag in dataset_mapping]
for task_tag, ds_tag in all_possibilities:
result = {}
if task_tag is not None:
result["task"] = {"name": task_mapping[task_tag], "type": task_tag}
if ds_tag is not None:
metadata = dataset_metadata_mapping.get(ds_tag, {})
result["dataset"] = {
"name": dataset_mapping[ds_tag],
"type": ds_tag,
**metadata,
}
if dataset_arg_mapping[ds_tag] is not None:
result["dataset"]["args"] = dataset_arg_mapping[ds_tag]
if len(metric_mapping) > 0:
result["metrics"] = []
for metric_tag, metric_name in metric_mapping.items():
result["metrics"].append(
{
"name": metric_name,
"type": metric_tag,
"value": self.eval_results[metric_name],
}
)
if "task" in result and "dataset" in result and "metrics" in result:
model_index["results"].append(result)
else:
logger.info(f"Dropping the following result as it does not have all the necessary fields:\n{result}")
return [model_index]
def create_metadata(self):
metric_mapping = infer_metric_tags_from_eval_results(self.eval_results)
metadata = {}
metadata = _insert_values_as_list(metadata, "language", self.language)
metadata = _insert_value(metadata, "license", self.license)
if self.finetuned_from is not None and isinstance(self.finetuned_from, str) and len(self.finetuned_from) > 0:
metadata = _insert_value(metadata, "base_model", self.finetuned_from)
metadata = _insert_values_as_list(metadata, "tags", self.tags)
metadata = _insert_values_as_list(metadata, "datasets", self.dataset_tags)
metadata = _insert_values_as_list(metadata, "metrics", list(metric_mapping.keys()))
metadata["model-index"] = self.create_model_index(metric_mapping)
return metadata
@classmethod
def from_trainer(
cls,
trainer,
language=None,
license=None,
tags=None,
model_name=None,
finetuned_from=None,
tasks=None,
dataset_tags=None,
dataset_metadata=None,
dataset=None,
dataset_args=None,
):
one_dataset = trainer.eval_dataset if trainer.eval_dataset is not None else trainer.train_dataset
if is_hf_dataset(one_dataset) and (dataset_tags is None or dataset_args is None or dataset_metadata is None):
default_tag = one_dataset.builder_name
if default_tag not in ["csv", "json", "pandas", "parquet", "text"]:
if dataset_metadata is None:
dataset_metadata = [{"config": one_dataset.config_name, "split": str(one_dataset.split)}]
if dataset_tags is None:
dataset_tags = [default_tag]
if dataset_args is None:
dataset_args = [one_dataset.config_name]
if dataset is None and dataset_tags is not None:
dataset = dataset_tags
if (
finetuned_from is None
and hasattr(trainer.model.config, "_name_or_path")
and not os.path.isdir(trainer.model.config._name_or_path)
):
finetuned_from = trainer.model.config._name_or_path
if tasks is None:
model_class_name = trainer.model.__class__.__name__
for task, mapping in TASK_MAPPING.items():
if model_class_name in _get_mapping_values(mapping):
tasks = task
if model_name is None:
model_name = Path(trainer.args.output_dir).name
if len(model_name) == 0:
model_name = finetuned_from
if tags is None:
tags = ["generated_from_trainer"]
elif isinstance(tags, str) and tags != "generated_from_trainer":
tags = [tags, "generated_from_trainer"]
elif "generated_from_trainer" not in tags:
tags.append("generated_from_trainer")
_, eval_lines, eval_results = parse_log_history(trainer.state.log_history)
hyperparameters = extract_hyperparameters_from_trainer(trainer)
return cls(
language=language,
license=license,
tags=tags,
model_name=model_name,
finetuned_from=finetuned_from,
tasks=tasks,
dataset=dataset,
dataset_tags=dataset_tags,
dataset_args=dataset_args,
dataset_metadata=dataset_metadata,
eval_results=eval_results,
eval_lines=eval_lines,
hyperparameters=hyperparameters,
)
@classmethod
def from_keras(
cls,
model,
model_name,
keras_history=None,
language=None,
license=None,
tags=None,
finetuned_from=None,
tasks=None,
dataset_tags=None,
dataset=None,
dataset_args=None,
):
if dataset is not None:
if is_hf_dataset(dataset) and (dataset_tags is None or dataset_args is None):
default_tag = dataset.builder_name
if default_tag not in ["csv", "json", "pandas", "parquet", "text"]:
if dataset_tags is None:
dataset_tags = [default_tag]
if dataset_args is None:
dataset_args = [dataset.config_name]
if dataset is None and dataset_tags is not None:
dataset = dataset_tags
if (
finetuned_from is None
and hasattr(model.config, "_name_or_path")
and not os.path.isdir(model.config._name_or_path)
):
finetuned_from = model.config._name_or_path
if tasks is None:
model_class_name = model.__class__.__name__
for task, mapping in TASK_MAPPING.items():
if model_class_name in _get_mapping_values(mapping):
Add ` generated_from_keras_callback to
def parse_keras_history(logs):
if hasattr(logs, "history"):
if not hasattr(logs, "epoch"):
return None, [], {}
logs.history["epoch"] = logs.epoch
logs = logs.history
else:
logs = {log_key: [single_dict[log_key] for single_dict in logs] for log_key in logs[0]}
lines = []
for i in range(len(logs["epoch"])):
epoch_dict = {log_key: log_value_list[i] for log_key, log_value_list in logs.items()}
values = {}
for k, v in epoch_dict.items():
if k.startswith("val_"):
k = "validation_" + k[4:]
elif k != "epoch":
k = "train_" + k
splits = k.split("_")
name = " ".join([part.capitalize() for part in splits])
values[name] = v
lines.append(values)
eval_results = lines[-1]
return logs, lines, eval_results
def parse_log_history(log_history):
idx = 0
while idx < len(log_history) and "train_runtime" not in log_history[idx]:
idx += 1
if idx == len(log_history):
idx -= 1
while idx >= 0 and "eval_loss" not in log_history[idx]:
idx -= 1
if idx >= 0:
return None, None, log_history[idx]
else:
return None, None, None
train_log = log_history[idx]
lines = []
training_loss = "No log"
for i in range(idx):
if "loss" in log_history[i]:
training_loss = log_history[i]["loss"]
if "eval_loss" in log_history[i]:
metrics = log_history[i].copy()
_ = metrics.pop("total_flos", None)
epoch = metrics.pop("epoch", None)
step = metrics.pop("step", None)
_ = metrics.pop("eval_runtime", None)
_ = metrics.pop("eval_samples_per_second", None)
_ = metrics.pop("eval_steps_per_second", None)
_ = metrics.pop("eval_jit_compilation_time", None)
values = {"Training Loss": training_loss, "Epoch": epoch, "Step": step}
for k, v in metrics.items():
if k == "eval_loss":
values["Validation Loss"] = v
else:
splits = k.split("_")
name = " ".join([part.capitalize() for part in splits[1:]])
values[name] = v
lines.append(values)
idx = len(log_history) - 1
while idx >= 0 and "eval_loss" not in log_history[idx]:
idx -= 1
if idx > 0:
eval_results = {}
for key, value in log_history[idx].items():
if key.startswith("eval_"):
key = key[5:]
if key not in ["runtime", "samples_per_second", "steps_per_second", "epoch", "step"]:
camel_cased_key = " ".join([part.capitalize() for part in key.split("_")])
eval_results[camel_cased_key] = value
return train_log, lines, eval_results
else:
return train_log, lines, None
def extract_hyperparameters_from_keras(model):
from .modeling_tf_utils import keras
hyperparameters = {}
if hasattr(model, "optimizer") and model.optimizer is not None:
hyperparameters["optimizer"] = model.optimizer.get_config()
else:
hyperparameters["optimizer"] = None
hyperparameters["training_precision"] = keras.mixed_precision.global_policy().name
return hyperparameters
def _maybe_round(v, decimals=4):
if isinstance(v, float) and len(str(v).split(".")) > 1 and len(str(v).split(".")[1]) > decimals:
return f"{v:.{decimals}f}"
return str(v)
def _regular_table_line(values, col_widths):
values_with_space = [f"| {v}" + " " * (w - len(v) + 1) for v, w in zip(values, col_widths)]
return "".join(values_with_space) + "|\n"
def _second_table_line(col_widths):
values = ["|:" + "-" * w + ":" for w in col_widths]
return "".join(values) + "|\n"
def make_markdown_table(lines):
"""
Create a nice Markdown table from the results in `lines`.
"""
if lines is None or len(lines) == 0:
return ""
col_widths = {key: len(str(key)) for key in lines[0].keys()}
for line in lines:
for key, value in line.items():
if col_widths[key] < len(_maybe_round(value)):
col_widths[key] = len(_maybe_round(value))
table = _regular_table_line(list(lines[0].keys()), list(col_widths.values()))
table += _second_table_line(list(col_widths.values()))
for line in lines:
table += _regular_table_line([_maybe_round(v) for v in line.values()], list(col_widths.values()))
return table
_TRAINING_ARGS_KEYS = [
"learning_rate",
"train_batch_size",
"eval_batch_size",
"seed",
]
def extract_hyperparameters_from_trainer(trainer):
hyperparameters = {k: getattr(trainer.args, k) for k in _TRAINING_ARGS_KEYS}
if trainer.args.parallel_mode not in [ParallelMode.NOT_PARALLEL, ParallelMode.NOT_DISTRIBUTED]:
hyperparameters["distributed_type"] = (
"multi-GPU" if trainer.args.parallel_mode == ParallelMode.DISTRIBUTED else trainer.args.parallel_mode.value
)
if trainer.args.world_size > 1:
hyperparameters["num_devices"] = trainer.args.world_size
if trainer.args.gradient_accumulation_steps > 1:
hyperparameters["gradient_accumulation_steps"] = trainer.args.gradient_accumulation_steps
total_train_batch_size = (
trainer.args.train_batch_size * trainer.args.world_size * trainer.args.gradient_accumulation_steps
)
if total_train_batch_size != hyperparameters["train_batch_size"]:
hyperparameters["total_train_batch_size"] = total_train_batch_size
total_eval_batch_size = trainer.args.eval_batch_size * trainer.args.world_size
if total_eval_batch_size != hyperparameters["eval_batch_size"]:
hyperparameters["total_eval_batch_size"] = total_eval_batch_size
if trainer.args.adafactor:
hyperparameters["optimizer"] = "Adafactor"
else:
hyperparameters["optimizer"] = (
f"Adam with betas=({trainer.args.adam_beta1},{trainer.args.adam_beta2}) and"
f" epsilon={trainer.args.adam_epsilon}"
)
hyperparameters["lr_scheduler_type"] = trainer.args.lr_scheduler_type.value
if trainer.args.warmup_ratio != 0.0:
hyperparameters["lr_scheduler_warmup_ratio"] = trainer.args.warmup_ratio
if trainer.args.warmup_steps != 0.0:
hyperparameters["lr_scheduler_warmup_steps"] = trainer.args.warmup_steps
if trainer.args.max_steps != -1:
hyperparameters["training_steps"] = trainer.args.max_steps
else:
hyperparameters["num_epochs"] = trainer.args.num_train_epochs
if trainer.args.fp16:
if trainer.use_apex:
hyperparameters["mixed_precision_training"] = f"Apex, opt level {trainer.args.fp16_opt_level}"
else:
hyperparameters["mixed_precision_training"] = "Native AMP"
if trainer.args.label_smoothing_factor != 0.0:
hyperparameters["label_smoothing_factor"] = trainer.args.label_smoothing_factor
return hyperparameters
.\modeling_attn_mask_utils.py
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import torch
@dataclass
class AttentionMaskConverter:
"""
A utility attention mask class that allows one to:
- Create a causal 4d mask
- Create a causal 4d mask with slided window
- Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask (batch_size, 1, query_length,
key_value_length) that can be multiplied with attention scores
Examples:
```
>>> import torch
>>> from transformers.modeling_attn_mask_utils import AttentionMaskConverter
>>> converter = AttentionMaskConverter(True)
>>> converter.to_4d(torch.tensor([[0, 0, 0, 1, 1]]), 5, key_value_length=5, dtype=torch.float32)
tensor([[[[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
[-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, -3.4028e+38],
[-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, 0.0000e+00]]]])
```
Parameters:
is_causal (`bool`):
Whether the attention mask should be a uni-directional (causal) or bi-directional mask.
sliding_window (`int`, *optional*):
Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer.
"""
is_causal: bool
sliding_window: int
def __init__(self, is_causal: bool, sliding_window: Optional[int] = None):
self.is_causal = is_causal
self.sliding_window = sliding_window
if self.sliding_window is not None and self.sliding_window <= 0:
raise ValueError(
f"Make sure that when passing `sliding_window` that its value is a strictly positive integer, not `{self.sliding_window}`"
)
def to_causal_4d(
self,
batch_size: int,
query_length: int,
key_value_length: int,
dtype: torch.dtype,
device: Union[torch.device, "str"] = "cpu",
) -> Optional[torch.Tensor]:
"""
Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative
bias to upper right hand triangular matrix (causal mask).
"""
if not self.is_causal:
raise ValueError(f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True.")
input_shape = (batch_size, query_length)
past_key_values_length = key_value_length - query_length
causal_4d_mask = None
if input_shape[-1] > 1 or self.sliding_window is not None:
causal_4d_mask = self._make_causal_mask(
input_shape,
dtype,
device=device,
past_key_values_length=past_key_values_length,
sliding_window=self.sliding_window,
)
return causal_4d_mask
def to_4d(
self,
attention_mask_2d: torch.Tensor,
query_length: int,
dtype: torch.dtype,
key_value_length: Optional[int] = None,
) -> torch.Tensor:
"""
Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length,
key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is
causal, a causal mask will be added.
"""
input_shape = (attention_mask_2d.shape[0], query_length)
causal_4d_mask = None
if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
if key_value_length is None:
raise ValueError(
"This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask."
)
past_key_values_length = key_value_length - query_length
causal_4d_mask = self._make_causal_mask(
input_shape,
dtype,
device=attention_mask_2d.device,
past_key_values_length=past_key_values_length,
sliding_window=self.sliding_window,
)
elif self.sliding_window is not None:
raise NotImplementedError("Sliding window is currently only implemented for causal masking")
expanded_attn_mask = self._expand_mask(attention_mask_2d, dtype, tgt_len=input_shape[-1]).to(
attention_mask_2d.device
)
if causal_4d_mask is not None:
expanded_attn_mask = causal_4d_mask.masked_fill(expanded_attn_mask.bool(), torch.finfo(dtype).min)
expanded_4d_mask = expanded_attn_mask
return expanded_4d_mask
@staticmethod
def _make_causal_mask(
input_ids_shape: torch.Size,
dtype: torch.dtype,
device: torch.device,
past_key_values_length: int = 0,
sliding_window: Optional[int] = None,
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
if sliding_window is not None:
diagonal = past_key_values_length - sliding_window - 1
context_mask = torch.tril(torch.ones_like(mask, dtype=torch.bool), diagonal=diagonal)
mask.masked_fill_(context_mask, torch.finfo(dtype).min)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
@staticmethod
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
@staticmethod
def _unmask_unattended(
expanded_mask: torch.FloatTensor,
min_dtype: float,
device: Optional[torch.device] = None
):
"""
Unmasks the unattended positions in the attention matrix.
"""
"""
Attend to all tokens in masked rows from the expanded attention mask, for example the relevant first rows when
using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
Details: https://github.com/pytorch/pytorch/issues/110213
`expanded_mask` is [bsz, num_masks, tgt_seq_len, src_seq_len] or [bsz, tgt_seq_len, src_seq_len].
`attention_mask` is [bsz, src_seq_len].
The dimension num_masks of `expanded_mask` is most often 1, but it can also be the number of heads in the case of alibi attention bias.
For example, if `expanded_mask` is (e.g. here left-padding case)
```
[[[[0, 0, 0],
[0, 0, 0],
[0, 0, 1]]],
[[[1, 0, 0],
[1, 1, 0],
[1, 1, 1]]],
[[[0, 0, 0],
[0, 1, 0],
[0, 1, 1]]]]
```
then the modified `expanded_mask` will be
```
[[[[1, 1, 1], <-- modified
[1, 1, 1], <-- modified
[0, 0, 1]]],
[[[1, 0, 0],
[1, 1, 0],
[1, 1, 1]]],
[[[1, 1, 1], <-- modified
[0, 1, 0],
[0, 1, 1]]]]
"""
if expanded_mask.dtype == torch.bool:
raise ValueError(
"AttentionMaskConverter._unmask_unattended expects a float `expanded_mask`, got a BoolTensor."
)
return expanded_mask.mul(~torch.all(expanded_mask == min_dtype, dim=-1, keepdim=True))
def _prepare_4d_causal_attention_mask(
attention_mask: Optional[torch.Tensor],
input_shape: Union[torch.Size, Tuple, List],
inputs_embeds: torch.Tensor,
past_key_values_length: int,
sliding_window: Optional[int] = None,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`
Args:
attention_mask (`torch.Tensor` or `None`):
A 2D attention mask of shape `(batch_size, key_value_length)`
input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
The input shape should be a tuple that defines `(batch_size, query_length)`.
inputs_embeds (`torch.Tensor`):
The embedded inputs as a torch Tensor.
past_key_values_length (`int`):
The length of the key value cache.
sliding_window (`int`, *optional*):
If the model uses windowed attention, a sliding window should be passed.
"""
attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
key_value_length = input_shape[-1] + past_key_values_length
if attention_mask is not None and len(attention_mask.shape) == 2:
attention_mask = attn_mask_converter.to_4d(
attention_mask, input_shape[-1], key_value_length=key_value_length, dtype=inputs_embeds.dtype
)
elif attention_mask is not None and len(attention_mask.shape) == 4:
expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
if tuple(attention_mask.shape) != expected_shape:
raise ValueError(
f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
)
else:
inverted_mask = 1.0 - attention_mask
attention_mask = inverted_mask.masked_fill(
inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min
)
else:
attention_mask = attn_mask_converter.to_causal_4d(
input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
)
return attention_mask
def _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask: Optional[torch.Tensor],
input_shape: Union[torch.Size, Tuple, List],
inputs_embeds: torch.Tensor,
past_key_values_length: int,
sliding_window: Optional[int] = None,
):
"""
Prepares the correct `attn_mask` argument to be used by `torch.nn.functional.scaled_dot_product_attention`.
In case no token is masked in the `attention_mask` argument, we simply set it to `None` for the cases `query_length == 1` and
"""
attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
key_value_length = input_shape[-1] + past_key_values_length
batch_size, query_length = input_shape
is_tracing = (
torch.jit.is_tracing()
or isinstance(inputs_embeds, torch.fx.Proxy)
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
)
if attention_mask is not None:
if len(attention_mask.shape) == 4:
expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
if tuple(attention_mask.shape) != expected_shape:
raise ValueError(
f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
)
else:
inverted_mask = 1.0 - attention_mask.to(inputs_embeds.dtype)
attention_mask = inverted_mask.masked_fill(
inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min
)
return attention_mask
elif not is_tracing and torch.all(attention_mask == 1):
if query_length == 1:
attention_mask = None
elif key_value_length == query_length:
attention_mask = None
else:
pass
elif query_length > 1 and key_value_length != query_length:
attention_mask = True
elif is_tracing:
raise ValueError(
'Attention using SDPA can not be traced with torch.jit.trace when no attention_mask is provided. To solve this issue, please either load your model with the argument `attn_implementation="eager"` or pass an attention_mask input when tracing the model.'
)
if attention_mask is None:
expanded_4d_mask = None
elif attention_mask is True:
expanded_4d_mask = attn_mask_converter.to_causal_4d(
input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
)
else:
expanded_4d_mask = attn_mask_converter.to_4d(
attention_mask,
input_shape[-1],
dtype=inputs_embeds.dtype,
key_value_length=key_value_length,
)
if not is_tracing and expanded_4d_mask.device.type == "cuda":
expanded_4d_mask = AttentionMaskConverter._unmask_unattended(
expanded_4d_mask, min_dtype=torch.finfo(inputs_embeds.dtype).min
)
return expanded_4d_mask
def _prepare_4d_attention_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`
Args:
mask (`torch.Tensor` or `None`):
A 2D attention mask of shape `(batch_size, key_value_length)`
dtype (`torch.dtype`):
The torch dtype the created mask shall have.
tgt_len (`int`):
The target length or query length the created mask shall have.
"""
return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
def _prepare_4d_attention_mask_for_sdpa(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`
Args:
mask (`torch.Tensor` or `None`):
A 2D attention mask of shape `(batch_size, key_value_length)`
dtype (`torch.dtype`):
The torch dtype the created mask shall have.
tgt_len (`int`):
The target length or query length the created mask shall have.
"""
batch_size, key_value_length = mask.shape
tgt_len = tgt_len if tgt_len is not None else key_value_length
is_tracing = (
torch.jit.is_tracing()
or isinstance(mask, torch.fx.Proxy)
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
)
if torch.all(mask == 1):
if is_tracing:
pass
elif tgt_len == 1:
return None
elif key_value_length == tgt_len:
return None
else:
return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
else:
return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
def create_causal_mask(
input_shape: Union[tuple[int], list[int], torch.Size],
dtype: torch.dtype,
device: torch.device,
sliding_window: Optional[int] = None
) -> Optional[torch.Tensor]:
"""
创建一个形状为 `(batch_size, 1, query_length, key_value_length)` 的因果性四维掩码
Args:
input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
输入形状应为定义 `(batch_size, query_length)` 的元组。
dtype (`torch.dtype`):
所创建掩码的 torch 数据类型。
device (`torch.device`):
所创建掩码的 torch 设备。
sliding_window (`int`, *optional*):
如果模型使用窗口化注意力,应传入一个滑动窗口大小。
"""
attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
key_value_length = past_key_values_length + input_shape[-1]
attention_mask = attn_mask_converter.to_causal_4d(
input_shape[0], input_shape[-1], key_value_length, dtype=dtype, device=device
)
return attention_mask
.\modeling_flax_outputs.py
from typing import Dict, Optional, Tuple
import flax
import jax.numpy as jnp
from .utils import ModelOutput
@flax.struct.dataclass
class FlaxBaseModelOutput(ModelOutput):
"""
模型输出的基础类,包含可能的隐藏状态和注意力机制。
Args:
last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`):
模型最后一层的隐藏状态序列输出。
hidden_states (`tuple(jnp.ndarray)`, *optional*, 当 `output_hidden_states=True` 被传递或 `config.output_hidden_states=True` 时返回):
形状为 `(batch_size, sequence_length, hidden_size)` 的 `jnp.ndarray` 元组。
模型每一层的隐藏状态加上初始嵌入输出。
attentions (`tuple(jnp.ndarray)`, *optional*, 当 `output_attentions=True` 被传递或 `config.output_attentions=True` 时返回):
形状为 `(batch_size, num_heads, sequence_length, sequence_length)` 的 `jnp.ndarray` 元组。
注意力机制softmax后的注意力权重,用于计算自注意力头中的加权平均值。
"""
last_hidden_state: jnp.ndarray = None
hidden_states: Optional[Tuple[jnp.ndarray]] = None
attentions: Optional[Tuple[jnp.ndarray]] = None
@flax.struct.dataclass
class FlaxBaseModelOutputWithNoAttention(ModelOutput):
"""
模型输出的基础类,包含可能的隐藏状态,但不包含注意力机制。
Args:
last_hidden_state (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`):
模型最后一层的隐藏状态序列输出。
hidden_states (`tuple(jnp.ndarray)`, *optional*, 当 `output_hidden_states=True` 被传递或 `config.output_hidden_states=True` 时返回):
形状为 `(batch_size, num_channels, height, width)` 的 `jnp.ndarray` 元组。
模型每一层的隐藏状态加上可选的初始嵌入输出。
"""
last_hidden_state: jnp.ndarray = None
hidden_states: Optional[Tuple[jnp.ndarray]] = None
@flax.struct.dataclass
class FlaxBaseModelOutputWithPoolingAndNoAttention(ModelOutput):
"""
Base class for model's outputs that also contains a pooling of the last hidden states.
Args:
last_hidden_state (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`):
Sequence of hidden-states at the output of the last layer of the model.
pooler_output (`jnp.ndarray` of shape `(batch_size, hidden_size)`):
Last layer hidden-state after a pooling operation on the spatial dimensions.
hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `jnp.ndarray` (one for the output of the embeddings, if the model has an embedding layer, + one
for the output of each layer) of shape `(batch_size, num_channels, height, width)`. Hidden-states of the
model at the output of each layer plus the optional initial embedding outputs.
"""
last_hidden_state: jnp.ndarray = None
pooler_output: jnp.ndarray = None
hidden_states: Optional[Tuple[jnp.ndarray]] = None
@flax.struct.dataclass
class FlaxImageClassifierOutputWithNoAttention(ModelOutput):
"""
Base class for outputs of image classification models.
Args:
logits (`jnp.ndarray` of shape `(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when
`config.output_hidden_states=True`):
Tuple of `jnp.ndarray` (one for the output of the embeddings, if the model has an embedding layer, + one
for the output of each stage) of shape `(batch_size, num_channels, height, width)`. Hidden-states (also
called feature maps) of the model at the output of each stage.
"""
logits: jnp.ndarray = None
hidden_states: Optional[Tuple[jnp.ndarray]] = None
@flax.struct.dataclass
class FlaxBaseModelOutputWithPast(ModelOutput):
"""
Base class for model's outputs, with potential hidden states and attentions.
"""
last_hidden_state: jnp.ndarray = None
past_key_values: Optional[Dict[str, jnp.ndarray]] = None
hidden_states: Optional[Tuple[jnp.ndarray]] = None
attentions: Optional[Tuple[jnp.ndarray]] = None
@flax.struct.dataclass
class FlaxBaseModelOutputWithPooling(ModelOutput):
"""
Base class for model's outputs that also contains a pooling of the last hidden states.
Args:
last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
pooler_output (`jnp.ndarray` of shape `(batch_size, hidden_size)`):
Last layer hidden-state of the first token of the sequence (classification token) further processed by a
Linear layer and a Tanh activation function. The Linear layer weights are trained from the next sentence
prediction (classification) objective during pretraining.
hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
last_hidden_state: jnp.ndarray = None
pooler_output: jnp.ndarray = None
hidden_states: Optional[Tuple[jnp.ndarray]] = None
attentions: Optional[Tuple[jnp.ndarray]] = None
@flax.struct.dataclass
class FlaxBaseModelOutputWithPoolingAndCrossAttentions(ModelOutput):
"""
Base class for model's outputs that also contains a pooling of the last hidden states.
"""
Args:
last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
pooler_output (`jnp.ndarray` of shape `(batch_size, hidden_size)`):
Last layer hidden-state of the first token of the sequence (classification token) after further processing
through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns
the classification token after processing through a linear layer and a tanh activation function. The linear
layer weights are trained from the next sentence prediction (classification) objective during pretraining.
hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `jnp.ndarray` (one for the output of the embeddings, if the model has an embedding layer, + one
for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
cross_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
weighted average in the cross-attention heads.
past_key_values (`tuple(tuple(jnp.ndarray))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(jnp.ndarray)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
`config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
input) to speed up sequential decoding.
"""
last_hidden_state: jnp.ndarray = None
# 定义一个变量 `pooler_output`,类型为 `jnp.ndarray`,初始值为 None
pooler_output: jnp.ndarray = None
# 定义一个变量 `hidden_states`,类型为 `Optional[Tuple[jnp.ndarray]]`,初始值为 None
hidden_states: Optional[Tuple[jnp.ndarray]] = None
# 定义一个变量 `past_key_values`,类型为 `Optional[Tuple[Tuple[jnp.ndarray]]]`,初始值为 None
past_key_values: Optional[Tuple[Tuple[jnp.ndarray]]] = None
# 定义一个变量 `attentions`,类型为 `Optional[Tuple[jnp.ndarray]]`,初始值为 None
attentions: Optional[Tuple[jnp.ndarray]] = None
# 定义一个变量 `cross_attentions`,类型为 `Optional[Tuple[jnp.ndarray]]`,初始值为 None
cross_attentions: Optional[Tuple[jnp.ndarray]] = None
# 使用 @flax.struct.dataclass 装饰器声明一个数据类,该类继承自 ModelOutput 类
@flax.struct.dataclass
class FlaxBaseModelOutputWithPastAndCrossAttentions(ModelOutput):
"""
Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
Args:
last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
hidden_size)` is output.
past_key_values (`tuple(tuple(jnp.ndarray))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(jnp.ndarray)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
`config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
input) to speed up sequential decoding.
hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
cross_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
weighted average in the cross-attention heads.
"""
# 定义类的属性,每个属性都有一个默认值为 None
last_hidden_state: jnp.ndarray = None
past_key_values: Optional[Tuple[Tuple[jnp.ndarray]]] = None
hidden_states: Optional[Tuple[jnp.ndarray]] = None
attentions: Optional[Tuple[jnp.ndarray]] = None
cross_attentions: Optional[Tuple[jnp.ndarray]] = None
# 定义基于Flax的数据类,表示序列到序列模型的输出,继承自ModelOutput
@flax.struct.dataclass
class FlaxSeq2SeqModelOutput(ModelOutput):
"""
Base class for model encoder's outputs that also contains pre-computed hidden states that can speed up sequential decoding.
"""
# 最后一个隐藏状态,类型为jnp.ndarray,默认为None
last_hidden_state: jnp.ndarray = None
# 过去的键值对,类型为可选的元组,包含元组的元组,每个元组包含jnp.ndarray,默认为None
past_key_values: Optional[Tuple[Tuple[jnp.ndarray]]] = None
# 解码器的隐藏状态,类型为可选的元组,包含jnp.ndarray,默认为None
decoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None
# 解码器的注意力权重,类型为可选的元组,包含jnp.ndarray,默认为None
decoder_attentions: Optional[Tuple[jnp.ndarray]] = None
# 交叉注意力的权重,类型为可选的元组,包含jnp.ndarray,默认为None
cross_attentions: Optional[Tuple[jnp.ndarray]] = None
# 编码器最后一个隐藏状态,类型为可选的jnp.ndarray,默认为None
encoder_last_hidden_state: Optional[jnp.ndarray] = None
# 编码器的隐藏状态,类型为可选的元组,包含jnp.ndarray,默认为None
encoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None
# 编码器的注意力权重,类型为可选的元组,包含jnp.ndarray,默认为None
encoder_attentions: Optional[Tuple[jnp.ndarray]] = None
# 定义基于Flax的数据类,表示带有交叉注意力的因果语言模型输出,继承自ModelOutput
@flax.struct.dataclass
class FlaxCausalLMOutputWithCrossAttentions(ModelOutput):
"""
Base class for causal language model (or autoregressive) outputs.
"""
# 预测的logits,形状为(batch_size, sequence_length, config.vocab_size)的jnp.ndarray
logits: jnp.ndarray
# 隐藏状态的元组,包含embedding输出和每层输出的jnp.ndarray,形状为(batch_size, sequence_length, hidden_size)
hidden_states: Optional[Tuple[jnp.ndarray]] = None
# 注意力权重的元组,每层一个jnp.ndarray,形状为(batch_size, num_heads, sequence_length, sequence_length)
attentions: Optional[Tuple[jnp.ndarray]] = None
# 交叉注意力权重的元组,每层一个jnp.ndarray,形状为(batch_size, num_heads, sequence_length, sequence_length)
cross_attentions: Optional[Tuple[jnp.ndarray]] = None
# 过去的键值对的元组,每层一个jnp.ndarray元组,长度为config.n_layers,仅在使用缓存时有效,用于编码器-解码器设置
past_key_values: Optional[Tuple[Tuple[jnp.ndarray]]] = None
# 定义变量 logits,用于存储一个 NumPy 数组,初始值为 None
logits: jnp.ndarray = None
# 定义变量 past_key_values,类型为 Optional[Tuple[Tuple[jnp.ndarray]]],可选的三重嵌套元组结构,初始值为 None
past_key_values: Optional[Tuple[Tuple[jnp.ndarray]]] = None
# 定义变量 hidden_states,类型为 Optional[Tuple[jnp.ndarray]],可选的元组结构,初始值为 None
hidden_states: Optional[Tuple[jnp.ndarray]] = None
# 定义变量 attentions,类型为 Optional[Tuple[jnp.ndarray]],可选的元组结构,初始值为 None
attentions: Optional[Tuple[jnp.ndarray]] = None
# 定义变量 cross_attentions,类型为 Optional[Tuple[jnp.ndarray]],可选的元组结构,初始值为 None
cross_attentions: Optional[Tuple[jnp.ndarray]] = None
@flax.struct.dataclass
class FlaxMaskedLMOutput(ModelOutput):
"""
Masked语言模型输出的基类。
Args:
logits (`jnp.ndarray` of shape `(batch_size, sequence_length, config.vocab_size)`):
语言建模头的预测分数(SoftMax之前的每个词汇标记的分数)。
hidden_states (`tuple(jnp.ndarray)`, *optional*, 当 `output_hidden_states=True` 或 `config.output_hidden_states=True` 时返回):
形状为 `(batch_size, sequence_length, hidden_size)` 的 `jnp.ndarray` 元组。
模型在每一层输出的隐藏状态加上初始嵌入输出。
attentions (`tuple(jnp.ndarray)`, *optional*, 当 `output_attentions=True` 或 `config.output_attentions=True` 时返回):
形状为 `(batch_size, num_heads, sequence_length, sequence_length)` 的 `jnp.ndarray` 元组。
经过注意力 softmax 后的注意力权重,用于计算自注意力头中的加权平均值。
"""
logits: jnp.ndarray = None
hidden_states: Optional[Tuple[jnp.ndarray]] = None
attentions: Optional[Tuple[jnp.ndarray]] = None
FlaxCausalLMOutput = FlaxMaskedLMOutput
@flax.struct.dataclass
class FlaxSeq2SeqLMOutput(ModelOutput):
"""
序列到序列语言模型输出的基类。
"""
logits: jnp.ndarray = None
past_key_values: Optional[Tuple[Tuple[jnp.ndarray]]] = None
decoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None
decoder_attentions: Optional[Tuple[jnp.ndarray]] = None
cross_attentions: Optional[Tuple[jnp.ndarray]] = None
encoder_last_hidden_state: Optional[jnp.ndarray] = None
encoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None
encoder_attentions: Optional[Tuple[jnp.ndarray]] = None
@flax.struct.dataclass
class FlaxNextSentencePredictorOutput(ModelOutput):
"""
预测两个句子是否连续的模型输出的基类。
"""
Args:
logits (`jnp.ndarray` of shape `(batch_size, 2)`):
Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
before SoftMax).
hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
logits: jnp.ndarray = None
hidden_states: Optional[Tuple[jnp.ndarray]] = None
attentions: Optional[Tuple[jnp.ndarray]] = None
class FlaxSequenceClassifierOutput(ModelOutput):
"""
Base class for outputs of sentence classification models.
Args:
logits (`jnp.ndarray` of shape `(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
logits: jnp.ndarray = None
hidden_states: Optional[Tuple[jnp.ndarray]] = None
attentions: Optional[Tuple[jnp.ndarray]] = None
class FlaxSeq2SeqSequenceClassifierOutput(ModelOutput):
"""
Base class for outputs of sequence-to-sequence sentence classification models.
"""
logits: jnp.ndarray = None
past_key_values: Optional[Tuple[Tuple[jnp.ndarray]]] = None
decoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None
decoder_attentions: Optional[Tuple[jnp.ndarray]] = None
cross_attentions: Optional[Tuple[jnp.ndarray]] = None
encoder_last_hidden_state: Optional[jnp.ndarray] = None
encoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None
encoder_attentions: Optional[Tuple[jnp.ndarray]] = None
class FlaxMultipleChoiceModelOutput(ModelOutput):
"""
Base class for outputs of multiple choice models.
"""
"""
Args:
logits (`jnp.ndarray` of shape `(batch_size, num_choices)`):
分类器的输出分数(SoftMax 之前)。
hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
包含每层输出的元组,每个 `jnp.ndarray` 的形状为 `(batch_size, sequence_length, hidden_size)`。
模型在每一层的隐藏状态以及初始嵌入输出。
attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
包含每层注意力权重的元组,每个 `jnp.ndarray` 的形状为 `(batch_size, num_heads, sequence_length, sequence_length)`。
注意力 softmax 后的注意力权重,用于计算自注意力头中的加权平均值。
"""
logits: jnp.ndarray = None
hidden_states: Optional[Tuple[jnp.ndarray]] = None
attentions: Optional[Tuple[jnp.ndarray]] = None
@flax.struct.dataclass
class FlaxTokenClassifierOutput(ModelOutput):
"""
序列标注模型输出的基类。
Args:
logits (`jnp.ndarray` of shape `(batch_size, sequence_length, config.num_labels)`):
分类得分(SoftMax 之前)。
hidden_states (`tuple(jnp.ndarray)`, *optional*, 当 `output_hidden_states=True` 或 `config.output_hidden_states=True` 时返回):
包含多个 `jnp.ndarray` 的元组,形状为 `(batch_size, sequence_length, hidden_size)`。
模型在每一层输出的隐藏状态,以及初始嵌入输出。
attentions (`tuple(jnp.ndarray)`, *optional*, 当 `output_attentions=True` 或 `config.output_attentions=True` 时返回):
包含多个 `jnp.ndarray` 的元组,形状为 `(batch_size, num_heads, sequence_length, sequence_length)`。
经过注意力 softmax 后的注意力权重,用于计算自注意力头中的加权平均值。
"""
logits: jnp.ndarray = None
hidden_states: Optional[Tuple[jnp.ndarray]] = None
attentions: Optional[Tuple[jnp.ndarray]] = None
@flax.struct.dataclass
class FlaxQuestionAnsweringModelOutput(ModelOutput):
"""
问答模型输出的基类。
Args:
start_logits (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
起始位置的得分(SoftMax 之前)。
end_logits (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
终止位置的得分(SoftMax 之前)。
hidden_states (`tuple(jnp.ndarray)`, *optional*, 当 `output_hidden_states=True` 或 `config.output_hidden_states=True` 时返回):
包含多个 `jnp.ndarray` 的元组,形状为 `(batch_size, sequence_length, hidden_size)`。
模型在每一层输出的隐藏状态,以及初始嵌入输出。
attentions (`tuple(jnp.ndarray)`, *optional*, 当 `output_attentions=True` 或 `config.output_attentions=True` 时返回):
包含多个 `jnp.ndarray` 的元组,形状为 `(batch_size, num_heads, sequence_length, sequence_length)`。
经过注意力 softmax 后的注意力权重,用于计算自注意力头中的加权平均值。
"""
start_logits: jnp.ndarray = None
end_logits: jnp.ndarray = None
hidden_states: Optional[Tuple[jnp.ndarray]] = None
attentions: Optional[Tuple[jnp.ndarray]] = None
@flax.struct.dataclass
class FlaxSeq2SeqQuestionAnsweringModelOutput(ModelOutput):
"""
序列到序列问答模型输出的基类。
"""
start_logits: jnp.ndarray = None
end_logits: jnp.ndarray = None
past_key_values: Optional[Tuple[Tuple[jnp.ndarray]]] = None
decoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None
decoder_attentions: Optional[Tuple[jnp.ndarray]] = None
cross_attentions: Optional[Tuple[jnp.ndarray]] = None
encoder_last_hidden_state: Optional[jnp.ndarray] = None
encoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None
encoder_attentions: Optional[Tuple[jnp.ndarray]] = None
.\modeling_flax_pytorch_utils.py
""" PyTorch - Flax general utilities."""
import os
from pickle import UnpicklingError
from typing import Dict, Tuple
import jax
import jax.numpy as jnp
import numpy as np
from flax.serialization import from_bytes
from flax.traverse_util import flatten_dict, unflatten_dict
import transformers
from . import is_safetensors_available, is_torch_available
from .utils import logging
if is_torch_available():
import torch
if is_safetensors_available():
from safetensors import safe_open
from safetensors.flax import load_file as safe_load_file
logger = logging.get_logger(__name__)
def load_pytorch_checkpoint_in_flax_state_dict(
flax_model, pytorch_checkpoint_path, is_sharded, allow_missing_keys=False
):
"""Load pytorch checkpoints in a flax model"""
if not is_sharded:
pt_path = os.path.abspath(pytorch_checkpoint_path)
logger.info(f"Loading PyTorch weights from {pt_path}")
if pt_path.endswith(".safetensors"):
pt_state_dict = {}
with safe_open(pt_path, framework="flax") as f:
for k in f.keys():
pt_state_dict[k] = f.get_tensor(k)
else:
try:
import torch
from .pytorch_utils import is_torch_greater_or_equal_than_1_13
except (ImportError, ModuleNotFoundError):
logger.error(
"Loading a PyTorch model in Flax, requires both PyTorch and Flax to be installed. Please see"
" https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation"
" instructions."
)
raise
weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
pt_state_dict = torch.load(pt_path, map_location="cpu", **weights_only_kwarg)
logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.")
flax_state_dict = convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model)
else:
flax_state_dict = convert_pytorch_sharded_state_dict_to_flax(pytorch_checkpoint_path, flax_model)
return flax_state_dict
def rename_key_and_reshape_tensor(
pt_tuple_key: Tuple[str],
pt_tensor: np.ndarray,
random_flax_state_dict: Dict[str, jnp.ndarray],
model_prefix: str,
) -> (Tuple[str], np.ndarray):
"""Rename PT weight names to corresponding Flax weight names and reshape tensor if necessary"""
def is_key_or_prefix_key_in_dict(key: Tuple[str]) -> bool:
"""Checks if `key` of `(prefix,) + key` is in random_flax_state_dict"""
return len(set(random_flax_state_dict) & {key, (model_prefix,) + key}) > 0
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
if pt_tuple_key[-1] in ["weight", "gamma"] and is_key_or_prefix_key_in_dict(renamed_pt_tuple_key):
return renamed_pt_tuple_key, pt_tensor
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("mean",)
if pt_tuple_key[-1] == "running_mean" and not is_key_or_prefix_key_in_dict(pt_tuple_key):
return renamed_pt_tuple_key, pt_tensor
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("var",)
if pt_tuple_key[-1] == "running_var" and not is_key_or_prefix_key_in_dict(pt_tuple_key):
return renamed_pt_tuple_key, pt_tensor
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("embedding",)
if pt_tuple_key[-1] == "weight" and is_key_or_prefix_key_in_dict(renamed_pt_tuple_key):
return renamed_pt_tuple_key, pt_tensor
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4 and not is_key_or_prefix_key_in_dict(pt_tuple_key):
pt_tensor = pt_tensor.transpose(2, 3, 1, 0)
return renamed_pt_tuple_key, pt_tensor
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
if pt_tuple_key[-1] == "weight" and not is_key_or_prefix_key_in_dict(pt_tuple_key):
pt_tensor = pt_tensor.T
return renamed_pt_tuple_key, pt_tensor
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("weight",)
if pt_tuple_key[-1] == "gamma":
return renamed_pt_tuple_key, pt_tensor
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("bias",)
if pt_tuple_key[-1] == "beta":
return renamed_pt_tuple_key, pt_tensor
name = None
if pt_tuple_key[-3::2] == ("parametrizations", "original0"):
name = pt_tuple_key[-2] + "_g"
elif pt_tuple_key[-3::2] == ("parametrizations", "original1"):
name = pt_tuple_key[-2] + "_v"
if name is not None:
renamed_pt_tuple_key = pt_tuple_key[:-3] + (name,)
return renamed_pt_tuple_key, pt_tensor
return pt_tuple_key, pt_tensor
bfloat16 = torch.bfloat16 if from_bin else "bfloat16"
weight_dtypes = {k: v.dtype for k, v in pt_state_dict.items()}
if from_bin:
for k, v in pt_state_dict.items():
if v.dtype == bfloat16:
v = v.float()
pt_state_dict[k] = v.numpy()
model_prefix = flax_model.base_model_prefix
if "params" in flax_model.params:
flax_model_params = flax_model.params["params"]
else:
flax_model_params = flax_model.params
random_flax_state_dict = flatten_dict(flax_model_params)
if "batch_stats" in flax_model.params:
flax_batch_stats = flatten_dict(flax_model.params["batch_stats"])
random_flax_state_dict.update(flax_batch_stats)
flax_state_dict = {}
load_model_with_head_into_base_model = (model_prefix not in flax_model_params) and (
model_prefix in {k.split(".")[0] for k in pt_state_dict.keys()}
)
load_base_model_into_model_with_head = (model_prefix in flax_model_params) and (
model_prefix not in {k.split(".")[0] for k in pt_state_dict.keys()}
)
for pt_key, pt_tensor in pt_state_dict.items():
pt_tuple_key = tuple(pt_key.split("."))
is_bfloat_16 = weight_dtypes[pt_key] == bfloat16
has_base_model_prefix = pt_tuple_key[0] == model_prefix
if load_model_with_head_into_base_model and has_base_model_prefix:
pt_tuple_key = pt_tuple_key[1:]
flax_key, flax_tensor = rename_key_and_reshape_tensor(
pt_tuple_key, pt_tensor, random_flax_state_dict, model_prefix
)
require_base_model_prefix = (model_prefix,) + flax_key in random_flax_state_dict
if load_base_model_into_model_with_head and require_base_model_prefix:
flax_key = (model_prefix,) + flax_key
if flax_key in random_flax_state_dict:
if flax_tensor.shape != random_flax_state_dict[flax_key].shape:
raise ValueError(
f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape "
f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}."
)
if "batch_stats" in flax_model.params:
if "mean" in flax_key[-1] or "var" in flax_key[-1]:
flax_state_dict[("batch_stats",) + flax_key] = jnp.asarray(flax_tensor)
continue
if "num_batches_tracked" in flax_key[-1]:
flax_state_dict.pop(flax_key, None)
continue
flax_state_dict[("params",) + flax_key] = (
jnp.asarray(flax_tensor) if not is_bfloat_16 else jnp.asarray(flax_tensor, dtype=jnp.bfloat16)
)
else:
flax_state_dict[flax_key] = (
jnp.asarray(flax_tensor) if not is_bfloat_16 else jnp.asarray(flax_tensor, dtype=jnp.bfloat16)
)
return unflatten_dict(flax_state_dict)
def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model):
import torch
from .pytorch_utils import is_torch_greater_or_equal_than_1_13
flax_state_dict = {}
return unflatten_dict(flax_state_dict)
def load_flax_checkpoint_in_pytorch_model(model, flax_checkpoint_path):
"""Load flax checkpoints in a PyTorch model"""
flax_checkpoint_path = os.path.abspath(flax_checkpoint_path)
logger.info(f"Loading Flax weights from {flax_checkpoint_path}")
flax_cls = getattr(transformers, "Flax" + model.__class__.__name__)
if flax_checkpoint_path.endswith(".safetensors"):
flax_state_dict = safe_load_file(flax_checkpoint_path)
flax_state_dict = unflatten_dict(flax_state_dict, sep=".")
else:
with open(flax_checkpoint_path, "rb") as state_f:
try:
flax_state_dict = from_bytes(flax_cls, state_f.read())
except UnpicklingError:
raise EnvironmentError(f"Unable to convert {flax_checkpoint_path} to Flax deserializable object. ")
return load_flax_weights_in_pytorch_model(model, flax_state_dict)
def load_flax_weights_in_pytorch_model(pt_model, flax_state):
"""Load flax checkpoints in a PyTorch model"""
try:
import torch
except (ImportError, ModuleNotFoundError):
logger.error(
"Loading a Flax weights in PyTorch, requires both PyTorch and Flax to be installed. Please see"
" https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation"
" instructions."
)
raise
is_type_bf16 = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype == jnp.bfloat16, flax_state)).values()
if any(is_type_bf16):
logger.warning(
"Found ``bfloat16`` weights in Flax model. Casting all ``bfloat16`` weights to ``float32`` "
"before loading those in PyTorch model."
)
flax_state = jax.tree_util.tree_map(
lambda params: params.astype(np.float32) if params.dtype == jnp.bfloat16 else params, flax_state
)
flax_state_dict = flatten_dict(flax_state)
pt_model_dict = pt_model.state_dict()
load_model_with_head_into_base_model = (pt_model.base_model_prefix in flax_state) and (
pt_model.base_model_prefix not in {k.split(".")[0] for k in pt_model_dict.keys()}
)
load_base_model_into_model_with_head = (pt_model.base_model_prefix not in flax_state) and (
pt_model.base_model_prefix in {k.split(".")[0] for k in pt_model_dict.keys()}
)
unexpected_keys = []
missing_keys = set(pt_model_dict.keys())
pt_model.load_state_dict(pt_model_dict)
missing_keys = list(missing_keys)
if len(unexpected_keys) > 0:
logger.warning(
"Some weights of the Flax model were not used when initializing the PyTorch model"
f" {pt_model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are initializing"
f" {pt_model.__class__.__name__} from a Flax model trained on another task or with another architecture"
" (e.g. initializing a BertForSequenceClassification model from a FlaxBertForPreTraining model).\n- This"
f" IS NOT expected if you are initializing {pt_model.__class__.__name__} from a Flax model that you expect"
" to be exactly identical (e.g. initializing a BertForSequenceClassification model from a"
" FlaxBertForSequenceClassification model)."
)
else:
logger.warning(f"All Flax model weights were used when initializing {pt_model.__class__.__name__}.\n")
if len(missing_keys) > 0:
logger.warning(
f"Some weights of {pt_model.__class__.__name__} were not initialized from the Flax model and are newly"
f" initialized: {missing_keys}\nYou should probably TRAIN this model on a down-stream task to be able to"
" use it for predictions and inference."
)
else:
logger.warning(
f"All the weights of {pt_model.__class__.__name__} were initialized from the Flax model.\n"
"If your task is similar to the task the model of the checkpoint was trained on, "
f"you can already use {pt_model.__class__.__name__} for predictions without further training."
)
return pt_model
.\modeling_flax_utils.py
import gc
import json
import os
import re
import warnings
from functools import partial
from pickle import UnpicklingError
from typing import Any, Dict, Optional, Set, Tuple, Union
import flax.linen as nn
import jax
import jax.numpy as jnp
import msgpack.exceptions
from flax.core.frozen_dict import FrozenDict, unfreeze
from flax.serialization import from_bytes, to_bytes
from flax.traverse_util import flatten_dict, unflatten_dict
from jax.random import PRNGKey
from .configuration_utils import PretrainedConfig
from .dynamic_module_utils import custom_object_save
from .generation import FlaxGenerationMixin, GenerationConfig
from .modeling_flax_pytorch_utils import load_pytorch_checkpoint_in_flax_state_dict
from .utils import (
FLAX_WEIGHTS_INDEX_NAME, FLAX_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME, WEIGHTS_NAME,
PushToHubMixin,
add_code_sample_docstrings,
add_start_docstrings_to_model_forward,
cached_file,
copy_func,
download_url,
has_file,
is_offline_mode,
is_remote_url,
logging,
replace_return_docstrings,
)
from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files
from .utils.import_utils import is_safetensors_available
if is_safetensors_available():
from safetensors import safe_open
from safetensors.flax import load_file as safe_load_file
from safetensors.flax import save_file as safe_save_file
logger = logging.get_logger(__name__)
def quick_gelu(x):
"""
快速 GELU 激活函数的定义,使用 JAX 实现
"""
return x * jax.nn.sigmoid(1.702 * x)
ACT2FN = {
"gelu": partial(nn.gelu, approximate=False),
"relu": nn.relu,
"silu": nn.swish,
"swish": nn.swish,
"gelu_new": partial(nn.gelu, approximate=True),
"quick_gelu": quick_gelu,
"gelu_pytorch_tanh": partial(nn.gelu, approximate=True),
}
def dtype_byte_size(dtype):
"""
根据数据类型 `dtype` 返回一个参数占用的字节数。例如:
```
>>> dtype_byte_size(np.float32)
4
```
"""
if dtype == bool:
return 1 / 8
bit_search = re.search(r"[^\d](\d+)$", dtype.name)
if bit_search is None:
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
bit_size = int(bit_search.groups()[0])
return bit_size // 8
def flax_shard_checkpoint(params, max_shard_size="10GB"):
"""
将参数 `params` 拆分为多个小的检查点文件,以便于存储和传输。
"""
Args:
params (`Union[Dict, FrozenDict]`): 模型参数的`PyTree`表示。
max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
每个子检查点的最大大小。如果表示为字符串,则需要是数字后跟单位(例如`"5MB"`)。
"""
# 将`max_shard_size`转换为整数表示
max_shard_size = convert_file_size_to_int(max_shard_size)
# 初始化用于存储分片状态字典的列表
sharded_state_dicts = []
# 当前分块的字典
current_block = {}
# 当前分块的大小
current_block_size = 0
# 总大小
total_size = 0
# 将参数展平为键值对
weights = flatten_dict(params, sep="/")
for item in weights:
# 计算权重项的大小
weight_size = weights[item].size * dtype_byte_size(weights[item].dtype)
# 如果当前分块加上当前权重项的大小超过了最大分块大小,进行分块
if current_block_size + weight_size > max_shard_size:
sharded_state_dicts.append(current_block)
current_block = {}
current_block_size = 0
# 将权重项添加到当前分块中
current_block[item] = weights[item]
current_block_size += weight_size
total_size += weight_size
# 添加最后一个分块
sharded_state_dicts.append(current_block)
# 如果只有一个分片,直接返回
if len(sharded_state_dicts) == 1:
return {FLAX_WEIGHTS_NAME: sharded_state_dicts[0]}, None
# 否则,构建权重映射和分片文件名
weight_map = {}
shards = {}
for idx, shard in enumerate(sharded_state_dicts):
shard_file = FLAX_WEIGHTS_NAME.replace(".msgpack", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.msgpack")
shards[shard_file] = shard
for weight_name in shard.keys():
weight_map[weight_name] = shard_file
# 添加元数据
metadata = {"total_size": total_size}
index = {"metadata": metadata, "weight_map": weight_map}
return shards, index
# FlaxPreTrainedModel 类,继承自 PushToHubMixin 和 FlaxGenerationMixin
class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
# 所有模型的基类。
r"""
Base class for all models.
[`FlaxPreTrainedModel`] takes care of storing the configuration of the models and handles methods for loading,
downloading and saving models.
Class attributes (overridden by derived classes):
- **config_class** ([`PretrainedConfig`]) -- A subclass of [`PretrainedConfig`] to use as configuration class
for this model architecture.
- **base_model_prefix** (`str`) -- A string indicating the attribute associated to the base model in derived
classes of the same architecture adding modules on top of the base model.
- **main_input_name** (`str`) -- The name of the principal input to the model (often `input_ids` for NLP
models, `pixel_values` for vision models and `input_values` for speech models).
"""
# 模型配置类,默认为 None
config_class = None
# 基模型前缀,默认为空字符串
base_model_prefix = ""
# 主要输入名称,默认为 "input_ids"
main_input_name = "input_ids"
# 自动类
_auto_class = None
# 缺失的键集合
_missing_keys = set()
# 模型初始化方法
def __init__(
self,
config: PretrainedConfig,
module: nn.Module,
input_shape: Tuple = (1, 1),
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
):
# 如果 config 为 None,则抛出 ValueError
if config is None:
raise ValueError("config cannot be None")
# 如果 module 为 None,则抛出 ValueError
if module is None:
raise ValueError("module cannot be None")
# 下面的属性用于在派生类中作为类型化属性暴露,因此为私有属性。
# 存储配置对象
self._config = config
# 存储模块对象
self._module = module
# 下面的属性为每个派生类通用的公共属性。
# 初始化随机数生成器的 key
self.key = PRNGKey(seed)
# 数据类型,默认为 jnp.float32
self.dtype = dtype
# 输入形状,默认为 (1, 1)
self.input_shape = input_shape
# 生成配置对象,基于模型配置生成,如果可以生成的话
self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
# 标志模型是否已初始化
self._is_initialized = _do_init
# 如果 _do_init 为 True,则随机初始化参数
if _do_init:
# 随机初始化模型参数
random_params = self.init_weights(self.key, input_shape)
# 计算参数的形状树
params_shape_tree = jax.eval_shape(lambda params: params, random_params)
else:
# 如果 _do_init 为 False,则部分初始化模型参数
init_fn = partial(self.init_weights, input_shape=input_shape)
params_shape_tree = jax.eval_shape(init_fn, self.key)
# 日志记录,提示模型权重未初始化
logger.info(
"Model weights are not initialized as `_do_init` is set to `False`. "
f"Make sure to call `{self.__class__.__name__}.init_weights` manually to initialize the weights."
)
# 存储参数形状树
self._params_shape_tree = params_shape_tree
# 将必需参数保存为集合
self._required_params = set(flatten_dict(unfreeze(params_shape_tree)).keys())
# 如果 _do_init 为 True,则设置模型参数
if _do_init:
self.params = random_params
# 定义一个抽象方法,用于初始化模型的权重。子类必须实现这个方法。
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> Dict:
raise NotImplementedError(f"init method has to be implemented for {self}")
# 定义一个抽象方法,用于启用梯度检查点功能。子类必须实现这个方法。
def enable_gradient_checkpointing(self):
raise NotImplementedError(f"gradient checkpointing method has to be implemented for {self}")
# 类方法,用于根据给定的配置和其他参数创建类的实例。
@classmethod
def _from_config(cls, config, **kwargs):
"""
All context managers that the model should be initialized under go here.
"""
return cls(config, **kwargs)
# 返回字符串标识,指示这是一个 Flax 模型。
@property
def framework(self) -> str:
"""
:str: Identifies that this is a Flax model.
"""
return "flax"
# 返回模型的配置信息。
@property
def config(self) -> PretrainedConfig:
return self._config
# 返回模型的内部模块。
@property
def module(self) -> nn.Module:
return self._module
# 返回模型的参数,可以是普通字典或者冻结字典。
@property
def params(self) -> Union[Dict, FrozenDict]:
if not self._is_initialized:
raise ValueError(
"`params` cannot be accessed from model when the model is created with `_do_init=False`. "
"You must call `init_weights` manually and store the params outside of the model and "
"pass it explicitly where needed."
)
return self._params
# 返回模型所需的参数集合。
@property
def required_params(self) -> Set:
return self._required_params
# 返回模型参数的形状树。
@property
def params_shape_tree(self) -> Dict:
return self._params_shape_tree
# 设置模型的参数,如果模型未初始化则抛出异常。
@params.setter
def params(self, params: Union[Dict, FrozenDict]):
# 如果模型未初始化,则不设置参数。
if not self._is_initialized:
raise ValueError(
"`params` cannot be set from model when the model is created with `_do_init=False`. "
"You store the params outside of the model."
)
# 如果参数是冻结字典,则解冻成普通字典。
if isinstance(params, FrozenDict):
params = unfreeze(params)
# 检查参数是否包含所有必需的参数键。
param_keys = set(flatten_dict(params).keys())
if len(self.required_params - param_keys) > 0:
raise ValueError(
"Some parameters are missing. Make sure that `params` include the following "
f"parameters {self.required_params - param_keys}"
)
# 设置模型的参数。
self._params = params
def _cast_floating_to(self, params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any:
"""
Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`.
"""
# 从 https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27 中借用
# 定义条件转换函数,用于将参数中的浮点值转换为指定的 dtype
def conditional_cast(param):
if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating):
param = param.astype(dtype)
return param
# 如果 mask 为 None,则直接对 params 应用 tree_map 转换
if mask is None:
return jax.tree_util.tree_map(conditional_cast, params)
# 将 params 展平为字典
flat_params = flatten_dict(params)
# 将 mask 也展平并获取其结构
flat_mask, _ = jax.tree_util.tree_flatten(mask)
# 遍历展平后的 mask 和 params 的键值对,并根据 mask 的值进行条件转换
for masked, key in zip(flat_mask, sorted(flat_params.keys())):
if masked:
flat_params[key] = conditional_cast(flat_params[key])
# 返回转换后的 params 的非展平版本
return unflatten_dict(flat_params)
def to_bf16(self, params: Union[Dict, FrozenDict], mask: Any = None):
r"""
Cast the floating-point `params` to `jax.numpy.bfloat16`. This returns a new `params` tree and does not cast
the `params` in place.
This method can be used on TPU to explicitly convert the model parameters to bfloat16 precision to do full
half-precision training or to save weights in bfloat16 for inference in order to save memory and improve speed.
Arguments:
params (`Union[Dict, FrozenDict]`):
A `PyTree` of model parameters.
mask (`Union[Dict, FrozenDict]`):
A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params
you want to cast, and should be `False` for those you want to skip.
Examples:
```
>>> from transformers import FlaxBertModel
>>>
>>> model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased")
>>>
>>> model.params = model.to_bf16(model.params)
>>>
>>>
>>> from flax import traverse_util
>>> model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased")
>>> flat_params = traverse_util.flatten_dict(model.params)
>>> mask = {
... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale"))
... for path in flat_params
... }
>>> mask = traverse_util.unflatten_dict(mask)
>>> model.params = model.to_bf16(model.params, mask)
```
"""
return self._cast_floating_to(params, jnp.bfloat16, mask)
def to_fp32(self, params: Union[Dict, FrozenDict], mask: Any = None):
r"""
Cast the floating-point `parmas` to `jax.numpy.float32`. This method can be used to explicitly convert the
model parameters to fp32 precision. This returns a new `params` tree and does not cast the `params` in place.
Arguments:
params (`Union[Dict, FrozenDict]`):
A `PyTree` of model parameters.
mask (`Union[Dict, FrozenDict]`):
A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params
you want to cast, and should be `False` for those you want to skip
Examples:
```
>>> from transformers import FlaxBertModel
>>>
>>> model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased")
>>>
>>>
>>> model.params = model.to_f16(model.params)
>>>
>>> model.params = model.to_fp32(model.params)
```
return self._cast_floating_to(params, jnp.float32, mask)
def to_fp16(self, params: Union[Dict, FrozenDict], mask: Any = None):
return self._cast_floating_to(params, jnp.float16, mask)
def load_flax_weights(cls, resolved_archive_file):
try:
if resolved_archive_file.endswith(".safetensors"):
state = safe_load_file(resolved_archive_file)
state = unflatten_dict(state, sep=".")
else:
with open(resolved_archive_file, "rb") as state_f:
state = from_bytes(cls, state_f.read())
except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
try:
with open(resolved_archive_file) as f:
if f.read().startswith("version"):
raise OSError(
"You seem to have cloned a repository without having git-lfs installed. Please"
" install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
" folder you cloned."
)
else:
raise ValueError from e
except (UnicodeDecodeError, ValueError):
raise EnvironmentError(f"Unable to convert {resolved_archive_file} to Flax deserializable object. ")
return state
@classmethod
def load_flax_sharded_weights(cls, shard_files):
"""
This is the same as [`flax.serialization.from_bytes`](https://flax.readthedocs.io/en/latest/_modules/flax/serialization.html#from_bytes) but for a sharded checkpoint.
This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being
loaded in the model.
Args:
shard_files (`List[str]`):
The list of shard files to load.
Returns:
`Dict`: A nested dictionary of the model parameters, in the expected format for flax models : `{'model':
{'params': {'...'}}}`.
"""
state_sharded_dict = {}
for shard_file in shard_files:
try:
with open(shard_file, "rb") as state_f:
state = from_bytes(cls, state_f.read())
except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
with open(shard_file) as f:
if f.read().startswith("version"):
raise OSError(
"You seem to have cloned a repository without having git-lfs installed. Please"
" install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
" folder you cloned."
)
else:
raise ValueError from e
except (UnicodeDecodeError, ValueError):
raise EnvironmentError(f"Unable to convert {shard_file} to Flax deserializable object. ")
state = flatten_dict(state, sep="/")
state_sharded_dict.update(state)
del state
gc.collect()
return unflatten_dict(state_sharded_dict, sep="/")
@classmethod
def can_generate(cls) -> bool:
"""
Returns whether this model can generate sequences with `.generate()`.
Returns:
`bool`: Whether this model can generate sequences with `.generate()`.
"""
if "GenerationMixin" in str(cls.prepare_inputs_for_generation) and "GenerationMixin" in str(cls.generate):
return False
return True
def from_pretrained(
cls,
pretrained_model_name_or_path: Union[str, os.PathLike],
dtype: jnp.dtype = jnp.float32,
*model_args,
config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
cache_dir: Optional[Union[str, os.PathLike]] = None,
ignore_mismatched_sizes: bool = False,
force_download: bool = False,
local_files_only: bool = False,
token: Optional[Union[str, bool]] = None,
revision: str = "main",
**kwargs,
):
"""
从预训练模型加载模型参数和配置。
<Tip warning={true}>
当前 API 处于实验阶段,未来版本可能会有一些轻微的更改。
</Tip>
Args:
pretrained_model_name_or_path (`str` or `os.PathLike`):
预训练模型的名称或路径。
dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
指定加载参数时使用的数据类型,默认为 jnp.float32。
*model_args:
其余位置参数,传递给具体模型加载函数。
config (`PretrainedConfig`, `str`, `os.PathLike`, *optional*, defaults to `None`):
预训练模型的配置对象或其路径,可选参数。
cache_dir (`str` or `os.PathLike`, *optional*, defaults to `None`):
缓存目录的路径,可选参数。
ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`):
是否忽略加载参数时大小不匹配的情况,默认为 False。
force_download (`bool`, *optional*, defaults to `False`):
是否强制重新下载模型,默认为 False。
local_files_only (`bool`, *optional*, defaults to `False`):
是否仅使用本地文件加载模型,默认为 False。
token (`str` or `bool`, *optional*, defaults to `None`):
token 用于验证下载的模型,可选参数。
revision (`str`, *optional*, defaults to `"main"`):
模型的版本号,默认为 "main"。
**kwargs:
其余关键字参数,传递给具体模型加载函数。
"""
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
params=None,
push_to_hub=False,
max_shard_size="10GB",
token: Optional[Union[str, bool]] = None,
safe_serialization: bool = False,
**kwargs,
):
"""
将当前模型保存到指定目录。
Args:
save_directory (`str` or `os.PathLike`):
保存模型的目录路径。
params:
要保存的模型参数,默认为 None。
push_to_hub (`bool`, *optional*, defaults to `False`):
是否将模型推送到模型 Hub,默认为 False。
max_shard_size (`str`, *optional*, defaults to `"10GB"`):
最大的分片大小,默认为 "10GB"。
token (`str` or `bool`, *optional*, defaults to `None`):
token 用于验证保存操作,可选参数。
safe_serialization (`bool`, *optional*, defaults to `False`):
是否进行安全序列化,默认为 False。
**kwargs:
其余关键字参数,传递给具体保存函数。
"""
@classmethod
def register_for_auto_class(cls, auto_class="FlaxAutoModel"):
"""
注册当前模型类到指定的自动加载类。仅用于自定义模型,因为库中的模型已经与自动加载类映射。
<Tip warning={true}>
当前 API 处于实验阶段,未来版本可能会有一些轻微的更改。
</Tip>
Args:
auto_class (`str` or `type`, *optional*, defaults to `"FlaxAutoModel"`):
要注册新模型的自动加载类名称或类型。
"""
if not isinstance(auto_class, str):
auto_class = auto_class.__name__
import transformers.models.auto as auto_module
if not hasattr(auto_module, auto_class):
raise ValueError(f"{auto_class} is not a valid auto class.")
cls._auto_class = auto_class
FlaxPreTrainedModel.push_to_hub = copy_func(FlaxPreTrainedModel.push_to_hub)
if FlaxPreTrainedModel.push_to_hub.__doc__ is not None:
FlaxPreTrainedModel.push_to_hub.__doc__ = FlaxPreTrainedModel.push_to_hub.__doc__.format(
object="model", object_class="FlaxAutoModel", object_files="model checkpoint"
)
def overwrite_call_docstring(model_class, docstring):
model_class.__call__ = copy_func(model_class.__call__)
model_class.__call__.__doc__ = None
model_class.__call__ = add_start_docstrings_to_model_forward(docstring)(model_class.__call__)
def append_call_sample_docstring(
model_class, checkpoint, output_type, config_class, mask=None, revision=None, real_checkpoint=None
):
model_class.__call__ = copy_func(model_class.__call__)
model_class.__call__ = add_code_sample_docstrings(
checkpoint=checkpoint,
output_type=output_type,
config_class=config_class,
model_cls=model_class.__name__,
revision=revision,
real_checkpoint=real_checkpoint,
)(model_class.__call__)
def append_replace_return_docstrings(model_class, output_type, config_class):
model_class.__call__ = copy_func(model_class.__call__)
model_class.__call__ = replace_return_docstrings(
output_type=output_type,
config_class=config_class,
)(model_class.__call__)
.\modeling_outputs.py
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
from .utils import ModelOutput
@dataclass
class BaseModelOutput(ModelOutput):
"""
模型输出的基类,可能包含隐藏状态和注意力。
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
模型最后一层的隐藏状态序列。
hidden_states (`tuple(torch.FloatTensor)`, *optional*, 当 `output_hidden_states=True` 时返回或当 `config.output_hidden_states=True` 时返回):
包含每层输出的元组 `torch.FloatTensor`(如果模型有嵌入层,则包含嵌入层输出),
形状为 `(batch_size, sequence_length, hidden_size)`。
模型在每层的隐藏状态以及可选的初始嵌入输出。
attentions (`tuple(torch.FloatTensor)`, *optional*, 当 `output_attentions=True` 时返回或当 `config.output_attentions=True` 时返回):
包含每层注意力权重的元组 `torch.FloatTensor`,
形状为 `(batch_size, num_heads, sequence_length, sequence_length)`。
注意力权重经过 softmax 后的结果,在自注意力头中用于计算加权平均值。
"""
last_hidden_state: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
@dataclass
class BaseModelOutputWithNoAttention(ModelOutput):
"""
模型输出的基类,仅包含潜在的隐藏状态。
继承自 ModelOutput。
"""
last_hidden_state: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
@dataclass
class BaseModelOutputWithPooling(ModelOutput):
"""
Base class for model's outputs that also contains a pooling of the last hidden states.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
Last layer hidden-state of the first token of the sequence (classification token) after further processing
through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns
the classification token after processing through a linear layer and a tanh activation function. The linear
layer weights are trained from the next sentence prediction (classification) objective during pretraining.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
last_hidden_state: torch.FloatTensor = None
pooler_output: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
@dataclass
class BaseModelOutputWithPoolingAndNoAttention(ModelOutput):
"""
Base class for model's outputs that also contains a pooling of the last hidden states.
"""
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Sequence of hidden-states at the output of the last layer of the model.
pooler_output: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
@dataclass
class BaseModelOutputWithPast(ModelOutput):
"""
Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
hidden_size)` is output.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
`config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
input) to speed up sequential decoding.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
last_hidden_state: torch.FloatTensor = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
@dataclass
class BaseModelOutputWithCrossAttentions(ModelOutput):
"""
Base class for model's outputs, with potential hidden states and attentions.
"""
last_hidden_state: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
@dataclass
class BaseModelOutputWithPoolingAndCrossAttentions(ModelOutput):
"""
Base class for model's outputs that also contains a pooling of the last hidden states.
模型输出的基础类,还包含最后隐藏状态的池化。
"""
last_hidden_state: torch.FloatTensor = None
pooler_output: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
@dataclass
class BaseModelOutputWithPastAndCrossAttentions(ModelOutput):
"""
Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
模型输出的基础类,可能还包含过去的键/值(用于加速顺序解码)。
"""
last_hidden_state: torch.FloatTensor = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
@dataclass
class MoECausalLMOutputWithPast(ModelOutput):
"""
Base class for causal language model (or autoregressive) outputs as well as Mixture of Expert's router hidden
states terms, to train a MoE model.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss (for next-token prediction).
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
z_loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided):
z_loss for the sparse modules.
aux_loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided):
aux_loss for the sparse modules.
router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True` is passed or when `config.add_router_probs=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`.
Router logits of the encoder model, useful to compute the auxiliary loss and the z_loss for the sparse
modules.
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
z_loss: Optional[torch.FloatTensor] = None
aux_loss: Optional[torch.FloatTensor] = None
router_logits: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
z_loss: torch.FloatTensor = None
aux_loss: torch.FloatTensor = None
router_logits: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class MoEModelOutput(ModelOutput):
"""
Base class for model's outputs, with potential hidden states and attentions.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
模型最后一层的输出隐藏状态序列。
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
每一层模型的隐藏状态输出,包括初始嵌入层的输出(如果存在)。
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
经过注意力softmax后的注意力权重,用于计算自注意力头中的加权平均值。
router_probs (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`.
Raw router probabilities that are computed by MoE routers, these terms are used to compute the auxiliary
loss and the z_loss for Mixture of Experts models.
由MoE路由器计算得到的原始路由器概率,用于计算混合专家模型的辅助损失和z_loss。
"""
last_hidden_state: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
router_probs: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class MoeModelOutputWithPast(ModelOutput):
"""
Base class for model's outputs, with potential hidden states and attentions.
"""
last_hidden_state: torch.FloatTensor = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
router_logits: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class MoeCausalLMOutputWithPast(ModelOutput):
"""
Base class for causal language model (or autoregressive) with mixture of experts outputs.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss (for next-token prediction).
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
aux_loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided):
Auxiliary loss for the sparse modules.
router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`.
Raw router logits (post-softmax) computed by MoE routers, used for computing auxiliary loss in Mixture of Experts models.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, each tuple containing 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`.
Pre-computed hidden states (keys and values in self-attention blocks) for speeding up sequential decoding.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for embedding layer output if present, plus one for each layer) of shape
`(batch_size, sequence_length, hidden_size)`.
Hidden states of the model at each layer's output, including optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`.
Attention weights after the softmax operation, used for computing weighted averages in self-attention heads.
"""
loss: Optional[torch.FloatTensor] = None
aux_loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
@dataclass
class MoEModelOutputWithPastAndCrossAttentions(ModelOutput):
"""
Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding) as well as
Mixture of Expert's router hidden states terms, to train a MoE model.
"""
last_hidden_state: torch.FloatTensor = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
router_probs: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class Seq2SeqModelOutput(ModelOutput):
"""
Base class for model encoder's outputs that also contains : pre-computed hidden states that can speed up sequential
decoding.
"""
last_hidden_state: torch.FloatTensor = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
@dataclass
class Seq2SeqMoEModelOutput(ModelOutput):
"""
Base class for model encoder's outputs that also contains : pre-computed hidden states that can speed up sequential
decoding.
"""
last_hidden_state: torch.FloatTensor = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
decoder_router_logits: Optional[Tuple[torch.FloatTensor]] = None
cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
encoder_router_logits: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class CausalLMOutput(ModelOutput):
"""
Base class for causal language model (or autoregressive) outputs.
"""
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss (for next-token prediction).
语言建模损失(用于下一个标记预测),是一个形状为 `(1,)` 的 `torch.FloatTensor`,当提供 `labels` 时返回。
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
语言建模头部的预测分数(在 SoftMax 之前每个词汇标记的得分),形状为 `(batch_size, sequence_length, config.vocab_size)` 的 `torch.FloatTensor`。
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
模型在每层输出的隐藏状态,以及可选的初始嵌入输出,形状为 `(batch_size, sequence_length, hidden_size)` 的 `torch.FloatTensor` 元组。
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
注意力 softmax 后的注意力权重,用于计算自注意力头部中的加权平均,形状为 `(batch_size, num_heads, sequence_length, sequence_length)` 的 `torch.FloatTensor` 元组。
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
# 使用 dataclass 装饰器定义一个类,用于表示因果语言模型(或自回归模型)的输出结果,继承自 ModelOutput 类。
@dataclass
class CausalLMOutputWithPast(ModelOutput):
"""
因果语言模型(或自回归模型)输出的基类。
Args:
loss (`torch.FloatTensor`,形状为 `(1,)`,*可选*,当提供 `labels` 参数时返回):
语言建模的损失(用于下一个标记的预测)。
logits (`torch.FloatTensor`,形状为 `(batch_size, sequence_length, config.vocab_size)`):
语言建模头部的预测分数(每个词汇标记的分数,在 SoftMax 之前)。
past_key_values (`tuple(tuple(torch.FloatTensor))`,*可选*,当传递 `use_cache=True` 或 `config.use_cache=True` 时返回):
包含预先计算的隐藏状态(自注意力块中的键和值),可用于加速顺序解码。
是一个长度为 `config.n_layers` 的元组,每个元组包含 2 个形状为 `(batch_size, num_heads, sequence_length, embed_size_per_head)` 的张量。
hidden_states (`tuple(torch.FloatTensor)`,*可选*,当传递 `output_hidden_states=True` 或 `config.output_hidden_states=True` 时返回):
包含模型在每一层输出的隐藏状态张量的元组(如果模型有嵌入层,则包含初始嵌入输出),
形状为 `(batch_size, sequence_length, hidden_size)`。
attentions (`tuple(torch.FloatTensor)`,*可选*,当传递 `output_attentions=True` 或 `config.output_attentions=True` 时返回):
自注意力头部中注意力 softmax 后的注意力权重张量的元组(每层一个),形状为 `(batch_size, num_heads, sequence_length, sequence_length)`。
"""
# 以下是类的字段定义
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
# 使用 dataclass 装饰器定义另一个类,用于表示具有交叉注意力的因果语言模型(或自回归模型)的输出结果,继承自 ModelOutput 类。
@dataclass
class CausalLMOutputWithCrossAttentions(ModelOutput):
"""
因果语言模型(或自回归模型)输出的基类,具有交叉注意力。
这个类继承自 ModelOutput。
"""
# 注意:这里的类定义未完全提供,根据文档字符串需要添加额外的字段和解释。
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss (for next-token prediction).
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Cross attentions weights after the attention softmax, used to compute the weighted average in the
cross-attention heads.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `torch.FloatTensor` tuples of length `config.n_layers`, with each tuple containing the cached key,
value states of the self-attention and the cross-attention layers if model is used in encoder-decoder
setting. Only relevant if `config.is_decoder = True`.
Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
@dataclass
class SequenceClassifierOutputWithPast(ModelOutput):
"""
Base class for outputs of sequence classification models that also include past key values,
hidden states, and attentions.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Classification (or regression if config.num_labels==1) loss.
logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attention weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
@dataclass
class MaskedLMOutput(ModelOutput):
"""
Base class for outputs of masked language models.
This class inherits `ModelOutput`, indicating it provides standard output for models.
"""
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Masked language modeling (MLM) loss.
掩码语言建模(MLM)损失,当提供`labels`时返回此值。
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
语言建模头的预测分数(SoftMax之前的每个词汇标记的分数)。
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer,
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
模型在每一层输出的隐藏状态,以及可选的初始嵌入输出。
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
注意力权重,经过注意力SoftMax后的值,用于计算自注意力头中的加权平均值。
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
@dataclass
class Seq2SeqLMOutput(ModelOutput):
"""
Base class for sequence-to-sequence language models outputs.
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
@dataclass
class Seq2SeqMoEOutput(ModelOutput):
"""
Base class for sequence-to-sequence language models outputs.
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
encoder_z_loss: torch.FloatTensor = None
decoder_z_loss: torch.FloatTensor = None
encoder_aux_loss: torch.FloatTensor = None
decoder_aux_loss: torch.FloatTensor = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
decoder_router_logits: Optional[Tuple[torch.FloatTensor]] = None
cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
encoder_router_logits: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class NextSentencePredictorOutput(ModelOutput):
"""
Base class for outputs of models predicting if two sentences are consecutive or not.
# 定义 loss 变量,用于存储下一个序列预测(分类)的损失值,类型为 torch.FloatTensor,可选参数,当提供 `next_sentence_label` 时返回。
loss: Optional[torch.FloatTensor] = None
# 定义 logits 变量,用于存储下一个序列预测(分类)头部的预测分数,形状为 `(batch_size, 2)` 的 torch.FloatTensor。
logits: torch.FloatTensor = None
# 定义 hidden_states 变量,用于存储模型每一层的隐藏状态输出,类型为元组 `Tuple[torch.FloatTensor, ...]`,可选参数,当 `output_hidden_states=True` 或 `config.output_hidden_states=True` 时返回。
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
# 定义 attentions 变量,用于存储注意力权重输出,类型为元组 `Tuple[torch.FloatTensor, ...]`,可选参数,当 `output_attentions=True` 或 `config.output_attentions=True` 时返回。
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
# 定义一个数据类,用于表示序列分类器模型的输出结果
@dataclass
class SequenceClassifierOutput(ModelOutput):
"""
Base class for outputs of sentence classification models.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Classification (or regression if config.num_labels==1) loss.
分类模型的损失值(如果提供了`labels`):一个形状为`(1,)`的`torch.FloatTensor`,在提供`labels`时返回。
logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
分类(或回归,如果`config.num_labels==1`)得分(SoftMax之前)的`torch.FloatTensor`,形状为`(batch_size, config.num_labels)`。
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
模型每一层的输出的隐藏状态,以及可选的初始嵌入输出。形状为`(batch_size, sequence_length, hidden_size)`。
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
注意力权重(经过注意力SoftMax后的)的元组,用于计算自注意力头中的加权平均值。形状为`(batch_size, num_heads, sequence_length, sequence_length)`。
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
# 定义一个数据类,用于表示序列到序列的句子分类器模型的输出结果
@dataclass
class Seq2SeqSequenceClassifierOutput(ModelOutput):
"""
Base class for outputs of sequence-to-sequence sentence classification models.
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
# 定义一个数据类,用于表示多选模型的输出结果
@dataclass
class MultipleChoiceModelOutput(ModelOutput):
"""
Base class for outputs of multiple choice models.
"""
"""
Args:
loss (`torch.FloatTensor` of shape *(1,)*, *optional*, returned when `labels` is provided):
分类损失值。
如果提供了`labels`,则返回此损失值。
logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`):
`num_choices` 是输入张量的第二个维度。
分类分数(SoftMax 之前的值)。
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
一个元组,包含 `torch.FloatTensor` 的张量。
第一个张量是模型嵌入层的输出(如果存在),每一层输出的张量的形状为 `(batch_size, sequence_length, hidden_size)`。
模型在每一层输出的隐藏状态,加上可选的初始嵌入输出。
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
一个元组,包含 `torch.FloatTensor` 的张量。
每一层的注意力权重张量的形状为 `(batch_size, num_heads, sequence_length, sequence_length)`。
经过注意力 softmax 后的注意力权重,用于计算自注意力头中的加权平均值。
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
# 用于描述令牌分类模型输出的基础类
@dataclass
class TokenClassifierOutput(ModelOutput):
"""
Base class for outputs of token classification models.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided) :
Classification loss.
分类损失,当提供 `labels` 参数时返回。
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`):
Classification scores (before SoftMax).
分类分数(SoftMax 之前的结果)。
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
模型在每一层输出的隐藏状态,以及可选的初始嵌入层输出。
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
注意力权重,经过注意力 SoftMax 后的结果,用于计算自注意力头中的加权平均值。
"""
@dataclass
class QuestionAnsweringModelOutput(ModelOutput):
"""
Base class for outputs of question answering models.
用于描述问答模型输出的基础类。
This class is currently empty but can be extended with specific outputs of QA models.
该类目前为空,但可以通过扩展以包含问答模型的特定输出。
"""
# 定义函数参数和返回值的文档字符串,描述了函数的输入和输出
loss: Optional[torch.FloatTensor] = None
# 可选的损失张量,当提供了 `labels` 参数时返回
start_logits: torch.FloatTensor = None
# 开始位置的得分张量,形状为 `(batch_size, sequence_length)`
end_logits: torch.FloatTensor = None
# 结束位置的得分张量,形状为 `(batch_size, sequence_length)`
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
# 可选的隐藏状态元组,包含每层输出的张量,形状为 `(batch_size, sequence_length, hidden_size)`
# 如果模型有嵌入层,则还包含初始嵌入输出
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
# 可选的注意力权重元组,包含每层注意力权重的张量
# 形状为 `(batch_size, num_heads, sequence_length, sequence_length)`
# 用于计算自注意力头中的加权平均值的注意力 softmax 后的注意力权重
# 定义了一个数据类,用于存储序列到序列问答模型的输出结果
@dataclass
class Seq2SeqQuestionAnsweringModelOutput(ModelOutput):
"""
Base class for outputs of sequence-to-sequence question answering models.
"""
# 损失值,如果存在的话,类型为 torch.FloatTensor
loss: Optional[torch.FloatTensor] = None
# 开始位置的预测 logits,类型为 torch.FloatTensor
start_logits: torch.FloatTensor = None
# 结束位置的预测 logits,类型为 torch.FloatTensor
end_logits: torch.FloatTensor = None
# 过去的键值,类型为可选的元组,包含了一系列 torch.FloatTensor
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
# 解码器的隐藏状态,类型为可选的元组,包含了一系列 torch.FloatTensor
decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
# 解码器的注意力权重,类型为可选的元组,包含了一系列 torch.FloatTensor
decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
# 交叉注意力权重,类型为可选的元组,包含了一系列 torch.FloatTensor
cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
# 编码器最后的隐藏状态,类型为可选的 torch.FloatTensor
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
# 编码器的隐藏状态,类型为可选的元组,包含了一系列 torch.FloatTensor
encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
# 编码器的注意力权重,类型为可选的元组,包含了一系列 torch.FloatTensor
encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
# 定义了一个数据类,用于存储语义分割模型的输出结果
@dataclass
class SemanticSegmenterOutput(ModelOutput):
"""
Base class for outputs of semantic segmentation models.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Classification (or regression if config.num_labels==1) loss.
logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels, logits_height, logits_width)`):
Classification scores for each pixel.
<Tip warning={true}>
The logits returned do not necessarily have the same size as the `pixel_values` passed as inputs. This is
to avoid doing two interpolations and lose some quality when a user needs to resize the logits to the
original image size as post-processing. You should always check your logits shape and resize as needed.
</Tip>
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, patch_size, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
# 损失值,如果存在的话,类型为 torch.FloatTensor
loss: Optional[torch.FloatTensor] = None
# 分类得分 logits,类型为 torch.FloatTensor,形状为 (batch_size, config.num_labels, logits_height, logits_width)
logits: torch.FloatTensor = None
# 隐藏状态,类型为可选的元组,包含了一系列 torch.FloatTensor
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
# 注意力权重,类型为可选的元组,包含了一系列 torch.FloatTensor
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
# 定义了一个数据类,用于存储图像分类模型的输出结果
@dataclass
class ImageClassifierOutput(ModelOutput):
"""
Base class for outputs of image classification models.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
分类(或者回归,如果 `config.num_labels==1`)的损失值。
logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
分类(或者回归,如果 `config.num_labels==1`)的分数(SoftMax 之前)。
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
元组类型的 `torch.FloatTensor`,包含模型在每个阶段输出的隐藏状态(特征映射),形状为 `(batch_size, sequence_length, hidden_size)`。如果模型包含嵌入层,则第一个张量表示嵌入的输出。
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
元组类型的 `torch.FloatTensor`,包含模型的注意力权重,形状为 `(batch_size, num_heads, patch_size, sequence_length)`。这些权重经过注意力 SoftMax 后得到,用于计算自注意力头中的加权平均值。
@dataclass
class ImageClassifierOutputWithNoAttention(ModelOutput):
"""
Base class for outputs of image classification models.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
分类模型的损失值(如果提供了`labels`参数)。
logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
分类模型的输出分数(在经过 SoftMax 之前)。
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
模型每个阶段的隐藏状态(也称为特征图),形状为 `(batch_size, num_channels, height, width)`。
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
@dataclass
class DepthEstimatorOutput(ModelOutput):
"""
Base class for outputs of depth estimation models.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
深度估计模型的损失值(如果提供了`labels`参数)。
predicted_depth (`torch.FloatTensor` of shape `(batch_size, height, width)`):
每个像素预测的深度值。
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
模型每个层的隐藏状态(也称为特征图),形状为 `(batch_size, num_channels, height, width)`。
每个层的输出以及可选的初始嵌入输出。
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
每个层的注意力权重,形状为 `(batch_size, num_heads, patch_size, sequence_length)`。
注意力softmax后的权重,用于计算自注意力头中的加权平均值。
"""
loss: Optional[torch.FloatTensor] = None
predicted_depth: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
@dataclass
class ImageSuperResolutionOutput(ModelOutput):
"""
Base class for outputs of image super resolution models.
"""
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
重建损失,当提供`labels`时返回。
reconstruction (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
重建的图像,可能是上采样后的结果。
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
一个元组,包含`torch.FloatTensor`类型的张量:
- 如果模型有嵌入层,则为形状为`(batch_size, sequence_length, hidden_size)`的张量;
- 每个阶段输出的隐藏状态(也称为特征图)。
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
一个元组,包含`torch.FloatTensor`类型的张量:
- 每层的注意力权重,形状为`(batch_size, num_heads, patch_size, sequence_length)`。
注意力 softmax 后的注意力权重,用于计算自注意力头中的加权平均值。
loss: Optional[torch.FloatTensor] = None
reconstruction: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
@dataclass
class Wav2Vec2BaseModelOutput(ModelOutput):
"""
Base class for models that have been trained with the Wav2Vec2 loss objective.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
extract_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, conv_dim[-1])`):
Sequence of extracted feature vectors of the last convolutional layer of the model.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
last_hidden_state: torch.FloatTensor = None
extract_features: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
@dataclass
class XVectorOutput(ModelOutput):
"""
Output type of [`Wav2Vec2ForXVector`].
"""
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
分类损失。
如果提供了 `labels`,则返回分类损失。
logits (`torch.FloatTensor` of shape `(batch_size, config.xvector_output_dim)`):
AMSoftmax 前的分类隐藏状态。
用于 AMSoftmax 前的分类隐藏状态。
embeddings (`torch.FloatTensor` of shape `(batch_size, config.xvector_output_dim)`):
用于基于向量相似性检索的话语嵌入。
用于基于向量相似性检索的话语嵌入。
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
模型每层输出的隐藏状态。
当传递 `output_hidden_states=True` 或 `config.output_hidden_states=True` 时返回。
元组包含了每层的 `torch.FloatTensor`,形状为 `(batch_size, sequence_length, hidden_size)`。
包括每层的隐藏状态以及初始嵌入输出。
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
自注意力权重。
当传递 `output_attentions=True` 或 `config.output_attentions=True` 时返回。
元组包含了每层的 `torch.FloatTensor`,形状为 `(batch_size, num_heads, sequence_length, sequence_length)`。
在注意力 softmax 后的注意力权重,用于计算自注意力头中的加权平均值。
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
embeddings: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
# 使用 dataclass 装饰器定义一个名为 `BackboneOutput` 的数据类,它继承自 `ModelOutput` 类
@dataclass
class BackboneOutput(ModelOutput):
"""
Base class for outputs of backbones.
Args:
feature_maps (`tuple(torch.FloatTensor)` of shape `(batch_size, num_channels, height, width)`):
Feature maps of the stages.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
shape `(batch_size, sequence_length, hidden_size)` or `(batch_size, num_channels, height, width)`,
depending on the backbone.
Hidden-states of the model at the output of each stage plus the initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`. Only applicable if the backbone uses attention.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
# 定义特征图的属性,类型为元组,包含了每个阶段的特征图
feature_maps: Tuple[torch.FloatTensor] = None
# 定义隐藏状态的属性,类型为可选的元组,包含了每个阶段的隐藏状态
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
# 定义注意力权重的属性,类型为可选的元组,包含了每个层的注意力权重
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
# 使用 dataclass 装饰器定义一个名为 `BaseModelOutputWithPoolingAndProjection` 的数据类,它继承自 `ModelOutput` 类
@dataclass
class BaseModelOutputWithPoolingAndProjection(ModelOutput):
"""
Base class for model's outputs that also contains a pooling of the last hidden states.
"""
# 定义函数参数和它们的类型注释,描述了函数所接收的不同类型的输入数据
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
模型最后一层输出的隐藏状态序列。
pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
经过附加预训练任务处理后的序列第一个标记(分类标记)的最后一层隐藏状态。
例如,在BERT系列模型中,这是经过线性层和tanh激活函数处理后的分类标记。
线性层的权重是从预训练过程中的下一句预测(分类)目标中训练得到的。
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
模型每一层的隐藏状态序列的元组。
每个元素的形状为 `(batch_size, sequence_length, hidden_size)`,包括可选的初始嵌入层输出。
当 `output_hidden_states=True` 传递给模型或者 `config.output_hidden_states=True` 时返回。
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
模型每一层的注意力权重的元组。
每个元素的形状为 `(batch_size, num_heads, sequence_length, sequence_length)`,
用于计算自注意力头中的加权平均值。
当 `output_attentions=True` 传递给模型或者 `config.output_attentions=True` 时返回。
projection_state (`tuple(torch.FloatTensor)`, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
投影层之前的文本嵌入的元组。
每个元素的形状为 `(batch_size, config.project_dim)`,
用于模拟教师编码器的最后隐藏状态。
@dataclass
class Seq2SeqSpectrogramOutput(ModelOutput):
"""
Base class for sequence-to-sequence spectrogram outputs.
"""
loss: Optional[torch.FloatTensor] = None # 损失值,用于存储模型输出的损失
spectrogram: torch.FloatTensor = None # 频谱图数据,存储模型生成的频谱图
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None # 过去的键值对,用于存储可加速顺序解码的隐藏状态
decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None # 解码器的隐藏状态列表
decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None # 解码器的注意力权重列表
cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None # 交叉注意力权重列表
encoder_last_hidden_state: Optional[torch.FloatTensor] = None # 编码器的最后隐藏状态
encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None # 编码器的隐藏状态列表
encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None # 编码器的注意力权重列表
@dataclass
class Seq2SeqTSModelOutput(ModelOutput):
"""
Base class for time series model's encoder outputs that also contains pre-computed hidden states that can speed up
sequential decoding.
"""
last_hidden_state: torch.FloatTensor = None # 最后的隐藏状态,存储编码器最后的隐藏状态
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None # 过去的键值对,用于存储可加速顺序解码的隐藏状态
decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None # 解码器的隐藏状态列表
decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None # 解码器的注意力权重列表
cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None # 交叉注意力权重列表
encoder_last_hidden_state: Optional[torch.FloatTensor] = None # 编码器的最后隐藏状态
encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None # 编码器的隐藏状态列表
encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None # 编码器的注意力权重列表
loc: Optional[torch.FloatTensor] = None # 位置参数,用于存储预测分布的位置参数
scale: Optional[torch.FloatTensor] = None # 尺度参数,用于存储预测分布的尺度参数
static_features: Optional[torch.FloatTensor] = None # 静态特征,用于存储与时间序列模型相关的静态特征
@dataclass
class Seq2SeqTSPredictionOutput(ModelOutput):
"""
Base class for time series model's decoder outputs that also contain the loss as well as the parameters of the
chosen distribution.
"""
loss: Optional[torch.FloatTensor] = None # 损失值,用于存储模型输出的损失
params: Optional[Tuple[torch.FloatTensor]] = None # 参数,用于存储所选分布的参数
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None # 过去的键值对,用于存储可加速顺序解码的隐藏状态
decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None # 解码器的隐藏状态列表
decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None # 解码器的注意力权重列表
cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None # 交叉注意力权重列表
encoder_last_hidden_state: Optional[torch.FloatTensor] = None # 编码器的最后隐藏状态
encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None # 编码器的隐藏状态列表
encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None # 编码器的注意力权重列表
loc: Optional[torch.FloatTensor] = None # 位置参数,用于存储预测分布的位置参数
scale: Optional[torch.FloatTensor] = None # 尺度参数,用于存储预测分布的尺度参数
static_features: Optional[torch.FloatTensor] = None # 静态特征,用于存储与时间序列模型相关的静态特征
@dataclass
class SampleTSPredictionOutput(ModelOutput):
"""
Base class for time series model's predictions outputs that contains the sampled values from the chosen
distribution.
Args:
sequences (`torch.FloatTensor` of shape `(batch_size, num_samples, prediction_length)` or `(batch_size, num_samples, prediction_length, input_size)`):
Sampled values from the chosen distribution.
"""
# 该类用于存储时间序列模型的预测输出,包括从所选分布中采样得到的值
# 声明一个变量 sequences,类型为 torch 的 FloatTensor,初始值为 None
sequences: torch.FloatTensor = None
# 使用 dataclass 装饰器定义 MaskedImageModelingOutput 类,用于封装掩码图像完成/修补模型的输出结果
@dataclass
class MaskedImageModelingOutput(ModelOutput):
"""
Base class for outputs of masked image completion / in-painting models.
掩码图像完成/修补模型输出结果的基类。
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided):
Reconstruction loss.
重建损失,当提供 `bool_masked_pos` 时返回。
reconstruction (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Reconstructed / completed images.
重建/完成的图像。
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or
when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states
(also called feature maps) of the model at the output of each stage.
隐藏状态,模型在每个阶段输出的隐藏状态(特征图)元组。
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when
`config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size,
sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
the self-attention heads.
注意力权重,经过注意力 softmax 后的权重,用于计算自注意力头中的加权平均值。
"""
# 定义 loss 属性,类型为 torch.FloatTensor,可选,表示重建损失,默认为 None
loss: Optional[torch.FloatTensor] = None
# 定义 reconstruction 属性,类型为 torch.FloatTensor,表示重建/完成的图像
reconstruction: torch.FloatTensor = None
# 定义 hidden_states 属性,类型为 tuple(torch.FloatTensor),可选,表示隐藏状态的元组
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
# 定义 attentions 属性,类型为 tuple(torch.FloatTensor),可选,表示注意力权重的元组
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
# @property 装饰器,定义 logits 属性,用于获取输出的最终结果
@property
def logits(self):
# 发出警告,提醒 logits 属性在 Transformers 版本 5 中将被移除,请使用 reconstruction 属性获取最终输出
warnings.warn(
"logits attribute is deprecated and will be removed in version 5 of Transformers."
" Please use the reconstruction attribute to retrieve the final output instead.",
FutureWarning,
)
# 返回 reconstruction 属性作为最终输出
return self.reconstruction