Transformers-源码解析-十七-

Transformers 源码解析(十七)

.\models\bigbird_pegasus\__init__.py

# 版权声明及许可信息
#
# 版权所有 2021 年 HuggingFace 团队。保留所有权利。
#
# 根据 Apache 许可证 2.0 版本("许可证")许可;
# 除非符合许可证,否则不得使用此文件。
# 您可以在以下网址获取许可证副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,软件分发根据"按原样"的基础分发,
# 没有任何明示或暗示的担保或条件。
# 有关详细信息,请参阅许可证。

# 导入必要的类型检查模块
from typing import TYPE_CHECKING

# 导入可选依赖未找到异常和懒加载模块
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available

# 定义模块的导入结构字典
_import_structure = {
    "configuration_bigbird_pegasus": [
        "BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP",
        "BigBirdPegasusConfig",
        "BigBirdPegasusOnnxConfig",
    ],
}

# 检查是否存在 Torch 可用,如果不可用则抛出可选依赖未找到异常
try:
    if not is_torch_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果 Torch 可用,则扩展导入结构字典
    _import_structure["modeling_bigbird_pegasus"] = [
        "BIGBIRD_PEGASUS_PRETRAINED_MODEL_ARCHIVE_LIST",
        "BigBirdPegasusForCausalLM",
        "BigBirdPegasusForConditionalGeneration",
        "BigBirdPegasusForQuestionAnswering",
        "BigBirdPegasusForSequenceClassification",
        "BigBirdPegasusModel",
        "BigBirdPegasusPreTrainedModel",
    ]

# 如果是类型检查阶段,导入配置和模型模块的具体内容
if TYPE_CHECKING:
    from .configuration_bigbird_pegasus import (
        BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP,
        BigBirdPegasusConfig,
        BigBirdPegasusOnnxConfig,
    )

    # 再次检查 Torch 是否可用,如果不可用则忽略
    try:
        if not is_torch_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 导入模型相关的内容
        from .modeling_bigbird_pegasus import (
            BIGBIRD_PEGASUS_PRETRAINED_MODEL_ARCHIVE_LIST,
            BigBirdPegasusForCausalLM,
            BigBirdPegasusForConditionalGeneration,
            BigBirdPegasusForQuestionAnswering,
            BigBirdPegasusForSequenceClassification,
            BigBirdPegasusModel,
            BigBirdPegasusPreTrainedModel,
        )

# 如果不是类型检查阶段,则注册懒加载模块
else:
    import sys

    # 将当前模块设置为懒加载模块
    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

.\models\big_bird\configuration_big_bird.py

# coding=utf-8
# Copyright 2021 Google Research and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

""" BigBird model configuration"""

# 导入所需模块
from collections import OrderedDict
from typing import Mapping

# 从相对路径导入必要的配置和工具类
from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig
from ...utils import logging

# 获取全局日志记录器
logger = logging.get_logger(__name__)

# 定义预训练模型配置文件的映射,映射了模型名称到配置文件的 URL
BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP = {
    "google/bigbird-roberta-base": "https://huggingface.co/google/bigbird-roberta-base/resolve/main/config.json",
    "google/bigbird-roberta-large": "https://huggingface.co/google/bigbird-roberta-large/resolve/main/config.json",
    "google/bigbird-base-trivia-itc": "https://huggingface.co/google/bigbird-base-trivia-itc/resolve/main/config.json",
    # 查看所有 BigBird 模型的列表:https://huggingface.co/models?filter=big_bird
}

# 定义 BigBirdConfig 类,继承自 PretrainedConfig
class BigBirdConfig(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of a [`BigBirdModel`]. It is used to instantiate an
    BigBird model according to the specified arguments, defining the model architecture. Instantiating a configuration
    with the defaults will yield a similar configuration to that of the BigBird
    [google/bigbird-roberta-base](https://huggingface.co/google/bigbird-roberta-base) architecture.

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.


    Example:

    ```
    >>> from transformers import BigBirdConfig, BigBirdModel

    >>> # Initializing a BigBird google/bigbird-roberta-base style configuration
    >>> configuration = BigBirdConfig()

    >>> # Initializing a model (with random weights) from the google/bigbird-roberta-base style configuration
    >>> model = BigBirdModel(configuration)

    >>> # Accessing the model configuration
    >>> configuration = model.config
    ```
    """

    # 定义模型类型为 "big_bird"
    model_type = "big_bird"
    # 初始化函数,用于初始化一个 Transformer 模型对象
    def __init__(
        self,
        vocab_size=50358,  # 设置词汇表大小,默认为50358
        hidden_size=768,  # 设置隐藏层大小,默认为768
        num_hidden_layers=12,  # 设置隐藏层数,默认为12
        num_attention_heads=12,  # 设置注意力头数,默认为12
        intermediate_size=3072,  # 设置中间层大小,默认为3072
        hidden_act="gelu_new",  # 设置隐藏层激活函数,默认为"gelu_new"
        hidden_dropout_prob=0.1,  # 设置隐藏层的dropout概率,默认为0.1
        attention_probs_dropout_prob=0.1,  # 设置注意力概率dropout概率,默认为0.1
        max_position_embeddings=4096,  # 设置最大位置嵌入数,默认为4096
        type_vocab_size=2,  # 设置类型词汇表大小,默认为2
        initializer_range=0.02,  # 设置初始化范围,默认为0.02
        layer_norm_eps=1e-12,  # 设置层归一化epsilon,默认为1e-12
        use_cache=True,  # 是否使用缓存,默认为True
        pad_token_id=0,  # 设置填充标记的ID,默认为0
        bos_token_id=1,  # 设置开始标记的ID,默认为1
        eos_token_id=2,  # 设置结束标记的ID,默认为2
        sep_token_id=66,  # 设置分隔标记的ID,默认为66
        attention_type="block_sparse",  # 设置注意力类型,默认为"block_sparse"
        use_bias=True,  # 是否使用偏置,默认为True
        rescale_embeddings=False,  # 是否重新缩放嵌入,默认为False
        block_size=64,  # 设置块大小,默认为64
        num_random_blocks=3,  # 设置随机块数,默认为3
        classifier_dropout=None,  # 分类器的dropout率,默认为None
        **kwargs,  # 其他关键字参数
    ):
        super().__init__(
            pad_token_id=pad_token_id,
            bos_token_id=bos_token_id,
            eos_token_id=eos_token_id,
            sep_token_id=sep_token_id,
            **kwargs,  # 调用父类的初始化函数,并传递相应的参数
        )

        self.vocab_size = vocab_size  # 初始化模型的词汇表大小
        self.max_position_embeddings = max_position_embeddings  # 初始化最大位置嵌入数
        self.hidden_size = hidden_size  # 初始化隐藏层大小
        self.num_hidden_layers = num_hidden_layers  # 初始化隐藏层数
        self.num_attention_heads = num_attention_heads  # 初始化注意力头数
        self.intermediate_size = intermediate_size  # 初始化中间层大小
        self.hidden_act = hidden_act  # 初始化隐藏层激活函数
        self.hidden_dropout_prob = hidden_dropout_prob  # 初始化隐藏层的dropout概率
        self.attention_probs_dropout_prob = attention_probs_dropout_prob  # 初始化注意力概率dropout概率
        self.initializer_range = initializer_range  # 初始化初始化范围
        self.type_vocab_size = type_vocab_size  # 初始化类型词汇表大小
        self.layer_norm_eps = layer_norm_eps  # 初始化层归一化epsilon
        self.use_cache = use_cache  # 初始化是否使用缓存

        self.rescale_embeddings = rescale_embeddings  # 初始化是否重新缩放嵌入
        self.attention_type = attention_type  # 初始化注意力类型
        self.use_bias = use_bias  # 初始化是否使用偏置
        self.block_size = block_size  # 初始化块大小
        self.num_random_blocks = num_random_blocks  # 初始化随机块数
        self.classifier_dropout = classifier_dropout  # 初始化分类器的dropout率
# 定义一个 BigBirdOnnxConfig 类,继承自 OnnxConfig 类
class BigBirdOnnxConfig(OnnxConfig):
    
    # 定义 inputs 属性,返回一个映射结构,其键为字符串,值为映射到字符串的字典
    @property
    def inputs(self) -> Mapping[str, Mapping[int, str]]:
        # 如果任务类型是多选,则动态轴包含 batch、choice 和 sequence
        if self.task == "multiple-choice":
            dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
        else:
            # 否则动态轴只包含 batch 和 sequence
            dynamic_axis = {0: "batch", 1: "sequence"}
        
        # 返回一个有序字典,包含两个键值对
        return OrderedDict(
            [
                ("input_ids", dynamic_axis),      # 键为 "input_ids",值为 dynamic_axis
                ("attention_mask", dynamic_axis), # 键为 "attention_mask",值为 dynamic_axis
            ]
        )

.\models\big_bird\convert_bigbird_original_tf_checkpoint_to_pytorch.py

# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert BigBird checkpoint."""

# 引入处理命令行参数的库
import argparse

# 引入 BigBird 相关的模型配置和加载权重的方法
from transformers import BigBirdConfig, BigBirdForPreTraining, BigBirdForQuestionAnswering, load_tf_weights_in_big_bird
from transformers.utils import logging

# 设置日志的输出级别为信息级别
logging.set_verbosity_info()

# 定义函数,用于将 TensorFlow 的 checkpoint 转换为 PyTorch 模型
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, big_bird_config_file, pytorch_dump_path, is_trivia_qa):
    # 从 JSON 文件中读取 BigBird 的配置
    config = BigBirdConfig.from_json_file(big_bird_config_file)
    # 打印正在根据配置构建 PyTorch 模型
    print(f"Building PyTorch model from configuration: {config}")

    # 根据是否是 TriviaQA 模型选择相应的 BigBird 模型
    if is_trivia_qa:
        model = BigBirdForQuestionAnswering(config)
    else:
        model = BigBirdForPreTraining(config)

    # 加载 TensorFlow checkpoint 中的权重到 PyTorch 模型
    load_tf_weights_in_big_bird(model, tf_checkpoint_path, is_trivia_qa=is_trivia_qa)

    # 保存 PyTorch 模型
    print(f"Save PyTorch model to {pytorch_dump_path}")
    model.save_pretrained(pytorch_dump_path)


if __name__ == "__main__":
    # 解析命令行参数
    parser = argparse.ArgumentParser()
    # 必须参数:TensorFlow checkpoint 的路径
    parser.add_argument(
        "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
    )
    # 必须参数:BigBird 模型的配置文件路径
    parser.add_argument(
        "--big_bird_config_file",
        default=None,
        type=str,
        required=True,
        help=(
            "The config json file corresponding to the pre-trained BERT model. \n"
            "This specifies the model architecture."
        ),
    )
    # 必须参数:输出的 PyTorch 模型路径
    parser.add_argument(
        "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
    )
    # 可选参数:是否包含 TriviaQA 头部
    parser.add_argument(
        "--is_trivia_qa", action="store_true", help="Whether to convert a model with a trivia_qa head."
    )
    # 解析参数
    args = parser.parse_args()
    # 调用函数,执行 TensorFlow 到 PyTorch 模型的转换
    convert_tf_checkpoint_to_pytorch(
        args.tf_checkpoint_path, args.big_bird_config_file, args.pytorch_dump_path, args.is_trivia_qa
    )

.\models\big_bird\modeling_big_bird.py

# 导入必要的库和模块
import math  # 导入数学库,用于数学运算
import os  # 导入操作系统库,用于操作文件路径等操作
from dataclasses import dataclass  # 导入dataclass模块,用于创建数据类
from typing import Optional, Tuple, Union  # 导入类型提示模块,用于类型声明

import numpy as np  # 导入NumPy库,用于数值计算
import torch  # 导入PyTorch库,用于构建和训练神经网络模型
import torch.utils.checkpoint  # 导入PyTorch的checkpoint模块,用于中间结果的保存和恢复
from torch import nn  # 导入PyTorch的神经网络模块,用于构建神经网络层
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss  # 导入PyTorch的损失函数

from ...activations import ACT2FN  # 导入激活函数,用于神经网络的非线性变换
from ...modeling_outputs import (  # 导入模型输出类,定义了不同任务的输出格式
    BaseModelOutputWithPastAndCrossAttentions,
    BaseModelOutputWithPoolingAndCrossAttentions,
    CausalLMOutputWithCrossAttentions,
    MaskedLMOutput,
    MultipleChoiceModelOutput,
    SequenceClassifierOutput,
    TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel  # 导入预训练模型基类,用于所有预训练模型的基本功能实现
from ...pytorch_utils import apply_chunking_to_forward  # 导入用于分块处理前向传播的工具函数
from ...utils import (  # 导入工具函数,用于日志记录、返回值替换等辅助功能
    ModelOutput,
    add_code_sample_docstrings,
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    logging,
    replace_return_docstrings,
)
from .configuration_big_bird import BigBirdConfig  # 导入BigBird模型的配置类,用于配置模型参数

logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器

_CHECKPOINT_FOR_DOC = "google/bigbird-roberta-base"  # 预训练模型的检查点地址,用于文档示例
_CONFIG_FOR_DOC = "BigBirdConfig"  # BigBird模型的配置信息,用于文档示例

BIG_BIRD_PRETRAINED_MODEL_ARCHIVE_LIST = [  # BigBird预训练模型的地址列表
    "google/bigbird-roberta-base",
    "google/bigbird-roberta-large",
    "google/bigbird-base-trivia-itc",
    # 查看所有BigBird模型地址:https://huggingface.co/models?filter=big_bird
]

_TRIVIA_QA_MAPPING = {  # TriviaQA数据集的映射关系,将TensorFlow模型权重映射到PyTorch模型
    "big_bird_attention": "attention/self",
    "output_layer_norm": "output/LayerNorm",
    "attention_output": "attention/output/dense",
    "output": "output/dense",
    "self_attention_layer_norm": "attention/output/LayerNorm",
    "intermediate": "intermediate/dense",
    "word_embeddings": "bert/embeddings/word_embeddings",
    "position_embedding": "bert/embeddings/position_embeddings",
    "type_embeddings": "bert/embeddings/token_type_embeddings",
    "embeddings": "bert/embeddings",
    "layer_normalization": "output/LayerNorm",
    "layer_norm": "LayerNorm",
    "trivia_qa_head": "qa_classifier",
    "dense": "intermediate/dense",
    "dense_1": "qa_outputs",
}


def load_tf_weights_in_big_bird(model, tf_checkpoint_path, is_trivia_qa=False):
    """Load tf checkpoints in a pytorch model."""
    def load_tf_weights_bert(init_vars, tf_path):
        names = []  # 用于存储变量名的列表
        tf_weights = {}  # 用于存储 TensorFlow 权重的字典

        for name, shape in init_vars:  # 遍历初始化变量的名称和形状
            array = tf.train.load_variable(tf_path, name)  # 加载 TensorFlow 模型中的变量值
            name = name.replace("bert/encoder/LayerNorm", "bert/embeddings/LayerNorm")  # 替换变量名中的特定字符串
            logger.info(f"Loading TF weight {name} with shape {shape}")  # 记录日志,显示加载的 TensorFlow 权重的名称和形状
            names.append(name)  # 将变量名添加到列表中
            tf_weights[name] = array  # 将变量名和对应的数组存储到字典中

        return names, tf_weights  # 返回变量名列表和 TensorFlow 权重字典

    def load_tf_weights_trivia_qa(init_vars):
        names = []  # 用于存储变量名的列表
        tf_weights = {}  # 用于存储 TensorFlow 权重的字典

        for i, var in enumerate(init_vars):  # 遍历初始化变量列表
            name_items = var.name.split("/")  # 使用斜杠分割变量名

            if "transformer_scaffold" in name_items[0]:  # 如果变量名中包含特定字符串
                layer_name_items = name_items[0].split("_")  # 使用下划线分割层名
                if len(layer_name_items) < 3:
                    layer_name_items += [0]  # 如果层名项少于3个,补充一个零

                name_items[0] = f"bert/encoder/layer_{layer_name_items[2]}"  # 格式化为特定的层名格式

            name = "/".join([_TRIVIA_QA_MAPPING[x] if x in _TRIVIA_QA_MAPPING else x for x in name_items])[:-2]
            # 根据映射替换变量名中的部分子串,并删除末尾的":0"

            if "self/attention/output" in name:  # 如果变量名中包含特定子串
                name = name.replace("self/attention/output", "output")  # 替换为指定的新子串

            if i >= len(init_vars) - 2:  # 如果索引超出初始化变量列表长度减2
                name = name.replace("intermediate", "output")  # 替换变量名中的特定子串为另一个

            logger.info(f"Loading TF weight {name} with shape {var.shape}")  # 记录日志,显示加载的 TensorFlow 权重的名称和形状
            array = var.value().numpy()  # 将 TensorFlow 变量的值转换为 NumPy 数组
            names.append(name)  # 将变量名添加到列表中
            tf_weights[name] = array  # 将变量名和对应的数组存储到字典中

        return names, tf_weights  # 返回变量名列表和 TensorFlow 权重字典

    try:
        import re  # 导入正则表达式模块

        import numpy as np  # 导入 NumPy 库
        import tensorflow as tf  # 导入 TensorFlow 库
    except ImportError:  # 如果导入错误
        logger.error(  # 记录错误日志
            "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
            "https://www.tensorflow.org/install/ for installation instructions."
        )
        raise  # 抛出导入错误异常

    tf_path = os.path.abspath(tf_checkpoint_path)  # 获取 TensorFlow 检查点路径的绝对路径
    logger.info(f"Converting TensorFlow checkpoint from {tf_path}")  # 记录信息日志,显示正在转换的 TensorFlow 检查点路径

    # Load weights from TF model
    init_vars = tf.saved_model.load(tf_path).variables if is_trivia_qa else tf.train.list_variables(tf_path)
    # 根据条件加载 TensorFlow 模型的变量列表或者变量名列表

    if len(init_vars) <= 0:  # 如果初始化变量列表长度小于等于0
        raise ValueError("Loaded trained variables cannot be empty.")  # 抛出数值错误异常,提示加载的训练变量不能为空

    pt_names = list(model.state_dict().keys())  # 获取 PyTorch 模型状态字典的键列表

    if is_trivia_qa:  # 如果是 TriviaQA 数据集
        names, tf_weights = load_tf_weights_trivia_qa(init_vars)  # 调用加载 TriviaQA 数据集权重的函数
    else:  # 否则
        names, tf_weights = load_tf_weights_bert(init_vars, tf_path)  # 调用加载 BERT 模型权重的函数

    logger.info(f"Weights not copied to PyTorch model: {', '.join(tf_weights.keys())}.")
    # 记录信息日志,显示未复制到 PyTorch 模型的权重名称列表

    logger.info(f"Weights not initialized in PyTorch model: {', '.join(pt_names)}.")
    # 记录信息日志,显示未在 PyTorch 模型中初始化的权重名称列表

    return model  # 返回模型
class BigBirdEmbeddings(nn.Module):
    """Construct the embeddings from word, position and token_type embeddings."""

    # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.__init__
    def __init__(self, config):
        super().__init__()
        # 创建一个词嵌入层,用于将词的索引映射为词嵌入向量
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
        # 创建一个位置嵌入层,用于将位置索引映射为位置嵌入向量
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
        # 创建一个token类型嵌入层,用于将token类型索引映射为token类型嵌入向量
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)

        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
        # any TensorFlow checkpoint file
        # 创建一个LayerNorm层,用于归一化输入向量
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        # 创建一个Dropout层,用于在训练过程中随机丢弃部分输入向量,防止过拟合
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        # position_ids (1, len position emb) is contiguous in memory and exported when serialized
        # 根据配置创建position_embedding_type,用于指定位置嵌入类型
        self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
        # 注册一个持久化的position_ids缓冲区,包含从0到config.max_position_embeddings-1的位置索引
        self.register_buffer(
            "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
        )
        # 注册一个持久化的token_type_ids缓冲区,其形状与position_ids相同,元素为零,用于token类型嵌入
        self.register_buffer(
            "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
        )
        # End copy

        # 是否对嵌入向量进行重新缩放
        self.rescale_embeddings = config.rescale_embeddings
        # 隐藏层的大小
        self.hidden_size = config.hidden_size

    def forward(
        self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
        ):
            # 如果传入的 input_ids 不为空,获取其尺寸作为 input_shape
            if input_ids is not None:
                input_shape = input_ids.size()
            else:
                # 否则,获取 inputs_embeds 的尺寸除最后一维之外的部分作为 input_shape
                input_shape = inputs_embeds.size()[:-1]

            # 获取序列的长度
            seq_length = input_shape[1]

            # 如果没有提供 position_ids,则使用预定义的位置 ID,从 self.position_ids 中截取相应长度的部分
            if position_ids is None:
                position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]

            # 设置 token_type_ids 为在构造函数中注册的缓冲区,通常情况下是全零,这对于在不传递 token_type_ids 的情况下追踪模型很有帮助,解决了问题 #5664
            if token_type_ids is None:
                if hasattr(self, "token_type_ids"):
                    # 使用已注册的 token_type_ids 缓冲区的部分来填充 token_type_ids,扩展以匹配 input_shape 的第一个维度和 seq_length
                    buffered_token_type_ids = self.token_type_ids[:, :seq_length]
                    buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
                    token_type_ids = buffered_token_type_ids_expanded
                else:
                    # 如果未注册 token_type_ids,则创建全零的 tensor 作为 token_type_ids,类型为 long,放置在 self.position_ids 设备上
                    token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)

            # 如果 inputs_embeds 为空,则使用 word_embeddings 层来获取对应的 embeddings
            if inputs_embeds is None:
                inputs_embeds = self.word_embeddings(input_ids)

            # 如果需要对 embeddings 进行重新缩放,则乘以 sqrt(hidden_size)
            if self.rescale_embeddings:
                inputs_embeds = inputs_embeds * (self.hidden_size**0.5)

            # 根据 token_type_ids 获取 token_type_embeddings
            token_type_embeddings = self.token_type_embeddings(token_type_ids)

            # 将 inputs_embeds 和 token_type_embeddings 相加得到 embeddings
            embeddings = inputs_embeds + token_type_embeddings

            # 根据 position_ids 获取 position_embeddings,并将其加到 embeddings 中
            position_embeddings = self.position_embeddings(position_ids)
            embeddings += position_embeddings

            # 对 embeddings 应用 dropout
            embeddings = self.dropout(embeddings)
            
            # 对 embeddings 应用 LayerNorm
            embeddings = self.LayerNorm(embeddings)
            
            # 返回最终的 embeddings
            return embeddings
class BigBirdSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 检查隐藏层大小是否能被注意力头数整除,并且如果没有嵌入大小的属性,抛出错误
        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
            raise ValueError(
                f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
                f"heads ({config.num_attention_heads})"
            )

        # 设置注意力头数和每个头的大小
        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        # 创建查询、键、值的线性层,用于注意力机制
        self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.use_bias)
        self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.use_bias)
        self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.use_bias)

        # 设置 dropout 层
        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
        # 标记是否为解码器
        self.is_decoder = config.is_decoder

    # 将输入转换为注意力分数的形状
    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    # 前向传播函数
    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        past_key_value=None,
        output_attentions=False,



