Transformers-源码解析-七-

Transformers 源码解析(七)

.\kernels\deformable_detr\cpu\ms_deform_attn_cpu.cpp

/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/

// 包含标准库向量头文件
#include <vector>

// 包含 ATen 库的头文件,提供张量操作
#include <ATen/ATen.h>
// 包含 CUDA 上下文头文件,用于处理 CUDA 相关操作
#include <ATen/cuda/CUDAContext.h>

// 定义 CPU 下的前向传播函数,返回 ATen 张量
at::Tensor
ms_deform_attn_cpu_forward(
    const at::Tensor &value,               // 输入张量 value
    const at::Tensor &spatial_shapes,      // 空间形状张量
    const at::Tensor &level_start_index,   // 层级起始索引张量
    const at::Tensor &sampling_loc,        // 采样位置张量
    const at::Tensor &attn_weight,         // 注意力权重张量
    const int im2col_step)                 // im2col 步长
{
    // 抛出错误,表明在 CPU 上未实现该函数
    AT_ERROR("Not implement on cpu");
}

// 定义 CPU 下的反向传播函数,返回 ATen 张量向量
std::vector<at::Tensor>
ms_deform_attn_cpu_backward(
    const at::Tensor &value,               // 输入张量 value
    const at::Tensor &spatial_shapes,      // 空间形状张量
    const at::Tensor &level_start_index,   // 层级起始索引张量
    const at::Tensor &sampling_loc,        // 采样位置张量
    const at::Tensor &attn_weight,         // 注意力权重张量
    const at::Tensor &grad_output,         // 梯度输出张量
    const int im2col_step)                 // im2col 步长
{
    // 抛出错误,表明在 CPU 上未实现该函数
    AT_ERROR("Not implement on cpu");
}

.\kernels\deformable_detr\cpu\ms_deform_attn_cpu.h

/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/

// 预处理指令,确保头文件只被包含一次
#pragma once

// 包含 PyTorch C++ 扩展库头文件
#include <torch/extension.h>

// 前向传播函数声明,计算注意力机制的前向传播
at::Tensor
ms_deform_attn_cpu_forward(
    const at::Tensor &value,               // 输入的特征值张量
    const at::Tensor &spatial_shapes,      // 空间形状信息张量
    const at::Tensor &level_start_index,   // 层级起始索引张量
    const at::Tensor &sampling_loc,        // 采样位置张量
    const at::Tensor &attn_weight,         // 注意力权重张量
    const int im2col_step);                // im2col 步长

// 反向传播函数声明,计算注意力机制的反向传播
std::vector<at::Tensor>
ms_deform_attn_cpu_backward(
    const at::Tensor &value,               // 输入的特征值张量
    const at::Tensor &spatial_shapes,      // 空间形状信息张量
    const at::Tensor &level_start_index,   // 层级起始索引张量
    const at::Tensor &sampling_loc,        // 采样位置张量
    const at::Tensor &attn_weight,         // 注意力权重张量
    const at::Tensor &grad_output,         // 梯度输出张量
    const int im2col_step);                // im2col 步长

.\kernels\deformable_detr\cuda\ms_deform_attn_cuda.h

/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/

#pragma once

// 包含 Torch C++ 扩展库的头文件
#include <torch/extension.h>

// 声明 CUDA 前向函数,计算多尺度可变形注意力机制的前向传播
at::Tensor ms_deform_attn_cuda_forward(
    const at::Tensor &value,           // 输入张量:特征图
    const at::Tensor &spatial_shapes,  // 输入张量:空间形状
    const at::Tensor &level_start_index,  // 输入张量:每级起始索引
    const at::Tensor &sampling_loc,    // 输入张量:采样位置
    const at::Tensor &attn_weight,     // 输入张量:注意力权重
    const int im2col_step              // 输入整数:im2col 步骤
);

// 声明 CUDA BF16(BFloat16)前向函数,计算多尺度可变形注意力机制的前向传播
at::Tensor ms_deform_attn_cuda_forward_bf16(
    const at::Tensor &value,           // 输入张量:特征图
    const at::Tensor &spatial_shapes,  // 输入张量:空间形状
    const at::Tensor &level_start_index,  // 输入张量:每级起始索引
    const at::Tensor &sampling_loc,    // 输入张量:采样位置
    const at::Tensor &attn_weight,     // 输入张量:注意力权重
    const int im2col_step              // 输入整数:im2col 步骤
);

// 声明 CUDA 反向函数,计算多尺度可变形注意力机制的反向传播
std::vector<at::Tensor> ms_deform_attn_cuda_backward(
    const at::Tensor &value,           // 输入张量:特征图
    const at::Tensor &spatial_shapes,  // 输入张量:空间形状
    const at::Tensor &level_start_index,  // 输入张量:每级起始索引
    const at::Tensor &sampling_loc,    // 输入张量:采样位置
    const at::Tensor &attn_weight,     // 输入张量:注意力权重
    const at::Tensor &grad_output,     // 输入张量:梯度输出
    const int im2col_step              // 输入整数:im2col 步骤
);

// 声明 CUDA BF16(BFloat16)反向函数,计算多尺度可变形注意力机制的反向传播
std::vector<at::Tensor> ms_deform_attn_cuda_backward_bf16(
    const at::Tensor &value,           // 输入张量:特征图
    const at::Tensor &spatial_shapes,  // 输入张量:空间形状
    const at::Tensor &level_start_index,  // 输入张量:每级起始索引
    const at::Tensor &sampling_loc,    // 输入张量:采样位置
    const at::Tensor &attn_weight,     // 输入张量:注意力权重
    const at::Tensor &grad_output,     // 输入张量:梯度输出
    const int im2col_step              // 输入整数:im2col 步骤
);

.\kernels\deformable_detr\ms_deform_attn.h

/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/

#pragma once

#include "cpu/ms_deform_attn_cpu.h"

#ifdef WITH_CUDA
#include "cuda/ms_deform_attn_cuda.h"
#endif

// 前向传播函数,处理注意力机制的计算
at::Tensor
ms_deform_attn_forward(
    const at::Tensor &value,                     // 输入张量,表示特征值
    const at::Tensor &spatial_shapes,            // 空间形状信息的张量
    const at::Tensor &level_start_index,         // 层级起始索引
    const at::Tensor &sampling_loc,              // 采样位置
    const at::Tensor &attn_weight,               // 注意力权重
    const int im2col_step)                       // im2col 步长
{
    // 如果输入张量在 GPU 上
    if (value.type().is_cuda())
    {
#ifdef WITH_CUDA
        // 调用 CUDA 实现的前向传播函数
        return ms_deform_attn_cuda_forward(
            value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step);
#else
        // 如果没有编译 GPU 支持,则抛出错误
        AT_ERROR("Not compiled with GPU support");
#endif
    }
    // 如果输入张量在 CPU 上,则抛出未实现 CPU 上的错误
    AT_ERROR("Not implemented on the CPU");
}

// 反向传播函数,处理注意力机制的反向梯度计算
std::vector<at::Tensor>
ms_deform_attn_backward(
    const at::Tensor &value,                     // 输入张量,表示特征值
    const at::Tensor &spatial_shapes,            // 空间形状信息的张量
    const at::Tensor &level_start_index,         // 层级起始索引
    const at::Tensor &sampling_loc,              // 采样位置
    const at::Tensor &attn_weight,               // 注意力权重
    const at::Tensor &grad_output,               // 梯度输出
    const int im2col_step)                       // im2col 步长
{
    // 如果输入张量在 GPU 上
    if (value.type().is_cuda())
    {
#ifdef WITH_CUDA
        // 调用 CUDA 实现的反向传播函数
        return ms_deform_attn_cuda_backward(
            value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step);
#else
        // 如果没有编译 GPU 支持,则抛出错误
        AT_ERROR("Not compiled with GPU support");
#endif
    }
    // 如果输入张量在 CPU 上,则抛出未实现 CPU 上的错误
    AT_ERROR("Not implemented on the CPU");
}

.\kernels\deformable_detr\vision.cpp

/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/

#include "ms_deform_attn.h"

// 使用 PYBIND11_MODULE 宏定义,将 C++ 函数绑定到 Python 中
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  // 定义 Python 可调用函数 ms_deform_attn_forward,对应 C++ 中的 ms_deform_attn_forward 函数
  m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward");
  // 定义 Python 可调用函数 ms_deform_attn_backward,对应 C++ 中的 ms_deform_attn_backward 函数
  m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward");
}

.\kernels\deta\cpu\ms_deform_attn_cpu.cpp

/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/

#include <vector>

#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>

// 定义了一个函数,用于在 CPU 上执行 ms_deform_attn 的前向传播
at::Tensor
ms_deform_attn_cpu_forward(
    const at::Tensor &value,                    // 输入张量 value
    const at::Tensor &spatial_shapes,           // 空间形状信息的张量
    const at::Tensor &level_start_index,        // 级别起始索引的张量
    const at::Tensor &sampling_loc,             // 采样位置的张量
    const at::Tensor &attn_weight,              // 注意力权重的张量
    const int im2col_step)                      // im2col 步长参数
{
    // 抛出错误,表示在 CPU 上尚未实现该函数
    AT_ERROR("Not implement on cpu");
}

// 定义了一个函数,用于在 CPU 上执行 ms_deform_attn 的反向传播
std::vector<at::Tensor>
ms_deform_attn_cpu_backward(
    const at::Tensor &value,                    // 输入张量 value
    const at::Tensor &spatial_shapes,           // 空间形状信息的张量
    const at::Tensor &level_start_index,        // 级别起始索引的张量
    const at::Tensor &sampling_loc,             // 采样位置的张量
    const at::Tensor &attn_weight,              // 注意力权重的张量
    const at::Tensor &grad_output,              // 梯度输出的张量
    const int im2col_step)                      // im2col 步长参数
{
    // 抛出错误,表示在 CPU 上尚未实现该函数
    AT_ERROR("Not implement on cpu");
}

.\kernels\deta\cpu\ms_deform_attn_cpu.h

/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/

// 预处理指令,确保头文件只被包含一次
#pragma once

// 包含 PyTorch C++ 扩展的头文件
#include <torch/extension.h>

// 前向推断函数声明,计算可变形注意力机制的前向传播
at::Tensor
ms_deform_attn_cpu_forward(
    const at::Tensor &value,            // 输入特征张量
    const at::Tensor &spatial_shapes,   // 空间形状信息
    const at::Tensor &level_start_index,// 级别起始索引
    const at::Tensor &sampling_loc,     // 采样位置
    const at::Tensor &attn_weight,      // 注意力权重
    const int im2col_step);             // im2col 步长

// 反向传播函数声明,计算可变形注意力机制的反向传播
std::vector<at::Tensor>
ms_deform_attn_cpu_backward(
    const at::Tensor &value,            // 输入特征张量
    const at::Tensor &spatial_shapes,   // 空间形状信息
    const at::Tensor &level_start_index,// 级别起始索引
    const at::Tensor &sampling_loc,     // 采样位置
    const at::Tensor &attn_weight,      // 注意力权重
    const at::Tensor &grad_output,      // 梯度输出
    const int im2col_step);             // im2col 步长


这段代码是一个C++头文件,声明了两个函数 `ms_deform_attn_cpu_forward` 和 `ms_deform_attn_cpu_backward`,用于实现可变形注意力机制的前向传播和反向传播。

.\kernels\deta\cuda\ms_deform_attn_cuda.h

/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/

// 包含 Torch C++ 扩展的头文件
#pragma once
#include <torch/extension.h>

// 声明 CUDA 前向传播函数,接受多个张量和整数参数
at::Tensor ms_deform_attn_cuda_forward(
    const at::Tensor &value,                    // 输入特征值张量
    const at::Tensor &spatial_shapes,           // 空间形状信息张量
    const at::Tensor &level_start_index,        // 层级起始索引张量
    const at::Tensor &sampling_loc,             // 采样位置张量
    const at::Tensor &attn_weight,              // 注意力权重张量
    const int im2col_step);                     // im2col 步长整数参数

// 声明 CUDA 反向传播函数,接受多个张量和整数参数,并返回张量向量
std::vector<at::Tensor> ms_deform_attn_cuda_backward(
    const at::Tensor &value,                    // 输入特征值张量
    const at::Tensor &spatial_shapes,           // 空间形状信息张量
    const at::Tensor &level_start_index,        // 层级起始索引张量
    const at::Tensor &sampling_loc,             // 采样位置张量
    const at::Tensor &attn_weight,              // 注意力权重张量
    const at::Tensor &grad_output,              // 梯度输出张量
    const int im2col_step);                     // im2col 步长整数参数

.\kernels\deta\ms_deform_attn.h

/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/

#pragma once

#include "cpu/ms_deform_attn_cpu.h"

#ifdef WITH_CUDA
#include "cuda/ms_deform_attn_cuda.h"
#endif

// 前向传播函数,用于实现可形变注意力机制的前向计算
at::Tensor
ms_deform_attn_forward(
    const at::Tensor &value,                  // 输入张量 value
    const at::Tensor &spatial_shapes,         // 空间形状信息张量
    const at::Tensor &level_start_index,      // 层级起始索引张量
    const at::Tensor &sampling_loc,           // 采样位置张量
    const at::Tensor &attn_weight,            // 注意力权重张量
    const int im2col_step)                    // im2col 步长参数
{
    // 如果输入张量在 CUDA 上,则调用 CUDA 实现的前向函数
    if (value.type().is_cuda())
    {
#ifdef WITH_CUDA
        return ms_deform_attn_cuda_forward(
            value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step);
#else
        // 如果没有编译 GPU 支持,则抛出错误信息
        AT_ERROR("Not compiled with GPU support");
#endif
    }
    // 如果在 CPU 上调用该函数,则抛出错误信息,表明未实现 CPU 版本
    AT_ERROR("Not implemented on the CPU");
}

// 反向传播函数,用于实现可形变注意力机制的反向计算
std::vector<at::Tensor>
ms_deform_attn_backward(
    const at::Tensor &value,                  // 输入张量 value
    const at::Tensor &spatial_shapes,         // 空间形状信息张量
    const at::Tensor &level_start_index,      // 层级起始索引张量
    const at::Tensor &sampling_loc,           // 采样位置张量
    const at::Tensor &attn_weight,            // 注意力权重张量
    const at::Tensor &grad_output,            // 梯度输出张量
    const int im2col_step)                    // im2col 步长参数
{
    // 如果输入张量在 CUDA 上,则调用 CUDA 实现的反向函数
    if (value.type().is_cuda())
    {
#ifdef WITH_CUDA
        return ms_deform_attn_cuda_backward(
            value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step);
#else
        // 如果没有编译 GPU 支持,则抛出错误信息
        AT_ERROR("Not compiled with GPU support");
#endif
    }
    // 如果在 CPU 上调用该函数,则抛出错误信息,表明未实现 CPU 版本
    AT_ERROR("Not implemented on the CPU");
}

.\kernels\deta\vision.cpp

/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/

#include "ms_deform_attn.h"

// 使用 Pybind11 构建一个 Python 模块,名字为 TORCH_EXTENSION_NAME
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  // 定义 Python 接口函数 ms_deform_attn_forward,与 C++ 函数 ms_deform_attn_forward 绑定
  m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward");
  // 定义 Python 接口函数 ms_deform_attn_backward,与 C++ 函数 ms_deform_attn_backward 绑定
  m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward");
}

.\kernels\mra\cuda_kernel.h

// 定义线程块大小为32
#define WARP_SIZE 32
// 定义全掩码为32位全1
#define FULL_MASK 0xffffffff
// 定义优化线程数为256
#define OPTIMAL_THREADS 256

// CUDA 核函数,计算每个批次中每个块中的最大值索引和最大值
__global__ void index_max_cuda_kernel(
  float *index_vals,       // [batch_size, 32, num_block]
  int   *indices,          // [batch_size, num_block]
  float *max_vals,         // [batch_size, A_num_block * 32]
  float *max_vals_scatter, // [batch_size, 32, num_block]
  long batch_size,         // 批次大小
  long A_num_block,        // A_num_block
  long B_num_block,        // B_num_block
  long num_block           // num_block
);

// CUDA 核函数,将稠密矩阵乘法结果转换为稀疏格式
__global__ void mm_to_sparse_cuda_kernel(
  float *dense_A,   // [batch_size, A_num_block, dim, 32]
  float *dense_B,   // [batch_size, B_num_block, dim, 32]
  int   *indices,   // [batch_size, num_block]
  float *sparse_C,  // [batch_size, num_block, 32, 32]
  long batch_size,  // 批次大小
  long A_num_block, // A_num_block
  long B_num_block, // B_num_block
  long dim,         // dim
  long num_block    // num_block
);

// CUDA 核函数,稀疏矩阵与稠密矩阵的乘法
__global__ void sparse_dense_mm_cuda_kernel(
  float *sparse_A,  // [batch_size, num_block, 32, 32]
  int   *indices,   // [batch_size, num_block]
  float *dense_B,   // [batch_size, B_num_block, dim, 32]
  float *dense_C,   // [batch_size, A_num_block, dim, 32]
  long batch_size,  // 批次大小
  long A_num_block, // A_num_block
  long B_num_block, // B_num_block
  long dim,         // dim
  long num_block    // num_block
);

// CUDA 核函数,计算稀疏矩阵在指定维度上的和
__global__ void reduce_sum_cuda_kernel(
  float *sparse_A,  // [batch_size, num_block, 32, 32]
  int   *indices,   // [batch_size, num_block]
  float *dense_C,   // [batch_size, A_num_block, 32]
  long batch_size,  // 批次大小
  long A_num_block, // A_num_block
  long B_num_block, // B_num_block
  long num_block    // num_block
);

// CUDA 核函数,将稠密矩阵按索引散布到稀疏矩阵中
__global__ void scatter_cuda_kernel(
  float *dense_A,   // [batch_size, A_num_block, 32]
  int   *indices,   // [batch_size, num_block]
  float *sparse_C,  // [batch_size, num_block, 32, 32]
  long batch_size,  // 批次大小
  long A_num_block, // A_num_block
  long B_num_block, // B_num_block
  long num_block    // num_block
);

.\kernels\mra\cuda_launch.h

# 包含 Torch C++ 扩展的头文件
#include <torch/extension.h>
# 包含 ATen 库的头文件,用于张量操作
#include <ATen/ATen.h>
# 包含 vector 标准库,用于定义和操作动态数组
#include <vector>

# 定义宏函数 min,返回两个数中较小的那个
#define min(a, b) ((a)<(b)?(a):(b))
# 定义宏函数 max,返回两个数中较大的那个
#define max(a, b) ((a)>(b)?(a):(b))

# 声明一个函数,该函数返回一个包含多个张量的 vector
std::vector<at::Tensor> index_max_kernel(
  at::Tensor index_vals,   # 输入参数:索引值的张量
  at::Tensor indices,      # 输入参数:索引的张量
  int A_num_block,         # 输入参数:A 矩阵的块数量
  int B_num_block          # 输入参数:B 矩阵的块数量
);

# 声明一个函数,该函数执行稠密矩阵乘法并返回一个稀疏张量
at::Tensor mm_to_sparse_kernel(
  at::Tensor dense_A,      # 输入参数:稠密矩阵 A
  at::Tensor dense_B,      # 输入参数:稠密矩阵 B
  at::Tensor indices       # 输入参数:索引张量
);

# 声明一个函数,该函数执行稀疏矩阵与稠密矩阵的乘法并返回结果张量
at::Tensor sparse_dense_mm_kernel(
  at::Tensor sparse_A,     # 输入参数:稀疏矩阵 A
  at::Tensor indices,      # 输入参数:索引张量
  at::Tensor dense_B,      # 输入参数:稠密矩阵 B
  int A_num_block          # 输入参数:A 矩阵的块数量
);

# 声明一个函数,该函数执行稀疏矩阵的元素求和操作并返回结果张量
at::Tensor reduce_sum_kernel(
  at::Tensor sparse_A,     # 输入参数:稀疏矩阵 A
  at::Tensor indices,      # 输入参数:索引张量
  int A_num_block,         # 输入参数:A 矩阵的块数量
  int B_num_block          # 输入参数:B 矩阵的块数量
);

# 声明一个函数,该函数执行稠密张量的散布(scatter)操作并返回结果张量
at::Tensor scatter_kernel(
  at::Tensor dense_A,      # 输入参数:稠密张量 A
  at::Tensor indices,      # 输入参数:索引张量
  int B_num_block          # 输入参数:B 矩阵的块数量
);

.\kernels\mra\torch_extension.cpp

#include <torch/extension.h>
#include <ATen/ATen.h>
#include "cuda_launch.h"  // 引入 CUDA 相关的头文件
#include <vector>  // 引入 vector 容器的头文件

std::vector<at::Tensor> index_max(  // 定义函数 index_max,返回一个 Tensor 向量
  at::Tensor index_vals,  // 输入参数 index_vals,类型为 Tensor
  at::Tensor indices,  // 输入参数 indices,类型为 Tensor
  int A_num_block,  // 输入参数 A_num_block,整型
  int B_num_block  // 输入参数 B_num_block,整型
) {
  return index_max_kernel(  // 调用 index_max_kernel 函数,返回其结果
    index_vals,  // 将 index_vals 作为参数传递给 index_max_kernel 函数
    indices,  // 将 indices 作为参数传递给 index_max_kernel 函数
    A_num_block,  // 将 A_num_block 作为参数传递给 index_max_kernel 函数
    B_num_block  // 将 B_num_block 作为参数传递给 index_max_kernel 函数
  );
}

at::Tensor mm_to_sparse(  // 定义函数 mm_to_sparse,返回一个 Tensor
  at::Tensor dense_A,  // 输入参数 dense_A,类型为 Tensor
  at::Tensor dense_B,  // 输入参数 dense_B,类型为 Tensor
  at::Tensor indices  // 输入参数 indices,类型为 Tensor
) {
  return mm_to_sparse_kernel(  // 调用 mm_to_sparse_kernel 函数,返回其结果
    dense_A,  // 将 dense_A 作为参数传递给 mm_to_sparse_kernel 函数
    dense_B,  // 将 dense_B 作为参数传递给 mm_to_sparse_kernel 函数
    indices  // 将 indices 作为参数传递给 mm_to_sparse_kernel 函数
  );
}

at::Tensor sparse_dense_mm(  // 定义函数 sparse_dense_mm,返回一个 Tensor
  at::Tensor sparse_A,  // 输入参数 sparse_A,类型为 Tensor
  at::Tensor indices,  // 输入参数 indices,类型为 Tensor
  at::Tensor dense_B,  // 输入参数 dense_B,类型为 Tensor
  int A_num_block  // 输入参数 A_num_block,整型
) {
  return sparse_dense_mm_kernel(  // 调用 sparse_dense_mm_kernel 函数,返回其结果
    sparse_A,  // 将 sparse_A 作为参数传递给 sparse_dense_mm_kernel 函数
    indices,  // 将 indices 作为参数传递给 sparse_dense_mm_kernel 函数
    dense_B,  // 将 dense_B 作为参数传递给 sparse_dense_mm_kernel 函数
    A_num_block  // 将 A_num_block 作为参数传递给 sparse_dense_mm_kernel 函数
  );
}

at::Tensor reduce_sum(  // 定义函数 reduce_sum,返回一个 Tensor
  at::Tensor sparse_A,  // 输入参数 sparse_A,类型为 Tensor
  at::Tensor indices,  // 输入参数 indices,类型为 Tensor
  int A_num_block,  // 输入参数 A_num_block,整型
  int B_num_block  // 输入参数 B_num_block,整型
) {
  return reduce_sum_kernel(  // 调用 reduce_sum_kernel 函数,返回其结果
    sparse_A,  // 将 sparse_A 作为参数传递给 reduce_sum_kernel 函数
    indices,  // 将 indices 作为参数传递给 reduce_sum_kernel 函数
    A_num_block,  // 将 A_num_block 作为参数传递给 reduce_sum_kernel 函数
    B_num_block  // 将 B_num_block 作为参数传递给 reduce_sum_kernel 函数
  );
}

at::Tensor scatter(  // 定义函数 scatter,返回一个 Tensor
  at::Tensor dense_A,  // 输入参数 dense_A,类型为 Tensor
  at::Tensor indices,  // 输入参数 indices,类型为 Tensor
  int B_num_block  // 输入参数 B_num_block,整型
) {
  return scatter_kernel(  // 调用 scatter_kernel 函数,返回其结果
    dense_A,  // 将 dense_A 作为参数传递给 scatter_kernel 函数
    indices,  // 将 indices 作为参数传递给 scatter_kernel 函数
    B_num_block  // 将 B_num_block 作为参数传递给 scatter_kernel 函数
  );
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {  // 定义 Python 扩展模块
  m.def("index_max", &index_max, "index_max (CUDA)");  // 将 index_max 函数绑定到 Python 中,并指定描述
  m.def("mm_to_sparse", &mm_to_sparse, "mm_to_sparse (CUDA)");  // 将 mm_to_sparse 函数绑定到 Python 中,并指定描述
  m.def("sparse_dense_mm", &sparse_dense_mm, "sparse_dense_mm (CUDA)");  // 将 sparse_dense_mm 函数绑定到 Python 中,并指定描述
  m.def("reduce_sum", &reduce_sum, "reduce_sum (CUDA)");  // 将 reduce_sum 函数绑定到 Python 中,并指定描述
  m.def("scatter", &scatter, "scatter (CUDA)");  // 将 scatter 函数绑定到 Python 中,并指定描述
}

.\kernels\rwkv\wkv_op.cpp

# 包含 Torch 扩展和 ATen 库的头文件
#include <torch/extension.h>
#include "ATen/ATen.h"

# 定义一个别名 bf16 代表 ATen 库中的 BFloat16 类型
typedef at::BFloat16 bf16;

# 声明 CUDA 前向传播函数,接受 float 类型数据
void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y);

# 声明 CUDA 前向传播函数,接受 BFloat16 类型数据
void cuda_forward_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y);

# 声明带状态的 CUDA 前向传播函数,接受 float 类型数据
void cuda_forward_with_state(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *s);

# 声明带状态的 CUDA 前向传播函数,接受 BFloat16 类型数据
void cuda_forward_with_state_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, float *s);

# 声明 CUDA 反向传播函数,接受 float 类型数据
void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv);

# 声明 CUDA 反向传播函数,接受 BFloat16 类型数据
void cuda_backward_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, bf16 *gy, bf16 *gw, bf16 *gu, bf16 *gk, bf16 *gv);

# 定义无状态的前向传播函数,接受 Torch 张量参数
void forward(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) {
    # 获取张量的批量大小 B,时间步长 T,通道数 C
    const int B = k.size(0);
    const int T = k.size(1);
    const int C = k.size(2);
    # 调用 CUDA 前向传播函数,传递 float 类型数据指针
    cuda_forward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>());
}

# 定义无状态的前向传播函数,接受 BFloat16 类型的 Torch 张量参数
void forward_bf16(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) {
    # 获取张量的批量大小 B,时间步长 T,通道数 C
    const int B = k.size(0);
    const int T = k.size(1);
    const int C = k.size(2);
    # 调用 CUDA 前向传播函数,传递 BFloat16 类型数据指针
    cuda_forward_bf16(B, T, C, w.data_ptr<float>(), u.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), y.data_ptr<bf16>());
}

# 定义带状态的前向传播函数,接受 Torch 张量参数及状态张量
void forward_with_state(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &s) {
    # 获取张量的批量大小 B,时间步长 T,通道数 C
    const int B = k.size(0);
    const int T = k.size(1);
    const int C = k.size(2);
    # 调用带状态的 CUDA 前向传播函数,传递 float 类型数据指针及状态张量数据指针
    cuda_forward_with_state(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>(), s.data_ptr<float>());
}

# 定义带状态的前向传播函数,接受 BFloat16 类型的 Torch 张量参数及状态张量
void forward_with_state_bf16(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &s) {
    # 获取张量的批量大小 B,时间步长 T,通道数 C
    const int B = k.size(0);
    const int T = k.size(1);
    const int C = k.size(2);
    # 调用带状态的 CUDA 前向传播函数,传递 BFloat16 类型数据指针及状态张量数据指针
    cuda_forward_with_state_bf16(B, T, C, w.data_ptr<float>(), u.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), y.data_ptr<bf16>(), s.data_ptr<float>());
}

# 定义反向传播函数,接受 Torch 张量参数
void backward(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) {
    # 获取张量的批量大小 B,时间步长 T,通道数 C
    const int B = k.size(0);
    const int T = k.size(1);
    const int C = k.size(2);
    # 调用 CUDA 反向传播函数,传递 float 类型数据指针
    cuda_backward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>(), gy.data_ptr<float>(), gw.data_ptr<float>(), gu.data_ptr<float>(), gk.data_ptr<float>(), gv.data_ptr<float>());
}

