/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
// 包含标准库向量头文件
#include <vector>
// 包含 ATen 库的头文件,提供张量操作
#include <ATen/ATen.h>
// 包含 CUDA 上下文头文件,用于处理 CUDA 相关操作
#include <ATen/cuda/CUDAContext.h>
// 定义 CPU 下的前向传播函数,返回 ATen 张量
at::Tensor
ms_deform_attn_cpu_forward(
const at::Tensor &value, // 输入张量 value
const at::Tensor &spatial_shapes, // 空间形状张量
const at::Tensor &level_start_index, // 层级起始索引张量
const at::Tensor &sampling_loc, // 采样位置张量
const at::Tensor &attn_weight, // 注意力权重张量
const int im2col_step) // im2col 步长
{
// 抛出错误,表明在 CPU 上未实现该函数
AT_ERROR("Not implement on cpu");
}
// 定义 CPU 下的反向传播函数,返回 ATen 张量向量
std::vector<at::Tensor>
ms_deform_attn_cpu_backward(
const at::Tensor &value, // 输入张量 value
const at::Tensor &spatial_shapes, // 空间形状张量
const at::Tensor &level_start_index, // 层级起始索引张量
const at::Tensor &sampling_loc, // 采样位置张量
const at::Tensor &attn_weight, // 注意力权重张量
const at::Tensor &grad_output, // 梯度输出张量
const int im2col_step) // im2col 步长
{
// 抛出错误,表明在 CPU 上未实现该函数
AT_ERROR("Not implement on cpu");
}
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
// 预处理指令,确保头文件只被包含一次
#pragma once
// 包含 PyTorch C++ 扩展库头文件
#include <torch/extension.h>
// 前向传播函数声明,计算注意力机制的前向传播
at::Tensor
ms_deform_attn_cpu_forward(
const at::Tensor &value, // 输入的特征值张量
const at::Tensor &spatial_shapes, // 空间形状信息张量
const at::Tensor &level_start_index, // 层级起始索引张量
const at::Tensor &sampling_loc, // 采样位置张量
const at::Tensor &attn_weight, // 注意力权重张量
const int im2col_step); // im2col 步长
// 反向传播函数声明,计算注意力机制的反向传播
std::vector<at::Tensor>
ms_deform_attn_cpu_backward(
const at::Tensor &value, // 输入的特征值张量
const at::Tensor &spatial_shapes, // 空间形状信息张量
const at::Tensor &level_start_index, // 层级起始索引张量
const at::Tensor &sampling_loc, // 采样位置张量
const at::Tensor &attn_weight, // 注意力权重张量
const at::Tensor &grad_output, // 梯度输出张量
const int im2col_step); // im2col 步长
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#pragma once
// 包含 Torch C++ 扩展库的头文件
#include <torch/extension.h>
// 声明 CUDA 前向函数,计算多尺度可变形注意力机制的前向传播
at::Tensor ms_deform_attn_cuda_forward(
const at::Tensor &value, // 输入张量:特征图
const at::Tensor &spatial_shapes, // 输入张量:空间形状
const at::Tensor &level_start_index, // 输入张量:每级起始索引
const at::Tensor &sampling_loc, // 输入张量:采样位置
const at::Tensor &attn_weight, // 输入张量:注意力权重
const int im2col_step // 输入整数:im2col 步骤
);
// 声明 CUDA BF16(BFloat16)前向函数,计算多尺度可变形注意力机制的前向传播
at::Tensor ms_deform_attn_cuda_forward_bf16(
const at::Tensor &value, // 输入张量:特征图
const at::Tensor &spatial_shapes, // 输入张量:空间形状
const at::Tensor &level_start_index, // 输入张量:每级起始索引
const at::Tensor &sampling_loc, // 输入张量:采样位置
const at::Tensor &attn_weight, // 输入张量:注意力权重
const int im2col_step // 输入整数:im2col 步骤
);
// 声明 CUDA 反向函数,计算多尺度可变形注意力机制的反向传播
std::vector<at::Tensor> ms_deform_attn_cuda_backward(
const at::Tensor &value, // 输入张量:特征图
const at::Tensor &spatial_shapes, // 输入张量:空间形状
const at::Tensor &level_start_index, // 输入张量:每级起始索引
const at::Tensor &sampling_loc, // 输入张量:采样位置
const at::Tensor &attn_weight, // 输入张量:注意力权重
const at::Tensor &grad_output, // 输入张量:梯度输出
const int im2col_step // 输入整数:im2col 步骤
);
// 声明 CUDA BF16(BFloat16)反向函数,计算多尺度可变形注意力机制的反向传播
std::vector<at::Tensor> ms_deform_attn_cuda_backward_bf16(
const at::Tensor &value, // 输入张量:特征图
const at::Tensor &spatial_shapes, // 输入张量:空间形状
const at::Tensor &level_start_index, // 输入张量:每级起始索引
const at::Tensor &sampling_loc, // 输入张量:采样位置
const at::Tensor &attn_weight, // 输入张量:注意力权重
const at::Tensor &grad_output, // 输入张量:梯度输出
const int im2col_step // 输入整数:im2col 步骤
);
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#pragma once
#include "cpu/ms_deform_attn_cpu.h"
#ifdef WITH_CUDA
#include "cuda/ms_deform_attn_cuda.h"
#endif
// 前向传播函数,处理注意力机制的计算
at::Tensor
ms_deform_attn_forward(
const at::Tensor &value, // 输入张量,表示特征值
const at::Tensor &spatial_shapes, // 空间形状信息的张量
const at::Tensor &level_start_index, // 层级起始索引
const at::Tensor &sampling_loc, // 采样位置
const at::Tensor &attn_weight, // 注意力权重
const int im2col_step) // im2col 步长
{
// 如果输入张量在 GPU 上
if (value.type().is_cuda())
{
#ifdef WITH_CUDA
// 调用 CUDA 实现的前向传播函数
return ms_deform_attn_cuda_forward(
value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step);
#else
// 如果没有编译 GPU 支持,则抛出错误
AT_ERROR("Not compiled with GPU support");
#endif
}
// 如果输入张量在 CPU 上,则抛出未实现 CPU 上的错误
AT_ERROR("Not implemented on the CPU");
}
// 反向传播函数,处理注意力机制的反向梯度计算
std::vector<at::Tensor>
ms_deform_attn_backward(
const at::Tensor &value, // 输入张量,表示特征值
const at::Tensor &spatial_shapes, // 空间形状信息的张量
const at::Tensor &level_start_index, // 层级起始索引
const at::Tensor &sampling_loc, // 采样位置
const at::Tensor &attn_weight, // 注意力权重
const at::Tensor &grad_output, // 梯度输出
const int im2col_step) // im2col 步长
{
// 如果输入张量在 GPU 上
if (value.type().is_cuda())
{
#ifdef WITH_CUDA
// 调用 CUDA 实现的反向传播函数
return ms_deform_attn_cuda_backward(
value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step);
#else
// 如果没有编译 GPU 支持,则抛出错误
AT_ERROR("Not compiled with GPU support");
#endif
}
// 如果输入张量在 CPU 上,则抛出未实现 CPU 上的错误
AT_ERROR("Not implemented on the CPU");
}
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#include "ms_deform_attn.h"
// 使用 PYBIND11_MODULE 宏定义,将 C++ 函数绑定到 Python 中
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// 定义 Python 可调用函数 ms_deform_attn_forward,对应 C++ 中的 ms_deform_attn_forward 函数
m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward");
// 定义 Python 可调用函数 ms_deform_attn_backward,对应 C++ 中的 ms_deform_attn_backward 函数
m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward");
}
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#include <vector>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
// 定义了一个函数,用于在 CPU 上执行 ms_deform_attn 的前向传播
at::Tensor
ms_deform_attn_cpu_forward(
const at::Tensor &value, // 输入张量 value
const at::Tensor &spatial_shapes, // 空间形状信息的张量
const at::Tensor &level_start_index, // 级别起始索引的张量
const at::Tensor &sampling_loc, // 采样位置的张量
const at::Tensor &attn_weight, // 注意力权重的张量
const int im2col_step) // im2col 步长参数
{
// 抛出错误,表示在 CPU 上尚未实现该函数
AT_ERROR("Not implement on cpu");
}
// 定义了一个函数,用于在 CPU 上执行 ms_deform_attn 的反向传播
std::vector<at::Tensor>
ms_deform_attn_cpu_backward(
const at::Tensor &value, // 输入张量 value
const at::Tensor &spatial_shapes, // 空间形状信息的张量
const at::Tensor &level_start_index, // 级别起始索引的张量
const at::Tensor &sampling_loc, // 采样位置的张量
const at::Tensor &attn_weight, // 注意力权重的张量
const at::Tensor &grad_output, // 梯度输出的张量
const int im2col_step) // im2col 步长参数
{
// 抛出错误,表示在 CPU 上尚未实现该函数
AT_ERROR("Not implement on cpu");
}
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
// 预处理指令,确保头文件只被包含一次
#pragma once
// 包含 PyTorch C++ 扩展的头文件
#include <torch/extension.h>
// 前向推断函数声明,计算可变形注意力机制的前向传播
at::Tensor
ms_deform_attn_cpu_forward(
const at::Tensor &value, // 输入特征张量
const at::Tensor &spatial_shapes, // 空间形状信息
const at::Tensor &level_start_index,// 级别起始索引
const at::Tensor &sampling_loc, // 采样位置
const at::Tensor &attn_weight, // 注意力权重
const int im2col_step); // im2col 步长
// 反向传播函数声明,计算可变形注意力机制的反向传播
std::vector<at::Tensor>
ms_deform_attn_cpu_backward(
const at::Tensor &value, // 输入特征张量
const at::Tensor &spatial_shapes, // 空间形状信息
const at::Tensor &level_start_index,// 级别起始索引
const at::Tensor &sampling_loc, // 采样位置
const at::Tensor &attn_weight, // 注意力权重
const at::Tensor &grad_output, // 梯度输出
const int im2col_step); // im2col 步长
这段代码是一个C++头文件,声明了两个函数 `ms_deform_attn_cpu_forward` 和 `ms_deform_attn_cpu_backward`,用于实现可变形注意力机制的前向传播和反向传播。
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
// 包含 Torch C++ 扩展的头文件
#pragma once
#include <torch/extension.h>
// 声明 CUDA 前向传播函数,接受多个张量和整数参数
at::Tensor ms_deform_attn_cuda_forward(
const at::Tensor &value, // 输入特征值张量
const at::Tensor &spatial_shapes, // 空间形状信息张量
const at::Tensor &level_start_index, // 层级起始索引张量
const at::Tensor &sampling_loc, // 采样位置张量
const at::Tensor &attn_weight, // 注意力权重张量
const int im2col_step); // im2col 步长整数参数
// 声明 CUDA 反向传播函数,接受多个张量和整数参数,并返回张量向量
std::vector<at::Tensor> ms_deform_attn_cuda_backward(
const at::Tensor &value, // 输入特征值张量
const at::Tensor &spatial_shapes, // 空间形状信息张量
const at::Tensor &level_start_index, // 层级起始索引张量
const at::Tensor &sampling_loc, // 采样位置张量
const at::Tensor &attn_weight, // 注意力权重张量
const at::Tensor &grad_output, // 梯度输出张量
const int im2col_step); // im2col 步长整数参数
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#pragma once
#include "cpu/ms_deform_attn_cpu.h"
#ifdef WITH_CUDA
#include "cuda/ms_deform_attn_cuda.h"
#endif
// 前向传播函数,用于实现可形变注意力机制的前向计算
at::Tensor
ms_deform_attn_forward(
const at::Tensor &value, // 输入张量 value
const at::Tensor &spatial_shapes, // 空间形状信息张量
const at::Tensor &level_start_index, // 层级起始索引张量
const at::Tensor &sampling_loc, // 采样位置张量
const at::Tensor &attn_weight, // 注意力权重张量
const int im2col_step) // im2col 步长参数
{
// 如果输入张量在 CUDA 上,则调用 CUDA 实现的前向函数
if (value.type().is_cuda())
{
#ifdef WITH_CUDA
return ms_deform_attn_cuda_forward(
value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step);
#else
// 如果没有编译 GPU 支持,则抛出错误信息
AT_ERROR("Not compiled with GPU support");
#endif
}
// 如果在 CPU 上调用该函数,则抛出错误信息,表明未实现 CPU 版本
AT_ERROR("Not implemented on the CPU");
}
// 反向传播函数,用于实现可形变注意力机制的反向计算
std::vector<at::Tensor>
ms_deform_attn_backward(
const at::Tensor &value, // 输入张量 value
const at::Tensor &spatial_shapes, // 空间形状信息张量
const at::Tensor &level_start_index, // 层级起始索引张量
const at::Tensor &sampling_loc, // 采样位置张量
const at::Tensor &attn_weight, // 注意力权重张量
const at::Tensor &grad_output, // 梯度输出张量
const int im2col_step) // im2col 步长参数
{
// 如果输入张量在 CUDA 上,则调用 CUDA 实现的反向函数
if (value.type().is_cuda())
{
#ifdef WITH_CUDA
return ms_deform_attn_cuda_backward(
value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step);
#else
// 如果没有编译 GPU 支持,则抛出错误信息
AT_ERROR("Not compiled with GPU support");
#endif
}
// 如果在 CPU 上调用该函数,则抛出错误信息,表明未实现 CPU 版本
AT_ERROR("Not implemented on the CPU");
}
.\kernels\deta\vision.cpp
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#include "ms_deform_attn.h"
// 使用 Pybind11 构建一个 Python 模块,名字为 TORCH_EXTENSION_NAME
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// 定义 Python 接口函数 ms_deform_attn_forward,与 C++ 函数 ms_deform_attn_forward 绑定
m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward");
// 定义 Python 接口函数 ms_deform_attn_backward,与 C++ 函数 ms_deform_attn_backward 绑定
m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward");
}
.\kernels\mra\cuda_kernel.h
// 定义线程块大小为32
#define WARP_SIZE 32
// 定义全掩码为32位全1
#define FULL_MASK 0xffffffff
// 定义优化线程数为256
#define OPTIMAL_THREADS 256
// CUDA 核函数,计算每个批次中每个块中的最大值索引和最大值
__global__ void index_max_cuda_kernel(
float *index_vals, // [batch_size, 32, num_block]
int *indices, // [batch_size, num_block]
float *max_vals, // [batch_size, A_num_block * 32]
float *max_vals_scatter, // [batch_size, 32, num_block]
long batch_size, // 批次大小
long A_num_block, // A_num_block
long B_num_block, // B_num_block
long num_block // num_block
);
// CUDA 核函数,将稠密矩阵乘法结果转换为稀疏格式
__global__ void mm_to_sparse_cuda_kernel(
float *dense_A, // [batch_size, A_num_block, dim, 32]
float *dense_B, // [batch_size, B_num_block, dim, 32]
int *indices, // [batch_size, num_block]
float *sparse_C, // [batch_size, num_block, 32, 32]
long batch_size, // 批次大小
long A_num_block, // A_num_block
long B_num_block, // B_num_block
long dim, // dim
long num_block // num_block
);
// CUDA 核函数,稀疏矩阵与稠密矩阵的乘法
__global__ void sparse_dense_mm_cuda_kernel(
float *sparse_A, // [batch_size, num_block, 32, 32]
int *indices, // [batch_size, num_block]
float *dense_B, // [batch_size, B_num_block, dim, 32]
float *dense_C, // [batch_size, A_num_block, dim, 32]
long batch_size, // 批次大小
long A_num_block, // A_num_block
long B_num_block, // B_num_block
long dim, // dim
long num_block // num_block
);
// CUDA 核函数,计算稀疏矩阵在指定维度上的和
__global__ void reduce_sum_cuda_kernel(
float *sparse_A, // [batch_size, num_block, 32, 32]
int *indices, // [batch_size, num_block]
float *dense_C, // [batch_size, A_num_block, 32]
long batch_size, // 批次大小
long A_num_block, // A_num_block
long B_num_block, // B_num_block
long num_block // num_block
);
// CUDA 核函数,将稠密矩阵按索引散布到稀疏矩阵中
__global__ void scatter_cuda_kernel(
float *dense_A, // [batch_size, A_num_block, 32]
int *indices, // [batch_size, num_block]
float *sparse_C, // [batch_size, num_block, 32, 32]
long batch_size, // 批次大小
long A_num_block, // A_num_block
long B_num_block, // B_num_block
long num_block // num_block
);
.\kernels\mra\cuda_launch.h
# 包含 Torch C++ 扩展的头文件
#include <torch/extension.h>
# 包含 ATen 库的头文件,用于张量操作
#include <ATen/ATen.h>
# 包含 vector 标准库,用于定义和操作动态数组
#include <vector>
# 定义宏函数 min,返回两个数中较小的那个
#define min(a, b) ((a)<(b)?(a):(b))
# 定义宏函数 max,返回两个数中较大的那个
#define max(a, b) ((a)>(b)?(a):(b))
# 声明一个函数,该函数返回一个包含多个张量的 vector
std::vector<at::Tensor> index_max_kernel(
at::Tensor index_vals, # 输入参数:索引值的张量
at::Tensor indices, # 输入参数:索引的张量
int A_num_block, # 输入参数:A 矩阵的块数量
int B_num_block # 输入参数:B 矩阵的块数量
);
# 声明一个函数,该函数执行稠密矩阵乘法并返回一个稀疏张量
at::Tensor mm_to_sparse_kernel(
at::Tensor dense_A, # 输入参数:稠密矩阵 A
at::Tensor dense_B, # 输入参数:稠密矩阵 B
at::Tensor indices # 输入参数:索引张量
);
# 声明一个函数,该函数执行稀疏矩阵与稠密矩阵的乘法并返回结果张量
at::Tensor sparse_dense_mm_kernel(
at::Tensor sparse_A, # 输入参数:稀疏矩阵 A
at::Tensor indices, # 输入参数:索引张量
at::Tensor dense_B, # 输入参数:稠密矩阵 B
int A_num_block # 输入参数:A 矩阵的块数量
);
# 声明一个函数,该函数执行稀疏矩阵的元素求和操作并返回结果张量
at::Tensor reduce_sum_kernel(
at::Tensor sparse_A, # 输入参数:稀疏矩阵 A
at::Tensor indices, # 输入参数:索引张量
int A_num_block, # 输入参数:A 矩阵的块数量
int B_num_block # 输入参数:B 矩阵的块数量
);
# 声明一个函数,该函数执行稠密张量的散布(scatter)操作并返回结果张量
at::Tensor scatter_kernel(
at::Tensor dense_A, # 输入参数:稠密张量 A
at::Tensor indices, # 输入参数:索引张量
int B_num_block # 输入参数:B 矩阵的块数量
);
.\kernels\mra\torch_extension.cpp
#include <torch/extension.h>
#include <ATen/ATen.h>
#include "cuda_launch.h" // 引入 CUDA 相关的头文件
#include <vector> // 引入 vector 容器的头文件
std::vector<at::Tensor> index_max( // 定义函数 index_max,返回一个 Tensor 向量
at::Tensor index_vals, // 输入参数 index_vals,类型为 Tensor
at::Tensor indices, // 输入参数 indices,类型为 Tensor
int A_num_block, // 输入参数 A_num_block,整型
int B_num_block // 输入参数 B_num_block,整型
) {
return index_max_kernel( // 调用 index_max_kernel 函数,返回其结果
index_vals, // 将 index_vals 作为参数传递给 index_max_kernel 函数
indices, // 将 indices 作为参数传递给 index_max_kernel 函数
A_num_block, // 将 A_num_block 作为参数传递给 index_max_kernel 函数
B_num_block // 将 B_num_block 作为参数传递给 index_max_kernel 函数
);
}
at::Tensor mm_to_sparse( // 定义函数 mm_to_sparse,返回一个 Tensor
at::Tensor dense_A, // 输入参数 dense_A,类型为 Tensor
at::Tensor dense_B, // 输入参数 dense_B,类型为 Tensor
at::Tensor indices // 输入参数 indices,类型为 Tensor
) {
return mm_to_sparse_kernel( // 调用 mm_to_sparse_kernel 函数,返回其结果
dense_A, // 将 dense_A 作为参数传递给 mm_to_sparse_kernel 函数
dense_B, // 将 dense_B 作为参数传递给 mm_to_sparse_kernel 函数
indices // 将 indices 作为参数传递给 mm_to_sparse_kernel 函数
);
}
at::Tensor sparse_dense_mm( // 定义函数 sparse_dense_mm,返回一个 Tensor
at::Tensor sparse_A, // 输入参数 sparse_A,类型为 Tensor
at::Tensor indices, // 输入参数 indices,类型为 Tensor
at::Tensor dense_B, // 输入参数 dense_B,类型为 Tensor
int A_num_block // 输入参数 A_num_block,整型
) {
return sparse_dense_mm_kernel( // 调用 sparse_dense_mm_kernel 函数,返回其结果
sparse_A, // 将 sparse_A 作为参数传递给 sparse_dense_mm_kernel 函数
indices, // 将 indices 作为参数传递给 sparse_dense_mm_kernel 函数
dense_B, // 将 dense_B 作为参数传递给 sparse_dense_mm_kernel 函数
A_num_block // 将 A_num_block 作为参数传递给 sparse_dense_mm_kernel 函数
);
}
at::Tensor reduce_sum( // 定义函数 reduce_sum,返回一个 Tensor
at::Tensor sparse_A, // 输入参数 sparse_A,类型为 Tensor
at::Tensor indices, // 输入参数 indices,类型为 Tensor
int A_num_block, // 输入参数 A_num_block,整型
int B_num_block // 输入参数 B_num_block,整型
) {
return reduce_sum_kernel( // 调用 reduce_sum_kernel 函数,返回其结果
sparse_A, // 将 sparse_A 作为参数传递给 reduce_sum_kernel 函数
indices, // 将 indices 作为参数传递给 reduce_sum_kernel 函数
A_num_block, // 将 A_num_block 作为参数传递给 reduce_sum_kernel 函数
B_num_block // 将 B_num_block 作为参数传递给 reduce_sum_kernel 函数
);
}
at::Tensor scatter( // 定义函数 scatter,返回一个 Tensor
at::Tensor dense_A, // 输入参数 dense_A,类型为 Tensor
at::Tensor indices, // 输入参数 indices,类型为 Tensor
int B_num_block // 输入参数 B_num_block,整型
) {
return scatter_kernel( // 调用 scatter_kernel 函数,返回其结果
dense_A, // 将 dense_A 作为参数传递给 scatter_kernel 函数
indices, // 将 indices 作为参数传递给 scatter_kernel 函数
B_num_block // 将 B_num_block 作为参数传递给 scatter_kernel 函数
);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // 定义 Python 扩展模块
m.def("index_max", &index_max, "index_max (CUDA)"); // 将 index_max 函数绑定到 Python 中,并指定描述
m.def("mm_to_sparse", &mm_to_sparse, "mm_to_sparse (CUDA)"); // 将 mm_to_sparse 函数绑定到 Python 中,并指定描述
m.def("sparse_dense_mm", &sparse_dense_mm, "sparse_dense_mm (CUDA)"); // 将 sparse_dense_mm 函数绑定到 Python 中,并指定描述
m.def("reduce_sum", &reduce_sum, "reduce_sum (CUDA)"); // 将 reduce_sum 函数绑定到 Python 中,并指定描述
m.def("scatter", &scatter, "scatter (CUDA)"); // 将 scatter 函数绑定到 Python 中,并指定描述
}
.\kernels\rwkv\wkv_op.cpp
# 包含 Torch 扩展和 ATen 库的头文件
#include <torch/extension.h>
#include "ATen/ATen.h"
# 定义一个别名 bf16 代表 ATen 库中的 BFloat16 类型
typedef at::BFloat16 bf16;
# 声明 CUDA 前向传播函数,接受 float 类型数据
void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y);
# 声明 CUDA 前向传播函数,接受 BFloat16 类型数据
void cuda_forward_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y);
# 声明带状态的 CUDA 前向传播函数,接受 float 类型数据
void cuda_forward_with_state(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *s);
# 声明带状态的 CUDA 前向传播函数,接受 BFloat16 类型数据
void cuda_forward_with_state_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, float *s);
# 声明 CUDA 反向传播函数,接受 float 类型数据
void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv);
# 声明 CUDA 反向传播函数,接受 BFloat16 类型数据
void cuda_backward_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, bf16 *gy, bf16 *gw, bf16 *gu, bf16 *gk, bf16 *gv);
# 定义无状态的前向传播函数,接受 Torch 张量参数
void forward(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) {
# 获取张量的批量大小 B,时间步长 T,通道数 C
const int B = k.size(0);
const int T = k.size(1);
const int C = k.size(2);
# 调用 CUDA 前向传播函数,传递 float 类型数据指针
cuda_forward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>());
}
# 定义无状态的前向传播函数,接受 BFloat16 类型的 Torch 张量参数
void forward_bf16(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) {
# 获取张量的批量大小 B,时间步长 T,通道数 C
const int B = k.size(0);
const int T = k.size(1);
const int C = k.size(2);
# 调用 CUDA 前向传播函数,传递 BFloat16 类型数据指针
cuda_forward_bf16(B, T, C, w.data_ptr<float>(), u.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), y.data_ptr<bf16>());
}
# 定义带状态的前向传播函数,接受 Torch 张量参数及状态张量
void forward_with_state(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &s) {
# 获取张量的批量大小 B,时间步长 T,通道数 C
const int B = k.size(0);
const int T = k.size(1);
const int C = k.size(2);
# 调用带状态的 CUDA 前向传播函数,传递 float 类型数据指针及状态张量数据指针
cuda_forward_with_state(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>(), s.data_ptr<float>());
}
# 定义带状态的前向传播函数,接受 BFloat16 类型的 Torch 张量参数及状态张量
void forward_with_state_bf16(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &s) {
# 获取张量的批量大小 B,时间步长 T,通道数 C
const int B = k.size(0);
const int T = k.size(1);
const int C = k.size(2);
# 调用带状态的 CUDA 前向传播函数,传递 BFloat16 类型数据指针及状态张量数据指针
cuda_forward_with_state_bf16(B, T, C, w.data_ptr<float>(), u.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), y.data_ptr<bf16>(), s.data_ptr<float>());
}
# 定义反向传播函数,接受 Torch 张量参数
void backward(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) {
# 获取张量的批量大小 B,时间步长 T,通道数 C
const int B = k.size(0);
const int T = k.size(1);
const int C = k.size(2);
# 调用 CUDA 反向传播函数,传递 float 类型数据指针
cuda_backward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>(), gy.data_ptr<float>(), gw.data_ptr<float>(), gu.data_ptr<float>(), gk.data_ptr<float>(), gv.data_ptr<float>());
}
# 定义反向传播函数,接受 BFloat16 类型的 Torch 张量参数
void backward_bf16(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) {
# 获取张量的批量大小 B
const int B = k.size(0);
# 以下代码与 backward 函数相同,但接受 BFloat16 类型数据
cuda_backward_bf16(B, T, C, w.data_ptr<float>(), u.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), y.data_ptr<bf16>(), gy.data_ptr<bf16>(), gw.data_ptr<bf16>(), gu.data_ptr<bf16>(), gk.data_ptr<bf16>(), gv.data_ptr<bf16>());
}
# 定义常量 T,表示 k 张量的第一个维度大小
const int T = k.size(1);
# 定义常量 C,表示 k 张量的第二个维度大小
const int C = k.size(2);
# 调用 CUDA 后向传播函数 cuda_backward_bf16,传递以下参数:
# - B: 未明确提到,可能是某种批量大小或其它参数
# - T: k 张量的第一个维度大小
# - C: k 张量的第二个维度大小
# - w.data_ptr<float>():权重张量 w 的 float 类型数据指针
# - u.data_ptr<bf16>():输入张量 u 的 bf16 类型数据指针
# - k.data_ptr<bf16>():内核张量 k 的 bf16 类型数据指针
# - v.data_ptr<bf16>():中间变量 v 的 bf16 类型数据指针
# - y.data_ptr<bf16>():输出张量 y 的 bf16 类型数据指针
# - gy.data_ptr<bf16>():输出梯度 gy 的 bf16 类型数据指针
# - gw.data_ptr<bf16>():权重梯度 gw 的 bf16 类型数据指针
# - gu.data_ptr<bf16>():输入梯度 gu 的 bf16 类型数据指针
# - gk.data_ptr<bf16>():内核梯度 gk 的 bf16 类型数据指针
# - gv.data_ptr<bf16>():中间变量梯度 gv 的 bf16 类型数据指针
cuda_backward_bf16(B, T, C, w.data_ptr<float>(), u.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), y.data_ptr<bf16>(),
gy.data_ptr<bf16>(), gw.data_ptr<bf16>(), gu.data_ptr<bf16>(), gk.data_ptr<bf16>(), gv.data_ptr<bf16>());
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// 定义 Python 绑定模块的函数 "forward",与 C++ 函数 &forward 绑定,描述为 "wkv forward"
m.def("forward", &forward, "wkv forward");
// 定义 Python 绑定模块的函数 "forward_bf16",与 C++ 函数 &forward_bf16 绑定,描述为 "wkv forward bf16"
m.def("forward_bf16", &forward_bf16, "wkv forward bf16");
// 定义 Python 绑定模块的函数 "forward_with_state",与 C++ 函数 &forward_with_state 绑定,描述为 "wkv forward with state"
m.def("forward_with_state", &forward_with_state, "wkv forward with state");
// 定义 Python 绑定模块的函数 "forward_with_state_bf16",与 C++ 函数 &forward_with_state_bf16 绑定,描述为 "wkv forward with state bf16"
m.def("forward_with_state_bf16", &forward_with_state_bf16, "wkv forward with state bf16");
// 定义 Python 绑定模块的函数 "backward",与 C++ 函数 &backward 绑定,描述为 "wkv backward"
m.def("backward", &backward, "wkv backward");
// 定义 Python 绑定模块的函数 "backward_bf16",与 C++ 函数 &backward_bf16 绑定,描述为 "wkv backward bf16"
m.def("backward_bf16", &backward_bf16, "wkv backward bf16");
}
TORCH_LIBRARY(wkv, m) {
// 在 Torch 的 wkv 库中注册函数 "forward",与 C++ 函数 forward 绑定
m.def("forward", forward);
// 在 Torch 的 wkv 库中注册函数 "forward_bf16",与 C++ 函数 forward_bf16 绑定
m.def("forward_bf16", forward_bf16);
// 在 Torch 的 wkv 库中注册函数 "forward_with_state",与 C++ 函数 forward_with_state 绑定
m.def("forward_with_state", forward_with_state);
// 在 Torch 的 wkv 库中注册函数 "forward_with_state_bf16",与 C++ 函数 forward_with_state_bf16 绑定
m.def("forward_with_state_bf16", forward_with_state_bf16);
// 在 Torch 的 wkv 库中注册函数 "backward",与 C++ 函数 backward 绑定
m.def("backward", backward);
// 在 Torch 的 wkv 库中注册函数 "backward_bf16",与 C++ 函数 backward_bf16 绑定
m.def("backward_bf16", backward_bf16);
}
.\kernels\yoso\common.h
# 定义宏函数,返回两个数中较小的一个
#define min(a, b) ((a)<(b)?(a):(b))
# 定义宏函数,返回两个数中较大的一个
#define max(a, b) ((a)>(b)?(a):(b))
# 定义宏函数,对两个数进行向上取整的除法运算
#define ceil_divide(a, b) ((a)/(b)+((a)%(b)!=0))
# 定义宏函数,根据条件选择返回其中一个值
#define select(cond, a, b) ((cond)?(a):(b))
# 定义常数 PI,表示圆周率
#define PI 3.141592
# 定义常数 EPSILON,表示一个小的正数,通常用于浮点数比较的容差
#define EPSILON 1e-8
# 定义常数 MAX_VAL,表示一个较大的数值上限
#define MAX_VAL 1e12
# 定义常数 MIN_VAL,表示一个较小的数值下限
#define MIN_VAL -1e12
# 定义常数 EMPTY_VALUE,表示一个特定的空值或未初始化值
#define EMPTY_VALUE -1
.\kernels\yoso\common_cuda.h
# 定义每个线程块中的最大线程数
#define MAX_THREADS_PER_BLOCK 1024
# 定义优化后的推荐线程块中的线程数
#define OPTIMAL_THREADS_PER_BLOCK 256
# 定义线程束(warp)的大小
#define WARP_SIZE 32
# 定义在X方向上的最大线程块数量
#define MAX_NUM_BLOCK_X 2147483647
# 定义在Y方向上的最大线程块数量
#define MAX_NUM_BLOCK_Y 65535
# 定义在Z方向上的最大线程块数量
#define MAX_NUM_BLOCK_Z 65535
# 定义每个线程块可用的最大共享内存量
#define MAX_SHARED_MEM_PER_BLOCK 48000
# 定义一个掩码,包含所有位都是1
#define FULL_MASK 0xffffffff
.\kernels\yoso\common_cuda_device.h
# 包含公共头文件,这里假设 "common.h" 包含了项目中的通用定义和声明
#include "common.h"
# 定义一个模板函数 set_insert,用于向集合中插入元素
template<typename T>
__device__ int set_insert(T *set, int set_size, T value) {
# 计算值在集合中的插入位置
int slot = value % set_size;
int start_slot = slot;
# 循环尝试插入值,直到成功或者集合已满
while (true) {
# 使用原子操作 CAS(Compare and Swap),尝试在集合中插入值
T prev = atomicCAS(&set[slot], EMPTY_VALUE, value);
# 如果插入成功或者集合中已经存在相同值,则返回插入位置
if (prev == EMPTY_VALUE || prev == value) {
return slot;
}
# 如果插入失败,则尝试下一个位置
slot = (slot + 1) % set_size;
# 如果回到起始位置,表示集合已满
if (slot == start_slot) {
return -1;
}
}
return -1;
}
# 定义一个模板函数 set_lookup,用于在集合中查找元素的位置
template<typename T>
__device__ int set_lookup(T *set, int set_size, T value) {
# 计算值在集合中的起始位置
int slot = value % set_size;
int start_slot = slot;
# 循环查找值,直到找到或者集合遍历完毕
while (true) {
# 如果当前位置的值等于要查找的值,则返回该位置
if (set[slot] == value) {
return slot;
}
# 否则尝试下一个位置
slot = (slot + 1) % set_size;
# 如果回到起始位置,表示值不在集合中
if (slot == start_slot) {
return -1;
}
}
return -1;
}
# 定义一个模板函数 init_buffer,用于初始化缓冲区
template<typename T>
__device__ void init_buffer(T init_value, T *buffer, int buffer_size, int num_threads, int thread_id) {
# 同步所有线程,确保前面的操作已完成
__syncthreads();
# 循环初始化缓冲区
for (int i = 0; i < buffer_size; i = i + num_threads) {
int offset_idx = i + thread_id;
# 如果当前线程需要处理有效的索引,则初始化缓冲区值
if (offset_idx < buffer_size) {
buffer[offset_idx] = init_value;
}
}
# 再次同步所有线程,确保所有初始化操作已完成
__syncthreads();
}
# 定义一个模板函数 copy_data,用于从源地址复制数据到目标地址
template<typename T>
__device__ void copy_data(T *src_pt, T *dist_pt, int data_length, int num_threads, int thread_id) {
# 同步所有线程,确保前面的操作已完成
__syncthreads();
# 循环复制数据
for (int i = 0; i < data_length; i = i + num_threads) {
int offset_idx = i + thread_id;
# 如果当前线程需要处理有效的索引,则复制数据
if (offset_idx < data_length) {
dist_pt[offset_idx] = src_pt[offset_idx];
}
}
# 再次同步所有线程,确保所有复制操作已完成
__syncthreads();
}
# 定义一个模板函数 init_buffer_nonblocking,用于非阻塞方式初始化缓冲区
template<typename T>
__device__ void init_buffer_nonblocking(T init_value, T *buffer, int buffer_size, int num_threads, int thread_id) {
# 循环初始化缓冲区,无需同步线程
for (int i = 0; i < buffer_size; i = i + num_threads) {
int offset_idx = i + thread_id;
# 如果当前线程需要处理有效的索引,则初始化缓冲区值
if (offset_idx < buffer_size) {
buffer[offset_idx] = init_value;
}
}
}
# 定义一个模板函数 copy_data_nonblocking,用于非阻塞方式从源地址复制数据到目标地址
template<typename T>
__device__ void copy_data_nonblocking(T *src_pt, T *dist_pt, int data_length, int num_threads, int thread_id) {
# 循环复制数据,无需同步线程
for (int i = 0; i < data_length; i = i + num_threads) {
int offset_idx = i + thread_id;
# 如果当前线程需要处理有效的索引,则复制数据
if (offset_idx < data_length) {
dist_pt[offset_idx] = src_pt[offset_idx];
}
}
}
.\kernels\yoso\fast_lsh_cumulation.h
// 导入 PyTorch C++ 扩展头文件
#include <torch/extension.h>
// 导入 ATen 库的头文件
#include <ATen/ATen.h>
// 导入 STL 中的 vector 容器
#include <vector>
// 定义快速哈希(版本1)的核函数,返回多个张量作为结果
std::vector<at::Tensor> fast_hash_ver1_kernel(
// 查询掩码张量
at::Tensor query_mask,
// 查询向量张量
at::Tensor query_vector,
// 关键字掩码张量
at::Tensor key_mask,
// 关键字向量张量
at::Tensor key_vector,
// 哈希函数数量
int num_hash_f,
// 哈希码长度
int hash_code_len,
// 是否使用 CUDA
bool use_cuda
);
// 定义哈希累积(版本1)的核函数,返回张量作为结果
at::Tensor lsh_cumulation_ver1_kernel(
// 查询掩码张量
at::Tensor query_mask,
// 查询哈希码张量
at::Tensor query_hash_code,
// 关键字掩码张量
at::Tensor key_mask,
// 关键字哈希码张量
at::Tensor key_hash_code,
// 值张量
at::Tensor value,
// 哈希表容量
int hashtable_capacity,
// 是否使用 CUDA
bool use_cuda
);
// 定义加权哈希累积(版本1)的核函数,返回张量作为结果
at::Tensor lsh_weighted_cumulation_ver1_kernel(
// 查询掩码张量
at::Tensor query_mask,
// 查询哈希码张量
at::Tensor query_hash_code,
// 查询权重张量
at::Tensor query_weight,
// 关键字掩码张量
at::Tensor key_mask,
// 关键字哈希码张量
at::Tensor key_hash_code,
// 关键字权重张量
at::Tensor key_weight,
// 值张量
at::Tensor value,
// 哈希表容量
int hashtable_capacity,
// 是否使用 CUDA
bool use_cuda
);
// 定义加权哈希累积(版本2、3、4)的核函数,具体功能与版本1类似
// 只是版本号不同,参数及返回值的张量类型与数量相同,不再重复注释每个版本的功能
at::Tensor lsh_weighted_cumulation_ver2_kernel(
at::Tensor query_mask,
at::Tensor query_hash_code,
at::Tensor query_weight,
at::Tensor key_mask,
at::Tensor key_hash_code,
at::Tensor key_weight,
at::Tensor value,
int hashtable_capacity,
bool use_cuda
);
at::Tensor lsh_weighted_cumulation_ver3_kernel(
at::Tensor query_mask,
at::Tensor query_hash_code,
at::Tensor query_weight,
at::Tensor key_mask,
at::Tensor key_hash_code,
at::Tensor key_weight,
at::Tensor value,
int hashtable_capacity,
bool use_cuda
);
at::Tensor lsh_weighted_cumulation_ver4_kernel(
at::Tensor query_mask,
at::Tensor query_hash_code,
at::Tensor query_weight,
at::Tensor key_mask,
at::Tensor key_hash_code,
at::Tensor key_weight,
at::Tensor value,
int hashtable_capacity,
bool use_cuda
);
.\kernels\yoso\fast_lsh_cumulation_cuda.h
__global__ void fast_hash_ver1_cuda_kernel(
int *mask, // [batch_size, num_vector],用于存储掩码数据的整数指针
float *vector, // [batch_size, num_vector, vector_dim],存储向量数据的浮点数指针
int *Dmat, // [3, num_part, vector_dim],存储分割矩阵数据的整数指针
int *hash_code, // [batch_size, num_vector, num_hash_f],存储哈希码数据的整数指针
int batch_size, // 批处理大小,整数参数
int num_vector, // 向量数量,整数参数
int vector_dim, // 向量维度,整数参数
int num_part, // 分割数,整数参数
int num_hash_f, // 哈希函数数量,整数参数
int hash_code_len // 哈希码长度,整数参数
);
__global__ void lsh_cumulation_ver1_step1_cuda_kernel(
int *key_mask, // [batch_size, num_key],用于存储键掩码数据的整数指针
int *key_hash_code, // [batch_size, num_key, num_hash_f],存储键哈希码数据的整数指针
float *value, // [batch_size, num_key, value_dim],存储值数据的浮点数指针
float *hashtable_value, // [batch_size, num_hash_f, hashtable_capacity, value_dim],哈希表值的浮点数指针
int batch_size, // 批处理大小,整数参数
int num_hash_f, // 哈希函数数量,整数参数
int hashtable_capacity, // 哈希表容量,整数参数
int num_key, // 键数量,整数参数
int value_dim, // 值维度,整数参数
int offset_warp // 偏移量(warp),整数参数
);
__global__ void lsh_cumulation_ver1_step2_cuda_kernel(
int *query_mask, // [batch_size, num_query],用于存储查询掩码数据的整数指针
int *query_hash_code, // [batch_size, num_query, num_hash_f],存储查询哈希码数据的整数指针
float *hashtable_value, // [batch_size, num_hash_f, hashtable_capacity, value_dim],哈希表值的浮点数指针
float *cumulation_value, // [batch_size, num_query, value_dim],累积值的浮点数指针
int batch_size, // 批处理大小,整数参数
int num_hash_f, // 哈希函数数量,整数参数
int hashtable_capacity, // 哈希表容量,整数参数
int num_query, // 查询数量,整数参数
int value_dim, // 值维度,整数参数
int offset_warp // 偏移量(warp),整数参数
);
__global__ void lsh_weighted_cumulation_ver1_step1_cuda_kernel(
int *key_mask, // [batch_size, num_key],用于存储键掩码数据的整数指针
int *key_hash_code, // [batch_size, num_key, num_hash_f],存储键哈希码数据的整数指针
float *key_weight, // [batch_size, num_key, weight_dim],存储键权重数据的浮点数指针
float *value, // [batch_size, num_key, value_dim],存储值数据的浮点数指针
float *hashtable_value, // [batch_size, num_hash_f, hashtable_capacity, WARP_SIZE],哈希表值的浮点数指针
int batch_size, // 批处理大小,整数参数
int num_hash_f, // 哈希函数数量,整数参数
int hashtable_capacity, // 哈希表容量,整数参数
int num_key, // 键数量,整数参数
int value_dim, // 值维度,整数参数
int weight_dim, // 权重维度,整数参数
int offset_warp, // 偏移量(warp),整数参数
int weight_idx // 权重索引,整数参数
);
__global__ void lsh_weighted_cumulation_ver1_step2_cuda_kernel(
int *query_mask, // [batch_size, num_query],用于存储查询掩码数据的整数指针
int *query_hash_code, // [batch_size, num_query, num_hash_f],存储查询哈希码数据的整数指针
float *query_weight, // [batch_size, num_query, weight_dim],存储查询权重数据的浮点数指针
float *hashtable_value, // [batch_size, num_hash_f, hashtable_capacity, WARP_SIZE],哈希表值的浮点数指针
float *cumulation_value, // [batch_size, num_query, value_dim],累积值的浮点数指针
int batch_size, // 批处理大小,整数参数
int num_hash_f, // 哈希函数数量,整数参数
int hashtable_capacity, // 哈希表容量,整数参数
int num_query, // 查询数量,整数参数
int value_dim, // 值维度,整数参数
int weight_dim, // 权重维度,整数参数
int offset_warp, // 偏移量(warp),整数参数
int weight_idx // 权重索引,整数参数
);
__global__ void count_sort_step1_cuda_kernel(
int *key_mask, // [batch_size, num_key],用于存储键掩码数据的整数指针
int *key_hash_code, // [batch_size, num_key, num_hash_f],存储键哈希码数据的整数指针
int *count_sort_table, // [batch_size, num_hash_f, hashtable_capacity],计数排序表的整数指针
int batch_size, // 批处理大小,整数参数
int num_hash_f, // 哈希函数数量,整数参数
int hashtable_capacity,// 哈希表容量,整数参数
int num_key // 键数量,整数参数
);
__global__ void count_sort_step2_cuda_kernel(
int *count_sort_table, // [batch_size, num_hash_f, hashtable_capacity],计数排序表的整数指针
int batch_size, // 批处理大小,整数参数
int num_hash_f, // 哈希函数数量,整数参数
int hashtable_capacity // 哈希表容量,整数参数
);
__global__ void count_sort_step3_cuda_kernel(
int *key_mask, // 输入:表示批次中每个关键字的掩码数组 [batch_size, num_key]
int *key_hash_code, // 输入:表示批次中每个关键字的哈希码数组 [batch_size, num_key, num_hash_f]
int *count_sort_table, // 输入/输出:计数排序表格,用于存储排序后的关键字索引 [batch_size, num_hash_f, hashtable_capacity]
int *key_sorted_idxes, // 输出:存储排序后的关键字索引 [batch_size, num_hash_f, num_key]
int batch_size, // 输入:批次大小
int num_hash_f, // 输入:哈希函数数量
int hashtable_capacity, // 输入:哈希表容量
int num_key // 输入:每个批次中的关键字数量
);
__global__ void extract_query_info_cuda_kernel(
int *query_mask, // 输入:表示批次中每个查询的掩码数组 [batch_size, num_query]
int *query_hash_code, // 输入:表示批次中每个查询的哈希码数组 [batch_size, num_query, num_hash_f]
int *count_sort_table, // 输入:计数排序表格,用于存储排序后的关键字索引 [batch_size, num_hash_f, hashtable_capacity]
int *query_info, // 输出:存储查询信息,包括关键字索引和哈希函数索引 [batch_size, num_query, 2, num_hash_f]
int batch_size, // 输入:批次大小
int num_hash_f, // 输入:哈希函数数量
int hashtable_capacity,// 输入:哈希表容量
int num_query // 输入:每个批次中的查询数量
);
__global__ void lsh_weighted_cumulation_ver2_step2_cuda_kernel(
int *query_mask, // 输入:表示批次中每个查询的掩码数组 [batch_size, num_query]
int *query_info, // 输入:存储查询信息,包括关键字索引和哈希函数索引 [batch_size, num_query, 2, num_hash_f]
int *key_sorted_idxes, // 输入:存储排序后的关键字索引 [batch_size, num_hash_f, num_key]
float *query_weight, // 输入:查询的权重数组 [batch_size, num_query, weight_dim]
float *key_weight, // 输入:关键字的权重数组 [batch_size, num_key, weight_dim]
float *value, // 输入:关键字对应的值数组 [batch_size, num_key, value_dim]
float *cumulation_value, // 输出:累积后的值数组 [batch_size, num_query, value_dim]
int batch_size, // 输入:批次大小
int num_hash_f, // 输入:哈希函数数量
int num_query, // 输入:每个批次中的查询数量
int num_key, // 输入:每个批次中的关键字数量
int value_dim, // 输入:值的维度
int weight_dim // 输入:权重的维度
);
__global__ void lsh_weighted_cumulation_ver3_step2_cuda_kernel(
int *query_sorted_idxes, // 输入:存储排序后的查询索引 [batch_size, num_hash_f, num_query]
int *key_mask, // 输入:表示批次中每个关键字的掩码数组 [batch_size, num_key]
int *key_info, // 输入:关键字的信息数组,包括索引和哈希函数索引 [batch_size, num_key, 2, num_hash_f]
float *query_weight, // 输入:查询的权重数组 [batch_size, num_query, weight_dim]
float *key_weight, // 输入:关键字的权重数组 [batch_size, num_key, weight_dim]
float *value, // 输入:关键字对应的值数组 [batch_size, num_key, value_dim]
float *cumulation_value, // 输出:累积后的值数组 [batch_size, num_query, value_dim]
int batch_size, // 输入:批次大小
int num_hash_f, // 输入:哈希函数数量
int num_query, // 输入:每个批次中的查询数量
int num_key, // 输入:每个批次中的关键字数量
int value_dim, // 输入:值的维度
int weight_dim // 输入:权重的维度
);
__global__ void lsh_weighted_cumulation_ver4_step2_cuda_kernel(
int *query_sorted_idxes, // 输入:存储排序后的查询索引 [batch_size, num_hash_f, num_query]
int *key_mask, // 输入:表示批次中每个关键字的掩码数组 [batch_size, num_key]
int *key_info, // 输入:关键字的信息数组,包括索引和哈希函数索引 [batch_size, num_key, 2, num_hash_f]
float *query_weight, // 输入:查询的权重数组 [batch_size, num_query, weight_dim]
float *key_weight, // 输入:关键字的权重数组 [batch_size, num_key, weight_dim]
float *value, // 输入:关键字对应的值数组 [batch_size, num_key, value_dim]
float *cumulation_value, // 输出:累积后的值数组 [batch_size, num_query, value_dim]
int batch_size, // 输入:批次大小
int num_hash_f, // 输入:哈希函数数量
int num_query, // 输入:每个批次中的查询数量
int num_key, // 输入:每个批次中的关键字数量
int value_dim, // 输入:值的维度
int weight_dim // 输入:权重的维度
);
.\kernels\yoso\fast_lsh_cumulation_torch.cpp
#include <torch/extension.h>
#include <ATen/ATen.h>
#include "fast_lsh_cumulation.h" // 引入自定义的头文件,包含快速LSH累积相关的函数声明
#include "common_cuda.h" // 引入自定义的头文件,包含通用的CUDA函数声明
#include <vector> // 引入标准库中的向量容器
// 快速哈希函数,调用指定版本的核函数处理哈希计算
std::vector<at::Tensor> fast_hash(
at::Tensor query_mask, // 查询掩码,形状为[batch_size, num_query]
at::Tensor query_vector, // 查询向量,形状为[batch_size, num_query, vector_dim]
at::Tensor key_mask, // 键掩码,形状为[batch_size, num_key]
at::Tensor key_vector, // 键向量,形状为[batch_size, num_key, vector_dim]
int num_hash_f, // 哈希函数数量
int hash_code_len, // 哈希码长度
bool use_cuda, // 是否使用CUDA加速
int version // 函数版本号
) {
return fast_hash_ver1_kernel(
query_mask,
query_vector,
key_mask,
key_vector,
num_hash_f,
hash_code_len,
use_cuda
);
}
// LSH累积函数,调用指定版本的核函数执行LSH累积操作
at::Tensor lsh_cumulation(
at::Tensor query_mask, // 查询掩码,形状为[batch_size, num_query]
at::Tensor query_hash_code, // 查询哈希码,形状为[batch_size, num_query, num_hash_f]
at::Tensor key_mask, // 键掩码,形状为[batch_size, num_key]
at::Tensor key_hash_code, // 键哈希码,形状为[batch_size, num_key, num_hash_f]
at::Tensor value, // 值,形状为[batch_size, num_key, value_dim]
int hashtable_capacity, // 哈希表容量
bool use_cuda, // 是否使用CUDA加速
int version // 函数版本号
) {
return lsh_cumulation_ver1_kernel(
query_mask,
query_hash_code,
key_mask,
key_hash_code,
value,
hashtable_capacity,
use_cuda
);
}
// 加权LSH累积函数,根据版本号调用不同的核函数执行不同版本的加权LSH累积操作
at::Tensor lsh_weighted_cumulation(
at::Tensor query_mask, // 查询掩码,形状为[batch_size, num_query]
at::Tensor query_hash_code, // 查询哈希码,形状为[batch_size, num_query, num_hash_f]
at::Tensor query_weight, // 查询权重,形状为[batch_size, num_query, weight_dim]
at::Tensor key_mask, // 键掩码,形状为[batch_size, num_key]
at::Tensor key_hash_code, // 键哈希码,形状为[batch_size, num_key, num_hash_f]
at::Tensor key_weight, // 键权重,形状为[batch_size, num_key, weight_dim]
at::Tensor value, // 值,形状为[batch_size, num_key, value_dim]
int hashtable_capacity, // 哈希表容量
bool use_cuda, // 是否使用CUDA加速
int version // 函数版本号
) {
if (version == 1) {
return lsh_weighted_cumulation_ver1_kernel(
query_mask,
query_hash_code,
query_weight,
key_mask,
key_hash_code,
key_weight,
value,
hashtable_capacity,
use_cuda
);
} else if (version == 2) {
return lsh_weighted_cumulation_ver2_kernel(
query_mask,
query_hash_code,
query_weight,
key_mask,
key_hash_code,
key_weight,
value,
hashtable_capacity,
use_cuda
);
} else if (version == 3) {
return lsh_weighted_cumulation_ver3_kernel(
query_mask,
query_hash_code,
query_weight,
key_mask,
key_hash_code,
key_weight,
value,
hashtable_capacity,
use_cuda
);
} else if (version == 4) {
return lsh_weighted_cumulation_ver4_kernel(
query_mask,
query_hash_code,
query_weight,
key_mask,
key_hash_code,
key_weight,
value,
hashtable_capacity,
use_cuda
);
} else {
// 默认情况下使用第三个版本的核函数
return lsh_weighted_cumulation_ver3_kernel(
query_mask,
query_hash_code,
query_weight,
key_mask,
key_hash_code,
key_weight,
value,
hashtable_capacity,
use_cuda
);
}
}
# 使用 PYBIND11_MODULE 宏定义一个 Python 模块
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
# 将 fast_hash 函数绑定到 Python 模块中,并命名为 "Fast Hash (CUDA)"
m.def("fast_hash", &fast_hash, "Fast Hash (CUDA)");
# 将 lsh_cumulation 函数绑定到 Python 模块中,并命名为 "LSH Cumulation (CUDA)"
m.def("lsh_cumulation", &lsh_cumulation, "LSH Cumulation (CUDA)");
# 将 lsh_weighted_cumulation 函数绑定到 Python 模块中,并命名为 "LSH Weighted Cumulation (CUDA)"
m.def("lsh_weighted_cumulation", &lsh_weighted_cumulation, "LSH Weighted Cumulation (CUDA)");
}
.\modelcard.py
# 导入所需的模块和库
import copy # 导入深拷贝函数
import json # 导入处理 JSON 的库
import os # 导入操作系统相关的功能
import warnings # 导入警告处理模块
from dataclasses import dataclass # 导入 dataclass 用于创建数据类
from pathlib import Path # 导入 Path 类用于处理文件路径
from typing import Any, Dict, List, Optional, Union # 导入类型提示相关功能
import requests # 导入处理 HTTP 请求的库
import yaml # 导入处理 YAML 文件的库
from huggingface_hub import model_info # 导入 Hugging Face Hub 的模型信息功能
from huggingface_hub.utils import HFValidationError # 导入 Hugging Face Hub 的验证错误处理
from . import __version__ # 导入当前包的版本信息
from .models.auto.modeling_auto import ( # 导入自动生成模型的相关映射名称
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
MODEL_FOR_CTC_MAPPING_NAMES,
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES,
MODEL_FOR_MASKED_LM_MAPPING_NAMES,
MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES,
MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES,
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES,
)
from .training_args import ParallelMode # 导入并行模式参数
from .utils import ( # 导入工具函数和常量
MODEL_CARD_NAME,
cached_file,
is_datasets_available,
is_offline_mode,
is_tf_available,
is_tokenizers_available,
is_torch_available,
logging,
)
TASK_MAPPING = { # 定义任务与模型映射关系的字典
"text-generation": MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
"image-classification": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
"image-segmentation": MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES,
"fill-mask": MODEL_FOR_MASKED_LM_MAPPING_NAMES,
"object-detection": MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES,
"question-answering": MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES,
"text2text-generation": MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
"text-classification": MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
"table-question-answering": MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES,
"token-classification": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
"audio-classification": MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
"automatic-speech-recognition": {**MODEL_FOR_CTC_MAPPING_NAMES, **MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES},
"zero-shot-image-classification": MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES,
}
logger = logging.get_logger(__name__) # 获取当前模块的日志记录器
# 定义结构化的模型卡片类。存储模型卡片以及加载/下载/保存模型卡片的方法。
# 请阅读以下论文以获取关于各部分的详细信息和解释:“Model Cards for Model Reporting” 作者包括 Margaret Mitchell, Simone Wu, Andrew Zaldivar, 等人提出了模型卡片的建议。链接:https://arxiv.org/abs/1810.03993
# 注意:可以加载和保存模型卡片到磁盘上。
"""
# 初始化方法,用于创建模型卡片对象
def __init__(self, **kwargs):
# 发出警告,表示该类 `ModelCard` 已被弃用,并将在 Transformers 的第五版中移除
warnings.warn(
"The class `ModelCard` is deprecated and will be removed in version 5 of Transformers", FutureWarning
)
# 推荐的属性来源于 https://arxiv.org/abs/1810.03993(见论文)
# 设置模型细节
self.model_details = kwargs.pop("model_details", {})
# 设置预期使用
self.intended_use = kwargs.pop("intended_use", {})
# 设置因素
self.factors = kwargs.pop("factors", {})
# 设置度量
self.metrics = kwargs.pop("metrics", {})
# 设置评估数据
self.evaluation_data = kwargs.pop("evaluation_data", {})
# 设置训练数据
self.training_data = kwargs.pop("training_data", {})
# 设置定量分析
self.quantitative_analyses = kwargs.pop("quantitative_analyses", {})
# 设置伦理考虑
self.ethical_considerations = kwargs.pop("ethical_considerations", {})
# 设置注意事项和建议
self.caveats_and_recommendations = kwargs.pop("caveats_and_recommendations", {})
# 打开额外的属性
for key, value in kwargs.items():
try:
setattr(self, key, value)
except AttributeError as err:
# 如果无法设置属性,则记录错误信息并抛出异常
logger.error(f"Can't set {key} with value {value} for {self}")
raise err
# 将模型卡片对象保存到指定的目录或文件
def save_pretrained(self, save_directory_or_file):
"""Save a model card object to the directory or file `save_directory_or_file`."""
# 如果保存目录存在,则使用预定义的文件名保存,方便使用 `from_pretrained` 加载
if os.path.isdir(save_directory_or_file):
output_model_card_file = os.path.join(save_directory_or_file, MODEL_CARD_NAME)
else:
output_model_card_file = save_directory_or_file
# 将模型卡片对象保存为 JSON 文件
self.to_json_file(output_model_card_file)
logger.info(f"Model card saved in {output_model_card_file}")
# 从 Python 字典中构造一个 `ModelCard` 对象的类方法
@classmethod
def from_dict(cls, json_object):
"""Constructs a `ModelCard` from a Python dictionary of parameters."""
return cls(**json_object)
# 从 JSON 文件中构造一个 `ModelCard` 对象的类方法
@classmethod
def from_json_file(cls, json_file):
"""Constructs a `ModelCard` from a json file of parameters."""
# 读取 JSON 文件内容
with open(json_file, "r", encoding="utf-8") as reader:
text = reader.read()
# 解析 JSON 文本为 Python 字典对象
dict_obj = json.loads(text)
# 使用字典对象构造一个新的 `ModelCard` 对象
return cls(**dict_obj)
# 判断两个 `ModelCard` 对象是否相等的特殊方法
def __eq__(self, other):
return self.__dict__ == other.__dict__
# 返回 `ModelCard` 对象的字符串表示形式的特殊方法
def __repr__(self):
return str(self.to_json_string())
# 将当前对象实例序列化为一个 Python 字典
def to_dict(self):
"""Serializes this instance to a Python dictionary."""
# 深拷贝当前对象的所有属性到 output 字典中
output = copy.deepcopy(self.__dict__)
return output
# 将当前对象实例序列化为 JSON 字符串
def to_json_string(self):
"""Serializes this instance to a JSON string."""
# 调用 to_dict 方法获取对象的字典表示,转换为带缩进和排序键的 JSON 字符串,并添加换行符
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
# 将当前对象实例保存到一个 JSON 文件中
def to_json_file(self, json_file_path):
"""Save this instance to a json file."""
# 打开指定路径的 JSON 文件,使用 UTF-8 编码写入对象的 JSON 字符串表示
with open(json_file_path, "w", encoding="utf-8") as writer:
writer.write(self.to_json_string())
AUTOGENERATED_TRAINER_COMMENT = """
<!-- This model card has been generated automatically according to the information the Trainer had access to. You
should probably proofread and complete it, then remove this comment. -->
"""
AUTOGENERATED_KERAS_COMMENT = """
<!-- This model card has been generated automatically according to the information Keras had access to. You should
probably proofread and complete it, then remove this comment. -->
"""
TASK_TAG_TO_NAME_MAPPING = {
"fill-mask": "Masked Language Modeling", # 映射任务标签 "fill-mask" 到任务名称 "Masked Language Modeling"
"image-classification": "Image Classification", # 映射任务标签 "image-classification" 到任务名称 "Image Classification"
"image-segmentation": "Image Segmentation", # 映射任务标签 "image-segmentation" 到任务名称 "Image Segmentation"
"multiple-choice": "Multiple Choice", # 映射任务标签 "multiple-choice" 到任务名称 "Multiple Choice"
"object-detection": "Object Detection", # 映射任务标签 "object-detection" 到任务名称 "Object Detection"
"question-answering": "Question Answering", # 映射任务标签 "question-answering" 到任务名称 "Question Answering"
"summarization": "Summarization", # 映射任务标签 "summarization" 到任务名称 "Summarization"
"table-question-answering": "Table Question Answering", # 映射任务标签 "table-question-answering" 到任务名称 "Table Question Answering"
"text-classification": "Text Classification", # 映射任务标签 "text-classification" 到任务名称 "Text Classification"
"text-generation": "Causal Language Modeling", # 映射任务标签 "text-generation" 到任务名称 "Causal Language Modeling"
"text2text-generation": "Sequence-to-sequence Language Modeling", # 映射任务标签 "text2text-generation" 到任务名称 "Sequence-to-sequence Language Modeling"
"token-classification": "Token Classification", # 映射任务标签 "token-classification" 到任务名称 "Token Classification"
"translation": "Translation", # 映射任务标签 "translation" 到任务名称 "Translation"
"zero-shot-classification": "Zero Shot Classification", # 映射任务标签 "zero-shot-classification" 到任务名称 "Zero Shot Classification"
"automatic-speech-recognition": "Automatic Speech Recognition", # 映射任务标签 "automatic-speech-recognition" 到任务名称 "Automatic Speech Recognition"
"audio-classification": "Audio Classification", # 映射任务标签 "audio-classification" 到任务名称 "Audio Classification"
}
METRIC_TAGS = [
"accuracy", # 表示度量标签 "accuracy",用于评估模型准确性
"bleu", # 表示度量标签 "bleu",用于评估机器翻译质量
"f1", # 表示度量标签 "f1",用于评估分类和信息检索等任务的准确性
"matthews_correlation", # 表示度量标签 "matthews_correlation",用于评估二分类问题中的相关性
"pearsonr", # 表示度量标签 "pearsonr",用于评估两个变量之间的线性相关性
"precision", # 表示度量标签 "precision",用于评估分类模型中的精确性
"recall", # 表示度量标签 "recall",用于评估分类模型中的召回率
"rouge", # 表示度量标签 "rouge",用于评估文本摘要生成模型的质量
"sacrebleu", # 表示度量标签 "sacrebleu",用于机器翻译任务中的 BLEU 得分
"spearmanr", # 表示度量标签 "spearmanr",用于评估两个变量的非线性相关性
"wer", # 表示度量标签 "wer",用于评估自动语音识别中的词错误率
]
def _listify(obj):
if obj is None:
return [] # 如果对象为 None,则返回空列表
elif isinstance(obj, str):
return [obj] # 如果对象为字符串,则返回包含该字符串的列表
else:
return obj # 否则返回原始对象
def _insert_values_as_list(metadata, name, values):
if values is None:
return metadata # 如果值为 None,则返回元数据本身
if isinstance(values, str):
values = [values] # 如果值为字符串,则转换成单元素列表
values = [v for v in values if v is not None] # 过滤掉值中的 None 元素
if len(values) == 0:
return metadata # 如果列表为空,则返回元数据本身
metadata[name] = values # 将处理后的列表赋给元数据对应的名称
return metadata # 返回更新后的元数据
def infer_metric_tags_from_eval_results(eval_results):
if eval_results is None:
return {} # 如果评估结果为 None,则返回空字典
result = {} # 初始化结果字典
for key in eval_results.keys():
if key.lower().replace(" ", "_") in METRIC_TAGS:
result[key.lower().replace(" ", "_")] = key # 将符合度量标签的键添加到结果字典中
elif key.lower() == "rouge1":
result["rouge"] = key # 特别处理 "rouge1",将其映射为 "rouge"
return result # 返回最终的结果字典
def _insert_value(metadata, name, value):
if value is None:
return metadata # 如果值为 None,则返回元数据本身
metadata[name] = value # 将值插入到元数据中对应的名称
return metadata # 返回更新后的元数据
def is_hf_dataset(dataset):
if not is_datasets_available():
return False # 如果 datasets 库不可用,则返回 False
from datasets import Dataset, IterableDataset
return isinstance(dataset, (Dataset, IterableDataset)) # 判断 dataset 是否是 Dataset 或 IterableDataset 类的实例
def _get_mapping_values(mapping):
result = [] # 初始化结果列表
for v in mapping.values():
if isinstance(v, (tuple, list)):
result += list(v) # 如果值是元组或列表,则将其展开并添加到结果列表中
else:
result.append(v) # 否则直接添加到结果列表中
return result # 返回所有映射值组成的列表
@dataclass
class TrainingSummary:
model_name: str # 模型名称
language: Optional[Union[str, List[str]]] = None # 语言属性,可以是字符串或字符串列表,默认为 None
license: Optional[str] = None # 许可证信息,默认为 None
"""
# 标签,可以是字符串或字符串列表,用于标识模型的类别或特征
tags: Optional[Union[str, List[str]]] = None
# 微调自哪个模型而来的信息
finetuned_from: Optional[str] = None
# 任务,可以是字符串或字符串列表,描述模型训练的任务类型
tasks: Optional[Union[str, List[str]]] = None
# 数据集,可以是字符串或字符串列表,指定用于训练的数据集名称或描述
dataset: Optional[Union[str, List[str]]] = None
# 数据集标签,可以是字符串或字符串列表,用于描述数据集的特征或类别
dataset_tags: Optional[Union[str, List[str]]] = None
# 数据集参数,可以是字符串或字符串列表,指定数据集的详细参数
dataset_args: Optional[Union[str, List[str]]] = None
# 数据集元数据,是一个字典,包含关于数据集的其他信息
dataset_metadata: Optional[Dict[str, Any]] = None
# 评估结果,是一个字典,包含模型评估的指标和结果
eval_results: Optional[Dict[str, float]] = None
# 评估结果的行信息,是一个字符串列表,记录评估结果的详细信息
eval_lines: Optional[List[str]] = None
# 超参数,是一个字典,包含模型训练时使用的超参数信息
hyperparameters: Optional[Dict[str, Any]] = None
# 模型的来源,通常为字符串 "trainer"
source: Optional[str] = "trainer"
def __post_init__(self):
# 根据微调自的模型信息推断默认许可证
if (
self.license is None # 如果许可证为空
and not is_offline_mode() # 并且不是离线模式
and self.finetuned_from is not None # 并且有微调自的模型信息
and len(self.finetuned_from) > 0 # 并且微调自的模型信息不为空字符串
):
try:
# 获取微调自模型的信息
info = model_info(self.finetuned_from)
# 遍历模型信息的标签
for tag in info.tags:
# 如果标签以 "license:" 开头
if tag.startswith("license:"):
# 设置许可证为标签中 "license:" 后的内容
self.license = tag[8:]
except (requests.exceptions.HTTPError, requests.exceptions.ConnectionError, HFValidationError):
# 处理可能的网络请求错误或验证错误
pass
def create_model_index(self, metric_mapping):
# 初始化模型索引,包含模型名称
model_index = {"name": self.model_name}
# 将数据集相关信息转换为列表形式
dataset_names = _listify(self.dataset)
dataset_tags = _listify(self.dataset_tags)
dataset_args = _listify(self.dataset_args)
dataset_metadata = _listify(self.dataset_metadata)
# 如果参数数量不足,则用 None 补齐
if len(dataset_args) < len(dataset_tags):
dataset_args = dataset_args + [None] * (len(dataset_tags) - len(dataset_args))
# 创建数据集映射字典,将标签映射到名称
dataset_mapping = dict(zip(dataset_tags, dataset_names))
dataset_arg_mapping = dict(zip(dataset_tags, dataset_args))
dataset_metadata_mapping = dict(zip(dataset_tags, dataset_metadata))
# 创建任务映射字典,将任务标签映射到任务名称
task_mapping = {
task: TASK_TAG_TO_NAME_MAPPING[task] for task in _listify(self.tasks) if task in TASK_TAG_TO_NAME_MAPPING
}
# 初始化结果列表
model_index["results"] = []
# 如果任务映射和数据集映射都为空,则返回只包含模型名称的列表
if len(task_mapping) == 0 and len(dataset_mapping) == 0:
return [model_index]
# 如果任务映射为空,则将其设置为包含 None 的字典
if len(task_mapping) == 0:
task_mapping = {None: None}
# 如果数据集映射为空,则将其设置为包含 None 的字典
if len(dataset_mapping) == 0:
dataset_mapping = {None: None}
# 遍历所有可能的任务和数据集组合
all_possibilities = [(task_tag, ds_tag) for task_tag in task_mapping for ds_tag in dataset_mapping]
for task_tag, ds_tag in all_possibilities:
result = {}
# 如果任务标签不为空,则设置任务名称和类型
if task_tag is not None:
result["task"] = {"name": task_mapping[task_tag], "type": task_tag}
# 如果数据集标签不为空,则设置数据集名称、类型以及元数据
if ds_tag is not None:
metadata = dataset_metadata_mapping.get(ds_tag, {})
result["dataset"] = {
"name": dataset_mapping[ds_tag],
"type": ds_tag,
**metadata,
}
# 如果数据集参数不为空,则设置参数
if dataset_arg_mapping[ds_tag] is not None:
result["dataset"]["args"] = dataset_arg_mapping[ds_tag]
# 如果度量映射不为空,则设置度量结果
if len(metric_mapping) > 0:
result["metrics"] = []
for metric_tag, metric_name in metric_mapping.items():
result["metrics"].append(
{
"name": metric_name,
"type": metric_tag,
"value": self.eval_results[metric_name],
}
)
# 如果结果中包含任务、数据集和度量,则将结果添加到模型索引中
if "task" in result and "dataset" in result and "metrics" in result:
model_index["results"].append(result)
else:
# 否则,记录日志并丢弃结果以避免模型卡片被拒绝
logger.info(f"Dropping the following result as it does not have all the necessary fields:\n{result}")
# 返回包含模型索引的列表
return [model_index]
# 创建元数据的方法,用于生成模型相关的元数据信息
def create_metadata(self):
# 从评估结果推断度量标签的映射关系
metric_mapping = infer_metric_tags_from_eval_results(self.eval_results)
# 初始化一个空的元数据字典
metadata = {}
# 将语言信息插入元数据字典,作为列表形式存储
metadata = _insert_values_as_list(metadata, "language", self.language)
# 将许可证信息插入元数据字典,作为单一值存储
metadata = _insert_value(metadata, "license", self.license)
# 如果模型是从某个基础模型微调而来,且基础模型为非空字符串,则插入基础模型信息
if self.finetuned_from is not None and isinstance(self.finetuned_from, str) and len(self.finetuned_from) > 0:
metadata = _insert_value(metadata, "base_model", self.finetuned_from)
# 将标签信息插入元数据字典,作为列表形式存储
metadata = _insert_values_as_list(metadata, "tags", self.tags)
# 将数据集标签信息插入元数据字典,作为列表形式存储
metadata = _insert_values_as_list(metadata, "datasets", self.dataset_tags)
# 将度量标签映射中的键(度量名称)作为列表插入元数据字典
metadata = _insert_values_as_list(metadata, "metrics", list(metric_mapping.keys()))
# 创建模型索引并插入元数据字典中
metadata["model-index"] = self.create_model_index(metric_mapping)
# 返回生成的元数据字典
return metadata
@classmethod
def from_trainer(
cls,
trainer,
language=None,
license=None,
tags=None,
model_name=None,
finetuned_from=None,
tasks=None,
dataset_tags=None,
dataset_metadata=None,
dataset=None,
dataset_args=None,
):
# 推断默认数据集
one_dataset = trainer.eval_dataset if trainer.eval_dataset is not None else trainer.train_dataset
# 如果数据集来自 HF 数据集且缺少标签、参数或元数据,则推断默认标签
if is_hf_dataset(one_dataset) and (dataset_tags is None or dataset_args is None or dataset_metadata is None):
default_tag = one_dataset.builder_name
# 排除不是来自 Hub 的虚构数据集
if default_tag not in ["csv", "json", "pandas", "parquet", "text"]:
# 如果缺少元数据,则创建包含配置名和分割信息的元数据列表
if dataset_metadata is None:
dataset_metadata = [{"config": one_dataset.config_name, "split": str(one_dataset.split)}]
# 如果缺少标签,则使用默认标签
if dataset_tags is None:
dataset_tags = [default_tag]
# 如果缺少参数,则使用配置名作为参数
if dataset_args is None:
dataset_args = [one_dataset.config_name]
# 如果未指定数据集但指定了数据集标签,则将数据集设为数据集标签
if dataset is None and dataset_tags is not None:
dataset = dataset_tags
# 推断默认微调自
if (
finetuned_from is None
and hasattr(trainer.model.config, "_name_or_path")
and not os.path.isdir(trainer.model.config._name_or_path)
):
# 使用模型配置的名称或路径作为微调来源
finetuned_from = trainer.model.config._name_or_path
# 推断默认任务标签
if tasks is None:
model_class_name = trainer.model.__class__.__name__
# 遍历任务映射表,根据模型类名获取任务标签
for task, mapping in TASK_MAPPING.items():
if model_class_name in _get_mapping_values(mapping):
tasks = task
# 如果未指定模型名称,则使用输出目录的名称作为模型名称
if model_name is None:
model_name = Path(trainer.args.output_dir).name
# 如果模型名称为空字符串,则使用微调来源作为模型名称
if len(model_name) == 0:
model_name = finetuned_from
# 将 `generated_from_trainer` 添加到标签中
if tags is None:
tags = ["generated_from_trainer"]
elif isinstance(tags, str) and tags != "generated_from_trainer":
tags = [tags, "generated_from_trainer"]
elif "generated_from_trainer" not in tags:
tags.append("generated_from_trainer")
# 解析训练状态日志历史,获取日志行和评估结果
_, eval_lines, eval_results = parse_log_history(trainer.state.log_history)
# 从训练器中提取超参数
hyperparameters = extract_hyperparameters_from_trainer(trainer)
# 返回构造的类对象,初始化各个参数
return cls(
language=language,
license=license,
tags=tags,
model_name=model_name,
finetuned_from=finetuned_from,
tasks=tasks,
dataset=dataset,
dataset_tags=dataset_tags,
dataset_args=dataset_args,
dataset_metadata=dataset_metadata,
eval_results=eval_results,
eval_lines=eval_lines,
hyperparameters=hyperparameters,
)
@classmethod
def from_keras(
cls,
model,
model_name,
keras_history=None,
language=None,
license=None,
tags=None,
finetuned_from=None,
tasks=None,
dataset_tags=None,
dataset=None,
dataset_args=None,
# 接受以下参数并返回新的 HFModelArguments 对象
):
# 如果给定了 dataset 参数:
if dataset is not None:
# 如果 dataset 是 HF dataset 并且 dataset_tags 或 dataset_args 为 None:
if is_hf_dataset(dataset) and (dataset_tags is None or dataset_args is None):
# 使用 dataset 的构建器名称作为默认标签
default_tag = dataset.builder_name
# 排除不是来自 Hub 的虚构数据集
if default_tag not in ["csv", "json", "pandas", "parquet", "text"]:
# 如果 dataset_tags 为 None,则设为默认标签列表
if dataset_tags is None:
dataset_tags = [default_tag]
# 如果 dataset_args 为 None,则设为 dataset 的配置名称列表
if dataset_args is None:
dataset_args = [dataset.config_name]
# 如果 dataset 为 None 而 dataset_tags 不为 None,则将 dataset 设置为 dataset_tags
if dataset is None and dataset_tags is not None:
dataset = dataset_tags
# 推断默认的 finetuned_from
if (
finetuned_from is None
and hasattr(model.config, "_name_or_path")
and not os.path.isdir(model.config._name_or_path)
):
# 使用 model.config 的 _name_or_path 属性作为 finetuned_from
finetuned_from = model.config._name_or_path
# 推断默认的任务标签:
if tasks is None:
# 获取 model 的类名
model_class_name = model.__class__.__name__
# 遍历 TASK_MAPPING 中的任务映射
for task, mapping in TASK_MAPPING.items():
# 如果 model_class_name 在映射值中
if model_class_name in _get_mapping_values(mapping):
# 设置任务为当前的 task
Add ` generated_from_keras_callback to
# 解析 `logs` 参数,该参数可以是 `model.fit()` 返回的 `keras.History` 对象,也可以是传递给 `PushToHubCallback` 的累积日志字典
def parse_keras_history(logs):
if hasattr(logs, "history"):
# 如果 `logs` 对象有 `history` 属性,则看起来像是一个 `History` 对象
if not hasattr(logs, "epoch"):
# 如果 `logs` 对象没有 `epoch` 属性,表示历史记录为空,返回空结果
return None, [], {}
# 将 `epoch` 属性添加到 `logs.history` 字典中
logs.history["epoch"] = logs.epoch
# 使用 `logs.history` 替换 `logs`,统一处理为字典格式
logs = logs.history
else:
# 如果 `logs` 不是 `History` 对象,则假设它是一个包含字典列表的训练日志,我们将其转换为字典的列表格式,以匹配 `History` 对象的结构
logs = {log_key: [single_dict[log_key] for single_dict in logs] for log_key in logs[0]}
# 初始化空列表 `lines`,用于存储解析后的日志信息
lines = []
# 遍历 `epoch` 列表的长度,即遍历每个周期的日志
for i in range(len(logs["epoch"])):
# 创建当前周期的字典 `epoch_dict`,将每个日志键值对应到当前周期的值
epoch_dict = {log_key: log_value_list[i] for log_key, log_value_list in logs.items()}
# 初始化空字典 `values`,用于存储当前周期的解析后的键值对
values = {}
# 遍历 `epoch_dict` 中的每个键值对
for k, v in epoch_dict.items():
if k.startswith("val_"):
# 如果键以 "val_" 开头,将其改为以 "validation_" 开头
k = "validation_" + k[4:]
elif k != "epoch":
# 如果键不是 "epoch",则将其改为以 "train_" 开头
k = "train_" + k
# 将键名按照下划线分割后,每个部分首字母大写,形成更友好的名称
splits = k.split("_")
name = " ".join([part.capitalize() for part in splits])
# 将处理后的键值对加入到 `values` 字典中
values[name] = v
# 将当前周期解析后的字典 `values` 添加到 `lines` 列表中
lines.append(values)
# 提取评估结果,即最后一个周期解析后的结果
eval_results = lines[-1]
# 返回原始日志字典、解析后的周期信息列表 `lines` 和评估结果 `eval_results`
return logs, lines, eval_results
# 解析 `log_history` 参数,获取 `Trainer` 的中间和最终评估结果
def parse_log_history(log_history):
# 初始化索引 `idx`,从头开始查找直到找到包含 "train_runtime" 的日志条目
idx = 0
while idx < len(log_history) and "train_runtime" not in log_history[idx]:
idx += 1
# 如果没有训练日志
if idx == len(log_history):
# 将索引减一,从最后一个日志向前查找包含 "eval_loss" 的日志条目
idx -= 1
while idx >= 0 and "eval_loss" not in log_history[idx]:
idx -= 1
# 如果找到了包含 "eval_loss" 的日志条目,则返回 `None`、`None` 和该日志条目
if idx >= 0:
return None, None, log_history[idx]
else:
# 如果没有找到包含 "eval_loss" 的日志条目,则返回三个 `None`
return None, None, None
# 现在我们可以假设存在训练日志
# 获取训练日志 `train_log`,即包含 "train_runtime" 的日志条目
train_log = log_history[idx]
# 初始化空列表 `lines`,用于存储解析后的日志信息
lines = []
# 初始化训练损失为 "No log"
training_loss = "No log"
# 遍历日志历史记录中索引范围内的每一个索引 i
for i in range(idx):
# 如果当前索引 i 的日志记录包含 "loss" 键
if "loss" in log_history[i]:
# 将训练损失记录下来
training_loss = log_history[i]["loss"]
# 如果当前索引 i 的日志记录包含 "eval_loss" 键
if "eval_loss" in log_history[i]:
# 复制当前日志记录中的所有项到 metrics 字典中
metrics = log_history[i].copy()
# 移除不需要的项目
_ = metrics.pop("total_flos", None)
epoch = metrics.pop("epoch", None)
step = metrics.pop("step", None)
_ = metrics.pop("eval_runtime", None)
_ = metrics.pop("eval_samples_per_second", None)
_ = metrics.pop("eval_steps_per_second", None)
_ = metrics.pop("eval_jit_compilation_time", None)
# 初始化一个空字典 values,用于存储需要的指标
values = {"Training Loss": training_loss, "Epoch": epoch, "Step": step}
# 遍历 metrics 字典中的每一项
for k, v in metrics.items():
# 如果当前项的键是 "eval_loss"
if k == "eval_loss":
# 将其值存入 values 字典中作为 "Validation Loss"
values["Validation Loss"] = v
else:
# 如果键不是 "eval_loss",将键按照下划线分割为列表
splits = k.split("_")
# 将分割后的每个部分的首字母大写,并连接起来作为指标名称
name = " ".join([part.capitalize() for part in splits[1:]])
# 将该指标的值存入 values 字典中
values[name] = v
# 将 values 字典存入 lines 列表中
lines.append(values)
# 将 idx 设置为日志历史记录的长度减一
idx = len(log_history) - 1
# 当 idx 大于等于 0 且日志历史记录中索引为 idx 的项不包含 "eval_loss" 键时循环
while idx >= 0 and "eval_loss" not in log_history[idx]:
# 减小 idx 的值
idx -= 1
# 如果 idx 大于 0
if idx > 0:
# 初始化一个空字典 eval_results,用于存储评估结果
eval_results = {}
# 遍历日志历史记录中索引为 idx 的项的每一个键值对
for key, value in log_history[idx].items():
# 如果键以 "eval_" 开头,去除开头的 "eval_"
if key.startswith("eval_"):
key = key[5:]
# 如果键不是 ["runtime", "samples_per_second", "steps_per_second", "epoch", "step"] 中的一员
if key not in ["runtime", "samples_per_second", "steps_per_second", "epoch", "step"]:
# 将键按照下划线分割为列表,每个部分首字母大写,并连接起来作为新键
camel_cased_key = " ".join([part.capitalize() for part in key.split("_")])
# 将该键及其对应的值存入 eval_results 字典中
eval_results[camel_cased_key] = value
# 返回训练日志 train_log,行列表 lines,以及评估结果 eval_results
return train_log, lines, eval_results
else:
# 如果 idx 不大于 0,则返回训练日志 train_log,行列表 lines,以及空的评估结果
return train_log, lines, None
def extract_hyperparameters_from_keras(model):
# 导入 keras 模块中的函数和类
from .modeling_tf_utils import keras
# 创建一个空字典用于存储超参数
hyperparameters = {}
# 检查模型是否具有优化器,并且获取其配置信息
if hasattr(model, "optimizer") and model.optimizer is not None:
hyperparameters["optimizer"] = model.optimizer.get_config()
else:
hyperparameters["optimizer"] = None
# 获取全局训练精度策略的名称
hyperparameters["training_precision"] = keras.mixed_precision.global_policy().name
# 返回提取的超参数字典
return hyperparameters
def _maybe_round(v, decimals=4):
# 如果 v 是浮点数且有小数部分超过指定的小数位数,则返回按小数位数四舍五入后的字符串
if isinstance(v, float) and len(str(v).split(".")) > 1 and len(str(v).split(".")[1]) > decimals:
return f"{v:.{decimals}f}"
# 否则返回 v 的字符串形式
return str(v)
def _regular_table_line(values, col_widths):
# 生成 Markdown 表格的一行,包括表格的普通行格式
values_with_space = [f"| {v}" + " " * (w - len(v) + 1) for v, w in zip(values, col_widths)]
return "".join(values_with_space) + "|\n"
def _second_table_line(col_widths):
# 生成 Markdown 表格的第二行,包括表头和数据之间的分隔线格式
values = ["|:" + "-" * w + ":" for w in col_widths]
return "".join(values) + "|\n"
def make_markdown_table(lines):
"""
Create a nice Markdown table from the results in `lines`.
"""
# 如果 lines 为空或者 None,则返回空字符串
if lines is None or len(lines) == 0:
return ""
# 初始化列宽字典,计算每列的最大宽度
col_widths = {key: len(str(key)) for key in lines[0].keys()}
for line in lines:
for key, value in line.items():
if col_widths[key] < len(_maybe_round(value)):
col_widths[key] = len(_maybe_round(value))
# 构建 Markdown 表格的内容
table = _regular_table_line(list(lines[0].keys()), list(col_widths.values()))
table += _second_table_line(list(col_widths.values()))
for line in lines:
table += _regular_table_line([_maybe_round(v) for v in line.values()], list(col_widths.values()))
return table
_TRAINING_ARGS_KEYS = [
"learning_rate",
"train_batch_size",
"eval_batch_size",
"seed",
]
def extract_hyperparameters_from_trainer(trainer):
# 从训练器对象中提取超参数,使用预定义的训练参数键
hyperparameters = {k: getattr(trainer.args, k) for k in _TRAINING_ARGS_KEYS}
# 如果并行模式不是单GPU模式或非分布式模式,则添加分布式类型
if trainer.args.parallel_mode not in [ParallelMode.NOT_PARALLEL, ParallelMode.NOT_DISTRIBUTED]:
hyperparameters["distributed_type"] = (
"multi-GPU" if trainer.args.parallel_mode == ParallelMode.DISTRIBUTED else trainer.args.parallel_mode.value
)
# 如果使用多个设备进行训练,则添加设备数量
if trainer.args.world_size > 1:
hyperparameters["num_devices"] = trainer.args.world_size
# 如果梯度累积步数大于1,则添加梯度累积步数
if trainer.args.gradient_accumulation_steps > 1:
hyperparameters["gradient_accumulation_steps"] = trainer.args.gradient_accumulation_steps
# 计算总的训练批次大小,如果不等于预定义的训练批次大小,则添加总训练批次大小
total_train_batch_size = (
trainer.args.train_batch_size * trainer.args.world_size * trainer.args.gradient_accumulation_steps
)
if total_train_batch_size != hyperparameters["train_batch_size"]:
hyperparameters["total_train_batch_size"] = total_train_batch_size
# 计算总的评估批次大小,如果不等于预定义的评估批次大小,则添加总评估批次大小
total_eval_batch_size = trainer.args.eval_batch_size * trainer.args.world_size
if total_eval_batch_size != hyperparameters["eval_batch_size"]:
hyperparameters["total_eval_batch_size"] = total_eval_batch_size
# 如果训练器的参数中指定了使用 Adafactor 优化器
if trainer.args.adafactor:
# 将超参数中的优化器设置为 Adafactor
hyperparameters["optimizer"] = "Adafactor"
else:
# 否则,使用带有指定参数的 Adam 优化器
hyperparameters["optimizer"] = (
f"Adam with betas=({trainer.args.adam_beta1},{trainer.args.adam_beta2}) and"
f" epsilon={trainer.args.adam_epsilon}"
)
# 设置学习率调度器的类型为训练器参数中指定的值
hyperparameters["lr_scheduler_type"] = trainer.args.lr_scheduler_type.value
# 如果训练器参数中指定了非零的预热比例
if trainer.args.warmup_ratio != 0.0:
# 将预热比例设置到学习率调度器的预热比例参数中
hyperparameters["lr_scheduler_warmup_ratio"] = trainer.args.warmup_ratio
# 如果训练器参数中指定了非零的预热步数
if trainer.args.warmup_steps != 0.0:
# 将预热步数设置到学习率调度器的预热步数参数中
hyperparameters["lr_scheduler_warmup_steps"] = trainer.args.warmup_steps
# 如果训练器参数中指定了最大步数不等于 -1
if trainer.args.max_steps != -1:
# 将最大步数设置到超参数的训练步数参数中
hyperparameters["training_steps"] = trainer.args.max_steps
else:
# 否则,将训练轮数设置到超参数的训练轮数参数中
hyperparameters["num_epochs"] = trainer.args.num_train_epochs
# 如果训练器参数中指定了使用混合精度训练
if trainer.args.fp16:
# 如果使用了 Apex 框架
if trainer.use_apex:
# 将混合精度训练设置为 Apex,并包括指定的优化级别
hyperparameters["mixed_precision_training"] = f"Apex, opt level {trainer.args.fp16_opt_level}"
else:
# 否则,将混合精度训练设置为本地 AMP 支持
hyperparameters["mixed_precision_training"] = "Native AMP"
# 如果训练器参数中指定了标签平滑因子不等于 0.0
if trainer.args.label_smoothing_factor != 0.0:
# 将标签平滑因子设置到超参数的标签平滑因子参数中
hyperparameters["label_smoothing_factor"] = trainer.args.label_smoothing_factor
# 返回设置好的超参数字典
return hyperparameters
.\modeling_attn_mask_utils.py
# 引入必要的库
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
# 导入PyTorch库
import torch
# 定义一个数据类,用于处理注意力掩码转换的实用工具
@dataclass
class AttentionMaskConverter:
"""
A utility attention mask class that allows one to:
- Create a causal 4d mask
- Create a causal 4d mask with slided window
- Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask (batch_size, 1, query_length,
key_value_length) that can be multiplied with attention scores
Examples:
```
>>> import torch
>>> from transformers.modeling_attn_mask_utils import AttentionMaskConverter
>>> converter = AttentionMaskConverter(True)
>>> converter.to_4d(torch.tensor([[0, 0, 0, 1, 1]]), 5, key_value_length=5, dtype=torch.float32)
tensor([[[[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
[-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, -3.4028e+38],
[-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, 0.0000e+00]]]])
```
Parameters:
is_causal (`bool`):
Whether the attention mask should be a uni-directional (causal) or bi-directional mask.
sliding_window (`int`, *optional*):
Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer.
"""
# 定义类的属性,控制注意力掩码的类型和参数
is_causal: bool
sliding_window: int
# 类的初始化方法,设定初始参数并验证滑动窗口参数的有效性
def __init__(self, is_causal: bool, sliding_window: Optional[int] = None):
self.is_causal = is_causal
self.sliding_window = sliding_window
# 如果定义了滑动窗口参数,确保其为正整数
if self.sliding_window is not None and self.sliding_window <= 0:
raise ValueError(
f"Make sure that when passing `sliding_window` that its value is a strictly positive integer, not `{self.sliding_window}`"
)
# 方法:生成单向(因果)的四维注意力掩码
def to_causal_4d(
self,
batch_size: int,
query_length: int,
key_value_length: int,
dtype: torch.dtype,
device: Union[torch.device, "str"] = "cpu",
) -> Optional[torch.Tensor]:
"""
Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative
bias to upper right hand triangular matrix (causal mask).
"""
# 如果不是因果关系,抛出数值错误异常
if not self.is_causal:
raise ValueError(f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True.")
# 如果形状未被缓存,创建一个新的因果遮罩并缓存它
input_shape = (batch_size, query_length)
past_key_values_length = key_value_length - query_length
# 创建因果遮罩
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
causal_4d_mask = None
if input_shape[-1] > 1 or self.sliding_window is not None:
causal_4d_mask = self._make_causal_mask(
input_shape,
dtype,
device=device,
past_key_values_length=past_key_values_length,
sliding_window=self.sliding_window,
)
return causal_4d_mask
def to_4d(
self,
attention_mask_2d: torch.Tensor,
query_length: int,
dtype: torch.dtype,
key_value_length: Optional[int] = None,
) -> torch.Tensor:
"""
Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length,
key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is
causal, a causal mask will be added.
"""
# 计算输入形状,即(batch_size, query_length)
input_shape = (attention_mask_2d.shape[0], query_length)
# 创建因果(mask)
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
causal_4d_mask = None
if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
if key_value_length is None:
raise ValueError(
"This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask."
)
# 计算过去键值长度
past_key_values_length = key_value_length - query_length
# 生成因果(mask)
causal_4d_mask = self._make_causal_mask(
input_shape,
dtype,
device=attention_mask_2d.device,
past_key_values_length=past_key_values_length,
sliding_window=self.sliding_window,
)
elif self.sliding_window is not None:
# 抛出未实现错误,滑动窗口目前仅支持因果掩蔽
raise NotImplementedError("Sliding window is currently only implemented for causal masking")
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
# 扩展注意力(mask)
expanded_attn_mask = self._expand_mask(attention_mask_2d, dtype, tgt_len=input_shape[-1]).to(
attention_mask_2d.device
)
# 如果存在因果(mask),则用大负数填充未注意的位置
if causal_4d_mask is not None:
expanded_attn_mask = causal_4d_mask.masked_fill(expanded_attn_mask.bool(), torch.finfo(dtype).min)
# 扩展后的注意力(mask)可能会导致溢出
expanded_4d_mask = expanded_attn_mask
return expanded_4d_mask
@staticmethod
def _make_causal_mask(
input_ids_shape: torch.Size,
dtype: torch.dtype,
device: torch.device,
past_key_values_length: int = 0,
sliding_window: Optional[int] = None,
"""
Make causal mask used for bi-directional self-attention.
"""
# 获取输入的张量形状信息,包括批大小和目标序列长度
bsz, tgt_len = input_ids_shape
# 创建一个与目标长度相同的方形矩阵,用极小的浮点数填充,设备为指定的设备
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
# 创建一个条件张量,范围是目标长度的整数序列
mask_cond = torch.arange(mask.size(-1), device=device)
# 使用条件张量来生成一个下三角矩阵,将其对角线上的元素保持为0
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
# 将掩码转换为指定的数据类型
mask = mask.to(dtype)
# 如果过去键值长度大于0,将0填充的过去键值长度张量连接到现有掩码之前
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
# 如果需要,添加下三角滑动窗口掩码
if sliding_window is not None:
# 计算对角线的值,用于下三角滑动窗口掩码
diagonal = past_key_values_length - sliding_window - 1
# 创建一个下三角矩阵,掩盖大于对角线值的元素
context_mask = torch.tril(torch.ones_like(mask, dtype=torch.bool), diagonal=diagonal)
# 使用极小的浮点数填充掩码中被标记的位置
mask.masked_fill_(context_mask, torch.finfo(dtype).min)
# 将掩码扩展为四维张量:[bsz, 1, tgt_len, tgt_len + past_key_values_length]
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
@staticmethod
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
# 获取输入掩码的形状信息,包括批大小和源序列长度
bsz, src_len = mask.size()
# 如果未提供目标序列长度,则使用源序列长度作为目标序列长度
tgt_len = tgt_len if tgt_len is not None else src_len
# 将二维掩码扩展为四维张量,增加两个额外的维度,将其转换为指定的数据类型
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
# 创建一个反掩码,用1减去扩展的掩码
inverted_mask = 1.0 - expanded_mask
# 使用极小的浮点数填充反掩码中被标记的位置
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
@staticmethod
def _unmask_unattended(
expanded_mask: torch.FloatTensor,
min_dtype: float,
device: Optional[torch.device] = None
):
"""
Unmasks the unattended positions in the attention matrix.
"""
# fmt: off
"""
Attend to all tokens in masked rows from the expanded attention mask, for example the relevant first rows when
using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
Details: https://github.com/pytorch/pytorch/issues/110213
`expanded_mask` is [bsz, num_masks, tgt_seq_len, src_seq_len] or [bsz, tgt_seq_len, src_seq_len].
`attention_mask` is [bsz, src_seq_len].
The dimension num_masks of `expanded_mask` is most often 1, but it can also be the number of heads in the case of alibi attention bias.
For example, if `expanded_mask` is (e.g. here left-padding case)
```
[[[[0, 0, 0],
[0, 0, 0],
[0, 0, 1]]],
[[[1, 0, 0],
[1, 1, 0],
[1, 1, 1]]],
[[[0, 0, 0],
[0, 1, 0],
[0, 1, 1]]]]
```
then the modified `expanded_mask` will be
```
[[[[1, 1, 1], <-- modified
[1, 1, 1], <-- modified
[0, 0, 1]]],
[[[1, 0, 0],
[1, 1, 0],
[1, 1, 1]]],
[[[1, 1, 1], <-- modified
[0, 1, 0],
[0, 1, 1]]]]
"""
# fmt: on
# 检查 expanded_mask 的数据类型是否为 torch.bool,若是则抛出 ValueError
if expanded_mask.dtype == torch.bool:
raise ValueError(
"AttentionMaskConverter._unmask_unattended expects a float `expanded_mask`, got a BoolTensor."
)
# 返回一个修改过的 expanded_mask,其中未被完全掩盖的行将保持不变,其他行将被置为零
return expanded_mask.mul(~torch.all(expanded_mask == min_dtype, dim=-1, keepdim=True))
# 创建一个用于 SDPA(scaled_dot_product_attention)的 4D 因果注意力掩码,形状为 `(batch_size, 1, query_length, key_value_length)`
# 从形状为 `(batch_size, key_value_length)` 的 2D 注意力掩码创建
def _prepare_4d_causal_attention_mask(
attention_mask: Optional[torch.Tensor],
input_shape: Union[torch.Size, Tuple, List],
inputs_embeds: torch.Tensor,
past_key_values_length: int,
sliding_window: Optional[int] = None,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`
Args:
attention_mask (`torch.Tensor` or `None`):
A 2D attention mask of shape `(batch_size, key_value_length)`
input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
The input shape should be a tuple that defines `(batch_size, query_length)`.
inputs_embeds (`torch.Tensor`):
The embedded inputs as a torch Tensor.
past_key_values_length (`int`):
The length of the key value cache.
sliding_window (`int`, *optional*):
If the model uses windowed attention, a sliding window should be passed.
"""
# 创建一个 AttentionMaskConverter 对象,用于生成因果关系的注意力掩码
attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
# 计算 key_value_length,这里是输入形状的最后一个维度加上过去的键值长度
key_value_length = input_shape[-1] + past_key_values_length
# 如果输入的 attention_mask 不为空且是 2D 的
if attention_mask is not None and len(attention_mask.shape) == 2:
# 将 2D 的 attention_mask 转换成 4D 的
attention_mask = attn_mask_converter.to_4d(
attention_mask, input_shape[-1], key_value_length=key_value_length, dtype=inputs_embeds.dtype
)
# 如果 attention_mask 不为空且是 4D 的
elif attention_mask is not None and len(attention_mask.shape) == 4:
# 检查 attention_mask 的形状是否符合预期
expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
if tuple(attention_mask.shape) != expected_shape:
# 如果不符合预期,抛出 ValueError
raise ValueError(
f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
)
else:
# 如果 4D 的 attention_mask 形状正确,则反转它并用负无穷填充
inverted_mask = 1.0 - attention_mask
attention_mask = inverted_mask.masked_fill(
inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min
)
else:
# 如果 attention_mask 为空或者不是 2D 或 4D 的,则使用 AttentionMaskConverter 生成因果 4D 注意力掩码
attention_mask = attn_mask_converter.to_causal_4d(
input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
)
# 返回生成的 attention_mask
return attention_mask
# Adapted from _prepare_4d_causal_attention_mask
# 根据 _prepare_4d_causal_attention_mask 适配
def _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask: Optional[torch.Tensor],
input_shape: Union[torch.Size, Tuple, List],
inputs_embeds: torch.Tensor,
past_key_values_length: int,
sliding_window: Optional[int] = None,
):
"""
Prepares the correct `attn_mask` argument to be used by `torch.nn.functional.scaled_dot_product_attention`.
In case no token is masked in the `attention_mask` argument, we simply set it to `None` for the cases `query_length == 1` and
"""
# 创建一个注意力掩码转换器对象,设定为因果关系模式,并指定是否使用滑动窗口
attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
# 计算键值对的长度,包括输入形状的最后一个维度和过去键值对的长度
key_value_length = input_shape[-1] + past_key_values_length
# 获取输入形状中的批处理大小和查询长度
batch_size, query_length = input_shape
# 检查是否处于追踪状态,如果是,则需要使用SDPA的`attn_mask`参数而不是自定义的`attention_mask`
is_tracing = (
torch.jit.is_tracing()
or isinstance(inputs_embeds, torch.fx.Proxy)
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
)
# 如果传入了注意力掩码
if attention_mask is not None:
# 如果注意力掩码是4维的
if len(attention_mask.shape) == 4:
# 验证注意力掩码的形状是否符合预期
expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
if tuple(attention_mask.shape) != expected_shape:
# 抛出值错误,指出注意力掩码的形状不正确
raise ValueError(
f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
)
else:
# 如果4维掩码形状正确,反转掩码并用负无穷填充
inverted_mask = 1.0 - attention_mask.to(inputs_embeds.dtype)
attention_mask = inverted_mask.masked_fill(
inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min
)
return attention_mask
# 如果不处于追踪状态,并且所有的注意力掩码值都是1
elif not is_tracing and torch.all(attention_mask == 1):
if query_length == 1:
# 当查询长度为1时,因果关注和双向关注是相同的
attention_mask = None
elif key_value_length == query_length:
# 当键值对的长度等于查询长度时,不需要注意力掩码
attention_mask = None
else:
# 对于查询长度大于1且键值长度不等于查询长度的情况,无法忽略注意力掩码,需要特别处理
pass
# 如果没有传入注意力掩码,并且查询长度大于1且键值长度不等于查询长度
elif query_length > 1 and key_value_length != query_length:
# 将注意力掩码设为True,以便在后续控制流中转到`to_causal_4d`
attention_mask = True
# 如果正在进行跟踪(tracing),且未提供注意力掩码(attention_mask),则抛出值错误异常
elif is_tracing:
raise ValueError(
'Attention using SDPA can not be traced with torch.jit.trace when no attention_mask is provided. To solve this issue, please either load your model with the argument `attn_implementation="eager"` or pass an attention_mask input when tracing the model.'
)
# 如果没有提供 attention_mask,则将 expanded_4d_mask 设置为 None
if attention_mask is None:
expanded_4d_mask = None
# 如果 attention_mask 设置为 True,则通过 attn_mask_converter.to_causal_4d 函数生成扩展后的 4D 注意力掩码
elif attention_mask is True:
expanded_4d_mask = attn_mask_converter.to_causal_4d(
input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
)
# 否则,根据给定的 attention_mask 使用 attn_mask_converter.to_4d 函数生成扩展后的 4D 注意力掩码
else:
expanded_4d_mask = attn_mask_converter.to_4d(
attention_mask,
input_shape[-1],
dtype=inputs_embeds.dtype,
key_value_length=key_value_length,
)
# 如果不是在跟踪模式下,并且 expanded_4d_mask 存在且在 CUDA 设备上,
# 则调用 AttentionMaskConverter._unmask_unattended 函数,处理未注意的部分
if not is_tracing and expanded_4d_mask.device.type == "cuda":
expanded_4d_mask = AttentionMaskConverter._unmask_unattended(
expanded_4d_mask, min_dtype=torch.finfo(inputs_embeds.dtype).min
)
# 返回生成或处理后的 expanded_4d_mask 注意力掩码
return expanded_4d_mask
# 创建一个非因果关系的四维注意力掩码,其形状为 `(batch_size, 1, query_length, key_value_length)`,从形状为 `(batch_size, key_value_length)` 的二维掩码创建
def _prepare_4d_attention_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`
Args:
mask (`torch.Tensor` or `None`):
A 2D attention mask of shape `(batch_size, key_value_length)`
dtype (`torch.dtype`):
The torch dtype the created mask shall have.
tgt_len (`int`):
The target length or query length the created mask shall have.
"""
# 调用内部函数 `_expand_mask` 扩展掩码至四维并返回
return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
# 为 SDPA(Scaled Dot-Product Attention)创建非因果四维注意力掩码,形状为 `(batch_size, 1, query_length, key_value_length)`,从形状为 `(batch_size, key_value_length)` 的二维掩码创建
def _prepare_4d_attention_mask_for_sdpa(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`
Args:
mask (`torch.Tensor` or `None`):
A 2D attention mask of shape `(batch_size, key_value_length)`
dtype (`torch.dtype`):
The torch dtype the created mask shall have.
tgt_len (`int`):
The target length or query length the created mask shall have.
"""
# 获取 batch_size 和 key_value_length 的尺寸
batch_size, key_value_length = mask.shape
# 如果未提供 tgt_len,则默认为 key_value_length
tgt_len = tgt_len if tgt_len is not None else key_value_length
# 检查是否处于追踪模式
is_tracing = (
torch.jit.is_tracing()
or isinstance(mask, torch.fx.Proxy)
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
)
# 如果掩码中所有元素均为 1
if torch.all(mask == 1):
if is_tracing:
pass # 如果处于追踪模式,不做任何操作
elif tgt_len == 1:
# 对于 query_length == 1,因果和双向注意力相同,返回 None
return None
elif key_value_length == tgt_len:
# 如果 key_value_length 等于 tgt_len,返回 None
return None
else:
# 对于 query_length > 1 且 key_value_length != query_length 的情况,
# 我们不能忽略注意力掩码,因为 SDPA 因果掩码的生成可能会出错,
# 我们在 SDPA 中将 is_causal=False,并依赖于 Transformers 的 attention_mask,因此在这里不设置为 None。
# 参考: https://github.com/pytorch/pytorch/issues/108108
return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
else:
# 对于其他情况,调用 `_expand_mask` 扩展掩码并返回
return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
# device: torch.device,
# 定义一个参数 `device`,类型为 `torch.device`,用于指定张量运算的设备(如CPU或GPU)
# past_key_values_length: int = 0,
# 定义一个参数 `past_key_values_length`,类型为 `int`,默认值为 `0`,用于指定过去的键值对长度
# sliding_window: Optional[int] = None,
# 定义一个参数 `sliding_window`,类型为可选的 `int`,默认值为 `None`,用于指定滑动窗口的大小
# 创建一个形状为 `(batch_size, 1, query_length, key_value_length)` 的因果性四维掩码
def create_causal_mask(
input_shape: Union[tuple[int], list[int], torch.Size],
dtype: torch.dtype,
device: torch.device,
sliding_window: Optional[int] = None
) -> Optional[torch.Tensor]:
"""
创建一个形状为 `(batch_size, 1, query_length, key_value_length)` 的因果性四维掩码
Args:
input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
输入形状应为定义 `(batch_size, query_length)` 的元组。
dtype (`torch.dtype`):
所创建掩码的 torch 数据类型。
device (`torch.device`):
所创建掩码的 torch 设备。
sliding_window (`int`, *optional*):
如果模型使用窗口化注意力,应传入一个滑动窗口大小。
"""
# 创建一个注意力掩码转换器,设置为因果性,根据是否提供滑动窗口参数决定
attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
# 计算 key_value_length,包括过去键值长度和输入形状的最后一个维度
key_value_length = past_key_values_length + input_shape[-1]
# 使用掩码转换器生成四维因果性掩码
attention_mask = attn_mask_converter.to_causal_4d(
input_shape[0], input_shape[-1], key_value_length, dtype=dtype, device=device
)
# 返回生成的注意力掩码
return attention_mask
.\modeling_flax_outputs.py
# 导入必要的模块和类
from typing import Dict, Optional, Tuple # 导入类型提示相关模块
import flax # 导入Flax库,用于结构化数据类
import jax.numpy as jnp # 导入JAX的NumPy接口
from .utils import ModelOutput # 从当前目录下的utils模块导入ModelOutput类
@flax.struct.dataclass
class FlaxBaseModelOutput(ModelOutput):
"""
模型输出的基础类,包含可能的隐藏状态和注意力机制。
Args:
last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`):
模型最后一层的隐藏状态序列输出。
hidden_states (`tuple(jnp.ndarray)`, *optional*, 当 `output_hidden_states=True` 被传递或 `config.output_hidden_states=True` 时返回):
形状为 `(batch_size, sequence_length, hidden_size)` 的 `jnp.ndarray` 元组。
模型每一层的隐藏状态加上初始嵌入输出。
attentions (`tuple(jnp.ndarray)`, *optional*, 当 `output_attentions=True` 被传递或 `config.output_attentions=True` 时返回):
形状为 `(batch_size, num_heads, sequence_length, sequence_length)` 的 `jnp.ndarray` 元组。
注意力机制softmax后的注意力权重,用于计算自注意力头中的加权平均值。
"""
last_hidden_state: jnp.ndarray = None
hidden_states: Optional[Tuple[jnp.ndarray]] = None
attentions: Optional[Tuple[jnp.ndarray]] = None
@flax.struct.dataclass
class FlaxBaseModelOutputWithNoAttention(ModelOutput):
"""
模型输出的基础类,包含可能的隐藏状态,但不包含注意力机制。
Args:
last_hidden_state (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`):
模型最后一层的隐藏状态序列输出。
hidden_states (`tuple(jnp.ndarray)`, *optional*, 当 `output_hidden_states=True` 被传递或 `config.output_hidden_states=True` 时返回):
形状为 `(batch_size, num_channels, height, width)` 的 `jnp.ndarray` 元组。
模型每一层的隐藏状态加上可选的初始嵌入输出。
"""
last_hidden_state: jnp.ndarray = None
# 定义一个可选的变量 hidden_states,类型为包含 jnp.ndarray 的元组,初始值为 None
hidden_states: Optional[Tuple[jnp.ndarray]] = None
# 使用 @flax.struct.dataclass 装饰器声明一个数据类,表示带有池化和无注意力机制的模型输出
@flax.struct.dataclass
class FlaxBaseModelOutputWithPoolingAndNoAttention(ModelOutput):
"""
Base class for model's outputs that also contains a pooling of the last hidden states.
Args:
last_hidden_state (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`):
Sequence of hidden-states at the output of the last layer of the model.
pooler_output (`jnp.ndarray` of shape `(batch_size, hidden_size)`):
Last layer hidden-state after a pooling operation on the spatial dimensions.
hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `jnp.ndarray` (one for the output of the embeddings, if the model has an embedding layer, + one
for the output of each layer) of shape `(batch_size, num_channels, height, width)`. Hidden-states of the
model at the output of each layer plus the optional initial embedding outputs.
"""
# 声明类的属性及其类型注解
last_hidden_state: jnp.ndarray = None
pooler_output: jnp.ndarray = None
hidden_states: Optional[Tuple[jnp.ndarray]] = None
# 使用 @flax.struct.dataclass 装饰器声明一个数据类,表示不带注意力机制的图像分类模型输出
@flax.struct.dataclass
class FlaxImageClassifierOutputWithNoAttention(ModelOutput):
"""
Base class for outputs of image classification models.
Args:
logits (`jnp.ndarray` of shape `(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when
`config.output_hidden_states=True`):
Tuple of `jnp.ndarray` (one for the output of the embeddings, if the model has an embedding layer, + one
for the output of each stage) of shape `(batch_size, num_channels, height, width)`. Hidden-states (also
called feature maps) of the model at the output of each stage.
"""
# 声明类的属性及其类型注解
logits: jnp.ndarray = None
hidden_states: Optional[Tuple[jnp.ndarray]] = None
# 使用 @flax.struct.dataclass 装饰器声明一个数据类,表示带有过去状态的模型输出
@flax.struct.dataclass
class FlaxBaseModelOutputWithPast(ModelOutput):
"""
Base class for model's outputs, with potential hidden states and attentions.
"""
# `last_hidden_state` 是模型最后一层的隐藏状态序列,形状为 `(batch_size, sequence_length, hidden_size)`
# 这里使用了 JAX 数组 `jnp.ndarray`,表示 JAX 程序的数组结构
last_hidden_state: jnp.ndarray = None
# `past_key_values` 是一个字典,包含预先计算的隐藏状态(在注意力块中的键和值),用于快速自回归解码
# 键和值的隐藏状态的形状为 `[batch_size, max_length]`
past_key_values: Optional[Dict[str, jnp.ndarray]] = None
# `hidden_states` 是一个元组,包含了模型每一层的隐藏状态
# 第一个元素是嵌入层的输出,后续元素是每一层的输出,形状为 `(batch_size, sequence_length, hidden_size)`
# 只有在传递参数 `output_hidden_states=True` 或者配置 `config.output_hidden_states=True` 时才返回
hidden_states: Optional[Tuple[jnp.ndarray]] = None
# `attentions` 是一个元组,包含了每一层的注意力权重
# 每个元素是一个形状为 `(batch_size, num_heads, sequence_length, sequence_length)` 的 JAX 数组
# 只有在传递参数 `output_attentions=True` 或者配置 `config.output_attentions=True` 时才返回
attentions: Optional[Tuple[jnp.ndarray]] = None
# 使用 `flax.struct.dataclass` 装饰器定义一个数据类,该类继承自 `ModelOutput` 类,用于表示模型输出并包含最后隐藏状态的池化结果。
@flax.struct.dataclass
class FlaxBaseModelOutputWithPooling(ModelOutput):
"""
Base class for model's outputs that also contains a pooling of the last hidden states.
Args:
last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
pooler_output (`jnp.ndarray` of shape `(batch_size, hidden_size)`):
Last layer hidden-state of the first token of the sequence (classification token) further processed by a
Linear layer and a Tanh activation function. The Linear layer weights are trained from the next sentence
prediction (classification) objective during pretraining.
hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
# 定义类的属性,表示模型输出中的最后隐藏状态、池化输出、隐藏状态以及注意力权重
last_hidden_state: jnp.ndarray = None
pooler_output: jnp.ndarray = None
hidden_states: Optional[Tuple[jnp.ndarray]] = None
attentions: Optional[Tuple[jnp.ndarray]] = None
# 使用 `flax.struct.dataclass` 装饰器定义另一个数据类,表示模型输出并包含最后隐藏状态的池化结果以及交叉注意力权重。
@flax.struct.dataclass
class FlaxBaseModelOutputWithPoolingAndCrossAttentions(ModelOutput):
"""
Base class for model's outputs that also contains a pooling of the last hidden states.
"""
# 该类继承自 `ModelOutput` 类,与 `FlaxBaseModelOutputWithPooling` 类似,但这里还包括交叉注意力权重的定义。
Args:
last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
pooler_output (`jnp.ndarray` of shape `(batch_size, hidden_size)`):
Last layer hidden-state of the first token of the sequence (classification token) after further processing
through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns
the classification token after processing through a linear layer and a tanh activation function. The linear
layer weights are trained from the next sentence prediction (classification) objective during pretraining.
hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `jnp.ndarray` (one for the output of the embeddings, if the model has an embedding layer, + one
for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
cross_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
weighted average in the cross-attention heads.
past_key_values (`tuple(tuple(jnp.ndarray))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(jnp.ndarray)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
`config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
input) to speed up sequential decoding.
"""
last_hidden_state: jnp.ndarray = None
# 定义一个变量 `pooler_output`,类型为 `jnp.ndarray`,初始值为 None
pooler_output: jnp.ndarray = None
# 定义一个变量 `hidden_states`,类型为 `Optional[Tuple[jnp.ndarray]]`,初始值为 None
hidden_states: Optional[Tuple[jnp.ndarray]] = None
# 定义一个变量 `past_key_values`,类型为 `Optional[Tuple[Tuple[jnp.ndarray]]]`,初始值为 None
past_key_values: Optional[Tuple[Tuple[jnp.ndarray]]] = None
# 定义一个变量 `attentions`,类型为 `Optional[Tuple[jnp.ndarray]]`,初始值为 None
attentions: Optional[Tuple[jnp.ndarray]] = None
# 定义一个变量 `cross_attentions`,类型为 `Optional[Tuple[jnp.ndarray]]`,初始值为 None
cross_attentions: Optional[Tuple[jnp.ndarray]] = None
# 使用 @flax.struct.dataclass 装饰器声明一个数据类,该类继承自 ModelOutput 类
@flax.struct.dataclass
class FlaxBaseModelOutputWithPastAndCrossAttentions(ModelOutput):
"""
Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
Args:
last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
hidden_size)` is output.
past_key_values (`tuple(tuple(jnp.ndarray))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(jnp.ndarray)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
`config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
input) to speed up sequential decoding.
hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
cross_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
weighted average in the cross-attention heads.
"""
# 定义类的属性,每个属性都有一个默认值为 None
last_hidden_state: jnp.ndarray = None
past_key_values: Optional[Tuple[Tuple[jnp.ndarray]]] = None
hidden_states: Optional[Tuple[jnp.ndarray]] = None
attentions: Optional[Tuple[jnp.ndarray]] = None
cross_attentions: Optional[Tuple[jnp.ndarray]] = None
# 定义基于Flax的数据类,表示序列到序列模型的输出,继承自ModelOutput
@flax.struct.dataclass
class FlaxSeq2SeqModelOutput(ModelOutput):
"""
Base class for model encoder's outputs that also contains pre-computed hidden states that can speed up sequential decoding.
"""
# 最后一个隐藏状态,类型为jnp.ndarray,默认为None
last_hidden_state: jnp.ndarray = None
# 过去的键值对,类型为可选的元组,包含元组的元组,每个元组包含jnp.ndarray,默认为None
past_key_values: Optional[Tuple[Tuple[jnp.ndarray]]] = None
# 解码器的隐藏状态,类型为可选的元组,包含jnp.ndarray,默认为None
decoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None
# 解码器的注意力权重,类型为可选的元组,包含jnp.ndarray,默认为None
decoder_attentions: Optional[Tuple[jnp.ndarray]] = None
# 交叉注意力的权重,类型为可选的元组,包含jnp.ndarray,默认为None
cross_attentions: Optional[Tuple[jnp.ndarray]] = None
# 编码器最后一个隐藏状态,类型为可选的jnp.ndarray,默认为None
encoder_last_hidden_state: Optional[jnp.ndarray] = None
# 编码器的隐藏状态,类型为可选的元组,包含jnp.ndarray,默认为None
encoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None
# 编码器的注意力权重,类型为可选的元组,包含jnp.ndarray,默认为None
encoder_attentions: Optional[Tuple[jnp.ndarray]] = None
# 定义基于Flax的数据类,表示带有交叉注意力的因果语言模型输出,继承自ModelOutput
@flax.struct.dataclass
class FlaxCausalLMOutputWithCrossAttentions(ModelOutput):
"""
Base class for causal language model (or autoregressive) outputs.
"""
# 预测的logits,形状为(batch_size, sequence_length, config.vocab_size)的jnp.ndarray
logits: jnp.ndarray
# 隐藏状态的元组,包含embedding输出和每层输出的jnp.ndarray,形状为(batch_size, sequence_length, hidden_size)
hidden_states: Optional[Tuple[jnp.ndarray]] = None
# 注意力权重的元组,每层一个jnp.ndarray,形状为(batch_size, num_heads, sequence_length, sequence_length)
attentions: Optional[Tuple[jnp.ndarray]] = None
# 交叉注意力权重的元组,每层一个jnp.ndarray,形状为(batch_size, num_heads, sequence_length, sequence_length)
cross_attentions: Optional[Tuple[jnp.ndarray]] = None
# 过去的键值对的元组,每层一个jnp.ndarray元组,长度为config.n_layers,仅在使用缓存时有效,用于编码器-解码器设置
past_key_values: Optional[Tuple[Tuple[jnp.ndarray]]] = None
# 定义变量 logits,用于存储一个 NumPy 数组,初始值为 None
logits: jnp.ndarray = None
# 定义变量 past_key_values,类型为 Optional[Tuple[Tuple[jnp.ndarray]]],可选的三重嵌套元组结构,初始值为 None
past_key_values: Optional[Tuple[Tuple[jnp.ndarray]]] = None
# 定义变量 hidden_states,类型为 Optional[Tuple[jnp.ndarray]],可选的元组结构,初始值为 None
hidden_states: Optional[Tuple[jnp.ndarray]] = None
# 定义变量 attentions,类型为 Optional[Tuple[jnp.ndarray]],可选的元组结构,初始值为 None
attentions: Optional[Tuple[jnp.ndarray]] = None
# 定义变量 cross_attentions,类型为 Optional[Tuple[jnp.ndarray]],可选的元组结构,初始值为 None
cross_attentions: Optional[Tuple[jnp.ndarray]] = None
@flax.struct.dataclass
class FlaxMaskedLMOutput(ModelOutput):
"""
Masked语言模型输出的基类。
Args:
logits (`jnp.ndarray` of shape `(batch_size, sequence_length, config.vocab_size)`):
语言建模头的预测分数(SoftMax之前的每个词汇标记的分数)。
hidden_states (`tuple(jnp.ndarray)`, *optional*, 当 `output_hidden_states=True` 或 `config.output_hidden_states=True` 时返回):
形状为 `(batch_size, sequence_length, hidden_size)` 的 `jnp.ndarray` 元组。
模型在每一层输出的隐藏状态加上初始嵌入输出。
attentions (`tuple(jnp.ndarray)`, *optional*, 当 `output_attentions=True` 或 `config.output_attentions=True` 时返回):
形状为 `(batch_size, num_heads, sequence_length, sequence_length)` 的 `jnp.ndarray` 元组。
经过注意力 softmax 后的注意力权重,用于计算自注意力头中的加权平均值。
"""
logits: jnp.ndarray = None
hidden_states: Optional[Tuple[jnp.ndarray]] = None
attentions: Optional[Tuple[jnp.ndarray]] = None
FlaxCausalLMOutput = FlaxMaskedLMOutput
@flax.struct.dataclass
class FlaxSeq2SeqLMOutput(ModelOutput):
"""
序列到序列语言模型输出的基类。
"""
logits: jnp.ndarray = None
past_key_values: Optional[Tuple[Tuple[jnp.ndarray]]] = None
decoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None
decoder_attentions: Optional[Tuple[jnp.ndarray]] = None
cross_attentions: Optional[Tuple[jnp.ndarray]] = None
encoder_last_hidden_state: Optional[jnp.ndarray] = None
encoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None
encoder_attentions: Optional[Tuple[jnp.ndarray]] = None
@flax.struct.dataclass
class FlaxNextSentencePredictorOutput(ModelOutput):
"""
预测两个句子是否连续的模型输出的基类。
"""
Args:
logits (`jnp.ndarray` of shape `(batch_size, 2)`):
Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
before SoftMax).
hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
# 定义函数参数及其类型注释,描述函数接受的输入参数及其形状和类型信息
logits: jnp.ndarray = None
# 定义变量 logits,用于存储形状为(batch_size, 2)的 jnp.ndarray,表示下一个序列预测分类头部的预测分数(经过 SoftMax 之前的 True/False 连续性得分)。
hidden_states: Optional[Tuple[jnp.ndarray]] = None
# 定义变量 hidden_states,可选的元组类型,包含 jnp.ndarray(当传递 output_hidden_states=True 或 config.output_hidden_states=True 时返回)。
# 元组中的每个数组形状为(batch_size, sequence_length, hidden_size),表示模型在每一层输出的隐藏状态以及初始嵌入输出。
attentions: Optional[Tuple[jnp.ndarray]] = None
# 定义变量 attentions,可选的元组类型,包含 jnp.ndarray(当传递 output_attentions=True 或 config.output_attentions=True 时返回)。
# 元组中的每个数组形状为(batch_size, num_heads, sequence_length, sequence_length),
# 表示注意力 softmax 后的注意力权重,用于计算自注意力头部中的加权平均值。
# 使用 @flax.struct.dataclass 装饰器声明一个数据类,用于表示序列分类模型的输出。
class FlaxSequenceClassifierOutput(ModelOutput):
"""
Base class for outputs of sentence classification models.
Args:
logits (`jnp.ndarray` of shape `(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
# 分类或回归得分(未经 SoftMax 处理前)的 logits,形状为 `(batch_size, config.num_labels)`
logits: jnp.ndarray = None
# 模型每一层的输出的隐藏状态的元组,形状为 `(batch_size, sequence_length, hidden_size)`,当 `output_hidden_states=True` 时返回
hidden_states: Optional[Tuple[jnp.ndarray]] = None
# 自注意力机制注意力权重的元组,形状为 `(batch_size, num_heads, sequence_length, sequence_length)`,当 `output_attentions=True` 时返回
attentions: Optional[Tuple[jnp.ndarray]] = None
# 使用 @flax.struct.dataclass 装饰器声明一个数据类,用于表示序列到序列的句子分类模型的输出。
class FlaxSeq2SeqSequenceClassifierOutput(ModelOutput):
"""
Base class for outputs of sequence-to-sequence sentence classification models.
"""
# 分类得分 logits,形状为 `(batch_size, config.num_labels)`
logits: jnp.ndarray = None
# 用于存储过去键值的元组,形状为 `(batch_size, num_layers, 2, batch_size, num_heads, sequence_length, head_dim)`,当 `output_past_key_values=True` 时返回
past_key_values: Optional[Tuple[Tuple[jnp.ndarray]]] = None
# 解码器每一层的隐藏状态的元组,形状为 `(batch_size, sequence_length, hidden_size)`,当 `output_hidden_states=True` 时返回
decoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None
# 解码器每一层的注意力权重的元组,形状为 `(batch_size, num_heads, sequence_length, sequence_length)`,当 `output_attentions=True` 时返回
decoder_attentions: Optional[Tuple[jnp.ndarray]] = None
# 编码器解码器之间注意力权重的元组,形状为 `(batch_size, num_heads, sequence_length, sequence_length)`,当 `output_attentions=True` 时返回
cross_attentions: Optional[Tuple[jnp.ndarray]] = None
# 编码器最后一层的隐藏状态,形状为 `(batch_size, sequence_length, hidden_size)`
encoder_last_hidden_state: Optional[jnp.ndarray] = None
# 编码器每一层的隐藏状态的元组,形状为 `(batch_size, sequence_length, hidden_size)`,当 `output_hidden_states=True` 时返回
encoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None
# 编码器每一层的注意力权重的元组,形状为 `(batch_size, num_heads, sequence_length, sequence_length)`,当 `output_attentions=True` 时返回
encoder_attentions: Optional[Tuple[jnp.ndarray]] = None
# 使用 @flax.struct.dataclass 装饰器声明一个数据类,用于表示多项选择模型的输出。
class FlaxMultipleChoiceModelOutput(ModelOutput):
"""
Base class for outputs of multiple choice models.
"""
# 此类暂未定义任何属性,因此无需添加额外的注释。
"""
Args:
logits (`jnp.ndarray` of shape `(batch_size, num_choices)`):
分类器的输出分数(SoftMax 之前)。
hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
包含每层输出的元组,每个 `jnp.ndarray` 的形状为 `(batch_size, sequence_length, hidden_size)`。
模型在每一层的隐藏状态以及初始嵌入输出。
attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
包含每层注意力权重的元组,每个 `jnp.ndarray` 的形状为 `(batch_size, num_heads, sequence_length, sequence_length)`。
注意力 softmax 后的注意力权重,用于计算自注意力头中的加权平均值。
"""
logits: jnp.ndarray = None
hidden_states: Optional[Tuple[jnp.ndarray]] = None
attentions: Optional[Tuple[jnp.ndarray]] = None
# 使用 @flax.struct.dataclass 装饰器定义一个数据类,表示序列标注模型的输出
@flax.struct.dataclass
class FlaxTokenClassifierOutput(ModelOutput):
"""
序列标注模型输出的基类。
Args:
logits (`jnp.ndarray` of shape `(batch_size, sequence_length, config.num_labels)`):
分类得分(SoftMax 之前)。
hidden_states (`tuple(jnp.ndarray)`, *optional*, 当 `output_hidden_states=True` 或 `config.output_hidden_states=True` 时返回):
包含多个 `jnp.ndarray` 的元组,形状为 `(batch_size, sequence_length, hidden_size)`。
模型在每一层输出的隐藏状态,以及初始嵌入输出。
attentions (`tuple(jnp.ndarray)`, *optional*, 当 `output_attentions=True` 或 `config.output_attentions=True` 时返回):
包含多个 `jnp.ndarray` 的元组,形状为 `(batch_size, num_heads, sequence_length, sequence_length)`。
经过注意力 softmax 后的注意力权重,用于计算自注意力头中的加权平均值。
"""
logits: jnp.ndarray = None
hidden_states: Optional[Tuple[jnp.ndarray]] = None
attentions: Optional[Tuple[jnp.ndarray]] = None
# 使用 @flax.struct.dataclass 装饰器定义一个数据类,表示问答模型的输出
@flax.struct.dataclass
class FlaxQuestionAnsweringModelOutput(ModelOutput):
"""
问答模型输出的基类。
Args:
start_logits (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
起始位置的得分(SoftMax 之前)。
end_logits (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
终止位置的得分(SoftMax 之前)。
hidden_states (`tuple(jnp.ndarray)`, *optional*, 当 `output_hidden_states=True` 或 `config.output_hidden_states=True` 时返回):
包含多个 `jnp.ndarray` 的元组,形状为 `(batch_size, sequence_length, hidden_size)`。
模型在每一层输出的隐藏状态,以及初始嵌入输出。
attentions (`tuple(jnp.ndarray)`, *optional*, 当 `output_attentions=True` 或 `config.output_attentions=True` 时返回):
包含多个 `jnp.ndarray` 的元组,形状为 `(batch_size, num_heads, sequence_length, sequence_length)`。
经过注意力 softmax 后的注意力权重,用于计算自注意力头中的加权平均值。
"""
start_logits: jnp.ndarray = None
end_logits: jnp.ndarray = None
hidden_states: Optional[Tuple[jnp.ndarray]] = None
attentions: Optional[Tuple[jnp.ndarray]] = None
# 使用 @flax.struct.dataclass 装饰器定义一个数据类,表示序列到序列问答模型的输出
@flax.struct.dataclass
class FlaxSeq2SeqQuestionAnsweringModelOutput(ModelOutput):
"""
序列到序列问答模型输出的基类。
"""
start_logits: jnp.ndarray = None
end_logits: jnp.ndarray = None
# 初始化变量,用于存储模型解码器的相关状态和注意力权重
past_key_values: Optional[Tuple[Tuple[jnp.ndarray]]] = None
# 初始化变量,用于存储模型解码器的隐藏状态
decoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None
# 初始化变量,用于存储模型解码器的注意力权重
decoder_attentions: Optional[Tuple[jnp.ndarray]] = None
# 初始化变量,用于存储模型交叉注意力的权重
cross_attentions: Optional[Tuple[jnp.ndarray]] = None
# 初始化变量,用于存储模型编码器的最后隐藏状态
encoder_last_hidden_state: Optional[jnp.ndarray] = None
# 初始化变量,用于存储模型编码器的隐藏状态的序列
encoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None
# 初始化变量,用于存储模型编码器的注意力权重的序列
encoder_attentions: Optional[Tuple[jnp.ndarray]] = None
.\modeling_flax_pytorch_utils.py
# 设置编码为 UTF-8,确保文件能正确处理非 ASCII 字符
# 版权声明及许可协议,指定本文件的版权归属及使用许可
# 注意事项:根据 Apache 许可证 2.0 版本,除非符合许可协议,否则禁止使用本文件
# 获取许可协议的详细信息,请访问指定的 URL
# 如果适用法律要求或书面同意,软件将按“原样”分发,不提供任何明示或暗示的担保或条件
# 查看许可协议以了解具体语言和限制条件
""" PyTorch - Flax general utilities."""
# 导入所需的标准库和模块
import os # 导入操作系统功能
from pickle import UnpicklingError # 导入反序列化错误异常
from typing import Dict, Tuple # 导入类型提示
import jax # 导入 JAX 库
import jax.numpy as jnp # 导入 JAX 的 NumPy 接口
import numpy as np # 导入 NumPy 库
from flax.serialization import from_bytes # 从字节流中反序列化对象
from flax.traverse_util import flatten_dict, unflatten_dict # 对字典进行扁平化和展开操作
import transformers # 导入 Transformers 库
from . import is_safetensors_available, is_torch_available # 导入本地模块
from .utils import logging # 从本地工具模块中导入日志功能
if is_torch_available(): # 如果 Torch 可用
import torch # 导入 Torch 库
if is_safetensors_available(): # 如果 SafeTensors 可用
from safetensors import safe_open # 导入安全打开文件函数
from safetensors.flax import load_file as safe_load_file # 导入安全加载文件函数
logger = logging.get_logger(__name__) # 获取当前模块的日志记录器
#####################
# PyTorch => Flax #
#####################
# 定义一个函数,将 PyTorch 的检查点加载到 Flax 模型的状态字典中
def load_pytorch_checkpoint_in_flax_state_dict(
flax_model, pytorch_checkpoint_path, is_sharded, allow_missing_keys=False
):
"""Load pytorch checkpoints in a flax model"""
if not is_sharded: # 如果不是分片加载
pt_path = os.path.abspath(pytorch_checkpoint_path) # 获取 PyTorch 检查点的绝对路径
logger.info(f"Loading PyTorch weights from {pt_path}") # 记录日志,显示正在加载的 PyTorch 权重文件路径
if pt_path.endswith(".safetensors"): # 如果文件路径以 ".safetensors" 结尾
pt_state_dict = {} # 初始化一个空字典,用于存储 PyTorch 的状态字典
with safe_open(pt_path, framework="flax") as f: # 使用安全方式打开文件
for k in f.keys(): # 遍历文件中的键
pt_state_dict[k] = f.get_tensor(k) # 将文件中的张量数据存储到状态字典中
else: # 如果文件路径不以 ".safetensors" 结尾
try:
import torch # 尝试导入 Torch 库
from .pytorch_utils import is_torch_greater_or_equal_than_1_13 # 导入版本比较工具函数
except (ImportError, ModuleNotFoundError): # 处理导入错误或模块未找到异常
logger.error(
"Loading a PyTorch model in Flax, requires both PyTorch and Flax to be installed. Please see"
" https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation"
" instructions."
) # 记录加载错误信息,指出安装 PyTorch 和 Flax 的必要性
raise # 抛出异常
# 根据不同的 Torch 版本加载权重数据
weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
pt_state_dict = torch.load(pt_path, map_location="cpu", **weights_only_kwarg) # 使用 Torch 加载权重数据到状态字典
logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.") # 记录日志,显示加载的参数数量
flax_state_dict = convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model) # 将 PyTorch 的状态字典转换为 Flax 的状态字典
else:
# 如果模型是分片的,并且 pytorch_checkpoint_path 已经包含了 .pt 分片文件的列表
# 则将使用 convert_pytorch_sharded_state_dict_to_flax 函数将其转换为 Flax 模型的状态字典
flax_state_dict = convert_pytorch_sharded_state_dict_to_flax(pytorch_checkpoint_path, flax_model)
# 返回转换后的 Flax 模型状态字典
return flax_state_dict
# 将 PyTorch 权重名称重命名为对应的 Flax 权重名称并在必要时重塑张量
def rename_key_and_reshape_tensor(
pt_tuple_key: Tuple[str],
pt_tensor: np.ndarray,
random_flax_state_dict: Dict[str, jnp.ndarray],
model_prefix: str,
) -> (Tuple[str], np.ndarray):
"""Rename PT weight names to corresponding Flax weight names and reshape tensor if necessary"""
def is_key_or_prefix_key_in_dict(key: Tuple[str]) -> bool:
"""Checks if `key` of `(prefix,) + key` is in random_flax_state_dict"""
return len(set(random_flax_state_dict) & {key, (model_prefix,) + key}) > 0
# layer norm
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
if pt_tuple_key[-1] in ["weight", "gamma"] and is_key_or_prefix_key_in_dict(renamed_pt_tuple_key):
return renamed_pt_tuple_key, pt_tensor
# batch norm layer mean
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("mean",)
if pt_tuple_key[-1] == "running_mean" and not is_key_or_prefix_key_in_dict(pt_tuple_key):
return renamed_pt_tuple_key, pt_tensor
# batch norm layer var
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("var",)
if pt_tuple_key[-1] == "running_var" and not is_key_or_prefix_key_in_dict(pt_tuple_key):
return renamed_pt_tuple_key, pt_tensor
# embedding
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("embedding",)
if pt_tuple_key[-1] == "weight" and is_key_or_prefix_key_in_dict(renamed_pt_tuple_key):
return renamed_pt_tuple_key, pt_tensor
# conv layer
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4 and not is_key_or_prefix_key_in_dict(pt_tuple_key):
pt_tensor = pt_tensor.transpose(2, 3, 1, 0)
return renamed_pt_tuple_key, pt_tensor
# linear layer
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
if pt_tuple_key[-1] == "weight" and not is_key_or_prefix_key_in_dict(pt_tuple_key):
pt_tensor = pt_tensor.T
return renamed_pt_tuple_key, pt_tensor
# old PyTorch layer norm weight
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("weight",)
if pt_tuple_key[-1] == "gamma":
return renamed_pt_tuple_key, pt_tensor
# old PyTorch layer norm bias
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("bias",)
if pt_tuple_key[-1] == "beta":
return renamed_pt_tuple_key, pt_tensor
# New `weight_norm` from https://github.com/huggingface/transformers/pull/24030
name = None
if pt_tuple_key[-3::2] == ("parametrizations", "original0"):
name = pt_tuple_key[-2] + "_g"
elif pt_tuple_key[-3::2] == ("parametrizations", "original1"):
name = pt_tuple_key[-2] + "_v"
if name is not None:
renamed_pt_tuple_key = pt_tuple_key[:-3] + (name,)
return renamed_pt_tuple_key, pt_tensor
# 默认情况,返回原始的 tuple key 和张量
return pt_tuple_key, pt_tensor
# 根据条件确定要使用的 bfloat16 类型,如果 from_bin 是 True 则使用 torch.bfloat16,否则使用字符串 "bfloat16"
bfloat16 = torch.bfloat16 if from_bin else "bfloat16"
# 创建一个字典,将 PyTorch 模型状态字典中每个键对应的数据类型收集起来
weight_dtypes = {k: v.dtype for k, v in pt_state_dict.items()}
# 如果 from_bin 是 True,则需要将 PyTorch 模型状态字典中的 bfloat16 类型转换为 float32,以避免精度损失问题
if from_bin:
for k, v in pt_state_dict.items():
# 当前 numpy 不支持 bfloat16,因此在此情况下需要将其转换为 float32 类型
if v.dtype == bfloat16:
v = v.float()
pt_state_dict[k] = v.numpy()
# 获取 Flax 模型的基础模型前缀
model_prefix = flax_model.base_model_prefix
# 如果模型中包含批归一化层,则使用 params 字典
if "params" in flax_model.params:
flax_model_params = flax_model.params["params"]
else:
flax_model_params = flax_model.params
# 将 Flax 模型参数展平为字典
random_flax_state_dict = flatten_dict(flax_model_params)
# 如果模型参数中包含 batch_stats,则将其展平并加入随机状态字典中
if "batch_stats" in flax_model.params:
flax_batch_stats = flatten_dict(flax_model.params["batch_stats"])
random_flax_state_dict.update(flax_batch_stats)
# 初始化一个空的 Flax 状态字典
flax_state_dict = {}
# 根据条件判断是将头部模型加载到基础模型中,还是将基础模型加载到带头部模型中
load_model_with_head_into_base_model = (model_prefix not in flax_model_params) and (
model_prefix in {k.split(".")[0] for k in pt_state_dict.keys()}
)
load_base_model_into_model_with_head = (model_prefix in flax_model_params) and (
model_prefix not in {k.split(".")[0] for k in pt_state_dict.keys()}
)
# 需要修改一些参数名称以匹配 Flax 模型的命名规范
# 此处的注释指出了需要进行参数名称匹配的必要性,但没有具体说明如何实现
# 遍历 PyTorch 状态字典中的每个键值对
for pt_key, pt_tensor in pt_state_dict.items():
# 将点分割的键名转换为元组形式
pt_tuple_key = tuple(pt_key.split("."))
# 检查当前权重数据类型是否为 bfloat16
is_bfloat_16 = weight_dtypes[pt_key] == bfloat16
# 如果需要,移除基础模型前缀
has_base_model_prefix = pt_tuple_key[0] == model_prefix
if load_model_with_head_into_base_model and has_base_model_prefix:
pt_tuple_key = pt_tuple_key[1:]
# 使用指定函数重命名键名并调整张量形状
flax_key, flax_tensor = rename_key_and_reshape_tensor(
pt_tuple_key, pt_tensor, random_flax_state_dict, model_prefix
)
# 如果需要,添加模型前缀
require_base_model_prefix = (model_prefix,) + flax_key in random_flax_state_dict
if load_base_model_into_model_with_head and require_base_model_prefix:
flax_key = (model_prefix,) + flax_key
# 检查重命名后的键是否存在于随机化的 Flax 状态字典中
if flax_key in random_flax_state_dict:
# 检查张量形状是否与期望的 Flax 模型权重形状一致,否则抛出 ValueError
if flax_tensor.shape != random_flax_state_dict[flax_key].shape:
raise ValueError(
f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape "
f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}."
)
# 如果 Flax 模型包含批次归一化层,添加批次统计信息
if "batch_stats" in flax_model.params:
if "mean" in flax_key[-1] or "var" in flax_key[-1]:
# 将 Flax 张量转换为 JAX 数组,并存储在新的位置
flax_state_dict[("batch_stats",) + flax_key] = jnp.asarray(flax_tensor)
continue
# 移除 num_batches_tracked 键
if "num_batches_tracked" in flax_key[-1]:
flax_state_dict.pop(flax_key, None)
continue
# 否则,将权重添加到 params 键下
flax_state_dict[("params",) + flax_key] = (
jnp.asarray(flax_tensor) if not is_bfloat_16 else jnp.asarray(flax_tensor, dtype=jnp.bfloat16)
)
else:
# 如果模型不包含批次归一化层,也将权重添加到状态字典中
flax_state_dict[flax_key] = (
jnp.asarray(flax_tensor) if not is_bfloat_16 else jnp.asarray(flax_tensor, dtype=jnp.bfloat16)
)
# 返回经过重新构造的 Flax 状态字典
return unflatten_dict(flax_state_dict)
############################
# Sharded Pytorch => Flax #
############################
# 将分片的 PyTorch 状态字典转换为 Flax 格式
def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model):
import torch
from .pytorch_utils import is_torch_greater_or_equal_than_1_13
# Load the index
flax_state_dict = {}
# 调用函数 unflatten_dict 将 Flax 状态字典展开并返回
return unflatten_dict(flax_state_dict)
#####################
# Flax => PyTorch #
#####################
# 在 PyTorch 模型中加载 Flax 检查点
def load_flax_checkpoint_in_pytorch_model(model, flax_checkpoint_path):
"""Load flax checkpoints in a PyTorch model"""
flax_checkpoint_path = os.path.abspath(flax_checkpoint_path)
logger.info(f"Loading Flax weights from {flax_checkpoint_path}")
# import correct flax class
flax_cls = getattr(transformers, "Flax" + model.__class__.__name__)
# load flax weight dict
if flax_checkpoint_path.endswith(".safetensors"):
# 使用 safe_load_file 函数加载安全张量文件
flax_state_dict = safe_load_file(flax_checkpoint_path)
# 使用分隔符 "." 对 Flax 状态字典进行展开
flax_state_dict = unflatten_dict(flax_state_dict, sep=".")
else:
with open(flax_checkpoint_path, "rb") as state_f:
try:
# 尝试从文件中读取并解析 Flax 序列化对象
flax_state_dict = from_bytes(flax_cls, state_f.read())
except UnpicklingError:
raise EnvironmentError(f"Unable to convert {flax_checkpoint_path} to Flax deserializable object. ")
return load_flax_weights_in_pytorch_model(model, flax_state_dict)
# 在 PyTorch 模型中加载 Flax 权重
def load_flax_weights_in_pytorch_model(pt_model, flax_state):
"""Load flax checkpoints in a PyTorch model"""
try:
import torch # noqa: F401
except (ImportError, ModuleNotFoundError):
logger.error(
"Loading a Flax weights in PyTorch, requires both PyTorch and Flax to be installed. Please see"
" https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation"
" instructions."
)
raise
# check if we have bf16 weights
# 检查是否存在 bf16 类型的权重,并转换为 fp32 类型以便 PyTorch 加载
is_type_bf16 = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype == jnp.bfloat16, flax_state)).values()
if any(is_type_bf16):
# 如果发现 Flax 模型中包含 bf16 类型的权重,则警告并将它们转换为 fp32 类型
logger.warning(
"Found ``bfloat16`` weights in Flax model. Casting all ``bfloat16`` weights to ``float32`` "
"before loading those in PyTorch model."
)
flax_state = jax.tree_util.tree_map(
lambda params: params.astype(np.float32) if params.dtype == jnp.bfloat16 else params, flax_state
)
# 将 Flax 状态字典展开为一维字典
flax_state_dict = flatten_dict(flax_state)
# 获取 PyTorch 模型的状态字典
pt_model_dict = pt_model.state_dict()
# 判断是否需要将模型头加载到基础模型中
load_model_with_head_into_base_model = (pt_model.base_model_prefix in flax_state) and (
pt_model.base_model_prefix not in {k.split(".")[0] for k in pt_model_dict.keys()}
)
# 检查是否需要将基础模型加载到带头部的模型中
load_base_model_into_model_with_head = (pt_model.base_model_prefix not in flax_state) and (
pt_model.base_model_prefix in {k.split(".")[0] for k in pt_model_dict.keys()}
)
# 用于跟踪未预期和丢失的键
unexpected_keys = []
missing_keys = set(pt_model_dict.keys())
# 加载 PyTorch 模型的状态字典
pt_model.load_state_dict(pt_model_dict)
# 将缺失的键重新转换为列表
missing_keys = list(missing_keys)
# 如果存在未预期的键,则发出警告
if len(unexpected_keys) > 0:
logger.warning(
"Some weights of the Flax model were not used when initializing the PyTorch model"
f" {pt_model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are initializing"
f" {pt_model.__class__.__name__} from a Flax model trained on another task or with another architecture"
" (e.g. initializing a BertForSequenceClassification model from a FlaxBertForPreTraining model).\n- This"
f" IS NOT expected if you are initializing {pt_model.__class__.__name__} from a Flax model that you expect"
" to be exactly identical (e.g. initializing a BertForSequenceClassification model from a"
" FlaxBertForSequenceClassification model)."
)
else:
# 所有 Flax 模型的权重均已用于初始化 PyTorch 模型
logger.warning(f"All Flax model weights were used when initializing {pt_model.__class__.__name__}.\n")
# 如果存在丢失的键,则发出警告
if len(missing_keys) > 0:
logger.warning(
f"Some weights of {pt_model.__class__.__name__} were not initialized from the Flax model and are newly"
f" initialized: {missing_keys}\nYou should probably TRAIN this model on a down-stream task to be able to"
" use it for predictions and inference."
)
else:
# 所有的权重已从 Flax 模型初始化
logger.warning(
f"All the weights of {pt_model.__class__.__name__} were initialized from the Flax model.\n"
"If your task is similar to the task the model of the checkpoint was trained on, "
f"you can already use {pt_model.__class__.__name__} for predictions without further training."
)
# 返回加载后的 PyTorch 模型
return pt_model
.\modeling_flax_utils.py
# coding=utf-8
# 代码文件声明使用 UTF-8 编码
# 导入标准库和第三方库
import gc # 垃圾回收模块
import json # JSON 数据格式处理模块
import os # 系统操作模块
import re # 正则表达式模块
import warnings # 警告模块
from functools import partial # 偏函数功能
from pickle import UnpicklingError # 反序列化错误异常
from typing import Any, Dict, Optional, Set, Tuple, Union # 类型提示模块
# 导入 Flax 和 JAX 库
import flax.linen as nn # Flax 的线性模块
import jax # JAX 数值计算库
import jax.numpy as jnp # JAX 的 NumPy 接口
import msgpack.exceptions # MsgPack 序列化异常模块
from flax.core.frozen_dict import FrozenDict, unfreeze # 冻结字典和解冻功能
from flax.serialization import from_bytes, to_bytes # 对象序列化和反序列化
from flax.traverse_util import flatten_dict, unflatten_dict # 字典扁平化和反扁平化
from jax.random import PRNGKey # JAX 随机数生成模块
# 导入本地的配置和工具函数
from .configuration_utils import PretrainedConfig # 预训练模型配置类
from .dynamic_module_utils import custom_object_save # 自定义对象保存函数
from .generation import FlaxGenerationMixin, GenerationConfig # 生成相关模块
from .modeling_flax_pytorch_utils import load_pytorch_checkpoint_in_flax_state_dict # 加载 PyTorch 检查点到 Flax 状态字典
from .utils import (
FLAX_WEIGHTS_INDEX_NAME, FLAX_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, # 各种常量定义
WEIGHTS_INDEX_NAME, WEIGHTS_NAME, # 权重文件名和索引名
PushToHubMixin, # 推送到 Hub 的混合类
add_code_sample_docstrings, # 添加代码示例文档字符串
add_start_docstrings_to_model_forward, # 添加模型前向方法的文档字符串
cached_file, # 缓存文件函数
copy_func, # 复制函数对象
download_url, # 下载 URL 资源函数
has_file, # 检查文件是否存在函数
is_offline_mode, # 检查是否处于离线模式函数
is_remote_url, # 检查 URL 是否远程地址函数
logging, # 日志模块
replace_return_docstrings, # 替换返回值的文档字符串函数
)
from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files # Hub 工具函数
from .utils.import_utils import is_safetensors_available # 检查是否安装安全张量库
if is_safetensors_available():
from safetensors import safe_open # 安全打开文件函数
from safetensors.flax import load_file as safe_load_file # 安全加载文件函数
from safetensors.flax import save_file as safe_save_file # 安全保存文件函数
logger = logging.get_logger(__name__) # 获取当前模块的日志记录器
def quick_gelu(x):
"""
快速 GELU 激活函数的定义,使用 JAX 实现
"""
return x * jax.nn.sigmoid(1.702 * x)
ACT2FN = {
"gelu": partial(nn.gelu, approximate=False), # 使用 Flax 提供的精确 GELU 激活函数
"relu": nn.relu, # 使用 Flax 提供的 ReLU 激活函数
"silu": nn.swish, # 使用 Flax 提供的 SiLU(Swish)激活函数
"swish": nn.swish, # 使用 Flax 提供的 Swish 激活函数
"gelu_new": partial(nn.gelu, approximate=True), # 使用 Flax 提供的近似 GELU 激活函数
"quick_gelu": quick_gelu, # 使用定义的快速 GELU 激活函数
"gelu_pytorch_tanh": partial(nn.gelu, approximate=True), # 使用 Flax 提供的近似 GELU 激活函数
}
def dtype_byte_size(dtype):
"""
根据数据类型 `dtype` 返回一个参数占用的字节数。例如:
```
>>> dtype_byte_size(np.float32)
4
```
"""
if dtype == bool:
return 1 / 8 # 布尔类型占用 1 位,即 1/8 字节
bit_search = re.search(r"[^\d](\d+)$", dtype.name)
if bit_search is None:
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.") # 若数据类型不合法,抛出异常
bit_size = int(bit_search.groups()[0]) # 获取数据类型的位数大小
return bit_size // 8 # 返回字节数
def flax_shard_checkpoint(params, max_shard_size="10GB"):
"""
将参数 `params` 拆分为多个小的检查点文件,以便于存储和传输。
"""
# 将模型状态字典拆分为子检查点,使得每个子检查点的最终大小不超过给定的大小限制。
# 子检查点的确定是通过按照状态字典的键的顺序迭代进行的,因此不会优化使每个子检查点尽可能接近传递的最大大小。
# 例如,如果限制是10GB,并且我们有大小为[6GB, 6GB, 2GB, 6GB, 2GB, 2GB]的权重,则它们将被分割为[6GB]、[6+2GB]、[6+2+2GB],而不是[6+2+2GB]、[6+2GB]、[6GB]。
# <Tip warning={true}>
# 如果模型中的某个权重大于`max_shard_size`,它将单独存在于其自己的子检查点中,其大小将大于`max_shard_size`。
# </Tip>
Args:
params (`Union[Dict, FrozenDict]`): 模型参数的`PyTree`表示。
max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
每个子检查点的最大大小。如果表示为字符串,则需要是数字后跟单位(例如`"5MB"`)。
"""
# 将`max_shard_size`转换为整数表示
max_shard_size = convert_file_size_to_int(max_shard_size)
# 初始化用于存储分片状态字典的列表
sharded_state_dicts = []
# 当前分块的字典
current_block = {}
# 当前分块的大小
current_block_size = 0
# 总大小
total_size = 0
# 将参数展平为键值对
weights = flatten_dict(params, sep="/")
for item in weights:
# 计算权重项的大小
weight_size = weights[item].size * dtype_byte_size(weights[item].dtype)
# 如果当前分块加上当前权重项的大小超过了最大分块大小,进行分块
if current_block_size + weight_size > max_shard_size:
sharded_state_dicts.append(current_block)
current_block = {}
current_block_size = 0
# 将权重项添加到当前分块中
current_block[item] = weights[item]
current_block_size += weight_size
total_size += weight_size
# 添加最后一个分块
sharded_state_dicts.append(current_block)
# 如果只有一个分片,直接返回
if len(sharded_state_dicts) == 1:
return {FLAX_WEIGHTS_NAME: sharded_state_dicts[0]}, None
# 否则,构建权重映射和分片文件名
weight_map = {}
shards = {}
for idx, shard in enumerate(sharded_state_dicts):
shard_file = FLAX_WEIGHTS_NAME.replace(".msgpack", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.msgpack")
shards[shard_file] = shard
for weight_name in shard.keys():
weight_map[weight_name] = shard_file
# 添加元数据
metadata = {"total_size": total_size}
index = {"metadata": metadata, "weight_map": weight_map}
return shards, index
# FlaxPreTrainedModel 类,继承自 PushToHubMixin 和 FlaxGenerationMixin
class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
# 所有模型的基类。
r"""
Base class for all models.
[`FlaxPreTrainedModel`] takes care of storing the configuration of the models and handles methods for loading,
downloading and saving models.
Class attributes (overridden by derived classes):
- **config_class** ([`PretrainedConfig`]) -- A subclass of [`PretrainedConfig`] to use as configuration class
for this model architecture.
- **base_model_prefix** (`str`) -- A string indicating the attribute associated to the base model in derived
classes of the same architecture adding modules on top of the base model.
- **main_input_name** (`str`) -- The name of the principal input to the model (often `input_ids` for NLP
models, `pixel_values` for vision models and `input_values` for speech models).
"""
# 模型配置类,默认为 None
config_class = None
# 基模型前缀,默认为空字符串
base_model_prefix = ""
# 主要输入名称,默认为 "input_ids"
main_input_name = "input_ids"
# 自动类
_auto_class = None
# 缺失的键集合
_missing_keys = set()
# 模型初始化方法
def __init__(
self,
config: PretrainedConfig,
module: nn.Module,
input_shape: Tuple = (1, 1),
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
):
# 如果 config 为 None,则抛出 ValueError
if config is None:
raise ValueError("config cannot be None")
# 如果 module 为 None,则抛出 ValueError
if module is None:
raise ValueError("module cannot be None")
# 下面的属性用于在派生类中作为类型化属性暴露,因此为私有属性。
# 存储配置对象
self._config = config
# 存储模块对象
self._module = module
# 下面的属性为每个派生类通用的公共属性。
# 初始化随机数生成器的 key
self.key = PRNGKey(seed)
# 数据类型,默认为 jnp.float32
self.dtype = dtype
# 输入形状,默认为 (1, 1)
self.input_shape = input_shape
# 生成配置对象,基于模型配置生成,如果可以生成的话
self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
# 标志模型是否已初始化
self._is_initialized = _do_init
# 如果 _do_init 为 True,则随机初始化参数
if _do_init:
# 随机初始化模型参数
random_params = self.init_weights(self.key, input_shape)
# 计算参数的形状树
params_shape_tree = jax.eval_shape(lambda params: params, random_params)
else:
# 如果 _do_init 为 False,则部分初始化模型参数
init_fn = partial(self.init_weights, input_shape=input_shape)
params_shape_tree = jax.eval_shape(init_fn, self.key)
# 日志记录,提示模型权重未初始化
logger.info(
"Model weights are not initialized as `_do_init` is set to `False`. "
f"Make sure to call `{self.__class__.__name__}.init_weights` manually to initialize the weights."
)
# 存储参数形状树
self._params_shape_tree = params_shape_tree
# 将必需参数保存为集合
self._required_params = set(flatten_dict(unfreeze(params_shape_tree)).keys())
# 如果 _do_init 为 True,则设置模型参数
if _do_init:
self.params = random_params
# 定义一个抽象方法,用于初始化模型的权重。子类必须实现这个方法。
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> Dict:
raise NotImplementedError(f"init method has to be implemented for {self}")
# 定义一个抽象方法,用于启用梯度检查点功能。子类必须实现这个方法。
def enable_gradient_checkpointing(self):
raise NotImplementedError(f"gradient checkpointing method has to be implemented for {self}")
# 类方法,用于根据给定的配置和其他参数创建类的实例。
@classmethod
def _from_config(cls, config, **kwargs):
"""
All context managers that the model should be initialized under go here.
"""
return cls(config, **kwargs)
# 返回字符串标识,指示这是一个 Flax 模型。
@property
def framework(self) -> str:
"""
:str: Identifies that this is a Flax model.
"""
return "flax"
# 返回模型的配置信息。
@property
def config(self) -> PretrainedConfig:
return self._config
# 返回模型的内部模块。
@property
def module(self) -> nn.Module:
return self._module
# 返回模型的参数,可以是普通字典或者冻结字典。
@property
def params(self) -> Union[Dict, FrozenDict]:
if not self._is_initialized:
raise ValueError(
"`params` cannot be accessed from model when the model is created with `_do_init=False`. "
"You must call `init_weights` manually and store the params outside of the model and "
"pass it explicitly where needed."
)
return self._params
# 返回模型所需的参数集合。
@property
def required_params(self) -> Set:
return self._required_params
# 返回模型参数的形状树。
@property
def params_shape_tree(self) -> Dict:
return self._params_shape_tree
# 设置模型的参数,如果模型未初始化则抛出异常。
@params.setter
def params(self, params: Union[Dict, FrozenDict]):
# 如果模型未初始化,则不设置参数。
if not self._is_initialized:
raise ValueError(
"`params` cannot be set from model when the model is created with `_do_init=False`. "
"You store the params outside of the model."
)
# 如果参数是冻结字典,则解冻成普通字典。
if isinstance(params, FrozenDict):
params = unfreeze(params)
# 检查参数是否包含所有必需的参数键。
param_keys = set(flatten_dict(params).keys())
if len(self.required_params - param_keys) > 0:
raise ValueError(
"Some parameters are missing. Make sure that `params` include the following "
f"parameters {self.required_params - param_keys}"
)
# 设置模型的参数。
self._params = params
def _cast_floating_to(self, params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any:
"""
Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`.
"""
# 从 https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27 中借用
# 定义条件转换函数,用于将参数中的浮点值转换为指定的 dtype
def conditional_cast(param):
if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating):
param = param.astype(dtype)
return param
# 如果 mask 为 None,则直接对 params 应用 tree_map 转换
if mask is None:
return jax.tree_util.tree_map(conditional_cast, params)
# 将 params 展平为字典
flat_params = flatten_dict(params)
# 将 mask 也展平并获取其结构
flat_mask, _ = jax.tree_util.tree_flatten(mask)
# 遍历展平后的 mask 和 params 的键值对,并根据 mask 的值进行条件转换
for masked, key in zip(flat_mask, sorted(flat_params.keys())):
if masked:
flat_params[key] = conditional_cast(flat_params[key])
# 返回转换后的 params 的非展平版本
return unflatten_dict(flat_params)
def to_bf16(self, params: Union[Dict, FrozenDict], mask: Any = None):
r"""
Cast the floating-point `params` to `jax.numpy.bfloat16`. This returns a new `params` tree and does not cast
the `params` in place.
This method can be used on TPU to explicitly convert the model parameters to bfloat16 precision to do full
half-precision training or to save weights in bfloat16 for inference in order to save memory and improve speed.
Arguments:
params (`Union[Dict, FrozenDict]`):
A `PyTree` of model parameters.
mask (`Union[Dict, FrozenDict]`):
A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params
you want to cast, and should be `False` for those you want to skip.
Examples:
```
>>> from transformers import FlaxBertModel
>>> # load model
>>> model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased")
>>> # By default, the model parameters will be in fp32 precision, to cast these to bfloat16 precision
>>> model.params = model.to_bf16(model.params)
>>> # If you want don't want to cast certain parameters (for example layer norm bias and scale)
>>> # then pass the mask as follows
>>> from flax import traverse_util
>>> model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased")
>>> flat_params = traverse_util.flatten_dict(model.params)
>>> mask = {
... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale"))
... for path in flat_params
... }
>>> mask = traverse_util.unflatten_dict(mask)
>>> model.params = model.to_bf16(model.params, mask)
```
"""
return self._cast_floating_to(params, jnp.bfloat16, mask)
def to_fp32(self, params: Union[Dict, FrozenDict], mask: Any = None):
r"""
Cast the floating-point `parmas` to `jax.numpy.float32`. This method can be used to explicitly convert the
model parameters to fp32 precision. This returns a new `params` tree and does not cast the `params` in place.
Arguments:
params (`Union[Dict, FrozenDict]`):
A `PyTree` of model parameters.
mask (`Union[Dict, FrozenDict]`):
A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params
you want to cast, and should be `False` for those you want to skip
Examples:
```
>>> from transformers import FlaxBertModel
>>> # Download model and configuration from huggingface.co
>>> model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased")
>>> # By default, the model params will be in fp32, to illustrate the use of this method,
>>> # we'll first cast to fp16 and back to fp32
>>> model.params = model.to_f16(model.params)
>>> # now cast back to fp32
>>> model.params = model.to_fp32(model.params)
```
# 使用 jax 库中的 numpy 模块将浮点型参数 `params` 转换为单精度浮点数(float32)
return self._cast_floating_to(params, jnp.float32, mask)
# 将浮点数参数 `params` 转换为 `jax.numpy.float16` 类型。返回一个新的 `params` 树,不会原地修改 `params`。
#
# 在 GPU 上可以使用此方法显式地将模型参数转换为 float16 精度,以进行全半精度训练,或者将权重保存为 float16 以节省内存并提高速度。
#
# 参数:
# params (`Union[Dict, FrozenDict]`):
# 模型参数的 PyTree 结构。
# mask (`Union[Dict, FrozenDict]`, 可选):
# 与 `params` 结构相同的 PyTree。叶子节点应为布尔值,`True` 表示要转换的参数,`False` 表示要跳过的参数。
#
# 示例:
#
# ```
# >>> from transformers import FlaxBertModel
# >>>
# >>> # 加载模型
# >>> model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased")
# >>>
# >>> # 默认情况下,模型参数将是 fp32 类型,要将其转换为 float16 类型
# >>> model.params = model.to_fp16(model.params)
# >>>
# >>> # 如果不想转换某些参数(例如层归一化的偏置和缩放)
# >>> # 则按以下方式传递 mask
# >>> from flax import traverse_util
# >>>
# >>> model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased")
# >>> flat_params = traverse_util.flatten_dict(model.params)
# >>> mask = {
# ... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale"))
# ... for path in flat_params
# ... }
# >>> mask = traverse_util.unflatten_dict(mask)
# >>> model.params = model.to_fp16(model.params, mask)
# ```
def to_fp16(self, params: Union[Dict, FrozenDict], mask: Any = None):
# 调用内部方法 `_cast_floating_to` 将 `params` 中的浮点数类型转换为 `jnp.float16` 类型
return self._cast_floating_to(params, jnp.float16, mask)
# 定义一个类方法,用于加载 Flax 模型的权重数据
def load_flax_weights(cls, resolved_archive_file):
try:
# 如果文件名以 ".safetensors" 结尾,使用 safe_load_file 加载状态
if resolved_archive_file.endswith(".safetensors"):
state = safe_load_file(resolved_archive_file)
# 使用特定分隔符将状态字典展开
state = unflatten_dict(state, sep=".")
else:
# 否则,使用二进制方式读取文件并将其反序列化为对象状态
with open(resolved_archive_file, "rb") as state_f:
state = from_bytes(cls, state_f.read())
# 捕获反序列化过程可能出现的异常
except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
try:
# 尝试以文本模式打开文件,检查其内容以确定错误类型
with open(resolved_archive_file) as f:
# 如果文件内容以 "version" 开头,可能是由于缺少 git-lfs 导致的错误
if f.read().startswith("version"):
raise OSError(
"You seem to have cloned a repository without having git-lfs installed. Please"
" install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
" folder you cloned."
)
else:
# 否则,抛出 ValueError 并将原始异常作为其原因
raise ValueError from e
# 捕获可能的解码或数值错误
except (UnicodeDecodeError, ValueError):
# 抛出环境错误,指示无法将文件转换为 Flax 可反序列化对象
raise EnvironmentError(f"Unable to convert {resolved_archive_file} to Flax deserializable object. ")
# 返回加载的状态对象
return state
@classmethod
def load_flax_sharded_weights(cls, shard_files):
"""
This is the same as [`flax.serialization.from_bytes`](https://flax.readthedocs.io/en/latest/_modules/flax/serialization.html#from_bytes) but for a sharded checkpoint.
This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being
loaded in the model.
Args:
shard_files (`List[str]`):
The list of shard files to load.
Returns:
`Dict`: A nested dictionary of the model parameters, in the expected format for flax models : `{'model':
{'params': {'...'}}}`.
"""
# Load the index
state_sharded_dict = {}
for shard_file in shard_files:
# load using msgpack utils
try:
with open(shard_file, "rb") as state_f:
state = from_bytes(cls, state_f.read())
except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
# Handle specific error cases
with open(shard_file) as f:
if f.read().startswith("version"):
raise OSError(
"You seem to have cloned a repository without having git-lfs installed. Please"
" install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
" folder you cloned."
)
else:
raise ValueError from e
except (UnicodeDecodeError, ValueError):
# Raise an environment error if conversion fails
raise EnvironmentError(f"Unable to convert {shard_file} to Flax deserializable object. ")
# Flatten the state dictionary using '/' separator
state = flatten_dict(state, sep="/")
# Update the main dictionary with the flattened state
state_sharded_dict.update(state)
# Clean up the `state` variable from memory
del state
# Perform garbage collection to free up memory
gc.collect()
# Unflatten the state_sharded_dict to match the format of model.params
return unflatten_dict(state_sharded_dict, sep="/")
@classmethod
def can_generate(cls) -> bool:
"""
Returns whether this model can generate sequences with `.generate()`.
Returns:
`bool`: Whether this model can generate sequences with `.generate()`.
"""
# Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation.
# Alternatively, the model can also have a custom `generate` function.
if "GenerationMixin" in str(cls.prepare_inputs_for_generation) and "GenerationMixin" in str(cls.generate):
# If both conditions are met, return False indicating generation capability is not supported
return False
# If not, return True indicating generation capability is supported
return True
def from_pretrained(
cls,
pretrained_model_name_or_path: Union[str, os.PathLike], # 接受预训练模型名称或路径作为输入参数
dtype: jnp.dtype = jnp.float32, # 指定数据类型,默认为 jnp.float32
*model_args, # 其余位置参数
config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None, # 预训练配置对象或其路径,可选参数
cache_dir: Optional[Union[str, os.PathLike]] = None, # 缓存目录的路径,可选参数
ignore_mismatched_sizes: bool = False, # 是否忽略大小不匹配的情况,默认为 False
force_download: bool = False, # 是否强制下载,默认为 False
local_files_only: bool = False, # 是否仅使用本地文件,默认为 False
token: Optional[Union[str, bool]] = None, # token 用于验证,可选参数
revision: str = "main", # 版本号,默认为 "main"
**kwargs, # 其余关键字参数
):
"""
从预训练模型加载模型参数和配置。
<Tip warning={true}>
当前 API 处于实验阶段,未来版本可能会有一些轻微的更改。
</Tip>
Args:
pretrained_model_name_or_path (`str` or `os.PathLike`):
预训练模型的名称或路径。
dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
指定加载参数时使用的数据类型,默认为 jnp.float32。
*model_args:
其余位置参数,传递给具体模型加载函数。
config (`PretrainedConfig`, `str`, `os.PathLike`, *optional*, defaults to `None`):
预训练模型的配置对象或其路径,可选参数。
cache_dir (`str` or `os.PathLike`, *optional*, defaults to `None`):
缓存目录的路径,可选参数。
ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`):
是否忽略加载参数时大小不匹配的情况,默认为 False。
force_download (`bool`, *optional*, defaults to `False`):
是否强制重新下载模型,默认为 False。
local_files_only (`bool`, *optional*, defaults to `False`):
是否仅使用本地文件加载模型,默认为 False。
token (`str` or `bool`, *optional*, defaults to `None`):
token 用于验证下载的模型,可选参数。
revision (`str`, *optional*, defaults to `"main"`):
模型的版本号,默认为 "main"。
**kwargs:
其余关键字参数,传递给具体模型加载函数。
"""
def save_pretrained(
self,
save_directory: Union[str, os.PathLike], # 保存模型的目录路径
params=None, # 要保存的模型参数,默认为 None
push_to_hub=False, # 是否推送到模型 Hub,默认为 False
max_shard_size="10GB", # 最大的分片大小,默认为 "10GB"
token: Optional[Union[str, bool]] = None, # token 用于验证,可选参数
safe_serialization: bool = False, # 是否进行安全序列化,默认为 False
**kwargs, # 其余关键字参数
):
"""
将当前模型保存到指定目录。
Args:
save_directory (`str` or `os.PathLike`):
保存模型的目录路径。
params:
要保存的模型参数,默认为 None。
push_to_hub (`bool`, *optional*, defaults to `False`):
是否将模型推送到模型 Hub,默认为 False。
max_shard_size (`str`, *optional*, defaults to `"10GB"`):
最大的分片大小,默认为 "10GB"。
token (`str` or `bool`, *optional*, defaults to `None`):
token 用于验证保存操作,可选参数。
safe_serialization (`bool`, *optional*, defaults to `False`):
是否进行安全序列化,默认为 False。
**kwargs:
其余关键字参数,传递给具体保存函数。
"""
@classmethod
def register_for_auto_class(cls, auto_class="FlaxAutoModel"):
"""
注册当前模型类到指定的自动加载类。仅用于自定义模型,因为库中的模型已经与自动加载类映射。
<Tip warning={true}>
当前 API 处于实验阶段,未来版本可能会有一些轻微的更改。
</Tip>
Args:
auto_class (`str` or `type`, *optional*, defaults to `"FlaxAutoModel"`):
要注册新模型的自动加载类名称或类型。
"""
if not isinstance(auto_class, str):
auto_class = auto_class.__name__
import transformers.models.auto as auto_module
if not hasattr(auto_module, auto_class):
raise ValueError(f"{auto_class} is not a valid auto class.")
cls._auto_class = auto_class
# 复制 FlaxPreTrainedModel 类中的 push_to_hub 方法,以确保我们修改的是其副本而非原始方法
FlaxPreTrainedModel.push_to_hub = copy_func(FlaxPreTrainedModel.push_to_hub)
# 如果 push_to_hub 方法已有文档字符串,则使用格式化字符串来更新其文档字符串,将对象类型、对象类名和对象文件类型作为参数插入
if FlaxPreTrainedModel.push_to_hub.__doc__ is not None:
FlaxPreTrainedModel.push_to_hub.__doc__ = FlaxPreTrainedModel.push_to_hub.__doc__.format(
object="model", object_class="FlaxAutoModel", object_files="model checkpoint"
)
def overwrite_call_docstring(model_class, docstring):
# 复制 model_class 的 __call__ 方法,以确保仅修改该函数的文档字符串
model_class.__call__ = copy_func(model_class.__call__)
# 删除现有的 __call__ 方法的文档字符串
model_class.__call__.__doc__ = None
# 设置正确的 __call__ 方法文档字符串,使用指定的 docstring
model_class.__call__ = add_start_docstrings_to_model_forward(docstring)(model_class.__call__)
def append_call_sample_docstring(
model_class, checkpoint, output_type, config_class, mask=None, revision=None, real_checkpoint=None
):
# 复制 model_class 的 __call__ 方法,以确保仅修改该函数的文档字符串
model_class.__call__ = copy_func(model_class.__call__)
# 使用 add_code_sample_docstrings 函数为 __call__ 方法添加代码示例的文档字符串,传入相关参数
model_class.__call__ = add_code_sample_docstrings(
checkpoint=checkpoint,
output_type=output_type,
config_class=config_class,
model_cls=model_class.__name__,
revision=revision,
real_checkpoint=real_checkpoint,
)(model_class.__call__)
def append_replace_return_docstrings(model_class, output_type, config_class):
# 复制 model_class 的 __call__ 方法,以确保仅修改该函数的文档字符串
model_class.__call__ = copy_func(model_class.__call__)
# 使用 replace_return_docstrings 函数替换 __call__ 方法的返回值相关文档字符串,传入输出类型和配置类参数
model_class.__call__ = replace_return_docstrings(
output_type=output_type,
config_class=config_class,
)(model_class.__call__)
.\modeling_outputs.py
# 导入警告模块,用于可能的警告消息
import warnings
# 导入 dataclass 模块,用于定义数据类
from dataclasses import dataclass
# 导入 Optional 和 Tuple 类型提示
from typing import Optional, Tuple
# 导入 PyTorch 模块
import torch
# 从当前包中导入 ModelOutput 类
from .utils import ModelOutput
# 定义 BaseModelOutput 数据类,继承自 ModelOutput
@dataclass
class BaseModelOutput(ModelOutput):
"""
模型输出的基类,可能包含隐藏状态和注意力。
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
模型最后一层的隐藏状态序列。
hidden_states (`tuple(torch.FloatTensor)`, *optional*, 当 `output_hidden_states=True` 时返回或当 `config.output_hidden_states=True` 时返回):
包含每层输出的元组 `torch.FloatTensor`(如果模型有嵌入层,则包含嵌入层输出),
形状为 `(batch_size, sequence_length, hidden_size)`。
模型在每层的隐藏状态以及可选的初始嵌入输出。
attentions (`tuple(torch.FloatTensor)`, *optional*, 当 `output_attentions=True` 时返回或当 `config.output_attentions=True` 时返回):
包含每层注意力权重的元组 `torch.FloatTensor`,
形状为 `(batch_size, num_heads, sequence_length, sequence_length)`。
注意力权重经过 softmax 后的结果,在自注意力头中用于计算加权平均值。
"""
# 最后的隐藏状态,默认为 None
last_hidden_state: torch.FloatTensor = None
# 隐藏状态的元组,可选,默认为 None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
# 注意力权重的元组,可选,默认为 None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
# 定义 BaseModelOutputWithNoAttention 数据类,继承自 ModelOutput
@dataclass
class BaseModelOutputWithNoAttention(ModelOutput):
"""
模型输出的基类,仅包含潜在的隐藏状态。
继承自 ModelOutput。
"""
# 定义函数的参数说明和类型注解。`last_hidden_state`参数是一个 `torch.FloatTensor` 类型的张量,
# 其形状为 `(batch_size, num_channels, height, width)`,表示模型最后一层的隐藏状态序列。
# `hidden_states`参数是一个可选的元组类型,包含 `torch.FloatTensor` 类型的张量,
# 形状为 `(batch_size, num_channels, height, width)`。这个元组用于存储模型每一层的隐藏状态,
# 如果模型具有嵌入层,则还包括初始嵌入层的输出。
last_hidden_state: torch.FloatTensor = None
# 初始化 `last_hidden_state` 变量为 `None`,表示这个参数可以不提供具体值。
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
# 初始化 `hidden_states` 变量为 `None`,表示这个参数也可以不提供具体值。
# 它是一个可选的元组类型,元组中的每个元素都是 `torch.FloatTensor` 类型的张量。
# 如果提供了 `output_hidden_states=True` 或 `config.output_hidden_states=True`,
# 这个元组会包含模型每一层的隐藏状态和可能的初始嵌入层的输出。
# 使用 dataclass 装饰器定义一个基础模型输出类,包含池化后的最后隐藏状态等信息
@dataclass
class BaseModelOutputWithPooling(ModelOutput):
"""
Base class for model's outputs that also contains a pooling of the last hidden states.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
Last layer hidden-state of the first token of the sequence (classification token) after further processing
through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns
the classification token after processing through a linear layer and a tanh activation function. The linear
layer weights are trained from the next sentence prediction (classification) objective during pretraining.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
# 定义类变量,存储最后隐藏状态的张量
last_hidden_state: torch.FloatTensor = None
# 定义类变量,存储经过池化处理后的分类池化输出
pooler_output: torch.FloatTensor = None
# 定义类变量,存储模型隐藏状态的元组(每层的输出)
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
# 定义类变量,存储注意力权重的元组(每层的注意力权重)
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
# 使用 dataclass 装饰器定义一个基础模型输出类,包含池化后的最后隐藏状态,但不包含注意力权重信息
@dataclass
class BaseModelOutputWithPoolingAndNoAttention(ModelOutput):
"""
Base class for model's outputs that also contains a pooling of the last hidden states.
"""
# `last_hidden_state`是模型最后一层的隐藏状态,形状为(batch_size, num_channels, height, width)
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Sequence of hidden-states at the output of the last layer of the model.
# `pooler_output`是经过空间维度池化操作后的最后一层隐藏状态,形状为(batch_size, hidden_size)
pooler_output: torch.FloatTensor = None
# `hidden_states`是一个元组,包含了模型每一层的隐藏状态输出,如果模型有嵌入层,还包括初始嵌入输出。
# 每个张量的形状为(batch_size, num_channels, height, width)。
# 可选的返回值,当`output_hidden_states=True`被传递或者`config.output_hidden_states=True`时返回。
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
# 使用 dataclass 装饰器声明一个数据类,表示带有过去键/值的模型输出
@dataclass
class BaseModelOutputWithPast(ModelOutput):
"""
Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
hidden_size)` is output.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
`config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
input) to speed up sequential decoding.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
# 定义类成员变量,表示模型输出中的最后隐藏状态、过去键/值、隐藏状态、注意力权重
last_hidden_state: torch.FloatTensor = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
# 使用 dataclass 装饰器声明一个数据类,表示带有交叉注意力的模型输出
@dataclass
class BaseModelOutputWithCrossAttentions(ModelOutput):
"""
Base class for model's outputs, with potential hidden states and attentions.
"""
# `last_hidden_state`是模型最后一层的输出隐藏状态,形状为(batch_size, sequence_length, hidden_size)
last_hidden_state: torch.FloatTensor = None
# `hidden_states`是一个元组,包含模型每一层的隐藏状态(如果模型有嵌入层,则包含初始嵌入输出),
# 其形状为(batch_size, sequence_length, hidden_size)
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
# `attentions`是一个元组,包含每一层的注意力权重张量,形状为(batch_size, num_heads, sequence_length, sequence_length),
# 用于计算自注意力头中的加权平均值
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
# `cross_attentions`是一个元组,包含解码器交叉注意力层的注意力权重张量,形状为(batch_size, num_heads, sequence_length, sequence_length),
# 用于计算交叉注意力头中的加权平均值
cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
# 带有池化和交叉注意力的基础模型输出类,继承自ModelOutput
@dataclass
class BaseModelOutputWithPoolingAndCrossAttentions(ModelOutput):
"""
Base class for model's outputs that also contains a pooling of the last hidden states.
模型输出的基础类,还包含最后隐藏状态的池化。
"""
# 最后隐藏状态,类型为 torch.FloatTensor
last_hidden_state: torch.FloatTensor = None
# 池化输出,类型为 torch.FloatTensor
pooler_output: torch.FloatTensor = None
# 隐藏状态的元组,可选类型为 torch.FloatTensor
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
# 过去键/值的元组,可选类型为 Tuple[Tuple[torch.FloatTensor]]
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
# 注意力权重的元组,可选类型为 torch.FloatTensor
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
# 交叉注意力权重的元组,可选类型为 torch.FloatTensor
cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
@dataclass
class BaseModelOutputWithPastAndCrossAttentions(ModelOutput):
"""
Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
模型输出的基础类,可能还包含过去的键/值(用于加速顺序解码)。
"""
# 定义函数参数 `last_hidden_state`,表示模型最后一层的隐藏状态,是一个形状为 `(batch_size, sequence_length, hidden_size)` 的张量。
# 如果使用了 `past_key_values`,则输出的是形状为 `(batch_size, 1, hidden_size)` 的序列最后隐藏状态。
last_hidden_state: torch.FloatTensor = None
# 定义函数可选参数 `past_key_values`,是一个元组,包含了预计算的隐藏状态键值对(在自注意力块中的键和值),长度为 `config.n_layers`,每个元组包含两个张量。
# 如果 `config.is_encoder_decoder=True`,还包括两个形状为 `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)` 的张量。
# 用于加速顺序解码过程。
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
# 定义函数可选参数 `hidden_states`,是一个元组,包含了模型每一层的隐藏状态。
# 如果模型有嵌入层,第一个张量是嵌入层的输出;每个张量的形状为 `(batch_size, sequence_length, hidden_size)`。
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
# 定义函数可选参数 `attentions`,是一个元组,包含了每一层的注意力权重。
# 每个张量的形状为 `(batch_size, num_heads, sequence_length, sequence_length)`。
# 这些权重在自注意力头中的注意力 softmax 之后,用于计算加权平均值。
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
# 定义函数可选参数 `cross_attentions`,是一个元组,包含了解码器交叉注意力层的注意力权重。
# 每个张量的形状为 `(batch_size, num_heads, sequence_length, sequence_length)`。
# 这些权重在交叉注意力头中的注意力 softmax 之后,用于计算加权平均值。
cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
@dataclass
class MoECausalLMOutputWithPast(ModelOutput):
"""
Base class for causal language model (or autoregressive) outputs as well as Mixture of Expert's router hidden
states terms, to train a MoE model.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss (for next-token prediction).
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
z_loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided):
z_loss for the sparse modules.
aux_loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided):
aux_loss for the sparse modules.
router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True` is passed or when `config.add_router_probs=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`.
Router logits of the encoder model, useful to compute the auxiliary loss and the z_loss for the sparse
modules.
"""
# loss 属性,表示语言模型损失,用于下一个标记的预测
loss: Optional[torch.FloatTensor] = None
# logits 属性,表示语言模型头部的预测分数(SoftMax 前的每个词汇标记的分数)
logits: torch.FloatTensor = None
# past_key_values 属性,存储预先计算的隐藏状态(注意力机制的键值对),用于加速顺序解码
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
# hidden_states 属性,存储每一层的模型隐藏状态,包括可选的初始嵌入输出
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
# attentions 属性,存储每一层的注意力权重,用于计算自注意力头部的加权平均值
attentions: Optional[Tuple[torch.FloatTensor]] = None
# z_loss 属性,用于稀疏模块的 z_loss
z_loss: Optional[torch.FloatTensor] = None
# aux_loss 属性,用于稀疏模块的 aux_loss
aux_loss: Optional[torch.FloatTensor] = None
# router_logits 属性,存储编码器模型的路由器 logits,用于计算稀疏模块的辅助损失和 z_loss
router_logits: Optional[Tuple[torch.FloatTensor]] = None
# 定义一个可选的元组 attentions,元组中包含了多个 torch.FloatTensor 类型的张量,初始值为 None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
# 定义一个 torch.FloatTensor 类型的张量 z_loss,初始值为 None
z_loss: torch.FloatTensor = None
# 定义一个 torch.FloatTensor 类型的张量 aux_loss,初始值为 None
aux_loss: torch.FloatTensor = None
# 定义一个可选的单元素元组 router_logits,元组中包含了一个 torch.FloatTensor 类型的张量,初始值为 None
router_logits: Optional[Tuple[torch.FloatTensor]] = None
# 用于定义模型输出的数据类,继承自ModelOutput类
@dataclass
class MoEModelOutput(ModelOutput):
"""
Base class for model's outputs, with potential hidden states and attentions.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
模型最后一层的输出隐藏状态序列。
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
每一层模型的隐藏状态输出,包括初始嵌入层的输出(如果存在)。
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
经过注意力softmax后的注意力权重,用于计算自注意力头中的加权平均值。
router_probs (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`.
Raw router probabilities that are computed by MoE routers, these terms are used to compute the auxiliary
loss and the z_loss for Mixture of Experts models.
由MoE路由器计算得到的原始路由器概率,用于计算混合专家模型的辅助损失和z_loss。
"""
# 定义最后一个隐藏状态,默认为None
last_hidden_state: torch.FloatTensor = None
# 定义隐藏状态的元组,可选,当output_hidden_states=True或config.output_hidden_states=True时返回
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
# 定义注意力的元组,可选,当output_attentions=True或config.output_attentions=True时返回
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
# 定义路由器概率的元组,可选,当output_router_probs=True和config.add_router_probs=True或config.output_router_probs=True时返回
router_probs: Optional[Tuple[torch.FloatTensor]] = None
# 用于定义模型输出的数据类,继承自ModelOutput类,并且包含过去隐藏状态和注意力
@dataclass
class MoeModelOutputWithPast(ModelOutput):
"""
Base class for model's outputs, with potential hidden states and attentions.
"""
# 定义输入参数 `last_hidden_state`,类型为 `torch.FloatTensor`,形状为 `(batch_size, sequence_length, hidden_size)`
last_hidden_state: torch.FloatTensor = None
# 定义输入参数 `past_key_values`,类型为 `Optional[Tuple[Tuple[torch.FloatTensor]]]`,可选参数,
# 当使用 `use_cache=True` 或 `config.use_cache=True` 时返回,包含预先计算的隐藏状态(在自注意力块中的键和值)
# 如果 `config.is_encoder_decoder=True`,还包含交叉注意力块中的隐藏状态
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
# 定义输入参数 `hidden_states`,类型为 `Optional[Tuple[torch.FloatTensor, ...]]`,可选参数,
# 当使用 `output_hidden_states=True` 或 `config.output_hidden_states=True` 时返回,
# 包含模型每一层的隐藏状态,以及可能的初始嵌入层输出
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
# 定义输入参数 `attentions`,类型为 `Optional[Tuple[torch.FloatTensor, ...]]`,可选参数,
# 当使用 `output_attentions=True` 或 `config.output_attentions=True` 时返回,
# 包含注意力 softmax 后的注意力权重,用于在自注意力头中计算加权平均值
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
# 定义输入参数 `router_logits`,类型为 `Optional[Tuple[torch.FloatTensor]]`,可选参数,
# 当使用 `output_router_probs=True` 和 `config.add_router_probs=True` 时返回,
# 包含 MoE 路由器计算的原始路由器 logits(经 softmax 处理后),用于混合专家模型的辅助损失计算
router_logits: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class MoeCausalLMOutputWithPast(ModelOutput):
"""
Base class for causal language model (or autoregressive) with mixture of experts outputs.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss (for next-token prediction).
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
aux_loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided):
Auxiliary loss for the sparse modules.
router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`.
Raw router logits (post-softmax) computed by MoE routers, used for computing auxiliary loss in Mixture of Experts models.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, each tuple containing 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`.
Pre-computed hidden states (keys and values in self-attention blocks) for speeding up sequential decoding.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for embedding layer output if present, plus one for each layer) of shape
`(batch_size, sequence_length, hidden_size)`.
Hidden states of the model at each layer's output, including optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`.
Attention weights after the softmax operation, used for computing weighted averages in self-attention heads.
"""
loss: Optional[torch.FloatTensor] = None
aux_loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
# 定义了一个变量 router_logits,类型为 Optional[Tuple[torch.FloatTensor]],初始值为 None
@dataclass
class MoEModelOutputWithPastAndCrossAttentions(ModelOutput):
"""
Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding) as well as
Mixture of Expert's router hidden states terms, to train a MoE model.
"""
last_hidden_state: torch.FloatTensor = None # 最后一个隐藏状态的张量,默认为None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None # 可选的过去键/值对的元组,用于加速序列解码
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None # 可选的隐藏状态的元组
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None # 可选的注意力分布的元组
cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None # 可选的交叉注意力分布的元组
router_probs: Optional[Tuple[torch.FloatTensor]] = None # 可选的路由器概率的元组
@dataclass
class Seq2SeqModelOutput(ModelOutput):
"""
Base class for model encoder's outputs that also contains : pre-computed hidden states that can speed up sequential
decoding.
"""
last_hidden_state: torch.FloatTensor = None # 最后一个隐藏状态的张量,默认为None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None # 可选的过去键/值对的元组,用于加速序列解码
decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None # 可选的解码器隐藏状态的元组
decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None # 可选的解码器注意力分布的元组
cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None # 可选的交叉注意力分布的元组
encoder_last_hidden_state: Optional[torch.FloatTensor] = None # 可选的编码器最后一个隐藏状态的张量
encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None # 可选的编码器隐藏状态的元组
encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None # 可选的编码器注意力分布的元组
@dataclass
class Seq2SeqMoEModelOutput(ModelOutput):
"""
Base class for model encoder's outputs that also contains : pre-computed hidden states that can speed up sequential
decoding.
"""
last_hidden_state: torch.FloatTensor = None # 最后一个隐藏状态的张量,默认为None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None # 可选的过去键/值对的元组,用于加速序列解码
decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None # 可选的解码器隐藏状态的元组
decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None # 可选的解码器注意力分布的元组
decoder_router_logits: Optional[Tuple[torch.FloatTensor]] = None # 可选的解码器路由器逻辑概率的元组
cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None # 可选的交叉注意力分布的元组
encoder_last_hidden_state: Optional[torch.FloatTensor] = None # 可选的编码器最后一个隐藏状态的张量
encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None # 可选的编码器隐藏状态的元组
encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None # 可选的编码器注意力分布的元组
encoder_router_logits: Optional[Tuple[torch.FloatTensor]] = None # 可选的编码器路由器逻辑概率的元组
@dataclass
class CausalLMOutput(ModelOutput):
"""
Base class for causal language model (or autoregressive) outputs.
"""
# 这里没有定义任何字段,但作为基类提供了一个用于因果语言模型输出的基础类
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss (for next-token prediction).
语言建模损失(用于下一个标记预测),是一个形状为 `(1,)` 的 `torch.FloatTensor`,当提供 `labels` 时返回。
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
语言建模头部的预测分数(在 SoftMax 之前每个词汇标记的得分),形状为 `(batch_size, sequence_length, config.vocab_size)` 的 `torch.FloatTensor`。
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
模型在每层输出的隐藏状态,以及可选的初始嵌入输出,形状为 `(batch_size, sequence_length, hidden_size)` 的 `torch.FloatTensor` 元组。
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
注意力 softmax 后的注意力权重,用于计算自注意力头部中的加权平均,形状为 `(batch_size, num_heads, sequence_length, sequence_length)` 的 `torch.FloatTensor` 元组。
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
# 使用 dataclass 装饰器定义一个类,用于表示因果语言模型(或自回归模型)的输出结果,继承自 ModelOutput 类。
@dataclass
class CausalLMOutputWithPast(ModelOutput):
"""
因果语言模型(或自回归模型)输出的基类。
Args:
loss (`torch.FloatTensor`,形状为 `(1,)`,*可选*,当提供 `labels` 参数时返回):
语言建模的损失(用于下一个标记的预测)。
logits (`torch.FloatTensor`,形状为 `(batch_size, sequence_length, config.vocab_size)`):
语言建模头部的预测分数(每个词汇标记的分数,在 SoftMax 之前)。
past_key_values (`tuple(tuple(torch.FloatTensor))`,*可选*,当传递 `use_cache=True` 或 `config.use_cache=True` 时返回):
包含预先计算的隐藏状态(自注意力块中的键和值),可用于加速顺序解码。
是一个长度为 `config.n_layers` 的元组,每个元组包含 2 个形状为 `(batch_size, num_heads, sequence_length, embed_size_per_head)` 的张量。
hidden_states (`tuple(torch.FloatTensor)`,*可选*,当传递 `output_hidden_states=True` 或 `config.output_hidden_states=True` 时返回):
包含模型在每一层输出的隐藏状态张量的元组(如果模型有嵌入层,则包含初始嵌入输出),
形状为 `(batch_size, sequence_length, hidden_size)`。
attentions (`tuple(torch.FloatTensor)`,*可选*,当传递 `output_attentions=True` 或 `config.output_attentions=True` 时返回):
自注意力头部中注意力 softmax 后的注意力权重张量的元组(每层一个),形状为 `(batch_size, num_heads, sequence_length, sequence_length)`。
"""
# 以下是类的字段定义
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
# 使用 dataclass 装饰器定义另一个类,用于表示具有交叉注意力的因果语言模型(或自回归模型)的输出结果,继承自 ModelOutput 类。
@dataclass
class CausalLMOutputWithCrossAttentions(ModelOutput):
"""
因果语言模型(或自回归模型)输出的基类,具有交叉注意力。
这个类继承自 ModelOutput。
"""
# 注意:这里的类定义未完全提供,根据文档字符串需要添加额外的字段和解释。
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss (for next-token prediction).
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Cross attentions weights after the attention softmax, used to compute the weighted average in the
cross-attention heads.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `torch.FloatTensor` tuples of length `config.n_layers`, with each tuple containing the cached key,
value states of the self-attention and the cross-attention layers if model is used in encoder-decoder
setting. Only relevant if `config.is_decoder = True`.
Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
@dataclass
class SequenceClassifierOutputWithPast(ModelOutput):
"""
Base class for outputs of sequence classification models that also include past key values,
hidden states, and attentions.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Classification (or regression if config.num_labels==1) loss.
logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attention weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
@dataclass
class MaskedLMOutput(ModelOutput):
"""
Base class for outputs of masked language models.
This class inherits `ModelOutput`, indicating it provides standard output for models.
"""
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Masked language modeling (MLM) loss.
掩码语言建模(MLM)损失,当提供`labels`时返回此值。
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
语言建模头的预测分数(SoftMax之前的每个词汇标记的分数)。
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer,
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
模型在每一层输出的隐藏状态,以及可选的初始嵌入输出。
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
注意力权重,经过注意力SoftMax后的值,用于计算自注意力头中的加权平均值。
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
@dataclass
class Seq2SeqLMOutput(ModelOutput):
"""
Base class for sequence-to-sequence language models outputs.
"""
# 损失值,可选的浮点张量,用于表示模型的损失
loss: Optional[torch.FloatTensor] = None
# 输出的对数概率值,用于表示模型生成的对数概率
logits: torch.FloatTensor = None
# 过去的键值,可选的嵌套元组,用于存储过去的键值
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
# 解码器隐藏状态,可选的浮点张量元组,表示解码器的隐藏状态
decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
# 解码器注意力权重,可选的浮点张量元组,表示解码器的注意力权重
decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
# 交叉注意力权重,可选的浮点张量元组,表示编码器-解码器的交叉注意力权重
cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
# 编码器最后的隐藏状态,可选的浮点张量,表示编码器的最后隐藏状态
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
# 编码器隐藏状态,可选的浮点张量元组,表示编码器的隐藏状态
encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
# 编码器注意力权重,可选的浮点张量元组,表示编码器的注意力权重
encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
@dataclass
class Seq2SeqMoEOutput(ModelOutput):
"""
Base class for sequence-to-sequence language models outputs.
"""
# 损失值,可选的浮点张量,用于表示模型的损失
loss: Optional[torch.FloatTensor] = None
# 输出的对数概率值,用于表示模型生成的对数概率
logits: torch.FloatTensor = None
# 编码器 Z 损失,用于表示编码器的 Z 损失
encoder_z_loss: torch.FloatTensor = None
# 解码器 Z 损失,用于表示解码器的 Z 损失
decoder_z_loss: torch.FloatTensor = None
# 编码器辅助损失,用于表示编码器的辅助损失
encoder_aux_loss: torch.FloatTensor = None
# 解码器辅助损失,用于表示解码器的辅助损失
decoder_aux_loss: torch.FloatTensor = None
# 过去的键值,可选的嵌套元组,用于存储过去的键值
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
# 解码器隐藏状态,可选的浮点张量元组,表示解码器的隐藏状态
decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
# 解码器注意力权重,可选的浮点张量元组,表示解码器的注意力权重
decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
# 解码器路由器对数概率,可选的浮点张量,表示解码器的路由器对数概率
decoder_router_logits: Optional[Tuple[torch.FloatTensor]] = None
# 交叉注意力权重,可选的浮点张量元组,表示编码器-解码器的交叉注意力权重
cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
# 编码器最后的隐藏状态,可选的浮点张量,表示编码器的最后隐藏状态
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
# 编码器隐藏状态,可选的浮点张量元组,表示编码器的隐藏状态
encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
# 编码器注意力权重,可选的浮点张量元组,表示编码器的注意力权重
encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
# 编码器路由器对数概率,可选的浮点张量,表示编码器的路由器对数概率
encoder_router_logits: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class NextSentencePredictorOutput(ModelOutput):
"""
Base class for outputs of models predicting if two sentences are consecutive or not.
# 定义 loss 变量,用于存储下一个序列预测(分类)的损失值,类型为 torch.FloatTensor,可选参数,当提供 `next_sentence_label` 时返回。
loss: Optional[torch.FloatTensor] = None
# 定义 logits 变量,用于存储下一个序列预测(分类)头部的预测分数,形状为 `(batch_size, 2)` 的 torch.FloatTensor。
logits: torch.FloatTensor = None
# 定义 hidden_states 变量,用于存储模型每一层的隐藏状态输出,类型为元组 `Tuple[torch.FloatTensor, ...]`,可选参数,当 `output_hidden_states=True` 或 `config.output_hidden_states=True` 时返回。
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
# 定义 attentions 变量,用于存储注意力权重输出,类型为元组 `Tuple[torch.FloatTensor, ...]`,可选参数,当 `output_attentions=True` 或 `config.output_attentions=True` 时返回。
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
# 定义一个数据类,用于表示序列分类器模型的输出结果
@dataclass
class SequenceClassifierOutput(ModelOutput):
"""
Base class for outputs of sentence classification models.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Classification (or regression if config.num_labels==1) loss.
分类模型的损失值(如果提供了`labels`):一个形状为`(1,)`的`torch.FloatTensor`,在提供`labels`时返回。
logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
分类(或回归,如果`config.num_labels==1`)得分(SoftMax之前)的`torch.FloatTensor`,形状为`(batch_size, config.num_labels)`。
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
模型每一层的输出的隐藏状态,以及可选的初始嵌入输出。形状为`(batch_size, sequence_length, hidden_size)`。
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
注意力权重(经过注意力SoftMax后的)的元组,用于计算自注意力头中的加权平均值。形状为`(batch_size, num_heads, sequence_length, sequence_length)`。
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
# 定义一个数据类,用于表示序列到序列的句子分类器模型的输出结果
@dataclass
class Seq2SeqSequenceClassifierOutput(ModelOutput):
"""
Base class for outputs of sequence-to-sequence sentence classification models.
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
# 定义一个数据类,用于表示多选模型的输出结果
@dataclass
class MultipleChoiceModelOutput(ModelOutput):
"""
Base class for outputs of multiple choice models.
"""
"""
Args:
loss (`torch.FloatTensor` of shape *(1,)*, *optional*, returned when `labels` is provided):
分类损失值。
如果提供了`labels`,则返回此损失值。
logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`):
`num_choices` 是输入张量的第二个维度。
分类分数(SoftMax 之前的值)。
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
一个元组,包含 `torch.FloatTensor` 的张量。
第一个张量是模型嵌入层的输出(如果存在),每一层输出的张量的形状为 `(batch_size, sequence_length, hidden_size)`。
模型在每一层输出的隐藏状态,加上可选的初始嵌入输出。
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
一个元组,包含 `torch.FloatTensor` 的张量。
每一层的注意力权重张量的形状为 `(batch_size, num_heads, sequence_length, sequence_length)`。
经过注意力 softmax 后的注意力权重,用于计算自注意力头中的加权平均值。
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
# 用于描述令牌分类模型输出的基础类
@dataclass
class TokenClassifierOutput(ModelOutput):
"""
Base class for outputs of token classification models.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided) :
Classification loss.
分类损失,当提供 `labels` 参数时返回。
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`):
Classification scores (before SoftMax).
分类分数(SoftMax 之前的结果)。
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
模型在每一层输出的隐藏状态,以及可选的初始嵌入层输出。
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
注意力权重,经过注意力 SoftMax 后的结果,用于计算自注意力头中的加权平均值。
"""
@dataclass
class QuestionAnsweringModelOutput(ModelOutput):
"""
Base class for outputs of question answering models.
用于描述问答模型输出的基础类。
This class is currently empty but can be extended with specific outputs of QA models.
该类目前为空,但可以通过扩展以包含问答模型的特定输出。
"""
# 定义函数参数和返回值的文档字符串,描述了函数的输入和输出
loss: Optional[torch.FloatTensor] = None
# 可选的损失张量,当提供了 `labels` 参数时返回
start_logits: torch.FloatTensor = None
# 开始位置的得分张量,形状为 `(batch_size, sequence_length)`
end_logits: torch.FloatTensor = None
# 结束位置的得分张量,形状为 `(batch_size, sequence_length)`
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
# 可选的隐藏状态元组,包含每层输出的张量,形状为 `(batch_size, sequence_length, hidden_size)`
# 如果模型有嵌入层,则还包含初始嵌入输出
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
# 可选的注意力权重元组,包含每层注意力权重的张量
# 形状为 `(batch_size, num_heads, sequence_length, sequence_length)`
# 用于计算自注意力头中的加权平均值的注意力 softmax 后的注意力权重
# 定义了一个数据类,用于存储序列到序列问答模型的输出结果
@dataclass
class Seq2SeqQuestionAnsweringModelOutput(ModelOutput):
"""
Base class for outputs of sequence-to-sequence question answering models.
"""
# 损失值,如果存在的话,类型为 torch.FloatTensor
loss: Optional[torch.FloatTensor] = None
# 开始位置的预测 logits,类型为 torch.FloatTensor
start_logits: torch.FloatTensor = None
# 结束位置的预测 logits,类型为 torch.FloatTensor
end_logits: torch.FloatTensor = None
# 过去的键值,类型为可选的元组,包含了一系列 torch.FloatTensor
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
# 解码器的隐藏状态,类型为可选的元组,包含了一系列 torch.FloatTensor
decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
# 解码器的注意力权重,类型为可选的元组,包含了一系列 torch.FloatTensor
decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
# 交叉注意力权重,类型为可选的元组,包含了一系列 torch.FloatTensor
cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
# 编码器最后的隐藏状态,类型为可选的 torch.FloatTensor
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
# 编码器的隐藏状态,类型为可选的元组,包含了一系列 torch.FloatTensor
encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
# 编码器的注意力权重,类型为可选的元组,包含了一系列 torch.FloatTensor
encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
# 定义了一个数据类,用于存储语义分割模型的输出结果
@dataclass
class SemanticSegmenterOutput(ModelOutput):
"""
Base class for outputs of semantic segmentation models.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Classification (or regression if config.num_labels==1) loss.
logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels, logits_height, logits_width)`):
Classification scores for each pixel.
<Tip warning={true}>
The logits returned do not necessarily have the same size as the `pixel_values` passed as inputs. This is
to avoid doing two interpolations and lose some quality when a user needs to resize the logits to the
original image size as post-processing. You should always check your logits shape and resize as needed.
</Tip>
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, patch_size, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
# 损失值,如果存在的话,类型为 torch.FloatTensor
loss: Optional[torch.FloatTensor] = None
# 分类得分 logits,类型为 torch.FloatTensor,形状为 (batch_size, config.num_labels, logits_height, logits_width)
logits: torch.FloatTensor = None
# 隐藏状态,类型为可选的元组,包含了一系列 torch.FloatTensor
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
# 注意力权重,类型为可选的元组,包含了一系列 torch.FloatTensor
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
# 定义了一个数据类,用于存储图像分类模型的输出结果
@dataclass
class ImageClassifierOutput(ModelOutput):
"""
Base class for outputs of image classification models.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
分类(或者回归,如果 `config.num_labels==1`)的损失值。
logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
分类(或者回归,如果 `config.num_labels==1`)的分数(SoftMax 之前)。
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
元组类型的 `torch.FloatTensor`,包含模型在每个阶段输出的隐藏状态(特征映射),形状为 `(batch_size, sequence_length, hidden_size)`。如果模型包含嵌入层,则第一个张量表示嵌入的输出。
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
元组类型的 `torch.FloatTensor`,包含模型的注意力权重,形状为 `(batch_size, num_heads, patch_size, sequence_length)`。这些权重经过注意力 SoftMax 后得到,用于计算自注意力头中的加权平均值。
@dataclass
class ImageClassifierOutputWithNoAttention(ModelOutput):
"""
Base class for outputs of image classification models.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
分类模型的损失值(如果提供了`labels`参数)。
logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
分类模型的输出分数(在经过 SoftMax 之前)。
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
模型每个阶段的隐藏状态(也称为特征图),形状为 `(batch_size, num_channels, height, width)`。
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
@dataclass
class DepthEstimatorOutput(ModelOutput):
"""
Base class for outputs of depth estimation models.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
深度估计模型的损失值(如果提供了`labels`参数)。
predicted_depth (`torch.FloatTensor` of shape `(batch_size, height, width)`):
每个像素预测的深度值。
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
模型每个层的隐藏状态(也称为特征图),形状为 `(batch_size, num_channels, height, width)`。
每个层的输出以及可选的初始嵌入输出。
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
每个层的注意力权重,形状为 `(batch_size, num_heads, patch_size, sequence_length)`。
注意力softmax后的权重,用于计算自注意力头中的加权平均值。
"""
loss: Optional[torch.FloatTensor] = None
predicted_depth: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
@dataclass
class ImageSuperResolutionOutput(ModelOutput):
"""
Base class for outputs of image super resolution models.
"""
# 定义函数的参数和返回值的类型注解
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
重建损失,当提供`labels`时返回。
reconstruction (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
重建的图像,可能是上采样后的结果。
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
一个元组,包含`torch.FloatTensor`类型的张量:
- 如果模型有嵌入层,则为形状为`(batch_size, sequence_length, hidden_size)`的张量;
- 每个阶段输出的隐藏状态(也称为特征图)。
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
一个元组,包含`torch.FloatTensor`类型的张量:
- 每层的注意力权重,形状为`(batch_size, num_heads, patch_size, sequence_length)`。
注意力 softmax 后的注意力权重,用于计算自注意力头中的加权平均值。
loss: Optional[torch.FloatTensor] = None
# 默认值为 None 的可选项,类型为 torch.FloatTensor
reconstruction: torch.FloatTensor = None
# 默认值为 None 的 torch.FloatTensor 类型的变量
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
# 默认值为 None 的可选项,类型为元组,包含 torch.FloatTensor 类型的张量
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
# 默认值为 None 的可选项,类型为元组,包含 torch.FloatTensor 类型的张量
# 使用 `dataclass` 装饰器定义一个数据类,用于表示 Wav2Vec2 模型的基础输出。
@dataclass
class Wav2Vec2BaseModelOutput(ModelOutput):
"""
Base class for models that have been trained with the Wav2Vec2 loss objective.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
extract_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, conv_dim[-1])`):
Sequence of extracted feature vectors of the last convolutional layer of the model.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
# 定义类变量 `last_hidden_state`,表示模型输出的最后一层隐藏状态
last_hidden_state: torch.FloatTensor = None
# 定义类变量 `extract_features`,表示模型输出的最后一个卷积层的特征向量序列
extract_features: torch.FloatTensor = None
# 定义类变量 `hidden_states`,表示模型每一层的隐藏状态的元组,可选返回项
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
# 定义类变量 `attentions`,表示注意力权重的元组,可选返回项,用于自注意力头中的加权平均计算
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
# 使用 `dataclass` 装饰器定义一个数据类,表示 `Wav2Vec2ForXVector` 的输出类型。
@dataclass
class XVectorOutput(ModelOutput):
"""
Output type of [`Wav2Vec2ForXVector`].
"""
# 此数据类暂未定义任何具体的输出内容,保留空的定义。
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
分类损失。
如果提供了 `labels`,则返回分类损失。
logits (`torch.FloatTensor` of shape `(batch_size, config.xvector_output_dim)`):
AMSoftmax 前的分类隐藏状态。
用于 AMSoftmax 前的分类隐藏状态。
embeddings (`torch.FloatTensor` of shape `(batch_size, config.xvector_output_dim)`):
用于基于向量相似性检索的话语嵌入。
用于基于向量相似性检索的话语嵌入。
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
模型每层输出的隐藏状态。
当传递 `output_hidden_states=True` 或 `config.output_hidden_states=True` 时返回。
元组包含了每层的 `torch.FloatTensor`,形状为 `(batch_size, sequence_length, hidden_size)`。
包括每层的隐藏状态以及初始嵌入输出。
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
自注意力权重。
当传递 `output_attentions=True` 或 `config.output_attentions=True` 时返回。
元组包含了每层的 `torch.FloatTensor`,形状为 `(batch_size, num_heads, sequence_length, sequence_length)`。
在注意力 softmax 后的注意力权重,用于计算自注意力头中的加权平均值。
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
embeddings: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
# 使用 dataclass 装饰器定义一个名为 `BackboneOutput` 的数据类,它继承自 `ModelOutput` 类
@dataclass
class BackboneOutput(ModelOutput):
"""
Base class for outputs of backbones.
Args:
feature_maps (`tuple(torch.FloatTensor)` of shape `(batch_size, num_channels, height, width)`):
Feature maps of the stages.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
shape `(batch_size, sequence_length, hidden_size)` or `(batch_size, num_channels, height, width)`,
depending on the backbone.
Hidden-states of the model at the output of each stage plus the initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`. Only applicable if the backbone uses attention.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
# 定义特征图的属性,类型为元组,包含了每个阶段的特征图
feature_maps: Tuple[torch.FloatTensor] = None
# 定义隐藏状态的属性,类型为可选的元组,包含了每个阶段的隐藏状态
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
# 定义注意力权重的属性,类型为可选的元组,包含了每个层的注意力权重
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
# 使用 dataclass 装饰器定义一个名为 `BaseModelOutputWithPoolingAndProjection` 的数据类,它继承自 `ModelOutput` 类
@dataclass
class BaseModelOutputWithPoolingAndProjection(ModelOutput):
"""
Base class for model's outputs that also contains a pooling of the last hidden states.
"""
# 定义函数参数和它们的类型注释,描述了函数所接收的不同类型的输入数据
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
模型最后一层输出的隐藏状态序列。
pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
经过附加预训练任务处理后的序列第一个标记(分类标记)的最后一层隐藏状态。
例如,在BERT系列模型中,这是经过线性层和tanh激活函数处理后的分类标记。
线性层的权重是从预训练过程中的下一句预测(分类)目标中训练得到的。
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
模型每一层的隐藏状态序列的元组。
每个元素的形状为 `(batch_size, sequence_length, hidden_size)`,包括可选的初始嵌入层输出。
当 `output_hidden_states=True` 传递给模型或者 `config.output_hidden_states=True` 时返回。
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
模型每一层的注意力权重的元组。
每个元素的形状为 `(batch_size, num_heads, sequence_length, sequence_length)`,
用于计算自注意力头中的加权平均值。
当 `output_attentions=True` 传递给模型或者 `config.output_attentions=True` 时返回。
projection_state (`tuple(torch.FloatTensor)`, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
投影层之前的文本嵌入的元组。
每个元素的形状为 `(batch_size, config.project_dim)`,
用于模拟教师编码器的最后隐藏状态。
@dataclass
class Seq2SeqSpectrogramOutput(ModelOutput):
"""
Base class for sequence-to-sequence spectrogram outputs.
"""
loss: Optional[torch.FloatTensor] = None # 损失值,用于存储模型输出的损失
spectrogram: torch.FloatTensor = None # 频谱图数据,存储模型生成的频谱图
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None # 过去的键值对,用于存储可加速顺序解码的隐藏状态
decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None # 解码器的隐藏状态列表
decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None # 解码器的注意力权重列表
cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None # 交叉注意力权重列表
encoder_last_hidden_state: Optional[torch.FloatTensor] = None # 编码器的最后隐藏状态
encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None # 编码器的隐藏状态列表
encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None # 编码器的注意力权重列表
@dataclass
class Seq2SeqTSModelOutput(ModelOutput):
"""
Base class for time series model's encoder outputs that also contains pre-computed hidden states that can speed up
sequential decoding.
"""
last_hidden_state: torch.FloatTensor = None # 最后的隐藏状态,存储编码器最后的隐藏状态
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None # 过去的键值对,用于存储可加速顺序解码的隐藏状态
decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None # 解码器的隐藏状态列表
decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None # 解码器的注意力权重列表
cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None # 交叉注意力权重列表
encoder_last_hidden_state: Optional[torch.FloatTensor] = None # 编码器的最后隐藏状态
encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None # 编码器的隐藏状态列表
encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None # 编码器的注意力权重列表
loc: Optional[torch.FloatTensor] = None # 位置参数,用于存储预测分布的位置参数
scale: Optional[torch.FloatTensor] = None # 尺度参数,用于存储预测分布的尺度参数
static_features: Optional[torch.FloatTensor] = None # 静态特征,用于存储与时间序列模型相关的静态特征
@dataclass
class Seq2SeqTSPredictionOutput(ModelOutput):
"""
Base class for time series model's decoder outputs that also contain the loss as well as the parameters of the
chosen distribution.
"""
loss: Optional[torch.FloatTensor] = None # 损失值,用于存储模型输出的损失
params: Optional[Tuple[torch.FloatTensor]] = None # 参数,用于存储所选分布的参数
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None # 过去的键值对,用于存储可加速顺序解码的隐藏状态
decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None # 解码器的隐藏状态列表
decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None # 解码器的注意力权重列表
cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None # 交叉注意力权重列表
encoder_last_hidden_state: Optional[torch.FloatTensor] = None # 编码器的最后隐藏状态
encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None # 编码器的隐藏状态列表
encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None # 编码器的注意力权重列表
loc: Optional[torch.FloatTensor] = None # 位置参数,用于存储预测分布的位置参数
scale: Optional[torch.FloatTensor] = None # 尺度参数,用于存储预测分布的尺度参数
static_features: Optional[torch.FloatTensor] = None # 静态特征,用于存储与时间序列模型相关的静态特征
@dataclass
class SampleTSPredictionOutput(ModelOutput):
"""
Base class for time series model's predictions outputs that contains the sampled values from the chosen
distribution.
Args:
sequences (`torch.FloatTensor` of shape `(batch_size, num_samples, prediction_length)` or `(batch_size, num_samples, prediction_length, input_size)`):
Sampled values from the chosen distribution.
"""
# 该类用于存储时间序列模型的预测输出,包括从所选分布中采样得到的值
# 声明一个变量 sequences,类型为 torch 的 FloatTensor,初始值为 None
sequences: torch.FloatTensor = None
# 使用 dataclass 装饰器定义 MaskedImageModelingOutput 类,用于封装掩码图像完成/修补模型的输出结果
@dataclass
class MaskedImageModelingOutput(ModelOutput):
"""
Base class for outputs of masked image completion / in-painting models.
掩码图像完成/修补模型输出结果的基类。
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided):
Reconstruction loss.
重建损失,当提供 `bool_masked_pos` 时返回。
reconstruction (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Reconstructed / completed images.
重建/完成的图像。
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or
when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states
(also called feature maps) of the model at the output of each stage.
隐藏状态,模型在每个阶段输出的隐藏状态(特征图)元组。
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when
`config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size,
sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
the self-attention heads.
注意力权重,经过注意力 softmax 后的权重,用于计算自注意力头中的加权平均值。
"""
# 定义 loss 属性,类型为 torch.FloatTensor,可选,表示重建损失,默认为 None
loss: Optional[torch.FloatTensor] = None
# 定义 reconstruction 属性,类型为 torch.FloatTensor,表示重建/完成的图像
reconstruction: torch.FloatTensor = None
# 定义 hidden_states 属性,类型为 tuple(torch.FloatTensor),可选,表示隐藏状态的元组
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
# 定义 attentions 属性,类型为 tuple(torch.FloatTensor),可选,表示注意力权重的元组
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
# @property 装饰器,定义 logits 属性,用于获取输出的最终结果
@property
def logits(self):
# 发出警告,提醒 logits 属性在 Transformers 版本 5 中将被移除,请使用 reconstruction 属性获取最终输出
warnings.warn(
"logits attribute is deprecated and will be removed in version 5 of Transformers."
" Please use the reconstruction attribute to retrieve the final output instead.",
FutureWarning,
)
# 返回 reconstruction 属性作为最终输出
return self.reconstruction