class BigBirdBlockSparseAttention(nn.Module):
    def __init__(self, config, seed=None):
        super().__init__()

        # 设置最大序列长度和随机数种子
        self.max_seqlen = config.max_position_embeddings
        self.seed = seed

        # 检查隐藏层大小是否能被注意力头数整除,如果不能抛出错误
        if config.hidden_size % config.num_attention_heads != 0:
            raise ValueError(
                f"The hidden size {config.hidden_size} is not a multiple of the number of attention "
                f"heads {config.num_attention_heads}."
            )

        # 设置注意力头数、随机块数和块大小
        self.num_attention_heads = config.num_attention_heads
        self.num_random_blocks = config.num_random_blocks
        self.block_size = config.block_size

        # 设置注意力头大小和总大小
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        # 创建查询、键、值的线性层,用于注意力机制
        self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.use_bias)
        self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.use_bias)
        self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.use_bias)

    # 将输入转换为注意力分数的形状
    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    # 前向传播函数
    def forward(
        self,
        hidden_states,
        band_mask=None,
        from_mask=None,
        to_mask=None,
        from_blocked_mask=None,
        to_blocked_mask=None,
        output_attentions=None,
    ):
        # 目前此类无法在解码器中使用。

        # 获取隐藏状态的批量大小、序列长度和最后一个维度的信息
        batch_size, seqlen, _ = hidden_states.size()
        to_seq_length = from_seq_length = seqlen
        from_block_size = to_block_size = self.block_size

        # 检查查询侧序列长度是否是块大小的倍数
        if from_seq_length % from_block_size != 0:
            raise ValueError("Query sided sequence length must be multiple of block size")

        # 检查键/值侧序列长度是否是块大小的倍数
        if to_seq_length % to_block_size != 0:
            raise ValueError("Key/Value sided sequence length must be multiple of block size")

        # 对查询、键和值进行变换以适应注意力矩阵的计算
        query_layer = self.transpose_for_scores(self.query(hidden_states))
        key_layer = self.transpose_for_scores(self.key(hidden_states))
        value_layer = self.transpose_for_scores(self.value(hidden_states))

        # 调用大鸟模型的稀疏块注意力计算函数
        context_layer, attention_probs = self.bigbird_block_sparse_attention(
            query_layer,
            key_layer,
            value_layer,
            band_mask,
            from_mask,
            to_mask,
            from_blocked_mask,
            to_blocked_mask,
            self.num_attention_heads,
            self.num_random_blocks,
            self.attention_head_size,
            from_block_size,
            to_block_size,
            batch_size,
            from_seq_length,
            to_seq_length,
            seed=self.seed,
            plan_from_length=None,
            plan_num_rand_blocks=None,
            output_attentions=output_attentions,
        )

        # 将上下文层展开并重塑形状以匹配输入张量的预期形状
        context_layer = context_layer.contiguous().view(batch_size, from_seq_length, -1)

        # 根据需要返回注意力权重
        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
        return outputs

    @staticmethod
    def torch_bmm_nd(inp_1, inp_2, ndim=None):
        """快速的多维矩阵乘法"""
        # 使用 torch.bmm 替代 torch.einsum 进行更快的矩阵乘法计算 ("bhqk,bhkd->bhqd")
        return torch.bmm(inp_1.reshape((-1,) + inp_1.shape[-2:]), inp_2.reshape((-1,) + inp_2.shape[-2:])).view(
            inp_1.shape[: ndim - 2] + (inp_1.shape[ndim - 2], inp_2.shape[ndim - 1])
        )

    @staticmethod
    def torch_bmm_nd_transpose(inp_1, inp_2, ndim=None):
        """带转置的快速多维矩阵乘法"""
        # 使用 torch.bmm 替代 torch.einsum 进行更快的矩阵乘法计算 ("bhqd,bhkd->bhqk")
        return torch.bmm(
            inp_1.reshape((-1,) + inp_1.shape[-2:]), inp_2.reshape((-1,) + inp_2.shape[-2:]).transpose(1, 2)
        ).view(inp_1.shape[: ndim - 2] + (inp_1.shape[ndim - 2], inp_2.shape[ndim - 2]))

    def bigbird_block_sparse_attention(
        self,
        query_layer,
        key_layer,
        value_layer,
        band_mask,
        from_mask,
        to_mask,
        from_blocked_mask,
        to_blocked_mask,
        n_heads,
        n_rand_blocks,
        attention_head_size,
        from_block_size,
        to_block_size,
        batch_size,
        from_seq_len,
        to_seq_len,
        seed,
        plan_from_length,
        plan_num_rand_blocks,
        output_attentions,
    ):
        # 大鸟模型的稀疏块注意力计算函数,详细实现略过
        pass
    def torch_gather_b2(params, indices):
        # this operation is equivalent to tf.gather when batch_dims=2

        # 检查 params 和 indices 的前两个维度是否相同
        if params.shape[:2] != indices.shape[:2]:
            raise ValueError(
                "Make sure that the first two dimensions of params and indices are identical, "
                f"but they are params: {params.shape[:2]} vs. indices: {indices.shape[:2]}"
            )

        # 计算需要收集的索引数量
        num_indices_to_gather = indices.shape[-2] * indices.shape[-1]
        # 获取 params 中可选择的索引数量
        num_indices_to_pick_from = params.shape[2]

        # 创建偏移量,以便在展平的 indices 上进行选择
        shift = torch.arange(indices.shape[0] * indices.shape[1] * num_indices_to_gather, device=indices.device)
        indices_shift = torch.div(shift, num_indices_to_gather, rounding_mode="floor") * num_indices_to_pick_from

        # 将 indices 展平并添加偏移量,以便在 params 中选择对应数据
        flattened_indices = indices.view(-1) + indices_shift
        flattened_params = params.reshape(-1, params.shape[-2], params.shape[-1])

        # 使用展平后的 indices 在 params 中进行选择
        out_flattened = flattened_params.index_select(0, flattened_indices)

        # 将结果重新形状为原始形状,包括收集的索引数量维度
        out = out_flattened.reshape(params.shape[:2] + (num_indices_to_gather,) + params.shape[3:])
        return out

    @staticmethod
    def _create_rand_mask_from_inputs(
        from_blocked_mask,
        to_blocked_mask,
        rand_attn,
        num_attention_heads,
        num_rand_blocks,
        batch_size,
        from_seq_length,
        from_block_size,
    ):
        """
        Create 3D attention mask from a 2D tensor mask.

        Args:
            from_blocked_mask: 2D Tensor of shape [batch_size, from_seq_length//from_block_size, from_block_size].
            to_blocked_mask: int32 Tensor of shape [batch_size, to_seq_length//to_block_size, to_block_size].
            rand_attn: [batch_size, num_attention_heads, from_seq_length//from_block_size-2, num_rand_blocks]
            num_attention_heads: int. Number of attention heads.
            num_rand_blocks: int. Number of random chunks per row.
            batch_size: int. Batch size for computation.
            from_seq_length: int. length of from sequence.
            from_block_size: int. size of block in from sequence.

        Returns:
            float Tensor of shape [batch_size, num_attention_heads, from_seq_length//from_block_size-2,
            from_block_size, num_rand_blocks*to_block_size].
        """
        # 计算窗口的数量
        num_windows = from_seq_length // from_block_size - 2
        # 从 to_blocked_mask 和 rand_attn 创建随机掩码
        rand_mask = torch.stack([p1[i1.flatten()] for p1, i1 in zip(to_blocked_mask, rand_attn)])
        # 将随机掩码重塑为所需的形状
        rand_mask = rand_mask.view(batch_size, num_attention_heads, num_windows, num_rand_blocks * from_block_size)
        # 使用 einsum 创建最终的随机掩码
        rand_mask = torch.einsum("blq,bhlk->bhlqk", from_blocked_mask[:, 1:-1], rand_mask)
        return rand_mask
    # 定义一个函数,用于生成随机注意力的分布计划
    def _get_rand_attn_plan(from_seq_length, from_block_size, num_rand_blocks):
        """
        Gives the plan of where to put random attention.

        Args:
            from_seq_length: int. length of from sequence. 输入序列的长度
            from_block_size: int. size of block in from sequence. 输入序列中每个块的大小
            num_rand_blocks: int. Number of random chunks per row. 每行的随机块数

        Returns:
            plan_from_length: ending location of from block plan_num_rand_blocks: number of random ending location for
            each block 返回计划的输入序列块的结束位置和每个块的随机结束位置数量
        """

        # 初始化存储计划信息的列表
        plan_from_length = []
        plan_num_rand_blocks = []

        # 根据条件生成计划信息
        if (2 * num_rand_blocks + 5) < (from_seq_length // from_block_size):
            # 如果满足条件,计划从块的长度为 (2 * num_rand_blocks + 5) * from_block_size
            plan_from_length.append(int((2 * num_rand_blocks + 5) * from_block_size))
            plan_num_rand_blocks.append(num_rand_blocks)
            # 输入序列长度为 from_seq_length
            plan_from_length.append(from_seq_length)
            # 随机块数为 0
            plan_num_rand_blocks.append(0)
        elif (num_rand_blocks + 5) < (from_seq_length // from_block_size):
            # 否则,如果满足条件,计划从块的长度为 (num_rand_blocks + 5) * from_block_size
            plan_from_length.append(int((num_rand_blocks + 5) * from_block_size))
            # 前一半块的随机块数为 num_rand_blocks // 2
            plan_num_rand_blocks.append(num_rand_blocks // 2)
            # 输入序列长度为 from_seq_length
            plan_from_length.append(from_seq_length)
            # 后一半块的随机块数为 num_rand_blocks - (num_rand_blocks // 2)
            plan_num_rand_blocks.append(num_rand_blocks - (num_rand_blocks // 2))
        else:
            # 否则,输入序列长度为 from_seq_length
            plan_from_length.append(from_seq_length)
            # 随机块数为 num_rand_blocks
            plan_num_rand_blocks.append(num_rand_blocks)

        # 返回生成的计划信息
        return plan_from_length, plan_num_rand_blocks

    # 定义一个函数,用于生成 BigBird 模型的随机掩码
    def _bigbird_block_rand_mask(
        self, from_seq_length, to_seq_length, from_block_size, to_block_size, num_rand_blocks, last_idx=-1
        ):
    # 创建随机注意力的邻接列表。
    
    Args:
        from_seq_length: int. 来源序列的长度。
        to_seq_length: int. 目标序列的长度。
        from_block_size: int. 来源序列中的块大小。
        to_block_size: int. 目标序列中的块大小。
        num_rand_blocks: int. 每行随机块的数量。
        last_idx: int. 如果为-1,则从目标序列中任意选择 num_rand_blocks 个块;
                  如果为正数,则只选择到 last_idx 为止的 num_rand_blocks 个块。
    
    Returns:
        邻接列表,大小为 from_seq_length//from_block_size-2 行,num_rand_blocks 列。
        表示每个源序列块与随机选择的目标序列块之间的注意力关系。
    def _bigbird_block_rand_mask_with_head(
        self,
        from_seq_length,
        to_seq_length,
        from_block_size,
        to_block_size,
        num_heads,
        plan_from_length,
        plan_num_rand_blocks,
        window_block_left=1,
        window_block_right=1,
        global_block_top=1,
        global_block_bottom=1,
        global_block_left=1,
        global_block_right=1,
    ):
        """
        Generates a random mask for BigBird attention with head information.

        Args:
            from_seq_length: int. Length of the source sequence.
            to_seq_length: int. Length of the target sequence.
            from_block_size: int. Block size of the source sequence.
            to_block_size: int. Block size of the target sequence.
            num_heads: int. Number of attention heads.
            plan_from_length: int. Planned length of the source sequence.
            plan_num_rand_blocks: int. Planned number of random blocks for attention.
            window_block_left: int. Number of blocks in the window to the left of a block.
            window_block_right: int. Number of blocks in the window to the right of a block.
            global_block_top: int. Number of global blocks used at the top.
            global_block_bottom: int. Number of global blocks used at the bottom.
            global_block_left: int. Number of global blocks used to the left.
            global_block_right: int. Number of global blocks used to the right.

        Returns:
            A randomly masked attention matrix with head information.
        """
        # Implementation details for generating a random mask with BigBird constraints
        pass

    @staticmethod
    def _get_single_block_row_attention(
        block_id,
        to_start_block_id,
        to_end_block_id,
        num_rand_blocks,
        window_block_left=1,
        window_block_right=1,
        global_block_left=1,
        global_block_right=1,
    ):
        """
        For a single row block, get random row attention.

        Args:
            block_id: int. Block ID of the row.
            to_start_block_id: int. Start ID of the attention column.
            to_end_block_id: int. End ID of the attention column.
            num_rand_blocks: int. Number of random blocks to select.
            window_block_left: int. Number of blocks in the window to the left of a block.
            window_block_right: int. Number of blocks in the window to the right of a block.
            global_block_left: int. Number of global blocks used to the left.
            global_block_right: int. Number of global blocks used to the right.

        Returns:
            Array containing the random attention vector of size num_rand_blocks.
        """
        # List of to_blocks from which to choose random attention
        to_block_list = np.arange(to_start_block_id, to_end_block_id, dtype=np.int32)
        # Permute the blocks
        perm_block = np.random.permutation(to_block_list)

        # Illegal blocks for the current block_id, using window
        illegal_blocks = list(range(block_id - window_block_left, block_id + window_block_right + 1))

        # Add blocks at the start and at the end
        illegal_blocks.extend(list(range(global_block_left)))
        illegal_blocks.extend(list(range(to_end_block_id - global_block_right, to_end_block_id)))

        # The second from_block cannot choose random attention on second last to_block
        if block_id == 1:
            illegal_blocks.append(to_end_block_id - 2)

        # The second last from_block cannot choose random attention on second to_block
        if block_id == to_end_block_id - 2:
            illegal_blocks.append(1)

        selected_random_blocks = []

        for i in range(to_end_block_id - to_start_block_id):
            if perm_block[i] not in illegal_blocks:
                selected_random_blocks.append(perm_block[i])
            if len(selected_random_blocks) == num_rand_blocks:
                break
        return np.array(selected_random_blocks, dtype=np.int32)
# 从transformers.models.bert.modeling_bert.BertSelfOutput中复制代码,并将Bert改为BigBird
class BigBirdSelfOutput(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 创建一个全连接层,将输入特征大小映射为输出特征大小
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        # 创建一个LayerNorm层,用于归一化输入数据
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        # 创建一个Dropout层,用于随机失活神经元,防止过拟合
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
        # 将输入的隐藏状态通过全连接层映射到新的空间
        hidden_states = self.dense(hidden_states)
        # 对映射后的结果进行随机失活
        hidden_states = self.dropout(hidden_states)
        # 将映射结果与输入张量进行残差连接,并通过LayerNorm进行归一化处理
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        # 返回归一化后的结果张量
        return hidden_states


class BigBirdAttention(nn.Module):
    def __init__(self, config, seed=None):
        super().__init__()
        # 初始化注意力类型
        self.attention_type = config.attention_type
        # 存储配置信息
        self.config = config
        # 存储随机种子信息
        self.seed = seed

        # 根据配置选择不同的注意力类型
        if self.config.attention_type == "original_full":
            # 如果是原始全注意力类型,则使用BigBirdSelfAttention
            self.self = BigBirdSelfAttention(config)
        elif self.config.attention_type == "block_sparse":
            # 如果是块稀疏注意力类型,则使用BigBirdBlockSparseAttention
            self.self = BigBirdBlockSparseAttention(config, seed)
        else:
            # 如果配置的注意力类型不在支持范围内,则抛出错误
            raise ValueError(
                f"attention_type can either be original_full or block_sparse, but is {self.config.attention_type}"
            )

        # 创建自定义的输出层
        self.output = BigBirdSelfOutput(config)

    def set_attention_type(self, value: str):
        # 如果设置的注意力类型不在支持的范围内,则抛出错误
        if value not in ["original_full", "block_sparse"]:
            raise ValueError(
                f"attention_type can only be set to either 'original_full' or 'block_sparse', but is {value}"
            )
        # 如果设置的注意力类型与当前类型一致,则直接返回
        if value == self.attention_type:
            return

        # 更新当前的注意力类型
        self.attention_type = value
        # 根据新的注意力类型重新设置self.self
        if value == "original_full":
            # 复制所有权重到新的全注意力类
            attn_weights = BigBirdSelfAttention(self.config)
        else:
            # 复制所有权重到新的稀疏注意力类
            attn_weights = BigBirdBlockSparseAttention(self.config, self.seed)

        # 将当前的查询、键、值权重复制到新的注意力对象中
        attn_weights.query = self.self.query
        attn_weights.value = self.self.value
        attn_weights.key = self.self.key
        # 更新self.self为新的注意力对象
        self.self = attn_weights
        # 更新注意力类型
        self.attention_type = value
        # 如果不在训练状态下,评估更新后的self.self
        if not self.training:
            self.self.eval()

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        past_key_value=None,
        output_attentions=False,
        # 块稀疏配置
        band_mask=None,
        from_mask=None,
        to_mask=None,
        from_blocked_mask=None,
        to_blocked_mask=None,
        # fp16 compatibility
        # 如果使用了 fp16,需要确保 band_mask、from_mask、to_mask 的数据类型与 hidden_states 相匹配
        if band_mask is not None:
            band_mask = band_mask.to(hidden_states.dtype)
        if from_mask is not None:
            from_mask = from_mask.to(hidden_states.dtype)
        if to_mask is not None:
            to_mask = to_mask.to(hidden_states.dtype)

        # 根据不同的 attention_type 选择不同的 self-attention 计算方式
        if self.attention_type == "original_full":
            # 使用全连接注意力机制进行自注意力计算
            self_outputs = self.self(
                hidden_states,
                attention_mask,
                head_mask,
                encoder_hidden_states,
                encoder_attention_mask,
                past_key_value,
                output_attentions,
            )
        else:
            # 如果是 BigBird 模型,且作为解码器使用时抛出错误
            if encoder_hidden_states is not None:
                raise ValueError("BigBird cannot be used as a decoder when config.attention_type != 'original_full'")
            # 使用 BigBird 特有的部分连接注意力机制进行自注意力计算
            self_outputs = self.self(
                hidden_states, band_mask, from_mask, to_mask, from_blocked_mask, to_blocked_mask, output_attentions
            )

        # 将 self-attention 的输出作为输入,经过输出层处理得到最终的注意力输出
        attention_output = self.output(self_outputs[0], hidden_states)

        # 如果需要输出注意力权重,将它们添加到输出元组中
        outputs = (attention_output,) + self_outputs[1:]  # 如果输出了注意力权重,则添加到输出中
        return outputs
# 从 transformers.models.bert.modeling_bert.BertIntermediate 复制并修改为 BigBirdIntermediate 类
class BigBirdIntermediate(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 创建一个全连接层,将输入的隐藏状态大小调整为中间状态大小
        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
        # 根据配置选择隐藏层激活函数,如果是字符串则从预定义映射中选择对应函数,否则直接使用配置中的函数
        if isinstance(config.hidden_act, str):
            self.intermediate_act_fn = ACT2FN[config.hidden_act]
        else:
            self.intermediate_act_fn = config.hidden_act

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # 将隐藏状态输入全连接层进行线性变换
        hidden_states = self.dense(hidden_states)
        # 应用中间层激活函数到变换后的隐藏状态
        hidden_states = self.intermediate_act_fn(hidden_states)
        # 返回经过线性变换和激活函数后的隐藏状态
        return hidden_states


# 从 transformers.models.bert.modeling_bert.BertOutput 复制并修改为 BigBirdOutput 类
class BigBirdOutput(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 创建一个全连接层,将中间状态大小调整为隐藏状态大小
        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
        # 创建一个 LayerNorm 层,对隐藏状态进行归一化
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        # 创建一个 Dropout 层,用于随机置零隐藏状态中的部分元素
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
        # 将隐藏状态输入全连接层进行线性变换
        hidden_states = self.dense(hidden_states)
        # 对变换后的隐藏状态应用 Dropout 层
        hidden_states = self.dropout(hidden_states)
        # 对加和后的结果进行 LayerNorm 归一化处理
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        # 返回经过线性变换、Dropout 和 LayerNorm 后的隐藏状态
        return hidden_states


# BigBirdLayer 类,定义 BigBird 模型的一个层
class BigBirdLayer(nn.Module):
    def __init__(self, config, seed=None):
        super().__init__()
        self.config = config
        self.attention_type = config.attention_type
        self.chunk_size_feed_forward = config.chunk_size_feed_forward
        self.seq_len_dim = 1
        # 创建 BigBirdAttention 层,用于处理注意力机制
        self.attention = BigBirdAttention(config, seed=seed)
        self.is_decoder = config.is_decoder
        self.add_cross_attention = config.add_cross_attention
        # 如果设置了交叉注意力,确保模型是解码器模型,否则抛出错误
        if self.add_cross_attention:
            if not self.is_decoder:
                raise TypeError(f"{self} should be used as a decoder model if cross attention is added")
            self.crossattention = BigBirdAttention(config)
        # 创建 BigBirdIntermediate 层,用于处理中间层操作
        self.intermediate = BigBirdIntermediate(config)
        # 创建 BigBirdOutput 层,用于处理输出层操作
        self.output = BigBirdOutput(config)

    def set_attention_type(self, value: str):
        # 如果给定的注意力类型不是 'original_full' 或 'block_sparse',抛出 ValueError
        if value not in ["original_full", "block_sparse"]:
            raise ValueError(
                f"attention_type can only be set to either 'original_full' or 'block_sparse', but is {value}"
            )
        # 如果当前注意力类型已经正确设置,则直接返回
        if value == self.attention_type:
            return
        # 否则更新注意力类型,并将新类型应用到注意力层和交叉注意力层(如果存在)
        self.attention_type = value
        self.attention.set_attention_type(value)

        if self.add_cross_attention:
            self.crossattention.set_attention_type(value)
    # 定义神经网络的前向传播函数,用于推断或训练过程中的前向计算
    def forward(
        self,
        hidden_states,                    # 输入的隐藏状态张量,通常是模型的输出或前一层的输出
        attention_mask=None,              # 注意力掩码,用于指定哪些位置需要屏蔽,通常用于处理变长序列
        head_mask=None,                   # 头部掩码,用于指定哪些注意力头部需要屏蔽
        encoder_hidden_states=None,       # 编码器的隐藏状态,用于跨层注意力等任务
        encoder_attention_mask=None,      # 编码器的注意力掩码,指定哪些编码器位置需要屏蔽
        band_mask=None,                   # 带状掩码,用于指定注意力矩阵中的带状结构
        from_mask=None,                   # 起始位置掩码,指定从哪些位置开始注意力计算
        to_mask=None,                     # 终止位置掩码,指定到哪些位置结束注意力计算
        blocked_encoder_mask=None,        # 阻塞编码器掩码,用于指定哪些编码器隐藏状态应被屏蔽
        past_key_value=None,              # 过去的键值对,用于缓存前向传播中的注意力权重等信息
        output_attentions=False,          # 是否输出注意力权重信息,默认为不输出

        # 定义神经网络的前向传播函数,用于推断或训练过程中的前向计算
        def forward(
            self,
            hidden_states,                    # 输入的隐藏状态张量,通常是模型的输出或前一层的输出
            attention_mask=None,              # 注意力掩码,用于指定哪些位置需要屏蔽,通常用于处理变长序列
            head_mask=None,                   # 头部掩码,用于指定哪些注意力头部需要屏蔽
            encoder_hidden_states=None,       # 编码器的隐藏状态,用于跨层注意力等任务
            encoder_attention_mask=None,      # 编码器的注意力掩码,指定哪些编码器位置需要屏蔽
            band_mask=None,                   # 带状掩码,用于指定注意力矩阵中的带状结构
            from_mask=None,                   # 起始位置掩码,指定从哪些位置开始注意力计算
            to_mask=None,                     # 终止位置掩码,指定到哪些位置结束注意力计算
            blocked_encoder_mask=None,        # 阻塞编码器掩码,用于指定哪些编码器隐藏状态应被屏蔽
            past_key_value=None,              # 过去的键值对,用于缓存前向传播中的注意力权重等信息
            output_attentions=False,          # 是否输出注意力权重信息,默认为不输出
    ):
        # 如果过去的键/值对存在,则只保留自注意力的缓存的前两个位置(1和2)
        self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
        # 执行自注意力计算
        self_attention_outputs = self.attention(
            hidden_states,
            attention_mask,
            head_mask,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            past_key_value=self_attn_past_key_value,
            output_attentions=output_attentions,
            band_mask=band_mask,
            from_mask=from_mask,
            to_mask=to_mask,
            from_blocked_mask=blocked_encoder_mask,
            to_blocked_mask=blocked_encoder_mask,
        )
        # 获取自注意力计算的输出
        attention_output = self_attention_outputs[0]

        # 如果模型是解码器,最后一个输出是自注意力缓存的元组
        if self.is_decoder:
            # 排除最后一个元素,它是自注意力缓存的元组
            outputs = self_attention_outputs[1:-1]
            # 获取自注意力计算的当前键/值对
            present_key_value = self_attention_outputs[-1]
        else:
            # 排除第一个元素,因为我们输出注意力权重时需要添加自注意力
            outputs = self_attention_outputs[1:]

        cross_attn_present_key_value = None
        # 如果是解码器并且存在编码器的隐藏状态
        if self.is_decoder and encoder_hidden_states is not None:
            # 如果模型没有交叉注意力层,则引发错误
            if not hasattr(self, "crossattention"):
                raise ValueError(
                    f"If `encoder_hidden_states` are passed, {self} has to be instantiated with                    "
                    " cross-attention layers by setting `config.add_cross_attention=True`"
                )

            # 如果过去的键/值对存在,则只保留交叉注意力缓存的后两个位置(3和4)
            cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
            # 执行交叉注意力计算
            cross_attention_outputs = self.crossattention(
                attention_output,
                attention_mask,
                head_mask,
                encoder_hidden_states,
                encoder_attention_mask,
                cross_attn_past_key_value,
                output_attentions,
            )
            # 获取交叉注意力计算的输出
            attention_output = cross_attention_outputs[0]
            # 添加交叉注意力计算的输出(排除最后一个元素,它是交叉注意力缓存的元组)
            outputs = outputs + cross_attention_outputs[1:-1]

            # 将交叉注意力缓存添加到当前键/值对的第三和第四位置
            cross_attn_present_key_value = cross_attention_outputs[-1]
            present_key_value = present_key_value + cross_attn_present_key_value

        # 将注意力输出应用于前向网络的分块处理
        layer_output = apply_chunking_to_forward(
            self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
        )

        # 组装最终的输出元组
        outputs = (layer_output,) + outputs

        # 如果是解码器,将注意力的键/值对作为最后一个输出返回
        if self.is_decoder:
            outputs = outputs + (present_key_value,)

        return outputs
    # 定义一个方法,用于执行神经网络的前向传播,处理给定的注意力输出
    def feed_forward_chunk(self, attention_output):
        # 调用 self.intermediate 方法,对注意力输出进行中间处理
        intermediate_output = self.intermediate(attention_output)
        # 调用 self.output 方法,将中间处理后的结果和注意力输出作为参数,生成最终的层输出
        layer_output = self.output(intermediate_output, attention_output)
        # 返回最终的层输出作为结果
        return layer_output
# 定义 BigBirdEncoder 类,继承自 nn.Module
class BigBirdEncoder(nn.Module):
    # 初始化方法,接受 config 参数
    def __init__(self, config):
        super().__init__()
        self.config = config  # 将传入的 config 参数保存到实例变量中
        self.attention_type = config.attention_type  # 从 config 中获取 attention_type 参数

        # 创建一个包含多个 BigBirdLayer 实例的列表,每个实例都使用不同的 seed
        self.layer = nn.ModuleList(
            [BigBirdLayer(config, seed=layer_idx) for layer_idx in range(config.num_hidden_layers)]
        )

        self.gradient_checkpointing = False  # 初始化 gradient_checkpointing 标志为 False

    # 设置 attention_type 的方法,接受一个字符串参数 value
    def set_attention_type(self, value: str):
        # 如果 value 不是 "original_full" 或 "block_sparse",抛出 ValueError 异常
        if value not in ["original_full", "block_sparse"]:
            raise ValueError(
                f"attention_type can only be set to either 'original_full' or 'block_sparse', but is {value}"
            )
        # 如果当前 attention_type 已经是要设置的值,则直接返回,不进行更改
        if value == self.attention_type:
            return
        # 更新 attention_type 为新的值
        self.attention_type = value
        # 遍历所有层并设置它们的 attention_type
        for layer in self.layer:
            layer.set_attention_type(value)

    # 前向传播方法定义
    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        past_key_values=None,
        use_cache=None,
        output_attentions=False,
        output_hidden_states=False,
        band_mask=None,
        from_mask=None,
        to_mask=None,
        blocked_encoder_mask=None,
        return_dict=True,
    ):
        # 这里是 BigBirdEncoder 的前向传播逻辑,具体实现在这里面
        pass  # 在这里应该填写实际的前向传播逻辑,暂时为空



# 定义 BigBirdPredictionHeadTransform 类,继承自 nn.Module
class BigBirdPredictionHeadTransform(nn.Module):
    # 初始化方法,接受 config 参数
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)  # 创建一个线性层
        # 根据 config 中的 hidden_act 字符串或函数设置 transform_act_fn
        if isinstance(config.hidden_act, str):
            self.transform_act_fn = ACT2FN[config.hidden_act]
        else:
            self.transform_act_fn = config.hidden_act
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)  # 创建 LayerNorm 层

    # 前向传播方法定义,接受输入 hidden_states,并返回输出的 torch.Tensor
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states = self.dense(hidden_states)  # 将输入通过线性层 dense
        hidden_states = self.transform_act_fn(hidden_states)  # 经过激活函数变换
        hidden_states = self.LayerNorm(hidden_states)  # 应用 LayerNorm
        return hidden_states  # 返回变换后的 hidden_states



# 定义 BigBirdLMPredictionHead 类,继承自 nn.Module
class BigBirdLMPredictionHead(nn.Module):
    # 初始化方法,接受 config 参数
    def __init__(self, config):
        super().__init__()
        self.transform = BigBirdPredictionHeadTransform(config)  # 创建 BigBirdPredictionHeadTransform 实例
        # 输出层是一个线性层,将隐藏状态映射到词汇表大小的向量,没有偏置
        self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        self.bias = nn.Parameter(torch.zeros(config.vocab_size))  # 创建偏置参数
        self.decoder.bias = self.bias  # 将偏置参数与 decoder 层关联,以便与 resize_token_embeddings 正确调整大小

    # 前向传播方法定义,接受输入 hidden_states,并返回输出的 torch.Tensor
    def forward(self, hidden_states):
        hidden_states = self.transform(hidden_states)  # 将输入经过 transform 处理
        hidden_states = self.decoder(hidden_states)  # 使用 decoder 进行线性变换
        return hidden_states  # 返回变换后的 hidden_states
# 从 transformers.models.bert.modeling_bert.BertOnlyMLMHead 复制并将 Bert 改为 BigBird
class BigBirdOnlyMLMHead(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 初始化 MLM 头部预测层
        self.predictions = BigBirdLMPredictionHead(config)

    def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
        # 调用预测层进行预测
        prediction_scores = self.predictions(sequence_output)
        return prediction_scores


# 从 transformers.models.bert.modeling_bert.BertOnlyNSPHead 复制并将 Bert 改为 BigBird
class BigBirdOnlyNSPHead(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 初始化 NSP 头部的线性层
        self.seq_relationship = nn.Linear(config.hidden_size, 2)

    def forward(self, pooled_output):
        # 计算序列关系分数
        seq_relationship_score = self.seq_relationship(pooled_output)
        return seq_relationship_score


# 从 transformers.models.bert.modeling_bert.BertPreTrainingHeads 复制并将 Bert 改为 BigBird
class BigBirdPreTrainingHeads(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 初始化 MLM 预测头部和 NSP 头部
        self.predictions = BigBirdLMPredictionHead(config)
        self.seq_relationship = nn.Linear(config.hidden_size, 2)

    def forward(self, sequence_output, pooled_output):
        # 分别调用预测层和序列关系层进行计算
        prediction_scores = self.predictions(sequence_output)
        seq_relationship_score = self.seq_relationship(pooled_output)
        return prediction_scores, seq_relationship_score


class BigBirdPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    config_class = BigBirdConfig
    load_tf_weights = load_tf_weights_in_big_bird
    base_model_prefix = "bert"
    supports_gradient_checkpointing = True

    def _init_weights(self, module):
        """Initialize the weights"""
        if isinstance(module, nn.Linear):
            # 使用正态分布初始化线性层的权重
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.bias is not None:
                # 如果存在偏置,则将偏置初始化为零
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            # 使用正态分布初始化嵌入层的权重
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.padding_idx is not None:
                # 如果有填充索引,则将填充索引处的权重初始化为零
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            # 将 LayerNorm 层的偏置初始化为零,权重初始化为 1.0
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)


BIG_BIRD_START_DOCSTRING = r"""
    This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
    it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
    behavior.
"""
    Parameters:
        config ([`BigBirdConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""

BIG_BIRD_INPUTS_DOCSTRING = r"""
    Args:
        input_ids (`torch.LongTensor` of shape `({0})`):
            Indices of input sequence tokens in the vocabulary.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            [What are attention masks?](../glossary#attention-mask)
        token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
            Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
            1]`:

            - 0 corresponds to a *sentence A* token,
            - 1 corresponds to a *sentence B* token.

            [What are token type IDs?](../glossary#token-type-ids)
        position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
            config.max_position_embeddings - 1]`.

            [What are position IDs?](../glossary#position-ids)
        head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
            Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

        inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
            is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
            model's internal embedding lookup matrix.
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""


@dataclass
class BigBirdForPreTrainingOutput(ModelOutput):
    """
    Output type of [`BigBirdForPreTraining`].
    
    This class defines the output structure for the BigBird model during pre-training.
    """
    Args:
        loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
            总损失,由掩码语言建模损失和下一个序列预测(分类)损失之和组成。
        prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
            语言建模头部的预测分数(SoftMax之前的每个词汇标记的分数)。
        seq_relationship_logits (`torch.FloatTensor` of shape `(batch_size, 2)`):
            下一个序列预测(分类)头部的预测分数(SoftMax之前的True/False延续的分数)。
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            模型在每层输出的隐藏状态的元组,包括初始嵌入输出。
            每个元素是 `torch.FloatTensor`,形状为 `(batch_size, sequence_length, hidden_size)`。
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            注意力权重的元组,用于计算自注意力头部中的加权平均值。
            每个元素是 `torch.FloatTensor`,形状为 `(batch_size, num_heads, sequence_length, sequence_length)`。
    """

    loss: Optional[torch.FloatTensor] = None
    prediction_logits: torch.FloatTensor = None
    seq_relationship_logits: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None
# 定义一个数据类,用于存储问题回答模型的输出结果
@dataclass
class BigBirdForQuestionAnsweringModelOutput(ModelOutput):
    """
    Base class for outputs of question answering models.

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
        start_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
            Span-start scores (before SoftMax).
        end_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
            Span-end scores (before SoftMax).
        pooler_output (`torch.FloatTensor` of shape `(batch_size, 1)`):
            pooler output from BigBirdModel
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    """

    # 损失值,如果提供了`labels`则返回,表示总的跨度提取损失,由起始和结束位置的交叉熵之和组成
    loss: Optional[torch.FloatTensor] = None
    # 跨度起始得分(SoftMax之前),形状为(batch_size, sequence_length)
    start_logits: torch.FloatTensor = None
    # 跨度结束得分(SoftMax之前),形状为(batch_size, sequence_length)
    end_logits: torch.FloatTensor = None
    # BigBirdModel的汇聚输出,形状为(batch_size, 1)
    pooler_output: torch.FloatTensor = None
    # 隐藏状态,如果`output_hidden_states=True`则返回,是一个元组,包含了每一层的输出,形状为(batch_size, sequence_length, hidden_size)
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    # 注意力权重,如果`output_attentions=True`则返回,是一个元组,包含了每一层的注意力权重,形状为(batch_size, num_heads, sequence_length, sequence_length)
    attentions: Optional[Tuple[torch.FloatTensor]] = None


# BigBird模型类,继承自BigBirdPreTrainedModel
@add_start_docstrings(
    "The bare BigBird Model transformer outputting raw hidden-states without any specific head on top.",
    BIG_BIRD_START_DOCSTRING,
)
class BigBirdModel(BigBirdPreTrainedModel):
    """

    The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
    cross-attention is added between the self-attention layers, following the architecture described in [Attention is
    all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
    Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.

    To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
    to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
    `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
    """
    # 初始化函数,用于初始化 BigBirdForCausalLM 类的实例
    def __init__(self, config, add_pooling_layer=True):
        # 调用父类的初始化函数
        super().__init__(config)
        # 设置注意力机制类型为配置文件中指定的类型
        self.attention_type = self.config.attention_type
        # 保存传入的配置参数
        self.config = config

        # 设置模型的块大小为配置中指定的块大小
        self.block_size = self.config.block_size

        # 初始化嵌入层,使用 BigBirdEmbeddings 类
        self.embeddings = BigBirdEmbeddings(config)
        # 初始化编码器,使用 BigBirdEncoder 类
        self.encoder = BigBirdEncoder(config)

        # 如果需要添加池化层
        if add_pooling_layer:
            # 创建一个线性层用于池化
            self.pooler = nn.Linear(config.hidden_size, config.hidden_size)
            # 激活函数为双曲正切函数
            self.activation = nn.Tanh()
        else:
            # 如果不需要池化层,设置为 None
            self.pooler = None
            self.activation = None

        # 如果注意力类型不是 "original_full" 且配置要求添加交叉注意力
        if self.attention_type != "original_full" and config.add_cross_attention:
            # 发出警告并强制将 attention_type 设为 "original_full"
            logger.warning(
                "When using `BigBirdForCausalLM` as decoder, then `attention_type` must be `original_full`. Setting"
                " `attention_type=original_full`"
            )
            self.set_attention_type("original_full")

        # 初始化权重并应用最终处理
        self.post_init()

    # 获取输入嵌入层的方法
    def get_input_embeddings(self):
        return self.embeddings.word_embeddings

    # 设置输入嵌入层的方法
    def set_input_embeddings(self, value):
        self.embeddings.word_embeddings = value

    # 设置注意力类型的方法
    def set_attention_type(self, value: str):
        # 如果值不是 "original_full" 或 "block_sparse",抛出异常
        if value not in ["original_full", "block_sparse"]:
            raise ValueError(
                f"attention_type can only be set to either 'original_full' or 'block_sparse', but is {value}"
            )
        # 如果当前的 attention_type 已经正确设置,则直接返回
        if value == self.attention_type:
            return
        # 否则更新 attention_type,并通知编码器更新注意力类型
        self.attention_type = value
        self.encoder.set_attention_type(value)

    # 前向传播函数,实现了 BigBirdForCausalLM 类的前向计算过程
    @add_start_docstrings_to_model_forward(BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=BaseModelOutputWithPoolingAndCrossAttentions,
        config_class=_CONFIG_FOR_DOC,
    )
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        # 此处定义了前向传播的参数和返回值类型,详细文档请参考相应的注释和文档字符串
        pass  # 此处仅为占位符,实际前向传播功能未在此展示
    def create_masks_for_block_sparse_attn(attention_mask: torch.Tensor, block_size: int):
        # 获取批次大小和序列长度
        batch_size, seq_length = attention_mask.size()
        
        # 检查序列长度是否是块大小的整数倍,如果不是则引发异常
        if seq_length % block_size != 0:
            raise ValueError(
                f"Sequence length must be multiple of block size, but sequence length is {seq_length}, while block"
                f" size is {block_size}."
            )

        def create_band_mask_from_inputs(from_blocked_mask, to_blocked_mask):
            """
            从2D张量掩码创建3D注意力掩码。

            Args:
                from_blocked_mask: 形状为[batch_size, from_seq_length//from_block_size, from_block_size]的2D张量。
                to_blocked_mask: 形状为[batch_size, to_seq_length//to_block_size, to_block_size]的int32张量。

            Returns:
                形状为[batch_size, 1, from_seq_length//from_block_size-4, from_block_size, 3*to_block_size]的浮点张量。
            """
            # 从输入的块掩码创建带状掩码
            exp_blocked_to_pad = torch.cat(
                [to_blocked_mask[:, 1:-3], to_blocked_mask[:, 2:-2], to_blocked_mask[:, 3:-1]], dim=2
            )
            band_mask = torch.einsum("blq,blk->blqk", from_blocked_mask[:, 2:-2], exp_blocked_to_pad)
            band_mask.unsqueeze_(1)
            return band_mask

        # 将注意力掩码重塑为块表示形式
        blocked_encoder_mask = attention_mask.view(batch_size, seq_length // block_size, block_size)
        
        # 使用块掩码创建带状掩码
        band_mask = create_band_mask_from_inputs(blocked_encoder_mask, blocked_encoder_mask)

        # 为源掩码和目标掩码创建需要的形状
        from_mask = attention_mask.view(batch_size, 1, seq_length, 1)
        to_mask = attention_mask.view(batch_size, 1, 1, seq_length)

        # 返回块掩码、带状掩码、源掩码和目标掩码
        return blocked_encoder_mask, band_mask, from_mask, to_mask

    def _pad_to_block_size(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        token_type_ids: torch.Tensor,
        position_ids: torch.Tensor,
        inputs_embeds: torch.Tensor,
        pad_token_id: int,
        """
        A helper function to pad tokens and mask to work with implementation of BigBird block-sparse attention.
        """
        # padding
        block_size = self.config.block_size  # 从模型配置中获取块大小

        input_shape = input_ids.shape if input_ids is not None else inputs_embeds.shape  # 获取输入的形状
        batch_size, seq_len = input_shape[:2]  # 获取批次大小和序列长度

        padding_len = (block_size - seq_len % block_size) % block_size  # 计算需要填充的长度
        if padding_len > 0:
            logger.warning_once(
                f"Input ids are automatically padded from {seq_len} to {seq_len + padding_len} to be a multiple of "
                f"`config.block_size`: {block_size}"
            )  # 发出警告,说明输入的 ids 自动填充以确保长度是 `config.block_size` 的倍数

            if input_ids is not None:
                input_ids = nn.functional.pad(input_ids, (0, padding_len), value=pad_token_id)  # 对输入 ids 进行填充
            if position_ids is not None:
                # 使用 pad_token_id 填充 position_ids,与 modeling_bigbird.BigBirdEmbeddings 中保持一致
                position_ids = nn.functional.pad(position_ids, (0, padding_len), value=pad_token_id)
            if inputs_embeds is not None:
                # 创建新的输入 ids 填充,并使用模型的 embeddings 生成对应的嵌入向量
                input_ids_padding = inputs_embeds.new_full(
                    (batch_size, padding_len),
                    self.config.pad_token_id,
                    dtype=torch.long,
                )
                inputs_embeds_padding = self.embeddings(input_ids_padding)
                inputs_embeds = torch.cat([inputs_embeds, inputs_embeds_padding], dim=-2)  # 将填充后的嵌入向量拼接到原始嵌入向量中

            attention_mask = nn.functional.pad(
                attention_mask, (0, padding_len), value=False
            )  # 对注意力掩码进行填充,填充部分不计入注意力
            token_type_ids = nn.functional.pad(token_type_ids, (0, padding_len), value=0)  # 使用 token_type_id = 0 进行填充

        return padding_len, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds
# 定义 BigBirdForMaskedLM 类,继承自 BigBirdPreTrainedModel 类
class BigBirdForMaskedLM(BigBirdPreTrainedModel):
    # 定义权重共享的键值对列表
    _tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]

    # 初始化函数,接收一个配置参数 config
    def __init__(self, config):
        # 调用父类的初始化方法
        super().__init__(config)

        # 如果配置要求是 decoder,则发出警告信息
        if config.is_decoder:
            logger.warning(
                "If you want to use `BigBirdForMaskedLM` make sure `config.is_decoder=False` for "
                "bi-directional self-attention."
            )

        # 使用配置初始化 BigBirdModel,并设置为 self.bert 属性
        self.bert = BigBirdModel(config)
        # 使用配置初始化 BigBirdOnlyMLMHead,并设置为 self.cls 属性
        self.cls = BigBirdOnlyMLMHead(config)

        # 初始化权重并进行最终处理
        self.post_init()

    # 返回预测解码器的输出嵌入
    def get_output_embeddings(self):
        return self.cls.predictions.decoder

    # 设置预测解码器的输出嵌入为新的嵌入
    def set_output_embeddings(self, new_embeddings):
        self.cls.predictions.decoder = new_embeddings

    # 覆盖模型的 forward 方法,接受多个输入参数,并带有相关的文档字符串和返回值替换
    @add_start_docstrings_to_model_forward(BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.FloatTensor] = None,
        next_sentence_label: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        # 函数体暂时省略
        pass
    # 定义一个方法用于模型的前向传播
    def forward(
        self,
        input_ids: torch.LongTensor = None,  # 输入的token ID序列,数据类型为LongTensor
        attention_mask: Optional[torch.FloatTensor] = None,  # 可选的注意力遮罩,数据类型为FloatTensor
        token_type_ids: Optional[torch.LongTensor] = None,  # 可选的token类型ID序列,数据类型为LongTensor
        position_ids: Optional[torch.LongTensor] = None,  # 可选的位置ID序列,数据类型为LongTensor
        head_mask: Optional[torch.FloatTensor] = None,  # 可选的头部遮罩,数据类型为FloatTensor
        inputs_embeds: Optional[torch.FloatTensor] = None,  # 可选的嵌入向量输入,数据类型为FloatTensor
        encoder_hidden_states: Optional[torch.FloatTensor] = None,  # 可选的编码器隐藏状态,数据类型为FloatTensor
        encoder_attention_mask: Optional[torch.FloatTensor] = None,  # 可选的编码器注意力遮罩,数据类型为FloatTensor
        labels: Optional[torch.LongTensor] = None,  # 可选的标签,数据类型为LongTensor
        output_attentions: Optional[bool] = None,  # 是否输出注意力权重,数据类型为bool
        output_hidden_states: Optional[bool] = None,  # 是否输出隐藏状态,数据类型为bool
        return_dict: Optional[bool] = None,  # 是否以字典形式返回结果,数据类型为bool
    ):
        # 定义用于生成输入的方法,支持模型生成任务
        def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
            # 获取输入token ID序列的形状信息
            input_shape = input_ids.shape
            # 获取有效的批次大小
            effective_batch_size = input_shape[0]

            # 如果配置中未定义PAD token ID,则抛出数值错误
            if self.config.pad_token_id is None:
                raise ValueError("The PAD token should be defined for generation")
            # 在注意力遮罩的末尾添加一个虚拟token
            attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)
            # 创建一个全为PAD token ID的虚拟token张量
            dummy_token = torch.full(
                (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
            )
            # 将虚拟token添加到输入token ID序列的末尾
            input_ids = torch.cat([input_ids, dummy_token], dim=1)

            # 返回输入字典,包括输入token ID序列和更新后的注意力遮罩
            return {"input_ids": input_ids, "attention_mask": attention_mask}
# 使用装饰器添加文档字符串,说明这是一个 BigBird 模型,用于语言建模任务的微调
@add_start_docstrings(
    """BigBird Model with a `language modeling` head on top for CLM fine-tuning.""", BIG_BIRD_START_DOCSTRING
)
# 定义 BigBirdForCausalLM 类,继承自 BigBirdPreTrainedModel
class BigBirdForCausalLM(BigBirdPreTrainedModel):
    # 指定共享权重的键名列表,用于多个权重共享
    _tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]

    # 初始化方法,接收配置参数 config
    def __init__(self, config):
        # 调用父类的初始化方法
        super().__init__(config)

        # 如果不是解码器模式,记录警告信息
        if not config.is_decoder:
            logger.warning("If you want to use `BigBirdForCausalLM` as a standalone, add `is_decoder=True.`")

        # 初始化 BigBirdModel 和 BigBirdOnlyMLMHead
        self.bert = BigBirdModel(config)
        self.cls = BigBirdOnlyMLMHead(config)

        # 初始化权重并应用最终处理
        self.post_init()

    # 返回输出嵌入层的方法
    def get_output_embeddings(self):
        return self.cls.predictions.decoder

    # 设置输出嵌入层的方法
    def set_output_embeddings(self, new_embeddings):
        self.cls.predictions.decoder = new_embeddings

    # 前向传播方法,接收多个输入参数,返回模型预测结果
    @add_start_docstrings_to_model_forward(BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=CausalLMOutputWithCrossAttentions,
        config_class=_CONFIG_FOR_DOC,
    )
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        # 为生成准备输入的方法,接收输入的 ID、过去的键值对、注意力掩码等参数
        def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
            # 获取输入的形状
            input_shape = input_ids.shape

            # 如果没有提供注意力掩码,创建全为1的注意力掩码
            if attention_mask is None:
                attention_mask = input_ids.new_ones(input_shape)

            # 如果使用过去的键值对,截取输入的 ID
            if past_key_values is not None:
                past_length = past_key_values[0][0].shape[2]

                # 某些生成方法已经只传递了最后一个输入 ID
                if input_ids.shape[1] > past_length:
                    remove_prefix_length = past_length
                else:
                    # 默认旧的行为:仅保留最后一个 ID
                    remove_prefix_length = input_ids.shape[1] - 1

                input_ids = input_ids[:, remove_prefix_length:]

            # 返回输入字典,包含输入 ID、注意力掩码和过去的键值对
            return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
    # 重新排序缓存中的过去键值对,以适应束搜索中的索引重排
    def _reorder_cache(self, past_key_values, beam_idx):
        # 初始化重新排序后的过去状态元组
        reordered_past = ()
        # 遍历每一层的过去状态
        for layer_past in past_key_values:
            # 对每一层的过去状态的前两个元素(通常是键和值)进行重新排序
            reordered_past += (
                # 使用beam_idx重新排序,保证与束搜索顺序一致,转移到相同设备上
                tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2])
                # 将第三个元素及以后的元素保持不变,通常是额外的状态信息
                + layer_past[2:],
            )
        # 返回重新排序后的过去状态元组
        return reordered_past
class BigBirdForMultipleChoice(BigBirdPreTrainedModel):
    """BigBird Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
    softmax) e.g. for RocStories/SWAG tasks."""

    def __init__(self, config):
        super().__init__(config)

        # 使用 BigBirdModel 初始化 BERT 部分
        self.bert = BigBirdModel(config)
        # 使用给定的隐藏层丢弃概率初始化 Dropout 层
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        # 使用线性层初始化分类器,输出维度为1,用于多选任务
        self.classifier = nn.Linear(config.hidden_size, 1)

        # 初始化权重并应用最终处理
        self.post_init()

    @add_start_docstrings_to_model_forward(
        BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
    )
    @replace_return_docstrings(output_type=MultipleChoiceModelOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        """
        前向传播方法,执行 BigBird 多选分类任务。

        参数:
            input_ids: 输入 token IDs 张量,形状为 (batch_size, num_choices, sequence_length)
            attention_mask: 注意力掩码张量,形状为 (batch_size, num_choices, sequence_length)
            token_type_ids: token 类型 IDs 张量,形状为 (batch_size, num_choices, sequence_length)
            position_ids: 位置 IDs 张量,形状为 (batch_size, num_choices, sequence_length)
            head_mask: 头部掩码张量,形状为 (num_heads,) 或者 (num_layers, num_heads)
            inputs_embeds: 输入嵌入张量,形状为 (batch_size, num_choices, sequence_length, hidden_size)
            labels: 标签张量,形状为 (batch_size,),每个值为 0 或 1
            output_attentions: 是否输出注意力权重
            output_hidden_states: 是否输出隐藏状态
            return_dict: 是否返回字典形式的输出

        返回:
            MultipleChoiceModelOutput: 包含多选模型输出的对象
        """
        # 确保返回字典的选项不为 None
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # 获取 BERT 模型的输出
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 提取池化输出(<s> token 对应的输出)
        pooled_output = outputs[1]

        # 应用 Dropout
        pooled_output = self.dropout(pooled_output)

        # 通过分类器获取 logits
        logits = self.classifier(pooled_output)

        if labels is not None:
            # 计算交叉熵损失
            loss_fct = nn.CrossEntropyLoss()
            # 多选分类任务需要将 logits 和 labels 转置
            logits = logits.view(-1, self.num_labels)
            labels = labels.view(-1)
            # 计算损失
            loss = loss_fct(logits, labels)
            # 输出损失和 logits
            return MultipleChoiceModelOutput(loss=loss, logits=logits)

        # 返回 logits
        return logits
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=MultipleChoiceModelOutput,
        config_class=_CONFIG_FOR_DOC,
    )


    # 使用指定的参数注释添加代码示例的文档字符串
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[MultipleChoiceModelOutput, Tuple[torch.FloatTensor]]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
            num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
            `input_ids` above)
        """
        # 初始化返回字典,根据配置确定是否使用返回字典
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        # 获取输入数据的选择数目
        num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]

        # 将输入张量展平为二维张量,以便于后续处理
        input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
        attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
        token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
        position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
        inputs_embeds = (
            inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
            if inputs_embeds is not None
            else None
        )

        # 调用BERT模型处理输入数据
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 获取BERT模型的池化输出
        pooled_output = outputs[1]

        # 对池化输出应用Dropout操作
        pooled_output = self.dropout(pooled_output)
        # 使用分类器计算logits
        logits = self.classifier(pooled_output)
        # 重新整形logits张量以匹配多选择的形状
        reshaped_logits = logits.view(-1, num_choices)

        # 初始化损失为None
        loss = None
        # 如果有提供标签,则计算交叉熵损失
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(reshaped_logits, labels)

        # 如果不需要返回字典,则返回扁平化的输出元组
        if not return_dict:
            output = (reshaped_logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        # 否则,返回带有多选择模型输出格式的字典
        return MultipleChoiceModelOutput(
            loss=loss,
            logits=reshaped_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
# 使用装饰器添加文档字符串,描述了这是一个在BigBird模型基础上进行标记分类的模型,例如用于命名实体识别(NER)任务
@add_start_docstrings(
    """
    BigBird Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
    Named-Entity-Recognition (NER) tasks.
    """,
    BIG_BIRD_START_DOCSTRING,
)
# 定义 BigBirdForTokenClassification 类,继承自 BigBirdPreTrainedModel 类
class BigBirdForTokenClassification(BigBirdPreTrainedModel):
    # 初始化方法,接受一个配置对象 config
    def __init__(self, config):
        # 调用父类的初始化方法
        super().__init__(config)
        # 设置类别数为配置对象中的类别数
        self.num_labels = config.num_labels

        # 初始化 BigBirdModel 模型
        self.bert = BigBirdModel(config)
        
        # 根据配置设置分类器的 dropout 概率,如果未设置,则使用隐藏层 dropout 概率
        classifier_dropout = (
            config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
        )
        # 使用 dropout 模块
        self.dropout = nn.Dropout(classifier_dropout)
        # 分类器层,将隐藏状态映射到类别数上
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

        # 初始化权重并进行最终处理
        self.post_init()

    # 使用装饰器添加文档字符串到 forward 方法,描述输入参数的格式
    @add_start_docstrings_to_model_forward(BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    # 添加代码示例的文档字符串,指定了检查点、输出类型和配置类
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=TokenClassifierOutput,
        config_class=_CONFIG_FOR_DOC,
    )
    # 前向传播方法,接受多个输入参数,输出预测结果或损失
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,

        # 函数参数说明:
        # - input_ids: 输入的 token IDs
        # - attention_mask: 注意力掩码,指示哪些位置是填充的
        # - token_type_ids: token 类型 IDs,用于区分不同句子的 token
        # - position_ids: 位置 IDs,标识 token 的位置
        # - head_mask: 头部掩码,用于指定哪些注意力头是有效的
        # - inputs_embeds: 嵌入输入,替代输入的 token IDs
        # - labels: 标签,用于训练时的真实类别
        # - output_attentions: 是否输出注意力权重
        # - output_hidden_states: 是否输出隐藏状态
        # - return_dict: 是否以字典形式返回输出
        
        # 返回 TokenClassifierOutput 类型的对象,包含模型的输出结果
        return super().forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            labels=labels,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
    ) -> Union[TokenClassifierOutput, Tuple[torch.FloatTensor]]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
        """
        # 初始化返回字典,如果未指定则使用配置中的返回设置
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # 使用BERT模型处理输入数据,获取输出结果
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 从BERT模型的输出中获取序列输出(通常是最后一层的隐藏状态)
        sequence_output = outputs[0]

        # 对序列输出进行Dropout处理
        sequence_output = self.dropout(sequence_output)
        
        # 使用分类器对处理后的序列输出进行分类得到logits
        logits = self.classifier(sequence_output)

        # 初始化损失为None
        loss = None
        # 如果提供了标签,计算分类损失
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

        # 如果不要求返回字典格式,则返回一个元组
        if not return_dict:
            output = (logits,) + outputs[2:]  # 只返回logits和额外的输出(隐藏状态等)
            return ((loss,) + output) if loss is not None else output

        # 返回TokenClassifierOutput格式的对象,包括损失、logits、隐藏状态和注意力权重
        return TokenClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
class BigBirdForQuestionAnsweringHead(nn.Module):
    """Head for question answering tasks."""

    def __init__(self, config):
        super().__init__()
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.intermediate = BigBirdIntermediate(config)  # 初始化中间层对象
        self.output = BigBirdOutput(config)  # 初始化输出层对象
        self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)  # 初始化线性层

    def forward(self, encoder_output):
        hidden_states = self.dropout(encoder_output)  # 应用 dropout 到编码器输出
        hidden_states = self.intermediate(hidden_states)  # 经过中间层处理
        hidden_states = self.output(hidden_states, encoder_output)  # 经过输出层处理,传入编码器输出
        hidden_states = self.qa_outputs(hidden_states)  # 通过线性层计算最终输出
        return hidden_states


@add_start_docstrings(
    """
    BigBird Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
    layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
    """,
    BIG_BIRD_START_DOCSTRING,
)
class BigBirdForQuestionAnswering(BigBirdPreTrainedModel):
    def __init__(self, config, add_pooling_layer=False):
        super().__init__(config)

        config.num_labels = 2  # 设置类别数量为2
        self.num_labels = config.num_labels
        self.sep_token_id = config.sep_token_id  # 分隔符 token 的 id

        self.bert = BigBirdModel(config, add_pooling_layer=add_pooling_layer)  # 初始化 BigBird 模型
        self.qa_classifier = BigBirdForQuestionAnsweringHead(config)  # 初始化问题回答头部

        # Initialize weights and apply final processing
        self.post_init()

    @add_start_docstrings_to_model_forward(BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @replace_return_docstrings(output_type=BigBirdForQuestionAnsweringModelOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        question_lengths: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        start_positions: Optional[torch.LongTensor] = None,
        end_positions: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        # 前向传播函数,详细参数见注释中的说明

    @staticmethod
    def prepare_question_mask(q_lengths: torch.Tensor, maxlen: int):
        # 准备问题掩码,根据问题长度和最大长度生成掩码
        mask = torch.arange(0, maxlen).to(q_lengths.device)
        mask.unsqueeze_(0)  # 增加维度
        mask = torch.where(mask < q_lengths, 1, 0)  # 根据长度生成掩码
        return mask

.\models\big_bird\modeling_flax_big_bird.py

# 导入必要的模块和类
from typing import Callable, Optional, Tuple  # 导入类型提示相关模块

import flax  # 导入Flax框架
import flax.linen as nn  # 导入Flax的linen模块,用于定义神经网络模型
import jax  # 导入JAX,用于自动求导和并行计算
import jax.numpy as jnp  # 导入JAX的NumPy接口,命名为jnp,用于数组操作

# 从Flax的core.frozen_dict模块中导入FrozenDict、freeze、unfreeze等相关函数和类
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze  

# 从Flax的linen模块中导入combine_masks、make_causal_mask等函数和类,用于处理神经网络模型
from flax.linen import combine_masks, make_causal_mask  
from flax.linen import partitioning as nn_partitioning  # 导入linen.partitioning模块,用于模型分区
from flax.linen.attention import dot_product_attention_weights  # 导入dot_product_attention_weights函数,用于注意力机制权重计算
from flax.traverse_util import flatten_dict, unflatten_dict  # 导入flatten_dict和unflatten_dict函数,用于字典扁平化和反扁平化
from jax import lax  # 导入lax模块,用于定义JAX的低级API

# 导入特定的模型输出类和工具函数
from ...modeling_flax_outputs import (
    FlaxBaseModelOutputWithPastAndCrossAttentions,
    FlaxBaseModelOutputWithPooling,
    FlaxBaseModelOutputWithPoolingAndCrossAttentions,
    FlaxCausalLMOutputWithCrossAttentions,
    FlaxMaskedLMOutput,
    FlaxMultipleChoiceModelOutput,
    FlaxSequenceClassifierOutput,
    FlaxTokenClassifierOutput,
)

# 导入特定的模型基类和工具函数
from ...modeling_flax_utils import (
    ACT2FN,
    FlaxPreTrainedModel,
    append_call_sample_docstring,
    append_replace_return_docstrings,
    overwrite_call_docstring,
)

# 导入通用工具函数和类
from ...utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging

# 从当前目录下的configuration_big_bird.py文件中导入BigBirdConfig类
from .configuration_big_bird import BigBirdConfig  

# 获取logger对象,用于记录日志信息
logger = logging.get_logger(__name__)

# 定义用于文档的检查点和配置变量
_CHECKPOINT_FOR_DOC = "google/bigbird-roberta-base"
_CONFIG_FOR_DOC = "BigBirdConfig"

# 定义并装饰remat函数,用于对神经网络模型进行分区重组
remat = nn_partitioning.remat

# 定义FlaxBigBirdForPreTrainingOutput类,继承自ModelOutput,用于BigBird预训练模型的输出类型
@flax.struct.dataclass
class FlaxBigBirdForPreTrainingOutput(ModelOutput):
    """
    Output type of [`BigBirdForPreTraining`].
    """
    # `prediction_logits` 是一个形状为 `(batch_size, sequence_length, config.vocab_size)` 的 NumPy 数组,
    # 包含语言建模头部的预测分数(在 SoftMax 之前的每个词汇标记的分数)。
    prediction_logits: jnp.ndarray = None
    
    # `seq_relationship_logits` 是一个形状为 `(batch_size, 2)` 的 NumPy 数组,
    # 包含下一个序列预测(分类)头部的预测分数(在 SoftMax 之前的 True/False 继续的分数)。
    seq_relationship_logits: jnp.ndarray = None
    
    # `hidden_states` 是一个可选的元组,当传递 `output_hidden_states=True` 或 `config.output_hidden_states=True` 时返回。
    # 其中包含多个 `jnp.ndarray`(一个用于嵌入的输出 + 每个层的输出),
    # 形状为 `(batch_size, sequence_length, hidden_size)`。
    # 这些是模型在每个层输出的隐藏状态以及初始嵌入输出。
    hidden_states: Optional[Tuple[jnp.ndarray]] = None
    
    # `attentions` 是一个可选的元组,当传递 `output_attentions=True` 或 `config.output_attentions=True` 时返回。
    # 其中包含多个 `jnp.ndarray`(每个层一个),形状为 `(batch_size, num_heads, sequence_length, sequence_length)`。
    # 这些是经过注意力 SoftMax 后的注意力权重,用于计算自注意力头中的加权平均值。
    attentions: Optional[Tuple[jnp.ndarray]] = None
@flax.struct.dataclass
class FlaxBigBirdForQuestionAnsweringModelOutput(ModelOutput):
    """
    Base class for outputs of question answering models.

    Args:
        start_logits (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
            Span-start scores (before SoftMax).
        end_logits (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
            Span-end scores (before SoftMax).
        pooled_output (`jnp.ndarray` of shape `(batch_size, hidden_size)`):
            pooled_output returned by FlaxBigBirdModel.
        hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
            `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    """

    start_logits: jnp.ndarray = None  # Span-start scores (before SoftMax) for question answering.
    end_logits: jnp.ndarray = None  # Span-end scores (before SoftMax) for question answering.
    pooled_output: jnp.ndarray = None  # Output pooled by FlaxBigBirdModel.
    hidden_states: Optional[Tuple[jnp.ndarray]] = None  # Hidden states of model layers and embeddings.
    attentions: Optional[Tuple[jnp.ndarray]] = None  # Attention weights for self-attention heads.


BIG_BIRD_START_DOCSTRING = r"""

    This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading, saving and converting weights from PyTorch models)

    This model is also a
    [flax.linen.Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) subclass. Use it as
    a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and
    behavior.

    Finally, this model supports inherent JAX features such as:

    - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
    - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
    - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
    - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
"""
    # Parameters: 定义函数参数和其作用
    # config ([`BigBirdConfig`]): 模型配置类,包含模型的所有参数
    #     初始化配置文件不会加载与模型相关的权重,仅加载配置。
    #     若要加载模型权重,请查看 [`~FlaxPreTrainedModel.from_pretrained`] 方法。
    # dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
    #     计算的数据类型。可以是 `jax.numpy.float32`, `jax.numpy.float16` (在GPU上), `jax.numpy.bfloat16` (在TPU上) 之一。
    #
    #     可用于启用混合精度训练或在GPU或TPU上进行半精度推断。如果指定,则所有计算将使用给定的 `dtype` 进行。
    #
    #     **注意,这只是指定计算的数据类型,不影响模型参数的数据类型。**
    #
    #     如果希望更改模型参数的数据类型,请参阅 [`~FlaxPreTrainedModel.to_fp16`] 和 [`~FlaxPreTrainedModel.to_bf16`]。
"""

BIG_BIRD_INPUTS_DOCSTRING = r"""
    Args:
        input_ids (`numpy.ndarray` of shape `({0})`):
            Indices of input sequence tokens in the vocabulary.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        attention_mask (`numpy.ndarray` of shape `({0})`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            [What are attention masks?](../glossary#attention-mask)
        token_type_ids (`numpy.ndarray` of shape `({0})`, *optional*):
            Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
            1]`:

            - 0 corresponds to a *sentence A* token,
            - 1 corresponds to a *sentence B* token.

            [What are token type IDs?](../glossary#token-type-ids)
        position_ids (`numpy.ndarray` of shape `({0})`, *optional*):
            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
            config.max_position_embeddings - 1]`.
        head_mask (`numpy.ndarray` of shape `({0})`, `optional):
            Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.

"""


class FlaxBigBirdEmbeddings(nn.Module):
    """Construct the embeddings from word, position and token_type embeddings."""

    config: BigBirdConfig
    dtype: jnp.dtype = jnp.float32  # the dtype of the computation

    # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEmbeddings.setup

    # 此处定义了一个FlaxBigBirdEmbeddings类,用于构建从词嵌入、位置嵌入和token_type嵌入构成的嵌入向量。
    # 初始化模型的各种嵌入层和正则化层
    def setup(self):
        # 初始化词嵌入层,将词汇表大小、隐藏层大小等作为参数传入
        self.word_embeddings = nn.Embed(
            self.config.vocab_size,
            self.config.hidden_size,
            embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
            dtype=self.dtype,
        )
        # 初始化位置嵌入层,将最大位置嵌入数、隐藏层大小等作为参数传入
        self.position_embeddings = nn.Embed(
            self.config.max_position_embeddings,
            self.config.hidden_size,
            embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
            dtype=self.dtype,
        )
        # 初始化类型嵌入层,将类型词汇表大小、隐藏层大小等作为参数传入
        self.token_type_embeddings = nn.Embed(
            self.config.type_vocab_size,
            self.config.hidden_size,
            embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
            dtype=self.dtype,
        )
        # 初始化 Layer Normalization 层,使用指定的 epsilon 参数
        self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
        # 初始化 Dropout 层,使用指定的 dropout 率
        self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)

    # 模型的调用方法,将输入的各种嵌入 ID 进行嵌入,并返回处理后的隐藏状态
    def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True):
        # 嵌入输入的词 ID,将其转换为整数类型并传递给词嵌入层
        inputs_embeds = self.word_embeddings(input_ids.astype("i4"))
        # 嵌入位置 ID,将其转换为整数类型并传递给位置嵌入层
        position_embeds = self.position_embeddings(position_ids.astype("i4"))
        # 嵌入类型 ID,将其转换为整数类型并传递给类型嵌入层
        token_type_embeddings = self.token_type_embeddings(token_type_ids.astype("i4"))

        # 如果配置中指定需要重新缩放嵌入层的权重
        if self.config.rescale_embeddings:
            # 对输入嵌入层的值进行按比例缩放
            inputs_embeds *= self.config.hidden_size**0.5

        # 将所有嵌入层的结果相加,形成隐藏状态的初始表示
        hidden_states = inputs_embeds + token_type_embeddings + position_embeds

        # 应用 Dropout 进行正则化,根据 deterministic 参数决定是否使用确定性模式
        hidden_states = self.dropout(hidden_states, deterministic=deterministic)
        # 应用 Layer Normalization 进行正则化,将结果传递给 LayerNorm 层
        hidden_states = self.LayerNorm(hidden_states)
        # 返回最终的隐藏状态作为模型的输出
        return hidden_states
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfAttention with Bert->BigBird
class FlaxBigBirdSelfAttention(nn.Module):
    # 定义类属性config,表示BigBird模型的配置
    config: BigBirdConfig
    # 是否使用因果注意力,默认为False
    causal: bool = False
    # 计算时使用的数据类型,默认为jnp.float32
    dtype: jnp.dtype = jnp.float32  # the dtype of the computation

    def setup(self):
        # 计算每个注意力头的维度
        self.head_dim = self.config.hidden_size // self.config.num_attention_heads
        # 检查隐藏大小是否能够被注意力头数整除
        if self.config.hidden_size % self.config.num_attention_heads != 0:
            raise ValueError(
                # 如果不能整除,抛出错误提示
                "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads` "
                "                   : {self.config.num_attention_heads}"
            )

        # 初始化查询层
        self.query = nn.Dense(
            self.config.hidden_size,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
        )
        # 初始化键层
        self.key = nn.Dense(
            self.config.hidden_size,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
        )
        # 初始化值层
        self.value = nn.Dense(
            self.config.hidden_size,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
        )

        # 如果使用因果注意力,创建因果掩码
        if self.causal:
            self.causal_mask = make_causal_mask(
                jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool"
            )

    # 将隐藏状态分割为多个注意力头
    def _split_heads(self, hidden_states):
        return hidden_states.reshape(hidden_states.shape[:2] + (self.config.num_attention_heads, self.head_dim))

    # 合并多个注意力头为一个隐藏状态
    def _merge_heads(self, hidden_states):
        return hidden_states.reshape(hidden_states.shape[:2] + (self.config.hidden_size,))

    @nn.compact
    # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention._concatenate_to_cache
    def _concatenate_to_cache(self, key, value, query, attention_mask):
        """
        This function takes projected key, value states from a single input token and concatenates the states to cached
        states from previous steps. This function is slightly adapted from the official Flax repository:
        https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
        """
        # 检测是否通过检查"cache"变量中"cached_key"来初始化
        is_initialized = self.has_variable("cache", "cached_key")
        # 初始化或获取缓存的key和value,如果不存在则创建全零数组
        cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
        cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
        # 获取或创建缓存索引,初始为0
        cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))

        if is_initialized:
            # 获取缓存key的维度信息,从而更新缓存的状态
            *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
            # 根据当前缓存索引更新key和value的缓存状态
            cur_index = cache_index.value
            indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
            key = lax.dynamic_update_slice(cached_key.value, key, indices)
            value = lax.dynamic_update_slice(cached_value.value, value, indices)
            cached_key.value = key
            cached_value.value = value
            # 更新缓存索引,增加已更新的缓存向量数目
            num_updated_cache_vectors = query.shape[1]
            cache_index.value = cache_index.value + num_updated_cache_vectors
            # 为缓存的解码器自注意力创建因果掩码:我们的单个查询位置只应该关注已生成和缓存的key位置,而不是剩余的零元素。
            pad_mask = jnp.broadcast_to(
                jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
                tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
            )
            attention_mask = combine_masks(pad_mask, attention_mask)
        
        # 返回更新后的key、value和注意力掩码
        return key, value, attention_mask
    # 定义一个名为 FlaxBigBirdBlockSparseAttention 的类,继承自 nn.Module
    class FlaxBigBirdBlockSparseAttention(nn.Module):
        # 类变量:BigBirdConfig 类型的 config 对象,block_sparse_seed 和 dtype 为可选参数
        config: BigBirdConfig
        block_sparse_seed: int = None
        dtype: jnp.dtype = jnp.float32

        # 初始化方法,设置网络的各个组件
        def setup(self):
            # 创建一个 Dense 层作为查询网络,输出维度为 config.hidden_size
            self.query = nn.Dense(
                self.config.hidden_size,
                dtype=self.dtype,
                use_bias=self.config.use_bias,
                # 使用正态分布初始化权重,标准差为 config.initializer_range
                kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
            )
            # 创建一个 Dense 层作为键网络,输出维度为 config.hidden_size
            self.key = nn.Dense(
                self.config.hidden_size,
                dtype=self.dtype,
                use_bias=self.config.use_bias,
                # 使用正态分布初始化权重,标准差为 config.initializer_range
                kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
            )
            # 创建一个 Dense 层作为值网络,输出维度为 config.hidden_size
            self.value = nn.Dense(
                self.config.hidden_size,
                dtype=self.dtype,
                use_bias=self.config.use_bias,
                # 使用正态分布初始化权重,标准差为 config.initializer_range
                kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
            )

        # 静态方法:将输入 x 转置为 scores 矩阵的形状
        @staticmethod
        def transpose_for_scores(x, n_heads, head_size):
            # 新的形状为 x 的最后一维除去最后一个元素,加上 (n_heads, head_size)
            new_x_shape = x.shape[:-1] + (n_heads, head_size)
            x = x.reshape(*new_x_shape)
            # 交换指定的维度顺序:第一维和第三维互换位置
            return jnp.transpose(x, axes=(0, 2, 1, 3))

        # 实例方法:处理输入的 hidden_states 和 attention_mask,执行注意力计算
        def __call__(
            self,
            hidden_states,
            attention_mask,
            deterministic=True,
            output_attentions=False,
        ):
            # 提取配置中的注意力头数和头大小
            n_heads = self.config.num_attention_heads
            head_size = self.config.hidden_size // n_heads

            # 创建用于块稀疏注意力的掩码
            blocked_encoder_mask, band_mask, from_mask, to_mask = self.create_masks_for_block_sparse_attn(
                attention_mask, self.config.block_size
            )

            # 对查询、键和值进行维度变换,以备进行注意力计算
            query_layer = self.transpose_for_scores(self.query(hidden_states), n_heads, head_size)
            key_layer = self.transpose_for_scores(self.key(hidden_states), n_heads, head_size)
            value_layer = self.transpose_for_scores(self.value(hidden_states), n_heads, head_size)

            # 如果需要非确定性操作,则创建随机数生成器密钥
            indices_prng_key = None
            if not deterministic:
                indices_prng_key = self.make_rng("indices")

            # 执行 BigBird 块稀疏注意力机制
            attn_output, attn_weights = self.bigbird_block_sparse_attention(
                query_layer,
                key_layer,
                value_layer,
                band_mask,
                from_mask,
                to_mask,
                blocked_encoder_mask,
                blocked_encoder_mask,
                n_heads,
                head_size,
                indices_prng_key=indices_prng_key,
                deterministic=deterministic,
                plan_from_length=None,
                plan_num_rand_blocks=None,
                output_attentions=output_attentions,
            )

            # 根据需要返回注意力输出和注意力权重
            outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
            return outputs

        # 静态方法:
        @staticmethod
    def create_masks_for_block_sparse_attn(attention_mask, block_size: int):
        # 获取输入的注意力掩码的批次大小和序列长度
        batch_size, seq_length = attention_mask.shape
        # 检查序列长度是否是块大小的倍数,否则引发数值错误
        if seq_length % block_size != 0:
            raise ValueError(
                f"Sequence length must be multiple of block size, but sequence length is {seq_length}, while block"
                f" size is {block_size}."
            )

        def create_band_mask_from_inputs(from_blocked_mask, to_blocked_mask):
            """
            从二维张量掩码创建三维注意力掩码。

            Args:
                from_blocked_mask: 形状为 [batch_size, from_seq_length//from_block_size, from_block_size] 的二维张量掩码。
                to_blocked_mask: 形状为 [batch_size, to_seq_length//to_block_size, to_block_size] 的整数32位张量掩码。

            Returns:
                形状为 [batch_size, 1, from_seq_length//from_block_size-4, from_block_size, 3*to_block_size] 的浮点张量。
            """
            # 扩展并拼接来自被阻塞的掩码以进行填充
            exp_blocked_to_pad = jnp.concatenate(
                [to_blocked_mask[:, 1:-3], to_blocked_mask[:, 2:-2], to_blocked_mask[:, 3:-1]], axis=2
            )
            # 使用爱因斯坦求和符号计算带状掩码
            band_mask = jnp.einsum("blq,blk->blqk", from_blocked_mask[:, 2:-2], exp_blocked_to_pad)
            band_mask = jnp.expand_dims(band_mask, 1)
            return band_mask

        # 将注意力掩码重新形状为块形式的编码器掩码
        blocked_encoder_mask = attention_mask.reshape(batch_size, seq_length // block_size, block_size)
        # 创建带状掩码
        band_mask = create_band_mask_from_inputs(blocked_encoder_mask, blocked_encoder_mask)

        # 重新形状创建来自掩码和去掩码
        from_mask = attention_mask.reshape(batch_size, 1, seq_length, 1)
        to_mask = attention_mask.reshape(batch_size, 1, 1, seq_length)

        return blocked_encoder_mask, band_mask, from_mask, to_mask

    @staticmethod
    def jax_gather(params, indices, batch_dims=2):
        """
        正确地从参数中聚集指数(相当于tf.gather但有修改)。

        Args:
            params: 形状为 (bsz, n_heads, num_blocks, block_size, head_dim) 的参数。
            indices: 形状为 (<num_blocks, 1) 的索引。

        Returns:
            聚集后的张量,形状为 params.shape[:batch_dims] + indices.shape + params.shape[batch_dims+1:]。
        """

        def _jax_gather(params, indices):
            return params[indices]

        # 使用jax.vmap逐批次维度进行映射
        for _ in range(batch_dims):
            _jax_gather = jax.vmap(_jax_gather, in_axes=(0, 0))

        return _jax_gather(params, indices)  # 返回聚集结果
    def _create_rand_mask_from_inputs(
        self,
        from_blocked_mask,
        to_blocked_mask,
        broadcasted_rand_attn,
        num_attention_heads,
        num_random_blocks,
        batch_size,
        from_seq_length,
        from_block_size,
    ):
        """
        Create 3D attention mask from a 2D tensor mask.

        Args:
            from_blocked_mask: 2D Tensor of shape [batch_size, from_seq_length//from_block_size, from_block_size].
                Mask for the 'from' sequence, divided into blocks.
            to_blocked_mask: int32 Tensor of shape [batch_size, to_seq_length//to_block_size, to_block_size].
                Mask for the 'to' sequence, divided into blocks.
            broadcasted_rand_attn:
                [batch_size, num_attention_heads, from_seq_length//from_block_size-2, num_rand_blocks]
                Random attention distribution broadcasted across heads and sequence blocks.
            num_attention_heads: int. Number of attention heads.
            num_random_blocks: int. Number of random chunks per row.
            batch_size: int. Batch size for computation.
            from_seq_length: int. Length of 'from' sequence.
            from_block_size: int. Size of block in 'from' sequence.

        Returns:
            float Tensor of shape [batch_size, num_attention_heads, from_seq_length//from_block_size-2,
            from_block_size, num_rand_blocks*to_block_size].
            3D attention mask combining information from 'from' and 'to' sequences.
        """
        # Calculate the number of windows in the 'from' sequence
        num_windows = from_seq_length // from_block_size - 2
        
        # Gather the random attention mask using JAX gather operation
        rand_mask = self.jax_gather(to_blocked_mask, broadcasted_rand_attn, batch_dims=1)
        
        # Reshape the random mask to match the required output shape
        rand_mask = rand_mask.reshape(
            batch_size, num_attention_heads, num_windows, num_random_blocks * from_block_size
        )
        
        # Perform Einstein summation to combine 'from' block mask with random attention
        rand_mask = jnp.einsum("blq,bhlk->bhlqk", from_blocked_mask[:, 1:-1], rand_mask)
        
        # Return the final random attention mask
        return rand_mask

    @staticmethod
    def _get_rand_attn_plan(from_seq_length, from_block_size, num_rand_blocks):
        """
        根据给定的参数生成随机注意力的分布计划。

        Args:
            from_seq_length: int. 源序列的长度。
            from_block_size: int. 源序列中的块大小。
            num_rand_blocks: int. 每行随机块的数量。

        Returns:
            plan_from_length: list. 源块的结束位置计划。
            plan_num_rand_blocks: list. 每个块中随机结束位置的数量。
        """

        plan_from_length = []  # 初始化源块的结束位置列表
        plan_num_rand_blocks = []  # 初始化每个块中随机结束位置的数量列表

        # 根据条件生成不同的分布计划
        if (2 * num_rand_blocks + 5) < (from_seq_length // from_block_size):
            plan_from_length.append(int((2 * num_rand_blocks + 5) * from_block_size))
            plan_num_rand_blocks.append(num_rand_blocks)
            plan_from_length.append(from_seq_length)
            plan_num_rand_blocks.append(0)
        elif (num_rand_blocks + 5) < (from_seq_length // from_block_size):
            plan_from_length.append(int((num_rand_blocks + 5) * from_block_size))
            plan_num_rand_blocks.append(num_rand_blocks // 2)
            plan_from_length.append(from_seq_length)
            plan_num_rand_blocks.append(num_rand_blocks - (num_rand_blocks // 2))
        else:
            plan_from_length.append(from_seq_length)
            plan_num_rand_blocks.append(num_rand_blocks)

        return plan_from_length, plan_num_rand_blocks

    @staticmethod
    def _bigbird_block_rand_mask(
        from_seq_length,
        to_seq_length,
        from_block_size,
        to_block_size,
        num_rand_blocks,
        indices_prng_key: Optional[jax.random.PRNGKey] = None,
        deterministic: Optional[bool] = True,
        last_idx: Optional[int] = -1,
    ):
        """
        生成BigBird模型中块随机掩码。

        Args:
            from_seq_length: int. 源序列的长度。
            to_seq_length: int. 目标序列的长度。
            from_block_size: int. 源序列中的块大小。
            to_block_size: int. 目标序列中的块大小。
            num_rand_blocks: int. 每行随机块的数量。
            indices_prng_key: Optional[jax.random.PRNGKey]. 随机数生成器密钥。
            deterministic: Optional[bool]. 是否确定性生成随机数。
            last_idx: Optional[int]. 最后一个索引位置,默认为-1。

        Returns:
            返回生成的块随机掩码。
        """

    def _bigbird_block_rand_mask_with_head(
        self,
        from_seq_length,
        to_seq_length,
        from_block_size,
        to_block_size,
        num_heads,
        plan_from_length,
        plan_num_rand_blocks,
        indices_prng_key: Optional[jax.random.PRNGKey] = None,
        deterministic: Optional[bool] = True,
        window_block_left=1,
        window_block_right=1,
        global_block_top=1,
        global_block_bottom=1,
        global_block_left=1,
        global_block_right=1,
    ):
        """
        生成带有头信息的BigBird模型中的块随机掩码。

        Args:
            from_seq_length: int. 源序列的长度。
            to_seq_length: int. 目标序列的长度。
            from_block_size: int. 源序列中的块大小。
            to_block_size: int. 目标序列中的块大小。
            num_heads: int. 头的数量。
            plan_from_length: list. 源块的结束位置计划。
            plan_num_rand_blocks: list. 每个块中随机结束位置的数量。
            indices_prng_key: Optional[jax.random.PRNGKey]. 随机数生成器密钥。
            deterministic: Optional[bool]. 是否确定性生成随机数。
            window_block_left: int. 左侧窗口块大小,默认为1。
            window_block_right: int. 右侧窗口块大小,默认为1。
            global_block_top: int. 顶部全局块大小,默认为1。
            global_block_bottom: int. 底部全局块大小,默认为1。
            global_block_left: int. 左侧全局块大小,默认为1。
            global_block_right: int. 右侧全局块大小,默认为1.

        Returns:
            返回生成的带有头信息的块随机掩码。
        """

    @staticmethod
    def _get_single_block_row_attention(
        block_id,
        to_start_block_id,
        to_end_block_id,
        num_rand_blocks,
        indices_prng_key: Optional[jax.random.PRNGKey] = None,
        window_block_left=1,
        window_block_right=1,
        global_block_left=1,
        global_block_right=1,
    ):
        """
        获取单个块行注意力的实现。

        Args:
            block_id: int. 块的ID。
            to_start_block_id: int. 目标序列起始块的ID。
            to_end_block_id: int. 目标序列结束块的ID。
            num_rand_blocks: int. 每行随机块的数量。
            indices_prng_key: Optional[jax.random.PRNGKey]. 随机数生成器密钥。
            window_block_left: int. 左侧窗口块大小,默认为1。
            window_block_right: int. 右侧窗口块大小,默认为1。
            global_block_left: int. 左侧全局块大小,默认为1。
            global_block_right: int. 右侧全局块大小,默认为1。
        """
    ):
        """
        For a single row block get random row attention.

        Args:
            block_id: int. block id of row.
                表示行块的块标识号。
            to_start_block_id: int. random attention column start id.
                随机注意力列开始的块标识号。
            to_end_block_id: int. random attention column end id.
                随机注意力列结束的块标识号。
            num_rand_blocks: int. number of random blocks to be selected.
                要选择的随机块的数量。
            indices_prng_key: jax.random.PRNGKey. PRNG key that is used to perform random jax operations
                用于执行随机 JAX 操作的 PRNG 密钥。
            window_block_left: int. number of blocks of window to left of a block.
                在一个块左边的窗口中的块数。
            window_block_right: int. number of blocks of window to right of a block.
                在一个块右边的窗口中的块数。
            global_block_left: int. Number of blocks globally used to the left.
                左侧全局使用的块数。
            global_block_right: int. Number of blocks globally used to the right.
                右侧全局使用的块数。

        Returns:
            row containing the random attention vector of size num_rand_blocks.
            包含大小为 num_rand_blocks 的随机注意力向量的行。
        """
        # list of to_blocks from which to choose random attention
        to_block_list = jnp.arange(to_start_block_id, to_end_block_id, dtype=jnp.int32)
        # permute the blocks
        perm_block = jax.random.permutation(indices_prng_key, to_block_list)

        # illegal blocks for the current block id, using window
        illegal_blocks = list(range(block_id - window_block_left, block_id + window_block_right + 1))

        # Add blocks at the start and at the end
        illegal_blocks.extend(list(range(global_block_left)))
        illegal_blocks.extend(list(range(to_end_block_id - global_block_right, to_end_block_id)))

        # The second from_block cannot choose random attention on second last to_block
        if block_id == 1:
            illegal_blocks.append(to_end_block_id - 2)

        # The second last from_block cannot choose random attention on second to_block
        if block_id == to_end_block_id - 2:
            illegal_blocks.append(1)

        selected_random_blocks = []

        for i in range(to_end_block_id - to_start_block_id):
            if perm_block[i] not in illegal_blocks:
                selected_random_blocks.append(perm_block[i])
            if len(selected_random_blocks) == num_rand_blocks:
                break
        return jnp.array(selected_random_blocks, dtype=jnp.int32)
# 从 `transformers.models.bert.modeling_flax_bert.FlaxBertSelfOutput` 复制并修改为 BigBird
class FlaxBigBirdSelfOutput(nn.Module):
    # BigBird 的配置信息
    config: BigBirdConfig
    # 计算的数据类型
    dtype: jnp.dtype = jnp.float32  # the dtype of the computation

    # 初始化函数,设置层的结构
    def setup(self):
        # 全连接层,将输入的隐藏状态转换为指定大小的输出
        self.dense = nn.Dense(
            self.config.hidden_size,
            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
            dtype=self.dtype,
        )
        # LayerNorm 层,用于规范化输入数据
        self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
        # Dropout 层,用于随机失活,防止过拟合
        self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)

    # 对象调用函数,执行层的前向计算
    def __call__(self, hidden_states, input_tensor, deterministic: bool = True):
        # 全连接层计算
        hidden_states = self.dense(hidden_states)
        # Dropout 计算
        hidden_states = self.dropout(hidden_states, deterministic=deterministic)
        # LayerNorm 计算,将残差连接后的结果规范化
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        # 返回处理后的隐藏状态
        return hidden_states


# BigBird 注意力机制层
class FlaxBigBirdAttention(nn.Module):
    # BigBird 的配置信息
    config: BigBirdConfig
    # 层的编号
    layer_id: int = None
    # 是否使用因果注意力
    causal: bool = False
    # 计算的数据类型
    dtype: jnp.dtype = jnp.float32

    # 初始化函数,设置层的结构
    def setup(self):
        # 根据配置选择不同类型的注意力机制
        if self.config.attention_type == "original_full":
            # 使用原始的全注意力机制
            self.self = FlaxBigBirdSelfAttention(self.config, causal=self.causal, dtype=self.dtype)
        elif self.config.attention_type == "block_sparse":
            # 使用块稀疏注意力机制
            self.self = FlaxBigBirdBlockSparseAttention(self.config, block_sparse_seed=self.layer_id, dtype=self.dtype)
        else:
            # 抛出错误,如果配置不匹配
            raise ValueError(
                f"Your `config.attention_type` is {self.config.attention_type} but it can either be `original_full` or"
                " `block_sparse`"
            )

        # 输出层,用于处理自注意力的输出结果
        self.output = FlaxBigBirdSelfOutput(self.config, dtype=self.dtype)

    # 对象调用函数,执行注意力计算
    def __call__(
        self,
        hidden_states,
        attention_mask,
        layer_head_mask,
        key_value_states=None,
        init_cache=False,
        deterministic=True,
        output_attentions: bool = False,
        # 如果 attention_mask 的形状为 (*batch_sizes, kv_length),FLAX 要求形状为 (*batch_sizes, 1, 1, kv_length),以便广播匹配 attn_weights 的形状为 (*batch_sizes, num_heads, q_length, kv_length)
        # 当 self.config.attention_type == "original_full" 时,使用带有额外参数的 self.self 方法进行注意力计算
        attn_outputs = self.self(
            hidden_states,
            attention_mask,
            layer_head_mask=layer_head_mask,
            key_value_states=key_value_states,
            init_cache=init_cache,
            deterministic=deterministic,
            output_attentions=output_attentions,
        )
        # 否则使用默认参数调用 self.self 方法进行注意力计算
        else:
            attn_outputs = self.self(
                hidden_states,
                attention_mask,
                deterministic=deterministic,
                output_attentions=output_attentions,
            )
        # 获取注意力输出的第一个元素
        attn_output = attn_outputs[0]
        # 通过 self.output 方法计算最终的输出 hidden_states
        hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic)

        # 构建输出元组,至少包含 hidden_states
        outputs = (hidden_states,)

        # 如果需要输出注意力信息,则在输出元组中添加 attn_outputs 的第二个元素
        if output_attentions:
            outputs += (attn_outputs[1],)

        # 返回最终输出元组
        return outputs
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertIntermediate with Bert->BigBird
class FlaxBigBirdIntermediate(nn.Module):
    config: BigBirdConfig
    dtype: jnp.dtype = jnp.float32  # the dtype of the computation

    def setup(self):
        # 定义一个全连接层,输出大小为中间层大小,使用正态分布初始化权重
        self.dense = nn.Dense(
            self.config.intermediate_size,
            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
            dtype=self.dtype,
        )
        # 根据配置选择激活函数
        self.activation = ACT2FN[self.config.hidden_act]

    def __call__(self, hidden_states):
        # 将输入的隐藏状态通过全连接层处理
        hidden_states = self.dense(hidden_states)
        # 应用激活函数到处理后的隐藏状态
        hidden_states = self.activation(hidden_states)
        return hidden_states


# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertOutput with Bert->BigBird
class FlaxBigBirdOutput(nn.Module):
    config: BigBirdConfig
    dtype: jnp.dtype = jnp.float32  # the dtype of the computation

    def setup(self):
        # 定义一个全连接层,输出大小为隐藏大小,使用正态分布初始化权重
        self.dense = nn.Dense(
            self.config.hidden_size,
            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
            dtype=self.dtype,
        )
        # 定义一个 dropout 层,用于隐藏层的输出
        self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
        # 定义一个 LayerNorm 层,用于归一化隐藏层的输出
        self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)

    def __call__(self, hidden_states, attention_output, deterministic: bool = True):
        # 将隐藏状态输入全连接层进行处理
        hidden_states = self.dense(hidden_states)
        # 对处理后的隐藏状态应用 dropout 操作
        hidden_states = self.dropout(hidden_states, deterministic=deterministic)
        # 将 dropout 后的结果与注意力输出进行残差连接,并通过 LayerNorm 层进行归一化
        hidden_states = self.LayerNorm(hidden_states + attention_output)
        return hidden_states


class FlaxBigBirdLayer(nn.Module):
    config: BigBirdConfig
    layer_id: int = None
    dtype: jnp.dtype = jnp.float32  # the dtype of the computation

    def setup(self):
        # 定义 BigBird 层的注意力机制
        self.attention = FlaxBigBirdAttention(
            self.config, layer_id=self.layer_id, causal=self.config.is_decoder, dtype=self.dtype
        )
        # 定义 BigBird 层的中间层
        self.intermediate = FlaxBigBirdIntermediate(self.config, dtype=self.dtype)
        # 定义 BigBird 层的输出层
        self.output = FlaxBigBirdOutput(self.config, dtype=self.dtype)
        # 如果配置中包含跨注意力机制,定义 BigBird 层的跨注意力机制
        if self.config.add_cross_attention:
            self.crossattention = FlaxBigBirdAttention(self.config, causal=False, dtype=self.dtype)

    # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayer.__call__ with Bert->BigBird
    def __call__(
        self,
        hidden_states,
        attention_mask,
        layer_head_mask,
        encoder_hidden_states: Optional[jnp.ndarray] = None,
        encoder_attention_mask: Optional[jnp.ndarray] = None,
        init_cache: bool = False,
        deterministic: bool = True,
        output_attentions: bool = False,

        # BigBird 层的调用方法,接收隐藏状态、注意力掩码、层头掩码等输入参数
        # 如果需要初始化缓存,则传入 True
        # deterministic 参数指定是否使用确定性计算,默认为 True
        # output_attentions 参数指定是否输出注意力权重,默认为 False
        # Self Attention
        # 使用 self.attention 方法对输入的 hidden_states 进行自注意力计算
        attention_outputs = self.attention(
            hidden_states,
            attention_mask,
            layer_head_mask=layer_head_mask,
            init_cache=init_cache,
            deterministic=deterministic,
            output_attentions=output_attentions,
        )
        # 获取自注意力计算后的输出
        attention_output = attention_outputs[0]

        # Cross-Attention Block
        # 如果存在 encoder_hidden_states,则执行交叉注意力计算
        if encoder_hidden_states is not None:
            cross_attention_outputs = self.crossattention(
                attention_output,
                attention_mask=encoder_attention_mask,
                layer_head_mask=layer_head_mask,
                key_value_states=encoder_hidden_states,
                deterministic=deterministic,
                output_attentions=output_attentions,
            )
            # 获取交叉注意力计算后的输出
            attention_output = cross_attention_outputs[0]

        # 经过注意力计算后的输出再经过 intermediate 层处理
        hidden_states = self.intermediate(attention_output)
        # 经过输出层处理,得到最终的 hidden_states
        hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic)

        # 输出结果为 hidden_states 组成的元组
        outputs = (hidden_states,)

        # 如果需要输出 attentions,则将 attentions 添加到输出结果中
        if output_attentions:
            outputs += (attention_outputs[1],)
            # 如果存在 encoder_hidden_states,则将交叉注意力也添加到输出结果中
            if encoder_hidden_states is not None:
                outputs += (cross_attention_outputs[1],)
        
        # 返回最终的输出结果
        return outputs
# 定义一个名为 FlaxBigBirdLayerCollection 的类,继承自 nn.Module
class FlaxBigBirdLayerCollection(nn.Module):
    # config 属性,类型为 BigBirdConfig,用于存储 BigBird 的配置信息
    config: BigBirdConfig
    # dtype 属性,默认为 jnp.float32,用于定义计算的数据类型
    dtype: jnp.dtype = jnp.float32  # 计算的数据类型
    # gradient_checkpointing 属性,默认为 False,表示是否开启梯度检查点
    gradient_checkpointing: bool = False

    # 定义类的初始化方法 setup
    def setup(self):
        # 如果开启了梯度检查点
        if self.gradient_checkpointing:
            # 定义一个经过 remat 处理的 FlaxBigBirdCheckpointLayer 类,用于梯度检查点
            FlaxBigBirdCheckpointLayer = remat(FlaxBigBirdLayer, static_argnums=(5, 6, 7))
            # 初始化 self.layers,创建包含多个 FlaxBigBirdCheckpointLayer 实例的列表
            self.layers = [
                FlaxBigBirdCheckpointLayer(self.config, layer_id=i, name=str(i), dtype=self.dtype)
                for i in range(self.config.num_hidden_layers)
            ]
        else:
            # 如果未开启梯度检查点,初始化 self.layers,创建包含多个 FlaxBigBirdLayer 实例的列表
            self.layers = [
                FlaxBigBirdLayer(self.config, layer_id=i, name=str(i), dtype=self.dtype)
                for i in range(self.config.num_hidden_layers)
            ]

    # 定义类的调用方法 __call__,用于执行实例化对象时的操作
    # 该方法功能与 transformers.models.bert.modeling_flax_bert.FlaxBertLayerCollection.__call__ 相似,替换了 Bert 为 BigBird
    def __call__(
        self,
        hidden_states,  # 输入参数,表示隐藏状态
        attention_mask,  # 输入参数,表示注意力掩码
        head_mask,  # 输入参数,表示头掩码
        encoder_hidden_states: Optional[jnp.ndarray] = None,  # 可选输入参数,编码器隐藏状态
        encoder_attention_mask: Optional[jnp.ndarray] = None,  # 可选输入参数,编码器注意力掩码
        init_cache: bool = False,  # 是否初始化缓存,默认为 False
        deterministic: bool = True,  # 是否确定性计算,默认为 True
        output_attentions: bool = False,  # 是否输出注意力,默认为 False
        output_hidden_states: bool = False,  # 是否输出隐藏状态,默认为 False
        return_dict: bool = True,  # 是否返回字典,默认为 True
        # 返回值:根据参数执行 BigBird 相关计算并返回相应结果
        ):
            all_attentions = () if output_attentions else None
            all_hidden_states = () if output_hidden_states else None
            all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None

            # 检查是否需要为每个层级指定正确数量的头部掩码
            if head_mask is not None:
                if head_mask.shape[0] != (len(self.layers)):
                    raise ValueError(
                        f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.shape[0]}."
                    )

            # 遍历每一层的 Transformer 层进行处理
            for i, layer in enumerate(self.layers):
                # 如果需要输出隐藏状态,则将当前隐藏状态添加到 all_hidden_states 中
                if output_hidden_states:
                    all_hidden_states += (hidden_states,)

                # 调用当前层的 Transformer 层进行前向传播
                layer_outputs = layer(
                    hidden_states,
                    attention_mask,
                    head_mask[i] if head_mask is not None else None,
                    encoder_hidden_states,
                    encoder_attention_mask,
                    init_cache,
                    deterministic,
                    output_attentions,
                )

                # 更新隐藏状态为当前层输出的第一个元素
                hidden_states = layer_outputs[0]

                # 如果需要输出注意力权重,则将当前层的注意力权重添加到 all_attentions 中
                if output_attentions:
                    all_attentions += (layer_outputs[1],)

                    # 如果有编码器的隐藏状态,将当前层的交叉注意力权重添加到 all_cross_attentions 中
                    if encoder_hidden_states is not None:
                        all_cross_attentions += (layer_outputs[2],)

            # 如果需要输出隐藏状态,则将最终的隐藏状态添加到 all_hidden_states 中
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            # 构建最终的输出元组
            outputs = (hidden_states, all_hidden_states, all_attentions, all_cross_attentions)

            # 如果不需要返回字典形式的结果,则以元组形式返回所有非空结果
            if not return_dict:
                return tuple(v for v in outputs if v is not None)

            # 如果需要返回字典形式的结果,则构建特定格式的输出对象并返回
            return FlaxBaseModelOutputWithPastAndCrossAttentions(
                last_hidden_state=hidden_states,
                hidden_states=all_hidden_states,
                attentions=all_attentions,
                cross_attentions=all_cross_attentions,
            )
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEncoder with Bert->BigBird
class FlaxBigBirdEncoder(nn.Module):
    config: BigBirdConfig
    dtype: jnp.dtype = jnp.float32  # 计算过程中使用的数据类型
    gradient_checkpointing: bool = False  # 梯度检查点是否启用,默认为 False

    def setup(self):
        # 初始化 BigBird 编码器层集合,配置包括数据类型和梯度检查点设置
        self.layer = FlaxBigBirdLayerCollection(
            self.config,
            dtype=self.dtype,
            gradient_checkpointing=self.gradient_checkpointing,
        )

    def __call__(
        self,
        hidden_states,
        attention_mask,
        head_mask,
        encoder_hidden_states: Optional[jnp.ndarray] = None,
        encoder_attention_mask: Optional[jnp.ndarray] = None,
        init_cache: bool = False,
        deterministic: bool = True,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        # 调用 BigBird 编码器层集合来处理输入
        return self.layer(
            hidden_states,
            attention_mask,
            head_mask=head_mask,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            init_cache=init_cache,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )


# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPredictionHeadTransform with Bert->BigBird
class FlaxBigBirdPredictionHeadTransform(nn.Module):
    config: BigBirdConfig
    dtype: jnp.dtype = jnp.float32

    def setup(self):
        # 初始化 BigBird 预测头转换层,包括稠密层、激活函数和 LayerNorm 层
        self.dense = nn.Dense(self.config.hidden_size, dtype=self.dtype)
        self.activation = ACT2FN[self.config.hidden_act]
        self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)

    def __call__(self, hidden_states):
        # 通过稠密层、激活函数和 LayerNorm 层处理隐藏状态
        hidden_states = self.dense(hidden_states)
        hidden_states = self.activation(hidden_states)
        return self.LayerNorm(hidden_states)


# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLMPredictionHead with Bert->BigBird, np.ndarray->jnp.ndarray
class FlaxBigBirdLMPredictionHead(nn.Module):
    config: BigBirdConfig
    dtype: jnp.dtype = jnp.float32
    bias_init: Callable[..., jnp.ndarray] = jax.nn.initializers.zeros

    def setup(self):
        # 初始化 BigBird 语言模型预测头,包括预测头转换和输出稠密层
        self.transform = FlaxBigBirdPredictionHeadTransform(self.config, dtype=self.dtype)
        self.decoder = nn.Dense(self.config.vocab_size, dtype=self.dtype, use_bias=False)
        self.bias = self.param("bias", self.bias_init, (self.config.vocab_size,))
    # 定义一个特殊方法 __call__,使得对象可以像函数一样被调用
    def __call__(self, hidden_states, shared_embedding=None):
        # 调用 transform 方法对隐藏状态进行变换处理
        hidden_states = self.transform(hidden_states)

        # 如果提供了共享的嵌入矩阵,则使用 decoder 对象应用该共享嵌入
        if shared_embedding is not None:
            hidden_states = self.decoder.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
        else:
            # 否则,直接使用 decoder 对象处理隐藏状态
            hidden_states = self.decoder(hidden_states)

        # 将 bias 转换为与当前数据类型相匹配的 JAX 数组
        bias = jnp.asarray(self.bias, self.dtype)
        # 将隐藏状态加上偏置项
        hidden_states += bias
        # 返回处理后的隐藏状态
        return hidden_states
# 从 transformers.models.bert.modeling_flax_bert.FlaxBertOnlyMLMHead 复制并修改为 BigBird
class FlaxBigBirdOnlyMLMHead(nn.Module):
    # 使用 BigBirdConfig 配置类初始化模块
    config: BigBirdConfig
    # 默认数据类型为 jnp.float32
    dtype: jnp.dtype = jnp.float32

    def setup(self):
        # 使用 BigBirdLMPredictionHead 初始化预测头部
        self.predictions = FlaxBigBirdLMPredictionHead(self.config, dtype=self.dtype)

    def __call__(self, hidden_states, shared_embedding=None):
        # 使用预测头部处理隐藏状态并返回结果
        hidden_states = self.predictions(hidden_states, shared_embedding=shared_embedding)
        return hidden_states


# 从 transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainingHeads 复制并修改为 BigBird
class FlaxBigBirdPreTrainingHeads(nn.Module):
    # 使用 BigBirdConfig 配置类初始化模块
    config: BigBirdConfig
    # 默认数据类型为 jnp.float32
    dtype: jnp.dtype = jnp.float32

    def setup(self):
        # 使用 BigBirdLMPredictionHead 初始化预测头部
        self.predictions = FlaxBigBirdLMPredictionHead(self.config, dtype=self.dtype)
        # 使用 Dense 层初始化序列关系预测
        self.seq_relationship = nn.Dense(2, dtype=self.dtype)

    def __call__(self, hidden_states, pooled_output, shared_embedding=None):
        # 使用预测头部处理隐藏状态并返回预测分数
        prediction_scores = self.predictions(hidden_states, shared_embedding=shared_embedding)
        # 使用序列关系预测处理池化输出并返回结果
        seq_relationship_score = self.seq_relationship(pooled_output)
        return prediction_scores, seq_relationship_score


class FlaxBigBirdPreTrainedModel(FlaxPreTrainedModel):
    """
    一个抽象类,处理权重初始化以及下载和加载预训练模型的简单接口。
    """

    # 使用 BigBirdConfig 配置类作为配置类
    config_class = BigBirdConfig
    # 基础模型前缀为 "bert"
    base_model_prefix = "bert"
    # 模块类默认为空
    module_class: nn.Module = None

    def __init__(
        self,
        config: BigBirdConfig,
        input_shape: Optional[tuple] = None,
        seed: int = 0,
        dtype: jnp.dtype = jnp.float32,
        _do_init: bool = True,
        gradient_checkpointing: bool = False,
        **kwargs,
    ):
        # 使用模块类初始化模块,根据配置和其他参数设置输入形状等
        module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs)
        # 根据注意力类型和输入形状设置默认输入形状
        if config.attention_type == "block_sparse" and input_shape is None:
            input_shape = (1, 12 * config.block_size)
        elif input_shape is None:
            input_shape = (1, 1)

        # 调用父类初始化方法
        super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)

    # 从 transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.enable_gradient_checkpointing 复制
    def enable_gradient_checkpointing(self):
        # 使用模块类初始化模块,并启用梯度检查点
        self._module = self.module_class(
            config=self.config,
            dtype=self.dtype,
            gradient_checkpointing=True,
        )
    # 初始化模型权重的函数,使用给定的随机数种子和输入形状初始化模型参数
    def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
        # 初始化输入张量
        input_ids = jnp.zeros(input_shape, dtype="i4")
        # 根据输入张量创建与其相同形状的 token 类型张量,初始化为零
        token_type_ids = jnp.zeros_like(input_ids)
        # 创建位置张量,广播到与 input_ids 相同的形状
        position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
        # 创建注意力掩码张量,初始化为全 1
        attention_mask = jnp.ones_like(input_ids)
        # 创建头掩码张量,形状为 (层数, 注意力头数),初始化为全 1
        head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads))

        # 使用随机数种子 rng 拆分出三个新的随机数种子
        params_rng, dropout_rng, indices_rng = jax.random.split(rng, num=3)
        # 将拆分后的随机数种子保存在字典中
        rngs = {"params": params_rng, "dropout": dropout_rng, "indices": indices_rng}

        # 如果配置中包含跨注意力机制,则初始化编码器隐藏状态和编码器注意力掩码
        if self.config.add_cross_attention:
            encoder_hidden_states = jnp.zeros(input_shape + (self.config.hidden_size,))
            encoder_attention_mask = attention_mask
            # 使用模块的初始化方法初始化模型,返回结果不作为字典返回
            module_init_outputs = self.module.init(
                rngs,
                input_ids,
                attention_mask,
                token_type_ids,
                position_ids,
                head_mask,
                encoder_hidden_states,
                encoder_attention_mask,
                return_dict=False,
            )
        else:
            # 使用模块的初始化方法初始化模型,返回结果不作为字典返回
            module_init_outputs = self.module.init(
                rngs,
                input_ids,
                attention_mask,
                token_type_ids,
                position_ids,
                head_mask,
                return_dict=False,
            )

        # 获取初始化后的随机参数
        random_params = module_init_outputs["params"]

        # 如果给定了预训练参数,则将随机参数展平并添加到参数中
        if params is not None:
            random_params = flatten_dict(unfreeze(random_params))
            params = flatten_dict(unfreeze(params))
            for missing_key in self._missing_keys:
                params[missing_key] = random_params[missing_key]
            self._missing_keys = set()
            # 冻结参数并返回
            return freeze(unflatten_dict(params))
        else:
            # 否则直接返回随机初始化的参数
            return random_params

    # 从 transformers.models.bart.modeling_flax_bart.FlaxBartDecoderPreTrainedModel.init_cache 复制而来
    # 初始化缓存的函数,用于快速自回归解码
    def init_cache(self, batch_size, max_length):
        r"""
        Args:
            batch_size (`int`):
                用于快速自回归解码的批量大小,定义了初始化缓存的批量大小。
            max_length (`int`):
                自回归解码的最大可能长度,定义了初始化缓存的序列长度。
        """
        # 初始化用于检索缓存的输入变量
        input_ids = jnp.ones((batch_size, max_length), dtype="i4")
        # 创建与 input_ids 相同形状的注意力掩码张量,初始化为全 1
        attention_mask = jnp.ones_like(input_ids, dtype="i4")
        # 广播位置张量到与 input_ids 相同的形状
        position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)

        # 使用模块的初始化方法初始化模型,返回结果不作为字典返回,并标记为初始化缓存
        init_variables = self.module.init(
            jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
        )
        # 返回解冻后的缓存变量
        return unfreeze(init_variables["cache"])
    # 将模型调用函数装饰为使用指定的文档字符串格式
    @add_start_docstrings_to_model_forward(BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    # 定义模型调用函数,接受多个输入参数
    def __call__(
        self,
        input_ids,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        params: dict = None,
        dropout_rng: Optional[jax.random.PRNGKey] = None,
        indices_rng: Optional[jax.random.PRNGKey] = None,
        train: bool = False,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        past_key_values: dict = None,
# 定义了一个 FlaxBigBirdModule 类,继承自 nn.Module
class FlaxBigBirdModule(nn.Module):
    # 类属性,存储 BigBirdConfig 配置对象
    config: BigBirdConfig
    # 计算时使用的数据类型,默认为 jnp.float32
    dtype: jnp.dtype = jnp.float32  # the dtype of the computation
    # 是否添加池化层的标志,默认为 True
    add_pooling_layer: bool = True
    # 是否使用梯度检查点的标志,默认为 False
    gradient_checkpointing: bool = False

    # 模块初始化方法
    def setup(self):
        # 初始化 embeddings 属性,调用 FlaxBigBirdEmbeddings 构造方法
        self.embeddings = FlaxBigBirdEmbeddings(self.config, dtype=self.dtype)
        # 初始化 encoder 属性,调用 FlaxBigBirdEncoder 构造方法
        self.encoder = FlaxBigBirdEncoder(
            self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
        )
        # 初始化 pooler 属性,调用 nn.Dense 构造方法
        self.pooler = nn.Dense(
            self.config.hidden_size,
            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
            dtype=self.dtype,
        )

    # 对象调用方法,实现模块的前向计算
    def __call__(
        self,
        input_ids,
        attention_mask,
        token_type_ids,
        position_ids,
        head_mask,
        encoder_hidden_states: Optional[jnp.ndarray] = None,
        encoder_attention_mask: Optional[jnp.ndarray] = None,
        init_cache: bool = False,
        deterministic: bool = True,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        # 调用 embeddings 属性的方法,获取输入序列的嵌入表示
        hidden_states = self.embeddings(
            input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic
        )
        # 调用 encoder 属性的方法,对输入的隐藏状态进行编码
        outputs = self.encoder(
            hidden_states,
            attention_mask,
            head_mask=head_mask,
            deterministic=deterministic,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            init_cache=init_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        # 从 encoder 输出中获取隐藏状态
        hidden_states = outputs[0]

        # 如果设置了添加池化层的标志,则对隐藏状态进行池化操作
        pooled = nn.tanh(self.pooler(hidden_states[:, 0, :])) if self.add_pooling_layer else None

        # 如果 return_dict 为 False,则根据 pooled 是否为 None 返回不同的输出
        if not return_dict:
            if pooled is None:
                return (hidden_states,) + outputs[1:]
            return (hidden_states, pooled) + outputs[1:]

        # 构建返回的输出对象,包括最终的隐藏状态和池化输出
        return FlaxBaseModelOutputWithPoolingAndCrossAttentions(
            last_hidden_state=hidden_states,
            pooler_output=pooled,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
            cross_attentions=outputs.cross_attentions,
        )


# 添加了关于 BigBird 模型的文档字符串
@add_start_docstrings(
    "The bare BigBird Model transformer outputting raw hidden-states without any specific head on top.",
    BIG_BIRD_START_DOCSTRING,
)
# 从 FlaxBigBirdPreTrainedModel 继承,并将 module_class 设置为 FlaxBigBirdModule
class FlaxBigBirdModel(FlaxBigBirdPreTrainedModel):
    module_class = FlaxBigBirdModule


# 复制自 transformers.models.bert.modeling_flax_bert.FlaxBertModel,将其中的 Bert 替换为 BigBird
# 添加了对 FlaxBigBirdModel 的调用样例文档字符串
append_call_sample_docstring(FlaxBigBirdModel, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutputWithPooling, _CONFIG_FOR_DOC)


# 复制自 transformers.models.bert.modeling_flax_bert.FlaxBertForPreTrainingModule,将其中的 Bert 替换为 BigBird
class FlaxBigBirdForPreTrainingModule(nn.Module):
    # 定义类的属性,BigBirdConfig 类型的 config,默认数据类型为 jnp.float32 的 dtype,是否开启梯度检查点的 gradient_checkpointing
    config: BigBirdConfig
    dtype: jnp.dtype = jnp.float32
    gradient_checkpointing: bool = False

    # 类的初始化方法
    def setup(self):
        # 初始化 FlaxBigBirdModule 类对象 self.bert,传入配置 config、数据类型 dtype、梯度检查点设置 gradient_checkpointing
        self.bert = FlaxBigBirdModule(
            config=self.config,
            dtype=self.dtype,
            gradient_checkpointing=self.gradient_checkpointing,
        )
        # 初始化 FlaxBigBirdPreTrainingHeads 类对象 self.cls,传入配置 config、数据类型 dtype
        self.cls = FlaxBigBirdPreTrainingHeads(config=self.config, dtype=self.dtype)

    # 类的调用方法,接收多个参数,包括输入的各种 IDs、掩码、位置 IDs、头掩码,以及一些控制参数
    def __call__(
        self,
        input_ids,
        attention_mask,
        token_type_ids,
        position_ids,
        head_mask,
        deterministic: bool = True,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        # 调用 self.bert 对象进行模型前向传播,传入所有参数,并指定返回的数据类型是字典(return_dict=True)
        outputs = self.bert(
            input_ids,
            attention_mask,
            token_type_ids,
            position_ids,
            head_mask,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 根据配置决定是否共享词嵌入矩阵
        if self.config.tie_word_embeddings:
            # 如果要求共享词嵌入矩阵,则获取 self.bert 对象中的共享词嵌入
            shared_embedding = self.bert.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
        else:
            shared_embedding = None

        # 从模型输出中获取隐藏状态和池化输出
        hidden_states = outputs[0]
        pooled_output = outputs[1]

        # 调用 self.cls 对象进行预测头部预训练任务的预测,传入隐藏状态、池化输出以及可能的共享词嵌入
        prediction_scores, seq_relationship_score = self.cls(
            hidden_states, pooled_output, shared_embedding=shared_embedding
        )

        # 根据 return_dict 的值确定返回的数据结构
        if not return_dict:
            # 如果 return_dict=False,则返回元组形式的输出,包括预测得分、序列关系得分以及额外的隐藏状态和注意力权重
            return (prediction_scores, seq_relationship_score) + outputs[2:]

        # 如果 return_dict=True,则返回 FlaxBigBirdForPreTrainingOutput 类的实例,包含预测得分、序列关系得分、隐藏状态和注意力权重
        return FlaxBigBirdForPreTrainingOutput(
            prediction_logits=prediction_scores,
            seq_relationship_logits=seq_relationship_score,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
@add_start_docstrings(
    """
    BigBird Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next
    sentence prediction (classification)` head.
    """,
    BIG_BIRD_START_DOCSTRING,
)
# 定义一个BigBird模型,包含预训练过程中的两个头部:掩码语言建模头部和下一个句子预测头部
# 这段注释是为了说明该类是从FlaxBigBirdPreTrainedModel继承而来的,并设置了模块类为FlaxBigBirdForPreTrainingModule
class FlaxBigBirdForPreTraining(FlaxBigBirdPreTrainedModel):
    module_class = FlaxBigBirdForPreTrainingModule


FLAX_BIG_BIRD_FOR_PRETRAINING_DOCSTRING = """
    Returns:

    Example:

    ```
    >>> from transformers import AutoTokenizer, FlaxBigBirdForPreTraining

    >>> tokenizer = AutoTokenizer.from_pretrained("google/bigbird-roberta-base")
    >>> model = FlaxBigBirdForPreTraining.from_pretrained("google/bigbird-roberta-base")

    >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="np")
    >>> outputs = model(**inputs)

    >>> prediction_logits = outputs.prediction_logits
    >>> seq_relationship_logits = outputs.seq_relationship_logits
    ```
"""
# 更新FlaxBigBirdForPreTraining类的文档字符串,包含了输入说明和示例
overwrite_call_docstring(
    FlaxBigBirdForPreTraining,
    BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, sequence_length") + FLAX_BIG_BIRD_FOR_PRETRAINING_DOCSTRING,
)
# 向FlaxBigBirdForPreTraining类中追加或替换返回文档字符串,指定了输出类型为FlaxBigBirdForPreTrainingOutput,配置类为_CONFIG_FOR_DOC


# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForMaskedLMModule with Bert->BigBird
# 从transformers.models.bert.modeling_flax_bert.FlaxBertForMaskedLMModule复制过来,将Bert更换为BigBird
class FlaxBigBirdForMaskedLMModule(nn.Module):
    config: BigBirdConfig
    dtype: jnp.dtype = jnp.float32
    gradient_checkpointing: bool = False

    def setup(self):
        # 初始化BigBird模块,不添加池化层
        self.bert = FlaxBigBirdModule(
            config=self.config,
            add_pooling_layer=False,
            dtype=self.dtype,
            gradient_checkpointing=self.gradient_checkpointing,
        )
        # 初始化BigBird模型的MLM头部
        self.cls = FlaxBigBirdOnlyMLMHead(config=self.config, dtype=self.dtype)

    def __call__(
        self,
        input_ids,
        attention_mask,
        token_type_ids,
        position_ids,
        head_mask,
        deterministic: bool = True,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
        # 定义了FlaxBigBirdForMaskedLMModule的调用方法,接受多个输入参数
        # 调用 BERT 模型进行前向传播,获取模型输出
        outputs = self.bert(
            input_ids,
            attention_mask,
            token_type_ids,
            position_ids,
            head_mask,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 从模型输出中提取隐藏状态
        hidden_states = outputs[0]

        # 如果配置要求共享词嵌入,则获取共享的词嵌入向量
        if self.config.tie_word_embeddings:
            shared_embedding = self.bert.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
        else:
            shared_embedding = None

        # 使用分类头部模型计算预测分数
        logits = self.cls(hidden_states, shared_embedding=shared_embedding)

        # 如果不返回字典形式的输出,则返回元组
        if not return_dict:
            return (logits,) + outputs[1:]

        # 返回 FlaxMaskedLMOutput 类的实例作为字典形式的输出
        return FlaxMaskedLMOutput(
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
@add_start_docstrings("""BigBird Model with a `language modeling` head on top.""", BIG_BIRD_START_DOCSTRING)
# 添加起始文档字符串,说明这是在 BigBird 模型基础上加上语言建模头部的类
# 从 transformers.models.bert.modeling_flax_bert.FlaxBertForMaskedLM 复制并将 Bert 改为 BigBird
class FlaxBigBirdForMaskedLM(FlaxBigBirdPreTrainedModel):
    module_class = FlaxBigBirdForMaskedLMModule

# 添加调用示例文档字符串,描述如何在 FlaxBigBirdForMaskedLM 类上附加检查点的说明
append_call_sample_docstring(FlaxBigBirdForMaskedLM, _CHECKPOINT_FOR_DOC, FlaxMaskedLMOutput, _CONFIG_FOR_DOC)

# BigBird 分类头部,用于句子级别分类任务
class FlaxBigBirdClassificationHead(nn.Module):
    """Head for sentence-level classification tasks."""

    config: BigBirdConfig
    dtype: jnp.dtype = jnp.float32

    def setup(self):
        self.dense = nn.Dense(self.config.hidden_size, dtype=self.dtype)
        # 设置分类器的 dropout,如果未提供特定的分类器 dropout,则使用隐藏层 dropout
        classifier_dropout = (
            self.config.classifier_dropout
            if self.config.classifier_dropout is not None
            else self.config.hidden_dropout_prob
        )
        self.dropout = nn.Dropout(classifier_dropout)
        self.out_proj = nn.Dense(self.config.num_labels, dtype=self.dtype)

    def __call__(self, features, deterministic=True):
        x = features[:, 0, :]  # 取 <s> token(相当于 [CLS])
        x = self.dropout(x, deterministic=deterministic)
        x = self.dense(x)
        x = ACT2FN[self.config.hidden_act](x)  # 使用指定的激活函数处理隐藏层输出
        x = self.dropout(x, deterministic=deterministic)
        x = self.out_proj(x)
        return x

# BigBird 序列分类模块
class FlaxBigBirdForSequenceClassificationModule(nn.Module):
    config: BigBirdConfig
    dtype: jnp.dtype = jnp.float32
    gradient_checkpointing: bool = False

    def setup(self):
        # 设置 BigBird 模块作为 BERT
        self.bert = FlaxBigBirdModule(
            config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
        )
        self.classifier = FlaxBigBirdClassificationHead(self.config, dtype=self.dtype)

    def __call__(
        self,
        input_ids,
        attention_mask,
        token_type_ids,
        position_ids,
        head_mask,
        deterministic: bool = True,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        # 模型计算
        outputs = self.bert(
            input_ids,
            attention_mask,
            token_type_ids,
            position_ids,
            head_mask,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output = outputs[0]  # 获取序列输出
        logits = self.classifier(sequence_output, deterministic=deterministic)  # 使用分类头部进行分类

        if not return_dict:
            return (logits,) + outputs[2:]

        # 返回序列分类器输出对象
        return FlaxSequenceClassifierOutput(
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

@add_start_docstrings(
    """
    BigBird Model transformer with a sequence classification/regression head on top (a linear layer on top of the
    """
)
# 添加起始文档字符串,说明这是在 BigBird 模型基础上加上序列分类/回归头部的类(线性层在顶部)
    pooled output) e.g. for GLUE tasks.
    ```
    这部分代码是一个多行字符串,描述了`BigBirdForSequenceClassification`类的用途和功能,特别是在GLUE任务中如何使用汇集输出(pooled output)。
    ```
    BIG_BIRD_START_DOCSTRING,
    ```
    这里调用了`BIG_BIRD_START_DOCSTRING`,它可能是一个预定义的常量或函数,用于指示文档字符串的开始位置。
    ```
# 从transformers.models.bert.modeling_flax_bert.FlaxBertForSequenceClassification复制代码,将Bert改为BigBird
class FlaxBigBirdForSequenceClassification(FlaxBigBirdPreTrainedModel):
    # 将模块类指定为FlaxBigBirdForSequenceClassificationModule
    module_class = FlaxBigBirdForSequenceClassificationModule


# 将样本调用文档字符串附加到FlaxBigBirdForSequenceClassification类上
append_call_sample_docstring(
    FlaxBigBirdForSequenceClassification,
    _CHECKPOINT_FOR_DOC,
    FlaxSequenceClassifierOutput,
    _CONFIG_FOR_DOC,
)


# 从transformers.models.bert.modeling_flax_bert.FlaxBertForMultipleChoiceModule复制代码,将Bert改为BigBird
class FlaxBigBirdForMultipleChoiceModule(nn.Module):
    # BigBird配置
    config: BigBirdConfig
    # 数据类型,默认为32位浮点数
    dtype: jnp.dtype = jnp.float32
    # 梯度检查点,默认关闭
    gradient_checkpointing: bool = False

    def setup(self):
        # 初始化BigBird模块
        self.bert = FlaxBigBirdModule(
            config=self.config,
            dtype=self.dtype,
            gradient_checkpointing=self.gradient_checkpointing,
        )
        # Dropout层,使用隐藏层dropout比率
        self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
        # 分类器,输出为1,使用指定数据类型
        self.classifier = nn.Dense(1, dtype=self.dtype)

    def __call__(
        self,
        input_ids,
        attention_mask,
        token_type_ids,
        position_ids,
        head_mask,
        deterministic: bool = True,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        # 获取选项数量
        num_choices = input_ids.shape[1]
        # 重新整形输入数据,用于模型输入
        input_ids = input_ids.reshape(-1, input_ids.shape[-1]) if input_ids is not None else None
        attention_mask = attention_mask.reshape(-1, attention_mask.shape[-1]) if attention_mask is not None else None
        token_type_ids = token_type_ids.reshape(-1, token_type_ids.shape[-1]) if token_type_ids is not None else None
        position_ids = position_ids.reshape(-1, position_ids.shape[-1]) if position_ids is not None else None

        # 模型前向传播
        outputs = self.bert(
            input_ids,
            attention_mask,
            token_type_ids,
            position_ids,
            head_mask,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 提取池化输出
        pooled_output = outputs[1]
        # 应用dropout到池化输出
        pooled_output = self.dropout(pooled_output, deterministic=deterministic)
        # 应用分类器得到logits
        logits = self.classifier(pooled_output)

        # 重新整形logits,以匹配选项数量
        reshaped_logits = logits.reshape(-1, num_choices)

        # 如果不返回字典,则返回重整后的logits和额外的输出
        if not return_dict:
            return (reshaped_logits,) + outputs[2:]

        # 返回多选模型输出对象
        return FlaxMultipleChoiceModelOutput(
            logits=reshaped_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )


@add_start_docstrings(
    """
    BigBird模型,顶部带有多选分类头(池化输出上的线性层和softmax),例如用于RocStories/SWAG任务。
    """,
    BIG_BIRD_START_DOCSTRING,
)
class FlaxBigBirdForMultipleChoice(FlaxBigBirdPreTrainedModel):
    # 将模块类指定为FlaxBigBirdForMultipleChoiceModule
    module_class = FlaxBigBirdForMultipleChoiceModule
    # 初始化函数,用于创建一个 BigBirdLayer 的实例
    def __init__(
        self,
        config: BigBirdConfig,  # 参数:BigBird 模型的配置对象
        input_shape: Optional[tuple] = None,  # 参数:输入数据的形状,可选,默认为 None
        seed: int = 0,  # 参数:随机种子,默认为 0
        dtype: jnp.dtype = jnp.float32,  # 参数:数据类型,默认为 jnp.float32
        _do_init: bool = True,  # 参数:是否执行初始化,默认为 True
        **kwargs,  # 其他关键字参数
    ):
        # 如果配置的注意力类型是 "block_sparse" 并且输入形状是 None
        if config.attention_type == "block_sparse" and input_shape is None:
            # 设置输入形状为 (1, 1, 12 * config.block_size)
            input_shape = (1, 1, 12 * config.block_size)
        # 如果输入形状仍然是 None
        elif input_shape is None:
            # 设置输入形状为 (1, 1)
            input_shape = (1, 1)
        
        # 调用父类的初始化方法,传递配置对象、输入形状、随机种子、数据类型、是否执行初始化标志位
        super().__init__(config, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
# 调用函数 overwrite_call_docstring,为 FlaxBigBirdForMultipleChoice 类重写文档字符串
overwrite_call_docstring(
    FlaxBigBirdForMultipleChoice, BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
)

# 调用函数 append_call_sample_docstring,为 FlaxBigBirdForMultipleChoice 类附加示例文档字符串
append_call_sample_docstring(
    FlaxBigBirdForMultipleChoice,
    _CHECKPOINT_FOR_DOC,
    FlaxMultipleChoiceModelOutput,
    _CONFIG_FOR_DOC,
)


# 从 transformers.models.bert.modeling_flax_bert.FlaxBertForTokenClassificationModule 复制,并将 Bert 替换为 BigBird
class FlaxBigBirdForTokenClassificationModule(nn.Module):
    config: BigBirdConfig  # 定义配置项为 BigBirdConfig 类型
    dtype: jnp.dtype = jnp.float32  # 数据类型设置为 jnp.float32,默认为浮点数
    gradient_checkpointing: bool = False  # 梯度检查点设置为 False,默认不启用

    def setup(self):
        # 初始化 self.bert,使用 FlaxBigBirdModule 构建 BigBird 模型,设置一些参数
        self.bert = FlaxBigBirdModule(
            config=self.config,
            dtype=self.dtype,
            add_pooling_layer=False,
            gradient_checkpointing=self.gradient_checkpointing,
        )
        # 设置 dropout 层,使用配置中的 classifier_dropout,若未指定则使用 hidden_dropout_prob
        classifier_dropout = (
            self.config.classifier_dropout
            if self.config.classifier_dropout is not None
            else self.config.hidden_dropout_prob
        )
        self.dropout = nn.Dropout(rate=classifier_dropout)  # 设置 dropout 层
        self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype)  # 设置分类器层

    def __call__(
        self,
        input_ids,
        attention_mask,
        token_type_ids,
        position_ids,
        head_mask,
        deterministic: bool = True,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        # 调用 BigBird 模型 self.bert 进行前向传播
        outputs = self.bert(
            input_ids,
            attention_mask,
            token_type_ids,
            position_ids,
            head_mask,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        hidden_states = outputs[0]  # 获取模型输出的隐藏状态
        hidden_states = self.dropout(hidden_states, deterministic=deterministic)  # 应用 dropout
        logits = self.classifier(hidden_states)  # 使用分类器得到 logits

        if not return_dict:
            return (logits,) + outputs[1:]  # 返回 logits 和其它输出

        # 返回 FlaxTokenClassifierOutput 类型的对象,包含 logits、隐藏状态和注意力分布
        return FlaxTokenClassifierOutput(
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )


@add_start_docstrings(
    """
    BigBird 模型添加了一个 token 分类头部(线性层在隐藏状态输出之上),例如用于命名实体识别(NER)任务。
    """,
    BIG_BIRD_START_DOCSTRING,
)
# 从 transformers.models.bert.modeling_flax_bert.FlaxBertForTokenClassification 复制,并将 Bert 替换为 BigBird
class FlaxBigBirdForTokenClassification(FlaxBigBirdPreTrainedModel):
    module_class = FlaxBigBirdForTokenClassificationModule  # 指定模型类为 FlaxBigBirdForTokenClassificationModule


# 附加文档字符串示例到 FlaxBigBirdForTokenClassification 类
append_call_sample_docstring(
    FlaxBigBirdForTokenClassification,
    _CHECKPOINT_FOR_DOC,
    FlaxTokenClassifierOutput,
    _CONFIG_FOR_DOC,
)


# 为问答任务头部定义类 FlaxBigBirdForQuestionAnsweringHead
class FlaxBigBirdForQuestionAnsweringHead(nn.Module):
    config: BigBirdConfig  # 定义配置项为 BigBirdConfig 类型
    dtype: jnp.dtype = jnp.float32  # 数据类型设置为 jnp.float32,默认为浮点数
    # 在模型设置过程中初始化 dropout 层,使用给定的隐藏层dropout概率
    def setup(self):
        self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
        # 初始化一个中间层对象,用于处理 BigBird 模型的中间输出
        self.intermediate = FlaxBigBirdIntermediate(self.config, dtype=self.dtype)
        # 初始化一个输出层对象,用于处理 BigBird 模型的最终输出
        self.output = FlaxBigBirdOutput(self.config, dtype=self.dtype)
        # 初始化一个全连接层,用于执行问题回答任务的最终输出
        self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype)

    # 模型调用方法,接收编码器的输出和一个确定性标志
    def __call__(self, encoder_output, deterministic=True):
        # 对编码器输出应用 dropout 层,根据确定性标志确定是否随机失活
        hidden_states = self.dropout(encoder_output, deterministic=deterministic)
        # 将 dropout 处理后的隐藏状态传递给中间层对象处理
        hidden_states = self.intermediate(hidden_states)
        # 将中间层处理后的输出传递给输出层对象处理,并结合编码器的原始输出
        hidden_states = self.output(hidden_states, encoder_output)
        # 将输出层处理后的结果传递给问题回答的全连接层,生成最终的模型输出
        hidden_states = self.qa_outputs(hidden_states)
        # 返回问题回答任务的最终输出
        return hidden_states
class FlaxBigBirdForQuestionAnsweringModule(nn.Module):
    # 定义模型配置
    config: BigBirdConfig
    # 定义数据类型,默认为32位浮点数
    dtype: jnp.dtype = jnp.float32
    # 是否添加池化层,默认为False
    add_pooling_layer: bool = False
    # 是否使用梯度检查点,默认为False
    gradient_checkpointing: bool = False

    def setup(self):
        # 设置模型的类别数为2
        self.config.num_labels = 2
        # 初始化 BigBird 模型
        self.bert = FlaxBigBirdModule(
            self.config,
            dtype=self.dtype,
            add_pooling_layer=self.add_pooling_layer,
            gradient_checkpointing=self.gradient_checkpointing,
        )
        # 初始化用于问答的分类器
        self.qa_classifier = FlaxBigBirdForQuestionAnsweringHead(self.config, dtype=self.dtype)

    def __call__(
        self,
        input_ids,
        attention_mask,
        token_type_ids,
        position_ids,
        head_mask,
        logits_mask=None,
        deterministic: bool = True,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        # 调用模型计算
        outputs = self.bert(
            input_ids,
            attention_mask,
            token_type_ids,
            position_ids,
            head_mask,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 提取模型输出的隐藏状态
        hidden_states = outputs[0]
        # 如果启用池化层,则提取池化后的输出
        pooled_output = outputs[1] if self.add_pooling_layer else None
        # 使用问答分类器计算 logits
        logits = self.qa_classifier(hidden_states, deterministic=deterministic)

        if logits_mask is not None:
            # 如果提供了 logits_mask,则在竞赛中移除问题标记
            logits = logits - logits_mask * 1e6

        # 将 logits 分割为起始位置和结束位置的预测
        start_logits, end_logits = logits.split(self.config.num_labels, axis=-1)
        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)

        if not return_dict:
            # 如果不要求返回字典,则返回元组形式的结果
            return (start_logits, end_logits) + outputs[1:]

        # 返回问答模型的输出,包括起始和结束 logits,以及其它可选输出
        return FlaxBigBirdForQuestionAnsweringModelOutput(
            start_logits=start_logits,
            end_logits=end_logits,
            pooled_output=pooled_output,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )


@add_start_docstrings(
    """
    BigBird 模型,顶部带有用于抽取式问答任务(如 SQuAD)的跨度分类头部(线性层在隐藏状态输出之上计算 'span start logits' 和 'span end logits')。
    """,
    BIG_BIRD_START_DOCSTRING,
)
class FlaxBigBirdForQuestionAnswering(FlaxBigBirdPreTrainedModel):
    module_class = FlaxBigBirdForQuestionAnsweringModule

    @add_start_docstrings_to_model_forward(BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    # 定义一个静态方法,用于静态调用或实例调用
    @staticmethod
    # 以下是 __call__ 方法的定义,用于模型类实例的调用
    def __call__(
        self,
        input_ids,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        question_lengths=None,
        params: dict = None,
        dropout_rng: Optional[jax.random.PRNGKey] = None,
        indices_rng: Optional[jax.random.PRNGKey] = None,
        train: bool = False,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        # 根据需求设置是否输出注意力权重
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        # 根据需求设置是否输出隐藏状态
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        # 根据需求设置是否返回字典格式的输出结果
        return_dict = return_dict if return_dict is not None else self.config.return_dict

        # 如果未提供位置编码,使用输入张量形状的广播操作生成位置编码
        if position_ids is None:
            position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)

        # 如果未提供注意力掩码,使用与输入张量形状相同的全 1 张量作为注意力掩码
        if attention_mask is None:
            attention_mask = jnp.ones_like(input_ids)

        # 如果未提供头部掩码,使用形状为 (层数, 注意力头数) 的全 1 张量作为头部掩码
        if head_mask is None:
            head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads))

        # 如果未提供问题长度并且输入不为空,则计算问题长度
        if question_lengths is None and input_ids is not None:
            # 假设输入格式为:<cls> <question> <sep> context <sep>
            question_lengths = jnp.argmax((input_ids == self.config.sep_token_id).astype("i4"), axis=-1) + 1
            question_lengths = jnp.expand_dims(question_lengths, axis=1)

        # 计算输入张量的序列长度
        seqlen = input_ids.shape[1]

        # 初始化 logits_mask 为 None
        logits_mask = None
        # 如果存在问题长度,则准备问题掩码
        if question_lengths is not None:
            # 将长度为问题的 logits 设置为 `-inf`
            logits_mask = self.prepare_question_mask(question_lengths, seqlen)
            # 如果未提供 token_type_ids,则使用 logits_mask 的反向值
            if token_type_ids is None:
                token_type_ids = (~logits_mask).astype("i4")
            logits_mask = jnp.expand_dims(logits_mask, axis=2)
            logits_mask = logits_mask.at[:, 0].set(False)

        # 如果未提供 token_type_ids,则初始化为与 input_ids 形状相同的全 0 张量
        if token_type_ids is None:
            token_type_ids = jnp.zeros_like(input_ids)

        # 如果需要处理任何伪随机数生成器(PRNG)
        rngs = {}
        if dropout_rng is not None:
            rngs["dropout"] = dropout_rng

        if indices_rng is not None:
            rngs["indices"] = indices_rng

        # 调用 self.module 的 apply 方法,传递各种输入参数
        return self.module.apply(
            {"params": params or self.params},
            jnp.array(input_ids, dtype="i4"),
            jnp.array(attention_mask, dtype="i4"),
            token_type_ids,
            jnp.array(position_ids, dtype="i4"),
            jnp.array(head_mask, dtype="i4"),
            logits_mask,
            not train,
            output_attentions,
            output_hidden_states,
            return_dict,
            rngs=rngs,
        )
    # 定义函数 prepare_question_mask,准备问题的掩码
    def prepare_question_mask(q_lengths, maxlen: int):
        # q_lengths -> (bz, 1)
        # 创建一个长度为 maxlen 的数组 mask,其中包含从 0 到 maxlen-1 的整数
        mask = jnp.arange(0, maxlen)
        # 将 mask 扩展为二维数组,与 q_lengths 比较,生成布尔型掩码
        mask = jnp.expand_dims(mask, axis=0) < q_lengths
        # 返回生成的掩码
        return mask
# 将示例文档字符串添加到指定的模型类中
append_call_sample_docstring(
    FlaxBigBirdForQuestionAnswering,  # 要添加文档字符串的模型类
    _CHECKPOINT_FOR_DOC,  # 用于文档的检查点
    FlaxBigBirdForQuestionAnsweringModelOutput,  # 模型输出类
    _CONFIG_FOR_DOC,  # 用于文档的配置
)


# 定义一个用于语言建模的 BigBird 模型类
class FlaxBigBirdForCausalLMModule(nn.Module):
    config: BigBirdConfig  # BigBird 模型的配置类
    dtype: jnp.dtype = jnp.float32  # 数据类型,默认为 jnp.float32
    gradient_checkpointing: bool = False  # 是否使用梯度检查点,默认为 False

    def setup(self):
        # 初始化 BigBird 模型,不添加池化层
        self.bert = FlaxBigBirdModule(
            config=self.config,
            add_pooling_layer=False,
            dtype=self.dtype,
            gradient_checkpointing=self.gradient_checkpointing,
        )
        # 初始化仅包含 MLM 头部的 BigBird 模型
        self.cls = FlaxBigBirdOnlyMLMHead(config=self.config, dtype=self.dtype)

    def __call__(
        self,
        input_ids,
        attention_mask,
        position_ids,
        token_type_ids: Optional[jnp.ndarray] = None,
        head_mask: Optional[jnp.ndarray] = None,
        encoder_hidden_states: Optional[jnp.ndarray] = None,
        encoder_attention_mask: Optional[jnp.ndarray] = None,
        init_cache: bool = False,
        deterministic: bool = True,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        # 模型前向传播
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            init_cache=init_cache,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        hidden_states = outputs[0]
        if self.config.tie_word_embeddings:
            # 如果需要共享词嵌入,则获取共享的词嵌入
            shared_embedding = self.bert.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
        else:
            shared_embedding = None

        # 计算预测分数
        logits = self.cls(hidden_states, shared_embedding=shared_embedding)

        if not return_dict:
            return (logits,) + outputs[1:]

        # 返回带有交叉注意力的 FlaxCausalLMOutputWithCrossAttentions 类的输出
        return FlaxCausalLMOutputWithCrossAttentions(
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
            cross_attentions=outputs.cross_attentions,
        )


@add_start_docstrings(
    """
    在 BigBird 模型顶部添加一个语言建模头部的模型(在隐藏状态输出的顶部添加一个线性层),例如用于自回归任务。
    """,
    BIG_BIRD_START_DOCSTRING,  # BigBird 模型的起始文档字符串
)
# 从 transformers.models.bert.modeling_flax_bert.FlaxBertForCausalLM 复制并修改为 BigBird
class FlaxBigBirdForCausalLM(FlaxBigBirdPreTrainedModel):
    module_class = FlaxBigBirdForCausalLMModule  # 使用的模块类
    def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
        # initializing the cache
        # 获取输入的批次大小和序列长度
        batch_size, seq_length = input_ids.shape

        # 使用 self.init_cache 方法初始化过去的键值对
        past_key_values = self.init_cache(batch_size, max_length)

        # 注意:通常需要为 attention_mask 中大于 input_ids.shape[-1] 和小于 cache_length 的位置放置 0
        # 但由于解码器使用因果 mask,这些位置已经被屏蔽了。
        # 因此,我们可以在这里创建一个静态的 attention_mask,这样更有效率。
        extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
        
        # 如果提供了 attention_mask,则根据其累积求和计算 position_ids
        if attention_mask is not None:
            position_ids = attention_mask.cumsum(axis=-1) - 1
            # 使用 lax.dynamic_update_slice 更新 extended_attention_mask 的部分值
            extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
        else:
            # 否则,使用广播方式创建 position_ids
            position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))

        # 返回准备好的输入参数字典
        return {
            "past_key_values": past_key_values,
            "attention_mask": extended_attention_mask,
            "position_ids": position_ids,
        }

    def update_inputs_for_generation(self, model_outputs, model_kwargs):
        # 更新 model_kwargs 中的 past_key_values 和 position_ids
        model_kwargs["past_key_values"] = model_outputs.past_key_values
        model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
        return model_kwargs
# 调用函数append_call_sample_docstring,将以下参数传递给它:
# - FlaxBigBirdForCausalLM: 类型,表示为生成样例文档字符串时用到的模型类
# - _CHECKPOINT_FOR_DOC: 常量,表示为生成样例文档字符串时用到的检查点名称
# - FlaxCausalLMOutputWithCrossAttentions: 类型,表示为生成样例文档字符串时用到的模型输出类
# - _CONFIG_FOR_DOC: 常量,表示为生成样例文档字符串时用到的配置信息名称
append_call_sample_docstring(
    FlaxBigBirdForCausalLM,
    _CHECKPOINT_FOR_DOC,
    FlaxCausalLMOutputWithCrossAttentions,
    _CONFIG_FOR_DOC,
)

.\models\big_bird\tokenization_big_bird.py

# coding=utf-8
# Copyright 2021 Google Research and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for BigBird."""


import os
import re
from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple

import sentencepiece as spm  # 导入 sentencepiece 库

from ...tokenization_utils import AddedToken, PreTrainedTokenizer
from ...utils import logging  # 导入 logging 模块


logger = logging.get_logger(__name__)  # 获取当前模块的 logger

VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"}  # 定义词汇文件的名称映射

PRETRAINED_VOCAB_FILES_MAP = {
    "vocab_file": {
        "google/bigbird-roberta-base": "https://huggingface.co/google/bigbird-roberta-base/resolve/main/spiece.model",
        "google/bigbird-roberta-large": (
            "https://huggingface.co/google/bigbird-roberta-large/resolve/main/spiece.model"
        ),
        "google/bigbird-base-trivia-itc": (
            "https://huggingface.co/google/bigbird-base-trivia-itc/resolve/main/spiece.model"
        ),
    }
}  # 预训练词汇文件的映射,包含模型名称及其对应的远程路径

PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
    "google/bigbird-roberta-base": 4096,
    "google/bigbird-roberta-large": 4096,
    "google/bigbird-base-trivia-itc": 4096,
}  # 预训练位置嵌入的尺寸映射,包含模型名称及其对应的位置嵌入大小


class BigBirdTokenizer(PreTrainedTokenizer):
    """
    Construct a BigBird tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece).

    This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
    this superclass for more information regarding those methods.
    """
    # BigBirdTokenizer 类,继承自 PreTrainedTokenizer,用于构建 BigBird 分词器,基于 SentencePiece
    pass  # 占位符,暂未实现额外的方法或属性,仅作为类声明的结尾
    # vocab_file 参数:指定 SentencePiece 文件的路径,该文件包含用于实例化分词器的词汇表
    # unk_token 参数(可选,默认为 "<unk>"):未知标记,表示词汇表中不存在的词汇将被设置为此标记
    # bos_token 参数(可选,默认为 "<s>"):序列开始标记
    # eos_token 参数(可选,默认为 "</s>"):序列结束标记
    # pad_token 参数(可选,默认为 "<pad>"):用于填充的标记,在处理不同长度的序列时使用
    # sep_token 参数(可选,默认为 "[SEP]"):分隔符标记,用于构建多个序列的时候使用
    # mask_token 参数(可选,默认为 "[MASK]"):掩码标记,在掩码语言建模(Masked Language Modeling)中使用,模型会尝试预测这些标记
    # cls_token 参数(可选,默认为 "[CLS]"):分类器标记,用于序列分类任务中,表示序列的开始
    # sp_model_kwargs 参数(可选):将传递给 SentencePieceProcessor.__init__() 方法的参数字典,
    # 可以用于配置 SentencePiece 的各种参数,例如启用子词正则化、设置采样参数等

    vocab_files_names = VOCAB_FILES_NAMES
    # vocab_files_names 变量:包含了预训练模型所需的词汇文件名的列表

    pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
    # pretrained_vocab_files_map 变量:包含了预训练模型对应的词汇文件路径的映射字典

    max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
    # max_model_input_sizes 变量:包含了预训练位置嵌入的最大输入尺寸的字典

    model_input_names = ["input_ids", "attention_mask"]
    # model_input_names 变量:包含了模型输入的名称列表,用于对应模型的输入要求

    prefix_tokens: List[int] = []
    # prefix_tokens 变量:用于存储特殊前缀标记的列表,初始化为空列表
    # 初始化函数,接受多个参数和关键字参数来配置词汇表和特殊标记
    def __init__(
        self,
        vocab_file,
        unk_token="<unk>",
        bos_token="<s>",
        eos_token="</s>",
        pad_token="<pad>",
        sep_token="[SEP]",
        mask_token="[MASK]",
        cls_token="[CLS]",
        sp_model_kwargs: Optional[Dict[str, Any]] = None,
        **kwargs,
    ) -> None:
        # 如果特殊标记是字符串,则将其转换为 AddedToken 对象,保留其空白字符处理设置
        bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
        eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
        unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
        pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
        cls_token = AddedToken(cls_token, lstrip=False, rstrip=False) if isinstance(cls_token, str) else cls_token
        sep_token = AddedToken(sep_token, lstrip=False, rstrip=False) if isinstance(sep_token, str) else sep_token

        # Mask token 被视为普通单词,即包括其前面的空格
        mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token

        # 如果未提供 sp_model_kwargs,则设为空字典
        self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs

        # 保存词汇表文件路径
        self.vocab_file = vocab_file

        # 创建 SentencePieceProcessor 对象,并加载词汇表文件
        self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
        self.sp_model.Load(vocab_file)

        # 调用父类的初始化方法,传递特殊标记和其它关键字参数
        super().__init__(
            bos_token=bos_token,
            eos_token=eos_token,
            unk_token=unk_token,
            pad_token=pad_token,
            sep_token=sep_token,
            mask_token=mask_token,
            cls_token=cls_token,
            sp_model_kwargs=self.sp_model_kwargs,
            **kwargs,
        )

    @property
    # 返回词汇表大小,由 SentencePieceProcessor 对象提供
    def vocab_size(self):
        return self.sp_model.get_piece_size()

    # 返回包含所有词汇及其对应 id 的字典,包括添加的特殊标记
    def get_vocab(self):
        vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
        vocab.update(self.added_tokens_encoder)
        return vocab

    # 返回对象的状态,用于序列化
    def __getstate__(self):
        state = self.__dict__.copy()
        state["sp_model"] = None  # 移除 sp_model 对象,以免被序列化
        return state

    # 设置对象的状态,用于反序列化
    def __setstate__(self, d):
        self.__dict__ = d

        # 为了向后兼容性,如果缺少 sp_model_kwargs 属性,则设为空字典
        if not hasattr(self, "sp_model_kwargs"):
            self.sp_model_kwargs = {}

        # 重新创建 SentencePieceProcessor 对象并加载词汇表文件
        self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
        self.sp_model.Load(self.vocab_file)

    # 对文本进行分词处理,返回由字符串组成的列表(标记)
    def _tokenize(self, text: str) -> List[str]:
        """Take as input a string and return a list of strings (tokens) for words/sub-words"""
        return self.sp_model.encode(text, out_type=str)

    # 将标记(字符串)转换为其对应的 id
    def _convert_token_to_id(self, token):
        """Converts a token (str) in an id using the vocab."""
        return self.sp_model.piece_to_id(token)
    # 使用给定的索引在词汇表中将索引转换为对应的标记字符串
    def _convert_id_to_token(self, index):
        """Converts an index (integer) in a token (str) using the vocab."""
        # 使用 sentencepiece 模型将索引转换为对应的标记字符串
        token = self.sp_model.IdToPiece(index)
        return token

    # 从 transformers.models.albert.tokenization_albert.AlbertTokenizer.convert_tokens_to_string 复制而来
    # 将一系列标记字符串转换为单个字符串
    def convert_tokens_to_string(self, tokens):
        """Converts a sequence of tokens (string) in a single string."""
        current_sub_tokens = []  # 当前正在处理的子标记列表
        out_string = ""  # 输出的合并后的字符串
        prev_is_special = False  # 上一个标记是否为特殊标记
        for token in tokens:
            # 确保特殊标记不会使用 sentencepiece 模型进行解码
            if token in self.all_special_tokens:
                if not prev_is_special:
                    out_string += " "  # 添加空格来分隔特殊标记
                # 使用 sentencepiece 模型解码当前子标记列表,并加上当前特殊标记
                out_string += self.sp_model.decode(current_sub_tokens) + token
                prev_is_special = True
                current_sub_tokens = []  # 重置当前子标记列表
            else:
                current_sub_tokens.append(token)  # 将当前标记添加到当前子标记列表中
                prev_is_special = False
        # 将剩余的子标记列表使用 sentencepiece 模型解码,并添加到输出字符串中
        out_string += self.sp_model.decode(current_sub_tokens)
        return out_string.strip()  # 返回去掉两侧空格的输出字符串

    def _decode(
        self,
        token_ids: List[int],
        skip_special_tokens: bool = False,
        clean_up_tokenization_spaces: bool = None,
        spaces_between_special_tokens: bool = True,
        **kwargs,
    ):
        ) -> str:
        # 从 kwargs 中获取 use_source_tokenizer 参数,并设置到实例变量中
        self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False)

        # 转换 token_ids 到 tokens 列表,跳过特殊标记(如果需要)
        filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)

        # 为了避免混合字节级别和 unicode 字符(例如字节级别 BPT),需要分别构建字符串以处理添加的标记和字节级别的 tokens
        # 参考:https://github.com/huggingface/transformers/issues/1133
        sub_texts = []
        current_sub_text = []
        for token in filtered_tokens:
            if skip_special_tokens and token in self.all_special_ids:
                continue
            # 如果 token 是添加的特殊标记
            if token in self.added_tokens_encoder:
                # 如果当前子文本不为空,则将其转换为字符串并添加到 sub_texts 中
                if current_sub_text:
                    sub_texts.append(self.convert_tokens_to_string(current_sub_text))
                    current_sub_text = []
                sub_texts.append(token)
            else:
                current_sub_text.append(token)
        # 将剩余的 current_sub_text 转换为字符串并添加到 sub_texts 中
        if current_sub_text:
            sub_texts.append(self.convert_tokens_to_string(current_sub_text))

        # 模仿 Rust 分词器的行为:
        # 在 [MASK] 和 [SEP] 前不添加空格
        if spaces_between_special_tokens:
            # 使用正则表达式去除特殊标记前的空格
            text = re.sub(r" (\[(MASK|SEP)\])", r"\1", " ".join(sub_texts))
        else:
            text = "".join(sub_texts)

        # 根据 clean_up_tokenization_spaces 参数清理分词后的空格
        clean_up_tokenization_spaces = (
            clean_up_tokenization_spaces
            if clean_up_tokenization_spaces is not None
            else self.clean_up_tokenization_spaces
        )
        if clean_up_tokenization_spaces:
            # 清理分词后的空格
            clean_text = self.clean_up_tokenization(text)
            return clean_text
        else:
            return text

    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
        # 如果保存目录不存在,则记录错误并返回
        if not os.path.isdir(save_directory):
            logger.error(f"Vocabulary path ({save_directory}) should be a directory")
            return
        
        # 构建输出的词汇表文件路径
        out_vocab_file = os.path.join(
            save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
        )

        # 如果当前词汇表文件路径和目标文件路径不同且当前词汇表文件存在,则复制当前词汇表文件到目标文件路径
        if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
            copyfile(self.vocab_file, out_vocab_file)
        # 如果当前词汇表文件不存在,则将当前的词汇表内容写入目标文件路径
        elif not os.path.isfile(self.vocab_file):
            with open(out_vocab_file, "wb") as fi:
                content_spiece_model = self.sp_model.serialized_model_proto()
                fi.write(content_spiece_model)

        # 返回保存的词汇表文件路径的元组
        return (out_vocab_file,)

    def build_inputs_with_special_tokens(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        """
        Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
        adding special tokens. A Big Bird sequence has the following format:

        - single sequence: `[CLS] X [SEP]`
        - pair of sequences: `[CLS] A [SEP] B [SEP]`

        Args:
            token_ids_0 (`List[int]`):
                List of IDs to which the special tokens will be added.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.

        Returns:
            `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
        """
        # Check if only one sequence is provided
        if token_ids_1 is None:
            return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
        # Define special tokens for the start and separation
        cls = [self.cls_token_id]
        sep = [self.sep_token_id]
        # Concatenate tokens for a pair of sequences
        return cls + token_ids_0 + sep + token_ids_1 + sep

    def get_special_tokens_mask(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
    ) -> List[int]:
        """
        Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
        special tokens using the tokenizer `prepare_for_model` method.

        Args:
            token_ids_0 (`List[int]`):
                List of IDs.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.
            already_has_special_tokens (`bool`, *optional*, defaults to `False`):
                Whether or not the token list is already formatted with special tokens for the model.

        Returns:
            `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
        """
        # If the tokens already have special tokens, delegate to the superclass
        if already_has_special_tokens:
            return super().get_special_tokens_mask(
                token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
            )

        # Calculate special tokens mask for a single sequence
        if token_ids_1 is None:
            return [1] + ([0] * len(token_ids_0)) + [1]
        
        # Calculate special tokens mask for a pair of sequences
        return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]

    def create_token_type_ids_from_sequences(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        """
        Create token type IDs tensor from two sequences or a single sequence. Token type IDs are binary tensors where
        0 indicates the first sequence and 1 indicates the second sequence.

        Args:
            token_ids_0 (`List[int]`):
                List of IDs corresponding to the first sequence.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.

        Returns:
            `List[int]`: List of token type IDs.
        """
        # Initialize token type IDs for the first sequence
        token_type_ids = [0] * len(token_ids_0)
        # If token_ids_1 is provided, extend token type IDs to cover both sequences
        if token_ids_1 is not None:
            token_type_ids += [1] * len(token_ids_1)
        return token_type_ids
    ) -> List[int]:
        """
        Create a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence
        pair mask has the following format: :: 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 | first sequence | second
        sequence | If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s).

        Args:
            token_ids_0 (`List[int]`):
                List of IDs.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.

        Returns:
            `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
        """
        # Separator token IDs for separating sequences
        sep = [self.sep_token_id]
        # Classification token ID indicating the start of a classification task
        cls = [self.cls_token_id]
        
        # If only one sequence (`token_ids_1` is `None`), return mask for the first sequence
        if token_ids_1 is None:
            # Return a list of zeros representing the mask for the first sequence
            return len(cls + token_ids_0 + sep) * [0]
        
        # If there are two sequences, return a combined mask for both sequences
        # Concatenate the length of cls + token_ids_0 + sep with zeros, then add the length of token_ids_1 + sep with ones
        return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]

.\models\big_bird\tokenization_big_bird_fast.py

# 设定文件编码为 UTF-8
# 版权声明及许可信息
# 根据 Apache License 2.0 许可使用代码
# 如果不符合许可条件,则不能使用本文件
# 获取许可副本地址:http://www.apache.org/licenses/LICENSE-2.0
#
# 除非法律要求或书面同意,本软件是按“原样”基础分发的,不提供任何明示或暗示的担保或条件。
# 请查阅许可证以了解具体的法律权限和限制。
""" Big Bird 模型的 Tokenization 类 """

# 导入标准库和模块
import os
from shutil import copyfile
from typing import List, Optional, Tuple

# 导入依赖的工具和函数
from ...tokenization_utils import AddedToken
from ...tokenization_utils_fast import PreTrainedTokenizerFast
from ...utils import is_sentencepiece_available, logging

# 如果 SentencePiece 可用,则导入 BigBirdTokenizer 类,否则置为 None
if is_sentencepiece_available():
    from .tokenization_big_bird import BigBirdTokenizer
else:
    BigBirdTokenizer = None

# 获取日志记录器
logger = logging.get_logger(__name__)

# 定义词汇文件名映射
VOCAB_FILES_NAMES = {"vocab_file": "spiece.model", "tokenizer_file": "tokenizer.json"}

# 定义预训练模型的词汇文件映射
PRETRAINED_VOCAB_FILES_MAP = {
    "vocab_file": {
        "google/bigbird-roberta-base": "https://huggingface.co/google/bigbird-roberta-base/resolve/main/spiece.model",
        "google/bigbird-roberta-large": (
            "https://huggingface.co/google/bigbird-roberta-large/resolve/main/spiece.model"
        ),
        "google/bigbird-base-trivia-itc": (
            "https://huggingface.co/google/bigbird-base-trivia-itc/resolve/main/spiece.model"
        ),
    },
    "tokenizer_file": {
        "google/bigbird-roberta-base": (
            "https://huggingface.co/google/bigbird-roberta-base/resolve/main/tokenizer.json"
        ),
        "google/bigbird-roberta-large": (
            "https://huggingface.co/google/bigbird-roberta-large/resolve/main/tokenizer.json"
        ),
        "google/bigbird-base-trivia-itc": (
            "https://huggingface.co/google/bigbird-base-trivia-itc/resolve/main/tokenizer.json"
        ),
    },
}

# 定义预训练模型的位置编码嵌入大小
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
    "google/bigbird-roberta-base": 4096,
    "google/bigbird-roberta-large": 4096,
    "google/bigbird-base-trivia-itc": 4096,
}

# 定义 SentencePiece 中的特殊字符
SPIECE_UNDERLINE = "▁"

# BigBirdTokenizerFast 类继承自 PreTrainedTokenizerFast 类
class BigBirdTokenizerFast(PreTrainedTokenizerFast):
    """
    构建一个“快速”的 BigBird 分词器(由 HuggingFace 的 tokenizers 库支持)。基于 Unigram 模型。
    该分词器继承自 `PreTrainedTokenizerFast`,包含大多数主要方法。用户应参考其超类以获取更多关于这些方法的信息。
    """
    """
    Args:
        vocab_file (`str`):
            [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
            contains the vocabulary necessary to instantiate a tokenizer.
        bos_token (`str`, *optional*, defaults to `"<s>"`):
            The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.

            <Tip>

            When building a sequence using special tokens, this is not the token that is used for the beginning of
            sequence. The token used is the `cls_token`.

            </Tip>

        eos_token (`str`, *optional*, defaults to `"</s>"`):
            The end of sequence token. .. note:: When building a sequence using special tokens, this is not the token
            that is used for the end of sequence. The token used is the `sep_token`.
        unk_token (`str`, *optional*, defaults to `"<unk>"`):
            The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
            token instead.
        sep_token (`str`, *optional*, defaults to `"[SEP]"`):
            The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
            sequence classification or for a text and a question for question answering. It is also used as the last
            token of a sequence built with special tokens.
        pad_token (`str`, *optional*, defaults to `"<pad>"`):
            The token used for padding, for example when batching sequences of different lengths.
        cls_token (`str`, *optional*, defaults to `"[CLS]"`):
            The classifier token which is used when doing sequence classification (classification of the whole sequence
            instead of per-token classification). It is the first token of the sequence when built with special tokens.
        mask_token (`str`, *optional*, defaults to `"[MASK]"`):
            The token used for masking values. This is the token used when training this model with masked language
            modeling. This is the token which the model will try to predict.
    """
    # 定义一些预先设置好的常量和类,用于初始化 tokenizer
    vocab_files_names = VOCAB_FILES_NAMES
    pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
    max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
    slow_tokenizer_class = BigBirdTokenizer
    model_input_names = ["input_ids", "attention_mask"]
    prefix_tokens: List[int] = []

    def __init__(
        self,
        vocab_file=None,
        tokenizer_file=None,
        unk_token="<unk>",
        bos_token="<s>",
        eos_token="</s>",
        pad_token="<pad>",
        sep_token="[SEP]",
        mask_token="[MASK]",
        cls_token="[CLS]",
        **kwargs,
    ):
        """
        构造函数,初始化一个新的 tokenizer 对象。

        Args:
            vocab_file (str, optional): SentencePiece 文件的路径,包含了实例化 tokenizer 所需的词汇表。
            tokenizer_file (str, optional): tokenizer 文件的路径,如果提供了,将会加载现有的 tokenizer。
            unk_token (str, optional): 未知 token,当词汇表中没有某个词时,将使用该 token。
            bos_token (str, optional): 序列的开头 token,用于序列分类或者特殊 token 序列的起始。
            eos_token (str, optional): 序列的结尾 token,用于特殊 token 序列的结束。
            pad_token (str, optional): 填充 token,在批处理不同长度序列时使用。
            sep_token (str, optional): 分隔 token,用于构建来自多个序列的单一序列。
            cls_token (str, optional): 分类器 token,用于序列分类任务中整个序列的分类。
            mask_token (str, optional): 掩码 token,用于预测被 mask 的词语。
            **kwargs: 其他关键字参数,用于额外配置 tokenizer。
        """
    ):
        # 如果 bos_token 是字符串类型,则创建一个 AddedToken 对象,保持左右两端的空白不变
        bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
        # 如果 eos_token 是字符串类型,则创建一个 AddedToken 对象,保持左右两端的空白不变
        eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
        # 如果 unk_token 是字符串类型,则创建一个 AddedToken 对象,保持左右两端的空白不变
        unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
        # 如果 pad_token 是字符串类型,则创建一个 AddedToken 对象,保持左右两端的空白不变
        pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
        # 如果 cls_token 是字符串类型,则创建一个 AddedToken 对象,保持左右两端的空白不变
        cls_token = AddedToken(cls_token, lstrip=False, rstrip=False) if isinstance(cls_token, str) else cls_token
        # 如果 sep_token 是字符串类型,则创建一个 AddedToken 对象,去除左侧空白,保持右侧空白
        mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token

        # 调用父类的初始化方法,传入参数进行初始化
        super().__init__(
            vocab_file,
            tokenizer_file=tokenizer_file,
            bos_token=bos_token,
            eos_token=eos_token,
            unk_token=unk_token,
            sep_token=sep_token,
            pad_token=pad_token,
            cls_token=cls_token,
            mask_token=mask_token,
            **kwargs,
        )

        # 设置实例的 vocab_file 属性
        self.vocab_file = vocab_file

    @property
    def can_save_slow_tokenizer(self) -> bool:
        # 检查 vocab_file 是否存在,如果存在返回 True,否则返回 False
        return os.path.isfile(self.vocab_file) if self.vocab_file else False

    def build_inputs_with_special_tokens(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        """
        Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
        adding special tokens. An BigBird sequence has the following format:

        - single sequence: `[CLS] X [SEP]`
        - pair of sequences: `[CLS] A [SEP] B [SEP]`

        Args:
            token_ids_0 (`List[int]`):
                List of IDs to which the special tokens will be added
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.

        Returns:
            `List[int]`: list of [input IDs](../glossary#input-ids) with the appropriate special tokens.
        """
        sep = [self.sep_token_id]
        cls = [self.cls_token_id]
        # 如果 token_ids_1 为空,则返回 `[CLS] + token_ids_0 + [SEP]`
        if token_ids_1 is None:
            return cls + token_ids_0 + sep
        # 否则返回 `[CLS] + token_ids_0 + [SEP] + token_ids_1 + [SEP]`
        return cls + token_ids_0 + sep + token_ids_1 + sep

    def get_special_tokens_mask(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
    ):
    def create_token_type_ids_from_sequences(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        """
        Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT
        sequence pair mask has the following format:
    
        ```
        0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
        | first sequence    | second sequence |
        ```
    
        if token_ids_1 is None, only returns the first portion of the mask (0s).
    
        Args:
            token_ids_0 (`List[int]`):
                List of ids.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.
    
        Returns:
            `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
        """
        # 定义分隔和类别标记
        sep = [self.sep_token_id]
        cls = [self.cls_token_id]
    
        # 如果没有第二个序列,则返回只包含第一个序列和分隔符的长度的 0 组成的列表
        if token_ids_1 is None:
            return len(cls + token_ids_0 + sep) * [0]
        
        # 否则返回一个列表,其中包含第一个序列、分隔符以及第二个序列和分隔符的长度的 0 和 1 组成的列表
        return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
    # 定义一个保存词汇表的方法,接受一个保存目录和可选的文件名前缀作为参数,并返回一个包含文件路径的元组
    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
        # 检查是否能够保存慢速分词器的词汇表,否则抛出数值错误
        if not self.can_save_slow_tokenizer:
            raise ValueError(
                "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
                "tokenizer."
            )

        # 如果保存目录不存在,则记录错误并返回
        if not os.path.isdir(save_directory):
            logger.error(f"Vocabulary path ({save_directory}) should be a directory")
            return
        
        # 构建输出词汇表文件的路径,如果提供了前缀则加在文件名前面,否则直接使用默认文件名
        out_vocab_file = os.path.join(
            save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
        )

        # 如果当前词汇表文件的绝对路径不等于输出路径的绝对路径,则复制当前词汇表文件到输出路径
        if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
            copyfile(self.vocab_file, out_vocab_file)

        # 返回保存的词汇表文件路径的元组
        return (out_vocab_file,)
posted @ 2024-06-30 15:34  绝不原创的飞龙  阅读(0)  评论(0编辑  收藏  举报