# 定义反向传播函数,接受 BFloat16 类型的 Torch 张量参数
void backward_bf16(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) {
    # 获取张量的批量大小 B
    const int B = k.size(0);
    # 以下代码与 backward 函数相同,但接受 BFloat16 类型数据
    cuda_backward_bf16(B, T, C, w.data_ptr<float>(), u.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), y.data_ptr<bf16>(), gy.data_ptr<bf16>(), gw.data_ptr<bf16>(), gu.data_ptr<bf16>(), gk.data_ptr<bf16>(), gv.data_ptr<bf16>());
}
    # 定义常量 T,表示 k 张量的第一个维度大小
    const int T = k.size(1);
    # 定义常量 C,表示 k 张量的第二个维度大小
    const int C = k.size(2);
    # 调用 CUDA 后向传播函数 cuda_backward_bf16,传递以下参数:
    # - B: 未明确提到,可能是某种批量大小或其它参数
    # - T: k 张量的第一个维度大小
    # - C: k 张量的第二个维度大小
    # - w.data_ptr<float>():权重张量 w 的 float 类型数据指针
    # - u.data_ptr<bf16>():输入张量 u 的 bf16 类型数据指针
    # - k.data_ptr<bf16>():内核张量 k 的 bf16 类型数据指针
    # - v.data_ptr<bf16>():中间变量 v 的 bf16 类型数据指针
    # - y.data_ptr<bf16>():输出张量 y 的 bf16 类型数据指针
    # - gy.data_ptr<bf16>():输出梯度 gy 的 bf16 类型数据指针
    # - gw.data_ptr<bf16>():权重梯度 gw 的 bf16 类型数据指针
    # - gu.data_ptr<bf16>():输入梯度 gu 的 bf16 类型数据指针
    # - gk.data_ptr<bf16>():内核梯度 gk 的 bf16 类型数据指针
    # - gv.data_ptr<bf16>():中间变量梯度 gv 的 bf16 类型数据指针
    cuda_backward_bf16(B, T, C, w.data_ptr<float>(), u.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), y.data_ptr<bf16>(),
        gy.data_ptr<bf16>(), gw.data_ptr<bf16>(), gu.data_ptr<bf16>(), gk.data_ptr<bf16>(), gv.data_ptr<bf16>());
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    // 定义 Python 绑定模块的函数 "forward",与 C++ 函数 &forward 绑定,描述为 "wkv forward"
    m.def("forward", &forward, "wkv forward");
    // 定义 Python 绑定模块的函数 "forward_bf16",与 C++ 函数 &forward_bf16 绑定,描述为 "wkv forward bf16"
    m.def("forward_bf16", &forward_bf16, "wkv forward bf16");
    // 定义 Python 绑定模块的函数 "forward_with_state",与 C++ 函数 &forward_with_state 绑定,描述为 "wkv forward with state"
    m.def("forward_with_state", &forward_with_state, "wkv forward with state");
    // 定义 Python 绑定模块的函数 "forward_with_state_bf16",与 C++ 函数 &forward_with_state_bf16 绑定,描述为 "wkv forward with state bf16"
    m.def("forward_with_state_bf16", &forward_with_state_bf16, "wkv forward with state bf16");
    // 定义 Python 绑定模块的函数 "backward",与 C++ 函数 &backward 绑定,描述为 "wkv backward"
    m.def("backward", &backward, "wkv backward");
    // 定义 Python 绑定模块的函数 "backward_bf16",与 C++ 函数 &backward_bf16 绑定,描述为 "wkv backward bf16"
    m.def("backward_bf16", &backward_bf16, "wkv backward bf16");
}

TORCH_LIBRARY(wkv, m) {
    // 在 Torch 的 wkv 库中注册函数 "forward",与 C++ 函数 forward 绑定
    m.def("forward", forward);
    // 在 Torch 的 wkv 库中注册函数 "forward_bf16",与 C++ 函数 forward_bf16 绑定
    m.def("forward_bf16", forward_bf16);
    // 在 Torch 的 wkv 库中注册函数 "forward_with_state",与 C++ 函数 forward_with_state 绑定
    m.def("forward_with_state", forward_with_state);
    // 在 Torch 的 wkv 库中注册函数 "forward_with_state_bf16",与 C++ 函数 forward_with_state_bf16 绑定
    m.def("forward_with_state_bf16", forward_with_state_bf16);
    // 在 Torch 的 wkv 库中注册函数 "backward",与 C++ 函数 backward 绑定
    m.def("backward", backward);
    // 在 Torch 的 wkv 库中注册函数 "backward_bf16",与 C++ 函数 backward_bf16 绑定
    m.def("backward_bf16", backward_bf16);
}

.\kernels\yoso\common.h

# 定义宏函数,返回两个数中较小的一个
#define min(a, b) ((a)<(b)?(a):(b))

# 定义宏函数,返回两个数中较大的一个
#define max(a, b) ((a)>(b)?(a):(b))

# 定义宏函数,对两个数进行向上取整的除法运算
#define ceil_divide(a, b) ((a)/(b)+((a)%(b)!=0))

# 定义宏函数,根据条件选择返回其中一个值
#define select(cond, a, b) ((cond)?(a):(b))

# 定义常数 PI,表示圆周率
#define PI 3.141592

# 定义常数 EPSILON,表示一个小的正数,通常用于浮点数比较的容差
#define EPSILON 1e-8

# 定义常数 MAX_VAL,表示一个较大的数值上限
#define MAX_VAL 1e12

# 定义常数 MIN_VAL,表示一个较小的数值下限
#define MIN_VAL -1e12

# 定义常数 EMPTY_VALUE,表示一个特定的空值或未初始化值
#define EMPTY_VALUE -1

.\kernels\yoso\common_cuda.h

# 定义每个线程块中的最大线程数
#define MAX_THREADS_PER_BLOCK 1024

# 定义优化后的推荐线程块中的线程数
#define OPTIMAL_THREADS_PER_BLOCK 256

# 定义线程束(warp)的大小
#define WARP_SIZE 32

# 定义在X方向上的最大线程块数量
#define MAX_NUM_BLOCK_X 2147483647

# 定义在Y方向上的最大线程块数量
#define MAX_NUM_BLOCK_Y 65535

# 定义在Z方向上的最大线程块数量
#define MAX_NUM_BLOCK_Z 65535

# 定义每个线程块可用的最大共享内存量
#define MAX_SHARED_MEM_PER_BLOCK 48000

# 定义一个掩码,包含所有位都是1
#define FULL_MASK 0xffffffff

.\kernels\yoso\common_cuda_device.h

# 包含公共头文件,这里假设 "common.h" 包含了项目中的通用定义和声明
#include "common.h"

# 定义一个模板函数 set_insert,用于向集合中插入元素
template<typename T>
__device__ int set_insert(T *set, int set_size, T value) {
  # 计算值在集合中的插入位置
  int slot = value % set_size;
  int start_slot = slot;
  # 循环尝试插入值,直到成功或者集合已满
  while (true) {
    # 使用原子操作 CAS(Compare and Swap),尝试在集合中插入值
    T prev = atomicCAS(&set[slot], EMPTY_VALUE, value);
    # 如果插入成功或者集合中已经存在相同值,则返回插入位置
    if (prev == EMPTY_VALUE || prev == value) {
      return slot;
    }
    # 如果插入失败,则尝试下一个位置
    slot = (slot + 1) % set_size;
    # 如果回到起始位置,表示集合已满
    if (slot == start_slot) {
      return -1;
    }
  }
  return -1;
}

# 定义一个模板函数 set_lookup,用于在集合中查找元素的位置
template<typename T>
__device__ int set_lookup(T *set, int set_size, T value) {
  # 计算值在集合中的起始位置
  int slot = value % set_size;
  int start_slot = slot;
  # 循环查找值,直到找到或者集合遍历完毕
  while (true) {
    # 如果当前位置的值等于要查找的值,则返回该位置
    if (set[slot] == value) {
      return slot;
    }
    # 否则尝试下一个位置
    slot = (slot + 1) % set_size;
    # 如果回到起始位置,表示值不在集合中
    if (slot == start_slot) {
      return -1;
    }
  }
  return -1;
}

# 定义一个模板函数 init_buffer,用于初始化缓冲区
template<typename T>
__device__ void init_buffer(T init_value, T *buffer, int buffer_size, int num_threads, int thread_id) {
  # 同步所有线程,确保前面的操作已完成
  __syncthreads();
  # 循环初始化缓冲区
  for (int i = 0; i < buffer_size; i = i + num_threads) {
    int offset_idx = i + thread_id;
    # 如果当前线程需要处理有效的索引,则初始化缓冲区值
    if (offset_idx < buffer_size) {
      buffer[offset_idx] = init_value;
    }
  }
  # 再次同步所有线程,确保所有初始化操作已完成
  __syncthreads();
}

# 定义一个模板函数 copy_data,用于从源地址复制数据到目标地址
template<typename T>
__device__ void copy_data(T *src_pt, T *dist_pt, int data_length, int num_threads, int thread_id) {
  # 同步所有线程,确保前面的操作已完成
  __syncthreads();
  # 循环复制数据
  for (int i = 0; i < data_length; i = i + num_threads) {
    int offset_idx = i + thread_id;
    # 如果当前线程需要处理有效的索引,则复制数据
    if (offset_idx < data_length) {
      dist_pt[offset_idx] = src_pt[offset_idx];
    }
  }
  # 再次同步所有线程,确保所有复制操作已完成
  __syncthreads();
}

# 定义一个模板函数 init_buffer_nonblocking,用于非阻塞方式初始化缓冲区
template<typename T>
__device__ void init_buffer_nonblocking(T init_value, T *buffer, int buffer_size, int num_threads, int thread_id) {
  # 循环初始化缓冲区,无需同步线程
  for (int i = 0; i < buffer_size; i = i + num_threads) {
    int offset_idx = i + thread_id;
    # 如果当前线程需要处理有效的索引,则初始化缓冲区值
    if (offset_idx < buffer_size) {
      buffer[offset_idx] = init_value;
    }
  }
}

# 定义一个模板函数 copy_data_nonblocking,用于非阻塞方式从源地址复制数据到目标地址
template<typename T>
__device__ void copy_data_nonblocking(T *src_pt, T *dist_pt, int data_length, int num_threads, int thread_id) {
  # 循环复制数据,无需同步线程
  for (int i = 0; i < data_length; i = i + num_threads) {
    int offset_idx = i + thread_id;
    # 如果当前线程需要处理有效的索引,则复制数据
    if (offset_idx < data_length) {
      dist_pt[offset_idx] = src_pt[offset_idx];
    }
  }
}

.\kernels\yoso\fast_lsh_cumulation.h

// 导入 PyTorch C++ 扩展头文件
#include <torch/extension.h>
// 导入 ATen 库的头文件
#include <ATen/ATen.h>
// 导入 STL 中的 vector 容器
#include <vector>

// 定义快速哈希(版本1)的核函数,返回多个张量作为结果
std::vector<at::Tensor> fast_hash_ver1_kernel(
  // 查询掩码张量
  at::Tensor query_mask,
  // 查询向量张量
  at::Tensor query_vector,
  // 关键字掩码张量
  at::Tensor key_mask,
  // 关键字向量张量
  at::Tensor key_vector,
  // 哈希函数数量
  int num_hash_f,
  // 哈希码长度
  int hash_code_len,
  // 是否使用 CUDA
  bool use_cuda
);

// 定义哈希累积(版本1)的核函数,返回张量作为结果
at::Tensor lsh_cumulation_ver1_kernel(
  // 查询掩码张量
  at::Tensor query_mask,
  // 查询哈希码张量
  at::Tensor query_hash_code,
  // 关键字掩码张量
  at::Tensor key_mask,
  // 关键字哈希码张量
  at::Tensor key_hash_code,
  // 值张量
  at::Tensor value,
  // 哈希表容量
  int hashtable_capacity,
  // 是否使用 CUDA
  bool use_cuda
);

// 定义加权哈希累积(版本1)的核函数,返回张量作为结果
at::Tensor lsh_weighted_cumulation_ver1_kernel(
  // 查询掩码张量
  at::Tensor query_mask,
  // 查询哈希码张量
  at::Tensor query_hash_code,
  // 查询权重张量
  at::Tensor query_weight,
  // 关键字掩码张量
  at::Tensor key_mask,
  // 关键字哈希码张量
  at::Tensor key_hash_code,
  // 关键字权重张量
  at::Tensor key_weight,
  // 值张量
  at::Tensor value,
  // 哈希表容量
  int hashtable_capacity,
  // 是否使用 CUDA
  bool use_cuda
);

// 定义加权哈希累积(版本2、3、4)的核函数,具体功能与版本1类似
// 只是版本号不同,参数及返回值的张量类型与数量相同,不再重复注释每个版本的功能
at::Tensor lsh_weighted_cumulation_ver2_kernel(
  at::Tensor query_mask,
  at::Tensor query_hash_code,
  at::Tensor query_weight,
  at::Tensor key_mask,
  at::Tensor key_hash_code,
  at::Tensor key_weight,
  at::Tensor value,
  int hashtable_capacity,
  bool use_cuda
);

at::Tensor lsh_weighted_cumulation_ver3_kernel(
  at::Tensor query_mask,
  at::Tensor query_hash_code,
  at::Tensor query_weight,
  at::Tensor key_mask,
  at::Tensor key_hash_code,
  at::Tensor key_weight,
  at::Tensor value,
  int hashtable_capacity,
  bool use_cuda
);

at::Tensor lsh_weighted_cumulation_ver4_kernel(
  at::Tensor query_mask,
  at::Tensor query_hash_code,
  at::Tensor query_weight,
  at::Tensor key_mask,
  at::Tensor key_hash_code,
  at::Tensor key_weight,
  at::Tensor value,
  int hashtable_capacity,
  bool use_cuda
);

.\kernels\yoso\fast_lsh_cumulation_cuda.h

__global__ void fast_hash_ver1_cuda_kernel(
  int *mask,        // [batch_size, num_vector],用于存储掩码数据的整数指针
  float *vector,    // [batch_size, num_vector, vector_dim],存储向量数据的浮点数指针
  int *Dmat,        // [3, num_part, vector_dim],存储分割矩阵数据的整数指针
  int *hash_code,   // [batch_size, num_vector, num_hash_f],存储哈希码数据的整数指针
  int batch_size,   // 批处理大小,整数参数
  int num_vector,   // 向量数量,整数参数
  int vector_dim,   // 向量维度,整数参数
  int num_part,     // 分割数,整数参数
  int num_hash_f,   // 哈希函数数量,整数参数
  int hash_code_len // 哈希码长度,整数参数
);

__global__ void lsh_cumulation_ver1_step1_cuda_kernel(
  int *key_mask,           // [batch_size, num_key],用于存储键掩码数据的整数指针
  int *key_hash_code,      // [batch_size, num_key, num_hash_f],存储键哈希码数据的整数指针
  float *value,            // [batch_size, num_key, value_dim],存储值数据的浮点数指针
  float *hashtable_value,  // [batch_size, num_hash_f, hashtable_capacity, value_dim],哈希表值的浮点数指针
  int batch_size,          // 批处理大小,整数参数
  int num_hash_f,          // 哈希函数数量,整数参数
  int hashtable_capacity,  // 哈希表容量,整数参数
  int num_key,             // 键数量,整数参数
  int value_dim,           // 值维度,整数参数
  int offset_warp          // 偏移量(warp),整数参数
);

__global__ void lsh_cumulation_ver1_step2_cuda_kernel(
  int *query_mask,         // [batch_size, num_query],用于存储查询掩码数据的整数指针
  int *query_hash_code,    // [batch_size, num_query, num_hash_f],存储查询哈希码数据的整数指针
  float *hashtable_value,  // [batch_size, num_hash_f, hashtable_capacity, value_dim],哈希表值的浮点数指针
  float *cumulation_value, // [batch_size, num_query, value_dim],累积值的浮点数指针
  int batch_size,          // 批处理大小,整数参数
  int num_hash_f,          // 哈希函数数量,整数参数
  int hashtable_capacity,  // 哈希表容量,整数参数
  int num_query,           // 查询数量,整数参数
  int value_dim,           // 值维度,整数参数
  int offset_warp          // 偏移量(warp),整数参数
);

__global__ void lsh_weighted_cumulation_ver1_step1_cuda_kernel(
  int *key_mask,            // [batch_size, num_key],用于存储键掩码数据的整数指针
  int *key_hash_code,       // [batch_size, num_key, num_hash_f],存储键哈希码数据的整数指针
  float *key_weight,        // [batch_size, num_key, weight_dim],存储键权重数据的浮点数指针
  float *value,             // [batch_size, num_key, value_dim],存储值数据的浮点数指针
  float *hashtable_value,   // [batch_size, num_hash_f, hashtable_capacity, WARP_SIZE],哈希表值的浮点数指针
  int batch_size,           // 批处理大小,整数参数
  int num_hash_f,           // 哈希函数数量,整数参数
  int hashtable_capacity,   // 哈希表容量,整数参数
  int num_key,              // 键数量,整数参数
  int value_dim,            // 值维度,整数参数
  int weight_dim,           // 权重维度,整数参数
  int offset_warp,          // 偏移量(warp),整数参数
  int weight_idx            // 权重索引,整数参数
);

__global__ void lsh_weighted_cumulation_ver1_step2_cuda_kernel(
  int *query_mask,          // [batch_size, num_query],用于存储查询掩码数据的整数指针
  int *query_hash_code,     // [batch_size, num_query, num_hash_f],存储查询哈希码数据的整数指针
  float *query_weight,      // [batch_size, num_query, weight_dim],存储查询权重数据的浮点数指针
  float *hashtable_value,   // [batch_size, num_hash_f, hashtable_capacity, WARP_SIZE],哈希表值的浮点数指针
  float *cumulation_value,  // [batch_size, num_query, value_dim],累积值的浮点数指针
  int batch_size,           // 批处理大小,整数参数
  int num_hash_f,           // 哈希函数数量,整数参数
  int hashtable_capacity,   // 哈希表容量,整数参数
  int num_query,            // 查询数量,整数参数
  int value_dim,            // 值维度,整数参数
  int weight_dim,           // 权重维度,整数参数
  int offset_warp,          // 偏移量(warp),整数参数
  int weight_idx            // 权重索引,整数参数
);

__global__ void count_sort_step1_cuda_kernel(
  int *key_mask,         // [batch_size, num_key],用于存储键掩码数据的整数指针
  int *key_hash_code,    // [batch_size, num_key, num_hash_f],存储键哈希码数据的整数指针
  int *count_sort_table, // [batch_size, num_hash_f, hashtable_capacity],计数排序表的整数指针
  int batch_size,        // 批处理大小,整数参数
  int num_hash_f,        // 哈希函数数量,整数参数
  int hashtable_capacity,// 哈希表容量,整数参数
  int num_key            // 键数量,整数参数
);

__global__ void count_sort_step2_cuda_kernel(
  int *count_sort_table,  // [batch_size, num_hash_f, hashtable_capacity],计数排序表的整数指针
  int batch_size,         // 批处理大小,整数参数
  int num_hash_f,         // 哈希函数数量,整数参数
  int hashtable_capacity  // 哈希表容量,整数参数
);
__global__ void count_sort_step3_cuda_kernel(
  int *key_mask,          // 输入:表示批次中每个关键字的掩码数组 [batch_size, num_key]
  int *key_hash_code,     // 输入:表示批次中每个关键字的哈希码数组 [batch_size, num_key, num_hash_f]
  int *count_sort_table,  // 输入/输出:计数排序表格,用于存储排序后的关键字索引 [batch_size, num_hash_f, hashtable_capacity]
  int *key_sorted_idxes,  // 输出:存储排序后的关键字索引 [batch_size, num_hash_f, num_key]
  int batch_size,         // 输入:批次大小
  int num_hash_f,         // 输入:哈希函数数量
  int hashtable_capacity, // 输入:哈希表容量
  int num_key             // 输入:每个批次中的关键字数量
);

__global__ void extract_query_info_cuda_kernel(
  int *query_mask,       // 输入:表示批次中每个查询的掩码数组 [batch_size, num_query]
  int *query_hash_code,  // 输入:表示批次中每个查询的哈希码数组 [batch_size, num_query, num_hash_f]
  int *count_sort_table, // 输入:计数排序表格,用于存储排序后的关键字索引 [batch_size, num_hash_f, hashtable_capacity]
  int *query_info,       // 输出:存储查询信息,包括关键字索引和哈希函数索引 [batch_size, num_query, 2, num_hash_f]
  int batch_size,        // 输入:批次大小
  int num_hash_f,        // 输入:哈希函数数量
  int hashtable_capacity,// 输入:哈希表容量
  int num_query          // 输入:每个批次中的查询数量
);

__global__ void lsh_weighted_cumulation_ver2_step2_cuda_kernel(
  int *query_mask,         // 输入:表示批次中每个查询的掩码数组 [batch_size, num_query]
  int *query_info,         // 输入:存储查询信息,包括关键字索引和哈希函数索引 [batch_size, num_query, 2, num_hash_f]
  int *key_sorted_idxes,   // 输入:存储排序后的关键字索引 [batch_size, num_hash_f, num_key]
  float *query_weight,     // 输入:查询的权重数组 [batch_size, num_query, weight_dim]
  float *key_weight,       // 输入:关键字的权重数组 [batch_size, num_key, weight_dim]
  float *value,            // 输入:关键字对应的值数组 [batch_size, num_key, value_dim]
  float *cumulation_value, // 输出:累积后的值数组 [batch_size, num_query, value_dim]
  int batch_size,          // 输入:批次大小
  int num_hash_f,          // 输入:哈希函数数量
  int num_query,           // 输入:每个批次中的查询数量
  int num_key,             // 输入:每个批次中的关键字数量
  int value_dim,           // 输入:值的维度
  int weight_dim           // 输入:权重的维度
);

__global__ void lsh_weighted_cumulation_ver3_step2_cuda_kernel(
  int *query_sorted_idxes,   // 输入:存储排序后的查询索引 [batch_size, num_hash_f, num_query]
  int *key_mask,             // 输入:表示批次中每个关键字的掩码数组 [batch_size, num_key]
  int *key_info,             // 输入:关键字的信息数组,包括索引和哈希函数索引 [batch_size, num_key, 2, num_hash_f]
  float *query_weight,       // 输入:查询的权重数组 [batch_size, num_query, weight_dim]
  float *key_weight,         // 输入:关键字的权重数组 [batch_size, num_key, weight_dim]
  float *value,              // 输入:关键字对应的值数组 [batch_size, num_key, value_dim]
  float *cumulation_value,   // 输出:累积后的值数组 [batch_size, num_query, value_dim]
  int batch_size,            // 输入:批次大小
  int num_hash_f,            // 输入:哈希函数数量
  int num_query,             // 输入:每个批次中的查询数量
  int num_key,               // 输入:每个批次中的关键字数量
  int value_dim,             // 输入:值的维度
  int weight_dim             // 输入:权重的维度
);

__global__ void lsh_weighted_cumulation_ver4_step2_cuda_kernel(
  int *query_sorted_idxes,   // 输入:存储排序后的查询索引 [batch_size, num_hash_f, num_query]
  int *key_mask,             // 输入:表示批次中每个关键字的掩码数组 [batch_size, num_key]
  int *key_info,             // 输入:关键字的信息数组,包括索引和哈希函数索引 [batch_size, num_key, 2, num_hash_f]
  float *query_weight,       // 输入:查询的权重数组 [batch_size, num_query, weight_dim]
  float *key_weight,         // 输入:关键字的权重数组 [batch_size, num_key, weight_dim]
  float *value,              // 输入:关键字对应的值数组 [batch_size, num_key, value_dim]
  float *cumulation_value,   // 输出:累积后的值数组 [batch_size, num_query, value_dim]
  int batch_size,            // 输入:批次大小
  int num_hash_f,            // 输入:哈希函数数量
  int num_query,             // 输入:每个批次中的查询数量
  int num_key,               // 输入:每个批次中的关键字数量
  int value_dim,             // 输入:值的维度
  int weight_dim             // 输入:权重的维度
);

.\kernels\yoso\fast_lsh_cumulation_torch.cpp

#include <torch/extension.h>
#include <ATen/ATen.h>
#include "fast_lsh_cumulation.h"  // 引入自定义的头文件,包含快速LSH累积相关的函数声明
#include "common_cuda.h"           // 引入自定义的头文件,包含通用的CUDA函数声明
#include <vector>                  // 引入标准库中的向量容器

// 快速哈希函数,调用指定版本的核函数处理哈希计算
std::vector<at::Tensor> fast_hash(
  at::Tensor query_mask,      // 查询掩码,形状为[batch_size, num_query]
  at::Tensor query_vector,    // 查询向量,形状为[batch_size, num_query, vector_dim]
  at::Tensor key_mask,        // 键掩码,形状为[batch_size, num_key]
  at::Tensor key_vector,      // 键向量,形状为[batch_size, num_key, vector_dim]
  int num_hash_f,             // 哈希函数数量
  int hash_code_len,          // 哈希码长度
  bool use_cuda,              // 是否使用CUDA加速
  int version                 // 函数版本号
) {
  return fast_hash_ver1_kernel(
    query_mask,
    query_vector,
    key_mask,
    key_vector,
    num_hash_f,
    hash_code_len,
    use_cuda
  );
}

// LSH累积函数,调用指定版本的核函数执行LSH累积操作
at::Tensor lsh_cumulation(
  at::Tensor query_mask,         // 查询掩码,形状为[batch_size, num_query]
  at::Tensor query_hash_code,    // 查询哈希码,形状为[batch_size, num_query, num_hash_f]
  at::Tensor key_mask,           // 键掩码,形状为[batch_size, num_key]
  at::Tensor key_hash_code,      // 键哈希码,形状为[batch_size, num_key, num_hash_f]
  at::Tensor value,              // 值,形状为[batch_size, num_key, value_dim]
  int hashtable_capacity,        // 哈希表容量
  bool use_cuda,                 // 是否使用CUDA加速
  int version                    // 函数版本号
) {
  return lsh_cumulation_ver1_kernel(
    query_mask,
    query_hash_code,
    key_mask,
    key_hash_code,
    value,
    hashtable_capacity,
    use_cuda
  );
}

// 加权LSH累积函数,根据版本号调用不同的核函数执行不同版本的加权LSH累积操作
at::Tensor lsh_weighted_cumulation(
  at::Tensor query_mask,         // 查询掩码,形状为[batch_size, num_query]
  at::Tensor query_hash_code,    // 查询哈希码,形状为[batch_size, num_query, num_hash_f]
  at::Tensor query_weight,       // 查询权重,形状为[batch_size, num_query, weight_dim]
  at::Tensor key_mask,           // 键掩码,形状为[batch_size, num_key]
  at::Tensor key_hash_code,      // 键哈希码,形状为[batch_size, num_key, num_hash_f]
  at::Tensor key_weight,         // 键权重,形状为[batch_size, num_key, weight_dim]
  at::Tensor value,              // 值,形状为[batch_size, num_key, value_dim]
  int hashtable_capacity,        // 哈希表容量
  bool use_cuda,                 // 是否使用CUDA加速
  int version                    // 函数版本号
) {
  if (version == 1) {
    return lsh_weighted_cumulation_ver1_kernel(
      query_mask,
      query_hash_code,
      query_weight,
      key_mask,
      key_hash_code,
      key_weight,
      value,
      hashtable_capacity,
      use_cuda
    );
  } else if (version == 2) {
    return lsh_weighted_cumulation_ver2_kernel(
      query_mask,
      query_hash_code,
      query_weight,
      key_mask,
      key_hash_code,
      key_weight,
      value,
      hashtable_capacity,
      use_cuda
    );
  } else if (version == 3) {
    return lsh_weighted_cumulation_ver3_kernel(
      query_mask,
      query_hash_code,
      query_weight,
      key_mask,
      key_hash_code,
      key_weight,
      value,
      hashtable_capacity,
      use_cuda
    );
  } else if (version == 4) {
    return lsh_weighted_cumulation_ver4_kernel(
      query_mask,
      query_hash_code,
      query_weight,
      key_mask,
      key_hash_code,
      key_weight,
      value,
      hashtable_capacity,
      use_cuda
    );
  } else {
    // 默认情况下使用第三个版本的核函数
    return lsh_weighted_cumulation_ver3_kernel(
      query_mask,
      query_hash_code,
      query_weight,
      key_mask,
      key_hash_code,
      key_weight,
      value,
      hashtable_capacity,
      use_cuda
    );
  }
}
# 使用 PYBIND11_MODULE 宏定义一个 Python 模块
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  # 将 fast_hash 函数绑定到 Python 模块中,并命名为 "Fast Hash (CUDA)"
  m.def("fast_hash", &fast_hash, "Fast Hash (CUDA)");
  # 将 lsh_cumulation 函数绑定到 Python 模块中,并命名为 "LSH Cumulation (CUDA)"
  m.def("lsh_cumulation", &lsh_cumulation, "LSH Cumulation (CUDA)");
  # 将 lsh_weighted_cumulation 函数绑定到 Python 模块中,并命名为 "LSH Weighted Cumulation (CUDA)"
  m.def("lsh_weighted_cumulation", &lsh_weighted_cumulation, "LSH Weighted Cumulation (CUDA)");
}

.\modelcard.py

# 导入所需的模块和库
import copy  # 导入深拷贝函数
import json  # 导入处理 JSON 的库
import os  # 导入操作系统相关的功能
import warnings  # 导入警告处理模块
from dataclasses import dataclass  # 导入 dataclass 用于创建数据类
from pathlib import Path  # 导入 Path 类用于处理文件路径
from typing import Any, Dict, List, Optional, Union  # 导入类型提示相关功能

import requests  # 导入处理 HTTP 请求的库
import yaml  # 导入处理 YAML 文件的库
from huggingface_hub import model_info  # 导入 Hugging Face Hub 的模型信息功能
from huggingface_hub.utils import HFValidationError  # 导入 Hugging Face Hub 的验证错误处理

from . import __version__  # 导入当前包的版本信息
from .models.auto.modeling_auto import (  # 导入自动生成模型的相关映射名称
    MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
    MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
    MODEL_FOR_CTC_MAPPING_NAMES,
    MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
    MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES,
    MODEL_FOR_MASKED_LM_MAPPING_NAMES,
    MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES,
    MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES,
    MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
    MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
    MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES,
    MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES,
    MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
    MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES,
)
from .training_args import ParallelMode  # 导入并行模式参数
from .utils import (  # 导入工具函数和常量
    MODEL_CARD_NAME,
    cached_file,
    is_datasets_available,
    is_offline_mode,
    is_tf_available,
    is_tokenizers_available,
    is_torch_available,
    logging,
)


TASK_MAPPING = {  # 定义任务与模型映射关系的字典
    "text-generation": MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
    "image-classification": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
    "image-segmentation": MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES,
    "fill-mask": MODEL_FOR_MASKED_LM_MAPPING_NAMES,
    "object-detection": MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES,
    "question-answering": MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES,
    "text2text-generation": MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
    "text-classification": MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
    "table-question-answering": MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES,
    "token-classification": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
    "audio-classification": MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
    "automatic-speech-recognition": {**MODEL_FOR_CTC_MAPPING_NAMES, **MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES},
    "zero-shot-image-classification": MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES,
}

logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器
    # 定义结构化的模型卡片类。存储模型卡片以及加载/下载/保存模型卡片的方法。

    # 请阅读以下论文以获取关于各部分的详细信息和解释:“Model Cards for Model Reporting” 作者包括 Margaret Mitchell, Simone Wu, Andrew Zaldivar, 等人提出了模型卡片的建议。链接:https://arxiv.org/abs/1810.03993

    # 注意:可以加载和保存模型卡片到磁盘上。
    """
    
    # 初始化方法,用于创建模型卡片对象
    def __init__(self, **kwargs):
        # 发出警告,表示该类 `ModelCard` 已被弃用,并将在 Transformers 的第五版中移除
        warnings.warn(
            "The class `ModelCard` is deprecated and will be removed in version 5 of Transformers", FutureWarning
        )
        # 推荐的属性来源于 https://arxiv.org/abs/1810.03993(见论文)
        # 设置模型细节
        self.model_details = kwargs.pop("model_details", {})
        # 设置预期使用
        self.intended_use = kwargs.pop("intended_use", {})
        # 设置因素
        self.factors = kwargs.pop("factors", {})
        # 设置度量
        self.metrics = kwargs.pop("metrics", {})
        # 设置评估数据
        self.evaluation_data = kwargs.pop("evaluation_data", {})
        # 设置训练数据
        self.training_data = kwargs.pop("training_data", {})
        # 设置定量分析
        self.quantitative_analyses = kwargs.pop("quantitative_analyses", {})
        # 设置伦理考虑
        self.ethical_considerations = kwargs.pop("ethical_considerations", {})
        # 设置注意事项和建议
        self.caveats_and_recommendations = kwargs.pop("caveats_and_recommendations", {})

        # 打开额外的属性
        for key, value in kwargs.items():
            try:
                setattr(self, key, value)
            except AttributeError as err:
                # 如果无法设置属性,则记录错误信息并抛出异常
                logger.error(f"Can't set {key} with value {value} for {self}")
                raise err

    # 将模型卡片对象保存到指定的目录或文件
    def save_pretrained(self, save_directory_or_file):
        """Save a model card object to the directory or file `save_directory_or_file`."""
        # 如果保存目录存在,则使用预定义的文件名保存,方便使用 `from_pretrained` 加载
        if os.path.isdir(save_directory_or_file):
            output_model_card_file = os.path.join(save_directory_or_file, MODEL_CARD_NAME)
        else:
            output_model_card_file = save_directory_or_file

        # 将模型卡片对象保存为 JSON 文件
        self.to_json_file(output_model_card_file)
        logger.info(f"Model card saved in {output_model_card_file}")

    # 从 Python 字典中构造一个 `ModelCard` 对象的类方法
    @classmethod
    def from_dict(cls, json_object):
        """Constructs a `ModelCard` from a Python dictionary of parameters."""
        return cls(**json_object)

    # 从 JSON 文件中构造一个 `ModelCard` 对象的类方法
    @classmethod
    def from_json_file(cls, json_file):
        """Constructs a `ModelCard` from a json file of parameters."""
        # 读取 JSON 文件内容
        with open(json_file, "r", encoding="utf-8") as reader:
            text = reader.read()
        # 解析 JSON 文本为 Python 字典对象
        dict_obj = json.loads(text)
        # 使用字典对象构造一个新的 `ModelCard` 对象
        return cls(**dict_obj)

    # 判断两个 `ModelCard` 对象是否相等的特殊方法
    def __eq__(self, other):
        return self.__dict__ == other.__dict__

    # 返回 `ModelCard` 对象的字符串表示形式的特殊方法
    def __repr__(self):
        return str(self.to_json_string())
    # 将当前对象实例序列化为一个 Python 字典
    def to_dict(self):
        """Serializes this instance to a Python dictionary."""
        # 深拷贝当前对象的所有属性到 output 字典中
        output = copy.deepcopy(self.__dict__)
        return output

    # 将当前对象实例序列化为 JSON 字符串
    def to_json_string(self):
        """Serializes this instance to a JSON string."""
        # 调用 to_dict 方法获取对象的字典表示,转换为带缩进和排序键的 JSON 字符串,并添加换行符
        return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"

    # 将当前对象实例保存到一个 JSON 文件中
    def to_json_file(self, json_file_path):
        """Save this instance to a json file."""
        # 打开指定路径的 JSON 文件,使用 UTF-8 编码写入对象的 JSON 字符串表示
        with open(json_file_path, "w", encoding="utf-8") as writer:
            writer.write(self.to_json_string())
AUTOGENERATED_TRAINER_COMMENT = """
<!-- This model card has been generated automatically according to the information the Trainer had access to. You
should probably proofread and complete it, then remove this comment. -->
"""

AUTOGENERATED_KERAS_COMMENT = """
<!-- This model card has been generated automatically according to the information Keras had access to. You should
probably proofread and complete it, then remove this comment. -->
"""


TASK_TAG_TO_NAME_MAPPING = {
    "fill-mask": "Masked Language Modeling",  # 映射任务标签 "fill-mask" 到任务名称 "Masked Language Modeling"
    "image-classification": "Image Classification",  # 映射任务标签 "image-classification" 到任务名称 "Image Classification"
    "image-segmentation": "Image Segmentation",  # 映射任务标签 "image-segmentation" 到任务名称 "Image Segmentation"
    "multiple-choice": "Multiple Choice",  # 映射任务标签 "multiple-choice" 到任务名称 "Multiple Choice"
    "object-detection": "Object Detection",  # 映射任务标签 "object-detection" 到任务名称 "Object Detection"
    "question-answering": "Question Answering",  # 映射任务标签 "question-answering" 到任务名称 "Question Answering"
    "summarization": "Summarization",  # 映射任务标签 "summarization" 到任务名称 "Summarization"
    "table-question-answering": "Table Question Answering",  # 映射任务标签 "table-question-answering" 到任务名称 "Table Question Answering"
    "text-classification": "Text Classification",  # 映射任务标签 "text-classification" 到任务名称 "Text Classification"
    "text-generation": "Causal Language Modeling",  # 映射任务标签 "text-generation" 到任务名称 "Causal Language Modeling"
    "text2text-generation": "Sequence-to-sequence Language Modeling",  # 映射任务标签 "text2text-generation" 到任务名称 "Sequence-to-sequence Language Modeling"
    "token-classification": "Token Classification",  # 映射任务标签 "token-classification" 到任务名称 "Token Classification"
    "translation": "Translation",  # 映射任务标签 "translation" 到任务名称 "Translation"
    "zero-shot-classification": "Zero Shot Classification",  # 映射任务标签 "zero-shot-classification" 到任务名称 "Zero Shot Classification"
    "automatic-speech-recognition": "Automatic Speech Recognition",  # 映射任务标签 "automatic-speech-recognition" 到任务名称 "Automatic Speech Recognition"
    "audio-classification": "Audio Classification",  # 映射任务标签 "audio-classification" 到任务名称 "Audio Classification"
}


METRIC_TAGS = [
    "accuracy",  # 表示度量标签 "accuracy",用于评估模型准确性
    "bleu",  # 表示度量标签 "bleu",用于评估机器翻译质量
    "f1",  # 表示度量标签 "f1",用于评估分类和信息检索等任务的准确性
    "matthews_correlation",  # 表示度量标签 "matthews_correlation",用于评估二分类问题中的相关性
    "pearsonr",  # 表示度量标签 "pearsonr",用于评估两个变量之间的线性相关性
    "precision",  # 表示度量标签 "precision",用于评估分类模型中的精确性
    "recall",  # 表示度量标签 "recall",用于评估分类模型中的召回率
    "rouge",  # 表示度量标签 "rouge",用于评估文本摘要生成模型的质量
    "sacrebleu",  # 表示度量标签 "sacrebleu",用于机器翻译任务中的 BLEU 得分
    "spearmanr",  # 表示度量标签 "spearmanr",用于评估两个变量的非线性相关性
    "wer",  # 表示度量标签 "wer",用于评估自动语音识别中的词错误率
]


def _listify(obj):
    if obj is None:
        return []  # 如果对象为 None,则返回空列表
    elif isinstance(obj, str):
        return [obj]  # 如果对象为字符串,则返回包含该字符串的列表
    else:
        return obj  # 否则返回原始对象


def _insert_values_as_list(metadata, name, values):
    if values is None:
        return metadata  # 如果值为 None,则返回元数据本身
    if isinstance(values, str):
        values = [values]  # 如果值为字符串,则转换成单元素列表
    values = [v for v in values if v is not None]  # 过滤掉值中的 None 元素
    if len(values) == 0:
        return metadata  # 如果列表为空,则返回元数据本身
    metadata[name] = values  # 将处理后的列表赋给元数据对应的名称
    return metadata  # 返回更新后的元数据


def infer_metric_tags_from_eval_results(eval_results):
    if eval_results is None:
        return {}  # 如果评估结果为 None,则返回空字典
    result = {}  # 初始化结果字典
    for key in eval_results.keys():
        if key.lower().replace(" ", "_") in METRIC_TAGS:
            result[key.lower().replace(" ", "_")] = key  # 将符合度量标签的键添加到结果字典中
        elif key.lower() == "rouge1":
            result["rouge"] = key  # 特别处理 "rouge1",将其映射为 "rouge"
    return result  # 返回最终的结果字典


def _insert_value(metadata, name, value):
    if value is None:
        return metadata  # 如果值为 None,则返回元数据本身
    metadata[name] = value  # 将值插入到元数据中对应的名称
    return metadata  # 返回更新后的元数据


def is_hf_dataset(dataset):
    if not is_datasets_available():
        return False  # 如果 datasets 库不可用,则返回 False

    from datasets import Dataset, IterableDataset

    return isinstance(dataset, (Dataset, IterableDataset))  # 判断 dataset 是否是 Dataset 或 IterableDataset 类的实例


def _get_mapping_values(mapping):
    result = []  # 初始化结果列表
    for v in mapping.values():
        if isinstance(v, (tuple, list)):
            result += list(v)  # 如果值是元组或列表,则将其展开并添加到结果列表中
        else:
            result.append(v)  # 否则直接添加到结果列表中
    return result  # 返回所有映射值组成的列表


@dataclass
class TrainingSummary:
    model_name: str  # 模型名称
    language: Optional[Union[str, List[str]]] = None  # 语言属性,可以是字符串或字符串列表,默认为 None
    license: Optional[str] = None  # 许可证信息,默认为 None
"""
    # 标签,可以是字符串或字符串列表,用于标识模型的类别或特征
    tags: Optional[Union[str, List[str]]] = None
    # 微调自哪个模型而来的信息
    finetuned_from: Optional[str] = None
    # 任务,可以是字符串或字符串列表,描述模型训练的任务类型
    tasks: Optional[Union[str, List[str]]] = None
    # 数据集,可以是字符串或字符串列表,指定用于训练的数据集名称或描述
    dataset: Optional[Union[str, List[str]]] = None
    # 数据集标签,可以是字符串或字符串列表,用于描述数据集的特征或类别
    dataset_tags: Optional[Union[str, List[str]]] = None
    # 数据集参数,可以是字符串或字符串列表,指定数据集的详细参数
    dataset_args: Optional[Union[str, List[str]]] = None
    # 数据集元数据,是一个字典,包含关于数据集的其他信息
    dataset_metadata: Optional[Dict[str, Any]] = None
    # 评估结果,是一个字典,包含模型评估的指标和结果
    eval_results: Optional[Dict[str, float]] = None
    # 评估结果的行信息,是一个字符串列表,记录评估结果的详细信息
    eval_lines: Optional[List[str]] = None
    # 超参数,是一个字典,包含模型训练时使用的超参数信息
    hyperparameters: Optional[Dict[str, Any]] = None
    # 模型的来源,通常为字符串 "trainer"
    source: Optional[str] = "trainer"

    def __post_init__(self):
        # 根据微调自的模型信息推断默认许可证
        if (
            self.license is None  # 如果许可证为空
            and not is_offline_mode()  # 并且不是离线模式
            and self.finetuned_from is not None  # 并且有微调自的模型信息
            and len(self.finetuned_from) > 0  # 并且微调自的模型信息不为空字符串
        ):
            try:
                # 获取微调自模型的信息
                info = model_info(self.finetuned_from)
                # 遍历模型信息的标签
                for tag in info.tags:
                    # 如果标签以 "license:" 开头
                    if tag.startswith("license:"):
                        # 设置许可证为标签中 "license:" 后的内容
                        self.license = tag[8:]
            except (requests.exceptions.HTTPError, requests.exceptions.ConnectionError, HFValidationError):
                # 处理可能的网络请求错误或验证错误
                pass
    def create_model_index(self, metric_mapping):
        # 初始化模型索引,包含模型名称
        model_index = {"name": self.model_name}

        # 将数据集相关信息转换为列表形式
        dataset_names = _listify(self.dataset)
        dataset_tags = _listify(self.dataset_tags)
        dataset_args = _listify(self.dataset_args)
        dataset_metadata = _listify(self.dataset_metadata)

        # 如果参数数量不足,则用 None 补齐
        if len(dataset_args) < len(dataset_tags):
            dataset_args = dataset_args + [None] * (len(dataset_tags) - len(dataset_args))

        # 创建数据集映射字典,将标签映射到名称
        dataset_mapping = dict(zip(dataset_tags, dataset_names))
        dataset_arg_mapping = dict(zip(dataset_tags, dataset_args))
        dataset_metadata_mapping = dict(zip(dataset_tags, dataset_metadata))

        # 创建任务映射字典,将任务标签映射到任务名称
        task_mapping = {
            task: TASK_TAG_TO_NAME_MAPPING[task] for task in _listify(self.tasks) if task in TASK_TAG_TO_NAME_MAPPING
        }

        # 初始化结果列表
        model_index["results"] = []

        # 如果任务映射和数据集映射都为空,则返回只包含模型名称的列表
        if len(task_mapping) == 0 and len(dataset_mapping) == 0:
            return [model_index]

        # 如果任务映射为空,则将其设置为包含 None 的字典
        if len(task_mapping) == 0:
            task_mapping = {None: None}

        # 如果数据集映射为空,则将其设置为包含 None 的字典
        if len(dataset_mapping) == 0:
            dataset_mapping = {None: None}

        # 遍历所有可能的任务和数据集组合
        all_possibilities = [(task_tag, ds_tag) for task_tag in task_mapping for ds_tag in dataset_mapping]
        for task_tag, ds_tag in all_possibilities:
            result = {}

            # 如果任务标签不为空,则设置任务名称和类型
            if task_tag is not None:
                result["task"] = {"name": task_mapping[task_tag], "type": task_tag}

            # 如果数据集标签不为空,则设置数据集名称、类型以及元数据
            if ds_tag is not None:
                metadata = dataset_metadata_mapping.get(ds_tag, {})
                result["dataset"] = {
                    "name": dataset_mapping[ds_tag],
                    "type": ds_tag,
                    **metadata,
                }
                # 如果数据集参数不为空,则设置参数
                if dataset_arg_mapping[ds_tag] is not None:
                    result["dataset"]["args"] = dataset_arg_mapping[ds_tag]

            # 如果度量映射不为空,则设置度量结果
            if len(metric_mapping) > 0:
                result["metrics"] = []
                for metric_tag, metric_name in metric_mapping.items():
                    result["metrics"].append(
                        {
                            "name": metric_name,
                            "type": metric_tag,
                            "value": self.eval_results[metric_name],
                        }
                    )

            # 如果结果中包含任务、数据集和度量,则将结果添加到模型索引中
            if "task" in result and "dataset" in result and "metrics" in result:
                model_index["results"].append(result)
            else:
                # 否则,记录日志并丢弃结果以避免模型卡片被拒绝
                logger.info(f"Dropping the following result as it does not have all the necessary fields:\n{result}")

        # 返回包含模型索引的列表
        return [model_index]
    # 创建元数据的方法,用于生成模型相关的元数据信息
    def create_metadata(self):
        # 从评估结果推断度量标签的映射关系
        metric_mapping = infer_metric_tags_from_eval_results(self.eval_results)

        # 初始化一个空的元数据字典
        metadata = {}

        # 将语言信息插入元数据字典,作为列表形式存储
        metadata = _insert_values_as_list(metadata, "language", self.language)
        
        # 将许可证信息插入元数据字典,作为单一值存储
        metadata = _insert_value(metadata, "license", self.license)
        
        # 如果模型是从某个基础模型微调而来,且基础模型为非空字符串,则插入基础模型信息
        if self.finetuned_from is not None and isinstance(self.finetuned_from, str) and len(self.finetuned_from) > 0:
            metadata = _insert_value(metadata, "base_model", self.finetuned_from)
        
        # 将标签信息插入元数据字典,作为列表形式存储
        metadata = _insert_values_as_list(metadata, "tags", self.tags)
        
        # 将数据集标签信息插入元数据字典,作为列表形式存储
        metadata = _insert_values_as_list(metadata, "datasets", self.dataset_tags)
        
        # 将度量标签映射中的键(度量名称)作为列表插入元数据字典
        metadata = _insert_values_as_list(metadata, "metrics", list(metric_mapping.keys()))
        
        # 创建模型索引并插入元数据字典中
        metadata["model-index"] = self.create_model_index(metric_mapping)

        # 返回生成的元数据字典
        return metadata

    @classmethod
    def from_trainer(
        cls,
        trainer,
        language=None,
        license=None,
        tags=None,
        model_name=None,
        finetuned_from=None,
        tasks=None,
        dataset_tags=None,
        dataset_metadata=None,
        dataset=None,
        dataset_args=None,
    ):
        # 推断默认数据集
        one_dataset = trainer.eval_dataset if trainer.eval_dataset is not None else trainer.train_dataset
        # 如果数据集来自 HF 数据集且缺少标签、参数或元数据,则推断默认标签
        if is_hf_dataset(one_dataset) and (dataset_tags is None or dataset_args is None or dataset_metadata is None):
            default_tag = one_dataset.builder_name
            # 排除不是来自 Hub 的虚构数据集
            if default_tag not in ["csv", "json", "pandas", "parquet", "text"]:
                # 如果缺少元数据,则创建包含配置名和分割信息的元数据列表
                if dataset_metadata is None:
                    dataset_metadata = [{"config": one_dataset.config_name, "split": str(one_dataset.split)}]
                # 如果缺少标签,则使用默认标签
                if dataset_tags is None:
                    dataset_tags = [default_tag]
                # 如果缺少参数,则使用配置名作为参数
                if dataset_args is None:
                    dataset_args = [one_dataset.config_name]

        # 如果未指定数据集但指定了数据集标签,则将数据集设为数据集标签
        if dataset is None and dataset_tags is not None:
            dataset = dataset_tags

        # 推断默认微调自
        if (
            finetuned_from is None
            and hasattr(trainer.model.config, "_name_or_path")
            and not os.path.isdir(trainer.model.config._name_or_path)
        ):
            # 使用模型配置的名称或路径作为微调来源
            finetuned_from = trainer.model.config._name_or_path

        # 推断默认任务标签
        if tasks is None:
            model_class_name = trainer.model.__class__.__name__
            # 遍历任务映射表,根据模型类名获取任务标签
            for task, mapping in TASK_MAPPING.items():
                if model_class_name in _get_mapping_values(mapping):
                    tasks = task

        # 如果未指定模型名称,则使用输出目录的名称作为模型名称
        if model_name is None:
            model_name = Path(trainer.args.output_dir).name
        # 如果模型名称为空字符串,则使用微调来源作为模型名称
        if len(model_name) == 0:
            model_name = finetuned_from

        # 将 `generated_from_trainer` 添加到标签中
        if tags is None:
            tags = ["generated_from_trainer"]
        elif isinstance(tags, str) and tags != "generated_from_trainer":
            tags = [tags, "generated_from_trainer"]
        elif "generated_from_trainer" not in tags:
            tags.append("generated_from_trainer")

        # 解析训练状态日志历史,获取日志行和评估结果
        _, eval_lines, eval_results = parse_log_history(trainer.state.log_history)
        # 从训练器中提取超参数
        hyperparameters = extract_hyperparameters_from_trainer(trainer)

        # 返回构造的类对象,初始化各个参数
        return cls(
            language=language,
            license=license,
            tags=tags,
            model_name=model_name,
            finetuned_from=finetuned_from,
            tasks=tasks,
            dataset=dataset,
            dataset_tags=dataset_tags,
            dataset_args=dataset_args,
            dataset_metadata=dataset_metadata,
            eval_results=eval_results,
            eval_lines=eval_lines,
            hyperparameters=hyperparameters,
        )

    @classmethod
    def from_keras(
        cls,
        model,
        model_name,
        keras_history=None,
        language=None,
        license=None,
        tags=None,
        finetuned_from=None,
        tasks=None,
        dataset_tags=None,
        dataset=None,
        dataset_args=None,
        # 接受以下参数并返回新的 HFModelArguments 对象
        ):
        # 如果给定了 dataset 参数:
        if dataset is not None:
            # 如果 dataset 是 HF dataset 并且 dataset_tags 或 dataset_args 为 None:
            if is_hf_dataset(dataset) and (dataset_tags is None or dataset_args is None):
                # 使用 dataset 的构建器名称作为默认标签
                default_tag = dataset.builder_name
                # 排除不是来自 Hub 的虚构数据集
                if default_tag not in ["csv", "json", "pandas", "parquet", "text"]:
                    # 如果 dataset_tags 为 None,则设为默认标签列表
                    if dataset_tags is None:
                        dataset_tags = [default_tag]
                    # 如果 dataset_args 为 None,则设为 dataset 的配置名称列表
                    if dataset_args is None:
                        dataset_args = [dataset.config_name]

        # 如果 dataset 为 None 而 dataset_tags 不为 None,则将 dataset 设置为 dataset_tags
        if dataset is None and dataset_tags is not None:
            dataset = dataset_tags

        # 推断默认的 finetuned_from
        if (
            finetuned_from is None
            and hasattr(model.config, "_name_or_path")
            and not os.path.isdir(model.config._name_or_path)
        ):
            # 使用 model.config 的 _name_or_path 属性作为 finetuned_from
            finetuned_from = model.config._name_or_path

        # 推断默认的任务标签:
        if tasks is None:
            # 获取 model 的类名
            model_class_name = model.__class__.__name__
            # 遍历 TASK_MAPPING 中的任务映射
            for task, mapping in TASK_MAPPING.items():
                # 如果 model_class_name 在映射值中
                if model_class_name in _get_mapping_values(mapping):
                    # 设置任务为当前的 task

                    Add ` generated_from_keras_callback to
# 解析 `logs` 参数,该参数可以是 `model.fit()` 返回的 `keras.History` 对象,也可以是传递给 `PushToHubCallback` 的累积日志字典
def parse_keras_history(logs):
    if hasattr(logs, "history"):
        # 如果 `logs` 对象有 `history` 属性,则看起来像是一个 `History` 对象
        if not hasattr(logs, "epoch"):
            # 如果 `logs` 对象没有 `epoch` 属性,表示历史记录为空,返回空结果
            return None, [], {}
        # 将 `epoch` 属性添加到 `logs.history` 字典中
        logs.history["epoch"] = logs.epoch
        # 使用 `logs.history` 替换 `logs`,统一处理为字典格式
        logs = logs.history
    else:
        # 如果 `logs` 不是 `History` 对象,则假设它是一个包含字典列表的训练日志,我们将其转换为字典的列表格式,以匹配 `History` 对象的结构
        logs = {log_key: [single_dict[log_key] for single_dict in logs] for log_key in logs[0]}

    # 初始化空列表 `lines`,用于存储解析后的日志信息
    lines = []
    # 遍历 `epoch` 列表的长度,即遍历每个周期的日志
    for i in range(len(logs["epoch"])):
        # 创建当前周期的字典 `epoch_dict`,将每个日志键值对应到当前周期的值
        epoch_dict = {log_key: log_value_list[i] for log_key, log_value_list in logs.items()}
        # 初始化空字典 `values`,用于存储当前周期的解析后的键值对
        values = {}
        # 遍历 `epoch_dict` 中的每个键值对
        for k, v in epoch_dict.items():
            if k.startswith("val_"):
                # 如果键以 "val_" 开头,将其改为以 "validation_" 开头
                k = "validation_" + k[4:]
            elif k != "epoch":
                # 如果键不是 "epoch",则将其改为以 "train_" 开头
                k = "train_" + k
            # 将键名按照下划线分割后,每个部分首字母大写,形成更友好的名称
            splits = k.split("_")
            name = " ".join([part.capitalize() for part in splits])
            # 将处理后的键值对加入到 `values` 字典中
            values[name] = v
        # 将当前周期解析后的字典 `values` 添加到 `lines` 列表中
        lines.append(values)

    # 提取评估结果,即最后一个周期解析后的结果
    eval_results = lines[-1]

    # 返回原始日志字典、解析后的周期信息列表 `lines` 和评估结果 `eval_results`
    return logs, lines, eval_results


# 解析 `log_history` 参数,获取 `Trainer` 的中间和最终评估结果
def parse_log_history(log_history):
    # 初始化索引 `idx`,从头开始查找直到找到包含 "train_runtime" 的日志条目
    idx = 0
    while idx < len(log_history) and "train_runtime" not in log_history[idx]:
        idx += 1

    # 如果没有训练日志
    if idx == len(log_history):
        # 将索引减一,从最后一个日志向前查找包含 "eval_loss" 的日志条目
        idx -= 1
        while idx >= 0 and "eval_loss" not in log_history[idx]:
            idx -= 1

        # 如果找到了包含 "eval_loss" 的日志条目,则返回 `None`、`None` 和该日志条目
        if idx >= 0:
            return None, None, log_history[idx]
        else:
            # 如果没有找到包含 "eval_loss" 的日志条目,则返回三个 `None`
            return None, None, None

    # 现在我们可以假设存在训练日志
    # 获取训练日志 `train_log`,即包含 "train_runtime" 的日志条目
    train_log = log_history[idx]
    # 初始化空列表 `lines`,用于存储解析后的日志信息
    lines = []
    # 初始化训练损失为 "No log"
    training_loss = "No log"
    # 遍历日志历史记录中索引范围内的每一个索引 i
    for i in range(idx):
        # 如果当前索引 i 的日志记录包含 "loss" 键
        if "loss" in log_history[i]:
            # 将训练损失记录下来
            training_loss = log_history[i]["loss"]
        
        # 如果当前索引 i 的日志记录包含 "eval_loss" 键
        if "eval_loss" in log_history[i]:
            # 复制当前日志记录中的所有项到 metrics 字典中
            metrics = log_history[i].copy()
            # 移除不需要的项目
            _ = metrics.pop("total_flos", None)
            epoch = metrics.pop("epoch", None)
            step = metrics.pop("step", None)
            _ = metrics.pop("eval_runtime", None)
            _ = metrics.pop("eval_samples_per_second", None)
            _ = metrics.pop("eval_steps_per_second", None)
            _ = metrics.pop("eval_jit_compilation_time", None)
            
            # 初始化一个空字典 values,用于存储需要的指标
            values = {"Training Loss": training_loss, "Epoch": epoch, "Step": step}
            
            # 遍历 metrics 字典中的每一项
            for k, v in metrics.items():
                # 如果当前项的键是 "eval_loss"
                if k == "eval_loss":
                    # 将其值存入 values 字典中作为 "Validation Loss"
                    values["Validation Loss"] = v
                else:
                    # 如果键不是 "eval_loss",将键按照下划线分割为列表
                    splits = k.split("_")
                    # 将分割后的每个部分的首字母大写,并连接起来作为指标名称
                    name = " ".join([part.capitalize() for part in splits[1:]])
                    # 将该指标的值存入 values 字典中
                    values[name] = v
            
            # 将 values 字典存入 lines 列表中
            lines.append(values)

    # 将 idx 设置为日志历史记录的长度减一
    idx = len(log_history) - 1
    
    # 当 idx 大于等于 0 且日志历史记录中索引为 idx 的项不包含 "eval_loss" 键时循环
    while idx >= 0 and "eval_loss" not in log_history[idx]:
        # 减小 idx 的值
        idx -= 1

    # 如果 idx 大于 0
    if idx > 0:
        # 初始化一个空字典 eval_results,用于存储评估结果
        eval_results = {}
        
        # 遍历日志历史记录中索引为 idx 的项的每一个键值对
        for key, value in log_history[idx].items():
            # 如果键以 "eval_" 开头,去除开头的 "eval_"
            if key.startswith("eval_"):
                key = key[5:]
            # 如果键不是 ["runtime", "samples_per_second", "steps_per_second", "epoch", "step"] 中的一员
            if key not in ["runtime", "samples_per_second", "steps_per_second", "epoch", "step"]:
                # 将键按照下划线分割为列表,每个部分首字母大写,并连接起来作为新键
                camel_cased_key = " ".join([part.capitalize() for part in key.split("_")])
                # 将该键及其对应的值存入 eval_results 字典中
                eval_results[camel_cased_key] = value
        
        # 返回训练日志 train_log,行列表 lines,以及评估结果 eval_results
        return train_log, lines, eval_results
    else:
        # 如果 idx 不大于 0,则返回训练日志 train_log,行列表 lines,以及空的评估结果
        return train_log, lines, None
def extract_hyperparameters_from_keras(model):
    # 导入 keras 模块中的函数和类
    from .modeling_tf_utils import keras

    # 创建一个空字典用于存储超参数
    hyperparameters = {}

    # 检查模型是否具有优化器,并且获取其配置信息
    if hasattr(model, "optimizer") and model.optimizer is not None:
        hyperparameters["optimizer"] = model.optimizer.get_config()
    else:
        hyperparameters["optimizer"] = None

    # 获取全局训练精度策略的名称
    hyperparameters["training_precision"] = keras.mixed_precision.global_policy().name

    # 返回提取的超参数字典
    return hyperparameters


def _maybe_round(v, decimals=4):
    # 如果 v 是浮点数且有小数部分超过指定的小数位数,则返回按小数位数四舍五入后的字符串
    if isinstance(v, float) and len(str(v).split(".")) > 1 and len(str(v).split(".")[1]) > decimals:
        return f"{v:.{decimals}f}"
    # 否则返回 v 的字符串形式
    return str(v)


def _regular_table_line(values, col_widths):
    # 生成 Markdown 表格的一行,包括表格的普通行格式
    values_with_space = [f"| {v}" + " " * (w - len(v) + 1) for v, w in zip(values, col_widths)]
    return "".join(values_with_space) + "|\n"


def _second_table_line(col_widths):
    # 生成 Markdown 表格的第二行,包括表头和数据之间的分隔线格式
    values = ["|:" + "-" * w + ":" for w in col_widths]
    return "".join(values) + "|\n"


def make_markdown_table(lines):
    """
    Create a nice Markdown table from the results in `lines`.
    """
    # 如果 lines 为空或者 None,则返回空字符串
    if lines is None or len(lines) == 0:
        return ""

    # 初始化列宽字典,计算每列的最大宽度
    col_widths = {key: len(str(key)) for key in lines[0].keys()}
    for line in lines:
        for key, value in line.items():
            if col_widths[key] < len(_maybe_round(value)):
                col_widths[key] = len(_maybe_round(value))

    # 构建 Markdown 表格的内容
    table = _regular_table_line(list(lines[0].keys()), list(col_widths.values()))
    table += _second_table_line(list(col_widths.values()))
    for line in lines:
        table += _regular_table_line([_maybe_round(v) for v in line.values()], list(col_widths.values()))
    return table


_TRAINING_ARGS_KEYS = [
    "learning_rate",
    "train_batch_size",
    "eval_batch_size",
    "seed",
]


def extract_hyperparameters_from_trainer(trainer):
    # 从训练器对象中提取超参数,使用预定义的训练参数键
    hyperparameters = {k: getattr(trainer.args, k) for k in _TRAINING_ARGS_KEYS}

    # 如果并行模式不是单GPU模式或非分布式模式,则添加分布式类型
    if trainer.args.parallel_mode not in [ParallelMode.NOT_PARALLEL, ParallelMode.NOT_DISTRIBUTED]:
        hyperparameters["distributed_type"] = (
            "multi-GPU" if trainer.args.parallel_mode == ParallelMode.DISTRIBUTED else trainer.args.parallel_mode.value
        )

    # 如果使用多个设备进行训练,则添加设备数量
    if trainer.args.world_size > 1:
        hyperparameters["num_devices"] = trainer.args.world_size

    # 如果梯度累积步数大于1,则添加梯度累积步数
    if trainer.args.gradient_accumulation_steps > 1:
        hyperparameters["gradient_accumulation_steps"] = trainer.args.gradient_accumulation_steps

    # 计算总的训练批次大小,如果不等于预定义的训练批次大小,则添加总训练批次大小
    total_train_batch_size = (
        trainer.args.train_batch_size * trainer.args.world_size * trainer.args.gradient_accumulation_steps
    )
    if total_train_batch_size != hyperparameters["train_batch_size"]:
        hyperparameters["total_train_batch_size"] = total_train_batch_size

    # 计算总的评估批次大小,如果不等于预定义的评估批次大小,则添加总评估批次大小
    total_eval_batch_size = trainer.args.eval_batch_size * trainer.args.world_size
    if total_eval_batch_size != hyperparameters["eval_batch_size"]:
        hyperparameters["total_eval_batch_size"] = total_eval_batch_size
    # 如果训练器的参数中指定了使用 Adafactor 优化器
    if trainer.args.adafactor:
        # 将超参数中的优化器设置为 Adafactor
        hyperparameters["optimizer"] = "Adafactor"
    else:
        # 否则,使用带有指定参数的 Adam 优化器
        hyperparameters["optimizer"] = (
            f"Adam with betas=({trainer.args.adam_beta1},{trainer.args.adam_beta2}) and"
            f" epsilon={trainer.args.adam_epsilon}"
        )

    # 设置学习率调度器的类型为训练器参数中指定的值
    hyperparameters["lr_scheduler_type"] = trainer.args.lr_scheduler_type.value
    
    # 如果训练器参数中指定了非零的预热比例
    if trainer.args.warmup_ratio != 0.0:
        # 将预热比例设置到学习率调度器的预热比例参数中
        hyperparameters["lr_scheduler_warmup_ratio"] = trainer.args.warmup_ratio
    
    # 如果训练器参数中指定了非零的预热步数
    if trainer.args.warmup_steps != 0.0:
        # 将预热步数设置到学习率调度器的预热步数参数中
        hyperparameters["lr_scheduler_warmup_steps"] = trainer.args.warmup_steps
    
    # 如果训练器参数中指定了最大步数不等于 -1
    if trainer.args.max_steps != -1:
        # 将最大步数设置到超参数的训练步数参数中
        hyperparameters["training_steps"] = trainer.args.max_steps
    else:
        # 否则,将训练轮数设置到超参数的训练轮数参数中
        hyperparameters["num_epochs"] = trainer.args.num_train_epochs

    # 如果训练器参数中指定了使用混合精度训练
    if trainer.args.fp16:
        # 如果使用了 Apex 框架
        if trainer.use_apex:
            # 将混合精度训练设置为 Apex,并包括指定的优化级别
            hyperparameters["mixed_precision_training"] = f"Apex, opt level {trainer.args.fp16_opt_level}"
        else:
            # 否则,将混合精度训练设置为本地 AMP 支持
            hyperparameters["mixed_precision_training"] = "Native AMP"

    # 如果训练器参数中指定了标签平滑因子不等于 0.0
    if trainer.args.label_smoothing_factor != 0.0:
        # 将标签平滑因子设置到超参数的标签平滑因子参数中
        hyperparameters["label_smoothing_factor"] = trainer.args.label_smoothing_factor

    # 返回设置好的超参数字典
    return hyperparameters

.\modeling_attn_mask_utils.py

# 引入必要的库
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union

# 导入PyTorch库
import torch

# 定义一个数据类,用于处理注意力掩码转换的实用工具
@dataclass
class AttentionMaskConverter:
    """
    A utility attention mask class that allows one to:
        - Create a causal 4d mask
        - Create a causal 4d mask with slided window
        - Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask (batch_size, 1, query_length,
          key_value_length) that can be multiplied with attention scores

    Examples:

    ```
    >>> import torch
    >>> from transformers.modeling_attn_mask_utils import AttentionMaskConverter

    >>> converter = AttentionMaskConverter(True)
    >>> converter.to_4d(torch.tensor([[0, 0, 0, 1, 1]]), 5, key_value_length=5, dtype=torch.float32)
    tensor([[[[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
            [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
            [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
            [-3.4028e+38, -3.4028e+38, -3.4028e+38,  0.0000e+00, -3.4028e+38],
            [-3.4028e+38, -3.4028e+38, -3.4028e+38,  0.0000e+00,  0.0000e+00]]]])
    ```

    Parameters:
        is_causal (`bool`):
            Whether the attention mask should be a uni-directional (causal) or bi-directional mask.

        sliding_window (`int`, *optional*):
            Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer.
    """

    # 定义类的属性,控制注意力掩码的类型和参数
    is_causal: bool
    sliding_window: int

    # 类的初始化方法,设定初始参数并验证滑动窗口参数的有效性
    def __init__(self, is_causal: bool, sliding_window: Optional[int] = None):
        self.is_causal = is_causal
        self.sliding_window = sliding_window

        # 如果定义了滑动窗口参数,确保其为正整数
        if self.sliding_window is not None and self.sliding_window <= 0:
            raise ValueError(
                f"Make sure that when passing `sliding_window` that its value is a strictly positive integer, not `{self.sliding_window}`"
            )

    # 方法:生成单向(因果)的四维注意力掩码
    def to_causal_4d(
        self,
        batch_size: int,
        query_length: int,
        key_value_length: int,
        dtype: torch.dtype,
        device: Union[torch.device, "str"] = "cpu",
    ) -> Optional[torch.Tensor]:
        """
        Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative
        bias to upper right hand triangular matrix (causal mask).
        """
        # 如果不是因果关系,抛出数值错误异常
        if not self.is_causal:
            raise ValueError(f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True.")

        # 如果形状未被缓存,创建一个新的因果遮罩并缓存它
        input_shape = (batch_size, query_length)
        past_key_values_length = key_value_length - query_length

        # 创建因果遮罩
        # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
        causal_4d_mask = None
        if input_shape[-1] > 1 or self.sliding_window is not None:
            causal_4d_mask = self._make_causal_mask(
                input_shape,
                dtype,
                device=device,
                past_key_values_length=past_key_values_length,
                sliding_window=self.sliding_window,
            )

        return causal_4d_mask

    def to_4d(
        self,
        attention_mask_2d: torch.Tensor,
        query_length: int,
        dtype: torch.dtype,
        key_value_length: Optional[int] = None,
    ) -> torch.Tensor:
        """
        Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length,
        key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is
        causal, a causal mask will be added.
        """
        # 计算输入形状,即(batch_size, query_length)
        input_shape = (attention_mask_2d.shape[0], query_length)

        # 创建因果(mask)
        # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
        causal_4d_mask = None
        if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
            if key_value_length is None:
                raise ValueError(
                    "This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask."
                )

            # 计算过去键值长度
            past_key_values_length = key_value_length - query_length
            # 生成因果(mask)
            causal_4d_mask = self._make_causal_mask(
                input_shape,
                dtype,
                device=attention_mask_2d.device,
                past_key_values_length=past_key_values_length,
                sliding_window=self.sliding_window,
            )
        elif self.sliding_window is not None:
            # 抛出未实现错误,滑动窗口目前仅支持因果掩蔽
            raise NotImplementedError("Sliding window is currently only implemented for causal masking")

        # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
        # 扩展注意力(mask)
        expanded_attn_mask = self._expand_mask(attention_mask_2d, dtype, tgt_len=input_shape[-1]).to(
            attention_mask_2d.device
        )

        # 如果存在因果(mask),则用大负数填充未注意的位置
        if causal_4d_mask is not None:
            expanded_attn_mask = causal_4d_mask.masked_fill(expanded_attn_mask.bool(), torch.finfo(dtype).min)

        # 扩展后的注意力(mask)可能会导致溢出
        expanded_4d_mask = expanded_attn_mask

        return expanded_4d_mask

    @staticmethod
    def _make_causal_mask(
        input_ids_shape: torch.Size,
        dtype: torch.dtype,
        device: torch.device,
        past_key_values_length: int = 0,
        sliding_window: Optional[int] = None,
        """
        Make causal mask used for bi-directional self-attention.
        """
        # 获取输入的张量形状信息,包括批大小和目标序列长度
        bsz, tgt_len = input_ids_shape

        # 创建一个与目标长度相同的方形矩阵,用极小的浮点数填充,设备为指定的设备
        mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)

        # 创建一个条件张量,范围是目标长度的整数序列
        mask_cond = torch.arange(mask.size(-1), device=device)

        # 使用条件张量来生成一个下三角矩阵,将其对角线上的元素保持为0
        mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)

        # 将掩码转换为指定的数据类型
        mask = mask.to(dtype)

        # 如果过去键值长度大于0,将0填充的过去键值长度张量连接到现有掩码之前
        if past_key_values_length > 0:
            mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)

        # 如果需要,添加下三角滑动窗口掩码
        if sliding_window is not None:
            # 计算对角线的值,用于下三角滑动窗口掩码
            diagonal = past_key_values_length - sliding_window - 1

            # 创建一个下三角矩阵,掩盖大于对角线值的元素
            context_mask = torch.tril(torch.ones_like(mask, dtype=torch.bool), diagonal=diagonal)

            # 使用极小的浮点数填充掩码中被标记的位置
            mask.masked_fill_(context_mask, torch.finfo(dtype).min)

        # 将掩码扩展为四维张量:[bsz, 1, tgt_len, tgt_len + past_key_values_length]
        return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)

    @staticmethod
    def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
        """
        Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
        """
        # 获取输入掩码的形状信息,包括批大小和源序列长度
        bsz, src_len = mask.size()

        # 如果未提供目标序列长度,则使用源序列长度作为目标序列长度
        tgt_len = tgt_len if tgt_len is not None else src_len

        # 将二维掩码扩展为四维张量,增加两个额外的维度,将其转换为指定的数据类型
        expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)

        # 创建一个反掩码,用1减去扩展的掩码
        inverted_mask = 1.0 - expanded_mask

        # 使用极小的浮点数填充反掩码中被标记的位置
        return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)

    @staticmethod
    def _unmask_unattended(
        expanded_mask: torch.FloatTensor,
        min_dtype: float,
        device: Optional[torch.device] = None
    ):
        """
        Unmasks the unattended positions in the attention matrix.
        """
        # fmt: off
        """
        Attend to all tokens in masked rows from the expanded attention mask, for example the relevant first rows when
        using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
        Details: https://github.com/pytorch/pytorch/issues/110213

        `expanded_mask` is [bsz, num_masks, tgt_seq_len, src_seq_len] or [bsz, tgt_seq_len, src_seq_len].
        `attention_mask` is [bsz, src_seq_len].

        The dimension num_masks of `expanded_mask` is most often 1, but it can also be the number of heads in the case of alibi attention bias.

        For example, if `expanded_mask` is (e.g. here left-padding case)
        ```
        [[[[0, 0, 0],
           [0, 0, 0],
           [0, 0, 1]]],
         [[[1, 0, 0],
           [1, 1, 0],
           [1, 1, 1]]],
         [[[0, 0, 0],
           [0, 1, 0],
           [0, 1, 1]]]]
        ```
        then the modified `expanded_mask` will be
        ```
        [[[[1, 1, 1],   <-- modified
           [1, 1, 1],   <-- modified
           [0, 0, 1]]],
         [[[1, 0, 0],
           [1, 1, 0],
           [1, 1, 1]]],
         [[[1, 1, 1],   <-- modified
           [0, 1, 0],
           [0, 1, 1]]]]
        """
        # fmt: on
        # 检查 expanded_mask 的数据类型是否为 torch.bool,若是则抛出 ValueError
        if expanded_mask.dtype == torch.bool:
            raise ValueError(
                "AttentionMaskConverter._unmask_unattended expects a float `expanded_mask`, got a BoolTensor."
            )

        # 返回一个修改过的 expanded_mask,其中未被完全掩盖的行将保持不变,其他行将被置为零
        return expanded_mask.mul(~torch.all(expanded_mask == min_dtype, dim=-1, keepdim=True))
    # 创建一个用于 SDPA(scaled_dot_product_attention)的 4D 因果注意力掩码,形状为 `(batch_size, 1, query_length, key_value_length)`
    # 从形状为 `(batch_size, key_value_length)` 的 2D 注意力掩码创建
def _prepare_4d_causal_attention_mask(
    attention_mask: Optional[torch.Tensor],
    input_shape: Union[torch.Size, Tuple, List],
    inputs_embeds: torch.Tensor,
    past_key_values_length: int,
    sliding_window: Optional[int] = None,
):
    """
    Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
    `(batch_size, key_value_length)`

    Args:
        attention_mask (`torch.Tensor` or `None`):
            A 2D attention mask of shape `(batch_size, key_value_length)`
        input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
            The input shape should be a tuple that defines `(batch_size, query_length)`.
        inputs_embeds (`torch.Tensor`):
            The embedded inputs as a torch Tensor.
        past_key_values_length (`int`):
            The length of the key value cache.
        sliding_window (`int`, *optional*):
            If the model uses windowed attention, a sliding window should be passed.
    """
    # 创建一个 AttentionMaskConverter 对象,用于生成因果关系的注意力掩码
    attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)

    # 计算 key_value_length,这里是输入形状的最后一个维度加上过去的键值长度
    key_value_length = input_shape[-1] + past_key_values_length

    # 如果输入的 attention_mask 不为空且是 2D 的
    if attention_mask is not None and len(attention_mask.shape) == 2:
        # 将 2D 的 attention_mask 转换成 4D 的
        attention_mask = attn_mask_converter.to_4d(
            attention_mask, input_shape[-1], key_value_length=key_value_length, dtype=inputs_embeds.dtype
        )
    # 如果 attention_mask 不为空且是 4D 的
    elif attention_mask is not None and len(attention_mask.shape) == 4:
        # 检查 attention_mask 的形状是否符合预期
        expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
        if tuple(attention_mask.shape) != expected_shape:
            # 如果不符合预期,抛出 ValueError
            raise ValueError(
                f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
            )
        else:
            # 如果 4D 的 attention_mask 形状正确,则反转它并用负无穷填充
            inverted_mask = 1.0 - attention_mask
            attention_mask = inverted_mask.masked_fill(
                inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min
            )
    else:
        # 如果 attention_mask 为空或者不是 2D 或 4D 的,则使用 AttentionMaskConverter 生成因果 4D 注意力掩码
        attention_mask = attn_mask_converter.to_causal_4d(
            input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
        )

    # 返回生成的 attention_mask
    return attention_mask


# Adapted from _prepare_4d_causal_attention_mask
# 根据 _prepare_4d_causal_attention_mask 适配
def _prepare_4d_causal_attention_mask_for_sdpa(
    attention_mask: Optional[torch.Tensor],
    input_shape: Union[torch.Size, Tuple, List],
    inputs_embeds: torch.Tensor,
    past_key_values_length: int,
    sliding_window: Optional[int] = None,
):
    """
    Prepares the correct `attn_mask` argument to be used by `torch.nn.functional.scaled_dot_product_attention`.

    In case no token is masked in the `attention_mask` argument, we simply set it to `None` for the cases `query_length == 1` and
    """
    # 创建一个注意力掩码转换器对象,设定为因果关系模式,并指定是否使用滑动窗口
    attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)

    # 计算键值对的长度,包括输入形状的最后一个维度和过去键值对的长度
    key_value_length = input_shape[-1] + past_key_values_length
    # 获取输入形状中的批处理大小和查询长度
    batch_size, query_length = input_shape

    # 检查是否处于追踪状态,如果是,则需要使用SDPA的`attn_mask`参数而不是自定义的`attention_mask`
    is_tracing = (
        torch.jit.is_tracing()
        or isinstance(inputs_embeds, torch.fx.Proxy)
        or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
    )

    # 如果传入了注意力掩码
    if attention_mask is not None:
        # 如果注意力掩码是4维的
        if len(attention_mask.shape) == 4:
            # 验证注意力掩码的形状是否符合预期
            expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
            if tuple(attention_mask.shape) != expected_shape:
                # 抛出值错误,指出注意力掩码的形状不正确
                raise ValueError(
                    f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
                )
            else:
                # 如果4维掩码形状正确,反转掩码并用负无穷填充
                inverted_mask = 1.0 - attention_mask.to(inputs_embeds.dtype)
                attention_mask = inverted_mask.masked_fill(
                    inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min
                )
                return attention_mask

        # 如果不处于追踪状态,并且所有的注意力掩码值都是1
        elif not is_tracing and torch.all(attention_mask == 1):
            if query_length == 1:
                # 当查询长度为1时,因果关注和双向关注是相同的
                attention_mask = None
            elif key_value_length == query_length:
                # 当键值对的长度等于查询长度时,不需要注意力掩码
                attention_mask = None
            else:
                # 对于查询长度大于1且键值长度不等于查询长度的情况,无法忽略注意力掩码,需要特别处理
                pass

    # 如果没有传入注意力掩码,并且查询长度大于1且键值长度不等于查询长度
    elif query_length > 1 and key_value_length != query_length:
        # 将注意力掩码设为True,以便在后续控制流中转到`to_causal_4d`
        attention_mask = True
    # 如果正在进行跟踪(tracing),且未提供注意力掩码(attention_mask),则抛出值错误异常
    elif is_tracing:
        raise ValueError(
            'Attention using SDPA can not be traced with torch.jit.trace when no attention_mask is provided. To solve this issue, please either load your model with the argument `attn_implementation="eager"` or pass an attention_mask input when tracing the model.'
        )

    # 如果没有提供 attention_mask,则将 expanded_4d_mask 设置为 None
    if attention_mask is None:
        expanded_4d_mask = None
    # 如果 attention_mask 设置为 True,则通过 attn_mask_converter.to_causal_4d 函数生成扩展后的 4D 注意力掩码
    elif attention_mask is True:
        expanded_4d_mask = attn_mask_converter.to_causal_4d(
            input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
        )
    # 否则,根据给定的 attention_mask 使用 attn_mask_converter.to_4d 函数生成扩展后的 4D 注意力掩码
    else:
        expanded_4d_mask = attn_mask_converter.to_4d(
            attention_mask,
            input_shape[-1],
            dtype=inputs_embeds.dtype,
            key_value_length=key_value_length,
        )

        # 如果不是在跟踪模式下,并且 expanded_4d_mask 存在且在 CUDA 设备上,
        # 则调用 AttentionMaskConverter._unmask_unattended 函数,处理未注意的部分
        if not is_tracing and expanded_4d_mask.device.type == "cuda":
            expanded_4d_mask = AttentionMaskConverter._unmask_unattended(
                expanded_4d_mask, min_dtype=torch.finfo(inputs_embeds.dtype).min
            )

    # 返回生成或处理后的 expanded_4d_mask 注意力掩码
    return expanded_4d_mask
# 创建一个非因果关系的四维注意力掩码,其形状为 `(batch_size, 1, query_length, key_value_length)`,从形状为 `(batch_size, key_value_length)` 的二维掩码创建
def _prepare_4d_attention_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
    """
    Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
    `(batch_size, key_value_length)`

    Args:
        mask (`torch.Tensor` or `None`):
            A 2D attention mask of shape `(batch_size, key_value_length)`
        dtype (`torch.dtype`):
            The torch dtype the created mask shall have.
        tgt_len (`int`):
            The target length or query length the created mask shall have.
    """
    # 调用内部函数 `_expand_mask` 扩展掩码至四维并返回
    return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)


# 为 SDPA(Scaled Dot-Product Attention)创建非因果四维注意力掩码,形状为 `(batch_size, 1, query_length, key_value_length)`,从形状为 `(batch_size, key_value_length)` 的二维掩码创建
def _prepare_4d_attention_mask_for_sdpa(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
    """
    Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
    `(batch_size, key_value_length)`

    Args:
        mask (`torch.Tensor` or `None`):
            A 2D attention mask of shape `(batch_size, key_value_length)`
        dtype (`torch.dtype`):
            The torch dtype the created mask shall have.
        tgt_len (`int`):
            The target length or query length the created mask shall have.
    """
    # 获取 batch_size 和 key_value_length 的尺寸
    batch_size, key_value_length = mask.shape
    # 如果未提供 tgt_len,则默认为 key_value_length
    tgt_len = tgt_len if tgt_len is not None else key_value_length

    # 检查是否处于追踪模式
    is_tracing = (
        torch.jit.is_tracing()
        or isinstance(mask, torch.fx.Proxy)
        or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
    )

    # 如果掩码中所有元素均为 1
    if torch.all(mask == 1):
        if is_tracing:
            pass  # 如果处于追踪模式,不做任何操作
        elif tgt_len == 1:
            # 对于 query_length == 1,因果和双向注意力相同,返回 None
            return None
        elif key_value_length == tgt_len:
            # 如果 key_value_length 等于 tgt_len,返回 None
            return None
        else:
            # 对于 query_length > 1 且 key_value_length != query_length 的情况,
            # 我们不能忽略注意力掩码,因为 SDPA 因果掩码的生成可能会出错,
            # 我们在 SDPA 中将 is_causal=False,并依赖于 Transformers 的 attention_mask,因此在这里不设置为 None。
            # 参考: https://github.com/pytorch/pytorch/issues/108108
            return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
    else:
        # 对于其他情况,调用 `_expand_mask` 扩展掩码并返回
        return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
    # device: torch.device,
    # 定义一个参数 `device`,类型为 `torch.device`,用于指定张量运算的设备(如CPU或GPU)

    # past_key_values_length: int = 0,
    # 定义一个参数 `past_key_values_length`,类型为 `int`,默认值为 `0`,用于指定过去的键值对长度

    # sliding_window: Optional[int] = None,
    # 定义一个参数 `sliding_window`,类型为可选的 `int`,默认值为 `None`,用于指定滑动窗口的大小
# 创建一个形状为 `(batch_size, 1, query_length, key_value_length)` 的因果性四维掩码

def create_causal_mask(
    input_shape: Union[tuple[int], list[int], torch.Size],
    dtype: torch.dtype,
    device: torch.device,
    sliding_window: Optional[int] = None
) -> Optional[torch.Tensor]:
    """
    创建一个形状为 `(batch_size, 1, query_length, key_value_length)` 的因果性四维掩码

    Args:
        input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
            输入形状应为定义 `(batch_size, query_length)` 的元组。
        dtype (`torch.dtype`):
            所创建掩码的 torch 数据类型。
        device (`torch.device`):
            所创建掩码的 torch 设备。
        sliding_window (`int`, *optional*):
            如果模型使用窗口化注意力,应传入一个滑动窗口大小。
    """
    # 创建一个注意力掩码转换器,设置为因果性,根据是否提供滑动窗口参数决定
    attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)

    # 计算 key_value_length,包括过去键值长度和输入形状的最后一个维度
    key_value_length = past_key_values_length + input_shape[-1]

    # 使用掩码转换器生成四维因果性掩码
    attention_mask = attn_mask_converter.to_causal_4d(
        input_shape[0], input_shape[-1], key_value_length, dtype=dtype, device=device
    )

    # 返回生成的注意力掩码
    return attention_mask

.\modeling_flax_outputs.py

# 导入必要的模块和类
from typing import Dict, Optional, Tuple  # 导入类型提示相关模块

import flax  # 导入Flax库,用于结构化数据类
import jax.numpy as jnp  # 导入JAX的NumPy接口

from .utils import ModelOutput  # 从当前目录下的utils模块导入ModelOutput类


@flax.struct.dataclass
class FlaxBaseModelOutput(ModelOutput):
    """
    模型输出的基础类,包含可能的隐藏状态和注意力机制。

    Args:
        last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`):
            模型最后一层的隐藏状态序列输出。
        hidden_states (`tuple(jnp.ndarray)`, *optional*, 当 `output_hidden_states=True` 被传递或 `config.output_hidden_states=True` 时返回):
            形状为 `(batch_size, sequence_length, hidden_size)` 的 `jnp.ndarray` 元组。

            模型每一层的隐藏状态加上初始嵌入输出。
        attentions (`tuple(jnp.ndarray)`, *optional*, 当 `output_attentions=True` 被传递或 `config.output_attentions=True` 时返回):
            形状为 `(batch_size, num_heads, sequence_length, sequence_length)` 的 `jnp.ndarray` 元组。

            注意力机制softmax后的注意力权重,用于计算自注意力头中的加权平均值。
    """

    last_hidden_state: jnp.ndarray = None
    hidden_states: Optional[Tuple[jnp.ndarray]] = None
    attentions: Optional[Tuple[jnp.ndarray]] = None


@flax.struct.dataclass
class FlaxBaseModelOutputWithNoAttention(ModelOutput):
    """
    模型输出的基础类,包含可能的隐藏状态,但不包含注意力机制。

    Args:
        last_hidden_state (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`):
            模型最后一层的隐藏状态序列输出。
        hidden_states (`tuple(jnp.ndarray)`, *optional*, 当 `output_hidden_states=True` 被传递或 `config.output_hidden_states=True` 时返回):
            形状为 `(batch_size, num_channels, height, width)` 的 `jnp.ndarray` 元组。

            模型每一层的隐藏状态加上可选的初始嵌入输出。
    """

    last_hidden_state: jnp.ndarray = None
    # 定义一个可选的变量 hidden_states,类型为包含 jnp.ndarray 的元组,初始值为 None
    hidden_states: Optional[Tuple[jnp.ndarray]] = None
# 使用 @flax.struct.dataclass 装饰器声明一个数据类,表示带有池化和无注意力机制的模型输出
@flax.struct.dataclass
class FlaxBaseModelOutputWithPoolingAndNoAttention(ModelOutput):
    """
    Base class for model's outputs that also contains a pooling of the last hidden states.

    Args:
        last_hidden_state (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`):
            Sequence of hidden-states at the output of the last layer of the model.
        pooler_output (`jnp.ndarray` of shape `(batch_size, hidden_size)`):
            Last layer hidden-state after a pooling operation on the spatial dimensions.
        hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `jnp.ndarray` (one for the output of the embeddings, if the model has an embedding layer, + one
            for the output of each layer) of shape `(batch_size, num_channels, height, width)`. Hidden-states of the
            model at the output of each layer plus the optional initial embedding outputs.
    """

    # 声明类的属性及其类型注解
    last_hidden_state: jnp.ndarray = None
    pooler_output: jnp.ndarray = None
    hidden_states: Optional[Tuple[jnp.ndarray]] = None


# 使用 @flax.struct.dataclass 装饰器声明一个数据类,表示不带注意力机制的图像分类模型输出
@flax.struct.dataclass
class FlaxImageClassifierOutputWithNoAttention(ModelOutput):
    """
    Base class for outputs of image classification models.

    Args:
        logits (`jnp.ndarray` of shape `(batch_size, config.num_labels)`):
            Classification (or regression if config.num_labels==1) scores (before SoftMax).
        hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when
        `config.output_hidden_states=True`):
            Tuple of `jnp.ndarray` (one for the output of the embeddings, if the model has an embedding layer, + one
            for the output of each stage) of shape `(batch_size, num_channels, height, width)`. Hidden-states (also
            called feature maps) of the model at the output of each stage.
    """

    # 声明类的属性及其类型注解
    logits: jnp.ndarray = None
    hidden_states: Optional[Tuple[jnp.ndarray]] = None


# 使用 @flax.struct.dataclass 装饰器声明一个数据类,表示带有过去状态的模型输出
@flax.struct.dataclass
class FlaxBaseModelOutputWithPast(ModelOutput):
    """
    Base class for model's outputs, with potential hidden states and attentions.
    """
    # `last_hidden_state` 是模型最后一层的隐藏状态序列,形状为 `(batch_size, sequence_length, hidden_size)`
    # 这里使用了 JAX 数组 `jnp.ndarray`,表示 JAX 程序的数组结构
    last_hidden_state: jnp.ndarray = None

    # `past_key_values` 是一个字典,包含预先计算的隐藏状态(在注意力块中的键和值),用于快速自回归解码
    # 键和值的隐藏状态的形状为 `[batch_size, max_length]`
    past_key_values: Optional[Dict[str, jnp.ndarray]] = None

    # `hidden_states` 是一个元组,包含了模型每一层的隐藏状态
    # 第一个元素是嵌入层的输出,后续元素是每一层的输出,形状为 `(batch_size, sequence_length, hidden_size)`
    # 只有在传递参数 `output_hidden_states=True` 或者配置 `config.output_hidden_states=True` 时才返回
    hidden_states: Optional[Tuple[jnp.ndarray]] = None

    # `attentions` 是一个元组,包含了每一层的注意力权重
    # 每个元素是一个形状为 `(batch_size, num_heads, sequence_length, sequence_length)` 的 JAX 数组
    # 只有在传递参数 `output_attentions=True` 或者配置 `config.output_attentions=True` 时才返回
    attentions: Optional[Tuple[jnp.ndarray]] = None
# 使用 `flax.struct.dataclass` 装饰器定义一个数据类,该类继承自 `ModelOutput` 类,用于表示模型输出并包含最后隐藏状态的池化结果。
@flax.struct.dataclass
class FlaxBaseModelOutputWithPooling(ModelOutput):
    """
    Base class for model's outputs that also contains a pooling of the last hidden states.

    Args:
        last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        pooler_output (`jnp.ndarray` of shape `(batch_size, hidden_size)`):
            Last layer hidden-state of the first token of the sequence (classification token) further processed by a
            Linear layer and a Tanh activation function. The Linear layer weights are trained from the next sentence
            prediction (classification) objective during pretraining.
        hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
            `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    """

    # 定义类的属性,表示模型输出中的最后隐藏状态、池化输出、隐藏状态以及注意力权重
    last_hidden_state: jnp.ndarray = None
    pooler_output: jnp.ndarray = None
    hidden_states: Optional[Tuple[jnp.ndarray]] = None
    attentions: Optional[Tuple[jnp.ndarray]] = None


# 使用 `flax.struct.dataclass` 装饰器定义另一个数据类,表示模型输出并包含最后隐藏状态的池化结果以及交叉注意力权重。
@flax.struct.dataclass
class FlaxBaseModelOutputWithPoolingAndCrossAttentions(ModelOutput):
    """
    Base class for model's outputs that also contains a pooling of the last hidden states.
    """

    # 该类继承自 `ModelOutput` 类,与 `FlaxBaseModelOutputWithPooling` 类似,但这里还包括交叉注意力权重的定义。
        Args:
            last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`):
                Sequence of hidden-states at the output of the last layer of the model.
            pooler_output (`jnp.ndarray` of shape `(batch_size, hidden_size)`):
                Last layer hidden-state of the first token of the sequence (classification token) after further processing
                through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns
                the classification token after processing through a linear layer and a tanh activation function. The linear
                layer weights are trained from the next sentence prediction (classification) objective during pretraining.
            hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
                Tuple of `jnp.ndarray` (one for the output of the embeddings, if the model has an embedding layer, + one
                for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

                Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
            attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
                Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
                sequence_length)`.

                Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
                heads.
            cross_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
                Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
                sequence_length)`.

                Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
                weighted average in the cross-attention heads.
            past_key_values (`tuple(tuple(jnp.ndarray))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
                Tuple of `tuple(jnp.ndarray)` of length `config.n_layers`, with each tuple having 2 tensors of shape
                `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
                `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
                encoder_sequence_length, embed_size_per_head)`.

                Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
                `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
                input) to speed up sequential decoding.
        """

        last_hidden_state: jnp.ndarray = None
    # 定义一个变量 `pooler_output`,类型为 `jnp.ndarray`,初始值为 None
    pooler_output: jnp.ndarray = None
    # 定义一个变量 `hidden_states`,类型为 `Optional[Tuple[jnp.ndarray]]`,初始值为 None
    hidden_states: Optional[Tuple[jnp.ndarray]] = None
    # 定义一个变量 `past_key_values`,类型为 `Optional[Tuple[Tuple[jnp.ndarray]]]`,初始值为 None
    past_key_values: Optional[Tuple[Tuple[jnp.ndarray]]] = None
    # 定义一个变量 `attentions`,类型为 `Optional[Tuple[jnp.ndarray]]`,初始值为 None
    attentions: Optional[Tuple[jnp.ndarray]] = None
    # 定义一个变量 `cross_attentions`,类型为 `Optional[Tuple[jnp.ndarray]]`,初始值为 None
    cross_attentions: Optional[Tuple[jnp.ndarray]] = None
# 使用 @flax.struct.dataclass 装饰器声明一个数据类,该类继承自 ModelOutput 类
@flax.struct.dataclass
class FlaxBaseModelOutputWithPastAndCrossAttentions(ModelOutput):
    """
    Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).

    Args:
        last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.

            If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
            hidden_size)` is output.
        past_key_values (`tuple(tuple(jnp.ndarray))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
            Tuple of `tuple(jnp.ndarray)` of length `config.n_layers`, with each tuple having 2 tensors of shape
            `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
            `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
            encoder_sequence_length, embed_size_per_head)`.

            Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
            `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
            input) to speed up sequential decoding.
        hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
            `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
        cross_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
            Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
            weighted average in the cross-attention heads.
    """

    # 定义类的属性,每个属性都有一个默认值为 None
    last_hidden_state: jnp.ndarray = None
    past_key_values: Optional[Tuple[Tuple[jnp.ndarray]]] = None
    hidden_states: Optional[Tuple[jnp.ndarray]] = None
    attentions: Optional[Tuple[jnp.ndarray]] = None
    cross_attentions: Optional[Tuple[jnp.ndarray]] = None
# 定义基于Flax的数据类,表示序列到序列模型的输出,继承自ModelOutput
@flax.struct.dataclass
class FlaxSeq2SeqModelOutput(ModelOutput):
    """
    Base class for model encoder's outputs that also contains pre-computed hidden states that can speed up sequential decoding.
    """

    # 最后一个隐藏状态,类型为jnp.ndarray,默认为None
    last_hidden_state: jnp.ndarray = None
    # 过去的键值对,类型为可选的元组,包含元组的元组,每个元组包含jnp.ndarray,默认为None
    past_key_values: Optional[Tuple[Tuple[jnp.ndarray]]] = None
    # 解码器的隐藏状态,类型为可选的元组,包含jnp.ndarray,默认为None
    decoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None
    # 解码器的注意力权重,类型为可选的元组,包含jnp.ndarray,默认为None
    decoder_attentions: Optional[Tuple[jnp.ndarray]] = None
    # 交叉注意力的权重,类型为可选的元组,包含jnp.ndarray,默认为None
    cross_attentions: Optional[Tuple[jnp.ndarray]] = None
    # 编码器最后一个隐藏状态,类型为可选的jnp.ndarray,默认为None
    encoder_last_hidden_state: Optional[jnp.ndarray] = None
    # 编码器的隐藏状态,类型为可选的元组,包含jnp.ndarray,默认为None
    encoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None
    # 编码器的注意力权重,类型为可选的元组,包含jnp.ndarray,默认为None
    encoder_attentions: Optional[Tuple[jnp.ndarray]] = None


# 定义基于Flax的数据类,表示带有交叉注意力的因果语言模型输出,继承自ModelOutput
@flax.struct.dataclass
class FlaxCausalLMOutputWithCrossAttentions(ModelOutput):
    """
    Base class for causal language model (or autoregressive) outputs.
    """

    # 预测的logits,形状为(batch_size, sequence_length, config.vocab_size)的jnp.ndarray
    logits: jnp.ndarray
    # 隐藏状态的元组,包含embedding输出和每层输出的jnp.ndarray,形状为(batch_size, sequence_length, hidden_size)
    hidden_states: Optional[Tuple[jnp.ndarray]] = None
    # 注意力权重的元组,每层一个jnp.ndarray,形状为(batch_size, num_heads, sequence_length, sequence_length)
    attentions: Optional[Tuple[jnp.ndarray]] = None
    # 交叉注意力权重的元组,每层一个jnp.ndarray,形状为(batch_size, num_heads, sequence_length, sequence_length)
    cross_attentions: Optional[Tuple[jnp.ndarray]] = None
    # 过去的键值对的元组,每层一个jnp.ndarray元组,长度为config.n_layers,仅在使用缓存时有效,用于编码器-解码器设置
    past_key_values: Optional[Tuple[Tuple[jnp.ndarray]]] = None
    # 定义变量 logits,用于存储一个 NumPy 数组,初始值为 None
    logits: jnp.ndarray = None
    # 定义变量 past_key_values,类型为 Optional[Tuple[Tuple[jnp.ndarray]]],可选的三重嵌套元组结构,初始值为 None
    past_key_values: Optional[Tuple[Tuple[jnp.ndarray]]] = None
    # 定义变量 hidden_states,类型为 Optional[Tuple[jnp.ndarray]],可选的元组结构,初始值为 None
    hidden_states: Optional[Tuple[jnp.ndarray]] = None
    # 定义变量 attentions,类型为 Optional[Tuple[jnp.ndarray]],可选的元组结构,初始值为 None
    attentions: Optional[Tuple[jnp.ndarray]] = None
    # 定义变量 cross_attentions,类型为 Optional[Tuple[jnp.ndarray]],可选的元组结构,初始值为 None
    cross_attentions: Optional[Tuple[jnp.ndarray]] = None
@flax.struct.dataclass
class FlaxMaskedLMOutput(ModelOutput):
    """
    Masked语言模型输出的基类。

    Args:
        logits (`jnp.ndarray` of shape `(batch_size, sequence_length, config.vocab_size)`):
            语言建模头的预测分数(SoftMax之前的每个词汇标记的分数)。
        hidden_states (`tuple(jnp.ndarray)`, *optional*, 当 `output_hidden_states=True` 或 `config.output_hidden_states=True` 时返回):
            形状为 `(batch_size, sequence_length, hidden_size)` 的 `jnp.ndarray` 元组。

            模型在每一层输出的隐藏状态加上初始嵌入输出。
        attentions (`tuple(jnp.ndarray)`, *optional*, 当 `output_attentions=True` 或 `config.output_attentions=True` 时返回):
            形状为 `(batch_size, num_heads, sequence_length, sequence_length)` 的 `jnp.ndarray` 元组。

            经过注意力 softmax 后的注意力权重,用于计算自注意力头中的加权平均值。
    """

    logits: jnp.ndarray = None
    hidden_states: Optional[Tuple[jnp.ndarray]] = None
    attentions: Optional[Tuple[jnp.ndarray]] = None


FlaxCausalLMOutput = FlaxMaskedLMOutput


@flax.struct.dataclass
class FlaxSeq2SeqLMOutput(ModelOutput):
    """
    序列到序列语言模型输出的基类。

    """

    logits: jnp.ndarray = None
    past_key_values: Optional[Tuple[Tuple[jnp.ndarray]]] = None
    decoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None
    decoder_attentions: Optional[Tuple[jnp.ndarray]] = None
    cross_attentions: Optional[Tuple[jnp.ndarray]] = None
    encoder_last_hidden_state: Optional[jnp.ndarray] = None
    encoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None
    encoder_attentions: Optional[Tuple[jnp.ndarray]] = None


@flax.struct.dataclass
class FlaxNextSentencePredictorOutput(ModelOutput):
    """
    预测两个句子是否连续的模型输出的基类。

    """
    Args:
        logits (`jnp.ndarray` of shape `(batch_size, 2)`):
            Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
            before SoftMax).
        hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
            `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    """
    # 定义函数参数及其类型注释,描述函数接受的输入参数及其形状和类型信息

    logits: jnp.ndarray = None
    # 定义变量 logits,用于存储形状为(batch_size, 2)的 jnp.ndarray,表示下一个序列预测分类头部的预测分数(经过 SoftMax 之前的 True/False 连续性得分)。

    hidden_states: Optional[Tuple[jnp.ndarray]] = None
    # 定义变量 hidden_states,可选的元组类型,包含 jnp.ndarray(当传递 output_hidden_states=True 或 config.output_hidden_states=True 时返回)。
    # 元组中的每个数组形状为(batch_size, sequence_length, hidden_size),表示模型在每一层输出的隐藏状态以及初始嵌入输出。

    attentions: Optional[Tuple[jnp.ndarray]] = None
    # 定义变量 attentions,可选的元组类型,包含 jnp.ndarray(当传递 output_attentions=True 或 config.output_attentions=True 时返回)。
    # 元组中的每个数组形状为(batch_size, num_heads, sequence_length, sequence_length),
    # 表示注意力 softmax 后的注意力权重,用于计算自注意力头部中的加权平均值。
# 使用 @flax.struct.dataclass 装饰器声明一个数据类,用于表示序列分类模型的输出。
class FlaxSequenceClassifierOutput(ModelOutput):
    """
    Base class for outputs of sentence classification models.

    Args:
        logits (`jnp.ndarray` of shape `(batch_size, config.num_labels)`):
            Classification (or regression if config.num_labels==1) scores (before SoftMax).
        hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
            `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    """
    
    # 分类或回归得分(未经 SoftMax 处理前)的 logits,形状为 `(batch_size, config.num_labels)`
    logits: jnp.ndarray = None
    # 模型每一层的输出的隐藏状态的元组,形状为 `(batch_size, sequence_length, hidden_size)`,当 `output_hidden_states=True` 时返回
    hidden_states: Optional[Tuple[jnp.ndarray]] = None
    # 自注意力机制注意力权重的元组,形状为 `(batch_size, num_heads, sequence_length, sequence_length)`,当 `output_attentions=True` 时返回
    attentions: Optional[Tuple[jnp.ndarray]] = None


# 使用 @flax.struct.dataclass 装饰器声明一个数据类,用于表示序列到序列的句子分类模型的输出。
class FlaxSeq2SeqSequenceClassifierOutput(ModelOutput):
    """
    Base class for outputs of sequence-to-sequence sentence classification models.

    """
    
    # 分类得分 logits,形状为 `(batch_size, config.num_labels)`
    logits: jnp.ndarray = None
    # 用于存储过去键值的元组,形状为 `(batch_size, num_layers, 2, batch_size, num_heads, sequence_length, head_dim)`,当 `output_past_key_values=True` 时返回
    past_key_values: Optional[Tuple[Tuple[jnp.ndarray]]] = None
    # 解码器每一层的隐藏状态的元组,形状为 `(batch_size, sequence_length, hidden_size)`,当 `output_hidden_states=True` 时返回
    decoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None
    # 解码器每一层的注意力权重的元组,形状为 `(batch_size, num_heads, sequence_length, sequence_length)`,当 `output_attentions=True` 时返回
    decoder_attentions: Optional[Tuple[jnp.ndarray]] = None
    # 编码器解码器之间注意力权重的元组,形状为 `(batch_size, num_heads, sequence_length, sequence_length)`,当 `output_attentions=True` 时返回
    cross_attentions: Optional[Tuple[jnp.ndarray]] = None
    # 编码器最后一层的隐藏状态,形状为 `(batch_size, sequence_length, hidden_size)`
    encoder_last_hidden_state: Optional[jnp.ndarray] = None
    # 编码器每一层的隐藏状态的元组,形状为 `(batch_size, sequence_length, hidden_size)`,当 `output_hidden_states=True` 时返回
    encoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None
    # 编码器每一层的注意力权重的元组,形状为 `(batch_size, num_heads, sequence_length, sequence_length)`,当 `output_attentions=True` 时返回
    encoder_attentions: Optional[Tuple[jnp.ndarray]] = None


# 使用 @flax.struct.dataclass 装饰器声明一个数据类,用于表示多项选择模型的输出。
class FlaxMultipleChoiceModelOutput(ModelOutput):
    """
    Base class for outputs of multiple choice models.
    """
    
    # 此类暂未定义任何属性,因此无需添加额外的注释。
    """
    Args:
        logits (`jnp.ndarray` of shape `(batch_size, num_choices)`):
            分类器的输出分数(SoftMax 之前)。

        hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            包含每层输出的元组,每个 `jnp.ndarray` 的形状为 `(batch_size, sequence_length, hidden_size)`。
            模型在每一层的隐藏状态以及初始嵌入输出。

        attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            包含每层注意力权重的元组,每个 `jnp.ndarray` 的形状为 `(batch_size, num_heads, sequence_length, sequence_length)`。
            注意力 softmax 后的注意力权重,用于计算自注意力头中的加权平均值。
    """

    logits: jnp.ndarray = None
    hidden_states: Optional[Tuple[jnp.ndarray]] = None
    attentions: Optional[Tuple[jnp.ndarray]] = None
# 使用 @flax.struct.dataclass 装饰器定义一个数据类,表示序列标注模型的输出
@flax.struct.dataclass
class FlaxTokenClassifierOutput(ModelOutput):
    """
    序列标注模型输出的基类。

    Args:
        logits (`jnp.ndarray` of shape `(batch_size, sequence_length, config.num_labels)`):
            分类得分(SoftMax 之前)。
        hidden_states (`tuple(jnp.ndarray)`, *optional*, 当 `output_hidden_states=True` 或 `config.output_hidden_states=True` 时返回):
            包含多个 `jnp.ndarray` 的元组,形状为 `(batch_size, sequence_length, hidden_size)`。

            模型在每一层输出的隐藏状态,以及初始嵌入输出。
        attentions (`tuple(jnp.ndarray)`, *optional*, 当 `output_attentions=True` 或 `config.output_attentions=True` 时返回):
            包含多个 `jnp.ndarray` 的元组,形状为 `(batch_size, num_heads, sequence_length, sequence_length)`。

            经过注意力 softmax 后的注意力权重,用于计算自注意力头中的加权平均值。
    """

    logits: jnp.ndarray = None
    hidden_states: Optional[Tuple[jnp.ndarray]] = None
    attentions: Optional[Tuple[jnp.ndarray]] = None


# 使用 @flax.struct.dataclass 装饰器定义一个数据类,表示问答模型的输出
@flax.struct.dataclass
class FlaxQuestionAnsweringModelOutput(ModelOutput):
    """
    问答模型输出的基类。

    Args:
        start_logits (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
            起始位置的得分(SoftMax 之前)。
        end_logits (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
            终止位置的得分(SoftMax 之前)。
        hidden_states (`tuple(jnp.ndarray)`, *optional*, 当 `output_hidden_states=True` 或 `config.output_hidden_states=True` 时返回):
            包含多个 `jnp.ndarray` 的元组,形状为 `(batch_size, sequence_length, hidden_size)`。

            模型在每一层输出的隐藏状态,以及初始嵌入输出。
        attentions (`tuple(jnp.ndarray)`, *optional*, 当 `output_attentions=True` 或 `config.output_attentions=True` 时返回):
            包含多个 `jnp.ndarray` 的元组,形状为 `(batch_size, num_heads, sequence_length, sequence_length)`。

            经过注意力 softmax 后的注意力权重,用于计算自注意力头中的加权平均值。
    """

    start_logits: jnp.ndarray = None
    end_logits: jnp.ndarray = None
    hidden_states: Optional[Tuple[jnp.ndarray]] = None
    attentions: Optional[Tuple[jnp.ndarray]] = None


# 使用 @flax.struct.dataclass 装饰器定义一个数据类,表示序列到序列问答模型的输出
@flax.struct.dataclass
class FlaxSeq2SeqQuestionAnsweringModelOutput(ModelOutput):
    """
    序列到序列问答模型输出的基类。

    """

    start_logits: jnp.ndarray = None
    end_logits: jnp.ndarray = None
    # 初始化变量,用于存储模型解码器的相关状态和注意力权重
    past_key_values: Optional[Tuple[Tuple[jnp.ndarray]]] = None
    # 初始化变量,用于存储模型解码器的隐藏状态
    decoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None
    # 初始化变量,用于存储模型解码器的注意力权重
    decoder_attentions: Optional[Tuple[jnp.ndarray]] = None
    # 初始化变量,用于存储模型交叉注意力的权重
    cross_attentions: Optional[Tuple[jnp.ndarray]] = None
    # 初始化变量,用于存储模型编码器的最后隐藏状态
    encoder_last_hidden_state: Optional[jnp.ndarray] = None
    # 初始化变量,用于存储模型编码器的隐藏状态的序列
    encoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None
    # 初始化变量,用于存储模型编码器的注意力权重的序列
    encoder_attentions: Optional[Tuple[jnp.ndarray]] = None

.\modeling_flax_pytorch_utils.py

# 设置编码为 UTF-8,确保文件能正确处理非 ASCII 字符
# 版权声明及许可协议,指定本文件的版权归属及使用许可
# 注意事项:根据 Apache 许可证 2.0 版本,除非符合许可协议,否则禁止使用本文件
# 获取许可协议的详细信息,请访问指定的 URL
# 如果适用法律要求或书面同意,软件将按“原样”分发,不提供任何明示或暗示的担保或条件
# 查看许可协议以了解具体语言和限制条件

""" PyTorch - Flax general utilities."""
# 导入所需的标准库和模块

import os  # 导入操作系统功能
from pickle import UnpicklingError  # 导入反序列化错误异常
from typing import Dict, Tuple  # 导入类型提示

import jax  # 导入 JAX 库
import jax.numpy as jnp  # 导入 JAX 的 NumPy 接口
import numpy as np  # 导入 NumPy 库
from flax.serialization import from_bytes  # 从字节流中反序列化对象
from flax.traverse_util import flatten_dict, unflatten_dict  # 对字典进行扁平化和展开操作

import transformers  # 导入 Transformers 库

from . import is_safetensors_available, is_torch_available  # 导入本地模块
from .utils import logging  # 从本地工具模块中导入日志功能


if is_torch_available():  # 如果 Torch 可用
    import torch  # 导入 Torch 库

if is_safetensors_available():  # 如果 SafeTensors 可用
    from safetensors import safe_open  # 导入安全打开文件函数
    from safetensors.flax import load_file as safe_load_file  # 导入安全加载文件函数


logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器


#####################
# PyTorch => Flax #
#####################

# 定义一个函数,将 PyTorch 的检查点加载到 Flax 模型的状态字典中
def load_pytorch_checkpoint_in_flax_state_dict(
    flax_model, pytorch_checkpoint_path, is_sharded, allow_missing_keys=False
):
    """Load pytorch checkpoints in a flax model"""

    if not is_sharded:  # 如果不是分片加载
        pt_path = os.path.abspath(pytorch_checkpoint_path)  # 获取 PyTorch 检查点的绝对路径
        logger.info(f"Loading PyTorch weights from {pt_path}")  # 记录日志,显示正在加载的 PyTorch 权重文件路径

        if pt_path.endswith(".safetensors"):  # 如果文件路径以 ".safetensors" 结尾
            pt_state_dict = {}  # 初始化一个空字典,用于存储 PyTorch 的状态字典
            with safe_open(pt_path, framework="flax") as f:  # 使用安全方式打开文件
                for k in f.keys():  # 遍历文件中的键
                    pt_state_dict[k] = f.get_tensor(k)  # 将文件中的张量数据存储到状态字典中
        else:  # 如果文件路径不以 ".safetensors" 结尾
            try:
                import torch  # 尝试导入 Torch 库
                from .pytorch_utils import is_torch_greater_or_equal_than_1_13  # 导入版本比较工具函数
            except (ImportError, ModuleNotFoundError):  # 处理导入错误或模块未找到异常
                logger.error(
                    "Loading a PyTorch model in Flax, requires both PyTorch and Flax to be installed. Please see"
                    " https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation"
                    " instructions."
                )  # 记录加载错误信息,指出安装 PyTorch 和 Flax 的必要性
                raise  # 抛出异常

            # 根据不同的 Torch 版本加载权重数据
            weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
            pt_state_dict = torch.load(pt_path, map_location="cpu", **weights_only_kwarg)  # 使用 Torch 加载权重数据到状态字典
            logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.")  # 记录日志,显示加载的参数数量

        flax_state_dict = convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model)  # 将 PyTorch 的状态字典转换为 Flax 的状态字典
    else:
        # 如果模型是分片的,并且 pytorch_checkpoint_path 已经包含了 .pt 分片文件的列表
        # 则将使用 convert_pytorch_sharded_state_dict_to_flax 函数将其转换为 Flax 模型的状态字典
        flax_state_dict = convert_pytorch_sharded_state_dict_to_flax(pytorch_checkpoint_path, flax_model)
    # 返回转换后的 Flax 模型状态字典
    return flax_state_dict
# 将 PyTorch 权重名称重命名为对应的 Flax 权重名称并在必要时重塑张量
def rename_key_and_reshape_tensor(
    pt_tuple_key: Tuple[str],
    pt_tensor: np.ndarray,
    random_flax_state_dict: Dict[str, jnp.ndarray],
    model_prefix: str,
) -> (Tuple[str], np.ndarray):
    """Rename PT weight names to corresponding Flax weight names and reshape tensor if necessary"""

    def is_key_or_prefix_key_in_dict(key: Tuple[str]) -> bool:
        """Checks if `key` of `(prefix,) + key` is in random_flax_state_dict"""
        return len(set(random_flax_state_dict) & {key, (model_prefix,) + key}) > 0

    # layer norm
    renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
    if pt_tuple_key[-1] in ["weight", "gamma"] and is_key_or_prefix_key_in_dict(renamed_pt_tuple_key):
        return renamed_pt_tuple_key, pt_tensor

    # batch norm layer mean
    renamed_pt_tuple_key = pt_tuple_key[:-1] + ("mean",)
    if pt_tuple_key[-1] == "running_mean" and not is_key_or_prefix_key_in_dict(pt_tuple_key):
        return renamed_pt_tuple_key, pt_tensor

    # batch norm layer var
    renamed_pt_tuple_key = pt_tuple_key[:-1] + ("var",)
    if pt_tuple_key[-1] == "running_var" and not is_key_or_prefix_key_in_dict(pt_tuple_key):
        return renamed_pt_tuple_key, pt_tensor

    # embedding
    renamed_pt_tuple_key = pt_tuple_key[:-1] + ("embedding",)
    if pt_tuple_key[-1] == "weight" and is_key_or_prefix_key_in_dict(renamed_pt_tuple_key):
        return renamed_pt_tuple_key, pt_tensor

    # conv layer
    renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
    if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4 and not is_key_or_prefix_key_in_dict(pt_tuple_key):
        pt_tensor = pt_tensor.transpose(2, 3, 1, 0)
        return renamed_pt_tuple_key, pt_tensor

    # linear layer
    renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
    if pt_tuple_key[-1] == "weight" and not is_key_or_prefix_key_in_dict(pt_tuple_key):
        pt_tensor = pt_tensor.T
        return renamed_pt_tuple_key, pt_tensor

    # old PyTorch layer norm weight
    renamed_pt_tuple_key = pt_tuple_key[:-1] + ("weight",)
    if pt_tuple_key[-1] == "gamma":
        return renamed_pt_tuple_key, pt_tensor

    # old PyTorch layer norm bias
    renamed_pt_tuple_key = pt_tuple_key[:-1] + ("bias",)
    if pt_tuple_key[-1] == "beta":
        return renamed_pt_tuple_key, pt_tensor

    # New `weight_norm` from https://github.com/huggingface/transformers/pull/24030
    name = None
    if pt_tuple_key[-3::2] == ("parametrizations", "original0"):
        name = pt_tuple_key[-2] + "_g"
    elif pt_tuple_key[-3::2] == ("parametrizations", "original1"):
        name = pt_tuple_key[-2] + "_v"
    if name is not None:
        renamed_pt_tuple_key = pt_tuple_key[:-3] + (name,)
        return renamed_pt_tuple_key, pt_tensor

    # 默认情况,返回原始的 tuple key 和张量
    return pt_tuple_key, pt_tensor
    # 根据条件确定要使用的 bfloat16 类型,如果 from_bin 是 True 则使用 torch.bfloat16,否则使用字符串 "bfloat16"
    bfloat16 = torch.bfloat16 if from_bin else "bfloat16"

    # 创建一个字典,将 PyTorch 模型状态字典中每个键对应的数据类型收集起来
    weight_dtypes = {k: v.dtype for k, v in pt_state_dict.items()}

    # 如果 from_bin 是 True,则需要将 PyTorch 模型状态字典中的 bfloat16 类型转换为 float32,以避免精度损失问题
    if from_bin:
        for k, v in pt_state_dict.items():
            # 当前 numpy 不支持 bfloat16,因此在此情况下需要将其转换为 float32 类型
            if v.dtype == bfloat16:
                v = v.float()
            pt_state_dict[k] = v.numpy()

    # 获取 Flax 模型的基础模型前缀
    model_prefix = flax_model.base_model_prefix

    # 如果模型中包含批归一化层,则使用 params 字典
    if "params" in flax_model.params:
        flax_model_params = flax_model.params["params"]
    else:
        flax_model_params = flax_model.params

    # 将 Flax 模型参数展平为字典
    random_flax_state_dict = flatten_dict(flax_model_params)

    # 如果模型参数中包含 batch_stats,则将其展平并加入随机状态字典中
    if "batch_stats" in flax_model.params:
        flax_batch_stats = flatten_dict(flax_model.params["batch_stats"])
        random_flax_state_dict.update(flax_batch_stats)

    # 初始化一个空的 Flax 状态字典
    flax_state_dict = {}

    # 根据条件判断是将头部模型加载到基础模型中,还是将基础模型加载到带头部模型中
    load_model_with_head_into_base_model = (model_prefix not in flax_model_params) and (
        model_prefix in {k.split(".")[0] for k in pt_state_dict.keys()}
    )
    load_base_model_into_model_with_head = (model_prefix in flax_model_params) and (
        model_prefix not in {k.split(".")[0] for k in pt_state_dict.keys()}
    )

    # 需要修改一些参数名称以匹配 Flax 模型的命名规范
    # 此处的注释指出了需要进行参数名称匹配的必要性,但没有具体说明如何实现
    # 遍历 PyTorch 状态字典中的每个键值对
    for pt_key, pt_tensor in pt_state_dict.items():
        # 将点分割的键名转换为元组形式
        pt_tuple_key = tuple(pt_key.split("."))
        # 检查当前权重数据类型是否为 bfloat16
        is_bfloat_16 = weight_dtypes[pt_key] == bfloat16

        # 如果需要,移除基础模型前缀
        has_base_model_prefix = pt_tuple_key[0] == model_prefix
        if load_model_with_head_into_base_model and has_base_model_prefix:
            pt_tuple_key = pt_tuple_key[1:]

        # 使用指定函数重命名键名并调整张量形状
        flax_key, flax_tensor = rename_key_and_reshape_tensor(
            pt_tuple_key, pt_tensor, random_flax_state_dict, model_prefix
        )

        # 如果需要,添加模型前缀
        require_base_model_prefix = (model_prefix,) + flax_key in random_flax_state_dict
        if load_base_model_into_model_with_head and require_base_model_prefix:
            flax_key = (model_prefix,) + flax_key

        # 检查重命名后的键是否存在于随机化的 Flax 状态字典中
        if flax_key in random_flax_state_dict:
            # 检查张量形状是否与期望的 Flax 模型权重形状一致,否则抛出 ValueError
            if flax_tensor.shape != random_flax_state_dict[flax_key].shape:
                raise ValueError(
                    f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape "
                    f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}."
                )

        # 如果 Flax 模型包含批次归一化层,添加批次统计信息
        if "batch_stats" in flax_model.params:
            if "mean" in flax_key[-1] or "var" in flax_key[-1]:
                # 将 Flax 张量转换为 JAX 数组,并存储在新的位置
                flax_state_dict[("batch_stats",) + flax_key] = jnp.asarray(flax_tensor)
                continue
            # 移除 num_batches_tracked 键
            if "num_batches_tracked" in flax_key[-1]:
                flax_state_dict.pop(flax_key, None)
                continue

            # 否则,将权重添加到 params 键下
            flax_state_dict[("params",) + flax_key] = (
                jnp.asarray(flax_tensor) if not is_bfloat_16 else jnp.asarray(flax_tensor, dtype=jnp.bfloat16)
            )
        else:
            # 如果模型不包含批次归一化层,也将权重添加到状态字典中
            flax_state_dict[flax_key] = (
                jnp.asarray(flax_tensor) if not is_bfloat_16 else jnp.asarray(flax_tensor, dtype=jnp.bfloat16)
            )

    # 返回经过重新构造的 Flax 状态字典
    return unflatten_dict(flax_state_dict)
############################
# Sharded Pytorch => Flax #
############################

# 将分片的 PyTorch 状态字典转换为 Flax 格式
def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model):
    import torch

    from .pytorch_utils import is_torch_greater_or_equal_than_1_13

    # Load the index
    flax_state_dict = {}
    # 调用函数 unflatten_dict 将 Flax 状态字典展开并返回
    return unflatten_dict(flax_state_dict)


#####################
# Flax => PyTorch #
#####################

# 在 PyTorch 模型中加载 Flax 检查点
def load_flax_checkpoint_in_pytorch_model(model, flax_checkpoint_path):
    """Load flax checkpoints in a PyTorch model"""
    flax_checkpoint_path = os.path.abspath(flax_checkpoint_path)
    logger.info(f"Loading Flax weights from {flax_checkpoint_path}")

    # import correct flax class
    flax_cls = getattr(transformers, "Flax" + model.__class__.__name__)

    # load flax weight dict
    if flax_checkpoint_path.endswith(".safetensors"):
        # 使用 safe_load_file 函数加载安全张量文件
        flax_state_dict = safe_load_file(flax_checkpoint_path)
        # 使用分隔符 "." 对 Flax 状态字典进行展开
        flax_state_dict = unflatten_dict(flax_state_dict, sep=".")
    else:
        with open(flax_checkpoint_path, "rb") as state_f:
            try:
                # 尝试从文件中读取并解析 Flax 序列化对象
                flax_state_dict = from_bytes(flax_cls, state_f.read())
            except UnpicklingError:
                raise EnvironmentError(f"Unable to convert {flax_checkpoint_path} to Flax deserializable object. ")

    return load_flax_weights_in_pytorch_model(model, flax_state_dict)


# 在 PyTorch 模型中加载 Flax 权重
def load_flax_weights_in_pytorch_model(pt_model, flax_state):
    """Load flax checkpoints in a PyTorch model"""

    try:
        import torch  # noqa: F401
    except (ImportError, ModuleNotFoundError):
        logger.error(
            "Loading a Flax weights in PyTorch, requires both PyTorch and Flax to be installed. Please see"
            " https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation"
            " instructions."
        )
        raise

    # check if we have bf16 weights
    # 检查是否存在 bf16 类型的权重,并转换为 fp32 类型以便 PyTorch 加载
    is_type_bf16 = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype == jnp.bfloat16, flax_state)).values()
    if any(is_type_bf16):
        # 如果发现 Flax 模型中包含 bf16 类型的权重,则警告并将它们转换为 fp32 类型
        logger.warning(
            "Found ``bfloat16`` weights in Flax model. Casting all ``bfloat16`` weights to ``float32`` "
            "before loading those in PyTorch model."
        )
        flax_state = jax.tree_util.tree_map(
            lambda params: params.astype(np.float32) if params.dtype == jnp.bfloat16 else params, flax_state
        )

    # 将 Flax 状态字典展开为一维字典
    flax_state_dict = flatten_dict(flax_state)
    # 获取 PyTorch 模型的状态字典
    pt_model_dict = pt_model.state_dict()

    # 判断是否需要将模型头加载到基础模型中
    load_model_with_head_into_base_model = (pt_model.base_model_prefix in flax_state) and (
        pt_model.base_model_prefix not in {k.split(".")[0] for k in pt_model_dict.keys()}
    )
    # 检查是否需要将基础模型加载到带头部的模型中
    load_base_model_into_model_with_head = (pt_model.base_model_prefix not in flax_state) and (
        pt_model.base_model_prefix in {k.split(".")[0] for k in pt_model_dict.keys()}
    )

    # 用于跟踪未预期和丢失的键
    unexpected_keys = []
    missing_keys = set(pt_model_dict.keys())

    # 加载 PyTorch 模型的状态字典
    pt_model.load_state_dict(pt_model_dict)

    # 将缺失的键重新转换为列表
    missing_keys = list(missing_keys)

    # 如果存在未预期的键,则发出警告
    if len(unexpected_keys) > 0:
        logger.warning(
            "Some weights of the Flax model were not used when initializing the PyTorch model"
            f" {pt_model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are initializing"
            f" {pt_model.__class__.__name__} from a Flax model trained on another task or with another architecture"
            " (e.g. initializing a BertForSequenceClassification model from a FlaxBertForPreTraining model).\n- This"
            f" IS NOT expected if you are initializing {pt_model.__class__.__name__} from a Flax model that you expect"
            " to be exactly identical (e.g. initializing a BertForSequenceClassification model from a"
            " FlaxBertForSequenceClassification model)."
        )
    else:
        # 所有 Flax 模型的权重均已用于初始化 PyTorch 模型
        logger.warning(f"All Flax model weights were used when initializing {pt_model.__class__.__name__}.\n")

    # 如果存在丢失的键,则发出警告
    if len(missing_keys) > 0:
        logger.warning(
            f"Some weights of {pt_model.__class__.__name__} were not initialized from the Flax model and are newly"
            f" initialized: {missing_keys}\nYou should probably TRAIN this model on a down-stream task to be able to"
            " use it for predictions and inference."
        )
    else:
        # 所有的权重已从 Flax 模型初始化
        logger.warning(
            f"All the weights of {pt_model.__class__.__name__} were initialized from the Flax model.\n"
            "If your task is similar to the task the model of the checkpoint was trained on, "
            f"you can already use {pt_model.__class__.__name__} for predictions without further training."
        )

    # 返回加载后的 PyTorch 模型
    return pt_model

.\modeling_flax_utils.py

# coding=utf-8
# 代码文件声明使用 UTF-8 编码

# 导入标准库和第三方库
import gc  # 垃圾回收模块
import json  # JSON 数据格式处理模块
import os  # 系统操作模块
import re  # 正则表达式模块
import warnings  # 警告模块
from functools import partial  # 偏函数功能
from pickle import UnpicklingError  # 反序列化错误异常
from typing import Any, Dict, Optional, Set, Tuple, Union  # 类型提示模块

# 导入 Flax 和 JAX 库
import flax.linen as nn  # Flax 的线性模块
import jax  # JAX 数值计算库
import jax.numpy as jnp  # JAX 的 NumPy 接口
import msgpack.exceptions  # MsgPack 序列化异常模块
from flax.core.frozen_dict import FrozenDict, unfreeze  # 冻结字典和解冻功能
from flax.serialization import from_bytes, to_bytes  # 对象序列化和反序列化
from flax.traverse_util import flatten_dict, unflatten_dict  # 字典扁平化和反扁平化
from jax.random import PRNGKey  # JAX 随机数生成模块

# 导入本地的配置和工具函数
from .configuration_utils import PretrainedConfig  # 预训练模型配置类
from .dynamic_module_utils import custom_object_save  # 自定义对象保存函数
from .generation import FlaxGenerationMixin, GenerationConfig  # 生成相关模块
from .modeling_flax_pytorch_utils import load_pytorch_checkpoint_in_flax_state_dict  # 加载 PyTorch 检查点到 Flax 状态字典
from .utils import (
    FLAX_WEIGHTS_INDEX_NAME, FLAX_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME,  # 各种常量定义
    WEIGHTS_INDEX_NAME, WEIGHTS_NAME,  # 权重文件名和索引名
    PushToHubMixin,  # 推送到 Hub 的混合类
    add_code_sample_docstrings,  # 添加代码示例文档字符串
    add_start_docstrings_to_model_forward,  # 添加模型前向方法的文档字符串
    cached_file,  # 缓存文件函数
    copy_func,  # 复制函数对象
    download_url,  # 下载 URL 资源函数
    has_file,  # 检查文件是否存在函数
    is_offline_mode,  # 检查是否处于离线模式函数
    is_remote_url,  # 检查 URL 是否远程地址函数
    logging,  # 日志模块
    replace_return_docstrings,  # 替换返回值的文档字符串函数
)
from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files  # Hub 工具函数
from .utils.import_utils import is_safetensors_available  # 检查是否安装安全张量库


if is_safetensors_available():
    from safetensors import safe_open  # 安全打开文件函数
    from safetensors.flax import load_file as safe_load_file  # 安全加载文件函数
    from safetensors.flax import save_file as safe_save_file  # 安全保存文件函数

logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器


def quick_gelu(x):
    """
    快速 GELU 激活函数的定义,使用 JAX 实现
    """
    return x * jax.nn.sigmoid(1.702 * x)


ACT2FN = {
    "gelu": partial(nn.gelu, approximate=False),  # 使用 Flax 提供的精确 GELU 激活函数
    "relu": nn.relu,  # 使用 Flax 提供的 ReLU 激活函数
    "silu": nn.swish,  # 使用 Flax 提供的 SiLU(Swish)激活函数
    "swish": nn.swish,  # 使用 Flax 提供的 Swish 激活函数
    "gelu_new": partial(nn.gelu, approximate=True),  # 使用 Flax 提供的近似 GELU 激活函数
    "quick_gelu": quick_gelu,  # 使用定义的快速 GELU 激活函数
    "gelu_pytorch_tanh": partial(nn.gelu, approximate=True),  # 使用 Flax 提供的近似 GELU 激活函数
}


def dtype_byte_size(dtype):
    """
    根据数据类型 `dtype` 返回一个参数占用的字节数。例如:
    ```
    >>> dtype_byte_size(np.float32)
    4
    ```
    """
    if dtype == bool:
        return 1 / 8  # 布尔类型占用 1 位,即 1/8 字节
    bit_search = re.search(r"[^\d](\d+)$", dtype.name)
    if bit_search is None:
        raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")  # 若数据类型不合法,抛出异常
    bit_size = int(bit_search.groups()[0])  # 获取数据类型的位数大小
    return bit_size // 8  # 返回字节数


def flax_shard_checkpoint(params, max_shard_size="10GB"):
    """
    将参数 `params` 拆分为多个小的检查点文件,以便于存储和传输。
    """
    # 将模型状态字典拆分为子检查点,使得每个子检查点的最终大小不超过给定的大小限制。
    # 子检查点的确定是通过按照状态字典的键的顺序迭代进行的,因此不会优化使每个子检查点尽可能接近传递的最大大小。
    # 例如,如果限制是10GB,并且我们有大小为[6GB, 6GB, 2GB, 6GB, 2GB, 2GB]的权重,则它们将被分割为[6GB]、[6+2GB]、[6+2+2GB],而不是[6+2+2GB]、[6+2GB]、[6GB]。
    # <Tip warning={true}>
    # 如果模型中的某个权重大于`max_shard_size`,它将单独存在于其自己的子检查点中,其大小将大于`max_shard_size`。
    # </Tip>
    
    Args:
        params (`Union[Dict, FrozenDict]`): 模型参数的`PyTree`表示。
        max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
            每个子检查点的最大大小。如果表示为字符串,则需要是数字后跟单位(例如`"5MB"`)。
    """
    # 将`max_shard_size`转换为整数表示
    max_shard_size = convert_file_size_to_int(max_shard_size)
    
    # 初始化用于存储分片状态字典的列表
    sharded_state_dicts = []
    # 当前分块的字典
    current_block = {}
    # 当前分块的大小
    current_block_size = 0
    # 总大小
    total_size = 0
    
    # 将参数展平为键值对
    weights = flatten_dict(params, sep="/")
    for item in weights:
        # 计算权重项的大小
        weight_size = weights[item].size * dtype_byte_size(weights[item].dtype)
    
        # 如果当前分块加上当前权重项的大小超过了最大分块大小,进行分块
        if current_block_size + weight_size > max_shard_size:
            sharded_state_dicts.append(current_block)
            current_block = {}
            current_block_size = 0
    
        # 将权重项添加到当前分块中
        current_block[item] = weights[item]
        current_block_size += weight_size
        total_size += weight_size
    
    # 添加最后一个分块
    sharded_state_dicts.append(current_block)
    
    # 如果只有一个分片,直接返回
    if len(sharded_state_dicts) == 1:
        return {FLAX_WEIGHTS_NAME: sharded_state_dicts[0]}, None
    
    # 否则,构建权重映射和分片文件名
    weight_map = {}
    shards = {}
    for idx, shard in enumerate(sharded_state_dicts):
        shard_file = FLAX_WEIGHTS_NAME.replace(".msgpack", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.msgpack")
        shards[shard_file] = shard
        for weight_name in shard.keys():
            weight_map[weight_name] = shard_file
    
    # 添加元数据
    metadata = {"total_size": total_size}
    index = {"metadata": metadata, "weight_map": weight_map}
    return shards, index
# FlaxPreTrainedModel 类,继承自 PushToHubMixin 和 FlaxGenerationMixin
class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
    # 所有模型的基类。
    r"""
    Base class for all models.

    [`FlaxPreTrainedModel`] takes care of storing the configuration of the models and handles methods for loading,
    downloading and saving models.

    Class attributes (overridden by derived classes):

        - **config_class** ([`PretrainedConfig`]) -- A subclass of [`PretrainedConfig`] to use as configuration class
          for this model architecture.
        - **base_model_prefix** (`str`) -- A string indicating the attribute associated to the base model in derived
          classes of the same architecture adding modules on top of the base model.
        - **main_input_name** (`str`) -- The name of the principal input to the model (often `input_ids` for NLP
          models, `pixel_values` for vision models and `input_values` for speech models).
    """

    # 模型配置类,默认为 None
    config_class = None
    # 基模型前缀,默认为空字符串
    base_model_prefix = ""
    # 主要输入名称,默认为 "input_ids"
    main_input_name = "input_ids"
    # 自动类
    _auto_class = None
    # 缺失的键集合
    _missing_keys = set()

    # 模型初始化方法
    def __init__(
        self,
        config: PretrainedConfig,
        module: nn.Module,
        input_shape: Tuple = (1, 1),
        seed: int = 0,
        dtype: jnp.dtype = jnp.float32,
        _do_init: bool = True,
    ):
        # 如果 config 为 None,则抛出 ValueError
        if config is None:
            raise ValueError("config cannot be None")

        # 如果 module 为 None,则抛出 ValueError
        if module is None:
            raise ValueError("module cannot be None")

        # 下面的属性用于在派生类中作为类型化属性暴露,因此为私有属性。
        # 存储配置对象
        self._config = config
        # 存储模块对象
        self._module = module

        # 下面的属性为每个派生类通用的公共属性。
        # 初始化随机数生成器的 key
        self.key = PRNGKey(seed)
        # 数据类型,默认为 jnp.float32
        self.dtype = dtype
        # 输入形状,默认为 (1, 1)
        self.input_shape = input_shape
        # 生成配置对象,基于模型配置生成,如果可以生成的话
        self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None

        # 标志模型是否已初始化
        self._is_initialized = _do_init

        # 如果 _do_init 为 True,则随机初始化参数
        if _do_init:
            # 随机初始化模型参数
            random_params = self.init_weights(self.key, input_shape)
            # 计算参数的形状树
            params_shape_tree = jax.eval_shape(lambda params: params, random_params)
        else:
            # 如果 _do_init 为 False,则部分初始化模型参数
            init_fn = partial(self.init_weights, input_shape=input_shape)
            params_shape_tree = jax.eval_shape(init_fn, self.key)

            # 日志记录,提示模型权重未初始化
            logger.info(
                "Model weights are not initialized as `_do_init` is set to `False`. "
                f"Make sure to call `{self.__class__.__name__}.init_weights` manually to initialize the weights."
            )

        # 存储参数形状树
        self._params_shape_tree = params_shape_tree

        # 将必需参数保存为集合
        self._required_params = set(flatten_dict(unfreeze(params_shape_tree)).keys())

        # 如果 _do_init 为 True,则设置模型参数
        if _do_init:
            self.params = random_params
    # 定义一个抽象方法,用于初始化模型的权重。子类必须实现这个方法。
    def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> Dict:
        raise NotImplementedError(f"init method has to be implemented for {self}")

    # 定义一个抽象方法,用于启用梯度检查点功能。子类必须实现这个方法。
    def enable_gradient_checkpointing(self):
        raise NotImplementedError(f"gradient checkpointing method has to be implemented for {self}")

    # 类方法,用于根据给定的配置和其他参数创建类的实例。
    @classmethod
    def _from_config(cls, config, **kwargs):
        """
        All context managers that the model should be initialized under go here.
        """
        return cls(config, **kwargs)

    # 返回字符串标识,指示这是一个 Flax 模型。
    @property
    def framework(self) -> str:
        """
        :str: Identifies that this is a Flax model.
        """
        return "flax"

    # 返回模型的配置信息。
    @property
    def config(self) -> PretrainedConfig:
        return self._config

    # 返回模型的内部模块。
    @property
    def module(self) -> nn.Module:
        return self._module

    # 返回模型的参数,可以是普通字典或者冻结字典。
    @property
    def params(self) -> Union[Dict, FrozenDict]:
        if not self._is_initialized:
            raise ValueError(
                "`params` cannot be accessed from model when the model is created with `_do_init=False`. "
                "You must call `init_weights` manually and store the params outside of the model and "
                "pass it explicitly where needed."
            )
        return self._params

    # 返回模型所需的参数集合。
    @property
    def required_params(self) -> Set:
        return self._required_params

    # 返回模型参数的形状树。
    @property
    def params_shape_tree(self) -> Dict:
        return self._params_shape_tree

    # 设置模型的参数,如果模型未初始化则抛出异常。
    @params.setter
    def params(self, params: Union[Dict, FrozenDict]):
        # 如果模型未初始化,则不设置参数。
        if not self._is_initialized:
            raise ValueError(
                "`params` cannot be set from model when the model is created with `_do_init=False`. "
                "You store the params outside of the model."
            )

        # 如果参数是冻结字典,则解冻成普通字典。
        if isinstance(params, FrozenDict):
            params = unfreeze(params)
        
        # 检查参数是否包含所有必需的参数键。
        param_keys = set(flatten_dict(params).keys())
        if len(self.required_params - param_keys) > 0:
            raise ValueError(
                "Some parameters are missing. Make sure that `params` include the following "
                f"parameters {self.required_params - param_keys}"
            )
        
        # 设置模型的参数。
        self._params = params
    def _cast_floating_to(self, params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any:
        """
        Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`.
        """

        # 从 https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27 中借用
        # 定义条件转换函数,用于将参数中的浮点值转换为指定的 dtype
        def conditional_cast(param):
            if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating):
                param = param.astype(dtype)
            return param

        # 如果 mask 为 None,则直接对 params 应用 tree_map 转换
        if mask is None:
            return jax.tree_util.tree_map(conditional_cast, params)

        # 将 params 展平为字典
        flat_params = flatten_dict(params)
        # 将 mask 也展平并获取其结构
        flat_mask, _ = jax.tree_util.tree_flatten(mask)

        # 遍历展平后的 mask 和 params 的键值对,并根据 mask 的值进行条件转换
        for masked, key in zip(flat_mask, sorted(flat_params.keys())):
            if masked:
                flat_params[key] = conditional_cast(flat_params[key])

        # 返回转换后的 params 的非展平版本
        return unflatten_dict(flat_params)

    def to_bf16(self, params: Union[Dict, FrozenDict], mask: Any = None):
        r"""
        Cast the floating-point `params` to `jax.numpy.bfloat16`. This returns a new `params` tree and does not cast
        the `params` in place.

        This method can be used on TPU to explicitly convert the model parameters to bfloat16 precision to do full
        half-precision training or to save weights in bfloat16 for inference in order to save memory and improve speed.

        Arguments:
            params (`Union[Dict, FrozenDict]`):
                A `PyTree` of model parameters.
            mask (`Union[Dict, FrozenDict]`):
                A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params
                you want to cast, and should be `False` for those you want to skip.

        Examples:

        ```
        >>> from transformers import FlaxBertModel

        >>> # load model
        >>> model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased")
        >>> # By default, the model parameters will be in fp32 precision, to cast these to bfloat16 precision
        >>> model.params = model.to_bf16(model.params)
        >>> # If you want don't want to cast certain parameters (for example layer norm bias and scale)
        >>> # then pass the mask as follows
        >>> from flax import traverse_util

        >>> model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased")
        >>> flat_params = traverse_util.flatten_dict(model.params)
        >>> mask = {
        ...     path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale"))
        ...     for path in flat_params
        ... }
        >>> mask = traverse_util.unflatten_dict(mask)
        >>> model.params = model.to_bf16(model.params, mask)
        ```
        """
        return self._cast_floating_to(params, jnp.bfloat16, mask)
    def to_fp32(self, params: Union[Dict, FrozenDict], mask: Any = None):
        r"""
        Cast the floating-point `parmas` to `jax.numpy.float32`. This method can be used to explicitly convert the
        model parameters to fp32 precision. This returns a new `params` tree and does not cast the `params` in place.

        Arguments:
            params (`Union[Dict, FrozenDict]`):
                A `PyTree` of model parameters.
            mask (`Union[Dict, FrozenDict]`):
                A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params
                you want to cast, and should be `False` for those you want to skip

        Examples:

        ```
        >>> from transformers import FlaxBertModel

        >>> # Download model and configuration from huggingface.co
        >>> model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased")
        >>> # By default, the model params will be in fp32, to illustrate the use of this method,
        >>> # we'll first cast to fp16 and back to fp32
        >>> model.params = model.to_f16(model.params)
        >>> # now cast back to fp32
        >>> model.params = model.to_fp32(model.params)
        ```
        
        # 使用 jax 库中的 numpy 模块将浮点型参数 `params` 转换为单精度浮点数(float32)
        return self._cast_floating_to(params, jnp.float32, mask)
    # 将浮点数参数 `params` 转换为 `jax.numpy.float16` 类型。返回一个新的 `params` 树,不会原地修改 `params`。
    #
    # 在 GPU 上可以使用此方法显式地将模型参数转换为 float16 精度,以进行全半精度训练,或者将权重保存为 float16 以节省内存并提高速度。
    #
    # 参数:
    #     params (`Union[Dict, FrozenDict]`):
    #         模型参数的 PyTree 结构。
    #     mask (`Union[Dict, FrozenDict]`, 可选):
    #         与 `params` 结构相同的 PyTree。叶子节点应为布尔值,`True` 表示要转换的参数,`False` 表示要跳过的参数。
    #
    # 示例:
    #
    # ```
    # >>> from transformers import FlaxBertModel
    # >>>
    # >>> # 加载模型
    # >>> model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased")
    # >>>
    # >>> # 默认情况下,模型参数将是 fp32 类型,要将其转换为 float16 类型
    # >>> model.params = model.to_fp16(model.params)
    # >>>
    # >>> # 如果不想转换某些参数(例如层归一化的偏置和缩放)
    # >>> # 则按以下方式传递 mask
    # >>> from flax import traverse_util
    # >>>
    # >>> model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased")
    # >>> flat_params = traverse_util.flatten_dict(model.params)
    # >>> mask = {
    # ...     path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale"))
    # ...     for path in flat_params
    # ... }
    # >>> mask = traverse_util.unflatten_dict(mask)
    # >>> model.params = model.to_fp16(model.params, mask)
    # ```
    def to_fp16(self, params: Union[Dict, FrozenDict], mask: Any = None):
        # 调用内部方法 `_cast_floating_to` 将 `params` 中的浮点数类型转换为 `jnp.float16` 类型
        return self._cast_floating_to(params, jnp.float16, mask)
    # 定义一个类方法,用于加载 Flax 模型的权重数据
    def load_flax_weights(cls, resolved_archive_file):
        try:
            # 如果文件名以 ".safetensors" 结尾,使用 safe_load_file 加载状态
            if resolved_archive_file.endswith(".safetensors"):
                state = safe_load_file(resolved_archive_file)
                # 使用特定分隔符将状态字典展开
                state = unflatten_dict(state, sep=".")
            else:
                # 否则,使用二进制方式读取文件并将其反序列化为对象状态
                with open(resolved_archive_file, "rb") as state_f:
                    state = from_bytes(cls, state_f.read())
        # 捕获反序列化过程可能出现的异常
        except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
            try:
                # 尝试以文本模式打开文件,检查其内容以确定错误类型
                with open(resolved_archive_file) as f:
                    # 如果文件内容以 "version" 开头,可能是由于缺少 git-lfs 导致的错误
                    if f.read().startswith("version"):
                        raise OSError(
                            "You seem to have cloned a repository without having git-lfs installed. Please"
                            " install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
                            " folder you cloned."
                        )
                    else:
                        # 否则,抛出 ValueError 并将原始异常作为其原因
                        raise ValueError from e
            # 捕获可能的解码或数值错误
            except (UnicodeDecodeError, ValueError):
                # 抛出环境错误,指示无法将文件转换为 Flax 可反序列化对象
                raise EnvironmentError(f"Unable to convert {resolved_archive_file} to Flax deserializable object. ")

        # 返回加载的状态对象
        return state

    @classmethod
    def load_flax_sharded_weights(cls, shard_files):
        """
        This is the same as [`flax.serialization.from_bytes`](https://flax.readthedocs.io/en/latest/_modules/flax/serialization.html#from_bytes) but for a sharded checkpoint.

        This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being
        loaded in the model.

        Args:
            shard_files (`List[str]`):
                The list of shard files to load.

        Returns:
            `Dict`: A nested dictionary of the model parameters, in the expected format for flax models : `{'model':
            {'params': {'...'}}}`.
        """

        # Load the index
        state_sharded_dict = {}

        for shard_file in shard_files:
            # load using msgpack utils
            try:
                with open(shard_file, "rb") as state_f:
                    state = from_bytes(cls, state_f.read())
            except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
                # Handle specific error cases
                with open(shard_file) as f:
                    if f.read().startswith("version"):
                        raise OSError(
                            "You seem to have cloned a repository without having git-lfs installed. Please"
                            " install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
                            " folder you cloned."
                        )
                    else:
                        raise ValueError from e
            except (UnicodeDecodeError, ValueError):
                # Raise an environment error if conversion fails
                raise EnvironmentError(f"Unable to convert {shard_file} to Flax deserializable object. ")

            # Flatten the state dictionary using '/' separator
            state = flatten_dict(state, sep="/")
            # Update the main dictionary with the flattened state
            state_sharded_dict.update(state)
            # Clean up the `state` variable from memory
            del state
            # Perform garbage collection to free up memory
            gc.collect()

        # Unflatten the state_sharded_dict to match the format of model.params
        return unflatten_dict(state_sharded_dict, sep="/")

    @classmethod
    def can_generate(cls) -> bool:
        """
        Returns whether this model can generate sequences with `.generate()`.

        Returns:
            `bool`: Whether this model can generate sequences with `.generate()`.
        """
        # Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation.
        # Alternatively, the model can also have a custom `generate` function.
        if "GenerationMixin" in str(cls.prepare_inputs_for_generation) and "GenerationMixin" in str(cls.generate):
            # If both conditions are met, return False indicating generation capability is not supported
            return False
        # If not, return True indicating generation capability is supported
        return True
    def from_pretrained(
        cls,
        pretrained_model_name_or_path: Union[str, os.PathLike],  # 接受预训练模型名称或路径作为输入参数
        dtype: jnp.dtype = jnp.float32,  # 指定数据类型,默认为 jnp.float32
        *model_args,  # 其余位置参数
        config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,  # 预训练配置对象或其路径,可选参数
        cache_dir: Optional[Union[str, os.PathLike]] = None,  # 缓存目录的路径,可选参数
        ignore_mismatched_sizes: bool = False,  # 是否忽略大小不匹配的情况,默认为 False
        force_download: bool = False,  # 是否强制下载,默认为 False
        local_files_only: bool = False,  # 是否仅使用本地文件,默认为 False
        token: Optional[Union[str, bool]] = None,  # token 用于验证,可选参数
        revision: str = "main",  # 版本号,默认为 "main"
        **kwargs,  # 其余关键字参数
    ):
        """
        从预训练模型加载模型参数和配置。

        <Tip warning={true}>
        当前 API 处于实验阶段,未来版本可能会有一些轻微的更改。
        </Tip>

        Args:
            pretrained_model_name_or_path (`str` or `os.PathLike`):
                预训练模型的名称或路径。
            dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
                指定加载参数时使用的数据类型,默认为 jnp.float32。
            *model_args:
                其余位置参数,传递给具体模型加载函数。
            config (`PretrainedConfig`, `str`, `os.PathLike`, *optional*, defaults to `None`):
                预训练模型的配置对象或其路径,可选参数。
            cache_dir (`str` or `os.PathLike`, *optional*, defaults to `None`):
                缓存目录的路径,可选参数。
            ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`):
                是否忽略加载参数时大小不匹配的情况,默认为 False。
            force_download (`bool`, *optional*, defaults to `False`):
                是否强制重新下载模型,默认为 False。
            local_files_only (`bool`, *optional*, defaults to `False`):
                是否仅使用本地文件加载模型,默认为 False。
            token (`str` or `bool`, *optional*, defaults to `None`):
                token 用于验证下载的模型,可选参数。
            revision (`str`, *optional*, defaults to `"main"`):
                模型的版本号,默认为 "main"。
            **kwargs:
                其余关键字参数,传递给具体模型加载函数。
        """

    def save_pretrained(
        self,
        save_directory: Union[str, os.PathLike],  # 保存模型的目录路径
        params=None,  # 要保存的模型参数,默认为 None
        push_to_hub=False,  # 是否推送到模型 Hub,默认为 False
        max_shard_size="10GB",  # 最大的分片大小,默认为 "10GB"
        token: Optional[Union[str, bool]] = None,  # token 用于验证,可选参数
        safe_serialization: bool = False,  # 是否进行安全序列化,默认为 False
        **kwargs,  # 其余关键字参数
    ):
        """
        将当前模型保存到指定目录。

        Args:
            save_directory (`str` or `os.PathLike`):
                保存模型的目录路径。
            params:
                要保存的模型参数,默认为 None。
            push_to_hub (`bool`, *optional*, defaults to `False`):
                是否将模型推送到模型 Hub,默认为 False。
            max_shard_size (`str`, *optional*, defaults to `"10GB"`):
                最大的分片大小,默认为 "10GB"。
            token (`str` or `bool`, *optional*, defaults to `None`):
                token 用于验证保存操作,可选参数。
            safe_serialization (`bool`, *optional*, defaults to `False`):
                是否进行安全序列化,默认为 False。
            **kwargs:
                其余关键字参数,传递给具体保存函数。
        """

    @classmethod
    def register_for_auto_class(cls, auto_class="FlaxAutoModel"):
        """
        注册当前模型类到指定的自动加载类。仅用于自定义模型,因为库中的模型已经与自动加载类映射。

        <Tip warning={true}>
        当前 API 处于实验阶段,未来版本可能会有一些轻微的更改。
        </Tip>

        Args:
            auto_class (`str` or `type`, *optional*, defaults to `"FlaxAutoModel"`):
                要注册新模型的自动加载类名称或类型。
        """

        if not isinstance(auto_class, str):
            auto_class = auto_class.__name__

        import transformers.models.auto as auto_module

        if not hasattr(auto_module, auto_class):
            raise ValueError(f"{auto_class} is not a valid auto class.")

        cls._auto_class = auto_class
# 复制 FlaxPreTrainedModel 类中的 push_to_hub 方法,以确保我们修改的是其副本而非原始方法
FlaxPreTrainedModel.push_to_hub = copy_func(FlaxPreTrainedModel.push_to_hub)

# 如果 push_to_hub 方法已有文档字符串,则使用格式化字符串来更新其文档字符串,将对象类型、对象类名和对象文件类型作为参数插入
if FlaxPreTrainedModel.push_to_hub.__doc__ is not None:
    FlaxPreTrainedModel.push_to_hub.__doc__ = FlaxPreTrainedModel.push_to_hub.__doc__.format(
        object="model", object_class="FlaxAutoModel", object_files="model checkpoint"
    )


def overwrite_call_docstring(model_class, docstring):
    # 复制 model_class 的 __call__ 方法,以确保仅修改该函数的文档字符串
    model_class.__call__ = copy_func(model_class.__call__)
    # 删除现有的 __call__ 方法的文档字符串
    model_class.__call__.__doc__ = None
    # 设置正确的 __call__ 方法文档字符串,使用指定的 docstring
    model_class.__call__ = add_start_docstrings_to_model_forward(docstring)(model_class.__call__)


def append_call_sample_docstring(
    model_class, checkpoint, output_type, config_class, mask=None, revision=None, real_checkpoint=None
):
    # 复制 model_class 的 __call__ 方法,以确保仅修改该函数的文档字符串
    model_class.__call__ = copy_func(model_class.__call__)
    # 使用 add_code_sample_docstrings 函数为 __call__ 方法添加代码示例的文档字符串,传入相关参数
    model_class.__call__ = add_code_sample_docstrings(
        checkpoint=checkpoint,
        output_type=output_type,
        config_class=config_class,
        model_cls=model_class.__name__,
        revision=revision,
        real_checkpoint=real_checkpoint,
    )(model_class.__call__)


def append_replace_return_docstrings(model_class, output_type, config_class):
    # 复制 model_class 的 __call__ 方法,以确保仅修改该函数的文档字符串
    model_class.__call__ = copy_func(model_class.__call__)
    # 使用 replace_return_docstrings 函数替换 __call__ 方法的返回值相关文档字符串,传入输出类型和配置类参数
    model_class.__call__ = replace_return_docstrings(
        output_type=output_type,
        config_class=config_class,
    )(model_class.__call__)

.\modeling_outputs.py

# 导入警告模块,用于可能的警告消息
import warnings
# 导入 dataclass 模块,用于定义数据类
from dataclasses import dataclass
# 导入 Optional 和 Tuple 类型提示
from typing import Optional, Tuple

# 导入 PyTorch 模块
import torch

# 从当前包中导入 ModelOutput 类
from .utils import ModelOutput

# 定义 BaseModelOutput 数据类,继承自 ModelOutput
@dataclass
class BaseModelOutput(ModelOutput):
    """
    模型输出的基类,可能包含隐藏状态和注意力。
    
    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            模型最后一层的隐藏状态序列。
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, 当 `output_hidden_states=True` 时返回或当 `config.output_hidden_states=True` 时返回):
            包含每层输出的元组 `torch.FloatTensor`(如果模型有嵌入层,则包含嵌入层输出),
            形状为 `(batch_size, sequence_length, hidden_size)`。

            模型在每层的隐藏状态以及可选的初始嵌入输出。
        attentions (`tuple(torch.FloatTensor)`, *optional*, 当 `output_attentions=True` 时返回或当 `config.output_attentions=True` 时返回):
            包含每层注意力权重的元组 `torch.FloatTensor`,
            形状为 `(batch_size, num_heads, sequence_length, sequence_length)`。

            注意力权重经过 softmax 后的结果,在自注意力头中用于计算加权平均值。
    """

    # 最后的隐藏状态,默认为 None
    last_hidden_state: torch.FloatTensor = None
    # 隐藏状态的元组,可选,默认为 None
    hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
    # 注意力权重的元组,可选,默认为 None
    attentions: Optional[Tuple[torch.FloatTensor, ...]] = None


# 定义 BaseModelOutputWithNoAttention 数据类,继承自 ModelOutput
@dataclass
class BaseModelOutputWithNoAttention(ModelOutput):
    """
    模型输出的基类,仅包含潜在的隐藏状态。

    继承自 ModelOutput。
    """
    # 定义函数的参数说明和类型注解。`last_hidden_state`参数是一个 `torch.FloatTensor` 类型的张量,
    # 其形状为 `(batch_size, num_channels, height, width)`,表示模型最后一层的隐藏状态序列。
    # `hidden_states`参数是一个可选的元组类型,包含 `torch.FloatTensor` 类型的张量,
    # 形状为 `(batch_size, num_channels, height, width)`。这个元组用于存储模型每一层的隐藏状态,
    # 如果模型具有嵌入层,则还包括初始嵌入层的输出。
    
    last_hidden_state: torch.FloatTensor = None
    # 初始化 `last_hidden_state` 变量为 `None`,表示这个参数可以不提供具体值。
    
    hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
    # 初始化 `hidden_states` 变量为 `None`,表示这个参数也可以不提供具体值。
    # 它是一个可选的元组类型,元组中的每个元素都是 `torch.FloatTensor` 类型的张量。
    # 如果提供了 `output_hidden_states=True` 或 `config.output_hidden_states=True`,
    # 这个元组会包含模型每一层的隐藏状态和可能的初始嵌入层的输出。
# 使用 dataclass 装饰器定义一个基础模型输出类,包含池化后的最后隐藏状态等信息
@dataclass
class BaseModelOutputWithPooling(ModelOutput):
    """
    Base class for model's outputs that also contains a pooling of the last hidden states.

    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
            Last layer hidden-state of the first token of the sequence (classification token) after further processing
            through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns
            the classification token after processing through a linear layer and a tanh activation function. The linear
            layer weights are trained from the next sentence prediction (classification) objective during pretraining.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    """

    # 定义类变量,存储最后隐藏状态的张量
    last_hidden_state: torch.FloatTensor = None
    # 定义类变量,存储经过池化处理后的分类池化输出
    pooler_output: torch.FloatTensor = None
    # 定义类变量,存储模型隐藏状态的元组(每层的输出)
    hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
    # 定义类变量,存储注意力权重的元组(每层的注意力权重)
    attentions: Optional[Tuple[torch.FloatTensor, ...]] = None


# 使用 dataclass 装饰器定义一个基础模型输出类,包含池化后的最后隐藏状态,但不包含注意力权重信息
@dataclass
class BaseModelOutputWithPoolingAndNoAttention(ModelOutput):
    """
    Base class for model's outputs that also contains a pooling of the last hidden states.
    """
    # `last_hidden_state`是模型最后一层的隐藏状态,形状为(batch_size, num_channels, height, width)
    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Sequence of hidden-states at the output of the last layer of the model.
    
    # `pooler_output`是经过空间维度池化操作后的最后一层隐藏状态,形状为(batch_size, hidden_size)
    pooler_output: torch.FloatTensor = None
    
    # `hidden_states`是一个元组,包含了模型每一层的隐藏状态输出,如果模型有嵌入层,还包括初始嵌入输出。
    # 每个张量的形状为(batch_size, num_channels, height, width)。
    # 可选的返回值,当`output_hidden_states=True`被传递或者`config.output_hidden_states=True`时返回。
    hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
# 使用 dataclass 装饰器声明一个数据类,表示带有过去键/值的模型输出
@dataclass
class BaseModelOutputWithPast(ModelOutput):
    """
    Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).

    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.

            If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
            hidden_size)` is output.
        past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
            Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
            `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
            `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
            encoder_sequence_length, embed_size_per_head)`.

            Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
            `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
            input) to speed up sequential decoding.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    """

    # 定义类成员变量,表示模型输出中的最后隐藏状态、过去键/值、隐藏状态、注意力权重
    last_hidden_state: torch.FloatTensor = None
    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
    attentions: Optional[Tuple[torch.FloatTensor, ...]] = None


# 使用 dataclass 装饰器声明一个数据类,表示带有交叉注意力的模型输出
@dataclass
class BaseModelOutputWithCrossAttentions(ModelOutput):
    """
    Base class for model's outputs, with potential hidden states and attentions.
    """
    # `last_hidden_state`是模型最后一层的输出隐藏状态,形状为(batch_size, sequence_length, hidden_size)
    last_hidden_state: torch.FloatTensor = None
    # `hidden_states`是一个元组,包含模型每一层的隐藏状态(如果模型有嵌入层,则包含初始嵌入输出),
    # 其形状为(batch_size, sequence_length, hidden_size)
    hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
    # `attentions`是一个元组,包含每一层的注意力权重张量,形状为(batch_size, num_heads, sequence_length, sequence_length),
    # 用于计算自注意力头中的加权平均值
    attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
    # `cross_attentions`是一个元组,包含解码器交叉注意力层的注意力权重张量,形状为(batch_size, num_heads, sequence_length, sequence_length),
    # 用于计算交叉注意力头中的加权平均值
    cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
# 带有池化和交叉注意力的基础模型输出类,继承自ModelOutput
@dataclass
class BaseModelOutputWithPoolingAndCrossAttentions(ModelOutput):
    """
    Base class for model's outputs that also contains a pooling of the last hidden states.
    模型输出的基础类,还包含最后隐藏状态的池化。

    """

    # 最后隐藏状态,类型为 torch.FloatTensor
    last_hidden_state: torch.FloatTensor = None
    # 池化输出,类型为 torch.FloatTensor
    pooler_output: torch.FloatTensor = None
    # 隐藏状态的元组,可选类型为 torch.FloatTensor
    hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
    # 过去键/值的元组,可选类型为 Tuple[Tuple[torch.FloatTensor]]
    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    # 注意力权重的元组,可选类型为 torch.FloatTensor
    attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
    # 交叉注意力权重的元组,可选类型为 torch.FloatTensor
    cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None


@dataclass
class BaseModelOutputWithPastAndCrossAttentions(ModelOutput):
    """
    Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
    模型输出的基础类,可能还包含过去的键/值(用于加速顺序解码)。

    """
    # 定义函数参数 `last_hidden_state`,表示模型最后一层的隐藏状态,是一个形状为 `(batch_size, sequence_length, hidden_size)` 的张量。
    # 如果使用了 `past_key_values`,则输出的是形状为 `(batch_size, 1, hidden_size)` 的序列最后隐藏状态。
    last_hidden_state: torch.FloatTensor = None
    
    # 定义函数可选参数 `past_key_values`,是一个元组,包含了预计算的隐藏状态键值对(在自注意力块中的键和值),长度为 `config.n_layers`,每个元组包含两个张量。
    # 如果 `config.is_encoder_decoder=True`,还包括两个形状为 `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)` 的张量。
    # 用于加速顺序解码过程。
    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    
    # 定义函数可选参数 `hidden_states`,是一个元组,包含了模型每一层的隐藏状态。
    # 如果模型有嵌入层,第一个张量是嵌入层的输出;每个张量的形状为 `(batch_size, sequence_length, hidden_size)`。
    hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
    
    # 定义函数可选参数 `attentions`,是一个元组,包含了每一层的注意力权重。
    # 每个张量的形状为 `(batch_size, num_heads, sequence_length, sequence_length)`。
    # 这些权重在自注意力头中的注意力 softmax 之后,用于计算加权平均值。
    attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
    
    # 定义函数可选参数 `cross_attentions`,是一个元组,包含了解码器交叉注意力层的注意力权重。
    # 每个张量的形状为 `(batch_size, num_heads, sequence_length, sequence_length)`。
    # 这些权重在交叉注意力头中的注意力 softmax 之后,用于计算加权平均值。
    cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
@dataclass
class MoECausalLMOutputWithPast(ModelOutput):
    """
    Base class for causal language model (or autoregressive) outputs as well as Mixture of Expert's router hidden
    states terms, to train a MoE model.

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            Language modeling loss (for next-token prediction).
        logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
            Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
            `(batch_size, num_heads, sequence_length, embed_size_per_head)`)

            Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
            `past_key_values` input) to speed up sequential decoding.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
        z_loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided):
            z_loss for the sparse modules.
        aux_loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided):
            aux_loss for the sparse modules.
        router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True` is passed or when `config.add_router_probs=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`.

            Router logits of the encoder model, useful to compute the auxiliary loss and the z_loss for the sparse
            modules.
    """

    # loss 属性,表示语言模型损失,用于下一个标记的预测
    loss: Optional[torch.FloatTensor] = None

    # logits 属性,表示语言模型头部的预测分数(SoftMax 前的每个词汇标记的分数)
    logits: torch.FloatTensor = None

    # past_key_values 属性,存储预先计算的隐藏状态(注意力机制的键值对),用于加速顺序解码
    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None

    # hidden_states 属性,存储每一层的模型隐藏状态,包括可选的初始嵌入输出
    hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None

    # attentions 属性,存储每一层的注意力权重,用于计算自注意力头部的加权平均值
    attentions: Optional[Tuple[torch.FloatTensor]] = None

    # z_loss 属性,用于稀疏模块的 z_loss
    z_loss: Optional[torch.FloatTensor] = None

    # aux_loss 属性,用于稀疏模块的 aux_loss
    aux_loss: Optional[torch.FloatTensor] = None

    # router_logits 属性,存储编码器模型的路由器 logits,用于计算稀疏模块的辅助损失和 z_loss
    router_logits: Optional[Tuple[torch.FloatTensor]] = None
    # 定义一个可选的元组 attentions,元组中包含了多个 torch.FloatTensor 类型的张量,初始值为 None
    attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
    
    # 定义一个 torch.FloatTensor 类型的张量 z_loss,初始值为 None
    z_loss: torch.FloatTensor = None
    
    # 定义一个 torch.FloatTensor 类型的张量 aux_loss,初始值为 None
    aux_loss: torch.FloatTensor = None
    
    # 定义一个可选的单元素元组 router_logits,元组中包含了一个 torch.FloatTensor 类型的张量,初始值为 None
    router_logits: Optional[Tuple[torch.FloatTensor]] = None
# 用于定义模型输出的数据类,继承自ModelOutput类
@dataclass
class MoEModelOutput(ModelOutput):
    """
    Base class for model's outputs, with potential hidden states and attentions.

    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
            模型最后一层的输出隐藏状态序列。

        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
            每一层模型的隐藏状态输出,包括初始嵌入层的输出(如果存在)。

        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
            经过注意力softmax后的注意力权重,用于计算自注意力头中的加权平均值。

        router_probs (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`.

            Raw router probabilities that are computed by MoE routers, these terms are used to compute the auxiliary
            loss and the z_loss for Mixture of Experts models.
            由MoE路由器计算得到的原始路由器概率,用于计算混合专家模型的辅助损失和z_loss。
    """

    # 定义最后一个隐藏状态,默认为None
    last_hidden_state: torch.FloatTensor = None
    # 定义隐藏状态的元组,可选,当output_hidden_states=True或config.output_hidden_states=True时返回
    hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
    # 定义注意力的元组,可选,当output_attentions=True或config.output_attentions=True时返回
    attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
    # 定义路由器概率的元组,可选,当output_router_probs=True和config.add_router_probs=True或config.output_router_probs=True时返回
    router_probs: Optional[Tuple[torch.FloatTensor]] = None


# 用于定义模型输出的数据类,继承自ModelOutput类,并且包含过去隐藏状态和注意力
@dataclass
class MoeModelOutputWithPast(ModelOutput):
    """
    Base class for model's outputs, with potential hidden states and attentions.
    """
    # 定义输入参数 `last_hidden_state`,类型为 `torch.FloatTensor`,形状为 `(batch_size, sequence_length, hidden_size)`
    last_hidden_state: torch.FloatTensor = None
    
    # 定义输入参数 `past_key_values`,类型为 `Optional[Tuple[Tuple[torch.FloatTensor]]]`,可选参数,
    # 当使用 `use_cache=True` 或 `config.use_cache=True` 时返回,包含预先计算的隐藏状态(在自注意力块中的键和值)
    # 如果 `config.is_encoder_decoder=True`,还包含交叉注意力块中的隐藏状态
    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    
    # 定义输入参数 `hidden_states`,类型为 `Optional[Tuple[torch.FloatTensor, ...]]`,可选参数,
    # 当使用 `output_hidden_states=True` 或 `config.output_hidden_states=True` 时返回,
    # 包含模型每一层的隐藏状态,以及可能的初始嵌入层输出
    hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
    
    # 定义输入参数 `attentions`,类型为 `Optional[Tuple[torch.FloatTensor, ...]]`,可选参数,
    # 当使用 `output_attentions=True` 或 `config.output_attentions=True` 时返回,
    # 包含注意力 softmax 后的注意力权重,用于在自注意力头中计算加权平均值
    attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
    
    # 定义输入参数 `router_logits`,类型为 `Optional[Tuple[torch.FloatTensor]]`,可选参数,
    # 当使用 `output_router_probs=True` 和 `config.add_router_probs=True` 时返回,
    # 包含 MoE 路由器计算的原始路由器 logits(经 softmax 处理后),用于混合专家模型的辅助损失计算
    router_logits: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class MoeCausalLMOutputWithPast(ModelOutput):
    """
    Base class for causal language model (or autoregressive) with mixture of experts outputs.

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            Language modeling loss (for next-token prediction).

        logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).

        aux_loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided):
            Auxiliary loss for the sparse modules.

        router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`.

            Raw router logits (post-softmax) computed by MoE routers, used for computing auxiliary loss in Mixture of Experts models.

        past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
            Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, each tuple containing 2 tensors of shape
            `(batch_size, num_heads, sequence_length, embed_size_per_head)`.

            Pre-computed hidden states (keys and values in self-attention blocks) for speeding up sequential decoding.

        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for embedding layer output if present, plus one for each layer) of shape
            `(batch_size, sequence_length, hidden_size)`.

            Hidden states of the model at each layer's output, including optional initial embedding outputs.

        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`.

            Attention weights after the softmax operation, used for computing weighted averages in self-attention heads.
    """

    loss: Optional[torch.FloatTensor] = None
    aux_loss: Optional[torch.FloatTensor] = None
    logits: torch.FloatTensor = None
    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
    attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
    # 定义了一个变量 router_logits,类型为 Optional[Tuple[torch.FloatTensor]],初始值为 None
@dataclass
class MoEModelOutputWithPastAndCrossAttentions(ModelOutput):
    """
    Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding) as well as
    Mixture of Expert's router hidden states terms, to train a MoE model.

    """

    last_hidden_state: torch.FloatTensor = None  # 最后一个隐藏状态的张量,默认为None
    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None  # 可选的过去键/值对的元组,用于加速序列解码
    hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None  # 可选的隐藏状态的元组
    attentions: Optional[Tuple[torch.FloatTensor, ...]] = None  # 可选的注意力分布的元组
    cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None  # 可选的交叉注意力分布的元组
    router_probs: Optional[Tuple[torch.FloatTensor]] = None  # 可选的路由器概率的元组


@dataclass
class Seq2SeqModelOutput(ModelOutput):
    """
    Base class for model encoder's outputs that also contains : pre-computed hidden states that can speed up sequential
    decoding.

    """

    last_hidden_state: torch.FloatTensor = None  # 最后一个隐藏状态的张量,默认为None
    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None  # 可选的过去键/值对的元组,用于加速序列解码
    decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None  # 可选的解码器隐藏状态的元组
    decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None  # 可选的解码器注意力分布的元组
    cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None  # 可选的交叉注意力分布的元组
    encoder_last_hidden_state: Optional[torch.FloatTensor] = None  # 可选的编码器最后一个隐藏状态的张量
    encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None  # 可选的编码器隐藏状态的元组
    encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None  # 可选的编码器注意力分布的元组


@dataclass
class Seq2SeqMoEModelOutput(ModelOutput):
    """
    Base class for model encoder's outputs that also contains : pre-computed hidden states that can speed up sequential
    decoding.

    """

    last_hidden_state: torch.FloatTensor = None  # 最后一个隐藏状态的张量,默认为None
    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None  # 可选的过去键/值对的元组,用于加速序列解码
    decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None  # 可选的解码器隐藏状态的元组
    decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None  # 可选的解码器注意力分布的元组
    decoder_router_logits: Optional[Tuple[torch.FloatTensor]] = None  # 可选的解码器路由器逻辑概率的元组
    cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None  # 可选的交叉注意力分布的元组
    encoder_last_hidden_state: Optional[torch.FloatTensor] = None  # 可选的编码器最后一个隐藏状态的张量
    encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None  # 可选的编码器隐藏状态的元组
    encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None  # 可选的编码器注意力分布的元组
    encoder_router_logits: Optional[Tuple[torch.FloatTensor]] = None  # 可选的编码器路由器逻辑概率的元组


@dataclass
class CausalLMOutput(ModelOutput):
    """
    Base class for causal language model (or autoregressive) outputs.
    """

    # 这里没有定义任何字段,但作为基类提供了一个用于因果语言模型输出的基础类
    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            Language modeling loss (for next-token prediction).
            语言建模损失(用于下一个标记预测),是一个形状为 `(1,)` 的 `torch.FloatTensor`,当提供 `labels` 时返回。
        logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
            语言建模头部的预测分数(在 SoftMax 之前每个词汇标记的得分),形状为 `(batch_size, sequence_length, config.vocab_size)` 的 `torch.FloatTensor`。
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
            模型在每层输出的隐藏状态,以及可选的初始嵌入输出,形状为 `(batch_size, sequence_length, hidden_size)` 的 `torch.FloatTensor` 元组。
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
            注意力 softmax 后的注意力权重,用于计算自注意力头部中的加权平均,形状为 `(batch_size, num_heads, sequence_length, sequence_length)` 的 `torch.FloatTensor` 元组。
    """

    loss: Optional[torch.FloatTensor] = None
    logits: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
    attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
# 使用 dataclass 装饰器定义一个类,用于表示因果语言模型(或自回归模型)的输出结果,继承自 ModelOutput 类。
@dataclass
class CausalLMOutputWithPast(ModelOutput):
    """
    因果语言模型(或自回归模型)输出的基类。

    Args:
        loss (`torch.FloatTensor`,形状为 `(1,)`,*可选*,当提供 `labels` 参数时返回):
            语言建模的损失(用于下一个标记的预测)。
        logits (`torch.FloatTensor`,形状为 `(batch_size, sequence_length, config.vocab_size)`):
            语言建模头部的预测分数(每个词汇标记的分数,在 SoftMax 之前)。
        past_key_values (`tuple(tuple(torch.FloatTensor))`,*可选*,当传递 `use_cache=True` 或 `config.use_cache=True` 时返回):
            包含预先计算的隐藏状态(自注意力块中的键和值),可用于加速顺序解码。
            是一个长度为 `config.n_layers` 的元组,每个元组包含 2 个形状为 `(batch_size, num_heads, sequence_length, embed_size_per_head)` 的张量。
        hidden_states (`tuple(torch.FloatTensor)`,*可选*,当传递 `output_hidden_states=True` 或 `config.output_hidden_states=True` 时返回):
            包含模型在每一层输出的隐藏状态张量的元组(如果模型有嵌入层,则包含初始嵌入输出),
            形状为 `(batch_size, sequence_length, hidden_size)`。
        attentions (`tuple(torch.FloatTensor)`,*可选*,当传递 `output_attentions=True` 或 `config.output_attentions=True` 时返回):
            自注意力头部中注意力 softmax 后的注意力权重张量的元组(每层一个),形状为 `(batch_size, num_heads, sequence_length, sequence_length)`。
    """

    # 以下是类的字段定义
    loss: Optional[torch.FloatTensor] = None
    logits: torch.FloatTensor = None
    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
    attentions: Optional[Tuple[torch.FloatTensor, ...]] = None


# 使用 dataclass 装饰器定义另一个类,用于表示具有交叉注意力的因果语言模型(或自回归模型)的输出结果,继承自 ModelOutput 类。
@dataclass
class CausalLMOutputWithCrossAttentions(ModelOutput):
    """
    因果语言模型(或自回归模型)输出的基类,具有交叉注意力。

    这个类继承自 ModelOutput。
    """

    # 注意:这里的类定义未完全提供,根据文档字符串需要添加额外的字段和解释。
    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            Language modeling loss (for next-token prediction).
        logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
        cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Cross attentions weights after the attention softmax, used to compute the weighted average in the
            cross-attention heads.
        past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
            Tuple of `torch.FloatTensor` tuples of length `config.n_layers`, with each tuple containing the cached key,
            value states of the self-attention and the cross-attention layers if model is used in encoder-decoder
            setting. Only relevant if `config.is_decoder = True`.

            Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
            `past_key_values` input) to speed up sequential decoding.
@dataclass
class SequenceClassifierOutputWithPast(ModelOutput):
    """
    Base class for outputs of sequence classification models that also include past key values,
    hidden states, and attentions.

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            Classification (or regression if config.num_labels==1) loss.
        logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
            Classification (or regression if config.num_labels==1) scores (before SoftMax).
        past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
            Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
            `(batch_size, num_heads, sequence_length, embed_size_per_head)`)

            Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
            `past_key_values` input) to speed up sequential decoding.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attention weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    """

    loss: Optional[torch.FloatTensor] = None
    logits: torch.FloatTensor = None
    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
    attentions: Optional[Tuple[torch.FloatTensor, ...]] = None


@dataclass
class MaskedLMOutput(ModelOutput):
    """
    Base class for outputs of masked language models.

    This class inherits `ModelOutput`, indicating it provides standard output for models.

    """
    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            Masked language modeling (MLM) loss.
            掩码语言建模(MLM)损失,当提供`labels`时返回此值。
        logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
            语言建模头的预测分数(SoftMax之前的每个词汇标记的分数)。
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer,
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
            模型在每一层输出的隐藏状态,以及可选的初始嵌入输出。
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
            注意力权重,经过注意力SoftMax后的值,用于计算自注意力头中的加权平均值。
    """

    loss: Optional[torch.FloatTensor] = None
    logits: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
    attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
@dataclass
class Seq2SeqLMOutput(ModelOutput):
    """
    Base class for sequence-to-sequence language models outputs.

    """

    # 损失值,可选的浮点张量,用于表示模型的损失
    loss: Optional[torch.FloatTensor] = None
    # 输出的对数概率值,用于表示模型生成的对数概率
    logits: torch.FloatTensor = None
    # 过去的键值,可选的嵌套元组,用于存储过去的键值
    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    # 解码器隐藏状态,可选的浮点张量元组,表示解码器的隐藏状态
    decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
    # 解码器注意力权重,可选的浮点张量元组,表示解码器的注意力权重
    decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
    # 交叉注意力权重,可选的浮点张量元组,表示编码器-解码器的交叉注意力权重
    cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
    # 编码器最后的隐藏状态,可选的浮点张量,表示编码器的最后隐藏状态
    encoder_last_hidden_state: Optional[torch.FloatTensor] = None
    # 编码器隐藏状态,可选的浮点张量元组,表示编码器的隐藏状态
    encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
    # 编码器注意力权重,可选的浮点张量元组,表示编码器的注意力权重
    encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None


@dataclass
class Seq2SeqMoEOutput(ModelOutput):
    """
    Base class for sequence-to-sequence language models outputs.

    """

    # 损失值,可选的浮点张量,用于表示模型的损失
    loss: Optional[torch.FloatTensor] = None
    # 输出的对数概率值,用于表示模型生成的对数概率
    logits: torch.FloatTensor = None
    # 编码器 Z 损失,用于表示编码器的 Z 损失
    encoder_z_loss: torch.FloatTensor = None
    # 解码器 Z 损失,用于表示解码器的 Z 损失
    decoder_z_loss: torch.FloatTensor = None
    # 编码器辅助损失,用于表示编码器的辅助损失
    encoder_aux_loss: torch.FloatTensor = None
    # 解码器辅助损失,用于表示解码器的辅助损失
    decoder_aux_loss: torch.FloatTensor = None
    # 过去的键值,可选的嵌套元组,用于存储过去的键值
    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    # 解码器隐藏状态,可选的浮点张量元组,表示解码器的隐藏状态
    decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
    # 解码器注意力权重,可选的浮点张量元组,表示解码器的注意力权重
    decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
    # 解码器路由器对数概率,可选的浮点张量,表示解码器的路由器对数概率
    decoder_router_logits: Optional[Tuple[torch.FloatTensor]] = None
    # 交叉注意力权重,可选的浮点张量元组,表示编码器-解码器的交叉注意力权重
    cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
    # 编码器最后的隐藏状态,可选的浮点张量,表示编码器的最后隐藏状态
    encoder_last_hidden_state: Optional[torch.FloatTensor] = None
    # 编码器隐藏状态,可选的浮点张量元组,表示编码器的隐藏状态
    encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
    # 编码器注意力权重,可选的浮点张量元组,表示编码器的注意力权重
    encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
    # 编码器路由器对数概率,可选的浮点张量,表示编码器的路由器对数概率
    encoder_router_logits: Optional[Tuple[torch.FloatTensor]] = None


@dataclass
class NextSentencePredictorOutput(ModelOutput):
    """
    Base class for outputs of models predicting if two sentences are consecutive or not.
    # 定义 loss 变量,用于存储下一个序列预测(分类)的损失值,类型为 torch.FloatTensor,可选参数,当提供 `next_sentence_label` 时返回。
    loss: Optional[torch.FloatTensor] = None
    # 定义 logits 变量,用于存储下一个序列预测(分类)头部的预测分数,形状为 `(batch_size, 2)` 的 torch.FloatTensor。
    logits: torch.FloatTensor = None
    # 定义 hidden_states 变量,用于存储模型每一层的隐藏状态输出,类型为元组 `Tuple[torch.FloatTensor, ...]`,可选参数,当 `output_hidden_states=True` 或 `config.output_hidden_states=True` 时返回。
    hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
    # 定义 attentions 变量,用于存储注意力权重输出,类型为元组 `Tuple[torch.FloatTensor, ...]`,可选参数,当 `output_attentions=True` 或 `config.output_attentions=True` 时返回。
    attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
# 定义一个数据类,用于表示序列分类器模型的输出结果
@dataclass
class SequenceClassifierOutput(ModelOutput):
    """
    Base class for outputs of sentence classification models.

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            Classification (or regression if config.num_labels==1) loss.
            分类模型的损失值(如果提供了`labels`):一个形状为`(1,)`的`torch.FloatTensor`,在提供`labels`时返回。
        logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
            Classification (or regression if config.num_labels==1) scores (before SoftMax).
            分类(或回归,如果`config.num_labels==1`)得分(SoftMax之前)的`torch.FloatTensor`,形状为`(batch_size, config.num_labels)`。
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
            模型每一层的输出的隐藏状态,以及可选的初始嵌入输出。形状为`(batch_size, sequence_length, hidden_size)`。
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
            注意力权重(经过注意力SoftMax后的)的元组,用于计算自注意力头中的加权平均值。形状为`(batch_size, num_heads, sequence_length, sequence_length)`。
    """

    loss: Optional[torch.FloatTensor] = None
    logits: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
    attentions: Optional[Tuple[torch.FloatTensor, ...]] = None


# 定义一个数据类,用于表示序列到序列的句子分类器模型的输出结果
@dataclass
class Seq2SeqSequenceClassifierOutput(ModelOutput):
    """
    Base class for outputs of sequence-to-sequence sentence classification models.

    """

    loss: Optional[torch.FloatTensor] = None
    logits: torch.FloatTensor = None
    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
    decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
    cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
    encoder_last_hidden_state: Optional[torch.FloatTensor] = None
    encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
    encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None


# 定义一个数据类,用于表示多选模型的输出结果
@dataclass
class MultipleChoiceModelOutput(ModelOutput):
    """
    Base class for outputs of multiple choice models.

    """
    """
    Args:
        loss (`torch.FloatTensor` of shape *(1,)*, *optional*, returned when `labels` is provided):
            分类损失值。
            如果提供了`labels`,则返回此损失值。

        logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`):
            `num_choices` 是输入张量的第二个维度。
            分类分数(SoftMax 之前的值)。

        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            一个元组,包含 `torch.FloatTensor` 的张量。
            第一个张量是模型嵌入层的输出(如果存在),每一层输出的张量的形状为 `(batch_size, sequence_length, hidden_size)`。

            模型在每一层输出的隐藏状态,加上可选的初始嵌入输出。

        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            一个元组,包含 `torch.FloatTensor` 的张量。
            每一层的注意力权重张量的形状为 `(batch_size, num_heads, sequence_length, sequence_length)`。

            经过注意力 softmax 后的注意力权重,用于计算自注意力头中的加权平均值。
    """

    loss: Optional[torch.FloatTensor] = None
    logits: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
    attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
# 用于描述令牌分类模型输出的基础类
@dataclass
class TokenClassifierOutput(ModelOutput):
    """
    Base class for outputs of token classification models.

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided) :
            Classification loss.
            分类损失,当提供 `labels` 参数时返回。
        logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`):
            Classification scores (before SoftMax).
            分类分数(SoftMax 之前的结果)。
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
            模型在每一层输出的隐藏状态,以及可选的初始嵌入层输出。
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
            注意力权重,经过注意力 SoftMax 后的结果,用于计算自注意力头中的加权平均值。
    """

@dataclass
class QuestionAnsweringModelOutput(ModelOutput):
    """
    Base class for outputs of question answering models.
    用于描述问答模型输出的基础类。

    This class is currently empty but can be extended with specific outputs of QA models.
    该类目前为空,但可以通过扩展以包含问答模型的特定输出。
    """
    # 定义函数参数和返回值的文档字符串,描述了函数的输入和输出

    loss: Optional[torch.FloatTensor] = None
    # 可选的损失张量,当提供了 `labels` 参数时返回

    start_logits: torch.FloatTensor = None
    # 开始位置的得分张量,形状为 `(batch_size, sequence_length)`

    end_logits: torch.FloatTensor = None
    # 结束位置的得分张量,形状为 `(batch_size, sequence_length)`

    hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
    # 可选的隐藏状态元组,包含每层输出的张量,形状为 `(batch_size, sequence_length, hidden_size)`
    # 如果模型有嵌入层,则还包含初始嵌入输出

    attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
    # 可选的注意力权重元组,包含每层注意力权重的张量
    # 形状为 `(batch_size, num_heads, sequence_length, sequence_length)`
    # 用于计算自注意力头中的加权平均值的注意力 softmax 后的注意力权重
# 定义了一个数据类,用于存储序列到序列问答模型的输出结果
@dataclass
class Seq2SeqQuestionAnsweringModelOutput(ModelOutput):
    """
    Base class for outputs of sequence-to-sequence question answering models.

    """

    # 损失值,如果存在的话,类型为 torch.FloatTensor
    loss: Optional[torch.FloatTensor] = None
    # 开始位置的预测 logits,类型为 torch.FloatTensor
    start_logits: torch.FloatTensor = None
    # 结束位置的预测 logits,类型为 torch.FloatTensor
    end_logits: torch.FloatTensor = None
    # 过去的键值,类型为可选的元组,包含了一系列 torch.FloatTensor
    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    # 解码器的隐藏状态,类型为可选的元组,包含了一系列 torch.FloatTensor
    decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
    # 解码器的注意力权重,类型为可选的元组,包含了一系列 torch.FloatTensor
    decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
    # 交叉注意力权重,类型为可选的元组,包含了一系列 torch.FloatTensor
    cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
    # 编码器最后的隐藏状态,类型为可选的 torch.FloatTensor
    encoder_last_hidden_state: Optional[torch.FloatTensor] = None
    # 编码器的隐藏状态,类型为可选的元组,包含了一系列 torch.FloatTensor
    encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
    # 编码器的注意力权重,类型为可选的元组,包含了一系列 torch.FloatTensor
    encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None


# 定义了一个数据类,用于存储语义分割模型的输出结果
@dataclass
class SemanticSegmenterOutput(ModelOutput):
    """
    Base class for outputs of semantic segmentation models.

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            Classification (or regression if config.num_labels==1) loss.
        logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels, logits_height, logits_width)`):
            Classification scores for each pixel.

            <Tip warning={true}>

            The logits returned do not necessarily have the same size as the `pixel_values` passed as inputs. This is
            to avoid doing two interpolations and lose some quality when a user needs to resize the logits to the
            original image size as post-processing. You should always check your logits shape and resize as needed.

            </Tip>

        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each layer) of shape `(batch_size, patch_size, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    """

    # 损失值,如果存在的话,类型为 torch.FloatTensor
    loss: Optional[torch.FloatTensor] = None
    # 分类得分 logits,类型为 torch.FloatTensor,形状为 (batch_size, config.num_labels, logits_height, logits_width)
    logits: torch.FloatTensor = None
    # 隐藏状态,类型为可选的元组,包含了一系列 torch.FloatTensor
    hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
    # 注意力权重,类型为可选的元组,包含了一系列 torch.FloatTensor
    attentions: Optional[Tuple[torch.FloatTensor, ...]] = None


# 定义了一个数据类,用于存储图像分类模型的输出结果
@dataclass
class ImageClassifierOutput(ModelOutput):
    """
    Base class for outputs of image classification models.
    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            分类(或者回归,如果 `config.num_labels==1`)的损失值。
        logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
            分类(或者回归,如果 `config.num_labels==1`)的分数(SoftMax 之前)。
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            元组类型的 `torch.FloatTensor`,包含模型在每个阶段输出的隐藏状态(特征映射),形状为 `(batch_size, sequence_length, hidden_size)`。如果模型包含嵌入层,则第一个张量表示嵌入的输出。
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            元组类型的 `torch.FloatTensor`,包含模型的注意力权重,形状为 `(batch_size, num_heads, patch_size, sequence_length)`。这些权重经过注意力 SoftMax 后得到,用于计算自注意力头中的加权平均值。
@dataclass
class ImageClassifierOutputWithNoAttention(ModelOutput):
    """
    Base class for outputs of image classification models.

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            分类模型的损失值(如果提供了`labels`参数)。
        logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
            分类模型的输出分数(在经过 SoftMax 之前)。
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            模型每个阶段的隐藏状态(也称为特征图),形状为 `(batch_size, num_channels, height, width)`。

    """

    loss: Optional[torch.FloatTensor] = None
    logits: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None


@dataclass
class DepthEstimatorOutput(ModelOutput):
    """
    Base class for outputs of depth estimation models.

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            深度估计模型的损失值(如果提供了`labels`参数)。
        predicted_depth (`torch.FloatTensor` of shape `(batch_size, height, width)`):
            每个像素预测的深度值。

        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            模型每个层的隐藏状态(也称为特征图),形状为 `(batch_size, num_channels, height, width)`。

            每个层的输出以及可选的初始嵌入输出。
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            每个层的注意力权重,形状为 `(batch_size, num_heads, patch_size, sequence_length)`。

            注意力softmax后的权重,用于计算自注意力头中的加权平均值。
    """

    loss: Optional[torch.FloatTensor] = None
    predicted_depth: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
    attentions: Optional[Tuple[torch.FloatTensor, ...]] = None


@dataclass
class ImageSuperResolutionOutput(ModelOutput):
    """
    Base class for outputs of image super resolution models.
    """
    # 定义函数的参数和返回值的类型注解
    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            重建损失,当提供`labels`时返回。
        reconstruction (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            重建的图像,可能是上采样后的结果。
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            一个元组,包含`torch.FloatTensor`类型的张量:
            - 如果模型有嵌入层,则为形状为`(batch_size, sequence_length, hidden_size)`的张量;
            - 每个阶段输出的隐藏状态(也称为特征图)。
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            一个元组,包含`torch.FloatTensor`类型的张量:
            - 每层的注意力权重,形状为`(batch_size, num_heads, patch_size, sequence_length)`。
            
            注意力 softmax 后的注意力权重,用于计算自注意力头中的加权平均值。
    
    loss: Optional[torch.FloatTensor] = None
        # 默认值为 None 的可选项,类型为 torch.FloatTensor
    reconstruction: torch.FloatTensor = None
        # 默认值为 None 的 torch.FloatTensor 类型的变量
    hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
        # 默认值为 None 的可选项,类型为元组,包含 torch.FloatTensor 类型的张量
    attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
        # 默认值为 None 的可选项,类型为元组,包含 torch.FloatTensor 类型的张量
# 使用 `dataclass` 装饰器定义一个数据类,用于表示 Wav2Vec2 模型的基础输出。
@dataclass
class Wav2Vec2BaseModelOutput(ModelOutput):
    """
    Base class for models that have been trained with the Wav2Vec2 loss objective.

    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        extract_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, conv_dim[-1])`):
            Sequence of extracted feature vectors of the last convolutional layer of the model.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    """

    # 定义类变量 `last_hidden_state`,表示模型输出的最后一层隐藏状态
    last_hidden_state: torch.FloatTensor = None
    # 定义类变量 `extract_features`,表示模型输出的最后一个卷积层的特征向量序列
    extract_features: torch.FloatTensor = None
    # 定义类变量 `hidden_states`,表示模型每一层的隐藏状态的元组,可选返回项
    hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
    # 定义类变量 `attentions`,表示注意力权重的元组,可选返回项,用于自注意力头中的加权平均计算
    attentions: Optional[Tuple[torch.FloatTensor, ...]] = None


# 使用 `dataclass` 装饰器定义一个数据类,表示 `Wav2Vec2ForXVector` 的输出类型。
@dataclass
class XVectorOutput(ModelOutput):
    """
    Output type of [`Wav2Vec2ForXVector`].
    """

    # 此数据类暂未定义任何具体的输出内容,保留空的定义。
    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            分类损失。
            如果提供了 `labels`,则返回分类损失。
        logits (`torch.FloatTensor` of shape `(batch_size, config.xvector_output_dim)`):
            AMSoftmax 前的分类隐藏状态。
            用于 AMSoftmax 前的分类隐藏状态。
        embeddings (`torch.FloatTensor` of shape `(batch_size, config.xvector_output_dim)`):
            用于基于向量相似性检索的话语嵌入。
            用于基于向量相似性检索的话语嵌入。
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            模型每层输出的隐藏状态。
            当传递 `output_hidden_states=True` 或 `config.output_hidden_states=True` 时返回。
            元组包含了每层的 `torch.FloatTensor`,形状为 `(batch_size, sequence_length, hidden_size)`。
            包括每层的隐藏状态以及初始嵌入输出。
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            自注意力权重。
            当传递 `output_attentions=True` 或 `config.output_attentions=True` 时返回。
            元组包含了每层的 `torch.FloatTensor`,形状为 `(batch_size, num_heads, sequence_length, sequence_length)`。
            在注意力 softmax 后的注意力权重,用于计算自注意力头中的加权平均值。
    """

    loss: Optional[torch.FloatTensor] = None
    logits: torch.FloatTensor = None
    embeddings: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
    attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
# 使用 dataclass 装饰器定义一个名为 `BackboneOutput` 的数据类,它继承自 `ModelOutput` 类
@dataclass
class BackboneOutput(ModelOutput):
    """
    Base class for outputs of backbones.

    Args:
        feature_maps (`tuple(torch.FloatTensor)` of shape `(batch_size, num_channels, height, width)`):
            Feature maps of the stages.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)` or `(batch_size, num_channels, height, width)`,
            depending on the backbone.

            Hidden-states of the model at the output of each stage plus the initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`. Only applicable if the backbone uses attention.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    """

    # 定义特征图的属性,类型为元组,包含了每个阶段的特征图
    feature_maps: Tuple[torch.FloatTensor] = None
    # 定义隐藏状态的属性,类型为可选的元组,包含了每个阶段的隐藏状态
    hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
    # 定义注意力权重的属性,类型为可选的元组,包含了每个层的注意力权重
    attentions: Optional[Tuple[torch.FloatTensor, ...]] = None


# 使用 dataclass 装饰器定义一个名为 `BaseModelOutputWithPoolingAndProjection` 的数据类,它继承自 `ModelOutput` 类
@dataclass
class BaseModelOutputWithPoolingAndProjection(ModelOutput):
    """
    Base class for model's outputs that also contains a pooling of the last hidden states.
    """
    # 定义函数参数和它们的类型注释,描述了函数所接收的不同类型的输入数据
    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            模型最后一层输出的隐藏状态序列。
        pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
            经过附加预训练任务处理后的序列第一个标记(分类标记)的最后一层隐藏状态。
            例如,在BERT系列模型中,这是经过线性层和tanh激活函数处理后的分类标记。
            线性层的权重是从预训练过程中的下一句预测(分类)目标中训练得到的。
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            模型每一层的隐藏状态序列的元组。
            每个元素的形状为 `(batch_size, sequence_length, hidden_size)`,包括可选的初始嵌入层输出。
            当 `output_hidden_states=True` 传递给模型或者 `config.output_hidden_states=True` 时返回。
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            模型每一层的注意力权重的元组。
            每个元素的形状为 `(batch_size, num_heads, sequence_length, sequence_length)`,
            用于计算自注意力头中的加权平均值。
            当 `output_attentions=True` 传递给模型或者 `config.output_attentions=True` 时返回。
        projection_state (`tuple(torch.FloatTensor)`, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            投影层之前的文本嵌入的元组。
            每个元素的形状为 `(batch_size, config.project_dim)`,
            用于模拟教师编码器的最后隐藏状态。
@dataclass
class Seq2SeqSpectrogramOutput(ModelOutput):
    """
    Base class for sequence-to-sequence spectrogram outputs.

    """

    loss: Optional[torch.FloatTensor] = None  # 损失值,用于存储模型输出的损失
    spectrogram: torch.FloatTensor = None  # 频谱图数据,存储模型生成的频谱图
    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None  # 过去的键值对,用于存储可加速顺序解码的隐藏状态
    decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None  # 解码器的隐藏状态列表
    decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None  # 解码器的注意力权重列表
    cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None  # 交叉注意力权重列表
    encoder_last_hidden_state: Optional[torch.FloatTensor] = None  # 编码器的最后隐藏状态
    encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None  # 编码器的隐藏状态列表
    encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None  # 编码器的注意力权重列表


@dataclass
class Seq2SeqTSModelOutput(ModelOutput):
    """
    Base class for time series model's encoder outputs that also contains pre-computed hidden states that can speed up
    sequential decoding.

    """

    last_hidden_state: torch.FloatTensor = None  # 最后的隐藏状态,存储编码器最后的隐藏状态
    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None  # 过去的键值对,用于存储可加速顺序解码的隐藏状态
    decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None  # 解码器的隐藏状态列表
    decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None  # 解码器的注意力权重列表
    cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None  # 交叉注意力权重列表
    encoder_last_hidden_state: Optional[torch.FloatTensor] = None  # 编码器的最后隐藏状态
    encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None  # 编码器的隐藏状态列表
    encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None  # 编码器的注意力权重列表
    loc: Optional[torch.FloatTensor] = None  # 位置参数,用于存储预测分布的位置参数
    scale: Optional[torch.FloatTensor] = None  # 尺度参数,用于存储预测分布的尺度参数
    static_features: Optional[torch.FloatTensor] = None  # 静态特征,用于存储与时间序列模型相关的静态特征


@dataclass
class Seq2SeqTSPredictionOutput(ModelOutput):
    """
    Base class for time series model's decoder outputs that also contain the loss as well as the parameters of the
    chosen distribution.

    """

    loss: Optional[torch.FloatTensor] = None  # 损失值,用于存储模型输出的损失
    params: Optional[Tuple[torch.FloatTensor]] = None  # 参数,用于存储所选分布的参数
    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None  # 过去的键值对,用于存储可加速顺序解码的隐藏状态
    decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None  # 解码器的隐藏状态列表
    decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None  # 解码器的注意力权重列表
    cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None  # 交叉注意力权重列表
    encoder_last_hidden_state: Optional[torch.FloatTensor] = None  # 编码器的最后隐藏状态
    encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None  # 编码器的隐藏状态列表
    encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None  # 编码器的注意力权重列表
    loc: Optional[torch.FloatTensor] = None  # 位置参数,用于存储预测分布的位置参数
    scale: Optional[torch.FloatTensor] = None  # 尺度参数,用于存储预测分布的尺度参数
    static_features: Optional[torch.FloatTensor] = None  # 静态特征,用于存储与时间序列模型相关的静态特征


@dataclass
class SampleTSPredictionOutput(ModelOutput):
    """
    Base class for time series model's predictions outputs that contains the sampled values from the chosen
    distribution.

    Args:
        sequences (`torch.FloatTensor` of shape `(batch_size, num_samples, prediction_length)` or `(batch_size, num_samples, prediction_length, input_size)`):
            Sampled values from the chosen distribution.

    """

    # 该类用于存储时间序列模型的预测输出,包括从所选分布中采样得到的值
    # 声明一个变量 sequences,类型为 torch 的 FloatTensor,初始值为 None
    sequences: torch.FloatTensor = None
# 使用 dataclass 装饰器定义 MaskedImageModelingOutput 类,用于封装掩码图像完成/修补模型的输出结果
@dataclass
class MaskedImageModelingOutput(ModelOutput):
    """
    Base class for outputs of masked image completion / in-painting models.
    掩码图像完成/修补模型输出结果的基类。

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided):
            Reconstruction loss.
            重建损失,当提供 `bool_masked_pos` 时返回。
        reconstruction (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
           Reconstructed / completed images.
           重建/完成的图像。

        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or
        when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states
            (also called feature maps) of the model at the output of each stage.
            隐藏状态,模型在每个阶段输出的隐藏状态(特征图)元组。

        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when
        `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size,
            sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
            the self-attention heads.
            注意力权重,经过注意力 softmax 后的权重,用于计算自注意力头中的加权平均值。
    """

    # 定义 loss 属性,类型为 torch.FloatTensor,可选,表示重建损失,默认为 None
    loss: Optional[torch.FloatTensor] = None

    # 定义 reconstruction 属性,类型为 torch.FloatTensor,表示重建/完成的图像
    reconstruction: torch.FloatTensor = None

    # 定义 hidden_states 属性,类型为 tuple(torch.FloatTensor),可选,表示隐藏状态的元组
    hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None

    # 定义 attentions 属性,类型为 tuple(torch.FloatTensor),可选,表示注意力权重的元组
    attentions: Optional[Tuple[torch.FloatTensor, ...]] = None

    # @property 装饰器,定义 logits 属性,用于获取输出的最终结果
    @property
    def logits(self):
        # 发出警告,提醒 logits 属性在 Transformers 版本 5 中将被移除,请使用 reconstruction 属性获取最终输出
        warnings.warn(
            "logits attribute is deprecated and will be removed in version 5 of Transformers."
            " Please use the reconstruction attribute to retrieve the final output instead.",
            FutureWarning,
        )
        # 返回 reconstruction 属性作为最终输出
        return self.reconstruction
posted @ 2024-06-29 17:00  绝不原创的飞龙  阅读(40)  评论(0编辑  收藏  举报